Skip to content

Commit c0f93d2

Browse files
authored
Add more tests
1 parent d68afce commit c0f93d2

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

ext/LogDensityProblemsADADTypesExt.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ if isdefined(Base.Experimental, :register_error_hint)
7373
function __init__()
7474
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
7575
if exc.f === LogDensityProblemsAD.ADgradient && length(argtypes) == 2 && first(argtypes) <: ADTypes.AbstractADType
76-
print(io, "\nDon't know how to AD with $(first(argtypes)). Did you forget to load DifferentiationInterface?")
76+
print(io, "\nDon't know how to AD with $(nameof(first(argtypes))). Did you forget to load DifferentiationInterface?")
7777
end
7878
end
7979
end

test/runtests.jl

+14-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import FiniteDifferences, ForwardDiff, Enzyme, Tracker, Zygote, ReverseDiff # ba
66
import ADTypes # load support for AD types with options
77
import BenchmarkTools # load the heuristic chunks code
88
using ComponentArrays: ComponentVector # test with other vector types
9-
import DifferentiationInterface
109

1110
struct EnzymeTestMode <: Enzyme.Mode{Enzyme.DefaultABI, false, false} end
1211

@@ -71,6 +70,15 @@ struct TestTag end
7170
# Allow tag type in gradient etc. calls of the log density function
7271
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TestTag, V}}, ::Base.Fix1{typeof(logdensity),typeof(TestLogDensity())}, ::AbstractArray{V}) where {V} = true
7372

73+
@testset "Missing DI for unsupported ADType" begin
74+
msg = "Don't know how to AD with AutoFiniteDifferences. Did you forget to load DifferentiationInterface?"
75+
adtype = ADTypes.AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(5, 1))
76+
@test_throws msg ADgradient(adtype, TestLogDensity2())
77+
@test_throws msg ADgradient(adtype, TestLogDensity2(); x=zeros(20))
78+
end
79+
80+
import DifferentiationInterface
81+
7482
@testset "AD via ReverseDiff" begin
7583
= TestLogDensity()
7684

@@ -295,7 +303,11 @@ end
295303
end
296304

297305
@testset "ADgradient missing method" begin
298-
@test_throws "Don't know how to AD with Foo, consider `import Foo` if there is such a package." ADgradient(:Foo, TestLogDensity2())
306+
msg = "Don't know how to AD with Foo, consider `import Foo` if there is such a package."
307+
@test_throws msg ADgradient(:Foo, TestLogDensity2())
308+
@test_throws msg ADgradient(:Foo, TestLogDensity2(); x=zeros(20))
309+
@test_throws msg ADgradient(Val(:Foo), TestLogDensity2())
310+
@test_throws msg ADgradient(Val(:Foo), TestLogDensity2(); x=zeros(20))
299311
end
300312

301313
@testset "benchmark ForwardDiff chunk size" begin

0 commit comments

Comments
 (0)