-
Notifications
You must be signed in to change notification settings - Fork 387
Add context parallelism #1445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add context parallelism #1445
Conversation
b342894
to
6b9c29d
Compare
3f55480
to
3ba3d54
Compare
4c4ef47
to
1b65e2e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
go anisha go
f6d6f39
to
4169d41
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved! Please don't change the _DENSE_VMEM_LIMIT in this PR though, instead we can document using a different VMEM value in our MaxText configs
9f18043
to
c6ea0d7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
c6ea0d7
to
f4926bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the change!
f4926bf
to
c4ead2f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm again
c4ead2f
to
fbe8c2d
Compare
fbe8c2d
to
4e7d847
Compare
Description
Adding Context Parallelism to MaxText to support long context length in training on TPUs, following PR 1133 which adds it for GPUs
This is because we want to support the common context parallelism paradigm of AG keys and values, shard on query sequence.
Previously, MaxText implemented sequence parallelism by sharding the attention operations on heads (the FF operation being sharded by sequence). This decision was made because the heads dim acts like a batch dimension. This PR's solution frees heads dim to be used for other model parallelism.
We also ensure that we employ "load balancing" in the context parallelism - meaning, we reorder the tokens and the attention mask such that in attention all devices have somewhat similar amount of work (and hence time taken) other wise, with say context parallelism=2, device which receives shard with say token 0,1,2,3 have much lesser work than device which received shard with say token 4,5,6,7, so we load balance by dividing the token as 0,1,6,7 and 2,3,4,5
We use
jnp.take
in the load balancing, which hopefully can be improved and is being tracked in b/413770626FIXES: b/377904983
Tests
Tested locally on v5p-8,
and also on v6e-256 for Llama3.1-70b with context length of 131072 tokens
python3 -m benchmarks.maxtext_xpk_runner
llama3_1_70b_131072
frommaxtext/benchmarks/maxtext_trillium_model_configs.py
Also added unit test for context parallelism in
MaxText/tests/attention_test.py
Updated
maxtext/end_to_end/tpu/llama2/7b/test_llama2_7b.sh
to double as integration test for context parallelism for training intrain.py
viaici_context_parallelism
.Also, used
ici_context_parallelism
forforward_pass_logit_checker
to test forward pass since we can't use flash attention indecode.py
Checklist
Before submitting this PR, please make sure (put X in square brackets):