Source code for xtuner.v1.rl.replay_buffer

import asyncio
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields, is_dataclass, replace
from itertools import count
from pathlib import Path
from typing import Any, List, TypeAlias, Union

import pandas as pd
import ray
import torch
from pydantic import BaseModel, ConfigDict

from xtuner.v1.data_proto.rl_data import (
    RolloutState,
    Status,
    get_group_status,
    refresh_seq_staleness,
    reset_rollout_response,
    update_sample_version,
)
from xtuner.v1.rl.utils import (
    BetweenNode,
    ConditionNode,
    LogicNode,
    LogicOperator,
    Operators,
    QueryNode,
    ScalarNode,
    SetNode,
    parse_query,
)
from xtuner.v1.utils import get_logger


logger = get_logger(__name__)


def maybe_expire_group(group: list[RolloutState], stale_threshold: int) -> None:
    if stale_threshold <= 0:
        raise ValueError(f"stale_threshold must be positive, got {stale_threshold}.")

    group_status = get_group_status(group)
    if group_status not in (Status.COMPLETED, Status.ABORTED):
        return
    if any(getattr(sample, "seq_staleness", 0) >= stale_threshold for sample in group):
        # 生成结果入库前统一做过期翻转,后续存储逻辑只按最终 group status 分类。
        for sample in group:
            sample.status = Status.EXPIRED


@dataclass
class StorageItem:
    # 存储类型
    item: List[RolloutState]
    uid: int
    timestamp_id: int
    task_name: str
    status: Status
    staleness: int


@dataclass
class SerializedRayObjectRef:
    value: Any


def _snapshot_nested_objectrefs(obj: Any) -> Any:
    if isinstance(obj, ray.ObjectRef):
        return SerializedRayObjectRef(_snapshot_nested_objectrefs(ray.get(obj)))
    if isinstance(obj, BaseModel):
        snapshot = obj.model_copy(deep=False)
        for field_name in type(obj).model_fields:
            setattr(snapshot, field_name, _snapshot_nested_objectrefs(getattr(obj, field_name)))
        return snapshot
    if is_dataclass(obj) and not isinstance(obj, type):
        return replace(
            obj,
            **{field.name: _snapshot_nested_objectrefs(getattr(obj, field.name)) for field in fields(obj)},
        )
    if isinstance(obj, list):
        return [_snapshot_nested_objectrefs(value) for value in obj]
    if isinstance(obj, tuple):
        return tuple(_snapshot_nested_objectrefs(value) for value in obj)
    if isinstance(obj, set):
        return {_snapshot_nested_objectrefs(value) for value in obj}
    if isinstance(obj, dict):
        return {key: _snapshot_nested_objectrefs(value) for key, value in obj.items()}
    return obj


def _restore_nested_objectrefs(obj: Any) -> Any:
    if isinstance(obj, SerializedRayObjectRef):
        return ray.put(_restore_nested_objectrefs(obj.value))
    if isinstance(obj, BaseModel):
        restored = obj.model_copy(deep=False)
        for field_name in type(obj).model_fields:
            setattr(restored, field_name, _restore_nested_objectrefs(getattr(obj, field_name)))
        return restored
    if is_dataclass(obj) and not isinstance(obj, type):
        return replace(
            obj,
            **{field.name: _restore_nested_objectrefs(getattr(obj, field.name)) for field in fields(obj)},
        )
    if isinstance(obj, list):
        return [_restore_nested_objectrefs(value) for value in obj]
    if isinstance(obj, tuple):
        return tuple(_restore_nested_objectrefs(value) for value in obj)
    if isinstance(obj, set):
        return {_restore_nested_objectrefs(value) for value in obj}
    if isinstance(obj, dict):
        return {key: _restore_nested_objectrefs(value) for key, value in obj.items()}
    return obj


QUERY_KEYS = [f.name for f in fields(StorageItem)]
QueryKey = Union[str, LogicOperator]  # str 是 StorageItem 的字段名,LogicOperator 是 "$and", "$or" 等逻辑操作符

# 查询类型:
QueryDict: TypeAlias = dict[
    QueryKey,
    Union[
        Any,  # 直接匹配值,例如: {"task_name": "math"}
        dict[Operators, Any],  # 操作符匹配,例如: {"uid": {"$gt": 10}}
        List["QueryDict"],  # 逻辑组合,例如: {"$and": [{"a": 1}, {"b": 2}]}
    ],
]
QueryType = Union[QueryDict, QueryNode]


class StorageBackend(ABC):
    @abstractmethod
    async def put(self, item: StorageItem) -> int: ...

    @abstractmethod
    async def get(self, query: QueryType) -> List[StorageItem]: ...

    @abstractmethod
    async def count(self, query: QueryType) -> int: ...

    @abstractmethod
    async def delete(self, uids: list[int]) -> None: ...

    @abstractmethod
    async def update(self, items: list[StorageItem]) -> None: ...

    @abstractmethod
    def __len__(self) -> int: ...

    @abstractmethod
    def state_dict(self) -> dict[str, Any]: ...

    @abstractmethod
    def load_state_dict(self, state: dict[str, Any]) -> None: ...


class ReplayPolicy(ABC):
    @abstractmethod
    async def put(self, item: StorageItem, storage_backend: StorageBackend) -> None: ...

    @abstractmethod
    async def get(self, count: int, query: QueryType, storage_backend: StorageBackend) -> list[list[RolloutState]]: ...

    async def count(self, query: QueryType, storage_backend: StorageBackend) -> int:
        return await storage_backend.count(query)


class NaiveStorage(StorageBackend):
    def __init__(self):
        self._uid_gen = count(1)
        self._timestamp_id_gen = count(1)
        self._items: dict[int, StorageItem] = {}

    async def put(self, item: StorageItem) -> int:
        uid = next(self._uid_gen)
        stored = replace(item, uid=uid, timestamp_id=next(self._timestamp_id_gen))
        self._items[uid] = stored
        return uid

    def _evaluate(self, item: StorageItem, query_node: QueryNode) -> bool:
        """NaiveStorage 实现的原生 Python 对象过滤树遍历."""
        if isinstance(query_node, LogicNode):
            if not query_node.conditions:
                return query_node.relation == "$and"

            if query_node.relation == "$and":
                return all(self._evaluate(item, child) for child in query_node.conditions)
            else:
                return any(self._evaluate(item, child) for child in query_node.conditions)

        elif isinstance(query_node, ConditionNode):
            if query_node.field not in QUERY_KEYS:
                raise ValueError(f"查询字段错误: 找不到属性 '{query_node.field}'。可用属性为: {QUERY_KEYS}")
            val = getattr(item, query_node.field)

            if isinstance(query_node, ScalarNode):
                if query_node.op == "$eq":
                    return val == query_node.value
                if query_node.op == "$ne":
                    return val != query_node.value
                if query_node.op == "$gt":
                    return val > query_node.value
                if query_node.op == "$gte":
                    return val >= query_node.value
                if query_node.op == "$lt":
                    return val < query_node.value
                if query_node.op == "$lte":
                    return val <= query_node.value

            elif isinstance(query_node, SetNode):
                if query_node.op == "$in":
                    return val in query_node.value
                if query_node.op == "$not_in":
                    return val not in query_node.value

            elif isinstance(query_node, BetweenNode):
                return query_node.lower <= val <= query_node.upper

        return False

    async def get(self, query: QueryType) -> list[StorageItem]:
        ast = parse_query(query)
        return [item for item in self._items.values() if self._evaluate(item, ast)]

    async def count(self, query: QueryType) -> int:
        ast = parse_query(query)
        return sum(1 for item in self._items.values() if self._evaluate(item, ast))

    async def delete(self, uids: list[int]) -> None:
        if not uids:
            return
        for uid in uids:
            self._items.pop(uid, None)

    async def update(self, items: list[StorageItem]) -> None:
        for item in items:
            old_item = self._items.get(item.uid)
            if old_item is None:
                continue
            # 原地更新保留 uid/timestamp,避免刷新 staleness 改变 replay 顺序。
            self._items[item.uid] = replace(item, uid=old_item.uid, timestamp_id=old_item.timestamp_id)

    def __len__(self) -> int:
        return len(self._items)

    def state_dict(self) -> dict[str, Any]:
        max_uid = max(self._items, default=0)
        max_timestamp_id = max((item.timestamp_id for item in self._items.values()), default=0)
        return {
            "items": [_snapshot_nested_objectrefs(item) for item in self._items.values()],
            "next_uid": max_uid + 1,
            "next_timestamp_id": max_timestamp_id + 1,
        }

    def load_state_dict(self, state: dict[str, Any]) -> None:
        items: list[StorageItem] = [_restore_nested_objectrefs(item) for item in state["items"]]
        self._items = {item.uid: item for item in items}
        self._uid_gen = count(state["next_uid"])
        self._timestamp_id_gen = count(state["next_timestamp_id"])


class PandasStorage(StorageBackend):
    def __init__(self):
        self._uid_gen = count(1)
        self._timestamp_id_gen = count(1)
        self._df = pd.DataFrame(columns=["uid", "timestamp_id", "task_name", "status", "staleness", "item"])
        self._buffer: list[dict] = []

    def _flush_buffer(self):
        if self._buffer:
            new_df = pd.DataFrame(self._buffer)
            self._df = new_df if self._df.empty else pd.concat([self._df, new_df], ignore_index=True)
            self._buffer.clear()

    async def put(self, item: StorageItem) -> int:
        uid = next(self._uid_gen)
        row = {
            "uid": uid,
            "timestamp_id": next(self._timestamp_id_gen),
            "task_name": item.task_name,
            "status": item.status,
            "staleness": item.staleness,
            "item": item.item,
        }
        self._buffer.append(row)
        return uid

    def _evaluate(self, query_node: QueryNode, df: pd.DataFrame) -> pd.Series:
        """PandasStorage 实现的向量化 DataFrame 过滤树遍历."""
        if isinstance(query_node, LogicNode):
            if not query_node.conditions:
                return (
                    pd.Series(True, index=df.index)
                    if query_node.relation == "$and"
                    else pd.Series(False, index=df.index)
                )

            mask = self._evaluate(query_node.conditions[0], df)
            for child in query_node.conditions[1:]:
                child_mask = self._evaluate(child, df)
                if query_node.relation == "$and":
                    mask = mask & child_mask
                else:
                    mask = mask | child_mask
            return mask

        elif isinstance(query_node, ConditionNode):
            field = query_node.field
            if field not in QUERY_KEYS:
                raise ValueError(f"查询字段错误: 找不到属性 '{query_node.field}'。可用属性为: {QUERY_KEYS}")
            series = df[query_node.field]

            if isinstance(query_node, ScalarNode):
                if query_node.op == "$eq":
                    return series == query_node.value
                if query_node.op == "$ne":
                    return series != query_node.value
                if query_node.op == "$gt":
                    return series > query_node.value
                if query_node.op == "$gte":
                    return series >= query_node.value
                if query_node.op == "$lt":
                    return series < query_node.value
                if query_node.op == "$lte":
                    return series <= query_node.value

            elif isinstance(query_node, SetNode):
                if query_node.op == "$in":
                    return series.isin(query_node.value)
                if query_node.op == "$not_in":
                    return ~series.isin(query_node.value)

            elif isinstance(query_node, BetweenNode):
                return series.between(query_node.lower, query_node.upper)
        else:
            raise ValueError(f"不支持的查询节点类型: {type(query_node)}")

    async def get(self, query: QueryType) -> list[StorageItem]:
        self._flush_buffer()
        if self._df.empty:
            return []

        ast = parse_query(query)
        filtered_df = self._df[self._evaluate(ast, self._df)]
        return [
            StorageItem(
                item=row["item"],
                uid=row["uid"],
                timestamp_id=row["timestamp_id"],
                task_name=row["task_name"],
                status=row["status"],
                staleness=row["staleness"],
            )
            for _, row in filtered_df.iterrows()
        ]

    async def count(self, query: QueryType) -> int:
        self._flush_buffer()
        if self._df.empty:
            return 0
        ast = parse_query(query)
        return int(self._evaluate(ast, self._df).sum())

    async def delete(self, uids: list[int]) -> None:
        self._flush_buffer()
        if not uids or self._df.empty:
            return
        self._df = self._df[~self._df["uid"].isin(uids)]

    async def update(self, items: list[StorageItem]) -> None:
        self._flush_buffer()
        if not items or self._df.empty:
            return
        for item in items:
            mask = self._df["uid"] == item.uid
            if not mask.any():
                continue
            for row_idx in self._df.index[mask]:
                self._df.at[row_idx, "status"] = item.status
                self._df.at[row_idx, "staleness"] = item.staleness
                self._df.at[row_idx, "item"] = item.item

    def __len__(self) -> int:
        return len(self._df) + len(self._buffer)

    def state_dict(self) -> dict[str, Any]:
        self._flush_buffer()
        max_uid = int(self._df["uid"].max()) if not self._df.empty else 0
        max_timestamp_id = int(self._df["timestamp_id"].max()) if not self._df.empty else 0
        df = self._df.copy(deep=True)
        if not df.empty:
            df["item"] = df["item"].map(_snapshot_nested_objectrefs)
        return {
            "df": df,
            "next_uid": max_uid + 1,
            "next_timestamp_id": max_timestamp_id + 1,
        }

    def load_state_dict(self, state: dict[str, Any]) -> None:
        self._df = state["df"].copy(deep=True)
        if not self._df.empty:
            self._df["item"] = self._df["item"].map(_restore_nested_objectrefs)
        self._buffer = []
        self._uid_gen = count(state["next_uid"])
        self._timestamp_id_gen = count(state["next_timestamp_id"])


class FIFOReplayPolicy(ReplayPolicy):
    async def put(self, item: StorageItem, storage_backend: StorageBackend) -> None:
        if not item.item:
            return
        await storage_backend.put(item)

    async def get(self, count: int, query: QueryType, storage_backend: StorageBackend) -> list[list[RolloutState]]:
        if count <= 0:
            return []
        records = await storage_backend.get(query)
        records.sort(key=lambda r: r.timestamp_id)
        selected = records[:count]
        if selected:
            await storage_backend.delete([record.uid for record in selected])
        return [record.item for record in selected]


class StalenessReplayPolicy(ReplayPolicy):
    async def put(self, item: StorageItem, storage_backend: StorageBackend) -> None:
        if not item.item:
            return
        await storage_backend.put(item)

    async def get(self, count: int, query: QueryType, storage_backend: StorageBackend) -> list[list[RolloutState]]:
        if count <= 0:
            return []

        records = await storage_backend.get(query)
        records.sort(key=lambda r: (-r.staleness, r.timestamp_id))
        selected = records[:count]
        if selected:
            await storage_backend.delete([record.uid for record in selected])
        return [record.item for record in selected]

    async def count(self, query: QueryType, storage_backend: StorageBackend) -> int:
        return await storage_backend.count(query)


class ReplayBuffer:
    _SAVE_PATH = "replay_buffer.pth"

    def __init__(self, policy: ReplayPolicy, storage_backend: StorageBackend):
        self._policy = policy
        self._storage = storage_backend
        self._lock = asyncio.Lock()

    async def put(
        self,
        items: list[RolloutState],
        task_name: str,
        *,
        model_step: int | None = None,
        current_train_step: int | None = None,
        stale_threshold: int | None = None,
    ) -> None:
        if not items:
            return
        if model_step is not None:
            for item in items:
                update_sample_version(item, model_step)
        if current_train_step is not None:
            refresh_seq_staleness(items, current_train_step)
        if stale_threshold is not None:
            maybe_expire_group(items, stale_threshold)

        status = get_group_status(items)
        if status == Status.EXPIRED:
            for item in items:
                reset_rollout_response(item)
        storage_item = StorageItem(
            item=items,
            uid=0,
            timestamp_id=0,
            task_name=task_name,
            status=status,
            staleness=max(item.seq_staleness for item in items),
        )
        async with self._lock:
            await self._policy.put(storage_item, self._storage)

    async def get(self, batch_size: int, task_name: str, group_status: Status) -> list[list[RolloutState]]:
        # 使用 DSL 字典进行查询
        query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": group_status}]}
        async with self._lock:
            return await self._policy.get(batch_size, query_dsl, self._storage)

    async def count(self, task_name: str, group_status: Status) -> int:
        # 使用 DSL 字典进行查询
        query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": group_status}]}
        async with self._lock:
            return await self._policy.count(query_dsl, self._storage)

    async def refresh_staleness(
        self,
        *,
        task_stale_thresholds: dict[str, int],
        current_train_step: int,
        statuses: list[Status] | None = None,
    ) -> dict[str, int]:
        # 刷新可复用样本的 staleness;completed / aborted 都可能来自旧权重,需要按 train_step 淘汰。
        for task_name, stale_threshold in task_stale_thresholds.items():
            if stale_threshold <= 0:
                raise ValueError(f"stale_threshold must be positive, got {stale_threshold}.")
        if statuses is None:
            statuses = [Status.COMPLETED, Status.ABORTED]
        expired_counts: dict[str, int] = {}
        async with self._lock:
            updated_records: list[StorageItem] = []
            for task_name, stale_threshold in task_stale_thresholds.items():
                query_dsl: QueryDict = {
                    "$and": [
                        {"task_name": task_name},
                        {"status": {"$in": statuses}},
                    ]
                }
                records = await self._storage.get(query_dsl)
                expired_count = 0
                for record in records:
                    refresh_seq_staleness(record.item, current_train_step)
                    staleness = max((getattr(item, "seq_staleness", 0) for item in record.item), default=0)
                    should_expire = any(getattr(item, "seq_staleness", 0) >= stale_threshold for item in record.item)
                    if should_expire:
                        # completed / aborted 样本超过 step 级阈值时整组翻转,后续 sampler 可按 EXPIRED 重新取样。
                        for item in record.item:
                            reset_rollout_response(item)
                            item.status = Status.EXPIRED
                        status = Status.EXPIRED
                        expired_count += 1
                    else:
                        status = get_group_status(record.item)
                    updated_records.append(replace(record, status=status, staleness=staleness))
                expired_counts[task_name] = expired_count
            await self._storage.update(updated_records)
        return expired_counts

    async def is_ready(
        self,
        task_batch_sizes: dict[str, int],
        *,
        group_status: Status = Status.COMPLETED,
    ) -> bool:
        async with self._lock:
            for task_name, batch_size in task_batch_sizes.items():
                if batch_size <= 0:
                    continue
                query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": group_status}]}
                if await self._policy.count(query_dsl, self._storage) < batch_size:
                    return False
        return True

    async def take_batch(
        self,
        task_batch_sizes: dict[str, int],
        *,
        group_status: Status = Status.COMPLETED,
    ) -> tuple[dict[str, list[list[RolloutState]]], dict[str, int]]:
        batch_by_task: dict[str, list[list[RolloutState]]] = {}
        consumed_counts: dict[str, int] = {}
        async with self._lock:
            for task_name, batch_size in task_batch_sizes.items():
                if batch_size <= 0:
                    batch_by_task[task_name] = []
                    consumed_counts[task_name] = 0
                    continue
                query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": group_status}]}
                task_batch = await self._policy.get(batch_size, query_dsl, self._storage)
                batch_by_task[task_name] = task_batch
                consumed_counts[task_name] = len(task_batch)
        return batch_by_task, consumed_counts

    async def count_statuses(
        self,
        task_names: list[str],
        statuses: list[Status],
    ) -> dict[str, dict[Status, int]]:
        counts: dict[str, dict[Status, int]] = {task_name: {} for task_name in task_names}
        async with self._lock:
            for task_name in task_names:
                for status in statuses:
                    query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": status}]}
                    counts[task_name][status] = await self._policy.count(query_dsl, self._storage)
        return counts

    def __len__(self) -> int:
        return len(self._storage)

    async def save(self, path: str | Path) -> None:
        file_path = Path(path)
        file_path.mkdir(parents=True, exist_ok=True)
        replay_buffer_path = file_path / self._SAVE_PATH
        async with self._lock:
            state = {
                "policy": type(self._policy).__name__,
                "storage": type(self._storage).__name__,
                "storage_state": self._storage.state_dict(),
            }
        await asyncio.to_thread(torch.save, state, replay_buffer_path)
        logger.info(f"Replay buffer saved to {replay_buffer_path}")

    async def resume(self, path: str | Path) -> None:
        if len(self._storage) > 0:
            raise RuntimeError("Cannot resume into a non-empty buffer")

        file_path = Path(path)
        replay_buffer_path = file_path / self._SAVE_PATH
        state = await asyncio.to_thread(torch.load, replay_buffer_path, map_location="cpu", weights_only=False)
        if state["policy"] != type(self._policy).__name__:
            raise ValueError(f"Replay policy mismatch: expected {type(self._policy).__name__}, got {state['policy']}")

        if state["storage"] != type(self._storage).__name__:
            raise ValueError(
                f"Storage backend mismatch: expected {type(self._storage).__name__}, got {state['storage']}"
            )

        async with self._lock:
            self._storage.load_state_dict(state["storage_state"])
        logger.info(f"Replay buffer resumed from {replay_buffer_path}")


[docs]class SyncReplayBufferConfig(BaseModel): """Configuration for the synchronous replay buffer. The synchronous replay buffer uses FIFO replay policy and in-memory native Python storage. It is intended for on-demand rollout production where samples are consumed by the current training step. Args: No user-configurable fields. **Examples:** Example synchronous replay buffer:: config = SyncReplayBufferConfig() """ model_config = ConfigDict(extra="forbid") def build(self): return ReplayBuffer(policy=FIFOReplayPolicy(), storage_backend=NaiveStorage())
[docs]class AsyncReplayBufferConfig(BaseModel): """Configuration for the asynchronous replay buffer. The asynchronous replay buffer uses a staleness-aware replay policy and in-memory native Python storage. It is intended for background rollout production where samples may be generated by an earlier model step. Args: No user-configurable fields. **Examples:** Example asynchronous replay buffer:: config = AsyncReplayBufferConfig() """ model_config = ConfigDict(extra="forbid") def build(self): policy = StalenessReplayPolicy() return ReplayBuffer(policy=policy, storage_backend=NaiveStorage())