Tensorflow 模型保存与调用

Tensorflow 两种保存模型的方式:pb 和  saved_model 都可以。

1、pb

1.1 模型保存成pb

freozen_pb.py

 1 import tensorflow as tf
 2 from tensorflow.python.framework import graph_util
 3 
 4 
 5 
 6 with tf.Session(graph=tf.Graph()) as sess:
 7     x = tf.placeholder(tf.int32, name='in_x')
 8     y = tf.placeholder(tf.int32, name='in_y')
 9     b = tf.Variable(1, name='b')
10     m = tf.multiply(x, y)
11     a = tf.add(m, b, name='out_add')
12 
13     sess.run(tf.global_variables_initializer())
14     constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['out_add'])
15 
16     feed_dict = {x: 10, y: 3}
17     print(sess.run(a, feed_dict))
18 
19     with tf.gfile.FastGFile('./model.pb', mode='wb') as f:
20         f.write(constant_graph.SerializeToString())

1.2 调用pb模型

call_pb.py

 1 import tensorflow as tf
 2 from tensorflow.python.platform import gfile
 3 
 4 
 5 sess = tf.Session()
 6 with gfile.FastGFile('./model.pb', 'rb') as f:
 7     graph_def = tf.GraphDef()
 8     graph_def.ParseFromString(f.read())
 9     sess.graph.as_default()
10     tf.import_graph_def(graph_def, name='')
11 
12 sess.run(tf.global_variables_initializer())
13 #print(sess.run('b:0'))
14 
15 in_x = sess.graph.get_tensor_by_name('in_x:0')
16 in_y = sess.graph.get_tensor_by_name('in_y:0')
17 out_add = sess.graph.get_tensor_by_name('out_add:0')
18 
19 ret = sess.run(out_add, feed_dict={in_x: 8, in_y: 9})
20 print(ret)

2、 saved_model

2.1 模型保存成saved model

freozen_sm.py

 1 import os
 2 import tensorflow as tf
 3 
 4 saved_model_path = os.getcwd()
 5 
 6 with tf.Session(graph=tf.Graph()) as sess:
 7     x = tf.placeholder(tf.int32, name='in_x')
 8     y = tf.placeholder(tf.int32, name='in_y')
 9     b = tf.Variable(1, name='b')
10     m = tf.multiply(x, y)
11     a = tf.add(m, b, name='out_add')
12 
13     sess.run(tf.global_variables_initializer())
14 
15     tf.saved_model.simple_save(sess, './sm', {'in_x': x, 'in_y': y}, {'out_add': a}, )

2.2 调用saved model模型

call_sm.py

 1 import tensorflow as tf
 2 
 3 sess = tf.Session()
 4 tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], './sm')
 5 in_x = sess.graph.get_tensor_by_name('in_x:0')
 6 in_y = sess.graph.get_tensor_by_name('in_y:0')
 7 out_add = sess.graph.get_tensor_by_name('out_add:0')
 8 
 9 ret = sess.run(out_add, feed_dict={in_x: 8, in_y: 5})
10 print(ret)
原文地址:https://www.cnblogs.com/vsignsoft/p/14000250.html