Closed
Description
Hi,
I am running the same QuestionAnswering model using Pytorch and Flax, but I am seeing quite a bad performance for Flax.
Model | Time |
---|---|
Pytorch (GPU) | 30 ms |
Flax (GPU) | 4 sec |
Flax (TPU) | 15 sec |
I have warmed the jit function, and I think I have put the model and data on the device in both cases. What am I doing wrong here?
Notebook: https://colab.research.google.com/drive/18ouY4S9rYtlEC1dht_w8WU4fFKnP9yhP#scrollTo=FLfV3ymlwRJC
Environment info
transformers
version: 4.16.2- Platform: Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic
- Python version: 3.7.12
- PyTorch version (GPU?): 1.10.0+cu111 (True)
- Tensorflow version (GPU?): 2.7.0 (True)
- Flax version (CPU?/GPU?/TPU?): 0.4.0 (gpu)
- Jax version: 0.2.25
- JaxLib version: 0.1.71
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help
Information
Model I am using (Bert, XLNet ...): BERT (FlaxBertForQuestionAnswering, BertForQuestionAnswering)
The problem arises when using:
- the official example scripts: (give details below)
- [ X] my own modified scripts: (give details below)
The tasks I am working on is:
- an official GLUE/SQUaD task: (give the name)
- [ X] my own task or dataset: (give details below)
To reproduce
Steps to reproduce the behavior:
- Run this notebook on Google Colab (GPU runtime)
Expected behavior
About the same performance for Pytorch as for Flax.
Metadata
Metadata
Assignees
Labels
No labels