1
+ # Generative model w/ Glow architecture from Kingma & Dhariwal (2018)
2
+ # Network layers are made conditional with CIIN type layers
3
+ # Author: Rafael Orozco, [email protected]
4
+ # Date: March 2023
5
+
6
+ using InvertibleNetworks, LinearAlgebra, Flux
7
+
8
+ device = InvertibleNetworks. CUDA. functional () ? gpu : cpu
9
+
10
+ nx = 32 # must be multiple of 2^L where L is the multiscale level of the network
11
+ ny = 32 # must be multiple of 2^L where L is the multiscale level of the network
12
+ n_in = 4
13
+ n_cond = 4
14
+ n_hidden = 32
15
+ batchsize = 5
16
+ L = 2 # number of scales
17
+ K = 2 # number of flow steps per scale
18
+
19
+ # Input
20
+ X = rand (Float32, nx, ny, n_in, batchsize) |> device;
21
+
22
+ # Condition
23
+ Y = rand (Float32, nx, ny, n_in, batchsize) |> device;
24
+
25
+ # Glow network
26
+ G = NetworkConditionalGlow (n_in, n_cond, n_hidden, L, K) |> device
27
+
28
+ # Objective function
29
+ function loss (G, X, Y)
30
+ ZX, ZY, logdet = G. forward (X, Y)
31
+ f = .5f0 / batchsize* norm (ZX)^ 2 - logdet
32
+ G. backward (1f0 ./ batchsize* ZX, ZX, ZY)
33
+ return f
34
+ end
35
+
36
+ # Evaluate loss
37
+ f = loss (G, X, Y)
38
+
39
+ # Update weights
40
+ opt = Flux. ADAM ()
41
+ Params = get_params (G)
42
+ for p in Params
43
+ Flux. update! (opt, p. data, p. grad)
44
+ end
45
+ clear_grad! (G)
46
+
47
+ # ############### 3D example: To do with 3 spatial dimensions you need to set ndims on network.
48
+ # ############################# or use NetworkConditionalGlow3D
49
+ nz = 32
50
+
51
+ # 3D Input
52
+ X_3d = rand (Float32, nx, ny, nz, n_in, batchsize) |> device;
53
+
54
+ # #dCondition
55
+ Y_3d = rand (Float32, nx, ny, nz, n_in, batchsize) |> device;
56
+
57
+ # 3D Glow network
58
+ G_3d = NetworkConditionalGlow (n_in, n_cond, n_hidden, L, K; ndims= 3 ) |> device
59
+
60
+ # Evaluate loss
61
+ f = loss (G_3d, X_3d, Y_3d)
62
+
63
+ # Update weights
64
+ opt = Flux. ADAM ()
65
+ Params = get_params (G_3d)
66
+ for p in Params
67
+ Flux. update! (opt, p. data, p. grad)
68
+ end
69
+ clear_grad! (G_3d)
0 commit comments