Source code for xtuner.v1.train.toy_tokenizer

from pathlib import Path
from typing import Literal, Union

import numpy as np
import torch

from transformers.tokenization_utils_base import BatchEncoding


def _replace_token_ids(text, ids, special_str, special_id):
    if special_str in text:
        _positions = []
        start = 0
        while True:
            pos = text.find(special_str, start)
            if pos == -1:
                break
            _positions.append(pos)
            start = pos + len(special_str)

        new_ids = []
        current_pos = 0

        for _pos in _positions:
            _byte = len(text[:_pos].encode("utf-8"))
            new_ids.extend(ids[current_pos:_byte])
            new_ids.append(special_id)
            current_pos = _byte + len(special_str.encode("utf-8"))

        new_ids.extend(ids[current_pos:])
        ids = new_ids
    return ids


[docs]class UTF8ByteTokenizer: """A simple byte-level tokenizer that encodes text as UTF-8 bytes with special token handling. This tokenizer converts text into a sequence of byte values (0-255) and supports special tokens for text boundaries and image-related placeholders. It provides basic encoding/decoding functionality compatible with transformers' tokenizer interface. Args: bos_token_id (int, optional): Beginning of sequence token ID. Defaults to 256. eos_token_id (int, optional): End of sequence token ID. Defaults to 257. pad_token_id (int, optional): Padding token ID. Defaults to 258. image_start_id_map (dict[str, int], optional): Mapping from image start string to token ID. Defaults to {"<img>": 259}. image_context_id_map (dict[str, int], optional): Mapping from image context string to token ID. Defaults to {"<IMG_CONTEXT>": 260}. image_end_id_map (dict[str, int], optional): Mapping from image end string to token ID. Defaults to {"</img>": 261}. """ def __init__( self, bos_token_id: int | None = 256, eos_token_id: int | None = 257, pad_token_id: int | None = 258, image_start_id_map: dict[str, int] | None = None, image_context_id_map: dict[str, int] | None = None, image_end_id_map: dict[str, int] | None = None, ): if image_start_id_map is None: image_start_id_map = {"<img>": 259} if image_context_id_map is None: image_context_id_map = {"<IMG_CONTEXT>": 260} if image_end_id_map is None: image_end_id_map = {"</img>": 261} self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.eos_token = "<end>" self.image_start_str = list(image_start_id_map.keys())[0] self.image_context_str = list(image_context_id_map.keys())[0] self.image_end_str = list(image_end_id_map.keys())[0] self.image_start_id = list(image_start_id_map.values())[0] self.image_context_id = list(image_context_id_map.values())[0] self.image_end_id = list(image_end_id_map.values())[0] self.special_ids = {x for x in [bos_token_id, eos_token_id, pad_token_id] if x is not None} self.vocab_size = 256 + len(self.special_ids) + 3
[docs] def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: """Encode text into a sequence of token IDs. Converts text to UTF-8 bytes and replaces special image tokens with their corresponding IDs. Optionally adds BOS/EOS tokens. Args: text (str): The text to encode. add_special_tokens (bool, optional): Whether to add BOS/EOS tokens. Defaults to False. Returns: list[int]: List of token IDs representing the encoded text. """ # 严格 UTF-8 编码,遇到非法字符会抛错;如需忽略可用 errors='ignore' b = text.encode("utf-8", errors="strict") ids = list(b) # 每个字节 0..255 ids = _replace_token_ids(text, ids, self.image_start_str, self.image_start_id) ids = _replace_token_ids(text, ids, self.image_context_str, self.image_context_id) ids = _replace_token_ids(text, ids, self.image_end_str, self.image_end_id) if add_special_tokens and self.bos_token_id is not None: ids = [self.bos_token_id] + ids if add_special_tokens and self.eos_token_id is not None: ids = ids + [self.eos_token_id] return ids
[docs] def decode(self, ids: list[int], skip_special_tokens: bool = True) -> str: """Decode a sequence of token IDs back to text. Converts byte-level token IDs back to UTF-8 text. Special tokens are either skipped or kept as-is based on skip_special_tokens parameter. Args: ids (list[int]): List of token IDs to decode. skip_special_tokens (bool, optional): Whether to skip special tokens during decoding. Defaults to True. Returns: str: The decoded text. Raises: ValueError: If any token ID is negative. """ bytes_list = [] for t in ids: if t < 0: raise ValueError(f"invalid token id: {t}") if t < 256: bytes_list.append(t) else: if not skip_special_tokens and t in self.special_ids: # 特殊符号默认不还原为字符;如有需要可自定义占位符 pass # 其余 >=256 的非特殊 id 直接忽略 return bytes(bytes_list).decode("utf-8", errors="strict")
[docs] def convert_tokens_to_ids(self, token: str) -> Union[int, list[int]]: ids = self.encode(token) if len(ids) == 1: return ids[0] return ids
[docs] def save_pretrained(self, save_directory: str | Path): pass
def __call__(self, text: list[str], return_tensors: Literal["np", "pt"] = "np", **kwargs): """Encode a batch of texts into token IDs with tensor format. Processes multiple texts simultaneously and returns them in the specified tensor format compatible with transformers library. Args: text (list[str]): List of texts to encode. return_tensors (Literal["np", "pt"], optional): Format of returned tensors. "np" for numpy arrays, "pt" for PyTorch tensors. Defaults to "np". **kwargs: Additional arguments passed to encode method, including add_special_tokens (bool). Returns: BatchEncoding: Object containing encoded token IDs with the specified tensor type. """ batch_ids = [] for _text in text: ids = self.encode(_text, add_special_tokens=kwargs.get("add_special_tokens", False)) if return_tensors == "np": ids = np.array(ids, dtype=np.int32) elif return_tensors == "pt": ids = torch.tensor(ids, dtype=torch.int32) # type: ignore batch_ids.append(ids) return BatchEncoding({"input_ids": batch_ids}, tensor_type=return_tensors)