Skip to content

Commit 0300796

Browse files
authored
Merge pull request tobegit3hub#26 from justinsu/master
add sparse_predict_client.cc
2 parents c29f742 + ee822ad commit 0300796

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/* A c++ version of sparse_predict_client
2+
* Build it like inception_client.cc
3+
=======================================================*/
4+
#include <iostream>
5+
#include <fstream>
6+
7+
#include <grpc++/create_channel.h>
8+
#include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
9+
#include "tensorflow/core/framework/tensor.h"
10+
#include "tensorflow/core/util/command_line_flags.h"
11+
12+
using grpc::Channel;
13+
using grpc::ClientContext;
14+
using grpc::ClientReader;
15+
using grpc::ClientReaderWriter;
16+
using grpc::ClientWriter;
17+
using grpc::Status;
18+
19+
20+
using tensorflow::serving::PredictRequest;
21+
using tensorflow::serving::PredictResponse;
22+
using tensorflow::serving::PredictionService;
23+
24+
typedef google::protobuf::Map< std::string, tensorflow::TensorProto > OutMap;
25+
26+
27+
class ServingClient {
28+
public:
29+
ServingClient(std::shared_ptr<Channel> channel)
30+
: stub_(PredictionService::NewStub(channel)) {
31+
}
32+
33+
std::string callPredict(std::string model_name) {
34+
PredictRequest predictRequest;
35+
PredictResponse response;
36+
ClientContext context;
37+
38+
predictRequest.mutable_model_spec()->set_name(model_name);
39+
40+
google::protobuf::Map< std::string, tensorflow::TensorProto >& inputs =
41+
*predictRequest.mutable_inputs();
42+
43+
// Example libSVM data:
44+
// 0 5:1 6:1 17:1 21:1 35:1 40:1 53:1 63:1 71:1 73:1 74:1 76:1 80:1 83:1
45+
// 1 5:1 7:1 17:1 22:1 36:1 40:1 51:1 63:1 67:1 73:1 74:1 76:1 81:1 83:1
46+
47+
// Generate keys proto
48+
tensorflow::TensorProto keys_tensor_proto;
49+
keys_tensor_proto.set_dtype(tensorflow::DataType::DT_INT32);
50+
keys_tensor_proto.add_int_val(1);
51+
keys_tensor_proto.add_int_val(2);
52+
keys_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
53+
54+
inputs["keys"] = keys_tensor_proto;
55+
56+
57+
// Generate indexs TensorProto
58+
tensorflow::TensorProto indexs_tensor_proto;
59+
indexs_tensor_proto.set_dtype(tensorflow::DataType::DT_INT64);
60+
long indexs[28][2] = { {0, 0}, {0, 1}, {0, 2}, {0, 3}, {0, 4}, {0, 5},
61+
{0, 6}, {0, 7}, {0, 8}, {0, 9}, {0, 10}, {0, 11},
62+
{0, 12}, {0, 13}, {1, 0}, {1, 1}, {1, 2}, {1, 3},
63+
{1, 4}, {1, 5}, {1, 6}, {1, 7}, {1, 8}, {1, 9},
64+
{1, 10}, {1, 11}, {1, 12}, {1, 13} };
65+
for (int i = 0; i < 28; i++) {
66+
for (int j = 0; j < 2; j++) {
67+
indexs_tensor_proto.add_int64_val(indexs[i][j]);
68+
}
69+
}
70+
indexs_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(28);
71+
indexs_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
72+
73+
inputs["indexs"] = indexs_tensor_proto;
74+
std::cout << "Generate indexs tensorproto ok." << std::endl;
75+
76+
// Generate ids TensorProto
77+
tensorflow::TensorProto ids_tensor_proto;
78+
ids_tensor_proto.set_dtype(tensorflow::DataType::DT_INT64);
79+
int ids[28] = {5, 6, 17, 21, 35, 40, 53, 63, 71, 73, 74, 76, 80, 83, 5,
80+
7, 17, 22, 36, 40, 51, 63, 67, 73, 74, 76, 81, 83};
81+
for (int i = 0; i < 28; i++) {
82+
ids_tensor_proto.add_int64_val(ids[i]);
83+
}
84+
ids_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(28);
85+
86+
inputs["ids"] = ids_tensor_proto;
87+
std::cout << "Generate ids tensorproto ok." << std::endl;
88+
89+
// Generate values TensorProto
90+
tensorflow::TensorProto values_tensor_proto;
91+
values_tensor_proto.set_dtype(tensorflow::DataType::DT_FLOAT);
92+
float values[] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
93+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
94+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
95+
for (int i = 0; i < 28; i++) {
96+
values_tensor_proto.add_float_val(values[i]);
97+
}
98+
values_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(28);
99+
100+
inputs["values"] = values_tensor_proto;
101+
std::cout << "Generate values tensorproto ok." << std::endl;
102+
103+
// Generate shape TensorProto
104+
tensorflow::TensorProto shape_tensor_proto;
105+
shape_tensor_proto.set_dtype(tensorflow::DataType::DT_INT64);
106+
shape_tensor_proto.add_int64_val(2); // ins num
107+
shape_tensor_proto.add_int64_val(124); // feature num
108+
shape_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
109+
110+
inputs["shape"] = shape_tensor_proto;
111+
std::cout << "Generate shape tensorproto ok." << std::endl;
112+
113+
114+
Status status = stub_->Predict(&context, predictRequest, &response);
115+
116+
std::cout << "check status.." << std::endl;
117+
118+
if (status.ok()) {
119+
std::cout << "call predict ok" << std::endl;
120+
std::cout << "outputs size is "<< response.outputs_size() << std::endl;
121+
OutMap& map_outputs = *response.mutable_outputs();
122+
OutMap::iterator iter;
123+
int output_index = 0;
124+
125+
for(iter = map_outputs.begin();iter != map_outputs.end(); ++iter){
126+
tensorflow::TensorProto& result_tensor_proto= iter->second;
127+
tensorflow::Tensor tensor;
128+
bool converted = tensor.FromProto(result_tensor_proto);
129+
if (converted) {
130+
std::cout << "the " <<iter->first <<" result tensor[" << output_index << "] is:" <<
131+
std::endl << tensor.SummarizeValue(13) << std::endl;
132+
}else {
133+
std::cout << "the " <<iter->first <<" result tensor[" << output_index <<
134+
"] convert failed." << std::endl;
135+
}
136+
++output_index;
137+
}
138+
return "Done.";
139+
} else {
140+
std::cout << "gRPC call return code: "
141+
<<status.error_code() << ": " << status.error_message()
142+
<< std::endl;
143+
return "gRPC failed.";
144+
}
145+
}
146+
147+
private:
148+
std::unique_ptr<PredictionService::Stub> stub_;
149+
};
150+
151+
int main(int argc, char** argv) {
152+
std::string server_port = "localhost:9000";
153+
std::string model_name = "sparse";
154+
std::vector<tensorflow::Flag> flag_list = {
155+
tensorflow::Flag("server_port", &server_port,
156+
"the IP and port of the server"),
157+
tensorflow::Flag("model_name", &model_name, "name of model")
158+
};
159+
std::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
160+
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
161+
if (!parse_result) {
162+
std::cout << usage;
163+
return -1;
164+
}
165+
166+
ServingClient guide(
167+
grpc::CreateChannel( server_port,
168+
grpc::InsecureChannelCredentials()));
169+
std::cout << "Calling sparse predictor..." << std::endl;
170+
std::cout << guide.callPredict(model_name) << std::endl;
171+
172+
return 0;
173+
}

0 commit comments

Comments
 (0)