Skip to content

Commit 6a3946a

Browse files
committed
Trying something
1 parent 1ab2474 commit 6a3946a

File tree

2 files changed

+69
-158
lines changed

2 files changed

+69
-158
lines changed

airflow-core/src/airflow/models/taskinstance.py

Lines changed: 22 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,8 @@
7878
from airflow.configuration import conf
7979
from airflow.exceptions import (
8080
AirflowException,
81-
AirflowFailException,
8281
AirflowInactiveAssetInInletOrOutletException,
8382
AirflowRescheduleException,
84-
AirflowSensorTimeout,
85-
AirflowSkipException,
8683
AirflowTaskTerminated,
8784
AirflowTaskTimeout,
8885
TaskDeferralError,
@@ -118,7 +115,6 @@
118115
from airflow.utils.span_status import SpanStatus
119116
from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON, UtcDateTime
120117
from airflow.utils.state import DagRunState, State, TaskInstanceState
121-
from airflow.utils.task_instance_session import set_current_task_instance_session
122118
from airflow.utils.timeout import timeout
123119
from airflow.utils.xcom import XCOM_RETURN_KEY
124120

@@ -1692,161 +1688,33 @@ def _run_raw_task(
16921688
:param pool: specifies the pool to use to run the task instance
16931689
:param session: SQLAlchemy ORM Session
16941690
"""
1695-
if TYPE_CHECKING:
1696-
assert self.task
1697-
1698-
if TYPE_CHECKING:
1699-
assert isinstance(self.task, BaseOperator)
1700-
1701-
self.test_mode = test_mode
1702-
self.refresh_from_task(self.task, pool_override=pool)
1703-
self.refresh_from_db(session=session)
1704-
self.hostname = get_hostname()
1705-
self.pid = os.getpid()
1706-
if not test_mode:
1707-
TaskInstance.save_to_db(ti=self, session=session)
1708-
actual_start_date = timezone.utcnow()
1709-
Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}", tags=self.stats_tags)
1710-
# Same metric with tagging
1711-
Stats.incr("ti.start", tags=self.stats_tags)
1712-
# Initialize final state counters at zero
1713-
for state in State.task_states:
1714-
Stats.incr(
1715-
f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}",
1716-
count=0,
1717-
tags=self.stats_tags,
1718-
)
1719-
# Same metric with tagging
1720-
Stats.incr(
1721-
"ti.finish",
1722-
count=0,
1723-
tags={**self.stats_tags, "state": str(state)},
1724-
)
1725-
with set_current_task_instance_session(session=session):
1726-
self.task = self.task.prepare_for_execution()
1727-
context = self.get_template_context(ignore_param_exceptions=False, session=session)
1728-
1729-
try:
1730-
if self.task:
1731-
from airflow.sdk.definitions.asset import Asset
1732-
1733-
inlets = [asset.asprofile() for asset in self.task.inlets if isinstance(asset, Asset)]
1734-
outlets = [asset.asprofile() for asset in self.task.outlets if isinstance(asset, Asset)]
1735-
TaskInstance.validate_inlet_outlet_assets_activeness(inlets, outlets, session=session)
1736-
if not mark_success:
1737-
TaskInstance._execute_task_with_callbacks(
1738-
self=self, # type: ignore[arg-type]
1739-
context=context,
1740-
test_mode=test_mode,
1741-
session=session,
1742-
)
1743-
if not test_mode:
1744-
self.refresh_from_db(lock_for_update=True, session=session, keep_local_changes=True)
1745-
self.state = TaskInstanceState.SUCCESS
1746-
except TaskDeferred as defer:
1747-
# The task has signalled it wants to defer execution based on
1748-
# a trigger.
1749-
if raise_on_defer:
1750-
raise
1751-
self.defer_task(exception=defer, session=session)
1752-
self.log.info(
1753-
"Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, logical_date=%s, start_date=%s",
1754-
self.dag_id,
1755-
self.task_id,
1756-
self.run_id,
1757-
_date_or_empty(task_instance=self, attr="logical_date"),
1758-
_date_or_empty(task_instance=self, attr="start_date"),
1759-
)
1760-
return TaskReturnCode.DEFERRED
1761-
except AirflowSkipException as e:
1762-
# Recording SKIP
1763-
# log only if exception has any arguments to prevent log flooding
1764-
if e.args:
1765-
self.log.info(e)
1766-
if not test_mode:
1767-
self.refresh_from_db(lock_for_update=True, session=session, keep_local_changes=True)
1768-
self.state = TaskInstanceState.SKIPPED
1769-
_run_finished_callback(callbacks=self.task.on_skipped_callback, context=context)
1770-
TaskInstance.save_to_db(ti=self, session=session)
1771-
except AirflowRescheduleException as reschedule_exception:
1772-
self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session)
1773-
self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
1774-
return None
1775-
except (AirflowFailException, AirflowSensorTimeout) as e:
1776-
# If AirflowFailException is raised, task should not retry.
1777-
# If a sensor in reschedule mode reaches timeout, task should not retry.
1778-
self.handle_failure(
1779-
e, test_mode, context, force_fail=True, session=session
1780-
) # already saves to db
1781-
raise
1782-
except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated) as e:
1783-
if not test_mode:
1784-
self.refresh_from_db(lock_for_update=True, session=session)
1785-
# for case when task is marked as success/failed externally
1786-
# or dagrun timed out and task is marked as skipped
1787-
# current behavior doesn't hit the callbacks
1788-
if self.state in State.finished:
1789-
self.clear_next_method_args()
1790-
TaskInstance.save_to_db(ti=self, session=session)
1791-
return None
1792-
self.handle_failure(e, test_mode, context, session=session)
1793-
raise
1794-
except SystemExit as e:
1795-
# We have already handled SystemExit with success codes (0 and None) in the `_execute_task`.
1796-
# Therefore, here we must handle only error codes.
1797-
msg = f"Task failed due to SystemExit({e.code})"
1798-
self.handle_failure(msg, test_mode, context, session=session)
1799-
raise AirflowException(msg)
1800-
except BaseException as e:
1801-
self.handle_failure(e, test_mode, context, session=session)
1802-
raise
1803-
finally:
1804-
# Print a marker post execution for internals of post task processing
1805-
log.info("::group::Post task execution logs")
1806-
1807-
Stats.incr(
1808-
f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}",
1809-
tags=self.stats_tags,
1810-
)
1811-
# Same metric with tagging
1812-
Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)})
1691+
from airflow.sdk.api.datamodels._generated import TaskInstance as TaskInstanceSDK
1692+
from airflow.sdk.execution_time.supervisor import run_task_in_process
18131693

1814-
# Recording SKIPPED or SUCCESS
1815-
self.clear_next_method_args()
1816-
self.end_date = timezone.utcnow()
1817-
_log_state(task_instance=self)
1818-
self.set_duration()
1694+
self.set_state(TaskInstanceState.QUEUED)
18191695

1820-
# run on_success_callback before db committing
1821-
# otherwise, the LocalTaskJob sees the state is changed to `success`,
1822-
# but the task_runner is still running, LocalTaskJob then treats the state is set externally!
1823-
if self.state == TaskInstanceState.SUCCESS:
1824-
_run_finished_callback(callbacks=self.task.on_success_callback, context=context)
1696+
if mark_success:
1697+
self.set_state(TaskInstanceState.SUCCESS)
1698+
log.info("[DAG TEST] Marking success for %s ", self.task_id)
1699+
return
18251700

1826-
if not test_mode:
1827-
_add_log(event=self.state, task_instance=self, session=session)
1828-
if self.state == TaskInstanceState.SUCCESS:
1829-
from airflow.sdk.execution_time.task_runner import (
1830-
_build_asset_profiles,
1831-
_serialize_outlet_events,
1832-
)
1701+
taskrun_result = run_task_in_process(
1702+
ti=TaskInstanceSDK(
1703+
id=self.id,
1704+
task_id=self.task_id,
1705+
dag_id=self.task.dag_id,
1706+
run_id=self.run_id,
1707+
try_number=self.try_number,
1708+
map_index=self.map_index,
1709+
),
1710+
task=self.task,
1711+
)
18331712

1834-
TaskInstance.register_asset_changes_in_db(
1835-
self,
1836-
list(_build_asset_profiles(self.task.outlets)),
1837-
list(_serialize_outlet_events(context["outlet_events"])),
1838-
session=session,
1839-
)
1713+
if taskrun_result.state != TaskInstanceState.QUEUED:
1714+
self.set_state(taskrun_result.state)
18401715

1841-
TaskInstance.save_to_db(ti=self, session=session)
1842-
if self.state == TaskInstanceState.SUCCESS:
1843-
try:
1844-
get_listener_manager().hook.on_task_instance_success(
1845-
previous_state=TaskInstanceState.RUNNING, task_instance=self
1846-
)
1847-
except Exception:
1848-
log.exception("error calling listener")
1849-
return None
1716+
if taskrun_result.error:
1717+
raise taskrun_result.error
18501718

18511719
@staticmethod
18521720
@provide_session

airflow-core/tests/unit/api_fastapi/execution_api/conftest.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020

2121
import pytest
2222
from fastapi.testclient import TestClient
23+
from svcs import Registry
2324

2425
from airflow.api_fastapi.app import cached_app
2526
from airflow.api_fastapi.auth.tokens import JWTValidator
26-
from airflow.api_fastapi.execution_api.app import lifespan
2727

2828

2929
@pytest.fixture
@@ -32,8 +32,51 @@ def client(request: pytest.FixtureRequest):
3232

3333
with TestClient(app, headers={"Authorization": "Bearer fake"}) as client:
3434
auth = AsyncMock(spec=JWTValidator)
35-
auth.avalidated_claims.return_value = {"sub": "edb09971-4e0e-4221-ad3f-800852d38085"}
3635

37-
# Inject our fake JWTValidator object. Can be over-ridden by tests if they want
38-
lifespan.registry.register_value(JWTValidator, auth)
36+
# Create a side_effect function that dynamically extracts the task instance ID from validators
37+
def smart_validated_claims(cred, validators=None):
38+
# Extract task instance ID from validators if present
39+
# This handles the JWTBearerTIPathDep case where the validator contains the task ID from the path
40+
if (
41+
validators
42+
and "sub" in validators
43+
and isinstance(validators["sub"], dict)
44+
and "value" in validators["sub"]
45+
):
46+
return {
47+
"sub": validators["sub"]["value"],
48+
"exp": 9999999999, # Far future expiration
49+
"iat": 1000000000, # Past issuance time
50+
"aud": "test-audience",
51+
}
52+
53+
# For other cases (like JWTBearerDep) where no specific validators are provided
54+
# Return a default UUID with all required claims
55+
return {
56+
"sub": "00000000-0000-0000-0000-000000000000",
57+
"exp": 9999999999, # Far future expiration
58+
"iat": 1000000000, # Past issuance time
59+
"aud": "test-audience",
60+
}
61+
62+
# Set the side_effect for avalidated_claims
63+
auth.avalidated_claims.side_effect = smart_validated_claims
64+
65+
# Get the execution API app from the mounted app
66+
execution_app = next(route.app for route in app.routes if route.path == "/execution")
67+
68+
# Create a new registry
69+
registry = Registry()
70+
registry.register_value(JWTValidator, auth)
71+
72+
# Set up the lifespan context
73+
async def setup_lifespan():
74+
execution_app.state.svcs_registry = registry
75+
execution_app.state.lifespan_called = True
76+
77+
# Run the lifespan setup
78+
import asyncio
79+
80+
asyncio.run(setup_lifespan())
81+
3982
yield client

0 commit comments

Comments
 (0)