Skip to content

keops periodic and keops kernels unit tests #2296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Jun 2, 2023

Conversation

m-julian
Copy link
Contributor

  1. Implemented keops periodic kernel
  2. Added unit tests for keops periodic, rbf, and matern kernels (both keops and nonkeops versions). The tensors in the keops unit tests are smaller than gpytorch.settings.max_cholesky_size.value(), so the pykeops implementations were not tested. In the unit tests, I set the max_cholesky_size to make sure that both keops/nonkeops versions are tested.
  3. Updated docstrings of the keops kernels because they should work with batch dimensions as well (batch unit tests pass).

I saw there is a check if not torch.cuda.is_available(), so I have also added it to the keops periodic kernel tests. Are these pykeops tests going to be ran with Github Actions or are they skipped? The tests ran fine locally with cuda.

@m-julian m-julian marked this pull request as draft March 13, 2023 14:21
@m-julian
Copy link
Contributor Author

Seems like there is an issue with the gradients because the lenghtscale hyperparameter is not optimized (period_length is optimized). I set up an example like the one in the docs and just changed out the base kernel to gpytorch.kernels.keops.PeriodicKernel(). Setting the train data to train_x = torch.linspace(0, 1, 1500) which should use the keops implementation gives:

before training
lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
raw_lengthscale grad: None
period length: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
raw_period length: None
after training
lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
raw_lengthscale grad: None
period length: tensor([[1.5373]], grad_fn=<SoftplusBackward0>)
raw_period length: tensor([[-0.0059]])

Interestingly, lengthscale does not change for train_x = torch.linspace(0, 1, 200) (below max_choleksy_size)

before training
lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
raw_lengthscale grad: None
period length: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
raw_period length: None
after training
lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
raw_lengthscale grad: None
period length: tensor([[1.5973]], grad_fn=<SoftplusBackward0>)
raw_period length: tensor([[-0.0017]])

I was expecting this to work as _nonkeops_covar_func is used in this case.

Switching to the gpytorch periodic kernel gpytorch.kernels.PeriodicKernel() gives:

before training
lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
raw_lengthscale grad: None
period length: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
raw_period length: None
after training
lengthscale: tensor([[1.0248]], grad_fn=<SoftplusBackward0>)
raw_lengthscale grad: tensor([[0.0019]])
period length: tensor([[1.6203]], grad_fn=<SoftplusBackward0>)
raw_period length: tensor([[-0.0118]])

Not sure what is causing the problem, so any help is appreciated.

# then don't apply KeOps
# enable gradients to ensure that test time caches on small predictions are still
# backprop-able
with torch.autograd.enable_grad():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this context manager being used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why it is there, I saw that it was in the KeOps RBF kernel, so I have left it as is in the periodic one. Still new to the KeOps library, so I thought it did something. I commented it out and it didn't do anything to the gradients (i.e. the lengthscale gradient was missing, but the period_length gradient was calculated).

@gpleiss
Copy link
Member

gpleiss commented Mar 16, 2023

@m-julian at a first glance, I'm not sure why the lengthscale isn't getting gradients...
As a sanity check, if you backpropagate through the KeOps PeriodicKernel with diag=True (which should call a non-KeOps function), do you get gradients?

@m-julian
Copy link
Contributor Author

@gpleiss sorry for the slow reply, I think the code below shows the issue

import torch
from gpytorch.kernels.keops import PeriodicKernel

torch.manual_seed(7)

x = torch.randn(1000, 2)
y = torch.randn(1000, 2)
k = PeriodicKernel(ard_num_dims=2)
covar = k(x, y, diag=True)
loss = covar.sum()
print("lengthscale grad before backward:", k.raw_lengthscale.grad)
print("period_length grad before backward:", k.raw_period_length.grad)
loss.backward()
print("lengthscale grad after backward:", k.raw_lengthscale.grad)
print("period_length grad after backward:", k.raw_period_length.grad)

which gives:

lengthscale grad before backward: None
period_length grad before backward: None
lengthscale grad after backward: tensor([[57.9246, 63.7257]])
period_length grad after backward: tensor([[ -76.7319, -104.1367]])

Setting diag=False gives:

lengthscale grad before backward: None
period_length grad before backward: None
lengthscale grad after backward: None
period_length grad after backward: tensor([[-2824.9148,  1428.2969]])

Since I am using 1000 points, it does use KeOps when diag=False. So seems like the issue comes from lengthscale being used in the KeOps part of the code. The period_length is used before that, so the gradient is calculated. Similarly, for the KeOps RBFKernel, the lengthscale is used before the KeOps part and the gradients are fine.

@m-julian
Copy link
Contributor Author

Seems like the raw_lengthscale parameter is missing from the graph when KeOpsLinearOperator is used.

Just as a test, I changed the forward method of the keops periodic to always use the _nonkeops_covar_func like so

def forward(self, x1, x2, diag=False, **params):

    x1_ = x1.div(self.period_length / math.pi)
    x2_ = x2.div(self.period_length / math.pi)

    covar_func = lambda x1, x2, diag=diag: self._nonkeops_covar_func(x1, x2, diag)

    return covar_func(x1_, x2_)

Then running the code below gives gradients for raw_period_length and raw_lengthscale

torch.manual_seed(7)
M, N, D = 1000, 900, 3
x = torch.randn(M, D).double()
y = torch.randn(N, D).double()

k = KeOpsPeriodicKernel(ard_num_dims=3).double()
k.lengthscale = torch.tensor(1.0).double()
k.period_length = torch.tensor(1.0).double()
with gpytorch.settings.lazily_evaluate_kernels(False):
    covar = k(x, y)
    print(type(covar))
    res2 = covar.sum(dim=1)
    res2 = res2.sum()
    print(res2)
    g_x = torch.autograd.grad(res2, [k.raw_lengthscale, k.raw_period_length])
    print(g_x)
<class 'linear_operator.operators.dense_linear_operator.DenseLinearOperator'>
tensor(90887.1133, dtype=torch.float64, grad_fn=<SumBackward0>)
(tensor([[31787.0246, 31835.8736, 31787.8815]], dtype=torch.float64), tensor([[ 536.2740, -637.6759, -213.9906]], dtype=torch.float64))

When returning a KeOpsLinearOperator instead

def forward(self, x1, x2, diag=False, **params):

    x1_ = x1.div(self.period_length / math.pi)
    x2_ = x2.div(self.period_length / math.pi)

    covar_func = lambda x1, x2, diag=diag: self._nonkeops_covar_func(x1, x2, diag)

    return KeOpsLinearOperator(x1_, x2_, covar_func)

gives the following output, where k.raw_lengthscale is causing the RuntimeError.

<class 'linear_operator.operators.keops_linear_operator.KeOpsLinearOperator'>
tensor(90887.1133, dtype=torch.float64, grad_fn=<SumBackward0>)
Traceback (most recent call last):
  File "/home/julian/Desktop/test/keops_periodic_low_level/keops_implementation.py", line 41, in <module>
    g_x = torch.autograd.grad(res2, [k.raw_lengthscale, k.raw_period_length])
  File "/home/julian/.venv/ichor/lib/python3.10/site-packages/torch/autograd/__init__.py", line 300, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Just to confirm that the gradient calculations in KeOps work, I implemented the kernel outside of gpytorch like so

torch.manual_seed(7)
M, N, D = 1000, 900, 3
x = torch.randn(M, D).double()
y = torch.randn(N, D).double()
period_length = torch.ones(1,1,D).double()
period_length = torch.nn.Parameter(period_length)
lengthscale = torch.ones((1,1,D)).double()
lengthscale = torch.nn.Parameter(lengthscale)

x_i = LazyTensor(x.view(M, 1, D))
y_j = LazyTensor(y.view(1, N, D))
x_i = x_i / (period_length / math.pi)
y_j = y_j / (period_length/ math.pi)

K = ((((x_i - y_j).abs().sin()) ** 2) * (-2.0 / lengthscale)).sum(-1).exp()
res1 = K.sum(dim=1)
res1 = res1.sum()
g_x = torch.autograd.grad(res1, [lengthscale, period_length])
print(res1)
print(g_x)
tensor(90887.1133, dtype=torch.float64, grad_fn=<SumBackward0>)
(tensor([[[50286.3324, 50363.6105, 50287.6881]]], dtype=torch.float64), tensor([[[  848.3730, -1008.7884,  -338.5282]]], dtype=torch.float64))

Gradients are calculated here (values are different because they are not for the raw parameters that gpytorch uses), so don't think it is an issue with KeOps.

@gpleiss
Copy link
Member

gpleiss commented May 5, 2023

@m-julian if you install LinearOperator locally and use the linops_keops branch, you can use KernelLinearOperator instead of KeOpsLinearOperator. (Docs are here: https://linear-operator--62.org.readthedocs.build/en/62/data_sparse_operators.html#kernellinearoperator)

Unlike with KeOpsLinearOperator covar_func should NOT close over any learnable parameters. (With our linear operator autograd hacks, closed over variables never make it to the computation graph, which is why you didn't see gradients for the period_length parameter.)

I'm hoping this branch will get merged in soon (and we'll put out a new LinearOperator release), so the local installation is only a temporary measure.

@m-julian
Copy link
Contributor Author

Thanks! I have updated the keops kernels to KernelLinearOperator, so they can be used when the main branch gets the changes. I have left the kernel parameters which were not closed over before outside of KernelLinearOperator (e.g. the lengthscale for RBF as it is divided out before covar_func). Gradients for both parameters in the periodic kernel get calculated now.

Currently the gpytorch diagonal tests are failing for the keops kernels at the assert here . The shape of diag_mat is ..., N, 1 instead of ..., N, 1, 1.

@gpleiss
Copy link
Member

gpleiss commented May 12, 2023

Taking a look at your code, I think that you should only construct the KernelLinearOperator if diag=False. I would move the if diag logic that is currently in covar_func to the forward method.

(The diag mode in the Kernel forward functions is a bit clunky right now - I'm thinking we can eventually simplify things with KernelLinearOperator.)

@m-julian
Copy link
Contributor Author

I moved diag into forward and the diagonals get computed fine now. I've also added some unit tests for the gradients and removed the with torch.autograd.enable_grad() context manager because it doesn't seem to change anything.

I think this PR should be ready to be merged in once linear_operator is updated.

@gpleiss gpleiss force-pushed the pykeops_periodic branch from 8a20024 to 638e3b8 Compare June 2, 2023 19:28
@gpleiss gpleiss force-pushed the pykeops_periodic branch from 638e3b8 to a3b6040 Compare June 2, 2023 19:48
@gpleiss gpleiss marked this pull request as ready for review June 2, 2023 20:36
@gpleiss gpleiss enabled auto-merge (squash) June 2, 2023 20:37
@gpleiss gpleiss merged commit deb90c2 into cornellius-gp:master Jun 2, 2023
@gpleiss
Copy link
Member

gpleiss commented Jun 2, 2023

Finally got it in! Thanks @m-julian !

gpleiss added a commit that referenced this pull request Sep 21, 2023
KernelLinearOperator was throwing errors when computing the diagonal of a KeOps kernel.
(This computation happens during preconditioning, which requires the
diagonal of the already-formed kernel LinearOperator object.)
This error was because KeopsLinearOperator.diagonal calls to_dense on
the output of a batch kernel operation. However, to_dense is not defined
for KeOps LazyTensors.

This PR is in some sense a hack fix to this bug (a less hack fix will
require changes to KernelLinearOperator), but it is also a generally
nice and helpful refactor that will improve KeOps kernels in general.

The fixes:
- KeOpsKernels now only define a forward function, that will be used
both when we want to use KeOps and when we want to bypass it.
- KeOpsKernels now use a `_lazify_inputs` helper method, which
(potentially) wraps the inputs as KeOpsLazyTensors, or potentially
leaves the inputs as torch Tensors.
- The KeOps wrapping happens unless we want to bypass KeOps, which
occurs when either (1) the matrix is small (below Cholesky size) or (2)
when the use has turned off the `gpytorch.settings.use_keops` option
(*NEW IN THIS PR*).

Why this is beneficial:
- KeOps kernels now follow the same API as non-KeOps kernels (define a
forward method)
- The user now only has to define one forward method, that works in both
the keops and non-keops cases
- The `diagonal` call in KeopsLinearOperator constructs a batch 1x1
matrix, which is small enough to bypass keops and thus avoid the current
bug. (Hence why this solution is currently a hack, but could become less
hacky with a small modification to KernelLinearOperator and/or the
to_dense method in LinearOperator).

Other changes:
- Fix stability issues with the keops MaternKernel. (There were some NaN
issues)
- Introduce a `gpytorch.settings.use_keops` feature flag.
- Clean up KeOPs notebook

[Fixes #2363]
gpleiss added a commit that referenced this pull request Sep 21, 2023
KernelLinearOperator was throwing errors when computing the diagonal of a KeOps kernel.
(This computation happens during preconditioning, which requires the
diagonal of the already-formed kernel LinearOperator object.)
This error was because KeopsLinearOperator.diagonal calls to_dense on
the output of a batch kernel operation. However, to_dense is not defined
for KeOps LazyTensors.

This PR is in some sense a hack fix to this bug (a less hack fix will
require changes to KernelLinearOperator), but it is also a generally
nice and helpful refactor that will improve KeOps kernels in general.

The fixes:
- KeOpsKernels now only define a forward function, that will be used
both when we want to use KeOps and when we want to bypass it.
- KeOpsKernels now use a `_lazify_inputs` helper method, which
(potentially) wraps the inputs as KeOpsLazyTensors, or potentially
leaves the inputs as torch Tensors.
- The KeOps wrapping happens unless we want to bypass KeOps, which
occurs when either (1) the matrix is small (below Cholesky size) or (2)
when the use has turned off the `gpytorch.settings.use_keops` option
(*NEW IN THIS PR*).

Why this is beneficial:
- KeOps kernels now follow the same API as non-KeOps kernels (define a
forward method)
- The user now only has to define one forward method, that works in both
the keops and non-keops cases
- The `diagonal` call in KeopsLinearOperator constructs a batch 1x1
matrix, which is small enough to bypass keops and thus avoid the current
bug. (Hence why this solution is currently a hack, but could become less
hacky with a small modification to KernelLinearOperator and/or the
to_dense method in LinearOperator).

Other changes:
- Fix stability issues with the keops MaternKernel. (There were some NaN
issues)
- Introduce a `gpytorch.settings.use_keops` feature flag.
- Clean up KeOPs notebook

[Fixes #2363]
gpleiss added a commit that referenced this pull request Sep 21, 2023
KernelLinearOperator was throwing errors when computing the diagonal of a KeOps kernel.
(This computation happens during preconditioning, which requires the
diagonal of the already-formed kernel LinearOperator object.)
This error was because KeopsLinearOperator.diagonal calls to_dense on
the output of a batch kernel operation. However, to_dense is not defined
for KeOps LazyTensors.

This PR is in some sense a hack fix to this bug (a less hack fix will
require changes to KernelLinearOperator), but it is also a generally
nice and helpful refactor that will improve KeOps kernels in general.

The fixes:
- KeOpsKernels now only define a forward function, that will be used
both when we want to use KeOps and when we want to bypass it.
- KeOpsKernels now use a `_lazify_inputs` helper method, which
(potentially) wraps the inputs as KeOpsLazyTensors, or potentially
leaves the inputs as torch Tensors.
- The KeOps wrapping happens unless we want to bypass KeOps, which
occurs when either (1) the matrix is small (below Cholesky size) or (2)
when the use has turned off the `gpytorch.settings.use_keops` option
(*NEW IN THIS PR*).

Why this is beneficial:
- KeOps kernels now follow the same API as non-KeOps kernels (define a
forward method)
- The user now only has to define one forward method, that works in both
the keops and non-keops cases
- The `diagonal` call in KeopsLinearOperator constructs a batch 1x1
matrix, which is small enough to bypass keops and thus avoid the current
bug. (Hence why this solution is currently a hack, but could become less
hacky with a small modification to KernelLinearOperator and/or the
to_dense method in LinearOperator).

Other changes:
- Fix stability issues with the keops MaternKernel. (There were some NaN
issues)
- Introduce a `gpytorch.settings.use_keops` feature flag.
- Clean up KeOPs notebook

[Fixes #2363]
gpleiss added a commit that referenced this pull request Nov 13, 2023
KernelLinearOperator was throwing errors when computing the diagonal of a KeOps kernel.
(This computation happens during preconditioning, which requires the
diagonal of the already-formed kernel LinearOperator object.)
This error was because KeopsLinearOperator.diagonal calls to_dense on
the output of a batch kernel operation. However, to_dense is not defined
for KeOps LazyTensors.

This PR is in some sense a hack fix to this bug (a less hack fix will
require changes to KernelLinearOperator), but it is also a generally
nice and helpful refactor that will improve KeOps kernels in general.

The fixes:
- KeOpsKernels now only define a forward function, that will be used
both when we want to use KeOps and when we want to bypass it.
- KeOpsKernels now use a `_lazify_inputs` helper method, which
(potentially) wraps the inputs as KeOpsLazyTensors, or potentially
leaves the inputs as torch Tensors.
- The KeOps wrapping happens unless we want to bypass KeOps, which
occurs when either (1) the matrix is small (below Cholesky size) or (2)
when the use has turned off the `gpytorch.settings.use_keops` option
(*NEW IN THIS PR*).

Why this is beneficial:
- KeOps kernels now follow the same API as non-KeOps kernels (define a
forward method)
- The user now only has to define one forward method, that works in both
the keops and non-keops cases
- The `diagonal` call in KeopsLinearOperator constructs a batch 1x1
matrix, which is small enough to bypass keops and thus avoid the current
bug. (Hence why this solution is currently a hack, but could become less
hacky with a small modification to KernelLinearOperator and/or the
to_dense method in LinearOperator).

Other changes:
- Fix stability issues with the keops MaternKernel. (There were some NaN
issues)
- Introduce a `gpytorch.settings.use_keops` feature flag.
- Clean up KeOPs notebook

[Fixes #2363]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

2 participants