Skip to content

Commit ae3cc1c

Browse files
authored
SageMaker @Remote function: Added multi-node functionality (#4984)
* implemented multi-node distribution with @Remote function * completed unit tests * added distributed training with CPU and torchrun * backwards compatibility nproc_per_node * fixing code: permissions for non-root users, integration tests * fixed docstyle * refactor nproc_per_node for backwards compatibility * refactor nproc_per_node for backwards compatibility * pylint fix, newlines * added unit tests for bootstrap_environment remote
1 parent a58654e commit ae3cc1c

File tree

6 files changed

+908
-91
lines changed

6 files changed

+908
-91
lines changed

src/sagemaker/remote_function/client.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def remote(
9191
use_spot_instances=False,
9292
max_wait_time_in_seconds=None,
9393
use_torchrun=False,
94-
nproc_per_node=1,
94+
nproc_per_node: Optional[int] = None,
9595
):
9696
"""Decorator for running the annotated function as a SageMaker training job.
9797
@@ -284,8 +284,9 @@ def remote(
284284
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
285285
Defaults to ``False``.
286286
287-
nproc_per_node (int): Specifies the number of processes per node for distributed training.
288-
Defaults to ``1``.
287+
nproc_per_node (Optional int): Specifies the number of processes per node for
288+
distributed training. Defaults to ``None``.
289+
This is defined automatically configured on the instance type.
289290
"""
290291

291292
def _remote(func):
@@ -325,9 +326,13 @@ def _remote(func):
325326
@functools.wraps(func)
326327
def wrapper(*args, **kwargs):
327328

328-
if instance_count > 1 and not spark_config:
329+
if instance_count > 1 and not (
330+
(spark_config is not None and not use_torchrun)
331+
or (spark_config is None and use_torchrun)
332+
):
329333
raise ValueError(
330-
"Remote function do not support training on multi instances. "
334+
"Remote function do not support training on multi instances "
335+
+ "without spark_config or use_torchrun. "
331336
+ "Please provide instance_count = 1"
332337
)
333338

@@ -532,7 +537,7 @@ def __init__(
532537
use_spot_instances=False,
533538
max_wait_time_in_seconds=None,
534539
use_torchrun=False,
535-
nproc_per_node=1,
540+
nproc_per_node: Optional[int] = None,
536541
):
537542
"""Constructor for RemoteExecutor
538543
@@ -725,17 +730,22 @@ def __init__(
725730
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
726731
Defaults to ``False``.
727732
728-
nproc_per_node (int): Specifies the number of processes per node.
729-
Defaults to ``1``.
733+
nproc_per_node (Optional int): Specifies the number of processes per node for
734+
distributed training. Defaults to ``None``.
735+
This is defined automatically configured on the instance type.
730736
"""
731737
self.max_parallel_jobs = max_parallel_jobs
732738

733739
if self.max_parallel_jobs <= 0:
734740
raise ValueError("max_parallel_jobs must be greater than 0.")
735741

736-
if instance_count > 1 and not spark_config:
742+
if instance_count > 1 and not (
743+
(spark_config is not None and not use_torchrun)
744+
or (spark_config is None and use_torchrun)
745+
):
737746
raise ValueError(
738-
"Remote function do not support training on multi instances. "
747+
"Remote function do not support training on multi instances "
748+
+ "without spark_config or use_torchrun. "
739749
+ "Please provide instance_count = 1"
740750
)
741751

src/sagemaker/remote_function/core/stored_function.py

-6
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ def __init__(
5555
hmac_key: str,
5656
s3_kms_key: str = None,
5757
context: Context = Context(),
58-
use_torchrun: bool = False,
59-
nproc_per_node: int = 1,
6058
):
6159
"""Construct a StoredFunction object.
6260
@@ -67,16 +65,12 @@ def __init__(
6765
s3_kms_key: KMS key used to encrypt artifacts uploaded to S3.
6866
hmac_key: Key used to encrypt serialized and deserialized function and arguments.
6967
context: Build or run context of a pipeline step.
70-
use_torchrun: Whether to use torchrun for distributed training.
71-
nproc_per_node: Number of processes per node for distributed training.
7268
"""
7369
self.sagemaker_session = sagemaker_session
7470
self.s3_base_uri = s3_base_uri
7571
self.s3_kms_key = s3_kms_key
7672
self.hmac_key = hmac_key
7773
self.context = context
78-
self.use_torchrun = use_torchrun
79-
self.nproc_per_node = nproc_per_node
8074

8175
self.func_upload_path = s3_path_join(
8276
s3_base_uri, context.step_name, context.func_step_s3_dir

src/sagemaker/remote_function/job.py

+64-11
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,12 @@
130130
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
131131
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"
132132
133+
printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n"
134+
cat /opt/ml/input/config/resourceconfig.json
133135
134136
printf "INFO: Bootstraping runtime environment.\\n"
135137
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
138+
source /opt/ml/input/sm_training.env
136139
137140
if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
138141
then
@@ -155,9 +158,11 @@
155158
fi
156159
157160
printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n"
161+
printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function \\n"
158162
$conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@"
159163
else
160164
printf "INFO: No conda env provided. Invoking remote function\\n"
165+
printf "INFO: python -m sagemaker.remote_function.invoke_function \\n"
161166
python -m sagemaker.remote_function.invoke_function "$@"
162167
fi
163168
"""
@@ -175,9 +180,12 @@
175180
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
176181
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"
177182
183+
printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n"
184+
cat /opt/ml/input/config/resourceconfig.json
178185
179186
printf "INFO: Bootstraping runtime environment.\\n"
180187
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
188+
source /opt/ml/input/sm_training.env
181189
182190
if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
183191
then
@@ -200,11 +208,18 @@
200208
fi
201209
202210
printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n"
203-
$conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \
211+
printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
212+
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
213+
-m sagemaker.remote_function.invoke_function \\n"
214+
$conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
215+
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
204216
-m sagemaker.remote_function.invoke_function "$@"
205217
else
206218
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
207-
torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
219+
printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
220+
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\n"
221+
torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
222+
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@"
208223
fi
209224
"""
210225

@@ -262,8 +277,8 @@ def __init__(
262277
spark_config: SparkConfig = None,
263278
use_spot_instances=False,
264279
max_wait_time_in_seconds=None,
265-
use_torchrun=False,
266-
nproc_per_node=1,
280+
use_torchrun: bool = False,
281+
nproc_per_node: Optional[int] = None,
267282
):
268283
"""Initialize a _JobSettings instance which configures the remote job.
269284
@@ -445,6 +460,13 @@ def __init__(
445460
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
446461
After this amount of time Amazon SageMaker will stop waiting for managed spot
447462
training job to complete. Defaults to ``None``.
463+
464+
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
465+
Defaults to ``False``.
466+
467+
nproc_per_node (Optional int): Specifies the number of processes per node for
468+
distributed training. Defaults to ``None``.
469+
This is defined automatically configured on the instance type.
448470
"""
449471
self.sagemaker_session = sagemaker_session or Session()
450472
self.environment_variables = resolve_value_from_config(
@@ -732,6 +754,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
732754
)
733755

734756
logger.info("Creating job: %s", job_name)
757+
735758
job_settings.sagemaker_session.sagemaker_client.create_training_job(**training_job_request)
736759

737760
return _Job(
@@ -776,8 +799,6 @@ def compile(
776799
s3_base_uri=s3_base_uri,
777800
hmac_key=hmac_key,
778801
s3_kms_key=job_settings.s3_kms_key,
779-
use_torchrun=job_settings.use_torchrun,
780-
nproc_per_node=job_settings.nproc_per_node,
781802
)
782803
stored_function.save(func, *func_args, **func_kwargs)
783804
else:
@@ -790,8 +811,6 @@ def compile(
790811
step_name=step_compilation_context.step_name,
791812
func_step_s3_dir=step_compilation_context.pipeline_build_time,
792813
),
793-
use_torchrun=job_settings.use_torchrun,
794-
nproc_per_node=job_settings.nproc_per_node,
795814
)
796815

797816
stored_function.save_pipeline_step_function(serialized_data)
@@ -931,6 +950,7 @@ def compile(
931950
request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key})
932951

933952
extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri)
953+
extended_request = _extend_torchrun_to_request(extended_request, job_settings)
934954

935955
return extended_request
936956

@@ -1011,7 +1031,7 @@ def _prepare_and_upload_runtime_scripts(
10111031
s3_kms_key: str,
10121032
sagemaker_session: Session,
10131033
use_torchrun: bool = False,
1014-
nproc_per_node: int = 1,
1034+
nproc_per_node: Optional[int] = None,
10151035
):
10161036
"""Copy runtime scripts to a folder and upload to S3.
10171037
@@ -1030,7 +1050,7 @@ def _prepare_and_upload_runtime_scripts(
10301050
10311051
use_torchrun (bool): Whether to use torchrun or not.
10321052
1033-
nproc_per_node (int): Number of processes per node.
1053+
nproc_per_node (Optional[int]): Number of processes per node
10341054
"""
10351055

10361056
from sagemaker.workflow.utilities import load_step_compilation_context
@@ -1054,7 +1074,11 @@ def _prepare_and_upload_runtime_scripts(
10541074

10551075
if use_torchrun:
10561076
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
1057-
entry_point_script = entry_point_script.replace("$NPROC_PER_NODE", str(nproc_per_node))
1077+
1078+
if nproc_per_node is not None and nproc_per_node > 0:
1079+
entry_point_script = entry_point_script.replace(
1080+
"$SM_NPROC_PER_NODE", str(nproc_per_node)
1081+
)
10581082

10591083
with open(entrypoint_script_path, "w", newline="\n") as file:
10601084
file.writelines(entry_point_script)
@@ -1435,6 +1459,35 @@ def _upload_serialized_spark_configuration(
14351459
return config_file_s3_uri
14361460

14371461

1462+
def _extend_torchrun_to_request(
1463+
request_dict: Dict,
1464+
job_settings: _JobSettings,
1465+
) -> Dict:
1466+
"""Extend the create training job request with torchrun configuration.
1467+
1468+
Args:
1469+
request_dict (Dict): create training job request dict.
1470+
job_settings (_JobSettings): the job settings.
1471+
"""
1472+
use_torchrun = job_settings.use_torchrun
1473+
instance_count = job_settings.instance_count
1474+
1475+
if not use_torchrun:
1476+
return request_dict
1477+
1478+
if instance_count == 1:
1479+
return request_dict
1480+
1481+
extended_request = request_dict.copy()
1482+
1483+
for input_channel in extended_request["InputDataConfig"]:
1484+
s3_data_source = input_channel["DataSource"].get("S3DataSource", None)
1485+
if s3_data_source:
1486+
s3_data_source["S3DataDistributionType"] = "FullyReplicated"
1487+
1488+
return extended_request
1489+
1490+
14381491
def _extend_spark_config_to_request(
14391492
request_dict: Dict,
14401493
job_settings: _JobSettings,

0 commit comments

Comments
 (0)