TensorFlow入门(一) - mnist手写数字识别(网络搭建)
源代码/数据集已上传到 Github - tensorflow-tutorial-samples
这篇文章是 TensorFlow Tutorial 入门教程的第一篇文章。主要介绍了如何从0开始用tensorflow搭建最简单的网络进行训练。
mnist数据集
简介
MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片。在机器学习中的地位相当于Python入门的打印Hello World
。官网是THE MNIST DATABASE of handwritten digits
该数据集包含以下四个部分
- train-images-idx3-ubyte.gz: 训练集-图片,6w
- train-labels-idx1-ubyte.gz: 训练集-标签,6w
- t10k-images-idx3-ubyte.gz: 测试集-图片,1w
- t10k-labels-idx1-ubyte.gz: 测试集-标签,1w
图片和标签
mnist数据集里的每张图片大小为28 * 28像素,可以用28 * 28的大小的数组来表示一张图片。
标签用大小为10的数组来表示,这种编码我们称之为One hot(独热编码)。
One-hot编码(独热编码)
独热编码使用N位代表N种状态,任意时候只有其中一位有效。
采用独热编码的例子
1 | 性别: |
独热编码的优点在于
- 能够处理非连续型数值特征
- 在一定程度上也扩充了特征。比如性别本身是一个特征,经过编码以后,就变成了男或女两个特征。
在神经网络中,独热编码其实具有很强的容错性,比如神经网络的输出结果是 [0,0.1,0.2,0.7,0,0,0,0,0, 0]转成独热编码后,表示数字3。即值最大的地方变为1,其余均为0。[0,0.1,0.4,0.5,0,0,0,0,0, 0]也能表示数字3。
numpy中有一个函数,numpy.argmax()可以取得最大值的下标。
神经网络的重要概念
输入(x)输出(y)、标签(label)
- 输入是指传入给网络处理的向量,相当于数学函数中的变量。
- 输出是指网络处理后返回的结果,相当于数据函数中的函数值。
- 标签是指我们期望网络返回的结果。
对于识别mnist图片而言,输入是大小为784(28 * 28)的向量,输出是大小为10的概率向量(概率最大的位置,即预测的数字)。
损失函数(loss function)
损失函数评估网络模型的好坏,值越大,表示模型越差,值越小,表示模型越好。因为传入大量的训练集训练的目标,就是将损失函数的值降到最小。
常见的损失函数定义:
- 差的平方和 sum((y - label)^2)
1 | [0, 0, 1] 与 [0.1, 0.3, 0.6]的差的平方和为 0.01 + 0.09 + 0.16 = 0.26 |
- 交叉熵 -sum(label * log(y))
1 |
|
当label为0时,交叉熵为0,label为1时,交叉熵为-log(y),交叉熵只关注独热编码中有效位的损失。这样屏蔽了无效位值的变化(无效位的值的变化并不会影响最终结果),并且通过取对数放大了有效位的损失。当有效位的值趋近于0时,交叉熵趋近于正无穷大。
回归模型
我们可以将网络理解为一个函数,回归模型,其实是希望对这个函数进行拟合。
比如定义模型为 Y = X * w + b,对应的损失即
1 | loss = (Y - labal)^2 |
可以通过不断地传入X和label的值,来修正w和b,使得最终得到的Y与label的loss最小。这个训练的过程,可以采用梯度下降的方法。通过梯度下降,找到最快的方向,调整w和b值,使得w * X + b的值越来越接近label。
梯度下降的具体过程,就不在这篇文章中展开了。
学习速率
简单说,梯度即一个函数的斜率,找到函数的斜率,其实就知道了w和b的值往哪个方向调整,能够让函数值(loss)降低得最快。那么方向知道了,往这个方向调整多少呢?这个数,神经网络中称之为学习速率。学习速率调得太低,训练速度会很慢,学习速率调得过高,每次迭代波动会很大。
softmax激活函数
本文不展开讲解softmax激活函数。事实上,再计算交叉熵前的Y值是经过softmax后的,经过softmax后的Y,并不影响Y向量的每个位置的值之间的大小关系。大致有2个作用,一是放大效果,二是梯度下降时需要一个可导的函数。
1 | def softmax(x): |
Tensorflow识别手写数字
构造网络 model.py
1 | import tensorflow as tf |
训练 train.py
1 | import tensorflow as tf |
验证准确率 train.py
1 | class Train: |
主函数 train.py
1 | if __name__ == "__main__": |
项目已更新在Github,数据集由于国内网络等因素,有时候不能正确下载,所以数据集也一并同步了。
觉得还不错,不要吝惜你的star,支持是持续不断更新的动力。