Skip to content

Commit 91669a0

Browse files
committed
add example of conditional glow 2d and 3d
1 parent 857bac9 commit 91669a0

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Generative model w/ Glow architecture from Kingma & Dhariwal (2018)
2+
# Network layers are made conditional with CIIN type layers
3+
# Author: Rafael Orozco, [email protected]
4+
# Date: March 2023
5+
6+
using InvertibleNetworks, LinearAlgebra, Flux
7+
8+
device = InvertibleNetworks.CUDA.functional() ? gpu : cpu
9+
10+
nx = 32 # must be multiple of 2^L where L is the multiscale level of the network
11+
ny = 32 # must be multiple of 2^L where L is the multiscale level of the network
12+
n_in = 4
13+
n_cond = 4
14+
n_hidden = 32
15+
batchsize = 5
16+
L = 2 # number of scales
17+
K = 2 # number of flow steps per scale
18+
19+
# Input
20+
X = rand(Float32, nx, ny, n_in, batchsize) |> device;
21+
22+
# Condition
23+
Y = rand(Float32, nx, ny, n_in, batchsize) |> device;
24+
25+
# Glow network
26+
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K) |> device
27+
28+
# Objective function
29+
function loss(G, X, Y)
30+
ZX, ZY, logdet = G.forward(X, Y)
31+
f = .5f0/batchsize*norm(ZX)^2 - logdet
32+
G.backward(1f0./batchsize*ZX, ZX, ZY)
33+
return f
34+
end
35+
36+
# Evaluate loss
37+
f = loss(G, X, Y)
38+
39+
# Update weights
40+
opt = Flux.ADAM()
41+
Params = get_params(G)
42+
for p in Params
43+
Flux.update!(opt, p.data, p.grad)
44+
end
45+
clear_grad!(G)
46+
47+
################ 3D example: To do with 3 spatial dimensions you need to set ndims on network.
48+
############################## or use NetworkConditionalGlow3D
49+
nz = 32
50+
51+
# 3D Input
52+
X_3d = rand(Float32, nx, ny, nz, n_in, batchsize) |> device;
53+
54+
# #dCondition
55+
Y_3d = rand(Float32, nx, ny, nz, n_in, batchsize) |> device;
56+
57+
# 3D Glow network
58+
G_3d = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; ndims=3) |> device
59+
60+
# Evaluate loss
61+
f = loss(G_3d, X_3d, Y_3d)
62+
63+
# Update weights
64+
opt = Flux.ADAM()
65+
Params = get_params(G_3d)
66+
for p in Params
67+
Flux.update!(opt, p.data, p.grad)
68+
end
69+
clear_grad!(G_3d)

0 commit comments

Comments
 (0)