Skip to content

Commit 80f8d03

Browse files
authored
[data] fix qwen2.5 omni plugin (#7573)
* align key with qwen2vl * nit && change scripts
1 parent 1199759 commit 80f8d03

File tree

4 files changed

+47
-6
lines changed

4 files changed

+47
-6
lines changed

scripts/lora_part_merge.py scripts/qwen_omni_merge.py

+45-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import fire
2121
from peft import PeftModel
22-
from transformers import AutoModel, AutoProcessor, AutoTokenizer
22+
from transformers import AutoModel, AutoProcessor, AutoTokenizer, Qwen2_5OmniThinkerForConditionalGeneration
2323

2424

2525
def merge_lora(
@@ -31,7 +31,7 @@ def merge_lora(
3131
):
3232
"""Load the original model, tokenizer, and processor configuration, merge the LoRA weights.
3333
34-
for a specified submodule, and save the final merged model along with its configurations.
34+
For a specified submodule, and save the final merged model along with its configurations.
3535
3636
Args:
3737
base_model_path (str): Path to the original model directory.
@@ -86,5 +86,47 @@ def merge_lora(
8686
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
8787

8888

89+
def save_full_model(
90+
saved_thinker_path: str,
91+
base_model_path: str,
92+
save_path: str,
93+
extra_file: str = "spk_dict.pt",
94+
):
95+
"""Load the saved thinker module and the original model, replace the thinker in the original model.
96+
97+
Then save the complete model along with its tokenizer and processor configuration.
98+
99+
Args:
100+
saved_thinker_path (str): Path to the saved thinker weights.
101+
base_model_path (str): Directory path of the original model.
102+
save_path (str): Directory where the final complete model will be saved.
103+
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
104+
"""
105+
# Load the thinker module
106+
thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(saved_thinker_path, device_map="cpu")
107+
# Load the original model
108+
base_model = AutoModel.from_pretrained(base_model_path, device_map="cpu")
109+
# Replace the thinker module in the original model
110+
base_model.thinker = thinker
111+
112+
# Load the processor and tokenizer
113+
processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
114+
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
115+
116+
# Save the complete model along with its configurations
117+
base_model.save_pretrained(save_path)
118+
tokenizer.save_pretrained(save_path)
119+
processor.save_pretrained(save_path)
120+
print(f"Complete model, tokenizer, and processor configuration have been saved to {save_path}.")
121+
122+
source_file = os.path.join(base_model_path, extra_file)
123+
target_file = os.path.join(save_path, extra_file)
124+
if os.path.exists(source_file):
125+
shutil.copy(source_file, target_file)
126+
print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
127+
else:
128+
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
129+
130+
89131
if __name__ == "__main__":
90-
fire.Fire(merge_lora)
132+
fire.Fire({"save_full": save_full_model, "merge_lora": merge_lora})

src/llamafactory/data/collator.py

-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
203203

204204
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
205205
# avoid conflict
206-
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid", None)
207206
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
208207
features["position_ids"], features["rope_deltas"] = (
209208
new_position_ids.clone(),

src/llamafactory/data/mm_plugin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1405,7 +1405,7 @@ def process_messages(
14051405
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
14061406
)
14071407
.flatten()
1408-
* mm_inputs["video_second_per_grid"][num_video_tokens]
1408+
* mm_inputs["second_per_grid_ts"][num_video_tokens]
14091409
* 25 # FIXME hardcode of position_id_per_seconds=25
14101410
).long()
14111411
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]

src/llamafactory/model/loader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def load_model(
157157
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
158158
else:
159159
model = load_class.from_pretrained(**init_kwargs)
160-
if load_class is AutoModelForTextToWaveform:
160+
if getattr(model.config, "model_type", None) == "qwen2_5_omni":
161161
model = model.thinker # use part of Omni model
162162

163163
if model_args.mixture_of_depths == "convert":

0 commit comments

Comments
 (0)