Skip to content

Commit 30ad808

Browse files
authored
Merge pull request #37 from slimgroup/typefix
Fix signature types
2 parents 424ffcb + 3c23f52 commit 30ad808

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+458
-471
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "InvertibleNetworks"
22
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
3-
authors = ["Philipp Witte <[email protected]>", "Ali Siahkoohi <[email protected]>", "Mathias Louboutin <[email protected]>", "Gabrio Rizzuti <[email protected]>", "Felix J. herrmann <[email protected]>"]
4-
version = "2.0.1"
3+
authors = ["Philipp Witte <[email protected]>", "Ali Siahkoohi <[email protected]>", "Mathias Louboutin <[email protected]>", "Gabrio Rizzuti <[email protected]>", "Rafael Orozco <[email protected]>", "Felix J. herrmann <[email protected]>"]
4+
version = "2.0.2"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/InvertibleNetworks.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ kernel_size(::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D
2828
channels_in(::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = C_in
2929
channels_out(::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = C_out
3030

31-
function DCDims(X::AbstractArray{Float32, N}, W::AbstractArray{Float32, N}; stride=1, padding=1, nc=nothing) where N
31+
function DCDims(X::AbstractArray{T, N}, W::AbstractArray{T, N}; stride=1, padding=1, nc=nothing) where {T, N}
3232
sw = size(W)
3333
isnothing(nc) && (nc = sw[N-1])
3434
sx = (size(X)[1:N-2]..., nc, size(X)[end])
@@ -76,6 +76,6 @@ include("networks/invertible_network_conditional_hint_multiscale.jl")
7676
include("utils/jacobian.jl")
7777

7878
# gpu
79-
include("utils/cuda.jl")
79+
include("utils/compute_utils.jl")
8080

8181
end

src/conditional_layers/conditional_layer_hint.jl

+16-16
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ end
8282
# 3D Constructor from input dimensions
8383
ConditionalLayerHINT3D(args...; kw...) = ConditionalLayerHINT(args...; kw..., ndims=3)
8484

85-
function forward(X, Y, CH::ConditionalLayerHINT; logdet=nothing)
85+
function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing) where {T, N}
8686
isnothing(logdet) ? logdet = (CH.logdet && ~CH.is_reversed) : logdet = logdet
8787

8888
# Y-lane
@@ -99,7 +99,7 @@ function forward(X, Y, CH::ConditionalLayerHINT; logdet=nothing)
9999
logdet ? (return Zx, Zy, logdet1 + logdet2 + logdet3) : (return Zx, Zy)
100100
end
101101

102-
function inverse(Zx, Zy, CH::ConditionalLayerHINT; logdet=nothing)
102+
function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing) where {T, N}
103103
isnothing(logdet) ? logdet = (CH.logdet && CH.is_reversed) : logdet = logdet
104104

105105
# Y-lane
@@ -117,7 +117,7 @@ function inverse(Zx, Zy, CH::ConditionalLayerHINT; logdet=nothing)
117117
logdet ? (return X, Y, logdet1 + logdet2 + logdet3) : (return X, Y)
118118
end
119119

120-
function backward(ΔZx, ΔZy, Zx, Zy, CH::ConditionalLayerHINT; logdet=nothing, set_grad::Bool=true)
120+
function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing, set_grad::Bool=true) where {T, N}
121121
isnothing(logdet) ? logdet = (CH.logdet && ~CH.is_reversed) : logdet = logdet
122122

123123
# Y-lane
@@ -133,12 +133,12 @@ function backward(ΔZx, ΔZy, Zx, Zy, CH::ConditionalLayerHINT; logdet=nothing,
133133

134134
# X-lane: conditional layer
135135
if set_grad
136-
ΔYp_, ΔX, X = CH.CL_YX.backward(ΔYp.*0f0, ΔZx, Yp, Zx)[[1,2,4]]
136+
ΔYp_, ΔX, X = CH.CL_YX.backward(ΔYp.*0, ΔZx, Yp, Zx)[[1,2,4]]
137137
else
138138
if logdet
139-
ΔYp_, ΔX, Δθ_CLYX, _, X, ∇logdet_CLYX = CH.CL_YX.backward(ΔYp.*0f0, ΔZx, Yp, Zx; set_grad=set_grad)
139+
ΔYp_, ΔX, Δθ_CLYX, _, X, ∇logdet_CLYX = CH.CL_YX.backward(ΔYp.*0, ΔZx, Yp, Zx; set_grad=set_grad)
140140
else
141-
ΔYp_, ΔX, Δθ_CLYX, _, X = CH.CL_YX.backward(ΔYp.*0f0, ΔZx, Yp, Zx; set_grad=set_grad)
141+
ΔYp_, ΔX, Δθ_CLYX, _, X = CH.CL_YX.backward(ΔYp.*0, ΔZx, Yp, Zx; set_grad=set_grad)
142142
end
143143
end
144144
ΔYp += ΔYp_
@@ -178,14 +178,14 @@ function backward(ΔZx, ΔZy, Zx, Zy, CH::ConditionalLayerHINT; logdet=nothing,
178178
return ΔX, ΔY, Δθ, X, Y
179179
else
180180
∇logdet = cat(∇logdet_CLX, ∇logdet_CLY, ∇logdet_CLYX; dims=1)
181-
~isnothing(CH.C_X) && (∇logdet = cat(∇logdet, 0f0*Δθ_CX; dims=1))
182-
~isnothing(CH.C_Y) && (∇logdet = cat(∇logdet, 0f0*Δθ_CY; dims=1))
181+
~isnothing(CH.C_X) && (∇logdet = cat(∇logdet, 0*Δθ_CX; dims=1))
182+
~isnothing(CH.C_Y) && (∇logdet = cat(∇logdet, 0*Δθ_CY; dims=1))
183183
return ΔX, ΔY, Δθ, X, Y, ∇logdet
184184
end
185185
end
186186
end
187187

188-
function backward_inv(ΔX, ΔY, X, Y, CH::ConditionalLayerHINT)
188+
function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T, N}
189189

190190
# 1x1 Convolutions
191191
if isnothing(CH.C_X) || isnothing(CH.C_Y)
@@ -200,7 +200,7 @@ function backward_inv(ΔX, ΔY, X, Y, CH::ConditionalLayerHINT)
200200
ΔX, X = backward_inv(ΔXp, Xp, CH.CL_X)
201201

202202
# X-lane: conditional layer
203-
ΔYp_, ΔZx, Zx = backward_inv(ΔYp.*0f0, ΔX, Yp, X, CH.CL_YX)[[1,2,4]]
203+
ΔYp_, ΔZx, Zx = backward_inv(ΔYp.*0, ΔX, Yp, X, CH.CL_YX)[[1,2,4]]
204204
ΔYp += ΔYp_
205205

206206
# Y-lane
@@ -209,14 +209,14 @@ function backward_inv(ΔX, ΔY, X, Y, CH::ConditionalLayerHINT)
209209
return ΔZx, ΔZy, Zx, Zy
210210
end
211211

212-
function forward_Y(Y, CH::ConditionalLayerHINT)
212+
function forward_Y(Y::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T, N}
213213
~isnothing(CH.C_Y) ? (Yp = CH.C_Y.forward(Y)) : (Yp = copy(Y))
214214
Zy = CH.CL_Y.forward(Yp; logdet=false)
215215
return Zy
216216

217217
end
218218

219-
function inverse_Y(Zy, CH::ConditionalLayerHINT)
219+
function inverse_Y(Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T, N}
220220
Yp = CH.CL_Y.inverse(Zy; logdet=false)
221221
~isnothing(CH.C_Y) ? (Y = CH.C_Y.inverse(Yp)) : (Y = copy(Yp))
222222
return Y
@@ -225,7 +225,7 @@ end
225225

226226
## Jacobian-related utils
227227

228-
function jacobian(ΔX, ΔY, Δθ::Array{Parameter, 1}, X, Y, CH::ConditionalLayerHINT; logdet=nothing)
228+
function jacobian(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing) where {T, N}
229229
isnothing(logdet) ? logdet = (CH.logdet && ~CH.is_reversed) : logdet = logdet
230230

231231
# Selecting parameters
@@ -274,16 +274,16 @@ function jacobian(ΔX, ΔY, Δθ::Array{Parameter, 1}, X, Y, CH::ConditionalLaye
274274

275275
if logdet
276276
GNΔθ = cat(GNΔθ_X, GNΔθ_Y, GNΔθ_YX; dims=1)
277-
~isnothing(CH.C_X) && (GNΔθ = cat(GNΔθ, 0f0.*Δθ_CX; dims=1))
278-
~isnothing(CH.C_Y) && (GNΔθ = cat(GNΔθ, 0f0.*Δθ_CY; dims=1))
277+
~isnothing(CH.C_X) && (GNΔθ = cat(GNΔθ, 0 .*Δθ_CX; dims=1))
278+
~isnothing(CH.C_Y) && (GNΔθ = cat(GNΔθ, 0 .*Δθ_CY; dims=1))
279279
return ΔZx, ΔZy, Zx, Zy, logdet1 + logdet2 + logdet3, GNΔθ
280280
else
281281
return ΔZx, ΔZy, Zx, Zy
282282
end
283283

284284
end
285285

286-
adjointJacobian(ΔZx, ΔZy, Zx, Zy, CH::ConditionalLayerHINT; logdet=nothing) = backward(ΔZx, ΔZy, Zx, Zy, CH; set_grad=false, logdet=logdet)
286+
adjointJacobian(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing) where {T, N} = backward(ΔZx, ΔZy, Zx, Zy, CH; set_grad=false, logdet=logdet)
287287

288288

289289
## Other utils

src/conditional_layers/conditional_layer_residual_block.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ function ConditionalResidualBlock(nx1, nx2, nx_in, ny1, ny2, ny_in, n_hidden, ba
8484
return ConditionalResidualBlock(W0, W1, W2, W3, b0, b1, b2, cdims1, cdims2, cdims3)
8585
end
8686

87-
function forward(X0, D, RB::ConditionalResidualBlock; save=false)
87+
function forward(X0::AbstractArray{T, N}, D::AbstractArray{T, N}, RB::ConditionalResidualBlock; save=false) where {T, N}
8888

8989
# Dimensions of input image X
9090
nx1, nx2, nx_in, batchsize = size(X0)
@@ -110,7 +110,7 @@ function forward(X0, D, RB::ConditionalResidualBlock; save=false)
110110
end
111111

112112

113-
function backward(ΔX4, ΔD, X0, D, RB::ConditionalResidualBlock; set_grad::Bool=true)
113+
function backward(ΔX4::AbstractArray{T, N}, ΔD::AbstractArray{T, N}, X0::AbstractArray{T, N}, D::AbstractArray{T, N}, RB::ConditionalResidualBlock; set_grad::Bool=true) where {T, N}
114114

115115
# Recompute forward states from input X
116116
Y0, Y1, Y2, Y3, X1, X2, X3 = forward(X0, D, RB; save=true)

src/layers/invertible_layer_actnorm.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ function ActNorm(k; logdet=false)
5757
end
5858

5959
# 2-3D Foward pass: Input X, Output Y
60-
function forward(X::AbstractArray{Float32, N}, AN::ActNorm; logdet=nothing) where N
60+
function forward(X::AbstractArray{T, N}, AN::ActNorm; logdet=nothing) where {T, N}
6161
isnothing(logdet) ? logdet = (AN.logdet && ~AN.is_reversed) : logdet = logdet
6262
inds = [i!=(N-1) ? 1 : (:) for i=1:N]
6363
dims = collect(1:N-1); dims[end] +=1
@@ -67,7 +67,7 @@ function forward(X::AbstractArray{Float32, N}, AN::ActNorm; logdet=nothing) wher
6767
if isnothing(AN.s.data) && !AN.is_reversed
6868
μ = mean(X; dims=dims)[inds...]
6969
σ_sqr = var(X; dims=dims)[inds...]
70-
AN.s.data = 1f0 ./ sqrt.(σ_sqr)
70+
AN.s.data = 1 ./ sqrt.(σ_sqr)
7171
AN.b.data = -μ ./ sqrt.(σ_sqr)
7272
end
7373
Y = X .* reshape(AN.s.data, inds...) .+ reshape(AN.b.data, inds...)
@@ -77,7 +77,7 @@ function forward(X::AbstractArray{Float32, N}, AN::ActNorm; logdet=nothing) wher
7777
end
7878

7979
# 2-3D Inverse pass: Input Y, Output X
80-
function inverse(Y::AbstractArray{Float32, N}, AN::ActNorm; logdet=nothing) where N
80+
function inverse(Y::AbstractArray{T, N}, AN::ActNorm; logdet=nothing) where {T, N}
8181
isnothing(logdet) ? logdet = (AN.logdet && AN.is_reversed) : logdet = logdet
8282
inds = [i!=(N-1) ? 1 : (:) for i=1:N]
8383
dims = collect(1:N-1); dims[end] +=1
@@ -97,7 +97,7 @@ function inverse(Y::AbstractArray{Float32, N}, AN::ActNorm; logdet=nothing) wher
9797
end
9898

9999
# 2-3D Backward pass: Input (ΔY, Y), Output (ΔY, Y)
100-
function backward(ΔY::AbstractArray{Float32, N}, Y::AbstractArray{Float32, N}, AN::ActNorm; set_grad::Bool = true) where N
100+
function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, AN::ActNorm; set_grad::Bool = true) where {T, N}
101101
inds = [i!=(N-1) ? 1 : (:) for i=1:N]
102102
dims = collect(1:N-1); dims[end] +=1
103103
nn = size(ΔY)[1:N-2]
@@ -118,13 +118,13 @@ function backward(ΔY::AbstractArray{Float32, N}, Y::AbstractArray{Float32, N},
118118
if set_grad
119119
return ΔX, X
120120
else
121-
AN.logdet ? (return ΔX, Δθ, X, [Parameter(Δs_), Parameter(0f0*Δb)]) : (return ΔX, Δθ, X)
121+
AN.logdet ? (return ΔX, Δθ, X, [Parameter(Δs_), Parameter(0*Δb)]) : (return ΔX, Δθ, X)
122122
end
123123
end
124124

125125
## Reverse-layer functions
126126
# 2-3D Backward pass (inverse): Input (ΔX, X), Output (ΔX, X)
127-
function backward_inv(ΔX::AbstractArray{Float32, N}, X::AbstractArray{Float32, N}, AN::ActNorm; set_grad::Bool = true) where N
127+
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, AN::ActNorm; set_grad::Bool = true) where {T, N}
128128
inds = [i!=(N-1) ? 1 : (:) for i=1:N]
129129
dims = collect(1:N-1); dims[end] +=1
130130
nn = size(ΔX)[1:N-2]
@@ -151,7 +151,7 @@ end
151151

152152
## Jacobian-related functions
153153
# 2-£D
154-
function jacobian(ΔX::AbstractArray{Float32, N}, Δθ::AbstractArray{Parameter, 1}, X::AbstractArray{Float32, N}, AN::ActNorm; logdet=nothing) where N
154+
function jacobian(ΔX::AbstractArray{T, N}, Δθ::AbstractArray{Parameter, 1}, X::AbstractArray{T, N}, AN::ActNorm; logdet=nothing) where {T, N}
155155
isnothing(logdet) ? logdet = (AN.logdet && ~AN.is_reversed) : logdet = logdet
156156
inds = [i!=(N-1) ? 1 : (:) for i=1:N]
157157
nn = size(ΔX)[1:N-2]
@@ -175,7 +175,7 @@ function jacobian(ΔX::AbstractArray{Float32, N}, Δθ::AbstractArray{Parameter,
175175
end
176176

177177
# 2D/3D
178-
function adjointJacobian(ΔY::AbstractArray{Float32, N}, Y::AbstractArray{Float32, N}, AN::ActNorm) where N
178+
function adjointJacobian(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, AN::ActNorm) where {T, N}
179179
return backward(ΔY, Y, AN; set_grad=false)
180180
end
181181

src/layers/invertible_layer_basic.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ end
8787
CouplingLayerBasic3D(args...;kw...) = CouplingLayerBasic(args...; kw..., ndims=3)
8888

8989
# 2D/3D Forward pass: Input X, Output Y
90-
function forward(X1::AbstractArray{Float32, N}, X2::AbstractArray{Float32, N}, L::CouplingLayerBasic; save::Bool=false, logdet=nothing) where N
90+
function forward(X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLayerBasic; save::Bool=false, logdet=nothing) where {T, N}
9191
isnothing(logdet) ? logdet = (L.logdet && ~L.is_reversed) : logdet = logdet
9292

9393
# Coupling layer
@@ -103,13 +103,13 @@ function forward(X1::AbstractArray{Float32, N}, X2::AbstractArray{Float32, N}, L
103103
end
104104

105105
# 2D/3D Inverse pass: Input Y, Output X
106-
function inverse(Y1::AbstractArray{Float32, N}, Y2::AbstractArray{Float32, N}, L::CouplingLayerBasic; save::Bool=false, logdet=nothing) where N
106+
function inverse(Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLayerBasic; save::Bool=false, logdet=nothing) where {T, N}
107107
isnothing(logdet) ? logdet = (L.logdet && L.is_reversed) : logdet = logdet
108108

109109
# Inverse layer
110110
logS_T1, logS_T2 = tensor_split(L.RB.forward(Y1))
111111
S = L.activation.forward(logS_T1)
112-
X2 = (Y2 - logS_T2) ./ (S .+ eps(1f0)) # add epsilon to avoid division by 0
112+
X2 = (Y2 - logS_T2) ./ (S .+ eps(T)) # add epsilon to avoid division by 0
113113

114114
if logdet
115115
save == true ? (return Y1, X2, -coupling_logdet_forward(S), S) : (return Y1, X2, -coupling_logdet_forward(S))
@@ -119,7 +119,7 @@ function inverse(Y1::AbstractArray{Float32, N}, Y2::AbstractArray{Float32, N}, L
119119
end
120120

121121
# 2D/3D Backward pass: Input (ΔY, Y), Output (ΔX, X)
122-
function backward(ΔY1, ΔY2, Y1, Y2, L::CouplingLayerBasic; set_grad::Bool=true)
122+
function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true) where {T, N}
123123

124124
# Recompute forward state
125125
X1, X2, S = inverse(Y1, Y2, L; save=true, logdet=false)
@@ -136,7 +136,7 @@ function backward(ΔY1, ΔY2, Y1, Y2, L::CouplingLayerBasic; set_grad::Bool=true
136136
else
137137
ΔX1, Δθ = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X1; set_grad=set_grad)
138138
if L.logdet
139-
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(coupling_logdet_backward(S), S), 0f0.*ΔT), X1; set_grad=set_grad)
139+
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(coupling_logdet_backward(S), S), 0 .*ΔT), X1; set_grad=set_grad)
140140
end
141141
ΔX1 += ΔY1
142142
end
@@ -149,7 +149,7 @@ function backward(ΔY1, ΔY2, Y1, Y2, L::CouplingLayerBasic; set_grad::Bool=true
149149
end
150150

151151
# 2D/3D Reverse backward pass: Input (ΔX, X), Output (ΔY, Y)
152-
function backward_inv(ΔX1, ΔX2, X1, X2, L::CouplingLayerBasic; set_grad::Bool=true)
152+
function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true) where {T, N}
153153

154154
# Recompute inverse state
155155
Y1, Y2, S = forward(X1, X2, L; save=true, logdet=false)
@@ -179,9 +179,9 @@ end
179179
## Jacobian-related functions
180180

181181
# 2D
182-
function jacobian(ΔX1::AbstractArray{Float32, N}, ΔX2::AbstractArray{Float32, N}, Δθ::AbstractArray{Parameter, 1},
183-
X1::AbstractArray{Float32, N}, X2::AbstractArray{Float32, N}, L::CouplingLayerBasic;
184-
save=false, logdet=nothing) where N
182+
function jacobian(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, Δθ::AbstractArray{Parameter, 1},
183+
X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLayerBasic;
184+
save=false, logdet=nothing) where {T, N}
185185
isnothing(logdet) ? logdet = (L.logdet && ~L.is_reversed) : logdet = logdet
186186

187187
logS_T1, logS_T2 = tensor_split(L.RB.forward(X1))
@@ -203,7 +203,7 @@ function jacobian(ΔX1::AbstractArray{Float32, N}, ΔX2::AbstractArray{Float32,
203203
end
204204

205205
# 2D/3D
206-
function adjointJacobian(ΔY1, ΔY2, Y1, Y2, L::CouplingLayerBasic)
206+
function adjointJacobian(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLayerBasic) where {T, N}
207207
return backward(ΔY1, ΔY2, Y1, Y2, L; set_grad=false)
208208
end
209209

0 commit comments

Comments
 (0)