Skip to content

Commit 981edd8

Browse files
authored
Better support for missing labels (#2288)
* Fix prediction with NaN values in training labels * Missing observation support for multitask and allow MultivariateMultitaskNormal indexing * Fix error in MultitaskMultivariateNormal indexing on '...' * Fix indexing with negative values * Add tests - Indexing MultitaskMultivariateNormal - Missing data in single-task and multitask models * Render docs for MultitaskMultivariateNormal indexing and missing observations * Fix docs warning * Fix docstring * Finally fix docstring * Change missing data handling to option flag * Revamp missing value implementation - Enable via gpytorch.settings - Two modes: 'mask' and 'fill' - Makes GaussianLikelihoodWithMissingObs obsolete - Supports approximate GPs * Fix Python version incompatibility * Increase atol on variational tests * Add ExactMarginalLogLikelihoodWithMissingObs back with deprecation warning * Add warning if kernel matrix is made dense * Fix docs * Add quick path for noop slice indices * Add test for noop slice indexing * Fix docs * Switch to MaskedLinearOperator * Switch to MaskedLinearOperator from linear-operator 0.5.1 * Disable test_t_matmul_matrix() for LazyEvaluatedKernelTensor The test fails because LazyEvaluatedKernelTensor only supports _matmul() with checkpointing, but checkpointing is deprecated. * Fix merge conflict
1 parent 5e93d2c commit 981edd8

12 files changed

+769
-49
lines changed

docs/source/distributions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ MultitaskMultivariateNormal
4444

4545
.. autoclass:: MultitaskMultivariateNormal
4646
:members:
47+
:special-members: __getitem__
4748

4849

4950
Delta

docs/source/likelihoods.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ reduce the variance when computing approximate GP objective functions.
3434
:members:
3535

3636
:hidden:`GaussianLikelihoodWithMissingObs`
37-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3838

3939
.. autoclass:: GaussianLikelihoodWithMissingObs
4040
:members:

gpytorch/distributions/multitask_multivariate_normal.py

+147-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
import torch
44
from linear_operator import LinearOperator, to_linear_operator
5-
from linear_operator.operators import BlockDiagLinearOperator, BlockInterleavedLinearOperator, CatLinearOperator
5+
from linear_operator.operators import (
6+
BlockDiagLinearOperator,
7+
BlockInterleavedLinearOperator,
8+
CatLinearOperator,
9+
DiagLinearOperator,
10+
)
611

712
from .multivariate_normal import MultivariateNormal
813

@@ -18,7 +23,7 @@ class MultitaskMultivariateNormal(MultivariateNormal):
1823
:param torch.Tensor mean: An `n x t` or batch `b x n x t` matrix of means for the MVN distribution.
1924
:param ~linear_operator.operators.LinearOperator covar: An `... x NT x NT` (batch) matrix.
2025
covariance matrix of MVN distribution.
21-
:param bool validate_args: (default=False) If True, validate `mean` anad `covariance_matrix` arguments.
26+
:param bool validate_args: (default=False) If True, validate `mean` and `covariance_matrix` arguments.
2227
:param bool interleaved: (default=True) If True, covariance matrix is interpreted as block-diagonal w.r.t.
2328
inter-task covariances for each observation. If False, it is interpreted as block-diagonal
2429
w.r.t. inter-observation covariance for each task.
@@ -276,5 +281,145 @@ def variance(self):
276281
return var.view(new_shape).transpose(-1, -2).contiguous()
277282
return var.view(self._output_shape)
278283

284+
def __getitem__(self, idx) -> MultivariateNormal:
285+
"""
286+
Constructs a new MultivariateNormal that represents a random variable
287+
modified by an indexing operation.
288+
289+
The mean and covariance matrix arguments are indexed accordingly.
290+
291+
:param Any idx: Index to apply to the mean. The covariance matrix is indexed accordingly.
292+
:returns: If indices specify a slice for samples and tasks, returns a
293+
MultitaskMultivariateNormal, else returns a MultivariateNormal.
294+
"""
295+
296+
# Normalize index to a tuple
297+
if not isinstance(idx, tuple):
298+
idx = (idx,)
299+
300+
if ... in idx:
301+
# Replace ellipsis '...' with explicit indices
302+
ellipsis_location = idx.index(...)
303+
if ... in idx[ellipsis_location + 1 :]:
304+
raise IndexError("Only one ellipsis '...' is supported!")
305+
prefix = idx[:ellipsis_location]
306+
suffix = idx[ellipsis_location + 1 :]
307+
infix_length = self.mean.dim() - len(prefix) - len(suffix)
308+
if infix_length < 0:
309+
raise IndexError(f"Index {idx} has too many dimensions")
310+
idx = prefix + (slice(None),) * infix_length + suffix
311+
elif len(idx) == self.mean.dim() - 1:
312+
# Normalize indices ignoring the task-index to include it
313+
idx = idx + (slice(None),)
314+
315+
new_mean = self.mean[idx]
316+
317+
# We now create a covariance matrix appropriate for new_mean
318+
if len(idx) <= self.mean.dim() - 2:
319+
# We are only indexing the batch dimensions in this case
320+
return MultitaskMultivariateNormal(
321+
mean=new_mean,
322+
covariance_matrix=self.lazy_covariance_matrix[idx],
323+
interleaved=self._interleaved,
324+
)
325+
elif len(idx) > self.mean.dim():
326+
raise IndexError(f"Index {idx} has too many dimensions")
327+
else:
328+
# We have an index that extends over all dimensions
329+
batch_idx = idx[:-2]
330+
if self._interleaved:
331+
row_idx = idx[-2]
332+
col_idx = idx[-1]
333+
num_rows = self._output_shape[-2]
334+
num_cols = self._output_shape[-1]
335+
else:
336+
row_idx = idx[-1]
337+
col_idx = idx[-2]
338+
num_rows = self._output_shape[-1]
339+
num_cols = self._output_shape[-2]
340+
341+
if isinstance(row_idx, int) and isinstance(col_idx, int):
342+
# Single sample with single task
343+
row_idx = _normalize_index(row_idx, num_rows)
344+
col_idx = _normalize_index(col_idx, num_cols)
345+
new_cov = DiagLinearOperator(
346+
self.lazy_covariance_matrix.diagonal()[batch_idx + (row_idx * num_cols + col_idx,)]
347+
)
348+
return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov)
349+
elif isinstance(row_idx, int) and isinstance(col_idx, slice):
350+
# A block of the covariance matrix
351+
row_idx = _normalize_index(row_idx, num_rows)
352+
col_idx = _normalize_slice(col_idx, num_cols)
353+
new_slice = slice(
354+
col_idx.start + row_idx * num_cols,
355+
col_idx.stop + row_idx * num_cols,
356+
col_idx.step,
357+
)
358+
new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)]
359+
return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov)
360+
elif isinstance(row_idx, slice) and isinstance(col_idx, int):
361+
# A block of the reversely interleaved covariance matrix
362+
row_idx = _normalize_slice(row_idx, num_rows)
363+
col_idx = _normalize_index(col_idx, num_cols)
364+
new_slice = slice(row_idx.start + col_idx, row_idx.stop * num_cols + col_idx, row_idx.step * num_cols)
365+
new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)]
366+
return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov)
367+
elif (
368+
isinstance(row_idx, slice)
369+
and isinstance(col_idx, slice)
370+
and row_idx == col_idx == slice(None, None, None)
371+
):
372+
new_cov = self.lazy_covariance_matrix[batch_idx]
373+
return MultitaskMultivariateNormal(
374+
mean=new_mean,
375+
covariance_matrix=new_cov,
376+
interleaved=self._interleaved,
377+
validate_args=False,
378+
)
379+
elif isinstance(row_idx, slice) or isinstance(col_idx, slice):
380+
# slice x slice or indices x slice or slice x indices
381+
if isinstance(row_idx, slice):
382+
row_idx = torch.arange(num_rows)[row_idx]
383+
if isinstance(col_idx, slice):
384+
col_idx = torch.arange(num_cols)[col_idx]
385+
row_grid, col_grid = torch.meshgrid(row_idx, col_idx, indexing="ij")
386+
indices = (row_grid * num_cols + col_grid).reshape(-1)
387+
new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices]
388+
return MultitaskMultivariateNormal(
389+
mean=new_mean, covariance_matrix=new_cov, interleaved=self._interleaved, validate_args=False
390+
)
391+
else:
392+
# row_idx and col_idx have pairs of indices
393+
indices = row_idx * num_cols + col_idx
394+
new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices]
395+
return MultivariateNormal(
396+
mean=new_mean,
397+
covariance_matrix=new_cov,
398+
)
399+
279400
def __repr__(self) -> str:
280401
return f"MultitaskMultivariateNormal(mean shape: {self._output_shape})"
402+
403+
404+
def _normalize_index(i: int, dim_size: int) -> int:
405+
if i < 0:
406+
return dim_size + i
407+
else:
408+
return i
409+
410+
411+
def _normalize_slice(s: slice, dim_size: int) -> slice:
412+
start = s.start
413+
if start is None:
414+
start = 0
415+
elif start < 0:
416+
start = dim_size + start
417+
stop = s.stop
418+
if stop is None:
419+
stop = dim_size
420+
elif stop < 0:
421+
stop = dim_size + stop
422+
step = s.step
423+
if step is None:
424+
step = 1
425+
return slice(start, stop, step)

gpytorch/distributions/multivariate_normal.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -343,16 +343,22 @@ def __getitem__(self, idx) -> MultivariateNormal:
343343
344344
The mean and covariance matrix arguments are indexed accordingly.
345345
346-
:param idx: Index to apply.
346+
:param idx: Index to apply to the mean. The covariance matrix is indexed accordingly.
347347
"""
348348

349349
if not isinstance(idx, tuple):
350350
idx = (idx,)
351+
if len(idx) > self.mean.dim() and Ellipsis in idx:
352+
idx = tuple(i for i in idx if i != Ellipsis)
353+
if len(idx) < self.mean.dim():
354+
raise IndexError("Multiple ambiguous ellipsis in index!")
355+
351356
rest_idx = idx[:-1]
352357
last_idx = idx[-1]
353358
new_mean = self.mean[idx]
354359

355360
if len(idx) <= self.mean.dim() - 1 and (Ellipsis not in rest_idx):
361+
# We are only indexing the batch dimensions in this case
356362
new_cov = self.lazy_covariance_matrix[idx]
357363
elif len(idx) > self.mean.dim():
358364
raise IndexError(f"Index {idx} has too many dimensions")

gpytorch/likelihoods/gaussian_likelihood.py

+55-8
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
#!/usr/bin/env python3
2-
32
import math
43
import warnings
54
from copy import deepcopy
65
from typing import Any, Optional, Tuple, Union
76

87
import torch
9-
from linear_operator.operators import LinearOperator, ZeroLinearOperator
8+
from linear_operator.operators import LinearOperator, MaskedLinearOperator, ZeroLinearOperator
109
from torch import Tensor
1110
from torch.distributions import Distribution, Normal
1211

12+
from .. import settings
1313
from ..constraints import Interval
1414
from ..distributions import base_distributions, MultivariateNormal
1515
from ..priors import Prior
@@ -39,17 +39,39 @@ def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: An
3939
return self.noise_covar(*params, shape=base_shape, **kwargs)
4040

4141
def expected_log_prob(self, target: Tensor, input: MultivariateNormal, *params: Any, **kwargs: Any) -> Tensor:
42-
mean, variance = input.mean, input.variance
43-
num_event_dim = len(input.event_shape)
4442

45-
noise = self._shaped_noise_covar(mean.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2)
43+
noise = self._shaped_noise_covar(input.mean.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2)
4644
# Potentially reshape the noise to deal with the multitask case
4745
noise = noise.view(*noise.shape[:-1], *input.event_shape)
4846

47+
# Handle NaN values if enabled
48+
nan_policy = settings.observation_nan_policy.value()
49+
if nan_policy == "mask":
50+
observed = settings.observation_nan_policy._get_observed(target, input.event_shape)
51+
input = MultivariateNormal(
52+
mean=input.mean[..., observed],
53+
covariance_matrix=MaskedLinearOperator(
54+
input.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1)
55+
),
56+
)
57+
noise = noise[..., observed]
58+
target = target[..., observed]
59+
elif nan_policy == "fill":
60+
missing = torch.isnan(target)
61+
target = settings.observation_nan_policy._fill_tensor(target)
62+
63+
mean, variance = input.mean, input.variance
4964
res = ((target - mean).square() + variance) / noise + noise.log() + math.log(2 * math.pi)
5065
res = res.mul(-0.5)
51-
if num_event_dim > 1: # Do appropriate summation for multitask Gaussian likelihoods
66+
67+
if nan_policy == "fill":
68+
res = res * ~missing
69+
70+
# Do appropriate summation for multitask Gaussian likelihoods
71+
num_event_dim = len(input.event_shape)
72+
if num_event_dim > 1:
5273
res = res.sum(list(range(-1, -num_event_dim, -1)))
74+
5375
return res
5476

5577
def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> Normal:
@@ -60,12 +82,31 @@ def log_marginal(
6082
self, observations: Tensor, function_dist: MultivariateNormal, *params: Any, **kwargs: Any
6183
) -> Tensor:
6284
marginal = self.marginal(function_dist, *params, **kwargs)
85+
86+
# Handle NaN values if enabled
87+
nan_policy = settings.observation_nan_policy.value()
88+
if nan_policy == "mask":
89+
observed = settings.observation_nan_policy._get_observed(observations, marginal.event_shape)
90+
marginal = MultivariateNormal(
91+
mean=marginal.mean[..., observed],
92+
covariance_matrix=MaskedLinearOperator(
93+
marginal.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1)
94+
),
95+
)
96+
observations = observations[..., observed]
97+
elif nan_policy == "fill":
98+
missing = torch.isnan(observations)
99+
observations = settings.observation_nan_policy._fill_tensor(observations)
100+
63101
# We're making everything conditionally independent
64102
indep_dist = base_distributions.Normal(marginal.mean, marginal.variance.clamp_min(1e-8).sqrt())
65103
res = indep_dist.log_prob(observations)
66104

105+
if nan_policy == "fill":
106+
res = res * ~missing
107+
67108
# Do appropriate summation for multitask Gaussian likelihoods
68-
num_event_dim = len(function_dist.event_shape)
109+
num_event_dim = len(marginal.event_shape)
69110
if num_event_dim > 1:
70111
res = res.sum(list(range(-1, -num_event_dim, -1)))
71112
return res
@@ -150,13 +191,15 @@ class GaussianLikelihoodWithMissingObs(GaussianLikelihood):
150191
.. note::
151192
This likelihood can be used for exact or approximate inference.
152193
194+
.. warning::
195+
This likelihood is deprecated in favor of :class:`gpytorch.settings.observation_nan_policy`.
196+
153197
:param noise_prior: Prior for noise parameter :math:`\sigma^2`.
154198
:type noise_prior: ~gpytorch.priors.Prior, optional
155199
:param noise_constraint: Constraint for noise parameter :math:`\sigma^2`.
156200
:type noise_constraint: ~gpytorch.constraints.Interval, optional
157201
:param batch_shape: The batch shape of the learned noise parameter (default: []).
158202
:type batch_shape: torch.Size, optional
159-
160203
:var torch.Tensor noise: :math:`\sigma^2` parameter (noise)
161204
162205
.. note::
@@ -166,6 +209,10 @@ class GaussianLikelihoodWithMissingObs(GaussianLikelihood):
166209
MISSING_VALUE_FILL: float = -999.0
167210

168211
def __init__(self, **kwargs: Any) -> None:
212+
warnings.warn(
213+
"GaussianLikelihoodWithMissingObs is replaced by gpytorch.settings.observation_nan_policy('fill').",
214+
DeprecationWarning,
215+
)
169216
super().__init__(**kwargs)
170217

171218
def _get_masked_obs(self, x: Tensor) -> Tuple[Tensor, Tensor]:

gpytorch/mlls/exact_marginal_log_likelihood.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#!/usr/bin/env python3
22

3+
from linear_operator.operators import MaskedLinearOperator
4+
5+
from .. import settings
36
from ..distributions import MultivariateNormal
47
from ..likelihoods import _GaussianLikelihoodBase
58
from .marginal_log_likelihood import MarginalLogLikelihood
@@ -59,8 +62,23 @@ def forward(self, function_dist, target, *params):
5962
if not isinstance(function_dist, MultivariateNormal):
6063
raise RuntimeError("ExactMarginalLogLikelihood can only operate on Gaussian random variables")
6164

62-
# Get the log prob of the marginal distribution
65+
# Determine output likelihood
6366
output = self.likelihood(function_dist, *params)
67+
68+
# Remove NaN values if enabled
69+
if settings.observation_nan_policy.value() == "mask":
70+
observed = settings.observation_nan_policy._get_observed(target, output.event_shape)
71+
output = MultivariateNormal(
72+
mean=output.mean[..., observed],
73+
covariance_matrix=MaskedLinearOperator(
74+
output.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1)
75+
),
76+
)
77+
target = target[..., observed]
78+
elif settings.observation_nan_policy.value() == "fill":
79+
raise ValueError("NaN observation policy 'fill' is not supported by ExactMarginalLogLikelihood!")
80+
81+
# Get the log prob of the marginal distribution
6482
res = output.log_prob(target)
6583
res = self._add_other_terms(res, params)
6684

0 commit comments

Comments
 (0)