Skip to content

Commit b68a4ee

Browse files
committed
Fix all typing errors within pymatgen core for now.
1 parent cd6e274 commit b68a4ee

16 files changed

+77
-65
lines changed

.github/workflows/lint.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
uv run ruff format --check .
3232
3333
- name: mypy
34-
run: uv run mypy -p pymatgen
34+
run: uv run mypy -p pymatgen.core
3535

3636
- name: pyright
3737
run: uv run pyright src

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ exclude_also = [
301301
ignore_missing_imports = true
302302
namespace_packages = true
303303
no_implicit_optional = false
304-
disable_error_code = ["annotation-unchecked", "override", "operator", "attr-defined", "union-attr"] #, "operator", "arg-type", "index", "call-arg", "return-value", "assignment", "attr-defined"]
304+
disable_error_code = ["annotation-unchecked", "override", "operator", "attr-defined", "union-attr", "misc"] #, "operator", "arg-type", "index", "call-arg", "return-value", "assignment", "attr-defined"]
305305
exclude = ['src/pymatgen/analysis', 'src/pymatgen/io', 'src/pymatgen/cli', 'src/pymatgen/electronic_structure', 'src/pymatgen/phonon']
306306
plugins = ["numpy.typing.mypy_plugin"]
307307

src/pymatgen/command_line/enumlib_caller.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
if TYPE_CHECKING:
4848
from typing import ClassVar
4949

50+
from pymatgen.core.structure import IStructure
51+
5052
logger = logging.getLogger(__name__)
5153

5254
# Favor the use of the newer "enum.x" by Gus Hart over "multienum.x"
@@ -73,7 +75,7 @@ class EnumlibAdaptor:
7375

7476
def __init__(
7577
self,
76-
structure: Structure,
78+
structure: Structure | IStructure,
7779
min_cell_size: int = 1,
7880
max_cell_size: int = 1,
7981
symm_prec: float = 0.1,

src/pymatgen/command_line/mcsqs_caller.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class Sqs(NamedTuple):
3535
"run_mcsqs requires first installing AT-AT, see https://www.brown.edu/Departments/Engineering/Labs/avdw/atat/",
3636
)
3737
def run_mcsqs(
38-
structure: Structure,
38+
structure: Structure | IStructure,
3939
clusters: dict[int, float],
4040
scaling: int | list[int] = 1,
4141
search_time: float = 60,

src/pymatgen/core/bonds.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def obtain_all_bond_lengths(
145145
sp2 = sp2.symbol
146146
syms = tuple(sorted([sp1, sp2]))
147147
if syms in bond_lengths:
148-
return bond_lengths[syms].copy()
148+
return bond_lengths[syms].copy() # type:ignore[index]
149149
if default_bl is not None:
150150
return {1.0: default_bl}
151151
raise ValueError(f"No bond data for elements {syms[0]} - {syms[1]}")

src/pymatgen/core/interface.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from collections.abc import Callable, Sequence
2929
from typing import Any
3030

31-
from numpy.typing import ArrayLike, NDArray
31+
from numpy.typing import NDArray
3232
from typing_extensions import Self
3333

3434
from pymatgen.core import Element, Species
@@ -63,7 +63,7 @@ def __init__(
6363
self,
6464
lattice: np.ndarray | Lattice,
6565
species: Sequence[CompositionLike],
66-
coords: Sequence[ArrayLike],
66+
coords: Sequence[NDArray] | NDArray,
6767
rotation_axis: tuple[int, ...],
6868
rotation_angle: float,
6969
gb_plane: tuple[int, int, int],

src/pymatgen/core/lattice.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1696,12 +1696,12 @@ def get_miller_index_from_coords(
16961696
coords_are_cartesian: bool = True,
16971697
round_dp: int = 4,
16981698
verbose: bool = True,
1699-
) -> tuple[int, ...]:
1699+
) -> tuple[int, int, int]:
17001700
"""Get the Miller index of a plane from a list of site coordinates.
17011701
17021702
A minimum of 3 sets of coordinates are required. If more than 3 sets of
1703-
coordinates are given, the best plane that minimises the distance to all
1704-
points will be calculated.
1703+
coordinates are given, the best plane that minimizes the distance to all
1704+
Points will be calculated.
17051705
17061706
Args:
17071707
coords (iterable): A list or numpy array of coordinates. Can be
@@ -1729,7 +1729,7 @@ def get_miller_index_from_coords(
17291729

17301730
# Get unitary normal vector
17311731
u_norm = vh[2, :]
1732-
return get_integer_index(u_norm, round_dp=round_dp, verbose=verbose)
1732+
return get_integer_index(u_norm, round_dp=round_dp, verbose=verbose) # type: ignore[return-value]
17331733

17341734
def get_recp_symmetry_operation(self, symprec: float = 0.01) -> list[SymmOp]:
17351735
"""Find the symmetric operations of the reciprocal lattice,
@@ -1749,7 +1749,7 @@ def get_recp_symmetry_operation(self, symprec: float = 0.01) -> list[SymmOp]:
17491749
from pymatgen.core.structure import Structure
17501750
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
17511751

1752-
recp = Structure(recp_lattice, ["H"], [[0, 0, 0]])
1752+
recp = Structure(recp_lattice, ["H"], [[0, 0, 0]]) # type:ignore[list-item]
17531753
# Create a function that uses the symmetry operations in the
17541754
# structure to find Miller indices that might give repetitive slabs
17551755
analyzer = SpacegroupAnalyzer(recp, symprec=symprec)

src/pymatgen/core/operations.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def from_axis_angle_and_translation(
262262
axis: NDArray,
263263
angle: float,
264264
angle_in_radians: bool = False,
265-
translation_vec: Sequence[float] = (0, 0, 0),
265+
translation_vec: Sequence[float] | NDArray = (0, 0, 0),
266266
) -> SymmOp:
267267
"""Generate a SymmOp for a rotation about a given axis plus translation.
268268
@@ -301,8 +301,8 @@ def from_axis_angle_and_translation(
301301

302302
@staticmethod
303303
def from_origin_axis_angle(
304-
origin: Sequence[float],
305-
axis: Sequence[float],
304+
origin: Sequence[float] | NDArray,
305+
axis: Sequence[float] | NDArray,
306306
angle: float,
307307
angle_in_radians: bool = False,
308308
) -> SymmOp:

src/pymatgen/core/sites.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pymatgen.util.misc import is_np_dict_equal
1717

1818
if TYPE_CHECKING:
19+
from collections.abc import Sequence
1920
from typing import Any
2021

2122
from numpy.typing import NDArray
@@ -200,7 +201,7 @@ def distance(self, other: Site) -> float:
200201
"""
201202
return float(np.linalg.norm(other.coords - self.coords))
202203

203-
def distance_from_point(self, pt: tuple[float, float, float]) -> float:
204+
def distance_from_point(self, pt: Sequence[float] | NDArray) -> float:
204205
"""Get distance between the site and a point in space.
205206
206207
Args:

src/pymatgen/core/structure.py

+29-29
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def __iter__(self) -> Iterator[PeriodicSite]:
230230
return iter(self.sites)
231231

232232
# TODO return type needs fixing (can be Sequence[PeriodicSite] but raises lots of mypy errors)
233-
def __getitem__(self, ind: int | slice) -> PeriodicSite:
233+
def __getitem__(self, ind: int | slice):
234234
return self.sites[ind] # type: ignore[return-value]
235235

236236
def __len__(self) -> int:
@@ -999,7 +999,7 @@ def from_ase_atoms(cls, atoms: Atoms, **kwargs) -> Self:
999999
"""
10001000
from pymatgen.io.ase import AseAtomsAdaptor
10011001

1002-
return AseAtomsAdaptor.get_structure(atoms, cls=cls, **kwargs)
1002+
return AseAtomsAdaptor.get_structure(atoms, cls=cls, **kwargs) # type:ignore[type-var]
10031003

10041004

10051005
class IStructure(SiteCollection, MSONable):
@@ -2764,7 +2764,7 @@ def get_orderings(
27642764
self,
27652765
mode: Literal["enum", "sqs"] = "enum",
27662766
**kwargs,
2767-
) -> list[Structure | IStructure]:
2767+
) -> list[Structure]:
27682768
"""Get list of orderings for a disordered structure. If structure
27692769
does not contain disorder, the default structure is returned.
27702770
@@ -2784,7 +2784,7 @@ def get_orderings(
27842784
List[Structure]
27852785
"""
27862786
if self.is_ordered:
2787-
return [self]
2787+
return [self] # type:ignore[list-item]
27882788
if mode.startswith("enum"):
27892789
from pymatgen.command_line.enumlib_caller import EnumlibAdaptor
27902790

@@ -2903,7 +2903,7 @@ def from_dict(
29032903
if fmt == "abivars":
29042904
from pymatgen.io.abinit.abiobjects import structure_from_abivars
29052905

2906-
return structure_from_abivars(cls=cls, **dct)
2906+
return structure_from_abivars(cls=cls, **dct) # type:ignore[return-value]
29072907

29082908
lattice = Lattice.from_dict(dct["lattice"])
29092909
sites = [PeriodicSite.from_dict(sd, lattice) for sd in dct["sites"]]
@@ -3117,7 +3117,7 @@ def from_str( # type:ignore[override]
31173117
elif fmt_low == "xsf":
31183118
from pymatgen.io.xcrysden import XSF
31193119

3120-
struct = XSF.from_str(input_string, **kwargs).structure
3120+
struct = XSF.from_str(input_string, **kwargs).structure # type:ignore[assignment]
31213121
elif fmt_low == "mcsqs":
31223122
from pymatgen.io.atat import Mcsqs
31233123

@@ -3257,7 +3257,7 @@ def from_file( # type:ignore[override]
32573257
elif fnmatch(fname, "input*.xml"):
32583258
from pymatgen.io.exciting import ExcitingInput
32593259

3260-
return ExcitingInput.from_file(fname, **kwargs).structure # type:ignore[assignment]
3260+
return ExcitingInput.from_file(fname, **kwargs).structure # type:ignore[assignment, return-value]
32613261
elif fnmatch(fname, "*rndstr.in*") or fnmatch(fname, "*lat.in*") or fnmatch(fname, "*bestsqs*"):
32623262
return cls.from_str(
32633263
contents,
@@ -3270,7 +3270,7 @@ def from_file( # type:ignore[override]
32703270
elif fnmatch(fname, "CTRL*"):
32713271
from pymatgen.io.lmto import LMTOCtrl
32723272

3273-
return LMTOCtrl.from_file(filename=filename, **kwargs).structure # type:ignore[assignment]
3273+
return LMTOCtrl.from_file(filename=filename, **kwargs).structure # type:ignore[assignment,return-value]
32743274
elif fnmatch(fname, "geometry.in*"):
32753275
return cls.from_str(
32763276
contents,
@@ -3496,7 +3496,7 @@ def __init__(
34963496
label = labels[idx] if labels else None
34973497
sites.append(Site(species[idx], coords[idx], properties=prop, label=label))
34983498

3499-
self._sites = tuple(sites)
3499+
self._sites = tuple(sites) # type:ignore[arg-type]
35003500
if validate_proximity and not self.is_valid():
35013501
raise StructureError("Molecule contains sites that are less than 0.01 Angstrom apart!")
35023502

@@ -3785,7 +3785,7 @@ def get_distance(self, i: int, j: int) -> float:
37853785
"""
37863786
return self[i].distance(self[j])
37873787

3788-
def get_sites_in_sphere(self, pt: ArrayLike, r: float) -> list[Neighbor]:
3788+
def get_sites_in_sphere(self, pt: NDArray, r: float) -> list[Neighbor]:
37893789
"""Find all sites within a sphere from a point.
37903790
37913791
Args:
@@ -3825,7 +3825,7 @@ def get_neighbors(self, site: Site, r: float) -> list[Neighbor]:
38253825
nns = self.get_sites_in_sphere(site.coords, r)
38263826
return [nn for nn in nns if nn != site]
38273827

3828-
def get_neighbors_in_shell(self, origin: ArrayLike, r: float, dr: float) -> list[Neighbor]:
3828+
def get_neighbors_in_shell(self, origin: NDArray, r: float, dr: float) -> list[Neighbor]:
38293829
"""Get all sites in a shell centered on origin (coords) between radii
38303830
r-dr and r+dr.
38313831
@@ -3925,7 +3925,7 @@ def get_boxed_structure(
39253925
break
39263926
distances = lattice.get_all_distances(
39273927
lattice.get_fractional_coords(new_coords),
3928-
lattice.get_fractional_coords(all_coords),
3928+
lattice.get_fractional_coords(all_coords), # type:ignore[arg-type]
39293929
)
39303930
if np.amin(distances) > min_dist:
39313931
break
@@ -4104,7 +4104,7 @@ def from_file(cls, filename: PathLike) -> IMolecule | Molecule: # type:ignore[o
41044104
filename = str(filename)
41054105

41064106
with zopen(filename, mode="rt", encoding="utf-8") as file:
4107-
contents: str = file.read()
4107+
contents: str = file.read() # type:ignore[assignment]
41084108
fname = filename.lower()
41094109
if fnmatch(fname, "*.xyz*"):
41104110
return cls.from_str(contents, fmt="xyz")
@@ -4134,9 +4134,9 @@ class Structure(IStructure, collections.abc.MutableSequence):
41344134

41354135
def __init__(
41364136
self,
4137-
lattice: ArrayLike | Lattice,
4137+
lattice: NDArray | Lattice,
41384138
species: Sequence[CompositionLike],
4139-
coords: Sequence[ArrayLike] | np.ndarray,
4139+
coords: Sequence[NDArray] | NDArray,
41404140
charge: float | None = None,
41414141
validate_proximity: bool = False,
41424142
to_unit_cell: bool = False,
@@ -4264,7 +4264,7 @@ def __setitem__(
42644264
else:
42654265
self._sites[ii].species = site[0] # type: ignore[assignment, index]
42664266
if len(site) > 1:
4267-
self._sites[ii].frac_coords = site[1] # type: ignore[index]
4267+
self._sites[ii].frac_coords = site[1] # type: ignore[index,assignment]
42684268
if len(site) > 2:
42694269
self._sites[ii].properties = site[2] # type: ignore[assignment, index]
42704270

@@ -4288,7 +4288,7 @@ def lattice(self, lattice: ArrayLike | Lattice) -> None:
42884288
def append( # type:ignore[override]
42894289
self,
42904290
species: CompositionLike,
4291-
coords: ArrayLike,
4291+
coords: NDArray,
42924292
coords_are_cartesian: bool = False,
42934293
validate_proximity: bool = False,
42944294
properties: dict | None = None,
@@ -4320,7 +4320,7 @@ def insert( # type:ignore[override]
43204320
self,
43214321
idx: int,
43224322
species: CompositionLike,
4323-
coords: ArrayLike,
4323+
coords: NDArray,
43244324
coords_are_cartesian: bool = False,
43254325
validate_proximity: bool = False,
43264326
properties: dict | None = None,
@@ -4358,7 +4358,7 @@ def replace(
43584358
self,
43594359
idx: int,
43604360
species: CompositionLike,
4361-
coords: ArrayLike | None = None,
4361+
coords: NDArray | None = None,
43624362
coords_are_cartesian: bool = False,
43634363
properties: dict | None = None,
43644364
label: str | None = None,
@@ -4757,7 +4757,7 @@ def make_supercell(
47574757
scaling_matrix: ArrayLike,
47584758
to_unit_cell: bool = True,
47594759
in_place: bool = True,
4760-
) -> Self:
4760+
) -> Structure:
47614761
"""Create a supercell.
47624762
47634763
Args:
@@ -4793,7 +4793,7 @@ def make_supercell(
47934793

47944794
return struct
47954795

4796-
def scale_lattice(self, volume: float) -> Self:
4796+
def scale_lattice(self, volume: float) -> Structure:
47974797
"""Perform scaling of the lattice vectors so that length proportions
47984798
and angles are preserved.
47994799
@@ -4920,7 +4920,7 @@ def relax(
49204920
Structure | tuple[Structure, Trajectory]: Relaxed structure or if return_trajectory=True,
49214921
2-tuple of Structure and matgl TrajectoryObserver.
49224922
"""
4923-
return self._relax(
4923+
return self._relax( # type:ignore[return-value]
49244924
calculator,
49254925
relax_cell=relax_cell,
49264926
optimizer=optimizer,
@@ -5111,7 +5111,7 @@ def __init__(
51115111
charge_spin_check=charge_spin_check,
51125112
properties=properties,
51135113
)
5114-
self._sites: list[Site] = list(self._sites)
5114+
self._sites: list[Site] = list(self._sites) # type:ignore[assignment]
51155115

51165116
def __setitem__(
51175117
self,
@@ -5161,7 +5161,7 @@ def __delitem__(self, idx: SupportsIndex | slice) -> None:
51615161
def append( # type:ignore[override]
51625162
self,
51635163
species: CompositionLike,
5164-
coords: ArrayLike,
5164+
coords: NDArray,
51655165
validate_proximity: bool = False,
51665166
properties: dict | None = None,
51675167
) -> Self:
@@ -5250,9 +5250,9 @@ def insert( # type:ignore[override]
52505250
new_site = Site(species, coords, properties=properties, label=label)
52515251
if validate_proximity:
52525252
for site in self:
5253-
if site.distance(new_site) < self.DISTANCE_TOLERANCE:
5253+
if site.distance(new_site) < self.DISTANCE_TOLERANCE: # type:ignore[arg-type]
52545254
raise ValueError("New site is too close to an existing site!")
5255-
cast("list[PeriodicSite]", self.sites).insert(idx, new_site)
5255+
cast("list[PeriodicSite]", self.sites).insert(idx, new_site) # type:ignore[arg-type]
52565256

52575257
return self
52585258

@@ -5278,7 +5278,7 @@ def remove_species(self, species: Sequence[SpeciesLike]) -> Self:
52785278
label=site.label,
52795279
)
52805280
)
5281-
self.sites = new_sites
5281+
self.sites = new_sites # type:ignore[assignment]
52825282
return self
52835283

52845284
def remove_sites(self, indices: Sequence[int]) -> Self:
@@ -5466,7 +5466,7 @@ def substitute(
54665466
# Check whether the functional group is in database.
54675467
if func_group not in FunctionalGroups:
54685468
raise RuntimeError("Can't find functional group in list. Provide explicit coordinate instead")
5469-
functional_group = FunctionalGroups[func_group]
5469+
functional_group = FunctionalGroups[func_group] # type:ignore[assignment]
54705470

54715471
# If a bond length can be found, modify func_grp so that the X-group
54725472
# bond length is equal to the bond length.
@@ -5536,7 +5536,7 @@ def relax(
55365536
Molecule | tuple[Molecule, Trajectory]: Relaxed Molecule or if return_trajectory=True,
55375537
2-tuple of Molecule and ASE TrajectoryObserver.
55385538
"""
5539-
return self._relax(
5539+
return self._relax( # type:ignore[return-value]
55405540
calculator,
55415541
relax_cell=False,
55425542
optimizer=optimizer,

0 commit comments

Comments
 (0)