Skip to content

Commit fb5e40f

Browse files
committed
Merge commit '3f25232aaba44aa4377c7e5ed670587a72f5886e'
2 parents 469969e + 3f25232 commit fb5e40f

File tree

6 files changed

+109
-5
lines changed

6 files changed

+109
-5
lines changed

torch/lib/THC/THCBlas.cu

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,31 @@ void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, long m,
344344
(int)batchCount));
345345
}
346346

347+
#if CUDA_VERSION >= 8000
348+
void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
349+
float alpha, const float *a, long lda, long strideA, const float *b, long ldb, long strideB,
350+
float beta, float *c, long ldc, long strideC, long batchCount)
351+
{
352+
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
353+
354+
{
355+
THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
356+
"with the bound [val] <= %d", INT_MAX);
357+
}
358+
359+
adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
360+
cublasOperation_t opa = convertTransToCublasOperation(transa);
361+
cublasOperation_t opb = convertTransToCublasOperation(transb);
362+
363+
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
364+
cublasSetStream(handle, THCState_getCurrentStream(state));
365+
THCublasCheck(cublasSgemmStridedBatched(handle,
366+
opa, opb, (int)m, (int)n, (int)k,
367+
&alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC,
368+
(int)batchCount));
369+
}
370+
#endif
371+
347372
void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, long m, long n, long k,
348373
double alpha, const double *a[], long lda, const double *b[], long ldb,
349374
double beta, double *c[], long ldc, long batchCount)
@@ -366,6 +391,30 @@ void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, long m,
366391
(int)batchCount));
367392
}
368393

394+
#if CUDA_VERSION >= 8000
395+
void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
396+
double alpha, const double *a, long lda, long strideA, const double *b, long ldb, long strideB,
397+
double beta, double *c, long ldc, long strideC, long batchCount)
398+
{
399+
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
400+
{
401+
THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
402+
"with the bound [val] <= %d", INT_MAX);
403+
}
404+
405+
adjustLd(transa, transb, m, n, k, &lda, &ldb, &ldc);
406+
cublasOperation_t opa = convertTransToCublasOperation(transa);
407+
cublasOperation_t opb = convertTransToCublasOperation(transb);
408+
409+
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
410+
cublasSetStream(handle, THCState_getCurrentStream(state));
411+
THCublasCheck(cublasDgemmStridedBatched(handle,
412+
opa, opb, (int)m, (int)n, (int)k,
413+
&alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC,
414+
(int)batchCount));
415+
}
416+
#endif
417+
369418
/* Inverse */
370419
void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize) {
371420
if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) )

torch/lib/THC/THCBlas.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,14 @@ THC_API void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb,
3131
THC_API void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, long m, long n, long k,
3232
double alpha, const double *a[], long lda, const double *b[], long ldb,
3333
double beta, double *c[], long ldc, long batchCount);
34-
34+
#if CUDA_VERSION >= 8000
35+
THC_API void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
36+
float alpha, const float *a, long lda, long strideA, const float *b, long ldb, long strideB,
37+
float beta, float *c, long ldc, long strideC, long batchCount);
38+
THC_API void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
39+
double alpha, const double *a, long lda, long strideA, const double *b, long ldb, long strideB,
40+
double beta, double *c, long ldc, long strideC, long batchCount);
41+
#endif
3542
/* Inverse */
3643
THC_API void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize);
3744
THC_API void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize);

torch/lib/THC/THCNumerics.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ struct THCNumerics<char> {
4444
static inline __host__ __device__ bool eq(char a, char b) { return a == b; }
4545
static inline __host__ __device__ bool ne(char a, char b) { return a != b; }
4646

47+
static inline __host__ __device__ char neg(char a) { return -a; }
4748
static inline __host__ __device__ char add(char a, char b) { return a + b; }
4849
static inline __host__ __device__ char mul(char a, char b) { return a * b; }
4950
static inline __host__ __device__ char sub(char a, char b) { return a - b; }
@@ -63,6 +64,7 @@ struct THCNumerics<short> {
6364
static inline __host__ __device__ bool eq(short a, short b) { return a == b; }
6465
static inline __host__ __device__ bool ne(short a, short b) { return a != b; }
6566

67+
static inline __host__ __device__ short neg(short a) { return -a; }
6668
static inline __host__ __device__ short add(short a, short b) { return a + b; }
6769
static inline __host__ __device__ short mul(short a, short b) { return a * b; }
6870
static inline __host__ __device__ short sub(short a, short b) { return a - b; }
@@ -82,6 +84,7 @@ struct THCNumerics<int> {
8284
static inline __host__ __device__ bool eq(int a, int b) { return a == b; }
8385
static inline __host__ __device__ bool ne(int a, int b) { return a != b; }
8486

87+
static inline __host__ __device__ int neg(int a) { return -a; }
8588
static inline __host__ __device__ int add(int a, int b) { return a + b; }
8689
static inline __host__ __device__ int mul(int a, int b) { return a * b; }
8790
static inline __host__ __device__ int sub(int a, int b) { return a - b; }
@@ -101,6 +104,7 @@ struct THCNumerics<long> {
101104
static inline __host__ __device__ bool eq(long a, long b) { return a == b; }
102105
static inline __host__ __device__ bool ne(long a, long b) { return a != b; }
103106

107+
static inline __host__ __device__ long neg(long a) { return -a; }
104108
static inline __host__ __device__ long add(long a, long b) { return a + b; }
105109
static inline __host__ __device__ long mul(long a, long b) { return a * b; }
106110
static inline __host__ __device__ long sub(long a, long b) { return a - b; }

torch/lib/THC/generic/THCTensorMathBlas.cu

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,10 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
537537

538538
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
539539
// Compute pointers to matrices in each batch.
540+
#if CUDA_VERSION < 8000
540541
size_t matrices_size = num_batches * sizeof(real*);
541542

542-
// Copy pointers to device.
543+
// Copy pointers to device.
543544
const real **d_matrices1, **d_matrices2;
544545
real **d_result_matrices;
545546
THCudaCheck(THCudaMalloc(state, (void**)&d_matrices1, matrices_size));
@@ -558,7 +559,6 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
558559
createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
559560
(const real**)d_result_matrices, THCTensor_(data)(state,result_),
560561
result_->stride[0], num_batches);
561-
562562
#ifdef THC_REAL_IS_FLOAT
563563
THCudaBlas_SgemmBatched(
564564
state,
@@ -592,6 +592,38 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
592592
THCudaFree(state, d_matrices1);
593593
THCudaFree(state, d_matrices2);
594594
THCudaFree(state, d_result_matrices);
595+
596+
#else
597+
#ifdef THC_REAL_IS_FLOAT
598+
THCudaBlas_SgemmStridedBatched(
599+
state,
600+
transpose_batch1,
601+
transpose_batch2,
602+
result_->size[transpose_result ? 2 : 1],
603+
result_->size[transpose_result ? 1 : 2],
604+
batch1_->size[transpose_result ? 1 : 2],
605+
alpha,
606+
THCTensor_(data)(state, batch1_), lda, batch1_->stride[0],
607+
THCTensor_(data)(state, batch2_), ldb, batch2_->stride[0],
608+
beta,
609+
THCTensor_(data)(state, result_), ldc, result_->stride[0],
610+
num_batches);
611+
#elif defined(THC_REAL_IS_DOUBLE)
612+
THCudaBlas_DgemmStridedBatched(
613+
state,
614+
transpose_batch1,
615+
transpose_batch2,
616+
result_->size[transpose_result ? 2 : 1],
617+
result_->size[transpose_result ? 1 : 2],
618+
batch1_->size[transpose_result ? 1 : 2],
619+
alpha,
620+
THCTensor_(data)(state, batch1_), lda, batch1_->stride[0],
621+
THCTensor_(data)(state, batch2_), ldb, batch2_->stride[0],
622+
beta,
623+
THCTensor_(data)(state, result_), ldc, result_->stride[0],
624+
num_batches);
625+
#endif
626+
#endif
595627

596628
#elif defined(THC_REAL_IS_HALF)
597629
// Currently no HgemmBatched in Cublas

torch/lib/THC/generic/THCTensorMathPointwise.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(rsqrt, THCNumerics<real>::rsqrt, Real)
4646
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( ceil, THCNumerics<real>::ceil, Real)
4747
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(floor, THCNumerics<real>::floor, Real)
4848
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(trunc, THCNumerics<real>::trunc, Real)
49-
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( neg, THCNumerics<real>::neg, Real)
5049

5150
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( acos, THCNumerics<real>::acos, Real)
5251
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cosh, THCNumerics<real>::cosh, Real)
@@ -61,6 +60,13 @@ IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cinv, THCNumerics<real>::cinv, Real)
6160

6261
#endif
6362

63+
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) || \
64+
defined(THC_REAL_IS_SHORT) || defined(THC_REAL_IS_INT) || defined(THC_REAL_IS_LONG)
65+
66+
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( neg, THCNumerics<real>::neg, Real)
67+
68+
#endif
69+
6470
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( abs, THCNumerics<real>::abs, Real)
6571

6672
#undef IMPLEMENT_CUDA_TENSOR_BASIC_FUNC_

torch/lib/THC/generic/THCTensorMathPointwise.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,17 @@ THC_API void THCTensor_(trunc)(THCState *state, THCTensor *self, THCTensor *src)
3030
THC_API void THCTensor_(frac)(THCState *state, THCTensor *self, THCTensor *src);
3131
THC_API void THCTensor_(lerp)(THCState *state, THCTensor *result, THCTensor *a, THCTensor *b, real w);
3232

33-
THC_API void THCTensor_(neg)(THCState *state, THCTensor *self, THCTensor *src);
3433
THC_API void THCTensor_(cinv)(THCState *state, THCTensor *self, THCTensor *src);
3534

3635
#endif
3736

37+
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) || \
38+
defined(THC_REAL_IS_SHORT) || defined(THC_REAL_IS_INT) || defined(THC_REAL_IS_LONG)
39+
40+
THC_API void THCTensor_(neg)(THCState *state, THCTensor *self, THCTensor *src);
41+
42+
#endif
43+
3844
THC_API void THCTensor_(abs)(THCState *state, THCTensor *self, THCTensor *src);
3945
THC_API void THCTensor_(sign)(THCState *state, THCTensor *self, THCTensor *src);
4046
THC_API void THCTensor_(clamp)(THCState *state, THCTensor *self, THCTensor *src, real min_value, real max_value);

0 commit comments

Comments
 (0)