Source code for xtuner.v1.rl.agent_loop_manager.sampler

import copy
from pathlib import Path
from typing import Iterator, Optional, cast
from uuid import uuid4

import ray
import torch
from pydantic import BaseModel, ConfigDict

from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.datasets.config import DataloaderConfig
from xtuner.v1.datasets.dataloader import Dataloader
from xtuner.v1.rl.replay_buffer import ReplayBuffer
from xtuner.v1.utils import XTUNER_DETERMINISTIC
from xtuner.v1.utils.logger import get_logger


logger = get_logger(__name__)


[docs]class SamplerConfig(BaseModel): """Configuration for sampling prompts into rollout groups. ``SamplerConfig`` wraps a dataloader configuration and controls how many rollout samples are generated from the same prompt. The sampler first tries to reuse eligible replay-buffer samples and falls back to the dataloader when no reusable sample is available. Args: dataloader_cfg (DataloaderConfig): Dataset dataloader configuration that yields ``RolloutState`` prompts. prompt_repeat_k (int): Number of rollout samples to create for each prompt. This is commonly the GRPO group size. Defaults to 1. **Examples:** Example sampler for an 8-response group:: config = SamplerConfig( dataloader_cfg=dataloader_cfg, prompt_repeat_k=8, ) """ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) dataloader_cfg: DataloaderConfig prompt_repeat_k: int = 1 def build( self, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str, replay_buffer: ReplayBuffer ) -> "Sampler": if isinstance(tokenizer, str): tokenizer_obj = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) else: tokenizer_obj = tokenizer dataloader = self.dataloader_cfg.build( tokenizer=tokenizer_obj, dp_mesh=None, global_batch_size=1, micro_batch_size=1, seed=1 ) return Sampler(dataloader=dataloader, prompt_repeat_k=self.prompt_repeat_k, replay_buffer=replay_buffer)
# TODO: The best solution is to put it in the fake_collator, # but it will cause a deadlock problem, so it is temporarily placed here. # The best solution should be to start the dataloader using spawn. def put_to_ray(data: RolloutState) -> RolloutState: if hasattr(data, "mm_info") and data.mm_info is not None: pixel_values = data.mm_info.get("pixel_values", None) if pixel_values is not None: data.mm_info["pixel_values"] = ray.put(pixel_values) return data class _DatasetSampler: def __init__(self, dataloader: Dataloader, prompt_repeat_k: int): self.dataloader = dataloader self.dataloader_iter: Optional[Iterator] = None self.cur_epoch = 0 self.prompt_repeat_k = prompt_repeat_k self._consumed_samples: int = 0 def __len__(self) -> int: return len(self.dataloader) def sample_from_dataloader(self) -> list[RolloutState]: if self.dataloader_iter is None: self.dataloader_iter = iter(self.dataloader) assert self.dataloader_iter is not None try: data = cast(RolloutState, next(self.dataloader_iter)[0]) data = put_to_ray(data) except StopIteration: self.cur_epoch += 1 self.dataloader.set_epoch(self.cur_epoch) self.dataloader_iter = iter(self.dataloader) data = cast(RolloutState, next(self.dataloader_iter)[0]) data = put_to_ray(data) if XTUNER_DETERMINISTIC: message_uid = self._consumed_samples uid_base = self._consumed_samples * self.prompt_repeat_k group_data = [] for item_idx in range(self.prompt_repeat_k): new_data = copy.deepcopy(data) if XTUNER_DETERMINISTIC: new_data.message_uid = message_uid new_data.uid = uid_base + item_idx new_data.session_uid = new_data.uid else: new_data.uid = uuid4().int group_data.append(new_data) self._consumed_samples += 1 return cast(list[RolloutState], group_data) class Sampler(_DatasetSampler): _DATALOADER_FILE = "dataloader" def __init__( self, dataloader: Dataloader, prompt_repeat_k: int, replay_buffer: ReplayBuffer, ): super().__init__(dataloader, prompt_repeat_k) self.replay_buffer = replay_buffer async def sample(self, task_name: str, group_status: list[Status] | None = None) -> list[RolloutState]: for status in group_status or []: buffer_data = await self.replay_buffer.get(1, task_name=task_name, group_status=status) if buffer_data: return buffer_data[0] return self.sample_from_dataloader() def save(self, checkpoint_path: Path | str) -> None: """Save the sampler's dataloader state to checkpoint.""" checkpoint_path = Path(checkpoint_path) dataloader_state = self.dataloader.get_state_dict() torch.save(dataloader_state, checkpoint_path / self._DATALOADER_FILE) def resume(self, checkpoint_path: Path | str) -> None: """Resume the sampler's dataloader state from checkpoint.""" checkpoint_path = Path(checkpoint_path) dataloader_path = checkpoint_path / self._DATALOADER_FILE if not dataloader_path.exists(): logger.warning(f"Dataloader state {dataloader_path} not found, skipping resume.") return state = torch.load(dataloader_path, map_location="cpu") self.dataloader.load_state_dict(state) self.dataloader_iter = iter(self.dataloader) self._consumed_samples = state["sampler"]["step"] self.cur_epoch = state["sampler"]["epoch"]