diff --git a/cads_broker/database.py b/cads_broker/database.py index fd945eb1..20cf61c3 100644 --- a/cads_broker/database.py +++ b/cads_broker/database.py @@ -427,26 +427,9 @@ def count_users(status: str, entry_point: str, session: sa.orm.Session) -> int: ) -def update_dismissed_requests(session: sa.orm.Session) -> Iterable[str]: - stmt_dismissed = ( - sa.update(SystemRequest) - .where(SystemRequest.status == "dismissed") - .returning(SystemRequest.request_uid) - .values(status="failed", response_error={"reason": "dismissed request"}) - ) - dismissed_uids = session.scalars(stmt_dismissed).fetchall() - session.execute( # type: ignore - sa.insert(Events), - map( - lambda x: { - "request_uid": x, - "message": DISMISSED_MESSAGE, - "event_type": "user_visible_error", - }, - dismissed_uids, - ), - ) - return dismissed_uids +def get_dismissed_requests(session: sa.orm.Session) -> Iterable[SystemRequest]: + stmt_dismissed = sa.select(SystemRequest).where(SystemRequest.status == "dismissed") + return session.scalars(stmt_dismissed).fetchall() def get_events_from_request( diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index 5fd87cad..69e398cd 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -170,9 +170,9 @@ def values(self) -> Iterable[Any]: with self._lock: return self.queue_dict.values() - def pop(self, key: str) -> Any: + def pop(self, key: str, default=None) -> Any: with self._lock: - return self.queue_dict.pop(key, None) + return self.queue_dict.pop(key, default) def len(self) -> int: with self._lock: @@ -192,9 +192,12 @@ def __init__(self, number_of_workers) -> None: if os.path.exists(self.rules_path): self.rules = self.rules_path else: - parser = QoS.RulesParser(io.StringIO(os.getenv("DEFAULT_RULES", ""))) + logger.info("rules file not found", rules_path=self.rules_path) + parser = QoS.RulesParser( + io.StringIO(os.getenv("DEFAULT_RULES", "")), logger=logger + ) self.rules = QoS.RuleSet() - parser.parse_rules(self.rules, self.environment) + parser.parse_rules(self.rules, self.environment, raise_exception=False) @attrs.define @@ -235,6 +238,7 @@ def from_address( qos_config.rules, qos_config.environment, rules_hash=rules_hash, + logger=logger, ) with session_maker_write() as session: qos.environment.set_session(session) @@ -336,16 +340,16 @@ def sync_database(self, session: sa.orm.Session) -> None: - If the task is not in the dask scheduler, it is re-queued. This behaviour can be changed with an environment variable. """ - # the retrieve API sets the status to "dismissed", here the broker deletes the request - # this is to better control the status of the QoS - dismissed_uids = db.update_dismissed_requests(session) - for uid in dismissed_uids: - if future := self.futures.pop(uid, None): + # the retrieve API sets the status to "dismissed", + # here the broker fixes the QoS and queue status accordingly + dismissed_requests = db.get_dismissed_requests(session) + for request in dismissed_requests: + if future := self.futures.pop(request.request_uid, None): future.cancel() - if dismissed_uids: - self.queue.reset() - self.qos.reload_rules(session) - db.reset_qos_rules(session, self.qos) + self.qos.notify_end_of_request( + request, session, scheduler=self.internal_scheduler + ) + self.queue.pop(request.request_uid, None) session.commit() statement = sa.select(db.SystemRequest).where( diff --git a/cads_broker/entry_points.py b/cads_broker/entry_points.py index 94c97085..b2d364f7 100644 --- a/cads_broker/entry_points.py +++ b/cads_broker/entry_points.py @@ -53,7 +53,6 @@ def add_dummy_requests( entry_point="cads_adaptors:DummyAdaptor", ) session.add(request) - if i % 100 == 0: session.commit() session.commit() diff --git a/cads_broker/expressions/Parser.py b/cads_broker/expressions/Parser.py index f64cbbae..f8e2e6f5 100644 --- a/cads_broker/expressions/Parser.py +++ b/cads_broker/expressions/Parser.py @@ -41,7 +41,7 @@ class Parser: This class must be sub-classed. """ - def __init__(self, path, comments=True): + def __init__(self, path, logger, comments=True): if isinstance(path, str): self.reader = Reader(open(path)) else: @@ -49,6 +49,7 @@ def __init__(self, path, comments=True): self.comments = comments self.eof = False self.line = 0 + self.logger = logger def read(self): if self.eof: diff --git a/cads_broker/expressions/RulesParser.py b/cads_broker/expressions/RulesParser.py index 31578e22..46a664a0 100644 --- a/cads_broker/expressions/RulesParser.py +++ b/cads_broker/expressions/RulesParser.py @@ -327,7 +327,7 @@ def parse(self): return result - def parse_rules(self, rules, environment): + def parse_rules(self, rules, environment, raise_exception=True): """Parse the text provided in the constructor. Args: @@ -340,22 +340,29 @@ def parse_rules(self, rules, environment): ParserError: [description] """ while self.peek(): - ident = self.parse_ident() - - if ident == "limit": - self.parse_global_limit(rules, environment) - continue - - if ident == "priority": - self.parse_priority(rules, environment) - continue - - if ident == "permission": - self.parse_permission(rules, environment) - continue - - if ident == "user": - self.parse_user_limit(rules, environment) - continue - - raise ParserError(f"Unknown rule: '{ident}'", self.line + 1) + try: + ident = self.parse_ident() + + if ident == "limit": + self.parse_global_limit(rules, environment) + continue + + if ident == "priority": + self.parse_priority(rules, environment) + continue + + if ident == "permission": + self.parse_permission(rules, environment) + continue + + if ident == "user": + self.parse_user_limit(rules, environment) + continue + + raise ParserError(f"Unknown rule: '{ident}'", self.line + 1) + except ParserError as e: + if raise_exception: + raise e + else: + self.logger.info(e) + return diff --git a/cads_broker/qos/QoS.py b/cads_broker/qos/QoS.py index ece2e5fc..1b9862bb 100644 --- a/cads_broker/qos/QoS.py +++ b/cads_broker/qos/QoS.py @@ -26,12 +26,13 @@ def wrapped(self, *args, **kwargs): class QoS: - def __init__(self, rules, environment, rules_hash): + def __init__(self, rules, environment, rules_hash, logger): self.lock = threading.RLock() self.rules_hash = rules_hash self.environment = environment + self.logger = logger # The list of active requests # Cache associating Request and their Properties @@ -53,13 +54,13 @@ def __init__(self, rules, environment, rules_hash): def read_rules(self): """Read the rule files and populate the rule_set.""" # Create a parser to parse the rules file - parser = RulesParser(self.path) + parser = RulesParser(self.path, logger=self.logger) # The rules will be stored in self.rules self.rules = RuleSet() # Parse the rules - parser.parse_rules(self.rules, self.environment) + parser.parse_rules(self.rules, self.environment, raise_exception=False) # Print the rules self.rules.dump() @@ -156,9 +157,9 @@ def _properties(self, request, session): properties.limits.append(rule) # Add per-user limits - limit = self.user_limit(request) - if limit is not None: - properties.limits.append(limit) + limits = self.user_limit(request) + if limits != []: + properties.limits.extend(limits) # Add priorities and compute starting priority priority = 0 @@ -251,10 +252,10 @@ def user_limit(self, request): """Return the per-user limit for the user associated with the request.""" user = request.user_uid - limit = self.per_user_limits.get(user) - if limit is not None: - print(user, limit) - return limit + limits = self.per_user_limits.get(user, []) + if limits != []: + print(user, limits) + return limits for limit in self.rules.user_limits: if limit.match(request): @@ -263,10 +264,9 @@ def user_limit(self, request): user otherwise all users will share that limit """ limit = limit.clone() - self.per_user_limits[user] = limit - return limit - return None - # raise Exception(f"Not rules matching user '{user}'") + limits.append(limit) + self.per_user_limits[user] = limits + return limits @locked def pick(self, queue, session): diff --git a/tests/test_01_expressions.py b/tests/test_01_expressions.py index 21846a50..415a5cb5 100644 --- a/tests/test_01_expressions.py +++ b/tests/test_01_expressions.py @@ -1,4 +1,5 @@ import io +import logging from cads_broker import Environment from cads_broker.expressions import FunctionFactory @@ -14,6 +15,8 @@ lambda context, *args: context.request.adaptor, ) +logger = logging.getLogger("test") + class TestRequest: user_uid = "david" @@ -30,7 +33,7 @@ class TestRequest: def compile(text): - parser = RulesParser(io.StringIO(text)) + parser = RulesParser(io.StringIO(text), logger=logger) return parser.parse() diff --git a/tests/test_02_database.py b/tests/test_02_database.py index 473be6ba..888e0bdc 100644 --- a/tests/test_02_database.py +++ b/tests/test_02_database.py @@ -14,11 +14,13 @@ class MockRule: - def __init__(self, name, conclusion, info, condition): + def __init__(self, name, conclusion, info, condition, queued=[], running=0): self.name = name self.conclusion = conclusion self.info = info self.condition = condition + self.queued = queued + self.value = running def evaluate(self, request): return 10 @@ -517,8 +519,12 @@ def test_add_qos_rule(session_obj: sa.orm.sessionmaker) -> None: def test_add_request_qos_status(session_obj: sa.orm.sessionmaker) -> None: - rule1 = MockRule("name1", "conclusion1", "info1", "condition1") - rule2 = MockRule("name2", "conclusion2", "info2", "condition2") + rule1 = MockRule( + "name1", "conclusion1", "info1", "condition1", queued=list(range(5)) + ) + rule2 = MockRule( + "name2", "conclusion2", "info2", "condition2", queued=list(range(1)) + ) adaptor_properties = mock_config() request = mock_system_request(adaptor_properties_hash=adaptor_properties.hash) request_uid = request.request_uid @@ -535,16 +541,18 @@ def test_add_request_qos_status(session_obj: sa.orm.sessionmaker) -> None: with session_obj() as session: request = db.get_request(request_uid, session=session) assert db.get_qos_status_from_request(request) == { - "name1": [ - {"info": "info1", "queued": 5 + 1, "running": 0, "conclusion": "10"} - ], + "name1": [{"info": "info1", "queued": 5, "running": 0, "conclusion": "10"}], "name2": [{"info": "info2", "queued": 1, "running": 0, "conclusion": "10"}], } def test_delete_request_qos_status(session_obj: sa.orm.sessionmaker) -> None: - rule1 = MockRule("name1", "conclusion1", "info1", "condition1") - rule2 = MockRule("name2", "conclusion2", "info2", "condition2") + rule1 = MockRule( + "name1", "conclusion1", "info1", "condition1", queued=list(range(5)), running=2 + ) + rule2 = MockRule( + "name2", "conclusion2", "info2", "condition2", queued=list(range(3)), running=2 + ) adaptor_properties = mock_config() request = mock_system_request(adaptor_properties_hash=adaptor_properties.hash) request_uid = request.request_uid @@ -574,14 +582,16 @@ def test_delete_request_qos_status(session_obj: sa.orm.sessionmaker) -> None: rule1 = db.get_qos_rule(str(rule1.__hash__()), session=session) rule2 = db.get_qos_rule(str(rule2.__hash__()), session=session) assert rule1.queued == rule1_queued - assert rule1.running == rule1_running + 1 assert rule2.queued == rule2_queued - assert rule2.running == rule2_running + 1 def test_decrement_qos_rule_running(session_obj: sa.orm.sessionmaker) -> None: - rule1 = MockRule("name1", "conclusion1", "info1", "condition1") - rule2 = MockRule("name2", "conclusion2", "info2", "condition2") + rule1 = MockRule( + "name1", "conclusion1", "info1", "condition1", queued=list(range(5)), running=2 + ) + rule2 = MockRule( + "name2", "conclusion2", "info2", "condition2", queued=list(range(3)), running=4 + ) rule1_queued = 5 rule1_running = 2 rule2_queued = 3 @@ -602,11 +612,11 @@ def test_decrement_qos_rule_running(session_obj: sa.orm.sessionmaker) -> None: with session_obj() as session: assert ( db.get_qos_rule(str(rule1.__hash__()), session=session).running - == rule1_running - 1 + == rule1_running ) assert ( db.get_qos_rule(str(rule2.__hash__()), session=session).running - == rule2_running - 1 + == rule2_running ) diff --git a/tests/test_03_qos.py b/tests/test_03_qos.py index fd4c2f74..5ff301cd 100644 --- a/tests/test_03_qos.py +++ b/tests/test_03_qos.py @@ -1,4 +1,5 @@ import io +import logging from cads_broker import Environment from cads_broker.expressions import FunctionFactory @@ -13,6 +14,7 @@ "adaptor", lambda context, *args: context.request.adaptor, ) +logger = logging.getLogger("test") class TestRequest: @@ -30,7 +32,7 @@ class TestRequest: def compile(text): - parser = RulesParser(io.StringIO(text)) + parser = RulesParser(io.StringIO(text), logger=logger) rules = RuleSet() parser.parse_rules(rules, environment) return rules diff --git a/tests/test_20_dispatcher.py b/tests/test_20_dispatcher.py index 9aeb38fa..72175fae 100644 --- a/tests/test_20_dispatcher.py +++ b/tests/test_20_dispatcher.py @@ -1,4 +1,5 @@ import datetime +import logging import uuid from typing import Any @@ -13,6 +14,8 @@ # create client object and connect to local cluster CLIENT = distributed.Client(distributed.LocalCluster()) +logger = logging.getLogger("test") + def mock_config(hash: str = "", config: dict[str, Any] = {}, form: dict[str, Any] = {}): adaptor_metadata = db.AdaptorProperties( @@ -45,7 +48,9 @@ def test_broker_sync_database( mocker: pytest_mock.plugin.MockerFixture, session_obj: sa.orm.sessionmaker ) -> None: environment = Environment.Environment() - qos = QoS.QoS(rules=Rule.RuleSet(), environment=environment, rules_hash="") + qos = QoS.QoS( + rules=Rule.RuleSet(), environment=environment, rules_hash="", logger=logger + ) broker = dispatcher.Broker( client=CLIENT, environment=environment, @@ -89,10 +94,10 @@ def test_broker_sync_database( session.commit() def mock_get_tasks() -> dict[str, str]: - return {in_dask_request_uid: "..."} + return {in_dask_request_uid: {"state": "...", "exception": None}} mocker.patch( - "cads_broker.dispatcher.get_tasks", + "cads_broker.dispatcher.get_tasks_from_scheduler", return_value=mock_get_tasks(), ) broker.futures = {in_futures_request_uid: "..."} @@ -114,8 +119,8 @@ def mock_get_tasks() -> dict[str, str]: db.SystemRequest.request_uid == lost_request_uid ) output_request = session.scalars(statement).first() - assert output_request.status == "failed" - assert output_request.request_metadata.get("resubmit_number") is None + assert output_request.status == "accepted" + assert output_request.request_metadata.get("resubmit_number") == 1 # with pytest.raises(db.NoResultFound): # with session_obj() as session: