Source code for xtuner.v1.rl.judger.composed

from __future__ import annotations

import asyncio
from typing import Callable, TypeAlias

from pydantic import BaseModel, ConfigDict, Field

from xtuner.v1.data_proto.rl_data import RolloutState

from .native import Judger, JudgerConfig, JudgerOutput


# Merge function contract for multi-branch composed judging:
# - The first argument is the original rollout state.
# - The second argument maps each selected branch key to that branch's raw judger output.
# - The function must return the same shape as the input rollout state with ``reward`` populated.
JudgerMergeFn: TypeAlias = Callable[
    [RolloutState | list[RolloutState], dict[str, JudgerOutput | list[JudgerOutput]]],
    RolloutState | list[RolloutState],
]


class ComposedJudger(Judger):
    def __init__(
        self,
        branches: dict[str, Judger],
        merge_fn: JudgerMergeFn | None = None,
    ):
        super().__init__()
        if not branches:
            raise ValueError("ComposedJudger requires at least one branch.")
        self.branches = branches
        # ``merge_fn=None`` is only valid for routing mode, where each sample
        # selects exactly one branch and that branch's reward is passed through.
        # If ``data_source`` selects multiple branches, callers must provide an
        # explicit merge function because reward aggregation is task-specific.
        self.merge_fn = merge_fn

    def _select_keys_from_data_source(self, rollout_state: RolloutState) -> list[str]:
        data_source = rollout_state.data_source
        if data_source is None:
            raise ValueError(
                "ComposedJudger requires rollout_state.data_source to route judger branches. "
                f"task_name={rollout_state.task_name!r}, available={sorted(self.branches)}"
            )
        if isinstance(data_source, str):
            if data_source not in self.branches:
                raise KeyError(
                    f"Unknown judger branch from data_source: {data_source!r}, available={sorted(self.branches)}"
                )
            return [data_source]
        if isinstance(data_source, dict):
            if not data_source:
                raise ValueError("ComposedJudger data_source dict must contain at least one judger branch.")
            selected_keys = []
            for key in data_source:
                if not isinstance(key, str):
                    raise TypeError(f"ComposedJudger data_source dict keys must be strings, got {key!r}.")
                if key not in self.branches:
                    raise KeyError(
                        f"Unknown judger branch from data_source: {key!r}, available={sorted(self.branches)}"
                    )
                selected_keys.append(key)
            return selected_keys

        raise TypeError(
            "ComposedJudger data_source must be a branch name string or a dict of branch names "
            f"got {type(data_source).__name__}: {data_source!r}. "
            f"task_name={rollout_state.task_name!r}, available={sorted(self.branches)}"
        )

    async def _judge_branch(
        self,
        key: str,
        rollout_state: RolloutState,
    ) -> tuple[str, JudgerOutput]:
        branch = self.branches[key]
        payload = branch.preprocess(rollout_state)
        output = await branch.judge_payload(payload)
        if isinstance(output, list):
            raise TypeError(f"Branch {key!r} returned a list output for one RolloutState.")
        return key, output

    async def _batch_judge_branch(
        self,
        key: str,
        rollout_states: list[RolloutState],
    ) -> tuple[str, list[JudgerOutput]]:
        branch = self.branches[key]
        payloads = [branch.preprocess(state) for state in rollout_states]
        outputs = await branch.judge_payload(payloads)
        if not isinstance(outputs, list):
            raise TypeError(f"Branch {key!r} returned a single output for a rollout state list.")
        if len(outputs) != len(rollout_states):
            raise ValueError(f"Branch {key!r} returned {len(outputs)} outputs for {len(rollout_states)} states.")
        return key, outputs

    def _postprocess_branch_batch(
        self,
        branch: Judger,
        rollout_states: list[RolloutState],
        outputs: list[JudgerOutput],
    ) -> list[RolloutState]:
        return [branch.postprocess(state, output) for state, output in zip(rollout_states, outputs)]

    async def judge(self, rollout_state: RolloutState) -> RolloutState:
        selected_keys = self._select_keys_from_data_source(rollout_state)

        if len(selected_keys) == 1:
            key = selected_keys[0]
            _, output = await self._judge_branch(key, rollout_state)
            return self.branches[key].postprocess(rollout_state, output)

        if self.merge_fn is None:
            raise ValueError(
                "ComposedJudger selected multiple branches but merge_fn is not provided. "
                f"selected_keys={selected_keys!r}"
            )

        judged = dict[str, JudgerOutput | list[JudgerOutput]](
            await asyncio.gather(*(self._judge_branch(key, rollout_state) for key in selected_keys))
        )
        merged = self.merge_fn(rollout_state, judged)
        if isinstance(merged, list):
            raise TypeError("ComposedJudger merge_fn returned a list for judge.")
        return merged

    async def batch_judge(self, rollout_states: list[RolloutState]) -> list[RolloutState]:
        if not rollout_states:
            raise ValueError("ComposedJudger requires at least one RolloutState when input is a list.")
        selected_keys = self._select_keys_from_data_source(rollout_states[0])

        if len(selected_keys) == 1:
            key = selected_keys[0]
            _, outputs = await self._batch_judge_branch(key, rollout_states)
            return self._postprocess_branch_batch(self.branches[key], rollout_states, outputs)

        if self.merge_fn is None:
            raise ValueError(
                "ComposedJudger selected multiple branches but merge_fn is not provided. "
                f"selected_keys={selected_keys!r}"
            )

        judged = dict[str, JudgerOutput | list[JudgerOutput]](
            await asyncio.gather(*(self._batch_judge_branch(key, rollout_states) for key in selected_keys))
        )
        merged = self.merge_fn(rollout_states, judged)
        if not isinstance(merged, list):
            raise TypeError("ComposedJudger merge_fn returned a single RolloutState for batch_judge.")
        return merged


[docs]class ComposedJudgerConfig(BaseModel): """Configuration for composing multiple judgers. ``ComposedJudgerConfig`` routes rollout states through ``RolloutState.data_source``. A string value selects one branch and passes that branch output through as ``RolloutState.reward``. A dict value selects multiple branches by key and requires ``merge_fn`` to define the final reward shape. Args: branches (dict[str, JudgerConfig | ComposedJudgerConfig]): Mapping from branch name to judger configuration. Branch names must match ``RolloutState.data_source`` string values or dict keys. merge_fn (JudgerMergeFn | None): Function that merges multiple branch outputs into the returned rollout state. Required when ``data_source`` may select more than one branch. Leave as ``None`` only when every sample selects exactly one branch. **Examples:** Example composed judger with two branches:: config = ComposedJudgerConfig( branches={ "math": GSM8KJudgerConfig(), "format": JudgerConfig(judger_name="format", reward_handler=format_reward), }, merge_fn=merge_rewards, ) """ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") branches: dict[str, JudgerConfigLike] # Branch routing is fixed to ``RolloutState.data_source``: # - str selects one branch. # - dict keys select multiple branches. # ``merge_fn=None`` means single-branch pass-through only. If data_source # may select multiple branches, this must be set explicitly. merge_fn: JudgerMergeFn | None = Field(default=None, exclude=True) def build(self) -> Judger: from .factory import build_judger return build_judger(self)
JudgerConfigLike: TypeAlias = JudgerConfig | ComposedJudgerConfig ComposedJudgerConfig.model_rebuild()