Skip to content

Commit 9510431

Browse files
Filip Binkiewiczapaszke
authored andcommitted
Fix nonzero bug
1 parent e72e2db commit 9510431

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

torch/lib/THD/master_worker/master/generic/THDTensorMath.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ void THDTensor_(nonzero)(THDLongTensor *subscript, THDTensor *tensor) {
7070
),
7171
THDState::s_current_worker
7272
);
73-
ptrdiff_t numel = receiveValueFromWorker<ptrdiff_t>(tensor->storage->node_id);
73+
long long numel = receiveValueFromWorker<long long>(tensor->storage->node_id);
7474
THDLongTensor__resize2d(subscript, numel, tensor->nDimension);
7575
}
7676

torch/lib/THD/master_worker/worker/dispatch/TensorMath.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ static void tensorNonzero(rpc::RPCMessage& raw_message) {
5353
thpp::Tensor *tensor = unpackRetrieveTensor(raw_message);
5454
finalize(raw_message);
5555
tensor->nonzero(*subscript);
56-
sendValueToMaster((double)subscript->numel());
56+
long long numel = subscript->sizes().size() > 0 ? subscript->sizes()[0] : 0;
57+
sendValueToMaster(numel);
5758
}
5859

5960
static void tensorIndexSelect(rpc::RPCMessage& raw_message) {

0 commit comments

Comments
 (0)