20
20
21
21
*Input*:
22
22
23
- - `n_in`, `n_hidden`: number of input and hidden channels
23
+ - `n_in`: number of input channels
24
+
25
+ - `n_hidden`: number of hidden channels
26
+
27
+ - `n_out`: number of ouput channels
28
+
29
+ - `activation`: activation type between conv layers and final output
24
30
25
31
- `k1`, `k2`: kernel size of convolutions in residual block. `k1` is the kernel of the first and third
26
32
operator, `k2` is the kernel size of the second operator.
@@ -67,6 +73,7 @@ struct ResidualBlock <: NeuralNetLayer
67
73
fan:: Bool
68
74
strides
69
75
pad
76
+ activation:: ActivationFunction
70
77
end
71
78
72
79
@Flux . functor ResidualBlock
75
82
# Constructors
76
83
77
84
# Constructor
78
- function ResidualBlock (n_in, n_hidden; k1= 3 , k2= 3 , p1= 1 , p2= 1 , s1= 1 , s2= 1 , fan= false , ndims= 2 )
85
+ function ResidualBlock (n_in, n_hidden; n_out= nothing , activation:: ActivationFunction = ReLUlayer (), k1= 3 , k2= 3 , p1= 1 , p2= 1 , s1= 1 , s2= 1 , fan= false , ndims= 2 )
86
+ # default/legacy behaviour
87
+ isnothing (n_out) && (n_out = 2 * n_in)
79
88
80
89
k1 = Tuple (k1 for i= 1 : ndims)
81
90
k2 = Tuple (k2 for i= 1 : ndims)
82
91
# Initialize weights
83
92
W1 = Parameter (glorot_uniform (k1... , n_in, n_hidden))
84
93
W2 = Parameter (glorot_uniform (k2... , n_hidden, n_hidden))
85
- W3 = Parameter (glorot_uniform (k1... , 2 * n_in , n_hidden))
94
+ W3 = Parameter (glorot_uniform (k1... , n_out , n_hidden))
86
95
b1 = Parameter (zeros (Float32, n_hidden))
87
96
b2 = Parameter (zeros (Float32, n_hidden))
88
97
89
- return ResidualBlock (W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2))
98
+ return ResidualBlock (W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2), activation )
90
99
end
91
100
92
101
# Constructor for given weights
93
- function ResidualBlock (W1, W2, W3, b1, b2; p1= 1 , p2= 1 , s1= 1 , s2= 1 , fan= false , ndims= 2 )
102
+ function ResidualBlock (W1, W2, W3, b1, b2; activation :: ActivationFunction = ReLUlayer (), p1= 1 , p2= 1 , s1= 1 , s2= 1 , fan= false , ndims= 2 )
94
103
95
104
# Make weights parameters
96
105
W1 = Parameter (W1)
@@ -99,7 +108,7 @@ function ResidualBlock(W1, W2, W3, b1, b2; p1=1, p2=1, s1=1, s2=1, fan=false, nd
99
108
b1 = Parameter (b1)
100
109
b2 = Parameter (b2)
101
110
102
- return ResidualBlock (W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2))
111
+ return ResidualBlock (W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2),activation )
103
112
end
104
113
105
114
ResidualBlock3D (args... ; kw... ) = ResidualBlock (args... ; kw... , ndims= 3 )
@@ -111,17 +120,17 @@ function forward(X1::AbstractArray{T, N}, RB::ResidualBlock; save=false) where {
111
120
inds = [i!= (N- 1 ) ? 1 : Colon () for i= 1 : N]
112
121
113
122
Y1 = conv (X1, RB. W1. data; stride= RB. strides[1 ], pad= RB. pad[1 ]) .+ reshape (RB. b1. data, inds... )
114
- X2 = ReLU (Y1)
123
+ X2 = RB . activation . forward (Y1)
115
124
116
125
Y2 = X2 + conv (X2, RB. W2. data; stride= RB. strides[2 ], pad= RB. pad[2 ]) .+ reshape (RB. b2. data, inds... )
117
- X3 = ReLU (Y2)
126
+ X3 = RB . activation . forward (Y2)
118
127
119
- cdims3 = DCDims (X1, RB. W3. data; nc = 2 * size (X1, N - 1 ), stride= RB. strides[1 ], padding= RB. pad[1 ])
128
+ cdims3 = DCDims (X1, RB. W3. data; stride= RB. strides[1 ], padding= RB. pad[1 ])
120
129
Y3 = ∇conv_data (X3, RB. W3. data, cdims3)
121
130
# Return if only recomputing state
122
131
save && (return Y1, Y2, Y3)
123
132
# Finish forward
124
- RB. fan == true ? (return ReLU (Y3)) : (return GaLU (Y3))
133
+ RB. fan == true ? (return RB . activation . forward (Y3)) : (return GaLU (Y3))
125
134
end
126
135
127
136
# Backward
@@ -135,21 +144,21 @@ function backward(ΔX4::AbstractArray{T, N}, X1::AbstractArray{T, N},
135
144
136
145
# Cdims
137
146
cdims2 = DenseConvDims (Y2, RB. W2. data; stride= RB. strides[2 ], padding= RB. pad[2 ])
138
- cdims3 = DCDims (X1, RB. W3. data; nc = 2 * size (X1, N - 1 ), stride= RB. strides[1 ], padding= RB. pad[1 ])
147
+ cdims3 = DCDims (X1, RB. W3. data; stride= RB. strides[1 ], padding= RB. pad[1 ])
139
148
140
149
# Backpropagate residual ΔX4 and compute gradients
141
- RB. fan == true ? (ΔY3 = ReLUgrad (ΔX4, Y3)) : (ΔY3 = GaLUgrad (ΔX4, Y3))
150
+ RB. fan == true ? (ΔY3 = RB . activation . backward (ΔX4, Y3)) : (ΔY3 = GaLUgrad (ΔX4, Y3))
142
151
ΔX3 = conv (ΔY3, RB. W3. data, cdims3)
143
- ΔW3 = ∇conv_filter (ΔY3, ReLU (Y2), cdims3)
152
+ ΔW3 = ∇conv_filter (ΔY3, RB . activation . forward (Y2), cdims3)
144
153
145
- ΔY2 = ReLUgrad (ΔX3, Y2)
154
+ ΔY2 = RB . activation . backward (ΔX3, Y2)
146
155
ΔX2 = ∇conv_data (ΔY2, RB. W2. data, cdims2) + ΔY2
147
- ΔW2 = ∇conv_filter (ReLU (Y1), ΔY2, cdims2)
156
+ ΔW2 = ∇conv_filter (RB . activation . forward (Y1), ΔY2, cdims2)
148
157
Δb2 = sum (ΔY2, dims= dims)[inds... ]
149
158
150
159
cdims1 = DenseConvDims (X1, RB. W1. data; stride= RB. strides[1 ], padding= RB. pad[1 ])
151
160
152
- ΔY1 = ReLUgrad (ΔX2, Y1)
161
+ ΔY1 = RB . activation . backward (ΔX2, Y1)
153
162
ΔX1 = ∇conv_data (ΔY1, RB. W1. data, cdims1)
154
163
ΔW1 = ∇conv_filter (X1, ΔY1, cdims1)
155
164
Δb1 = sum (ΔY1, dims= dims)[inds... ]
@@ -177,22 +186,22 @@ function jacobian(ΔX1::AbstractArray{T, N}, Δθ::Array{Parameter, 1},
177
186
178
187
Y1 = conv (X1, RB. W1. data, cdims1) .+ reshape (RB. b1. data, inds... )
179
188
ΔY1 = conv (ΔX1, RB. W1. data, cdims1) + conv (X1, Δθ[1 ]. data, cdims1) .+ reshape (Δθ[4 ]. data, inds... )
180
- X2 = ReLU (Y1)
181
- ΔX2 = ReLUgrad (ΔY1, Y1)
189
+ X2 = RB . activation . forward (Y1)
190
+ ΔX2 = RB . activation . backward (ΔY1, Y1)
182
191
183
192
cdims2 = DenseConvDims (X2, RB. W2. data; stride= RB. strides[2 ], padding= RB. pad[2 ])
184
193
185
194
Y2 = X2 + conv (X2, RB. W2. data, cdims2) .+ reshape (RB. b2. data, inds... )
186
195
ΔY2 = ΔX2 + conv (ΔX2, RB. W2. data, cdims2) + conv (X2, Δθ[2 ]. data, cdims2) .+ reshape (Δθ[5 ]. data, inds... )
187
- X3 = ReLU (Y2)
188
- ΔX3 = ReLUgrad (ΔY2, Y2)
196
+ X3 = RB . activation . forward (Y2)
197
+ ΔX3 = RB . activation . backward (ΔY2, Y2)
189
198
190
199
cdims3 = DCDims (X1, RB. W3. data; nc= 2 * size (X1, N- 1 ), stride= RB. strides[1 ], padding= RB. pad[1 ])
191
200
Y3 = ∇conv_data (X3, RB. W3. data, cdims3)
192
201
ΔY3 = ∇conv_data (ΔX3, RB. W3. data, cdims3) + ∇conv_data (X3, Δθ[3 ]. data, cdims3)
193
202
if RB. fan == true
194
- X4 = ReLU (Y3)
195
- ΔX4 = ReLUgrad (ΔY3, Y3)
203
+ X4 = RB . activation . forward (Y3)
204
+ ΔX4 = RB . activation . backward (ΔY3, Y3)
196
205
else
197
206
ΔX4, X4 = GaLUjacobian (ΔY3, Y3)
198
207
end
0 commit comments