Skip to content

Commit ef6427c

Browse files
committed
simple model class
1 parent f6f8830 commit ef6427c

File tree

3 files changed

+229
-0
lines changed

3 files changed

+229
-0
lines changed

src/ofxMSATFSimpleModel.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#include "ofxMSATFSimpleModel.h"
2+
3+
namespace msa {
4+
namespace tf {
5+
6+
7+
//--------------------------------------------------------------
8+
SimpleModel::SimpleModel(string model_path,
9+
vector<string> input_op_names,
10+
vector<string> output_op_names,
11+
string name,
12+
const string device,
13+
const tensorflow::SessionOptions& session_options) {
14+
setup(model_path, input_op_names, output_op_names, name, device, session_options);
15+
}
16+
17+
18+
//--------------------------------------------------------------
19+
void SimpleModel::setup(string model_path,
20+
vector<string> input_op_names,
21+
vector<string> output_op_names,
22+
string name,
23+
const string device,
24+
const tensorflow::SessionOptions& session_options) {
25+
ofLogVerbose("msa::tf::SimpleModel")
26+
<< "SimpleModel " << name
27+
<< " model_path: " << model_path;
28+
// << " input_op_name: " << input_op_names
29+
// << " output_op_name: " << output_op_names
30+
// << " device: " << device;
31+
32+
close();
33+
34+
this->model_path = model_path;
35+
this->input_op_names = input_op_names;
36+
this->output_op_names = output_op_names;
37+
this->name = (name=="") ? model_path : name;
38+
39+
this->graph_def = load_graph_def(model_path);
40+
this->session = create_session_with_graph(this->graph_def, device, session_options);
41+
42+
// prepare input tensors
43+
// ideally read tensor type & shape from the graph_def and allocate tensors correctly
44+
// for now init empty tensors, and user should call init_inputs()
45+
for(const auto& op_name : input_op_names) this->input_tensors.push_back(make_pair(op_name, tensorflow::Tensor()));
46+
}
47+
48+
49+
//--------------------------------------------------------------
50+
void SimpleModel::init_inputs(tensorflow::DataType type, const tensorflow::TensorShape& shape, int tensor_index) {
51+
this->input_tensors[tensor_index].second = tensorflow::Tensor(type, shape);
52+
}
53+
54+
55+
//--------------------------------------------------------------
56+
bool SimpleModel::run() {
57+
if(!this->is_loaded()) {
58+
ofLogWarning("msa::tf::SimpleModel") << "Trying to run " << name << " when not loaded";
59+
return false;
60+
}
61+
62+
63+
// run graph, feed input tensors, fetch output tensors
64+
tensorflow::Status status = session->Run(this->input_tensors, this->output_op_names, {}, &this->output_tensors);
65+
if(status != tensorflow::Status::OK()) {
66+
ofLogError("msa::tf::SimpleModel") << status.error_message();
67+
return false;
68+
}
69+
70+
return true;
71+
}
72+
73+
74+
//--------------------------------------------------------------
75+
void SimpleModel::close() {
76+
input_op_names.clear();
77+
output_op_names.clear();
78+
79+
session = nullptr;
80+
graph_def = nullptr;
81+
82+
input_tensors.clear();
83+
output_tensors.clear();
84+
}
85+
86+
87+
88+
}
89+
}

src/ofxMSATFSimpleModel.h

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* a simple wrapper for a basic predictor model
3+
*
4+
*
5+
*/
6+
7+
#pragma once
8+
9+
#include "ofxMSATFIncludes.h"
10+
#include "ofxMSATFUtils.h"
11+
12+
namespace msa {
13+
namespace tf {
14+
15+
class SimpleModel
16+
{
17+
public:
18+
typedef shared_ptr<SimpleModel> Ptr;
19+
20+
//--------------------------------------------------------------
21+
// empty constructor does nothing, call setup later.
22+
SimpleModel() {}
23+
24+
//--------------------------------------------------------------
25+
// or pass everything in constructor (E.g. if using as Ptr)
26+
SimpleModel(string model_path,
27+
vector<string> input_op_names,
28+
vector<string> output_op_names,
29+
string name="",
30+
const string device="", // "/cpu:0", "/gpu:0" etc.
31+
const tensorflow::SessionOptions& session_options=tensorflow::SessionOptions());
32+
33+
34+
//--------------------------------------------------------------
35+
// or call setup
36+
void setup(string model_path,
37+
vector<string> input_op_names,
38+
vector<string> output_op_names,
39+
string name="",
40+
const string device="", // "/cpu:0", "/gpu:0" etc.
41+
const tensorflow::SessionOptions& session_options=tensorflow::SessionOptions());
42+
43+
44+
45+
//--------------------------------------------------------------
46+
// then initialise input tensors to specified type and shape
47+
// tensor_index is which input tensor to init (if there is more than one). order is same as input_op_names
48+
// (ideally the SimpleModel constructor or setup would read this info from the graph_def and call this internally)
49+
void init_inputs(tensorflow::DataType type, const tensorflow::TensorShape& shape, int tensor_index=0);
50+
51+
52+
//--------------------------------------------------------------
53+
// getters
54+
bool is_loaded() const { return session != nullptr; }
55+
56+
string get_name() const { return name; }
57+
string get_model_path() const { return model_path; }
58+
59+
const vector<string>& get_input_op_names() const { return input_op_names; }
60+
const vector<string>& get_output_op_names() const { return output_op_names; }
61+
62+
Session_ptr& get_session() { return session; }
63+
const Session_ptr& get_session() const { return session; }
64+
65+
GraphDef_ptr& get_graph_def() { return graph_def; }
66+
const GraphDef_ptr& get_graph_def() const { return graph_def; }
67+
68+
tensorflow::Tensor& get_input_tensor(int i=0) { return input_tensors[i].second; } // .first is the name
69+
const tensorflow::Tensor& get_input_tensor(int i=0) const { return input_tensors[i].second; }
70+
71+
tensorflow::Tensor& get_output_tensor(int i=0) { return output_tensors[i]; }
72+
const tensorflow::Tensor& get_output_tensors(int i=0) const { return output_tensors[i]; }
73+
74+
75+
//--------------------------------------------------------------
76+
// run the model on this->input_tensors
77+
// output is written to this->output_tensors
78+
// returns true if successful, otherwise returns false
79+
// use tensor <--> OF Format conversion functions in ofxMSATFUtils (to convert ofImage, ofPixels, std::vector <--> tensor)
80+
bool run();
81+
82+
83+
//--------------------------------------------------------------
84+
// convenience methods for run
85+
86+
// if the model expects an image, conversion to tensor done internally
87+
// output written to this->output_tensors as usual
88+
// img_in must be same format (e.g. float32, int etc.) as tensor!
89+
// optional xxx_range parameters are for automatic mapping of values, e.g. 0...1 <--> -1...1 (leave blank to bypass)
90+
// (image_range -> model_input_range before going in. model_output_range -> image_range after coming out)
91+
// TODO: assuming batch size 1 for now
92+
93+
// if model expects an image
94+
template<typename T>
95+
bool run(const ofImage_<T>& img_in, ofVec2f model_in_range=ofVec2f(), ofVec2f image_range=ofVec2f(0, 1));
96+
97+
// if the model also outputs an image, conversion to tensor done internally
98+
// output image written to img_out (doesn't have to be pre-allocated, but if it is pre-allocated, it will be quicker
99+
template<typename T>
100+
bool run(const ofImage_<T>& img_in, ofImage_<T>& img_out, ofVec2f model_in_range=ofVec2f(), ofVec2f model_out_range=ofVec2f(), ofVec2f image_range=ofVec2f(0, 1));
101+
102+
103+
protected:
104+
string model_path; // path to file containing model data
105+
vector<string> input_op_names; // name(s) of operators for input (i.e. to feed)
106+
vector<string> output_op_names; // name(s) of operators for output (i.e. to fetch)
107+
string name; // name of model (e.g. for gui)
108+
109+
msa::tf::Session_ptr session;
110+
msa::tf::GraphDef_ptr graph_def;
111+
112+
vector<pair<string, tensorflow::Tensor> > input_tensors; // input(s) to the model (using tensorflow format vector< pair<name, tensor> >)
113+
vector<tensorflow::Tensor> output_tensors; // output(s) of the model
114+
115+
void close();
116+
};
117+
118+
119+
//--------------------------------------------------------------
120+
template<typename T>
121+
bool SimpleModel::run(const ofImage_<T>& img_in, ofVec2f model_in_range, ofVec2f image_range) {
122+
// dump img_in into input tensor. do not use memcpy. map range as nessecary
123+
msa::tf::image_to_tensor(img_in, this->get_input_tensor(), false, image_range, model_in_range);
124+
return this->run();
125+
}
126+
127+
128+
//--------------------------------------------------------------
129+
template<typename T>
130+
bool SimpleModel::run(const ofImage_<T>& img_in, ofImage_<T>& img_out, ofVec2f model_in_range, ofVec2f model_out_range, ofVec2f image_range) {
131+
if( this->run(img_in, model_in_range, image_range) ) {
132+
// dump output tensor into img_out. do not use memcpy, map range as nessecary
133+
msa::tf::tensor_to_image(this->get_output_tensor(), img_out, false, model_out_range, image_range);
134+
}
135+
}
136+
137+
}
138+
}

src/ofxMSATensorFlow.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "ofxMSATFVizUtils.h"
1616
#include "ofxMSATFImageClassifier.h"
1717
#include "ofxMSATFLayerVisualizer.h"
18+
#include "ofxMSATFSimpleModel.h"
19+
1820
#include "ofxMSAMathUtils.h"
1921

2022
namespace msa {

0 commit comments

Comments
 (0)