Transformer 学习之路 - Prefix-Tuning 实战
Transformer 学习之路 - Prefix-Tuning 实战
在深度学习领域,Transformer 模型因其强大的表现力而广受欢迎。然而,随着模型规模的增大,如何高效地微调这些模型以适应特定任务成为了一个挑战。本文将深入探讨一种高效的微调方法——Prefix-Tuning,并展示其在实际应用中的操作步骤。
什么是 Prefix-Tuning?
Prefix-Tuning 是一种针对大型预训练语言模型的高效微调方法。它通过优化特定任务的“前缀”嵌入而不是调整整个模型的参数来实现模型的自适应。在 Prefix-Tuning 中,前缀并不是传统意义上的文本提示词,而是可学习的嵌入向量序列。这些前缀嵌入会与输入文本拼接在一起,引导模型在生成或预测时倾向于某种任务特性。
Prefix-Tuning 的核心思想
与 Prompt Tuning 和 P-tuning 类似,Prefix-Tuning 通过在模型的输入前加入可训练的前缀嵌入来引导模型完成特定任务。然而,Prefix-Tuning 的设计更关注条件生成(例如对话生成、摘要生成等)任务,通过仅优化前缀来保留模型的原始能力,同时减少训练成本。
Prefix-Tuning 的工作流程
-
初始化前缀嵌入:
- Prefix-Tuning 会初始化一组前缀嵌入向量,这些向量的维度与模型的输入嵌入维度相同。
-
前缀嵌入与输入拼接:
- 这些前缀嵌入与输入文本的嵌入表示拼接在一起,构成模型的输入序列。例如,对于输入 “I love this product!” 以及前缀嵌入 [Prefix1]、[Prefix2] 等,最终的输入会是:
[Prefix1] [Prefix2] ... [PrefixN] I love this product!
- 这些前缀嵌入与输入文本的嵌入表示拼接在一起,构成模型的输入序列。例如,对于输入 “I love this product!” 以及前缀嵌入 [Prefix1]、[Prefix2] 等,最终的输入会是:
-
冻结模型参数:
- 为了减少训练开销,Prefix-Tuning 通常会冻结模型原本的参数,仅优化前缀嵌入向量的权重。
-
训练与优化前缀嵌入:
- 在训练过程中,前缀嵌入会逐步更新,以便适应特定的任务需求。通过这种方式,Prefix-Tuning 能引导模型生成或预测时,更加符合目标任务的上下文和期望输出。
Prefix-Tuning 的优势
- 参数高效:Prefix-Tuning 仅优化前缀嵌入而不修改模型的主体参数,这极大地降低了存储和计算成本。
- 任务迁移方便:只需为每个任务训练和保存相应的前缀嵌入,因此适合需要多任务切换的场景。
- 保留模型原始能力:由于不微调整个模型的权重,Prefix-Tuning 在执行特定任务时能保留模型原始的生成能力。
应用场景
Prefix-Tuning 特别适合条件生成任务,如:
- 对话生成:通过前缀嵌入向量引导模型生成具有特定风格或语境的对话。
- 摘要生成:在摘要生成任务中使用前缀嵌入来帮助模型生成简洁准确的摘要。
- 机器翻译:帮助模型适应特定的语言或语境转换需求。
示例
假设有一个对话生成任务,用户输入 “What is the weather like today?”。Prefix-Tuning 的输入可能会是:
[Prefix1] [Prefix2] ... [PrefixN] What is the weather like today?
在训练过程中,这些前缀嵌入会被调整成引导模型产生合理回答的方式,比如提供当前天气情况的句子。
Prefix-Tuning 与其他提示方法的对比
| 特性 | Prefix-Tuning | Prompt Tuning / P-tuning | Fine-Tuning |
|---|---|---|---|
| 参数量 | 少量(只优化前缀嵌入) | 少量(只优化提示词嵌入) | 较大(优化整个模型) |
| 适用场景 | 条件生成任务 | 分类任务、生成任务 | 所有任务 |
| 计算开销 | 较低 | 较低 | 较高 |
| 保留模型能力 | 是 | 是 | 否(调整模型整体参数) |
总结
Prefix-Tuning 是一种专注于条件生成任务的高效微调方法。它通过优化少量前缀嵌入引导模型生成期望输出,而无需调整模型的所有参数。这种方法能够在保持模型生成能力的同时,快速适应特定任务,是对大型预训练模型微调的一种高效替代方案。
实战演练
Step1 导入相关包
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer
Step2 加载数据集
ds = Dataset.load_from_disk("../data/alpaca_data_zh/")
ds
ds[:3]
Step3 数据集预处理
tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-1b4-zh")
tokenizer
def process_func(example):
MAX_LENGTH = 256
input_ids, attention_mask, labels = [], [], []
instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
response = tokenizer(example["output"] + tokenizer.eos_token)
input_ids = instruction["input_ids"] + response["input_ids"]
attention_mask = instruction["attention_mask"] + response["attention_mask"]
labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
if len(input_ids) > MAX_LENGTH:
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds
tokenizer.decode(tokenized_ds[1]["input_ids"])
tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"])))
Step4 创建模型
model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh")
Prefix-tuning
PEFT Step1 配置文件
from peft import PrefixTuningConfig, get_peft_model, TaskType
config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10, prefix_projection=True)
config
PEFT Step2 创建模型
model = get_peft_model(model, config)
model.prompt_encoder
model.print_trainable_parameters()
Step5 配置训练参数
args = TrainingArguments(
output_dir="./chatbot",
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
logging_steps=10,
num_train_epochs=1
)
Step6 创建训练器
trainer = Trainer(
model=model,
args=args,
tokenizer=tokenizer,
train_dataset=tokenized_ds,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
Step7 模型训练
trainer.train()
Step8 模型推理
model = model.cuda()
ipt = tokenizer("Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(model.device)
tokenizer.decode(model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True)
通过以上步骤,我们成功地应用了 Prefix-Tuning 技术来微调一个大型预训练语言模型,并展示了其在对话生成任务中的应用。希望这篇文章能帮助你更好地理解并应用 Prefix-Tuning 技术。