File tree Expand file tree Collapse file tree 6 files changed +29
-55
lines changed Expand file tree Collapse file tree 6 files changed +29
-55
lines changed Original file line number Diff line number Diff line change 1
1
name = " Stheno"
2
2
uuid = " 8188c328-b5d6-583d-959b-9690869a5511"
3
- version = " 0.7.17 "
3
+ version = " 0.7.18 "
4
4
5
5
[deps ]
6
6
AbstractGPs = " 99985d1d-32ba-4be9-9821-2ec096f28918"
Original file line number Diff line number Diff line change @@ -35,7 +35,7 @@ module Stheno
35
35
# Various bits of utility that aren't inherently GP-related. Often very type-piratic.
36
36
include (joinpath (" util" , " zygote_rules.jl" ))
37
37
include (joinpath (" util" , " covariance_matrices.jl" ))
38
- include (joinpath (" util" , " dense .jl" ))
38
+ include (joinpath (" util" , " block_arrays .jl" ))
39
39
include (joinpath (" util" , " abstract_data_set.jl" ))
40
40
include (joinpath (" util" , " proper_type_piracy.jl" ))
41
41
Original file line number Diff line number Diff line change
1
+ # This file contains a number of additions to BlockArrays.jl. These are completely
2
+ # independent of Stheno.jl, and will (hopefully) move over to BlockArrays.jl at some point.
3
+
4
+ function ChainRulesCore. rrule (:: typeof (BlockArrays. mortar), _blocks:: AbstractArray )
5
+ y = BlockArrays. mortar (_blocks)
6
+ Ty = typeof (y)
7
+ function mortar_pullback (Δ:: Tangent )
8
+ return (NoTangent (), Δ. blocks)
9
+ end
10
+ function mortar_pullback (Δ:: BlockArray )
11
+ return mortar_pullback (Tangent {Ty} (; blocks = Δ. blocks, axes= NoTangent ()))
12
+ end
13
+ return y, mortar_pullback
14
+ end
15
+
16
+ # A hook to which I can attach an rrule without commiting type-piracy against BlockArrays.
17
+ _collect (X:: BlockArray ) = Array (X)
18
+
19
+ function ChainRulesCore. rrule (:: typeof (_collect), X:: BlockArray )
20
+ function Array_pullback (Δ:: Array )
21
+ ΔX = Tangent {Any} (blocks= BlockArray (Δ, axes (X)). blocks, axes= NoTangent ())
22
+ return (NoTangent (), ΔX)
23
+ end
24
+ return Array (X), Array_pullback
25
+ end
Load Diff This file was deleted.
Original file line number Diff line number Diff line change @@ -34,8 +34,7 @@ using Stheno:
34
34
diag_At_A,
35
35
diag_At_B,
36
36
diag_Xt_invA_X,
37
- diag_Xt_invA_Y,
38
- blocksizes
37
+ diag_Xt_invA_Y
39
38
40
39
using Stheno. AbstractGPs. TestUtils: test_internal_abstractgps_interface
41
40
using Stheno. AbstractGPs. Distributions: MvNormal
@@ -55,7 +54,7 @@ include("test_util.jl")
55
54
@timedtestset " util" begin
56
55
include (joinpath (" util" , " zygote_rules.jl" ))
57
56
include (joinpath (" util" , " covariance_matrices.jl" ))
58
- include (joinpath (" util" , " dense .jl" ))
57
+ include (joinpath (" util" , " block_arrays .jl" ))
59
58
include (joinpath (" util" , " abstract_data_set.jl" ))
60
59
end
61
60
File renamed without changes.
You can’t perform that action at this time.
0 commit comments