From 37e3334dbf0fcca0bf435c5a6d1c4dc74d4d4637 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 26 Jun 2024 15:04:45 +0200 Subject: [PATCH 1/3] Fix pretty printing and ReverseDiff constructor --- Project.toml | 2 +- src/dense.jl | 83 ++++++++++++++++---------------------------------- src/legacy.jl | 2 ++ src/sparse.jl | 10 +++--- test/legacy.jl | 5 +++ test/misc.jl | 1 + 6 files changed, 41 insertions(+), 62 deletions(-) diff --git a/Project.toml b/Project.toml index d9ea062..9557b76 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.5.1" +version = "1.5.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/dense.jl b/src/dense.jl index 7958a67..9951e84 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -20,7 +20,7 @@ end mode(::AutoChainRules) = ForwardOrReverseMode() # specialized in the extension function Base.show(io::IO, backend::AutoChainRules) - print(io, "AutoChainRules(ruleconfig=$(repr(backend.ruleconfig, context=io)))") + print(io, AutoChainRules, "(ruleconfig=", repr(backend.ruleconfig; context = io), ")") end """ @@ -63,11 +63,9 @@ end mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension function Base.show(io::IO, backend::AutoEnzyme) - if isnothing(backend.mode) - print(io, "AutoEnzyme()") - else - print(io, "AutoEnzyme(mode=$(repr(backend.mode, context=io)))") - end + print(io, AutoEnzyme, "(") + !isnothing(backend.mode) && print(io, "mode=", repr(backend.mode; context = io)) + print(io, ")") end """ @@ -111,21 +109,14 @@ end mode(::AutoFiniteDiff) = ForwardMode() function Base.show(io::IO, backend::AutoFiniteDiff) - s = "AutoFiniteDiff(" - if backend.fdtype != Val(:forward) - s *= "fdtype=$(repr(backend.fdtype, context=io)), " - end - if backend.fdjtype != backend.fdtype - s *= "fdjtype=$(repr(backend.fdjtype, context=io)), " - end - if backend.fdhtype != Val(:hcentral) - s *= "fdhtype=$(repr(backend.fdhtype, context=io)), " - end - if endswith(s, ", ") - s = s[1:(end - 2)] - end - s *= ")" - print(io, s) + print(io, AutoFiniteDiff, "(") + backend.fdtype != Val(:forward) && + print(io, "fdtype=", repr(backend.fdtype; context = io), ",") + backend.fdjtype != backend.fdtype && + print(io, "fdjtype=", repr(backend.fdjtype; context = io), ",") + backend.fdhtype != Val(:hcentral) && + print(io, "fdhtype=", repr(backend.fdhtype; context = io), ",") + print(io, ")") end """ @@ -150,7 +141,7 @@ end mode(::AutoFiniteDifferences) = ForwardMode() function Base.show(io::IO, backend::AutoFiniteDifferences) - print(io, "AutoFiniteDifferences(fdm=$(repr(backend.fdm, context=io)))") + print(io, AutoFiniteDifferences, "(fdm=", repr(backend.fdm; context = io), ")") end """ @@ -183,18 +174,10 @@ end mode(::AutoForwardDiff) = ForwardMode() function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize} - s = "AutoForwardDiff(" - if chunksize !== nothing - s *= "chunksize=$chunksize, " - end - if backend.tag !== nothing - s *= "tag=$(repr(backend.tag, context=io)), " - end - if endswith(s, ", ") - s = s[1:(end - 2)] - end - s *= ")" - print(io, s) + print(io, AutoForwardDiff, "(") + chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io), ",") + backend.tag !== nothing && print(io, "tag=", repr(backend.tag; context = io), ",") + print(io, ")") end """ @@ -227,18 +210,10 @@ end mode(::AutoPolyesterForwardDiff) = ForwardMode() function Base.show(io::IO, backend::AutoPolyesterForwardDiff{chunksize}) where {chunksize} - s = "AutoPolyesterForwardDiff(" - if chunksize !== nothing - s *= "chunksize=$chunksize, " - end - if backend.tag !== nothing - s *= "tag=$(repr(backend.tag, context=io)), " - end - if endswith(s, ", ") - s = s[1:(end - 2)] - end - s *= ")" - print(io, s) + print(io, AutoPolyesterForwardDiff, "(") + chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io), ",") + backend.tag !== nothing && print(io, "tag=", repr(backend.tag; context = io), ",") + print(io, ")") end """ @@ -277,11 +252,9 @@ end mode(::AutoReverseDiff) = ReverseMode() function Base.show(io::IO, ::AutoReverseDiff{compile}) where {compile} - if !compile - print(io, "AutoReverseDiff()") - else - print(io, "AutoReverseDiff(compile=true)") - end + print(io, AutoReverseDiff, "(") + compile && print(io, "compile=true") + print(io, ")") end """ @@ -321,11 +294,9 @@ end mode(::AutoTapir) = ReverseMode() function Base.show(io::IO, backend::AutoTapir) - if backend.safe_mode - print(io, "AutoTapir()") - else - print(io, "AutoTapir(safe_mode=false)") - end + print(io, AutoReverseDiff, "(") + !(backend.safe_mode) && print(io, "safe_mode=false") + print(io, ")") end """ diff --git a/src/legacy.jl b/src/legacy.jl index 5784399..91d8acc 100644 --- a/src/legacy.jl +++ b/src/legacy.jl @@ -11,6 +11,8 @@ @deprecate AutoSparseZygote() AutoSparse(AutoZygote()) +@deprecate AutoReverseDiff(compile) AutoReverseDiff(; compile) + function mtk_to_symbolics(obj_sparse::Bool, cons_sparse::Bool) if obj_sparse || cons_sparse return AutoSparse(AutoSymbolics()) diff --git a/src/sparse.jl b/src/sparse.jl index 85a8e97..de7c5a9 100644 --- a/src/sparse.jl +++ b/src/sparse.jl @@ -155,15 +155,15 @@ function AutoSparse( end function Base.show(io::IO, backend::AutoSparse) - s = "AutoSparse(dense_ad=$(repr(backend.dense_ad, context=io)), " + print(io, AutoSparse, "(dense_ad=", repr(backend.dense_ad, context = io), ",") if backend.sparsity_detector != NoSparsityDetector() - s *= "sparsity_detector=$(repr(backend.sparsity_detector, context=io)), " + print(io, "sparsity_detector=", repr(backend.sparsity_detector, context = io), ",") end if backend.coloring_algorithm != NoColoringAlgorithm() - s *= "coloring_algorithm=$(repr(backend.coloring_algorithm, context=io))), " + print( + io, "coloring_algorithm=", repr(backend.coloring_algorithm, context = io), ",") end - s = s[1:(end - 2)] * ")" - print(io, s) + print(io, ")") end """ diff --git a/test/legacy.jl b/test/legacy.jl index 08f09d9..d4f2076 100644 --- a/test/legacy.jl +++ b/test/legacy.jl @@ -58,3 +58,8 @@ end @test ad isa AbstractADType @test dense_ad(ad) isa AutoZygote end + +@testset "AutoReverseDiff without kwarg" begin + ad = @test_deprecated AutoReverseDiff(true) + @test ad.compile +end diff --git a/test/misc.jl b/test/misc.jl index a1b2274..5f03a94 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -7,6 +7,7 @@ end @testset "Printing" begin for ad in every_ad_with_options() @test startswith(string(ad), "Auto") + @test contains(string(ad), "(") @test endswith(string(ad), ")") end From 4341595f018ca879f8762d5e1a5f14c7c9faa5e0 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 26 Jun 2024 15:11:40 +0200 Subject: [PATCH 2/3] Better --- src/dense.jl | 16 +++++++++------- src/sparse.jl | 6 +++--- test/misc.jl | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/src/dense.jl b/src/dense.jl index 9951e84..dfcedd9 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -111,11 +111,11 @@ mode(::AutoFiniteDiff) = ForwardMode() function Base.show(io::IO, backend::AutoFiniteDiff) print(io, AutoFiniteDiff, "(") backend.fdtype != Val(:forward) && - print(io, "fdtype=", repr(backend.fdtype; context = io), ",") + print(io, "fdtype=", repr(backend.fdtype; context = io), ", ") backend.fdjtype != backend.fdtype && - print(io, "fdjtype=", repr(backend.fdjtype; context = io), ",") + print(io, "fdjtype=", repr(backend.fdjtype; context = io), ", ") backend.fdhtype != Val(:hcentral) && - print(io, "fdhtype=", repr(backend.fdhtype; context = io), ",") + print(io, "fdhtype=", repr(backend.fdhtype; context = io)) print(io, ")") end @@ -175,8 +175,9 @@ mode(::AutoForwardDiff) = ForwardMode() function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize} print(io, AutoForwardDiff, "(") - chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io), ",") - backend.tag !== nothing && print(io, "tag=", repr(backend.tag; context = io), ",") + chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io), + (backend.tag !== nothing ? ", " : "")) + backend.tag !== nothing && print(io, "tag=", repr(backend.tag; context = io)) print(io, ")") end @@ -211,8 +212,9 @@ mode(::AutoPolyesterForwardDiff) = ForwardMode() function Base.show(io::IO, backend::AutoPolyesterForwardDiff{chunksize}) where {chunksize} print(io, AutoPolyesterForwardDiff, "(") - chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io), ",") - backend.tag !== nothing && print(io, "tag=", repr(backend.tag; context = io), ",") + chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io), + (backend.tag !== nothing ? ", " : "")) + backend.tag !== nothing && print(io, "tag=", repr(backend.tag; context = io)) print(io, ")") end diff --git a/src/sparse.jl b/src/sparse.jl index de7c5a9..ffd157a 100644 --- a/src/sparse.jl +++ b/src/sparse.jl @@ -155,13 +155,13 @@ function AutoSparse( end function Base.show(io::IO, backend::AutoSparse) - print(io, AutoSparse, "(dense_ad=", repr(backend.dense_ad, context = io), ",") + print(io, AutoSparse, "(dense_ad=", repr(backend.dense_ad, context = io)) if backend.sparsity_detector != NoSparsityDetector() - print(io, "sparsity_detector=", repr(backend.sparsity_detector, context = io), ",") + print(io, ", sparsity_detector=", repr(backend.sparsity_detector, context = io)) end if backend.coloring_algorithm != NoColoringAlgorithm() print( - io, "coloring_algorithm=", repr(backend.coloring_algorithm, context = io), ",") + io, ", coloring_algorithm=", repr(backend.coloring_algorithm, context = io)) end print(io, ")") end diff --git a/test/misc.jl b/test/misc.jl index 5f03a94..0ca1ddd 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -20,3 +20,40 @@ end @test contains(string(sparse_backend1), string(AutoForwardDiff())) @test length(string(sparse_backend1)) < length(string(sparse_backend2)) end + +import ADTypes + +struct FakeSparsityDetector <: ADTypes.AbstractSparsityDetector end +struct FakeColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm end + +for backend in [ + # dense + ADTypes.AutoChainRules(; ruleconfig = :rc), + ADTypes.AutoDiffractor(), + ADTypes.AutoEnzyme(), + ADTypes.AutoEnzyme(mode = :forward), + ADTypes.AutoFastDifferentiation(), + ADTypes.AutoFiniteDiff(), + ADTypes.AutoFiniteDiff(fdtype = :fd, fdjtype = :fdj, fdhtype = :fdh), + ADTypes.AutoFiniteDifferences(; fdm = :fdm), + ADTypes.AutoForwardDiff(), + ADTypes.AutoForwardDiff(chunksize = 3, tag = :tag), + ADTypes.AutoPolyesterForwardDiff(), + ADTypes.AutoPolyesterForwardDiff(chunksize = 3, tag = :tag), + ADTypes.AutoReverseDiff(), + ADTypes.AutoReverseDiff(compile = true), + ADTypes.AutoSymbolics(), + ADTypes.AutoTapir(), + ADTypes.AutoTapir(safe_mode = false), + ADTypes.AutoTracker(), + ADTypes.AutoZygote(), + # sparse + ADTypes.AutoSparse(ADTypes.AutoForwardDiff()), + ADTypes.AutoSparse( + ADTypes.AutoForwardDiff(); + sparsity_detector = FakeSparsityDetector(), + coloring_algorithm = FakeColoringAlgorithm() + ) +] + println(backend) +end From 32e0cd45f332889bd65285e1cfcdcbdce3c2f261 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 26 Jun 2024 15:13:10 +0200 Subject: [PATCH 3/3] Tapir bug --- src/dense.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dense.jl b/src/dense.jl index dfcedd9..dd7e8e5 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -296,7 +296,7 @@ end mode(::AutoTapir) = ReverseMode() function Base.show(io::IO, backend::AutoTapir) - print(io, AutoReverseDiff, "(") + print(io, AutoTapir, "(") !(backend.safe_mode) && print(io, "safe_mode=false") print(io, ")") end