Skip to content

Add AutoGTPSA backend #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ AutoFiniteDiff
AutoFiniteDifferences
```

Taylor mode:

```@docs
AutoGTPSA
```

### Reverse mode

```@docs
Expand Down
1 change: 1 addition & 0 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export AutoChainRules,
AutoFiniteDiff,
AutoFiniteDifferences,
AutoForwardDiff,
AutoGTPSA,
AutoModelingToolkit,
AutoPolyesterForwardDiff,
AutoReverseDiff,
Expand Down
32 changes: 32 additions & 0 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,38 @@ function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize
print(io, ")")
end

"""
AutoGTPSA{D}

Struct used to select the [GTPSA.jl](https://github.com/bmad-sim/GTPSA.jl) backend for automatic differentiation.

Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).

# Constructors

AutoGTPSA(; descriptor=nothing)

# Fields

- `descriptor::D`: can be either

+ a GTPSA `Descriptor` specifying the number of variables/parameters, parameter
order, individual variable/parameter truncation orders, and maximum order. See
the [GTPSA.jl documentation](https://bmad-sim.github.io/GTPSA.jl/stable/man/c_descriptor/) for more details.
+ `nothing` to automatically use a `Descriptor` given the context.
"""
Base.@kwdef struct AutoGTPSA{D} <: AbstractADType
descriptor::D = nothing
end

mode(::AutoGTPSA) = ForwardMode()

function Base.show(io::IO, backend::AutoGTPSA{D}) where {D}
print(io, AutoGTPSA, "(")
D != Nothing && print(io, "descriptor=", repr(backend.descriptor; context = io))
print(io, ")")
end

"""
AutoPolyesterForwardDiff{chunksize,T}

Expand Down
14 changes: 14 additions & 0 deletions test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,20 @@ end
@test ad.tag == CustomTag()
end

@testset "AutoGTPSA" begin
ad = AutoGTPSA(; descriptor = nothing)
@test ad isa AbstractADType
@test ad isa AutoGTPSA{Nothing}
@test mode(ad) isa ForwardMode
@test ad.descriptor === nothing

ad = AutoGTPSA(; descriptor = Val(:descriptor))
@test ad isa AbstractADType
@test ad isa AutoGTPSA{Val{:descriptor}}
@test mode(ad) isa ForwardMode
@test ad.descriptor == Val(:descriptor)
end

@testset "AutoPolyesterForwardDiff" begin
ad = AutoPolyesterForwardDiff()
@test ad isa AbstractADType
Expand Down
2 changes: 2 additions & 0 deletions test/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ for backend in [
ADTypes.AutoFiniteDifferences(; fdm = :fdm),
ADTypes.AutoForwardDiff(),
ADTypes.AutoForwardDiff(chunksize = 3, tag = :tag),
ADTypes.AutoGTPSA(),
ADTypes.AutoGTPSA(; descriptor = Val(:descriptor)),
ADTypes.AutoPolyesterForwardDiff(),
ADTypes.AutoPolyesterForwardDiff(chunksize = 3, tag = :tag),
ADTypes.AutoReverseDiff(),
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ function every_ad()
AutoFiniteDiff(),
AutoFiniteDifferences(; fdm = :fdm),
AutoForwardDiff(),
AutoGTPSA(),
AutoPolyesterForwardDiff(),
AutoReverseDiff(),
AutoSymbolics(),
Expand All @@ -64,6 +65,8 @@ function every_ad_with_options()
AutoFiniteDifferences(; fdm = :fdm),
AutoForwardDiff(),
AutoForwardDiff(chunksize = 3, tag = :tag),
AutoGTPSA(),
AutoGTPSA(descriptor = Val(:descriptor)),
AutoPolyesterForwardDiff(),
AutoPolyesterForwardDiff(chunksize = 3, tag = :tag),
AutoReverseDiff(),
Expand Down
Loading