清水泥沙

Transformer 学习之路 - 掩码语言模型训练与实战

Transformer 学习之路 - 掩码语言模型训练与实战

掩码语言模型(Masked Language Model,MLM)是自然语言处理(NLP)领域中的一项重要技术,广泛应用于文本生成、文本分类、问答系统等任务。本文将从原理、应用及实战训练三个方面,深入解析 MLM 技术,并结合代码示例帮助读者理解并应用。


掩码语言模型的工作原理

1. 输入处理

MLM 的核心思想是通过随机掩盖句子中的部分单词,训练模型来预测这些被掩盖的单词。例如,给定句子 “The cat sat on the mat.",模型可能会将其处理为 “The cat [MASK] on the mat.",然后尝试预测被掩盖的单词 “sat”。

2. 模型训练

模型接收包含掩码的句子作为输入,并输出预测的单词。通过这种方式,模型能够学习到单词之间的上下文关系和语义。

3. 损失计算

模型的预测结果与实际被掩盖的单词进行比较,计算损失函数,并通过反向传播优化模型参数。


掩码语言模型的应用

MLM 在许多 NLP 任务中都有广泛应用:

  • 文本生成:生成缺失的文本部分。
  • 文本分类:为文本赋予标签。
  • 问答系统:根据上下文回答问题。

最著名的 MLM 模型是 BERT(Bidirectional Encoder Representations from Transformers),它通过 MLM 技术进行训练,能够有效捕捉语言的上下文信息。


实战训练:从数据到模型

接下来,我们将通过代码示例,一步步实现一个 MLM 模型的训练过程。

Step1 导入相关包

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer

Step2 加载数据集

ds = Dataset.load_from_disk("/content/drive/MyDrive/ai-learning/2.NLP Task/07-language_model/wiki_cn_filtered/")

Step3 数据预处理

tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-macbert-base")

def process_func(examples):
    return tokenizer(examples["completion"], max_length=384, truncation=True)

tokenized_ds = ds.map(process_func, batched=True, remove_columns=ds.column_names)

Step4 创建模型

model = AutoModelForMaskedLM.from_pretrained("hfl/chinese-macbert-base")

Step5 配置训练参数

import os
os.environ["WANDB_DISABLED"] = "true"

args = TrainingArguments(
    output_dir="./output",
    logging_dir=None,
    per_device_train_batch_size=16,
    num_train_epochs=3,
)

Step6 创建 Trainer

trainer = Trainer(
    args=args,
    model=model,
    tokenizer=tokenizer,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15)
)

Step7 训练模型

for param in model.parameters():
    if not param.is_contiguous():
        param.data = param.contiguous()

trainer.train()

Step8 评估模型

from transformers import pipeline

pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, device=0)

print(pipe("西安交通[MASK][MASK]博物馆(Xi'an Jiaotong University Museum)是一座位于西安交通大学的博物馆"))
print(pipe("下面是一则[MASK][MASK]新闻。小编报道,近日,游戏产业发展的非常好!"))

Step9 保存模型

model_save_path = "/content/drive/MyDrive/ai-learning/2.NLP Task/07-language_model/masked"
tokenizer.save_pretrained(model_save_path)
model.save_pretrained(model_save_path)

总结

通过本文的学习,我们深入了解了掩码语言模型的工作原理及其在 NLP 任务中的应用。结合代码示例,我们还实现了一个完整的 MLM 模型训练流程。希望本文能帮助读者更好地理解 Transformer 技术,并将其应用到实际项目中。

如果你对 Transformer 和 MLM 技术有更多兴趣,不妨动手实践,探索更多可能性!