Skip to content

Commit 3f297f8

Browse files
authored
Add EMR Container Base Trigger (#488)
* Add emr base container * Fix mypy * Change variable retries to tries * Move class variable at top and docstring fix
1 parent 6101f75 commit 3f297f8

File tree

4 files changed

+26
-56
lines changed

4 files changed

+26
-56
lines changed

astronomer/providers/amazon/aws/operators/emr.py

-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def execute(self, context: "Context") -> None:
4444
timeout=self.execution_timeout,
4545
trigger=EmrContainerOperatorTrigger(
4646
virtual_cluster_id=self.virtual_cluster_id,
47-
name=self.name,
4847
job_id=job_id,
4948
aws_conn_id=self.aws_conn_id,
5049
poll_interval=self.poll_interval,

astronomer/providers/amazon/aws/sensors/emr.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def execute(self, context: Context) -> None:
3737
trigger=EmrContainerSensorTrigger(
3838
virtual_cluster_id=self.virtual_cluster_id,
3939
job_id=self.job_id,
40-
max_retries=self.max_retries,
40+
max_tries=self.max_retries,
4141
aws_conn_id=self.aws_conn_id,
4242
poll_interval=self.poll_interval,
4343
),

astronomer/providers/amazon/aws/triggers/emr.py

+17-45
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from abc import ABC
23
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple
34

45
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -10,14 +11,13 @@
1011
)
1112

1213

13-
class EmrContainerSensorTrigger(BaseTrigger):
14+
class EmrContainerBaseTrigger(BaseTrigger, ABC):
1415
"""
15-
The EmrContainerSensorTrigger is triggered when EMR container is created, it polls for the AWS EMR EKS Virtual
16-
Cluster Job status. It is fired as deferred class with params to run the task in trigger worker
16+
Poll for the status of EMR container until reaches terminal state
1717
1818
:param virtual_cluster_id: Reference Emr cluster id
1919
:param job_id: job_id to check the state
20-
:param max_retries: maximum retry for poll for the status
20+
:param max_tries: maximum try attempts for polling the status
2121
:param aws_conn_id: Reference to AWS connection id
2222
:param poll_interval: polling period in seconds to check for the status
2323
"""
@@ -26,16 +26,21 @@ def __init__(
2626
self,
2727
virtual_cluster_id: str,
2828
job_id: str,
29-
max_retries: Optional[int] = None,
3029
aws_conn_id: str = "aws_default",
3130
poll_interval: int = 10,
31+
max_tries: Optional[int] = None,
32+
**kwargs: Any,
3233
):
33-
super().__init__()
3434
self.virtual_cluster_id = virtual_cluster_id
3535
self.job_id = job_id
36-
self.max_retries = max_retries
3736
self.aws_conn_id = aws_conn_id
3837
self.poll_interval = poll_interval
38+
self.max_tries = max_tries
39+
super().__init__(**kwargs)
40+
41+
42+
class EmrContainerSensorTrigger(EmrContainerBaseTrigger):
43+
"""Poll for the status of EMR container until reaches terminal state"""
3944

4045
def serialize(self) -> Tuple[str, Dict[str, Any]]:
4146
"""Serializes EmrContainerSensorTrigger arguments and classpath."""
@@ -44,9 +49,9 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
4449
{
4550
"virtual_cluster_id": self.virtual_cluster_id,
4651
"job_id": self.job_id,
47-
"max_retries": self.max_retries,
48-
"poll_interval": self.poll_interval,
4952
"aws_conn_id": self.aws_conn_id,
53+
"max_tries": self.max_tries,
54+
"poll_interval": self.poll_interval,
5055
},
5156
)
5257

@@ -66,7 +71,7 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
6671
msg = "EMR Containers sensors completed"
6772
yield TriggerEvent({"status": "success", "message": msg})
6873

69-
if self.max_retries and try_number >= self.max_retries:
74+
if self.max_tries and try_number >= self.max_tries:
7075
yield TriggerEvent(
7176
{
7277
"status": "error",
@@ -80,53 +85,20 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
8085
yield TriggerEvent({"status": "error", "message": str(e)})
8186

8287

83-
class EmrContainerOperatorTrigger(BaseTrigger):
84-
"""
85-
The EmrContainerSensorTrigger is triggered when EMR container is created, it polls for the AWS EMR EKS Virtual
86-
Cluster Job status. It is fired as deferred class with params to run the task in trigger worker
87-
88-
:param virtual_cluster_id: The EMR on EKS virtual cluster ID.
89-
:param name: The name of the job run.
90-
:param execution_role_arn: The IAM role ARN associated with the job run.
91-
:param release_label: The Amazon EMR release version to use for the job run.
92-
:param job_driver: Job configuration details, e.g. the Spark job parameters.
93-
:param configuration_overrides: The configuration overrides for the job run,
94-
specifically either application configuration or monitoring configuration.
95-
:param client_request_token: The client idempotency token of the job run request.
96-
:param aws_conn_id: Reference to AWS connection id.
97-
:param poll_interval: polling period in seconds to check for the status.
98-
:param max_retries: maximum retry for poll for the status.
99-
"""
88+
class EmrContainerOperatorTrigger(EmrContainerBaseTrigger):
89+
"""Poll for the status of EMR container until reaches terminal state"""
10090

10191
INTERMEDIATE_STATES: List[str] = ["PENDING", "SUBMITTED", "RUNNING"]
10292
FAILURE_STATES: List[str] = ["FAILED", "CANCELLED", "CANCEL_PENDING"]
10393
SUCCESS_STATES: List[str] = ["COMPLETED"]
10494
TERMINAL_STATES: List[str] = ["COMPLETED", "FAILED", "CANCELLED", "CANCEL_PENDING"]
10595

106-
def __init__(
107-
self,
108-
virtual_cluster_id: str,
109-
name: str,
110-
job_id: str,
111-
aws_conn_id: str = "aws_default",
112-
poll_interval: int = 30,
113-
max_tries: Optional[int] = None,
114-
):
115-
super().__init__()
116-
self.virtual_cluster_id = virtual_cluster_id
117-
self.name = name
118-
self.job_id = job_id
119-
self.aws_conn_id = aws_conn_id
120-
self.poll_interval = poll_interval
121-
self.max_tries = max_tries
122-
12396
def serialize(self) -> Tuple[str, Dict[str, Any]]:
12497
"""Serializes EmrContainerOperatorTrigger arguments and classpath."""
12598
return (
12699
"astronomer.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger",
127100
{
128101
"virtual_cluster_id": self.virtual_cluster_id,
129-
"name": self.name,
130102
"job_id": self.job_id,
131103
"aws_conn_id": self.aws_conn_id,
132104
"max_tries": self.max_tries,

tests/amazon/aws/triggers/test_emr.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@
5454

5555
def test_emr_container_sensors_trigger_serialization():
5656
"""
57-
Asserts that the TaskStateTrigger correctly serializes its arguments
57+
Asserts that the EmrContainerSensorTrigger correctly serializes its arguments
5858
and classpath.
5959
"""
6060
trigger = EmrContainerSensorTrigger(
6161
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
6262
job_id=JOB_ID,
63-
max_retries=MAX_RETRIES,
63+
max_tries=MAX_RETRIES,
6464
aws_conn_id=AWS_CONN_ID,
6565
poll_interval=POLL_INTERVAL,
6666
)
@@ -69,7 +69,7 @@ def test_emr_container_sensors_trigger_serialization():
6969
assert kwargs == {
7070
"virtual_cluster_id": VIRTUAL_CLUSTER_ID,
7171
"job_id": JOB_ID,
72-
"max_retries": MAX_RETRIES,
72+
"max_tries": MAX_RETRIES,
7373
"poll_interval": POLL_INTERVAL,
7474
"aws_conn_id": AWS_CONN_ID,
7575
}
@@ -92,7 +92,7 @@ async def test_emr_container_sensors_trigger_run(mock_query_status, mock_status)
9292
trigger = EmrContainerSensorTrigger(
9393
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
9494
job_id=JOB_ID,
95-
max_retries=MAX_RETRIES,
95+
max_tries=MAX_RETRIES,
9696
aws_conn_id=AWS_CONN_ID,
9797
poll_interval=POLL_INTERVAL,
9898
)
@@ -118,7 +118,7 @@ async def test_emr_container_sensors_trigger_completed(mock_query_status, mock_s
118118
trigger = EmrContainerSensorTrigger(
119119
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
120120
job_id=JOB_ID,
121-
max_retries=MAX_RETRIES,
121+
max_tries=MAX_RETRIES,
122122
aws_conn_id=AWS_CONN_ID,
123123
poll_interval=POLL_INTERVAL,
124124
)
@@ -142,7 +142,7 @@ async def test_emr_container_sensors_trigger_failure_status(mock_query_status, m
142142
trigger = EmrContainerSensorTrigger(
143143
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
144144
job_id=JOB_ID,
145-
max_retries=MAX_RETRIES,
145+
max_tries=MAX_RETRIES,
146146
aws_conn_id=AWS_CONN_ID,
147147
poll_interval=POLL_INTERVAL,
148148
)
@@ -162,7 +162,7 @@ async def test_emr_container_sensors_trigger_exception(mock_query_status):
162162
trigger = EmrContainerSensorTrigger(
163163
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
164164
job_id=JOB_ID,
165-
max_retries=MAX_RETRIES,
165+
max_tries=MAX_RETRIES,
166166
aws_conn_id=AWS_CONN_ID,
167167
poll_interval=POLL_INTERVAL,
168168
)
@@ -181,7 +181,7 @@ async def test_emr_container_sensor_trigger_timeout(mock_query_status):
181181
job_id=JOB_ID,
182182
aws_conn_id=AWS_CONN_ID,
183183
poll_interval=1,
184-
max_retries=2,
184+
max_tries=2,
185185
)
186186
generator = trigger.run()
187187
actual = await generator.asend(None)
@@ -433,7 +433,6 @@ def test_emr_container_operator_trigger_serialization():
433433
assert classpath == "astronomer.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger"
434434
assert kwargs == {
435435
"virtual_cluster_id": VIRTUAL_CLUSTER_ID,
436-
"name": NAME,
437436
"job_id": JOB_ID,
438437
"aws_conn_id": AWS_CONN_ID,
439438
"poll_interval": POLL_INTERVAL,

0 commit comments

Comments
 (0)