Skip to content

Commit 07fcd97

Browse files
killeentsoumith
authored andcommitted
add cudnn data type processing for ATen tensor (pytorch#2087)
1 parent 54cabb8 commit 07fcd97

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

torch/csrc/cudnn/Types.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ cudnnDataType_t getCudnnDataType(const thpp::Tensor& tensor)
3434
throw std::runtime_error(msg);
3535
}
3636

37+
cudnnDataType_t getCudnnDataType(const at::Tensor& tensor) {
38+
if (tensor.type().scalarType() == at::kFloat) {
39+
return CUDNN_DATA_FLOAT;
40+
} else if (tensor.type().scalarType() == at::kDouble) {
41+
return CUDNN_DATA_DOUBLE;
42+
} else if (tensor.type().scalarType() == at::kHalf) {
43+
return CUDNN_DATA_HALF;
44+
}
45+
std::string msg("getCudnnDataType() not supported for ");
46+
msg += at::toString(tensor.type().scalarType());
47+
throw std::runtime_error(msg);
48+
}
49+
3750
PyObject * getTensorClass(PyObject *args)
3851
{
3952
for (int i = 0; i < PyTuple_Size(args); i++) {

torch/csrc/cudnn/Types.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
#include <cudnn.h>
88
#include "../Types.h"
99
#include <THPP/THPP.h>
10+
#include <ATen/Tensor.h>
1011

1112
namespace torch { namespace cudnn {
1213

1314
PyObject * getTensorClass(PyObject *args);
1415
cudnnDataType_t getCudnnDataType(PyObject *tensorClass);
1516
cudnnDataType_t getCudnnDataType(const thpp::Tensor& tensor);
17+
cudnnDataType_t getCudnnDataType(const at::Tensor& tensor);
1618
void _THVoidTensor_assertContiguous(THVoidTensor *tensor, const std::string& name);
1719

1820
#define THVoidTensor_assertContiguous(tensor) \

0 commit comments

Comments
 (0)