import asyncio
import json
import math
import time
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.rl.agent_loop import AgentLoopConfig, AgentLoopSpec, get_agent_loop_rollout_ctl
from xtuner.v1.rl.judger import ComposedJudgerConfig, JudgerConfig, build_judger
from xtuner.v1.rl.replay_buffer import ReplayBuffer
from xtuner.v1.rl.rollout import RolloutController
from xtuner.v1.utils import get_logger
from .producer import (
GROUP_GENERATE_TIME_KEY,
ProduceBatchStatus,
ProduceContext,
ProduceProgress,
ProduceStrategy,
ProduceStrategyConfig,
SyncProduceStrategyConfig,
default_is_valid_sample_fn,
)
from .sampler import Sampler, SamplerConfig
@dataclass
class ProduceBatchResult:
"""Result of a single ``produce_batch`` call.
Args:
rollout_states (list[list[RolloutState]]): Completed rollout groups retrieved from the replay buffer for training.
group_gen_count (int | None): Number of generate-group calls finished in this batch (None if no generations ran).
group_gen_mean_s (float | None): Mean wall-clock time per generate-group call, in seconds.
group_gen_p50_s (float | None): Median (p50) generate-group time, in seconds.
group_gen_p99_s (float | None): 99th percentile generate-group time, in seconds.
group_gen_p99_p50_ratio (float | None): Ratio of p99 to p50, indicating tail-latency skew.
group_gen_pause_time_s (float | None): Time spent in pause/cleanup phase (async strategy only), in seconds.
leftover_init (int): Number of init groups remaining in the replay buffer after this batch.
leftover_completed (int): Number of completed groups remaining in the replay buffer after this batch.
leftover_aborted (int): Number of aborted groups remaining in the replay buffer.
leftover_expired (int): Number of expired groups remaining in the replay buffer.
leftover_failed (int): Number of failed groups remaining in the replay buffer.
leftover_filtered (int): Number of filtered groups remaining in the replay buffer.
raw_rewards_sum (float): Sum of rewards produced before replay-buffer insertion for the current window.
raw_rewards_count (int): Number of reward-bearing samples included in ``raw_rewards_sum``.
produced_samples (int): Number of rollout samples produced in the current produce window.
produced_tokens (int): Number of response tokens produced in the current produce window.
produce_time_s (float): Wall-clock production time consumed by the current produce window.
"""
rollout_states: list[list[RolloutState]]
status: ProduceBatchStatus = ProduceBatchStatus.NORMAL
# per-group generation timing stats (all None if no generations occurred)
group_gen_count: int | None = None
group_gen_mean_s: float | None = None
group_gen_p50_s: float | None = None
group_gen_p99_s: float | None = None
group_gen_p99_p50_ratio: float | None = None
group_gen_pause_time_s: float | None = None
# leftover samples remaining in replay buffer after batch retrieval
leftover_init: int = 0
leftover_completed: int = 0
leftover_aborted: int = 0
leftover_expired: int = 0
leftover_failed: int = 0
leftover_filtered: int = 0
# rewards produced during the current produce window, including completed and filtered groups.
raw_rewards_sum: float = 0.0
raw_rewards_count: int = 0
produced_samples: int = 0
produced_tokens: int = 0
produce_time_s: float = 0.0
task_batch_sizes: dict[str, int] | None = None
task_results: dict[str, "ProduceBatchResult"] | None = None
@dataclass(frozen=True)
class _TaskRunner:
task_name: str
agent_loop: AgentLoopSpec
produce_strategy: ProduceStrategy
sampler: Sampler
weight: float = 1.0
order: int = 0
class _TaskSamplerView:
def __init__(self, samplers: list[Sampler]):
self._samplers = samplers
def __len__(self) -> int:
return sum(len(sampler) for sampler in self._samplers)
class AgentLoopManagerStatus(Enum):
"""AgentLoopManager 的全局状态.
按下面的路径流转:
- 初始状态是 NORMAL
- NORMAL -> UPDATE_WEIGHT_AND_ABORT
- trainer 开始做权重同步前触发
- UPDATE_WEIGHT_AND_ABORT -> NORMAL
- 权重同步完成后调用 continue_product()
- NORMAL -> EXPIRED_BATCH
- 当前 rollout model 已经过旧
- EXPIRED_BATCH -> UPDATE_WEIGHT_AND_ABORT
- trainer 检测到过期后,进入权重同步阶段
- 任意状态 -> FINISH
- 训练结束
这里有一个重要区分:
- AgentLoopManagerStatus 是“后台 producer 的全局运行状态”
- ProduceBatchStatus 是“单次调度调用的局部结果”
"""
NORMAL = auto()
UPDATE_WEIGHT_AND_ABORT = auto()
EXPIRED_BATCH = auto()
FINISH = auto()
def _fill_produce_timing_stats(
result: ProduceBatchResult, generate_times_s: list[float], pause_time_s: float = 0.0
) -> None:
if not generate_times_s:
if pause_time_s > 0:
result.group_gen_pause_time_s = pause_time_s
return
sorted_times = sorted(generate_times_s)
n = len(sorted_times)
mean_s = sum(sorted_times) / n
p50_s = sorted_times[n // 2]
p99_s = sorted_times[int(n * 0.99)]
ratio = p99_s / p50_s if p50_s > 0 else float("inf")
result.group_gen_count = n
result.group_gen_mean_s = mean_s
result.group_gen_p50_s = p50_s
result.group_gen_p99_s = p99_s
result.group_gen_p99_p50_ratio = ratio
result.group_gen_pause_time_s = pause_time_s
def _fill_group_timing_stats(
result: ProduceBatchResult, rollout_states: list[list[RolloutState]], pause_time_s: float = 0.0
) -> None:
generate_times: list[float] = []
for group in rollout_states:
if not group:
continue
group_time = getattr(group[0], "extra_fields", {}).get(GROUP_GENERATE_TIME_KEY)
if group_time is not None:
generate_times.append(group_time)
_fill_produce_timing_stats(result, generate_times, pause_time_s=pause_time_s)
def _aggregate_status(statuses: list[ProduceBatchStatus]) -> ProduceBatchStatus:
if any(status== ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT for status in statuses):
return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT
if any(status== ProduceBatchStatus.EXPIRED_BATCH for status in statuses):
return ProduceBatchStatus.EXPIRED_BATCH
return ProduceBatchStatus.NORMAL
_LEFTOVER_STATUSES = [
Status.INIT,
Status.COMPLETED,
Status.ABORTED,
Status.EXPIRED,
Status.FAILED,
Status.FILTERED,
]
def _fill_leftover_counts(result: ProduceBatchResult, status_counts: dict[Status, int]) -> None:
result.leftover_init = status_counts.get(Status.INIT, 0)
result.leftover_completed = status_counts.get(Status.COMPLETED, 0)
result.leftover_aborted = status_counts.get(Status.ABORTED, 0)
result.leftover_expired = status_counts.get(Status.EXPIRED, 0)
result.leftover_failed = status_counts.get(Status.FAILED, 0)
result.leftover_filtered = status_counts.get(Status.FILTERED, 0)
def _build_produce_context(
task_runner: _TaskRunner,
replay_buffer: ReplayBuffer,
batch_size: int,
train_step: int,
model_step: int,
update_event: asyncio.Event,
progress: ProduceProgress,
) -> ProduceContext:
return ProduceContext(
agent_loop=task_runner.agent_loop,
sampler=task_runner.sampler,
replay_buffer=replay_buffer,
task_batch_size=batch_size,
task_name=task_runner.task_name,
train_step=train_step,
update_event=update_event,
model_step=model_step,
progress=progress,
is_valid_sample_fn=getattr(task_runner.produce_strategy, "is_valid_sample_fn", default_is_valid_sample_fn),
stale_threshold=getattr(task_runner.produce_strategy, "stale_threshold", None),
)
[docs]class TaskSpecConfig(BaseModel):
"""Configuration for one task managed by ``AgentLoopManager``.
A task spec binds together the dataset sampler, agent loop, optional judger,
production strategy, and sampling weight for one RL data source. Multi-task
training is represented as a list of ``TaskSpecConfig`` objects.
Args:
task_name (str): Unique task name used for logging, replay-buffer
routing, and checkpoint state.
weight (float): Relative batch allocation weight for this task in
multi-task training. Defaults to 1.0.
agent_loop_config (AgentLoopConfig): Agent loop configuration used to
generate rollout samples for this task.
judger_config (JudgerConfig | ComposedJudgerConfig | None): Optional
judger configuration used to score generated samples. Defaults to
None.
produce_strategy_config (ProduceStrategyConfig): Strategy used to
produce rollout samples. Defaults to ``SyncProduceStrategyConfig``.
sampler_config (SamplerConfig): Dataset sampler configuration for this
task.
**Examples:**
Example configuration for one task::
task = TaskSpecConfig(
task_name="gsm8k",
weight=1.0,
agent_loop_config=SingleTurnAgentLoopConfig(
hf_checkpoint="Qwen/Qwen3-8B",
sample_params=SampleParams(max_tokens=1024),
),
judger_config=GSM8KJudgerConfig(),
sampler_config=SamplerConfig(dataloader_cfg=dataloader_cfg, prompt_repeat_k=8),
)
"""
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
task_name: str
weight: float = Field(default=1.0, ge=0.0)
agent_loop_config: AgentLoopConfig
judger_config: JudgerConfig | ComposedJudgerConfig | None = None
produce_strategy_config: ProduceStrategyConfig = SyncProduceStrategyConfig()
sampler_config: SamplerConfig
[docs]class AgentLoopManagerConfig(BaseModel):
"""Configuration for the agent loop manager.
``AgentLoopManagerConfig`` defines the rollout-producing side of RL
training. It may manage a single task or a weighted list of tasks, and each
task owns its sampler, agent loop, judger, and production strategy.
Args:
tasks (list[TaskSpecConfig] | TaskSpecConfig): One task config or a
list of task configs. Task names must be unique when a list is
provided.
**Examples:**
Example configuration for a single-task manager::
config = AgentLoopManagerConfig(
tasks=TaskSpecConfig(
task_name="gsm8k",
agent_loop_config=SingleTurnAgentLoopConfig(
hf_checkpoint="Qwen/Qwen3-8B",
sample_params=SampleParams(max_tokens=1024),
),
judger_config=GSM8KJudgerConfig(),
sampler_config=SamplerConfig(dataloader_cfg=dataloader_cfg, prompt_repeat_k=8),
)
)
"""
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
tasks: list[TaskSpecConfig] | TaskSpecConfig
def build(
self,
rollout_controller: RolloutController,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
replay_buffer: ReplayBuffer,
logger=None,
sync_weights_interval: int = 1,
) -> "AgentLoopManager":
tasks = self.tasks if isinstance(self.tasks, list) else [self.tasks]
if not tasks:
raise ValueError("AgentLoopManagerConfig requires at least one task config.")
seen_task_names: set[str] = set()
task_runners: list[_TaskRunner] = []
for order, task_cfg in enumerate(tasks):
if task_cfg.task_name in seen_task_names:
raise ValueError(f"Duplicate task_name found in AgentLoopManagerConfig: {task_cfg.task_name}")
seen_task_names.add(task_cfg.task_name)
agent_loop = task_cfg.agent_loop_config.build(
rollout_controller=rollout_controller,
judger=build_judger(task_cfg.judger_config) if task_cfg.judger_config is not None else None,
logger=logger,
)
produce_strategy = task_cfg.produce_strategy_config.build(
sync_weights_interval=sync_weights_interval,
rollout_controller=rollout_controller,
)
sampler = task_cfg.sampler_config.build(tokenizer=tokenizer, replay_buffer=replay_buffer)
task_runners.append(
_TaskRunner(
task_name=task_cfg.task_name,
agent_loop=agent_loop,
produce_strategy=produce_strategy,
sampler=sampler,
weight=task_cfg.weight,
order=order,
)
)
return AgentLoopManager(
task_runners=task_runners,
replay_buffer=replay_buffer,
logger=logger,
)
class AgentLoopManager:
_TASK_CHECKPOINT_DIR = "tasks"
_MANAGER_STATE_PATH = "agent_loop_manager_state.json"
_STATUS_POLL_INTERVAL_S = 1.0
def __init__(
self,
task_runners: list[_TaskRunner],
replay_buffer: ReplayBuffer,
logger=None,
):
if not task_runners:
raise ValueError("AgentLoopManager requires at least one task runner.")
if sum(task.weight for task in task_runners) <= 0:
raise ValueError("At least one task weight must be positive for AgentLoopManager.")
self.task_runners = task_runners
self.replay_buffer = replay_buffer
self.data_sampler = (
task_runners[0].sampler
if len(task_runners) == 1
else _TaskSamplerView([task.sampler for task in task_runners])
)
self.name = task_runners[0].task_name if len(task_runners) == 1 else "multi_task"
if logger is None:
self.logger = get_logger()
else:
self.logger = logger
self.task_names = [task.task_name for task in self.task_runners]
# 非共卡并发控制信号:consumer 在同步权重前置位,producer / strategy 应直接观察
# event 状态并尽快停止继续发新 rollout;不要用额外布尔快照替代这个 event。
self._update_event = asyncio.Event()
self._finish_event = asyncio.Event()
# 非共卡 producer 读取的 model_step:rollout 侧当前使用的是哪个 train_step 同步后的模型。
# 权重更新前必须先 pause 并清空 pending task,因此一个 pending 生命周期内只对应一个 model_step。
self._model_step = 0
# 非共卡 producer / consumer 共享的控制状态。produce_loop / get_batch 应直接读取
# self._status,不要跨 await 缓存局部快照,避免错过同步、过期或结束状态变化。
self._status = AgentLoopManagerStatus.NORMAL
# pause_produce 写入、下一次 get batch 读取并清零的耗时指标。
# 只用于消费侧日志/metrics;读写不构成生产正确性依赖。
self._pause_time_s = 0.0
# 非共卡 producer / consumer 共享的绝对累计进度。对象引用必须保持稳定;
# consumer 原地更新字段,producer / strategy 需要字段值时直接读取 progress.xxx,
# 不要把字段值复制成跨 await 使用的局部快照。
self._produce_progress = ProduceProgress.build(self.task_names)
def get_task_batch_sizes(self, global_batch_size: int, train_step: int) -> dict[str, int]:
"""Return the per-task batch sizes for the current train step.
Subclasses may override this method to implement custom dynamic batch allocation policies. Returning 0 for a
task effectively disables that task for the current produce_batch call.
"""
if global_batch_size < 0:
raise ValueError(f"global_batch_size must be non-negative, got {global_batch_size}")
total_weight = sum(task.weight for task in self.task_runners)
if total_weight <= 0:
raise ValueError("Sum of task weights must be positive.")
if global_batch_size == 0:
return {task.task_name: 0 for task in self.task_runners}
raw_allocations = [global_batch_size * task.weight / total_weight for task in self.task_runners]
floor_allocations = [math.floor(raw) for raw in raw_allocations]
remaining = global_batch_size - sum(floor_allocations)
task_batch_sizes = {task.task_name: floor_allocations[idx] for idx, task in enumerate(self.task_runners)}
if remaining <= 0:
return task_batch_sizes
ranked_tasks = sorted(
enumerate(self.task_runners),
key=lambda item: (
-(raw_allocations[item[0]] - floor_allocations[item[0]]),
item[1].order,
),
)
for idx, task in ranked_tasks[:remaining]:
task_batch_sizes[task.task_name] += 1
return task_batch_sizes
def _validate_task_batch_sizes(self, task_batch_sizes: dict[str, int], global_batch_size: int) -> None:
expected_task_names = {task.task_name for task in self.task_runners}
actual_task_names = set(task_batch_sizes.keys())
if actual_task_names != expected_task_names:
missing_task_names = expected_task_names - actual_task_names
extra_task_names = actual_task_names - expected_task_names
raise ValueError(
"Invalid task batch sizes returned by get_task_batch_sizes: "
f"missing={sorted(missing_task_names)}, extra={sorted(extra_task_names)}"
)
negative_batch_sizes = {
task_name: task_batch_size
for task_name, task_batch_size in task_batch_sizes.items()
if task_batch_size < 0
}
if negative_batch_sizes:
raise ValueError(f"Task batch sizes must be non-negative, got {negative_batch_sizes}")
total_batch_size = sum(task_batch_sizes.values())
if total_batch_size != global_batch_size:
raise ValueError(
"Task batch sizes must sum to the requested global batch size, "
f"got total={total_batch_size}, expected={global_batch_size}"
)
async def _refresh_for_all_tasks(self, train_step: int, statuses: list[Status]) -> None:
task_stale_thresholds: dict[str, int] = {}
for task in self.task_runners:
# colocate / disagg 都统一刷新 staleness;同步策略没有 stale_threshold 时使用 1。
stale_threshold = getattr(task.produce_strategy, "stale_threshold", 1)
task_stale_thresholds[task.task_name] = stale_threshold
expired_counts = await self.replay_buffer.refresh_staleness(
task_stale_thresholds=task_stale_thresholds,
current_train_step=train_step,
statuses=statuses,
)
for task_name, expired_count in expired_counts.items():
self.logger.info(
f"[AgentLoopManager][{self.name}] Refresh staleness for task {task_name}: expired_count={expired_count}"
)
def _get_task_batch_sizes_for_step(self, batch_size: int, train_step: int) -> dict[str, int]:
if len(self.task_runners) == 1:
return {self.task_runners[0].task_name: batch_size}
task_batch_sizes = self.get_task_batch_sizes(batch_size, train_step)
self._validate_task_batch_sizes(task_batch_sizes, batch_size)
return task_batch_sizes
@staticmethod
def _aggregate_task_results(
ordered_tasks: list[_TaskRunner], task_results: dict[str, ProduceBatchResult]
) -> ProduceBatchResult:
rollout_states: list[list[RolloutState]] = []
leftover_init = 0
leftover_completed = 0
leftover_aborted = 0
leftover_expired = 0
leftover_failed = 0
leftover_filtered = 0
total_group_count = 0
weighted_group_mean_sum = 0.0
weighted_group_p50_sum = 0.0
weighted_group_p99_sum = 0.0
weighted_group_ratio_sum = 0.0
total_pause_time_s = 0.0
raw_rewards_sum = 0.0
raw_rewards_count = 0
produced_samples = 0
produced_tokens = 0
produce_time_s = 0.0
for task in ordered_tasks:
result = task_results[task.task_name]
rollout_states.extend(result.rollout_states)
leftover_init += result.leftover_init
leftover_completed += result.leftover_completed
leftover_aborted += result.leftover_aborted
leftover_expired += result.leftover_expired
leftover_failed += result.leftover_failed
leftover_filtered += result.leftover_filtered
raw_rewards_sum += result.raw_rewards_sum
raw_rewards_count += result.raw_rewards_count
produced_samples += result.produced_samples
produced_tokens += result.produced_tokens
produce_time_s += result.produce_time_s
if result.group_gen_count is not None and result.group_gen_mean_s is not None:
total_group_count += result.group_gen_count
weighted_group_mean_sum += result.group_gen_count * result.group_gen_mean_s
weighted_group_p50_sum += result.group_gen_count * (result.group_gen_p50_s or 0.0)
weighted_group_p99_sum += result.group_gen_count * (result.group_gen_p99_s or 0.0)
weighted_group_ratio_sum += result.group_gen_count * (result.group_gen_p99_p50_ratio or 0.0)
total_pause_time_s += result.group_gen_pause_time_s or 0.0
aggregated = ProduceBatchResult(
rollout_states=rollout_states,
leftover_init=leftover_init,
leftover_completed=leftover_completed,
leftover_aborted=leftover_aborted,
leftover_expired=leftover_expired,
leftover_failed=leftover_failed,
leftover_filtered=leftover_filtered,
raw_rewards_sum=raw_rewards_sum,
raw_rewards_count=raw_rewards_count,
produced_samples=produced_samples,
produced_tokens=produced_tokens,
produce_time_s=produce_time_s,
task_results={task.task_name: task_results[task.task_name] for task in ordered_tasks},
)
if total_group_count > 0:
aggregated.group_gen_count = total_group_count
aggregated.group_gen_mean_s = weighted_group_mean_sum / total_group_count
aggregated.group_gen_p50_s = weighted_group_p50_sum / total_group_count
aggregated.group_gen_p99_s = weighted_group_p99_sum / total_group_count
aggregated.group_gen_p99_p50_ratio = weighted_group_ratio_sum / total_group_count
aggregated.group_gen_pause_time_s = total_pause_time_s
return aggregated
async def _produce_batch_to_buffer(
self,
task_batch_sizes: dict[str, int],
progress: ProduceProgress,
) -> ProduceBatchStatus:
current_future_step = progress.producer_future_step
model_step = self._model_step
expired_tasks = [
task.task_name
for task in self.task_runners
if task.produce_strategy.is_model_expired(current_future_step, model_step)
]
if expired_tasks:
self.logger.info(
f"[AgentLoopManager][{self.name}] EXPIRED_BATCH: "
f"future_step={current_future_step}, tasks={expired_tasks}"
)
return ProduceBatchStatus.EXPIRED_BATCH
active_tasks = [task for task in self.task_runners if progress.target_samples[task.task_name] > 0]
assert active_tasks, "No active tasks found"
produce_start = time.perf_counter()
try:
statuses = await asyncio.gather(
*[
task.produce_strategy.produce_batch(
_build_produce_context(
task,
self.replay_buffer,
task_batch_sizes[task.task_name],
current_future_step,
model_step,
self._update_event,
progress,
)
)
for task in active_tasks
]
)
finally:
progress.add_produce_time(time.perf_counter() - produce_start)
return _aggregate_status(statuses)
async def pause_produce(
self,
*,
use_global_progress: bool,
progress: ProduceProgress | None = None,
) -> float:
# 这是 producer 的“显式刹车”接口。
#
# 设计动机:
# - 旧 colocate 语义里,一次 produce_batch() 结束后就自然收尾;
# - 非共卡后,producer 可能在后台持续运行,何时停下来必须交给 trainer 明确控制。
#
# 因此调用方必须显式说明是否使用全局 progress:
# - use_global_progress=True:非共卡后台生产循环在权重同步点前暂停;
# - use_global_progress=False:共卡同步 produce_batch 的本次调用收尾,使用本地 progress。
# 返回值 `pause_time_s` 不是业务语义,而是日志/诊断信息,
# 供训练侧在下一次消费 batch 时上报。
# use_global_progress=False 模式会在下一次 produce_batch 入口通过 continue_produce 恢复;
# use_global_progress=True 模式则由 trainer 在权重同步和评测完成后显式恢复。
if use_global_progress:
if progress is not None:
raise ValueError("progress must not be provided when use_global_progress=True.")
pause_progress = self._produce_progress
else:
if progress is None:
raise ValueError("progress must be provided when use_global_progress=False.")
pause_progress = progress
# 合法参数确认后,统一拉起 manager 级暂停信号,阻止仍在运行的 produce_batch 继续调度新 rollout。
self._update_event.set()
self._status = AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT
# 必须先让 producer / strategy 看到暂停状态,再暂停 rollout controller,避免暂停过程中继续调度新请求。
rollout_ctl = await get_agent_loop_rollout_ctl(self.task_runners[0].agent_loop)
await rollout_ctl.pause_generation.remote() # type: ignore[attr-defined]
pause_time_s = 0.0
for task in self.task_runners:
ctx = _build_produce_context(
task,
self.replay_buffer,
0,
pause_progress.producer_future_step,
self._model_step,
self._update_event,
pause_progress,
)
pause_time_s += await task.produce_strategy.pause_produce(
ctx,
)
self._pause_time_s = pause_time_s
return pause_time_s
def _log_buffer_counts(
self,
task_batch_sizes: dict[str, int],
batch_by_task: dict[str, list[list[RolloutState]]],
leftover_counts: dict[str, dict[Status, int]],
) -> None:
for task in self.task_runners:
task_name = task.task_name
task_counts = leftover_counts.get(task_name, {})
self.logger.info(
f"[AgentLoopManager][{self.name}] get_batch from buffer for task {task_name}: "
f"requested={task_batch_sizes[task_name]}, retrieved={len(batch_by_task.get(task_name, []))}, "
f"leftover_init={task_counts.get(Status.INIT, 0)}, "
f"leftover_completed={task_counts.get(Status.COMPLETED, 0)}, "
f"leftover_aborted={task_counts.get(Status.ABORTED, 0)}, "
f"leftover_expired={task_counts.get(Status.EXPIRED, 0)}, "
f"leftover_failed={task_counts.get(Status.FAILED, 0)}, "
f"leftover_filtered={task_counts.get(Status.FILTERED, 0)}"
)
def _build_result_from_batch(
self,
task_batch_sizes: dict[str, int],
batch_by_task: dict[str, list[list[RolloutState]]],
leftover_counts: dict[str, dict[Status, int]],
*,
progress: ProduceProgress,
pause_time_s: float,
) -> ProduceBatchResult:
if len(self.task_runners) == 1:
task = self.task_runners[0]
raw_rewards_sum, raw_rewards_count = progress.consume_raw_rewards(task.task_name)
produced_samples, produced_tokens = progress.consume_produced(task.task_name)
produce_time_s = progress.consume_produce_time()
result = ProduceBatchResult(
rollout_states=batch_by_task.get(task.task_name, []),
raw_rewards_sum=raw_rewards_sum,
raw_rewards_count=raw_rewards_count,
produced_samples=produced_samples,
produced_tokens=produced_tokens,
produce_time_s=produce_time_s,
)
_fill_leftover_counts(result, leftover_counts.get(task.task_name, {}))
_fill_group_timing_stats(result, result.rollout_states, pause_time_s=pause_time_s)
return result
task_results: dict[str, ProduceBatchResult] = {}
produce_time_s = progress.consume_produce_time()
for task in self.task_runners:
raw_rewards_sum, raw_rewards_count = progress.consume_raw_rewards(task.task_name)
produced_samples, produced_tokens = progress.consume_produced(task.task_name)
result = ProduceBatchResult(
rollout_states=batch_by_task.get(task.task_name, []),
raw_rewards_sum=raw_rewards_sum,
raw_rewards_count=raw_rewards_count,
produced_samples=produced_samples,
produced_tokens=produced_tokens,
)
_fill_leftover_counts(result, leftover_counts.get(task.task_name, {}))
task_results[task.task_name] = result
ordered_tasks = sorted(self.task_runners, key=lambda task: (task.task_name, task.order))
aggregated = self._aggregate_task_results(ordered_tasks, task_results)
aggregated.produce_time_s = produce_time_s
aggregated.task_batch_sizes = {task.task_name: task_batch_sizes[task.task_name] for task in ordered_tasks}
_fill_group_timing_stats(aggregated, aggregated.rollout_states, pause_time_s=pause_time_s)
return aggregated
async def _get_batch_from_buffer(
self,
*,
batch_size: int,
task_batch_sizes: dict[str, int],
consume_progress: ProduceProgress,
) -> ProduceBatchResult:
pause_time_s = self._pause_time_s
self._pause_time_s = 0.0
self._validate_task_batch_sizes(task_batch_sizes, batch_size)
batch_by_task, consumed_counts = await self.replay_buffer.take_batch(task_batch_sizes)
consume_progress.mark_consumed(consumed_counts)
leftover_counts = await self.replay_buffer.count_statuses(self.task_names, _LEFTOVER_STATUSES)
self._log_buffer_counts(task_batch_sizes, batch_by_task, leftover_counts)
return self._build_result_from_batch(
task_batch_sizes,
batch_by_task,
leftover_counts,
progress=consume_progress,
pause_time_s=pause_time_s,
)
async def continue_produce(self, model_step: int) -> None:
#
# 它和 pause_produce(use_global_progress=True) 是一对:
# - pause_produce(...) 负责让 producer 停下来;
# - continue_produce(...) 负责在同步/评测完成后解除暂停。
#
# 这里同步更新 `_model_step`,表示 rollout 侧接下来生成样本时,
# 应把“当前正在使用的是哪一版权重”记录成这个版本号。
self._model_step = model_step
rollout_ctl = await get_agent_loop_rollout_ctl(self.task_runners[0].agent_loop)
await rollout_ctl.continue_generation.remote() # type: ignore[attr-defined]
# rollout controller 真正恢复后,再把 manager 暴露成 NORMAL,produce_loop 才能继续生产。
self._status = AgentLoopManagerStatus.NORMAL
self._update_event.clear()
def shutdown(self) -> None:
# 公开收口后台 producer 的退出信号,避免 trainer 直接写 manager 私有状态。
self._status = AgentLoopManagerStatus.FINISH
self._update_event.set()
self._finish_event.set()
async def _wait_for_status_exit(self, blocked_status: AgentLoopManagerStatus) -> None:
while not self._finish_event.is_set() and self._status == blocked_status:
await asyncio.sleep(self._STATUS_POLL_INTERVAL_S)
async def produce_batch(
self,
batch_size: int,
train_step: int,
*,
model_step: int,
) -> ProduceBatchResult:
# `produce_batch()` 是保留给 colocate 路径的同步入口。
#
# 它虽然名字没变,但内部已经改成三段式:
# 1. `_produce_batch_to_buffer()` 只负责生产,把结果写入 replay buffer
# 2. `pause_produce()` 显式收尾 pending rollout
# 3. `_get_batch_from_buffer()` 再把训练 batch 取出来
#
# 这也是为什么这里要求返回非空 batch:
# - colocate 语义下,调用它就是为了拿一批可训练 completed groups
# - 如果需要合法返回空 batch + 特殊状态,那应该走 disagg 的 `get_batch()`
if batch_size <= 0:
raise ValueError(f"produce_batch expects batch_size > 0, got {batch_size}")
start = time.perf_counter()
self.logger.info(
f"[AgentLoopManager][{self.name}] Start produce_batch: train_step={train_step} model_step={model_step} batch_size={batch_size}"
)
current_sizes = self._get_task_batch_sizes_for_step(batch_size, train_step)
active_tasks = [task for task in self.task_runners if current_sizes[task.task_name] > 0]
assert active_tasks, "No active tasks found"
# 共卡路径不复用非共卡的 paused producer 状态机。
# 即使 manager 是从 resume() 恢复出来、当前仍处在 UPDATE_WEIGHT_AND_ABORT,
# produce_batch() 也应视作一次独立的同步生产过程,从干净状态开始。
#
# 共卡路径下,produce_batch() 对应 rollout worker 当前持有的权重版本。
await self.continue_produce(model_step=model_step)
local_progress = ProduceProgress.build_local(self.task_names, current_sizes, train_step)
status = ProduceBatchStatus.NORMAL
try:
# 共卡 produce_batch 也是消费入口;生产前先刷新 buffer 中已有 completed / aborted。
await self._refresh_for_all_tasks(train_step, [Status.COMPLETED, Status.ABORTED])
status = await self._produce_batch_to_buffer(
task_batch_sizes=current_sizes,
progress=local_progress,
)
finally:
await self.pause_produce(
use_global_progress=False,
progress=local_progress,
)
result = await self._get_batch_from_buffer(
batch_size=batch_size,
task_batch_sizes=current_sizes,
consume_progress=local_progress,
)
result.status = status
assert result.rollout_states, (
"AgentLoopManager.produce_batch() must return non-empty rollout_states for colocated training. "
"Use get_batch() for disaggregated empty/expired reads."
)
self.logger.info(
f"[AgentLoopManager][{self.name}] produce_batch done "
f"elapsed={time.perf_counter() - start:.3f}, completed_groups={len(result.rollout_states)}"
)
return result
async def produce_loop(self, batch_size: int) -> None:
# `produce_loop()` 是非共卡新增的后台生产循环。
# batch_size 表示每个 future train_step 的目标生产规模;producer 需要它来推进累计目标,
# 所以这个参数保留在后台生产入口,而不是从 get_batch() 的消费请求里推断。
#
# 和 colocate 最大的区别是:
# - 它不直接把 batch 返回给 trainer
# - 它只是持续把样本“喂”进 replay buffer
# - trainer 前台通过 `get_batch()` 异步消费
#
# 因此这里的核心职责不是“凑出一批训练数据”,而是根据 manager 的全局状态机
# 决定什么时候继续生产、什么时候暂停等待、什么时候彻底退出。
while not self._finish_event.is_set():
if self._status == AgentLoopManagerStatus.FINISH:
break
if self._status in (AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT, AgentLoopManagerStatus.EXPIRED_BATCH):
# 同步前主动暂停和模型过期都只能由 trainer 调用 continue_produce() 恢复。
await self._wait_for_status_exit(self._status)
continue
task_batch_sizes = self._produce_progress.ensure_target_upto(
batch_size=batch_size,
future_step=self._produce_progress.producer_future_step,
allocate_batch_sizes=self._get_task_batch_sizes_for_step,
)
produce_status = await self._produce_batch_to_buffer(
task_batch_sizes=task_batch_sizes,
progress=self._produce_progress,
)
if produce_status == ProduceBatchStatus.EXPIRED_BATCH:
# 注意:
# - EXPIRED_BATCH 是 producer 在生产过程中自己检测出来的“立即停下”信号
# - UPDATE_WEIGHT_AND_ABORT 则是 trainer 在同步前通过 pause_produce() 主动设置的
self._status = AgentLoopManagerStatus.EXPIRED_BATCH
elif produce_status == ProduceBatchStatus.NORMAL:
# 只有正常完成一轮生产时,producer 自己维护的 train_step 才前进一步。
self._produce_progress.advance_future_step()
# 主动让出事件循环,避免 fake strategy / 极快路径在测试里造成忙等空转。
await asyncio.sleep(0)
async def get_batch(self, batch_size: int, train_step: int) -> ProduceBatchResult:
# `get_batch()` 是非共卡路径给 trainer 的消费接口。
#
# 设计上它和 `produce_batch()` 明确分工:
# - `produce_batch()`:colocate,一次调用内完成“生产+收尾+取数”
# - `get_batch()`:disagg,等待 replay buffer 准备好当前训练步所需 batch 后再取数
#
# 因而这里允许返回空 batch 的唯一合法场景是:
# - manager 已进入 EXPIRED_BATCH
# - 当前训练侧已有比 rollout 侧更新的 Model Step,可以直接同步过去
# 如果没有更新的模型版本,则要么消费当前已准备好的 batch,要么 fail fast 暴露不变量破坏。
progress = self._produce_progress
progress.begin_consume(train_step)
await self._refresh_for_all_tasks(train_step, [Status.COMPLETED, Status.ABORTED])
task_batch_sizes = self._get_task_batch_sizes_for_step(batch_size, train_step)
current_model_step = train_step - 1
while not self._finish_event.is_set():
if self._status == AgentLoopManagerStatus.EXPIRED_BATCH:
# 只有训练侧已经有更新的 Model Step,空 expired 才能跳过训练并直接同步。
if current_model_step > self._model_step:
pause_time_s = self._pause_time_s
self._pause_time_s = 0.0
result = ProduceBatchResult(
rollout_states=[],
status=ProduceBatchStatus.EXPIRED_BATCH,
)
if pause_time_s > 0:
result.group_gen_pause_time_s = pause_time_s
return result
# 没有更新模型且当前 batch 不 ready 时,producer 已停且无法靠同步恢复,必须立即暴露不变量。
if not await self.replay_buffer.is_ready(task_batch_sizes):
leftover_counts = await self.replay_buffer.count_statuses(self.task_names, _LEFTOVER_STATUSES)
raise RuntimeError(
"AgentLoopManager reached EXPIRED_BATCH without a newer model or a ready batch: "
f"train_step={train_step}, current_model_step={current_model_step}, "
f"rollout_model_step={self._model_step}, manager_status={self._status.name}, "
f"producer_future_step={progress.producer_future_step}, "
f"next_consumer_step={progress.next_consumer_step}, "
f"target_upto_future_step={progress.target_upto_future_step}, "
f"target_samples={progress.target_samples}, "
f"consumed_samples={progress.consumed_samples}, "
f"task_batch_sizes={task_batch_sizes}, "
f"leftover_status_counts={leftover_counts}"
)
if await self.replay_buffer.is_ready(task_batch_sizes):
result = await self._get_batch_from_buffer(
batch_size=batch_size,
task_batch_sizes=task_batch_sizes,
consume_progress=progress,
)
if self._status == AgentLoopManagerStatus.EXPIRED_BATCH:
# expired 但带数据表示 trainer 仍需完成本 step,再用新 Model Step 恢复 producer。
result.status = ProduceBatchStatus.EXPIRED_BATCH
if result.rollout_states:
progress.finish_consume(train_step)
await self._refresh_for_all_tasks(train_step + 1, [Status.COMPLETED, Status.ABORTED])
return result
await asyncio.sleep(self._STATUS_POLL_INTERVAL_S)
return ProduceBatchResult(rollout_states=[])
def _task_checkpoint_path(self, checkpoint_path: Path | str, task_name: str) -> Path:
checkpoint_path = Path(checkpoint_path)
return checkpoint_path / self._TASK_CHECKPOINT_DIR / task_name
def _manager_state_path(self, checkpoint_path: Path | str) -> Path:
checkpoint_path = Path(checkpoint_path)
return checkpoint_path / self._MANAGER_STATE_PATH
def _get_pending_task_counts(self) -> dict[str, int]:
pending_task_counts: dict[str, int] = {}
for task in self.task_runners:
pending_count = task.produce_strategy.pending_task_count()
if pending_count > 0:
pending_task_counts[task.task_name] = pending_count
return pending_task_counts
async def save(self, checkpoint_path: Path | str, model_step: int) -> None:
"""Save all task sampler states and the shared replay buffer."""
checkpoint_path = Path(checkpoint_path)
checkpoint_path.mkdir(parents=True, exist_ok=True)
pending_task_counts = self._get_pending_task_counts()
if pending_task_counts:
raise RuntimeError(
"Cannot save AgentLoopManager while pending rollout tasks still exist: "
f"{pending_task_counts}. Call pause_produce() first."
)
# 保存前显式记录当前 checkpoint 对应的模型步数,resume 时直接恢复这一份状态。
self._model_step = model_step
for task in self.task_runners:
task_checkpoint_path = self._task_checkpoint_path(checkpoint_path, task.task_name)
task_checkpoint_path.mkdir(parents=True, exist_ok=True)
task.sampler.save(task_checkpoint_path)
# manager 层保持 async 语义;同步入口只允许在 trainer 边界用 asyncio_run 包起来。
await self.replay_buffer.save(checkpoint_path)
manager_state_path = self._manager_state_path(checkpoint_path)
progress_state = self._produce_progress.state_dict()
with manager_state_path.open("w") as f:
json.dump(
{
"status": self._status.name,
"model_step": self._model_step,
**progress_state,
},
f,
)
async def resume(self, checkpoint_path: Path | str) -> int:
"""Resume all task sampler states and the shared replay buffer."""
checkpoint_path = Path(checkpoint_path)
for task in self.task_runners:
task.sampler.resume(self._task_checkpoint_path(checkpoint_path, task.task_name))
# replay buffer 恢复是 async I/O,不能在已有 event loop 中再次嵌套 asyncio_run。
await self.replay_buffer.resume(checkpoint_path)
manager_state_path = self._manager_state_path(checkpoint_path)
with manager_state_path.open("r") as f:
manager_state = json.load(f)
saved_model_step = manager_state["model_step"]
self._produce_progress.load_state_dict(manager_state)
self._update_event = asyncio.Event()
self._finish_event = asyncio.Event()
self._update_event.set()
self._status = AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT
self._pause_time_s = 0.0
self._model_step = saved_model_step
return saved_model_step