Skip to content

Commit 5079617

Browse files
Taragolisferruzzi
andcommitted
Remove non-public interface usage in EcsOperator
Co-authored-by: D. Ferruzzi <[email protected]>
1 parent abef61f commit 5079617

File tree

4 files changed

+194
-94
lines changed

4 files changed

+194
-94
lines changed

airflow/providers/amazon/aws/operators/ecs.py

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from airflow.configuration import conf
3030
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
31-
from airflow.models import BaseOperator, XCom
31+
from airflow.models import BaseOperator
3232
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
3333
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
3434
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, should_retry_eni
@@ -38,11 +38,12 @@
3838
ClusterInactiveTrigger,
3939
TaskDoneTrigger,
4040
)
41+
from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
4142
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
4243
from airflow.utils.helpers import prune_dict
43-
from airflow.utils.session import provide_session
4444

4545
if TYPE_CHECKING:
46+
from airflow.models import TaskInstance
4647
from airflow.utils.context import Context
4748

4849
DEFAULT_CONN_ID = "aws_default"
@@ -450,8 +451,6 @@ class EcsRunTaskOperator(EcsBaseOperator):
450451
"network_configuration": "json",
451452
"tags": "json",
452453
}
453-
REATTACH_XCOM_KEY = "ecs_task_arn"
454-
REATTACH_XCOM_TASK_ID_TEMPLATE = "{task_id}_task_arn"
455454

456455
def __init__(
457456
self,
@@ -507,6 +506,8 @@ def __init__(
507506
self.awslogs_region = self.region
508507

509508
self.arn: str | None = None
509+
self.started_by: str | None = None
510+
510511
self.retry_args = quota_retry
511512
self.task_log_fetcher: AwsTaskLogFetcher | None = None
512513
self.wait_for_completion = wait_for_completion
@@ -525,19 +526,22 @@ def _get_ecs_task_id(task_arn: str | None) -> str | None:
525526
return None
526527
return task_arn.split("/")[-1]
527528

528-
@provide_session
529-
def execute(self, context, session=None):
529+
def execute(self, context):
530530
self.log.info(
531531
"Running ECS Task - Task definition: %s - on cluster %s", self.task_definition, self.cluster
532532
)
533533
self.log.info("EcsOperator overrides: %s", self.overrides)
534534

535535
if self.reattach:
536-
self._try_reattach_task(context)
536+
# Generate deterministic UUID which refers to unique TaskInstanceKey
537+
ti: TaskInstance = context["ti"]
538+
self.started_by = generate_uuid(*map(str, ti.key.primary))
539+
self.log.info("Try to find run with startedBy=%r", self.started_by)
540+
self._try_reattach_task()
537541

538542
if not self.arn:
539543
# start the task except if we reattached to an existing one just before.
540-
self._start_task(context)
544+
self._start_task()
541545

542546
if self.deferrable:
543547
self.defer(
@@ -574,7 +578,7 @@ def execute(self, context, session=None):
574578
else:
575579
self._wait_for_task_ended()
576580

577-
self._after_execution(session)
581+
self._after_execution()
578582

579583
if self.do_xcom_push and self.task_log_fetcher:
580584
return self.task_log_fetcher.get_last_log_message()
@@ -598,27 +602,15 @@ def execute_complete(self, context, event=None):
598602
if len(one_log["events"]) > 0:
599603
return one_log["events"][0]["message"]
600604

601-
@provide_session
602-
def _after_execution(self, session=None):
605+
def _after_execution(self):
603606
self._check_success_task()
604607

605-
self.log.info("ECS Task has been successfully executed")
606-
607-
if self.reattach:
608-
# Clear the XCom value storing the ECS task ARN if the task has completed
609-
# as we can't reattach it anymore
610-
self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))
611-
612-
def _xcom_del(self, session, task_id):
613-
session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == task_id).delete()
614-
615-
@AwsBaseHook.retry(should_retry_eni)
616-
def _start_task(self, context):
608+
def _start_task(self):
617609
run_opts = {
618610
"cluster": self.cluster,
619611
"taskDefinition": self.task_definition,
620612
"overrides": self.overrides,
621-
"startedBy": self.owner,
613+
"startedBy": self.started_by or self.owner,
622614
}
623615

624616
if self.capacity_provider_strategy:
@@ -650,27 +642,15 @@ def _start_task(self, context):
650642
self.arn = response["tasks"][0]["taskArn"]
651643
self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))
652644

653-
if self.reattach:
654-
# Save the task ARN in XCom to be able to reattach it if needed
655-
self.xcom_push(context, key=self.REATTACH_XCOM_KEY, value=self.arn)
656-
657-
def _try_reattach_task(self, context):
658-
task_def_resp = self.client.describe_task_definition(taskDefinition=self.task_definition)
659-
ecs_task_family = task_def_resp["taskDefinition"]["family"]
660-
645+
def _try_reattach_task(self):
661646
list_tasks_resp = self.client.list_tasks(
662-
cluster=self.cluster, desiredStatus="RUNNING", family=ecs_task_family
647+
cluster=self.cluster, desiredStatus="RUNNING", startedBy=self.started_by
663648
)
664649
running_tasks = list_tasks_resp["taskArns"]
665-
666-
# Check if the ECS task previously launched is already running
667-
previous_task_arn = self.xcom_pull(
668-
context,
669-
task_ids=self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id),
670-
key=self.REATTACH_XCOM_KEY,
671-
)
672-
if previous_task_arn in running_tasks:
673-
self.arn = previous_task_arn
650+
if running_tasks:
651+
if len(running_tasks) > 1:
652+
self.log.warning("Found more then one previously launched tasks: %s", running_tasks)
653+
self.arn = running_tasks[0]
674654
self.log.info("Reattaching previously launched task: %s", self.arn)
675655
else:
676656
self.log.info("No active previously launched task found to reattach")
@@ -690,8 +670,6 @@ def _wait_for_task_ended(self) -> None:
690670
},
691671
)
692672

693-
return
694-
695673
def _aws_logs_enabled(self):
696674
return self.awslogs_group and self.awslogs_stream_prefix
697675

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from uuid import NAMESPACE_OID, UUID, uuid5
21+
22+
NIL_UUID = UUID(int=0)
23+
24+
25+
def generate_uuid(*values: str | None, namespace: UUID = NAMESPACE_OID) -> str:
26+
"""
27+
Convert input values to deterministic UUID string representation.
28+
29+
This function is only intended to generate a hash which used as an identifier, not for any security use.
30+
31+
Generates a UUID v5 (SHA-1 + Namespace) for each value provided,
32+
and this UUID is used as the Namespace for the next element.
33+
34+
If only one non-None value is provided to the function, then the result of the function
35+
would be the same as result of ``uuid.uuid5``.
36+
37+
All ``None`` values are replaced by NIL UUID. If it only one value is provided then return NIL UUID.
38+
39+
:param namespace: Initial namespace value to pass into the ``uuid.uuid5`` function.
40+
"""
41+
if not values:
42+
raise ValueError("Expected at least 1 argument")
43+
44+
if len(values) == 1 and values[0] is None:
45+
return str(NIL_UUID)
46+
47+
result = namespace
48+
for item in values:
49+
result = uuid5(result, item if item is not None else str(NIL_UUID))
50+
51+
return str(result)

tests/providers/amazon/aws/operators/test_ecs.py

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -521,39 +521,47 @@ def test_check_success_task_not_raises(self, client_mock):
521521
["", {"testTagKey": "testTagValue"}],
522522
],
523523
)
524-
@mock.patch.object(EcsRunTaskOperator, "_xcom_del")
525-
@mock.patch.object(
526-
EcsRunTaskOperator,
527-
"xcom_pull",
528-
return_value=f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
524+
@pytest.mark.parametrize(
525+
"arns, expected_arn",
526+
[
527+
pytest.param(
528+
[
529+
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
530+
"arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b54",
531+
],
532+
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
533+
id="multiple-arns",
534+
),
535+
pytest.param(
536+
[
537+
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
538+
],
539+
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
540+
id="simgle-arn",
541+
),
542+
],
529543
)
544+
@mock.patch("airflow.providers.amazon.aws.operators.ecs.generate_uuid")
530545
@mock.patch.object(EcsRunTaskOperator, "_wait_for_task_ended")
531546
@mock.patch.object(EcsRunTaskOperator, "_check_success_task")
532547
@mock.patch.object(EcsRunTaskOperator, "_start_task")
533548
@mock.patch.object(EcsBaseOperator, "client")
534549
def test_reattach_successful(
535-
self,
536-
client_mock,
537-
start_mock,
538-
check_mock,
539-
wait_mock,
540-
xcom_pull_mock,
541-
xcom_del_mock,
542-
launch_type,
543-
tags,
550+
self, client_mock, start_mock, check_mock, wait_mock, uuid_mock, launch_type, tags, arns, expected_arn
544551
):
552+
"""Test reattach on first running Task ARN."""
553+
mock_ti = mock.MagicMock(name="MockedTaskInstance")
554+
mock_ti.key.primary = ("mock_dag", "mock_ti", "mock_runid", 42)
555+
fake_uuid = "01-02-03-04"
556+
uuid_mock.return_value = fake_uuid
545557

546558
self.set_up_operator(launch_type=launch_type, tags=tags)
547-
client_mock.describe_task_definition.return_value = {"taskDefinition": {"family": "f"}}
548-
client_mock.list_tasks.return_value = {
549-
"taskArns": [
550-
"arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b54",
551-
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
552-
]
553-
}
559+
client_mock.list_tasks.return_value = {"taskArns": arns}
554560

555561
self.ecs.reattach = True
556-
self.ecs.execute(self.mock_context)
562+
self.ecs.execute({"ti": mock_ti})
563+
564+
uuid_mock.assert_called_once_with("mock_dag", "mock_ti", "mock_runid", "42")
557565

558566
extend_args = {}
559567
if launch_type:
@@ -563,20 +571,14 @@ def test_reattach_successful(
563571
if tags:
564572
extend_args["tags"] = [{"key": k, "value": v} for (k, v) in tags.items()]
565573

566-
client_mock.describe_task_definition.assert_called_once_with(taskDefinition="t")
567-
568-
client_mock.list_tasks.assert_called_once_with(cluster="c", desiredStatus="RUNNING", family="f")
574+
client_mock.list_tasks.assert_called_once_with(
575+
cluster="c", desiredStatus="RUNNING", startedBy=fake_uuid
576+
)
569577

570578
start_mock.assert_not_called()
571-
xcom_pull_mock.assert_called_once_with(
572-
self.mock_context,
573-
key=self.ecs.REATTACH_XCOM_KEY,
574-
task_ids=self.ecs.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.ecs.task_id),
575-
)
576579
wait_mock.assert_called_once_with()
577580
check_mock.assert_called_once_with()
578-
xcom_del_mock.assert_called_once()
579-
assert self.ecs.arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
581+
assert self.ecs.arn == expected_arn
580582

581583
@pytest.mark.parametrize(
582584
"launch_type, tags",
@@ -587,29 +589,25 @@ def test_reattach_successful(
587589
["", {"testTagKey": "testTagValue"}],
588590
],
589591
)
590-
@mock.patch.object(EcsRunTaskOperator, "_xcom_del")
591-
@mock.patch.object(EcsRunTaskOperator, "_try_reattach_task")
592+
@mock.patch("airflow.providers.amazon.aws.operators.ecs.generate_uuid")
592593
@mock.patch.object(EcsRunTaskOperator, "_wait_for_task_ended")
593594
@mock.patch.object(EcsRunTaskOperator, "_check_success_task")
594595
@mock.patch.object(EcsBaseOperator, "client")
595596
def test_reattach_save_task_arn_xcom(
596-
self,
597-
client_mock,
598-
check_mock,
599-
wait_mock,
600-
reattach_mock,
601-
xcom_del_mock,
602-
launch_type,
603-
tags,
597+
self, client_mock, check_mock, wait_mock, uuid_mock, launch_type, tags, caplog
604598
):
599+
"""Test no reattach in no running Task started by this Task ID."""
600+
mock_ti = mock.MagicMock(name="MockedTaskInstance")
601+
mock_ti.key.primary = ("mock_dag", "mock_ti", "mock_runid", 42)
602+
fake_uuid = "01-02-03-04"
603+
uuid_mock.return_value = fake_uuid
605604

606605
self.set_up_operator(launch_type=launch_type, tags=tags)
607-
client_mock.describe_task_definition.return_value = {"taskDefinition": {"family": "f"}}
608606
client_mock.list_tasks.return_value = {"taskArns": []}
609607
client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES
610608

611609
self.ecs.reattach = True
612-
self.ecs.execute(self.mock_context)
610+
self.ecs.execute({"ti": mock_ti})
613611

614612
extend_args = {}
615613
if launch_type:
@@ -619,12 +617,14 @@ def test_reattach_save_task_arn_xcom(
619617
if tags:
620618
extend_args["tags"] = [{"key": k, "value": v} for (k, v) in tags.items()]
621619

622-
reattach_mock.assert_called_once()
620+
client_mock.list_tasks.assert_called_once_with(
621+
cluster="c", desiredStatus="RUNNING", startedBy=fake_uuid
622+
)
623623
client_mock.run_task.assert_called_once()
624624
wait_mock.assert_called_once_with()
625625
check_mock.assert_called_once_with()
626-
xcom_del_mock.assert_called_once()
627626
assert self.ecs.arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
627+
assert "No active previously launched task found to reattach" in caplog.messages
628628

629629
@mock.patch.object(EcsBaseOperator, "client")
630630
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
@@ -670,17 +670,14 @@ def test_with_defer(self, client_mock):
670670
assert deferred.value.trigger.task_arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
671671

672672
@mock.patch.object(EcsRunTaskOperator, "client", new_callable=PropertyMock)
673-
@mock.patch.object(EcsRunTaskOperator, "_xcom_del")
674-
def test_execute_complete(self, xcom_del_mock: MagicMock, client_mock):
673+
def test_execute_complete(self, client_mock):
675674
event = {"status": "success", "task_arn": "my_arn"}
676675
self.ecs.reattach = True
677676

678677
self.ecs.execute_complete(None, event)
679678

680679
# task gets described to assert its success
681680
client_mock().describe_tasks.assert_called_once_with(cluster="c", tasks=["my_arn"])
682-
# if reattach mode, xcom value is deleted on success
683-
xcom_del_mock.assert_called_once()
684681

685682

686683
class TestEcsCreateClusterOperator(EcsBaseTestCase):

0 commit comments

Comments
 (0)