|
16 | 16 | #include <THC/THC.h> // for USE_MAGMA |
17 | 17 |
|
18 | 18 | #ifdef USE_MAGMA |
19 | | -#include <magma.h> |
20 | 19 | #include <magma_types.h> |
| 20 | +#include <magma_v2.h> |
21 | 21 |
|
22 | 22 | const bool use_magma_ = true; |
23 | 23 | #else |
@@ -95,10 +95,18 @@ void magmaCholeskyBatched( |
95 | 95 | magma_uplo_t uplo, magma_int_t n, scalar_t** dA_array, magma_int_t ldda, |
96 | 96 | magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue); |
97 | 97 |
|
98 | | -template<class scalar_t> |
| 98 | +template <class scalar_t> |
99 | 99 | void magmaTriangularSolve( |
100 | | - magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, |
101 | | - scalar_t* dA, magma_int_t ldda, scalar_t* dB, magma_int_t lddb); |
| 100 | + magma_uplo_t uplo, |
| 101 | + magma_trans_t trans, |
| 102 | + magma_diag_t diag, |
| 103 | + magma_int_t m, |
| 104 | + magma_int_t n, |
| 105 | + scalar_t* dA, |
| 106 | + magma_int_t ldda, |
| 107 | + scalar_t* dB, |
| 108 | + magma_int_t lddb, |
| 109 | + const MAGMAQueue& magma_queue); |
102 | 110 |
|
103 | 111 | template<class scalar_t> |
104 | 112 | void magmaTriangularSolveBatched( |
@@ -662,45 +670,117 @@ void magmaCholeskyBatched<c10::complex<float>>( |
662 | 670 | AT_CUDA_CHECK(cudaGetLastError()); |
663 | 671 | } |
664 | 672 |
|
665 | | -template<> |
| 673 | +template <> |
666 | 674 | void magmaTriangularSolve<double>( |
667 | | - magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, |
668 | | - double* dA, magma_int_t ldda, double* dB, magma_int_t lddb) { |
669 | | - MagmaStreamSyncGuard guard; |
670 | | - magma_dtrsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb); |
| 675 | + magma_uplo_t uplo, |
| 676 | + magma_trans_t trans, |
| 677 | + magma_diag_t diag, |
| 678 | + magma_int_t m, |
| 679 | + magma_int_t n, |
| 680 | + double* dA, |
| 681 | + magma_int_t ldda, |
| 682 | + double* dB, |
| 683 | + magma_int_t lddb, |
| 684 | + const MAGMAQueue& magma_queue) { |
| 685 | + magma_dtrsm( |
| 686 | + MagmaLeft, |
| 687 | + uplo, |
| 688 | + trans, |
| 689 | + diag, |
| 690 | + m, |
| 691 | + n, |
| 692 | + 1, |
| 693 | + dA, |
| 694 | + ldda, |
| 695 | + dB, |
| 696 | + lddb, |
| 697 | + magma_queue.get_queue()); |
671 | 698 | AT_CUDA_CHECK(cudaGetLastError()); |
672 | 699 | } |
673 | 700 |
|
674 | | -template<> |
| 701 | +template <> |
675 | 702 | void magmaTriangularSolve<float>( |
676 | | - magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, |
677 | | - float* dA, magma_int_t ldda, float* dB, magma_int_t lddb) { |
678 | | - MagmaStreamSyncGuard guard; |
679 | | - magma_strsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb); |
| 703 | + magma_uplo_t uplo, |
| 704 | + magma_trans_t trans, |
| 705 | + magma_diag_t diag, |
| 706 | + magma_int_t m, |
| 707 | + magma_int_t n, |
| 708 | + float* dA, |
| 709 | + magma_int_t ldda, |
| 710 | + float* dB, |
| 711 | + magma_int_t lddb, |
| 712 | + const MAGMAQueue& magma_queue) { |
| 713 | + magma_strsm( |
| 714 | + MagmaLeft, |
| 715 | + uplo, |
| 716 | + trans, |
| 717 | + diag, |
| 718 | + m, |
| 719 | + n, |
| 720 | + 1, |
| 721 | + dA, |
| 722 | + ldda, |
| 723 | + dB, |
| 724 | + lddb, |
| 725 | + magma_queue.get_queue()); |
680 | 726 | AT_CUDA_CHECK(cudaGetLastError()); |
681 | 727 | } |
682 | 728 |
|
683 | | -template<> |
| 729 | +template <> |
684 | 730 | void magmaTriangularSolve<c10::complex<double>>( |
685 | | - magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, |
686 | | - c10::complex<double>* dA, magma_int_t ldda, c10::complex<double>* dB, magma_int_t lddb) { |
687 | | - MagmaStreamSyncGuard guard; |
| 731 | + magma_uplo_t uplo, |
| 732 | + magma_trans_t trans, |
| 733 | + magma_diag_t diag, |
| 734 | + magma_int_t m, |
| 735 | + magma_int_t n, |
| 736 | + c10::complex<double>* dA, |
| 737 | + magma_int_t ldda, |
| 738 | + c10::complex<double>* dB, |
| 739 | + magma_int_t lddb, |
| 740 | + const MAGMAQueue& magma_queue) { |
688 | 741 | magmaDoubleComplex alpha({1, 0}); |
689 | | - magma_ztrsm(MagmaLeft, uplo, trans, diag, m, n, alpha, |
690 | | - reinterpret_cast<magmaDoubleComplex*>(dA), ldda, |
691 | | - reinterpret_cast<magmaDoubleComplex*>(dB), lddb); |
| 742 | + magma_ztrsm( |
| 743 | + MagmaLeft, |
| 744 | + uplo, |
| 745 | + trans, |
| 746 | + diag, |
| 747 | + m, |
| 748 | + n, |
| 749 | + alpha, |
| 750 | + reinterpret_cast<magmaDoubleComplex*>(dA), |
| 751 | + ldda, |
| 752 | + reinterpret_cast<magmaDoubleComplex*>(dB), |
| 753 | + lddb, |
| 754 | + magma_queue.get_queue()); |
692 | 755 | AT_CUDA_CHECK(cudaGetLastError()); |
693 | 756 | } |
694 | 757 |
|
695 | | -template<> |
| 758 | +template <> |
696 | 759 | void magmaTriangularSolve<c10::complex<float>>( |
697 | | - magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, |
698 | | - c10::complex<float>* dA, magma_int_t ldda, c10::complex<float>* dB, magma_int_t lddb) { |
699 | | - MagmaStreamSyncGuard guard; |
| 760 | + magma_uplo_t uplo, |
| 761 | + magma_trans_t trans, |
| 762 | + magma_diag_t diag, |
| 763 | + magma_int_t m, |
| 764 | + magma_int_t n, |
| 765 | + c10::complex<float>* dA, |
| 766 | + magma_int_t ldda, |
| 767 | + c10::complex<float>* dB, |
| 768 | + magma_int_t lddb, |
| 769 | + const MAGMAQueue& magma_queue) { |
700 | 770 | magmaFloatComplex alpha({1, 0}); |
701 | | - magma_ctrsm(MagmaLeft, uplo, trans, diag, m, n, alpha, |
702 | | - reinterpret_cast<magmaFloatComplex*>(dA), ldda, |
703 | | - reinterpret_cast<magmaFloatComplex*>(dB), lddb); |
| 771 | + magma_ctrsm( |
| 772 | + MagmaLeft, |
| 773 | + uplo, |
| 774 | + trans, |
| 775 | + diag, |
| 776 | + m, |
| 777 | + n, |
| 778 | + alpha, |
| 779 | + reinterpret_cast<magmaFloatComplex*>(dA), |
| 780 | + ldda, |
| 781 | + reinterpret_cast<magmaFloatComplex*>(dB), |
| 782 | + lddb, |
| 783 | + magma_queue.get_queue()); |
704 | 784 | AT_CUDA_CHECK(cudaGetLastError()); |
705 | 785 | } |
706 | 786 |
|
@@ -1636,11 +1716,14 @@ AT_ERROR("triangular_solve: MAGMA library not found in " |
1636 | 1716 | magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)"); |
1637 | 1717 | magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount"); |
1638 | 1718 |
|
| 1719 | + MAGMAQueue magma_queue(b.get_device()); |
| 1720 | + |
1639 | 1721 | // batch_size == 1 implies that: |
1640 | 1722 | // 1. the RHS and LHS tensors have 2 dimensions, or |
1641 | 1723 | // 2. the RHS and LHS tensors have more than 2 dimensions but all batch dimensions are 1 |
1642 | 1724 | if (batch_size == 1) { |
1643 | | - magmaTriangularSolve<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n); |
| 1725 | + magmaTriangularSolve<scalar_t>( |
| 1726 | + uplo, trans, diag, n, nrhs, A_data, n, b_data, n, magma_queue); |
1644 | 1727 | } else { |
1645 | 1728 | auto A_mat_stride = matrixStride(A); |
1646 | 1729 | auto b_mat_stride = matrixStride(b); |
|
0 commit comments