Skip to content

Commit 8ef803c

Browse files
authored
feat: support cache contexts with Zygote using Buffer (#708)
1 parent 5dfd7ad commit 8ef803c

File tree

3 files changed

+40
-94
lines changed

3 files changed

+40
-94
lines changed

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ Moreover, each context type is supported by a specific subset of backends:
7272
| `AutoReverseDiff` |||
7373
| `AutoSymbolics` |||
7474
| `AutoTracker` |||
75-
| `AutoZygote` || |
75+
| `AutoZygote` || 🔀 |
7676

7777
## Second order
7878

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 38 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@ using ADTypes: AutoForwardDiff, AutoZygote
44
import DifferentiationInterface as DI
55
using ForwardDiff: ForwardDiff
66
using Zygote:
7-
ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian
7+
Buffer,
8+
ZygoteRuleConfig,
9+
gradient,
10+
hessian,
11+
jacobian,
12+
pullback,
13+
withgradient,
14+
withjacobian
815

916
struct ZygoteNothingError <: Exception
1017
f
@@ -27,6 +34,9 @@ check_nothing(::Any, f, x, contexts) = nothing
2734
DI.check_available(::AutoZygote) = true
2835
DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported()
2936

37+
translate(c::DI.Context) = DI.unwrap(c)
38+
translate(c::DI.Cache) = Buffer(DI.unwrap(c))
39+
3040
## Pullback
3141

3242
struct ZygotePullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
@@ -35,32 +45,22 @@ struct ZygotePullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
3545
end
3646

3747
function DI.prepare_pullback(
38-
f, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}
48+
f, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C}
3949
) where {C}
4050
return DI.NoPullbackPrep()
4151
end
4252

4353
function DI.prepare_pullback_same_point(
44-
f,
45-
::DI.NoPullbackPrep,
46-
::AutoZygote,
47-
x,
48-
ty::NTuple,
49-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
54+
f, ::DI.NoPullbackPrep, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C}
5055
) where {C}
51-
y, pb = pullback(f, x, map(DI.unwrap, contexts)...)
56+
y, pb = pullback(f, x, map(translate, contexts)...)
5257
return ZygotePullbackPrepSamePoint(y, pb)
5358
end
5459

5560
function DI.value_and_pullback(
56-
f,
57-
::DI.NoPullbackPrep,
58-
::AutoZygote,
59-
x,
60-
ty::NTuple,
61-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
61+
f, ::DI.NoPullbackPrep, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C}
6262
) where {C}
63-
y, pb = pullback(f, x, map(DI.unwrap, contexts)...)
63+
y, pb = pullback(f, x, map(translate, contexts)...)
6464
tx = map(ty) do dy
6565
first(pb(dy))
6666
end
@@ -74,7 +74,7 @@ function DI.value_and_pullback(
7474
::AutoZygote,
7575
x,
7676
ty::NTuple,
77-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
77+
contexts::Vararg{DI.Context,C},
7878
) where {C}
7979
(; y, pb) = prep
8080
tx = map(ty) do dy
@@ -90,7 +90,7 @@ function DI.pullback(
9090
::AutoZygote,
9191
x,
9292
ty::NTuple,
93-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
93+
contexts::Vararg{DI.Context,C},
9494
) where {C}
9595
(; pb) = prep
9696
tx = map(ty) do dy
@@ -102,112 +102,72 @@ end
102102

103103
## Gradient
104104

105-
function DI.prepare_gradient(
106-
f, ::AutoZygote, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}
107-
) where {C}
105+
function DI.prepare_gradient(f, ::AutoZygote, x, contexts::Vararg{DI.Context,C}) where {C}
108106
return DI.NoGradientPrep()
109107
end
110108

111109
function DI.value_and_gradient(
112-
f,
113-
::DI.NoGradientPrep,
114-
::AutoZygote,
115-
x,
116-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
110+
f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C}
117111
) where {C}
118-
(; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...)
112+
(; val, grad) = withgradient(f, x, map(translate, contexts)...)
119113
check_nothing(first(grad), f, x, contexts)
120114
return val, first(grad)
121115
end
122116

123117
function DI.gradient(
124-
f,
125-
::DI.NoGradientPrep,
126-
::AutoZygote,
127-
x,
128-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
118+
f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C}
129119
) where {C}
130-
grad = gradient(f, x, map(DI.unwrap, contexts)...)
120+
grad = gradient(f, x, map(translate, contexts)...)
131121
check_nothing(first(grad), f, x, contexts)
132122
return first(grad)
133123
end
134124

135125
function DI.value_and_gradient!(
136-
f,
137-
grad,
138-
prep::DI.NoGradientPrep,
139-
backend::AutoZygote,
140-
x,
141-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
126+
f, grad, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C}
142127
) where {C}
143128
y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...)
144129
return y, copyto!(grad, new_grad)
145130
end
146131

147132
function DI.gradient!(
148-
f,
149-
grad,
150-
prep::DI.NoGradientPrep,
151-
backend::AutoZygote,
152-
x,
153-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
133+
f, grad, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C}
154134
) where {C}
155135
return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...))
156136
end
157137

158138
## Jacobian
159139

160-
function DI.prepare_jacobian(
161-
f, ::AutoZygote, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}
162-
) where {C}
140+
function DI.prepare_jacobian(f, ::AutoZygote, x, contexts::Vararg{DI.Context,C}) where {C}
163141
return DI.NoJacobianPrep()
164142
end
165143

166144
function DI.value_and_jacobian(
167-
f,
168-
::DI.NoJacobianPrep,
169-
::AutoZygote,
170-
x,
171-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
145+
f, ::DI.NoJacobianPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C}
172146
) where {C}
173-
y = f(x, map(DI.unwrap, contexts)...)
147+
y = f(x, map(translate, contexts)...)
174148
# https://github.com/FluxML/Zygote.jl/issues/1506
175-
jac = jacobian(f, x, map(DI.unwrap, contexts)...)
149+
jac = jacobian(f, x, map(translate, contexts)...)
176150
check_nothing(first(jac), f, x, contexts)
177151
return y, first(jac)
178152
end
179153

180154
function DI.jacobian(
181-
f,
182-
::DI.NoJacobianPrep,
183-
::AutoZygote,
184-
x,
185-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
155+
f, ::DI.NoJacobianPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C}
186156
) where {C}
187-
jac = jacobian(f, x, map(DI.unwrap, contexts)...)
157+
jac = jacobian(f, x, map(translate, contexts)...)
188158
check_nothing(first(jac), f, x, contexts)
189159
return first(jac)
190160
end
191161

192162
function DI.value_and_jacobian!(
193-
f,
194-
jac,
195-
prep::DI.NoJacobianPrep,
196-
backend::AutoZygote,
197-
x,
198-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
163+
f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C}
199164
) where {C}
200165
y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...)
201166
return y, copyto!(jac, new_jac)
202167
end
203168

204169
function DI.jacobian!(
205-
f,
206-
jac,
207-
prep::DI.NoJacobianPrep,
208-
backend::AutoZygote,
209-
x,
210-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
170+
f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C}
211171
) where {C}
212172
return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...))
213173
end
@@ -217,22 +177,13 @@ end
217177
# Beware, this uses ForwardDiff for the inner differentiation
218178

219179
function DI.prepare_hvp(
220-
f,
221-
backend::AutoZygote,
222-
x,
223-
tx::NTuple,
224-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
180+
f, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C}
225181
) where {C}
226182
return DI.prepare_hvp(f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...)
227183
end
228184

229185
function DI.hvp(
230-
f,
231-
prep::DI.HVPPrep,
232-
backend::AutoZygote,
233-
x,
234-
tx::NTuple,
235-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
186+
f, prep::DI.HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C}
236187
) where {C}
237188
return DI.hvp(f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...)
238189
end
@@ -244,20 +195,15 @@ function DI.hvp!(
244195
backend::AutoZygote,
245196
x,
246197
tx::NTuple,
247-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
198+
contexts::Vararg{DI.Context,C},
248199
) where {C}
249200
return DI.hvp!(
250201
f, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...
251202
)
252203
end
253204

254205
function DI.gradient_and_hvp(
255-
f,
256-
prep::DI.HVPPrep,
257-
backend::AutoZygote,
258-
x,
259-
tx::NTuple,
260-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
206+
f, prep::DI.HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C}
261207
) where {C}
262208
return DI.gradient_and_hvp(
263209
f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...
@@ -272,7 +218,7 @@ function DI.gradient_and_hvp!(
272218
backend::AutoZygote,
273219
x,
274220
tx::NTuple,
275-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
221+
contexts::Vararg{DI.Context,C},
276222
) where {C}
277223
return DI.gradient_and_hvp!(
278224
f, grad, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...

DifferentiationInterface/test/Back/Zygote/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ end
2727
@testset "Dense" begin
2828
test_differentiation(
2929
backends,
30-
default_scenarios(; include_constantified=true);
30+
default_scenarios(; include_constantified=true, include_cachified=true);
3131
excluded=[:second_derivative],
3232
logging=LOGGING,
3333
)

0 commit comments

Comments
 (0)