TensorFlow模型存储为PB形式【BILSTM+CRF 】

## TensorFlow模型存储为PB形式【BILSTM+CRF 】
为什么要采用SavedModel格式呢?其主要优点是SaveModel与语言无关,比如可以使用python语言训练模型,然后在Java中非常方便的加载模型。当然这也不是说checkpoints模型格式做不到,只是在跨语言时比较麻烦。另外如果使用Tensorflow Serving server来部署模型,必须选择SavedModel格式。

**SavedModel包含啥?**
一个比较完整的SavedModel模型包含以下内容:

> assets/
> assets.extra/
> variables/
>         - variables.data-*****-of-*****
>         - variables.index
> saved_model.pb

saved_model.pb是MetaGraphDef,它包含图形结构。
variables文件夹保存训练所习得的权重。
assets文件夹可以添加可能需要的外部文件,assets.extra是一个库可以添加其特定assets的地方。


MetaGraph是一个数据流图,加上其相关的变量、assets和签名。MetaGraphDef是MetaGraph的Protocol Buffer表示。

assets和assets.extra是可选的,比如本次保存的模型只包含以下的内容:

> variables/
>        variables.data-*****-of-*****
>        variables.index
> saved_model.pb

一、 训练结果checkPoint转存为PB
1.读取checkPoint数据

graph2 = tf.Graph()
with graph2.as_default():
m = BiLSTM_CRF(args)
saver = tf.train.import_meta_graph("{}model-35640.meta".format('model_path'))
with tf.Session(graph=graph2) as session:
saver.restore(session, tf.train.latest_checkpoint('model_path')) #加载ckpt模型
export_model(session, m)

2.转存为PB形式
**参数定义**

a = session.graph.get_tensor_by_name("a:0")
b = session.graph.get_tensor_by_name("b:0")
c = session.graph.get_tensor_by_name("c:0")
d = session.graph.get_tensor_by_name("d:0")

x = session.graph.get_tensor_by_name('x:0')
y = session.graph.get_tensor_by_name('y:0')
**结构定义**

**signature对象**,这个对象包含了计算图中输入与输出张量的键值对信息,键即是张量名,值即是protobuff结构的张量。

prediction_signature = signature_def_utils.build_signature_def(
inputs={"a": utils.build_tensor_info(a), # 将张量转为protobuff结构的快捷方法,也就是说下面的输入abcd 以及输出 x y都是经过该函数处理之后的结果。
"b": utils.build_tensor_info(b), # Protobuf是一种平台无关、语言无关、可扩展且轻便高效的序列化数据结构的协议,可以用于网络通信和数据存储。
"c": utils.build_tensor_info(c),
"d": utils.build_tensor_info(d)},

outputs={
"x": utils.build_tensor_info(x),
"y": utils.build_tensor_info(y)},

method_name=signature_constants.PREDICT_METHOD_NAME)

export_path = 'result_path'
if os.path.exists(export_path):
os.system("rm -rf " + export_path)
print("Export the model to {}".format(export_path))


**图定义、存储**

try:
legacy_init_op = tf.group(
tf.tables_initializer(), name='legacy_init_op')
builder = saved_model_builder.SavedModelBuilder(export_path)

#可以自己定义tag,在签名的定义上更加灵活。
一个模型可以包含不同的MetaGraphDef,保存图形的CPU版本和GPU版本,或者你想区分训练和发布版本。这个时候tag就可以用来区分不同的MetaGraphDef,加载的时候能够根据tag来加载模型的不同计算图。

builder.add_meta_graph_and_variables(
session, [tag_constants.SERVING], # 系统会给一个默认的tag: “serve”,也可以用tag_constants.SERVING这个常量。
clear_devices=True,
signature_def_map={
'predict_images':
prediction_signature,
},
# legacy_init_op=legacy_init_op,
main_op=tf.tables_initializer(),
strip_default_attrs=True
)


builder.save()
print('Done exporting!')

except Exception as e:
print("Fail to export saved model, exception: {}".format(e))


2.加载PB Model

session = tf.Session(graph=tf.Graph())
model_file_path = 'data_path_save'
meta_graph = tf.saved_model.loader.load(session, [tf.saved_model.tag_constants.SERVING], model_file_path)
model_graph_signature = list(meta_graph.signature_def.items())[0][1] #存在多个图结构时
output_tensor_names = []
output_op_names = []
for output_item in model_graph_signature.outputs.items():
output_op_name = output_item[0]
output_op_names.append(output_op_name)
output_tensor_name = output_item[1].name
output_tensor_names.append(output_tensor_name)
print("load model finish!")


3.进行预测


sentences = {}
# 测试pb模型
for test_x in [['周杰伦爱喝奶茶'],['习主席在陕西'],['刘德华的老婆是不是朱丽倩'],['张学友的老婆是谁咱不知道'],['解放牌卡车是不是在北京生产的']]:
sentences, seq_len_list = _preprocess(test_x)

feed_dict_map = {}
for input_item in model_graph_signature.inputs.items():
input_op_name = input_item[0]
input_tensor_name = input_item[1].name
feed_dict_map[input_tensor_name] = sentences[input_op_name]

logits, transition_params = session.run(output_tensor_names, feed_dict=feed_dict_map)

tag = predict(logits,transition_params, seq_len_list)

label2tag = {}
for tag, label in tag2label.items():
label2tag[label] = tag if label != 0 else label

tag = [label2tag[label] for label in label_list[0]]

PER, LOC, ORG = get_entity(tag, list(''.join(test_x).strip()))
print('PER: {} LOC: {} ORG: {}'.format(PER, LOC, ORG))

原文地址:https://www.cnblogs.com/hanhaotian/p/12875695.html