diff --git a/cads_broker/database.py b/cads_broker/database.py index fd945eb1..20cf61c3 100644 --- a/cads_broker/database.py +++ b/cads_broker/database.py @@ -427,26 +427,9 @@ def count_users(status: str, entry_point: str, session: sa.orm.Session) -> int: ) -def update_dismissed_requests(session: sa.orm.Session) -> Iterable[str]: - stmt_dismissed = ( - sa.update(SystemRequest) - .where(SystemRequest.status == "dismissed") - .returning(SystemRequest.request_uid) - .values(status="failed", response_error={"reason": "dismissed request"}) - ) - dismissed_uids = session.scalars(stmt_dismissed).fetchall() - session.execute( # type: ignore - sa.insert(Events), - map( - lambda x: { - "request_uid": x, - "message": DISMISSED_MESSAGE, - "event_type": "user_visible_error", - }, - dismissed_uids, - ), - ) - return dismissed_uids +def get_dismissed_requests(session: sa.orm.Session) -> Iterable[SystemRequest]: + stmt_dismissed = sa.select(SystemRequest).where(SystemRequest.status == "dismissed") + return session.scalars(stmt_dismissed).fetchall() def get_events_from_request( diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index 5fd87cad..028fc766 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -170,9 +170,9 @@ def values(self) -> Iterable[Any]: with self._lock: return self.queue_dict.values() - def pop(self, key: str) -> Any: + def pop(self, key: str, default=None) -> Any: with self._lock: - return self.queue_dict.pop(key, None) + return self.queue_dict.pop(key, default) def len(self) -> int: with self._lock: @@ -336,16 +336,16 @@ def sync_database(self, session: sa.orm.Session) -> None: - If the task is not in the dask scheduler, it is re-queued. This behaviour can be changed with an environment variable. """ - # the retrieve API sets the status to "dismissed", here the broker deletes the request - # this is to better control the status of the QoS - dismissed_uids = db.update_dismissed_requests(session) - for uid in dismissed_uids: - if future := self.futures.pop(uid, None): + # the retrieve API sets the status to "dismissed", + # here the broker fixes the QoS and queue status accordingly + dismissed_requests = db.get_dismissed_requests(session) + for request in dismissed_requests: + if future := self.futures.pop(request.request_uid, None): future.cancel() - if dismissed_uids: - self.queue.reset() - self.qos.reload_rules(session) - db.reset_qos_rules(session, self.qos) + self.qos.notify_end_of_request( + request, session, scheduler=self.internal_scheduler + ) + self.queue.pop(request.request_uid, None) session.commit() statement = sa.select(db.SystemRequest).where( diff --git a/cads_broker/entry_points.py b/cads_broker/entry_points.py index 94c97085..b2d364f7 100644 --- a/cads_broker/entry_points.py +++ b/cads_broker/entry_points.py @@ -53,7 +53,6 @@ def add_dummy_requests( entry_point="cads_adaptors:DummyAdaptor", ) session.add(request) - if i % 100 == 0: session.commit() session.commit() diff --git a/tests/test_02_database.py b/tests/test_02_database.py index 473be6ba..888e0bdc 100644 --- a/tests/test_02_database.py +++ b/tests/test_02_database.py @@ -14,11 +14,13 @@ class MockRule: - def __init__(self, name, conclusion, info, condition): + def __init__(self, name, conclusion, info, condition, queued=[], running=0): self.name = name self.conclusion = conclusion self.info = info self.condition = condition + self.queued = queued + self.value = running def evaluate(self, request): return 10 @@ -517,8 +519,12 @@ def test_add_qos_rule(session_obj: sa.orm.sessionmaker) -> None: def test_add_request_qos_status(session_obj: sa.orm.sessionmaker) -> None: - rule1 = MockRule("name1", "conclusion1", "info1", "condition1") - rule2 = MockRule("name2", "conclusion2", "info2", "condition2") + rule1 = MockRule( + "name1", "conclusion1", "info1", "condition1", queued=list(range(5)) + ) + rule2 = MockRule( + "name2", "conclusion2", "info2", "condition2", queued=list(range(1)) + ) adaptor_properties = mock_config() request = mock_system_request(adaptor_properties_hash=adaptor_properties.hash) request_uid = request.request_uid @@ -535,16 +541,18 @@ def test_add_request_qos_status(session_obj: sa.orm.sessionmaker) -> None: with session_obj() as session: request = db.get_request(request_uid, session=session) assert db.get_qos_status_from_request(request) == { - "name1": [ - {"info": "info1", "queued": 5 + 1, "running": 0, "conclusion": "10"} - ], + "name1": [{"info": "info1", "queued": 5, "running": 0, "conclusion": "10"}], "name2": [{"info": "info2", "queued": 1, "running": 0, "conclusion": "10"}], } def test_delete_request_qos_status(session_obj: sa.orm.sessionmaker) -> None: - rule1 = MockRule("name1", "conclusion1", "info1", "condition1") - rule2 = MockRule("name2", "conclusion2", "info2", "condition2") + rule1 = MockRule( + "name1", "conclusion1", "info1", "condition1", queued=list(range(5)), running=2 + ) + rule2 = MockRule( + "name2", "conclusion2", "info2", "condition2", queued=list(range(3)), running=2 + ) adaptor_properties = mock_config() request = mock_system_request(adaptor_properties_hash=adaptor_properties.hash) request_uid = request.request_uid @@ -574,14 +582,16 @@ def test_delete_request_qos_status(session_obj: sa.orm.sessionmaker) -> None: rule1 = db.get_qos_rule(str(rule1.__hash__()), session=session) rule2 = db.get_qos_rule(str(rule2.__hash__()), session=session) assert rule1.queued == rule1_queued - assert rule1.running == rule1_running + 1 assert rule2.queued == rule2_queued - assert rule2.running == rule2_running + 1 def test_decrement_qos_rule_running(session_obj: sa.orm.sessionmaker) -> None: - rule1 = MockRule("name1", "conclusion1", "info1", "condition1") - rule2 = MockRule("name2", "conclusion2", "info2", "condition2") + rule1 = MockRule( + "name1", "conclusion1", "info1", "condition1", queued=list(range(5)), running=2 + ) + rule2 = MockRule( + "name2", "conclusion2", "info2", "condition2", queued=list(range(3)), running=4 + ) rule1_queued = 5 rule1_running = 2 rule2_queued = 3 @@ -602,11 +612,11 @@ def test_decrement_qos_rule_running(session_obj: sa.orm.sessionmaker) -> None: with session_obj() as session: assert ( db.get_qos_rule(str(rule1.__hash__()), session=session).running - == rule1_running - 1 + == rule1_running ) assert ( db.get_qos_rule(str(rule2.__hash__()), session=session).running - == rule2_running - 1 + == rule2_running ) diff --git a/tests/test_20_dispatcher.py b/tests/test_20_dispatcher.py index 9aeb38fa..60a814b8 100644 --- a/tests/test_20_dispatcher.py +++ b/tests/test_20_dispatcher.py @@ -89,10 +89,10 @@ def test_broker_sync_database( session.commit() def mock_get_tasks() -> dict[str, str]: - return {in_dask_request_uid: "..."} + return {in_dask_request_uid: {"state": "...", "exception": None}} mocker.patch( - "cads_broker.dispatcher.get_tasks", + "cads_broker.dispatcher.get_tasks_from_scheduler", return_value=mock_get_tasks(), ) broker.futures = {in_futures_request_uid: "..."} @@ -114,8 +114,8 @@ def mock_get_tasks() -> dict[str, str]: db.SystemRequest.request_uid == lost_request_uid ) output_request = session.scalars(statement).first() - assert output_request.status == "failed" - assert output_request.request_metadata.get("resubmit_number") is None + assert output_request.status == "accepted" + assert output_request.request_metadata.get("resubmit_number") == 1 # with pytest.raises(db.NoResultFound): # with session_obj() as session: