Skip to content

Commit 48da011

Browse files
committed
Make PiecewisePolynomialKernel GPU compatible
[Fixes #2199]
1 parent 703dfbd commit 48da011

File tree

1 file changed

+25
-21
lines changed

1 file changed

+25
-21
lines changed

gpytorch/kernels/piecewise_polynomial_kernel.py

+25-21
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,32 @@
1+
import math
12
from typing import Optional
23

34
import torch
5+
from torch import Tensor
46

57
from .kernel import Kernel
68

79

10+
def _fmax(r: Tensor, j: int, q: int) -> Tensor:
11+
return torch.max(torch.tensor(0.0, dtype=r.dtype, device=r.device), 1 - r).pow(j + q)
12+
13+
14+
def _get_cov(r: Tensor, j: int, q: int) -> Tensor:
15+
if q == 0:
16+
return 1
17+
if q == 1:
18+
return (j + 1) * r + 1
19+
if q == 2:
20+
return 1 + (j + 2) * r + ((j + 4 * j + 3) / 3.0) * (r**2)
21+
if q == 3:
22+
return (
23+
1
24+
+ (j + 3) * r
25+
+ ((6 * j**2 + 36 * j + 45) / 15.0) * r.square()
26+
+ ((j**3 + 9 * j**2 + 23 * j + 15) / 15.0) * (r**3)
27+
)
28+
29+
830
class PiecewisePolynomialKernel(Kernel):
931
r"""
1032
Computes a covariance matrix based on the Piecewise Polynomial kernel
@@ -79,32 +101,14 @@ def __init__(self, q: Optional[int] = 2, **kwargs):
79101
raise ValueError("q expected to be 0, 1, 2 or 3")
80102
self.q = q
81103

82-
def fmax(self, r, j, q):
83-
return torch.max(torch.tensor(0.0), 1 - r).pow(j + q)
84-
85-
def get_cov(self, r, j, q):
86-
if q == 0:
87-
return 1
88-
if q == 1:
89-
return (j + 1) * r + 1
90-
if q == 2:
91-
return 1 + (j + 2) * r + ((j**2 + 4 * j + 3) / 3.0) * r**2
92-
if q == 3:
93-
return (
94-
1
95-
+ (j + 3) * r
96-
+ ((6 * j**2 + 36 * j + 45) / 15.0) * r**2
97-
+ ((j**3 + 9 * j**2 + 23 * j + 15) / 15.0) * r**3
98-
)
99-
100-
def forward(self, x1, x2, last_dim_is_batch=False, diag=False, **params):
104+
def forward(self, x1: Tensor, x2: Tensor, last_dim_is_batch: bool = False, diag: bool = False, **params) -> Tensor:
101105
x1_ = x1.div(self.lengthscale)
102106
x2_ = x2.div(self.lengthscale)
103107
if last_dim_is_batch is True:
104108
D = x1.shape[1]
105109
else:
106110
D = x1.shape[-1]
107-
j = torch.floor(torch.tensor(D / 2.0)) + self.q + 1
111+
j = math.floor(D / 2.0) + self.q + 1
108112
if last_dim_is_batch and diag:
109113
r = self.covar_dist(x1_, x2_, last_dim_is_batch=True, diag=True)
110114
elif diag:
@@ -113,5 +117,5 @@ def forward(self, x1, x2, last_dim_is_batch=False, diag=False, **params):
113117
r = self.covar_dist(x1_, x2_, last_dim_is_batch=True)
114118
else:
115119
r = self.covar_dist(x1_, x2_)
116-
cov_matrix = self.fmax(r, j, self.q) * self.get_cov(r, j, self.q)
120+
cov_matrix = _fmax(r, j, self.q) * _get_cov(r, j, self.q)
117121
return cov_matrix

0 commit comments

Comments
 (0)