Skip to content

Commit c10e02f

Browse files
luitjensdanpovey
authored andcommitted
[build,src] Enhancements to the cudamatrix/cudavector classes. (kaldi-asr#3373)
* Added CuSolver to the matrix class. This is only supported with Cuda 9.1 or newer. Calling CuSolver code without Cuda 9.1 or newer will result in a runtime error. This change required some changes to the build system which requires versioning the configure script. This forces everyone to reconfigure. Failure to reconfigure would result in linking and build errors on some systems.
1 parent 04cf43b commit c10e02f

File tree

4 files changed

+83
-5
lines changed

4 files changed

+83
-5
lines changed

src/configure

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
# This should be incremented after any significant change to the configure
4141
# script, i.e. any change affecting kaldi.mk or the build system as a whole.
42-
CONFIGURE_VERSION=10
42+
CONFIGURE_VERSION=11
4343

4444
# We support bash version 3.2 (Macs still ship with this version as of 2019)
4545
# and above.
@@ -433,22 +433,32 @@ function configure_cuda {
433433
7_*)
434434
MIN_UNSUPPORTED_GCC_VER="5.0"
435435
MIN_UNSUPPORTED_GCC_VER_NUM=50000;
436+
CUSOLVER=false
436437
;;
437438
8_*)
438439
MIN_UNSUPPORTED_GCC_VER="6.0"
439440
MIN_UNSUPPORTED_GCC_VER_NUM=60000;
441+
CUSOLVER=false
440442
;;
441-
9_0 | 9_1)
443+
9_0)
442444
MIN_UNSUPPORTED_GCC_VER="7.0"
443445
MIN_UNSUPPORTED_GCC_VER_NUM=70000;
446+
CUSOLVER=false
447+
;;
448+
9_1)
449+
MIN_UNSUPPORTED_GCC_VER="7.0"
450+
MIN_UNSUPPORTED_GCC_VER_NUM=70000;
451+
CUSOLVER=true
444452
;;
445453
9_2 | 9_* | 10_0)
446454
MIN_UNSUPPORTED_GCC_VER="8.0"
447455
MIN_UNSUPPORTED_GCC_VER_NUM=80000;
456+
CUSOLVER=true
448457
;;
449458
10_1 | 10_*)
450459
MIN_UNSUPPORTED_GCC_VER="9.0"
451460
MIN_UNSUPPORTED_GCC_VER_NUM=90000;
461+
CUSOLVER=true
452462
;;
453463
*)
454464
echo "Unsupported CUDA_VERSION (CUDA_VERSION=$CUDA_VERSION), please report it to Kaldi mailing list, together with 'nvcc -h' or 'ptxas -h' which lists allowed -gencode values..."; exit 1;
@@ -492,6 +502,8 @@ function configure_cuda {
492502
echo CUDA = true >> kaldi.mk
493503
echo CUDATKDIR = $CUDATKDIR >> kaldi.mk
494504
echo "CUDA_ARCH = $CUDA_ARCH" >> kaldi.mk
505+
506+
495507
echo >> kaldi.mk
496508

497509
# 64bit/32bit? We do not support cross compilation with CUDA so, use direct
@@ -512,6 +524,11 @@ WARNING: CUDA will not be used!
512524
CUDA is not supported with 32-bit builds."
513525
exit 1;
514526
fi
527+
528+
#add cusolver flags for newer toolkits
529+
if [[ $CUSOLVER -eq true ]]; then
530+
echo "CUDA_LDLIBS += -lcusolver" >> kaldi.mk
531+
fi
515532

516533
else
517534
echo "\

src/cudamatrix/cu-common.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@
5959
} \
6060
}
6161

62+
#define CUSOLVER_SAFE_CALL(fun) \
63+
{ \
64+
int32 ret; \
65+
if ((ret = (fun)) != 0) { \
66+
KALDI_ERR << "cusolverStatus_t " << ret << " : \"" << ret << "\" returned from '" << #fun << "'"; \
67+
} \
68+
}
69+
70+
6271
#define CUSPARSE_SAFE_CALL(fun) \
6372
{ \
6473
int32 ret; \

src/cudamatrix/cu-device.cc

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,21 @@ void CuDevice::Initialize() {
110110
// Initialize CUBLAS.
111111
CUBLAS_SAFE_CALL(cublasCreate(&cublas_handle_));
112112
CUBLAS_SAFE_CALL(cublasSetStream(cublas_handle_, cudaStreamPerThread));
113+
114+
#if CUDA_VERSION >= 9100
115+
CUSOLVER_SAFE_CALL(cusolverDnCreate(&cusolverdn_handle_));
116+
CUSOLVER_SAFE_CALL(cusolverDnSetStream(cusolverdn_handle_,
117+
cudaStreamPerThread));
118+
#endif
113119

114-
#if CUDA_VERSION >= 9000
120+
#if CUDA_VERSION >= 9000
115121
if (device_options_.use_tensor_cores) {
116122
// Enable tensor cores in CUBLAS
117123
// Note if the device does not support tensor cores this will fall back to normal math mode
118124
CUBLAS_SAFE_CALL(cublasSetMathMode(cublas_handle_,
119125
CUBLAS_TENSOR_OP_MATH));
120126
}
121-
#endif
127+
#endif
122128

123129
// Initialize the cuSPARSE library
124130
CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_));
@@ -130,6 +136,7 @@ void CuDevice::Initialize() {
130136
// To get same random sequence, call srand() before the constructor is invoked,
131137
CURAND_SAFE_CALL(curandSetGeneratorOrdering(
132138
curand_handle_, CURAND_ORDERING_PSEUDO_DEFAULT));
139+
CURAND_SAFE_CALL(curandSetStream(curand_handle_, cudaStreamPerThread));
133140
SeedGpu();
134141
}
135142
}
@@ -263,6 +270,23 @@ void CuDevice::FinalizeActiveGpu() {
263270
// Initialize CUBLAS.
264271
CUBLAS_SAFE_CALL(cublasCreate(&cublas_handle_));
265272
CUBLAS_SAFE_CALL(cublasSetStream(cublas_handle_, cudaStreamPerThread));
273+
274+
#if CUDA_VERSION >= 9100
275+
CUSOLVER_SAFE_CALL(cusolverDnCreate(&cusolverdn_handle_));
276+
CUSOLVER_SAFE_CALL(cusolverDnSetStream(cusolverdn_handle_,
277+
cudaStreamPerThread));
278+
#endif
279+
280+
#if CUDA_VERSION >= 9000
281+
if (device_options_.use_tensor_cores) {
282+
// Enable tensor cores in CUBLAS
283+
// Note if the device does not support tensor cores this will fall back to normal math mode
284+
CUBLAS_SAFE_CALL(cublasSetMathMode(cublas_handle_,
285+
CUBLAS_TENSOR_OP_MATH));
286+
}
287+
#endif
288+
289+
266290
// Initialize the cuSPARSE library
267291
CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_));
268292
CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread));
@@ -537,7 +561,8 @@ CuDevice::CuDevice():
537561
initialized_(false),
538562
device_id_copy_(-1),
539563
cublas_handle_(NULL),
540-
cusparse_handle_(NULL) {
564+
cusparse_handle_(NULL),
565+
cusolverdn_handle_(NULL) {
541566
}
542567

543568
CuDevice::~CuDevice() {
@@ -548,6 +573,11 @@ CuDevice::~CuDevice() {
548573
if (curand_handle_) {
549574
CURAND_SAFE_CALL(curandDestroyGenerator(curand_handle_));
550575
}
576+
#if CUDA_VERSION >= 9100
577+
if (cusolverdn_handle_) {
578+
CUSOLVER_SAFE_CALL(cusolverDnDestroy(cusolverdn_handle_));
579+
}
580+
#endif
551581
}
552582

553583

src/cudamatrix/cu-device.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@
3737
#include "cudamatrix/cu-allocator.h"
3838
#include "cudamatrix/cu-common.h"
3939

40+
#if CUDA_VERSION >= 9100
41+
#include <cusolverDn.h>
42+
#else
43+
// cusolver not supported.
44+
// Setting a few types to minimize compiler guards.
45+
// If a user tries to use cusovler it will throw an error.
46+
typedef void* cusolverDnHandle_t;
47+
typedef int cusolverStatus_t;
48+
#endif
49+
4050
namespace kaldi {
4151

4252
class CuTimer;
@@ -83,6 +93,13 @@ class CuDevice {
8393
inline cublasHandle_t GetCublasHandle() { return cublas_handle_; }
8494
inline cusparseHandle_t GetCusparseHandle() { return cusparse_handle_; }
8595
inline curandGenerator_t GetCurandHandle() { return curand_handle_; }
96+
inline cusolverDnHandle_t GetCusolverDnHandle() {
97+
#if CUDA_VERSION < 9100
98+
KALDI_ERR << "CUDA VERSION '" << CUDA_VERSION << "' not new enough to support "
99+
<< "cusolver. Upgrade to at least 9.1";
100+
#endif
101+
return cusolverdn_handle_;
102+
}
86103

87104
inline void SeedGpu() {
88105
if (CuDevice::Instantiate().Enabled()) {
@@ -304,6 +321,7 @@ class CuDevice {
304321
cublasHandle_t cublas_handle_;
305322
cusparseHandle_t cusparse_handle_;
306323
curandGenerator_t curand_handle_;
324+
cusolverDnHandle_t cusolverdn_handle_;
307325
}; // class CuDevice
308326

309327

@@ -322,6 +340,10 @@ inline cublasHandle_t GetCublasHandle() {
322340
return CuDevice::Instantiate().GetCublasHandle();
323341
}
324342

343+
inline cusolverDnHandle_t GetCusolverDnHandle() {
344+
return CuDevice::Instantiate().GetCusolverDnHandle();
345+
}
346+
325347
// A more convenient way to get the handle to use cuSPARSE APIs.
326348
inline cusparseHandle_t GetCusparseHandle() {
327349
return CuDevice::Instantiate().GetCusparseHandle();

0 commit comments

Comments
 (0)