@@ -208,7 +208,7 @@ def setup(self):
208
208
self .decoder_layer = self .get_decoder_layers ()
209
209
self .norm_layer = self .get_norm_layer ()
210
210
if self .config .using_pipeline_parallelism :
211
- pipeline_stage_module = self .get_pipeline_stage_module (self .decoder_layer [ 0 ] )
211
+ pipeline_stage_module = self .get_pipeline_stage_module (self .decoder_layer )
212
212
remat_policy = self .get_remat_policy ()
213
213
self .pipeline_module = pipeline .Pipeline (
214
214
config = self .config , mesh = self .mesh , layers = pipeline_stage_module , remat_policy = remat_policy
@@ -397,8 +397,15 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metdata_axis_name, mes
397
397
)
398
398
return scan_fn (config = cfg , mesh = mesh , name = metdata_axis_name , quant = self .quant )
399
399
400
- def get_pipeline_stage_module (self , base_stage ):
400
+ def get_pipeline_stage_module (self , decoder_blocks ):
401
+ def get_layer_to_pipeline (blocks , cfg ):
402
+ if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
403
+ return blocks [1 ] # return the sparse block
404
+ else :
405
+ return blocks [0 ]
401
406
cfg = self .config
407
+ base_stage = get_layer_to_pipeline (decoder_blocks , cfg )
408
+
402
409
if cfg .set_remat_policy_on_layers_per_stage :
403
410
policy = self .get_remat_policy ()
404
411
base_stage = self .set_remat_policy ([base_stage ], policy )[0 ]
@@ -463,20 +470,46 @@ def __call__(
463
470
)
464
471
else :
465
472
partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
466
- y = self .pipeline_module (
467
- y , decoder_segment_ids , decoder_positions , deterministic , model_mode , partition_spec = partition_spec
468
- )
469
- remaining_layers = self .config .num_decoder_layers - self .config .pipeline_parallel_layers
470
- if remaining_layers > 0 :
473
+ if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
474
+ assert len (RemattedBlockLayers ) == 2 , f"Scanned layers must have a length of 2 using deepseek."
475
+ dense_layer = RemattedBlockLayers [0 ]
476
+ moe_layer = RemattedBlockLayers [1 ]
477
+ num_moe_layers = cfg .num_decoder_layers - cfg .first_num_dense_layers
478
+ num_moe_layers_outside_pp = num_moe_layers - self .config .pipeline_parallel_layers
471
479
logical_axis_rules_pp_as_dp = maxtext_utils .logical_axis_rules_pp_act_as_dp (self .config .logical_axis_rules )
480
+ # We chose not to pipeline the dense layers, only sparse for SPMD.
472
481
with self .mesh , nn .partitioning .axis_rules (logical_axis_rules_pp_as_dp ):
473
- y , _ = self .scan_decoder_layers (cfg , RemattedBlockLayers [ 0 ], remaining_layers , "layers " , mesh )(
482
+ y , _ = self .scan_decoder_layers (cfg , dense_layer , cfg . first_num_dense_layers , "dense_layers " , mesh )(
474
483
y ,
475
484
decoder_segment_ids ,
476
485
decoder_positions ,
477
486
deterministic ,
478
487
model_mode ,
479
488
)
489
+ if num_moe_layers_outside_pp > 0 :
490
+ y , _ = self .scan_decoder_layers (cfg , moe_layer , num_moe_layers_outside_pp , "moe_layers" , mesh )(
491
+ y ,
492
+ decoder_segment_ids ,
493
+ decoder_positions ,
494
+ deterministic ,
495
+ model_mode ,
496
+ )
497
+ y = self .pipeline_module (y , decoder_segment_ids , decoder_positions , deterministic , model_mode , partition_spec = partition_spec )
498
+ else :
499
+ y = self .pipeline_module (
500
+ y , decoder_segment_ids , decoder_positions , deterministic , model_mode , partition_spec = partition_spec
501
+ )
502
+ remaining_layers = self .config .num_decoder_layers - self .config .pipeline_parallel_layers
503
+ if remaining_layers > 0 :
504
+ logical_axis_rules_pp_as_dp = maxtext_utils .logical_axis_rules_pp_act_as_dp (self .config .logical_axis_rules )
505
+ with self .mesh , nn .partitioning .axis_rules (logical_axis_rules_pp_as_dp ):
506
+ y , _ = self .scan_decoder_layers (cfg , RemattedBlockLayers [0 ], remaining_layers , "layers" , mesh )(
507
+ y ,
508
+ decoder_segment_ids ,
509
+ decoder_positions ,
510
+ deterministic ,
511
+ model_mode ,
512
+ )
480
513
else :
481
514
if cfg .scan_layers :
482
515
if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
0 commit comments