保存与加载Tensorflow模型

保存与加载模型

安装tensorflow-datasets,导入依赖项:

1
%pip install tensorflow-datasets
1
2
3
4
import tensorflow_datasets as tfds

import tensorflow as tf
tfds.disable_progress_bar()
1
tf.__version__
2.3.0
1
mirrored_strategy = tf.distribute.MirroredStrategy()

创建一个分发变量和图形的策略

tf.distribute.MirroredStrategy 策略是如何运作的?

所有变量和模型图都复制在副本上。
输入都均匀分布在副本中。
每个副本在收到输入后计算输入的损失和梯度。
通过求和,每一个副本上的梯度都能同步。
同步后,每个副本上的复制的变量都可以同样更新。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def get_data():
datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255

return image, label

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

return train_dataset, eval_dataset
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def get_model():
with mirrored_strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPool2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])

model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
return model

训练模型

1
model = get_model()
1
train_dataset, eval_dataset = get_data()
[1mDownloading and preparing dataset mnist/3.0.1 (download: Unknown size, generated: Unknown size, total: Unknown size) to C:\Users\v-xujwan\tensorflow_datasets\mnist\3.0.1...[0m
Shuffling and writing examples to C:\Users\v-xujwan\tensorflow_datasets\mnist\3.0.1.incompleteI371SH\mnist-train.tfrecord
Shuffling and writing examples to C:\Users\v-xujwan\tensorflow_datasets\mnist\3.0.1.incompleteI371SH\mnist-test.tfrecord
[1mDataset mnist downloaded and prepared to C:\Users\v-xujwan\tensorflow_datasets\mnist\3.0.1. Subsequent calls will reuse this data.[0m
1
model.fit(train_dataset, epochs=2)
Epoch 1/2
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\data\ops\multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.


WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\data\ops\multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.


938/938 [==============================] - 18s 19ms/step - loss: 0.2089 - accuracy: 0.9389
Epoch 2/2
938/938 [==============================] - 19s 20ms/step - loss: 0.0689 - accuracy: 0.97980s - los

<tensorflow.python.keras.callbacks.History at 0x2b4c939e278>

保存并加载模型

现在有了一个简单的模型可以使用,让我们看一下保存/加载API。有两套可用的API:

高级的keras model.save和tf.keras.models.load_model
低级的tf.saved_model.save和tf.saved_model.load

使用keras API

1
2
keras_model_path = "./tmp/keras_save"
model.save(keras_model_path)
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\tracking\tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\tracking\tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\tracking\tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\tracking\tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


INFO:tensorflow:Assets written to: ./tmp/keras_save\assets


INFO:tensorflow:Assets written to: ./tmp/keras_save\assets

模型保存成功,看一下保存的文件

└───tmp
    └───keras_save
        ├───assets
        └───variables
            └───variables.data-00000-of-00001
            └───variables.index
        └───saved_model.pb

接着还原模型

1
restored_keras_model = tf.keras.models.load_model(keras_model_path)

还原后的模型可以继续训练

1
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 16s 17ms/step - loss: 0.0494 - accuracy: 0.09890s - loss: 0.0493 - accuracy: 0.09
Epoch 2/2
938/938 [==============================] - 16s 17ms/step - loss: 0.0353 - accuracy: 0.0989

<tensorflow.python.keras.callbacks.History at 0x2b4c5b5a128>

现在加载模型并使用进行训练tf.distribute.Strategy

1
2
3
4
another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 16s 17ms/step - loss: 0.0501 - accuracy: 0.0990
Epoch 2/2
938/938 [==============================] - 16s 17ms/step - loss: 0.0354 - accuracy: 0.0989
1
restored_keras_model_ds.predict
<tensorflow.python.keras.engine.sequential.Sequential at 0x2b4c75ea2b0>

使用tf.saved_model API

现在使用低级的api,保存方法和keras类似

1
2
3
model = get_model()
saved_model_path = "./tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: ./tmp/tf_save\assets


INFO:tensorflow:Assets written to: ./tmp/tf_save\assets

可以使用进行加载tf.saved_model.load()。但是,由于它是一个较低级别的API(因此具有更广泛的用例范围),因此它不会返回Keras模型。相反,它返回一个对象,该对象包含可用于进行推断的函数。例如:

还可以以分布式方式加载和进行推断:

1
2
3
4
5
6
7
8
9
10
11
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

dist_predict_dataset = another_strategy.experimental_distribute_dataset(
predict_dataset)

# Calling the function in a distributed manner
for batch in dist_predict_dataset:
another_strategy.run(inference_func,args=(batch,))
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.


WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.


WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.


WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.


WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.


WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.


WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.


WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.


WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.


WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.


WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

保存检查点

检查站捕获所有参数(的精确值tf.Variable由模型中使用的对象)。检查点不包含由模型所定义的计算的任何描述,因此通常仅当将使用保存的参数值源代码可用有用。

在另一方面中SavedModel格式包括由除了参数值(检查点)模型中定义的计算的序列化描述。这种格式的模型是独立于创建模型的源代码。因此,它们适用于通过TensorFlow部署服务,TensorFlow精简版,TensorFlow.js,或在其他编程语言(的C,C ++,JAVA,围棋,防锈,C#等TensorFlow API)的程序。

本指南涵盖API进行写入和读出检查站。

建立

1
2
3
4
5
6
7
8
9
class Net(tf.keras.Model):
# 一个简单的线性模型

def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)

def call(self, x):
return self.l1(x)
1
net = Net()
1
net.save_weights('./tmp/easy_checkpoint')

写检查站

一个TensorFlow模型的持久状态被存储在tf.Variable对象。这些可以直接构造,但通常通过高级API等生成tf.keras.layers或tf.keras.Model 。

管理变量最简单的方法是将其安装到Python对象,然后引用这些对象。

的子类tf.train.Checkpoint , tf.keras.layers.Layer和tf.keras.Model自动跟踪分配给它们的属性变量。下面的例子构造了一个简单的线性模型,然后写入其中包含所有模型的变量值的检查站。

您可以轻松地保存模型检查点与Model.save_weights

手动检查点

1
2
3
4
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(dict(x=inputs, y=labels)).repeat().batch(2)
1
2
3
4
5
6
7
8
def train_step(net, example, optimizer):
with tf.GradientTape() as tape:
output = net(example['x'])
loss = tf.reduce_mean(tf.abs(output - example['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss

创建检查点的对象

手动进行检查点,您将需要一个tf.train.Checkpoint对象。凡检查点你想要的对象被设置为对象的属性。

一个tf.train.CheckpointManager也可用于管理多个检查点有帮助。

1
2
3
4
5
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tmp/tf_ckpts', max_to_keep=3)

训练和保存检查点模型

下面的训练循环创建模型的实例和优化的,然后收集他们入tf.train.Checkpoint对象。它在循环中调用数据的每批训练步骤,并定期检查点写入到磁盘。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("恢复点:{}".format(manager.latest_checkpoint))
else:
print("开始初始化")

for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("保存检查点 {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
1
train_and_checkpoint(net, manager)
开始初始化
保存检查点 10: ./tmp/tf_ckpts\ckpt-1
loss 29.78
保存检查点 20: ./tmp/tf_ckpts\ckpt-2
loss 23.19
保存检查点 30: ./tmp/tf_ckpts\ckpt-3
loss 16.63
保存检查点 40: ./tmp/tf_ckpts\ckpt-4
loss 10.17
保存检查点 50: ./tmp/tf_ckpts\ckpt-5
loss 4.09

恢复和继续训练

1
2
3
4
5
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tmp/tf_ckpts', max_to_keep=3)
1
train_and_checkpoint(net, manager)
恢复点:./tmp/tf_ckpts\ckpt-10
保存检查点 110: ./tmp/tf_ckpts\ckpt-11
loss 0.27
保存检查点 120: ./tmp/tf_ckpts\ckpt-12
loss 0.20
保存检查点 130: ./tmp/tf_ckpts\ckpt-13
loss 0.16
保存检查点 140: ./tmp/tf_ckpts\ckpt-14
loss 0.21
保存检查点 150: ./tmp/tf_ckpts\ckpt-15
loss 0.20
1
print(manager.checkpoints)
['./tmp/tf_ckpts\\ckpt-13', './tmp/tf_ckpts\\ckpt-14', './tmp/tf_ckpts\\ckpt-15']

这些路径,如’./tf_ckpts/ckpt-10’ ,不是磁盘上的文件。相反,它们是一个前缀index文件和包含可变值的一个或多个数据文件。这些前缀在单个组合在一起checkpoint文件( ‘./tf_ckpts/checkpoint’ ),其中CheckpointManager保存其状态。

手动检查

tf.train.list_variables列出了检查点键和变量的形状在一个检查点。检查点键是显示在以上图上的路径。

1
tf.train.list_variables(tf.train.latest_checkpoint('./tmp/tf_ckpts'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

保存与估计基于对象的检查站

通过默认保存变量名,而不是在前面的章节中描述的对象图检查点估计。 tf.train.Checkpoint将接受基于域名的检查点,但移动估计的模型以外的部位时,变量名可以更改model_fn 。保存基于对象的检查站,使得它更容易培养的估算内部模型,然后外面用它之一。

1
import tensorflow.compat.v1 as tf_compat
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def model_fn(features, labels, mode):
net = Net()
opt = tf.keras.optimizers.Adam(0.1)
ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
optimizers=opt, net=net)
with tf.GradientTape()as tape:
output = net(features['x'])
loss = tf.reduce_mean(tf.abs(output - features['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
return tf.estimator.EstimatorSpec(
mode,
loss=loss,
train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
ckpt.step.assign_add(1)),scaffold=tf_compat.train.Scaffold(saver=ckpt))
1
2
3
tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tmp/tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.


INFO:tensorflow:Using default config.


INFO:tensorflow:Using config: {'_model_dir': './tmp/tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


INFO:tensorflow:Using config: {'_model_dir': './tmp/tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.


WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...


INFO:tensorflow:Saving checkpoints for 0 into ./tmp/tf_estimator_example/model.ckpt.


INFO:tensorflow:Saving checkpoints for 0 into ./tmp/tf_estimator_example/model.ckpt.


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...


INFO:tensorflow:loss = 4.505075, step = 1


INFO:tensorflow:loss = 4.505075, step = 1


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...


INFO:tensorflow:Saving checkpoints for 10 into ./tmp/tf_estimator_example/model.ckpt.


INFO:tensorflow:Saving checkpoints for 10 into ./tmp/tf_estimator_example/model.ckpt.


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...


INFO:tensorflow:Loss for final step: 36.96539.


INFO:tensorflow:Loss for final step: 36.96539.


<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x2b4ccb5f588>
1
tf.train.latest_checkpoint('./tmp/tf_estimator_example')
'./tmp/tf_estimator_example\\model.ckpt-10'

tf.train.Checkpoint则可以从其加载估计的检查站model_dir 。

1
2
3
4
5
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tmp/tf_estimator_example/'))
ckpt.step.numpy()
10
0%