Skip to content

Commit fbc3399

Browse files
authored
fix: Allow for downstream tests to provide checkpoint mocks (#211)
1 parent edd176f commit fbc3399

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

src/anemoi/inference/testing/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def fake_checkpoints(func: Callable[..., Any]) -> Callable[..., Any]:
3131

3232
@functools.wraps(func)
3333
def wrapper(*args: Any, **kwargs: Any) -> Any:
34+
from unittest.mock import MagicMock
3435
from unittest.mock import patch
3536

3637
from .mock_checkpoint import MockRunConfiguration
@@ -39,6 +40,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
3940

4041
with (
4142
patch("anemoi.inference.checkpoint.load_metadata", mock_load_metadata),
43+
patch("anemoi.inference.provenance.validate_environment", MagicMock()),
4244
patch("torch.load", mock_torch_load),
4345
patch("anemoi.inference.metadata.USE_LEGACY", True),
4446
patch("anemoi.inference.tasks.runner.RunConfiguration", MockRunConfiguration),

src/anemoi/inference/testing/mock_checkpoint.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def mock_load_metadata(path: Optional[str], *, supporting_arrays: bool = True) -
8282
if path is None:
8383
metadata = SIMPLE_METADATA
8484
else:
85-
path = files_for_tests(os.path.join("checkpoints", path))
85+
if not os.path.isabs(path):
86+
path = files_for_tests(os.path.join("checkpoints", path))
8687
name, _ = os.path.splitext(path)
8788
for ext in (".yaml", ".json"):
8889
path = f"{name}{ext}"

0 commit comments

Comments
 (0)