71
71
72
72
# Constructor
73
73
function NetworkMultiScaleConditionalHINT (n_in:: Int64 , n_hidden:: Int64 , L:: Int64 , K:: Int64 ;
74
- split_scales= false , k1= 3 , k2= 3 , p1= 1 , p2= 1 , s1= 1 , s2= 1 , logdet= true , ndims= 2 , squeezer:: Squeezer = ShuffleLayer ())
74
+ split_scales= false , k1= 3 , k2= 3 , p1= 1 , p2= 1 , s1= 1 , s2= 1 , logdet= true , ndims= 2 , squeezer:: Squeezer = ShuffleLayer (), activation :: ActivationFunction = SigmoidLayer () )
75
75
76
76
AN_X = Array {ActNorm} (undef, L, K)
77
77
AN_Y = Array {ActNorm} (undef, L, K)
@@ -89,7 +89,7 @@ function NetworkMultiScaleConditionalHINT(n_in::Int64, n_hidden::Int64, L::Int64
89
89
for j= 1 : K
90
90
AN_X[i, j] = ActNorm (n_in* 4 ; logdet= logdet)
91
91
AN_Y[i, j] = ActNorm (n_in* 4 ; logdet= logdet)
92
- CL[i, j] = ConditionalLayerHINT (n_in* 4 , n_hidden; permute= true , k1= k1, k2= k2, p1= p1, p2= p2, s1= s1, s2= s2, logdet= logdet, ndims= ndims)
92
+ CL[i, j] = ConditionalLayerHINT (n_in* 4 , n_hidden; permute= true , activation = activation, k1= k1, k2= k2, p1= p1, p2= p2, s1= s1, s2= s2, logdet= logdet, ndims= ndims)
93
93
end
94
94
n_in *= channel_factor
95
95
end
@@ -131,6 +131,28 @@ function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::NetworkMult
131
131
logdet ? (return X, Y, logdet_) : (return X, Y)
132
132
end
133
133
134
+ # Forward pass and compute logdet
135
+ function forward_Y (Y:: AbstractArray{T, N} , CH:: NetworkMultiScaleConditionalHINT ; logdet= false ) where {T, N}
136
+ CH. split_scales && (Y_save = array_of_array (Y, CH. L- 1 ))
137
+
138
+ logdet_ = 0f0
139
+ for i= 1 : CH. L
140
+ Y = CH. squeezer. forward (Y)
141
+ for j= 1 : CH. K
142
+ logdet ? (Y_,logdet1) = CH. AN_Y[i, j]. forward (Y; logdet= true ) : Y_ = CH. AN_Y[i, j]. forward (Y; logdet= false )
143
+ logdet ? (Y_,logdet2) = CH. CL[i, j]. forward_Y (Y_; logdet= true ) : Y = CH. CL[i, j]. forward_Y (Y_; logdet= false )
144
+ logdet && (logdet_ += (logdet1 + logdet2))
145
+ end
146
+ if CH. split_scales && i < CH. L # don't split after last iteration
147
+ Y, Zy = tensor_split (Y)
148
+ Y_save[i] = Zy
149
+ CH. XY_dims[i] = collect (size (Zy))
150
+ end
151
+ end
152
+ CH. split_scales && (Y = cat_states (Y_save, Y))
153
+ logdet ? (return Y, logdet_) : (return Y)
154
+ end
155
+
134
156
# Inverse pass and compute gradients
135
157
function inverse (Zx:: AbstractArray{T, N} , Zy:: AbstractArray{T, N} , CH:: NetworkMultiScaleConditionalHINT ; logdet= nothing ) where {T, N}
136
158
isnothing (logdet) ? logdet = (CH. logdet && CH. is_reversed) : logdet = logdet
@@ -234,26 +256,7 @@ function backward_inv(ΔX, ΔY, X, Y, CH::NetworkMultiScaleConditionalHINT)
234
256
end
235
257
end
236
258
237
- # Forward pass and compute logdet
238
- function forward_Y (Y:: AbstractArray{T, N} , CH:: NetworkMultiScaleConditionalHINT ) where {T, N}
239
- CH. split_scales && (Y_save = array_of_array (Y, CH. L- 1 ))
240
-
241
- for i= 1 : CH. L
242
- Y = CH. squeezer. forward (Y)
243
- for j= 1 : CH. K
244
- Y_ = CH. AN_Y[i, j]. forward (Y; logdet= false )
245
- Y = CH. CL[i, j]. forward_Y (Y_)
246
- end
247
- if CH. split_scales && i < CH. L # don't split after last iteration
248
- Y, Zy = tensor_split (Y)
249
- Y_save[i] = Zy
250
- CH. XY_dims[i] = collect (size (Zy))
251
- end
252
- end
253
- CH. split_scales && (Y = cat_states (Y_save, Y))
254
- return Y
255
259
256
- end
257
260
258
261
# Inverse pass and compute gradients
259
262
function inverse_Y (Zy:: AbstractArray{T, N} , CH:: NetworkMultiScaleConditionalHINT ) where {T, N}
0 commit comments