Skip to content

Commit 61d8e44

Browse files
SunMarczucchini-nlp
authored andcommitted
Fix tests due to breaking change in accelerate (huggingface#39451)
* update values * fix
1 parent 1561fec commit 61d8e44

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/trainer/test_trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3394,7 +3394,7 @@ def test_auto_batch_size_with_deepspeed(self):
33943394
)
33953395
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
33963396
trainer.train()
3397-
self.assertEqual(trainer._train_batch_size, 8)
3397+
self.assertEqual(trainer._train_batch_size, 14)
33983398

33993399
def test_auto_batch_size_with_resume_from_checkpoint(self):
34003400
train_dataset = RegressionDataset(length=128)
@@ -3414,16 +3414,16 @@ def test_auto_batch_size_with_resume_from_checkpoint(self):
34143414
)
34153415
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
34163416
trainer.train()
3417-
# After `auto_find_batch_size` is ran we should now be at 8
3418-
self.assertEqual(trainer._train_batch_size, 8)
3417+
# After `auto_find_batch_size` is ran we should now be at 16*0.9=14
3418+
self.assertEqual(trainer._train_batch_size, 14)
34193419

34203420
# We can then make a new Trainer
34213421
trainer = Trainer(model, args, train_dataset=train_dataset)
34223422
# Check we are at 16 to start
34233423
self.assertEqual(trainer._train_batch_size, 16 * max(trainer.args.n_gpu, 1))
34243424
trainer.train(resume_from_checkpoint=True)
3425-
# We should be back to 8 again, picking up based upon the last ran Trainer
3426-
self.assertEqual(trainer._train_batch_size, 8)
3425+
# We should be back to 14 again, picking up based upon the last ran Trainer
3426+
self.assertEqual(trainer._train_batch_size, 14)
34273427

34283428
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
34293429
def test_training_with_resume_from_checkpoint_false(self):

tests/trainer/test_trainer_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def mock_training_loop_function(batch_size):
464464
raise RuntimeError("CUDA out of memory.")
465465

466466
mock_training_loop_function()
467-
self.assertEqual(batch_sizes, [64, 32, 16])
467+
self.assertEqual(batch_sizes, [64, 57, 51, 45, 40, 36, 32, 28, 25, 22, 19, 17, 15])
468468

469469
@require_accelerate
470470
def test_executable_batch_size_no_search(self):

0 commit comments

Comments
 (0)