From 59ae4f83b30e02407721a7ab40dcd2209a705bb4 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 2 Aug 2024 08:37:27 +0100 Subject: [PATCH 1/6] Initial attempt --- Project.toml | 2 +- src/dense.jl | 13 ++++++++----- src/legacy.jl | 2 ++ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index d6b8020..adcecd9 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.6.1" +version = "1.6.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/dense.jl b/src/dense.jl index 6757476..920d8f1 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -331,11 +331,12 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). # Constructors - AutoTapir(; safe_mode=true) + AutoTapir(; debug_mode::Bool) # Fields - - `safe_mode::Bool`: whether to run additional checks to catch errors early. While this is + - `safe_mode::Bool`: (to be renamed to `debug_mode` in the next breaking release) + whether to run additional checks to catch errors early. While this is on by default to ensure that users are aware of this option, you should generally turn it off for actual use, as it has substantial performance implications. If you encounter a problem with using Tapir (it fails to differentiate a function, or @@ -343,15 +344,17 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). on and look at what happens. Often errors are caught earlier and the error messages are more useful. """ -Base.@kwdef struct AutoTapir <: AbstractADType - safe_mode::Bool = true +struct AutoTapir <: AbstractADType + safe_mode::Bool end +AutoTapir(; debug_mode::Bool) = AutoTapir(debug_mode) + mode(::AutoTapir) = ReverseMode() function Base.show(io::IO, backend::AutoTapir) print(io, AutoTapir, "(") - !(backend.safe_mode) && print(io, "safe_mode=false") + !(backend.safe_mode) && print(io, "debug_mode=false") print(io, ")") end diff --git a/src/legacy.jl b/src/legacy.jl index 786299b..a61c685 100644 --- a/src/legacy.jl +++ b/src/legacy.jl @@ -15,6 +15,8 @@ @deprecate AutoReverseDiff(compile) AutoReverseDiff(; compile) +@deprecate AutoTapir() AutoTapir(; debug_mode=true), + function mtk_to_symbolics(obj_sparse::Bool, cons_sparse::Bool) if obj_sparse || cons_sparse return AutoSparse(AutoSymbolics()) From 403a104a6c79d0943666f550933efeedfaa82e94 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 2 Aug 2024 08:41:00 +0100 Subject: [PATCH 2/6] Remove errant comma --- src/legacy.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/legacy.jl b/src/legacy.jl index a61c685..2adc639 100644 --- a/src/legacy.jl +++ b/src/legacy.jl @@ -15,7 +15,7 @@ @deprecate AutoReverseDiff(compile) AutoReverseDiff(; compile) -@deprecate AutoTapir() AutoTapir(; debug_mode=true), +@deprecate AutoTapir() AutoTapir(; debug_mode=true) function mtk_to_symbolics(obj_sparse::Bool, cons_sparse::Bool) if obj_sparse || cons_sparse From 64f4f3010d3aad8019f920639c91aba5dc6d55a1 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 2 Aug 2024 08:50:24 +0100 Subject: [PATCH 3/6] Refactor --- src/dense.jl | 29 ++++++++++++++++++++++++++++- src/legacy.jl | 2 -- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/dense.jl b/src/dense.jl index 920d8f1..d612851 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -348,7 +348,34 @@ struct AutoTapir <: AbstractADType safe_mode::Bool end -AutoTapir(; debug_mode::Bool) = AutoTapir(debug_mode) +# This is a really awkward function to deprecate, because Julia does not dispatch on kwargs. +function AutoTapir(; + debug_mode::Union{Bool, Nothing}=nothing, safe_mode::Union{Bool, Nothing}=nothing, +) + if debug_mode !== nothing && safe_mode !== nothing + throw(ArgumentError( + "Both `debug_mode` and `safe_mode` have been set. Please only set `debug_mode`." + )) + end + + if safe_mode !== nothing + Base.depwarn( + "AutoTapir(; safe_mode) is deprecated, use AutoTapir(; debug_mode) instead.", + ((Base.Core).Typeof(AutoTapir)).name.mt.name, + ) + return AutoTapir(safe_mode) + end + + if debug_mode === nothing + Base.depwarn( + "AutoTapir() is deprecated, use AutoTapir(; debug_mode=true) instead.", + ((Base.Core).Typeof(AutoTapir)).name.mt.name, + ) + return AutoTapir(true) + else + return AutoTapir(debug_mode) + end +end mode(::AutoTapir) = ReverseMode() diff --git a/src/legacy.jl b/src/legacy.jl index 2adc639..786299b 100644 --- a/src/legacy.jl +++ b/src/legacy.jl @@ -15,8 +15,6 @@ @deprecate AutoReverseDiff(compile) AutoReverseDiff(; compile) -@deprecate AutoTapir() AutoTapir(; debug_mode=true) - function mtk_to_symbolics(obj_sparse::Bool, cons_sparse::Bool) if obj_sparse || cons_sparse return AutoSparse(AutoSymbolics()) From c9493541df94c24ca6c46bc14d3697c75aff5bbf Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 2 Aug 2024 08:51:05 +0100 Subject: [PATCH 4/6] Formatting --- src/dense.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dense.jl b/src/dense.jl index d612851..07af01b 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -350,7 +350,7 @@ end # This is a really awkward function to deprecate, because Julia does not dispatch on kwargs. function AutoTapir(; - debug_mode::Union{Bool, Nothing}=nothing, safe_mode::Union{Bool, Nothing}=nothing, + debug_mode::Union{Bool, Nothing}=nothing, safe_mode::Union{Bool, Nothing}=nothing ) if debug_mode !== nothing && safe_mode !== nothing throw(ArgumentError( From 813ad8f0c33125e30a10421c30c9065431d73c15 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 2 Aug 2024 08:56:21 +0100 Subject: [PATCH 5/6] Testing --- test/dense.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/dense.jl b/test/dense.jl index 739cf59..cd68668 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -155,7 +155,8 @@ end @test mode(ad) isa ReverseMode @test ad.safe_mode - ad = AutoTapir(; safe_mode = false) + @test_warn "" AutoTapir(; safe_mode = false) + ad = AutoTapir(; debug_mode = false) @test !ad.safe_mode end From 94cb4633ff50dd3b43db41fd2caab5d1e53aa492 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 2 Aug 2024 09:03:38 +0100 Subject: [PATCH 6/6] Fix testing --- test/dense.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/dense.jl b/test/dense.jl index cd68668..72ff6d8 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -155,9 +155,13 @@ end @test mode(ad) isa ReverseMode @test ad.safe_mode - @test_warn "" AutoTapir(; safe_mode = false) - ad = AutoTapir(; debug_mode = false) + ad = AutoTapir(; safe_mode = false) @test !ad.safe_mode + + # Check that new interface works as intended. + @test_throws ArgumentError AutoTapir(; debug_mode=false, safe_mode=true) + @test !AutoTapir(; debug_mode=false).safe_mode + @test AutoTapir(; debug_mode=true).safe_mode end @testset "AutoTracker" begin