Skip to content

Commit 7aab282

Browse files
committed
Fix extra docs errors
1 parent 04a102f commit 7aab282

File tree

4 files changed

+7
-5
lines changed

4 files changed

+7
-5
lines changed

cebra/data/multi_session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def configure_for(self, model: "cebra.models.Model"):
110110
"""Configure the dataset offset for the provided model.
111111
112112
Call this function before indexing the dataset. This sets the
113-
:py:attr:`offset` attribute of the dataset.
113+
:py:attr:`cebra_data.Dataset.offset` attribute of the dataset.
114114
115115
Args:
116116
model: The model to configure the dataset for.

cebra/data/single_session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def configure_for(self, model: "cebra.models.Model"):
7676
"""Configure the dataset offset for the provided model.
7777
7878
Call this function before indexing the dataset. This sets the
79-
:py:attr:`offset` attribute of the dataset.
79+
:py:attr:`cebra_data.Dataset.offset` attribute of the dataset.
8080
8181
Args:
8282
model: The model to configure the dataset for.

cebra/solver/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _transform(
185185
model: cebra.models.Model,
186186
inputs: torch.Tensor,
187187
pad_before_transform: bool,
188-
offset: cebra.data.Offset,
188+
offset: cebra.data.datatypes.Offset,
189189
) -> torch.Tensor:
190190
"""Compute the embedding.
191191
@@ -206,7 +206,7 @@ def _transform(
206206

207207
def _batched_transform(model: cebra.models.Model, inputs: torch.Tensor,
208208
batch_size: int, pad_before_transform: bool,
209-
offset: cebra.data.Offset) -> torch.Tensor:
209+
offset: cebra.data.datatypes.Offset) -> torch.Tensor:
210210
"""Compute the embedding on batched inputs.
211211
212212
Args:

tests/test_solver.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565

6666
# multi_session_tests.append((*args, cebra.solver.MultiSessionAuxVariableSolver))
6767

68+
6869
def _get_loader(data, loader_initfunc):
6970
kwargs = dict(num_steps=5, batch_size=32)
7071
loader = loader_initfunc(data, **kwargs)
@@ -574,6 +575,7 @@ def test_select_model_multi_session(data_name, model_name, session_id,
574575
assert offset.left == offset_.left and offset.right == offset_.right
575576
assert model == model_
576577

578+
577579
models = [
578580
"offset1-model",
579581
"offset10-model",
@@ -683,7 +685,7 @@ def test_batched_transform_multi_session(data_name, model_name, padding,
683685
n_samples = dataset._datasets[0].neural.shape[0]
684686
assert all(
685687
d.neural.shape[0] == n_samples for d in dataset._datasets
686-
), # all sessions need to have same number of samples
688+
), "for this set all of the sessions need to have same number of samples."
687689

688690
smallest_batch_length = n_samples - batch_size
689691
offset_ = model[0].get_offset()

0 commit comments

Comments
 (0)