Skip to content

Commit 3abe5c8

Browse files
adamlerersoumith
authored andcommitted
Fast transposed copy
1 parent efa913b commit 3abe5c8

File tree

1 file changed

+69
-1
lines changed

1 file changed

+69
-1
lines changed

generic/THTensorCopy.c

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,84 @@
22
#define TH_GENERIC_FILE "generic/THTensorCopy.c"
33
#else
44

5+
int THTensor_(copyTransposeValid)(THTensor *tensor, THTensor *src) {
6+
const int MIN_SZ = 60 * 60;
7+
return THTensor_(isContiguous)(tensor) &&
8+
THTensor_(nDimension)(src) == 2 &&
9+
THTensor_(stride)(src, 0) == 1 &&
10+
THTensor_(stride)(src, 1) == THTensor_(size)(src, 0) &&
11+
THTensor_(nElement)(tensor) >= MIN_SZ;
12+
}
13+
14+
// special case copy where tensor is contiguous and src is a transposed matrix
15+
// This can be generalized to most copies, but it's tricker
16+
void THTensor_(copyTranspose)(THTensor *tensor, THTensor *src) {
17+
#define MIN(x, y) (((x) < (y)) ? (x) : (y))
18+
#define MAX(x, y) (((x) > (y)) ? (x) : (y))
19+
20+
#ifdef TH_REAL_IS_BYTE
21+
const int BLOCK_SZ = 120;
22+
#else
23+
const int BLOCK_SZ = 60;
24+
#endif
25+
26+
THTensor *buf = THTensor_(newWithSize2d)(BLOCK_SZ, BLOCK_SZ);
27+
real *sp = THTensor_(data)(src);
28+
real *rp = THTensor_(data)(tensor);
29+
real *bp = THTensor_(data)(buf);
30+
31+
long NR = THTensor_(size)(src, 0);
32+
long NC = THTensor_(size)(src, 1);
33+
for (long R = 0; R < NR; R += BLOCK_SZ) {
34+
for (long C = 0; C < NC; C += BLOCK_SZ) {
35+
real *spo = sp + R + C * NR;
36+
real *rpo = rp + C + R * NC;
37+
38+
int nr = MIN(NR - R, BLOCK_SZ);
39+
int nc = MIN(NC - C, BLOCK_SZ);
40+
41+
// 1. copy columns from src to buf
42+
for (int c = 0; c < nc; c++) {
43+
memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(real));
44+
}
45+
46+
// 2. transpose buf in place
47+
int rc_max = MAX(nr, nc);
48+
int rc_min = MIN(nr, nc);
49+
for (int r = 0; r < rc_max; r++) {
50+
int end = MIN(r, rc_min);
51+
for (int c = 0; c < end; c++) {
52+
real tmp = bp[r + BLOCK_SZ * c];
53+
bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c];
54+
bp[r * BLOCK_SZ + c] = tmp;
55+
}
56+
}
57+
58+
// 3. copy rows from buf to dst
59+
for (int r = 0; r < nr; r++) {
60+
memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(real));
61+
}
62+
}
63+
}
64+
THTensor_(free)(buf);
65+
#undef MIN
66+
#undef MAX
67+
}
68+
569
void THTensor_(copy)(THTensor *tensor, THTensor *src)
670
{
771
if (THTensor_(isContiguous)(tensor) && THTensor_(isContiguous)(src) && THTensor_(nElement)(tensor) == THTensor_(nElement)(src)) {
872
real *sp = THTensor_(data)(src);
973
real *rp = THTensor_(data)(tensor);
1074
ptrdiff_t sz = THTensor_(nElement)(tensor);
1175
#ifndef TH_REAL_IS_HALF
12-
THVector_(copy)(rp, sp, sz);
76+
THVector_(copy)(rp, sp, sz);
1377
#else
1478
memcpy(rp, sp, sz * sizeof(real));
79+
#endif
80+
#ifndef TH_REAL_IS_HALF
81+
} else if (THTensor_(copyTransposeValid)(tensor, src)) {
82+
THTensor_(copyTranspose)(tensor, src);
1583
#endif
1684
} else {
1785
TH_TENSOR_APPLY2(real, tensor, real, src, *tensor_data = *src_data;)

0 commit comments

Comments
 (0)