Skip to content

Commit 2e7959d

Browse files
Merge pull request #2494 from naefjo/feature/online-learning-improvements
Bug: Exploit Structure in get_fantasy_strategy
2 parents 9551eba + e09674d commit 2e7959d

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

gpytorch/models/exact_prediction_strategies.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
AddedDiagLinearOperator,
1111
BatchRepeatLinearOperator,
1212
ConstantMulLinearOperator,
13-
DenseLinearOperator,
1413
InterpolatedLinearOperator,
1514
LinearOperator,
1615
LowRankRootAddedDiagLinearOperator,
@@ -211,8 +210,8 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
211210

212211
# now update the root and root inverse
213212
new_lt = self.lik_train_train_covar.cat_rows(fant_train_covar, fant_fant_covar)
214-
new_root = new_lt.root_decomposition().root.to_dense()
215-
new_covar_cache = new_lt.root_inv_decomposition().root.to_dense()
213+
new_root = new_lt.root_decomposition().root
214+
new_covar_cache = new_lt.root_inv_decomposition().root
216215

217216
# Expand inputs accordingly if necessary (for fantasies at the same points)
218217
if full_inputs[0].dim() <= full_targets.dim():
@@ -222,7 +221,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
222221
full_inputs = [fi.expand(fant_batch_shape + fi.shape) for fi in full_inputs]
223222
full_mean = full_mean.expand(fant_batch_shape + full_mean.shape)
224223
full_covar = BatchRepeatLinearOperator(full_covar, repeat_shape)
225-
new_root = BatchRepeatLinearOperator(DenseLinearOperator(new_root), repeat_shape)
224+
new_root = BatchRepeatLinearOperator(new_root, repeat_shape)
226225
# no need to repeat the covar cache, broadcasting will do the right thing
227226

228227
if isinstance(full_output, MultitaskMultivariateNormal):
@@ -238,7 +237,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
238237
inv_root=new_covar_cache,
239238
)
240239
add_to_cache(fant_strat, "mean_cache", fant_mean_cache)
241-
add_to_cache(fant_strat, "covar_cache", new_covar_cache)
240+
add_to_cache(fant_strat, "covar_cache", new_covar_cache.to_dense())
242241
return fant_strat
243242

244243
@property

0 commit comments

Comments
 (0)