Skip to content

Commit f6e19a3

Browse files
committed
LO._bilinear_derivative only computes derivatives for args that require gradients
1 parent 7dc7fb1 commit f6e19a3

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

linear_operator/operators/_linear_operator.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[O
362362
# Construct a detached version of each argument in the linear operator
363363
args = []
364364
for arg in self.representation():
365-
if torch.is_tensor(arg) and arg.dtype.is_floating_point:
365+
if torch.is_tensor(arg) and arg.dtype.is_floating_point and arg.requires_grad:
366366
args.append(arg.detach().requires_grad_(True))
367367
else:
368368
args.append(arg.detach())
@@ -467,11 +467,14 @@ def _args(self) -> Tuple[Union[torch.Tensor, "LinearOperator", int], ...]:
467467
def _args(self, args: Tuple[Union[torch.Tensor, "LinearOperator", int], ...]) -> None:
468468
self._args_memo = args
469469

470+
@property
471+
def _differentiable_kwargs(self) -> Dict[str, Union[Tensor, "LinearOperator"]]:
472+
return dict(zip(self._differentiable_kwarg_names, self._differentiable_kwarg_vals))
473+
470474
@property
471475
def _kwargs(self) -> Dict[str, Any]:
472-
kwargs = dict(
473-
zip(self._differentiable_kwarg_names, self._differentiable_kwarg_vals), **self._nondifferentiable_kwargs
474-
)
476+
kwargs = self._differentiable_kwargs
477+
kwargs.update(self._nondifferentiable_kwargs)
475478
return kwargs
476479

477480
def _approx_diagonal(self: Float[LinearOperator, "*batch N N"]) -> Float[torch.Tensor, "*batch N"]:

linear_operator/operators/kernel_linear_operator.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,13 @@ def __init__(
141141
)
142142

143143
# Create a version of each argument that is expanded to the broadcast batch shape
144-
x1 = x1.expand(*batch_broadcast_shape, *x1.shape[-2:]).contiguous()
145-
x2 = x2.expand(*batch_broadcast_shape, *x2.shape[-2:]).contiguous()
144+
#
145+
# NOTE: we must explicitly call requires_grad on each of these arguments
146+
# for the automatic _bilinear_derivative to work in torch.autograd.Functions
147+
x1 = x1.expand(*batch_broadcast_shape, *x1.shape[-2:]).contiguous().requires_grad_(x1.requires_grad)
148+
x2 = x2.expand(*batch_broadcast_shape, *x2.shape[-2:]).contiguous().requires_grad_(x2.requires_grad)
146149
tensor_params = dict(
147-
(name, val.expand(*batch_broadcast_shape, *param_nonbatch_shapes[name]))
150+
(name, val.expand(*batch_broadcast_shape, *param_nonbatch_shapes[name]).requires_grad_(val.requires_grad))
148151
for name, val in tensor_params.items()
149152
)
150153
new_param_batch_shapes = dict((name, batch_broadcast_shape) for name in param_batch_shapes.keys())

0 commit comments

Comments
 (0)