Skip to content

Commit 581921f

Browse files
committed
support unsafe functions for getting/constructor tensors from TH objects for backward compat.
1 parent f0788af commit 581921f

File tree

15 files changed

+44
-81
lines changed

15 files changed

+44
-81
lines changed

CMakeLists.txt

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,6 @@ INSTALL(TARGETS ATen
177177
LIBRARY DESTINATION "${TENSOR_LIB_INSTALL_LIB_DIR}"
178178
ARCHIVE DESTINATION "${TENSOR_LIB_INSTALL_LIB_DIR}")
179179

180-
ADD_EXECUTABLE(scalar_test test/scalar_test.cpp)
181-
TARGET_LINK_LIBRARIES(scalar_test ATen)
182-
183-
ADD_EXECUTABLE(basic test/basic.cpp)
184-
TARGET_LINK_LIBRARIES(basic ATen)
185-
186-
add_executable(atest test/atest.cpp)
187-
target_link_libraries(atest ATen)
188-
189180
FOREACH(HEADER ${base_h})
190181
INSTALL(FILES ${HEADER} DESTINATION ${TENSOR_LIB_INSTALL_INCLUDE_DIR}/ATen)
191182
ENDFOREACH()

CheckGenerator.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#pragma once
2+
#include "ATen/CPUGenerator.h"
3+
4+
namespace at {
5+
static inline CPUGenerator * check_generator(Generator* expr) {
6+
if(auto result = dynamic_cast<CPUGenerator*>(expr))
7+
return result;
8+
runtime_error("Expected a 'CPUGenerator' but found 'CUDAGenerator'");
9+
}
10+
}

Context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Context {
1717
initCUDAIfNeeded(p);
1818
auto & type = type_registry[static_cast<int>(p)][static_cast<int>(s)];
1919
if(!type)
20-
runtime_error("%s%s%sType is not enabled.",toString(p),toString(s));
20+
runtime_error("%s%sType is not enabled.",toString(p),toString(s));
2121
return *type;
2222
}
2323
Generator & defaultGenerator(Backend p) {

TensorImpl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ struct TensorImpl {
2020
virtual int64_t dim() = 0;
2121
virtual Scalar localScalar() = 0;
2222
virtual void assign_(Scalar s) = 0;
23+
virtual void * unsafeGetTH() = 0;
2324
void retain() {
2425
++refcount;
2526
}

Utils.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include "ATen/CPUGenerator.h"
4-
53
namespace at {
64

75
#define AT_ASSERT(cond, ...) if (! (cond) ) { at::runtime_error(__VA_ARGS__); }
@@ -21,12 +19,4 @@ static inline T* checked_cast(Base* expr, const char * name, int pos) {
2119
T::typeString(),expr->type().toString(),pos,name);
2220
}
2321

24-
struct CPUGenerator;
25-
struct Generator;
26-
static inline CPUGenerator * check_generator(Generator* expr) {
27-
if(auto result = dynamic_cast<CPUGenerator*>(expr))
28-
return result;
29-
runtime_error("Expected a 'CPUGenerator' but found 'CUDAGenerator'");
30-
}
31-
3222
} // at

gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import function_wrapper
1010
import dispatch_macros
1111
import copy_wrapper
12+
1213
from code_template import CodeTemplate
1314

1415

@@ -197,6 +198,7 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
197198
top_env['type_registrations'].append(type_register)
198199
top_env['type_headers'].append(
199200
'#include "ATen/{}.h"'.format(env['Type']))
201+
200202
return env
201203

202204

scratch.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

templates/Tensor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ struct Tensor {
109109
template<typename T>
110110
T * data() const;
111111

112+
void * unsafeGetTH() {
113+
return pImpl->unsafeGetTH();
114+
}
115+
112116
//toLongData(), toFloatData() etc.
113117
#define TO_TYPE_DATA(T,name,_) \
114118
T * to##name##Data() const;

templates/TensorDerived.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ int64_t ${Tensor}::dim() {
3232
const char * ${Tensor}::typeString() {
3333
return "${Type}";
3434
}
35+
void * ${Tensor}::unsafeGetTH() {
36+
return tensor;
37+
}
3538

3639
${TensorDenseOrSparse}
3740

templates/TensorDerived.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ struct ${Tensor} : public TensorImpl {
2020
virtual int64_t dim() override;
2121
virtual Scalar localScalar() override;
2222
virtual void assign_(Scalar s) override;
23+
virtual void * unsafeGetTH() override;
2324
static const char * typeString();
2425

2526
//TODO(zach): sort of friend permissions later so this

0 commit comments

Comments
 (0)