Skip to content

Commit c38997c

Browse files
committed
Merge branch 'AfaqSabirIBEX-AddSessionConfigOptions'
2 parents 50930c6 + 6ca8378 commit c38997c

File tree

5 files changed

+32
-4
lines changed

5 files changed

+32
-4
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import tensorflow as tf
2+
3+
def create_serialized_options(fraction, growth):
4+
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=fraction, allow_growth=growth)
5+
config = tf.ConfigProto(gpu_options=gpu_options)
6+
serialized = config.SerializeToString()
7+
return '{' + ','.join(list(map(hex, serialized))) + '}'
8+
9+
if __name__ == "__main__":
10+
print("Create serialized options which allow TF to use a certain percentage of GPU memory and allow TF to expand this memory if required.")
11+
for i in range(1, 10):
12+
memory_fraction_to_use = 0.1
13+
enable_memory_growth = True
14+
print("GPU memory to be used: ", i * 10.0, "%")
15+
print(create_serialized_options(memory_fraction_to_use, enable_memory_growth))

examples/load_model/create_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def example_1():
1313
i = tf.initializers.global_variables()
1414

1515
# Write the model definition
16-
with open('models/load_model.pb', 'wb') as f:
16+
with open('load_model.pb', 'wb') as f:
1717
f.write(tf.get_default_graph().as_graph_def().SerializeToString())
1818

1919

examples/load_model/main.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
#include <iomanip>
1010

1111
int main() {
12+
// Load model with a path to the .pb file.
13+
// An optional std::vector<uint8_t> parameter can be used to supply Tensorflow with
14+
// session options. The vector must represent a serialized ConfigProto which can be
15+
// generated manually in python. See create_config_options.py.
16+
// Example:
17+
// const std::vector<uint8_t> ModelConfigOptions = { 0x32, 0xb, 0x9, 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xb9, 0x3f, 0x20, 0x1 };
18+
// Model model("../model.pb", ModelConfigOptions);
1219
Model model("../model.pb");
1320
model.init();
1421

include/Model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ class Tensor;
1919

2020
class Model {
2121
public:
22-
explicit Model(const std::string&);
22+
// Pass a path to the model file and optional Tensorflow config options. See examples/load_model/main.cpp.
23+
explicit Model(const std::string& model_filename, const std::vector<uint8_t>& config_options = {});
2324

2425
// Rule of five, moving is easy as the pointers can be copied, copying not as i have no idea how to copy
2526
// the contents of the pointer (i guess dereferencing won't do a deep copy)

src/Model.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,19 @@
44

55
#include "../include/Model.h"
66

7-
Model::Model(const std::string& model_filename) {
8-
7+
Model::Model(const std::string& model_filename, const std::vector<uint8_t>& config_options) {
98
this->status = TF_NewStatus();
109
this->graph = TF_NewGraph();
1110

1211
// Create the session.
1312
TF_SessionOptions* sess_opts = TF_NewSessionOptions();
1413

14+
if (!config_options.empty())
15+
{
16+
TF_SetConfig(sess_opts, static_cast<const void*>(config_options.data()), config_options.size(), this->status);
17+
this->status_check(true);
18+
}
19+
1520
this->session = TF_NewSession(this->graph, sess_opts, this->status);
1621
TF_DeleteSessionOptions(sess_opts);
1722

0 commit comments

Comments
 (0)