Skip to content

Commit b1b213f

Browse files
authored
Additive GP affine transformation (#229)
* Add additive_gp transformation * Bump patch
1 parent 485d8d6 commit b1b213f

File tree

6 files changed

+77
-2
lines changed

6 files changed

+77
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Stheno"
22
uuid = "8188c328-b5d6-583d-959b-9690869a5511"
3-
version = "0.8.0"
3+
version = "0.8.1"
44

55
[deps]
66
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@ stretch
2525
periodic
2626
shift
2727
select
28+
additive_gp
2829
```

src/Stheno.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,13 @@ module Stheno
3636
include(joinpath("affine_transformations", "addition.jl"))
3737
include(joinpath("affine_transformations", "compose.jl"))
3838
include(joinpath("affine_transformations", "product.jl"))
39+
include(joinpath("affine_transformations", "additive_gp.jl"))
3940

4041
# AbstractGP subtype which groups together other AbstractGP subtypes.
4142
include("gaussian_process_probabilistic_programme.jl")
4243

4344
export atomic, BlockData, GPC, GPPPInput, @gppp
44-
export , select, stretch, periodic, shift
45+
export , select, stretch, periodic, shift, additive_gp
4546
export SparseFiniteGP
4647

4748
end # module
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
additive_gp(fs)
3+
4+
Produces the GP given by
5+
```julia
6+
sum(fs[1](x[1]) + fs[2](x[2]) + ... + fs[D](x[D]))
7+
```
8+
Requires that `length(fs)` is the same as the dimension of the inputs to be used.
9+
"""
10+
additive_gp(fs) = additive_gp(fs, 1:length(fs))
11+
12+
"""
13+
additive_gp(fs, indices)
14+
15+
`fs` should be a collection of GPs, and `indices` a collection of collections of integer
16+
indices.
17+
For example, `indices` might be something like `[1:2, 3, 4:6]`, in which case `fs` would
18+
need to comprise exactly three elements. In general, this functions requires that
19+
`length(fs) == length(indices)`.
20+
21+
Produces the GP given by
22+
```julia
23+
sum(fs[1](x[indices[1]]) + fs[2](x[indices[2]]) + ... + fs[D](x[indices[D]]))
24+
```
25+
"""
26+
function additive_gp(fs, indices)
27+
fs_projected = map((f, idx) -> f Select(idx), fs, indices)
28+
return sum(fs_projected)
29+
end
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
@testset "additive_gp" begin
2+
@testset "arbitrary indices" begin
3+
f = @gppp let
4+
f1 = GP(SEKernel())
5+
f2 = GP(SEKernel())
6+
f3 = additive_gp((f1, f2), [2:3, 1])
7+
end
8+
x_raw = ColVecs(randn(3, 7))
9+
x1_raw = ColVecs(x_raw.X[2:3, :])
10+
x2_raw = x_raw.X[1, :]
11+
x = vcat(
12+
GPPPInput(:f1, x1_raw),
13+
GPPPInput(:f2, x2_raw),
14+
GPPPInput(:f3, x_raw),
15+
)
16+
y = rand(f(x, 1e-9))
17+
18+
@test y[1:7] + y[8:14] y[15:21] rtol=1e-3
19+
20+
z = GPPPInput(:f3, ColVecs(3 * randn(3, 4)))
21+
test_internal_abstractgps_interface(MersenneTwister(123456), f, x, z; jitter=1e-15)
22+
end
23+
@testset "regular indices" begin
24+
f = @gppp let
25+
f1 = GP(SEKernel())
26+
f2 = GP(Matern52Kernel())
27+
f3 = additive_gp((f1, f2))
28+
end
29+
x_raw = ColVecs(randn(2, 7))
30+
x1_raw = x_raw.X[1, :]
31+
x2_raw = x_raw.X[2, :]
32+
x = vcat(
33+
GPPPInput(:f1, x1_raw),
34+
GPPPInput(:f2, x2_raw),
35+
GPPPInput(:f3, x_raw),
36+
)
37+
y = rand(f(x, 1e-9))
38+
@test y[1:7] + y[8:14] y[15:21] rtol=1e-3
39+
40+
z = GPPPInput(:f3, ColVecs(3 * randn(2, 4)))
41+
test_internal_abstractgps_interface(MersenneTwister(123456), f, x, z; jitter=1e-15)
42+
end
43+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ include("test_util.jl")
6161
include(joinpath("affine_transformations", "addition.jl"))
6262
include(joinpath("affine_transformations", "compose.jl"))
6363
include(joinpath("affine_transformations", "product.jl"))
64+
include(joinpath("affine_transformations", "additive_gp.jl"))
6465
end
6566

6667
include("gaussian_process_probabilistic_programme.jl")

0 commit comments

Comments
 (0)