Source code for xtuner.v1.train.rl_trainer

import asyncio
import json
import math
import os
import random
import re
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from shutil import rmtree
from typing import Any, List, cast

import numpy as np
import ray
import torch
from mmengine.dist import get_rank
from mmengine.runner import set_random_seed
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Literal, TypedDict

from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from xtuner.v1._writer import get_writer
from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.data_proto.sequence_context import SequenceContext
from xtuner.v1.patch import patch_default_save_plan
from xtuner.v1.rl.advantage import BaseAdvantageConfig, GRPOAdvantageConfig
from xtuner.v1.rl.agent_loop_manager import (
    AgentLoopManagerConfig,
    ProduceBatchResult,
    ProduceBatchStatus,
)
from xtuner.v1.rl.agent_loop_manager.producer import default_should_continue_fn
from xtuner.v1.rl.evaluator import EvaluatorConfig
from xtuner.v1.rl.gateway.config import GatewayConfig
from xtuner.v1.rl.replay_buffer import (
    AsyncReplayBufferConfig,
    SyncReplayBufferConfig,
    _restore_nested_objectrefs,
    _snapshot_nested_objectrefs,
)
from xtuner.v1.rl.rollout.controller import RolloutControllerProxy
from xtuner.v1.rl.rollout.worker import RolloutConfig
from xtuner.v1.rl.trainer.controller import TrainingController
from xtuner.v1.rl.trainer.worker import WorkerConfig, WorkerLogItem
from xtuner.v1.rl.utils import (
    AcceleratorResourcesConfig,
    AutoAcceleratorWorkers,
    CPUResourceManager,
    asyncio_run,
    create_task,
    set_cpu_resource_manager,
    sort_rollout_state_for_deterministic,
)
from xtuner.v1.train.trainer import LoadCheckpointConfig, XTunerMeta
from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, is_hf_model_path, set_deterministic, timer
from xtuner.v1.utils.device import get_device, get_torch_device_module
from xtuner.v1.utils.env_check import get_rollout_engine_version


# TODO: Move DEVICE to `xtuner.utils.device`
PG_READY_TIMEOUT = 30
DEVICE = get_device()
DEVICE_MODULE = get_torch_device_module()


def _to_cpu_tensor(value: np.ndarray | None, *, dtype: torch.dtype | None = None) -> torch.Tensor | None:
    if value is None:
        return None
    assert isinstance(value, np.ndarray), f"Expected np.ndarray, got {type(value)}"
    return torch.as_tensor(value, dtype=dtype, device="cpu")


def check_fa3():
    if os.environ.get("XTUNER_USE_FA3", "0") != "1":
        return

    try:
        from xtuner.v1.ops.flash_attn import get_flash_attn_varlen

        get_flash_attn_varlen()
    except RuntimeError as e:
        raise RuntimeError(f"Flash attention v3 runtime error {e}, Please install it first or set XTUNER_USE_FA3=0.")


def bind_train_rollout(
    train_controller: TrainingController,
    rollout_controller: RolloutControllerProxy,
) -> None:
    """Bind the training and rollout workers for update weights."""
    info_dict = ray.get(rollout_controller.get_rollout_metadata.remote())  # type: ignore[attr-defined]
    train_controller.update_rollout_info(info_dict)
    return


def _parse_debug_rollout_step(path: Path) -> int:
    match = re.fullmatch(r"debug_rollout_(\d+)\.pt", path.name)
    if match is None:
        raise ValueError(f"Unexpected debug rollout file name: {path}")
    return int(match.group(1))


class TrainInfo(TypedDict, total=False):
    data_info: dict[str, float]
    workers_log_item: list[WorkerLogItem]


@dataclass(frozen=True)
class RLThroughputBenchmark:
    """Throughput metrics exported by RL trainer.

    Keep this dataclass focused on concise, user-facing throughput signals.
    Large counters and intermediate rates used only for computation should stay
    as local variables in `_compute_benchmark_metrics`.

    Metrics:
        sgs means samples per GPU per second, and tgs means tokens per GPU per second.

        e2e_effective_sgs: Run-level E2E effective sample throughput per train
            worker/GPU.
            It uses cumulative training-consumed samples since RL training start
            divided by elapsed wall time from RL training start to current step
            log time and train worker count.
        e2e_effective_tgs: Per-train-worker run-level E2E effective token
            throughput. It uses cumulative training-consumed tokens since RL
            training start divided by run-level E2E elapsed time and train
            worker count.
        effective_sgs: Current step effective sample throughput per train
            worker/GPU. It
            uses samples consumed by the current training step divided by the
            full current step wall time, including rollout/get, prepare,
            training, sync/save/eval phases that run inside the step timer,
            and train worker count.
        effective_tgs: Per-train-worker current step effective token throughput.
            It uses tokens consumed by the current training step divided by the
            full current step wall time and train worker count.
        training_tgs: Per-train-worker training-only token throughput. It uses
            current step training-consumed tokens divided by `train_controller.fit`
            time and train worker count.
        rollout_sgs: Rollout sample throughput per rollout worker/GPU. It uses samples
            produced by the current rollout window divided by producer
            `produce_batch` wall time and rollout worker count.
        rollout_tgs: Per-rollout-worker rollout token throughput. It uses
            response tokens produced by the current rollout window divided by
            producer `produce_batch` wall time and rollout worker count.
    """

    e2e_effective_sgs: float
    e2e_effective_tgs: float
    effective_sgs: float
    effective_tgs: float
    training_tgs: float
    rollout_sgs: float
    rollout_tgs: float

    def to_scalars(self) -> dict[str, float]:
        return {f"throughput/{key}": value for key, value in asdict(self).items()}


def get_train_seq_ctx(
    input_ids: torch.LongTensor,
    position_ids: np.ndarray | None = None,
    multimodal_train_info: dict | None = None,
    len_response_ids: int = 0,
):
    seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu")
    position_ids = _to_cpu_tensor(position_ids, dtype=torch.long)
    if position_ids is not None and len(position_ids.shape) == 3:
        # qwen3vl 需要特殊处理,其余的不需要额外处理
        max_value = position_ids.max(dim=-1).values  # (3,1)
        response_position_ids = max_value.unsqueeze(-1).expand(-1, -1, len_response_ids) + torch.arange(
            1, len_response_ids + 1, device=max_value.device
        )
        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
        seq_ctx.position_ids = position_ids  # type: ignore[assignment]
        assert position_ids.size(-1) == input_ids.size(-1)

    if multimodal_train_info:
        seq_ctx.pixel_values = multimodal_train_info.get("pixel_values")
        seq_ctx.image_grid_thw = _to_cpu_tensor(multimodal_train_info.get("image_grid_thw"), dtype=torch.long)
    return seq_ctx


def is_valid_for_training(group_data_items: list[RolloutState], logger) -> bool:
    """Checks if a group of rollout states is valid for a training step.

    Args:
        group_data_items: A list of RolloutState objects.

    Returns:
        True if the group is valid, False otherwise.

    NOTE: Why this check is needed:
    - For system fault tolerance, this check is performed at rollout / dataflow
      time, but we still do it here to ensure training data integrity.
    - 'filtered'/'failed': These items are fundamentally broken or incomplete and
      should not be used for training.
    - 'aborted': These items represent rollouts that were stopped
      prematurely. Using such partial data could lead the model to learn
      undesirable behaviors (e.g., stopping generation too early).
    - Empty response/response_ids: The model's generated response is the core
      of the training data for RL algorithms like PPO. If the response is
      missing, there is nothing to compute rewards on or to train the model with.
    """
    is_abort = any(item.status== Status.ABORTED for item in group_data_items)
    is_filtered = any(item.status== Status.FILTERED for item in group_data_items)
    is_failed = any(item.status== Status.FAILED for item in group_data_items)
    if is_filtered or is_failed or is_abort:
        logger.warning(
            f"Invalid dataflow group found during training, rollout state filtered: {is_filtered}, failed: {is_failed}, aborted: {is_abort}."
        )
        return False
    for item in group_data_items:
        response_valid = item.response is not None and len(item.response) > 0
        ids_valid = item.response_ids is not None and len(item.response_ids) > 0
        if not ids_valid:
            # NOTE: `response_ids` is the critical field for token-in-token-out mode, so we ensure it's not empty.
            logger.warning(
                "Invalid dataflow item found during training: no response or response_ids and skip this item."
            )
            return False
        if not response_valid:
            # NOTE: check valid response string for judger inputs
            logger.warning("Invalid dataflow item found during training: empty response string and skip this item.")
            return False
    return True


def _validate_sync_intervals(
    sync_weights_interval: int,
    checkpoint_interval: int | None,
    hf_interval: int | None,
    evaluate_step: int | None = None,
    enable_evaluate: bool = False,
) -> None:
    if sync_weights_interval <= 0:
        raise ValueError(f"sync_weights_interval must be positive, got {sync_weights_interval}.")

    for name, interval in (
        ("checkpoint_interval", checkpoint_interval),
        ("hf_interval", hf_interval),
    ):
        if interval is None or interval == -1:
            continue
        if interval <= 0:
            raise ValueError(f"{name} must be positive or -1/None to disable it, got {interval}.")
        if interval % sync_weights_interval != 0:
            raise ValueError(
                f"{name}={interval} must be a multiple of sync_weights_interval={sync_weights_interval}, "
                "because checkpoint/HF saves only run on weight-sync steps."
            )

    if enable_evaluate:
        if evaluate_step is None or evaluate_step <= 0:
            raise ValueError(f"evaluate_step must be positive when evaluation is enabled, got {evaluate_step}.")
        if evaluate_step % sync_weights_interval != 0:
            raise ValueError(
                f"evaluate_step={evaluate_step} must be a multiple of "
                f"sync_weights_interval={sync_weights_interval}, because evaluation only runs on weight-sync steps."
            )


class BaseRLTrainerConfig(BaseModel):
    """Base configuration shared by XTuner RL trainers.

    This base class defines the common training, rollout, evaluation, checkpoint, and logging fields used by both
    colocated and disaggregated RL trainers. Concrete trainer configs add their resource layout fields.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")

    train_worker_cfg: WorkerConfig
    rollout_config: RolloutConfig
    tokenizer_path: str | Path
    replay_buffer_config: SyncReplayBufferConfig | AsyncReplayBufferConfig = SyncReplayBufferConfig()
    agent_loop_manager_cfg: AgentLoopManagerConfig
    eval_agent_loop_manager_cfg: AgentLoopManagerConfig | None = None
    evaluator_config: EvaluatorConfig | None = None
    load_from: str | Path
    total_train_steps: int | None = None
    total_epochs: int | None = None
    train_batch_size: int
    advantage_estimator_config: BaseAdvantageConfig = Field(default_factory=GRPOAdvantageConfig)
    sync_weights_interval: int = 1
    gateway_config: GatewayConfig | None = None

    enable_evaluate: bool = True
    enable_initial_evaluate: bool = False
    evaluate_step: int = 1
    work_dir: Path | str | None = None
    auto_resume: bool = False
    load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig()
    checkpoint_interval: int | None = -1
    checkpoint_maxkeep: int | None = -1
    hf_interval: int | None = -1
    hf_max_keep: int | None = -1
    checkpoint_no_save_optimizer: bool = False
    log_dir: Path | str | None = None
    seed: int = 42
    debug_rollout: bool = False
    debug_rollout_dir: Path | str | None = None
    debug_train: bool = False
    skip_checkpoint_validation: bool = False
    exp_tracker: Literal["tensorboard", "jsonl"] = "tensorboard"

    @model_validator(mode="after")
    def _validate_sync_intervals(self):
        if self.debug_rollout and self.debug_train:
            raise ValueError("debug_rollout and debug_train cannot be enabled at the same time.")
        if self.debug_rollout and self.debug_rollout_dir is None:
            raise ValueError("debug_rollout_dir must be provided when debug_rollout=True.")
        if self.debug_train and self.debug_rollout_dir is None:
            raise ValueError("debug_rollout_dir must be provided when debug_train=True.")
        if not self.debug_train and self.total_train_steps is None and self.total_epochs is None:
            raise ValueError("Either total_train_steps or total_epochs must be provided.")
        if self.total_train_steps is not None and self.total_train_steps <= 0:
            raise ValueError(f"total_train_steps must be positive, got {self.total_train_steps}.")
        if self.total_epochs is not None and self.total_epochs <= 0:
            raise ValueError(f"total_epochs must be positive, got {self.total_epochs}.")
        _validate_sync_intervals(
            sync_weights_interval=self.sync_weights_interval,
            checkpoint_interval=self.checkpoint_interval,
            hf_interval=self.hf_interval,
            evaluate_step=self.evaluate_step,
            enable_evaluate=self.enable_evaluate,
        )
        return self


[docs]class RLColocateTrainerConfig(BaseRLTrainerConfig): """Configuration for the colocated RL trainer. ``RLColocateTrainerConfig`` runs training workers and rollout workers on a shared accelerator resource pool. It is typically used when rollout and training alternate on the same set of devices. Args: train_worker_cfg (WorkerConfig): Training worker configuration, including model, optimizer, loss, and FSDP settings. rollout_config (RolloutConfig): Rollout backend configuration. tokenizer_path (str | Path): Tokenizer path used by the agent loop sampler and rollout processing. replay_buffer_config (SyncReplayBufferConfig | AsyncReplayBufferConfig): Replay buffer configuration. Defaults to ``SyncReplayBufferConfig``. agent_loop_manager_cfg (AgentLoopManagerConfig): Agent loop manager configuration used for training rollout production. eval_agent_loop_manager_cfg (AgentLoopManagerConfig | None): Optional agent loop manager for evaluation. Defaults to None. evaluator_config (EvaluatorConfig | None): Optional evaluator configuration. Defaults to None. load_from (str | Path): Initial checkpoint or model path to load. total_train_steps (int | None): Total number of training steps. Defaults to None. total_epochs (int | None): Total number of dataset epochs. Defaults to None. train_batch_size (int): Number of rollout samples consumed per training step. advantage_estimator_config (BaseAdvantageConfig): Advantage estimator configuration. Defaults to ``GRPOAdvantageConfig``. sync_weights_interval (int): Interval, in train steps, for syncing weights from training to rollout. Defaults to 1. gateway_config (GatewayConfig | None): Optional gateway configuration. Defaults to None. enable_evaluate (bool): Whether to run evaluation. Defaults to True. enable_initial_evaluate (bool): Whether to evaluate before training. Defaults to False. evaluate_step (int): Evaluation interval in train steps. Defaults to 1. work_dir (Path | str | None): Directory for checkpoints and runtime state. Defaults to None. auto_resume (bool): Whether to resume automatically from ``work_dir``. Defaults to False. load_checkpoint_cfg (LoadCheckpointConfig): Checkpoint loading policy. Defaults to ``LoadCheckpointConfig()``. checkpoint_interval (int | None): Native checkpoint interval. Defaults to -1. checkpoint_maxkeep (int | None): Maximum number of native checkpoints to keep. Defaults to -1. hf_interval (int | None): Hugging Face checkpoint export interval. Defaults to -1. hf_max_keep (int | None): Maximum number of Hugging Face checkpoints to keep. Defaults to -1. checkpoint_no_save_optimizer (bool): Whether to skip optimizer states when saving checkpoints. Defaults to False. log_dir (Path | str | None): Directory for logs. Defaults to None. seed (int): Global random seed. Defaults to 66. debug_rollout (bool): Whether to enable rollout debugging. Defaults to False. skip_checkpoint_validation (bool): Whether to skip checkpoint validation. Defaults to False. exp_tracker (Literal["tensorboard", "jsonl"]): Experiment tracker type. Defaults to "tensorboard". resources (AcceleratorResourcesConfig): Shared accelerator resources used by both training and rollout workers. **Examples:** Example colocated trainer configuration:: config = RLColocateTrainerConfig( train_worker_cfg=train_worker_cfg, rollout_config=rollout_config, tokenizer_path="Qwen/Qwen3-8B", agent_loop_manager_cfg=agent_loop_manager_cfg, load_from="Qwen/Qwen3-8B", total_train_steps=1000, train_batch_size=128, resources=AcceleratorResourcesConfig(num_workers=8), ) """ resources: AcceleratorResourcesConfig def build(self) -> "RLColocateTrainer": return RLColocateTrainer(self)
[docs]class RLDisaggregatedTrainerConfig(BaseRLTrainerConfig): """Configuration for the disaggregated RL trainer. ``RLDisaggregatedTrainerConfig`` uses separate accelerator resource pools for training and rollout. It is typically used when rollout production runs concurrently with training on dedicated devices. Args: train_worker_cfg (WorkerConfig): Training worker configuration, including model, optimizer, loss, and FSDP settings. rollout_config (RolloutConfig): Rollout backend configuration. tokenizer_path (str | Path): Tokenizer path used by the agent loop sampler and rollout processing. replay_buffer_config (SyncReplayBufferConfig | AsyncReplayBufferConfig): Replay buffer configuration. Defaults to ``SyncReplayBufferConfig``. agent_loop_manager_cfg (AgentLoopManagerConfig): Agent loop manager configuration used for training rollout production. eval_agent_loop_manager_cfg (AgentLoopManagerConfig | None): Optional agent loop manager for evaluation. Defaults to None. evaluator_config (EvaluatorConfig | None): Optional evaluator configuration. Defaults to None. load_from (str | Path): Initial checkpoint or model path to load. total_train_steps (int | None): Total number of training steps. Defaults to None. total_epochs (int | None): Total number of dataset epochs. Defaults to None. train_batch_size (int): Number of rollout samples consumed per training step. advantage_estimator_config (BaseAdvantageConfig): Advantage estimator configuration. Defaults to ``GRPOAdvantageConfig``. sync_weights_interval (int): Interval, in train steps, for syncing weights from training to rollout. Defaults to 1. gateway_config (GatewayConfig | None): Optional gateway configuration. Defaults to None. enable_evaluate (bool): Whether to run evaluation. Defaults to True. enable_initial_evaluate (bool): Whether to evaluate before training. Defaults to False. evaluate_step (int): Evaluation interval in train steps. Defaults to 1. work_dir (Path | str | None): Directory for checkpoints and runtime state. Defaults to None. auto_resume (bool): Whether to resume automatically from ``work_dir``. Defaults to False. load_checkpoint_cfg (LoadCheckpointConfig): Checkpoint loading policy. Defaults to ``LoadCheckpointConfig()``. checkpoint_interval (int | None): Native checkpoint interval. Defaults to -1. checkpoint_maxkeep (int | None): Maximum number of native checkpoints to keep. Defaults to -1. hf_interval (int | None): Hugging Face checkpoint export interval. Defaults to -1. hf_max_keep (int | None): Maximum number of Hugging Face checkpoints to keep. Defaults to -1. checkpoint_no_save_optimizer (bool): Whether to skip optimizer states when saving checkpoints. Defaults to False. log_dir (Path | str | None): Directory for logs. Defaults to None. seed (int): Global random seed. Defaults to 66. debug_rollout (bool): Whether to enable rollout debugging. Defaults to False. skip_checkpoint_validation (bool): Whether to skip checkpoint validation. Defaults to False. exp_tracker (Literal["tensorboard", "jsonl"]): Experiment tracker type. Defaults to "tensorboard". train_resources (AcceleratorResourcesConfig): Accelerator resources for training workers. rollout_resources (AcceleratorResourcesConfig): Accelerator resources for rollout workers. **Examples:** Example disaggregated trainer configuration:: config = RLDisaggregatedTrainerConfig( train_worker_cfg=train_worker_cfg, rollout_config=rollout_config, tokenizer_path="Qwen/Qwen3-8B", agent_loop_manager_cfg=agent_loop_manager_cfg, load_from="Qwen/Qwen3-8B", total_train_steps=1000, train_batch_size=128, train_resources=AcceleratorResourcesConfig(num_workers=4), rollout_resources=AcceleratorResourcesConfig(num_workers=4), ) """ train_resources: AcceleratorResourcesConfig rollout_resources: AcceleratorResourcesConfig def build(self) -> "RLDisaggregatedTrainer": return RLDisaggregatedTrainer(self)
class BaseRLTrainer: _EXP_TRACKING_PATH = "exp_tracking" _CHECKPOINT_DIR = "checkpoints" _HF_DIR = "hf" _SAVE_TRAIN_STATE_PATH = "train_state.json" train_controller: TrainingController rollout_controller: RolloutControllerProxy _debug_train_files: dict[int, Path] def _init_common(self, cfg: BaseRLTrainerConfig, *, meta_path: str, logger_tag: str) -> None: check_fa3() self._init_work_dir_and_meta(cfg, meta_path) self._init_load_source(cfg) self._init_save_config(cfg) log_dir = self._init_logger(cfg, logger_tag) self._save_runtime_environment(log_dir) self._init_train_state(cfg) self._init_train_worker_config(cfg, log_dir) self._init_rollout_config(cfg, log_dir) self._ensure_rollout_http_concurrency(cfg) self._init_runtime_flags(cfg) self._advantage_estimator = cfg.advantage_estimator_config.build() self._cpu_resource_manager: CPUResourceManager | None = None self._num_workers = 1.0 self._rollout_num_workers = 1.0 self._benchmark_start_time_s: float | None = None self._benchmark_training_samples: int = 0 self._benchmark_training_tokens: int = 0 self._exp_tracker = get_writer(writer_type=cfg.exp_tracker, log_dir=log_dir / self._EXP_TRACKING_PATH) self._display_all_workers_log = False def _init_work_dir_and_meta(self, cfg: BaseRLTrainerConfig, meta_path: str) -> None: work_dir = Path(cfg.work_dir) if cfg.work_dir else Path.cwd() / "work_dirs" if get_rank() == 0: work_dir.mkdir(parents=True, exist_ok=True) self._meta = XTunerMeta.build(work_dir, meta_path, cfg.auto_resume) self._meta_path = meta_path def _init_load_source(self, cfg: BaseRLTrainerConfig) -> None: self._load_from = Path(cfg.load_from) if isinstance(cfg.load_from, str) else cfg.load_from is_hf_path, error_info = is_hf_model_path(cfg.load_from) if cfg.load_from is not None else (False, "") self._load_from_hf = is_hf_path if not self._load_from_hf: raise NotImplementedError(error_info) def _init_save_config(self, cfg: BaseRLTrainerConfig) -> None: self._hf_max_keep = cfg.hf_max_keep self._hf_interval = cfg.hf_interval self._checkpoint_interval = cfg.checkpoint_interval self._checkpoint_maxkeep = cfg.checkpoint_maxkeep self._checkpoint_no_save_optimizer = cfg.checkpoint_no_save_optimizer self._load_checkpoint_cfg = self._resolve_load_checkpoint_cfg(cfg.auto_resume, cfg.load_checkpoint_cfg) def _init_logger(self, cfg: BaseRLTrainerConfig, logger_tag: str) -> Path: log_dir = self.exp_dir / "logs" log_dir.mkdir(parents=True, exist_ok=True) self.logger = get_logger(log_dir=log_dir, tag=logger_tag) if cfg.skip_checkpoint_validation: patch_default_save_plan() return log_dir def _save_runtime_environment(self, log_dir: Path) -> None: if get_rank() != 0: return env_path = log_dir / "env.json" environment_variables = dict(os.environ) infer_engine_version = get_rollout_engine_version() environment_variables.update(infer_engine_version) with env_path.open("w") as f: json.dump(environment_variables, f, indent=2) def _init_train_state(self, cfg: BaseRLTrainerConfig) -> None: self._total_train_steps = cfg.total_train_steps or 0 self._total_epochs = cfg.total_epochs self._cur_step = 0 self._global_train_step = 0 self._seed = cfg.seed self.train_batch_size = cfg.train_batch_size self._sync_weights_interval = cfg.sync_weights_interval if XTUNER_DETERMINISTIC: set_deterministic() set_random_seed(cfg.seed) def _init_train_worker_config(self, cfg: BaseRLTrainerConfig, log_dir: Path) -> None: if cfg.train_worker_cfg.seed is None: self.logger.warning(f"RLTrainer seed {cfg.seed} is used as train worker seed.") cfg.train_worker_cfg.seed = cfg.seed cfg.train_worker_cfg.load_from = cfg.load_from cfg.train_worker_cfg.log_dir = log_dir self._train_worker_cfg = cfg.train_worker_cfg def _init_rollout_config(self, cfg: BaseRLTrainerConfig, log_dir: Path) -> None: cfg.rollout_config.worker_log_dir = log_dir if self._load_checkpoint_cfg.checkpoint_path is not None: cfg.rollout_config.skip_load_weights = True self.logger.info( f"Skip load rollout weights due to resume from checkpoint {self._load_checkpoint_cfg.checkpoint_path}" ) self._rollout_config = cfg.rollout_config def _ensure_rollout_http_concurrency(self, cfg: BaseRLTrainerConfig) -> None: rollout_max_batch_size = cfg.rollout_config.rollout_max_batch_size_per_instance if rollout_max_batch_size is None or rollout_max_batch_size <= 0: return if isinstance(cfg, RLDisaggregatedTrainerConfig): rollout_worker_count = cfg.rollout_resources.num_workers elif isinstance(cfg, RLColocateTrainerConfig): rollout_worker_count = cfg.resources.num_workers else: rollout_worker_count = 1 active_rollout_worker_count, _ = cfg.rollout_config.get_active_servers_count(rollout_worker_count) if active_rollout_worker_count <= 0: return tasks = cfg.agent_loop_manager_cfg.tasks task_cfgs = tasks if isinstance(tasks, list) else [tasks] total_weight = sum(task.weight for task in task_cfgs) if total_weight <= 0: return scheduled_http_requests = 0.0 for task in task_cfgs: task_batch_size = cfg.train_batch_size * task.weight / total_weight over_sample_threshold = float(getattr(task.produce_strategy_config, "over_sample_threshold", 0.0)) scheduled_http_requests += ( task_batch_size * task.sampler_config.prompt_repeat_k * (1 + over_sample_threshold) ) required_http_concurrency = math.ceil(scheduled_http_requests / active_rollout_worker_count) current_http_concurrency = math.ceil(rollout_max_batch_size * cfg.rollout_config.allow_over_concurrency_ratio) if current_http_concurrency >= required_http_concurrency: return new_ratio = required_http_concurrency / rollout_max_batch_size cfg.rollout_config.allow_over_concurrency_ratio = new_ratio self.logger.warning( "Increasing rollout_config.allow_over_concurrency_ratio because httpx max_connections is smaller " "than the expected per-worker rollout request concurrency: " f"max_connections={current_http_concurrency}, " f"required_connections={required_http_concurrency}" ) def _init_runtime_flags(self, cfg: BaseRLTrainerConfig) -> None: self._enable_evaluate = cfg.enable_evaluate self._enable_initial_evaluate = cfg.enable_initial_evaluate and cfg.enable_evaluate self._evaluate_step = cfg.evaluate_step self._debug_rollout = cfg.debug_rollout self._debug_rollout_dir = Path(cfg.debug_rollout_dir) if cfg.debug_rollout_dir is not None else None self._debug_train = cfg.debug_train self._debug_train_files: dict[int, Path] = {} def _maybe_start_gateway(self, cfg: BaseRLTrainerConfig) -> None: if cfg.gateway_config is None or not cfg.gateway_config.auto_start: return # gateway 依赖 rollout controller,因此在 rollout controller 构建完成后统一启动。 ray.get(self.rollout_controller.start_gateway.remote(cfg.gateway_config)) def _build_agent_loop_components(self, cfg: BaseRLTrainerConfig, replay_buffer) -> None: self.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path, trust_remote_code=True) self.agent_loop_manager = cfg.agent_loop_manager_cfg.build( rollout_controller=self.rollout_controller, tokenizer=self.tokenizer, replay_buffer=replay_buffer, logger=self.logger, sync_weights_interval=cfg.sync_weights_interval, ) if self._enable_evaluate: assert cfg.eval_agent_loop_manager_cfg is not None self.eval_agent_loop_manager = cfg.eval_agent_loop_manager_cfg.build( rollout_controller=self.rollout_controller, tokenizer=self.tokenizer, replay_buffer=replay_buffer, logger=self.logger, sync_weights_interval=cfg.sync_weights_interval, ) total_eval_samples = len(self.eval_agent_loop_manager.data_sampler) assert cfg.evaluator_config is not None self.evaluator = cfg.evaluator_config.build(total_eval_samples=total_eval_samples) self._resolve_total_train_steps(cfg) def _resolve_total_train_steps(self, cfg: BaseRLTrainerConfig) -> None: if cfg.total_train_steps is not None: self._total_train_steps = cfg.total_train_steps return assert cfg.total_epochs is not None dataset_size = len(self.agent_loop_manager.data_sampler) self._total_train_steps = dataset_size // cfg.train_batch_size * cfg.total_epochs self.logger.info( "Resolved total_train_steps from total_epochs: " f"dataset_size={dataset_size}, train_batch_size={cfg.train_batch_size}, " f"total_epochs={cfg.total_epochs}, total_train_steps={self._total_train_steps}" ) @property def exp_dir(self) -> Path: return Path(self._meta.latest_exp.exp_dir) def _resolve_load_checkpoint_cfg( self, auto_resume: bool, load_checkpoint_cfg: LoadCheckpointConfig ) -> LoadCheckpointConfig: """Resolve checkpoint path for auto-resume.""" latest_checkpoint = self._meta.latest_exp.latest_checkpoint if latest_checkpoint is not None and auto_resume: load_checkpoint_cfg.checkpoint_path = Path(latest_checkpoint) return load_checkpoint_cfg def _resume_train_controller_and_state(self, checkpoint_path: Path | str) -> Path: # 子类只复用训练 worker 和 train_state 恢复,权重同步流程各自维护。 self.logger.info(f"Resume train controller and state from {checkpoint_path}") checkpoint_path = Path(checkpoint_path) self.train_controller.resume(self._load_checkpoint_cfg) train_state_path = checkpoint_path / self._SAVE_TRAIN_STATE_PATH with train_state_path.open("r") as f: train_state = json.load(f) self._cur_step = train_state["cur_step"] return checkpoint_path async def _resume_agent_loop_manager(self, checkpoint_path: Path | str) -> int: self.logger.info(f"Resume agent_loop_manager from {checkpoint_path}") checkpoint_path = Path(checkpoint_path) # asyncio_run 只能出现在 trainer 的同步边界: # - colocate 的 __init__/fit/_sync_weights_and_save 仍是同步入口,可以显式包一层; # - disaggregated 的 _fit 已经在 asyncio_run 启动的事件循环里,内部必须全程 await。 # 因此 agent_loop_manager / replay_buffer 的 save/resume 必须保持 async;如果它们内部再调用 # asyncio_run,save/resume 会在 disaggregated 训练循环里触发 nested asyncio_run 失败。 saved_model_step = await self.agent_loop_manager.resume(checkpoint_path) return saved_model_step async def _maybe_save_checkpoint(self, cur_step: int) -> None: """Save checkpoint if interval condition is met.""" ckp_interval = self._checkpoint_interval if ckp_interval is None or ckp_interval == -1: return if cur_step % ckp_interval != 0: return checkpoint_path = self.exp_dir / self._CHECKPOINT_DIR / f"ckpt-step-{cur_step}" checkpoint_path.mkdir(parents=True, exist_ok=True) # 1. Save sampler (dataloader) state self.logger.info(f"Saving sampler state to {checkpoint_path}") # 保持 manager checkpoint 的 async 调用链。 # 是否 asyncio_run 只由 trainer 最外层同步入口统一决定。 await self.agent_loop_manager.save(checkpoint_path, model_step=cur_step) # 2. Save DCP checkpoint (model + optimizer) self.logger.info(f"Saving DCP checkpoint to {checkpoint_path}") self.train_controller.save(str(checkpoint_path), self._checkpoint_no_save_optimizer) # 3. Save train state JSON train_state_path = checkpoint_path / self._SAVE_TRAIN_STATE_PATH with train_state_path.open("w") as f: json.dump({"cur_step": cur_step}, f) # 4. Update meta current_exp = self._meta.latest_exp current_exp.checkpoint_list.append(str(checkpoint_path)) # 5. Prune old checkpoints ckp_maxkeep = self._checkpoint_maxkeep ckp_list = current_exp.checkpoint_list if ckp_maxkeep is not None and ckp_maxkeep > 0 and len(ckp_list) > ckp_maxkeep: for deleted in ckp_list[:-ckp_maxkeep]: if Path(deleted).exists(): rmtree(deleted, ignore_errors=True) current_exp.checkpoint_list = ckp_list[-ckp_maxkeep:] # 6. Persist meta to disk meta_path = self.exp_dir.parent / self._meta_path with meta_path.open("w") as f: f.write(self._meta.model_dump_json(indent=2)) def _maybe_save_hf(self, cur_step: int): if self._hf_interval is None or self._hf_interval == -1: return if not self._load_from_hf: raise RuntimeError( "Only support saving to Huggingface format when loading from Huggingface! " "You meet this error means `load_from` of trainer is not a Huggingface model path." ) if cur_step % self._hf_interval != 0 and cur_step != self._total_train_steps: return save_hf_path = self.exp_dir / self._HF_DIR / f"hf-step-{cur_step}" save_hf_path.mkdir(parents=True, exist_ok=True) # update meta current_exp = self._meta.latest_exp current_exp.hf_checkpoint_list.append(str(save_hf_path)) # save hf self.logger.info(f"Saving Huggingface checkpoint to {save_hf_path}") hf_list = self._meta.latest_exp.hf_checkpoint_list if self._hf_max_keep is not None and self._hf_max_keep > 0 and len(hf_list) > self._hf_max_keep: for deleted in hf_list[: -self._hf_max_keep]: if Path(deleted).exists(): rmtree(deleted, ignore_errors=True) current_exp.hf_checkpoint_list = hf_list[-self._hf_max_keep :] self.train_controller.save_hf(str(save_hf_path)) # save tokenizer if isinstance(self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): self.tokenizer.save_pretrained(str(save_hf_path)) async def _run_initial_evaluate(self) -> None: eval_produce_result = await self.eval_agent_loop_manager.produce_batch( self.evaluator.eval_batch_size, train_step=1, model_step=0, ) if XTUNER_DETERMINISTIC: eval_produce_result.rollout_states = sort_rollout_state_for_deterministic( eval_produce_result.rollout_states ) eval_metrics = self.evaluator.run(eval_produce_result.rollout_states) self.logger.info(f"Initial rollout evaluate scores {eval_metrics} and start training") tb_scores = {f"eval/{k}": v for k, v in eval_metrics.items()} self._exp_tracker.add_scalars(tag_scalar_dict=tb_scores, global_step=0) def _train_one_batch( self, train_batch: list[list[RolloutState]], train_step: int, step_timer_dict: dict, *, offload_rollout_before_train: bool = False, onload_train_before_train: bool = False, raw_rewards_sum: float = 0.0, raw_rewards_count: int = 0, ) -> TrainInfo: train_sample_count = sum(len(group) for group in train_batch) self.logger.info(f"generate {train_sample_count} samples for training") train_trajectory_dir = self.exp_dir / "train_rollout" train_trajectory_dir.mkdir(parents=True, exist_ok=True) train_trajectory_path = train_trajectory_dir / f"train_rollout_{train_step}.jsonl" self._save_trajectories(train_batch, train_trajectory_path) self.logger.info(f"Train step {train_step} train trajectories saved to {train_trajectory_path}") # 共卡需要先释放 rollout,再把训练 worker onload;非共卡不走这两个动作。 if offload_rollout_before_train: ray.get(self.rollout_controller.offload.remote()) if onload_train_before_train: with timer("onload", step_timer_dict): self.train_controller.onload(target="all") self.logger.info("Training controller loaded") with timer("prepare_data", step_timer_dict): data_batches, data_info = self._prepare_train_data( train_batch, self._train_worker_cfg.pack_max_length, raw_rewards_sum=raw_rewards_sum, raw_rewards_count=raw_rewards_count, ) self.logger.info(f"Prepared {len(data_batches)} training data batches") with timer("training", step_timer_dict): workers_log_item: list[WorkerLogItem] = self.train_controller.fit( data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=train_step, ) return { "data_info": data_info, "workers_log_item": workers_log_item, } async def _run_evaluation(self, train_step: int) -> dict[str, float]: eval_produce_result = await self.eval_agent_loop_manager.produce_batch( self.evaluator.eval_batch_size, train_step=1, model_step=0, ) if XTUNER_DETERMINISTIC: eval_produce_result.rollout_states = sort_rollout_state_for_deterministic( eval_produce_result.rollout_states ) eval_batch = eval_produce_result.rollout_states eval_metrics = self.evaluator.run(eval_batch) eval_trajectory_dir = self.exp_dir / "eval_rollout" eval_trajectory_dir.mkdir(parents=True, exist_ok=True) eval_trajectory_path = eval_trajectory_dir / f"eval_rollout_{train_step}.jsonl" self._save_trajectories(eval_batch, eval_trajectory_path) self.logger.info(f"Train step {train_step} eval trajectories saved to {eval_trajectory_path}") return eval_metrics def _save_debug_rollout_batch(self, train_batch: list[list[RolloutState]], train_step: int) -> None: assert self._debug_rollout_dir is not None self._debug_rollout_dir.mkdir(parents=True, exist_ok=True) save_path = self._debug_rollout_dir / f"debug_rollout_{train_step}.pt" serializable_batch = [ [cast(RolloutState, _snapshot_nested_objectrefs(rollout_state)) for rollout_state in group] for group in train_batch ] torch.save(serializable_batch, save_path) self.logger.info(f"Debug rollout batch for step {train_step} saved to {save_path}") def _list_debug_rollout_files(self, debug_rollout_dir: Path) -> dict[int, Path]: debug_files = { _parse_debug_rollout_step(path): path for path in sorted(debug_rollout_dir.glob("debug_rollout_*.pt"), key=_parse_debug_rollout_step) } if not debug_files: raise FileNotFoundError(f"No debug rollout files found in {debug_rollout_dir}") return debug_files def _load_debug_rollout_batch(self, train_step: int) -> list[list[RolloutState]]: debug_file = self._debug_train_files.get(train_step) if debug_file is None: raise FileNotFoundError(f"No debug rollout file found for train step {train_step}") train_batch = torch.load(debug_file, map_location="cpu", weights_only=False) train_batch = [ [cast(RolloutState, _restore_nested_objectrefs(rollout_state)) for rollout_state in group] for group in train_batch ] self.logger.info(f"Loaded debug rollout batch for step {train_step} from {debug_file}") return cast(list[list[RolloutState]], train_batch) # TODO: simplify with Packer.pack_pad_dispatch() def _prepare_train_data( self, data_groups: list[list[RolloutState]], pack_max_length: int, raw_rewards_sum: float = 0.0, raw_rewards_count: int = 0, ): rewards_list = [] advantages_list = [] prompt_len_list = [] response_len_list = [] training_tokens = 0 data_batches = [] for j, group in enumerate(data_groups): if not is_valid_for_training(group, self.logger): self.logger.error(f"Skip one data group {group} due to rollout failed or empty response.") continue is_vlm_model = "train_prompt_ids" in group[0].extra_fields if is_vlm_model: # TODO(hha): VLM, 不好的设计,后续要去掉 prompt_ids = group[0].extra_fields["train_prompt_ids"] else: prompt_ids = group[0].prompt_ids assert prompt_ids is not None and len(prompt_ids) > 0, ( f"Prompt ids cannot be None or empty in data: {group[0]}" ) rewards = [] for data in group: assert data.reward is not None and "score" in data.reward, ( f"Reward is missing or does not contain 'score' key in data: {data}" ) rewards.append(data.reward["score"]) rewards_list.extend(rewards) rewards_tensor = torch.tensor(rewards, dtype=torch.float32) advantages = self._advantage_estimator.compute(rewards_tensor, group) prompt_repeat_k = len(group) for i in range(prompt_repeat_k): item = group[i].response logprobs: list[float] | None = None response_ids: List[int] = [] if group[i].response_ids is not None: resp_ids_raw = group[i].response_ids if isinstance(resp_ids_raw, torch.Tensor): response_ids = resp_ids_raw.flatten().tolist() else: response_ids = cast(List[int], resp_ids_raw) logprobs = group[i].logprobs if logprobs is not None: assert len(logprobs) == len(response_ids), ( f"{len(logprobs)} vs {len(response_ids)}, data: {group[i]}" ) # 只有 response 部分有 logprobs, 需要前面追加 logprobs = [0.0] * (len(prompt_ids) - 1) + logprobs # type: ignore[arg-type] else: assert item is not None, "response item cannot be None" response_ids = self.tokenizer(item, return_tensors="pt")["input_ids"].flatten().tolist() # 返回的 routed_experts 不包括 eos 的值,实际上也不需要,需要减一 # TODO: verl tool agent loop 是否需要? input_ids = prompt_ids + response_ids[:-1] prompt_len_list.append(len(prompt_ids)) response_len_list.append(len(response_ids)) # 根据 response_mask 计算 response_ids 对应的shifted_labels if not group[i].response_mask: response_mask = [1] * len(response_ids) response_labels = response_ids else: assert len(group[i].response_mask) == len(response_ids), ( # type: ignore[arg-type] f"{len(group[i].response_mask)} vs {len(response_ids)}" # type: ignore[arg-type] ) response_mask = cast(list[int], group[i].response_mask) response_labels = [ response_id if mask_id != 0 else -100 for response_id, mask_id in zip(response_ids, response_mask) ] shifted_labels = [-100] * (len(prompt_ids) - 1) + response_labels shifted_labels_t = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) # 根据 response_mask 计算新的 advantages advatnages_val = advantages[i].item() actual_advantages = [advatnages_val] * len(prompt_ids) + [ 0.0 if mask == 0 else advatnages_val for mask in response_mask ] advantages_list.extend(actual_advantages[:-1]) assert len(input_ids) <= pack_max_length, f"{len(input_ids)} vs {pack_max_length}" training_tokens += len(input_ids) input_ids_t = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) if logprobs is not None: rollout_logprobs = torch.tensor(logprobs, dtype=torch.float32).unsqueeze(0) assert rollout_logprobs.size() == shifted_labels_t.size(), ( f"{rollout_logprobs.size()} vs {shifted_labels_t.size()}" ) else: rollout_logprobs = None position_ids = group[i].position_ids multimodal_train_info = group[i].mm_info multi_info_cast = cast(dict | None, multimodal_train_info) seq_ctx = get_train_seq_ctx(input_ids_t, position_ids, multi_info_cast, len(response_ids) - 1) # type: ignore[arg-type] data_dict = { "seq_ctx": seq_ctx, "shifted_labels": shifted_labels_t, "advantage": actual_advantages, "rollout_logprobs": rollout_logprobs, } seq_ctx.rollout_routed_experts = group[i].routed_experts # n,layer*expert data_batches.append(data_dict) if not XTUNER_DETERMINISTIC: random.shuffle(data_batches) rewards_t = torch.tensor(rewards_list).float() if rewards_list else torch.tensor([0.0]).float() advantages_t = torch.tensor(advantages_list).float() if advantages_list else torch.tensor([0.0]).float() prompt_len_t = torch.tensor(prompt_len_list).float() if prompt_len_list else torch.tensor([0.0]).float() response_len_t = torch.tensor(response_len_list).float() if response_len_list else torch.tensor([0.0]).float() raw_rewards_mean = raw_rewards_sum / raw_rewards_count if raw_rewards_count > 0 else rewards_t.mean().item() info_dict = { "batch_size": len(rewards_list), "training_samples": len(rewards_list), "training_tokens": training_tokens, "rewards/mean": rewards_t.mean().item(), "rewards/min": rewards_t.min().item(), "rewards/max": rewards_t.max().item(), "raw_rewards/mean": raw_rewards_mean, "advantages/mean": advantages_t.mean().item(), "advantages/min": advantages_t.min().item(), "advantages/max": advantages_t.max().item(), "response_len/mean": response_len_t.mean().item(), "response_len/min": response_len_t.min().item(), "response_len/max": response_len_t.max().item(), "response_len/std": response_len_t.std().item(), "prompt_len/mean": prompt_len_t.mean().item(), "prompt_len/min": prompt_len_t.min().item(), "prompt_len/max": prompt_len_t.max().item(), } return data_batches, info_dict def _compute_benchmark_metrics( self, data_info: dict[str, float], produce_result: ProduceBatchResult, step_timer_dict: dict, benchmark_end_time_s: float, ) -> RLThroughputBenchmark | None: benchmark_start_time_s = self._benchmark_start_time_s if benchmark_start_time_s is None: return None e2e_s = benchmark_end_time_s - benchmark_start_time_s step_s = step_timer_dict.get("step") training_s = step_timer_dict.get("training") rollout_s = produce_result.produce_time_s or step_timer_dict.get("produce_batch") if e2e_s <= 0 or step_s is None or step_s <= 0 or training_s is None or training_s <= 0: return None training_samples = float(data_info.get("training_samples", data_info.get("batch_size", 0.0))) training_tokens = float(data_info.get("training_tokens", 0.0)) effective_samples = float(self._benchmark_training_samples) effective_tokens = float(self._benchmark_training_tokens) rollout_samples = float(produce_result.produced_samples) rollout_tokens = float(produce_result.produced_tokens) train_gpu_count = float(getattr(self, "_num_workers", 1.0)) if train_gpu_count <= 0: train_gpu_count = 1.0 rollout_gpu_count = float(getattr(self, "_rollout_num_workers", train_gpu_count)) if rollout_gpu_count <= 0: rollout_gpu_count = 1.0 e2e_effective_tokens_per_s = effective_tokens / e2e_s effective_tokens_per_s = training_tokens / step_s training_tokens_per_s = training_tokens / training_s rollout_samples_per_s = rollout_samples / rollout_s if rollout_s is not None and rollout_s > 0 else 0.0 rollout_tokens_per_s = rollout_tokens / rollout_s if rollout_s is not None and rollout_s > 0 else 0.0 return RLThroughputBenchmark( e2e_effective_sgs=effective_samples / e2e_s / train_gpu_count, e2e_effective_tgs=e2e_effective_tokens_per_s / train_gpu_count, effective_sgs=training_samples / step_s / train_gpu_count, effective_tgs=effective_tokens_per_s / train_gpu_count, training_tgs=training_tokens_per_s / train_gpu_count, rollout_sgs=rollout_samples_per_s / rollout_gpu_count, rollout_tgs=rollout_tokens_per_s / rollout_gpu_count, ) def _log_step( self, train_step: int, step_timer_dict: dict, produce_result: ProduceBatchResult, train_info: TrainInfo, eval_info: dict[str, float], ): all_scalars = {} log_time_str = "" trajectory_str = "" throughput_str = "" eval_str = "" if step_timer_dict: all_scalars.update({f"time/{k}": v for k, v in step_timer_dict.items()}) log_time_str = f"\nTrain step {train_step} finished and timing listed:\n" log_time_str += "\n".join([f" - {k:<25}: {v:.2f}s" for k, v in step_timer_dict.items()]) if produce_result.group_gen_count is not None: all_scalars["timing/task_n"] = produce_result.group_gen_count all_scalars["timing/task_mean_s"] = produce_result.group_gen_mean_s all_scalars["timing/task_p50_s"] = produce_result.group_gen_p50_s all_scalars["timing/task_p99_s"] = produce_result.group_gen_p99_s all_scalars["timing/task_p99_p50_ratio"] = produce_result.group_gen_p99_p50_ratio if produce_result.group_gen_pause_time_s is not None: all_scalars["timing/pause_s"] = produce_result.group_gen_pause_time_s all_scalars["async/init_samples"] = produce_result.leftover_init all_scalars["async/completed_samples"] = produce_result.leftover_completed all_scalars["async/aborted_samples"] = produce_result.leftover_aborted all_scalars["async/expired_samples"] = produce_result.leftover_expired all_scalars["async/failed_samples"] = produce_result.leftover_failed all_scalars["async/filtered_samples"] = produce_result.leftover_filtered if train_info: data_info = train_info.get("data_info", {}) training_samples = int(data_info.get("training_samples", data_info.get("batch_size", 0))) training_tokens = int(data_info.get("training_tokens", 0)) self._benchmark_training_samples += training_samples self._benchmark_training_tokens += training_tokens benchmark_end_time_s = float(data_info.get("benchmark_end_time_s", time.perf_counter())) benchmark_data_info_keys = { "training_samples", "training_tokens", "benchmark_end_time_s", } response_data_info = {k: v for k, v in data_info.items() if k not in benchmark_data_info_keys} all_scalars.update({f"response/{k}": v for k, v in response_data_info.items()}) throughput_benchmark = self._compute_benchmark_metrics( data_info, produce_result, step_timer_dict, benchmark_end_time_s, ) if throughput_benchmark is not None: throughput_metrics = throughput_benchmark.to_scalars() all_scalars.update(throughput_metrics) throughput_str = f"\nTrain step {train_step} throughput statistics:\n" throughput_str += "\n".join( [f"- {k.removeprefix('throughput/'):<25}: {v:.4f}" for k, v in throughput_metrics.items()] ) trajectory_str = f"\nTrain step {train_step} data statistics:\n" trajectory_str += "\n".join([f"- {k:<25}: {v:.4f}" for k, v in response_data_info.items()]) rank0_log_item = train_info["workers_log_item"][0] rank0_rollout_is_metrics = rank0_log_item.get("rollout_is_metrics", {}) rank0_mismatch_metrics = rank0_log_item.get("mismatch_metrics", {}) rank0_rollout_entropy = rank0_log_item.get("rollout_entropy", 0.0) all_scalars.update({f"rollout_is/{k}": v for k, v in rank0_rollout_is_metrics.items()}) all_scalars.update({f"{k}": v for k, v in rank0_mismatch_metrics.items()}) all_scalars.update({"entropy/rollout": rank0_rollout_entropy}) all_scalars.update({"entropy/train": rank0_log_item["train_entropy"]}) for worker_idx, log_item in enumerate(train_info["workers_log_item"]): if not self._display_all_workers_log and worker_idx > 0: break mini_batch_metrics: dict[str, List[float]] = {} for mini_batch_log in log_item["train_metrics"]: for k, v in mini_batch_log.items(): mini_batch_metrics.setdefault(k, []).append(cast(float, v)) for key, value in mini_batch_metrics.items(): avg_value = sum(value) / len(value) all_scalars.update({f"train_metrics/worker_{worker_idx}/step_avg_{key}": avg_value}) rank_sft_log = log_item["sft_train_metrics"] for k, v in rank_sft_log.items(): all_scalars.update({f"sft_train_metrics/worker_{worker_idx}/{k}": v}) self._log_mini_batch_metrics(train_info["workers_log_item"]) if eval_info: all_scalars.update({f"eval/{k}": v for k, v in eval_info.items()}) eval_str = " ".join([f"{k}: {v:.4f}" for k, v in eval_info.items()]) self.logger.info( f"Train step {train_step}/{self._total_train_steps}{log_time_str} {trajectory_str} {throughput_str}" ) if eval_str: self.logger.info(f"Eval: {eval_str}") self._exp_tracker.add_scalars(tag_scalar_dict=all_scalars, global_step=train_step) def _save_trajectories(self, data_groups: list[list[RolloutState]], save_path: Path) -> None: rewards = [] response_len_list = [] for group in data_groups: if not is_valid_for_training(group, self.logger): continue for data in group: assert data.reward is not None rewards.append(data.reward["score"]) if data.response_ids is not None: if isinstance(data.response_ids, torch.Tensor): response_ids = data.response_ids.flatten().tolist() else: response_ids = data.response_ids response_len_list.append(len(response_ids)) elif data.response is not None: response_ids = self.tokenizer.encode(data.response, add_special_tokens=False) response_len_list.append(len(response_ids)) rewards_tensor = torch.tensor(rewards).float() if rewards else torch.tensor([0.0]).float() response_lens = torch.tensor(response_len_list).float() if response_len_list else torch.tensor([0.0]).float() _count = 0 with open(save_path, "w", encoding="utf-8") as f: summary = { "reward_mean": rewards_tensor.mean().item(), "reward_std": rewards_tensor.std().item(), "reward_max": rewards_tensor.max().item(), "reward_min": rewards_tensor.min().item(), "response_len_mean": response_lens.mean().item(), "response_len_std": response_lens.std().item(), "response_len_max": response_lens.max().item(), "response_len_min": response_lens.min().item(), "total_len": len(rewards), } json.dump(summary, f, ensure_ascii=False, indent=2) f.write("\n") for group in data_groups: if not is_valid_for_training(group, self.logger): continue for data in group: assert data.reward is not None ground_truth = None if data.reward_model is not None: ground_truth = data.reward_model.get("ground_truth") item = { "prompt": data.message, "raw_prompt": data.extra_fields.get("raw_prompt", None), "response": data.response, "response_len": response_len_list[_count], "label": ground_truth, "reward": data.reward["score"], "finish_reason": data.finish_reason, } json.dump(item, f, ensure_ascii=False, indent=2) f.write("\n") _count += 1 def _log_mini_batch_metrics(self, workers_log_item: List[WorkerLogItem]): train_start_step = self._global_train_step + 1 for worker_idx, log_item in enumerate(workers_log_item): for step_idx, mini_batch_log in enumerate(log_item["train_metrics"]): if not self._display_all_workers_log and worker_idx > 0: break current_global_step = train_start_step + step_idx metrics: dict[str, Any] = dict(mini_batch_log) self._exp_tracker.add_scalars( tag_scalar_dict={f"train_metrics/worker_{worker_idx}/{k}": float(v) for k, v in metrics.items()}, global_step=current_global_step, ) self._global_train_step += len(workers_log_item[0]["train_metrics"])
[docs]class RLColocateTrainer(BaseRLTrainer): _META_PATH = ".xtuner_rl_colocate_trainer" # 共卡 trainer 保留自己的资源编排、resume、主循环和权重同步;通用保存、日志仍在 BaseRLTrainer。 def __init__(self, cfg: RLColocateTrainerConfig): self._init_common(cfg, meta_path=self._META_PATH, logger_tag="RLTrainer") self._num_workers = float(cfg.resources.num_workers) self._rollout_num_workers = float(cfg.resources.num_workers) self._pg = AutoAcceleratorWorkers.build_placement_group(cfg.resources) self._cpu_resource_manager = CPUResourceManager(self._pg) self._cpu_resource_manager.log_initial_snapshot() set_cpu_resource_manager(self._cpu_resource_manager) if self._debug_rollout: if self._rollout_config.skip_load_weights: self.logger.info( "debug_rollout cannot be used with rollout_config.skip_load_weights=True. force set skip_load_weights to False" ) self._rollout_config.skip_load_weights = False self.rollout_controller = self._rollout_config.build(self._pg) self._maybe_start_gateway(cfg) replay_buffer = cfg.replay_buffer_config.build() self._build_agent_loop_components(cfg, replay_buffer) self._cpu_resource_manager.log_registered_summary() self.logger.warning("Debug rollout mode is enabled. Only rollout workers will be started.") return self.train_controller = self._train_worker_cfg.build(self._pg) checkpoint_path = self._load_checkpoint_cfg.checkpoint_path if checkpoint_path is not None: checkpoint_path = self._resume_train_controller_and_state(checkpoint_path) if self._debug_train: assert self._debug_rollout_dir is not None self.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path, trust_remote_code=True) self._debug_train_files = self._list_debug_rollout_files(self._debug_rollout_dir) if cfg.total_train_steps is None: self._total_train_steps = max(self._debug_train_files) self.logger.warning( "Debug train mode is enabled. Only training workers will be started and rollout weights will not be synchronized." ) return # Free trainer-side GPU memory before bringing up colocated rollout workers. # Backends like sglang may size KV cache against their own target utilization # instead of the trainer's transient footprint, which can cause init-time OOM. self.train_controller.offload(target="all") self.rollout_controller = self._rollout_config.build(self._pg) self._maybe_start_gateway(cfg) bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) replay_buffer = cfg.replay_buffer_config.build() self._build_agent_loop_components(cfg, replay_buffer) if checkpoint_path is not None: asyncio_run(self._resume_agent_loop_manager(checkpoint_path)) self.train_controller.set_train_rollout_mode("colocate") self._cpu_resource_manager.log_registered_summary() if self._rollout_config.skip_load_weights: self._sync_weights_from_train_workers() def _sync_weights_from_train_workers(self) -> None: self.logger.info("Rollout workers skip load weights, update weights from train workers.") ray.get(self.rollout_controller.offload.remote()) self.train_controller.onload(target="model") ray.get(self.rollout_controller.onload_weights.remote()) self.train_controller.update_weights() self.train_controller.offload(target="model") ray.get(self.rollout_controller.onload_kvcache.remote()) self.logger.info("Rollout workers updated weights from train workers.")
[docs] def fit(self): self.logger.info("Start RL training") if self._cur_step >= self._total_train_steps: self.logger.info(f"Train steps {self._total_train_steps} reached, stop training") return if self._debug_train: self._fit_debug_train() return if self._enable_initial_evaluate and not self._debug_rollout: asyncio_run(self._run_initial_evaluate()) self._benchmark_start_time_s = time.perf_counter() self._benchmark_training_samples = 0 self._benchmark_training_tokens = 0 init_train_step = self._cur_step + 1 model_step = self._get_colocate_rollout_model_step(init_train_step) for train_step in range(init_train_step, self._total_train_steps + 1): self.logger.info(f"Train step {train_step}/{self._total_train_steps} start") step_timer_dict = {} with timer("step", step_timer_dict): # 共卡路径一次调用内完成 rollout 生产和 replay buffer 消费。 self.logger.info( f"[Step {train_step}] start to generate rollout experience for train step {train_step} with model step {model_step}" ) with timer("produce_batch", step_timer_dict): produce_result: ProduceBatchResult = asyncio_run( self.agent_loop_manager.produce_batch( self.train_batch_size, train_step=train_step, model_step=model_step, ) ) if XTUNER_DETERMINISTIC: produce_result.rollout_states = sort_rollout_state_for_deterministic(produce_result.rollout_states) train_batch = produce_result.rollout_states assert train_batch, ( "RLColocateTrainer expects agent_loop_manager.produce_batch() to return non-empty rollout_states." ) if not self._debug_rollout: train_log_info = self._train_one_batch( train_batch, train_step, step_timer_dict, offload_rollout_before_train=True, onload_train_before_train=True, raw_rewards_sum=produce_result.raw_rewards_sum, raw_rewards_count=produce_result.raw_rewards_count, ) else: self._save_debug_rollout_batch(train_batch, train_step) train_log_info = {} if not self._debug_rollout: weights_synced = self._sync_weights_and_save(train_step, step_timer_dict) if weights_synced: model_step = train_step eval_log_info = {} if weights_synced and self._enable_evaluate and train_step % self._evaluate_step == 0: with timer("evaluation", step_timer_dict): eval_log_info.update(asyncio_run(self._run_evaluation(train_step))) else: eval_log_info = {} self._log_step(train_step, step_timer_dict, produce_result, train_log_info, eval_log_info) self._cur_step = train_step
def _fit_debug_train(self) -> None: self._benchmark_start_time_s = time.perf_counter() self._benchmark_training_samples = 0 self._benchmark_training_tokens = 0 init_train_step = self._cur_step + 1 for train_step in range(init_train_step, self._total_train_steps + 1): self.logger.info(f"Debug train step {train_step}/{self._total_train_steps} start") step_timer_dict: dict[str, float] = {} with timer("step", step_timer_dict): with timer("load_debug_rollout", step_timer_dict): train_batch = self._load_debug_rollout_batch(train_step) train_log_info = self._train_one_batch( train_batch, train_step, step_timer_dict, offload_rollout_before_train=False, onload_train_before_train=False, ) eval_log_info: dict[str, float] = {} produce_result = ProduceBatchResult(rollout_states=train_batch) self._log_step(train_step, step_timer_dict, produce_result, train_log_info, eval_log_info) self._cur_step = train_step def _get_colocate_rollout_model_step(self, train_step: int) -> int: previous_step = train_step - 1 return previous_step - (previous_step % self._sync_weights_interval) def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict) -> bool: """Save state and switch colocated resources back to rollout workers.""" should_sync_weights = train_step % self._sync_weights_interval == 0 with timer("save_ckpt", step_timer_dict): self.train_controller.offload(target="optimizer") asyncio_run(self._maybe_save_checkpoint(train_step)) self._maybe_save_hf(train_step) ray.get(self.rollout_controller.recover_failed_workers.remote()) timer_name = "sync_weight" if should_sync_weights else "switch_to_rollout" with timer(timer_name, step_timer_dict): if should_sync_weights: bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) ray.get(self.rollout_controller.onload_weights.remote()) self.train_controller.update_weights() self.logger.info("Rollout workers update weights successfully in colocate mode") self.train_controller.offload(target="model") else: self.train_controller.offload(target="model") ray.get(self.rollout_controller.onload_weights.remote()) ray.get(self.rollout_controller.onload_kvcache.remote()) return should_sync_weights
[docs]class RLDisaggregatedTrainer(BaseRLTrainer): _META_PATH = ".xtuner_rl_disaggregated_trainer" def __init__(self, cfg: RLDisaggregatedTrainerConfig): self._init_common(cfg, meta_path=self._META_PATH, logger_tag="RLDisaggTrainer") self._num_workers = float(cfg.train_resources.num_workers) self._rollout_num_workers = float(cfg.rollout_resources.num_workers) self._train_pg, self._rollout_pg = self._build_disaggregated_placement_groups( train_resources=cfg.train_resources, rollout_resources=cfg.rollout_resources, ) self._cpu_resource_manager = CPUResourceManager([self._train_pg, self._rollout_pg]) self._cpu_resource_manager.log_initial_snapshot() set_cpu_resource_manager(self._cpu_resource_manager) self.train_controller = self._train_worker_cfg.build(self._train_pg) self.rollout_controller = self._rollout_config.build(self._rollout_pg) self._maybe_start_gateway(cfg) replay_buffer = cfg.replay_buffer_config.build() self._build_agent_loop_components(cfg, replay_buffer) # 在非共卡使用模式时,生产者和消费者并发执行 # 为了让生产者和消费者配合,不能引入生产中的早停机制,否则生产不够,消费者会被阻塞 # 所以 should_continue_fn 必须为 default_should_continue_fn for task_runner in self.agent_loop_manager.task_runners: if task_runner.produce_strategy.should_continue_fn is not default_should_continue_fn: raise ValueError( "In disaggregated mode, should_continue_fn must be default, " "because it does not allow early stopping in production." ) bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) self.train_controller.set_train_rollout_mode("disaggregated") if self._load_checkpoint_cfg.checkpoint_path is not None: self._resume_from_checkpoint(self._load_checkpoint_cfg.checkpoint_path) self._cpu_resource_manager.log_registered_summary() def _build_disaggregated_placement_groups( self, train_resources: AcceleratorResourcesConfig, rollout_resources: AcceleratorResourcesConfig, ): pg_name_prefix = f"xtuner_rl_disagg_{self.exp_dir.name}" train_pg_name = f"{pg_name_prefix}_train" rollout_pg_name = f"{pg_name_prefix}_rollout" train_pg = AutoAcceleratorWorkers.build_placement_group(train_resources, name=train_pg_name) rollout_pg = AutoAcceleratorWorkers.build_placement_group(rollout_resources, name=rollout_pg_name) if train_pg.id == rollout_pg.id: raise RuntimeError( "RLDisaggregatedTrainer requires distinct placement groups for train and rollout, " f"but both resolved to the same placement group id={train_pg.id}. " "Please check placement-group naming and stale Ray cluster state." ) self.logger.info( "Created disaggregated placement groups: " f"train={train_pg_name}(id={train_pg.id}), " f"rollout={rollout_pg_name}(id={rollout_pg.id})" ) return train_pg, rollout_pg def _resume_from_checkpoint(self, checkpoint_path: Path | str) -> None: checkpoint_path = self._resume_train_controller_and_state(checkpoint_path) saved_model_step = asyncio_run(self._resume_agent_loop_manager(checkpoint_path)) assert self._cur_step == saved_model_step self.update_weights() asyncio_run(self.agent_loop_manager.continue_produce(model_step=saved_model_step))
[docs] def fit(self): # 对外保留同步 fit 接口,内部用 async loop 组织 producer/consumer。 return asyncio_run(self._fit())
async def _fit(self): self.logger.info("Start RL disaggregated training") if self._cur_step >= self._total_train_steps: self.logger.info(f"Train steps {self._total_train_steps} reached, stop training") return if self._enable_initial_evaluate: await self._run_initial_evaluate() self._benchmark_start_time_s = time.perf_counter() self._benchmark_training_samples = 0 self._benchmark_training_tokens = 0 # 后台 producer 只负责持续往 replay buffer 写数据,前台 trainer 通过 get_batch 消费。 producer_task = create_task( self.agent_loop_manager.produce_loop( batch_size=self.train_batch_size, ) ) try: # train_step 表示“下一步待完成训练”;空 expired 不算完成,所以必须用 while 支持重试同一步。 train_step = self._cur_step + 1 while train_step <= self._total_train_steps: self.logger.info(f"Train step {train_step}/{self._total_train_steps} start") step_timer_dict: dict[str, float] = {} train_log_info = {} eval_log_info = {} with timer("step", step_timer_dict): with timer("get_batch", step_timer_dict): produce_result = await self.agent_loop_manager.get_batch( self.train_batch_size, train_step=train_step ) if XTUNER_DETERMINISTIC: produce_result.rollout_states = sort_rollout_state_for_deterministic( produce_result.rollout_states ) train_batch = produce_result.rollout_states # EXPIRED_BATCH 分两类:空 batch 是控制面同步;非空 batch 仍然是可训练数据。 empty_expired_batch = produce_result.status == ProduceBatchStatus.EXPIRED_BATCH and not train_batch if empty_expired_batch: # 没有完成训练,能同步的只能是上一轮已经完成的 Model Step。 sync_model_step = train_step - 1 self.logger.info( "Skip train step because rollout model is expired and a newer model already exists; " f"sync completed model_step={sync_model_step} first." ) else: # 非空 expired 必须训练出当前 step 的新模型版本,否则 producer 没有更新权重可恢复。 assert train_batch, ( "RLDisaggregatedTrainer expects get_batch() to return non-empty rollout_states " "unless status is empty EXPIRED_BATCH." ) # 非共卡训练要求后台 producer 在训练当前 batch 时继续推进; # 同步训练路径放到线程里执行,避免 ray.get / 文件写入阻塞事件循环。 train_log_info = await asyncio.to_thread( self._train_one_batch, train_batch, train_step, step_timer_dict, raw_rewards_sum=produce_result.raw_rewards_sum, raw_rewards_count=produce_result.raw_rewards_count, ) sync_model_step = train_step # 后续保存、同步、评测、恢复 producer 都以“已完成的 Model Step”为唯一口径。 need_sync = ( empty_expired_batch or produce_result.status == ProduceBatchStatus.EXPIRED_BATCH or sync_model_step % self._sync_weights_interval == 0 or sync_model_step == self._total_train_steps ) if need_sync: # 同步前先暂停后台 producer,避免 save/sync 时还有 pending rollout 继续写 buffer。 with timer("pause_produce", step_timer_dict): await self.agent_loop_manager.pause_produce(use_global_progress=True) await self._sync_weights_and_save(sync_model_step, step_timer_dict) if ( self._enable_evaluate and sync_model_step > 0 and sync_model_step % self._evaluate_step == 0 ): # eval 放在恢复 producer 前,避免后台生产抢占 rollout 资源。 with timer("evaluation", step_timer_dict): eval_log_info.update(await self._run_evaluation(sync_model_step)) await self.agent_loop_manager.continue_produce(model_step=sync_model_step) if empty_expired_batch: # 空 expired 没有完成训练,不能 log 成已完成 step,也不能推进 _cur_step。 continue self._log_step(train_step, step_timer_dict, produce_result, train_log_info, eval_log_info) self._cur_step = train_step train_step = self._cur_step + 1 finally: self.agent_loop_manager.shutdown() await producer_task async def _sync_weights_and_save(self, model_step: int, step_timer_dict: dict): # 非共卡已经在 _fit 里暂停 producer;这里保持静止态下的 save -> bind -> update 顺序。 with timer("save_ckpt", step_timer_dict): await self._maybe_save_checkpoint(model_step) self._maybe_save_hf(model_step) ray.get(self.rollout_controller.recover_failed_workers.remote()) with timer("sync_weight", step_timer_dict): bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) self.update_weights()
[docs] def update_weights(self): # producer 的 pause/continue 由 AgentLoopManager 控制,避免这里提前恢复 rollout 影响 eval 顺序。 self.train_controller.update_weights() self.logger.info("Rollout workers update weights successfully in disaggregated mode")