@@ -179,6 +179,11 @@ def remove(self, item: Any) -> None:
179
179
self .queue .remove (item )
180
180
self .index [item ["function" ].__name__ ].remove (item ["kwargs" ]["request_uid" ])
181
181
182
+ def refresh (self ) -> None :
183
+ with self ._lock :
184
+ self .queue = list ()
185
+ self .index = dict ()
186
+
182
187
183
188
def perf_logger (func ):
184
189
def wrapper (* args , ** kwargs ):
@@ -192,6 +197,25 @@ def wrapper(*args, **kwargs):
192
197
return wrapper
193
198
194
199
200
+ def instantiate_qos (session_read : sa .orm .Session , number_of_workers : int ) -> QoS .QoS :
201
+ qos_config = QoSRules (number_of_workers = number_of_workers )
202
+ factory .register_functions ()
203
+ rules_hash = get_rules_hash (qos_config .rules_path )
204
+ qos = QoS .QoS (
205
+ qos_config .rules ,
206
+ qos_config .environment ,
207
+ rules_hash = rules_hash ,
208
+ logger = logger ,
209
+ )
210
+ qos .environment .set_session (session_read )
211
+ return qos
212
+
213
+
214
+ def reload_qos_rules (session : sa .orm .sessionmaker , qos : QoS .QoS ) -> None :
215
+ perf_logger (qos .reload_rules )(session = session )
216
+ perf_logger (db .reset_qos_rules )(session , qos )
217
+
218
+
195
219
class Queue :
196
220
"""A simple queue to store the requests that have been accepted by the broker.
197
221
@@ -288,27 +312,17 @@ def from_address(
288
312
session_maker_write : sa .orm .sessionmaker | None = None ,
289
313
):
290
314
client = distributed .Client (address )
291
- qos_config = QoSRules (get_number_of_workers (client ))
292
- factory .register_functions ()
293
315
session_maker_read = db .ensure_session_obj (session_maker_read , mode = "r" )
294
316
session_maker_write = db .ensure_session_obj (session_maker_write , mode = "w" )
295
- rules_hash = get_rules_hash (qos_config .rules_path )
296
- qos = QoS .QoS (
297
- qos_config .rules ,
298
- qos_config .environment ,
299
- rules_hash = rules_hash ,
300
- logger = logger ,
301
- )
302
317
with session_maker_read () as session_read :
303
- qos . environment . set_session (session_read )
318
+ qos = instantiate_qos (session_read , get_number_of_workers ( client ) )
304
319
with session_maker_write () as session :
305
- perf_logger (qos .reload_rules )(session = session )
306
- perf_logger (db .reset_qos_rules )(session , qos )
320
+ reload_qos_rules (session , qos )
307
321
self = cls (
308
322
client = client ,
309
323
session_maker_read = session_maker_read ,
310
324
session_maker_write = session_maker_write ,
311
- environment = qos_config .environment ,
325
+ environment = qos .environment ,
312
326
qos = qos ,
313
327
address = address ,
314
328
)
@@ -740,12 +754,12 @@ def run(self) -> None:
740
754
# reset the cache of the qos functions
741
755
db .QOS_FUNCTIONS_CACHE .clear ()
742
756
with self .session_maker_read () as session_read :
743
- if ( rules_hash := get_rules_hash (self .qos .path ) ) != self .qos .rules_hash :
757
+ if get_rules_hash (self .qos .path ) != self .qos .rules_hash :
744
758
logger .info ("reloading qos rules" )
759
+ self .qos = instantiate_qos (session_read , self .number_of_workers )
745
760
with self .session_maker_write () as session_write :
746
- self .qos .reload_rules (session = session_write )
747
- db .reset_qos_rules (session_write , self .qos )
748
- self .qos .rules_hash = rules_hash
761
+ reload_qos_rules (session_write , self .qos )
762
+ self .internal_scheduler .refresh ()
749
763
self .qos .environment .set_session (session_read )
750
764
# expire_on_commit=False is used to detach the accepted requests without an error
751
765
# this is not a problem because accepted requests cannot be modified in this loop
0 commit comments