Skip to content

Commit 736ca1f

Browse files
committed
add multi turn rollout test
1 parent 084553e commit 736ca1f

File tree

7 files changed

+174
-15
lines changed

7 files changed

+174
-15
lines changed

recipe/dapo/src/dapo_ray_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class RayDAPOTrainer(RayPPOTrainer):
3535
Note that this trainer runs on the driver process on a single CPU/GPU node.
3636
"""
3737

38-
def fit(self):
38+
async def fit(self):
3939
"""
4040
The training loop of PPO.
4141
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.

recipe/dapo/src/main_dapo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def run_ppo(config) -> None:
7575
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
7676
class TaskRunner:
7777

78-
def run(self, config):
78+
async def run(self, config):
7979
from verl.utils.fs import copy_to_local
8080
# print initial config
8181
from pprint import pprint
@@ -186,7 +186,7 @@ def run(self, config):
186186
reward_fn=reward_fn,
187187
val_reward_fn=val_reward_fn)
188188
trainer.init_workers()
189-
trainer.fit()
189+
await trainer.fit()
190190

191191

192192
if __name__ == '__main__':

recipe/prime/main_prime.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"""
3131
from .prime_ray_trainer import RayPRIMETrainer
3232

33+
import asyncio
3334
import ray
3435
import hydra
3536

@@ -54,6 +55,10 @@ def run_prime(config, compute_score=None):
5455

5556
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
5657
def main_task(config, compute_score=None):
58+
asyncio.run(_main_task(config, compute_score))
59+
60+
61+
async def _main_task(config, compute_score=None):
5762
from verl.utils.fs import copy_local_path_from_hdfs
5863
# print initial config
5964
from pprint import pprint
@@ -132,7 +137,7 @@ def main_task(config, compute_score=None):
132137
reward_fn=reward_fn,
133138
val_reward_fn=val_reward_fn)
134139
trainer.init_workers()
135-
trainer.fit()
140+
await trainer.fit()
136141

137142

138143
if __name__ == '__main__':

recipe/prime/prime_ray_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def _load_checkpoint(self):
308308
if isinstance(self.train_dataloader.dataset, RLHFDataset):
309309
self.train_dataloader.dataset.resume_dataset_state()
310310

311-
def fit(self):
311+
async def fit(self):
312312
"""
313313
The training loop of PPO.
314314
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.

tests/rollout/test_vllm_multi_turn.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
from typing import Any, Dict
17+
18+
import ray
19+
from omegaconf import OmegaConf
20+
from openai.types.chat.chat_completion import ChatCompletion
21+
22+
from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler
23+
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
24+
from verl.workers.fsdp_async_workers import AsyncActorRolloutRefWorker, AsyncLLMManager
25+
from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs
26+
from verl.single_controller.ray.base import Worker, create_colocated_worker_cls
27+
28+
29+
async def test_vllm_multi_turn():
30+
config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml")
31+
model_path = "Qwen/Qwen2-7B-Instruct"
32+
model_name = "/".join(model_path.split("/")[-2:])
33+
config.actor_rollout_ref.model.path = model_path
34+
config.actor_rollout_ref.rollout.mode = "async"
35+
config.actor_rollout_ref.rollout.prompt_length = 4096
36+
config.actor_rollout_ref.rollout.response_length = 4096
37+
38+
# =========================== 1. Create hybrid ActorRollout workers ===========================
39+
ray.init(
40+
runtime_env={
41+
'env_vars': {
42+
'TOKENIZERS_PARALLELISM': 'true',
43+
'NCCL_DEBUG': 'WARN',
44+
'VLLM_LOGGING_LEVEL': 'WARN',
45+
'VLLM_USE_V1': '1',
46+
}
47+
})
48+
role_worker_mapping = {
49+
Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker),
50+
}
51+
global_pool_id = 'global_pool'
52+
resource_pool_spec = {
53+
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
54+
}
55+
mapping = {
56+
Role.ActorRollout: global_pool_id,
57+
}
58+
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
59+
resource_pool_manager.create_resource_pool()
60+
resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()}
61+
62+
# create actor and rollout
63+
resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout)
64+
actor_rollout_cls = RayClassWithInitArgs(cls=role_worker_mapping[Role.ActorRollout],
65+
config=config.actor_rollout_ref,
66+
role='actor_rollout')
67+
resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls
68+
69+
all_wg = {}
70+
wg_dicts = []
71+
for resource_pool, class_dict in resource_pool_to_cls.items():
72+
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict, worker_cls=Worker)
73+
wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
74+
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
75+
all_wg.update(spawn_wg)
76+
wg_dicts.append(wg_dict)
77+
actor_rollout_wg = all_wg['actor_rollout']
78+
actor_rollout_wg.init_model()
79+
80+
# =========================== 2. Create AsyncLLMManager&ChatScheduler ===========================
81+
async_rollout_manager = AsyncLLMManager(
82+
config=config.actor_rollout_ref,
83+
worker_group=actor_rollout_wg,
84+
)
85+
86+
async_chat_scheduler = ChatCompletionScheduler(
87+
config=config.actor_rollout_ref.rollout,
88+
model_path=config.actor_rollout_ref.model.path,
89+
server_addresses=async_rollout_manager.server_addresses,
90+
)
91+
92+
# =========================== 3. Multi turn rollout ===========================
93+
async def callback(completions: ChatCompletion, info: Dict[str, Any]):
94+
messages, round = info["messages"], info["round"]
95+
message = completions.choices[0].message
96+
messages.append({"role": message.role, "content": message.content})
97+
print(f"[round={round}] role: {message.role}, content: {message.content}")
98+
99+
extra_headers = {"x-request-id": completions.id}
100+
if round == 0:
101+
messages.append({"role": "user", "content": "What is your name?"})
102+
await async_chat_scheduler.submit_chat_completions(
103+
callback=callback,
104+
callback_additional_info={
105+
"messages": messages,
106+
"round": 1
107+
},
108+
model=model_name,
109+
messages=messages,
110+
extra_headers=extra_headers,
111+
)
112+
elif round == 1:
113+
messages.append({"role": "user", "content": "What is your favorite color?"})
114+
await async_chat_scheduler.submit_chat_completions(
115+
callback=callback,
116+
callback_additional_info={
117+
"messages": messages,
118+
"round": 2
119+
},
120+
model=model_name,
121+
messages=messages,
122+
extra_headers=extra_headers,
123+
)
124+
else:
125+
print("Done!")
126+
127+
messages = [{
128+
"role": "user",
129+
"content": "Let's play a role playing game. Your name is Bob, your favorite color is red."
130+
}]
131+
await async_chat_scheduler.submit_chat_completions(
132+
callback=callback,
133+
callback_additional_info={
134+
"messages": messages,
135+
"round": 0
136+
},
137+
model=model_name,
138+
messages=messages,
139+
)
140+
assert len(messages) == 6
141+
for round, message in enumerate(messages):
142+
if round % 2 == 0:
143+
assert message["role"] == "user"
144+
else:
145+
assert message["role"] == "assistant"
146+
147+
148+
if __name__ == "__main__":
149+
asyncio.run(test_vllm_multi_turn())

verl/single_controller/ray/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def _unwrap_ray_remote(cls):
475475

476476
def _nearest_common_base(mros: List):
477477
last_common = object
478-
min_len = min([len(mro) for mro in mros])
478+
min_len = min([len(mro) for mro in mros]) - 1 # exclude final derived class
479479

480480
for i in range(min_len):
481481
mro = mros[0][i]
@@ -487,15 +487,16 @@ def _nearest_common_base(mros: List):
487487
return last_common
488488

489489

490-
def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]):
490+
def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs], worker_cls: type = None):
491491
"""
492492
This function should return a class instance that delegates the calls to every
493493
cls in cls_dict
494494
"""
495495
cls_dict = {}
496496
init_args_dict = {}
497-
worker_cls = _nearest_common_base(
498-
[list(reversed(cls.cls.__ray_actor_class__.__mro__)) for cls in class_dict.values()])
497+
if worker_cls is None:
498+
worker_cls = _nearest_common_base(
499+
[list(reversed(cls.cls.__ray_actor_class__.__mro__)) for cls in class_dict.values()])
499500
assert issubclass(worker_cls, Worker), f"worker_cls {worker_cls} should be a subclass of Worker"
500501
print(f"find nearest common base class {worker_cls}")
501502

verl/workers/rollout/chat_scheduler.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import heapq
2-
from abc import ABC, abstractmethod
32
from uuid import uuid4
43
from typing import Any, Callable, Dict, List
54

@@ -12,7 +11,7 @@
1211
from verl.protocol import DataProto
1312

1413

15-
class ChatCompletionScheduler(ABC):
14+
class ChatCompletionScheduler:
1615

1716
def __init__(self, config: DictConfig, model_path: str, server_addresses: List[str], max_cache_size: int = 10000):
1817
"""
@@ -52,8 +51,16 @@ async def submit_chat_completions(
5251
**chat_complete_request: dict, request parameters same as OpenAI AsyncCompletions.create.
5352
OpenAI API reference: https://platform.openai.com/docs/api-reference/chat/create
5453
"""
55-
request_id = chat_complete_request.get("extra_headers", {}).get("x-request-id", None)
54+
if "extra_headers" not in chat_complete_request:
55+
chat_complete_request["extra_headers"] = {}
56+
57+
extra_headers = chat_complete_request["extra_headers"]
58+
request_id = extra_headers.get("x-request-id", None)
5659
if request_id:
60+
if request_id.startswith("chatcmpl-"):
61+
request_id = request_id[len("chatcmpl-"):]
62+
extra_headers["x-request-id"] = request_id
63+
5764
address = self.request_id_to_address[request_id]
5865
else:
5966
address = self.weighted_addresses[0][1]
@@ -62,8 +69,6 @@ async def submit_chat_completions(
6269

6370
request_id = uuid4().hex
6471
self.request_id_to_address[request_id] = address
65-
if "extra_headers" not in chat_complete_request:
66-
chat_complete_request["extra_headers"] = {}
6772
chat_complete_request["extra_headers"]["x-request-id"] = request_id
6873

6974
# TODO: OpenAI client uses httpx, seems to have performance issue in high concurrency requests.
@@ -91,6 +96,5 @@ async def _chat_completions_aiohttp(self, address: str, **chat_complete_request)
9196
finally:
9297
await session.close()
9398

94-
@abstractmethod
9599
async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto:
96100
raise NotImplementedError

0 commit comments

Comments
 (0)