Skip to content

Remove non-public interface usage in EcsRunTaskOperator #29447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 24 additions & 44 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator, XCom
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, should_retry_eni
Expand All @@ -38,11 +38,12 @@
ClusterInactiveTrigger,
TaskDoneTrigger,
)
from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict
from airflow.utils.session import provide_session

if TYPE_CHECKING:
from airflow.models import TaskInstance
from airflow.utils.context import Context

DEFAULT_CONN_ID = "aws_default"
Expand Down Expand Up @@ -450,8 +451,6 @@ class EcsRunTaskOperator(EcsBaseOperator):
"network_configuration": "json",
"tags": "json",
}
REATTACH_XCOM_KEY = "ecs_task_arn"
REATTACH_XCOM_TASK_ID_TEMPLATE = "{task_id}_task_arn"

def __init__(
self,
Expand Down Expand Up @@ -507,6 +506,8 @@ def __init__(
self.awslogs_region = self.region

self.arn: str | None = None
self._started_by: str | None = None

self.retry_args = quota_retry
self.task_log_fetcher: AwsTaskLogFetcher | None = None
self.wait_for_completion = wait_for_completion
Expand All @@ -525,19 +526,22 @@ def _get_ecs_task_id(task_arn: str | None) -> str | None:
return None
return task_arn.split("/")[-1]

@provide_session
def execute(self, context, session=None):
def execute(self, context):
self.log.info(
"Running ECS Task - Task definition: %s - on cluster %s", self.task_definition, self.cluster
)
self.log.info("EcsOperator overrides: %s", self.overrides)

if self.reattach:
self._try_reattach_task(context)
# Generate deterministic UUID which refers to unique TaskInstanceKey
ti: TaskInstance = context["ti"]
self._started_by = generate_uuid(*map(str, ti.key.primary))
self.log.info("Try to find run with startedBy=%r", self._started_by)
self._try_reattach_task(started_by=self._started_by)

if not self.arn:
# start the task except if we reattached to an existing one just before.
self._start_task(context)
self._start_task()

if self.deferrable:
self.defer(
Expand Down Expand Up @@ -574,7 +578,7 @@ def execute(self, context, session=None):
else:
self._wait_for_task_ended()

self._after_execution(session)
self._after_execution()

if self.do_xcom_push and self.task_log_fetcher:
return self.task_log_fetcher.get_last_log_message()
Expand All @@ -598,27 +602,15 @@ def execute_complete(self, context, event=None):
if len(one_log["events"]) > 0:
return one_log["events"][0]["message"]

@provide_session
def _after_execution(self, session=None):
def _after_execution(self):
self._check_success_task()

self.log.info("ECS Task has been successfully executed")

if self.reattach:
# Clear the XCom value storing the ECS task ARN if the task has completed
# as we can't reattach it anymore
self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))

def _xcom_del(self, session, task_id):
session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == task_id).delete()

@AwsBaseHook.retry(should_retry_eni)
def _start_task(self, context):
def _start_task(self):
run_opts = {
"cluster": self.cluster,
"taskDefinition": self.task_definition,
"overrides": self.overrides,
"startedBy": self.owner,
"startedBy": self._started_by or self.owner,
}

if self.capacity_provider_strategy:
Expand Down Expand Up @@ -650,27 +642,17 @@ def _start_task(self, context):
self.arn = response["tasks"][0]["taskArn"]
self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))

if self.reattach:
# Save the task ARN in XCom to be able to reattach it if needed
self.xcom_push(context, key=self.REATTACH_XCOM_KEY, value=self.arn)
Comment on lines -653 to -655
Copy link
Contributor

@vandonr-amz vandonr-amz Aug 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a somewhat-breaking change, as the example code

# You must set `reattach=True` in order to get ecs_task_arn if you plan to use a Sensor.
reattach=True,

was recommending setting reattach to true to get the ARN.

I think this sucked, but removing the arn entirely from the xcom values is not good either.
What we could do is set it all the time now that we don't rely on this anymore to know if we need to reattach.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this part was never work as it expected

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, pushing the ARN to xcom at least was working.
I opened a PR to restore that specific thing.


def _try_reattach_task(self, context):
task_def_resp = self.client.describe_task_definition(taskDefinition=self.task_definition)
ecs_task_family = task_def_resp["taskDefinition"]["family"]

def _try_reattach_task(self, started_by: str):
if not started_by:
raise AirflowException("`started_by` should not be empty or None")
list_tasks_resp = self.client.list_tasks(
cluster=self.cluster, desiredStatus="RUNNING", family=ecs_task_family
cluster=self.cluster, desiredStatus="RUNNING", startedBy=started_by
)
running_tasks = list_tasks_resp["taskArns"]

# Check if the ECS task previously launched is already running
previous_task_arn = self.xcom_pull(
context,
task_ids=self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id),
key=self.REATTACH_XCOM_KEY,
)
if previous_task_arn in running_tasks:
self.arn = previous_task_arn
if running_tasks:
if len(running_tasks) > 1:
self.log.warning("Found more then one previously launched tasks: %s", running_tasks)
self.arn = running_tasks[0]
self.log.info("Reattaching previously launched task: %s", self.arn)
else:
self.log.info("No active previously launched task found to reattach")
Expand All @@ -690,8 +672,6 @@ def _wait_for_task_ended(self) -> None:
},
)

return

def _aws_logs_enabled(self):
return self.awslogs_group and self.awslogs_stream_prefix

Expand Down
51 changes: 51 additions & 0 deletions airflow/providers/amazon/aws/utils/identifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from uuid import NAMESPACE_OID, UUID, uuid5

NIL_UUID = UUID(int=0)


def generate_uuid(*values: str | None, namespace: UUID = NAMESPACE_OID) -> str:
"""
Convert input values to deterministic UUID string representation.

This function is only intended to generate a hash which used as an identifier, not for any security use.

Generates a UUID v5 (SHA-1 + Namespace) for each value provided,
and this UUID is used as the Namespace for the next element.

If only one non-None value is provided to the function, then the result of the function
would be the same as result of ``uuid.uuid5``.

All ``None`` values are replaced by NIL UUID. If it only one value is provided then return NIL UUID.

:param namespace: Initial namespace value to pass into the ``uuid.uuid5`` function.
"""
if not values:
raise ValueError("Expected at least 1 argument")

if len(values) == 1 and values[0] is None:
return str(NIL_UUID)

result = namespace
for item in values:
result = uuid5(result, item if item is not None else str(NIL_UUID))

return str(result)
97 changes: 47 additions & 50 deletions tests/providers/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,39 +521,47 @@ def test_check_success_task_not_raises(self, client_mock):
["", {"testTagKey": "testTagValue"}],
],
)
@mock.patch.object(EcsRunTaskOperator, "_xcom_del")
@mock.patch.object(
EcsRunTaskOperator,
"xcom_pull",
return_value=f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
@pytest.mark.parametrize(
"arns, expected_arn",
[
pytest.param(
[
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
"arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b54",
],
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
id="multiple-arns",
),
pytest.param(
[
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
],
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
id="simgle-arn",
),
],
)
@mock.patch("airflow.providers.amazon.aws.operators.ecs.generate_uuid")
@mock.patch.object(EcsRunTaskOperator, "_wait_for_task_ended")
@mock.patch.object(EcsRunTaskOperator, "_check_success_task")
@mock.patch.object(EcsRunTaskOperator, "_start_task")
@mock.patch.object(EcsBaseOperator, "client")
def test_reattach_successful(
self,
client_mock,
start_mock,
check_mock,
wait_mock,
xcom_pull_mock,
xcom_del_mock,
launch_type,
tags,
self, client_mock, start_mock, check_mock, wait_mock, uuid_mock, launch_type, tags, arns, expected_arn
):
"""Test reattach on first running Task ARN."""
mock_ti = mock.MagicMock(name="MockedTaskInstance")
mock_ti.key.primary = ("mock_dag", "mock_ti", "mock_runid", 42)
fake_uuid = "01-02-03-04"
uuid_mock.return_value = fake_uuid

self.set_up_operator(launch_type=launch_type, tags=tags)
client_mock.describe_task_definition.return_value = {"taskDefinition": {"family": "f"}}
client_mock.list_tasks.return_value = {
"taskArns": [
"arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b54",
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
]
}
client_mock.list_tasks.return_value = {"taskArns": arns}

self.ecs.reattach = True
self.ecs.execute(self.mock_context)
self.ecs.execute({"ti": mock_ti})

uuid_mock.assert_called_once_with("mock_dag", "mock_ti", "mock_runid", "42")

extend_args = {}
if launch_type:
Expand All @@ -563,20 +571,14 @@ def test_reattach_successful(
if tags:
extend_args["tags"] = [{"key": k, "value": v} for (k, v) in tags.items()]

client_mock.describe_task_definition.assert_called_once_with(taskDefinition="t")

client_mock.list_tasks.assert_called_once_with(cluster="c", desiredStatus="RUNNING", family="f")
client_mock.list_tasks.assert_called_once_with(
cluster="c", desiredStatus="RUNNING", startedBy=fake_uuid
)

start_mock.assert_not_called()
xcom_pull_mock.assert_called_once_with(
self.mock_context,
key=self.ecs.REATTACH_XCOM_KEY,
task_ids=self.ecs.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.ecs.task_id),
)
wait_mock.assert_called_once_with()
check_mock.assert_called_once_with()
xcom_del_mock.assert_called_once()
assert self.ecs.arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
assert self.ecs.arn == expected_arn

@pytest.mark.parametrize(
"launch_type, tags",
Expand All @@ -587,29 +589,25 @@ def test_reattach_successful(
["", {"testTagKey": "testTagValue"}],
],
)
@mock.patch.object(EcsRunTaskOperator, "_xcom_del")
@mock.patch.object(EcsRunTaskOperator, "_try_reattach_task")
@mock.patch("airflow.providers.amazon.aws.operators.ecs.generate_uuid")
@mock.patch.object(EcsRunTaskOperator, "_wait_for_task_ended")
@mock.patch.object(EcsRunTaskOperator, "_check_success_task")
@mock.patch.object(EcsBaseOperator, "client")
def test_reattach_save_task_arn_xcom(
self,
client_mock,
check_mock,
wait_mock,
reattach_mock,
xcom_del_mock,
launch_type,
tags,
self, client_mock, check_mock, wait_mock, uuid_mock, launch_type, tags, caplog
):
"""Test no reattach in no running Task started by this Task ID."""
mock_ti = mock.MagicMock(name="MockedTaskInstance")
mock_ti.key.primary = ("mock_dag", "mock_ti", "mock_runid", 42)
fake_uuid = "01-02-03-04"
uuid_mock.return_value = fake_uuid

self.set_up_operator(launch_type=launch_type, tags=tags)
client_mock.describe_task_definition.return_value = {"taskDefinition": {"family": "f"}}
client_mock.list_tasks.return_value = {"taskArns": []}
client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES

self.ecs.reattach = True
self.ecs.execute(self.mock_context)
self.ecs.execute({"ti": mock_ti})

extend_args = {}
if launch_type:
Expand All @@ -619,12 +617,14 @@ def test_reattach_save_task_arn_xcom(
if tags:
extend_args["tags"] = [{"key": k, "value": v} for (k, v) in tags.items()]

reattach_mock.assert_called_once()
client_mock.list_tasks.assert_called_once_with(
cluster="c", desiredStatus="RUNNING", startedBy=fake_uuid
)
client_mock.run_task.assert_called_once()
wait_mock.assert_called_once_with()
check_mock.assert_called_once_with()
xcom_del_mock.assert_called_once()
assert self.ecs.arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
assert "No active previously launched task found to reattach" in caplog.messages

@mock.patch.object(EcsBaseOperator, "client")
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
Expand Down Expand Up @@ -670,17 +670,14 @@ def test_with_defer(self, client_mock):
assert deferred.value.trigger.task_arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"

@mock.patch.object(EcsRunTaskOperator, "client", new_callable=PropertyMock)
@mock.patch.object(EcsRunTaskOperator, "_xcom_del")
def test_execute_complete(self, xcom_del_mock: MagicMock, client_mock):
def test_execute_complete(self, client_mock):
event = {"status": "success", "task_arn": "my_arn"}
self.ecs.reattach = True

self.ecs.execute_complete(None, event)

# task gets described to assert its success
client_mock().describe_tasks.assert_called_once_with(cluster="c", tasks=["my_arn"])
# if reattach mode, xcom value is deleted on success
xcom_del_mock.assert_called_once()


class TestEcsCreateClusterOperator(EcsBaseTestCase):
Expand Down
Loading