Skip to content

Commit c3f646f

Browse files
authored
Merge pull request serizba#201 from ljn917/cppflow2-per-model-status
Use per model TF_Status instead of the global one
2 parents e389df0 + 45aa7eb commit c3f646f

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

include/cppflow/context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ namespace cppflow {
2424
class context {
2525
public:
2626
static TFE_Context* get_context();
27+
28+
// only use get_status() for eager ops
2729
static TF_Status* get_status();
2830

2931
private:

include/cppflow/model.h

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ namespace cppflow {
4242
private:
4343
TF_Buffer * readGraph(const std::string& filename);
4444

45+
std::shared_ptr<TF_Status> status;
4546
std::shared_ptr<TF_Graph> graph;
4647
std::shared_ptr<TF_Session> session;
4748
};
@@ -51,14 +52,15 @@ namespace cppflow {
5152
namespace cppflow {
5253

5354
inline model::model(const std::string &filename, const TYPE type) {
55+
this->status = {TF_NewStatus(), &TF_DeleteStatus};
5456
this->graph = {TF_NewGraph(), TF_DeleteGraph};
5557

5658
// Create the session.
5759
std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)> session_options = {TF_NewSessionOptions(), TF_DeleteSessionOptions};
5860

59-
auto session_deleter = [](TF_Session* sess) {
60-
TF_DeleteSession(sess, context::get_status());
61-
status_check(context::get_status());
61+
auto session_deleter = [this](TF_Session* sess) {
62+
TF_DeleteSession(sess, this->status.get());
63+
status_check(this->status.get());
6264
};
6365

6466
if (type == TYPE::SAVED_MODEL) {
@@ -68,12 +70,12 @@ namespace cppflow {
6870
int tag_len = 1;
6971
const char* tag = "serve";
7072
this->session = {TF_LoadSessionFromSavedModel(session_options.get(), run_options.get(), filename.c_str(),
71-
&tag, tag_len, this->graph.get(), meta_graph.get(), context::get_status()),
73+
&tag, tag_len, this->graph.get(), meta_graph.get(), this->status.get()),
7274
session_deleter};
7375
}
7476
else if (type == TYPE::FROZEN_GRAPH) {
75-
this->session = {TF_NewSession(this->graph.get(), session_options.get(), context::get_status()), session_deleter};
76-
status_check(context::get_status());
77+
this->session = {TF_NewSession(this->graph.get(), session_options.get(), this->status.get()), session_deleter};
78+
status_check(this->status.get());
7779

7880
// Import the graph definition
7981
TF_Buffer* def = readGraph(filename);
@@ -82,14 +84,14 @@ namespace cppflow {
8284
}
8385

8486
std::unique_ptr<TF_ImportGraphDefOptions, decltype(&TF_DeleteImportGraphDefOptions)> graph_opts = {TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions};
85-
TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(), context::get_status());
87+
TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(), this->status.get());
8688
TF_DeleteBuffer(def);
8789
}
8890
else {
8991
throw std::runtime_error("Model type unknown");
9092
}
9193

92-
status_check(context::get_status());
94+
status_check(this->status.get());
9395
}
9496

9597
inline std::vector<std::string> model::get_operations() const {
@@ -122,16 +124,16 @@ namespace cppflow {
122124
// DIMENSIONS
123125

124126
// Get number of dimensions
125-
int n_dims = TF_GraphGetTensorNumDims(this->graph.get(), out_op, context::get_status());
127+
int n_dims = TF_GraphGetTensorNumDims(this->graph.get(), out_op, this->status.get());
126128

127129
// If is not a scalar
128130
if (n_dims > 0) {
129131
// Get dimensions
130132
auto* dims = new int64_t[n_dims];
131-
TF_GraphGetTensorShape(this->graph.get(), out_op, dims, n_dims, context::get_status());
133+
TF_GraphGetTensorShape(this->graph.get(), out_op, dims, n_dims, this->status.get());
132134

133135
// Check error on Model Status
134-
status_check(context::get_status());
136+
status_check(this->status.get());
135137

136138
shape = std::vector<int64_t>(dims, dims + n_dims);
137139

@@ -181,8 +183,8 @@ namespace cppflow {
181183
TF_SessionRun(this->session.get(), NULL,
182184
inp_ops.data(), inp_val.data(), static_cast<int>(inputs.size()),
183185
out_ops.data(), out_val.get(), static_cast<int>(outputs.size()),
184-
NULL, 0,NULL , context::get_status());
185-
status_check(context::get_status());
186+
NULL, 0,NULL , this->status.get());
187+
status_check(this->status.get());
186188

187189
std::vector<tensor> result;
188190
result.reserve(outputs.size());

0 commit comments

Comments
 (0)