From 910fa29539d589f1ca6d4bf2d973024ebe42c5fe Mon Sep 17 00:00:00 2001 From: eternally-z Date: Tue, 3 Mar 2026 11:39:03 +0800 Subject: [PATCH 1/3] fix split_state_dict function --- megatron/core/optimizer/optimizer.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index df8ec8ef613..2d7f402d64e 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -1161,20 +1161,25 @@ def _split_state_dict(self, state_dict): state_dicts = [None] * len(self.chained_optimizers) if state_dict is not None: if len(self.model_chunks) == 1: - state_dicts[0] = state_dict + # When there is only one global model chunk, all sub-optimizers + # (e.g., dense and MoE parts) use the same model state dict. + state_dicts = [state_dict] * len(self.chained_optimizers) else: - # Split state_dict if needed + # Split state_dict by chunk identity prefix = "model" if "model0" in state_dict.keys() else "model_" - offset = 0 + chunk_id_to_global_idx = {id(chunk): idx for idx, chunk in enumerate(self.model_chunks)} for optimizer_idx, optimizer in enumerate(self.chained_optimizers): if hasattr(optimizer, "model_chunks"): d = {} - for chunk_idx in range(len(optimizer.model_chunks)): + for chunk_idx, model_chunk in enumerate(optimizer.model_chunks): + assert id(model_chunk) in chunk_id_to_global_idx, ( + "Sub-optimizer model chunk was not found in chained optimizer model chunks" + ) + global_idx = chunk_id_to_global_idx[id(model_chunk)] assert ( - f"{prefix}{offset}" in state_dict - ), f"Wrong state_dict format, cannot find '{prefix}{offset}'" - d[f"{prefix}{chunk_idx}"] = state_dict[f"{prefix}{offset}"] - offset += 1 + f"{prefix}{global_idx}" in state_dict + ), f"Wrong state_dict format, cannot find '{prefix}{global_idx}'" + d[f"{prefix}{chunk_idx}"] = state_dict[f"{prefix}{global_idx}"] if len(d) > 0: state_dicts[optimizer_idx] = d return state_dicts From c65fd5c735f8a9a23a649a86e37a17068058b3ef Mon Sep 17 00:00:00 2001 From: eternally-z Date: Tue, 3 Mar 2026 16:41:39 +0800 Subject: [PATCH 2/3] minor fix --- megatron/core/optimizer/optimizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index 2d7f402d64e..6c430c9c491 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -1165,17 +1165,17 @@ def _split_state_dict(self, state_dict): # (e.g., dense and MoE parts) use the same model state dict. state_dicts = [state_dict] * len(self.chained_optimizers) else: - # Split state_dict by chunk identity + # Split state_dict by model chunk object. prefix = "model" if "model0" in state_dict.keys() else "model_" - chunk_id_to_global_idx = {id(chunk): idx for idx, chunk in enumerate(self.model_chunks)} + chunk_to_global_idx = {chunk: idx for idx, chunk in enumerate(self.model_chunks)} for optimizer_idx, optimizer in enumerate(self.chained_optimizers): if hasattr(optimizer, "model_chunks"): d = {} for chunk_idx, model_chunk in enumerate(optimizer.model_chunks): - assert id(model_chunk) in chunk_id_to_global_idx, ( + assert model_chunk in chunk_to_global_idx, ( "Sub-optimizer model chunk was not found in chained optimizer model chunks" ) - global_idx = chunk_id_to_global_idx[id(model_chunk)] + global_idx = chunk_to_global_idx[model_chunk] assert ( f"{prefix}{global_idx}" in state_dict ), f"Wrong state_dict format, cannot find '{prefix}{global_idx}'" From e867137124b90f82b87bd9e2d704dda6aff39f20 Mon Sep 17 00:00:00 2001 From: kunlunl Date: Tue, 3 Mar 2026 22:33:29 +0800 Subject: [PATCH 3/3] Fix lint error --- megatron/core/optimizer/optimizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index 6c430c9c491..f5d66b8db4f 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -1173,7 +1173,8 @@ def _split_state_dict(self, state_dict): d = {} for chunk_idx, model_chunk in enumerate(optimizer.model_chunks): assert model_chunk in chunk_to_global_idx, ( - "Sub-optimizer model chunk was not found in chained optimizer model chunks" + "Sub-optimizer model chunk was not found in " + "chained optimizer model chunks" ) global_idx = chunk_to_global_idx[model_chunk] assert (