Skip to content

Commit db20338

Browse files
committed
access pymatviz.io.save_fig from namespace
1 parent 37ed2c7 commit db20338

27 files changed

+119
-147
lines changed

data/mp/build_phase_diagram.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010

1111
import pandas as pd
1212
import pymatviz
13+
import pymatviz as pmv
1314
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram
1415
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
1516
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
1617
from pymatgen.ext.matproj import MPRester
1718
from pymatviz.enums import Key
18-
from pymatviz.io import save_fig
1919
from tqdm import tqdm
2020

2121
from matbench_discovery import MP_DIR, ROOT, today
@@ -121,4 +121,4 @@
121121
xlabel="MP Formation Energy (eV/atom)",
122122
ylabel="Our Formation Energy (eV/atom)",
123123
)
124-
save_fig(ax, f"{ROOT}/tmp/{today}-our-vs-mp-formation-energies.webp", dpi=300)
124+
pmv.save_fig(ax, f"{ROOT}/tmp/{today}-our-vs-mp-formation-energies.webp", dpi=300)

data/mp/eda_mp_trj.py

+19-21
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from matplotlib.colors import SymLogNorm
1818
from pymatgen.core import Composition, Element
1919
from pymatviz.enums import Key
20-
from pymatviz.io import save_fig
21-
from pymatviz.utils import si_fmt
2220
from tqdm import tqdm
2321

2422
from matbench_discovery import MP_DIR, PDF_FIGS, ROOT, SITE_FIGS
@@ -108,7 +106,7 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
108106
"""Annotate each periodic table tile with the number of values in its histogram."""
109107
face_color = cmap(norm(np.sum(len(hist_vals)))) if hist_vals else "none"
110108
bbox = dict(facecolor=face_color, alpha=0.4, pad=2, edgecolor="none")
111-
return dict(text=si_fmt(len(hist_vals), ".0f"), bbox=bbox)
109+
return dict(text=pmv.si_fmt(len(hist_vals), ".0f"), bbox=bbox)
112110

113111

114112
# %% plot per-element magmom histograms
@@ -152,7 +150,7 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
152150
cbar = matplotlib.colorbar.ColorbarBase(
153151
cbar_ax, cmap=cmap, norm=norm, orientation="horizontal"
154152
)
155-
save_fig(fig_ptable_magmoms, f"{PDF_FIGS}/mp-trj-magmoms-ptable-hists.pdf")
153+
pmv.save_fig(fig_ptable_magmoms, f"{PDF_FIGS}/mp-trj-magmoms-ptable-hists.pdf")
156154

157155

158156
# %% plot per-element force histograms
@@ -195,7 +193,7 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
195193
cbar_ax, cmap=cmap, norm=norm, orientation="horizontal"
196194
)
197195

198-
save_fig(fig_ptable_forces, f"{PDF_FIGS}/mp-trj-forces-ptable-hists.pdf")
196+
pmv.save_fig(fig_ptable_forces, f"{PDF_FIGS}/mp-trj-forces-ptable-hists.pdf")
199197

200198

201199
# %% plot histogram of number of sites per element
@@ -238,7 +236,7 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
238236
cbar_title_kwds=dict(fontsize=16),
239237
cbar_coords=(0.18, 0.85, 0.42, 0.02),
240238
anno_kwds=lambda hist_vals: dict(
241-
text=si_fmt(len(hist_vals), ".0f"),
239+
text=pmv.si_fmt(len(hist_vals), ".0f"),
242240
xy=(0.8, 0.6),
243241
bbox=dict(pad=2, edgecolor="none", facecolor="none"),
244242
),
@@ -262,7 +260,7 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
262260
cbar.set_label("Number of atoms in MPtrj structures", fontsize=16)
263261
cbar.ax.xaxis.set_label_position("top")
264262

265-
save_fig(fig_ptable_sites, f"{PDF_FIGS}/mp-trj-n-sites-ptable-hists.pdf")
263+
pmv.save_fig(fig_ptable_sites, f"{PDF_FIGS}/mp-trj-n-sites-ptable-hists.pdf")
266264

267265

268266
# %%
@@ -301,7 +299,7 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
301299
img_name += "-symlog" if isinstance(log, SymLogNorm) else "-log"
302300
if excl_noble:
303301
img_name += "-excl-noble"
304-
save_fig(ax_ptable, f"{PDF_FIGS}/{img_name}.pdf")
302+
pmv.save_fig(ax_ptable, f"{PDF_FIGS}/{img_name}.pdf")
305303

306304

307305
# %%
@@ -319,7 +317,7 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
319317
img_name = "mp-trj-mp-ratio-element-counts-by-occurrence"
320318
if normalized:
321319
img_name += "-normalized"
322-
save_fig(ax_ptable, f"{PDF_FIGS}/{img_name}.pdf")
320+
pmv.save_fig(ax_ptable, f"{PDF_FIGS}/{img_name}.pdf")
323321

324322

325323
# %% plot formation energy per atom distribution
@@ -332,8 +330,8 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
332330
fig.show()
333331

334332
pdf_kwds = dict(width=500, height=300)
335-
# save_fig(fig, f"{PDF_FIGS}/mp-trj-e-form-hist.pdf", **pdf_kwds)
336-
# save_fig(fig, f"{SITE_FIGS}/mp-trj-e-form-hist.svelte")
333+
# pmv.save_fig(fig, f"{PDF_FIGS}/mp-trj-e-form-hist.pdf", **pdf_kwds)
334+
# pmv.save_fig(fig, f"{SITE_FIGS}/mp-trj-e-form-hist.svelte")
337335

338336

339337
# %% plot forces distribution
@@ -343,8 +341,8 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
343341
fig.update_yaxes(type="log")
344342
fig.show()
345343

346-
# save_fig(fig, f"{PDF_FIGS}/mp-trj-forces-hist.pdf", **pdf_kwds)
347-
# save_fig(fig, f"{SITE_FIGS}/mp-trj-forces-hist.svelte")
344+
# pmv.save_fig(fig, f"{PDF_FIGS}/mp-trj-forces-hist.pdf", **pdf_kwds)
345+
# pmv.save_fig(fig, f"{SITE_FIGS}/mp-trj-forces-hist.svelte")
348346

349347

350348
# %% plot hydrostatic stress distribution
@@ -354,8 +352,8 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
354352
fig.update_yaxes(type="log")
355353
fig.show()
356354

357-
# save_fig(fig, f"{PDF_FIGS}/mp-trj-stresses-hist.pdf", **pdf_kwds)
358-
# save_fig(fig, f"{SITE_FIGS}/mp-trj-stresses-hist.svelte")
355+
# pmv.save_fig(fig, f"{PDF_FIGS}/mp-trj-stresses-hist.pdf", **pdf_kwds)
356+
# pmv.save_fig(fig, f"{SITE_FIGS}/mp-trj-stresses-hist.svelte")
359357

360358

361359
# %% plot magmoms distribution
@@ -365,8 +363,8 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
365363
fig.update_yaxes(type="log")
366364
fig.show()
367365

368-
# save_fig(fig, f"{PDF_FIGS}/mp-trj-magmoms-hist.pdf", **pdf_kwds)
369-
# save_fig(fig, f"{SITE_FIGS}/mp-trj-magmoms-hist.svelte")
366+
# pmv.save_fig(fig, f"{PDF_FIGS}/mp-trj-magmoms-hist.pdf", **pdf_kwds)
367+
# pmv.save_fig(fig, f"{SITE_FIGS}/mp-trj-magmoms-hist.svelte")
370368

371369

372370
# %%
@@ -393,8 +391,8 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
393391

394392
fig.show()
395393
img_name = "mp-vs-mp-trj-vs-wbm-arity-hist"
396-
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
397-
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=450, height=280)
394+
pmv.save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
395+
pmv.save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=450, height=280)
398396

399397

400398
# %% calc n_sites from per-site atomic numbers
@@ -472,5 +470,5 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
472470
img_name = "mp-trj-n-sites-hist"
473471
if log_y:
474472
img_name += "-log"
475-
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
476-
# save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=450, height=300)
473+
pmv.save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
474+
# pmv.save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=450, height=300)

data/mp/get_mp_energies.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@
88
import os
99

1010
import pandas as pd
11+
import pymatviz as pmv
1112
from aviary.wren.utils import get_protostructure_label_from_spglib
1213
from mp_api.client import MPRester
1314
from pymatgen.core import Structure
1415
from pymatviz.enums import Key
15-
from pymatviz.io import save_fig # noqa: F401
16-
from pymatviz.powerups import annotate_metrics
17-
from pymatviz.utils import PLOTLY
1816
from tqdm import tqdm
1917

2018
from matbench_discovery import STABILITY_THRESHOLD, today
@@ -55,7 +53,7 @@
5553
df_spg = pd.json_normalize(df_mp.pop("symmetry"))[["number", "symbol"]]
5654
df_mp["spacegroup_symbol"] = df_spg.symbol.to_numpy()
5755

58-
df_mp.energy_type.value_counts().plot.pie(backend=PLOTLY, autopct="%1.1f%%")
56+
df_mp.energy_type.value_counts().plot.pie(backend=pmv.utils.PLOTLY, autopct="%1.1f%%")
5957
# GGA: 72.2%, GGA+U: 27.8%
6058

6159

@@ -96,10 +94,12 @@
9694
title=f"{today} - {len(df_mp):,} MP entries",
9795
)
9896

99-
annotate_metrics(df_mp.formation_energy_per_atom, df_mp.decomposition_enthalpy)
97+
pmv.powerups.annotate_metrics(
98+
df_mp.formation_energy_per_atom, df_mp.decomposition_enthalpy
99+
)
100100
# result on 2023-01-10: plots match. no correlation between formation energy and
101101
# decomposition enthalpy. R^2 = -1.571, MAE = 1.604
102-
# save_fig(ax, f"{module_dir}/mp-decomp-enth-vs-e-form.webp", dpi=300)
102+
# pmv.save_fig(ax, f"{module_dir}/mp-decomp-enth-vs-e-form.webp", dpi=300)
103103

104104

105105
# %% scatter plot energy above convex hull vs decomposition enthalpy
@@ -117,4 +117,4 @@
117117
title=f"{n_above_line:,} / {len(df_mp):,} = {n_above_line / len(df_mp):.1%} "
118118
"MP materials with\nenergy_above_hull - decomposition_enthalpy.clip(0) > 0.1"
119119
)
120-
# save_fig(ax, f"{module_dir}/mp-e-above-hull-vs-decomp-enth.webp", dpi=300)
120+
# pmv.save_fig(ax, f"{module_dir}/mp-e-above-hull-vs-decomp-enth.webp", dpi=300)

data/wbm/compare_cse_vs_ce_mp_2020_corrections.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import json
99

1010
import pandas as pd
11+
import pymatviz as pmv
1112
from pymatgen.entries.compatibility import (
1213
MaterialsProject2020Compatibility,
1314
MaterialsProjectCompatibility,
1415
)
1516
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
1617
from pymatviz.enums import Key
17-
from pymatviz.io import save_fig
1818
from tqdm import tqdm
1919

2020
from matbench_discovery import ROOT, today
@@ -96,7 +96,7 @@
9696

9797
ax.axline((0, 0), slope=1, color="gray", linestyle="dashed", zorder=-1)
9898

99-
save_fig(ax, f"{ROOT}/tmp/{today}-ce-vs-cse-corrections-outliers.pdf")
99+
pmv.save_fig(ax, f"{ROOT}/tmp/{today}-ce-vs-cse-corrections-outliers.pdf")
100100

101101

102102
# %%
@@ -117,7 +117,7 @@
117117
# insight: all materials for which ComputedEntry and ComputedStructureEntry give
118118
# different formation energies are oxides or sulfides for which MP 2020 compat takes
119119
# into account structural information to make more accurate corrections.
120-
save_fig(ax, f"{ROOT}/tmp/{today}-ce-vs-cse-e-form-outliers.pdf")
120+
pmv.save_fig(ax, f"{ROOT}/tmp/{today}-ce-vs-cse-e-form-outliers.pdf")
121121

122122

123123
# %% below code resulted in

data/wbm/compile_wbm_test_set.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from pymatgen.entries.computed_entries import ComputedStructureEntry
2424
from pymatviz.enums import Key
25-
from pymatviz.io import save_fig
2625
from tqdm import tqdm
2726

2827
from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, today
@@ -478,13 +477,13 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
478477

479478
# %%
480479
img_name = "hist-wbm-e-form-per-atom"
481-
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
480+
pmv.save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
482481
# recommended to upload SVG to vecta.io/nano for compression
483-
# save_fig(fig, f"{img_name}.svg", width=800, height=300)
482+
# pmv.save_fig(fig, f"{img_name}.svg", width=800, height=300)
484483

485484
# make full data range visible in PDF
486485
# fig.layout.xaxis.range = [-12, 82]
487-
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf")
486+
pmv.save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf")
488487

489488

490489
# %%

data/wbm/eda_wbm.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from matplotlib.colors import SymLogNorm
1313
from pymatgen.core import Composition, Structure
1414
from pymatviz.enums import Key
15-
from pymatviz.io import save_fig
1615
from pymatviz.utils import PLOTLY, si_fmt, si_fmt_int
1716

1817
from matbench_discovery import PDF_FIGS, ROOT, SITE_FIGS, STABILITY_THRESHOLD
@@ -91,7 +90,7 @@
9190
)
9291
if log:
9392
filename += "-symlog" if isinstance(log, SymLogNorm) else "-log"
94-
save_fig(ax_mp_cnt, f"{PDF_FIGS}/{filename}.pdf")
93+
pmv.save_fig(ax_mp_cnt, f"{PDF_FIGS}/{filename}.pdf")
9594

9695

9796
# %% ratio of WBM to MP counts
@@ -105,7 +104,7 @@
105104
img_name = "wbm-mp-ratio-element-counts-by-occurrence"
106105
if normalized:
107106
img_name += "-normalized"
108-
save_fig(ax_ptable, f"{PDF_FIGS}/{img_name}.pdf")
107+
pmv.save_fig(ax_ptable, f"{PDF_FIGS}/{img_name}.pdf")
109108

110109

111110
# %% export element counts by WBM step to JSON
@@ -139,8 +138,8 @@
139138
fig.layout.margin = dict(l=0, r=0, b=0, t=0)
140139
fig.show()
141140
svg_path = f"{module_dir}/figs/wbm-elements.svg"
142-
# save_fig(fig, svg_path, width=1000, height=500)
143-
save_fig(fig, f"{PDF_FIGS}/{dataset}-element-{count_mode}-counts.pdf")
141+
# pmv.save_fig(fig, svg_path, width=1000, height=500)
142+
pmv.save_fig(fig, f"{PDF_FIGS}/{dataset}-element-{count_mode}-counts.pdf")
144143

145144

146145
# %% histogram of energy distance to MP convex hull for WBM
@@ -216,9 +215,9 @@
216215
MbdKey.e_form_raw: "e-form-uncorrected",
217216
}[e_col]
218217
img_name = f"hist-wbm-{suffix}"
219-
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
220-
# save_fig(fig, f"./figs/{img_name}.svg", width=800, height=500)
221-
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=600, height=300)
218+
pmv.save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
219+
# pmv.save_fig(fig, f"./figs/{img_name}.svg", width=800, height=500)
220+
pmv.save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=600, height=300)
222221

223222

224223
# %%
@@ -252,8 +251,8 @@
252251

253252
fig.show()
254253

255-
save_fig(fig, f"{SITE_FIGS}/mp-elemental-ref-energies.svelte")
256-
save_fig(fig, f"{PDF_FIGS}/mp-elemental-ref-energies.pdf")
254+
pmv.save_fig(fig, f"{SITE_FIGS}/mp-elemental-ref-energies.svelte")
255+
pmv.save_fig(fig, f"{PDF_FIGS}/mp-elemental-ref-energies.pdf")
257256

258257

259258
# %% plot 2d and 3d t-SNE projections of one-hot encoded element vectors summed by
@@ -326,8 +325,8 @@
326325
fig.layout.title.update(text="WBM Spacegroup Sunburst", x=0.5, font_size=14)
327326
fig.layout.margin = dict(l=0, r=0, t=30, b=0)
328327
fig.show()
329-
save_fig(fig, f"{SITE_FIGS}/spacegroup-sunburst-wbm.svelte")
330-
save_fig(fig, f"{PDF_FIGS}/spacegroup-sunburst-wbm.pdf")
328+
pmv.save_fig(fig, f"{SITE_FIGS}/spacegroup-sunburst-wbm.svelte")
329+
pmv.save_fig(fig, f"{PDF_FIGS}/spacegroup-sunburst-wbm.pdf")
331330

332331

333332
# %%
@@ -337,8 +336,8 @@
337336
fig.layout.title.update(text="MP Spacegroup Sunburst", x=0.5, font_size=14)
338337
fig.layout.margin = dict(l=0, r=0, t=30, b=0)
339338
fig.show()
340-
save_fig(fig, f"{SITE_FIGS}/spacegroup-sunburst-mp.svelte")
341-
save_fig(fig, f"{PDF_FIGS}/spacegroup-sunburst-mp.pdf")
339+
pmv.save_fig(fig, f"{SITE_FIGS}/spacegroup-sunburst-mp.svelte")
340+
pmv.save_fig(fig, f"{PDF_FIGS}/spacegroup-sunburst-mp.pdf")
342341
# would be good to have consistent order of crystal systems between sunbursts but not
343342
# controllable yet
344343
# https://github.com/plotly/plotly.py/issues/4115
@@ -366,8 +365,8 @@
366365

367366
fig.show()
368367
img_name = "mp-vs-wbm-arity-hist"
369-
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
370-
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=450, height=280)
368+
pmv.save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
369+
pmv.save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=450, height=280)
371370

372371

373372
# %% find large structures that changed symmetry during relaxation

matbench_discovery/enums.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import StrEnum, unique
44
from typing import Self
55

6-
from pymatviz.utils import html_tag
6+
import pymatviz as pmv
77

88

99
class LabelEnum(StrEnum):
@@ -136,7 +136,7 @@ class TestSubset(LabelEnum):
136136
full = "full", "Full Test Set"
137137

138138

139-
eV_per_atom = html_tag( # noqa: N816
139+
eV_per_atom = pmv.html_tag( # noqa: N816
140140
"(eV/atom)", tag="span", style="font-size: 0.8em; font-weight: lighter;"
141141
)
142142

models/chgnet/analyze_chgnet.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import pymatviz as pmv
99
from pymatgen.core import Structure
1010
from pymatviz.enums import Key
11-
from pymatviz.io import save_fig
12-
from pymatviz.utils import PLOTLY
1311

1412
from matbench_discovery import PDF_FIGS
1513
from matbench_discovery import plots as plots
@@ -54,7 +52,7 @@
5452
y=e_form_2000,
5553
hover_name=Key.mat_id,
5654
hover_data=[Key.formula],
57-
backend=PLOTLY,
55+
backend=pmv.utils.PLOTLY,
5856
title=f"{len(df_diff)} structures have > {min_e_diff} eV/atom energy diff after "
5957
"longer relaxation",
6058
)
@@ -88,7 +86,7 @@
8886
formula = struct.composition.reduced_formula
8987
ax.set_title(f"{idx}. {formula} (spg={spg_num})\n{row.Index}", fontweight="bold")
9088

91-
save_fig(fig, f"{PDF_FIGS}/chgnet-bad-relax-structures.pdf")
89+
pmv.save_fig(fig, f"{PDF_FIGS}/chgnet-bad-relax-structures.pdf")
9290

9391

9492
# %% ensure all CHGNet static predictions (direct energy without any structure

0 commit comments

Comments
 (0)