-
-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
Summary
Benchmarking reveals that logits_processor and stopping_criteria in bytes_decoder.generate() account for 63.5% of the total generation time (warm cache). Optimizing these components could provide a 2.74x speedup.
Location
welt/model.py:503-509 in _generate_word_bytes():
return self.bytes_decoder.generate(
inputs_embeds=inputs_embeds,
generation_config=bytes_generation_config,
tokenizer=tokenizer,
logits_processor=[self.logits_processor], # ← Costs 31.6% of runtime
stopping_criteria=stopping_criteria, # ← Costs 27.9% of runtime
)Benchmark Results
Warm Cache (After torch.compile):
| Configuration | Time (s) | Speedup vs Baseline | % of Baseline |
|---|---|---|---|
| Baseline (both parameters) | 3.1276 | — | 100% |
Without logits_processor |
2.1392 | 31.6% faster | 68.4% |
Without stopping_criteria |
2.2548 | 27.9% faster | 72.1% |
| Without both | 1.1408 | 63.5% faster | 36.5% |
Cold Cache (First Run with torch.compile):
| Configuration | Time (s) | Speedup vs Baseline |
|---|---|---|
| Baseline (both parameters) | 20.1456 | — |
Without logits_processor |
17.2984 | 14.1% faster |
Without stopping_criteria |
18.4905 | 8.2% faster |
| Without both | 16.9498 | 15.9% faster |
Note: Cold cache results show lower relative impact because compilation overhead dominates (6.4x slower than warm cache).
Analysis
Components:
-
logits_processor:UTF8ValidationLogitsProcessor(compiled at line 419)- Ensures valid UTF-8 byte sequences during generation
- Costs ~0.99s per benchmark (31.6% of runtime)
-
stopping_criteria:WordStoppingCriteria- Stops generation at word boundaries
- Costs ~0.87s per benchmark (27.9% of runtime)
Why This Matters:
These two components together take nearly 2x longer than the actual model forward passes, word encoding, and tokenization combined (1.99s vs 1.14s).
Reproduction
# Run benchmark with warmup
python -m welt_training.sample
# The script now includes:
# 1. Warmup run to compile everything
# 2. Timed benchmark run with warm cacheProposed Solutions
Option 1: Optimize Existing Implementations
- Profile
UTF8ValidationLogitsProcessorto identify bottlenecks - Profile
WordStoppingCriteriafor optimization opportunities - Consider vectorization or JIT compilation improvements
- Investigate if torch.compile is effectively optimizing these components
Option 2: Alternative Implementations
- Implement validation logic directly in CUDA/Triton for GPU acceleration
- Move stopping criteria checks to a more efficient location in the generation loop
- Consider caching or batching validation checks
Option 3: Make Optional
- Add flags to disable these checks for inference when validation isn't critical
- Document the trade-offs (performance vs correctness guarantees)
Questions
- Are these components already torch.compiled effectively? (They are compiled at line 419/576)
- Could validation be moved to post-processing to avoid per-token overhead?
- Is there redundancy in the checks that could be eliminated?
- What's the actual implementation complexity of these components?
Additional Context
- Model:
sign/WeLT-string-repetition - Hardware: NVIDIA GB10 (CUDA capability 12.1)
- PyTorch optimizations enabled: cudnn benchmark, TF32, Flash Attention
- Generation config:
max_generated_words=32 - Batch size: 3 samples
Expected Outcome
Ideally, we should be able to:
- Keep the correctness guarantees of UTF-8 validation and word stopping
- Reduce their combined overhead from ~2s to <0.5s (75% reduction)
- Achieve close to the 1.14s generation time while maintaining safety
This would provide a 2.74x overall speedup without compromising functionality.
Metadata
Metadata
Assignees
Labels
No labels