Skip to content

[Bug] GPytorch Kernel Partitioning Increasing memory usage #2352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Felice27 opened this issue May 25, 2023 · 8 comments
Closed

[Bug] GPytorch Kernel Partitioning Increasing memory usage #2352

Felice27 opened this issue May 25, 2023 · 8 comments
Labels

Comments

@Felice27
Copy link

Felice27 commented May 25, 2023

🐛 Bug

I am attempting to fit an exact GP regression on a dataset of ~1 million points; my current code works with 1/10th of the full dataset. When following the steps in example 02 "Simple_MultiGPU_GP_Regression" notebook, I encountered an issue where the memory usage increased after introducing a kernel partition.

To reproduce

import math
import torch
import gpytorch
from datetime import datetime
import sys
from matplotlib import pyplot as plt
sys.path.append('../')
from LBFGS import FullBatchLBFGS # LGFGS.py file from GitHub
import numpy as np
import pandas as pd
import os

data = np.genfromtxt(filename, delimiter=',', skip_header = 1, dtype=None) # File of 13 parameters, 4 of which are relevant to my algorithm
intensity = data[:, 3]
thickness = data[:, 8]
focal_distance = data[:, 2]
max_energy = data[:, 4][
data = np.dstack((intensity, thickness, focal_distance, max_energy)).reshape(num_points, 4)
data = torch.tensor(data)
N = data.shape[0]
# make train/val/test
n_train = int(0.8 * N)
train_x, train_y = data[:n_train, :-1], data[:n_train, -1]
test_x, test_y = data[n_train:, :-1], data[n_train:, -1]

# normalize features
mean = train_x.mean(dim=-2, keepdim=True)
std = train_x.std(dim=-2, keepdim=True) + 1e-6 # prevent dividing by 0
train_x = (train_x - mean) / std
test_x = (test_x - mean) / std

# normalize labels
mean, std = train_y.mean(),train_y.std()
train_y = (train_y - mean) / std
test_y = (test_y - mean) / std

# make continguous
train_x, train_y = train_x.contiguous(), train_y.contiguous()
test_x, test_y = test_x.contiguous(), test_y.contiguous()

output_device = torch.device('cuda:0')

train_x, train_y = train_x.to(output_device), train_y.to(output_device)
test_x, test_y = test_x.to(output_device), test_y.to(output_device)
n_devices = torch.cuda.device_count()
print('Planning to run on {} GPUs.'.format(n_devices))
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, n_devices):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        base_covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        
        self.covar_module = gpytorch.kernels.MultiDeviceKernel(
            base_covar_module, device_ids=range(n_devices),
            output_device=output_device
        )
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

def train(train_x,
          train_y,
          n_devices,
          output_device,
          checkpoint_size,
          preconditioner_size,
          n_training_iter,
):
    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(output_device)
    model = ExactGPModel(train_x, train_y, likelihood, n_devices).to(output_device)
    model = model.double() # Necessary to match float64 type of dataset
    model.train()
    likelihood.train()
    
    optimizer = FullBatchLBFGS(model.parameters(), lr=0.1)
    # "Loss" for GPs - the marginal log likelihood
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    
    with gpytorch.beta_features.checkpoint_kernel(checkpoint_size), \
         gpytorch.settings.max_preconditioner_size(preconditioner_size):

        def closure():
            optimizer.zero_grad()
            output = model(train_x)
            loss = -mll(output, train_y)
            return loss

        loss = closure()
        loss.backward()

        for i in range(n_training_iter):
            options = {'closure': closure, 'current_loss': loss, 'max_ls': 10}
            loss, _, _, _, _, _, _, fail = optimizer.step(options)
            
            print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
                i + 1, n_training_iter, loss.item(),
                model.covar_module.module.base_kernel.lengthscale.item(),
                model.likelihood.noise.item()
            ))
            
            if fail:
                print('Convergence reached!')
                break
    
    print(f"Finished training on {train_x.size(0)} data points using {n_devices} GPUs.")
    return model, likelihood
import gc

def find_best_gpu_setting(train_x,
                          train_y,
                          n_devices,
                          output_device,
                          preconditioner_size
):
    N = train_x.size(0)

    # Find the optimum partition/checkpoint size by decreasing in powers of 2
    # Start with no partitioning (size = 0)
    settings = [0] + [int(n) for n in np.ceil(N / 2**np.arange(1, np.floor(np.log2(N))))]

    for checkpoint_size in settings:
        print('Number of devices: {} -- Kernel partition size: {}'.format(n_devices, checkpoint_size))
        try:
            # Try a full forward and backward pass with this setting to check memory usage
            _, _ = train(train_x, train_y,
                         n_devices=n_devices, output_device=output_device,
                         checkpoint_size=checkpoint_size,
                         preconditioner_size=preconditioner_size, n_training_iter=1)

            # when successful, break out of for-loop and jump to finally block
            break
        except RuntimeError as e:
            print('RuntimeError: {}'.format(e))
        except AttributeError as e:
            print('AttributeError: {}'.format(e))
        finally:
            # handle CUDA OOM error
            gc.collect()
            torch.cuda.empty_cache()
    return checkpoint_size

** Error message **

Number of devices: 1 -- Kernel partition size: 0
RuntimeError: CUDA out of memory. Tried to allocate 63.64 GiB (GPU 0; 31.74 GiB total capacity; 93.80 MiB already allocated; 30.03 GiB free; 104.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Number of devices: 1 -- Kernel partition size: 46209
RuntimeError: CUDA out of memory. Tried to allocate 190.91 GiB (GPU 0; 31.74 GiB total capacity; 153.03 MiB already allocated; 29.96 GiB free; 178.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Number of devices: 1 -- Kernel partition size: 23105
RuntimeError: CUDA out of memory. Tried to allocate 190.91 GiB (GPU 0; 31.74 GiB total capacity; 153.03 MiB already allocated; 29.96 GiB free; 178.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Number of devices: 1 -- Kernel partition size: 11553
RuntimeError: CUDA out of memory. Tried to allocate 190.91 GiB (GPU 0; 31.74 GiB total capacity; 153.03 MiB already allocated; 29.96 GiB free; 178.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
(continues as kernel size decays, all attempting to allocate 190.91 GiB)

Expected Behavior

I expect the memory allocation to decrease by roughly a factor of 2 every time the kernel size decreases, but the attempted memory allocation increases when introducing a kernel partition and remains constant for all nonzero kernel sizes. Is there something I'm missing in order to get the memory usage to decrease properly?

System information

  • GPyTorch Version 1.10
  • PyTorch Version 2.0.1
  • Computer OS Linux

Additional context

Question also posted to https://stackoverflow.com/questions/76335780/why-is-gpytorch-kernel-partition-size-not-reducing-cuda-memory-usage

@Felice27 Felice27 added the bug label May 25, 2023
@gpleiss
Copy link
Member

gpleiss commented May 26, 2023

@Felice27 we're going to deprecate multi-GPU / kernel partitioning. You should instead try the KeOps integration (https://docs.gpytorch.ai/en/stable/examples/02_Scalable_Exact_GPs/KeOps_GP_Regression.html).

KeOps essentially does our kernel partitioning (on a single GPU), but it is far better than our own code in the Simple_MultiGPU_GP_Regression notebook.

@Felice27
Copy link
Author

Felice27 commented May 26, 2023

@Felice27 we're going to deprecate multi-GPU / kernel partitioning. You should instead try the KeOps integration (https://docs.gpytorch.ai/en/stable/examples/02_Scalable_Exact_GPs/KeOps_GP_Regression.html).

KeOps essentially does our kernel partitioning (on a single GPU), but it is far better than our own code in the Simple_MultiGPU_GP_Regression notebook.

Thank you for the assistance! I'll start updating my code, but I would just have two questions:

  1. Is there any way to integrate KeOps with a multi-GPU regression? As part of my work, I can request access to up to 4 NVidia Volta V100 32GB GPUs in parallel, which would greatly improve both performance and speed for the training process on the original million-point dataset if possible.
  2. Would KeOps also work with a vector-output regression (from R^3 to R^3)? As of right now, I'm focusing on only predicting one label of the data to ensure I understand the training process, but I would eventually like to make predictions on all 3. Would this work with KeOps?

@gpleiss
Copy link
Member

gpleiss commented May 26, 2023

@Felice27 at the moment KeOps isn't compatible with multi-GPU regression, but it shouldn't be too difficult to accomplish this. Once cornellius-gp/linear_operator#62 is merged into the LinearOperator repo, it shouldn't be too difficult to write a MultiGPU keops kernel.

I'm stretched very thin at the moment, so if you'd be up for writing a KeOps MultiGPU kernel and putting up a PR that'd be great. (Wait until cornellius-gp/linear_operator#62 is merged in tho.)

@Felice27
Copy link
Author

Alright, thank you for all the assistance! Once that PR is merged, I'll look into accomplishing that.

@Felice27
Copy link
Author

Is there any reason that the time per iteration increases significantly as the program continues to run? When I run the training for 50 iterations, each iteration is taking about 25 seconds until about the 20th iteration, at which point the time to run each iteration continues to grow to hundreds of seconds per iteration. I don't think there's an obvious memory leak, as adding in a gc.collect(); torch.cuda.empty_cache() every 10 iterations has no effect on the increase in training time. Is there some optimization to make with KeOps to prevent this?

@gpleiss
Copy link
Member

gpleiss commented Jun 1, 2023

@Felice27 the large-scale GPs use conjugate gradients under the hood (Cholesky won't fit into memory). CG is an iterative algorithm, and the number of iterations required to reach convergence depends on the conditioning of the kernel matrix. Changes to the GP hyperparameters change the conditioning of the kernel matrix, which may cause CG to require more iterations before convergence.

@gpleiss
Copy link
Member

gpleiss commented Jun 2, 2023

Closing because checkpointing is now deprecated (as of v1.11)

@gpleiss gpleiss closed this as completed Jun 2, 2023
@gpleiss
Copy link
Member

gpleiss commented Jun 2, 2023

Also @Felice27 the cornellius-gp/linear_operator#62 PR is now in (as of v1.11)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants