Skip to content

Commit 004301c

Browse files
ebrevdotensorflower-gardener
authored andcommitted
Bugfix to bug created in previous TensorArray change:
TF_REGISTER_ALL_TYPES does not register complex types on android builds, but matmul etc. require it; and use SetZero functor. Reverted that change, and cleaned up the TensorArray LockedRead to handle only numeric types with SetZero; and then only call it when the size of the Tensor is > 0. Change: 121546296
1 parent 54c55d7 commit 004301c

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

tensorflow/core/kernels/constant_op.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,16 @@ struct SetZeroFunctor<CPUDevice, T> {
122122
};
123123

124124
#define DEFINE_SETZERO_CPU(T) template struct SetZeroFunctor<CPUDevice, T>;
125-
TF_CALL_ALL_TYPES(DEFINE_SETZERO_CPU)
125+
DEFINE_SETZERO_CPU(Eigen::half);
126+
DEFINE_SETZERO_CPU(float);
127+
DEFINE_SETZERO_CPU(double);
128+
DEFINE_SETZERO_CPU(uint8);
129+
DEFINE_SETZERO_CPU(int8);
130+
DEFINE_SETZERO_CPU(int16);
131+
DEFINE_SETZERO_CPU(int32);
132+
DEFINE_SETZERO_CPU(int64);
133+
DEFINE_SETZERO_CPU(complex64);
134+
DEFINE_SETZERO_CPU(complex128);
126135
#undef DEFINE_SETZERO_CPU
127136

128137
} // end namespace functor

tensorflow/core/kernels/tensor_array.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
6060
}
6161

6262
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
63-
TF_CALL_ALL_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
63+
TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
6464
#undef TENSOR_ARRAY_SET_ZERO_CPU
6565

6666
#if GOOGLE_CUDA

tensorflow/core/kernels/tensor_array.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
7777
Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value);
7878

7979
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
80-
TF_CALL_ALL_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
80+
TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
8181
#undef TENSOR_ARRAY_SET_ZERO_CPU
8282

8383
#if GOOGLE_CUDA
@@ -469,8 +469,10 @@ Status TensorArray::LockedRead(OpKernelContext* ctx, const int32 index,
469469
Tensor* tensor_t;
470470
TF_RETURN_IF_ERROR(
471471
ctx->allocate_persistent(dtype_, t.shape, &t.tensor, &tensor_t));
472-
Status s = tensor_array::TensorSetZero<Device, T>(ctx, tensor_t);
473-
if (!s.ok()) return s;
472+
if (t.shape.num_elements() > 0) {
473+
Status s = tensor_array::TensorSetZero<Device, T>(ctx, tensor_t);
474+
if (!s.ok()) return s;
475+
}
474476
}
475477

476478
// Data is available inside the tensor, copy the reference over.

0 commit comments

Comments
 (0)