Skip to content

Commit f5c2748

Browse files
authored
fix(providers/sql): respect soft_fail argument when exception is raised (#34199)
1 parent c5016f7 commit f5c2748

File tree

2 files changed

+80
-20
lines changed
  • airflow/providers/common/sql/sensors
  • tests/providers/common/sql/sensors

2 files changed

+80
-20
lines changed

airflow/providers/common/sql/sensors/sql.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from typing import Any, Sequence
2020

21-
from airflow import AirflowException
21+
from airflow.exceptions import AirflowException, AirflowSkipException
2222
from airflow.hooks.base import BaseHook
2323
from airflow.providers.common.sql.hooks.sql import DbApiHook
2424
from airflow.sensors.base import BaseSensorOperator
@@ -96,19 +96,37 @@ def poke(self, context: Any):
9696
records = hook.get_records(self.sql, self.parameters)
9797
if not records:
9898
if self.fail_on_empty:
99-
raise AirflowException("No rows returned, raising as per fail_on_empty flag")
99+
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
100+
message = "No rows returned, raising as per fail_on_empty flag"
101+
if self.soft_fail:
102+
raise AirflowSkipException(message)
103+
raise AirflowException(message)
100104
else:
101105
return False
106+
102107
first_cell = records[0][0]
103108
if self.failure is not None:
104109
if callable(self.failure):
105110
if self.failure(first_cell):
106-
raise AirflowException(f"Failure criteria met. self.failure({first_cell}) returned True")
111+
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
112+
message = f"Failure criteria met. self.failure({first_cell}) returned True"
113+
if self.soft_fail:
114+
raise AirflowSkipException(message)
115+
raise AirflowException(message)
107116
else:
108-
raise AirflowException(f"self.failure is present, but not callable -> {self.failure}")
117+
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
118+
message = f"self.failure is present, but not callable -> {self.failure}"
119+
if self.soft_fail:
120+
raise AirflowSkipException(message)
121+
raise AirflowException(message)
122+
109123
if self.success is not None:
110124
if callable(self.success):
111125
return self.success(first_cell)
112126
else:
113-
raise AirflowException(f"self.success is present, but not callable -> {self.success}")
127+
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
128+
message = f"self.success is present, but not callable -> {self.success}"
129+
if self.soft_fail:
130+
raise AirflowSkipException(message)
131+
raise AirflowException(message)
114132
return bool(first_cell)

tests/providers/common/sql/sensors/test_sql.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import pytest
2323

24-
from airflow.exceptions import AirflowException
24+
from airflow.exceptions import AirflowException, AirflowSkipException
2525
from airflow.models.dag import DAG
2626
from airflow.providers.common.sql.hooks.sql import DbApiHook
2727
from airflow.providers.common.sql.sensors.sql import SqlSensor
@@ -117,17 +117,26 @@ def test_sql_sensor_postgres_poke(self, mock_hook):
117117
mock_get_records.return_value = [["1"]]
118118
assert op.poke(None)
119119

120+
@pytest.mark.parametrize(
121+
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
122+
)
120123
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
121-
def test_sql_sensor_postgres_poke_fail_on_empty(self, mock_hook):
124+
def test_sql_sensor_postgres_poke_fail_on_empty(
125+
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
126+
):
122127
op = SqlSensor(
123-
task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", fail_on_empty=True
128+
task_id="sql_sensor_check",
129+
conn_id="postgres_default",
130+
sql="SELECT 1",
131+
fail_on_empty=True,
132+
soft_fail=soft_fail,
124133
)
125134

126135
mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
127136
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
128137

129138
mock_get_records.return_value = []
130-
with pytest.raises(AirflowException):
139+
with pytest.raises(expected_exception):
131140
op.poke(None)
132141

133142
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
@@ -148,10 +157,19 @@ def test_sql_sensor_postgres_poke_success(self, mock_hook):
148157
mock_get_records.return_value = [["1"]]
149158
assert not op.poke(None)
150159

160+
@pytest.mark.parametrize(
161+
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
162+
)
151163
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
152-
def test_sql_sensor_postgres_poke_failure(self, mock_hook):
164+
def test_sql_sensor_postgres_poke_failure(
165+
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
166+
):
153167
op = SqlSensor(
154-
task_id="sql_sensor_check", conn_id="postgres_default", sql="SELECT 1", failure=lambda x: x in [1]
168+
task_id="sql_sensor_check",
169+
conn_id="postgres_default",
170+
sql="SELECT 1",
171+
failure=lambda x: x in [1],
172+
soft_fail=soft_fail,
155173
)
156174

157175
mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
@@ -161,17 +179,23 @@ def test_sql_sensor_postgres_poke_failure(self, mock_hook):
161179
assert not op.poke(None)
162180

163181
mock_get_records.return_value = [[1]]
164-
with pytest.raises(AirflowException):
182+
with pytest.raises(expected_exception):
165183
op.poke(None)
166184

185+
@pytest.mark.parametrize(
186+
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
187+
)
167188
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
168-
def test_sql_sensor_postgres_poke_failure_success(self, mock_hook):
189+
def test_sql_sensor_postgres_poke_failure_success(
190+
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
191+
):
169192
op = SqlSensor(
170193
task_id="sql_sensor_check",
171194
conn_id="postgres_default",
172195
sql="SELECT 1",
173196
failure=lambda x: x in [1],
174197
success=lambda x: x in [2],
198+
soft_fail=soft_fail,
175199
)
176200

177201
mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
@@ -181,20 +205,26 @@ def test_sql_sensor_postgres_poke_failure_success(self, mock_hook):
181205
assert not op.poke(None)
182206

183207
mock_get_records.return_value = [[1]]
184-
with pytest.raises(AirflowException):
208+
with pytest.raises(expected_exception):
185209
op.poke(None)
186210

187211
mock_get_records.return_value = [[2]]
188212
assert op.poke(None)
189213

214+
@pytest.mark.parametrize(
215+
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
216+
)
190217
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
191-
def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook):
218+
def test_sql_sensor_postgres_poke_failure_success_same(
219+
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
220+
):
192221
op = SqlSensor(
193222
task_id="sql_sensor_check",
194223
conn_id="postgres_default",
195224
sql="SELECT 1",
196225
failure=lambda x: x in [1],
197226
success=lambda x: x in [1],
227+
soft_fail=soft_fail,
198228
)
199229

200230
mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
@@ -204,40 +234,52 @@ def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook):
204234
assert not op.poke(None)
205235

206236
mock_get_records.return_value = [[1]]
207-
with pytest.raises(AirflowException):
237+
with pytest.raises(expected_exception):
208238
op.poke(None)
209239

240+
@pytest.mark.parametrize(
241+
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
242+
)
210243
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
211-
def test_sql_sensor_postgres_poke_invalid_failure(self, mock_hook):
244+
def test_sql_sensor_postgres_poke_invalid_failure(
245+
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
246+
):
212247
op = SqlSensor(
213248
task_id="sql_sensor_check",
214249
conn_id="postgres_default",
215250
sql="SELECT 1",
216251
failure=[1],
252+
soft_fail=soft_fail,
217253
)
218254

219255
mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
220256
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
221257

222258
mock_get_records.return_value = [[1]]
223-
with pytest.raises(AirflowException) as ctx:
259+
with pytest.raises(expected_exception) as ctx:
224260
op.poke(None)
225261
assert "self.failure is present, but not callable -> [1]" == str(ctx.value)
226262

263+
@pytest.mark.parametrize(
264+
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
265+
)
227266
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
228-
def test_sql_sensor_postgres_poke_invalid_success(self, mock_hook):
267+
def test_sql_sensor_postgres_poke_invalid_success(
268+
self, mock_hook, soft_fail: bool, expected_exception: AirflowException
269+
):
229270
op = SqlSensor(
230271
task_id="sql_sensor_check",
231272
conn_id="postgres_default",
232273
sql="SELECT 1",
233274
success=[1],
275+
soft_fail=soft_fail,
234276
)
235277

236278
mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
237279
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
238280

239281
mock_get_records.return_value = [[1]]
240-
with pytest.raises(AirflowException) as ctx:
282+
with pytest.raises(expected_exception) as ctx:
241283
op.poke(None)
242284
assert "self.success is present, but not callable -> [1]" == str(ctx.value)
243285

0 commit comments

Comments
 (0)