1
1
{-# LANGUAGE AllowAmbiguousTypes #-}
2
+ {-# LANGUAGE CPP #-}
2
3
{-# LANGUAGE DataKinds #-}
3
4
{-# LANGUAGE DerivingVia #-}
4
5
{-# LANGUAGE GADTs #-}
5
6
{-# LANGUAGE LambdaCase #-}
6
7
{-# LANGUAGE OverloadedStrings #-}
7
8
{-# LANGUAGE PatternSynonyms #-}
9
+ {-# LANGUAGE ScopedTypeVariables #-}
8
10
{-# LANGUAGE ViewPatterns #-}
9
11
12
+ #if MIN_VERSION_random(1,3,0)
13
+ {-# OPTIONS_GHC -Wno-deprecations #-} -- Due to usage of `split`
14
+ #endif
10
15
-- | Generate example CBOR given a CDDL specification
11
16
module Codec.CBOR.Cuddle.CBOR.Gen (generateCBORTerm , generateCBORTerm' ) where
12
17
13
18
import Capability.Reader
14
19
import Capability.Sink (HasSink )
15
20
import Capability.Source (HasSource , MonadState (.. ))
16
- import Capability.State (HasState , get , modify , state )
21
+ import Capability.State (HasState , get , modify )
17
22
import Codec.CBOR.Cuddle.CDDL (
18
23
Name (.. ),
19
24
OccurrenceIndicator (.. ),
@@ -31,6 +36,7 @@ import Codec.CBOR.Write qualified as CBOR
31
36
import Control.Monad (join , replicateM , (<=<) )
32
37
import Control.Monad.Reader (Reader , runReader )
33
38
import Control.Monad.State.Strict (StateT , runStateT )
39
+ import Control.Monad.State.Strict qualified as MTL
34
40
import Data.Bifunctor (second )
35
41
import Data.ByteString (ByteString )
36
42
import Data.ByteString.Base16 qualified as Base16
@@ -45,24 +51,25 @@ import Data.Word (Word32, Word64)
45
51
import GHC.Generics (Generic )
46
52
import System.Random.Stateful (
47
53
Random ,
48
- RandomGen (genShortByteString , genWord32 , genWord64 ),
49
- RandomGenM ,
50
- StatefulGen (.. ),
54
+ RandomGen (.. ),
55
+ StateGenM (.. ),
51
56
UniformRange (uniformRM ),
52
- applyRandomGenM ,
53
57
randomM ,
54
58
uniformByteStringM ,
55
59
)
60
+ #if MIN_VERSION_random(1,3,0)
61
+ import System.Random.Stateful (
62
+ SplitGen (.. )
63
+ )
64
+ #endif
56
65
57
66
--------------------------------------------------------------------------------
58
67
-- Generator infrastructure
59
68
--------------------------------------------------------------------------------
60
69
61
70
-- | Generator context, parametrised over the type of the random seed
62
- data GenEnv g = GenEnv
71
+ newtype GenEnv = GenEnv
63
72
{ cddl :: CTreeRoot' Identity MonoRef
64
- , fakeSeed :: CapGenM g
65
- -- ^ Access the "fake" seed, necessary to recursively call generators
66
73
}
67
74
deriving (Generic )
68
75
@@ -77,63 +84,63 @@ data GenState g = GenState
77
84
}
78
85
deriving (Generic )
79
86
80
- newtype M g a = M { runM :: StateT (GenState g ) (Reader (GenEnv g )) a }
81
- deriving (Functor , Applicative , Monad )
87
+ instance RandomGen g => RandomGen (GenState g ) where
88
+ genWord8 = withRandomSeed genWord8
89
+ genWord16 = withRandomSeed genWord16
90
+ genWord32 = withRandomSeed genWord32
91
+ genWord64 = withRandomSeed genWord64
92
+ split = splitGenStateWith split
93
+
94
+ #if MIN_VERSION_random(1,3,0)
95
+ instance SplitGen g => SplitGen (GenState g ) where
96
+ splitGen = splitGenStateWith splitGen
97
+ #endif
98
+
99
+ splitGenStateWith :: (g -> (g , g )) -> GenState g -> (GenState g , GenState g )
100
+ splitGenStateWith f s =
101
+ case f (randomSeed s) of
102
+ (gen', gen) -> (s {randomSeed = gen'}, s {randomSeed = gen})
103
+
104
+ withRandomSeed :: (t -> (a , g )) -> GenState t -> (a , GenState g )
105
+ withRandomSeed f s =
106
+ case f (randomSeed s) of
107
+ (r, gen) -> (r, s {randomSeed = gen})
108
+
109
+ newtype M g a = M { runM :: StateT (GenState g ) (Reader GenEnv ) a }
110
+ deriving (Functor , Applicative , Monad , MTL.MonadState (GenState g))
82
111
deriving
83
112
(HasSource " randomSeed" g , HasSink " randomSeed" g , HasState " randomSeed" g )
84
113
via Field
85
114
" randomSeed"
86
115
()
87
- (MonadState (StateT (GenState g ) (Reader ( GenEnv g ) )))
116
+ (MonadState (StateT (GenState g ) (Reader GenEnv )))
88
117
deriving
89
118
(HasSource " depth" Int , HasSink " depth" Int , HasState " depth" Int )
90
119
via Field
91
120
" depth"
92
121
()
93
- (MonadState (StateT (GenState g ) (Reader ( GenEnv g ) )))
122
+ (MonadState (StateT (GenState g ) (Reader GenEnv )))
94
123
deriving
95
124
( HasSource " cddl" (CTreeRoot' Identity MonoRef )
96
125
, HasReader " cddl" (CTreeRoot' Identity MonoRef )
97
126
)
98
127
via Field
99
128
" cddl"
100
129
()
101
- (Lift (StateT (GenState g ) (MonadReader (Reader (GenEnv g )))))
102
- deriving
103
- (HasSource " fakeSeed" (CapGenM g ), HasReader " fakeSeed" (CapGenM g ))
104
- via Field
105
- " fakeSeed"
106
- ()
107
- (Lift (StateT (GenState g ) (MonadReader (Reader (GenEnv g )))))
108
-
109
- -- | Opaque type carrying the type of a pure PRNG inside a capability-style
110
- -- state monad.
111
- data CapGenM g = CapGenM
130
+ (Lift (StateT (GenState g ) (MonadReader (Reader GenEnv ))))
112
131
113
- instance RandomGen g => StatefulGen (CapGenM g ) (M g ) where
114
- uniformWord64 _ = state @ " randomSeed" genWord64
115
- uniformWord32 _ = state @ " randomSeed" genWord32
116
-
117
- uniformShortByteString n _ = state @ " randomSeed" (genShortByteString n)
118
-
119
- instance RandomGen r => RandomGenM (CapGenM r ) r (M r ) where
120
- applyRandomGenM f _ = state @ " randomSeed" f
121
-
122
- runGen :: M g a -> GenEnv g -> GenState g -> (a , GenState g )
132
+ runGen :: M g a -> GenEnv -> GenState g -> (a , GenState g )
123
133
runGen m env st = runReader (runStateT (runM m) st) env
124
134
125
- evalGen :: M g a -> GenEnv g -> GenState g -> a
135
+ evalGen :: M g a -> GenEnv -> GenState g -> a
126
136
evalGen m env = fst . runGen m env
127
137
128
- asksM :: forall tag r m a . HasReader tag r m => (r -> m a ) -> m a
129
- asksM f = f =<< ask @ tag
130
-
131
138
--------------------------------------------------------------------------------
132
139
-- Wrappers around some Random function in Gen
133
140
--------------------------------------------------------------------------------
134
141
135
142
genUniformRM :: forall a g . (UniformRange a , RandomGen g ) => (a , a ) -> M g a
136
- genUniformRM = asksM @ " fakeSeed " . uniformRM
143
+ genUniformRM r = uniformRM r ( StateGenM @ ( GenState g ))
137
144
138
145
-- | Generate a random number in a given range, biased increasingly towards the
139
146
-- lower end as the depth parameter increases.
@@ -143,9 +150,8 @@ genDepthBiasedRM ::
143
150
(a , a ) ->
144
151
M g a
145
152
genDepthBiasedRM bounds = do
146
- fs <- ask @ " fakeSeed"
147
153
d <- get @ " depth"
148
- samples <- replicateM d (uniformRM bounds fs )
154
+ samples <- replicateM d (genUniformRM bounds)
149
155
pure $ minimum samples
150
156
151
157
-- | Generates a bool, increasingly likely to be 'False' as the depth increases.
@@ -155,10 +161,10 @@ genDepthBiasedBool = do
155
161
and <$> replicateM d genRandomM
156
162
157
163
genRandomM :: forall g a . (Random a , RandomGen g ) => M g a
158
- genRandomM = asksM @ " fakeSeed " randomM
164
+ genRandomM = randomM ( StateGenM @ ( GenState g ))
159
165
160
166
genBytes :: forall g . RandomGen g => Int -> M g ByteString
161
- genBytes n = asksM @ " fakeSeed " $ uniformByteStringM n
167
+ genBytes n = uniformByteStringM n ( StateGenM @ ( GenState g ))
162
168
163
169
genText :: forall g . RandomGen g => Int -> M g Text
164
170
genText n = pure $ T. pack . take n . join $ repeat [' a' .. ' z' ]
@@ -436,12 +442,12 @@ genValueVariant (VBool b) = pure $ TBool b
436
442
437
443
generateCBORTerm :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> Term
438
444
generateCBORTerm cddl n stdGen =
439
- let genEnv = GenEnv {cddl, fakeSeed = CapGenM }
445
+ let genEnv = GenEnv {cddl}
440
446
genState = GenState {randomSeed = stdGen, depth = 1 }
441
447
in evalGen (genForName n) genEnv genState
442
448
443
449
generateCBORTerm' :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> (Term , g )
444
450
generateCBORTerm' cddl n stdGen =
445
- let genEnv = GenEnv {cddl, fakeSeed = CapGenM }
451
+ let genEnv = GenEnv {cddl}
446
452
genState = GenState {randomSeed = stdGen, depth = 1 }
447
453
in second randomSeed $ runGen (genForName n) genEnv genState
0 commit comments