Source code for xtuner.v1.rl.judger.native

"""
Judger 体系关系图
=================

                        ┌─────────────────┐
                        │     Judger      │  ← 所有 judger 的统一接口
                        │   judge(state)  │
                        │ batch_judge(list)│
                        └────────┬────────┘
                                 │ 继承
              ┌──────────────────┼───────────────────┐
              │                  │                   │
     ┌────────▼───────┐  ┌───────▼──────┐  ┌────────▼────────┐
     │  NativeJudger  │  │ RemoteJudger │  │   JudgerPool    │
     │                │  │              │  │                 │
     │ 本地执行        │  │ Ray Actor 代理│  │ 多副本负载均衡   │
     │ 调用 reward_fn │  │ 调用.remote() │  │ round-robin 分发│
     └────────────────┘  └───────┬──────┘  └────────┬────────┘
                                 │ 包含              │ 包含多个
                         ┌───────▼──────┐   ┌───────▼──────┐
                         │  JudgerActor │   │ RemoteJudger │
                         │  (Ray Actor) │   │  (同左)       │
                         │ 包装NativeJ  │   └──────────────┘
                         └───────┬──────┘
                                 │ 内部调用
                         ┌───────▼──────┐
                         │ NativeJudger │
                         └──────────────┘

     ┌──────────────────────────────────────┐
     │         ComposedJudger               │
     │                                      │
     │  data_source → 选 branch → judge     │
     │  merge_fn  → 合并多个 branch 的结果  │
     │                                      │
     │  branches: dict[str, Judger]         │
     │  (每个 branch 可以是上面任意一种)      │
     └──────────────────────────────────────┘

构建模式
--------
未配置 external CPU resources → NativeJudger                     (纯本地,无 Ray)

配置 external CPU resources 且 num_workers = 1 → RemoteJudger
                                          └─► JudgerActor (Ray Worker)
                                                  └─► NativeJudger

配置 external CPU resources 且 num_workers > 1 → JudgerPool
                                          ├─► RemoteJudger → JudgerActor → NativeJudger
                                          ├─► RemoteJudger → JudgerActor → NativeJudger
                                          └─► RemoteJudger → JudgerActor → NativeJudger

调用链示例(remote 模式,单条打分)
------------------------------------
AgentLoop
  └─► RemoteJudger.judge(state)
        ├─► preprocess(state)                 ← driver 侧提取轻量 payload
        └─► JudgerActor.judge_payload.remote(payload)
              └─► NativeJudger.judge_payload(payload)
                    └─► reward_handler(response, label)

批量打分语义
------------
Judger.judge 只处理单条 RolloutState。需要批量打分时调用
Judger.batch_judge(list[RolloutState])。不是所有具体 judger 都支持 batch;
比如 NativeJudger 和 CompassVerifierV2 会在 batch_judge 入口直接报错。
"""

from __future__ import annotations

import asyncio
import inspect
from typing import Any, Callable, TypeAlias, cast

import httpx
from pydantic import BaseModel, ConfigDict, Field
from ray.actor import ActorClass, ActorProxy

from xtuner.v1.data_proto.rl_data import RolloutState
from xtuner.v1.rl.utils import CPUActorLauncher, CPUResourcesConfig
from xtuner.v1.utils.logger import get_logger
from xtuner.v1.utils.type_helper import ray_method


logger = get_logger()

JudgerPayload: TypeAlias = dict[str, Any]
JudgerPayloadBatch: TypeAlias = JudgerPayload | list[JudgerPayload]
JudgerOutput: TypeAlias = dict[str, Any]
JudgerOutputBatch: TypeAlias = JudgerOutput | list[JudgerOutput]


class Judger:
    def __init__(self, judger_name: str | None = None):
        self._judger_name = judger_name or self.__class__.__name__

    def preprocess(self, rollout_state: RolloutState) -> JudgerPayload:
        return {
            "response": rollout_state.response,
            "label": rollout_state.reward_model.get("ground_truth") if rollout_state.reward_model else None,
            "message": rollout_state.message,
            "status": rollout_state.status,
            "data_source": rollout_state.data_source,
            "task_name": rollout_state.task_name,
        }

    def postprocess(self, rollout_state: RolloutState, output: JudgerOutput) -> RolloutState:
        rollout_state.reward = output
        return rollout_state

    async def judge(self, rollout_state: RolloutState) -> RolloutState:
        payload = self.preprocess(rollout_state)
        output = await self.judge_payload(payload)
        if isinstance(output, list):
            raise TypeError("Judger returned a list output for a single rollout state.")
        return self.postprocess(rollout_state, output)

    async def batch_judge(self, rollout_states: list[RolloutState]) -> list[RolloutState]:
        payloads = [self.preprocess(state) for state in rollout_states]
        outputs = await self.judge_payload(payloads)
        if not isinstance(outputs, list):
            raise TypeError(f"Judger returned a single output for {len(rollout_states)} rollout states.")
        if len(outputs) != len(rollout_states):
            raise ValueError(f"Judger returned {len(outputs)} outputs for {len(rollout_states)} rollout states.")
        return [self.postprocess(state, output) for state, output in zip(rollout_states, outputs)]

    async def judge_payload(self, payload: JudgerPayloadBatch) -> JudgerOutputBatch:
        raise NotImplementedError(f"{self.__class__.__name__}.judge_payload() is not implemented.")

    def get_judger_name(self) -> str:
        return self._judger_name


class NativeJudger(Judger):
    """Local judger implementation backed by a Python callable or HTTP
    endpoint.

    ``NativeJudger`` calls one reward handler for one rollout sample. It does
    not support ``batch_judge(list[RolloutState])``; callers that need grouped
    routing should use ``ComposedJudger`` or a judger implementation that
    explicitly supports batch payloads.
    """

    def __init__(
        self,
        judger_name: str = "native_judger",
        reward_handler: Callable | str | None = None,
        extra_info: dict | None = None,
        request_timeout: float = 30.0,
    ):
        super().__init__(judger_name=judger_name)
        self.extra_info = extra_info or {}
        self.reward_handler = reward_handler
        self.request_timeout = request_timeout

    async def batch_judge(self, rollout_states: list[RolloutState]) -> list[RolloutState]:
        raise NotImplementedError("NativeJudger does not support batch_judge.")

    async def judge_payload(self, payload: JudgerPayloadBatch) -> JudgerOutputBatch:
        assert not isinstance(payload, list), "NativeJudger does not support batch payloads."
        assert payload["response"] is not None, (
            "RolloutState must have a response for judging. You should detokenize the response_ids in AgentLoop"
        )
        assert payload["label"] is not None, (
            "RolloutState must have reward_model with 'ground_truth' for judging. You should set reward_model in "
            "AgentLoop"
        )
        input_kwargs = {
            "response": payload["response"],
            "label": payload["label"],
            "extra_info": {**self.extra_info},
        }

        judger_response = None
        if isinstance(self.reward_handler, str):
            # TODO: 如果超时或者返回状态错误,会如何?
            # TODO: 这里不好 try 的原因是在异常情况下,我们应该给 -1 还是 0 分呢?
            async with httpx.AsyncClient(timeout=self.request_timeout) as client:
                response = await client.post(self.reward_handler, json=input_kwargs)
                response.raise_for_status()
                judger_response = response.json()
        elif callable(self.reward_handler):
            if inspect.iscoroutinefunction(self.reward_handler):
                judger_response = await self.reward_handler(**input_kwargs)
            else:
                judger_response = self.reward_handler(**input_kwargs)

        assert judger_response is not None, "Reward handler did not return a response."
        assert isinstance(judger_response, dict), (
            f"Reward handler must return a dict, but got {type(judger_response)}."
        )
        return cast(JudgerOutput, judger_response)


class RemoteJudger(Judger):
    """Driver-side proxy for a Ray-hosted judger.

    ``RemoteJudger`` keeps the same ``Judger`` interface as local judgers, so
    callers still pass ``RolloutState`` to ``judge``. This proxy runs the same
    ``preprocess`` implementation as the actor-side judger on the driver, then
    sends only the lightweight payload to ``JudgerActor``. ``JudgerActor`` lives
    in the Ray worker process and owns the real local judger instance that
    executes ``judge_payload``. Batch support is determined by that actor-side
    judger.
    """

    def __init__(self, actor: RayJudgerProxy, judger_name: str, preprocess_judger: Judger | None = None):
        super().__init__(judger_name=judger_name)
        self.actor = actor
        self.preprocess_judger = preprocess_judger

    # Preprocess must run on the driver before the Ray call. Otherwise the full
    # RolloutState would be serialized to the actor, and remote branches with
    # custom preprocess logic would lose the payload fields they require.
    def preprocess(self, rollout_state: RolloutState) -> JudgerPayload:
        if self.preprocess_judger is None:
            return super().preprocess(rollout_state)
        return self.preprocess_judger.preprocess(rollout_state)

    async def judge_payload(self, payload: JudgerPayloadBatch) -> JudgerOutputBatch:
        return await self.actor.judge_payload.remote(payload)


class JudgerPool(Judger):
    """Round-robin dispatch across replicas of the same judger type."""

    def __init__(self, replicas: list[Judger], judger_name: str):
        super().__init__(judger_name=judger_name)
        if not replicas:
            raise ValueError("JudgerPool requires at least one replica.")
        self.replicas = replicas
        self._rr_index = 0
        self._lock = asyncio.Lock()
        self._worker_loads = dict.fromkeys(range(len(replicas)), 0)

    async def _pick_replica(self) -> tuple[int, Judger]:
        async with self._lock:
            replica_idx = self._rr_index % len(self.replicas)
            self._rr_index = (self._rr_index + 1) % len(self.replicas)
            self._worker_loads[replica_idx] += 1
            return replica_idx, self.replicas[replica_idx]

    async def _release_replica(self, replica_idx: int) -> None:
        async with self._lock:
            self._worker_loads[replica_idx] -= 1

    async def judge_payload(self, payload: JudgerPayloadBatch) -> JudgerOutputBatch:
        replica_idx, replica = await self._pick_replica()
        try:
            return await replica.judge_payload(payload)
        finally:
            await self._release_replica(replica_idx)

    def get_worker_status(self) -> dict[str, int]:
        return {f"{self._judger_name}[{idx}]": load for idx, load in self._worker_loads.items()}


[docs]class JudgerConfig(BaseModel): """Configuration for a native judger. ``JudgerConfig`` describes the reward logic and optionally names the external CPU resources used to run the judger as Ray actors. CPU quantities are declared by ``CPUResourcesConfig`` and validated by ``CPUResourceManager``. Args: judger_name (str): Logical judger name used in logs and reward output. reward_handler (Callable | str | None): Reward function or HTTP endpoint used to score a rollout. Defaults to None. request_timeout (float): Timeout in seconds for HTTP reward handlers. Defaults to 30.0. extra_info (dict): Extra static information passed to the reward handler. Defaults to an empty dict. **Examples:** Example local judger:: config = JudgerConfig( judger_name="custom/math", reward_handler=compute_reward, extra_info={"score": 1.0}, ) Remote actor judgers are enabled by setting ``cpu_resources``. """ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") judger_name: str reward_handler: Callable | str | None = Field(default=None, exclude=True) request_timeout: float = 30.0 extra_info: dict = Field(default_factory=dict, exclude=True) cpu_resources: CPUResourcesConfig | None = None def build_local(self) -> Judger: return NativeJudger( judger_name=self.judger_name, reward_handler=self.reward_handler, request_timeout=self.request_timeout, extra_info=self.extra_info, ) def build(self) -> Judger: from .factory import build_judger return build_judger(self)
class JudgerActor: def __init__(self, judger_config: JudgerConfig): self.judger = judger_config.build_local() @ray_method async def judge_payload(self, payload: JudgerPayloadBatch) -> JudgerOutputBatch: return await self.judger.judge_payload(payload) RayJudger = cast(ActorClass[JudgerActor], CPUActorLauncher.to_actor_class(JudgerActor)) RayJudgerProxy: TypeAlias = ActorProxy[JudgerActor]