1
- {-# LANGUAGE TemplateHaskell #-}
2
- {-# LANGUAGE RankNTypes #-}
3
- {-# LANGUAGE ScopedTypeVariables #-}
4
- {-# LANGUAGE TypeOperators #-}
5
1
{-# LANGUAGE BangPatterns #-}
6
- {-# LANGUAGE TypeApplications #-}
7
2
{-# LANGUAGE DeriveDataTypeable #-}
8
3
{-# LANGUAGE FlexibleContexts #-}
4
+ {-# LANGUAGE RankNTypes #-}
5
+ {-# LANGUAGE ScopedTypeVariables #-}
6
+ {-# LANGUAGE TemplateHaskell #-}
7
+ {-# LANGUAGE TypeApplications #-}
8
+ {-# LANGUAGE TypeOperators #-}
9
+
9
10
11
+ module HedgehogTest where
12
+
13
+ import ChebApproxAcc
14
+ import ChebTypingAcc
15
+
16
+ import Hedgehog
17
+ import qualified Hedgehog.Gen as Gen
18
+ import qualified Hedgehog.Range as Range
19
+ import Prelude as P
20
+ import Data.Array.Accelerate as A
21
+ import Data.Array.Accelerate.Debug as A
22
+ import Data.Array.Accelerate.Interpreter as I
23
+ import Data.Array.Accelerate.LLVM.Native as CPU
24
+ import Data.Array.Accelerate.Array.Sugar as Sugar ( (!) , Arrays , Array , Shape , Elt , DIM0 , DIM1 , DIM2 , DIM3 , Z (.. ), (:.) (.. ), fromList , size )
25
+ import Data.Array.Accelerate.Test.Similar
26
+ import Test.Tasty
27
+ import Test.Tasty.Hedgehog
28
+ import Data.Typeable
29
+
30
+
31
+ {- -
32
+ data Expr =
33
+ Var String
34
+ | Lam String Expr
35
+ | App Expr Expr
36
+
37
+ -- Assuming we have a name generator
38
+ genName :: MonadGen m => m String
39
+
40
+ -- We can write a generator for expressions
41
+ genExpr :: MonadGen m => m Expr
42
+ genExpr =
43
+ Gen.recursive Gen.choice [
44
+ -- non-recursive generators
45
+ Var <$> genName
46
+ ] [
47
+ -- recursive generators
48
+ Gen.subtermM genExpr (x -> Lam <$> genName <*> pure x)
49
+ , Gen.subterm2 genExpr genExpr App
50
+ ]
51
+ --}
52
+
53
+
54
+ -- Operator terms we want to generate random combinations of
55
+ data Term
56
+ = Val
57
+ | Sin Term
58
+ | Cos Term
59
+ | Exponential Term
60
+ | Divide Term Term
61
+ deriving Show
62
+
63
+ -- we can write a generator for expressions
64
+ genTerm :: MonadGen m => m Term
65
+ genTerm =
66
+ Gen. recursive Gen. choice
67
+ -- non-recursive generators
68
+ [ pure Val ]
69
+ -- recursive generators
70
+ [ Gen. subtermM genTerm (pure . Sin )
71
+ , Gen. subtermM genTerm (pure . Cos )
72
+ , Gen. subtermM genTerm (pure . Exponential )
73
+ ]
74
+
75
+ evalTerm :: Term -> Double -> Double
76
+ evalTerm Val x = x
77
+ evalTerm (Sin t) x = P. sin (evalTerm t x)
78
+ evalTerm (Cos t) x = P. cos (evalTerm t x)
79
+ evalTerm (Exponential t) x = P. exp (evalTerm t x)
80
+
81
+ termToAcc :: Term -> Exp Double -> Exp Double
82
+ termToAcc Val x = x
83
+ termToAcc (Sin t) x = A. sin (termToAcc t x)
84
+ termToAcc (Cos t) x = A. cos (termToAcc t x)
85
+ termToAcc (Exponential t) x = A. exp (termToAcc t x)
86
+
87
+
88
+ type Run = forall a . Arrays a = > Acc a -> a
89
+ type RunN = forall f . Afunction f = > f -> AfunctionR f
90
+
91
+ dim0 :: Gen DIM0
92
+ dim0 = return Z
93
+
94
+ dim1 :: Gen DIM1
95
+ dim1 = (Z :. ) <$> Gen. int (Range. linear 0 1024 )
96
+
97
+ f64 :: Gen Double
98
+ f64 = Gen. double (Range. linearFracFrom 0 (- log_flt_max) log_flt_max)
99
+
100
+ array :: (Shape sh , Elt e ) => sh -> Gen e -> Gen (Array sh e )
101
+ array sh gen = fromList sh <$> Gen. list (Range. singleton (Sugar. size sh)) gen
102
+
103
+ log_flt_max :: P. RealFloat a => a
104
+ log_flt_max = log flt_max
105
+
106
+ flt_max :: P. RealFloat a => a
107
+ flt_max = x
108
+ where
109
+ n = P. floatDigits x
110
+ b = P. floatRadix x
111
+ (_,u) = P. floatRange x
112
+ x = P. encodeFloat (b P. ^ n - 1 ) (u - n)
113
+
114
+ flt_min :: P. RealFloat a => a
115
+ flt_min = x
116
+ where
117
+ n = P. floatDigits x
118
+ b = P. floatRadix x
119
+ (l,_) = P. floatRange x
120
+ x = P. encodeFloat (b P. ^ n - 1 ) (l - n - 1 )
121
+
122
+ -- except :: Gen e -> (e -> Bool) -> Gen e
123
+ -- except gen f = do
124
+ -- v <- gen
125
+ -- when (f v) Gen.discard
126
+ -- return v
127
+
128
+ prop_commutative :: Property
129
+ prop_commutative =
130
+ property $ do
131
+ sh <- forAll dim1
132
+ xs <- forAll (array sh f64)
133
+ ys <- forAll (array sh f64)
134
+ let go = CPU. runN (\ x y -> fromCheb (toCheb x + toCheb y))
135
+ go xs ys === go ys xs
136
+
137
+ prop_approx_error :: Property
138
+ prop_approx_error =
139
+ property $ do
140
+ term <- forAll genTerm
141
+ let f = termToAcc term
142
+ pol = chebf f 20 -- TODO
143
+ -- pol = chebfPrecise f -- TODO
144
+ err = approxError f pol
145
+ ok = A. all (A. < 1.0E-15 ) err
146
+ --
147
+ True === indexArray (CPU. run ok) Z
148
+
149
+ fullrange :: P. RealFloat e => (Range e -> Gen e ) -> Gen e
150
+ fullrange gen = gen (Range. linearFracFrom 0 (- flt_max) flt_max)
151
+
152
+ tests :: IO Bool
153
+ tests =
154
+ checkParallel $$ (discover)
10
155
11
- module HedgehogTest
12
- where
13
- import Hedgehog
14
- import qualified Hedgehog.Gen as Gen
15
- import qualified Hedgehog.Range as Range
16
- import ChebApproxAcc
17
- import Prelude as P
18
- import Data.Array.Accelerate as A
19
- import Data.Array.Accelerate.Debug as A
20
- import Data.Array.Accelerate.Interpreter as I
21
- import Data.Array.Accelerate.LLVM.Native as CPU
22
- import Data.Array.Accelerate.Array.Sugar as Sugar ( (!) , Arrays , Array , Shape , Elt , DIM0 , DIM1 , DIM2 , DIM3 , Z (.. ), (:.) (.. ), fromList , size )
23
- import Data.Array.Accelerate.Test.Similar
24
- import Test.Tasty
25
- import Test.Tasty.Hedgehog
26
- import Data.Typeable
27
-
28
-
29
- type Run = forall a . Arrays a = > Acc a -> a
30
- type RunN = forall f . Afunction f = > f -> AfunctionR f
31
-
32
- dim0 :: Gen DIM0
33
- dim0 = return Z
34
-
35
- dim1 :: Gen DIM1
36
- dim1 = (Z :. ) <$> Gen. int (Range. linear 0 1024 )
37
-
38
- array :: (Shape sh , Elt e ) => sh -> Gen e -> Gen (Array sh e )
39
- array sh gen = fromList sh <$> Gen. list (Range. singleton (Sugar. size sh)) gen
40
-
41
-
42
- flt_max :: A. RealFloat a => a
43
- flt_max = x
44
- where
45
- n = A. floatDigits x
46
- b = A. floatRadix x
47
- (_,u) = A. floatRange x
48
- x = A. encodeFloat (b A. ^ n - 1 ) (u - n)
49
-
50
- flt_min :: A. RealFloat a => a
51
- flt_min = x
52
- where
53
- n = A. floatDigits x
54
- b = A. floatRadix x
55
- (l,_) = A. floatRange x
56
- x = A. encodeFloat (b A. ^ n - 1 ) (l - n - 1 )
57
-
58
- {- except :: Gen e -> (e -> Bool) -> Gen e
59
- except gen f = do
60
- v <- gen
61
- when (f v) Gen.discard
62
- return v
63
-
64
- splitEvery :: Int -> [a] -> [[a]]
65
- splitEvery _ [] = cycle [[]]
66
- splitEvery n xs =
67
- let (h,t) = splitAt n xs
68
- in h : splitEvery n t
69
-
70
- splitPlaces :: [Int] -> [a] -> [[a]]
71
- splitPlaces [] _ = []
72
- splitPlaces (i:is) vs =
73
- let (h,t) = splitAt i vs
74
- in h : splitPlaces is t -}
75
-
76
-
77
- prop_commutative :: Property
78
- prop_commutative =
79
- property $ do
80
- xs <- forAll $ (Gen. list (Range. linear 0 100 ) (Gen. double (Range. linearFrac 0 2 )))
81
- ys <- forAll $ (Gen. list (Range. linear 0 100 ) (Gen. double (Range. linearFrac 0 2 )))
82
- (CPU. run $ use (fromList (Z :. (100 :: Int )) xs)) === (CPU. run $ use (fromList (Z :. (100 :: Int )) xs))
83
- -- (CPU.run $ (sumVectors ( use (fromList (Z :. 100) xs)) (use (fromList (Z :. 100 ) ys)))) === (CPU.run $ (sumVectors ( use (fromList (Z :. 100) xs)) (use (fromList (Z :. 100 ) ys))) )
84
-
85
- tests :: IO Bool
86
- tests =
87
- checkParallel $$ (discover)
88
-
89
-
90
- -- test = testProperty "commutative" $ test_sqrt CPU.runN sh (e (Range.linearFrac 0 flt_max))
91
-
92
- test_sqrt
93
- :: (Shape sh , Similar e , P. Eq sh , P. Floating e , A. Floating e )
94
- => RunN
95
- -> Gen sh
96
- -> Gen e
97
- -> Property
98
- test_sqrt runN dim e =
99
- property $ do
100
- sh <- forAll dim
101
- xs <- forAll (array sh e)
102
- let ! go = runN (A. map sqrt ) in go xs ~~~ mapRef sqrt xs
103
-
104
- mapRef :: (Shape sh , Elt a , Elt b ) => (a -> b ) -> Array sh a -> Array sh b
105
- mapRef f xs = fromFunction (arrayShape xs) (\ ix -> f (xs Sugar. ! ix))
106
-
107
- newtype TestDouble = TestDouble Bool deriving (P.Eq , P.Show , Typeable )
108
-
109
- test_map :: RunN -> TestTree
110
- test_map runN =
111
- testGroup " map"
112
- [
113
- at @ TestDouble $ testFloatingElt Gen. double
114
- ]
115
- where
116
-
117
- testFloatingElt
118
- :: forall a . (P. RealFloat a , A. Floating a , A. RealFrac a , Similar a )
119
- => (Range a -> Gen a )
120
- -> TestTree
121
- testFloatingElt e =
122
- testGroup (show (typeOf (undefined :: a )))
123
- [
124
- testDim dim1
125
- ]
126
- where
127
- testDim
128
- :: forall sh . (Shape sh , P. Eq sh )
129
- => Gen sh
130
- -> TestTree
131
- testDim sh =
132
- testGroup (" DIM" P. ++ show (rank @ sh ))
133
- [ -- operators on Num
134
- testProperty " sqrt" $ test_sqrt runN sh (e (Range. linearFrac 0 flt_max))
135
- ]
136
-
137
- fullrange :: P. RealFloat e => (Range e -> Gen e ) -> Gen e
138
- fullrange gen = gen (Range. linearFracFrom 0 (- flt_max) flt_max)
139
-
0 commit comments