Skip to content

Commit 73fd600

Browse files
committed
remove save_reference_img() from conftest.py
pre-commit autoupdate
1 parent 13df583 commit 73fd600

14 files changed

+33
-96
lines changed

.pre-commit-config.yaml

+7-7
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,24 @@ repos:
1212
- id: isort
1313

1414
- repo: https://github.com/psf/black
15-
rev: 22.6.0
15+
rev: 22.8.0
1616
hooks:
1717
- id: black-jupyter
1818

1919
- repo: https://github.com/PyCQA/flake8
20-
rev: 4.0.1
20+
rev: 5.0.4
2121
hooks:
2222
- id: flake8
2323
additional_dependencies: [flake8-bugbear]
2424

2525
- repo: https://github.com/asottile/pyupgrade
26-
rev: v2.34.0
26+
rev: v2.37.3
2727
hooks:
2828
- id: pyupgrade
2929
args: [--py38-plus]
3030

3131
- repo: https://github.com/pre-commit/mirrors-mypy
32-
rev: v0.961
32+
rev: v0.971
3333
hooks:
3434
- id: mypy
3535
additional_dependencies: [types-requests]
@@ -58,19 +58,19 @@ repos:
5858
exclude: tests
5959

6060
- repo: https://github.com/codespell-project/codespell
61-
rev: v2.1.0
61+
rev: v2.2.1
6262
hooks:
6363
- id: codespell
6464
stages: [commit, commit-msg]
6565
exclude_types: [csv, svg, html, yaml, jupyter]
6666

6767
- repo: https://github.com/PyCQA/autoflake
68-
rev: v1.4
68+
rev: v1.5.3
6969
hooks:
7070
- id: autoflake
7171

7272
- repo: https://github.com/nbQA-dev/nbQA
73-
rev: 1.3.1
73+
rev: 1.4.0
7474
hooks:
7575
- id: nbqa-pyupgrade
7676
args: [--py38-plus]

assets/_generate_assets.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@
265265
ax = plot_structure_2d(struct)
266266
formula = struct.composition.reduced_formula
267267
_, spacegroup = struct.get_space_group_info()
268-
ax.set_title(
269-
f"{formula} (disordered {mp_id} with {spacegroup = })", fontweight="bold"
270-
)
268+
269+
anno_text = f"{formula}\ndisordered {mp_id} with {spacegroup = }"
270+
ax.text(0.5, 1, anno_text, url=href, ha="center", transform=ax.transAxes)
271271

272272
ax.figure.set_size_inches(8, 8)
273273

pymatviz/correlation.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ def marchenko_pastur(
6262
Returns:
6363
ax: The plot's matplotlib Axes.
6464
"""
65-
if ax is None:
66-
ax = plt.gca()
65+
ax = ax or plt.gca()
6766

6867
# use eigvalsh for speed since correlation matrix is symmetric
6968
evals = np.linalg.eigvalsh(matrix)

pymatviz/cumulative.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ def cumulative_residual(
2020
Returns:
2121
ax: The plot's matplotlib Axes.
2222
"""
23-
if ax is None:
24-
ax = plt.gca()
23+
ax = ax or plt.gca()
2524

2625
res = np.sort(preds - targets)
2726

@@ -74,8 +73,7 @@ def cumulative_error(
7473
Returns:
7574
ax: The plot's matplotlib Axes.
7675
"""
77-
if ax is None:
78-
ax = plt.gca()
76+
ax = ax or plt.gca()
7977

8078
errors = np.sort(np.abs(preds - targets))
8179
n_data = len(errors)

pymatviz/histograms.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def residual_hist(
4242
Returns:
4343
ax: The plot's matplotlib Axes.
4444
"""
45-
if ax is None:
46-
ax = plt.gca()
45+
ax = ax or plt.gca()
4746

4847
y_res = y_pred - y_true
4948

@@ -74,8 +73,8 @@ def true_pred_hist(
7473
truth_color: str = "blue",
7574
**kwargs: Any,
7675
) -> plt.Axes:
77-
"""Plot a histogram of model predictions with bars colored by the mean uncertainty of
78-
predictions in that bin. Overlaid by a more transparent histogram of ground truth
76+
"""Plot a histogram of model predictions with bars colored by the mean uncertainty
77+
of predictions in that bin. Overlaid by a more transparent histogram of ground truth
7978
values.
8079
8180
Args:
@@ -91,8 +90,7 @@ def true_pred_hist(
9190
Returns:
9291
ax: The plot's matplotlib Axes.
9392
"""
94-
if ax is None:
95-
ax = plt.gca()
93+
ax = ax or plt.gca()
9694

9795
color_map = getattr(plt.cm, cmap)
9896
y_true, y_pred, y_std = np.array([y_true, y_pred, y_std])
@@ -159,8 +157,7 @@ def spacegroup_hist(
159157
Returns:
160158
ax: The plot's matplotlib Axes.
161159
"""
162-
if ax is None:
163-
ax = plt.gca()
160+
ax = ax or plt.gca()
164161

165162
if isinstance(next(iter(data)), Structure):
166163
# if 1st sequence item is structure, assume all are
@@ -316,8 +313,7 @@ def hist_elemental_prevalence(
316313
Returns:
317314
ax: The plot's matplotlib Axes.
318315
"""
319-
if ax is None:
320-
ax = plt.gca()
316+
ax = ax or plt.gca()
321317

322318
elem_counts = count_elements(formulas, count_mode)
323319
non_zero = elem_counts[elem_counts > 0].sort_values(ascending=False)

pymatviz/parity.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ def density_scatter(
8080
Returns:
8181
ax: The plot's matplotlib Axes.
8282
"""
83-
if ax is None:
84-
ax = plt.gca()
83+
ax = ax or plt.gca()
8584

8685
xs, ys, cs = hist_density(xs, ys, sort=sort, bins=density_bins)
8786

@@ -130,8 +129,7 @@ def scatter_with_err_bar(
130129
Returns:
131130
ax: The plot's matplotlib Axes.
132131
"""
133-
if ax is None:
134-
ax = plt.gca()
132+
ax = ax or plt.gca()
135133

136134
styles = dict(markersize=6, fmt="o", ecolor="g", capthick=2, elinewidth=2)
137135
ax.errorbar(xs, ys, yerr=yerr, xerr=xerr, **kwargs, **styles)
@@ -171,8 +169,7 @@ def density_hexbin(
171169
Returns:
172170
ax: The plot's matplotlib Axes.
173171
"""
174-
if ax is None:
175-
ax = plt.gca()
172+
ax = ax or plt.gca()
176173

177174
# the scatter plot
178175
hexbin = ax.hexbin(xs, yx, gridsize=75, mincnt=1, bins="log", C=weights, **kwargs)
@@ -246,8 +243,7 @@ def residual_vs_actual(
246243
Returns:
247244
ax: The plot's matplotlib Axes.
248245
"""
249-
if ax is None:
250-
ax = plt.gca()
246+
ax = ax or plt.gca()
251247

252248
y_err = y_true - y_pred
253249

pymatviz/ptable.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,7 @@ def ptable_heatmap(
188188
# TODO can we pass as a kwarg and still ensure aspect ratio respected?
189189
fig = plt.figure(figsize=(0.75 * n_columns, 0.7 * n_rows))
190190

191-
if ax is None:
192-
ax = plt.gca()
191+
ax = ax or plt.gca()
193192

194193
rw = rh = 0.9 # rectangle width/height
195194

@@ -273,7 +272,7 @@ def ptable_heatmap(
273272

274273
mappable = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
275274

276-
def tick_fmt(val: float, pos: int) -> str:
275+
def tick_fmt(val: float, _pos: int) -> str:
277276
# val: value at color axis tick (e.g. 10.0, 20.0, ...)
278277
# pos: zero-based tick counter (e.g. 0, 1, 2, ...)
279278
if heat_mode == "percent":

pymatviz/relevance.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ def roc_curve(
2121
Returns:
2222
tuple[float, ax]: The classifier's ROC-AUC and the plot's matplotlib Axes.
2323
"""
24-
if ax is None:
25-
ax = plt.gca()
24+
ax = ax or plt.gca()
2625

2726
# get the metrics
2827
false_pos_rate, true_pos_rate, _ = skm.roc_curve(targets, proba_pos)
@@ -51,8 +50,7 @@ def precision_recall_curve(
5150
Returns:
5251
tuple[float, ax]: The classifier's precision score and the matplotlib Axes.
5352
"""
54-
if ax is None:
55-
ax = plt.gca()
53+
ax = ax or plt.gca()
5654

5755
# get the metrics
5856
precision, recall, _ = skm.precision_recall_curve(targets, proba_pos)

pymatviz/structure_viz.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,7 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
177177
Returns:
178178
plt.Axes: matplotlib Axes instance with plotted structure.
179179
"""
180-
if ax is None:
181-
ax = plt.gca()
180+
ax = ax or plt.gca()
182181

183182
if isinstance(site_labels, (list, tuple)):
184183
if len(site_labels) != len(struct):

pymatviz/uncertainty.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ def qq_gaussian(
3636
Returns:
3737
ax: The plot's matplotlib Axes.
3838
"""
39-
if ax is None:
40-
ax = plt.gca()
39+
ax = ax or plt.gca()
4140

4241
if isinstance(y_std, np.ndarray):
4342
y_std = {"std": y_std}
@@ -206,8 +205,7 @@ def error_decay_with_uncert(
206205
ax: matplotlib Axes object with plotted model error drop curve based on
207206
excluding data points by order of large to small model uncertainties.
208207
"""
209-
if ax is None:
210-
ax = plt.gca()
208+
ax = ax or plt.gca()
211209

212210
xs = range(100 if percentiles else len(y_true), 0, -1)
213211

pymatviz/utils.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ def annotate_bars(
112112
**kwargs: Additional arguments (rotation, arrowprops, etc.) are passed to
113113
ax.annotate().
114114
"""
115-
if ax is None:
116-
ax = plt.gca()
115+
ax = ax or plt.gca()
117116

118117
if labels is None:
119118
labels = [int(patch.get_height()) for patch in ax.patches]
@@ -170,8 +169,7 @@ def add_mae_r2_box(
170169
Returns:
171170
AnchoredText: Instance containing the metrics.
172171
"""
173-
if ax is None:
174-
ax = plt.gca()
172+
ax = ax or plt.gca()
175173

176174
mask = ~np.isnan(xs) & ~np.isnan(ys)
177175
xs, ys = xs[mask], ys[mask]

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ author = Janosh Riebesell
99
author_email = [email protected]
1010
license = MIT
1111
license_files = license
12-
keywords = materials informatics, materials discovery, data visualization, plotly, matplotlib
12+
keywords = science, materials informatics, materials discovery, chemistry, data visualization, plotly, matplotlib
1313
classifiers =
1414
Programming Language :: Python :: 3
1515
Programming Language :: Python :: 3 :: Only

tests/conftest.py

-40
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
from __future__ import annotations
22

3-
import subprocess
4-
from shutil import which
5-
63
import matplotlib.pyplot as plt
74
import numpy as np
85
import plotly.express as px
96
import pytest
107
from pymatgen.core import Lattice, Structure
118

12-
from pymatviz import ROOT
13-
149

1510
# random regression data
1611
np.random.seed(42)
@@ -67,38 +62,3 @@ def plotly_scatter():
6762
y2 = xs**0.5
6863
fig = px.scatter(x=xs, y=[y1, y2])
6964
return fig
70-
71-
72-
def save_reference_img(save_to: str) -> None:
73-
"""Save a matplotlib figure to a specified fixture path.
74-
75-
Raises:
76-
ValueError: save_to is not inside 'tests/fixtures/' directory.
77-
"""
78-
if not save_to.startswith((f"{ROOT}/tests/fixtures/", "tests/fixtures/")):
79-
raise ValueError(f"{save_to=} must point at 'tests/fixtures/'")
80-
81-
pngquant, zopflipng = which("pngquant"), which("zopflipng")
82-
83-
print(
84-
f"created new fixture {save_to=}, image comparison will run for real on "
85-
"subsequent test runs unless fixture is deleted"
86-
)
87-
plt.savefig(save_to)
88-
plt.close()
89-
90-
if not pngquant:
91-
return print("Warning: pngquant not installed. Cannot compress new fixture.")
92-
if not zopflipng:
93-
return print("Warning: zopflipng not installed. Cannot compress new fixture.")
94-
95-
subprocess.run(
96-
f"{pngquant} 32 --skip-if-larger --ext .png --force".split() + [save_to],
97-
check=False,
98-
capture_output=True,
99-
)
100-
subprocess.run(
101-
[zopflipng, "-y", save_to, save_to],
102-
check=True,
103-
capture_output=True,
104-
)

tests/test_structure_viz.py

-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
import os
4-
53
import matplotlib.pyplot as plt
64
import pandas as pd
75
import pytest
@@ -11,8 +9,6 @@
119
from pymatviz.structure_viz import plot_structure_2d
1210

1311

14-
os.makedirs(fixture_dir := "tests/fixtures/structure_viz", exist_ok=True)
15-
1612
lattice = Lattice.cubic(5)
1713
disordered_struct = Structure(
1814
lattice, [{"Fe": 0.75, "C": 0.25}, "O"], [[0, 0, 0], [0.5, 0.5, 0.5]]

0 commit comments

Comments
 (0)