Saver 保存与读取

tensorflow 框架下的Saver 功能,用以保存和读取运算数据

Saver 保存数据

代码

import tensorflow as tf

# Save to file
#remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')


init = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:
   sess.run(init)
   save_path = saver.save(sess, "my_net/save_net.ckpt")
   print("Save to path: ", save_path)

生成文件截图

运行代码,会在当前目录下生成一个新文件夹my_net ,文件夹内容有

解释

  1. 这里我们定义了两个张量,2行3列的W,和1行3列的b。这里强调行列形状 ,原因是只有存储张量的形状和读取时张量形状相同,才能被读取成功。

  2. 并且这里的W和b都定义了name ,name是读取时候对应变量的关键 --'weights'和'biases'。和张量符号W和b没什么关系。

  3. 定义文件扩展名为ckpt ,因为官方是这样定义的。

Saver 读取数据

import tensorflow as tf

W = tf.Variable(tf.zeros([2,3]), dtype=tf.float32, name="weights")
b = tf.Variable(tf.zeros([1,3]), dtype=tf.float32, name="biases")

# not need init step

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "my_net/save_net.ckpt")
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))

打印结果

可以看到,读取数据文件代码定义张量W和b为全0,经过Saver 读取处理后,张量的数值成为保存文件中的数值。

解释

  1. 用Saver从文件读取,然后把读到的张量自动赋值给name相同 的张量

  2. 注意在读取代码中,张量被定义,但是没有初始化环节(sess.run(init)这一步) ,因为读取文件中的张量已经被初始化过了,这里就不用了

原文地址:https://www.cnblogs.com/maskerk/p/9984179.html