Skip to content

Commit d61bdd8

Browse files
committed
Test that DI stays in its lane
1 parent 6b9d344 commit d61bdd8

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

test/runtests.jl

+12-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ import BenchmarkTools # load the heuristic chunks cod
88
using ComponentArrays: ComponentVector # test with other vector types
99
import DifferentiationInterface
1010

11+
DIGradient = Base.get_extension(
12+
LogDensityProblemsAD,
13+
:LogDensityProblemsADDifferentiationInterfaceExt
14+
).DIGradient
15+
1116
struct EnzymeTestMode <: Enzyme.Mode{Enzyme.DefaultABI, false, false} end
1217

1318
####
@@ -92,6 +97,7 @@ ForwardDiff.checktag(::Type{ForwardDiff.Tag{TestTag, V}}, ::Base.Fix1{typeof(log
9297

9398
# ADTypes support
9499
@test typeof(ADgradient(ADTypes.AutoReverseDiff(; compile = Val(true)), ℓ)) === typeof(∇ℓ_compile)
100+
@test !isa(ADgradient(ADTypes.AutoReverseDiff(), ℓ), DIGradient)
95101

96102
for ∇ℓ in (∇ℓ_default, ∇ℓ_nocompile, ∇ℓ_compile, ∇ℓ_compile_x)
97103
@test dimension(∇ℓ) == 3
@@ -128,6 +134,7 @@ end
128134

129135
# ADTypes support
130136
@test ADgradient(ADTypes.AutoForwardDiff(), ℓ) === ∇ℓ
137+
@test !isa(ADgradient(ADTypes.AutoForwardDiff(), ℓ), DIGradient)
131138

132139
for _ in 1:100
133140
x = randn(3)
@@ -212,6 +219,7 @@ end
212219

213220
# ADTypes support
214221
@test ADgradient(ADTypes.AutoTracker(), ℓ) === ∇ℓ
222+
@test !isa(ADgradient(ADTypes.AutoTracker(), ℓ), DIGradient)
215223
end
216224

217225
@testset "AD via Zygote" begin
@@ -228,6 +236,7 @@ end
228236

229237
# ADTypes support
230238
@test ADgradient(ADTypes.AutoZygote(), ℓ) === ∇ℓ
239+
@test !isa(ADgradient(ADTypes.AutoZygote(), ℓ), DIGradient)
231240
end
232241

233242
@testset "AD via Enzyme" begin
@@ -242,6 +251,7 @@ end
242251

243252
# ADTypes support
244253
@test ADgradient(ADTypes.AutoEnzyme(), ℓ) === ∇ℓ_reverse
254+
@test !isa(ADgradient(ADTypes.AutoEnzyme(), ℓ), DIGradient)
245255

246256
∇ℓ_forward = ADgradient(:Enzyme, ℓ; mode=Enzyme.Forward)
247257
@test ADgradient(ADTypes.AutoEnzyme(;mode=Enzyme.Forward), ℓ) === ∇ℓ_forward
@@ -313,6 +323,7 @@ end
313323
end
314324
end
315325
@testset "$(typeof(∇ℓ))" for ∇ℓ in ∇ℓ_candidates
326+
@test ∇ℓ isa DIGradient
316327
@test dimension(∇ℓ) == 3
317328
@test capabilities(∇ℓ) LogDensityOrder(1)
318329
for _ in 1:100
@@ -321,4 +332,4 @@ end
321332
@test logdensity_and_gradient(∇ℓ, x) (test_logdensity1(x), test_gradient(x)) atol = 1e-5
322333
end
323334
end
324-
end
335+
end;

0 commit comments

Comments
 (0)