Skip to content

Commit e9697fc

Browse files
committed
when passed a series, plot_histogram now use series name as x-axis title
fix add_ecdf_line for go.Bar() and throw clear error message when passing unsupported plotly traces (only Histogram, Scatter, and Bar are supported) fix matplotlib text assertion in test_add_best_fit_line, needed to use AnchoredText instead of Annotation
1 parent 890970b commit e9697fc

File tree

8 files changed

+105
-23
lines changed

8 files changed

+105
-23
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]
88

99
repos:
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.4.8
11+
rev: v0.4.9
1212
hooks:
1313
- id: ruff
1414
args: [--fix]
@@ -73,7 +73,7 @@ repos:
7373
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$
7474

7575
- repo: https://github.com/pre-commit/mirrors-eslint
76-
rev: v9.4.0
76+
rev: v9.5.0
7777
hooks:
7878
- id: eslint
7979
types: [file]
@@ -87,6 +87,6 @@ repos:
8787
- typescript-eslint
8888

8989
- repo: https://github.com/RobertCraigie/pyright-python
90-
rev: v1.1.366
90+
rev: v1.1.367
9191
hooks:
9292
- id: pyright

pymatviz/enums.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ class Key(LabelEnum):
121121
element = "element", "Element"
122122
energy = "energy", f"Energy {eV}"
123123
energy_per_atom = "energy_per_atom", f"Energy {eV_per_atom}"
124+
# PBE, PBEsol, PBE+U, r2SCAN, etc.
125+
dft_functional = "dft_functional", "DFT Functional"
126+
uncorrected_energy_per_atom = (
127+
"uncorrected_energy_per_atom",
128+
f"Uncorrected Energy {eV_per_atom}",
129+
)
124130
cohesive_energy_per_atom = (
125131
"cohesive_energy_per_atom",
126132
f"Cohesive Energy {eV_per_atom}",
@@ -131,18 +137,22 @@ class Key(LabelEnum):
131137
formula = "formula", "Formula"
132138
formula_pretty = "formula_pretty", "Pretty Formula"
133139
heat_val = "heat_val", "Heatmap Value" # used by PTableProjector for ptable data
140+
id = "id", "ID"
134141
init_struct = "initial_structure", "Initial Structure"
135142
magmoms = "magmoms", "Magnetic Moments"
136143
mat_id = "material_id", "Material ID"
137144
n_sites = "n_sites", "Number of Sites"
138145
oxi_state_guesses = "oxidation_state_guesses", "Oxidation State Guesses"
139146
spacegroup = "spacegroup", "Spacegroup Number"
140147
spacegroup_symbol = "spacegroup_symbol", "Spacegroup Symbol"
148+
step = "step", "Step"
141149
stress = "stress", "Stress"
142150
structure = "structure", "Structure"
151+
task = "task", "Task"
143152
task_id = "task_id", "Task ID" # unique identifier for a compute task
144153
task_type = "task_type", "Task Type"
145154
volume = "volume", "Volume (ų)"
155+
vol_per_atom = "volume_per_atom", f"Volume per Atom {cubic_angstrom}"
146156
wyckoff = "wyckoff", "Aflow-style Wyckoff Label" # crystallographic site symmetry
147157
phonon_bandstructure = "phonon_bandstructure", "Phonon Band Structure"
148158
phonon_dos = "phonon_dos", "Phonon Density of States"

pymatviz/histograms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,12 +397,22 @@ def plot_histogram(
397397
fig = plt.figure(**fig_kwargs)
398398
plt.bar(bin_edges[:-1], hist_vals, **kwargs)
399399
plt.yscale("log" if log_y else "linear")
400+
401+
if isinstance(values, pd.Series):
402+
plt.xlabel(values.name)
403+
plt.ylabel("Density" if density else "Count")
404+
400405
elif backend == PLOTLY_BACKEND:
401406
fig = go.Figure(**fig_kwargs)
407+
kwargs = {"showlegend": False, **kwargs}
402408
fig.add_bar(x=bin_edges, y=hist_vals, **kwargs)
403409
_bin_width = (bin_edges[1] - bin_edges[0]) * bin_width
404410
fig.update_traces(width=_bin_width, marker_line_width=0)
405411
fig.update_yaxes(type="log" if log_y else "linear")
412+
413+
if isinstance(values, pd.Series):
414+
fig.layout.xaxis.title = values.name
415+
fig.layout.yaxis.title = "Density" if density else "Count"
406416
else:
407417
raise ValueError(f"Unsupported {backend=}. Must be one of {get_args(Backend)}")
408418

pymatviz/powerups.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from matplotlib.offsetbox import AnchoredText
3434
from matplotlib.text import Annotation
3535
from numpy.typing import ArrayLike
36+
from plotly.basedatatypes import BaseTraceType
3637

3738

3839
def with_marginal_hist(
@@ -342,7 +343,13 @@ def add_best_fit_line(
342343
raise TypeError(f"{fig=} must be instance of {type_names}")
343344

344345
backend = PLOTLY_BACKEND if isinstance(fig, go.Figure) else MPL_BACKEND
345-
kwargs.setdefault("color", "blue")
346+
# default to navy color but let annotate_params override
347+
kwargs.setdefault(
348+
"color",
349+
annotate_params.get("color", "navy")
350+
if isinstance(annotate_params, dict)
351+
else "navy",
352+
)
346353

347354
if 0 in {len(xs), len(ys)}:
348355
if isinstance(fig, go.Figure):
@@ -431,12 +438,30 @@ def add_ecdf_line(
431438
type_names = " | ".join(f"{t.__module__}.{t.__qualname__}" for t in valid_types)
432439
raise TypeError(f"{fig=} must be instance of {type_names}")
433440

434-
ecdf = px.ecdf(values if len(values) else fig.data[trace_idx].x).data[0]
441+
if values == ():
442+
target_trace: BaseTraceType = fig.data[trace_idx]
443+
if isinstance(target_trace, (go.Histogram, go.Scatter)):
444+
values = target_trace.x
445+
elif isinstance(target_trace, go.Bar):
446+
xs, ys = target_trace.x, target_trace.y
447+
values = np.repeat(xs[:-1], ys)
448+
449+
else:
450+
cls = type(target_trace)
451+
qual_name = cls.__module__ + "." + cls.__qualname__
452+
raise ValueError(
453+
f"Cannot auto-determine x-values for ECDF from {qual_name}, "
454+
"pass values explicitly. Currently only Histogram, Scatter, Box, "
455+
"and Violin traces are supported and may well need more testing. "
456+
"Please report issues at https://github.com/janosh/pymatviz/issues."
457+
)
458+
459+
ecdf_trace = px.ecdf(values).data[0]
435460

436461
# if fig has facets, add ECDF to all subplots
437462
add_trace_defaults = {} if fig._grid_ref is None else dict(row="all", col="all") # noqa: SLF001
438463

439-
fig.add_trace(ecdf, **add_trace_defaults | kwargs)
464+
fig.add_trace(ecdf_trace, **add_trace_defaults | kwargs)
440465
# move ECDF line to secondary y-axis
441466
# set color to darkened version of primary y-axis color
442467
trace_defaults = dict(yaxis="y2", name="Cumulative", line=dict(color="gray"))
@@ -454,6 +479,7 @@ def add_ecdf_line(
454479
color=color,
455480
linecolor=color,
456481
)
482+
# make secondary ECDF y-axis inherit primary y-axis styles
457483
fig.layout.yaxis2 = yaxis_defaults | getattr(fig.layout, "yaxis2", {})
458484

459485
return fig

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ urls = { Homepage = "https://github.com/janosh/pymatviz" }
2626
requires-python = ">=3.9"
2727
dependencies = [
2828
"matplotlib>=3.6.2",
29-
"numpy>=1.21.0",
29+
"numpy>=1.21.0,<2",
3030
"pandas>=2.0.0",
3131
"plotly",
3232
"pymatgen",

readme.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ See [`pymatviz/histograms.py`](pymatviz/histograms.py).
129129
| ![spg-num-hist-matplotlib] | ![spg-symbol-hist-matplotlib] |
130130
| [`spacegroup_hist([65, 134, 225, ...], backend="plotly")`](pymatviz/histograms.py) | [`spacegroup_hist(["C2/m", "P-43m", "Fm-3m", ...], backend="plotly")`](pymatviz/histograms.py) |
131131
| ![spg-num-hist-plotly] | ![spg-symbol-hist-plotly] |
132-
| [`elements_hist(compositions, log=True, bar_values='count')`](pymatviz/histograms.py) | [`plot_histogram(df_expt_gap["gap expt"], log_y=True)`](pymatviz/histograms.py) |
132+
| [`elements_hist(compositions, log=True, bar_values='count')`](pymatviz/histograms.py) | [`plot_histogram(df_matbench["gap expt"], log_y=True)`](pymatviz/histograms.py) |
133133
| ![elements-hist] | ![matbench-expt-gap-hist] |
134134

135135
[spg-symbol-hist-plotly]: https://github.com/janosh/pymatviz/raw/main/assets/spg-symbol-hist-plotly.svg

tests/test_histograms.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33
from typing import TYPE_CHECKING, Literal
44

55
import matplotlib.pyplot as plt
6+
import pandas as pd
67
import plotly.graph_objects as go
78
import pytest
89

910
from pymatviz import elements_hist, plot_histogram, spacegroup_hist
1011
from pymatviz.utils import MPL_BACKEND, PLOTLY_BACKEND, VALID_BACKENDS
11-
from tests.conftest import y_pred, y_true
12+
from tests.conftest import df_regr, y_pred, y_true
1213

1314

1415
if TYPE_CHECKING:
16+
import numpy as np
1517
from pymatgen.core import Structure
1618

1719
from pymatviz.utils import Backend
@@ -92,8 +94,11 @@ def test_hist_elemental_prevalence(glass_formulas: list[str]) -> None:
9294
@pytest.mark.parametrize("log_y", [True, False])
9395
@pytest.mark.parametrize("backend", VALID_BACKENDS)
9496
@pytest.mark.parametrize("bins", [20, 100])
95-
def test_plot_histogram(log_y: bool, backend: Backend, bins: int) -> None:
96-
fig = plot_histogram(y_true, backend=backend, log_y=log_y, bins=bins)
97+
@pytest.mark.parametrize("values", [y_true, df_regr.y_true])
98+
def test_plot_histogram(
99+
values: np.ndarray | pd.Series, log_y: bool, backend: Backend, bins: int
100+
) -> None:
101+
fig = plot_histogram(values, backend=backend, log_y=log_y, bins=bins)
97102
if backend == MPL_BACKEND:
98103
assert isinstance(fig, plt.Figure)
99104
y_min, y_max = fig.axes[0].get_ylim()
@@ -105,6 +110,10 @@ def test_plot_histogram(log_y: bool, backend: Backend, bins: int) -> None:
105110
}[(log_y, bins)]
106111
assert y_min == pytest.approx(y_min_exp)
107112
assert y_max == pytest.approx(y_max_exp)
113+
114+
if isinstance(values, pd.Series):
115+
assert fig.axes[0].get_xlabel() == values.name
116+
assert fig.axes[0].get_ylabel() == "Count"
108117
else:
109118
assert isinstance(fig, go.Figure)
110119
dev_fig = fig.full_figure_for_development()
@@ -117,3 +126,7 @@ def test_plot_histogram(log_y: bool, backend: Backend, bins: int) -> None:
117126
}[(log_y, bins)]
118127
assert y_min == pytest.approx(y_min_exp)
119128
assert y_max == pytest.approx(y_max_exp)
129+
130+
if isinstance(values, pd.Series):
131+
assert fig.layout.xaxis.title.text == values.name
132+
assert fig.layout.yaxis.title.text == "Count"

tests/test_powerups.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from unittest.mock import patch
77

88
import matplotlib.pyplot as plt
9+
import plotly.express as px
910
import plotly.graph_objects as go
1011
import pytest
11-
from matplotlib.text import Annotation
12+
from matplotlib.offsetbox import AnchoredText
1213

1314
from pymatviz.powerups import (
1415
add_best_fit_line,
@@ -276,25 +277,38 @@ def test_add_ecdf_line(
276277

277278
trace_kwargs = trace_kwargs or {}
278279

279-
ecdf = fig.data[-1] # retrieve ecdf line
280+
ecdf_trace = fig.data[-1] # retrieve ecdf line
280281
expected_name = trace_kwargs.get("name", "Cumulative")
281282
expected_color = trace_kwargs.get("line_color", "gray")
282-
assert ecdf.name == expected_name
283-
assert ecdf.line.color == expected_color
284-
assert ecdf.yaxis == "y2"
283+
assert ecdf_trace.name == expected_name
284+
assert ecdf_trace.line.color == expected_color
285+
assert ecdf_trace.yaxis == "y2"
285286
assert fig.layout.yaxis2.range == (0, 1)
286287
assert fig.layout.yaxis2.title.text == expected_name
287288
assert fig.layout.yaxis2.color == expected_color
288289

289290

290291
def test_add_ecdf_line_raises() -> None:
292+
# check TypeError when passing invalid fig
291293
for fig in (None, "foo", 42.0):
292294
with pytest.raises(
293295
TypeError,
294296
match=f"{fig=} must be instance of plotly.graph_objs._figure.Figure",
295297
):
296298
add_ecdf_line(fig)
297299

300+
# check ValueError when x-values cannot be auto-determined
301+
fig_violin = px.violin(x=[1, 2, 3], y=[4, 5, 6])
302+
violin_trace = type(fig_violin.data[0])
303+
qual_name = f"{violin_trace.__module__}.{violin_trace.__qualname__}"
304+
with pytest.raises(
305+
ValueError, match=f"Cannot auto-determine x-values for ECDF from {qual_name}"
306+
):
307+
add_ecdf_line(fig_violin)
308+
309+
# check ValueError disappears when passing x-values explicitly
310+
add_ecdf_line(fig_violin, values=[1, 2, 3])
311+
298312

299313
def test_with_marginal_hist() -> None:
300314
fig, ax = plt.subplots()
@@ -315,31 +329,40 @@ def test_add_best_fit_line(
315329
matplotlib_scatter: plt.Figure,
316330
annotate_params: bool | dict[str, Any],
317331
) -> None:
332+
# test plotly
318333
fig_plotly = add_best_fit_line(plotly_scatter, annotate_params=annotate_params)
319334
assert isinstance(fig_plotly, go.Figure)
320-
assert fig_plotly.layout.shapes[-1].type == "line"
335+
best_fit_line = fig_plotly.layout.shapes[-1] # retrieve best fit line
336+
assert best_fit_line.type == "line"
337+
expected_color = (
338+
annotate_params.get("color") if isinstance(annotate_params, dict) else "navy"
339+
)
340+
assert best_fit_line.line.color == expected_color
321341

322342
if annotate_params:
323343
assert fig_plotly.layout.annotations[-1].text.startswith("LS fit: ")
344+
assert fig_plotly.layout.annotations[-1].font.color == expected_color
324345
else:
325346
assert len(fig_plotly.layout.annotations) == 0
326347

348+
# test matplotlib
327349
fig_mpl = add_best_fit_line(matplotlib_scatter, annotate_params=annotate_params)
328350
assert isinstance(fig_mpl, plt.Figure)
329351
with pytest.raises(IndexError):
330352
fig_mpl.axes[1]
331353
ax = fig_mpl.axes[0]
332354
assert ax.lines[-1].get_linestyle() == "--"
355+
assert ax.lines[-1].get_color() == expected_color
333356

334-
anno = next( # TODO figure out why this always gives None
335-
(child for child in ax.get_children() if isinstance(child, Annotation)),
357+
anno: AnchoredText = next( # TODO figure out why this always gives None
358+
(child for child in ax.get_children() if isinstance(child, AnchoredText)),
336359
None,
337360
)
338361

339-
# if annotate_params:
340-
# assert anno.get_text().startswith("LS fit: ")
341-
# else:
342-
assert anno is None
362+
if annotate_params:
363+
assert anno.txt.get_text().startswith("LS fit: ")
364+
else:
365+
assert anno is None
343366

344367

345368
def test_add_best_fit_line_invalid_fig() -> None:

0 commit comments

Comments
 (0)