Warning
This is highly experimental, and prepare for the worst hope for the best.
Selective loss computation for Transformer training. Only hard tokens contribute to loss, providing actual backward pass savings.
pip install cggrFor CUDA acceleration with Triton kernels (Linux/Windows):
pip install cggr[cuda]| Platform | Triton Kernels | PyTorch Fallback | Status |
|---|---|---|---|
| CUDA (Linux) | ✓ | ✓ | Full Support |
| CUDA (Windows) | ✓ | ✓ | Full Support |
| ROCm (AMD) | ✗ | ✓ | Supported |
| MPS (Apple Silicon) | ✗ | ✓ | Supported |
| CPU | ✗ | ✓ | Supported |
| Architecture | Auto-Detect | Notes |
|---|---|---|
| Llama/Mistral/Qwen/Gemma/Phi-3 | ✓ | model.layers style |
| GPT-2/GPT-J/Falcon/GPT-NeoX | ✓ | transformer.h style |
| BERT/RoBERTa | ✓ | encoder.layer style |
| Mamba/SSM | ✓ | backbone.layers style |
| Other | Passthrough | Uses full model as router |
CGGR supports Flash Attention for memory-efficient attention computation:
# Install with Flash Attention support
pip install cggr[flash]from cggr_flash import load_model_with_flash_attention, enable_flash_attention
# Option 1: Load model with Flash Attention
model = load_model_with_flash_attention("microsoft/phi-2")
# Option 2: Enable on existing model
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("...")
model = enable_flash_attention(model) # Auto-selects best backend| Backend | Requirements | Speed |
|---|---|---|
flash_attention_2 |
flash-attn library | Fastest |
sdpa |
PyTorch 2.0+ | Fast |
eager |
None | Baseline |
| Metric | Standard Training | CGGR (Batch Split) | Benefit |
|---|---|---|---|
| Backward Pass | 100% of tokens | 25% of tokens | 4x cheaper backward pass |
| Forward Pass | 1.0x cost | ~1.1x cost (Pass 1 + 2) | Negligible overhead (~9ms) |
| Total Speed | 1.0x (Baseline) | 1.4x - 2.0x faster | Significant training acceleration |
| Data Efficiency | Learns from all tokens | Prioritizes hard tokens | Learns faster from hard examples |
| Memory | High (full graph) | Lower (sparse graph) | Can increase batch size |
Model: HuggingFaceTB/SmolLM-135M
Dataset: AI-MO/NuminaMath-1.5
Evaluation: GSM8K (Math Reasoning)
Important
Key Result: CGGR achieved 4x higher sample throughput and +1.5% Accuracy by utilizing idle compute cycles.
| Metric | Standard (BS=1) | CGGR (BS=4) | Improvement |
|---|---|---|---|
| Final Accuracy (GSM8K) | 8.00% | 9.50% | +1.50% |
| Final Loss | 0.3610 | 0.0980 | -73% |
| Total Samples Processed | ~14,368 | ~58,716 | 4.08x |
| Wall Clock Time | 6 Hours | 6 Hours | Same |
In standard training (Batch Size = 1), high-end GPUs are often latency-bound, spending more time waiting for memory transfers than doing math.
CGGR exploits this by quadrupling the batch size (Batch Size = 4) without increasing the step time.
- Step Latency: Unchanged (1.02x throughput ratio in steps/sec).
- Data Throughput: 4.08x higher (samples/sec).
By processing 4x more data in the same timeframe, CGGR converges significantly faster and deeper (Loss 0.09 vs 0.36).
The most efficient way to use CGGR is via CGGRModel. It uses a lightweight router to score difficulty and only computes gradients for hard tokens.
from cggr import CGGRModel, create_truncated_router
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("...").cuda()
# Create lightweight router (shares weights, 0 extra memory)
router = create_truncated_router(model, num_layers=4)
# Wrap model
cggr_model = CGGRModel(
model,
router=router,
min_tokens_ratio=0.25
)
# Train
loss = cggr_model(input_ids, labels=labels)
loss.backward()If you cannot use CGGRModel (e.g. specialized architectures), you can use CGGRLoss manually.
from cggr import CGGRLoss
criterion = CGGRLoss(
scoring='combined', # 'entropy', 'margin', 'loss', 'combined'
selection='stratified', # 'topk', 'stratified', 'sequence_aware'
min_tokens_ratio=0.25,
warmup_steps=1000,
)
for batch in dataloader:
logits = model(input_ids)
loss = criterion(logits, targets) # Only hard tokens
loss.backward()
optimizer.step()
criterion.step()For architectures like SRDE that require static tensor shapes, you can use CGGRScorer to generate a routing mask instead of splitting the batch.
from cggr import CGGRScorer
# 1. Initialize Scorer
self.scorer = CGGRScorer(router, min_tokens_ratio=0.5)
# 2. Get Mask
difficulty, mask, info = self.scorer(input_ids)
# 3. Apply Mask (Null Routing)
# mask is Boolean: True=Hard (Route to Expert), False=Easy (Skip/Null)
expert_output = expert_layer(x) * mask.unsqueeze(-1)| Strategy | Description | Best For |
|---|---|---|
entropy |
High entropy = hard | General training |
margin |
Small top-2 margin = hard | Classification |
loss |
High loss = hard | Direct optimization |
combined |
All signals combined | Best overall |
| Strategy | Description | Benefit |
|---|---|---|
topk |
Top-k hardest tokens | Simple, fast |
stratified |
Sample from difficulty buckets | Prevents forgetting |
sequence_aware |
Ensure coverage per sequence | Preserves structure |
Automatically adjusts token ratio based on batch confidence:
- Low confidence → more tokens (model is learning)
- High confidence → fewer tokens (model has converged)
CGGRLoss(dynamic_threshold=True, threshold_sensitivity=0.5)CGGRLoss(
# Scoring
scoring='combined',
# Selection
selection='topk',
num_strata=4, # For stratified
min_tokens_per_sequence=1, # For sequence_aware
# Thresholding
dynamic_threshold=True,
threshold_sensitivity=0.5,
# Curriculum
min_tokens_ratio=0.25,
warmup_steps=1000,
)| Config | Backward FLOPs | Overhead |
|---|---|---|
| Standard Loss | 100% | 0% |
| CGGR (25% tokens) | ~25% | ~0% |
For Token-Routed MLP and Mixture-of-Experts architectures, CGGR provides advanced persistent kernels with:
- Persistent Kernels - Keep SM threads active across expert batches
- Cooperative Thread Groups - Better SM utilization through cooperative scheduling
- Warp-Specialized Streaming - Overlap memory loads with computation
- Software Pipelining - Multi-stage async prefetching for latency hiding
from cggr import PersistentTRMLP, create_persistent_tr_mlp
# Create MoE layer with persistent optimizations
moe_layer = create_persistent_tr_mlp(
hidden_dim=1024,
intermediate_dim=4096,
num_experts=8,
top_k=2,
)
# Forward pass (routes tokens to experts)
output, aux_loss = moe_layer(hidden_states)
total_loss = language_modeling_loss + 0.01 * aux_lossExpected Performance Improvement: ~10-15% faster than standard grouped GEMM
| Metric | Standard | Persistent | Improvement |
|---|---|---|---|
| Kernel Launch Overhead | ~5μs/batch | ~0.5μs/batch | ~10x |
| SM Utilization | ~70% | ~85% | ~21% |
| End-to-End Throughput | Baseline | +10-15% | ✓ |