1
1
import asyncio
2
+ from abc import ABC
2
3
from typing import Any , AsyncIterator , Dict , Iterable , List , Optional , Tuple
3
4
4
5
from airflow .triggers .base import BaseTrigger , TriggerEvent
10
11
)
11
12
12
13
13
- class EmrContainerSensorTrigger (BaseTrigger ):
14
+ class EmrContainerBaseTrigger (BaseTrigger , ABC ):
14
15
"""
15
- The EmrContainerSensorTrigger is triggered when EMR container is created, it polls for the AWS EMR EKS Virtual
16
- Cluster Job status. It is fired as deferred class with params to run the task in trigger worker
16
+ Poll for the status of EMR container until reaches terminal state
17
17
18
18
:param virtual_cluster_id: Reference Emr cluster id
19
19
:param job_id: job_id to check the state
20
- :param max_retries : maximum retry for poll for the status
20
+ :param max_tries : maximum try attempts for polling the status
21
21
:param aws_conn_id: Reference to AWS connection id
22
22
:param poll_interval: polling period in seconds to check for the status
23
23
"""
@@ -26,16 +26,21 @@ def __init__(
26
26
self ,
27
27
virtual_cluster_id : str ,
28
28
job_id : str ,
29
- max_retries : Optional [int ] = None ,
30
29
aws_conn_id : str = "aws_default" ,
31
30
poll_interval : int = 10 ,
31
+ max_tries : Optional [int ] = None ,
32
+ ** kwargs : Any ,
32
33
):
33
- super ().__init__ ()
34
34
self .virtual_cluster_id = virtual_cluster_id
35
35
self .job_id = job_id
36
- self .max_retries = max_retries
37
36
self .aws_conn_id = aws_conn_id
38
37
self .poll_interval = poll_interval
38
+ self .max_tries = max_tries
39
+ super ().__init__ (** kwargs )
40
+
41
+
42
+ class EmrContainerSensorTrigger (EmrContainerBaseTrigger ):
43
+ """Poll for the status of EMR container until reaches terminal state"""
39
44
40
45
def serialize (self ) -> Tuple [str , Dict [str , Any ]]:
41
46
"""Serializes EmrContainerSensorTrigger arguments and classpath."""
@@ -44,9 +49,9 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
44
49
{
45
50
"virtual_cluster_id" : self .virtual_cluster_id ,
46
51
"job_id" : self .job_id ,
47
- "max_retries" : self .max_retries ,
48
- "poll_interval" : self .poll_interval ,
49
52
"aws_conn_id" : self .aws_conn_id ,
53
+ "max_tries" : self .max_tries ,
54
+ "poll_interval" : self .poll_interval ,
50
55
},
51
56
)
52
57
@@ -66,7 +71,7 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
66
71
msg = "EMR Containers sensors completed"
67
72
yield TriggerEvent ({"status" : "success" , "message" : msg })
68
73
69
- if self .max_retries and try_number >= self .max_retries :
74
+ if self .max_tries and try_number >= self .max_tries :
70
75
yield TriggerEvent (
71
76
{
72
77
"status" : "error" ,
@@ -80,53 +85,20 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
80
85
yield TriggerEvent ({"status" : "error" , "message" : str (e )})
81
86
82
87
83
- class EmrContainerOperatorTrigger (BaseTrigger ):
84
- """
85
- The EmrContainerSensorTrigger is triggered when EMR container is created, it polls for the AWS EMR EKS Virtual
86
- Cluster Job status. It is fired as deferred class with params to run the task in trigger worker
87
-
88
- :param virtual_cluster_id: The EMR on EKS virtual cluster ID.
89
- :param name: The name of the job run.
90
- :param execution_role_arn: The IAM role ARN associated with the job run.
91
- :param release_label: The Amazon EMR release version to use for the job run.
92
- :param job_driver: Job configuration details, e.g. the Spark job parameters.
93
- :param configuration_overrides: The configuration overrides for the job run,
94
- specifically either application configuration or monitoring configuration.
95
- :param client_request_token: The client idempotency token of the job run request.
96
- :param aws_conn_id: Reference to AWS connection id.
97
- :param poll_interval: polling period in seconds to check for the status.
98
- :param max_retries: maximum retry for poll for the status.
99
- """
88
+ class EmrContainerOperatorTrigger (EmrContainerBaseTrigger ):
89
+ """Poll for the status of EMR container until reaches terminal state"""
100
90
101
91
INTERMEDIATE_STATES : List [str ] = ["PENDING" , "SUBMITTED" , "RUNNING" ]
102
92
FAILURE_STATES : List [str ] = ["FAILED" , "CANCELLED" , "CANCEL_PENDING" ]
103
93
SUCCESS_STATES : List [str ] = ["COMPLETED" ]
104
94
TERMINAL_STATES : List [str ] = ["COMPLETED" , "FAILED" , "CANCELLED" , "CANCEL_PENDING" ]
105
95
106
- def __init__ (
107
- self ,
108
- virtual_cluster_id : str ,
109
- name : str ,
110
- job_id : str ,
111
- aws_conn_id : str = "aws_default" ,
112
- poll_interval : int = 30 ,
113
- max_tries : Optional [int ] = None ,
114
- ):
115
- super ().__init__ ()
116
- self .virtual_cluster_id = virtual_cluster_id
117
- self .name = name
118
- self .job_id = job_id
119
- self .aws_conn_id = aws_conn_id
120
- self .poll_interval = poll_interval
121
- self .max_tries = max_tries
122
-
123
96
def serialize (self ) -> Tuple [str , Dict [str , Any ]]:
124
97
"""Serializes EmrContainerOperatorTrigger arguments and classpath."""
125
98
return (
126
99
"astronomer.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger" ,
127
100
{
128
101
"virtual_cluster_id" : self .virtual_cluster_id ,
129
- "name" : self .name ,
130
102
"job_id" : self .job_id ,
131
103
"aws_conn_id" : self .aws_conn_id ,
132
104
"max_tries" : self .max_tries ,
0 commit comments