Skip to content

Commit 1d88e82

Browse files
committed
fix tests
1 parent a89bae6 commit 1d88e82

File tree

6 files changed

+24
-30
lines changed

6 files changed

+24
-30
lines changed

assets/scripts/ptable_plotly/ptable_heatmap_splits_plotly.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def elem_color_scale(element: str, _val: float, split_idx: int) -> str:
7474
)
7575
fig.layout.title.update(text=title, x=0.4, y=0.8)
7676
fig.show()
77+
pmv.io.save_and_compress_svg(fig, "ptable-heatmap-splits-plotly-3-color-schemes")
7778

7879

7980
# %% Visualize multiple element color schemes on a split periodic table heatmap
@@ -89,3 +90,4 @@ def elem_color_scale(element: str, _val: float, split_idx: int) -> str:
8990
title = "<b>Element color schemes</b><br>left: VESTA, right: ALLOY"
9091
fig.layout.title.update(text=title, x=0.4, y=0.8)
9192
fig.show()
93+
pmv.io.save_and_compress_svg(fig, "ptable-heatmap-splits-plotly-2-color-schemes")

pymatviz/coordination/plotly.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from plotly.subplots import make_subplots
1515
from pymatgen.analysis.local_env import NearNeighbors
1616

17-
from pymatviz.colors import ELEM_COLORS_JMOL, ELEM_COLORS_VESTA
17+
from pymatviz.colors import ELEM_COLORS_JMOL
1818
from pymatviz.coordination.helpers import (
1919
CnSplitMode,
2020
calculate_average_cn,
@@ -145,14 +145,10 @@ def coordination_hist(
145145
if isinstance(element_color_scheme, dict):
146146
# Merge custom colors with default Jmol colors to get a complete color scheme
147147
element_colors = ELEM_COLORS_JMOL | element_color_scheme
148-
elif element_color_scheme == ElemColorScheme.jmol:
149-
element_colors = ELEM_COLORS_JMOL
150-
elif element_color_scheme == ElemColorScheme.vesta:
151-
element_colors = ELEM_COLORS_VESTA
152-
elif isinstance(element_color_scheme, dict):
153-
element_colors = element_color_scheme
148+
elif isinstance(element_color_scheme, ElemColorScheme):
149+
element_colors = element_color_scheme.color_map
154150
else:
155-
raise ValueError(
151+
raise TypeError(
156152
f"Invalid {element_color_scheme=}. Must be {', '.join(ElemColorScheme)} "
157153
f"or a custom dict."
158154
)
@@ -403,12 +399,10 @@ def coordination_vs_cutoff_line(
403399

404400
if isinstance(element_color_scheme, dict):
405401
element_colors = ELEM_COLORS_JMOL | element_color_scheme
406-
elif element_color_scheme == ElemColorScheme.jmol:
407-
element_colors = ELEM_COLORS_JMOL
408-
elif element_color_scheme == ElemColorScheme.vesta:
409-
element_colors = ELEM_COLORS_VESTA
402+
elif isinstance(element_color_scheme, ElemColorScheme):
403+
element_colors = element_color_scheme.color_map
410404
else:
411-
raise ValueError(
405+
raise TypeError(
412406
f"Invalid {element_color_scheme=}. Must be {', '.join(ElemColorScheme)} "
413407
"or a custom dict."
414408
)

pymatviz/enums.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from typing_extensions import Self
1616

17+
from pymatviz.typing import RgbColorType
18+
1719
# TODO: remove following definition of StrEnum once Python 3.11+
1820
if sys.version_info >= (3, 11):
1921
from enum import StrEnum
@@ -798,6 +800,13 @@ class ElemColorScheme(LabelEnum):
798800
# custom made for pymatviz
799801
alloy = "alloy", "Alloy", "High-contrast color scheme optimized for metal alloys"
800802

803+
@property
804+
def color_map(self) -> dict[str, RgbColorType]:
805+
"""Return map from element symbol to color."""
806+
import pymatviz.colors as pmv_colors
807+
808+
return getattr(pmv_colors, f"ELEM_COLORS_{self.value.upper()}")
809+
801810

802811
@unique
803812
class SiteCoords(LabelEnum):

pymatviz/structure_viz/mpl.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from pymatgen.core import Structure
2121
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
2222

23-
from pymatviz.colors import ELEM_COLORS_JMOL, ELEM_COLORS_VESTA
2423
from pymatviz.enums import ElemColorScheme, Key
2524
from pymatviz.structure_viz.helpers import (
2625
NO_SYM_MSG,
@@ -202,13 +201,11 @@ class used to plot chemical bonds. Allowed are edgecolor, facecolor, color,
202201
warnings.warn(NO_SYM_MSG, UserWarning, stacklevel=2)
203202

204203
# Get default colors
205-
if str(elem_colors) == str(ElemColorScheme.jmol):
206-
elem_colors = ELEM_COLORS_JMOL
207-
elif str(elem_colors) == str(ElemColorScheme.vesta):
208-
elem_colors = ELEM_COLORS_VESTA
204+
if isinstance(elem_colors, ElemColorScheme):
205+
elem_colors = elem_colors.color_map
209206
elif not isinstance(elem_colors, dict):
210207
valid_color_schemes = "', '".join(ElemColorScheme)
211-
raise ValueError(
208+
raise TypeError(
212209
f"colors must be a dict or one of ('{valid_color_schemes}')"
213210
)
214211

tests/coordination/test_plotly.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def test_coordination_hist_color_schemes(structures: Sequence[Structure]) -> Non
305305

306306
def test_coordination_hist_invalid_elem_colors(structures: Sequence[Structure]) -> None:
307307
"""Test invalid color scheme handling."""
308-
with pytest.raises(ValueError, match="Invalid.*element_color_scheme"):
308+
with pytest.raises(TypeError, match="Invalid.*element_color_scheme"):
309309
coordination_hist(structures[0], element_color_scheme="invalid") # type: ignore[arg-type]
310310

311311

tests/structure_viz/test_structure_viz_mpl.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from pymatgen.core import Structure
1111

1212
import pymatviz as pmv
13-
from pymatviz.colors import ELEM_COLORS_JMOL, ELEM_COLORS_VESTA
1413
from pymatviz.enums import ElemColorScheme, Key
1514

1615

@@ -163,23 +162,16 @@ def subplot_title(struct: Structure, key: str | int) -> str:
163162

164163
def test_structure_2d_color_warning() -> None:
165164
# Copernicium is not in the default color scheme
166-
elem_symbol = "Cn"
165+
elem_symbol = "Fl"
167166
struct = Structure(np.eye(3) * 5, [elem_symbol] * 2, coords=COORDS)
168167
fallback_color = "gray"
169168

170169
for elem_colors in ElemColorScheme:
171-
if elem_colors == ElemColorScheme.jmol:
172-
elem_color_symbols = ", ".join(ELEM_COLORS_JMOL)
173-
elif elem_colors == ElemColorScheme.vesta:
174-
elem_color_symbols = ", ".join(ELEM_COLORS_VESTA)
175-
else:
176-
raise ValueError(f"Unexpected {elem_colors=}")
177-
178170
with pytest.warns(
179171
UserWarning,
180172
match=f"{elem_symbol=} not in elem_colors, using "
181173
f"{fallback_color=}\nelement color palette specifies the "
182-
f"following elements: {elem_color_symbols}",
174+
f"following elements: {', '.join(elem_colors.color_map)}",
183175
):
184176
pmv.structure_2d(struct, elem_colors=elem_colors)
185177

0 commit comments

Comments
 (0)