通过grpc调用tfserving模型(python+java)

tfserving模型部署见:https://www.cnblogs.com/bincoding/p/13266685.html
demo代码:https://github.com/haibincoder/tf_tools

对应restful入参:

{
    "inputs": {
        "input": [[13, 45, 13, 13, 49, 1, 49, 196, 594, 905, 48, 231, 318, 712, 1003, 477, 259, 291, 287, 161, 65, 62, 82, 68, 2, 10]],
        "drop_out": 1,
        "sequence_length": [26]
        },
    "signature_name":"predict"
}

python代码:

from grpc.beta import implementations
import tensorflow as tf

from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc

# 获取stub
channel = implementations.insecure_channel('localhost', 8500)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel._channel)

# 模型签名
request = predict_pb2.PredictRequest()
request.model_spec.name = 'ner'
# request.model_spec.version = 'latest'
request.model_spec.signature_name = 'predict'

# 构造入参
x_data = [[13, 45, 13, 13, 49, 1, 49, 196, 594, 905, 48, 231, 318, 712, 1003, 477, 259, 291, 287, 161, 65, 62, 82, 68, 2, 10]]
drop_out = 1
sequence_length = [26]
request.inputs['input'].CopyFrom(tf.make_tensor_proto(x_data, dtype=tf.int32))
request.inputs['sequence_length'].CopyFrom(tf.make_tensor_proto(sequence_length, dtype=tf.int32))
request.inputs['drop_out'].CopyFrom(tf.make_tensor_proto(drop_out, dtype=tf.float32))

#  返回CRF结果,输出发射概率矩阵和状态转移概率矩阵
result = stub.Predict(request, 10.0)  # 10 secs timeout

print(result)

java pom:

<dependencies>
        <dependency>
            <groupId>com.yesup.oss</groupId>
            <artifactId>tensorflow-client</artifactId>
            <version>1.4-2</version>
        </dependency>
        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-netty-shaded</artifactId>
            <version>1.14.0</version>
        </dependency>
        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-protobuf</artifactId>
            <version>1.14.0</version>
        </dependency>
        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-stub</artifactId>
            <version>1.14.0</version>
        </dependency>
    </dependencies>

java代码:

public static void main(String[] args) {
        ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 8500).usePlaintext(true).build();
        // 这里使用block模式
        PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
        // 创建请求
        Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
        // 模型名称和模型方法名预设
        Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
        modelSpecBuilder.setName("ner");
        modelSpecBuilder.setSignatureName("predict");
        predictRequestBuilder.setModelSpec(modelSpecBuilder);
        // 设置入参,访问默认是最新版本,如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法
        List<Float> input = Arrays.asList(13f, 45f, 13f, 13f, 49f, 1f, 49f, 196f, 594f, 905f, 48f, 231f, 318f, 712f, 1003f, 477f, 259f, 291f, 287f, 161f, 65f, 62f, 82f, 68f, 2f, 10f);
        TensorProto.Builder inputTensorProto = TensorProto.newBuilder();
        inputTensorProto.setDtype(DataType.DT_INT32);
        inputTensorProto.addAllFloatVal(input);
        TensorShapeProto.Builder inputShapeBuilder = TensorShapeProto.newBuilder();
        inputShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
        inputShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(input.size()));
        inputTensorProto.setTensorShape(inputShapeBuilder.build());
        
        int dropout = 1;
        TensorProto.Builder dropoutTensorProto = TensorProto.newBuilder();
        dropoutTensorProto.setDtype(DataType.DT_FLOAT);
        dropoutTensorProto.addIntVal(dropout);
        TensorShapeProto.Builder dropoutShapeBuilder = TensorShapeProto.newBuilder();
        dropoutShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
        dropoutTensorProto.setTensorShape(dropoutShapeBuilder.build());

        List<Integer> seqLength = Collections.singletonList(26);
        TensorProto.Builder seqLengthTensorProto = TensorProto.newBuilder();
        seqLengthTensorProto.setDtype(DataType.DT_INT32);
        seqLengthTensorProto.addAllIntVal(seqLength);
        TensorShapeProto.Builder seqLengthShapeBuilder = TensorShapeProto.newBuilder();
        seqLengthShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
        seqLengthTensorProto.setTensorShape(seqLengthShapeBuilder.build());

        predictRequestBuilder.putInputs("input", inputTensorProto.build());
        predictRequestBuilder.putInputs("drop_out", dropoutTensorProto.build());
        predictRequestBuilder.putInputs("sequence_length", seqLengthTensorProto.build());

        // 访问并获取结果
        Predict.PredictResponse predictResponse = stub.withDeadlineAfter(3, TimeUnit.SECONDS).predict(predictRequestBuilder.build());
        Map<String, TensorProto> result = predictResponse.getOutputsMap();
        // CRF模型结果,发射概率矩阵和状态概率矩阵
        System.out.println("预测值是:" + result.toString());
    }

注意事项:

  1. 请求type和模型定义的type保持一致,可以到tfserving网页查看模型参数:
    否则会报错:Expects arg[0] to be float but int32 is provided
    tfserving restful网页:http://localhost:8501/v1/models/ner/metadata
    tfserving部署方法见:https://www.cnblogs.com/bincoding/p/13266685.html
原文地址:https://www.cnblogs.com/bincoding/p/13274948.html