Skip to content

Commit 4305cb6

Browse files
committed
WIP 01 - Add deadline to DAG model
Example usage: ``` @task def hello(): log.info('hello world') with DAG( dag_id='dag_with_deadline', deadline=DeadlineAlert( trigger=DeadlineTrigger.DagrunExecutionDate, interval=timedelta(hours=1), callback=hello, ) ): hello() ```
1 parent 319cf30 commit 4305cb6

File tree

15 files changed

+1969
-1775
lines changed

15 files changed

+1969
-1775
lines changed

airflow/dag_processing/collection.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# under the License.
1818

1919
"""
20-
Utility code that write DAGs in bulk into the database.
20+
Utility code that writes DAGs in bulk into the database.
2121
2222
This should generally only be called by internal methods such as
2323
``DagBag._sync_to_db``, ``DAG.bulk_write_to_db``.
@@ -453,6 +453,9 @@ def update_dags(
453453
"core", "max_consecutive_failed_dag_runs_per_dag"
454454
)
455455

456+
if dag.deadline is not None:
457+
dm.deadline = dag.deadline
458+
456459
if hasattr(dag, "has_task_concurrency_limits"):
457460
dm.has_task_concurrency_limits = dag.has_task_concurrency_limits
458461
else:
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
19+
"""
20+
Remove processor_subdir.
21+
22+
Revision ID: dfee8bd5d574
23+
Revises: 6a9e7a527a88
24+
Create Date: 2024-12-18 19:10:26.962464
25+
"""
26+
27+
from __future__ import annotations
28+
29+
import sqlalchemy as sa
30+
import sqlalchemy_jsonfield
31+
from alembic import op
32+
33+
from airflow.settings import json
34+
35+
revision = "dfee8bd5d574"
36+
down_revision = "6a9e7a527a88"
37+
branch_labels = None
38+
depends_on = None
39+
airflow_version = "3.1.0"
40+
41+
42+
def upgrade():
43+
op.add_column(
44+
"dag",
45+
sa.Column("deadline", sqlalchemy_jsonfield.JSONField(json=json), nullable=True),
46+
)
47+
48+
49+
def downgrade():
50+
op.drop_column("dag", "deadline")

airflow/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def import_all_models():
6060
import airflow.models.dag_version
6161
import airflow.models.dagbundle
6262
import airflow.models.dagwarning
63-
import airflow.models.deadline
6463
import airflow.models.errors
6564
import airflow.models.serialized_dag
6665
import airflow.models.taskinstancehistory
@@ -96,6 +95,7 @@ def __getattr__(name):
9695
"DagTag": "airflow.models.dag",
9796
"DagWarning": "airflow.models.dagwarning",
9897
"DbCallbackRequest": "airflow.models.db_callback_request",
98+
"Deadline": "airflow.models.deadline",
9999
"Log": "airflow.models.log",
100100
"MappedOperator": "airflow.models.mappedoperator",
101101
"Operator": "airflow.models.operator",

airflow/models/dag.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ class DAG(TaskSDKDag, LoggingMixin):
381381
:param dagrun_timeout: Specify the duration a DagRun should be allowed to run before it times out or
382382
fails. Task instances that are running when a DagRun is timed out will be marked as skipped.
383383
:param sla_miss_callback: DEPRECATED - The SLA feature is removed in Airflow 3.0, to be replaced with a new implementation in 3.1
384+
:param deadline: Optional Deadline Alert for the DAG.
384385
:param default_view: Specify DAG default view (grid, graph, duration,
385386
gantt, landing_times), default grid
386387
:param orientation: Specify DAG orientation in graph view (LR, TB, RL, BT), default LR
@@ -2058,7 +2059,7 @@ class DagModel(Base):
20582059

20592060
__tablename__ = "dag"
20602061
"""
2061-
These items are stored in the database for state related information
2062+
These items are stored in the database for state related information.
20622063
"""
20632064
dag_id = Column(StringID(), primary_key=True)
20642065
# A DAG can be paused from the UI / DB
@@ -2095,6 +2096,8 @@ class DagModel(Base):
20952096
timetable_description = Column(String(1000), nullable=True)
20962097
# Asset expression based on asset triggers
20972098
asset_expression = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
2099+
# DAG deadline information
2100+
deadline = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
20982101
# Tags for view filter
20992102
tags = relationship("DagTag", cascade="all, delete, delete-orphan", backref=backref("dag"))
21002103
# Dag owner links for DAGs view

airflow/models/deadline.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
from datetime import datetime
20-
from typing import TYPE_CHECKING
19+
import logging
20+
import sys
21+
from datetime import datetime, timedelta
22+
from typing import TYPE_CHECKING, Callable
2123

2224
import sqlalchemy_jsonfield
2325
import uuid6
@@ -32,6 +34,8 @@
3234
if TYPE_CHECKING:
3335
from sqlalchemy.orm import Session
3436

37+
log = logging.getLogger(__name__)
38+
3539

3640
class Deadline(Base, LoggingMixin):
3741
"""A Deadline is a 'need-by' date which triggers a callback if the provided time has passed."""
@@ -90,3 +94,86 @@ def _determine_resource() -> tuple[str, str]:
9094
def add_deadline(cls, deadline: Deadline, session: Session = NEW_SESSION):
9195
"""Add the provided deadline to the table."""
9296
session.add(deadline)
97+
98+
99+
class DeadlineTrigger:
100+
"""
101+
Store the calculation methods for the various SDeadline Alert triggers.
102+
103+
TODO: Embetter this docstring muchly.
104+
105+
usage:
106+
107+
In the DAG define a deadline as
108+
109+
deadline=DeadlineAlert(
110+
trigger=DeadlineTrigger.DAGRUN_EXECUTION_DATE,
111+
interval=timedelta(hours=1),
112+
callback=hello,
113+
)
114+
115+
to parse the deadline trigger use DeadlineTrigger.evaluate(dag.deadline.trigger)
116+
"""
117+
118+
DAGRUN_EXECUTION_DATE = "dagrun_execution_date"
119+
120+
@staticmethod
121+
def evaluate(trigger: str):
122+
return eval(f"DeadlineTrigger().{trigger}()")
123+
124+
@staticmethod
125+
def get_from_db(table_name, column_name):
126+
# TODO:
127+
# fetch appropriate timestamp from db
128+
# cast to datetime
129+
# return
130+
log.info("MOCKED Getting %s :: %s", table_name, column_name)
131+
return datetime(2024, 1, 1)
132+
133+
def dagrun_execution_date(self) -> datetime:
134+
return self.get_from_db("dagrun", "execution_date")
135+
136+
137+
class DeadlineAlert(LoggingMixin):
138+
"""Store Deadline values needed to calculate the need-by timestamp and the callback information."""
139+
140+
def __init__(
141+
self,
142+
trigger: type[DeadlineTrigger] | datetime,
143+
interval: timedelta,
144+
callback: Callable | str,
145+
callback_kwargs: dict | None = None,
146+
):
147+
super().__init__()
148+
self.trigger = trigger
149+
self.interval = interval
150+
self.callback_kwargs = callback_kwargs
151+
self.callback = self.get_callback_path(callback)
152+
153+
@staticmethod
154+
def get_callback_path(_callback: str | Callable) -> str:
155+
if callable(_callback):
156+
# Get the reference path to the callable in the form `airflow.models.deadline.get_from_db`
157+
return f"{_callback.__module__}.{_callback.__qualname__}"
158+
159+
# Check if the dotpath can resolve to a callable; store it or raise a ValueError
160+
try:
161+
_callback_module, _callback_name = _callback.rsplit(".", 1)
162+
getattr(sys.modules[_callback_module], _callback_name)
163+
return _callback
164+
except (KeyError, AttributeError):
165+
# KeyError if the path is not valid
166+
# AttributeError if the provided value can't be rsplit
167+
raise ValueError("callback is not a path to a callable")
168+
169+
def serialize_deadline_alert(self):
170+
from airflow.serialization.serialized_objects import BaseSerialization
171+
172+
return BaseSerialization.serialize(
173+
{
174+
"trigger": self.trigger,
175+
"interval": self.interval,
176+
"callback": self.callback,
177+
"callback_kwargs": self.callback_kwargs,
178+
}
179+
)

airflow/serialization/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,4 @@ class DagAttributeTypes(str, Enum):
7575
DAG_CALLBACK_REQUEST = "dag_callback_request"
7676
TASK_INSTANCE_KEY = "task_instance_key"
7777
TRIGGER = "trigger"
78+
DEADLINE_ALERT = "deadline_alert"

airflow/serialization/schema.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,12 @@
186186
},
187187
"dag_display_name": { "type" : "string"},
188188
"description": { "type" : "string"},
189+
"deadline": {
190+
"anyOf": [
191+
{ "$ref": "#/definitions/dict" },
192+
{ "type": "null" }
193+
]
194+
},
189195
"_concurrency": { "type" : "number"},
190196
"max_active_tasks": { "type" : "number"},
191197
"max_active_runs": { "type" : "number"},

airflow/serialization/serialized_objects.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from airflow.models.baseoperatorlink import BaseOperatorLink, XComOperatorLink
4444
from airflow.models.connection import Connection
4545
from airflow.models.dag import DAG, _get_model_data_interval
46+
from airflow.models.deadline import DeadlineAlert
4647
from airflow.models.expandinput import (
4748
EXPAND_INPUT_EMPTY,
4849
create_expand_input,
@@ -720,6 +721,8 @@ def serialize(
720721
)
721722
elif isinstance(var, DAG):
722723
return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG)
724+
elif isinstance(var, DeadlineAlert):
725+
return cls._encode(DeadlineAlert.serialize_deadline_alert(var), type_=DAT.DEADLINE_ALERT)
723726
elif isinstance(var, Resources):
724727
return var.to_dict()
725728
elif isinstance(var, MappedOperator):
@@ -1666,6 +1669,8 @@ def serialize_dag(cls, dag: DAG) -> dict:
16661669
serialized_dag["dag_dependencies"] = [x.__dict__ for x in sorted(dag_deps)]
16671670
serialized_dag["task_group"] = TaskGroupSerialization.serialize_task_group(dag.task_group)
16681671

1672+
serialized_dag["deadline"] = dag.deadline.serialize_deadline_alert() if dag.deadline else None
1673+
16691674
# Edge info in the JSON exactly matches our internal structure
16701675
serialized_dag["edge_info"] = dag.edge_info
16711676
serialized_dag["params"] = cls._serialize_params_dict(dag.params)

airflow/utils/db.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class MappedClassProtocol(Protocol):
9595
"2.10.0": "22ed7efa9da2",
9696
"2.10.3": "5f2621c13b39",
9797
"3.0.0": "6a9e7a527a88",
98+
"3.1.0": "dfee8bd5d574",
9899
}
99100

100101

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
617116c74735faa69a297fe665664b691f341487bc8dcbef3e3ec7e76cdea799
1+
b4c271a3e9913d4d0a00ae7ec0744a5071d593dff92447872a5cf2bed89000c9

0 commit comments

Comments
 (0)