Skip to content

Flax model much slower than Pytorch #15581

Closed
@yorickvanzweeden

Description

@yorickvanzweeden

Hi,

I am running the same QuestionAnswering model using Pytorch and Flax, but I am seeing quite a bad performance for Flax.

Model Time
Pytorch (GPU) 30 ms
Flax (GPU) 4 sec
Flax (TPU) 15 sec

I have warmed the jit function, and I think I have put the model and data on the device in both cases. What am I doing wrong here?

Notebook: https://colab.research.google.com/drive/18ouY4S9rYtlEC1dht_w8WU4fFKnP9yhP#scrollTo=FLfV3ymlwRJC

Environment info

  • transformers version: 4.16.2
  • Platform: Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.7.12
  • PyTorch version (GPU?): 1.10.0+cu111 (True)
  • Tensorflow version (GPU?): 2.7.0 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.4.0 (gpu)
  • Jax version: 0.2.25
  • JaxLib version: 0.1.71
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help

@patrickvonplaten

Information

Model I am using (Bert, XLNet ...): BERT (FlaxBertForQuestionAnswering, BertForQuestionAnswering)

The problem arises when using:

  • the official example scripts: (give details below)
  • [ X] my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • [ X] my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. Run this notebook on Google Colab (GPU runtime)

Expected behavior

About the same performance for Pytorch as for Flax.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions