Skip to content

Commit 581e2e4

Browse files
authored
Change AirflowTaskTimeout to inherit BaseException (#35653)
Code that normally catches Exception should not implicitly ignore interrupts from AirflowTaskTimout. Fixes #35644 #35474
1 parent 69d48ed commit 581e2e4

File tree

7 files changed

+47
-15
lines changed

7 files changed

+47
-15
lines changed

airflow/exceptions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ class InvalidStatsNameException(AirflowException):
7979
"""Raise when name of the stats is invalid."""
8080

8181

82-
class AirflowTaskTimeout(AirflowException):
82+
# Important to inherit BaseException instead of AirflowException->Exception, since this Exception is used
83+
# to explicitly interrupt ongoing task. Code that does normal error-handling should not treat
84+
# such interrupt as an error that can be handled normally. (Compare with KeyboardInterrupt)
85+
class AirflowTaskTimeout(BaseException):
8386
"""Raise when the task execution times-out."""
8487

8588

airflow/models/taskinstance.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ def _is_eligible_to_retry(*, task_instance: TaskInstance | TaskInstancePydantic)
812812
def _handle_failure(
813813
*,
814814
task_instance: TaskInstance | TaskInstancePydantic,
815-
error: None | str | Exception | KeyboardInterrupt,
815+
error: None | str | BaseException,
816816
session: Session,
817817
test_mode: bool | None = None,
818818
context: Context | None = None,
@@ -2411,7 +2411,7 @@ def _run_raw_task(
24112411
self.handle_failure(e, test_mode, context, force_fail=True, session=session)
24122412
session.commit()
24132413
raise
2414-
except AirflowException as e:
2414+
except (AirflowTaskTimeout, AirflowException) as e:
24152415
if not test_mode:
24162416
self.refresh_from_db(lock_for_update=True, session=session)
24172417
# for case when task is marked as success/failed externally
@@ -2426,17 +2426,17 @@ def _run_raw_task(
24262426
self.handle_failure(e, test_mode, context, session=session)
24272427
session.commit()
24282428
raise
2429-
except (Exception, KeyboardInterrupt) as e:
2430-
self.handle_failure(e, test_mode, context, session=session)
2431-
session.commit()
2432-
raise
24332429
except SystemExit as e:
24342430
# We have already handled SystemExit with success codes (0 and None) in the `_execute_task`.
24352431
# Therefore, here we must handle only error codes.
24362432
msg = f"Task failed due to SystemExit({e.code})"
24372433
self.handle_failure(msg, test_mode, context, session=session)
24382434
session.commit()
24392435
raise Exception(msg)
2436+
except BaseException as e:
2437+
self.handle_failure(e, test_mode, context, session=session)
2438+
session.commit()
2439+
raise
24402440
finally:
24412441
Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags)
24422442
# Same metric with tagging
@@ -2743,7 +2743,7 @@ def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -
27432743
def fetch_handle_failure_context(
27442744
cls,
27452745
ti: TaskInstance | TaskInstancePydantic,
2746-
error: None | str | Exception | KeyboardInterrupt,
2746+
error: None | str | BaseException,
27472747
test_mode: bool | None = None,
27482748
context: Context | None = None,
27492749
force_fail: bool = False,
@@ -2838,7 +2838,7 @@ def save_to_db(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_S
28382838
@provide_session
28392839
def handle_failure(
28402840
self,
2841-
error: None | str | Exception | KeyboardInterrupt,
2841+
error: None | str | BaseException,
28422842
test_mode: bool | None = None,
28432843
context: Context | None = None,
28442844
force_fail: bool = False,

airflow/providers/celery/executors/celery_executor_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
import airflow.settings as settings
4343
from airflow.configuration import conf
44-
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
44+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout
4545
from airflow.executors.base_executor import BaseExecutor
4646
from airflow.stats import Stats
4747
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
@@ -198,7 +198,7 @@ class ExceptionWithTraceback:
198198
:param exception_traceback: The stacktrace to wrap
199199
"""
200200

201-
def __init__(self, exception: Exception, exception_traceback: str):
201+
def __init__(self, exception: BaseException, exception_traceback: str):
202202
self.exception = exception
203203
self.traceback = exception_traceback
204204

@@ -211,7 +211,7 @@ def send_task_to_executor(
211211
try:
212212
with timeout(seconds=OPERATION_TIMEOUT):
213213
result = task_to_run.apply_async(args=[command], queue=queue)
214-
except Exception as e:
214+
except (Exception, AirflowTaskTimeout) as e:
215215
exception_traceback = f"Celery Task ID: {key}\n{traceback.format_exc()}"
216216
result = ExceptionWithTraceback(e, exception_traceback)
217217

airflow/utils/context.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class Context(TypedDict, total=False):
6565
data_interval_start: DateTime
6666
ds: str
6767
ds_nodash: str
68-
exception: KeyboardInterrupt | Exception | str | None
68+
exception: BaseException | str | None
6969
execution_date: DateTime
7070
expanded_ti_count: int | None
7171
inlets: list

newsfragments/35653.significant.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
``AirflowTimeoutError`` is no longer ``except``ed by default through ``Exception``
2+
3+
The ``AirflowTimeoutError`` is now inheriting ``BaseException`` instead of
4+
``AirflowException``->``Exception``.
5+
See https://docs.python.org/3/library/exceptions.html#exception-hierarchy
6+
7+
This prevents code catching ``Exception`` from accidentally
8+
catching ``AirflowTimeoutError`` and continuing to run.
9+
``AirflowTimeoutError`` is an explicit intent to cancel the task, and should not
10+
be caught in attempts to handle the error and return some default value.
11+
12+
Catching ``AirflowTimeoutError`` is still possible by explicitly ``except``ing
13+
``AirflowTimeoutError`` or ``BaseException``.
14+
This is discouraged, as it may allow the code to continue running even after
15+
such cancellation requests.
16+
Code that previously depended on performing strict cleanup in every situation
17+
after catching ``Exception`` is advised to use ``finally`` blocks or
18+
context managers. To perform only the cleanup and then automatically
19+
re-raise the exception.
20+
See similar considerations about catching ``KeyboardInterrupt`` in
21+
https://docs.python.org/3/library/exceptions.html#KeyboardInterrupt

tests/core/test_core.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,18 @@ class InvalidTemplateFieldOperator(BaseOperator):
7171
op.dry_run()
7272

7373
def test_timeout(self, dag_maker):
74+
def sleep_and_catch_other_exceptions():
75+
try:
76+
sleep(5)
77+
# Catching Exception should NOT catch AirflowTaskTimeout
78+
except Exception:
79+
pass
80+
7481
with dag_maker():
7582
op = PythonOperator(
7683
task_id="test_timeout",
7784
execution_timeout=timedelta(seconds=1),
78-
python_callable=lambda: sleep(5),
85+
python_callable=sleep_and_catch_other_exceptions,
7986
)
8087
dag_maker.create_dagrun()
8188
with pytest.raises(AirflowTaskTimeout):

tests/providers/microsoft/azure/hooks/test_synapse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytest
2222
from azure.synapse.spark import SparkClient
2323

24+
from airflow.exceptions import AirflowTaskTimeout
2425
from airflow.models.connection import Connection
2526
from airflow.providers.microsoft.azure.hooks.synapse import AzureSynapseHook, AzureSynapseSparkBatchRunStatus
2627

@@ -172,7 +173,7 @@ def test_wait_for_job_run_status(hook, job_run_status, expected_status, expected
172173
if expected_output != "timeout":
173174
assert hook.wait_for_job_run_status(**config) == expected_output
174175
else:
175-
with pytest.raises(Exception):
176+
with pytest.raises(AirflowTaskTimeout):
176177
hook.wait_for_job_run_status(**config)
177178

178179

0 commit comments

Comments
 (0)