Skip to content

Commit cb321eb

Browse files
Improve dismissed requests (#109)
* track killed worker log event from nanny * add configurability to requeue * request not found * rename requeue env var * change env variable * check all qos status * qa * manage dismissed requests * fix * improve reset_qos_rules * fix reset_qos_rules * remove warning to restart the broker. it is not needed anymore * qa * tests
1 parent ec1eaf3 commit cb321eb

File tree

5 files changed

+70
-92
lines changed

5 files changed

+70
-92
lines changed

cads_broker/database.py

+39-19
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import sqlalchemy_utils
1515
import structlog
1616
from sqlalchemy.dialects.postgresql import JSONB
17+
from typing_extensions import Iterable
1718

1819
import alembic.command
1920
import alembic.config
@@ -27,6 +28,9 @@
2728
status_enum = sa.Enum(
2829
"accepted", "running", "failed", "successful", "dismissed", name="status"
2930
)
31+
DISMISSED_MESSAGE = os.getenv(
32+
"DISMISSED_MESSAGE", "The request has been dismissed by the system."
33+
)
3034

3135

3236
class NoResultFound(Exception):
@@ -421,6 +425,28 @@ def count_users(status: str, entry_point: str, session: sa.orm.Session) -> int:
421425
)
422426

423427

428+
def update_dismissed_requests(session: sa.orm.Session) -> Iterable[str]:
429+
stmt_dismissed = (
430+
sa.update(SystemRequest)
431+
.where(SystemRequest.status == "dismissed")
432+
.returning(SystemRequest.request_uid)
433+
.values(status="failed", response_error={"reason": "dismissed request"})
434+
)
435+
dismissed_uids = session.scalars(stmt_dismissed).fetchall()
436+
session.execute( # type: ignore
437+
sa.insert(Events),
438+
map(
439+
lambda x: {
440+
"request_uid": x,
441+
"message": DISMISSED_MESSAGE,
442+
"event_type": "user_visible_error",
443+
},
444+
dismissed_uids,
445+
),
446+
)
447+
return dismissed_uids
448+
449+
424450
def get_events_from_request(
425451
request_uid: str,
426452
session: sa.orm.Session,
@@ -439,11 +465,23 @@ def get_events_from_request(
439465
return events
440466

441467

442-
def reset_qos_rules(session: sa.orm.Session):
468+
def reset_qos_rules(session: sa.orm.Session, qos):
443469
"""Delete all QoS rules."""
444470
for rule in session.scalars(sa.select(QoSRule)):
445471
rule.system_requests = []
446472
session.delete(rule)
473+
cached_rules: dict[str, Any] = {}
474+
for request in get_running_requests(session):
475+
# Recompute the limits
476+
limits = qos.limits_for(request, session)
477+
cached_rules.update(
478+
delete_request_qos_status(
479+
request_uid=request.request_uid,
480+
rules=limits,
481+
session=session,
482+
rules_in_db=cached_rules,
483+
)
484+
)
447485
session.commit()
448486

449487

@@ -482,24 +520,6 @@ def add_qos_rule(
482520
return qos_rule
483521

484522

485-
def increment_qos_rule_running(
486-
rules: list, session: sa.orm.Session, rules_in_db: dict[str, QoSRule] = {}, **kwargs
487-
):
488-
"""Increment the running counter of a QoS rule."""
489-
created_rules: dict = {}
490-
for rule in rules:
491-
if (rule_uid := str(rule.__hash__())) in rules_in_db:
492-
qos_rule = rules_in_db[rule_uid]
493-
else:
494-
try:
495-
qos_rule = get_qos_rule(rule_uid, session)
496-
except sqlalchemy.orm.exc.NoResultFound:
497-
qos_rule = add_qos_rule(rule=rule, session=session)
498-
created_rules[qos_rule.uid] = qos_rule
499-
qos_rule.running += 1
500-
return created_rules
501-
502-
503523
def decrement_qos_rule_running(
504524
rules: list, session: sa.orm.Session, rules_in_db: dict[str, QoSRule] = {}, **kwargs
505525
):

cads_broker/dispatcher.py

+19-17
Original file line numberDiff line numberDiff line change
@@ -212,19 +212,20 @@ def from_address(
212212
factory.register_functions()
213213
session_maker_read = db.ensure_session_obj(session_maker_read, mode="r")
214214
session_maker_write = db.ensure_session_obj(session_maker_write, mode="w")
215-
with session_maker_write() as session:
216-
db.reset_qos_rules(session)
217215
rules_hash = get_rules_hash(qos_config.rules_path)
216+
qos = QoS.QoS(
217+
qos_config.rules,
218+
qos_config.environment,
219+
rules_hash=rules_hash,
220+
)
221+
with session_maker_write() as session:
222+
db.reset_qos_rules(session, qos)
218223
self = cls(
219224
client=client,
220225
session_maker_read=session_maker_read,
221226
session_maker_write=session_maker_write,
222227
environment=qos_config.environment,
223-
qos=QoS.QoS(
224-
qos_config.rules,
225-
qos_config.environment,
226-
rules_hash=rules_hash,
227-
),
228+
qos=qos,
228229
address=address,
229230
)
230231
return self
@@ -244,21 +245,22 @@ def sync_database(self, session: sa.orm.Session) -> None:
244245
245246
If the task is not in the dask scheduler, it is re-queued.
246247
"""
248+
# the retrieve API sets the status to "dismissed", here the broker deletes the request
249+
# this is to better control the status of the QoS
250+
dismissed_uids = db.update_dismissed_requests(session)
251+
for uid in dismissed_uids:
252+
if future := self.futures.pop(uid, None):
253+
future.cancel()
254+
if dismissed_uids:
255+
self.queue.reset()
256+
self.qos.reload_rules(session)
257+
db.reset_qos_rules(session, self.qos)
258+
247259
statement = sa.select(db.SystemRequest).where(
248260
db.SystemRequest.status.in_(("running", "dismissed"))
249261
)
250262
dask_tasks = get_tasks(self.client)
251263
for request in session.scalars(statement):
252-
# the retrieve API set the status to "dismissed", here the broker deletes the request
253-
# this is to better control the status of the QoS
254-
if request.status == "dismissed":
255-
db.delete_request(request=request, session=session)
256-
self.qos.notify_end_of_request(
257-
request, session, scheduler=self.internal_scheduler
258-
)
259-
if future := self.futures.get(request.request_uid):
260-
future.cancel()
261-
continue
262264
# if request is in futures, go on
263265
if request.request_uid in self.futures:
264266
continue

cads_broker/entry_points.py

-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def delete_requests(
159159
typer.echo(
160160
f"Status set to 'dismissed' for {number_of_requests} requests in the broker database."
161161
)
162-
typer.echo("Please restart the broker.")
163162

164163

165164
@app.command()

tests/test_02_database.py

-43
Original file line numberDiff line numberDiff line change
@@ -516,25 +516,6 @@ def test_add_qos_rule(session_obj: sa.orm.sessionmaker) -> None:
516516
assert db.get_qos_rule(str(rule.__hash__()), session=session).name == rule.name
517517

518518

519-
def test_increment_qos_rule_running(session_obj: sa.orm.sessionmaker) -> None:
520-
rule1 = MockRule("name1", "conclusion1", "info1", "condition1")
521-
rule2 = MockRule("name2", "conclusion2", "info2", "condition2")
522-
with session_obj() as session:
523-
db.add_qos_rule(rule1, session=session)
524-
with session_obj() as session:
525-
db.increment_qos_rule_running([rule1, rule2], rules_in_db={}, session=session)
526-
session.commit()
527-
with session_obj() as session:
528-
assert db.get_qos_rule(str(rule1.__hash__()), session=session).running == 1
529-
rule = db.get_qos_rule(str(rule2.__hash__()), session=session)
530-
assert rule.running == 1
531-
rules_in_db = {rule.uid: rule}
532-
db.increment_qos_rule_running([rule2], rules_in_db=rules_in_db, session=session)
533-
session.commit()
534-
with session_obj() as session:
535-
assert db.get_qos_rule(str(rule2.__hash__()), session=session).running == 2
536-
537-
538519
def test_add_request_qos_status(session_obj: sa.orm.sessionmaker) -> None:
539520
rule1 = MockRule("name1", "conclusion1", "info1", "condition1")
540521
rule2 = MockRule("name2", "conclusion2", "info2", "condition2")
@@ -629,30 +610,6 @@ def test_decrement_qos_rule_running(session_obj: sa.orm.sessionmaker) -> None:
629610
)
630611

631612

632-
def test_reset_qos_rules(session_obj: sa.orm.sessionmaker) -> None:
633-
rule1 = MockRule("name1", "conclusion1", "info1", "condition1")
634-
rule2 = MockRule("name2", "conclusion2", "info2", "condition2")
635-
adaptor_properties = mock_config()
636-
request = mock_system_request(adaptor_properties_hash=adaptor_properties.hash)
637-
request_uid = request.request_uid
638-
with session_obj() as session:
639-
rule1_db = db.add_qos_rule(rule1, session=session)
640-
rule2_db = db.add_qos_rule(rule2, session=session)
641-
rules_in_db = {rule1_db.uid: rule1_db, rule2_db.uid: rule2_db}
642-
session.add(adaptor_properties)
643-
session.add(request)
644-
session.commit()
645-
646-
db.add_request_qos_status(
647-
request, [rule1, rule2], session=session, rules_in_db=rules_in_db
648-
)
649-
session.commit()
650-
651-
db.reset_qos_rules(session=session)
652-
request = db.get_request(request_uid, session=session)
653-
assert db.get_qos_status_from_request(request) == {}
654-
655-
656613
def test_get_events_from_request(session_obj: sa.orm.sessionmaker) -> None:
657614
adaptor_properties = mock_config()
658615
request = mock_system_request(adaptor_properties_hash=adaptor_properties.hash)

tests/test_20_dispatcher.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_broker_sync_database(
5959
in_futures_request_uid = str(uuid.uuid4())
6060
in_dask_request_uid = str(uuid.uuid4())
6161
lost_request_uid = str(uuid.uuid4())
62-
dismissed_request_uid = str(uuid.uuid4())
62+
# dismissed_request_uid = str(uuid.uuid4())
6363
adaptor_metadata = mock_config()
6464
in_futures_request = mock_system_request(
6565
request_uid=in_futures_request_uid,
@@ -76,17 +76,17 @@ def test_broker_sync_database(
7676
status="running",
7777
adaptor_properties_hash=adaptor_metadata.hash,
7878
)
79-
dismissed_request = mock_system_request(
80-
request_uid=dismissed_request_uid,
81-
status="dismissed",
82-
adaptor_properties_hash=adaptor_metadata.hash,
83-
)
79+
# dismissed_request = mock_system_request(
80+
# request_uid=dismissed_request_uid,
81+
# status="dismissed",
82+
# adaptor_properties_hash=adaptor_metadata.hash,
83+
# )
8484
with session_obj() as session:
8585
session.add(adaptor_metadata)
8686
session.add(in_futures_request)
8787
session.add(in_dask_request)
8888
session.add(lost_request)
89-
session.add(dismissed_request)
89+
# session.add(dismissed_request)
9090
session.commit()
9191

9292
def mock_get_tasks() -> dict[str, str]:
@@ -115,9 +115,9 @@ def mock_get_tasks() -> dict[str, str]:
115115
db.SystemRequest.request_uid == lost_request_uid
116116
)
117117
output_request = session.scalars(statement).first()
118-
assert output_request.status == "accepted"
119-
assert output_request.request_metadata.get("resubmit_number") == 1
118+
assert output_request.status == "failed"
119+
assert output_request.request_metadata.get("resubmit_number") is None
120120

121-
with pytest.raises(db.NoResultFound):
122-
with session_obj() as session:
123-
db.get_request(dismissed_request_uid, session=session)
121+
# with pytest.raises(db.NoResultFound):
122+
# with session_obj() as session:
123+
# db.get_request(dismissed_request_uid, session=session)

0 commit comments

Comments
 (0)