Skip to content

Commit 81ef037

Browse files
authored
Merge pull request #24 from slimgroup/reverse_multi
make NetworkMultiScaleConditionalHINT reversible and use logdet correctly
2 parents cf56bd0 + 3b80485 commit 81ef037

13 files changed

+400
-165
lines changed

src/layers/invertible_layer_glow.jl

+19-16
Original file line numberDiff line numberDiff line change
@@ -64,27 +64,28 @@ struct CouplingLayerGlow <: NeuralNetLayer
6464
C::Conv1x1
6565
RB::Union{ResidualBlock, FluxBlock}
6666
logdet::Bool
67+
activation::ActivationFunction
6768
end
6869

6970
@Flux.functor CouplingLayerGlow
7071

7172
# Constructor from 1x1 convolution and residual block
72-
function CouplingLayerGlow(C::Conv1x1, RB::ResidualBlock; logdet=false)
73+
function CouplingLayerGlow(C::Conv1x1, RB::ResidualBlock; logdet=false, activation::ActivationFunction=SigmoidLayer())
7374
RB.fan == false && throw("Set ResidualBlock.fan == true")
74-
return CouplingLayerGlow(C, RB, logdet)
75+
return CouplingLayerGlow(C, RB, logdet, activation)
7576
end
7677

7778
# Constructor from 1x1 convolution and residual Flux block
78-
CouplingLayerGlow(C::Conv1x1, RB::FluxBlock; logdet=false) = CouplingLayerGlow(C, RB, logdet)
79+
CouplingLayerGlow(C::Conv1x1, RB::FluxBlock; logdet=false, activation::ActivationFunction=SigmoidLayer()) = CouplingLayerGlow(C, RB, logdet, activation)
7980

8081
# Constructor from input dimensions
81-
function CouplingLayerGlow(n_in::Int64, n_hidden::Int64; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, ndims=2)
82+
function CouplingLayerGlow(n_in::Int64, n_hidden::Int64; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, activation::ActivationFunction=SigmoidLayer(), ndims=2)
8283

8384
# 1x1 Convolution and residual block for invertible layer
8485
C = Conv1x1(n_in)
8586
RB = ResidualBlock(Int(n_in/2), n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims)
8687

87-
return CouplingLayerGlow(C, RB, logdet)
88+
return CouplingLayerGlow(C, RB, logdet, activation)
8889
end
8990

9091
CouplingLayerGlow3D(args...;kw...) = CouplingLayerGlow(args...; kw..., ndims=3)
@@ -100,9 +101,10 @@ function forward(X::AbstractArray{T, 4}, L::CouplingLayerGlow) where T
100101

101102
Y2 = copy(X2)
102103
logS_T = L.RB.forward(X2)
103-
Sm = Sigmoid(logS_T[:,:,1:k,:])
104+
Sm = L.activation.forward(logS_T[:,:,1:k,:])
104105
Tm = logS_T[:, :, k+1:end, :]
105106
Y1 = Sm.*X1 + Tm
107+
106108
Y = tensor_cat(Y1, Y2)
107109

108110
L.logdet == true ? (return Y, glow_logdet_forward(Sm)) : (return Y)
@@ -117,9 +119,10 @@ function inverse(Y::AbstractArray{T, 4}, L::CouplingLayerGlow; save=false) where
117119

118120
X2 = copy(Y2)
119121
logS_T = L.RB.forward(X2)
120-
Sm = Sigmoid(logS_T[:,:,1:k,:])
122+
Sm = L.activation.forward(logS_T[:,:,1:k,:])
121123
Tm = logS_T[:, :, k+1:end, :]
122124
X1 = (Y1 - Tm) ./ (Sm .+ eps(T)) # add epsilon to avoid division by 0
125+
123126
X_ = tensor_cat(X1, X2)
124127
X = L.C.inverse(X_)
125128

@@ -143,10 +146,10 @@ function backward(ΔY::AbstractArray{T, 4}, Y::AbstractArray{T, 4}, L::CouplingL
143146

144147
ΔX1 = ΔY1 .* S
145148
if set_grad
146-
ΔX2 = L.RB.backward(cat(SigmoidGrad(ΔS, S), ΔT; dims=3), X2) + ΔY2
149+
ΔX2 = L.RB.backward(cat(L.activation.backward(ΔS, S), ΔT; dims=3), X2) + ΔY2
147150
else
148-
ΔX2, Δθrb = L.RB.backward(cat(SigmoidGrad(ΔS, S), ΔT; dims=3), X2; set_grad=set_grad)
149-
_, ∇logdet = L.RB.backward(cat(SigmoidGrad(ΔS_, S), 0 .*ΔT; dims=3), X2; set_grad=set_grad)
151+
ΔX2, Δθrb = L.RB.backward(cat(L.activation.backward(ΔS, S), ΔT; dims=3), X2; set_grad=set_grad)
152+
_, ∇logdet = L.RB.backward(cat(L.activation.backward(ΔS_, S), 0f0.*ΔT; dims=3), X2; set_grad=set_grad)
150153
ΔX2 += ΔY2
151154
end
152155
ΔX_ = tensor_cat(ΔX1, ΔX2)
@@ -179,20 +182,20 @@ function jacobian(ΔX::AbstractArray{T, 4}, Δθ::Array{Parameter, 1}, X, L::Cou
179182
Y2 = copy(X2)
180183
ΔY2 = copy(ΔX2)
181184
ΔlogS_T, logS_T = L.RB.jacobian(ΔX2, Δθ[4:end], X2)
182-
S = Sigmoid(logS_T[:,:,1:k,:])
183-
ΔS = SigmoidGrad(ΔlogS_T[:,:,1:k,:], nothing; x=logS_T[:,:,1:k,:])
185+
Sm = L.activation.forward(logS_T[:,:,1:k,:])
186+
ΔS = L.activation.backward(ΔlogS_T[:,:,1:k,:], nothing;x=logS_T[:,:,1:k,:])
184187
Tm = logS_T[:, :, k+1:end, :]
185188
ΔT = ΔlogS_T[:, :, k+1:end, :]
186-
Y1 = S.*X1 + Tm
187-
ΔY1 = ΔS.*X1 + S.*ΔX1 + ΔT
189+
Y1 = Sm.*X1 + Tm
190+
ΔY1 = ΔS.*X1 + Sm.*ΔX1 + ΔT
188191
Y = tensor_cat(Y1, Y2)
189192
ΔY = tensor_cat(ΔY1, ΔY2)
190193

191194
# Gauss-Newton approximation of logdet terms
192195
JΔθ = L.RB.jacobian(cuzeros(ΔX2, size(ΔX2)), Δθ[4:end], X2)[1][:, :, 1:k, :]
193-
GNΔθ = cat(0*Δθ[1:3], -L.RB.adjointJacobian(tensor_cat(SigmoidGrad(JΔθ, S), zeros(Float32, size(S))), X2)[2]; dims=1)
196+
GNΔθ = cat(0f0*Δθ[1:3], -L.RB.adjointJacobian(tensor_cat(L.activation.backward(JΔθ, Sm), zeros(Float32, size(Sm))), X2)[2]; dims=1)
194197

195-
L.logdet ? (return ΔY, Y, glow_logdet_forward(S), GNΔθ) : (return ΔY, Y)
198+
L.logdet ? (return ΔY, Y, glow_logdet_forward(Sm), GNΔθ) : (return ΔY, Y)
196199
end
197200

198201
function adjointJacobian(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingLayerGlow) where {T, N}

0 commit comments

Comments
 (0)