@@ -78,26 +78,26 @@ def code(self):
78
78
'string' : '''
79
79
std::vector<std::size_t> {0}_sizes; {0}_sizes.reserve({0}.size());
80
80
std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_sizes), [](const auto& s) {{ return s.size();}});
81
- TFE_OpSetAttrStringList(op.get(), "{orig:}", reinterpret_cast<const void *const *>({0}.data()), {0}_sizes.data(), {0}.size());
81
+ TFE_OpSetAttrStringList(op.get(), "{orig:}", reinterpret_cast<const void *const *>({0}.data()), {0}_sizes.data(), static_cast<int>( {0}.size() ));
82
82
''' ,
83
- 'int' : 'TFE_OpSetAttrIntList(op.get(), "{orig:}", {0}.data(), {0}.size());' ,
84
- 'float' : 'TFE_OpSetAttrFloatList(op.get(), "{orig:}", {0}.data(), {0}.size());' ,
83
+ 'int' : 'TFE_OpSetAttrIntList(op.get(), "{orig:}", {0}.data(), static_cast<int>( {0}.size() ));' ,
84
+ 'float' : 'TFE_OpSetAttrFloatList(op.get(), "{orig:}", {0}.data(), static_cast<int>( {0}.size() ));' ,
85
85
'bool' : 'TFE_OpSetAttrBoolList(op.get(), "{orig:}", std::vector<unsigned char>({0}.begin(), {0}.end()).data(), {0}.size());' ,
86
- 'type' : 'TFE_OpSetAttrTypeList(op.get(), "{orig:}", reinterpret_cast<const enum TF_DataType *>({0}.data()), {0}.size());' ,
86
+ 'type' : 'TFE_OpSetAttrTypeList(op.get(), "{orig:}", reinterpret_cast<const enum TF_DataType *>({0}.data()), static_cast<int>( {0}.size() ));' ,
87
87
'shape' : '''
88
88
std::vector<const int64_t*> {0}_values; {0}_values.reserve({0}.size());
89
89
std::vector<int> {0}_ndims; {0}_ndims.reserve({0}.size());
90
90
std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_values), [](const auto& v) {{ return v.data();}});
91
- std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_ndims), [](const auto& v) {{ return v.size();}});
92
- TFE_OpSetAttrShapeList(op.get(), "{orig:}", {0}_values.data(), {0}_ndims.data(), {0}.size(), context::get_status());
91
+ std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_ndims), [](const auto& v) {{ return static_cast<int>( v.size() );}});
92
+ TFE_OpSetAttrShapeList(op.get(), "{orig:}", {0}_values.data(), {0}_ndims.data(), static_cast<int>( {0}.size() ), context::get_status());
93
93
status_check(context::get_status());
94
94
''' ,
95
95
}[self .type ].format (self .name .replace ('template' , 'template_arg' ), orig = self .name )).replace ('\n ' , '\n ' )
96
96
97
97
else :
98
98
return textwrap .dedent ({
99
99
'shape' : '''
100
- TFE_OpSetAttrShape(op.get(), "{orig:}", {0}.data(), {0}.size(), context::get_status());
100
+ TFE_OpSetAttrShape(op.get(), "{orig:}", {0}.data(), static_cast<int>( {0}.size() ), context::get_status());
101
101
status_check(context::get_status());
102
102
''' ,
103
103
'int' : 'TFE_OpSetAttrInt(op.get(), "{orig:}", {0});' ,
@@ -172,7 +172,7 @@ def code(self):
172
172
add_inputs_list = textwrap .dedent ('''
173
173
std::vector<TFE_TensorHandle*> {0}_handles; {0}_handles.reserve({0}.size());
174
174
std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_handles), [](const auto& t) {{ return t.tfe_handle.get();}});
175
- TFE_OpAddInputList(op.get(), {0}_handles.data(), {0}.size(), context::get_status());
175
+ TFE_OpAddInputList(op.get(), {0}_handles.data(), static_cast<int>( {0}.size() ), context::get_status());
176
176
status_check(context::get_status());
177
177
''' ).replace ('\n ' , '\n ' )
178
178
0 commit comments