Skip to content

Commit 04510e2

Browse files
committed
adding potrs and uplo option to potrf
adding tests for torch.potrs and (modified) torch.potrf
1 parent 727be61 commit 04510e2

File tree

6 files changed

+221
-34
lines changed

6 files changed

+221
-34
lines changed

TensorMath.lua

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,20 +1109,33 @@ static void THTensor_random1__(THTensor *self, THGenerator *gen, long b)
11091109
{{name=Tensor, default=true, returned=true, invisible=true},
11101110
{name=Tensor}}
11111111
)
1112-
1113-
interface:wrap("potri",
1114-
cname("potri"),
1112+
interface:wrap("potrf",
1113+
cname("potrf"),
11151114
{{name=Tensor, returned=true},
1116-
{name=Tensor}},
1117-
cname("potri"),
1115+
{name=Tensor},
1116+
{name='charoption', values={'U', 'L'}, default='U'}}, -- uplo
1117+
cname("potrf"),
11181118
{{name=Tensor, default=true, returned=true, invisible=true},
1119-
{name=Tensor}}
1119+
{name=Tensor},
1120+
{name='charoption', values={'U', 'L'}, default='U'}}
11201121
)
1121-
interface:wrap("potrf",
1122-
cname("potrf"),
1122+
interface:wrap("potrs",
1123+
cname("potrs"),
1124+
{{name=Tensor, returned=true},
1125+
{name=Tensor},
1126+
{name=Tensor},
1127+
{name='charoption', values={'U', 'L'}, default='U'}}, -- uplo
1128+
cname("potrs"),
1129+
{{name=Tensor, default=true, returned=true, invisible=true},
1130+
{name=Tensor},
1131+
{name=Tensor},
1132+
{name='charoption', values={'U', 'L'}, default='U'}}
1133+
)
1134+
interface:wrap("potri",
1135+
cname("potri"),
11231136
{{name=Tensor, returned=true},
11241137
{name=Tensor}},
1125-
cname("potrf"),
1138+
cname("potri"),
11261139
{{name=Tensor, default=true, returned=true, invisible=true},
11271140
{name=Tensor}}
11281141
)

doc/maths.md

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,6 +1621,102 @@ x = torch.trtrs(b, a)
16211621
4.1895292266754e-15
16221622
```
16231623

1624+
<a name="torch."></a>
1625+
### torch.potrf([res,] A [, 'U' or 'L'] ) ###
1626+
1627+
Cholesky Decomposition of 2D tensor `A`. Matrix `A` has to be a positive-definite and either symetric or complex Hermitian.
1628+
1629+
Optional character `uplo` = {'U', 'L'} specified whether the upper or lower triangular decomposition should be returned. By default, `uplo` = 'U'.
1630+
1631+
`X = torch.potrf(A, 'U')` returns the upper triangular Cholesky decomposition of X.
1632+
1633+
`X = torch.potrf(A, 'L')` returns the lower triangular Cholesky decomposition of X.
1634+
1635+
If tensor `res` is provided, the resulting decomposition will be stored therein.
1636+
1637+
```
1638+
A = torch.Tensor({
1639+
{1.2705, 0.9971, 0.4948, 0.1389, 0.2381},
1640+
{0.9971, 0.9966, 0.6752, 0.0686, 0.1196},
1641+
{0.4948, 0.6752, 1.1434, 0.0314, 0.0582},
1642+
{0.1389, 0.0686, 0.0314, 0.0270, 0.0526},
1643+
{0.2381, 0.1196, 0.0582, 0.0526, 0.3957}})
1644+
1645+
chol = torch.potrf(A)
1646+
> chol
1647+
1.1272 0.8846 0.4390 0.1232 0.2112
1648+
0.0000 0.4626 0.6200 -0.0874 -0.1453
1649+
0.0000 0.0000 0.7525 0.0419 0.0738
1650+
0.0000 0.0000 0.0000 0.0491 0.2199
1651+
0.0000 0.0000 0.0000 0.0000 0.5255
1652+
[torch.DoubleTensor of size 5x5]
1653+
1654+
torch.potrf(chol, A, 'L')
1655+
> chol
1656+
1.1272 0.0000 0.0000 0.0000 0.0000
1657+
0.8846 0.4626 0.0000 0.0000 0.0000
1658+
0.4390 0.6200 0.7525 0.0000 0.0000
1659+
0.1232 -0.0874 0.0419 0.0491 0.0000
1660+
0.2112 -0.1453 0.0738 0.2199 0.5255
1661+
[torch.DoubleTensor of size 5x5]
1662+
```
1663+
1664+
<a name="torch."></a>
1665+
### torch.potrs([res,] chol [, 'U' or 'L'] ) ###
1666+
1667+
Returns the solution to linear system `AX = B` using the Cholesky decomposition `chol` of 2D tensor `A`.
1668+
1669+
Square matrix `chol` should be triangular; and, righthand side matrix `B` should be of full rank.
1670+
1671+
Optional character `uplo` = {'U', 'L'} specified matrix `chol` as being other upper or lower triangular; and, by default, equals 'U'.
1672+
1673+
If tensor `res` is provided, the resulting decomposition will be stored therein.
1674+
1675+
```
1676+
A = torch.Tensor({
1677+
{1.2705, 0.9971, 0.4948, 0.1389, 0.2381},
1678+
{0.9971, 0.9966, 0.6752, 0.0686, 0.1196},
1679+
{0.4948, 0.6752, 1.1434, 0.0314, 0.0582},
1680+
{0.1389, 0.0686, 0.0314, 0.0270, 0.0526},
1681+
{0.2381, 0.1196, 0.0582, 0.0526, 0.3957}})
1682+
1683+
B = torch.Tensor({
1684+
{0.6219, 0.3439, 0.0431},
1685+
{0.5642, 0.1756, 0.0153},
1686+
{0.2334, 0.8594, 0.4103},
1687+
{0.7556, 0.1966, 0.9637},
1688+
{0.1420, 0.7185, 0.7476}})
1689+
1690+
chol = torch.potrf(A)
1691+
> chol
1692+
1.1272 0.8846 0.4390 0.1232 0.2112
1693+
0.0000 0.4626 0.6200 -0.0874 -0.1453
1694+
0.0000 0.0000 0.7525 0.0419 0.0738
1695+
0.0000 0.0000 0.0000 0.0491 0.2199
1696+
0.0000 0.0000 0.0000 0.0000 0.5255
1697+
[torch.DoubleTensor of size 5x5]
1698+
1699+
solve = torch.potrs(B, chol)
1700+
> solve
1701+
12.1945 61.8622 92.6882
1702+
-11.1782 -97.0303 -138.4874
1703+
-15.3442 -76.6562 -116.8218
1704+
6.1930 13.5238 25.2056
1705+
29.9678 251.7346 360.2301
1706+
[torch.DoubleTensor of size 5x3]
1707+
1708+
> A*solve
1709+
0.6219 0.3439 0.0431
1710+
0.5642 0.1756 0.0153
1711+
0.2334 0.8594 0.4103
1712+
0.7556 0.1966 0.9637
1713+
0.1420 0.7185 0.7476
1714+
[torch.DoubleTensor of size 5x3]
1715+
1716+
> B:dist(A*solve)
1717+
4.6783066076306e-14
1718+
```
1719+
16241720
<a name="torch.gels"></a>
16251721
### torch.gels([resb, resa,] b, a) ###
16261722

lib/TH/generic/THLapack.c

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,31 +162,31 @@ void THLapack_(potrf)(char uplo, int n, real *a, int lda, int *info)
162162
#endif
163163
}
164164

165-
/* Cholesky factorization based Matrix Inverse */
166-
void THLapack_(potri)(char uplo, int n, real *a, int lda, int *info)
165+
/* Solve A*X = B with a symmetric positive definite matrix A using the Cholesky factorization */
166+
void THLapack_(potrs)(char uplo, int n, int nrhs, real *a, int lda, real *b, int ldb, int *info)
167167
{
168168
#ifdef USE_LAPACK
169169
#if defined(TH_REAL_IS_DOUBLE)
170-
dpotri_(&uplo, &n, a, &lda, info);
170+
dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
171171
#else
172-
spotri_(&uplo, &n, a, &lda, info);
172+
spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
173173
#endif
174174
#else
175-
THError("potri: Lapack library not found in compile time\n");
175+
THError("potrs: Lapack library not found in compile time\n");
176176
#endif
177177
}
178178

179-
/* Solve A*X = B with a symmetric positive definite matrix A using the Cholesky factorization */
180-
void THLapack_(potrs)(char uplo, int n, int nrhs, real *a, int lda, real *b, int ldb, int *info)
179+
/* Cholesky factorization based Matrix Inverse */
180+
void THLapack_(potri)(char uplo, int n, real *a, int lda, int *info)
181181
{
182182
#ifdef USE_LAPACK
183183
#if defined(TH_REAL_IS_DOUBLE)
184-
dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
184+
dpotri_(&uplo, &n, a, &lda, info);
185185
#else
186-
spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
186+
spotri_(&uplo, &n, a, &lda, info);
187187
#endif
188188
#else
189-
THError("potrs: Lapack library not found in compile time\n");
189+
THError("potri: Lapack library not found in compile time\n");
190190
#endif
191191
}
192192

lib/TH/generic/THTensorLapack.c

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -429,16 +429,15 @@ void THTensor_(getri)(THTensor *ra_, THTensor *a)
429429
THTensor_(freeCopyTo)(ra__, ra_);
430430
THTensor_(free)(work);
431431
THIntTensor_free(ipiv);
432-
}
432+
}
433433

434-
void THTensor_(potrf)(THTensor *ra_, THTensor *a)
434+
void THTensor_(potrf)(THTensor *ra_, THTensor *a, const char *uplo)
435435
{
436436
if (a == NULL) a = ra_;
437437
THArgCheck(a->nDimension == 2, 1, "A should be 2 dimensional");
438438
THArgCheck(a->size[0] == a->size[1], 1, "A should be square");
439439

440440
int n, lda, info;
441-
char uplo = 'U';
442441
THTensor *ra__ = NULL;
443442

444443
ra__ = THTensor_(cloneColumnMajor)(ra_, a);
@@ -447,23 +446,64 @@ void THTensor_(potrf)(THTensor *ra_, THTensor *a)
447446
lda = n;
448447

449448
/* Run Factorization */
450-
THLapack_(potrf)(uplo, n, THTensor_(data)(ra__), lda, &info);
449+
THLapack_(potrf)(uplo[0], n, THTensor_(data)(ra__), lda, &info);
451450
THLapackCheck("Lapack Error %s : A(%d,%d) is 0, A cannot be factorized", "potrf", info, info);
452451

453-
/* Build full upper-triangular matrix */
452+
/* Build full matrix */
453+
real *p = THTensor_(data)(ra__);
454+
long i, j;
455+
456+
/* Upper Triangular Case */
457+
if (uplo[0] == 'U')
454458
{
455-
real *p = THTensor_(data)(ra__);
456-
long i,j;
457459
for (i=0; i<n; i++) {
458-
for (j=i+1; j<n; j++) {
459-
p[i*n+j] = 0;
460+
for (j=i+1; j<n; j++) {
461+
p[n*i + j] = 0;
462+
}
463+
}
464+
}
465+
/* Lower Triangular Case */
466+
else
467+
{
468+
for (i=0; i<n; i++) {
469+
for (j=0; j<i; j++) {
470+
p[n*i + j] = 0;
460471
}
461472
}
462473
}
463474

464475
THTensor_(freeCopyTo)(ra__, ra_);
465476
}
466477

478+
void THTensor_(potrs)(THTensor *rb_, THTensor *b, THTensor *a, const char *uplo)
479+
{
480+
if (b == NULL) b = rb_;
481+
482+
THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
483+
THArgCheck(b->size[0] >= b->size[1], 2, "Matrix B is rank-deficient");
484+
485+
int n, nrhs, lda, ldb, info;
486+
THTensor *ra__; // working version of A matrix to be passed into lapack TRTRS
487+
THTensor *rb__; // working version of B matrix to be passed into lapack TRTRS
488+
489+
ra__ = THTensor_(cloneColumnMajor)(NULL, a);
490+
rb__ = THTensor_(cloneColumnMajor)(rb_, b);
491+
492+
n = (int)ra__->size[0];
493+
nrhs = (int)rb__->size[1];
494+
lda = n;
495+
ldb = n;
496+
497+
THLapack_(potrs)(uplo[0], n, nrhs, THTensor_(data)(ra__),
498+
lda, THTensor_(data)(rb__), ldb, &info);
499+
500+
501+
THLapackCheck("Lapack Error in %s : A(%d,%d) is zero, singular A", "potrs", info, info);
502+
503+
THTensor_(free)(ra__);
504+
THTensor_(freeCopyTo)(rb__, rb_);
505+
}
506+
467507
void THTensor_(potri)(THTensor *ra_, THTensor *a)
468508
{
469509
if (a == NULL) a = ra_;

lib/TH/generic/THTensorLapack.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ TH_API void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const ch
1010
TH_API void THTensor_(gesvd)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *a, const char *jobu);
1111
TH_API void THTensor_(gesvd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra_, THTensor *a, const char *jobu);
1212
TH_API void THTensor_(getri)(THTensor *ra_, THTensor *a);
13+
TH_API void THTensor_(potrf)(THTensor *ra_, THTensor *a, const char *uplo);
14+
TH_API void THTensor_(potrs)(THTensor *rb_, THTensor *b_, THTensor *a_, const char *uplo);
1315
TH_API void THTensor_(potri)(THTensor *ra_, THTensor *a);
14-
TH_API void THTensor_(potrf)(THTensor *ra_, THTensor *a);
1516
TH_API void THTensor_(qr)(THTensor *rq_, THTensor *rr_, THTensor *a);
1617
TH_API void THTensor_(geqrf)(THTensor *ra_, THTensor *rtau_, THTensor *a);
1718
TH_API void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau);

test/test.lua

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,11 +2035,48 @@ function torchtest.testBoxMullerState()
20352035
end
20362036

20372037
function torchtest.testCholesky()
2038-
local x = torch.rand(10,10)
2039-
local A = torch.mm(x, x:t())
2040-
local C = torch.potrf(A)
2041-
local B = torch.mm(C:t(), C)
2042-
mytester:assertTensorEq(A, B, 1e-14, 'potrf did not allow rebuilding the original matrix')
2038+
local x = torch.rand(10,10)
2039+
local A = torch.mm(x, x:t())
2040+
2041+
---- Default Case
2042+
local C = torch.potrf(A)
2043+
local B = torch.mm(C:t(), C)
2044+
mytester:assertTensorEq(A, B, 1e-14, 'potrf did not allow rebuilding the original matrix')
2045+
2046+
---- Test Upper Triangular
2047+
local U = torch.potrf(A, 'U')
2048+
B = torch.mm(U:t(), U)
2049+
mytester:assertTensorEq(A, B, 1e-14, 'potrf (upper) did not allow rebuilding the original matrix')
2050+
2051+
---- Test Lower Triangular
2052+
local L = torch.potrf(A, 'L')
2053+
B = torch.mm(L, L:t())
2054+
mytester:assertTensorEq(A, B, 1e-14, 'potrf (lower) did not allow rebuilding the original matrix')
2055+
end
2056+
2057+
function torchtest.potrs()
2058+
if not torch.potrs then return end
2059+
local a=torch.Tensor({{6.80, -2.11, 5.66, 5.97, 8.23},
2060+
{-6.05, -3.30, 5.36, -4.44, 1.08},
2061+
{-0.45, 2.58, -2.70, 0.27, 9.04},
2062+
{8.32, 2.71, 4.35, -7.17, 2.14},
2063+
{-9.67, -5.14, -7.26, 6.08, -6.87}}):t()
2064+
local b=torch.Tensor({{4.02, 6.19, -8.22, -7.57, -3.03},
2065+
{-1.56, 4.00, -8.67, 1.75, 2.86},
2066+
{9.81, -4.09, -4.57, -8.61, 8.99}}):t()
2067+
2068+
---- Make sure 'a' is symmetric PSD
2069+
a = torch.mm(a, a:t())
2070+
2071+
---- Upper Triangular Test
2072+
local U = torch.potrf(a, 'U')
2073+
local x = torch.potrs(b, U, 'U')
2074+
mytester:assertlt(b:dist(a*x),1e-12,'torch.trtrs')
2075+
2076+
---- Lower Triangular Test
2077+
local L = torch.potrf(a, 'L')
2078+
x = torch.potrs(b, L, 'L')
2079+
mytester:assertlt(b:dist(a*x),1e-12,'torch.trtrs')
20432080
end
20442081

20452082
function torchtest.testNumel()

0 commit comments

Comments
 (0)