Skip to content

Commit 0b71aa0

Browse files
committed
Likelihood passes in args/kwargs to expected_log_prob
1 parent 75779a9 commit 0b71aa0

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

gpytorch/likelihoods/likelihood.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def expected_log_prob(
4545
self, observations: Tensor, function_dist: MultivariateNormal, *args: Any, **kwargs: Any
4646
) -> Tensor:
4747
likelihood_samples = self._draw_likelihood_samples(function_dist, *args, **kwargs)
48-
res = likelihood_samples.log_prob(observations).mean(dim=0)
48+
res = likelihood_samples.log_prob(observations, *args, **kwargs).mean(dim=0)
4949
return res
5050

5151
@abstractmethod
@@ -410,7 +410,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
410410
def expected_log_prob(
411411
self, observations: Tensor, function_dist: MultivariateNormal, *args: Any, **kwargs: Any
412412
) -> Tensor:
413-
log_prob_lambda = lambda function_samples: self.forward(function_samples).log_prob(observations)
413+
log_prob_lambda = lambda function_samples: self.forward(function_samples, *args, **kwargs).log_prob(
414+
observations
415+
)
414416
log_prob = self.quadrature(log_prob_lambda, function_dist)
415417
return log_prob
416418

0 commit comments

Comments
 (0)