from typing import Any, Literal
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh
from typing_extensions import Self
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContext, CELossKwargs
from xtuner.v1.loss.utils import sp_gather, sp_split
# from ..utils import sp_split
from xtuner.v1.rl.rollout_is import RolloutImportanceSampling
from xtuner.v1.utils.device import get_device
DEVICE = get_device()
def compute_kl_loss_weight(
shifted_labels: torch.Tensor, global_grad_tokens: torch.Tensor, kl_loss_coef: float, ignore_idx: int = -100
) -> torch.Tensor:
kl_loss_weight = torch.ones_like(shifted_labels, dtype=torch.float32) / global_grad_tokens * kl_loss_coef
kl_loss_weight[shifted_labels == ignore_idx] = 0.0
return kl_loss_weight
[docs]class BaseRLLossConfig(CELossConfig):
"""Base configuration for reinforcement learning loss functions in XTuner
RL.
Configuration base class for RL loss computations, providing a framework for
policy optimization objectives with optional KL divergence regularization.
Serves as the foundation for various RL algorithms including PPO, GRPO, and custom implementations.
Args:
policy_loss_cfg (dict[str, Any]): Configuration parameters for the main policy loss.
Contains algorithm-specific parameters for policy optimization.
use_kl_loss (bool): Whether to include KL divergence penalty in the loss.
When True, requires a reference model for KL computation. Defaults to False.
kl_loss_coef (float): Coefficient for weighting the KL divergence penalty.
Controls the strength of regularization against the reference policy. Defaults to 0.001.
kl_loss_type (Literal["kl", "k1", "abs", "mse", "k2", "low_var_kl", "k3"] | None):
Type of KL penalty computation method. Different types provide various
regularization behaviors and numerical stability properties. Defaults to None.
rollout_is (RolloutImportanceSampling): Configuration parameters for the rollout importance sampling.
Contains algorithm-specific parameters for rollout importance sampling.
Defaults to RolloutImportanceSampling().
**Abstract Method:**
loss_ctx_cls: Must be implemented by subclasses to return the appropriate
loss context class for the specific RL algorithm.
**Examples:**
Example configuration for basic RL loss ::
config = GRPOLossConfig(
policy_loss_cfg=dict(
cliprange_high=0.2,
cliprange_low=0.2,
loss_type='vanilla',
),
use_kl_loss=False
)
Example configuration RL loss with KL regularization::
config = GRPOLossConfig(
policy_loss_cfg=dict(
cliprange_high=0.2,
cliprange_low=0.2,
loss_type='vanilla',
),
use_kl_loss=True,
kl_loss_coef=0.001,
kl_loss_type="low_var_kl"
)
.. note::
When ``use_kl_loss=True``, ensure that the training worker is configured
with a reference model for KL divergence computation.
"""
policy_loss_cfg: dict[str, Any]
use_kl_loss: bool = False
kl_loss_coef: float = 0.001
kl_loss_type: Literal["kl", "k1", "abs", "mse", "k2", "low_var_kl", "k3"] | None = None
rollout_is: RolloutImportanceSampling = RolloutImportanceSampling()
@property
def loss_ctx_cls(self) -> type["BaseRLLossContext"]:
raise NotImplementedError
@property
def _loss_kwargs_cls(self) -> type["BaseRLLossKwargs"]:
raise NotImplementedError
def build(
self,
data: dict,
sp_mesh: DeviceMesh | None = None,
) -> "BaseRLLossContext | None":
"""Build RL loss context from data dict.
Args:
data (dict): Data dictionary containing RL-specific fields:
- shifted_labels (torch.Tensor): The shifted labels
- advantages (torch.Tensor): Advantage estimates
- rollout_logprobs (torch.Tensor | None): Rollout log probabilities
- old_logprobs (torch.Tensor | None): Old policy log probabilities (optional, can be set later)
- rollout_is_weights (torch.Tensor | None): Importance sampling weights
- ref_logprobs (torch.Tensor | None): Reference model log probabilities
sp_mesh (DeviceMesh | None): Sequence parallel device mesh
Returns:
BaseRLLossContext | None: The built loss context, or None if required fields are missing
"""
# Check for required fields
if "shifted_labels" not in data or "advantages" not in data:
return None
# Extract RL-specific fields from data
shifted_labels = data["shifted_labels"]
advantages = data["advantages"]
rollout_logprobs = data.get("rollout_logprobs", None)
old_logprobs = data.get("old_logprobs", None)
rollout_is_weights = data.get("rollout_is_weights", None)
ref_logprobs = data.get("ref_logprobs", None)
LossKwargs = self._loss_kwargs_cls
loss_kwargs = LossKwargs(
shifted_labels=shifted_labels,
old_logprobs=old_logprobs,
advantages=advantages,
rollout_logprobs=rollout_logprobs,
is_weights=rollout_is_weights,
ref_logprobs=ref_logprobs,
).to(DEVICE)
if sp_mesh is not None and sp_mesh.size() > 1:
loss_kwargs = loss_kwargs.sp_split(sp_mesh)
LossContext = self.loss_ctx_cls
return LossContext(self, loss_kwargs)
[docs]class BaseRLLossKwargs(CELossKwargs):
"""Keyword arguments for reinforcement learning loss computation.
Args:
shifted_labels (torch.Tensor): The shifted labels for the input sequences.
old_logprobs (torch.Tensor): Log probabilities from the old policy.
advantages (torch.Tensor): Advantage estimates for the actions taken.
policy_loss_weight (torch.Tensor): Weights for each token in the policy loss computation.
ref_logprobs (torch.Tensor | None): Reference log probabilities for KL penalty, if used.
kl_loss_weight (torch.Tensor | None): Weights for each token in the KL loss computation, if used.
rollout_logprobs (torch.Tensor | None): Rollout log probabilities from inference engine, used for importance sampling.
is_weights (torch.Tensor | None): Importance sampling weights. If None, importance sampling is not used.
"""
rollout_logprobs: torch.Tensor | None = None
advantages: torch.Tensor
old_logprobs: torch.Tensor | None = None
policy_loss_weight: torch.Tensor | None = None
ref_logprobs: torch.Tensor | None = None
kl_loss_weight: torch.Tensor | None = None
global_grad_tokens: torch.Tensor | None = None
is_weights: torch.Tensor | None = None
def sp_split(self, sp_mesh: DeviceMesh) -> Self:
# Call parent class to handle shifted_labels
super().sp_split(sp_mesh)
# Handle RL-specific fields
self.advantages = sp_split(self.advantages, sp_mesh=sp_mesh, split_dim=1, padding_value=0.0)
if self.rollout_logprobs is not None:
self.rollout_logprobs = sp_split(self.rollout_logprobs, sp_mesh=sp_mesh, split_dim=1, padding_value=0.0)
if self.is_weights is not None:
self.is_weights = sp_split(self.is_weights, sp_mesh=sp_mesh, split_dim=1, padding_value=1.0)
# 1. 这里不用对old_logprobs和ref_logprobs进行sp_split,因为他是模型 fwd 生成的,
# 因为模型 fwd 前一定会对 seq_ctx 进行 sp_split。
# 2. global_grad_tokens 是scalar Tensor, 不用进行sp_split。
# 3. 这里也不用对各种weight(policy_loss_weight, kl_loss_weight, is_weights)进行sp_split,
# 因为他们在LossContext.build_batches()中生成时也保证是sp_split的。
return self
def to(self, device: torch.device | str) -> Self:
# Call parent class to handle shifted_labels
super().to(device)
# Handle RL-specific fields
self.advantages = self.advantages.to(device)
if self.old_logprobs is not None:
self.old_logprobs = self.old_logprobs.to(device)
if self.ref_logprobs is not None:
self.ref_logprobs = self.ref_logprobs.to(device)
if self.rollout_logprobs is not None:
self.rollout_logprobs = self.rollout_logprobs.to(device)
if self.is_weights is not None:
self.is_weights = self.is_weights.to(device)
if self.global_grad_tokens is not None:
self.global_grad_tokens = self.global_grad_tokens.to(device)
if self.policy_loss_weight is not None:
self.policy_loss_weight = self.policy_loss_weight.to(device)
if self.kl_loss_weight is not None:
self.kl_loss_weight = self.kl_loss_weight.to(device)
return self
[docs]class BaseRLLossContext(CELossContext):
loss_cfg: BaseRLLossConfig # type: ignore[assignment]
loss_kwargs: BaseRLLossKwargs # type: ignore[assignment]
def compute_rollout_is(
self, sp_mesh: DeviceMesh, num_tokens: torch.Tensor
) -> tuple[dict[str, Any], dict[str, Any]]:
shifted_labels = self.loss_kwargs.shifted_labels
rollout_logprobs = self.loss_kwargs.rollout_logprobs
mask = shifted_labels != self.loss_cfg.ignore_idx
old_logprobs = self.loss_kwargs.old_logprobs
assert rollout_logprobs is not None
assert old_logprobs is not None
if sp_mesh and sp_mesh.size() > 1:
# Temporarily sp_gather old_logprobs here, but not modify loss_kwargs.old_logprobs(still in sp_split state)
rollout_logprobs = sp_gather(rollout_logprobs, sp_mesh, dim=1)
old_logprobs = sp_gather(old_logprobs, sp_mesh, dim=1)
old_logprobs = old_logprobs[:, : rollout_logprobs.size(1)] # type: ignore
mask = sp_gather(mask, sp_mesh, dim=1)
mask = mask[:, : rollout_logprobs.size(1)] # type: ignore
rollout_is_weights, rollout_is_mask, mismatch_metrics, rollout_is_metrics = (
self.loss_cfg.rollout_is.compute_rollout_importance_weights_and_metrics(
old_log_prob=old_logprobs,
rollout_log_prob=rollout_logprobs,
num_tokens=num_tokens,
response_mask=mask,
)
)
if sp_mesh and sp_mesh.size() > 1:
rollout_is_mask = sp_split(rollout_is_mask, sp_mesh, 1, 0)
assert rollout_is_mask.size(1) == shifted_labels.size(1), (
f"rollout_is_mask {rollout_is_mask.size(1)} vs shifted_labels {shifted_labels.size(1)}"
)
if rollout_is_weights is not None:
rollout_is_weights = sp_split(rollout_is_weights, sp_mesh, 1, 0)
assert rollout_is_weights.size(1) == shifted_labels.size(1), (
f"rollout_is_weights {rollout_is_weights.size(1)} vs shifted_labels {shifted_labels.size(1)}"
)
shifted_labels[~rollout_is_mask.bool()] = -100 # update loss mask
self.loss_kwargs.is_weights = rollout_is_weights
return mismatch_metrics, rollout_is_metrics
def finalize_train_policy_metrics(extra_info_dict: dict[str, Any], device: str | torch.device) -> dict[str, Any]:
if "reduced_train_policy_valid_count" not in extra_info_dict:
return extra_info_dict
has_clip_count = (
"reduced_train_policy_clip_low_count" in extra_info_dict
and "reduced_train_policy_clip_high_count" in extra_info_dict
)
sum_keys = [
"reduced_train_policy_ratio_abs_dev_sum",
"reduced_train_policy_kl1_sum",
"reduced_train_policy_kl3_sum",
"reduced_train_policy_valid_count",
]
if has_clip_count:
sum_keys.extend(["reduced_train_policy_clip_low_count", "reduced_train_policy_clip_high_count"])
max_keys = ("reduced_train_policy_ratio_max",)
min_keys = ("reduced_train_policy_ratio_min",)
def reduce_values(keys, op):
values = torch.tensor([extra_info_dict.pop(key, 0.0) for key in keys], dtype=torch.float32, device=device)
if dist.is_initialized():
dist.all_reduce(values, op=op)
return dict(zip(keys, values.tolist()))
train_policy_values = {}
train_policy_values.update(reduce_values(sum_keys, dist.ReduceOp.SUM))
train_policy_values.update(reduce_values(max_keys, dist.ReduceOp.MAX))
train_policy_values.update(reduce_values(min_keys, dist.ReduceOp.MIN))
valid_count = train_policy_values["reduced_train_policy_valid_count"]
output_keys = {
"reduced_train_policy_ratio_abs_dev_mean": "reduced_train_policy_ratio_abs_dev_sum",
"reduced_train_policy_kl1": "reduced_train_policy_kl1_sum",
"reduced_train_policy_kl3": "reduced_train_policy_kl3_sum",
}
if has_clip_count:
output_keys.update(
{
"reduced_train_policy_clip_frac_low": "reduced_train_policy_clip_low_count",
"reduced_train_policy_clip_frac_high": "reduced_train_policy_clip_high_count",
}
)
for output_key, sum_key in output_keys.items():
extra_info_dict[output_key] = train_policy_values[sum_key] / valid_count if valid_count > 0 else 0.0
ratio_max = train_policy_values["reduced_train_policy_ratio_max"] if valid_count > 0 else 0.0
ratio_min = train_policy_values["reduced_train_policy_ratio_min"] if valid_count > 0 else 0.0
extra_info_dict["reduced_train_policy_ratio_max"] = ratio_max
extra_info_dict["reduced_train_policy_ratio_min"] = ratio_min
# legacy metric,keep here
extra_info_dict["max_ratio"] = ratio_max
return extra_info_dict