2
2
3
3
import torch
4
4
from linear_operator import LinearOperator , to_linear_operator
5
- from linear_operator .operators import BlockDiagLinearOperator , BlockInterleavedLinearOperator , CatLinearOperator
5
+ from linear_operator .operators import (
6
+ BlockDiagLinearOperator ,
7
+ BlockInterleavedLinearOperator ,
8
+ CatLinearOperator ,
9
+ DiagLinearOperator ,
10
+ )
6
11
7
12
from .multivariate_normal import MultivariateNormal
8
13
@@ -18,7 +23,7 @@ class MultitaskMultivariateNormal(MultivariateNormal):
18
23
:param torch.Tensor mean: An `n x t` or batch `b x n x t` matrix of means for the MVN distribution.
19
24
:param ~linear_operator.operators.LinearOperator covar: An `... x NT x NT` (batch) matrix.
20
25
covariance matrix of MVN distribution.
21
- :param bool validate_args: (default=False) If True, validate `mean` anad `covariance_matrix` arguments.
26
+ :param bool validate_args: (default=False) If True, validate `mean` and `covariance_matrix` arguments.
22
27
:param bool interleaved: (default=True) If True, covariance matrix is interpreted as block-diagonal w.r.t.
23
28
inter-task covariances for each observation. If False, it is interpreted as block-diagonal
24
29
w.r.t. inter-observation covariance for each task.
@@ -276,5 +281,145 @@ def variance(self):
276
281
return var .view (new_shape ).transpose (- 1 , - 2 ).contiguous ()
277
282
return var .view (self ._output_shape )
278
283
284
+ def __getitem__ (self , idx ) -> MultivariateNormal :
285
+ """
286
+ Constructs a new MultivariateNormal that represents a random variable
287
+ modified by an indexing operation.
288
+
289
+ The mean and covariance matrix arguments are indexed accordingly.
290
+
291
+ :param Any idx: Index to apply to the mean. The covariance matrix is indexed accordingly.
292
+ :returns: If indices specify a slice for samples and tasks, returns a
293
+ MultitaskMultivariateNormal, else returns a MultivariateNormal.
294
+ """
295
+
296
+ # Normalize index to a tuple
297
+ if not isinstance (idx , tuple ):
298
+ idx = (idx ,)
299
+
300
+ if ... in idx :
301
+ # Replace ellipsis '...' with explicit indices
302
+ ellipsis_location = idx .index (...)
303
+ if ... in idx [ellipsis_location + 1 :]:
304
+ raise IndexError ("Only one ellipsis '...' is supported!" )
305
+ prefix = idx [:ellipsis_location ]
306
+ suffix = idx [ellipsis_location + 1 :]
307
+ infix_length = self .mean .dim () - len (prefix ) - len (suffix )
308
+ if infix_length < 0 :
309
+ raise IndexError (f"Index { idx } has too many dimensions" )
310
+ idx = prefix + (slice (None ),) * infix_length + suffix
311
+ elif len (idx ) == self .mean .dim () - 1 :
312
+ # Normalize indices ignoring the task-index to include it
313
+ idx = idx + (slice (None ),)
314
+
315
+ new_mean = self .mean [idx ]
316
+
317
+ # We now create a covariance matrix appropriate for new_mean
318
+ if len (idx ) <= self .mean .dim () - 2 :
319
+ # We are only indexing the batch dimensions in this case
320
+ return MultitaskMultivariateNormal (
321
+ mean = new_mean ,
322
+ covariance_matrix = self .lazy_covariance_matrix [idx ],
323
+ interleaved = self ._interleaved ,
324
+ )
325
+ elif len (idx ) > self .mean .dim ():
326
+ raise IndexError (f"Index { idx } has too many dimensions" )
327
+ else :
328
+ # We have an index that extends over all dimensions
329
+ batch_idx = idx [:- 2 ]
330
+ if self ._interleaved :
331
+ row_idx = idx [- 2 ]
332
+ col_idx = idx [- 1 ]
333
+ num_rows = self ._output_shape [- 2 ]
334
+ num_cols = self ._output_shape [- 1 ]
335
+ else :
336
+ row_idx = idx [- 1 ]
337
+ col_idx = idx [- 2 ]
338
+ num_rows = self ._output_shape [- 1 ]
339
+ num_cols = self ._output_shape [- 2 ]
340
+
341
+ if isinstance (row_idx , int ) and isinstance (col_idx , int ):
342
+ # Single sample with single task
343
+ row_idx = _normalize_index (row_idx , num_rows )
344
+ col_idx = _normalize_index (col_idx , num_cols )
345
+ new_cov = DiagLinearOperator (
346
+ self .lazy_covariance_matrix .diagonal ()[batch_idx + (row_idx * num_cols + col_idx ,)]
347
+ )
348
+ return MultivariateNormal (mean = new_mean , covariance_matrix = new_cov )
349
+ elif isinstance (row_idx , int ) and isinstance (col_idx , slice ):
350
+ # A block of the covariance matrix
351
+ row_idx = _normalize_index (row_idx , num_rows )
352
+ col_idx = _normalize_slice (col_idx , num_cols )
353
+ new_slice = slice (
354
+ col_idx .start + row_idx * num_cols ,
355
+ col_idx .stop + row_idx * num_cols ,
356
+ col_idx .step ,
357
+ )
358
+ new_cov = self .lazy_covariance_matrix [batch_idx + (new_slice , new_slice )]
359
+ return MultivariateNormal (mean = new_mean , covariance_matrix = new_cov )
360
+ elif isinstance (row_idx , slice ) and isinstance (col_idx , int ):
361
+ # A block of the reversely interleaved covariance matrix
362
+ row_idx = _normalize_slice (row_idx , num_rows )
363
+ col_idx = _normalize_index (col_idx , num_cols )
364
+ new_slice = slice (row_idx .start + col_idx , row_idx .stop * num_cols + col_idx , row_idx .step * num_cols )
365
+ new_cov = self .lazy_covariance_matrix [batch_idx + (new_slice , new_slice )]
366
+ return MultivariateNormal (mean = new_mean , covariance_matrix = new_cov )
367
+ elif (
368
+ isinstance (row_idx , slice )
369
+ and isinstance (col_idx , slice )
370
+ and row_idx == col_idx == slice (None , None , None )
371
+ ):
372
+ new_cov = self .lazy_covariance_matrix [batch_idx ]
373
+ return MultitaskMultivariateNormal (
374
+ mean = new_mean ,
375
+ covariance_matrix = new_cov ,
376
+ interleaved = self ._interleaved ,
377
+ validate_args = False ,
378
+ )
379
+ elif isinstance (row_idx , slice ) or isinstance (col_idx , slice ):
380
+ # slice x slice or indices x slice or slice x indices
381
+ if isinstance (row_idx , slice ):
382
+ row_idx = torch .arange (num_rows )[row_idx ]
383
+ if isinstance (col_idx , slice ):
384
+ col_idx = torch .arange (num_cols )[col_idx ]
385
+ row_grid , col_grid = torch .meshgrid (row_idx , col_idx , indexing = "ij" )
386
+ indices = (row_grid * num_cols + col_grid ).reshape (- 1 )
387
+ new_cov = self .lazy_covariance_matrix [batch_idx + (indices ,)][..., indices ]
388
+ return MultitaskMultivariateNormal (
389
+ mean = new_mean , covariance_matrix = new_cov , interleaved = self ._interleaved , validate_args = False
390
+ )
391
+ else :
392
+ # row_idx and col_idx have pairs of indices
393
+ indices = row_idx * num_cols + col_idx
394
+ new_cov = self .lazy_covariance_matrix [batch_idx + (indices ,)][..., indices ]
395
+ return MultivariateNormal (
396
+ mean = new_mean ,
397
+ covariance_matrix = new_cov ,
398
+ )
399
+
279
400
def __repr__ (self ) -> str :
280
401
return f"MultitaskMultivariateNormal(mean shape: { self ._output_shape } )"
402
+
403
+
404
+ def _normalize_index (i : int , dim_size : int ) -> int :
405
+ if i < 0 :
406
+ return dim_size + i
407
+ else :
408
+ return i
409
+
410
+
411
+ def _normalize_slice (s : slice , dim_size : int ) -> slice :
412
+ start = s .start
413
+ if start is None :
414
+ start = 0
415
+ elif start < 0 :
416
+ start = dim_size + start
417
+ stop = s .stop
418
+ if stop is None :
419
+ stop = dim_size
420
+ elif stop < 0 :
421
+ stop = dim_size + stop
422
+ step = s .step
423
+ if step is None :
424
+ step = 1
425
+ return slice (start , stop , step )
0 commit comments