Skip to content

Commit 0536866

Browse files
authored
Merge pull request serizba#1 from CarlPoirier/master
add tensor shape getter
2 parents 438c35e + c996c98 commit 0536866

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

src/Tensor.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ std::vector<T> Tensor::get_data() {
147147
return std::vector<T>(T_data, T_data + size);
148148
}
149149

150-
150+
std::vector<int64_t> Tensor::get_shape() {
151+
return shape;
152+
}
151153

152154
template<typename T>
153155
TF_DataType Tensor::deduce_type() {

src/Tensor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class Tensor {
4040
template<typename T>
4141
std::vector<T> get_data();
4242

43+
std::vector<int64_t> get_shape();
44+
4345
private:
4446
TF_Tensor* val;
4547
TF_Output op;

0 commit comments

Comments
 (0)