82
82
# 3D Constructor from input dimensions
83
83
ConditionalLayerHINT3D (args... ; kw... ) = ConditionalLayerHINT (args... ; kw... , ndims= 3 )
84
84
85
- function forward (X, Y, CH:: ConditionalLayerHINT ; logdet= nothing )
85
+ function forward (X:: AbstractArray{T, N} , Y :: AbstractArray{T, N} , CH:: ConditionalLayerHINT ; logdet= nothing ) where {T, N}
86
86
isnothing (logdet) ? logdet = (CH. logdet && ~ CH. is_reversed) : logdet = logdet
87
87
88
88
# Y-lane
@@ -99,7 +99,7 @@ function forward(X, Y, CH::ConditionalLayerHINT; logdet=nothing)
99
99
logdet ? (return Zx, Zy, logdet1 + logdet2 + logdet3) : (return Zx, Zy)
100
100
end
101
101
102
- function inverse (Zx, Zy, CH:: ConditionalLayerHINT ; logdet= nothing )
102
+ function inverse (Zx:: AbstractArray{T, N} , Zy :: AbstractArray{T, N} , CH:: ConditionalLayerHINT ; logdet= nothing ) where {T, N}
103
103
isnothing (logdet) ? logdet = (CH. logdet && CH. is_reversed) : logdet = logdet
104
104
105
105
# Y-lane
@@ -117,7 +117,7 @@ function inverse(Zx, Zy, CH::ConditionalLayerHINT; logdet=nothing)
117
117
logdet ? (return X, Y, logdet1 + logdet2 + logdet3) : (return X, Y)
118
118
end
119
119
120
- function backward (ΔZx, ΔZy, Zx, Zy, CH:: ConditionalLayerHINT ; logdet= nothing , set_grad:: Bool = true )
120
+ function backward (ΔZx:: AbstractArray{T, N} , ΔZy:: AbstractArray{T, N} , Zx :: AbstractArray{T, N} , Zy :: AbstractArray{T, N} , CH:: ConditionalLayerHINT ; logdet= nothing , set_grad:: Bool = true ) where {T, N}
121
121
isnothing (logdet) ? logdet = (CH. logdet && ~ CH. is_reversed) : logdet = logdet
122
122
123
123
# Y-lane
@@ -133,12 +133,12 @@ function backward(ΔZx, ΔZy, Zx, Zy, CH::ConditionalLayerHINT; logdet=nothing,
133
133
134
134
# X-lane: conditional layer
135
135
if set_grad
136
- ΔYp_, ΔX, X = CH. CL_YX. backward (ΔYp.* 0f0 , ΔZx, Yp, Zx)[[1 ,2 ,4 ]]
136
+ ΔYp_, ΔX, X = CH. CL_YX. backward (ΔYp.* 0 , ΔZx, Yp, Zx)[[1 ,2 ,4 ]]
137
137
else
138
138
if logdet
139
- ΔYp_, ΔX, Δθ_CLYX, _, X, ∇logdet_CLYX = CH. CL_YX. backward (ΔYp.* 0f0 , ΔZx, Yp, Zx; set_grad= set_grad)
139
+ ΔYp_, ΔX, Δθ_CLYX, _, X, ∇logdet_CLYX = CH. CL_YX. backward (ΔYp.* 0 , ΔZx, Yp, Zx; set_grad= set_grad)
140
140
else
141
- ΔYp_, ΔX, Δθ_CLYX, _, X = CH. CL_YX. backward (ΔYp.* 0f0 , ΔZx, Yp, Zx; set_grad= set_grad)
141
+ ΔYp_, ΔX, Δθ_CLYX, _, X = CH. CL_YX. backward (ΔYp.* 0 , ΔZx, Yp, Zx; set_grad= set_grad)
142
142
end
143
143
end
144
144
ΔYp += ΔYp_
@@ -178,14 +178,14 @@ function backward(ΔZx, ΔZy, Zx, Zy, CH::ConditionalLayerHINT; logdet=nothing,
178
178
return ΔX, ΔY, Δθ, X, Y
179
179
else
180
180
∇logdet = cat (∇logdet_CLX, ∇logdet_CLY, ∇logdet_CLYX; dims= 1 )
181
- ~ isnothing (CH. C_X) && (∇logdet = cat (∇logdet, 0f0 * Δθ_CX; dims= 1 ))
182
- ~ isnothing (CH. C_Y) && (∇logdet = cat (∇logdet, 0f0 * Δθ_CY; dims= 1 ))
181
+ ~ isnothing (CH. C_X) && (∇logdet = cat (∇logdet, 0 * Δθ_CX; dims= 1 ))
182
+ ~ isnothing (CH. C_Y) && (∇logdet = cat (∇logdet, 0 * Δθ_CY; dims= 1 ))
183
183
return ΔX, ΔY, Δθ, X, Y, ∇logdet
184
184
end
185
185
end
186
186
end
187
187
188
- function backward_inv (ΔX, ΔY, X, Y, CH:: ConditionalLayerHINT )
188
+ function backward_inv (ΔX:: AbstractArray{T, N} , ΔY :: AbstractArray{T, N} , X :: AbstractArray{T, N} , Y :: AbstractArray{T, N} , CH:: ConditionalLayerHINT ) where {T, N}
189
189
190
190
# 1x1 Convolutions
191
191
if isnothing (CH. C_X) || isnothing (CH. C_Y)
@@ -200,7 +200,7 @@ function backward_inv(ΔX, ΔY, X, Y, CH::ConditionalLayerHINT)
200
200
ΔX, X = backward_inv (ΔXp, Xp, CH. CL_X)
201
201
202
202
# X-lane: conditional layer
203
- ΔYp_, ΔZx, Zx = backward_inv (ΔYp.* 0f0 , ΔX, Yp, X, CH. CL_YX)[[1 ,2 ,4 ]]
203
+ ΔYp_, ΔZx, Zx = backward_inv (ΔYp.* 0 , ΔX, Yp, X, CH. CL_YX)[[1 ,2 ,4 ]]
204
204
ΔYp += ΔYp_
205
205
206
206
# Y-lane
@@ -209,14 +209,14 @@ function backward_inv(ΔX, ΔY, X, Y, CH::ConditionalLayerHINT)
209
209
return ΔZx, ΔZy, Zx, Zy
210
210
end
211
211
212
- function forward_Y (Y, CH:: ConditionalLayerHINT )
212
+ function forward_Y (Y:: AbstractArray{T, N} , CH:: ConditionalLayerHINT ) where {T, N}
213
213
~ isnothing (CH. C_Y) ? (Yp = CH. C_Y. forward (Y)) : (Yp = copy (Y))
214
214
Zy = CH. CL_Y. forward (Yp; logdet= false )
215
215
return Zy
216
216
217
217
end
218
218
219
- function inverse_Y (Zy, CH:: ConditionalLayerHINT )
219
+ function inverse_Y (Zy:: AbstractArray{T, N} , CH:: ConditionalLayerHINT ) where {T, N}
220
220
Yp = CH. CL_Y. inverse (Zy; logdet= false )
221
221
~ isnothing (CH. C_Y) ? (Y = CH. C_Y. inverse (Yp)) : (Y = copy (Yp))
222
222
return Y
225
225
226
226
# # Jacobian-related utils
227
227
228
- function jacobian (ΔX, ΔY, Δθ:: Array{Parameter, 1} , X, Y, CH:: ConditionalLayerHINT ; logdet= nothing )
228
+ function jacobian (ΔX:: AbstractArray{T, N} , ΔY :: AbstractArray{T, N} , Δθ:: Array{Parameter, 1} , X:: AbstractArray{T, N} , Y :: AbstractArray{T, N} , CH:: ConditionalLayerHINT ; logdet= nothing ) where {T, N}
229
229
isnothing (logdet) ? logdet = (CH. logdet && ~ CH. is_reversed) : logdet = logdet
230
230
231
231
# Selecting parameters
@@ -274,16 +274,16 @@ function jacobian(ΔX, ΔY, Δθ::Array{Parameter, 1}, X, Y, CH::ConditionalLaye
274
274
275
275
if logdet
276
276
GNΔθ = cat (GNΔθ_X, GNΔθ_Y, GNΔθ_YX; dims= 1 )
277
- ~ isnothing (CH. C_X) && (GNΔθ = cat (GNΔθ, 0f0 .* Δθ_CX; dims= 1 ))
278
- ~ isnothing (CH. C_Y) && (GNΔθ = cat (GNΔθ, 0f0 .* Δθ_CY; dims= 1 ))
277
+ ~ isnothing (CH. C_X) && (GNΔθ = cat (GNΔθ, 0 .* Δθ_CX; dims= 1 ))
278
+ ~ isnothing (CH. C_Y) && (GNΔθ = cat (GNΔθ, 0 .* Δθ_CY; dims= 1 ))
279
279
return ΔZx, ΔZy, Zx, Zy, logdet1 + logdet2 + logdet3, GNΔθ
280
280
else
281
281
return ΔZx, ΔZy, Zx, Zy
282
282
end
283
283
284
284
end
285
285
286
- adjointJacobian (ΔZx, ΔZy, Zx, Zy, CH:: ConditionalLayerHINT ; logdet= nothing ) = backward (ΔZx, ΔZy, Zx, Zy, CH; set_grad= false , logdet= logdet)
286
+ adjointJacobian (ΔZx:: AbstractArray{T, N} , ΔZy:: AbstractArray{T, N} , Zx :: AbstractArray{T, N} , Zy :: AbstractArray{T, N} , CH:: ConditionalLayerHINT ; logdet= nothing ) where {T, N} = backward (ΔZx, ΔZy, Zx, Zy, CH; set_grad= false , logdet= logdet)
287
287
288
288
289
289
# # Other utils
0 commit comments