Skip to content

Commit 96db0e3

Browse files
use pydantic in broker configuration
1 parent 1c5f894 commit 96db0e3

File tree

2 files changed

+45
-30
lines changed

2 files changed

+45
-30
lines changed

cads_broker/config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,23 @@
2525
dbsettings = None
2626

2727

28+
class BrokerConfig(pydantic_settings.BaseSettings):
29+
30+
high_priority_user_uid: str = "8d8ee054-6a09-4da8-a5be-d5dff52bbc5f"
31+
broker_priority_algorithm: str = "legacy"
32+
broker_priority_interval_hours: int = 24
33+
get_number_of_workers_cache_time: int = 10
34+
qos_rules_cache_time: int = 10
35+
get_tasks_from_scheduler_cache_time: int = 1
36+
rules_path: str = "/src/rules.qos"
37+
wait_time: float = 2.
38+
sync_database_cache_time: int = 10
39+
broker_requeue_on_killed_worker_requests: bool = False
40+
broker_requeue_on_lost_requests: bool = True
41+
broker_requeue_limit: int = 3
42+
broker_max_internal_scheduler_tasks: int = 500
43+
44+
2845
class SqlalchemySettings(pydantic_settings.BaseSettings):
2946
"""Postgres-specific API settings.
3047

cads_broker/dispatcher.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,15 @@
3535
"finished": "successful",
3636
}
3737

38-
WORKERS_MULTIPLIER = float(os.getenv("WORKERS_MULTIPLIER", 1))
3938
ONE_SECOND = datetime.timedelta(seconds=1)
40-
HIGH_PRIORITY_USER_UID = os.getenv(
41-
"HIGH_PRIORITY_USER_UID", "8d8ee054-6a09-4da8-a5be-d5dff52bbc5f"
42-
)
43-
BROKER_PRIORITY_ALGORITHM = os.getenv("BROKER_PRIORITY_ALGORITHM", "legacy")
39+
ONE_MINUTE = ONE_SECOND * 60
40+
ONE_HOUR = ONE_MINUTE * 60
41+
CONFIG = config.BrokerConfig()
4442

4543

4644
@cachetools.cached( # type: ignore
4745
cache=cachetools.TTLCache(
48-
maxsize=1024, ttl=float(os.getenv("GET_NUMBER_OF_WORKERS_CACHE_TIME", 10))
46+
maxsize=1024, ttl=CONFIG.get_number_of_workers_cache_time
4947
),
5048
info=True,
5149
)
@@ -58,14 +56,12 @@ def get_number_of_workers(client: distributed.Client) -> int:
5856

5957

6058
@cachetools.cached( # type: ignore
61-
cache=cachetools.TTLCache(
62-
maxsize=1024, ttl=int(os.getenv("QOS_RULES_CACHE_TIME", 10))
63-
),
59+
cache=cachetools.TTLCache(maxsize=1024, ttl=CONFIG.qos_rules_cache_time),
6460
info=True,
6561
)
6662
def get_rules_hash(rules_path: str):
6763
if rules_path is None or not os.path.exists(rules_path):
68-
rules = os.getenv("DEFAULT_RULES", "")
64+
rules = ""
6965
else:
7066
with open(rules_path) as f:
7167
rules = f.read()
@@ -74,7 +70,7 @@ def get_rules_hash(rules_path: str):
7470

7571
@cachetools.cached( # type: ignore
7672
cache=cachetools.TTLCache(
77-
maxsize=1024, ttl=int(os.getenv("GET_TASKS_FROM_SCHEDULER_CACHE_TIME", 1))
73+
maxsize=1024, ttl=CONFIG.get_tasks_from_scheduler_cache_time
7874
),
7975
info=True,
8076
)
@@ -192,14 +188,12 @@ def reset(self) -> None:
192188
class QoSRules:
193189
def __init__(self, number_of_workers) -> None:
194190
self.environment = Environment.Environment(number_of_workers=number_of_workers)
195-
self.rules_path = os.getenv("RULES_PATH", "/src/rules.qos")
191+
self.rules_path = CONFIG.rules_path
196192
if os.path.exists(self.rules_path):
197193
self.rules = self.rules_path
198194
else:
199195
logger.info("rules file not found", rules_path=self.rules_path)
200-
parser = QoS.RulesParser(
201-
io.StringIO(os.getenv("DEFAULT_RULES", "")), logger=logger
202-
)
196+
parser = QoS.RulesParser(io.StringIO(""), logger=logger)
203197
self.rules = QoS.RuleSet()
204198
parser.parse_rules(self.rules, self.environment, raise_exception=False)
205199

@@ -212,9 +206,9 @@ class Broker:
212206
address: str
213207
session_maker_read: sa.orm.sessionmaker
214208
session_maker_write: sa.orm.sessionmaker
215-
wait_time: float = float(os.getenv("BROKER_WAIT_TIME", 2))
209+
wait_time: float = CONFIG.wait_time
216210
ttl_cache = cachetools.TTLCache(
217-
maxsize=1024, ttl=int(os.getenv("SYNC_DATABASE_CACHE_TIME", 10))
211+
maxsize=1024, ttl=CONFIG.sync_database_cache_time
218212
)
219213

220214
futures: dict[str, distributed.Future] = attrs.field(
@@ -282,7 +276,7 @@ def set_request_error_status(
282276
request = db.get_request(request_uid, session=session)
283277
if request.status != "running":
284278
return None
285-
requeue = os.getenv("BROKER_REQUEUE_ON_KILLED_WORKER_REQUESTS", False)
279+
requeue = CONFIG.broker_requeue_on_killed_worker_requests
286280
if error_reason == "KilledWorker":
287281
worker_restart_events = self.client.get_events("worker-restart-memory")
288282
# get info on worker and pid of the killed request
@@ -315,9 +309,11 @@ def set_request_error_status(
315309
session=session,
316310
)
317311
requeue = False
318-
if requeue and request.request_metadata.get(
319-
"resubmit_number", 0
320-
) < os.getenv("BROKER_REQUEUE_LIMIT", 3):
312+
if (
313+
requeue
314+
and request.request_metadata.get("resubmit_number", 0)
315+
< CONFIG.broker_requeue_limit
316+
):
321317
logger.info("worker killed: re-queueing", job_id=request_uid)
322318
db.requeue_request(request=request, session=session)
323319
self.queue.add(request_uid, request)
@@ -444,10 +440,10 @@ def sync_database(self, session: sa.orm.Session) -> None:
444440
)
445441
continue
446442
# FIXME: check if request status has changed
447-
if os.getenv(
448-
"BROKER_REQUEUE_ON_LOST_REQUESTS", True
449-
) and request.request_metadata.get("resubmit_number", 0) < os.getenv(
450-
"BROKER_REQUEUE_LIMIT", 3
443+
if (
444+
CONFIG.broker_requeue_on_lost_requests
445+
and request.request_metadata.get("resubmit_number", 0)
446+
< CONFIG.broker_requeue_limit
451447
):
452448
logger.info(
453449
"request not found: re-queueing", job_id={request.request_uid}
@@ -488,7 +484,7 @@ def sync_qos_rules(self, session_write) -> None:
488484
if tasks_number := len(self.internal_scheduler.queue):
489485
logger.info("performance", tasks_number=tasks_number)
490486
for task in list(self.internal_scheduler.queue)[
491-
: int(os.getenv("BROKER_MAX_INTERNAL_SCHEDULER_TASKS", 500))
487+
: CONFIG.broker_max_internal_scheduler_tasks
492488
]:
493489
# the internal scheduler is used to asynchronously add qos rules to database
494490
# it returns a new qos rule if a new qos rule is added to database
@@ -575,16 +571,18 @@ def submit_requests(
575571
candidates: Iterable[db.SystemRequest],
576572
) -> None:
577573
"""Check the qos rules and submit the requests to the dask scheduler."""
578-
if BROKER_PRIORITY_ALGORITHM == "processing_time":
574+
if CONFIG.broker_priority_algorithm == "processing_time":
579575
user_requests: dict[str, list[db.SystemRequest]] = {}
580576
for request in candidates:
581577
user_requests.setdefault(request.user_uid, []).append(request)
582578
# FIXME: this is a temporary solution to prioritize subrequests from the high priority user
583579
interval_stop = datetime.datetime.now()
584580
users_queue = {
585-
HIGH_PRIORITY_USER_UID: 0
581+
CONFIG.high_priority_user_uid: 0
586582
} | db.get_users_queue_from_processing_time(
587-
interval_stop=interval_stop, session=session_write
583+
interval_stop=interval_stop,
584+
session=session_write,
585+
interval=ONE_HOUR * CONFIG.broker_priority_interval_hours,
588586
)
589587
requests_counter = 0
590588
for user_uid in users_queue:
@@ -613,7 +611,7 @@ def submit_requests(
613611
if self.qos.can_run(
614612
request, session=session_write, scheduler=self.internal_scheduler
615613
):
616-
if requests_counter <= int(number_of_requests * WORKERS_MULTIPLIER):
614+
if requests_counter <= int(number_of_requests):
617615
self.submit_request(request, session=session_write)
618616
requests_counter += 1
619617

0 commit comments

Comments
 (0)