@@ -727,12 +727,12 @@ Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight
727
727
Tensor tensordot (const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2) {
728
728
TORCH_CHECK (dims1.size () == dims2.size (), " both dimension lists should have same length" );
729
729
TORCH_CHECK (input1.scalar_type () == input2.scalar_type (), " both inputs should have same dtype" );
730
- int64_t csize = 1 ; // total size of the contracted dimensions
730
+ SymInt csize = 1 ; // total size of the contracted dimensions
731
731
Tensor t1 = input1;
732
732
Tensor t2 = input2;
733
733
for (const auto i : c10::irange (dims1.size ())) {
734
- int s1 = input1.size (dims1[i]);
735
- int s2 = input2.size (dims2[i]);
734
+ SymInt s1 = input1.sym_size (dims1[i]);
735
+ SymInt s2 = input2.sym_size (dims2[i]);
736
736
if (s2 == 1 ) { // broadcasted dimensions can be summed right away
737
737
t1 = t1.sum (dims1[i], true );
738
738
} else if (s1 == 1 ) {
@@ -746,19 +746,20 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
746
746
747
747
auto cdims1 = at::dim_list_to_bitset (dims1, input1.dim ());
748
748
auto cdims2 = at::dim_list_to_bitset (dims2, input2.dim ());
749
- std::vector<int64_t > p1, p2, rsizes; // p1, p2: input permutations, rsizes: sizes of the result
749
+ std::vector<int64_t > p1, p2; // p1, p2: input permutations
750
+ std::vector<SymInt> rsizes; // rsizes: sizes of the result
750
751
p1.reserve (input1.dim ());
751
752
p2.reserve (input2.dim ());
752
753
rsizes.reserve (input1.dim () + input2.dim () - (int64_t ) dims1.size ());
753
- int64_t size1 = 1 ; // number of non-contracted elements in input1
754
- int64_t size2 = 1 ; // number of non-contracted elements in input2
754
+ SymInt size1 = 1 ; // number of non-contracted elements in input1
755
+ SymInt size2 = 1 ; // number of non-contracted elements in input2
755
756
756
757
// fill the permutations and compute sizes
757
758
for (const auto i : c10::irange (input1.dim ())) {
758
759
if (! cdims1[i]) {
759
760
p1.emplace_back (i);
760
- size1 *= t1.size (i);
761
- rsizes.emplace_back (t1.size (i));
761
+ size1 *= t1.sym_size (i);
762
+ rsizes.emplace_back (t1.sym_size (i));
762
763
}
763
764
}
764
765
for (const auto x : dims1) {
@@ -770,15 +771,15 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
770
771
for (const auto i : c10::irange (input2.dim ())) {
771
772
if (! cdims2[i]) {
772
773
p2.emplace_back (i);
773
- size2 *= t2.size (i);
774
- rsizes.emplace_back (t2.size (i));
774
+ size2 *= t2.sym_size (i);
775
+ rsizes.emplace_back (t2.sym_size (i));
775
776
}
776
777
}
777
778
// permut and reshape for matrix multiplication
778
- t1 = t1.permute (p1).reshape ({size1, csize});
779
- t2 = t2.permute (p2).reshape ({csize, size2});
779
+ t1 = t1.permute (p1).reshape_symint ({size1, csize});
780
+ t2 = t2.permute (p2).reshape_symint ({csize, size2});
780
781
// multiply and reshape to target size
781
- return at::mm (t1, t2).reshape (rsizes);
782
+ return at::mm (t1, t2).reshape_symint (rsizes);
782
783
}
783
784
784
785
Tensor &tensordot_out (const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2, Tensor& result) {
0 commit comments