Skip to content

Commit 128b311

Browse files
committed
feat(task-sdk): check invalid inlets or outlets before running tasks
1 parent 850ba8e commit 128b311

File tree

8 files changed

+138
-0
lines changed

8 files changed

+138
-0
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,9 @@ class TaskStatesResponse(BaseModel):
345345
"""Response for task states with run_id, task and state."""
346346

347347
task_states: dict[str, Any]
348+
349+
350+
class InvalidAssetsResponse(BaseModel):
351+
"""Response for invalid assets."""
352+
353+
invalid_assets: Annotated[list[AssetProfile], Field(default_factory=list)]

airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717

1818
from __future__ import annotations
1919

20+
import contextlib
21+
import itertools
2022
import json
2123
from collections import defaultdict
2224
from collections.abc import Iterator
2325
from typing import TYPE_CHECKING, Annotated, Any
2426
from uuid import UUID
2527

28+
import attrs
2629
import structlog
2730
from cadwyn import VersionedAPIRouter
2831
from fastapi import Body, HTTPException, Query, status
@@ -37,6 +40,7 @@
3740
from airflow.api_fastapi.common.db.common import SessionDep
3841
from airflow.api_fastapi.common.types import UtcDateTime
3942
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
43+
InvalidAssetsResponse,
4044
PrevSuccessfulDagRunResponse,
4145
TaskStatesResponse,
4246
TIDeferredStatePayload,
@@ -51,13 +55,16 @@
5155
TITerminalStatePayload,
5256
)
5357
from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
58+
from airflow.exceptions import TaskNotFound
59+
from airflow.models.asset import AssetActive
5460
from airflow.models.dagbag import DagBag
5561
from airflow.models.dagrun import DagRun as DR
5662
from airflow.models.taskinstance import TaskInstance as TI, _stop_remaining_tasks
5763
from airflow.models.taskreschedule import TaskReschedule
5864
from airflow.models.trigger import Trigger
5965
from airflow.models.xcom import XComModel
6066
from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
67+
from airflow.sdk.definitions.asset import Asset, AssetUniqueKey
6168
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
6269
from airflow.utils import timezone
6370
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
@@ -840,5 +847,67 @@ def _get_group_tasks(dag_id: str, task_group_id: str, session: SessionDep, logic
840847
return group_tasks
841848

842849

850+
@ti_id_router.get(
851+
"/{task_instance_id}/validate-inlets-and-outlets",
852+
status_code=status.HTTP_200_OK,
853+
responses={
854+
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
855+
},
856+
)
857+
def validate_inlets_and_outlets(
858+
task_instance_id: UUID,
859+
session: SessionDep,
860+
dag_bag: DagBagDep,
861+
) -> InvalidAssetsResponse:
862+
"""Validate whehter there're inactive assets in inlets and outlets of a given task instance."""
863+
ti_id_str = str(task_instance_id)
864+
bind_contextvars(ti_id=ti_id_str)
865+
866+
ti = session.scalar(select(TI).where(TI.id == ti_id_str))
867+
if not ti or not ti.logical_date:
868+
log.error("Task Instance not found")
869+
raise HTTPException(
870+
status_code=status.HTTP_404_NOT_FOUND,
871+
detail={
872+
"reason": "not_found",
873+
"message": "Task Instance not found",
874+
},
875+
)
876+
877+
if not ti.task:
878+
dag = dag_bag.get_dag(ti.dag_id)
879+
if dag:
880+
with contextlib.suppress(TaskNotFound):
881+
ti.task = dag.get_task(ti.task_id)
882+
883+
inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)]
884+
outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)]
885+
if not (inlets or outlets):
886+
return InvalidAssetsResponse(invalid_assets=[])
887+
888+
all_asset_unique_keys: set[AssetUniqueKey] = {
889+
AssetUniqueKey.from_asset(inlet_or_outlet) # type: ignore
890+
for inlet_or_outlet in itertools.chain(inlets, outlets)
891+
}
892+
active_asset_unique_keys = {
893+
AssetUniqueKey(name, uri)
894+
for name, uri in session.execute(
895+
select(AssetActive.name, AssetActive.uri).where(
896+
tuple_(AssetActive.name, AssetActive.uri).in_(
897+
attrs.astuple(key) for key in all_asset_unique_keys
898+
)
899+
)
900+
)
901+
}
902+
different = all_asset_unique_keys - active_asset_unique_keys
903+
904+
return InvalidAssetsResponse(
905+
invalid_assets=[
906+
asset_unique_key.to_asset().asprofile() # type: ignore
907+
for asset_unique_key in different
908+
]
909+
)
910+
911+
843912
# This line should be at the end of the file to ensure all routes are registered
844913
router.include_router(ti_id_router)

airflow-core/src/airflow/models/dagrun.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,6 +1894,7 @@ def schedule_tis(
18941894
and not ti.task.on_execute_callback
18951895
and not ti.task.on_success_callback
18961896
and not ti.task.outlets
1897+
and not ti.task.inlets
18971898
):
18981899
empty_ti_ids.append(ti.id)
18991900
# check "start_trigger_args" to see whether the operator supports start execution from triggerer

task-sdk/src/airflow/sdk/api/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
ConnectionResponse,
4141
DagRunStateResponse,
4242
DagRunType,
43+
InvalidAssetsResponse,
4344
PrevSuccessfulDagRunResponse,
4445
TaskInstanceState,
4546
TaskStatesResponse,
@@ -272,6 +273,11 @@ def get_task_states(
272273
resp = self.client.get("task-instances/states", params=params)
273274
return TaskStatesResponse.model_validate_json(resp.read())
274275

276+
def validate_inlets_and_outlets(self, id: uuid.UUID) -> InvalidAssetsResponse:
277+
"""Validate whehter there're inactive assets in inlets and outlets of a given task instance."""
278+
resp = self.client.get(f"task-instances/{id}/validate-inlets-and-outlets")
279+
return InvalidAssetsResponse.model_validate_json(resp.read())
280+
275281

276282
class ConnectionOperations:
277283
__slots__ = ("client",)

task-sdk/src/airflow/sdk/api/datamodels/_generated.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,14 @@ class IntermediateTIState(str, Enum):
168168
DEFERRED = "deferred"
169169

170170

171+
class InvalidAssetsResponse(BaseModel):
172+
"""
173+
Response for invalid assets.
174+
"""
175+
176+
invalid_assets: Annotated[list[AssetProfile] | None, Field(title="Invalid Assets")] = None
177+
178+
171179
class PrevSuccessfulDagRunResponse(BaseModel):
172180
"""
173181
Schema for response with previous successful DagRun information for Task Template Context.

task-sdk/src/airflow/sdk/execution_time/comms.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
BundleInfo,
6262
ConnectionResponse,
6363
DagRunStateResponse,
64+
InvalidAssetsResponse,
6465
PrevSuccessfulDagRunResponse,
6566
TaskInstance,
6667
TaskInstanceState,
@@ -206,6 +207,24 @@ def source_task_instance(self) -> AssetEventSourceTaskInstance | None:
206207
)
207208

208209

210+
class InvalidAssetsResult(InvalidAssetsResponse):
211+
"""Response of InvalidAssets requests."""
212+
213+
type: Literal["InvalidAssetsResult"] = "InvalidAssetsResult"
214+
215+
@classmethod
216+
def from_invalid_assets_response(
217+
cls, invalid_assets_response: InvalidAssetsResponse
218+
) -> InvalidAssetsResult:
219+
"""
220+
Get InvalidAssetsResponse from InvalidAssetsResult.
221+
222+
InvalidAssetsResponse is autogenerated from the API schema, so we need to convert it to InvalidAssetsResult
223+
for communication between the Supervisor and the task process.
224+
"""
225+
return cls(**invalid_assets_response.model_dump(exclude_defaults=True), type="InvalidAssetsResult")
226+
227+
209228
class XComResult(XComResponse):
210229
"""Response to ReadXCom request."""
211230

@@ -354,6 +373,7 @@ class OKResponse(BaseModel):
354373
VariableResult,
355374
XComResult,
356375
XComCountResponse,
376+
InvalidAssetsResult,
357377
OKResponse,
358378
],
359379
Field(discriminator="type"),
@@ -557,6 +577,11 @@ class GetAssetEventByAssetAlias(BaseModel):
557577
type: Literal["GetAssetEventByAssetAlias"] = "GetAssetEventByAssetAlias"
558578

559579

580+
class ValidateInletsAndOutlets(BaseModel):
581+
ti_id: UUID
582+
type: Literal["ValidateInletsAndOutlets"] = "ValidateInletsAndOutlets"
583+
584+
560585
class GetPrevSuccessfulDagRun(BaseModel):
561586
ti_id: UUID
562587
type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun"
@@ -623,6 +648,7 @@ class GetDRCount(BaseModel):
623648
SetXCom,
624649
SkipDownstreamTasks,
625650
SucceedTask,
651+
ValidateInletsAndOutlets,
626652
TaskState,
627653
TriggerDagRun,
628654
DeleteVariable,

task-sdk/src/airflow/sdk/execution_time/supervisor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
GetXCom,
8888
GetXComCount,
8989
GetXComSequenceItem,
90+
InvalidAssetsResult,
9091
PrevSuccessfulDagRunResult,
9192
PutVariable,
9293
RescheduleTask,
@@ -100,6 +101,7 @@
100101
TaskStatesResult,
101102
ToSupervisor,
102103
TriggerDagRun,
104+
ValidateInletsAndOutlets,
103105
VariableResult,
104106
XComCountResponse,
105107
XComResult,
@@ -1154,6 +1156,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
11541156
)
11551157
elif isinstance(msg, DeleteVariable):
11561158
resp = self.client.variables.delete(msg.key)
1159+
elif isinstance(msg, ValidateInletsAndOutlets):
1160+
invalid_assets_resp = self.client.task_instances.validate_inlets_and_outlets(msg.ti_id)
1161+
resp = InvalidAssetsResult.from_invalid_assets_response(invalid_assets_resp)
1162+
dump_opts = {"exclude_unset": True}
11571163
else:
11581164
log.error("Unhandled request", msg=msg)
11591165
return

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock
4242
from airflow.dag_processing.bundles.manager import DagBundlesManager
43+
from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException
4344
from airflow.listeners.listener import get_listener_manager
4445
from airflow.sdk.api.datamodels._generated import (
4546
AssetProfile,
@@ -66,6 +67,7 @@
6667
GetTaskRescheduleStartDate,
6768
GetTaskStates,
6869
GetTICount,
70+
InvalidAssetsResult,
6971
RescheduleTask,
7072
RetryTask,
7173
SetRenderedFields,
@@ -79,6 +81,7 @@
7981
ToSupervisor,
8082
ToTask,
8183
TriggerDagRun,
84+
ValidateInletsAndOutlets,
8285
)
8386
from airflow.sdk.execution_time.context import (
8487
ConnectionAccessor,
@@ -778,6 +781,8 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv
778781
# so that we do not call the API unnecessarily
779782
SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields))
780783

784+
_validate_task_inlets_and_outlets(ti=ti, log=log)
785+
781786
try:
782787
# TODO: Call pre execute etc.
783788
get_listener_manager().hook.on_task_instance_running(
@@ -790,6 +795,17 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv
790795
return None
791796

792797

798+
def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None:
799+
SUPERVISOR_COMMS.send_request(msg=ValidateInletsAndOutlets(ti_id=ti.id), log=log)
800+
invalid_assets_resp = SUPERVISOR_COMMS.get_message()
801+
if TYPE_CHECKING:
802+
assert isinstance(invalid_assets_resp, InvalidAssetsResult)
803+
if invalid_assets := invalid_assets_resp.invalid_assets:
804+
raise AirflowInactiveAssetInInletOrOutletException(
805+
inactive_asset_keys=[AssetUniqueKey.from_asset(asset) for asset in invalid_assets]
806+
)
807+
808+
793809
def _defer_task(
794810
defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger
795811
) -> tuple[ToSupervisor, TaskInstanceState]:

0 commit comments

Comments
 (0)