Skip to content

Commit fb22b91

Browse files
authored
Fix issue #4856 by copying environment variables (#5115)
* Fix issue #4856 by copying environment variables
1 parent 1782329 commit fb22b91

File tree

2 files changed

+79
-34
lines changed

2 files changed

+79
-34
lines changed

src/sagemaker/workflow/notebook_job_step.py

+17-33
Original file line numberDiff line numberDiff line change
@@ -13,49 +13,33 @@
1313
"""The notebook job step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16+
import os
1617
import re
1718
import shutil
18-
import os
19+
from typing import Dict, List, Optional, Union
1920

20-
from typing import (
21-
List,
22-
Optional,
23-
Union,
24-
Dict,
21+
from sagemaker import vpc_utils
22+
from sagemaker.config.config_schema import (
23+
NOTEBOOK_JOB_ROLE_ARN,
24+
NOTEBOOK_JOB_S3_KMS_KEY_ID,
25+
NOTEBOOK_JOB_S3_ROOT_URI,
26+
NOTEBOOK_JOB_VOLUME_KMS_KEY_ID,
27+
NOTEBOOK_JOB_VPC_CONFIG_SECURITY_GROUP_IDS,
28+
NOTEBOOK_JOB_VPC_CONFIG_SUBNETS,
2529
)
26-
30+
from sagemaker.s3 import S3Uploader
31+
from sagemaker.s3_utils import s3_path_join
32+
from sagemaker.session import get_execution_role
33+
from sagemaker.utils import Tags, _tmpdir, format_tags, name_from_base, resolve_value_from_config
34+
from sagemaker.workflow.entities import PipelineVariable, RequestType
2735
from sagemaker.workflow.execution_variables import ExecutionVariables
2836
from sagemaker.workflow.functions import Join
2937
from sagemaker.workflow.properties import Properties
3038
from sagemaker.workflow.retry import RetryPolicy
31-
from sagemaker.workflow.steps import (
32-
Step,
33-
ConfigurableRetryStep,
34-
StepTypeEnum,
35-
)
3639
from sagemaker.workflow.step_collections import StepCollection
3740
from sagemaker.workflow.step_outputs import StepOutput
38-
39-
from sagemaker.workflow.entities import (
40-
RequestType,
41-
PipelineVariable,
42-
)
41+
from sagemaker.workflow.steps import ConfigurableRetryStep, Step, StepTypeEnum
4342
from sagemaker.workflow.utilities import _collect_parameters, load_step_compilation_context
44-
from sagemaker.session import get_execution_role
45-
46-
from sagemaker.s3_utils import s3_path_join
47-
from sagemaker.s3 import S3Uploader
48-
from sagemaker.utils import _tmpdir, name_from_base, resolve_value_from_config, format_tags, Tags
49-
from sagemaker import vpc_utils
50-
51-
from sagemaker.config.config_schema import (
52-
NOTEBOOK_JOB_ROLE_ARN,
53-
NOTEBOOK_JOB_S3_ROOT_URI,
54-
NOTEBOOK_JOB_S3_KMS_KEY_ID,
55-
NOTEBOOK_JOB_VOLUME_KMS_KEY_ID,
56-
NOTEBOOK_JOB_VPC_CONFIG_SUBNETS,
57-
NOTEBOOK_JOB_VPC_CONFIG_SECURITY_GROUP_IDS,
58-
)
5943

6044

6145
# disable E1101 as collect_parameters decorator sets the attributes
@@ -374,7 +358,7 @@ def _prepare_env_variables(self):
374358
execution mechanism.
375359
"""
376360

377-
job_envs = self.environment_variables if self.environment_variables else {}
361+
job_envs = dict(self.environment_variables or {})
378362
system_envs = {
379363
"AWS_DEFAULT_REGION": self._region_from_session,
380364
"SM_JOB_DEF_VERSION": "1.0",

tests/unit/sagemaker/workflow/test_notebook_job_step.py

+62-1
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import os
1516
import unittest
17+
1618
from mock import Mock, patch
1719

18-
from sagemaker.workflow.notebook_job_step import NotebookJobStep
1920
from sagemaker.workflow.functions import Join
21+
from sagemaker.workflow.notebook_job_step import NotebookJobStep
2022

2123
REGION = "us-west-2"
2224
PIPELINE_NAME = "test-pipeline-name"
@@ -573,3 +575,62 @@ def _create_step_with_required_fields(self):
573575
image_uri=IMAGE_URI,
574576
kernel_name=KERNEL_NAME,
575577
)
578+
579+
def test_environment_variables_not_shared(self):
580+
"""Test that environment variables are not shared between NotebookJob steps"""
581+
# Setup shared environment variables
582+
shared_env_vars = {"test": "test"}
583+
584+
# Create two steps with the same environment variables dictionary
585+
step1 = NotebookJobStep(
586+
name="step1",
587+
input_notebook=INPUT_NOTEBOOK,
588+
image_uri=IMAGE_URI,
589+
kernel_name=KERNEL_NAME,
590+
environment_variables=shared_env_vars,
591+
)
592+
593+
step2 = NotebookJobStep(
594+
name="step2",
595+
input_notebook=INPUT_NOTEBOOK,
596+
image_uri=IMAGE_URI,
597+
kernel_name=KERNEL_NAME,
598+
environment_variables=shared_env_vars,
599+
)
600+
601+
# Get the arguments for both steps
602+
step1_args = step1.arguments
603+
step2_args = step2.arguments
604+
605+
# Verify that the environment variables are different objects
606+
self.assertIsNot(
607+
step1_args["Environment"],
608+
step2_args["Environment"],
609+
"Environment dictionaries should be different objects",
610+
)
611+
612+
# Verify that modifying one step's environment doesn't affect the other
613+
step1_env = step1_args["Environment"]
614+
step2_env = step2_args["Environment"]
615+
616+
# Both should have the original test value
617+
self.assertEqual(step1_env["test"], "test")
618+
self.assertEqual(step2_env["test"], "test")
619+
620+
# Modify step1's environment
621+
step1_env["test"] = "modified"
622+
623+
# Verify step2's environment remains unchanged
624+
self.assertEqual(step2_env["test"], "test")
625+
626+
# Verify notebook names are correct for each step
627+
self.assertEqual(
628+
step1_env["SM_INPUT_NOTEBOOK_NAME"],
629+
os.path.basename(INPUT_NOTEBOOK),
630+
"Step 1 should have its own notebook name",
631+
)
632+
self.assertEqual(
633+
step2_env["SM_INPUT_NOTEBOOK_NAME"],
634+
os.path.basename(INPUT_NOTEBOOK),
635+
"Step 2 should have its own notebook name",
636+
)

0 commit comments

Comments
 (0)