Skip to content

Commit f020146

Browse files
authored
Merge pull request #64 from Balandat/catch_lincg_test_warnings
Catch some warnings in linear_cg tests
2 parents 46b08fc + 66dc5ec commit f020146

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

test/utils/test_linear_cg.py

+24-18
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import os
44
import random
55
import unittest
6+
import warnings
67

78
import torch
89

910
from linear_operator.utils.linear_cg import linear_cg
11+
from linear_operator.utils.warnings import NumericalWarning
1012

1113

1214
class TestLinearCG(unittest.TestCase):
@@ -69,15 +71,17 @@ def test_cg_with_tridiag(self):
6971
matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1))
7072

7173
rhs = torch.randn(size, 50, dtype=torch.float64)
72-
solves, t_mats = linear_cg(
73-
matrix.matmul,
74-
rhs=rhs,
75-
n_tridiag=5,
76-
max_tridiag_iter=10,
77-
max_iter=size,
78-
tolerance=0,
79-
eps=1e-15,
80-
)
74+
with warnings.catch_warnings(record=True) as ws:
75+
solves, t_mats = linear_cg(
76+
matrix.matmul,
77+
rhs=rhs,
78+
n_tridiag=5,
79+
max_tridiag_iter=10,
80+
max_iter=size,
81+
tolerance=0,
82+
eps=1e-15,
83+
)
84+
self.assertTrue(any(issubclass(w.category, NumericalWarning) for w in ws))
8185

8286
# Check cg
8387
matrix_chol = torch.linalg.cholesky(matrix)
@@ -115,15 +119,17 @@ def test_batch_cg_with_tridiag(self):
115119
matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1))
116120

117121
rhs = torch.randn(batch, size, 10, dtype=torch.float64)
118-
solves, t_mats = linear_cg(
119-
matrix.matmul,
120-
rhs=rhs,
121-
n_tridiag=8,
122-
max_iter=size,
123-
max_tridiag_iter=10,
124-
tolerance=0,
125-
eps=1e-30,
126-
)
122+
with warnings.catch_warnings(record=True) as ws:
123+
solves, t_mats = linear_cg(
124+
matrix.matmul,
125+
rhs=rhs,
126+
n_tridiag=8,
127+
max_iter=size,
128+
max_tridiag_iter=10,
129+
tolerance=0,
130+
eps=1e-30,
131+
)
132+
self.assertTrue(any(issubclass(w.category, NumericalWarning) for w in ws))
127133

128134
# Check cg
129135
matrix_chol = torch.linalg.cholesky(matrix)

0 commit comments

Comments
 (0)