Skip to content

Commit 5e078bb

Browse files
committed
scalar flags added, and used to dispatch when there is a scalar variant of a function. broadcast annotations are used to figure out when a scalar s + A should also be converted.
1 parent 278cbba commit 5e078bb

File tree

10 files changed

+150
-23
lines changed

10 files changed

+150
-23
lines changed

Formatting.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,8 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
231231
Tensor tensor = tensor_.toType(getType(kCPU,kDouble)).contiguous();
232232
if(tensor.ndimension() == 0) {
233233
stream << std::defaultfloat << tensor.data<double>()[0] << std::endl;
234-
// stream << "[Tensor<" << typedesc() << "," << devicedesc() << "> (value)]";
235-
}
236-
else if(tensor.ndimension() == 1) {
234+
stream << "[" << tensor_.pImpl->toString() << " of size {}]";
235+
} else if(tensor.ndimension() == 1) {
237236
double scale;
238237
int64_t sz;
239238
std::tie(scale, sz) = __printFormat(stream, tensor);

THLongStorageView.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,33 @@
55
namespace at {
66

77
// make a fake storage out of a size, pointer pair...
8+
// used as an argument where THSize and THStride are passed into TH
89
class THLongStorageView {
910
public:
10-
static THLongStorageView make(ArrayRef<int64_t> ref) {
11-
return THLongStorageView(ref);
11+
static THLongStorageView make(ArrayRef<int64_t> ref, bool zero_dim_to_one = false) {
12+
return THLongStorageView(ref,zero_dim_to_one);
1213
}
1314
operator THLongStorage*() {
1415
return &storage;
1516
}
1617
private:
17-
THLongStorageView(ArrayRef<int64_t> ref) {
18-
storage.data = (long*)(ref.data());
19-
storage.size = ref.size();
18+
THLongStorageView(ArrayRef<int64_t> ref, bool zero_dim_to_one) {
19+
if(zero_dim_to_one && ref.size() == 0) {
20+
// make storage of size 0 actually a 1-length storage with 1 element
21+
// so that our 0-dim tensors get allocated as 1-dim inside TH
22+
one = 1;
23+
storage.data = &one;
24+
storage.size = 1;
25+
} else {
26+
storage.data = (long*)(ref.data());
27+
storage.size = ref.size();
28+
}
2029
storage.refcount = 0;
2130
storage.flag = 0;
2231
storage.allocator = nullptr;
2332
storage.allocatorContext = nullptr;
2433
}
25-
34+
long one;
2635
THLongStorage storage;
2736
};
2837

TensorImpl.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ namespace at {
1010
class Type;
1111
struct TensorImpl {
1212
TensorImpl(Type * type)
13-
: type_(type), refcount(1) {}
13+
: type_(type), refcount(1), is_scalar(false) {}
1414
Type & type() const {
1515
return *type_;
1616
}
17-
1817
virtual const char * toString() const = 0;
1918
virtual IntList sizes() = 0;
2019
virtual IntList strides() = 0;
20+
virtual int64_t dim() = 0;
21+
virtual Scalar localScalar() = 0;
2122
void retain() {
2223
++refcount;
2324
}
@@ -27,10 +28,23 @@ struct TensorImpl {
2728
}
2829
}
2930
virtual ~TensorImpl() {}
30-
3131
friend class Type;
32+
33+
// 0-dim patchup of TH requires us to have a flag marking
34+
// if a Tensor should be treated as 0-dim.
35+
// the generated wrapper manipulates this flag.
36+
// the setter should never be exposed in Tensor's public API
37+
// because eventually we would like isScalar() to just be dim() == 0;
38+
bool isScalar() const {
39+
return is_scalar;
40+
}
41+
void setScalar(bool s) {
42+
is_scalar = s;
43+
}
44+
3245
private:
3346
std::atomic<int> refcount;
47+
bool is_scalar;
3448
Type * type_;
3549
};
3650

copy_wrapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
throw std::runtime_error("unsupported type in copy");
3333
break;
3434
}
35+
dst.pImpl->setScalar(src.pImpl->isScalar());
3536
}
3637
""")
3738

function_wrapper.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,18 @@
4545
}
4646
""")
4747

48+
ZERO_DIM_CHECK = CodeTemplate("""\
49+
if(${check_name}.dim() == 0) {
50+
return ${method_prefix}${api_name}(${zero_dim_actuals});
51+
}""")
52+
53+
SCALAR_EXPAND = CodeTemplate("""\
54+
Tensor ${name}__;
55+
if(${name}_->isScalar()) {
56+
${name}__ = ${name}.expand(${other}.sizes());
57+
${name}_ = static_cast<${Tensor}*>(${name}__.pImpl);
58+
}
59+
""")
4860

4961
class NYIError(Exception):
5062
"""Indicates we don't support this declaration yet"""
@@ -83,8 +95,8 @@ def __init__(self, reason):
8395
'THIntegerTensor*': CodeTemplate('checked_cast<${Backend}IntTensor>(${arg_name}.pImpl,"${arg_name}",${arg_pos})'),
8496
'THStorage*': CodeTemplate('checked_cast<${Storage}>(&${arg_name},"${arg_name}",${arg_pos})'),
8597
'THGenerator*': CodeTemplate('check_generator(&${arg_name})'),
86-
'THSize*': CodeTemplate('THLongStorageView::make(${arg_name})'),
87-
'THStride*': CodeTemplate('THLongStorageView::make(${arg_name})'),
98+
'THSize*': CodeTemplate('THLongStorageView::make(${arg_name},true)'),
99+
'THStride*': CodeTemplate('THLongStorageView::make(${arg_name},true)'),
88100
'real': CodeTemplate('${arg_name}.to${ScalarName}()'),
89101
'accreal': CodeTemplate('${arg_name}.to${AccScalarName}()'),
90102

@@ -290,16 +302,27 @@ def is_actual_return_long(ret):
290302
return ret['type'] == 'long' or (backend_type_env['ScalarName'] == 'Long' and
291303
ret['type'] == 'real' or ret['type'] == 'accreal')
292304

305+
def handle_zero_dim(env,option):
306+
if 'zero_dim_dispatch_when_scalar' not in option:
307+
return []
308+
check_name = option['zero_dim_dispatch_when_scalar']
309+
zero_dim_actuals = [ arg['name']
310+
if arg['name'] != check_name else arg['name']+'.scalar()'
311+
for arg in option['formals_list'] ]
312+
return [ ZERO_DIM_CHECK.substitute(env,check_name = check_name, zero_dim_actuals=zero_dim_actuals) ]
313+
293314
def emit_body(env, option):
294315
body = []
316+
body += handle_zero_dim(env,option)
295317
# arguments are potentially duplicated because of one argument
296318
# referencing another
297319
seen_names = set()
298-
# only generated checked casts the first time we see it
299320
count = 0
300321
for arg in option['arguments']:
301322
if is_real_argument_to_wrapper(arg):
302323
count += 1
324+
325+
# only generated checked casts the first time we see it
303326
if not arg['name'] in seen_names and requires_checked_cast(arg):
304327
seen_names.add(arg['name'])
305328
if arg.get('allocate', False):
@@ -326,19 +349,25 @@ def emit_body(env, option):
326349
arg['name'], ','.join(dims)))
327350
if arg.get('cpu_zero', False):
328351
body.append("{}.zero_();".format(arg['name']))
329-
330-
option['actuals'] = get_arguments(option)
352+
# handle scalars that occur on LHS of things like a - b
353+
if 'broadcast' in arg and 'inplace' not in arg['broadcast']:
354+
other = arg['broadcast'].split(' ')[0].split(',')[0]
355+
body.append(SCALAR_EXPAND.substitute(env,
356+
name=arg['name'],
357+
other=other))
358+
359+
option['derived_actuals'] = get_arguments(option)
331360
is_cuda = backend_type_env['Backend'] == 'CUDA'
332361
is_nn = option['mode'] == 'NN'
333362
if is_cuda or is_nn:
334-
option['actuals'] = ['context->thc_state'] + option['actuals']
363+
option['derived_actuals'] = ['context->thc_state'] + option['derived_actuals']
335364

336365
if is_nn:
337366
prefix = 'THNN_{}'.format(env['THType'])
338367
else:
339368
prefix = env['THTensor'] + '_'
340369

341-
call = prefix + CodeTemplate("${cname}(${actuals})").substitute(env)
370+
call = prefix + CodeTemplate("${cname}(${derived_actuals})").substitute(env)
342371
ret = option['return']
343372
if ret['kind'] == 'arguments':
344373
body.append(call + ";")

preprocess_declarations.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from copy import deepcopy
33
from function_wrapper import TYPE_FORMAL_GENERIC
44
import common_with_cwrap
5+
import yaml
56

67
type_map = {
78
'floating_point': [
@@ -54,7 +55,7 @@ def expand(pair):
5455

5556

5657
def exclude(declaration):
57-
return 'only_register' in declaration
58+
return 'only_register' in declaration or declaration.get('python_name') == 'ndimension'
5859

5960

6061
def add_variants(option):
@@ -109,6 +110,43 @@ def sanitize_return(option):
109110
def set_mode(option):
110111
option['mode'] = option.get('mode', 'TH')
111112

113+
# To enable 0-dim support in TH operations
114+
# we find all places where a single Scalar replaced with a Tensor
115+
# as an argument is still a valid function
116+
# we then mark the tensor variant with a key zero_dim_dispatch_when_scalar: name
117+
# where 'name' is the name of the argument that should be a scalar
118+
# during dispatch, if that argument is marked internally as holding a scalar
119+
# then the method will dispatch to that function.
120+
def discover_zero_dim_tensor_operations(declaration):
121+
def exclude(arg):
122+
return arg.get('ignore_check')
123+
124+
def signature(option,i=None,value=None):
125+
elements = [TYPE_FORMAL_GENERIC.get(arg['type'],arg['type'])
126+
if i is None or j != i else value
127+
for j, arg in enumerate(option['arguments'])
128+
if not exclude(arg) ]
129+
return '#'.join(elements)
130+
signature_to_option = {signature(option): option
131+
for option in declaration['options']}
132+
133+
for option in declaration['options']:
134+
for i,arg in enumerate(option['arguments']):
135+
if arg['type'] == 'real':
136+
signature_of_tensor_version = signature(option,i,'Tensor &')
137+
if signature_of_tensor_version in signature_to_option:
138+
tensor_version = \
139+
signature_to_option[signature_of_tensor_version]
140+
names = [arg['name'] for arg in tensor_version['arguments']
141+
if not exclude(arg)]
142+
tensor_version['zero_dim_dispatch_when_scalar'] = names[i]
143+
print("FOUND "+str(i))
144+
print("Scalar Version ===== ")
145+
print(yaml.dump(option))
146+
print("Tensor Version ===== ")
147+
print(yaml.dump(tensor_version))
148+
print("SHARED "+names[i])
149+
112150

113151
def run(declarations):
114152
declarations = [d for d in declarations if not exclude(d)]
@@ -120,6 +158,8 @@ def run(declarations):
120158
type_to_signature=TYPE_FORMAL_GENERIC,
121159
remove_self=True)
122160
common_with_cwrap.sort_by_number_of_options(declaration)
161+
discover_zero_dim_tensor_operations(declaration)
162+
123163
new_options = []
124164
for option in declaration['options']:
125165
set_mode(option)

templates/Tensor.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ struct Tensor {
7676
IntList strides() const {
7777
return pImpl->strides();
7878
}
79+
int64_t dim() const {
80+
return pImpl->dim();
81+
}
82+
int64_t ndimension() const {
83+
return dim();
84+
}
85+
Scalar scalar() const {
86+
return pImpl->localScalar();
87+
}
7988
Type & type() const {
8089
return pImpl->type();
8190
}
@@ -95,9 +104,7 @@ struct Tensor {
95104
Tensor toBackend(Backend b) {
96105
return toType(type().toBackend(b));
97106
}
98-
int64_t dim() const {
99-
return ndimension();
100-
}
107+
101108
template<typename T>
102109
T * data() const;
103110

templates/TensorDerived.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,24 @@ const char * ${Tensor}::toString() const {
5454
}
5555

5656
IntList ${Tensor}::sizes() {
57+
if(isScalar())
58+
return IntList();
5759
return IntList(reinterpret_cast<int64_t*>(tensor->size),tensor->nDimension);
5860
}
5961
IntList ${Tensor}::strides() {
62+
if(isScalar())
63+
return IntList();
6064
return IntList(reinterpret_cast<int64_t*>(tensor->stride),tensor->nDimension);
6165
}
66+
int64_t ${Tensor}::dim() {
67+
if(isScalar())
68+
return 0;
69+
return ${THTensor}_nDimension(${state,}tensor);
70+
}
71+
Scalar ${Tensor}::localScalar() {
72+
AT_ASSERT(isScalar(),"localScalar() called on Tensor with %d dims",sizes().size());
73+
return Scalar(${to_at_half}(${THTensor}_get1d(${state,}tensor, 0)));
74+
}
6275

6376
const char * ${Tensor}::typeString() {
6477
return "${Type}";

templates/TensorDerived.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ struct ${Tensor} : public TensorImpl {
1616
virtual const char * toString() const override;
1717
virtual IntList sizes() override;
1818
virtual IntList strides() override;
19+
virtual int64_t dim() override;
20+
virtual Scalar localScalar() override;
1921
static const char * typeString();
2022

2123
//TODO(zach): sort of friend permissions later so this

test/basic.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ static void test(Type & type) {
3535
{
3636
std::cout << "sort:" << std::endl;
3737
Tensor b = type.rand({3, 4});
38+
3839
std::cout << b << std::endl;
3940
auto z = b.sort(1);
4041
std::cout << std::get<0>(z) << std::endl;
@@ -167,6 +168,18 @@ static void test(Type & type) {
167168
//std::cout << select(select(a, 1, 3), 0, 2) << std::endl;
168169
}
169170

171+
{
172+
std::cout << "zero-dim: " << std::endl;
173+
Tensor a = type.rand({1});
174+
// TODO
175+
a.pImpl->setScalar(true);
176+
std::cout << a << "dims: " << a.dim() << std::endl;
177+
std::cout << a.scalar() << std::endl;
178+
Tensor b = type.rand({3,4});
179+
std::cout << b + a << std::endl;
180+
std::cout << a + b << std::endl;
181+
}
182+
170183
}
171184

172185
int main()

0 commit comments

Comments
 (0)