Skip to content

Commit 1020d88

Browse files
committed
tests: fix all pyright complaints in tests
1 parent 183ed42 commit 1020d88

13 files changed

+127
-104
lines changed

.pylintrc

+2
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,8 @@ disable=too-many-lines,
438438
# Throws false positives that prevent abstract class hierarchies; The check is
439439
# handled at instantiation time by other linting rules.
440440
abstract-method,
441+
# Ruff's line-too-long handling is better - no need to duplicate it here
442+
line-too-long,
441443

442444
# Enable the message, report, category or checker with the given id(s). You can
443445
# either give multiple identifier separated by comma (,) or put this option

tests/coverage/compare.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def get_most_recent_coverage_total(reference_directory: Path) -> float:
121121
print("**WARNING: No historic coverage data found.**")
122122
return 0
123123

124-
most_recent_file = max(files.keys(), key=files.get)
124+
most_recent_file = max(files.keys(), key=files.__getitem__)
125125

126126
with open(most_recent_file, "r", encoding="utf8") as f:
127127
coverage_dict = json.load(f)

tests/performance/cases/basic_coresets.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def _setup_dataset():
5656
random_state=random_seed,
5757
return_centers=True,
5858
)
59+
x = jnp.asarray(x)
5960

6061
# Setup the original data object
6162
data = Data(x)
@@ -92,7 +93,7 @@ def setup_stein():
9293
data, length_scale = _setup_dataset()
9394
# We use kernel density matching rather than sliced score matching as it's much
9495
# faster than the sliced score matching used in the original unit test
95-
matcher = KernelDensityMatching(length_scale=length_scale)
96+
matcher = KernelDensityMatching(length_scale=length_scale.item())
9697
stein_kernel = SteinKernel(
9798
PCIMQKernel(length_scale=length_scale),
9899
matcher.match(jnp.asarray(data)),

tests/performance/compare.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def get_most_recent_historic_data(
161161
# fine since there aren't any results present to normalise
162162
return {"results": {}, "normalisation": {"compilation": 0.0, "execution": 0.0}}
163163

164-
most_recent_file = max(files.keys(), key=files.get)
164+
most_recent_file = max(files.keys(), key=files.__getitem__)
165165

166166
with open(most_recent_file, "r", encoding="utf8") as f:
167167
return json.load(f)

tests/unit/test_approximation.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import jax.random as jr
2727
import numpy as np
2828
import pytest
29-
from jax.typing import ArrayLike
29+
from jax import Array
3030

3131
from coreax.approximation import (
3232
ANNchorApproximateKernel,
@@ -42,11 +42,11 @@
4242

4343
class _Problem(NamedTuple):
4444
random_key: KeyArrayLike
45-
data: ArrayLike
45+
data: Array
4646
kernel: ScalarValuedKernel
4747
num_kernel_points: int
4848
num_train_points: int
49-
true_distances: ArrayLike
49+
true_distances: Array
5050

5151

5252
@pytest.mark.parametrize(
@@ -113,7 +113,7 @@ def problem(self) -> _Problem:
113113

114114
# We can repeat the above, but changing the point with which we are comparing
115115
# to get:
116-
true_distances = np.array(
116+
true_distances = jnp.array(
117117
[
118118
0.5855723855138795,
119119
0.5737865795122914,

tests/unit/test_benchmark.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import pytest
2525
import torch
2626
from jax import random
27-
from torch.utils.data import Dataset
27+
from torchvision.datasets import VisionDataset
2828

2929
from benchmark.mnist_benchmark import (
3030
MLP,
@@ -43,10 +43,11 @@
4343
)
4444

4545

46-
class MockDataset(Dataset):
46+
class MockDataset(VisionDataset):
4747
"""Mock dataset class for testing purposes."""
4848

49-
def __init__(self, data: torch.Tensor, labels: torch.Tensor) -> None:
49+
# We deliberately don't call super().__init__(), as this is a mock class
50+
def __init__(self, data: torch.Tensor, labels: torch.Tensor) -> None: # pylint: disable=super-init-not-called
5051
"""
5152
Initialise the MockDataset.
5253

tests/unit/test_coreset.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,12 @@ def test_deprecated_nodes(
109109
with pytest.warns(DeprecationWarning):
110110
nodes = coreset.nodes
111111

112-
if setup.coreset_type is PseudoCoreset:
112+
if isinstance(coreset, PseudoCoreset):
113113
assert nodes == coreset.points
114-
elif setup.coreset_type is Coresubset:
114+
elif isinstance(coreset, Coresubset):
115115
assert nodes == coreset.indices
116116
else:
117-
raise ValueError(setup.coreset_type)
117+
raise TypeError(type(coreset))
118118

119119
def test_deprecated_coreset(
120120
self,
@@ -189,7 +189,8 @@ def test_build_array_conversion(
189189
)
190190
with ctx:
191191
coreset_from_arrays = setup.coreset_type(
192-
coreset_input_final, pre_coreset_data_final
192+
coreset_input_final, # pyright: ignore[reportArgumentType]
193+
pre_coreset_data_final, # pyright: ignore[reportArgumentType]
193194
)
194195
coreset_from_data = setup.coreset_type.build(
195196
setup.coreset_input, setup.pre_coreset_data
@@ -267,7 +268,7 @@ def test_invalid_pre_coreset_data_type(
267268
with pytest.raises(TypeError, match="pre_coreset_data"):
268269
with warnings.catch_warnings():
269270
warnings.simplefilter(action="ignore", category=DeprecationWarning)
270-
coreset_type.build(indices_or_nodes, object())
271+
coreset_type.build(indices_or_nodes, object()) # pyright: ignore[reportArgumentType, reportCallIssue]
271272

272273
@pytest.mark.parametrize("coreset_type", [Coresubset, PseudoCoreset])
273274
def test_invalid_indices_or_points_type(
@@ -280,7 +281,7 @@ def test_invalid_indices_or_points_type(
280281
):
281282
with warnings.catch_warnings():
282283
warnings.simplefilter(action="ignore", category=DeprecationWarning)
283-
coreset_type(object(), pre_coreset_data)
284+
coreset_type(object(), pre_coreset_data) # pyright: ignore[reportArgumentType, reportCallIssue]
284285

285286

286287
class TestCoresubset:

tests/unit/test_kernels.py

+17-27
Original file line numberDiff line numberDiff line change
@@ -220,19 +220,19 @@ def test_gradients(
220220

221221
@abstractmethod
222222
def expected_grad_x(
223-
self, x: ArrayLike, y: ArrayLike, kernel: _ScalarValuedKernel
223+
self, x: Array, y: Array, kernel: _ScalarValuedKernel
224224
) -> Union[Array, np.ndarray]:
225225
"""Compute expected gradient of the kernel w.r.t ``x``."""
226226

227227
@abstractmethod
228228
def expected_grad_y(
229-
self, x: ArrayLike, y: ArrayLike, kernel: _ScalarValuedKernel
229+
self, x: Array, y: Array, kernel: _ScalarValuedKernel
230230
) -> Union[Array, np.ndarray]:
231231
"""Compute expected gradient of the kernel w.r.t ``y``."""
232232

233233
@abstractmethod
234234
def expected_divergence_x_grad_y(
235-
self, x: ArrayLike, y: ArrayLike, kernel: _ScalarValuedKernel
235+
self, x: Array, y: Array, kernel: _ScalarValuedKernel
236236
) -> Union[Array, np.ndarray]:
237237
"""Compute expected divergence of the kernel w.r.t ``x`` gradient ``y``."""
238238

@@ -402,9 +402,7 @@ def problem(self, request: pytest.FixtureRequest, kernel: PowerKernel) -> _Probl
402402
)
403403
return _Problem(x, y, expected_distances, modified_kernel)
404404

405-
def expected_grad_x(
406-
self, x: ArrayLike, y: ArrayLike, kernel: PowerKernel
407-
) -> np.ndarray:
405+
def expected_grad_x(self, x: Array, y: Array, kernel: PowerKernel) -> np.ndarray:
408406
num_points, dimension = np.atleast_2d(x).shape
409407

410408
expected_grad = (
@@ -419,9 +417,7 @@ def expected_grad_x(
419417

420418
return np.array(expected_grad)
421419

422-
def expected_grad_y(
423-
self, x: ArrayLike, y: ArrayLike, kernel: PowerKernel
424-
) -> np.ndarray:
420+
def expected_grad_y(self, x: Array, y: Array, kernel: PowerKernel) -> np.ndarray:
425421
num_points, dimension = np.atleast_2d(x).shape
426422

427423
expected_grad = (
@@ -437,7 +433,7 @@ def expected_grad_y(
437433
return np.array(expected_grad)
438434

439435
def expected_divergence_x_grad_y(
440-
self, x: ArrayLike, y: ArrayLike, kernel: PowerKernel
436+
self, x: Array, y: Array, kernel: PowerKernel
441437
) -> np.ndarray:
442438
divergence = self.power * (
443439
(
@@ -579,9 +575,7 @@ def problem(
579575
)
580576
return _Problem(x, y, expected_distances, modified_kernel)
581577

582-
def expected_grad_x(
583-
self, x: ArrayLike, y: ArrayLike, kernel: AdditiveKernel
584-
) -> np.ndarray:
578+
def expected_grad_x(self, x: Array, y: Array, kernel: AdditiveKernel) -> np.ndarray:
585579
num_points, dimension = np.atleast_2d(x).shape
586580

587581
# Variable rename allows for nicer automatic formatting
@@ -595,9 +589,7 @@ def expected_grad_x(
595589

596590
return np.array(expected_grad)
597591

598-
def expected_grad_y(
599-
self, x: ArrayLike, y: ArrayLike, kernel: AdditiveKernel
600-
) -> np.ndarray:
592+
def expected_grad_y(self, x: Array, y: Array, kernel: AdditiveKernel) -> np.ndarray:
601593
num_points, dimension = np.atleast_2d(x).shape
602594

603595
# Variable rename allows for nicer automatic formatting
@@ -612,7 +604,7 @@ def expected_grad_y(
612604
return np.array(expected_grad)
613605

614606
def expected_divergence_x_grad_y(
615-
self, x: ArrayLike, y: ArrayLike, kernel: AdditiveKernel
607+
self, x: Array, y: Array, kernel: AdditiveKernel
616608
) -> np.ndarray:
617609
num_points, _ = np.atleast_2d(x).shape
618610

@@ -691,9 +683,7 @@ def problem(
691683
)
692684
return _Problem(x, y, expected_distances, modified_kernel)
693685

694-
def expected_grad_x(
695-
self, x: ArrayLike, y: ArrayLike, kernel: ProductKernel
696-
) -> np.ndarray:
686+
def expected_grad_x(self, x: Array, y: Array, kernel: ProductKernel) -> np.ndarray:
697687
num_points, dimension = np.atleast_2d(x).shape
698688

699689
# Variable rename allows for nicer automatic formatting
@@ -709,9 +699,7 @@ def expected_grad_x(
709699

710700
return np.array(expected_grad)
711701

712-
def expected_grad_y(
713-
self, x: ArrayLike, y: ArrayLike, kernel: ProductKernel
714-
) -> np.ndarray:
702+
def expected_grad_y(self, x: Array, y: Array, kernel: ProductKernel) -> np.ndarray:
715703
num_points, dimension = np.atleast_2d(x).shape
716704

717705
# Variable rename allows for nicer automatic formatting
@@ -728,7 +716,7 @@ def expected_grad_y(
728716
return np.array(expected_grad)
729717

730718
def expected_divergence_x_grad_y(
731-
self, x: ArrayLike, y: ArrayLike, kernel: ProductKernel
719+
self, x: Array, y: Array, kernel: ProductKernel
732720
) -> np.ndarray:
733721
# Variable rename allows for nicer automatic formatting
734722
k1, k2 = kernel.first_kernel, kernel.second_kernel
@@ -748,7 +736,7 @@ def test_symmetric_product_kernel(self):
748736
We consider a product kernel with equal input kernels and check that
749737
the second kernel is never called.
750738
"""
751-
x = np.array([1])
739+
x = jnp.array([1])
752740

753741
# Form two simple mocked kernels and force any == operation to return True
754742
first_kernel = MagicMock(spec=ScalarValuedKernel)
@@ -1362,8 +1350,10 @@ def problem( # noqa: C901
13621350
expected_distances = np.zeros((num_points, num_points))
13631351
for x_idx, x_ in enumerate(x):
13641352
for y_idx, y_ in enumerate(y):
1365-
expected_distances[x_idx, y_idx] = scipy_norm(y_, length_scale).pdf(
1366-
x_
1353+
expected_distances[x_idx, y_idx] = (
1354+
scipy_norm(y_, length_scale)
1355+
# Ignore Pyright here - the .pdf() function definitely exists!
1356+
.pdf(x_) # pyright: ignore[reportAttributeAccessIssue]
13671357
)
13681358
x, y = x.reshape(-1, 1), y.reshape(-1, 1)
13691359
elif mode == "negative_length_scale":

tests/unit/test_metrics.py

+3
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ def test_mmd_random_data(
206206
kernel_mm = kernel.compute(y.data, y.data)
207207
kernel_nm = kernel.compute(x.data, y.data)
208208
if mode == "weighted":
209+
assert isinstance(x.weights, Array)
210+
assert isinstance(y.weights, Array)
209211
weights_nn = x.weights[..., None] * x.weights[None, ...]
210212
weights_mm = y.weights[..., None] * y.weights[None, ...]
211213
weights_nm = x.weights[..., None] * y.weights[None, ...]
@@ -355,6 +357,7 @@ def test_ksd_random_data(
355357
# Compute each term in the KSD formula to obtain an expected KSD.
356358
kernel_mm = kernel.compute(y.data, y.data)
357359
if mode == "weighted":
360+
assert isinstance(y.weights, Array)
358361
weights_mm = y.weights[..., None] * y.weights[None, ...]
359362
expected_ksd = jnp.sqrt(jnp.average(kernel_mm, weights=weights_mm))
360363
output = metric.compute(x, y, laplace_correct=False, regularise=False)

0 commit comments

Comments
 (0)