From ca1ac5ba6d192cd141a12eb6b16bb22efefd5722 Mon Sep 17 00:00:00 2001 From: Turakar Date: Wed, 7 Jun 2023 14:26:33 +0200 Subject: [PATCH 1/5] Add MaskedLinearOperator --- .../composition_decoration_operators.rst | 6 + linear_operator/operators/_linear_operator.py | 14 ++- .../operators/interpolated_linear_operator.py | 9 -- .../operators/masked_linear_operator.py | 111 ++++++++++++++++++ .../test/linear_operator_test_case.py | 5 +- test/operators/test_masked_linear_operator.py | 77 ++++++++++++ 6 files changed, 208 insertions(+), 14 deletions(-) create mode 100644 linear_operator/operators/masked_linear_operator.py create mode 100644 test/operators/test_masked_linear_operator.py diff --git a/docs/source/composition_decoration_operators.rst b/docs/source/composition_decoration_operators.rst index 1373d18f..628c79f3 100644 --- a/docs/source/composition_decoration_operators.rst +++ b/docs/source/composition_decoration_operators.rst @@ -40,6 +40,12 @@ Composition/Decoration LinearOperators .. autoclass:: linear_operator.operators.KroneckerProductDiagLinearOperator :members: +:hidden:`MaskedLinearOperator` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: linear_operator.operators.MaskedLinearOperator + :members: + :hidden:`MatmulLinearOperator` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/linear_operator/operators/_linear_operator.py b/linear_operator/operators/_linear_operator.py index f4c01fd2..f655c25b 100644 --- a/linear_operator/operators/_linear_operator.py +++ b/linear_operator/operators/_linear_operator.py @@ -2645,22 +2645,28 @@ def type(self: LinearOperator, dtype: torch.dtype) -> LinearOperator: """ attr_flag = _TYPES_DICT[dtype] + def _type_helper(arg): + if arg.dtype.is_floating_point: + return arg.to(dtype) + else: + return arg + new_args = [] new_kwargs = {} for arg in self._args: if hasattr(arg, attr_flag): try: - new_args.append(arg.clone().to(dtype)) + new_args.append(_type_helper(arg.clone())) except AttributeError: - new_args.append(deepcopy(arg).to(dtype)) + new_args.append(_type_helper(deepcopy(arg))) else: new_args.append(arg) for name, val in self._kwargs.items(): if hasattr(val, attr_flag): try: - new_kwargs[name] = val.clone().to(dtype) + new_kwargs[name] = _type_helper(val.clone()) except AttributeError: - new_kwargs[name] = deepcopy(val).to(dtype) + new_kwargs[name] = _type_helper(deepcopy(val)) else: new_kwargs[name] = val return self.__class__(*new_args, **new_kwargs) diff --git a/linear_operator/operators/interpolated_linear_operator.py b/linear_operator/operators/interpolated_linear_operator.py index 868283fd..c92eb84d 100644 --- a/linear_operator/operators/interpolated_linear_operator.py +++ b/linear_operator/operators/interpolated_linear_operator.py @@ -405,15 +405,6 @@ def _sum_batch(self, dim: int) -> LinearOperator: block_diag, left_interp_indices, left_interp_values, right_interp_indices, right_interp_values ) - def double( - self: Float[LinearOperator, "*batch M N"], device_id: Optional[str] = None - ) -> Float[LinearOperator, "*batch M N"]: - # We need to ensure that the indices remain integers. - new_lt = super().double(device_id=device_id) - new_lt.left_interp_indices = new_lt.left_interp_indices.type(torch.int64) - new_lt.right_interp_indices = new_lt.right_interp_indices.type(torch.int64) - return new_lt - def matmul( self: Float[LinearOperator, "*batch M N"], other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], diff --git a/linear_operator/operators/masked_linear_operator.py b/linear_operator/operators/masked_linear_operator.py new file mode 100644 index 00000000..634b2e02 --- /dev/null +++ b/linear_operator/operators/masked_linear_operator.py @@ -0,0 +1,111 @@ +from typing import List, Optional, Tuple, Union + +import torch +from jaxtyping import Bool, Float +from torch import Tensor + +from linear_operator import LinearOperator + +from linear_operator.utils.getitem import _is_noop_index, IndexType + + +class MaskedLinearOperator(LinearOperator): + r""" + A :obj:`~linear_operator.operators.LinearOperator` that applies a mask to the rows and columns of a base + :obj:`~linear_operator.operators.LinearOperator`. + """ + + def __init__( + self, + base: Float[LinearOperator, "*batch M0 N0"], + row_mask: Bool[Tensor, "M0"], + col_mask: Bool[Tensor, "N0"], + ): + r""" + Create a new :obj:`~linear_operator.operators.MaskedLinearOperator` that applies a mask to the rows and columns + of a base :obj:`~linear_operator.operators.LinearOperator`. + + :param base: The base :obj:`~linear_operator.operators.LinearOperator`. + :param row_mask: A :obj:`torch.BoolTensor` containing the mask to apply to the rows. + :param col_mask: A :obj:`torch.BoolTensor` containing the mask to apply to the columns. + """ + _args_memo = None + super().__init__(base, row_mask, col_mask) + self.base = base + self.row_mask = row_mask + self.col_mask = col_mask + self.row_eq_col_mask = torch.equal(row_mask, col_mask) + + @staticmethod + def _expand(tensor: Float[Tensor, "*batch N C"], mask: Bool[Tensor, "N0"]) -> Float[Tensor, "*batch N0 C"]: + res = torch.zeros( + *tensor.shape[:-2], + mask.size(-1), + tensor.size(-1), + device=tensor.device, + dtype=tensor.dtype, + ) + res[..., mask, :] = tensor + return res + + def _matmul( + self: Float[LinearOperator, "*batch M N"], + rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], + ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + rhs_expanded = self._expand(rhs, self.col_mask) + res_expanded = self.base.matmul(rhs_expanded) + res = res_expanded[..., self.row_mask, :] + + return res + + def _t_matmul( + self: Float[LinearOperator, "*batch M N"], + rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], + ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: + rhs_expanded = self._expand(rhs, self.row_mask) + res_expanded = self.base.t_matmul(rhs_expanded) + res = res_expanded[..., self.col_mask, :] + return res + + def _size(self) -> torch.Size: + return torch.Size( + (*self.base.size()[:-2], torch.count_nonzero(self.row_mask), torch.count_nonzero(self.col_mask)) + ) + + def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + return self.__class__(self.base.mT, self.col_mask, self.row_mask) + + def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: + if not self.row_eq_col_mask: + raise NotImplementedError() + diag = self.base.diagonal() + return diag[..., self.row_mask] + + def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + full_dense = self.base.to_dense() + return full_dense[..., self.row_mask, :][..., :, self.col_mask] + + def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: + left_vecs = self._expand(left_vecs, self.row_mask) + right_vecs = self._expand(right_vecs, self.col_mask) + return self.base._bilinear_derivative(left_vecs, right_vecs) + (None, None) + + def _expand_batch( + self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] + ) -> Float[LinearOperator, "... M N"]: + return self.__class__(self.base._expand_batch(batch_shape), self.row_mask, self.col_mask) + + def _unsqueeze_batch(self, dim: int) -> LinearOperator: + return self.__class__(self.base._unsqueeze_batch(dim), self.row_mask, self.col_mask) + + def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator: + if _is_noop_index(row_index) and _is_noop_index(col_index): + if len(batch_indices): + return self.__class__(self.base[batch_indices], self.row_mask, self.col_mask) + else: + return self + else: + return super()._getitem(row_index, col_index, *batch_indices) + + def _permute_batch(self, *dims: int) -> LinearOperator: + return self.__class__(self.base._permute_batch(*dims), self.row_mask, self.col_mask) diff --git a/linear_operator/test/linear_operator_test_case.py b/linear_operator/test/linear_operator_test_case.py index 49e6ce74..6df70bb6 100644 --- a/linear_operator/test/linear_operator_test_case.py +++ b/linear_operator/test/linear_operator_test_case.py @@ -684,7 +684,10 @@ def test_bilinear_derivative(self): ) for dc, da in zip(deriv_custom, deriv_auto): - self.assertAllClose(dc, da) + if dc is None: + assert da is None + else: + self.assertAllClose(dc, da) def test_cat_rows(self): linear_op = self.create_linear_op() diff --git a/test/operators/test_masked_linear_operator.py b/test/operators/test_masked_linear_operator.py new file mode 100644 index 00000000..4df5bf4e --- /dev/null +++ b/test/operators/test_masked_linear_operator.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 + +import unittest + +import torch + +from linear_operator import to_linear_operator +from linear_operator.operators.masked_linear_operator import MaskedLinearOperator +from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase, RectangularLinearOperatorTestCase + + +class TestMaskedLinearOperator(LinearOperatorTestCase, unittest.TestCase): + seed = 2023 + + def create_linear_op(self): + base = torch.randn(5, 5) + base = base.mT @ base + base.requires_grad_(True) + base = to_linear_operator(base) + mask = torch.tensor([1, 1, 0, 1, 0], dtype=torch.bool) + covar = MaskedLinearOperator(base, mask, mask) + return covar + + def evaluate_linear_op(self, linear_op): + base = linear_op.base.to_dense() + return base[..., linear_op.row_mask, :][..., linear_op.col_mask] + + +class TestMaskedLinearOperatorBatch(LinearOperatorTestCase, unittest.TestCase): + seed = 2023 + + def create_linear_op(self): + base = torch.randn(2, 5, 5) + base = base.mT @ base + base.requires_grad_(True) + base = to_linear_operator(base) + mask = torch.tensor([1, 1, 0, 1, 0], dtype=torch.bool) + covar = MaskedLinearOperator(base, mask, mask) + return covar + + def evaluate_linear_op(self, linear_op): + base = linear_op.base.to_dense() + return base[..., linear_op.row_mask, :][..., linear_op.col_mask] + + +class TestMaskedLinearOperatorRectangular(RectangularLinearOperatorTestCase, unittest.TestCase): + seed = 2023 + + def create_linear_op(self): + base = to_linear_operator(torch.randn(5, 6, requires_grad=True)) + row_mask = torch.tensor([1, 1, 0, 1, 0], dtype=torch.bool) + col_mask = torch.tensor([1, 0, 1, 0, 1, 1], dtype=torch.bool) + covar = MaskedLinearOperator(base, row_mask, col_mask) + return covar + + def evaluate_linear_op(self, linear_op): + base = linear_op.base.to_dense() + return base[..., linear_op.row_mask, :][..., linear_op.col_mask] + + +class TestMaskedLinearOperatorRectangularMultiBatch(RectangularLinearOperatorTestCase, unittest.TestCase): + seed = 2023 + + def create_linear_op(self): + base = to_linear_operator(torch.randn(2, 3, 5, 6, requires_grad=True)) + row_mask = torch.tensor([1, 1, 0, 1, 0], dtype=torch.bool) + col_mask = torch.tensor([1, 0, 1, 0, 1, 1], dtype=torch.bool) + covar = MaskedLinearOperator(base, row_mask, col_mask) + return covar + + def evaluate_linear_op(self, linear_op): + base = linear_op.base.to_dense() + return base[..., linear_op.row_mask, :][..., linear_op.col_mask] + + +if __name__ == "__main__": + unittest.main() From c9e992a9092a3ba6c10452ca60030a1307a040c7 Mon Sep 17 00:00:00 2001 From: Turakar Date: Wed, 7 Jun 2023 14:31:03 +0200 Subject: [PATCH 2/5] Add to index --- linear_operator/operators/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/linear_operator/operators/__init__.py b/linear_operator/operators/__init__.py index ec95aa65..06fbc79f 100644 --- a/linear_operator/operators/__init__.py +++ b/linear_operator/operators/__init__.py @@ -23,6 +23,7 @@ ) from .low_rank_root_added_diag_linear_operator import LowRankRootAddedDiagLinearOperator from .low_rank_root_linear_operator import LowRankRootLinearOperator +from .masked_linear_operator import MaskedLinearOperator from .matmul_linear_operator import MatmulLinearOperator from .mul_linear_operator import MulLinearOperator from .permutation_linear_operator import PermutationLinearOperator, TransposePermutationLinearOperator @@ -62,6 +63,7 @@ "SumKroneckerLinearOperator", "LowRankRootAddedDiagLinearOperator", "LowRankRootLinearOperator", + "MaskedLinearOperator", "MatmulLinearOperator", "MulLinearOperator", "PermutationLinearOperator", From 5d631574648fdfd5d263ea6de0272be1beb08ed9 Mon Sep 17 00:00:00 2001 From: Turakar Date: Wed, 7 Jun 2023 14:35:10 +0200 Subject: [PATCH 3/5] Fix import --- linear_operator/operators/masked_linear_operator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/linear_operator/operators/masked_linear_operator.py b/linear_operator/operators/masked_linear_operator.py index 634b2e02..863e07fc 100644 --- a/linear_operator/operators/masked_linear_operator.py +++ b/linear_operator/operators/masked_linear_operator.py @@ -4,9 +4,7 @@ from jaxtyping import Bool, Float from torch import Tensor -from linear_operator import LinearOperator - -from linear_operator.utils.getitem import _is_noop_index, IndexType +from ._linear_operator import IndexType, LinearOperator, _is_noop_index class MaskedLinearOperator(LinearOperator): From 13af5e9d6f34526f6583b10b3dcc930dcf5fdf66 Mon Sep 17 00:00:00 2001 From: Turakar Date: Wed, 7 Jun 2023 14:38:31 +0200 Subject: [PATCH 4/5] Linting --- linear_operator/operators/masked_linear_operator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/linear_operator/operators/masked_linear_operator.py b/linear_operator/operators/masked_linear_operator.py index 863e07fc..0269fb60 100644 --- a/linear_operator/operators/masked_linear_operator.py +++ b/linear_operator/operators/masked_linear_operator.py @@ -4,7 +4,7 @@ from jaxtyping import Bool, Float from torch import Tensor -from ._linear_operator import IndexType, LinearOperator, _is_noop_index +from ._linear_operator import _is_noop_index, IndexType, LinearOperator class MaskedLinearOperator(LinearOperator): @@ -27,7 +27,6 @@ def __init__( :param row_mask: A :obj:`torch.BoolTensor` containing the mask to apply to the rows. :param col_mask: A :obj:`torch.BoolTensor` containing the mask to apply to the columns. """ - _args_memo = None super().__init__(base, row_mask, col_mask) self.base = base self.row_mask = row_mask From 1b9987beda537e0313e5ddd7392f1fa8d3e83b56 Mon Sep 17 00:00:00 2001 From: Turakar Date: Fri, 23 Jun 2023 18:11:22 +0200 Subject: [PATCH 5/5] Add _get_indices() --- linear_operator/operators/masked_linear_operator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/linear_operator/operators/masked_linear_operator.py b/linear_operator/operators/masked_linear_operator.py index 0269fb60..9968b1c3 100644 --- a/linear_operator/operators/masked_linear_operator.py +++ b/linear_operator/operators/masked_linear_operator.py @@ -104,5 +104,10 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I else: return super()._getitem(row_index, col_index, *batch_indices) + def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: + row_mapping = torch.arange(self.base.size(-2), device=self.base.device)[self.row_mask] + col_mapping = torch.arange(self.base.size(-1), device=self.base.device)[self.col_mask] + return self.base._get_indices(row_mapping[row_index], col_mapping[col_index], *batch_indices) + def _permute_batch(self, *dims: int) -> LinearOperator: return self.__class__(self.base._permute_batch(*dims), self.row_mask, self.col_mask)