Skip to content

Add context parallelism #1445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2025
Merged

Conversation

A9isha
Copy link
Collaborator

@A9isha A9isha commented Mar 22, 2025

Description

Adding Context Parallelism to MaxText to support long context length in training on TPUs, following PR 1133 which adds it for GPUs
This is because we want to support the common context parallelism paradigm of AG keys and values, shard on query sequence.
Previously, MaxText implemented sequence parallelism by sharding the attention operations on heads (the FF operation being sharded by sequence). This decision was made because the heads dim acts like a batch dimension. This PR's solution frees heads dim to be used for other model parallelism.

We also ensure that we employ "load balancing" in the context parallelism - meaning, we reorder the tokens and the attention mask such that in attention all devices have somewhat similar amount of work (and hence time taken) other wise, with say context parallelism=2, device which receives shard with say token 0,1,2,3 have much lesser work than device which received shard with say token 4,5,6,7, so we load balance by dividing the token as 0,1,6,7 and 2,3,4,5

We use jnp.take in the load balancing, which hopefully can be improved and is being tracked in b/413770626

FIXES: b/377904983

Tests

Tested locally on v5p-8,

and also on v6e-256 for Llama3.1-70b with context length of 131072 tokens

  1. using python3 -m benchmarks.maxtext_xpk_runner
  2. by updating the model to be used as llama3_1_70b_131072 from maxtext/benchmarks/maxtext_trillium_model_configs.py

Also added unit test for context parallelism in MaxText/tests/attention_test.py

Updated maxtext/end_to_end/tpu/llama2/7b/test_llama2_7b.sh to double as integration test for context parallelism for training in train.py via ici_context_parallelism.
Also, used ici_context_parallelism for forward_pass_logit_checker to test forward pass since we can't use flash attention in decode.py

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@A9isha A9isha force-pushed the anisha-context-parallelism branch 2 times, most recently from b342894 to 6b9c29d Compare March 22, 2025 19:37
@A9isha A9isha force-pushed the anisha-context-parallelism branch 2 times, most recently from 3f55480 to 3ba3d54 Compare April 15, 2025 23:01
@A9isha A9isha force-pushed the anisha-context-parallelism branch 3 times, most recently from 4c4ef47 to 1b65e2e Compare April 18, 2025 01:37
@A9isha A9isha marked this pull request as ready for review April 18, 2025 02:27
@A9isha A9isha requested a review from rdyro April 18, 2025 02:27
Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

go anisha go

@A9isha A9isha force-pushed the anisha-context-parallelism branch 2 times, most recently from f6d6f39 to 4169d41 Compare April 22, 2025 00:17
Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved! Please don't change the _DENSE_VMEM_LIMIT in this PR though, instead we can document using a different VMEM value in our MaxText configs

@A9isha A9isha force-pushed the anisha-context-parallelism branch 5 times, most recently from 9f18043 to c6ea0d7 Compare April 30, 2025 23:56
@A9isha A9isha requested review from Obliviour, gobbleturk and shralex May 1, 2025 00:54
Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@A9isha A9isha force-pushed the anisha-context-parallelism branch from c6ea0d7 to f4926bf Compare May 1, 2025 01:24
@A9isha A9isha requested a review from richjames0 May 1, 2025 01:28
Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the change!

@A9isha A9isha force-pushed the anisha-context-parallelism branch from f4926bf to c4ead2f Compare May 1, 2025 17:35
@A9isha A9isha requested review from richjames0 and gobbleturk May 2, 2025 20:27
Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm again

@A9isha A9isha force-pushed the anisha-context-parallelism branch from c4ead2f to fbe8c2d Compare May 2, 2025 22:14
@A9isha A9isha force-pushed the anisha-context-parallelism branch from fbe8c2d to 4e7d847 Compare May 2, 2025 22:23
@copybara-service copybara-service bot merged commit 18ccd37 into main May 2, 2025
17 checks passed
@copybara-service copybara-service bot deleted the anisha-context-parallelism branch May 2, 2025 23:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants