1
+ import math
1
2
from typing import Optional
2
3
3
4
import torch
5
+ from torch import Tensor
4
6
5
7
from .kernel import Kernel
6
8
7
9
10
+ def _fmax (r : Tensor , j : int , q : int ) -> Tensor :
11
+ return torch .max (torch .tensor (0.0 , dtype = r .dtype , device = r .device ), 1 - r ).pow (j + q )
12
+
13
+
14
+ def _get_cov (r : Tensor , j : int , q : int ) -> Tensor :
15
+ if q == 0 :
16
+ return 1
17
+ if q == 1 :
18
+ return (j + 1 ) * r + 1
19
+ if q == 2 :
20
+ return 1 + (j + 2 ) * r + ((j + 4 * j + 3 ) / 3.0 ) * (r ** 2 )
21
+ if q == 3 :
22
+ return (
23
+ 1
24
+ + (j + 3 ) * r
25
+ + ((6 * j ** 2 + 36 * j + 45 ) / 15.0 ) * r .square ()
26
+ + ((j ** 3 + 9 * j ** 2 + 23 * j + 15 ) / 15.0 ) * (r ** 3 )
27
+ )
28
+
29
+
8
30
class PiecewisePolynomialKernel (Kernel ):
9
31
r"""
10
32
Computes a covariance matrix based on the Piecewise Polynomial kernel
@@ -79,32 +101,14 @@ def __init__(self, q: Optional[int] = 2, **kwargs):
79
101
raise ValueError ("q expected to be 0, 1, 2 or 3" )
80
102
self .q = q
81
103
82
- def fmax (self , r , j , q ):
83
- return torch .max (torch .tensor (0.0 ), 1 - r ).pow (j + q )
84
-
85
- def get_cov (self , r , j , q ):
86
- if q == 0 :
87
- return 1
88
- if q == 1 :
89
- return (j + 1 ) * r + 1
90
- if q == 2 :
91
- return 1 + (j + 2 ) * r + ((j ** 2 + 4 * j + 3 ) / 3.0 ) * r ** 2
92
- if q == 3 :
93
- return (
94
- 1
95
- + (j + 3 ) * r
96
- + ((6 * j ** 2 + 36 * j + 45 ) / 15.0 ) * r ** 2
97
- + ((j ** 3 + 9 * j ** 2 + 23 * j + 15 ) / 15.0 ) * r ** 3
98
- )
99
-
100
- def forward (self , x1 , x2 , last_dim_is_batch = False , diag = False , ** params ):
104
+ def forward (self , x1 : Tensor , x2 : Tensor , last_dim_is_batch : bool = False , diag : bool = False , ** params ) -> Tensor :
101
105
x1_ = x1 .div (self .lengthscale )
102
106
x2_ = x2 .div (self .lengthscale )
103
107
if last_dim_is_batch is True :
104
108
D = x1 .shape [1 ]
105
109
else :
106
110
D = x1 .shape [- 1 ]
107
- j = torch .floor (torch . tensor ( D / 2.0 ) ) + self .q + 1
111
+ j = math .floor (D / 2.0 ) + self .q + 1
108
112
if last_dim_is_batch and diag :
109
113
r = self .covar_dist (x1_ , x2_ , last_dim_is_batch = True , diag = True )
110
114
elif diag :
@@ -113,5 +117,5 @@ def forward(self, x1, x2, last_dim_is_batch=False, diag=False, **params):
113
117
r = self .covar_dist (x1_ , x2_ , last_dim_is_batch = True )
114
118
else :
115
119
r = self .covar_dist (x1_ , x2_ )
116
- cov_matrix = self . fmax (r , j , self .q ) * self . get_cov (r , j , self .q )
120
+ cov_matrix = _fmax (r , j , self .q ) * _get_cov (r , j , self .q )
117
121
return cov_matrix
0 commit comments