14
14
from typing import List
15
15
from typing import Literal
16
16
from typing import Optional
17
+ from typing import Union
18
+ from typing import overload
17
19
18
20
from anemoi .utils .provenance import gather_provenance_info
19
21
from packaging .version import Version
25
27
# Complete package name to be exempt
26
28
EXEMPT_PACKAGES = [
27
29
"anemoi.training" ,
30
+ "anemoi.inference" ,
28
31
"hydra" ,
29
32
"hydra_plugins" ,
33
+ "hydra_plugins.anemoi_searchpath" ,
30
34
"lightning" ,
31
35
"pytorch_lightning" ,
32
36
"lightning_fabric" ,
41
45
LOG = logging .getLogger (__name__ )
42
46
43
47
48
+ @overload
44
49
def validate_environment (
45
50
metadata : "Metadata" ,
46
51
* ,
47
52
all_packages : bool = False ,
48
53
on_difference : Literal ["warn" , "error" , "ignore" ] = "warn" ,
49
- exempt_packages : Optional [list [str ]] = None ,
50
- ) -> bool :
54
+ exempt_packages : Optional [List [str ]] = None ,
55
+ ) -> bool : ...
56
+
57
+
58
+ @overload
59
+ def validate_environment (
60
+ metadata : "Metadata" ,
61
+ * ,
62
+ all_packages : bool = False ,
63
+ on_difference : Literal ["return" ] = "return" ,
64
+ exempt_packages : Optional [List [str ]] = None ,
65
+ ) -> str : ...
66
+
67
+
68
+ def validate_environment (
69
+ metadata : "Metadata" ,
70
+ * ,
71
+ all_packages : bool = False ,
72
+ on_difference : Literal ["warn" , "error" , "ignore" , "return" ] = "warn" ,
73
+ exempt_packages : Optional [List [str ]] = None ,
74
+ ) -> Union [bool , str ]:
51
75
"""Validate environment of the checkpoint against the current environment.
52
76
53
77
Parameters
@@ -58,12 +82,13 @@ def validate_environment(
58
82
Check all packages in environment or just `anemoi`'s, by default False
59
83
on_difference : Literal['warn', 'error', 'ignore'], optional
60
84
What to do on difference, by default "warn"
61
- exempt_packages : list [str], optional
85
+ exempt_packages : List [str], optional
62
86
List of packages to exempt from the check, by default EXEMPT_PACKAGES
63
87
64
88
Returns
65
89
-------
66
- bool
90
+ Union[bool, str]
91
+ boolean if `on_difference` is not 'return', otherwise formatted text of the differences
67
92
True if environment is valid, False otherwise
68
93
69
94
Raises
@@ -105,6 +130,7 @@ def validate_environment(
105
130
for module in train_environment ["module_versions" ].keys ():
106
131
inference_module_name = module # Due to package name differences between retrieval methods this may change
107
132
133
+ train_module_version_str = train_environment ["module_versions" ][module ]
108
134
if not all_packages and "anemoi" not in module :
109
135
continue
110
136
elif module in exempt_packages or module .split ("." )[0 ] in EXEMPT_NAMESPACES :
@@ -122,7 +148,9 @@ def validate_environment(
122
148
continue
123
149
except (ModuleNotFoundError , ImportError ):
124
150
pass
125
- invalid_messages ["missing" ].append (f"Missing module in inference environment: { module } " )
151
+ invalid_messages ["missing" ].append (
152
+ f"Missing module in inference environment: { module } =={ train_module_version_str } "
153
+ )
126
154
continue
127
155
128
156
train_environment_version = Version (train_environment ["module_versions" ][module ])
@@ -142,6 +170,9 @@ def validate_environment(
142
170
if file_record ["modified_files" ] == 0 and file_record ["untracked_files" ] == 0 :
143
171
continue
144
172
173
+ if git_record in exempt_packages :
174
+ continue
175
+
145
176
if git_record not in inference_environment ["git_versions" ]:
146
177
invalid_messages ["uncommitted" ].append (
147
178
f"Training environment contained uncommitted change missing in inference environment: { git_record } "
@@ -159,6 +190,9 @@ def validate_environment(
159
190
if file_record ["modified_files" ] == 0 and file_record ["untracked_files" ] == 0 :
160
191
continue
161
192
193
+ if git_record in exempt_packages :
194
+ continue
195
+
162
196
if git_record not in train_environment ["git_versions" ]:
163
197
invalid_messages ["uncommitted" ].append (
164
198
f"Inference environment contains uncommited changes missing in training: { git_record } "
@@ -174,6 +208,8 @@ def validate_environment(
174
208
raise RuntimeError (text )
175
209
elif on_difference == "ignore" :
176
210
pass
211
+ elif on_difference == "return" :
212
+ return text
177
213
else :
178
214
raise ValueError (f"Invalid value for `on_difference`: { on_difference } " )
179
215
return False
0 commit comments