Skip to content

Commit ee6f9fe

Browse files
authored
Merge pull request #42 from tpapp/dw/error_hints
Add error hints
2 parents e3401f2 + ef18c48 commit ee6f9fe

File tree

4 files changed

+36
-6
lines changed

4 files changed

+36
-6
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogDensityProblemsAD"
22
uuid = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
33
authors = ["Tamás K. Papp <[email protected]>"]
4-
version = "1.12.0"
4+
version = "1.13.0"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

ext/LogDensityProblemsADADTypesExt.jl

+11
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,15 @@ function LogDensityProblemsAD.ADgradient(::ADTypes.AutoZygote, ℓ; x::Union{Not
6868
return LogDensityProblemsAD.ADgradient(Val(:Zygote), ℓ)
6969
end
7070

71+
# Better error message if users forget to load DifferentiationInterface
72+
if isdefined(Base.Experimental, :register_error_hint)
73+
function __init__()
74+
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
75+
if exc.f === LogDensityProblemsAD.ADgradient && length(argtypes) == 2 && first(argtypes) <: ADTypes.AbstractADType
76+
print(io, "\nDon't know how to AD with $(nameof(first(argtypes))). Did you forget to load DifferentiationInterface?")
77+
end
78+
end
79+
end
80+
end
81+
7182
end # module

src/LogDensityProblemsAD.jl

+11-3
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,17 @@ The function `parent` can be used to retrieve the original argument.
6363
"""
6464
ADgradient(kind::Symbol, P; kwargs...) = ADgradient(Val{kind}(), P; kwargs...)
6565

66-
function ADgradient(v::Val{kind}, P; kwargs...) where kind
67-
@info "Don't know how to AD with $(kind), consider `import $(kind)` if there is such a package."
68-
throw(MethodError(ADgradient, (v, P)))
66+
# Better error message if users forget to load the AD package
67+
if isdefined(Base.Experimental, :register_error_hint)
68+
_unval(::Type{Val{T}}) where {T} = T
69+
function __init__()
70+
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
71+
if exc.f === ADgradient && length(argtypes) == 2 && first(argtypes) <: Val
72+
kind = _unval(first(argtypes))
73+
print(io, "\nDon't know how to AD with $(kind), consider `import $(kind)` if there is such a package.")
74+
end
75+
end
76+
end
6977
end
7078

7179
#####

test/runtests.jl

+13-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import FiniteDifferences, ForwardDiff, Enzyme, Tracker, Zygote, ReverseDiff # ba
66
import ADTypes # load support for AD types with options
77
import BenchmarkTools # load the heuristic chunks code
88
using ComponentArrays: ComponentVector # test with other vector types
9-
import DifferentiationInterface
109

1110
struct EnzymeTestMode <: Enzyme.Mode{Enzyme.DefaultABI, false, false} end
1211

@@ -71,6 +70,15 @@ struct TestTag end
7170
# Allow tag type in gradient etc. calls of the log density function
7271
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TestTag, V}}, ::Base.Fix1{typeof(logdensity),typeof(TestLogDensity())}, ::AbstractArray{V}) where {V} = true
7372

73+
@testset "Missing DI for unsupported ADType" begin
74+
msg = "Don't know how to AD with AutoFiniteDifferences. Did you forget to load DifferentiationInterface?"
75+
adtype = ADTypes.AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(5, 1))
76+
@test_throws msg ADgradient(adtype, TestLogDensity2())
77+
@test_throws msg ADgradient(adtype, TestLogDensity2(); x=zeros(20))
78+
end
79+
80+
import DifferentiationInterface
81+
7482
@testset "AD via ReverseDiff" begin
7583
= TestLogDensity()
7684

@@ -296,7 +304,10 @@ end
296304

297305
@testset "ADgradient missing method" begin
298306
msg = "Don't know how to AD with Foo, consider `import Foo` if there is such a package."
299-
@test_logs((:info, msg), @test_throws(MethodError, ADgradient(:Foo, TestLogDensity2())))
307+
@test_throws msg ADgradient(:Foo, TestLogDensity2())
308+
@test_throws msg ADgradient(:Foo, TestLogDensity2(); x=zeros(20))
309+
@test_throws msg ADgradient(Val(:Foo), TestLogDensity2())
310+
@test_throws msg ADgradient(Val(:Foo), TestLogDensity2(); x=zeros(20))
300311
end
301312

302313
@testset "benchmark ForwardDiff chunk size" begin

0 commit comments

Comments
 (0)