Skip to content

Commit c39d48e

Browse files
adamlerersoumith
authored andcommitted
Fast transposed copy
1 parent 05bc877 commit c39d48e

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

test/test_torch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3230,6 +3230,14 @@ def test_Size(self):
32303230
self.assertIsInstance(x[:-1], torch.Size)
32313231
self.assertIsInstance(x + x, torch.Size)
32323232

3233+
# unit test for THTensor_(copyTranspose)
3234+
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3235+
def test_big_transpose(self):
3236+
t = torch.rand(456, 789)
3237+
t1 = t.t().contiguous()
3238+
t2 = torch.from_numpy(t.numpy().transpose())
3239+
self.assertEqual(t1, t2)
3240+
32333241
# Functions to test negative dimension wrapping
32343242
METHOD = 1
32353243
INPLACE_METHOD = 2

torch/csrc/generic/Tensor.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,8 @@ PyTypeObject THPTensorStatelessType = {
816816
void THPTensor_(initCopyMethods)()
817817
{
818818
auto& h = THTensor_(copy_functions);
819+
// copy from same type
820+
THPInsertCopyFunction(h, &THTensor_(copy));
819821
// copy from CPU types
820822
THPInsertCopyFunction(h, &THTensor_(copyByte));
821823
THPInsertCopyFunction(h, &THTensor_(copyChar));

0 commit comments

Comments
 (0)