Skip to content

Commit 6fdb76f

Browse files
authored
Fix AD (#244)
* Fix mean_vector * Fix tests * Bump patch * Fix examples
1 parent bd15654 commit 6fdb76f

File tree

10 files changed

+38
-10
lines changed

10 files changed

+38
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Stheno"
22
uuid = "8188c328-b5d6-583d-959b-9690869a5511"
3-
version = "0.8.1"
3+
version = "0.8.2"
44

55
[deps]
66
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"

examples/extended_mauna_loa/script.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ ml_df, Ttr_df, Tte_df = let
4848
mauna_loa_co2 = let
4949
mauna_loa_data_raw = CSV.read(
5050
joinpath(datadep"mauna_loa", "monthly_in_situ_co2_mlo.csv"), DataFrame;
51-
skipto=58, header=false,
51+
skipto=61, header=false,
5252
)
5353

5454
data = DataFrame(

src/Stheno.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ module Stheno
1919

2020
const AV{T} = AbstractVector{T}
2121

22+
import ChainRulesCore: rrule
23+
2224
# A couple of AbstractVector subtypes useful for expressing structure in inputs
2325
# regularly found in GPPPs.
2426
include("input_collection_types.jl")

src/affine_transformations/addition.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ end
1919

2020
const add_args = Tuple{typeof(+), AbstractGP, AbstractGP}
2121

22+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mean), ::add_args, ::AV)
23+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(cov), ::add_args, ::AV)
24+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(var), ::add_args, ::AV)
25+
2226
mean((_, fa, fb)::add_args, x::AV) = mean(fa, x) .+ mean(fb, x)
2327

2428
function cov((_, fa, fb)::add_args, x::AV)
@@ -62,6 +66,10 @@ end
6266

6367
const add_known{T} = Tuple{typeof(+), T, AbstractGP}
6468

69+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mean), ::add_known, ::AV)
70+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(cov), ::add_known, ::AV)
71+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(var), ::add_known, ::AV)
72+
6573
mean((_, b, f)::add_known, x::AV) = b.(x) .+ mean(f, x)
6674
mean((_, b, f)::add_known{<:Real}, x::AV) = b .+ mean(f, x)
6775

src/affine_transformations/compose.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@ Constructs the DerivedGP f′ given by f′(x) := f(g(x))
77
"""
88
(f::AbstractGP, g) = DerivedGP((, f, g), f.gpc)
99

10-
1110
const comp_args = Tuple{typeof(), AbstractGP, Any}
1211

12+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mean), ::comp_args, ::AV)
13+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(cov), ::comp_args, ::AV)
14+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(var), ::comp_args, ::AV)
15+
1316
mean((_, f, g)::comp_args, x::AV) = mean(f, g.(x))
1417

1518
cov((_, f, g)::comp_args, x::AV) = cov(f, g.(x))

src/affine_transformations/cross.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ end
1414
_collect(X::BlockArray) = Array(X)
1515

1616
function ChainRulesCore.rrule(::typeof(_collect), X::BlockArray)
17-
function Array_pullback::Array)
18-
ΔX = Tangent{Any}(blocks=BlockArray(Δ, axes(X)).blocks, axes=NoTangent())
17+
function Array_pullback::AbstractArray)
18+
ΔX = Tangent{Any}(blocks=BlockArray(collect(Δ), axes(X)).blocks, axes=NoTangent())
1919
return (NoTangent(), ΔX)
2020
end
2121
return Array(X), Array_pullback
@@ -45,9 +45,12 @@ function consistency_checks(fs)
4545
end
4646
ChainRulesCore.@non_differentiable consistency_checks(::Any)
4747

48-
4948
const cross_args{T<:AbstractVector{<:AbstractGP}} = Tuple{typeof(cross), T}
5049

50+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mean), ::cross_args, ::AV)
51+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(cov), ::cross_args, ::AV)
52+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(var), ::cross_args, ::AV)
53+
5154
function mean((_, fs)::cross_args, x::BlockData)
5255
blks = map((f, blk)->mean(f, blk), fs, blocks(x))
5356
return _collect(_mortar(blks))

src/affine_transformations/product.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@ If `f isa Real`, then `h(x) = f * g(x)`.
1010
"""
1111
*(f, g::AbstractGP) = DerivedGP((*, f, g), g.gpc)
1212
*(f::AbstractGP, g) = DerivedGP((*, g, f), f.gpc)
13-
*(f::AbstractGP, g::AbstractGP) = throw(ArgumentError("Cannot multiply two GPs together."))
13+
*(::AbstractGP, ::AbstractGP) = throw(ArgumentError("Cannot multiply two GPs together."))
1414

1515
const prod_args{Tf} = Tuple{typeof(*), Tf, <:AbstractGP}
1616

17+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mean), ::prod_args, ::AV)
18+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(cov), ::prod_args, ::AV)
19+
@opt_out rrule(::RuleConfig{>:HasReverseMode}, ::typeof(var), ::prod_args, ::AV)
20+
1721
#
1822
# Scale by a function
1923
#

src/gp/atomic_gp.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ end
2121

2222
atomic(gp::Tgp, gpc::GPC) where {Tgp<:AbstractGP} = AtomicGP{Tgp}(gp, gpc)
2323

24+
@opt_out rrule(::typeof(mean), ::AtomicGP, ::AbstractVector)
25+
@opt_out rrule(::typeof(cov), ::AtomicGP, ::AbstractVector)
26+
@opt_out rrule(::typeof(var), ::AtomicGP, ::AbstractVector)
27+
2428
AbstractGPs.mean(f::AtomicGP, x::AbstractVector) = mean(f.gp, x)
2529

2630
AbstractGPs.cov(f::AtomicGP, x::AbstractVector) = cov(f.gp, x)

src/gp/derived_gp.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ struct DerivedGP{Targs} <: SthenoAbstractGP
1616
end
1717
DerivedGP(args::Targs, gpc::GPC) where {Targs} = DerivedGP{Targs}(args, gpc)
1818

19+
@opt_out rrule(::typeof(mean), ::DerivedGP, ::AbstractVector)
20+
@opt_out rrule(::typeof(cov), ::DerivedGP, ::AbstractVector)
21+
@opt_out rrule(::typeof(var), ::DerivedGP, ::AbstractVector)
22+
1923
AbstractGPs.mean(f::DerivedGP, x::AbstractVector) = mean(f.args, x)
2024

2125
AbstractGPs.cov(f::DerivedGP, x::AbstractVector) = cov(f.args, x)

test/gp/atomic_gp.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ struct ToyAbstractGP <: AbstractGP end
1111
x = collect(range(-1.0, 1.0; length=N))
1212
x′ = collect(range(-1.0, 1.0; length=N′))
1313

14-
@test mean(f, x) == AbstractGPs._map_meanfunction(m, x)
14+
@test mean(f, x) == mean_vector(m, x)
1515
@test cov(f, x) == kernelmatrix(k, x)
1616
AbstractGPs.TestUtils.test_internal_abstractgps_interface(rng, f, x, x′)
1717
end
@@ -27,8 +27,8 @@ struct ToyAbstractGP <: AbstractGP end
2727
k1, k2 = SqExponentialKernel(), SqExponentialKernel()
2828
f1, f2 = atomic(GP(m1, k1), gpc), atomic(GP(m2, k2), gpc)
2929

30-
@test mean(f1, x) == AbstractGPs._map_meanfunction(m1, x)
31-
@test mean(f2, x) == AbstractGPs._map_meanfunction(m2, x)
30+
@test mean(f1, x) == mean_vector(m1, x)
31+
@test mean(f2, x) == mean_vector(m2, x)
3232

3333
@test cov(f1, f2, x, x′) == zeros(N, N′)
3434
@test var(f1, x) == ones(N)

0 commit comments

Comments
 (0)