损失函数#

动机#

预训练和微调任务的损失函数往往使用的是 CE Loss 。CE Loss 大家想必都不陌生,可为什么 XTuner 还要设计自己的 CE Loss 呢?

  1. 节约显存

当今大语言模型的词表普遍较大,同时,我们希望增加输入序列的长度来充分利用算力,导致 lm_head 计算 logits 再计算 loss 进而 backward 这一过程将会耗费大量显存。如下所示,使用 XTuner 提供的 chunk loss 可以节约 4/5 左右的显存:

import torch
import torch.nn as nn
import torch.nn.functional as F
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem, CELossContext
import time


hidden_states = torch.randn(32768, 4096, device="cuda", dtype=torch.bfloat16, requires_grad=True)
lm_head = nn.Linear(4096, 151936, bias=False).to(device="cuda", dtype=torch.bfloat16)
torch.cuda.reset_peak_memory_stats()
t1 = time.time()
logits = lm_head(hidden_states)
shifted_labels = torch.randint(0, 151936, (32768, ), device="cuda")
loss = F.cross_entropy(logits, shifted_labels)
loss.backward()
max_memory = torch.cuda.max_memory_allocated()
reserved_memory = torch.cuda.max_memory_reserved()
print(f"Eager mode Loss: {loss.item()}")
print(f"Eager mode hidden_states grad norm: {hidden_states.grad.norm().item()}")
print(f"Eager mode lm_head weight grad norm: {lm_head.weight.grad.norm().item()}")
print(f"Eager mode Max memory allocated: {max_memory / 1024**3:.2f} GB")
print(f"Eager mode Max memory reserved: {reserved_memory / 1024**3:.2f} GB")
print(f"Eager mode Time taken: {time.time() - t1:.2f} seconds")

del logits
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

shifted_labels = shifted_labels.unsqueeze(0)
hidden_states = hidden_states.unsqueeze(0)
hidden_states = hidden_states.clone().detach().requires_grad_(True)
lm_head.weight.grad = None
t1 = time.time()
loss_ctx_input_list = [CELossContextInputItem(shifted_labels=shifted_labels)]
loss_cfg = CELossConfig(mode='chunk', chunk_size=1024, loss_reduction="token")
batches_loss_kwargs = CELossContext.build_batches_loss_kwargs(loss_ctx_input_list, loss_cfg)
loss_ctx = CELossContext(loss_cfg, batches_loss_kwargs[0])
loss, _ = loss_ctx.forward(hidden_states, lm_head.weight)
loss.backward()
max_memory = torch.cuda.max_memory_allocated()
reserved_memory = torch.cuda.max_memory_reserved()
print(f"Chunk mode Loss: {loss.item()}")
print(f"Chunk mode hidden_states grad norm: {hidden_states.grad.norm().item()}")
print(f"Chunk mode lm_head weight grad norm: {lm_head.weight.grad.norm().item()}")
print(f"Chunk mode Max memory allocated: {max_memory / 1024**3:.2f} GB")
print(f"Chunk mode Max memory reserved: {reserved_memory / 1024**3:.2f} GB")
print(f"Chunk mode Time taken: {time.time() - t1:.2f} seconds")
Eager mode Loss: 12.125
Eager mode hidden_states grad norm: 0.0031890869140625
Eager mode lm_head weight grad norm: 0.353515625
Eager mode Max memory allocated: 38.57 GB
Eager mode Max memory reserved: 47.81 GB
Eager mode Time taken: 0.42 seconds
Chunk mode Loss: 12.094674110412598
Chunk mode hidden_states grad norm: 0.0031890869140625
Chunk mode lm_head weight grad norm: 0.353515625
Chunk mode Max memory allocated: 6.87 GB
Chunk mode Max memory reserved: 9.56 GB
Chunk mode Time taken: 0.26 seconds
  1. 实现 loss 的全局校准

什么是 loss 全局校准?

loss 全局校准是指,无论使用多少张显卡,无论使用什么并行策略和梯度累积策略,其训练的效果都等价于在一张显卡上不使用任何并行策略时的效果(不考虑是否会 OOM)。

为什么要做 loss 全局校准?

我们希望模型的训练不受显卡数量、并行策略、梯度累积策略的变化而变化。

如果不进行 loss 全局校准,那么对于同样一批数据,使用 8 卡梯度累积 2 和使用 16 卡梯度累积 1 (global batch size 相同)的训练行为是不同的。换言之,当显卡数量、并行策略、梯度累积策略的变化时,如果不进行 loss 全局校准,则训练行为是不可复现的,如下所示。

import torch
import torch.nn as nn
import torch.nn.functional as F
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContext
from mmengine.dist import infer_launcher, init_dist
import torch.distributed as dist


dist_launcher = infer_launcher()
init_dist(dist_launcher)
rank = dist.get_rank()
world_size = dist.get_world_size()

torch.manual_seed(0)
lm_head = nn.Linear(2, 10, bias=False).to(device="cuda", dtype=torch.bfloat16)
hidden_states_gt = torch.randn(8, 2, device="cuda", dtype=torch.bfloat16, requires_grad=True)
shifted_labels_gt = torch.tensor([-100, 0, 1, -100, 0, 1, 2, 3], device="cuda")

# 1 gpu
logits = lm_head(hidden_states_gt)
loss = F.cross_entropy(logits, shifted_labels_gt)
loss.backward()
grad_1_gpu = lm_head.weight.grad.clone()

# 2 gpu without global average
hidden_states = hidden_states_gt.clone().detach().requires_grad_(True)
lm_head.weight.grad = None
hidden_states = torch.chunk(hidden_states, world_size, dim=0)[rank]
shifted_labels = torch.chunk(shifted_labels_gt, world_size, dim=0)[rank]
logits = lm_head(hidden_states)
loss = F.cross_entropy(logits, shifted_labels)
loss.backward()
dist.all_reduce(lm_head.weight.grad, op=dist.ReduceOp.AVG)
grad_2_gpu = lm_head.weight.grad.clone()
print(f'Without global average, torch.allclose(grad_1_gpu, grad_2_gpu, atol=1e-2, rtol=1e-2) = {torch.allclose(grad_1_gpu, grad_2_gpu, atol=1e-2, rtol=1e-2)}')

# 2 gpu without global average
hidden_states = hidden_states_gt.clone().detach().requires_grad_(True)
lm_head.weight.grad = None
hidden_states = torch.chunk(hidden_states, world_size, dim=0)[rank]
shifted_labels = torch.chunk(shifted_labels_gt, world_size, dim=0)[rank]
hidden_states = hidden_states.unsqueeze(0)
shifted_labels = shifted_labels.unsqueeze(0)
loss_cfg = CELossConfig(mode='chunk', chunk_size=1024, loss_reduction="token")
loss_ctx = loss_cfg.build(shifted_labels)
loss_ctx_list = CELossContext.build_batches([loss_ctx])
loss_ctx = loss_ctx_list[0]
loss, _ = loss_ctx.forward(hidden_states, lm_head.weight)
loss.backward()

dist.all_reduce(lm_head.weight.grad, op=dist.ReduceOp.AVG)
grad_2_gpu = lm_head.weight.grad.clone()
print(f'With global average, torch.allclose(grad_1_gpu, grad_2_gpu, atol=1e-2, rtol=1e-2) = {torch.allclose(grad_1_gpu, grad_2_gpu, atol=1e-2, rtol=1e-2)}')
Without global average, torch.allclose(grad_1_gpu, grad_2_gpu, atol=1e-2, rtol=1e-2) = False
Without global average, torch.allclose(grad_1_gpu, grad_2_gpu, atol=1e-2, rtol=1e-2) = False
With global average, torch.allclose(grad_1_gpu, grad_2_gpu, atol=1e-2, rtol=1e-2) = True
With global average, torch.allclose(grad_1_gpu, grad_2_gpu, atol=1e-2, rtol=1e-2) = True

如何做 loss 全局校准?

假设我们有两张显卡,序列并行度为 2,梯度累积 2 次。

                            rank0         rank1
iter0 loss                 l00, l01      l02, l03
      loss weight          w00, w01      w02, w03
      loss mask (0 or 1)   m00, m01      m02, m03
iter1 loss                 l10, l11      l12, l13
      loss weight          w10, w11      w12, w13
      loss mask (0 or 1)   m10, m11      m12, m13

那么,loss 校准的方式如下:

  1. 计算所有显卡在梯度累积范围内的 loss mask 的和:

global_loss_mask_sum = all_reduce(sum([loss_mask.sum() for loss_mask in loss_masks_grad_acc]), op=dist.ReduceOp.SUM, group=world)
                     = (m00 + m01 + m02 + m03 + m10 + m11 + m12 + m13)
  1. 计算当前 iter 的 loss,以 rank0 iter0 为例:

loss_rank0iter0 = (l00 * w00 * m00 + l01 * w01 * m01)
loss_rank0iter0 = loss_rank0iter0 / global_loss_mask_sum
                = (l00 * w00 * m00 + l01 * w01 * m01) / (m00 + m01 + m02 + m03 + m10 + m11 + m12 + m13)
loss_rank0iter0 = all_reduce_autograd(loss_rank0iter0, op=dist.ReduceOp.SUM, group=world)
                = (l00 * w00 * m00 + l01 * w01 * m01 + l02 * w02 * m02 + l03 * w03 * m03) / (m00 + m01 + m02 + m03 + m10 + m11 + m12 + m13)
  1. 计算梯度累积范围内的 step_loss,与一张显卡不使用梯度累积时的效果相同:

step_loss = loss_rank0iter0 + loss_rank0iter1
          = (l00 * w00 * m00 + l01 * w01 * m01 + l02 * w02 * m02 + l03 * w03 * m03 + l10 * w10 * m10 + l11 * w11 * m11 + l12 * w12 * m12 + l13 * w13 * m13) / (m00 + m01 + m02 + m03 + m10 + m11 + m12 + m13)

CE Loss#

XTuner 中所有的 loss 计算均涉及两个核心组件 LossConfigLossContext 。CE Loss 则对应 CELossConfigCELossContext。下面是一个简单的 CE Loss 的使用示例:

import torch
import torch.nn as nn
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContext

emb = nn.Embedding(4, 2)
head = nn.Linear(2, 4, bias=False)

input_ids = torch.randint(0, 10, (1, 5))
shifted_labels = input_ids[:, 1:]
input_ids = input_ids[:, :-1]
hidden_states = emb(input_ids)

loss_ctx_input_list = [CELossContextInputItem(shifted_labels=shifted_labels)]
loss_cfg = CELossConfig(mode='chunk', chunk_size=1024, loss_reduction="token")
loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"])
loss_ctx_list = CELossContext.build_batches([loss_ctx])
loss_ctx = loss_ctx_list[0]
loss, _ = loss_ctx.forward(hidden_states, head.weight)
loss.backward()

CELossConfig#

CELossConfig 包含了 CE Loss 计算所需的所有可配置项。由三个通用配置项:ignore_idx, modechunk_size,以及一个 CE Loss 特有的 loss_reduction 组成。

class CELossConfig:
    ignore_idx: Annotated[int, Parameter(help="ignore index for loss calculation")] = -100
    mode: Annotated[Literal["eager", "chunk"], Parameter(help="loss calculation mode")] = "eager"
    chunk_size: Annotated[int | None, Parameter(help="chunk size when mode is chunk")] = 1024
    loss_reduction: Annotated[Literal["token", "sample", "square"], Parameter(help="loss reduction mode")] = "token"
  • ignore_idx 表示在 loss 计算中被忽略的 label ids ,通常为 -100 ,用户无需额外设置。

  • mode 共有 “eager” 和 “chunk” 两种可选,推荐设置为 “chunk” 模式来节省显存。

  • chunk_size 只有 mode 是 “chunk” 是才会生效。

  • loss_reduction 有 “token”, “sample”, “square” 三种可选,我们通常选择 “token” 模式,即 token 之间的 CE Loss 计算互不影响。

CELossContext#

CELossContext 中我们引入了额外的一个数据结构:CELossKwargs

  • CELossKwargs 表示 CE Loss 实际计算的时候需要用到哪些参数,即:shifted_labelsloss_weight 两项,注意此时的 loss_weight 已经经历过全局校准的处理了,详细实现请参考 xtuner/v1/loss/ce_loss.py

我们在 CELossContext 中只需要实现两个接口:

  1. 为了做 loss 全局校准,staticmethod build_batches 计算全局校准对应的loss weight。

  2. loss_fn 根据 CELossKwargs 计算出当前 iter 的 loss。

对于其他功能(如:chunk loss),不同 loss 都是通用的,我们统一放到 BaseLossContext 里实现。

Custom Loss#

如需自定义 loss 形式,需要重新实现 CustomLossConfigCustomLossContext 两个数据结构。

CustomLossConfig#

继承 BaseLossConfig 并拓展所需字段:

from xtuner.v1.loss import BaseLossConfig

class CustomLossConfig(BaseLossConfig):
    arg1: Any
    ...

    @property
    def loss_ctx_cls(self) -> type[CustomLossContext]:
        return CustomLossContext

CustomLossContext#

第一步,定义 custom loss 实际计算的时候需要用到哪些参数:

from xtuner.v1.loss import BaseLossContext, BaseLossKwargs

class CustomLossKwargs(BaseLossKwargs):
    shifted_labels: torch.Tensor
    loss_weight: torch.Tensor
    arg1: Any
    ...

第二步,继承 BaseLossContext 并实现 CustomLossContext 中的 classmethod build_batchesloss_fn

from xtuner.v1.loss import BaseLossContext, BaseLossKwargs

class CustomLossContext(BaseLossContext):
    loss_cfg: CustomLossConfig
    loss_kwargs: CustomLossKwargs

    @staticmethod
    def build_batches(
        loss_ctx_list: list["CELossContext"],
        # 为了提高计算效率,XTuner 会将多条短数据 pack 成一条长数据进行训练
        # 若在计算 CustomLossKwargs 的过程中需要解 pack 成若干短数据,则需要传入 cu_seq_lens_list
        # 默认为 None 即可。
        cu_seq_lens_list: Sequence[torch.IntTensor] | None = None,
        # 若开启了序列并行 (sp) 且计算 CustomLossKwargs 的过程中需要 sp 切分前的数据,则需要传入 cu_seq_lens_list
        # 默认为 None 即可。
        sp_mesh: DeviceMesh | None = None,
    ) -> list[CustomLossKwargs]:
        ...
    
    def loss_fn(
        self,
        hidden_states: torch.Tensor,
        head_weight: torch.Tensor,
        head_bias: torch.Tensor | None,
        loss_kwargs: CustomLossKwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        ...