Skip to content

Commit 71d77f4

Browse files
committed
fix plot function cumulative_metrics() when 'RMSE' in metrics
simplify test_glob_to_df upload_to_figshare.py remove CHUNK_SIZE global
1 parent 6f06b4c commit 71d77f4

File tree

6 files changed

+32
-35
lines changed

6 files changed

+32
-35
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/astral-sh/ruff-pre-commit
10-
rev: v0.0.289
10+
rev: v0.0.290
1111
hooks:
1212
- id: ruff
1313
args: [--fix]

matbench_discovery/plots.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -667,8 +667,9 @@ def cumulative_metrics(
667667
df_preds (pd.DataFrame): Distance to convex hull predicted by models, one column
668668
per model (in eV / atom). Same as true energy to convex hull plus predicted
669669
minus true formation energy.
670-
metrics (Sequence[str], optional): Which metrics to plot. Defaults to
671-
('Precision', 'Recall'). Also accepts 'F1'.
670+
metrics (Sequence[str], optional): Which metrics to plot. Any subset of
671+
("Precision", "Recall", "F1", "MAE", "RMSE").
672+
Defaults to ('Precision', 'Recall').
672673
stability_threshold (float, optional): Max distance above convex hull before
673674
material is considered unstable. Defaults to 0.
674675
project_end_point ('x' | 'y' | 'xy' | '', optional): Whether to project end
@@ -735,9 +736,9 @@ def cumulative_metrics(
735736
f1_interp = cubic_interpolate(model_range, f1_cum[:n_pred_stable])
736737
dfs["F1"][model_name] = dict(zip(xs_model, f1_interp(xs_model)))
737738

739+
cum_counts = np.arange(1, len(each_true) + 1)
738740
if "MAE" in metrics:
739741
cum_errors = (each_true - each_pred).abs().cumsum()
740-
cum_counts = np.arange(1, len(each_true) + 1)
741742
mae_cum = cum_errors / cum_counts
742743
mae_interp = cubic_interpolate(model_range, mae_cum[:n_pred_stable])
743744
dfs["MAE"][model_name] = dict(zip(xs_model, mae_interp(xs_model)))
@@ -848,6 +849,7 @@ def cumulative_metrics(
848849
text=optimal_recall,
849850
showarrow=False,
850851
# rotate text parallel to line
852+
# angle not quite right, could be improved
851853
textangle=math.degrees(math.cos(n_stable)),
852854
**grid_pos,
853855
)

scripts/upload_to_figshare.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
TOKEN = file.read().split("figshare_token=")[1].split("\n")[0]
2929

3030
BASE_URL = "https://api.figshare.com/v2"
31-
CHUNK_SIZE = 10_000_000 # ~10MB
3231

3332
with open(f"{ROOT}/pyproject.toml", "rb") as file:
3433
pyproject = tomllib.load(file)["project"]
@@ -96,12 +95,16 @@ def create_article(metadata: dict[str, str | int | float]) -> int:
9695
return result["id"]
9796

9897

99-
def get_file_hash_and_size(file_name: str) -> tuple[str, int]:
100-
"""Get the md5 hash and size of a file."""
98+
def get_file_hash_and_size(
99+
file_name: str, chunk_size: int = 10_000_000
100+
) -> tuple[str, int]:
101+
"""Get the md5 hash and size of a file. File is read in chunks of chunk_size bytes.
102+
Default chunk size is 10_000_000 ~= 10MB.
103+
"""
101104
md5 = hashlib.md5()
102105
size = 0
103106
with open(file_name, "rb") as file:
104-
while data := file.read(CHUNK_SIZE):
107+
while data := file.read(chunk_size):
105108
size += len(data)
106109
md5.update(data)
107110
return md5.hexdigest(), size

tests/test_data.py

+14-18
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ def test_load(
5252
) -> None:
5353
filepath = DATA_FILES[data_key]
5454
# intercept HTTP requests and write dummy df to disk instead
55-
with patch("urllib.request.urlretrieve") as urlretrieve:
55+
with patch("urllib.request.urlretrieve") as url_retrieve:
5656
# dummy df with random floats and material_id column
5757
df_csv = pd._testing.makeDataFrame().reset_index(names="material_id")
5858

5959
writer = dummy_df_serialized.to_json if ".json" in filepath else df_csv.to_csv
60-
urlretrieve.side_effect = lambda url, path: writer(path)
60+
url_retrieve.side_effect = lambda _url, path: writer(path)
6161
out = load(
6262
data_key,
6363
hydrate=hydrate,
@@ -70,7 +70,7 @@ def test_load(
7070
assert f"Downloading {data_key!r} from {figshare_urls[data_key][0]}" in stdout
7171

7272
# check we called read_csv/read_json once for each data_name
73-
assert urlretrieve.call_count == 1
73+
assert url_retrieve.call_count == 1
7474

7575
assert isinstance(out, pd.DataFrame), f"{data_key} not a DataFrame"
7676

@@ -201,21 +201,17 @@ def test_df_wbm() -> None:
201201
assert set(df_wbm) > {"bandgap_pbe", "formula", "material_id"}
202202

203203

204-
@pytest.mark.parametrize("pattern", ["tmp/*df.csv", "tmp/*df.json"])
205-
def test_glob_to_df(pattern: str) -> None:
206-
try:
207-
df = pd._testing.makeMixedDataFrame()
204+
@pytest.mark.parametrize("pattern", ["*df.csv", "*df.json"])
205+
def test_glob_to_df(pattern: str, tmp_path: Path) -> None:
206+
df = pd._testing.makeMixedDataFrame()
208207

209-
os.makedirs(f"{ROOT}/tmp", exist_ok=True)
210-
df.to_csv(f"{ROOT}/tmp/dummy_df.csv", index=False)
211-
df.to_json(f"{ROOT}/tmp/dummy_df.json")
208+
os.makedirs(f"{tmp_path}", exist_ok=True)
209+
df.to_csv(f"{tmp_path}/dummy_df.csv", index=False)
210+
df.to_json(f"{tmp_path}/dummy_df.json")
212211

213-
df_out = glob_to_df(pattern)
214-
assert df_out.shape == df.shape
215-
assert list(df_out) == list(df)
212+
df_out = glob_to_df(f"{tmp_path}/{pattern}")
213+
assert df_out.shape == df.shape
214+
assert list(df_out) == list(df)
216215

217-
with pytest.raises(FileNotFoundError):
218-
glob_to_df("foo")
219-
finally:
220-
os.remove(f"{ROOT}/tmp/dummy_df.csv")
221-
os.remove(f"{ROOT}/tmp/dummy_df.json")
216+
with pytest.raises(FileNotFoundError):
217+
glob_to_df("foo")

tests/test_models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ def test_model_dirs_have_metadata() -> None:
3434
err_msg = f"Invalid {key=}, expected {required[key]} in {model_dir}"
3535
assert isinstance(metadata[key], required[key]), err_msg # type: ignore
3636

37-
authors, date_added, mbd_version, model_name, model_version, repo = (
37+
authors, date_added, mbd_version, yml_model_name, model_version, repo = (
3838
metadata[key] for key in list(required)[:-1]
3939
)
40+
assert model_name == yml_model_name, f"{model_name=} != {yml_model_name=}"
4041

4142
# make sure all keys are valid
4243
for name in model_name if isinstance(model_name, list) else [model_name]:

tests/test_plots.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@
3838
@pytest.mark.parametrize("backend", ["matplotlib", "plotly"])
3939
@pytest.mark.parametrize(
4040
"metrics",
41-
[
42-
("Recall",),
43-
("Recall", "MAE"),
44-
("Recall", "Precision", "F1"),
45-
],
41+
[("Recall",), ("Recall", "MAE"), ("Recall", "Precision", "RMSE")],
4642
)
4743
def test_cumulative_metrics(
4844
project_end_point: AxLine,
@@ -68,9 +64,8 @@ def test_cumulative_metrics(
6864
assert {ax.get_ylabel() for ax in fig.axes} >= {*metrics}
6965
elif backend == "plotly":
7066
assert isinstance(fig, go.Figure)
71-
# TODO fix AssertionError {'Recall', 'metric=F1'} == {'F1', 'Recall'}
72-
# subplot_titles = [anno.text for anno in fig.layout.annotations][:len(metrics)]
73-
# assert set(subplot_titles) == set(metrics)
67+
subplot_titles = {anno.text.split("=")[-1] for anno in fig.layout.annotations}
68+
assert subplot_titles >= set(metrics)
7469

7570

7671
def test_cumulative_metrics_raises() -> None:

0 commit comments

Comments
 (0)