Skip to content

Fix training status of noise model of HeteroskedasticNoise after exceptions #2382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 25, 2023

Conversation

fjzzq2002
Copy link
Contributor

In the current implementation of HeteroskedasticNoise.forward, self.noise_model.train(training) is set after the output from self.noise_model is received. When an exception is thrown by self.noise_model(), this reset is not called, leaving self.noise_model in evaluation mode. This patch fixes this scenario by adding a try-finally block.

The following is a typical error example:

import gpytorch
import torch

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

train_x = torch.tensor([[1.0], [2.0]])
train_y = torch.tensor([0.0, 0.0])
test_x = torch.tensor([[3.0]])
likelihood = gpytorch.likelihoods.GaussianLikelihood()
noise_model = ExactGPModel(train_x, train_y, likelihood).to(torch.double)
noise_model(train_x)
final_likelihood = gpytorch.likelihoods.HeteroskedasticNoise(noise_model)
assert noise_model.training and final_likelihood.training

# under a normal lengthscale, our likelihood works as expected
noise_model.covar_module.base_kernel.raw_lengthscale.data[[0]] = 0
print(final_likelihood(test_x).to_dense())

# now assume due to an imperfect optimizer the lengthscale got really low
noise_model.covar_module.base_kernel.raw_lengthscale.data[[0]] = -720
assert 0 < noise_model.covar_module.base_kernel.lengthscale < 1e-310

# as a result, we got a numerical error whenever we try to eval on noise_model
noise_model.eval()
try:
    print(noise_model(test_x))
except Exception as e:
    print("Error:", e)
noise_model.train()

# now we run the final_likelihood which ends in another error
try:
    print(final_likelihood(test_x).to_dense())
except Exception as e:
    print("Error:", e)

# after the call, noise_model is still in evaluation mode, so the cache is not cleared
assert final_likelihood.training and not noise_model.training

# even if we reset lengthscale back to normal, it still cannot give the correct likelihood
noise_model.covar_module.base_kernel.raw_lengthscale.data[[0]] = 0
try:
    print(final_likelihood(test_x).to_dense())
except Exception as e:
    print("Error:", e)

# works after calling train() to clear the cache
noise_model.train()
print(final_likelihood(test_x).to_dense())

We also believe it resolves pytorch/botorch#1386 (replicated pytorch/botorch#1386 (comment) and our patch successfully fixed it).

Copy link
Collaborator

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. This makes sense to me! Though I think the unit test needs updating.



class TestNoiseModels(unittest.TestCase):
def test_heteroskedasticnoise_error(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is just testing the NumericallyUnstableModelExample but not the actual code of the noise model?

Copy link
Contributor Author

@fjzzq2002 fjzzq2002 Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are testing noise model HeteroskedasticNoise here by wrapping it around NumericallyUnstableModelExample. Or should I change the name of this class?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue was that the previous code didn't actually test that things were reset back, seems to be fixed now.



class TestNoiseModels(unittest.TestCase):
def test_heteroskedasticnoise_error(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue was that the previous code didn't actually test that things were reset back, seems to be fixed now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug] "RuntimeError: Trying to backward through the graph a second time" error in fit_gpytorch_mll
2 participants