Skip to content

isthere a way to merge model after use P-Tuning #2538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
1 of 4 tasks
lixd opened this issue May 9, 2025 · 3 comments
Open
1 of 4 tasks

isthere a way to merge model after use P-Tuning #2538

lixd opened this issue May 9, 2025 · 3 comments

Comments

@lixd
Copy link

lixd commented May 9, 2025

System Info

peft 0.13.2
transformers 4.46.3

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

there is my code:

# 文件名:peft_p_tuning.py
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForSeq2Seq,
    Trainer,
    TrainingArguments
)
from peft import PromptEncoderConfig, get_peft_model,PeftModel

'''
!pip install datasets==3.1.0 torch==2.4.1 transformers==4.46.3 peft==0.13.2
'''

# 固定参数(直接在这里修改)
CONFIG = {
    # 模型参数
    "model_path": "/mnt/e015a2b7cb4b49f18419022d3fb045ec/models/Qwen2-1___5B-Instruct",  # 模型本地路径
    "data_path": "/mnt/e015a2b7cb4b49f18419022d3fb045ec/dataset/identity.json",  # 数据集本地路径
    "output_dir": "/mnt/e015a2b7cb4b49f18419022d3fb045ec/demo/peft/prompt_output",  # prefix adaptor

    # 训练参数
    "max_length": 256,  # 文本最大长度
    "batch_size": 1,  # 根据显存调整(24G显存可设为2)
    "grad_accum_steps": 8,  # 梯度累积步数
    "learning_rate": 2e-3,  # 学习率
    "num_epochs":5,  # 训练轮数

    # Prompt-Tuning参数(关键修改部分)
    "num_virtual_tokens": 100,        # 虚拟token数量(可保持相同)
    "encoder_num_layers": 4,         # 编码器层数
    "encoder_hidden_size": 1024,     # 编码器隐藏层维度
    "encoder_reparam_type": "MLP",   # 编码器类型 MLP、LSTM
}


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'device: {device}')

    # 1. 加载模型和tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        CONFIG["model_path"],
        torch_dtype=torch.bfloat16,
        #device_map="auto"
    ).to(device)
    tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_path"])
    tokenizer.pad_token = tokenizer.eos_token  # 设置填充token

    # 2. 配置P-Tuning(关键修改部分)
    peft_config = PromptEncoderConfig(
        task_type="CAUSAL_LM",
        num_virtual_tokens=CONFIG["num_virtual_tokens"],
        # encoder_num_layers=CONFIG["encoder_num_layers"],
        encoder_reparameterization_type=CONFIG["encoder_reparam_type"],
        encoder_hidden_size=CONFIG["encoder_hidden_size"]
    )

    model = get_peft_model(model, peft_config)
    # 查看可训练参数
    model.print_trainable_parameters()

    # 3. 处理数据集
    def format_data(sample):
        # 构建对话格式
        prompt = f"Human: {sample['instruction']}\nInput: {sample['input']}\nAssistant: "
        response = f"{sample['output']}{tokenizer.eos_token}"

        # 合并编码
        full_text = prompt + response
        tokenized = tokenizer(
            full_text,
            max_length=CONFIG["max_length"],
            truncation=True,
            padding="max_length"
        )

        # 创建labels(仅计算response的loss)
        prompt_len = len(tokenizer(prompt).input_ids)
        tokenized["labels"] = [-100] * prompt_len + tokenized.input_ids[prompt_len:]
        return tokenized

    dataset = load_dataset("json", data_files=CONFIG["data_path"], split="train")
    dataset = dataset.map(format_data, remove_columns=dataset.column_names)

    # 4. 配置训练器
    training_args = TrainingArguments(
        output_dir=CONFIG["output_dir"],
        per_device_train_batch_size=CONFIG["batch_size"],
        gradient_accumulation_steps=CONFIG["grad_accum_steps"],
        learning_rate=CONFIG["learning_rate"],
        num_train_epochs=CONFIG["num_epochs"],
        logging_steps=10,
        save_strategy="no",  # 关闭检查点保存以简化
        fp16=True,  # 启用混合精度
        report_to="none"  # 禁用wandb等记录

    )

    # 5. 开始训练
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=DataCollatorForSeq2Seq(tokenizer, padding=True)
    )
    print("===== 开始训练 =====")
    trainer.train()
    print("===== 训练完成 =====")

    # 6. 保存最终模型
    model.save_pretrained(CONFIG['output_dir'])
    print(f"模型已保存到 {CONFIG['output_dir']}")


    # 7.加载模型
    base_model = AutoModelForCausalLM.from_pretrained(
        CONFIG["model_path"],
        torch_dtype=torch.bfloat16,
        #device_map="auto"
    ).to(device)
    loaded_model = PeftModel.from_pretrained(base_model,CONFIG['output_dir'])
    # 准备输入
    tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_path"])
    tokenizer.pad_token = tokenizer.eos_token  # 设置填充token

    # 8. 推理
    prompt="who are you ?"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    loaded_model.eval()
    # outputs = loaded_model.generate(**inputs, max_length=1024)
    outputs = loaded_model.generate(
        **inputs,
        max_new_tokens=256,          # 控制生成长度
        temperature=0.9,             # 增加多样性
        repetition_penalty=1.2,      # 防止重复
        top_p=0.9,                   # 核采样
        do_sample=True
    )
    print(f'prompt: {prompt} outputs: {tokenizer.decode(outputs[0], skip_special_tokens=True)}')


if __name__ == "__main__":
    main()

when use P-Tuning model to inference,i need run like this:

    base_model = AutoModelForCausalLM.from_pretrained(
        CONFIG["model_path"],
        torch_dtype=torch.bfloat16,
        #device_map="auto"
    ).to(device)
    loaded_model = PeftModel.from_pretrained(base_model,CONFIG['output_dir'])
    # 准备输入
    tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_path"])
    tokenizer.pad_token = tokenizer.eos_token  # 设置填充token

load base model first,then use PeftModel.from_pretrained(base_model,CONFIG['output_dir']) to load p-tuning output.
is there a way to merge p-tuning output to generate a new model?

Expected behavior

merge model like lora

@githubnemo
Copy link
Collaborator

As far as I know there is no way to merge P-Tuning virtual tokens with the model and I don't see how this should work either. The nature of this method is that new virtual tokens are prepended to the embedded input, i.e. we're dealing with activations and not weights anymore.

What kind of problem are you trying to solve with merging?

@lixd
Copy link
Author

lixd commented May 13, 2025

@githubnemo thanks for your replay,easy way to save model and inference with other component,like vLLM

@githubnemo
Copy link
Collaborator

From a quick search it seems that vLLM does not support soft prompting directly.

However I found that it supports passing embeddings when generating. In theory you could use that to pass the learned embeddings to the vLLM instance for inference.

A quick sketch based on your initial code (no guarantees):

loaded_model = PeftModel.from_pretrained(base_model, CONFIG['output_dir'])
tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_path"])

prompt="who are you?"
inputs = tokenizer(prompt, return_tensors="pt").to(device)

soft_prompt_embeds = loaded_model.prompt_encoder['default'](inputs)

# ... setup vllm llm ...

llm.generate([{"prompt_embeds": soft_prompt_embeds}], sampling_params)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants