Skip to content

Commit 033a4a1

Browse files
authored
Merge pull request #59 from slimgroup/conditional-glow
Conditional glow network and flexible residual block
2 parents a068423 + 25e999f commit 033a4a1

11 files changed

+803
-192
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "InvertibleNetworks"
22
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
33
authors = ["Philipp Witte <[email protected]>", "Ali Siahkoohi <[email protected]>", "Mathias Louboutin <[email protected]>", "Gabrio Rizzuti <[email protected]>", "Rafael Orozco <[email protected]>", "Felix J. herrmann <[email protected]>"]
4-
version = "2.1.5"
4+
version = "2.2.0"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/InvertibleNetworks.jl

+2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ include("networks/invertible_network_glow.jl") # Glow: Dinh et al. (2017), King
6868
include("networks/invertible_network_hyperbolic.jl") # Hyperbolic: Lensink et al. (2019)
6969

7070
# Conditional layers and nets
71+
include("conditional_layers/conditional_layer_glow.jl")
7172
include("conditional_layers/conditional_layer_hint.jl")
73+
include("networks/invertible_network_conditional_glow.jl")
7274
include("networks/invertible_network_conditional_hint.jl")
7375
include("networks/invertible_network_conditional_hint_multiscale.jl")
7476

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Conditional coupling layer based on GLOW and cIIN
2+
# Date: January 2022
3+
4+
export ConditionalLayerGlow, ConditionalLayerGlow3D
5+
6+
7+
"""
8+
CL = ConditionalLayerGlow(C::Conv1x1, RB::ResidualBlock; logdet=false)
9+
10+
or
11+
12+
CL = ConditionalLayerGlow(n_in, n_cond, n_hidden; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, ndims=2) (2D)
13+
14+
CL = ConditionalLayerGlow(n_in, n_cond, n_hidden; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, ndims=3) (3D)
15+
16+
CL = ConditionalLayerGlowGlow3D(n_in, n_cond, n_hidden; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false) (3D)
17+
18+
Create a Real NVP-style invertible conditional coupling layer based on 1x1 convolutions and a residual block.
19+
20+
*Input*:
21+
22+
- `C::Conv1x1`: 1x1 convolution layer
23+
24+
- `RB::ResidualBlock`: residual block layer consisting of 3 convolutional layers with ReLU activations.
25+
26+
- `logdet`: bool to indicate whether to compte the logdet of the layer
27+
28+
or
29+
30+
- `n_in`,`n_out`, `n_hidden`: number of channels for: passive input, conditioned input and hidden layer
31+
32+
- `k1`, `k2`: kernel size of convolutions in residual block. `k1` is the kernel of the first and third
33+
operator, `k2` is the kernel size of the second operator.
34+
35+
- `p1`, `p2`: padding for the first and third convolution (`p1`) and the second convolution (`p2`)
36+
37+
- `s1`, `s2`: stride for the first and third convolution (`s1`) and the second convolution (`s2`)
38+
39+
- `ndims` : number of dimensions
40+
41+
*Output*:
42+
43+
- `CL`: Invertible Real NVP conditional coupling layer.
44+
45+
*Usage:*
46+
47+
- Forward mode: `Y, logdet = CL.forward(X, C)` (if constructed with `logdet=true`)
48+
49+
- Inverse mode: `X = CL.inverse(Y, C)`
50+
51+
- Backward mode: `ΔX, X = CL.backward(ΔY, Y, C)`
52+
53+
*Trainable parameters:*
54+
55+
- None in `CL` itself
56+
57+
- Trainable parameters in residual block `CL.RB` and 1x1 convolution layer `CL.C`
58+
59+
See also: [`Conv1x1`](@ref), [`ResidualBlock`](@ref), [`get_params`](@ref), [`clear_grad!`](@ref)
60+
"""
61+
struct ConditionalLayerGlow <: NeuralNetLayer
62+
C::Conv1x1
63+
RB::ResidualBlock
64+
logdet::Bool
65+
activation::ActivationFunction
66+
end
67+
68+
@Flux.functor ConditionalLayerGlow
69+
70+
# Constructor from 1x1 convolution and residual block
71+
function ConditionalLayerGlow(C::Conv1x1, RB::ResidualBlock; logdet=false, activation::ActivationFunction=SigmoidLayer())
72+
RB.fan == false && throw("Set ResidualBlock.fan == true")
73+
return ConditionalLayerGlow(C, RB, logdet, activation)
74+
end
75+
76+
# Constructor from input dimensions
77+
function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, activation::ActivationFunction=SigmoidLayer(), rb_activation::ActivationFunction=RELUlayer(), ndims=2)
78+
79+
# 1x1 Convolution and residual block for invertible layers
80+
C = Conv1x1(n_in)
81+
RB = ResidualBlock(Int(n_in/2)+n_cond, n_hidden; n_out=n_in, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims)
82+
83+
return ConditionalLayerGlow(C, RB, logdet, activation)
84+
end
85+
86+
ConditionalLayerGlow3D(args...;kw...) = ConditionalLayerGlow(args...; kw..., ndims=3)
87+
88+
# Forward pass: Input X, Output Y
89+
function forward(X::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalLayerGlow) where {T,N}
90+
91+
X_ = L.C.forward(X)
92+
X1, X2 = tensor_split(X_)
93+
94+
Y2 = copy(X2)
95+
96+
# Cat conditioning variable C into network input
97+
logS_T = L.RB.forward(tensor_cat(X2,C))
98+
logS, log_T = tensor_split(logS_T)
99+
100+
Sm = L.activation.forward(logS)
101+
Tm = log_T
102+
Y1 = Sm.*X1 + Tm
103+
104+
Y = tensor_cat(Y1, Y2)
105+
106+
L.logdet == true ? (return Y, glow_logdet_forward(Sm)) : (return Y)
107+
end
108+
109+
# Inverse pass: Input Y, Output X
110+
function inverse(Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalLayerGlow; save=false) where {T,N}
111+
112+
Y1, Y2 = tensor_split(Y)
113+
114+
X2 = copy(Y2)
115+
logS_T = L.RB.forward(tensor_cat(X2,C))
116+
logS, log_T = tensor_split(logS_T)
117+
118+
Sm = L.activation.forward(logS)
119+
Tm = log_T
120+
X1 = (Y1 - Tm) ./ (Sm .+ eps(T)) # add epsilon to avoid division by 0
121+
122+
X_ = tensor_cat(X1, X2)
123+
X = L.C.inverse(X_)
124+
125+
save == true ? (return X, X1, X2, Sm) : (return X)
126+
end
127+
128+
# Backward pass: Input (ΔY, Y), Output (ΔX, X)
129+
function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalLayerGlow;) where {T,N}
130+
131+
# Recompute forward state
132+
X, X1, X2, S = inverse(Y, C, L; save=true)
133+
134+
# Backpropagate residual
135+
ΔY1, ΔY2 = tensor_split(ΔY)
136+
ΔT = copy(ΔY1)
137+
ΔS = ΔY1 .* X1
138+
ΔX1 = ΔY1 .* S
139+
140+
if L.logdet
141+
ΔS -= glow_logdet_backward(S)
142+
end
143+
144+
# Backpropagate RB
145+
ΔX2_ΔC = L.RB.backward(cat(L.activation.backward(ΔS, S), ΔT; dims=3), (tensor_cat(X2, C)))
146+
ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=Int(size(ΔY)[N-1]/2))
147+
ΔX2 += ΔY2
148+
149+
# Backpropagate 1x1 conv
150+
ΔX = L.C.inverse((tensor_cat(ΔX1, ΔX2), tensor_cat(X1, X2)))[1]
151+
152+
return ΔX, X, ΔC
153+
end

src/layers/layer_residual_block.jl

+31-22
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@ or
2020
2121
*Input*:
2222
23-
- `n_in`, `n_hidden`: number of input and hidden channels
23+
- `n_in`: number of input channels
24+
25+
- `n_hidden`: number of hidden channels
26+
27+
- `n_out`: number of ouput channels
28+
29+
- `activation`: activation type between conv layers and final output
2430
2531
- `k1`, `k2`: kernel size of convolutions in residual block. `k1` is the kernel of the first and third
2632
operator, `k2` is the kernel size of the second operator.
@@ -67,6 +73,7 @@ struct ResidualBlock <: NeuralNetLayer
6773
fan::Bool
6874
strides
6975
pad
76+
activation::ActivationFunction
7077
end
7178

7279
@Flux.functor ResidualBlock
@@ -75,22 +82,24 @@ end
7582
# Constructors
7683

7784
# Constructor
78-
function ResidualBlock(n_in, n_hidden; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2)
85+
function ResidualBlock(n_in, n_hidden; n_out=nothing, activation::ActivationFunction=ReLUlayer(), k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2)
86+
# default/legacy behaviour
87+
isnothing(n_out) && (n_out = 2*n_in)
7988

8089
k1 = Tuple(k1 for i=1:ndims)
8190
k2 = Tuple(k2 for i=1:ndims)
8291
# Initialize weights
8392
W1 = Parameter(glorot_uniform(k1..., n_in, n_hidden))
8493
W2 = Parameter(glorot_uniform(k2..., n_hidden, n_hidden))
85-
W3 = Parameter(glorot_uniform(k1..., 2*n_in, n_hidden))
94+
W3 = Parameter(glorot_uniform(k1..., n_out, n_hidden))
8695
b1 = Parameter(zeros(Float32, n_hidden))
8796
b2 = Parameter(zeros(Float32, n_hidden))
8897

89-
return ResidualBlock(W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2))
98+
return ResidualBlock(W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2), activation)
9099
end
91100

92101
# Constructor for given weights
93-
function ResidualBlock(W1, W2, W3, b1, b2; p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2)
102+
function ResidualBlock(W1, W2, W3, b1, b2; activation::ActivationFunction=ReLUlayer(), p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2)
94103

95104
# Make weights parameters
96105
W1 = Parameter(W1)
@@ -99,7 +108,7 @@ function ResidualBlock(W1, W2, W3, b1, b2; p1=1, p2=1, s1=1, s2=1, fan=false, nd
99108
b1 = Parameter(b1)
100109
b2 = Parameter(b2)
101110

102-
return ResidualBlock(W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2))
111+
return ResidualBlock(W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2),activation)
103112
end
104113

105114
ResidualBlock3D(args...; kw...) = ResidualBlock(args...; kw..., ndims=3)
@@ -111,17 +120,17 @@ function forward(X1::AbstractArray{T, N}, RB::ResidualBlock; save=false) where {
111120
inds =[i!=(N-1) ? 1 : Colon() for i=1:N]
112121

113122
Y1 = conv(X1, RB.W1.data; stride=RB.strides[1], pad=RB.pad[1]) .+ reshape(RB.b1.data, inds...)
114-
X2 = ReLU(Y1)
123+
X2 = RB.activation.forward(Y1)
115124

116125
Y2 = X2 + conv(X2, RB.W2.data; stride=RB.strides[2], pad=RB.pad[2]) .+ reshape(RB.b2.data, inds...)
117-
X3 = ReLU(Y2)
126+
X3 = RB.activation.forward(Y2)
118127

119-
cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1])
128+
cdims3 = DCDims(X1, RB.W3.data; stride=RB.strides[1], padding=RB.pad[1])
120129
Y3 = ∇conv_data(X3, RB.W3.data, cdims3)
121130
# Return if only recomputing state
122131
save && (return Y1, Y2, Y3)
123132
# Finish forward
124-
RB.fan == true ? (return ReLU(Y3)) : (return GaLU(Y3))
133+
RB.fan == true ? (return RB.activation.forward(Y3)) : (return GaLU(Y3))
125134
end
126135

127136
# Backward
@@ -135,21 +144,21 @@ function backward(ΔX4::AbstractArray{T, N}, X1::AbstractArray{T, N},
135144

136145
# Cdims
137146
cdims2 = DenseConvDims(Y2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2])
138-
cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1])
147+
cdims3 = DCDims(X1, RB.W3.data; stride=RB.strides[1], padding=RB.pad[1])
139148

140149
# Backpropagate residual ΔX4 and compute gradients
141-
RB.fan == true ? (ΔY3 = ReLUgrad(ΔX4, Y3)) : (ΔY3 = GaLUgrad(ΔX4, Y3))
150+
RB.fan == true ? (ΔY3 = RB.activation.backward(ΔX4, Y3)) : (ΔY3 = GaLUgrad(ΔX4, Y3))
142151
ΔX3 = conv(ΔY3, RB.W3.data, cdims3)
143-
ΔW3 = ∇conv_filter(ΔY3, ReLU(Y2), cdims3)
152+
ΔW3 = ∇conv_filter(ΔY3, RB.activation.forward(Y2), cdims3)
144153

145-
ΔY2 = ReLUgrad(ΔX3, Y2)
154+
ΔY2 = RB.activation.backward(ΔX3, Y2)
146155
ΔX2 = ∇conv_data(ΔY2, RB.W2.data, cdims2) + ΔY2
147-
ΔW2 = ∇conv_filter(ReLU(Y1), ΔY2, cdims2)
156+
ΔW2 = ∇conv_filter(RB.activation.forward(Y1), ΔY2, cdims2)
148157
Δb2 = sum(ΔY2, dims=dims)[inds...]
149158

150159
cdims1 = DenseConvDims(X1, RB.W1.data; stride=RB.strides[1], padding=RB.pad[1])
151160

152-
ΔY1 = ReLUgrad(ΔX2, Y1)
161+
ΔY1 = RB.activation.backward(ΔX2, Y1)
153162
ΔX1 = ∇conv_data(ΔY1, RB.W1.data, cdims1)
154163
ΔW1 = ∇conv_filter(X1, ΔY1, cdims1)
155164
Δb1 = sum(ΔY1, dims=dims)[inds...]
@@ -177,22 +186,22 @@ function jacobian(ΔX1::AbstractArray{T, N}, Δθ::Array{Parameter, 1},
177186

178187
Y1 = conv(X1, RB.W1.data, cdims1) .+ reshape(RB.b1.data, inds...)
179188
ΔY1 = conv(ΔX1, RB.W1.data, cdims1) + conv(X1, Δθ[1].data, cdims1) .+ reshape(Δθ[4].data, inds...)
180-
X2 = ReLU(Y1)
181-
ΔX2 = ReLUgrad(ΔY1, Y1)
189+
X2 = RB.activation.forward(Y1)
190+
ΔX2 = RB.activation.backward(ΔY1, Y1)
182191

183192
cdims2 = DenseConvDims(X2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2])
184193

185194
Y2 = X2 + conv(X2, RB.W2.data, cdims2) .+ reshape(RB.b2.data, inds...)
186195
ΔY2 = ΔX2 + conv(ΔX2, RB.W2.data, cdims2) + conv(X2, Δθ[2].data, cdims2) .+ reshape(Δθ[5].data, inds...)
187-
X3 = ReLU(Y2)
188-
ΔX3 = ReLUgrad(ΔY2, Y2)
196+
X3 = RB.activation.forward(Y2)
197+
ΔX3 = RB.activation.backward(ΔY2, Y2)
189198

190199
cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1])
191200
Y3 = ∇conv_data(X3, RB.W3.data, cdims3)
192201
ΔY3 = ∇conv_data(ΔX3, RB.W3.data, cdims3) + ∇conv_data(X3, Δθ[3].data, cdims3)
193202
if RB.fan == true
194-
X4 = ReLU(Y3)
195-
ΔX4 = ReLUgrad(ΔY3, Y3)
203+
X4 = RB.activation.forward(Y3)
204+
ΔX4 = RB.activation.backward(ΔY3, Y3)
196205
else
197206
ΔX4, X4 = GaLUjacobian(ΔY3, Y3)
198207
end

0 commit comments

Comments
 (0)