Skip to content

Commit a2e58ee

Browse files
authored
Make ChainRulesCore a weak dependency (#445)
1 parent 10a710f commit a2e58ee

File tree

3 files changed

+21
-11
lines changed

3 files changed

+21
-11
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Roots"
22
uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
3-
version = "2.1.8"
3+
version = "2.2.0"
44

55
[deps]
66
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -9,12 +9,14 @@ CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
99
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1010

1111
[weakdeps]
12+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1213
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1314
IntervalRootFinding = "d2bf35a9-74e0-55ec-b149-d360ff49b807"
1415
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
1516
SymPyPythonCall = "bc8888f7-b21e-4b7c-a06a-5d9c9496438c"
1617

1718
[extensions]
19+
RootsChainRulesCoreExt = "ChainRulesCore"
1820
RootsForwardDiffExt = "ForwardDiff"
1921
RootsIntervalRootFindingExt = "IntervalRootFinding"
2022
RootsSymPyExt = "SymPy"

src/chain_rules.jl renamed to ext/RootsChainRulesCoreExt.jl

+14-8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
module RootsChainRulesCoreExt
2+
3+
using Roots
4+
import ChainRulesCore
5+
16
# View find_zero as solving `f(x, p) = 0` for `xᵅ(p)`.
27
# This is implicitly defined. By the implicit function theorem, we have:
38
# ∇f = 0 => ∂/∂ₓ f(xᵅ, p) ⋅ ∂xᵅ/∂ₚ + ∂/∂ₚf(x\^α, p) ⋅ I = 0
@@ -15,7 +20,6 @@
1520
# that is fixable.)
1621

1722
# this assumes a function and a parameter `p` passed in
18-
import ChainRulesCore: Tangent, NoTangent, frule, rrule
1923
function ChainRulesCore.frule(
2024
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
2125
(_, _, _, Δp),
@@ -42,17 +46,17 @@ ChainRulesCore.frule(
4246
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
4347
xdots,
4448
::typeof(solve),
45-
ZP::Roots.ZeroProblem,
49+
ZP::ZeroProblem,
4650
M::Roots.AbstractUnivariateZeroMethod,
4751
::Nothing;
4852
kwargs...,
49-
) = frule(config, xdots, solve, ZP, M; kwargs...)
53+
) = ChainRulesCore.frule(config, xdots, solve, ZP, M; kwargs...)
5054

5155
function ChainRulesCore.frule(
5256
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
5357
(_, Δq, _),
5458
::typeof(solve),
55-
ZP::Roots.ZeroProblem,
59+
ZP::ZeroProblem,
5660
M::Roots.AbstractUnivariateZeroMethod;
5761
kwargs...,
5862
)
@@ -61,12 +65,12 @@ function ChainRulesCore.frule(
6165
zprob2 = ZeroProblem(|>, ZP.x₀)
6266
nms = fieldnames(typeof(foo))
6367
nt = NamedTuple{nms}(getfield(foo, n) for n in nms)
64-
dfoo = Tangent{typeof(foo)}(; nt...)
68+
dfoo = ChainRulesCore.Tangent{typeof(foo)}(; nt...)
6569

66-
return frule(
70+
return ChainRulesCore.frule(
6771
config,
68-
(NoTangent(), NoTangent(), NoTangent(), dfoo),
69-
Roots.solve,
72+
(ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dfoo),
73+
solve,
7074
zprob2,
7175
M,
7276
foo,
@@ -146,3 +150,5 @@ function ChainRulesCore.rrule(
146150

147151
return xᵅ, pullback_solve_ZeroProblem
148152
end
153+
154+
end # module

src/Roots.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ using Printf
2222
import CommonSolve
2323
import CommonSolve: solve, solve!, init
2424
using Accessors
25-
import ChainRulesCore
2625

2726
export fzero, fzeros, secant_method
2827

@@ -53,7 +52,6 @@ include("functions.jl")
5352
include("trace.jl")
5453
include("find_zero.jl")
5554
include("hybrid.jl")
56-
include("chain_rules.jl")
5755

5856
include("Bracketing/bracketing.jl")
5957
include("Bracketing/bisection.jl")
@@ -83,4 +81,8 @@ include("find_zeros.jl")
8381
include("simple.jl")
8482
include("alternative_interfaces.jl")
8583

84+
if !isdefined(Base, :get_extension)
85+
include("../ext/RootsChainRulesCoreExt.jl")
86+
end
87+
8688
end

0 commit comments

Comments
 (0)