TensorFlow入门(二) - mnist手写数字识别(模型保存加载)
源代码/数据集已上传到 Github - tensorflow-tutorial-samples
这篇文章是 TensorFlow Tutorial 入门教程的第二篇文章。
上一篇文章TensorFlow入门(一) - mnist手写数字识别(网络搭建)介绍了神经网络输入
、输出
、独热编码
、损失函数
等最基本的知识,并且演示了如何用最简单的模型实现mnist手写数字识别91%的正确率。但是遗留的问题是,模型保存在内存中,每次都得重新开始训练。
这篇文章解决的就是这个问题。将依次介绍tensorflow中如何保存
已经训练好的模型,如何在某个训练步数的基础上继续训练
,最后将演示如何加载模型
,并借助pillow(Python2中称为PIL)库实现真实手写数字图片的识别。
模型的保存
- 首先看一下项目的目录结构
1 | |--mnist/ |
第一步更改模型,记录global_step
每一次训练,会进行一次梯度下降,传入的global_step的值会自增1,因此,可以通过计算global_step这个张量的值,知道当前训练了多少步。
1 | # model.py |
第二步,每隔N步保存
1 | CKPT_DIR = 'ckpt' # 定义模型存储的位置 |
- 最终保存的模型如下所示
假设训练到了2000步,保存了2次模型。ckpt文件夹下会生成7个文件,第一个文件是 checkpoint文件,保存了所有的模型的路径。其中第一行代表当前的状态,即在加载模型时,使用哪一个模型是由第一行决定的。
每个模型包含3个文件,分别是
- model-xxx.data-00000-of-00001
- model-xxx.index
- model-xxx.meta
checkpoint文件
1 | model_checkpoint_path: "model-2000" |
目录结构
1 | |--v2/ |
加载模型与继续训练(train.py)
假设我们当前模型已经训练到了2000步,但是由于某种原因停止了。那么是否可以在2000步的基础上继续训练呢?
- 只需一步,训练前保存的模型restore到session中即可。这里需要注意的是,创建
tf.train.Saver
对象一定要在创建tf.Session
之后。
1 | CKPT_DIR = 'ckpt' |
- 再次运行代码,将打印出
1 | Continue from |
- 如果将checkpoint文件的第一行改为如下,训练将从1000开始,再次训练到2000时,会将原来的2000的模型覆盖。所以restore哪一个模型,只与checkpoint的第一行有关,即只与
model_checkpoint_path
有关。1
model_checkpoint_path: "model-1000"
1 | Continue from |
使用模型预测数字(predict.py)
第一步,restore模型
1 | import numpy as np |
第二步读入图片并预测
1 | class Predict: |
- test_images目录下的
0.png
,1.png
,4.png
三张图片的预测结果。1
2
3
4app = Predict()
app.predict('../test_images/0.png')
app.predict('../test_images/1.png')
app.predict('../test_images/4.png')
最后的结果
第一次 python train.py
1
2第 1000步,当前loss:26.94
第 2000步,当前loss:28.362000步时停止,第二次 python train.py
1
2
3
4
5Continue from
-> Minibatch update : 2000
第 3000步,当前loss:23.49
第 4000步,当前loss:20.40
第 5000步,当前loss:11.65python predict.py
1
2
3
4
5
6../test_images/0.png
-> Predict digit 0
../test_images/1.png
-> Predict digit 1
../test_images/4.png
-> Predict digit 4
源代码&数据集已上传到 Github
觉得还不错,不要吝惜你的star,支持是持续不断更新的动力。
附 推荐
上一篇 « TensorFlow入门(一) - mnist手写数字识别(网络搭建) 下一篇 » Pandas 数据处理(一) - DataFrame 与 Series