Skip to content

bugfix: strip tags when falling back to update in SageMakerEndpointOperator #33487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,16 +494,32 @@ def execute(self, context: Context) -> dict:
try:
response = sagemaker_operation(
endpoint_info,
wait_for_completion=False,
)
# waiting for completion is handled here in the operator
except ClientError:
self.operation = "update"
sagemaker_operation = self.hook.update_endpoint
response = sagemaker_operation(
endpoint_info,
wait_for_completion=False,
wait_for_completion=False, # waiting for completion is handled here in the operator
)
except ClientError as ce:
if self.operation == "create" and ce.response["Error"]["Message"].startswith(
"Cannot create already existing endpoint"
):
# if we get an error because the endpoint already exists, we try to update it instead
self.operation = "update"
sagemaker_operation = self.hook.update_endpoint
self.log.warning(
"cannot create already existing endpoint %s, "
"updating it with the given config instead",
endpoint_info["EndpointName"],
)
if "Tags" in endpoint_info:
self.log.warning(
"Provided tags will be ignored in the update operation "
"(tags on the existing endpoint will be unchanged)"
)
endpoint_info.pop("Tags")
response = sagemaker_operation(
endpoint_info,
wait_for_completion=False,
)
else:
raise

if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker endpoint creation failed: {response}")
Expand Down
37 changes: 37 additions & 0 deletions tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,43 @@ def test_execute_with_duplicate_endpoint_creation(
}
self.sagemaker.execute(None)

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "create_model")
@mock.patch.object(SageMakerHook, "create_endpoint_config")
@mock.patch.object(SageMakerHook, "create_endpoint")
@mock.patch.object(SageMakerHook, "update_endpoint")
@mock.patch.object(sagemaker, "serialize", return_value="")
def test_execute_with_duplicate_endpoint_removes_tags(
self,
serialize,
mock_endpoint_update,
mock_endpoint_create,
mock_endpoint_config,
mock_model,
mock_client,
):
mock_endpoint_create.side_effect = ClientError(
error_response={
"Error": {
"Code": "ValidationException",
"Message": "Cannot create already existing endpoint.",
}
},
operation_name="CreateEndpoint",
)

def _check_no_tags(config, wait_for_completion):
assert "Tags" not in config
return {
"EndpointArn": "test_arn",
"ResponseMetadata": {"HTTPStatusCode": 200},
}

mock_endpoint_update.side_effect = _check_no_tags

self.sagemaker.config["Endpoint"]["Tags"] = {"Key": "k", "Value": "v"}
self.sagemaker.execute(None)

@mock.patch.object(SageMakerHook, "create_model")
@mock.patch.object(SageMakerHook, "create_endpoint_config")
@mock.patch.object(SageMakerHook, "create_endpoint")
Expand Down