Skip to content

Commit a165235

Browse files
committed
Merge branch 'hf-diverge-fix' into 'main'
Mark weights from vision encoder to be non-tensor-parallelizable to ensure gradients are correctly all-reduced See merge request ADLR/megatron-lm!3190
2 parents dc8372b + 1674ce3 commit a165235

File tree

3 files changed

+85
-45
lines changed

3 files changed

+85
-45
lines changed

megatron/core/distributed/finalize_model_grads.py

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -189,45 +189,6 @@ def _allreduce_embedding_grads(model: List[torch.nn.Module], config: Transformer
189189
_allreduce_position_embedding_grads(model, config)
190190

191191

192-
def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig):
193-
"""
194-
All-reduce layernorm grads (for sequence parallelism).
195-
"""
196-
197-
# All-reduce layernorm parameters across model parallel nodes
198-
# when sequence parallelism is used
199-
if parallel_state.get_tensor_model_parallel_world_size() > 1 and (
200-
config.sequence_parallel or config.qk_layernorm
201-
):
202-
params = []
203-
grads = []
204-
for model_chunk in model:
205-
ddp_config = model_chunk.ddp_config
206-
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
207-
if param.requires_grad and (
208-
getattr(param, 'sequence_parallel', False)
209-
or 'q_layernorm' in name
210-
or 'k_layernorm' in name
211-
):
212-
params.append(param)
213-
grad_attr = _get_main_grad_attr(param, ddp_config.use_custom_fsdp)
214-
grad = getattr(param, grad_attr)
215-
grad = _unshard_if_dtensor(grad)
216-
grads.append(grad.data)
217-
if grads:
218-
coalesced = _flatten_dense_tensors(grads)
219-
torch.distributed.all_reduce(
220-
coalesced, group=parallel_state.get_tensor_model_parallel_group()
221-
)
222-
for param, buf, synced in zip(
223-
params, grads, _unflatten_dense_tensors(coalesced, grads)
224-
):
225-
buf.copy_(synced)
226-
grad_attr = _get_main_grad_attr(param, ddp_config.use_custom_fsdp)
227-
orig_grad = getattr(param, grad_attr)
228-
setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad))
229-
230-
231192
def _update_router_expert_bias(model: List[torch.nn.Module], config: TransformerConfig):
232193
"""
233194
Update the expert bias of the router for a global batch.
@@ -256,6 +217,70 @@ def _update_router_expert_bias(model: List[torch.nn.Module], config: Transformer
256217
expert_bias.copy_(updated_expert_bias)
257218

258219

220+
def _allreduce_non_tensor_model_parallel_grads(
221+
model: List[torch.nn.Module], config: TransformerConfig
222+
):
223+
"""
224+
All-reduce both layernorm grads (for sequence parallelism) and
225+
gradients from modules with average_gradients_across_tp_domain=True
226+
across tensor-model-parallel ranks.
227+
"""
228+
if parallel_state.get_tensor_model_parallel_world_size() <= 1:
229+
return
230+
231+
params_sum = []
232+
grads_sum = []
233+
params_avg = []
234+
grads_avg = []
235+
236+
for model_chunk in model:
237+
ddp_config = model_chunk.ddp_config
238+
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
239+
if param.requires_grad:
240+
# Check if this param needs average reduction (average_gradients_across_tp_domain)
241+
if getattr(param, "average_gradients_across_tp_domain", False):
242+
params_avg.append(param)
243+
grad_attr = _get_main_grad_attr(param, ddp_config.use_custom_fsdp)
244+
grad = getattr(param, grad_attr)
245+
grad = _unshard_if_dtensor(grad)
246+
grads_avg.append(grad.data)
247+
# Check if this param needs sum reduction (sequence parallel or qk_layernorm)
248+
elif (config.sequence_parallel and getattr(param, "sequence_parallel", False)) or (
249+
config.qk_layernorm and ("q_layernorm" in name or "k_layernorm" in name)
250+
):
251+
params_sum.append(param)
252+
grad_attr = _get_main_grad_attr(param, ddp_config.use_custom_fsdp)
253+
grad = getattr(param, grad_attr)
254+
grad = _unshard_if_dtensor(grad)
255+
grads_sum.append(grad.data)
256+
257+
# Loop grads and perform correct all-reduce
258+
for params, grads, all_reduce_op in zip(
259+
[params_sum, params_avg],
260+
[grads_sum, grads_avg],
261+
[torch.distributed.ReduceOp.SUM, torch.distributed.ReduceOp.AVG],
262+
):
263+
if grads:
264+
coalesced = _flatten_dense_tensors(grads)
265+
torch.distributed.all_reduce(
266+
coalesced, op=all_reduce_op, group=parallel_state.get_tensor_model_parallel_group()
267+
)
268+
for param, buf, synced in zip(
269+
params, grads, _unflatten_dense_tensors(coalesced, grads)
270+
):
271+
buf.copy_(synced)
272+
grad_attr = _get_main_grad_attr(param, ddp_config.use_custom_fsdp)
273+
orig_grad = getattr(param, grad_attr)
274+
setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad))
275+
276+
277+
"""
278+
This is an alias to _allreduce_non_tensor_model_parallel_grads that we must
279+
maintain for legacy tests. We can remove this proxy in mcore 0.14.
280+
"""
281+
_allreduce_layernorm_grads = _allreduce_non_tensor_model_parallel_grads
282+
283+
259284
def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None):
260285
"""
261286
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
@@ -282,14 +307,14 @@ def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torc
282307
if config.timers is not None:
283308
config.timers('conditional-embedder-grads-all-reduce').stop()
284309

285-
# All-reduce layer-norm grads (for sequence parallelism).
310+
# All-reduce layer-norm grads (for sequence parallelism) and non-tensor parallel modules.
286311
if config.timers is not None:
287-
config.timers('layernorm-grads-all-reduce', log_level=1).start(
312+
config.timers('non-tensor-parallel-grads-all-reduce', log_level=1).start(
288313
barrier=config.barrier_with_L1_time
289314
)
290-
_allreduce_layernorm_grads(model, config)
315+
_allreduce_non_tensor_model_parallel_grads(model, config)
291316
if config.timers is not None:
292-
config.timers('layernorm-grads-all-reduce').stop()
317+
config.timers('non-tensor-parallel-grads-all-reduce').stop()
293318

294319
# All-reduce embedding grads (for pipeline parallelism).
295320
if config.timers is not None:

megatron/core/models/huggingface/module.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
22

3+
import torch
34
from transformers import AutoConfig, AutoModel
45

56
from megatron.core.transformer.module import MegatronModule
@@ -17,6 +18,20 @@ def set_input_tensor(self, input_tensor):
1718
"""Dummy function for set_input_tensor"""
1819
self.input_tensor = input_tensor
1920

21+
def __setattr__(self, name: str, value):
22+
"""
23+
Set average_gradients_across_tp_domain attribute true on all params so that during
24+
finalize_model_grads an all-reduce is performed on this module’s gradients across
25+
tensor parallel ranks. This keeps replicated weights synchronized and prevents drift
26+
due to non determinism in HF models producing slightly different grads in replicated
27+
models on the same inputs.
28+
"""
29+
super().__setattr__(name, value)
30+
31+
if isinstance(value, torch.nn.Module):
32+
for param in value.parameters(recurse=True):
33+
setattr(param, "average_gradients_across_tp_domain", True)
34+
2035

2136
class AutoHuggingFaceModel(HuggingFaceModule):
2237
"""

tests/unit_tests/distributed/test_finalize_model_grads.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from megatron.core import parallel_state
1010
from megatron.core.distributed import DistributedDataParallelConfig
1111
from megatron.core.distributed.finalize_model_grads import (
12-
_allreduce_layernorm_grads,
12+
_allreduce_non_tensor_model_parallel_grads,
1313
_allreduce_word_embedding_grads,
1414
)
1515
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
@@ -67,7 +67,7 @@ def test_allreduce_layernorm_grads(self, freeze_model, tp_size):
6767
else:
6868
param.grad = torch.ones_like(param)
6969

70-
_allreduce_layernorm_grads([self.model], self.transformer_config)
70+
_allreduce_non_tensor_model_parallel_grads([self.model], self.transformer_config)
7171

7272
@pytest.mark.parametrize(
7373
("freeze_model", "pp_size", "share_embeddings"),

0 commit comments

Comments
 (0)