Skip to content

Likelihood docs update #2292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .pyre_configuration
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
{
"site_package_search_strategy": "pep561",
"source_directories": [
{"import_root": ".", "source": "gpytorch/"}
"gpytorch/"
],
"ignore_all_errors": [
"gpytorch/functions/*.py",
"gpytorch/lazy/*.py",
"gpytorch/test/*.py"
],
"search_path": [
".",
Expand Down
20 changes: 19 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,28 @@ def find_version(*file_paths):
shutil.rmtree(examples_dest)
os.mkdir(examples_dest)

# Include examples in documentation
# This adds a lot of time to the doc buiod; to bypass use the environment variable SKIP_EXAMPLES=true
for root, dirs, files in os.walk(examples_source):
for dr in dirs:
os.mkdir(os.path.join(root.replace(examples_source, examples_dest), dr))
for fil in files:
if os.path.splitext(fil)[1] in [".ipynb", ".md", ".rst"]:
source_filename = os.path.join(root, fil)
dest_filename = source_filename.replace(examples_source, examples_dest)
shutil.copyfile(source_filename, dest_filename)

# If we're skipping examples, put a dummy file in place
if os.getenv("SKIP_EXAMPLES"):
if dest_filename.endswith("index.rst"):
shutil.copyfile(source_filename, dest_filename)
else:
with open(os.path.splitext(dest_filename)[0] + ".rst", "w") as f:
basename = os.path.splitext(os.path.basename(dest_filename))[0]
f.write(f"{basename}\n" + "=" * 80)

# Otherwise, copy over the real example files
else:
shutil.copyfile(source_filename, dest_filename)

# -- Project information -----------------------------------------------------

Expand Down Expand Up @@ -282,6 +296,10 @@ def _process(annotation, config):
arg = annotation.__args__[0]
res = "list(" + _process(arg, config) + ")"

# Convert any List[*A*] into "list(*A*)"
elif str(annotation).startswith("typing.Dict"):
res = str(annotation)

# Convert any Iterable[*A*] into "iterable(*A*)"
elif str(annotation).startswith("typing.Iterable"):
arg = annotation.__args__[0]
Expand Down
1 change: 1 addition & 0 deletions docs/source/likelihoods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Likelihood
--------------------

.. autoclass:: Likelihood
:special-members: __call__
:members:


Expand Down

Large diffs are not rendered by default.

28 changes: 23 additions & 5 deletions gpytorch/likelihoods/bernoulli_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#!/usr/bin/env python3

import warnings
from typing import Any

import torch
from torch import Tensor
from torch.distributions import Bernoulli

from ..distributions import base_distributions
from ..distributions import base_distributions, MultivariateNormal
from ..functions import log_normal_cdf
from .likelihood import _OneDimensionalLikelihood

Expand All @@ -21,26 +24,41 @@ class BernoulliLikelihood(_OneDimensionalLikelihood):
p(Y=y|f)=\Phi((2y - 1)f)
\end{equation*}

.. note::
BernoulliLikelihood has an analytic marginal distribution.

.. note::
The labels should take values in {0, 1}.
"""

def forward(self, function_samples, **kwargs):
has_analytic_marginal: bool = True

def __init__(self) -> None:
return super().__init__()

def forward(self, function_samples: Tensor, *args: Any, **kwargs: Any) -> Bernoulli:
output_probs = base_distributions.Normal(0, 1).cdf(function_samples)
return base_distributions.Bernoulli(probs=output_probs)

def log_marginal(self, observations, function_dist, *args, **kwargs):
def log_marginal(
self, observations: Tensor, function_dist: MultivariateNormal, *args: Any, **kwargs: Any
) -> Tensor:
marginal = self.marginal(function_dist, *args, **kwargs)
return marginal.log_prob(observations)

def marginal(self, function_dist, **kwargs):
def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) -> Bernoulli:
"""
:return: Analytic marginal :math:`p(\mathbf y)`.
"""
mean = function_dist.mean
var = function_dist.variance
link = mean.div(torch.sqrt(1 + var))
output_probs = base_distributions.Normal(0, 1).cdf(link)
return base_distributions.Bernoulli(probs=output_probs)

def expected_log_prob(self, observations, function_dist, *params, **kwargs):
def expected_log_prob(
self, observations: Tensor, function_dist: MultivariateNormal, *params: Any, **kwargs: Any
) -> Tensor:
if torch.any(observations.eq(-1)):
# Remove after 1.0
warnings.warn(
Expand Down
27 changes: 17 additions & 10 deletions gpytorch/likelihoods/beta_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
#!/usr/bin/env python3

from typing import Any, Optional

import torch
from torch import Tensor
from torch.distributions import Beta

from ..constraints import Positive
from ..constraints import Interval, Positive
from ..distributions import base_distributions
from ..priors import Prior
from .likelihood import _OneDimensionalLikelihood


Expand All @@ -27,16 +32,18 @@ class BetaLikelihood(_OneDimensionalLikelihood):
p(y \mid f) = \text{Beta} \left( \sigma(f) s , (1 - \sigma(f)) s\right)

:param batch_shape: The batch shape of the learned noise parameter (default: []).
:type batch_shape: torch.Size, optional
:param scale_prior: Prior for scale parameter :math:`s`.
:type scale_prior: ~gpytorch.priors.Prior, optional
:param scale_constraint: Constraint for scale parameter :math:`s`.
:type scale_constraint: ~gpytorch.constraints.Interval, optional

:var torch.Tensor scale: :math:`s` parameter (scale)
:ivar torch.Tensor scale: :math:`s` parameter (scale)
"""

def __init__(self, batch_shape=torch.Size([]), scale_prior=None, scale_constraint=None):
def __init__(
self,
batch_shape: torch.Size = torch.Size([]),
scale_prior: Optional[Prior] = None,
scale_constraint: Optional[Interval] = None,
) -> None:
super().__init__()

if scale_constraint is None:
Expand All @@ -49,19 +56,19 @@ def __init__(self, batch_shape=torch.Size([]), scale_prior=None, scale_constrain
self.register_constraint("raw_scale", scale_constraint)

@property
def scale(self):
def scale(self) -> Tensor:
return self.raw_scale_constraint.transform(self.raw_scale)

@scale.setter
def scale(self, value):
def scale(self, value: Tensor) -> None:
self._set_scale(value)

def _set_scale(self, value):
def _set_scale(self, value: Tensor) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_scale)
self.initialize(raw_scale=self.raw_scale_constraint.inverse_transform(value))

def forward(self, function_samples, **kwargs):
def forward(self, function_samples: Tensor, *args: Any, **kwargs: Any) -> Beta:
mixture = torch.sigmoid(function_samples)
scale = self.scale
alpha = mixture * scale + 1
Expand Down
Loading