数据集拼接#
简介#
对于大型语言模型(LLM)的输入而言,“数据集拼接” 这一概念指的是将多个 token 序列拼接成一个单独的输入。大量的数据集都存在一个特点,即其长度分布严重偏向较短的序列,而 Transformers 模型接收固定长度的输入。因此,在模型训练过程中,通常需要将每条数据 “Pad” 至当前 batch 最长序列的长度,而 “Pad Token” 往往是某个特定的无意义的 token。
将多条数据打包在一起可以不再需要使用 “Pad Token” 进行无意义的填充,减少计算资源的浪费,同时还可以保持模型作为具有固定大小输入的静态图表示的优点。
下表展示了 InternLM2 7B 模型在 Alpaca 数据集上使用不同数据集拼接策略进行训练的速度对比,如表所示,“数据集拼接”会大幅度提升训练效率:
拼接策略 |
每秒处理 token 数 |
加速比 |
---|---|---|
不使用 |
362.9 |
|
拼接至 2k |
2677.1 |
7.38x |
拼接至 4k |
3124.3 |
8.61x |
拼接至 8k |
3173.9 |
8.76x |
拼接至 16k |
2864.4 |
7.89x |
拼接至 32k |
2965.4 |
8.17x |
在 XTuner 中使用数据拼接#
XTuner 中提供的 config 文件中默认使用了“数据集拼接”这一功能,可以通过设置 max_length
字段来调整数据拼接长度。例如可通过以下方式将拼接长度调整为 32k :
#######################################################################
# PART 1 Settings #
#######################################################################
- max_length = 2048
+ max_length = 32768
pack_to_max_length = True
#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
train_dataset = dict(
max_length=max_length,
pack_to_max_length=pack_to_max_length,
...)
若不想使用数据拼接,在 config 中将 pack_to_max_length
设为 False 即可,
此时 config 中的 max_length
字段表示单条数据最长的 token 数,整个 batch 会被 pad 成当前 batch 内最长的一条数据的长度。
同时,XTuner 支持一种数据集采样策略 (LengthGroupedSampler
),在不使用数据拼接策略时可以保证在一个 batch 中的数据长度尽可能接近,
以减少 Pad 对计算资源的浪费。详细用法请参考
LengthGroupedSampler 文档 。