Skip to content

Forward-mode AD #389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 197 commits into
base: main
Choose a base branch
from
Open

Forward-mode AD #389

wants to merge 197 commits into from

Conversation

gdalle
Copy link
Collaborator

@gdalle gdalle commented Nov 24, 2024

This is a very rough backbone of forward mode AD, based on #386 and the existing reverse mode implementation.

Will's edits (apologies for editing your thing @gdalle -- I just want to make sure that the todo list is at the top of the PR):

Todo:

  • make FunctionWrappers work correctly not going to do this in this PR
  • add support for MistyClosures
  • add tests for Hessian vector products
  • define is_primitive separately for forwards and reverse pass.
  • do a complete pass to review design -- are there any high-level things we ought to modify?
  • improve DRY-ness of code, particularly in testing infrastructure in particular.
  • check GPU compatibility, make sure no major design issues prevent future GPU compatibility, and be explicit about what needs to be done in the future.
  • what name should we use for @from_rule: @from_chainrules or @from_chain_rule, see comments below.
  • add support for UpsilonNodes and PhiCNodes.
  • get all tests passing
  • bump to version 0.5 actually not needed

Once the above are complete, I'll request reviews.

Copy link

codecov bot commented Nov 24, 2024

Codecov Report

Attention: Patch coverage is 94.04070% with 82 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/interpreter/s2s_forward_mode_ad.jl 88.77% 22 Missing ⚠️
src/test_utils.jl 86.66% 16 Missing ⚠️
src/rrules/foreigncall.jl 75.75% 8 Missing ⚠️
src/rrules/memory.jl 87.69% 8 Missing ⚠️
src/utils.jl 76.92% 6 Missing ⚠️
src/rrules/tasks.jl 64.28% 5 Missing ⚠️
src/dual.jl 85.71% 3 Missing ⚠️
src/rrules/builtins.jl 97.82% 3 Missing ⚠️
src/developer_tools.jl 0.00% 2 Missing ⚠️
src/interpreter/s2s_reverse_mode_ad.jl 71.42% 2 Missing ⚠️
... and 5 more
Files with missing lines Coverage Δ
src/Mooncake.jl 100.00% <ø> (ø)
src/interpreter/ir_utils.jl 89.68% <100.00%> (+2.81%) ⬆️
src/rrules/array_legacy.jl 100.00% <100.00%> (ø)
src/rrules/avoiding_non_differentiable_code.jl 100.00% <100.00%> (ø)
src/rrules/blas.jl 99.64% <100.00%> (+0.84%) ⬆️
src/rrules/fastmath.jl 100.00% <100.00%> (ø)
src/rrules/lapack.jl 100.00% <100.00%> (+0.56%) ⬆️
src/rrules/linear_algebra.jl 100.00% <100.00%> (ø)
src/rrules/low_level_maths.jl 100.00% <100.00%> (ø)
src/rrules/new.jl 91.30% <100.00%> (+2.84%) ⬆️
... and 20 more

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. I've left a few comments, but if you're planning to do a bunch of additional stuff, then maybe they're redundant. Either way, don't feel the need to respond to them.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

@willtebbutt following our discussion yesterday I scratched my head some more, and I decided that it would be infinitely simpler to enforce the invariant that one line of primal IR maps to one line of dual IR. While this may require additional fallbacks in the Julia code itself, I hope it will make our lives much easier on the IR side. What do you think?

@willtebbutt
Copy link
Collaborator

I think this could work.

You could just replace the frule!! calls with a call to a function call_frule!! which would be something like

@inline function call_frule!!(rule::R, fargs::Vararg{Any, N}) where {N}
    return rule(map(x -> x isa Dual ? x : zero_dual(x), fargs)...)
end

The optimisation pass will lower this to the what we were thinking about writing out in the IR anyway.

I think the other important kinds of nodes would be largely straightforward to handle.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

I think we might need to be slightly more subtle. If an argument to the :call or :invoke expression is a CC.Argument or a CC.SSAValue, we don't wrap it in a Dual because we assume it will already be one, right?

@willtebbutt
Copy link
Collaborator

willtebbutt commented Nov 26, 2024

Yes. I think my propose code handles this though, or am I missing something?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

In the spirit of higher-order AD, we may encounter Dual inputs that we want to wrap with a second Dual, and Dual inputs that we want to leave as-is. So I think this wrapping needs to be decided from the type of each argument in the IR?

@willtebbutt
Copy link
Collaborator

Very good point.

So I think this wrapping needs to be decided from the type of each argument in the IR?

Agreed. Specifically, I think we need to distinguish between literals / QuoteNodes / GlobalRefs, and Argument / SSAValues?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

I still need to dig into the different node types we might encounter (and I still don't understand QuoteNodes) but yeah, Argument and SSAValue don't need to be wrapped.

@gdalle gdalle mentioned this pull request Nov 27, 2024
@willtebbutt
Copy link
Collaborator

I was reviewing the design docs and realised that, sadly, the "one line of primal IR maps to one line of dual IR" won't work for Core.GotoIfNot nodes. See https://compintell.github.io/Mooncake.jl/previews/PR386/developer_documentation/forwards_mode_design/#Statement-Transformation .

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

I think that's okay, the main trouble is adding new lines which insert new variables because it requires manual renumbering. A GoTo should be much simpler.

@willtebbutt
Copy link
Collaborator

Were the difficulties around renumbering etc not resolved by not compact!ing until the end? I feel like I might be missing something.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

No they weren't. I experimented with compact! in various places and I was struggling a lot, so I asked Frames for advice. She agreed that insertion should usually be avoided.
If we have to insert something for GoTo, I think it will still be easier because we're not defining a new SSAValue so we don't have to adapt future statements that refer to it.

@willtebbutt
Copy link
Collaborator

willtebbutt commented Nov 27, 2024

Ah, right, but we do need to insert a new SSAValue. Suppose that the GotoIfNot of interest is

GotoIfNot(%5, #3)

i.e. jump to block 3 if not %5. In the forwards-mode IR this would become

%new_ssa = Expr(:call, primal, %5)
GotoIfNot(%new_ssa, #3)

Does this not cause the same kind of problems?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

Oh yes you're probably right. Although it might be slightly less of a hassle because the new SSA is only used in one spot, right after. I'll take a look

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

Do you know what I should do about expressions of type :code_coverage_effect? I assume they're inserted automatically and they're alone on their lines?

@willtebbutt
Copy link
Collaborator

willtebbutt commented Nov 27, 2024

Yup -- I just strip them out of the IR entirely in reverse-mode. See https://github.com/compintell/Mooncake.jl/blob/0f37c079bd1ae064e7b84696eed4a1f7eb763f1f/src/interpreter/s2s_reverse_mode_ad.jl#L728

The way to remove an instruction from an IRCode is just to replace the instruction with nothing.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

I think this works for GotoIfNot:

  1. make all the insertions necessary
  2. compact! once to make sure they applied
  3. shift the conditions of all GotoIfNot nodes to refer to the node right before them (where we get the primal value of the condition)

MWE (requires this branch of Mooncake):

const CC = Core.Compiler
using Mooncake
using MistyClosures

f(x) = x > 1 ? 2x : 3 + x
ir = Base.code_ircode(f, (Float64,))[1][1]
initial_ir = copy(ir)
get_primal_inst = CC.NewInstruction(Expr(:call, +, 1, 2), Any)  # placeholder for get_primal
CC.insert_node!(ir, CC.SSAValue(3), get_primal_inst, false)
ir = CC.compact!(ir)
for k in 1:length(ir.stmts)
    inst = ir[CC.SSAValue(k)][:stmt]
    if inst isa Core.GotoIfNot
        Mooncake.replace_call!(ir,CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k-1), inst.dest))
    end
end
ir
julia> initial_ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >%2 = Base.or_int(%1, false)::Bool                                                                                 ││╻   <
  └──      goto #3 if not %2                                                                                            │   
  2%4 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %43%6 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %6                                                                                                    │   
                                                                                                                            

julia> ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >
  │        Base.or_int(%1, false)::Bool                                                                                 ││╻   <%3 = (+)(1, 2)::Any                                                                                               │   
  └──      goto #3 if not %3                                                                                            │   
  2%5 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %53%7 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %7      

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This pull request implements forward-mode automatic differentiation (AD) for Mooncake.jl, expanding the framework to support both forward and reverse mode AD. The implementation adds dual number support, forward mode rules, and testing infrastructure.

  • Adds forward-mode AD support with dual numbers and frule!! implementation
  • Updates all existing rules to support both forward and reverse mode differentiation
  • Refactors testing infrastructure to handle both AD modes
  • Introduces new macros for defining mode-agnostic rules

Reviewed Changes

Copilot reviewed 79 out of 80 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
test/tools_for_rules.jl Updates macro names and testing infrastructure to support both forward and reverse modes
test/runtests.jl Adds test for forward mode AD and disables piracy checks
test/rrules/*.jl Updates test function calls to use new unified testing interface
test/interpreter/*.jl Adds forward mode testing and updates test signatures
test/integration_testing/*.jl Updates test calls to specify reverse mode explicitly
src/utils.jl Adds utility functions for forward mode and Julia version compatibility
src/tools_for_rules.jl Introduces @from_chainrules macro and forward mode rule construction
src/test_utils.jl Major refactoring to support testing both forward and reverse modes
src/rrules/*.jl Adds frule!! implementations for all existing rrule!! definitions
Comments suppressed due to low confidence (2)

test/tools_for_rules.jl:22

  • [nitpick] The macro name @zero_derivative is inconsistent with the existing @zero_adjoint macro pattern. Consider using @zero_tangent or maintaining @zero_adjoint for consistency.
@zero_derivative MinimalCtx Tuple{typeof(zero_tester),Float64}

test/tools_for_rules.jl:37

  • [nitpick] The macro name @from_chainrules is unclear about its purpose. Consider @from_chain_rules or @chainrules_to_mooncake for better clarity.
@from_chainrules DefaultCtx Tuple{typeof(bleh),Float64,Int} false

@gdalle
Copy link
Collaborator Author

gdalle commented Jul 25, 2025

Okay that was completely useless 🤣 good to know though

@willtebbutt
Copy link
Collaborator

Do we want to export Dual, in addition to making it part of the public interface? My concern is that lots of things define Dual types, and this might cause annoying name clashes. @gdalle what are your thoughts on this?

This was referenced Aug 3, 2025
@willtebbutt
Copy link
Collaborator

I've addressed most of the comments @gdalle , but there are a few more to deal with before merging. Thanks for finding a pretty wide selection of things which weren't quite correct!

"""
prepare_derivative_cache(f, x...)
Returns a cache used with [`value_and_derivative!!`](@ref). See that function for more info.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Returns a cache used with [`value_and_derivative!!`](@ref). See that function for more info.
Returns a cache used with [`value_and_derivative!!`](@ref). See that function for more info.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants