diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index df8ec8ef613..f5d66b8db4f 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -1161,20 +1161,26 @@ def _split_state_dict(self, state_dict): state_dicts = [None] * len(self.chained_optimizers) if state_dict is not None: if len(self.model_chunks) == 1: - state_dicts[0] = state_dict + # When there is only one global model chunk, all sub-optimizers + # (e.g., dense and MoE parts) use the same model state dict. + state_dicts = [state_dict] * len(self.chained_optimizers) else: - # Split state_dict if needed + # Split state_dict by model chunk object. prefix = "model" if "model0" in state_dict.keys() else "model_" - offset = 0 + chunk_to_global_idx = {chunk: idx for idx, chunk in enumerate(self.model_chunks)} for optimizer_idx, optimizer in enumerate(self.chained_optimizers): if hasattr(optimizer, "model_chunks"): d = {} - for chunk_idx in range(len(optimizer.model_chunks)): + for chunk_idx, model_chunk in enumerate(optimizer.model_chunks): + assert model_chunk in chunk_to_global_idx, ( + "Sub-optimizer model chunk was not found in " + "chained optimizer model chunks" + ) + global_idx = chunk_to_global_idx[model_chunk] assert ( - f"{prefix}{offset}" in state_dict - ), f"Wrong state_dict format, cannot find '{prefix}{offset}'" - d[f"{prefix}{chunk_idx}"] = state_dict[f"{prefix}{offset}"] - offset += 1 + f"{prefix}{global_idx}" in state_dict + ), f"Wrong state_dict format, cannot find '{prefix}{global_idx}'" + d[f"{prefix}{chunk_idx}"] = state_dict[f"{prefix}{global_idx}"] if len(d) > 0: state_dicts[optimizer_idx] = d return state_dicts