Source code for xtuner.v1.rl.agent_loop_manager.agent_loop_manager

import asyncio
import json
import math
import time
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path

from pydantic import BaseModel, ConfigDict, Field

from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.rl.agent_loop import AgentLoopConfig, AgentLoopSpec, get_agent_loop_rollout_ctl
from xtuner.v1.rl.judger import ComposedJudgerConfig, JudgerConfig, build_judger
from xtuner.v1.rl.replay_buffer import ReplayBuffer
from xtuner.v1.rl.rollout import RolloutController
from xtuner.v1.utils import get_logger

from .producer import (
    GROUP_GENERATE_TIME_KEY,
    ProduceBatchStatus,
    ProduceContext,
    ProduceProgress,
    ProduceStrategy,
    ProduceStrategyConfig,
    SyncProduceStrategyConfig,
    default_is_valid_sample_fn,
)
from .sampler import Sampler, SamplerConfig


@dataclass
class ProduceBatchResult:
    """Result of a single ``produce_batch`` call.

    Args:
        rollout_states (list[list[RolloutState]]): Completed rollout groups retrieved from the replay buffer for training.
        group_gen_count (int | None): Number of generate-group calls finished in this batch (None if no generations ran).
        group_gen_mean_s (float | None): Mean wall-clock time per generate-group call, in seconds.
        group_gen_p50_s (float | None): Median (p50) generate-group time, in seconds.
        group_gen_p99_s (float | None): 99th percentile generate-group time, in seconds.
        group_gen_p99_p50_ratio (float | None): Ratio of p99 to p50, indicating tail-latency skew.
        group_gen_pause_time_s (float | None): Time spent in pause/cleanup phase (async strategy only), in seconds.
        leftover_init (int): Number of init groups remaining in the replay buffer after this batch.
        leftover_completed (int): Number of completed groups remaining in the replay buffer after this batch.
        leftover_aborted (int): Number of aborted groups remaining in the replay buffer.
        leftover_expired (int): Number of expired groups remaining in the replay buffer.
        leftover_failed (int): Number of failed groups remaining in the replay buffer.
        leftover_filtered (int): Number of filtered groups remaining in the replay buffer.
        raw_rewards_sum (float): Sum of rewards produced before replay-buffer insertion for the current window.
        raw_rewards_count (int): Number of reward-bearing samples included in ``raw_rewards_sum``.
        produced_samples (int): Number of rollout samples produced in the current produce window.
        produced_tokens (int): Number of response tokens produced in the current produce window.
        produce_time_s (float): Wall-clock production time consumed by the current produce window.
    """

    rollout_states: list[list[RolloutState]]
    status: ProduceBatchStatus = ProduceBatchStatus.NORMAL
    # per-group generation timing stats (all None if no generations occurred)
    group_gen_count: int | None = None
    group_gen_mean_s: float | None = None
    group_gen_p50_s: float | None = None
    group_gen_p99_s: float | None = None
    group_gen_p99_p50_ratio: float | None = None
    group_gen_pause_time_s: float | None = None
    # leftover samples remaining in replay buffer after batch retrieval
    leftover_init: int = 0
    leftover_completed: int = 0
    leftover_aborted: int = 0
    leftover_expired: int = 0
    leftover_failed: int = 0
    leftover_filtered: int = 0
    # rewards produced during the current produce window, including completed and filtered groups.
    raw_rewards_sum: float = 0.0
    raw_rewards_count: int = 0
    produced_samples: int = 0
    produced_tokens: int = 0
    produce_time_s: float = 0.0
    task_batch_sizes: dict[str, int] | None = None
    task_results: dict[str, "ProduceBatchResult"] | None = None


@dataclass(frozen=True)
class _TaskRunner:
    task_name: str
    agent_loop: AgentLoopSpec
    produce_strategy: ProduceStrategy
    sampler: Sampler
    weight: float = 1.0
    order: int = 0


class _TaskSamplerView:
    def __init__(self, samplers: list[Sampler]):
        self._samplers = samplers

    def __len__(self) -> int:
        return sum(len(sampler) for sampler in self._samplers)


class AgentLoopManagerStatus(Enum):
    """AgentLoopManager 的全局状态.

    按下面的路径流转:
    - 初始状态是 NORMAL
    - NORMAL -> UPDATE_WEIGHT_AND_ABORT
      - trainer 开始做权重同步前触发
    - UPDATE_WEIGHT_AND_ABORT -> NORMAL
      - 权重同步完成后调用 continue_product()
    - NORMAL -> EXPIRED_BATCH
      - 当前 rollout model 已经过旧
    - EXPIRED_BATCH -> UPDATE_WEIGHT_AND_ABORT
      - trainer 检测到过期后,进入权重同步阶段
    - 任意状态 -> FINISH
      - 训练结束

    这里有一个重要区分:
    - AgentLoopManagerStatus 是“后台 producer 的全局运行状态”
    - ProduceBatchStatus 是“单次调度调用的局部结果”
    """

    NORMAL = auto()
    UPDATE_WEIGHT_AND_ABORT = auto()
    EXPIRED_BATCH = auto()
    FINISH = auto()


def _fill_produce_timing_stats(
    result: ProduceBatchResult, generate_times_s: list[float], pause_time_s: float = 0.0
) -> None:
    if not generate_times_s:
        if pause_time_s > 0:
            result.group_gen_pause_time_s = pause_time_s
        return
    sorted_times = sorted(generate_times_s)
    n = len(sorted_times)
    mean_s = sum(sorted_times) / n
    p50_s = sorted_times[n // 2]
    p99_s = sorted_times[int(n * 0.99)]
    ratio = p99_s / p50_s if p50_s > 0 else float("inf")
    result.group_gen_count = n
    result.group_gen_mean_s = mean_s
    result.group_gen_p50_s = p50_s
    result.group_gen_p99_s = p99_s
    result.group_gen_p99_p50_ratio = ratio
    result.group_gen_pause_time_s = pause_time_s


def _fill_group_timing_stats(
    result: ProduceBatchResult, rollout_states: list[list[RolloutState]], pause_time_s: float = 0.0
) -> None:
    generate_times: list[float] = []
    for group in rollout_states:
        if not group:
            continue
        group_time = getattr(group[0], "extra_fields", {}).get(GROUP_GENERATE_TIME_KEY)
        if group_time is not None:
            generate_times.append(group_time)

    _fill_produce_timing_stats(result, generate_times, pause_time_s=pause_time_s)


def _aggregate_status(statuses: list[ProduceBatchStatus]) -> ProduceBatchStatus:
    if any(status== ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT for status in statuses):
        return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT
    if any(status== ProduceBatchStatus.EXPIRED_BATCH for status in statuses):
        return ProduceBatchStatus.EXPIRED_BATCH
    return ProduceBatchStatus.NORMAL


_LEFTOVER_STATUSES = [
    Status.INIT,
    Status.COMPLETED,
    Status.ABORTED,
    Status.EXPIRED,
    Status.FAILED,
    Status.FILTERED,
]


def _fill_leftover_counts(result: ProduceBatchResult, status_counts: dict[Status, int]) -> None:
    result.leftover_init = status_counts.get(Status.INIT, 0)
    result.leftover_completed = status_counts.get(Status.COMPLETED, 0)
    result.leftover_aborted = status_counts.get(Status.ABORTED, 0)
    result.leftover_expired = status_counts.get(Status.EXPIRED, 0)
    result.leftover_failed = status_counts.get(Status.FAILED, 0)
    result.leftover_filtered = status_counts.get(Status.FILTERED, 0)


def _build_produce_context(
    task_runner: _TaskRunner,
    replay_buffer: ReplayBuffer,
    batch_size: int,
    train_step: int,
    model_step: int,
    update_event: asyncio.Event,
    progress: ProduceProgress,
) -> ProduceContext:
    return ProduceContext(
        agent_loop=task_runner.agent_loop,
        sampler=task_runner.sampler,
        replay_buffer=replay_buffer,
        task_batch_size=batch_size,
        task_name=task_runner.task_name,
        train_step=train_step,
        update_event=update_event,
        model_step=model_step,
        progress=progress,
        is_valid_sample_fn=getattr(task_runner.produce_strategy, "is_valid_sample_fn", default_is_valid_sample_fn),
        stale_threshold=getattr(task_runner.produce_strategy, "stale_threshold", None),
    )


[docs]class TaskSpecConfig(BaseModel): """Configuration for one task managed by ``AgentLoopManager``. A task spec binds together the dataset sampler, agent loop, optional judger, production strategy, and sampling weight for one RL data source. Multi-task training is represented as a list of ``TaskSpecConfig`` objects. Args: task_name (str): Unique task name used for logging, replay-buffer routing, and checkpoint state. weight (float): Relative batch allocation weight for this task in multi-task training. Defaults to 1.0. agent_loop_config (AgentLoopConfig): Agent loop configuration used to generate rollout samples for this task. judger_config (JudgerConfig | ComposedJudgerConfig | None): Optional judger configuration used to score generated samples. Defaults to None. produce_strategy_config (ProduceStrategyConfig): Strategy used to produce rollout samples. Defaults to ``SyncProduceStrategyConfig``. sampler_config (SamplerConfig): Dataset sampler configuration for this task. **Examples:** Example configuration for one task:: task = TaskSpecConfig( task_name="gsm8k", weight=1.0, agent_loop_config=SingleTurnAgentLoopConfig( hf_checkpoint="Qwen/Qwen3-8B", sample_params=SampleParams(max_tokens=1024), ), judger_config=GSM8KJudgerConfig(), sampler_config=SamplerConfig(dataloader_cfg=dataloader_cfg, prompt_repeat_k=8), ) """ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) task_name: str weight: float = Field(default=1.0, ge=0.0) agent_loop_config: AgentLoopConfig judger_config: JudgerConfig | ComposedJudgerConfig | None = None produce_strategy_config: ProduceStrategyConfig = SyncProduceStrategyConfig() sampler_config: SamplerConfig
[docs]class AgentLoopManagerConfig(BaseModel): """Configuration for the agent loop manager. ``AgentLoopManagerConfig`` defines the rollout-producing side of RL training. It may manage a single task or a weighted list of tasks, and each task owns its sampler, agent loop, judger, and production strategy. Args: tasks (list[TaskSpecConfig] | TaskSpecConfig): One task config or a list of task configs. Task names must be unique when a list is provided. **Examples:** Example configuration for a single-task manager:: config = AgentLoopManagerConfig( tasks=TaskSpecConfig( task_name="gsm8k", agent_loop_config=SingleTurnAgentLoopConfig( hf_checkpoint="Qwen/Qwen3-8B", sample_params=SampleParams(max_tokens=1024), ), judger_config=GSM8KJudgerConfig(), sampler_config=SamplerConfig(dataloader_cfg=dataloader_cfg, prompt_repeat_k=8), ) ) """ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) tasks: list[TaskSpecConfig] | TaskSpecConfig def build( self, rollout_controller: RolloutController, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, replay_buffer: ReplayBuffer, logger=None, sync_weights_interval: int = 1, ) -> "AgentLoopManager": tasks = self.tasks if isinstance(self.tasks, list) else [self.tasks] if not tasks: raise ValueError("AgentLoopManagerConfig requires at least one task config.") seen_task_names: set[str] = set() task_runners: list[_TaskRunner] = [] for order, task_cfg in enumerate(tasks): if task_cfg.task_name in seen_task_names: raise ValueError(f"Duplicate task_name found in AgentLoopManagerConfig: {task_cfg.task_name}") seen_task_names.add(task_cfg.task_name) agent_loop = task_cfg.agent_loop_config.build( rollout_controller=rollout_controller, judger=build_judger(task_cfg.judger_config) if task_cfg.judger_config is not None else None, logger=logger, ) produce_strategy = task_cfg.produce_strategy_config.build( sync_weights_interval=sync_weights_interval, rollout_controller=rollout_controller, ) sampler = task_cfg.sampler_config.build(tokenizer=tokenizer, replay_buffer=replay_buffer) task_runners.append( _TaskRunner( task_name=task_cfg.task_name, agent_loop=agent_loop, produce_strategy=produce_strategy, sampler=sampler, weight=task_cfg.weight, order=order, ) ) return AgentLoopManager( task_runners=task_runners, replay_buffer=replay_buffer, logger=logger, )
class AgentLoopManager: _TASK_CHECKPOINT_DIR = "tasks" _MANAGER_STATE_PATH = "agent_loop_manager_state.json" _STATUS_POLL_INTERVAL_S = 1.0 def __init__( self, task_runners: list[_TaskRunner], replay_buffer: ReplayBuffer, logger=None, ): if not task_runners: raise ValueError("AgentLoopManager requires at least one task runner.") if sum(task.weight for task in task_runners) <= 0: raise ValueError("At least one task weight must be positive for AgentLoopManager.") self.task_runners = task_runners self.replay_buffer = replay_buffer self.data_sampler = ( task_runners[0].sampler if len(task_runners) == 1 else _TaskSamplerView([task.sampler for task in task_runners]) ) self.name = task_runners[0].task_name if len(task_runners) == 1 else "multi_task" if logger is None: self.logger = get_logger() else: self.logger = logger self.task_names = [task.task_name for task in self.task_runners] # 非共卡并发控制信号:consumer 在同步权重前置位,producer / strategy 应直接观察 # event 状态并尽快停止继续发新 rollout;不要用额外布尔快照替代这个 event。 self._update_event = asyncio.Event() self._finish_event = asyncio.Event() # 非共卡 producer 读取的 model_step:rollout 侧当前使用的是哪个 train_step 同步后的模型。 # 权重更新前必须先 pause 并清空 pending task,因此一个 pending 生命周期内只对应一个 model_step。 self._model_step = 0 # 非共卡 producer / consumer 共享的控制状态。produce_loop / get_batch 应直接读取 # self._status,不要跨 await 缓存局部快照,避免错过同步、过期或结束状态变化。 self._status = AgentLoopManagerStatus.NORMAL # pause_produce 写入、下一次 get batch 读取并清零的耗时指标。 # 只用于消费侧日志/metrics;读写不构成生产正确性依赖。 self._pause_time_s = 0.0 # 非共卡 producer / consumer 共享的绝对累计进度。对象引用必须保持稳定; # consumer 原地更新字段,producer / strategy 需要字段值时直接读取 progress.xxx, # 不要把字段值复制成跨 await 使用的局部快照。 self._produce_progress = ProduceProgress.build(self.task_names) def get_task_batch_sizes(self, global_batch_size: int, train_step: int) -> dict[str, int]: """Return the per-task batch sizes for the current train step. Subclasses may override this method to implement custom dynamic batch allocation policies. Returning 0 for a task effectively disables that task for the current produce_batch call. """ if global_batch_size < 0: raise ValueError(f"global_batch_size must be non-negative, got {global_batch_size}") total_weight = sum(task.weight for task in self.task_runners) if total_weight <= 0: raise ValueError("Sum of task weights must be positive.") if global_batch_size == 0: return {task.task_name: 0 for task in self.task_runners} raw_allocations = [global_batch_size * task.weight / total_weight for task in self.task_runners] floor_allocations = [math.floor(raw) for raw in raw_allocations] remaining = global_batch_size - sum(floor_allocations) task_batch_sizes = {task.task_name: floor_allocations[idx] for idx, task in enumerate(self.task_runners)} if remaining <= 0: return task_batch_sizes ranked_tasks = sorted( enumerate(self.task_runners), key=lambda item: ( -(raw_allocations[item[0]] - floor_allocations[item[0]]), item[1].order, ), ) for idx, task in ranked_tasks[:remaining]: task_batch_sizes[task.task_name] += 1 return task_batch_sizes def _validate_task_batch_sizes(self, task_batch_sizes: dict[str, int], global_batch_size: int) -> None: expected_task_names = {task.task_name for task in self.task_runners} actual_task_names = set(task_batch_sizes.keys()) if actual_task_names != expected_task_names: missing_task_names = expected_task_names - actual_task_names extra_task_names = actual_task_names - expected_task_names raise ValueError( "Invalid task batch sizes returned by get_task_batch_sizes: " f"missing={sorted(missing_task_names)}, extra={sorted(extra_task_names)}" ) negative_batch_sizes = { task_name: task_batch_size for task_name, task_batch_size in task_batch_sizes.items() if task_batch_size < 0 } if negative_batch_sizes: raise ValueError(f"Task batch sizes must be non-negative, got {negative_batch_sizes}") total_batch_size = sum(task_batch_sizes.values()) if total_batch_size != global_batch_size: raise ValueError( "Task batch sizes must sum to the requested global batch size, " f"got total={total_batch_size}, expected={global_batch_size}" ) async def _refresh_for_all_tasks(self, train_step: int, statuses: list[Status]) -> None: task_stale_thresholds: dict[str, int] = {} for task in self.task_runners: # colocate / disagg 都统一刷新 staleness;同步策略没有 stale_threshold 时使用 1。 stale_threshold = getattr(task.produce_strategy, "stale_threshold", 1) task_stale_thresholds[task.task_name] = stale_threshold expired_counts = await self.replay_buffer.refresh_staleness( task_stale_thresholds=task_stale_thresholds, current_train_step=train_step, statuses=statuses, ) for task_name, expired_count in expired_counts.items(): self.logger.info( f"[AgentLoopManager][{self.name}] Refresh staleness for task {task_name}: expired_count={expired_count}" ) def _get_task_batch_sizes_for_step(self, batch_size: int, train_step: int) -> dict[str, int]: if len(self.task_runners) == 1: return {self.task_runners[0].task_name: batch_size} task_batch_sizes = self.get_task_batch_sizes(batch_size, train_step) self._validate_task_batch_sizes(task_batch_sizes, batch_size) return task_batch_sizes @staticmethod def _aggregate_task_results( ordered_tasks: list[_TaskRunner], task_results: dict[str, ProduceBatchResult] ) -> ProduceBatchResult: rollout_states: list[list[RolloutState]] = [] leftover_init = 0 leftover_completed = 0 leftover_aborted = 0 leftover_expired = 0 leftover_failed = 0 leftover_filtered = 0 total_group_count = 0 weighted_group_mean_sum = 0.0 weighted_group_p50_sum = 0.0 weighted_group_p99_sum = 0.0 weighted_group_ratio_sum = 0.0 total_pause_time_s = 0.0 raw_rewards_sum = 0.0 raw_rewards_count = 0 produced_samples = 0 produced_tokens = 0 produce_time_s = 0.0 for task in ordered_tasks: result = task_results[task.task_name] rollout_states.extend(result.rollout_states) leftover_init += result.leftover_init leftover_completed += result.leftover_completed leftover_aborted += result.leftover_aborted leftover_expired += result.leftover_expired leftover_failed += result.leftover_failed leftover_filtered += result.leftover_filtered raw_rewards_sum += result.raw_rewards_sum raw_rewards_count += result.raw_rewards_count produced_samples += result.produced_samples produced_tokens += result.produced_tokens produce_time_s += result.produce_time_s if result.group_gen_count is not None and result.group_gen_mean_s is not None: total_group_count += result.group_gen_count weighted_group_mean_sum += result.group_gen_count * result.group_gen_mean_s weighted_group_p50_sum += result.group_gen_count * (result.group_gen_p50_s or 0.0) weighted_group_p99_sum += result.group_gen_count * (result.group_gen_p99_s or 0.0) weighted_group_ratio_sum += result.group_gen_count * (result.group_gen_p99_p50_ratio or 0.0) total_pause_time_s += result.group_gen_pause_time_s or 0.0 aggregated = ProduceBatchResult( rollout_states=rollout_states, leftover_init=leftover_init, leftover_completed=leftover_completed, leftover_aborted=leftover_aborted, leftover_expired=leftover_expired, leftover_failed=leftover_failed, leftover_filtered=leftover_filtered, raw_rewards_sum=raw_rewards_sum, raw_rewards_count=raw_rewards_count, produced_samples=produced_samples, produced_tokens=produced_tokens, produce_time_s=produce_time_s, task_results={task.task_name: task_results[task.task_name] for task in ordered_tasks}, ) if total_group_count > 0: aggregated.group_gen_count = total_group_count aggregated.group_gen_mean_s = weighted_group_mean_sum / total_group_count aggregated.group_gen_p50_s = weighted_group_p50_sum / total_group_count aggregated.group_gen_p99_s = weighted_group_p99_sum / total_group_count aggregated.group_gen_p99_p50_ratio = weighted_group_ratio_sum / total_group_count aggregated.group_gen_pause_time_s = total_pause_time_s return aggregated async def _produce_batch_to_buffer( self, task_batch_sizes: dict[str, int], progress: ProduceProgress, ) -> ProduceBatchStatus: current_future_step = progress.producer_future_step model_step = self._model_step expired_tasks = [ task.task_name for task in self.task_runners if task.produce_strategy.is_model_expired(current_future_step, model_step) ] if expired_tasks: self.logger.info( f"[AgentLoopManager][{self.name}] EXPIRED_BATCH: " f"future_step={current_future_step}, tasks={expired_tasks}" ) return ProduceBatchStatus.EXPIRED_BATCH active_tasks = [task for task in self.task_runners if progress.target_samples[task.task_name] > 0] assert active_tasks, "No active tasks found" produce_start = time.perf_counter() try: statuses = await asyncio.gather( *[ task.produce_strategy.produce_batch( _build_produce_context( task, self.replay_buffer, task_batch_sizes[task.task_name], current_future_step, model_step, self._update_event, progress, ) ) for task in active_tasks ] ) finally: progress.add_produce_time(time.perf_counter() - produce_start) return _aggregate_status(statuses) async def pause_produce( self, *, use_global_progress: bool, progress: ProduceProgress | None = None, ) -> float: # 这是 producer 的“显式刹车”接口。 # # 设计动机: # - 旧 colocate 语义里,一次 produce_batch() 结束后就自然收尾; # - 非共卡后,producer 可能在后台持续运行,何时停下来必须交给 trainer 明确控制。 # # 因此调用方必须显式说明是否使用全局 progress: # - use_global_progress=True:非共卡后台生产循环在权重同步点前暂停; # - use_global_progress=False:共卡同步 produce_batch 的本次调用收尾,使用本地 progress。 # 返回值 `pause_time_s` 不是业务语义,而是日志/诊断信息, # 供训练侧在下一次消费 batch 时上报。 # use_global_progress=False 模式会在下一次 produce_batch 入口通过 continue_produce 恢复; # use_global_progress=True 模式则由 trainer 在权重同步和评测完成后显式恢复。 if use_global_progress: if progress is not None: raise ValueError("progress must not be provided when use_global_progress=True.") pause_progress = self._produce_progress else: if progress is None: raise ValueError("progress must be provided when use_global_progress=False.") pause_progress = progress # 合法参数确认后,统一拉起 manager 级暂停信号,阻止仍在运行的 produce_batch 继续调度新 rollout。 self._update_event.set() self._status = AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT # 必须先让 producer / strategy 看到暂停状态,再暂停 rollout controller,避免暂停过程中继续调度新请求。 rollout_ctl = await get_agent_loop_rollout_ctl(self.task_runners[0].agent_loop) await rollout_ctl.pause_generation.remote() # type: ignore[attr-defined] pause_time_s = 0.0 for task in self.task_runners: ctx = _build_produce_context( task, self.replay_buffer, 0, pause_progress.producer_future_step, self._model_step, self._update_event, pause_progress, ) pause_time_s += await task.produce_strategy.pause_produce( ctx, ) self._pause_time_s = pause_time_s return pause_time_s def _log_buffer_counts( self, task_batch_sizes: dict[str, int], batch_by_task: dict[str, list[list[RolloutState]]], leftover_counts: dict[str, dict[Status, int]], ) -> None: for task in self.task_runners: task_name = task.task_name task_counts = leftover_counts.get(task_name, {}) self.logger.info( f"[AgentLoopManager][{self.name}] get_batch from buffer for task {task_name}: " f"requested={task_batch_sizes[task_name]}, retrieved={len(batch_by_task.get(task_name, []))}, " f"leftover_init={task_counts.get(Status.INIT, 0)}, " f"leftover_completed={task_counts.get(Status.COMPLETED, 0)}, " f"leftover_aborted={task_counts.get(Status.ABORTED, 0)}, " f"leftover_expired={task_counts.get(Status.EXPIRED, 0)}, " f"leftover_failed={task_counts.get(Status.FAILED, 0)}, " f"leftover_filtered={task_counts.get(Status.FILTERED, 0)}" ) def _build_result_from_batch( self, task_batch_sizes: dict[str, int], batch_by_task: dict[str, list[list[RolloutState]]], leftover_counts: dict[str, dict[Status, int]], *, progress: ProduceProgress, pause_time_s: float, ) -> ProduceBatchResult: if len(self.task_runners) == 1: task = self.task_runners[0] raw_rewards_sum, raw_rewards_count = progress.consume_raw_rewards(task.task_name) produced_samples, produced_tokens = progress.consume_produced(task.task_name) produce_time_s = progress.consume_produce_time() result = ProduceBatchResult( rollout_states=batch_by_task.get(task.task_name, []), raw_rewards_sum=raw_rewards_sum, raw_rewards_count=raw_rewards_count, produced_samples=produced_samples, produced_tokens=produced_tokens, produce_time_s=produce_time_s, ) _fill_leftover_counts(result, leftover_counts.get(task.task_name, {})) _fill_group_timing_stats(result, result.rollout_states, pause_time_s=pause_time_s) return result task_results: dict[str, ProduceBatchResult] = {} produce_time_s = progress.consume_produce_time() for task in self.task_runners: raw_rewards_sum, raw_rewards_count = progress.consume_raw_rewards(task.task_name) produced_samples, produced_tokens = progress.consume_produced(task.task_name) result = ProduceBatchResult( rollout_states=batch_by_task.get(task.task_name, []), raw_rewards_sum=raw_rewards_sum, raw_rewards_count=raw_rewards_count, produced_samples=produced_samples, produced_tokens=produced_tokens, ) _fill_leftover_counts(result, leftover_counts.get(task.task_name, {})) task_results[task.task_name] = result ordered_tasks = sorted(self.task_runners, key=lambda task: (task.task_name, task.order)) aggregated = self._aggregate_task_results(ordered_tasks, task_results) aggregated.produce_time_s = produce_time_s aggregated.task_batch_sizes = {task.task_name: task_batch_sizes[task.task_name] for task in ordered_tasks} _fill_group_timing_stats(aggregated, aggregated.rollout_states, pause_time_s=pause_time_s) return aggregated async def _get_batch_from_buffer( self, *, batch_size: int, task_batch_sizes: dict[str, int], consume_progress: ProduceProgress, ) -> ProduceBatchResult: pause_time_s = self._pause_time_s self._pause_time_s = 0.0 self._validate_task_batch_sizes(task_batch_sizes, batch_size) batch_by_task, consumed_counts = await self.replay_buffer.take_batch(task_batch_sizes) consume_progress.mark_consumed(consumed_counts) leftover_counts = await self.replay_buffer.count_statuses(self.task_names, _LEFTOVER_STATUSES) self._log_buffer_counts(task_batch_sizes, batch_by_task, leftover_counts) return self._build_result_from_batch( task_batch_sizes, batch_by_task, leftover_counts, progress=consume_progress, pause_time_s=pause_time_s, ) async def continue_produce(self, model_step: int) -> None: # # 它和 pause_produce(use_global_progress=True) 是一对: # - pause_produce(...) 负责让 producer 停下来; # - continue_produce(...) 负责在同步/评测完成后解除暂停。 # # 这里同步更新 `_model_step`,表示 rollout 侧接下来生成样本时, # 应把“当前正在使用的是哪一版权重”记录成这个版本号。 self._model_step = model_step rollout_ctl = await get_agent_loop_rollout_ctl(self.task_runners[0].agent_loop) await rollout_ctl.continue_generation.remote() # type: ignore[attr-defined] # rollout controller 真正恢复后,再把 manager 暴露成 NORMAL,produce_loop 才能继续生产。 self._status = AgentLoopManagerStatus.NORMAL self._update_event.clear() def shutdown(self) -> None: # 公开收口后台 producer 的退出信号,避免 trainer 直接写 manager 私有状态。 self._status = AgentLoopManagerStatus.FINISH self._update_event.set() self._finish_event.set() async def _wait_for_status_exit(self, blocked_status: AgentLoopManagerStatus) -> None: while not self._finish_event.is_set() and self._status == blocked_status: await asyncio.sleep(self._STATUS_POLL_INTERVAL_S) async def produce_batch( self, batch_size: int, train_step: int, *, model_step: int, ) -> ProduceBatchResult: # `produce_batch()` 是保留给 colocate 路径的同步入口。 # # 它虽然名字没变,但内部已经改成三段式: # 1. `_produce_batch_to_buffer()` 只负责生产,把结果写入 replay buffer # 2. `pause_produce()` 显式收尾 pending rollout # 3. `_get_batch_from_buffer()` 再把训练 batch 取出来 # # 这也是为什么这里要求返回非空 batch: # - colocate 语义下,调用它就是为了拿一批可训练 completed groups # - 如果需要合法返回空 batch + 特殊状态,那应该走 disagg 的 `get_batch()` if batch_size <= 0: raise ValueError(f"produce_batch expects batch_size > 0, got {batch_size}") start = time.perf_counter() self.logger.info( f"[AgentLoopManager][{self.name}] Start produce_batch: train_step={train_step} model_step={model_step} batch_size={batch_size}" ) current_sizes = self._get_task_batch_sizes_for_step(batch_size, train_step) active_tasks = [task for task in self.task_runners if current_sizes[task.task_name] > 0] assert active_tasks, "No active tasks found" # 共卡路径不复用非共卡的 paused producer 状态机。 # 即使 manager 是从 resume() 恢复出来、当前仍处在 UPDATE_WEIGHT_AND_ABORT, # produce_batch() 也应视作一次独立的同步生产过程,从干净状态开始。 # # 共卡路径下,produce_batch() 对应 rollout worker 当前持有的权重版本。 await self.continue_produce(model_step=model_step) local_progress = ProduceProgress.build_local(self.task_names, current_sizes, train_step) status = ProduceBatchStatus.NORMAL try: # 共卡 produce_batch 也是消费入口;生产前先刷新 buffer 中已有 completed / aborted。 await self._refresh_for_all_tasks(train_step, [Status.COMPLETED, Status.ABORTED]) status = await self._produce_batch_to_buffer( task_batch_sizes=current_sizes, progress=local_progress, ) finally: await self.pause_produce( use_global_progress=False, progress=local_progress, ) result = await self._get_batch_from_buffer( batch_size=batch_size, task_batch_sizes=current_sizes, consume_progress=local_progress, ) result.status = status assert result.rollout_states, ( "AgentLoopManager.produce_batch() must return non-empty rollout_states for colocated training. " "Use get_batch() for disaggregated empty/expired reads." ) self.logger.info( f"[AgentLoopManager][{self.name}] produce_batch done " f"elapsed={time.perf_counter() - start:.3f}, completed_groups={len(result.rollout_states)}" ) return result async def produce_loop(self, batch_size: int) -> None: # `produce_loop()` 是非共卡新增的后台生产循环。 # batch_size 表示每个 future train_step 的目标生产规模;producer 需要它来推进累计目标, # 所以这个参数保留在后台生产入口,而不是从 get_batch() 的消费请求里推断。 # # 和 colocate 最大的区别是: # - 它不直接把 batch 返回给 trainer # - 它只是持续把样本“喂”进 replay buffer # - trainer 前台通过 `get_batch()` 异步消费 # # 因此这里的核心职责不是“凑出一批训练数据”,而是根据 manager 的全局状态机 # 决定什么时候继续生产、什么时候暂停等待、什么时候彻底退出。 while not self._finish_event.is_set(): if self._status == AgentLoopManagerStatus.FINISH: break if self._status in (AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT, AgentLoopManagerStatus.EXPIRED_BATCH): # 同步前主动暂停和模型过期都只能由 trainer 调用 continue_produce() 恢复。 await self._wait_for_status_exit(self._status) continue task_batch_sizes = self._produce_progress.ensure_target_upto( batch_size=batch_size, future_step=self._produce_progress.producer_future_step, allocate_batch_sizes=self._get_task_batch_sizes_for_step, ) produce_status = await self._produce_batch_to_buffer( task_batch_sizes=task_batch_sizes, progress=self._produce_progress, ) if produce_status == ProduceBatchStatus.EXPIRED_BATCH: # 注意: # - EXPIRED_BATCH 是 producer 在生产过程中自己检测出来的“立即停下”信号 # - UPDATE_WEIGHT_AND_ABORT 则是 trainer 在同步前通过 pause_produce() 主动设置的 self._status = AgentLoopManagerStatus.EXPIRED_BATCH elif produce_status == ProduceBatchStatus.NORMAL: # 只有正常完成一轮生产时,producer 自己维护的 train_step 才前进一步。 self._produce_progress.advance_future_step() # 主动让出事件循环,避免 fake strategy / 极快路径在测试里造成忙等空转。 await asyncio.sleep(0) async def get_batch(self, batch_size: int, train_step: int) -> ProduceBatchResult: # `get_batch()` 是非共卡路径给 trainer 的消费接口。 # # 设计上它和 `produce_batch()` 明确分工: # - `produce_batch()`:colocate,一次调用内完成“生产+收尾+取数” # - `get_batch()`:disagg,等待 replay buffer 准备好当前训练步所需 batch 后再取数 # # 因而这里允许返回空 batch 的唯一合法场景是: # - manager 已进入 EXPIRED_BATCH # - 当前训练侧已有比 rollout 侧更新的 Model Step,可以直接同步过去 # 如果没有更新的模型版本,则要么消费当前已准备好的 batch,要么 fail fast 暴露不变量破坏。 progress = self._produce_progress progress.begin_consume(train_step) await self._refresh_for_all_tasks(train_step, [Status.COMPLETED, Status.ABORTED]) task_batch_sizes = self._get_task_batch_sizes_for_step(batch_size, train_step) current_model_step = train_step - 1 while not self._finish_event.is_set(): if self._status == AgentLoopManagerStatus.EXPIRED_BATCH: # 只有训练侧已经有更新的 Model Step,空 expired 才能跳过训练并直接同步。 if current_model_step > self._model_step: pause_time_s = self._pause_time_s self._pause_time_s = 0.0 result = ProduceBatchResult( rollout_states=[], status=ProduceBatchStatus.EXPIRED_BATCH, ) if pause_time_s > 0: result.group_gen_pause_time_s = pause_time_s return result # 没有更新模型且当前 batch 不 ready 时,producer 已停且无法靠同步恢复,必须立即暴露不变量。 if not await self.replay_buffer.is_ready(task_batch_sizes): leftover_counts = await self.replay_buffer.count_statuses(self.task_names, _LEFTOVER_STATUSES) raise RuntimeError( "AgentLoopManager reached EXPIRED_BATCH without a newer model or a ready batch: " f"train_step={train_step}, current_model_step={current_model_step}, " f"rollout_model_step={self._model_step}, manager_status={self._status.name}, " f"producer_future_step={progress.producer_future_step}, " f"next_consumer_step={progress.next_consumer_step}, " f"target_upto_future_step={progress.target_upto_future_step}, " f"target_samples={progress.target_samples}, " f"consumed_samples={progress.consumed_samples}, " f"task_batch_sizes={task_batch_sizes}, " f"leftover_status_counts={leftover_counts}" ) if await self.replay_buffer.is_ready(task_batch_sizes): result = await self._get_batch_from_buffer( batch_size=batch_size, task_batch_sizes=task_batch_sizes, consume_progress=progress, ) if self._status == AgentLoopManagerStatus.EXPIRED_BATCH: # expired 但带数据表示 trainer 仍需完成本 step,再用新 Model Step 恢复 producer。 result.status = ProduceBatchStatus.EXPIRED_BATCH if result.rollout_states: progress.finish_consume(train_step) await self._refresh_for_all_tasks(train_step + 1, [Status.COMPLETED, Status.ABORTED]) return result await asyncio.sleep(self._STATUS_POLL_INTERVAL_S) return ProduceBatchResult(rollout_states=[]) def _task_checkpoint_path(self, checkpoint_path: Path | str, task_name: str) -> Path: checkpoint_path = Path(checkpoint_path) return checkpoint_path / self._TASK_CHECKPOINT_DIR / task_name def _manager_state_path(self, checkpoint_path: Path | str) -> Path: checkpoint_path = Path(checkpoint_path) return checkpoint_path / self._MANAGER_STATE_PATH def _get_pending_task_counts(self) -> dict[str, int]: pending_task_counts: dict[str, int] = {} for task in self.task_runners: pending_count = task.produce_strategy.pending_task_count() if pending_count > 0: pending_task_counts[task.task_name] = pending_count return pending_task_counts async def save(self, checkpoint_path: Path | str, model_step: int) -> None: """Save all task sampler states and the shared replay buffer.""" checkpoint_path = Path(checkpoint_path) checkpoint_path.mkdir(parents=True, exist_ok=True) pending_task_counts = self._get_pending_task_counts() if pending_task_counts: raise RuntimeError( "Cannot save AgentLoopManager while pending rollout tasks still exist: " f"{pending_task_counts}. Call pause_produce() first." ) # 保存前显式记录当前 checkpoint 对应的模型步数,resume 时直接恢复这一份状态。 self._model_step = model_step for task in self.task_runners: task_checkpoint_path = self._task_checkpoint_path(checkpoint_path, task.task_name) task_checkpoint_path.mkdir(parents=True, exist_ok=True) task.sampler.save(task_checkpoint_path) # manager 层保持 async 语义;同步入口只允许在 trainer 边界用 asyncio_run 包起来。 await self.replay_buffer.save(checkpoint_path) manager_state_path = self._manager_state_path(checkpoint_path) progress_state = self._produce_progress.state_dict() with manager_state_path.open("w") as f: json.dump( { "status": self._status.name, "model_step": self._model_step, **progress_state, }, f, ) async def resume(self, checkpoint_path: Path | str) -> int: """Resume all task sampler states and the shared replay buffer.""" checkpoint_path = Path(checkpoint_path) for task in self.task_runners: task.sampler.resume(self._task_checkpoint_path(checkpoint_path, task.task_name)) # replay buffer 恢复是 async I/O,不能在已有 event loop 中再次嵌套 asyncio_run。 await self.replay_buffer.resume(checkpoint_path) manager_state_path = self._manager_state_path(checkpoint_path) with manager_state_path.open("r") as f: manager_state = json.load(f) saved_model_step = manager_state["model_step"] self._produce_progress.load_state_dict(manager_state) self._update_event = asyncio.Event() self._finish_event = asyncio.Event() self._update_event.set() self._status = AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT self._pause_time_s = 0.0 self._model_step = saved_model_step return saved_model_step