File tree 2 files changed +30
-1
lines changed
2 files changed +30
-1
lines changed Original file line number Diff line number Diff line change @@ -231,7 +231,19 @@ def __getitem__(self, idx):
231
231
index_dataloader = DataLoader (index_dataset , batch_size = batch_size )
232
232
233
233
output = []
234
- for index_batch in index_dataloader :
234
+ for batch_idx , index_batch in enumerate (index_dataloader ):
235
+ # NOTE(celia): This is to prevent that adding the offset to the
236
+ # penultimate batch for larger offset make the batch_end_idx larger
237
+ # than the input length, while we also don't want to drop the last
238
+ # samples that do not fit in a complete batch.
239
+ if batch_idx == (len (index_dataloader ) - 2 ):
240
+ # penultimate batch, last complete batch
241
+ last_batch = index_batch
242
+ continue
243
+ if batch_idx == (len (index_dataloader ) - 1 ):
244
+ # last batch, incomplete
245
+ index_batch = torch .cat ((last_batch , index_batch ), dim = 0 )
246
+
235
247
batch_start_idx , batch_end_idx = index_batch [0 ], index_batch [- 1 ] + 1
236
248
batched_data = _get_batch (inputs = inputs ,
237
249
offset = offset ,
Original file line number Diff line number Diff line change @@ -1506,3 +1506,20 @@ def test_new_transform(model_architecture, device):
1506
1506
embedding2 = cebra_model .transform_deprecated (X , session_id = 2 )
1507
1507
assert np .allclose (embedding1 , embedding2 , rtol = 1e-5 ,
1508
1508
atol = 1e-8 ), "Arrays are not close enough"
1509
+
1510
+
1511
+ def test_last_incomplete_batch_smaller_than_offset ():
1512
+ """
1513
+ When offset of the model is larger than the remaining samples in the
1514
+ last batch, an error could happen. We merge the penultimate
1515
+ and last batches together to avoid this.
1516
+ """
1517
+ train = cebra .data .TensorDataset (neural = np .random .rand (20111 , 100 ),
1518
+ continuous = np .random .rand (20111 , 2 ))
1519
+
1520
+ model = cebra .CEBRA (max_iterations = 2 ,
1521
+ model_architecture = "offset36-model-more-dropout" ,
1522
+ device = "cpu" )
1523
+ model .fit (train .neural , train .continuous )
1524
+
1525
+ _ = model .transform (train .neural , batch_size = 300 )
You can’t perform that action at this time.
0 commit comments