diff --git a/src/ArrayFire/Algorithm.hs b/src/ArrayFire/Algorithm.hs index 108beb4..11a054b 100644 --- a/src/ArrayFire/Algorithm.hs +++ b/src/ArrayFire/Algorithm.hs @@ -24,30 +24,26 @@ -------------------------------------------------------------------------------- module ArrayFire.Algorithm where -import ArrayFire.Array import ArrayFire.FFI import ArrayFire.Internal.Algorithm import ArrayFire.Internal.Types -import Foreign.C.Types -import Data.Word - -- | Sum all of the elements in 'Array' along the specified dimension -- -- >>> A.sum (A.vector @Double 10 [1..]) 0 -- 55.0 -- --- >>> A.sum (A.matrix @Double (10,10) [[2..],[2..]]) 0 +-- >>> A.matrix @Double (10,10) $ replicate 10 [1..] -- 65.0 sum :: AFType a => Array a -- ^ Array to sum -> Int - -- ^ Dimension along which to perform sum - -> a + -- ^ 0-based Dimension along which to perform sum + -> Array a -- ^ Will return the sum of all values in the input array along the specified dimension -sum x (fromIntegral -> n) = getScalar (x `op1` (\p a -> af_sum p a n)) +sum x (fromIntegral -> n) = (x `op1` (\p a -> af_sum p a n)) -- | Sum all of the elements in 'Array' along the specified dimension, using a default value for NaN -- @@ -61,9 +57,9 @@ sumNaN -- ^ Dimension along which to perform sum -> Double -- ^ Default value to use in the case of NaN - -> a + -> Array a -- ^ Will return the sum of all values in the input array along the specified dimension, substituted with the default value -sumNaN n (fromIntegral -> i) d = getScalar (n `op1` (\p a -> af_sum_nan p a i d)) +sumNaN n (fromIntegral -> i) d = (n `op1` (\p a -> af_sum_nan p a i d)) -- | Product all of the elements in 'Array' along the specified dimension -- @@ -75,9 +71,9 @@ product -- ^ Array to product -> Int -- ^ Dimension along which to perform product - -> a + -> Array a -- ^ Will return the product of all values in the input array along the specified dimension -product x (fromIntegral -> n) = getScalar (x `op1` (\p a -> af_product p a n)) +product x (fromIntegral -> n) = (x `op1` (\p a -> af_product p a n)) -- | Product all of the elements in 'Array' along the specified dimension, using a default value for NaN -- @@ -91,9 +87,9 @@ productNaN -- ^ Dimension along which to perform product -> Double -- ^ Default value to use in the case of NaN - -> a + -> Array a -- ^ Will return the product of all values in the input array along the specified dimension, substituted with the default value -productNaN n (fromIntegral -> i) d = getScalar (n `op1` (\p a -> af_product_nan p a i d)) +productNaN n (fromIntegral -> i) d = n `op1` (\p a -> af_product_nan p a i d) -- | Take the minimum of an 'Array' along a specific dimension -- @@ -105,9 +101,9 @@ min -- ^ Array input -> Int -- ^ Dimension along which to retrieve the min element - -> a + -> Array a -- ^ Will contain the minimum of all values in the input array along dim -min x (fromIntegral -> n) = getScalar (x `op1` (\p a -> af_min p a n)) +min x (fromIntegral -> n) = x `op1` (\p a -> af_min p a n) -- | Take the maximum of an 'Array' along a specific dimension -- @@ -119,9 +115,9 @@ max -- ^ Array input -> Int -- ^ Dimension along which to retrieve the max element - -> a + -> Array a -- ^ Will contain the maximum of all values in the input array along dim -max x (fromIntegral -> n) = getScalar (x `op1` (\p a -> af_max p a n)) +max x (fromIntegral -> n) = x `op1` (\p a -> af_max p a n) -- | Find if all elements in an 'Array' are 'True' along a dimension -- @@ -133,10 +129,10 @@ allTrue -- ^ Array input -> Int -- ^ Dimension along which to see if all elements are True - -> Bool + -> Array a -- ^ Will contain the maximum of all values in the input array along dim allTrue x (fromIntegral -> n) = - toEnum . fromIntegral $ getScalar @CBool @a (x `op1` (\p a -> af_all_true p a n)) + x `op1` (\p a -> af_all_true p a n) -- | Find if any elements in an 'Array' are 'True' along a dimension -- @@ -148,10 +144,10 @@ anyTrue -- ^ Array input -> Int -- ^ Dimension along which to see if all elements are True - -> Bool + -> Array a -- ^ Returns if all elements are true anyTrue x (fromIntegral -> n) = - toEnum . fromIntegral $ getScalar @CBool @a (x `op1` (\p a -> af_any_true p a n)) + (x `op1` (\p a -> af_any_true p a n)) -- | Count elements in an 'Array' along a dimension -- @@ -163,9 +159,9 @@ count -- ^ Array input -> Int -- ^ Dimension along which to count - -> Int + -> Array Int -- ^ Count of all elements along dimension -count x (fromIntegral -> n) = fromIntegral $ getScalar @Word32 @a (x `op1` (\p a -> af_count p a n)) +count x (fromIntegral -> n) = x `op1d` (\p a -> af_count p a n) -- | Sum all elements in an 'Array' along all dimensions -- diff --git a/src/ArrayFire/Arith.hs b/src/ArrayFire/Arith.hs index fbe2e0c..83c0f90 100644 --- a/src/ArrayFire/Arith.hs +++ b/src/ArrayFire/Arith.hs @@ -35,6 +35,8 @@ import ArrayFire.FFI import ArrayFire.Internal.Arith import ArrayFire.Internal.Types +import Foreign.C.Types + -- | Adds two 'Array' objects -- -- >>> A.scalar @Int 1 `A.add` A.scalar @Int 1 @@ -202,10 +204,10 @@ lt -- ^ First input -> Array a -- ^ Second input - -> Array a + -> Array CBool -- ^ Result of less than lt x y = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_lt arr arr1 arr2 1 -- | Test if on 'Array' is less than another 'Array' @@ -224,10 +226,10 @@ ltBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array a + -> Array CBool -- ^ Result of less than ltBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_lt arr arr1 arr2 batch -- | Test if an 'Array' is greater than another 'Array' @@ -244,10 +246,10 @@ gt -- ^ First input -> Array a -- ^ Second input - -> Array a + -> Array CBool -- ^ Result of gt gt x y = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_gt arr arr1 arr2 1 -- | Test if an 'Array' is greater than another 'Array' @@ -262,10 +264,10 @@ gtBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array a + -> Array CBool -- ^ Result of gt gtBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_gt arr arr1 arr2 batch -- | Test if one 'Array' is less than or equal to another 'Array' @@ -282,10 +284,10 @@ le -- ^ First input -> Array a -- ^ Second input - -> Array a + -> Array CBool -- ^ Result of less than or equal le x y = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_le arr arr1 arr2 1 -- | Test if one 'Array' is less than or equal to another 'Array' @@ -304,10 +306,10 @@ leBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array a + -> Array CBool -- ^ Result of less than or equal leBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_le arr arr1 arr2 batch -- | Test if one 'Array' is greater than or equal to another 'Array' @@ -324,10 +326,10 @@ ge -- ^ First input -> Array a -- ^ Second input - -> Array a + -> Array CBool -- ^ Result of greater than or equal ge x y = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_ge arr arr1 arr2 1 -- | Test if one 'Array' is greater than or equal to another 'Array' @@ -343,10 +345,10 @@ geBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array a + -> Array CBool -- ^ Result of greater than or equal geBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_ge arr arr1 arr2 batch -- | Test if one 'Array' is equal to another 'Array' @@ -364,10 +366,10 @@ eq -- ^ First input -> Array a -- ^ Second input - -> Array a + -> Array CBool -- ^ Result of equal eq x y = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_eq arr arr1 arr2 1 -- | Test if one 'Array' is equal to another 'Array' @@ -382,10 +384,10 @@ eqBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array a + -> Array CBool -- ^ Result of equal eqBatched x y (fromIntegral . fromEnum -> batch) = - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_eq arr arr1 arr2 batch -- | Test if one 'Array' is not equal to another 'Array' @@ -402,10 +404,10 @@ neq -- ^ First input -> Array a -- ^ Second input - -> Array a + -> Array CBool -- ^ Result of not equal neq x y = - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_neq arr arr1 arr2 1 -- | Test if one 'Array' is not equal to another 'Array' @@ -420,10 +422,10 @@ neqBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array a + -> Array CBool -- ^ Result of not equal neqBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_neq arr arr1 arr2 batch -- | Logical 'and' one 'Array' with another @@ -439,10 +441,10 @@ and -- ^ First input -> Array a -- ^ Second input - -> Array a + -> Array CBool -- ^ Result of and and x y = - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_and arr arr1 arr2 1 -- | Logical 'and' one 'Array' with another @@ -459,10 +461,10 @@ andBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array a + -> Array CBool -- ^ Result of and andBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_and arr arr1 arr2 batch -- | Logical 'or' one 'Array' with another @@ -478,10 +480,10 @@ or -- ^ First input -> Array a -- ^ Second input - -> Array a + -> Array CBool -- ^ Result of or or x y = - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_or arr arr1 arr2 1 -- | Logical 'or' one 'Array' with another @@ -499,10 +501,10 @@ orBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array a + -> Array CBool -- ^ Result of or orBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_or arr arr1 arr2 batch -- | Not the values of an 'Array' @@ -515,9 +517,9 @@ not :: AFType a => Array a -- ^ Input 'Array' - -> Array a + -> Array CBool -- ^ Result of 'not' on an 'Array' -not = flip op1 af_not +not = flip op1d af_not -- | Bitwise and the values in one 'Array' against another 'Array' -- @@ -531,10 +533,10 @@ bitAnd -- ^ First input -> Array a -- ^ Second input - -> Array a + -> Array CBool -- ^ Result of bitwise and bitAnd x y = - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_bitand arr arr1 arr2 1 -- | Bitwise and the values in one 'Array' against another 'Array' @@ -551,10 +553,10 @@ bitAndBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array a + -> Array CBool -- ^ Result of bitwise and bitAndBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_bitand arr arr1 arr2 batch -- | Bitwise or the values in one 'Array' against another 'Array' @@ -569,10 +571,10 @@ bitOr -- ^ First input -> Array a -- ^ Second input - -> Array a + -> Array CBool -- ^ Result of bit or bitOr x y = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_bitor arr arr1 arr2 1 -- | Bitwise or the values in one 'Array' against another 'Array' @@ -589,10 +591,10 @@ bitOrBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array a + -> Array CBool -- ^ Result of bit or bitOrBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_bitor arr arr1 arr2 batch -- | Bitwise xor the values in one 'Array' against another 'Array' @@ -607,10 +609,10 @@ bitXor -- ^ First input -> Array a -- ^ Second input - -> Array a + -> Array CBool -- ^ Result of bit xor bitXor x y = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_bitxor arr arr1 arr2 1 -- | Bitwise xor the values in one 'Array' against another 'Array' @@ -627,10 +629,10 @@ bitXorBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array a + -> Array CBool -- ^ Result of bit xor bitXorBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_bitxor arr arr1 arr2 batch -- | Left bit shift the values in one 'Array' against another 'Array' @@ -645,10 +647,10 @@ bitShiftL -- ^ First input -> Array a -- ^ Second input - -> Array a + -> Array CBool -- ^ Result of bit shift left bitShiftL x y = - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_bitshiftl arr arr1 arr2 1 -- | Left bit shift the values in one 'Array' against another 'Array' @@ -665,10 +667,10 @@ bitShiftLBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array a + -> Array CBool -- ^ Result of bit shift left bitShiftLBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_bitshiftl arr arr1 arr2 batch -- | Right bit shift the values in one 'Array' against another 'Array' @@ -683,10 +685,10 @@ bitShiftR -- ^ First input -> Array a -- ^ Second input - -> Array a + -> Array CBool -- ^ Result of bit shift right bitShiftR x y = - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_bitshiftr arr arr1 arr2 1 -- | Right bit shift the values in one 'Array' against another 'Array' @@ -703,10 +705,10 @@ bitShiftRBatched -- ^ Second input -> Bool -- ^ Use batch - -> Array a + -> Array CBool -- ^ Result of bit shift left bitShiftRBatched x y (fromIntegral . fromEnum -> batch) = do - x `op2` y $ \arr arr1 arr2 -> + x `op2bool` y $ \arr arr1 arr2 -> af_bitshiftr arr arr1 arr2 batch -- | Cast one 'Array' into another diff --git a/src/ArrayFire/FFI.hs b/src/ArrayFire/FFI.hs index 4e4609d..e56d1f9 100644 --- a/src/ArrayFire/FFI.hs +++ b/src/ArrayFire/FFI.hs @@ -85,6 +85,24 @@ op2 (Array fptr1) (Array fptr2) op = fptr <- newForeignPtr af_release_array_finalizer ptr pure (Array fptr) +op2bool + :: Array b + -> Array a + -> (Ptr AFArray -> AFArray -> AFArray -> IO AFErr) + -> Array CBool +{-# NOINLINE op2bool #-} +op2bool (Array fptr1) (Array fptr2) op = + unsafePerformIO $ do + withForeignPtr fptr1 $ \ptr1 -> + withForeignPtr fptr2 $ \ptr2 -> do + ptr <- + alloca $ \ptrInput -> do + throwAFError =<< op ptrInput ptr1 ptr2 + peek ptrInput + fptr <- newForeignPtr af_release_array_finalizer ptr + pure (Array fptr) + + op2p :: Array a -> (Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr) diff --git a/test/ArrayFire/AlgorithmSpec.hs b/test/ArrayFire/AlgorithmSpec.hs index 1121dd3..0541372 100644 --- a/test/ArrayFire/AlgorithmSpec.hs +++ b/test/ArrayFire/AlgorithmSpec.hs @@ -17,8 +17,8 @@ spec = A.sum (A.scalar @A.Word32 10) 0 `shouldBe` 10 A.sum (A.scalar @A.Word64 10) 0 `shouldBe` 10 A.sum (A.scalar @Double 10) 0 `shouldBe` 10.0 - A.sum (A.scalar @(A.Complex Double) (1 A.:+ 1)) 0 `shouldBe` 1 A.:+ 1 - A.sum (A.scalar @(A.Complex Float) (1 A.:+ 1)) 0 `shouldBe` 1 A.:+ 1 + A.sum (A.scalar @(A.Complex Double) (1 A.:+ 1)) 0 `shouldBe` A.scalar (1 A.:+ 1) + A.sum (A.scalar @(A.Complex Float) (1 A.:+ 1)) 0 `shouldBe` A.scalar (1 A.:+ 1) A.sum (A.scalar @A.CBool 1) 0 `shouldBe` 1 A.sum (A.scalar @A.CBool 0) 0 `shouldBe` 0 it "Should sum a vector" $ do @@ -30,15 +30,15 @@ spec = A.sum (A.vector @A.Word32 10 [1..]) 0 `shouldBe` 55 A.sum (A.vector @A.Word64 10 [1..]) 0 `shouldBe` 55 A.sum (A.vector @Double 10 [1..]) 0 `shouldBe` 55.0 - A.sum (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` 10.0 A.:+ 10.0 - A.sum (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` 10.0 A.:+ 10.0 + A.sum (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (10.0 A.:+ 10.0) + A.sum (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (10.0 A.:+ 10.0) A.sum (A.vector @A.CBool 10 (repeat 1)) 0 `shouldBe` 10 A.sum (A.vector @A.CBool 10 (repeat 0)) 0 `shouldBe` 0 it "Should sum a default value to replace NaN" $ do A.sumNaN (A.vector @Float 10 [1..]) 0 1.0 `shouldBe` 55 A.sumNaN (A.vector @Double 2 [acos 2, acos 2]) 0 50 `shouldBe` 100 - A.sumNaN (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` 10.0 A.:+ 10.0 - A.sumNaN (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` 10.0 A.:+ 10.0 + A.sumNaN (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` A.scalar (10.0 A.:+ 10.0) + A.sumNaN (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` A.scalar (10.0 A.:+ 10.0) it "Should product a scalar" $ do A.product (A.scalar @Int 10) 0 `shouldBe` 10 A.product (A.scalar @A.Int64 10) 0 `shouldBe` 10 @@ -48,8 +48,8 @@ spec = A.product (A.scalar @A.Word32 10) 0 `shouldBe` 10 A.product (A.scalar @A.Word64 10) 0 `shouldBe` 10 A.product (A.scalar @Double 10) 0 `shouldBe` 10.0 - A.product (A.scalar @(A.Complex Double) (1 A.:+ 1)) 0 `shouldBe` 1 A.:+ 1 - A.product (A.scalar @(A.Complex Float) (1 A.:+ 1)) 0 `shouldBe` 1 A.:+ 1 + A.product (A.scalar @(A.Complex Double) (1 A.:+ 1)) 0 `shouldBe` A.scalar (1 A.:+ 1) + A.product (A.scalar @(A.Complex Float) (1 A.:+ 1)) 0 `shouldBe` A.scalar (1 A.:+ 1) A.product (A.scalar @A.CBool 1) 0 `shouldBe` 1 A.product (A.scalar @A.CBool 0) 0 `shouldBe` 0 it "Should product a vector" $ do @@ -61,15 +61,15 @@ spec = A.product (A.vector @A.Word32 10 [1..]) 0 `shouldBe` 3628800 A.product (A.vector @A.Word64 10 [1..]) 0 `shouldBe` 3628800 A.product (A.vector @Double 10 [1..]) 0 `shouldBe` 3628800.0 - A.product (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` 0.0 A.:+ 32.0 - A.product (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` 0.0 A.:+ 32.0 + A.product (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (0.0 A.:+ 32.0) + A.product (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (0.0 A.:+ 32.0) A.product (A.vector @A.CBool 10 (repeat 1)) 0 `shouldBe` 10 -- FIXME: This is a bug, should be 0 A.product (A.vector @A.CBool 10 (repeat 0)) 0 `shouldBe` 0 it "Should product a default value to replace NaN" $ do A.productNaN (A.vector @Float 10 [1..]) 0 1.0 `shouldBe` 3628800.0 A.productNaN (A.vector @Double 2 [acos 2, acos 2]) 0 50 `shouldBe` 2500 - A.productNaN (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` 0.0 A.:+ 32 - A.productNaN (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` 0 A.:+ 32 + A.productNaN (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` A.scalar (0.0 A.:+ 32) + A.productNaN (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 1.0 `shouldBe` A.scalar (0 A.:+ 32) it "Should take the minimum element of a vector" $ do A.min (A.vector @Int 10 [1..]) 0 `shouldBe` 1 A.min (A.vector @A.Int64 10 [1..]) 0 `shouldBe` 1 @@ -79,19 +79,19 @@ spec = A.min (A.vector @A.Word32 10 [1..]) 0 `shouldBe` 1 A.min (A.vector @A.Word64 10 [1..]) 0 `shouldBe` 1 A.min (A.vector @Double 10 [1..]) 0 `shouldBe` 1 - A.min (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` 1 A.:+ 1 - A.min (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` 1 A.:+ 1 + A.min (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (1 A.:+ 1) + A.min (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` A.scalar (1 A.:+ 1) A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1 A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1 it "Should find if all elements are true along dimension" $ do - A.allTrue (A.vector @Double 5 (repeat 12.0)) 0 `shouldBe` True - A.allTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` True - A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` False - A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` False + A.allTrue (A.vector @Double 5 (repeat 12.0)) 0 `shouldBe` 1 + A.allTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 1 + A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0 + A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0 it "Should find if any elements are true along dimension" $ do - A.anyTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` True - A.anyTrue (A.vector @Int 5 (repeat 23)) 0 `shouldBe` True - A.anyTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` False + A.anyTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 1 + A.anyTrue (A.vector @Int 5 (repeat 23)) 0 `shouldBe` 1 + A.anyTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0 it "Should get count of all elements" $ do A.count (A.vector @Int 5 (repeat 1)) 0 `shouldBe` 5 A.count (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 5 diff --git a/test/ArrayFire/ArithSpec.hs b/test/ArrayFire/ArithSpec.hs index d0eb16b..ae03a54 100644 --- a/test/ArrayFire/ArithSpec.hs +++ b/test/ArrayFire/ArithSpec.hs @@ -31,14 +31,15 @@ spec = 3 `shouldBe` cbrt @Double 27 it "Should take square root" $ do 2 `shouldBe` sqrt @Double 4 - it "Should lt Array" $ do - 2 < (3 :: Array Double) `shouldBe` True + it "Should lte Array" $ do - 2 <= (3 :: Array Double) `shouldBe` True + 2 `le` (3 :: Array Double) `shouldBe` 1 it "Should gte Array" $ do - 2 >= (3 :: Array Double) `shouldBe` False + 2 `ge` (3 :: Array Double) `shouldBe` 0 it "Should gt Array" $ do - 2 > (3 :: Array Double) `shouldBe` False + 2 `gt` (3 :: Array Double) `shouldBe` 0 + it "Should lt Array" $ do + 2 `le` (3 :: Array Double) `shouldBe` 1 it "Should eq Array" $ do 3 == (3 :: Array Double) `shouldBe` True it "Should and Array" $ do diff --git a/test/ArrayFire/LAPACKSpec.hs b/test/ArrayFire/LAPACKSpec.hs index 44fc237..2c9f554 100644 --- a/test/ArrayFire/LAPACKSpec.hs +++ b/test/ArrayFire/LAPACKSpec.hs @@ -11,12 +11,12 @@ spec = it "Should have LAPACK available" $ do A.isLAPACKAvailable `shouldBe` True it "Should perform svd" $ do - let (s,v,d) = A.svd $ A.matrix @Double (4,2) [ [1,2], [3,4], [5,6], [7,8] ] + let (s,v,d) = A.svd $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ] A.getDims s `shouldBe` (4,4,1,1) A.getDims v `shouldBe` (2,1,1,1) A.getDims d `shouldBe` (2,2,1,1) it "Should perform svd in place" $ do - let (s,v,d) = A.svdInPlace $ A.matrix @Double (4,2) [ [1,2], [3,4], [5,6], [7,8] ] + let (s,v,d) = A.svdInPlace $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ] A.getDims s `shouldBe` (4,4,1,1) A.getDims v `shouldBe` (2,1,1,1) A.getDims d `shouldBe` (2,2,1,1)