Skip to content

Commit 7affaf3

Browse files
gpleissBalandat
andauthored
Add KernelLinearOperator, deprecate KeOpsLinearOperator (#62)
* Add KernelLinearOperator, deprecate KeOpsLinearOperator KeOpsLinearOperator does not correctly backpropagate gradients if the covar_func closes over parameters. KernelLinearOperator corrects for this, and is set up to replace LazyEvaluatedKernelTensor in GPyTorch down the line. * Fix KeOpsLinearOperator deprecation * Allow for kernels with reduced batches and multiple outputs per input * LinearOperator kwargs can be differentiated through Previously, only positional args were added to the LinearOperator representation, and so only positional args would receive gradients from _bilinear_derivative. This commit also adds Tensor/LinearOperator kwargs to the representation, and so kwarg Tensor/LinearOperators will also receive gradients. * Hyperparameters for KernelLinearOperator must be kwargs * LO._bilinear_derivative only computes derivatives for args that require gradients * Expand upon closure variables warning for KernelLinearOperator * LO._bilinear_derivative exits early if no parameters require gradients * Refactor KernelLinearOperator._getitem * Allow for optional number of nonbatch parameter dimensions * Fix LO._bilinear_derivative * Update linear_operator/operators/_linear_operator.py Co-authored-by: Max Balandat <[email protected]> * Update linear_operator/operators/kernel_linear_operator.py Co-authored-by: Max Balandat <[email protected]> * Update linear_operator/operators/linear_operator_representation_tree.py Co-authored-by: Max Balandat <[email protected]> * Update linear_operator/operators/kernel_linear_operator.py Co-authored-by: Max Balandat <[email protected]> * Update linear_operator/operators/kernel_linear_operator.py Co-authored-by: Max Balandat <[email protected]> * Update linear_operator/operators/kernel_linear_operator.py Co-authored-by: Max Balandat <[email protected]> * Update linear_operator/operators/kernel_linear_operator.py Co-authored-by: Max Balandat <[email protected]> * Update linear_operator/operators/kernel_linear_operator.py Co-authored-by: Max Balandat <[email protected]> * Update linear_operator/operators/kernel_linear_operator.py Co-authored-by: Max Balandat <[email protected]> * Update linear_operator/operators/kernel_linear_operator.py Co-authored-by: Max Balandat <[email protected]> * Update linear_operator/operators/kernel_linear_operator.py Co-authored-by: Max Balandat <[email protected]> * Fix errors, address comments * KroneckerProductLinearOperator broadcasts * Test cases and fixes for multitask KernelLinearOperator --------- Co-authored-by: Max Balandat <[email protected]>
1 parent f020146 commit 7affaf3

10 files changed

+699
-35
lines changed

docs/source/conf.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,20 @@ def _dim_to_str(dim):
116116
if isinstance(dim, jaxtyping.array_types._NamedVariadicDim):
117117
return "..."
118118
elif isinstance(dim, jaxtyping.array_types._FixedDim):
119-
return str(dim.size)
119+
res = str(dim.size)
120+
if dim.broadcastable:
121+
res = "#" + res
122+
return res
120123
elif isinstance(dim, jaxtyping.array_types._SymbolicDim):
121124
expr = code_deparse(dim.expr).text.strip().split("return ")[1]
122125
return f"({expr})"
123126
elif "jaxtyping" not in str(dim.__class__): # Probably the case that we have an ellipsis
124127
return "..."
125128
else:
126-
return str(dim.name)
129+
res = str(dim.name)
130+
if dim.broadcastable:
131+
res = "#" + res
132+
return res
127133

128134

129135
# Function to format type hints
@@ -152,9 +158,15 @@ def _process(annotation, config):
152158
elif hasattr(annotation, "__name__"):
153159
res = _convert_internal_and_external_class_to_strings(annotation)
154160

161+
elif str(annotation).startswith("typing.Callable"):
162+
if len(annotation.__args__) == 2:
163+
res = f"Callable[{_process(annotation.__args__[0], config)} -> {_process(annotation.__args__[1], config)}]"
164+
else:
165+
res = "Callable"
166+
155167
# Convert any Union[*A*, *B*, *C*] into "*A* or *B* or *C*"
156168
# Also, convert any Optional[*A*] into "*A*, optional"
157-
elif "typing.Union" in str(annotation):
169+
elif str(annotation).startswith("typing.Union"):
158170
is_optional_str = ""
159171
args = list(annotation.__args__)
160172
# Hack: Optional[*A*] are represented internally as Union[*A*, Nonetype]
@@ -166,13 +178,13 @@ def _process(annotation, config):
166178
res = " or ".join(processed_args) + is_optional_str
167179

168180
# Convert any Tuple[*A*, *B*] into "(*A*, *B*)"
169-
elif "typing.Tuple" in str(annotation):
181+
elif str(annotation).startswith("typing.Tuple"):
170182
args = list(annotation.__args__)
171183
res = "(" + ", ".join(_process(arg, config) for arg in args) + ")"
172184

173185
# Callable typing annotation
174-
elif "typing." in str(annotation):
175-
return str(annotation)
186+
elif str(annotation).startswith("typing."):
187+
return str(annotation)[7:]
176188

177189
# Special cases for forward references.
178190
# This is brittle, as it only contains case for a select few forward refs

docs/source/data_sparse_operators.rst

+6
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ Data-Sparse LinearOperators
3636
.. autoclass:: linear_operator.operators.IdentityLinearOperator
3737
:members:
3838

39+
:hidden:`KernelLinearOperator`
40+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41+
42+
.. autoclass:: linear_operator.operators.KernelLinearOperator
43+
:members:
44+
3945
:hidden:`RootLinearOperator`
4046
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4147

linear_operator/operators/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .identity_linear_operator import IdentityLinearOperator
1515
from .interpolated_linear_operator import InterpolatedLinearOperator
1616
from .keops_linear_operator import KeOpsLinearOperator
17+
from .kernel_linear_operator import KernelLinearOperator
1718
from .kronecker_product_added_diag_linear_operator import KroneckerProductAddedDiagLinearOperator
1819
from .kronecker_product_linear_operator import (
1920
KroneckerProductDiagLinearOperator,
@@ -53,6 +54,7 @@
5354
"IdentityLinearOperator",
5455
"InterpolatedLinearOperator",
5556
"KeOpsLinearOperator",
57+
"KernelLinearOperator",
5658
"KroneckerProductLinearOperator",
5759
"KroneckerProductAddedDiagLinearOperator",
5860
"KroneckerProductDiagLinearOperator",

linear_operator/operators/_linear_operator.py

+35-11
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from __future__ import annotations
44

55
import functools
6+
import itertools
67
import math
78
import numbers
89
import warnings
910
from abc import abstractmethod
11+
from collections import OrderedDict
1012
from copy import deepcopy
1113
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1214

@@ -150,7 +152,14 @@ def __init__(self, *args, **kwargs):
150152
raise ValueError(err)
151153

152154
self._args = args
153-
self._kwargs = kwargs
155+
self._differentiable_kwargs = OrderedDict()
156+
self._nondifferentiable_kwargs = dict()
157+
for name, val in sorted(kwargs.items()):
158+
# Sorting is necessary so that the flattening in the representation tree is deterministic
159+
if torch.is_tensor(val) or isinstance(val, LinearOperator):
160+
self._differentiable_kwargs[name] = val
161+
else:
162+
self._nondifferentiable_kwargs[name] = val
154163

155164
####
156165
# The following methods need to be defined by the LinearOperator
@@ -350,17 +359,24 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O
350359
"""
351360
from collections import deque
352361

353-
args = tuple(self.representation())
354-
args_with_grads = tuple(arg for arg in args if arg.requires_grad)
362+
# Construct a detached version of each argument in the linear operator
363+
args = []
364+
for arg in self.representation():
365+
# All arguments here are guaranteed to be tensors
366+
if arg.dtype.is_floating_point and arg.requires_grad:
367+
args.append(arg.detach().requires_grad_(True))
368+
else:
369+
args.append(arg.detach())
355370

356-
# Easy case: if we don't require any gradients, then just return!
357-
if not len(args_with_grads):
358-
return tuple(None for _ in args)
371+
# If no arguments require gradients, then we're done!
372+
if not any(arg.requires_grad for arg in args):
373+
return (None,) * len(args)
359374

360-
# Normal case: we'll use the autograd to get us a derivative
375+
# We'll use the autograd to get us a derivative
361376
with torch.autograd.enable_grad():
362-
loss = (left_vecs * self._matmul(right_vecs)).sum()
363-
loss.requires_grad_(True)
377+
lin_op = self.representation_tree()(*args)
378+
loss = (left_vecs * lin_op._matmul(right_vecs)).sum()
379+
args_with_grads = [arg for arg in args if arg.requires_grad]
364380
actual_grads = deque(torch.autograd.grad(loss, args_with_grads, allow_unused=True))
365381

366382
# Now make sure that the object we return has one entry for every item in args
@@ -457,6 +473,10 @@ def _args(self) -> Tuple[Union[torch.Tensor, "LinearOperator", int], ...]:
457473
def _args(self, args: Tuple[Union[torch.Tensor, "LinearOperator", int], ...]) -> None:
458474
self._args_memo = args
459475

476+
@property
477+
def _kwargs(self) -> Dict[str, Any]:
478+
return {**self._differentiable_kwargs, **self._nondifferentiable_kwargs}
479+
460480
def _approx_diagonal(self: Float[LinearOperator, "*batch N N"]) -> Float[torch.Tensor, "*batch N"]:
461481
"""
462482
(Optional) returns an (approximate) diagonal of the matrix
@@ -1344,7 +1364,11 @@ def detach(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "
13441364
(In practice, this function removes all Tensors that make up the
13451365
:obj:`~linear_operator.opeators.LinearOperator` from the computation graph.)
13461366
"""
1347-
return self.clone().detach_()
1367+
detached_args = [arg.detach() if hasattr(arg, "detach") else arg for arg in self._args]
1368+
detached_kwargs = dict(
1369+
(key, val.detach() if hasattr(val, "detach") else val) for key, val in self._kwargs.items()
1370+
)
1371+
return self.__class__(*detached_args, **detached_kwargs)
13481372

13491373
def detach_(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]:
13501374
"""
@@ -2013,7 +2037,7 @@ def representation(self) -> Tuple[torch.Tensor, ...]:
20132037
Returns the Tensors that are used to define the LinearOperator
20142038
"""
20152039
representation = []
2016-
for arg in self._args:
2040+
for arg in itertools.chain(self._args, self._differentiable_kwargs.values()):
20172041
if torch.is_tensor(arg):
20182042
representation.append(arg)
20192043
elif hasattr(arg, "representation") and callable(arg.representation): # Is it a LinearOperator?

linear_operator/operators/keops_linear_operator.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import warnings
4+
35
from typing import Optional, Tuple, Union
46

57
import torch
@@ -13,6 +15,10 @@
1315

1416
class KeOpsLinearOperator(LinearOperator):
1517
def __init__(self, x1, x2, covar_func, **params):
18+
warnings.warn(
19+
"KeOpsLinearOperator is deprecated. Please use KernelLinearOperator instead.",
20+
DeprecationWarning,
21+
)
1622
super().__init__(x1, x2, covar_func=covar_func, **params)
1723

1824
self.x1 = x1.contiguous()

0 commit comments

Comments
 (0)