Source code for xtuner.v1.rl.evaluator

from collections.abc import Mapping
from typing import Annotated, Protocol, cast, runtime_checkable

from cyclopts import Parameter
from pydantic import BaseModel, ConfigDict, Field

from xtuner.v1.data_proto.rl_data import RolloutState


@runtime_checkable
class ComputeMetricProtocol(Protocol):
    def __call__(self, samples: list[RolloutState]) -> dict[str, float]: ...


def default_compute_metric_func(samples: list[RolloutState]) -> dict[str, float]:
    if not samples:
        return {"accuracy": 0.0}

    positive = 0
    for s in samples:
        reward = s.reward
        assert isinstance(reward, Mapping)
        score = reward["score"]
        if score > 0:
            positive += 1
    return {"accuracy": positive / len(samples)}


class Evaluator:
    def __init__(
        self,
        compute_metric_func: ComputeMetricProtocol | None = None,
        eval_batch_size: int = 0,
    ):
        self.compute_metric_func = compute_metric_func or default_compute_metric_func
        self.eval_batch_size = eval_batch_size

    def run(self, samples: list[RolloutState] | list[list[RolloutState]]) -> dict[str, float]:
        # 将 list[list[RolloutState]] 转换为 list[RolloutState]
        if samples and isinstance(samples[0], list):
            flat_samples = [sample for batch in cast(list[list[RolloutState]], samples) for sample in batch]
        else:
            flat_samples = cast(list[RolloutState], samples)
        return self.compute_metric_func(flat_samples)


[docs]class EvaluatorConfig(BaseModel): """Configuration for rollout evaluation. ``EvaluatorConfig`` controls how many generated samples are selected for evaluation and which metric function is used to summarize them. It is used by RL trainers when evaluation is enabled. Args: eval_sample_ratio (float): Ratio of generated samples to evaluate when ``eval_sample_num`` is not set. Defaults to 0. eval_sample_num (int): Fixed number of samples to evaluate. A positive value takes precedence over ``eval_sample_ratio``. Defaults to 0. compute_metric_func (ComputeMetricProtocol | None): Optional function that receives evaluated rollout states and returns metrics. Defaults to None. **Examples:** Example evaluator using a fixed sample count:: config = EvaluatorConfig( eval_sample_num=128, compute_metric_func=compute_metrics, ) """ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") eval_sample_ratio: Annotated[ float, Parameter(help="Ratio of samples to evaluate from the generated samples."), ] = 0 eval_sample_num: Annotated[ int, Parameter(help="Number of samples to evaluate from the generated samples."), ] = 0 compute_metric_func: Annotated[ ComputeMetricProtocol | None, Field(exclude=True), Parameter(help="An optional metric computation function."), ] = None def build(self, total_eval_samples: int = 0) -> "Evaluator": if self.eval_sample_num > 0: eval_batch_size = self.eval_sample_num else: assert total_eval_samples > 0, ( "Total eval samples must be greater than 0 if eval sample num is not provided" ) if self.eval_sample_ratio > 0: eval_batch_size = int(total_eval_samples * self.eval_sample_ratio) else: eval_batch_size = total_eval_samples return Evaluator( compute_metric_func=self.compute_metric_func, eval_batch_size=eval_batch_size, )