Skip to content

Add support for SPMD PP for deepseek decoder block #1687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gobbleturk
Copy link
Collaborator

@gobbleturk gobbleturk commented May 5, 2025

Description

Add support for using PP with deepseek, including with the new feature pipeline_parallel_layers which only pipelines a subset of layers. This change can help out with SPMD pipelining since PP must divide the number of layers. E.g. for deepseek there are 58 sparse layers, which does not have many friendly divisors.

With this PR we can pipeline just a subset of the sparse layers, e.g. 56 of them and set PP=8. Other layers will be sharded like DP.

Tests

Ran some locally of smaller v2-16B model and added AOT test - I will paste XPROFS in a bit

  • DeepSeekV2 16B local on v6e-8 with pdb=1, PP=8, pipeline_subset_layers=24 xprof
  • With pure FSDP step time is roughly 2x xprof

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question to help my understanding while reviewing - why can we only pipeline 56 of the 58 sparse layers for deepseek?

@gobbleturk gobbleturk force-pushed the mattdavidow-pipeline-deepseek branch from 23fd9a9 to 03e1524 Compare May 5, 2025 22:34
@gobbleturk gobbleturk force-pushed the mattdavidow-pipeline-deepseek branch from 03e1524 to ee35916 Compare May 5, 2025 22:35
@gobbleturk
Copy link
Collaborator Author

Quick question to help my understanding while reviewing - why can we only pipeline 56 of the 58 sparse layers for deepseek?

we can pipeline 58 but the PP rank must divide the number of layers - 58 isn't a particularly divisor friendly number, I updated this in PR description

@gobbleturk gobbleturk force-pushed the mattdavidow-pipeline-deepseek branch from ee35916 to 46bbf6d Compare May 5, 2025 22:43
Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one general question for my understanding

Comment on lines +482 to +497
y, _ = self.scan_decoder_layers(cfg, dense_layer, cfg.first_num_dense_layers, "dense_layers", mesh)(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
)
if num_moe_layers_outside_pp > 0:
y, _ = self.scan_decoder_layers(cfg, moe_layer, num_moe_layers_outside_pp, "moe_layers", mesh)(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
)
y = self.pipeline_module(y, decoder_segment_ids, decoder_positions, deterministic, model_mode, partition_spec=partition_spec)
Copy link
Collaborator

@bvandermoon bvandermoon May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding, how are the 56 layers that need to be pipelined being specified here? Is it just because they aren't being used in self.scan_decoder_layers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

56 was an example, it can be set via pipeline_parallel_layers which was added in a previous PR

@gobbleturk gobbleturk force-pushed the mattdavidow-pipeline-deepseek branch from 46bbf6d to 14300d9 Compare May 7, 2025 17:59
Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending the test failures

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Matt! One more comment, can we add a unit test in moe_test.py like [this](

def test_megablox_expert_parallelism(self):
, to ensure PP works as expected with assertion?

)
remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers
if remaining_layers > 0:
logical_axis_rules_pp_as_dp = maxtext_utils.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could add some flexibility here? Instead of act as dp, could we act based on passed in config (not sure if doable). DP may not be helpful.

For instance, for DS-v3 (61 layers with 58 MoE layers) with pipeline_parallel_layers=56, we could run FSDP or EP from config bellow?

pipeline_parallel_layers=56 && ici_fsdp=-1/ici_ep=-1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with pipeline_parallel_layers=56 there are only 2 moe layers that will be treated with PP replaced with DP. These two layers will still be sharded by the other sharding strategies - e.g. if config was EP_ICI=16 FSDP_ICI=16 PP_DCN=16, these two layers will be sharded as EP_ICI=16, FSDP_ICI=16 DP_DCN=16, e.g. still the weights are sharded 256 ways

I do like the idea of flexibility, but perhaps we can save this as potential future work?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two layers will be sharded as EP_ICI=16, FSDP_ICI=16 DP_DCN=16, e.g. still the weights are sharded 256 ways

SG! I thought it will only have DP sharding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants