Source code for xtuner.v1.rl.agent_loop.single_turn_agent_loop
from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status
from xtuner.v1.rl.judger import Judger
from xtuner.v1.rl.rollout import RolloutController
from .agent_loop import AgentLoop, AgentLoopConfig
[docs]class SingleTurnAgentLoopConfig(AgentLoopConfig):
"""Configuration for the built-in single-turn agent loop.
``SingleTurnAgentLoopConfig`` runs one model generation for each input
``RolloutState`` and optionally sends the completed output to a judger. It
is the default choice for math, QA, and other single-response RL tasks.
Args:
sample_params (SampleParams): Sampling parameters used by the rollout
backend, such as temperature and maximum generation length.
hf_checkpoint (str): Hugging Face checkpoint path used to identify the
policy checkpoint for the agent loop.
cpu_resources (CPUResourcesConfig | None): PG-external CPU resources
used to run this agent loop as Ray actors. ``None`` runs the loop
in local mode. Defaults to None.
enable_batch_judge (bool): Whether to judge a generated group in one
batch in ``generate_group``. Defaults to False.
**Examples:**
Example configuration for a single-turn task::
config = SingleTurnAgentLoopConfig(
sample_params=SampleParams(max_tokens=1024, temperature=1.0),
hf_checkpoint="Qwen/Qwen3-8B",
enable_batch_judge=True,
)
"""
def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> "SingleTurnAgentLoop":
return SingleTurnAgentLoop(
rollout_ctl=rollout_controller,
sample_params=self.sample_params,
hf_checkpoint=self.hf_checkpoint,
judger=judger,
logger=logger,
enable_batch_judge=self.enable_batch_judge,
)
class SingleTurnAgentLoop(AgentLoop):
def __init__(
self,
rollout_ctl: RolloutController,
sample_params: SampleParams,
hf_checkpoint: str,
judger: Judger | None = None,
logger=None,
enable_batch_judge: bool = False,
):
super().__init__(
rollout_ctl=rollout_ctl,
sample_params=sample_params,
hf_checkpoint=hf_checkpoint,
judger=judger,
logger=logger,
enable_batch_judge=enable_batch_judge,
)
async def generate_sample(
self,
rollout_state: RolloutState,
**kwargs,
) -> RolloutState:
if not rollout_state.tokens:
rollout_state.tokens = rollout_state.prompt_ids
# 推理引擎generate, 生成的结果会覆盖到 rollout_state.response_ids 上
rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined]
# 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分
if rollout_state.status != Status.COMPLETED:
return rollout_state
if self.judger is not None and not self.enable_batch_judge:
# 如果开启了批量打分,则在 generate_group 里统一打分,不在这里逐条打分
rollout_state = await self.run_judger(rollout_state)
return rollout_state