Skip to content

Commit bfd890f

Browse files
authored
feat: Add version validation if model fails to load (#194)
1 parent a5177d3 commit bfd890f

File tree

4 files changed

+64
-21
lines changed

4 files changed

+64
-21
lines changed

src/anemoi/inference/checkpoint.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -258,24 +258,25 @@ def validate_environment(
258258
self,
259259
*,
260260
all_packages: bool = False,
261-
on_difference: Literal["warn", "error", "ignore"] = "warn",
262-
exempt_packages: Optional[List[str]] = None,
263-
) -> bool:
261+
on_difference: Literal["warn", "error", "ignore", "return"] = "warn",
262+
exempt_packages: Optional[list[str]] = None,
263+
) -> Union[bool, str]:
264264
"""Validate the environment.
265265
266266
Parameters
267267
----------
268268
all_packages : bool, optional
269-
Whether to validate all packages, by default False.
270-
on_difference : str, optional
271-
Action to take on difference, by default "warn".
272-
exempt_packages : Optional[List[str]], optional
273-
List of packages to exempt, by default None.
269+
Check all packages in the environment (True) or just anemoi's (False), by default False.
270+
on_difference : Literal['warn', 'error', 'ignore', 'return'], optional
271+
What to do on difference, by default "warn"
272+
exempt_packages : list[str], optional
273+
List of packages to exempt from the check, by default EXEMPT_PACKAGES
274274
275275
Returns
276276
-------
277-
bool
278-
True if the environment is valid, False otherwise.
277+
Union[bool, str]
278+
boolean if `on_difference` is not 'return', otherwise formatted text of the differences
279+
True if environment is valid, False otherwise
279280
"""
280281
return self._metadata.validate_environment(
281282
all_packages=all_packages, on_difference=on_difference, exempt_packages=exempt_packages

src/anemoi/inference/metadata.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -631,23 +631,24 @@ def validate_environment(
631631
self,
632632
*,
633633
all_packages: bool = False,
634-
on_difference: Literal["warn", "error", "ignore"] = "warn",
634+
on_difference: Literal["warn", "error", "ignore", "return"] = "warn",
635635
exempt_packages: Optional[list[str]] = None,
636-
) -> bool:
636+
) -> Union[bool, str]:
637637
"""Validate environment of the checkpoint against the current environment.
638638
639639
Parameters
640640
----------
641641
all_packages : bool, optional
642-
Check all packages in environment or just `anemoi`'s, by default False
643-
on_difference : Literal['warn', 'error', 'ignore'], optional
642+
Check all packages in the environment (True) or just anemoi's (False), by default False.
643+
on_difference : Literal['warn', 'error', 'ignore', 'return'], optional
644644
What to do on difference, by default "warn"
645645
exempt_packages : list[str], optional
646646
List of packages to exempt from the check, by default EXEMPT_PACKAGES
647647
648648
Returns
649649
-------
650-
bool
650+
Union[bool, str]
651+
boolean if `on_difference` is not 'return', otherwise formatted text of the differences
651652
True if environment is valid, False otherwise
652653
653654
Raises

src/anemoi/inference/provenance.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from typing import List
1515
from typing import Literal
1616
from typing import Optional
17+
from typing import Union
18+
from typing import overload
1719

1820
from anemoi.utils.provenance import gather_provenance_info
1921
from packaging.version import Version
@@ -25,8 +27,10 @@
2527
# Complete package name to be exempt
2628
EXEMPT_PACKAGES = [
2729
"anemoi.training",
30+
"anemoi.inference",
2831
"hydra",
2932
"hydra_plugins",
33+
"hydra_plugins.anemoi_searchpath",
3034
"lightning",
3135
"pytorch_lightning",
3236
"lightning_fabric",
@@ -41,13 +45,33 @@
4145
LOG = logging.getLogger(__name__)
4246

4347

48+
@overload
4449
def validate_environment(
4550
metadata: "Metadata",
4651
*,
4752
all_packages: bool = False,
4853
on_difference: Literal["warn", "error", "ignore"] = "warn",
49-
exempt_packages: Optional[list[str]] = None,
50-
) -> bool:
54+
exempt_packages: Optional[List[str]] = None,
55+
) -> bool: ...
56+
57+
58+
@overload
59+
def validate_environment(
60+
metadata: "Metadata",
61+
*,
62+
all_packages: bool = False,
63+
on_difference: Literal["return"] = "return",
64+
exempt_packages: Optional[List[str]] = None,
65+
) -> str: ...
66+
67+
68+
def validate_environment(
69+
metadata: "Metadata",
70+
*,
71+
all_packages: bool = False,
72+
on_difference: Literal["warn", "error", "ignore", "return"] = "warn",
73+
exempt_packages: Optional[List[str]] = None,
74+
) -> Union[bool, str]:
5175
"""Validate environment of the checkpoint against the current environment.
5276
5377
Parameters
@@ -58,12 +82,13 @@ def validate_environment(
5882
Check all packages in environment or just `anemoi`'s, by default False
5983
on_difference : Literal['warn', 'error', 'ignore'], optional
6084
What to do on difference, by default "warn"
61-
exempt_packages : list[str], optional
85+
exempt_packages : List[str], optional
6286
List of packages to exempt from the check, by default EXEMPT_PACKAGES
6387
6488
Returns
6589
-------
66-
bool
90+
Union[bool, str]
91+
boolean if `on_difference` is not 'return', otherwise formatted text of the differences
6792
True if environment is valid, False otherwise
6893
6994
Raises
@@ -105,6 +130,7 @@ def validate_environment(
105130
for module in train_environment["module_versions"].keys():
106131
inference_module_name = module # Due to package name differences between retrieval methods this may change
107132

133+
train_module_version_str = train_environment["module_versions"][module]
108134
if not all_packages and "anemoi" not in module:
109135
continue
110136
elif module in exempt_packages or module.split(".")[0] in EXEMPT_NAMESPACES:
@@ -122,7 +148,9 @@ def validate_environment(
122148
continue
123149
except (ModuleNotFoundError, ImportError):
124150
pass
125-
invalid_messages["missing"].append(f"Missing module in inference environment: {module}")
151+
invalid_messages["missing"].append(
152+
f"Missing module in inference environment: {module}=={train_module_version_str}"
153+
)
126154
continue
127155

128156
train_environment_version = Version(train_environment["module_versions"][module])
@@ -142,6 +170,9 @@ def validate_environment(
142170
if file_record["modified_files"] == 0 and file_record["untracked_files"] == 0:
143171
continue
144172

173+
if git_record in exempt_packages:
174+
continue
175+
145176
if git_record not in inference_environment["git_versions"]:
146177
invalid_messages["uncommitted"].append(
147178
f"Training environment contained uncommitted change missing in inference environment: {git_record}"
@@ -159,6 +190,9 @@ def validate_environment(
159190
if file_record["modified_files"] == 0 and file_record["untracked_files"] == 0:
160191
continue
161192

193+
if git_record in exempt_packages:
194+
continue
195+
162196
if git_record not in train_environment["git_versions"]:
163197
invalid_messages["uncommitted"].append(
164198
f"Inference environment contains uncommited changes missing in training: {git_record}"
@@ -174,6 +208,8 @@ def validate_environment(
174208
raise RuntimeError(text)
175209
elif on_difference == "ignore":
176210
pass
211+
elif on_difference == "return":
212+
return text
177213
else:
178214
raise ValueError(f"Invalid value for `on_difference`: {on_difference}")
179215
return False

src/anemoi/inference/runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,12 @@ def model(self) -> torch.nn.Module:
448448
The loaded model.
449449
"""
450450
with Timer(f"Loading {self.checkpoint}"):
451-
model = torch.load(self.checkpoint.path, map_location=self.device, weights_only=False).to(self.device)
451+
try:
452+
model = torch.load(self.checkpoint.path, map_location=self.device, weights_only=False).to(self.device)
453+
except Exception as e: # Wildcard exception to catch all errors
454+
validation_result = self.checkpoint.validate_environment(on_difference="return")
455+
error_msg = f"Error loading model - {validation_result}"
456+
raise RuntimeError(error_msg) from e
452457
# model.set_inference_options(**self.inference_options)
453458
assert getattr(model, "runner", None) is None, model.runner
454459
model.runner = self

0 commit comments

Comments
 (0)