@@ -40,6 +40,8 @@ export NetworkGlow, NetworkGlow3D
40
40
41
41
- `squeeze_type` : squeeze type that happens at each multiscale level
42
42
43
+ - `logdet` : boolean to turn on/off logdet term tracking and gradient calculation
44
+
43
45
*Output*:
44
46
45
47
- `G`: invertible Glow network.
@@ -67,12 +69,13 @@ struct NetworkGlow <: InvertibleNetwork
67
69
K:: Int64
68
70
squeezer:: Squeezer
69
71
split_scales:: Bool
72
+ logdet:: Bool
70
73
end
71
74
72
75
@Flux . functor NetworkGlow
73
76
74
77
# Constructor
75
- function NetworkGlow (n_in, n_hidden, L, K; nx= nothing , dense= false , freeze_conv= false , split_scales= false , k1= 3 , k2= 1 , p1= 1 , p2= 0 , s1= 1 , s2= 1 , ndims= 2 , squeezer:: Squeezer = ShuffleLayer (), activation:: ActivationFunction = SigmoidLayer ())
78
+ function NetworkGlow (n_in, n_hidden, L, K; logdet = true , nx= nothing , dense= false , freeze_conv= false , split_scales= false , k1= 3 , k2= 1 , p1= 1 , p2= 0 , s1= 1 , s2= 1 , ndims= 2 , squeezer:: Squeezer = ShuffleLayer (), activation:: ActivationFunction = SigmoidLayer ())
76
79
(n_in == 1 ) && (split_scales = true ) # Need extra channels for coupling layer
77
80
(dense && isnothing (nx)) && error (" Dense network needs nx as kwarg input" )
78
81
@@ -91,29 +94,28 @@ function NetworkGlow(n_in, n_hidden, L, K; nx=nothing, dense=false, freeze_conv=
91
94
n_in *= channel_factor # squeeze if split_scales is turned on
92
95
(dense && split_scales) && (nx = Int64 (nx/ 2 ))
93
96
for j= 1 : K
94
- AN[i, j] = ActNorm (n_in; logdet= true )
95
- CL[i, j] = CouplingLayerGlow (n_in, n_hidden; nx= nx, dense= dense, freeze_conv= freeze_conv, k1= k1, k2= k2, p1= p1, p2= p2, s1= s1, s2= s2, logdet= true , activation= activation, ndims= ndims)
97
+ AN[i, j] = ActNorm (n_in; logdet= logdet )
98
+ CL[i, j] = CouplingLayerGlow (n_in, n_hidden; nx= nx, dense= dense, freeze_conv= freeze_conv, k1= k1, k2= k2, p1= p1, p2= p2, s1= s1, s2= s2, logdet= logdet , activation= activation, ndims= ndims)
96
99
end
97
100
(i < L && split_scales) && (n_in = Int64 (n_in/ 2 ); ) # split
98
101
end
99
102
100
- return NetworkGlow (AN, CL, Z_dims, L, K, squeezer, split_scales)
103
+ return NetworkGlow (AN, CL, Z_dims, L, K, squeezer, split_scales,logdet )
101
104
end
102
105
103
106
NetworkGlow3D (args; kw... ) = NetworkGlow (args... ; kw... , ndims= 3 )
104
107
105
108
# Forward pass and compute logdet
106
- function forward (X:: AbstractArray{T, N} , G:: NetworkGlow ) where {T, N}
109
+ function forward (X:: AbstractArray{T, N} , G:: NetworkGlow ; ) where {T, N}
107
110
G. split_scales && (Z_save = array_of_array (X, max (G. L- 1 ,1 )))
108
111
109
-
110
- logdet = 0
112
+ logdet_ = 0
111
113
for i= 1 : G. L
112
114
(G. split_scales) && (X = G. squeezer. forward (X))
113
115
for j= 1 : G. K
114
- X, logdet1 = G. AN[i, j]. forward (X)
115
- X, logdet2 = G. CL[i, j]. forward (X)
116
- logdet += (logdet1 + logdet2)
116
+ G . logdet ? ( X, logdet1) = G . AN[i, j] . forward (X) : X = G. AN[i, j]. forward (X)
117
+ G . logdet ? ( X, logdet2) = G . CL[i, j] . forward (X) : X = G. CL[i, j]. forward (X)
118
+ G . logdet && (logdet_ += (logdet1 + logdet2) )
117
119
end
118
120
if G. split_scales && (i < G. L || i == 1 ) # don't split after last iteration
119
121
X, Z = tensor_split (X)
@@ -122,7 +124,8 @@ function forward(X::AbstractArray{T, N}, G::NetworkGlow) where {T, N}
122
124
end
123
125
end
124
126
G. split_scales && (X = cat_states (Z_save, X))
125
- return X, logdet
127
+
128
+ G. logdet ? (return X, logdet_) : (return X)
126
129
end
127
130
128
131
# Inverse pass
0 commit comments