数据集拼接#

简介#

对于大型语言模型(LLM)的输入而言,“数据集拼接” 这一概念指的是将多个 token 序列拼接成一个单独的输入。大量的数据集都存在一个特点,即其长度分布严重偏向较短的序列,而 Transformers 模型接收固定长度的输入。因此,在模型训练过程中,通常需要将每条数据 “Pad” 至当前 batch 最长序列的长度,而 “Pad Token” 往往是某个特定的无意义的 token。

将多条数据打包在一起可以不再需要使用 “Pad Token” 进行无意义的填充,减少计算资源的浪费,同时还可以保持模型作为具有固定大小输入的静态图表示的优点。

下表展示了 InternLM2 7B 模型在 Alpaca 数据集上使用不同数据集拼接策略进行训练的速度对比,如表所示,“数据集拼接”会大幅度提升训练效率:

拼接策略

每秒处理 token 数

加速比

不使用

362.9

拼接至 2k

2677.1

7.38x

拼接至 4k

3124.3

8.61x

拼接至 8k

3173.9

8.76x

拼接至 16k

2864.4

7.89x

拼接至 32k

2965.4

8.17x

在 XTuner 中使用数据拼接#

XTuner 中提供的 config 文件中默认使用了“数据集拼接”这一功能,可以通过设置 max_length 字段来调整数据拼接长度。例如可通过以下方式将拼接长度调整为 32k :

#######################################################################
#                          PART 1  Settings                           #
#######################################################################
- max_length = 2048
+ max_length = 32768
pack_to_max_length = True

#######################################################################
#                      PART 3  Dataset & Dataloader                   #
#######################################################################
train_dataset = dict(
    max_length=max_length,
    pack_to_max_length=pack_to_max_length,
    ...)

若不想使用数据拼接,在 config 中将 pack_to_max_length 设为 False 即可, 此时 config 中的 max_length 字段表示单条数据最长的 token 数,整个 batch 会被 pad 成当前 batch 内最长的一条数据的长度。 同时,XTuner 支持一种数据集采样策略 (LengthGroupedSampler),在不使用数据拼接策略时可以保证在一个 batch 中的数据长度尽可能接近, 以减少 Pad 对计算资源的浪费。详细用法请参考 LengthGroupedSampler 文档