Skip to content

Commit c507b44

Browse files
committed
Handle prediction at tree split points separately. Always use qr
factorization when solving the local system to avoid error when the system is singular. Also fix a one-off error in the median calculation in the KDTree implementation.
1 parent cb38f5e commit c507b44

File tree

4 files changed

+45
-14
lines changed

4 files changed

+45
-14
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ version = "0.5.2"
44

55
[deps]
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
89

910
[compat]
10-
julia = "0.7, 1"
1111
Distances = "0.7, 0.8, 0.9, 0.10"
12+
julia = "0.7, 1"
1213

1314
[extras]
1415
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/Loess.jl

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module Loess
22

33
import Distances.euclidean
44

5-
using Statistics
5+
using Statistics, LinearAlgebra
66

77
export loess, predict
88

@@ -44,6 +44,9 @@ function loess(xs::AbstractMatrix{T}, ys::AbstractVector{T};
4444

4545
n, m = size(xs)
4646
q = ceil(Int, (span * n))
47+
if q < degree + 1
48+
throw(ArgumentError("neighborhood size must be larger than degree+1=$(degree + 1) but was $q. Try increasing the value of span."))
49+
end
4750

4851
# TODO: We need to keep track of how we are normalizing so we can
4952
# correctly apply predict to unnormalized data. We should have a normalize
@@ -53,7 +56,6 @@ function loess(xs::AbstractMatrix{T}, ys::AbstractVector{T};
5356
end
5457

5558
kdtree = KDTree(xs, 0.05 * span)
56-
verts = Array{T}(undef, length(kdtree.verts), m)
5759

5860
# map verticies to their index in the bs coefficient matrix
5961
verts = Dict{Vector{T}, Int}()
@@ -69,6 +71,8 @@ function loess(xs::AbstractMatrix{T}, ys::AbstractVector{T};
6971
# TODO: higher degree fitting
7072
us = Array{T}(undef, q, 1 + degree * m)
7173
vs = Array{T}(undef, q)
74+
75+
has_warned = false
7276
for (vert, k) in verts
7377
# reset perm
7478
for i in 1:n
@@ -85,20 +89,22 @@ function loess(xs::AbstractMatrix{T}, ys::AbstractVector{T};
8589
dmax = maximum([ds[perm[i]] for i = 1:q])
8690

8791
for i in 1:q
88-
pi = perm[i]
89-
w = tricubic(ds[pi] / dmax)
92+
pᵢ = perm[i]
93+
w = tricubic(ds[pᵢ] / dmax)
9094
us[i,1] = w
9195
for j in 1:m
92-
x = xs[pi, j]
96+
x = xs[pᵢ, j]
9397
wxl = w
9498
for l in 1:degree
9599
wxl *= x
96-
us[i, 1 + (j-1)*degree + l] = wxl # w*x^l
100+
us[i, 1 + (j - 1)*degree + l] = wxl # w*x^l
97101
end
98102
end
99-
vs[i] = ys[pi] * w
103+
vs[i] = ys[pᵢ] * w
100104
end
101-
bs[k,:] = us \ vs
105+
106+
F = qr(us, Val(true))
107+
bs[k,:] = F\vs
102108
end
103109

104110
LoessModel{T}(xs, ys, bs, verts, kdtree)
@@ -149,11 +155,16 @@ function predict(model::LoessModel{T}, zs::AbstractVector{T}) where T <: Abstrac
149155
if m == 1
150156
@assert(length(adjacent_verts) == 2)
151157
z = zs[1]
152-
u = (z - adjacent_verts[1][1]) /
153-
(adjacent_verts[2][1] - adjacent_verts[1][1])
158+
v₁, v₂ = adjacent_verts[1][1], adjacent_verts[2][1]
159+
160+
if z == v₁ || z == v₂
161+
return evalpoly(zs, model.bs[model.verts[[z]],:])
162+
end
163+
164+
u = (z - v₁)/(v₂ - v₁)
154165

155-
y1 = evalpoly(zs, model.bs[model.verts[[adjacent_verts[1][1]]],:])
156-
y2 = evalpoly(zs, model.bs[model.verts[[adjacent_verts[2][1]]],:])
166+
y1 = evalpoly(zs, model.bs[model.verts[[v₁]],:])
167+
y2 = evalpoly(zs, model.bs[model.verts[[v₂]],:])
157168
return (1.0 - u) * y1 + u * y2
158169
else
159170
error("Multivariate blending not yet implemented")

src/kd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ function build_kdtree(xs::AbstractMatrix{T},
143143

144144
# find the median and partition
145145
if isodd(length(perm))
146-
mid = length(perm) ÷ 2
146+
mid = (length(perm) + 1) ÷ 2
147147
partialsort!(perm, mid, by=i -> xs[i, j])
148148
med = xs[perm[mid], j]
149149
mid1 = mid

test/runtests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,22 @@ let x = 1:10, y = sin.(1:10)
3232
end
3333

3434
@test_throws DimensionMismatch loess([1.0 2.0; 3.0 4.0], [1.0])
35+
36+
@testset "Issue 28" begin
37+
@testset "Example 1" begin
38+
x = [1.0, 2.0, 3.0, 4.0]
39+
y = [1.0, 2.0, 3.0, 4.0]
40+
@test_throws ArgumentError("neighborhood size must be larger than degree+1=3 but was 1. Try increasing the value of span.") loess(x, y, span = 0.25)
41+
@test_throws ArgumentError("neighborhood size must be larger than degree+1=3 but was 2. Try increasing the value of span.") loess(x, y, span = 0.33)
42+
@test predict(loess(x, y), x) x
43+
end
44+
45+
@testset "Example 2" begin
46+
x = [1.0, 1.0, 2.0, 3.0, 4.0, 4.0]
47+
y = [1.0, 1.0, 2.0, 3.0, 4.0, 4.0]
48+
@test_throws ArgumentError("neighborhood size must be larger than degree+1=3 but was 2. Try increasing the value of span.") loess(x, y, span = 0.33)
49+
@test_broken predict(loess(x, y, span = 0.4), x) x
50+
@test_broken predict(loess(x, y, span = 0.5), x) x
51+
@test predict(loess(x, y, span = 0.6), x) x
52+
end
53+
end

0 commit comments

Comments
 (0)