Skip to content

Monster PR - vectors, matrices, number theory bindings #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 40 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
c3a13f5
started binding CVec
bollu Nov 21, 2016
41861c1
started writing code to use Symengine exceptions, simplified code to …
bollu Nov 22, 2016
327d249
add a list of things you learnt along the way
bollu Nov 22, 2016
af90964
added error handling code so vectors dont crash on out of bounds
bollu Nov 24, 2016
89da046
implement vector with foreign pointer, started binding dense matrix
bollu Nov 25, 2016
f7ae2e6
started binding dense matrices
bollu Nov 25, 2016
057e308
added test for dense matrices
bollu Nov 25, 2016
de1032b
added checks to bounds in basicvec, as isuruf said that such checks s…
bollu Nov 25, 2016
c991b35
touched README so travis build is triggered
bollu Nov 25, 2016
fd4a5d2
made versions of cabal to be newer
bollu Nov 25, 2016
6a42612
changed a bunch of definitions to make the code simpler
bollu Nov 30, 2016
bb1d570
added a typeclass called Wrapped that represents ForeignPtr's wrapped…
bollu Nov 30, 2016
1777564
implemented getter for matrix
bollu Nov 30, 2016
8d22266
added getter for CDenseMatrix
bollu Nov 30, 2016
7d6ee7d
added dimensions access to dense matrix
bollu Nov 30, 2016
cf027a5
fixed get_size
bollu Nov 30, 2016
4440d9b
changed travis file to use the correct cabal, GHC version. BUMP
bollu Nov 30, 2016
e268442
added more dense matrix code
bollu Dec 8, 2016
dfdc5df
bound dense matrix solves, need to write test cases
bollu Dec 8, 2016
f7a707f
rewrote modules to be split into separate code
bollu Dec 9, 2016
9151a80
no longer allow a densematrix_new and vecbasic_new. Should be IO
bollu Dec 10, 2016
d5538bb
dependant typing is matrix. DenseMatrix size is now dependant typed
bollu Dec 11, 2016
35c7c74
made eye into typed function
bollu Dec 12, 2016
5bd0861
changed densematrix_get to be type level
bollu Dec 12, 2016
d9c5e4b
fully typed densematrix API
bollu Dec 12, 2016
82ec441
made code referentially transparent
bollu Dec 13, 2016
a3fa4ef
add comment about debacle with densematrix_set
bollu Dec 13, 2016
a96e5be
continue building number theory
bollu Dec 13, 2016
30411db
implemented number theory bindings
bollu Dec 14, 2016
16ebc01
changed the way basicsym, densematrix is constructed to now abuse mkF…
bollu Dec 14, 2016
a26e46e
edited VecBasic as well to prevent weird memory races. Hope this is c…
bollu Dec 14, 2016
afe7aed
edited basic_unaryop to do the construction thing
bollu Dec 14, 2016
5add770
identity of + is crashing
bollu Dec 14, 2016
9f549b4
DOES NOT COMPILE: turns out memory is _not_ the problem. changed all …
bollu Dec 15, 2016
ad21095
minimal test case: create and do nothing. crashes
bollu Dec 15, 2016
50f002d
found the error. divide by 0. I was assuming Q, (*) is a group, and n…
bollu Dec 15, 2016
7e2e5fc
crash fixed: removed test case that tried to invert 0. Need to actual…
bollu Dec 15, 2016
7a681fa
changed basicsym_binaryop to deal with exceptions. TODO: edit other c…
bollu Dec 15, 2016
a7ee3fb
added algebra-based test cases. Implemented det, inv, etc.
bollu Dec 15, 2016
27722bd
expose transpose
bollu Dec 15, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added algebra-based test cases. Implemented det, inv, etc.
  • Loading branch information
bollu committed Dec 15, 2016
commit a7ee3fb9766734212dcdde8af8e3b7bb08340e8a
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ GHCi session with Symengine loaded
clone `Symengine`, build it with the setting

```
cmake -DBUILD_SHARED_LIBS:BOOL=ON
cmake -DWITH_SYMENGINE_THREAD_SAFE=yes -DBUILD_SHARED_LIBS:BOOL=ON
```

this makes sure that dynamically linked libraries are being built, so we can
Expand Down
12 changes: 6 additions & 6 deletions src/Symengine/BasicSym.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ module Symengine.BasicSym(
complex,
symbol_new,
diff,
expand,
-- HACK: this should be internal :(
basicsym_new,
BasicSym,
Expand Down Expand Up @@ -129,14 +130,14 @@ lift_basicsym_binaryop :: (Ptr CBasicSym -> Ptr CBasicSym -> Ptr CBasicSym -> IO
BasicSym -> BasicSym -> BasicSym
lift_basicsym_binaryop f a b = unsafePerformIO $ do
s <- basicsym_new
exception_id <- with3 s a b f
forceException (liftException exception_id s)
with3 s a b f >>= throwOnSymIntException

return s

lift_basicsym_unaryop :: (Ptr CBasicSym -> Ptr CBasicSym -> IO a) -> BasicSym -> BasicSym
lift_basicsym_unaryop :: (Ptr CBasicSym -> Ptr CBasicSym -> IO CInt) -> BasicSym -> BasicSym
lift_basicsym_unaryop f a = unsafePerformIO $ do
s <- basicsym_new
with2 s a f
with2 s a f >>= throwOnSymIntException
return $ s


Expand Down Expand Up @@ -184,7 +185,7 @@ instance Num BasicSym where
(*) = lift_basicsym_binaryop $ basic_mul_ffi
negate = lift_basicsym_unaryop basic_neg_ffi
abs = lift_basicsym_unaryop basic_abs_ffi
signum = undefined

-- works only for long [-2^32, 2^32 - 1]
fromInteger = basic_from_integer

Expand Down Expand Up @@ -215,7 +216,6 @@ instance Floating BasicSym where

foreign import ccall "symengine/cwrapper.h ascii_art_str" ascii_art_str_ffi :: IO CString
foreign import ccall "symengine/cwrapper.h basic_new_heap" basic_new_heap_ffi :: IO (Ptr CBasicSym)
foreign import ccall "symengine/cwrapper.h basic_init_heap" basic_init_heap_ffi :: Ptr CBasicSym -> IO ()
foreign import ccall "symengine/cwrapper.h &basic_free_heap" ptr_basic_free_heap_ffi :: FunPtr(Ptr CBasicSym -> IO ())

-- constants
Expand Down
142 changes: 109 additions & 33 deletions src/Symengine/DenseMatrix.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ module Symengine.DenseMatrix
densematrix_new_vec,
densematrix_new_eye,
densematrix_new_diag,
densematrix_new_zeros,
densematrix_get,
densematrix_set,
densematrix_size,
Expand All @@ -32,13 +33,21 @@ module Symengine.DenseMatrix
densematrix_add,
densematrix_mul_matrix,
densematrix_mul_scalar,
det,
inv,

--decomposition
L(L), D(D), U(U),
densematrix_lu,
densematrix_ldl,
densematrix_fflu,
densematrix_ffldu,
densematrix_lu_solve
densematrix_lu_solve,

-- custom matrix class
Matrix(..)

--
)
where

Expand All @@ -64,6 +73,14 @@ import Data.Finite -- types to represent numbers
import Symengine.Internal
import Symengine.BasicSym
import Symengine.VecBasic

class Matrix m where
(<>) :: (KnownNat r, KnownNat c, KnownNat k) => m r k -> m k c -> m r c


instance Matrix (DenseMatrix) where
(<>) = densematrix_mul_matrix

data CDenseMatrix
data DenseMatrix :: Nat -> Nat -> * where
-- allow constructing raw DenseMatrix from a constructor
Expand All @@ -83,14 +100,28 @@ instance (KnownNat r, KnownNat c) => Eq (DenseMatrix r c) where
1 == fromIntegral (unsafePerformIO $
with2 mat1 mat2 cdensematrix_eq_ffi)

instance (KnownNat r, KnownNat c) => Num (DenseMatrix r c) where
(+) = densematrix_add
(-) d1 d2 = let
d2_neg = densematrix_mul_scalar d2 (fromInteger (-1))
in d1 + d2_neg
-- TODO: Should be elementwise multiplcation
(*) = undefined
-- TODO: should be elementwise signum
signum = undefined
-- TODO: should be elementwise abs
abs = undefined
-- make a 1x1 matrix
fromInteger = undefined -- densematrix_new_vec

densematrix_new :: (KnownNat r, KnownNat c) => IO (DenseMatrix r c)
densematrix_new = DenseMatrix <$> (mkForeignPtr cdensematrix_new_ffi cdensematrix_free_ffi)

_densematrix_copy :: (KnownNat r, KnownNat c) => DenseMatrix r c -> IO (DenseMatrix r c)
_densematrix_copy mat = do
newmat <- densematrix_new
with2 newmat mat cdensematrix_set_ffi
return newmat
newmat <- densematrix_new
throwOnSymIntException =<< with2 newmat mat cdensematrix_set_ffi
return newmat

densematrix_new_rows_cols :: forall r c . (KnownNat r, KnownNat c) => DenseMatrix r c
densematrix_new_rows_cols =
Expand All @@ -115,10 +146,20 @@ type Offset = Int
densematrix_new_eye :: forall k r c. (KnownNat r, KnownNat c, KnownNat k, KnownNat (r + k), KnownNat (c + k)) => DenseMatrix (r + k) (c + k)
densematrix_new_eye = unsafePerformIO $ do
let mat = densematrix_new_rows_cols
with mat (\m -> cdensematrix_eye_ffi m
throwOnSymIntException =<< with mat (\m -> cdensematrix_eye_ffi m
(fromIntegral . natVal $ (Proxy @ r))
(fromIntegral . natVal $ (Proxy @ c))
(fromIntegral . natVal $ (Proxy @ k)))


return mat

densematrix_new_zeros :: forall r c. (KnownNat r, KnownNat c) => DenseMatrix r c
densematrix_new_zeros = unsafePerformIO $ do
let mat = densematrix_new_rows_cols
throwOnSymIntException =<< with mat (\m -> cdensematrix_zeros_ffi m
(fromIntegral . natVal $ (Proxy @ r))
(fromIntegral . natVal $ (Proxy @ c)))
return mat

-- create a matrix with diagonal elements of length 'd', offset 'k'
Expand All @@ -129,33 +170,35 @@ densematrix_new_diag syms = unsafePerformIO $ do
let dim = offset + diagonal
vecsyms <- vector_to_vecbasic syms
let mat = densematrix_new_rows_cols :: DenseMatrix (d + k) (d + k)
with2 mat vecsyms (\m syms -> cdensematrix_diag_ffi m syms offset)
throwOnSymIntException =<< with2 mat vecsyms (\m syms -> cdensematrix_diag_ffi m syms offset)


return mat

type Row = Int
type Col = Int



densematrix_get :: forall r c. (KnownNat r, KnownNat c) =>
DenseMatrix r c -> Finite r -> Finite c -> BasicSym
densematrix_get mat getr getc = unsafePerformIO $ do
sym <- basicsym_new
let indexr = fromIntegral $ (getFinite getr)
let indexc = fromIntegral $ (getFinite getc)
with2 mat sym (\m s -> cdensematrix_get_basic_ffi s m indexr indexc)
throwOnSymIntException =<< with2 mat sym (\m s -> cdensematrix_get_basic_ffi s m indexr indexc)

return sym

densematrix_set :: forall r c. (KnownNat r, KnownNat c) =>
DenseMatrix r c -> Finite r -> Finite c -> BasicSym -> DenseMatrix r c
densematrix_set mat r c sym = unsafePerformIO $ do
mat' <- _densematrix_copy mat
with2 mat' sym (\m s -> cdensematrix_set_basic_ffi
throwOnSymIntException =<< with2 mat' sym (\m s -> cdensematrix_set_basic_ffi
m
(fromIntegral . getFinite $ r)
(fromIntegral . getFinite $ c)
s)

return mat'


Expand All @@ -164,34 +207,50 @@ type NCols = Int

-- | provides dimenions of matrix. combination of the FFI calls
-- `dense_matrix_rows` and `dense_matrix_cols`
densematrix_size :: forall r c. (KnownNat r, KnownNat c) =>
DenseMatrix r c -> (NRows, NCols)
densematrix_size :: forall r c. (KnownNat r, KnownNat c) => DenseMatrix r c -> (NRows, NCols)
densematrix_size mat =
(fromIntegral . natVal $ (Proxy @ r), fromIntegral . natVal $ (Proxy @ c))

densematrix_add :: forall r c. (KnownNat r, KnownNat c) =>
DenseMatrix r c-> DenseMatrix r c -> DenseMatrix r c
DenseMatrix r c -> DenseMatrix r c -> DenseMatrix r c
densematrix_add mata matb = unsafePerformIO $ do
res <- densematrix_new
with3 res mata matb cdensematrix_add_matrix
throwOnSymIntException =<< with3 res mata matb cdensematrix_add_matrix_ffi
return res


densematrix_mul_matrix :: forall r k c. (KnownNat r, KnownNat k, KnownNat c) =>
DenseMatrix r k -> DenseMatrix k c -> DenseMatrix r c
densematrix_mul_matrix mata matb = unsafePerformIO $ do
res <- densematrix_new
with3 res mata matb cdensematrix_mul_matrix
throwOnSymIntException =<< with3 res mata matb cdensematrix_mul_matrix_ffi
return res


densematrix_mul_scalar :: forall r c. (KnownNat r, KnownNat c) =>
DenseMatrix r c -> BasicSym -> DenseMatrix r c
densematrix_mul_scalar mata sym = unsafePerformIO $ do
res <- densematrix_new
with3 res mata sym cdensematrix_mul_scalar
throwOnSymIntException =<< with3 res mata sym cdensematrix_mul_scalar_ffi
return res

det :: forall r c. (KnownNat r, KnownNat c) => DenseMatrix r c -> BasicSym
det d = unsafePerformIO $ do
sym <- basicsym_new
throwOnSymIntException =<< with2 sym d cdensematrix_det_ffi
return sym

inv :: forall r c. (KnownNat r, KnownNat c) => DenseMatrix r c -> DenseMatrix r c
inv d = unsafePerformIO $ do
m <- densematrix_new
throwOnSymIntException =<< with2 m d cdensematrix_inv_ffi
return m

transpose :: forall r c. (KnownNat r, KnownNat c) => DenseMatrix r c -> DenseMatrix r c
transpose d = unsafePerformIO $ do
m <- densematrix_new
throwOnSymIntException =<< with2 m d cdensematrix_transpose_ffi
return m

newtype L r c = L (DenseMatrix r c)
newtype U r c = U (DenseMatrix r c)
Expand All @@ -200,15 +259,15 @@ densematrix_lu :: (KnownNat r, KnownNat c) => DenseMatrix r c-> (L r c, U r c)
densematrix_lu mat = unsafePerformIO $ do
l <- densematrix_new
u <- densematrix_new
with3 l u mat cdensematrix_lu
throwOnSymIntException =<< with3 l u mat cdensematrix_lu
return (L l, U u)

newtype D r c = D (DenseMatrix r c)
densematrix_ldl :: (KnownNat r, KnownNat c) => DenseMatrix r c-> (L r c, D r c)
densematrix_ldl mat = unsafePerformIO $ do
l <- densematrix_new
d <- densematrix_new
with3 l d mat cdensematrix_ldl
throwOnSymIntException =<< with3 l d mat cdensematrix_ldl

return (L l, D d)

Expand All @@ -217,7 +276,7 @@ newtype FFLU r c = FFLU (DenseMatrix r c)
densematrix_fflu :: (KnownNat r, KnownNat c) => DenseMatrix r c -> FFLU r c
densematrix_fflu mat = unsafePerformIO $ do
fflu <- densematrix_new
with2 fflu mat cdensematrix_fflu
throwOnSymIntException =<< with2 fflu mat cdensematrix_fflu
return (FFLU fflu)


Expand All @@ -228,7 +287,7 @@ densematrix_ffldu mat = unsafePerformIO $ do
d <- densematrix_new
u <- densematrix_new

with4 l d u mat cdensematrix_ffldu
throwOnSymIntException =<< with4 l d u mat cdensematrix_ffldu
return (L l, D d, U u)

-- solve A x = B
Expand All @@ -237,33 +296,50 @@ densematrix_lu_solve :: (KnownNat r, KnownNat c) =>
DenseMatrix r c -> DenseMatrix r c -> DenseMatrix r c
densematrix_lu_solve a b = unsafePerformIO $ do
x <- densematrix_new
with3 x a b cdensematrix_lu_solve
throwOnSymIntException =<< with3 x a b cdensematrix_lu_solve
return x

foreign import ccall unsafe "symengine/cwrapper.h dense_matrix_new" cdensematrix_new_ffi :: IO (Ptr CDenseMatrix)
foreign import ccall unsafe "symengine/cwrapper.h &dense_matrix_free" cdensematrix_free_ffi :: FunPtr ((Ptr CDenseMatrix) -> IO ())
foreign import ccall unsafe "symengine/cwrapper.h dense_matrix_new_rows_cols" cdensematrix_new_rows_cols_ffi :: CUInt -> CUInt -> IO (Ptr CDenseMatrix)
foreign import ccall unsafe "symengine/cwrapper.h dense_matrix_new_vec" cdensematrix_new_vec_ffi :: CUInt -> CUInt -> Ptr CVecBasic -> IO (Ptr CDenseMatrix)
foreign import ccall unsafe "symengine/cwrapper.h dense_matrix_eye" cdensematrix_eye_ffi :: Ptr CDenseMatrix -> CULong -> CULong -> CULong -> IO ()
foreign import ccall unsafe "symengine/cwrapper.h dense_matrix_diag" cdensematrix_diag_ffi :: Ptr CDenseMatrix -> Ptr CVecBasic -> CULong -> IO ()
foreign import ccall unsafe "symengine/cwrapper.h dense_matrix_zeros" cdensematrix_zeros_ffi :: Ptr CDenseMatrix -> CULong -> CULong -> IO CInt
foreign import ccall unsafe "symengine/cwrapper.h dense_matrix_eye" cdensematrix_eye_ffi :: Ptr CDenseMatrix -> CULong -> CULong -> CULong -> IO CInt
foreign import ccall unsafe "symengine/cwrapper.h dense_matrix_diag" cdensematrix_diag_ffi :: Ptr CDenseMatrix -> Ptr CVecBasic -> CULong -> IO CInt
foreign import ccall "symengine/cwrapper.h dense_matrix_eq" cdensematrix_eq_ffi :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO CInt
foreign import ccall "symengine/cwrapper.h dense_matrix_set" cdensematrix_set_ffi :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO ()
foreign import ccall "symengine/cwrapper.h dense_matrix_set" cdensematrix_set_ffi :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO CInt

foreign import ccall "symengine/cwrapper.h dense_matrix_str" cdensematrix_str_ffi :: Ptr CDenseMatrix -> IO CString

foreign import ccall "symengine/cwrapper.h dense_matrix_get_basic" cdensematrix_get_basic_ffi :: Ptr (CBasicSym) -> Ptr CDenseMatrix -> CUInt -> CUInt -> IO (Ptr CDenseMatrix)
foreign import ccall "symengine/cwrapper.h dense_matrix_set_basic" cdensematrix_set_basic_ffi :: Ptr CDenseMatrix -> CUInt -> CUInt -> Ptr (CBasicSym) -> IO ()
foreign import ccall "symengine/cwrapper.h dense_matrix_get_basic" cdensematrix_get_basic_ffi :: Ptr (CBasicSym) -> Ptr CDenseMatrix -> CUInt -> CUInt -> IO CInt
foreign import ccall "symengine/cwrapper.h dense_matrix_set_basic" cdensematrix_set_basic_ffi :: Ptr CDenseMatrix -> CUInt -> CUInt -> Ptr (CBasicSym) -> IO CInt


foreign import ccall "symengine/cwrapper.h dense_matrix_rows" cdensematrix_rows_ffi :: Ptr CDenseMatrix -> IO CULong
foreign import ccall "symengine/cwrapper.h dense_matrix_cols" cdensematrix_cols_ffi :: Ptr CDenseMatrix -> IO CULong

foreign import ccall "symengine/cwrapper.h dense_matrix_add_matrix" cdensematrix_add_matrix :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO ()
foreign import ccall "symengine/cwrapper.h dense_matrix_mul_matrix" cdensematrix_mul_matrix :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO ()
foreign import ccall "symengine/cwrapper.h dense_matrix_mul_scalar" cdensematrix_mul_scalar :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CBasicSym -> IO ()
foreign import ccall "symengine/cwrapper.h dense_matrix_add_matrix"
cdensematrix_add_matrix_ffi :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO CInt

foreign import ccall "symengine/cwrapper.h dense_matrix_mul_matrix"
cdensematrix_mul_matrix_ffi :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO CInt

foreign import ccall "symengine/cwrapper.h dense_matrix_mul_scalar"
cdensematrix_mul_scalar_ffi :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CBasicSym -> IO CInt

foreign import ccall "symengine/cwrapper.h dense_matrix_det"
cdensematrix_det_ffi :: Ptr CBasicSym -> Ptr CDenseMatrix -> IO CInt


foreign import ccall "symengine/cwrapper.h dense_matrix_inv"
cdensematrix_inv_ffi :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO CInt


foreign import ccall "symengine/cwrapper.h dense_matrix_transpose"
cdensematrix_transpose_ffi :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO CInt

foreign import ccall "symengine/cwrapper.h dense_matrix_LU" cdensematrix_lu :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO ()
foreign import ccall "symengine/cwrapper.h dense_matrix_LDL" cdensematrix_ldl :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO ()
foreign import ccall "symengine/cwrapper.h dense_matrix_FFLU" cdensematrix_fflu :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO ()
foreign import ccall "symengine/cwrapper.h dense_matrix_FFLDU" cdensematrix_ffldu :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO ()
foreign import ccall "symengine/cwrapper.h dense_matrix_LU_solve" cdensematrix_lu_solve:: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO ()
foreign import ccall "symengine/cwrapper.h dense_matrix_LU" cdensematrix_lu :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO CInt
foreign import ccall "symengine/cwrapper.h dense_matrix_LDL" cdensematrix_ldl :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO CInt
foreign import ccall "symengine/cwrapper.h dense_matrix_FFLU" cdensematrix_fflu :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO CInt
foreign import ccall "symengine/cwrapper.h dense_matrix_FFLDU" cdensematrix_ffldu :: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO CInt
foreign import ccall "symengine/cwrapper.h dense_matrix_LU_solve" cdensematrix_lu_solve:: Ptr CDenseMatrix -> Ptr CDenseMatrix -> Ptr CDenseMatrix -> IO CInt
27 changes: 13 additions & 14 deletions src/Symengine/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ module Symengine.Internal
CBasicSym,
CVecBasic,
SymengineException(NoException, RuntimeError, DivByZero, NotImplemented, DomainError, ParseError),
liftException,
forceException
forceException,
throwOnSymIntException
) where

import Prelude
Expand Down Expand Up @@ -46,18 +46,17 @@ data SymengineException = NoException |

instance Exception SymengineException

liftException :: CInt -> a -> Either SymengineException a
liftException exceptid a = let
exception = cIntToEnum exceptid
in
if exception == NoException
then Right a
else Left exception

forceException :: Either SymengineException a -> IO ()
forceException eithera = case eithera of
Left error -> throwIO error
Right a -> return ()

-- interpret the CInt as a SymengineException, and
-- throw if it is actually an error
throwOnSymIntException :: CInt -> IO ()
throwOnSymIntException i = forceException . cIntToEnum $ i

forceException :: SymengineException -> IO ()
forceException exception =
case exception of
NoException -> return ()
error @ _ -> throwIO error

cIntToEnum :: Enum a => CInt -> a
cIntToEnum = toEnum . fromIntegral
Expand Down
Loading