@@ -1471,10 +1471,18 @@ Tensor _inverse_helper_cuda_legacy(const Tensor& self) {
1471
1471
1472
1472
Tensor _inverse_helper_cuda (const Tensor& self) {
1473
1473
#ifdef USE_CUSOLVER
1474
- if ((self.dim () == 2 ) || (/* self.dim() > 2 && */ batchCount (self) <= 2 ) || !use_magma_) {
1475
- return _inverse_helper_cuda_lib (self); // cusolver or cublas
1476
- } else {
1477
- return _inverse_helper_cuda_legacy (self); // magma-cuda
1474
+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
1475
+ switch (preferred_backend) {
1476
+ case at::LinalgBackend::Cusolver:
1477
+ return _inverse_helper_cuda_lib (self); // cusolver or cublas
1478
+ case at::LinalgBackend::Magma:
1479
+ return _inverse_helper_cuda_legacy (self); // magma-cuda
1480
+ default :
1481
+ if (batchCount (self) <= 2 || !use_magma_) {
1482
+ return _inverse_helper_cuda_lib (self); // cusolver or cublas
1483
+ } else {
1484
+ return _inverse_helper_cuda_legacy (self); // magma-cuda
1485
+ }
1478
1486
}
1479
1487
#else
1480
1488
return _inverse_helper_cuda_legacy (self); // magma-cuda
@@ -1503,10 +1511,18 @@ Tensor& _linalg_inv_out_helper_cuda(Tensor &result, Tensor& infos_lu, Tensor& in
1503
1511
// This function calculates the inverse matrix in-place
1504
1512
// result should be in column major order and contain matrices to invert
1505
1513
#ifdef USE_CUSOLVER
1506
- if ((result.dim () == 2 ) || (/* result.dim() > 2 && */ batchCount (result) <= 2 ) || !use_magma_) {
1507
- return _linalg_inv_out_helper_cuda_lib (result, infos_lu, infos_getri); // cusolver or cublas
1508
- } else {
1509
- return _linalg_inv_out_helper_cuda_legacy (result, infos_lu, infos_getri); // magma-cuda
1514
+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
1515
+ switch (preferred_backend) {
1516
+ case at::LinalgBackend::Cusolver:
1517
+ return _linalg_inv_out_helper_cuda_lib (result, infos_lu, infos_getri); // cusolver or cublas
1518
+ case at::LinalgBackend::Magma:
1519
+ return _linalg_inv_out_helper_cuda_legacy (result, infos_lu, infos_getri); // magma-cuda
1520
+ default :
1521
+ if (batchCount (result) <= 2 || !use_magma_) {
1522
+ return _linalg_inv_out_helper_cuda_lib (result, infos_lu, infos_getri); // cusolver or cublas
1523
+ } else {
1524
+ return _linalg_inv_out_helper_cuda_legacy (result, infos_lu, infos_getri); // magma-cuda
1525
+ }
1510
1526
}
1511
1527
#else
1512
1528
return _linalg_inv_out_helper_cuda_legacy (result, infos_lu, infos_getri); // magma-cuda
@@ -1600,10 +1616,18 @@ Tensor _cholesky_solve_helper_cuda_magma(const Tensor& self, const Tensor& A, bo
1600
1616
// Batched cholesky_solve is dispatched to magma.
1601
1617
Tensor _cholesky_solve_helper_cuda (const Tensor& self, const Tensor& A, bool upper) {
1602
1618
#ifdef USE_CUSOLVER
1603
- if (batchCount (self) == 1 || !use_magma_) {
1604
- return _cholesky_solve_helper_cuda_cusolver (self, A, upper);
1605
- } else {
1606
- return _cholesky_solve_helper_cuda_magma (self, A, upper);
1619
+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
1620
+ switch (preferred_backend) {
1621
+ case at::LinalgBackend::Cusolver:
1622
+ return _cholesky_solve_helper_cuda_cusolver (self, A, upper);
1623
+ case at::LinalgBackend::Magma:
1624
+ return _cholesky_solve_helper_cuda_magma (self, A, upper);
1625
+ default :
1626
+ if (batchCount (self) == 1 || !use_magma_) {
1627
+ return _cholesky_solve_helper_cuda_cusolver (self, A, upper);
1628
+ } else {
1629
+ return _cholesky_solve_helper_cuda_magma (self, A, upper);
1630
+ }
1607
1631
}
1608
1632
#else
1609
1633
return _cholesky_solve_helper_cuda_magma (self, A, upper);
@@ -1706,10 +1730,20 @@ void cholesky_helper_magma(const Tensor& input, bool upper, const Tensor& info)
1706
1730
1707
1731
static void cholesky_kernel (const Tensor& input, const Tensor& info, bool upper) {
1708
1732
#ifdef USE_CUSOLVER
1709
- if (batchCount (input) == 1 || !use_magma_ || use_cusolver_potrf_batched_) {
1710
- cholesky_helper_cusolver (input, upper, info);
1711
- } else {
1712
- cholesky_helper_magma (input, upper, info);
1733
+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
1734
+ switch (preferred_backend) {
1735
+ case at::LinalgBackend::Cusolver:
1736
+ cholesky_helper_cusolver (input, upper, info);
1737
+ break ;
1738
+ case at::LinalgBackend::Magma:
1739
+ cholesky_helper_magma (input, upper, info);
1740
+ break ;
1741
+ default :
1742
+ if (batchCount (input) == 1 || !use_magma_ || use_cusolver_potrf_batched_) {
1743
+ cholesky_helper_cusolver (input, upper, info);
1744
+ } else {
1745
+ cholesky_helper_magma (input, upper, info);
1746
+ }
1713
1747
}
1714
1748
#else
1715
1749
cholesky_helper_magma (input, upper, info);
@@ -1777,10 +1811,19 @@ Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper)
1777
1811
// result should be in column major order and contain matrices to invert
1778
1812
// the content of result is overwritten by 'apply_cholesky_inverse'
1779
1813
#ifdef USE_CUSOLVER
1780
- if (batchCount (result) == 1 || !use_magma_) {
1781
- return cholesky_inverse_kernel_impl_cusolver (result, infos, upper);
1782
- } else {
1783
- return cholesky_inverse_kernel_impl_magma (result, infos, upper);
1814
+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
1815
+ switch (preferred_backend) {
1816
+ case at::LinalgBackend::Cusolver:
1817
+ return cholesky_inverse_kernel_impl_cusolver (result, infos, upper);
1818
+ case at::LinalgBackend::Magma:
1819
+ return cholesky_inverse_kernel_impl_magma (result, infos, upper);
1820
+ default :
1821
+ if (batchCount (result) == 1 ||
1822
+ !use_magma_) {
1823
+ return cholesky_inverse_kernel_impl_cusolver (result, infos, upper);
1824
+ } else {
1825
+ return cholesky_inverse_kernel_impl_magma (result, infos, upper);
1826
+ }
1784
1827
}
1785
1828
#else
1786
1829
return cholesky_inverse_kernel_impl_magma (result, infos, upper);
@@ -1944,20 +1987,39 @@ static void lu_batched_magma(const Tensor& input, const Tensor& pivots, const Te
1944
1987
static void apply_lu (const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
1945
1988
int64_t batch_size = batchCount (input);
1946
1989
#ifdef USE_CUSOLVER
1947
- // Use a heuristic to determine that cusolver is faster than MAGMA for the following sizes.
1948
- auto m = input.size (-2 );
1949
- // exclude complex128 since nan_to_num_ does not work with it.
1950
- if ((batch_size == 1 || (batch_size <= 8 && m <= 16 ) || !use_magma_ ) && !input.is_complex ()) {
1951
- lu_looped_cusolver (input, pivots, infos, compute_pivots);
1990
+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
1991
+ switch (preferred_backend) {
1992
+ case at::LinalgBackend::Cusolver:
1993
+ lu_looped_cusolver (input, pivots, infos, compute_pivots);
1994
+ break ;
1995
+ case at::LinalgBackend::Magma:
1996
+ if (batch_size == 1 ) {
1997
+ lu_looped_magma (input, pivots, infos, compute_pivots);
1998
+ } else {
1999
+ lu_batched_magma (input, pivots, infos, compute_pivots);
2000
+ }
2001
+ break ;
2002
+ default :
2003
+ // Use a heuristic to determine that cusolver is faster than MAGMA for the following sizes.
2004
+ auto m = input.size (-2 );
2005
+ // exclude complex128 since nan_to_num_ does not work with it.
2006
+ if ((batch_size == 1 ||
2007
+ (batch_size <= 8 && m <= 16 ) ||
2008
+ !use_magma_)
2009
+ && !input.is_complex ()) {
2010
+ lu_looped_cusolver (input, pivots, infos, compute_pivots);
2011
+ } else {
2012
+ lu_batched_magma (input, pivots, infos, compute_pivots);
2013
+ }
1952
2014
}
1953
2015
#else
1954
2016
if (batch_size == 1 ) {
1955
2017
lu_looped_magma (input, pivots, infos, compute_pivots);
1956
2018
}
1957
- #endif // USE_CUSOLVER
1958
2019
else {
1959
2020
lu_batched_magma (input, pivots, infos, compute_pivots);
1960
2021
}
2022
+ #endif // USE_CUSOLVER
1961
2023
}
1962
2024
1963
2025
REGISTER_CUDA_DISPATCH (lu_stub, &apply_lu);
@@ -2064,12 +2126,12 @@ Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau) {
2064
2126
// See discussions in https://github.com/pytorch/pytorch/pull/51348 for comparison of cuSOLVER-MAGMA
2065
2127
// and Windows failure.
2066
2128
// For reference here is the MAGMA-based implementation: https://gist.github.com/IvanYashchuk/2db50002c9d3c1462ff769e6410ad983
2067
- #if defined(USE_CUSOLVER)
2068
- return orgqr_helper_cusolver (result, tau); // cusolver
2069
- #else
2070
- TORCH_CHECK (false , " Calling torch.orgqr on a CUDA tensor requires compiling " ,
2071
- " PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support." );
2072
- #endif
2129
+ #if defined(USE_CUSOLVER)
2130
+ return orgqr_helper_cusolver (result, tau); // cusolver
2131
+ #else
2132
+ TORCH_CHECK (false , " Calling torch.orgqr on a CUDA tensor requires compiling " ,
2133
+ " PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support." );
2134
+ #endif
2073
2135
}
2074
2136
2075
2137
REGISTER_CUDA_DISPATCH (orgqr_stub, &orgqr_kernel_impl);
@@ -2136,7 +2198,14 @@ void geqrf_magma(const Tensor& input, const Tensor& tau) {
2136
2198
// This is a backend library dispatching helper function for calling looped batch implementation
2137
2199
void geqrf_looped (const Tensor& input, const Tensor& tau) {
2138
2200
#if defined(USE_CUSOLVER)
2139
- return geqrf_cusolver (input, tau);
2201
+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
2202
+ switch (preferred_backend) {
2203
+ case at::LinalgBackend::Magma:
2204
+ return geqrf_magma (input, tau);
2205
+ case at::LinalgBackend::Cusolver:
2206
+ default :
2207
+ return geqrf_cusolver (input, tau);
2208
+ }
2140
2209
#else
2141
2210
return geqrf_magma (input, tau);
2142
2211
#endif
@@ -2273,9 +2342,16 @@ std::tuple<Tensor, Tensor> linalg_qr_helper_magma(const Tensor& self, c10::strin
2273
2342
2274
2343
std::tuple<Tensor, Tensor> _linalg_qr_helper_cuda (const Tensor& input, c10::string_view mode) {
2275
2344
#if defined(USE_CUSOLVER)
2276
- // _linalg_qr_helper_default is a generic function that is implemented using
2277
- // geqrf_stub and orgqr_stub. It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
2278
- return _linalg_qr_helper_default (input, mode);
2345
+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
2346
+ switch (preferred_backend) {
2347
+ case at::LinalgBackend::Magma:
2348
+ return linalg_qr_helper_magma (input, mode);
2349
+ case at::LinalgBackend::Cusolver:
2350
+ default :
2351
+ // _linalg_qr_helper_default is a generic function that is implemented using
2352
+ // geqrf_stub and orgqr_stub. It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
2353
+ return _linalg_qr_helper_default (input, mode);
2354
+ }
2279
2355
#else
2280
2356
return linalg_qr_helper_magma (input, mode);
2281
2357
#endif
@@ -2432,7 +2508,15 @@ void linalg_eigh_magma(const Tensor& eigenvalues, const Tensor& eigenvectors, co
2432
2508
2433
2509
void linalg_eigh_kernel (const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
2434
2510
#if defined(USE_CUSOLVER)
2435
- linalg_eigh_cusolver (eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
2511
+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
2512
+ switch (preferred_backend) {
2513
+ case at::LinalgBackend::Magma:
2514
+ linalg_eigh_magma (eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
2515
+ break ;
2516
+ case at::LinalgBackend::Cusolver:
2517
+ default :
2518
+ linalg_eigh_cusolver (eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
2519
+ }
2436
2520
#else
2437
2521
linalg_eigh_magma (eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
2438
2522
#endif
@@ -2731,7 +2815,14 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda_legacy(const Tensor& self, b
2731
2815
2732
2816
std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda (const Tensor& self, bool some, bool compute_uv) {
2733
2817
#ifdef USE_CUSOLVER
2734
- return _svd_helper_cuda_lib (self, some, compute_uv);
2818
+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
2819
+ switch (preferred_backend) {
2820
+ case at::LinalgBackend::Magma:
2821
+ return _svd_helper_cuda_legacy (self, some, compute_uv);
2822
+ case at::LinalgBackend::Cusolver:
2823
+ default :
2824
+ return _svd_helper_cuda_lib (self, some, compute_uv);
2825
+ }
2735
2826
#else
2736
2827
return _svd_helper_cuda_legacy (self, some, compute_uv);
2737
2828
#endif
@@ -3046,10 +3137,17 @@ void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& /*infos*/
3046
3137
3047
3138
void gels_looped (const Tensor& a, Tensor& b, Tensor& infos) {
3048
3139
#if defined(USE_CUSOLVER)
3049
- // linalg_lstsq_gels is a generic function that is implemented using
3050
- // geqrf_stub, ormqr_stub, and triangular_solve_stub
3051
- // It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
3052
- return linalg_lstsq_gels (a, b, infos);
3140
+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
3141
+ switch (preferred_backend) {
3142
+ case at::LinalgBackend::Magma:
3143
+ return gels_magma (a, b, infos);
3144
+ case at::LinalgBackend::Cusolver:
3145
+ default :
3146
+ // linalg_lstsq_gels is a generic function that is implemented using
3147
+ // geqrf_stub, ormqr_stub, and triangular_solve_stub
3148
+ // It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
3149
+ return linalg_lstsq_gels (a, b, infos);
3150
+ }
3053
3151
#else
3054
3152
return gels_magma (a, b, infos);
3055
3153
#endif
0 commit comments