Skip to content

Commit bfe5ad2

Browse files
xwang233facebook-github-bot
authored andcommitted
[Linalg] Add a runtime switch to let pytorch prefer a backend impl in linalg functions on GPU (#67980)
Summary: Per title. This PR introduces a global flag that lets pytorch prefer one of the many backend implementations while calling linear algebra functions on GPU. Usage: ```python torch.backends.cuda.preferred_linalg_library('cusolver') ``` Available options (str): `'default'`, `'cusolver'`, `'magma'`. Issue #63992 inspired me to write this PR. No heuristic is perfect on all devices, library versions, matrix shapes, workloads, etc. We can obtain better performance if we can conveniently switch linear algebra backends at runtime. Performance of linear algebra operators after this PR should be no worse than before. The flag is set to **`'default'`** by default, which makes everything the same as before this PR. The implementation of this PR is basically following that of #67790. Pull Request resolved: #67980 Reviewed By: mruberry Differential Revision: D32849457 Pulled By: ngimel fbshipit-source-id: 679fee7744a03af057995aef06316306073010a6
1 parent 9663e08 commit bfe5ad2

File tree

10 files changed

+303
-43
lines changed

10 files changed

+303
-43
lines changed

aten/src/ATen/Context.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,21 @@ void Context::setAllowTF32CuBLAS(bool b) {
147147
allow_tf32_cublas = b;
148148
}
149149

150+
at::LinalgBackend Context::linalgPreferredBackend() const {
151+
return linalg_preferred_backend;
152+
}
153+
154+
void Context::setLinalgPreferredBackend(at::LinalgBackend b) {
155+
linalg_preferred_backend = b;
156+
if (b != at::LinalgBackend::Default) {
157+
TORCH_WARN_ONCE(
158+
"torch.backends.cuda.preferred_linalg_library is an experimental feature. "
159+
"If you see any error or unexpected behavior when this flag is set "
160+
"please file an issue on GitHub."
161+
);
162+
}
163+
}
164+
150165
bool Context::allowFP16ReductionCuBLAS() const {
151166
return allow_fp16_reduction_cublas;
152167
}

aten/src/ATen/Context.h

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <ATen/core/ATenGeneral.h>
44
#include <ATen/core/Generator.h>
55
#include <ATen/CPUGeneratorImpl.h>
6+
#include <ATen/LinalgBackend.h>
67
#include <ATen/core/LegacyTypeDispatch.h>
78
#include <ATen/core/DeprecatedTypeProperties.h>
89
#include <ATen/detail/CUDAHooksInterface.h>
@@ -128,6 +129,9 @@ class TORCH_API Context {
128129
bool deterministicCuDNN() const;
129130
void setDeterministicCuDNN(bool);
130131

132+
at::LinalgBackend linalgPreferredBackend() const;
133+
void setLinalgPreferredBackend(at::LinalgBackend);
134+
131135
// Note [Enabling Deterministic Operations]
132136
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
133137
// Operations in PyTorch that normally act nondeterministically, but have an alternate
@@ -249,6 +253,7 @@ class TORCH_API Context {
249253
bool allow_tf32_cublas = true;
250254
bool allow_fp16_reduction_cublas = true;
251255
bool enabled_mkldnn = true;
256+
at::LinalgBackend linalg_preferred_backend = at::LinalgBackend::Default;
252257
#ifdef C10_MOBILE
253258
bool release_original_weights = true;
254259
#else

aten/src/ATen/LinalgBackend.h

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#include <c10/util/Exception.h>
4+
5+
#include <ostream>
6+
#include <string>
7+
8+
namespace at {
9+
10+
enum class LinalgBackend : int8_t { Default, Cusolver, Magma };
11+
12+
inline std::string LinalgBackendToString(at::LinalgBackend backend) {
13+
switch (backend) {
14+
case LinalgBackend::Default:
15+
return "at::LinalgBackend::Default";
16+
case LinalgBackend::Cusolver:
17+
return "at::LinalgBackend::Cusolver";
18+
case LinalgBackend::Magma:
19+
return "at::LinalgBackend::Magma";
20+
default:
21+
TORCH_CHECK(false, "Unknown linalg backend");
22+
}
23+
}
24+
25+
inline std::ostream& operator<<(
26+
std::ostream& stream,
27+
at::LinalgBackend backend) {
28+
return stream << LinalgBackendToString(backend);
29+
}
30+
31+
} // namespace c10

aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp

+140-42
Original file line numberDiff line numberDiff line change
@@ -1471,10 +1471,18 @@ Tensor _inverse_helper_cuda_legacy(const Tensor& self) {
14711471

14721472
Tensor _inverse_helper_cuda(const Tensor& self) {
14731473
#ifdef USE_CUSOLVER
1474-
if ((self.dim() == 2) || (/* self.dim() > 2 && */ batchCount(self) <= 2) || !use_magma_) {
1475-
return _inverse_helper_cuda_lib(self); // cusolver or cublas
1476-
} else {
1477-
return _inverse_helper_cuda_legacy(self); // magma-cuda
1474+
auto preferred_backend = at::globalContext().linalgPreferredBackend();
1475+
switch (preferred_backend) {
1476+
case at::LinalgBackend::Cusolver:
1477+
return _inverse_helper_cuda_lib(self); // cusolver or cublas
1478+
case at::LinalgBackend::Magma:
1479+
return _inverse_helper_cuda_legacy(self); // magma-cuda
1480+
default:
1481+
if (batchCount(self) <= 2 || !use_magma_) {
1482+
return _inverse_helper_cuda_lib(self); // cusolver or cublas
1483+
} else {
1484+
return _inverse_helper_cuda_legacy(self); // magma-cuda
1485+
}
14781486
}
14791487
#else
14801488
return _inverse_helper_cuda_legacy(self); // magma-cuda
@@ -1503,10 +1511,18 @@ Tensor& _linalg_inv_out_helper_cuda(Tensor &result, Tensor& infos_lu, Tensor& in
15031511
// This function calculates the inverse matrix in-place
15041512
// result should be in column major order and contain matrices to invert
15051513
#ifdef USE_CUSOLVER
1506-
if ((result.dim() == 2) || (/* result.dim() > 2 && */ batchCount(result) <= 2) || !use_magma_) {
1507-
return _linalg_inv_out_helper_cuda_lib(result, infos_lu, infos_getri); // cusolver or cublas
1508-
} else {
1509-
return _linalg_inv_out_helper_cuda_legacy(result, infos_lu, infos_getri); // magma-cuda
1514+
auto preferred_backend = at::globalContext().linalgPreferredBackend();
1515+
switch (preferred_backend) {
1516+
case at::LinalgBackend::Cusolver:
1517+
return _linalg_inv_out_helper_cuda_lib(result, infos_lu, infos_getri); // cusolver or cublas
1518+
case at::LinalgBackend::Magma:
1519+
return _linalg_inv_out_helper_cuda_legacy(result, infos_lu, infos_getri); // magma-cuda
1520+
default:
1521+
if (batchCount(result) <= 2 || !use_magma_) {
1522+
return _linalg_inv_out_helper_cuda_lib(result, infos_lu, infos_getri); // cusolver or cublas
1523+
} else {
1524+
return _linalg_inv_out_helper_cuda_legacy(result, infos_lu, infos_getri); // magma-cuda
1525+
}
15101526
}
15111527
#else
15121528
return _linalg_inv_out_helper_cuda_legacy(result, infos_lu, infos_getri); // magma-cuda
@@ -1600,10 +1616,18 @@ Tensor _cholesky_solve_helper_cuda_magma(const Tensor& self, const Tensor& A, bo
16001616
// Batched cholesky_solve is dispatched to magma.
16011617
Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upper) {
16021618
#ifdef USE_CUSOLVER
1603-
if (batchCount(self) == 1 || !use_magma_) {
1604-
return _cholesky_solve_helper_cuda_cusolver(self, A, upper);
1605-
} else {
1606-
return _cholesky_solve_helper_cuda_magma(self, A, upper);
1619+
auto preferred_backend = at::globalContext().linalgPreferredBackend();
1620+
switch (preferred_backend) {
1621+
case at::LinalgBackend::Cusolver:
1622+
return _cholesky_solve_helper_cuda_cusolver(self, A, upper);
1623+
case at::LinalgBackend::Magma:
1624+
return _cholesky_solve_helper_cuda_magma(self, A, upper);
1625+
default:
1626+
if (batchCount(self) == 1 || !use_magma_) {
1627+
return _cholesky_solve_helper_cuda_cusolver(self, A, upper);
1628+
} else {
1629+
return _cholesky_solve_helper_cuda_magma(self, A, upper);
1630+
}
16071631
}
16081632
#else
16091633
return _cholesky_solve_helper_cuda_magma(self, A, upper);
@@ -1706,10 +1730,20 @@ void cholesky_helper_magma(const Tensor& input, bool upper, const Tensor& info)
17061730

17071731
static void cholesky_kernel(const Tensor& input, const Tensor& info, bool upper) {
17081732
#ifdef USE_CUSOLVER
1709-
if (batchCount(input) == 1 || !use_magma_ || use_cusolver_potrf_batched_) {
1710-
cholesky_helper_cusolver(input, upper, info);
1711-
} else {
1712-
cholesky_helper_magma(input, upper, info);
1733+
auto preferred_backend = at::globalContext().linalgPreferredBackend();
1734+
switch (preferred_backend) {
1735+
case at::LinalgBackend::Cusolver:
1736+
cholesky_helper_cusolver(input, upper, info);
1737+
break;
1738+
case at::LinalgBackend::Magma:
1739+
cholesky_helper_magma(input, upper, info);
1740+
break;
1741+
default:
1742+
if (batchCount(input) == 1 || !use_magma_ || use_cusolver_potrf_batched_) {
1743+
cholesky_helper_cusolver(input, upper, info);
1744+
} else {
1745+
cholesky_helper_magma(input, upper, info);
1746+
}
17131747
}
17141748
#else
17151749
cholesky_helper_magma(input, upper, info);
@@ -1777,10 +1811,19 @@ Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper)
17771811
// result should be in column major order and contain matrices to invert
17781812
// the content of result is overwritten by 'apply_cholesky_inverse'
17791813
#ifdef USE_CUSOLVER
1780-
if (batchCount(result) == 1 || !use_magma_) {
1781-
return cholesky_inverse_kernel_impl_cusolver(result, infos, upper);
1782-
} else {
1783-
return cholesky_inverse_kernel_impl_magma(result, infos, upper);
1814+
auto preferred_backend = at::globalContext().linalgPreferredBackend();
1815+
switch (preferred_backend) {
1816+
case at::LinalgBackend::Cusolver:
1817+
return cholesky_inverse_kernel_impl_cusolver(result, infos, upper);
1818+
case at::LinalgBackend::Magma:
1819+
return cholesky_inverse_kernel_impl_magma(result, infos, upper);
1820+
default:
1821+
if (batchCount(result) == 1 ||
1822+
!use_magma_) {
1823+
return cholesky_inverse_kernel_impl_cusolver(result, infos, upper);
1824+
} else {
1825+
return cholesky_inverse_kernel_impl_magma(result, infos, upper);
1826+
}
17841827
}
17851828
#else
17861829
return cholesky_inverse_kernel_impl_magma(result, infos, upper);
@@ -1944,20 +1987,39 @@ static void lu_batched_magma(const Tensor& input, const Tensor& pivots, const Te
19441987
static void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
19451988
int64_t batch_size = batchCount(input);
19461989
#ifdef USE_CUSOLVER
1947-
// Use a heuristic to determine that cusolver is faster than MAGMA for the following sizes.
1948-
auto m = input.size(-2);
1949-
// exclude complex128 since nan_to_num_ does not work with it.
1950-
if ((batch_size == 1 || (batch_size <= 8 && m <= 16) || !use_magma_ ) && !input.is_complex()) {
1951-
lu_looped_cusolver(input, pivots, infos, compute_pivots);
1990+
auto preferred_backend = at::globalContext().linalgPreferredBackend();
1991+
switch (preferred_backend) {
1992+
case at::LinalgBackend::Cusolver:
1993+
lu_looped_cusolver(input, pivots, infos, compute_pivots);
1994+
break;
1995+
case at::LinalgBackend::Magma:
1996+
if (batch_size == 1) {
1997+
lu_looped_magma(input, pivots, infos, compute_pivots);
1998+
} else {
1999+
lu_batched_magma(input, pivots, infos, compute_pivots);
2000+
}
2001+
break;
2002+
default:
2003+
// Use a heuristic to determine that cusolver is faster than MAGMA for the following sizes.
2004+
auto m = input.size(-2);
2005+
// exclude complex128 since nan_to_num_ does not work with it.
2006+
if ((batch_size == 1 ||
2007+
(batch_size <= 8 && m <= 16) ||
2008+
!use_magma_)
2009+
&& !input.is_complex()) {
2010+
lu_looped_cusolver(input, pivots, infos, compute_pivots);
2011+
} else {
2012+
lu_batched_magma(input, pivots, infos, compute_pivots);
2013+
}
19522014
}
19532015
#else
19542016
if (batch_size == 1) {
19552017
lu_looped_magma(input, pivots, infos, compute_pivots);
19562018
}
1957-
#endif // USE_CUSOLVER
19582019
else {
19592020
lu_batched_magma(input, pivots, infos, compute_pivots);
19602021
}
2022+
#endif // USE_CUSOLVER
19612023
}
19622024

19632025
REGISTER_CUDA_DISPATCH(lu_stub, &apply_lu);
@@ -2064,12 +2126,12 @@ Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau) {
20642126
// See discussions in https://github.com/pytorch/pytorch/pull/51348 for comparison of cuSOLVER-MAGMA
20652127
// and Windows failure.
20662128
// For reference here is the MAGMA-based implementation: https://gist.github.com/IvanYashchuk/2db50002c9d3c1462ff769e6410ad983
2067-
#if defined(USE_CUSOLVER)
2068-
return orgqr_helper_cusolver(result, tau); // cusolver
2069-
#else
2070-
TORCH_CHECK(false, "Calling torch.orgqr on a CUDA tensor requires compiling ",
2071-
"PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support.");
2072-
#endif
2129+
#if defined(USE_CUSOLVER)
2130+
return orgqr_helper_cusolver(result, tau); // cusolver
2131+
#else
2132+
TORCH_CHECK(false, "Calling torch.orgqr on a CUDA tensor requires compiling ",
2133+
"PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support.");
2134+
#endif
20732135
}
20742136

20752137
REGISTER_CUDA_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
@@ -2136,7 +2198,14 @@ void geqrf_magma(const Tensor& input, const Tensor& tau) {
21362198
// This is a backend library dispatching helper function for calling looped batch implementation
21372199
void geqrf_looped(const Tensor& input, const Tensor& tau) {
21382200
#if defined(USE_CUSOLVER)
2139-
return geqrf_cusolver(input, tau);
2201+
auto preferred_backend = at::globalContext().linalgPreferredBackend();
2202+
switch (preferred_backend) {
2203+
case at::LinalgBackend::Magma:
2204+
return geqrf_magma(input, tau);
2205+
case at::LinalgBackend::Cusolver:
2206+
default:
2207+
return geqrf_cusolver(input, tau);
2208+
}
21402209
#else
21412210
return geqrf_magma(input, tau);
21422211
#endif
@@ -2273,9 +2342,16 @@ std::tuple<Tensor, Tensor> linalg_qr_helper_magma(const Tensor& self, c10::strin
22732342

22742343
std::tuple<Tensor, Tensor> _linalg_qr_helper_cuda(const Tensor& input, c10::string_view mode) {
22752344
#if defined(USE_CUSOLVER)
2276-
// _linalg_qr_helper_default is a generic function that is implemented using
2277-
// geqrf_stub and orgqr_stub. It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
2278-
return _linalg_qr_helper_default(input, mode);
2345+
auto preferred_backend = at::globalContext().linalgPreferredBackend();
2346+
switch (preferred_backend) {
2347+
case at::LinalgBackend::Magma:
2348+
return linalg_qr_helper_magma(input, mode);
2349+
case at::LinalgBackend::Cusolver:
2350+
default:
2351+
// _linalg_qr_helper_default is a generic function that is implemented using
2352+
// geqrf_stub and orgqr_stub. It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
2353+
return _linalg_qr_helper_default(input, mode);
2354+
}
22792355
#else
22802356
return linalg_qr_helper_magma(input, mode);
22812357
#endif
@@ -2432,7 +2508,15 @@ void linalg_eigh_magma(const Tensor& eigenvalues, const Tensor& eigenvectors, co
24322508

24332509
void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
24342510
#if defined(USE_CUSOLVER)
2435-
linalg_eigh_cusolver(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
2511+
auto preferred_backend = at::globalContext().linalgPreferredBackend();
2512+
switch (preferred_backend) {
2513+
case at::LinalgBackend::Magma:
2514+
linalg_eigh_magma(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
2515+
break;
2516+
case at::LinalgBackend::Cusolver:
2517+
default:
2518+
linalg_eigh_cusolver(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
2519+
}
24362520
#else
24372521
linalg_eigh_magma(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
24382522
#endif
@@ -2731,7 +2815,14 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda_legacy(const Tensor& self, b
27312815

27322816
std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda(const Tensor& self, bool some, bool compute_uv) {
27332817
#ifdef USE_CUSOLVER
2734-
return _svd_helper_cuda_lib(self, some, compute_uv);
2818+
auto preferred_backend = at::globalContext().linalgPreferredBackend();
2819+
switch (preferred_backend) {
2820+
case at::LinalgBackend::Magma:
2821+
return _svd_helper_cuda_legacy(self, some, compute_uv);
2822+
case at::LinalgBackend::Cusolver:
2823+
default:
2824+
return _svd_helper_cuda_lib(self, some, compute_uv);
2825+
}
27352826
#else
27362827
return _svd_helper_cuda_legacy(self, some, compute_uv);
27372828
#endif
@@ -3046,10 +3137,17 @@ void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& /*infos*/
30463137

30473138
void gels_looped(const Tensor& a, Tensor& b, Tensor& infos) {
30483139
#if defined(USE_CUSOLVER)
3049-
// linalg_lstsq_gels is a generic function that is implemented using
3050-
// geqrf_stub, ormqr_stub, and triangular_solve_stub
3051-
// It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
3052-
return linalg_lstsq_gels(a, b, infos);
3140+
auto preferred_backend = at::globalContext().linalgPreferredBackend();
3141+
switch (preferred_backend) {
3142+
case at::LinalgBackend::Magma:
3143+
return gels_magma(a, b, infos);
3144+
case at::LinalgBackend::Cusolver:
3145+
default:
3146+
// linalg_lstsq_gels is a generic function that is implemented using
3147+
// geqrf_stub, ormqr_stub, and triangular_solve_stub
3148+
// It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
3149+
return linalg_lstsq_gels(a, b, infos);
3150+
}
30533151
#else
30543152
return gels_magma(a, b, infos);
30553153
#endif

docs/source/backends.rst

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ torch.backends.cuda
4545

4646
Clears the cuFFT plan cache.
4747

48+
.. autofunction:: torch.backends.cuda.preferred_linalg_library
49+
4850

4951
torch.backends.cudnn
5052
^^^^^^^^^^^^^^^^^^^^

0 commit comments

Comments
 (0)