Skip to content

Commit b10e608

Browse files
committed
refactor eda_mp_trj.py using pymatviz.plot_histogram
ruff unignore and fix PD901
1 parent 604cb04 commit b10e608

25 files changed

+186
-218
lines changed

data/mp/build_phase_diagram.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
1515
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
1616
from pymatgen.ext.matproj import MPRester
17+
from pymatviz.io import save_fig
1718
from tqdm import tqdm
1819

1920
from matbench_discovery import MP_DIR, ROOT, today
@@ -29,21 +30,21 @@
2930

3031
# save all ComputedStructureEntries to disk
3132
# mp-15590 appears twice so we drop_duplicates()
32-
df = pd.DataFrame(all_mp_computed_structure_entries, columns=["entry"])
33-
df.index.name = Key.mat_id
34-
df.index = [e.entry_id for e in df.entry]
35-
df.reset_index().to_json(
33+
df_mp_cse = pd.DataFrame(all_mp_computed_structure_entries, columns=["entry"])
34+
df_mp_cse.index.name = Key.mat_id
35+
df_mp_cse.index = [e.entry_id for e in df_mp_cse.entry]
36+
df_mp_cse.reset_index().to_json(
3637
f"{module_dir}/{today}-mp-computed-structure-entries.json.gz",
3738
default_handler=lambda x: x.as_dict(),
3839
)
3940

4041

4142
# %%
4243
data_path = f"{module_dir}/2023-02-07-mp-computed-structure-entries.json.gz"
43-
df = pd.read_json(data_path).set_index(Key.mat_id)
44+
df_mp_cse = pd.read_json(data_path).set_index(Key.mat_id)
4445

4546
# drop the structure, just load ComputedEntry, makes the PPD faster to build and load
46-
mp_computed_entries = [ComputedEntry.from_dict(dct) for dct in tqdm(df.entry)]
47+
mp_computed_entries = [ComputedEntry.from_dict(dct) for dct in tqdm(df_mp_cse.entry)]
4748

4849
print(f"{len(mp_computed_entries)=:,} on {today}")
4950
# len(mp_computed_entries) = 146,323 on 2022-09-16
@@ -118,4 +119,4 @@
118119
xlabel="MP Formation Energy (eV/atom)",
119120
ylabel="Our Formation Energy (eV/atom)",
120121
)
121-
ax.figure.savefig(f"{ROOT}/tmp/{today}-our-vs-mp-formation-energies.webp", dpi=300)
122+
save_fig(ax, f"{ROOT}/tmp/{today}-our-vs-mp-formation-energies.webp", dpi=300)

data/mp/eda_mp_trj.py

+36-71
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
import plotly.express as px
1616
from matplotlib.colors import SymLogNorm
1717
from pymatgen.core import Composition, Element
18-
from pymatviz import count_elements, ptable_heatmap, ptable_heatmap_ratio, ptable_hists
18+
from pymatviz import (
19+
count_elements,
20+
plot_histogram,
21+
ptable_heatmap,
22+
ptable_heatmap_ratio,
23+
ptable_hists,
24+
)
1925
from pymatviz.io import save_fig
2026
from pymatviz.utils import si_fmt
2127
from tqdm import tqdm
@@ -321,91 +327,50 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
321327

322328

323329
# %% plot formation energy per atom distribution
330+
# pdf_kwds defined to use the same figure size for all plots
331+
fig = plot_histogram(df_mp_trj[Key.e_form], bins=300)
332+
# fig.update_yaxes(type="log")
333+
fig.layout.xaxis.title = "E<sub>form</sub> (eV/atom)"
324334
count_col = "Number of Structures"
325-
axes_kwds = dict(linewidth=1, ticks="outside")
326-
pdf_kwds = dict(width=500, height=300)
327-
328-
x_col, y_col = "E<sub>form</sub> (eV/atom)", count_col
329-
df_e_form = locals().get("df_e_form")
330-
331-
if df_e_form is None: # only compute once for speed
332-
e_form_hist = np.histogram(df_mp_trj[Key.e_form], bins=300)
333-
df_e_form = pd.DataFrame(e_form_hist, index=[y_col, x_col]).T.round(3)
334-
335-
fig = px.bar(df_e_form, x=x_col, y=count_col, log_y=True)
336-
337-
bin_width = df_e_form[x_col].diff().iloc[-1] * 1.2
338-
fig.update_traces(width=bin_width, marker_line_width=0)
339-
fig.layout.xaxis.update(**axes_kwds)
340-
fig.layout.yaxis.update(**axes_kwds)
341-
fig.layout.margin = dict(l=5, r=5, b=5, t=5)
335+
fig.layout.yaxis.title = count_col
342336
fig.show()
343-
save_fig(fig, f"{PDF_FIGS}/mp-trj-e-form-hist.pdf", **pdf_kwds)
344-
save_fig(fig, f"{SITE_FIGS}/mp-trj-e-form-hist.svelte")
345-
346-
347-
# %% plot forces distribution
348-
# use numpy to pre-compute histogram
349-
x_col, y_col = "|Forces| (eV/Å)", count_col
350-
df_forces = locals().get("df_forces")
351337

352-
if df_forces is None: # only compute once for speed
353-
forces_hist = np.histogram(
354-
df_mp_trj[Key.forces].explode().explode().abs(), bins=300
355-
)
356-
df_forces = pd.DataFrame(forces_hist, index=[y_col, x_col]).T.round(3)
338+
pdf_kwds = dict(width=500, height=300)
339+
# save_fig(fig, f"{PDF_FIGS}/mp-trj-e-form-hist.pdf", **pdf_kwds)
340+
# save_fig(fig, f"{SITE_FIGS}/mp-trj-e-form-hist.svelte")
357341

358-
fig = px.bar(df_forces, x=x_col, y=count_col, log_y=True)
359342

360-
bin_width = df_forces[x_col].diff().iloc[-1] * 1.2
361-
fig.update_traces(width=bin_width, marker_line_width=0)
362-
fig.layout.xaxis.update(**axes_kwds)
363-
fig.layout.yaxis.update(**axes_kwds)
364-
fig.layout.margin = dict(l=5, r=5, b=5, t=5)
343+
# %% plot forces distribution
344+
fig = plot_histogram(df_mp_trj[Key.forces].explode().explode().abs(), bins=300)
345+
fig.layout.xaxis.title = "|Forces| (eV/Å)"
346+
fig.layout.yaxis.title = count_col
347+
fig.update_yaxes(type="log")
365348
fig.show()
366-
save_fig(fig, f"{PDF_FIGS}/mp-trj-forces-hist.pdf", **pdf_kwds)
367-
save_fig(fig, f"{SITE_FIGS}/mp-trj-forces-hist.svelte")
368349

350+
# save_fig(fig, f"{PDF_FIGS}/mp-trj-forces-hist.pdf", **pdf_kwds)
351+
# save_fig(fig, f"{SITE_FIGS}/mp-trj-forces-hist.svelte")
369352

370-
# %% plot hydrostatic stress distribution
371-
x_col, y_col = "1/3 Tr(σ) (eV/ų)", count_col # noqa: RUF001
372-
df_stresses = locals().get("df_stresses")
373-
374-
if df_stresses is None: # only compute once for speed
375-
stresses_hist = np.histogram(df_mp_trj[Key.stress_trace], bins=300)
376-
df_stresses = pd.DataFrame(stresses_hist, index=[y_col, x_col]).T.round(3)
377353

378-
fig = px.bar(df_stresses, x=x_col, y=y_col, log_y=True)
379-
380-
bin_width = (df_stresses[x_col].diff().mean()) * 1.2
381-
fig.update_traces(width=bin_width, marker_line_width=0)
382-
fig.layout.xaxis.update(**axes_kwds)
383-
fig.layout.yaxis.update(**axes_kwds)
384-
fig.layout.margin = dict(l=5, r=5, b=5, t=5)
354+
# %% plot hydrostatic stress distribution
355+
fig = plot_histogram(df_mp_trj[Key.stress_trace], bins=300)
356+
fig.layout.xaxis.title = "1/3 Tr(σ) (eV/ų)" # noqa: RUF001
357+
fig.layout.yaxis.title = count_col
358+
fig.update_yaxes(type="log")
385359
fig.show()
386360

387-
save_fig(fig, f"{PDF_FIGS}/mp-trj-stresses-hist.pdf", **pdf_kwds)
388-
save_fig(fig, f"{SITE_FIGS}/mp-trj-stresses-hist.svelte")
361+
# save_fig(fig, f"{PDF_FIGS}/mp-trj-stresses-hist.pdf", **pdf_kwds)
362+
# save_fig(fig, f"{SITE_FIGS}/mp-trj-stresses-hist.svelte")
389363

390364

391365
# %% plot magmoms distribution
392-
x_col, y_col = "Magmoms (μ<sub>B</sub>)", count_col
393-
df_magmoms = locals().get("df_magmoms")
394-
395-
if df_magmoms is None: # only compute once for speed
396-
magmoms_hist = np.histogram(df_mp_trj[Key.magmoms].dropna().explode(), bins=300)
397-
df_magmoms = pd.DataFrame(magmoms_hist, index=[y_col, x_col]).T.round(3)
398-
399-
fig = px.bar(df_magmoms, x=x_col, y=y_col, log_y=True)
400-
401-
bin_width = df_magmoms[x_col].diff().iloc[-1] * 1.2
402-
fig.update_traces(width=bin_width, marker_line_width=0)
403-
fig.layout.xaxis.update(**axes_kwds)
404-
fig.layout.yaxis.update(**axes_kwds)
405-
fig.layout.margin = dict(l=5, r=5, b=5, t=5)
366+
fig = plot_histogram(df_mp_trj[Key.magmoms].dropna().explode(), bins=300)
367+
fig.layout.xaxis.title = "Magmoms (μB)"
368+
fig.layout.yaxis.title = count_col
369+
fig.update_yaxes(type="log")
406370
fig.show()
407-
save_fig(fig, f"{PDF_FIGS}/mp-trj-magmoms-hist.pdf", **pdf_kwds)
408-
save_fig(fig, f"{SITE_FIGS}/mp-trj-magmoms-hist.svelte")
371+
372+
# save_fig(fig, f"{PDF_FIGS}/mp-trj-magmoms-hist.pdf", **pdf_kwds)
373+
# save_fig(fig, f"{SITE_FIGS}/mp-trj-magmoms-hist.svelte")
409374

410375

411376
# %%

data/mp/get_mp_energies.py

+18-17
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from aviary.wren.utils import get_aflow_label_from_spglib
1212
from mp_api.client import MPRester
1313
from pymatgen.core import Structure
14+
from pymatviz.io import save_fig # noqa: F401
1415
from pymatviz.powerups import annotate_metrics
1516
from tqdm import tqdm
1617

@@ -47,13 +48,13 @@
4748

4849

4950
# %%
50-
df = pd.DataFrame(docs).set_index(Key.mat_id)
51-
df = df.rename(columns={"formula_pretty": Key.formula, "nsites": Key.n_sites})
51+
df_mp = pd.DataFrame(docs).set_index(Key.mat_id)
52+
df_mp = df_mp.rename(columns={"formula_pretty": Key.formula, "nsites": Key.n_sites})
5253

53-
df_spg = pd.json_normalize(df.pop("symmetry"))[["number", "symbol"]]
54-
df["spacegroup_symbol"] = df_spg.symbol.to_numpy()
54+
df_spg = pd.json_normalize(df_mp.pop("symmetry"))[["number", "symbol"]]
55+
df_mp["spacegroup_symbol"] = df_spg.symbol.to_numpy()
5556

56-
df.energy_type.value_counts().plot.pie(backend="plotly", autopct="%1.1f%%")
57+
df_mp.energy_type.value_counts().plot.pie(backend="plotly", autopct="%1.1f%%")
5758
# GGA: 72.2%, GGA+U: 27.8%
5859

5960

@@ -69,39 +70,39 @@
6970
]
7071
# make sure symmetry detection succeeded for all structures
7172
assert df_cse[Key.wyckoff].str.startswith("invalid").sum() == 0
72-
df[Key.wyckoff] = df_cse[Key.wyckoff]
73+
df_mp[Key.wyckoff] = df_cse[Key.wyckoff]
7374

74-
spg_nums = df[Key.wyckoff].str.split("_").str[2].astype(int)
75+
spg_nums = df_mp[Key.wyckoff].str.split("_").str[2].astype(int)
7576
# make sure all our spacegroup numbers match MP's
7677
assert (spg_nums.sort_index() == df_spg["number"].sort_index()).all()
7778

78-
df.to_csv(DATA_FILES.mp_energies)
79+
df_mp.to_csv(DATA_FILES.mp_energies)
7980
# df = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index(Key.mat_id)
8081

8182

8283
# %% reproduce fig. 1b from https://arxiv.org/abs/2001.10591 (as data consistency check)
83-
ax = df.plot.scatter(
84+
ax = df_mp.plot.scatter(
8485
x=Key.form_energy,
8586
y="decomposition_enthalpy",
8687
alpha=0.1,
8788
xlim=[-5, 1],
8889
ylim=[-1, 1],
89-
color=(df.decomposition_enthalpy > STABILITY_THRESHOLD).map(
90+
color=(df_mp.decomposition_enthalpy > STABILITY_THRESHOLD).map(
9091
{True: "red", False: "blue"}
9192
),
92-
title=f"{today} - {len(df):,} MP entries",
93+
title=f"{today} - {len(df_mp):,} MP entries",
9394
)
9495

95-
annotate_metrics(df.formation_energy_per_atom, df.decomposition_enthalpy)
96+
annotate_metrics(df_mp.formation_energy_per_atom, df_mp.decomposition_enthalpy)
9697
# result on 2023-01-10: plots match. no correlation between formation energy and
9798
# decomposition enthalpy. R^2 = -1.571, MAE = 1.604
98-
# ax.figure.savefig(f"{module_dir}/mp-decomp-enth-vs-e-form.webp", dpi=300)
99+
# save_fig(ax, f"{module_dir}/mp-decomp-enth-vs-e-form.webp", dpi=300)
99100

100101

101102
# %% scatter plot energy above convex hull vs decomposition enthalpy
102103
# https://berkeleytheory.slack.com/archives/C16RE1TUN/p1673887564955539
103-
mask_above_line = df.energy_above_hull - df.decomposition_enthalpy.clip(0) > 0.1
104-
ax = df.plot.scatter(
104+
mask_above_line = df_mp.energy_above_hull - df_mp.decomposition_enthalpy.clip(0) > 0.1
105+
ax = df_mp.plot.scatter(
105106
x="decomposition_enthalpy",
106107
y="energy_above_hull",
107108
color=mask_above_line.map({True: "red", False: "blue"}),
@@ -110,7 +111,7 @@
110111
# most points lie on line y=x for x > 0 and y = 0 for x < 0.
111112
n_above_line = sum(mask_above_line)
112113
ax.set(
113-
title=f"{n_above_line:,} / {len(df):,} = {n_above_line / len(df):.1%} "
114+
title=f"{n_above_line:,} / {len(df_mp):,} = {n_above_line / len(df_mp):.1%} "
114115
"MP materials with\nenergy_above_hull - decomposition_enthalpy.clip(0) > 0.1"
115116
)
116-
# ax.figure.savefig(f"{module_dir}/mp-e-above-hull-vs-decomp-enth.webp", dpi=300)
117+
# save_fig(ax, f"{module_dir}/mp-e-above-hull-vs-decomp-enth.webp", dpi=300)

data/wbm/compare_cse_vs_ce_mp_2020_corrections.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
MaterialsProjectCompatibility,
1414
)
1515
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
16+
from pymatviz.io import save_fig
1617
from tqdm import tqdm
1718

1819
from matbench_discovery import ROOT, today
@@ -93,7 +94,7 @@
9394

9495
ax.axline((0, 0), slope=1, color="gray", linestyle="dashed", zorder=-1)
9596

96-
ax.figure.savefig(f"{ROOT}/tmp/{today}-ce-vs-cse-corrections-outliers.pdf")
97+
save_fig(ax, f"{ROOT}/tmp/{today}-ce-vs-cse-corrections-outliers.pdf")
9798

9899

99100
# %%
@@ -114,7 +115,7 @@
114115
# insight: all materials for which ComputedEntry and ComputedStructureEntry give
115116
# different formation energies are oxides or sulfides for which MP 2020 compat takes
116117
# into account structural information to make more accurate corrections.
117-
ax.figure.savefig(f"{ROOT}/tmp/{today}-ce-vs-cse-e-form-outliers.pdf")
118+
save_fig(ax, f"{ROOT}/tmp/{today}-ce-vs-cse-e-form-outliers.pdf")
118119

119120

120121
# %% below code resulted in

data/wbm/compile_wbm_test_set.py

+14-15
Original file line numberDiff line numberDiff line change
@@ -102,27 +102,31 @@
102102
continue
103103

104104
print(f"{step=}")
105-
df = pd.read_json(json_path).T
105+
df_wbm_step = pd.read_json(json_path).T
106106

107107
# we hash index only for speed
108108
# could use joblib.hash(df) to hash whole df but it's slow
109-
checksum = pd.util.hash_pandas_object(df.index).sum()
109+
checksum = pd.util.hash_pandas_object(df_wbm_step.index).sum()
110110
expected = wbm_structs_index_checksums[step - 1]
111111
assert checksum == expected, (
112112
f"bad df.index checksum for {step=}, {expected=}, got {checksum=}\n"
113113
f"\n{json_path=}"
114114
)
115115

116116
if step == 3:
117-
df = df.drop(index=[f"step_3_{wbm_id}" for wbm_id in bad_struct_ids])
117+
df_wbm_step = df_wbm_step.drop(
118+
index=[f"step_3_{wbm_id}" for wbm_id in bad_struct_ids]
119+
)
118120
# re-index after dropping bad structures to get same indices as summary file
119121
# where IDs are consecutive, i.e. step_3_70801 is followed by step_3_70802,
120122
# not step_3_70804, etc.
121123
# df.index = [f"step_3_{idx + 1}" for idx in range(len(df))]
122124

123125
step_len = step_lens[step - 1]
124-
assert len(df) == step_len, f"bad len for {step=}: {len(df)} != {step_len}"
125-
dfs_wbm_structs[step] = df
126+
assert (
127+
len(df_wbm_step) == step_len
128+
), f"bad len for {step=}: {len(df_wbm_step)} != {step_len}"
129+
dfs_wbm_structs[step] = df_wbm_step
126130

127131

128132
# NOTE step 5 is missing 2 initial structures, see nan_init_structs_ids below
@@ -212,11 +216,11 @@ def increment_wbm_material_id(wbm_id: str) -> str:
212216
print(f"{json_path=} already loaded.")
213217
continue
214218

215-
df = pd.read_json(json_path)
219+
df_wbm_step = pd.read_json(json_path)
216220

217221
step_len = step_lens[step - 1]
218-
dfs_wbm_cses[step] = df
219-
assert len(df) == step_len, f"{step=}: {len(df)} != {step_len}"
222+
dfs_wbm_cses[step] = df_wbm_step
223+
assert len(df_wbm_step) == step_len, f"{step=}: {len(df_wbm_step)} != {step_len}"
220224

221225

222226
# %%
@@ -589,14 +593,9 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
589593
try:
590594
from aviary.wren.utils import get_aflow_label_from_spglib
591595

592-
# add Aflow-style Wyckoff labels for initial and relaxed structures
593-
for key in (Key.init_wyckoff, Key.wyckoff):
594-
if key not in df_wbm:
595-
df_summary[key] = None
596-
597596
# from initial structures
598597
for idx in tqdm(df_wbm.index):
599-
if not pd.isna(df_summary.loc[idx, Key.init_wyckoff]):
598+
if not pd.isna(df_summary.loc[idx].get(Key.init_wyckoff)):
600599
continue # Aflow label already computed
601600
try:
602601
struct = Structure.from_dict(df_wbm.loc[idx, Key.init_struct])
@@ -606,7 +605,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
606605

607606
# from relaxed structures
608607
for idx in tqdm(df_wbm.index):
609-
if not pd.isna(df_summary.loc[idx, Key.wyckoff]):
608+
if not pd.isna(df_summary.loc[idx].get(Key.wyckoff)):
610609
continue
611610

612611
try:

data/wbm/eda_wbm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from matbench_discovery import PDF_FIGS, ROOT, SITE_FIGS, STABILITY_THRESHOLD
2525
from matbench_discovery import plots as plots
2626
from matbench_discovery.data import DATA_FILES, df_wbm
27-
from matbench_discovery.energy import mp_elem_reference_entries
27+
from matbench_discovery.energy import mp_elem_ref_entries
2828
from matbench_discovery.enums import Key, Model
2929
from matbench_discovery.preds import df_each_err
3030

@@ -237,7 +237,7 @@
237237
"Name": entry.composition.elements[0].long_name,
238238
"Material ID": entry.entry_id.replace("-GGA", ""),
239239
}
240-
for key, entry in mp_elem_reference_entries.items()
240+
for key, entry in mp_elem_ref_entries.items()
241241
]
242242
df_ref = pd.DataFrame(mp_ref_data).sort_values(atom_num_col)
243243

0 commit comments

Comments
 (0)