AgentLoop#
AgentLoop 是 XTuner RL 中最常需要按任务自定义的模块。它定义“一组样本如何完成 rollout”:如何组织推理输入、调用几次推理引擎、是否插入工具或环境反馈、哪些 token 参与训练,以及什么时候调用 Judger 写入 reward。
在默认训练链路中,AgentLoop 位于 sampler 和 replay buffer 之间:
Sampler -> list[RolloutState]
-> AgentLoop.generate_group()
-> RolloutController.generate()
-> Judger.judge() / Judger.batch_judge()
-> ReplayBuffer
-> RLTrainer._prepare_train_data()
如果你的任务只是单轮问答,通常直接使用预置的 SingleTurnAgentLoop。如果任务包含多轮交互、工具调用、环境 step、特殊终止条件、非模型 token 插入或自定义 response mask,就应该自定义 AgentLoop。
类型与构建#
xtuner/v1/rl/agent_loop/agent_loop.py 中有两个核心抽象:
AgentLoopConfig:配置对象,负责构建本地或 Ray actor 形式的 AgentLoop。AgentLoop:运行时对象,负责实现generate_sample()和generate_group()。
整体关系如下:
┌─────────────────────────────┐
│ AgentLoopConfig │
│ hf_checkpoint, sample_params │
│ cpu_resources │
└──────────────┬──────────────┘
│ build(...)
┌────────────────────────────┼────────────────────────────┐
│ │ │
cpu_resources = None cpu_resources.num_workers = 1 cpu_resources.num_workers > 1
│ │ │
▼ ▼ ▼
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ AgentLoop │ │ AgentLoopActor │ │ RouterAgentLoop │
│ 本地执行 │ │ Ray actor │ │ 多 actor 路由 │
└─────────────────┘ └────────┬────────┘ └────────┬────────┘
│ │
▼ ▼
┌─────────────────┐ ┌─────────────────┐
│ AgentLoop │ │ AgentLoopActor │
│ actor 内部实例 │ │ ... │
└─────────────────┘ └─────────────────┘
AgentLoop 统一暴露两个异步接口:
async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState:
...
async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]:
...
generate_sample() 处理单条样本;默认 generate_group() 会并发调用多次 generate_sample(),并在调用前把 self.sample_params 写到组内每条样本上。若配置了 enable_batch_judge=True,默认 generate_group() 会在组内样本生成完成后调用一次 self.run_judger(group_samples)。需要其他组级逻辑时,例如组内共享环境、组内排序过滤,可以覆盖 generate_group()。
输入输出约定#
AgentLoop 输入和输出都是 RolloutState。如果后续使用预置 RLTrainer._prepare_train_data(),自定义 AgentLoop 必须维护好训练所需字段。
输入#
Sampler 传给 AgentLoop 的 RolloutState 通常包含:
message:原始对话消息。prompt_ids:tokenized prompt,通常由 RL tokenize function 写入。reward_model:标签信息,例如{"ground_truth": ...},供 Judger 使用。sample_params:会在generate_group()中被 AgentLoop 的默认采样参数覆盖。task_name、uid、session_uid等调度字段。
生成前,AgentLoop 需要确保:
rollout_state.tokens是实际传给RolloutController.generate()的输入 token。单轮任务通常设为prompt_ids;多轮任务通常设为历史上下文拼接后的 token。rollout_state.sample_params是本次推理使用的参数。多轮任务里每一轮可能需要更新max_tokens。
输出#
AgentLoop 返回的 RolloutState 如果要进入训练,至少需要满足:
status == Status.COMPLETED。预置 trainer 会跳过ABORTED、FILTERED、FAILED的样本组。response_ids非空。_prepare_train_data()用它构造训练 token。response非空。Judger 和轨迹保存依赖文本 response。reward["score"]存在。_prepare_train_data()会直接读取它计算 advantage。
如果提供以下字段,还需要满足长度约定:
logprobs:长度必须等于len(response_ids)。response_mask:长度必须等于len(response_ids)。mask 为0的 token 会被转成训练 label-100,对应 advantage 也会置为0.0。
这也是自定义 AgentLoop 最容易出错的地方:工具返回、环境反馈、系统插入内容等不是模型生成的 token,应该在 response_mask 中置为 0,并给对应 logprobs 填 0.0。
SingleTurnAgentLoop#
SingleTurnAgentLoop 是默认单轮问答实现,适用于“给定 prompt,模型生成一次 response,然后打分”的任务。
单条样本流程如下:
generate_sample(state)
-> PartialRolloutHandler.preprocess(state, enable_partial_rollout)
-> 如果 state.tokens 为空,则 state.tokens = state.prompt_ids
-> await rollout_ctl.generate.remote(state)
-> PartialRolloutHandler.postprocess(state)
-> 如果 state.status != COMPLETED,直接返回,不触发 Judger
-> 如果配置了 judger,调用 self.run_judger(state)
典型配置:
from xtuner.v1.data_proto.rl_data import SampleParams
from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig
agent_loop_config = SingleTurnAgentLoopConfig(
hf_checkpoint=model_path,
sample_params=SampleParams(
max_tokens=1024,
top_k=0,
top_p=1.0,
temperature=1.0,
min_tokens=0,
),
)
AgentLoopConfig 还支持批量打分:
agent_loop_config = SingleTurnAgentLoopConfig(
hf_checkpoint=model_path,
sample_params=training_sample_params,
enable_batch_judge=True,
)
开启后,generate_sample() 不会逐条调用 Judger;generate_group() 会在组内样本全部生成完成后调用一次 self.run_judger(group_samples),内部会转到 judger.batch_judge(group_samples)。只有当前 Judger 明确实现 batch_judge() 时才应开启。
自定义 AgentLoop#
自定义 AgentLoop 通常需要做四件事:
继承
AgentLoop,实现generate_sample()。在
generate_sample()中维护tokens、sample_params、response_ids、response、logprobs、response_mask、status。需要打分时,在 response 可用后调用
self.run_judger(...)。继承
AgentLoopConfig,实现build_local(),这样才能接入TaskSpecConfig.agent_loop_config,并复用 Ray actor 构建逻辑。
self.run_judger(...) 会根据输入形态调用 judger.judge() 或 judger.batch_judge(),并统一处理 pause/cancel。若自定义 AgentLoop 支持 enable_batch_judge=True,generate_sample() 中的单条打分需要用 not self.enable_batch_judge 保护,避免默认 generate_group() 再次批量打分。若自定义 AgentLoop 需要覆盖 pause(),应在实现中调用 await super().pause(),否则正在执行的 Judger 任务无法复用基类的中断逻辑。
最小单轮实现#
下面是一个最小可用版本,行为接近 SingleTurnAgentLoop:
from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.rl.agent_loop import AgentLoop, AgentLoopConfig
from xtuner.v1.rl.judger import Judger
class CustomAgentLoop(AgentLoop):
async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState:
if not rollout_state.tokens:
rollout_state.tokens = rollout_state.prompt_ids
rollout_state.sample_params = rollout_state.sample_params or self.sample_params
rollout_state = await self.rollout_ctl.generate.remote(rollout_state)
if rollout_state.status != Status.COMPLETED:
return rollout_state
if self.judger is not None and not self.enable_batch_judge:
rollout_state = await self.run_judger(rollout_state)
return rollout_state
class CustomAgentLoopConfig(AgentLoopConfig):
def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> CustomAgentLoop:
return CustomAgentLoop(
rollout_ctl=rollout_controller,
sample_params=self.sample_params,
hf_checkpoint=self.hf_checkpoint,
judger=judger,
logger=logger,
)
这个版本适合没有工具、没有环境反馈、没有多轮上下文拼接的任务。只要 RolloutController.generate() 能正确写入 response_ids、response、logprobs 和 status,后续 Judger 与训练数据准备就能复用默认链路。
多轮或工具调用实现#
多轮任务通常需要循环调用 rollout_ctl.generate(),每轮把上轮输出和工具或环境结果追加到下一轮输入。可以参考 GSM8KToolAgentLoop 的模式:
class ToolAgentLoop(AgentLoop):
async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState:
final_response_ids: list[int] = []
final_logprobs: list[float] = []
final_response_mask: list[int] = []
cur_tokens = list(rollout_state.tokens or rollout_state.prompt_ids or [])
remaining_tokens = self.sample_params.max_tokens
for _ in range(self.max_turns):
rollout_state.tokens = cur_tokens
rollout_state.sample_params = self.sample_params.model_copy(
update={"max_tokens": remaining_tokens}
)
rollout_state = await self.rollout_ctl.generate.remote(rollout_state)
if rollout_state.status != Status.COMPLETED:
break
response_ids = list(rollout_state.response_ids or [])
logprobs = list(rollout_state.logprobs or [])
assert len(response_ids) == len(logprobs)
final_response_ids.extend(response_ids)
final_logprobs.extend(logprobs)
final_response_mask.extend([1] * len(response_ids))
cur_tokens.extend(response_ids)
tool_tokens = self._run_tool_and_encode_result(rollout_state)
if not tool_tokens:
break
final_response_ids.extend(tool_tokens)
final_logprobs.extend([0.0] * len(tool_tokens))
final_response_mask.extend([0] * len(tool_tokens))
cur_tokens.extend(tool_tokens)
remaining_tokens = self.sample_params.max_tokens - len(final_response_ids)
if remaining_tokens <= 0:
break
rollout_state.response_ids = final_response_ids[: self.sample_params.max_tokens]
rollout_state.logprobs = final_logprobs[: self.sample_params.max_tokens]
rollout_state.response_mask = final_response_mask[: self.sample_params.max_tokens]
rollout_state.response = self.tokenizer.decode(rollout_state.response_ids)
assert len(rollout_state.response_ids) == len(rollout_state.logprobs)
assert len(rollout_state.response_ids) == len(rollout_state.response_mask)
if rollout_state.status == Status.COMPLETED and self.judger is not None and not self.enable_batch_judge:
rollout_state = await self.run_judger(rollout_state)
return rollout_state
这个例子强调两个约定:
模型生成 token 的
response_mask为1。工具或环境插入 token 的
response_mask为0,logprobs填0.0。
覆盖 generate_group#
默认 generate_group() 会并发处理组内样本。如果你的任务需要组级逻辑,可以覆盖它:
async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]:
samples = await super().generate_group(rollout_state, **kwargs)
# 例:在默认并发生成和可选 batch judge 之后,继续执行组级过滤或排序。
samples = self.filter_or_sort_group(samples)
return samples
常见需要覆盖 generate_group() 的场景:
Judger 需要一次性处理同一个 prompt 的多条 response。
组内样本共享同一个外部环境或缓存。
需要在组内做过滤、排序或重采样。
不希望组内样本并发执行。
支持 partial rollout#
如果训练使用 AsyncProduceStrategyConfig(enable_partial_rollout=True),producer 会把 enable_partial_rollout 作为运行时参数传给 generate_group(),再传入 generate_sample()。
自定义 AgentLoop 可以直接复用 PartialRolloutHandler:
from xtuner.v1.rl.agent_loop.utils import PartialRolloutHandler
class CustomAgentLoop(AgentLoop):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.partial_rollout_handler = PartialRolloutHandler(
max_tokens=self.sample_params.max_tokens
)
async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState:
enable_partial_rollout = kwargs.get("enable_partial_rollout", False)
rollout_state = self.partial_rollout_handler.preprocess(
rollout_state,
enable_partial_rollout,
)
...
rollout_state = self.partial_rollout_handler.postprocess(rollout_state)
return rollout_state
如果任务的多轮上下文、工具结果或环境状态不能用默认 handler 合并,需要自己定义续跑逻辑。核心原则是:续跑后的 response_ids、response、logprobs、response_mask 必须仍然表示完整 response,而不是只有本轮新增部分。
在训练配置中使用#
训练配置中通常不手动实例化 AgentLoop,而是把 config 挂到 TaskSpecConfig.agent_loop_config:
from xtuner.v1.rl.agent_loop_manager import (
AgentLoopManagerConfig,
SamplerConfig,
SyncProduceStrategyConfig,
TaskSpecConfig,
)
agent_loop_config = CustomAgentLoopConfig(
hf_checkpoint=model_path,
sample_params=training_sample_params,
)
agent_loop_manager_cfg = AgentLoopManagerConfig(
tasks=TaskSpecConfig(
task_name="train_task",
agent_loop_config=agent_loop_config,
judger_config=judger_config,
produce_strategy_config=SyncProduceStrategyConfig(),
sampler_config=sampler_config,
),
)
AgentLoopManagerConfig.build() 会根据 agent_loop_config 构建 AgentLoop,根据 judger_config 构建 Judger,再把它们和 sampler、producer strategy 组装成 task runner。多任务训练时,每个 TaskSpecConfig 都可以使用不同的 AgentLoop。
自定义 Checklist#
实现自定义 AgentLoop 时,建议逐项检查:
generate_sample()是否只处理单条RolloutState。推理前是否设置了
rollout_state.tokens。每次调用
rollout_ctl.generate.remote()前是否设置了本轮sample_params。返回训练前,
response_ids、response、logprobs、response_mask是否完整且长度一致。非模型生成 token 是否在
response_mask中置为0。需要 Judger 时,是否通过
self.run_judger(...)调用打分,以复用 pause/cancel 处理。若使用
_prepare_train_data(),是否保证最终有reward["score"]。若使用 async partial rollout,是否正确处理
enable_partial_rollout和历史 response 合并。