|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 | 5 | import functools
|
| 6 | +import itertools |
6 | 7 | import math
|
7 | 8 | import numbers
|
8 | 9 | import warnings
|
9 | 10 | from abc import abstractmethod
|
| 11 | +from collections import OrderedDict |
10 | 12 | from copy import deepcopy
|
11 | 13 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
12 | 14 |
|
@@ -150,7 +152,14 @@ def __init__(self, *args, **kwargs):
|
150 | 152 | raise ValueError(err)
|
151 | 153 |
|
152 | 154 | self._args = args
|
153 |
| - self._kwargs = kwargs |
| 155 | + self._differentiable_kwargs = OrderedDict() |
| 156 | + self._nondifferentiable_kwargs = dict() |
| 157 | + for name, val in sorted(kwargs.items()): |
| 158 | + # Sorting is necessary so that the flattening in the representation tree is deterministic |
| 159 | + if torch.is_tensor(val) or isinstance(val, LinearOperator): |
| 160 | + self._differentiable_kwargs[name] = val |
| 161 | + else: |
| 162 | + self._nondifferentiable_kwargs[name] = val |
154 | 163 |
|
155 | 164 | ####
|
156 | 165 | # The following methods need to be defined by the LinearOperator
|
@@ -350,17 +359,24 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O
|
350 | 359 | """
|
351 | 360 | from collections import deque
|
352 | 361 |
|
353 |
| - args = tuple(self.representation()) |
354 |
| - args_with_grads = tuple(arg for arg in args if arg.requires_grad) |
| 362 | + # Construct a detached version of each argument in the linear operator |
| 363 | + args = [] |
| 364 | + for arg in self.representation(): |
| 365 | + # All arguments here are guaranteed to be tensors |
| 366 | + if arg.dtype.is_floating_point and arg.requires_grad: |
| 367 | + args.append(arg.detach().requires_grad_(True)) |
| 368 | + else: |
| 369 | + args.append(arg.detach()) |
355 | 370 |
|
356 |
| - # Easy case: if we don't require any gradients, then just return! |
357 |
| - if not len(args_with_grads): |
358 |
| - return tuple(None for _ in args) |
| 371 | + # If no arguments require gradients, then we're done! |
| 372 | + if not any(arg.requires_grad for arg in args): |
| 373 | + return (None,) * len(args) |
359 | 374 |
|
360 |
| - # Normal case: we'll use the autograd to get us a derivative |
| 375 | + # We'll use the autograd to get us a derivative |
361 | 376 | with torch.autograd.enable_grad():
|
362 |
| - loss = (left_vecs * self._matmul(right_vecs)).sum() |
363 |
| - loss.requires_grad_(True) |
| 377 | + lin_op = self.representation_tree()(*args) |
| 378 | + loss = (left_vecs * lin_op._matmul(right_vecs)).sum() |
| 379 | + args_with_grads = [arg for arg in args if arg.requires_grad] |
364 | 380 | actual_grads = deque(torch.autograd.grad(loss, args_with_grads, allow_unused=True))
|
365 | 381 |
|
366 | 382 | # Now make sure that the object we return has one entry for every item in args
|
@@ -457,6 +473,10 @@ def _args(self) -> Tuple[Union[torch.Tensor, "LinearOperator", int], ...]:
|
457 | 473 | def _args(self, args: Tuple[Union[torch.Tensor, "LinearOperator", int], ...]) -> None:
|
458 | 474 | self._args_memo = args
|
459 | 475 |
|
| 476 | + @property |
| 477 | + def _kwargs(self) -> Dict[str, Any]: |
| 478 | + return {**self._differentiable_kwargs, **self._nondifferentiable_kwargs} |
| 479 | + |
460 | 480 | def _approx_diagonal(self: Float[LinearOperator, "*batch N N"]) -> Float[torch.Tensor, "*batch N"]:
|
461 | 481 | """
|
462 | 482 | (Optional) returns an (approximate) diagonal of the matrix
|
@@ -1344,7 +1364,11 @@ def detach(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "
|
1344 | 1364 | (In practice, this function removes all Tensors that make up the
|
1345 | 1365 | :obj:`~linear_operator.opeators.LinearOperator` from the computation graph.)
|
1346 | 1366 | """
|
1347 |
| - return self.clone().detach_() |
| 1367 | + detached_args = [arg.detach() if hasattr(arg, "detach") else arg for arg in self._args] |
| 1368 | + detached_kwargs = dict( |
| 1369 | + (key, val.detach() if hasattr(val, "detach") else val) for key, val in self._kwargs.items() |
| 1370 | + ) |
| 1371 | + return self.__class__(*detached_args, **detached_kwargs) |
1348 | 1372 |
|
1349 | 1373 | def detach_(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch M N"]:
|
1350 | 1374 | """
|
@@ -2013,7 +2037,7 @@ def representation(self) -> Tuple[torch.Tensor, ...]:
|
2013 | 2037 | Returns the Tensors that are used to define the LinearOperator
|
2014 | 2038 | """
|
2015 | 2039 | representation = []
|
2016 |
| - for arg in self._args: |
| 2040 | + for arg in itertools.chain(self._args, self._differentiable_kwargs.values()): |
2017 | 2041 | if torch.is_tensor(arg):
|
2018 | 2042 | representation.append(arg)
|
2019 | 2043 | elif hasattr(arg, "representation") and callable(arg.representation): # Is it a LinearOperator?
|
|
0 commit comments