Skip to content

[DeepSpeed] [success] trained t5-11b on 1x 40GB gpu #9996

Closed
@stas00

Description

@stas00

Managed to train t5-11b on 1x 40GB gpu w/ Deepspeed (A100-SXM4-40GB)

Thank you, @PeterAJansen for letting me use your hardware!

Thank you, @jeffra and @samyam, for not believing that it is not possible to train t5-11b on 1x 40GB gpu w/ Deepspeed and supporting me that lead me to find a few bugs in the integration.

Sharing details for those who need.

If you want to try this at home please make sure you use transformers master as some bug fixes were just merged in

Well, it's similar to the t5-3b on 24GB success reported here and here.
But this time t5-11b on 1x 40GB gpu (or 4x if you wanted things faster)

As someone asked me before you need a huge amount of general RAM to use ZeRO-Offload for a huge model:

  • for t5-3b on 1x 24GB gpu: ~71GB RAM
  • for t5-11b on 1x 40GB gpu: ~234GB RAM

I was using /usr/bin/time -v program to get the peak memory measurement - it's the Maximum resident set size entry in the final report.

Question: I don't think /usr/bin/time does the right thing for multi-process - I think it only measures the parent process. e.g. with 4x gpus it reported only 102GB RAM, but I clearly saw in top that it was around 240GB. If you have an easy way to measure peak memory that takes into an account forked processes I'm all ears.

Batch sizes on one gpu:

  • with buffers of 5e8 I was able to run BS=2, which might be too small for training,
  • but with 2e8 I managed to squeeze in BS=10 for training, but OOMed on prediction

I'm referring to these batch sizes in ds_config.json:

        "allgather_bucket_size": 2e8,
        "reduce_bucket_size": 2e8,

And I tested for 2x and 4x DDP as well, BS=16 OOMed, BS=8 was good so I used that - but could probably squeeze some more.

edit1: later tests show that my test was too short and wasn't getting the CPU Adam optimizer kick in, as it skips the first 20 or so tests because of the overflow. So once it kicks in it takes more GPU memory, so the practical BS is much smaller - I think around 2 on this setup. So most likely you will need to use BS=2 for real work, until things get optimized even more.

edit2: things are getting re-shuffling in the tests, so the default ds_config.json file has moved in master to a new, hopefully permanent home. It's now at examples/tests/deepspeed/ds_config.json so you will need to adjust the command line to reflect this new location or simply copy it over to where the old one used to be.

here is the full benchmark:

# 1 gpu: 
# only training fits with this BS, eval needs a smaller BS

export BS=8; rm -rf output_dir; PYTHONPATH=../../src USE_TF=0 /usr/bin/time -v deepspeed --num_gpus=1 ./finetune_trainer.py --model_name_or_path t5-11b --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_predict --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 5 --n_train 60 --n_val 10 --n_test 10 --deepspeed ds_config.json --fp16

{'train_runtime': 31.0897, 'train_samples_per_second': 0.257, 'epoch': 1.0}

# 2 gpus:

export BS=8; rm -rf output_dir; PYTHONPATH=../../src USE_TF=0 /usr/bin/time -v deepspeed --num_gpus=2 ./finetune_trainer.py --model_name_or_path t5-11b --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_predict --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 5 --n_train 60 --n_val 10 --n_test 10 --deepspeed ds_config.json --fp16

{'train_runtime': 17.9026, 'train_samples_per_second': 0.223, 'epoch': 1.0}

# 4 gpus

export BS=8; rm -rf output_dir; PYTHONPATH=../../src USE_TF=0 /usr/bin/time -v deepspeed --num_gpus=4 ./finetune_trainer.py --model_name_or_path t5-11b --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_predict --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 5 --n_train 60 --n_val 10 --n_test 10 --deepspeed ds_config.json --fp16

{'train_runtime': 10.4404, 'train_samples_per_second': 0.192, 'epoch': 1.0}

Checkpointing should allow making even bigger batch sizes.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions