TensorFlow 2 中文文档 - 保存与加载模型
TensorFlow2 文档系列文章链接:
TensorFlow 2 / 2.0 中文文档
(Jul 9, 2019)
TensorFlow 2 中文文档 - MNIST 图像分类
(Jul 9, 2019)
TensorFlow 2 中文文档 - IMDB 文本分类
(Jul 9, 2019)
TensorFlow 2 中文文档 - 特征工程结构化数据分类
(Jul 9, 2019)
TensorFlow 2 中文文档 - 回归预测燃油效率
(Jul 11, 2019)
TensorFlow 2 中文文档 - 过拟合与欠拟合
(Jul 12, 2019)
TensorFlow 2 中文文档 - 保存与加载模型
(Jul 13, 2019)
TensorFlow 2 中文文档 - 卷积神经网络分类 CIFAR-10
(Jul 19, 2019)
TensorFlow 2 中文文档 - TFHub 迁移学习
(Jul 19, 2019)
TensorFlow 2 中文文档 - RNN LSTM 文本分类
(Jul 22, 2019)
源代码/数据集已上传到
Github - tensorflow2-docs-zh
TF2.0 TensorFlow 2 / 2.0 中文文档:保存与加载模型 Save and Restore model
主要内容:使用 tf.keras
接口训练、保存、加载模型,数据集选用 MNIST 。
1 2 $ pip install -q tensorflow==2.0.0-beta1 $ pip install -q h5py pyyaml
准备训练数据 1 2 3 4 5 6 7 8 9 10 11 12 import tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import datasets, layers, models, callbacksfrom tensorflow.keras.datasets import mnistimport osfile_path = os.path.abspath('./mnist.npz' ) (train_x, train_y), (test_x, test_y) = datasets.mnist.load_data(path=file_path) train_y, test_y = train_y[:1000 ], test_y[:1000 ] train_x = train_x[:1000 ].reshape(-1 , 28 * 28 ) / 255.0 test_x = test_x[:1000 ].reshape(-1 , 28 * 28 ) / 255.0
搭建模型 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 def create_model (): model = models.Sequential([ layers.Dense(512 , activation='relu' , input_shape=(784 ,)), layers.Dropout(0.2 ), layers.Dense(10 , activation='softmax' ) ]) model.compile (optimizer='adam' , metrics=['accuracy' ], loss='sparse_categorical_crossentropy' ) return model def evaluate (target_model ): _, acc = target_model.evaluate(test_x, test_y) print("Restore model, accuracy: {:5.2f}%" .format (100 *acc))
自动保存 checkpoints 这样做,一是训练结束后得到了训练好的模型,使用得不必再重新训练,二是训练过程被中断,可以从断点处继续训练。
设置tf.keras.callbacks.ModelCheckpoint
回调可以实现这一点。
1 2 3 4 5 6 7 8 9 10 11 checkpoint_path = "training_2/cp-{epoch:04d}.ckpt" checkpoint_dir = os.path.dirname(checkpoint_path) cp_callback = callbacks.ModelCheckpoint( checkpoint_path, verbose=1 , save_weights_only=True , period=10 ) model = create_model() model.save_weights(checkpoint_path.format (epoch=0 )) model.fit(train_x, train_y, epochs=50 , callbacks=[cp_callback], validation_data=(test_x, test_y), verbose=0 )
1 2 3 4 5 Epoch 00010: saving model to training_2/cp-0010. ckpt Epoch 00020: saving model to training_2/cp-0020. ckpt Epoch 00030: saving model to training_2/cp-0030. ckpt Epoch 00040: saving model to training_2/cp-0040. ckpt Epoch 00050: saving model to training_2/cp-0050. ckpt
加载权重:
1 2 3 4 5 latest = tf.train.latest_checkpoint(checkpoint_dir) model = create_model() model.load_weights(latest) evaluate(model)
1 2 1000/1000 [===] - 0s 90us/sample - loss: 0.4703 - accuracy: 0.8780 Restore model, accuracy: 87.80%
手动保存权重 1 2 3 4 5 model.save_weights('./checkpoints/mannul_checkpoint' ) model = create_model() model.load_weights('./checkpoints/mannul_checkpoint' ) evaluate(model)
1 2 1000 /1000 [===] - 0s 90us/sample - loss: 0.4703 - accuracy: 0.8780 Restore model, accuracy: 87.80 %
保存整个模型 上面的示例仅仅保存了模型中的权重(weights),模型和优化器都可以一起保存,包括权重(weights)、模型配置(architecture)和优化器配置(optimizer configuration)。这样做的好处是,当你恢复模型时,完全不依赖于原来搭建模型的代码。
保存完整的模型有很多应用场景,比如在浏览器中使用 TensorFlow.js 加载运行,比如在移动设备上使用 TensorFlow Lite 加载运行。
HDF5 直接调用model.save
即可保存为 HDF5 格式的文件。
1 model.save('my_model.h5' )
从 HDF5 中恢复完整的模型。
1 2 new_model = models.load_model('my_model.h5' ) evaluate(new_model)
1 2 1000/1000 [===] - 0s 90us/sample - loss: 0.4703 - accuracy: 0.8780 Restore model, accuracy: 87.80%
saved_model 保存为saved_model
格式。
1 2 3 import timesaved_model_path = "./saved_models/{}" .format (int (time.time())) tf.keras.experimental.export_saved_model(model, saved_model_path)
恢复模型并预测
1 2 new_model = tf.keras.experimental.load_from_saved_model(saved_model_path) model.predict(test_x).shape
saved_model
格式的模型可以直接用来预测(predict),但是 saved_model 没有保存优化器配置,如果要使用evaluate
方法,则需要先 compile。
1 2 3 4 new_model.compile (optimizer=model.optimizer, loss='sparse_categorical_crossentropy' , metrics=['accuracy' ]) evaluate(new_model)
1 2 1000/1000 [===] - 0s 90us/sample - loss: 0.4703 - accuracy: 0.8780 Restore model, accuracy: 87.80%
最后 TensorFlow 中还有其他的方式可以保存模型。
返回文档首页
完整代码:Github - save_restore_model.ipynb 参考文档:Save and restore models
附 推荐
上一篇 « TensorFlow 2 中文文档 - 过拟合与欠拟合
下一篇 » TensorFlow 2 中文文档 - 卷积神经网络分类 CIFAR-10