Skip to content

Add flags to CHGNet and MACE test scripts to run in static mode (no relaxation) #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 11, 2023
4 changes: 2 additions & 2 deletions .github/workflows/slow-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ jobs:

steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.9

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test-scripts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ jobs:
- scripts/model_figs/rolling_mae_vs_hull_dist_models.py
steps:
- name: Check out repository
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.9

Expand Down
128 changes: 98 additions & 30 deletions data/mp/eda_mp_trj.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pandas as pd
import plotly.express as px
from matplotlib.colors import SymLogNorm
from pymatgen.core import Composition
from pymatgen.core import Composition, Element
from pymatviz import count_elements, ptable_heatmap, ptable_heatmap_ratio, ptable_hists
from pymatviz.io import save_fig
from pymatviz.utils import si_fmt
Expand All @@ -41,7 +41,7 @@
e_form_per_atom_col = "ef_per_atom"
magmoms_col = "magmoms"
forces_col = "forces"
elems_col = "symbols"
site_nums_col = "site_nums"


# %% load MP element counts by occurrence to compute ratio with MPtrj
Expand Down Expand Up @@ -91,7 +91,7 @@
{
info_to_id(atoms.info): atoms.info
| {key: atoms.arrays.get(key) for key in ("forces", "magmoms")}
| {"formula": str(atoms.symbols), elems_col: atoms.symbols}
| {"formula": str(atoms.symbols), site_nums_col: atoms.symbols}
for atoms_list in tqdm(mp_trj_atoms.values(), total=len(mp_trj_atoms))
for atoms in atoms_list
}
Expand Down Expand Up @@ -120,86 +120,154 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:


# %% plot per-element magmom histograms
magmom_hist_path = f"{DATA_DIR}/mp/mp-trj-2022-09-elem-magmoms.json.bz2"
ptable_magmom_hist_path = f"{DATA_DIR}/mp/mp-trj-2022-09-elem-magmoms.json.bz2"

if os.path.isfile(magmom_hist_path):
mp_trj_elem_magmoms = pd.read_json(magmom_hist_path, typ="series")
elif "mp_trj_elem_magmoms" not in locals():
if os.path.isfile(ptable_magmom_hist_path):
srs_mp_trj_elem_magmoms = pd.read_json(ptable_magmom_hist_path, typ="series")
elif "srs_mp_trj_elem_magmoms" not in locals():
# project magmoms onto symbols in dict
df_mp_trj_elem_magmom = pd.DataFrame(
[
dict(zip(elems, magmoms))
for elems, magmoms in df_mp_trj.set_index(elems_col)[magmoms_col]
for elems, magmoms in df_mp_trj.set_index(site_nums_col)[magmoms_col]
.dropna()
.items()
]
)

mp_trj_elem_magmoms = {
srs_mp_trj_elem_magmoms = {
col: list(df_mp_trj_elem_magmom[col].dropna()) for col in df_mp_trj_elem_magmom
}
pd.Series(mp_trj_elem_magmoms).to_json(magmom_hist_path)
pd.Series(srs_mp_trj_elem_magmoms).to_json(ptable_magmom_hist_path)

cmap = plt.cm.get_cmap("viridis")
cmap = plt.get_cmap(color_map := "viridis")
norm = matplotlib.colors.LogNorm(vmin=1, vmax=150_000)

ax = ptable_hists(
mp_trj_elem_magmoms,
fig_ptable_magmoms = ptable_hists(
srs_mp_trj_elem_magmoms,
symbol_pos=(0.2, 0.8),
log=True,
cbar_title="Magmoms ($μ_B$)",
cbar_title_kwds=dict(fontsize=16),
cbar_coords=(0.18, 0.85, 0.42, 0.02),
# annotate each element with its number of magmoms in MPtrj
anno_kwds=tile_count_anno,
colormap=color_map,
)

cbar_ax = ax.figure.add_axes([0.26, 0.78, 0.25, 0.015])
cbar_ax = fig_ptable_magmoms.figure.add_axes([0.26, 0.78, 0.25, 0.015])
cbar = matplotlib.colorbar.ColorbarBase(
cbar_ax, cmap=cmap, norm=norm, orientation="horizontal"
)
save_fig(ax, f"{PDF_FIGS}/mp-trj-magmoms-ptable-hists.pdf")
save_fig(fig_ptable_magmoms, f"{PDF_FIGS}/mp-trj-magmoms-ptable-hists.pdf")


# %% plot per-element force histograms
force_hist_path = f"{DATA_DIR}/mp/mp-trj-2022-09-elem-forces.json.bz2"
ptable_force_hist_path = f"{DATA_DIR}/mp/mp-trj-2022-09-elem-forces.json.bz2"

if os.path.isfile(force_hist_path):
mp_trj_elem_forces = pd.read_json(force_hist_path, typ="series")
elif "mp_trj_elem_forces" not in locals():
if os.path.isfile(ptable_force_hist_path):
srs_mp_trj_elem_forces = pd.read_json(ptable_force_hist_path, typ="series")
elif "srs_mp_trj_elem_forces" not in locals():
df_mp_trj_elem_forces = pd.DataFrame(
[
dict(zip(elems, np.abs(forces).mean(axis=1)))
for elems, forces in df_mp_trj.set_index(elems_col)[forces_col].items()
for elems, forces in df_mp_trj.set_index(site_nums_col)[forces_col].items()
]
)
mp_trj_elem_forces = {
col: list(df_mp_trj_elem_forces[col].dropna()) for col in df_mp_trj_elem_forces
}
mp_trj_elem_forces = pd.Series(mp_trj_elem_forces)
mp_trj_elem_forces.to_json(force_hist_path)
srs_mp_trj_elem_forces = pd.Series(mp_trj_elem_forces)
srs_mp_trj_elem_forces.to_json(ptable_force_hist_path)

cmap = plt.cm.get_cmap("viridis")
cmap = plt.get_cmap(color_map := "viridis")
norm = matplotlib.colors.LogNorm(vmin=1, vmax=1_000_000)

max_force = 10 # eV/Å
ax = ptable_hists(
mp_trj_elem_forces.copy().map(lambda x: [val for val in x if val < max_force]),
fig_ptable_forces = ptable_hists(
srs_mp_trj_elem_forces.copy().map(lambda x: [val for val in x if val < max_force]),
symbol_pos=(0.3, 0.8),
log=True,
cbar_title="1/3 Σ|Forces| (eV/Å)",
cbar_title_kwds=dict(fontsize=16),
cbar_coords=(0.18, 0.85, 0.42, 0.02),
x_range=(0, max_force),
anno_kwds=tile_count_anno,
colormap=color_map,
)

cbar_ax = ax.figure.add_axes([0.26, 0.78, 0.25, 0.015])
cbar_ax = fig_ptable_forces.figure.add_axes([0.26, 0.78, 0.25, 0.015])
cbar = matplotlib.colorbar.ColorbarBase(
cbar_ax, cmap=cmap, norm=norm, orientation="horizontal"
)

save_fig(ax, f"{PDF_FIGS}/mp-trj-forces-ptable-hists.pdf")
save_fig(fig_ptable_forces, f"{PDF_FIGS}/mp-trj-forces-ptable-hists.pdf")


# %% plot histogram of number of sites per element
ptable_n_sites_hist_path = f"{DATA_DIR}/mp/mp-trj-2022-09-elem-n-sites.json.bz2"

if os.path.isfile(ptable_n_sites_hist_path):
srs_mp_trj_elem_n_sites = pd.read_json(ptable_n_sites_hist_path, typ="series")
elif "mp_trj_elem_n_sites" not in locals():
# construct a series of lists of site numbers per element (i.e. how often each
# element appears in a structure with n sites)
# create all df cols as int dtype
df_mp_trj_elem_n_sites = pd.DataFrame(
[
dict.fromkeys(set(site_nums), len(site_nums))
for site_nums in df_mp_trj[site_nums_col]
]
).astype(int)
mp_trj_elem_n_sites = {
col: list(df_mp_trj_elem_n_sites[col].dropna())
for col in df_mp_trj_elem_n_sites
}
srs_mp_trj_elem_n_sites = pd.Series(mp_trj_elem_n_sites).sort_index()

srs_mp_trj_elem_n_sites.index = srs_mp_trj_elem_n_sites.index.map(
Element.from_Z
).map(str)
srs_mp_trj_elem_n_sites.to_json(ptable_n_sites_hist_path)


cmap = plt.get_cmap("Blues")
cbar_ticks = (100, 1_000, 10_000, 100_000, 1_000_000)
norm = matplotlib.colors.LogNorm(vmin=min(cbar_ticks), vmax=max(cbar_ticks))

fig_ptable_sites = ptable_hists(
srs_mp_trj_elem_n_sites,
symbol_pos=(0.8, 0.9),
log=True,
cbar_title="Number of Sites",
cbar_title_kwds=dict(fontsize=16),
cbar_coords=(0.18, 0.85, 0.42, 0.02),
anno_kwds=lambda hist_vals: dict(
text=si_fmt(len(hist_vals), ".0f"),
xy=(0.8, 0.6),
bbox=dict(pad=2, edgecolor="none", facecolor="none"),
),
x_range=(1, 300),
hist_kwds=lambda hist_vals: dict(
color=cmap(norm(len(hist_vals))), edgecolor="none"
),
)

# turn off y axis for helium (why is it even there?)
fig_ptable_sites.axes[17].get_yaxis().set_visible(False)

cbar_ax = fig_ptable_sites.figure.add_axes([0.23, 0.8, 0.31, 0.025])
cbar = matplotlib.colorbar.ColorbarBase(
cbar_ax,
cmap=cmap,
norm=norm,
orientation="horizontal",
ticks=cbar_ticks,
)
cbar.set_label("Number of atoms in MPtrj structures", fontsize=16)
cbar.ax.xaxis.set_label_position("top")

save_fig(fig_ptable_sites, f"{PDF_FIGS}/mp-trj-n-sites-ptable-hists.pdf")


# %%
Expand Down Expand Up @@ -371,15 +439,15 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=450, height=280)


# %% calc n_sites from forces len
df_mp_trj[n_sites_col] = df_mp_trj[forces_col].map(len)
log_y = False
# %% calc n_sites from per-site atomic numbers
df_mp_trj[n_sites_col] = df_mp_trj[site_nums_col].map(len)
n_sites_hist, n_sites_bins = np.histogram(
df_mp_trj[n_sites_col], bins=range(1, df_mp_trj[n_sites_col].max() + 1)
)

n_struct_col = "Number of Structures"
df_n_sites = pd.DataFrame({n_sites_col: n_sites_bins[:-1], n_struct_col: n_sites_hist})
log_y = False


# %% plot n_sites distribution
Expand Down
7 changes: 3 additions & 4 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@
# ComputedStructureEntries or using PatchedPhaseDiagram to get e_above_hull
# warnings are:
# > No electronegativity for Ne. Setting to NaN. This has no physical meaning
for lineno in (120, 221, 1043):
warnings.filterwarnings(
action="ignore", category=UserWarning, module="pymatgen", lineno=lineno
)
# and MaterialsProject2020Compatibility to get formation energies
# > Failed to guess oxidation states for Entry
warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")

id_col = "material_id"
init_struct_col = "initial_structure"
Expand Down
10 changes: 5 additions & 5 deletions models/bowsr/test_bowsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@

# %%
slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", "0"))
slurm_job_id = os.getenv("SLURM_JOB_ID", "debug")
out_path = f"{out_dir}/{slurm_job_id}-{slurm_array_task_id:>03}.json.gz"
slurm_array_job_id = os.getenv("SLURM_ARRAY_JOB_ID", "debug")
out_path = f"{out_dir}/{slurm_array_job_id}-{slurm_array_task_id:>03}.json.gz"

if os.path.isfile(out_path):
raise SystemExit(f"{out_path=} already exists, exciting early")
Expand All @@ -73,9 +73,9 @@
print(f"{data_path = }")
print(f"{out_path=}")

df_in: pd.DataFrame = np.array_split(
pd.read_json(data_path).set_index(id_col), slurm_array_task_count
)[slurm_array_task_id - 1]
df_in = pd.read_json(data_path).set_index(id_col)
if slurm_array_task_count > 1:
df_in = np.array_split(df_in, slurm_array_task_count)[slurm_array_task_id - 1]


# %%
Expand Down
4 changes: 2 additions & 2 deletions models/cgcnn/test_cgcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@
data_loader=data_loader,
)

slurm_job_id = os.getenv("SLURM_JOB_ID", "debug")
df.round(4).to_csv(f"{out_dir}/{job_name}-preds-{slurm_job_id}.csv.gz")
slurm_array_job_id = os.getenv("SLURM_ARRAY_JOB_ID", "debug")
df.round(4).to_csv(f"{out_dir}/{job_name}-preds-{slurm_array_job_id}.csv.gz")
pred_col = f"{target_col}_pred_ens"
assert pred_col in df, f"{pred_col=} not in {list(df)}"
table = wandb.Table(dataframe=df[[target_col, pred_col]].reset_index())
Expand Down
4 changes: 2 additions & 2 deletions models/chgnet/ctk_structure_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
__date__ = "2023-03-07"

"""
This scripts runs a Crystal Toolkit app that shows a scatter plot of CHGNet energies
and allows to click on points to view the corresponding structures. Run with:
This scripts runs a Crystal Toolkit app that shows a parity plot of CHGNet vs PBE
energies and allows to click on points to view the corresponding structures. Run with:
python scripts/ctk_structure_viewer.py
Then open http://localhost:8000 in your browser.
"""
Expand Down
34 changes: 22 additions & 12 deletions models/chgnet/test_chgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@

# %%
slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", "0"))
slurm_job_id = os.getenv("SLURM_JOB_ID", "debug")
out_path = f"{out_dir}/{slurm_job_id}-{slurm_array_task_id:>03}.json.gz"
slurm_array_job_id = os.getenv("SLURM_ARRAY_JOB_ID", "debug")
out_path = f"{out_dir}/{slurm_array_job_id}-{slurm_array_task_id:>03}.json.gz"

if os.path.isfile(out_path):
raise SystemExit(f"{out_path=} already exists, exciting early")
Expand All @@ -66,12 +66,12 @@
print(f"\nJob started running {timestamp}")
print(f"{data_path=}")
e_pred_col = "chgnet_energy"
max_steps = 500
max_steps = 0
fmax = 0.05

df_in: pd.DataFrame = np.array_split(
pd.read_json(data_path).set_index(id_col), slurm_array_task_count
)[slurm_array_task_id - 1]
df_in = pd.read_json(data_path).set_index(id_col)
if slurm_array_task_count > 1:
df_in = np.array_split(df_in, slurm_array_task_count)[slurm_array_task_id - 1]

run_params = dict(
data_path=data_path,
Expand All @@ -98,18 +98,25 @@

structures = df_in[input_col].map(Structure.from_dict).to_dict()

for material_id in tqdm(structures, desc="Relaxing", disable=None):
for material_id in tqdm(structures, desc="Relaxing"):
if material_id in relax_results:
continue
try:
relax_result = chgnet.relax(
structures[material_id], verbose=False, steps=max_steps, fmax=fmax
structures[material_id],
verbose=False,
steps=max_steps,
fmax=fmax,
relax_cell=max_steps > 0,
)
relax_results[material_id] = {
"chgnet_structure": relax_result["final_structure"],
"chgnet_trajectory": relax_result["trajectory"].__dict__,
e_pred_col: relax_result["trajectory"].energies[-1],
e_pred_col: relax_result["trajectory"].energies[-1]
}
if max_steps > 0:
relax_struct = relax_result["final_structure"]
relax_results[material_id]["chgnet_structure"] = relax_struct
traj = relax_result["trajectory"]
relax_results[material_id]["chgnet_trajectory"] = traj.__dict__
except Exception as exc:
print(f"Failed to relax {material_id}: {exc!r}")

Expand All @@ -118,7 +125,10 @@
df_out = pd.DataFrame(relax_results).T
df_out.index.name = id_col

df_out.reset_index().to_json(out_path, default_handler=as_dict_handler)
if max_steps == 0:
df_out.add_suffix("_no_relax").to_csv(out_path.replace(".json.gz", ".csv.gz"))
else:
df_out.reset_index().to_json(out_path, default_handler=as_dict_handler)


# %%
Expand Down
4 changes: 2 additions & 2 deletions models/m3gnet/pre_vs_post_m3gnet_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
df_wbm.query("initial_wbm_volume.isna()").index.tolist()


# %% scatter plot of M3GNet/initial volumes vs DFT-relaxed volumes
# %% parity plot of M3GNet/initial volumes vs DFT-relaxed volumes
ax = density_scatter(
df=df_wbm.query("m3gnet_volume < 2000"),
x="final_wbm_volume",
Expand Down Expand Up @@ -211,7 +211,7 @@
# notebook server
fig.show(renderer="png", scale=2)
fig.write_image(
f"{SITE_FIGS}/m3gnet-energy-per-atom-scatter-is2re-vs-rs2re.webp", scale=2
f"{SITE_FIGS}/m3gnet-energy-per-atom-parity-is2re-vs-rs2re.webp", scale=2
)


Expand Down
Loading