@@ -1179,7 +1179,9 @@ def _get_mm_inputs(
1179
1179
video_maxlen = getattr (processor , "video_maxlen" , 128 ),
1180
1180
)
1181
1181
mm_inputs .update (image_processor (images = None , videos = video_data ["videos" ], return_tensors = "pt" ))
1182
- mm_inputs ["fps_per_video" ] = video_data ["fps_per_video" ]
1182
+ temporal_patch_size : int = getattr (image_processor , "temporal_patch_size" , 2 )
1183
+ if "second_per_grid_ts" in processor .model_input_names :
1184
+ mm_inputs ["second_per_grid_ts" ] = [temporal_patch_size / fps for fps in video_data ["fps_per_video" ]]
1183
1185
1184
1186
return mm_inputs
1185
1187
@@ -1238,28 +1240,6 @@ def process_messages(
1238
1240
1239
1241
return messages
1240
1242
1241
- @override
1242
- def get_mm_inputs (
1243
- self ,
1244
- images : list ["ImageInput" ],
1245
- videos : list ["VideoInput" ],
1246
- audios : list ["AudioInput" ],
1247
- imglens : list [int ],
1248
- vidlens : list [int ],
1249
- audlens : list [int ],
1250
- batch_ids : list [list [int ]],
1251
- processor : Optional ["MMProcessor" ],
1252
- ) -> dict [str , Union [list [int ], "torch.Tensor" ]]:
1253
- self ._validate_input (processor , images , videos , audios )
1254
- mm_inputs = self ._get_mm_inputs (images , videos , audios , processor )
1255
- fps_per_video = mm_inputs .pop ("fps_per_video" , [])
1256
- image_processor : BaseImageProcessor = getattr (processor , "image_processor" )
1257
- temporal_patch_size : int = getattr (image_processor , "temporal_patch_size" , 2 )
1258
- if "second_per_grid_ts" in processor .model_input_names and fps_per_video :
1259
- mm_inputs ["second_per_grid_ts" ] = [temporal_patch_size / fps for fps in fps_per_video ]
1260
-
1261
- return mm_inputs
1262
-
1263
1243
1264
1244
class Qwen2OmniPlugin (Qwen2VLPlugin ):
1265
1245
@override
@@ -1290,7 +1270,10 @@ def _get_mm_inputs(
1290
1270
video_maxlen = getattr (processor , "video_maxlen" , 128 ),
1291
1271
)
1292
1272
mm_inputs .update (image_processor (images = None , videos = video_dict ["videos" ], return_tensors = "pt" ))
1293
- mm_inputs ["fps_per_video" ] = video_dict ["fps_per_video" ]
1273
+ temporal_patch_size : int = getattr (image_processor , "temporal_patch_size" , 2 )
1274
+ mm_inputs ["video_second_per_grid" ] = torch .tensor (
1275
+ [temporal_patch_size / fps for fps in video_dict ["fps_per_video" ]]
1276
+ )
1294
1277
1295
1278
if len (audios ) != 0 :
1296
1279
audios = self ._regularize_audios (
@@ -1405,7 +1388,7 @@ def process_messages(
1405
1388
video_grid_thw [num_video_tokens ][2 ] // self .image_processor .merge_size ,
1406
1389
)
1407
1390
.flatten ()
1408
- * mm_inputs ["second_per_grid_ts " ][num_video_tokens ]
1391
+ * mm_inputs ["video_second_per_grid " ][num_video_tokens ]
1409
1392
* 25 # FIXME hardcode of position_id_per_seconds=25
1410
1393
).long ()
1411
1394
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
0 commit comments