Skip to content

Commit 6357b88

Browse files
committed
move FSDP param load/offload into sharding manager
1 parent 89ff526 commit 6357b88

File tree

11 files changed

+148
-133
lines changed

11 files changed

+148
-133
lines changed

tests/rollout/test_vllm_multi_turn.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from openai.types.chat.chat_completion import ChatCompletion
2121

2222
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
23-
from verl.single_controller.ray.base import Worker, create_colocated_worker_cls
23+
from verl.single_controller.ray.base import create_colocated_worker_cls
2424
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
2525
from verl.workers.fsdp_async_workers import AsyncActorRolloutRefWorker, AsyncLLMManager
2626
from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler
@@ -35,20 +35,25 @@ async def test_vllm_multi_turn():
3535
config.actor_rollout_ref.rollout.prompt_length = 4096
3636
config.actor_rollout_ref.rollout.response_length = 4096
3737

38+
# test sleep/wake_up with fsdp offload
39+
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
40+
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
41+
3842
# =========================== 1. Create hybrid ActorRollout workers ===========================
3943
ray.init(
4044
runtime_env={
41-
'env_vars': {
42-
'TOKENIZERS_PARALLELISM': 'true',
43-
'NCCL_DEBUG': 'WARN',
44-
'VLLM_LOGGING_LEVEL': 'WARN',
45-
'VLLM_USE_V1': '1',
45+
"env_vars": {
46+
"TOKENIZERS_PARALLELISM": "true",
47+
"NCCL_DEBUG": "WARN",
48+
"VLLM_LOGGING_LEVEL": "WARN",
49+
"VLLM_USE_V1": "1",
4650
}
47-
})
51+
}
52+
)
4853
role_worker_mapping = {
4954
Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker),
5055
}
51-
global_pool_id = 'global_pool'
56+
global_pool_id = "global_pool"
5257
resource_pool_spec = {
5358
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
5459
}
@@ -61,20 +66,20 @@ async def test_vllm_multi_turn():
6166

6267
# create actor and rollout
6368
resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout)
64-
actor_rollout_cls = RayClassWithInitArgs(cls=role_worker_mapping[Role.ActorRollout],
65-
config=config.actor_rollout_ref,
66-
role='actor_rollout')
67-
resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls
69+
actor_rollout_cls = RayClassWithInitArgs(
70+
cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout"
71+
)
72+
resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
6873

6974
all_wg = {}
7075
wg_dicts = []
7176
for resource_pool, class_dict in resource_pool_to_cls.items():
72-
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict, worker_cls=Worker)
77+
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
7378
wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
7479
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
7580
all_wg.update(spawn_wg)
7681
wg_dicts.append(wg_dict)
77-
actor_rollout_wg = all_wg['actor_rollout']
82+
actor_rollout_wg = all_wg["actor_rollout"]
7883
actor_rollout_wg.init_model()
7984

8085
# =========================== 2. Create AsyncLLMManager&ChatScheduler ===========================
@@ -89,6 +94,10 @@ async def test_vllm_multi_turn():
8994
server_addresses=async_rollout_manager.server_addresses,
9095
)
9196

97+
# test sleep and wake_up
98+
async_rollout_manager.sleep()
99+
async_rollout_manager.wake_up()
100+
92101
# =========================== 3. Multi turn rollout ===========================
93102
async def callback(completions: ChatCompletion, info: Dict[str, Any]):
94103
messages, round = info["messages"], info["round"]
@@ -101,10 +110,7 @@ async def callback(completions: ChatCompletion, info: Dict[str, Any]):
101110
messages.append({"role": "user", "content": "What is your name?"})
102111
await async_chat_scheduler.submit_chat_completions(
103112
callback=callback,
104-
callback_additional_info={
105-
"messages": messages,
106-
"round": 1
107-
},
113+
callback_additional_info={"messages": messages, "round": 1},
108114
model=model_name,
109115
messages=messages,
110116
extra_headers=extra_headers,
@@ -113,27 +119,20 @@ async def callback(completions: ChatCompletion, info: Dict[str, Any]):
113119
messages.append({"role": "user", "content": "What is your favorite color?"})
114120
await async_chat_scheduler.submit_chat_completions(
115121
callback=callback,
116-
callback_additional_info={
117-
"messages": messages,
118-
"round": 2
119-
},
122+
callback_additional_info={"messages": messages, "round": 2},
120123
model=model_name,
121124
messages=messages,
122125
extra_headers=extra_headers,
123126
)
124127
else:
125128
print("Done!")
126129

127-
messages = [{
128-
"role": "user",
129-
"content": "Let's play a role playing game. Your name is Bob, your favorite color is red."
130-
}]
130+
messages = [
131+
{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}
132+
]
131133
await async_chat_scheduler.submit_chat_completions(
132134
callback=callback,
133-
callback_additional_info={
134-
"messages": messages,
135-
"round": 0
136-
},
135+
callback_additional_info={"messages": messages, "round": 0},
137136
model=model_name,
138137
messages=messages,
139138
)

verl/single_controller/base/register_center/ray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, Tuple
15+
from typing import Dict
1616

1717
import ray
1818

verl/single_controller/ray/base.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import os
1617
import time
1718
from typing import Any, Dict, List, Optional, Tuple
19+
from unittest.mock import patch
1820

1921
import ray
2022
from ray.experimental.state.api import get_actor
@@ -23,6 +25,7 @@
2325
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy
2426

2527
from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup
28+
from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch
2629

2730
__all__ = ["Worker"]
2831

@@ -300,17 +303,23 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d
300303
elapsed = int(time.time() - start_time)
301304
if elapsed % 30 == 0:
302305
logging.warning(
303-
f"Waiting for register center actor {actor_name} to be ready. "
304-
f"Elapsed time: {elapsed} seconds out of {self._ray_wait_register_center_timeout} seconds."
306+
"Waiting for register center actor %s to be ready. "
307+
"Elapsed time: %s seconds out of %s seconds.",
308+
actor_name,
309+
elapsed,
310+
self._ray_wait_register_center_timeout,
305311
)
306312
time.sleep(1)
307313

308314
if register_center_actor is None:
309315
raise TimeoutError(
310-
f"Failed to get register_center_actor {actor_name} in {list_named_actors(all_namespaces=True)} "
316+
f"Failed to get register_center_actor {actor_name} "
317+
f"in {list_named_actors(all_namespaces=True)} "
311318
f"for {self._ray_wait_register_center_timeout} seconds. "
312-
"Ensure that any lingering Ray resources from previous runs are cleaned up (e.g., by restarting the Ray cluster), "
313-
"or adjust the waiting time by modifying the config `trainer.ray_wait_register_center_timeout`."
319+
"Ensure that any lingering Ray resources from previous "
320+
"runs are cleaned up (e.g., by restarting the Ray cluster), "
321+
"or adjust the waiting time by modifying the config "
322+
"`trainer.ray_wait_register_center_timeout`."
314323
)
315324

316325
rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote())
@@ -329,10 +338,9 @@ def from_detached(
329338
worker_names=None,
330339
ray_cls_with_init=None,
331340
):
332-
worker_group = cls(resource_pool=None,
333-
ray_cls_with_init=ray_cls_with_init,
334-
name_prefix=name_prefix,
335-
worker_names=worker_names)
341+
worker_group = cls(
342+
resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=name_prefix, worker_names=worker_names
343+
)
336344
return worker_group
337345

338346
def spawn(self, prefix_set):
@@ -382,8 +390,9 @@ def execute_all_sync(self, method_name: str, *args, **kwargs):
382390
return ray.get(self.execute_all_async(method_name, *args, **kwargs))
383391

384392
def execute_all_async(self, method_name: str, *args, **kwargs):
385-
# Here, we assume that if all arguments in args and kwargs are lists, and their lengths match len(self._workers),
386-
# we'll distribute each element in these lists to the corresponding worker
393+
# Here, we assume that if all arguments in args and kwargs are lists,
394+
# and their lengths match len(self._workers), we'll distribute each
395+
# element in these lists to the corresponding worker
387396
# print(f"execute_all_async: method {method_name}({args}, {kwargs})")
388397
length = len(self._workers)
389398
if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()):
@@ -421,11 +430,6 @@ def world_size(self):
421430
with code written in separate ray.Actors.
422431
"""
423432

424-
import os
425-
from unittest.mock import patch
426-
427-
from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch
428-
429433

430434
def _bind_workers_method_to_parent(cls, key, user_defined_cls):
431435
"""
@@ -443,12 +447,12 @@ def _bind_workers_method_to_parent(cls, key, user_defined_cls):
443447

444448
if hasattr(method, MAGIC_ATTR):
445449

446-
def generate_function(name):
450+
def generate_function(name, key=key):
447451
def func(self, *args, **kwargs):
448452
# dispatch to the actual worker
449453
return getattr(self.worker_dict[key], name)(*args, **kwargs)
450454

451-
return func
455+
return func # noqa: B023
452456

453457
func = generate_function(method_name)
454458
# pass MAGIC_ATTR for outer worker group
@@ -457,15 +461,16 @@ def func(self, *args, **kwargs):
457461
try:
458462
# bind direct rollout method to class without prefix
459463
if attrs["dispatch_mode"] == Dispatch.DIRECT_ROLLOUT_METHOD and "rollout" in key:
460-
assert not hasattr(cls, method_name), \
464+
assert not hasattr(cls, method_name), (
461465
f"conflict direct rollout method {method_name} with role {key}"
466+
)
462467
setattr(cls, method_name, func)
463468
print(f"bind role {key} method {method_name} to class {cls}")
464469
else:
465-
method_name_with_prefix = key + '_' + method_name
470+
method_name_with_prefix = key + "_" + method_name
466471
setattr(cls, method_name_with_prefix, func)
467472
except Exception as e:
468-
raise ValueError(f"Fail to set method_name {method_name}")
473+
raise ValueError(f"Fail to set method_name {method_name}") from e
469474

470475

471476
def _unwrap_ray_remote(cls):
@@ -474,32 +479,31 @@ def _unwrap_ray_remote(cls):
474479
return cls
475480

476481

477-
def _nearest_common_base(mros: List):
478-
last_common = object
479-
min_len = min([len(mro) for mro in mros]) - 1 # exclude final derived class
480-
481-
for i in range(min_len):
482-
mro = mros[0][i]
483-
for j in range(1, len(mros)):
484-
if mro != mros[j][i]:
485-
return last_common
486-
last_common = mro
487-
488-
return last_common
482+
def _determine_fsdp_megatron_base_class(mros: List):
483+
"""
484+
- megatron: base class should be MegatronWorker
485+
- fsdp: base class should be Worker
486+
"""
487+
for cls in mros[0]:
488+
if cls.__name__ == "MegatronWorker":
489+
return cls
490+
if cls.__name__ == "Worker":
491+
return cls
492+
raise ValueError(f"Cannot determine base class for {mros}")
489493

490494

491-
def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs], worker_cls: type = None):
495+
def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]):
492496
"""
493497
This function should return a class instance that delegates the calls to every
494498
cls in cls_dict
495499
"""
496500
cls_dict = {}
497501
init_args_dict = {}
498-
if worker_cls is None:
499-
worker_cls = _nearest_common_base(
500-
[list(reversed(cls.cls.__ray_actor_class__.__mro__)) for cls in class_dict.values()])
502+
worker_cls = _determine_fsdp_megatron_base_class(
503+
[cls.cls.__ray_actor_class__.__mro__ for cls in class_dict.values()]
504+
)
501505
assert issubclass(worker_cls, Worker), f"worker_cls {worker_cls} should be a subclass of Worker"
502-
print(f"find nearest common base class {worker_cls}")
506+
print(f"colocated worker base class {worker_cls}")
503507

504508
for key, cls in class_dict.items():
505509
cls_dict[key] = cls.cls
@@ -515,7 +519,8 @@ def __init__(self):
515519
for key, user_defined_cls in cls_dict.items():
516520
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
517521
# directly instantiate the class without remote
518-
# in worker class, e.g. <verl.single_controller.base.worker.Worker> when DISABLE_WORKER_INIT == 1 it will return immediately
522+
# in worker class, e.g. <verl.single_controller.base.worker.Worker>
523+
# when DISABLE_WORKER_INIT == 1 it will return immediately
519524
with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}):
520525
self.worker_dict[key] = user_defined_cls(
521526
*init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {})

verl/trainer/config/generation.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ model:
1414
external_lib: null
1515
rollout:
1616
name: vllm
17+
mode: "sync" # sync: LLM, async: AsyncLLM
1718
temperature: 1.0
1819
top_k: 50 # 0 for hf rollout, -1 for vllm rollout
1920
top_p: 0.7

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ actor_rollout_ref:
9494
log_prob_micro_batch_size_per_gpu: null
9595
rollout:
9696
name: vllm
97+
mode: "sync" # sync: LLM, async: AsyncLLM
9798
temperature: 1.0
9899
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
99100
top_p: 1

verl/trainer/ppo/ray_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ def init_workers(self):
736736

737737
# create async rollout manager and request scheduler
738738
self.async_rollout_mode = False
739-
if self.config.actor_rollout_ref.rollout.get("mode", "sync") == 'async':
739+
if self.config.actor_rollout_ref.rollout.mode == "async":
740740
from verl.workers.fsdp_async_workers import AsyncLLMManager
741741

742742
self.async_rollout_mode = True

0 commit comments

Comments
 (0)