Skip to content

Commit a10a1c9

Browse files
committed
start adding rules to propagate scalar to results
1 parent bb6908e commit a10a1c9

File tree

5 files changed

+31
-18
lines changed

5 files changed

+31
-18
lines changed

Formatting.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,9 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
230230
stream << "[ Tensor (empty) ]";
231231
} else {
232232
Tensor tensor = tensor_.toType(getType(kCPU,kDouble)).contiguous();
233-
if(tensor.ndimension() == 0) {
233+
if(tensor_.ndimension() == 0) {
234234
stream << std::defaultfloat << tensor.data<double>()[0] << std::endl;
235-
stream << "[" << tensor_.pImpl->toString() << " of size {}]";
235+
stream << "[ " << tensor_.pImpl->toString() << "{} ]";
236236
} else if(tensor.ndimension() == 1) {
237237
double scale;
238238
int64_t sz;
@@ -244,17 +244,17 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
244244
for(int64_t i = 0; i < tensor.size(0); i++) {
245245
stream << std::setw(sz) << tensor_p[i]/scale << std::endl;
246246
}
247-
stream << "[" << tensor_.pImpl->toString() << " of size " << tensor.size(0) << "]";
247+
stream << "[ " << tensor_.pImpl->toString() << "{" << tensor.size(0) << "} ]";
248248
} else if(tensor.ndimension() == 2) {
249249
__printMatrix(stream, tensor, linesize, 0);
250-
stream << "[" << tensor_.pImpl->toString() <<" of size " << tensor.size(0) << "x" << tensor.size(1) << "]";
250+
stream << "[ " << tensor_.pImpl->toString() << "{" << tensor.size(0) << "," << tensor.size(1) << "} ]";
251251
} else {
252252
__printTensor(stream, tensor, linesize);
253-
stream << "[" << tensor_.pImpl->toString() << " of size " << tensor.size(0);
253+
stream << "[ " << tensor_.pImpl->toString() << "{" << tensor.size(0);
254254
for(int64_t i = 1; i < tensor.ndimension(); i++) {
255-
stream << "x" << tensor.size(i);
255+
stream << "," << tensor.size(i);
256256
}
257-
stream << "]";
257+
stream << "} ]";
258258
}
259259
}
260260
return stream;

TensorImpl.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,14 @@ struct TensorImpl {
3838
bool isScalar() const {
3939
return is_scalar;
4040
}
41-
void setScalar(bool s) {
42-
is_scalar = s;
41+
// this is called by the generated wrapper code when there are conditions
42+
// when this output tensor should be a scalar. e.g. when all inputs
43+
// to a function 'add' were scalars, then condition_when_scalar == true.
44+
// we also prevent this from getting marked as a scalar if it is not
45+
// the right shape afterall.
46+
TensorImpl* maybeScalar(bool condition_when_scalar) {
47+
is_scalar = condition_when_scalar && (dim() == 0 || dim() == 1 && sizes()[0] == 1);
48+
return this;
4349
}
4450

4551
private:

copy_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
throw std::runtime_error("unsupported type in copy");
3333
break;
3434
}
35-
dst.pImpl->setScalar(src.pImpl->isScalar());
35+
dst.pImpl->maybeScalar(src.pImpl->isScalar());
3636
}
3737
""")
3838

function_wrapper.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,10 +318,12 @@ def emit_body(env, option):
318318
# referencing another
319319
seen_names = set()
320320
count = 0
321+
scalar_check = None
321322
for arg in option['arguments']:
322323
if is_real_argument_to_wrapper(arg):
323324
count += 1
324-
325+
if arg['type'] == 'THSize*':
326+
scalar_check = '{}.size() == 0'.format(arg['name'])
325327
# only generated checked casts the first time we see it
326328
if not arg['name'] in seen_names and requires_checked_cast(arg):
327329
seen_names.add(arg['name'])
@@ -369,24 +371,32 @@ def emit_body(env, option):
369371

370372
call = prefix + CodeTemplate("${cname}(${derived_actuals})").substitute(env)
371373
ret = option['return']
374+
372375
if ret['kind'] == 'arguments':
373376
body.append(call + ";")
374377
arguments_indices = ret['arguments']
378+
arguments = [option['arguments'][argi]
379+
for argi in arguments_indices]
380+
if scalar_check is not None:
381+
for arg in arguments:
382+
body.append("bool maybe_scalar = {};".format(scalar_check))
383+
body.append("{}_->maybeScalar(maybe_scalar);".format(arg['name']))
375384
if len(arguments_indices) == 1:
376-
arg = option['arguments'][arguments_indices[0]]
385+
arg = arguments[0]
377386
body.append("return {};".format(arg['name']))
378387
else:
379-
arguments = [option['arguments'][argi]
380-
for argi in arguments_indices]
381388
types = [to_return_type(arg, option) for arg in arguments]
382389
# TODO: check for move semantics...
383390
names = [arg['name'] for arg in arguments]
384391
body.append(CodeTemplate("return std::tuple<${types}>(${names});").substitute(
385392
types=types, names=names))
386393
elif ret['kind'] == 'type':
387394
if ret['type'] == 'THTensor*':
395+
maybe_scalar = "->maybeScalar({})".format(scalar_check) \
396+
if scalar_check is not None \
397+
else ""
388398
body.append(CodeTemplate(
389-
"return Tensor(new ${Tensor}(context,${arg_name}),false);").substitute(env, arg_name=call))
399+
"return Tensor((new ${Tensor}(context,${arg_name}))${maybe_scalar},false);").substitute(env, arg_name=call,maybe_scalar=maybe_scalar))
390400
else:
391401
# we using int64_t for long in the API, so correct it here...
392402
if is_actual_return_long(ret):

test/basic.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,6 @@ static void test(Type & type) {
172172
std::cout << "zero-dim: " << std::endl;
173173
Tensor a = type.scalarTensor(4); //type.rand({1});
174174

175-
// TODO: automate
176-
a.pImpl->setScalar(true);
177-
178175
std::cout << a << "dims: " << a.dim() << std::endl;
179176
std::cout << Scalar(a) << std::endl;
180177
Tensor b = type.rand({3,4});

0 commit comments

Comments
 (0)