Skip to content

Commit f73ba99

Browse files
committed
address coordination_hist TODO 'get the right y_max when bar_mode="stack"'
plus more coordination_hist unit tests
1 parent 56face7 commit f73ba99

File tree

3 files changed

+127
-35
lines changed

3 files changed

+127
-35
lines changed

.pre-commit-config.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]
88

99
repos:
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.7.0
11+
rev: v0.7.1
1212
hooks:
1313
- id: ruff
1414
args: [--fix]
@@ -17,7 +17,7 @@ repos:
1717
types_or: [python, jupyter]
1818

1919
- repo: https://github.com/pre-commit/mirrors-mypy
20-
rev: v1.12.0
20+
rev: v1.13.0
2121
hooks:
2222
- id: mypy
2323
additional_dependencies: [types-requests, types-PyYAML]
@@ -73,7 +73,7 @@ repos:
7373
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$
7474

7575
- repo: https://github.com/pre-commit/mirrors-eslint
76-
rev: v9.12.0
76+
rev: v9.13.0
7777
hooks:
7878
- id: eslint
7979
stages: [manual] # TODO: skip eslint for now
@@ -88,6 +88,6 @@ repos:
8888
- typescript-eslint
8989

9090
- repo: https://github.com/RobertCraigie/pyright-python
91-
rev: v1.1.385
91+
rev: v1.1.386
9292
hooks:
9393
- id: pyright

pymatviz/coordination.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def coordination_hist(
161161
elif split_mode in (SplitMode.by_structure, SplitMode.by_structure_and_element):
162162
n_subplots = len(coord_data)
163163
else:
164+
if split_mode != SplitMode.none:
165+
raise ValueError(f"Invalid {split_mode=}")
164166
n_subplots = 1
165167

166168
n_cols = min(3, n_subplots)
@@ -196,7 +198,6 @@ def coordination_hist(
196198
f"or a custom dict."
197199
)
198200

199-
max_count = 0
200201
row, col = 1, 1
201202
is_single_structure = len(structures) == 1
202203
if annotate_bars is True:
@@ -211,7 +212,6 @@ def coordination_hist(
211212
data = struct_data[elem_symbol]
212213
counts = Counter(data["cn"])
213214
y = [counts.get(i, 0) for i in x_range]
214-
max_count = max(max_count, *y)
215215

216216
hover_text = [
217217
create_hover_text(
@@ -259,7 +259,6 @@ def coordination_hist(
259259
]
260260
counts = Counter(all_cn)
261261
y = [counts.get(i, 0) for i in x_range]
262-
max_count = max(max_count, *y)
263262

264263
hover_text = [
265264
create_hover_text(
@@ -283,7 +282,6 @@ def coordination_hist(
283282
for elem_symbol, data in struct_data.items():
284283
counts = Counter(data["cn"])
285284
y = [counts.get(i, 0) for i in x_range]
286-
max_count = max(max_count, *y)
287285

288286
hover_text = [
289287
create_hover_text(
@@ -322,7 +320,6 @@ def coordination_hist(
322320
for elem_symbol, data in struct_data.items():
323321
counts = Counter(data["cn"])
324322
y = [counts.get(i, 0) for i in x_range]
325-
max_count = max(max_count, *y)
326323

327324
hover_text = [
328325
create_hover_text(
@@ -358,21 +355,16 @@ def coordination_hist(
358355
col = 1
359356
row += 1
360357

361-
fig.update_layout(
362-
barmode=bar_mode,
363-
bargap=0.15,
364-
bargroupgap=0.1,
365-
)
358+
fig.update_layout(barmode=bar_mode, bargap=0.15, bargroupgap=0.1)
366359

367360
# start x-axis just below the smallest observed CN
368361
fig.update_xaxes(
369362
tick0=int(min_cn),
370363
dtick=1,
371364
range=[min_cn - 0.5, max_cn + 0.5],
372365
)
373-
# Ensure y-axis starts at 0 and extends 10% higher than the max count
374-
# TODO needs to a fix to get the right y_max when bar_mode="stack"
375-
y_max = max_count * 1.1
366+
y_max = fig.full_figure_for_development(warn=False).layout.yaxis.range[1]
367+
# Ensure y-axis starts at 0
376368
fig.update_yaxes(title="Count", range=[0, y_max])
377369

378370
# Add title "Coordination Number" to x axes of the last n_cols subplots

tests/test_coordination.py

+118-18
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,28 @@
1818
def test_coordination_hist_single_structure(structures: Sequence[Structure]) -> None:
1919
"""Test coordination_hist with a single structure."""
2020
fig = coordination_hist(structures[0])
21-
assert fig.data
2221
assert len(fig.data) == len({site.specie.symbol for site in structures[0]})
2322

23+
# Test y-axis range
24+
expected_y_max = max(max(trace.y) for trace in fig.data) # get max CN count
25+
y_min, y_max = fig.layout.yaxis.range
26+
assert y_min == 0
27+
assert y_max == pytest.approx(expected_y_max, rel=0.1)
28+
29+
# Test x-axis properties
30+
assert fig.layout.xaxis.tick0 is not None
31+
assert fig.layout.xaxis.dtick == 1
32+
assert fig.layout.xaxis.range[0] < min(trace.x[0] for trace in fig.data)
33+
34+
# Test y-axis properties
35+
assert fig.layout.yaxis.range[0] == 0
36+
assert fig.layout.yaxis.title.text == "Count"
37+
2438

2539
def test_coordination_hist_multiple_structures(structures: Sequence[Structure]) -> None:
2640
"""Test coordination_hist with multiple structures."""
2741
struct_dict = {f"Structure_{i}": struct for i, struct in enumerate(structures)}
2842
fig = coordination_hist(struct_dict)
29-
assert fig.data
3043
expected_traces = sum(
3144
len({site.specie.symbol for site in struct}) for struct in structures
3245
)
@@ -39,7 +52,6 @@ def test_coordination_hist_split_modes(
3952
) -> None:
4053
"""Test coordination_hist with different split modes."""
4154
fig = coordination_hist(structures[0], split_mode=split_mode)
42-
assert fig.data
4355

4456
if split_mode in (SplitMode.none, SplitMode.by_element):
4557
assert len(fig.data) == len({site.specie.symbol for site in structures[0]})
@@ -58,7 +70,6 @@ def test_coordination_hist_custom_strategy(
5870
) -> None:
5971
"""Test coordination_hist with a custom strategy."""
6072
fig = coordination_hist(structures[1], strategy=strategy)
61-
assert fig.data
6273
assert len(fig.data) == 3
6374
expected_max_x = {
6475
3.0: 9,
@@ -95,7 +106,6 @@ def test_coordination_hist_hover_data(structures: Sequence[Structure]) -> None:
95106
"""Test coordination_hist with custom hover data."""
96107
structures[0].add_site_property("test_property", list(range(len(structures[0]))))
97108
fig = coordination_hist(structures[0], hover_data=["test_property"])
98-
assert fig.data
99109
assert "test_property" in fig.data[0].hovertext[0]
100110

101111

@@ -107,15 +117,13 @@ def test_coordination_hist_element_color_scheme(
107117
colors = ("red", "blue", "green", "yellow", "purple", "orange", "pink", "brown")
108118
custom_colors = dict(zip(elements, colors, strict=False))
109119
fig = coordination_hist(structures[0], element_color_scheme=custom_colors)
110-
assert fig.data
111120
for trace in fig.data:
112121
assert trace.marker.color == custom_colors[trace.name.split(" - ")[1]]
113122

114123

115124
def test_coordination_hist_annotate_bars(structures: Sequence[Structure]) -> None:
116125
"""Test coordination_hist with bar annotations."""
117126
fig = coordination_hist(structures[0], annotate_bars=True)
118-
assert fig.data
119127
elements = {site.specie.symbol for site in structures[0]} | {""}
120128
for trace in fig.data:
121129
assert {trace.text} <= elements, f"Invalid text: {trace.text}"
@@ -125,21 +133,11 @@ def test_coordination_hist_bar_kwargs(structures: Sequence[Structure]) -> None:
125133
"""Test coordination_hist with custom bar kwargs."""
126134
bar_kwargs = {"opacity": 0.5, "width": 0.5}
127135
fig = coordination_hist(structures[0], bar_kwargs=bar_kwargs)
128-
assert fig.data
129136
for trace in fig.data:
130137
assert trace.opacity == 0.5
131138
assert trace.width == 0.5
132139

133140

134-
def test_coordination_hist_y_axis_range(structures: Sequence[Structure]) -> None:
135-
"""Test if y-axis range is 10% higher than the max count."""
136-
fig = coordination_hist(structures[0])
137-
assert fig.data
138-
max_count = max(max(trace.y) for trace in fig.data)
139-
expected_y_max = max_count * 1.1
140-
assert fig.layout.yaxis.range[1] == pytest.approx(expected_y_max, rel=1e-6)
141-
142-
143141
def test_coordination_hist_invalid_input() -> None:
144142
"""Test coordination_hist with invalid input."""
145143
with pytest.raises(TypeError):
@@ -169,7 +167,6 @@ def test_coordination_vs_cutoff_line(
169167
"""Test coordination_vs_cutoff_line function with different strategies."""
170168
# Test with a single structure
171169
fig = coordination_vs_cutoff_line(structures[0], strategy=strategy)
172-
assert fig.data
173170
assert len(fig.data) == len({site.specie.symbol for site in structures[0]})
174171

175172
# Test with multiple structures
@@ -232,3 +229,106 @@ def test_coordination_vs_cutoff_line_invalid_strategy() -> None:
232229
structure = Structure(Lattice.cubic(5), ["Si"], [[0, 0, 0]])
233230
with pytest.raises(TypeError, match="Invalid strategy="):
234231
coordination_vs_cutoff_line(structure, strategy="invalid")
232+
233+
234+
def test_coordination_hist_hover_text_formatting(
235+
structures: Sequence[Structure],
236+
) -> None:
237+
"""Test hover text formatting in coordination_hist."""
238+
# Add test property
239+
structures[0].add_site_property("test_prop", list(range(len(structures[0]))))
240+
241+
# Test with single structure
242+
fig = coordination_hist(structures[0], hover_data=["test_prop"])
243+
hover_text = fig.data[0].hovertext[0]
244+
assert "Element:" in hover_text
245+
assert "Coordination number:" in hover_text
246+
assert "test_prop:" in hover_text
247+
248+
# Test with multiple structures
249+
struct_dict = {"struct1": structures[0], "struct2": structures[1]}
250+
fig_multi = coordination_hist(struct_dict, hover_data=["test_prop"])
251+
hover_text_multi = fig_multi.data[0].hovertext[0]
252+
assert "Formula:" in hover_text_multi
253+
254+
255+
def test_coordination_hist_subplot_layout(structures: Sequence[Structure]) -> None:
256+
"""Test subplot layout in coordination_hist."""
257+
struct_dict = {f"s{i}": struct for i, struct in enumerate(structures[:3])}
258+
259+
# Test by_structure layout
260+
fig = coordination_hist(struct_dict, split_mode=SplitMode.by_structure)
261+
assert len(fig.layout.annotations) == len(struct_dict) # subplot titles
262+
263+
# Test by_element layout
264+
elements = {
265+
site.specie.symbol for struct in struct_dict.values() for site in struct
266+
}
267+
fig_elem = coordination_hist(struct_dict, split_mode=SplitMode.by_element)
268+
assert len(fig_elem.layout.annotations) == len(elements)
269+
270+
271+
def test_coordination_hist_bar_customization(structures: Sequence[Structure]) -> None:
272+
"""Test bar customization options in coordination_hist."""
273+
# Test bar width
274+
bar_kwargs = {"width": 0.5}
275+
fig = coordination_hist(structures[0], bar_kwargs=bar_kwargs)
276+
assert all(trace.width == 0.5 for trace in fig.data)
277+
278+
# Test bar opacity
279+
bar_kwargs = {"opacity": 0.7}
280+
fig = coordination_hist(structures[0], bar_kwargs=bar_kwargs)
281+
assert all(trace.opacity == 0.7 for trace in fig.data)
282+
283+
284+
def test_coordination_hist_color_schemes(structures: Sequence[Structure]) -> None:
285+
"""Test different color schemes in coordination_hist."""
286+
# Test JMOL colors
287+
fig_jmol = coordination_hist(
288+
structures[0], element_color_scheme=ElemColorScheme.jmol
289+
)
290+
291+
# Test VESTA colors
292+
fig_vesta = coordination_hist(
293+
structures[0], element_color_scheme=ElemColorScheme.vesta
294+
)
295+
296+
# Colors should be different between schemes
297+
assert any(
298+
t1.marker.color != t2.marker.color
299+
for t1, t2 in zip(fig_jmol.data, fig_vesta.data, strict=True)
300+
)
301+
302+
303+
def test_coordination_hist_invalid_elem_colors(structures: Sequence[Structure]) -> None:
304+
"""Test invalid color scheme handling."""
305+
with pytest.raises(ValueError, match="Invalid.*element_color_scheme"):
306+
coordination_hist(structures[0], element_color_scheme="invalid") # type: ignore[arg-type]
307+
308+
309+
def test_coordination_hist_invalid_hover_data(structures: Sequence[Structure]) -> None:
310+
"""Test invalid hover_data handling."""
311+
with pytest.raises(TypeError, match="Invalid hover_data"):
312+
coordination_hist(structures[0], hover_data=123) # type: ignore[arg-type]
313+
314+
315+
def test_coordination_hist_invalid_split_mode(structures: Sequence[Structure]) -> None:
316+
"""Test invalid split_mode handling."""
317+
split_mode = "invalid_mode"
318+
with pytest.raises(ValueError, match=f"Invalid {split_mode=}"):
319+
coordination_hist(structures[0], split_mode=split_mode)
320+
321+
322+
def test_coordination_hist_bar_annotations(structures: Sequence[Structure]) -> None:
323+
"""Test bar annotation functionality."""
324+
# Test default annotation settings
325+
fig = coordination_hist(structures[0], annotate_bars=True)
326+
assert all(trace.text is not None for trace in fig.data)
327+
328+
# Test custom annotation settings
329+
custom_annotations = {"size": 14, "color": "red"}
330+
fig = coordination_hist(structures[0], annotate_bars=custom_annotations)
331+
assert all(
332+
trace.textfont.size == 14 and trace.textfont.color == "red"
333+
for trace in fig.data
334+
)

0 commit comments

Comments
 (0)