Skip to content

Commit 81b964c

Browse files
authored
Concatenate last batches for batched inference (#200)
* Concatenate last to batches for batched inference * Add test case
1 parent 0eac868 commit 81b964c

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

cebra/solver/base.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,19 @@ def __getitem__(self, idx):
231231
index_dataloader = DataLoader(index_dataset, batch_size=batch_size)
232232

233233
output = []
234-
for index_batch in index_dataloader:
234+
for batch_idx, index_batch in enumerate(index_dataloader):
235+
# NOTE(celia): This is to prevent that adding the offset to the
236+
# penultimate batch for larger offset make the batch_end_idx larger
237+
# than the input length, while we also don't want to drop the last
238+
# samples that do not fit in a complete batch.
239+
if batch_idx == (len(index_dataloader) - 2):
240+
# penultimate batch, last complete batch
241+
last_batch = index_batch
242+
continue
243+
if batch_idx == (len(index_dataloader) - 1):
244+
# last batch, incomplete
245+
index_batch = torch.cat((last_batch, index_batch), dim=0)
246+
235247
batch_start_idx, batch_end_idx = index_batch[0], index_batch[-1] + 1
236248
batched_data = _get_batch(inputs=inputs,
237249
offset=offset,

tests/test_sklearn.py

+17
Original file line numberDiff line numberDiff line change
@@ -1506,3 +1506,20 @@ def test_new_transform(model_architecture, device):
15061506
embedding2 = cebra_model.transform_deprecated(X, session_id=2)
15071507
assert np.allclose(embedding1, embedding2, rtol=1e-5,
15081508
atol=1e-8), "Arrays are not close enough"
1509+
1510+
1511+
def test_last_incomplete_batch_smaller_than_offset():
1512+
"""
1513+
When offset of the model is larger than the remaining samples in the
1514+
last batch, an error could happen. We merge the penultimate
1515+
and last batches together to avoid this.
1516+
"""
1517+
train = cebra.data.TensorDataset(neural=np.random.rand(20111, 100),
1518+
continuous=np.random.rand(20111, 2))
1519+
1520+
model = cebra.CEBRA(max_iterations=2,
1521+
model_architecture="offset36-model-more-dropout",
1522+
device="cpu")
1523+
model.fit(train.neural, train.continuous)
1524+
1525+
_ = model.transform(train.neural, batch_size=300)

0 commit comments

Comments
 (0)