Skip to content

Commit fc0250f

Browse files
authored
Allow attaching to previously launched task in ECSOperator (#16685)
This PR adds a parameter reattach_prev_task to ECSOperator. **Before:** Until now we could use 'reattach' which was reattaching a running ECS Task (if there was one running) of the same 'family' instead of creating a new one. The problem was that if we had workflows using the same ECS Task Definition in several tasks, it didn't know which one to reattach and we could only use `concurrency=1` in some pipelines for example (when we launch the same ECS task in parallel from Airflow with different configurations). **Now:** Now with reattach_prev_task instead, when we launch a new ECS task, it will store temporarily the ECS Task ARN in XCOM. If there is an issue during the run (typically connection problem between Airflow and ECS for long-running tasks or Airflow worker restarting which was then still running those tasks in the background without Airflow being aware of it): - self._start_task will store the ECS task ARN in XCOM (in a 'fake' task_id equal to f"{self.task_id}_task_arn" - in the next execution, it will check if this task ARN is still running and if so it will reattach it to the operator, otherwise it will create a new one - when the operator runs succesfully it will delete the XCOM value I didn't change the logic of 'reattach' to do that directly because I didn't know if it had been designed for other use cases **Update 2021-07-01:** After discussing with @darwinyip I made the change to 'reattach' directly instead of creating a new flag
1 parent 16c55f1 commit fc0250f

File tree

3 files changed

+118
-19
lines changed

3 files changed

+118
-19
lines changed

airflow/providers/amazon/aws/hooks/base_aws.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def retry_decorator(fun: Callable):
530530
def decorator_f(self, *args, **kwargs):
531531
retry_args = getattr(self, 'retry_args', None)
532532
if retry_args is None:
533-
return fun(self)
533+
return fun(self, *args, **kwargs)
534534
multiplier = retry_args.get('multiplier', 1)
535535
min_limit = retry_args.get('min', 1)
536536
max_limit = retry_args.get('max', 1)
@@ -543,7 +543,7 @@ def decorator_f(self, *args, **kwargs):
543543
'before': tenacity_logger,
544544
'after': tenacity_logger,
545545
}
546-
return tenacity.retry(**default_kwargs)(fun)(self)
546+
return tenacity.retry(**default_kwargs)(fun)(self, *args, **kwargs)
547547

548548
return decorator_f
549549

airflow/providers/amazon/aws/operators/ecs.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@
2424
from botocore.waiter import Waiter
2525

2626
from airflow.exceptions import AirflowException
27-
from airflow.models import BaseOperator
27+
from airflow.models import BaseOperator, XCom
2828
from airflow.providers.amazon.aws.exceptions import ECSOperatorError
2929
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
3030
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
3131
from airflow.typing_compat import Protocol, runtime_checkable
32+
from airflow.utils.session import provide_session
3233

3334

3435
def should_retry(exception: Exception):
@@ -135,8 +136,10 @@ class ECSOperator(BaseOperator):
135136
Only required if you want logs to be shown in the Airflow UI after your job has
136137
finished.
137138
:type awslogs_stream_prefix: str
138-
:param reattach: If set to True, will check if a task from the same family is already running.
139-
If so, the operator will attach to it instead of starting a new task.
139+
:param reattach: If set to True, will check if the task previously launched by the task_instance
140+
is already running. If so, the operator will attach to it instead of starting a new task.
141+
This is to avoid relaunching a new task when the connection drops between Airflow and ECS while
142+
the task is running (when the Airflow worker is restarted for example).
140143
:type reattach: bool
141144
:param quota_retry: Config if and how to retry _start_task() for transient errors.
142145
:type quota_retry: dict
@@ -145,6 +148,8 @@ class ECSOperator(BaseOperator):
145148
ui_color = '#f0ede4'
146149
template_fields = ('overrides',)
147150
template_fields_renderers = {"overrides": "json"}
151+
REATTACH_XCOM_KEY = "ecs_task_arn"
152+
REATTACH_XCOM_TASK_ID_TEMPLATE = "{task_id}_task_arn"
148153

149154
def __init__(
150155
self,
@@ -200,7 +205,8 @@ def __init__(
200205
self.arn: Optional[str] = None
201206
self.retry_args = quota_retry
202207

203-
def execute(self, context):
208+
@provide_session
209+
def execute(self, context, session=None):
204210
self.log.info(
205211
'Running ECS Task - Task definition: %s - on cluster %s', self.task_definition, self.cluster
206212
)
@@ -212,20 +218,28 @@ def execute(self, context):
212218
self._try_reattach_task()
213219

214220
if not self.arn:
215-
self._start_task()
221+
self._start_task(context)
216222

217223
self._wait_for_task_ended()
218224

219225
self._check_success_task()
220226

221227
self.log.info('ECS Task has been successfully executed')
222228

229+
if self.reattach:
230+
# Clear the XCom value storing the ECS task ARN if the task has completed
231+
# as we can't reattach it anymore
232+
self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))
233+
223234
if self.do_xcom_push:
224235
return self._last_log_message()
225236

226237
return None
227238

228-
def _start_task(self):
239+
def _xcom_del(self, session, task_id):
240+
session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == task_id).delete()
241+
242+
def _start_task(self, context):
229243
run_opts = {
230244
'cluster': self.cluster,
231245
'taskDefinition': self.task_definition,
@@ -261,6 +275,26 @@ def _start_task(self):
261275
self.log.info('ECS Task started: %s', response)
262276

263277
self.arn = response['tasks'][0]['taskArn']
278+
ecs_task_id = self.arn.split("/")[-1]
279+
self.log.info(f"ECS task ID is: {ecs_task_id}")
280+
281+
if self.reattach:
282+
# Save the task ARN in XCom to be able to reattach it if needed
283+
self._xcom_set(
284+
context,
285+
key=self.REATTACH_XCOM_KEY,
286+
value=self.arn,
287+
task_id=self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id),
288+
)
289+
290+
def _xcom_set(self, context, key, value, task_id):
291+
XCom.set(
292+
key=key,
293+
value=value,
294+
task_id=task_id,
295+
dag_id=self.dag_id,
296+
execution_date=context["ti"].execution_date,
297+
)
264298

265299
def _try_reattach_task(self):
266300
task_def_resp = self.client.describe_task_definition(taskDefinition=self.task_definition)
@@ -271,15 +305,16 @@ def _try_reattach_task(self):
271305
)
272306
running_tasks = list_tasks_resp['taskArns']
273307

274-
running_tasks_count = len(running_tasks)
275-
if running_tasks_count > 1:
276-
self.arn = running_tasks[0]
277-
self.log.warning('More than 1 ECS Task found. Reattaching to %s', self.arn)
278-
elif running_tasks_count == 1:
279-
self.arn = running_tasks[0]
280-
self.log.info('Reattaching task: %s', self.arn)
308+
# Check if the ECS task previously launched is already running
309+
previous_task_arn = self.xcom_pull(
310+
task_ids=self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id),
311+
key=self.REATTACH_XCOM_KEY,
312+
)
313+
if previous_task_arn in running_tasks:
314+
self.arn = previous_task_arn
315+
self.log.info("Reattaching previously launched task: %s", self.arn)
281316
else:
282-
self.log.info('No active tasks found to reattach')
317+
self.log.info("No active previously launched task found to reattach")
283318

284319
def _wait_for_task_ended(self) -> None:
285320
if not self.client or not self.arn:

tests/providers/amazon/aws/operators/test_ecs.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,16 +316,27 @@ def test_check_success_task_not_raises(self):
316316
['', {'testTagKey': 'testTagValue'}],
317317
]
318318
)
319+
@mock.patch.object(ECSOperator, "_xcom_del")
320+
@mock.patch.object(
321+
ECSOperator,
322+
"xcom_pull",
323+
return_value="arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55",
324+
)
319325
@mock.patch.object(ECSOperator, '_wait_for_task_ended')
320326
@mock.patch.object(ECSOperator, '_check_success_task')
321327
@mock.patch.object(ECSOperator, '_start_task')
322-
def test_reattach_successful(self, launch_type, tags, start_mock, check_mock, wait_mock):
328+
def test_reattach_successful(
329+
self, launch_type, tags, start_mock, check_mock, wait_mock, xcom_pull_mock, xcom_del_mock
330+
):
323331

324-
self.set_up_operator(launch_type=launch_type, tags=tags)
332+
self.set_up_operator(launch_type=launch_type, tags=tags) # pylint: disable=no-value-for-parameter
325333
client_mock = self.aws_hook_mock.return_value.get_conn.return_value
326334
client_mock.describe_task_definition.return_value = {'taskDefinition': {'family': 'f'}}
327335
client_mock.list_tasks.return_value = {
328-
'taskArns': ['arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55']
336+
'taskArns': [
337+
'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b54',
338+
'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55',
339+
]
329340
}
330341

331342
self.ecs.reattach = True
@@ -345,8 +356,61 @@ def test_reattach_successful(self, launch_type, tags, start_mock, check_mock, wa
345356
client_mock.list_tasks.assert_called_once_with(cluster='c', desiredStatus='RUNNING', family='f')
346357

347358
start_mock.assert_not_called()
359+
xcom_pull_mock.assert_called_once_with(
360+
key=self.ecs.REATTACH_XCOM_KEY,
361+
task_ids=self.ecs.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.ecs.task_id),
362+
)
363+
wait_mock.assert_called_once_with()
364+
check_mock.assert_called_once_with()
365+
xcom_del_mock.assert_called_once()
366+
assert self.ecs.arn == 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55'
367+
368+
@parameterized.expand(
369+
[
370+
['EC2', None],
371+
['FARGATE', None],
372+
['EC2', {'testTagKey': 'testTagValue'}],
373+
['', {'testTagKey': 'testTagValue'}],
374+
]
375+
)
376+
@mock.patch.object(ECSOperator, '_xcom_del')
377+
@mock.patch.object(ECSOperator, '_xcom_set')
378+
@mock.patch.object(ECSOperator, '_try_reattach_task')
379+
@mock.patch.object(ECSOperator, '_wait_for_task_ended')
380+
@mock.patch.object(ECSOperator, '_check_success_task')
381+
def test_reattach_save_task_arn_xcom(
382+
self, launch_type, tags, check_mock, wait_mock, reattach_mock, xcom_set_mock, xcom_del_mock
383+
):
384+
385+
self.set_up_operator(launch_type=launch_type, tags=tags) # pylint: disable=no-value-for-parameter
386+
client_mock = self.aws_hook_mock.return_value.get_conn.return_value
387+
client_mock.describe_task_definition.return_value = {'taskDefinition': {'family': 'f'}}
388+
client_mock.list_tasks.return_value = {'taskArns': []}
389+
client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES
390+
391+
self.ecs.reattach = True
392+
self.ecs.execute(None)
393+
394+
self.aws_hook_mock.return_value.get_conn.assert_called_once()
395+
extend_args = {}
396+
if launch_type:
397+
extend_args['launchType'] = launch_type
398+
if launch_type == 'FARGATE':
399+
extend_args['platformVersion'] = 'LATEST'
400+
if tags:
401+
extend_args['tags'] = [{'key': k, 'value': v} for (k, v) in tags.items()]
402+
403+
reattach_mock.assert_called_once()
404+
client_mock.run_task.assert_called_once()
405+
xcom_set_mock.assert_called_once_with(
406+
None,
407+
key=self.ecs.REATTACH_XCOM_KEY,
408+
task_id=self.ecs.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.ecs.task_id),
409+
value="arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55",
410+
)
348411
wait_mock.assert_called_once_with()
349412
check_mock.assert_called_once_with()
413+
xcom_del_mock.assert_called_once()
350414
assert self.ecs.arn == 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55'
351415

352416
@mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output")

0 commit comments

Comments
 (0)