Skip to content

Commit b2b8ba3

Browse files
committed
Merge pull request torch#373 from bartvm/ormqr
LAPACK ormqr routine
2 parents 727be61 + c7c9c88 commit b2b8ba3

File tree

7 files changed

+133
-1
lines changed

7 files changed

+133
-1
lines changed

TensorMath.lua

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,22 @@ static void THTensor_random1__(THTensor *self, THGenerator *gen, long b)
11561156
{name=Tensor},
11571157
{name=Tensor}}
11581158
)
1159+
interface:wrap("ormqr",
1160+
cname("ormqr"),
1161+
{{name=Tensor, returned=true},
1162+
{name=Tensor},
1163+
{name=Tensor},
1164+
{name=Tensor},
1165+
{name='charoption', values={'L', 'R'}, default='L'},
1166+
{name='charoption', values={'N', 'T'}, default='N'}},
1167+
cname("ormqr"),
1168+
{{name=Tensor, default=true, returned=true, invisible=true},
1169+
{name=Tensor},
1170+
{name=Tensor},
1171+
{name=Tensor},
1172+
{name='charoption', values={'L', 'R'}, default='L'},
1173+
{name='charoption', values={'N', 'T'}, default='N'}}
1174+
)
11591175
end
11601176

11611177
method:register(string.format("m_torch_%sMath__", Tensor))

doc/maths.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1953,7 +1953,8 @@ z = x * y
19531953
### torch.qr([q, r], x) ###
19541954

19551955
Compute a QR decomposition of the matrix `x`: matrices `q` and `r` such that
1956-
`x = q * r`, with `q` orthogonal and `r` upper triangular.
1956+
`x = q * r`, with `q` orthogonal and `r` upper triangular. This returns
1957+
the thin (reduced) QR factorization.
19571958

19581959
`=torch.qr(x)` returns the Q and R components as new matrices.
19591960

@@ -2028,6 +2029,23 @@ given by `torch.geqrf`. See
20282029
[LAPACK documentation](http://www.netlib.org/netlib/lapack/double/dorgqr.f) for
20292030
further details.
20302031

2032+
<a name="torch.ormqr"></a>
2033+
### torch.ormqr([res], m, tau, mat [, 'L' or 'R'] [, 'N' or 'T']) ###
2034+
2035+
Multiply a matrix with `Q` as defined by the elementary reflectors and
2036+
scalar factors returned by `geqrf`. This is a low-level function for
2037+
calling LAPACK directly. You'll generally want to use `torch.qr()`
2038+
instead.
2039+
2040+
* `side` (`'L'` or `'R'`) specifies whether `mat` should be
2041+
left-multiplied, `Q * mat`, or right-multiplied, `mat * Q`.
2042+
* `trans` (`'N'` or `'T`') specifies whether `Q` should be transposed
2043+
before being multiplied.
2044+
2045+
See [LAPACK
2046+
documentation](http://www.netlib.org/netlib/lapack/double/dormqr.f) for
2047+
further details.
2048+
20312049
<a name="torch.logical.dok"></a>
20322050
## Logical Operations on Tensors ##
20332051

lib/TH/generic/THLapack.c

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ TH_EXTERNC void sgeqrf_(int *m, int *n, float *a, int *lda, float *tau, float *w
2929
TH_EXTERNC void dgeqrf_(int *m, int *n, double *a, int *lda, double *tau, double *work, int *lwork, int *info);
3030
TH_EXTERNC void sorgqr_(int *m, int *n, int *k, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
3131
TH_EXTERNC void dorgqr_(int *m, int *n, int *k, double *a, int *lda, double *tau, double *work, int *lwork, int *info);
32+
TH_EXTERNC void sormqr_(char *side, char *trans, int *m, int *n, int *k, float *a, int *lda, float *tau, float *c, int *ldc, float *work, int *lwork, int *info);
33+
TH_EXTERNC void dormqr_(char *side, char *trans, int *m, int *n, int *k, double *a, int *lda, double *tau, double *c, int *ldc, double *work, int *lwork, int *info);
3234

3335

3436
/* Compute the solution to a real system of linear equations A * X = B */
@@ -218,5 +220,19 @@ void THLapack_(orgqr)(int m, int n, int k, real *a, int lda, real *tau, real *wo
218220
#endif
219221
}
220222

223+
/* Multiply Q with a matrix using the output of geqrf */
224+
void THLapack_(ormqr)(char side, char trans, int m, int n, int k, real *a, int lda, real *tau, real *c, int ldc, real *work, int lwork, int *info)
225+
{
226+
#ifdef USE_LAPACK
227+
#if defined(TH_REAL_IS_DOUBLE)
228+
dormqr_(&side, &trans, &m, &n, &k, a, &lda, tau, c, &ldc, work, &lwork, info);
229+
#else
230+
sormqr_(&side, &trans, &m, &n, &k, a, &lda, tau, c, &ldc, work, &lwork, info);
231+
#endif
232+
#else
233+
THError("ormqr: Lapack library not found in compile time\n");
234+
#endif
235+
}
236+
221237

222238
#endif

lib/TH/generic/THLapack.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,7 @@ void THLapack_(potrs)(char uplo, int n, int nrhs, real *a, int lda, real *b, int
3131
void THLapack_(geqrf)(int m, int n, real *a, int lda, real *tau, real *work, int lwork, int *info);
3232
/* Build Q from output of geqrf */
3333
void THLapack_(orgqr)(int m, int n, int k, real *a, int lda, real *tau, real *work, int lwork, int *info);
34+
/* Multiply Q with a matrix from output of geqrf */
35+
void THLapack_(ormqr)(char side, char trans, int m, int n, int k, real *a, int lda, real *tau, real *c, int ldc, real *work, int lwork, int *info);
3436

3537
#endif

lib/TH/generic/THTensorLapack.c

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,4 +637,63 @@ void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau)
637637
THTensor_(free)(work);
638638
}
639639

640+
/*
641+
The ormqr function multiplies Q with another matrix from a sequence of
642+
elementary reflectors, such as is produced by the geqrf function.
643+
644+
Args:
645+
* `ra_` - result Tensor, which will contain the matrix Q' c.
646+
* `a` - input Tensor, which should be a matrix with the directions of the
647+
elementary reflectors below the diagonal. If NULL, `ra_` is used as
648+
input.
649+
* `tau` - input Tensor, containing the magnitudes of the elementary
650+
reflectors.
651+
* `c` - input Tensor, containing the matrix to be multiplied.
652+
* `side` - char, determining whether c is left- or right-multiplied with Q.
653+
* `trans` - char, determining whether to transpose Q before multiplying.
654+
655+
For further details, please see the LAPACK documentation.
656+
657+
*/
658+
void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, const char *side, const char *trans)
659+
{
660+
if (a == NULL) a = ra_;
661+
THArgCheck(a->nDimension == 2, 1, "A should be 2 dimensional");
662+
663+
THTensor *ra__ = NULL;
664+
ra__ = THTensor_(cloneColumnMajor)(ra_, c);
665+
666+
int m = c->size[0];
667+
int n = c->size[1];
668+
int k = tau->size[0];
669+
int lda;
670+
if (*side == 'L')
671+
{
672+
lda = m;
673+
}
674+
else
675+
{
676+
lda = n;
677+
}
678+
int ldc = m;
679+
680+
/* Dry-run to query the suggested size of the workspace. */
681+
int info = 0;
682+
real wkopt = 0;
683+
THLapack_(ormqr)(side[0], trans[0], m, n, k, THTensor_(data)(a), lda,
684+
THTensor_(data)(tau), THTensor_(data)(ra__), ldc,
685+
&wkopt, -1, &info);
686+
687+
/* Allocate the workspace and call LAPACK to do the real work. */
688+
int lwork = (int)wkopt;
689+
THTensor *work = THTensor_(newWithSize1d)(lwork);
690+
THLapack_(ormqr)(side[0], trans[0], m, n, k, THTensor_(data)(a), lda,
691+
THTensor_(data)(tau), THTensor_(data)(ra__), ldc,
692+
THTensor_(data)(work), lwork, &info);
693+
694+
THLapackCheck(" Lapack Error %s : unknown Lapack error. info = %i", "ormqr", info);
695+
THTensor_(freeCopyTo)(ra__, ra_);
696+
THTensor_(free)(work);
697+
}
698+
640699
#endif

lib/TH/generic/THTensorLapack.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ TH_API void THTensor_(potrf)(THTensor *ra_, THTensor *a);
1515
TH_API void THTensor_(qr)(THTensor *rq_, THTensor *rr_, THTensor *a);
1616
TH_API void THTensor_(geqrf)(THTensor *ra_, THTensor *rtau_, THTensor *a);
1717
TH_API void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau);
18+
TH_API void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, const char *side, const char *trans);
1819

1920
#endif

test/test_qr.lua

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,21 @@ local function qrManual(geqrfFunc, orgqrFunc)
6060
end
6161
end
6262

63+
-- Check that Q multiplied with a matrix with ormqr gives the correct result
64+
local function checkQM(testOpts, mat1, mat2)
65+
local q, r = torch.qr(mat1)
66+
local m, tau = torch.geqrf(mat1)
67+
local requiredPrecision = 1e-5
68+
tester:assertTensorEq(torch.mm(q, mat2), torch.ormqr(m, tau, mat2),
69+
requiredPrecision)
70+
tester:assertTensorEq(torch.mm(mat2, q), torch.ormqr(m, tau, mat2, 'R'),
71+
requiredPrecision)
72+
tester:assertTensorEq(torch.mm(q:t(), mat2),
73+
torch.ormqr(m, tau, mat2, 'L', 'T'), requiredPrecision)
74+
tester:assertTensorEq(torch.mm(mat2, q:t()),
75+
torch.ormqr(m, tau, mat2, 'R', 'T'), requiredPrecision)
76+
end
77+
6378
-- Check that the given `q`, `r` matrices are a valid QR decomposition of `a`.
6479
local function checkQR(testOpts, a, q, r)
6580
local qrFunc = testOpts.qr
@@ -250,5 +265,10 @@ addTestVariations(tests, 'randomNonContiguous', function(testOpts)
250265
end
251266
end)
252267

268+
function tests.testQM()
269+
checkQM({}, torch.randn(10, 10), torch.randn(10, 10))
270+
-- checkQM({}, torch.randn(20, 10), torch.randn(20, 20))
271+
end
272+
253273
tester:add(tests)
254274
tester:run()

0 commit comments

Comments
 (0)