@@ -42,6 +42,7 @@ namespace cppflow {
42
42
private:
43
43
TF_Buffer * readGraph (const std::string& filename);
44
44
45
+ std::shared_ptr<TF_Status> status;
45
46
std::shared_ptr<TF_Graph> graph;
46
47
std::shared_ptr<TF_Session> session;
47
48
};
@@ -51,14 +52,15 @@ namespace cppflow {
51
52
namespace cppflow {
52
53
53
54
inline model::model (const std::string &filename, const TYPE type) {
55
+ this ->status = {TF_NewStatus (), &TF_DeleteStatus};
54
56
this ->graph = {TF_NewGraph (), TF_DeleteGraph};
55
57
56
58
// Create the session.
57
59
std::unique_ptr<TF_SessionOptions, decltype (&TF_DeleteSessionOptions)> session_options = {TF_NewSessionOptions (), TF_DeleteSessionOptions};
58
60
59
- auto session_deleter = [](TF_Session* sess) {
60
- TF_DeleteSession (sess, context::get_status ());
61
- status_check (context::get_status ());
61
+ auto session_deleter = [this ](TF_Session* sess) {
62
+ TF_DeleteSession (sess, this -> status . get ());
63
+ status_check (this -> status . get ());
62
64
};
63
65
64
66
if (type == TYPE::SAVED_MODEL) {
@@ -68,12 +70,12 @@ namespace cppflow {
68
70
int tag_len = 1 ;
69
71
const char * tag = " serve" ;
70
72
this ->session = {TF_LoadSessionFromSavedModel (session_options.get (), run_options.get (), filename.c_str (),
71
- &tag, tag_len, this ->graph .get (), meta_graph.get (), context::get_status ()),
73
+ &tag, tag_len, this ->graph .get (), meta_graph.get (), this -> status . get ()),
72
74
session_deleter};
73
75
}
74
76
else if (type == TYPE::FROZEN_GRAPH) {
75
- this ->session = {TF_NewSession (this ->graph .get (), session_options.get (), context::get_status ()), session_deleter};
76
- status_check (context::get_status ());
77
+ this ->session = {TF_NewSession (this ->graph .get (), session_options.get (), this -> status . get ()), session_deleter};
78
+ status_check (this -> status . get ());
77
79
78
80
// Import the graph definition
79
81
TF_Buffer* def = readGraph (filename);
@@ -82,14 +84,14 @@ namespace cppflow {
82
84
}
83
85
84
86
std::unique_ptr<TF_ImportGraphDefOptions, decltype (&TF_DeleteImportGraphDefOptions)> graph_opts = {TF_NewImportGraphDefOptions (), TF_DeleteImportGraphDefOptions};
85
- TF_GraphImportGraphDef (this ->graph .get (), def, graph_opts.get (), context::get_status ());
87
+ TF_GraphImportGraphDef (this ->graph .get (), def, graph_opts.get (), this -> status . get ());
86
88
TF_DeleteBuffer (def);
87
89
}
88
90
else {
89
91
throw std::runtime_error (" Model type unknown" );
90
92
}
91
93
92
- status_check (context::get_status ());
94
+ status_check (this -> status . get ());
93
95
}
94
96
95
97
inline std::vector<std::string> model::get_operations () const {
@@ -122,16 +124,16 @@ namespace cppflow {
122
124
// DIMENSIONS
123
125
124
126
// Get number of dimensions
125
- int n_dims = TF_GraphGetTensorNumDims (this ->graph .get (), out_op, context::get_status ());
127
+ int n_dims = TF_GraphGetTensorNumDims (this ->graph .get (), out_op, this -> status . get ());
126
128
127
129
// If is not a scalar
128
130
if (n_dims > 0 ) {
129
131
// Get dimensions
130
132
auto * dims = new int64_t [n_dims];
131
- TF_GraphGetTensorShape (this ->graph .get (), out_op, dims, n_dims, context::get_status ());
133
+ TF_GraphGetTensorShape (this ->graph .get (), out_op, dims, n_dims, this -> status . get ());
132
134
133
135
// Check error on Model Status
134
- status_check (context::get_status ());
136
+ status_check (this -> status . get ());
135
137
136
138
shape = std::vector<int64_t >(dims, dims + n_dims);
137
139
@@ -181,8 +183,8 @@ namespace cppflow {
181
183
TF_SessionRun (this ->session .get (), NULL ,
182
184
inp_ops.data (), inp_val.data (), static_cast <int >(inputs.size ()),
183
185
out_ops.data (), out_val.get (), static_cast <int >(outputs.size ()),
184
- NULL , 0 ,NULL , context::get_status ());
185
- status_check (context::get_status ());
186
+ NULL , 0 ,NULL , this -> status . get ());
187
+ status_check (this -> status . get ());
186
188
187
189
std::vector<tensor> result;
188
190
result.reserve (outputs.size ());
0 commit comments