Source code for xtuner.v1.rl.trainer.worker

import contextlib
import json
import math
import os
import time
from contextlib import contextmanager
from pathlib import Path
from typing import (
    TYPE_CHECKING,
    Dict,
    Iterable,
    List,
    Sequence,
    TypeAlias,
    TypedDict,
    cast,
)


if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

import numpy as np
import ray
import torch
import torch.distributed as dist
from mmengine.runner import set_random_seed
from pydantic import BaseModel, ConfigDict
from ray.actor import ActorClass, ActorProxy
from torch.distributed.device_mesh import init_device_mesh
from typing_extensions import NotRequired

from transformers import AutoTokenizer
from xtuner.v1.config.fsdp import FSDPConfig
from xtuner.v1.config.optim import LRConfig, OptimConfig
from xtuner.v1.data_proto.sequence_context import SequenceContext
from xtuner.v1.datasets.config import DataloaderConfig
from xtuner.v1.datasets.dataloader import Dataloader
from xtuner.v1.engine.train_engine import TrainEngine, TrainStepInfo
from xtuner.v1.float8.float8_handler import Float8Handler
from xtuner.v1.loss import BaseLossContext, CELossConfig, LogProbConfig
from xtuner.v1.loss.ce_loss import CELossContext, LMHeadLossContext
from xtuner.v1.loss.mtp_loss import MTPLossConfig, MTPLossContext
from xtuner.v1.model.base import BaseModel as XtunerBaseModel
from xtuner.v1.model.base import ModelItem, TransformerConfig
from xtuner.v1.model.compose.base import BaseComposeConfig, BaseComposeModel
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
from xtuner.v1.profiler import profiling_memory, profiling_time
from xtuner.v1.rl.loss import BaseRLLossConfig, BaseRLLossContext, finalize_train_policy_metrics, kl_penalty
from xtuner.v1.rl.utils import SingleAcceleratorWorker
from xtuner.v1.train.trainer import LoadCheckpointConfig
from xtuner.v1.utils import (
    XTUNER_DETERMINISTIC,
    ParallelConfigException,
    get_device,
    get_logger,
    get_torch_device_module,
    ray_method,
    set_deterministic,
)

from ..rollout_is import merge_rollout_is_metrics
from .update_weighter import UpdateWeighter


DeviceMeshRaw: TypeAlias = List[List[int]]  # A list of lists representing device mesh indices
ServiceUrlMap: TypeAlias = Dict[int, str]  # A dictionary mapping service names to their URLs
DEVICE = get_device()
DEVICE_MODULE = get_torch_device_module()


def calculate_entropy(
    shifted_labels_list: Sequence[torch.Tensor],
    old_logprobs_list: Sequence[torch.Tensor | None],
    global_grad_tokens: torch.Tensor,
) -> torch.Tensor | None:
    if len(old_logprobs_list) == 0 or old_logprobs_list[0] is None:
        return None
    sum_entropy: torch.Tensor | None = None
    for i, shifted_labels in enumerate(shifted_labels_list):
        mask = shifted_labels != -100
        assert old_logprobs_list[i] is not None
        entropy = -(cast(torch.Tensor, old_logprobs_list[i]) * mask).sum()
        sum_entropy = entropy if sum_entropy is None else sum_entropy + entropy
    sum_entropy = cast(torch.Tensor, sum_entropy)
    dist.all_reduce(sum_entropy, op=dist.ReduceOp.SUM)
    avg_sum_entropy = sum_entropy / global_grad_tokens if global_grad_tokens > 0 else torch.tensor(0.0)
    return avg_sum_entropy


[docs]class WorkerConfig(BaseModel): """Training worker configuration for XTuner RL. Configuration for RL training workers managing model training, optimization, and distributed computing in reinforcement learning workflows. Args: model_cfg (TransformerConfig): Model architecture configuration. optim_cfg (OptimConfig): Optimizer configuration for training. loss_cfg (BaseRLLossConfig): Loss function configuration for RL training. lr_cfg (LRConfig): Learning rate scheduler configuration. fsdp_cfg (FSDPConfig): Fully Sharded Data Parallel configuration. load_from (str | Path): Path to load the main model from. optimizer_steps (int): Number of optimizer steps per training iteration. Defaults to 1. sp_size (int): Sequence parallel size for distributed training. Defaults to 1. pack_max_length (int): Maximum sequence length for data packing. ref_load_from (str | Path | None): Path to load reference model from. If None, uses same as load_from. Defaults to None. ref_model_fsdp_cfg (FSDPConfig | None): FSDP configuration for reference model. Defaults to None. log_dir (str | Path | None): Directory for training logs. Defaults to None. update_weight_bucket_size_in_gb (float): Bucket size used when syncing updated weights to rollout workers. Defaults to 0.5. seed (int | None): Training worker random seed. When None, the RL trainer seed is used. Defaults to None. **Examples:** Example configuration for Basic worker:: config = WorkerConfig( model_cfg=TransformerConfig(model_name="llama2-7b"), optim_cfg=OptimConfig(optimizer="adamw"), loss_cfg=GRPOLossConfig(policy_loss_cfg={"loss_type": "vanilla"}), lr_cfg=LRConfig(lr=1e-5), fsdp_cfg=FSDPConfig(), load_from="meta-llama/Llama-2-7b-hf", pack_max_length=2048, ) .. note:: When ``use_kl_loss=True`` in loss_cfg, a reference model will be loaded for KL divergence computation during training. """ model_config = ConfigDict(title="Worker config", extra="forbid", arbitrary_types_allowed=True) model_cfg: TransformerConfig | BaseComposeConfig optim_cfg: OptimConfig loss_cfg: BaseRLLossConfig lr_cfg: LRConfig fsdp_cfg: FSDPConfig load_from: str | Path # TODO: 把 actor 和 ref 配置分离 optimizer_steps: int = 1 sp_size: int = 1 pack_max_length: int ref_load_from: str | Path | None = None ref_model_fsdp_cfg: FSDPConfig | None = None log_dir: str | Path | None = None update_weight_bucket_size_in_gb: float = 0.5 # 512MB seed: None | int = None # if None, use RLTrainer seed profile_step: list[int] | int | None = None # 1-based global RL train_step ids to profile. profile_time: bool = True profile_memory: bool = False # sft config sft_dataloader_cfg: DataloaderConfig | None = None sft_global_batch_size: int = -1 rollout_steps_per_sft: int = 1 sft_loss_cfg: CELossConfig = CELossConfig() def build(self, placement_group: "PlacementGroup"): """Build training workers and controller from this config and placement group.""" # import here to avoid circular import from xtuner.v1.rl.trainer.controller import TrainingController from xtuner.v1.rl.utils import AutoAcceleratorWorkers TrainingWorkerCls = ray.remote( runtime_env={ "env_vars": { "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", "HCCL_NPU_SOCKET_PORT_RANGE": "auto", } } )(TrainingWorker) train_workers, _ = AutoAcceleratorWorkers.from_placement_group(TrainingWorkerCls, self, placement_group) ray.wait([w.ready.remote() for w in train_workers]) return TrainingController(workers=train_workers)
class WorkerInputItem(TypedDict): seq_ctx: SequenceContext shifted_labels: torch.LongTensor advantages: torch.Tensor rollout_logprobs: torch.Tensor | None class WorkerTrainLogItem(TypedDict, total=False): step_consumed_tokens: int efficient_attn_ratio: float grad_norm: float class WorkerLogItem(TypedDict): train_entropy: float rollout_entropy: NotRequired[float] mismatch_metrics: NotRequired[dict[str, float]] rollout_is_metrics: NotRequired[dict[str, float]] train_metrics: List[WorkerTrainLogItem] sft_train_metrics: NotRequired[dict[str, float]] class TrainingWorker(SingleAcceleratorWorker, UpdateWeighter): _SAVE_WEIGHTS_DIR = "weights" _SAVE_SFT_DATALOADER_DIR = "sft_dataloader" _SAVE_SFT_TRAIN_STATE_PATH = "sft_train_state.json" def __init__( self, worker_cfg: WorkerConfig, rank: int, master_addr: str, master_port: int, world_size: int, accelerator: str = "GPU", ): super().__init__(worker_cfg, rank, master_addr, master_port, world_size, accelerator) self.config = cast(WorkerConfig, self.config) torch.accelerator.set_device_index(int(os.environ["LOCAL_RANK"])) self.rank = rank # TODO: add lr scheduler log_dir = worker_cfg.log_dir self.log_dir = None if log_dir is not None: self.log_dir = Path(log_dir) if isinstance(log_dir, str) else log_dir self.logger = get_logger(log_dir=self.log_dir, tag="TrainingWorker") else: self.logger = get_logger() self._set_deterministic() self._set_random_seed(worker_cfg.seed) self.data_mesh = self._init_data_mesh(sp_size=worker_cfg.sp_size) self.sp_mesh = self.data_mesh["sp"] self._init_sft(worker_cfg) if not worker_cfg.fsdp_cfg.torch_compile: worker_cfg.model_cfg.compile_cfg = False self._engine = self._build_engine(worker_cfg) self._has_ref = False if worker_cfg.loss_cfg.use_kl_loss: self._has_ref = True if worker_cfg.ref_load_from is None: worker_cfg.ref_load_from = worker_cfg.load_from self._ref_model = self._build_ref_model( worker_cfg.model_cfg, worker_cfg.ref_load_from, worker_cfg.ref_model_fsdp_cfg ) self._optimizer_steps = worker_cfg.optimizer_steps profile_step = worker_cfg.profile_step if isinstance(profile_step, int): profile_step = [profile_step] self._profile_step = set(profile_step or []) self._profile_time = worker_cfg.profile_time self._profile_memory = worker_cfg.profile_memory self._global_train_step = 0 if worker_cfg.loss_cfg.chunk_size is not None: mode = "chunk" else: mode = "eager" self.logprob_cfg = LogProbConfig(chunk_size=worker_cfg.loss_cfg.chunk_size, mode=mode) self.mtp_config = None if isinstance(worker_cfg.model_cfg, BaseComposeConfig): if hasattr(worker_cfg.model_cfg.text_config, "mtp_config"): self.mtp_config = worker_cfg.model_cfg.text_config.mtp_config self._init_update_weighter() def _init_sft(self, worker_cfg: WorkerConfig): self._sft_dataloader_config = worker_cfg.sft_dataloader_cfg self._sft_dataloader: Dataloader | None = None self._sft_dataloader_iter: Iterable | None = None self._sft_loss_cfg: CELossConfig | None = None self._rollout_steps_per_sft = worker_cfg.rollout_steps_per_sft self._rollout_step = 0 self._sft_cur_epoch = 0 self._sft_total_consumed_tokens = 0 if self._sft_dataloader_config is not None: assert worker_cfg.sft_global_batch_size > 0, "sft_global_batch_size must be greater than 0" assert worker_cfg.seed is not None, "seed must be set when sft_dataloader_config is not None" tokenizer = AutoTokenizer.from_pretrained(worker_cfg.load_from, trust_remote_code=True) self._sft_dataloader = self._sft_dataloader_config.build( tokenizer=tokenizer, dp_mesh=self.data_mesh["dp"], global_batch_size=worker_cfg.sft_global_batch_size, micro_batch_size=1, seed=worker_cfg.seed, ) self.logger.info(f"Sft Dataloader len: {len(self._sft_dataloader)}") sft_loss_cfg = worker_cfg.sft_loss_cfg if worker_cfg.sft_loss_cfg is None: sft_loss_cfg = CELossConfig() self._sft_loss_cfg = sft_loss_cfg def _set_deterministic(self): if XTUNER_DETERMINISTIC: self.logger.info("Setting deterministic algorithms of TrainingWorker.") set_deterministic() def _set_random_seed(self, seed: None | int): set_random_seed(seed) def _build_engine(self, worker_cfg: WorkerConfig) -> TrainEngine: engine = TrainEngine( # type: ignore optim_cfg=worker_cfg.optim_cfg, fsdp_cfg=worker_cfg.fsdp_cfg, model_cfg=worker_cfg.model_cfg, ) if worker_cfg.load_from is not None: engine.from_hf(worker_cfg.load_from) if engine.model.compile_cfg is not None and self.rank == 0: self.logger.info(f"The `compile_cfg` of model is {json.dumps(engine.model.compile_cfg, indent=4)}") return engine def _build_ref_model( self, ref_model_cfg: TransformerConfig | BaseComposeConfig, load_from: str | Path, ref_model_fsdp_cfg: FSDPConfig | None = None, ): # TODO: 需要重构,使得能更优雅的兼容 mllm model: BaseComposeModel | XtunerBaseModel with torch.device("meta"): model = ref_model_cfg.build() if isinstance(ref_model_cfg, BaseComposeConfig): assert ref_model_cfg.text_config.float8_cfg is None, "BaseComposeConfig does not support float8" if ref_model_fsdp_cfg is None: ref_model_fsdp_cfg = FSDPConfig(recompute_ratio=0, cpu_offload=False, requires_grad=False) model = model.fully_shard(ref_model_fsdp_cfg) model.from_hf(hf_path=load_from) model.eval() # type: ignore else: ref_model_cfg = cast(TransformerConfig, ref_model_cfg) if ref_model_cfg.float8_cfg is not None and ref_model_cfg.float8_cfg.enable_float8: float8_handler = Float8Handler( scaling_granularity_gemm=ref_model_cfg.float8_cfg.scaling_granularity_gemm, scaling_granularity_grouped_gemm=ref_model_cfg.float8_cfg.scaling_granularity_grouped_gemm, ) else: float8_handler = None if ref_model_fsdp_cfg is None: ref_model_fsdp_cfg = FSDPConfig(recompute_ratio=0, cpu_offload=False, requires_grad=False) model = model.fully_shard(ref_model_fsdp_cfg) # type: ignore model.from_hf(hf_path=load_from) model.eval() # type: ignore if float8_handler is not None: # As the ref model is not updated, we only compute params' scales once float8_handler.precompute_float8_dynamic_scale_for_fsdp(model) # type: ignore model.to_device("cpu") # type: ignore DEVICE_MODULE.empty_cache() return model def _init_data_mesh( self, sp_size: int, ): world_size = dist.get_world_size() if world_size % sp_size != 0: raise ParallelConfigException( f"Found sp_size {sp_size}, world_size {world_size}." "sequence parallel size must be a divisor of world size." ) dp_size = world_size // sp_size # TODO: fsdp_config could be None device = str(DEVICE) if not self.config.fsdp_cfg.cpu_offload else "cpu" data_mesh = init_device_mesh( device, (dp_size, sp_size), mesh_dim_names=("dp", "sp"), ) return data_mesh def compute_actor_logprobs( self, seq_ctx_list: list[SequenceContext], shifted_labels_list: list[torch.Tensor], ) -> list[torch.Tensor]: # precompute float8 dynamic scale only once self._engine._maybe_precompute_float8_dynamic_scale_for_fsdp() old_logprobs_list: list[torch.Tensor] = [] for seq_ctx, shifted_labels in zip(seq_ctx_list, shifted_labels_list): loss_ctx = self.logprob_cfg.build(data={"shifted_labels": shifted_labels}) assert loss_ctx is not None output = self._engine.forward_only(seq_ctx=seq_ctx, loss_ctx=loss_ctx) old_logprobs_list.append(output["loss"]) return old_logprobs_list def compute_ref_logprobs( self, seq_ctx_list: list[SequenceContext], shifted_labels_list: list[torch.Tensor] ) -> list[torch.Tensor]: assert self._has_ref self._ref_model.to_device(DEVICE) ref_logprobs_list: list[torch.Tensor] = [] for seq_ctx, shifted_labels in zip(seq_ctx_list, shifted_labels_list): with torch.no_grad(): loss_ctx = self.logprob_cfg.build(data={"shifted_labels": shifted_labels}) assert loss_ctx is not None ref_output = self._ref_model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx}) ref_logprobs_list.append(ref_output["loss"]) self._ref_model.to_device("cpu") return ref_logprobs_list def _add_rollout_routed_experts( self, seq_ctx: SequenceContext, rollout_routed_experts: torch.Tensor | list[torch.Tensor | ray.ObjectRef] ): language_cfg = ( self.config.model_cfg.text_config if isinstance(self.config.model_cfg, BaseComposeConfig) else self.config.model_cfg ) to_free_routed_expert_refs: list[ray.ObjectRef] = [] if isinstance(rollout_routed_experts, list): # list[n,l,e] out_rollout_routed_expert = [] for rollout_routed_expert in rollout_routed_experts: if isinstance(rollout_routed_expert, torch.Tensor): rollout_routed_experts_tensor = torch.randint( low=0, high=language_cfg.n_routed_experts, size=( rollout_routed_expert.size(0), language_cfg.num_hidden_layers, language_cfg.num_experts_per_tok, ), ) out_rollout_routed_expert.append(rollout_routed_experts_tensor) else: rollout_routed_expert_refs = rollout_routed_expert rollout_routed_expert = ray.get(rollout_routed_expert_refs) # free obj store explicitly if self.sp_mesh is None or self.sp_mesh.size() == 1: ray.internal.free(rollout_routed_expert_refs, local_only=False) else: if self.sp_mesh.get_local_rank() == 0: # only free once of sp mesh to_free_routed_expert_refs.append(rollout_routed_expert_refs) rollout_routed_expert = torch.as_tensor(rollout_routed_expert, dtype=torch.long) rollout_routed_expert = rollout_routed_expert.reshape( -1, language_cfg.num_hidden_layers, language_cfg.num_experts_per_tok ) out_rollout_routed_expert.append(rollout_routed_expert) seq_ctx.rollout_routed_experts = torch.cat(out_rollout_routed_expert, dim=0) # max_len,l,e else: assert isinstance(rollout_routed_experts, torch.Tensor), ( f"padding experts should be a dummy tensor, bug got {type(rollout_routed_experts)}" ) rollout_routed_experts_tensor = torch.randint( low=0, high=language_cfg.n_routed_experts, size=( self.config.pack_max_length, language_cfg.num_hidden_layers, language_cfg.num_experts_per_tok, ), ) seq_ctx.rollout_routed_experts = rollout_routed_experts_tensor assert seq_ctx.input_ids is not None, "input_ids is None" assert seq_ctx.rollout_routed_experts.size(0) == seq_ctx.input_ids.size(1) if self.sp_mesh is not None and self.sp_mesh.size() > 1: dist.barrier() for free_routed_expert_refs in to_free_routed_expert_refs: ray.internal.free(free_routed_expert_refs, local_only=False) del to_free_routed_expert_refs @contextmanager def _maybe_profiling(self, global_train_step: int, phase: str): if global_train_step not in self._profile_step: yield return if self.log_dir is not None: profile_home = self.log_dir.parent else: profile_home = Path(os.environ.get("WORK_DIR", ".")) with contextlib.ExitStack() as stack: if self._profile_time: time_dir = profile_home / "profiling_time" / phase / f"global-step-{global_train_step}" stack.enter_context(profiling_time(time_dir)) if self._profile_memory: memory_dir = profile_home / "profiling_memory" / phase / f"global-step-{global_train_step}" stack.enter_context(profiling_memory(memory_dir)) yield @ray_method def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLogItem: # NOTE: sglang会清除logger handle, 重新创建 self.logger = get_logger(log_dir=self.log_dir, tag="TrainingWorker") loss_cfg: BaseRLLossConfig = self.config.loss_cfg num_batches = len(data_batches) iters_per_step = math.ceil(num_batches / self._optimizer_steps) if num_batches < self._optimizer_steps: self.logger.info( f"Optimizer only step once because num_batches {num_batches} < optimizer_steps {self._optimizer_steps}." ) # Update seq_ctx: pixel_values, rollout_routed_experts # Init loss_ctx: shifted_labels, advantages, rollout_logprobs seq_ctx_list: list[SequenceContext] = [] loss_ctx_list: list[BaseRLLossContext] = [] mtp_loss_ctx_list: list[list[MTPLossContext]] = [] prepare_inputs_begin = time.perf_counter() for data in data_batches: # update seq_ctx seq_ctx = data["seq_ctx"] pixel_values = seq_ctx.pixel_values if pixel_values is not None: if not isinstance(pixel_values, np.ndarray): assert isinstance(pixel_values, list), ( f"pixel_values should be list of tensor, got {type(pixel_values)}" ) pixel_values = ray.get(list(pixel_values)) pixel_values = [torch.as_tensor(pixel_value) for pixel_value in pixel_values] pixel_values = torch.cat(pixel_values, dim=0) seq_ctx.pixel_values = pixel_values else: raise NotImplementedError("The case where pixel_values is a numpy array is not implemented yet.") rollout_routed_experts = seq_ctx.rollout_routed_experts if rollout_routed_experts is not None: self._add_rollout_routed_experts(seq_ctx, rollout_routed_experts) seq_ctx = data["seq_ctx"].to(DEVICE) if self.sp_mesh.size() > 1: seq_ctx = seq_ctx.split(self.sp_mesh) # init loss_ctx shifted_labels = data["shifted_labels"].to(DEVICE) advantages = data["advantages"].to(DEVICE) rollout_logprobs = data.get("rollout_logprobs", None) rollout_logprobs = rollout_logprobs.to(DEVICE) if rollout_logprobs is not None else None loss_ctx = loss_cfg.build( data={ "shifted_labels": shifted_labels, "advantages": advantages, "rollout_logprobs": rollout_logprobs, }, sp_mesh=self.sp_mesh, ) seq_ctx_list.append(seq_ctx) assert loss_ctx is not None loss_ctx_list.append(loss_ctx) if self.mtp_config is not None: mtp_loss_ctxs_per_batch: list[MTPLossContext] = [] for mtp_idx in range(self.mtp_config.num_layers): mtp_loss_cfg = MTPLossConfig( **loss_cfg.model_dump(include={"mode", "chunk_size"}), mtp_depth=mtp_idx + 1, detach_mtp_lm_head_weight=self.mtp_config.detach_mtp_lm_head_weight, ) mtp_ctx = mtp_loss_cfg.build( data={ "shifted_labels": shifted_labels, "seq_ctx": seq_ctx, "logprobs": rollout_logprobs, }, sp_mesh=self.sp_mesh, ) if mtp_ctx is not None: mtp_loss_ctxs_per_batch.append(mtp_ctx) mtp_loss_ctx_list.append(mtp_loss_ctxs_per_batch) self.logger.debug( f"Rank{self.rank} Rollout {rollout_idx} prepare_inputs elapsed=" f"{time.perf_counter() - prepare_inputs_begin:.4f}s" ) del data_batches # When sp_mesh.size() > 1, get the sp_split shifted_labels and rollout_logprobs shifted_labels_list = [loss_ctx.loss_kwargs.shifted_labels for loss_ctx in loss_ctx_list] rollout_logprobs_list = [loss_ctx.loss_kwargs.rollout_logprobs for loss_ctx in loss_ctx_list] # compute old logprobs old_logprobs_list = self.compute_actor_logprobs(seq_ctx_list, shifted_labels_list) for old_logprobs, loss_ctx in zip(old_logprobs_list, loss_ctx_list): loss_ctx.loss_kwargs.old_logprobs = old_logprobs worker_log_item: WorkerLogItem = {"train_entropy": 0.0, "train_metrics": [], "sft_train_metrics": {}} logger_msg = f"Rollout {rollout_idx}: " # compute entropy rank_grad_tokens: torch.Tensor | None = None for shifted_labels in shifted_labels_list: mask = shifted_labels != -100 grad_tokens = mask.sum() rank_grad_tokens = grad_tokens if rank_grad_tokens is None else rank_grad_tokens + grad_tokens rank_grad_tokens = cast(torch.Tensor, rank_grad_tokens) global_grad_tokens = rank_grad_tokens dist.all_reduce(global_grad_tokens, op=dist.ReduceOp.SUM) avg_sum_entropy = calculate_entropy(shifted_labels_list, old_logprobs_list, global_grad_tokens) avg_rollout_entropy = calculate_entropy(shifted_labels_list, rollout_logprobs_list, global_grad_tokens) assert avg_sum_entropy is not None worker_log_item["train_entropy"] = avg_sum_entropy.item() logger_msg += f"avg entropy: {avg_sum_entropy:.4f}" if avg_rollout_entropy is not None: worker_log_item["rollout_entropy"] = avg_rollout_entropy.item() logger_msg += f", avg rollout entropy: {avg_rollout_entropy:.4f}" # compute rollout importance sampling metrics all_rollout_is_metrics = [] all_mismatch_metrics = [] for i, loss_ctx in enumerate(loss_ctx_list): if loss_ctx.loss_kwargs.rollout_logprobs is not None: # calculate importance sampling weights num_tokens = seq_ctx_list[i].seq_lens_q mismatch_metrics, rollout_is_metrics = loss_ctx.compute_rollout_is(self.sp_mesh, num_tokens) all_rollout_is_metrics.append(rollout_is_metrics) all_mismatch_metrics.append(mismatch_metrics) if len(all_mismatch_metrics) > 0: mismatch_metrics = merge_rollout_is_metrics(all_mismatch_metrics, DEVICE) if len(mismatch_metrics) > 0: worker_log_item["mismatch_metrics"] = mismatch_metrics logger_msg += f"\n rollout mismatch metrics:\n{json.dumps(mismatch_metrics, indent=4)}" if len(all_rollout_is_metrics) > 0: rollout_is_metrics = merge_rollout_is_metrics(all_rollout_is_metrics, DEVICE) if len(rollout_is_metrics) > 0: worker_log_item["rollout_is_metrics"] = rollout_is_metrics logger_msg += f"\n rollout importance sampling metrics:\n{json.dumps(rollout_is_metrics, indent=4)}" if self.rank == 0: self.logger.info(logger_msg) only_calc_mismatch_ratio = os.environ.get("ONLY_CALC_MISMATCH_RATIO", "0") == "1" if only_calc_mismatch_ratio: return worker_log_item # compute reference logprobs ref_logprobs_list: list[torch.Tensor] | None = None if self._has_ref: ref_logprobs_list = self.compute_ref_logprobs(seq_ctx_list, shifted_labels_list) for i, loss_ctx in enumerate(loss_ctx_list): loss_ctx.loss_kwargs.ref_logprobs = ref_logprobs_list[i] kl_div_sum: torch.Tensor | None = None for i, shifted_labels in enumerate(shifted_labels_list): mask = shifted_labels != -100 kl_div = kl_penalty( cast(torch.Tensor, old_logprobs_list[i]), cast(torch.Tensor, ref_logprobs_list[i]), loss_weights=mask, kl_penalty="low_var_kl", ) kl_div_sum = kl_div if kl_div_sum is None else kl_div_sum + kl_div kl_div_sum = cast(torch.Tensor, kl_div_sum) dist.all_reduce(kl_div_sum, op=dist.ReduceOp.SUM) avg_kl_div = kl_div_sum / global_grad_tokens if global_grad_tokens > 0 else 0 self.logger.info(f"Rollout {rollout_idx}: avg KL divergence: {avg_kl_div:.4f}") # compute batched loss context batched_loss_ctx_list: list[BaseRLLossContext] = [] batched_mtp_loss_ctx_list: list[list[MTPLossContext]] = [] LossContext = loss_cfg.loss_ctx_cls for i in range(0, len(loss_ctx_list), iters_per_step): batches_loss_ctx = loss_ctx_list[i : i + iters_per_step] batched_loss_ctx_list.extend( LossContext.build_batches(batches_loss_ctx) # type: ignore[arg-type] ) if self.mtp_config is not None: batches_seq_ctx = seq_ctx_list[i : i + iters_per_step] cu_seq_lens_list = [seq_ctx.cu_seq_lens_q for seq_ctx in batches_seq_ctx] # mtp_loss_ctx_list: list[list[MTPLossContext]], outer=batch, inner=mtp_depth num_mtp_depths = len(mtp_loss_ctx_list[0]) if mtp_loss_ctx_list else 0 for mtp_idx in range(num_mtp_depths): depth_mtp_loss_ctxs: list[LMHeadLossContext] = [ mtp_loss_ctx_list[j][mtp_idx] for j in range(i, min(i + iters_per_step, len(mtp_loss_ctx_list))) ] batched_mtp_depth_ctxs = cast( list[MTPLossContext], MTPLossContext.build_batches( depth_mtp_loss_ctxs, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=self.sp_mesh, ), ) # Append each depth's batched ctx to the corresponding batch index for batch_offset, mtp_ctx in enumerate(batched_mtp_depth_ctxs): global_batch_idx = i + batch_offset if global_batch_idx >= len(batched_mtp_loss_ctx_list): batched_mtp_loss_ctx_list.append([mtp_ctx]) else: batched_mtp_loss_ctx_list[global_batch_idx].append(mtp_ctx) # train optimizer steps for i in range(0, len(seq_ctx_list), iters_per_step): global_train_step = self._global_train_step + 1 batches_seq_ctx = seq_ctx_list[i : i + iters_per_step] batches_loss_ctx = batched_loss_ctx_list[i : i + iters_per_step] engine_input = [ ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx}) for seq_ctx, loss_ctx in zip(batches_seq_ctx, batches_loss_ctx) ] if self.mtp_config is not None: batches_mtp_loss_ctxs = batched_mtp_loss_ctx_list[i : i + iters_per_step] engine_input = [ ModelItem( seq_ctx=seq_ctx, loss_ctx=cast( dict[str, BaseLossContext], {"mtp": mtp_loss_ctx_depths, "lm": loss_ctx}, ), ) for seq_ctx, loss_ctx, mtp_loss_ctx_depths in zip( batches_seq_ctx, batches_loss_ctx, batches_mtp_loss_ctxs ) ] train_step_begin = time.perf_counter() with self._maybe_profiling(global_train_step, "train_step"): train_step_info = self._engine.train_step( data_batches=engine_input, ) self.logger.debug( f"Rank{self.rank} Rollout {rollout_idx} GlobalStep {global_train_step} " f"train_step[{i}].engine_train_step elapsed={time.perf_counter() - train_step_begin:.4f}s" ) grad_norm = self._engine.clip_grad_norm() self._engine.step_optimizer(grad_norm) engine_logs_info = cast(dict[str, float], train_step_info.pop("logs_info")) # type: ignore[misc] engine_extra_info = train_step_info.pop("extra_info") # type: ignore[misc] if isinstance(engine_extra_info, ModelForwardExtraLogInfo): extra_info_dict = engine_extra_info.get() else: extra_info_dict = cast(dict, engine_extra_info) extra_info_dict = { k: v.item() if isinstance(v, torch.Tensor) else v for k, v in extra_info_dict.items() if isinstance(v, (torch.Tensor, int, float)) } extra_info_dict = finalize_train_policy_metrics(extra_info_dict, DEVICE) train_step_info.pop("total_loss") # type: ignore[misc] train_log_item = WorkerTrainLogItem( **engine_logs_info, # type: ignore[typeddict-item] **train_step_info, **extra_info_dict, grad_norm=grad_norm.item(), ) worker_log_item["train_metrics"].append(train_log_item) # Extract logs_info for logging log_str = ", ".join( f"{key}={value:.4f}" if isinstance(value, float) else f"{key}={value}" for key, value in train_log_item.items() if not key.startswith("reduced_train_policy_") and key != "max_ratio" ) log_str = f"Rank{self.rank} Rollout {rollout_idx} Step {i}: " + log_str self.logger.info(log_str) self._global_train_step = global_train_step self._rollout_step += 1 if self._sft_dataloader is not None and self._rollout_step % self._rollout_steps_per_sft == 0: train_step_info = self._fit_sft() engine_logs_info = train_step_info["logs_info"] worker_log_item["sft_train_metrics"] = { **engine_logs_info, **train_step_info["extra_info"].get(), "efficient_attn_ratio": train_step_info["efficient_attn_ratio"], } return worker_log_item def _fit_sft(self): self.logger.info(f"Train SFT after {self._rollout_step} RL steps") if self._sft_dataloader_iter is None: self._sft_dataloader_iter = iter(self._sft_dataloader) time_before_get_data = time.time() data_batch = self._next_sft_data_batch() time_before_train_step = time.time() data_time = time_before_train_step - time_before_get_data DEVICE_MODULE.reset_peak_memory_stats() train_step_info, grad_norm = self._train_one_step_sft(data_batch) time_after_train_step = time.time() step_time = time_after_train_step - time_before_train_step step_consumed_tokens = train_step_info["step_consumed_tokens"] reduced_step_consumed_tokens = self._reduce_number_across_rank(step_consumed_tokens) self._sft_total_consumed_tokens += reduced_step_consumed_tokens self._sft_log_step( train_step_info=train_step_info, local_step_consumed_tokens=step_consumed_tokens, step_consumed_tokens=reduced_step_consumed_tokens, total_consumed_tokens=self._sft_total_consumed_tokens, data_time=data_time, step_time=step_time, grad_norm=grad_norm, ) return train_step_info def _next_sft_data_batch(self): try: data = next(self._sft_dataloader_iter) # type: ignore[assignment] except StopIteration: self._sft_cur_epoch += 1 self._sft_dataloader.set_epoch(self._sft_cur_epoch) self._sft_dataloader_iter = iter(self._sft_dataloader) data = next(self._sft_dataloader_iter) return data def _train_one_step_sft(self, data_batch): seq_ctx_list: list[SequenceContext] = [] loss_cfg: CELossConfig = self._sft_loss_cfg loss_ctx_list: list[CELossContext] = [] for data in data_batch: seq_ctx = data["seq_ctx"].to(DEVICE) if self.sp_mesh.size() > 1: seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh) seq_ctx_list.append(seq_ctx) loss_ctx = loss_cfg.build(data={"shifted_labels": data["shifted_labels"]}, sp_mesh=self.sp_mesh) loss_ctx_list.append(loss_ctx) del data_batch cu_seq_lens_list = [seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list] loss_ctx_list = CELossContext.build_batches( loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=self.sp_mesh ) engine_input = [ ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx}) for seq_ctx, loss_ctx in zip(seq_ctx_list, loss_ctx_list) ] train_step_info = self._engine.train_step(engine_input) grad_norm = self._engine.clip_grad_norm() self._engine.step_optimizer(grad_norm) return train_step_info, grad_norm def _sft_log_step( self, train_step_info: TrainStepInfo, local_step_consumed_tokens: int, step_consumed_tokens: int, total_consumed_tokens: int, data_time: float, step_time: float, grad_norm: torch.Tensor, ): tgs = local_step_consumed_tokens / step_time logs_info = train_step_info.get("logs_info", {}) log_items = [f"{k}: {v:.8f}" for k, v in logs_info.items() if "loss" in k] log_items.append(f"total_loss: {train_step_info['total_loss']:.8f}") loss_log_str = ", ".join(log_items) max_memory = DEVICE_MODULE.max_memory_allocated() # type: ignore[attr-defined] reserved_memory = DEVICE_MODULE.max_memory_reserved() # type: ignore[attr-defined] self.logger.info( f"Rank{self.rank} Step {self._rollout_step}: data_time: {data_time:.4f} time: {step_time:.4f} " f"text_tokens: {local_step_consumed_tokens} " f"step_consumed_tokens: {step_consumed_tokens} " f"total_consumed_tokens: {total_consumed_tokens} " f"efficient_attn_ratio: {train_step_info['efficient_attn_ratio']:.4f} " f"{loss_log_str} " f"grad_norm: {grad_norm:.8f} " f"max_memory: {max_memory / (1024**3):.2f} GB " f"reserved_memory: {reserved_memory / (1024**3):.2f} GB " f"tgs: {tgs:.4f}" ) def _reduce_number_across_rank(self, rank_number: int) -> int: _gathered_list = [None for _ in range(dist.get_world_size())] dist.all_gather_object(_gathered_list, rank_number) reduced_number = sum(_gathered_list) # type: ignore[arg-type] return reduced_number @ray_method def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16): self._engine.save_hf(hf_dir, save_dtype) @ray_method def get_data_replicate_size(self) -> int: """Get the data replicate size for the training worker.""" # tp and pp will affect the data replicate size in engine # sp will affect the data replicate size in worker return self._engine.data_replicate_size * self.sp_mesh.size() @ray_method def get_model_cfg(self): model_cfg = self._engine.model_cfg return model_cfg @ray_method def offload_model(self): self._engine.put_model_to_device("cpu") DEVICE_MODULE.empty_cache() self.logger.info( f"Offloaded model to CPU. Current allocate {DEVICE_MODULE.memory_allocated() / (1024**2)} MB, reserved: {DEVICE_MODULE.memory_reserved() / (1024**2)} MB" ) @ray_method def offload_optimizer(self): """Offload the optimizer of the training worker.""" self._engine.put_optimizer_to_device("cpu") DEVICE_MODULE.empty_cache() self.logger.info( f"Offloaded optimizer to CPU. Current allocate {DEVICE_MODULE.memory_allocated() / (1024**2)} MB, " f"reserved: {DEVICE_MODULE.memory_reserved() / (1024**2)} MB" ) @ray_method def onload_model(self): self._engine.put_model_to_device(DEVICE) @ray_method def onload_optimizer(self): self._engine.put_optimizer_to_device(DEVICE) @ray_method def save(self, checkpoint_path: Path | str, no_save_optimizer: bool = False): """Save the DCP checkpoint of the training worker.""" if not isinstance(checkpoint_path, Path): checkpoint_path = Path(checkpoint_path) weights_path = checkpoint_path / self._SAVE_WEIGHTS_DIR # Save model and optimizer self._engine.save_dcp( weights_dir=weights_path, save_optimizer=not no_save_optimizer, ) # Save sft dataloader if self._sft_dataloader is not None: sft_dataloader_path = checkpoint_path / self._SAVE_SFT_DATALOADER_DIR dataloader_state = self._sft_dataloader.get_state_dict() total_consumed_samples = dataloader_state["total_consumed_samples"] if self.rank != 0: return torch.save(dataloader_state, sft_dataloader_path) train_state_path = checkpoint_path / self._SAVE_SFT_TRAIN_STATE_PATH with train_state_path.open("w") as f: f.write( json.dumps( { "cur_step": self._rollout_step, "cur_epoch": self._sft_cur_epoch, "total_consumed_samples": total_consumed_samples, "total_consumed_tokens": self._sft_total_consumed_tokens, } ) ) @ray_method def resume(self, load_checkpoint_cfg: LoadCheckpointConfig): """Resume the training worker from the checkpoint.""" resume_from = load_checkpoint_cfg.checkpoint_path if resume_from is None: return if isinstance(resume_from, str): resume_from = Path(resume_from) self.logger.info(f"Resume from checkpoint: {resume_from}") if not resume_from.exists(): raise FileNotFoundError(f"Checkpoint path {resume_from} does not exist.") weights_path = resume_from / self._SAVE_WEIGHTS_DIR if not weights_path.exists(): raise FileNotFoundError(f"Checkpoint at {resume_from} has no '{self._SAVE_WEIGHTS_DIR}/' directory.") self._engine.load_dcp( weights_dir=weights_path, load_states=load_checkpoint_cfg.load_optimizer_states, load_args=load_checkpoint_cfg.load_optimizer_args, ) # Resume sft dataloader if self._sft_dataloader is not None: train_state_path = resume_from / self._SAVE_SFT_TRAIN_STATE_PATH if not train_state_path.exists(): raise FileNotFoundError(f"Train state path {train_state_path} does not exist.") with train_state_path.open("r") as f: train_state = json.loads(f.read()) self._rollout_step = train_state["cur_step"] self._sft_cur_epoch = train_state["cur_epoch"] self._sft_total_consumed_tokens = train_state["total_consumed_tokens"] self.logger.info(f"Resume sft train state from {train_state_path}") sft_dataloader_path = resume_from / self._SAVE_SFT_DATALOADER_DIR if not sft_dataloader_path.exists(): raise FileNotFoundError(f"Dataloader path {sft_dataloader_path} does not exist.") dataloader_state = torch.load(sft_dataloader_path, map_location=DEVICE) self._sft_dataloader.load_state_dict(dataloader_state) self.logger.info(f"Resume sft dataloader from {sft_dataloader_path}") @ray_method def ready(self) -> bool: return True TrainingWorkerClass = ActorClass[TrainingWorker] TrainingWorkerProxy = ActorProxy[TrainingWorker]