@@ -521,39 +521,47 @@ def test_check_success_task_not_raises(self, client_mock):
521
521
["" , {"testTagKey" : "testTagValue" }],
522
522
],
523
523
)
524
- @mock .patch .object (EcsRunTaskOperator , "_xcom_del" )
525
- @mock .patch .object (
526
- EcsRunTaskOperator ,
527
- "xcom_pull" ,
528
- return_value = f"arn:aws:ecs:us-east-1:012345678910:task/{ TASK_ID } " ,
524
+ @pytest .mark .parametrize (
525
+ "arns, expected_arn" ,
526
+ [
527
+ pytest .param (
528
+ [
529
+ f"arn:aws:ecs:us-east-1:012345678910:task/{ TASK_ID } " ,
530
+ "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b54" ,
531
+ ],
532
+ f"arn:aws:ecs:us-east-1:012345678910:task/{ TASK_ID } " ,
533
+ id = "multiple-arns" ,
534
+ ),
535
+ pytest .param (
536
+ [
537
+ f"arn:aws:ecs:us-east-1:012345678910:task/{ TASK_ID } " ,
538
+ ],
539
+ f"arn:aws:ecs:us-east-1:012345678910:task/{ TASK_ID } " ,
540
+ id = "simgle-arn" ,
541
+ ),
542
+ ],
529
543
)
544
+ @mock .patch ("airflow.providers.amazon.aws.operators.ecs.generate_uuid" )
530
545
@mock .patch .object (EcsRunTaskOperator , "_wait_for_task_ended" )
531
546
@mock .patch .object (EcsRunTaskOperator , "_check_success_task" )
532
547
@mock .patch .object (EcsRunTaskOperator , "_start_task" )
533
548
@mock .patch .object (EcsBaseOperator , "client" )
534
549
def test_reattach_successful (
535
- self ,
536
- client_mock ,
537
- start_mock ,
538
- check_mock ,
539
- wait_mock ,
540
- xcom_pull_mock ,
541
- xcom_del_mock ,
542
- launch_type ,
543
- tags ,
550
+ self , client_mock , start_mock , check_mock , wait_mock , uuid_mock , launch_type , tags , arns , expected_arn
544
551
):
552
+ """Test reattach on first running Task ARN."""
553
+ mock_ti = mock .MagicMock (name = "MockedTaskInstance" )
554
+ mock_ti .key .primary = ("mock_dag" , "mock_ti" , "mock_runid" , 42 )
555
+ fake_uuid = "01-02-03-04"
556
+ uuid_mock .return_value = fake_uuid
545
557
546
558
self .set_up_operator (launch_type = launch_type , tags = tags )
547
- client_mock .describe_task_definition .return_value = {"taskDefinition" : {"family" : "f" }}
548
- client_mock .list_tasks .return_value = {
549
- "taskArns" : [
550
- "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b54" ,
551
- f"arn:aws:ecs:us-east-1:012345678910:task/{ TASK_ID } " ,
552
- ]
553
- }
559
+ client_mock .list_tasks .return_value = {"taskArns" : arns }
554
560
555
561
self .ecs .reattach = True
556
- self .ecs .execute (self .mock_context )
562
+ self .ecs .execute ({"ti" : mock_ti })
563
+
564
+ uuid_mock .assert_called_once_with ("mock_dag" , "mock_ti" , "mock_runid" , "42" )
557
565
558
566
extend_args = {}
559
567
if launch_type :
@@ -563,20 +571,14 @@ def test_reattach_successful(
563
571
if tags :
564
572
extend_args ["tags" ] = [{"key" : k , "value" : v } for (k , v ) in tags .items ()]
565
573
566
- client_mock .describe_task_definition .assert_called_once_with (taskDefinition = "t" )
567
-
568
- client_mock . list_tasks . assert_called_once_with ( cluster = "c" , desiredStatus = "RUNNING" , family = "f" )
574
+ client_mock .list_tasks .assert_called_once_with (
575
+ cluster = "c" , desiredStatus = "RUNNING" , startedBy = fake_uuid
576
+ )
569
577
570
578
start_mock .assert_not_called ()
571
- xcom_pull_mock .assert_called_once_with (
572
- self .mock_context ,
573
- key = self .ecs .REATTACH_XCOM_KEY ,
574
- task_ids = self .ecs .REATTACH_XCOM_TASK_ID_TEMPLATE .format (task_id = self .ecs .task_id ),
575
- )
576
579
wait_mock .assert_called_once_with ()
577
580
check_mock .assert_called_once_with ()
578
- xcom_del_mock .assert_called_once ()
579
- assert self .ecs .arn == f"arn:aws:ecs:us-east-1:012345678910:task/{ TASK_ID } "
581
+ assert self .ecs .arn == expected_arn
580
582
581
583
@pytest .mark .parametrize (
582
584
"launch_type, tags" ,
@@ -587,29 +589,25 @@ def test_reattach_successful(
587
589
["" , {"testTagKey" : "testTagValue" }],
588
590
],
589
591
)
590
- @mock .patch .object (EcsRunTaskOperator , "_xcom_del" )
591
- @mock .patch .object (EcsRunTaskOperator , "_try_reattach_task" )
592
+ @mock .patch ("airflow.providers.amazon.aws.operators.ecs.generate_uuid" )
592
593
@mock .patch .object (EcsRunTaskOperator , "_wait_for_task_ended" )
593
594
@mock .patch .object (EcsRunTaskOperator , "_check_success_task" )
594
595
@mock .patch .object (EcsBaseOperator , "client" )
595
596
def test_reattach_save_task_arn_xcom (
596
- self ,
597
- client_mock ,
598
- check_mock ,
599
- wait_mock ,
600
- reattach_mock ,
601
- xcom_del_mock ,
602
- launch_type ,
603
- tags ,
597
+ self , client_mock , check_mock , wait_mock , uuid_mock , launch_type , tags , caplog
604
598
):
599
+ """Test no reattach in no running Task started by this Task ID."""
600
+ mock_ti = mock .MagicMock (name = "MockedTaskInstance" )
601
+ mock_ti .key .primary = ("mock_dag" , "mock_ti" , "mock_runid" , 42 )
602
+ fake_uuid = "01-02-03-04"
603
+ uuid_mock .return_value = fake_uuid
605
604
606
605
self .set_up_operator (launch_type = launch_type , tags = tags )
607
- client_mock .describe_task_definition .return_value = {"taskDefinition" : {"family" : "f" }}
608
606
client_mock .list_tasks .return_value = {"taskArns" : []}
609
607
client_mock .run_task .return_value = RESPONSE_WITHOUT_FAILURES
610
608
611
609
self .ecs .reattach = True
612
- self .ecs .execute (self . mock_context )
610
+ self .ecs .execute ({ "ti" : mock_ti } )
613
611
614
612
extend_args = {}
615
613
if launch_type :
@@ -619,12 +617,14 @@ def test_reattach_save_task_arn_xcom(
619
617
if tags :
620
618
extend_args ["tags" ] = [{"key" : k , "value" : v } for (k , v ) in tags .items ()]
621
619
622
- reattach_mock .assert_called_once ()
620
+ client_mock .list_tasks .assert_called_once_with (
621
+ cluster = "c" , desiredStatus = "RUNNING" , startedBy = fake_uuid
622
+ )
623
623
client_mock .run_task .assert_called_once ()
624
624
wait_mock .assert_called_once_with ()
625
625
check_mock .assert_called_once_with ()
626
- xcom_del_mock .assert_called_once ()
627
626
assert self .ecs .arn == f"arn:aws:ecs:us-east-1:012345678910:task/{ TASK_ID } "
627
+ assert "No active previously launched task found to reattach" in caplog .messages
628
628
629
629
@mock .patch .object (EcsBaseOperator , "client" )
630
630
@mock .patch ("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher" )
@@ -670,17 +670,14 @@ def test_with_defer(self, client_mock):
670
670
assert deferred .value .trigger .task_arn == f"arn:aws:ecs:us-east-1:012345678910:task/{ TASK_ID } "
671
671
672
672
@mock .patch .object (EcsRunTaskOperator , "client" , new_callable = PropertyMock )
673
- @mock .patch .object (EcsRunTaskOperator , "_xcom_del" )
674
- def test_execute_complete (self , xcom_del_mock : MagicMock , client_mock ):
673
+ def test_execute_complete (self , client_mock ):
675
674
event = {"status" : "success" , "task_arn" : "my_arn" }
676
675
self .ecs .reattach = True
677
676
678
677
self .ecs .execute_complete (None , event )
679
678
680
679
# task gets described to assert its success
681
680
client_mock ().describe_tasks .assert_called_once_with (cluster = "c" , tasks = ["my_arn" ])
682
- # if reattach mode, xcom value is deleted on success
683
- xcom_del_mock .assert_called_once ()
684
681
685
682
686
683
class TestEcsCreateClusterOperator (EcsBaseTestCase ):
0 commit comments