Skip to content

Commit df1bf95

Browse files
committed
Handle prediction at tree split points separately. Always use qr
factorization when solving the local system to avoid error when the system is singular. Also fix a one-off error in the median calculation in the KDTree implementation.
1 parent ca17cf5 commit df1bf95

File tree

4 files changed

+46
-14
lines changed

4 files changed

+46
-14
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ version = "0.5.2"
44

55
[deps]
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
89

910
[compat]
10-
julia = "0.7, 1"
1111
Distances = "0.7, 0.8, 0.9, 0.10"
12+
julia = "0.7, 1"
1213

1314
[extras]
1415
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/Loess.jl

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module Loess
22

33
import Distances.euclidean
44

5-
using Statistics
5+
using Statistics, LinearAlgebra
66

77
export loess, predict
88

@@ -44,6 +44,9 @@ function loess(xs::AbstractMatrix{T}, ys::AbstractVector{T};
4444

4545
n, m = size(xs)
4646
q = ceil(Int, (span * n))
47+
if q < degree + 1
48+
throw(ArgumentError("neighborhood size must be larger than degree+1=$(degree + 1) but was $q. Try increasing the value of span."))
49+
end
4750

4851
# TODO: We need to keep track of how we are normalizing so we can
4952
# correctly apply predict to unnormalized data. We should have a normalize
@@ -53,7 +56,6 @@ function loess(xs::AbstractMatrix{T}, ys::AbstractVector{T};
5356
end
5457

5558
kdtree = KDTree(xs, 0.05 * span)
56-
verts = Array{T}(undef, length(kdtree.verts), m)
5759

5860
# map verticies to their index in the bs coefficient matrix
5961
verts = Dict{Vector{T}, Int}()
@@ -69,6 +71,7 @@ function loess(xs::AbstractMatrix{T}, ys::AbstractVector{T};
6971
# TODO: higher degree fitting
7072
us = Array{T}(undef, q, 1 + degree * m)
7173
vs = Array{T}(undef, q)
74+
7275
for (vert, k) in verts
7376
# reset perm
7477
for i in 1:n
@@ -85,20 +88,22 @@ function loess(xs::AbstractMatrix{T}, ys::AbstractVector{T};
8588
dmax = maximum([ds[perm[i]] for i = 1:q])
8689

8790
for i in 1:q
88-
pi = perm[i]
89-
w = tricubic(ds[pi] / dmax)
91+
pᵢ = perm[i]
92+
w = tricubic(ds[pᵢ] / dmax)
9093
us[i,1] = w
9194
for j in 1:m
92-
x = xs[pi, j]
95+
x = xs[pᵢ, j]
9396
wxl = w
9497
for l in 1:degree
9598
wxl *= x
96-
us[i, 1 + (j-1)*degree + l] = wxl # w*x^l
99+
us[i, 1 + (j - 1)*degree + l] = wxl # w*x^l
97100
end
98101
end
99-
vs[i] = ys[pi] * w
102+
vs[i] = ys[pᵢ] * w
100103
end
101-
bs[k,:] = us \ vs
104+
105+
F = qr(us, Val(true))
106+
bs[k,:] = F\vs
102107
end
103108

104109
LoessModel{T}(xs, ys, bs, verts, kdtree)
@@ -149,11 +154,16 @@ function predict(model::LoessModel{T}, zs::AbstractVector{T}) where T <: Abstrac
149154
if m == 1
150155
@assert(length(adjacent_verts) == 2)
151156
z = zs[1]
152-
u = (z - adjacent_verts[1][1]) /
153-
(adjacent_verts[2][1] - adjacent_verts[1][1])
157+
v₁, v₂ = adjacent_verts[1][1], adjacent_verts[2][1]
158+
159+
if z == v₁ || z == v₂
160+
return evalpoly(zs, model.bs[model.verts[[z]],:])
161+
end
162+
163+
u = (z - v₁)/(v₂ - v₁)
154164

155-
y1 = evalpoly(zs, model.bs[model.verts[[adjacent_verts[1][1]]],:])
156-
y2 = evalpoly(zs, model.bs[model.verts[[adjacent_verts[2][1]]],:])
165+
y1 = evalpoly(zs, model.bs[model.verts[[v₁]],:])
166+
y2 = evalpoly(zs, model.bs[model.verts[[v₂]],:])
157167
return (1.0 - u) * y1 + u * y2
158168
else
159169
error("Multivariate blending not yet implemented")

src/kd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ function build_kdtree(xs::AbstractMatrix{T},
143143

144144
# find the median and partition
145145
if isodd(length(perm))
146-
mid = length(perm) ÷ 2
146+
mid = (length(perm) + 1) ÷ 2
147147
partialsort!(perm, mid, by=i -> xs[i, j])
148148
med = xs[perm[mid], j]
149149
mid1 = mid

test/runtests.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,24 @@ let x = 1:10, y = sin.(1:10)
3232
end
3333

3434
@test_throws DimensionMismatch loess([1.0 2.0; 3.0 4.0], [1.0])
35+
36+
@testset "Issue 28" begin
37+
@testset "Example 1" begin
38+
x = [1.0, 2.0, 3.0, 4.0]
39+
y = [1.0, 2.0, 3.0, 4.0]
40+
@test_throws ArgumentError("neighborhood size must be larger than degree+1=3 but was 1. Try increasing the value of span.") loess(x, y, span = 0.25)
41+
@test_throws ArgumentError("neighborhood size must be larger than degree+1=3 but was 2. Try increasing the value of span.") loess(x, y, span = 0.33)
42+
@test predict(loess(x, y), x) x
43+
end
44+
45+
@testset "Example 2" begin
46+
x = [1.0, 1.0, 2.0, 3.0, 4.0, 4.0]
47+
y = [1.0, 1.0, 2.0, 3.0, 4.0, 4.0]
48+
@test_throws ArgumentError("neighborhood size must be larger than degree+1=3 but was 2. Try increasing the value of span.") loess(x, y, span = 0.33)
49+
# For 0.4 and 0.5 these current don't hit the middle values. I suspect
50+
# the issue is related to the ties in x.
51+
@test_broken predict(loess(x, y, span = 0.4), x) x
52+
@test_broken predict(loess(x, y, span = 0.5), x) x
53+
@test predict(loess(x, y, span = 0.6), x) x
54+
end
55+
end

0 commit comments

Comments
 (0)