import asyncio
import math
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Protocol, runtime_checkable
if TYPE_CHECKING:
from xtuner.v1.rl.rollout.controller import RolloutControllerProxy
import ray
import tqdm
from mmengine.dist import get_rank
from pydantic import BaseModel, ConfigDict, Field
from xtuner.v1.data_proto.rl_data import (
RolloutState,
Status,
get_group_status,
reset_rollout_response,
)
from xtuner.v1.rl.agent_loop import AgentLoopSpec
from xtuner.v1.rl.replay_buffer import ReplayBuffer
from xtuner.v1.rl.utils import (
AGENT_LOOP_PAUSE_REQUEST_TIMEOUT_S,
PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S,
calculate_seq_staleness,
cancel_and_drain,
create_task,
)
from xtuner.v1.utils import get_logger
from .sampler import Sampler
logger = get_logger()
GROUP_GENERATE_TIME_KEY = "group_generate_time_s"
class _ProgressDisplayer:
def __init__(self, progress_bar: Any | None) -> None:
self._tqdm = progress_bar
@classmethod
def create(cls, *, strategy_name: str, task_name: str, total: int, initial: int) -> "_ProgressDisplayer":
total = max(0, total)
initial = min(total, max(0, initial))
if total <= 0 or get_rank() != 0:
return cls(None)
return cls(
tqdm.tqdm(
total=total,
initial=initial,
desc=f"{strategy_name} {task_name}",
unit="sample",
dynamic_ncols=True,
mininterval=30,
leave=False,
)
)
def update(self, value: int) -> None:
if self._tqdm is None:
return
total = max(0, int(self._tqdm.total or 0))
value = min(total, max(0, value))
delta = value - self._tqdm.n
if delta > 0:
self._tqdm.update(delta)
self._tqdm.n = value
def close(self) -> None:
if self._tqdm is not None:
self._tqdm.close()
self._tqdm = None
@dataclass
class ProduceProgress:
"""生产者和消费者共享的 live 进度对象。
设计目标:
- Manager / 调用方负责初始化并原地更新这个对象,strategy 只接收引用并读取最新进度。
- target / consumed 使用全局绝对累计口径,避免 consumer 取走 buffer 中的 completed 后,
producer 把已消费样本误判成缺口并重复补发。
- 同一套语义同时服务非共卡全局 progress 和共卡 produce_batch 的局部 progress。
使用注意:
- 不要在 strategy 中补 key 或用 dict.get(..., 0) 兜底;缺少 task key 应 fail fast。
- 除非语义明确要求冻结本轮 produce_batch 的 target / scheduled_target,
否则不要把字段值复制成局部快照后跨 await 使用;需要字段值时直接读 progress.xxx,
让并发更新后的 next_consumer_step / consumed_samples 能尽早生效。
- 运行中不要整体替换 ProduceProgress 对象;resume 时也应原地更新字段,避免旧引用失效。
字段含义:
- next_consumer_step:producer 写入新样本时应面向的训练 step。get_batch(i) 入口设为 i,
成功取出非空 batch 后设为 i + 1。
- producer_future_step:producer 当前准备生产的 future step。
- consumed_samples:各 task 已被 consumer 从 replay buffer 取走的 group 绝对累计数。
- target_samples:各 task 截至 target_upto_future_step 应生产出的 group 绝对累计目标。
- target_upto_future_step:target_samples 已覆盖到的最大 future step。
- raw_rewards_sum / raw_rewards_count:各 task 自上次 consumer 取 batch 后,producer 实际生成出的
completed group reward 统计。filtered group 在过滤前仍按 completed 生成结果计入。
- produced_samples / produced_tokens:各 task 自上次 consumer 取 batch 后,producer 实际返回的样本数和
response token 数,包含 filtered / aborted / 未被训练消费的 completed 样本。
- produce_time_s:自上次 consumer 取 batch 后,producer 实际执行 produce_batch 的累计 wall time。
"""
next_consumer_step: int = 1
producer_future_step: int = 1
consumed_samples: dict[str, int] = field(default_factory=dict)
target_samples: dict[str, int] = field(default_factory=dict)
target_upto_future_step: int = 0
raw_rewards_sum: dict[str, float] = field(default_factory=dict)
raw_rewards_count: dict[str, int] = field(default_factory=dict)
produced_samples: dict[str, int] = field(default_factory=dict)
produced_tokens: dict[str, int] = field(default_factory=dict)
produce_time_s: float = 0.0
@classmethod
def build(cls, task_names: list[str]) -> "ProduceProgress":
return cls(
consumed_samples={task_name: 0 for task_name in task_names},
target_samples={task_name: 0 for task_name in task_names},
raw_rewards_sum={task_name: 0.0 for task_name in task_names},
raw_rewards_count={task_name: 0 for task_name in task_names},
produced_samples={task_name: 0 for task_name in task_names},
produced_tokens={task_name: 0 for task_name in task_names},
)
@classmethod
def build_local(
cls,
task_names: list[str],
task_batch_sizes: dict[str, int],
train_step: int,
) -> "ProduceProgress":
# 共卡路径使用局部 progress,只表达本次 produce_batch 的目标,不污染非共卡累计窗口。
return cls(
next_consumer_step=train_step,
producer_future_step=train_step,
consumed_samples={task_name: 0 for task_name in task_names},
target_samples=dict(task_batch_sizes),
target_upto_future_step=train_step,
raw_rewards_sum={task_name: 0.0 for task_name in task_names},
raw_rewards_count={task_name: 0 for task_name in task_names},
produced_samples={task_name: 0 for task_name in task_names},
produced_tokens={task_name: 0 for task_name in task_names},
)
def ensure_target_upto(
self,
*,
batch_size: int,
future_step: int,
allocate_batch_sizes: Callable[[int, int], dict[str, int]],
) -> dict[str, int]:
"""把累计 target 推进到指定 future step,并返回该 step 的 task batch size。"""
if future_step > self.target_upto_future_step:
for step in range(self.target_upto_future_step + 1, future_step + 1):
task_batch_sizes = allocate_batch_sizes(batch_size, step)
for task_name, task_batch_size in task_batch_sizes.items():
self.target_samples[task_name] += task_batch_size
self.target_upto_future_step = future_step
return allocate_batch_sizes(batch_size, future_step)
def begin_consume(self, train_step: int) -> None:
self.next_consumer_step = train_step
def mark_consumed(self, consumed_counts: dict[str, int]) -> None:
# consumer 真实取出多少就累计多少,target 不回退,避免 producer 把已消费样本当成缺口。
for task_name, count in consumed_counts.items():
self.consumed_samples[task_name] += count
def add_raw_rewards(self, task_name: str, rewards_sum: float, rewards_count: int) -> None:
self.raw_rewards_sum[task_name] += rewards_sum
self.raw_rewards_count[task_name] += rewards_count
def add_produced(self, task_name: str, samples: int, tokens: int) -> None:
self.produced_samples[task_name] += samples
self.produced_tokens[task_name] += tokens
def add_produce_time(self, elapsed_s: float) -> None:
self.produce_time_s += elapsed_s
def consume_produced(self, task_name: str) -> tuple[int, int]:
samples = self.produced_samples[task_name]
tokens = self.produced_tokens[task_name]
self.produced_samples[task_name] = 0
self.produced_tokens[task_name] = 0
return samples, tokens
def consume_produce_time(self) -> float:
produce_time_s = self.produce_time_s
self.produce_time_s = 0.0
return produce_time_s
def consume_raw_rewards(self, task_name: str) -> tuple[float, int]:
rewards_sum = self.raw_rewards_sum[task_name]
rewards_count = self.raw_rewards_count[task_name]
self.raw_rewards_sum[task_name] = 0.0
self.raw_rewards_count[task_name] = 0
return rewards_sum, rewards_count
def finish_consume(self, train_step: int) -> None:
self.next_consumer_step = train_step + 1
def advance_future_step(self) -> None:
self.producer_future_step += 1
def state_dict(self) -> dict[str, Any]:
return {
"next_consumer_step": self.next_consumer_step,
"producer_future_step": self.producer_future_step,
"consumed_samples": dict(self.consumed_samples),
"target_samples": dict(self.target_samples),
"target_upto_future_step": self.target_upto_future_step,
"raw_rewards_sum": dict(self.raw_rewards_sum),
"raw_rewards_count": dict(self.raw_rewards_count),
"produced_samples": dict(self.produced_samples),
"produced_tokens": dict(self.produced_tokens),
"produce_time_s": self.produce_time_s,
}
def load_state_dict(self, state: dict[str, Any]) -> None:
# 原地更新 dict,避免 strategy / context 持有旧引用。
self.next_consumer_step = state["next_consumer_step"]
self.producer_future_step = state["producer_future_step"]
self.target_upto_future_step = state["target_upto_future_step"]
self.consumed_samples.clear()
self.consumed_samples.update(state["consumed_samples"])
self.target_samples.clear()
self.target_samples.update(state["target_samples"])
task_names = set(self.consumed_samples) | set(self.target_samples)
self.raw_rewards_sum.clear()
self.raw_rewards_sum.update(
{task_name: float(state.get("raw_rewards_sum", {}).get(task_name, 0.0)) for task_name in task_names}
)
self.raw_rewards_count.clear()
self.raw_rewards_count.update(
{task_name: int(state.get("raw_rewards_count", {}).get(task_name, 0)) for task_name in task_names}
)
produced_samples_state = state.get("produced_samples", {})
produced_tokens_state = state.get("produced_tokens", {})
self.produced_samples.clear()
self.produced_samples.update(
{task_name: int(produced_samples_state.get(task_name, 0)) for task_name in task_names}
)
self.produced_tokens.clear()
self.produced_tokens.update(
{task_name: int(produced_tokens_state.get(task_name, 0)) for task_name in task_names}
)
self.produce_time_s = float(state.get("produce_time_s", 0.0))
class ProduceBatchStatus(Enum):
NORMAL = auto()
UPDATE_WEIGHT_AND_ABORT = auto()
EXPIRED_BATCH = auto()
def default_is_valid_sample_fn(samples: list[RolloutState]) -> bool:
return True
def default_should_continue_fn(completed_count: int, batch_size: int, **kwargs) -> bool:
return completed_count < batch_size
def calculate_stale_threshold(max_staleness: int, sync_weights_interval: int) -> int:
if max_staleness < 0:
raise ValueError(f"max_staleness must be non-negative, got {max_staleness}.")
if sync_weights_interval <= 0:
raise ValueError(f"sync_weights_interval must be positive, got {sync_weights_interval}.")
# max_staleness 按同步周期计数;+1 表示训练天然必须接受的当前同步周期滞后。
return (max_staleness + 1) * sync_weights_interval
@runtime_checkable
class IsValidSampleFn(Protocol):
def __call__(self, samples: list[RolloutState]) -> bool: ...
@runtime_checkable
class ShouldContinueFn(Protocol):
def __call__(self, completed_count: int, batch_size: int, **kwargs) -> bool: ...
@dataclass
class ProduceContext:
"""单 task 生产上下文。
这里集中维护 AsyncProduceStrategy 最容易传错的运行时契约:
- strategy 只接受 ProduceContext,不再兼容散装参数入口;
- target / consumed 都按绝对累计口径读取;
- 暂停只读 manager 传入的 update_event;
- rollout generate 的 ray/local 差异和 timing 字段写入;
- 生成结果先按业务有效性过滤,再统一交给 ReplayBuffer 写版本、刷新 staleness、执行过期。
"""
agent_loop: AgentLoopSpec
sampler: Sampler
replay_buffer: ReplayBuffer
task_batch_size: int
task_name: str
train_step: int
update_event: asyncio.Event
model_step: int
progress: ProduceProgress
is_valid_sample_fn: IsValidSampleFn = default_is_valid_sample_fn
stale_threshold: int | None = None
@property
def consumer_step(self) -> int:
return self.progress.next_consumer_step
@property
def target_abs(self) -> int:
return self.progress.target_samples[self.task_name]
def should_abort(self) -> bool:
return self.update_event.is_set()
async def expired_count(self) -> int:
return await self.replay_buffer.count(task_name=self.task_name, group_status=Status.EXPIRED)
async def available_count(self) -> int:
completed_count = await self.replay_buffer.count(task_name=self.task_name, group_status=Status.COMPLETED)
return self.progress.consumed_samples[self.task_name] + completed_count
async def sample_group(self, *, from_expired_pool: bool) -> list[RolloutState]:
group_status = [Status.EXPIRED, Status.ABORTED] if from_expired_pool else [Status.ABORTED]
return await self.sampler.sample(task_name=self.task_name, group_status=group_status)
async def generate_group(
self,
rollout_state: list[RolloutState],
*,
enable_partial_rollout: bool = False,
) -> list[RolloutState]:
# strategy 只表达“要生成”,不关心 agent_loop 是 ray actor 还是本地对象。
start = time.perf_counter()
if isinstance(self.agent_loop, ray.actor.ActorHandle):
result = await self.agent_loop.generate_group.remote(
rollout_state,
enable_partial_rollout=enable_partial_rollout,
)
else:
result = await self.agent_loop.generate_group(
rollout_state,
enable_partial_rollout=enable_partial_rollout,
)
elapsed = time.perf_counter() - start
for item in result:
extra_fields = getattr(item, "extra_fields", None)
if extra_fields is None:
extra_fields = {}
setattr(item, "extra_fields", extra_fields)
extra_fields[GROUP_GENERATE_TIME_KEY] = elapsed
return result
async def put_generated_group(self, group: list[RolloutState]) -> bool:
# 只有完整生成的 group 才需要业务有效性过滤;ABORTED / EXPIRED 保留原状态供重试或统计。
is_completed = get_group_status(group) == Status.COMPLETED
produced_tokens = sum(len(item.response_ids) for item in group if item.response_ids is not None)
if is_completed:
rewards_sum = 0.0
rewards_count = 0
for item in group:
if item.reward is None or "score" not in item.reward:
logger.warning(
f"Missing reward score in item (uid: {item.uid}) of completed group for task {self.task_name}. This item will be skipped in reward statistics."
)
continue
rewards_sum += float(item.reward["score"]) # type: ignore[index]
rewards_count += 1
self.progress.add_raw_rewards(self.task_name, rewards_sum, rewards_count)
is_valid = self.is_valid_sample_fn(group)
if not is_valid:
for item in group:
item.status = Status.FILTERED
reset_rollout_response(item)
await self.replay_buffer.put(
group,
self.task_name,
model_step=self.model_step,
current_train_step=self.consumer_step,
stale_threshold=self.stale_threshold,
)
self.progress.add_produced(self.task_name, samples=len(group), tokens=produced_tokens)
# replay_buffer.put 可能把 stale group 转为 EXPIRED,返回前重新判断是否仍可训练。
is_completed = get_group_status(group) == Status.COMPLETED
return is_completed
class ProduceStrategyConfig(ABC, BaseModel):
"""Base configuration for rollout production strategies.
Production strategies decide how the agent loop fills the replay buffer and
when it should stop producing samples for the current training step.
Args:
is_valid_sample_fn (IsValidSampleFn): Function used to decide whether a
generated rollout group is trainable. Defaults to
``default_is_valid_sample_fn``.
should_continue_fn (ShouldContinueFn): Function used to decide whether
production should continue after a group is processed. Defaults to
``default_should_continue_fn``.
"""
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
is_valid_sample_fn: IsValidSampleFn = default_is_valid_sample_fn
should_continue_fn: ShouldContinueFn = default_should_continue_fn
@abstractmethod
def build(
self,
*,
sync_weights_interval: int = 1,
rollout_controller: "Optional[RolloutControllerProxy]" = None,
) -> "ProduceStrategy": ...
[docs]class SyncProduceStrategyConfig(ProduceStrategyConfig):
"""Configuration for synchronous rollout production.
The synchronous strategy produces samples on demand for the current training
step. It is simpler and is the default choice when rollout and training run
in a colocated or tightly synchronized workflow.
Args:
is_valid_sample_fn (IsValidSampleFn): Function used to decide whether a
generated rollout group is trainable. Defaults to
``default_is_valid_sample_fn``.
should_continue_fn (ShouldContinueFn): Function used to decide whether
production should continue after a group is processed. Defaults to
``default_should_continue_fn``.
**Examples:**
Example synchronous strategy::
config = SyncProduceStrategyConfig()
"""
def build(
self,
*,
sync_weights_interval: int = 1,
rollout_controller: "Optional[RolloutControllerProxy]" = None,
) -> "SyncProduceStrategy":
return SyncProduceStrategy(
is_valid_sample_fn=self.is_valid_sample_fn, should_continue_fn=self.should_continue_fn
)
[docs]class AsyncProduceStrategyConfig(ProduceStrategyConfig):
"""Configuration for asynchronous rollout production.
The asynchronous strategy keeps producing rollout samples in the background
and stores them in the replay buffer. It can oversample, allow partial
rollout continuation, and discard samples that are too stale relative to the
current training step.
Args:
is_valid_sample_fn (IsValidSampleFn): Function used to decide whether a
generated rollout group is trainable. Defaults to
``default_is_valid_sample_fn``.
should_continue_fn (ShouldContinueFn): Function used to decide whether
production should continue after a group is processed. Defaults to
``default_should_continue_fn``.
over_sample_threshold (float): Extra completed-sample ratio allowed
before the producer stops. Defaults to 0.0.
enable_partial_rollout (bool): Whether unfinished rollouts can be
continued after a weight sync. Defaults to False.
max_staleness (int): Maximum allowed model-step staleness for replayed
samples. Defaults to 0.
tail_batch_trigger_size (int): Minimum pending tail size that can
trigger a final batch. Defaults to 0.
**Examples:**
Example asynchronous strategy::
config = AsyncProduceStrategyConfig(
over_sample_threshold=0.2,
enable_partial_rollout=True,
max_staleness=1,
)
"""
over_sample_threshold: float = 0.0
enable_partial_rollout: bool = False
max_staleness: int = Field(default=0, ge=0)
tail_batch_trigger_size: int = 0
def build(
self,
*,
sync_weights_interval: int = 1,
rollout_controller: "Optional[RolloutControllerProxy]" = None,
) -> "AsyncProduceStrategy":
if rollout_controller is not None:
import ray
ray.get(rollout_controller.set_enable_partial_rollout.remote(self.enable_partial_rollout))
return AsyncProduceStrategy(
over_sample_threshold=self.over_sample_threshold,
enable_partial_rollout=self.enable_partial_rollout,
max_staleness=self.max_staleness,
sync_weights_interval=sync_weights_interval,
tail_batch_trigger_size=self.tail_batch_trigger_size,
is_valid_sample_fn=self.is_valid_sample_fn,
should_continue_fn=self.should_continue_fn,
)
class ProduceStrategy(ABC):
def __init__(
self,
is_valid_sample_fn: IsValidSampleFn,
should_continue_fn: ShouldContinueFn,
):
self.is_valid_sample_fn = is_valid_sample_fn
self.should_continue_fn = should_continue_fn
@abstractmethod
async def produce_batch(self, ctx: ProduceContext) -> ProduceBatchStatus: ...
async def pause_produce(self, ctx: ProduceContext) -> float:
return 0.0
def is_model_expired(self, train_step: int, model_step: int) -> bool:
# 默认同步策略没有跨权重版本的后台样本,只有异步策略需要判定模型过期。
return False
def pending_task_count(self) -> int:
return 0
class _PendingTasks:
"""AsyncProduceStrategy 的并发 pending task 集合。
这里只封装 pending set 的并发协议,不理解 sampler / rollout / replay buffer:
- wait 使用快照,随后必须二次 claim,避免 produce 和 pause 重复处理同一个 done task。
- cancel 前先原子 claim 并清空集合,避免 cancel 后又被其他路径 claim。
- schedule one 在锁内同时检查 abort 和 pending 数,避免 pause 已触发后继续新增 task。
"""
def __init__(self) -> None:
self._tasks: set[asyncio.Task] = set()
self._lock = asyncio.Lock()
def count(self) -> int:
# 只暴露已经纳入 pending 集合的 task 数量。
return len(self._tasks)
async def claim_ready(self) -> set[asyncio.Task]:
async with self._lock:
ready = {task for task in self._tasks if task.done()}
self._tasks.difference_update(ready)
return ready
async def wait_and_claim(self, *, timeout_s: float) -> set[asyncio.Task]:
async with self._lock:
snapshot = set(self._tasks)
if not snapshot:
return set()
done, _ = await asyncio.wait(snapshot, timeout=timeout_s, return_when=asyncio.FIRST_COMPLETED)
async with self._lock:
claimed = done & self._tasks
self._tasks.difference_update(claimed)
return claimed
async def schedule_one(
self,
*,
max_pending: int,
should_abort: Callable[[], bool],
spawn_one: Callable[[], Awaitable[asyncio.Task]],
) -> bool:
async with self._lock:
if should_abort() or len(self._tasks) >= max_pending:
return False
# 保持“检查 abort / pending 数 / 新增 task”这一组操作原子化。
self._tasks.add(await spawn_one())
return True
async def _claim_all(self) -> set[asyncio.Task]:
async with self._lock:
claimed = set(self._tasks)
self._tasks.clear()
return claimed
async def cancel_all(self) -> int:
tasks = await self._claim_all()
if not tasks:
return 0
logger.warning(f"Cancelling {len(tasks)} pending rollout tasks.")
await cancel_and_drain(list(tasks))
return len(tasks)
class SyncProduceStrategy(ProduceStrategy):
async def produce_batch(self, ctx: ProduceContext) -> ProduceBatchStatus:
pending_tasks = set()
completed_sample_count = await ctx.replay_buffer.count(task_name=ctx.task_name, group_status=Status.COMPLETED)
# TODO: 是否支持 SyncProduceStrategy 在非共卡时使用?如果支持,下面这行注释掉?
# assert completed_sample_count == 0, "SyncProduceStrategy assumes no completed samples at the start."
for _ in range(ctx.task_batch_size):
rollout_state = await ctx.sampler.sample(task_name=ctx.task_name)
task = create_task(ctx.generate_group(rollout_state))
pending_tasks.add(task)
logger.info(f"[SyncProduceStrategy] Started {len(pending_tasks)} initial tasks.")
progress_displayer = _ProgressDisplayer.create(
strategy_name=self.__class__.__name__,
task_name=ctx.task_name,
total=ctx.target_abs,
initial=completed_sample_count,
)
try:
while self.should_continue_fn(completed_sample_count, ctx.task_batch_size):
if not pending_tasks:
logger.warning("[SyncProduceStrategy] All tasks are done but not enough samples collected.")
break
done_tasks, pending_tasks = await asyncio.wait(
pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED
)
# 如果要过滤,在这个地方处理,然后加入到 replay buffer
# 如果被过滤的数据就放到 put_to_filtered pool 中
for task in done_tasks:
items = task.result()
is_completed = await ctx.put_generated_group(items)
if not is_completed:
continue
completed_sample_count += 1
progress_displayer.update(completed_sample_count)
while len(pending_tasks) + completed_sample_count < ctx.task_batch_size and self.should_continue_fn(
completed_sample_count, ctx.task_batch_size
):
rollout_state = await ctx.sampler.sample(task_name=ctx.task_name)
task = create_task(ctx.generate_group(rollout_state))
pending_tasks.add(task)
finally:
progress_displayer.close()
return ProduceBatchStatus.NORMAL
class AsyncProduceStrategy(ProduceStrategy):
# Local retry interval for re-sending pause/abort while pending tasks drain.
PERIODIC_ABORT_INTERVAL_S = 5.0
def __init__(
self,
over_sample_threshold: float,
enable_partial_rollout: bool,
tail_batch_trigger_size: int,
max_staleness: int,
sync_weights_interval: int,
is_valid_sample_fn: IsValidSampleFn,
should_continue_fn: ShouldContinueFn,
):
super().__init__(is_valid_sample_fn, should_continue_fn)
# TODO: 需要添加 tail_batch_max_tries
# 作用是:如果一个样本多次重试,则将它置为特殊状态 MAX_TRIES,这类样本和过期样本一起触发tail batch逻辑
# 这个依赖:RolloutState 添加并维护一个新的属性 num_tries,每次打断时加1,达到 max_tries 时置为 MAX_TRIES
# 如果 enable_partial_rollout=True,不会触发这个逻辑,所以不受此影响
# 如果 enable_partial_rollout=False,分两种情况:
# 1) staleness = 0,即不允许过期样本,此时过期触发tail batch逻辑已经cover了tail batch逻辑
# 2) staleness > 0,此时需要 重试tail batch逻辑,否则多次重试的样本会影响rollout 效率
if not enable_partial_rollout and max_staleness > 0:
logger.warning(
"max_staleness > 0, enable_partial_rollout is False, this will affect rollout efficiency because not support tail_batch_max_tries logic now"
)
self.over_sample_threshold = over_sample_threshold
self.enable_partial_rollout = enable_partial_rollout
self.max_staleness = max_staleness
self.sync_weights_interval = sync_weights_interval
self.stale_threshold = calculate_stale_threshold(max_staleness, sync_weights_interval)
self.tail_batch_trigger_size = tail_batch_trigger_size
self._pending_tasks = _PendingTasks()
def is_model_expired(self, train_step: int, model_step: int) -> bool:
staleness = calculate_seq_staleness(model_step, train_step)
return staleness >= self.stale_threshold
def pending_task_count(self) -> int:
return self._pending_tasks.count()
async def _put_claimed(
self,
claimed_tasks: set[asyncio.Task],
ctx: ProduceContext,
available_base: int | None = None,
progress_displayer: _ProgressDisplayer | None = None,
) -> None:
completed_count = 0
for task in claimed_tasks:
items = task.result()
is_completed = await ctx.put_generated_group(items)
if is_completed:
completed_count += 1
if is_completed and available_base is not None and progress_displayer is not None:
progress_displayer.update(available_base + completed_count)
async def _pause_agent_loop(self, ctx: ProduceContext) -> None:
pause_request_start = time.perf_counter()
if isinstance(ctx.agent_loop, ray.actor.ActorHandle):
pause_future = ctx.agent_loop.pause.remote()
else:
pause_future = ctx.agent_loop.pause()
try:
await asyncio.wait_for(pause_future, timeout=AGENT_LOOP_PAUSE_REQUEST_TIMEOUT_S)
except asyncio.TimeoutError:
logger.warning(
f"Agent loop pause timed out: task={ctx.task_name}, timeout_s={AGENT_LOOP_PAUSE_REQUEST_TIMEOUT_S}, "
f"elapsed={time.perf_counter() - pause_request_start:.2f}s, "
f"pending={self._pending_tasks.count()}"
)
except Exception:
logger.exception(
f"Agent loop pause failed: task={ctx.task_name}, "
f"elapsed={time.perf_counter() - pause_request_start:.2f}s, "
f"pending={self._pending_tasks.count()}"
)
async def pause_produce(self, ctx: ProduceContext) -> float:
pause_start = time.perf_counter()
if self._pending_tasks.count() == 0:
return 0.0
pending_pause_tasks = {create_task(self._pause_agent_loop(ctx))}
initial_pending_count = self._pending_tasks.count()
logger.info(
f"Pause signal loop started for task {ctx.task_name}. "
f"Waiting for {initial_pending_count} pending tasks to complete. "
f"periodic_abort_interval_s={self.PERIODIC_ABORT_INTERVAL_S}, "
f"producer_pause_pending_task_timeout_s={PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S}"
)
cleanup_start_time = time.perf_counter()
next_periodic_abort_time = cleanup_start_time + self.PERIODIC_ABORT_INTERVAL_S
while True:
elapsed_time = time.perf_counter() - cleanup_start_time
if elapsed_time > PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S:
# 超时强制取消所有pending的任务
cancelled_count = await self._pending_tasks.cancel_all()
logger.warning(
f"Cleanup timeout of {PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S}s reached. "
f"Forcefully cancelling {cancelled_count} remaining tasks. "
f"task={ctx.task_name}"
)
break
if self._pending_tasks.count() == 0:
break
current_time = time.perf_counter()
pending_pause_tasks = {task for task in pending_pause_tasks if not task.done()}
# 定时发送 pause 信号
if self.PERIODIC_ABORT_INTERVAL_S > 0 and current_time >= next_periodic_abort_time:
pending_pause_tasks.add(create_task(self._pause_agent_loop(ctx)))
next_periodic_abort_time += self.PERIODIC_ABORT_INTERVAL_S
claimed_done = await self._pending_tasks.wait_and_claim(timeout_s=1)
for task in claimed_done:
paused_items = task.result()
await ctx.put_generated_group(paused_items)
await cancel_and_drain(list(pending_pause_tasks))
pause_time = time.perf_counter() - pause_start
logger.info(f"pause_produce completed for task {ctx.task_name} within {pause_time}s.")
return pause_time
async def produce_batch(self, ctx: ProduceContext) -> ProduceBatchStatus:
if ctx.task_name not in ctx.progress.consumed_samples:
raise KeyError(f"ProduceProgress.consumed_samples missing task_name={ctx.task_name!r}")
if ctx.task_name not in ctx.progress.target_samples:
raise KeyError(f"ProduceProgress.target_samples missing task_name={ctx.task_name!r}")
if ctx.target_abs <= 0:
return ProduceBatchStatus.NORMAL
# TODO: place this check just before while loop
if ctx.should_abort():
return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT
if self.is_model_expired(ctx.train_step, ctx.model_step):
return ProduceBatchStatus.EXPIRED_BATCH
# 先回收跨 produce_batch 调用遗留的已完成任务,避免 done task 长期留在 pending 集合里。
claimed_done = await self._pending_tasks.claim_ready()
await self._put_claimed(claimed_done, ctx)
# TODO: remove this check
if ctx.should_abort():
return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT
if self.is_model_expired(ctx.train_step, ctx.model_step):
return ProduceBatchStatus.EXPIRED_BATCH
expired_count = await ctx.expired_count()
sample_from_expired = self.tail_batch_trigger_size > 0 and expired_count >= self.tail_batch_trigger_size
if sample_from_expired:
logger.info(
f"Tail batch trigger condition met: {expired_count} expired samples "
f"(threshold: {self.tail_batch_trigger_size}). Enabling tail batch mode."
)
# 本轮 produce_batch 的必要累计目标固定;normal 模式只按当前 task batch 追加固定超发预算。
# tail-batch 模式只补必要缺口,新增任务固定从 EXPIRED pool 取,不再扩大超发窗口。
target_abs = ctx.target_abs
oversample_budget = 0 if sample_from_expired else math.ceil(self.over_sample_threshold * ctx.task_batch_size)
scheduled_target = target_abs + oversample_budget
logger.info(
f"Starting produce_batch for task {ctx.task_name} with target_abs={target_abs}, "
f"oversample_budget={oversample_budget}, scheduled_target={scheduled_target}."
)
async def spawn_one() -> asyncio.Task:
rollout_state = await ctx.sample_group(from_expired_pool=sample_from_expired)
return create_task(
ctx.generate_group(
rollout_state,
enable_partial_rollout=self.enable_partial_rollout,
)
)
initial_available = await ctx.available_count()
progress_displayer = _ProgressDisplayer.create(
strategy_name=self.__class__.__name__,
task_name=ctx.task_name,
total=ctx.target_abs,
initial=initial_available,
)
try:
while True:
if ctx.should_abort():
return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT
if self.is_model_expired(ctx.train_step, ctx.model_step):
return ProduceBatchStatus.EXPIRED_BATCH
available = await ctx.available_count()
progress_displayer.update(available)
if not self.should_continue_fn(available, target_abs):
return ProduceBatchStatus.NORMAL
pending_count = self._pending_tasks.count()
desired_pending = max(0, scheduled_target - available)
if available + pending_count < scheduled_target:
while await self._pending_tasks.schedule_one(
max_pending=desired_pending,
should_abort=ctx.should_abort,
spawn_one=spawn_one,
):
pass
# TODO: remove this check, because will check it when exit if statement, it's redundant
if ctx.should_abort():
return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT
if ctx.should_abort():
return ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT
if self._pending_tasks.count() == 0:
logger.warning("All tasks are done but not enough samples collected.")
return ProduceBatchStatus.NORMAL
claimed_done = await self._pending_tasks.wait_and_claim(timeout_s=1)
await self._put_claimed(
claimed_done,
ctx,
available_base=available,
progress_displayer=progress_displayer,
)
finally:
progress_displayer.close()