1
- {-# LANGUAGE GADTs #-}
2
- {-# LANGUAGE Rank2Types #-}
3
- {-# LANGUAGE RecordWildCards #-}
1
+ {-# LANGUAGE FlexibleContexts #-}
2
+ {-# LANGUAGE GADTs #-}
3
+ {-# LANGUAGE Rank2Types #-}
4
+ {-# LANGUAGE RecordWildCards #-}
5
+ {-# LANGUAGE PolyKinds #-}
4
6
module Main where
5
7
6
8
import Control.Category
7
9
import Control.Monad.Indexed
8
10
9
11
import Data.Composition
10
12
import Data.Distributive
13
+ import Data.Random
11
14
12
15
import Linear
13
16
import Linear.V
@@ -17,6 +20,7 @@ import Numeric.AD
17
20
import Prelude hiding (id , (.) )
18
21
19
22
import IndexedTardis
23
+ import WrapIndex
20
24
21
25
-- | The type representing a single sigmoid layer
22
26
data Layer f a b = Layer { weights :: V b (V a f ), biases :: V b f }
@@ -134,5 +138,25 @@ runMiniBatch xys lr =
134
138
backPropogate (collect fst xys, collect snd xys) (pure lr) .
135
139
batch
136
140
141
+ emptyLayer :: (Dim a , Dim b ) => Layer () a b
142
+ emptyLayer = Layer {weights = pure (pure () ), biases = pure () }
143
+
144
+ randomizeLayer' :: (Dim a , Dim b , Distribution d t ) => d t -> Layer x a b -> RVar (Layer t a b )
145
+ randomizeLayer' dist Layer {.. } =
146
+ Layer <$>
147
+ traverse (traverse (\ _ -> rvar dist)) weights <*>
148
+ traverse (\ _ -> rvar dist) biases
149
+
150
+ randomize' :: Distribution d t => d t -> Network x a b -> RVar (Network t a b )
151
+ randomize' dist = iunwrap . traverseNetwork (IWrap . randomizeLayer' dist)
152
+
153
+ randomize :: Distribution Normal t => Network x a b -> RVar (Network t a b )
154
+ randomize = randomize' StdNormal
155
+
137
156
main :: IO ()
138
- main = putStrLn " Hello, Haskell!"
157
+ main = do
158
+ let net0 = Lr (emptyLayer :: Layer () 10 3 ) Id .
159
+ Lr (emptyLayer :: Layer () 10 10 ) Id .
160
+ Lr (emptyLayer :: Layer () 3 10 ) Id
161
+ net <- runRVar (randomize net0) StdRandom :: IO (Network Double 3 3 )
162
+ return ()
0 commit comments