@@ -1012,12 +1012,31 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf
10121012
10131013void THTensor_ (btrisolve )(THTensor * rb_ , THTensor * b , THTensor * atf , THIntTensor * pivots )
10141014{
1015- THArgCheck (THTensor_ (nDimension )(atf ) == 3 , 1 , "expected 3D tensor, got %dD" , THTensor_ (nDimension )(atf ));
1015+ THArgCheck (THTensor_ (nDimension )(atf ) == 3 , 1 , "expected 3D tensor, got %dD" ,
1016+ THTensor_ (nDimension )(atf ));
1017+ THArgCheck (THTensor_ (nDimension )(b ) == 3 ||
1018+ THTensor_ (nDimension )(b ) == 2 , 4 , "expected 2D or 3D tensor" );
1019+ THArgCheck (THTensor_ (size )(atf , 0 ) ==
1020+ THTensor_ (size )(b , 0 ), 3 , "number of batches must be equal" );
1021+ THArgCheck (THTensor_ (size )(atf , 1 ) ==
1022+ THTensor_ (size )(atf , 2 ), 3 , "A matrices must be square" );
1023+ THArgCheck (THTensor_ (size )(atf , 1 ) ==
1024+ THTensor_ (size )(b , 1 ), 3 , "dimensions of A and b must be equal" );
10161025
1017- int lda ;
1026+ if (rb_ != b ) {
1027+ THTensor_ (resizeAs )(rb_ , b );
1028+ THTensor_ (copy )(rb_ , b );
1029+ }
1030+
1031+ long num_batches = atf -> size [0 ];
1032+ long n = atf -> size [1 ];
1033+ int nrhs = rb_ -> nDimension > 2 ? rb_ -> size [2 ] : 1 ;
1034+
1035+ int lda , ldb ;
10181036 THTensor * atf_ ;
1037+ THTensor * rb__ ;
10191038
1020- // correct ordering of A_a
1039+ // correct ordering of A
10211040 if (atf -> stride [1 ] == 1 ) {
10221041 // column ordered, what BLAS wants
10231042 lda = atf -> stride [2 ];
@@ -1034,14 +1053,29 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor
10341053 lda = atf_ -> stride [2 ];
10351054 }
10361055
1037- if (rb_ != b ) {
1038- THTensor_ (resizeAs )(rb_ , b );
1039- THTensor_ (copy )(rb_ , b );
1056+ // correct ordering of B
1057+ if (rb_ -> stride [1 ] == 1 ) {
1058+ // column ordered
1059+ if (rb_ -> nDimension == 2 || rb_ -> size [2 ] == 1 ) {
1060+ ldb = n ;
1061+ } else {
1062+ ldb = rb_ -> stride [2 ];
1063+ }
1064+ rb__ = rb_ ;
1065+ } else {
1066+ // make column ordered
1067+ if (rb_ -> nDimension > 2 ) {
1068+ THTensor * transp_r_ = THTensor_ (newTranspose )(rb_ , 1 , 2 );
1069+ rb__ = THTensor_ (newClone )(transp_r_ );
1070+ THTensor_ (free )(transp_r_ );
1071+ THTensor_ (transpose )(rb__ , NULL , 1 , 2 );
1072+ ldb = rb__ -> stride [2 ];
1073+ } else {
1074+ rb__ = THTensor_ (newClone )(rb_ );
1075+ ldb = n ;
1076+ }
10401077 }
10411078
1042- long num_batches = atf -> size [0 ];
1043- long n = atf -> size [1 ];
1044-
10451079 THTensor * ai = THTensor_ (new )();
10461080 THTensor * rbi = THTensor_ (new )();
10471081 THIntTensor * pivoti = THIntTensor_new ();
@@ -1052,14 +1086,14 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor
10521086
10531087 for (long batch = 0 ; batch < num_batches ; ++ batch ) {
10541088 THTensor_ (select )(ai , atf_ , 0 , batch );
1055- THTensor_ (select )(rbi , rb_ , 0 , batch );
1089+ THTensor_ (select )(rbi , rb__ , 0 , batch );
10561090 THIntTensor_select (pivoti , pivots , 0 , batch );
10571091
10581092#if defined(TH_REAL_IS_FLOAT ) || defined(TH_REAL_IS_DOUBLE )
10591093 int info ;
1060- THLapack_ (getrs )('N' , n , 1 , THTensor_ (data )(ai ), lda ,
1094+ THLapack_ (getrs )('N' , n , nrhs , THTensor_ (data )(ai ), lda ,
10611095 THIntTensor_data (pivoti ), THTensor_ (data )(rbi ),
1062- n , & info );
1096+ ldb , & info );
10631097 if (info != 0 ) {
10641098 THError ("Error: Nonzero info." );
10651099 }
@@ -1075,7 +1109,10 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor
10751109 if (atf_ != atf ) {
10761110 THTensor_ (free )(atf_ );
10771111 }
1078- }
10791112
1113+ if (rb__ != rb_ ) {
1114+ THTensor_ (freeCopyTo )(rb__ , rb_ );
1115+ }
1116+ }
10801117
10811118#endif
0 commit comments