Skip to content

Periodic table UX improvements #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 71 additions & 34 deletions pymatviz/ptable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import itertools
import sys
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Literal, get_args

Expand All @@ -15,7 +16,7 @@
from pandas.api.types import is_numeric_dtype, is_string_dtype
from pymatgen.core import Composition

from pymatviz.utils import df_ptable
from pymatviz.utils import df_ptable, pick_bw_for_contrast


if TYPE_CHECKING:
Expand Down Expand Up @@ -122,6 +123,10 @@ def count_elements(
srs = srs.reindex(df_ptable.index, fill_value=fill_value).rename("count")

if len(exclude_elements) > 0:
if isinstance(exclude_elements, str):
exclude_elements = [exclude_elements]
if isinstance(exclude_elements, tuple):
exclude_elements = list(exclude_elements)
try:
srs = srs.drop(exclude_elements)
except KeyError as exc:
Expand All @@ -140,16 +145,20 @@ def ptable_heatmap(
count_mode: CountMode = "composition",
cbar_title: str = "Element Count",
cbar_max: float | None = None,
cmap: str = "summer_r",
zero_color: str = "#DDD", # light gray
colorscale: str = "viridis",
infty_color: str = "lightskyblue",
na_color: str = "white",
heat_mode: Literal["value", "fraction", "percent"] | None = "value",
fmt: str | None = None,
cbar_fmt: str | None = None,
text_color: str | tuple[str, str] = "auto",
exclude_elements: Sequence[str] = (),
zero_color: str = "#eff", # light gray
zero_symbol: str | float = "-",
label_font_size: int = 16,
value_font_size: int = 12,
tile_size: float | tuple[float, float] = 0.9,
**kwargs: Any,
) -> plt.Axes:
"""Plot a heatmap across the periodic table of elements.

Expand All @@ -165,9 +174,9 @@ def ptable_heatmap(
cbar_max (float, optional): Maximum value of the colorbar range. Will be ignored
if smaller than the largest plotted value. For creating multiple plots with
identical color bars for visual comparison. Defaults to 0.
cmap (str, optional): Matplotlib colormap name to use. Defaults to "YlGn".
zero_color (str): Color to use for elements with value zero. Defaults to "#DDD"
(light gray).
colorscale (str, optional): Matplotlib colormap name to use. Defaults to
"viridis". See https://matplotlib.org/stable/users/explain/colors/colormaps
for available options.
infty_color: Color to use for elements with value infinity. Defaults to
"lightskyblue".
na_color: Color to use for elements with value infinity. Defaults to "white".
Expand All @@ -177,19 +186,28 @@ def ptable_heatmap(
"fraction" and "percent" can be used to make the colors in different
ptable_heatmap() (and ptable_heatmap_ratio()) plots comparable.
fmt (str): f-string format option for tile values. Defaults to ".1%"
(1 decimal place) if heat_mode="percent" else ".3g".
(1 decimal place) if heat_mode="percent" else ".3g". Use e.g. ",.0f" to
format values with thousands separators and no decimal places.
cbar_fmt (str): f-string format option to set a different colorbar tick
label format. Defaults to the above fmt.
text_color (str | tuple[str, str]): What color to use for element symbols and
heat labels. Must be a valid color name, or a 2-tuple of names, one to use
for the upper half of the color scale, one for the lower half. The special
value 'auto' applies 'black' on the lower and 'white' on the upper half of
the color scale. Defaults to "auto".
value "auto" applies "black" on the lower and "white" on the upper half of
the color scale. "auto_reverse" does the opposite. Defaults to "auto".
exclude_elements (list[str]): Elements to exclude from the heatmap. E.g. if
oxygen overpowers everything, you can try log=True or
exclude_elements=['O']. Defaults to ().
exclude_elements=["O"]. Defaults to ().
zero_color (str): Color to use for elements with value zero. Defaults to "#eff"
(light gray).
zero_symbol (str | float): Symbol to use for elements with value zero.
Defaults to "-".
label_font_size (int): Font size for element symbols. Defaults to 16.
value_font_size (int): Font size for heat values. Defaults to 12.
tile_size (float | tuple[float, float]): Size of each tile in the periodic
table as a fraction of available space before touching neighboring tiles.
1 or (1, 1) means no gaps between tiles. Defaults to 0.9.
**kwargs: Additional keyword arguments passed to plt.figure().

Returns:
ax: matplotlib Axes with the heatmap.
Expand All @@ -198,6 +216,9 @@ def ptable_heatmap(
raise ValueError(
"Combining log color scale and heat_mode='fraction'/'percent' unsupported"
)
if "cmap" in kwargs:
colorscale = kwargs.pop("cmap")
print("cmap argument is deprecated, use colorscale instead.", file=sys.stderr)

values = count_elements(values, count_mode, exclude_elements)

Expand All @@ -209,25 +230,30 @@ def ptable_heatmap(
values /= clean_vals.sum()
clean_vals /= clean_vals.sum() # normalize as well for norm.autoscale() below

color_map = get_cmap(cmap)
color_map = get_cmap(colorscale)

n_rows = df_ptable.row.max()
n_columns = df_ptable.column.max()

# TODO can we pass as a kwarg and still ensure aspect ratio respected?
fig = plt.figure(figsize=(0.75 * n_columns, 0.7 * n_rows))
fig = plt.figure(figsize=(0.75 * n_columns, 0.7 * n_rows), **kwargs)

ax = ax or plt.gca()

rw = rh = 0.9 # rectangle width/height
if isinstance(tile_size, (float, int)):
tile_width = tile_height = tile_size
else:
tile_width, tile_height = tile_size

norm = LogNorm() if log else Normalize()

norm.autoscale(clean_vals.to_numpy())
if cbar_max is not None:
norm.vmax = cbar_max

text_style = dict(horizontalalignment="center", fontsize=16, fontweight="semibold")
text_style = dict(
horizontalalignment="center", fontsize=label_font_size, fontweight="semibold"
)

for symbol, row, column, *_ in df_ptable.itertuples():
row = n_rows - row # invert row count to make periodic table right side up
Expand Down Expand Up @@ -259,7 +285,9 @@ def ptable_heatmap(
label = label.replace("e+0", "e")
if row < 3: # vertical offset for lanthanide + actinide series
row += 0.5
rect = Rectangle((column, row), rw, rh, edgecolor="gray", facecolor=color)
rect = Rectangle(
(column, row), tile_width, tile_height, edgecolor="gray", facecolor=color
)

if heat_mode is None:
# no value to display below in colored rectangle so center element symbol
Expand All @@ -268,22 +296,30 @@ def ptable_heatmap(
if symbol in exclude_elements:
text_clr = "black"
elif text_color == "auto":
text_clr = "white" if norm(tile_value) > 0.5 else "black"
if isinstance(color, (tuple, list)) and len(color) >= 3:
# treat color as RGB tuple and choose black or white text for contrast
text_clr = pick_bw_for_contrast(color)
else:
text_clr = "black"
elif isinstance(text_color, (tuple, list)):
text_clr = text_color[0] if norm(tile_value) > 0.5 else text_color[1]
else:
text_clr = text_color

plt.text(
column + 0.5 * rw, row + 0.5 * rh, symbol, color=text_clr, **text_style
column + 0.5 * tile_width,
row + 0.5 * tile_height,
symbol,
color=text_clr,
**text_style,
)

if heat_mode is not None:
plt.text(
column + 0.5 * rw,
row + 0.1 * rh,
column + 0.5 * tile_width,
row + 0.1 * tile_height,
label,
fontsize=10,
fontsize=value_font_size,
horizontalalignment="center",
color=text_clr,
)
Expand All @@ -297,7 +333,7 @@ def ptable_heatmap(
# format major and minor ticks
cb_ax.tick_params(which="both", labelsize=14, width=1)

mappable = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
mappable = plt.cm.ScalarMappable(norm=norm, cmap=colorscale)

def tick_fmt(val: float, _pos: int) -> str:
# val: value at color axis tick (e.g. 10.0, 20.0, ...)
Expand Down Expand Up @@ -327,7 +363,7 @@ def ptable_heatmap_ratio(
count_mode: CountMode = "composition",
normalize: bool = False,
cbar_title: str = "Element Ratio",
not_in_numerator: tuple[str, str] = ("#DDD", "gray: not in 1st list"),
not_in_numerator: tuple[str, str] = ("#eff", "gray: not in 1st list"),
not_in_denominator: tuple[str, str] = ("lightskyblue", "blue: not in 2nd list"),
not_in_either: tuple[str, str] = ("white", "white: not in either"),
**kwargs: Any,
Expand All @@ -350,7 +386,7 @@ def ptable_heatmap_ratio(
cbar_title (str): Title for the color bar. Defaults to "Element Ratio".
not_in_numerator (tuple[str, str]): Color and legend description used for
elements missing from numerator. Defaults to
('#DDD', 'gray: not in 1st list').
('#eff', 'gray: not in 1st list').
not_in_denominator (tuple[str, str]): See not_in_numerator. Defaults to
('lightskyblue', 'blue: not in 2nd list').
not_in_either (tuple[str, str]): See not_in_numerator. Defaults to
Expand Down Expand Up @@ -396,7 +432,7 @@ def ptable_heatmap_plotly(
precision: str | None = None,
hover_props: Sequence[str] | dict[str, str] | None = None,
hover_data: dict[str, str | int | float] | pd.Series | None = None,
font_colors: Sequence[str] = ("#eee", "black"),
font_colors: Sequence[str] = ("#eff", "black"),
gap: float = 5,
font_size: int | None = None,
bg_color: str | None = None,
Expand Down Expand Up @@ -448,7 +484,7 @@ def ptable_heatmap_plotly(
the hover tooltip on a new line below the element name"). Defaults to None.
font_colors (list[str]): One color name or two for [min_color, max_color].
min_color is applied to annotations with heatmap values less than
(max_val - min_val) / 2. Defaults to ("#eee", "black") meaning light text
(max_val - min_val) / 2. Defaults to ("#eff", "black") meaning light text
for low values and dark text for high values. May need to be manually
swapped depending on the colorscale.
gap (float): Gap in pixels between tiles of the periodic table. Defaults to 5.
Expand Down Expand Up @@ -492,7 +528,7 @@ def ptable_heatmap_plotly(
raise ValueError(f"{cscale_range=} should have length 2")

if isinstance(colorscale, (str, type(None))):
colorscale = px.colors.get_colorscale(colorscale or "Pinkyl")
colorscale = px.colors.get_colorscale(colorscale or "viridis")
elif isinstance(colorscale, Sequence) and isinstance(
colorscale[0], (str, list, tuple)
):
Expand All @@ -516,7 +552,7 @@ def ptable_heatmap_plotly(
raise ValueError(
"Log color scale requires all heat map values to be > 1 since values <= 1 "
f"map to negative log values which throws off the color scale. Got "
f"{smaller_1.size} values <= 1: {list(smaller_1)}"
f"{smaller_1.size} values <= 1: {dict(smaller_1)}"
)

if heat_mode in ("fraction", "percent"):
Expand Down Expand Up @@ -601,13 +637,14 @@ def ptable_heatmap_plotly(
# colors on empty tiles of the periodic table
heatmap_values[row][col] = color_val

# TODO: see if this ugly code can be handed off to plotly, looks like not atm
# https://github.com/janosh/pymatviz/issues/52
# https://github.com/plotly/documentation/issues/1611
log_cbar = dict(
tickvals=np.arange(int(np.log10(values.max())) + 1),
ticktext=10 ** np.arange(int(np.log10(values.max())) + 1),
)
if log:
# TODO: see if this ugly code can be handed off to plotly, looks like not atm
# https://github.com/janosh/pymatviz/issues/52
# https://github.com/plotly/documentation/issues/1611
log_cbar = dict(
tickvals=np.arange(int(np.log10(values.max())) + 1),
ticktext=10 ** np.arange(int(np.log10(values.max())) + 1),
)
if isinstance(font_colors, str):
font_colors = [font_colors]
if cscale_range == (None, None):
Expand Down
31 changes: 31 additions & 0 deletions pymatviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,34 @@ def patch_dict(
patched.update(updates)

yield patched


def luminance(color: tuple[float, float, float]) -> float:
"""Compute the luminance of a color as in https://stackoverflow.com/a/596243.

Args:
color (tuple[float, float, float]): RGB color tuple with values in [0, 1].

Returns:
float: Luminance of the color.
"""
red, green, blue, *_ = color # alpha = 1 - transparency
return 0.299 * red + 0.587 * green + 0.114 * blue


def pick_bw_for_contrast(
color: tuple[float, float, float], text_color_threshold: float = 0.7
) -> str:
"""Choose black or white text color for a given background color based on
luminance.

Args:
color (tuple[float, float, float]): RGB color tuple with values in [0, 1].
text_color_threshold (float, optional): Luminance threshold for choosing
black or white text color. Defaults to 0.7.

Returns:
str: "black" or "white" depending on the luminance of the background color.
"""
light_bg = luminance(color) > text_color_threshold
return "black" if light_bg else "white"
6 changes: 5 additions & 1 deletion tests/test_ptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_ptable_heatmap(
ptable_heatmap(glass_formulas, log=True)

# custom color map
ptable_heatmap(glass_formulas, log=True, cmap="summer")
ptable_heatmap(glass_formulas, log=True, colorscale="summer")

# heat_mode normalized to total count
ptable_heatmap(glass_formulas, heat_mode="fraction")
Expand Down Expand Up @@ -166,6 +166,10 @@ def test_ptable_heatmap(
cbar_1st_label = ax.child_axes[0].get_xticklabels()[0].get_text()
assert cbar_1st_label == "0.000%"

# test tile_size
ptable_heatmap(df_ptable.atomic_mass, tile_size=1)
ptable_heatmap(df_ptable.atomic_mass, tile_size=(0.9, 1))


def test_ptable_heatmap_ratio(
steel_formulas: list[str],
Expand Down
39 changes: 38 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from copy import deepcopy
from datetime import datetime
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
from unittest.mock import patch

import matplotlib.pyplot as plt
Expand All @@ -19,7 +19,9 @@
bin_df_cols,
df_to_arrays,
get_crystal_sys,
luminance,
patch_dict,
pick_bw_for_contrast,
)
from tests.conftest import y_pred, y_true

Expand Down Expand Up @@ -331,3 +333,38 @@ def test_annotate_bars(
ImportError, match=err_msg
):
annotate_bars(ax, adjust_test_pos=True)


@pytest.mark.parametrize(
"color,expected",
[
((0, 0, 0), 0), # Black
((1, 1, 1), 1), # White
((0.5, 0.5, 0.5), 0.5), # Gray
((1, 0, 0), 0.299), # Red
((0, 1, 0), 0.587), # Green
((0, 0, 1, 0.3), 0.114), # Blue with alpha (should be ignored)
],
)
def test_luminance(color: tuple[float, float, float], expected: float) -> None:
assert luminance(color) == pytest.approx(expected, 0.001)


@pytest.mark.parametrize(
"color,text_color_threshold,expected",
[
((1.0, 1.0, 1.0), 0.7, "black"), # White
((0, 0, 0), 0.7, "white"), # Black
((0.5, 0.5, 0.5), 0.7, "white"), # Gray
((0.5, 0.5, 0.5), 0, "black"), # Gray with low threshold
((1, 0, 0, 0.3), 0.7, "white"), # Red with alpha (should be ignored)
((0, 1, 0), 0.7, "white"), # Green
((0, 0, 1.0), 0.4, "white"), # Blue with low threshold
],
)
def test_pick_bw_for_contrast(
color: tuple[float, float, float],
text_color_threshold: float,
expected: Literal["black", "white"],
) -> None:
assert pick_bw_for_contrast(color, text_color_threshold) == expected