13
13
# limitations under the License.
14
14
15
15
import logging
16
+ import os
16
17
import time
17
18
from typing import Any , Dict , List , Optional , Tuple
19
+ from unittest .mock import patch
18
20
19
21
import ray
20
22
from ray .experimental .state .api import get_actor
23
25
from ray .util .scheduling_strategies import NodeAffinitySchedulingStrategy , PlacementGroupSchedulingStrategy
24
26
25
27
from verl .single_controller .base import ClassWithInitArgs , ResourcePool , Worker , WorkerGroup
28
+ from verl .single_controller .base .decorator import MAGIC_ATTR , Dispatch
26
29
27
30
__all__ = ["Worker" ]
28
31
@@ -300,17 +303,23 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d
300
303
elapsed = int (time .time () - start_time )
301
304
if elapsed % 30 == 0 :
302
305
logging .warning (
303
- f"Waiting for register center actor { actor_name } to be ready. "
304
- f"Elapsed time: { elapsed } seconds out of { self ._ray_wait_register_center_timeout } seconds."
306
+ "Waiting for register center actor %s to be ready. "
307
+ "Elapsed time: %s seconds out of %s seconds." ,
308
+ actor_name ,
309
+ elapsed ,
310
+ self ._ray_wait_register_center_timeout ,
305
311
)
306
312
time .sleep (1 )
307
313
308
314
if register_center_actor is None :
309
315
raise TimeoutError (
310
- f"Failed to get register_center_actor { actor_name } in { list_named_actors (all_namespaces = True )} "
316
+ f"Failed to get register_center_actor { actor_name } "
317
+ f"in { list_named_actors (all_namespaces = True )} "
311
318
f"for { self ._ray_wait_register_center_timeout } seconds. "
312
- "Ensure that any lingering Ray resources from previous runs are cleaned up (e.g., by restarting the Ray cluster), "
313
- "or adjust the waiting time by modifying the config `trainer.ray_wait_register_center_timeout`."
319
+ "Ensure that any lingering Ray resources from previous "
320
+ "runs are cleaned up (e.g., by restarting the Ray cluster), "
321
+ "or adjust the waiting time by modifying the config "
322
+ "`trainer.ray_wait_register_center_timeout`."
314
323
)
315
324
316
325
rank_zero_info = ray .get (register_center_actor .get_rank_zero_info .remote ())
@@ -329,10 +338,9 @@ def from_detached(
329
338
worker_names = None ,
330
339
ray_cls_with_init = None ,
331
340
):
332
- worker_group = cls (resource_pool = None ,
333
- ray_cls_with_init = ray_cls_with_init ,
334
- name_prefix = name_prefix ,
335
- worker_names = worker_names )
341
+ worker_group = cls (
342
+ resource_pool = None , ray_cls_with_init = ray_cls_with_init , name_prefix = name_prefix , worker_names = worker_names
343
+ )
336
344
return worker_group
337
345
338
346
def spawn (self , prefix_set ):
@@ -382,8 +390,9 @@ def execute_all_sync(self, method_name: str, *args, **kwargs):
382
390
return ray .get (self .execute_all_async (method_name , * args , ** kwargs ))
383
391
384
392
def execute_all_async (self , method_name : str , * args , ** kwargs ):
385
- # Here, we assume that if all arguments in args and kwargs are lists, and their lengths match len(self._workers),
386
- # we'll distribute each element in these lists to the corresponding worker
393
+ # Here, we assume that if all arguments in args and kwargs are lists,
394
+ # and their lengths match len(self._workers), we'll distribute each
395
+ # element in these lists to the corresponding worker
387
396
# print(f"execute_all_async: method {method_name}({args}, {kwargs})")
388
397
length = len (self ._workers )
389
398
if all (isinstance (arg , list ) for arg in args ) and all (isinstance (kwarg , list ) for kwarg in kwargs .values ()):
@@ -421,11 +430,6 @@ def world_size(self):
421
430
with code written in separate ray.Actors.
422
431
"""
423
432
424
- import os
425
- from unittest .mock import patch
426
-
427
- from verl .single_controller .base .decorator import MAGIC_ATTR , Dispatch
428
-
429
433
430
434
def _bind_workers_method_to_parent (cls , key , user_defined_cls ):
431
435
"""
@@ -443,12 +447,12 @@ def _bind_workers_method_to_parent(cls, key, user_defined_cls):
443
447
444
448
if hasattr (method , MAGIC_ATTR ):
445
449
446
- def generate_function (name ):
450
+ def generate_function (name , key = key ):
447
451
def func (self , * args , ** kwargs ):
448
452
# dispatch to the actual worker
449
453
return getattr (self .worker_dict [key ], name )(* args , ** kwargs )
450
454
451
- return func
455
+ return func # noqa: B023
452
456
453
457
func = generate_function (method_name )
454
458
# pass MAGIC_ATTR for outer worker group
@@ -457,15 +461,16 @@ def func(self, *args, **kwargs):
457
461
try :
458
462
# bind direct rollout method to class without prefix
459
463
if attrs ["dispatch_mode" ] == Dispatch .DIRECT_ROLLOUT_METHOD and "rollout" in key :
460
- assert not hasattr (cls , method_name ), \
464
+ assert not hasattr (cls , method_name ), (
461
465
f"conflict direct rollout method { method_name } with role { key } "
466
+ )
462
467
setattr (cls , method_name , func )
463
468
print (f"bind role { key } method { method_name } to class { cls } " )
464
469
else :
465
- method_name_with_prefix = key + '_' + method_name
470
+ method_name_with_prefix = key + "_" + method_name
466
471
setattr (cls , method_name_with_prefix , func )
467
472
except Exception as e :
468
- raise ValueError (f"Fail to set method_name { method_name } " )
473
+ raise ValueError (f"Fail to set method_name { method_name } " ) from e
469
474
470
475
471
476
def _unwrap_ray_remote (cls ):
@@ -474,32 +479,31 @@ def _unwrap_ray_remote(cls):
474
479
return cls
475
480
476
481
477
- def _nearest_common_base (mros : List ):
478
- last_common = object
479
- min_len = min ([len (mro ) for mro in mros ]) - 1 # exclude final derived class
480
-
481
- for i in range (min_len ):
482
- mro = mros [0 ][i ]
483
- for j in range (1 , len (mros )):
484
- if mro != mros [j ][i ]:
485
- return last_common
486
- last_common = mro
487
-
488
- return last_common
482
+ def _determine_fsdp_megatron_base_class (mros : List ):
483
+ """
484
+ - megatron: base class should be MegatronWorker
485
+ - fsdp: base class should be Worker
486
+ """
487
+ for cls in mros [0 ]:
488
+ if cls .__name__ == "MegatronWorker" :
489
+ return cls
490
+ if cls .__name__ == "Worker" :
491
+ return cls
492
+ raise ValueError (f"Cannot determine base class for { mros } " )
489
493
490
494
491
- def create_colocated_worker_cls (class_dict : dict [str , RayClassWithInitArgs ], worker_cls : type = None ):
495
+ def create_colocated_worker_cls (class_dict : dict [str , RayClassWithInitArgs ]):
492
496
"""
493
497
This function should return a class instance that delegates the calls to every
494
498
cls in cls_dict
495
499
"""
496
500
cls_dict = {}
497
501
init_args_dict = {}
498
- if worker_cls is None :
499
- worker_cls = _nearest_common_base (
500
- [ list ( reversed ( cls . cls . __ray_actor_class__ . __mro__ )) for cls in class_dict . values ()] )
502
+ worker_cls = _determine_fsdp_megatron_base_class (
503
+ [ cls . cls . __ray_actor_class__ . __mro__ for cls in class_dict . values ()]
504
+ )
501
505
assert issubclass (worker_cls , Worker ), f"worker_cls { worker_cls } should be a subclass of Worker"
502
- print (f"find nearest common base class { worker_cls } " )
506
+ print (f"colocated worker base class { worker_cls } " )
503
507
504
508
for key , cls in class_dict .items ():
505
509
cls_dict [key ] = cls .cls
@@ -515,7 +519,8 @@ def __init__(self):
515
519
for key , user_defined_cls in cls_dict .items ():
516
520
user_defined_cls = _unwrap_ray_remote (user_defined_cls )
517
521
# directly instantiate the class without remote
518
- # in worker class, e.g. <verl.single_controller.base.worker.Worker> when DISABLE_WORKER_INIT == 1 it will return immediately
522
+ # in worker class, e.g. <verl.single_controller.base.worker.Worker>
523
+ # when DISABLE_WORKER_INIT == 1 it will return immediately
519
524
with patch .dict (os .environ , {"DISABLE_WORKER_INIT" : "1" }):
520
525
self .worker_dict [key ] = user_defined_cls (
521
526
* init_args_dict [key ].get ("args" , ()), ** init_args_dict [key ].get ("kwargs" , {})
0 commit comments