Skip to content

Commit a94c948

Browse files
authored
Merge pull request #62 from ziyiyin97/master
Dispatch pullback rrule for all Invertible
2 parents 192a7fe + 7229577 commit a94c948

File tree

5 files changed

+19
-13
lines changed

5 files changed

+19
-13
lines changed

src/utils/chainrules.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ isa_newblock(state::InvertibleOperationsTape, X) = (state.counter_block == 0) ||
5656
"""
5757
Error if mismatch between state and network
5858
"""
59-
function check_coherence(state::InvertibleOperationsTape, net::Union{NeuralNetLayer,InvertibleNetwork})
59+
function check_coherence(state::InvertibleOperationsTape, net::Invertible)
6060
if state.counter_block != 0 && state.counter_layer != 0 && state.layer_blocks[state.counter_block][state.counter_layer] != net
6161
reset!(state)
6262
throw(ArgumentError("Current state does not correspond to current layer, resetting state..."))
@@ -66,7 +66,7 @@ end
6666
"""
6767
Update state in the forward pass.
6868
"""
69-
function forward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N}, Y::AbstractArray{T,N}, logdet::Union{Nothing,T}, net::Union{NeuralNetLayer,InvertibleNetwork}) where {T, N}
69+
function forward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N}, Y::AbstractArray{T,N}, logdet::Union{Nothing,T}, net::Invertible) where {T, N}
7070

7171
if isa_newblock(state, X)
7272
push!(state.Y, Y)
@@ -104,7 +104,7 @@ end
104104

105105
## Chain rules for invertible networks
106106
# General pullback function
107-
function pullback(net::Union{NeuralNetLayer,InvertibleNetwork}, ΔY::AbstractArray{T,N};
107+
function pullback(net::Invertible, ΔY::AbstractArray{T,N};
108108
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}
109109

110110
# Check state coherency
@@ -124,7 +124,7 @@ end
124124

125125

126126
# Reverse-mode AD rule
127-
function ChainRulesCore.rrule(net::Union{NeuralNetLayer,InvertibleNetwork}, X::AbstractArray{T, N};
127+
function ChainRulesCore.rrule(net::Invertible, X::AbstractArray{T, N};
128128
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}
129129

130130
# Forward pass

src/utils/invertible_network_sequential.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ export ComposedInvertibleNetwork, Composition
66
import Base.length, Base.∘
77

88
struct ComposedInvertibleNetwork <: InvertibleNetwork
9-
layers::Array{T, 1} where {T <: Union{NeuralNetLayer, InvertibleNetwork}}
9+
layers::Array{T, 1} where {T <: Invertible}
1010
logdet_array::Array{Bool, 1}
1111
logdet::Bool
1212
npars::Array{Int64, 1}
@@ -21,7 +21,7 @@ function Composition(layer...)
2121

2222
# Initializing output
2323
depth = length(layer)
24-
net_array = Array{Union{NeuralNetLayer, InvertibleNetwork}, 1}(undef, depth)
24+
net_array = Array{Invertible, 1}(undef, depth)
2525
logdet_array = Array{Bool, 1}(undef, depth)
2626
logdet = false
2727
npars = Array{Int64, 1}(undef, depth)

src/utils/jacobian.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ struct JacobianInvertibleNetwork{T} <: joAbstractLinearOperator{T, T}
1515
m::Int64
1616
fop::Function
1717
fop_T::Function
18-
N::Union{NeuralNetLayer, InvertibleNetwork}
18+
N::Invertible
1919
X::AbstractArray{T}
2020
Y::Union{Nothing, AbstractArray{T}}
2121
end

src/utils/neuralnet.jl

+1-5
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ reset!(AI::Array{<:Invertible}) = for I ∈ AI reset!(I) end
108108
Resets the gradient of all the parameters in NL
109109
"""
110110
clear_grad!(I::Invertible) = clear_grad!(get_params(I))
111-
clear_grad!(RL::Reversed) = clear_grad!(RL.I)
112111

113112
# Get gradients
114113
"""
@@ -124,12 +123,9 @@ get_grads(RL::Reversed)= get_grads(RL.I)
124123
get_grads(::Nothing) = []
125124

126125
# Set parameters
127-
function set_params!(N::Union{NeuralNetLayer, InvertibleNetwork}, θnew::Array{Parameter, 1})
126+
function set_params!(N::Invertible, θnew::Array{Parameter, 1})
128127
set_params!(get_params(N), θnew)
129128
end
130129

131-
# Set params for reversed layers/networks
132-
set_params!(RL::Reversed, θ::Array{Parameter, 1}) = set_params!(RL.I, θ)
133-
134130
# Make invertible nets callable objects
135131
(N::Invertible)(X::AbstractArray{T,N} where {T, N}) = N.forward(X)

test/test_utils/test_chainrules.jl

+11-1
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,14 @@ N = Chain(N1, N2, N3, N4, N5, N6, N7, N8, N9, N10);
6060

6161
g2 = gradient(X -> loss(X), X)
6262

63-
@test g g2[1] rtol=1f-6
63+
@test g g2[1] rtol=1f-6
64+
65+
## test Reverse network AD
66+
67+
Nrev = reverse(N10)
68+
Xrev, ∂rev = rrule(Nrev, X)
69+
grev = ∂rev(Xrev-Y0)
70+
71+
g2rev = gradient(X -> 0.5f0*norm(Nrev(X) - Y0)^2, X)
72+
73+
@test grev[2] g2rev[1] rtol=1f-6

0 commit comments

Comments
 (0)