Skip to content

Commit 20e6bff

Browse files
committed
More type fixes.
1 parent dae7d43 commit 20e6bff

File tree

7 files changed

+28
-31
lines changed

7 files changed

+28
-31
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ ignore_missing_imports = true
302302
namespace_packages = true
303303
no_implicit_optional = false
304304
disable_error_code = ["annotation-unchecked", "override", "operator", "attr-defined", "union-attr", "misc"] #, "operator", "arg-type", "index", "call-arg", "return-value", "assignment", "attr-defined"]
305-
exclude = ['src/pymatgen/analysis', 'src/pymatgen/io', 'src/pymatgen/cli', 'src/pymatgen/electronic_structure', 'src/pymatgen/phonon', 'src/pymatgen/vis', "src/pymatgen/alchemy"]
305+
exclude = ['src/pymatgen/analysis', 'src/pymatgen/io', 'src/pymatgen/cli', 'src/pymatgen/electronic_structure', 'src/pymatgen/phonon', "src/pymatgen/alchemy"]
306306
plugins = ["numpy.typing.mypy_plugin"]
307307

308308
[[tool.mypy.overrides]]

src/pymatgen/core/spectrum.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections.abc import Callable
1818
from typing import Literal
1919

20-
from numpy.typing import NDArray
20+
from numpy.typing import ArrayLike, NDArray
2121
from typing_extensions import Self
2222

2323

@@ -50,7 +50,7 @@ class Spectrum(MSONable):
5050
XLABEL = "x"
5151
YLABEL = "y"
5252

53-
def __init__(self, x: NDArray, y: NDArray, *args, **kwargs) -> None:
53+
def __init__(self, x: ArrayLike, y: ArrayLike, *args, **kwargs) -> None:
5454
"""
5555
Args:
5656
x (ndarray): A ndarray of N values.
@@ -62,8 +62,8 @@ def __init__(self, x: NDArray, y: NDArray, *args, **kwargs) -> None:
6262
etc. operators work properly.
6363
**kwargs: Same as that for *args.
6464
"""
65-
self.x = np.array(x)
66-
self.y = np.array(y)
65+
self.x = np.asarray(x)
66+
self.y = np.asarray(y)
6767
self.ydim = self.y.shape
6868
if self.x.shape[0] != self.ydim[0]:
6969
raise ValueError("x and y values have different first dimension!")

src/pymatgen/electronic_structure/core.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
if TYPE_CHECKING:
1414
from collections.abc import Sequence
1515

16-
from numpy.typing import NDArray
16+
from numpy.typing import ArrayLike, NDArray
1717
from typing_extensions import Self
1818

1919
from pymatgen.core import Lattice
@@ -127,7 +127,7 @@ class Magmom(MSONable):
127127
def __init__(
128128
self,
129129
moment: MagMomentLike,
130-
saxis: tuple[float, float, float] = (0, 0, 1),
130+
saxis: ArrayLike = (0, 0, 1),
131131
) -> None:
132132
"""
133133
Args:
@@ -137,10 +137,10 @@ def __init__(
137137
"""
138138
# Init from another Magmom instance
139139
if isinstance(moment, type(self)):
140-
saxis = moment.saxis # type: ignore[has-type]
141-
moment = moment.moment # type: ignore[has-type]
140+
saxis = moment.saxis
141+
moment = moment.moment
142142

143-
magmom: NDArray = np.array(moment, dtype="d")
143+
magmom = np.array(moment, dtype="d")
144144
if magmom.ndim == 0:
145145
magmom = magmom * (0, 0, 1) # (ruff-preview) noqa: PLR6104
146146

src/pymatgen/electronic_structure/dos.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@
2424
np.trapezoid = np.trapz # noqa: NPY201
2525

2626
if TYPE_CHECKING:
27-
from collections.abc import Sequence
2827
from typing import Any, Literal
2928

30-
from numpy.typing import NDArray
29+
from numpy.typing import ArrayLike, NDArray
3130
from typing_extensions import Self
3231

3332
from pymatgen.core.sites import PeriodicSite
@@ -48,7 +47,7 @@ class DOS(Spectrum):
4847
XLABEL = "Energy"
4948
YLABEL = "Density"
5049

51-
def __init__(self, energies: Sequence[float], densities: NDArray, efermi: float) -> None:
50+
def __init__(self, energies: ArrayLike, densities: ArrayLike, efermi: float) -> None:
5251
"""
5352
Args:
5453
energies (Sequence[float]): The Energies.
@@ -181,8 +180,8 @@ class Dos(MSONable):
181180
def __init__(
182181
self,
183182
efermi: float,
184-
energies: Sequence[float],
185-
densities: dict[Spin, NDArray],
183+
energies: ArrayLike,
184+
densities: dict[Spin, ArrayLike],
186185
norm_vol: float | None = None,
187186
) -> None:
188187
"""
@@ -196,10 +195,10 @@ def __init__(
196195
otherwise will be in states/eV/Angstrom^3.
197196
"""
198197
self.efermi = efermi
199-
self.energies = np.array(energies)
198+
self.energies = np.asarray(energies)
200199
self.norm_vol = norm_vol
201200
vol = norm_vol or 1
202-
self.densities = {k: np.array(d) / vol for k, d in densities.items()}
201+
self.densities = {k: np.asarray(d) / vol for k, d in densities.items()}
203202

204203
def __add__(self, other):
205204
"""Add two Dos.

src/pymatgen/phonon/dos.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
np.trapezoid = np.trapz # noqa: NPY201
2020

2121
if TYPE_CHECKING:
22-
from collections.abc import Sequence
23-
24-
from numpy.typing import NDArray
22+
from numpy.typing import ArrayLike, NDArray
2523
from typing_extensions import Self
2624

2725
BOLTZ_THZ_PER_K = const.value("Boltzmann constant in Hz/K") / const.tera # Boltzmann constant in THz/K
@@ -31,7 +29,7 @@
3129
class PhononDos(MSONable):
3230
"""Basic DOS object. All other DOS objects are extended versions of this object."""
3331

34-
def __init__(self, frequencies: Sequence, densities: Sequence) -> None:
32+
def __init__(self, frequencies: ArrayLike, densities: ArrayLike) -> None:
3533
"""
3634
Args:
3735
frequencies: A sequence of frequencies in THz
@@ -141,7 +139,7 @@ def __str__(self) -> str:
141139
return "\n".join(str_arr)
142140

143141
@classmethod
144-
def from_dict(cls, dct: dict[str, Sequence]) -> Self:
142+
def from_dict(cls, dct: dict[str, ArrayLike]) -> Self:
145143
"""Get PhononDos object from dict representation of PhononDos."""
146144
return cls(dct["frequencies"], dct["densities"])
147145

src/pymatgen/phonon/thermal_displacements.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ class ThermalDisplacementMatrices(MSONable):
5151

5252
def __init__(
5353
self,
54-
thermal_displacement_matrix_cart: ArrayLike[ArrayLike],
54+
thermal_displacement_matrix_cart: ArrayLike,
5555
structure: Structure,
5656
temperature: float | None,
57-
thermal_displacement_matrix_cif: ArrayLike[ArrayLike] = None,
57+
thermal_displacement_matrix_cif: ArrayLike | None = None,
5858
) -> None:
5959
"""
6060
Args:
@@ -89,8 +89,8 @@ def __init__(
8989

9090
@staticmethod
9191
def get_full_matrix(
92-
thermal_displacement: ArrayLike[ArrayLike],
93-
) -> np.ndarray[np.ndarray]:
92+
thermal_displacement: ArrayLike,
93+
) -> np.ndarray:
9494
"""Transfers the reduced matrix to the full matrix (order of reduced matrix U11, U22, U33, U23, U13, U12).
9595
9696
Args:
@@ -115,8 +115,8 @@ def get_full_matrix(
115115

116116
@staticmethod
117117
def get_reduced_matrix(
118-
thermal_displacement: ArrayLike[ArrayLike],
119-
) -> np.ndarray[np.ndarray]:
118+
thermal_displacement: ArrayLike,
119+
) -> np.ndarray:
120120
"""Transfers the full matrix to reduced matrix (order of reduced matrix U11, U22, U33, U23, U13, U12).
121121
122122
Args:
@@ -420,7 +420,7 @@ def ratio_prolate(self) -> np.ndarray:
420420
@classmethod
421421
def from_Ucif(
422422
cls,
423-
thermal_displacement_matrix_cif: ArrayLike[ArrayLike],
423+
thermal_displacement_matrix_cif: ArrayLike,
424424
structure: Structure,
425425
temperature: float | None = None,
426426
) -> Self:

src/pymatgen/vis/structure_vtk.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def contains_anion(site):
256256
if sp.symbol in self.excluded_bonding_elements or sp == anion:
257257
exclude = True
258258
break
259-
max_radius = max(max_radius, sp.average_ionic_radius)
259+
max_radius = max(max_radius, sp.average_ionic_radius) # type:ignore[type-var,assignment]
260260
color += occu * np.array(self.el_color_mapping.get(sp.symbol, [0, 0, 0]))
261261

262262
if not exclude:
@@ -982,7 +982,7 @@ def set_structures(self, structures: Sequence[Structure], tags=None):
982982
struct_radii = []
983983
struct_vis_radii = []
984984
for site in struct:
985-
radius = 0
985+
radius = 0.0
986986
vis_radius = 0.2
987987
for species, occu in site.species.items():
988988
radius += occu * (

0 commit comments

Comments
 (0)