Skip to content

Commit cba22ec

Browse files
committed
Changes to generator file to produce static cast of size_t values being passed into TensorFlow int parameters.
1 parent ad3a948 commit cba22ec

File tree

2 files changed

+35
-9
lines changed

2 files changed

+35
-9
lines changed

include/cppflow/ops_generator/generator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,26 +78,26 @@ def code(self):
7878
'string' : '''
7979
std::vector<std::size_t> {0}_sizes; {0}_sizes.reserve({0}.size());
8080
std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_sizes), [](const auto& s) {{ return s.size();}});
81-
TFE_OpSetAttrStringList(op.get(), "{orig:}", reinterpret_cast<const void *const *>({0}.data()), {0}_sizes.data(), {0}.size());
81+
TFE_OpSetAttrStringList(op.get(), "{orig:}", reinterpret_cast<const void *const *>({0}.data()), {0}_sizes.data(), static_cast<int>({0}.size()));
8282
''',
83-
'int' : 'TFE_OpSetAttrIntList(op.get(), "{orig:}", {0}.data(), {0}.size());',
84-
'float' : 'TFE_OpSetAttrFloatList(op.get(), "{orig:}", {0}.data(), {0}.size());',
83+
'int' : 'TFE_OpSetAttrIntList(op.get(), "{orig:}", {0}.data(), static_cast<int>({0}.size()));',
84+
'float' : 'TFE_OpSetAttrFloatList(op.get(), "{orig:}", {0}.data(), static_cast<int>({0}.size()));',
8585
'bool' : 'TFE_OpSetAttrBoolList(op.get(), "{orig:}", std::vector<unsigned char>({0}.begin(), {0}.end()).data(), {0}.size());',
86-
'type' : 'TFE_OpSetAttrTypeList(op.get(), "{orig:}", reinterpret_cast<const enum TF_DataType *>({0}.data()), {0}.size());',
86+
'type' : 'TFE_OpSetAttrTypeList(op.get(), "{orig:}", reinterpret_cast<const enum TF_DataType *>({0}.data()), static_cast<int>({0}.size()));',
8787
'shape' : '''
8888
std::vector<const int64_t*> {0}_values; {0}_values.reserve({0}.size());
8989
std::vector<int> {0}_ndims; {0}_ndims.reserve({0}.size());
9090
std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_values), [](const auto& v) {{ return v.data();}});
91-
std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_ndims), [](const auto& v) {{ return v.size();}});
92-
TFE_OpSetAttrShapeList(op.get(), "{orig:}", {0}_values.data(), {0}_ndims.data(), {0}.size(), context::get_status());
91+
std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_ndims), [](const auto& v) {{ return static_cast<int>(v.size());}});
92+
TFE_OpSetAttrShapeList(op.get(), "{orig:}", {0}_values.data(), {0}_ndims.data(), static_cast<int>({0}.size()), context::get_status());
9393
status_check(context::get_status());
9494
''',
9595
}[self.type].format(self.name.replace('template', 'template_arg'), orig=self.name)).replace('\n', '\n ')
9696

9797
else:
9898
return textwrap.dedent({
9999
'shape' : '''
100-
TFE_OpSetAttrShape(op.get(), "{orig:}", {0}.data(), {0}.size(), context::get_status());
100+
TFE_OpSetAttrShape(op.get(), "{orig:}", {0}.data(), static_cast<int>({0}.size()), context::get_status());
101101
status_check(context::get_status());
102102
''',
103103
'int' : 'TFE_OpSetAttrInt(op.get(), "{orig:}", {0});',
@@ -172,7 +172,7 @@ def code(self):
172172
add_inputs_list = textwrap.dedent('''
173173
std::vector<TFE_TensorHandle*> {0}_handles; {0}_handles.reserve({0}.size());
174174
std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_handles), [](const auto& t) {{ return t.tfe_handle.get();}});
175-
TFE_OpAddInputList(op.get(), {0}_handles.data(), {0}.size(), context::get_status());
175+
TFE_OpAddInputList(op.get(), {0}_handles.data(), static_cast<int>({0}.size()), context::get_status());
176176
status_check(context::get_status());
177177
''').replace('\n', '\n ')
178178

include/cppflow/raw_ops.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4127,7 +4127,7 @@ inline tensor choose_fastest_dataset(const std::vector<tensor>&input_datasets, i
41274127

41284128

41294129
// Attributes
4130-
TFE_OpSetAttrInt(op.get(), "N", static_cast<int>(input_datasets.size()));
4130+
TFE_OpSetAttrInt(op.get(), "N", input_datasets.size());
41314131
TFE_OpSetAttrInt(op.get(), "num_experiments", num_experiments);
41324132
TFE_OpSetAttrTypeList(op.get(), "output_types", reinterpret_cast<const enum TF_DataType *>(output_types.data()), static_cast<int>(output_types.size()));
41334133

@@ -6132,6 +6132,32 @@ inline tensor decode_gif(const tensor& contents) {
61326132
}
61336133

61346134

6135+
inline tensor decode_image(const tensor& contents, int64_t channels=0, datatype dtype=static_cast<datatype>(4), bool expand_animations=true) {
6136+
6137+
// Define Op
6138+
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(), "DecodeImage", context::get_status()), &TFE_DeleteOp);
6139+
status_check(context::get_status());
6140+
6141+
// Required input arguments
6142+
6143+
TFE_OpAddInput(op.get(), contents.tfe_handle.get(), context::get_status());
6144+
status_check(context::get_status());
6145+
6146+
6147+
// Attributes
6148+
TFE_OpSetAttrInt(op.get(), "channels", channels);
6149+
TFE_OpSetAttrType(op.get(), "dtype", dtype);
6150+
TFE_OpSetAttrBool(op.get(), "expand_animations", (unsigned char)expand_animations);
6151+
6152+
// Execute Op
6153+
int num_outputs_op = 1;
6154+
TFE_TensorHandle* res[1] = {nullptr};
6155+
TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
6156+
status_check(context::get_status());
6157+
return tensor(res[0]);
6158+
}
6159+
6160+
61356161
inline tensor decode_j_s_o_n_example(const tensor& json_examples) {
61366162

61376163
// Define Op

0 commit comments

Comments
 (0)