Skip to content

Commit add0103

Browse files
authored
Merge pull request #68 from JuliaStats/an/optimize
Various optimizations
2 parents ec042b9 + 2569c9b commit add0103

File tree

2 files changed

+63
-62
lines changed

2 files changed

+63
-62
lines changed

src/Loess.jl

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ export loess, predict
1010
include("kd.jl")
1111

1212

13-
mutable struct LoessModel{T <: AbstractFloat}
14-
xs::AbstractMatrix{T} # An n by m predictor matrix containing n observations from m predictors
15-
ys::AbstractVector{T} # A length n response vector
13+
struct LoessModel{T <: AbstractFloat}
14+
xs::Matrix{T} # An n by m predictor matrix containing n observations from m predictors
15+
ys::Vector{T} # A length n response vector
1616
predictions_and_gradients::Dict{Vector{T}, Vector{T}} # kd-tree vertexes mapped to prediction and gradient at each vertex
1717
kdtree::KDTree{T}
1818
end
@@ -44,6 +44,10 @@ function loess(
4444
degree::Integer = 2,
4545
cell::AbstractFloat = 0.2
4646
) where T<:AbstractFloat
47+
48+
Base.require_one_based_indexing(xs)
49+
Base.require_one_based_indexing(ys)
50+
4751
if size(xs, 1) != size(ys, 1)
4852
throw(DimensionMismatch("Predictor and response arrays must of the same length"))
4953
end
@@ -80,8 +84,12 @@ function loess(
8084
end
8185

8286
# distance to each point
83-
for i in 1:n
84-
ds[i] = euclidean(vec(vert), vec(xs[i,:]))
87+
@inbounds for i in 1:n
88+
s = zero(T)
89+
for j in 1:m
90+
s += (xs[i, j] - vert[j])^2
91+
end
92+
ds[i] = sqrt(s)
8593
end
8694

8795
# find the q closest points
@@ -128,7 +136,7 @@ function loess(
128136
]
129137
end
130138

131-
LoessModel{T}(xs, ys, predictions_and_gradients, kdtree)
139+
LoessModel(xs, ys, predictions_and_gradients, kdtree)
132140
end
133141

134142
loess(xs::AbstractVector{T}, ys::AbstractVector{T}; kwargs...) where {T<:AbstractFloat} =
@@ -153,50 +161,44 @@ end
153161
# Returns:
154162
# A length n' vector of predicted response values.
155163
#
156-
function predict(model::LoessModel, z::Real)
157-
predict(model, [z])
158-
end
159-
160-
function predict(model::LoessModel, zs::AbstractVector)
161-
162-
Base.require_one_based_indexing(zs)
164+
function predict(model::LoessModel{T}, z::Number) where T
165+
adjacent_verts = traverse(model.kdtree, (T(z),))
163166

164-
m = size(model.xs, 2)
167+
@assert(length(adjacent_verts) == 2)
168+
v₁, v₂ = adjacent_verts[1][1], adjacent_verts[2][1]
165169

166-
# in the univariate case, interpret a non-singleton zs as vector of
167-
# ponits, not one point
168-
if m == 1 && length(zs) > 1
169-
return predict(model, reshape(zs, (length(zs), 1)))
170+
if z == v₁ || z == v₂
171+
return first(model.predictions_and_gradients[[z]])
170172
end
171173

172-
if length(zs) != m
173-
error("$(m)-dimensional model applied to length $(length(zs)) vector")
174-
end
174+
y₁, dy₁ = model.predictions_and_gradients[[v₁]]
175+
y₂, dy₂ = model.predictions_and_gradients[[v₂]]
175176

176-
adjacent_verts = traverse(model.kdtree, zs)
177+
b_int = cubic_interpolation(v₁, y₁, dy₁, v₂, y₂, dy₂)
177178

178-
if m == 1
179-
@assert(length(adjacent_verts) == 2)
180-
z = zs[1]
181-
v₁, v₂ = adjacent_verts[1][1], adjacent_verts[2][1]
179+
return evalpoly(z, b_int)
180+
end
182181

183-
if z == v₁ || z == v₂
184-
return first(model.predictions_and_gradients[[z]])
185-
end
182+
function predict(model::LoessModel, zs::AbstractVector)
183+
if size(model.xs, 2) > 1
184+
throw(ArgumentError("multivariate blending not yet implemented"))
185+
end
186186

187-
y₁, dy₁ = model.predictions_and_gradients[[v₁]]
188-
y₂, dy₂ = model.predictions_and_gradients[[v₂]]
187+
return [predict(model, z) for z in zs]
188+
end
189189

190-
b_int = cubic_interpolation(v₁, y₁, dy₁, v₂, y₂, dy₂)
190+
function predict(model::LoessModel, zs::AbstractMatrix)
191+
if size(model.xs, 2) != size(zs, 2)
192+
throw(DimensionMismatch("number of columns in input matrix must match the number of columns in the model matrix"))
193+
end
191194

192-
return evalpoly(z, b_int)
195+
if size(zs, 2) == 1
196+
return predict(model, vec(zs))
193197
else
194-
error("Multivariate blending not yet implemented")
198+
return [predict(model, row) for row in eachrow(zs)]
195199
end
196200
end
197201

198-
predict(model::LoessModel, zs::AbstractMatrix) = map(Base.Fix1(predict, model), eachrow(zs))
199-
200202
"""
201203
tricubic(u)
202204

src/kd.jl

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,17 @@
11
# Simple static kd-trees.
22

3-
abstract type KDNode end
4-
5-
struct KDLeafNode <: KDNode
6-
end
7-
8-
struct KDInternalNode{T <: AbstractFloat} <: KDNode
3+
struct KDNode{T <: AbstractFloat}
94
j::Int # dimension on which the data is split
105
med::T # median value where the split occours
11-
leftnode::KDNode
12-
rightnode::KDNode
6+
leftnode::Union{Nothing, KDNode{T}}
7+
rightnode::Union{Nothing, KDNode{T}}
138
end
149

1510

1611
struct KDTree{T <: AbstractFloat}
17-
xs::AbstractMatrix{T} # A matrix of n, m-dimensional observations
12+
xs::Matrix{T} # A matrix of n, m-dimensional observations
1813
perm::Vector{Int} # permutation of data to avoid modifying xs
19-
root::KDNode # root node
14+
root::KDNode{T} # root node
2015
verts::Set{Vector{T}}
2116
bounds::Matrix{T} # Top-level bounding box
2217
end
@@ -114,7 +109,7 @@ Modifies:
114109
`perm`, `verts`
115110
116111
Returns:
117-
Either a `KDLeafNode` or a `KDInternalNode`
112+
Either a `nothing` or a `KDNode`
118113
"""
119114
function build_kdtree(xs::AbstractMatrix{T},
120115
perm::AbstractVector,
@@ -130,7 +125,7 @@ function build_kdtree(xs::AbstractMatrix{T},
130125

131126
if length(perm) <= leaf_size_cutoff || diameter(bounds) <= leaf_diameter_cutoff
132127
@debug "Creating leaf node" length(perm) leaf_size_cutoff diameter(bounds) leaf_diameter_cutoff
133-
return KDLeafNode()
128+
return nothing
134129
end
135130

136131
# split on the dimension with the largest spread
@@ -226,7 +221,7 @@ function build_kdtree(xs::AbstractMatrix{T},
226221
push!(verts, T[vert...])
227222
end
228223

229-
KDInternalNode{T}(j, med, leftnode, rightnode)
224+
KDNode(j, med, leftnode, rightnode)
230225
end
231226

232227

@@ -246,14 +241,15 @@ end
246241
Traverse the tree `kdtree` to the bottom and return the verticies of
247242
the bounding hypercube of the leaf node containing the point `x`.
248243
"""
249-
function traverse(kdtree::KDTree, x::AbstractVector)
244+
function traverse(kdtree::KDTree{T}, x::NTuple{N,T}) where {N,T}
245+
250246
m = size(kdtree.bounds, 2)
251247

252-
if length(x) != m
248+
if N != m
253249
throw(DimensionMismatch("$(m)-dimensional kd-tree searched with a length $(length(x)) vector."))
254250
end
255251

256-
for j in 1:m
252+
for j in 1:N
257253
if x[j] < kdtree.bounds[1, j] || x[j] > kdtree.bounds[2, j]
258254
error(
259255
"""
@@ -266,15 +262,18 @@ function traverse(kdtree::KDTree, x::AbstractVector)
266262

267263
bounds = copy(kdtree.bounds)
268264
node = kdtree.root
269-
while !isa(node, KDLeafNode)
270-
if x[node.j] <= node.med
271-
bounds[2, node.j] = node.med
272-
node = node.leftnode
273-
else
274-
bounds[1, node.j] = node.med
275-
node = node.rightnode
276-
end
277-
end
278265

279-
bounds_verts(bounds)
266+
return _traverse!(bounds, node, x)
267+
end
268+
269+
_traverse!(bounds, node::Nothing, x) = bounds
270+
function _traverse!(bounds, node::KDNode, x)
271+
if x[node.j] <= node.med
272+
bounds[2, node.j] = node.med
273+
return _traverse!(bounds, node.leftnode, x)
274+
else
275+
bounds[1, node.j] = node.med
276+
return _traverse!(bounds, node.rightnode, x)
277+
end
280278
end
279+

0 commit comments

Comments
 (0)