Skip to content

Port ti.run to Task SDK execution path #50141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def task_test(args, dag: DAG | None = None) -> None:
)
try:
with redirect_stdout(RedactedIO()):
_run_task(ti=ti)
_run_task(ti=ti, run_triggerer=True)
if ti.state == State.FAILED and args.post_mortem:
debugger = _guess_debugger()
debugger.set_trace()
Expand Down
719 changes: 19 additions & 700 deletions airflow-core/src/airflow/models/taskinstance.py

Large diffs are not rendered by default.

31 changes: 29 additions & 2 deletions airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,35 @@ def client(request: pytest.FixtureRequest):

with TestClient(app, headers={"Authorization": "Bearer fake"}) as client:
auth = AsyncMock(spec=JWTValidator)
auth.avalidated_claims.return_value = {"sub": "edb09971-4e0e-4221-ad3f-800852d38085"}

# Inject our fake JWTValidator object. Can be over-ridden by tests if they want
# Create a side_effect function that dynamically extracts the task instance ID from validators
def smart_validated_claims(cred, validators=None):
# Extract task instance ID from validators if present
# This handles the JWTBearerTIPathDep case where the validator contains the task ID from the path
if (
validators
and "sub" in validators
and isinstance(validators["sub"], dict)
and "value" in validators["sub"]
):
return {
"sub": validators["sub"]["value"],
"exp": 9999999999, # Far future expiration
"iat": 1000000000, # Past issuance time
"aud": "test-audience",
}

# For other cases (like JWTBearerDep) where no specific validators are provided
# Return a default UUID with all required claims
return {
"sub": "00000000-0000-0000-0000-000000000000",
"exp": 9999999999, # Far future expiration
"iat": 1000000000, # Past issuance time
"aud": "test-audience",
}

# Set the side_effect for avalidated_claims
auth.avalidated_claims.side_effect = smart_validated_claims
lifespan.registry.register_value(JWTValidator, auth)

yield client
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from __future__ import annotations

import operator
from datetime import datetime
from unittest import mock
from uuid import uuid4
Expand Down Expand Up @@ -1084,22 +1083,18 @@ def test_ti_skip_downstream(self, client, session, create_task_instance, dag_mak
t1 = EmptyOperator(task_id="t1")
t0 >> t1
dr = dag_maker.create_dagrun(run_id="run")
decision = dr.task_instance_scheduling_decisions(session=session)
for ti in sorted(decision.schedulable_tis, key=operator.attrgetter("task_id")):
# TODO: TaskSDK #45549
ti.task = dag_maker.dag.get_task(ti.task_id)
ti.run(session=session)

t0 = dr.get_task_instance("t0")
ti0 = dr.get_task_instance("t0")
ti0.set_state(State.SUCCESS)

response = client.patch(
f"/execution/task-instances/{t0.id}/skip-downstream",
f"/execution/task-instances/{ti0.id}/skip-downstream",
json=_json,
)
t1 = dr.get_task_instance("t1")
ti1 = dr.get_task_instance("t1")

assert response.status_code == 204
assert decision.schedulable_tis[0].state == State.SUCCESS
assert t1.state == State.SKIPPED
assert ti1.state == State.SKIPPED


class TestTIHealthEndpoint:
Expand Down
20 changes: 7 additions & 13 deletions airflow-core/tests/unit/listeners/test_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,13 @@ def test_listener_gets_only_subscribed_calls(create_task_instance, session=None)


@provide_session
def test_listener_suppresses_exceptions(create_task_instance, session, caplog):
def test_listener_suppresses_exceptions(create_task_instance, session, cap_structlog):
lm = get_listener_manager()
lm.add_listener(throwing_listener)

ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED)
with caplog.at_level(logging.ERROR):
ti._run_raw_task()
assert "error calling listener" in caplog.messages
ti.run()
assert "error calling listener" in cap_structlog


@provide_session
Expand All @@ -139,7 +138,7 @@ def test_listener_captures_failed_taskinstances(create_task_instance_of_operator
BashOperator, dag_id=DAG_ID, logical_date=LOGICAL_DATE, task_id=TASK_ID, bash_command="exit 1"
)
with pytest.raises(AirflowException):
ti._run_raw_task()
ti.run()

assert full_listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED]
assert len(full_listener.state) == 2
Expand All @@ -153,7 +152,7 @@ def test_listener_captures_longrunning_taskinstances(create_task_instance_of_ope
ti = create_task_instance_of_operator(
BashOperator, dag_id=DAG_ID, logical_date=LOGICAL_DATE, task_id=TASK_ID, bash_command="sleep 5"
)
ti._run_raw_task()
ti.run()

assert full_listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS]
assert len(full_listener.state) == 2
Expand All @@ -166,13 +165,9 @@ def test_class_based_listener(create_task_instance, session=None):
lm.add_listener(listener)

ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED)
# Using ti.run() instead of ti._run_raw_task() to capture state change to RUNNING
# that only happens on `check_and_change_state_before_execution()` that is called before
# `run()` calls `_run_raw_task()`
ti.run()

assert len(listener.state) == 2
assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS]
assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS, DagRunState.SUCCESS]


def test_listener_logs_call(caplog, create_task_instance, session):
Expand All @@ -181,10 +176,9 @@ def test_listener_logs_call(caplog, create_task_instance, session):
lm.add_listener(full_listener)

ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED)
ti._run_raw_task()
ti.run()

listener_logs = [r for r in caplog.record_tuples if r[0] == "airflow.listeners.listener"]
assert len(listener_logs) == 6
assert all(r[:-1] == ("airflow.listeners.listener", logging.DEBUG) for r in listener_logs)
assert listener_logs[0][-1].startswith("Calling 'on_task_instance_running' with {'")
assert listener_logs[1][-1].startswith("Hook impls: [<HookImpl plugin")
Expand Down
Loading
Loading