Skip to content

Commit 0b9b5da

Browse files
janoshDanielYang59
andauthored
Bump min supported Python to 3.10 (#195)
* drop numpy<2 pin, pin scipy>=1.14, matplotlib>=3.9, pandas>=2.2 * bump min supported python to 3.10 * ruff auto-upgrade code to 3.10 14 of 20 mypy errors remaining * fix CI error Version 3.1 was not found in the local cache * explicitly mark typealias * use | in isinstance * test all make_asset scripts in CI * fix missing packages matminer kaleido in test-scripts SI * relax numpy pin >=1.26 * add env var secrets.MP_API_KEY in test-scripts CI, skip phonon assets on ffonons ImportError * make_assets/phonons.py change exit code to SystemExit(0) * add min version pins for optional deps * add py.typed mark to ship type information --------- Co-authored-by: Haoyu (Daniel) <[email protected]>
1 parent 1de350a commit 0b9b5da

32 files changed

+180
-125
lines changed

.github/workflows/test.yml

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,43 @@ jobs:
1717
uses: janosh/workflows/.github/workflows/pytest-release.yml@main
1818
with:
1919
os: ${{ matrix.os }}
20-
python-version: 3.9
20+
python-version: "3.10"
2121
secrets: inherit
22+
23+
find-scripts:
24+
runs-on: ubuntu-latest
25+
outputs:
26+
script_list: ${{ steps.set-matrix.outputs.script_list }}
27+
steps:
28+
- name: Check out repository
29+
uses: actions/checkout@v4
30+
31+
- name: Find Python scripts
32+
id: set-matrix
33+
run: |
34+
SCRIPTS=$(find examples/make_assets -name "*.py" | jq -R -s -c 'split("\n")[:-1]')
35+
echo "script_list=$SCRIPTS" >> $GITHUB_OUTPUT
36+
37+
test-scripts:
38+
needs: find-scripts
39+
runs-on: ubuntu-latest
40+
strategy:
41+
fail-fast: false
42+
matrix:
43+
script: ${{fromJson(needs.find-scripts.outputs.script_list)}}
44+
steps:
45+
- name: Check out repository
46+
uses: actions/checkout@v4
47+
48+
- name: Set up Python
49+
uses: actions/setup-python@v5
50+
with:
51+
python-version: "3.10"
52+
53+
- name: Install package and dependencies
54+
run: pip install -e .[make-assets]
55+
56+
- name: Run script
57+
run: python ${{ matrix.script }}
58+
env:
59+
MP_API_KEY: ${{ secrets.MP_API_KEY }}

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]
88

99
repos:
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.6.1
11+
rev: v0.6.3
1212
hooks:
1313
- id: ruff
1414
args: [--fix]
@@ -17,7 +17,7 @@ repos:
1717
types_or: [python, jupyter]
1818

1919
- repo: https://github.com/pre-commit/mirrors-mypy
20-
rev: v1.11.1
20+
rev: v1.11.2
2121
hooks:
2222
- id: mypy
2323
additional_dependencies: [types-requests]
@@ -73,7 +73,7 @@ repos:
7373
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$
7474

7575
- repo: https://github.com/pre-commit/mirrors-eslint
76-
rev: v9.9.0
76+
rev: v9.9.1
7777
hooks:
7878
- id: eslint
7979
types: [file]
@@ -87,6 +87,6 @@ repos:
8787
- typescript-eslint
8888

8989
- repo: https://github.com/RobertCraigie/pyright-python
90-
rev: v1.1.376
90+
rev: v1.1.378
9191
hooks:
9292
- id: pyright

examples/dataset_exploration/matpes/eda.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,21 @@
225225
df_pbe[Key.forces] = df_pbe[Key.forces].map(np.abs)
226226

227227
df_r2scan_elem_forces = pd.DataFrame(
228-
{site.specie.symbol: np.linalg.norm(force) for site, force in zip(struct, forces)}
229-
for struct, forces in zip(df_r2scan[Key.structure], df_r2scan[Key.forces])
228+
{
229+
site.specie.symbol: np.linalg.norm(force)
230+
for site, force in zip(struct, forces, strict=True)
231+
}
232+
for struct, forces in zip(
233+
df_r2scan[Key.structure], df_r2scan[Key.forces], strict=True
234+
)
230235
).mean()
231236

232237
df_pbe_elem_forces = pd.DataFrame(
233-
{site.specie.symbol: np.linalg.norm(force) for site, force in zip(struct, forces)}
234-
for struct, forces in zip(df_pbe[Key.structure], df_pbe[Key.forces])
238+
{
239+
site.specie.symbol: np.linalg.norm(force)
240+
for site, force in zip(struct, forces, strict=True)
241+
}
242+
for struct, forces in zip(df_pbe[Key.structure], df_pbe[Key.forces], strict=True)
235243
).mean()
236244

237245

examples/make_assets/phonons.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
from pymatviz.enums import Key
1212

1313

14+
try:
15+
import ffonons # noqa: F401
16+
except ImportError:
17+
raise SystemExit(0) from None # install ffonons to run this script
18+
19+
1420
# %% Plot phonon bands and DOS
1521
for mp_id, formula in (
1622
("mp-2758", "Sr4Se4"),

examples/make_assets/scatter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
np_rng = np.random.default_rng(seed=0)
2929
y_true = np_rng.normal(5, 4, rand_regression_size)
3030
y_pred = 1.2 * y_true - 2 * np_rng.normal(0, 1, rand_regression_size)
31-
y_std = (y_true - y_pred) * 10 * np_rng.normal(0, 0.1, rand_regression_size)
31+
y_std = abs((y_true - y_pred) * 10 * np_rng.normal(0, 0.1, rand_regression_size))
3232

3333

3434
# %% density scatter plotly
@@ -42,7 +42,7 @@
4242
xs, ys = make_blobs(n_samples=100_000, centers=3, n_features=2, random_state=42)
4343

4444
x_col, y_col, target_col = "feature1", "feature2", "target"
45-
df_blobs = pd.DataFrame(dict(zip([x_col, y_col], xs.T)) | {target_col: ys})
45+
df_blobs = pd.DataFrame(dict(zip([x_col, y_col], xs.T, strict=True)) | {target_col: ys})
4646

4747
fig = pmv.density_scatter_plotly(df=df_blobs, x=x_col, y=y_col)
4848
fig.show()

examples/make_assets/uncertainty.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737

3838

3939
# %% Cumulative Plots
40-
ax = pmv.cumulative_error(y_pred, y_true)
40+
ax = pmv.cumulative_error(y_pred - y_true)
4141
pmv.io.save_and_compress_svg(ax, "cumulative-error")
4242

4343

44-
ax = pmv.cumulative_residual(y_pred, y_true)
44+
ax = pmv.cumulative_residual(y_pred - y_true)
4545
pmv.io.save_and_compress_svg(ax, "cumulative-residual")

pymatviz/bar.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def spacegroup_bar(
100100

101101
# sort df by crystal system going from smallest to largest spacegroup numbers
102102
# e.g. triclinic (1-2) comes first, cubic (195-230) last
103-
sys_order = dict(zip(crystal_sys_colors, range(len(crystal_sys_colors))))
103+
sys_order = dict(
104+
zip(crystal_sys_colors, range(len(crystal_sys_colors)), strict=True)
105+
)
104106
df_data = df_data.loc[
105107
df_data[Key.crystal_system].map(sys_order).sort_values().index
106108
]

pymatviz/io.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pathlib import Path
99
from shutil import which
1010
from time import sleep
11-
from typing import TYPE_CHECKING, Any, Callable, Final, Literal
11+
from typing import TYPE_CHECKING, Any, Final, Literal
1212

1313
import matplotlib.pyplot as plt
1414
import numpy as np
@@ -24,7 +24,7 @@
2424

2525

2626
if TYPE_CHECKING:
27-
from collections.abc import Sequence
27+
from collections.abc import Callable, Sequence
2828
from pathlib import Path
2929

3030
import pandas as pd
@@ -130,7 +130,7 @@ def save_fig(
130130
if any(var in os.environ for var in env_disable):
131131
return
132132
# handle matplotlib figures
133-
if isinstance(fig, (plt.Figure, plt.Axes)):
133+
if isinstance(fig, plt.Figure | plt.Axes):
134134
if hasattr(fig, "figure"):
135135
fig = fig.figure # unwrap Axes
136136
fig.savefig(path, **kwargs, transparent=True)
@@ -566,7 +566,7 @@ def print_table(
566566
x0 = (1 - total_width) / 2
567567
y_i = 1
568568

569-
for idx, (yd, row) in enumerate(zip(row_locs, rows)):
569+
for idx, (yd, row) in enumerate(zip(row_locs, rows, strict=True)):
570570
x_i = x0
571571
y_i -= yd
572572
# table zebra stripes
@@ -581,7 +581,7 @@ def print_table(
581581
)
582582
fig.add_artist(rect)
583583

584-
for xd, val in zip(col_widths, row):
584+
for xd, val in zip(col_widths, row, strict=True):
585585
text, weight, ha, bg_color, fg_color = val[:5]
586586

587587
if bg_color != row_colors[1]:

pymatviz/phonons.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import sys
66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Any, Literal, Union, get_args, no_type_check
7+
from typing import TYPE_CHECKING, Any, Literal, get_args, no_type_check
88

99
import plotly.express as px
1010
import plotly.graph_objects as go
@@ -23,7 +23,7 @@
2323
from pymatgen.core import Structure
2424
from typing_extensions import Self
2525

26-
AnyBandStructure = Union[BandStructureSymmLine, PhononBands]
26+
AnyBandStructure = BandStructureSymmLine | PhononBands
2727

2828

2929
@dataclass
@@ -121,8 +121,8 @@ def get_band_xaxis_ticks(
121121
return ticks_x_pos, tick_labels
122122

123123

124-
YMin = Union[float, Literal["y_min"]]
125-
YMax = Union[float, Literal["y_max"]]
124+
YMin = float | Literal["y_min"]
125+
YMax = float | Literal["y_max"]
126126

127127

128128
@no_type_check
@@ -133,7 +133,7 @@ def _shaded_range(
133133
return fig
134134

135135
shade_defaults = dict(layer="below", row="all", col="all")
136-
y_lim = dict(zip(("y_min", "y_max"), fig.layout.yaxis.range))
136+
y_lim = dict(zip(("y_min", "y_max"), fig.layout.yaxis.range, strict=True))
137137

138138
shaded_ys = shaded_ys or {(0, "y_min"): dict(fillcolor="gray", opacity=0.07)}
139139
for (y0, y1), kwds in shaded_ys.items():

pymatviz/powerups/both.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def annotate_metrics(
7070
"""
7171
if isinstance(metrics, str):
7272
metrics = [metrics]
73-
if not isinstance(metrics, (dict, list, tuple, set)):
73+
if not isinstance(metrics, dict | list | tuple | set):
7474
raise TypeError(
7575
f"metrics must be dict|list|tuple|set, not {type(metrics).__name__}"
7676
)
@@ -166,7 +166,7 @@ def add_identity_line(
166166
"""
167167
(x_min, x_max), (y_min, y_max) = get_fig_xy_range(fig=fig, trace_idx=trace_idx)
168168

169-
if isinstance(fig, (plt.Figure, plt.Axes)): # handle matplotlib
169+
if isinstance(fig, plt.Figure | plt.Axes): # handle matplotlib
170170
ax = fig if isinstance(fig, plt.Axes) else fig.gca()
171171

172172
line_defaults = dict(alpha=0.5, zorder=0, linestyle="dashed", color="black")

0 commit comments

Comments
 (0)