手写数字识别一直以来都是AI入门的HelloWorld,俗话说得好,麻雀虽小,五脏六腑俱全!
环境: Python2.7、Tensorflow1.12.0、macOS Mojave 10.14.2
准备数据集
本次实验将会使用最为简单的MNIST数据集,其中50000张训练集,10000张测试集。在使用之前,我们最好对数据集的格式进行了解下。MNIST数据集不是普通的3通道彩色图像,每张图像的形状都是一个1D向量(784),并且做了二值化处理。我们可以输出其中的一张图来观察下(读取方式为one_hot):
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.16470589 0.6431373 0.9960785 0.4784314
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.8352942 0.9921569 0.9921569 0.98823535 0.854902 0.36862746
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0.5647059 0.9843138 0.9921569
0.9921569 0.9921569 0.9921569 0.9843138 0.5882353 0.77647066
0.56078434 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.16078432 0.9215687 0.9921569 0.654902 0.58431375 0.8980393
0.9921569 0.9921569 0.9921569 0.9921569 0.86666673 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.38823533 0.9921569
0.49803925 0.01960784 0. 0.07843138 0.5686275 0.9921569
0.9921569 0.9921569 0.86666673 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.38823533 0.9921569 0.14901961 0.
0. 0. 0.03137255 0.7490196 0.9921569 0.9921569
0.86666673 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.38823533 0.9921569 0.69411767 0.3372549 0.2392157 0.25882354
0.3372549 0.8078432 0.9921569 0.9921569 0.43137258 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.38823533 0.9921569
0.9921569 0.9921569 0.9450981 0.95294124 0.9921569 0.9921569
0.9921569 0.86666673 0.0627451 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.38823533 0.9921569 0.9921569 0.9921569
0.9921569 0.9921569 0.9921569 0.9921569 0.94117653 0.34117648
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.09019608 0.227451 0.65882355 0.7372549 0.7372549 0.7372549
0.7960785 0.9921569 0.68235296 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.23137257 0.9921569
0.30588236 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.23137257 0.9921569 0.30588236 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.23137257 0.9921569 0.30588236 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.23137257 0.9921569
0.30588236 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.23137257 0.9921569 0.30588236 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.23137257 0.9921569 0.61960787 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.21568629 0.9725491
0.8117648 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.01568628 0.7294118 0.91372555 0.18823531
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.227451 0.9843138 0.98823535 0.32156864 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.7176471
0.8117648 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
每张图片还有对应的标签,同样我们把其中一个标签输出来观察结构:
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
如果不使用one_hot模式读取数据集,那么就是普通的阿拉伯数字,如果使用了one_hot模式读取,程序会把每个数字解析为1D的向量,向量的元素只有0和1,并且形状为(10)。用白话讲,就是1的索引就是这个向量所代表的实际数字。这样方便了后面概率的计算。
到这里我们已经对MNIST数据集有所了解了。
读取数据
读取数据的方式有多种,这里将使用一个开源的读取脚本。
input_data.py
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
import gzip
import os
import tempfile
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
# pylint: enable=unused-import
调用该脚本会出现以下警告:
由于tensorflow的版本为1.12.0有部分API将会废弃,并不影响功能。请不要担心。
通过以下脚本,我们可以读取到MNIST数据集
test_data.py
import input_data
mnist = input_data.read_data_sets('data',one_hot=True)
print mnist.train.next_batch(1)[0][0] #这里输出了第一张图片的数据
print mnist.train.next_batch(1)[1][0] #这里输出了第一张图片的标签
在国内数据集可能无法直接下载,因此可以直接下载这些数据包,存放到data目录。
- t10k-images-idx3-ubyte.gz
- t10k-labels-idx1-ubyte.gz
- train-images-idx3-ubyte.gz
- train-labels-idx1-ubyte.gz
百度云盘下载(链接地址失效请留言) 密码:g9j0
搭建神经网络
- 第一层
卷积:32个5x5x1的卷积核,步长为1,padding为same
激活:用relu函数进行非线性化处理
池化:2×2的大小,步长为2,最大池化,padding为same - 第二层
卷积:64个5x5x32的卷积核,步长为1,padding为same
激活:用relu函数进行非线性化处理
池化:2×2的大小,步长为2,最大池化,padding为same - 第三层
全连接:1024个神经元(根据经验得出此参数)
激活:用relu函数进行非线性化处理
过拟合:dropout处理 - 第四层
全连接:把1024个神经元映射为10个
softmax:取10个里面置信度最高的输出
实际代码:
model.py
import tensorflow as tf
def get_weights(shape):
w_variable = tf.get_variable("w_variable", shape=shape, trainable=True)
return w_variable
# initial = tf.truncated_normal(shape,stddev=0.1)
# return tf.Variable(initial)
def get_bias(shape):
b_variable = tf.get_variable("b_variable", shape=shape, trainable=True)
return b_variable
def conv2d_relu(input, filter, bias):
features = tf.nn.conv2d(input=input,
filter=filter,
strides=[1, 1, 1, 1],
padding='SAME')
return tf.nn.relu(features=(features + bias))
def max_pool(input):
return tf.nn.max_pool(value=input,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME')
def forward(input_data,keep_prob):
input_data = tf.reshape(input_data, [-1, 28, 28, 1])
with tf.variable_scope("layer_1"):
w1 = get_weights([5, 5, 1, 32])
b1 = get_bias([32])
conv1 = conv2d_relu(input_data, w1, b1)
pool1 = max_pool(conv1)
with tf.variable_scope("layer_2"):
w2 = get_weights([5, 5, 32, 64])
b2 = get_bias([64])
conv2 = conv2d_relu(pool1, w2, b2)
pool2 = max_pool(conv2)
with tf.variable_scope("full_connection_1"):
pool2_flat = tf.reshape(pool2, shape=[-1, 7 * 7 * 64])
w3 = get_weights([7 * 7 * 64, 1024])
b3 = get_bias([1024])
fc1 = tf.nn.relu(tf.matmul(pool2_flat, w3) + b3)
fc1_drop = tf.nn.dropout(fc1, keep_prob)
with tf.variable_scope("full_connection_2"):
w4 = get_weights([1024, 10])
b4 = get_bias([10])
fc2 = tf.matmul(fc1_drop, w4) + b4
softmax = tf.nn.softmax(fc2)
return softmax
需要注意的是:MINIST的图像数据是1D的格式,我们需要预先把它reshape为[-1, 28, 28, 1]的灰度图像,经过前向传播到达全连接层的时候,再把数据形状reshape为[-1, 7 x 7 x 64](7×7是最后pool2的大小,可以计算所得,详细见卷积和池化后尺寸计算),即把数据拉直。
训练神经网络
训练网络模型,就是让算法调整网络中的可以训练参数,拟合数据。我们使用交叉熵计算loss,通过Adam优化器来优化网络(学习率设置为1e-4)。最终我们把训练得到的模型存储到save_model/model.ckpt
实际代码:
train.py
import tensorflow as tf
import model
import input_data
def get_loss(input_lables, perdict_lables):
cross_entropy = -tf.reduce_sum(input_lables * tf.log(perdict_lables))
return cross_entropy
# prepare mnist data
mnist = input_data.read_data_sets('data', one_hot=True)
# super params
learning_rate = 1e-4
# placeholder
input_data = tf.placeholder(tf.float32, [None, 784])
input_labels = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)
forward_result = model.forward(input_data, keep_prob)
loss = get_loss(input_labels, forward_result)
correct_prediction = tf.equal(tf.argmax(forward_result, 1), tf.argmax(input_labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# backward
train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)
# save model
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(10000):
batch_data, batch_labels = mnist.train.next_batch(10)
_loss, _train_step, _accuracy = sess.run([loss, train_step, accuracy],
feed_dict={input_data: batch_data, input_labels: batch_labels,
keep_prob: 0.5})
print ('step = {}, loss = {}, accuracy= {:.2f}'.format(i, _loss, _accuracy))
saver.save(sess, 'save_model/model.ckpt')
print('test accuracy {:.4f}'.format(
accuracy.eval(feed_dict={input_data: mnist.test.images, input_labels: mnist.test.labels, keep_prob: 1})))
训练结果
参考数据
keep_prob | 迭代次数 | batch_size | 测试集准确率 |
---|---|---|---|
1.0 | 1000 | 50 | 0.9704 |
0.5 | 2000 | 20 | 0.9736 |
0.5 | 5000 | 10 | 0.9824 |
0.5 | 10000 | 10 | 0.9872 |
优化参数可以提高准确率
预测图片
利用刚才训练保存的模型来预测图像中的数字。由于手头没有实际图片,只能使用测试集里面的某一张图片来做测试。
代码如下:
predict.py
import tensorflow as tf
import model
import input_data
import numpy as np
import matplotlib.pyplot as plt
def show_image(mnist_image):
image = np.reshape(mnist_image, newshape=[28, -1])
plt.imshow(image, cmap=plt.get_cmap('gray_r'))
plt.show()
mnist = input_data.read_data_sets('data', one_hot=True)
input_data = tf.placeholder(tf.float32, [None, 784])
keep_prob = tf.placeholder(tf.float32)
forward_result = model.forward(input_data,keep_prob)
predict_result = tf.argmax(forward_result, 1)
saver = tf.train.Saver()
index = 1 #测试集图片索引
with tf.Session() as sess:
image = mnist.test.images[index]
show_image(image)
sess.run(tf.global_variables_initializer())
saver.restore(sess, "save_model/model.ckpt")
_predict_result = sess.run([predict_result], feed_dict={input_data: [image],keep_prob:1.0})
_correct_result = np.argmax(mnist.test.labels[index])
print('predict={} ,correct={}'.format(_predict_result[0][0], _correct_result))
需要注意的是:实际的手写数字图像(需要经过二值化处理为28x28x1的图像),输入该网络准确率会比较低,因为国外的书写习惯和国内的是不一样的。
预测结果
可以修改测试集图片索引index
来预测测试集中的其他图片。
总结
文中还有一些函数没有做出解释,在接下来的文章中会作进一步的详解。比如dropout,softmax,max_pool,conv2d.