Bert tensorflow 版本的线上预测demo

在模型上线预测时,使用pb格式模型,确定输入tensor和输出tensor,明确对应的节点即可。以下代码是最近做的ner模型的infer部分,大家可以参照修改自己的模型

import tensorflow as tf
import os
import pickle
from bert_crf import tokenization

model_dir = r'crf_output_bak/'
output_graph = './pb_model/query_model.pb'
bert_dir = r'chinese_L-12_H-768_A-12'

# 加载label->id的词典
with open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf:
    label2id = pickle.load(rf)
    id2label = {value: key for key, value in label2id.items()}

with open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf:
    label_list = pickle.load(rf)
num_labels = len(label_list)

tokenizer = tokenization.FullTokenizer(
        vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=True)


def load_pb_predict():
    """加载pb预测
    """
    text = ['深汕特别合作区']
    # print('input the test sentence:	{}'.format(sentence_all))
    # sentence = str(input())
    sentence = [[s for s in str(each)] for each in text]
    input_ids, input_mask, = convert(sentence)

    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(output_graph, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
        res = [each.name for each in output_graph_def.node]
        for each in res:
            print(each)
        with tf.compat.v1.Session() as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            t1 = time.time()
            input_ids_p = sess.graph.get_tensor_by_name("input_ids:0")
            input_mask_p = sess.graph.get_tensor_by_name("input_mask:0")
            #
            feed_dict = {input_ids_p: input_ids,
                         input_mask_p: input_mask}
            # 定义输出的张量名称
            output_tensor_name = sess.graph.get_tensor_by_name("viterbi/ReverseSequence_1:0")
            out = sess.run(output_tensor_name, feed_dict)
            pred_label_result = convert_id_to_label(out, id2label)
            t2 = time.time()
            print('模型预测吞吐量:{}'.format((t2-t1)/len(input_ids)))
            print(pred_label_result)


def convert_id_to_label(pred_ids_result, idx2label):

    result = []
    for row in range(len(pred_ids_result)):
        curr_seq = []
        for ids in pred_ids_result[row]:
            if ids == 0:
                continue
            curr_label = idx2label[ids]
            if curr_label in ['[CLS]', '[SEP]']:
                continue
            curr_seq.append(curr_label)
        result.append(curr_seq)
    return result


def convert(samples):
    input_ids_list = []
    input_mask_list = []
    for line in samples:
        feature = convert_single_example(0, line, label_list, 25)
        input_ids_list.append(feature.input_ids)
        input_mask_list.append(feature.input_mask)
        # input_ids = np.reshape([feature.input_ids],(batch_size, FLAGS.max_seq_length))
        # input_mask = np.reshape([feature.input_mask],(batch_size, FLAGS.max_seq_length))
        # segment_ids = np.reshape([feature.segment_ids],(batch_size, FLAGS.max_seq_length))
        # label_ids =np.reshape([feature.label_ids],(batch_size, FLAGS.max_seq_length))
    return input_ids_list, input_mask_list


def convert_single_example(ex_index, example, label_list, max_seq_length):
    """
    将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中
    :param ex_index: index
    :param example: 一个样本
    :param label_list: 标签列表
    :param max_seq_length:
    :param tokenizer:
    :param mode:
    :return:
    """
    label_map = {}
    # 1表示从1开始对label进行index化
    for (i, label) in enumerate(label_list, 1):
        label_map[label] = i
    # 保存label->index 的map
    if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')):
        with open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w:
            pickle.dump(label_map, w)

    tokens = example
    # tokens = .tokenize(example.text)
    # 序列截断
    if len(tokens) >= max_seq_length - 1:
        tokens = tokens[0:(max_seq_length - 2)]  # -2 的原因是因为序列需要加一个句首和句尾标志
    ntokens = []
    segment_ids = []
    label_ids = []
    ntokens.append("[CLS]")  # 句子开始设置CLS 标志
    segment_ids.append(0)
    # append("O") or append("[CLS]") not sure!
    label_ids.append("[CLS]")  # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病
    for i, token in enumerate(tokens):
        ntokens.append(token)
        segment_ids.append(0)
        label_ids.append(0)
    ntokens.append("[SEP]")  # 句尾添加[SEP] 标志
    segment_ids.append(0)
    # append("O") or append("[SEP]") not sure!
    label_ids.append("[SEP]")
    input_ids = tokenizer.convert_tokens_to_ids(ntokens)  # 将序列中的字(ntokens)转化为ID形式
    input_mask = [1] * len(input_ids)

    # padding, 使用
    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)
        # we don't concerned about it!
        label_ids.append(0)
        ntokens.append("**NULL**")
        # label_mask.append(0)
    # print(len(input_ids))
    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length
    assert len(label_ids) == max_seq_length
    # assert len(label_mask) == max_seq_length

    # 结构化为一个类
    feature = InputFeatures(
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        label_ids=label_ids,
        # label_mask = label_mask
    )
    return feature


class InputFeatures(object):
  """A single set of features of data."""

  def __init__(self,
               input_ids,
               input_mask,
               segment_ids,
               label_ids,
               is_real_example=True):
    self.input_ids = input_ids
    self.input_mask = input_mask
    self.segment_ids = segment_ids
    self.label_ids = label_ids
    self.is_real_example = is_real_example


if __name__ == '__main__':
    load_pb_predict()
原文地址:https://www.cnblogs.com/demo-deng/p/13625161.html