Skip to content

Commit 903db09

Browse files
authored
[infer] vllm video/audio inference (#7566)
1 parent aaf2e6b commit 903db09

File tree

10 files changed

+317
-273
lines changed

10 files changed

+317
-273
lines changed

scripts/vllm_infer.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,20 @@ def vllm_infer(
9292
multi_modal_data = {
9393
"image": template_obj.mm_plugin._regularize_images(
9494
sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
95-
)
95+
)["images"]
9696
}
97+
elif sample["videos"]:
98+
multi_modal_data = {
99+
"video": template_obj.mm_plugin._regularize_videos(
100+
sample["videos"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
101+
)["videos"]
102+
}
103+
elif sample["audios"]:
104+
audio_data = template_obj.mm_plugin._regularize_audios(
105+
sample["audios"],
106+
sampling_rate=16000,
107+
)
108+
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
97109
else:
98110
multi_modal_data = None
99111

@@ -131,7 +143,7 @@ def vllm_infer(
131143
"enable_lora": model_args.adapter_name_or_path is not None,
132144
}
133145
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
134-
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
146+
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
135147

136148
if isinstance(model_args.vllm_config, dict):
137149
engine_args.update(model_args.vllm_config)

src/llamafactory/api/chat.py

+41-8
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from ..data import Role as DataRole
2525
from ..extras import logging
26-
from ..extras.constants import IMAGE_PLACEHOLDER
26+
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
2727
from ..extras.misc import is_env_enabled
2828
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
2929
from .common import dictify, jsonify
@@ -56,7 +56,7 @@
5656

5757
if TYPE_CHECKING:
5858
from ..chat import ChatModel
59-
from ..data.mm_plugin import ImageInput
59+
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
6060
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
6161

6262

@@ -72,7 +72,14 @@
7272

7373
def _process_request(
7474
request: "ChatCompletionRequest",
75-
) -> tuple[list[dict[str, str]], Optional[str], Optional[str], Optional[list["ImageInput"]]]:
75+
) -> tuple[
76+
list[dict[str, str]],
77+
Optional[str],
78+
Optional[str],
79+
Optional[list["ImageInput"]],
80+
Optional[list["VideoInput"]],
81+
Optional[list["AudioInput"]],
82+
]:
7683
if is_env_enabled("API_VERBOSE", "1"):
7784
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
7885

@@ -88,7 +95,7 @@ def _process_request(
8895
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
8996

9097
input_messages = []
91-
images = []
98+
images, videos, audios = [], [], []
9299
for i, message in enumerate(request.messages):
93100
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
94101
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@@ -107,7 +114,7 @@ def _process_request(
107114
for input_item in message.content:
108115
if input_item.type == "text":
109116
text_content += input_item.text
110-
else:
117+
elif input_item.type == "image_url":
111118
text_content += IMAGE_PLACEHOLDER
112119
image_url = input_item.image_url.url
113120
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
@@ -118,6 +125,28 @@ def _process_request(
118125
image_stream = requests.get(image_url, stream=True).raw
119126

120127
images.append(Image.open(image_stream).convert("RGB"))
128+
elif input_item.type == "video_url":
129+
text_content += VIDEO_PLACEHOLDER
130+
video_url = input_item.video_url.url
131+
if os.path.isfile(video_url): # local file
132+
video_stream = open(video_url, "rb")
133+
else: # web uri
134+
video_stream = requests.get(video_url, stream=True).raw
135+
136+
videos.append(video_stream)
137+
elif input_item.type == "audio_url":
138+
text_content += AUDIO_PLACEHOLDER
139+
audio_url = input_item.audio_url.url
140+
if os.path.isfile(audio_url): # local file
141+
audio_stream = open(audio_url, "rb")
142+
else: # web uri
143+
audio_stream = requests.get(audio_url, stream=True).raw
144+
145+
audios.append(audio_stream)
146+
else:
147+
raise HTTPException(
148+
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}."
149+
)
121150

122151
input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
123152
else:
@@ -132,7 +161,7 @@ def _process_request(
132161
else:
133162
tools = None
134163

135-
return input_messages, system, tools, images or None
164+
return input_messages, system, tools, images or None, videos or None, audios or None
136165

137166

138167
def _create_stream_chat_completion_chunk(
@@ -151,12 +180,14 @@ async def create_chat_completion_response(
151180
request: "ChatCompletionRequest", chat_model: "ChatModel"
152181
) -> "ChatCompletionResponse":
153182
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
154-
input_messages, system, tools, images = _process_request(request)
183+
input_messages, system, tools, images, videos, audios = _process_request(request)
155184
responses = await chat_model.achat(
156185
input_messages,
157186
system,
158187
tools,
159188
images,
189+
videos,
190+
audios,
160191
do_sample=request.do_sample,
161192
temperature=request.temperature,
162193
top_p=request.top_p,
@@ -202,7 +233,7 @@ async def create_stream_chat_completion_response(
202233
request: "ChatCompletionRequest", chat_model: "ChatModel"
203234
) -> AsyncGenerator[str, None]:
204235
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
205-
input_messages, system, tools, images = _process_request(request)
236+
input_messages, system, tools, images, videos, audios = _process_request(request)
206237
if tools:
207238
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
208239

@@ -217,6 +248,8 @@ async def create_stream_chat_completion_response(
217248
system,
218249
tools,
219250
images,
251+
videos,
252+
audios,
220253
do_sample=request.do_sample,
221254
temperature=request.temperature,
222255
top_p=request.top_p,

src/llamafactory/api/protocol.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,17 @@ class FunctionCall(BaseModel):
7070
function: Function
7171

7272

73-
class ImageURL(BaseModel):
73+
class URL(BaseModel):
7474
url: str
75+
detail: Literal["auto", "low", "high"] = "auto"
7576

7677

7778
class MultimodalInputItem(BaseModel):
78-
type: Literal["text", "image_url"]
79+
type: Literal["text", "image_url", "video_url", "audio_url"]
7980
text: Optional[str] = None
80-
image_url: Optional[ImageURL] = None
81+
image_url: Optional[URL] = None
82+
video_url: Optional[URL] = None
83+
audio_url: Optional[URL] = None
8184

8285

8386
class ChatMessage(BaseModel):

src/llamafactory/chat/sglang_engine.py

+10-17
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
if is_sglang_available():
36-
from sglang.utils import launch_server_cmd, terminate_process, wait_for_server
36+
from sglang.utils import launch_server_cmd, terminate_process, wait_for_server # type: ignore
3737

3838

3939
if TYPE_CHECKING:
@@ -134,24 +134,17 @@ async def _generate(
134134
audios: Optional[list["AudioInput"]] = None,
135135
**input_kwargs,
136136
) -> AsyncIterator[dict[str, Any]]:
137-
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
138-
if images is not None:
139-
mm_input_dict.update({"images": images, "imglens": [len(images)]})
140-
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
141-
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
142-
143-
if videos is not None:
144-
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
145-
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
146-
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
147-
148-
if audios is not None:
149-
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
150-
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
151-
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
137+
if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
138+
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
139+
140+
if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
141+
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
142+
143+
if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
144+
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
152145

153146
messages = self.template.mm_plugin.process_messages(
154-
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor
147+
messages, images or [], videos or [], audios or [], self.processor
155148
)
156149
paired_messages = messages + [{"role": "assistant", "content": ""}]
157150
system = system or self.generating_args["default_system"]

src/llamafactory/chat/vllm_engine.py

+27-18
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
"max_lora_rank": model_args.vllm_max_lora_rank,
8484
}
8585
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
86-
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
86+
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
8787

8888
if isinstance(model_args.vllm_config, dict):
8989
engine_args.update(model_args.vllm_config)
@@ -111,24 +111,17 @@ async def _generate(
111111
**input_kwargs,
112112
) -> AsyncIterator["RequestOutput"]:
113113
request_id = f"chatcmpl-{uuid.uuid4().hex}"
114-
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
115-
if images is not None:
116-
mm_input_dict.update({"images": images, "imglens": [len(images)]})
117-
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
118-
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
119-
120-
if videos is not None:
121-
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
122-
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
123-
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
124-
125-
if audios is not None:
126-
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
127-
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
128-
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
114+
if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
115+
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
116+
117+
if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
118+
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
119+
120+
if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
121+
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
129122

130123
messages = self.template.mm_plugin.process_messages(
131-
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor
124+
messages, images or [], videos or [], audios or [], self.processor
132125
)
133126
paired_messages = messages + [{"role": "assistant", "content": ""}]
134127
system = system or self.generating_args["default_system"]
@@ -186,8 +179,24 @@ async def _generate(
186179
images,
187180
image_max_pixels=self.model_args.image_max_pixels,
188181
image_min_pixels=self.model_args.image_min_pixels,
189-
)
182+
)["images"]
190183
}
184+
elif videos is not None:
185+
multi_modal_data = {
186+
"video": self.template.mm_plugin._regularize_videos(
187+
videos,
188+
image_max_pixels=self.model_args.video_max_pixels,
189+
image_min_pixels=self.model_args.video_min_pixels,
190+
video_fps=self.model_args.video_fps,
191+
video_maxlen=self.model_args.video_maxlen,
192+
)["videos"]
193+
}
194+
elif audios is not None:
195+
audio_data = self.template.mm_plugin._regularize_audios(
196+
audios,
197+
sampling_rate=self.model_args.audio_sampling_rate,
198+
)
199+
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
191200
else:
192201
multi_modal_data = None
193202

src/llamafactory/data/converter.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@
2626
from transformers import Seq2SeqTrainingArguments
2727

2828
from ..hparams import DataArguments
29+
from .mm_plugin import AudioInput, ImageInput, VideoInput
2930
from .parser import DatasetAttr
3031

32+
MediaType = Union[ImageInput, VideoInput, AudioInput]
33+
34+
3135
logger = logging.get_logger(__name__)
3236

3337

@@ -36,10 +40,12 @@ class DatasetConverter:
3640
dataset_attr: "DatasetAttr"
3741
data_args: "DataArguments"
3842

39-
def _find_medias(self, medias: Union[Any, list[Any]]) -> Optional[list[Any]]:
43+
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
4044
r"""Optionally concatenate media path to media dir when loading from local disk."""
41-
if not isinstance(medias, list):
42-
medias = [medias] if medias is not None else []
45+
if medias is None:
46+
return None
47+
elif not isinstance(medias, list):
48+
medias = [medias]
4349
elif len(medias) == 0:
4450
return None
4551
else:

0 commit comments

Comments
 (0)