Skip to content

Commit ce88483

Browse files
Alexander Jipaazzhipa
andauthored
Add synchronous parameter to MLflowLogger (#19639)
Co-authored-by: Alexander Jipa <[email protected]>
1 parent 8947d13 commit ce88483

File tree

2 files changed

+67
-4
lines changed

2 files changed

+67
-4
lines changed

src/lightning/pytorch/loggers/mlflow.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
log = logging.getLogger(__name__)
4343
LOCAL_FILE_URI_PREFIX = "file:"
4444
_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0", "mlflow")
45+
_MLFLOW_SYNCHRONOUS_AVAILABLE = RequirementCache("mlflow>=2.8.0", "mlflow")
4546

4647

4748
class MLFlowLogger(Logger):
@@ -100,6 +101,8 @@ def any_lightning_module_function_or_hook(self):
100101
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
101102
default.
102103
run_id: The run identifier of the experiment. If not provided, a new run is started.
104+
synchronous: Hints mlflow whether to block the execution for every logging call until complete where
105+
applicable. Requires mlflow >= 2.8.0
103106
104107
Raises:
105108
ModuleNotFoundError:
@@ -120,9 +123,12 @@ def __init__(
120123
prefix: str = "",
121124
artifact_location: Optional[str] = None,
122125
run_id: Optional[str] = None,
126+
synchronous: Optional[bool] = None,
123127
):
124128
if not _MLFLOW_AVAILABLE:
125129
raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE))
130+
if synchronous is not None and not _MLFLOW_SYNCHRONOUS_AVAILABLE:
131+
raise ModuleNotFoundError("`synchronous` requires mlflow>=2.8.0")
126132
super().__init__()
127133
if not tracking_uri:
128134
tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}"
@@ -138,7 +144,7 @@ def __init__(
138144
self._checkpoint_callback: Optional[ModelCheckpoint] = None
139145
self._prefix = prefix
140146
self._artifact_location = artifact_location
141-
147+
self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
142148
self._initialized = False
143149

144150
from mlflow.tracking import MlflowClient
@@ -233,7 +239,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
233239

234240
# Log in chunks of 100 parameters (the maximum allowed by MLflow).
235241
for idx in range(0, len(params_list), 100):
236-
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])
242+
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100], **self._log_batch_kwargs)
237243

238244
@override
239245
@rank_zero_only
@@ -261,7 +267,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
261267
k = new_k
262268
metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0))
263269

264-
self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list)
270+
self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list, **self._log_batch_kwargs)
265271

266272
@override
267273
@rank_zero_only

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
import pytest
1919
from lightning.pytorch import Trainer
2020
from lightning.pytorch.demos.boring_classes import BoringModel
21-
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger, _get_resolve_tags
21+
from lightning.pytorch.loggers.mlflow import (
22+
_MLFLOW_AVAILABLE,
23+
_MLFLOW_SYNCHRONOUS_AVAILABLE,
24+
MLFlowLogger,
25+
_get_resolve_tags,
26+
)
2227

2328

2429
def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None):
@@ -260,6 +265,58 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path):
260265
)
261266

262267

268+
@pytest.mark.parametrize("synchronous", [False, True])
269+
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
270+
def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path, synchronous):
271+
"""Test that the logger calls methods on the mlflow experiment with the specified synchronous flag."""
272+
if not _MLFLOW_SYNCHRONOUS_AVAILABLE:
273+
pytest.skip("this test requires mlflow>=2.8.0")
274+
275+
time = mlflow_mock.entities.time
276+
metric = mlflow_mock.entities.Metric
277+
param = mlflow_mock.entities.Param
278+
time.return_value = 1
279+
280+
mlflow_client = mlflow_mock.tracking.MlflowClient.return_value
281+
mlflow_client.get_experiment_by_name.return_value = None
282+
logger = MLFlowLogger(
283+
"test", save_dir=str(tmp_path), artifact_location="my_artifact_location", synchronous=synchronous
284+
)
285+
286+
params = {"test": "test_param"}
287+
logger.log_hyperparams(params)
288+
289+
mlflow_client.log_batch.assert_called_once_with(
290+
run_id=logger.run_id, params=[param(key="test", value="test_param")], synchronous=synchronous
291+
)
292+
param.assert_called_with(key="test", value="test_param")
293+
294+
metrics = {"some_metric": 10}
295+
logger.log_metrics(metrics)
296+
297+
mlflow_client.log_batch.assert_called_with(
298+
run_id=logger.run_id,
299+
metrics=[metric(key="some_metric", value=10, timestamp=1000, step=0)],
300+
synchronous=synchronous,
301+
)
302+
metric.assert_called_with(key="some_metric", value=10, timestamp=1000, step=0)
303+
304+
mlflow_client.create_experiment.assert_called_once_with(name="test", artifact_location="my_artifact_location")
305+
306+
307+
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
308+
@mock.patch.dict("lightning.pytorch.loggers.mlflow.__dict__", {"_MLFLOW_SYNCHRONOUS_AVAILABLE": False})
309+
def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path):
310+
"""Test that the logger does not support synchronous flag."""
311+
time = mlflow_mock.entities.time
312+
time.return_value = 1
313+
314+
mlflow_client = mlflow_mock.tracking.MlflowClient.return_value
315+
mlflow_client.get_experiment_by_name.return_value = None
316+
with pytest.raises(ModuleNotFoundError):
317+
MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location", synchronous=True)
318+
319+
263320
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
264321
def test_mlflow_logger_with_long_param_value(mlflow_mock, tmp_path):
265322
"""Test that long parameter values are truncated to 250 characters."""

0 commit comments

Comments
 (0)