Skip to content

Commit 33a4cd3

Browse files
authored
improve performance with ties (#74)
* check for all equal in small chunks * patch bump * actually split on ties * tests * pre sort * clean up a little * cleanup * add comment about performance linking to PR * consistency and correctness in debug logging
1 parent 5126a74 commit 33a4cd3

File tree

3 files changed

+61
-32
lines changed

3 files changed

+61
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Loess"
22
uuid = "4345ca2d-374a-55d4-8d30-97f9976e7612"
3-
version = "0.6.1"
3+
version = "0.6.2"
44

55
[deps]
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

src/kd.jl

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ function KDTree(
4141
) where T <: AbstractFloat
4242

4343
n, m = size(xs)
44-
perm = collect(1:n)
4544

4645
bounds = Array{T}(undef, 2, m)
4746
for j in 1:m
@@ -63,14 +62,13 @@ function KDTree(
6362
push!(verts, T[vert...])
6463
end
6564

65+
perm = collect(1:n)
6666
root = build_kdtree(xs, perm, bounds, leaf_size_cutoff, leaf_diameter_cutoff, verts)
6767

68-
KDTree(convert(Matrix{T}, xs), collect(1:n), root, verts, bounds)
68+
KDTree(convert(Matrix{T}, xs), perm, root, verts, bounds)
6969
end
7070

7171

72-
73-
7472
"""
7573
diameter(bounds)
7674
@@ -88,6 +86,32 @@ function diameter(bounds::Matrix)
8886
euclidean(vec(bounds[1,:]), vec(bounds[2,:]))
8987
end
9088

89+
"""
90+
_select_j(xs::AbstractMatrix{T})
91+
92+
Select the column for sorting the rows xs based on the column with the largest spread.
93+
"""
94+
function _select_j(xs::AbstractMatrix{T}) where {T <: AbstractFloat}
95+
size(xs, 2) == 1 && return 1
96+
97+
# split on the dimension with the largest spread
98+
# maxspread, j = findmax(maximum(xs[perm, k]) - minimum(xs[perm, k]) for k in 1:m)
99+
j = 1
100+
maxspread = 0
101+
@inbounds for k in axes(xs, 2)
102+
xmin = Inf
103+
xmax = -Inf
104+
@inbounds for i in axes(xs, 1)
105+
xmin = min(xmin, xs[i, k])
106+
xmax = max(xmax, xs[i, k])
107+
end
108+
if xmax - xmin > maxspread
109+
maxspread = xmax - xmin
110+
j = k
111+
end
112+
end
113+
return j
114+
end
91115

92116
"""
93117
build_kdtree(xs, perm, bounds, leaf_size_cutoff, leaf_diameter_cutoff, verts)
@@ -121,30 +145,22 @@ function build_kdtree(xs::AbstractMatrix{T},
121145
Base.require_one_based_indexing(xs)
122146
Base.require_one_based_indexing(perm)
123147

148+
j = _select_j(xs)
124149
n, m = size(xs)
150+
# performance testing showed that sorting everything at once was dramatically faster
151+
# than repeated partial sorting with partialsort! when there are ties:
152+
# https://github.com/JuliaStats/Loess.jl/pull/74
153+
if !issorted(view(xs, perm, j))
154+
@debug "received unsorted data, sorting"
155+
sortperm!(perm, view(xs, :, j))
156+
end
157+
xjs = view(xs, perm, j)
125158

126159
if length(perm) <= leaf_size_cutoff || diameter(bounds) <= leaf_diameter_cutoff
127160
@debug "Creating leaf node" length(perm) leaf_size_cutoff diameter(bounds) leaf_diameter_cutoff
128161
return nothing
129162
end
130163

131-
# split on the dimension with the largest spread
132-
# maxspread, j = findmax(maximum(xs[perm, k]) - minimum(xs[perm, k]) for k in 1:m)
133-
j = 1
134-
maxspread = 0
135-
for k in 1:m
136-
xmin = Inf
137-
xmax = -Inf
138-
for i in perm
139-
xmin = min(xmin, xs[i, k])
140-
xmax = max(xmax, xs[i, k])
141-
end
142-
if xmax - xmin > maxspread
143-
maxspread = xmax - xmin
144-
j = k
145-
end
146-
end
147-
148164
# Find the "median" and partition
149165
#
150166
# The aim of the algorithm is to split the data recursively in two roughly equally sized
@@ -165,37 +181,36 @@ function build_kdtree(xs::AbstractMatrix{T},
165181
#
166182
# The details here are reversed engineered from the C/Fortran implementation wrapped
167183
# by R and also distribtued on NETLIB.
168-
mid = (length(perm) + 1) ÷ 2
169-
@debug "Candidate median index and median value" mid xs[perm[mid], j]
184+
mid = (length(xjs) + 1) ÷ 2
185+
@debug "Candidate median index and median value" mid xjs[mid]
170186

171187
offset = 0
172188
local mid1, mid2
173189
while true
174190
mid1 = mid + offset
175191
mid2 = mid1 + 1
176192
if mid1 < 1
177-
@debug "mid1 is zero. All elements are identical. Creating vertex and then two leaves" mid1 length(perm) xs[perm[mid], j]
193+
@debug "mid1 is zero. All elements are identical. Creating vertex and then two leaves" mid1 length(xjs) xjs[mid]
178194
offset = mid1 = 0
179-
mid2 = length(perm) + 1
195+
mid2 = length(xjs) + 1
180196
break
181197
end
182-
if mid2 > length(perm)
183-
@debug "mid2 is out of bounds. Continuing with negative offset" mid2 length(perm) offset
198+
if mid2 > length(xjs)
199+
@debug "mid2 is out of bounds. Continuing with negative offset" mid2 length(xjs) offset
184200
# This makes the offset 0, 1, -1, 2, -2, ...
185201
offset = -offset + (offset <= 0)
186202
continue
187203
end
188-
p12 = partialsort!(perm, mid1:mid2, by = i -> xs[i, j])
189-
if xs[p12[1], j] == xs[p12[2], j]
190-
@debug "tie! Adjusting offset" xs[p12[1], j] xs[p12[2], j] offset
204+
if xjs[mid1] == xjs[mid2]
205+
# @debug "tie! Adjusting offset" xs[p12[1], j] xs[p12[2], j] offset
191206
# This makes the offset 0, 1, -1, 2, -2, ...
192207
offset = -offset + (offset <= 0)
193208
else
194209
break
195210
end
196211
end
197212
mid += offset
198-
med = xs[perm[mid], j]
213+
med = xjs[mid]
199214
@debug "Accepted median index and median value" mid med
200215

201216
leftbounds = copy(bounds)

test/runtests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,20 @@ end
4747
@test_broken predict(model, x)[end] pred[end] atol=1e-5
4848
end
4949

50+
@testset "lots of ties" begin
51+
# adapted from https://github.com/JuliaStats/Loess.jl/pull/74#discussion_r1294303522
52+
x = repeat([π/4*i for i in -20:20], inner=101)
53+
y = sin.(x)
54+
55+
model = loess(x,y; span=0.2)
56+
for i in -3:3
57+
@test predict(model, i * π) 0 atol=1e-12
58+
# not great tolerance but loess also struggles to capture the sine peaks
59+
@test abs(predict(model, i * π + π / 2 )) 0.9 atol=0.1
60+
end
61+
62+
end
63+
5064
@test_throws DimensionMismatch loess([1.0 2.0; 3.0 4.0], [1.0])
5165

5266
@testset "Issue 28" begin

0 commit comments

Comments
 (0)