Skip to content

Commit ca26436

Browse files
committed
Merge branch 'dev' into fix_moe
2 parents 87ce96d + 6d80d9c commit ca26436

File tree

3 files changed

+353
-38
lines changed

3 files changed

+353
-38
lines changed

src/twinkle/model/megatron/megatron.py

Lines changed: 325 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
import inspect
33
import json
4+
import logging
45
import os
6+
import random
57
import re
8+
from argparse import Namespace
69
from dataclasses import dataclass, field
710
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Type, Union, Callable
811
import asyncio
912
import threading
13+
import numpy as np
1014
import torch
1115
import torch.distributed as dist
1216
import torch.nn as nn
@@ -794,13 +798,22 @@ def clip_grad_and_step(self, max_grad_norm: float=1.0, norm_type=2, **kwargs):
794798
self.lr_step(**kwargs)
795799

796800
@remote_function(dispatch='all', sync=True)
797-
def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, interval: int = 1, **kwargs):
801+
def save(self, name: Optional[str] = None, output_dir: Optional[str] = None,
802+
interval: int = 1, save_optimizer: bool = False, **kwargs):
798803
"""Save model checkpoint.
799804
805+
Always saves HF-format model weights. When ``save_optimizer`` is True,
806+
additionally saves optimizer / lr_scheduler / RNG state in mcore
807+
distributed-checkpoint format so that training can be resumed later.
808+
800809
Args:
801-
output_dir: Output directory.
802-
interval: Save each interval steps.
803-
**kwargs: Additional arguments.
810+
name: Checkpoint name. Defaults to ``'checkpoint-step-{cur_step}'``.
811+
output_dir: Output directory. Defaults to ``'output'``.
812+
interval: Save each *interval* steps.
813+
save_optimizer: If True, save optimizer + lr_scheduler + RNG state
814+
alongside the HF weights for checkpoint resumption.
815+
**kwargs: Additional arguments forwarded to the underlying save
816+
methods (e.g. ``adapter_name``).
804817
"""
805818
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
806819
optimizer_config = self.optimizer_group[adapter_name]
@@ -812,38 +825,333 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int
812825
if output_dir is None:
813826
output_dir = 'output'
814827
checkpoint_dir = os.path.join(output_dir, name)
815-
save_format = kwargs.pop('save_format', 'hf') # 'hf' or 'megatron'
816828

817-
if save_format == 'hf':
818-
self._save_hf_format(checkpoint_dir, optimizer_config.adapter_name)
819-
else:
820-
self._save_megatron_format(checkpoint_dir, optimizer_config.adapter_name)
829+
# Always save HF-format weights (for inference / deployment).
830+
self._save_hf_format(checkpoint_dir, optimizer_config.adapter_name)
831+
832+
# Optionally save mcore optimizer state (for training resumption).
833+
if save_optimizer:
834+
self._save_mcore_optimizer(
835+
checkpoint_dir,
836+
optimizer_config=optimizer_config,
837+
**kwargs,
838+
)
821839

822840
self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name)
823-
824-
# Final synchronization to ensure all ranks complete save
841+
842+
# Final synchronization to ensure all ranks complete save.
825843
if dist.is_initialized():
826844
dist.barrier()
827845

828846
return checkpoint_dir
829847

830-
831848
@remote_function(dispatch='all')
832849
def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
833-
if output_dir is None:
834-
# load from hub
850+
"""Load model weights, and optionally optimizer / scheduler / RNG state.
851+
852+
Args:
853+
name: Checkpoint name or HuggingFace Hub model id.
854+
output_dir: Parent directory that contains the checkpoint folder.
855+
If None **and** ``resume`` is False, downloads from Hub.
856+
resume: If True, restore optimizer, lr_scheduler and RNG state
857+
from the mcore sub-checkpoint for training resumption.
858+
**kwargs: Additional arguments (``adapter_name``, ``no_load_optim``,
859+
``no_load_rng``, etc.).
860+
"""
861+
resume = kwargs.pop('resume', False)
862+
if output_dir is None and not resume:
863+
# Load from hub
835864
token = kwargs.pop('token', None)
836865
checkpoint_dir = HubOperation.download_model(name, token=token)
837866
else:
867+
if output_dir is None:
868+
output_dir = 'output'
838869
checkpoint_dir = os.path.join(output_dir, name)
870+
839871
adapter_name = kwargs.get('adapter_name', self._get_default_group())
840-
bridge = self._bridge
841-
for _model in self.strategy.unwrap_model(self.model):
842-
bridge.load_weights(_model, checkpoint_dir, is_peft_format = (adapter_name != _default_adapter_name))
872+
873+
if resume:
874+
self._load_mcore_optimizer(
875+
checkpoint_dir,
876+
adapter_name=adapter_name,
877+
**kwargs,
878+
)
879+
else:
880+
bridge = self._bridge
881+
for _model in self.strategy.unwrap_model(self.model):
882+
bridge.load_weights(
883+
_model, checkpoint_dir,
884+
is_peft_format=(adapter_name != _default_adapter_name),
885+
)
843886

844887
if dist.is_initialized():
845888
dist.barrier()
846889

890+
@staticmethod
891+
def _get_rng_state() -> 'ShardedObject':
892+
from megatron.core import parallel_state as mpu, tensor_parallel
893+
from megatron.core.dist_checkpointing.mapping import ShardedObject
894+
895+
rng_state = {
896+
'random_rng_state': random.getstate(),
897+
'np_rng_state': np.random.get_state(),
898+
'torch_rng_state': torch.get_rng_state(),
899+
'cuda_rng_state': torch.cuda.get_rng_state(),
900+
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states(),
901+
}
902+
rng_state_list = [rng_state]
903+
904+
pp_rank = mpu.get_pipeline_model_parallel_rank()
905+
pp_size = mpu.get_pipeline_model_parallel_world_size()
906+
tp_rank = mpu.get_tensor_model_parallel_rank()
907+
tp_size = mpu.get_tensor_model_parallel_world_size()
908+
909+
return ShardedObject(
910+
'rng_state', rng_state_list,
911+
(pp_size, tp_size), (pp_rank, tp_rank),
912+
replica_id=mpu.get_data_parallel_rank(with_context_parallel=True),
913+
)
914+
915+
@staticmethod
916+
def _generate_state_dict(
917+
model: list,
918+
optimizer=None,
919+
opt_param_scheduler=None,
920+
rng_state=None,
921+
iteration: Optional[int] = None,
922+
model_sd_kwargs: Optional[dict] = None,
923+
optim_sd_kwargs: Optional[dict] = None,
924+
save_optim: bool = True,
925+
save_rng: bool = True,
926+
) -> dict:
927+
model_sd_kwargs = model_sd_kwargs or {}
928+
optim_sd_kwargs = optim_sd_kwargs or {}
929+
930+
state_dict: dict = {
931+
'checkpoint_version': 3.0,
932+
}
933+
if iteration is not None:
934+
state_dict['iteration'] = iteration
935+
936+
# Model sharded state dict
937+
for i, m in enumerate(model):
938+
key = 'model' if len(model) == 1 else f'model{i}'
939+
state_dict[key] = m.sharded_state_dict(**model_sd_kwargs)
940+
941+
# Optimizer + scheduler
942+
if save_optim and optimizer is not None:
943+
state_dict['optimizer'] = optimizer.sharded_state_dict(
944+
state_dict, **optim_sd_kwargs,
945+
)
946+
if opt_param_scheduler is not None:
947+
state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
948+
949+
# RNG
950+
if save_rng and rng_state is not None:
951+
state_dict['rng_state'] = rng_state
952+
953+
return state_dict
954+
955+
def _save_mcore_optimizer(
956+
self,
957+
checkpoint_dir: str,
958+
optimizer_config: 'MegatronOptimizerGroup',
959+
**kwargs,
960+
):
961+
from megatron.core import dist_checkpointing, parallel_state as mpu
962+
from megatron.core.dist_checkpointing.serialization import (
963+
get_default_save_sharded_strategy,
964+
)
965+
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
966+
FullyParallelSaveStrategyWrapper,
967+
)
968+
969+
iteration = optimizer_config.cur_step
970+
iter_dir = os.path.join(checkpoint_dir, f'iter_{iteration:07d}')
971+
os.makedirs(iter_dir, exist_ok=True)
972+
973+
sharded_sd_metadata = {
974+
'distrib_optim_sharding_type': 'dp_reshardable',
975+
'singleton_local_shards': False,
976+
'chained_optim_avoid_prefix': True,
977+
}
978+
979+
rng_state = self._get_rng_state()
980+
model = self.model
981+
982+
state_dict = self._generate_state_dict(
983+
model=model,
984+
optimizer=optimizer_config.optimizer,
985+
opt_param_scheduler=optimizer_config.lr_scheduler,
986+
rng_state=rng_state,
987+
iteration=iteration,
988+
model_sd_kwargs={'metadata': sharded_sd_metadata},
989+
optim_sd_kwargs={'metadata': sharded_sd_metadata},
990+
)
991+
992+
save_strategy = get_default_save_sharded_strategy()
993+
if mpu.get_data_parallel_world_size(with_context_parallel=True) > 1:
994+
save_strategy = FullyParallelSaveStrategyWrapper(
995+
save_strategy,
996+
mpu.get_data_parallel_group(with_context_parallel=True),
997+
)
998+
999+
dist_checkpointing.save(
1000+
state_dict, iter_dir, save_strategy,
1001+
async_sharded_save=False,
1002+
validate_access_integrity=True,
1003+
content_metadata=sharded_sd_metadata,
1004+
)
1005+
1006+
if dist.is_initialized():
1007+
dist.barrier()
1008+
1009+
# Write tracker file (rank 0 only).
1010+
rank = dist.get_rank() if dist.is_initialized() else 0
1011+
if rank == 0:
1012+
tracker_path = os.path.join(
1013+
checkpoint_dir, 'latest_checkpointed_iteration.txt',
1014+
)
1015+
with open(tracker_path, 'w') as f:
1016+
f.write(str(iteration))
1017+
1018+
logging.getLogger(__name__).info(
1019+
f'Saved mcore optimizer state at iteration {iteration} '
1020+
f'to {checkpoint_dir}'
1021+
)
1022+
1023+
def _load_mcore_optimizer(
1024+
self,
1025+
checkpoint_dir: str,
1026+
adapter_name: str = '',
1027+
**kwargs,
1028+
):
1029+
from megatron.core import (
1030+
dist_checkpointing, parallel_state as mpu, tensor_parallel,
1031+
)
1032+
from megatron.core.dist_checkpointing.serialization import (
1033+
get_default_load_sharded_strategy,
1034+
)
1035+
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
1036+
FullyParallelLoadStrategyWrapper,
1037+
)
1038+
1039+
no_load_optim = kwargs.pop('no_load_optim', False)
1040+
no_load_rng = kwargs.pop('no_load_rng', False)
1041+
1042+
optimizer_config = self.optimizer_group.get(
1043+
adapter_name or self._get_default_group(),
1044+
)
1045+
1046+
# Read iteration from tracker file.
1047+
tracker_path = os.path.join(
1048+
checkpoint_dir, 'latest_checkpointed_iteration.txt',
1049+
)
1050+
iteration = self._read_iteration(tracker_path)
1051+
if iteration == 0:
1052+
logging.getLogger(__name__).warning(
1053+
f'No checkpoint found in {checkpoint_dir}'
1054+
)
1055+
return
1056+
1057+
iter_dir = os.path.join(checkpoint_dir, f'iter_{iteration:07d}')
1058+
1059+
# Load common (non-sharded) state to inspect content metadata.
1060+
common_state = dist_checkpointing.load_common_state_dict(iter_dir)
1061+
sharded_sd_metadata = dist_checkpointing.load_content_metadata(
1062+
preloaded_state_dict=common_state,
1063+
)
1064+
1065+
# Build optimizer / scheduler references for the sharded state dict.
1066+
optimizer = optimizer_config.optimizer if not no_load_optim else None
1067+
opt_param_scheduler = (
1068+
optimizer_config.lr_scheduler if not no_load_optim else None
1069+
)
1070+
rng_state = self._get_rng_state() if not no_load_rng else None
1071+
1072+
optim_sd_kwargs = dict(metadata=sharded_sd_metadata, is_loading=True)
1073+
model_sd_kwargs = dict(metadata=sharded_sd_metadata)
1074+
1075+
sharded_state_dict = self._generate_state_dict(
1076+
model=self.model,
1077+
optimizer=optimizer,
1078+
opt_param_scheduler=opt_param_scheduler,
1079+
rng_state=rng_state,
1080+
iteration=iteration,
1081+
model_sd_kwargs=model_sd_kwargs,
1082+
optim_sd_kwargs=optim_sd_kwargs,
1083+
)
1084+
1085+
# Load using fully-parallel strategy for speed.
1086+
load_strategy = get_default_load_sharded_strategy(iter_dir)
1087+
if mpu.get_data_parallel_world_size(with_context_parallel=True) > 1:
1088+
load_strategy = FullyParallelLoadStrategyWrapper(
1089+
load_strategy,
1090+
mpu.get_data_parallel_group(with_context_parallel=True),
1091+
)
1092+
state_dict = dist_checkpointing.load(
1093+
sharded_state_dict, iter_dir, load_strategy,
1094+
)
1095+
1096+
# Restore model weights.
1097+
if len(self.model) == 1:
1098+
self.model[0].load_state_dict(state_dict['model'], strict=False)
1099+
else:
1100+
for i, m in enumerate(self.model):
1101+
key = f'model{i}'
1102+
if key in state_dict:
1103+
m.load_state_dict(state_dict[key], strict=False)
1104+
1105+
# Restore optimizer + LR scheduler.
1106+
if not no_load_optim and optimizer is not None and 'optimizer' in state_dict:
1107+
optimizer.load_state_dict(state_dict['optimizer'])
1108+
if (
1109+
opt_param_scheduler is not None
1110+
and 'opt_param_scheduler' in state_dict
1111+
):
1112+
opt_param_scheduler.load_state_dict(
1113+
state_dict['opt_param_scheduler'],
1114+
)
1115+
1116+
if not no_load_rng and 'rng_state' in state_dict:
1117+
rng = state_dict['rng_state']
1118+
rng = rng[0]
1119+
random.setstate(rng['random_rng_state'])
1120+
np.random.set_state(rng['np_rng_state'])
1121+
torch.set_rng_state(rng['torch_rng_state'])
1122+
torch.cuda.set_rng_state(rng['cuda_rng_state'])
1123+
tensor_parallel.get_cuda_rng_tracker().set_states(
1124+
rng['rng_tracker_states'],
1125+
)
1126+
1127+
# Restore iteration counter.
1128+
if optimizer_config is not None and 'iteration' in state_dict:
1129+
optimizer_config.cur_step = state_dict['iteration']
1130+
1131+
if dist.is_initialized():
1132+
dist.barrier()
1133+
1134+
logging.getLogger(__name__).info(
1135+
f'Resumed from mcore checkpoint at iteration {iteration} '
1136+
f'from {checkpoint_dir}'
1137+
)
1138+
1139+
@staticmethod
1140+
def _read_iteration(tracker_path: str) -> int:
1141+
if not os.path.exists(tracker_path):
1142+
return 0
1143+
with open(tracker_path, 'r') as f:
1144+
iteration = int(f.read().strip())
1145+
if torch.distributed.is_initialized():
1146+
iters_cuda = torch.tensor(
1147+
[iteration], dtype=torch.long, device='cuda',
1148+
)
1149+
torch.distributed.all_reduce(
1150+
iters_cuda, op=torch.distributed.ReduceOp.MAX,
1151+
)
1152+
iteration = iters_cuda[0].item()
1153+
return iteration
1154+
8471155
def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter = None):
8481156
"""Save in HuggingFace format using bridge adapter.
8491157

0 commit comments

Comments
 (0)