Skip to content

Commit 1f735b3

Browse files
committed
feat: Add forward mode Mooncake
1 parent 623e6af commit 1f735b3

File tree

6 files changed

+41
-3
lines changed

6 files changed

+41
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ADTypes"
22
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = ["Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors"]
4-
version = "1.14.1"
4+
version = "1.15.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Algorithmic differentiation:
2121
```@docs
2222
AutoForwardDiff
2323
AutoPolyesterForwardDiff
24+
AutoMooncakeForward
2425
```
2526

2627
Finite differences:

src/ADTypes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ export AutoChainRules,
4343
AutoGTPSA,
4444
AutoModelingToolkit,
4545
AutoMooncake,
46+
AutoMooncakeForward,
4647
AutoPolyesterForwardDiff,
4748
AutoReverseDiff,
4849
AutoSymbolics,

src/dense.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,31 @@ end
294294

295295
mode(::AutoMooncake) = ReverseMode()
296296

297+
"""
298+
AutoMooncakeForward
299+
300+
Struct used to select the [Mooncake.jl](https://github.com/compintell/Mooncake.jl) backend for automatic differentiation in forward mode.
301+
302+
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
303+
304+
!!! info
305+
306+
This type was introduced when forward mode became available in Mooncake.jl. It was kept separate from [`AutoMooncake`](@ref) in order to avoid requiring a breaking release of ADTypes.jl.
307+
308+
# Constructors
309+
310+
AutoMooncakeForward(; config)
311+
312+
# Fields
313+
314+
- `config`: either `nothing` or an instance of `Mooncake.Config` -- see the docstring of `Mooncake.Config` for more information. `AutoForwardMooncake(; config=nothing)` is equivalent to `AutoForwardMooncake(; config=Mooncake.Config())`, i.e. the default configuration.
315+
"""
316+
Base.@kwdef struct AutoMooncakeForward{Tconfig} <: AbstractADType
317+
config::Tconfig
318+
end
319+
320+
mode(::AutoMooncakeForward) = ForwardMode()
321+
297322
"""
298323
AutoPolyesterForwardDiff{chunksize,T}
299324

test/dense.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ end
7171
@test ad.absstep === nothing
7272
@test ad.dir
7373

74-
ad = AutoFiniteDiff(; fdtype = Val(:central), fdjtype = Val(:forward), relstep = 1e-3, absstep = 1e-4, dir = false)
74+
ad = AutoFiniteDiff(; fdtype = Val(:central), fdjtype = Val(:forward),
75+
relstep = 1e-3, absstep = 1e-4, dir = false)
7576
@test ad isa AbstractADType
7677
@test ad isa AutoFiniteDiff
7778
@test mode(ad) isa ForwardMode
@@ -126,13 +127,21 @@ end
126127
end
127128

128129
@testset "AutoMooncake" begin
129-
ad = AutoMooncake(; config=nothing)
130+
ad = AutoMooncake(; config = nothing)
130131
@test ad isa AbstractADType
131132
@test ad isa AutoMooncake
132133
@test mode(ad) isa ReverseMode
133134
@test ad.config === nothing
134135
end
135136

137+
@testset "AutoMooncakeForward" begin
138+
ad = AutoMooncakeForward(; config = nothing)
139+
@test ad isa AbstractADType
140+
@test ad isa AutoMooncakeForward
141+
@test mode(ad) isa ForwardMode
142+
@test ad.config === nothing
143+
end
144+
136145
@testset "AutoPolyesterForwardDiff" begin
137146
ad = AutoPolyesterForwardDiff()
138147
@test ad isa AbstractADType

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ function every_ad_with_options()
7070
AutoForwardDiff(chunksize = 3, tag = :tag),
7171
AutoGTPSA(),
7272
AutoGTPSA(descriptor = Val(:descriptor)),
73+
AutoMooncake(; config = :config),
74+
AutoMooncakeForward(; config = :config),
7375
AutoPolyesterForwardDiff(),
7476
AutoPolyesterForwardDiff(chunksize = 3, tag = :tag),
7577
AutoReverseDiff(),

0 commit comments

Comments
 (0)