Skip to content

Add EMR Container Base Trigger #488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion astronomer/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def execute(self, context: "Context") -> None:
timeout=self.execution_timeout,
trigger=EmrContainerOperatorTrigger(
virtual_cluster_id=self.virtual_cluster_id,
name=self.name,
job_id=job_id,
aws_conn_id=self.aws_conn_id,
poll_interval=self.poll_interval,
Expand Down
2 changes: 1 addition & 1 deletion astronomer/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def execute(self, context: Context) -> None:
trigger=EmrContainerSensorTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
max_retries=self.max_retries,
max_tries=self.max_retries,
aws_conn_id=self.aws_conn_id,
poll_interval=self.poll_interval,
),
Expand Down
72 changes: 22 additions & 50 deletions astronomer/providers/amazon/aws/triggers/emr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from abc import ABC
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple

from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand All @@ -10,14 +11,13 @@
)


class EmrContainerSensorTrigger(BaseTrigger):
class EmrContainerBaseTrigger(BaseTrigger, ABC):
"""
The EmrContainerSensorTrigger is triggered when EMR container is created, it polls for the AWS EMR EKS Virtual
Cluster Job status. It is fired as deferred class with params to run the task in trigger worker
Poll for the status of EMR container until reaches terminal state

:param virtual_cluster_id: Reference Emr cluster id
:param job_id: job_id to check the state
:param max_retries: maximum retry for poll for the status
:param max_tries: maximum retry for poll for the status
:param aws_conn_id: Reference to AWS connection id
:param poll_interval: polling period in seconds to check for the status
"""
Expand All @@ -26,16 +26,21 @@ def __init__(
self,
virtual_cluster_id: str,
job_id: str,
max_retries: Optional[int] = None,
aws_conn_id: str = "aws_default",
poll_interval: int = 10,
max_tries: Optional[int] = None,
**kwargs: Any,
):
super().__init__()
self.virtual_cluster_id = virtual_cluster_id
self.job_id = job_id
self.max_retries = max_retries
self.aws_conn_id = aws_conn_id
self.poll_interval = poll_interval
self.max_tries = max_tries
super().__init__(**kwargs)


class EmrContainerSensorTrigger(EmrContainerBaseTrigger):
"""Poll for the status of EMR container until reaches terminal state"""

def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes EmrContainerSensorTrigger arguments and classpath."""
Expand All @@ -44,9 +49,9 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
{
"virtual_cluster_id": self.virtual_cluster_id,
"job_id": self.job_id,
"max_retries": self.max_retries,
"poll_interval": self.poll_interval,
"aws_conn_id": self.aws_conn_id,
"max_tries": self.max_tries,
"poll_interval": self.poll_interval,
},
)

Expand All @@ -66,7 +71,7 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
msg = "EMR Containers sensors completed"
yield TriggerEvent({"status": "success", "message": msg})

if self.max_retries and try_number >= self.max_retries:
if self.max_tries and try_number >= self.max_tries:
yield TriggerEvent(
{
"status": "error",
Expand All @@ -80,60 +85,27 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
yield TriggerEvent({"status": "error", "message": str(e)})


class EmrContainerOperatorTrigger(BaseTrigger):
"""
The EmrContainerSensorTrigger is triggered when EMR container is created, it polls for the AWS EMR EKS Virtual
Cluster Job status. It is fired as deferred class with params to run the task in trigger worker

:param virtual_cluster_id: The EMR on EKS virtual cluster ID.
:param name: The name of the job run.
:param execution_role_arn: The IAM role ARN associated with the job run.
:param release_label: The Amazon EMR release version to use for the job run.
:param job_driver: Job configuration details, e.g. the Spark job parameters.
:param configuration_overrides: The configuration overrides for the job run,
specifically either application configuration or monitoring configuration.
:param client_request_token: The client idempotency token of the job run request.
:param aws_conn_id: Reference to AWS connection id.
:param poll_interval: polling period in seconds to check for the status.
:param max_retries: maximum retry for poll for the status.
"""

INTERMEDIATE_STATES: List[str] = ["PENDING", "SUBMITTED", "RUNNING"]
FAILURE_STATES: List[str] = ["FAILED", "CANCELLED", "CANCEL_PENDING"]
SUCCESS_STATES: List[str] = ["COMPLETED"]
TERMINAL_STATES: List[str] = ["COMPLETED", "FAILED", "CANCELLED", "CANCEL_PENDING"]

def __init__(
self,
virtual_cluster_id: str,
name: str,
job_id: str,
aws_conn_id: str = "aws_default",
poll_interval: int = 30,
max_tries: Optional[int] = None,
):
super().__init__()
self.virtual_cluster_id = virtual_cluster_id
self.name = name
self.job_id = job_id
self.aws_conn_id = aws_conn_id
self.poll_interval = poll_interval
self.max_tries = max_tries
class EmrContainerOperatorTrigger(EmrContainerBaseTrigger):
"""Poll for the status of EMR container until reaches terminal state"""

def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes EmrContainerOperatorTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger",
{
"virtual_cluster_id": self.virtual_cluster_id,
"name": self.name,
"job_id": self.job_id,
"aws_conn_id": self.aws_conn_id,
"max_tries": self.max_tries,
"poll_interval": self.poll_interval,
},
)

INTERMEDIATE_STATES: List[str] = ["PENDING", "SUBMITTED", "RUNNING"]
FAILURE_STATES: List[str] = ["FAILED", "CANCELLED", "CANCEL_PENDING"]
SUCCESS_STATES: List[str] = ["COMPLETED"]
TERMINAL_STATES: List[str] = ["COMPLETED", "FAILED", "CANCELLED", "CANCEL_PENDING"]

async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""Run until EMR container reaches the desire state"""
hook = EmrContainerHookAsync(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id)
Expand Down
17 changes: 8 additions & 9 deletions tests/amazon/aws/triggers/test_emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@

def test_emr_container_sensors_trigger_serialization():
"""
Asserts that the TaskStateTrigger correctly serializes its arguments
Asserts that the EmrContainerSensorTrigger correctly serializes its arguments
and classpath.
"""
trigger = EmrContainerSensorTrigger(
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
job_id=JOB_ID,
max_retries=MAX_RETRIES,
max_tries=MAX_RETRIES,
aws_conn_id=AWS_CONN_ID,
poll_interval=POLL_INTERVAL,
)
Expand All @@ -69,7 +69,7 @@ def test_emr_container_sensors_trigger_serialization():
assert kwargs == {
"virtual_cluster_id": VIRTUAL_CLUSTER_ID,
"job_id": JOB_ID,
"max_retries": MAX_RETRIES,
"max_tries": MAX_RETRIES,
"poll_interval": POLL_INTERVAL,
"aws_conn_id": AWS_CONN_ID,
}
Expand All @@ -92,7 +92,7 @@ async def test_emr_container_sensors_trigger_run(mock_query_status, mock_status)
trigger = EmrContainerSensorTrigger(
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
job_id=JOB_ID,
max_retries=MAX_RETRIES,
max_tries=MAX_RETRIES,
aws_conn_id=AWS_CONN_ID,
poll_interval=POLL_INTERVAL,
)
Expand All @@ -118,7 +118,7 @@ async def test_emr_container_sensors_trigger_completed(mock_query_status, mock_s
trigger = EmrContainerSensorTrigger(
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
job_id=JOB_ID,
max_retries=MAX_RETRIES,
max_tries=MAX_RETRIES,
aws_conn_id=AWS_CONN_ID,
poll_interval=POLL_INTERVAL,
)
Expand All @@ -142,7 +142,7 @@ async def test_emr_container_sensors_trigger_failure_status(mock_query_status, m
trigger = EmrContainerSensorTrigger(
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
job_id=JOB_ID,
max_retries=MAX_RETRIES,
max_tries=MAX_RETRIES,
aws_conn_id=AWS_CONN_ID,
poll_interval=POLL_INTERVAL,
)
Expand All @@ -162,7 +162,7 @@ async def test_emr_container_sensors_trigger_exception(mock_query_status):
trigger = EmrContainerSensorTrigger(
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
job_id=JOB_ID,
max_retries=MAX_RETRIES,
max_tries=MAX_RETRIES,
aws_conn_id=AWS_CONN_ID,
poll_interval=POLL_INTERVAL,
)
Expand All @@ -181,7 +181,7 @@ async def test_emr_container_sensor_trigger_timeout(mock_query_status):
job_id=JOB_ID,
aws_conn_id=AWS_CONN_ID,
poll_interval=1,
max_retries=2,
max_tries=2,
)
generator = trigger.run()
actual = await generator.asend(None)
Expand Down Expand Up @@ -433,7 +433,6 @@ def test_emr_container_operator_trigger_serialization():
assert classpath == "astronomer.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger"
assert kwargs == {
"virtual_cluster_id": VIRTUAL_CLUSTER_ID,
"name": NAME,
"job_id": JOB_ID,
"aws_conn_id": AWS_CONN_ID,
"poll_interval": POLL_INTERVAL,
Expand Down