Skip to content

Commit cd27049

Browse files
committed
raise for all missing models
1 parent 84a5bbd commit cd27049

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

matbench_discovery/data.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def load_df_wbm_with_preds(
214214
id_col: str = Key.mat_id,
215215
subset: pd.Index | Sequence[str] | TestSubset | None = None,
216216
max_error_threshold: float | None = 5.0,
217+
raise_on_missing_models: bool = True,
217218
**kwargs: Any,
218219
) -> pd.DataFrame:
219220
"""Load WBM summary dataframe with model predictions from disk.
@@ -234,6 +235,8 @@ def load_df_wbm_with_preds(
234235
a practitioner doing a prospective discovery effort. Predictions exceeding
235236
this threshold will be ignored in all downstream calculations of metrics.
236237
Defaults to 5 eV/atom.
238+
raise_on_missing_models (bool, optional): Whether to raise an exception if any
239+
models are not found. Defaults to True.
237240
**kwargs: Keyword arguments passed to glob_to_df().
238241
239242
Raises:
@@ -256,9 +259,10 @@ def load_df_wbm_with_preds(
256259

257260
df_out = df_wbm.copy()
258261

259-
try:
260-
prog_bar = tqdm(models, disable=not pbar, desc="Loading preds")
261-
for model_name in prog_bar:
262+
prog_bar = tqdm(models, disable=not pbar, desc="Loading preds")
263+
missing_model_exceptions = []
264+
for model_name in prog_bar:
265+
try:
262266
prog_bar.set_postfix_str(model_name)
263267

264268
# use getattr(name) in case model_name is already a Model enum
@@ -297,9 +301,14 @@ def load_df_wbm_with_preds(
297301
print(
298302
f"{n_bad:,} of {n_preds:,} unrealistic preds for {model_name}"
299303
)
300-
except Exception as exc:
301-
exc.add_note(f"Failed to load {model_name=}")
302-
raise
304+
except Exception as exc:
305+
exc.add_note(f"Failed to load {model_name=}")
306+
missing_model_exceptions.append(exc)
307+
308+
if missing_model_exceptions and raise_on_missing_models:
309+
raise ExceptionGroup(
310+
"Failed to load some models", missing_model_exceptions
311+
) from missing_model_exceptions[0]
303312

304313
if subset == TestSubset.uniq_protos:
305314
df_out = df_out.query(MbdKey.uniq_proto)

0 commit comments

Comments
 (0)