Skip to content

Commit 4136153

Browse files
nkaretnikovpytorchmergebot
authored andcommitted
[pt2] add SymInt support for tensordot and inner (#100356)
Pull Request resolved: #100356 Approved by: https://github.com/ezyang
1 parent 4582ceb commit 4136153

File tree

4 files changed

+17
-20
lines changed

4 files changed

+17
-20
lines changed

aten/src/ATen/native/Linear.cpp

+14-13
Original file line numberDiff line numberDiff line change
@@ -727,12 +727,12 @@ Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight
727727
Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2) {
728728
TORCH_CHECK(dims1.size() == dims2.size(), "both dimension lists should have same length");
729729
TORCH_CHECK(input1.scalar_type() == input2.scalar_type(), "both inputs should have same dtype");
730-
int64_t csize = 1; // total size of the contracted dimensions
730+
SymInt csize = 1; // total size of the contracted dimensions
731731
Tensor t1 = input1;
732732
Tensor t2 = input2;
733733
for (const auto i : c10::irange(dims1.size())) {
734-
int s1 = input1.size(dims1[i]);
735-
int s2 = input2.size(dims2[i]);
734+
SymInt s1 = input1.sym_size(dims1[i]);
735+
SymInt s2 = input2.sym_size(dims2[i]);
736736
if (s2 == 1) { // broadcasted dimensions can be summed right away
737737
t1 = t1.sum(dims1[i], true);
738738
} else if (s1 == 1) {
@@ -746,19 +746,20 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
746746

747747
auto cdims1 = at::dim_list_to_bitset(dims1, input1.dim());
748748
auto cdims2 = at::dim_list_to_bitset(dims2, input2.dim());
749-
std::vector<int64_t> p1, p2, rsizes; // p1, p2: input permutations, rsizes: sizes of the result
749+
std::vector<int64_t> p1, p2; // p1, p2: input permutations
750+
std::vector<SymInt> rsizes; // rsizes: sizes of the result
750751
p1.reserve(input1.dim());
751752
p2.reserve(input2.dim());
752753
rsizes.reserve(input1.dim() + input2.dim() - (int64_t) dims1.size());
753-
int64_t size1 = 1; // number of non-contracted elements in input1
754-
int64_t size2 = 1; // number of non-contracted elements in input2
754+
SymInt size1 = 1; // number of non-contracted elements in input1
755+
SymInt size2 = 1; // number of non-contracted elements in input2
755756

756757
// fill the permutations and compute sizes
757758
for (const auto i : c10::irange(input1.dim())) {
758759
if (! cdims1[i]) {
759760
p1.emplace_back(i);
760-
size1 *= t1.size(i);
761-
rsizes.emplace_back(t1.size(i));
761+
size1 *= t1.sym_size(i);
762+
rsizes.emplace_back(t1.sym_size(i));
762763
}
763764
}
764765
for (const auto x : dims1) {
@@ -770,15 +771,15 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
770771
for (const auto i : c10::irange(input2.dim())) {
771772
if (! cdims2[i]) {
772773
p2.emplace_back(i);
773-
size2 *= t2.size(i);
774-
rsizes.emplace_back(t2.size(i));
774+
size2 *= t2.sym_size(i);
775+
rsizes.emplace_back(t2.sym_size(i));
775776
}
776777
}
777778
// permut and reshape for matrix multiplication
778-
t1 = t1.permute(p1).reshape({size1, csize});
779-
t2 = t2.permute(p2).reshape({csize, size2});
779+
t1 = t1.permute(p1).reshape_symint({size1, csize});
780+
t2 = t2.permute(p2).reshape_symint({csize, size2});
780781
// multiply and reshape to target size
781-
return at::mm(t1, t2).reshape(rsizes);
782+
return at::mm(t1, t2).reshape_symint(rsizes);
782783
}
783784

784785
Tensor &tensordot_out(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2, Tensor& result) {

aten/src/ATen/native/LinearAlgebra.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1284,11 +1284,11 @@ Tensor inner(const Tensor& self, const Tensor& other) {
12841284

12851285
// Last dimension should match (tensordot does not enforce this)
12861286
TORCH_CHECK(
1287-
self.size(-1) == other.size(-1),
1287+
self.sym_size(-1) == other.sym_size(-1),
12881288
"inner() the last dimension must match on both input tensors but got shapes ",
1289-
self.sizes(),
1289+
self.sym_sizes(),
12901290
" and ",
1291-
other.sizes());
1291+
other.sym_sizes());
12921292

12931293
return at::tensordot(self, other, -1, -1);
12941294
}

test/functorch/test_aotdispatch.py

-2
Original file line numberDiff line numberDiff line change
@@ -2502,7 +2502,6 @@ def forward(self, x):
25022502
xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
25032503
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
25042504
xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
2505-
xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
25062505
xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
25072506
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
25082507
xfail('linalg.cholesky_ex', ''), # could not find kernel for aten.linalg_solve_triangular.default
@@ -2612,7 +2611,6 @@ def forward(self, x):
26122611
xfail('svd', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
26132612
xfail('svd_lowrank', ''), # could not find kernel
26142613
xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
2615-
xfail('tensordot', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
26162614
xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
26172615
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de...
26182616
xfail('_upsample_bilinear2d_aa'), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList

test/test_proxy_tensor.py

-2
Original file line numberDiff line numberDiff line change
@@ -1455,7 +1455,6 @@ def f(a, b, c, d, e):
14551455
xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition
14561456
xfail('hsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
14571457
xfail('index_reduce', ''), # Float
1458-
xfail('inner', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
14591458
xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition
14601459
xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
14611460
xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
@@ -1559,7 +1558,6 @@ def f(a, b, c, d, e):
15591558
xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at...
15601559
xfail('svd_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
15611560
xfail('take_along_dim', ''), # dtype of indices should be Long but got Float
1562-
xfail('tensordot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
15631561
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition
15641562
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
15651563
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition

0 commit comments

Comments
 (0)