19
19
20
20
import fire
21
21
from peft import PeftModel
22
- from transformers import AutoModel , AutoProcessor , AutoTokenizer
22
+ from transformers import AutoModel , AutoProcessor , AutoTokenizer , Qwen2_5OmniThinkerForConditionalGeneration
23
23
24
24
25
25
def merge_lora (
@@ -31,7 +31,7 @@ def merge_lora(
31
31
):
32
32
"""Load the original model, tokenizer, and processor configuration, merge the LoRA weights.
33
33
34
- for a specified submodule, and save the final merged model along with its configurations.
34
+ For a specified submodule, and save the final merged model along with its configurations.
35
35
36
36
Args:
37
37
base_model_path (str): Path to the original model directory.
@@ -86,5 +86,47 @@ def merge_lora(
86
86
print (f"File '{ extra_file } ' not found in { base_model_path } , skipping copy." )
87
87
88
88
89
+ def save_full_model (
90
+ saved_thinker_path : str ,
91
+ base_model_path : str ,
92
+ save_path : str ,
93
+ extra_file : str = "spk_dict.pt" ,
94
+ ):
95
+ """Load the saved thinker module and the original model, replace the thinker in the original model.
96
+
97
+ Then save the complete model along with its tokenizer and processor configuration.
98
+
99
+ Args:
100
+ saved_thinker_path (str): Path to the saved thinker weights.
101
+ base_model_path (str): Directory path of the original model.
102
+ save_path (str): Directory where the final complete model will be saved.
103
+ extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
104
+ """
105
+ # Load the thinker module
106
+ thinker = Qwen2_5OmniThinkerForConditionalGeneration .from_pretrained (saved_thinker_path , device_map = "cpu" )
107
+ # Load the original model
108
+ base_model = AutoModel .from_pretrained (base_model_path , device_map = "cpu" )
109
+ # Replace the thinker module in the original model
110
+ base_model .thinker = thinker
111
+
112
+ # Load the processor and tokenizer
113
+ processor = AutoProcessor .from_pretrained (base_model_path , trust_remote_code = True )
114
+ tokenizer = AutoTokenizer .from_pretrained (base_model_path , trust_remote_code = True )
115
+
116
+ # Save the complete model along with its configurations
117
+ base_model .save_pretrained (save_path )
118
+ tokenizer .save_pretrained (save_path )
119
+ processor .save_pretrained (save_path )
120
+ print (f"Complete model, tokenizer, and processor configuration have been saved to { save_path } ." )
121
+
122
+ source_file = os .path .join (base_model_path , extra_file )
123
+ target_file = os .path .join (save_path , extra_file )
124
+ if os .path .exists (source_file ):
125
+ shutil .copy (source_file , target_file )
126
+ print (f"File '{ extra_file } ' copied from { base_model_path } to { save_path } ." )
127
+ else :
128
+ print (f"File '{ extra_file } ' not found in { base_model_path } , skipping copy." )
129
+
130
+
89
131
if __name__ == "__main__" :
90
- fire .Fire (merge_lora )
132
+ fire .Fire ({ "save_full" : save_full_model , " merge_lora" : merge_lora } )
0 commit comments