From 6f7ee0151f7ad724c2204bd49dae4e4f1956878d Mon Sep 17 00:00:00 2001 From: Danny Friar Date: Tue, 6 Sep 2022 09:32:31 +0100 Subject: [PATCH 1/4] Accept closure argument in NGD optimizer `step` Make this consistent with the base optimizer class --- gpytorch/optim/ngd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/optim/ngd.py b/gpytorch/optim/ngd.py index caf807b91..6650bac94 100644 --- a/gpytorch/optim/ngd.py +++ b/gpytorch/optim/ngd.py @@ -28,7 +28,7 @@ def __init__(self, params: Iterable[Union[torch.nn.Parameter, dict]], num_data: super().__init__(params, defaults=dict(lr=lr)) @torch.no_grad() - def step(self) -> None: + def step(self, closure) -> None: """Performs a single optimization step.""" for group in self.param_groups: for p in group["params"]: From 4337a0029d69a3b1c2111bee7200844d461e69ed Mon Sep 17 00:00:00 2001 From: Geoff Pleiss Date: Tue, 6 Sep 2022 12:38:21 -0400 Subject: [PATCH 2/4] Update NGD docs. --- gpytorch/optim/ngd.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/gpytorch/optim/ngd.py b/gpytorch/optim/ngd.py index 6650bac94..a74b89702 100644 --- a/gpytorch/optim/ngd.py +++ b/gpytorch/optim/ngd.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from typing import Iterable, Union +from typing import Iterable, Union, Callable import torch @@ -28,8 +28,13 @@ def __init__(self, params: Iterable[Union[torch.nn.Parameter, dict]], num_data: super().__init__(params, defaults=dict(lr=lr)) @torch.no_grad() - def step(self, closure) -> None: - """Performs a single optimization step.""" + def step(self, closure: Optional[Callable] = None) -> None: + """ + Performs a single optimization step. + + (Note that the :attr:`closure` argument is not used by this optimizer; it is simply included to be + compatible with the PyTorch optimizer API.) + """ for group in self.param_groups: for p in group["params"]: if p.grad is None: From a0c7af5cd53d5f44526a41bea7541ca3e4682bb5 Mon Sep 17 00:00:00 2001 From: Geoff Pleiss Date: Tue, 6 Sep 2022 12:40:53 -0400 Subject: [PATCH 3/4] Update ngd.py --- gpytorch/optim/ngd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/optim/ngd.py b/gpytorch/optim/ngd.py index a74b89702..ed520988b 100644 --- a/gpytorch/optim/ngd.py +++ b/gpytorch/optim/ngd.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from typing import Iterable, Union, Callable +from typing import Callable, Iterable, Optional, Union import torch From 4732aa4d38a22e11290467d02e35d8429ddcecd0 Mon Sep 17 00:00:00 2001 From: Geoff Pleiss Date: Tue, 6 Sep 2022 12:42:16 -0400 Subject: [PATCH 4/4] Update ngd.py --- gpytorch/optim/ngd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/optim/ngd.py b/gpytorch/optim/ngd.py index ed520988b..e11cb1af7 100644 --- a/gpytorch/optim/ngd.py +++ b/gpytorch/optim/ngd.py @@ -31,7 +31,7 @@ def __init__(self, params: Iterable[Union[torch.nn.Parameter, dict]], num_data: def step(self, closure: Optional[Callable] = None) -> None: """ Performs a single optimization step. - + (Note that the :attr:`closure` argument is not used by this optimizer; it is simply included to be compatible with the PyTorch optimizer API.) """