极客兔兔

TensorFlow入门(二) - mnist手写数字识别(模型保存加载)

源代码/数据集已上传到 Github - tensorflow-tutorial-samples

这篇文章是 TensorFlow Tutorial 入门教程的第二篇文章。

上一篇文章TensorFlow入门(一) - mnist手写数字识别(网络搭建)介绍了神经网络输入输出独热编码损失函数等最基本的知识,并且演示了如何用最简单的模型实现mnist手写数字识别91%的正确率。但是遗留的问题是,模型保存在内存中,每次都得重新开始训练。

这篇文章解决的就是这个问题。将依次介绍tensorflow中如何保存已经训练好的模型,如何在某个训练步数的基础上继续训练,最后将演示如何加载模型,并借助pillow(Python2中称为PIL)库实现真实手写数字图片的识别。

模型的保存

  • 首先看一下项目的目录结构
1
2
3
4
5
6
7
8
9
10
11
|--mnist/
|--data_set/ 训练以及测试数据集
|--test_images/ 多张测试图片
|--0.png
|--1.png
|--4.png
|--v2/
|--ckpt/ 模型保存在这里!!!
|--model.py 网络模型
|--train.py 训练代码
|--predict.py 预测代码

第一步更改模型,记录global_step

每一次训练,会进行一次梯度下降,传入的global_step的值会自增1,因此,可以通过计算global_step这个张量的值,知道当前训练了多少步。

1
2
3
4
5
6
7
8
9
10
11
12
# model.py
class Network:
def __init__(self):
# 记录已经训练的次数
self.global_step = tf.Variable(0, trainable=False)

# ... 中间省略网络结构

# minimize 可传入参数 global_step, 每次训练 global_step的值会增加1
# 因此,可以通过计算self.global_step这个张量的值,知道当前训练了多少步
self.train = tf.train.GradientDescentOptimizer(0.001).minimize(
self.loss, global_step=self.global_step)

第二步,每隔N步保存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
CKPT_DIR = 'ckpt' # 定义模型存储的位置
net = Network()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# tf.train.Saver是用来保存训练结果的。
# max_to_keep 用来设置最多保存多少个模型,默认是5
# 如果保存的模型超过这个值,最旧的模型将被删除
saver = tf.train.Saver(max_to_keep=10)

train_step = 10000 # 总的训练次数10000
step = 0 # 记录训练次数, 初始化为0
save_interval = 1000 # 每隔1000步保存模型

while step < train_step:
# ...省略训练代码

step = sess.run(net.global_step)
# 模型保存在ckpt文件夹下
# 模型文件名最后会增加global_step的值,比如1000的模型文件名为 model-1000
if step % save_interval == 0:
saver.save(sess, CKPT_DIR + '/model', global_step=step)
  • 最终保存的模型如下所示

假设训练到了2000步,保存了2次模型。ckpt文件夹下会生成7个文件,第一个文件是 checkpoint文件,保存了所有的模型的路径。其中第一行代表当前的状态,即在加载模型时,使用哪一个模型是由第一行决定的。

每个模型包含3个文件,分别是

  1. model-xxx.data-00000-of-00001
  2. model-xxx.index
  3. model-xxx.meta

checkpoint文件

1
2
3
model_checkpoint_path: "model-2000"
all_model_checkpoint_paths: "model-1000"
all_model_checkpoint_paths: "model-2000"

目录结构

1
2
3
4
5
6
7
8
9
10
11
12
|--v2/  
|--ckpt/ 模型保存在这里!!!
|--checkpoint
|--model-1000.data-00000-of-00001
|--model-1000.index
|--model-1000.meta
|--model-2000.data-00000-of-00001
|--model-2000.index
|--model-2000.meta
|--model.py 网络模型
|--train.py 训练代码
|--predict.py 预测代码

加载模型与继续训练(train.py)

假设我们当前模型已经训练到了2000步,但是由于某种原因停止了。那么是否可以在2000步的基础上继续训练呢?

  • 只需一步,训练前保存的模型restore到session中即可。这里需要注意的是,创建 tf.train.Saver对象一定要在创建tf.Session之后。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
CKPT_DIR = 'ckpt'
net = Network()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=10)

train_step = 10000
step = 0
save_interval = 1000

# 开始训练前,检查ckpt文件夹,看是否有checkpoint文件存在。
# 如果存在,则读取checkpoint文件指向的模型,restore到sess中。
ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
# 读取网络中的global_step的值,即当前已经训练的次数
step = sess.run(net.global_step)
print('Continue from')
print(' -> Minibatch update : ', step)

while step < train_step:
# ...省略训练代码
  • 再次运行代码,将打印出
1
2
3
Continue from
-> Minibatch update : 2000
第 3000步,...
  • 如果将checkpoint文件的第一行改为如下,训练将从1000开始,再次训练到2000时,会将原来的2000的模型覆盖。所以restore哪一个模型,只与checkpoint的第一行有关,即只与model_checkpoint_path有关。
    1
    model_checkpoint_path: "model-1000"
1
2
3
Continue from
-> Minibatch update : 1000
第 2000步,...

使用模型预测数字(predict.py)

第一步,restore模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np
from PIL import Image


class Predict:
def __init__(self):
self.net = Network()
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
self.restore() # 加载模型到sess中

def restore(self):
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(self.sess, ckpt.model_checkpoint_path)
else:
raise FileNotFoundError("未保存任何模型")

def predict(self, image_path):
# ...省略

第二步读入图片并预测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Predict:
# ...

def predict(self, image_path):
# 读图片并转为黑白的
img = Image.open(image_path).convert('L')
flatten_img = np.reshape(img, 784)
x = np.array([1 - flatten_img])
y = self.sess.run(self.net.y, feed_dict={self.net.x: x})

# 因为x只传入了一张图片,取y[0]即可
# np.argmax()取得独热编码最大值的下标,即代表的数字
print(image_path)
print(' -> Predict digit', np.argmax(y[0]))
  • test_images目录下的0.png1.png4.png三张图片的预测结果。
    1
    2
    3
    4
    app = Predict()
    app.predict('../test_images/0.png')
    app.predict('../test_images/1.png')
    app.predict('../test_images/4.png')

最后的结果

  • 第一次 python train.py

    1
    2
    第 1000步,当前loss:26.94
    第 2000步,当前loss:28.36
  • 2000步时停止,第二次 python train.py

    1
    2
    3
    4
    5
    Continue from
    -> Minibatch update : 2000
    第 3000步,当前loss:23.49
    第 4000步,当前loss:20.40
    第 5000步,当前loss:11.65
  • python predict.py

    1
    2
    3
    4
    5
    6
    ../test_images/0.png
    -> Predict digit 0
    ../test_images/1.png
    -> Predict digit 1
    ../test_images/4.png
    -> Predict digit 4

源代码&数据集已上传到 Github

觉得还不错,不要吝惜你的star,支持是持续不断更新的动力。


专题:

本文发表于 2017-12-17,最后修改于 2020-04-01。

本站永久域名geektutu.com,也可搜索「 极客兔兔 」找到我。

期待关注我的 知乎专栏微博 ,查看最近的文章和动态。


上一篇 « TensorFlow入门(一) - mnist手写数字识别(网络搭建) 下一篇 » Pandas 数据处理(一) - DataFrame 与 Series

赞赏支持

请我吃胡萝卜 =^_^=

i ali

支付宝

i wechat

微信

推荐阅读

Big Image