Skip to content

Commit 8f34be8

Browse files
[WIP] use read and write connections to db (#102)
* use read and write connections to db * fix * Refactor request submission logic in Broker class * add new functions and variables to qos rules * add metadata to config object * tests * qa
1 parent d32eb11 commit 8f34be8

File tree

4 files changed

+46
-25
lines changed

4 files changed

+46
-25
lines changed

cads_broker/database.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def cost(self):
123123
return (0, 0)
124124

125125

126-
def ensure_session_obj(session_obj: sa.orm.sessionmaker | None) -> sa.orm.sessionmaker:
126+
def ensure_session_obj(
127+
session_obj: sa.orm.sessionmaker | None, mode="w"
128+
) -> sa.orm.sessionmaker:
127129
"""If `session_obj` is None, create a new session object.
128130
129131
Parameters
@@ -138,11 +140,15 @@ def ensure_session_obj(session_obj: sa.orm.sessionmaker | None) -> sa.orm.sessio
138140
if session_obj:
139141
return session_obj
140142
settings = config.ensure_settings(config.dbsettings)
143+
if mode == "r":
144+
connection_string = settings.connection_string_read
145+
elif mode == "w":
146+
connection_string = settings.connection_string
141147
if settings.pool_size == -1:
142-
engine = sa.create_engine(settings.connection_string, poolclass=sa.pool.NullPool)
148+
engine = sa.create_engine(connection_string, poolclass=sa.pool.NullPool)
143149
else:
144150
engine = sa.create_engine(
145-
settings.connection_string,
151+
connection_string,
146152
pool_recycle=settings.pool_recycle,
147153
pool_size=settings.pool_size,
148154
pool_timeout=settings.pool_timeout,

cads_broker/dispatcher.py

+34-20
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ class Broker:
106106
environment: Environment.Environment
107107
qos: QoS.QoS
108108
address: str
109-
session_maker: sa.orm.sessionmaker
109+
session_maker_read: sa.orm.sessionmaker
110+
session_maker_write: sa.orm.sessionmaker
110111
wait_time: float = float(os.getenv("BROKER_WAIT_TIME", 2))
111112
cache = cachetools.TTLCache(
112113
maxsize=1024, ttl=int(os.getenv("SYNC_DATABASE_CACHE_TIME", 10))
@@ -122,16 +123,19 @@ class Broker:
122123
def from_address(
123124
cls,
124125
address="scheduler:8786",
125-
session_maker: sa.orm.sessionmaker | None = None,
126+
session_maker_read: sa.orm.sessionmaker | None = None,
127+
session_maker_write: sa.orm.sessionmaker | None = None,
126128
):
127129
client = distributed.Client(address)
128130
qos_config = QoSRules()
129131
factory.register_functions()
130-
session_maker = db.ensure_session_obj(session_maker)
132+
session_maker_read = db.ensure_session_obj(session_maker_read, mode="r")
133+
session_maker_write = db.ensure_session_obj(session_maker_write, mode="w")
131134
rules_hash = get_rules_hash(qos_config.rules_path)
132135
self = cls(
133136
client=client,
134-
session_maker=session_maker,
137+
session_maker_read=session_maker_read,
138+
session_maker_write=session_maker_write,
135139
environment=qos_config.environment,
136140
qos=QoS.QoS(
137141
qos_config.rules,
@@ -190,7 +194,7 @@ def on_future_done(self, future: distributed.Future) -> None:
190194
user_visible_log = list(
191195
self.client.get_events(f"{future.key}/user_visible_log")
192196
)
193-
with self.session_maker() as session:
197+
with self.session_maker_write() as session:
194198
if future.status == "finished":
195199
result = future.result()
196200
request = db.set_request_status(
@@ -243,20 +247,23 @@ def on_future_done(self, future: distributed.Future) -> None:
243247
**logger_kwargs,
244248
)
245249

246-
def submit_requests(self, session: sa.orm.Session, number_of_requests: int) -> None:
247-
candidates = db.get_accepted_requests(session=session)
250+
def submit_requests(
251+
self, session_read: sa.orm.Session, number_of_requests: int
252+
) -> None:
253+
candidates = db.get_accepted_requests(session=session_read)
248254
queue = sorted(
249255
candidates,
250-
key=lambda candidate: self.qos.priority(candidate, session),
256+
key=lambda candidate: self.qos.priority(candidate, session_read),
251257
reverse=True,
252258
)
253259
requests_counter = 0
254260
for request in queue:
255-
if self.qos.can_run(request, session=session):
256-
self.submit_request(request, session=session)
257-
requests_counter += 1
258-
if requests_counter == int(number_of_requests * WORKERS_MULTIPLIER):
259-
break
261+
with self.session_maker_write() as session_write:
262+
if self.qos.can_run(request, session=session_write):
263+
self.submit_request(request, session=session_write)
264+
requests_counter += 1
265+
if requests_counter == int(number_of_requests * WORKERS_MULTIPLIER):
266+
break
260267

261268
def submit_request(
262269
self, request: db.SystemRequest, session: sa.orm.Session
@@ -270,7 +277,12 @@ def submit_request(
270277
key=request.request_uid,
271278
setup_code=request.request_body.get("setup_code", ""),
272279
entry_point=request.entry_point,
273-
config=request.adaptor_properties.config,
280+
config=dict(
281+
request_uid=request.request_uid,
282+
user_uid=request.user_uid,
283+
hostname=os.getenv("CDS_PROJECT_URL"),
284+
**request.adaptor_properties.config,
285+
),
274286
form=request.adaptor_properties.form,
275287
request=request.request_body.get("request", {}),
276288
resources=request.request_metadata.get("resources", {}),
@@ -285,13 +297,14 @@ def submit_request(
285297

286298
def run(self) -> None:
287299
while True:
288-
with self.session_maker() as session:
300+
with self.session_maker_read() as session_read:
289301
if (rules_hash := get_rules_hash(self.qos.path)) != self.qos.rules_hash:
290302
logger.info("reloading qos rules")
291-
self.qos.reload_rules(session=session)
303+
self.qos.reload_rules(session=session_read)
292304
self.qos.rules_hash = rules_hash
293-
self.qos.environment.set_session(session)
294-
self.sync_database(session=session)
305+
self.qos.environment.set_session(session_read)
306+
with self.session_maker_write() as session_write:
307+
self.sync_database(session=session_write)
295308
self.running_requests = len(
296309
[
297310
future
@@ -301,7 +314,7 @@ def run(self) -> None:
301314
]
302315
)
303316
number_accepted_requests = db.count_requests(
304-
session=session, status="accepted"
317+
session=session_read, status="accepted"
305318
)
306319
available_workers = self.number_of_workers - self.running_requests
307320
if number_accepted_requests > 0:
@@ -315,6 +328,7 @@ def run(self) -> None:
315328
if available_workers > 0:
316329
logger.info("broker info", queued_jobs=number_accepted_requests)
317330
self.submit_requests(
318-
session=session, number_of_requests=available_workers
331+
session_read=session_read,
332+
number_of_requests=available_workers,
319333
)
320334
time.sleep(self.wait_time)

cads_broker/factory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def register_functions():
4343
)
4444
expressions.FunctionFactory.FunctionFactory.register_function(
4545
"user_finished_request_count",
46-
lambda context, seconds: database.count_finished_requests_per_user_in_session(
46+
lambda context, seconds: database.count_finished_requests_per_user(
4747
user_uid=context.request.user_uid,
4848
seconds=seconds,
4949
session=context.environment.session,

tests/test_20_dispatcher.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def test_broker_sync_database(
5454
environment=environment,
5555
qos=qos,
5656
address="scheduler-address",
57-
session_maker=session_obj,
57+
session_maker_read=session_obj,
58+
session_maker_write=session_obj,
5859
)
5960

6061
in_futures_request_uid = str(uuid.uuid4())

0 commit comments

Comments
 (0)