22import torch
33import os
44import torch .distributed as dist
5- from torch . distributed . checkpoint . state_dict_saver import save_state_dict
5+ from safetensors . torch import save_file as save_safetensors_file , load_file as load_safetensors_file
66from torch .distributed .checkpoint .state_dict_loader import load_state_dict
7- from torch .distributed .checkpoint .default_planner import DefaultSavePlanner , DefaultLoadPlanner
8- from torch .distributed .checkpoint .filesystem import FileSystemWriter , FileSystemReader
7+ from torch .distributed .checkpoint .default_planner import DefaultLoadPlanner
8+ from torch .distributed .checkpoint .filesystem import FileSystemReader
99from typing import Optional , Union , Dict , Any , TYPE_CHECKING
10- from safetensors .torch import save_file as save_safetensors_file , load_file as load_safetensors_file
1110import logging
11+ import shutil
12+ from pathlib import Path
1213
1314# Import for type hinting, moved outside TYPE_CHECKING for runtime availability
1415from clt .training .wandb_logger import WandBLogger , DummyWandBLogger
@@ -67,12 +68,15 @@ def _save_checkpoint(
6768 # For example: trainer_state_to_save["step"] should be === step passed to this function.
6869 # We will save this entire dictionary.
6970
70- if not self .distributed : # Non-distributed save
71- os .makedirs (self .log_dir , exist_ok = True )
71+ # Ensure log_dir exists
72+ os .makedirs (self .log_dir , exist_ok = True )
73+
74+ # Non-distributed: save model, trainer state, and store state directly
75+ if not self .distributed :
7276 model_checkpoint_path = os .path .join (self .log_dir , f"clt_checkpoint_{ step } .safetensors" )
7377 latest_model_path = os .path .join (self .log_dir , "clt_checkpoint_latest.safetensors" )
74- store_checkpoint_path = os .path .join (self .log_dir , f"activation_store_checkpoint_ { step } .pt" )
75- latest_store_path = os .path .join (self .log_dir , "activation_store_checkpoint_latest .pt" )
78+ store_checkpoint_path = os .path .join (self .log_dir , f"activation_store_ { step } .pt" )
79+ latest_store_path = os .path .join (self .log_dir , "activation_store_latest .pt" )
7680 trainer_state_path = os .path .join (self .log_dir , f"trainer_state_{ step } .pt" )
7781 latest_trainer_state_path = os .path .join (self .log_dir , "trainer_state_latest.pt" )
7882
@@ -97,34 +101,36 @@ def _save_checkpoint(
97101 # --- Distributed Save ---
98102 checkpoint_dir = os .path .join (self .log_dir , f"step_{ step } " )
99103 latest_checkpoint_dir = os .path .join (self .log_dir , "latest" )
100-
104+
101105 # Create directories if they don't exist (all ranks should do this)
102106 os .makedirs (checkpoint_dir , exist_ok = True )
103107 os .makedirs (latest_checkpoint_dir , exist_ok = True )
104-
108+
105109 # Ensure all ranks see the directories before proceeding
106110 if self .distributed :
107111 dist .barrier ()
108112
109113 # Save model state dict using distributed checkpointing
110114 model_state_dict_for_dist_save = self .model .state_dict ()
111-
112- # Option 1: Save per-rank checkpoints separately for debugging
115+
116+ # Save per-rank checkpoints separately (workaround for PyTorch distributed checkpoint bug)
113117 rank_checkpoint_path = os .path .join (checkpoint_dir , f"rank_{ self .rank } _model.pt" )
114118 latest_rank_checkpoint_path = os .path .join (latest_checkpoint_dir , f"rank_{ self .rank } _model.pt" )
115-
119+
116120 try :
117- # Save individual rank files
121+ # Save individual rank files (workaround for PyTorch distributed checkpoint bug)
122+ # CRITICAL: Each rank must get its OWN model's state dict to avoid the weight duplication bug
123+ # where all ranks would save rank 0's weights. See scripts/debug/distributed_checkpoint_bug_analysis.md
118124 torch .save (model_state_dict_for_dist_save , rank_checkpoint_path )
119125 torch .save (model_state_dict_for_dist_save , latest_rank_checkpoint_path )
120126 logger .info (f"Rank { self .rank } : Saved individual checkpoint to { rank_checkpoint_path } " )
121-
127+
122128 # Debug: Check what we saved
123129 enc_key = "encoder_module.encoders.0.weight"
124130 if enc_key in model_state_dict_for_dist_save :
125131 checksum = torch .sum (torch .abs (model_state_dict_for_dist_save [enc_key ])).item ()
126132 logger .info (f"Rank { self .rank } : Saved { enc_key } with checksum { checksum :.6f} " )
127-
133+
128134 # Skip saving distributed checkpoint (.distcp files) to save space
129135 # We're using individual rank files instead due to PyTorch bug
130136 pass
@@ -136,7 +142,7 @@ def _save_checkpoint(
136142 # Wait for all ranks to save their individual checkpoints
137143 if self .distributed :
138144 dist .barrier ()
139-
145+
140146 if self .rank == 0 :
141147 # Save activation store
142148 store_checkpoint_path = os .path .join (checkpoint_dir , "activation_store.pt" )
@@ -146,12 +152,12 @@ def _save_checkpoint(
146152 torch .save (self .activation_store .state_dict (), latest_store_path )
147153 except Exception as e :
148154 logger .warning (f"Rank 0: Warning: Failed to save activation store state at step { step } : { e } " )
149-
155+
150156 # Merge individual rank checkpoints into consolidated model
151157 # This is a workaround for the PyTorch distributed checkpoint bug
152158 try :
153159 logger .info (f"Rank 0: Merging { self .world_size } rank checkpoints..." )
154-
160+
155161 # Load all rank state dicts
156162 state_dicts = []
157163 for rank in range (self .world_size ):
@@ -161,21 +167,21 @@ def _save_checkpoint(
161167 state_dicts .append (state_dict )
162168 else :
163169 logger .error (f"Rank 0: Missing rank checkpoint: { rank_path } " )
164- state_dicts = None
170+ state_dicts = [] # Re-initialize as empty list to break and fail gracefully
165171 break
166-
172+
167173 if state_dicts and len (state_dicts ) == self .world_size :
168174 # Merge the state dicts
169175 merged_state = self ._merge_tensor_parallel_weights (state_dicts )
170-
176+
171177 # Save as safetensors
172178 model_safetensors_path = os .path .join (checkpoint_dir , "model.safetensors" )
173179 latest_model_safetensors_path = os .path .join (latest_checkpoint_dir , "model.safetensors" )
174180 save_safetensors_file (merged_state , model_safetensors_path )
175181 save_safetensors_file (merged_state , latest_model_safetensors_path )
176182 logger .info (f"Rank 0: Saved merged model to { model_safetensors_path } " )
177183 else :
178- logger .error (f "Rank 0: Failed to merge rank checkpoints - missing files" )
184+ logger .error ("Rank 0: Failed to merge rank checkpoints - missing files" )
179185 # Fall back to single rank save
180186 model_safetensors_path = os .path .join (checkpoint_dir , "model.safetensors" )
181187 latest_model_safetensors_path = os .path .join (latest_checkpoint_dir , "model.safetensors" )
@@ -216,7 +222,7 @@ def _save_checkpoint(
216222
217223 if self .distributed :
218224 dist .barrier ()
219-
225+
220226 # Clean up old checkpoints to save space
221227 if self .rank == 0 and self .keep_n_checkpoints > 0 :
222228 self ._cleanup_old_checkpoints ()
@@ -256,11 +262,11 @@ def load_checkpoint(self, checkpoint_path: str) -> Dict[str, Any]:
256262 trainer_state_fname = ""
257263
258264 if base_name == "clt_checkpoint_latest.safetensors" :
259- store_checkpoint_fname = "activation_store_checkpoint_latest .pt"
265+ store_checkpoint_fname = "activation_store_latest .pt"
260266 trainer_state_fname = "trainer_state_latest.pt"
261267 elif base_name .startswith ("clt_checkpoint_" ) and base_name .endswith (".safetensors" ):
262268 step_str = base_name .replace ("clt_checkpoint_" , "" ).replace (".safetensors" , "" )
263- store_checkpoint_fname = f"activation_store_checkpoint_ { step_str } .pt"
269+ store_checkpoint_fname = f"activation_store_ { step_str } .pt"
264270 trainer_state_fname = f"trainer_state_{ step_str } .pt"
265271
266272 if not store_checkpoint_fname or not trainer_state_fname :
@@ -430,17 +436,17 @@ def _load_non_distributed_checkpoint(self, checkpoint_path: str, store_checkpoin
430436 else : # for older checkpoints that might not have extension in prefix string
431437 store_basename_prefix = store_basename_prefix + ".pt"
432438
433- # Ensure it correctly forms activation_store_checkpoint_ {step}.pt
439+ # Ensure it correctly forms activation_store_ {step}.pt
434440 if "latest" in basename :
435- store_basename = "activation_store_checkpoint_latest .pt"
441+ store_basename = "activation_store_latest .pt"
436442 else :
437443 # Extract step from basename like clt_checkpoint_100.safetensors -> 100
438444 step_str = basename .split ("_" )[- 1 ].split ("." )[0 ]
439- store_basename = f"activation_store_checkpoint_ { step_str } .pt"
445+ store_basename = f"activation_store_ { step_str } .pt"
440446 store_checkpoint_path = os .path .join (dirname , store_basename )
441447 # No change for clt_checkpoint_latest.pt because it's specific enough
442448 elif basename == "clt_checkpoint_latest.pt" or basename == "clt_checkpoint_latest.safetensors" :
443- store_checkpoint_path = os .path .join (dirname , "activation_store_checkpoint_latest .pt" )
449+ store_checkpoint_path = os .path .join (dirname , "activation_store_latest .pt" )
444450 else :
445451 store_checkpoint_path = None
446452
@@ -458,21 +464,20 @@ def _load_non_distributed_checkpoint(self, checkpoint_path: str, store_checkpoin
458464 logger .warning (
459465 f"Warning: Activation store checkpoint path not found or specified: { store_checkpoint_path } . Store state not loaded."
460466 )
461-
467+
462468 def _merge_tensor_parallel_weights (self , state_dicts : list ) -> Dict [str , torch .Tensor ]:
463469 """
464470 Merge tensor-parallel weights from multiple ranks into a single state dict.
465471 This is a workaround for the PyTorch distributed checkpoint bug.
466472 """
467473 merged_state = {}
468- world_size = len (state_dicts )
469-
474+
470475 # Get all parameter names from first rank
471476 param_names = list (state_dicts [0 ].keys ())
472-
477+
473478 for name in param_names :
474479 tensors = [sd [name ] for sd in state_dicts ]
475-
480+
476481 # Check if this is a tensor-parallel weight that needs concatenation
477482 if "encoder_module.encoders" in name :
478483 if "weight" in name :
@@ -484,28 +489,25 @@ def _merge_tensor_parallel_weights(self, state_dicts: list) -> Dict[str, torch.T
484489 else :
485490 # Other encoder parameters
486491 merged_state [name ] = tensors [0 ]
487-
492+
488493 elif "decoder_module.decoders" in name and "weight" in name :
489494 # Decoder weights are sharded along dim 1 (input features)
490495 merged_state [name ] = torch .cat (tensors , dim = 1 )
491-
496+
492497 elif "log_threshold" in name :
493498 # For BatchTopK threshold, concatenate the per-layer thresholds
494499 merged_state [name ] = torch .cat (tensors , dim = 1 )
495-
500+
496501 else :
497502 # For replicated parameters (biases, layer norms, etc.), use rank 0's version
498503 merged_state [name ] = tensors [0 ]
499-
504+
500505 return merged_state
501-
506+
502507 def _cleanup_old_checkpoints (self ):
503508 """Remove old checkpoints to save disk space, keeping only the last N."""
504- import shutil
505- from pathlib import Path
506-
507509 log_path = Path (self .log_dir )
508-
510+
509511 # Find all step directories
510512 step_dirs = []
511513 for item in log_path .iterdir ():
@@ -515,14 +517,14 @@ def _cleanup_old_checkpoints(self):
515517 step_dirs .append ((step_num , item ))
516518 except ValueError :
517519 continue
518-
520+
519521 # Sort by step number
520522 step_dirs .sort (key = lambda x : x [0 ])
521-
523+
522524 # Keep only the last N checkpoints
523525 if len (step_dirs ) > self .keep_n_checkpoints :
524- dirs_to_remove = step_dirs [:- self .keep_n_checkpoints ]
525-
526+ dirs_to_remove = step_dirs [: - self .keep_n_checkpoints ]
527+
526528 for step_num , dir_path in dirs_to_remove :
527529 try :
528530 shutil .rmtree (dir_path )
0 commit comments