Skip to content

Commit 0f2410d

Browse files
committed
rename data/wbm/(analysis->eda).py
rename scripts/(analyze_failure_cases->analyze_model_failure_cases).py
1 parent 9793034 commit 0f2410d

File tree

7 files changed

+86
-16
lines changed

7 files changed

+86
-16
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ __pycache__
1111
*.csv.bz2
1212
*.pkl.gz
1313
data/**/raw
14+
data/**/tsne
1415
data/2022-*
1516
data/m3gnet-*
1617

data/wbm/analysis.py data/wbm/eda.py

+66-3
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,26 @@
33

44
import numpy as np
55
import pandas as pd
6+
import plotly.express as px
67
from pymatgen.core import Composition
78
from pymatviz import count_elements, ptable_heatmap_plotly
89
from pymatviz.utils import save_fig
910

1011
from matbench_discovery import FIGS, ROOT, today
12+
from matbench_discovery import plots as plots
1113
from matbench_discovery.data import DATA_FILES, df_wbm
1214
from matbench_discovery.energy import mp_elem_reference_entries
13-
from matbench_discovery.plots import pio
15+
from matbench_discovery.preds import df_each_err, each_true_col
16+
17+
__author__ = "Janosh Riebesell"
18+
__date__ = "2023-03-30"
1419

1520
"""
16-
Compare MP and WBM elemental prevalence. Starting with WBM, MP below.
21+
WBM exploratory data analysis.
22+
Start with comparing MP and WBM elemental prevalence.
1723
"""
1824

1925
module_dir = os.path.dirname(__file__)
20-
print(f"{pio.templates.default=}")
2126
about_data_page = f"{ROOT}/site/src/routes/about-the-data"
2227

2328

@@ -170,3 +175,61 @@
170175
fig.show()
171176

172177
save_fig(fig, f"{FIGS}/mp-elemental-ref-energies.svelte")
178+
179+
180+
# %% plot 2d and 3d t-SNE projections of one-hot encoded element vectors summed by
181+
# weight in each WBM composition. TLDR: no obvious structure in the data
182+
# was hoping to find certain clusters to have higher or lower errors after seeing
183+
# many models struggle on the halogens in per-element error periodic table heatmaps
184+
# https://matbench-discovery.janosh.dev/models
185+
df_2d_tsne = pd.read_csv(f"{module_dir}/tsne/one-hot-112-composition-2d.csv.gz")
186+
df_2d_tsne = df_2d_tsne.set_index("material_id")
187+
188+
df_3d_tsne = pd.read_csv(f"{module_dir}/tsne/one-hot-112-composition-3d.csv.gz")
189+
model = "Wrenformer"
190+
df_3d_tsne = pd.read_csv(
191+
f"{module_dir}/tsne/one-hot-112-composition+{model}-each-err-3d-metric=eucl.csv.gz"
192+
)
193+
df_3d_tsne = df_3d_tsne.set_index("material_id")
194+
195+
df_wbm[list(df_2d_tsne)] = df_2d_tsne
196+
df_wbm[list(df_3d_tsne)] = df_3d_tsne
197+
df_wbm[list(df_each_err.add_suffix(" abs EACH error"))] = df_each_err.abs()
198+
199+
200+
# %%
201+
color_col = f"{model} abs EACH error"
202+
clr_range_max = df_wbm[color_col].mean() + df_wbm[color_col].std()
203+
204+
205+
# %%
206+
fig = px.scatter(
207+
df_wbm,
208+
x="2d t-SNE 1",
209+
y="2d t-SNE 2",
210+
color=color_col,
211+
hover_name="material_id",
212+
hover_data=("formula", each_true_col),
213+
range_color=(0, clr_range_max),
214+
)
215+
fig.show()
216+
217+
218+
# %%
219+
fig = px.scatter_3d(
220+
df_wbm,
221+
x="3d t-SNE 1",
222+
y="3d t-SNE 2",
223+
z="3d t-SNE 3",
224+
color=color_col,
225+
custom_data=["material_id", "formula", each_true_col, color_col],
226+
range_color=(0, clr_range_max),
227+
)
228+
fig.data[0].hovertemplate = (
229+
"<b>material_id: %{customdata[0]}</b><br><br>"
230+
"t-SNE: (%{x:.2f}, %{y:.2f}, %{z:.2f})<br>"
231+
"Formula: %{customdata[1]}<br>"
232+
"E<sub>above hull</sub>: %{customdata[2]:.2f}<br>"
233+
f"{color_col}: %{{customdata[3]:.2f}}<br>"
234+
)
235+
fig.show()
File renamed without changes.

scripts/compute_projections.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# %%
55
import os
6+
from datetime import datetime
67
from typing import Any, Literal
78

89
import numpy as np
@@ -38,6 +39,8 @@
3839
print(f"{data_path=}")
3940
print(f"{out_dim=}")
4041
print(f"{projection_type=}")
42+
start_time = datetime.now()
43+
print(f"job started at {start_time:%Y-%m-%d %H:%M:%S}")
4144
df_in = pd.read_csv(data_path, na_filter=False).set_index("material_id")
4245

4346

@@ -61,13 +64,13 @@ def metric(
6164
projector = TSNE(
6265
n_components=out_dim, random_state=0, n_iter=250, n_iter_without_progress=50
6366
)
64-
out_cols = [f"t-SNE {idx}" for idx in range(out_dim)]
67+
out_cols = [f"{out_dim}d t-SNE {idx + 1}" for idx in range(out_dim)]
6568
elif projection_type == "umap":
6669
from umap import UMAP
6770

6871
# TODO this execution path is untested (was never run yet)
6972
projector = UMAP(n_components=out_dim, random_state=0, metric=metric)
70-
out_cols = [f"t-SNE {idx+1}" for idx in range(out_dim)]
73+
out_cols = [f"{out_dim}d UMAP {idx + 1}" for idx in range(out_dim)]
7174

7275
identity = np.eye(one_hot_dim)
7376

@@ -78,17 +81,20 @@ def sum_one_hot_elem(formula: str) -> np.ndarray[Any, np.int64]:
7881

7982

8083
in_col = {"wbm": "formula", "mp": "formula_pretty"}[data_name]
81-
df_in[f"one_hot_{one_hot_dim}"] = [
82-
sum_one_hot_elem(formula) for formula in tqdm(df_in[in_col])
83-
]
84-
84+
one_hot_encoding = np.array(
85+
[sum_one_hot_elem(formula) for formula in tqdm(df_in[in_col])]
86+
)
8587

86-
one_hot_encoding = np.array(df_in[f"one_hot_{one_hot_dim}"].to_list())
8788
projections = projector.fit_transform(one_hot_encoding)
8889

8990
df_in[out_cols] = projections
9091

91-
out_path = f"{out_dir}/one-hot-{one_hot_dim}-composition-{out_dim}d.csv"
92+
out_path = f"{out_dir}/one-hot-{one_hot_dim}-composition-{out_dim}d.csv.gz"
9293
df_in[out_cols].to_csv(out_path)
9394

9495
print(f"Wrote projections to {out_path!r}")
96+
end_time = datetime.now()
97+
print(
98+
f"Job finished at {end_time:%Y-%m-%d %H:%M:%S} and took "
99+
f"{(end_time - start_time).seconds} sec"
100+
)

scripts/make_api_docs.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
from lazydocs import generate_docs
88

9-
from matbench_discovery import ROOT
9+
SITE = f"{os.path.dirname(__file__)}/../site"
1010

11-
with open(f"{ROOT}/site/package.json") as file:
11+
with open(f"{SITE}/package.json") as file:
1212
pkg = json.load(file) # get repo URL from package.json
1313

14-
out_path = f"{ROOT}/site/src/routes/api"
14+
out_path = f"{SITE}/src/routes/api"
1515

1616
for path in glob(f"{out_path}/*.md"):
1717
os.remove(path)

site/package.json

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
"build": "vite build",
1414
"preview": "vite preview",
1515
"serve": "vite build && vite preview",
16-
"check": "svelte-check",
17-
"make-api-docs": "cd .. && python scripts/make_api_docs.py"
16+
"check": "svelte-check"
1817
},
1918
"devDependencies": {
2019
"@iconify/svelte": "^3.1.0",

site/src/app.css

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
--toc-active-border-width: 0 0 0 1pt;
1616
--toc-active-bg: none;
1717
--toc-active-border-radius: 0;
18+
--toc-max-height: 85vh;
1819

1920
--zoo-github-corner-color: var(--night);
2021
--zoo-github-corner-bg: white;

0 commit comments

Comments
 (0)