Skip to content

Commit 3ec8f43

Browse files
committed
More stringent type checking.
1 parent b68a4ee commit 3ec8f43

File tree

2 files changed

+81
-75
lines changed

2 files changed

+81
-75
lines changed

src/pymatgen/core/lattice.py

+60-53
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def __init__(
6767
mat = np.array(matrix, dtype=np.float64).reshape((3, 3))
6868
mat.setflags(write=False)
6969
self._matrix: NDArray[np.float64] = mat
70-
self._inv_matrix: np.ndarray | None = None
70+
self._inv_matrix: NDArray[np.float64] | None = None
7171
self._diags = None
72-
self._lll_matrix_mappings: dict[float, tuple[np.ndarray, np.ndarray]] = {}
72+
self._lll_matrix_mappings: dict[float, tuple[NDArray[np.float64], NDArray[np.float64]]] = {}
7373
self._lll_inverse = None
7474

7575
self.pbc = pbc
@@ -178,7 +178,7 @@ def is_orthogonal(self) -> bool:
178178
return all(abs(a - 90) < 1e-5 for a in self.angles)
179179

180180
@property
181-
def matrix(self) -> np.ndarray:
181+
def matrix(self) -> NDArray[np.float64]:
182182
"""Copy of matrix representing the Lattice."""
183183
return self._matrix
184184

@@ -188,23 +188,23 @@ def is_3d_periodic(self) -> bool:
188188
return all(self.pbc)
189189

190190
@property
191-
def inv_matrix(self) -> np.ndarray:
191+
def inv_matrix(self) -> NDArray[np.float64]:
192192
"""Inverse of lattice matrix."""
193193
if self._inv_matrix is None:
194-
self._inv_matrix = np.linalg.inv(self._matrix)
194+
self._inv_matrix = np.linalg.inv(self._matrix) # type: ignore[assignment]
195195
self._inv_matrix.setflags(write=False)
196-
return self._inv_matrix
196+
return self._inv_matrix # type: ignore[return-value]
197197

198198
@property
199-
def metric_tensor(self) -> np.ndarray:
199+
def metric_tensor(self) -> NDArray[np.float64]:
200200
"""The metric tensor of the lattice."""
201201
return np.dot(self._matrix, self._matrix.T)
202202

203203
def copy(self) -> Self:
204204
"""Make a copy of this lattice."""
205205
return type(self)(self.matrix.copy(), pbc=self.pbc)
206206

207-
def get_cartesian_coords(self, fractional_coords: ArrayLike) -> np.ndarray:
207+
def get_cartesian_coords(self, fractional_coords: ArrayLike) -> NDArray[np.float64]:
208208
"""Get the Cartesian coordinates given fractional coordinates.
209209
210210
Args:
@@ -215,7 +215,7 @@ def get_cartesian_coords(self, fractional_coords: ArrayLike) -> np.ndarray:
215215
"""
216216
return np.dot(fractional_coords, self._matrix)
217217

218-
def get_fractional_coords(self, cart_coords: ArrayLike) -> np.ndarray:
218+
def get_fractional_coords(self, cart_coords: ArrayLike) -> NDArray[np.float64]:
219219
"""Get the fractional coordinates given Cartesian coordinates.
220220
221221
Args:
@@ -229,7 +229,7 @@ def get_fractional_coords(self, cart_coords: ArrayLike) -> np.ndarray:
229229
def get_vector_along_lattice_directions(
230230
self,
231231
cart_coords: ArrayLike,
232-
) -> np.ndarray:
232+
) -> NDArray[np.float64]:
233233
"""Get the coordinates along lattice directions given Cartesian coordinates.
234234
235235
Note, this is different than a projection of the Cartesian vector along the
@@ -540,23 +540,23 @@ def reciprocal_lattice_crystallographic(self) -> Self:
540540
return type(self)(self.reciprocal_lattice.matrix / (2 * np.pi))
541541

542542
@property
543-
def lll_matrix(self) -> np.ndarray:
543+
def lll_matrix(self) -> NDArray[np.float64]:
544544
"""The matrix for LLL reduction."""
545545
if 0.75 not in self._lll_matrix_mappings:
546546
self._lll_matrix_mappings[0.75] = self._calculate_lll()
547547
return self._lll_matrix_mappings[0.75][0]
548548

549549
@property
550-
def lll_mapping(self) -> np.ndarray:
550+
def lll_mapping(self) -> NDArray[np.float64]:
551551
"""The mapping between the LLL reduced lattice and the original lattice."""
552552
if 0.75 not in self._lll_matrix_mappings:
553553
self._lll_matrix_mappings[0.75] = self._calculate_lll()
554554
return self._lll_matrix_mappings[0.75][1]
555555

556556
@property
557-
def lll_inverse(self) -> np.ndarray:
557+
def lll_inverse(self) -> NDArray[np.float64]:
558558
"""Inverse of self.lll_mapping."""
559-
return np.linalg.inv(self.lll_mapping)
559+
return np.linalg.inv(self.lll_mapping) # type: ignore[return-value]
560560

561561
@property
562562
def selling_vector(self) -> NDArray[np.float64]:
@@ -922,7 +922,7 @@ def find_all_mappings(
922922
ltol: float = 1e-5,
923923
atol: float = 1,
924924
skip_rotation_matrix: bool = False,
925-
) -> Iterator[tuple[Lattice, np.ndarray | None, np.ndarray]]:
925+
) -> Iterator[tuple[Lattice, NDArray[np.float64] | None, NDArray[np.float64]]]:
926926
"""Find all mappings between current lattice and another lattice.
927927
928928
Args:
@@ -991,7 +991,7 @@ def find_mapping(
991991
ltol: float = 1e-5,
992992
atol: float = 1,
993993
skip_rotation_matrix: bool = False,
994-
) -> tuple[Lattice, np.ndarray | None, np.ndarray] | None:
994+
) -> tuple[Lattice, NDArray[np.float64] | None, NDArray[np.float64]] | None:
995995
"""Find a mapping between current lattice and another lattice. There
996996
are an infinite number of choices of basis vectors for two entirely
997997
equivalent lattices. This method returns a mapping that maps
@@ -1006,7 +1006,7 @@ def find_mapping(
10061006
Defaults to False.
10071007
10081008
Returns:
1009-
tuple[Lattice, np.ndarray, np.ndarray]: (aligned_lattice, rotation_matrix, scale_matrix)
1009+
tuple[Lattice, NDArray[np.float_], NDArray[np.float_]]: (aligned_lattice, rotation_matrix, scale_matrix)
10101010
if a mapping is found. aligned_lattice is a rotated version of other_lattice that
10111011
has the same lattice parameters, but which is aligned in the
10121012
coordinate system of this lattice so that translational points
@@ -1039,7 +1039,7 @@ def get_lll_reduced_lattice(self, delta: float = 0.75) -> Self:
10391039
self._lll_matrix_mappings[delta] = self._calculate_lll()
10401040
return type(self)(self._lll_matrix_mappings[delta][0])
10411041

1042-
def _calculate_lll(self, delta: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
1042+
def _calculate_lll(self, delta: float = 0.75) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
10431043
"""Perform a Lenstra-Lenstra-Lovasz lattice basis reduction to obtain a
10441044
c-reduced basis. This method returns a basis which is as "good" as
10451045
possible, with "good" defined by orthogonality of the lattice vectors.
@@ -1116,15 +1116,15 @@ def _calculate_lll(self, delta: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
11161116
result = np.linalg.lstsq(q.T, p.T, rcond=None)[0].T
11171117
u[k:3, (k - 2) : k] = result
11181118

1119-
return a.T, mapping.T
1119+
return a.T, mapping.T # type: ignore[return-value]
11201120

1121-
def get_lll_frac_coords(self, frac_coords: ArrayLike) -> np.ndarray:
1121+
def get_lll_frac_coords(self, frac_coords: ArrayLike) -> NDArray[np.float64]:
11221122
"""Given fractional coordinates in the lattice basis, returns corresponding
11231123
fractional coordinates in the lll basis.
11241124
"""
11251125
return np.dot(frac_coords, self.lll_inverse)
11261126

1127-
def get_frac_coords_from_lll(self, lll_frac_coords: ArrayLike) -> np.ndarray:
1127+
def get_frac_coords_from_lll(self, lll_frac_coords: ArrayLike) -> NDArray[np.float64]:
11281128
"""Given fractional coordinates in the lll basis, returns corresponding
11291129
fractional coordinates in the lattice basis.
11301130
"""
@@ -1292,7 +1292,7 @@ def scale(self, new_volume: float) -> Self:
12921292

12931293
return type(self)(versors * (new_c * ratios), pbc=self.pbc)
12941294

1295-
def get_wigner_seitz_cell(self) -> list[list[np.ndarray]]:
1295+
def get_wigner_seitz_cell(self) -> list[list[NDArray[np.float64]]]:
12961296
"""Get the Wigner-Seitz cell for the given lattice.
12971297
12981298
Returns:
@@ -1316,7 +1316,7 @@ def get_wigner_seitz_cell(self) -> list[list[np.ndarray]]:
13161316

13171317
return out
13181318

1319-
def get_brillouin_zone(self) -> list[list[np.ndarray]]:
1319+
def get_brillouin_zone(self) -> list[list[NDArray[np.float64]]]:
13201320
"""Get the Wigner-Seitz cell for the reciprocal lattice, aka the
13211321
Brillouin Zone.
13221322
@@ -1333,7 +1333,7 @@ def dot(
13331333
coords_a: ArrayLike,
13341334
coords_b: ArrayLike,
13351335
frac_coords: bool = False,
1336-
) -> np.ndarray:
1336+
) -> NDArray[np.float64]:
13371337
"""Compute the scalar product of vector(s).
13381338
13391339
Args:
@@ -1361,7 +1361,7 @@ def dot(
13611361

13621362
return np.array(list(itertools.starmap(np.dot, zip(cart_a, cart_b, strict=True))))
13631363

1364-
def norm(self, coords: ArrayLike, frac_coords: bool = True) -> np.ndarray:
1364+
def norm(self, coords: ArrayLike, frac_coords: bool = True) -> NDArray[np.float64]:
13651365
"""Compute the norm of vector(s).
13661366
13671367
Args:
@@ -1448,7 +1448,7 @@ def get_points_in_sphere_py(
14481448
center: ArrayLike,
14491449
r: float,
14501450
zip_results: bool = True,
1451-
) -> list[tuple[np.ndarray, float, int, np.ndarray]] | list[np.ndarray]:
1451+
) -> list[tuple[NDArray[np.float64], float, int, NDArray[np.float64]]] | list[NDArray[np.float64]]:
14521452
"""Find all points within a sphere from the point taking into account
14531453
periodic boundary conditions. This includes sites in other periodic
14541454
images.
@@ -1503,8 +1503,8 @@ def get_points_in_sphere_old(
15031503
r: float,
15041504
zip_results=True,
15051505
) -> (
1506-
list[tuple[np.ndarray, float, int, np.ndarray]]
1507-
| tuple[list[np.ndarray], list[float], list[int], list[np.ndarray]]
1506+
list[tuple[NDArray[np.float64], float, int, NDArray[np.float64]]]
1507+
| tuple[list[NDArray[np.float64]], list[float], list[int], list[NDArray[np.float64]]]
15081508
):
15091509
"""Find all points within a sphere from the point taking into account
15101510
periodic boundary conditions. This includes sites in other periodic
@@ -1603,7 +1603,7 @@ def get_all_distances(
16031603
self,
16041604
frac_coords1: ArrayLike,
16051605
frac_coords2: ArrayLike,
1606-
) -> np.ndarray:
1606+
) -> NDArray[np.float64]:
16071607
"""Get the distances between two lists of coordinates taking into
16081608
account periodic boundary conditions and the lattice. Note that this
16091609
computes an MxN array of distances (i.e. the distance between each
@@ -1657,7 +1657,7 @@ def get_distance_and_image(
16571657
frac_coords1: ArrayLike,
16581658
frac_coords2: ArrayLike,
16591659
jimage: ArrayLike | None = None,
1660-
) -> tuple[float, np.ndarray]:
1660+
) -> tuple[float, NDArray[np.int_]]:
16611661
"""Get distance between two frac_coords assuming periodic boundary
16621662
conditions. If the index jimage is not specified it selects the j
16631663
image nearest to the i atom and returns the distance and jimage
@@ -1675,7 +1675,7 @@ def get_distance_and_image(
16751675
the image that is nearest to the site is found.
16761676
16771677
Returns:
1678-
tuple[float, np.ndarray]: distance and periodic lattice translations (jimage)
1678+
tuple[float, NDArray[np.int_]]: distance and periodic lattice translations (jimage)
16791679
of the other site for which the distance applies. This means that
16801680
the distance between frac_coords1 and (jimage + frac_coords2) is
16811681
equal to distance.
@@ -1819,7 +1819,7 @@ def get_points_in_spheres(
18191819
numerical_tol: float = 1e-8,
18201820
lattice: Lattice | None = None,
18211821
return_fcoords: bool = False,
1822-
) -> list[list[tuple[NDArray[np.float64], float, int, NDArray[np.int_]]]]:
1822+
) -> list[list[tuple[NDArray[np.float64], float, int, NDArray[np.float64]]]]:
18231823
"""For each point in `center_coords`, get all the neighboring points
18241824
in `all_coords` that are within the cutoff radius `r`.
18251825
@@ -1893,12 +1893,19 @@ def get_points_in_spheres(
18931893
if not valid_coords:
18941894
return [[]] * len(center_coords)
18951895
valid_coords = np.concatenate(valid_coords, axis=0)
1896-
valid_images = np.concatenate(valid_images, axis=0)
1896+
return get_points_in_spheres(
1897+
valid_coords, # type:ignore[arg-type]
1898+
center_coords,
1899+
r,
1900+
pbc,
1901+
numerical_tol,
1902+
lattice,
1903+
return_fcoords=return_fcoords,
1904+
)
18971905

1898-
else:
1899-
valid_coords = all_coords # type: ignore[assignment]
1900-
valid_images = [[0, 0, 0]] * len(valid_coords)
1901-
valid_indices = np.arange(len(valid_coords)) # type: ignore[assignment]
1906+
valid_coords = all_coords # type: ignore[assignment]
1907+
valid_images = [[0, 0, 0]] * len(valid_coords) # type: ignore[list-item]
1908+
valid_indices = np.arange(len(valid_coords)) # type: ignore[assignment]
19021909

19031910
# Divide the valid 3D space into cubes and compute the cube ids
19041911
all_cube_index = _compute_cube_index(valid_coords, global_min, r) # type: ignore[arg-type]
@@ -1910,16 +1917,16 @@ def get_points_in_spheres(
19101917
cube_to_coords: dict[int, list] = defaultdict(list)
19111918
cube_to_images: dict[int, list] = defaultdict(list)
19121919
cube_to_indices: dict[int, list] = defaultdict(list)
1913-
for ii, jj, kk, ll in zip(all_cube_index.ravel(), valid_coords, valid_images, valid_indices, strict=True):
1914-
cube_to_coords[ii].append(jj)
1915-
cube_to_images[ii].append(kk)
1916-
cube_to_indices[ii].append(ll)
1920+
for ii, jj, kk, ll in zip(all_cube_index.ravel(), valid_coords, valid_images, valid_indices, strict=True): # type: ignore[assignment]
1921+
cube_to_coords[ii].append(jj) # type: ignore[index]
1922+
cube_to_images[ii].append(kk) # type: ignore[index]
1923+
cube_to_indices[ii].append(ll) # type: ignore[index]
19171924

19181925
# Find all neighboring cubes for each atom in the lattice cell
19191926
site_neighbors = find_neighbors(site_cube_index, nx, ny, nz)
1920-
neighbors: list[list[tuple[np.ndarray, float, int, np.ndarray]]] = []
1927+
neighbors: list[list[tuple[NDArray[np.float64], float, int, NDArray[np.float64]]]] = []
19211928

1922-
for ii, jj in zip(center_coords, site_neighbors, strict=True):
1929+
for cc, jj in zip(center_coords, site_neighbors, strict=True):
19231930
l1 = np.array(_three_to_one(jj, ny, nz), dtype=np.int64).ravel()
19241931
# Use the cube index map to find the all the neighboring
19251932
# coords, images, and indices
@@ -1930,8 +1937,8 @@ def get_points_in_spheres(
19301937
nn_coords = np.concatenate([cube_to_coords[k] for k in ks], axis=0) # type:ignore[index]
19311938
nn_images = itertools.chain(*(cube_to_images[k] for k in ks)) # type:ignore[index]
19321939
nn_indices = itertools.chain(*(cube_to_indices[k] for k in ks)) # type:ignore[index]
1933-
distances = np.linalg.norm(nn_coords - ii[None, :], axis=1)
1934-
nns: list[tuple[np.ndarray, float, int, np.ndarray]] = []
1940+
distances = np.linalg.norm(nn_coords - cc[None, :], axis=1) # type:ignore[index]
1941+
nns: list[tuple[NDArray[np.float64], float, int, NDArray[np.float64]]] = []
19351942
for coord, index, image, dist in zip(nn_coords, nn_indices, nn_images, distances, strict=True):
19361943
# Filtering out all sites that are beyond the cutoff
19371944
# Here there is no filtering of overlapping sites
@@ -1946,23 +1953,23 @@ def get_points_in_spheres(
19461953

19471954
# The following internal functions are used in the get_points_in_sphere method
19481955
def _compute_cube_index(
1949-
coords: np.ndarray,
1956+
coords: NDArray[np.float64],
19501957
global_min: float,
19511958
radius: float,
1952-
) -> np.ndarray:
1959+
) -> NDArray[np.int_]:
19531960
"""Compute the cube index from coordinates
19541961
Args:
19551962
coords: (nx3 array) atom coordinates
19561963
global_min: (float) lower boundary of coordinates
19571964
radius: (float) cutoff radius.
19581965
19591966
Returns:
1960-
np.ndarray: nx3 array int indices
1967+
NDArray[np.float_]: nx3 array int indices
19611968
"""
1962-
return np.array(np.floor((coords - global_min) / radius), dtype=np.int64)
1969+
return np.array(np.floor((coords - global_min) / radius), dtype=np.int_)
19631970

19641971

1965-
def _one_to_three(label1d: np.ndarray, ny: int, nz: int) -> np.ndarray:
1972+
def _one_to_three(label1d: NDArray[np.int_], ny: int, nz: int) -> NDArray[np.int_]:
19661973
"""Convert a 1D index array to 3D index array.
19671974
19681975
Args:
@@ -1971,20 +1978,20 @@ def _one_to_three(label1d: np.ndarray, ny: int, nz: int) -> np.ndarray:
19711978
nz: (int) number of cells in z direction
19721979
19731980
Returns:
1974-
np.ndarray: nx3 array int indices
1981+
NDArray[np.float_]: nx3 array int indices
19751982
"""
19761983
last = np.mod(label1d, nz)
19771984
second = np.mod((label1d - last) / nz, ny)
19781985
first = (label1d - last - second * nz) / (ny * nz)
19791986
return np.concatenate([first, second, last], axis=1)
19801987

19811988

1982-
def _three_to_one(label3d: np.ndarray, ny: int, nz: int) -> np.ndarray:
1989+
def _three_to_one(label3d: NDArray[np.int_], ny: int, nz: int) -> NDArray[np.int_]:
19831990
"""The reverse of _one_to_three."""
19841991
return np.array(label3d[:, 0] * ny * nz + label3d[:, 1] * nz + label3d[:, 2]).reshape((-1, 1))
19851992

19861993

1987-
def find_neighbors(label: np.ndarray, nx: int, ny: int, nz: int) -> list[np.ndarray]:
1994+
def find_neighbors(label: NDArray[np.int_], nx: int, ny: int, nz: int) -> list[NDArray[np.int_]]:
19881995
"""Given a cube index, find the neighbor cube indices.
19891996
19901997
Args:

0 commit comments

Comments
 (0)