CNN识别mnist手写数字

mnist数据的下载、读取部分请参见:DNN识别mnist手写数字

为了使读取到的图片数据能输入CNN,需要为图片数据增加channel维度

train_x = np.expand_dims(train_x,axis=-1)
test_x = np.expand_dims(test_x,axis=-1)

查看增维后数据的维度

print(train_x.shape)
print(test_x.shape)

搭建CNN并训练

drop_rate = 0.01
model = keras.Sequential()
model.add(layers.Conv2D(64,(3,3),activation='relu',input_shape=(28,28,1)))
model.add(layers.MaxPooling2D())
model.add(layers.Flatten())
model.add(layers.Dense(200,activation='relu'))
model.add(layers.Dropout(drop_rate))
model.add(layers.Dense(10,activation='softmax'))
adam = keras.optimizers.Adam(lr=0.001)
model.compile(optimizer=adam,loss='sparse_categorical_crossentropy',metrics=['acc'])
model.fit(train_x,train_y,epochs=10,batch_size=512)

经过10轮训练后,CNN在训练集上的loss和准确率如下

CNN在测试集上的loss和准确率如下

model.evaluate(test_x,test_y)

原文地址:https://www.cnblogs.com/bill-h/p/13907572.html