Skip to content

Commit 3f4b46a

Browse files
committed
Add potrs with MAGMA
1 parent bd38b9c commit 3f4b46a

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

THCTensorMath.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ THC_API void THCudaTensor_gesvd2(THCState *state, THCudaTensor *ru_, THCudaTenso
9595
THC_API void THCudaTensor_getri(THCState *state, THCudaTensor *ra_, THCudaTensor *a);
9696
THC_API void THCudaTensor_potri(THCState *state, THCudaTensor *ra_, THCudaTensor *a);
9797
THC_API void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor *a);
98+
THC_API void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *a, THCudaTensor *b);
9899
THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a);
99100

100101
THC_API void THCudaTensor_cat(THCState *state, THCudaTensor *result, THCudaTensor *ta, THCudaTensor *tb, int dimension);

THCTensorMathMagma.cu

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,33 @@ void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor *a)
507507
#endif
508508
}
509509

510+
void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *b, THCudaTensor *a)
511+
{
512+
#ifdef USE_MAGMA
513+
THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
514+
515+
int n = a->size[0];
516+
int nrhs = b->size[1];
517+
518+
THCudaTensor *b_ = THCudaTensor_newColumnMajor(state, rb_, b);
519+
float *b_data = THCudaTensor_data(state, b_);
520+
THCudaTensor *a_ = THCudaTensor_newColumnMajor(state, a, a);
521+
float *a_data = THCudaTensor_data(state, a_);
522+
523+
int info;
524+
magma_spotrs_gpu(MagmaUpper, n, nrhs, a_data, n, b_data, n, &info);
525+
526+
// check error value
527+
if (info < 0)
528+
THError("MAGMA potrs : Argument %d : illegal value", -info);
529+
530+
THCudaTensor_freeCopyTo(state, b_, rb_);
531+
THCudaTensor_free(state, a_);
532+
#else
533+
THError(NoMagma(potrs));
534+
#endif
535+
}
536+
510537
void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a_)
511538
{
512539
#ifdef USE_MAGMA

0 commit comments

Comments
 (0)