Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
a523e46
new script for merging distributed files
Jun 6, 2025
edca2b3
updated merge scripts
Jun 6, 2025
c3912eb
updated merger with logging
Jun 6, 2025
d9da866
updated merger to handle theta
Jun 6, 2025
75fa91d
dist eval script
Jun 6, 2025
09e2886
actually added script
Jun 6, 2025
703a41e
updated cuda mapping
Jun 6, 2025
a13bf15
updated model sharding
Jun 6, 2025
d8e5526
fixed data sharding in eval script
Jun 6, 2025
172f31a
fixed data sharding option
Jun 6, 2025
273cf91
added autocast to eval script
Jun 6, 2025
9ede9c3
new debug script for distributed saving
Jun 6, 2025
d565704
added new script to test gather
Jun 6, 2025
49080bd
updated error checking scripts
Jun 6, 2025
d6f675b
float16 version of eval
Jun 6, 2025
bc9cb41
script to test dtype
Jun 6, 2025
3a4c3d7
new debug scripts
Jun 6, 2025
68a6727
new debug script
Jun 6, 2025
ab3f563
corrected filename
Jun 6, 2025
546845a
updated arg
Jun 6, 2025
f376a09
fixed norm stats indexing
Jun 6, 2025
6cf6d86
added new debug scripts
Jun 6, 2025
36543a1
fixed smoke check
Jun 6, 2025
c9b21f7
batchtopk debug script
Jun 6, 2025
53b2b76
fixed issue in scripts
Jun 6, 2025
03d091b
new scripts
Jun 7, 2025
684143d
output debugger
Jun 7, 2025
f0c6e9f
output debugger fix
Jun 7, 2025
6067df6
rescaling test
Jun 7, 2025
a25a574
debugging for weight corruption
Jun 7, 2025
500bf0f
fixed device issue
Jun 7, 2025
bf89b17
debugging save load mismatch
Jun 7, 2025
ac26b29
fixes for save reload script
Jun 7, 2025
56da8cd
changed location of acts
Jun 7, 2025
f149e8c
manual training step
Jun 7, 2025
1a4dbd2
correct input size
Jun 7, 2025
464be2e
simpler test
Jun 7, 2025
85b183d
device assignment
Jun 7, 2025
12897dc
delegate to clttrainer
Jun 7, 2025
44fb72c
setting correct hyperparams
Jun 7, 2025
bb089a7
restore checkpointing
Jun 7, 2025
ea22a6c
fixed barrier
Jun 7, 2025
93c6cc6
simplified save load test
Jun 7, 2025
26c6077
fixed eval call
Jun 7, 2025
f05cf7f
fixed activation store call
Jun 7, 2025
8ad763c
fixed shapes
Jun 7, 2025
39c7e75
new script
Jun 7, 2025
8d738d9
fixed script
Jun 7, 2025
f8f2fe9
new debugging scripts and findings
curt-tigges Jun 11, 2025
f770cbc
fix for checkpoint file weight duplication
Jun 11, 2025
9c7a154
new checkpointing technique
curt-tigges Jun 12, 2025
2bbbb72
started perf optimization for dist training
curt-tigges Jun 14, 2025
5eecf9a
script cleanup
Jun 16, 2025
b0ee682
further script cleanup
Jun 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ clt_test_pythia_70m_jumprelu/
clt_smoke_output_local_wandb_batchtopk/
clt_smoke_output_remote_wandb/
wandb/
scripts/debug
scripts/optimization

# models
*.pt
Expand Down
3 changes: 3 additions & 0 deletions clt/config/clt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ class TrainingConfig:

# Optional diagnostic metrics (can be slow)
compute_sparsity_diagnostics: bool = False # Whether to compute detailed sparsity diagnostics during eval

# Performance profiling
enable_profiling: bool = False # Whether to enable detailed performance profiling

# Dead feature tracking
dead_feature_window: int = 1000 # Steps until a feature is considered dead
Expand Down
100 changes: 78 additions & 22 deletions clt/models/activations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from typing import Optional, Tuple, Dict, List
from typing import Optional, Tuple, Dict, List, Any
import logging
from clt.config import CLTConfig
from torch.distributed import ProcessGroup
Expand All @@ -26,9 +26,10 @@ def _compute_mask(x: torch.Tensor, k_per_token: int, x_for_ranking: Optional[tor

if k_total_batch > 0:
_, flat_indices = torch.topk(ranking_flat, k_total_batch, sorted=False)
mask_flat = torch.zeros_like(x_flat, dtype=torch.bool)
mask_flat[flat_indices] = True
mask = mask_flat.view_as(x)
# Optimized mask creation - avoid individual indexing
mask = torch.zeros(x_flat.numel(), dtype=torch.bool, device=x.device)
mask[flat_indices] = True
mask = mask.view_as(x)
else:
mask = torch.zeros_like(x, dtype=torch.bool)

Expand Down Expand Up @@ -118,6 +119,7 @@ def _compute_mask(x: torch.Tensor, k_float: float, x_for_ranking: Optional[torch

if k_per_token > 0:
_, topk_indices_per_row = torch.topk(ranking_tensor_to_use, k_per_token, dim=-1, sorted=False)
# Use scatter_ for efficient mask creation
mask = torch.zeros_like(x, dtype=torch.bool)
mask.scatter_(-1, topk_indices_per_row, True)
else:
Expand Down Expand Up @@ -231,6 +233,7 @@ def _apply_batch_topk_helper(
dtype: torch.dtype,
rank: int,
process_group: Optional[ProcessGroup],
profiler: Optional[Any] = None,
) -> Dict[int, torch.Tensor]:
"""Helper to apply BatchTopK globally across concatenated layer pre-activations."""

Expand Down Expand Up @@ -304,17 +307,42 @@ def _apply_batch_topk_helper(

if world_size > 1:
if rank == 0:
local_mask = BatchTopK._compute_mask(
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
)
if profiler:
with profiler.timer("batchtopk_compute_mask") as timer:
local_mask = BatchTopK._compute_mask(
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
)
if hasattr(timer, 'elapsed'):
profiler.record("batchtopk_compute_mask", timer.elapsed)
else:
local_mask = BatchTopK._compute_mask(
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
)
mask.copy_(local_mask)
dist_ops.broadcast(mask, src=0, group=process_group)

if hasattr(profiler, 'dist_profiler') and profiler.dist_profiler:
with profiler.dist_profiler.profile_op("batchtopk_broadcast"):
dist_ops.broadcast(mask, src=0, group=process_group)
else:
dist_ops.broadcast(mask, src=0, group=process_group)
else:
dist_ops.broadcast(mask, src=0, group=process_group)
if hasattr(profiler, 'dist_profiler') and profiler.dist_profiler:
with profiler.dist_profiler.profile_op("batchtopk_broadcast"):
dist_ops.broadcast(mask, src=0, group=process_group)
else:
dist_ops.broadcast(mask, src=0, group=process_group)
else:
mask = BatchTopK._compute_mask(
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
)
if profiler:
with profiler.timer("batchtopk_compute_mask") as timer:
mask = BatchTopK._compute_mask(
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
)
if hasattr(timer, 'elapsed'):
profiler.record("batchtopk_compute_mask", timer.elapsed)
else:
mask = BatchTopK._compute_mask(
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
)

activated_concatenated = concatenated_preactivations_original * mask.to(dtype)

Expand All @@ -336,6 +364,7 @@ def _apply_token_topk_helper(
dtype: torch.dtype,
rank: int,
process_group: Optional[ProcessGroup],
profiler: Optional[Any] = None,
) -> Dict[int, torch.Tensor]:
"""Helper to apply TokenTopK globally across concatenated layer pre-activations."""
world_size = dist_ops.get_world_size(process_group)
Expand Down Expand Up @@ -408,19 +437,46 @@ def _apply_token_topk_helper(

if world_size > 1:
if rank == 0:
local_mask = TokenTopK._compute_mask(
concatenated_preactivations_original,
k_val_float,
concatenated_preactivations_normalized,
)
if profiler:
with profiler.timer("topk_compute_mask") as timer:
local_mask = TokenTopK._compute_mask(
concatenated_preactivations_original,
k_val_float,
concatenated_preactivations_normalized,
)
if hasattr(timer, 'elapsed'):
profiler.record("topk_compute_mask", timer.elapsed)
else:
local_mask = TokenTopK._compute_mask(
concatenated_preactivations_original,
k_val_float,
concatenated_preactivations_normalized,
)
mask.copy_(local_mask)
dist_ops.broadcast(mask, src=0, group=process_group)

if hasattr(profiler, 'dist_profiler') and profiler.dist_profiler:
with profiler.dist_profiler.profile_op("topk_broadcast"):
dist_ops.broadcast(mask, src=0, group=process_group)
else:
dist_ops.broadcast(mask, src=0, group=process_group)
else:
dist_ops.broadcast(mask, src=0, group=process_group)
if hasattr(profiler, 'dist_profiler') and profiler.dist_profiler:
with profiler.dist_profiler.profile_op("topk_broadcast"):
dist_ops.broadcast(mask, src=0, group=process_group)
else:
dist_ops.broadcast(mask, src=0, group=process_group)
else:
mask = TokenTopK._compute_mask(
concatenated_preactivations_original, k_val_float, concatenated_preactivations_normalized
)
if profiler:
with profiler.timer("topk_compute_mask") as timer:
mask = TokenTopK._compute_mask(
concatenated_preactivations_original, k_val_float, concatenated_preactivations_normalized
)
if hasattr(timer, 'elapsed'):
profiler.record("topk_compute_mask", timer.elapsed)
else:
mask = TokenTopK._compute_mask(
concatenated_preactivations_original, k_val_float, concatenated_preactivations_normalized
)

activated_concatenated = concatenated_preactivations_original * mask.to(dtype)

Expand Down
Loading
Loading