Source code for xtuner.v1.config.fsdp

from typing import Any, Optional

import torch
from cyclopts import Parameter
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
from typing_extensions import Annotated


[docs]class FSDPConfig(BaseModel): model_config = ConfigDict( arbitrary_types_allowed=True, protected_namespaces=(), extra="forbid", ) tp_size: Annotated[int, Parameter(help="Tensor parallel size")] = 1 ep_size: Annotated[int, Parameter(help="Expert parallel size")] = 1 reshard_after_forward: Annotated[bool, Parameter(help="Reshard model parameters after forward pass")] = True recompute_ratio: Annotated[float, Parameter(help="Gradient checkpointing ratio for memory optimization")] = 1.0 vision_recompute_ratio: Annotated[float, Parameter(help="Recompute ratio for vision modules")] = 1.0 checkpoint_preserve_rng_state: Annotated[bool, Parameter(help="Preserve RNG state during checkpointing")] = True cpu_offload: Annotated[bool, Parameter(help="Enable CPU offloading for memory optimization")] = False # TODO: (caoweihan) Convert `torch.dtype` to `Annotated` for compatibility with cyclopts param_dtype: Annotated[torch.dtype, Parameter(help="Data type for model parameters")] = torch.bfloat16 reduce_dtype: Annotated[torch.dtype, Parameter(help="Data type for reduction operations")] = torch.bfloat16 fp32_lm_head: Annotated[bool, Parameter(help="Use float32 for language model head")] = False # TODO: deprecate `torch_compile` in favor of `compile_cfg` in XTunerBaseModelConfig torch_compile: Annotated[bool, Parameter(help="Enable model compilation for faster inference")] = True mesh_prefix: Annotated[str, Parameter(help="Prefix for device mesh configuration in distributed training")] = ( "default" ) requires_grad: Annotated[bool, Parameter(help="Enable gradient computation for model parameters")] = True hsdp_sharding_size: Annotated[ Optional[int], Parameter(help="Sharding size for HSDP (Hybrid Sharding Data Parallel)") ] = None def model_post_init(self, __context: Any) -> None: if self.hsdp_sharding_size is not None: assert self.ep_size == 1, "Currently, HSDP requires expert parallel size to be 1" @field_serializer("param_dtype", "reduce_dtype") def serialize_param_dtype(self, value: torch.dtype) -> str: return str(value) @field_validator("param_dtype", "reduce_dtype", mode="before") @classmethod def deserialize_param_dtype(cls, value: str | torch.dtype) -> torch.dtype: if isinstance(value, torch.dtype): return value elif isinstance(value, str): if "bfloat16" in value: return torch.bfloat16 elif "float16" in value or "half" in value: return torch.float16 elif "float32" in value or "float" in value: return torch.float32 else: raise ValueError() else: return value