Skip to content

TF2 DeBERTaV2 runs super slow on TPUs #18239

Closed
@WissamAntoun

Description

@WissamAntoun

System Info

latest version of transformers, Colab TPU, tensorflow 2

Who can help?

@kamalkraj @Rocketknight1 @BigBird01

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

It's currently hard to share code and access to the google bucket. But I believe any TF2 DeBERTaV2 code running on TPUs will have this issue

Expected behavior

I've been trying to train a deberta v3 model on GPU and TPUs. I got it to work on multi-node and multi-gpus using Nvidia deeplearning examples libraries https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/
I basically used the training setup and loop from the BERT code, the dataset utils from the ELECTRA code, and the model from Huggingface transformers with some changes in order to share embeddings.

On 6xA40 45gb gpus i get around 1370 sentences per seconds during training (which is lower than what Nvidia gets for Electra but it's fine).

Ok, now the problem.... on TPU i get 20 sentences per second

I traced the issue back to the tf.gather function here https://github.com/huggingface/transformers/blob/main/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py#L525

I ran TPU profiling and this is the output:
image

GatherV2 takes most of the time:
image

zoomed in pictures of the fast ops
image

Also, I'm not sure if this is TPU specific since on GPUs the training ~30% slower compared to regular ELECTRA.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions