Skip to content

Commit 9ee9cad

Browse files
committed
Add a way to build a random network
1 parent 36efb81 commit 9ee9cad

File tree

4 files changed

+49
-6
lines changed

4 files changed

+49
-6
lines changed

IndexedTardis.hs

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE PolyKinds #-}
12
module IndexedTardis where
23

34
import Control.Monad.Indexed

Main.hs

+28-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
{-# LANGUAGE GADTs #-}
2-
{-# LANGUAGE Rank2Types #-}
3-
{-# LANGUAGE RecordWildCards #-}
1+
{-# LANGUAGE FlexibleContexts #-}
2+
{-# LANGUAGE GADTs #-}
3+
{-# LANGUAGE Rank2Types #-}
4+
{-# LANGUAGE RecordWildCards #-}
5+
{-# LANGUAGE PolyKinds #-}
46
module Main where
57

68
import Control.Category
79
import Control.Monad.Indexed
810

911
import Data.Composition
1012
import Data.Distributive
13+
import Data.Random
1114

1215
import Linear
1316
import Linear.V
@@ -17,6 +20,7 @@ import Numeric.AD
1720
import Prelude hiding (id, (.))
1821

1922
import IndexedTardis
23+
import WrapIndex
2024

2125
-- | The type representing a single sigmoid layer
2226
data Layer f a b = Layer {weights :: V b (V a f), biases :: V b f}
@@ -134,5 +138,25 @@ runMiniBatch xys lr =
134138
backPropogate (collect fst xys, collect snd xys) (pure lr) .
135139
batch
136140

141+
emptyLayer :: (Dim a, Dim b) => Layer () a b
142+
emptyLayer = Layer {weights = pure (pure ()), biases = pure ()}
143+
144+
randomizeLayer' :: (Dim a, Dim b, Distribution d t) => d t -> Layer x a b -> RVar (Layer t a b)
145+
randomizeLayer' dist Layer{..} =
146+
Layer <$>
147+
traverse (traverse (\_ -> rvar dist)) weights <*>
148+
traverse (\_ -> rvar dist) biases
149+
150+
randomize' :: Distribution d t => d t -> Network x a b -> RVar (Network t a b)
151+
randomize' dist = iunwrap . traverseNetwork (IWrap . randomizeLayer' dist)
152+
153+
randomize :: Distribution Normal t => Network x a b -> RVar (Network t a b)
154+
randomize = randomize' StdNormal
155+
137156
main :: IO ()
138-
main = putStrLn "Hello, Haskell!"
157+
main = do
158+
let net0 = Lr (emptyLayer :: Layer () 10 3) Id .
159+
Lr (emptyLayer :: Layer () 10 10) Id .
160+
Lr (emptyLayer :: Layer () 3 10) Id
161+
net <- runRVar (randomize net0) StdRandom :: IO (Network Double 3 3)
162+
return ()

WrapIndex.hs

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{-# LANGUAGE PolyKinds #-}
2+
module WrapIndex where
3+
4+
import Control.Monad.Indexed
5+
6+
newtype IWrap m i j a = IWrap {iunwrap :: m a}
7+
8+
instance Functor m => IxFunctor (IWrap m) where
9+
imap f = IWrap . fmap f . iunwrap
10+
11+
instance Applicative m => IxPointed (IWrap m) where
12+
ireturn = IWrap . pure
13+
14+
instance Applicative m => IxApplicative (IWrap m) where
15+
IWrap f `iap` IWrap x = IWrap $ f <*> x
16+
17+
instance Monad m => IxMonad (IWrap m) where
18+
f `ibind` IWrap x = IWrap $ x >>= iunwrap . f

net.cabal

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ cabal-version: >=1.10
1717

1818
executable net
1919
main-is: Main.hs
20-
other-modules: IndexedTardis
20+
other-modules: IndexedTardis, WrapIndex
2121
-- other-extensions:
22-
build-depends: base >=4.9 && <4.10, linear, ad, distributive, indexed, composition
22+
build-depends: base >=4.9 && <4.10, linear, ad, distributive, indexed, composition, random-fu
2323
-- hs-source-dirs:
2424
default-language: Haskell2010

0 commit comments

Comments
 (0)