Skip to content

Performance Issue: logits_processor and stopping_criteria Account for 63.5% of Generation Time #51

@AmitMY

Description

@AmitMY

Summary

Benchmarking reveals that logits_processor and stopping_criteria in bytes_decoder.generate() account for 63.5% of the total generation time (warm cache). Optimizing these components could provide a 2.74x speedup.

Location

welt/model.py:503-509 in _generate_word_bytes():

return self.bytes_decoder.generate(
    inputs_embeds=inputs_embeds,
    generation_config=bytes_generation_config,
    tokenizer=tokenizer,
    logits_processor=[self.logits_processor],  # ← Costs 31.6% of runtime
    stopping_criteria=stopping_criteria,        # ← Costs 27.9% of runtime
)

Benchmark Results

Warm Cache (After torch.compile):

Configuration Time (s) Speedup vs Baseline % of Baseline
Baseline (both parameters) 3.1276 100%
Without logits_processor 2.1392 31.6% faster 68.4%
Without stopping_criteria 2.2548 27.9% faster 72.1%
Without both 1.1408 63.5% faster 36.5%

Cold Cache (First Run with torch.compile):

Configuration Time (s) Speedup vs Baseline
Baseline (both parameters) 20.1456
Without logits_processor 17.2984 14.1% faster
Without stopping_criteria 18.4905 8.2% faster
Without both 16.9498 15.9% faster

Note: Cold cache results show lower relative impact because compilation overhead dominates (6.4x slower than warm cache).

Analysis

Components:

  1. logits_processor: UTF8ValidationLogitsProcessor (compiled at line 419)

    • Ensures valid UTF-8 byte sequences during generation
    • Costs ~0.99s per benchmark (31.6% of runtime)
  2. stopping_criteria: WordStoppingCriteria

    • Stops generation at word boundaries
    • Costs ~0.87s per benchmark (27.9% of runtime)

Why This Matters:

These two components together take nearly 2x longer than the actual model forward passes, word encoding, and tokenization combined (1.99s vs 1.14s).

Reproduction

# Run benchmark with warmup
python -m welt_training.sample

# The script now includes:
# 1. Warmup run to compile everything
# 2. Timed benchmark run with warm cache

Proposed Solutions

Option 1: Optimize Existing Implementations

  • Profile UTF8ValidationLogitsProcessor to identify bottlenecks
  • Profile WordStoppingCriteria for optimization opportunities
  • Consider vectorization or JIT compilation improvements
  • Investigate if torch.compile is effectively optimizing these components

Option 2: Alternative Implementations

  • Implement validation logic directly in CUDA/Triton for GPU acceleration
  • Move stopping criteria checks to a more efficient location in the generation loop
  • Consider caching or batching validation checks

Option 3: Make Optional

  • Add flags to disable these checks for inference when validation isn't critical
  • Document the trade-offs (performance vs correctness guarantees)

Questions

  1. Are these components already torch.compiled effectively? (They are compiled at line 419/576)
  2. Could validation be moved to post-processing to avoid per-token overhead?
  3. Is there redundancy in the checks that could be eliminated?
  4. What's the actual implementation complexity of these components?

Additional Context

  • Model: sign/WeLT-string-repetition
  • Hardware: NVIDIA GB10 (CUDA capability 12.1)
  • PyTorch optimizations enabled: cudnn benchmark, TF32, Flash Attention
  • Generation config: max_generated_words=32
  • Batch size: 3 samples

Expected Outcome

Ideally, we should be able to:

  1. Keep the correctness guarantees of UTF-8 validation and word stopping
  2. Reduce their combined overhead from ~2s to <0.5s (75% reduction)
  3. Achieve close to the 1.14s generation time while maintaining safety

This would provide a 2.74x overall speedup without compromising functionality.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions