-
Notifications
You must be signed in to change notification settings - Fork 563
Add device
property to Kernel
s, add unit tests
#2234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is kinda useful at times.
gpytorch/kernels/kernel.py
Outdated
if self.has_lengthscale: | ||
return self.lengthscale.device | ||
else: | ||
for param in self.parameters(): | ||
return param.device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- What if
self.lengthscale
and parameters are on different devices? Can they be? Can other tensor attributes of a Kernel be on different devices? - How about
self.active_dims
(when not None), since all Kernels should have that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point - this is modeled after the code in the dtype
property below. I guess the standing assumption has been here that all the parameters (of which lenghtscale is one) are on the same device / have the same dtype. I can simplify this and raise an error if this assumption is violated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I retained the behavior if has_lengthscale = True
for now to not cause any unforeseen BC issues. But I'm raising an error if parameters / devices are heterogeneous between the parameters.
What's the use case for this? It seems potentially misleading without checks that all of the kernel's attribute tensors are on the same device. Maybe that's why there's no |
The use case is basically if we need to conveniently get "the" device/dtype of a kernel object (e.g. for creating new tensors in computations in other places).
That is true, I can raise a clear error if this is called on something that has heterogeneous devices/dtypes
So since modules can nest other modules it's true that generally we can't be sure (neither should we) that everything lives on the same device or has the same dtype. I think so far this has largely be the implicit assumption though, so let me raise the errors as mentioned above. I think the discussion as to whether this should live on the general module level is one we can have separately from this specific PR. |
Summary: ## Motivation As of cornellius-gp/gpytorch#2234, the parent class of BoTorch kernels now has a property "device." This means that if a subclass tries to set `self.device`, it will error. This is why the BoTorch CI is currently breaking: https://github.com/pytorch/botorch/actions/runs/3841992968/jobs/6542850176 Pull Request resolved: #1611 Test Plan: Tests should pass Reviewed By: saitcakmak, Balandat Differential Revision: D42354199 Pulled By: esantorella fbshipit-source-id: c53e5b508dd75f4116870cd30ab90d11cd3eb573
This is kinda useful. At times. Maybe.