Skip to content

Commit ae3ca4f

Browse files
authored
Merge pull request #12 from JuliaAI/dev
For a 0.2.2 release
2 parents fce2fef + 579a12f commit ae3ca4f

14 files changed

+299
-107
lines changed

Project.toml

+16-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
name = "LearnTestAPI"
22
uuid = "3111ed91-c4f2-40e7-bb19-7f6c618409b8"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.2.1"
4+
version = "0.2.2"
55

66
[deps]
7+
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
8+
CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
79
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
810
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
911
IsURL = "ceb4388c-583f-448d-bb30-00b11e8c5682"
1012
LearnAPI = "92ad9a40-7767-427a-9ee6-6e577f1266cb"
13+
LearnDataFrontEnds = "5cca22a3-9356-470e-ba1b-8268d0135a4b"
1114
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1215
MLCore = "c2834f40-e789-41da-a90e-33b280584a8c"
1316
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -22,6 +25,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2225
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2326

2427
[compat]
28+
CategoricalArrays = "0.10.8"
29+
CategoricalDistributions = "0.1.15"
2530
Distributions = "0.25"
2631
InteractiveUtils = "<0.0.1, 1"
2732
IsURL = "0.2.0"
@@ -46,7 +51,16 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4651
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4752
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4853
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
54+
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
4955
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
5056

5157
[targets]
52-
test = ["DataFrames", "Distributions", "Random", "LinearAlgebra", "Statistics", "Tables"]
58+
test = [
59+
"DataFrames",
60+
"Distributions",
61+
"Random",
62+
"LinearAlgebra",
63+
"Statistics",
64+
"StatsModels",
65+
"Tables",
66+
]

src/LearnTestAPI.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
LearnTestAPI
33
4-
Module for testing implementations of the interfacde defined in
4+
Module for testing implementations of the interface defined in
55
[LearnAPI.jl](https://juliaai.github.io/LearnAPI.jl/dev/).
66
77
If your package defines an object `learner` implementing the interface, then put something
@@ -46,12 +46,14 @@ using LinearAlgebra
4646
using Random
4747
using Statistics
4848
using UnPack
49+
import LearnDataFrontEnds
4950

5051
include("tools.jl")
5152
include("logging.jl")
5253
include("testapi.jl")
5354
include("learners/static_algorithms.jl")
5455
include("learners/regression.jl")
56+
include("learners/classification.jl")
5557
include("learners/ensembling.jl")
5658
# next learner excluded because of heavy dependencies:
5759
# include("learners/gradient_descent.jl")

src/learners/classification.jl

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# This file defines `ConstantClassifier()`
2+
3+
using LearnAPI
4+
import LearnDataFrontEnds as FrontEnds
5+
import MLCore
6+
import CategoricalArrays
7+
import CategoricalDistributions
8+
import CategoricalDistributions.OrderedCollections.OrderedDict
9+
import CategoricalDistributions.Distributions.StatsBase.proportionmap
10+
11+
# The implementation of a constant classifier below is not the simplest, but it
12+
# demonstrates some patterns that apply more generally in classification, including
13+
# inclusion of the canned data front end, `Sage`.
14+
15+
"""
16+
ConstantClassifier()
17+
18+
Instantiate a constant (dummy) classifier. Can predict `Point` or `Distribution` targets.
19+
20+
"""
21+
struct ConstantClassifier end
22+
23+
struct ConstantClassifierFitted
24+
learner::ConstantClassifier
25+
probabilities
26+
names::Vector{Symbol}
27+
classes_seen
28+
codes_seen
29+
decoder
30+
end
31+
32+
LearnAPI.learner(model::ConstantClassifierFitted) = model.learner
33+
34+
# add a data front end; `obs` will return objects with type `FrontEnds.Obs`:
35+
const front_end = FrontEnds.Sage(code_type=:small)
36+
LearnAPI.obs(learner::ConstantClassifier, data) =
37+
FrontEnds.fitobs(learner, data, front_end)
38+
LearnAPI.obs(model::ConstantClassifierFitted, data) =
39+
obs(model, data, front_end)
40+
41+
# data deconstructors:
42+
LearnAPI.features(learner::ConstantClassifier, data) =
43+
LearnAPI.features(learner, data, front_end)
44+
LearnAPI.target(learner::ConstantClassifier, data) =
45+
LearnAPI.target(learner, data, front_end)
46+
47+
function LearnAPI.fit(learner::ConstantClassifier, observations::FrontEnds.Obs; verbosity=1)
48+
y = observations.target # integer "codes"
49+
names = observations.names
50+
classes_seen = observations.classes_seen
51+
codes_seen = sort(unique(y))
52+
decoder = observations.decoder
53+
54+
d = proportionmap(y)
55+
# proportions ordered by key, i.e., by codes seen:
56+
probabilities = values(sort!(OrderedDict(d))) |> collect
57+
58+
return ConstantClassifierFitted(
59+
learner,
60+
probabilities,
61+
names,
62+
classes_seen,
63+
codes_seen,
64+
decoder,
65+
)
66+
end
67+
LearnAPI.fit(learner::ConstantClassifier, data; kwargs...) =
68+
fit(learner, obs(learner, data); kwargs...)
69+
70+
function LearnAPI.predict(
71+
model::ConstantClassifierFitted,
72+
::Point,
73+
observations::FrontEnds.Obs,
74+
)
75+
n = MLCore.numobs(observations)
76+
idx = argmax(model.probabilities)
77+
code_of_mode = model.codes_seen[idx]
78+
return model.decoder.(fill(code_of_mode, n))
79+
end
80+
LearnAPI.predict(model::ConstantClassifierFitted, ::Point, data) =
81+
predict(model, Point(), obs(model, data))
82+
83+
function LearnAPI.predict(
84+
model::ConstantClassifierFitted,
85+
::Distribution,
86+
observations::FrontEnds.Obs,
87+
)
88+
n = MLCore.numobs(observations)
89+
probs = model.probabilities
90+
# repeat vertically to get rows of a matrix:
91+
probs_matrix = reshape(repeat(probs, n), (length(probs), n))'
92+
return CategoricalDistributions.UnivariateFinite(model.classes_seen, probs_matrix)
93+
end
94+
LearnAPI.predict(model::ConstantClassifierFitted, ::Distribution, data) =
95+
predict(model, Distribution(), obs(model, data))
96+
97+
# accessor function:
98+
LearnAPI.feature_names(model::ConstantClassifierFitted) = model.names
99+
100+
@trait(
101+
ConstantClassifier,
102+
constructor = ConstantClassifier,
103+
kinds_of_proxy = (Point(),Distribution()),
104+
tags = ("classification",),
105+
functions = (
106+
:(LearnAPI.fit),
107+
:(LearnAPI.learner),
108+
:(LearnAPI.clone),
109+
:(LearnAPI.strip),
110+
:(LearnAPI.obs),
111+
:(LearnAPI.features),
112+
:(LearnAPI.target),
113+
:(LearnAPI.predict),
114+
:(LearnAPI.feature_names),
115+
)
116+
)
117+
118+
true

src/learners/dimension_reduction.jl

+29-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# This file defines `TruncatedSVD(; codim=1)`
22

33
using LearnAPI
4-
using LinearAlgebra
4+
using LinearAlgebra
5+
import LearnDataFrontEnds as FrontEnds
56

67

78
# # DIMENSION REDUCTION USING TRUNCATED SVD DECOMPOSITION
89

910
# Recall that truncated SVD reduction is the same as PCA reduction, but without
10-
# centering. We suppose observations are presented as the columns of a `Real` matrix.
11+
# centering.
1112

1213
# Some struct fields are left abstract for simplicity.
1314

@@ -23,6 +24,11 @@ end
2324
Instantiate a truncated singular value decomposition algorithm for reducing the dimension
2425
of observations by `codim`.
2526
27+
Data can be provided to `fit` or `transform` in any form supported by the `Tarragon` data
28+
front end at LearnDataFrontEnds.jl. However, the outputs of `transform` and
29+
`inverse_transform` are always matrices.
30+
31+
2632
```julia
2733
learner = Truncated()
2834
X = rand(3, 100) # 100 observations in 3-space
@@ -49,10 +55,21 @@ end
4955

5056
LearnAPI.learner(model::TruncatedSVDFitted) = model.learner
5157

52-
function LearnAPI.fit(learner::TruncatedSVD, X; verbosity=1)
58+
# add a canned data front end; `obs` will return objects of type `FrontEnds.Obs`:
59+
LearnAPI.obs(learner::TruncatedSVD, data) =
60+
FrontEnds.fitobs(learner, data, FrontEnds.Tarragon())
61+
LearnAPI.obs(model::TruncatedSVDFitted, data) =
62+
obs(model, data, FrontEnds.Tarragon())
63+
64+
# training data deconstructor:
65+
LearnAPI.features(learner::TruncatedSVD, data) =
66+
LearnAPI.features(learner, data, FrontEnds.Tarragon())
67+
68+
function LearnAPI.fit(learner::TruncatedSVD, observations::FrontEnds.Obs; verbosity=1)
5369

5470
# unpack hyperparameters:
5571
codim = learner.codim
72+
X = observations.features
5673
p, n = size(X)
5774
n p || error("Insufficient number observations. ")
5875
outdim = p - codim
@@ -70,14 +87,19 @@ function LearnAPI.fit(learner::TruncatedSVD, X; verbosity=1)
7087
return TruncatedSVDFitted(learner, U, Ut, singular_values)
7188

7289
end
90+
LearnAPI.fit(learner::TruncatedSVD, data; kwargs...) =
91+
LearnAPI.fit(learner, LearnAPI.obs(learner, data); kwargs...)
7392

74-
LearnAPI.transform(model::TruncatedSVDFitted, X) = model.Ut*X
93+
LearnAPI.transform(model::TruncatedSVDFitted, observations::FrontEnds.Obs) =
94+
model.Ut*(observations.features)
95+
LearnAPI.transform(model::TruncatedSVDFitted, data) =
96+
LearnAPI.transform(model, obs(model, data))
7597

7698
# convenience fit-transform:
77-
LearnAPI.transform(learner::TruncatedSVD, X; kwargs...) =
78-
transform(fit(learner, X; kwargs...), X)
99+
LearnAPI.transform(learner::TruncatedSVD, data; kwargs...) =
100+
transform(fit(learner, data; kwargs...), data)
79101

80-
LearnAPI.inverse_transform(model::TruncatedSVDFitted, W) = model.U*W
102+
LearnAPI.inverse_transform(model::TruncatedSVDFitted, W::AbstractMatrix) = model.U*W
81103

82104
# accessor function:
83105
function LearnAPI.extras(model::TruncatedSVDFitted)

src/learners/ensembling.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ LearnAPI.components(model::EnsembleFitted) = [:atom => model.models,]
211211
# - `out_of_sample_losses`
212212

213213
# For simplicity, this implementation is restricted to univariate features. The simplistic
214-
# algorithm is explained in the docstring. of the data presented.
214+
# algorithm is explained in the docstring.
215215

216216

217217
# ## HELPERS
@@ -276,6 +276,7 @@ function update!(
276276
stump = Stump(ξ, left, right)
277277
push!(forest, stump)
278278
new_predictions = _predict(stump, x)
279+
279280
# efficient in-place update of `predictions`:
280281
predictions .= (k*predictions .+ new_predictions)/(k + 1)
281282
push!(training_losses, (predictions[training_indices] .- ytrain).^2 |> sum)

0 commit comments

Comments
 (0)