Skip to content

Commit 9b05be6

Browse files
committed
Runtime check of asked data in tensor::get_data serizba#91
1 parent 285a22f commit 9b05be6

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

include/cppflow/tensor.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,16 @@ namespace cppflow {
215215

216216
template<typename T>
217217
std::vector<T> tensor::get_data() const {
218+
219+
// Check if asked datatype and tensor datatype match
220+
if (this->dtype() != deduce_tf_type<T>()) {
221+
auto type1 = cppflow::to_string(deduce_tf_type<T>());
222+
auto type2 = cppflow::to_string(this->dtype());
223+
auto error = "Datatype in function get_data (" + type1 + ") does not match tensor datatype (" + type2 + ")";
224+
throw std::runtime_error(error);
225+
}
226+
227+
218228
auto res_tensor = get_tensor();
219229

220230
// Check tensor data is not empty

0 commit comments

Comments
 (0)