-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathptable_heatmap_splits_plotly.py
157 lines (134 loc) · 5.48 KB
/
ptable_heatmap_splits_plotly.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# %%
import itertools
from collections.abc import Callable, Sequence
import numpy as np
from pymatgen.core import Element
import pymatviz as pmv
import pymatviz.colors as pmv_colors
from pymatviz.typing import RgbColorType
np_rng = np.random.default_rng(seed=0)
# %% Examples of ptable_heatmap_splits_plotly with different numbers of splits
for idx, (n_splits, orientation) in enumerate(
itertools.product(range(2, 5), ("diagonal", "horizontal", "vertical", "grid"))
):
if orientation == "grid" and n_splits != 4:
continue
if idx > 5: # running all n_split/orientation combos takes long
break
# Example 1: Single colorscale with single colorbar
data_dict = {
elem.symbol: np_rng.integers(10, 20, size=n_splits) for elem in Element
}
cbar_title = f"Periodic Table Heatmap with {n_splits}-fold split"
fig = pmv.ptable_heatmap_splits_plotly(
data=data_dict,
orientation=orientation, # type: ignore[arg-type]
colorscale="RdYlBu", # Single colorscale will be used for all splits
colorbar=dict(title=cbar_title),
)
fig.show()
# Example 2: Multiple colorscales with vertical colorbars
colorscales = ["Viridis", "Plasma", "Inferno", "Magma"][:n_splits]
colorbars = [
dict(title=f"Metric {idx + 1}", orientation="v") for idx in range(n_splits)
]
fig = pmv.ptable_heatmap_splits_plotly(
data=data_dict,
orientation=orientation, # type: ignore[arg-type]
colorscale=colorscales,
colorbar=colorbars,
)
fig.show()
# Example 3: Multiple colorscales with horizontal colorbars
# Use sequential colors from the same family
sequential_colors = [
[(0, "rgb(255,220,220)"), (1, "rgb(255,0,0)")], # Red scale
[(0, "rgb(220,220,255)"), (1, "rgb(0,0,255)")], # Blue scale
[(0, "rgb(220,255,220)"), (1, "rgb(0,255,0)")], # Green scale
[(0, "rgb(255,220,255)"), (1, "rgb(128,0,128)")], # Purple scale
][:n_splits]
colorbars = [
dict(title=f"Metric {idx + 1}", orientation="h") for idx in range(n_splits)
]
fig = pmv.ptable_heatmap_splits_plotly(
data=data_dict,
orientation=orientation, # type: ignore[arg-type]
colorscale=sequential_colors,
colorbar=colorbars,
)
fig.show()
# if orientation == "diagonal":
# pmv.io.save_and_compress_svg(fig, f"ptable-heatmap-splits-plotly-{n_splits}")
# %% Example 4: Custom color schemes with multiple colorbars
def make_color_scale(
color_schemes: Sequence[dict[str, RgbColorType]],
) -> Callable[[str, float, int], str]:
"""Return element colors in different palettes based on split index."""
def elem_color_scale(element: str, _val: float, split_idx: int) -> str:
# Default to gray for elements without defined colors
color = color_schemes[split_idx].get(element, "(128, 128, 128)")
return f"rgb{color}"
return elem_color_scale
palettes_3 = (
pmv_colors.ELEM_COLORS_ALLOY,
pmv_colors.ELEM_COLORS_JMOL,
pmv_colors.ELEM_COLORS_VESTA,
)
# Example with vertical colorbars
fig = pmv.ptable_heatmap_splits_plotly(
# Use dummy values for all elements
{str(elem): list(range(len(palettes_3))) for elem in Element},
orientation="diagonal", # could also use "grid"
colorscale=make_color_scale(palettes_3),
colorbar=[
dict(title="ALLOY Colors", orientation="v"),
dict(title="JMOL Colors", orientation="v"),
dict(title="VESTA Colors", orientation="v"),
],
hover_data=dict.fromkeys(
map(str, Element), "top left: JMOL<br>top right: VESTA, bottom: ALLOY"
),
)
title = (
"<b>Element color schemes</b><br>top left: JMOL, top right: VESTA, bottom: ALLOY"
)
fig.layout.title.update(text=title, x=0.4, y=0.8)
fig.show()
# pmv.io.save_and_compress_svg(fig, "ptable-heatmap-splits-plotly-3-color-schemes")
# %% Example 5: Two color schemes with horizontal colorbars
palettes_2 = (pmv_colors.ELEM_COLORS_ALLOY, pmv_colors.ELEM_COLORS_VESTA)
fig = pmv.ptable_heatmap_splits_plotly(
# Use dummy values for all elements
{str(elem): list(range(len(palettes_2))) for elem in Element},
orientation="vertical",
colorscale=make_color_scale(palettes_2),
colorbar=[
dict(title="VESTA Colors", orientation="h"),
dict(title="ALLOY Colors", orientation="h"),
],
hover_data=dict.fromkeys(map(str, Element), "left: VESTA<br>right: ALLOY"),
)
title = "<b>Element color schemes</b><br>left: VESTA, right: ALLOY"
fig.layout.title.update(text=title, x=0.4, y=0.8)
fig.show()
# pmv.io.save_and_compress_svg(fig, "ptable-heatmap-splits-plotly-2-color-schemes")
# %% Example 6: Mixed colorbar orientations
# Create data with 4 splits
data_dict = {elem.symbol: np_rng.integers(0, 100, size=4) for elem in Element}
# Use grid orientation with 4 different colorscales and mixed colorbar orientations
fig = pmv.ptable_heatmap_splits_plotly(
data=data_dict,
orientation="grid",
# Use colorscale names directly
colorscale=["Viridis", "Plasma", "Inferno", "Magma"],
colorbar=[
dict(title="Top Left", orientation="v", x=-0.05, y=0, len=0.4),
dict(title="Top Right", orientation="v", x=0.05, y=0, len=0.4),
dict(title="Bottom Left", orientation="h"),
dict(title="Bottom Right", orientation="h"),
],
)
title = "<b>Mixed Colorbar Orientations</b><br>Grid Layout Example"
fig.layout.title.update(text=title, x=0.4, y=0.9)
fig.show()
# pmv.io.save_and_compress_svg(fig, "ptable-heatmap-splits-plotly-mixed-colorbars")