@@ -8,6 +8,11 @@ import BenchmarkTools # load the heuristic chunks cod
8
8
using ComponentArrays: ComponentVector # test with other vector types
9
9
import DifferentiationInterface
10
10
11
+ DIGradient = Base. get_extension (
12
+ LogDensityProblemsAD,
13
+ :LogDensityProblemsADDifferentiationInterfaceExt
14
+ ). DIGradient
15
+
11
16
struct EnzymeTestMode <: Enzyme.Mode{Enzyme.DefaultABI, false, false} end
12
17
13
18
# ###
@@ -92,6 +97,7 @@ ForwardDiff.checktag(::Type{ForwardDiff.Tag{TestTag, V}}, ::Base.Fix1{typeof(log
92
97
93
98
# ADTypes support
94
99
@test typeof (ADgradient (ADTypes. AutoReverseDiff (; compile = Val (true )), ℓ)) === typeof (∇ℓ_compile)
100
+ @test ! isa (ADgradient (ADTypes. AutoReverseDiff (), ℓ), DIGradient)
95
101
96
102
for ∇ℓ in (∇ℓ_default, ∇ℓ_nocompile, ∇ℓ_compile, ∇ℓ_compile_x)
97
103
@test dimension (∇ℓ) == 3
128
134
129
135
# ADTypes support
130
136
@test ADgradient (ADTypes. AutoForwardDiff (), ℓ) === ∇ℓ
137
+ @test ! isa (ADgradient (ADTypes. AutoForwardDiff (), ℓ), DIGradient)
131
138
132
139
for _ in 1 : 100
133
140
x = randn (3 )
212
219
213
220
# ADTypes support
214
221
@test ADgradient (ADTypes. AutoTracker (), ℓ) === ∇ℓ
222
+ @test ! isa (ADgradient (ADTypes. AutoTracker (), ℓ), DIGradient)
215
223
end
216
224
217
225
@testset " AD via Zygote" begin
228
236
229
237
# ADTypes support
230
238
@test ADgradient (ADTypes. AutoZygote (), ℓ) === ∇ℓ
239
+ @test ! isa (ADgradient (ADTypes. AutoZygote (), ℓ), DIGradient)
231
240
end
232
241
233
242
@testset " AD via Enzyme" begin
242
251
243
252
# ADTypes support
244
253
@test ADgradient (ADTypes. AutoEnzyme (), ℓ) === ∇ℓ_reverse
254
+ @test ! isa (ADgradient (ADTypes. AutoEnzyme (), ℓ), DIGradient)
245
255
246
256
∇ℓ_forward = ADgradient (:Enzyme , ℓ; mode= Enzyme. Forward)
247
257
@test ADgradient (ADTypes. AutoEnzyme (;mode= Enzyme. Forward), ℓ) === ∇ℓ_forward
313
323
end
314
324
end
315
325
@testset " $(typeof (∇ℓ)) " for ∇ℓ in ∇ℓ_candidates
326
+ @test ∇ℓ isa DIGradient
316
327
@test dimension (∇ℓ) == 3
317
328
@test capabilities (∇ℓ) ≡ LogDensityOrder (1 )
318
329
for _ in 1 : 100
321
332
@test logdensity_and_gradient (∇ℓ, x) ≅ (test_logdensity1 (x), test_gradient (x)) atol = 1e-5
322
333
end
323
334
end
324
- end
335
+ end ;
0 commit comments