Skip to content

Commit e6a02cc

Browse files
authored
Merge pull request #21 from JuliaAI/measure
Address breaking changes in MLJBase 1.0
2 parents bed1dfc + 76ad34f commit e6a02cc

File tree

6 files changed

+59
-9
lines changed

6 files changed

+59
-9
lines changed

Project.toml

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJFlow"
22
uuid = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f"
33
authors = ["Jose Esparza <[email protected]>"]
4-
version = "0.1.1"
4+
version = "0.2.0"
55

66
[deps]
77
MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
@@ -10,15 +10,16 @@ MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
1010

1111
[compat]
1212
MLFlowClient = "0.4.4"
13-
MLJBase = "0.21.14"
13+
MLJBase = "1"
1414
MLJModelInterface = "1.9.1"
1515
julia = "1.6"
1616

1717
[extras]
1818
MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
1919
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
2020
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
21+
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
2122
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2223

2324
[targets]
24-
test = ["Test", "MLFlowClient", "MLJModels", "MLJDecisionTreeInterface"]
25+
test = ["Test", "MLFlowClient", "MLJModels", "MLJDecisionTreeInterface", "StatisticalMeasures"]

src/MLJFlow.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module MLJFlow
22

3-
using MLJBase: info, name, Model,
4-
Machine
3+
using MLJBase: Model, Machine, name
54
using MLJModelInterface: flat_params
65
using MLFlowClient: MLFlow, logparam, logmetric,
76
createrun, MLFlowRun, updaterun,

src/base.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ function log_evaluation(logger::MLFlowLogger, performance_evaluation)
33
artifact_location=logger.artifact_location)
44
run = createrun(logger.service, experiment;
55
tags=[
6-
Dict("key" => "resampling", "value" => string(performance_evaluation.resampling)),
6+
Dict(
7+
"key" => "resampling",
8+
"value" => string(performance_evaluation.resampling)
9+
),
710
Dict("key" => "repeats", "value" => string(performance_evaluation.repeats)),
811
Dict("key" => "model type", "value" => name(performance_evaluation.model)),
912
]

src/service.jl

+40-3
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,57 @@ function logmodelparams(service::MLFlow, run::MLFlowRun, model::Model)
1818
end
1919
end
2020

21+
const MLFLOW_CHAR_SET =
22+
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-. /"
23+
24+
"""
25+
good_name(measure)
26+
27+
**Private method.**
28+
29+
Returns a string representation of `measure` that can be used as a valid name in
30+
MLflow. Includes the value of the first hyperparameter, if there is one.
31+
32+
```julia
33+
julia> good_name(macro_f1score)
34+
"MulticlassFScore-beta_1.0"
35+
36+
"""
37+
function good_name(measure)
38+
name = string(measure)
39+
name = replace(name, ", …" => "")
40+
name = replace(name, " = " => "_")
41+
name = replace(name, "()" => "")
42+
name = replace(name, ")" => "")
43+
map(collect(name)) do char
44+
char in ['(', ','] && return '-'
45+
char == '=' && return '_'
46+
char in MLFLOW_CHAR_SET && return char
47+
" "
48+
end |> join
49+
end
50+
2151
"""
2252
logmachinemeasures(service::MLFlow, run::MLFlowRun, model::Model)
2353
2454
Extracts the parameters of a model and logs them to the MLFlow server.
2555
2656
# Arguments
27-
- `service::MLFlow`: An MLFlow service. See [MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlow)
28-
- `run::MLFlowRun`: An MLFlow run. See [MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlowRun)
57+
58+
- `service::MLFlow`: An MLFlow service. See
59+
[MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlow)
60+
61+
- `run::MLFlowRun`: An MLFlow run. See
62+
[MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlowRun)
63+
2964
- `measures`: A vector of measures.
65+
3066
- `measurements`: A vector of measurements.
67+
3168
"""
3269
function logmachinemeasures(service::MLFlow, run::MLFlowRun, measures,
3370
measurements)
34-
measure_names = measures .|> info .|> x -> x.name
71+
measure_names = measures .|> good_name
3572
for (name, value) in zip(measure_names, measurements)
3673
logmetric(service, run, name, value)
3774
end

test/runtests.jl

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ using MLJBase
66
using MLJModels
77
using MLFlowClient
88
using MLJModelInterface
9+
using StatisticalMeasures
910

1011
include("base.jl")
1112
include("types.jl")
13+
include("service.jl")
14+

test/service.jl

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
@testset "good_name" begin
2+
@test MLJFlow.good_name(rms) == "RootMeanSquaredError"
3+
@test MLJFlow.good_name(macro_f1score) == "MulticlassFScore-beta_1.0"
4+
@test MLJFlow.good_name(log_score) == "LogScore-tol_2.22045e-16"
5+
end
6+
7+
true

0 commit comments

Comments
 (0)