使用 Flash Attention 加速训练#
Flash Attention (Flash Attention 2) 是一种用于加速 Transformer 模型中 Attention 计算,并减少其显存消耗的算法。XTuner 中 Flash Attention (Flash Attention 2) 的支持情况如下表所示:
模型 |
Flash Attention 支持情况 |
---|---|
baichuan 1/2 |
❌ |
chatglm 2/3 |
❌ |
deepseek |
✅ |
gemma |
❌ |
internlm 1/2 |
✅ |
llama 2 |
✅ |
mistral |
✅ |
qwen 1/1.5 |
✅ |
starcoder |
✅ |
yi |
✅ |
zephyr |
✅ |
备注
XTuner 会根据运行环境自动控制 Flash Attention 的使用情况 (见 dispatch_modules):
环境 |
Flash Attention 使用情况 |
---|---|
安装 flash attn |
Flash Attention 2 |
未安装 flash attn 且 PyTorch Version <= 1.13 |
No Flash Attention |
未安装 flash attn 且 2.0 <= PyTorch Version <= 2.1 |
Flash Attention 1 |
未安装 flash attn 且 PyTorch Version >= 2.2 |
Flash Attention 2 |
备注
使用 XTuner 训练 QWen1/1.5 时若想使用 Flash Attention 加速,需要先安装 flash attn (参考 flash attn 安装,需要 cuda )