Closed
Description
🐛 Bug
By default, if any model parameters use a constraints.Interval
, fetching these parameters forces a GPU synchronize on every forward pass, due to this line:
gpytorch/gpytorch/constraints/constraints.py
Line 118 in 2e1ccec
True, this code can be disabled by calling gpytorch.settings.debug._set_state(False)
, but even with debug enabled this check ought to only happen once, not on every transform
.
To reproduce
import gpytorch
import torch
device = torch.device("cuda")
interval = gpytorch.constraints.Interval(0, 1).to(device)
raw = torch.tensor([-0.42], device=device)
# Arguably it's okay for this check to occur on the first transform.
# If you choose that fix, then uncomment this line to test your fix.
# constrained = interval.transform(raw)
torch.cuda.set_sync_debug_mode(2)
constrained = interval.transform(raw)
Traceback (most recent call last):
File "bug.py", line 17, in <module>
constrained = interval.transform(raw)
File "/opt/conda/envs/py39/lib/python3.9/site-packages/gpytorch/constraints/constraints.py", line 118, in transform
if max_bound == math.inf or min_bound == -math.inf:
RuntimeError: called a synchronizing CUDA operation
Expected Behavior
This sync should happen at most once, since lower_bound
and upper_bound
are constant.
- Easy fix: Add a flag on the
self
to denote whether the check has occurred, and only perform the check on the first call totransform
oruntransform
- More elegant fix: this check should happen during
__init__
(but this would require some refactoring)
System information
gpytorch: 1.9.1
torch: 1.13.1+cu116
Linux