tensorflow c++ API加载.pb模型文件并预测图片

tensorflow  python创建模型,训练模型,得到.pb模型文件后,用c++ api进行预测

  1 #include <iostream>
  2 #include <map>
  3 
  4 #include "tensorflow/cc/ops/image_ops.h"
  5 #include "tensorflow/cc/ops/standard_ops.h"
  6 #include "tensorflow/core/framework/graph.pb.h"
  7 #include "tensorflow/core/framework/tensor.h"
  8 #include "tensorflow/core/graph/default_device.h"
  9 #include "tensorflow/core/graph/graph_def_builder.h"
 10 #include "tensorflow/core/platform/logging.h"
 11 #include "tensorflow/core/platform/types.h"
 12 #include "tensorflow/core/public/session.h"
 13 
 14 using namespace std ;
 15 using namespace tensorflow;
 16 using tensorflow::Tensor;
 17 using tensorflow::Status;
 18 using tensorflow::string;
 19 using tensorflow::int32;
 20 
 21 
 22 //从文件名中读取数据
 23 Status ReadTensorFromImageFile(string file_name, const int input_height,
 24                                const int input_width,
 25                                vector<Tensor>* out_tensors) {
 26     auto root = Scope::NewRootScope();
 27     using namespace ops;
 28 
 29     auto file_reader = ops::ReadFile(root.WithOpName("file_reader"),file_name);
 30     const int wanted_channels = 1;
 31     Output image_reader;
 32     std::size_t found = file_name.find(".png");
 33     //判断文件格式
 34     if (found!=std::string::npos) {
 35         image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,DecodePng::Channels(wanted_channels));
 36     } 
 37     else {
 38         image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,DecodeJpeg::Channels(wanted_channels));
 39     }
 40     // 下面几步是读取图片并处理
 41     auto float_caster =Cast(root.WithOpName("float_caster"), image_reader, DT_FLOAT);
 42     auto dims_expander = ExpandDims(root, float_caster, 0);
 43     auto resized = ResizeBilinear(root, dims_expander,Const(root.WithOpName("resize"), {input_height, input_width}));
 44     // Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),{input_std});
 45     Transpose(root.WithOpName("transpose"),resized,{0,2,1,3});
 46 
 47     GraphDef graph;
 48     root.ToGraphDef(&graph);
 49 
 50     unique_ptr<Session> session(NewSession(SessionOptions()));
 51     session->Create(graph);
 52     session->Run({}, {"transpose"}, {}, out_tensors);//Run,获取图片数据保存到Tensor中
 53 
 54     return Status::OK();
 55 }
 56 
 57 int main(int argc, char* argv[]) {
 58 
 59     string graph_path = "aov_crnn.pb";
 60     GraphDef graph_def;
 61     //读取模型文件
 62     if (!ReadBinaryProto(Env::Default(), graph_path, &graph_def).ok()) {
 63         cout << "Read model .pb failed"<<endl;
 64         return -1;
 65     }
 66 
 67     //新建session
 68     unique_ptr<Session> session;
 69     SessionOptions sess_opt;
 70     sess_opt.config.mutable_gpu_options()->set_allow_growth(true);
 71     (&session)->reset(NewSession(sess_opt));
 72     if (!session->Create(graph_def).ok()) {
 73         cout<<"Create graph failed"<<endl;
 74         return -1;
 75     }
 76 
 77     //读取图像到inputs中
 78     int input_height = 40;
 79     int input_width = 240;
 80     vector<Tensor> inputs;
 81     // string image_path(argv[1]);
 82     string image_path("test.jpg");
 83     if (!ReadTensorFromImageFile(image_path, input_height, input_width,&inputs).ok()) {
 84         cout<<"Read image file failed"<<endl;
 85         return -1;
 86     }
 87 
 88     vector<Tensor> outputs;
 89     string input = "inputs_sq";
 90     string output = "results_sq";//graph中的输入节点和输出节点,需要预先知道
 91 
 92     pair<string,Tensor>img(input,inputs[0]);
 93     Status status = session->Run({img},{output}, {}, &outputs);//Run,得到运行结果,存到outputs中
 94     if (!status.ok()) {
 95         cout<<"Running model failed"<<endl;
 96         cout<<status.ToString()<<endl;
 97         return -1;
 98     }
 99 
100 
101     //得到模型运行结果
102     Tensor t = outputs[0];        
103     auto tmap = t.tensor<int64, 2>(); 
104     int output_dim = t.shape().dim_size(1); 
105 
106 
107     return 0;
108 }
g++ -g  tf_predict.cpp -o tf_predict -I /usr/include/eigen3 -I /usr/local/include/tf  -L/usr/local/lib/ `pkg-config --cflags --libs protobuf`  -ltensorflow_cc  -ltensorflow_framework

 也可以用opencv c++库读取图片Mat复制到Tensor中

 1 tensorflow::Tensor readTensor(string filename){
 2     tensorflow::Tensor input_tensor(DT_FLOAT,TensorShape({1,240,40,1}));
 3     Mat src=imread(filename,0);
 4     Mat dst;
 5     resize(src,dst,Size(240,40));//resize
 6     Mat dst_transpose=dst.t();//transpose
 7 
 8     auto tmap=input_tensor.tensor<float,4>();
 9 
10     for(int i=0;i<240;i++){//Mat复制到Tensor
11         for(int j=0;j<40;j++){
12             tmap(0,i,j,0)=dst_transpose.at<uchar>(i,j);
13         }
14     }
15 
16     return input_tensor;
17 }

 也可用指针引用的方式转换

1             tensorflow::Tensor input_tensor(DT_FLOAT,TensorShape({1,height,width,3}));
2         float *tensor_data_ptr = input_tensor.flat<float>().data();              
3         cv::Mat fake_mat(dst.rows, dst.cols, CV_32FC(src.channels()), tensor_data_ptr); 
4         dst.convertTo(fake_mat, CV_32FC3);
原文地址:https://www.cnblogs.com/buyizhiyou/p/10412967.html