Skip to content

Commit a92f4f3

Browse files
committed
Model run is invoked with call operator
1 parent 22e0bc4 commit a92f4f3

File tree

8 files changed

+99
-44
lines changed

8 files changed

+99
-44
lines changed

examples/efficientnet/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ find_library(TENSORFLOW_LIB tensorflow HINT ../../libtensorflow2/lib)
55

66
set(CMAKE_CXX_STANDARD 17)
77

8-
add_executable(example main.cpp)
8+
add_executable(example main.cpp ../../include/cppflow/cppflow.h)
99
target_include_directories(example PRIVATE ../../include ../../libtensorflow2/include)
1010
target_link_libraries (example "${TENSORFLOW_LIB}")

examples/efficientnet/main.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
#include <iostream>
22

3-
#include "cppflow/ops.h"
4-
#include "cppflow/model.h"
3+
#include "cppflow/cppflow.h"
54

65

76
int main() {
87

98
auto input = cppflow::decode_jpeg(cppflow::read_file(std::string("../my_cat.jpg")));
109
input = cppflow::cast(input, TF_UINT8, TF_FLOAT);
1110
input = cppflow::expand_dims(input, 0);
12-
cppflow::model m("../model");
13-
auto output = m.run(input);
11+
cppflow::model model("../model");
12+
auto output = model(input);
1413

1514
std::cout << "It's a tiger cat: " << cppflow::arg_max(output, 1) << std::endl;
1615

examples/load_model/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ project(example)
44
find_library(TENSORFLOW_LIB tensorflow HINT ../../libtensorflow2/lib)
55

66
set(CMAKE_CXX_STANDARD 17)
7-
set(CMAKE_CXX_FLAGS "-g -O3 -Wall -Wextra")
8-
97

108
add_executable(example main.cpp)
119
target_include_directories(example PRIVATE ../../include ../../libtensorflow2/include)

examples/load_model/main.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
int main() {
88

99
auto input = cppflow::fill({10, 5}, 1.0f);
10-
cppflow::model m("../model");
11-
auto output = m.run(input);
10+
cppflow::model model("../model");
11+
auto output = model(input);
1212

1313
std::cout << output << std::endl;
14-
14+
1515
return 0;
1616
}

include/cppflow/cppflow.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//
2+
// Created by serizba on 17/9/20.
3+
//
4+
5+
#ifndef EXAMPLE_CPPFLOW_H
6+
#define EXAMPLE_CPPFLOW_H
7+
8+
#include "tensor.h"
9+
#include "model.h"
10+
#include "raw_ops.h"
11+
#include "ops.h"
12+
#include "datatype.h"
13+
14+
#include <tensorflow/c/c_api.h>
15+
16+
namespace cppflow {
17+
18+
/**
19+
* Version of TensorFlow and CppFlow
20+
* @return A string containing the version of TensorFow and CppFlow
21+
*/
22+
std::string version();
23+
24+
}
25+
26+
/******************************
27+
* IMPLEMENTATION DETAILS *
28+
******************************/
29+
30+
namespace cppflow {
31+
std::string version() {
32+
return "TensorFlow: " + std::string(TF_Version()) + " CppFlow: 2.0.0";
33+
}
34+
}
35+
36+
#endif //EXAMPLE_CPPFLOW_H

include/cppflow/model.h

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,34 +23,10 @@ namespace cppflow {
2323
std::vector<std::string> get_operations() const;
2424

2525
//std::vector<tensor> operator()(std::vector<std::tuple<std::string, tensor>> inputs, std::vector<tensor> outputs);
26-
//std::vector<tensor> operator()(s)
2726

28-
tensor run(const tensor& input) {
27+
// At the moment only default run with one input and one output is implemented
28+
tensor operator()(const tensor& input);
2929

30-
auto inputs = new TF_Output[1];
31-
inputs[0].oper = TF_GraphOperationByName(this->graph, "serving_default_input_1");
32-
inputs[0].index = 0;
33-
34-
TF_Output op2[1];
35-
op2[0].oper = TF_GraphOperationByName(this->graph, "StatefulPartitionedCall");
36-
op2[0].index = 0;
37-
38-
39-
// TRY TO RUN
40-
//********* Allocate data for inputs & outputs
41-
auto inp_tensor = TFE_TensorHandleResolve(input.tfe_handle.get(), context::get_status());
42-
status_check(context::get_status());
43-
44-
45-
TF_Tensor* inpvals[1] = {inp_tensor};
46-
TF_Tensor* outvals[1] = {nullptr};
47-
48-
49-
TF_SessionRun(this->session, NULL, inputs, inpvals, 1, op2, outvals, 1, NULL, 0,NULL , context::get_status());
50-
status_check(context::get_status());
51-
52-
return tensor(outvals[0]);
53-
}
5430

5531
private:
5632

@@ -90,6 +66,31 @@ namespace cppflow {
9066
}
9167
return result;
9268
}
69+
70+
tensor model::operator()(const tensor& input) {
71+
auto inputs = new TF_Output[1];
72+
inputs[0].oper = TF_GraphOperationByName(this->graph, "serving_default_input_1");
73+
inputs[0].index = 0;
74+
75+
TF_Output op2[1];
76+
op2[0].oper = TF_GraphOperationByName(this->graph, "StatefulPartitionedCall");
77+
op2[0].index = 0;
78+
79+
80+
//********* Allocate data for inputs & outputs
81+
auto inp_tensor = TFE_TensorHandleResolve(input.tfe_handle.get(), context::get_status());
82+
status_check(context::get_status());
83+
84+
85+
TF_Tensor* inpvals[1] = {inp_tensor};
86+
TF_Tensor* outvals[1] = {nullptr};
87+
88+
89+
TF_SessionRun(this->session, NULL, inputs, inpvals, 1, op2, outvals, 1, NULL, 0,NULL , context::get_status());
90+
status_check(context::get_status());
91+
92+
return tensor(outvals[0]);
93+
}
9394
}
9495

9596
#endif //CPPFLOW2_MODEL_H

include/cppflow/ops.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ namespace cppflow {
7373
}
7474

7575
std::ostream& operator<<(std::ostream& os, const cppflow::tensor& t) {
76-
os << to_string(t);
77-
return os;
76+
std::string res = to_string(t);
77+
return os << res;
7878
}
7979

8080

@@ -83,9 +83,17 @@ namespace cppflow {
8383
auto res_tensor_h = TFE_TensorHandleResolve(res_tensor.tfe_handle.get(), context::get_status());
8484
status_check(context::get_status());
8585

86-
auto *t_str = static_cast<TF_TString *>(TF_TensorData(res_tensor_h));
86+
// For future version TensorFlow 2.4
87+
//auto *t_str = reinterpret_cast<TF_TString *>(TF_TensorData(res_tensor_h));
88+
//auto *t_str = (TF_TString *)(TF_TensorData(res_tensor_h));
89+
//auto result = std::string(TF_TString_GetDataPointer(t_str), TF_TString_GetSize(t_str));
90+
91+
const char* dst[1] = {nullptr};
92+
size_t dst_len[1] = {3};
93+
TF_StringDecode(static_cast<char*>(TF_TensorData(res_tensor_h)) + 8, TF_TensorByteSize(res_tensor_h), dst, dst_len, context::get_status());
94+
status_check(context::get_status());
95+
auto result = std::string(dst[0], *dst_len);
8796

88-
auto result = std::string(TF_TString_GetDataPointer(t_str), TF_TString_GetSize(t_str));
8997
TF_DeleteTensor(res_tensor_h);
9098

9199
return result;

include/cppflow/tensor.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,26 @@ namespace cppflow {
125125
tensor::tensor(const T& value) :
126126
tensor(std::vector<T>({value}), {}) {}
127127

128+
// For future version TensorFlow 2.4
129+
//template<>
130+
//tensor::tensor(const std::string& value) {
131+
// TF_TString tstr[1];
132+
// TF_TString_Init(&tstr[0]);
133+
// TF_TString_Copy(&tstr[0], value.c_str(), value.size());
134+
//
135+
// *this = tensor(static_cast<enum TF_DataType>(TF_STRING), (void *) tstr, TF_TString_GetSize(tstr), {});
136+
//}
137+
128138
template<>
129139
tensor::tensor(const std::string& value) {
130-
TF_TString tstr[1];
131-
TF_TString_Init(&tstr[0]);
132-
TF_TString_Copy(&tstr[0], value.c_str(), value.size());
140+
size_t size = 8 + TF_StringEncodedSize(value.length());
141+
char* data = new char[value.size() + 8];
142+
for (int i=0; i<8; i++) {data[i]=0;}
143+
TF_StringEncode(value.c_str(), value.size(), data + 8, size - 8, context::get_status());
144+
status_check(context::get_status());
133145

134-
*this = tensor(static_cast<enum TF_DataType>(TF_STRING), (void *) tstr, TF_TString_GetSize(tstr), {});
146+
*this = tensor(static_cast<enum TF_DataType>(TF_STRING), (void *) data, size, {});
147+
delete [] data;
135148
}
136149

137150
tensor::tensor(TFE_TensorHandle* handle) {

0 commit comments

Comments
 (0)