TensorFlow2_200729系列---18、手写数字识别(层方式)

TensorFlow2_200729系列---18、手写数字识别(层方式)

一、总结

一句话总结:

之前是张量(tensor)的方式,体现细节和原理,现在是层方式,更加简便简洁
model = Sequential([
    layers.Dense(256, activation=tf.nn.relu), # [b, 784] => [b, 256]
    layers.Dense(128, activation=tf.nn.relu), # [b, 256] => [b, 128]
    layers.Dense(64, activation=tf.nn.relu), # [b, 128] => [b, 64]
    layers.Dense(32, activation=tf.nn.relu), # [b, 64] => [b, 32]
    layers.Dense(10) # [b, 32] => [b, 10], 330 = 32*10 + 10
])
model.build(input_shape=[None, 28*28])
model.summary()

使用模型:
logits = model(x)

1、tensorflow的keras模块包括datasets, layers, optimizers, Sequential, metrics?

from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

二、手写数字识别(层方式)

博客对应课程的视频位置:

import  os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
from    tensorflow import keras
from    tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

assert tf.__version__.startswith('2.')

# 预处理函数
# 数据归一化
def preprocess(x, y):

    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    return x,y

# 自动加载数据
(x, y), (x_test, y_test) = datasets.fashion_mnist.load_data()
print(x.shape, y.shape)

# batch size:一批的大小
batchsz = 128

# 训练数据
# Creates a `Dataset` whose elements are slices of the given tensors.
db = tf.data.Dataset.from_tensor_slices((x,y))
# shuffle打乱并且分batch
db = db.map(preprocess).shuffle(10000).batch(batchsz)

# 测试数据
db_test = tf.data.Dataset.from_tensor_slices((x_test,y_test))
# 测试数据不需要打乱
db_test = db_test.map(preprocess).batch(batchsz)

# 迭代器
db_iter = iter(db)
sample = next(db_iter)
print('batch:', sample[0].shape, sample[1].shape)
# batch: (128, 28, 28) (128,)

model = Sequential([
    layers.Dense(256, activation=tf.nn.relu), # [b, 784] => [b, 256]
    layers.Dense(128, activation=tf.nn.relu), # [b, 256] => [b, 128]
    layers.Dense(64, activation=tf.nn.relu), # [b, 128] => [b, 64]
    layers.Dense(32, activation=tf.nn.relu), # [b, 64] => [b, 32]
    layers.Dense(10) # [b, 32] => [b, 10], 330 = 32*10 + 10
])
model.build(input_shape=[None, 28*28])
model.summary()
# w = w - lr*grad
optimizer = optimizers.Adam(lr=1e-3)

def main():


    for epoch in range(30):


        for step, (x,y) in enumerate(db):

            # x: [b, 28, 28] => [b, 784]
            # y: [b]
            x = tf.reshape(x, [-1, 28*28])

            with tf.GradientTape() as tape:
                # [b, 784] => [b, 10]
                logits = model(x)
                y_onehot = tf.one_hot(y, depth=10)
                # [b]
                loss_mse = tf.reduce_mean(tf.losses.MSE(y_onehot, logits))
                loss_ce = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
                loss_ce = tf.reduce_mean(loss_ce)

            # model.trainable_variables:表示参数,也就是w和b    
            grads = tape.gradient(loss_ce, model.trainable_variables)
            # 原地更新
            optimizer.apply_gradients(zip(grads, model.trainable_variables))


            if step % 100 == 0:
                print(epoch, step, 'loss:', float(loss_ce), float(loss_mse))


        # test
        total_correct = 0
        total_num = 0
        for x,y in db_test:

            # x: [b, 28, 28] => [b, 784]
            # y: [b]
            x = tf.reshape(x, [-1, 28*28])
            # [b, 10]
            logits = model(x)
            # logits => prob, [b, 10]
            prob = tf.nn.softmax(logits, axis=1)
            # [b, 10] => [b], int64
            pred = tf.argmax(prob, axis=1)
            pred = tf.cast(pred, dtype=tf.int32)
            # pred:[b]
            # y: [b]
            # correct: [b], True: equal, False: not equal
            correct = tf.equal(pred, y)
            correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))

            total_correct += int(correct)
            total_num += x.shape[0]

        acc = total_correct / total_num
        print(epoch, 'test acc:', acc)


if __name__ == '__main__':
    main()
(60000, 28, 28) (60000,)
batch: (128, 28, 28) (128,)
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                multiple                  200960    
_________________________________________________________________
dense_1 (Dense)              multiple                  32896     
_________________________________________________________________
dense_2 (Dense)              multiple                  8256      
_________________________________________________________________
dense_3 (Dense)              multiple                  2080      
_________________________________________________________________
dense_4 (Dense)              multiple                  330       
=================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________
0 0 loss: 2.2831406593322754 0.12128783017396927
0 100 loss: 0.6292909383773804 18.471372604370117
0 200 loss: 0.4110489785671234 18.012493133544922
0 300 loss: 0.4045095443725586 14.464678764343262
0 400 loss: 0.5457848310470581 18.89963722229004
0 test acc: 0.8464
1 0 loss: 0.3746733069419861 17.479198455810547
1 100 loss: 0.4358266592025757 22.785991668701172
1 200 loss: 0.34541454911231995 21.296964645385742
1 300 loss: 0.6067743301391602 15.614158630371094
1 400 loss: 0.3972930908203125 19.6492862701416
1 test acc: 0.8624
2 0 loss: 0.38962411880493164 21.97345733642578
2 100 loss: 0.32657697796821594 24.56244659423828
2 200 loss: 0.33351215720176697 21.94455909729004
2 300 loss: 0.33902692794799805 26.513694763183594
2 400 loss: 0.374667763710022 24.564464569091797
2 test acc: 0.87
3 0 loss: 0.3156976103782654 26.225505828857422
3 100 loss: 0.25029417872428894 31.27167320251465
3 200 loss: 0.2644597291946411 30.24835205078125
3 300 loss: 0.27939510345458984 28.048099517822266
3 400 loss: 0.28537383675575256 29.133506774902344
3 test acc: 0.8732
4 0 loss: 0.3248022496700287 28.466659545898438
4 100 loss: 0.2858565151691437 30.900304794311523
4 200 loss: 0.27638182044029236 29.498920440673828
4 300 loss: 0.230974480509758 30.112049102783203
4 400 loss: 0.2854992747306824 30.296920776367188
4 test acc: 0.8702
5 0 loss: 0.3455154001712799 40.34062194824219
5 100 loss: 0.1356138437986374 44.47637939453125
5 200 loss: 0.35171258449554443 36.933929443359375
5 300 loss: 0.1946554183959961 40.9914436340332
5 400 loss: 0.317558228969574 43.98896026611328
5 test acc: 0.8766
6 0 loss: 0.220485657453537 43.64817810058594
6 100 loss: 0.21728813648223877 39.432037353515625
6 200 loss: 0.1949092447757721 48.853843688964844
6 300 loss: 0.15513527393341064 44.78699493408203
6 400 loss: 0.21645382046699524 49.537960052490234
6 test acc: 0.8784
7 0 loss: 0.3043939471244812 43.625526428222656
7 100 loss: 0.3803304433822632 47.13135528564453
7 200 loss: 0.19133441150188446 45.429298400878906
7 300 loss: 0.1776151806116104 41.924537658691406
7 400 loss: 0.16863960027694702 42.79949951171875
7 test acc: 0.8738
8 0 loss: 0.1874811351299286 44.651611328125
8 100 loss: 0.18167194724082947 51.90435028076172
8 200 loss: 0.20349520444869995 49.169429779052734
8 300 loss: 0.2610611915588379 48.555458068847656
8 400 loss: 0.2457474023103714 55.176734924316406
8 test acc: 0.8802
9 0 loss: 0.2630460858345032 53.88508224487305
9 100 loss: 0.20558330416679382 63.04207992553711
9 200 loss: 0.18517211079597473 65.81611633300781
9 300 loss: 0.20496012270450592 55.272369384765625
9 400 loss: 0.22070546448230743 59.18791198730469
9 test acc: 0.8839
10 0 loss: 0.15226173400878906 50.21383285522461
10 100 loss: 0.10652273893356323 66.88746643066406
10 200 loss: 0.15289798378944397 69.40467071533203
10 300 loss: 0.24505120515823364 52.915138244628906
10 400 loss: 0.21931703388690948 62.56816101074219
10 test acc: 0.8879
11 0 loss: 0.28399163484573364 66.48797607421875
11 100 loss: 0.2144084870815277 68.24934387207031
11 200 loss: 0.2513824999332428 49.95161056518555
11 300 loss: 0.23070569336414337 53.46424102783203
11 400 loss: 0.1969212144613266 59.240577697753906
11 test acc: 0.8885
12 0 loss: 0.1924872249364853 55.86700439453125
12 100 loss: 0.21166521310806274 69.66253662109375
12 200 loss: 0.09095969796180725 76.57353210449219
12 300 loss: 0.15812699496746063 67.44322204589844
12 400 loss: 0.20802710950374603 84.34611511230469
12 test acc: 0.8818
13 0 loss: 0.22456292808055878 63.22731018066406
13 100 loss: 0.1939781904220581 75.43051147460938
13 200 loss: 0.3054753839969635 81.66931915283203
13 300 loss: 0.23840418457984924 63.1304931640625
13 400 loss: 0.2474619597196579 69.8863525390625
13 test acc: 0.8885
14 0 loss: 0.16988956928253174 85.44975280761719
14 100 loss: 0.2409886121749878 94.68711853027344
14 200 loss: 0.19825829565525055 75.6685791015625
14 300 loss: 0.26892679929733276 102.5340576171875
14 400 loss: 0.10896225273609161 90.99492645263672
14 test acc: 0.8929
15 0 loss: 0.20140430331230164 88.07635498046875
15 100 loss: 0.11349458247423172 94.00519561767578
15 200 loss: 0.1777578443288803 75.71000671386719
15 300 loss: 0.27039724588394165 74.60504150390625
15 400 loss: 0.2390979528427124 95.77617645263672
15 test acc: 0.8875
16 0 loss: 0.22253449261188507 74.66510009765625
16 100 loss: 0.15607573091983795 91.78387451171875
16 200 loss: 0.15405383706092834 109.38969421386719
16 300 loss: 0.10432792454957962 93.83760070800781
16 400 loss: 0.127157062292099 81.89163208007812
16 test acc: 0.8902
17 0 loss: 0.1748599112033844 74.17550659179688
17 100 loss: 0.21128180623054504 101.88328552246094
17 200 loss: 0.213323712348938 99.44528198242188
17 300 loss: 0.1905888170003891 92.34651947021484
17 400 loss: 0.08545292168855667 118.2466049194336
17 test acc: 0.892
18 0 loss: 0.13534505665302277 107.93522644042969
18 100 loss: 0.10933603346347809 120.73545837402344
18 200 loss: 0.21846728026866913 107.94190979003906
18 300 loss: 0.2655482292175293 107.08270263671875
18 400 loss: 0.23332582414150238 110.47785949707031
18 test acc: 0.892
19 0 loss: 0.16872575879096985 112.55984497070312
19 100 loss: 0.2029556930065155 105.87848663330078
19 200 loss: 0.13815325498580933 110.57797241210938
19 300 loss: 0.26082828640937805 106.12140655517578
19 400 loss: 0.15341421961784363 129.8838348388672
19 test acc: 0.8934
20 0 loss: 0.29162901639938354 111.24371337890625
20 100 loss: 0.23025716841220856 105.3729248046875
20 200 loss: 0.13770082592964172 119.57967376708984
20 300 loss: 0.24651116132736206 120.30937957763672
20 400 loss: 0.21254345774650574 107.91622161865234
20 test acc: 0.8917
21 0 loss: 0.09702333062887192 100.24187469482422
21 100 loss: 0.15910854935646057 129.3473358154297
21 200 loss: 0.0851014256477356 138.20169067382812
21 300 loss: 0.1595071405172348 118.91499328613281
21 400 loss: 0.2024853229522705 109.33180236816406
21 test acc: 0.8899
22 0 loss: 0.13461339473724365 140.96359252929688
22 100 loss: 0.18555812537670135 163.75796508789062
22 200 loss: 0.20990914106369019 129.61654663085938
22 300 loss: 0.13982388377189636 127.89125061035156
22 400 loss: 0.15919993817806244 130.1854248046875
22 test acc: 0.8945
23 0 loss: 0.1364736706018448 139.3319549560547
23 100 loss: 0.15799979865550995 164.3890380859375
23 200 loss: 0.13190290331840515 147.40843200683594
23 300 loss: 0.21002498269081116 128.39404296875
23 400 loss: 0.20846235752105713 150.4744873046875
23 test acc: 0.8977
24 0 loss: 0.180350199341774 148.44241333007812
24 100 loss: 0.13309326767921448 141.54718017578125
24 200 loss: 0.09000922739505768 144.95884704589844
24 300 loss: 0.09340814501047134 149.47250366210938
24 400 loss: 0.11350023001432419 135.44602966308594
24 test acc: 0.8948
25 0 loss: 0.10290056467056274 129.1949920654297
25 100 loss: 0.10859610140323639 147.93728637695312
25 200 loss: 0.15649116039276123 157.18661499023438
25 300 loss: 0.09786863625049591 163.36807250976562
25 400 loss: 0.13727730512619019 151.49111938476562
25 test acc: 0.8928
26 0 loss: 0.14575082063674927 152.08493041992188
26 100 loss: 0.08358915150165558 157.06666564941406
26 200 loss: 0.13103337585926056 141.79519653320312
26 300 loss: 0.1875842809677124 163.2027587890625
26 400 loss: 0.2265387624502182 177.1208038330078
26 test acc: 0.896
27 0 loss: 0.12106870114803314 160.7439422607422
27 100 loss: 0.11055881530046463 181.66207885742188
27 200 loss: 0.08392684161663055 155.105712890625
27 300 loss: 0.14919903874397278 153.12997436523438
27 400 loss: 0.11113433539867401 176.20297241210938
27 test acc: 0.8837
28 0 loss: 0.20585989952087402 201.09121704101562
28 100 loss: 0.1687045842409134 180.786865234375
28 200 loss: 0.13319478929042816 169.2735595703125
28 300 loss: 0.08168964087963104 159.86993408203125
28 400 loss: 0.11371222138404846 196.85433959960938
28 test acc: 0.8918
29 0 loss: 0.120663121342659 209.4883575439453
29 100 loss: 0.12449634820222855 171.43267822265625
29 200 loss: 0.18858373165130615 171.96971130371094
29 300 loss: 0.07583779096603394 193.2484130859375
29 400 loss: 0.10753950476646423 199.6551513671875
29 test acc: 0.8944
 
我的旨在学过的东西不再忘记(主要使用艾宾浩斯遗忘曲线算法及其它智能学习复习算法)的偏公益性质的完全免费的编程视频学习网站: fanrenyi.com;有各种前端、后端、算法、大数据、人工智能等课程。
博主25岁,前端后端算法大数据人工智能都有兴趣。
大家有啥都可以加博主联系方式(qq404006308,微信fan404006308)互相交流。工作、生活、心境,可以互相启迪。
聊技术,交朋友,修心境,qq404006308,微信fan404006308
26岁,真心找女朋友,非诚勿扰,微信fan404006308,qq404006308
人工智能群:939687837

作者相关推荐

原文地址:https://www.cnblogs.com/Renyi-Fan/p/13439979.html