Skip to content

Commit 46bbf6d

Browse files
committed
Add support for pipeling deepseek
1 parent 6247f44 commit 46bbf6d

File tree

3 files changed

+65
-10
lines changed

3 files changed

+65
-10
lines changed

MaxText/layers/models.py

+41-8
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def setup(self):
208208
self.decoder_layer = self.get_decoder_layers()
209209
self.norm_layer = self.get_norm_layer()
210210
if self.config.using_pipeline_parallelism:
211-
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer[0])
211+
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer)
212212
remat_policy = self.get_remat_policy()
213213
self.pipeline_module = pipeline.Pipeline(
214214
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
@@ -397,8 +397,15 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metdata_axis_name, mes
397397
)
398398
return scan_fn(config=cfg, mesh=mesh, name=metdata_axis_name, quant=self.quant)
399399

400-
def get_pipeline_stage_module(self, base_stage):
400+
def get_pipeline_stage_module(self, decoder_blocks):
401+
def get_layer_to_pipeline(blocks, cfg):
402+
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
403+
return blocks[1] # return the sparse block
404+
else:
405+
return blocks[0]
401406
cfg = self.config
407+
base_stage = get_layer_to_pipeline(decoder_blocks, cfg)
408+
402409
if cfg.set_remat_policy_on_layers_per_stage:
403410
policy = self.get_remat_policy()
404411
base_stage = self.set_remat_policy([base_stage], policy)[0]
@@ -463,20 +470,46 @@ def __call__(
463470
)
464471
else:
465472
partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
466-
y = self.pipeline_module(
467-
y, decoder_segment_ids, decoder_positions, deterministic, model_mode, partition_spec=partition_spec
468-
)
469-
remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers
470-
if remaining_layers > 0:
473+
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
474+
assert len(RemattedBlockLayers) == 2, f"Scanned layers must have a length of 2 using deepseek."
475+
dense_layer = RemattedBlockLayers[0]
476+
moe_layer = RemattedBlockLayers[1]
477+
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
478+
num_moe_layers_outside_pp = num_moe_layers - self.config.pipeline_parallel_layers
471479
logical_axis_rules_pp_as_dp = maxtext_utils.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
480+
# We chose not to pipeline the dense layers, only sparse for SPMD.
472481
with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp):
473-
y, _ = self.scan_decoder_layers(cfg, RemattedBlockLayers[0], remaining_layers, "layers", mesh)(
482+
y, _ = self.scan_decoder_layers(cfg, dense_layer, cfg.first_num_dense_layers, "dense_layers", mesh)(
474483
y,
475484
decoder_segment_ids,
476485
decoder_positions,
477486
deterministic,
478487
model_mode,
479488
)
489+
if num_moe_layers_outside_pp > 0:
490+
y, _ = self.scan_decoder_layers(cfg, moe_layer, num_moe_layers_outside_pp, "moe_layers", mesh)(
491+
y,
492+
decoder_segment_ids,
493+
decoder_positions,
494+
deterministic,
495+
model_mode,
496+
)
497+
y = self.pipeline_module(y, decoder_segment_ids, decoder_positions, deterministic, model_mode, partition_spec=partition_spec)
498+
else:
499+
y = self.pipeline_module(
500+
y, decoder_segment_ids, decoder_positions, deterministic, model_mode, partition_spec=partition_spec
501+
)
502+
remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers
503+
if remaining_layers > 0:
504+
logical_axis_rules_pp_as_dp = maxtext_utils.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
505+
with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp):
506+
y, _ = self.scan_decoder_layers(cfg, RemattedBlockLayers[0], remaining_layers, "layers", mesh)(
507+
y,
508+
decoder_segment_ids,
509+
decoder_positions,
510+
deterministic,
511+
model_mode,
512+
)
480513
else:
481514
if cfg.scan_layers:
482515
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:

MaxText/pyconfig.py

-2
Original file line numberDiff line numberDiff line change
@@ -850,8 +850,6 @@ def pipeline_first_axis(raw_keys):
850850

851851

852852
def validate_deepseek_moe(raw_keys):
853-
if raw_keys["decoder_block"] == "deepseek" and using_pipeline_parallelism(raw_keys):
854-
raise ValueError("Currently we do not support DeepSeek MoE with pipeline parallelism.")
855853
if raw_keys["n_routing_groups"] != -1:
856854
if raw_keys["topk_routing_group"] == -1:
857855
raise ValueError(f'config topk_routing_group: {raw_keys["topk_routing_group"]} is not defined')

MaxText/tests/train_compile_test.py

+24
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,30 @@ def test_moe_deepseek_without_device_limit(self):
559559
)
560560
)
561561

562+
563+
@pytest.mark.tpu_only
564+
def test_moe_deepseek_pipeline_subset(self):
565+
compiled_trainstep_file = "/tmp/test_moe_deepseek_pipeline_subset.pickle"
566+
train_compile_main(
567+
(
568+
None,
569+
os.path.join(PKG_DIR, "configs", "base.yml"),
570+
f"compiled_trainstep_file={compiled_trainstep_file}",
571+
"compile_topology=v6e-256",
572+
"compile_topology_num_slices=8",
573+
"use_iota_embed=true",
574+
"model_name=deepseek3-671b",
575+
"sparse_matmul=False",
576+
"megablox=False",
577+
"capacity_factor=1",
578+
"per_device_batch_size=1",
579+
"max_target_length=2048",
580+
"pipeline_parallel_layers=56",
581+
"ici_expert_parallelism=16",
582+
"dcn_pipeline_parallelism=8"
583+
)
584+
)
585+
562586
@pytest.mark.skip(reason="b/415132665: Enable it once scan is supported in training for shorter compiler time")
563587
@pytest.mark.tpu_only
564588
def test_moe_llama4_17b_16e(self):

0 commit comments

Comments
 (0)