130
130
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
131
131
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\ n"
132
132
133
+ printf "INFO: /opt/ml/input/config/resourceconfig.json:\\ n"
134
+ cat /opt/ml/input/config/resourceconfig.json
133
135
134
136
printf "INFO: Bootstraping runtime environment.\\ n"
135
137
python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ BOOTSTRAP_SCRIPT_NAME } "$@"
138
+ source /opt/ml/input/sm_training.env
136
139
137
140
if [ -d { JOB_REMOTE_FUNCTION_WORKSPACE } ]
138
141
then
155
158
fi
156
159
157
160
printf "INFO: Invoking remote function inside conda environment: $conda_env.\\ n"
161
+ printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function \\ n"
158
162
$conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@"
159
163
else
160
164
printf "INFO: No conda env provided. Invoking remote function\\ n"
165
+ printf "INFO: python -m sagemaker.remote_function.invoke_function \\ n"
161
166
python -m sagemaker.remote_function.invoke_function "$@"
162
167
fi
163
168
"""
175
180
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
176
181
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\ n"
177
182
183
+ printf "INFO: /opt/ml/input/config/resourceconfig.json:\\ n"
184
+ cat /opt/ml/input/config/resourceconfig.json
178
185
179
186
printf "INFO: Bootstraping runtime environment.\\ n"
180
187
python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ BOOTSTRAP_SCRIPT_NAME } "$@"
188
+ source /opt/ml/input/sm_training.env
181
189
182
190
if [ -d { JOB_REMOTE_FUNCTION_WORKSPACE } ]
183
191
then
200
208
fi
201
209
202
210
printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\ n"
203
- $conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \
211
+ printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
212
+ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
213
+ -m sagemaker.remote_function.invoke_function \\ n"
214
+ $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
215
+ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
204
216
-m sagemaker.remote_function.invoke_function "$@"
205
217
else
206
218
printf "INFO: No conda env provided. Invoking remote function with torchrun\\ n"
207
- torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
219
+ printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
220
+ --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\ n"
221
+ torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
222
+ --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@"
208
223
fi
209
224
"""
210
225
@@ -262,8 +277,8 @@ def __init__(
262
277
spark_config : SparkConfig = None ,
263
278
use_spot_instances = False ,
264
279
max_wait_time_in_seconds = None ,
265
- use_torchrun = False ,
266
- nproc_per_node = 1 ,
280
+ use_torchrun : bool = False ,
281
+ nproc_per_node : Optional [ int ] = None ,
267
282
):
268
283
"""Initialize a _JobSettings instance which configures the remote job.
269
284
@@ -445,6 +460,13 @@ def __init__(
445
460
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
446
461
After this amount of time Amazon SageMaker will stop waiting for managed spot
447
462
training job to complete. Defaults to ``None``.
463
+
464
+ use_torchrun (bool): Specifies whether to use torchrun for distributed training.
465
+ Defaults to ``False``.
466
+
467
+ nproc_per_node (Optional int): Specifies the number of processes per node for
468
+ distributed training. Defaults to ``None``.
469
+ This is defined automatically configured on the instance type.
448
470
"""
449
471
self .sagemaker_session = sagemaker_session or Session ()
450
472
self .environment_variables = resolve_value_from_config (
@@ -732,6 +754,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
732
754
)
733
755
734
756
logger .info ("Creating job: %s" , job_name )
757
+
735
758
job_settings .sagemaker_session .sagemaker_client .create_training_job (** training_job_request )
736
759
737
760
return _Job (
@@ -776,8 +799,6 @@ def compile(
776
799
s3_base_uri = s3_base_uri ,
777
800
hmac_key = hmac_key ,
778
801
s3_kms_key = job_settings .s3_kms_key ,
779
- use_torchrun = job_settings .use_torchrun ,
780
- nproc_per_node = job_settings .nproc_per_node ,
781
802
)
782
803
stored_function .save (func , * func_args , ** func_kwargs )
783
804
else :
@@ -790,8 +811,6 @@ def compile(
790
811
step_name = step_compilation_context .step_name ,
791
812
func_step_s3_dir = step_compilation_context .pipeline_build_time ,
792
813
),
793
- use_torchrun = job_settings .use_torchrun ,
794
- nproc_per_node = job_settings .nproc_per_node ,
795
814
)
796
815
797
816
stored_function .save_pipeline_step_function (serialized_data )
@@ -931,6 +950,7 @@ def compile(
931
950
request_dict ["Environment" ].update ({"REMOTE_FUNCTION_SECRET_KEY" : hmac_key })
932
951
933
952
extended_request = _extend_spark_config_to_request (request_dict , job_settings , s3_base_uri )
953
+ extended_request = _extend_torchrun_to_request (extended_request , job_settings )
934
954
935
955
return extended_request
936
956
@@ -1011,7 +1031,7 @@ def _prepare_and_upload_runtime_scripts(
1011
1031
s3_kms_key : str ,
1012
1032
sagemaker_session : Session ,
1013
1033
use_torchrun : bool = False ,
1014
- nproc_per_node : int = 1 ,
1034
+ nproc_per_node : Optional [ int ] = None ,
1015
1035
):
1016
1036
"""Copy runtime scripts to a folder and upload to S3.
1017
1037
@@ -1030,7 +1050,7 @@ def _prepare_and_upload_runtime_scripts(
1030
1050
1031
1051
use_torchrun (bool): Whether to use torchrun or not.
1032
1052
1033
- nproc_per_node (int): Number of processes per node.
1053
+ nproc_per_node (Optional[ int] ): Number of processes per node
1034
1054
"""
1035
1055
1036
1056
from sagemaker .workflow .utilities import load_step_compilation_context
@@ -1054,7 +1074,11 @@ def _prepare_and_upload_runtime_scripts(
1054
1074
1055
1075
if use_torchrun :
1056
1076
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
1057
- entry_point_script = entry_point_script .replace ("$NPROC_PER_NODE" , str (nproc_per_node ))
1077
+
1078
+ if nproc_per_node is not None and nproc_per_node > 0 :
1079
+ entry_point_script = entry_point_script .replace (
1080
+ "$SM_NPROC_PER_NODE" , str (nproc_per_node )
1081
+ )
1058
1082
1059
1083
with open (entrypoint_script_path , "w" , newline = "\n " ) as file :
1060
1084
file .writelines (entry_point_script )
@@ -1435,6 +1459,35 @@ def _upload_serialized_spark_configuration(
1435
1459
return config_file_s3_uri
1436
1460
1437
1461
1462
+ def _extend_torchrun_to_request (
1463
+ request_dict : Dict ,
1464
+ job_settings : _JobSettings ,
1465
+ ) -> Dict :
1466
+ """Extend the create training job request with torchrun configuration.
1467
+
1468
+ Args:
1469
+ request_dict (Dict): create training job request dict.
1470
+ job_settings (_JobSettings): the job settings.
1471
+ """
1472
+ use_torchrun = job_settings .use_torchrun
1473
+ instance_count = job_settings .instance_count
1474
+
1475
+ if not use_torchrun :
1476
+ return request_dict
1477
+
1478
+ if instance_count == 1 :
1479
+ return request_dict
1480
+
1481
+ extended_request = request_dict .copy ()
1482
+
1483
+ for input_channel in extended_request ["InputDataConfig" ]:
1484
+ s3_data_source = input_channel ["DataSource" ].get ("S3DataSource" , None )
1485
+ if s3_data_source :
1486
+ s3_data_source ["S3DataDistributionType" ] = "FullyReplicated"
1487
+
1488
+ return extended_request
1489
+
1490
+
1438
1491
def _extend_spark_config_to_request (
1439
1492
request_dict : Dict ,
1440
1493
job_settings : _JobSettings ,
0 commit comments