Skip to content

[Bug] PiecewisePolynomialKernel fails to put all tensors on the same GPU device #2199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
c-lyu opened this issue Nov 21, 2022 · 4 comments · Fixed by #2217
Closed

[Bug] PiecewisePolynomialKernel fails to put all tensors on the same GPU device #2199

c-lyu opened this issue Nov 21, 2022 · 4 comments · Fixed by #2217
Labels

Comments

@c-lyu
Copy link

c-lyu commented Nov 21, 2022

🐛 Bug

I was experimenting with the tutorial of Exact GP multiple GPUs here. However, when the base kernel was changed from RBF kernel to piecewise polynomial kernel, an error showed up that tensors are not on the same device.

To reproduce

** Code snippet to reproduce **

import torch
import gpytorch
from LBFGS import FullBatchLBFGS

import os
import numpy as np
import urllib.request
from scipy.io import loadmat
dataset = 'protein'
if not os.path.isfile(f'../../datasets/UCI/{dataset}.mat'):
    print(f'Downloading \'{dataset}\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1nRb8e7qooozXkNghC5eQS0JeywSXGX2S',
                               f'../../datasets/UCI/{dataset}.mat')

data = torch.Tensor(loadmat(f'../../datasets/UCI/{dataset}.mat')['data'])

n_train = 4000
train_x, train_y = data[:n_train, :-1], data[:n_train, -1]

n_devices = torch.cuda.device_count()
output_device = torch.device('cuda:0')
train_x, train_y = train_x.contiguous().to(output_device), train_y.contiguous().to(output_device)

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, n_devices):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        # change kernel here ----------------------------------------------------|
        base_covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.PiecewisePolynomialKernel())

        self.covar_module = gpytorch.kernels.MultiDeviceKernel(
            base_covar_module, device_ids=range(n_devices),
            output_device=output_device
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

def train(train_x,
          train_y,
          n_devices,
          output_device,
          checkpoint_size,
          preconditioner_size,
):
    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(output_device)
    model = ExactGPModel(train_x, train_y, likelihood, n_devices).to(output_device)
    model.train()
    likelihood.train()

    optimizer = FullBatchLBFGS(model.parameters(), lr=0.1)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    with gpytorch.beta_features.checkpoint_kernel(checkpoint_size), \
         gpytorch.settings.max_preconditioner_size(preconditioner_size):

        def closure():
            optimizer.zero_grad()
            output = model(train_x)
            loss = -mll(output, train_y)
            return loss

        loss = closure()
        loss.backward()

        options = {'closure': closure, 'current_loss': loss, 'max_ls': 10}
        loss, _, _, _, _, _, _, fail = optimizer.step(options)

        print(loss.item())
        
    return model, likelihood

_, _ = train(train_x, train_y,
             n_devices=n_devices, output_device=output_device,
             checkpoint_size=0, preconditioner_size=100)

** Stack trace/error message **

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 _, _ = train(train_x, train_y,
      2              n_devices=n_devices, output_device=output_device,
      3              checkpoint_size=0, preconditioner_size=100)

Input In [3], in train(train_x, train_y, n_devices, output_device, checkpoint_size, preconditioner_size)
     41     return loss
     43 loss = closure()
---> 44 loss.backward()
     46 options = {'closure': closure, 'current_loss': loss, 'max_ls': 10}
     47 loss, _, _, _, _, _, _, fail = optimizer.step(options)

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/torch/_tensor.py:396, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    387 if has_torch_function_unary(self):
    388     return handle_torch_function(
    389         Tensor.backward,
    390         (self,),
   (...)
    394         create_graph=create_graph,
    395         inputs=inputs)
--> 396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/torch/autograd/__init__.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    168     retain_graph = create_graph
    170 # The reason we repeat same the comment below is that
    171 # some Python versions print out the first line of a multi-line function
    172 # calls in the traceback and some print out the last line
--> 173 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175     allow_unreachable=True, accumulate_grad=True)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu!

System information

Please complete the following information:

  • GPyTorch Version: 1.9.0
  • PyTorch Version: 1.12.1
  • Computer OS: Ubuntu 16.04.5 LTS
  • CUDA version: 11.3
  • CUDA devices: two NVIDIA 3090

Additional context

I further experimented with training size and similar issue showed up when n_train = 100 using RBF kernel. Please see the error message below.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 _, _ = train(train_x, train_y,
      2              n_devices=n_devices, output_device=output_device,
      3              checkpoint_size=0, preconditioner_size=100)

Input In [3], in train(train_x, train_y, n_devices, output_device, checkpoint_size, preconditioner_size)
     40     loss = -mll(output, train_y)
     41     return loss
---> 43 loss = closure()
     44 loss.backward()
     46 options = {'closure': closure, 'current_loss': loss, 'max_ls': 10}

Input In [3], in train.<locals>.closure()
     38 optimizer.zero_grad()
     39 output = model(train_x)
---> 40 loss = -mll(output, train_y)
     41 return loss

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/gpytorch/module.py:30, in Module.__call__(self, *inputs, **kwargs)
     29 def __call__(self, *inputs, **kwargs):
---> 30     outputs = self.forward(*inputs, **kwargs)
     31     if isinstance(outputs, list):
     32         return [_validate_module_outputs(output) for output in outputs]

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:64, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params)
     62 # Get the log prob of the marginal distribution
     63 output = self.likelihood(function_dist, *params)
---> 64 res = output.log_prob(target)
     65 res = self._add_other_terms(res, params)
     67 # Scale by the amount of data we have

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/gpytorch/distributions/multivariate_normal.py:169, in MultivariateNormal.log_prob(self, value)
    167 # Get log determininant and first part of quadratic form
    168 covar = covar.evaluate_kernel()
--> 169 inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
    171 res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)])
    172 return res

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:1594, in LinearOperator.inv_quad_logdet(self, inv_quad_rhs, logdet, reduce_inv_quad)
   1592             will_need_cholesky = False
   1593     if will_need_cholesky:
-> 1594         cholesky = CholLinearOperator(TriangularLinearOperator(self.cholesky()))
   1595     return cholesky.inv_quad_logdet(
   1596         inv_quad_rhs=inv_quad_rhs,
   1597         logdet=logdet,
   1598         reduce_inv_quad=reduce_inv_quad,
   1599     )
   1601 # Short circuit to inv_quad function if we're not computing logdet

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:1229, in LinearOperator.cholesky(self, upper)
   1221 @_implements(torch.linalg.cholesky)
   1222 def cholesky(self, upper: bool = False) -> "TriangularLinearOperator":  # noqa F811
   1223     """
   1224     Cholesky-factorizes the LinearOperator.
   1225 
   1226     :param upper: Upper triangular or lower triangular factor (default: False).
   1227     :return: Cholesky factor (lower or upper triangular)
   1228     """
-> 1229     chol = self._cholesky(upper=False)
   1230     if upper:
   1231         chol = chol._transpose_nonbatch()

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/linear_operator/utils/memoize.py:59, in _cached.<locals>.g(self, *args, **kwargs)
     57 kwargs_pkl = pickle.dumps(kwargs)
     58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59     return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:483, in LinearOperator._cholesky(self, upper)
    480 if any(isinstance(sub_mat, KeOpsLinearOperator) for sub_mat in evaluated_kern_mat._args):
    481     raise RuntimeError("Cannot run Cholesky with KeOps: it will either be really slow or not work.")
--> 483 evaluated_mat = evaluated_kern_mat.to_dense()
    485 # if the tensor is a scalar, we can just take the square root
    486 if evaluated_mat.size(-1) == 1:

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/linear_operator/utils/memoize.py:59, in _cached.<locals>.g(self, *args, **kwargs)
     57 kwargs_pkl = pickle.dumps(kwargs)
     58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59     return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/linear_operator/operators/sum_linear_operator.py:68, in SumLinearOperator.to_dense(self)
     66 @cached
     67 def to_dense(self):
---> 68     return (sum(linear_op.to_dense() for linear_op in self.linear_ops)).contiguous()

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/linear_operator/operators/sum_linear_operator.py:68, in <genexpr>(.0)
     66 @cached
     67 def to_dense(self):
---> 68     return (sum(linear_op.to_dense() for linear_op in self.linear_ops)).contiguous()

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/linear_operator/operators/cat_linear_operator.py:378, in CatLinearOperator.to_dense(self)
    377 def to_dense(self):
--> 378     return torch.cat([to_dense(L) for L in self.linear_ops], dim=self.cat_dim)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument tensors in method wrapper_cat)
@c-lyu c-lyu added the bug label Nov 21, 2022
@gpleiss
Copy link
Member

gpleiss commented Nov 28, 2022

Taking a glance at PiecewisePolynomilKernel, I'm pretty sure the issue is here:

return torch.max(torch.tensor(0.0), 1 - r).pow(j + q)

We are instantiating a torch.tensor(0.0) without setting it to the same dtype or device as r. I think a fix should be as simple as using torch.tensor(0.0, dtype=r.dtype, device=r.device)

Can you try making that change to source code, and - if it works - submit a bug fix PR?

@c-lyu
Copy link
Author

c-lyu commented Nov 30, 2022

Thank you for the answer, but unfortunately this fix doesn't work.
I just verified that the issue is irrelavant to MultiDeviceKernel, but only with PiecewisePolynomialKernel, because the same error arises when using only a single GPU (the minmal code example is shown below).

According to the error message, the device error happens during back propagation rather than forward passing. Strangely, the inputs and outputs of the model, as well as the loss, are all on the correct device, as can be noticed in logs.

Code snippet

import torch
import gpytorch

import os
import numpy as np
import urllib.request
from scipy.io import loadmat
dataset = 'protein'
if not os.path.isfile(f'../../datasets/UCI/{dataset}.mat'):
    print(f'Downloading \'{dataset}\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1nRb8e7qooozXkNghC5eQS0JeywSXGX2S',
                               f'../../datasets/UCI/{dataset}.mat')

data = torch.Tensor(loadmat(f'../../datasets/UCI/{dataset}.mat')['data'])

n_train = 4000
train_x, train_y = data[:n_train, :-1], data[:n_train, -1]

output_device = torch.device('cuda:0')
train_x, train_y = train_x.contiguous().to(output_device), train_y.contiguous().to(output_device)

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.PiecewisePolynomialKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        print(f"mean_x.device: {mean_x.device} - {mean_x.size()}")
        print(f"covar_x.device: {covar_x.device} - {covar_x.size()}")
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.GaussianLikelihood().to(output_device)
model = ExactGPModel(train_x, train_y, likelihood).to(output_device)
model.train()
likelihood.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
print(f"train device: x: {train_x.device}, y: {train_y.device}")

optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
print(f"loss.device: {loss.device}")
loss.backward()
optimizer.step()

Log output

train device: x: cuda:0, y: cuda:0
mean_x.device: cuda:0 - torch.Size([4000])
covar_x.device: cuda:0 - torch.Size([4000, 4000])
loss.device: cuda:0

Error message

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [7], in <cell line: 4>()
      2 output = model(train_x)
      3 loss = -mll(output, train_y)
----> 4 loss.backward()
      5 optimizer.step()
      7 print(loss.item())

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/torch/_tensor.py:396, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    387 if has_torch_function_unary(self):
    388     return handle_torch_function(
    389         Tensor.backward,
    390         (self,),
   (...)
    394         create_graph=create_graph,
    395         inputs=inputs)
--> 396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File ~/anaconda3/envs/pyg/lib/python3.8/site-packages/torch/autograd/__init__.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    168     retain_graph = create_graph
    170 # The reason we repeat same the comment below is that
    171 # some Python versions print out the first line of a multi-line function
    172 # calls in the traceback and some print out the last line
--> 173 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175     allow_unreachable=True, accumulate_grad=True)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@gpleiss
Copy link
Member

gpleiss commented Dec 6, 2022

Hmm I can take a look later. Mind renaming the issue to reflect that this isn't a MultiDeviceKernel issue?

gpleiss added a commit that referenced this issue Dec 6, 2022
@c-lyu c-lyu changed the title [Bug] MultiDeviceKernel fails to put tensors on the same device [Bug] PiecewisePolynomialKernel fails to put all tensors on the same GPU device Dec 7, 2022
@c-lyu
Copy link
Author

c-lyu commented Dec 7, 2022

Sure, I have renamed this issue to be related to PiecewisePolynomialKernel.

gpleiss added a commit that referenced this issue Dec 9, 2022
gpleiss added a commit that referenced this issue Dec 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants