import contextlib
import gc
import inspect
import json
import os
import pickle
import sys
import time
from concurrent.futures import Future, TimeoutError
from contextlib import contextmanager
from datetime import datetime, timedelta
from pathlib import Path
from shutil import rmtree
from typing import (
Annotated,
Callable,
Literal,
Protocol,
Sequence,
Sized,
cast,
overload,
runtime_checkable,
)
import torch
import torch.distributed as dist
import torch.nn as nn
from cyclopts import Parameter
from mmengine import load
from mmengine.dist import get_rank, get_world_size
from mmengine.runner import set_random_seed
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator, model_serializer, model_validator
from torch.distributed import init_process_group
from torch.distributed.device_mesh import init_device_mesh
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LinearLR, SequentialLR
from typing_extensions import NotRequired, Self, TypedDict
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from xtuner.v1._writer import get_writer
from xtuner.v1.config import FSDPConfig, LRConfig, OptimConfig
from xtuner.v1.data_proto.sequence_context import SequenceContext
from xtuner.v1.datasets.config import BaseDataloaderConfig, DataloaderConfig, DatasetConfigList
from xtuner.v1.engine import TrainEngine
from xtuner.v1.engine.train_engine import TrainStepInfo
from xtuner.v1.loss import CELossConfig
from xtuner.v1.model.base import AsyncHFSaveHandle, ModelItem, XTunerBaseModelConfig
from xtuner.v1.model.moe.moe import MoEConfig
from xtuner.v1.patch import patch_dcp_save_state_dict, patch_dcp_save_with_cache_storage, patch_default_save_plan
from xtuner.v1.profiler import profiling_memory, profiling_time
from xtuner.v1.profiler.prober import ProberList
from xtuner.v1.profiler.prober_utils import setup_prober_list
from xtuner.v1.utils import (
XTUNER_DETERMINISTIC,
ParallelConfigException,
StrEnum,
get_logger,
is_hf_model_path,
log_format,
log_rank0,
profile_time_and_memory,
record_git_info,
set_deterministic,
)
from xtuner.v1.utils.check_health import check_health
from xtuner.v1.utils.device import get_device, get_torch_device_module
from xtuner.v1.utils.internal_metrics import (
InternalMetrics,
InternalMetricsConfig,
InternalMetricsRecorder,
flatten_internal_metrics_for_logs,
)
from .toy_tokenizer import UTF8ByteTokenizer
# TODO: Move DEVICE to `xtuner.utils.device`
DEVICE = get_device()
DEVICE_MODULE = get_torch_device_module()
logger = get_logger()
class GitInfo(TypedDict):
commit: str | None
staged: str
unstaged: str
class ExpHistory(TypedDict):
begin: int
timestamp: str
git_info: GitInfo
end: NotRequired[int]
comment: NotRequired[str]
class PerformanceStatistics(TypedDict):
local_step_consumed_tokens: int
local_step_consumed_img_tokens: int | None
local_total_consumed_tokens: int
approximate_total_consumed_tokens: int
tgs: float
exp_tgs: float
eta_seconds: float
eta_hms: str
e2e_train_time: float
class ExpInfo(BaseModel):
model_config = ConfigDict(extra="forbid")
history: list[ExpHistory]
exp_dir: str
hf_checkpoint_list: list[str] = []
checkpoint_list: list[str] = []
snap_checkpoint_list: list[str] = []
cur_step: int = 0
cur_epoch: int = 0
consumed_tokens: int = 0
consumed_samples: int = 0
@property
def latest_checkpoint(self) -> str | None:
# compare checkpoint_list and snap_checkpoint_list, return the latest checkpoint
latest_ckp = None
if self.checkpoint_list:
latest_ckp = self.checkpoint_list[-1]
if self.snap_checkpoint_list:
snap_ckp = self.snap_checkpoint_list[-1]
latest_ckp = self._get_latest_checkpoint(latest_ckp, snap_ckp)
return latest_ckp
def _get_latest_checkpoint(self, ckp1: str | None, ckp2: str | None) -> str | None:
if ckp1 is None:
return ckp2
if ckp2 is None:
return ckp1
# compare the timestamp of ckp1 and ckp2, return the latest one
# ckp path is like: checkpoints/epoch-1-step-20 or checkpoints/snapshot-epoch-3-step-50
step1 = int(ckp1.split("-")[-1])
step2 = int(ckp2.split("-")[-1])
return ckp1 if step1 > step2 else ckp2
class XTunerMeta(BaseModel):
model_config = ConfigDict(extra="forbid")
exps: list[ExpInfo]
@property
def latest_checkpoint(self) -> str | None:
for exp in self.exps:
if exp.latest_checkpoint is not None:
return exp.latest_checkpoint
return None
@property
def latest_hf_checkpoint(self) -> str | None:
for exp in self.exps:
if exp.hf_checkpoint_list:
return exp.hf_checkpoint_list[-1]
return None
@property
def latest_exp(self) -> ExpInfo:
return self.exps[-1]
def get_exp_by_checkpoint(self, checkpoint: str) -> ExpInfo | None:
for exp in self.exps:
for cp in exp.checkpoint_list:
if cp == checkpoint:
return exp
return None
@classmethod
def build(cls, work_dir: Path, meta_filename: str, resume: bool) -> "XTunerMeta":
"""Create or load meta from work_dir and optionally start a new exp or
resume.
Single-process helper (e.g. for rl_trainer). For distributed training use the trainer's _init_xtuner_meta.
"""
if not work_dir.exists():
work_dir.mkdir(parents=True, exist_ok=True)
meta_path = work_dir / meta_filename
if not meta_path.exists():
meta = cls(exps=[])
with open(meta_path, "w") as f:
f.write(meta.model_dump_json(indent=2))
meta = cast(XTunerMeta, cls.model_validate(load(meta_path, file_format="json")))
resume = resume and bool(meta.exps)
if resume and meta.exps:
latest_exp = meta.exps[-1]
latest_exp_history = latest_exp.history[-1]
begin = cast(int, latest_exp_history.get("end") or latest_exp_history["begin"])
exp_dir = Path(latest_exp.exp_dir)
git_dir = exp_dir / f"git-info-begin-{begin}"
if not git_dir.exists():
git_dir.mkdir(parents=True, exist_ok=True)
staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff"
commit = record_git_info(staged_path, unstaged_path)
git_info = GitInfo(
commit=commit,
staged=str(staged_path),
unstaged=str(unstaged_path),
)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
new_exp_history = ExpHistory(
begin=begin,
timestamp=timestamp,
git_info=git_info,
)
latest_exp.history.append(new_exp_history)
else:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
exp_dir = work_dir / timestamp
git_dir = Path(f"{exp_dir}/git-info-begin-0")
if not git_dir.exists():
git_dir.mkdir(parents=True, exist_ok=True)
staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff"
commit = record_git_info(staged_path, unstaged_path)
git_info = GitInfo(
commit=commit,
staged=str(staged_path),
unstaged=str(unstaged_path),
)
new_history = ExpHistory(
begin=0,
timestamp=timestamp,
git_info=git_info,
)
new_exp = ExpInfo(history=[new_history], exp_dir=str(exp_dir))
meta.exps.append(new_exp)
return meta
class ResumeConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
resume_from: str | Path | None = None
auto_resume: bool = False
load_optimizer_states: bool = True
load_optimizer_args: bool = True
load_dataset: bool = True
load_scheduler: bool = True
@runtime_checkable
class CheckpointHookBase(Protocol):
def __call__(
self,
checkpoint: Path,
step: int,
epoch: int | None,
total_step: int,
total_epoch: int | None,
) -> None: ...
@runtime_checkable
class CheckpointHook(CheckpointHookBase, Protocol):
def connect_trainer(self, trainer: "Trainer"): ...
@runtime_checkable
class TrainStepHookBase(Protocol):
def __call__(
self,
train_step_info: TrainStepInfo,
step: int,
epoch: int | None,
total_step: int,
total_epoch: int | None,
) -> None: ...
@runtime_checkable
class TrainStepHook(TrainStepHookBase, Protocol):
def connect_trainer(self, trainer: "Trainer"): ...
TrainStepHookProtocol = TrainStepHookBase | TrainStepHook
CheckpointHookProtocol = CheckpointHookBase | CheckpointHook
HookProtocol = TrainStepHookProtocol | CheckpointHookProtocol
class HookStage(StrEnum):
AFTER_SAVE_DCP = "after_save_dcp"
AFTER_SAVE_HF = "after_save_hf"
AFTER_SAVE_SNAPSHOT = "after_save_snapshot"
AFTER_TRAIN_STEP = "after_train_step"
class HooksConfig(BaseModel):
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
after_save_dcp: list[CheckpointHookProtocol] | CheckpointHookProtocol | None = None
after_save_hf: list[CheckpointHookProtocol] | CheckpointHookProtocol | None = None
after_save_snapshot: list[CheckpointHookProtocol] | CheckpointHookProtocol | None = None
after_train_step: list[TrainStepHookBase] | TrainStepHookBase | None = None
@field_validator("after_train_step", "after_save_dcp", "after_save_hf", "after_save_snapshot", mode="after")
@classmethod
def _validate_hooks(
cls,
value: list[HookProtocol] | HookProtocol | None,
) -> list[Callable] | None:
if value is None:
return None
if not isinstance(value, list):
value = [value]
return value
def _get_hook_name(self, hook: HookProtocol) -> str:
if inspect.isfunction(hook):
return hook.__name__
else:
return hook.__class__.__name__
@model_serializer
def serialize_hooks(self) -> dict[str, list[str] | None]:
def serialize_hook_list(hook_list: Sequence[HookProtocol]) -> list[str]:
return [self._get_hook_name(hook) for hook in hook_list]
return {
"after_save_dcp": serialize_hook_list(self.get_hooks(HookStage.AFTER_SAVE_DCP)),
"after_save_hf": serialize_hook_list(self.get_hooks(HookStage.AFTER_SAVE_HF)),
"after_save_snapshot": serialize_hook_list(self.get_hooks(HookStage.AFTER_SAVE_SNAPSHOT)),
"after_train_step": serialize_hook_list(self.get_hooks(HookStage.AFTER_TRAIN_STEP)),
}
@overload
def get_hooks(self, stage: Literal[HookStage.AFTER_TRAIN_STEP]) -> list[TrainStepHookProtocol]: ...
@overload
def get_hooks(
self,
stage: Literal[HookStage.AFTER_SAVE_DCP, HookStage.AFTER_SAVE_HF, HookStage.AFTER_SAVE_SNAPSHOT],
) -> list[CheckpointHookProtocol]: ...
@overload
def get_hooks(self, stage: HookStage) -> list[HookProtocol]: ...
def get_hooks(
self,
stage: HookStage,
) -> list:
hooks = getattr(self, stage)
if hooks is None:
return []
if not isinstance(hooks, list):
hooks = [hooks]
return hooks
def __getstate__(self):
state = {}
for k, v in self.__dict__.items():
try:
pickle.dumps(v)
# Some <local> function could raise AttributeError
except (pickle.PicklingError, AttributeError):
state[k] = f"<unpicklable: {type(v)}>"
else:
state[k] = v
return state
def __setstate__(self, state):
valid_state = {
k: None if isinstance(v, str) and v.startswith("<unpicklable:") else v for k, v in state.items()
}
self.__dict__.update(valid_state)
class LoadCheckpointConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
checkpoint_path: str | Path | None = None
load_optimizer_states: bool = True
load_optimizer_args: bool = True
load_dataset: bool = True
load_scheduler: bool = True
class TrainerConfig(BaseModel):
model_config = ConfigDict(
title="Trainer config",
extra="forbid",
arbitrary_types_allowed=True,
protected_namespaces=(),
)
model_cfg: XTunerBaseModelConfig
load_from: str | Path | None = None
tokenizer_path: str | Path | None = None
dataset_cfg: Annotated[DatasetConfigList | None, Parameter(show_default=False)] = (
None # TODO: Removed in version 1.1.0
)
dataloader_cfg: BaseDataloaderConfig
optim_cfg: OptimConfig
lr_cfg: LRConfig
loss_cfg: CELossConfig = CELossConfig()
fsdp_cfg: FSDPConfig | None = None
global_batch_size: int | None
work_dir: Path | str | None = None
log_dir: Path | str | None = None
sp_size: int = 1
total_step: int | None = None
total_epoch: int | None = None
resume_cfg: ResumeConfig | None = None # TODO: Removed in version 1.1.0
auto_resume: bool = False
load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig()
strict_load: bool = True
checkpoint_interval: int | None = -1
checkpoint_maxkeep: int | None = -1
async_hf_export: bool = False
skip_checkpoint_validation: bool = False # Suggest enabled if fsdp_size is larger than 512
patch_for_dcp_finish: bool = False
async_checkpoint: bool = False
snapshot_interval: int | None = None
check_health_interval: int | None = None
hf_interval: int | None = None
hf_max_keep: int | None = None
exp_tracker: Literal["tensorboard", "jsonl"] = "jsonl"
profile_step: list[int] | int | None = None
profile_time: bool = True
profile_memory: bool = False
intra_layer_micro_batch: int = 1
seed: int = 42
dist_backend: str | None = None
debug: bool = False
debug_skip_save: bool = False
prober_list: list[str] = []
do_clip: bool = True
grad_norm_dtype: torch.dtype = torch.float32
hooks_config: HooksConfig = HooksConfig()
internal_metrics_cfg: InternalMetricsConfig | None = None
@model_validator(mode="after")
def _convert_work_dir(self):
if isinstance(self.work_dir, str):
self.work_dir = Path(self.work_dir)
elif self.work_dir is None:
self.work_dir = Path.cwd()
return self
@field_serializer("grad_norm_dtype")
def serialize_dtype(self, value: torch.dtype) -> str:
return str(value)
@field_validator("grad_norm_dtype", mode="before")
@classmethod
def deserialize_dtype(cls, value: str) -> torch.dtype:
if "float32" in value:
return torch.float32
elif "float64" in value:
return torch.float64
else:
raise ValueError(f"grad_norm_dtype {value} is not supported, must be 'float32' or 'float64'")
[docs]class Trainer:
"""Trainer class for fine-tuning transformer models with FSDP support.
This class provides a high-level interface for training transformer models
with configurable distributed training, optimization, and checkpointing.
It supports various training configurations including sequence parallelism,
tensor parallelism, and data parallelism.
Args:
load_from (str | Path | None): Path to Huggingface model or saved trainer checkpoint.
model_cfg (TransformerConfig | InternS1BaseConfig): Configuration for the transformer model architecture.
optim_cfg (OptimConfig): Configuration for the optimizer.
fsdp_cfg (FSDPConfig | None): Configuration for Fully Sharded Data Parallel (FSDP).
dataset_cfg (DatasetConfigList): Configuration for training datasets.
dataloader_cfg (DataloaderConfig): Configuration for the data loader.
loss_cfg (CELossConfig | None): Config for the cross-entropy loss function.
lr_cfg (LRConfig): Configuration for the learning rate scheduler.
tokenizer_path (str | Path | None): Path to the tokenizer.
global_batch_size (int | None): Global batch size for training.
work_dir (Path | str | None): Directory for saving experiment outputs.
log_dir (Path | str | None): Directory for log files.
sp_size (int): Sequence parallel size.
total_step (int | None): Total training steps.
total_epoch (int | None): Number of training epochs.
resume_cfg (ResumeConfig | None): Configuration for resuming training.
auto_resume (bool): Whether to automatically resume training. Defaults to False.
load_checkpoint_cfg (LoadCheckpointConfig): Configuration for loading checkpoints.
strict_load (bool): Whether to strictly load model weights.
checkpoint_interval (int | None): Interval for saving checkpoints.
checkpoint_maxkeep (int | None): Maximum number of checkpoints to keep.
patch_for_dcp_finish (bool): If True, skip returning finish_checkpoint result.
hf_interval (int | None): Interval for saving Huggingface format checkpoints.
hf_max_keep (int | None): Maximum number of Huggingface checkpoints to keep.
profile_step (list[int] | int | None): Step to perform profiling.
profile_time (bool): Whether to profile training time.
profile_memory (bool): Whether to profile memory usage.
intra_layer_micro_batch (int): Intra-layer micro batch size.
seed (int): Random seed for reproducibility.
debug (bool): Whether to enable debug mode.
backend (str): Backend for distributed training.
"""
_config: TrainerConfig | None
_META_PATH = ".xtuner"
_PROFILE_TIME_PATH = "profiling_time"
_PROFILE_MEMORY_PATH = "profiling_memory"
_EXP_TRACKING_PATH = "exp_tracking"
_CHECKPOINT_DIR = "checkpoints"
_SAVE_WEIGHTS_DIR = "weights"
_SAVE_DATALOADER_DIR = "dataloader"
_SAVE_SCHEDULER_DIR = "lr_scheduler"
_SAVE_TRAIN_STATE_PATH = "train_state.json"
_DEFAULT_LOG_DIR = "logs"
def __init__(
self,
*,
load_from: str | Path | None = None, # Huggingface model path or saved trainer_path
model_cfg: XTunerBaseModelConfig,
optim_cfg: OptimConfig,
fsdp_cfg: FSDPConfig | None = FSDPConfig(),
dataset_cfg: DatasetConfigList | None = None, # TODO: Removed in version 1.1.0
dataloader_cfg: DataloaderConfig,
loss_cfg: CELossConfig | None = CELossConfig(),
lr_cfg: LRConfig,
tokenizer_path: str | Path | None = None,
global_batch_size: int | None,
work_dir: Path | str | None = None,
log_dir: Path | str | None = None,
sp_size: int = 1,
total_step: int | None = None,
total_epoch: int | None = None,
resume_cfg: ResumeConfig | None = ResumeConfig(), # TODO: Removed in version 1.1.0
auto_resume: bool = False,
load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig(),
strict_load: bool = True,
checkpoint_interval: int | None = -1,
checkpoint_maxkeep: int | None = -1,
async_hf_export: bool = False,
skip_checkpoint_validation: bool = False, # Suggest enabled if fsdp_size is larger than 512
patch_for_dcp_finish: bool = False,
async_checkpoint: bool = False,
snapshot_interval: int | None = None,
check_health_interval: int | None = None,
hf_interval: int | None = None,
hf_max_keep: int | None = None,
exp_tracker: Literal["tensorboard", "jsonl"] = "jsonl",
profile_step: list[int] | int | None = None,
profile_time: bool = True,
profile_memory: bool = False,
intra_layer_micro_batch: int = 1,
seed: int = 42,
debug: bool = False,
backend: str | None = None,
debug_skip_save: bool = False,
prober_list: list[str] = [],
do_clip: bool = True,
grad_norm_dtype: torch.dtype = torch.float32,
trainer_cfg: TrainerConfig | None = None,
hooks_config: HooksConfig = HooksConfig(),
internal_metrics_cfg: InternalMetricsConfig | None = None,
):
self._do_clip = do_clip
self._grad_norm_dtype = grad_norm_dtype
self._dataloader_config = dataloader_cfg
self._total_step = total_step
self._total_epoch = total_epoch
self._cur_epoch = 1
self._cur_step = 0
self._trainer_cfg = trainer_cfg
self._micro_batch_size: int | None = None
if skip_checkpoint_validation:
patch_default_save_plan()
if patch_for_dcp_finish:
if torch.__version__.startswith("2.7."):
patch_dcp_save_state_dict()
patch_dcp_save_with_cache_storage()
if isinstance(profile_step, int):
profile_step = [profile_step]
self._profile_step = profile_step
self._profile_time = profile_time
self._profile_memory = profile_memory
self._load_from = Path(load_from) if isinstance(load_from, str) else load_from
is_hf_path, error_info = is_hf_model_path(load_from) if load_from is not None else False, None
self._load_from_hf = is_hf_path
self._checkpoint_interval = checkpoint_interval
self._checkpoint_maxkeep = checkpoint_maxkeep
self._async_hf_export = async_hf_export
self._pending_async_hf_handle: AsyncHFSaveHandle | None = None
self._pending_async_hf_step: int | None = None
self._pending_async_hf_epoch: int | None = None
self._async_checkpoint = async_checkpoint
self._pending_checkpoint: Future | None = None
self._snapshot_interval = snapshot_interval
self._check_health_interval = check_health_interval
self._hf_max_keep = hf_max_keep
self._hf_interval = hf_interval
if fsdp_cfg is None:
fsdp_cfg = FSDPConfig()
self._fsdp_config = fsdp_cfg
self._optim_config = optim_cfg
self._sp_size = sp_size
self._debug = debug
self._seed = seed
# 日志变量前缀规则:
# 空间上,当前rank的用 local_,默认 reduced 无前缀
# 时间上,当前步用 step_, 累积用 total_
# self._local_total_consumed_tokens 表示时间上累积到现在的当前rank的和,resume则只考虑resume步数到现在
self._local_total_consumed_tokens = 0
self._init_total_tokens = 0
self._train_time = 0
self._train_time_offset = 0
self._init_dist(backend)
if resume_cfg is None:
resume_cfg = ResumeConfig()
self._work_dir = self._resolve_work_dir(work_dir)
self._auto_resume = auto_resume
self._auto_resume = self._resolve_deprecated_resume_cfg(
resume_cfg, self._auto_resume
) # TODO: Removed in version 1.1.0
self._meta = self._init_xtuner_meta(self.work_dir, auto_resume=self._auto_resume)
self._log_dir = self._resolve_log_dir(log_dir) # depends on exp_dir(work_dir and meta)
self.logger, log_dir = self._init_logger(self._log_dir) # depends on log_dir and init_dist(get_rank)
# After init logger
log_rank0.warning("`resume_cfg` is deprecated, please use `auto_resume` and `load_checkpoint_cfg` instead")
self._try_bind_numa()
self._set_deterministic()
self._set_random_seed(seed)
self._setup_env()
if tokenizer_path is not None:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
else:
self.tokenizer = UTF8ByteTokenizer()
log_rank0.info(f"Using toy tokenizer: {self.tokenizer}!")
self._load_checkpoint_cfg = self._resolve_load_checkpoint_cfg(self._auto_resume, load_checkpoint_cfg)
self._exp_tracker = self._init_tracker(
exp_tracker, self._log_dir / f"{self._EXP_TRACKING_PATH}/rank{self.rank}"
)
self.data_mesh = self._init_data_mesh(
fsdp_cfg.tp_size,
sp_size,
)
self.sp_mesh = self.data_mesh["sp"]
if global_batch_size is None:
global_batch_size = self.data_mesh["dp"].size() * intra_layer_micro_batch
self._global_batch_size = global_batch_size
self._resolve_model_loss_cfg(model_cfg, loss_cfg)
if loss_cfg is None:
loss_cfg = CELossConfig()
self._resolve_config_conflicts(self.tokenizer, model_cfg, dataloader_cfg, fsdp_cfg)
if dataset_cfg is not None: # TODO: Removed in version 1.1.0
log_rank0.warning("`dataset_cfg` is deprecated, please use `dataloader_cfg.dataset_config_list` instead")
# For backward compatibility, reserve the dataset_cfg interface, remove it later
if dataloader_cfg.dataset_config_list is not None:
log_rank0.warning("Outside dataset_cfg will override inner dataset_config_list")
dataloader_cfg.dataset_config_list = dataset_cfg
self._dataloader = dataloader_cfg.build(
tokenizer=self.tokenizer,
dp_mesh=self.data_mesh["dp"],
global_batch_size=self.global_batch_size,
micro_batch_size=self.micro_batch_size,
seed=seed,
total_step=total_step,
)
# streaming dataloader may override `total_step`, so we may move this check after `build_dataloader` later.
assert total_epoch is not None or total_step is not None, "`total_epoch` or `total_step` should be set"
assert total_epoch is None or total_step is None, (
f"`total_epoch`: {total_epoch}, `total_step`: {total_step} should not be set at the same time"
)
if isinstance(load_from, str):
load_from = Path(load_from)
self._can_save_hf = model_cfg.hf_config is not None or self._load_from_hf
if not self._can_save_hf:
assert_info = (
f"`hf_interval`: {hf_interval}, `hf_max_keep`: {hf_max_keep} and "
f"`async_hf_export`: {async_hf_export} "
f"should be None when `load_from` is not a Huggingface model path, "
)
if is_hf_path is False and error_info is not None:
assert_info += f", HF path load error Info: {error_info}"
assert hf_interval is None and hf_max_keep is None and async_hf_export is False, assert_info
self._engine = self.build_engine(
model_path=load_from,
model_config=model_cfg,
optim_config=optim_cfg,
fsdp_config=fsdp_cfg,
load_checkpoint_path=self._load_checkpoint_cfg.checkpoint_path,
strict=strict_load,
intra_layer_micro_batch=intra_layer_micro_batch,
)
self._lr_cfg = lr_cfg
self._lr_scheduler = self.build_lr_scheduler(lr_cfg, self.total_step)
self.loss_cfg = loss_cfg
if debug:
self._register_debug_hook()
if self._can_save_hf and self._hf_interval is None:
self._hf_interval = self.total_step
if debug_skip_save:
self._hf_interval = None
self._checkpoint_interval = None
self._snapshot_interval = None
if self._load_checkpoint_cfg.checkpoint_path is not None:
self._load_checkpoint()
self.hooks_config = self._setup_hooks(hooks_config=hooks_config)
setup_prober_list(self.exp_dir, self._profile_step, self._engine.model, prober_list)
self._metrics_recorder = self._maybe_init_model_metrics_recorder(internal_metrics_cfg)
[docs] @classmethod
def from_config(cls, config: TrainerConfig) -> Self:
"""Create a Trainer instance from a TrainerConfig.
Args:
config (TrainerConfig): TrainerConfig instance containing all configuration parameters.
Returns:
Self: Trainer instance initialized with the provided config.
"""
self = cls(
load_from=config.load_from,
model_cfg=config.model_cfg,
optim_cfg=config.optim_cfg,
fsdp_cfg=config.fsdp_cfg,
dataset_cfg=config.dataset_cfg,
dataloader_cfg=config.dataloader_cfg,
loss_cfg=config.loss_cfg,
lr_cfg=config.lr_cfg,
tokenizer_path=config.tokenizer_path,
global_batch_size=config.global_batch_size,
work_dir=config.work_dir,
log_dir=config.log_dir,
sp_size=config.sp_size,
total_step=config.total_step,
total_epoch=config.total_epoch,
resume_cfg=config.resume_cfg,
auto_resume=config.auto_resume,
load_checkpoint_cfg=config.load_checkpoint_cfg,
strict_load=config.strict_load,
checkpoint_interval=config.checkpoint_interval,
checkpoint_maxkeep=config.checkpoint_maxkeep,
async_hf_export=config.async_hf_export,
skip_checkpoint_validation=config.skip_checkpoint_validation,
patch_for_dcp_finish=config.patch_for_dcp_finish,
async_checkpoint=config.async_checkpoint,
snapshot_interval=config.snapshot_interval,
check_health_interval=config.check_health_interval,
hf_interval=config.hf_interval,
hf_max_keep=config.hf_max_keep,
exp_tracker=config.exp_tracker,
profile_step=config.profile_step,
profile_time=config.profile_time,
profile_memory=config.profile_memory,
intra_layer_micro_batch=config.intra_layer_micro_batch,
seed=config.seed,
backend=config.dist_backend,
debug=config.debug,
debug_skip_save=config.debug_skip_save,
prober_list=config.prober_list,
do_clip=config.do_clip,
grad_norm_dtype=config.grad_norm_dtype,
hooks_config=config.hooks_config,
trainer_cfg=config,
internal_metrics_cfg=config.internal_metrics_cfg,
)
self._config = config
self._print_training_config()
return self
[docs] def fit(self):
"""Run the training loop.
This method executes the main training loop, iterating through the dataset and performing training steps. It
handles data loading, forward pass, backward pass, optimization, logging, and checkpointing.
"""
train_begin = time.time()
time_before_get_data = time.time()
for data_batch in self._data_iter():
time_before_train_step = time.time()
ProberList.set_step(self._cur_step + 1)
DEVICE_MODULE.reset_peak_memory_stats()
with self._maybe_profiling():
engine_input = self._prepare_model_input(data_batch)
train_step_info = self._engine.train_step(engine_input)
hooks = self.hooks_config.get_hooks(HookStage.AFTER_TRAIN_STEP)
for hook in hooks:
hook(
train_step_info=train_step_info,
step=self.cur_step,
epoch=self._cur_epoch,
total_step=self.total_step,
total_epoch=self.total_epoch,
)
grad_norm = self._engine.clip_grad_norm(do_clip=self._do_clip, dtype=self._grad_norm_dtype)
self._engine.step_optimizer(grad_norm)
time_after_train_step = time.time()
ProberList.after_step()
data_time = time_before_train_step - time_before_get_data
step_time = time_after_train_step - time_before_train_step
internal_metrics = self._maybe_pop_model_internal_metrics(engine_input)
self._cur_step += 1
step_tokens = train_step_info["step_consumed_tokens"]
self._local_total_consumed_tokens += step_tokens
self._train_time = time_after_train_step - train_begin
# Compute training metrics
training_metrics = self._compute_performance_metrics(
local_step_consumed_tokens=step_tokens,
local_step_consumed_img_tokens=train_step_info.get("step_consumed_img_tokens"),
step_time=step_time,
)
# TODO: This log should be move before lr_scheduler.step, but for CI BC, keep it temporarily
self._log_step(
train_step_info=train_step_info,
training_metrics=training_metrics,
grad_norm=grad_norm.item(),
data_time=data_time,
step_time=step_time,
internal_metrics=internal_metrics,
)
self._lr_scheduler.step()
self._maybe_check_health()
self._maybe_save_hf()
ckpt_saved = self._maybe_save(is_snapshot=False)
if not ckpt_saved:
_ = self._maybe_save(is_snapshot=True)
time_before_get_data = time.time()
if self.cur_step % 50 == 0:
gc.collect()
self._wait_for_pending_async_hf()
# TODO: Should use flush rather than close
self._wait_for_pending_checkpoint()
self._engine.destroy_async_checkpoint_pg()
self._exp_tracker.close()
if self._metrics_recorder:
self._metrics_recorder.close()
log_rank0.info(f"Training finished in {time.time() - train_begin:.2f} seconds")
dist.barrier()
def _prepare_model_input(self, data_batch) -> list[ModelItem]:
seq_ctx_list: list[SequenceContext] = []
# 1. Extract seq_ctx
for data in data_batch:
seq_ctx = data["seq_ctx"].to(DEVICE)
if self.sp_mesh.size() > 1:
seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh)
seq_ctx_list.append(seq_ctx)
# 2. Compute cu_seq_lens_list (for calibration)
# 3. Call model's interface to build and calibrate all loss_ctx (done in one shot)
loss_ctx_dict_list = self._engine.model.build_loss_ctx_batch(data_batch, sp_mesh=self.sp_mesh)
# TODO: Consider moving data_batch deletion to the caller for better memory management.
del data_batch
# 4. Return ModelItem
engine_input = [
ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx_dict)
for seq_ctx, loss_ctx_dict in zip(seq_ctx_list, loss_ctx_dict_list)
]
return engine_input
def _reduce_number_across_rank(self, rank_number: int | float) -> int:
_gathered_list = [None for _ in range(self.world_size)]
dist.all_gather_object(_gathered_list, rank_number)
reduced_number = sum(_gathered_list) # type: ignore[arg-type]
return reduced_number
def _maybe_init_model_metrics_recorder(
self,
internal_metrics_cfg: InternalMetricsConfig | None,
) -> InternalMetricsRecorder | None:
if internal_metrics_cfg and internal_metrics_cfg.internal_metrics_interval:
self._internal_metrics_interval = internal_metrics_cfg.internal_metrics_interval
assert self._internal_metrics_interval > 0, (
"internal_metrics_interval must be greater than zero (or set to `None`)"
)
torch._dynamo.config.skip_nnmodule_hook_guards = (
False # otherwise the hook will be ignored for compiled modules
)
return InternalMetricsRecorder(internal_metrics_cfg, self._engine.model)
else:
return None
def _maybe_pop_model_internal_metrics(self, data_batches: list[ModelItem]) -> InternalMetrics | None:
if not self._metrics_recorder:
return None
if self._internal_metrics_interval is None:
return None
if self.cur_step % self._internal_metrics_interval != 0 and self.cur_step != self.total_step:
return None
with profile_time_and_memory("[Check Model Internal Metrics]"):
metrics = self._metrics_recorder.pop_metrics(data_batches)
return metrics
@property
def world_size(self) -> int:
"""Get the total number of processes in the distributed training group.
Returns:
int: Total number of processes.
"""
return get_world_size()
@property
def rank(self) -> int:
"""Get the rank of the current process in the distributed training
group.
Returns:
int: Rank of the current process.
"""
return get_rank()
@property
def micro_batch_size(self) -> int:
"""Calculate the micro batch size per data parallel rank.
Returns:
int: Micro batch size for the current rank.
"""
if self._micro_batch_size is None:
micro_batch_size = self.global_batch_size / self.data_mesh["dp"].size()
if not micro_batch_size.is_integer():
raise ParallelConfigException(
f"Global batch size {self.global_batch_size} must be divisible by "
f"data parallel size {self.data_mesh['dp'].size()}. "
"Please adjust the global batch size."
)
self._micro_batch_size = int(micro_batch_size)
return self._micro_batch_size
@property
def global_batch_size(self) -> int:
"""Get the global batch size across all data parallel ranks.
Returns:
int: Global batch size.
"""
return self._global_batch_size
@property
def total_step(self) -> int:
"""Calculate the total number of training steps.
Returns:
int: Total training steps.
"""
if self._total_step is None:
assert isinstance(self._dataloader, Sized), (
f"`total_epoch` should be set for a Mapped dataset, but got {self._dataloader.dataset}"
)
self._total_step = len(self._dataloader) * cast(int, self._total_epoch)
return self._total_step
@property
def total_epoch(self) -> int | None:
return self._total_epoch
@property
def cur_step(self) -> int:
"""Get the current training step.
Returns:
int: Current step number.
"""
return self._cur_step
@property
def cur_epoch(self) -> int | None:
"""Get the current training epoch.
Returns:
int | None: Current epoch number or None if not applicable.
"""
return self._cur_epoch
@property
def config(self) -> TrainerConfig | None:
return self._config
def _init_logger(self, log_dir: Path):
# Logging system maybe need better design
log_level = os.environ.get("XTUNER_LOG_LEVEL", "INFO").upper()
logger = get_logger()
logger.remove()
logger.add(log_dir / f"rank{get_rank()}.log", format=log_format(), backtrace=True, catch=True, level="DEBUG")
# Set log level to hide debug output
logger.add(sys.stderr, format=log_format(rank=get_rank()), level=log_level)
return logger, log_dir
def _init_tracker(self, exp_tracker: Literal["tensorboard", "jsonl"], log_dir: Path):
writer = get_writer(writer_type=exp_tracker, log_dir=log_dir)
return writer
def _init_data_mesh(
self,
tp_size: int,
sp_size: int,
):
if self.world_size % tp_size != 0:
raise ParallelConfigException(
f"Found tp_size {tp_size}, world_size {self.world_size}."
"tensor parallel size must be a divisor of world size."
)
if self.world_size % sp_size != 0:
raise ParallelConfigException(
f"Found sp_size {sp_size}, world_size {self.world_size}."
"sequence parallel size must be a divisor of world size."
)
if self.world_size % (tp_size * sp_size) != 0:
raise ParallelConfigException(
f"Found tp_size {tp_size}, sp_size {sp_size}, world_size {self.world_size}."
"`tp_size * sp_size` size must be a divisor of world size."
)
dp_size = self.world_size // (tp_size * sp_size)
# TODO: fsdp_config could be None
device = str(DEVICE) if self._fsdp_config.cpu_offload else "cpu"
data_mesh = init_device_mesh(
device,
(dp_size, sp_size, tp_size),
mesh_dim_names=("dp", "sp", "tp"),
)
return data_mesh
[docs] def build_engine(
self,
model_path: Path | None,
model_config: XTunerBaseModelConfig,
optim_config: OptimConfig,
fsdp_config: FSDPConfig,
load_checkpoint_path: str | Path | None,
intra_layer_micro_batch: int = 1,
strict: bool = True,
):
"""Build the training engine for the transformer model.
Args:
model_path (Path | None): Path to the model checkpoint or None for new initialization.
model_config (TransformerConfig | BaseComposeConfig): Model configuration.
optim_config (OptimConfig): Optimizer configuration.
fsdp_config (FSDPConfig): FSDP configuration for distributed training.
resume_cfg (ResumeConfig | None): Resume configuration for continuing training.
intra_layer_micro_batch (int): Intra-layer micro batch size for gradient accumulation.
strict (bool): Whether to strictly load model weights.
Returns:
TrainEngine: Initialized training engine.
"""
engine = TrainEngine( # type: ignore
optim_cfg=optim_config,
fsdp_cfg=fsdp_config,
model_cfg=model_config,
intra_layer_micro_batch=intra_layer_micro_batch,
async_hf_export=self._async_hf_export,
)
if model_path is not None and (model_config.dcp_ignore_frozen_params or load_checkpoint_path is None):
engine.from_hf(hf_path=model_path, strict=strict)
elif load_checkpoint_path is None:
engine.init_model_weights()
if model_path is not None:
engine.model.set_hf(model_path)
if engine.model.compile_cfg is not None:
log_rank0.info(f"The `compile_cfg` of model is {json.dumps(engine.model.compile_cfg, indent=4)}")
return engine
[docs] def build_lr_scheduler(self, lr_cfg: LRConfig, scheduler_step: int) -> torch.optim.lr_scheduler.LRScheduler:
"""Build the learning rate scheduler.
Args:
lr_cfg (LRConfig): Configuration for the learning rate scheduler.
Returns:
torch.optim.lr_scheduler.LRScheduler: Configured learning rate scheduler.
"""
if lr_cfg.warmup_ratio < 1:
warmup_steps = int(lr_cfg.warmup_ratio * scheduler_step)
else:
warmup_steps = int(lr_cfg.warmup_ratio)
def warmup_fn(x):
return x / warmup_steps if x < warmup_steps else 1
warmup_scheduler = LambdaLR(self._engine.optimizer, warmup_fn)
scheduler: torch.optim.lr_scheduler.LRScheduler
if lr_cfg.lr_type == "linear":
scheduler = LinearLR(
self._engine.optimizer,
start_factor=1.0,
end_factor=lr_cfg.lr_min / self._engine.optimizer.defaults["lr"],
total_iters=scheduler_step - warmup_steps,
)
elif lr_cfg.lr_type == "cosine":
scheduler = CosineAnnealingLR(
self._engine.optimizer, T_max=scheduler_step - warmup_steps, eta_min=lr_cfg.lr_min
)
elif lr_cfg.lr_type == "constant":
scheduler = LambdaLR(self._engine.optimizer, lambda x: 1.0)
else:
raise ValueError(f"Unsupported lr type: {lr_cfg.lr_type}")
lr_scheduler = SequentialLR(
optimizer=self._engine.optimizer,
schedulers=[warmup_scheduler, scheduler],
milestones=[warmup_steps],
)
return lr_scheduler
def _maybe_check_health(self):
if self._check_health_interval is None:
return
if (
(self._check_health_interval is not None and self.cur_step % self._check_health_interval == 0)
or (self._checkpoint_interval is not None and self.cur_step % self._checkpoint_interval == 0)
or (self._snapshot_interval is not None and self.cur_step % self._snapshot_interval == 0)
):
if not check_health():
raise RuntimeError("Health check failed, exit training")
log_rank0.info(f"Health check passed at step {self.cur_step}")
def _wait_for_pending_checkpoint(self, timeout: int = 3000) -> None:
if self._pending_checkpoint is None:
return
future = self._pending_checkpoint
self._pending_checkpoint = None
try:
future.result(timeout=timeout)
except TimeoutError:
future.cancel()
raise TimeoutError(f"Async checkpoint timed out after {timeout}s")
def _maybe_save(self, is_snapshot: bool = False) -> bool:
ckp_interval = self._checkpoint_interval if not is_snapshot else self._snapshot_interval
if ckp_interval is None:
return False
if ckp_interval == -1: # only save at the end of training
if self._cur_step != self.total_step:
return False
else:
if self.cur_step % ckp_interval != 0 and (is_snapshot or self._cur_step != self.total_step):
# if is_snapshot, only save at interval
# else save at interval or at the end of training
return False
checkpoint_path = self._get_checkpoint_path(epoch=self._cur_epoch, step=self.cur_step, is_snapshot=is_snapshot)
checkpoint_path.mkdir(parents=True, exist_ok=True)
# Ensure at most one async checkpoint is in flight.
self._wait_for_pending_checkpoint()
meta_path = self.work_dir / self._META_PATH
weights_path = checkpoint_path / self._SAVE_WEIGHTS_DIR
dataloader_path = checkpoint_path / self._SAVE_DATALOADER_DIR
scheduler_path = checkpoint_path / self._SAVE_SCHEDULER_DIR
train_state_path = checkpoint_path / self._SAVE_TRAIN_STATE_PATH
total_consumed_tokens = (
self._reduce_number_across_rank(self._local_total_consumed_tokens) + self._init_total_tokens
)
if self.cur_step % ckp_interval == 0:
DEVICE_MODULE.empty_cache()
# Save model and optimizer
future: Future | None = None
if self._async_checkpoint and not is_snapshot:
future = self._engine.async_save_dcp(weights_dir=weights_path)
else:
self._engine.save_dcp(weights_dir=weights_path)
# Save dataloader
self._save_dataloader(dataloader_path)
DEVICE_MODULE.empty_cache()
# Save scheduler
if self.rank == 0:
lr_scheduler_state = self._lr_scheduler.state_dict()
torch.save(lr_scheduler_state, scheduler_path)
# Save trainer config
if self._trainer_cfg is not None and self.rank == 0:
# TODO: Maybe we need a better way to serialize and deserialize config, rather than using pickle
config_path = checkpoint_path / "trainer_config.json"
config_bin = checkpoint_path / "trainer_config.bin"
with config_path.open("w") as f:
f.write(self._trainer_cfg.model_dump_json(indent=2))
with config_bin.open("wb") as f:
pickle.dump(self._trainer_cfg, f)
if future is not None:
self._pending_checkpoint = future
dist.barrier()
# Save train state
if self.rank == 0:
with train_state_path.open("w") as f:
f.write(
json.dumps(
{
"cur_step": self.cur_step,
"cur_epoch": self._cur_epoch,
"total_consumed_tokens": total_consumed_tokens,
"train_time_offset": self._train_time + self._train_time_offset,
}
)
)
# Update meta
current_exp = self.meta.latest_exp
ckp_list = current_exp.checkpoint_list if not is_snapshot else current_exp.snap_checkpoint_list
ckp_list.append(str(checkpoint_path))
current_exp.cur_step = self.cur_step
current_exp.cur_epoch = self._cur_epoch
current_exp.consumed_tokens = int(total_consumed_tokens)
current_exp.history[-1]["end"] = self.cur_step
# Delete checkpoints and update meta's checkpoint_list
ckp_maxkeep = self._checkpoint_maxkeep if not is_snapshot else 1
if ckp_maxkeep is not None and ckp_maxkeep > 0 and len(ckp_list) > ckp_maxkeep:
ckp_pop_num = len(ckp_list) - ckp_maxkeep
for _ in range(ckp_pop_num):
deleted_ckp = ckp_list.pop(0)
if self.rank == 0 and Path(deleted_ckp).exists():
rmtree(deleted_ckp)
# Save meta, must after deleting checkpoints to ensure the checkpoint_list is updated in the meta file
if self.rank == 0:
with meta_path.open("w") as f:
f.write(self.meta.model_dump_json(indent=2))
dist.barrier()
if is_snapshot:
hooks = self.hooks_config.get_hooks(HookStage.AFTER_SAVE_SNAPSHOT)
else:
hooks = self.hooks_config.get_hooks(HookStage.AFTER_SAVE_DCP)
for hook in hooks:
hook(
checkpoint=checkpoint_path,
step=self.cur_step,
epoch=self._cur_epoch,
total_step=self.total_step,
total_epoch=self.total_epoch,
)
return True
def _save_dataloader(self, dataloader_path: Path | str):
dataloader_state = self._dataloader.get_state_dict()
if self.rank == 0:
torch.save(dataloader_state, dataloader_path)
@property
def work_dir(self) -> Path:
"""Get the working directory for the trainer.
Returns:
Path: Working directory path.
"""
return self._work_dir
@property
def exp_dir(self) -> Path:
"""Get the experiment directory for the current run.
Returns:
Path: Experiment directory path.
"""
return Path(self._meta.latest_exp.exp_dir)
@property
def log_dir(self) -> Path:
return self._log_dir
@property
def checkpoint_dir(self) -> Path:
"""Get the path to the latest checkpoint.
Returns:
Path | None: Path to the latest checkpoint or None if not available.
"""
return self.exp_dir / self._CHECKPOINT_DIR
@property
def meta(self) -> XTunerMeta:
"""Get the XTuner metadata for tracking experiments.
Returns:
XTunerMeta: Experiment metadata tracker.
"""
return self._meta
def _data_iter(self):
data_iter = iter(self._dataloader)
while self._cur_step < self.total_step:
try:
data = next(data_iter)
except StopIteration:
self._cur_epoch += 1
self._dataloader.set_epoch(self._cur_epoch)
data_iter = iter(self._dataloader)
data = next(data_iter)
yield data
def _get_checkpoint_path(self, epoch: int, step: int, is_snapshot: bool = False) -> Path:
prefix = "snapshot-" if is_snapshot else "ckpt-"
# TODO: epoch在不同rank间可能不一致,在这个问题下使用 epoch 会出错, 待解决。
# 先使用 step 作为 checkpoint 的命名。
# return self.checkpoint_dir / f"{prefix}epoch-{epoch}-step-{step}"
return self.checkpoint_dir / f"{prefix}step-{step}"
def _set_deterministic(self):
if XTUNER_DETERMINISTIC:
log_rank0.info("Setting deterministic algorithms")
set_deterministic()
def _set_random_seed(self, seed: int):
set_random_seed(seed)
def _try_bind_numa(self):
if str(DEVICE) != "cuda":
log_rank0.info("Current device is not cuda, skip numa binding.")
return
try:
import numa
from numa import memory, schedule
numa_node_num = numa.info.get_max_node() + 1
total_GPU_per_node = DEVICE_MODULE.device_count()
# return while total_GPU_per_node is larger than numa_node_num or is not divisible by numa_node_num
if total_GPU_per_node <= numa_node_num:
return
if total_GPU_per_node % numa_node_num != 0:
return
# return while the number of processes is smaller than one node GPUs num
if self.world_size < total_GPU_per_node:
return
local_rank = self.rank % total_GPU_per_node
# compute numa id for each locak rank
per_numa = total_GPU_per_node // numa_node_num
numa_id = local_rank // per_numa
# bind numa node
schedule.run_on_nodes(numa_id)
memory.set_membind_nodes(numa_id)
except Exception:
logger.info(f"Rank: {self.rank} failed to bind process to numa node.")
return # try_bind_numa should not raise exception
else:
logger.info(f"Rank: {self.rank} success bind process to numa node: {numa_id}")
def _init_dist(self, backend: str | None = None):
if backend is None:
if torch.accelerator.current_accelerator().type == "cuda":
backend = "cpu:gloo,cuda:nccl"
elif torch.accelerator.current_accelerator().type == "npu":
backend = "cpu:gloo,npu:hccl"
else:
raise NotImplementedError
if not dist.is_initialized():
init_process_group(backend=backend)
torch.accelerator.set_device_index(int(os.environ["LOCAL_RANK"]))
# In some cases, the datasets can perform massive numpy loading before the first communication.
# After build dataset, massive numpy loading causing a lot of anonymous mmap allocation.
# THP(transparent huge page) kernel thread would continuously scan and merge these anonymous mmap to huge page.
# At the same time, if we perform communication for the first time, backend (e.g., NCCL) may register
# and lock address that might have been changed by THP, which causes a crash. So we should warmup first.
warmup_tensor = torch.ones(4, 4, device=torch.accelerator.current_accelerator())
dist.all_reduce(warmup_tensor)
def _init_xtuner_meta(self, work_dir: Path, auto_resume: bool) -> XTunerMeta:
# TODO: simplify with XTunerMeta.build() of dist version
if not work_dir.exists():
if self.rank == 0:
work_dir.mkdir(parents=True, exist_ok=True)
meta_path = work_dir / self._META_PATH
if not meta_path.exists() and self.rank == 0:
meta = XTunerMeta(exps=[])
with open(meta_path, "w") as f:
f.write(meta.model_dump_json(indent=2))
dist.barrier()
meta = cast(XTunerMeta, XTunerMeta.model_validate(load(meta_path, file_format="json")))
if auto_resume and meta.exps:
latest_exp = meta.exps[-1]
latest_exp_history = latest_exp.history[-1]
begin = cast(int, latest_exp_history.get("end") or latest_exp_history["begin"])
exp_dir = Path(latest_exp.exp_dir)
git_dir = exp_dir / f"git-info-begin-{begin}"
staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff"
if not git_dir.exists() and self.rank == 0:
git_dir.mkdir(parents=True, exist_ok=True)
commit = record_git_info(staged_path, unstaged_path)
_commit_tmp = [commit]
else:
_commit_tmp = [None] # type: ignore[list-item]
dist.broadcast_object_list(_commit_tmp, src=0)
commit = cast(str, _commit_tmp[0])
dist.barrier()
git_info = GitInfo(
commit=commit,
staged=str(staged_path),
unstaged=str(unstaged_path),
)
timestamp_list = [datetime.now().strftime("%Y%m%d%H%M%S")]
dist.broadcast_object_list(timestamp_list, src=0)
timestamp = timestamp_list[0]
new_exp_history = ExpHistory(
begin=begin,
timestamp=timestamp,
git_info=git_info,
)
latest_exp.history.append(new_exp_history)
else:
timestamp_list = [datetime.now().strftime("%Y%m%d%H%M%S")]
dist.broadcast_object_list(timestamp_list, src=0)
timestamp = timestamp_list[0]
exp_dir = work_dir / timestamp
git_dir = Path(f"{exp_dir}/git-info-begin-{0}")
if not git_dir.exists() and self.rank == 0:
git_dir.mkdir(parents=True, exist_ok=True)
dist.barrier()
staged_path, unstaged_path = git_dir / "staged.diff", git_dir / "unstaged.diff"
if self.rank == 0:
commit = record_git_info(staged_path, unstaged_path)
_commit_tmp = [commit]
else:
_commit_tmp = [None] # type: ignore[list-item]
dist.broadcast_object_list(_commit_tmp, src=0)
commit = cast(str, _commit_tmp[0])
git_info = GitInfo(
commit=commit,
staged=str(staged_path),
unstaged=str(unstaged_path),
)
new_history = ExpHistory(
begin=0,
timestamp=timestamp,
git_info=git_info,
)
new_exp = ExpInfo(
history=[new_history],
exp_dir=str(exp_dir),
)
meta.exps.append(new_exp)
return meta
@contextmanager
def _maybe_profiling(self):
"""Check if profiling is enabled and perform profiling if necessary."""
if self._profile_step is not None and self._cur_step in self._profile_step:
with contextlib.ExitStack() as stack:
if self._profile_time:
time_dir = self.exp_dir / self._PROFILE_TIME_PATH / f"step-{self._cur_step}"
stack.enter_context(profiling_time(time_dir))
if self._profile_memory:
memory_dir = self.exp_dir / self._PROFILE_MEMORY_PATH / f"step-{self._cur_step}"
stack.enter_context(profiling_memory(memory_dir))
yield
else:
yield
def _compute_performance_metrics(
self,
local_step_consumed_tokens: int,
local_step_consumed_img_tokens: int | None,
step_time: float,
) -> PerformanceStatistics:
"""Compute training metrics including tokens and throughput statistics.
Args:
local_step_consumed_tokens (int): Tokens consumed in current step on current rank.
local_step_consumed_img_tokens (int | None): Image tokens consumed in current step on current rank.
step_time (float): Time spent on current training step in seconds.
Returns:
TrainingMetrics: Dictionary containing computed training metrics.
"""
e2e_train_time = self._train_time + self._train_time_offset
tgs = local_step_consumed_tokens / step_time
approximate_total_consumed_tokens = (
self._init_total_tokens + self._local_total_consumed_tokens * self.world_size
)
# TODO: approximate_total_consumed_tokens_per_rank could be incorrect if world_size changed.
# So calculate `eta_seconds = step_time * remaining_steps` instead?
approximate_total_consumed_tokens_per_rank = approximate_total_consumed_tokens / self.world_size
exp_tgs = self._local_total_consumed_tokens / self._train_time if self._train_time > 0 else 0.0
remaining_steps = self.total_step - self.cur_step
avg_tokens_per_step = approximate_total_consumed_tokens_per_rank / self.cur_step
remaining_tokens = remaining_steps * avg_tokens_per_step
eta_seconds = remaining_tokens / max(tgs, 1)
eta_hms = str(timedelta(seconds=int(eta_seconds)))
return PerformanceStatistics(
local_step_consumed_tokens=local_step_consumed_tokens,
local_step_consumed_img_tokens=local_step_consumed_img_tokens,
local_total_consumed_tokens=self._local_total_consumed_tokens,
approximate_total_consumed_tokens=approximate_total_consumed_tokens,
tgs=tgs,
exp_tgs=exp_tgs,
eta_seconds=eta_seconds,
eta_hms=eta_hms,
e2e_train_time=e2e_train_time,
)
def _log_step(
self,
train_step_info: TrainStepInfo,
training_metrics: PerformanceStatistics,
grad_norm: float,
data_time: float,
step_time: float,
internal_metrics: InternalMetrics | None = None,
):
"""Log the training step information.
Args:
train_step_info (TrainStepInfo): Info returned per engine train_step.
training_metrics (TrainingMetrics): Computed training metrics including tokens and throughput.
grad_norm (float): Gradient norm value.
data_time (float): Time spent loading data in seconds.
step_time (float): Time spent on training step in seconds.
internal_metrics (InternalMetrics | None): Internal metrics from the model.
"""
train_step_info = train_step_info.copy()
lr = self._lr_scheduler.get_last_lr()[0]
loss_logs_info = train_step_info.pop("logs_info") | {"local_loss": train_step_info.pop("total_loss")} # type: ignore[misc]
loss_log_list = [f"{k}: {v:.8f}" for k, v in loss_logs_info.items()]
loss_log_str = ", ".join(loss_log_list)
extra_info = train_step_info.pop("extra_info") # type: ignore[misc]
extra_info_log_list = [f"{k}: {v:.4f}" for k, v in extra_info.get().items()]
extra_info_str = ", ".join(extra_info_log_list)
data_info_str = ", ".join([f"{k}: {v:.8f}" for k, v in train_step_info.items()])
max_memory = DEVICE_MODULE.max_memory_allocated() # type: ignore[attr-defined]
reserved_memory = DEVICE_MODULE.max_memory_reserved() # type: ignore[attr-defined]
flattened_internal_metrics = {}
if internal_metrics:
flattened_internal_metrics = flatten_internal_metrics_for_logs(internal_metrics)
if training_metrics["local_step_consumed_img_tokens"] is not None:
img_tokens_str = f"img_tokens: {training_metrics['local_step_consumed_img_tokens']} "
else:
img_tokens_str = ""
self.logger.info(
f"Epoch {self._cur_epoch} Step {self.cur_step}/{self.total_step} "
f"data_time: {data_time:.4f} lr: {lr:.6e} time: {step_time:.4f} "
f"text_tokens: {training_metrics['local_step_consumed_tokens']} {img_tokens_str}"
f"approximate_total_consumed_tokens: {training_metrics['approximate_total_consumed_tokens']} "
f"{loss_log_str} "
f"{data_info_str} "
f"{extra_info_str} "
f"grad_norm: {grad_norm:.8f} "
f"max_memory: {max_memory / (1024**3):.2f} GB "
f"reserved_memory: {reserved_memory / (1024**3):.2f} GB "
f"tgs: {training_metrics['tgs']:.1f} "
f"exp_tgs: {training_metrics['exp_tgs']:.1f} "
f"eta: {training_metrics['eta_hms']} "
)
log_scalars = {
"lr": lr,
"time/data_time": round(data_time, 4),
"time/step_time": round(step_time, 4),
"time/train_time": round(self._train_time, 4),
"time/eta_seconds": round(training_metrics["eta_seconds"], 1),
"runtime_info/text_tokens": training_metrics["local_step_consumed_tokens"],
"runtime_info/approximate_total_consumed_tokens": training_metrics["approximate_total_consumed_tokens"],
"runtime_info/tgs": training_metrics["tgs"],
"runtime_info/exp_tgs": training_metrics["exp_tgs"],
"runtime_info/efficient_attn_ratio": train_step_info["efficient_attn_ratio"],
"runtime_info/img_efficient_attn_ratio": train_step_info["img_efficient_attn_ratio"],
"memory/max_memory_GB": round(max_memory / (1024**3), 3),
"memory/reserved_memory_GB": round(reserved_memory / (1024**3), 3),
"grad_norm": grad_norm,
**flattened_internal_metrics,
}
log_scalars.update({f"loss/{k}": v for k, v in loss_logs_info.items()})
self._exp_tracker.add_scalars(tag_scalar_dict=log_scalars, global_step=self.cur_step)
DEVICE_MODULE.reset_peak_memory_stats() # type: ignore[attr-defined]
def _maybe_save_hf(self):
if self._hf_interval is None:
return
assert self._can_save_hf, "Model does not support saving in Huggingface format."
if self.cur_step % self._hf_interval != 0 and self.cur_step != self.total_step:
return
save_hf_path = self.exp_dir / f"hf-{self.cur_step}"
if self._async_hf_export:
self._wait_for_pending_async_hf()
self._pending_async_hf_handle = self._engine.async_save_hf(str(save_hf_path))
self._pending_async_hf_step = self.cur_step
self._pending_async_hf_epoch = self._cur_epoch
return
else:
self._engine.save_hf(str(save_hf_path))
self._finalize_hf_save(
save_hf_path,
step=self.cur_step,
epoch=self._cur_epoch,
delete_hf_dirs=True,
)
return
def _wait_for_pending_async_hf(self) -> None:
if self._pending_async_hf_handle is None:
return
handle = self._pending_async_hf_handle
step = self._pending_async_hf_step
epoch = self._pending_async_hf_epoch
self._pending_async_hf_handle = None
self._pending_async_hf_step = None
self._pending_async_hf_epoch = None
finalized_hf_path = self._engine.wait_async_hf(handle)
assert finalized_hf_path is not None
assert step is not None
assert epoch is not None
self._finalize_hf_save(
finalized_hf_path,
step=step,
epoch=epoch,
delete_hf_dirs=True,
)
def _finalize_hf_save(self, finalized_hf_path: Path, step: int, epoch: int, delete_hf_dirs: bool) -> None:
latest_hf_link = self.exp_dir / "hf-latest"
save_hf_path = finalized_hf_path
self.meta.latest_exp.hf_checkpoint_list.append(str(save_hf_path))
if self._hf_max_keep is not None and len(self.meta.latest_exp.hf_checkpoint_list) > self._hf_max_keep:
deleted_hf_checkpoints = self.meta.latest_exp.hf_checkpoint_list[: -self._hf_max_keep]
self.meta.latest_exp.hf_checkpoint_list = self.meta.latest_exp.hf_checkpoint_list[-self._hf_max_keep :]
for hf_dir in deleted_hf_checkpoints:
if delete_hf_dirs and self.rank == 0 and Path(hf_dir).exists():
rmtree(hf_dir)
if self.rank == 0:
if isinstance(self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
self.tokenizer.save_pretrained(str(save_hf_path))
# 将 latest_hf_link 指向 save_hf_path
latest_hf_link.unlink(missing_ok=True)
latest_hf_link.symlink_to(save_hf_path.absolute(), target_is_directory=True)
meta_path = self.work_dir / self._META_PATH
if self.rank == 0:
with meta_path.open("w") as f:
f.write(self.meta.model_dump_json(indent=2))
hooks = self.hooks_config.get_hooks(HookStage.AFTER_SAVE_HF)
for hook in hooks:
hook(
checkpoint=save_hf_path,
step=step,
epoch=epoch,
total_step=self.total_step,
total_epoch=self.total_epoch,
)
def _register_debug_hook(self):
"""Register a debug hook function to be called at the end of each
training step."""
def _detect_nan(module: nn.Module, output):
if isinstance(output, torch.Tensor):
if output.isnan().any():
logger.warning(f"Detect NaN in output of module {module.__class__.__name__}")
elif isinstance(output, (tuple, list)):
for item in output:
_detect_nan(module, item)
elif isinstance(output, dict):
for value in output.values():
_detect_nan(module, value)
def module_debug_forward_hook(module, input, output):
"""Debug hook to print module name and input/output shapes."""
_detect_nan(module, output)
for model in self._engine.model.modules():
if isinstance(model, nn.Module):
model.register_forward_hook(module_debug_forward_hook)
def _resolve_work_dir(self, work_dir: Path | str | None) -> Path:
if work_dir is None:
work_dir = Path.cwd() / "work_dir"
if isinstance(work_dir, str):
work_dir = Path(work_dir)
if get_rank() == 0:
work_dir.mkdir(parents=True, exist_ok=True)
return work_dir
def _resolve_log_dir(self, log_dir: Path | str | None) -> Path:
if log_dir is None:
log_dir = self.exp_dir / self._DEFAULT_LOG_DIR
if isinstance(log_dir, str):
log_dir = Path(log_dir)
return log_dir
def _resolve_config_conflicts(
self,
tokenizer: PreTrainedTokenizer,
model_cfg: XTunerBaseModelConfig,
dataloader_cfg: DataloaderConfig,
fsdp_cfg: FSDPConfig,
):
if hasattr(tokenizer, "pad_token_id"):
pad_token_id = tokenizer.pad_token_id
else:
pad_token_id = tokenizer.eos_token_id
if not isinstance(pad_token_id, int):
log_rank0.warning(
f"Tokenizer pad_token_id is {pad_token_id}, which is not an integer. Setting pad_token_id to 0."
)
if isinstance(pad_token_id, list):
pad_token_id = pad_token_id[0]
assert isinstance(pad_token_id, int), f"pad_token_id should be an integer, but got {pad_token_id}"
# Model's pad_token_id only affects the embedding module which acts specially for pad token.
# Model's pad_token_id may be different from tokenizer's pad_token_id.
# Note: Qwen3 Model's pad_token_id is None, which is different from Qwen tokenizer's pad_token_id.
# if isinstance(model_cfg, BaseComposeConfig):
# if model_cfg.text_config.pad_token_id != pad_token_id:
# logger.warning(
# f"Model pad_token_id {model_cfg.text_config.pad_token_id} is different from tokenizer "
# f"pad_token_id {pad_token_id}. Using tokenizer pad_token_id {pad_token_id}."
# )
# model_cfg.text_config.pad_token_id = pad_token_id
# elif model_cfg.pad_token_id != pad_token_id:
# logger.warning(
# f"Model pad_token_id {model_cfg.pad_token_id} is different from tokenizer pad_token_id "
# f"{pad_token_id}. Using tokenizer pad_token_id {pad_token_id}."
# )
# model_cfg.pad_token_id = pad_token_id
if dataloader_cfg.pad_token_id is None:
dataloader_cfg.pad_token_id = pad_token_id
elif dataloader_cfg.pad_token_id != pad_token_id:
log_rank0.warning(
f"Dataloader pad_token_id {dataloader_cfg.pad_token_id} is different from tokenizer "
f"pad_token_id {pad_token_id}. Using tokenizer pad_token_id {pad_token_id}."
)
dataloader_cfg.pad_token_id = pad_token_id
if self._sp_size > 1:
if dataloader_cfg.pack_to_max_length is False:
log_rank0.warning(
"pack_to_max_length must be True when using sequence parallel. Setting pack_to_max_length to True."
)
dataloader_cfg.pack_to_max_length = True
# Resolve parallel config conlicts between model and fsdp configs
self._resolve_deprecate_compile_cfg(model_cfg=model_cfg, fsdp_cfg=fsdp_cfg) # TODO: Remove in version 1.1.0
match model_cfg, fsdp_cfg:
case (MoEConfig(ep_size=1), FSDPConfig(ep_size=1)):
...
case (MoEConfig(ep_size=1), _):
model_cfg.ep_size = fsdp_cfg.ep_size
log_rank0.warning(f"Found model ep_size 1, using fsdp ep_size {fsdp_cfg.ep_size}.")
case (MoEConfig(), FSDPConfig(ep_size=1)):
fsdp_cfg.ep_size = model_cfg.ep_size
log_rank0.warning(f"Found fsdp ep_size 1, using fsdp ep_size {fsdp_cfg.ep_size}.")
match dataloader_cfg, model_cfg:
case DataloaderConfig(pack_to_max_length=False), XTunerBaseModelConfig(compile_cfg=value) if (
value is not False and value != {}
):
raise RuntimeError(
"`model_cfg.compile_cfg` and `fsdp_cfg.torch_compile` must be `False` if "
"`dataloader_cfg.pack_to_max_length` is `False`., but got:\n"
f"dataloader_cfg.pack_to_max_length: {dataloader_cfg.pack_to_max_length}\n"
f"model_cfg.compile_cfg: {model_cfg.compile_cfg}\n"
f"fsdp_cfg.torch_compile: {fsdp_cfg.torch_compile}" # TODO: removed in version 1.1.0 (FSDPConfig.torch_compile is deprecated)
)
def _resolve_deprecated_resume_cfg(self, resume_cfg: ResumeConfig, auto_resume: bool) -> bool:
if resume_cfg.auto_resume:
return True
return auto_resume
def _resolve_model_loss_cfg(self, model_cfg: XTunerBaseModelConfig, loss_cfg: CELossConfig | None):
"""Backward compatibility: set Trainer's loss_cfg to model's lm_loss_cfg if not already set.
Args:
model_cfg (XTunerBaseModelConfig): Model configuration
loss_cfg (CELossConfig): Loss configuration from Trainer
"""
if loss_cfg is not None:
if hasattr(model_cfg, "text_config"):
model_cfg.text_config.lm_loss_cfg = loss_cfg
else:
model_cfg.lm_loss_cfg = loss_cfg
log_rank0.warning(
"Setting model_cfg.lm_loss_cfg from Trainer's loss_cfg for backward compatibility. "
"In the future, please set lm_loss_cfg directly in model_cfg instead of Trainer."
)
def _resolve_load_checkpoint_cfg(
self, auto_resume: bool, load_checkpoint_cfg: LoadCheckpointConfig
) -> LoadCheckpointConfig:
# auto_resume优先级高,如果有latest ckp,则说明走auto_resume逻辑
# 此时,覆盖load checkpoint path,并且加载optimizer states, optimizer args, dataset, scheduler
latest_checkpoint = self.meta.latest_exp.latest_checkpoint
if latest_checkpoint is not None and auto_resume:
load_checkpoint_cfg.checkpoint_path = Path(latest_checkpoint)
load_checkpoint_cfg.load_optimizer_states = True
load_checkpoint_cfg.load_optimizer_args = True
load_checkpoint_cfg.load_dataset = True
load_checkpoint_cfg.load_scheduler = True
return load_checkpoint_cfg
def _load_checkpoint(self):
load_checkpoint_cfg: LoadCheckpointConfig = self._load_checkpoint_cfg
if (resume_from := load_checkpoint_cfg.checkpoint_path) is None:
log_rank0.info("No checkpoint to resume from.")
return
if isinstance(resume_from, str):
resume_from = Path(resume_from)
log_rank0.info(f"Resume from checkpoint: {resume_from}")
if not resume_from.exists():
raise FileNotFoundError(f"Checkpoint path {resume_from} does not exist.")
weights_path = resume_from / self._SAVE_WEIGHTS_DIR
if not weights_path.exists():
raise FileNotFoundError(f"Checkpoint at {resume_from} has no '{self._SAVE_WEIGHTS_DIR}/' directory.")
self._engine.load_dcp(
weights_dir=weights_path,
load_states=load_checkpoint_cfg.load_optimizer_states,
load_args=load_checkpoint_cfg.load_optimizer_args,
)
train_state_path = resume_from / self._SAVE_TRAIN_STATE_PATH
with train_state_path.open("r") as f:
train_state = json.load(f)
self._cur_step = train_state["cur_step"]
self._cur_epoch = train_state["cur_epoch"]
if load_checkpoint_cfg.load_dataset:
self._train_time_offset = train_state["train_time_offset"]
self._init_total_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC
dataloader_path = resume_from / self._SAVE_DATALOADER_DIR
self._resume_dataloader(dataloader_path)
if load_checkpoint_cfg.load_scheduler:
scheduler_path = resume_from / self._SAVE_SCHEDULER_DIR
if not scheduler_path.exists():
raise FileNotFoundError(f"Scheduler path {scheduler_path} does not exist.")
lr_scheduler_state = torch.load(scheduler_path, map_location=DEVICE)
self._lr_scheduler.load_state_dict(lr_scheduler_state)
else:
assert self.total_step > self._cur_step
scheduler_step = self.total_step - self._cur_step
self._lr_scheduler = self.build_lr_scheduler(self._lr_cfg, scheduler_step)
def _resume_dataloader(self, dataloader_path: Path):
if not dataloader_path.exists():
raise FileNotFoundError(f"Dataloader path {dataloader_path} does not exist.")
dataloader_state = torch.load(dataloader_path, map_location=DEVICE)
self._dataloader.load_state_dict(dataloader_state)
def _setup_hooks(self, hooks_config: HooksConfig) -> HooksConfig:
for stage in HookStage:
hooks = hooks_config.get_hooks(stage)
for hook in hooks:
if isinstance(hook, (TrainStepHook, CheckpointHook)):
hook.connect_trainer(self)
return hooks_config
def _setup_env(self):
if os.getenv("XTUNER_GC_ENABLE", "0") == "0":
gc.disable()
os.environ["TOKENIZERS_PARALLELISM"] = "true"
log_str = "\n============XTuner Training Environment============\n"
env = {
"XTUNER_DETERMINISTIC": os.getenv("XTUNER_DETERMINISTIC"),
"XTUNER_GC_ENABLE": os.getenv("XTUNER_GC_ENABLE"),
"XTUNER_FILE_OPEN_CONCURRENCY": os.getenv("XTUNER_FILE_OPEN_CONCURRENCY"),
"XTUNER_TOKENIZE_CHUNK_SIZE": os.getenv("XTUNER_TOKENIZE_CHUNK_SIZE"),
"XTUNER_TOKENIZE_WORKERS": os.getenv("XTUNER_TOKENIZE_WORKERS"),
"XTUNER_ACTIVATION_OFFLOAD": os.getenv("XTUNER_ACTIVATION_OFFLOAD"),
"XTUNER_USE_FA3": os.getenv("XTUNER_USE_FA3"),
"XTUNER_DISPATCHER_DEBUG": os.getenv("XTUNER_DISPATCHER_DEBUG"),
"XTUNER_ROUTER_DEBUG": os.getenv("XTUNER_ROUTER_DEBUG"),
"XTUNER_DECORD_VIDEO_THREADS": os.getenv("XTUNER_DECORD_VIDEO_THREADS"),
"XTUNER_USE_CUTLASS_GROUP_GEMM": os.getenv("XTUNER_USE_CUTLASS_GROUP_GEMM"),
"GROUPED_GEMM_USE_CUTLASS": os.getenv("GROUPED_GEMM_USE_CUTLASS"),
"XTUNER_USE_NATIVE_RMSNORM": os.getenv("XTUNER_USE_NATIVE_RMSNORM"),
"XTUNER_SM_MARGIN": os.getenv("XTUNER_SM_MARGIN"),
}
for k, v in env.items():
log_str += f"{k}: {v}\n"
log_str += "=================================================="
log_rank0.info(log_str)
def _print_training_config(self):
if self._config is not None and self.rank == 0:
config_str = self._config.model_dump_json(
indent=2,
# Printing `dataset_cfg` and `dataloader_cfg` would take up a huge amount of space and make
# the logs unreadable, so the trainer only prints the model-related configuration.
exclude={"dataset_cfg", "dataloader_cfg"},
serialize_as_any=True,
)
logger.info(f"Training config: {config_str}")
def _resolve_deprecate_compile_cfg(self, model_cfg: XTunerBaseModelConfig, fsdp_cfg: FSDPConfig):
log_rank0.warning(
"FSDPConfig.torch_compile is deprecated, and will be removed in version 1.1.0. "
"Please use XTunerBaseModelConfig.compile_cfg to control whether to use torch.compile for the model"
)
if not fsdp_cfg.torch_compile:
log_rank0.warning("FSDPConfig.torch_compile is set to False, setting model_cfg.compile_cfg to False.")
model_cfg.compile_cfg = False