Skip to content

Commit 6e91792

Browse files
committed
add new plot ptable_elemental_ratio
1 parent 1495701 commit 6e91792

File tree

7 files changed

+94
-28
lines changed

7 files changed

+94
-28
lines changed

assets/ptable_elemental_ratio.svg

+1
Loading

assets/ptable_elemental_ratio_log.svg

+1
Loading

mlmatrics/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
count_elements,
44
hist_elemental_prevalence,
55
ptable_elemental_prevalence,
6+
ptable_elemental_ratio,
67
)
78
from .histograms import residual_hist
89
from .metrics import regression_metrics

mlmatrics/elements.py

+69-28
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ def count_elements(formulas: list) -> pd.Series:
2121
Returns:
2222
pd.Series: Total number of appearances of each element in `formulas`.
2323
"""
24-
srs = pd.Series(formulas).apply(lambda x: pd.Series(Composition(x).as_dict())).sum()
24+
formula2dict = lambda str: pd.Series(
25+
Composition(str).fractional_composition.as_dict()
26+
)
27+
28+
srs = pd.Series(formulas).apply(formula2dict).sum()
2529

2630
# ensure all elements are present in returned Series (with count zero if they
2731
# weren't in formulas)
@@ -32,7 +36,10 @@ def count_elements(formulas: list) -> pd.Series:
3236

3337

3438
def ptable_elemental_prevalence(
35-
formulas: List[str] = None, elem_counts: pd.Series = None, log_scale: bool = False
39+
formulas: List[str] = None,
40+
elem_counts: pd.Series = None,
41+
log_scale: bool = False,
42+
cbar_title: str = None,
3643
) -> None:
3744
"""Display the prevalence of each element in a materials dataset plotted as a
3845
heatmap over the periodic table. `formulas` xor `elem_counts` must be passed.
@@ -54,34 +61,43 @@ def ptable_elemental_prevalence(
5461

5562
ptable = pd.read_csv(ROOT + "/data/periodic_table.csv")
5663

57-
n_row = ptable.row.max()
58-
n_column = ptable.column.max()
64+
n_rows = ptable.row.max()
65+
n_columns = ptable.column.max()
5966

60-
plt.figure(figsize=(n_column, n_row))
67+
plt.figure(figsize=(n_columns, n_rows))
6168

6269
rw = rh = 0.9 # rectangle width/height
63-
count_min = elem_counts.min()
64-
count_max = elem_counts.max()
70+
min_count = elem_counts.min()
71+
max_count = elem_counts.replace([np.inf, -np.inf], np.nan).dropna().max()
6572

6673
norm = Normalize(
67-
vmin=0 if log_scale else count_min,
68-
vmax=np.log(count_max) if log_scale else count_max,
74+
vmin=0 if log_scale else min_count,
75+
vmax=np.log(max_count) if log_scale else max_count,
6976
)
7077

7178
text_style = dict(
7279
horizontalalignment="center",
7380
verticalalignment="center",
7481
fontsize=20,
7582
fontweight="semibold",
76-
color="black",
7783
)
7884

7985
for symbol, row, column, _ in ptable.values:
80-
row = n_row - row
86+
row = n_rows - row
8187
count = elem_counts[symbol]
88+
8289
if log_scale and count != 0:
8390
count = np.log(count)
84-
color = YlGn(norm(count)) if count != 0 else "silver"
91+
92+
# inf or NaN are expected when passing in elem_counts from ptable_elemental_ratio
93+
if count == 0: # not in formulas_a
94+
color = "yellow"
95+
elif count == np.inf:
96+
color = "orange" # not in formulas_b
97+
elif pd.isna(count):
98+
color = "gray" # not in either formulas_a nor formulas_b
99+
else:
100+
color = YlGn(norm(count)) if count != 0 else "silver"
85101

86102
if row < 3:
87103
row += 0.5
@@ -95,43 +111,68 @@ def ptable_elemental_prevalence(
95111
x_offset = 3.5
96112
y_offset = 7.8
97113
length = 9
98-
for i in range(granularity):
99-
value = int(round((i) * count_max / (granularity - 1)))
114+
for idx in range(granularity):
115+
value = int(round(idx * max_count / (granularity - 1)))
100116
if log_scale and value != 0:
101117
value = np.log(value)
102118
color = YlGn(norm(value)) if value != 0 else "silver"
103-
x_loc = i / (granularity) * length + x_offset
119+
x_loc = idx / (granularity) * length + x_offset
104120
width = length / granularity
105121
height = 0.35
106122
rect = Rectangle(
107123
(x_loc, y_offset), width, height, edgecolor="gray", facecolor=color
108124
)
109125

110-
if i in [0, 4, 9, 14, 19]:
126+
if idx in [0, 4, 9, 14, 19]:
111127
text = f"{value:g}"
112128
if log_scale:
113129
text = f"{np.exp(value):g}".replace("e+0", "e")
114130
plt.text(x_loc + width / 2, y_offset - 0.4, text, **text_style)
115131

116132
plt.gca().add_patch(rect)
117133

118-
plt.text(
119-
x_offset + length / 2,
120-
y_offset + 0.7,
121-
"log(Element Count)" if log_scale else "Element Count",
122-
horizontalalignment="center",
123-
verticalalignment="center",
124-
fontweight="semibold",
125-
fontsize=20,
126-
color="k",
127-
)
134+
if cbar_title is None:
135+
cbar_title = "log(Element Count)" if log_scale else "Element Count"
136+
137+
plt.text(x_offset + length / 2, y_offset + 0.7, cbar_title, **text_style)
128138

129-
plt.ylim(-0.15, n_row + 0.1)
130-
plt.xlim(0.85, n_column + 1.1)
139+
plt.ylim(-0.15, n_rows + 0.1)
140+
plt.xlim(0.85, n_columns + 1.1)
131141

132142
plt.axis("off")
133143

134144

145+
def ptable_elemental_ratio(
146+
formulas_a: List[str], formulas_b: List[str], log_scale: bool = False
147+
) -> None:
148+
"""Display the prevalence of each element in a materials dataset plotted as a
149+
heatmap over the periodic table. `formulas` xor `elem_counts` must be passed.
150+
151+
Adapted from https://github.com/kaaiian/ML_figures.
152+
153+
Args:
154+
formulas (list[str]): compositional strings, e.g. ["Fe2O3", "Bi2Te3"]
155+
elem_counts (pd.Series): Map from element symbol to prevalence count
156+
log_scale (bool, optional): Whether color map scale is log or linear.
157+
"""
158+
elem_counts_a = count_elements(formulas_a)
159+
elem_counts_b = count_elements(formulas_b)
160+
161+
elem_counts = elem_counts_a / elem_counts_b
162+
163+
cbar_title = "log(Element Ratio)" if log_scale else "Element Ratio"
164+
165+
ptable_elemental_prevalence(
166+
elem_counts=elem_counts, log_scale=log_scale, cbar_title=cbar_title
167+
)
168+
169+
text_style = dict(fontsize=14, fontweight="semibold")
170+
171+
plt.text(0.2, 2, "yellow: not in first list", **text_style)
172+
plt.text(0.2, 1.5, "orange: not in second list", **text_style)
173+
plt.text(0.2, 1, "gray: not in either", **text_style)
174+
175+
135176
def hist_elemental_prevalence(
136177
formulas: list,
137178
log_scale: bool = False,

readme.md

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ See [`mlmatrics/elements.py`](mlmatrics/elements.py).
6060
| ![ptable_elemental_prevalence](assets/ptable_elemental_prevalence.svg) | ![ptable_elemental_prevalence_log](assets/ptable_elemental_prevalence_log.svg) |
6161
| [`hist_elemental_prevalence(compositions)`](mlmatrics/elements.py) | [`hist_elemental_prevalence(compositions, log_scale=True, bar_values='count')`](mlmatrics/elements.py) |
6262
| ![hist_elemental_prevalence](assets/hist_elemental_prevalence.svg) | ![hist_elemental_prevalence_log_count](assets/hist_elemental_prevalence_log_count.svg) |
63+
| [`ptable_elemental_ratio(comps_a, comps_b)`](mlmatrics/elements.py) | [`ptable_elemental_ratio(comps_a, comps_b, log_scale=True)`](mlmatrics/elements.py) |
64+
| ![ptable_elemental_ratio](assets/ptable_elemental_ratio.svg) | ![ptable_elemental_ratio_log](assets/ptable_elemental_ratio_log.svg) |
6365

6466
## Uncertainty Calibration
6567

scripts/plot_all.py

+10
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
hist_elemental_prevalence,
1616
precision_recall_curve,
1717
ptable_elemental_prevalence,
18+
ptable_elemental_ratio,
1819
qq_gaussian,
1920
residual_hist,
2021
residual_vs_actual,
@@ -93,6 +94,7 @@ def savefig(filename: str) -> None:
9394

9495
# %% Elemental Plots
9596
mp_formulas = pd.read_csv(f"{ROOT}/data/mp-n_elements<2.csv").formula
97+
roost_formulas = pd.read_csv(f"{ROOT}/data/ex-ensemble-roost.csv").composition
9698

9799

98100
ptable_elemental_prevalence(mp_formulas)
@@ -103,6 +105,14 @@ def savefig(filename: str) -> None:
103105
savefig("ptable_elemental_prevalence_log")
104106

105107

108+
ptable_elemental_ratio(mp_formulas, roost_formulas)
109+
savefig("ptable_elemental_ratio")
110+
111+
112+
ptable_elemental_ratio(mp_formulas, roost_formulas, log_scale=True)
113+
savefig("ptable_elemental_ratio_log")
114+
115+
106116
hist_elemental_prevalence(mp_formulas, keep_top=15)
107117
savefig("hist_elemental_prevalence")
108118

tests/test_elements.py

+10
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
count_elements,
55
hist_elemental_prevalence,
66
ptable_elemental_prevalence,
7+
ptable_elemental_ratio,
78
)
89

910
compositions = pd.read_csv("data/mp-n_elements<2.csv").formula
11+
compositions_b = pd.read_csv("data/ex-ensemble-roost.csv").composition
1012

1113

1214
def test_ptable_elemental_prevalence():
@@ -36,3 +38,11 @@ def test_hist_elemental_prevalence_with_keep_top():
3638

3739
def test_hist_elemental_prevalence_with_bar_values_count():
3840
hist_elemental_prevalence(compositions, keep_top=10, bar_values="count")
41+
42+
43+
def test_ptable_elemental_ratio():
44+
ptable_elemental_ratio(compositions, compositions_b)
45+
46+
47+
def test_ptable_elemental_ratio_log_scale():
48+
ptable_elemental_ratio(compositions, compositions_b, log_scale=True)

0 commit comments

Comments
 (0)