tensorflow 优化图

当我们把训练好的tensorflow训练图拿来进行预测时,会有多个训练时生成的节点,这些节点是不必要的,我们需要在预测的时候进行删除。

下面以bert的图为例,进行优化

    def optimize_graph(self, checkpoint_file, model_config):
        import json
        tf = self.import_tf()
        from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference

        config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True)

        init_checkpoint = checkpoint_file

        with tf.gfile.GFile(model_config, 'r') as f:
            bert_config = modeling.BertConfig.from_dict(json.load(f))

        input_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_ids')
        input_mask = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_mask')
        input_type_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_type_ids')

        import contextlib
        jit_scope = contextlib.suppress

        with jit_scope():
            input_tensors = [input_ids, input_mask, input_type_ids]
            model = modeling.BertModel(
                config=bert_config,
                is_training=False,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=input_type_ids,
                use_one_hot_embeddings=False)

            tvars = tf.trainable_variables()

            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)

            # get output tensor
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
            reader = tf.train.NewCheckpointReader(init_checkpoint)
            output_weights = reader.get_tensor('output_weights')
            output_bias = reader.get_tensor('output_bias')
            output_layers = model.get_pooled_output()
            pooled = tf.nn.softmax(tf.nn.bias_add(tf.matmul(output_layers, output_weights, transpose_b=True),
                                                  output_bias))
            pooled = tf.identity(pooled, 'final_encodes')

            output_tensors = [pooled]
            tmp_g = tf.get_default_graph().as_graph_def()

            # write graph to file
            with tf.Session(config=config) as sess:
                sess.run(tf.global_variables_initializer())
                tmp_g = tf.graph_util.convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors])
                dtypes = [n.dtype for n in input_tensors]
                tmp_g = optimize_for_inference(
                    tmp_g,
                    [n.name[:-2] for n in input_tensors],
                    [n.name[:-2] for n in output_tensors],
                    [dtype.as_datatype_enum for dtype in dtypes],
                    False)

                import tempfile
                tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=r'optimize').name
                with tf.gfile.GFile(tmp_file, 'wb') as f:
                    f.write(tmp_g.SerializeToString())

                return tmp_file

返回一个gfile类型的文件,我们可以像原来导入模型文件时,恢复图,不过这个图是优化过的。

原文地址:https://www.cnblogs.com/callyblog/p/10388487.html