Skip to content

Commit 32cb086

Browse files
Kuangdd01hiyouga
andauthored
[data] fix qwen2.5 omni plugin (#7578)
* specific entry * Update mm_plugin.py * fix fps cal --------- Co-authored-by: hoshi-hiyouga <[email protected]>
1 parent 80f8d03 commit 32cb086

File tree

2 files changed

+10
-25
lines changed

2 files changed

+10
-25
lines changed

src/llamafactory/data/collator.py

+2
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
192192
}
193193
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
194194
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
195+
if "video_second_per_grid" in mm_inputs: # for qwen2omni
196+
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
195197

196198
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
197199
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)

src/llamafactory/data/mm_plugin.py

+8-25
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,9 @@ def _get_mm_inputs(
11791179
video_maxlen=getattr(processor, "video_maxlen", 128),
11801180
)
11811181
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
1182-
mm_inputs["fps_per_video"] = video_data["fps_per_video"]
1182+
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
1183+
if "second_per_grid_ts" in processor.model_input_names:
1184+
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
11831185

11841186
return mm_inputs
11851187

@@ -1238,28 +1240,6 @@ def process_messages(
12381240

12391241
return messages
12401242

1241-
@override
1242-
def get_mm_inputs(
1243-
self,
1244-
images: list["ImageInput"],
1245-
videos: list["VideoInput"],
1246-
audios: list["AudioInput"],
1247-
imglens: list[int],
1248-
vidlens: list[int],
1249-
audlens: list[int],
1250-
batch_ids: list[list[int]],
1251-
processor: Optional["MMProcessor"],
1252-
) -> dict[str, Union[list[int], "torch.Tensor"]]:
1253-
self._validate_input(processor, images, videos, audios)
1254-
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
1255-
fps_per_video = mm_inputs.pop("fps_per_video", [])
1256-
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
1257-
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
1258-
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
1259-
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in fps_per_video]
1260-
1261-
return mm_inputs
1262-
12631243

12641244
class Qwen2OmniPlugin(Qwen2VLPlugin):
12651245
@override
@@ -1290,7 +1270,10 @@ def _get_mm_inputs(
12901270
video_maxlen=getattr(processor, "video_maxlen", 128),
12911271
)
12921272
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
1293-
mm_inputs["fps_per_video"] = video_dict["fps_per_video"]
1273+
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
1274+
mm_inputs["video_second_per_grid"] = torch.tensor(
1275+
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
1276+
)
12941277

12951278
if len(audios) != 0:
12961279
audios = self._regularize_audios(
@@ -1405,7 +1388,7 @@ def process_messages(
14051388
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
14061389
)
14071390
.flatten()
1408-
* mm_inputs["second_per_grid_ts"][num_video_tokens]
1391+
* mm_inputs["video_second_per_grid"][num_video_tokens]
14091392
* 25 # FIXME hardcode of position_id_per_seconds=25
14101393
).long()
14111394
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]

0 commit comments

Comments
 (0)