清水泥沙

Transformer 学习之路 - 数据集加载与预处理

Transformer 学习之路 - 数据集加载与预处理

在 Transformer 模型的训练过程中,数据集的加载和预处理是至关重要的一步。本文将详细介绍如何使用 datasets 库来加载、处理和管理数据集,确保数据能够高效地输入到模型中。

1. 安装与导入 datasets

首先,我们需要安装并导入 datasets 库。这个库提供了丰富的功能,可以轻松地加载和处理各种数据集。

!pip install datasets
from datasets import *

2. 加载数据集

datasets 库支持从线上和线下加载数据集。无论是公开的数据集还是自定义的数据集,都可以通过简单的代码实现加载。

2.1 加载线上数据集

datasets = load_dataset("madao33/new-title-chinese")
datasets

2.2 加载特定任务的数据集

boolq_dataset = load_dataset("super_glue", "boolq")
boolq_dataset

3. 数据集切分

在实际应用中,我们通常需要将数据集划分为训练集、验证集和测试集。datasets 库提供了多种切分方式,满足不同的需求。

3.1 加载部分数据

# 只加载训练数据
dataset = load_dataset("madao33/new-title-chinese", split="train")
dataset
# 只加载训练数据的10到100下标的数据
dataset = load_dataset("madao33/new-title-chinese", split="train[10:100]")
dataset
# 取训练集中后50%数据
dataset = load_dataset("madao33/new-title-chinese", split="train[:50%]")
dataset
# 先取后50%再取前50%
dataset = load_dataset("madao33/new-title-chinese", split=["train[:50%]", "train[50%:]"])
dataset

4. 查看数据集

在加载数据集后,我们通常需要查看数据的结构和内容,以便更好地理解数据。

4.1 查看数据集中的具体样本

# 查看训练集第0个样本
datasets["train"][0]
# 查看训练集前两个样本
datasets["train"][:2]
# 获取前五个标题字段
datasets["train"]["title"][:5]

4.2 查看数据集的列信息

# 查看列名
datasets["train"].column_names
# 查看列详情
datasets["train"].features

5. 数据集划分

为了评估模型的性能,我们通常需要将数据集划分为训练集和测试集。

5.1 随机划分数据集

# 划分数据集
dataset = datasets['train']
dataset.train_test_split(test_size=0.1)

5.2 按字段均衡划分数据集

# 按数据中的字段来均衡划分
dataset = boolq_dataset["train"]
dataset.train_test_split(test_size=0.1, stratify_by_column="label")

6. 数据选取和过滤

在处理数据集时,我们可能需要根据特定条件选取或过滤数据。

6.1 选取数据

# 选取
datasets["train"].select([0, 1])

6.2 过滤数据

# 过滤
filter_dataset = datasets["train"].filter(lambda example: "中国" in example["title"])
filter_dataset["title"][:5]

7. 数据映射

数据映射是指对数据集中的每个样本进行某种操作,例如添加前缀或进行编码。

7.1 添加前缀

def add_prefix(example):
    example["title"] = 'Prefix: ' + example["title"]
    return example

# 经过映射方法处理的数据集
prefix_dataset = datasets.map(add_prefix)
prefix_dataset["train"][:10]["title"]

7.2 使用 Tokenizer 进行编码

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

def preprocess_function(example, tokenizer=tokenizer):
    model_inputs = tokenizer(example["content"], max_length=512, truncation=True)
    labels = tokenizer(example["title"], max_length=32, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

processed_datasets = datasets.map(preprocess_function)
processed_datasets

7.3 多进程处理数据

# 开启多进程处理数据,提高处理速度
processed_datasets = datasets.map(preprocess_function, num_proc=4)
processed_datasets

7.4 批量处理数据

# 批量处理数据,去除多余的列
processed_datasets = datasets.map(preprocess_function, batched=True, remove_columns=datasets["train"].column_names)
processed_datasets

8. 保存和加载数据集

处理后的数据集可以保存到磁盘,以便后续使用。

8.1 保存数据集

processed_datasets.save_to_disk("./processed_data")

8.2 加载数据集

processed_datasets = load_from_disk("./processed_data")
processed_datasets

9. 加载本地文件作为数据集

datasets 库还支持加载本地文件作为数据集,例如 CSV、JSON 等格式。

9.1 加载 CSV 文件

csv_path = "/content/drive/MyDrive/Colab Notebooks/transformers-code/01-Getting Started/05-datasets/ChnSentiCorp_htl_all.csv"
dataset = load_dataset("csv", data_files=csv_path, split="train")
dataset

9.2 加载文件夹内全部文件

dataset = load_dataset("csv", data_files=["/content/drive/MyDrive/Colab Notebooks/transformers-code/01-Getting Started/05-datasets/all_data/ChnSentiCorp_htl_all copy 2.csv", "/content/drive/MyDrive/Colab Notebooks/transformers-code/01-Getting Started/05-datasets/all_data/ChnSentiCorp_htl_all copy 2.csv"], split='train')
dataset

10. 通过自定义脚本加载数据集

对于特殊格式的数据集,我们可以通过自定义脚本来加载和处理。

dir_path = "/content/drive/MyDrive/Colab Notebooks/transformers-code/01-Getting Started/05-datasets/"
script_path = dir_path+"load_script.py"
dataset = load_dataset(script_path, split="train")
dataset

11. 使用 DataCollator 进行数据批处理

在训练模型时,我们通常需要将数据批处理为固定大小的张量。DataCollator 可以帮助我们实现这一目标。

11.1 使用 DataCollatorWithPadding

from transformers import DataCollatorWithPadding

dataset = load_dataset("csv", data_files=csv_path, split='train')
dataset = dataset.filter(lambda x: x["review"] is not None)

def process_function(examples):
    tokenized_examples = tokenizer(examples["review"], max_length=128, truncation=True)
    tokenized_examples["labels"] = examples["label"]
    return tokenized_examples

tokenized_dataset = dataset.map(process_function, batched=True, remove_columns=dataset.column_names)
collator = DataCollatorWithPadding(tokenizer=tokenizer)

from torch.utils.data import DataLoader
dl = DataLoader(tokenized_dataset, batch_size=4, collate_fn=collator, shuffle=True)

num = 0
for batch in dl:
    print(batch["input_ids"].size())
    num += 1
    if num > 10:
        break

12. 总结

通过 datasets 库,我们可以轻松地加载、处理和