@@ -106,7 +106,8 @@ class Broker:
106
106
environment : Environment .Environment
107
107
qos : QoS .QoS
108
108
address : str
109
- session_maker : sa .orm .sessionmaker
109
+ session_maker_read : sa .orm .sessionmaker
110
+ session_maker_write : sa .orm .sessionmaker
110
111
wait_time : float = float (os .getenv ("BROKER_WAIT_TIME" , 2 ))
111
112
cache = cachetools .TTLCache (
112
113
maxsize = 1024 , ttl = int (os .getenv ("SYNC_DATABASE_CACHE_TIME" , 10 ))
@@ -122,16 +123,19 @@ class Broker:
122
123
def from_address (
123
124
cls ,
124
125
address = "scheduler:8786" ,
125
- session_maker : sa .orm .sessionmaker | None = None ,
126
+ session_maker_read : sa .orm .sessionmaker | None = None ,
127
+ session_maker_write : sa .orm .sessionmaker | None = None ,
126
128
):
127
129
client = distributed .Client (address )
128
130
qos_config = QoSRules ()
129
131
factory .register_functions ()
130
- session_maker = db .ensure_session_obj (session_maker )
132
+ session_maker_read = db .ensure_session_obj (session_maker_read , mode = "r" )
133
+ session_maker_write = db .ensure_session_obj (session_maker_write , mode = "w" )
131
134
rules_hash = get_rules_hash (qos_config .rules_path )
132
135
self = cls (
133
136
client = client ,
134
- session_maker = session_maker ,
137
+ session_maker_read = session_maker_read ,
138
+ session_maker_write = session_maker_write ,
135
139
environment = qos_config .environment ,
136
140
qos = QoS .QoS (
137
141
qos_config .rules ,
@@ -190,7 +194,7 @@ def on_future_done(self, future: distributed.Future) -> None:
190
194
user_visible_log = list (
191
195
self .client .get_events (f"{ future .key } /user_visible_log" )
192
196
)
193
- with self .session_maker () as session :
197
+ with self .session_maker_write () as session :
194
198
if future .status == "finished" :
195
199
result = future .result ()
196
200
request = db .set_request_status (
@@ -243,20 +247,23 @@ def on_future_done(self, future: distributed.Future) -> None:
243
247
** logger_kwargs ,
244
248
)
245
249
246
- def submit_requests (self , session : sa .orm .Session , number_of_requests : int ) -> None :
247
- candidates = db .get_accepted_requests (session = session )
250
+ def submit_requests (
251
+ self , session_read : sa .orm .Session , number_of_requests : int
252
+ ) -> None :
253
+ candidates = db .get_accepted_requests (session = session_read )
248
254
queue = sorted (
249
255
candidates ,
250
- key = lambda candidate : self .qos .priority (candidate , session ),
256
+ key = lambda candidate : self .qos .priority (candidate , session_read ),
251
257
reverse = True ,
252
258
)
253
259
requests_counter = 0
254
260
for request in queue :
255
- if self .qos .can_run (request , session = session ):
256
- self .submit_request (request , session = session )
257
- requests_counter += 1
258
- if requests_counter == int (number_of_requests * WORKERS_MULTIPLIER ):
259
- break
261
+ with self .session_maker_write () as session_write :
262
+ if self .qos .can_run (request , session = session_write ):
263
+ self .submit_request (request , session = session_write )
264
+ requests_counter += 1
265
+ if requests_counter == int (number_of_requests * WORKERS_MULTIPLIER ):
266
+ break
260
267
261
268
def submit_request (
262
269
self , request : db .SystemRequest , session : sa .orm .Session
@@ -270,7 +277,12 @@ def submit_request(
270
277
key = request .request_uid ,
271
278
setup_code = request .request_body .get ("setup_code" , "" ),
272
279
entry_point = request .entry_point ,
273
- config = request .adaptor_properties .config ,
280
+ config = dict (
281
+ request_uid = request .request_uid ,
282
+ user_uid = request .user_uid ,
283
+ hostname = os .getenv ("CDS_PROJECT_URL" ),
284
+ ** request .adaptor_properties .config ,
285
+ ),
274
286
form = request .adaptor_properties .form ,
275
287
request = request .request_body .get ("request" , {}),
276
288
resources = request .request_metadata .get ("resources" , {}),
@@ -285,13 +297,14 @@ def submit_request(
285
297
286
298
def run (self ) -> None :
287
299
while True :
288
- with self .session_maker () as session :
300
+ with self .session_maker_read () as session_read :
289
301
if (rules_hash := get_rules_hash (self .qos .path )) != self .qos .rules_hash :
290
302
logger .info ("reloading qos rules" )
291
- self .qos .reload_rules (session = session )
303
+ self .qos .reload_rules (session = session_read )
292
304
self .qos .rules_hash = rules_hash
293
- self .qos .environment .set_session (session )
294
- self .sync_database (session = session )
305
+ self .qos .environment .set_session (session_read )
306
+ with self .session_maker_write () as session_write :
307
+ self .sync_database (session = session_write )
295
308
self .running_requests = len (
296
309
[
297
310
future
@@ -301,7 +314,7 @@ def run(self) -> None:
301
314
]
302
315
)
303
316
number_accepted_requests = db .count_requests (
304
- session = session , status = "accepted"
317
+ session = session_read , status = "accepted"
305
318
)
306
319
available_workers = self .number_of_workers - self .running_requests
307
320
if number_accepted_requests > 0 :
@@ -315,6 +328,7 @@ def run(self) -> None:
315
328
if available_workers > 0 :
316
329
logger .info ("broker info" , queued_jobs = number_accepted_requests )
317
330
self .submit_requests (
318
- session = session , number_of_requests = available_workers
331
+ session_read = session_read ,
332
+ number_of_requests = available_workers ,
319
333
)
320
334
time .sleep (self .wait_time )
0 commit comments