Skip to content

Commit 04d3f45

Browse files
hussein-awalauranusjreladkal
authored
Order triggers by - TI priority_weight when assign unassigned triggers (#32318)
* Order triggers by - TI priority_weight when assign unassigned triggers Signed-off-by: Hussein Awala <[email protected]> * Update airflow/models/trigger.py Co-authored-by: Tzu-ping Chung <[email protected]> * Replace outer join by inner join and use coalesce to handle None values * fix unit tests --------- Signed-off-by: Hussein Awala <[email protected]> Co-authored-by: Tzu-ping Chung <[email protected]> Co-authored-by: eladkal <[email protected]>
1 parent a2a0d05 commit 04d3f45

File tree

3 files changed

+130
-16
lines changed

3 files changed

+130
-16
lines changed

airflow/models/trigger.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from sqlalchemy import Column, Integer, String, delete, func, or_, select, update
2424
from sqlalchemy.orm import Session, joinedload, relationship
25+
from sqlalchemy.sql.functions import coalesce
2526

2627
from airflow.api_internal.internal_api_call import internal_api_call
2728
from airflow.models.base import Base
@@ -244,8 +245,9 @@ def assign_unassigned(
244245
def get_sorted_triggers(cls, capacity, alive_triggerer_ids, session):
245246
query = with_row_locks(
246247
select(cls.id)
248+
.join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=False)
247249
.where(or_(cls.triggerer_id.is_(None), cls.triggerer_id.not_in(alive_triggerer_ids)))
248-
.order_by(cls.created_date)
250+
.order_by(coalesce(TaskInstance.priority_weight, 0).desc(), cls.created_date)
249251
.limit(capacity),
250252
session,
251253
skip_locked=True,

tests/jobs/test_triggerer_job.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def handle_events(self):
414414
assert len(instances) == 1
415415

416416

417-
def test_trigger_from_dead_triggerer(session):
417+
def test_trigger_from_dead_triggerer(session, create_task_instance):
418418
"""
419419
Checks that the triggerer will correctly claim a Trigger that is assigned to a
420420
triggerer that does not exist.
@@ -425,6 +425,13 @@ def test_trigger_from_dead_triggerer(session):
425425
trigger_orm.id = 1
426426
trigger_orm.triggerer_id = 999 # Non-existent triggerer
427427
session.add(trigger_orm)
428+
ti_orm = create_task_instance(
429+
task_id="ti_orm",
430+
execution_date=datetime.datetime.utcnow(),
431+
run_id="orm_run_id",
432+
)
433+
ti_orm.trigger_id = trigger_orm.id
434+
session.add(trigger_orm)
428435
session.commit()
429436
# Make a TriggererJobRunner and have it retrieve DB tasks
430437
job = Job()
@@ -434,7 +441,7 @@ def test_trigger_from_dead_triggerer(session):
434441
assert [x for x, y in job_runner.trigger_runner.to_create] == [1]
435442

436443

437-
def test_trigger_from_expired_triggerer(session):
444+
def test_trigger_from_expired_triggerer(session, create_task_instance):
438445
"""
439446
Checks that the triggerer will correctly claim a Trigger that is assigned to a
440447
triggerer that has an expired heartbeat.
@@ -445,6 +452,13 @@ def test_trigger_from_expired_triggerer(session):
445452
trigger_orm.id = 1
446453
trigger_orm.triggerer_id = 42
447454
session.add(trigger_orm)
455+
ti_orm = create_task_instance(
456+
task_id="ti_orm",
457+
execution_date=datetime.datetime.utcnow(),
458+
run_id="orm_run_id",
459+
)
460+
ti_orm.trigger_id = trigger_orm.id
461+
session.add(trigger_orm)
448462
# Use a TriggererJobRunner with an expired heartbeat
449463
triggerer_job_orm = Job(TriggererJobRunner.job_type)
450464
triggerer_job_orm.id = 42

tests/models/test_trigger.py

Lines changed: 111 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -171,19 +171,47 @@ def test_assign_unassigned(session, create_task_instance):
171171
trigger_on_healthy_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
172172
trigger_on_healthy_triggerer.id = 1
173173
trigger_on_healthy_triggerer.triggerer_id = healthy_triggerer.id
174+
session.add(trigger_on_healthy_triggerer)
175+
ti_trigger_on_healthy_triggerer = create_task_instance(
176+
task_id="ti_trigger_on_healthy_triggerer",
177+
execution_date=time_now,
178+
run_id="trigger_on_healthy_triggerer_run_id",
179+
)
180+
ti_trigger_on_healthy_triggerer.trigger_id = trigger_on_healthy_triggerer.id
181+
session.add(ti_trigger_on_healthy_triggerer)
174182
trigger_on_unhealthy_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
175183
trigger_on_unhealthy_triggerer.id = 2
176184
trigger_on_unhealthy_triggerer.triggerer_id = unhealthy_triggerer.id
185+
session.add(trigger_on_unhealthy_triggerer)
186+
ti_trigger_on_unhealthy_triggerer = create_task_instance(
187+
task_id="ti_trigger_on_unhealthy_triggerer",
188+
execution_date=time_now + datetime.timedelta(hours=1),
189+
run_id="trigger_on_unhealthy_triggerer_run_id",
190+
)
191+
ti_trigger_on_unhealthy_triggerer.trigger_id = trigger_on_unhealthy_triggerer.id
192+
session.add(ti_trigger_on_unhealthy_triggerer)
177193
trigger_on_killed_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
178194
trigger_on_killed_triggerer.id = 3
179195
trigger_on_killed_triggerer.triggerer_id = finished_triggerer.id
196+
session.add(trigger_on_killed_triggerer)
197+
ti_trigger_on_killed_triggerer = create_task_instance(
198+
task_id="ti_trigger_on_killed_triggerer",
199+
execution_date=time_now + datetime.timedelta(hours=2),
200+
run_id="trigger_on_killed_triggerer_run_id",
201+
)
202+
ti_trigger_on_killed_triggerer.trigger_id = trigger_on_killed_triggerer.id
203+
session.add(ti_trigger_on_killed_triggerer)
180204
trigger_unassigned_to_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
181205
trigger_unassigned_to_triggerer.id = 4
182-
assert trigger_unassigned_to_triggerer.triggerer_id is None
183-
session.add(trigger_on_healthy_triggerer)
184-
session.add(trigger_on_unhealthy_triggerer)
185-
session.add(trigger_on_killed_triggerer)
186206
session.add(trigger_unassigned_to_triggerer)
207+
ti_trigger_unassigned_to_triggerer = create_task_instance(
208+
task_id="ti_trigger_unassigned_to_triggerer",
209+
execution_date=time_now + datetime.timedelta(hours=3),
210+
run_id="trigger_unassigned_to_triggerer_run_id",
211+
)
212+
ti_trigger_unassigned_to_triggerer.trigger_id = trigger_unassigned_to_triggerer.id
213+
session.add(ti_trigger_unassigned_to_triggerer)
214+
assert trigger_unassigned_to_triggerer.triggerer_id is None
187215
session.commit()
188216
assert session.query(Trigger).count() == 4
189217
Trigger.assign_unassigned(new_triggerer.id, 100, health_check_threshold=30)
@@ -209,31 +237,101 @@ def test_assign_unassigned(session, create_task_instance):
209237
)
210238

211239

212-
def test_get_sorted_triggers(session, create_task_instance):
240+
def test_get_sorted_triggers_same_priority_weight(session, create_task_instance):
213241
"""
214-
Tests that triggers are sorted by the creation_date.
242+
Tests that triggers are sorted by the creation_date if they have the same priority.
215243
"""
244+
old_execution_date = datetime.datetime(
245+
2023, 5, 9, 12, 16, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
246+
)
216247
trigger_old = Trigger(
217248
classpath="airflow.triggers.testing.SuccessTrigger",
218249
kwargs={},
219-
created_date=datetime.datetime(
220-
2023, 5, 9, 12, 16, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
221-
),
250+
created_date=old_execution_date + datetime.timedelta(seconds=30),
222251
)
223252
trigger_old.id = 1
253+
session.add(trigger_old)
254+
TI_old = create_task_instance(
255+
task_id="old",
256+
execution_date=old_execution_date,
257+
run_id="old_run_id",
258+
)
259+
TI_old.priority_weight = 1
260+
TI_old.trigger_id = trigger_old.id
261+
session.add(TI_old)
262+
263+
new_execution_date = datetime.datetime(
264+
2023, 5, 9, 12, 17, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
265+
)
224266
trigger_new = Trigger(
225267
classpath="airflow.triggers.testing.SuccessTrigger",
226268
kwargs={},
227-
created_date=datetime.datetime(
228-
2023, 5, 9, 12, 17, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
229-
),
269+
created_date=new_execution_date + datetime.timedelta(seconds=30),
230270
)
231271
trigger_new.id = 2
232-
session.add(trigger_old)
233272
session.add(trigger_new)
273+
TI_new = create_task_instance(
274+
task_id="new",
275+
execution_date=new_execution_date,
276+
run_id="new_run_id",
277+
)
278+
TI_new.priority_weight = 1
279+
TI_new.trigger_id = trigger_new.id
280+
session.add(TI_new)
281+
234282
session.commit()
235283
assert session.query(Trigger).count() == 2
236284

237285
trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], session=session)
238286

239287
assert trigger_ids_query == [(1,), (2,)]
288+
289+
290+
def test_get_sorted_triggers_different_priority_weights(session, create_task_instance):
291+
"""
292+
Tests that triggers are sorted by the priority_weight.
293+
"""
294+
old_execution_date = datetime.datetime(
295+
2023, 5, 9, 12, 16, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
296+
)
297+
trigger_old = Trigger(
298+
classpath="airflow.triggers.testing.SuccessTrigger",
299+
kwargs={},
300+
created_date=old_execution_date + datetime.timedelta(seconds=30),
301+
)
302+
trigger_old.id = 1
303+
session.add(trigger_old)
304+
TI_old = create_task_instance(
305+
task_id="old",
306+
execution_date=old_execution_date,
307+
run_id="old_run_id",
308+
)
309+
TI_old.priority_weight = 1
310+
TI_old.trigger_id = trigger_old.id
311+
session.add(TI_old)
312+
313+
new_execution_date = datetime.datetime(
314+
2023, 5, 9, 12, 17, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
315+
)
316+
trigger_new = Trigger(
317+
classpath="airflow.triggers.testing.SuccessTrigger",
318+
kwargs={},
319+
created_date=new_execution_date + datetime.timedelta(seconds=30),
320+
)
321+
trigger_new.id = 2
322+
session.add(trigger_new)
323+
TI_new = create_task_instance(
324+
task_id="new",
325+
execution_date=new_execution_date,
326+
run_id="new_run_id",
327+
)
328+
TI_new.priority_weight = 2
329+
TI_new.trigger_id = trigger_new.id
330+
session.add(TI_new)
331+
332+
session.commit()
333+
assert session.query(Trigger).count() == 2
334+
335+
trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], session=session)
336+
337+
assert trigger_ids_query == [(2,), (1,)]

0 commit comments

Comments
 (0)