Skip to content

Commit 2f795f7

Browse files
committed
fix legend/subplot titles in scripts/scatter_e_above_hull_models.py
add paragraph on chemical diversity of train and test set
1 parent 8079ca1 commit 2f795f7

File tree

10 files changed

+155
-50
lines changed

10 files changed

+155
-50
lines changed

data/wbm/readme.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The first integer in each material ID ranging from 1 to 5 and coming right after
1414

1515
Each iteration has varying numbers of materials which are counted by the 2nd integer. Note this 2nd number is not always consecutive. A small number of materials (~0.2%) were removed by the data-cleaning steps detailed below. Don't be surprised to find an ID like `wbm-3-70804` followed by `wbm-3-70807`.
1616

17-
## 🪓   Data processing steps
17+
## 🪓   Data Processing Steps
1818

1919
The full set of processing steps used to curate the WBM test set from the raw data files (downloaded from URLs listed below) can be found in [`data/wbm/fetch_process_wbm_dataset.py`](https://github.com/janosh/matbench-discovery/blob/site/data/wbm/fetch_process_wbm_dataset.py). Processing involved
2020

@@ -45,7 +45,7 @@ The number of materials in each step before and after processing are:
4545
| before | 61,848 | 52,800 | 79,205 | 40,328 | 23,308 | 257,487 |
4646
| after | 61,466 | 52,755 | 79,160 | 40,314 | 23,268 | 256,963 |
4747

48-
## 🔗   Links to raw WBM data files
48+
## 🔗   Links to raw WBM Data Files
4949

5050
Links to WBM data files have proliferated. This is an attempt to keep track of them.
5151

@@ -72,7 +72,9 @@ materialscloud:2021.68 includes a readme file with a description of the dataset,
7272

7373
[wbm paper]: https://nature.com/articles/s41524-020-00481-6
7474

75-
## 📊   Plots
75+
## 📊   Chemical Diversity
76+
77+
Both the WBM test set and even more so the MP training set are heavily oxide dominated. The WBM test set is about 75% larger than the MP training set and also more chemically diverse, containing a higher fraction of transition metals, post-transition metals and metalloids. Our goal in picking such a large diverse test set is future-proofing. Ideally, this data will provide a challenging materials discovery test bed even for large foundational ML models in the future.
7678

7779
<slot name="wbm-elements-heatmap">
7880
<img src="./figs/2023-01-08-wbm-elements.svg" alt="Periodic table log heatmap of WBM elements">

scripts/scatter_e_above_hull_models.py

+59-28
Original file line numberDiff line numberDiff line change
@@ -45,26 +45,25 @@
4545

4646

4747
# %%
48-
def _metric_str(model_name: str) -> str:
49-
model_pred = df_wbm[e_above_hull_col] - (df_wbm[e_form_col] - df_wbm[model_name])
50-
MAE = (df_wbm[e_above_hull_col] - model_pred).abs().mean()
51-
isna = df_wbm[e_above_hull_col].isna() | model_pred.isna()
52-
R2 = r2_score(df_wbm[e_above_hull_col][~isna], model_pred[~isna])
53-
return f"{model_name} · {MAE=:.2f} · R<sup>2</sup>={R2:.2f}"
48+
def _metric_str(xs: list[float], ys: list[float]) -> str:
49+
# compute MAE and R2 for set of (x, y) pairs
50+
isna = np.isnan(xs) | np.isnan(ys)
51+
xs, ys = xs[~isna], ys[~isna]
52+
MAE = np.abs(xs - ys).mean()
53+
R2 = r2_score(xs, ys)
54+
return f" · MAE={MAE:.2f} · R<sup>2</sup>={R2:.2f}"
5455

5556

5657
def _add_metrics_to_legend(fig: go.Figure) -> None:
5758
for trace in fig.data:
5859
# initially hide all traces, let users select which models to compare
5960
trace.visible = "legendonly"
60-
# add MAE and R2 to legend
61-
model = trace.name
62-
trace.name = _metric_str(model)
61+
trace.name = f"{trace.name}{_metric_str(trace.x, trace.y)}"
6362

6463

6564
# %% scatter plot of actual vs predicted e_form_per_atom
6665
fig = px.scatter(
67-
df_melt.iloc[::10],
66+
df_melt.iloc[::5],
6867
x=e_form_col,
6968
y=e_form_preds,
7069
color=var_name,
@@ -80,13 +79,12 @@ def _add_metrics_to_legend(fig: go.Figure) -> None:
8079

8180
# %%
8281
img_path = f"{FIGS}/{today}-e-form-scatter-models"
83-
# fig.write_image(f"{img_path}.pdf")
84-
save_fig(fig, f"{img_path}.svelte")
82+
# save_fig(fig, f"{img_path}.svelte")
8583

8684

8785
# %% scatter plot of actual vs predicted e_above_hull
8886
fig = px.scatter(
89-
df_melt.iloc[::10],
87+
df_melt.iloc[::5],
9088
x=e_above_hull_col,
9189
y=e_above_hull_preds,
9290
color=var_name,
@@ -102,8 +100,7 @@ def _add_metrics_to_legend(fig: go.Figure) -> None:
102100

103101
# %%
104102
img_path = f"{FIGS}/{today}-e-above-hull-scatter-models"
105-
# fig.write_image(f"{img_path}.pdf")
106-
save_fig(fig, f"{img_path}.svelte")
103+
# save_fig(fig, f"{img_path}.svelte")
107104

108105

109106
# %% plot all models in separate subplots
@@ -116,41 +113,75 @@ def _add_metrics_to_legend(fig: go.Figure) -> None:
116113
)[true_pos * 0 + false_neg * 1 + false_pos * 2 + true_neg * 3]
117114

118115
fig = px.scatter(
119-
df_melt.iloc[::10],
116+
df_melt.iloc[::50],
120117
x=e_above_hull_col,
121118
y=e_above_hull_preds,
122119
facet_col=var_name,
123120
facet_col_wrap=3,
121+
facet_col_spacing=0.04,
122+
facet_row_spacing=0.15,
124123
hover_data=hover_cols,
125124
hover_name=id_col,
126125
color="clf",
127126
color_discrete_map=dict(zip(classes, ("green", "yellow", "red", "blue"))),
128-
opacity=0.4,
127+
# opacity=0.4,
128+
range_x=[-2, 2],
129+
range_y=[-2, 2],
129130
)
130131

132+
x_title = fig.layout.xaxis.title.text
133+
y_title = fig.layout.yaxis.title.text
134+
131135
# iterate over subplots and set new title
132-
for idx, model in enumerate(models, 1):
133-
# find index of annotation belonging to model
134-
anno_idx = [a.text for a in fig.layout.annotations].index(f"Model={model}")
135-
assert anno_idx >= 0, f"could not find annotation for {model}"
136+
for idx, anno in enumerate(fig.layout.annotations, 1):
137+
traces = [t for t in fig.data if t.xaxis == f"x{idx if idx > 1 else ''}"]
138+
xs = np.concatenate([t.x for t in traces])
139+
ys = np.concatenate([t.y for t in traces])
136140

141+
model = anno.text.split("=")[1]
137142
# set new subplot titles (adding MAE and R2)
138-
fig.layout.annotations[anno_idx].text = _metric_str(model)
143+
fig.layout.annotations[idx - 1].text = f"{model} {_metric_str(xs, ys)}"
139144

140145
# remove x and y axis titles if not on center row or center column
141-
if idx != 2:
142-
fig.layout[f"xaxis{idx}"].title.text = ""
143-
if idx > 1:
144-
fig.layout[f"yaxis{idx}"].title.text = ""
146+
fig.layout[f"xaxis{idx}"].title.text = ""
147+
fig.layout[f"yaxis{idx}"].title.text = ""
145148

146149
# add vertical and horizontal lines at 0
147150
fig.add_vline(x=0, line=dict(width=1, dash="dash", color="gray"))
148151
fig.add_hline(y=0, line=dict(width=1, dash="dash", color="gray"))
149152

150-
fig.update_layout(showlegend=False)
153+
151154
fig.update_xaxes(nticks=5)
152155
fig.update_yaxes(nticks=5)
153156

157+
legend = dict(
158+
title="", # remove legend title
159+
itemsizing="constant", # increase legend marker size
160+
orientation="h",
161+
x=0.5, # place legend centered above subplots
162+
xanchor="center",
163+
y=1.2,
164+
yanchor="top",
165+
)
166+
fig.layout.legend.update(legend)
167+
168+
axis_titles = dict(xref="paper", yref="paper", showarrow=False)
169+
fig.add_annotation(
170+
x=0.5,
171+
y=-0.16,
172+
text=x_title,
173+
**axis_titles,
174+
)
175+
# add y-axis title
176+
fig.add_annotation(
177+
x=-0.06,
178+
y=0.5,
179+
text=y_title,
180+
textangle=-90,
181+
**axis_titles,
182+
)
183+
184+
154185
fig.show()
155186
img_path = f"{STATIC}/{today}-each-scatter-models.png"
156-
# save_fig(fig, img_path, scale=4, width=1000, height=500)
187+
save_fig(fig, img_path, scale=4, width=1000, height=500)

site/src/app.d.ts

+12-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,18 @@
22
/// <reference types="mdsvex/globals" />
33

44
declare module '*.md'
5-
declare module '*package.json'
5+
6+
declare module '*package.json' {
7+
const pkg: Record<string, unknown>
8+
export default pkg
9+
}
610

711
declare module '*metadata.yml' {
8-
const content: import('$lib/types').ModelMetadata
9-
export default content
12+
const data: import('$lib').ModelMetadata
13+
export default data
14+
}
15+
16+
declare module '*element-counts.json' {
17+
const map: Record<string, number>
18+
export default map
1019
}

site/src/lib/ModelCard.svelte

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import { repository } from '$site/package.json'
33
import Icon from '@iconify/svelte'
44
import { pretty_num } from 'sveriodic-table/labels'
5-
import type { ModelData, ModelStat } from './types'
5+
import type { ModelData, ModelStat } from '.'
66
77
export let key: string
88
export let data: ModelData

site/src/lib/References.svelte

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<script lang="ts">
2-
import type { Reference } from './types'
2+
import type { Reference } from '.'
33
44
export let references: Reference[]
55
</script>

site/src/lib/index.ts

+55
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,58 @@ export { default as Footer } from './Footer.svelte'
22
export { default as ModelCard } from './ModelCard.svelte'
33
export { default as Nav } from './Nav.svelte'
44
export { default as References } from './References.svelte'
5+
6+
export type ModelData = ModelMetadata & ModelStats
7+
8+
export type ModelMetadata = {
9+
model_name: string
10+
model_version: string
11+
matbench_discovery_version: string
12+
date_added: string
13+
authors: Author[]
14+
repo: string
15+
url?: string
16+
doi?: string
17+
preprint?: string
18+
requirements?: Record<string, string>
19+
trained_on_benchmark: boolean
20+
}
21+
22+
export type ModelStats = {
23+
MAE: number
24+
RMSE: number
25+
R2: number
26+
Precision: number
27+
Recall: number
28+
F1: number
29+
missing_preds: number
30+
missing_percent: number
31+
Accuracy: number
32+
run_time: number
33+
run_time_h: string
34+
GPUs: number
35+
CPUs: number
36+
slurm_jobs: number
37+
}
38+
39+
export type Author = {
40+
name: string
41+
email?: string
42+
affiliation?: string
43+
orcid?: string
44+
url?: string
45+
twitter?: string
46+
}
47+
48+
export type Reference = {
49+
title: string
50+
id: string
51+
author: { family: string; given: string }[]
52+
DOI: string
53+
URL?: string
54+
issued: { year: number; month: number; day: number }[]
55+
accessed: { year: number; month: number; day: number }[]
56+
page: string
57+
type: string
58+
ISSN?: string
59+
}

site/src/routes/about-the-test-set/+page.svelte

+14-6
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,18 @@
1010
let log = false // log color scale
1111
const wbm_heat_vals: number[] = Object.values(wbm_elem_counts)
1212
const mp_heat_vals: number[] = Object.values(mp_elem_counts)
13-
const color_map = {
13+
const [mp_max, wbm_max] = [Math.max(...mp_heat_vals), Math.max(...wbm_heat_vals)]
14+
const mp_color_map = {
1415
200: `blue`,
15-
35_000: `green`,
16-
80_000: `yellow`,
17-
150_000: `red`,
16+
[mp_max / 4]: `green`,
17+
[mp_max / 2]: `yellow`,
18+
mp_max: `red`,
19+
}
20+
const wbm_color_map = {
21+
200: `blue`,
22+
[wbm_max / 4]: `green`,
23+
[wbm_max / 2]: `yellow`,
24+
wbm_max: `red`,
1825
}
1926
let active_mp_elem: ChemicalElement
2027
let active_wbm_elem: ChemicalElement
@@ -30,7 +37,7 @@
3037
<span>Log color scale <Toggle bind:checked={log} /></span>
3138
<PeriodicTable
3239
heatmap_values={wbm_heat_vals}
33-
{color_map}
40+
color_map={wbm_color_map}
3441
{log}
3542
bind:active_element={active_wbm_elem}
3643
>
@@ -51,9 +58,10 @@
5158
</PeriodicTable>
5259
</svelte:fragment>
5360
<svelte:fragment slot="mp-elements-heatmap">
61+
<span>Log color scale <Toggle bind:checked={log} /></span>
5462
<PeriodicTable
5563
heatmap_values={mp_heat_vals}
56-
{color_map}
64+
color_map={mp_color_map}
5765
{log}
5866
bind:active_element={active_mp_elem}
5967
>

site/src/routes/models/+page.server.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
import type { ModelData, ModelMetadata } from '$lib/types'
1+
import type { ModelData, ModelMetadata } from '$lib'
22
import { dirname } from 'path'
33
import type { PageServerLoad } from './$types'
44
import model_stats from './2023-01-23-model-stats.json'
55

6-
export const load: PageServerLoad = async () => {
6+
export const load: PageServerLoad = () => {
77
const yml = import.meta.glob(`$root/models/**/metadata.yml`, {
88
eager: true,
99
})

site/src/routes/models/+page.svelte

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<script lang="ts">
2+
import type { ModelStats } from '$lib'
23
import { ModelCard } from '$lib'
3-
import type { ModelStats } from '$lib/types'
44
import { RadioButtons } from 'svelte-zoo'
55
import { flip } from 'svelte/animate'
66
import { fade } from 'svelte/transition'

site/src/routes/paper/+page.svx

+5-5
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ Our benchmark is designed to make [adding future models easy](/how-to-contribute
167167

168168
1. [CGCNN](https://arxiv.org/abs/1710.10324) @xie_crystal_2018
169169
1. [BOWSR](https://sciencedirect.com/science/article/pii/S1369702121002984) @zuo_accelerating_2021
170-
1. [Wren](https://arxiv.org/abs/2106.11132) @goodall_rapid_2022
170+
1. [Wrenformer](https://arxiv.org/abs/2106.11132) @goodall_rapid_2022
171171
1. [M3GNet](https://arxiv.org/abs/2202.02450) @chen_universal_2022
172172
1. [MEGNet](https://arxiv.org/abs/1812.05055) @chen_graph_2019
173173
1. [Voronoi Random Forest](https://journals.aps.org/prb/abstract/10.1103/PhysRevB.96.024104) @goodall_rapid_2022
@@ -181,7 +181,7 @@ Our benchmark is designed to make [adding future models easy](/how-to-contribute
181181
Classification performance for all models
182182
</caption>
183183

184-
![Parity plot for each model's energy above hull predictions (based on their formation energy preds) vs DFT ground truth](./figs/2023-01-18-each-scatter-models.png)
184+
![Parity plot for each model's energy above hull predictions (based on their formation energy preds) vs DFT ground truth](./figs/2023-01-24-each-scatter-models.png)
185185

186186
<figcaption>@label:fig:each-scatter-models Parity plot for each model's energy above hull predictions (based on their formation energy preds) vs DFT ground truth</figcaption>
187187

@@ -190,9 +190,9 @@ Our benchmark is designed to make [adding future models easy](/how-to-contribute
190190
<figcaption>@label:fig:wbm-hull-dist-hist-models Histograms and rolling accuracy of using predicted formation energies for stability classification</figcaption>
191191

192192
<div>
193-
{#if typeof document !== `undefined`}
194-
<CumulativeClfMetrics class="pull-left" />
195-
{/if}
193+
{#if typeof document !== `undefined`}
194+
<CumulativeClfMetrics class="pull-left" />
195+
{/if}
196196
</div>
197197

198198
## Analysis

0 commit comments

Comments
 (0)