Skip to content

Add back invalid inlet and outlet check before running tasks #50773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,9 @@ class TaskStatesResponse(BaseModel):
"""Response for task states with run_id, task and state."""

task_states: dict[str, Any]


class InactiveAssetsResponse(BaseModel):
"""Response for inactive assets."""

inactive_assets: Annotated[list[AssetProfile], Field(default_factory=list)]
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

from __future__ import annotations

import contextlib
import itertools
import json
from collections import defaultdict
from collections.abc import Iterator
from typing import TYPE_CHECKING, Annotated, Any
from uuid import UUID

import attrs
import structlog
from cadwyn import VersionedAPIRouter
from fastapi import Body, HTTPException, Query, status
Expand All @@ -37,6 +40,7 @@
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
InactiveAssetsResponse,
PrevSuccessfulDagRunResponse,
TaskStatesResponse,
TIDeferredStatePayload,
Expand All @@ -51,13 +55,16 @@
TITerminalStatePayload,
)
from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException, TaskNotFound
from airflow.models.asset import AssetActive
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI, _stop_remaining_tasks
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.models.xcom import XComModel
from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
from airflow.sdk.definitions.asset import Asset, AssetUniqueKey
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
Expand Down Expand Up @@ -400,12 +407,16 @@ def ti_update_state(
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
updated_state = ti_patch_payload.state
task_instance = session.get(TI, ti_id_str)
TI.register_asset_changes_in_db(
task_instance,
ti_patch_payload.task_outlets, # type: ignore
ti_patch_payload.outlet_events,
session,
)
try:
TI.register_asset_changes_in_db(
task_instance,
ti_patch_payload.task_outlets, # type: ignore
ti_patch_payload.outlet_events,
session,
)
except AirflowInactiveAssetInInletOrOutletException as err:
log.error("Asset registration failed due to conflicting asset: %s", err)

query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
Expand Down Expand Up @@ -840,5 +851,67 @@ def _get_group_tasks(dag_id: str, task_group_id: str, session: SessionDep, logic
return group_tasks


@ti_id_router.get(
"/{task_instance_id}/validate-inlets-and-outlets",
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
},
)
def validate_inlets_and_outlets(
task_instance_id: UUID,
session: SessionDep,
dag_bag: DagBagDep,
) -> InactiveAssetsResponse:
"""Validate whether there're inactive assets in inlets and outlets of a given task instance."""
ti_id_str = str(task_instance_id)
bind_contextvars(ti_id=ti_id_str)

ti = session.scalar(select(TI).where(TI.id == ti_id_str))
if not ti or not ti.logical_date:
log.error("Task Instance not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": "Task Instance not found",
},
)

if not ti.task:
dag = dag_bag.get_dag(ti.dag_id)
if dag:
with contextlib.suppress(TaskNotFound):
ti.task = dag.get_task(ti.task_id)

inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)]
outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)]
if not (inlets or outlets):
return InactiveAssetsResponse(inactive_assets=[])

all_asset_unique_keys: set[AssetUniqueKey] = {
AssetUniqueKey.from_asset(inlet_or_outlet) # type: ignore
for inlet_or_outlet in itertools.chain(inlets, outlets)
}
active_asset_unique_keys = {
AssetUniqueKey(name, uri)
for name, uri in session.execute(
select(AssetActive.name, AssetActive.uri).where(
tuple_(AssetActive.name, AssetActive.uri).in_(
attrs.astuple(key) for key in all_asset_unique_keys
)
)
)
}
different = all_asset_unique_keys - active_asset_unique_keys

return InactiveAssetsResponse(
inactive_assets=[
asset_unique_key.to_asset().asprofile() # type: ignore
for asset_unique_key in different
]
)


# This line should be at the end of the file to ensure all routes are registered
router.include_router(ti_id_router)
1 change: 1 addition & 0 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,6 +1894,7 @@ def schedule_tis(
and not ti.task.on_execute_callback
and not ti.task.on_success_callback
and not ti.task.outlets
and not ti.task.inlets
):
empty_ti_ids.append(ti.id)
# check "start_trigger_args" to see whether the operator supports start execution from triggerer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import TaskGroup, task, task_group
from airflow.sdk import Asset, TaskGroup, task, task_group
from airflow.utils import timezone
from airflow.utils.state import State, TaskInstanceState, TerminalTIState

Expand Down Expand Up @@ -2139,3 +2139,54 @@ def add_one(x):
response = client.get("/execution/task-instances/states", params={"dag_id": dr.dag_id, **params})
assert response.status_code == 200
assert response.json() == {"task_states": {dr.run_id: expected}}


class TestInvactiveInletsAndOutlets:
def test_ti_inactive_inlets_and_outlets(self, client, dag_maker):
"""Test the inactive assets in inlets and outlets can be found."""
with dag_maker("test_inlets_and_outlets"):
EmptyOperator(
task_id="task1",
inlets=[Asset(name="inlet-name"), Asset(name="inlet-name", uri="but-different-uri")],
outlets=[
Asset(name="outlet-name", uri="uri"),
Asset(name="outlet-name", uri="second-different-uri"),
],
)

dr = dag_maker.create_dagrun()

task1_ti = dr.get_task_instance("task1")
response = client.get(f"/execution/task-instances/{task1_ti.id}/validate-inlets-and-outlets")
assert response.status_code == 200
inactive_assets = response.json()["inactive_assets"]
expected_inactive_assets = (
{
"name": "inlet-name",
"type": "Asset",
"uri": "but-different-uri",
},
{
"name": "outlet-name",
"type": "Asset",
"uri": "second-different-uri",
},
)
for asset in expected_inactive_assets:
assert asset in inactive_assets

def test_ti_inactive_inlets_and_outlets_without_inactive_assets(self, client, dag_maker):
"""Test the task without inactive assets in its inlets or outlets returns empty list."""
with dag_maker("test_inlets_and_outlets_inactive"):
EmptyOperator(
task_id="inactive_task1",
inlets=[Asset(name="inlet-name")],
outlets=[Asset(name="outlet-name", uri="uri")],
)

dr = dag_maker.create_dagrun()

task1_ti = dr.get_task_instance("inactive_task1")
response = client.get(f"/execution/task-instances/{task1_ti.id}/validate-inlets-and-outlets")
assert response.status_code == 200
assert response.json() == {"inactive_assets": []}
6 changes: 6 additions & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ConnectionResponse,
DagRunStateResponse,
DagRunType,
InactiveAssetsResponse,
PrevSuccessfulDagRunResponse,
TaskInstanceState,
TaskStatesResponse,
Expand Down Expand Up @@ -273,6 +274,11 @@ def get_task_states(
resp = self.client.get("task-instances/states", params=params)
return TaskStatesResponse.model_validate_json(resp.read())

def validate_inlets_and_outlets(self, id: uuid.UUID) -> InactiveAssetsResponse:
"""Validate whether there're inactive assets in inlets and outlets of a given task instance."""
resp = self.client.get(f"task-instances/{id}/validate-inlets-and-outlets")
return InactiveAssetsResponse.model_validate_json(resp.read())


class ConnectionOperations:
__slots__ = ("client",)
Expand Down
8 changes: 8 additions & 0 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ class DagRunType(str, Enum):
ASSET_TRIGGERED = "asset_triggered"


class InactiveAssetsResponse(BaseModel):
"""
Response for inactive assets.
"""

inactive_assets: Annotated[list[AssetProfile] | None, Field(title="Inactive Assets")] = None


class IntermediateTIState(str, Enum):
"""
States that a Task Instance can be in that indicate it is not yet in a terminal or running state.
Expand Down
12 changes: 12 additions & 0 deletions task-sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ def from_str(key: str) -> AssetUniqueKey:
def to_str(self) -> str:
return json.dumps(attrs.asdict(self))

@staticmethod
def from_profile(profile: AssetProfile) -> AssetUniqueKey:
if profile.name and profile.uri:
return AssetUniqueKey(name=profile.name, uri=profile.uri)

if name := profile.name:
return AssetUniqueKey(name=name, uri=name)
if uri := profile.uri:
return AssetUniqueKey(name=uri, uri=uri)

raise ValueError("name and uri cannot both be empty")


@attrs.define(frozen=True)
class AssetAliasUniqueKey:
Expand Down
26 changes: 26 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
BundleInfo,
ConnectionResponse,
DagRunStateResponse,
InactiveAssetsResponse,
PrevSuccessfulDagRunResponse,
TaskInstance,
TaskInstanceState,
Expand Down Expand Up @@ -208,6 +209,24 @@ def source_task_instance(self) -> AssetEventSourceTaskInstance | None:
)


class InactiveAssetsResult(InactiveAssetsResponse):
"""Response of InactiveAssets requests."""

type: Literal["InactiveAssetsResult"] = "InactiveAssetsResult"

@classmethod
def from_inactive_assets_response(
cls, inactive_assets_response: InactiveAssetsResponse
) -> InactiveAssetsResult:
"""
Get InactiveAssetsResponse from InactiveAssetsResult.

InactiveAssetsResponse is autogenerated from the API schema, so we need to convert it to InactiveAssetsResult
for communication between the Supervisor and the task process.
"""
return cls(**inactive_assets_response.model_dump(exclude_defaults=True), type="InactiveAssetsResult")


class XComResult(XComResponse):
"""Response to ReadXCom request."""

Expand Down Expand Up @@ -376,6 +395,7 @@ class OKResponse(BaseModel):
XComResult,
XComSequenceIndexResult,
XComSequenceSliceResult,
InactiveAssetsResult,
OKResponse,
],
Field(discriminator="type"),
Expand Down Expand Up @@ -590,6 +610,11 @@ class GetAssetEventByAssetAlias(BaseModel):
type: Literal["GetAssetEventByAssetAlias"] = "GetAssetEventByAssetAlias"


class ValidateInletsAndOutlets(BaseModel):
ti_id: UUID
type: Literal["ValidateInletsAndOutlets"] = "ValidateInletsAndOutlets"


class GetPrevSuccessfulDagRun(BaseModel):
ti_id: UUID
type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun"
Expand Down Expand Up @@ -657,6 +682,7 @@ class GetDRCount(BaseModel):
SetXCom,
SkipDownstreamTasks,
SucceedTask,
ValidateInletsAndOutlets,
TaskState,
TriggerDagRun,
DeleteVariable,
Expand Down
6 changes: 6 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
GetXComCount,
GetXComSequenceItem,
GetXComSequenceSlice,
InactiveAssetsResult,
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
Expand All @@ -101,6 +102,7 @@
TaskStatesResult,
ToSupervisor,
TriggerDagRun,
ValidateInletsAndOutlets,
VariableResult,
XComCountResponse,
XComResult,
Expand Down Expand Up @@ -1215,6 +1217,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
)
elif isinstance(msg, DeleteVariable):
resp = self.client.variables.delete(msg.key)
elif isinstance(msg, ValidateInletsAndOutlets):
inactive_assets_resp = self.client.task_instances.validate_inlets_and_outlets(msg.ti_id)
resp = InactiveAssetsResult.from_inactive_assets_response(inactive_assets_resp)
dump_opts = {"exclude_unset": True}
else:
log.error("Unhandled request", msg=msg)
return
Expand Down
Loading
Loading