Skip to content

Commit 6fa8a62

Browse files
author
Curt Tigges
committed
added extensive set of more tests
1 parent 80ca43a commit 6fa8a62

28 files changed

Lines changed: 1870 additions & 224 deletions

clt/training/checkpointing.py

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
import torch
33
import os
44
import torch.distributed as dist
5-
from torch.distributed.checkpoint.state_dict_saver import save_state_dict
5+
from safetensors.torch import save_file as save_safetensors_file, load_file as load_safetensors_file
66
from torch.distributed.checkpoint.state_dict_loader import load_state_dict
7-
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner, DefaultLoadPlanner
8-
from torch.distributed.checkpoint.filesystem import FileSystemWriter, FileSystemReader
7+
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
8+
from torch.distributed.checkpoint.filesystem import FileSystemReader
99
from typing import Optional, Union, Dict, Any, TYPE_CHECKING
10-
from safetensors.torch import save_file as save_safetensors_file, load_file as load_safetensors_file
1110
import logging
11+
import shutil
12+
from pathlib import Path
1213

1314
# Import for type hinting, moved outside TYPE_CHECKING for runtime availability
1415
from clt.training.wandb_logger import WandBLogger, DummyWandBLogger
@@ -67,12 +68,15 @@ def _save_checkpoint(
6768
# For example: trainer_state_to_save["step"] should be === step passed to this function.
6869
# We will save this entire dictionary.
6970

70-
if not self.distributed: # Non-distributed save
71-
os.makedirs(self.log_dir, exist_ok=True)
71+
# Ensure log_dir exists
72+
os.makedirs(self.log_dir, exist_ok=True)
73+
74+
# Non-distributed: save model, trainer state, and store state directly
75+
if not self.distributed:
7276
model_checkpoint_path = os.path.join(self.log_dir, f"clt_checkpoint_{step}.safetensors")
7377
latest_model_path = os.path.join(self.log_dir, "clt_checkpoint_latest.safetensors")
74-
store_checkpoint_path = os.path.join(self.log_dir, f"activation_store_checkpoint_{step}.pt")
75-
latest_store_path = os.path.join(self.log_dir, "activation_store_checkpoint_latest.pt")
78+
store_checkpoint_path = os.path.join(self.log_dir, f"activation_store_{step}.pt")
79+
latest_store_path = os.path.join(self.log_dir, "activation_store_latest.pt")
7680
trainer_state_path = os.path.join(self.log_dir, f"trainer_state_{step}.pt")
7781
latest_trainer_state_path = os.path.join(self.log_dir, "trainer_state_latest.pt")
7882

@@ -97,34 +101,36 @@ def _save_checkpoint(
97101
# --- Distributed Save ---
98102
checkpoint_dir = os.path.join(self.log_dir, f"step_{step}")
99103
latest_checkpoint_dir = os.path.join(self.log_dir, "latest")
100-
104+
101105
# Create directories if they don't exist (all ranks should do this)
102106
os.makedirs(checkpoint_dir, exist_ok=True)
103107
os.makedirs(latest_checkpoint_dir, exist_ok=True)
104-
108+
105109
# Ensure all ranks see the directories before proceeding
106110
if self.distributed:
107111
dist.barrier()
108112

109113
# Save model state dict using distributed checkpointing
110114
model_state_dict_for_dist_save = self.model.state_dict()
111-
112-
# Option 1: Save per-rank checkpoints separately for debugging
115+
116+
# Save per-rank checkpoints separately (workaround for PyTorch distributed checkpoint bug)
113117
rank_checkpoint_path = os.path.join(checkpoint_dir, f"rank_{self.rank}_model.pt")
114118
latest_rank_checkpoint_path = os.path.join(latest_checkpoint_dir, f"rank_{self.rank}_model.pt")
115-
119+
116120
try:
117-
# Save individual rank files
121+
# Save individual rank files (workaround for PyTorch distributed checkpoint bug)
122+
# CRITICAL: Each rank must get its OWN model's state dict to avoid the weight duplication bug
123+
# where all ranks would save rank 0's weights. See scripts/debug/distributed_checkpoint_bug_analysis.md
118124
torch.save(model_state_dict_for_dist_save, rank_checkpoint_path)
119125
torch.save(model_state_dict_for_dist_save, latest_rank_checkpoint_path)
120126
logger.info(f"Rank {self.rank}: Saved individual checkpoint to {rank_checkpoint_path}")
121-
127+
122128
# Debug: Check what we saved
123129
enc_key = "encoder_module.encoders.0.weight"
124130
if enc_key in model_state_dict_for_dist_save:
125131
checksum = torch.sum(torch.abs(model_state_dict_for_dist_save[enc_key])).item()
126132
logger.info(f"Rank {self.rank}: Saved {enc_key} with checksum {checksum:.6f}")
127-
133+
128134
# Skip saving distributed checkpoint (.distcp files) to save space
129135
# We're using individual rank files instead due to PyTorch bug
130136
pass
@@ -136,7 +142,7 @@ def _save_checkpoint(
136142
# Wait for all ranks to save their individual checkpoints
137143
if self.distributed:
138144
dist.barrier()
139-
145+
140146
if self.rank == 0:
141147
# Save activation store
142148
store_checkpoint_path = os.path.join(checkpoint_dir, "activation_store.pt")
@@ -146,12 +152,12 @@ def _save_checkpoint(
146152
torch.save(self.activation_store.state_dict(), latest_store_path)
147153
except Exception as e:
148154
logger.warning(f"Rank 0: Warning: Failed to save activation store state at step {step}: {e}")
149-
155+
150156
# Merge individual rank checkpoints into consolidated model
151157
# This is a workaround for the PyTorch distributed checkpoint bug
152158
try:
153159
logger.info(f"Rank 0: Merging {self.world_size} rank checkpoints...")
154-
160+
155161
# Load all rank state dicts
156162
state_dicts = []
157163
for rank in range(self.world_size):
@@ -161,21 +167,21 @@ def _save_checkpoint(
161167
state_dicts.append(state_dict)
162168
else:
163169
logger.error(f"Rank 0: Missing rank checkpoint: {rank_path}")
164-
state_dicts = None
170+
state_dicts = [] # Re-initialize as empty list to break and fail gracefully
165171
break
166-
172+
167173
if state_dicts and len(state_dicts) == self.world_size:
168174
# Merge the state dicts
169175
merged_state = self._merge_tensor_parallel_weights(state_dicts)
170-
176+
171177
# Save as safetensors
172178
model_safetensors_path = os.path.join(checkpoint_dir, "model.safetensors")
173179
latest_model_safetensors_path = os.path.join(latest_checkpoint_dir, "model.safetensors")
174180
save_safetensors_file(merged_state, model_safetensors_path)
175181
save_safetensors_file(merged_state, latest_model_safetensors_path)
176182
logger.info(f"Rank 0: Saved merged model to {model_safetensors_path}")
177183
else:
178-
logger.error(f"Rank 0: Failed to merge rank checkpoints - missing files")
184+
logger.error("Rank 0: Failed to merge rank checkpoints - missing files")
179185
# Fall back to single rank save
180186
model_safetensors_path = os.path.join(checkpoint_dir, "model.safetensors")
181187
latest_model_safetensors_path = os.path.join(latest_checkpoint_dir, "model.safetensors")
@@ -216,7 +222,7 @@ def _save_checkpoint(
216222

217223
if self.distributed:
218224
dist.barrier()
219-
225+
220226
# Clean up old checkpoints to save space
221227
if self.rank == 0 and self.keep_n_checkpoints > 0:
222228
self._cleanup_old_checkpoints()
@@ -256,11 +262,11 @@ def load_checkpoint(self, checkpoint_path: str) -> Dict[str, Any]:
256262
trainer_state_fname = ""
257263

258264
if base_name == "clt_checkpoint_latest.safetensors":
259-
store_checkpoint_fname = "activation_store_checkpoint_latest.pt"
265+
store_checkpoint_fname = "activation_store_latest.pt"
260266
trainer_state_fname = "trainer_state_latest.pt"
261267
elif base_name.startswith("clt_checkpoint_") and base_name.endswith(".safetensors"):
262268
step_str = base_name.replace("clt_checkpoint_", "").replace(".safetensors", "")
263-
store_checkpoint_fname = f"activation_store_checkpoint_{step_str}.pt"
269+
store_checkpoint_fname = f"activation_store_{step_str}.pt"
264270
trainer_state_fname = f"trainer_state_{step_str}.pt"
265271

266272
if not store_checkpoint_fname or not trainer_state_fname:
@@ -430,17 +436,17 @@ def _load_non_distributed_checkpoint(self, checkpoint_path: str, store_checkpoin
430436
else: # for older checkpoints that might not have extension in prefix string
431437
store_basename_prefix = store_basename_prefix + ".pt"
432438

433-
# Ensure it correctly forms activation_store_checkpoint_{step}.pt
439+
# Ensure it correctly forms activation_store_{step}.pt
434440
if "latest" in basename:
435-
store_basename = "activation_store_checkpoint_latest.pt"
441+
store_basename = "activation_store_latest.pt"
436442
else:
437443
# Extract step from basename like clt_checkpoint_100.safetensors -> 100
438444
step_str = basename.split("_")[-1].split(".")[0]
439-
store_basename = f"activation_store_checkpoint_{step_str}.pt"
445+
store_basename = f"activation_store_{step_str}.pt"
440446
store_checkpoint_path = os.path.join(dirname, store_basename)
441447
# No change for clt_checkpoint_latest.pt because it's specific enough
442448
elif basename == "clt_checkpoint_latest.pt" or basename == "clt_checkpoint_latest.safetensors":
443-
store_checkpoint_path = os.path.join(dirname, "activation_store_checkpoint_latest.pt")
449+
store_checkpoint_path = os.path.join(dirname, "activation_store_latest.pt")
444450
else:
445451
store_checkpoint_path = None
446452

@@ -458,21 +464,20 @@ def _load_non_distributed_checkpoint(self, checkpoint_path: str, store_checkpoin
458464
logger.warning(
459465
f"Warning: Activation store checkpoint path not found or specified: {store_checkpoint_path}. Store state not loaded."
460466
)
461-
467+
462468
def _merge_tensor_parallel_weights(self, state_dicts: list) -> Dict[str, torch.Tensor]:
463469
"""
464470
Merge tensor-parallel weights from multiple ranks into a single state dict.
465471
This is a workaround for the PyTorch distributed checkpoint bug.
466472
"""
467473
merged_state = {}
468-
world_size = len(state_dicts)
469-
474+
470475
# Get all parameter names from first rank
471476
param_names = list(state_dicts[0].keys())
472-
477+
473478
for name in param_names:
474479
tensors = [sd[name] for sd in state_dicts]
475-
480+
476481
# Check if this is a tensor-parallel weight that needs concatenation
477482
if "encoder_module.encoders" in name:
478483
if "weight" in name:
@@ -484,28 +489,25 @@ def _merge_tensor_parallel_weights(self, state_dicts: list) -> Dict[str, torch.T
484489
else:
485490
# Other encoder parameters
486491
merged_state[name] = tensors[0]
487-
492+
488493
elif "decoder_module.decoders" in name and "weight" in name:
489494
# Decoder weights are sharded along dim 1 (input features)
490495
merged_state[name] = torch.cat(tensors, dim=1)
491-
496+
492497
elif "log_threshold" in name:
493498
# For BatchTopK threshold, concatenate the per-layer thresholds
494499
merged_state[name] = torch.cat(tensors, dim=1)
495-
500+
496501
else:
497502
# For replicated parameters (biases, layer norms, etc.), use rank 0's version
498503
merged_state[name] = tensors[0]
499-
504+
500505
return merged_state
501-
506+
502507
def _cleanup_old_checkpoints(self):
503508
"""Remove old checkpoints to save disk space, keeping only the last N."""
504-
import shutil
505-
from pathlib import Path
506-
507509
log_path = Path(self.log_dir)
508-
510+
509511
# Find all step directories
510512
step_dirs = []
511513
for item in log_path.iterdir():
@@ -515,14 +517,14 @@ def _cleanup_old_checkpoints(self):
515517
step_dirs.append((step_num, item))
516518
except ValueError:
517519
continue
518-
520+
519521
# Sort by step number
520522
step_dirs.sort(key=lambda x: x[0])
521-
523+
522524
# Keep only the last N checkpoints
523525
if len(step_dirs) > self.keep_n_checkpoints:
524-
dirs_to_remove = step_dirs[:-self.keep_n_checkpoints]
525-
526+
dirs_to_remove = step_dirs[: -self.keep_n_checkpoints]
527+
526528
for step_num, dir_path in dirs_to_remove:
527529
try:
528530
shutil.rmtree(dir_path)

clt/training/data/local_activation_store.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515

1616
logger = logging.getLogger(__name__)
1717

18+
# Helper to map torch dtypes to numpy dtypes for conversion
19+
TORCH_TO_NUMPY_DTYPE_MAP = {
20+
torch.float32: np.float32,
21+
torch.float16: np.float16,
22+
torch.bfloat16: np.float32, # NumPy doesn't have bfloat16, so we use float32 as an intermediate
23+
}
24+
1825

1926
class LocalActivationStore(ManifestActivationStore):
2027
"""
@@ -182,7 +189,7 @@ def _load_norm_stats(self) -> Optional[Dict[str, Any]]:
182189
return None
183190

184191
@lru_cache(maxsize=64)
185-
def _load_chunk(self, chunk_path: str, layer_key: str, data_type: str):
192+
def _load_chunk(self, chunk_path: str, layer_key: str, data_type: str) -> np.ndarray:
186193
"""Loads entire HDF5 chunk from disk and caches"""
187194

188195
logger.debug(f"Fetching chunk {chunk_path} / {layer_key} / {data_type}")
@@ -240,11 +247,26 @@ def _layer_sort_key(name: str) -> int:
240247
else:
241248
row_indices_h5 = row_indices
242249

250+
# PyTorch doesn't support from_buffer for bfloat16, so we can't create bfloat16 bytes.
251+
# The conversion to bfloat16 must happen after the tensor is created.
252+
# If the requested dtype is bfloat16, we'll load the data as float32 bytes.
253+
if self.dtype == torch.bfloat16:
254+
target_np_dtype = np.float32
255+
else:
256+
# e.g., str(torch.float32) -> "torch.float32" -> "float32"
257+
dtype_str = str(self.dtype).split(".")[-1]
258+
target_np_dtype = np.dtype(dtype_str)
259+
243260
for i, lk in enumerate(layer_keys):
244261
input_data = self._load_chunk(chunk_path, lk, "inputs")[row_indices_h5, :]
245262
target_data = self._load_chunk(chunk_path, lk, "targets")[row_indices_h5, :]
246-
bufs.append(input_data.tobytes())
247-
bufs.append(target_data.tobytes())
263+
264+
# Convert to the target numpy dtype before getting bytes
265+
input_data_converted = input_data.astype(target_np_dtype)
266+
target_data_converted = target_data.astype(target_np_dtype)
267+
268+
bufs.append(input_data_converted.tobytes())
269+
bufs.append(target_data_converted.tobytes())
248270
return b"".join(bufs)
249271
except KeyError as e:
250272
logger.error(f"Error accessing data within chunk {chunk_id} at {chunk_path}: Missing key {e}")

clt/training/data/manifest_activation_store.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,9 @@ def _reset_generator_internal_state(self):
149149
self.batches_yielded_this_epoch = 0
150150

151151
def __iter__(self):
152-
# Reset iteration state based on strategy
153-
if self.sampling_strategy == "sequential":
154-
self.current_chunk_idx_in_order = 0
155-
self.current_row_offset_in_chunk = 0
156-
elif self.sampling_strategy == "random_chunk":
157-
self.batches_yielded_this_epoch = 0
152+
# The state (e.g., current position in iteration) is managed by the
153+
# attributes set during __init__ or load_state_dict.
154+
# This method just needs to return self so that the object is an iterator.
158155
return self
159156

160157
def __next__(self):
@@ -889,24 +886,29 @@ def _fetch_and_parse_batch(self, idxs: np.ndarray) -> ActivationBatch:
889886
final_batch_inputs: Dict[int, torch.Tensor] = {li: torch.cat(tensors) for li, tensors in layer_inputs.items()}
890887
final_batch_targets: Dict[int, torch.Tensor] = {li: torch.cat(tensors) for li, tensors in layer_targets.items()}
891888

892-
# 5. Apply Normalization (if enabled)
889+
# 5. Apply Normalization (if enabled) and Final Dtype Conversion
893890
if self.apply_normalization:
894891
log_stats_this_batch = {}
895892
for li in self.layer_indices:
896-
if li == 0 and final_batch_inputs[li].numel() > 0:
897-
inp_before = final_batch_inputs[li]
898-
log_stats_this_batch["inp_mean_before"] = inp_before.float().mean().item()
899-
log_stats_this_batch["inp_std_before"] = inp_before.float().std().item()
893+
# Always convert to float32 for normalization arithmetic
894+
inputs_li = final_batch_inputs[li].float()
895+
targets_li = final_batch_targets[li].float()
896+
897+
if li == 0 and inputs_li.numel() > 0:
898+
log_stats_this_batch["inp_mean_before"] = inputs_li.mean().item()
899+
log_stats_this_batch["inp_std_before"] = inputs_li.std().item()
900900
if li in self.mean_in:
901901
log_stats_this_batch["target_mean_in"] = self.mean_in[li].mean().item()
902902
log_stats_this_batch["target_std_in"] = self.std_in[li].mean().item()
903903

904904
if li in self.mean_in and li in self.std_in:
905-
final_batch_inputs[li] = (final_batch_inputs[li].float() - self.mean_in[li]) / self.std_in[li]
906-
final_batch_inputs[li] = final_batch_inputs[li].to(self.dtype)
905+
inputs_li = (inputs_li - self.mean_in[li]) / self.std_in[li]
907906
if li in self.mean_tg and li in self.std_tg:
908-
final_batch_targets[li] = (final_batch_targets[li].float() - self.mean_tg[li]) / self.std_tg[li]
909-
final_batch_targets[li] = final_batch_targets[li].to(self.dtype)
907+
targets_li = (targets_li - self.mean_tg[li]) / self.std_tg[li]
908+
909+
# Convert to final target dtype *after* normalization
910+
final_batch_inputs[li] = inputs_li.to(self.dtype)
911+
final_batch_targets[li] = targets_li.to(self.dtype)
910912

911913
if li == 0 and final_batch_inputs[li].numel() > 0:
912914
inp_after = final_batch_inputs[li]
@@ -915,6 +917,14 @@ def _fetch_and_parse_batch(self, idxs: np.ndarray) -> ActivationBatch:
915917

916918
if log_stats_this_batch:
917919
logger.debug(f"Normalization Stats (Layer 0): {log_stats_this_batch}")
920+
else:
921+
# If no normalization, just ensure final dtype is correct.
922+
# This is where the bfloat16 conversion happens safely.
923+
for li in self.layer_indices:
924+
if final_batch_inputs[li].dtype != self.dtype:
925+
final_batch_inputs[li] = final_batch_inputs[li].to(self.dtype)
926+
if final_batch_targets[li].dtype != self.dtype:
927+
final_batch_targets[li] = final_batch_targets[li].to(self.dtype)
918928

919929
parse_duration = time.monotonic() - parse_start_time
920930
total_duration = time.monotonic() - fetch_start_time

0 commit comments

Comments
 (0)