Source code for xtuner.v1.rl.agent_loop.single_turn_agent_loop

from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status
from xtuner.v1.rl.judger import Judger
from xtuner.v1.rl.rollout import RolloutController

from .agent_loop import AgentLoop, AgentLoopConfig


[docs]class SingleTurnAgentLoopConfig(AgentLoopConfig): """Configuration for the built-in single-turn agent loop. ``SingleTurnAgentLoopConfig`` runs one model generation for each input ``RolloutState`` and optionally sends the completed output to a judger. It is the default choice for math, QA, and other single-response RL tasks. Args: sample_params (SampleParams): Sampling parameters used by the rollout backend, such as temperature and maximum generation length. hf_checkpoint (str): Hugging Face checkpoint path used to identify the policy checkpoint for the agent loop. cpu_resources (CPUResourcesConfig | None): PG-external CPU resources used to run this agent loop as Ray actors. ``None`` runs the loop in local mode. Defaults to None. enable_batch_judge (bool): Whether to judge a generated group in one batch in ``generate_group``. Defaults to False. **Examples:** Example configuration for a single-turn task:: config = SingleTurnAgentLoopConfig( sample_params=SampleParams(max_tokens=1024, temperature=1.0), hf_checkpoint="Qwen/Qwen3-8B", enable_batch_judge=True, ) """ def build_local(self, rollout_controller, judger: Judger | None = None, logger=None) -> "SingleTurnAgentLoop": return SingleTurnAgentLoop( rollout_ctl=rollout_controller, sample_params=self.sample_params, hf_checkpoint=self.hf_checkpoint, judger=judger, logger=logger, enable_batch_judge=self.enable_batch_judge, )
class SingleTurnAgentLoop(AgentLoop): def __init__( self, rollout_ctl: RolloutController, sample_params: SampleParams, hf_checkpoint: str, judger: Judger | None = None, logger=None, enable_batch_judge: bool = False, ): super().__init__( rollout_ctl=rollout_ctl, sample_params=sample_params, hf_checkpoint=hf_checkpoint, judger=judger, logger=logger, enable_batch_judge=enable_batch_judge, ) async def generate_sample( self, rollout_state: RolloutState, **kwargs, ) -> RolloutState: if not rollout_state.tokens: rollout_state.tokens = rollout_state.prompt_ids # 推理引擎generate, 生成的结果会覆盖到 rollout_state.response_ids 上 rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined] # 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分 if rollout_state.status != Status.COMPLETED: return rollout_state if self.judger is not None and not self.enable_batch_judge: # 如果开启了批量打分,则在 generate_group 里统一打分,不在这里逐条打分 rollout_state = await self.run_judger(rollout_state) return rollout_state