Skip to content

Commit 3b97241

Browse files
wuxibin89PeterSH6
authored andcommitted
[rollout] feat: introduce vLLM AsyncLLM to support multi-turn rollout (volcengine#1138)
### Summary Introduce vLLM AsyncLLM to support multi-turn rollout and volcengine#385 volcengine#398 volcengine#710 ### Architecture ![async_llm_arch](https://github.com/user-attachments/assets/e8cd974c-0c26-4d96-9a9e-b71fd85dd32d) **New Components**: - AsyncLLMWorker: standalone vllm server instance - FastAPI: provide OpenAI-compatible HTTP server - AsyncLLM: async LLMEngine for online serving, for more details: [AsyncLLM](vllm-project/vllm#9826), [LLMEngine](https://docs.vllm.ai/en/latest/design/arch_overview.html#llmengine) - ExternalRayDistributedExecutor: custom executor backend manages workers in worker group, it grabs corresponding workers by actor names - AsyncLLManager: manages a group of vllm server instances(AsyncLLMWorker) - AsyncLLM lifecycle: initialization, wake_up, sleep. - FastAPI service discovery - ChatScheduler: schedule multiple chat completion requests with multiple server instances - Least requests load balance - Sticky session with prefix caching - Chat completion callback: tools calling ### TODO - [x] AsyncLLM: intialization/wake_up/sleep - [x] OpenAI API: support `/v1/chat/completions` - [x] RayPPOTrainer integration: replace `generate_sequences` to http call `/v1/chat/completions` - [x] GSM8K e2e training - [ ] Add document --------- Co-authored-by: shengguangming <[email protected]>
1 parent 5983ce7 commit 3b97241

28 files changed

+1207
-121
lines changed

.github/workflows/vllm.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,8 @@ jobs:
8484
cd tests/generation
8585
export OUTPUT_PATH="${HOME}/data/gen/qwen_05_gen_test.parquet"
8686
MODEL_ID=Qwen/Qwen2.5-0.5B-Instruct NGPUS_PER_NODE=1 GEN_TP=1 bash ./run_gen_qwen05.sh
87-
rm -rf "${OUTPUT_PATH}"
87+
rm -rf "${OUTPUT_PATH}"
88+
- name: Running multi-turn rollout tests on 8 L20 GPUs
89+
run: |
90+
pip3 install --upgrade vllm==0.8.3 tensordict==0.7.2
91+
python3 tests/rollout/test_vllm_multi_turn.py

examples/grpo_trainer/run_qwen2-7b_seq_balance.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,18 @@ set -x
33
# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
44
# export VLLM_ATTENTION_BACKEND=XFORMERS
55

6+
# For async rollout mode, dataset should return raw chat.
7+
rollout_mode="sync"
8+
if [ "$rollout_mode" = "async" ]; then
9+
return_raw_chat="True"
10+
chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler
11+
fi
12+
613
python3 -m verl.trainer.main_ppo \
714
algorithm.adv_estimator=grpo \
815
data.train_files=$HOME/data/gsm8k/train.parquet \
916
data.val_files=$HOME/data/gsm8k/test.parquet \
17+
data.return_raw_chat=$return_raw_chat \
1018
data.train_batch_size=1024 \
1119
data.max_prompt_length=512 \
1220
data.max_response_length=1024 \
@@ -27,6 +35,8 @@ python3 -m verl.trainer.main_ppo \
2735
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
2836
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
2937
actor_rollout_ref.rollout.name=vllm \
38+
actor_rollout_ref.rollout.mode=$rollout_mode \
39+
actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \
3040
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
3141
actor_rollout_ref.rollout.n=5 \
3242
actor_rollout_ref.ref.fsdp_config.param_offload=True \
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import asyncio
15+
from typing import Any, Dict, List
16+
17+
import torch
18+
from omegaconf import DictConfig
19+
from openai.types.chat.chat_completion import ChatCompletion
20+
from tensordict import TensorDict
21+
22+
from verl.protocol import DataProto
23+
from verl.workers.rollout.async_server import ChatCompletionScheduler
24+
25+
26+
class NaiveChatCompletionScheduler(ChatCompletionScheduler):
27+
"""
28+
A very naive implementation of ChatCompletionScheduler for demo purpose,
29+
only do single-turn chat completion.
30+
"""
31+
32+
def __init__(
33+
self,
34+
config: DictConfig,
35+
model_path: str,
36+
server_addresses: List[str],
37+
max_cache_size: int = 10000,
38+
):
39+
super().__init__(config, model_path, server_addresses, max_cache_size)
40+
41+
async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataProto:
42+
kwargs = dict(
43+
n=self.config.n,
44+
max_completion_tokens=self.config.response_length,
45+
temperature=self.config.temperature,
46+
top_p=self.config.top_p,
47+
)
48+
49+
do_sample = batch.meta_info.get("do_sample", True)
50+
is_validate = batch.meta_info.get("validate", False)
51+
if not do_sample or is_validate:
52+
kwargs["n"] = 1
53+
kwargs["temperature"] = 0
54+
55+
kwargs.update(sampling_params)
56+
print(f"[NaiveChatCompletionScheduler] generate_sequences sampling params: {kwargs}")
57+
58+
async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception):
59+
conversation, batch_conversations, batch_index = (
60+
info["conversation"],
61+
info["batch_conversations"],
62+
info["batch_index"],
63+
)
64+
65+
conversations = []
66+
for choice in completions.choices:
67+
chat = conversation.copy()
68+
chat.append({"role": choice.message.role, "content": choice.message.content})
69+
conversations.append(chat)
70+
batch_conversations[batch_index] = conversations
71+
72+
# NOTE: we can call tools and resubmit chat completions here.
73+
# call_tools(completions, info)
74+
# await self.submit_chat_completions(callback2, ...)
75+
76+
tasks, batch_conversations = [], [None] * len(batch)
77+
for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"]):
78+
# raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...]
79+
tasks.append(
80+
asyncio.create_task(
81+
self.submit_chat_completions(
82+
callback=callback,
83+
callback_additional_info={
84+
"batch_conversations": batch_conversations,
85+
"batch_index": batch_index,
86+
"conversation": list(conversation),
87+
},
88+
model=self.model_name,
89+
messages=conversation,
90+
**kwargs,
91+
)
92+
)
93+
)
94+
await asyncio.gather(*tasks)
95+
print("[NaiveChatCompletionScheduler] generate_sequences done")
96+
97+
return self._postprocess(batch, batch_conversations, kwargs["n"])
98+
99+
def _postprocess(
100+
self, batch: DataProto, batch_conversations: List[List[List[Dict[str, str]]]], n: int
101+
) -> DataProto:
102+
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
103+
# prompts: left pad
104+
# responses: right pad
105+
# input_ids: prompt + response
106+
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
107+
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
108+
109+
# prompts: [prompt] from input dataset
110+
prompts = [
111+
self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False)
112+
for prompt in batch.non_tensor_batch["raw_prompt"]
113+
]
114+
115+
# flatten batch_conversations if n > 1
116+
assert len(batch_conversations) == len(prompts)
117+
batch_conversations = [conversation for conversations in batch_conversations for conversation in conversations]
118+
assert len(batch_conversations) == len(prompts) * n
119+
120+
# sequences: [prompt + response]
121+
sequences = [
122+
self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False)
123+
for conversation in batch_conversations
124+
]
125+
126+
# responses: [response]
127+
# TODO: mask out tools calling tokens?
128+
responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)]
129+
130+
prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left")
131+
responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right")
132+
if n > 1:
133+
prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0)
134+
prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0)
135+
136+
input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1)
137+
attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1)
138+
position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask
139+
140+
batch = TensorDict(
141+
{
142+
"prompts": prompts["input_ids"],
143+
"responses": responses["input_ids"],
144+
"input_ids": input_ids,
145+
"attention_mask": attention_mask,
146+
"position_ids": position_ids,
147+
},
148+
batch_size=len(input_ids),
149+
)
150+
151+
return DataProto(batch=batch)

examples/ppo_trainer/run_qwen2-7b_seq_balance.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,18 @@ math_test_path=$HOME/data/math/test.parquet
88
train_files="['$gsm8k_train_path', '$math_train_path']"
99
test_files="['$gsm8k_test_path', '$math_test_path']"
1010

11+
# For async rollout mode, dataset should return raw chat.
12+
rollout_mode="sync"
13+
if [ "$rollout_mode" = "async" ]; then
14+
return_raw_chat="True"
15+
chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler
16+
fi
17+
1118
python3 -m verl.trainer.main_ppo \
1219
algorithm.adv_estimator=gae \
1320
data.train_files="$train_files" \
1421
data.val_files="$test_files" \
22+
data.return_raw_chat=$return_raw_chat \
1523
data.train_batch_size=4096 \
1624
data.max_prompt_length=4096 \
1725
data.max_response_length=4096 \
@@ -29,6 +37,8 @@ python3 -m verl.trainer.main_ppo \
2937
actor_rollout_ref.actor.use_kl_loss=False \
3038
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
3139
actor_rollout_ref.rollout.name=vllm \
40+
actor_rollout_ref.rollout.mode=$rollout_mode \
41+
actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \
3242
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
3343
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \
3444
critic.optim.lr=1e-5 \

recipe/dapo/src/dapo_ray_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class RayDAPOTrainer(RayPPOTrainer):
4040
Note that this trainer runs on the driver process on a single CPU/GPU node.
4141
"""
4242

43-
def fit(self):
43+
async def fit(self):
4444
"""
4545
The training loop of PPO.
4646
The driver process only need to call the compute functions of the worker group through RPC

recipe/dapo/src/main_dapo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def run_ppo(config) -> None:
7676

7777
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
7878
class TaskRunner:
79-
def run(self, config):
79+
80+
async def run(self, config):
8081
# print initial config
8182
from pprint import pprint
8283

@@ -201,7 +202,7 @@ def run(self, config):
201202
val_reward_fn=val_reward_fn,
202203
)
203204
trainer.init_workers()
204-
trainer.fit()
205+
await trainer.fit()
205206

206207

207208
if __name__ == "__main__":

recipe/prime/main_prime.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
3030
"""
3131

32+
import asyncio
33+
3234
import hydra
3335
import ray
3436

@@ -53,6 +55,10 @@ def run_prime(config, compute_score=None):
5355

5456
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
5557
def main_task(config, compute_score=None):
58+
asyncio.run(_main_task(config, compute_score))
59+
60+
61+
async def _main_task(config, compute_score=None):
5662
# print initial config
5763
from pprint import pprint
5864

@@ -142,7 +148,7 @@ def main_task(config, compute_score=None):
142148
val_reward_fn=val_reward_fn,
143149
)
144150
trainer.init_workers()
145-
trainer.fit()
151+
await trainer.fit()
146152

147153

148154
if __name__ == "__main__":

recipe/prime/prime_ray_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def _load_checkpoint(self):
331331
if isinstance(self.train_dataloader.dataset, RLHFDataset):
332332
self.train_dataloader.dataset.resume_dataset_state()
333333

334-
def fit(self):
334+
async def fit(self):
335335
"""
336336
The training loop of PPO.
337337
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.

tests/ray/test_worker_group_basics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def foo_custom(self, x, y):
6868
@ray.remote(num_gpus=0.1)
6969
def remote_call_wg(worker_names):
7070
class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)
71-
worker_group = RayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=class_with_args)
71+
worker_group = RayWorkerGroup.from_detached(
72+
worker_names=worker_names, ray_cls_with_init=class_with_args, name_prefix=None
73+
)
7274
print(worker_group.worker_names)
7375

7476
output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6])

0 commit comments

Comments
 (0)