|
2 | 2 | #define TH_GENERIC_FILE "generic/THTensorCopy.c" |
3 | 3 | #else |
4 | 4 |
|
| 5 | +int THTensor_(copyTransposeValid)(THTensor *tensor, THTensor *src) { |
| 6 | + const int MIN_SZ = 60 * 60; |
| 7 | + return THTensor_(isContiguous)(tensor) && |
| 8 | + THTensor_(nDimension)(src) == 2 && |
| 9 | + THTensor_(stride)(src, 0) == 1 && |
| 10 | + THTensor_(stride)(src, 1) == THTensor_(size)(src, 0) && |
| 11 | + THTensor_(nElement)(tensor) >= MIN_SZ; |
| 12 | +} |
| 13 | + |
| 14 | +// special case copy where tensor is contiguous and src is a transposed matrix |
| 15 | +// This can be generalized to most copies, but it's tricker |
| 16 | +void THTensor_(copyTranspose)(THTensor *tensor, THTensor *src) { |
| 17 | + #define MIN(x, y) (((x) < (y)) ? (x) : (y)) |
| 18 | + #define MAX(x, y) (((x) > (y)) ? (x) : (y)) |
| 19 | + |
| 20 | +#ifdef TH_REAL_IS_BYTE |
| 21 | + const int BLOCK_SZ = 120; |
| 22 | +#else |
| 23 | + const int BLOCK_SZ = 60; |
| 24 | +#endif |
| 25 | + |
| 26 | + THTensor *buf = THTensor_(newWithSize2d)(BLOCK_SZ, BLOCK_SZ); |
| 27 | + real *sp = THTensor_(data)(src); |
| 28 | + real *rp = THTensor_(data)(tensor); |
| 29 | + real *bp = THTensor_(data)(buf); |
| 30 | + |
| 31 | + long NR = THTensor_(size)(src, 0); |
| 32 | + long NC = THTensor_(size)(src, 1); |
| 33 | + for (long R = 0; R < NR; R += BLOCK_SZ) { |
| 34 | + for (long C = 0; C < NC; C += BLOCK_SZ) { |
| 35 | + real *spo = sp + R + C * NR; |
| 36 | + real *rpo = rp + C + R * NC; |
| 37 | + |
| 38 | + int nr = MIN(NR - R, BLOCK_SZ); |
| 39 | + int nc = MIN(NC - C, BLOCK_SZ); |
| 40 | + |
| 41 | + // 1. copy columns from src to buf |
| 42 | + for (int c = 0; c < nc; c++) { |
| 43 | + memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(real)); |
| 44 | + } |
| 45 | + |
| 46 | + // 2. transpose buf in place |
| 47 | + int rc_max = MAX(nr, nc); |
| 48 | + int rc_min = MIN(nr, nc); |
| 49 | + for (int r = 0; r < rc_max; r++) { |
| 50 | + int end = MIN(r, rc_min); |
| 51 | + for (int c = 0; c < end; c++) { |
| 52 | + real tmp = bp[r + BLOCK_SZ * c]; |
| 53 | + bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c]; |
| 54 | + bp[r * BLOCK_SZ + c] = tmp; |
| 55 | + } |
| 56 | + } |
| 57 | + |
| 58 | + // 3. copy rows from buf to dst |
| 59 | + for (int r = 0; r < nr; r++) { |
| 60 | + memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(real)); |
| 61 | + } |
| 62 | + } |
| 63 | + } |
| 64 | + THTensor_(free)(buf); |
| 65 | + #undef MIN |
| 66 | + #undef MAX |
| 67 | +} |
| 68 | + |
5 | 69 | void THTensor_(copy)(THTensor *tensor, THTensor *src) |
6 | 70 | { |
7 | 71 | if (THTensor_(isContiguous)(tensor) && THTensor_(isContiguous)(src) && THTensor_(nElement)(tensor) == THTensor_(nElement)(src)) { |
8 | 72 | real *sp = THTensor_(data)(src); |
9 | 73 | real *rp = THTensor_(data)(tensor); |
10 | 74 | ptrdiff_t sz = THTensor_(nElement)(tensor); |
11 | 75 | #ifndef TH_REAL_IS_HALF |
12 | | - THVector_(copy)(rp, sp, sz); |
| 76 | + THVector_(copy)(rp, sp, sz); |
13 | 77 | #else |
14 | 78 | memcpy(rp, sp, sz * sizeof(real)); |
| 79 | +#endif |
| 80 | +#ifndef TH_REAL_IS_HALF |
| 81 | + } else if (THTensor_(copyTransposeValid)(tensor, src)) { |
| 82 | + THTensor_(copyTranspose)(tensor, src); |
15 | 83 | #endif |
16 | 84 | } else { |
17 | 85 | TH_TENSOR_APPLY2(real, tensor, real, src, *tensor_data = *src_data;) |
|
0 commit comments