|
2 | 2 |
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 |
| -import warnings |
6 |
| -from typing import Any, Optional, Tuple, Union |
7 |
| - |
8 |
| -import torch |
9 |
| -from linear_operator import LinearOperator, operators |
10 |
| -from linear_operator.operators.cat_linear_operator import cat as _cat |
11 |
| - |
12 | 5 | from .lazy_evaluated_kernel_tensor import LazyEvaluatedKernelTensor
|
13 |
| -from .lazy_tensor import delazify, deprecated_lazy_tensor |
14 |
| -from .non_lazy_tensor import lazify |
15 |
| - |
16 |
| -# We will dynamically import LazyTensor/NonLazyTensor to trigger DeprecationWarnings |
17 |
| - |
18 |
| - |
19 |
| -_deprecated_lazy_tensors = { |
20 |
| - "AddedDiagLazyTensor": deprecated_lazy_tensor(operators.AddedDiagLinearOperator), |
21 |
| - "BatchRepeatLazyTensor": deprecated_lazy_tensor(operators.BatchRepeatLinearOperator), |
22 |
| - "BlockDiagLazyTensor": deprecated_lazy_tensor(operators.BlockDiagLinearOperator), |
23 |
| - "BlockInterleavedLazyTensor": deprecated_lazy_tensor(operators.BlockInterleavedLinearOperator), |
24 |
| - "BlockLazyTensor": deprecated_lazy_tensor(operators.BlockLinearOperator), |
25 |
| - "CatLazyTensor": deprecated_lazy_tensor(operators.CatLinearOperator), |
26 |
| - "CholLazyTensor": deprecated_lazy_tensor(operators.CholLinearOperator), |
27 |
| - "ConstantMulLazyTensor": deprecated_lazy_tensor(operators.ConstantMulLinearOperator), |
28 |
| - "ConstantDiagLazyTensor": deprecated_lazy_tensor(operators.ConstantDiagLinearOperator), |
29 |
| - "DiagLazyTensor": deprecated_lazy_tensor(operators.DiagLinearOperator), |
30 |
| - "IdentityLazyTensor": deprecated_lazy_tensor(operators.IdentityLinearOperator), |
31 |
| - "InterpolatedLazyTensor": deprecated_lazy_tensor(operators.InterpolatedLinearOperator), |
32 |
| - "KeOpsLazyTensor": deprecated_lazy_tensor(operators.KeOpsLinearOperator), |
33 |
| - "KroneckerProductAddedDiagLazyTensor": deprecated_lazy_tensor(operators.KroneckerProductAddedDiagLinearOperator), |
34 |
| - "KroneckerProductDiagLazyTensor": deprecated_lazy_tensor(operators.KroneckerProductDiagLinearOperator), |
35 |
| - "KroneckerProductLazyTensor": deprecated_lazy_tensor(operators.KroneckerProductLinearOperator), |
36 |
| - "KroneckerProductTriangularLazyTensor": deprecated_lazy_tensor(operators.KroneckerProductTriangularLinearOperator), |
37 |
| - "LowRankRootAddedDiagLazyTensor": deprecated_lazy_tensor(operators.LowRankRootAddedDiagLinearOperator), |
38 |
| - "LowRankRootLazyTensor": deprecated_lazy_tensor(operators.LowRankRootLinearOperator), |
39 |
| - "MatmulLazyTensor": deprecated_lazy_tensor(operators.MatmulLinearOperator), |
40 |
| - "MulLazyTensor": deprecated_lazy_tensor(operators.MulLinearOperator), |
41 |
| - "PsdSumLazyTensor": deprecated_lazy_tensor(operators.PsdSumLinearOperator), |
42 |
| - "RootLazyTensor": deprecated_lazy_tensor(operators.RootLinearOperator), |
43 |
| - "SumBatchLazyTensor": deprecated_lazy_tensor(operators.SumBatchLinearOperator), |
44 |
| - "SumKroneckerLazyTensor": deprecated_lazy_tensor(operators.SumKroneckerLinearOperator), |
45 |
| - "SumLazyTensor": deprecated_lazy_tensor(operators.SumLinearOperator), |
46 |
| - "ToeplitzLazyTensor": deprecated_lazy_tensor(operators.ToeplitzLinearOperator), |
47 |
| - "TriangularLazyTensor": deprecated_lazy_tensor(operators.TriangularLinearOperator), |
48 |
| - "ZeroLazyTensor": deprecated_lazy_tensor(operators.ZeroLinearOperator), |
49 |
| -} |
50 |
| - |
51 |
| - |
52 |
| -def cat( |
53 |
| - inputs: Tuple[Union[LinearOperator, torch.Tensor], ...], dim: int = 0, output_device: Optional[torch.device] = None |
54 |
| -) -> Union[torch.Tensor, LinearOperator]: |
55 |
| - warnings.warn("gpytorch.lazy.cat is deprecated in favor of linear_operator.cat") |
56 |
| - return _cat(inputs, dim=dim, output_device=output_device) |
57 |
| - |
58 | 6 |
|
59 | 7 | __all__ = [
|
60 |
| - "delazify", |
61 |
| - "lazify", |
62 |
| - "cat", |
63 | 8 | "LazyEvaluatedKernelTensor",
|
64 |
| - "LazyTensor", |
65 | 9 | ]
|
66 |
| - |
67 |
| - |
68 |
| -def __getattr__(name: str) -> Any: |
69 |
| - if not name.startswith("_"): |
70 |
| - warnings.warn( |
71 |
| - "GPyTorch will be replacing all LazyTensor functionality with the linear operator package. " |
72 |
| - "Replace all references to gpytorch.lazy.*LazyTensor with linear_operator.operators.*LinearOperator.", |
73 |
| - DeprecationWarning, |
74 |
| - ) |
75 |
| - if name == "LazyTensor": |
76 |
| - from .lazy_tensor import LazyTensor |
77 |
| - |
78 |
| - return deprecated_lazy_tensor(LazyTensor) |
79 |
| - elif name == "NonLazyTensor": |
80 |
| - from .non_lazy_tensor import NonLazyTensor |
81 |
| - |
82 |
| - return deprecated_lazy_tensor(NonLazyTensor) |
83 |
| - elif name in _deprecated_lazy_tensors: |
84 |
| - return _deprecated_lazy_tensors[name] |
85 |
| - raise AttributeError(f"module gpytorch.lazy has no attribute {name}") |
0 commit comments