|
18 | 18 | import pytest
|
19 | 19 | from lightning.pytorch import Trainer
|
20 | 20 | from lightning.pytorch.demos.boring_classes import BoringModel
|
21 |
| -from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger, _get_resolve_tags |
| 21 | +from lightning.pytorch.loggers.mlflow import ( |
| 22 | + _MLFLOW_AVAILABLE, |
| 23 | + _MLFLOW_SYNCHRONOUS_AVAILABLE, |
| 24 | + MLFlowLogger, |
| 25 | + _get_resolve_tags, |
| 26 | +) |
22 | 27 |
|
23 | 28 |
|
24 | 29 | def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None):
|
@@ -260,6 +265,58 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path):
|
260 | 265 | )
|
261 | 266 |
|
262 | 267 |
|
| 268 | +@pytest.mark.parametrize("synchronous", [False, True]) |
| 269 | +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) |
| 270 | +def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path, synchronous): |
| 271 | + """Test that the logger calls methods on the mlflow experiment with the specified synchronous flag.""" |
| 272 | + if not _MLFLOW_SYNCHRONOUS_AVAILABLE: |
| 273 | + pytest.skip("this test requires mlflow>=2.8.0") |
| 274 | + |
| 275 | + time = mlflow_mock.entities.time |
| 276 | + metric = mlflow_mock.entities.Metric |
| 277 | + param = mlflow_mock.entities.Param |
| 278 | + time.return_value = 1 |
| 279 | + |
| 280 | + mlflow_client = mlflow_mock.tracking.MlflowClient.return_value |
| 281 | + mlflow_client.get_experiment_by_name.return_value = None |
| 282 | + logger = MLFlowLogger( |
| 283 | + "test", save_dir=str(tmp_path), artifact_location="my_artifact_location", synchronous=synchronous |
| 284 | + ) |
| 285 | + |
| 286 | + params = {"test": "test_param"} |
| 287 | + logger.log_hyperparams(params) |
| 288 | + |
| 289 | + mlflow_client.log_batch.assert_called_once_with( |
| 290 | + run_id=logger.run_id, params=[param(key="test", value="test_param")], synchronous=synchronous |
| 291 | + ) |
| 292 | + param.assert_called_with(key="test", value="test_param") |
| 293 | + |
| 294 | + metrics = {"some_metric": 10} |
| 295 | + logger.log_metrics(metrics) |
| 296 | + |
| 297 | + mlflow_client.log_batch.assert_called_with( |
| 298 | + run_id=logger.run_id, |
| 299 | + metrics=[metric(key="some_metric", value=10, timestamp=1000, step=0)], |
| 300 | + synchronous=synchronous, |
| 301 | + ) |
| 302 | + metric.assert_called_with(key="some_metric", value=10, timestamp=1000, step=0) |
| 303 | + |
| 304 | + mlflow_client.create_experiment.assert_called_once_with(name="test", artifact_location="my_artifact_location") |
| 305 | + |
| 306 | + |
| 307 | +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) |
| 308 | +@mock.patch.dict("lightning.pytorch.loggers.mlflow.__dict__", {"_MLFLOW_SYNCHRONOUS_AVAILABLE": False}) |
| 309 | +def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path): |
| 310 | + """Test that the logger does not support synchronous flag.""" |
| 311 | + time = mlflow_mock.entities.time |
| 312 | + time.return_value = 1 |
| 313 | + |
| 314 | + mlflow_client = mlflow_mock.tracking.MlflowClient.return_value |
| 315 | + mlflow_client.get_experiment_by_name.return_value = None |
| 316 | + with pytest.raises(ModuleNotFoundError): |
| 317 | + MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location", synchronous=True) |
| 318 | + |
| 319 | + |
263 | 320 | @mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
|
264 | 321 | def test_mlflow_logger_with_long_param_value(mlflow_mock, tmp_path):
|
265 | 322 | """Test that long parameter values are truncated to 250 characters."""
|
|
0 commit comments