11# Copyright (c) ModelScope Contributors. All rights reserved.
22import inspect
33import json
4+ import logging
45import os
6+ import random
57import re
8+ from argparse import Namespace
69from dataclasses import dataclass , field
710from typing import Any , Dict , Generator , List , Literal , Optional , Tuple , Type , Union , Callable
811import asyncio
912import threading
13+ import numpy as np
1014import torch
1115import torch .distributed as dist
1216import torch .nn as nn
@@ -794,13 +798,22 @@ def clip_grad_and_step(self, max_grad_norm: float=1.0, norm_type=2, **kwargs):
794798 self .lr_step (** kwargs )
795799
796800 @remote_function (dispatch = 'all' , sync = True )
797- def save (self , name : Optional [str ] = None , output_dir : Optional [str ] = None , interval : int = 1 , ** kwargs ):
801+ def save (self , name : Optional [str ] = None , output_dir : Optional [str ] = None ,
802+ interval : int = 1 , save_optimizer : bool = False , ** kwargs ):
798803 """Save model checkpoint.
799804
805+ Always saves HF-format model weights. When ``save_optimizer`` is True,
806+ additionally saves optimizer / lr_scheduler / RNG state in mcore
807+ distributed-checkpoint format so that training can be resumed later.
808+
800809 Args:
801- output_dir: Output directory.
802- interval: Save each interval steps.
803- **kwargs: Additional arguments.
810+ name: Checkpoint name. Defaults to ``'checkpoint-step-{cur_step}'``.
811+ output_dir: Output directory. Defaults to ``'output'``.
812+ interval: Save each *interval* steps.
813+ save_optimizer: If True, save optimizer + lr_scheduler + RNG state
814+ alongside the HF weights for checkpoint resumption.
815+ **kwargs: Additional arguments forwarded to the underlying save
816+ methods (e.g. ``adapter_name``).
804817 """
805818 adapter_name = kwargs .pop ('adapter_name' , self ._get_default_group ())
806819 optimizer_config = self .optimizer_group [adapter_name ]
@@ -812,38 +825,333 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int
812825 if output_dir is None :
813826 output_dir = 'output'
814827 checkpoint_dir = os .path .join (output_dir , name )
815- save_format = kwargs .pop ('save_format' , 'hf' ) # 'hf' or 'megatron'
816828
817- if save_format == 'hf' :
818- self ._save_hf_format (checkpoint_dir , optimizer_config .adapter_name )
819- else :
820- self ._save_megatron_format (checkpoint_dir , optimizer_config .adapter_name )
829+ # Always save HF-format weights (for inference / deployment).
830+ self ._save_hf_format (checkpoint_dir , optimizer_config .adapter_name )
831+
832+ # Optionally save mcore optimizer state (for training resumption).
833+ if save_optimizer :
834+ self ._save_mcore_optimizer (
835+ checkpoint_dir ,
836+ optimizer_config = optimizer_config ,
837+ ** kwargs ,
838+ )
821839
822840 self ._save_tokenizer (checkpoint_dir , adapter_name = adapter_name )
823-
824- # Final synchronization to ensure all ranks complete save
841+
842+ # Final synchronization to ensure all ranks complete save.
825843 if dist .is_initialized ():
826844 dist .barrier ()
827845
828846 return checkpoint_dir
829847
830-
831848 @remote_function (dispatch = 'all' )
832849 def load (self , name : str , output_dir : Optional [str ] = None , ** kwargs ):
833- if output_dir is None :
834- # load from hub
850+ """Load model weights, and optionally optimizer / scheduler / RNG state.
851+
852+ Args:
853+ name: Checkpoint name or HuggingFace Hub model id.
854+ output_dir: Parent directory that contains the checkpoint folder.
855+ If None **and** ``resume`` is False, downloads from Hub.
856+ resume: If True, restore optimizer, lr_scheduler and RNG state
857+ from the mcore sub-checkpoint for training resumption.
858+ **kwargs: Additional arguments (``adapter_name``, ``no_load_optim``,
859+ ``no_load_rng``, etc.).
860+ """
861+ resume = kwargs .pop ('resume' , False )
862+ if output_dir is None and not resume :
863+ # Load from hub
835864 token = kwargs .pop ('token' , None )
836865 checkpoint_dir = HubOperation .download_model (name , token = token )
837866 else :
867+ if output_dir is None :
868+ output_dir = 'output'
838869 checkpoint_dir = os .path .join (output_dir , name )
870+
839871 adapter_name = kwargs .get ('adapter_name' , self ._get_default_group ())
840- bridge = self ._bridge
841- for _model in self .strategy .unwrap_model (self .model ):
842- bridge .load_weights (_model , checkpoint_dir , is_peft_format = (adapter_name != _default_adapter_name ))
872+
873+ if resume :
874+ self ._load_mcore_optimizer (
875+ checkpoint_dir ,
876+ adapter_name = adapter_name ,
877+ ** kwargs ,
878+ )
879+ else :
880+ bridge = self ._bridge
881+ for _model in self .strategy .unwrap_model (self .model ):
882+ bridge .load_weights (
883+ _model , checkpoint_dir ,
884+ is_peft_format = (adapter_name != _default_adapter_name ),
885+ )
843886
844887 if dist .is_initialized ():
845888 dist .barrier ()
846889
890+ @staticmethod
891+ def _get_rng_state () -> 'ShardedObject' :
892+ from megatron .core import parallel_state as mpu , tensor_parallel
893+ from megatron .core .dist_checkpointing .mapping import ShardedObject
894+
895+ rng_state = {
896+ 'random_rng_state' : random .getstate (),
897+ 'np_rng_state' : np .random .get_state (),
898+ 'torch_rng_state' : torch .get_rng_state (),
899+ 'cuda_rng_state' : torch .cuda .get_rng_state (),
900+ 'rng_tracker_states' : tensor_parallel .get_cuda_rng_tracker ().get_states (),
901+ }
902+ rng_state_list = [rng_state ]
903+
904+ pp_rank = mpu .get_pipeline_model_parallel_rank ()
905+ pp_size = mpu .get_pipeline_model_parallel_world_size ()
906+ tp_rank = mpu .get_tensor_model_parallel_rank ()
907+ tp_size = mpu .get_tensor_model_parallel_world_size ()
908+
909+ return ShardedObject (
910+ 'rng_state' , rng_state_list ,
911+ (pp_size , tp_size ), (pp_rank , tp_rank ),
912+ replica_id = mpu .get_data_parallel_rank (with_context_parallel = True ),
913+ )
914+
915+ @staticmethod
916+ def _generate_state_dict (
917+ model : list ,
918+ optimizer = None ,
919+ opt_param_scheduler = None ,
920+ rng_state = None ,
921+ iteration : Optional [int ] = None ,
922+ model_sd_kwargs : Optional [dict ] = None ,
923+ optim_sd_kwargs : Optional [dict ] = None ,
924+ save_optim : bool = True ,
925+ save_rng : bool = True ,
926+ ) -> dict :
927+ model_sd_kwargs = model_sd_kwargs or {}
928+ optim_sd_kwargs = optim_sd_kwargs or {}
929+
930+ state_dict : dict = {
931+ 'checkpoint_version' : 3.0 ,
932+ }
933+ if iteration is not None :
934+ state_dict ['iteration' ] = iteration
935+
936+ # Model sharded state dict
937+ for i , m in enumerate (model ):
938+ key = 'model' if len (model ) == 1 else f'model{ i } '
939+ state_dict [key ] = m .sharded_state_dict (** model_sd_kwargs )
940+
941+ # Optimizer + scheduler
942+ if save_optim and optimizer is not None :
943+ state_dict ['optimizer' ] = optimizer .sharded_state_dict (
944+ state_dict , ** optim_sd_kwargs ,
945+ )
946+ if opt_param_scheduler is not None :
947+ state_dict ['opt_param_scheduler' ] = opt_param_scheduler .state_dict ()
948+
949+ # RNG
950+ if save_rng and rng_state is not None :
951+ state_dict ['rng_state' ] = rng_state
952+
953+ return state_dict
954+
955+ def _save_mcore_optimizer (
956+ self ,
957+ checkpoint_dir : str ,
958+ optimizer_config : 'MegatronOptimizerGroup' ,
959+ ** kwargs ,
960+ ):
961+ from megatron .core import dist_checkpointing , parallel_state as mpu
962+ from megatron .core .dist_checkpointing .serialization import (
963+ get_default_save_sharded_strategy ,
964+ )
965+ from megatron .core .dist_checkpointing .strategies .fully_parallel import (
966+ FullyParallelSaveStrategyWrapper ,
967+ )
968+
969+ iteration = optimizer_config .cur_step
970+ iter_dir = os .path .join (checkpoint_dir , f'iter_{ iteration :07d} ' )
971+ os .makedirs (iter_dir , exist_ok = True )
972+
973+ sharded_sd_metadata = {
974+ 'distrib_optim_sharding_type' : 'dp_reshardable' ,
975+ 'singleton_local_shards' : False ,
976+ 'chained_optim_avoid_prefix' : True ,
977+ }
978+
979+ rng_state = self ._get_rng_state ()
980+ model = self .model
981+
982+ state_dict = self ._generate_state_dict (
983+ model = model ,
984+ optimizer = optimizer_config .optimizer ,
985+ opt_param_scheduler = optimizer_config .lr_scheduler ,
986+ rng_state = rng_state ,
987+ iteration = iteration ,
988+ model_sd_kwargs = {'metadata' : sharded_sd_metadata },
989+ optim_sd_kwargs = {'metadata' : sharded_sd_metadata },
990+ )
991+
992+ save_strategy = get_default_save_sharded_strategy ()
993+ if mpu .get_data_parallel_world_size (with_context_parallel = True ) > 1 :
994+ save_strategy = FullyParallelSaveStrategyWrapper (
995+ save_strategy ,
996+ mpu .get_data_parallel_group (with_context_parallel = True ),
997+ )
998+
999+ dist_checkpointing .save (
1000+ state_dict , iter_dir , save_strategy ,
1001+ async_sharded_save = False ,
1002+ validate_access_integrity = True ,
1003+ content_metadata = sharded_sd_metadata ,
1004+ )
1005+
1006+ if dist .is_initialized ():
1007+ dist .barrier ()
1008+
1009+ # Write tracker file (rank 0 only).
1010+ rank = dist .get_rank () if dist .is_initialized () else 0
1011+ if rank == 0 :
1012+ tracker_path = os .path .join (
1013+ checkpoint_dir , 'latest_checkpointed_iteration.txt' ,
1014+ )
1015+ with open (tracker_path , 'w' ) as f :
1016+ f .write (str (iteration ))
1017+
1018+ logging .getLogger (__name__ ).info (
1019+ f'Saved mcore optimizer state at iteration { iteration } '
1020+ f'to { checkpoint_dir } '
1021+ )
1022+
1023+ def _load_mcore_optimizer (
1024+ self ,
1025+ checkpoint_dir : str ,
1026+ adapter_name : str = '' ,
1027+ ** kwargs ,
1028+ ):
1029+ from megatron .core import (
1030+ dist_checkpointing , parallel_state as mpu , tensor_parallel ,
1031+ )
1032+ from megatron .core .dist_checkpointing .serialization import (
1033+ get_default_load_sharded_strategy ,
1034+ )
1035+ from megatron .core .dist_checkpointing .strategies .fully_parallel import (
1036+ FullyParallelLoadStrategyWrapper ,
1037+ )
1038+
1039+ no_load_optim = kwargs .pop ('no_load_optim' , False )
1040+ no_load_rng = kwargs .pop ('no_load_rng' , False )
1041+
1042+ optimizer_config = self .optimizer_group .get (
1043+ adapter_name or self ._get_default_group (),
1044+ )
1045+
1046+ # Read iteration from tracker file.
1047+ tracker_path = os .path .join (
1048+ checkpoint_dir , 'latest_checkpointed_iteration.txt' ,
1049+ )
1050+ iteration = self ._read_iteration (tracker_path )
1051+ if iteration == 0 :
1052+ logging .getLogger (__name__ ).warning (
1053+ f'No checkpoint found in { checkpoint_dir } '
1054+ )
1055+ return
1056+
1057+ iter_dir = os .path .join (checkpoint_dir , f'iter_{ iteration :07d} ' )
1058+
1059+ # Load common (non-sharded) state to inspect content metadata.
1060+ common_state = dist_checkpointing .load_common_state_dict (iter_dir )
1061+ sharded_sd_metadata = dist_checkpointing .load_content_metadata (
1062+ preloaded_state_dict = common_state ,
1063+ )
1064+
1065+ # Build optimizer / scheduler references for the sharded state dict.
1066+ optimizer = optimizer_config .optimizer if not no_load_optim else None
1067+ opt_param_scheduler = (
1068+ optimizer_config .lr_scheduler if not no_load_optim else None
1069+ )
1070+ rng_state = self ._get_rng_state () if not no_load_rng else None
1071+
1072+ optim_sd_kwargs = dict (metadata = sharded_sd_metadata , is_loading = True )
1073+ model_sd_kwargs = dict (metadata = sharded_sd_metadata )
1074+
1075+ sharded_state_dict = self ._generate_state_dict (
1076+ model = self .model ,
1077+ optimizer = optimizer ,
1078+ opt_param_scheduler = opt_param_scheduler ,
1079+ rng_state = rng_state ,
1080+ iteration = iteration ,
1081+ model_sd_kwargs = model_sd_kwargs ,
1082+ optim_sd_kwargs = optim_sd_kwargs ,
1083+ )
1084+
1085+ # Load using fully-parallel strategy for speed.
1086+ load_strategy = get_default_load_sharded_strategy (iter_dir )
1087+ if mpu .get_data_parallel_world_size (with_context_parallel = True ) > 1 :
1088+ load_strategy = FullyParallelLoadStrategyWrapper (
1089+ load_strategy ,
1090+ mpu .get_data_parallel_group (with_context_parallel = True ),
1091+ )
1092+ state_dict = dist_checkpointing .load (
1093+ sharded_state_dict , iter_dir , load_strategy ,
1094+ )
1095+
1096+ # Restore model weights.
1097+ if len (self .model ) == 1 :
1098+ self .model [0 ].load_state_dict (state_dict ['model' ], strict = False )
1099+ else :
1100+ for i , m in enumerate (self .model ):
1101+ key = f'model{ i } '
1102+ if key in state_dict :
1103+ m .load_state_dict (state_dict [key ], strict = False )
1104+
1105+ # Restore optimizer + LR scheduler.
1106+ if not no_load_optim and optimizer is not None and 'optimizer' in state_dict :
1107+ optimizer .load_state_dict (state_dict ['optimizer' ])
1108+ if (
1109+ opt_param_scheduler is not None
1110+ and 'opt_param_scheduler' in state_dict
1111+ ):
1112+ opt_param_scheduler .load_state_dict (
1113+ state_dict ['opt_param_scheduler' ],
1114+ )
1115+
1116+ if not no_load_rng and 'rng_state' in state_dict :
1117+ rng = state_dict ['rng_state' ]
1118+ rng = rng [0 ]
1119+ random .setstate (rng ['random_rng_state' ])
1120+ np .random .set_state (rng ['np_rng_state' ])
1121+ torch .set_rng_state (rng ['torch_rng_state' ])
1122+ torch .cuda .set_rng_state (rng ['cuda_rng_state' ])
1123+ tensor_parallel .get_cuda_rng_tracker ().set_states (
1124+ rng ['rng_tracker_states' ],
1125+ )
1126+
1127+ # Restore iteration counter.
1128+ if optimizer_config is not None and 'iteration' in state_dict :
1129+ optimizer_config .cur_step = state_dict ['iteration' ]
1130+
1131+ if dist .is_initialized ():
1132+ dist .barrier ()
1133+
1134+ logging .getLogger (__name__ ).info (
1135+ f'Resumed from mcore checkpoint at iteration { iteration } '
1136+ f'from { checkpoint_dir } '
1137+ )
1138+
1139+ @staticmethod
1140+ def _read_iteration (tracker_path : str ) -> int :
1141+ if not os .path .exists (tracker_path ):
1142+ return 0
1143+ with open (tracker_path , 'r' ) as f :
1144+ iteration = int (f .read ().strip ())
1145+ if torch .distributed .is_initialized ():
1146+ iters_cuda = torch .tensor (
1147+ [iteration ], dtype = torch .long , device = 'cuda' ,
1148+ )
1149+ torch .distributed .all_reduce (
1150+ iters_cuda , op = torch .distributed .ReduceOp .MAX ,
1151+ )
1152+ iteration = iters_cuda [0 ].item ()
1153+ return iteration
1154+
8471155 def _save_hf_format (self , output_dir : str , adapter_name : str , lora_converter = None ):
8481156 """Save in HuggingFace format using bridge adapter.
8491157
0 commit comments