23
23
24
24
from ..data import Role as DataRole
25
25
from ..extras import logging
26
- from ..extras .constants import IMAGE_PLACEHOLDER
26
+ from ..extras .constants import AUDIO_PLACEHOLDER , IMAGE_PLACEHOLDER , VIDEO_PLACEHOLDER
27
27
from ..extras .misc import is_env_enabled
28
28
from ..extras .packages import is_fastapi_available , is_pillow_available , is_requests_available
29
29
from .common import dictify , jsonify
56
56
57
57
if TYPE_CHECKING :
58
58
from ..chat import ChatModel
59
- from ..data .mm_plugin import ImageInput
59
+ from ..data .mm_plugin import AudioInput , ImageInput , VideoInput
60
60
from .protocol import ChatCompletionRequest , ScoreEvaluationRequest
61
61
62
62
72
72
73
73
def _process_request (
74
74
request : "ChatCompletionRequest" ,
75
- ) -> tuple [list [dict [str , str ]], Optional [str ], Optional [str ], Optional [list ["ImageInput" ]]]:
75
+ ) -> tuple [
76
+ list [dict [str , str ]],
77
+ Optional [str ],
78
+ Optional [str ],
79
+ Optional [list ["ImageInput" ]],
80
+ Optional [list ["VideoInput" ]],
81
+ Optional [list ["AudioInput" ]],
82
+ ]:
76
83
if is_env_enabled ("API_VERBOSE" , "1" ):
77
84
logger .info_rank0 (f"==== request ====\n { json .dumps (dictify (request ), indent = 2 , ensure_ascii = False )} " )
78
85
@@ -88,7 +95,7 @@ def _process_request(
88
95
raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = "Only supports u/a/u/a/u..." )
89
96
90
97
input_messages = []
91
- images = []
98
+ images , videos , audios = [], [], []
92
99
for i , message in enumerate (request .messages ):
93
100
if i % 2 == 0 and message .role not in [Role .USER , Role .TOOL ]:
94
101
raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = "Invalid role" )
@@ -107,7 +114,7 @@ def _process_request(
107
114
for input_item in message .content :
108
115
if input_item .type == "text" :
109
116
text_content += input_item .text
110
- else :
117
+ elif input_item . type == "image_url" :
111
118
text_content += IMAGE_PLACEHOLDER
112
119
image_url = input_item .image_url .url
113
120
if re .match (r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$" , image_url ): # base64 image
@@ -118,6 +125,28 @@ def _process_request(
118
125
image_stream = requests .get (image_url , stream = True ).raw
119
126
120
127
images .append (Image .open (image_stream ).convert ("RGB" ))
128
+ elif input_item .type == "video_url" :
129
+ text_content += VIDEO_PLACEHOLDER
130
+ video_url = input_item .video_url .url
131
+ if os .path .isfile (video_url ): # local file
132
+ video_stream = open (video_url , "rb" )
133
+ else : # web uri
134
+ video_stream = requests .get (video_url , stream = True ).raw
135
+
136
+ videos .append (video_stream )
137
+ elif input_item .type == "audio_url" :
138
+ text_content += AUDIO_PLACEHOLDER
139
+ audio_url = input_item .audio_url .url
140
+ if os .path .isfile (audio_url ): # local file
141
+ audio_stream = open (audio_url , "rb" )
142
+ else : # web uri
143
+ audio_stream = requests .get (audio_url , stream = True ).raw
144
+
145
+ audios .append (audio_stream )
146
+ else :
147
+ raise HTTPException (
148
+ status_code = status .HTTP_400_BAD_REQUEST , detail = f"Invalid input type { input_item .type } ."
149
+ )
121
150
122
151
input_messages .append ({"role" : ROLE_MAPPING [message .role ], "content" : text_content })
123
152
else :
@@ -132,7 +161,7 @@ def _process_request(
132
161
else :
133
162
tools = None
134
163
135
- return input_messages , system , tools , images or None
164
+ return input_messages , system , tools , images or None , videos or None , audios or None
136
165
137
166
138
167
def _create_stream_chat_completion_chunk (
@@ -151,12 +180,14 @@ async def create_chat_completion_response(
151
180
request : "ChatCompletionRequest" , chat_model : "ChatModel"
152
181
) -> "ChatCompletionResponse" :
153
182
completion_id = f"chatcmpl-{ uuid .uuid4 ().hex } "
154
- input_messages , system , tools , images = _process_request (request )
183
+ input_messages , system , tools , images , videos , audios = _process_request (request )
155
184
responses = await chat_model .achat (
156
185
input_messages ,
157
186
system ,
158
187
tools ,
159
188
images ,
189
+ videos ,
190
+ audios ,
160
191
do_sample = request .do_sample ,
161
192
temperature = request .temperature ,
162
193
top_p = request .top_p ,
@@ -202,7 +233,7 @@ async def create_stream_chat_completion_response(
202
233
request : "ChatCompletionRequest" , chat_model : "ChatModel"
203
234
) -> AsyncGenerator [str , None ]:
204
235
completion_id = f"chatcmpl-{ uuid .uuid4 ().hex } "
205
- input_messages , system , tools , images = _process_request (request )
236
+ input_messages , system , tools , images , videos , audios = _process_request (request )
206
237
if tools :
207
238
raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = "Cannot stream function calls." )
208
239
@@ -217,6 +248,8 @@ async def create_stream_chat_completion_response(
217
248
system ,
218
249
tools ,
219
250
images ,
251
+ videos ,
252
+ audios ,
220
253
do_sample = request .do_sample ,
221
254
temperature = request .temperature ,
222
255
top_p = request .top_p ,
0 commit comments