Skip to content

Commit 6e84a01

Browse files
t-vifacebook-github-bot
authored andcommitted
move to non-legacy magma v2 headers (pytorch#49978)
Summary: We recently (pytorch#7582) dropped magma v1 support, but we were still including the legacy compatibility headers and using functions only provided by them. This changes the includes to the new magma_v2 header and fixes the triangular solve functions to use the v2-style magma_queue-using API. Pull Request resolved: pytorch#49978 Reviewed By: mrshenli Differential Revision: D25752499 Pulled By: ngimel fbshipit-source-id: 26d916bc5ce63978b341aefb072af228f140637d
1 parent fdb81c5 commit 6e84a01

File tree

5 files changed

+116
-33
lines changed

5 files changed

+116
-33
lines changed

aten/src/ATen/cuda/detail/CUDAHooks.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#endif
2222

2323
#ifdef USE_MAGMA
24-
#include <magma.h>
24+
#include <magma_v2.h>
2525
#endif
2626

2727
#ifdef __HIP_PLATFORM_HCC__

aten/src/ATen/native/cuda/BatchLinearAlgebra.cu

Lines changed: 112 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
#include <THC/THC.h> // for USE_MAGMA
1717

1818
#ifdef USE_MAGMA
19-
#include <magma.h>
2019
#include <magma_types.h>
20+
#include <magma_v2.h>
2121

2222
const bool use_magma_ = true;
2323
#else
@@ -95,10 +95,18 @@ void magmaCholeskyBatched(
9595
magma_uplo_t uplo, magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
9696
magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue);
9797

98-
template<class scalar_t>
98+
template <class scalar_t>
9999
void magmaTriangularSolve(
100-
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
101-
scalar_t* dA, magma_int_t ldda, scalar_t* dB, magma_int_t lddb);
100+
magma_uplo_t uplo,
101+
magma_trans_t trans,
102+
magma_diag_t diag,
103+
magma_int_t m,
104+
magma_int_t n,
105+
scalar_t* dA,
106+
magma_int_t ldda,
107+
scalar_t* dB,
108+
magma_int_t lddb,
109+
const MAGMAQueue& magma_queue);
102110

103111
template<class scalar_t>
104112
void magmaTriangularSolveBatched(
@@ -662,45 +670,117 @@ void magmaCholeskyBatched<c10::complex<float>>(
662670
AT_CUDA_CHECK(cudaGetLastError());
663671
}
664672

665-
template<>
673+
template <>
666674
void magmaTriangularSolve<double>(
667-
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
668-
double* dA, magma_int_t ldda, double* dB, magma_int_t lddb) {
669-
MagmaStreamSyncGuard guard;
670-
magma_dtrsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb);
675+
magma_uplo_t uplo,
676+
magma_trans_t trans,
677+
magma_diag_t diag,
678+
magma_int_t m,
679+
magma_int_t n,
680+
double* dA,
681+
magma_int_t ldda,
682+
double* dB,
683+
magma_int_t lddb,
684+
const MAGMAQueue& magma_queue) {
685+
magma_dtrsm(
686+
MagmaLeft,
687+
uplo,
688+
trans,
689+
diag,
690+
m,
691+
n,
692+
1,
693+
dA,
694+
ldda,
695+
dB,
696+
lddb,
697+
magma_queue.get_queue());
671698
AT_CUDA_CHECK(cudaGetLastError());
672699
}
673700

674-
template<>
701+
template <>
675702
void magmaTriangularSolve<float>(
676-
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
677-
float* dA, magma_int_t ldda, float* dB, magma_int_t lddb) {
678-
MagmaStreamSyncGuard guard;
679-
magma_strsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb);
703+
magma_uplo_t uplo,
704+
magma_trans_t trans,
705+
magma_diag_t diag,
706+
magma_int_t m,
707+
magma_int_t n,
708+
float* dA,
709+
magma_int_t ldda,
710+
float* dB,
711+
magma_int_t lddb,
712+
const MAGMAQueue& magma_queue) {
713+
magma_strsm(
714+
MagmaLeft,
715+
uplo,
716+
trans,
717+
diag,
718+
m,
719+
n,
720+
1,
721+
dA,
722+
ldda,
723+
dB,
724+
lddb,
725+
magma_queue.get_queue());
680726
AT_CUDA_CHECK(cudaGetLastError());
681727
}
682728

683-
template<>
729+
template <>
684730
void magmaTriangularSolve<c10::complex<double>>(
685-
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
686-
c10::complex<double>* dA, magma_int_t ldda, c10::complex<double>* dB, magma_int_t lddb) {
687-
MagmaStreamSyncGuard guard;
731+
magma_uplo_t uplo,
732+
magma_trans_t trans,
733+
magma_diag_t diag,
734+
magma_int_t m,
735+
magma_int_t n,
736+
c10::complex<double>* dA,
737+
magma_int_t ldda,
738+
c10::complex<double>* dB,
739+
magma_int_t lddb,
740+
const MAGMAQueue& magma_queue) {
688741
magmaDoubleComplex alpha({1, 0});
689-
magma_ztrsm(MagmaLeft, uplo, trans, diag, m, n, alpha,
690-
reinterpret_cast<magmaDoubleComplex*>(dA), ldda,
691-
reinterpret_cast<magmaDoubleComplex*>(dB), lddb);
742+
magma_ztrsm(
743+
MagmaLeft,
744+
uplo,
745+
trans,
746+
diag,
747+
m,
748+
n,
749+
alpha,
750+
reinterpret_cast<magmaDoubleComplex*>(dA),
751+
ldda,
752+
reinterpret_cast<magmaDoubleComplex*>(dB),
753+
lddb,
754+
magma_queue.get_queue());
692755
AT_CUDA_CHECK(cudaGetLastError());
693756
}
694757

695-
template<>
758+
template <>
696759
void magmaTriangularSolve<c10::complex<float>>(
697-
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
698-
c10::complex<float>* dA, magma_int_t ldda, c10::complex<float>* dB, magma_int_t lddb) {
699-
MagmaStreamSyncGuard guard;
760+
magma_uplo_t uplo,
761+
magma_trans_t trans,
762+
magma_diag_t diag,
763+
magma_int_t m,
764+
magma_int_t n,
765+
c10::complex<float>* dA,
766+
magma_int_t ldda,
767+
c10::complex<float>* dB,
768+
magma_int_t lddb,
769+
const MAGMAQueue& magma_queue) {
700770
magmaFloatComplex alpha({1, 0});
701-
magma_ctrsm(MagmaLeft, uplo, trans, diag, m, n, alpha,
702-
reinterpret_cast<magmaFloatComplex*>(dA), ldda,
703-
reinterpret_cast<magmaFloatComplex*>(dB), lddb);
771+
magma_ctrsm(
772+
MagmaLeft,
773+
uplo,
774+
trans,
775+
diag,
776+
m,
777+
n,
778+
alpha,
779+
reinterpret_cast<magmaFloatComplex*>(dA),
780+
ldda,
781+
reinterpret_cast<magmaFloatComplex*>(dB),
782+
lddb,
783+
magma_queue.get_queue());
704784
AT_CUDA_CHECK(cudaGetLastError());
705785
}
706786

@@ -1636,11 +1716,14 @@ AT_ERROR("triangular_solve: MAGMA library not found in "
16361716
magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");
16371717
magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");
16381718

1719+
MAGMAQueue magma_queue(b.get_device());
1720+
16391721
// batch_size == 1 implies that:
16401722
// 1. the RHS and LHS tensors have 2 dimensions, or
16411723
// 2. the RHS and LHS tensors have more than 2 dimensions but all batch dimensions are 1
16421724
if (batch_size == 1) {
1643-
magmaTriangularSolve<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n);
1725+
magmaTriangularSolve<scalar_t>(
1726+
uplo, trans, diag, n, nrhs, A_data, n, b_data, n, magma_queue);
16441727
} else {
16451728
auto A_mat_stride = matrixStride(A);
16461729
auto b_mat_stride = matrixStride(b);

aten/src/ATen/native/cuda/MiscUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
#include <THC/THC.h> // for USE_MAGMA
77

88
#ifdef USE_MAGMA
9-
#include <magma.h>
109
#include <magma_types.h>
10+
#include <magma_v2.h>
1111
#endif
1212

1313
namespace at {

aten/src/THC/THCTensorMathMagma.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include <ATen/native/cuda/MiscUtils.h>
99

1010
#ifdef USE_MAGMA
11-
#include <magma.h>
11+
#include <magma_v2.h>
1212
#endif
1313

1414
#ifndef DIVUP

aten/src/THC/THCTensorMathMagma.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#define THC_TENSOR_MATH_MAGMA_CUH
33

44
#ifdef USE_MAGMA
5-
#include <magma.h>
5+
#include <magma_v2.h>
66
#endif
77

88
#ifdef USE_MAGMA

0 commit comments

Comments
 (0)