Skip to content

Commit 3086e26

Browse files
authored
Speed up imports and add a CI (#2845)
* Working test * Timing cleanup * Add CI * Fix nits * Mixup imports * Clean * tuna -> tuna-interpreter * Refactor pippy imports * Accelerator * Fin * Fin * Keep specific ones for docs
1 parent 5d5d07a commit 3086e26

File tree

8 files changed

+203
-34
lines changed

8 files changed

+203
-34
lines changed

.github/workflows/test_imports.yml

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
name: Run Import Tests
2+
3+
on:
4+
pull_request:
5+
paths:
6+
- "src/**"
7+
- "tests/**"
8+
- ".github/**"
9+
- "examples/**"
10+
- "setup.py"
11+
types: [opened, synchronize, reopened]
12+
13+
env:
14+
HF_HOME: ~/hf_cache
15+
TESTING_MOCKED_DATALOADERS: "1"
16+
IS_GITHUB_CI: "1"
17+
18+
jobs:
19+
run-tests:
20+
runs-on: ubuntu-latest
21+
strategy:
22+
fail-fast: false
23+
matrix:
24+
pytorch-version: [
25+
latest,
26+
minimum,
27+
]
28+
steps:
29+
- uses: actions/[email protected]
30+
- name: Set up python 3.8
31+
uses: actions/setup-python@v3
32+
with:
33+
python-version: 3.8
34+
35+
- name: Install the library
36+
run: |
37+
pip install -e .
38+
pip install pytest-reportlog tabulate setuptools git+https://github.com/muellerzr/import-timer
39+
40+
- name: Show installed libraries
41+
run: |
42+
pip freeze
43+
44+
- name: Run Import Tests
45+
env:
46+
PYTORCH_VERSION: ${{ matrix.pytorch-version }}
47+
run: |
48+
pytest -sv tests/test_imports.py
49+
50+
- name: Generate Report
51+
if: always()
52+
run: |
53+
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

src/accelerate/accelerator.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@
8181
has_transformer_engine_layers,
8282
is_bf16_available,
8383
is_deepspeed_available,
84-
is_fp8_available,
8584
is_ipex_available,
8685
is_lomo_available,
8786
is_megatron_lm_available,
@@ -117,11 +116,6 @@
117116
DummyScheduler,
118117
)
119118

120-
if is_fp8_available():
121-
import transformer_engine.common.recipe as te_recipe
122-
from transformer_engine.pytorch import fp8_autocast
123-
124-
125119
if is_megatron_lm_available():
126120
from .utils import (
127121
MegatronEngine,
@@ -1384,6 +1378,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
13841378

13851379
# We prepare fp8 after, allowing for bf16 autocast to happen first
13861380
if getattr(self.fp8_recipe_handler, "backend", None) == "TE":
1381+
# Import here to keep base imports fast
1382+
import transformer_engine.common.recipe as te_recipe
1383+
from transformer_engine.pytorch import fp8_autocast
1384+
13871385
if not has_transformer_engine_layers(model):
13881386
with torch.no_grad():
13891387
convert_model(model)

src/accelerate/inference.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,6 @@
2828
)
2929

3030

31-
if is_pippy_available():
32-
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
33-
from pippy.PipelineStage import PipelineStage
34-
35-
3631
def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None):
3732
"""
3833
Calculates the device map for `model` with an offset for PiPPy
@@ -83,6 +78,10 @@ def build_pipeline(model, split_points, args, kwargs, num_chunks):
8378
Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use
8479
`AcceleratorState.num_processes`
8580
"""
81+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
82+
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
83+
from pippy.PipelineStage import PipelineStage
84+
8685
# We need to annotate the split points in the model for PiPPy
8786
state = PartialState()
8887
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})

src/accelerate/test_utils/testing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
is_datasets_available,
4141
is_deepspeed_available,
4242
is_dvclive_available,
43+
is_import_timer_available,
4344
is_mlu_available,
4445
is_mps_available,
4546
is_npu_available,
@@ -377,6 +378,14 @@ def require_pippy(test_case):
377378
return unittest.skipUnless(is_pippy_available(), "test requires pippy")(test_case)
378379

379380

381+
def require_import_timer(test_case):
382+
"""
383+
Decorator marking a test that requires tuna interpreter installed. These tests are skipped when tuna isn't
384+
installed
385+
"""
386+
return unittest.skipUnless(is_import_timer_available(), "test requires tuna interpreter")(test_case)
387+
388+
380389
_atleast_one_tracker_available = (
381390
any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available()
382391
)

src/accelerate/utils/__init__.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
is_deepspeed_available,
8787
is_dvclive_available,
8888
is_fp8_available,
89+
is_import_timer_available,
8990
is_ipex_available,
9091
is_lomo_available,
9192
is_megatron_lm_available,
@@ -195,24 +196,31 @@
195196
prepare_simple_launcher_cmd_env,
196197
prepare_tpu,
197198
)
199+
200+
# For docs
198201
from .megatron_lm import (
199202
AbstractTrainStep,
200203
BertTrainStep,
201204
GPTTrainStep,
202-
MegatronEngine,
203205
MegatronLMDummyDataLoader,
204206
MegatronLMDummyScheduler,
205-
MegatronLMOptimizerWrapper,
206-
MegatronLMSchedulerWrapper,
207207
T5TrainStep,
208208
avg_losses_across_data_parallel_group,
209-
gather_across_data_parallel_groups,
210209
)
211-
from .megatron_lm import initialize as megatron_lm_initialize
212-
from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader
213-
from .megatron_lm import prepare_model_optimizer_scheduler as megatron_lm_prepare_model_optimizer_scheduler
214-
from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer
215-
from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler
210+
211+
212+
if is_megatron_lm_available():
213+
from .megatron_lm import (
214+
MegatronEngine,
215+
MegatronLMOptimizerWrapper,
216+
MegatronLMSchedulerWrapper,
217+
gather_across_data_parallel_groups,
218+
)
219+
from .megatron_lm import initialize as megatron_lm_initialize
220+
from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader
221+
from .megatron_lm import prepare_model_optimizer_scheduler as megatron_lm_prepare_model_optimizer_scheduler
222+
from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer
223+
from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler
216224
from .memory import find_executable_batch_size, release_memory
217225
from .other import (
218226
check_os_kernel,

src/accelerate/utils/fsdp_utils.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,12 @@
1818
import torch
1919

2020
from ..logging import get_logger
21-
from .constants import FSDP_MODEL_NAME, FSDP_PYTORCH_VERSION, OPTIMIZER_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME
22-
from .imports import is_torch_distributed_available
21+
from .constants import FSDP_MODEL_NAME, OPTIMIZER_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME
2322
from .modeling import is_peft_model
2423
from .other import save
2524
from .versions import is_torch_version
2625

2726

28-
if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available():
29-
import torch.distributed.checkpoint as dist_cp
30-
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner, DefaultSavePlanner
31-
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
32-
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
33-
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
34-
# `dist_cp_format_utils is only available from pt>=2.3.0
35-
if is_torch_version(">=", "2.3.0") and is_torch_distributed_available():
36-
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
37-
38-
3927
logger = get_logger(__name__)
4028

4129

@@ -58,8 +46,13 @@ def _set_model_state_dict(model, state_dict, adapter_only=False):
5846

5947

6048
def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False):
61-
os.makedirs(output_dir, exist_ok=True)
49+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
50+
import torch.distributed.checkpoint as dist_cp
51+
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
52+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
53+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
6254

55+
os.makedirs(output_dir, exist_ok=True)
6356
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
6457
# FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
6558
# so, only enable it when num_processes>1
@@ -103,6 +96,12 @@ def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0,
10396

10497

10598
def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False):
99+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
100+
import torch.distributed.checkpoint as dist_cp
101+
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
102+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
103+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
104+
106105
accelerator.wait_for_everyone()
107106
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
108107
# FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
@@ -156,6 +155,12 @@ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, a
156155

157156

158157
def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0):
158+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
159+
import torch.distributed.checkpoint as dist_cp
160+
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
161+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
162+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
163+
159164
os.makedirs(output_dir, exist_ok=True)
160165
with FSDP.state_dict_type(
161166
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
@@ -183,6 +188,12 @@ def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir,
183188

184189

185190
def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, optimizer_index=0, adapter_only=False):
191+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
192+
import torch.distributed.checkpoint as dist_cp
193+
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
194+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
195+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
196+
186197
accelerator.wait_for_everyone()
187198
with FSDP.state_dict_type(
188199
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
@@ -221,6 +232,10 @@ def _distributed_checkpoint_to_merged_weights(checkpoint_dir: str, save_path: st
221232
222233
Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
223234
"""
235+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
236+
import torch.distributed.checkpoint as dist_cp
237+
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
238+
224239
state_dict = {}
225240
save_path = Path(save_path)
226241
save_path.mkdir(exist_ok=True)

src/accelerate/utils/imports.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def get_ccl_version():
8181
return importlib.metadata.version("oneccl_bind_pt")
8282

8383

84+
def is_import_timer_available():
85+
return _is_package_available("import_timer")
86+
87+
8488
def is_pynvml_available():
8589
return _is_package_available("pynvml")
8690

tests/test_imports.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import subprocess
15+
16+
from accelerate.test_utils.testing import TempDirTestCase, require_import_timer
17+
from accelerate.utils import is_import_timer_available
18+
19+
20+
if is_import_timer_available():
21+
from import_timer import calculate_total_time, read_import_profile
22+
from import_timer.core import get_paths_above_threshold, sort_nodes_by_total_time
23+
24+
25+
def convert_list_to_string(data):
26+
end_result = ""
27+
arrow_right = "->"
28+
for path in data:
29+
end_result += f"{arrow_right.join(path[0])} {path[1]:.3f}s\n"
30+
return end_result
31+
32+
33+
def run_import_time(command: str):
34+
output = subprocess.run(["python3", "-X", "importtime", "-c", command], capture_output=True, text=True)
35+
return output.stderr
36+
37+
38+
@require_import_timer
39+
class ImportSpeedTester(TempDirTestCase):
40+
"""
41+
Test suite which checks if imports have seen slowdowns
42+
based on a particular baseline.
43+
44+
If the error messages are not clear enough to get a
45+
full view of what is slowing things down (or to
46+
figure out how deep the initial depth should be),
47+
please view the profile with the `tuna` framework:
48+
`tuna import.log`.
49+
"""
50+
51+
clear_on_setup = False
52+
53+
@classmethod
54+
def setUpClass(cls):
55+
super().setUpClass()
56+
output = run_import_time("import torch")
57+
data = read_import_profile(output)
58+
total_time = calculate_total_time(data)
59+
cls.pytorch_time = total_time
60+
61+
def test_base_import(self):
62+
output = run_import_time("import accelerate")
63+
data = read_import_profile(output)
64+
total_time = calculate_total_time(data)
65+
pct_more = total_time / self.pytorch_time
66+
# Base import should never be more than 10% slower than raw torch import
67+
err_msg = f"Base import is more than 20% slower than raw torch import ({pct_more * 100:.2f}%), please check the attached `tuna` profile:\n"
68+
sorted_data = sort_nodes_by_total_time(data)
69+
paths_above_threshold = get_paths_above_threshold(sorted_data, 0.1, max_depth=7)
70+
err_msg += f"\n{convert_list_to_string(paths_above_threshold)}"
71+
self.assertLess(pct_more, 1.2, err_msg)
72+
73+
def test_cli_import(self):
74+
output = run_import_time("from accelerate.commands.launch import launch_command_parser")
75+
data = read_import_profile(output)
76+
total_time = calculate_total_time(data)
77+
pct_more = total_time / self.pytorch_time
78+
# Base import should never be more than 10% slower than raw torch import
79+
err_msg = f"Base import is more than 20% slower than raw torch import ({pct_more * 100:.2f}%), please check the attached `tuna` profile:\n"
80+
sorted_data = sort_nodes_by_total_time(data)
81+
paths_above_threshold = get_paths_above_threshold(sorted_data, 0.1, max_depth=7)
82+
err_msg += f"\n{convert_list_to_string(paths_above_threshold)}"
83+
self.assertLess(pct_more, 1.2, err_msg)

0 commit comments

Comments
 (0)