|
3 | 3 | import os
|
4 | 4 | import random
|
5 | 5 | import unittest
|
| 6 | +import warnings |
6 | 7 |
|
7 | 8 | import torch
|
8 | 9 |
|
9 | 10 | from linear_operator.utils.linear_cg import linear_cg
|
| 11 | +from linear_operator.utils.warnings import NumericalWarning |
10 | 12 |
|
11 | 13 |
|
12 | 14 | class TestLinearCG(unittest.TestCase):
|
@@ -69,15 +71,17 @@ def test_cg_with_tridiag(self):
|
69 | 71 | matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1))
|
70 | 72 |
|
71 | 73 | rhs = torch.randn(size, 50, dtype=torch.float64)
|
72 |
| - solves, t_mats = linear_cg( |
73 |
| - matrix.matmul, |
74 |
| - rhs=rhs, |
75 |
| - n_tridiag=5, |
76 |
| - max_tridiag_iter=10, |
77 |
| - max_iter=size, |
78 |
| - tolerance=0, |
79 |
| - eps=1e-15, |
80 |
| - ) |
| 74 | + with warnings.catch_warnings(record=True) as ws: |
| 75 | + solves, t_mats = linear_cg( |
| 76 | + matrix.matmul, |
| 77 | + rhs=rhs, |
| 78 | + n_tridiag=5, |
| 79 | + max_tridiag_iter=10, |
| 80 | + max_iter=size, |
| 81 | + tolerance=0, |
| 82 | + eps=1e-15, |
| 83 | + ) |
| 84 | + self.assertTrue(any(issubclass(w.category, NumericalWarning) for w in ws)) |
81 | 85 |
|
82 | 86 | # Check cg
|
83 | 87 | matrix_chol = torch.linalg.cholesky(matrix)
|
@@ -115,15 +119,17 @@ def test_batch_cg_with_tridiag(self):
|
115 | 119 | matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1))
|
116 | 120 |
|
117 | 121 | rhs = torch.randn(batch, size, 10, dtype=torch.float64)
|
118 |
| - solves, t_mats = linear_cg( |
119 |
| - matrix.matmul, |
120 |
| - rhs=rhs, |
121 |
| - n_tridiag=8, |
122 |
| - max_iter=size, |
123 |
| - max_tridiag_iter=10, |
124 |
| - tolerance=0, |
125 |
| - eps=1e-30, |
126 |
| - ) |
| 122 | + with warnings.catch_warnings(record=True) as ws: |
| 123 | + solves, t_mats = linear_cg( |
| 124 | + matrix.matmul, |
| 125 | + rhs=rhs, |
| 126 | + n_tridiag=8, |
| 127 | + max_iter=size, |
| 128 | + max_tridiag_iter=10, |
| 129 | + tolerance=0, |
| 130 | + eps=1e-30, |
| 131 | + ) |
| 132 | + self.assertTrue(any(issubclass(w.category, NumericalWarning) for w in ws)) |
127 | 133 |
|
128 | 134 | # Check cg
|
129 | 135 | matrix_chol = torch.linalg.cholesky(matrix)
|
|
0 commit comments