Skip to content

Handle soft fail for s3keysensor #1161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions astronomer/providers/amazon/aws/sensors/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import timedelta
from typing import Any, Callable, List, Sequence, cast

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, S3KeysUnchangedSensor
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -81,7 +81,14 @@ def __init__(

def execute(self, context: Context) -> None:
"""Check for a keys in s3 and defers using the trigger"""
if not self.poke(context):
try:
poke = self.poke(context)
except Exception as e:
if self.soft_fail:
raise AirflowSkipException(f"{self.task_id} failed")
else:
raise e
if not poke:
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=S3KeyTrigger(
Expand All @@ -92,6 +99,7 @@ def execute(self, context: Context) -> None:
aws_conn_id=self.aws_conn_id,
verify=self.verify,
poke_interval=self.poke_interval,
soft_fail=self.soft_fail,
),
method_name="execute_complete",
)
Expand All @@ -103,6 +111,8 @@ def execute_complete(self, context: Context, event: Any = None) -> bool | None:
successful.
"""
if event["status"] == "error":
if event["soft_fail"]:
raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
elif event["status"] == "success" and "s3_objects" in event:
files = typing.cast(List[str], event["s3_objects"])
Expand Down
6 changes: 5 additions & 1 deletion astronomer/providers/amazon/aws/triggers/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class S3KeyTrigger(BaseTrigger):
Unix wildcard pattern
:param aws_conn_id: reference to the s3 connection
:param hook_params: params for hook its optional
:param soft_fail: Set to true to mark the task as SKIPPED on failure
:param check_fn: Function that receives the list of the S3 objects,
and returns a boolean
"""
Expand All @@ -34,6 +35,7 @@ def __init__(
check_fn: Callable[..., bool] | None = None,
aws_conn_id: str = "aws_default",
poke_interval: float = 5.0,
soft_fail: bool = False,
**hook_params: Any,
):
super().__init__()
Expand All @@ -44,6 +46,7 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.hook_params = hook_params
self.poke_interval = poke_interval
self.soft_fail = soft_fail

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize S3KeyTrigger arguments and classpath."""
Expand All @@ -57,6 +60,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"aws_conn_id": self.aws_conn_id,
"hook_params": self.hook_params,
"poke_interval": self.poke_interval,
"soft_fail": self.soft_fail,
},
)

Expand All @@ -77,7 +81,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
await asyncio.sleep(self.poke_interval)

except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
yield TriggerEvent({"status": "error", "message": str(e), "soft_fail": self.soft_fail})

def _get_async_hook(self) -> S3HookAsync:
return S3HookAsync(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify"))
Expand Down
32 changes: 30 additions & 2 deletions tests/amazon/aws/sensors/test_s3_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest import mock

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred
from airflow.models import DAG, DagRun, TaskInstance
from airflow.models.variable import Variable
from airflow.utils import timezone
Expand Down Expand Up @@ -165,7 +165,9 @@ def test_s3_key_sensor_execute_complete_error(self, key, bucket, mock_hook, mock
bucket_name=bucket,
)
with pytest.raises(AirflowException):
sensor.execute_complete(context={}, event={"status": "error", "message": "mocked error"})
sensor.execute_complete(
context={}, event={"status": "error", "message": "mocked error", "soft_fail": False}
)

@parameterized.expand(
[
Expand Down Expand Up @@ -243,6 +245,32 @@ def test_s3_key_sensor_with_wildcard_async(self, mock_hook, mock_poke, context):

assert isinstance(exc.value.trigger, S3KeyTrigger), "Trigger is not a S3KeyTrigger"

def test_soft_fail(self):
"""Raise AirflowSkipException in case soft_fail is true"""
sensor = S3KeySensorAsync(
task_id="s3_key_sensor_async", bucket_key="key", bucket_name="bucket", soft_fail=True
)
with pytest.raises(AirflowSkipException):
sensor.execute_complete(
context={}, event={"status": "error", "message": "mocked error", "soft_fail": True}
)

@pytest.mark.parametrize(
"soft_fail,exec",
[
(True, AirflowSkipException),
(False, Exception),
],
)
@mock.patch(f"{MODULE}.S3KeySensorAsync.poke")
def test_execute_handle_exception(self, mock_poke, soft_fail, exec):
mock_poke.side_effect = Exception()
sensor = S3KeySensorAsync(
task_id="s3_key_sensor_async", bucket_key="key", bucket_name="bucket", soft_fail=soft_fail
)
with pytest.raises(exec):
sensor.execute(context={})


class TestS3KeysUnchangedSensorAsync:
@mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.defer")
Expand Down
2 changes: 2 additions & 0 deletions tests/amazon/aws/triggers/test_s3_triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_serialization(self):
"aws_conn_id": "aws_default",
"hook_params": {},
"check_fn": None,
"soft_fail": False,
"poke_interval": 5.0,
}

Expand Down Expand Up @@ -76,6 +77,7 @@ async def test_run_exception(self, mock_client):
{
"message": "Unable to locate credentials",
"status": "error",
"soft_fail": False,
}
)
== actual
Expand Down