Skip to content

Commit 0913416

Browse files
committed
fix ptable_elemental_prevalence log scale colorbar
1 parent 6a9ee75 commit 0913416

5 files changed

+44
-76
lines changed

assets/ptable_elemental_prevalence.svg

+1-1
Loading

assets/ptable_elemental_prevalence_log.svg

+1-1
Loading

assets/ptable_elemental_ratio.svg

+1-1
Loading

assets/ptable_elemental_ratio_log.svg

+1-1
Loading

mlmatrics/elements.py

+40-72
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pandas as pd
66
from matplotlib.axes import Axes
77
from matplotlib.cm import get_cmap
8-
from matplotlib.colors import Normalize
8+
from matplotlib.colors import LogNorm, Normalize
99
from matplotlib.patches import Rectangle
1010
from pymatgen import Composition
1111

@@ -29,7 +29,7 @@ def count_elements(formulas: list) -> pd.Series:
2929

3030
# ensure all elements are present in returned Series (with count zero if they
3131
# weren't in formulas)
32-
ptable = pd.read_csv(ROOT + "/data/periodic_table.csv")
32+
ptable = pd.read_csv(f"{ROOT}/data/periodic_table.csv")
3333
# fill_value=0 required as max(NaN, any int) = NaN
3434
srs = srs.combine(pd.Series(0, index=ptable.symbol), max, fill_value=0)
3535
return srs
@@ -39,6 +39,7 @@ def ptable_elemental_prevalence(
3939
formulas: List[str] = None,
4040
elem_counts: pd.Series = None,
4141
log: bool = False,
42+
ax: Axes = None,
4243
cbar_title: str = None,
4344
cmap: str = "YlGn",
4445
) -> None:
@@ -51,6 +52,7 @@ def ptable_elemental_prevalence(
5152
formulas (list[str]): compositional strings, e.g. ["Fe2O3", "Bi2Te3"]
5253
elem_counts (pd.Series): Map from element symbol to prevalence count
5354
log (bool, optional): Whether color map scale is log or linear.
55+
ax (Axes, optional): plt axes. Defaults to None.
5456
cbar_title (str, optional): Optional Title for colorbar. Defaults to None.
5557
cmap (str, optional): Matplotlib colormap name to use. Defaults to "YlGn".
5658
@@ -65,38 +67,40 @@ def ptable_elemental_prevalence(
6567
if formulas is not None:
6668
elem_counts = count_elements(formulas)
6769

68-
ptable = pd.read_csv(ROOT + "/data/periodic_table.csv")
70+
ptable = pd.read_csv(f"{ROOT}/data/periodic_table.csv")
6971
cmap = get_cmap(cmap)
7072

7173
n_rows = ptable.row.max()
7274
n_columns = ptable.column.max()
7375

74-
# TODO can we pass as as a kwarg and still ensure aspect ratio respected?
75-
plt.figure(figsize=(n_columns, n_rows))
76+
# TODO can we pass as a kwarg and still ensure aspect ratio respected?
77+
fig = plt.figure(figsize=(0.75 * n_columns, 0.7 * n_rows))
78+
79+
if ax is None:
80+
ax = plt.gca()
7681

7782
rw = rh = 0.9 # rectangle width/height
7883
min_count = elem_counts.min()
84+
# replace([np.inf, -np.inf], np.nan) deals with missing or zero-values when
85+
# plotting ptable_elemental_ratio
7986
max_count = elem_counts.replace([np.inf, -np.inf], np.nan).dropna().max()
8087

81-
norm = Normalize(
82-
vmin=0 if log else min_count,
83-
vmax=np.log10(max_count) if log else max_count,
84-
)
88+
if log:
89+
norm = LogNorm(max(min_count, 1), max_count)
90+
else:
91+
norm = Normalize(min_count, max_count)
8592

8693
text_style = dict(
8794
horizontalalignment="center",
8895
verticalalignment="center",
89-
fontsize=20,
96+
fontsize=15,
9097
fontweight="semibold",
9198
)
9299

93100
for symbol, row, column, _ in ptable.values:
94101
row = n_rows - row
95102
count = elem_counts[symbol]
96103

97-
if log and count > 0:
98-
count = np.log10(count)
99-
100104
# inf or NaN are expected when passing in elem_counts from ptable_elemental_ratio
101105
if count == 0: # not in formulas_a
102106
color = "silver"
@@ -105,54 +109,30 @@ def ptable_elemental_prevalence(
105109
elif pd.isna(count):
106110
color = "white" # not in either formulas_a nor formulas_b
107111
else:
108-
color = cmap(norm(count)) if count != 0 else "silver"
112+
color = cmap(norm(count)) if count > 0 else "silver"
109113

110114
if row < 3:
111115
row += 0.5
112116
rect = Rectangle((column, row), rw, rh, edgecolor="gray", facecolor=color)
113117

114118
plt.text(column + rw / 2, row + rw / 2, symbol, **text_style)
115119

116-
plt.gca().add_patch(rect)
117-
118-
# color bar
119-
granularity = 20 # number of cells in the color bar
120-
bar_xpos, bar_ypos = 3.5, 7.8 # bar position
121-
bar_width, bar_height = 9, 0.35
122-
cell_width = bar_width / granularity
123-
124-
for idx in np.arange(granularity) + (1 if log else 0):
125-
value = idx * max_count / (granularity - 1)
126-
if log and value > 0:
127-
value = np.log10(value)
128-
129-
color = cmap(norm(value)) if value != 0 else "silver"
130-
x_loc = (idx - (1 if log else 0)) / granularity * bar_width + bar_xpos
131-
rect = Rectangle(
132-
(x_loc, bar_ypos), cell_width, bar_height, edgecolor="gray", facecolor=color
133-
)
120+
ax.add_patch(rect)
134121

135-
if idx in np.linspace(0, granularity, granularity // 4) + (
136-
1 if log else 0
137-
) or idx == (granularity - (0 if log else 1)):
138-
text = f"{value:.1f}" if log else f"{value:.0f}"
139-
plt.text(x_loc + cell_width / 2, bar_ypos - 0.4, text, **text_style)
122+
# colorbar position and size: [bar_xpos, bar_ypos, bar_width, bar_height]
123+
# anchored at lower left corner
124+
cb_ax = ax.inset_axes([0.18, 0.8, 0.42, 0.05], transform=ax.transAxes)
125+
# format major and minor ticks
126+
cb_ax.tick_params(which="both", labelsize=14, width=1)
140127

141-
plt.gca().add_patch(rect)
142-
143-
if log:
144-
plt.text(
145-
bar_xpos + cell_width / 2, bar_ypos + 0.6, int(min_count), **text_style
146-
)
147-
plt.text(x_loc + cell_width / 2, bar_ypos + 0.6, int(max_count), **text_style)
148-
149-
if cbar_title is None:
150-
cbar_title = "log(Element Count)" if log else "Element Count"
151-
152-
plt.text(bar_xpos + bar_width / 2, bar_ypos + 0.7, cbar_title, **text_style)
128+
cbar = fig.colorbar(
129+
plt.cm.ScalarMappable(norm=norm, cmap=cmap), orientation="horizontal", cax=cb_ax
130+
)
131+
cbar.outline.set_linewidth(1)
132+
cb_ax.set_title(cbar_title or "Element Count", pad=15, **text_style)
153133

154-
plt.ylim(-0.15, n_rows + 0.1)
155-
plt.xlim(0.85, n_columns + 1.1)
134+
plt.ylim(0.3, n_rows + 0.1)
135+
plt.xlim(0.9, n_columns + 1)
156136

157137
plt.axis("off")
158138

@@ -204,30 +184,18 @@ def ptable_elemental_ratio(
204184

205185
elem_counts = elem_counts_a / elem_counts_b
206186

207-
cbar_title = "log(Element Ratio)" if log else "Element Ratio"
208-
209187
ptable_elemental_prevalence(
210-
elem_counts=elem_counts, log=log, cbar_title=cbar_title, **kwargs
188+
elem_counts=elem_counts, log=log, cbar_title="Element Ratio", **kwargs
211189
)
212190

213-
text_style = {"fontsize": 14, "fontweight": "semibold"}
214-
215-
# add key for the colours
216-
plt.text(
217-
0.8,
218-
2,
219-
"gray: not in st list",
220-
**text_style,
221-
bbox={"facecolor": "silver", "linewidth": 0},
222-
)
223-
plt.text(
224-
0.8,
225-
1.5,
226-
"blue: not in 2nd list",
227-
**text_style,
228-
bbox={"facecolor": "lightskyblue", "linewidth": 0},
229-
)
230-
plt.text(0.8, 1, "white: not in either", **text_style)
191+
# add legend for the colours
192+
for y_pos, label, color, txt in [
193+
[0.4, "white", "white", "not in either"],
194+
[1.1, "blue", "lightskyblue", "not in 2nd list"],
195+
[1.8, "gray", "silver", "not in 1st list"],
196+
]:
197+
bbox = {"facecolor": color, "edgecolor": "gray"}
198+
plt.text(0.8, y_pos, f"{label}: {txt}", fontsize=12, bbox=bbox)
231199

232200

233201
def hist_elemental_prevalence(

0 commit comments

Comments
 (0)