module Data.Matrix.Dense.Class.Operations (
getConjMatrix,
getScaledMatrix,
getShiftedMatrix,
getAddMatrix,
getSubMatrix,
getMulMatrix,
getDivMatrix,
addMatrix,
subMatrix,
axpyMatrix,
mulMatrix,
divMatrix,
unsafeGetAddMatrix,
unsafeGetSubMatrix,
unsafeGetMulMatrix,
unsafeGetDivMatrix,
unsafeAddMatrix,
unsafeSubMatrix,
unsafeAxpyMatrix,
unsafeMulMatrix,
unsafeDivMatrix,
) where
import BLAS.Elem( BLAS1 )
import BLAS.Internal( checkBinaryOp )
import BLAS.Tensor( BaseTensor(..) )
import Data.Matrix.Dense.Class.Internal
getConjMatrix :: (ReadMatrix a x m, WriteMatrix b y m, BLAS1 e) =>
a mn e -> m (b mn e)
getConjMatrix = getUnaryOp doConjMatrix
getScaledMatrix :: (ReadMatrix a x m, WriteMatrix b y m, BLAS1 e) =>
e -> a mn e -> m (b mn e)
getScaledMatrix e = getUnaryOp (scaleByMatrix e)
getShiftedMatrix :: (ReadMatrix a x m, WriteMatrix b y m, BLAS1 e) =>
e -> a mn e -> m (b mn e)
getShiftedMatrix e = getUnaryOp (shiftByMatrix e)
getAddMatrix ::
(ReadMatrix a x m, ReadMatrix b x m, WriteMatrix c z m, BLAS1 e) =>
a mn e -> b mn e -> m (c mn e)
getAddMatrix = checkTensorOp2 unsafeGetAddMatrix
unsafeGetAddMatrix ::
(ReadMatrix a x m, ReadMatrix b x m, WriteMatrix c z m, BLAS1 e) =>
a mn e -> b mn e -> m (c mn e)
unsafeGetAddMatrix = unsafeGetBinaryOp unsafeAddMatrix
getSubMatrix ::
(ReadMatrix a x m, ReadMatrix b x m, WriteMatrix c z m, BLAS1 e) =>
a mn e -> b mn e -> m (c mn e)
getSubMatrix = checkTensorOp2 unsafeGetSubMatrix
unsafeGetSubMatrix ::
(ReadMatrix a x m, ReadMatrix b x m, WriteMatrix c z m, BLAS1 e) =>
a mn e -> b mn e -> m (c mn e)
unsafeGetSubMatrix = unsafeGetBinaryOp unsafeSubMatrix
getMulMatrix ::
(ReadMatrix a x m, ReadMatrix b x m, WriteMatrix c z m, BLAS1 e) =>
a mn e -> b mn e -> m (c mn e)
getMulMatrix = checkTensorOp2 unsafeGetMulMatrix
unsafeGetMulMatrix ::
(ReadMatrix a x m, ReadMatrix b x m, WriteMatrix c z m, BLAS1 e) =>
a mn e -> b mn e -> m (c mn e)
unsafeGetMulMatrix = unsafeGetBinaryOp unsafeMulMatrix
getDivMatrix ::
(ReadMatrix a x m, ReadMatrix b x m, WriteMatrix c z m, BLAS1 e) =>
a mn e -> b mn e -> m (c mn e)
getDivMatrix = checkTensorOp2 unsafeGetDivMatrix
unsafeGetDivMatrix ::
(ReadMatrix a x m, ReadMatrix b x m, WriteMatrix c z m, BLAS1 e) =>
a mn e -> b mn e -> m (c mn e)
unsafeGetDivMatrix = unsafeGetBinaryOp unsafeDivMatrix
axpyMatrix :: (ReadMatrix a x m, WriteMatrix b y m, BLAS1 e) =>
e -> a n e -> b n e -> m ()
axpyMatrix alpha x y =
checkBinaryOp (shape x) (shape y) $ unsafeAxpyMatrix alpha x y
addMatrix :: (WriteMatrix b y m, ReadMatrix a x m, BLAS1 e) =>
b n e -> a n e -> m ()
addMatrix b a =
checkBinaryOp (shape b) (shape a) $ unsafeAddMatrix b a
unsafeAddMatrix :: (WriteMatrix b y m, ReadMatrix a x m, BLAS1 e) =>
b n e -> a n e -> m ()
unsafeAddMatrix b a = unsafeAxpyMatrix 1 a b
subMatrix :: (WriteMatrix b y m, ReadMatrix a x m, BLAS1 e) =>
b n e -> a n e -> m ()
subMatrix b a =
checkBinaryOp (shape b) (shape a) $ unsafeSubMatrix b a
unsafeSubMatrix :: (WriteMatrix b y m, ReadMatrix a x m, BLAS1 e) =>
b n e -> a n e -> m ()
unsafeSubMatrix b a = unsafeAxpyMatrix (1) a b
mulMatrix :: (WriteMatrix b y m, ReadMatrix a x m, BLAS1 e) =>
b n e -> a n e -> m ()
mulMatrix b a =
checkBinaryOp (shape b) (shape a) $ unsafeMulMatrix b a
divMatrix :: (WriteMatrix b y m, ReadMatrix a x m, BLAS1 e) =>
b n e -> a n e -> m ()
divMatrix b a =
checkBinaryOp (shape b) (shape a) $ unsafeDivMatrix b a
checkTensorOp2 :: (BaseTensor x i, BaseTensor y i) =>
(x n e -> y n e -> a) ->
x n e -> y n e -> a
checkTensorOp2 f x y =
checkBinaryOp (shape x) (shape y) $ f x y
getUnaryOp :: (ReadMatrix a x m, WriteMatrix b y m, BLAS1 e) =>
(b mn e -> m ()) -> a mn e -> m (b mn e)
getUnaryOp f a = do
b <- newCopyMatrix a
f b
return b
unsafeGetBinaryOp ::
(WriteMatrix c z m, ReadMatrix a x m, ReadMatrix b x m, BLAS1 e) =>
(c n e -> b n e -> m ()) ->
a n e -> b n e -> m (c n e)
unsafeGetBinaryOp f a b = do
c <- newCopyMatrix a
f c b
return c