Source code for xtuner.v1.rl.evaluator
from collections.abc import Mapping
from typing import Annotated, Protocol, cast, runtime_checkable
from cyclopts import Parameter
from pydantic import BaseModel, ConfigDict, Field
from xtuner.v1.data_proto.rl_data import RolloutState
@runtime_checkable
class ComputeMetricProtocol(Protocol):
def __call__(self, samples: list[RolloutState]) -> dict[str, float]: ...
def default_compute_metric_func(samples: list[RolloutState]) -> dict[str, float]:
if not samples:
return {"accuracy": 0.0}
positive = 0
for s in samples:
reward = s.reward
assert isinstance(reward, Mapping)
score = reward["score"]
if score > 0:
positive += 1
return {"accuracy": positive / len(samples)}
class Evaluator:
def __init__(
self,
compute_metric_func: ComputeMetricProtocol | None = None,
eval_batch_size: int = 0,
):
self.compute_metric_func = compute_metric_func or default_compute_metric_func
self.eval_batch_size = eval_batch_size
def run(self, samples: list[RolloutState] | list[list[RolloutState]]) -> dict[str, float]:
# 将 list[list[RolloutState]] 转换为 list[RolloutState]
if samples and isinstance(samples[0], list):
flat_samples = [sample for batch in cast(list[list[RolloutState]], samples) for sample in batch]
else:
flat_samples = cast(list[RolloutState], samples)
return self.compute_metric_func(flat_samples)
[docs]class EvaluatorConfig(BaseModel):
"""Configuration for rollout evaluation.
``EvaluatorConfig`` controls how many generated samples are selected for
evaluation and which metric function is used to summarize them. It is used
by RL trainers when evaluation is enabled.
Args:
eval_sample_ratio (float): Ratio of generated samples to evaluate when
``eval_sample_num`` is not set. Defaults to 0.
eval_sample_num (int): Fixed number of samples to evaluate. A positive
value takes precedence over ``eval_sample_ratio``. Defaults to 0.
compute_metric_func (ComputeMetricProtocol | None): Optional function
that receives evaluated rollout states and returns metrics. Defaults
to None.
**Examples:**
Example evaluator using a fixed sample count::
config = EvaluatorConfig(
eval_sample_num=128,
compute_metric_func=compute_metrics,
)
"""
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
eval_sample_ratio: Annotated[
float,
Parameter(help="Ratio of samples to evaluate from the generated samples."),
] = 0
eval_sample_num: Annotated[
int,
Parameter(help="Number of samples to evaluate from the generated samples."),
] = 0
compute_metric_func: Annotated[
ComputeMetricProtocol | None,
Field(exclude=True),
Parameter(help="An optional metric computation function."),
] = None
def build(self, total_eval_samples: int = 0) -> "Evaluator":
if self.eval_sample_num > 0:
eval_batch_size = self.eval_sample_num
else:
assert total_eval_samples > 0, (
"Total eval samples must be greater than 0 if eval sample num is not provided"
)
if self.eval_sample_ratio > 0:
eval_batch_size = int(total_eval_samples * self.eval_sample_ratio)
else:
eval_batch_size = total_eval_samples
return Evaluator(
compute_metric_func=self.compute_metric_func,
eval_batch_size=eval_batch_size,
)