|
22 | 22 | """Pre-defined datasets."""
|
23 | 23 |
|
24 | 24 | import types
|
25 |
| -from typing import List, Tuple, Union |
| 25 | +from typing import List, Literal, Optional, Tuple, Union |
26 | 26 |
|
27 | 27 | import numpy as np
|
28 | 28 | import numpy.typing as npt
|
29 | 29 | import torch
|
30 | 30 |
|
31 | 31 | import cebra.data as cebra_data
|
| 32 | +import cebra.helper as cebra_helper |
| 33 | +from cebra.data.datatypes import Offset |
32 | 34 |
|
33 | 35 |
|
34 | 36 | class TensorDataset(cebra_data.SingleSessionDataset):
|
@@ -64,26 +66,52 @@ def __init__(self,
|
64 | 66 | neural: Union[torch.Tensor, npt.NDArray],
|
65 | 67 | continuous: Union[torch.Tensor, npt.NDArray] = None,
|
66 | 68 | discrete: Union[torch.Tensor, npt.NDArray] = None,
|
67 |
| - offset: int = 1, |
| 69 | + offset: Offset = Offset(0, 1), |
68 | 70 | device: str = "cpu"):
|
69 | 71 | super().__init__(device=device)
|
70 |
| - self.neural = self._to_tensor(neural, torch.FloatTensor).float() |
71 |
| - self.continuous = self._to_tensor(continuous, torch.FloatTensor) |
72 |
| - self.discrete = self._to_tensor(discrete, torch.LongTensor) |
| 72 | + self.neural = self._to_tensor(neural, check_dtype="float").float() |
| 73 | + self.continuous = self._to_tensor(continuous, check_dtype="float") |
| 74 | + self.discrete = self._to_tensor(discrete, check_dtype="int") |
73 | 75 | if self.continuous is None and self.discrete is None:
|
74 | 76 | raise ValueError(
|
75 | 77 | "You have to pass at least one of the arguments 'continuous' or 'discrete'."
|
76 | 78 | )
|
77 | 79 | self.offset = offset
|
78 | 80 |
|
79 |
| - def _to_tensor(self, array, check_dtype=None): |
| 81 | + def _to_tensor( |
| 82 | + self, |
| 83 | + array: Union[torch.Tensor, npt.NDArray], |
| 84 | + check_dtype: Optional[Literal["int", |
| 85 | + "float"]] = None) -> torch.Tensor: |
| 86 | + """Convert :py:func:`numpy.array` to :py:class:`torch.Tensor` if necessary and check the dtype. |
| 87 | +
|
| 88 | + Args: |
| 89 | + array: Array to check. |
| 90 | + check_dtype: If not `None`, list of dtypes to which the values in `array` |
| 91 | + must belong to. Defaults to None. |
| 92 | +
|
| 93 | + Returns: |
| 94 | + The `array` as a :py:class:`torch.Tensor`. |
| 95 | + """ |
80 | 96 | if array is None:
|
81 | 97 | return None
|
82 | 98 | if isinstance(array, np.ndarray):
|
83 | 99 | array = torch.from_numpy(array)
|
84 | 100 | if check_dtype is not None:
|
85 |
| - if not isinstance(array, check_dtype): |
86 |
| - raise TypeError(f"{type(array)} instead of {check_dtype}.") |
| 101 | + if check_dtype not in ["int", "float"]: |
| 102 | + raise ValueError( |
| 103 | + f"check_dtype must be 'int' or 'float', got {check_dtype}") |
| 104 | + if (check_dtype == "int" and not cebra_helper._is_integer(array) |
| 105 | + ) or (check_dtype == "float" and |
| 106 | + not cebra_helper._is_floating(array)): |
| 107 | + raise TypeError( |
| 108 | + f"Array has type {array.dtype} instead of {check_dtype}.") |
| 109 | + if cebra_helper._is_floating(array): |
| 110 | + array = array.float() |
| 111 | + if cebra_helper._is_integer(array): |
| 112 | + # NOTE(stes): Required for standardizing number format on |
| 113 | + # windows machines. |
| 114 | + array = array.long() |
87 | 115 | return array
|
88 | 116 |
|
89 | 117 | @property
|
|
0 commit comments