Skip to content

Commit b8d0c7f

Browse files
committed
checked cast does it all
1 parent f4c502e commit b8d0c7f

File tree

3 files changed

+7
-12
lines changed

3 files changed

+7
-12
lines changed

Local.cwrap

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,5 @@
8080
- TensorList tensors
8181
- int dim
8282
aten_custom_call: |
83-
size_t inputs_ = tensors_.size();
84-
std::vector<${THTensor}*> t_;
85-
t_.reserve(inputs_);
86-
for (unsigned int i = 0; i < inputs_; ++i) {
87-
t_.push_back(tensors_[i]->tensor);
88-
}
89-
${THTensor}_catArray(${state,}self_->tensor, t_.data(), inputs_, dim);
83+
${THTensor}_catArray(${state,}self_->tensor, tensors_.data(), tensors_.size(), dim);
9084
]]

Utils.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ static inline T* checked_cast(Base* expr, const char * name, int pos) {
2121
T::typeString(),expr->type().toString(),pos,name);
2222
}
2323

24-
template <typename T, typename TBase>
25-
static inline ArrayRef<T*> tensor_list_checked_cast(ArrayRef<TBase> tensors, const char * name, int pos) {
26-
std::vector<T*> casted(tensors.size());
24+
// Converts a TensorList (i.e. ArrayRef<Tensor> to the underlying TH* Tensor Pointer)
25+
template <typename T, typename TBase, typename TH>
26+
static inline std::vector<TH*> tensor_list_checked_cast(ArrayRef<TBase> tensors, const char * name, int pos) {
27+
std::vector<TH*> casted(tensors.size());
2728
for (unsigned int i = 0; i < tensors.size(); ++i) {
2829
auto *expr = tensors[i].pImpl;
2930
if (!expr) {
@@ -33,7 +34,7 @@ static inline ArrayRef<T*> tensor_list_checked_cast(ArrayRef<TBase> tensors, con
3334
}
3435
auto result = dynamic_cast<T*>(expr);
3536
if (result) {
36-
casted.push_back(result);
37+
casted.push_back(result->tensor);
3738
} else {
3839
runtime_error("Expected a Tensor of type %s but found a type %s for sequence element %u "
3940
" in sequence argument at position #%d '%s'",

function_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(self, reason):
9999
'THStride*': CodeTemplate('THLongStorageView::make(${arg_name},true)'),
100100
'real': CodeTemplate('${arg_name}.to${ScalarName}()'),
101101
'accreal': CodeTemplate('${arg_name}.to${AccScalarName}()'),
102-
'TensorList': CodeTemplate('tensor_list_checked_cast<${Tensor}>(${arg_name},"${arg_name}",${arg_pos})'),
102+
'TensorList': CodeTemplate('tensor_list_checked_cast<${Tensor}, Tensor, ${THTensor}>(${arg_name},"${arg_name}",${arg_pos})'),
103103
}
104104

105105
CHECKED_USE = {

0 commit comments

Comments
 (0)