@@ -138,7 +138,7 @@ def __call__(self, x: Tensor, prior: bool = False, **kwargs: Any) -> Multivariat
138
138
if x is not None :
139
139
assert self .inducing_points .shape [:- 2 ] == x .shape [:- 2 ], (
140
140
f"x batch shape must matches inducing points batch shape, "
141
- f"but got train data batch shape = { x .shape [:- 2 ]} , "
141
+ f"but got x batch shape = { x .shape [:- 2 ]} , "
142
142
f"inducing points batch shape = { self .inducing_points .shape [:- 2 ]} ."
143
143
)
144
144
@@ -211,6 +211,7 @@ def forward(
211
211
nn_indices = self .nn_util .find_nn_idx (x .float ())
212
212
213
213
x_batch_shape = x .shape [:- 2 ]
214
+ batch_shape = torch .broadcast_shapes (self ._model_batch_shape , x_batch_shape )
214
215
x_bsz = x .shape [- 2 ]
215
216
assert nn_indices .shape == (* x_batch_shape , x_bsz , self .k ), nn_indices .shape
216
217
@@ -221,7 +222,6 @@ def forward(
221
222
assert inducing_points .shape == (* x_batch_shape , x_bsz , self .k , self .D )
222
223
223
224
# get variational mean and covar for nearest neighbors
224
- batch_shape = torch .broadcast_shapes (self ._model_batch_shape , x_batch_shape )
225
225
inducing_values = self ._variational_distribution .variational_mean
226
226
expanded_inducing_values = inducing_values .unsqueeze (- 1 ).expand (* batch_shape , self .M , self .k )
227
227
expanded_nn_indices = nn_indices .expand (* batch_shape , x_bsz , self .k )
@@ -239,16 +239,20 @@ def forward(
239
239
# Make everything batch mode
240
240
x = x .unsqueeze (- 2 )
241
241
assert x .shape == (* x_batch_shape , x_bsz , 1 , self .D )
242
+ x = x .expand (* batch_shape , x_bsz , 1 , self .D )
242
243
243
244
# Compute forward mode in the standard way
244
- _x_batch_dims = tuple (range (len (x_batch_shape )))
245
- _x = x .permute ((- 3 ,) + _x_batch_dims + (- 2 , - 1 ))
246
- _inducing_points = inducing_points .permute ((- 3 ,) + _x_batch_dims + (- 2 , - 1 ))
247
- _inducing_values = inducing_values .permute ((- 2 ,) + _x_batch_dims + (- 1 ,))
248
- _variational_inducing_covar = variational_inducing_covar .permute ((- 3 ,) + _x_batch_dims + (- 2 , - 1 ))
245
+ _batch_dims = tuple (range (len (batch_shape )))
246
+ _x = x .permute ((- 3 ,) + _batch_dims + (- 2 , - 1 )) # (x_bsz, *batch_shape, 1, D)
247
+
248
+ # inducing_points.shape (*x_batch_shape, x_bsz, self.k, self.D)
249
+ inducing_points = inducing_points .expand (* batch_shape , x_bsz , self .k , self .D )
250
+ _inducing_points = inducing_points .permute ((- 3 ,) + _batch_dims + (- 2 , - 1 )) # (x_bsz, *batch_shape, k, D)
251
+ _inducing_values = inducing_values .permute ((- 2 ,) + _batch_dims + (- 1 ,))
252
+ _variational_inducing_covar = variational_inducing_covar .permute ((- 3 ,) + _batch_dims + (- 2 , - 1 ))
249
253
dist = super ().forward (_x , _inducing_points , _inducing_values , _variational_inducing_covar , ** kwargs )
250
254
251
- _x_batch_dims = tuple (range (1 , 1 + len (x_batch_shape )))
255
+ _x_batch_dims = tuple (range (1 , 1 + len (batch_shape )))
252
256
predictive_mean = dist .mean # (x_bsz, *x_batch_shape, 1)
253
257
predictive_covar = dist .covariance_matrix # (x_bsz, *x_batch_shape, 1, 1)
254
258
predictive_mean = predictive_mean .permute (_x_batch_dims + (0 , - 1 ))
0 commit comments