diff --git a/gpytorch/priors/torch_priors.py b/gpytorch/priors/torch_priors.py index 9df86d8d0..b497dc4e2 100644 --- a/gpytorch/priors/torch_priors.py +++ b/gpytorch/priors/torch_priors.py @@ -71,7 +71,7 @@ def __init__(self, scale, validate_args=None, transform=None): self._transform = transform def expand(self, batch_shape): - return HalfCauchy(self.loc.expand(batch_shape), self.scale.expand(batch_shape)) + return HalfCauchyPrior(self.scale.expand(batch_shape)) class GammaPrior(Prior, Gamma):