diff --git a/cads_broker/config.py b/cads_broker/config.py index 597fac6..ccb79c5 100644 --- a/cads_broker/config.py +++ b/cads_broker/config.py @@ -28,7 +28,7 @@ class BrokerConfig(pydantic_settings.BaseSettings): broker_priority_algorithm: str = "legacy" broker_priority_interval_hours: int = 24 - broker_get_number_of_workers_cache_time: int = 10 + broker_get_number_of_workers_cache_time: int = 60 broker_qos_rules_cache_time: int = 10 broker_get_tasks_from_scheduler_cache_time: int = 1 broker_rules_path: str = "/src/rules.qos" @@ -42,7 +42,12 @@ class BrokerConfig(pydantic_settings.BaseSettings): broker_max_dismissed_requests: int = 100 broker_cancel_stuck_requests_cache_ttl: int = 60 broker_stuck_requests_limit_minutes: int = 15 - broker_memory_error_user_visible_log: str = "Worker has been killed due to memory usage." + broker_memory_error_user_visible_log: str = ( + "Worker has been killed due to memory usage." + ) + broker_workers_gap: int = ( + 10 # max discrepancy of workers number before qos rules are reloaded + ) class SqlalchemySettings(pydantic_settings.BaseSettings): diff --git a/cads_broker/database.py b/cads_broker/database.py index c1960af..b6e40de 100644 --- a/cads_broker/database.py +++ b/cads_broker/database.py @@ -652,21 +652,18 @@ def get_qos_status_from_request( def requeue_request( request: SystemRequest, session: sa.orm.Session, -) -> SystemRequest | None: - if request.status == "running": - # ugly implementation because sqlalchemy doesn't allow to directly update JSONB - # FIXME: use a specific column for resubmit_number - metadata = dict(request.request_metadata) - metadata.update( - {"resubmit_number": request.request_metadata.get("resubmit_number", 0) + 1} - ) - request.request_metadata = metadata - request.status = "accepted" - session.commit() - logger.info("requeueing request", **logger_kwargs(request=request)) - return request - else: - return None +) -> SystemRequest: + # ugly implementation because sqlalchemy doesn't allow to directly update JSONB + # FIXME: use a specific column for resubmit_number + metadata = dict(request.request_metadata) + metadata.update( + {"resubmit_number": request.request_metadata.get("resubmit_number", 0) + 1} + ) + request.request_metadata = metadata + request.status = "accepted" + session.commit() + logger.info("requeueing request", **logger_kwargs(request=request)) + return request def set_request_cache_id(request_uid: str, cache_id: int, session: sa.orm.Session): @@ -680,11 +677,9 @@ def set_request_cache_id(request_uid: str, cache_id: int, session: sa.orm.Sessio def set_successful_request( request_uid: str, session: sa.orm.Session, -) -> SystemRequest | None: +) -> SystemRequest: statement = sa.select(SystemRequest).where(SystemRequest.request_uid == request_uid) request = session.scalars(statement).one() - if request.status == "successful": - return None request.status = "successful" request.finished_at = sa.func.now() session.commit() diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index d07cc1e..b7d5759 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -283,6 +283,87 @@ def __init__(self, number_of_workers) -> None: parser.parse_rules(self.rules, self.environment, raise_exception=False) +def set_running_request( + request: db.SystemRequest, + priority: int | None, + qos: QoS.QoS, + queue: Queue, + internal_scheduler: Scheduler, + session: sa.orm.Session, +) -> db.SystemRequest: + """Set the status of the request to running and notify the qos rules.""" + request = db.set_request_status( + request_uid=request.request_uid, + status="running", + priority=priority, + session=session, + ) + qos.notify_start_of_request(request, scheduler=internal_scheduler) + queue.pop(request.request_uid) + return request + + +def set_successful_request( + request: db.SystemRequest, + qos: QoS.QoS, + internal_scheduler: Scheduler, + session: sa.orm.Session, +) -> db.SystemRequest: + """Set the status of the request to successful and notify the qos rules.""" + if request.status == "successful": + return request + request = db.set_successful_request( + request_uid=request.request_uid, + session=session, + ) + qos.notify_end_of_request(request, scheduler=internal_scheduler) + logger.info( + "job has finished", + **db.logger_kwargs(request=request), + ) + return request + + +def set_failed_request( + request: db.SystemRequest, + error_message: str, + error_reason: str, + qos: QoS.QoS, + internal_scheduler: Scheduler, + session: sa.orm.Session, +) -> db.SystemRequest: + """Set the status of the request to failed and notify the qos rules.""" + request = db.set_request_status( + request_uid=request.request_uid, + status="failed", + error_message=error_message, + error_reason=error_reason, + session=session, + ) + qos.notify_end_of_request(request, scheduler=internal_scheduler) + logger.info( + "job has finished", + **db.logger_kwargs(request=request), + ) + return request + + +def requeue_request( + request: db.SystemRequest, + qos: QoS.QoS, + queue: Queue, + internal_scheduler: Scheduler, + session: sa.orm.Session, +) -> db.SystemRequest: + """Re-queue the request and notify the qos rules.""" + if request.status == "running": + queued_request = db.requeue_request(request=request, session=session) + qos.notify_end_of_request(queued_request, scheduler=internal_scheduler) + queue.add(queued_request.request_uid, request) + return queued_request + return request + + @attrs.define class Broker: client: distributed.Client @@ -328,8 +409,7 @@ def from_address( ) return self - @property - def number_of_workers(self): + def set_number_of_workers(self): if self.client.scheduler is None: logger.info("Reconnecting to dask scheduler...") self.client = distributed.Client(self.address) @@ -337,9 +417,20 @@ def number_of_workers(self): self.environment.number_of_workers = number_of_workers return number_of_workers + def update_number_of_workers(self, session_write): + """Reload qos rules if number of workers has changed by a number greater than BROKER_WORKERS_GAP.""" + if ( + abs(self.environment.number_of_workers - get_number_of_workers(self.client)) + > CONFIG.broker_workers_gap + ): + self.set_number_of_workers() + reload_qos_rules(session_write, self.qos) + self.internal_scheduler.refresh() + self.queue.reset() + def set_request_error_status( self, exception, request_uid, session - ) -> db.SystemRequest | None: + ) -> db.SystemRequest: """Set the status of the request to failed and write the error message and reason. If the error reason is "KilledWorker": @@ -351,7 +442,7 @@ def set_request_error_status( error_reason = exception.__class__.__name__ request = db.get_request(request_uid, session=session) if request.status != "running": - return None + return request requeue = CONFIG.broker_requeue_on_killed_worker_requests if error_reason == "KilledWorker": worker_restart_events = self.client.get_events("worker-restart-memory") @@ -381,11 +472,12 @@ def set_request_error_status( message=CONFIG.broker_memory_error_user_visible_log, session=session, ) - request = db.set_request_status( - request_uid, - "failed", + request = set_failed_request( + request=request, error_message=error_message, error_reason=error_reason, + qos=self.qos, + internal_scheduler=self.internal_scheduler, session=session, ) requeue = False @@ -395,18 +487,20 @@ def set_request_error_status( < CONFIG.broker_requeue_limit ): logger.info("worker killed: re-queueing", job_id=request_uid) - queued_request = db.requeue_request(request=request, session=session) - if queued_request: - self.queue.add(request_uid, request) - self.qos.notify_end_of_request( - request, scheduler=self.internal_scheduler - ) + request = requeue_request( + request=request, + qos=self.qos, + queue=self.queue, + internal_scheduler=self.internal_scheduler, + session=session, + ) else: - request = db.set_request_status( - request_uid, - "failed", + request = set_failed_request( + request=request, error_message=error_message, error_reason=error_reason, + qos=self.qos, + internal_scheduler=self.internal_scheduler, session=session, ) return request @@ -481,73 +575,53 @@ def sync_database(self, session: sa.orm.Session) -> None: self.qos.notify_start_of_request( request, scheduler=self.internal_scheduler ) - continue elif task := scheduler_tasks.get(request.request_uid, None): - if (state := task["state"]) in ("memory", "erred"): - if state == "memory": - # if the task is in memory and it is not in the futures - # it means that the task has been lost by the broker (broker has been restarted) - # the task is successful. If the "set_successful_request" function returns None - # it means that the request has already been set to successful - finished_request = db.set_successful_request( - request_uid=request.request_uid, - session=session, - ) - elif state == "erred": - exception = pickle.loads(task["exception"]) - finished_request = self.set_request_error_status( - exception=exception, - request_uid=request.request_uid, - session=session, - ) - # notify the qos only if the request has been set to successful or failed here. - if finished_request: - self.qos.notify_end_of_request( - finished_request, scheduler=self.internal_scheduler - ) - logger.info( - "job has finished", - dask_status=task["state"], - **db.logger_kwargs(request=finished_request), - ) + if (state := task["state"]) == "memory": + # if the task is in memory and it is not in the futures + # it means that the task has been lost by the broker (broker has been restarted) + # the task is successful. If the "set_successful_request" function returns None + # it means that the request has already been set to successful + set_successful_request( + request=request, + qos=self.qos, + internal_scheduler=self.internal_scheduler, + session=session, + ) + elif state == "erred": + exception = pickle.loads(task["exception"]) + self.set_request_error_status( + exception=exception, + request_uid=request.request_uid, + session=session, + ) # if the task is in processing, it means that the task is still running elif state == "processing": # notify start of request if it is not already notified self.qos.notify_start_of_request( request, scheduler=self.internal_scheduler ) - continue elif state == "released": # notify start of request if it is not already notified - queued_request = db.requeue_request( - request=request, session=session + requeue_request( + request=request, + qos=self.qos, + queue=self.queue, + internal_scheduler=self.internal_scheduler, + session=session, ) - if queued_request: - self.queue.add(queued_request.request_uid, request) - self.qos.notify_end_of_request( - request, scheduler=self.internal_scheduler - ) - continue # if it doesn't find the request: re-queue it else: request = db.get_request(request.request_uid, session=session) # if the broker finds the cache_id it means that the job has finished if request.cache_id: - successful_request = db.set_successful_request( - request_uid=request.request_uid, + set_successful_request( + request=request, + qos=self.qos, + internal_scheduler=self.internal_scheduler, session=session, ) - if successful_request: - self.qos.notify_end_of_request( - request, scheduler=self.internal_scheduler - ) - logger.info( - "job has finished", - **db.logger_kwargs(request=successful_request), - ) - continue - # FIXME: check if request status has changed - if ( + # check how many times the request has been re-queued + elif ( CONFIG.broker_requeue_on_lost_requests and request.request_metadata.get("resubmit_number", 0) < CONFIG.broker_requeue_limit @@ -555,26 +629,22 @@ def sync_database(self, session: sa.orm.Session) -> None: logger.info( "request not found: re-queueing", job_id={request.request_uid} ) - queued_request = db.requeue_request( - request=request, session=session + requeue_request( + request=request, + qos=self.qos, + queue=self.queue, + internal_scheduler=self.internal_scheduler, + session=session, ) - if queued_request: - self.queue.add(queued_request.request_uid, request) - self.qos.notify_end_of_request( - request, scheduler=self.internal_scheduler - ) else: - db.set_request_status( - request_uid=request.request_uid, - status="failed", + set_failed_request( + request=request, error_message="Request not found in dask scheduler", error_reason="not_found", + qos=self.qos, + internal_scheduler=self.internal_scheduler, session=session, ) - self.qos.notify_end_of_request( - request, scheduler=self.internal_scheduler - ) - logger.info("job has finished", **db.logger_kwargs(request=request)) @perf_logger def sync_qos_rules(self, session_write) -> None: @@ -624,7 +694,11 @@ def sync_futures(self) -> None: for key in finished_futures: self.futures.pop(key, None) - def on_future_done(self, future: distributed.Future) -> str: + def on_future_done(self, future: distributed.Future) -> str | None: + """Update the database status of the request according to the status of the future. + + If the status of the request in the database is not "running", it does nothing and returns None. + """ with self.session_maker_write() as session: try: request = db.get_request(future.key, session=session) @@ -636,45 +710,32 @@ def on_future_done(self, future: distributed.Future) -> str: ) return future.key if request.status != "running": - return + return None if future.status == "finished": # the result is updated in the database by the worker - request = db.set_successful_request( - request_uid=future.key, + set_successful_request( + request=request, + qos=self.qos, + internal_scheduler=self.internal_scheduler, session=session, ) elif future.status == "error": exception = future.exception() - request = self.set_request_error_status( + self.set_request_error_status( exception=exception, request_uid=future.key, session=session ) elif future.status != "cancelled": # if the dask status is unknown, re-queue it - request = db.set_request_status( - future.key, - "accepted", + requeue_request( + request=request, + qos=self.qos, + queue=self.queue, + internal_scheduler=self.internal_scheduler, session=session, - resubmit=True, - ) - self.queue.add(future.key, request) - logger.warning( - "unknown dask status, re-queing", - job_status={future.status}, - job_id=request.request_uid, ) else: # if the dask status is cancelled, the qos has already been reset by sync_database - return - # self.futures.pop(future.key, None) - if request: - self.qos.notify_end_of_request( - request, scheduler=self.internal_scheduler - ) - logger.info( - "job has finished", - dask_status=future.status, - **db.logger_kwargs(request=request), - ) + return None future.release() return future.key @@ -682,6 +743,7 @@ def on_future_done(self, future: distributed.Future) -> str: def cache_requests_qos_properties(self, requests, session: sa.orm.Session) -> None: """Cache the qos properties of the requests.""" # copy list of requests to avoid RuntimeError: dictionary changed size during iteration + logger.info(f"caching qos properties for {len(requests)} requests") for request in list(requests): try: self.qos._properties(request, check_permissions=True) @@ -737,14 +799,14 @@ def submit_request( priority: int | None = None, ) -> None: """Submit the request to the dask scheduler and update the qos rules accordingly.""" - request = db.set_request_status( - request_uid=request.request_uid, - status="running", + request = set_running_request( + request=request, priority=priority, + queue=self.queue, + qos=self.qos, + internal_scheduler=self.internal_scheduler, session=session, ) - self.qos.notify_start_of_request(request, scheduler=self.internal_scheduler) - self.queue.pop(request.request_uid) future = self.client.submit( worker.submit_workflow, key=request.request_uid, @@ -775,14 +837,19 @@ def run(self) -> None: with self.session_maker_read() as session_read: if get_rules_hash(self.qos.path) != self.qos.rules_hash: logger.info("reloading qos rules") - self.qos = instantiate_qos(session_read, self.number_of_workers) + self.qos = instantiate_qos( + session_read, self.environment.number_of_workers + ) with self.session_maker_write() as session_write: reload_qos_rules(session_write, self.qos) self.internal_scheduler.refresh() + self.queue.reset() self.qos.environment.set_session(session_read) # expire_on_commit=False is used to detach the accepted requests without an error # this is not a problem because accepted requests cannot be modified in this loop with self.session_maker_write(expire_on_commit=False) as session_write: + # reload qos rules if the number of workers has changed + self.update_number_of_workers(session_write) self.queue.add_accepted_requests( db.get_accepted_requests( session=session_write, @@ -814,13 +881,15 @@ def run(self) -> None: cancel_stuck_requests(client=self.client, session=session_read) running_requests = len(db.get_running_requests(session=session_read)) queue_length = self.queue.len() - available_workers = self.number_of_workers - running_requests + available_workers = ( + self.environment.number_of_workers - running_requests + ) if queue_length > 0: logger.info( "broker info", available_workers=available_workers, running_requests=running_requests, - number_of_workers=self.number_of_workers, + number_of_workers=self.environment.number_of_workers, futures=len(self.futures), ) if available_workers > 0: