Skip to content

Commit 0eac868

Browse files
committed
Merge remote-tracking branch 'origin/main' into batched-inference-and-padding
2 parents e1b7cc7 + e652b9a commit 0eac868

28 files changed

+354
-153
lines changed

cebra/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ def __getattr__(key):
9292

9393
return CEBRA
9494
elif key == "KNNDecoder":
95-
from cebra.integrations.sklearn.decoder import KNNDecoder
95+
from cebra.integrations.sklearn.decoder import KNNDecoder # noqa: F811
9696

9797
return KNNDecoder
9898
elif key == "L1LinearRegressor":
99-
from cebra.integrations.sklearn.decoder import L1LinearRegressor
99+
from cebra.integrations.sklearn.decoder import L1LinearRegressor # noqa: F811
100100

101101
return L1LinearRegressor
102102
elif not key.startswith("_"):

cebra/data/datasets.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
"""Pre-defined datasets."""
2323

2424
import types
25-
from typing import List, Tuple, Union
25+
from typing import List, Literal, Optional, Tuple, Union
2626

2727
import numpy as np
2828
import numpy.typing as npt
2929
import torch
3030

3131
import cebra.data as cebra_data
32+
import cebra.helper as cebra_helper
33+
from cebra.data.datatypes import Offset
3234

3335

3436
class TensorDataset(cebra_data.SingleSessionDataset):
@@ -64,26 +66,52 @@ def __init__(self,
6466
neural: Union[torch.Tensor, npt.NDArray],
6567
continuous: Union[torch.Tensor, npt.NDArray] = None,
6668
discrete: Union[torch.Tensor, npt.NDArray] = None,
67-
offset: int = 1,
69+
offset: Offset = Offset(0, 1),
6870
device: str = "cpu"):
6971
super().__init__(device=device)
70-
self.neural = self._to_tensor(neural, torch.FloatTensor).float()
71-
self.continuous = self._to_tensor(continuous, torch.FloatTensor)
72-
self.discrete = self._to_tensor(discrete, torch.LongTensor)
72+
self.neural = self._to_tensor(neural, check_dtype="float").float()
73+
self.continuous = self._to_tensor(continuous, check_dtype="float")
74+
self.discrete = self._to_tensor(discrete, check_dtype="int")
7375
if self.continuous is None and self.discrete is None:
7476
raise ValueError(
7577
"You have to pass at least one of the arguments 'continuous' or 'discrete'."
7678
)
7779
self.offset = offset
7880

79-
def _to_tensor(self, array, check_dtype=None):
81+
def _to_tensor(
82+
self,
83+
array: Union[torch.Tensor, npt.NDArray],
84+
check_dtype: Optional[Literal["int",
85+
"float"]] = None) -> torch.Tensor:
86+
"""Convert :py:func:`numpy.array` to :py:class:`torch.Tensor` if necessary and check the dtype.
87+
88+
Args:
89+
array: Array to check.
90+
check_dtype: If not `None`, list of dtypes to which the values in `array`
91+
must belong to. Defaults to None.
92+
93+
Returns:
94+
The `array` as a :py:class:`torch.Tensor`.
95+
"""
8096
if array is None:
8197
return None
8298
if isinstance(array, np.ndarray):
8399
array = torch.from_numpy(array)
84100
if check_dtype is not None:
85-
if not isinstance(array, check_dtype):
86-
raise TypeError(f"{type(array)} instead of {check_dtype}.")
101+
if check_dtype not in ["int", "float"]:
102+
raise ValueError(
103+
f"check_dtype must be 'int' or 'float', got {check_dtype}")
104+
if (check_dtype == "int" and not cebra_helper._is_integer(array)
105+
) or (check_dtype == "float" and
106+
not cebra_helper._is_floating(array)):
107+
raise TypeError(
108+
f"Array has type {array.dtype} instead of {check_dtype}.")
109+
if cebra_helper._is_floating(array):
110+
array = array.float()
111+
if cebra_helper._is_integer(array):
112+
# NOTE(stes): Required for standardizing number format on
113+
# windows machines.
114+
array = array.long()
87115
return array
88116

89117
@property

cebra/data/helper.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,15 @@ class OrthogonalProcrustesAlignment:
9494
9595
For each dataset, the data and labels to align the data on is provided.
9696
97-
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to the labels of the reference dataset (``ref_label``) are selected and used to sample from the dataset to align (``data``).
98-
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number of samples ``subsample``.
99-
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`, on those subsampled datasets.
100-
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data`` to the ``ref_data``.
97+
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to
98+
the labels of the reference dataset (``ref_label``) are selected and used to sample
99+
from the dataset to align (``data``).
100+
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number
101+
of samples ``subsample``.
102+
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`,
103+
on those subsampled datasets.
104+
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data``
105+
to the ``ref_data``.
101106
102107
Note:
103108
``data`` and ``ref_data`` can be of different sample size (axis 0) but **must** have the same number

cebra/data/load.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,8 @@ def load(
663663
- if no key is provided, the first data structure found upon iteration of the collection will be loaded;
664664
- if a key is provided, it needs to correspond to an existing item of the collection;
665665
- if a key is provided, the data value accessed needs to be a data structure;
666-
- the function loads data for only one data structure, even if the file contains more. The function can be called again with the corresponding key to get the other ones.
666+
- the function loads data for only one data structure, even if the file contains more. The function can be
667+
called again with the corresponding key to get the other ones.
667668
668669
Args:
669670
file: The path to the given file to load, in a supported format.

cebra/data/single_session.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@ def __post_init__(self):
358358
# here might be sub-optimal. The final behavior should be determined after
359359
# e.g. integrating the FAISS dataloader back in.
360360
super().__post_init__()
361-
index = self.index.to(self.device)
362361

363362
if self.conditional != "time_delta":
364363
raise NotImplementedError(
@@ -368,8 +367,7 @@ def __post_init__(self):
368367
self.time_distribution = cebra.distributions.TimeContrastive(
369368
time_offset=self.time_offset,
370369
num_samples=len(self.dataset.neural),
371-
device=self.device,
372-
)
370+
device=self.device)
373371
self.behavior_distribution = cebra.distributions.TimedeltaDistribution(
374372
self.dataset.continuous_index, self.time_offset, device=self.device)
375373

cebra/datasets/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,6 @@ def get_datapath(path: str = None) -> str:
9898
from cebra.datasets.monkey_reaching import *
9999
from cebra.datasets.synthetic_data import *
100100
except ModuleNotFoundError as e:
101-
import warnings
102-
103101
warnings.warn(f"Could not initialize one or more datasets: {e}. "
104102
f"For using the datasets, consider installing the "
105103
f"[datasets] extension via pip.")

cebra/datasets/allen/ca_movie.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@
2222
"""Allen pseudomouse Ca dataset.
2323
2424
References:
25-
*Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339.
26-
*de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151.
27-
*https://github.com/zivlab/visual_drift
28-
*http://observatory.brain-map.org/visualcoding
29-
25+
* Deitch, Daniel, Alon Rubin, and Yaniv Ziv.
26+
"Representational drift in the mouse visual cortex."
27+
Current biology 31.19 (2021): 4327-4339.
28+
* de Vries, Saskia EJ, et al.
29+
"A large-scale standardized physiological survey reveals functional
30+
organization of the mouse visual cortex."
31+
Nature neuroscience 23.1 (2020): 138-151.
32+
* https://github.com/zivlab/visual_drift
33+
* http://observatory.brain-map.org/visualcoding
3034
"""
3135

3236
import pathlib

cebra/datasets/allen/ca_movie_decoding.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@
2222
"""Allen pseudomouse Ca decoding dataset with train/test split.
2323
2424
References:
25-
*Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339.
26-
*de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151.
27-
*https://github.com/zivlab/visual_drift
28-
*http://observatory.brain-map.org/visualcoding
29-
25+
* Deitch, Daniel, Alon Rubin, and Yaniv Ziv.
26+
"Representational drift in the mouse visual cortex."
27+
Current biology 31.19 (2021): 4327-4339.
28+
* de Vries, Saskia EJ, et al.
29+
"A large-scale standardized physiological survey reveals functional
30+
organization of the mouse visual cortex."
31+
Nature neuroscience 23.1 (2020): 138-151.
32+
* https://github.com/zivlab/visual_drift
33+
* http://observatory.brain-map.org/visualcoding
3034
"""
3135

3236
import pathlib
@@ -243,11 +247,6 @@ def _convert_to_nums(string):
243247

244248
return pseudo_mouse
245249

246-
pseudo_mouse = np.vstack(
247-
[get_neural_data(num_movie, mice) for mice in list_mice])
248-
249-
return pseudo_mouse
250-
251250
def __len__(self):
252251
return self.neural.size(0)
253252

cebra/datasets/allen/combined.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,19 @@
2222
"""Joint Allen pseudomouse Ca/Neuropixel datasets.
2323
2424
References:
25-
*Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339.
26-
*de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151.
27-
*https://github.com/zivlab/visual_drift
28-
*http://observatory.brain-map.org/visualcoding
29-
*https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html
30-
*Siegle, Joshua H., et al. "Survey of spiking in the mouse visual system reveals functional hierarchy." Nature 592.7852 (2021): 86-92.
31-
25+
* Deitch, Daniel, Alon Rubin, and Yaniv Ziv.
26+
"Representational drift in the mouse visual cortex."
27+
Current Biology 31.19 (2021): 4327-4339.
28+
* de Vries, Saskia EJ, et al.
29+
"A large-scale standardized physiological survey reveals functional
30+
organization of the mouse visual cortex."
31+
Nature Neuroscience 23.1 (2020): 138-151.
32+
* https://github.com/zivlab/visual_drift
33+
* http://observatory.brain-map.org/visualcoding
34+
* https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html
35+
* Siegle, Joshua H., et al.
36+
"Survey of spiking in the mouse visual system reveals functional hierarchy."
37+
Nature 592.7852 (2021): 86-92.
3238
"""
3339

3440
import cebra.data

cebra/datasets/allen/make_neuropixel.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,12 @@ def read_neuropixel(
192192
"intervals/natural_movie_one_presentations/start_time"][...]
193193
end_time = d[
194194
"intervals/natural_movie_one_presentations/stop_time"][...]
195-
timeseries = d[
196-
"intervals/natural_movie_one_presentations/timeseries"][...]
197-
timeseries_index = d[
198-
"intervals/natural_movie_one_presentations/timeseries_index"][
199-
...]
195+
# NOTE(stes): never used. leaving here for future reference
196+
#timeseries = d[
197+
# "intervals/natural_movie_one_presentations/timeseries"][...]
198+
#timeseries_index = d[
199+
# "intervals/natural_movie_one_presentations/timeseries_index"][
200+
# ...]
200201
session_no = d["identifier"][...].item()
201202
spike_time_index = d["units/spike_times_index"][...]
202203
spike_times = d["units/spike_times"][...]
@@ -266,14 +267,14 @@ def read_neuropixel(
266267
"neural": sessions_dic,
267268
"frames": session_frames
268269
},
269-
Path(args.save_path) /
270+
pathlib.Path(args.save_path) /
270271
f"neuropixel_sessions_{int(args.sampling_rate)}_filtered.jl",
271272
)
272273
jl.dump(
273274
{
274275
"neural": pseudo_mice,
275276
"frames": pseudo_mice_frames
276277
},
277-
Path(args.save_path) /
278+
pathlib.Path(args.save_path) /
278279
f"neuropixel_pseudomouse_{int(args.sampling_rate)}_filtered.jl",
279280
)

cebra/datasets/allen/single_session_ca.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,21 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121
#
22-
"""Allen single mouse dataset.
22+
"""
23+
Allen single mouse dataset.
2324
2425
References:
25-
*Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339.
26-
*de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151.
27-
*https://github.com/zivlab/visual_drift
28-
*http://observatory.brain-map.org/visualcoding
26+
* Deitch, Daniel, Alon Rubin, and Yaniv Ziv.
27+
"Representational drift in the mouse visual cortex."
28+
Current Biology 31.19 (2021): 4327-4339.
29+
30+
* de Vries, Saskia EJ, et al.
31+
"A large-scale standardized physiological survey reveals functional
32+
organization of the mouse visual cortex."
33+
Nature Neuroscience 23.1 (2020): 138-151.
2934
35+
* https://github.com/zivlab/visual_drift
36+
* http://observatory.brain-map.org/visualcoding
3037
"""
3138
import pathlib
3239

@@ -113,7 +120,7 @@ def __getitem__(self, index):
113120
"allen-movie1-ca-single-session-corrupt-{session_id}",
114121
session_id=range(len(_SINGLE_SESSION_CA)),
115122
)
116-
class SingleSessionAllenCa(cebra.data.SingleSessionDataset):
123+
class SingleSessionAllenCaCorrupted(cebra.data.SingleSessionDataset):
117124
"""A corrupted single mouse 30Hz calcium events dataset during the allen MOVIE1 stimulus.
118125
119126
A dataset of a single mouse 30Hz calcium events from the excitatory neurons in the primary visual cortex
@@ -352,7 +359,7 @@ def __init__(self, repeat_no, split_flag):
352359
repeat_no=[9],
353360
split_flag=["train", "test"],
354361
)
355-
class SingleSessionAllenCaDecoding(cebra.data.SingleSessionDataset):
362+
class SingleSessionAllenCaDecodingCorrupted(cebra.data.SingleSessionDataset):
356363
"""A corrupted single mouse 30Hz calcium events dataset during the allen MOVIE1 stimulus with train/test splits.
357364
358365
A dataset of a single mouse 30Hz calcium events from the excitatory neurons

cebra/datasets/gaussian_mixture.py

+3
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@
2727

2828
import cebra.data
2929
import cebra.io
30+
from cebra.datasets import get_datapath
3031
from cebra.datasets import parametrize
3132
from cebra.datasets import register
3233

34+
_DEFAULT_DATADIR = get_datapath()
35+
3336

3437
@register("continuous-gaussian-mixture")
3538
@parametrize(

cebra/datasets/generate_synthetic_data.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import joblib as jl
3131
import keras
3232
import numpy as np
33-
import poisson
33+
import poisson as poisson_utils
3434
import scipy.stats
3535
import tensorflow as tf
3636

@@ -228,7 +228,7 @@ def refractory_poisson(x):
228228
flattened_lam = lam_true.flatten()
229229
x = np.zeros_like(flattened_lam)
230230
for i, rate in enumerate(flattened_lam):
231-
neuron = poisson.PoissonNeuron(
231+
neuron = poisson_utils.PoissonNeuron(
232232
spike_rate=rate * args.scale,
233233
num_repeats=1,
234234
time_interval=args.time_interval,

cebra/datasets/hippocampus.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,16 @@ def decode(self, x_train, y_train, x_test, y_test):
160160
class SingleRatTrialSplitDataset(SingleRatDataset):
161161
"""A single rat hippocampus tetrode recording while the rat navigates on a linear track with 3-fold splits.
162162
163-
Neural data is spike counts binned into 25ms time window and the behavior is position and the running direction (left, right) of a rat.
164-
The behavior label is structured as 3D array consists of position, right, and left.
165-
The neural and behavior recordings are parsed into trials (a round trip from one end of the track) and the trials are split into a train, valid and test set with k=3 nested cross validation.
163+
Neural data is spike counts binned into 25ms time window and the behavior is position and the running
164+
direction (left, right) of a rat. The behavior label is structured as 3D array consists of position,
165+
right, and left. The neural and behavior recordings are parsed into trials (a round trip from one end
166+
of the track) and the trials are split into a train, valid and test set with k=3 nested cross validation.
166167
167168
Args:
168169
name: The name of a rat to use. Choose among 'achilles', 'buddy', 'cicero' and 'gatsby'.
169170
split_no: The `k` for k-fold split. Choose among 0, 1, 2.
170-
split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test'(all trials except test split).
171+
split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test'
172+
(all trials except test split).
171173
172174
"""
173175

@@ -281,13 +283,16 @@ class MultipleRatsTrialSplitDataset(cebra.data.DatasetCollection):
281283
"""4 rats hippocampus tetrode recording while the rat navigates on a linear track with 3-fold splits.
282284
283285
Neural and behavior recordings of 4 rats.
284-
For each rat, neural data is spike counts binned into 25ms time window and the behavior is position and the running direction (left, right) of a rat.
286+
For each rat, neural data is spike counts binned into 25ms time window and the behavior is position
287+
and the running direction (left, right) of a rat.
285288
The behavior label is structured as 3D array consists of position, right, and left.
286-
Neural and behavior recordings of each rat are parsed into trials (a round trip from one end of the track) and the trials are split into a train, valid and test set with k=3 nested cross validation.
289+
Neural and behavior recordings of each rat are parsed into trials (a round trip from one end of the track)
290+
and the trials are split into a train, valid and test set with k=3 nested cross validation.
287291
288292
Args:
289293
split_no: The `k` for k-fold split. Choose among 0, 1, and 2.
290-
split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test'(all trials except test split).
294+
split: The split to use. Choose among 'train', 'valid', 'test', 'all', and 'wo_test'
295+
(all trials except test split).
291296
292297
"""
293298

0 commit comments

Comments
 (0)