Skip to content

Commit f001a22

Browse files
committed
Simplify usage of random, by removing fakeSeed
1 parent a950eb2 commit f001a22

File tree

1 file changed

+15
-25
lines changed
  • src/Codec/CBOR/Cuddle/CBOR

1 file changed

+15
-25
lines changed

src/Codec/CBOR/Cuddle/CBOR/Gen.hs

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
{-# LANGUAGE LambdaCase #-}
77
{-# LANGUAGE OverloadedStrings #-}
88
{-# LANGUAGE PatternSynonyms #-}
9+
{-# LANGUAGE ScopedTypeVariables #-}
10+
{-# LANGUAGE TypeApplications #-}
911
{-# LANGUAGE TypeFamilies #-}
1012
{-# LANGUAGE ViewPatterns #-}
1113

@@ -70,10 +72,8 @@ import System.Random.Stateful (
7072
--------------------------------------------------------------------------------
7173

7274
-- | Generator context, parametrised over the type of the random seed
73-
data GenEnv g = GenEnv
75+
newtype GenEnv = GenEnv
7476
{ cddl :: CTreeRoot' Identity MonoRef
75-
, fakeSeed :: CapGenM g
76-
-- ^ Access the "fake" seed, necessary to recursively call generators
7777
}
7878
deriving (Generic)
7979

@@ -88,34 +88,28 @@ data GenState g = GenState
8888
}
8989
deriving (Generic)
9090

91-
newtype M g a = M {runM :: StateT (GenState g) (Reader (GenEnv g)) a}
91+
newtype M g a = M {runM :: StateT (GenState g) (Reader GenEnv) a}
9292
deriving (Functor, Applicative, Monad)
9393
deriving
9494
(HasSource "randomSeed" g, HasSink "randomSeed" g, HasState "randomSeed" g)
9595
via Field
9696
"randomSeed"
9797
()
98-
(MonadState (StateT (GenState g) (Reader (GenEnv g))))
98+
(MonadState (StateT (GenState g) (Reader GenEnv)))
9999
deriving
100100
(HasSource "depth" Int, HasSink "depth" Int, HasState "depth" Int)
101101
via Field
102102
"depth"
103103
()
104-
(MonadState (StateT (GenState g) (Reader (GenEnv g))))
104+
(MonadState (StateT (GenState g) (Reader GenEnv)))
105105
deriving
106106
( HasSource "cddl" (CTreeRoot' Identity MonoRef)
107107
, HasReader "cddl" (CTreeRoot' Identity MonoRef)
108108
)
109109
via Field
110110
"cddl"
111111
()
112-
(Lift (StateT (GenState g) (MonadReader (Reader (GenEnv g)))))
113-
deriving
114-
(HasSource "fakeSeed" (CapGenM g), HasReader "fakeSeed" (CapGenM g))
115-
via Field
116-
"fakeSeed"
117-
()
118-
(Lift (StateT (GenState g) (MonadReader (Reader (GenEnv g)))))
112+
(Lift (StateT (GenState g) (MonadReader (Reader GenEnv))))
119113

120114
-- | Opaque type carrying the type of a pure PRNG inside a capability-style
121115
-- state monad.
@@ -143,21 +137,18 @@ instance RandomGen r => RandomGenM (CapGenM r) r (M r) where
143137
applyRandomGenM f _ = state @"randomSeed" f
144138
#endif
145139

146-
runGen :: M g a -> GenEnv g -> GenState g -> (a, GenState g)
140+
runGen :: M g a -> GenEnv -> GenState g -> (a, GenState g)
147141
runGen m env st = runReader (runStateT (runM m) st) env
148142

149-
evalGen :: M g a -> GenEnv g -> GenState g -> a
143+
evalGen :: M g a -> GenEnv -> GenState g -> a
150144
evalGen m env = fst . runGen m env
151145

152-
asksM :: forall tag r m a. HasReader tag r m => (r -> m a) -> m a
153-
asksM f = f =<< ask @tag
154-
155146
--------------------------------------------------------------------------------
156147
-- Wrappers around some Random function in Gen
157148
--------------------------------------------------------------------------------
158149

159150
genUniformRM :: forall a g. (UniformRange a, RandomGen g) => (a, a) -> M g a
160-
genUniformRM = asksM @"fakeSeed" . uniformRM
151+
genUniformRM r = uniformRM r (CapGenM @g)
161152

162153
-- | Generate a random number in a given range, biased increasingly towards the
163154
-- lower end as the depth parameter increases.
@@ -167,9 +158,8 @@ genDepthBiasedRM ::
167158
(a, a) ->
168159
M g a
169160
genDepthBiasedRM bounds = do
170-
fs <- ask @"fakeSeed"
171161
d <- get @"depth"
172-
samples <- replicateM d (uniformRM bounds fs)
162+
samples <- replicateM d (genUniformRM bounds)
173163
pure $ minimum samples
174164

175165
-- | Generates a bool, increasingly likely to be 'False' as the depth increases.
@@ -179,10 +169,10 @@ genDepthBiasedBool = do
179169
and <$> replicateM d genRandomM
180170

181171
genRandomM :: forall g a. (Random a, RandomGen g) => M g a
182-
genRandomM = asksM @"fakeSeed" randomM
172+
genRandomM = randomM (CapGenM @g)
183173

184174
genBytes :: forall g. RandomGen g => Int -> M g ByteString
185-
genBytes n = asksM @"fakeSeed" $ uniformByteStringM n
175+
genBytes n = uniformByteStringM n (CapGenM @g)
186176

187177
genText :: forall g. RandomGen g => Int -> M g Text
188178
genText n = pure $ T.pack . take n . join $ repeat ['a' .. 'z']
@@ -460,12 +450,12 @@ genValueVariant (VBool b) = pure $ TBool b
460450

461451
generateCBORTerm :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> Term
462452
generateCBORTerm cddl n stdGen =
463-
let genEnv = GenEnv {cddl, fakeSeed = CapGenM}
453+
let genEnv = GenEnv {cddl}
464454
genState = GenState {randomSeed = stdGen, depth = 1}
465455
in evalGen (genForName n) genEnv genState
466456

467457
generateCBORTerm' :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> (Term, g)
468458
generateCBORTerm' cddl n stdGen =
469-
let genEnv = GenEnv {cddl, fakeSeed = CapGenM}
459+
let genEnv = GenEnv {cddl}
470460
genState = GenState {randomSeed = stdGen, depth = 1}
471461
in second randomSeed $ runGen (genForName n) genEnv genState

0 commit comments

Comments
 (0)