-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.jl
44 lines (36 loc) · 1.47 KB
/
base.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
function log_evaluation(logger::MLFlowLogger, performance_evaluation)
experiment = getorcreateexperiment(logger.service, logger.experiment_name;
artifact_location=logger.artifact_location)
model_name = name(performance_evaluation.model)
run = createrun(logger.service, experiment;
run_name="$(model_name) run",
tags=[
Dict("key" => "resampling", "value" => string(performance_evaluation.resampling)),
Dict("key" => "repeats", "value" => string(performance_evaluation.repeats)),
]
)
logmodelparams(logger.service, run, performance_evaluation.model)
logmachinemeasures(logger.service, run, performance_evaluation.measure,
performance_evaluation.measurement)
updaterun(logger.service, run, "FINISHED")
end
function save(logger::MLFlowLogger, mach::Machine)
io = IOBuffer()
save(io, mach)
seekstart(io)
model = mach.model
model_name = name(model)
experiment = getorcreateexperiment(logger.service, logger.experiment_name,
artifact_location=logger.artifact_location)
run = createrun(logger.service, experiment;
run_name="$(model_name) run")
logmodelparams(logger.service, run, model)
fname = "$(model_name).jls"
logartifact(logger.service, run, fname, io)
updaterun(logger.service, run, "FINISHED")
end
"""
service(logger::MLFlowLogger)
Returns the MLFlow service of a logger.
"""
service(logger::MLFlowLogger) = logger.service