import asyncio
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields, is_dataclass, replace
from itertools import count
from pathlib import Path
from typing import Any, List, TypeAlias, Union
import pandas as pd
import ray
import torch
from pydantic import BaseModel, ConfigDict
from xtuner.v1.data_proto.rl_data import (
RolloutState,
Status,
get_group_status,
refresh_seq_staleness,
reset_rollout_response,
update_sample_version,
)
from xtuner.v1.rl.utils import (
BetweenNode,
ConditionNode,
LogicNode,
LogicOperator,
Operators,
QueryNode,
ScalarNode,
SetNode,
parse_query,
)
from xtuner.v1.utils import get_logger
logger = get_logger(__name__)
def maybe_expire_group(group: list[RolloutState], stale_threshold: int) -> None:
if stale_threshold <= 0:
raise ValueError(f"stale_threshold must be positive, got {stale_threshold}.")
group_status = get_group_status(group)
if group_status not in (Status.COMPLETED, Status.ABORTED):
return
if any(getattr(sample, "seq_staleness", 0) >= stale_threshold for sample in group):
# 生成结果入库前统一做过期翻转,后续存储逻辑只按最终 group status 分类。
for sample in group:
sample.status = Status.EXPIRED
@dataclass
class StorageItem:
# 存储类型
item: List[RolloutState]
uid: int
timestamp_id: int
task_name: str
status: Status
staleness: int
@dataclass
class SerializedRayObjectRef:
value: Any
def _snapshot_nested_objectrefs(obj: Any) -> Any:
if isinstance(obj, ray.ObjectRef):
return SerializedRayObjectRef(_snapshot_nested_objectrefs(ray.get(obj)))
if isinstance(obj, BaseModel):
snapshot = obj.model_copy(deep=False)
for field_name in type(obj).model_fields:
setattr(snapshot, field_name, _snapshot_nested_objectrefs(getattr(obj, field_name)))
return snapshot
if is_dataclass(obj) and not isinstance(obj, type):
return replace(
obj,
**{field.name: _snapshot_nested_objectrefs(getattr(obj, field.name)) for field in fields(obj)},
)
if isinstance(obj, list):
return [_snapshot_nested_objectrefs(value) for value in obj]
if isinstance(obj, tuple):
return tuple(_snapshot_nested_objectrefs(value) for value in obj)
if isinstance(obj, set):
return {_snapshot_nested_objectrefs(value) for value in obj}
if isinstance(obj, dict):
return {key: _snapshot_nested_objectrefs(value) for key, value in obj.items()}
return obj
def _restore_nested_objectrefs(obj: Any) -> Any:
if isinstance(obj, SerializedRayObjectRef):
return ray.put(_restore_nested_objectrefs(obj.value))
if isinstance(obj, BaseModel):
restored = obj.model_copy(deep=False)
for field_name in type(obj).model_fields:
setattr(restored, field_name, _restore_nested_objectrefs(getattr(obj, field_name)))
return restored
if is_dataclass(obj) and not isinstance(obj, type):
return replace(
obj,
**{field.name: _restore_nested_objectrefs(getattr(obj, field.name)) for field in fields(obj)},
)
if isinstance(obj, list):
return [_restore_nested_objectrefs(value) for value in obj]
if isinstance(obj, tuple):
return tuple(_restore_nested_objectrefs(value) for value in obj)
if isinstance(obj, set):
return {_restore_nested_objectrefs(value) for value in obj}
if isinstance(obj, dict):
return {key: _restore_nested_objectrefs(value) for key, value in obj.items()}
return obj
QUERY_KEYS = [f.name for f in fields(StorageItem)]
QueryKey = Union[str, LogicOperator] # str 是 StorageItem 的字段名,LogicOperator 是 "$and", "$or" 等逻辑操作符
# 查询类型:
QueryDict: TypeAlias = dict[
QueryKey,
Union[
Any, # 直接匹配值,例如: {"task_name": "math"}
dict[Operators, Any], # 操作符匹配,例如: {"uid": {"$gt": 10}}
List["QueryDict"], # 逻辑组合,例如: {"$and": [{"a": 1}, {"b": 2}]}
],
]
QueryType = Union[QueryDict, QueryNode]
class StorageBackend(ABC):
@abstractmethod
async def put(self, item: StorageItem) -> int: ...
@abstractmethod
async def get(self, query: QueryType) -> List[StorageItem]: ...
@abstractmethod
async def count(self, query: QueryType) -> int: ...
@abstractmethod
async def delete(self, uids: list[int]) -> None: ...
@abstractmethod
async def update(self, items: list[StorageItem]) -> None: ...
@abstractmethod
def __len__(self) -> int: ...
@abstractmethod
def state_dict(self) -> dict[str, Any]: ...
@abstractmethod
def load_state_dict(self, state: dict[str, Any]) -> None: ...
class ReplayPolicy(ABC):
@abstractmethod
async def put(self, item: StorageItem, storage_backend: StorageBackend) -> None: ...
@abstractmethod
async def get(self, count: int, query: QueryType, storage_backend: StorageBackend) -> list[list[RolloutState]]: ...
async def count(self, query: QueryType, storage_backend: StorageBackend) -> int:
return await storage_backend.count(query)
class NaiveStorage(StorageBackend):
def __init__(self):
self._uid_gen = count(1)
self._timestamp_id_gen = count(1)
self._items: dict[int, StorageItem] = {}
async def put(self, item: StorageItem) -> int:
uid = next(self._uid_gen)
stored = replace(item, uid=uid, timestamp_id=next(self._timestamp_id_gen))
self._items[uid] = stored
return uid
def _evaluate(self, item: StorageItem, query_node: QueryNode) -> bool:
"""NaiveStorage 实现的原生 Python 对象过滤树遍历."""
if isinstance(query_node, LogicNode):
if not query_node.conditions:
return query_node.relation == "$and"
if query_node.relation == "$and":
return all(self._evaluate(item, child) for child in query_node.conditions)
else:
return any(self._evaluate(item, child) for child in query_node.conditions)
elif isinstance(query_node, ConditionNode):
if query_node.field not in QUERY_KEYS:
raise ValueError(f"查询字段错误: 找不到属性 '{query_node.field}'。可用属性为: {QUERY_KEYS}")
val = getattr(item, query_node.field)
if isinstance(query_node, ScalarNode):
if query_node.op == "$eq":
return val == query_node.value
if query_node.op == "$ne":
return val != query_node.value
if query_node.op == "$gt":
return val > query_node.value
if query_node.op == "$gte":
return val >= query_node.value
if query_node.op == "$lt":
return val < query_node.value
if query_node.op == "$lte":
return val <= query_node.value
elif isinstance(query_node, SetNode):
if query_node.op == "$in":
return val in query_node.value
if query_node.op == "$not_in":
return val not in query_node.value
elif isinstance(query_node, BetweenNode):
return query_node.lower <= val <= query_node.upper
return False
async def get(self, query: QueryType) -> list[StorageItem]:
ast = parse_query(query)
return [item for item in self._items.values() if self._evaluate(item, ast)]
async def count(self, query: QueryType) -> int:
ast = parse_query(query)
return sum(1 for item in self._items.values() if self._evaluate(item, ast))
async def delete(self, uids: list[int]) -> None:
if not uids:
return
for uid in uids:
self._items.pop(uid, None)
async def update(self, items: list[StorageItem]) -> None:
for item in items:
old_item = self._items.get(item.uid)
if old_item is None:
continue
# 原地更新保留 uid/timestamp,避免刷新 staleness 改变 replay 顺序。
self._items[item.uid] = replace(item, uid=old_item.uid, timestamp_id=old_item.timestamp_id)
def __len__(self) -> int:
return len(self._items)
def state_dict(self) -> dict[str, Any]:
max_uid = max(self._items, default=0)
max_timestamp_id = max((item.timestamp_id for item in self._items.values()), default=0)
return {
"items": [_snapshot_nested_objectrefs(item) for item in self._items.values()],
"next_uid": max_uid + 1,
"next_timestamp_id": max_timestamp_id + 1,
}
def load_state_dict(self, state: dict[str, Any]) -> None:
items: list[StorageItem] = [_restore_nested_objectrefs(item) for item in state["items"]]
self._items = {item.uid: item for item in items}
self._uid_gen = count(state["next_uid"])
self._timestamp_id_gen = count(state["next_timestamp_id"])
class PandasStorage(StorageBackend):
def __init__(self):
self._uid_gen = count(1)
self._timestamp_id_gen = count(1)
self._df = pd.DataFrame(columns=["uid", "timestamp_id", "task_name", "status", "staleness", "item"])
self._buffer: list[dict] = []
def _flush_buffer(self):
if self._buffer:
new_df = pd.DataFrame(self._buffer)
self._df = new_df if self._df.empty else pd.concat([self._df, new_df], ignore_index=True)
self._buffer.clear()
async def put(self, item: StorageItem) -> int:
uid = next(self._uid_gen)
row = {
"uid": uid,
"timestamp_id": next(self._timestamp_id_gen),
"task_name": item.task_name,
"status": item.status,
"staleness": item.staleness,
"item": item.item,
}
self._buffer.append(row)
return uid
def _evaluate(self, query_node: QueryNode, df: pd.DataFrame) -> pd.Series:
"""PandasStorage 实现的向量化 DataFrame 过滤树遍历."""
if isinstance(query_node, LogicNode):
if not query_node.conditions:
return (
pd.Series(True, index=df.index)
if query_node.relation == "$and"
else pd.Series(False, index=df.index)
)
mask = self._evaluate(query_node.conditions[0], df)
for child in query_node.conditions[1:]:
child_mask = self._evaluate(child, df)
if query_node.relation == "$and":
mask = mask & child_mask
else:
mask = mask | child_mask
return mask
elif isinstance(query_node, ConditionNode):
field = query_node.field
if field not in QUERY_KEYS:
raise ValueError(f"查询字段错误: 找不到属性 '{query_node.field}'。可用属性为: {QUERY_KEYS}")
series = df[query_node.field]
if isinstance(query_node, ScalarNode):
if query_node.op == "$eq":
return series == query_node.value
if query_node.op == "$ne":
return series != query_node.value
if query_node.op == "$gt":
return series > query_node.value
if query_node.op == "$gte":
return series >= query_node.value
if query_node.op == "$lt":
return series < query_node.value
if query_node.op == "$lte":
return series <= query_node.value
elif isinstance(query_node, SetNode):
if query_node.op == "$in":
return series.isin(query_node.value)
if query_node.op == "$not_in":
return ~series.isin(query_node.value)
elif isinstance(query_node, BetweenNode):
return series.between(query_node.lower, query_node.upper)
else:
raise ValueError(f"不支持的查询节点类型: {type(query_node)}")
async def get(self, query: QueryType) -> list[StorageItem]:
self._flush_buffer()
if self._df.empty:
return []
ast = parse_query(query)
filtered_df = self._df[self._evaluate(ast, self._df)]
return [
StorageItem(
item=row["item"],
uid=row["uid"],
timestamp_id=row["timestamp_id"],
task_name=row["task_name"],
status=row["status"],
staleness=row["staleness"],
)
for _, row in filtered_df.iterrows()
]
async def count(self, query: QueryType) -> int:
self._flush_buffer()
if self._df.empty:
return 0
ast = parse_query(query)
return int(self._evaluate(ast, self._df).sum())
async def delete(self, uids: list[int]) -> None:
self._flush_buffer()
if not uids or self._df.empty:
return
self._df = self._df[~self._df["uid"].isin(uids)]
async def update(self, items: list[StorageItem]) -> None:
self._flush_buffer()
if not items or self._df.empty:
return
for item in items:
mask = self._df["uid"] == item.uid
if not mask.any():
continue
for row_idx in self._df.index[mask]:
self._df.at[row_idx, "status"] = item.status
self._df.at[row_idx, "staleness"] = item.staleness
self._df.at[row_idx, "item"] = item.item
def __len__(self) -> int:
return len(self._df) + len(self._buffer)
def state_dict(self) -> dict[str, Any]:
self._flush_buffer()
max_uid = int(self._df["uid"].max()) if not self._df.empty else 0
max_timestamp_id = int(self._df["timestamp_id"].max()) if not self._df.empty else 0
df = self._df.copy(deep=True)
if not df.empty:
df["item"] = df["item"].map(_snapshot_nested_objectrefs)
return {
"df": df,
"next_uid": max_uid + 1,
"next_timestamp_id": max_timestamp_id + 1,
}
def load_state_dict(self, state: dict[str, Any]) -> None:
self._df = state["df"].copy(deep=True)
if not self._df.empty:
self._df["item"] = self._df["item"].map(_restore_nested_objectrefs)
self._buffer = []
self._uid_gen = count(state["next_uid"])
self._timestamp_id_gen = count(state["next_timestamp_id"])
class FIFOReplayPolicy(ReplayPolicy):
async def put(self, item: StorageItem, storage_backend: StorageBackend) -> None:
if not item.item:
return
await storage_backend.put(item)
async def get(self, count: int, query: QueryType, storage_backend: StorageBackend) -> list[list[RolloutState]]:
if count <= 0:
return []
records = await storage_backend.get(query)
records.sort(key=lambda r: r.timestamp_id)
selected = records[:count]
if selected:
await storage_backend.delete([record.uid for record in selected])
return [record.item for record in selected]
class StalenessReplayPolicy(ReplayPolicy):
async def put(self, item: StorageItem, storage_backend: StorageBackend) -> None:
if not item.item:
return
await storage_backend.put(item)
async def get(self, count: int, query: QueryType, storage_backend: StorageBackend) -> list[list[RolloutState]]:
if count <= 0:
return []
records = await storage_backend.get(query)
records.sort(key=lambda r: (-r.staleness, r.timestamp_id))
selected = records[:count]
if selected:
await storage_backend.delete([record.uid for record in selected])
return [record.item for record in selected]
async def count(self, query: QueryType, storage_backend: StorageBackend) -> int:
return await storage_backend.count(query)
class ReplayBuffer:
_SAVE_PATH = "replay_buffer.pth"
def __init__(self, policy: ReplayPolicy, storage_backend: StorageBackend):
self._policy = policy
self._storage = storage_backend
self._lock = asyncio.Lock()
async def put(
self,
items: list[RolloutState],
task_name: str,
*,
model_step: int | None = None,
current_train_step: int | None = None,
stale_threshold: int | None = None,
) -> None:
if not items:
return
if model_step is not None:
for item in items:
update_sample_version(item, model_step)
if current_train_step is not None:
refresh_seq_staleness(items, current_train_step)
if stale_threshold is not None:
maybe_expire_group(items, stale_threshold)
status = get_group_status(items)
if status == Status.EXPIRED:
for item in items:
reset_rollout_response(item)
storage_item = StorageItem(
item=items,
uid=0,
timestamp_id=0,
task_name=task_name,
status=status,
staleness=max(item.seq_staleness for item in items),
)
async with self._lock:
await self._policy.put(storage_item, self._storage)
async def get(self, batch_size: int, task_name: str, group_status: Status) -> list[list[RolloutState]]:
# 使用 DSL 字典进行查询
query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": group_status}]}
async with self._lock:
return await self._policy.get(batch_size, query_dsl, self._storage)
async def count(self, task_name: str, group_status: Status) -> int:
# 使用 DSL 字典进行查询
query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": group_status}]}
async with self._lock:
return await self._policy.count(query_dsl, self._storage)
async def refresh_staleness(
self,
*,
task_stale_thresholds: dict[str, int],
current_train_step: int,
statuses: list[Status] | None = None,
) -> dict[str, int]:
# 刷新可复用样本的 staleness;completed / aborted 都可能来自旧权重,需要按 train_step 淘汰。
for task_name, stale_threshold in task_stale_thresholds.items():
if stale_threshold <= 0:
raise ValueError(f"stale_threshold must be positive, got {stale_threshold}.")
if statuses is None:
statuses = [Status.COMPLETED, Status.ABORTED]
expired_counts: dict[str, int] = {}
async with self._lock:
updated_records: list[StorageItem] = []
for task_name, stale_threshold in task_stale_thresholds.items():
query_dsl: QueryDict = {
"$and": [
{"task_name": task_name},
{"status": {"$in": statuses}},
]
}
records = await self._storage.get(query_dsl)
expired_count = 0
for record in records:
refresh_seq_staleness(record.item, current_train_step)
staleness = max((getattr(item, "seq_staleness", 0) for item in record.item), default=0)
should_expire = any(getattr(item, "seq_staleness", 0) >= stale_threshold for item in record.item)
if should_expire:
# completed / aborted 样本超过 step 级阈值时整组翻转,后续 sampler 可按 EXPIRED 重新取样。
for item in record.item:
reset_rollout_response(item)
item.status = Status.EXPIRED
status = Status.EXPIRED
expired_count += 1
else:
status = get_group_status(record.item)
updated_records.append(replace(record, status=status, staleness=staleness))
expired_counts[task_name] = expired_count
await self._storage.update(updated_records)
return expired_counts
async def is_ready(
self,
task_batch_sizes: dict[str, int],
*,
group_status: Status = Status.COMPLETED,
) -> bool:
async with self._lock:
for task_name, batch_size in task_batch_sizes.items():
if batch_size <= 0:
continue
query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": group_status}]}
if await self._policy.count(query_dsl, self._storage) < batch_size:
return False
return True
async def take_batch(
self,
task_batch_sizes: dict[str, int],
*,
group_status: Status = Status.COMPLETED,
) -> tuple[dict[str, list[list[RolloutState]]], dict[str, int]]:
batch_by_task: dict[str, list[list[RolloutState]]] = {}
consumed_counts: dict[str, int] = {}
async with self._lock:
for task_name, batch_size in task_batch_sizes.items():
if batch_size <= 0:
batch_by_task[task_name] = []
consumed_counts[task_name] = 0
continue
query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": group_status}]}
task_batch = await self._policy.get(batch_size, query_dsl, self._storage)
batch_by_task[task_name] = task_batch
consumed_counts[task_name] = len(task_batch)
return batch_by_task, consumed_counts
async def count_statuses(
self,
task_names: list[str],
statuses: list[Status],
) -> dict[str, dict[Status, int]]:
counts: dict[str, dict[Status, int]] = {task_name: {} for task_name in task_names}
async with self._lock:
for task_name in task_names:
for status in statuses:
query_dsl: QueryDict = {"$and": [{"task_name": task_name}, {"status": status}]}
counts[task_name][status] = await self._policy.count(query_dsl, self._storage)
return counts
def __len__(self) -> int:
return len(self._storage)
async def save(self, path: str | Path) -> None:
file_path = Path(path)
file_path.mkdir(parents=True, exist_ok=True)
replay_buffer_path = file_path / self._SAVE_PATH
async with self._lock:
state = {
"policy": type(self._policy).__name__,
"storage": type(self._storage).__name__,
"storage_state": self._storage.state_dict(),
}
await asyncio.to_thread(torch.save, state, replay_buffer_path)
logger.info(f"Replay buffer saved to {replay_buffer_path}")
async def resume(self, path: str | Path) -> None:
if len(self._storage) > 0:
raise RuntimeError("Cannot resume into a non-empty buffer")
file_path = Path(path)
replay_buffer_path = file_path / self._SAVE_PATH
state = await asyncio.to_thread(torch.load, replay_buffer_path, map_location="cpu", weights_only=False)
if state["policy"] != type(self._policy).__name__:
raise ValueError(f"Replay policy mismatch: expected {type(self._policy).__name__}, got {state['policy']}")
if state["storage"] != type(self._storage).__name__:
raise ValueError(
f"Storage backend mismatch: expected {type(self._storage).__name__}, got {state['storage']}"
)
async with self._lock:
self._storage.load_state_dict(state["storage_state"])
logger.info(f"Replay buffer resumed from {replay_buffer_path}")
[docs]class SyncReplayBufferConfig(BaseModel):
"""Configuration for the synchronous replay buffer.
The synchronous replay buffer uses FIFO replay policy and in-memory native
Python storage. It is intended for on-demand rollout production where
samples are consumed by the current training step.
Args:
No user-configurable fields.
**Examples:**
Example synchronous replay buffer::
config = SyncReplayBufferConfig()
"""
model_config = ConfigDict(extra="forbid")
def build(self):
return ReplayBuffer(policy=FIFOReplayPolicy(), storage_backend=NaiveStorage())
[docs]class AsyncReplayBufferConfig(BaseModel):
"""Configuration for the asynchronous replay buffer.
The asynchronous replay buffer uses a staleness-aware replay policy and
in-memory native Python storage. It is intended for background rollout
production where samples may be generated by an earlier model step.
Args:
No user-configurable fields.
**Examples:**
Example asynchronous replay buffer::
config = AsyncReplayBufferConfig()
"""
model_config = ConfigDict(extra="forbid")
def build(self):
policy = StalenessReplayPolicy()
return ReplayBuffer(policy=policy, storage_backend=NaiveStorage())