Skip to content

Commit 0d90675

Browse files
authored
Fix TP, enable test, silence noisy logs (#2761)
1 parent fa92c96 commit 0d90675

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

recipes/full_finetune_distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,11 @@ def __init__(self, cfg: DictConfig) -> None:
231231
self._activation_offloading_use_streams = cfg.get(
232232
"activation_offloading_use_streams", True
233233
)
234-
if self._activation_offloading_use_streams and self.parallel_dims.tp_enabled:
234+
if (
235+
self._enable_activation_offloading
236+
and self._activation_offloading_use_streams
237+
and self.parallel_dims.tp_enabled
238+
):
235239
warn(
236240
message=(
237241
"Using activation offloading with streams is not advised in tensor parallel, and may "

tests/recipes/test_full_finetune_distributed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import pytest
1313
import torch
14+
from packaging import version
1415
from tests.common import TUNE_PATH
1516

1617
from tests.recipes.utils import (
@@ -130,7 +131,8 @@ def test_loss(
130131
)
131132

132133
@pytest.mark.skipif(
133-
torch.__version__ < "2.8.0", reason="2D parallel test requires PyTorch >= 2.8"
134+
version.parse(torch.__version__).base_version < "2.8.0",
135+
reason="2D parallel test requires PyTorch >= 2.8",
134136
)
135137
@pytest.mark.integration_test
136138
@pytest.mark.parametrize(

torchtune/models/llama3/_parallelism.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def _get_base_llama_tp_training_plan(
3737
"norm": SequenceParallel(),
3838
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
3939
"layers.*.attn": layerwise_prepare_module_input_cls(
40-
input_layouts=(Shard(1), None),
41-
desired_input_layouts=(Replicate(), None),
40+
input_layouts=(Shard(1), Shard(1)),
41+
desired_input_layouts=(Replicate(), Replicate()),
4242
),
4343
"layers.*.mlp": layerwise_prepare_module_input_cls(
4444
input_layouts=(Shard(1),),

torchtune/models/llama4/_parallelism.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def decoder_only_tp_training_plan(model: nn.Module) -> dict[str, ParallelStyle]:
6363
layer_plan = {
6464
f"decoder.layers.{layer_id}.sa_norm": SequenceParallel(),
6565
f"decoder.layers.{layer_id}.attn": PrepareModuleInput(
66-
input_layouts=(Shard(1), None),
67-
desired_input_layouts=(Replicate(), None),
66+
input_layouts=(Shard(1), Shard(1)),
67+
desired_input_layouts=(Replicate(), Replicate()),
6868
),
6969
f"decoder.layers.{layer_id}.attn.q_proj": ColwiseParallel(),
7070
f"decoder.layers.{layer_id}.attn.k_proj": ColwiseParallel(),

0 commit comments

Comments
 (0)