Skip to content

Add KernelLinearOperator, deprecate KeOpsLinearOperator #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b29ba6e
Add KernelLinearOperator, deprecate KeOpsLinearOperator
gpleiss May 5, 2023
b0bf6dd
Fix KeOpsLinearOperator deprecation
gpleiss May 24, 2023
cab7af0
Allow for kernels with reduced batches and multiple outputs per input
gpleiss May 24, 2023
83df1b9
LinearOperator kwargs can be differentiated through
gpleiss May 25, 2023
235197d
Hyperparameters for KernelLinearOperator must be kwargs
gpleiss May 25, 2023
6e42cb5
LO._bilinear_derivative only computes derivatives for args that requi…
gpleiss May 25, 2023
0dde888
Expand upon closure variables warning for KernelLinearOperator
gpleiss May 25, 2023
6255472
LO._bilinear_derivative exits early if no parameters require gradients
gpleiss May 25, 2023
6da16e4
Refactor KernelLinearOperator._getitem
gpleiss May 25, 2023
5762c62
Allow for optional number of nonbatch parameter dimensions
gpleiss May 25, 2023
da7dad9
Fix LO._bilinear_derivative
gpleiss May 26, 2023
bfde1fe
Update linear_operator/operators/_linear_operator.py
gpleiss May 27, 2023
a1fb466
Update linear_operator/operators/kernel_linear_operator.py
gpleiss May 27, 2023
90f676c
Update linear_operator/operators/linear_operator_representation_tree.py
gpleiss May 27, 2023
bfa03c0
Update linear_operator/operators/kernel_linear_operator.py
gpleiss May 27, 2023
83136f4
Update linear_operator/operators/kernel_linear_operator.py
gpleiss May 27, 2023
e27cfec
Update linear_operator/operators/kernel_linear_operator.py
gpleiss May 27, 2023
5a6a3f4
Update linear_operator/operators/kernel_linear_operator.py
gpleiss May 27, 2023
b280037
Update linear_operator/operators/kernel_linear_operator.py
gpleiss May 27, 2023
dfd45f3
Update linear_operator/operators/kernel_linear_operator.py
gpleiss May 27, 2023
3aa7178
Update linear_operator/operators/kernel_linear_operator.py
gpleiss May 27, 2023
cb9f16f
Update linear_operator/operators/kernel_linear_operator.py
gpleiss May 27, 2023
aefc73d
Fix errors, address comments
gpleiss May 27, 2023
cc92b70
KroneckerProductLinearOperator broadcasts
gpleiss Jun 2, 2023
f97e9d5
Test cases and fixes for multitask KernelLinearOperator
gpleiss Jun 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,20 @@ def _dim_to_str(dim):
if isinstance(dim, jaxtyping.array_types._NamedVariadicDim):
return "..."
elif isinstance(dim, jaxtyping.array_types._FixedDim):
return str(dim.size)
res = str(dim.size)
if dim.broadcastable:
res = "#" + res
return res
elif isinstance(dim, jaxtyping.array_types._SymbolicDim):
expr = code_deparse(dim.expr).text.strip().split("return ")[1]
return f"({expr})"
elif "jaxtyping" not in str(dim.__class__): # Probably the case that we have an ellipsis
return "..."
else:
return str(dim.name)
res = str(dim.name)
if dim.broadcastable:
res = "#" + res
return res


# Function to format type hints
Expand Down Expand Up @@ -152,9 +158,15 @@ def _process(annotation, config):
elif hasattr(annotation, "__name__"):
res = _convert_internal_and_external_class_to_strings(annotation)

elif str(annotation).startswith("typing.Callable"):
if len(annotation.__args__) == 2:
res = f"Callable[{_process(annotation.__args__[0], config)} -> {_process(annotation.__args__[1], config)}]"
else:
res = "Callable"

# Convert any Union[*A*, *B*, *C*] into "*A* or *B* or *C*"
# Also, convert any Optional[*A*] into "*A*, optional"
elif "typing.Union" in str(annotation):
elif str(annotation).startswith("typing.Union"):
is_optional_str = ""
args = list(annotation.__args__)
# Hack: Optional[*A*] are represented internally as Union[*A*, Nonetype]
Expand All @@ -166,13 +178,13 @@ def _process(annotation, config):
res = " or ".join(processed_args) + is_optional_str

# Convert any Tuple[*A*, *B*] into "(*A*, *B*)"
elif "typing.Tuple" in str(annotation):
elif str(annotation).startswith("typing.Tuple"):
args = list(annotation.__args__)
res = "(" + ", ".join(_process(arg, config) for arg in args) + ")"

# Callable typing annotation
elif "typing." in str(annotation):
return str(annotation)
elif str(annotation).startswith("typing."):
return str(annotation)[7:]

# Special cases for forward references.
# This is brittle, as it only contains case for a select few forward refs
Expand Down
6 changes: 6 additions & 0 deletions docs/source/data_sparse_operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ Data-Sparse LinearOperators
.. autoclass:: linear_operator.operators.IdentityLinearOperator
:members:

:hidden:`KernelLinearOperator`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: linear_operator.operators.KernelLinearOperator
:members:

:hidden:`RootLinearOperator`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions linear_operator/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .identity_linear_operator import IdentityLinearOperator
from .interpolated_linear_operator import InterpolatedLinearOperator
from .keops_linear_operator import KeOpsLinearOperator
from .kernel_linear_operator import KernelLinearOperator
from .kronecker_product_added_diag_linear_operator import KroneckerProductAddedDiagLinearOperator
from .kronecker_product_linear_operator import (
KroneckerProductDiagLinearOperator,
Expand Down Expand Up @@ -53,6 +54,7 @@
"IdentityLinearOperator",
"InterpolatedLinearOperator",
"KeOpsLinearOperator",
"KernelLinearOperator",
"KroneckerProductLinearOperator",
"KroneckerProductAddedDiagLinearOperator",
"KroneckerProductDiagLinearOperator",
Expand Down
46 changes: 35 additions & 11 deletions linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from __future__ import annotations

import functools
import itertools
import math
import numbers
import warnings
from abc import abstractmethod
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -150,7 +152,14 @@ def __init__(self, *args, **kwargs):
raise ValueError(err)

self._args = args
self._kwargs = kwargs
self._differentiable_kwargs = OrderedDict()
self._nondifferentiable_kwargs = dict()
for name, val in sorted(kwargs.items()):
# Sorting is necessary so that the flattening in the representation tree is deterministic
if torch.is_tensor(val) or isinstance(val, LinearOperator):
self._differentiable_kwargs[name] = val
else:
self._nondifferentiable_kwargs[name] = val

####
# The following methods need to be defined by the LinearOperator
Expand Down Expand Up @@ -350,17 +359,24 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O
"""
from collections import deque

args = tuple(self.representation())
args_with_grads = tuple(arg for arg in args if arg.requires_grad)
# Construct a detached version of each argument in the linear operator
args = []
for arg in self.representation():
# All arguments here are guaranteed to be tensors
if arg.dtype.is_floating_point and arg.requires_grad:
args.append(arg.detach().requires_grad_(True))
else:
args.append(arg.detach())

# Easy case: if we don't require any gradients, then just return!
if not len(args_with_grads):
return tuple(None for _ in args)
# If no arguments require gradients, then we're done!
if not any(arg.requires_grad for arg in args):
return (None,) * len(args)

# Normal case: we'll use the autograd to get us a derivative
# We'll use the autograd to get us a derivative
with torch.autograd.enable_grad():
loss = (left_vecs * self._matmul(right_vecs)).sum()
loss.requires_grad_(True)
lin_op = self.representation_tree()(*args)
loss = (left_vecs * lin_op._matmul(right_vecs)).sum()
args_with_grads = [arg for arg in args if arg.requires_grad]
actual_grads = deque(torch.autograd.grad(loss, args_with_grads, allow_unused=True))

# Now make sure that the object we return has one entry for every item in args
Expand Down Expand Up @@ -457,6 +473,10 @@ def _args(self) -> Tuple[Union[torch.Tensor, "LinearOperator", int], ...]:
def _args(self, args: Tuple[Union[torch.Tensor, "LinearOperator", int], ...]) -> None:
self._args_memo = args

@property
def _kwargs(self) -> Dict[str, Any]:
return {**self._differentiable_kwargs, **self._nondifferentiable_kwargs}

def _approx_diagonal(self: Float[LinearOperator, "*batch N N"]) -> Float[torch.Tensor, "*batch N"]:
"""
(Optional) returns an (approximate) diagonal of the matrix
Expand Down Expand Up @@ -1344,7 +1364,11 @@ def detach(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "
(In practice, this function removes all Tensors that make up the
:obj:`~linear_operator.opeators.LinearOperator` from the computation graph.)
"""
return self.clone().detach_()
detached_args = [arg.detach() if hasattr(arg, "detach") else arg for arg in self._args]
detached_kwargs = dict(
(key, val.detach() if hasattr(val, "detach") else val) for key, val in self._kwargs.items()
)
return self.__class__(*detached_args, **detached_kwargs)

def detach_(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]:
"""
Expand Down Expand Up @@ -2013,7 +2037,7 @@ def representation(self) -> Tuple[torch.Tensor, ...]:
Returns the Tensors that are used to define the LinearOperator
"""
representation = []
for arg in self._args:
for arg in itertools.chain(self._args, self._differentiable_kwargs.values()):
if torch.is_tensor(arg):
representation.append(arg)
elif hasattr(arg, "representation") and callable(arg.representation): # Is it a LinearOperator?
Expand Down
6 changes: 6 additions & 0 deletions linear_operator/operators/keops_linear_operator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import warnings

from typing import Optional, Tuple, Union

import torch
Expand All @@ -13,6 +15,10 @@

class KeOpsLinearOperator(LinearOperator):
def __init__(self, x1, x2, covar_func, **params):
warnings.warn(
"KeOpsLinearOperator is deprecated. Please use KernelLinearOperator instead.",
DeprecationWarning,
)
super().__init__(x1, x2, covar_func=covar_func, **params)

self.x1 = x1.contiguous()
Expand Down
Loading