Source code for xtuner.v1.rl.judger.gsm8k

import re
from typing import Callable

from .native import JudgerConfig


_SOLUTION_CLIP_CHARS = 300


def extract_solution(solution_str, method="strict"):
    """Extract the numerical solution from a string.

    Args:
        solution_str (str): The string containing the solution.
        method (str): The extraction method, either "strict" or "flexible".
            "strict" requires the solution to be in the format "#### <number>".
            "flexible" extracts the last numerical value found.
            Defaults to "strict".

    Returns:
        str or None: The extracted numerical solution as a string, or None if
            not found.
    """
    assert method in ["strict", "flexible"]

    # Optimization: Regular expression matching on very long strings can be slow.
    # For math problems, the final answer is usually at the end.
    # We only match on the last 300 characters, which is a safe approximation for 300 tokens.
    if len(solution_str) > _SOLUTION_CLIP_CHARS:
        solution_str = solution_str[-_SOLUTION_CLIP_CHARS:]

    if method == "strict":
        # this also tests the formatting of the model
        solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
        if len(solutions) == 0:
            final_answer = None
        else:
            # take the last solution
            final_answer = solutions[-1].replace(",", "").replace("$", "")
    elif method == "flexible":
        answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
        final_answer = None
        if len(answer) == 0:
            # no reward is there is no answer
            pass
        else:
            invalid_str = ["", "."]
            # find the last number that is not '.'
            for final_answer in reversed(answer):
                if final_answer not in invalid_str:
                    break
    return final_answer


def compute_reward(response, label, extra_info):
    """Compute the reward for a given response based on the GSM8K dataset and
    criteria.

    Args:
        response (str): The model's generated response.
        label (str): The ground-truth answer.
        extra_info (dict): A dictionary containing scoring information,
            e.g., `{"score": 1, "format_score": 0}`.

    Returns:
        int or float: The calculated reward.
    """
    predict_str = response
    ground_truth = label
    answer = extract_solution(predict_str)
    if answer is None:
        return {"score": 0}
    else:
        if answer == ground_truth:
            return {"score": extra_info["score"]}
        else:
            return {"score": extra_info["format_score"]}


[docs]class GSM8KJudgerConfig(JudgerConfig): """Configuration for the built-in GSM8K judger. ``GSM8KJudgerConfig`` scores mathematical reasoning responses by extracting the final numeric answer and comparing it with the ground-truth answer. It is a preset ``JudgerConfig`` for the ``openai/gsm8k`` task. Args: judger_name (str): Logical judger name. Defaults to "openai/gsm8k". extra_info (dict): Reward values used by the GSM8K reward function. Defaults to ``{"score": 1, "format_score": 0}``. reward_handler (Callable | str): Reward handler used to compute the score. Defaults to ``compute_reward``. request_timeout (float): Timeout in seconds for HTTP reward handlers. Defaults to 30.0. **Examples:** Example GSM8K judger:: config = GSM8KJudgerConfig() """ judger_name: str = "openai/gsm8k" extra_info: dict = {"score": 1, "format_score": 0} reward_handler: Callable | str = compute_reward