16
16
17
17
function inv_test (nx, ny, n_in, batchsize, logdet, squeeze_type, split_scales)
18
18
print (" \n Multiscale Conditional HINT invertibility test with squeeze_type=$(squeeze_type) , split_scales=$(split_scales) , logdet=$(logdet) \n " )
19
- CH = NetworkMultiScaleConditionalHINT (n_in, n_hidden, L, K; squeezer = squeeze_type, logdet= logdet, split_scales= split_scales)
19
+ CH = NetworkMultiScaleConditionalHINT (n_in, n_hidden, L, K; squeezer = squeeze_type () , logdet= logdet, split_scales= split_scales)
20
20
21
21
# Input image and data
22
22
X = randn (Float32, nx, ny, n_in, batchsize)
@@ -61,8 +61,8 @@ function loss(CH, X, Y)
61
61
end
62
62
63
63
function grad_test_X (nx, ny, n_channel, batchsize, logdet, squeeze_type, split_scales)
64
- print (" \n Multiscale Conditional HINT invertibility test with squeeze_type=$(squeeze_type) , split_scales=$(split_scales) , logdet=$(logdet) \n " )
65
- CH = NetworkMultiScaleConditionalHINT (n_in, n_hidden, L, K; squeezer = squeeze_type, logdet= logdet, split_scales= split_scales)
64
+ print (" \n Multiscale Conditional HINT gradient test with squeeze_type=$(squeeze_type) , split_scales=$(split_scales) , logdet=$(logdet) \n " )
65
+ CH = NetworkMultiScaleConditionalHINT (n_in, n_hidden, L, K; squeezer = squeeze_type () , logdet= logdet, split_scales= split_scales)
66
66
67
67
68
68
# Input image
@@ -71,12 +71,12 @@ function grad_test_X(nx, ny, n_channel, batchsize, logdet, squeeze_type, split_s
71
71
72
72
# Input data
73
73
Y0 = randn (Float32, nx, ny, n_channel, batchsize)
74
- dY = randn (Float32, nx, ny, n_channel, batchsize)
74
+ dY = 10 * randn (Float32, nx, ny, n_channel, batchsize)
75
75
76
76
f0, gX, gY = loss (CH, X0, Y0)[1 : 3 ]
77
77
78
78
maxiter = 5
79
- h = 0.1f0
79
+ h = 0.1f0
80
80
err1 = zeros (Float32, maxiter)
81
81
err2 = zeros (Float32, maxiter)
82
82
@@ -92,11 +92,7 @@ function grad_test_X(nx, ny, n_channel, batchsize, logdet, squeeze_type, split_s
92
92
@test isapprox (err2[end ] / (err2[1 ]/ 4 ^ (maxiter- 1 )), 1f0 ; atol= 1f1 )
93
93
end
94
94
95
- shuffle_sq = ShuffleLayer ()
96
- wavelet_sq = WaveletLayer ()
97
- Haar_sq = HaarLayer ()
98
-
99
- for squeeze_i in [shuffle_sq, wavelet_sq, Haar_sq]
95
+ for squeeze_i in [ShuffleLayer, WaveletLayer, HaarLayer]
100
96
for split_scales in [true , false ]
101
97
for logdet in [false , true ]
102
98
inv_test (nx, ny, n_in, batchsize, logdet, squeeze_i, split_scales)
@@ -111,11 +107,11 @@ end
111
107
# Gradient test
112
108
113
109
# Initialization
114
- CH = NetworkMultiScaleConditionalHINT (n_in, n_hidden, L, K; split_scales= false , k1= 3 , k2= 1 , p1= 1 , p2= 0 , squeezer = shuffle_sq );
110
+ CH = NetworkMultiScaleConditionalHINT (n_in, n_hidden, L, K; split_scales= false , k1= 3 , k2= 1 , p1= 1 , p2= 0 , squeezer = ShuffleLayer () );
115
111
CH. forward (randn (Float32, nx, ny, n_in, batchsize), randn (Float32, nx, ny, n_in, batchsize))
116
112
θ = deepcopy (get_params (CH))
117
113
118
- CH0 = NetworkMultiScaleConditionalHINT (n_in, n_hidden, L, K; split_scales= false , k1= 3 , k2= 1 , p1= 1 , p2= 0 , squeezer = shuffle_sq );
114
+ CH0 = NetworkMultiScaleConditionalHINT (n_in, n_hidden, L, K; split_scales= false , k1= 3 , k2= 1 , p1= 1 , p2= 0 , squeezer = ShuffleLayer () );
119
115
120
116
CH0. forward (randn (Float32, nx, ny, n_in, batchsize), randn (Float32, nx, ny, n_in, batchsize))
121
117
θ0 = deepcopy (get_params (CH0))
0 commit comments