1
+ #include < utility>
2
+
1
3
//
2
4
// Created by sergio on 13/05/19.
3
5
//
@@ -19,9 +21,6 @@ Tensor::Tensor(const Model& model, const std::string& operation) {
19
21
// Get number of dimensions
20
22
int n_dims = TF_GraphGetTensorNumDims (model.graph , this ->op , model.status );
21
23
22
- // Dimension is not known
23
- error_check (n_dims != -1 , " Shape of tensors must be known" );
24
-
25
24
// DataType
26
25
this ->type = TF_OperationOutputType (this ->op );
27
26
@@ -37,8 +36,7 @@ Tensor::Tensor(const Model& model, const std::string& operation) {
37
36
this ->shape = std::vector<int64_t >(dims, dims + n_dims);
38
37
39
38
// Only one dimension can be unknown using this constructor
40
- error_check (std::count (this ->shape .begin (), this ->shape .end (), -1 ) <= 1 ,
41
- " At most one dimension can be unknown" );
39
+ // error_check(std::count(this->shape.begin(), this->shape.end(), -1) <= 1, "At most one dimension can be unknown");
42
40
}
43
41
44
42
this ->flag = 0 ;
@@ -82,6 +80,12 @@ void Tensor::set_data(std::vector<T> new_data) {
82
80
// Check type
83
81
this ->error_check (deduce_type<T>() == this ->type , " Provided type is different from Tensor expected type" );
84
82
83
+ // Dimensions must be known
84
+ this ->error_check (!this ->shape .empty (), " Shape of the input Tensor is not known, please provide a shape" );
85
+
86
+ // At most one dimension can be unknown
87
+ this ->error_check (std::count (this ->shape .begin (), this ->shape .end (), -1 ) >= -1 , " At most one dimension can be unknown, please provide a shape" );
88
+
85
89
// Check number of elements
86
90
auto exp_size = std::abs (std::accumulate (this ->shape .begin (), this ->shape .end (), 1 , std::multiplies<>()));
87
91
@@ -107,6 +111,17 @@ void Tensor::set_data(std::vector<T> new_data) {
107
111
this ->flag = 1 ;
108
112
}
109
113
114
+ template <typename T> void Tensor::set_data (std::vector<T> new_data, const std::vector<int64_t >& new_shape) {
115
+
116
+ this ->error_check (this ->shape .empty () || this ->shape .size () == new_shape.size (), " Provided shape has different number of dimensions" );
117
+ auto old_shape = this ->shape ;
118
+
119
+ this ->shape = new_shape;
120
+ this ->set_data (new_data);
121
+
122
+ this ->shape = old_shape;
123
+ }
124
+
110
125
template <typename T>
111
126
std::vector<T> Tensor::get_data () {
112
127
@@ -159,6 +174,20 @@ TF_DataType Tensor::deduce_type() {
159
174
return TF_UINT64;
160
175
}
161
176
177
+ void Tensor::deduce_shape (const Model& model) {
178
+ // Get number of dimensions
179
+ int n_dims = TF_NumDims (this ->val );
180
+
181
+ // If is not a scalar
182
+ if (n_dims > 0 ) {
183
+ // Get dimensions
184
+ this ->shape = std::vector<int64_t >(n_dims, -1 );
185
+ for (int i=0 ; i<n_dims; i++) {
186
+ this ->shape [i] = TF_Dim (this ->val , i);
187
+ }
188
+ }
189
+ }
190
+
162
191
163
192
// VALID deduce_type TEMPLATES
164
193
template TF_DataType Tensor::deduce_type<float >();
@@ -199,3 +228,15 @@ template void Tensor::set_data<uint16_t>(std::vector<uint16_t> new_data);
199
228
template void Tensor::set_data<uint32_t >(std::vector<uint32_t > new_data);
200
229
template void Tensor::set_data<uint64_t >(std::vector<uint64_t > new_data);
201
230
231
+ // VALID set_data TEMPLATES
232
+ template void Tensor::set_data<float >(std::vector<float > new_data, const std::vector<int64_t >& new_shape);
233
+ template void Tensor::set_data<double >(std::vector<double > new_data, const std::vector<int64_t >& new_shape);
234
+ // template void Tensor::set_data<bool>(std::vector<bool> new_data, const std::vector<int64_t>& new_shape);
235
+ template void Tensor::set_data<int8_t >(std::vector<int8_t > new_data, const std::vector<int64_t >& new_shape);
236
+ template void Tensor::set_data<int16_t >(std::vector<int16_t > new_data, const std::vector<int64_t >& new_shape);
237
+ template void Tensor::set_data<int32_t >(std::vector<int32_t > new_data, const std::vector<int64_t >& new_shape);
238
+ template void Tensor::set_data<int64_t >(std::vector<int64_t > new_data, const std::vector<int64_t >& new_shape);
239
+ template void Tensor::set_data<uint8_t >(std::vector<uint8_t > new_data, const std::vector<int64_t >& new_shape);
240
+ template void Tensor::set_data<uint16_t >(std::vector<uint16_t > new_data, const std::vector<int64_t >& new_shape);
241
+ template void Tensor::set_data<uint32_t >(std::vector<uint32_t > new_data, const std::vector<int64_t >& new_shape);
242
+ template void Tensor::set_data<uint64_t >(std::vector<uint64_t > new_data, const std::vector<int64_t >& new_shape);
0 commit comments