@@ -4,7 +4,14 @@ using ADTypes: AutoForwardDiff, AutoZygote
4
4
import DifferentiationInterface as DI
5
5
using ForwardDiff: ForwardDiff
6
6
using Zygote:
7
- ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian
7
+ Buffer,
8
+ ZygoteRuleConfig,
9
+ gradient,
10
+ hessian,
11
+ jacobian,
12
+ pullback,
13
+ withgradient,
14
+ withjacobian
8
15
9
16
struct ZygoteNothingError <: Exception
10
17
f
@@ -27,6 +34,9 @@ check_nothing(::Any, f, x, contexts) = nothing
27
34
DI. check_available (:: AutoZygote ) = true
28
35
DI. inplace_support (:: AutoZygote ) = DI. InPlaceNotSupported ()
29
36
37
+ translate (c:: DI.Context ) = DI. unwrap (c)
38
+ translate (c:: DI.Cache ) = Buffer (DI. unwrap (c))
39
+
30
40
# # Pullback
31
41
32
42
struct ZygotePullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
@@ -35,32 +45,22 @@ struct ZygotePullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
35
45
end
36
46
37
47
function DI. prepare_pullback (
38
- f, :: AutoZygote , x, ty:: NTuple , contexts:: Vararg{DI.ConstantOrFunctionOrBackend ,C}
48
+ f, :: AutoZygote , x, ty:: NTuple , contexts:: Vararg{DI.Context ,C}
39
49
) where {C}
40
50
return DI. NoPullbackPrep ()
41
51
end
42
52
43
53
function DI. prepare_pullback_same_point (
44
- f,
45
- :: DI.NoPullbackPrep ,
46
- :: AutoZygote ,
47
- x,
48
- ty:: NTuple ,
49
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
54
+ f, :: DI.NoPullbackPrep , :: AutoZygote , x, ty:: NTuple , contexts:: Vararg{DI.Context,C}
50
55
) where {C}
51
- y, pb = pullback (f, x, map (DI . unwrap , contexts)... )
56
+ y, pb = pullback (f, x, map (translate , contexts)... )
52
57
return ZygotePullbackPrepSamePoint (y, pb)
53
58
end
54
59
55
60
function DI. value_and_pullback (
56
- f,
57
- :: DI.NoPullbackPrep ,
58
- :: AutoZygote ,
59
- x,
60
- ty:: NTuple ,
61
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
61
+ f, :: DI.NoPullbackPrep , :: AutoZygote , x, ty:: NTuple , contexts:: Vararg{DI.Context,C}
62
62
) where {C}
63
- y, pb = pullback (f, x, map (DI . unwrap , contexts)... )
63
+ y, pb = pullback (f, x, map (translate , contexts)... )
64
64
tx = map (ty) do dy
65
65
first (pb (dy))
66
66
end
@@ -74,7 +74,7 @@ function DI.value_and_pullback(
74
74
:: AutoZygote ,
75
75
x,
76
76
ty:: NTuple ,
77
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend ,C} ,
77
+ contexts:: Vararg{DI.Context ,C} ,
78
78
) where {C}
79
79
(; y, pb) = prep
80
80
tx = map (ty) do dy
@@ -90,7 +90,7 @@ function DI.pullback(
90
90
:: AutoZygote ,
91
91
x,
92
92
ty:: NTuple ,
93
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend ,C} ,
93
+ contexts:: Vararg{DI.Context ,C} ,
94
94
) where {C}
95
95
(; pb) = prep
96
96
tx = map (ty) do dy
@@ -102,112 +102,72 @@ end
102
102
103
103
# # Gradient
104
104
105
- function DI. prepare_gradient (
106
- f, :: AutoZygote , x, contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C}
107
- ) where {C}
105
+ function DI. prepare_gradient (f, :: AutoZygote , x, contexts:: Vararg{DI.Context,C} ) where {C}
108
106
return DI. NoGradientPrep ()
109
107
end
110
108
111
109
function DI. value_and_gradient (
112
- f,
113
- :: DI.NoGradientPrep ,
114
- :: AutoZygote ,
115
- x,
116
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
110
+ f, :: DI.NoGradientPrep , :: AutoZygote , x, contexts:: Vararg{DI.Context,C}
117
111
) where {C}
118
- (; val, grad) = withgradient (f, x, map (DI . unwrap , contexts)... )
112
+ (; val, grad) = withgradient (f, x, map (translate , contexts)... )
119
113
check_nothing (first (grad), f, x, contexts)
120
114
return val, first (grad)
121
115
end
122
116
123
117
function DI. gradient (
124
- f,
125
- :: DI.NoGradientPrep ,
126
- :: AutoZygote ,
127
- x,
128
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
118
+ f, :: DI.NoGradientPrep , :: AutoZygote , x, contexts:: Vararg{DI.Context,C}
129
119
) where {C}
130
- grad = gradient (f, x, map (DI . unwrap , contexts)... )
120
+ grad = gradient (f, x, map (translate , contexts)... )
131
121
check_nothing (first (grad), f, x, contexts)
132
122
return first (grad)
133
123
end
134
124
135
125
function DI. value_and_gradient! (
136
- f,
137
- grad,
138
- prep:: DI.NoGradientPrep ,
139
- backend:: AutoZygote ,
140
- x,
141
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
126
+ f, grad, prep:: DI.NoGradientPrep , backend:: AutoZygote , x, contexts:: Vararg{DI.Context,C}
142
127
) where {C}
143
128
y, new_grad = DI. value_and_gradient (f, prep, backend, x, contexts... )
144
129
return y, copyto! (grad, new_grad)
145
130
end
146
131
147
132
function DI. gradient! (
148
- f,
149
- grad,
150
- prep:: DI.NoGradientPrep ,
151
- backend:: AutoZygote ,
152
- x,
153
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
133
+ f, grad, prep:: DI.NoGradientPrep , backend:: AutoZygote , x, contexts:: Vararg{DI.Context,C}
154
134
) where {C}
155
135
return copyto! (grad, DI. gradient (f, prep, backend, x, contexts... ))
156
136
end
157
137
158
138
# # Jacobian
159
139
160
- function DI. prepare_jacobian (
161
- f, :: AutoZygote , x, contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C}
162
- ) where {C}
140
+ function DI. prepare_jacobian (f, :: AutoZygote , x, contexts:: Vararg{DI.Context,C} ) where {C}
163
141
return DI. NoJacobianPrep ()
164
142
end
165
143
166
144
function DI. value_and_jacobian (
167
- f,
168
- :: DI.NoJacobianPrep ,
169
- :: AutoZygote ,
170
- x,
171
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
145
+ f, :: DI.NoJacobianPrep , :: AutoZygote , x, contexts:: Vararg{DI.Context,C}
172
146
) where {C}
173
- y = f (x, map (DI . unwrap , contexts)... )
147
+ y = f (x, map (translate , contexts)... )
174
148
# https://github.com/FluxML/Zygote.jl/issues/1506
175
- jac = jacobian (f, x, map (DI . unwrap , contexts)... )
149
+ jac = jacobian (f, x, map (translate , contexts)... )
176
150
check_nothing (first (jac), f, x, contexts)
177
151
return y, first (jac)
178
152
end
179
153
180
154
function DI. jacobian (
181
- f,
182
- :: DI.NoJacobianPrep ,
183
- :: AutoZygote ,
184
- x,
185
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
155
+ f, :: DI.NoJacobianPrep , :: AutoZygote , x, contexts:: Vararg{DI.Context,C}
186
156
) where {C}
187
- jac = jacobian (f, x, map (DI . unwrap , contexts)... )
157
+ jac = jacobian (f, x, map (translate , contexts)... )
188
158
check_nothing (first (jac), f, x, contexts)
189
159
return first (jac)
190
160
end
191
161
192
162
function DI. value_and_jacobian! (
193
- f,
194
- jac,
195
- prep:: DI.NoJacobianPrep ,
196
- backend:: AutoZygote ,
197
- x,
198
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
163
+ f, jac, prep:: DI.NoJacobianPrep , backend:: AutoZygote , x, contexts:: Vararg{DI.Context,C}
199
164
) where {C}
200
165
y, new_jac = DI. value_and_jacobian (f, prep, backend, x, contexts... )
201
166
return y, copyto! (jac, new_jac)
202
167
end
203
168
204
169
function DI. jacobian! (
205
- f,
206
- jac,
207
- prep:: DI.NoJacobianPrep ,
208
- backend:: AutoZygote ,
209
- x,
210
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
170
+ f, jac, prep:: DI.NoJacobianPrep , backend:: AutoZygote , x, contexts:: Vararg{DI.Context,C}
211
171
) where {C}
212
172
return copyto! (jac, DI. jacobian (f, prep, backend, x, contexts... ))
213
173
end
@@ -217,22 +177,13 @@ end
217
177
# Beware, this uses ForwardDiff for the inner differentiation
218
178
219
179
function DI. prepare_hvp (
220
- f,
221
- backend:: AutoZygote ,
222
- x,
223
- tx:: NTuple ,
224
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
180
+ f, backend:: AutoZygote , x, tx:: NTuple , contexts:: Vararg{DI.Context,C}
225
181
) where {C}
226
182
return DI. prepare_hvp (f, DI. SecondOrder (AutoForwardDiff (), backend), x, tx, contexts... )
227
183
end
228
184
229
185
function DI. hvp (
230
- f,
231
- prep:: DI.HVPPrep ,
232
- backend:: AutoZygote ,
233
- x,
234
- tx:: NTuple ,
235
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
186
+ f, prep:: DI.HVPPrep , backend:: AutoZygote , x, tx:: NTuple , contexts:: Vararg{DI.Context,C}
236
187
) where {C}
237
188
return DI. hvp (f, prep, DI. SecondOrder (AutoForwardDiff (), backend), x, tx, contexts... )
238
189
end
@@ -244,20 +195,15 @@ function DI.hvp!(
244
195
backend:: AutoZygote ,
245
196
x,
246
197
tx:: NTuple ,
247
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend ,C} ,
198
+ contexts:: Vararg{DI.Context ,C} ,
248
199
) where {C}
249
200
return DI. hvp! (
250
201
f, tg, prep, DI. SecondOrder (AutoForwardDiff (), backend), x, tx, contexts...
251
202
)
252
203
end
253
204
254
205
function DI. gradient_and_hvp (
255
- f,
256
- prep:: DI.HVPPrep ,
257
- backend:: AutoZygote ,
258
- x,
259
- tx:: NTuple ,
260
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
206
+ f, prep:: DI.HVPPrep , backend:: AutoZygote , x, tx:: NTuple , contexts:: Vararg{DI.Context,C}
261
207
) where {C}
262
208
return DI. gradient_and_hvp (
263
209
f, prep, DI. SecondOrder (AutoForwardDiff (), backend), x, tx, contexts...
@@ -272,7 +218,7 @@ function DI.gradient_and_hvp!(
272
218
backend:: AutoZygote ,
273
219
x,
274
220
tx:: NTuple ,
275
- contexts:: Vararg{DI.ConstantOrFunctionOrBackend ,C} ,
221
+ contexts:: Vararg{DI.Context ,C} ,
276
222
) where {C}
277
223
return DI. gradient_and_hvp! (
278
224
f, grad, tg, prep, DI. SecondOrder (AutoForwardDiff (), backend), x, tx, contexts...
0 commit comments