Skip to content

Commit c822896

Browse files
authored
Allow AbstractGPs in WrappedGP (#217)
* Bump patch * Allow AbstractGP in WrappedGP * Test nested GPPP * Check WrappedGP with fresh AbstractGP * Remove redundant code * Remove redudant code
1 parent 576df0b commit c822896

File tree

8 files changed

+27
-540
lines changed

8 files changed

+27
-540
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Stheno"
22
uuid = "8188c328-b5d6-583d-959b-9690869a5511"
3-
version = "0.7.15"
3+
version = "0.7.16"
44

55
[deps]
66
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"

src/Stheno.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ module Stheno
3737
include(joinpath("util", "zygote_rules.jl"))
3838
include(joinpath("util", "covariance_matrices.jl"))
3939
include(joinpath("util", "block_arrays", "dense.jl"))
40-
include(joinpath("util", "block_arrays", "diagonal.jl"))
4140
include(joinpath("util", "abstract_data_set.jl"))
4241
include(joinpath("util", "proper_type_piracy.jl"))
4342

src/gp/gp.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ struct WrappedGP{Tgp<:AbstractGP} <: SthenoAbstractGP
1212
gp::Tgp
1313
n::Int
1414
gpc::GPC
15-
function WrappedGP{Tgp}(gp::Tgp, gpc::GPC) where {Tgp<:GP}
15+
function WrappedGP{Tgp}(gp::Tgp, gpc::GPC) where {Tgp<:AbstractGP}
1616
wgp = new{Tgp}(gp, next_index(gpc), gpc)
1717
gpc.n += 1
1818
return wgp
1919
end
2020
end
2121

22-
wrap(gp::Tgp, gpc::GPC) where {Tgp<:GP} = WrappedGP{Tgp}(gp, gpc)
22+
wrap(gp::Tgp, gpc::GPC) where {Tgp<:AbstractGP} = WrappedGP{Tgp}(gp, gpc)
2323

2424
mean(f::WrappedGP, x::AbstractVector) = mean(f.gp, x)
2525

src/util/block_arrays/diagonal.jl

Lines changed: 0 additions & 231 deletions
This file was deleted.

test/gaussian_process_probabilistic_programme.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@
8282
GPPPInput(:f1, randn(4)),
8383
),
8484
]
85-
rng = MersenneTwister(123456)
86-
AbstractGPs.TestUtils.test_internal_abstractgps_interface(rng, f, x0, x1)
85+
test_internal_abstractgps_interface(MersenneTwister(123456), f, x0, x1)
8786
end
8887

8988
@timedtestset "gppp macro" begin
@@ -103,4 +102,20 @@
103102
y = rand(f(x, s))
104103
Zygote.gradient((x, y, f, s) -> logpdf(f(x, s), y), x, y, f, s)
105104
end
105+
106+
# Check that we can use one GPPP inside another.
107+
@timedtestset "nested gppp" begin
108+
109+
gpc_outer = GPC()
110+
f1_outer = Stheno.wrap(f, gpc_outer)
111+
f2_outer = 5 * f1_outer
112+
f_outer = Stheno.GPPP((f1=f1_outer, f2=f2_outer), gpc_outer)
113+
114+
x0 = GPPPInput(:f1, randn(5))
115+
x1 = GPPPInput(:f2, randn(4))
116+
x0_outer = GPPPInput(:f1, x0)
117+
x1_outer = GPPPInput(:f2, x1)
118+
rng = MersenneTwister(123456)
119+
test_internal_abstractgps_interface(rng, f_outer, x0_outer, x1_outer)
120+
end
106121
end

test/gp/gp.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
struct ToyAbstractGP <: AbstractGP end
2+
13
@timedtestset "gp" begin
24

35
# Ensure that basic functionality works as expected.
@@ -33,4 +35,8 @@
3335

3436
@test cov(f1, f1, x′, x) cov(f1, f1, x, x′)'
3537
end
38+
39+
@timedtestset "wrapped AbstractGP" begin
40+
wrap(ToyAbstractGP(), GPC())
41+
end
3642
end

test/runtests.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ using Stheno:
2424
GPC,
2525
AV,
2626
FiniteGP,
27-
block_diagonal,
2827
AbstractGP,
2928
BlockData,
3029
blocks,
@@ -36,10 +35,9 @@ using Stheno:
3635
diag_At_B,
3736
diag_Xt_invA_X,
3837
diag_Xt_invA_Y,
39-
block_diagonal,
40-
BlockDiagonal,
4138
blocksizes
4239

40+
using Stheno.AbstractGPs.TestUtils: test_internal_abstractgps_interface
4341
using Stheno.AbstractGPs.Distributions: MvNormal
4442
using FiniteDifferences: j′vp
4543

@@ -60,7 +58,6 @@ include("test_util.jl")
6058
@testset "block_arrays" begin
6159
include(joinpath("util", "block_arrays", "test_util.jl"))
6260
include(joinpath("util", "block_arrays", "dense.jl"))
63-
include(joinpath("util", "block_arrays", "diagonal.jl"))
6461
end
6562
include(joinpath("util", "abstract_data_set.jl"))
6663
end

0 commit comments

Comments
 (0)