1
1
using ChainRulesCore
2
- export logdetjac
3
- import ChainRulesCore: frule, rrule
4
-
2
+ export logdetjac, getrrule
3
+ import ChainRulesCore: frule, rrule, @non_differentiable
5
4
5
+ @non_differentiable get_params (:: Invertible )
6
+ @non_differentiable get_params (:: Reversed )
6
7
# # Tape types and utilities
7
8
8
9
"""
@@ -81,7 +82,6 @@ function forward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N},
81
82
if logdet isa Float32
82
83
state. logdet === nothing ? (state. logdet = logdet) : (state. logdet += logdet)
83
84
end
84
-
85
85
end
86
86
87
87
"""
@@ -97,15 +97,13 @@ function backward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N}
97
97
state. Y[state. counter_block] = X
98
98
state. counter_layer -= 1
99
99
end
100
-
101
100
state. counter_block == 0 && reset! (state) # reset state when first block/first layer is reached
102
-
103
101
end
104
102
105
103
# # Chain rules for invertible networks
106
104
# General pullback function
107
105
function pullback (net:: Invertible , ΔY:: AbstractArray{T,N} ;
108
- state:: InvertibleOperationsTape = GLOBAL_STATE_INVOPS) where {T, N}
106
+ state:: InvertibleOperationsTape = GLOBAL_STATE_INVOPS) where {T, N}
109
107
110
108
# Check state coherency
111
109
check_coherence (state, net)
@@ -114,19 +112,19 @@ function pullback(net::Invertible, ΔY::AbstractArray{T,N};
114
112
T2 = typeof (current (state))
115
113
ΔY = convert (T2, ΔY)
116
114
# Backward pass
117
- ΔX, X_ = net. backward (ΔY, current (state))
118
-
115
+ ΔX, X_ = net. backward (ΔY, current (state); set_grad = true )
116
+ Δθ = getfield .( get_params (net), :grad )
119
117
# Update state
120
118
backward_update! (state, X_)
121
119
122
- return nothing , ΔX
120
+ return NoTangent (), NoTangent (), ΔX, Δθ
123
121
end
124
122
125
123
126
124
# Reverse-mode AD rule
127
- function ChainRulesCore. rrule (net:: Invertible , X:: AbstractArray{T, N} ;
125
+ function ChainRulesCore. rrule (:: typeof (forward_net), net:: Invertible , X:: AbstractArray{T, N} , θ ... ;
128
126
state:: InvertibleOperationsTape = GLOBAL_STATE_INVOPS) where {T, N}
129
-
127
+
130
128
# Forward pass
131
129
net. logdet ? ((Y, logdet) = net. forward (X)) : (Y = net. forward (X); logdet = nothing )
132
130
142
140
143
141
# # Logdet utilities for Zygote pullback
144
142
145
- logdetjac (; state:: InvertibleOperationsTape = GLOBAL_STATE_INVOPS) = state. logdet
143
+ logdetjac (; state:: InvertibleOperationsTape = GLOBAL_STATE_INVOPS) = state. logdet
144
+
145
+ # # Utility to get the pullback directly for testing
146
+
147
+ getrrule (net:: Invertible , X:: AbstractArray ) = rrule (forward_net, net, X, getfield .(get_params (net), :data ))
0 commit comments