|
4 | 4 | # %%
|
5 | 5 | import io
|
6 | 6 | import os
|
| 7 | +from typing import Any |
7 | 8 | from zipfile import ZipFile
|
8 | 9 |
|
9 | 10 | import ase
|
10 | 11 | import ase.io.extxyz
|
| 12 | +import matplotlib.colors |
| 13 | +import matplotlib.pyplot as plt |
11 | 14 | import numpy as np
|
12 | 15 | import pandas as pd
|
13 | 16 | import plotly.express as px
|
|
37 | 40 | e_form_per_atom_col = "ef_per_atom"
|
38 | 41 | magmoms_col = "magmoms"
|
39 | 42 | forces_col = "forces"
|
| 43 | +elems_col = "symbols" |
40 | 44 |
|
41 | 45 |
|
42 | 46 | # %% load MP element counts by occurrence to compute ratio with MPtrj
|
|
46 | 50 | df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index(id_col)
|
47 | 51 |
|
48 | 52 |
|
49 |
| -# %% --- load preprocessed MPtrj summary data --- |
| 53 | +# %% --- load preprocessed MPtrj summary data if available --- |
50 | 54 | mp_trj_summary_path = f"{DATA_DIR}/mp/mp-trj-2022-09-summary.json.bz2"
|
51 | 55 | if os.path.isfile(mp_trj_summary_path):
|
52 | 56 | df_mp_trj = pd.read_json(mp_trj_summary_path)
|
|
84 | 88 |
|
85 | 89 | df_mp_trj = pd.DataFrame(
|
86 | 90 | {
|
87 |
| - info_to_id(atoms.info): {"formula": str(atoms.symbols)} |
| 91 | + info_to_id(atoms.info): atoms.info |
88 | 92 | | {key: atoms.arrays.get(key) for key in ("forces", "magmoms")}
|
89 |
| - | atoms.info |
| 93 | + | {"formula": str(atoms.symbols), elems_col: atoms.symbols} |
90 | 94 | for atoms_list in tqdm(mp_trj_atoms.values(), total=len(mp_trj_atoms))
|
91 | 95 | for atoms in atoms_list
|
92 | 96 | }
|
|
106 | 110 | df_mp_trj.to_json(mp_trj_summary_path)
|
107 | 111 |
|
108 | 112 |
|
| 113 | +# %% |
| 114 | +def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]: |
| 115 | + """Annotate each periodic table tile with the number of values in its histogram.""" |
| 116 | + facecolor = cmap(norm(np.sum(len(hist_vals)))) if hist_vals else "none" |
| 117 | + bbox = dict(facecolor=facecolor, alpha=0.4, pad=2, edgecolor="none") |
| 118 | + return dict(text=si_fmt(len(hist_vals), ".0f"), bbox=bbox) |
| 119 | + |
| 120 | + |
109 | 121 | # %% plot per-element magmom histograms
|
110 | 122 | magmom_hist_path = f"{DATA_DIR}/mp/mp-trj-2022-09-elem-magmoms.json.bz2"
|
111 | 123 |
|
112 | 124 | if os.path.isfile(magmom_hist_path):
|
113 | 125 | mp_trj_elem_magmoms = pd.read_json(magmom_hist_path, typ="series")
|
114 | 126 | elif "mp_trj_elem_magmoms" not in locals():
|
115 |
| - df_mp_trj_magmom = pd.DataFrame( |
116 |
| - { |
117 |
| - info_to_id(atoms.info): ( |
118 |
| - dict(zip(atoms.symbols, atoms.arrays["magmoms"], strict=True)) |
119 |
| - if magmoms_col in atoms.arrays |
120 |
| - else None |
121 |
| - ) |
122 |
| - for frame_id in tqdm(mp_trj_atoms) |
123 |
| - for atoms in mp_trj_atoms[frame_id] |
124 |
| - } |
125 |
| - ).T.dropna(axis=0, how="all") |
| 127 | + # project magmoms onto symbols in dict |
| 128 | + df_mp_trj_elem_magmom = pd.DataFrame( |
| 129 | + [ |
| 130 | + dict(zip(elems, magmoms)) |
| 131 | + for elems, magmoms in df_mp_trj.set_index(elems_col)[magmoms_col] |
| 132 | + .dropna() |
| 133 | + .items() |
| 134 | + ] |
| 135 | + ) |
126 | 136 |
|
127 | 137 | mp_trj_elem_magmoms = {
|
128 |
| - col: list(df_mp_trj_magmom[col].dropna()) for col in df_mp_trj_magmom |
| 138 | + col: list(df_mp_trj_elem_magmom[col].dropna()) for col in df_mp_trj_elem_magmom |
129 | 139 | }
|
130 | 140 | pd.Series(mp_trj_elem_magmoms).to_json(magmom_hist_path)
|
131 | 141 |
|
| 142 | +cmap = plt.cm.get_cmap("viridis") |
| 143 | +norm = matplotlib.colors.LogNorm(vmin=1, vmax=150_000) |
| 144 | + |
132 | 145 | ax = ptable_hists(
|
133 | 146 | mp_trj_elem_magmoms,
|
134 | 147 | symbol_pos=(0.2, 0.8),
|
135 | 148 | log=True,
|
136 | 149 | cbar_title="Magmoms ($μ_B$)",
|
| 150 | + cbar_title_kwds=dict(fontsize=16), |
| 151 | + cbar_coords=(0.18, 0.85, 0.42, 0.02), |
137 | 152 | # annotate each element with its number of magmoms in MPtrj
|
138 |
| - anno_kwds=dict(text=lambda hist_vals: si_fmt(len(hist_vals), ".0f")), |
| 153 | + anno_kwds=tile_count_anno, |
139 | 154 | )
|
140 | 155 |
|
| 156 | +cbar_ax = ax.figure.add_axes([0.26, 0.78, 0.25, 0.015]) |
| 157 | +cbar = matplotlib.colorbar.ColorbarBase( |
| 158 | + cbar_ax, cmap=cmap, norm=norm, orientation="horizontal" |
| 159 | +) |
141 | 160 | save_fig(ax, f"{PDF_FIGS}/mp-trj-magmoms-ptable-hists.pdf")
|
142 | 161 |
|
143 | 162 |
|
| 163 | +# %% plot per-element force histograms |
| 164 | +force_hist_path = f"{DATA_DIR}/mp/mp-trj-2022-09-elem-forces.json.bz2" |
| 165 | + |
| 166 | +if os.path.isfile(force_hist_path): |
| 167 | + mp_trj_elem_forces = pd.read_json(force_hist_path, typ="series") |
| 168 | +elif "mp_trj_elem_forces" not in locals(): |
| 169 | + df_mp_trj_elem_forces = pd.DataFrame( |
| 170 | + [ |
| 171 | + dict(zip(elems, np.abs(forces).mean(axis=1))) |
| 172 | + for elems, forces in df_mp_trj.set_index(elems_col)[forces_col].items() |
| 173 | + ] |
| 174 | + ) |
| 175 | + mp_trj_elem_forces = { |
| 176 | + col: list(df_mp_trj_elem_forces[col].dropna()) for col in df_mp_trj_elem_forces |
| 177 | + } |
| 178 | + mp_trj_elem_forces = pd.Series(mp_trj_elem_forces) |
| 179 | + mp_trj_elem_forces.to_json(force_hist_path) |
| 180 | + |
| 181 | +cmap = plt.cm.get_cmap("viridis") |
| 182 | +norm = matplotlib.colors.LogNorm(vmin=1, vmax=1_000_000) |
| 183 | + |
| 184 | +max_force = 10 # eV/Å |
| 185 | +ax = ptable_hists( |
| 186 | + mp_trj_elem_forces.copy().map(lambda x: [val for val in x if val < max_force]), |
| 187 | + symbol_pos=(0.3, 0.8), |
| 188 | + log=True, |
| 189 | + cbar_title="1/3 Σ|Forces| (eV/Å)", |
| 190 | + cbar_title_kwds=dict(fontsize=16), |
| 191 | + cbar_coords=(0.18, 0.85, 0.42, 0.02), |
| 192 | + x_range=(0, max_force), |
| 193 | + anno_kwds=tile_count_anno, |
| 194 | +) |
| 195 | + |
| 196 | +cbar_ax = ax.figure.add_axes([0.26, 0.78, 0.25, 0.015]) |
| 197 | +cbar = matplotlib.colorbar.ColorbarBase( |
| 198 | + cbar_ax, cmap=cmap, norm=norm, orientation="horizontal" |
| 199 | +) |
| 200 | + |
| 201 | +save_fig(ax, f"{PDF_FIGS}/mp-trj-forces-ptable-hists.pdf") |
| 202 | + |
| 203 | + |
144 | 204 | # %%
|
145 | 205 | elem_counts: dict[str, dict[str, int]] = {}
|
146 | 206 | for count_mode in ("composition", "occurrence"):
|
|
153 | 213 |
|
154 | 214 |
|
155 | 215 | # %%
|
| 216 | +count_mode = "composition" |
156 | 217 | if "trj_elem_counts" not in locals():
|
157 | 218 | trj_elem_counts = pd.read_json(
|
158 |
| - f"{data_page}/mp-trj-element-counts-by-occurrence.json", typ="series" |
| 219 | + f"{data_page}/mp-trj-element-counts-by-{count_mode}.json", |
| 220 | + typ="series", |
159 | 221 | )
|
160 | 222 |
|
161 | 223 | excl_elems = "He Ne Ar Kr Xe".split() if (excl_noble := False) else ()
|
|
167 | 229 | zero_color="#efefef",
|
168 | 230 | log=(log := True),
|
169 | 231 | exclude_elements=excl_elems, # drop noble gases
|
170 |
| - cbar_range=None if excl_noble else (2000, None), |
| 232 | + cbar_range=None if excl_noble else (10_000, None), |
171 | 233 | label_font_size=17,
|
172 | 234 | value_font_size=14,
|
173 | 235 | )
|
174 | 236 |
|
175 |
| -img_name = f"mp-trj-element-counts-by-occurrence{'-log' if log else ''}" |
| 237 | +img_name = f"mp-trj-element-counts-by-{count_mode}{'-log' if log else ''}" |
176 | 238 | if excl_noble:
|
177 | 239 | img_name += "-excl-noble"
|
178 | 240 | save_fig(ax_ptable, f"{PDF_FIGS}/{img_name}.pdf")
|
|
0 commit comments