松鼠乐园 松鼠乐园
  • 注册
  • 登录
  • 首页
  • 快捷入口
    • Vue
    • Tensorflow
    • Springboot
    • 语言类
      • CSS
      • ES5
      • ES6
      • Go
      • Java
      • Javascript
    • 工具类
      • Git
      • 工具推荐
    • 服务器&运维
      • Centos
      • Docker
      • Linux
      • Mac
      • MySQL
      • Nginx
      • Redis
      • Windows
    • 资源类
      • 论文
      • 书籍推荐
      • 后端资源
      • 前端资源
      • html网页模板
      • 代码
    • 性能优化
    • 测试
  • 重大新闻
  • 人工智能
  • 开源项目
  • Vue2.0从零开始
  • 广场

Tensorflow手写数字识别

迦娜王 3年前 Tensorflow

手写数字识别一直以来都是AI入门的HelloWorld,俗话说得好,麻雀虽小,五脏六腑俱全!
环境: Python2.7、Tensorflow1.12.0、macOS Mojave 10.14.2
file

准备数据集

本次实验将会使用最为简单的MNIST数据集,其中50000张训练集,10000张测试集。在使用之前,我们最好对数据集的格式进行了解下。MNIST数据集不是普通的3通道彩色图像,每张图像的形状都是一个1D向量(784),并且做了二值化处理。我们可以输出其中的一张图来观察下(读取方式为one_hot):

[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.16470589 0.6431373  0.9960785  0.4784314
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.8352942  0.9921569  0.9921569  0.98823535 0.854902   0.36862746
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.5647059  0.9843138  0.9921569
 0.9921569  0.9921569  0.9921569  0.9843138  0.5882353  0.77647066
 0.56078434 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.16078432 0.9215687  0.9921569  0.654902   0.58431375 0.8980393
 0.9921569  0.9921569  0.9921569  0.9921569  0.86666673 0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.38823533 0.9921569
 0.49803925 0.01960784 0.         0.07843138 0.5686275  0.9921569
 0.9921569  0.9921569  0.86666673 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.38823533 0.9921569  0.14901961 0.
 0.         0.         0.03137255 0.7490196  0.9921569  0.9921569
 0.86666673 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.38823533 0.9921569  0.69411767 0.3372549  0.2392157  0.25882354
 0.3372549  0.8078432  0.9921569  0.9921569  0.43137258 0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.38823533 0.9921569
 0.9921569  0.9921569  0.9450981  0.95294124 0.9921569  0.9921569
 0.9921569  0.86666673 0.0627451  0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.38823533 0.9921569  0.9921569  0.9921569
 0.9921569  0.9921569  0.9921569  0.9921569  0.94117653 0.34117648
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.09019608 0.227451   0.65882355 0.7372549  0.7372549  0.7372549
 0.7960785  0.9921569  0.68235296 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.23137257 0.9921569
 0.30588236 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.23137257 0.9921569  0.30588236 0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.23137257 0.9921569  0.30588236 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.23137257 0.9921569
 0.30588236 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.23137257 0.9921569  0.30588236 0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.23137257 0.9921569  0.61960787 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.21568629 0.9725491
 0.8117648  0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.01568628 0.7294118  0.91372555 0.18823531
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.227451   0.9843138  0.98823535 0.32156864 0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.7176471
 0.8117648  0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.        ]

每张图片还有对应的标签,同样我们把其中一个标签输出来观察结构:

[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]

如果不使用one_hot模式读取数据集,那么就是普通的阿拉伯数字,如果使用了one_hot模式读取,程序会把每个数字解析为1D的向量,向量的元素只有0和1,并且形状为(10)。用白话讲,就是1的索引就是这个向量所代表的实际数字。这样方便了后面概率的计算。
到这里我们已经对MNIST数据集有所了解了。

读取数据

读取数据的方式有多种,这里将使用一个开源的读取脚本。
input_data.py

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=unused-import
import gzip
import os
import tempfile

import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
# pylint: enable=unused-import

调用该脚本会出现以下警告:
file
由于tensorflow的版本为1.12.0有部分API将会废弃,并不影响功能。请不要担心。

通过以下脚本,我们可以读取到MNIST数据集
test_data.py

import input_data
mnist = input_data.read_data_sets('data',one_hot=True)
print mnist.train.next_batch(1)[0][0] #这里输出了第一张图片的数据
print mnist.train.next_batch(1)[1][0] #这里输出了第一张图片的标签

在国内数据集可能无法直接下载,因此可以直接下载这些数据包,存放到data目录。

  • t10k-images-idx3-ubyte.gz
  • t10k-labels-idx1-ubyte.gz
  • train-images-idx3-ubyte.gz
  • train-labels-idx1-ubyte.gz

百度云盘下载(链接地址失效请留言) 密码:g9j0

搭建神经网络

  1. 第一层
    卷积:32个5x5x1的卷积核,步长为1,padding为same
    激活:用relu函数进行非线性化处理
    池化:2×2的大小,步长为2,最大池化,padding为same
  2. 第二层
    卷积:64个5x5x32的卷积核,步长为1,padding为same
    激活:用relu函数进行非线性化处理
    池化:2×2的大小,步长为2,最大池化,padding为same
  3. 第三层
    全连接:1024个神经元(根据经验得出此参数)
    激活:用relu函数进行非线性化处理
    过拟合:dropout处理
  4. 第四层
    全连接:把1024个神经元映射为10个
    softmax:取10个里面置信度最高的输出

实际代码:
model.py

import tensorflow as tf

def get_weights(shape):
    w_variable = tf.get_variable("w_variable", shape=shape, trainable=True)
    return w_variable

    # initial = tf.truncated_normal(shape,stddev=0.1)
    # return tf.Variable(initial)

def get_bias(shape):
    b_variable = tf.get_variable("b_variable", shape=shape, trainable=True)
    return b_variable

def conv2d_relu(input, filter, bias):
    features = tf.nn.conv2d(input=input,
                            filter=filter,
                            strides=[1, 1, 1, 1],
                            padding='SAME')
    return tf.nn.relu(features=(features + bias))

def max_pool(input):
    return tf.nn.max_pool(value=input,
                          ksize=[1, 2, 2, 1],
                          strides=[1, 2, 2, 1],
                          padding='SAME')

def forward(input_data,keep_prob):
    input_data = tf.reshape(input_data, [-1, 28, 28, 1])
    with tf.variable_scope("layer_1"):
        w1 = get_weights([5, 5, 1, 32])
        b1 = get_bias([32])
        conv1 = conv2d_relu(input_data, w1, b1)
        pool1 = max_pool(conv1)

    with tf.variable_scope("layer_2"):
        w2 = get_weights([5, 5, 32, 64])
        b2 = get_bias([64])
        conv2 = conv2d_relu(pool1, w2, b2)
        pool2 = max_pool(conv2)

    with tf.variable_scope("full_connection_1"):
        pool2_flat = tf.reshape(pool2, shape=[-1, 7 * 7 * 64])
        w3 = get_weights([7 * 7 * 64, 1024])
        b3 = get_bias([1024])
        fc1 = tf.nn.relu(tf.matmul(pool2_flat, w3) + b3)
        fc1_drop = tf.nn.dropout(fc1, keep_prob)

    with tf.variable_scope("full_connection_2"):
        w4 = get_weights([1024, 10])
        b4 = get_bias([10])
        fc2 = tf.matmul(fc1_drop, w4) + b4
        softmax = tf.nn.softmax(fc2)
    return softmax

需要注意的是:MINIST的图像数据是1D的格式,我们需要预先把它reshape为[-1, 28, 28, 1]的灰度图像,经过前向传播到达全连接层的时候,再把数据形状reshape为[-1, 7 x 7 x 64](7×7是最后pool2的大小,可以计算所得,详细见卷积和池化后尺寸计算),即把数据拉直。

训练神经网络

训练网络模型,就是让算法调整网络中的可以训练参数,拟合数据。我们使用交叉熵计算loss,通过Adam优化器来优化网络(学习率设置为1e-4)。最终我们把训练得到的模型存储到save_model/model.ckpt
实际代码:
train.py

import tensorflow as tf
import model
import input_data

def get_loss(input_lables, perdict_lables):
    cross_entropy = -tf.reduce_sum(input_lables * tf.log(perdict_lables))
    return cross_entropy

# prepare mnist data
mnist = input_data.read_data_sets('data', one_hot=True)

# super params
learning_rate = 1e-4

# placeholder
input_data = tf.placeholder(tf.float32, [None, 784])
input_labels = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)

forward_result = model.forward(input_data, keep_prob)

loss = get_loss(input_labels, forward_result)

correct_prediction = tf.equal(tf.argmax(forward_result, 1), tf.argmax(input_labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# backward
train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)

# save model
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(10000):
        batch_data, batch_labels = mnist.train.next_batch(10)
        _loss, _train_step, _accuracy = sess.run([loss, train_step, accuracy],
                                                 feed_dict={input_data: batch_data, input_labels: batch_labels,
                                                            keep_prob: 0.5})
        print ('step = {}, loss = {}, accuracy= {:.2f}'.format(i, _loss, _accuracy))

    saver.save(sess, 'save_model/model.ckpt')

    print('test accuracy {:.4f}'.format(
        accuracy.eval(feed_dict={input_data: mnist.test.images, input_labels: mnist.test.labels, keep_prob: 1})))

训练结果

file

参考数据

keep_prob 迭代次数 batch_size 测试集准确率
1.0 1000 50 0.9704
0.5 2000 20 0.9736
0.5 5000 10 0.9824
0.5 10000 10 0.9872

优化参数可以提高准确率

预测图片

利用刚才训练保存的模型来预测图像中的数字。由于手头没有实际图片,只能使用测试集里面的某一张图片来做测试。
代码如下:
predict.py

import tensorflow as tf
import model
import input_data
import numpy as np
import matplotlib.pyplot as plt

def show_image(mnist_image):
    image = np.reshape(mnist_image, newshape=[28, -1])
    plt.imshow(image, cmap=plt.get_cmap('gray_r'))
    plt.show()

mnist = input_data.read_data_sets('data', one_hot=True)

input_data = tf.placeholder(tf.float32, [None, 784])
keep_prob = tf.placeholder(tf.float32)
forward_result = model.forward(input_data,keep_prob)
predict_result = tf.argmax(forward_result, 1)
saver = tf.train.Saver()
index = 1 #测试集图片索引
with tf.Session() as sess:
    image = mnist.test.images[index]
    show_image(image)

    sess.run(tf.global_variables_initializer())
    saver.restore(sess, "save_model/model.ckpt")
    _predict_result = sess.run([predict_result], feed_dict={input_data: [image],keep_prob:1.0})

    _correct_result = np.argmax(mnist.test.labels[index])

    print('predict={} ,correct={}'.format(_predict_result[0][0], _correct_result))

需要注意的是:实际的手写数字图像(需要经过二值化处理为28x28x1的图像),输入该网络准确率会比较低,因为国外的书写习惯和国内的是不一样的。

预测结果

file
可以修改测试集图片索引index来预测测试集中的其他图片。

总结

文中还有一些函数没有做出解释,在接下来的文章中会作进一步的详解。比如dropout,softmax,max_pool,conv2d.

17
本文系作者 @迦娜王 原创发布在 松鼠乐园。未经许可,禁止转载。
MockJS根据参数返回不同的数据
上一篇
SSD:Single Shot MultiBox Detector
下一篇
评论 (0)

请登录以参与评论。

现在登录
17
相关文章
我收集了12款自动生成器,无聊人士自娱自乐专用
输入一张图,就能让二次元老婆动起来,宛如3D:这全是为了科学啊
使用ONNX+TensorRT部署人脸检测和关键点250fps
基于 Keras 的烟火检测
松鼠乐园

资源整合,创造价值

小伙伴
墨魇博客 无同创意
目录
重大新闻 Centos CSS Docker ES5 ES6 Go Java Javascript Linux Mac MySQL Nginx Redis Springboot Tensorflow Vue Vue2.x从零开始 Windows 书籍推荐 人工智能 前端资源 后端资源 壁纸 开源项目 测试 论文
Copyright © 2018-2022 松鼠乐园. Designed by nicetheme. 浙ICP备15039601号-4
  • 重大新闻
  • Centos
  • CSS
  • Docker
  • ES5
  • ES6
  • Go
  • Java
  • Javascript
  • Linux
  • Mac
  • MySQL
  • Nginx
  • Redis
  • Springboot
  • Tensorflow
  • Vue
  • Vue2.x从零开始
  • Windows
  • 书籍推荐
  • 人工智能
  • 前端资源
  • 后端资源
  • 壁纸
  • 开源项目
  • 测试
  • 论文
热门搜索
  • jetson nano
  • vue
  • java
  • mysql
  • 人工智能
  • 人脸识别
迦娜王
坚持才有希望
1224 文章
35 评论
242 喜欢
  • 17
  • 0
  • Top