-
Notifications
You must be signed in to change notification settings - Fork 20
Forward-mode AD #389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Forward-mode AD #389
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great. I've left a few comments, but if you're planning to do a bunch of additional stuff, then maybe they're redundant. Either way, don't feel the need to respond to them.
Co-authored-by: Will Tebbutt <[email protected]> Signed-off-by: Guillaume Dalle <[email protected]>
@willtebbutt following our discussion yesterday I scratched my head some more, and I decided that it would be infinitely simpler to enforce the invariant that one line of primal IR maps to one line of dual IR. While this may require additional fallbacks in the Julia code itself, I hope it will make our lives much easier on the IR side. What do you think? |
I think this could work. You could just replace the @inline function call_frule!!(rule::R, fargs::Vararg{Any, N}) where {N}
return rule(map(x -> x isa Dual ? x : zero_dual(x), fargs)...)
end The optimisation pass will lower this to the what we were thinking about writing out in the IR anyway. I think the other important kinds of nodes would be largely straightforward to handle. |
I think we might need to be slightly more subtle. If an argument to the |
Yes. I think my propose code handles this though, or am I missing something? |
In the spirit of higher-order AD, we may encounter |
Very good point.
Agreed. Specifically, I think we need to distinguish between literals / |
I still need to dig into the different node types we might encounter (and I still don't understand |
I was reviewing the design docs and realised that, sadly, the "one line of primal IR maps to one line of dual IR" won't work for |
I think that's okay, the main trouble is adding new lines which insert new variables because it requires manual renumbering. A GoTo should be much simpler. |
Were the difficulties around renumbering etc not resolved by not |
No they weren't. I experimented with |
Ah, right, but we do need to insert a new SSAValue. Suppose that the GotoIfNot(%5, #3) i.e. jump to block 3 if not %new_ssa = Expr(:call, primal, %5)
GotoIfNot(%new_ssa, #3) Does this not cause the same kind of problems? |
Oh yes you're probably right. Although it might be slightly less of a hassle because the new SSA is only used in one spot, right after. I'll take a look |
Do you know what I should do about expressions of type |
Yup -- I just strip them out of the IR entirely in reverse-mode. See https://github.com/compintell/Mooncake.jl/blob/0f37c079bd1ae064e7b84696eed4a1f7eb763f1f/src/interpreter/s2s_reverse_mode_ad.jl#L728 The way to remove an instruction from an |
I think this works for
MWE (requires this branch of Mooncake): const CC = Core.Compiler
using Mooncake
using MistyClosures
f(x) = x > 1 ? 2x : 3 + x
ir = Base.code_ircode(f, (Float64,))[1][1]
initial_ir = copy(ir)
get_primal_inst = CC.NewInstruction(Expr(:call, +, 1, 2), Any) # placeholder for get_primal
CC.insert_node!(ir, CC.SSAValue(3), get_primal_inst, false)
ir = CC.compact!(ir)
for k in 1:length(ir.stmts)
inst = ir[CC.SSAValue(k)][:stmt]
if inst isa Core.GotoIfNot
Mooncake.replace_call!(ir,CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k-1), inst.dest))
end
end
ir julia> initial_ir
5 1 ─ %1 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ %2 = Base.or_int(%1, false)::Bool ││╻ <
└── goto #3 if not %2 │
2 ─ %4 = Base.mul_float(2.0, _2)::Float64 ││╻ *
└── return %4 │
3 ─ %6 = Base.add_float(3.0, _2)::Float64 ││╻ +
└── return %6 │
julia> ir
5 1 ─ %1 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ Base.or_int(%1, false)::Bool ││╻ <
│ %3 = (+)(1, 2)::Any │
└── goto #3 if not %3 │
2 ─ %5 = Base.mul_float(2.0, _2)::Float64 ││╻ *
└── return %5 │
3 ─ %7 = Base.add_float(3.0, _2)::Float64 ││╻ +
└── return %7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This pull request implements forward-mode automatic differentiation (AD) for Mooncake.jl, expanding the framework to support both forward and reverse mode AD. The implementation adds dual number support, forward mode rules, and testing infrastructure.
- Adds forward-mode AD support with dual numbers and frule!! implementation
- Updates all existing rules to support both forward and reverse mode differentiation
- Refactors testing infrastructure to handle both AD modes
- Introduces new macros for defining mode-agnostic rules
Reviewed Changes
Copilot reviewed 79 out of 80 changed files in this pull request and generated 4 comments.
Show a summary per file
File | Description |
---|---|
test/tools_for_rules.jl | Updates macro names and testing infrastructure to support both forward and reverse modes |
test/runtests.jl | Adds test for forward mode AD and disables piracy checks |
test/rrules/*.jl | Updates test function calls to use new unified testing interface |
test/interpreter/*.jl | Adds forward mode testing and updates test signatures |
test/integration_testing/*.jl | Updates test calls to specify reverse mode explicitly |
src/utils.jl | Adds utility functions for forward mode and Julia version compatibility |
src/tools_for_rules.jl | Introduces @from_chainrules macro and forward mode rule construction |
src/test_utils.jl | Major refactoring to support testing both forward and reverse modes |
src/rrules/*.jl | Adds frule!! implementations for all existing rrule!! definitions |
Comments suppressed due to low confidence (2)
test/tools_for_rules.jl:22
- [nitpick] The macro name
@zero_derivative
is inconsistent with the existing@zero_adjoint
macro pattern. Consider using@zero_tangent
or maintaining@zero_adjoint
for consistency.
@zero_derivative MinimalCtx Tuple{typeof(zero_tester),Float64}
test/tools_for_rules.jl:37
- [nitpick] The macro name
@from_chainrules
is unclear about its purpose. Consider@from_chain_rules
or@chainrules_to_mooncake
for better clarity.
@from_chainrules DefaultCtx Tuple{typeof(bleh),Float64,Int} false
Okay that was completely useless 🤣 good to know though |
Do we want to export |
Co-authored-by: Guillaume Dalle <[email protected]> Signed-off-by: Will Tebbutt <[email protected]>
I've addressed most of the comments @gdalle , but there are a few more to deal with before merging. Thanks for finding a pretty wide selection of things which weren't quite correct! |
""" | ||
prepare_derivative_cache(f, x...) | ||
Returns a cache used with [`value_and_derivative!!`](@ref). See that function for more info. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returns a cache used with [`value_and_derivative!!`](@ref). See that function for more info. | |
Returns a cache used with [`value_and_derivative!!`](@ref). See that function for more info. |
This is a very rough backbone of forward mode AD, based on #386 and the existing reverse mode implementation.
Will's edits (apologies for editing your thing @gdalle -- I just want to make sure that the todo list is at the top of the PR):
Todo:
make FunctionWrappers work correctlynot going to do this in this PRis_primitive
separately for forwards and reverse pass.@from_rule
:@from_chainrules
or@from_chain_rule
, see comments below.UpsilonNode
s andPhiCNode
s.bump to version 0.5actually not neededOnce the above are complete, I'll request reviews.