@@ -344,6 +344,31 @@ void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, long m,
344344 (int )batchCount));
345345}
346346
347+ #if CUDA_VERSION >= 8000
348+ void THCudaBlas_SgemmStridedBatched (THCState *state, char transa, char transb, long m, long n, long k,
349+ float alpha, const float *a, long lda, long strideA, const float *b, long ldb, long strideB,
350+ float beta, float *c, long ldc, long strideC, long batchCount)
351+ {
352+ if ( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
353+
354+ {
355+ THError (" Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
356+ " with the bound [val] <= %d" , INT_MAX);
357+ }
358+
359+ adjustLd (transa, transb, m, n, k, &lda, &ldb, &ldc);
360+ cublasOperation_t opa = convertTransToCublasOperation (transa);
361+ cublasOperation_t opb = convertTransToCublasOperation (transb);
362+
363+ cublasHandle_t handle = THCState_getCurrentBlasHandle (state);
364+ cublasSetStream (handle, THCState_getCurrentStream (state));
365+ THCublasCheck (cublasSgemmStridedBatched (handle,
366+ opa, opb, (int )m, (int )n, (int )k,
367+ &alpha, a, (int )lda, strideA, b, (int )ldb, strideB, &beta, c, (int )ldc, strideC,
368+ (int )batchCount));
369+ }
370+ #endif
371+
347372void THCudaBlas_DgemmBatched (THCState *state, char transa, char transb, long m, long n, long k,
348373 double alpha, const double *a[], long lda, const double *b[], long ldb,
349374 double beta, double *c[], long ldc, long batchCount)
@@ -366,6 +391,30 @@ void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, long m,
366391 (int )batchCount));
367392}
368393
394+ #if CUDA_VERSION >= 8000
395+ void THCudaBlas_DgemmStridedBatched (THCState *state, char transa, char transb, long m, long n, long k,
396+ double alpha, const double *a, long lda, long strideA, const double *b, long ldb, long strideB,
397+ double beta, double *c, long ldc, long strideC, long batchCount)
398+ {
399+ if ( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
400+ {
401+ THError (" Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
402+ " with the bound [val] <= %d" , INT_MAX);
403+ }
404+
405+ adjustLd (transa, transb, m, n, k, &lda, &ldb, &ldc);
406+ cublasOperation_t opa = convertTransToCublasOperation (transa);
407+ cublasOperation_t opb = convertTransToCublasOperation (transb);
408+
409+ cublasHandle_t handle = THCState_getCurrentBlasHandle (state);
410+ cublasSetStream (handle, THCState_getCurrentStream (state));
411+ THCublasCheck (cublasDgemmStridedBatched (handle,
412+ opa, opb, (int )m, (int )n, (int )k,
413+ &alpha, a, (int )lda, strideA, b, (int )ldb, strideB, &beta, c, (int )ldc, strideC,
414+ (int )batchCount));
415+ }
416+ #endif
417+
369418/* Inverse */
370419void THCudaBlas_Sgetrf (THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize) {
371420 if ( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) )
0 commit comments