module Data.Matrix.Dense.Internal (
Matrix(..),
module BLAS.Tensor.Base,
module BLAS.Matrix.Base,
coerceMatrix,
matrix,
listMatrix,
rowsMatrix,
colsMatrix,
rowMatrix,
colMatrix,
unsafeMatrix,
module BLAS.Tensor.Immutable,
zeroMatrix,
constantMatrix,
identityMatrix,
submatrix,
splitRowsAt,
splitColsAt,
unsafeSubmatrix,
unsafeSplitRowsAt,
unsafeSplitColsAt,
diag,
unsafeDiag,
ldaOfMatrix,
isHermMatrix,
) where
import Data.AEq
import System.IO.Unsafe
import BLAS.Elem ( Elem, BLAS1 )
import BLAS.Internal ( inlinePerformIO )
import BLAS.UnsafeIOToM
import BLAS.Tensor.Base
import BLAS.Tensor.Immutable
import BLAS.Tensor
import BLAS.Matrix.Base hiding ( BaseMatrix )
import qualified BLAS.Matrix.Base as BLAS
import Data.Matrix.Dense.Class.Creating
import Data.Matrix.Dense.Class.Special
import Data.Matrix.Dense.Class.Views( submatrixView, unsafeSubmatrixView,
splitRowsAt, splitColsAt, unsafeSplitRowsAt, unsafeSplitColsAt,
diagView, unsafeDiagView )
import Data.Matrix.Dense.Class.Internal( coerceMatrix, isHermMatrix,
ldaOfMatrix, colViews, BaseMatrix(..), IOMatrix, maybeFromRow,
maybeFromCol, newCopyMatrix, ReadMatrix )
import Data.Matrix.Dense.Class.Operations
import Data.Vector.Dense.Class.Internal
import Data.Vector.Dense
newtype Matrix mn e = M (IOMatrix mn e)
unsafeFreezeIOMatrix :: IOMatrix mn e -> Matrix mn e
unsafeFreezeIOMatrix = M
unsafeThawIOMatrix :: Matrix mn e -> IOMatrix mn e
unsafeThawIOMatrix (M a) = a
liftMatrix :: (IOMatrix n e -> a) -> Matrix n e -> a
liftMatrix f (M x) = f x
liftMatrix2 ::
(IOMatrix n e -> IOMatrix n e -> a) ->
Matrix n e -> Matrix n e -> a
liftMatrix2 f x = liftMatrix (liftMatrix f x)
unsafeLiftMatrix :: (IOMatrix n e -> IO a) -> Matrix n e -> a
unsafeLiftMatrix f = unsafePerformIO . liftMatrix f
unsafeLiftMatrix2 ::
(IOMatrix n e -> IOMatrix n e -> IO a) ->
Matrix n e -> Matrix n e -> a
unsafeLiftMatrix2 f x y = unsafePerformIO $ liftMatrix2 f x y
inlineLiftMatrix :: (IOMatrix n e -> IO a) -> Matrix n e -> a
inlineLiftMatrix f = inlinePerformIO . liftMatrix f
matrix :: (BLAS1 e) => (Int,Int) -> [((Int,Int), e)] -> Matrix (m,n) e
matrix mn ies = unsafeFreezeIOMatrix $ unsafePerformIO $ newMatrix mn ies
listMatrix :: (BLAS1 e) => (Int,Int) -> [e] -> Matrix (m,n) e
listMatrix mn es = unsafeFreezeIOMatrix $ unsafePerformIO $ newListMatrix mn es
unsafeMatrix :: (BLAS1 e) => (Int,Int) -> [((Int,Int), e)] -> Matrix (m,n) e
unsafeMatrix mn ies = unsafeFreezeIOMatrix $ unsafePerformIO $ unsafeNewMatrix mn ies
rowsMatrix :: (BLAS1 e) => (Int,Int) -> [Vector n e] -> Matrix (m,n) e
rowsMatrix mn rs = unsafeFreezeIOMatrix $ unsafePerformIO $ newRowsMatrix mn rs
colsMatrix :: (BLAS1 e) => (Int,Int) -> [Vector m e] -> Matrix (m,n) e
colsMatrix mn cs = unsafeFreezeIOMatrix $ unsafePerformIO $ newColsMatrix mn cs
rowMatrix :: (BLAS1 e) => Vector n e -> Matrix (one,n) e
rowMatrix x =
case maybeFromRow $ unsafeThawIOVector x of
Just x' -> unsafeFreezeIOMatrix x'
Nothing -> unsafeFreezeIOMatrix $ unsafePerformIO $ newRowMatrix x
where
unsafeThawIOVector :: Vector n e -> IOVector n e
unsafeThawIOVector = unsafeThawVector
colMatrix :: (BLAS1 e) => Vector m e -> Matrix (m,one) e
colMatrix x =
case maybeFromCol $ unsafeThawIOVector x of
Just x' -> unsafeFreezeIOMatrix x'
Nothing -> unsafeFreezeIOMatrix $ unsafePerformIO $ newColMatrix x
where
unsafeThawIOVector :: Vector n e -> IOVector n e
unsafeThawIOVector = unsafeThawVector
zeroMatrix :: (BLAS1 e) => (Int,Int) -> Matrix (m,n) e
zeroMatrix mn = unsafeFreezeIOMatrix $ unsafePerformIO $ newZeroMatrix mn
constantMatrix :: (BLAS1 e) => (Int,Int) -> e -> Matrix (m,n) e
constantMatrix mn e = unsafeFreezeIOMatrix $ unsafePerformIO $ newConstantMatrix mn e
identityMatrix :: (BLAS1 e) => (Int,Int) -> Matrix (m,n) e
identityMatrix mn = unsafeFreezeIOMatrix $ unsafePerformIO $ newIdentityMatrix mn
submatrix :: (Elem e) => Matrix mn e -> (Int,Int) -> (Int,Int) -> Matrix mn' e
submatrix = submatrixView
unsafeSubmatrix :: (Elem e) => Matrix mn e -> (Int,Int) -> (Int,Int) -> Matrix mn' e
unsafeSubmatrix = unsafeSubmatrixView
diag :: (Elem e) => Matrix mn e -> Int -> Vector k e
diag = diagView
unsafeDiag :: (Elem e) => Matrix mn e -> Int -> Vector k e
unsafeDiag = unsafeDiagView
instance BaseTensor Matrix (Int,Int) where
shape = liftMatrix shape
bounds = liftMatrix bounds
instance ITensor Matrix (Int,Int) where
(//) = replaceHelp writeElem
unsafeReplace = replaceHelp unsafeWriteElem
unsafeAt a i = inlineLiftMatrix (flip unsafeReadElem i) a
size = inlineLiftMatrix getSize
elems = inlineLiftMatrix getElems
indices = inlineLiftMatrix getIndices
assocs = inlineLiftMatrix getAssocs
tmap f a
| isHermMatrix a = coerceMatrix $ herm $
listMatrix (n,m) $ map (conj . f) (elems a)
| otherwise = coerceMatrix $
listMatrix (m,n) $ map f (elems a)
where
(m,n) = shape a
(*>) k x = unsafeFreezeIOMatrix $ unsafeLiftMatrix (getScaledMatrix k) x
shift k x = unsafeFreezeIOMatrix $ unsafeLiftMatrix (getShiftedMatrix k) x
replaceHelp :: (BLAS1 e) =>
(IOMatrix mn e -> (Int,Int) -> e -> IO ()) ->
Matrix mn e -> [((Int,Int), e)] -> Matrix mn e
replaceHelp set x ies =
unsafePerformIO $ do
y <- newCopyMatrix (unsafeThawIOMatrix x)
mapM_ (uncurry $ set y) ies
return (unsafeFreezeIOMatrix y)
instance (Monad m) => ReadTensor Matrix (Int,Int) m where
getSize = return . size
getAssocs = return . assocs
getIndices = return . indices
getElems = return . elems
getAssocs' = getAssocs
getIndices' = getIndices
getElems' = getElems
unsafeReadElem x i = return (unsafeAt x i)
instance BLAS.BaseMatrix Matrix where
herm (M a) = M (herm a)
instance BaseMatrix Matrix Vector where
matrixViewArray f p m n l h = M $ matrixViewArray f p m n l h
arrayFromMatrix (M a ) = arrayFromMatrix a
instance (UnsafeIOToM m) => ReadMatrix Matrix Vector m where
instance (BLAS1 e) => Num (Matrix mn e) where
(+) x y = unsafeFreezeIOMatrix $ unsafeLiftMatrix2 getAddMatrix x y
() x y = unsafeFreezeIOMatrix $ unsafeLiftMatrix2 getSubMatrix x y
(*) x y = unsafeFreezeIOMatrix $ unsafeLiftMatrix2 getMulMatrix x y
negate = ((1) *>)
abs = tmap abs
signum = tmap signum
fromInteger = coerceMatrix . (constantMatrix (1,1)) . fromInteger
instance (BLAS1 e) => Fractional (Matrix mn e) where
(/) x y = unsafeFreezeIOMatrix $ unsafeLiftMatrix2 getDivMatrix x y
recip = tmap recip
fromRational = coerceMatrix . (constantMatrix (1,1)) . fromRational
instance (BLAS1 e, Floating e) => Floating (Matrix (m,n) e) where
pi = constantMatrix (1,1) pi
exp = tmap exp
sqrt = tmap sqrt
log = tmap log
(**) = tzipWith (**)
sin = tmap sin
cos = tmap cos
tan = tmap tan
asin = tmap asin
acos = tmap acos
atan = tmap atan
sinh = tmap sinh
cosh = tmap cosh
tanh = tmap tanh
asinh = tmap asinh
acosh = tmap acosh
atanh = tmap atanh
tzipWith :: (BLAS1 e) =>
(e -> e -> e) -> Matrix mn e -> Matrix mn e -> Matrix mn e
tzipWith f a b
| shape b /= mn =
error ("tzipWith: matrix shapes differ; first has shape `" ++
show mn ++ "' and second has shape `" ++
show (shape b) ++ "'")
| otherwise =
coerceMatrix $
listMatrix mn $ zipWith f (colElems a) (colElems b)
where
mn = shape a
colElems = (concatMap elems) . colViews . coerceMatrix
instance (BLAS1 e, Show e) => Show (Matrix mn e) where
show a | isHermMatrix a =
"herm (" ++ show (herm $ coerceMatrix a) ++ ")"
| otherwise =
"listMatrix " ++ show (shape a) ++ " " ++ show (elems a)
compareHelp :: (BLAS1 e) =>
(e -> e -> Bool) -> Matrix mn e -> Matrix mn e -> Bool
compareHelp cmp a b
| shape a /= shape b =
False
| isHermMatrix a == isHermMatrix b =
let elems' = if isHermMatrix a then elems . herm .coerceMatrix
else elems
in
and $ zipWith cmp (elems' a) (elems' b)
| otherwise =
and $ zipWith cmp (colElems a) (colElems b)
where
colElems c = concatMap elems (colViews $ coerceMatrix c)
instance (BLAS1 e, Eq e) => Eq (Matrix mn e) where
(==) = compareHelp (==)
instance (BLAS1 e, AEq e) => AEq (Matrix mn e) where
(===) = compareHelp (===)
(~==) = compareHelp (~==)