@@ -189,45 +189,6 @@ def _allreduce_embedding_grads(model: List[torch.nn.Module], config: Transformer
189
189
_allreduce_position_embedding_grads (model , config )
190
190
191
191
192
- def _allreduce_layernorm_grads (model : List [torch .nn .Module ], config : TransformerConfig ):
193
- """
194
- All-reduce layernorm grads (for sequence parallelism).
195
- """
196
-
197
- # All-reduce layernorm parameters across model parallel nodes
198
- # when sequence parallelism is used
199
- if parallel_state .get_tensor_model_parallel_world_size () > 1 and (
200
- config .sequence_parallel or config .qk_layernorm
201
- ):
202
- params = []
203
- grads = []
204
- for model_chunk in model :
205
- ddp_config = model_chunk .ddp_config
206
- for name , param in get_attr_wrapped_model (model_chunk , 'named_parameters' )():
207
- if param .requires_grad and (
208
- getattr (param , 'sequence_parallel' , False )
209
- or 'q_layernorm' in name
210
- or 'k_layernorm' in name
211
- ):
212
- params .append (param )
213
- grad_attr = _get_main_grad_attr (param , ddp_config .use_custom_fsdp )
214
- grad = getattr (param , grad_attr )
215
- grad = _unshard_if_dtensor (grad )
216
- grads .append (grad .data )
217
- if grads :
218
- coalesced = _flatten_dense_tensors (grads )
219
- torch .distributed .all_reduce (
220
- coalesced , group = parallel_state .get_tensor_model_parallel_group ()
221
- )
222
- for param , buf , synced in zip (
223
- params , grads , _unflatten_dense_tensors (coalesced , grads )
224
- ):
225
- buf .copy_ (synced )
226
- grad_attr = _get_main_grad_attr (param , ddp_config .use_custom_fsdp )
227
- orig_grad = getattr (param , grad_attr )
228
- setattr (param , grad_attr , _reshard_if_dtensor (buf , orig_grad ))
229
-
230
-
231
192
def _update_router_expert_bias (model : List [torch .nn .Module ], config : TransformerConfig ):
232
193
"""
233
194
Update the expert bias of the router for a global batch.
@@ -256,6 +217,70 @@ def _update_router_expert_bias(model: List[torch.nn.Module], config: Transformer
256
217
expert_bias .copy_ (updated_expert_bias )
257
218
258
219
220
+ def _allreduce_non_tensor_model_parallel_grads (
221
+ model : List [torch .nn .Module ], config : TransformerConfig
222
+ ):
223
+ """
224
+ All-reduce both layernorm grads (for sequence parallelism) and
225
+ gradients from modules with average_gradients_across_tp_domain=True
226
+ across tensor-model-parallel ranks.
227
+ """
228
+ if parallel_state .get_tensor_model_parallel_world_size () <= 1 :
229
+ return
230
+
231
+ params_sum = []
232
+ grads_sum = []
233
+ params_avg = []
234
+ grads_avg = []
235
+
236
+ for model_chunk in model :
237
+ ddp_config = model_chunk .ddp_config
238
+ for name , param in get_attr_wrapped_model (model_chunk , 'named_parameters' )():
239
+ if param .requires_grad :
240
+ # Check if this param needs average reduction (average_gradients_across_tp_domain)
241
+ if getattr (param , "average_gradients_across_tp_domain" , False ):
242
+ params_avg .append (param )
243
+ grad_attr = _get_main_grad_attr (param , ddp_config .use_custom_fsdp )
244
+ grad = getattr (param , grad_attr )
245
+ grad = _unshard_if_dtensor (grad )
246
+ grads_avg .append (grad .data )
247
+ # Check if this param needs sum reduction (sequence parallel or qk_layernorm)
248
+ elif (config .sequence_parallel and getattr (param , "sequence_parallel" , False )) or (
249
+ config .qk_layernorm and ("q_layernorm" in name or "k_layernorm" in name )
250
+ ):
251
+ params_sum .append (param )
252
+ grad_attr = _get_main_grad_attr (param , ddp_config .use_custom_fsdp )
253
+ grad = getattr (param , grad_attr )
254
+ grad = _unshard_if_dtensor (grad )
255
+ grads_sum .append (grad .data )
256
+
257
+ # Loop grads and perform correct all-reduce
258
+ for params , grads , all_reduce_op in zip (
259
+ [params_sum , params_avg ],
260
+ [grads_sum , grads_avg ],
261
+ [torch .distributed .ReduceOp .SUM , torch .distributed .ReduceOp .AVG ],
262
+ ):
263
+ if grads :
264
+ coalesced = _flatten_dense_tensors (grads )
265
+ torch .distributed .all_reduce (
266
+ coalesced , op = all_reduce_op , group = parallel_state .get_tensor_model_parallel_group ()
267
+ )
268
+ for param , buf , synced in zip (
269
+ params , grads , _unflatten_dense_tensors (coalesced , grads )
270
+ ):
271
+ buf .copy_ (synced )
272
+ grad_attr = _get_main_grad_attr (param , ddp_config .use_custom_fsdp )
273
+ orig_grad = getattr (param , grad_attr )
274
+ setattr (param , grad_attr , _reshard_if_dtensor (buf , orig_grad ))
275
+
276
+
277
+ """
278
+ This is an alias to _allreduce_non_tensor_model_parallel_grads that we must
279
+ maintain for legacy tests. We can remove this proxy in mcore 0.14.
280
+ """
281
+ _allreduce_layernorm_grads = _allreduce_non_tensor_model_parallel_grads
282
+
283
+
259
284
def finalize_model_grads (model : List [torch .nn .Module ], num_tokens : Optional [torch .Tensor ] = None ):
260
285
"""
261
286
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
@@ -282,14 +307,14 @@ def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torc
282
307
if config .timers is not None :
283
308
config .timers ('conditional-embedder-grads-all-reduce' ).stop ()
284
309
285
- # All-reduce layer-norm grads (for sequence parallelism).
310
+ # All-reduce layer-norm grads (for sequence parallelism) and non-tensor parallel modules .
286
311
if config .timers is not None :
287
- config .timers ('layernorm -grads-all-reduce' , log_level = 1 ).start (
312
+ config .timers ('non-tensor-parallel -grads-all-reduce' , log_level = 1 ).start (
288
313
barrier = config .barrier_with_L1_time
289
314
)
290
- _allreduce_layernorm_grads (model , config )
315
+ _allreduce_non_tensor_model_parallel_grads (model , config )
291
316
if config .timers is not None :
292
- config .timers ('layernorm -grads-all-reduce' ).stop ()
317
+ config .timers ('non-tensor-parallel -grads-all-reduce' ).stop ()
293
318
294
319
# All-reduce embedding grads (for pipeline parallelism).
295
320
if config .timers is not None :
0 commit comments