Skip to content

Commit 62a6458

Browse files
committed
eda_mp_trj.py add code for mp-trj-forces-ptable-hists.pdf showing the distribution of forces for each element in the periodic table
per_element_errors.py add ptable-element-wise-each-error-hists-{model}.pdf
1 parent 46366d1 commit 62a6458

File tree

3 files changed

+97
-21
lines changed

3 files changed

+97
-21
lines changed

data/mp/eda_mp_trj.py

+81-19
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
# %%
55
import io
66
import os
7+
from typing import Any
78
from zipfile import ZipFile
89

910
import ase
1011
import ase.io.extxyz
12+
import matplotlib.colors
13+
import matplotlib.pyplot as plt
1114
import numpy as np
1215
import pandas as pd
1316
import plotly.express as px
@@ -37,6 +40,7 @@
3740
e_form_per_atom_col = "ef_per_atom"
3841
magmoms_col = "magmoms"
3942
forces_col = "forces"
43+
elems_col = "symbols"
4044

4145

4246
# %% load MP element counts by occurrence to compute ratio with MPtrj
@@ -46,7 +50,7 @@
4650
df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index(id_col)
4751

4852

49-
# %% --- load preprocessed MPtrj summary data ---
53+
# %% --- load preprocessed MPtrj summary data if available ---
5054
mp_trj_summary_path = f"{DATA_DIR}/mp/mp-trj-2022-09-summary.json.bz2"
5155
if os.path.isfile(mp_trj_summary_path):
5256
df_mp_trj = pd.read_json(mp_trj_summary_path)
@@ -84,9 +88,9 @@
8488

8589
df_mp_trj = pd.DataFrame(
8690
{
87-
info_to_id(atoms.info): {"formula": str(atoms.symbols)}
91+
info_to_id(atoms.info): atoms.info
8892
| {key: atoms.arrays.get(key) for key in ("forces", "magmoms")}
89-
| atoms.info
93+
| {"formula": str(atoms.symbols), elems_col: atoms.symbols}
9094
for atoms_list in tqdm(mp_trj_atoms.values(), total=len(mp_trj_atoms))
9195
for atoms in atoms_list
9296
}
@@ -106,41 +110,97 @@
106110
df_mp_trj.to_json(mp_trj_summary_path)
107111

108112

113+
# %%
114+
def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
115+
"""Annotate each periodic table tile with the number of values in its histogram."""
116+
facecolor = cmap(norm(np.sum(len(hist_vals)))) if hist_vals else "none"
117+
bbox = dict(facecolor=facecolor, alpha=0.4, pad=2, edgecolor="none")
118+
return dict(text=si_fmt(len(hist_vals), ".0f"), bbox=bbox)
119+
120+
109121
# %% plot per-element magmom histograms
110122
magmom_hist_path = f"{DATA_DIR}/mp/mp-trj-2022-09-elem-magmoms.json.bz2"
111123

112124
if os.path.isfile(magmom_hist_path):
113125
mp_trj_elem_magmoms = pd.read_json(magmom_hist_path, typ="series")
114126
elif "mp_trj_elem_magmoms" not in locals():
115-
df_mp_trj_magmom = pd.DataFrame(
116-
{
117-
info_to_id(atoms.info): (
118-
dict(zip(atoms.symbols, atoms.arrays["magmoms"], strict=True))
119-
if magmoms_col in atoms.arrays
120-
else None
121-
)
122-
for frame_id in tqdm(mp_trj_atoms)
123-
for atoms in mp_trj_atoms[frame_id]
124-
}
125-
).T.dropna(axis=0, how="all")
127+
# project magmoms onto symbols in dict
128+
df_mp_trj_elem_magmom = pd.DataFrame(
129+
[
130+
dict(zip(elems, magmoms))
131+
for elems, magmoms in df_mp_trj.set_index(elems_col)[magmoms_col]
132+
.dropna()
133+
.items()
134+
]
135+
)
126136

127137
mp_trj_elem_magmoms = {
128-
col: list(df_mp_trj_magmom[col].dropna()) for col in df_mp_trj_magmom
138+
col: list(df_mp_trj_elem_magmom[col].dropna()) for col in df_mp_trj_elem_magmom
129139
}
130140
pd.Series(mp_trj_elem_magmoms).to_json(magmom_hist_path)
131141

142+
cmap = plt.cm.get_cmap("viridis")
143+
norm = matplotlib.colors.LogNorm(vmin=1, vmax=150_000)
144+
132145
ax = ptable_hists(
133146
mp_trj_elem_magmoms,
134147
symbol_pos=(0.2, 0.8),
135148
log=True,
136149
cbar_title="Magmoms ($μ_B$)",
150+
cbar_title_kwds=dict(fontsize=16),
151+
cbar_coords=(0.18, 0.85, 0.42, 0.02),
137152
# annotate each element with its number of magmoms in MPtrj
138-
anno_kwds=dict(text=lambda hist_vals: si_fmt(len(hist_vals), ".0f")),
153+
anno_kwds=tile_count_anno,
139154
)
140155

156+
cbar_ax = ax.figure.add_axes([0.26, 0.78, 0.25, 0.015])
157+
cbar = matplotlib.colorbar.ColorbarBase(
158+
cbar_ax, cmap=cmap, norm=norm, orientation="horizontal"
159+
)
141160
save_fig(ax, f"{PDF_FIGS}/mp-trj-magmoms-ptable-hists.pdf")
142161

143162

163+
# %% plot per-element force histograms
164+
force_hist_path = f"{DATA_DIR}/mp/mp-trj-2022-09-elem-forces.json.bz2"
165+
166+
if os.path.isfile(force_hist_path):
167+
mp_trj_elem_forces = pd.read_json(force_hist_path, typ="series")
168+
elif "mp_trj_elem_forces" not in locals():
169+
df_mp_trj_elem_forces = pd.DataFrame(
170+
[
171+
dict(zip(elems, np.abs(forces).mean(axis=1)))
172+
for elems, forces in df_mp_trj.set_index(elems_col)[forces_col].items()
173+
]
174+
)
175+
mp_trj_elem_forces = {
176+
col: list(df_mp_trj_elem_forces[col].dropna()) for col in df_mp_trj_elem_forces
177+
}
178+
mp_trj_elem_forces = pd.Series(mp_trj_elem_forces)
179+
mp_trj_elem_forces.to_json(force_hist_path)
180+
181+
cmap = plt.cm.get_cmap("viridis")
182+
norm = matplotlib.colors.LogNorm(vmin=1, vmax=1_000_000)
183+
184+
max_force = 10 # eV/Å
185+
ax = ptable_hists(
186+
mp_trj_elem_forces.copy().map(lambda x: [val for val in x if val < max_force]),
187+
symbol_pos=(0.3, 0.8),
188+
log=True,
189+
cbar_title="1/3 Σ|Forces| (eV/Å)",
190+
cbar_title_kwds=dict(fontsize=16),
191+
cbar_coords=(0.18, 0.85, 0.42, 0.02),
192+
x_range=(0, max_force),
193+
anno_kwds=tile_count_anno,
194+
)
195+
196+
cbar_ax = ax.figure.add_axes([0.26, 0.78, 0.25, 0.015])
197+
cbar = matplotlib.colorbar.ColorbarBase(
198+
cbar_ax, cmap=cmap, norm=norm, orientation="horizontal"
199+
)
200+
201+
save_fig(ax, f"{PDF_FIGS}/mp-trj-forces-ptable-hists.pdf")
202+
203+
144204
# %%
145205
elem_counts: dict[str, dict[str, int]] = {}
146206
for count_mode in ("composition", "occurrence"):
@@ -153,9 +213,11 @@
153213

154214

155215
# %%
216+
count_mode = "composition"
156217
if "trj_elem_counts" not in locals():
157218
trj_elem_counts = pd.read_json(
158-
f"{data_page}/mp-trj-element-counts-by-occurrence.json", typ="series"
219+
f"{data_page}/mp-trj-element-counts-by-{count_mode}.json",
220+
typ="series",
159221
)
160222

161223
excl_elems = "He Ne Ar Kr Xe".split() if (excl_noble := False) else ()
@@ -167,12 +229,12 @@
167229
zero_color="#efefef",
168230
log=(log := True),
169231
exclude_elements=excl_elems, # drop noble gases
170-
cbar_range=None if excl_noble else (2000, None),
232+
cbar_range=None if excl_noble else (10_000, None),
171233
label_font_size=17,
172234
value_font_size=14,
173235
)
174236

175-
img_name = f"mp-trj-element-counts-by-occurrence{'-log' if log else ''}"
237+
img_name = f"mp-trj-element-counts-by-{count_mode}{'-log' if log else ''}"
176238
if excl_noble:
177239
img_name += "-excl-noble"
178240
save_fig(ax_ptable, f"{PDF_FIGS}/{img_name}.pdf")

matbench_discovery/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@
111111
global_layout = dict(
112112
# colorway=px.colors.qualitative.Pastel,
113113
# colorway=colorway,
114-
margin=dict(l=30, r=20, t=60, b=20),
114+
# margin=dict(l=30, r=20, t=60, b=20),
115115
paper_bgcolor="rgba(0,0,0,0)",
116116
# plot_bgcolor="rgba(0,0,0,0)",
117117
font_size=13,

scripts/model_figs/per_element_errors.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pandas as pd
99
import plotly.express as px
1010
from pymatgen.core import Composition, Element
11-
from pymatviz import ptable_heatmap_plotly
11+
from pymatviz import ptable_heatmap_plotly, ptable_hists
1212
from pymatviz.io import save_fig
1313
from pymatviz.utils import bin_df_cols, df_ptable
1414
from tqdm import tqdm
@@ -256,3 +256,17 @@
256256

257257
fig.show()
258258
save_fig(fig, f"{SITE_FIGS}/each-error-vs-least-prevalent-element-in-struct.svelte")
259+
260+
261+
# %% plot histogram of model errors for each element
262+
model = "MACE"
263+
ax = ptable_hists(
264+
df_frac_comp * (df_each_err[model].to_numpy()[:, None]),
265+
log=True,
266+
cbar_title=f"{model} convex hull distance errors (eV/atom)",
267+
x_range=(-0.5, 0.5),
268+
symbol_pos=(0.1, 0.8),
269+
)
270+
271+
img_name = f"ptable-each-error-hists-{model.lower().replace(' ', '-')}"
272+
save_fig(ax, f"{PDF_FIGS}/{img_name}.pdf")

0 commit comments

Comments
 (0)