Skip to content

Commit df92d43

Browse files
authored
Merge pull request #79 from slimgroup/flux-fix
fix Flux compat
2 parents 5c092f9 + b2eeb12 commit df92d43

9 files changed

+88
-34
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "InvertibleNetworks"
22
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
33
authors = ["Philipp Witte <[email protected]>", "Ali Siahkoohi <[email protected]>", "Mathias Louboutin <[email protected]>", "Gabrio Rizzuti <[email protected]>", "Rafael Orozco <[email protected]>", "Felix J. herrmann <[email protected]>"]
4-
version = "2.2.4"
4+
version = "2.2.5"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/layers/invertible_layer_glow.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingL
142142
ΔX2 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X2) + ΔY2
143143
else
144144
ΔX2, Δθrb = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT; ), X2; set_grad=set_grad)
145-
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS_, S), 0f0.*ΔT;), X2; set_grad=set_grad)
145+
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), 0f0.*ΔT;), X2; set_grad=set_grad)
146146
ΔX2 += ΔY2
147147
end
148148
ΔX_ = tensor_cat(ΔX1, ΔX2)

src/layers/invertible_layer_hint.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingL
299299
end
300300

301301
# Input are two tensors ΔX, X
302-
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing) where {T, N}
302+
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing, set_grad::Bool=true) where {T, N}
303303
isnothing(permute) ? permute = H.permute : permute = permute
304304

305305
# Permutation

src/utils/chainrules.jl

+15-13
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
using ChainRulesCore
2-
export logdetjac
3-
import ChainRulesCore: frule, rrule
4-
2+
export logdetjac, getrrule
3+
import ChainRulesCore: frule, rrule, @non_differentiable
54

5+
@non_differentiable get_params(::Invertible)
6+
@non_differentiable get_params(::Reversed)
67
## Tape types and utilities
78

89
"""
@@ -81,7 +82,6 @@ function forward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N},
8182
if logdet isa Float32
8283
state.logdet === nothing ? (state.logdet = logdet) : (state.logdet += logdet)
8384
end
84-
8585
end
8686

8787
"""
@@ -97,15 +97,13 @@ function backward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N}
9797
state.Y[state.counter_block] = X
9898
state.counter_layer -= 1
9999
end
100-
101100
state.counter_block == 0 && reset!(state) # reset state when first block/first layer is reached
102-
103101
end
104102

105103
## Chain rules for invertible networks
106104
# General pullback function
107105
function pullback(net::Invertible, ΔY::AbstractArray{T,N};
108-
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}
106+
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}
109107

110108
# Check state coherency
111109
check_coherence(state, net)
@@ -114,19 +112,19 @@ function pullback(net::Invertible, ΔY::AbstractArray{T,N};
114112
T2 = typeof(current(state))
115113
ΔY = convert(T2, ΔY)
116114
# Backward pass
117-
ΔX, X_ = net.backward(ΔY, current(state))
118-
115+
ΔX, X_ = net.backward(ΔY, current(state); set_grad=true)
116+
Δθ = getfield.(get_params(net), :grad)
119117
# Update state
120118
backward_update!(state, X_)
121119

122-
return nothing, ΔX
120+
return NoTangent(), NoTangent(), ΔX, Δθ
123121
end
124122

125123

126124
# Reverse-mode AD rule
127-
function ChainRulesCore.rrule(net::Invertible, X::AbstractArray{T, N};
125+
function ChainRulesCore.rrule(::typeof(forward_net), net::Invertible, X::AbstractArray{T, N}, θ...;
128126
state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) where {T, N}
129-
127+
130128
# Forward pass
131129
net.logdet ? ((Y, logdet) = net.forward(X)) : (Y = net.forward(X); logdet = nothing)
132130

@@ -142,4 +140,8 @@ end
142140

143141
## Logdet utilities for Zygote pullback
144142

145-
logdetjac(; state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) = state.logdet
143+
logdetjac(; state::InvertibleOperationsTape=GLOBAL_STATE_INVOPS) = state.logdet
144+
145+
## Utility to get the pullback directly for testing
146+
147+
getrrule(net::Invertible, X::AbstractArray) = rrule(forward_net, net, X, getfield.(get_params(net), :data))

src/utils/dimensionality_operations.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ end
475475

476476
# Split and reshape 1D vector Y in latent space back to states Zi
477477
# where Zi is the split tensor at each multiscale level.
478-
function split_states(Y::AbstractVector{T}, Z_dims) where {T, N}
478+
function split_states(Y::AbstractVector{T}, Z_dims) where {T}
479479
L = length(Z_dims) + 1
480480
inds = cumsum([1, [prod(Z_dims[j]) for j=1:L-1]...])
481481
Z_save = [reshape(Y[inds[j]:inds[j+1]-1], xy_dims(Z_dims[j], Val(j==L), Val(length(Z_dims[j])))) for j=1:L-1]

src/utils/neuralnet.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333
getproperty(I::Invertible, s::Symbol) = _get_property(I, Val{s}())
3434

3535
_get_property(I::Invertible, ::Val{s}) where {s} = getfield(I, s)
36-
_get_property(R::Reversed, ::Val{:I}) where s = getfield(R, :I)
36+
_get_property(R::Reversed, ::Val{:I}) = getfield(R, :I)
3737
_get_property(R::Reversed, ::Val{s}) where s = _get_property(R.I, Val{s}())
3838

3939
for m _INet_modes
@@ -128,4 +128,5 @@ function set_params!(N::Invertible, θnew::Array{Parameter, 1})
128128
end
129129

130130
# Make invertible nets callable objects
131-
(N::Invertible)(X::AbstractArray{T,N} where {T, N}) = N.forward(X)
131+
(net::Invertible)(X::AbstractArray{T,N} where {T, N}) = forward_net(net, X, getfield.(get_params(net), :data))
132+
forward_net(net::Invertible, X::AbstractArray{T,N}, ::Any) where {T, N} = net.forward(X)

test/runtests.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ basics = ["test_utils/test_objectives.jl",
1717
"test_utils/test_activations.jl",
1818
"test_utils/test_squeeze.jl",
1919
"test_utils/test_jacobian.jl",
20-
"test_utils/test_chainrules.jl"]
20+
"test_utils/test_chainrules.jl",
21+
"test_utils/test_flux.jl"]
2122

2223
# Layers
2324
layers = ["test_layers/test_residual_block.jl",

test/test_utils/test_chainrules.jl

+14-14
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,22 @@ N10 = CouplingLayerHINT(n_ch, n_hidden; logdet=logdet, permute="full")
2323

2424
# Forward pass + gathering pullbacks
2525
function fw(X)
26-
X1, ∂1 = rrule(N1, X)
27-
X2, ∂2 = rrule(N2, X1)
28-
X3, ∂3 = rrule(N3, X2)
26+
X1, ∂1 = getrrule(N1, X)
27+
X2, ∂2 = getrrule(N2, X1)
28+
X3, ∂3 = getrrule(N3, X2)
2929
X5, ∂5 = Flux.Zygote.pullback(Chain(N4, N5), X3)
30-
X6, ∂6 = rrule(N6, X5)
31-
X7, ∂7 = rrule(N7, X6)
30+
X6, ∂6 = getrrule(N6, X5)
31+
X7, ∂7 = getrrule(N7, X6)
3232
X9, ∂9 = Flux.Zygote.pullback(Chain(N8, N9), X7)
33-
X10, ∂10 = rrule(N10, X9)
34-
d1 = x -> ∂1(x)[2]
35-
d2 = x -> ∂2(x)[2]
36-
d3 = x -> ∂3(x)[2]
33+
X10, ∂10 = getrrule(N10, X9)
34+
d1 = x -> ∂1(x)[3]
35+
d2 = x -> ∂2(x)[3]
36+
d3 = x -> ∂3(x)[3]
3737
d5 = x -> ∂5(x)[1]
38-
d6 = x -> ∂6(x)[2]
39-
d7 = x -> ∂7(x)[2]
38+
d6 = x -> ∂6(x)[3]
39+
d7 = x -> ∂7(x)[3]
4040
d9 = x -> ∂9(x)[1]
41-
d10 = x -> ∂10(x)[2]
41+
d10 = x -> ∂10(x)[3]
4242
return X10, ΔY -> d1(d2(d3(d5(d6(d7(d9(d10(ΔY))))))))
4343
end
4444

@@ -65,9 +65,9 @@ g2 = gradient(X -> loss(X), X)
6565
## test Reverse network AD
6666

6767
Nrev = reverse(N10)
68-
Xrev, ∂rev = rrule(Nrev, X)
68+
Xrev, ∂rev = getrrule(Nrev, X)
6969
grev = ∂rev(Xrev-Y0)
7070

7171
g2rev = gradient(X -> 0.5f0*norm(Nrev(X) - Y0)^2, X)
7272

73-
@test grev[2] g2rev[1] rtol=1f-6
73+
@test grev[3] g2rev[1] rtol=1f-6

test/test_utils/test_flux.jl

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using InvertibleNetworks, Flux, Test, LinearAlgebra
2+
3+
# Define network
4+
nx = 1
5+
ny = 1
6+
n_in = 2
7+
n_hidden = 10
8+
batchsize = 32
9+
10+
# net
11+
AN = ActNorm(n_in; logdet = false)
12+
C = CouplingLayerGlow(n_in, n_hidden; logdet = false, k1 = 1, k2 = 1, p1 = 0, p2 = 0)
13+
pan, pc = deepcopy(get_params(AN)), deepcopy(get_params(C))
14+
model = Chain(AN, C)
15+
16+
# dummy input & target
17+
X = randn(Float32, nx, ny, n_in, batchsize)
18+
Y = model(X)
19+
X0 = rand(Float32, nx, ny, n_in, batchsize) .+ 1
20+
21+
# loss fn
22+
loss(model, X, Y) = Flux.mse(Y, model(X))
23+
24+
# old, implicit-style Flux
25+
θ = Flux.params(model)
26+
opt = Descent(0.001)
27+
28+
l, grads = Flux.withgradient(θ) do
29+
loss(model, X0, Y)
30+
end
31+
32+
for θi in θ
33+
@test θi keys(grads.grads)
34+
@test !isnothing(grads.grads[θi])
35+
@test size(grads.grads[θi]) == size(θi)
36+
end
37+
38+
Flux.update!(opt, θ, grads)
39+
40+
for i = 1:5
41+
li, grads = Flux.withgradient(θ) do
42+
loss(model, X, Y)
43+
end
44+
45+
@info "Loss: $li"
46+
@test li != l
47+
global l = li
48+
49+
Flux.update!(opt, θ, grads)
50+
end

0 commit comments

Comments
 (0)