Skip to content

Commit 8e50218

Browse files
committed
adhere to PEP 484 (no implicit optional)
drop jupyter-dash from examples/mprester_ptable.ipynb
1 parent 71f4bbd commit 8e50218

16 files changed

+102
-95
lines changed

.pre-commit-config.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/charliermarsh/ruff-pre-commit
10-
rev: v0.0.272
10+
rev: v0.0.275
1111
hooks:
1212
- id: ruff
1313
args: [--fix]
@@ -18,7 +18,7 @@ repos:
1818
- id: black-jupyter
1919

2020
- repo: https://github.com/pre-commit/mirrors-mypy
21-
rev: v1.3.0
21+
rev: v1.4.0
2222
hooks:
2323
- id: mypy
2424
additional_dependencies: [types-requests]
@@ -41,7 +41,7 @@ repos:
4141
- id: trailing-whitespace
4242

4343
- repo: https://github.com/codespell-project/codespell
44-
rev: v2.2.4
44+
rev: v2.2.5
4545
hooks:
4646
- id: codespell
4747
stages: [commit, commit-msg]

dataset_exploration/wbm/readme.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
Refer to [Figshare description](https://figshare.com/s/ff0ad14505f9624f0c05).
1+
Refer to <https://matbench-discovery.materialsproject.org/about-the-data>.

examples/mprester_ptable.ipynb

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"outputs": [],
1616
"source": [
1717
"# dash needed for interactive plots\n",
18-
"!pip install pymatviz dash jupyter-dash"
18+
"!pip install pymatviz dash"
1919
]
2020
},
2121
{
@@ -29,7 +29,6 @@
2929
"import plotly.graph_objects as go\n",
3030
"import plotly.io as pio\n",
3131
"from dash.dependencies import Input, Output\n",
32-
"from jupyter_dash import JupyterDash\n",
3332
"from pymatgen.ext.matproj import MPRester\n",
3433
"\n",
3534
"from pymatviz import count_elements, ptable_heatmap, ptable_heatmap_plotly\n",
@@ -211,7 +210,7 @@
211210
}
212211
],
213212
"source": [
214-
"app = JupyterDash(prevent_initial_callbacks=True)\n",
213+
"app = dash.Dash(prevent_initial_callbacks=True)\n",
215214
"\n",
216215
"graph = dash.dcc.Graph(figure=fig, id=\"ptable-heatmap\", responsive=True)\n",
217216
"dropdown = dash.dcc.Dropdown(\n",

pymatviz/correlation.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import numpy as np
24
from matplotlib import pyplot as plt
35
from numpy.typing import ArrayLike
@@ -38,7 +40,7 @@ def marchenko_pastur(
3840
gamma: float,
3941
sigma: float = 1,
4042
filter_high_evals: bool = False,
41-
ax: plt.Axes = None,
43+
ax: Optional[plt.Axes] = None,
4244
) -> plt.Axes:
4345
"""Plot the eigenvalue distribution of a symmetric matrix (usually a correlation
4446
matrix) against the Marchenko Pastur distribution.

pymatviz/cumulative.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from typing import Any
1+
from typing import Any, Optional
22

33
import matplotlib.pyplot as plt
44
import numpy as np
55
from numpy.typing import ArrayLike
66

77

8-
def cumulative_residual(res: ArrayLike, ax: plt.Axes = None, **kwargs: Any) -> plt.Axes:
8+
def cumulative_residual(
9+
res: ArrayLike, ax: Optional[plt.Axes] = None, **kwargs: Any
10+
) -> plt.Axes:
911
"""Plot the empirical cumulative distribution for the residuals (y - mu).
1012
1113
Args:
@@ -55,7 +57,7 @@ def cumulative_residual(res: ArrayLike, ax: plt.Axes = None, **kwargs: Any) -> p
5557

5658

5759
def cumulative_error(
58-
abs_err: ArrayLike, ax: plt.Axes = None, **kwargs: Any
60+
abs_err: ArrayLike, ax: Optional[plt.Axes] = None, **kwargs: Any
5961
) -> plt.Axes:
6062
"""Plot the empirical cumulative distribution of the absolute errors.
6163

pymatviz/histograms.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
def residual_hist(
2525
y_res: ArrayLike,
26-
ax: plt.Axes = None,
26+
ax: plt.Axes | None = None,
2727
xlabel: str | None = r"Residual ($y_\mathrm{true} - y_\mathrm{pred}$)",
2828
**kwargs: Any,
2929
) -> plt.Axes:
@@ -66,8 +66,8 @@ def true_pred_hist(
6666
y_true: ArrayLike | str,
6767
y_pred: ArrayLike | str,
6868
y_std: ArrayLike | str,
69-
df: pd.DataFrame = None,
70-
ax: plt.Axes = None,
69+
df: pd.DataFrame | None = None,
70+
ax: plt.Axes | None = None,
7171
cmap: str = "hot",
7272
truth_color: str = "blue",
7373
true_label: str = r"$y_\mathrm{true}$",
@@ -140,7 +140,7 @@ def spacegroup_hist(
140140
show_counts: bool = True,
141141
xticks: Literal["all", "crys_sys_edges"] | int = 20,
142142
include_missing: bool = False,
143-
ax: plt.Axes = None,
143+
ax: plt.Axes | None = None,
144144
**kwargs: Any,
145145
) -> plt.Axes:
146146
"""Plot a histogram of spacegroups shaded by crystal system.
@@ -285,8 +285,8 @@ def hist_elemental_prevalence(
285285
formulas: ElemValues,
286286
count_mode: CountMode = "composition",
287287
log: bool = False,
288-
keep_top: int = None,
289-
ax: plt.Axes = None,
288+
keep_top: int | None = None,
289+
ax: plt.Axes | None = None,
290290
bar_values: Literal["percent", "count"] | None = "percent",
291291
h_offset: int = 0,
292292
v_offset: int = 10,

pymatviz/parity.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
def hist_density(
2020
x: ArrayLike | str,
2121
y: ArrayLike | str,
22-
df: pd.DataFrame = None,
22+
df: pd.DataFrame | None = None,
2323
sort: bool = True,
2424
bins: int = 100,
2525
) -> tuple[ArrayLike, ArrayLike, ArrayLike]:
@@ -59,13 +59,13 @@ def hist_density(
5959
def density_scatter(
6060
x: ArrayLike | str,
6161
y: ArrayLike | str,
62-
df: pd.DataFrame = None,
63-
ax: plt.Axes = None,
62+
df: pd.DataFrame | None = None,
63+
ax: plt.Axes | None = None,
6464
sort: bool = True,
6565
log_cmap: bool = True,
6666
density_bins: int = 100,
67-
xlabel: str = None,
68-
ylabel: str = None,
67+
xlabel: str | None = None,
68+
ylabel: str | None = None,
6969
identity: bool = True,
7070
stats: bool | dict[str, Any] = True,
7171
**kwargs: Any,
@@ -132,13 +132,13 @@ def density_scatter(
132132
def scatter_with_err_bar(
133133
x: ArrayLike | str,
134134
y: ArrayLike | str,
135-
df: pd.DataFrame = None,
136-
xerr: ArrayLike = None,
137-
yerr: ArrayLike = None,
138-
ax: plt.Axes = None,
135+
df: pd.DataFrame | None = None,
136+
xerr: ArrayLike | None = None,
137+
yerr: ArrayLike | None = None,
138+
ax: plt.Axes | None = None,
139139
xlabel: str = "Actual",
140140
ylabel: str = "Predicted",
141-
title: str = None,
141+
title: str | None = None,
142142
**kwargs: Any,
143143
) -> plt.Axes:
144144
"""Scatter plot with optional x- and/or y-error bars. Useful when passing model
@@ -179,9 +179,9 @@ def scatter_with_err_bar(
179179
def density_hexbin(
180180
x: ArrayLike | str,
181181
y: ArrayLike | str,
182-
df: pd.DataFrame = None,
183-
ax: plt.Axes = None,
184-
weights: ArrayLike = None,
182+
df: pd.DataFrame | None = None,
183+
ax: plt.Axes | None = None,
184+
weights: ArrayLike | None = None,
185185
xlabel: str = "Actual",
186186
ylabel: str = "Predicted",
187187
**kwargs: Any,
@@ -227,8 +227,8 @@ def density_hexbin(
227227
def density_scatter_with_hist(
228228
x: ArrayLike | str,
229229
y: ArrayLike | str,
230-
df: pd.DataFrame = None,
231-
cell: GridSpec = None,
230+
df: pd.DataFrame | None = None,
231+
cell: GridSpec | None = None,
232232
bins: int = 100,
233233
**kwargs: Any,
234234
) -> plt.Axes:
@@ -243,8 +243,8 @@ def density_scatter_with_hist(
243243
def density_hexbin_with_hist(
244244
x: ArrayLike | str,
245245
y: ArrayLike | str,
246-
df: pd.DataFrame = None,
247-
cell: GridSpec = None,
246+
df: pd.DataFrame | None = None,
247+
cell: GridSpec | None = None,
248248
bins: int = 100,
249249
**kwargs: Any,
250250
) -> plt.Axes:
@@ -259,8 +259,8 @@ def density_hexbin_with_hist(
259259
def residual_vs_actual(
260260
y_true: ArrayLike | str,
261261
y_pred: ArrayLike | str,
262-
df: pd.DataFrame = None,
263-
ax: plt.Axes = None,
262+
df: pd.DataFrame | None = None,
263+
ax: plt.Axes | None = None,
264264
xlabel: str = r"Actual value",
265265
ylabel: str = r"Residual ($y_\mathrm{true} - y_\mathrm{pred}$)",
266266
**kwargs: Any,

pymatviz/ptable.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def count_elements(
135135
def ptable_heatmap(
136136
elem_values: ElemValues,
137137
log: bool = False,
138-
ax: plt.Axes = None,
138+
ax: plt.Axes | None = None,
139139
count_mode: CountMode = "composition",
140140
cbar_title: str = "Element Count",
141141
cbar_max: float | int | None = None,
@@ -144,8 +144,8 @@ def ptable_heatmap(
144144
infty_color: str = "lightskyblue",
145145
na_color: str = "white",
146146
heat_mode: Literal["value", "fraction", "percent"] | None = "value",
147-
precision: str = None,
148-
cbar_precision: str = None,
147+
precision: str | None = None,
148+
cbar_precision: str | None = None,
149149
text_color: str | tuple[str, str] = "auto",
150150
exclude_elements: Sequence[str] = (),
151151
zero_symbol: str | float = "-",
@@ -392,14 +392,14 @@ def ptable_heatmap_plotly(
392392
colorscale: str | Sequence[str] | Sequence[tuple[float, str]] = "viridis",
393393
showscale: bool = True,
394394
heat_mode: Literal["value", "fraction", "percent"] | None = "value",
395-
precision: str = None,
395+
precision: str | None = None,
396396
hover_props: Sequence[str] | dict[str, str] | None = None,
397397
hover_data: dict[str, str | int | float] | pd.Series | None = None,
398398
font_colors: Sequence[str] = ("#eee", "black"),
399399
gap: float = 5,
400-
font_size: int = None,
401-
bg_color: str = None,
402-
color_bar: dict[str, Any] = None,
400+
font_size: int | None = None,
401+
bg_color: str | None = None,
402+
color_bar: dict[str, Any] | None = None,
403403
cscale_range: tuple[float | None, float | None] = (None, None),
404404
exclude_elements: Sequence[str] = (),
405405
log: bool = False,

pymatviz/relevance.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
def roc_curve(
1717
targets: ArrayLike | str,
1818
proba_pos: ArrayLike | str,
19-
df: pd.DataFrame = None,
20-
ax: plt.Axes = None,
19+
df: pd.DataFrame | None = None,
20+
ax: plt.Axes | None = None,
2121
) -> tuple[float, plt.Axes]:
2222
"""Plot the receiver operating characteristic curve of a binary classifier given
2323
target labels and predicted probabilities for the positive class.
@@ -51,8 +51,8 @@ def roc_curve(
5151
def precision_recall_curve(
5252
targets: ArrayLike | str,
5353
proba_pos: ArrayLike | str,
54-
df: pd.DataFrame = None,
55-
ax: plt.Axes = None,
54+
df: pd.DataFrame | None = None,
55+
ax: plt.Axes | None = None,
5656
) -> tuple[float, plt.Axes]:
5757
"""Plot the precision recall curve of a binary classifier.
5858

pymatviz/structure_viz.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ class ExperimentalWarning(Warning):
3030
# inspired by ASE https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib
3131

3232

33-
def _angles_to_rotation_matrix(angles: str, rotation: ArrayLike = None) -> ArrayLike:
33+
def _angles_to_rotation_matrix(
34+
angles: str, rotation: ArrayLike | None = None
35+
) -> ArrayLike:
3436
"""Convert Euler angles to a rotation matrix.
3537
3638
Note the order of angles matters. 50x,40z != 40z,50x.
@@ -106,16 +108,16 @@ def unit_cell_to_lines(cell: ArrayLike) -> tuple[ArrayLike, ArrayLike, ArrayLike
106108

107109
def plot_structure_2d(
108110
struct: Structure,
109-
ax: plt.Axes = None,
111+
ax: plt.Axes | None = None,
110112
rotation: str = "10x,10y,0z",
111113
atomic_radii: float | dict[str, float] | None = None,
112-
colors: dict[str, str | list[float]] = None,
114+
colors: dict[str, str | list[float]] | None = None,
113115
scale: float = 1,
114116
show_unit_cell: bool = True,
115117
show_bonds: bool | NearNeighbors = False,
116118
site_labels: bool | dict[str, str | float] | list[str | float] = True,
117-
label_kwargs: dict[str, Any] = None,
118-
bond_kwargs: dict[str, Any] = None,
119+
label_kwargs: dict[str, Any] | None = None,
120+
bond_kwargs: dict[str, Any] | None = None,
119121
standardize_struct: bool | None = None,
120122
axis: bool | str = "off",
121123
) -> plt.Axes:

pymatviz/uncertainty.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def qq_gaussian(
1818
y_true: ArrayLike | str,
1919
y_pred: ArrayLike | str,
2020
y_std: ArrayLike | dict[str, ArrayLike] | str | Sequence[str],
21-
df: pd.DataFrame = None,
22-
ax: plt.Axes = None,
21+
df: pd.DataFrame | None = None,
22+
ax: plt.Axes | None = None,
2323
) -> plt.Axes:
2424
"""Plot the Gaussian quantile-quantile (Q-Q) plot of one (passed as array) or
2525
multiple (passed as dict) sets of uncertainty estimates for a single pair of ground
@@ -189,10 +189,10 @@ def error_decay_with_uncert(
189189
y_true: ArrayLike | str,
190190
y_pred: ArrayLike | str,
191191
y_std: ArrayLike | dict[str, ArrayLike] | str | Sequence[str],
192-
df: pd.DataFrame = None,
192+
df: pd.DataFrame | None = None,
193193
n_rand: int = 100,
194194
percentiles: bool = True,
195-
ax: plt.Axes = None,
195+
ax: plt.Axes | None = None,
196196
) -> plt.Axes:
197197
"""Plot for assessing the quality of uncertainty estimates. If a model's uncertainty
198198
is well calibrated, i.e. strongly correlated with its error, removing the most

pymatviz/utils.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
def with_hist(
4949
xs: ArrayLike,
5050
ys: ArrayLike,
51-
cell: GridSpec = None,
51+
cell: GridSpec | None = None,
5252
bins: int = 100,
5353
) -> plt.Axes:
5454
"""Call before creating a plot and use the returned `ax_main` for all
@@ -88,10 +88,10 @@ def with_hist(
8888

8989

9090
def annotate_bars(
91-
ax: plt.Axes = None,
91+
ax: plt.Axes | None = None,
9292
v_offset: int | float = 10,
9393
h_offset: int | float = 0,
94-
labels: Sequence[str | int | float] = None,
94+
labels: Sequence[str | int | float] | None = None,
9595
fontsize: int = 14,
9696
y_max_headroom: float = 1.2,
9797
**kwargs: Any,
@@ -140,7 +140,7 @@ def annotate_bars(
140140
def annotate_metrics(
141141
xs: ArrayLike,
142142
ys: ArrayLike,
143-
ax: plt.Axes = None,
143+
ax: plt.Axes | None = None,
144144
metrics: dict[str, float] | Sequence[str] = ("MAE", "$R^2$"),
145145
prefix: str = "",
146146
suffix: str = "",
@@ -248,7 +248,7 @@ def get_crystal_sys(spg: int) -> CrystalSystem:
248248

249249

250250
def add_identity_line(
251-
fig: go.Figure, line_kwds: dict[str, Any] = None, trace_idx: int = 0
251+
fig: go.Figure, line_kwds: dict[str, Any] | None = None, trace_idx: int = 0
252252
) -> go.Figure:
253253
"""Add a line shape to the background layer of a plotly figure spanning
254254
from smallest to largest x/y values in the trace specified by trace_idx.
@@ -296,7 +296,7 @@ def add_identity_line(
296296
def save_fig(
297297
fig: go.Figure | plt.Figure | plt.Axes,
298298
path: str,
299-
plotly_config: dict[str, Any] = None,
299+
plotly_config: dict[str, Any] | None = None,
300300
env_disable: Sequence[str] = ("CI",),
301301
pdf_sleep: float = 0.6,
302302
**kwargs: Any,

0 commit comments

Comments
 (0)