module Data.Matrix.Banded.Class.Internal (
IOBanded,
STBanded,
unsafeIOBandedToSTBanded,
unsafeSTBandedToIOBanded,
BaseBanded(..),
ReadBanded,
WriteBanded,
bandedViewMatrix,
matrixFromBanded,
ldaOfBanded,
isHermBanded,
hermBanded,
bandwidth,
numLower,
numUpper,
coerceBanded,
newBanded_,
newZeroBanded,
setZeroBanded,
newConstantBanded,
setConstantBanded,
modifyWithBanded,
canModifyElemBanded,
unsafeWriteElemBanded,
unsafeRowViewBanded,
unsafeColViewBanded,
unsafeGetRowBanded,
unsafeGetColBanded,
shapeBanded,
boundsBanded,
withBandedPtr,
withBandedElemPtr,
indexOfBanded,
indicesBanded,
gbmv,
gbmm,
) where
import Control.Monad
import Control.Monad.ST
import Data.Ix
import Data.List( foldl' )
import Foreign
import Unsafe.Coerce
import BLAS.Elem
import BLAS.C.Types
import qualified BLAS.C.Level2 as BLAS
import BLAS.Internal( diagLen )
import BLAS.UnsafeIOToM
import BLAS.Matrix.Base hiding ( BaseMatrix )
import qualified BLAS.Matrix.Base as BLAS
import BLAS.Matrix.Mutable
import BLAS.Tensor
import Data.Vector.Dense.Class.Internal( IOVector, STVector,
BaseVector(..), ReadVector, WriteVector, doConjVector,
withVectorPtr, stride, isConj )
import Data.Vector.Dense.Class.Creating( newListVector )
import Data.Vector.Dense.Class.Operations( getConjVector )
import Data.Matrix.Dense.Class( BaseMatrix, ReadMatrix, WriteMatrix,
isHermMatrix, arrayFromMatrix, matrixViewArray, colViews )
class (BLAS.BaseMatrix a, BaseVector x) =>
BaseBanded a x | a -> x where
bandedViewArray :: ForeignPtr e -> Ptr e -> Int -> Int -> Int -> Int -> Int -> Bool -> a mn e
arrayFromBanded :: a mn e -> (ForeignPtr e, Ptr e, Int, Int, Int, Int, Int, Bool)
class (UnsafeIOToM m, ReadTensor a (Int,Int) m,
BaseBanded a x, ReadVector x m) =>
ReadBanded a x m | a -> x where
class (WriteTensor a (Int,Int) m,
WriteVector x m, ReadBanded a x m) =>
WriteBanded a x m | a -> m, m -> a, a -> x where
withBandedPtr :: (BaseBanded a x, Storable e) =>
a mn e -> (Ptr e -> IO b) -> IO b
withBandedPtr a f =
let (fp,p,_,_,_,_,_,_) = arrayFromBanded a
in do
b <- f p
touchForeignPtr fp
return b
size1 :: (BaseBanded a x) => a mn e -> Int
size1 a = let (_,_,m,_,_,_,_,_) = arrayFromBanded a in m
size2 :: (BaseBanded a x) => a mn e -> Int
size2 a = let (_,_,_,n,_,_,_,_) = arrayFromBanded a in n
lowBW :: (BaseBanded a x) => a mn e -> Int
lowBW a = let (_,_,_,_,kl,_,_,_) = arrayFromBanded a in kl
upBW :: (BaseBanded a x) => a mn e -> Int
upBW a = let (_,_,_,_,_,ku,_,_) = arrayFromBanded a in ku
ldaOfBanded :: (BaseBanded a x) => a mn e -> Int
ldaOfBanded a = let (_,_,_,_,_,_,l,_) = arrayFromBanded a in l
isHermBanded :: (BaseBanded a x) => a mn e -> Bool
isHermBanded a = let (_,_,_,_,_,_,_,h) = arrayFromBanded a in h
matrixFromBanded :: (BaseBanded b x, BaseMatrix a x) =>
b mn e -> ((Int,Int), (Int,Int), a mn' e, Bool)
matrixFromBanded b =
let (f,p,m,n,kl,ku,ld,h) = arrayFromBanded b
a = matrixViewArray f p (kl+1+ku) n ld False
in ((m,n), (kl,ku), a, h)
bandedViewMatrix :: (BaseMatrix a x, BaseBanded b x) =>
(Int,Int) -> (Int,Int) -> a mn e -> Bool -> Maybe (b mn' e)
bandedViewMatrix (m,n) (kl,ku) a h =
if isHermMatrix a
then Nothing
else let (f,p,m',n',ld,_) = arrayFromMatrix a
in case undefined of
_ | m' /= kl+1+ku ->
error $ "bandedViewMatrix:"
++ " number of rows must be equal to number of diagonals"
_ | n' /= n ->
error $ "bandedViewMatrix:"
++ " numbers of columns must be equal"
_ ->
Just $ bandedViewArray f p m n kl ku ld h
bandwidth :: (BaseBanded a x) => a mn e -> (Int,Int)
bandwidth a =
let (kl,ku) = (numLower a, numUpper a)
in (negate kl, ku)
numLower :: (BaseBanded a x) => a mn e -> Int
numLower a | isHermBanded a = upBW a
| otherwise = lowBW a
numUpper :: (BaseBanded a x) => a mn e -> Int
numUpper a | isHermBanded a = lowBW a
| otherwise = upBW a
coerceBanded :: (BaseBanded a x) => a mn e -> a mn' e
coerceBanded = unsafeCoerce
shapeBanded :: (BaseBanded a x) => a mn e -> (Int,Int)
shapeBanded a | isHermBanded a = (size2 a, size1 a)
| otherwise = (size1 a, size2 a)
boundsBanded :: (BaseBanded a x) => a mn e -> ((Int,Int), (Int,Int))
boundsBanded a = ((0,0), (m1,n1)) where (m,n) = shapeBanded a
hermBanded :: (BaseBanded a x) => a (m,n) e -> a (n,m) e
hermBanded a = let (f,p,m,n,kl,ku,l,h) = arrayFromBanded a
in bandedViewArray f p m n kl ku l (not h)
getSizeBanded :: (ReadBanded a x m) => a mn e -> m Int
getSizeBanded = return . sizeBanded
getIndicesBanded :: (ReadBanded a x m) => a mn e -> m [(Int,Int)]
getIndicesBanded = return . indicesBanded
getElemsBanded :: (ReadBanded a x m, Elem e) => a mn e -> m [e]
getElemsBanded a = getAssocsBanded a >>= return . (map snd)
getAssocsBanded :: (ReadBanded a x m, Elem e) => a mn e -> m [((Int,Int),e)]
getAssocsBanded a = do
is <- getIndicesBanded a
unsafeInterleaveM $ mapM (\i -> unsafeReadElem a i >>= \e -> return (i,e)) is
getIndicesBanded' :: (ReadBanded a x m) => a mn e -> m [(Int,Int)]
getIndicesBanded' = getIndicesBanded
getElemsBanded' :: (ReadBanded a x m, Elem e) => a mn e -> m [e]
getElemsBanded' a = getAssocsBanded' a >>= return . (map snd)
getAssocsBanded' :: (ReadBanded a x m, Elem e) => a mn e -> m [((Int,Int),e)]
getAssocsBanded' a = do
is <- getIndicesBanded a
mapM (\i -> unsafeReadElem a i >>= \e -> return (i,e)) is
unsafeReadElemBanded :: (ReadBanded a x m, Elem e) => a mn e -> (Int,Int) -> m e
unsafeReadElemBanded a (i,j)
| isHermBanded a =
unsafeReadElemBanded (hermBanded $ coerceBanded a) (j,i)
>>= return . conj
| hasStorageBanded a (i,j) =
unsafeIOToM $
withBandedElemPtr a (i,j) peek
| otherwise =
return 0
newBanded_ :: (WriteBanded a x m, Elem e) => (Int,Int) -> (Int,Int) -> m (a mn e)
newBanded_ (m,n) (kl,ku)
| m < 0 || n < 0 =
err "dimensions must be non-negative."
| kl < 0 =
err "lower bandwdth must be non-negative."
| m /= 0 && kl >= m =
err "lower bandwidth must be less than m."
| ku < 0 =
err "upper bandwidth must be non-negative."
| n /= 0 && ku >= n =
err "upper bandwidth must be less than n."
| otherwise =
let m' = kl + 1 + ku
l = m'
h = False
in unsafeIOToM $ do
fp <- mallocForeignPtrArray (m' * n)
let p = unsafeForeignPtrToPtr fp
return $ bandedViewArray fp p m n kl ku l h
where
err s = fail $ "newBanded_ " ++ show (m,n) ++ " " ++ show (kl,ku) ++ ": " ++ s
newZeroBanded :: (WriteBanded a x m, Elem e) => (Int,Int) -> (Int,Int) -> m (a mn e)
newZeroBanded mn bw = do
a <- newBanded_ mn bw
setZeroBanded a
return a
newConstantBanded :: (WriteBanded a x m, Elem e) => (Int,Int) -> (Int,Int) -> e -> m (a mn e)
newConstantBanded mn bw e = do
a <- newBanded_ mn bw
setConstantBanded e a
return a
setZeroBanded :: (WriteBanded a x m, Elem e) => a mn e -> m ()
setZeroBanded = setConstantBanded 0
setConstantBanded :: (WriteBanded a x m, Elem e) => e -> a mn e -> m ()
setConstantBanded e a
| isHermBanded a = setConstantBanded (conj e) a'
| otherwise = do
is <- getIndicesBanded a
mapM_ (\i -> unsafeWriteElemBanded a i e) is
where
a' = (hermBanded . coerceBanded) a
unsafeWriteElemBanded :: (WriteBanded a x m, Elem e) =>
a mn e -> (Int,Int) -> e -> m ()
unsafeWriteElemBanded a (i,j) e
| isHermBanded a = unsafeWriteElemBanded a' (j,i) $ conj e
| otherwise = unsafeIOToM $
withBandedElemPtr a (i,j) (`poke` e)
where
a' = (hermBanded . coerceBanded) a
modifyWithBanded :: (WriteBanded a x m, Elem e) => (e -> e) -> a mn e -> m ()
modifyWithBanded f a = do
ies <- getAssocsBanded a
mapM_ (\(ij,e) -> unsafeWriteElemBanded a ij (f e)) ies
canModifyElemBanded :: (WriteBanded a x m) => a mn e -> (Int,Int) -> m Bool
canModifyElemBanded a ij = return $ hasStorageBanded a ij
unsafeRowViewBanded :: (BaseBanded a x, Storable e) =>
a mn e -> Int -> (Int, x k e, Int)
unsafeRowViewBanded a i =
if h then
case unsafeColViewBanded a' i of (nb, v, na) -> (nb, conj v, na)
else
let nb = max (i kl) 0
na = max (n 1 i ku) 0
r = min (ku + i) (kl + ku)
c = max (i kl) 0
p' = p `advancePtr` (r + c * ld)
inc = ld 1
len = n (nb + na)
in if len >= 0
then (nb, vectorViewArray f p' len inc False, na)
else (n , vectorViewArray f p' 0 inc False, 0)
where
(f,p,_,n,kl,ku,ld,h) = arrayFromBanded a
a' = (hermBanded . coerceBanded) a
unsafeColViewBanded :: (BaseBanded a x, Storable e) =>
a mn e -> Int -> (Int, x k e, Int)
unsafeColViewBanded a j =
if h then
case unsafeRowViewBanded a' j of (nb, v, na) -> (nb, conj v, na)
else
let nb = max (j ku) 0
na = max (m 1 j kl) 0
r = max (ku j) 0
c = j
p' = p `advancePtr` (r + c * ld)
inc = 1
len = m (nb + na)
in if len >= 0
then (nb, vectorViewArray f p' len inc False, na)
else (m , vectorViewArray f p' 0 inc False, 0)
where
(f,p,m,_,kl,ku,ld,h) = arrayFromBanded a
a' = (hermBanded . coerceBanded) a
unsafeGetRowBanded :: (ReadBanded a x m, WriteVector y m, Elem e) =>
a (k,l) e -> Int -> m (y l e)
unsafeGetRowBanded a i =
let (nb,x,na) = unsafeRowViewBanded a i
n = numCols a
in do
es <- getElems x
newListVector n $ (replicate nb 0) ++ es ++ (replicate na 0)
unsafeGetColBanded :: (ReadBanded a x m, WriteVector y m, Elem e) =>
a (k,l) e -> Int -> m (y k e)
unsafeGetColBanded a j = unsafeGetRowBanded (hermBanded a) j >>= return . conj
gbmv :: (ReadBanded a z m, ReadVector x m, WriteVector y m, BLAS2 e) =>
e -> a (k,l) e -> x l e -> e -> y k e -> m ()
gbmv alpha a x beta y
| numRows a == 0 || numCols a == 0 =
scaleBy beta y
| isConj x = do
x' <- getConjVector (conj x)
gbmv alpha a x' beta y
| isConj y = do
doConjVector y
gbmv alpha a x beta (conj y)
doConjVector y
| otherwise =
let order = colMajor
transA = blasTransOf a
(m,n) = case (isHermBanded a) of
False -> shape a
True -> (flipShape . shape) a
(kl,ku) = case (isHermBanded a) of
False -> (numLower a, numUpper a)
True -> (numUpper a, numLower a)
ldA = ldaOfBanded a
incX = stride x
incY = stride y
in unsafeIOToM $
withBandedPtr a $ \pA ->
withVectorPtr x $ \pX ->
withVectorPtr y $ \pY -> do
BLAS.gbmv order transA m n kl ku alpha pA ldA pX incX beta pY incY
gbmm :: (ReadBanded a x m, ReadMatrix b y m, WriteMatrix c z m, BLAS2 e) =>
e -> a (r,s) e -> b (s,t) e -> e -> c (r,t) e -> m ()
gbmm alpha a b beta c =
sequence_ $
zipWith (\x y -> gbmv alpha a x beta y) (colViews b) (colViews c)
withBandedElemPtr :: (BaseBanded a x, Storable e) =>
a mn e -> (Int,Int) -> (Ptr e -> IO b) -> IO b
withBandedElemPtr a (i,j) f
| isHermBanded a = withBandedElemPtr (hermBanded $ coerceBanded a) (j,i) f
| otherwise = withBandedPtr a $ \ptr ->
f $ ptr `advancePtr` (indexOfBanded a (i,j))
indexOfBanded :: (BaseBanded a x) => a mn e -> (Int,Int) -> Int
indexOfBanded a (i,j) =
let (_,_,_,_,_,ku,ld,h) = arrayFromBanded a
(i',j') = if h then (j,i) else (i,j)
in ku + (i' j') + j' * ld
hasStorageBanded :: (BaseBanded a x) => a mn e -> (Int,Int) -> Bool
hasStorageBanded a (i,j) =
let (_,_,m,_,kl,ku,_,h) = arrayFromBanded a
(i',j') = if h then (j,i) else (i,j)
in inRange (max 0 (j'ku), min (m1) (j'+kl)) i'
sizeBanded :: (BaseBanded a x) => a mn e -> Int
sizeBanded a =
let (_,_,m,n,kl,ku,_,_) = arrayFromBanded a
in foldl' (+) 0 $ map (diagLen (m,n)) [(kl)..ku]
indicesBanded :: (BaseBanded a x) => a mn e -> [(Int,Int)]
indicesBanded a =
let is = if isHermBanded a
then [ (i,j) | i <- range (0,m1), j <- range (0,n1) ]
else [ (i,j) | j <- range (0,n1), i <- range (0,m1) ]
in filter (hasStorageBanded a) is
where (m,n) = shapeBanded a
blasTransOf :: (BaseBanded a x) => a mn e -> CBLASTrans
blasTransOf a =
case (isHermBanded a) of
False -> noTrans
True -> conjTrans
flipShape :: (Int,Int) -> (Int,Int)
flipShape (m,n) = (n,m)
data IOBanded mn e =
BM !(ForeignPtr e)
!(Ptr e)
!Int
!Int
!Int
!Int
!Int
!Bool
newtype STBanded s mn e = ST (IOBanded mn e)
unsafeIOBandedToSTBanded :: IOBanded mn e -> STBanded s mn e
unsafeIOBandedToSTBanded = ST
unsafeSTBandedToIOBanded :: STBanded s mn e -> IOBanded mn e
unsafeSTBandedToIOBanded (ST x) = x
instance BaseBanded IOBanded IOVector where
bandedViewArray f p m n kl ku ld h = BM f p m n kl ku ld h
arrayFromBanded (BM f p m n kl ku ld h) = (f,p,m,n,kl,ku,ld,h)
instance BaseBanded (STBanded s) (STVector s) where
bandedViewArray f p m n kl ku ld h = ST (BM f p m n kl ku ld h)
arrayFromBanded (ST (BM f p m n kl ku ld h)) = (f,p,m,n,kl,ku,ld,h)
instance BaseTensor IOBanded (Int,Int) where
shape = shapeBanded
bounds = boundsBanded
instance BaseTensor (STBanded s) (Int,Int) where
shape = shapeBanded
bounds = boundsBanded
instance BLAS.BaseMatrix IOBanded where
herm = hermBanded
instance BLAS.BaseMatrix (STBanded s) where
herm = hermBanded
instance ReadBanded IOBanded IOVector IO
instance ReadBanded (STBanded s) (STVector s) (ST s)
instance ReadTensor IOBanded (Int,Int) IO where
getSize = getSizeBanded
getAssocs = getAssocsBanded
getIndices = getIndicesBanded
getElems = getElemsBanded
getAssocs' = getAssocsBanded'
getIndices' = getIndicesBanded'
getElems' = getElemsBanded'
unsafeReadElem = unsafeReadElemBanded
instance ReadTensor (STBanded s) (Int,Int) (ST s) where
getSize = getSizeBanded
getAssocs = getAssocsBanded
getIndices = getIndicesBanded
getElems = getElemsBanded
getAssocs' = getAssocsBanded'
getIndices' = getIndicesBanded'
getElems' = getElemsBanded'
unsafeReadElem = unsafeReadElemBanded
instance WriteBanded IOBanded IOVector IO where
instance WriteBanded (STBanded s) (STVector s) (ST s) where
instance WriteTensor IOBanded (Int,Int) IO where
setConstant = setConstantBanded
setZero = setZeroBanded
modifyWith = modifyWithBanded
unsafeWriteElem = unsafeWriteElemBanded
canModifyElem = canModifyElemBanded
instance WriteTensor (STBanded s) (Int,Int) (ST s) where
setConstant = setConstantBanded
setZero = setZeroBanded
modifyWith = modifyWithBanded
unsafeWriteElem = unsafeWriteElemBanded
canModifyElem = canModifyElemBanded
instance (BLAS2 e) => MMatrix IOBanded e IO where
unsafeDoSApplyAdd = gbmv
unsafeDoSApplyAddMat = gbmm
unsafeGetRow = unsafeGetRowBanded
unsafeGetCol = unsafeGetColBanded
instance (BLAS2 e) => MMatrix (STBanded s) e (ST s) where
unsafeDoSApplyAdd = gbmv
unsafeDoSApplyAddMat = gbmm
unsafeGetRow = unsafeGetRowBanded
unsafeGetCol = unsafeGetColBanded