Skip to content

Commit c3f9914

Browse files
authored
[Auto3DSeg] Add mlflow support in autorunner. (#7176)
Add MLflow support in AutoRunner Class. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: dongy <[email protected]>
1 parent cf886e7 commit c3f9914

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

monai/apps/auto3dseg/auto_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ class AutoRunner:
8383
zip url will be downloaded and extracted into the work_dir.
8484
allow_skip: a switch passed to BundleGen process which determines if some Algo in the default templates
8585
can be skipped based on the analysis on the dataset from Auto3DSeg DataAnalyzer.
86+
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of the remote
87+
tracking Server; MLflow runs will be recorded locally in algorithms' model folder if the value is None.
8688
kwargs: image writing parameters for the ensemble inference. The kwargs format follows the SaveImage
8789
transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage.
8890
@@ -209,6 +211,7 @@ def __init__(
209211
not_use_cache: bool = False,
210212
templates_path_or_url: str | None = None,
211213
allow_skip: bool = True,
214+
mlflow_tracking_uri: str | None = None,
212215
**kwargs: Any,
213216
):
214217
logger.info(f"AutoRunner using work directory {work_dir}")
@@ -220,6 +223,7 @@ def __init__(
220223
self.algos = algos
221224
self.templates_path_or_url = templates_path_or_url
222225
self.allow_skip = allow_skip
226+
self.mlflow_tracking_uri = mlflow_tracking_uri
223227
self.kwargs = deepcopy(kwargs)
224228

225229
if input is None and os.path.isfile(self.data_src_cfg_name):
@@ -783,6 +787,7 @@ def run(self):
783787
templates_path_or_url=self.templates_path_or_url,
784788
data_stats_filename=self.datastats_filename,
785789
data_src_cfg_name=self.data_src_cfg_name,
790+
mlflow_tracking_uri=self.mlflow_tracking_uri,
786791
)
787792

788793
if self.gpu_customization:

monai/apps/auto3dseg/bundle_gen.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(self, template_path: PathLike):
8585
self.template_path = template_path
8686
self.data_stats_files = ""
8787
self.data_list_file = ""
88+
self.mlflow_tracking_uri = None
8889
self.output_path = ""
8990
self.name = ""
9091
self.best_metric = None
@@ -129,6 +130,17 @@ def set_data_source(self, data_src_cfg: str) -> None:
129130
"""
130131
self.data_list_file = data_src_cfg
131132

133+
def set_mlflow_tracking_uri(self, mlflow_tracking_uri: str | None) -> None:
134+
"""
135+
Set the tracking URI for MLflow server
136+
137+
Args:
138+
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of
139+
the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
140+
the value is None.
141+
"""
142+
self.mlflow_tracking_uri = mlflow_tracking_uri # type: ignore
143+
132144
def fill_template_config(self, data_stats_filename: str, algo_path: str, **kwargs: Any) -> dict:
133145
"""
134146
The configuration files defined when constructing this Algo instance might not have a complete training
@@ -432,6 +444,9 @@ class BundleGen(AlgoGen):
432444
data_stats_filename: the path to the data stats file (generated by DataAnalyzer).
433445
data_src_cfg_name: the path to the data source config YAML file. The config will be in a form of
434446
{"modality": "ct", "datalist": "path_to_json_datalist", "dataroot": "path_dir_data"}.
447+
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of
448+
the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
449+
the value is None.
435450
.. code-block:: bash
436451
437452
python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/datastats.yaml"
@@ -444,6 +459,7 @@ def __init__(
444459
templates_path_or_url: str | None = None,
445460
data_stats_filename: str | None = None,
446461
data_src_cfg_name: str | None = None,
462+
mlflow_tracking_uri: str | None = None,
447463
):
448464
if algos is None or isinstance(algos, (list, tuple, str)):
449465
if templates_path_or_url is None:
@@ -496,6 +512,7 @@ def __init__(
496512

497513
self.data_stats_filename = data_stats_filename
498514
self.data_src_cfg_name = data_src_cfg_name
515+
self.mlflow_tracking_uri = mlflow_tracking_uri
499516
self.history: list[dict] = []
500517

501518
def set_data_stats(self, data_stats_filename: str) -> None:
@@ -524,6 +541,21 @@ def get_data_src(self):
524541
"""Get the data source filename"""
525542
return self.data_src_cfg_name
526543

544+
def set_mlflow_tracking_uri(self, mlflow_tracking_uri):
545+
"""
546+
Set the tracking URI for MLflow server
547+
548+
Args:
549+
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of
550+
the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
551+
the value is None.
552+
"""
553+
self.mlflow_tracking_uri = mlflow_tracking_uri
554+
555+
def get_mlflow_tracking_uri(self):
556+
"""Get the tracking URI for MLflow server"""
557+
return self.mlflow_tracking_uri
558+
527559
def get_history(self) -> list:
528560
"""Get the history of the bundleAlgo object with their names/identifiers"""
529561
return self.history
@@ -575,9 +607,11 @@ def generate(
575607
for f_id in ensure_tuple(fold_idx):
576608
data_stats = self.get_data_stats()
577609
data_src_cfg = self.get_data_src()
610+
mlflow_tracking_uri = self.get_mlflow_tracking_uri()
578611
gen_algo = deepcopy(algo)
579612
gen_algo.set_data_stats(data_stats)
580613
gen_algo.set_data_source(data_src_cfg)
614+
gen_algo.set_mlflow_tracking_uri(mlflow_tracking_uri)
581615
name = f"{gen_algo.name}_{f_id}"
582616

583617
if allow_skip:

0 commit comments

Comments
 (0)