Skip to content

Commit c6d6cbe

Browse files
committed
Check that all tensors are on the same GPU in cuDNN bindings
1 parent 85e82e8 commit c6d6cbe

File tree

3 files changed

+34
-0
lines changed

3 files changed

+34
-0
lines changed

torch/csrc/cudnn/BatchNorm.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ void cudnn_batch_norm_forward(
6262
THVoidTensor* save_mean, THVoidTensor* save_var, bool training,
6363
double exponential_average_factor, double epsilon)
6464
{
65+
assertSameGPU(dataType, input, output, weight, bias, running_mean, running_var,
66+
save_mean, save_var);
6567
cudnnBatchNormMode_t mode;
6668
if (input->nDimension == 2) {
6769
mode = CUDNN_BATCHNORM_PER_ACTIVATION;
@@ -120,6 +122,8 @@ void cudnn_batch_norm_backward(
120122
THVoidTensor* save_mean, THVoidTensor* save_var, bool training,
121123
double epsilon)
122124
{
125+
assertSameGPU(dataType, input, grad_output, grad_input, grad_weight, grad_bias, weight,
126+
running_mean, running_var, save_mean, save_var);
123127
cudnnBatchNormMode_t mode;
124128
if (input->nDimension == 2) {
125129
mode = CUDNN_BATCHNORM_PER_ACTIVATION;

torch/csrc/cudnn/Conv.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ void cudnn_convolution_forward(
285285
THVoidTensor* input, THVoidTensor* weight, THVoidTensor* output,
286286
Convolution* info, bool benchmark)
287287
{
288+
assertSameGPU(dataType, input, weight, output);
288289
int groups = info->groups;
289290

290291
cudnnConvolutionFwdAlgo_t fwdAlg;
@@ -309,6 +310,7 @@ void cudnn_convolution_add_bias(
309310
THVoidTensor* bias, THVoidTensor* output,
310311
Convolution* info)
311312
{
313+
assertSameGPU(dataType, bias, output);
312314
CHECK_ARG(output->nDimension <= 5);
313315
TensorDescriptor& bdesc = info->bdesc;
314316

@@ -329,6 +331,7 @@ void cudnn_convolution_backward_data(
329331
THVoidTensor* gradOutput, THVoidTensor* gradInput, THVoidTensor* weight,
330332
Convolution* info, bool benchmark)
331333
{
334+
assertSameGPU(dataType, gradOutput, gradInput, weight);
332335
int groups = info->params.groups;
333336

334337
cudnnConvolutionBwdDataAlgo_t bwdDataAlg;
@@ -353,6 +356,7 @@ void cudnn_convolution_backward_filter(
353356
THVoidTensor* gradOutput, THVoidTensor* input, THVoidTensor* gradWeight,
354357
Convolution* info, bool benchmark)
355358
{
359+
assertSameGPU(dataType, gradOutput, input, gradWeight);
356360
int groups = info->params.groups;
357361

358362
cudnnConvolutionBwdFilterAlgo_t bwdFilterAlg;
@@ -380,6 +384,7 @@ void cudnn_convolution_backward_bias(
380384
THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
381385
THVoidTensor* gradOutput, THVoidTensor* gradBias, Convolution* info)
382386
{
387+
assertSameGPU(dataType, gradOutput, gradBias);
383388
Constant one(dataType, 1);
384389
Constant zero(dataType, 0);
385390
void* gradOutput_ptr = tensorPointer(dataType, gradOutput, 0, 1, 0);

torch/csrc/cudnn/Exceptions.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,42 @@
11
#ifndef THP_CUDNN_EXCEPTIONS_INC
22
#define THP_CUDNN_EXCEPTIONS_INC
33

4+
#include <THC/THC.h>
45
#include <cudnn.h>
56
#include <string>
67
#include <stdexcept>
78
#include <sstream>
89

10+
#include "Types.h"
911

1012
#define CHECK_ARG(cond) _CHECK_ARG(cond, #cond, __FILE__, __LINE__)
1113

14+
extern THCState* state;
1215

1316
namespace torch { namespace cudnn {
1417

18+
template<typename ...T>
19+
void assertSameGPU(cudnnDataType_t dataType, T* ... tensors) {
20+
static_assert(std::is_same<THVoidTensor, typename std::common_type<T...>::type>::value,
21+
"all arguments to assertSameGPU have to be THVoidTensor*");
22+
int is_same;
23+
if (dataType == CUDNN_DATA_FLOAT) {
24+
is_same = THCudaTensor_checkGPU(state, sizeof...(T),
25+
reinterpret_cast<THCudaTensor*>(tensors)...);
26+
} else if (dataType == CUDNN_DATA_HALF) {
27+
is_same = THCudaHalfTensor_checkGPU(state, sizeof...(T),
28+
reinterpret_cast<THCudaHalfTensor*>(tensors)...);
29+
} else if (dataType == CUDNN_DATA_DOUBLE) {
30+
is_same = THCudaDoubleTensor_checkGPU(state, sizeof...(T),
31+
reinterpret_cast<THCudaDoubleTensor*>(tensors)...);
32+
} else {
33+
throw std::runtime_error("unknown cuDNN data type");
34+
}
35+
if (!is_same) {
36+
throw std::runtime_error("tensors are on different GPUs");
37+
}
38+
}
39+
1540
class cudnn_exception : public std::runtime_error {
1641
public:
1742
cudnnStatus_t status;

0 commit comments

Comments
 (0)