10
10
AddedDiagLinearOperator ,
11
11
BatchRepeatLinearOperator ,
12
12
ConstantMulLinearOperator ,
13
- DenseLinearOperator ,
14
13
InterpolatedLinearOperator ,
15
14
LinearOperator ,
16
15
LowRankRootAddedDiagLinearOperator ,
@@ -211,8 +210,8 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
211
210
212
211
# now update the root and root inverse
213
212
new_lt = self .lik_train_train_covar .cat_rows (fant_train_covar , fant_fant_covar )
214
- new_root = new_lt .root_decomposition ().root . to_dense ()
215
- new_covar_cache = new_lt .root_inv_decomposition ().root . to_dense ()
213
+ new_root = new_lt .root_decomposition ().root
214
+ new_covar_cache = new_lt .root_inv_decomposition ().root
216
215
217
216
# Expand inputs accordingly if necessary (for fantasies at the same points)
218
217
if full_inputs [0 ].dim () <= full_targets .dim ():
@@ -222,7 +221,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
222
221
full_inputs = [fi .expand (fant_batch_shape + fi .shape ) for fi in full_inputs ]
223
222
full_mean = full_mean .expand (fant_batch_shape + full_mean .shape )
224
223
full_covar = BatchRepeatLinearOperator (full_covar , repeat_shape )
225
- new_root = BatchRepeatLinearOperator (DenseLinearOperator ( new_root ) , repeat_shape )
224
+ new_root = BatchRepeatLinearOperator (new_root , repeat_shape )
226
225
# no need to repeat the covar cache, broadcasting will do the right thing
227
226
228
227
if isinstance (full_output , MultitaskMultivariateNormal ):
@@ -238,7 +237,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
238
237
inv_root = new_covar_cache ,
239
238
)
240
239
add_to_cache (fant_strat , "mean_cache" , fant_mean_cache )
241
- add_to_cache (fant_strat , "covar_cache" , new_covar_cache )
240
+ add_to_cache (fant_strat , "covar_cache" , new_covar_cache . to_dense () )
242
241
return fant_strat
243
242
244
243
@property
0 commit comments