File tree Expand file tree Collapse file tree 2 files changed +15
-0
lines changed Expand file tree Collapse file tree 2 files changed +15
-0
lines changed Original file line number Diff line number Diff line change @@ -34,6 +34,19 @@ cudnnDataType_t getCudnnDataType(const thpp::Tensor& tensor)
3434 throw std::runtime_error (msg);
3535}
3636
37+ cudnnDataType_t getCudnnDataType (const at::Tensor& tensor) {
38+ if (tensor.type ().scalarType () == at::kFloat ) {
39+ return CUDNN_DATA_FLOAT;
40+ } else if (tensor.type ().scalarType () == at::kDouble ) {
41+ return CUDNN_DATA_DOUBLE;
42+ } else if (tensor.type ().scalarType () == at::kHalf ) {
43+ return CUDNN_DATA_HALF;
44+ }
45+ std::string msg (" getCudnnDataType() not supported for " );
46+ msg += at::toString (tensor.type ().scalarType ());
47+ throw std::runtime_error (msg);
48+ }
49+
3750PyObject * getTensorClass (PyObject *args)
3851{
3952 for (int i = 0 ; i < PyTuple_Size (args); i++) {
Original file line number Diff line number Diff line change 77#include < cudnn.h>
88#include " ../Types.h"
99#include < THPP/THPP.h>
10+ #include < ATen/Tensor.h>
1011
1112namespace torch { namespace cudnn {
1213
1314PyObject * getTensorClass (PyObject *args);
1415cudnnDataType_t getCudnnDataType (PyObject *tensorClass);
1516cudnnDataType_t getCudnnDataType (const thpp::Tensor& tensor);
17+ cudnnDataType_t getCudnnDataType (const at::Tensor& tensor);
1618void _THVoidTensor_assertContiguous (THVoidTensor *tensor, const std::string& name);
1719
1820#define THVoidTensor_assertContiguous (tensor ) \
You can’t perform that action at this time.
0 commit comments