Skip to content

Commit 12f1b4f

Browse files
committed
Merge commit '84bdbe5ab4b602b021ff494487c8ad57457052d3'
2 parents 8595403 + 84bdbe5 commit 12f1b4f

File tree

1 file changed

+50
-13
lines changed

1 file changed

+50
-13
lines changed

torch/lib/TH/generic/THTensorLapack.c

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,12 +1012,31 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf
10121012

10131013
void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor *pivots)
10141014
{
1015-
THArgCheck(THTensor_(nDimension)(atf) == 3, 1, "expected 3D tensor, got %dD", THTensor_(nDimension)(atf));
1015+
THArgCheck(THTensor_(nDimension)(atf) == 3, 1, "expected 3D tensor, got %dD",
1016+
THTensor_(nDimension)(atf));
1017+
THArgCheck(THTensor_(nDimension)(b) == 3 ||
1018+
THTensor_(nDimension)(b) == 2, 4, "expected 2D or 3D tensor");
1019+
THArgCheck(THTensor_(size)(atf, 0) ==
1020+
THTensor_(size)(b, 0), 3, "number of batches must be equal");
1021+
THArgCheck(THTensor_(size)(atf, 1) ==
1022+
THTensor_(size)(atf, 2), 3, "A matrices must be square");
1023+
THArgCheck(THTensor_(size)(atf, 1) ==
1024+
THTensor_(size)(b, 1), 3, "dimensions of A and b must be equal");
10161025

1017-
int lda;
1026+
if (rb_ != b) {
1027+
THTensor_(resizeAs)(rb_, b);
1028+
THTensor_(copy)(rb_, b);
1029+
}
1030+
1031+
long num_batches = atf->size[0];
1032+
long n = atf->size[1];
1033+
int nrhs = rb_->nDimension > 2 ? rb_->size[2] : 1;
1034+
1035+
int lda, ldb;
10181036
THTensor *atf_;
1037+
THTensor *rb__;
10191038

1020-
// correct ordering of A_a
1039+
// correct ordering of A
10211040
if (atf->stride[1] == 1) {
10221041
// column ordered, what BLAS wants
10231042
lda = atf->stride[2];
@@ -1034,14 +1053,29 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor
10341053
lda = atf_->stride[2];
10351054
}
10361055

1037-
if (rb_ != b) {
1038-
THTensor_(resizeAs)(rb_, b);
1039-
THTensor_(copy)(rb_, b);
1056+
// correct ordering of B
1057+
if (rb_->stride[1] == 1) {
1058+
// column ordered
1059+
if (rb_->nDimension == 2 || rb_->size[2] == 1) {
1060+
ldb = n;
1061+
} else {
1062+
ldb = rb_->stride[2];
1063+
}
1064+
rb__ = rb_;
1065+
} else {
1066+
// make column ordered
1067+
if (rb_->nDimension > 2) {
1068+
THTensor *transp_r_ = THTensor_(newTranspose)(rb_, 1, 2);
1069+
rb__ = THTensor_(newClone)(transp_r_);
1070+
THTensor_(free)(transp_r_);
1071+
THTensor_(transpose)(rb__, NULL, 1, 2);
1072+
ldb = rb__->stride[2];
1073+
} else {
1074+
rb__ = THTensor_(newClone)(rb_);
1075+
ldb = n;
1076+
}
10401077
}
10411078

1042-
long num_batches = atf->size[0];
1043-
long n = atf->size[1];
1044-
10451079
THTensor *ai = THTensor_(new)();
10461080
THTensor *rbi = THTensor_(new)();
10471081
THIntTensor *pivoti = THIntTensor_new();
@@ -1052,14 +1086,14 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor
10521086

10531087
for (long batch = 0; batch < num_batches; ++batch) {
10541088
THTensor_(select)(ai, atf_, 0, batch);
1055-
THTensor_(select)(rbi, rb_, 0, batch);
1089+
THTensor_(select)(rbi, rb__, 0, batch);
10561090
THIntTensor_select(pivoti, pivots, 0, batch);
10571091

10581092
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
10591093
int info;
1060-
THLapack_(getrs)('N', n, 1, THTensor_(data)(ai), lda,
1094+
THLapack_(getrs)('N', n, nrhs, THTensor_(data)(ai), lda,
10611095
THIntTensor_data(pivoti), THTensor_(data)(rbi),
1062-
n, &info);
1096+
ldb, &info);
10631097
if (info != 0) {
10641098
THError("Error: Nonzero info.");
10651099
}
@@ -1075,7 +1109,10 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor
10751109
if (atf_ != atf) {
10761110
THTensor_(free)(atf_);
10771111
}
1078-
}
10791112

1113+
if (rb__ != rb_) {
1114+
THTensor_(freeCopyTo)(rb__, rb_);
1115+
}
1116+
}
10801117

10811118
#endif

0 commit comments

Comments
 (0)