Skip to content

respect soft_fail argument when exception is raised for sql sensors #34199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions airflow/providers/common/sql/sensors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import Any, Sequence

from airflow import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.sensors.base import BaseSensorOperator
Expand Down Expand Up @@ -96,19 +96,37 @@ def poke(self, context: Any):
records = hook.get_records(self.sql, self.parameters)
if not records:
if self.fail_on_empty:
raise AirflowException("No rows returned, raising as per fail_on_empty flag")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = "No rows returned, raising as per fail_on_empty flag"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
else:
return False

first_cell = records[0][0]
if self.failure is not None:
if callable(self.failure):
if self.failure(first_cell):
raise AirflowException(f"Failure criteria met. self.failure({first_cell}) returned True")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"Failure criteria met. self.failure({first_cell}) returned True"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
else:
raise AirflowException(f"self.failure is present, but not callable -> {self.failure}")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"self.failure is present, but not callable -> {self.failure}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

if self.success is not None:
if callable(self.success):
return self.success(first_cell)
else:
raise AirflowException(f"self.success is present, but not callable -> {self.success}")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"self.success is present, but not callable -> {self.success}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
return bool(first_cell)
72 changes: 57 additions & 15 deletions tests/providers/common/sql/sensors/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytest

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models.dag import DAG
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.common.sql.sensors.sql import SqlSensor
Expand Down Expand Up @@ -117,17 +117,26 @@ def test_sql_sensor_postgres_poke(self, mock_hook):
mock_get_records.return_value = [["1"]]
assert op.poke(None)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_fail_on_empty(self, mock_hook):
def test_sql_sensor_postgres_poke_fail_on_empty(
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
):
op = SqlSensor(
task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", fail_on_empty=True
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
fail_on_empty=True,
soft_fail=soft_fail,
)

mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
with pytest.raises(AirflowException):
with pytest.raises(expected_exception):
op.poke(None)

@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
Expand All @@ -148,10 +157,19 @@ def test_sql_sensor_postgres_poke_success(self, mock_hook):
mock_get_records.return_value = [["1"]]
assert not op.poke(None)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_failure(self, mock_hook):
def test_sql_sensor_postgres_poke_failure(
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
):
op = SqlSensor(
task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", failure=lambda x: x in [1]
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
failure=lambda x: x in [1],
soft_fail=soft_fail,
)

mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
Expand All @@ -161,17 +179,23 @@ def test_sql_sensor_postgres_poke_failure(self, mock_hook):
assert not op.poke(None)

mock_get_records.return_value = [[1]]
with pytest.raises(AirflowException):
with pytest.raises(expected_exception):
op.poke(None)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_failure_success(self, mock_hook):
def test_sql_sensor_postgres_poke_failure_success(
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
):
op = SqlSensor(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
failure=lambda x: x in [1],
success=lambda x: x in [2],
soft_fail=soft_fail,
)

mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
Expand All @@ -181,20 +205,26 @@ def test_sql_sensor_postgres_poke_failure_success(self, mock_hook):
assert not op.poke(None)

mock_get_records.return_value = [[1]]
with pytest.raises(AirflowException):
with pytest.raises(expected_exception):
op.poke(None)

mock_get_records.return_value = [[2]]
assert op.poke(None)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook):
def test_sql_sensor_postgres_poke_failure_success_same(
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
):
op = SqlSensor(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
failure=lambda x: x in [1],
success=lambda x: x in [1],
soft_fail=soft_fail,
)

mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
Expand All @@ -204,40 +234,52 @@ def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook):
assert not op.poke(None)

mock_get_records.return_value = [[1]]
with pytest.raises(AirflowException):
with pytest.raises(expected_exception):
op.poke(None)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_invalid_failure(self, mock_hook):
def test_sql_sensor_postgres_poke_invalid_failure(
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
):
op = SqlSensor(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
failure=[1],
soft_fail=soft_fail,
)

mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = [[1]]
with pytest.raises(AirflowException) as ctx:
with pytest.raises(expected_exception) as ctx:
op.poke(None)
assert "self.failure is present, but not callable -> [1]" == str(ctx.value)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
def test_sql_sensor_postgres_poke_invalid_success(self, mock_hook):
def test_sql_sensor_postgres_poke_invalid_success(
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
):
op = SqlSensor(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
success=[1],
soft_fail=soft_fail,
)

mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = [[1]]
with pytest.raises(AirflowException) as ctx:
with pytest.raises(expected_exception) as ctx:
op.poke(None)
assert "self.success is present, but not callable -> [1]" == str(ctx.value)

Expand Down