Skip to content

Commit 5f571ec

Browse files
authored
Merge pull request #68 from slimgroup/activation-hint
add activation parameter to conditional multiscale hint
2 parents f6b2d17 + bfb1149 commit 5f571ec

9 files changed

+44
-41
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "InvertibleNetworks"
22
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
33
authors = ["Philipp Witte <[email protected]>", "Ali Siahkoohi <[email protected]>", "Mathias Louboutin <[email protected]>", "Gabrio Rizzuti <[email protected]>", "Rafael Orozco <[email protected]>", "Felix J. herrmann <[email protected]>"]
4-
version = "2.2.0"
4+
version = "2.2.1"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/conditional_layers/conditional_layer_hint.jl

+7-9
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ end
6565

6666
# 2D Constructor from input dimensions
6767
function ConditionalLayerHINT(n_in::Int64, n_hidden::Int64; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1,
68-
logdet=true, permute=true, ndims=2)
68+
logdet=true, permute=true, ndims=2, activation::ActivationFunction=SigmoidLayer())
6969

7070
# Create basic coupling layers
71-
CL_X = CouplingLayerHINT(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, permute="none", ndims=ndims)
72-
CL_Y = CouplingLayerHINT(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, permute="none", ndims=ndims)
73-
CL_YX = CouplingLayerBasic(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, ndims=ndims)
71+
CL_X = CouplingLayerHINT(n_in, n_hidden; activation=activation,k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, permute="none", ndims=ndims)
72+
CL_Y = CouplingLayerHINT(n_in, n_hidden; activation=activation,k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, permute="none", ndims=ndims)
73+
CL_YX = CouplingLayerBasic(n_in, n_hidden; activation=activation,k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, ndims=ndims)
7474

7575
# Permutation using 1x1 convolution
7676
permute == true ? (C_X = Conv1x1(n_in)) : (C_X = nothing)
@@ -209,11 +209,10 @@ function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::Abs
209209
return ΔZx, ΔZy, Zx, Zy
210210
end
211211

212-
function forward_Y(Y::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T, N}
212+
function forward_Y(Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=false) where {T, N}
213213
~isnothing(CH.C_Y) ? (Yp = CH.C_Y.forward(Y)) : (Yp = copy(Y))
214-
Zy = CH.CL_Y.forward(Yp; logdet=false)
215-
return Zy
216-
214+
logdet ? (Zy, logdet_) = CH.CL_Y.forward(Yp; logdet=true) : Zy = CH.CL_Y.forward(Yp; logdet=false)
215+
logdet ? (return Zy, logdet_) : (return Zy)
217216
end
218217

219218
function inverse_Y(Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T, N}
@@ -222,7 +221,6 @@ function inverse_Y(Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T,
222221
return Y
223222
end
224223

225-
226224
## Jacobian-related utils
227225

228226
function jacobian(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing) where {T, N}

src/layers/invertible_layer_hint.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@ end
7373

7474
# Constructor for given coupling layer and 1 x 1 convolution
7575
CouplingLayerHINT(CL::AbstractArray{CouplingLayerBasic, 1}, C::Union{Conv1x1, Nothing};
76-
logdet=false, permute="none") = CouplingLayerHINT(CL, C, logdet, permute, false)
76+
logdet=false, permute="none", activation::ActivationFunction=SigmoidLayer()) = CouplingLayerHINT(CL, C, logdet, permute, false)
7777

7878
# 2D Constructor from input dimensions
7979
function CouplingLayerHINT(n_in::Int64, n_hidden::Int64; logdet=false, permute="none",
80-
k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, ndims=2)
80+
k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, ndims=2, activation::ActivationFunction=SigmoidLayer())
8181

8282
# Create basic coupling layers
8383
n = get_depth(n_in)
8484
CL = Array{CouplingLayerBasic}(undef, n)
8585
for j=1:n
86-
CL[j] = CouplingLayerBasic(Int(n_in/2^j), n_hidden; k1=k1, k2=k2, p1=p1, p2=p2,
86+
CL[j] = CouplingLayerBasic(Int(n_in/2^j), n_hidden;activation=activation, k1=k1, k2=k2, p1=p1, p2=p2,
8787
s1=s1, s2=s2, logdet=logdet, ndims=ndims)
8888
end
8989

src/networks/invertible_network_conditional_glow.jl

+1
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractA
169169
end
170170

171171
if G.split_scales
172+
ΔC_total = G.squeezer.inverse(ΔC_total)
172173
C = G.squeezer.inverse(C)
173174
X = G.squeezer.inverse(X)
174175
ΔX = G.squeezer.inverse(ΔX)

src/networks/invertible_network_conditional_hint.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ end
5858
@Flux.functor NetworkConditionalHINT
5959

6060
# Constructor
61-
function NetworkConditionalHINT(n_in, n_hidden, depth; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, logdet=true, ndims=2)
61+
function NetworkConditionalHINT(n_in, n_hidden, depth; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, logdet=true, ndims=2,activation::ActivationFunction=SigmoidLayer(), )
6262

6363
AN_X = Array{ActNorm}(undef, depth)
6464
AN_Y = Array{ActNorm}(undef, depth)
@@ -68,7 +68,7 @@ function NetworkConditionalHINT(n_in, n_hidden, depth; k1=3, k2=3, p1=1, p2=1, s
6868
for j=1:depth
6969
AN_X[j] = ActNorm(n_in; logdet=logdet)
7070
AN_Y[j] = ActNorm(n_in; logdet=logdet)
71-
CL[j] = ConditionalLayerHINT(n_in, n_hidden; permute=true, k1=k1, k2=k2, p1=p1, p2=p2,
71+
CL[j] = ConditionalLayerHINT(n_in, n_hidden; activation=activation,permute=true, k1=k1, k2=k2, p1=p1, p2=p2,
7272
s1=s1, s2=s2, logdet=logdet, ndims=ndims)
7373
end
7474

src/networks/invertible_network_conditional_hint_multiscale.jl

+24-21
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ end
7171

7272
# Constructor
7373
function NetworkMultiScaleConditionalHINT(n_in::Int64, n_hidden::Int64, L::Int64, K::Int64;
74-
split_scales=false, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, logdet=true, ndims=2, squeezer::Squeezer=ShuffleLayer())
74+
split_scales=false, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, logdet=true, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())
7575

7676
AN_X = Array{ActNorm}(undef, L, K)
7777
AN_Y = Array{ActNorm}(undef, L, K)
@@ -89,7 +89,7 @@ function NetworkMultiScaleConditionalHINT(n_in::Int64, n_hidden::Int64, L::Int64
8989
for j=1:K
9090
AN_X[i, j] = ActNorm(n_in*4; logdet=logdet)
9191
AN_Y[i, j] = ActNorm(n_in*4; logdet=logdet)
92-
CL[i, j] = ConditionalLayerHINT(n_in*4, n_hidden; permute=true, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, ndims=ndims)
92+
CL[i, j] = ConditionalLayerHINT(n_in*4, n_hidden; permute=true, activation=activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, ndims=ndims)
9393
end
9494
n_in *= channel_factor
9595
end
@@ -131,6 +131,28 @@ function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::NetworkMult
131131
logdet ? (return X, Y, logdet_) : (return X, Y)
132132
end
133133

134+
# Forward pass and compute logdet
135+
function forward_Y(Y::AbstractArray{T, N}, CH::NetworkMultiScaleConditionalHINT; logdet=false) where {T, N}
136+
CH.split_scales && (Y_save = array_of_array(Y, CH.L-1))
137+
138+
logdet_ = 0f0
139+
for i=1:CH.L
140+
Y = CH.squeezer.forward(Y)
141+
for j=1:CH.K
142+
logdet ? (Y_,logdet1) = CH.AN_Y[i, j].forward(Y; logdet=true) : Y_ = CH.AN_Y[i, j].forward(Y; logdet=false)
143+
logdet ? (Y_,logdet2) = CH.CL[i, j].forward_Y(Y_; logdet=true) : Y = CH.CL[i, j].forward_Y(Y_; logdet=false)
144+
logdet && (logdet_ += (logdet1 + logdet2))
145+
end
146+
if CH.split_scales && i < CH.L # don't split after last iteration
147+
Y, Zy = tensor_split(Y)
148+
Y_save[i] = Zy
149+
CH.XY_dims[i] = collect(size(Zy))
150+
end
151+
end
152+
CH.split_scales && (Y = cat_states(Y_save, Y))
153+
logdet ? (return Y, logdet_) : (return Y)
154+
end
155+
134156
# Inverse pass and compute gradients
135157
function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::NetworkMultiScaleConditionalHINT; logdet=nothing) where {T, N}
136158
isnothing(logdet) ? logdet = (CH.logdet && CH.is_reversed) : logdet = logdet
@@ -234,26 +256,7 @@ function backward_inv(ΔX, ΔY, X, Y, CH::NetworkMultiScaleConditionalHINT)
234256
end
235257
end
236258

237-
# Forward pass and compute logdet
238-
function forward_Y(Y::AbstractArray{T, N}, CH::NetworkMultiScaleConditionalHINT) where {T, N}
239-
CH.split_scales && (Y_save = array_of_array(Y, CH.L-1))
240-
241-
for i=1:CH.L
242-
Y = CH.squeezer.forward(Y)
243-
for j=1:CH.K
244-
Y_ = CH.AN_Y[i, j].forward(Y; logdet=false)
245-
Y = CH.CL[i, j].forward_Y(Y_)
246-
end
247-
if CH.split_scales && i < CH.L # don't split after last iteration
248-
Y, Zy = tensor_split(Y)
249-
Y_save[i] = Zy
250-
CH.XY_dims[i] = collect(size(Zy))
251-
end
252-
end
253-
CH.split_scales && (Y = cat_states(Y_save, Y))
254-
return Y
255259

256-
end
257260

258261
# Inverse pass and compute gradients
259262
function inverse_Y(Zy::AbstractArray{T, N}, CH::NetworkMultiScaleConditionalHINT) where {T, N}

src/networks/invertible_network_hint_multiscale.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ end
6767

6868
# Constructor
6969
function NetworkMultiScaleHINT(n_in::Int64, n_hidden::Int64, L::Int64, K::Int64;
70-
split_scales=false, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, ndims=2)
70+
split_scales=false, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, activation::ActivationFunction=SigmoidLayer(), ndims=2)
7171

7272
AN = Array{ActNorm}(undef, L, K)
7373
CL = Array{CouplingLayerHINT}(undef, L, K)
@@ -83,7 +83,7 @@ function NetworkMultiScaleHINT(n_in::Int64, n_hidden::Int64, L::Int64, K::Int64;
8383
for i=1:L
8484
for j=1:K
8585
AN[i, j] = ActNorm(n_in*4; logdet=true)
86-
CL[i, j] = CouplingLayerHINT(n_in*4, n_hidden; permute="full", k1=k1, k2=k2, p1=p1, p2=p2,
86+
CL[i, j] = CouplingLayerHINT(n_in*4, n_hidden; activation=activation,permute="full", k1=k1, k2=k2, p1=p1, p2=p2,
8787
s1=s1, s2=s2, logdet=true, ndims=ndims)
8888
end
8989
n_in *= channel_factor

src/networks/invertible_network_irim.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ end
6666
@Flux.functor NetworkLoop
6767

6868
# 2D Constructor
69-
function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2)
69+
function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2, activation::ActivationFunction=SigmoidLayer())
7070

7171
if type == "additive"
7272
L = Array{CouplingLayerIRIM}(undef, maxiter)
@@ -77,7 +77,7 @@ function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4,
7777
AN = Array{ActNorm}(undef, maxiter)
7878
for j=1:maxiter
7979
if type == "additive"
80-
L[j] = CouplingLayerIRIM(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims)
80+
L[j] = CouplingLayerIRIM(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2,activation=activation, ndims=ndims)
8181
elseif type == "HINT"
8282
L[j] = CouplingLayerHINT(n_in, n_hidden; logdet=false, permute="both", k1=k1, k2=k2, p1=p1, p2=p2,
8383
s1=s1, s2=s2, ndims=ndims)

test/test_layers/test_actnorm.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
# Date: January 2020
33

44
using InvertibleNetworks, LinearAlgebra, Test, Statistics
5+
using Random
56

6-
7+
Random.seed!(11)
78
###############################################################################
89
# Test logdet implementation
910

0 commit comments

Comments
 (0)