diff --git a/gpytorch/optim/ngd.py b/gpytorch/optim/ngd.py index caf807b91..e11cb1af7 100644 --- a/gpytorch/optim/ngd.py +++ b/gpytorch/optim/ngd.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from typing import Iterable, Union +from typing import Callable, Iterable, Optional, Union import torch @@ -28,8 +28,13 @@ def __init__(self, params: Iterable[Union[torch.nn.Parameter, dict]], num_data: super().__init__(params, defaults=dict(lr=lr)) @torch.no_grad() - def step(self) -> None: - """Performs a single optimization step.""" + def step(self, closure: Optional[Callable] = None) -> None: + """ + Performs a single optimization step. + + (Note that the :attr:`closure` argument is not used by this optimizer; it is simply included to be + compatible with the PyTorch optimizer API.) + """ for group in self.param_groups: for p in group["params"]: if p.grad is None: