Skip to content

Commit 4e4ec54

Browse files
authored
Add MaskedLinearOperator (#69)
* Add MaskedLinearOperator * Add to index * Fix import * Linting * Add _get_indices()
1 parent 7c5aabd commit 4e4ec54

7 files changed

+212
-14
lines changed

docs/source/composition_decoration_operators.rst

+6
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ Composition/Decoration LinearOperators
4040
.. autoclass:: linear_operator.operators.KroneckerProductDiagLinearOperator
4141
:members:
4242

43+
:hidden:`MaskedLinearOperator`
44+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
45+
46+
.. autoclass:: linear_operator.operators.MaskedLinearOperator
47+
:members:
48+
4349
:hidden:`MatmulLinearOperator`
4450
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4551

linear_operator/operators/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from .low_rank_root_added_diag_linear_operator import LowRankRootAddedDiagLinearOperator
2525
from .low_rank_root_linear_operator import LowRankRootLinearOperator
26+
from .masked_linear_operator import MaskedLinearOperator
2627
from .matmul_linear_operator import MatmulLinearOperator
2728
from .mul_linear_operator import MulLinearOperator
2829
from .permutation_linear_operator import PermutationLinearOperator, TransposePermutationLinearOperator
@@ -62,6 +63,7 @@
6263
"SumKroneckerLinearOperator",
6364
"LowRankRootAddedDiagLinearOperator",
6465
"LowRankRootLinearOperator",
66+
"MaskedLinearOperator",
6567
"MatmulLinearOperator",
6668
"MulLinearOperator",
6769
"PermutationLinearOperator",

linear_operator/operators/_linear_operator.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -2645,22 +2645,28 @@ def type(self: LinearOperator, dtype: torch.dtype) -> LinearOperator:
26452645
"""
26462646
attr_flag = _TYPES_DICT[dtype]
26472647

2648+
def _type_helper(arg):
2649+
if arg.dtype.is_floating_point:
2650+
return arg.to(dtype)
2651+
else:
2652+
return arg
2653+
26482654
new_args = []
26492655
new_kwargs = {}
26502656
for arg in self._args:
26512657
if hasattr(arg, attr_flag):
26522658
try:
2653-
new_args.append(arg.clone().to(dtype))
2659+
new_args.append(_type_helper(arg.clone()))
26542660
except AttributeError:
2655-
new_args.append(deepcopy(arg).to(dtype))
2661+
new_args.append(_type_helper(deepcopy(arg)))
26562662
else:
26572663
new_args.append(arg)
26582664
for name, val in self._kwargs.items():
26592665
if hasattr(val, attr_flag):
26602666
try:
2661-
new_kwargs[name] = val.clone().to(dtype)
2667+
new_kwargs[name] = _type_helper(val.clone())
26622668
except AttributeError:
2663-
new_kwargs[name] = deepcopy(val).to(dtype)
2669+
new_kwargs[name] = _type_helper(deepcopy(val))
26642670
else:
26652671
new_kwargs[name] = val
26662672
return self.__class__(*new_args, **new_kwargs)

linear_operator/operators/interpolated_linear_operator.py

-9
Original file line numberDiff line numberDiff line change
@@ -405,15 +405,6 @@ def _sum_batch(self, dim: int) -> LinearOperator:
405405
block_diag, left_interp_indices, left_interp_values, right_interp_indices, right_interp_values
406406
)
407407

408-
def double(
409-
self: Float[LinearOperator, "*batch M N"], device_id: Optional[str] = None
410-
) -> Float[LinearOperator, "*batch M N"]:
411-
# We need to ensure that the indices remain integers.
412-
new_lt = super().double(device_id=device_id)
413-
new_lt.left_interp_indices = new_lt.left_interp_indices.type(torch.int64)
414-
new_lt.right_interp_indices = new_lt.right_interp_indices.type(torch.int64)
415-
return new_lt
416-
417408
def matmul(
418409
self: Float[LinearOperator, "*batch M N"],
419410
other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]],
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from typing import List, Optional, Tuple, Union
2+
3+
import torch
4+
from jaxtyping import Bool, Float
5+
from torch import Tensor
6+
7+
from ._linear_operator import _is_noop_index, IndexType, LinearOperator
8+
9+
10+
class MaskedLinearOperator(LinearOperator):
11+
r"""
12+
A :obj:`~linear_operator.operators.LinearOperator` that applies a mask to the rows and columns of a base
13+
:obj:`~linear_operator.operators.LinearOperator`.
14+
"""
15+
16+
def __init__(
17+
self,
18+
base: Float[LinearOperator, "*batch M0 N0"],
19+
row_mask: Bool[Tensor, "M0"],
20+
col_mask: Bool[Tensor, "N0"],
21+
):
22+
r"""
23+
Create a new :obj:`~linear_operator.operators.MaskedLinearOperator` that applies a mask to the rows and columns
24+
of a base :obj:`~linear_operator.operators.LinearOperator`.
25+
26+
:param base: The base :obj:`~linear_operator.operators.LinearOperator`.
27+
:param row_mask: A :obj:`torch.BoolTensor` containing the mask to apply to the rows.
28+
:param col_mask: A :obj:`torch.BoolTensor` containing the mask to apply to the columns.
29+
"""
30+
super().__init__(base, row_mask, col_mask)
31+
self.base = base
32+
self.row_mask = row_mask
33+
self.col_mask = col_mask
34+
self.row_eq_col_mask = torch.equal(row_mask, col_mask)
35+
36+
@staticmethod
37+
def _expand(tensor: Float[Tensor, "*batch N C"], mask: Bool[Tensor, "N0"]) -> Float[Tensor, "*batch N0 C"]:
38+
res = torch.zeros(
39+
*tensor.shape[:-2],
40+
mask.size(-1),
41+
tensor.size(-1),
42+
device=tensor.device,
43+
dtype=tensor.dtype,
44+
)
45+
res[..., mask, :] = tensor
46+
return res
47+
48+
def _matmul(
49+
self: Float[LinearOperator, "*batch M N"],
50+
rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]],
51+
) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]:
52+
rhs_expanded = self._expand(rhs, self.col_mask)
53+
res_expanded = self.base.matmul(rhs_expanded)
54+
res = res_expanded[..., self.row_mask, :]
55+
56+
return res
57+
58+
def _t_matmul(
59+
self: Float[LinearOperator, "*batch M N"],
60+
rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]],
61+
) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]:
62+
rhs_expanded = self._expand(rhs, self.row_mask)
63+
res_expanded = self.base.t_matmul(rhs_expanded)
64+
res = res_expanded[..., self.col_mask, :]
65+
return res
66+
67+
def _size(self) -> torch.Size:
68+
return torch.Size(
69+
(*self.base.size()[:-2], torch.count_nonzero(self.row_mask), torch.count_nonzero(self.col_mask))
70+
)
71+
72+
def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]:
73+
return self.__class__(self.base.mT, self.col_mask, self.row_mask)
74+
75+
def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
76+
if not self.row_eq_col_mask:
77+
raise NotImplementedError()
78+
diag = self.base.diagonal()
79+
return diag[..., self.row_mask]
80+
81+
def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]:
82+
full_dense = self.base.to_dense()
83+
return full_dense[..., self.row_mask, :][..., :, self.col_mask]
84+
85+
def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]:
86+
left_vecs = self._expand(left_vecs, self.row_mask)
87+
right_vecs = self._expand(right_vecs, self.col_mask)
88+
return self.base._bilinear_derivative(left_vecs, right_vecs) + (None, None)
89+
90+
def _expand_batch(
91+
self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]]
92+
) -> Float[LinearOperator, "... M N"]:
93+
return self.__class__(self.base._expand_batch(batch_shape), self.row_mask, self.col_mask)
94+
95+
def _unsqueeze_batch(self, dim: int) -> LinearOperator:
96+
return self.__class__(self.base._unsqueeze_batch(dim), self.row_mask, self.col_mask)
97+
98+
def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator:
99+
if _is_noop_index(row_index) and _is_noop_index(col_index):
100+
if len(batch_indices):
101+
return self.__class__(self.base[batch_indices], self.row_mask, self.col_mask)
102+
else:
103+
return self
104+
else:
105+
return super()._getitem(row_index, col_index, *batch_indices)
106+
107+
def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor:
108+
row_mapping = torch.arange(self.base.size(-2), device=self.base.device)[self.row_mask]
109+
col_mapping = torch.arange(self.base.size(-1), device=self.base.device)[self.col_mask]
110+
return self.base._get_indices(row_mapping[row_index], col_mapping[col_index], *batch_indices)
111+
112+
def _permute_batch(self, *dims: int) -> LinearOperator:
113+
return self.__class__(self.base._permute_batch(*dims), self.row_mask, self.col_mask)

linear_operator/test/linear_operator_test_case.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,10 @@ def test_bilinear_derivative(self):
684684
)
685685

686686
for dc, da in zip(deriv_custom, deriv_auto):
687-
self.assertAllClose(dc, da)
687+
if dc is None:
688+
assert da is None
689+
else:
690+
self.assertAllClose(dc, da)
688691

689692
def test_cat_rows(self):
690693
linear_op = self.create_linear_op()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#!/usr/bin/env python3
2+
3+
import unittest
4+
5+
import torch
6+
7+
from linear_operator import to_linear_operator
8+
from linear_operator.operators.masked_linear_operator import MaskedLinearOperator
9+
from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase, RectangularLinearOperatorTestCase
10+
11+
12+
class TestMaskedLinearOperator(LinearOperatorTestCase, unittest.TestCase):
13+
seed = 2023
14+
15+
def create_linear_op(self):
16+
base = torch.randn(5, 5)
17+
base = base.mT @ base
18+
base.requires_grad_(True)
19+
base = to_linear_operator(base)
20+
mask = torch.tensor([1, 1, 0, 1, 0], dtype=torch.bool)
21+
covar = MaskedLinearOperator(base, mask, mask)
22+
return covar
23+
24+
def evaluate_linear_op(self, linear_op):
25+
base = linear_op.base.to_dense()
26+
return base[..., linear_op.row_mask, :][..., linear_op.col_mask]
27+
28+
29+
class TestMaskedLinearOperatorBatch(LinearOperatorTestCase, unittest.TestCase):
30+
seed = 2023
31+
32+
def create_linear_op(self):
33+
base = torch.randn(2, 5, 5)
34+
base = base.mT @ base
35+
base.requires_grad_(True)
36+
base = to_linear_operator(base)
37+
mask = torch.tensor([1, 1, 0, 1, 0], dtype=torch.bool)
38+
covar = MaskedLinearOperator(base, mask, mask)
39+
return covar
40+
41+
def evaluate_linear_op(self, linear_op):
42+
base = linear_op.base.to_dense()
43+
return base[..., linear_op.row_mask, :][..., linear_op.col_mask]
44+
45+
46+
class TestMaskedLinearOperatorRectangular(RectangularLinearOperatorTestCase, unittest.TestCase):
47+
seed = 2023
48+
49+
def create_linear_op(self):
50+
base = to_linear_operator(torch.randn(5, 6, requires_grad=True))
51+
row_mask = torch.tensor([1, 1, 0, 1, 0], dtype=torch.bool)
52+
col_mask = torch.tensor([1, 0, 1, 0, 1, 1], dtype=torch.bool)
53+
covar = MaskedLinearOperator(base, row_mask, col_mask)
54+
return covar
55+
56+
def evaluate_linear_op(self, linear_op):
57+
base = linear_op.base.to_dense()
58+
return base[..., linear_op.row_mask, :][..., linear_op.col_mask]
59+
60+
61+
class TestMaskedLinearOperatorRectangularMultiBatch(RectangularLinearOperatorTestCase, unittest.TestCase):
62+
seed = 2023
63+
64+
def create_linear_op(self):
65+
base = to_linear_operator(torch.randn(2, 3, 5, 6, requires_grad=True))
66+
row_mask = torch.tensor([1, 1, 0, 1, 0], dtype=torch.bool)
67+
col_mask = torch.tensor([1, 0, 1, 0, 1, 1], dtype=torch.bool)
68+
covar = MaskedLinearOperator(base, row_mask, col_mask)
69+
return covar
70+
71+
def evaluate_linear_op(self, linear_op):
72+
base = linear_op.base.to_dense()
73+
return base[..., linear_op.row_mask, :][..., linear_op.col_mask]
74+
75+
76+
if __name__ == "__main__":
77+
unittest.main()

0 commit comments

Comments
 (0)