# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Literal, cast
import torch
import torch.distributed as dist
import torch.nn.functional as F
from ..utils import gather_logprobs
from .base_loss import (
BaseRLLossConfig,
BaseRLLossContext,
BaseRLLossKwargs,
compute_kl_loss_weight,
)
from .loss_fn import get_policy_loss_fn, kl_penalty, sft_loss_fn
[docs]class OrealLossConfig(BaseRLLossConfig):
"""Configuration for OREAL loss computation in XTuner RL.
``OrealLossConfig`` extends the base RL loss with separate calibration for
positive and negative tokens. It can combine policy loss, optional KL loss,
and an SFT-style loss term for positive samples.
Args:
policy_loss_cfg (dict[str, Any]): Configuration parameters for the main
policy loss.
use_kl_loss (bool): Whether to include KL divergence penalty in the
loss. Defaults to False.
kl_loss_coef (float): Coefficient for the KL divergence penalty.
Defaults to 0.001.
kl_loss_type (Literal["kl", "k1", "abs", "mse", "k2", "low_var_kl", "k3"] | None):
Type of KL penalty computation method. Defaults to None.
rollout_is (RolloutImportanceSampling): Rollout importance sampling
configuration. Defaults to ``RolloutImportanceSampling()``.
positive_loss_factor (float): Global multiplier for positive-token
losses. Defaults to 1.0.
pos_sft_loss_weight (float): Weight of the SFT-style loss on positive
tokens. Defaults to 1.0.
pos_policy_loss_weight (float): Weight of the policy loss on positive
tokens. Defaults to 1.0.
negative_loss_factor (float): Global multiplier for negative-token
losses. Defaults to 1.0.
**Examples:**
Example OREAL loss configuration::
config = OrealLossConfig(
policy_loss_cfg={"loss_type": "vanilla"},
positive_loss_factor=1.0,
pos_sft_loss_weight=1.0,
pos_policy_loss_weight=1.0,
negative_loss_factor=1.0,
)
"""
policy_loss_cfg: dict[str, Any]
use_kl_loss: bool = False
kl_loss_coef: float = 0.001
kl_loss_type: Literal["kl", "k1", "abs", "mse", "k2", "low_var_kl", "k3"] | None = None
positive_loss_factor: float = 1.0
pos_sft_loss_weight: float = 1.0
pos_policy_loss_weight: float = 1.0
negative_loss_factor: float = 1.0
@property
def loss_ctx_cls(self) -> type["OrealLossContext"]:
return OrealLossContext
@property
def _loss_kwargs_cls(self) -> type["OrealLossKwargs"]:
return OrealLossKwargs
[docs]class OrealLossKwargs(BaseRLLossKwargs):
sft_loss_weight: torch.Tensor | None = None
[docs]class OrealLossContext(BaseRLLossContext):
loss_cfg: OrealLossConfig
loss_kwargs: OrealLossKwargs
def __init__(self, loss_cfg: OrealLossConfig, loss_kwargs: OrealLossKwargs):
super().__init__(loss_cfg, loss_kwargs)
self.policy_loss_fn = get_policy_loss_fn(self.loss_cfg.policy_loss_cfg.get("loss_type", "vanilla"))
@staticmethod
def build_batches(loss_ctx_list: list["OrealLossContext"]) -> list["OrealLossContext"]: # type: ignore[override]
assert len(loss_ctx_list) > 0, "loss_ctx_list can not be empty"
loss_cfg = loss_ctx_list[0].loss_cfg
shifted_labels_list = [loss_ctx.loss_kwargs.shifted_labels for loss_ctx in loss_ctx_list]
advantages_list = [loss_ctx.loss_kwargs.advantages for loss_ctx in loss_ctx_list]
# Compute the denominator used in the global calibration of the loss
rank_grad_tokens = sum((labels != loss_cfg.ignore_idx).sum() for labels in shifted_labels_list)
rank_grad_tokens = cast(torch.Tensor, rank_grad_tokens)
rank_positive_tokens = sum(
((labels != loss_cfg.ignore_idx) & (adv > 0)).sum()
for labels, adv in zip(shifted_labels_list, advantages_list)
)
global_grad_tokens = rank_grad_tokens
global_positive_tokens = rank_positive_tokens
if dist.is_initialized():
dist.all_reduce(global_grad_tokens, op=dist.ReduceOp.SUM)
dist.all_reduce(global_positive_tokens, op=dist.ReduceOp.SUM)
global_negative_tokens = global_grad_tokens - global_positive_tokens
for loss_ctx in loss_ctx_list:
loss_kwargs = loss_ctx.loss_kwargs
shifted_labels = loss_kwargs.shifted_labels
advantages = loss_kwargs.advantages
assert loss_kwargs.old_logprobs is not None, "old_logprobs can not be None"
# compute sft loss_weights
# TODO: oreal 官方实现里 sft loss weights 要乘两个 loss factor,需要进一步 check 下
sft_loss_weights = (
torch.ones_like(shifted_labels, dtype=torch.float32)
* loss_cfg.pos_sft_loss_weight
* loss_cfg.positive_loss_factor
/ global_positive_tokens
)
sft_loss_weights[shifted_labels == loss_cfg.ignore_idx] = 0.0
sft_loss_weights[advantages <= 0] = 0.0 # only positive advantages tokens contribute to sft loss
# compute policy loss_weights
policy_loss_weights = torch.ones_like(shifted_labels, dtype=torch.float32)
policy_loss_weights[shifted_labels == loss_cfg.ignore_idx] = 0.0
policy_loss_weights[advantages > 0] *= (
loss_cfg.pos_policy_loss_weight * loss_cfg.positive_loss_factor / global_positive_tokens
)
policy_loss_weights[advantages <= 0] *= loss_cfg.negative_loss_factor / global_negative_tokens
if loss_kwargs.is_weights is not None:
policy_loss_weights = policy_loss_weights * loss_kwargs.is_weights
# compute kl loss weights
if loss_cfg.use_kl_loss:
assert loss_kwargs.ref_logprobs is not None, "ref_logprobs can not be None"
kl_loss_weight = compute_kl_loss_weight(
shifted_labels,
global_grad_tokens,
loss_cfg.kl_loss_coef,
loss_cfg.ignore_idx,
)
else:
kl_loss_weight = None
loss_kwargs.sft_loss_weight = sft_loss_weights
loss_kwargs.policy_loss_weight = policy_loss_weights
loss_kwargs.kl_loss_weight = kl_loss_weight
loss_kwargs.global_grad_tokens = global_grad_tokens
return loss_ctx_list
def loss_fn(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None,
loss_kwargs: OrealLossKwargs,
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
"""Step 2.a and 2.b in the loss calculation in
xtuner/v1/loss/base_loss_ctx.py."""
# We do linear forward here to simplify the implementation of chunk loss (saving memory).
logits = F.linear(hidden_states, head_weight, head_bias)
logits = logits.float()
shifted_labels = loss_kwargs.shifted_labels
old_logprobs = loss_kwargs.old_logprobs
advantages = loss_kwargs.advantages
policy_loss_weight = loss_kwargs.policy_loss_weight
sft_loss_weight = loss_kwargs.sft_loss_weight
assert sft_loss_weight is not None, "sft_loss_weight can not be None"
sft_loss = sft_loss_fn(logits, shifted_labels, sft_loss_weight, ignore_idx=self.loss_cfg.ignore_idx)
logprobs = gather_logprobs(logits, shifted_labels)
policy_loss = self.policy_loss_fn(
logprobs,
old_logprobs,
advantages,
policy_loss_weight,
self.loss_cfg.policy_loss_cfg,
)
loss = sft_loss + policy_loss
if self.loss_cfg.use_kl_loss:
ref_logprobs = loss_kwargs.ref_logprobs
kl_loss_weight = loss_kwargs.kl_loss_weight
assert ref_logprobs is not None and kl_loss_weight is not None, (
"loss_kwargs.ref_logprobs and loss_kwargs.kl_loss_weight can not be None when use_kl_loss=True"
)
kl_loss = kl_penalty(logprobs, ref_logprobs, kl_loss_weight, self.loss_cfg.kl_loss_type)
loss = loss + kl_loss
return loss, (logits, {})