Skip to content

random-1.3 support #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cuddle.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ library
ordered-containers ^>=0.2.4,
parser-combinators ^>=1.3,
prettyprinter ^>=1.7.1,
random <1.3,
random >=1.2,
regex-tdfa ^>=1.3.2,
scientific ^>=0.3.8,
text >=2.0.2 && <2.2,
Expand Down
72 changes: 43 additions & 29 deletions src/Codec/CBOR/Cuddle/CBOR/Gen.hs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}

-- | Generate example CBOR given a CDDL specification
Expand Down Expand Up @@ -45,24 +49,31 @@ import Data.Word (Word32, Word64)
import GHC.Generics (Generic)
import System.Random.Stateful (
Random,
RandomGen (genShortByteString, genWord32, genWord64),
RandomGenM,
RandomGen (..),
StatefulGen (..),
UniformRange (uniformRM),
applyRandomGenM,
randomM,
uniformByteStringM,
)

#if MIN_VERSION_random(1,3,0)
import Data.Coerce
import System.Random.Stateful (
FrozenGen (..),
uniformByteArray,
)
#else
import System.Random.Stateful (
RandomGenM,
applyRandomGenM,
)
#endif
--------------------------------------------------------------------------------
-- Generator infrastructure
--------------------------------------------------------------------------------

-- | Generator context, parametrised over the type of the random seed
data GenEnv g = GenEnv
newtype GenEnv = GenEnv
{ cddl :: CTreeRoot' Identity MonoRef
, fakeSeed :: CapGenM g
-- ^ Access the "fake" seed, necessary to recursively call generators
}
deriving (Generic)

Expand All @@ -77,63 +88,67 @@ data GenState g = GenState
}
deriving (Generic)

newtype M g a = M {runM :: StateT (GenState g) (Reader (GenEnv g)) a}
newtype M g a = M {runM :: StateT (GenState g) (Reader GenEnv) a}
deriving (Functor, Applicative, Monad)
deriving
(HasSource "randomSeed" g, HasSink "randomSeed" g, HasState "randomSeed" g)
via Field
"randomSeed"
()
(MonadState (StateT (GenState g) (Reader (GenEnv g))))
(MonadState (StateT (GenState g) (Reader GenEnv)))
deriving
(HasSource "depth" Int, HasSink "depth" Int, HasState "depth" Int)
via Field
"depth"
()
(MonadState (StateT (GenState g) (Reader (GenEnv g))))
(MonadState (StateT (GenState g) (Reader GenEnv)))
deriving
( HasSource "cddl" (CTreeRoot' Identity MonoRef)
, HasReader "cddl" (CTreeRoot' Identity MonoRef)
)
via Field
"cddl"
()
(Lift (StateT (GenState g) (MonadReader (Reader (GenEnv g)))))
deriving
(HasSource "fakeSeed" (CapGenM g), HasReader "fakeSeed" (CapGenM g))
via Field
"fakeSeed"
()
(Lift (StateT (GenState g) (MonadReader (Reader (GenEnv g)))))
(Lift (StateT (GenState g) (MonadReader (Reader GenEnv))))

-- | Opaque type carrying the type of a pure PRNG inside a capability-style
-- state monad.
data CapGenM g = CapGenM

newtype CapGen g = CapGen g
deriving (RandomGen, Eq)

instance RandomGen g => StatefulGen (CapGenM g) (M g) where
uniformWord64 _ = state @"randomSeed" genWord64
uniformWord32 _ = state @"randomSeed" genWord32

#if MIN_VERSION_random(1,3,0)
uniformByteArrayM isPinned n _ = state @"randomSeed" (uniformByteArray isPinned n)
#else
uniformShortByteString n _ = state @"randomSeed" (genShortByteString n)
#endif

#if MIN_VERSION_random(1,3,0)
instance RandomGen r => FrozenGen (CapGen r) (M r) where
type MutableGen (CapGen r) (M r) = CapGenM r
modifyGen CapGenM f = state @"randomSeed" (coerce f)
#else
instance RandomGen r => RandomGenM (CapGenM r) r (M r) where
applyRandomGenM f _ = state @"randomSeed" f
#endif

runGen :: M g a -> GenEnv g -> GenState g -> (a, GenState g)
runGen :: M g a -> GenEnv -> GenState g -> (a, GenState g)
runGen m env st = runReader (runStateT (runM m) st) env

evalGen :: M g a -> GenEnv g -> GenState g -> a
evalGen :: M g a -> GenEnv -> GenState g -> a
evalGen m env = fst . runGen m env

asksM :: forall tag r m a. HasReader tag r m => (r -> m a) -> m a
asksM f = f =<< ask @tag

--------------------------------------------------------------------------------
-- Wrappers around some Random function in Gen
--------------------------------------------------------------------------------

genUniformRM :: forall a g. (UniformRange a, RandomGen g) => (a, a) -> M g a
genUniformRM = asksM @"fakeSeed" . uniformRM
genUniformRM r = uniformRM r (CapGenM @g)

-- | Generate a random number in a given range, biased increasingly towards the
-- lower end as the depth parameter increases.
Expand All @@ -143,9 +158,8 @@ genDepthBiasedRM ::
(a, a) ->
M g a
genDepthBiasedRM bounds = do
fs <- ask @"fakeSeed"
d <- get @"depth"
samples <- replicateM d (uniformRM bounds fs)
samples <- replicateM d (genUniformRM bounds)
pure $ minimum samples

-- | Generates a bool, increasingly likely to be 'False' as the depth increases.
Expand All @@ -155,10 +169,10 @@ genDepthBiasedBool = do
and <$> replicateM d genRandomM

genRandomM :: forall g a. (Random a, RandomGen g) => M g a
genRandomM = asksM @"fakeSeed" randomM
genRandomM = randomM (CapGenM @g)

genBytes :: forall g. RandomGen g => Int -> M g ByteString
genBytes n = asksM @"fakeSeed" $ uniformByteStringM n
genBytes n = uniformByteStringM n (CapGenM @g)

genText :: forall g. RandomGen g => Int -> M g Text
genText n = pure $ T.pack . take n . join $ repeat ['a' .. 'z']
Expand Down Expand Up @@ -436,12 +450,12 @@ genValueVariant (VBool b) = pure $ TBool b

generateCBORTerm :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> Term
generateCBORTerm cddl n stdGen =
let genEnv = GenEnv {cddl, fakeSeed = CapGenM}
let genEnv = GenEnv {cddl}
genState = GenState {randomSeed = stdGen, depth = 1}
in evalGen (genForName n) genEnv genState

generateCBORTerm' :: RandomGen g => CTreeRoot' Identity MonoRef -> Name -> g -> (Term, g)
generateCBORTerm' cddl n stdGen =
let genEnv = GenEnv {cddl, fakeSeed = CapGenM}
let genEnv = GenEnv {cddl}
genState = GenState {randomSeed = stdGen, depth = 1}
in second randomSeed $ runGen (genForName n) genEnv genState