Skip to content

Add MaskedLinearOperator #69

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/composition_decoration_operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ Composition/Decoration LinearOperators
.. autoclass:: linear_operator.operators.KroneckerProductDiagLinearOperator
:members:

:hidden:`MaskedLinearOperator`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: linear_operator.operators.MaskedLinearOperator
:members:

:hidden:`MatmulLinearOperator`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions linear_operator/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from .low_rank_root_added_diag_linear_operator import LowRankRootAddedDiagLinearOperator
from .low_rank_root_linear_operator import LowRankRootLinearOperator
from .masked_linear_operator import MaskedLinearOperator
from .matmul_linear_operator import MatmulLinearOperator
from .mul_linear_operator import MulLinearOperator
from .permutation_linear_operator import PermutationLinearOperator, TransposePermutationLinearOperator
Expand Down Expand Up @@ -62,6 +63,7 @@
"SumKroneckerLinearOperator",
"LowRankRootAddedDiagLinearOperator",
"LowRankRootLinearOperator",
"MaskedLinearOperator",
"MatmulLinearOperator",
"MulLinearOperator",
"PermutationLinearOperator",
Expand Down
14 changes: 10 additions & 4 deletions linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2645,22 +2645,28 @@ def type(self: LinearOperator, dtype: torch.dtype) -> LinearOperator:
"""
attr_flag = _TYPES_DICT[dtype]

def _type_helper(arg):
if arg.dtype.is_floating_point:
return arg.to(dtype)
else:
return arg

new_args = []
new_kwargs = {}
for arg in self._args:
if hasattr(arg, attr_flag):
try:
new_args.append(arg.clone().to(dtype))
new_args.append(_type_helper(arg.clone()))
except AttributeError:
new_args.append(deepcopy(arg).to(dtype))
new_args.append(_type_helper(deepcopy(arg)))
else:
new_args.append(arg)
for name, val in self._kwargs.items():
if hasattr(val, attr_flag):
try:
new_kwargs[name] = val.clone().to(dtype)
new_kwargs[name] = _type_helper(val.clone())
except AttributeError:
new_kwargs[name] = deepcopy(val).to(dtype)
new_kwargs[name] = _type_helper(deepcopy(val))
else:
new_kwargs[name] = val
return self.__class__(*new_args, **new_kwargs)
Expand Down
9 changes: 0 additions & 9 deletions linear_operator/operators/interpolated_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,15 +405,6 @@ def _sum_batch(self, dim: int) -> LinearOperator:
block_diag, left_interp_indices, left_interp_values, right_interp_indices, right_interp_values
)

def double(
self: Float[LinearOperator, "*batch M N"], device_id: Optional[str] = None
) -> Float[LinearOperator, "*batch M N"]:
# We need to ensure that the indices remain integers.
new_lt = super().double(device_id=device_id)
new_lt.left_interp_indices = new_lt.left_interp_indices.type(torch.int64)
new_lt.right_interp_indices = new_lt.right_interp_indices.type(torch.int64)
return new_lt

def matmul(
self: Float[LinearOperator, "*batch M N"],
other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]],
Expand Down
113 changes: 113 additions & 0 deletions linear_operator/operators/masked_linear_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import List, Optional, Tuple, Union

import torch
from jaxtyping import Bool, Float
from torch import Tensor

from ._linear_operator import _is_noop_index, IndexType, LinearOperator


class MaskedLinearOperator(LinearOperator):
r"""
A :obj:`~linear_operator.operators.LinearOperator` that applies a mask to the rows and columns of a base
:obj:`~linear_operator.operators.LinearOperator`.
"""

def __init__(
self,
base: Float[LinearOperator, "*batch M0 N0"],
row_mask: Bool[Tensor, "M0"],
col_mask: Bool[Tensor, "N0"],
):
r"""
Create a new :obj:`~linear_operator.operators.MaskedLinearOperator` that applies a mask to the rows and columns
of a base :obj:`~linear_operator.operators.LinearOperator`.

:param base: The base :obj:`~linear_operator.operators.LinearOperator`.
:param row_mask: A :obj:`torch.BoolTensor` containing the mask to apply to the rows.
:param col_mask: A :obj:`torch.BoolTensor` containing the mask to apply to the columns.
"""
super().__init__(base, row_mask, col_mask)
self.base = base
self.row_mask = row_mask
self.col_mask = col_mask
self.row_eq_col_mask = torch.equal(row_mask, col_mask)

@staticmethod
def _expand(tensor: Float[Tensor, "*batch N C"], mask: Bool[Tensor, "N0"]) -> Float[Tensor, "*batch N0 C"]:
res = torch.zeros(
*tensor.shape[:-2],
mask.size(-1),
tensor.size(-1),
device=tensor.device,
dtype=tensor.dtype,
)
res[..., mask, :] = tensor
return res

def _matmul(
self: Float[LinearOperator, "*batch M N"],
rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]],
) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]:
rhs_expanded = self._expand(rhs, self.col_mask)
res_expanded = self.base.matmul(rhs_expanded)
res = res_expanded[..., self.row_mask, :]

return res

def _t_matmul(
self: Float[LinearOperator, "*batch M N"],
rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]],
) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]:
rhs_expanded = self._expand(rhs, self.row_mask)
res_expanded = self.base.t_matmul(rhs_expanded)
res = res_expanded[..., self.col_mask, :]
return res

def _size(self) -> torch.Size:
return torch.Size(
(*self.base.size()[:-2], torch.count_nonzero(self.row_mask), torch.count_nonzero(self.col_mask))
)

def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]:
return self.__class__(self.base.mT, self.col_mask, self.row_mask)

def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
if not self.row_eq_col_mask:
raise NotImplementedError()
diag = self.base.diagonal()
return diag[..., self.row_mask]

def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]:
full_dense = self.base.to_dense()
return full_dense[..., self.row_mask, :][..., :, self.col_mask]

def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]:
left_vecs = self._expand(left_vecs, self.row_mask)
right_vecs = self._expand(right_vecs, self.col_mask)
return self.base._bilinear_derivative(left_vecs, right_vecs) + (None, None)

def _expand_batch(
self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]]
) -> Float[LinearOperator, "... M N"]:
return self.__class__(self.base._expand_batch(batch_shape), self.row_mask, self.col_mask)

def _unsqueeze_batch(self, dim: int) -> LinearOperator:
return self.__class__(self.base._unsqueeze_batch(dim), self.row_mask, self.col_mask)

def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator:
if _is_noop_index(row_index) and _is_noop_index(col_index):
if len(batch_indices):
return self.__class__(self.base[batch_indices], self.row_mask, self.col_mask)
else:
return self
else:
return super()._getitem(row_index, col_index, *batch_indices)

def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor:
row_mapping = torch.arange(self.base.size(-2), device=self.base.device)[self.row_mask]
col_mapping = torch.arange(self.base.size(-1), device=self.base.device)[self.col_mask]
return self.base._get_indices(row_mapping[row_index], col_mapping[col_index], *batch_indices)

def _permute_batch(self, *dims: int) -> LinearOperator:
return self.__class__(self.base._permute_batch(*dims), self.row_mask, self.col_mask)
5 changes: 4 additions & 1 deletion linear_operator/test/linear_operator_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,10 @@ def test_bilinear_derivative(self):
)

for dc, da in zip(deriv_custom, deriv_auto):
self.assertAllClose(dc, da)
if dc is None:
assert da is None
else:
self.assertAllClose(dc, da)

def test_cat_rows(self):
linear_op = self.create_linear_op()
Expand Down
77 changes: 77 additions & 0 deletions test/operators/test_masked_linear_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python3

import unittest

import torch

from linear_operator import to_linear_operator
from linear_operator.operators.masked_linear_operator import MaskedLinearOperator
from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase, RectangularLinearOperatorTestCase


class TestMaskedLinearOperator(LinearOperatorTestCase, unittest.TestCase):
seed = 2023

def create_linear_op(self):
base = torch.randn(5, 5)
base = base.mT @ base
base.requires_grad_(True)
base = to_linear_operator(base)
mask = torch.tensor([1, 1, 0, 1, 0], dtype=torch.bool)
covar = MaskedLinearOperator(base, mask, mask)
return covar

def evaluate_linear_op(self, linear_op):
base = linear_op.base.to_dense()
return base[..., linear_op.row_mask, :][..., linear_op.col_mask]


class TestMaskedLinearOperatorBatch(LinearOperatorTestCase, unittest.TestCase):
seed = 2023

def create_linear_op(self):
base = torch.randn(2, 5, 5)
base = base.mT @ base
base.requires_grad_(True)
base = to_linear_operator(base)
mask = torch.tensor([1, 1, 0, 1, 0], dtype=torch.bool)
covar = MaskedLinearOperator(base, mask, mask)
return covar

def evaluate_linear_op(self, linear_op):
base = linear_op.base.to_dense()
return base[..., linear_op.row_mask, :][..., linear_op.col_mask]


class TestMaskedLinearOperatorRectangular(RectangularLinearOperatorTestCase, unittest.TestCase):
seed = 2023

def create_linear_op(self):
base = to_linear_operator(torch.randn(5, 6, requires_grad=True))
row_mask = torch.tensor([1, 1, 0, 1, 0], dtype=torch.bool)
col_mask = torch.tensor([1, 0, 1, 0, 1, 1], dtype=torch.bool)
covar = MaskedLinearOperator(base, row_mask, col_mask)
return covar

def evaluate_linear_op(self, linear_op):
base = linear_op.base.to_dense()
return base[..., linear_op.row_mask, :][..., linear_op.col_mask]


class TestMaskedLinearOperatorRectangularMultiBatch(RectangularLinearOperatorTestCase, unittest.TestCase):
seed = 2023

def create_linear_op(self):
base = to_linear_operator(torch.randn(2, 3, 5, 6, requires_grad=True))
row_mask = torch.tensor([1, 1, 0, 1, 0], dtype=torch.bool)
col_mask = torch.tensor([1, 0, 1, 0, 1, 1], dtype=torch.bool)
covar = MaskedLinearOperator(base, row_mask, col_mask)
return covar

def evaluate_linear_op(self, linear_op):
base = linear_op.base.to_dense()
return base[..., linear_op.row_mask, :][..., linear_op.col_mask]


if __name__ == "__main__":
unittest.main()