AgentLoop#

AgentLoop 是 XTuner RL 中最常需要按任务自定义的模块。它定义“一组样本如何完成 rollout”:如何组织推理输入、调用几次推理引擎、是否插入工具或环境反馈、哪些 token 参与训练,以及什么时候调用 Judger 写入 reward。

在默认训练链路中,AgentLoop 位于 sampler 和 replay buffer 之间:

Sampler -> list[RolloutState]
  -> AgentLoop.generate_group()
  -> RolloutController.generate()
  -> Judger.judge() / Judger.batch_judge()
  -> ReplayBuffer
  -> RLTrainer._prepare_train_data()

如果你的任务只是单轮问答,通常直接使用预置的 SingleTurnAgentLoop。如果任务包含多轮交互、工具调用、环境 step、特殊终止条件、非模型 token 插入或自定义 response mask,就应该自定义 AgentLoop。

类型与构建#

xtuner/v1/rl/agent_loop/agent_loop.py 中有两个核心抽象:

  • AgentLoopConfig:配置对象,负责构建本地或 Ray actor 形式的 AgentLoop。

  • AgentLoop:运行时对象,负责实现 generate_sample()generate_group()

整体关系如下:

	                           ┌─────────────────────────────┐
	                           │       AgentLoopConfig        │
	                           │ hf_checkpoint, sample_params │
	                           │ cpu_resources                 │
	                           └──────────────┬──────────────┘
	                                          │ build(...)
	             ┌────────────────────────────┼────────────────────────────┐
	             │                            │                            │
	    cpu_resources = None       cpu_resources.num_workers = 1  cpu_resources.num_workers > 1
             │                            │                            │
             ▼                            ▼                            ▼
   ┌─────────────────┐          ┌─────────────────┐          ┌─────────────────┐
   │   AgentLoop     │          │ AgentLoopActor  │          │ RouterAgentLoop │
   │ 本地执行         │          │ Ray actor        │          │ 多 actor 路由     │
   └─────────────────┘          └────────┬────────┘          └────────┬────────┘
                                          │                            │
                                          ▼                            ▼
                                  ┌─────────────────┐        ┌─────────────────┐
                                  │   AgentLoop     │        │ AgentLoopActor  │
                                  │ actor 内部实例   │        │ ...             │
                                  └─────────────────┘        └─────────────────┘

AgentLoop 统一暴露两个异步接口:

async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState:
    ...

async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]:
    ...

generate_sample() 处理单条样本;默认 generate_group() 会并发调用多次 generate_sample(),并在调用前把 self.sample_params 写到组内每条样本上。若配置了 enable_batch_judge=True,默认 generate_group() 会在组内样本生成完成后调用一次 self.run_judger(group_samples)。需要其他组级逻辑时,例如组内共享环境、组内排序过滤,可以覆盖 generate_group()

输入输出约定#

AgentLoop 输入和输出都是 RolloutState。如果后续使用预置 RLTrainer._prepare_train_data(),自定义 AgentLoop 必须维护好训练所需字段。

输入#

Sampler 传给 AgentLoop 的 RolloutState 通常包含:

  • message:原始对话消息。

  • prompt_ids:tokenized prompt,通常由 RL tokenize function 写入。

  • reward_model:标签信息,例如 {"ground_truth": ...},供 Judger 使用。

  • sample_params:会在 generate_group() 中被 AgentLoop 的默认采样参数覆盖。

  • task_nameuidsession_uid 等调度字段。

生成前,AgentLoop 需要确保:

  • rollout_state.tokens 是实际传给 RolloutController.generate() 的输入 token。单轮任务通常设为 prompt_ids;多轮任务通常设为历史上下文拼接后的 token。

  • rollout_state.sample_params 是本次推理使用的参数。多轮任务里每一轮可能需要更新 max_tokens

输出#

AgentLoop 返回的 RolloutState 如果要进入训练,至少需要满足:

  • status == Status.COMPLETED。预置 trainer 会跳过 ABORTEDFILTEREDFAILED 的样本组。

  • response_ids 非空。_prepare_train_data() 用它构造训练 token。

  • response 非空。Judger 和轨迹保存依赖文本 response。

  • reward["score"] 存在。_prepare_train_data() 会直接读取它计算 advantage。

如果提供以下字段,还需要满足长度约定:

  • logprobs:长度必须等于 len(response_ids)

  • response_mask:长度必须等于 len(response_ids)。mask 为 0 的 token 会被转成训练 label -100,对应 advantage 也会置为 0.0

这也是自定义 AgentLoop 最容易出错的地方:工具返回、环境反馈、系统插入内容等不是模型生成的 token,应该在 response_mask 中置为 0,并给对应 logprobs0.0

SingleTurnAgentLoop#

SingleTurnAgentLoop 是默认单轮问答实现,适用于“给定 prompt,模型生成一次 response,然后打分”的任务。

单条样本流程如下:

generate_sample(state)
  -> PartialRolloutHandler.preprocess(state, enable_partial_rollout)
  -> 如果 state.tokens 为空,则 state.tokens = state.prompt_ids
  -> await rollout_ctl.generate.remote(state)
  -> PartialRolloutHandler.postprocess(state)
  -> 如果 state.status != COMPLETED,直接返回,不触发 Judger
  -> 如果配置了 judger,调用 self.run_judger(state)

典型配置:

from xtuner.v1.data_proto.rl_data import SampleParams
from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig

agent_loop_config = SingleTurnAgentLoopConfig(
    hf_checkpoint=model_path,
    sample_params=SampleParams(
        max_tokens=1024,
        top_k=0,
        top_p=1.0,
        temperature=1.0,
        min_tokens=0,
    ),
)

AgentLoopConfig 还支持批量打分:

agent_loop_config = SingleTurnAgentLoopConfig(
    hf_checkpoint=model_path,
    sample_params=training_sample_params,
    enable_batch_judge=True,
)

开启后,generate_sample() 不会逐条调用 Judger;generate_group() 会在组内样本全部生成完成后调用一次 self.run_judger(group_samples),内部会转到 judger.batch_judge(group_samples)。只有当前 Judger 明确实现 batch_judge() 时才应开启。

自定义 AgentLoop#

自定义 AgentLoop 通常需要做四件事:

  1. 继承 AgentLoop,实现 generate_sample()

  2. generate_sample() 中维护 tokenssample_paramsresponse_idsresponselogprobsresponse_maskstatus

  3. 需要打分时,在 response 可用后调用 self.run_judger(...)

  4. 继承 AgentLoopConfig,实现 build_local(),这样才能接入 TaskSpecConfig.agent_loop_config,并复用 Ray actor 构建逻辑。

self.run_judger(...) 会根据输入形态调用 judger.judge()judger.batch_judge(),并统一处理 pause/cancel。若自定义 AgentLoop 支持 enable_batch_judge=Truegenerate_sample() 中的单条打分需要用 not self.enable_batch_judge 保护,避免默认 generate_group() 再次批量打分。若自定义 AgentLoop 需要覆盖 pause(),应在实现中调用 await super().pause(),否则正在执行的 Judger 任务无法复用基类的中断逻辑。

最小单轮实现#

下面是一个最小可用版本,行为接近 SingleTurnAgentLoop

from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.rl.agent_loop import AgentLoop, AgentLoopConfig
from xtuner.v1.rl.judger import Judger

class CustomAgentLoop(AgentLoop):
    async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState:
        if not rollout_state.tokens:
            rollout_state.tokens = rollout_state.prompt_ids

        rollout_state.sample_params = rollout_state.sample_params or self.sample_params
        rollout_state = await self.rollout_ctl.generate.remote(rollout_state)

        if rollout_state.status != Status.COMPLETED:
            return rollout_state

        if self.judger is not None and not self.enable_batch_judge:
            rollout_state = await self.run_judger(rollout_state)
        return rollout_state


class CustomAgentLoopConfig(AgentLoopConfig):
    def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> CustomAgentLoop:
        return CustomAgentLoop(
            rollout_ctl=rollout_controller,
            sample_params=self.sample_params,
            hf_checkpoint=self.hf_checkpoint,
            judger=judger,
            logger=logger,
        )

这个版本适合没有工具、没有环境反馈、没有多轮上下文拼接的任务。只要 RolloutController.generate() 能正确写入 response_idsresponselogprobsstatus,后续 Judger 与训练数据准备就能复用默认链路。

多轮或工具调用实现#

多轮任务通常需要循环调用 rollout_ctl.generate(),每轮把上轮输出和工具或环境结果追加到下一轮输入。可以参考 GSM8KToolAgentLoop 的模式:

class ToolAgentLoop(AgentLoop):
    async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState:
        final_response_ids: list[int] = []
        final_logprobs: list[float] = []
        final_response_mask: list[int] = []

        cur_tokens = list(rollout_state.tokens or rollout_state.prompt_ids or [])
        remaining_tokens = self.sample_params.max_tokens

        for _ in range(self.max_turns):
            rollout_state.tokens = cur_tokens
            rollout_state.sample_params = self.sample_params.model_copy(
                update={"max_tokens": remaining_tokens}
            )

            rollout_state = await self.rollout_ctl.generate.remote(rollout_state)
            if rollout_state.status != Status.COMPLETED:
                break

            response_ids = list(rollout_state.response_ids or [])
            logprobs = list(rollout_state.logprobs or [])
            assert len(response_ids) == len(logprobs)

            final_response_ids.extend(response_ids)
            final_logprobs.extend(logprobs)
            final_response_mask.extend([1] * len(response_ids))
            cur_tokens.extend(response_ids)

            tool_tokens = self._run_tool_and_encode_result(rollout_state)
            if not tool_tokens:
                break

            final_response_ids.extend(tool_tokens)
            final_logprobs.extend([0.0] * len(tool_tokens))
            final_response_mask.extend([0] * len(tool_tokens))
            cur_tokens.extend(tool_tokens)

            remaining_tokens = self.sample_params.max_tokens - len(final_response_ids)
            if remaining_tokens <= 0:
                break

        rollout_state.response_ids = final_response_ids[: self.sample_params.max_tokens]
        rollout_state.logprobs = final_logprobs[: self.sample_params.max_tokens]
        rollout_state.response_mask = final_response_mask[: self.sample_params.max_tokens]
        rollout_state.response = self.tokenizer.decode(rollout_state.response_ids)

        assert len(rollout_state.response_ids) == len(rollout_state.logprobs)
        assert len(rollout_state.response_ids) == len(rollout_state.response_mask)

        if rollout_state.status == Status.COMPLETED and self.judger is not None and not self.enable_batch_judge:
            rollout_state = await self.run_judger(rollout_state)
        return rollout_state

这个例子强调两个约定:

  • 模型生成 token 的 response_mask1

  • 工具或环境插入 token 的 response_mask0logprobs0.0

覆盖 generate_group#

默认 generate_group() 会并发处理组内样本。如果你的任务需要组级逻辑,可以覆盖它:

async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]:
    samples = await super().generate_group(rollout_state, **kwargs)

    # 例:在默认并发生成和可选 batch judge 之后,继续执行组级过滤或排序。
    samples = self.filter_or_sort_group(samples)
    return samples

常见需要覆盖 generate_group() 的场景:

  • Judger 需要一次性处理同一个 prompt 的多条 response。

  • 组内样本共享同一个外部环境或缓存。

  • 需要在组内做过滤、排序或重采样。

  • 不希望组内样本并发执行。

支持 partial rollout#

如果训练使用 AsyncProduceStrategyConfig(enable_partial_rollout=True),producer 会把 enable_partial_rollout 作为运行时参数传给 generate_group(),再传入 generate_sample()

自定义 AgentLoop 可以直接复用 PartialRolloutHandler

from xtuner.v1.rl.agent_loop.utils import PartialRolloutHandler

class CustomAgentLoop(AgentLoop):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.partial_rollout_handler = PartialRolloutHandler(
            max_tokens=self.sample_params.max_tokens
        )

    async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState:
        enable_partial_rollout = kwargs.get("enable_partial_rollout", False)
        rollout_state = self.partial_rollout_handler.preprocess(
            rollout_state,
            enable_partial_rollout,
        )
        ...
        rollout_state = self.partial_rollout_handler.postprocess(rollout_state)
        return rollout_state

如果任务的多轮上下文、工具结果或环境状态不能用默认 handler 合并,需要自己定义续跑逻辑。核心原则是:续跑后的 response_idsresponselogprobsresponse_mask 必须仍然表示完整 response,而不是只有本轮新增部分。

在训练配置中使用#

训练配置中通常不手动实例化 AgentLoop,而是把 config 挂到 TaskSpecConfig.agent_loop_config

from xtuner.v1.rl.agent_loop_manager import (
    AgentLoopManagerConfig,
    SamplerConfig,
    SyncProduceStrategyConfig,
    TaskSpecConfig,
)

agent_loop_config = CustomAgentLoopConfig(
    hf_checkpoint=model_path,
    sample_params=training_sample_params,
)

agent_loop_manager_cfg = AgentLoopManagerConfig(
    tasks=TaskSpecConfig(
        task_name="train_task",
        agent_loop_config=agent_loop_config,
        judger_config=judger_config,
        produce_strategy_config=SyncProduceStrategyConfig(),
        sampler_config=sampler_config,
    ),
)

AgentLoopManagerConfig.build() 会根据 agent_loop_config 构建 AgentLoop,根据 judger_config 构建 Judger,再把它们和 sampler、producer strategy 组装成 task runner。多任务训练时,每个 TaskSpecConfig 都可以使用不同的 AgentLoop。

自定义 Checklist#

实现自定义 AgentLoop 时,建议逐项检查:

  • generate_sample() 是否只处理单条 RolloutState

  • 推理前是否设置了 rollout_state.tokens

  • 每次调用 rollout_ctl.generate.remote() 前是否设置了本轮 sample_params

  • 返回训练前,response_idsresponselogprobsresponse_mask 是否完整且长度一致。

  • 非模型生成 token 是否在 response_mask 中置为 0

  • 需要 Judger 时,是否通过 self.run_judger(...) 调用打分,以复用 pause/cancel 处理。

  • 若使用 _prepare_train_data(),是否保证最终有 reward["score"]

  • 若使用 async partial rollout,是否正确处理 enable_partial_rollout 和历史 response 合并。