Skip to content

Commit 0599cc7

Browse files
authored
Merge pull request #92 from slimgroup/logdet_glow
add logdet option to glow and example of training with flux
2 parents a5ef4ed + 827638d commit 0599cc7

File tree

5 files changed

+217
-164
lines changed

5 files changed

+217
-164
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "InvertibleNetworks"
22
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
33
authors = ["Philipp Witte <[email protected]>", "Ali Siahkoohi <[email protected]>", "Mathias Louboutin <[email protected]>", "Gabrio Rizzuti <[email protected]>", "Rafael Orozco <[email protected]>", "Felix J. herrmann <[email protected]>"]
4-
version = "2.2.5"
4+
version = "2.2.6"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Train networks with flux. Only guaranteed to work with logdet=false for now.
2+
# So you can train them as invertible networks like this, not as normalizing flows.
3+
using InvertibleNetworks, Flux
4+
5+
# Glow Network
6+
model = NetworkGlow(2, 32, 2, 5; logdet=false)
7+
8+
# dummy input & target
9+
X = randn(Float32, 16, 16, 2, 2)
10+
Y = 2 .* X .+ 1
11+
12+
# loss fn
13+
loss(model, X, Y) = Flux.mse(Y, model(X))
14+
15+
θ = Flux.params(model)
16+
opt = ADAM(0.0001f0)
17+
18+
for i = 1:500
19+
l, grads = Flux.withgradient(θ) do
20+
loss(model, X, Y)
21+
end
22+
@show l
23+
Flux.update!(opt, θ, grads)
24+
end

src/networks/invertible_network_glow.jl

+14-11
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ export NetworkGlow, NetworkGlow3D
4040
4141
- `squeeze_type` : squeeze type that happens at each multiscale level
4242
43+
- `logdet` : boolean to turn on/off logdet term tracking and gradient calculation
44+
4345
*Output*:
4446
4547
- `G`: invertible Glow network.
@@ -67,12 +69,13 @@ struct NetworkGlow <: InvertibleNetwork
6769
K::Int64
6870
squeezer::Squeezer
6971
split_scales::Bool
72+
logdet::Bool
7073
end
7174

7275
@Flux.functor NetworkGlow
7376

7477
# Constructor
75-
function NetworkGlow(n_in, n_hidden, L, K; nx=nothing, dense=false, freeze_conv=false, split_scales=false, k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())
78+
function NetworkGlow(n_in, n_hidden, L, K; logdet=true,nx=nothing, dense=false, freeze_conv=false, split_scales=false, k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())
7679
(n_in == 1) && (split_scales = true) # Need extra channels for coupling layer
7780
(dense && isnothing(nx)) && error("Dense network needs nx as kwarg input")
7881

@@ -91,29 +94,28 @@ function NetworkGlow(n_in, n_hidden, L, K; nx=nothing, dense=false, freeze_conv=
9194
n_in *= channel_factor # squeeze if split_scales is turned on
9295
(dense && split_scales) && (nx = Int64(nx/2))
9396
for j=1:K
94-
AN[i, j] = ActNorm(n_in; logdet=true)
95-
CL[i, j] = CouplingLayerGlow(n_in, n_hidden; nx=nx, dense=dense, freeze_conv=freeze_conv, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=true, activation=activation, ndims=ndims)
97+
AN[i, j] = ActNorm(n_in; logdet=logdet)
98+
CL[i, j] = CouplingLayerGlow(n_in, n_hidden; nx=nx, dense=dense, freeze_conv=freeze_conv, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, activation=activation, ndims=ndims)
9699
end
97100
(i < L && split_scales) && (n_in = Int64(n_in/2); ) # split
98101
end
99102

100-
return NetworkGlow(AN, CL, Z_dims, L, K, squeezer, split_scales)
103+
return NetworkGlow(AN, CL, Z_dims, L, K, squeezer, split_scales,logdet)
101104
end
102105

103106
NetworkGlow3D(args; kw...) = NetworkGlow(args...; kw..., ndims=3)
104107

105108
# Forward pass and compute logdet
106-
function forward(X::AbstractArray{T, N}, G::NetworkGlow) where {T, N}
109+
function forward(X::AbstractArray{T, N}, G::NetworkGlow;) where {T, N}
107110
G.split_scales && (Z_save = array_of_array(X, max(G.L-1,1)))
108111

109-
110-
logdet = 0
112+
logdet_ = 0
111113
for i=1:G.L
112114
(G.split_scales) && (X = G.squeezer.forward(X))
113115
for j=1:G.K
114-
X, logdet1 = G.AN[i, j].forward(X)
115-
X, logdet2 = G.CL[i, j].forward(X)
116-
logdet += (logdet1 + logdet2)
116+
G.logdet ? (X, logdet1) = G.AN[i, j].forward(X) : X = G.AN[i, j].forward(X)
117+
G.logdet ? (X, logdet2) = G.CL[i, j].forward(X) : X = G.CL[i, j].forward(X)
118+
G.logdet && (logdet_ += (logdet1 + logdet2))
117119
end
118120
if G.split_scales && (i < G.L || i == 1) # don't split after last iteration
119121
X, Z = tensor_split(X)
@@ -122,7 +124,8 @@ function forward(X::AbstractArray{T, N}, G::NetworkGlow) where {T, N}
122124
end
123125
end
124126
G.split_scales && (X = cat_states(Z_save, X))
125-
return X, logdet
127+
128+
G.logdet ? (return X, logdet_) : (return X)
126129
end
127130

128131
# Inverse pass

test/runtests.jl

+18-3
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,30 @@ if test_suite == "all" || test_suite == "layers"
6767
end
6868
end
6969

70-
# Networks
70+
max_attempts=3
7171
if test_suite == "all" || test_suite == "networks"
7272
@testset verbose = true "Networks" begin
7373
for t=networks
74-
@testset "Test $t" begin
75-
@timeit TIMEROUTPUT "$t" begin include(t) end
74+
for attempt in 1:max_attempts
75+
println("Running tests, attempt $attempt...")
76+
try
77+
results = @testset "Test $t" begin
78+
@timeit TIMEROUTPUT "$t" begin include(t) end
79+
end
80+
81+
if all(record->record.status == :pass, results.results)
82+
println("Tests passed on attempt $attempt.")
83+
return
84+
end
85+
catch e
86+
println("Tests failed on attempt $attempt. Retrying...")
87+
end
7688
end
89+
println("Tests failed after $max_attempts attempts.")
7790
end
7891
end
7992
end
8093

94+
95+
8196
show(TIMEROUTPUT; compact=true, sortby=:firstexec)

0 commit comments

Comments
 (0)