Skip to content

Commit adcc5e8

Browse files
authored
Merge pull request #84 from input-output-hk/lehins/simplify-random-usage
Alternative usage of `random`
2 parents d704372 + c30a073 commit adcc5e8

File tree

2 files changed

+51
-45
lines changed

2 files changed

+51
-45
lines changed

cuddle.cabal

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ library
8181
ordered-containers ^>=0.2.4,
8282
parser-combinators ^>=1.3,
8383
prettyprinter ^>=1.7.1,
84-
random <1.3,
84+
random >=1.2,
8585
regex-tdfa ^>=1.3.2,
8686
scientific ^>=0.3.8,
8787
text >=2.0.2 && <2.2,

src/Codec/CBOR/Cuddle/CBOR/Gen.hs

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
{-# LANGUAGE AllowAmbiguousTypes #-}
2+
{-# LANGUAGE CPP #-}
23
{-# LANGUAGE DataKinds #-}
34
{-# LANGUAGE DerivingVia #-}
45
{-# LANGUAGE GADTs #-}
56
{-# LANGUAGE LambdaCase #-}
67
{-# LANGUAGE OverloadedStrings #-}
78
{-# LANGUAGE PatternSynonyms #-}
9+
{-# LANGUAGE ScopedTypeVariables #-}
810
{-# LANGUAGE ViewPatterns #-}
911

12+
#if MIN_VERSION_random(1,3,0)
13+
{-# OPTIONS_GHC -Wno-deprecations #-} -- Due to usage of `split`
14+
#endif
1015
-- | Generate example CBOR given a CDDL specification
1116
module Codec.CBOR.Cuddle.CBOR.Gen (generateCBORTerm, generateCBORTerm') where
1217

1318
import Capability.Reader
1419
import Capability.Sink (HasSink)
1520
import Capability.Source (HasSource, MonadState (..))
16-
import Capability.State (HasState, get, modify, state)
21+
import Capability.State (HasState, get, modify)
1722
import Codec.CBOR.Cuddle.CDDL (
1823
Name (..),
1924
OccurrenceIndicator (..),
@@ -31,6 +36,7 @@ import Codec.CBOR.Write qualified as CBOR
3136
import Control.Monad (join, replicateM, (<=<))
3237
import Control.Monad.Reader (Reader, runReader)
3338
import Control.Monad.State.Strict (StateT, runStateT)
39+
import Control.Monad.State.Strict qualified as MTL
3440
import Data.Bifunctor (second)
3541
import Data.ByteString (ByteString)
3642
import Data.ByteString.Base16 qualified as Base16
@@ -45,24 +51,25 @@ import Data.Word (Word32, Word64)
4551
import GHC.Generics (Generic)
4652
import System.Random.Stateful (
4753
Random,
48-
RandomGen (genShortByteString, genWord32, genWord64),
49-
RandomGenM,
50-
StatefulGen (..),
54+
RandomGen (..),
55+
StateGenM (..),
5156
UniformRange (uniformRM),
52-
applyRandomGenM,
5357
randomM,
5458
uniformByteStringM,
5559
)
60+
#if MIN_VERSION_random(1,3,0)
61+
import System.Random.Stateful (
62+
SplitGen (..)
63+
)
64+
#endif
5665

5766
--------------------------------------------------------------------------------
5867
-- Generator infrastructure
5968
--------------------------------------------------------------------------------
6069

6170
-- | Generator context, parametrised over the type of the random seed
62-
data GenEnv g = GenEnv
71+
newtype GenEnv = GenEnv
6372
{ cddl :: CTreeRoot' Identity MonoRef
64-
, fakeSeed :: CapGenM g
65-
-- ^ Access the "fake" seed, necessary to recursively call generators
6673
}
6774
deriving (Generic)
6875

@@ -77,63 +84,63 @@ data GenState g = GenState
7784
}
7885
deriving (Generic)
7986

80-
newtype M g a = M {runM :: StateT (GenState g) (Reader (GenEnv g)) a}
81-
deriving (Functor, Applicative, Monad)
87+
instance RandomGen g => RandomGen (GenState g) where
88+
genWord8 = withRandomSeed genWord8
89+
genWord16 = withRandomSeed genWord16
90+
genWord32 = withRandomSeed genWord32
91+
genWord64 = withRandomSeed genWord64
92+
split = splitGenStateWith split
93+
94+
#if MIN_VERSION_random(1,3,0)
95+
instance SplitGen g => SplitGen (GenState g) where
96+
splitGen = splitGenStateWith splitGen
97+
#endif
98+
99+
splitGenStateWith :: (g -> (g, g)) -> GenState g -> (GenState g, GenState g)
100+
splitGenStateWith f s =
101+
case f (randomSeed s) of
102+
(gen', gen) -> (s {randomSeed = gen'}, s {randomSeed = gen})
103+
104+
withRandomSeed :: (t -> (a, g)) -> GenState t -> (a, GenState g)
105+
withRandomSeed f s =
106+
case f (randomSeed s) of
107+
(r, gen) -> (r, s {randomSeed = gen})
108+
109+
newtype M g a = M {runM :: StateT (GenState g) (Reader GenEnv) a}
110+
deriving (Functor, Applicative, Monad, MTL.MonadState (GenState g))
82111
deriving
83112
(HasSource "randomSeed" g, HasSink "randomSeed" g, HasState "randomSeed" g)
84113
via Field
85114
"randomSeed"
86115
()
87-
(MonadState (StateT (GenState g) (Reader (GenEnv g))))
116+
(MonadState (StateT (GenState g) (Reader GenEnv)))
88117
deriving
89118
(HasSource "depth" Int, HasSink "depth" Int, HasState "depth" Int)
90119
via Field
91120
"depth"
92121
()
93-
(MonadState (StateT (GenState g) (Reader (GenEnv g))))
122+
(MonadState (StateT (GenState g) (Reader GenEnv)))
94123
deriving
95124
( HasSource "cddl" (CTreeRoot' Identity MonoRef)
96125
, HasReader "cddl" (CTreeRoot' Identity MonoRef)
97126
)
98127
via Field
99128
"cddl"
100129
()
101-
(Lift (StateT (GenState g) (MonadReader (Reader (GenEnv g)))))
102-
deriving
103-
(HasSource "fakeSeed" (CapGenM g), HasReader "fakeSeed" (CapGenM g))
104-
via Field
105-
"fakeSeed"
106-
()
107-
(Lift (StateT (GenState g) (MonadReader (Reader (GenEnv g)))))
108-
109-
-- | Opaque type carrying the type of a pure PRNG inside a capability-style
110-
-- state monad.
111-
data CapGenM g = CapGenM
130+
(Lift (StateT (GenState g) (MonadReader (Reader GenEnv))))
112131

113-
instance RandomGen g => StatefulGen (CapGenM g) (M g) where
114-
uniformWord64 _ = state @"randomSeed" genWord64
115-
uniformWord32 _ = state @"randomSeed" genWord32
116-
117-
uniformShortByteString n _ = state @"randomSeed" (genShortByteString n)
118-
119-
instance RandomGen r => RandomGenM (CapGenM r) r (M r) where
120-
applyRandomGenM f _ = state @"randomSeed" f
121-
122-
runGen :: M g a -> GenEnv g -> GenState g -> (a, GenState g)
132+
runGen :: M g a -> GenEnv -> GenState g -> (a, GenState g)
123133
runGen m env st = runReader (runStateT (runM m) st) env
124134

125-
evalGen :: M g a -> GenEnv g -> GenState g -> a
135+
evalGen :: M g a -> GenEnv -> GenState g -> a
126136
evalGen m env = fst . runGen m env
127137

128-
asksM :: forall tag r m a. HasReader tag r m => (r -> m a) -> m a
129-
asksM f = f =<< ask @tag
130-
131138
--------------------------------------------------------------------------------
132139
-- Wrappers around some Random function in Gen
133140
--------------------------------------------------------------------------------
134141

135142
genUniformRM :: forall a g. (UniformRange a, RandomGen g) => (a, a) -> M g a
136-
genUniformRM = asksM @"fakeSeed" . uniformRM
143+
genUniformRM r = uniformRM r (StateGenM @(GenState g))
137144

138145
-- | Generate a random number in a given range, biased increasingly towards the
139146
-- lower end as the depth parameter increases.
@@ -143,9 +150,8 @@ genDepthBiasedRM ::
143150
(a, a) ->
144151
M g a
145152
genDepthBiasedRM bounds = do
146-
fs <- ask @"fakeSeed"
147153
d <- get @"depth"
148-
samples <- replicateM d (uniformRM bounds fs)
154+
samples <- replicateM d (genUniformRM bounds)
149155
pure $ minimum samples
150156

151157
-- | Generates a bool, increasingly likely to be 'False' as the depth increases.
@@ -155,10 +161,10 @@ genDepthBiasedBool = do
155161
and <$> replicateM d genRandomM
156162

157163
genRandomM :: forall g a. (Random a, RandomGen g) => M g a
158-
genRandomM = asksM @"fakeSeed" randomM
164+
genRandomM = randomM (StateGenM @(GenState g))
159165

160166
genBytes :: forall g. RandomGen g => Int -> M g ByteString
161-
genBytes n = asksM @"fakeSeed" $ uniformByteStringM n
167+
genBytes n = uniformByteStringM n (StateGenM @(GenState g))
162168

163169
genText :: forall g. RandomGen g => Int -> M g Text
164170
genText n = pure $ T.pack . take n . join $ repeat ['a' .. 'z']
@@ -436,12 +442,12 @@ genValueVariant (VBool b) = pure $ TBool b
436442

437443
generateCBORTerm :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> Term
438444
generateCBORTerm cddl n stdGen =
439-
let genEnv = GenEnv {cddl, fakeSeed = CapGenM}
445+
let genEnv = GenEnv {cddl}
440446
genState = GenState {randomSeed = stdGen, depth = 1}
441447
in evalGen (genForName n) genEnv genState
442448

443449
generateCBORTerm' :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> (Term, g)
444450
generateCBORTerm' cddl n stdGen =
445-
let genEnv = GenEnv {cddl, fakeSeed = CapGenM}
451+
let genEnv = GenEnv {cddl}
446452
genState = GenState {randomSeed = stdGen, depth = 1}
447453
in second randomSeed $ runGen (genForName n) genEnv genState

0 commit comments

Comments
 (0)