Skip to content

Commit 1f8cee0

Browse files
committed
hedgehog random function generator
1 parent 50a77a2 commit 1f8cee0

File tree

2 files changed

+160
-144
lines changed

2 files changed

+160
-144
lines changed

chebApprox.cabal

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,26 @@ library
1515
exposed-modules: ChebApproxAcc, ChebTypingAcc, ChebMath, HedgehogTest
1616
hs-source-dirs: src
1717
build-depends: base >=4.12 && <4.13,
18-
accelerate,
19-
accelerate-llvm-native,
20-
hedgehog,
18+
accelerate,
19+
accelerate-llvm-native,
20+
hedgehog,
2121
tasty-hedgehog,
22-
tasty,
23-
clock,
24-
formatting,
25-
vector
22+
tasty,
23+
clock,
24+
formatting,
25+
vector
2626
default-language: Haskell2010
2727

2828
executable test
2929
main-is: Test.hs
3030
hs-source-dirs: test
3131
build-depends: base >=4.12 && <4.13,
32-
HUnit,
32+
HUnit,
3333
accelerate,
3434
accelerate-fft,
3535
accelerate-llvm-native,
36-
chebApprox,
37-
lens-accelerate,
36+
chebApprox,
37+
lens-accelerate,
3838
diagrams-lib,
3939
diagrams-cairo,
4040
Chart,

src/HedgehogTest.hs

Lines changed: 150 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1,139 +1,155 @@
1-
{-# LANGUAGE TemplateHaskell #-}
2-
{-# LANGUAGE RankNTypes #-}
3-
{-# LANGUAGE ScopedTypeVariables #-}
4-
{-# LANGUAGE TypeOperators #-}
51
{-# LANGUAGE BangPatterns #-}
6-
{-# LANGUAGE TypeApplications #-}
72
{-# LANGUAGE DeriveDataTypeable #-}
83
{-# LANGUAGE FlexibleContexts #-}
4+
{-# LANGUAGE RankNTypes #-}
5+
{-# LANGUAGE ScopedTypeVariables #-}
6+
{-# LANGUAGE TemplateHaskell #-}
7+
{-# LANGUAGE TypeApplications #-}
8+
{-# LANGUAGE TypeOperators #-}
9+
910

11+
module HedgehogTest where
12+
13+
import ChebApproxAcc
14+
import ChebTypingAcc
15+
16+
import Hedgehog
17+
import qualified Hedgehog.Gen as Gen
18+
import qualified Hedgehog.Range as Range
19+
import Prelude as P
20+
import Data.Array.Accelerate as A
21+
import Data.Array.Accelerate.Debug as A
22+
import Data.Array.Accelerate.Interpreter as I
23+
import Data.Array.Accelerate.LLVM.Native as CPU
24+
import Data.Array.Accelerate.Array.Sugar as Sugar( (!), Arrays, Array, Shape, Elt, DIM0, DIM1, DIM2, DIM3, Z(..), (:.)(..), fromList, size )
25+
import Data.Array.Accelerate.Test.Similar
26+
import Test.Tasty
27+
import Test.Tasty.Hedgehog
28+
import Data.Typeable
29+
30+
31+
{--
32+
data Expr =
33+
Var String
34+
| Lam String Expr
35+
| App Expr Expr
36+
37+
-- Assuming we have a name generator
38+
genName :: MonadGen m => m String
39+
40+
-- We can write a generator for expressions
41+
genExpr :: MonadGen m => m Expr
42+
genExpr =
43+
Gen.recursive Gen.choice [
44+
-- non-recursive generators
45+
Var <$> genName
46+
] [
47+
-- recursive generators
48+
Gen.subtermM genExpr (x -> Lam <$> genName <*> pure x)
49+
, Gen.subterm2 genExpr genExpr App
50+
]
51+
--}
52+
53+
54+
-- Operator terms we want to generate random combinations of
55+
data Term
56+
= Val
57+
| Sin Term
58+
| Cos Term
59+
| Exponential Term
60+
| Divide Term Term
61+
deriving Show
62+
63+
-- we can write a generator for expressions
64+
genTerm :: MonadGen m => m Term
65+
genTerm =
66+
Gen.recursive Gen.choice
67+
-- non-recursive generators
68+
[ pure Val ]
69+
-- recursive generators
70+
[ Gen.subtermM genTerm (pure . Sin)
71+
, Gen.subtermM genTerm (pure . Cos)
72+
, Gen.subtermM genTerm (pure . Exponential)
73+
]
74+
75+
evalTerm :: Term -> Double -> Double
76+
evalTerm Val x = x
77+
evalTerm (Sin t) x = P.sin (evalTerm t x)
78+
evalTerm (Cos t) x = P.cos (evalTerm t x)
79+
evalTerm (Exponential t) x = P.exp (evalTerm t x)
80+
81+
termToAcc :: Term -> Exp Double -> Exp Double
82+
termToAcc Val x = x
83+
termToAcc (Sin t) x = A.sin (termToAcc t x)
84+
termToAcc (Cos t) x = A.cos (termToAcc t x)
85+
termToAcc (Exponential t) x = A.exp (termToAcc t x)
86+
87+
88+
type Run = forall a. Arrays a => Acc a -> a
89+
type RunN = forall f. Afunction f => f -> AfunctionR f
90+
91+
dim0 :: Gen DIM0
92+
dim0 = return Z
93+
94+
dim1 :: Gen DIM1
95+
dim1 = (Z :.) <$> Gen.int (Range.linear 0 1024)
96+
97+
f64 :: Gen Double
98+
f64 = Gen.double (Range.linearFracFrom 0 (-log_flt_max) log_flt_max)
99+
100+
array :: (Shape sh, Elt e) => sh -> Gen e -> Gen (Array sh e)
101+
array sh gen = fromList sh <$> Gen.list (Range.singleton (Sugar.size sh)) gen
102+
103+
log_flt_max :: P.RealFloat a => a
104+
log_flt_max = log flt_max
105+
106+
flt_max :: P.RealFloat a => a
107+
flt_max = x
108+
where
109+
n = P.floatDigits x
110+
b = P.floatRadix x
111+
(_,u) = P.floatRange x
112+
x = P.encodeFloat (b P.^n - 1) (u - n)
113+
114+
flt_min :: P.RealFloat a => a
115+
flt_min = x
116+
where
117+
n = P.floatDigits x
118+
b = P.floatRadix x
119+
(l,_) = P.floatRange x
120+
x = P.encodeFloat (b P.^n - 1) (l - n - 1)
121+
122+
-- except :: Gen e -> (e -> Bool) -> Gen e
123+
-- except gen f = do
124+
-- v <- gen
125+
-- when (f v) Gen.discard
126+
-- return v
127+
128+
prop_commutative :: Property
129+
prop_commutative =
130+
property $ do
131+
sh <- forAll dim1
132+
xs <- forAll (array sh f64)
133+
ys <- forAll (array sh f64)
134+
let go = CPU.runN (\x y -> fromCheb (toCheb x + toCheb y))
135+
go xs ys === go ys xs
136+
137+
prop_approx_error :: Property
138+
prop_approx_error =
139+
property $ do
140+
term <- forAll genTerm
141+
let f = termToAcc term
142+
pol = chebf f 20 -- TODO
143+
-- pol = chebfPrecise f -- TODO
144+
err = approxError f pol
145+
ok = A.all (A.< 1.0E-15) err
146+
--
147+
True === indexArray (CPU.run ok) Z
148+
149+
fullrange :: P.RealFloat e => (Range e -> Gen e) -> Gen e
150+
fullrange gen = gen (Range.linearFracFrom 0 (-flt_max) flt_max)
151+
152+
tests :: IO Bool
153+
tests =
154+
checkParallel $$(discover)
10155

11-
module HedgehogTest
12-
where
13-
import Hedgehog
14-
import qualified Hedgehog.Gen as Gen
15-
import qualified Hedgehog.Range as Range
16-
import ChebApproxAcc
17-
import Prelude as P
18-
import Data.Array.Accelerate as A
19-
import Data.Array.Accelerate.Debug as A
20-
import Data.Array.Accelerate.Interpreter as I
21-
import Data.Array.Accelerate.LLVM.Native as CPU
22-
import Data.Array.Accelerate.Array.Sugar as Sugar( (!), Arrays, Array, Shape, Elt, DIM0, DIM1, DIM2, DIM3, Z(..), (:.)(..), fromList, size )
23-
import Data.Array.Accelerate.Test.Similar
24-
import Test.Tasty
25-
import Test.Tasty.Hedgehog
26-
import Data.Typeable
27-
28-
29-
type Run = forall a. Arrays a => Acc a -> a
30-
type RunN = forall f. Afunction f => f -> AfunctionR f
31-
32-
dim0 :: Gen DIM0
33-
dim0 = return Z
34-
35-
dim1 :: Gen DIM1
36-
dim1 = (Z :.) <$> Gen.int (Range.linear 0 1024)
37-
38-
array :: (Shape sh, Elt e) => sh -> Gen e -> Gen (Array sh e)
39-
array sh gen = fromList sh <$> Gen.list (Range.singleton (Sugar.size sh)) gen
40-
41-
42-
flt_max :: A.RealFloat a => a
43-
flt_max = x
44-
where
45-
n = A.floatDigits x
46-
b = A.floatRadix x
47-
(_,u) = A.floatRange x
48-
x = A.encodeFloat (b A.^n - 1) (u - n)
49-
50-
flt_min :: A.RealFloat a => a
51-
flt_min = x
52-
where
53-
n = A.floatDigits x
54-
b = A.floatRadix x
55-
(l,_) = A.floatRange x
56-
x = A.encodeFloat (b A.^n - 1) (l - n - 1)
57-
58-
{- except :: Gen e -> (e -> Bool) -> Gen e
59-
except gen f = do
60-
v <- gen
61-
when (f v) Gen.discard
62-
return v
63-
64-
splitEvery :: Int -> [a] -> [[a]]
65-
splitEvery _ [] = cycle [[]]
66-
splitEvery n xs =
67-
let (h,t) = splitAt n xs
68-
in h : splitEvery n t
69-
70-
splitPlaces :: [Int] -> [a] -> [[a]]
71-
splitPlaces [] _ = []
72-
splitPlaces (i:is) vs =
73-
let (h,t) = splitAt i vs
74-
in h : splitPlaces is t -}
75-
76-
77-
prop_commutative :: Property
78-
prop_commutative =
79-
property $ do
80-
xs <- forAll $ (Gen.list (Range.linear 0 100) (Gen.double (Range.linearFrac 0 2)))
81-
ys <- forAll $ (Gen.list (Range.linear 0 100) (Gen.double (Range.linearFrac 0 2)))
82-
(CPU.run $ use (fromList (Z :. (100::Int)) xs)) === (CPU.run $ use (fromList (Z :. (100::Int)) xs))
83-
--(CPU.run $ (sumVectors ( use (fromList (Z :. 100) xs)) (use (fromList (Z :. 100 ) ys)))) === (CPU.run $ (sumVectors ( use (fromList (Z :. 100) xs)) (use (fromList (Z :. 100 ) ys))) )
84-
85-
tests :: IO Bool
86-
tests =
87-
checkParallel $$(discover)
88-
89-
90-
--test = testProperty "commutative" $ test_sqrt CPU.runN sh (e (Range.linearFrac 0 flt_max))
91-
92-
test_sqrt
93-
:: (Shape sh, Similar e, P.Eq sh, P.Floating e, A.Floating e)
94-
=> RunN
95-
-> Gen sh
96-
-> Gen e
97-
-> Property
98-
test_sqrt runN dim e =
99-
property $ do
100-
sh <- forAll dim
101-
xs <- forAll (array sh e)
102-
let !go = runN (A.map sqrt) in go xs ~~~ mapRef sqrt xs
103-
104-
mapRef :: (Shape sh, Elt a, Elt b) => (a -> b) -> Array sh a -> Array sh b
105-
mapRef f xs = fromFunction (arrayShape xs) (\ix -> f (xs Sugar.! ix))
106-
107-
newtype TestDouble = TestDouble Bool deriving (P.Eq, P.Show, Typeable)
108-
109-
test_map :: RunN -> TestTree
110-
test_map runN =
111-
testGroup "map"
112-
[
113-
at @TestDouble $ testFloatingElt Gen.double
114-
]
115-
where
116-
117-
testFloatingElt
118-
:: forall a. (P.RealFloat a, A.Floating a, A.RealFrac a, Similar a)
119-
=> (Range a -> Gen a)
120-
-> TestTree
121-
testFloatingElt e =
122-
testGroup (show (typeOf (undefined :: a)))
123-
[
124-
testDim dim1
125-
]
126-
where
127-
testDim
128-
:: forall sh. (Shape sh, P.Eq sh)
129-
=> Gen sh
130-
-> TestTree
131-
testDim sh =
132-
testGroup ("DIM" P.++ show (rank @sh))
133-
[ -- operators on Num
134-
testProperty "sqrt" $ test_sqrt runN sh (e (Range.linearFrac 0 flt_max))
135-
]
136-
137-
fullrange :: P.RealFloat e => (Range e -> Gen e) -> Gen e
138-
fullrange gen = gen (Range.linearFracFrom 0 (-flt_max) flt_max)
139-

0 commit comments

Comments
 (0)