保存与加载模型
安装tensorflow-datasets,导入依赖项:
1 | %pip install tensorflow-datasets |
1 | import tensorflow_datasets as tfds |
1 | tf.__version__ |
2.3.0
1 | mirrored_strategy = tf.distribute.MirroredStrategy() |
创建一个分发变量和图形的策略
tf.distribute.MirroredStrategy 策略是如何运作的?
所有变量和模型图都复制在副本上。
输入都均匀分布在副本中。
每个副本在收到输入后计算输入的损失和梯度。
通过求和,每一个副本上的梯度都能同步。
同步后,每个副本上的复制的变量都可以同样更新。
1 | def get_data(): |
1 | def get_model(): |
训练模型
1 | model = get_model() |
1 | train_dataset, eval_dataset = get_data() |
[1mDownloading and preparing dataset mnist/3.0.1 (download: Unknown size, generated: Unknown size, total: Unknown size) to C:\Users\v-xujwan\tensorflow_datasets\mnist\3.0.1...[0m
Shuffling and writing examples to C:\Users\v-xujwan\tensorflow_datasets\mnist\3.0.1.incompleteI371SH\mnist-train.tfrecord
Shuffling and writing examples to C:\Users\v-xujwan\tensorflow_datasets\mnist\3.0.1.incompleteI371SH\mnist-test.tfrecord
[1mDataset mnist downloaded and prepared to C:\Users\v-xujwan\tensorflow_datasets\mnist\3.0.1. Subsequent calls will reuse this data.[0m
1 | model.fit(train_dataset, epochs=2) |
Epoch 1/2
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\data\ops\multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\data\ops\multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
938/938 [==============================] - 18s 19ms/step - loss: 0.2089 - accuracy: 0.9389
Epoch 2/2
938/938 [==============================] - 19s 20ms/step - loss: 0.0689 - accuracy: 0.97980s - los
<tensorflow.python.keras.callbacks.History at 0x2b4c939e278>
保存并加载模型
现在有了一个简单的模型可以使用,让我们看一下保存/加载API。有两套可用的API:
高级的keras model.save和tf.keras.models.load_model
低级的tf.saved_model.save和tf.saved_model.load
使用keras API
1 | keras_model_path = "./tmp/keras_save" |
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\tracking\tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\tracking\tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\tracking\tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\tracking\tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: ./tmp/keras_save\assets
INFO:tensorflow:Assets written to: ./tmp/keras_save\assets
模型保存成功,看一下保存的文件
└───tmp
└───keras_save
├───assets
└───variables
└───variables.data-00000-of-00001
└───variables.index
└───saved_model.pb
接着还原模型
1 | restored_keras_model = tf.keras.models.load_model(keras_model_path) |
还原后的模型可以继续训练
1 | restored_keras_model.fit(train_dataset, epochs=2) |
Epoch 1/2
938/938 [==============================] - 16s 17ms/step - loss: 0.0494 - accuracy: 0.09890s - loss: 0.0493 - accuracy: 0.09
Epoch 2/2
938/938 [==============================] - 16s 17ms/step - loss: 0.0353 - accuracy: 0.0989
<tensorflow.python.keras.callbacks.History at 0x2b4c5b5a128>
现在加载模型并使用进行训练tf.distribute.Strategy
1 | another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0") |
Epoch 1/2
938/938 [==============================] - 16s 17ms/step - loss: 0.0501 - accuracy: 0.0990
Epoch 2/2
938/938 [==============================] - 16s 17ms/step - loss: 0.0354 - accuracy: 0.0989
1 | restored_keras_model_ds.predict |
<tensorflow.python.keras.engine.sequential.Sequential at 0x2b4c75ea2b0>
使用tf.saved_model API
现在使用低级的api,保存方法和keras类似
1 | model = get_model() |
INFO:tensorflow:Assets written to: ./tmp/tf_save\assets
INFO:tensorflow:Assets written to: ./tmp/tf_save\assets
可以使用进行加载tf.saved_model.load()。但是,由于它是一个较低级别的API(因此具有更广泛的用例范围),因此它不会返回Keras模型。相反,它返回一个对象,该对象包含可用于进行推断的函数。例如:
还可以以分布式方式加载和进行推断:
1 | another_strategy = tf.distribute.MirroredStrategy() |
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
保存检查点
检查站捕获所有参数(的精确值tf.Variable由模型中使用的对象)。检查点不包含由模型所定义的计算的任何描述,因此通常仅当将使用保存的参数值源代码可用有用。
在另一方面中SavedModel格式包括由除了参数值(检查点)模型中定义的计算的序列化描述。这种格式的模型是独立于创建模型的源代码。因此,它们适用于通过TensorFlow部署服务,TensorFlow精简版,TensorFlow.js,或在其他编程语言(的C,C ++,JAVA,围棋,防锈,C#等TensorFlow API)的程序。
本指南涵盖API进行写入和读出检查站。
建立
1 | class Net(tf.keras.Model): |
1 | net = Net() |
1 | net.save_weights('./tmp/easy_checkpoint') |
写检查站
一个TensorFlow模型的持久状态被存储在tf.Variable对象。这些可以直接构造,但通常通过高级API等生成tf.keras.layers或tf.keras.Model 。
管理变量最简单的方法是将其安装到Python对象,然后引用这些对象。
的子类tf.train.Checkpoint , tf.keras.layers.Layer和tf.keras.Model自动跟踪分配给它们的属性变量。下面的例子构造了一个简单的线性模型,然后写入其中包含所有模型的变量值的检查站。
您可以轻松地保存模型检查点与Model.save_weights
手动检查点
1 | def toy_dataset(): |
1 | def train_step(net, example, optimizer): |
创建检查点的对象
手动进行检查点,您将需要一个tf.train.Checkpoint对象。凡检查点你想要的对象被设置为对象的属性。
一个tf.train.CheckpointManager也可用于管理多个检查点有帮助。
1 | opt = tf.keras.optimizers.Adam(0.1) |
训练和保存检查点模型
下面的训练循环创建模型的实例和优化的,然后收集他们入tf.train.Checkpoint对象。它在循环中调用数据的每批训练步骤,并定期检查点写入到磁盘。
1 | def train_and_checkpoint(net, manager): |
1 | train_and_checkpoint(net, manager) |
开始初始化
保存检查点 10: ./tmp/tf_ckpts\ckpt-1
loss 29.78
保存检查点 20: ./tmp/tf_ckpts\ckpt-2
loss 23.19
保存检查点 30: ./tmp/tf_ckpts\ckpt-3
loss 16.63
保存检查点 40: ./tmp/tf_ckpts\ckpt-4
loss 10.17
保存检查点 50: ./tmp/tf_ckpts\ckpt-5
loss 4.09
恢复和继续训练
1 | opt = tf.keras.optimizers.Adam(0.1) |
1 | train_and_checkpoint(net, manager) |
恢复点:./tmp/tf_ckpts\ckpt-10
保存检查点 110: ./tmp/tf_ckpts\ckpt-11
loss 0.27
保存检查点 120: ./tmp/tf_ckpts\ckpt-12
loss 0.20
保存检查点 130: ./tmp/tf_ckpts\ckpt-13
loss 0.16
保存检查点 140: ./tmp/tf_ckpts\ckpt-14
loss 0.21
保存检查点 150: ./tmp/tf_ckpts\ckpt-15
loss 0.20
1 | print(manager.checkpoints) |
['./tmp/tf_ckpts\\ckpt-13', './tmp/tf_ckpts\\ckpt-14', './tmp/tf_ckpts\\ckpt-15']
这些路径,如’./tf_ckpts/ckpt-10’ ,不是磁盘上的文件。相反,它们是一个前缀index文件和包含可变值的一个或多个数据文件。这些前缀在单个组合在一起checkpoint文件( ‘./tf_ckpts/checkpoint’ ),其中CheckpointManager保存其状态。
手动检查
tf.train.list_variables列出了检查点键和变量的形状在一个检查点。检查点键是显示在以上图上的路径。
1 | tf.train.list_variables(tf.train.latest_checkpoint('./tmp/tf_ckpts')) |
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
[1, 5]),
('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
[1, 5]),
('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
('step/.ATTRIBUTES/VARIABLE_VALUE', [])]
保存与估计基于对象的检查站
通过默认保存变量名,而不是在前面的章节中描述的对象图检查点估计。 tf.train.Checkpoint将接受基于域名的检查点,但移动估计的模型以外的部位时,变量名可以更改model_fn 。保存基于对象的检查站,使得它更容易培养的估算内部模型,然后外面用它之一。
1 | import tensorflow.compat.v1 as tf_compat |
1 | def model_fn(features, labels, mode): |
1 | tf.keras.backend.clear_session() |
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tmp/tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': './tmp/tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tmp/tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into ./tmp/tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.505075, step = 1
INFO:tensorflow:loss = 4.505075, step = 1
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tmp/tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into ./tmp/tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 36.96539.
INFO:tensorflow:Loss for final step: 36.96539.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x2b4ccb5f588>
1 | tf.train.latest_checkpoint('./tmp/tf_estimator_example') |
'./tmp/tf_estimator_example\\model.ckpt-10'
tf.train.Checkpoint则可以从其加载估计的检查站model_dir 。
1 | opt = tf.keras.optimizers.Adam(0.1) |
10