diff --git a/.gitignore b/.gitignore index 9661d04..11e7637 100644 --- a/.gitignore +++ b/.gitignore @@ -60,4 +60,7 @@ coverage.xml *.cover /research -.python-version \ No newline at end of file +.python-version + +# Local cargo config (dev overrides) +.cargo/ \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 2ebde6d..adf14f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,9 +14,13 @@ categories = ["text-processing", "encoding"] name = "splintr" crate-type = ["cdylib", "rlib"] +[features] +default = [] +pcre2 = ["dep:pcre2"] + [dependencies] -# PCRE2 regex with JIT support (2-4x faster than fancy-regex) -pcre2 = "0.2" +# PCRE2 regex with JIT support (optional, for benchmarking) +pcre2 = { version = "0.2", optional = true } # Rayon for internal parallelism rayon = "1.10" # Fast hashing (FxHashMap) @@ -31,6 +35,12 @@ base64 = "0.22" aho-corasick = "1.1" # LRU cache for frequent token sequences lru = "0.12" +# regexr regex engine (default backend) +regexr = { version = "0.1.0-beta.4", features = ["jit", "simd"] } + +[dev-dependencies] +# PCRE2 for benchmarking comparisons +pcre2 = "0.2" [profile.release] opt-level = 3 diff --git a/README.md b/README.md index 6ce31ce..6d3d69f 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ See the [API Guide](docs/api_guide.md) and [docs.rs](https://docs.rs/splintr) fo - **Compatible vocabularies** - Supports cl100k_base, o200k_base (OpenAI), Llama 3 family (Meta), and DeepSeek V3 (DeepSeek) - **Streaming decoders** - Real-time LLM output display with proper UTF-8 handling ([guide](docs/api_guide.md#streaming-decoder)) - **54 agent tokens** - Built-in support for chat, CoT reasoning, ReAct agents, tool calling, RAG citations ([docs](docs/special_tokens.md)) -- **Battle-tested algorithms** - PCRE2 with JIT, Aho-Corasick for special tokens, linked-list BPE +- **Battle-tested algorithms** - Regexr with JIT (pure Rust), Aho-Corasick for special tokens, linked-list BPE **Cross-platform:** @@ -154,6 +154,43 @@ cat results/my_benchmark.md The benchmark suite tests single text encoding, batch encoding, streaming decoder performance, and special token handling across various content types. +### Regex Backends + +Splintr uses a pure-Rust regex engine ([`regexr`](https://crates.io/crates/regexr)) by default, with optional PCRE2 support for compatibility. + +**Default Backend (regexr):** +- Pure Rust implementation (no C dependencies) +- JIT compilation and SIMD acceleration +- Native UTF-8 and Unicode property support + +**Optional PCRE2 Backend:** + +```python +from splintr import Tokenizer + +# Default: regexr backend (pure Rust) +tokenizer = Tokenizer.from_pretrained("cl100k_base") + +# Optional: switch to PCRE2 (requires --features pcre2) +tokenizer = Tokenizer.from_pretrained("cl100k_base").pcre2(True) +``` + +To enable PCRE2, build with the feature flag: + +```bash +maturin develop --release --features pcre2 +``` + +**Benchmarking:** + +```bash +# Compare backends (requires PCRE2 feature) +python benchmarks/benchmark_regexr_comparison.py --model cl100k_base + +# Visual comparison with charts +python benchmarks/benchmark_regexr_viz.py --model cl100k_base +``` + ## Streaming Decoders For real-time LLM applications where tokens arrive one at a time, Splintr provides streaming decoders that handle UTF-8 boundary alignment: @@ -226,7 +263,7 @@ See [docs/special_tokens.md](docs/special_tokens.md) for the complete list and [ Splintr implements several optimizations that make tokenization faster: -- **PCRE2 with JIT compilation**: 2-4x speedup on regex pattern matching +- **Regexr with JIT compilation**: Pure Rust regex engine with SIMD acceleration - **Rayon parallelism**: Leverages multiple CPU cores for batch encoding - **Linked-list BPE algorithm**: Avoids O(N²) complexity on pathological inputs - **FxHashMap**: Faster lookups than default SipHash for non-adversarial contexts diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index e5409d0..126dec6 100755 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -32,8 +32,16 @@ try: from splintr import Tokenizer as SplintrTokenizer HAS_SPLINTR = True + # Test if PCRE2 is available + try: + test_tok = SplintrTokenizer.from_pretrained("cl100k_base").pcre2(True) + HAS_PCRE2 = True + del test_tok + except ValueError: + HAS_PCRE2 = False except ImportError: HAS_SPLINTR = False + HAS_PCRE2 = False print("Warning: splintr not installed. Run: pip install -e . or maturin develop") try: @@ -671,6 +679,13 @@ def main(): action="store_true", help="Skip cache benchmarks" ) + parser.add_argument( + "--backend", + type=str, + default="regexr", + choices=["regexr", "pcre2"], + help="Regex backend to use: regexr (default, pure Rust) or pcre2 (requires feature flag)" + ) args = parser.parse_args() if not HAS_SPLINTR: @@ -689,9 +704,19 @@ def main(): print("=" * 70) # Load tokenizers - print(f"\nLoading tokenizers (model: {args.model})...") - splintr_enc = SplintrTokenizer.from_pretrained(args.model) - print(f" Splintr: {splintr_enc}") + backend_str = "PCRE2" if args.backend == "pcre2" else "Regexr" + print(f"\nLoading tokenizers (model: {args.model}, backend: {backend_str})...") + + if args.backend == "pcre2": + if not HAS_PCRE2: + print("Error: PCRE2 backend requested but not available.") + print(" Build with: maturin develop --release --features pcre2") + return 1 + splintr_enc = SplintrTokenizer.from_pretrained(args.model).pcre2(True) + else: # regexr (default) + splintr_enc = SplintrTokenizer.from_pretrained(args.model) + + print(f" Splintr ({backend_str}): {splintr_enc}") tiktoken_enc = None if args.compare or args.correctness_only: diff --git a/benchmarks/benchmark_batch.py b/benchmarks/benchmark_batch.py index 3d7ebaa..e738acb 100644 --- a/benchmarks/benchmark_batch.py +++ b/benchmarks/benchmark_batch.py @@ -23,10 +23,11 @@ Tokenizers convert text into numerical representations that models can understand.""" TOKENIZER_COLORS = { - "splintr": "#2ecc71", # Green - "tiktoken": "#3498db", # Blue - "huggingface": "#e74c3c", # Red - "tokendagger": "#9b59b6", # Purple + "splintr": "#2ecc71", # Green (default, pure Rust) + "splintr-pcre2": "#27ae60", # Dark Green (optional) + "tiktoken": "#3498db", # Blue + "huggingface": "#e74c3c", # Red + "tokendagger": "#9b59b6", # Purple } @@ -86,14 +87,15 @@ def load_tokenizers(): """Load all available tokenizers with batch functions. All tokenizers use their native batch encoding methods: - - splintr: encode_batch (Rayon parallel) + - splintr: encode_batch (Rayon parallel, pure Rust regex with JIT) + - splintr-pcre2: encode_batch (Rayon parallel, PCRE2 with JIT) - tiktoken: encode_ordinary_batch (native batch) - huggingface: encode_batch (native batch) - tokendagger: encode_batch (native batch) """ tokenizers = {} - # splintr - native batch via Rayon + # splintr - default backend (pure Rust with JIT) try: import splintr enc = splintr.Tokenizer.from_pretrained("cl100k_base") @@ -102,6 +104,15 @@ def load_tokenizers(): except ImportError: print("splintr not available") + # splintr-pcre2 - optional backend (requires --features pcre2) + try: + import splintr + enc_pcre2 = splintr.Tokenizer.from_pretrained("cl100k_base").pcre2(True) + tokenizers["splintr-pcre2"] = enc_pcre2.encode_batch + print("Loaded: splintr-pcre2 (native encode_batch)") + except (ImportError, ValueError) as e: + print(f"splintr-pcre2 not available: {e}") + # tiktoken - native batch try: import tiktoken diff --git a/benchmarks/benchmark_regexr_comparison.py b/benchmarks/benchmark_regexr_comparison.py new file mode 100644 index 0000000..cf569fa --- /dev/null +++ b/benchmarks/benchmark_regexr_comparison.py @@ -0,0 +1,431 @@ +#!/usr/bin/env python3 +""" +Benchmark comparing PCRE2 vs Regexr regex backends in splintr. + +Tests both implementations with identical workloads to measure: +- Throughput (MB/s) +- Latency (mean, min, max, std) +- Correctness (ensure identical token outputs) +- Performance ratio (regexr vs pcre2) + +Usage: + python benchmarks/benchmark_regexr_comparison.py + python benchmarks/benchmark_regexr_comparison.py --iterations 20 + python benchmarks/benchmark_regexr_comparison.py --model o200k_base + python benchmarks/benchmark_regexr_comparison.py --workload long # Only long texts +""" + +import argparse +import json +import os +import platform +import statistics +import time +from dataclasses import dataclass, asdict +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Tuple + +try: + from splintr import Tokenizer + HAS_SPLINTR = True + # Test if PCRE2 is available + try: + test_tok = Tokenizer.from_pretrained("cl100k_base").pcre2(True) + HAS_PCRE2 = True + del test_tok + except ValueError: + HAS_PCRE2 = False + print("Note: PCRE2 not available. Build with: maturin develop --release --features pcre2") + print(" Benchmark will compare regexr only.\n") +except ImportError: + HAS_SPLINTR = False + print("Error: splintr not installed. Run: pip install -e . or maturin develop") + exit(1) + + +@dataclass +class ComparisonResult: + """Results from comparing PCRE2 vs Regexr on a single workload.""" + workload_name: str + data_size_bytes: int + data_size_chars: int + + # PCRE2 results + pcre2_mean_ms: float + pcre2_std_ms: float + pcre2_min_ms: float + pcre2_max_ms: float + pcre2_throughput_mb_s: float + + # Regexr results + regexr_mean_ms: float + regexr_std_ms: float + regexr_min_ms: float + regexr_max_ms: float + regexr_throughput_mb_s: float + + # Comparison metrics + speedup_ratio: float # pcre2_time / regexr_time (>1 means regexr is faster) + tokens_match: bool # Do both produce identical tokens? + + iterations: int + + +@dataclass +class SystemInfo: + platform: str + python_version: str + cpu_count: int + timestamp: str + + +def get_system_info() -> SystemInfo: + """Collect system information.""" + return SystemInfo( + platform=platform.platform(), + python_version=platform.python_version(), + cpu_count=os.cpu_count() or 1, + timestamp=datetime.now().isoformat(), + ) + + +def benchmark_single(func, iterations: int = 10, warmup: int = 2) -> Tuple[float, float, float, float]: + """ + Benchmark a function and return (mean_ms, std_ms, min_ms, max_ms). + """ + # Warmup + for _ in range(warmup): + func() + + # Timed runs + times = [] + for _ in range(iterations): + start = time.perf_counter() + func() + elapsed = time.perf_counter() - start + times.append(elapsed * 1000) # Convert to ms + + mean_ms = statistics.mean(times) + std_ms = statistics.stdev(times) if len(times) > 1 else 0 + min_ms = min(times) + max_ms = max(times) + + return mean_ms, std_ms, min_ms, max_ms + + +def generate_test_workloads() -> Dict[str, str]: + """Generate test workloads of various sizes and content types.""" + return { + # Size-based tests + "tiny": "Hello, world!", + "short": "The quick brown fox jumps over the lazy dog. " * 10, + "medium": "The quick brown fox jumps over the lazy dog. " * 1000, + "long": "The quick brown fox jumps over the lazy dog. " * 10000, + "very_long": "The quick brown fox jumps over the lazy dog. " * 50000, + + # Content-type tests + "multilingual": "Hello! 你好!مرحبا!Bonjour! Hola! Привет! " * 1000, + "chinese": "你好世界!这是一个测试。人工智能正在改变世界。机器学习是人工智能的一个分支。" * 1000, + "code_python": ''' +def fibonacci(n): + """Calculate the nth Fibonacci number.""" + if n <= 1: + return n + return fibonacci(n - 1) + fibonacci(n - 2) + +class DataProcessor: + def __init__(self, data): + self.data = data + + def process(self): + return [x * 2 for x in self.data if x > 0] +''' * 500, + "code_json": '{"name": "test", "value": 123, "nested": {"key": "value", "array": [1, 2, 3]}}' * 1000, + "numbers": "1234567890 9876543210 " * 5000, + "special_chars": "!@#$%^&*()_+-=[]{}|;':\",./<>? " * 2000, + "emojis": "🎉🎊🎈🎁🎀🎄🎃🎇🎆✨🌟💫⭐️🌈🦄🐉🔥💧🌊 " * 500, + "whitespace_heavy": " word another more " * 3000, + + # Pattern-specific tests + "contractions": "I'm, you're, he's, she's, it's, we're, they're, I'll, you'll, we'll, " * 1000, + "punctuation_heavy": "Hello... World!!! What? Really??? Yes!!! No??? Maybe... Perhaps!!! " * 1000, + } + + +def compare_backends( + pcre2_tokenizer, # Optional[Tokenizer] + regexr_tokenizer: Tokenizer, + workload_name: str, + text: str, + iterations: int = 10, +) -> ComparisonResult: + """ + Compare PCRE2 and Regexr backends on a single workload. + """ + data_size_bytes = len(text.encode('utf-8')) + data_size_chars = len(text) + + # Regexr tokens (always available) + regexr_tokens = regexr_tokenizer.encode(text) + + # PCRE2 comparison (if available) + if pcre2_tokenizer is not None: + pcre2_tokens = pcre2_tokenizer.encode(text) + tokens_match = pcre2_tokens == regexr_tokens + + if not tokens_match: + print(f"WARNING: Token mismatch for '{workload_name}'!") + print(f" PCRE2 tokens: {len(pcre2_tokens)} tokens") + print(f" Regexr tokens: {len(regexr_tokens)} tokens") + # Show first few tokens for debugging + print(f" First 10 PCRE2: {pcre2_tokens[:10]}") + print(f" First 10 Regexr: {regexr_tokens[:10]}") + + # Benchmark PCRE2 + pcre2_mean, pcre2_std, pcre2_min, pcre2_max = benchmark_single( + lambda: pcre2_tokenizer.encode(text), + iterations=iterations, + ) + else: + # PCRE2 not available - set placeholder values + tokens_match = True # Can't compare, assume correct + pcre2_mean = 0.0 + pcre2_std = 0.0 + pcre2_min = 0.0 + pcre2_max = 0.0 + + # Benchmark Regexr (always available) + regexr_mean, regexr_std, regexr_min, regexr_max = benchmark_single( + lambda: regexr_tokenizer.encode(text), + iterations=iterations, + ) + + # Calculate throughput + if pcre2_tokenizer is not None and pcre2_mean > 0: + pcre2_throughput = (data_size_bytes / 1024 / 1024) / (pcre2_mean / 1000) + else: + pcre2_throughput = 0.0 + + if regexr_mean > 0: + regexr_throughput = (data_size_bytes / 1024 / 1024) / (regexr_mean / 1000) + else: + regexr_throughput = 0.0 + + # Calculate speedup ratio (pcre2 / regexr) + # > 1 means regexr is faster + # < 1 means pcre2 is faster + if pcre2_tokenizer is not None and regexr_mean > 0 and pcre2_mean > 0: + speedup_ratio = pcre2_mean / regexr_mean + else: + speedup_ratio = 1.0 # No comparison possible + + return ComparisonResult( + workload_name=workload_name, + data_size_bytes=data_size_bytes, + data_size_chars=data_size_chars, + pcre2_mean_ms=pcre2_mean, + pcre2_std_ms=pcre2_std, + pcre2_min_ms=pcre2_min, + pcre2_max_ms=pcre2_max, + pcre2_throughput_mb_s=pcre2_throughput, + regexr_mean_ms=regexr_mean, + regexr_std_ms=regexr_std, + regexr_min_ms=regexr_min, + regexr_max_ms=regexr_max, + regexr_throughput_mb_s=regexr_throughput, + speedup_ratio=speedup_ratio, + tokens_match=tokens_match, + iterations=iterations, + ) + + +def print_comparison_table(results: List[ComparisonResult], has_pcre2: bool = True): + """Print a formatted comparison table.""" + print("\n" + "="*120) + if has_pcre2: + print("PCRE2 vs Regexr Performance Comparison") + else: + print("Regexr Performance Benchmark (PCRE2 not available)") + print("="*120) + + if has_pcre2: + # Header with PCRE2 comparison + print(f"{'Workload':<20} {'Size':>10} {'PCRE2 (ms)':>15} {'Regexr (ms)':>15} {'PCRE2 MB/s':>12} {'Regexr MB/s':>12} {'Speedup':>10} {'Match':>8}") + print("-"*120) + + # Data rows + for r in results: + size_kb = r.data_size_bytes / 1024 + size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb/1024:.1f} MB" + + # Color code speedup ratio + if r.speedup_ratio > 1.1: + speedup_str = f"{r.speedup_ratio:.2f}x ✓" # Regexr faster + elif r.speedup_ratio < 0.9: + speedup_str = f"{r.speedup_ratio:.2f}x ✗" # PCRE2 faster + else: + speedup_str = f"{r.speedup_ratio:.2f}x ~" # Similar + + match_str = "✓" if r.tokens_match else "✗ FAIL" + + print(f"{r.workload_name:<20} {size_str:>10} " + f"{r.pcre2_mean_ms:>13.2f} ± {r.pcre2_std_ms:.2f} " + f"{r.regexr_mean_ms:>11.2f} ± {r.regexr_std_ms:.2f} " + f"{r.pcre2_throughput_mb_s:>12.1f} " + f"{r.regexr_throughput_mb_s:>12.1f} " + f"{speedup_str:>10} " + f"{match_str:>8}") + + print("="*120) + + # Summary statistics + print("\nSummary:") + avg_speedup = statistics.mean([r.speedup_ratio for r in results]) + all_match = all(r.tokens_match for r in results) + regexr_faster_count = sum(1 for r in results if r.speedup_ratio > 1.0) + pcre2_faster_count = sum(1 for r in results if r.speedup_ratio < 1.0) + + print(f" Average speedup ratio: {avg_speedup:.2f}x") + print(f" Regexr faster: {regexr_faster_count}/{len(results)} workloads") + print(f" PCRE2 faster: {pcre2_faster_count}/{len(results)} workloads") + print(f" Correctness: {'✓ All outputs match' if all_match else '✗ Some outputs differ'}") + + if avg_speedup > 1.0: + print(f"\n → Regexr is {avg_speedup:.1f}x faster on average") + else: + print(f"\n → PCRE2 is {1/avg_speedup:.1f}x faster on average") + else: + # Regexr-only output + print(f"{'Workload':<20} {'Size':>10} {'Regexr (ms)':>18} {'Regexr MB/s':>12}") + print("-"*70) + + for r in results: + size_kb = r.data_size_bytes / 1024 + size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb/1024:.1f} MB" + + print(f"{r.workload_name:<20} {size_str:>10} " + f"{r.regexr_mean_ms:>14.2f} ± {r.regexr_std_ms:.2f} " + f"{r.regexr_throughput_mb_s:>12.1f}") + + print("="*70) + + # Summary for regexr-only + print("\nSummary:") + avg_throughput = statistics.mean([r.regexr_throughput_mb_s for r in results if r.regexr_throughput_mb_s > 0]) + print(f" Average throughput: {avg_throughput:.1f} MB/s") + print(f" Note: Build with --features pcre2 to enable comparison") + + +def save_results(results: List[ComparisonResult], system_info: SystemInfo, output_file: str): + """Save results to JSON file.""" + data = { + "system_info": asdict(system_info), + "results": [asdict(r) for r in results], + } + + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w') as f: + json.dump(data, f, indent=2) + + print(f"\n✓ Results saved to: {output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Compare PCRE2 vs Regexr performance in splintr tokenizer" + ) + parser.add_argument( + "--model", + default="cl100k_base", + choices=["cl100k_base", "o200k_base", "llama3", "deepseek_v3"], + help="Model to use for benchmarking (default: cl100k_base)", + ) + parser.add_argument( + "--iterations", + type=int, + default=10, + help="Number of iterations per benchmark (default: 10)", + ) + parser.add_argument( + "--workload", + type=str, + default=None, + help="Run only a specific workload (e.g., 'long', 'multilingual')", + ) + parser.add_argument( + "--output", + type=str, + default="benchmark_results/regexr_comparison.json", + help="Output file for results (default: benchmark_results/regexr_comparison.json)", + ) + + args = parser.parse_args() + + if not HAS_SPLINTR: + print("Error: splintr not installed") + return 1 + + print(f"Loading {args.model} tokenizers...") + regexr_tokenizer = Tokenizer.from_pretrained(args.model) # Default is regexr with JIT + print(" → Regexr using JIT engine (default)") + + if HAS_PCRE2: + pcre2_tokenizer = Tokenizer.from_pretrained(args.model).pcre2(True) + print(" → PCRE2 backend enabled") + else: + pcre2_tokenizer = None + print(" → PCRE2 not available (benchmark regexr only)") + + print(f"Generating test workloads...") + workloads = generate_test_workloads() + + # Filter to specific workload if requested + if args.workload: + if args.workload not in workloads: + print(f"Error: Unknown workload '{args.workload}'") + print(f"Available: {', '.join(workloads.keys())}") + return 1 + workloads = {args.workload: workloads[args.workload]} + + print(f"Running {len(workloads)} workloads with {args.iterations} iterations each...") + print() + + results = [] + system_info = get_system_info() + + for i, (name, text) in enumerate(workloads.items(), 1): + print(f"[{i}/{len(workloads)}] Benchmarking: {name:<20} ({len(text.encode('utf-8'))/1024:.1f} KB)...", end=" ") + + result = compare_backends( + pcre2_tokenizer, + regexr_tokenizer, + name, + text, + iterations=args.iterations, + ) + + results.append(result) + + # Print quick summary + if HAS_PCRE2: + if result.speedup_ratio > 1.0: + print(f"Regexr {result.speedup_ratio:.2f}x faster") + else: + print(f"PCRE2 {1/result.speedup_ratio:.2f}x faster") + else: + print(f"{result.regexr_throughput_mb_s:.1f} MB/s") + + # Print detailed comparison table + print_comparison_table(results, has_pcre2=HAS_PCRE2) + + # Save results + save_results(results, system_info, args.output) + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/benchmarks/benchmark_regexr_viz.py b/benchmarks/benchmark_regexr_viz.py new file mode 100755 index 0000000..f6d1ca5 --- /dev/null +++ b/benchmarks/benchmark_regexr_viz.py @@ -0,0 +1,558 @@ +#!/usr/bin/env python3 +""" +Benchmark: PCRE2 vs Regexr Backend Comparison with Visualization +Compares the performance of PCRE2 and Regexr regex backends in splintr. + +Usage: + python benchmarks/benchmark_regexr_viz.py + python benchmarks/benchmark_regexr_viz.py --model o200k_base + python benchmarks/benchmark_regexr_viz.py --iterations 20 +""" + +import argparse +import gc +import statistics +import time +from dataclasses import dataclass +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +try: + from splintr import Tokenizer + HAS_SPLINTR = True + # Test if PCRE2 is available + try: + test_tok = Tokenizer.from_pretrained("cl100k_base").pcre2(True) + HAS_PCRE2 = True + del test_tok + except ValueError: + HAS_PCRE2 = False + print("Note: PCRE2 not available. Build with: maturin develop --release --features pcre2") + print(" Exiting - this benchmark requires PCRE2 for comparison.\n") + exit(1) +except ImportError: + HAS_SPLINTR = False + print("Error: splintr not installed. Run: pip install -e . or maturin develop") + exit(1) + + +# Sample texts for benchmarking +SAMPLE_TEXTS = { + "tiny": "Hello!", + "short": "Hello, world! This is a test.", + "medium": """The quick brown fox jumps over the lazy dog. + Machine learning models require tokenization to process text efficiently. + Tokenizers convert text into numerical representations that models can understand.""" * 10, + "long": """Artificial intelligence and machine learning have revolutionized + the way we process and understand natural language. Large language models (LLMs) + like GPT-4, Claude, and others rely heavily on efficient tokenization to handle + vast amounts of text data. The performance of tokenizers directly impacts the + overall throughput of these systems, making optimization crucial for production + deployments. BPE (Byte Pair Encoding) has become the de facto standard for + modern tokenizers due to its balance of vocabulary efficiency and handling of + out-of-vocabulary words.""" * 100, + "code": ''' +def fibonacci(n: int) -> int: + """Calculate the nth Fibonacci number.""" + if n <= 1: + return n + return fibonacci(n - 1) + fibonacci(n - 2) + +class TokenizerBenchmark: + def __init__(self, name: str): + self.name = name + self.results = [] + + def run(self, text: str, iterations: int = 100): + for _ in range(iterations): + tokens = self.encode(text) + self.results.append(len(tokens)) +''' * 50, + "multilingual": """ + English: The quick brown fox jumps over the lazy dog. + 中文: 快速的棕色狐狸跳过懒狗。 + 日本語: 素早い茶色の狐が怠惰な犬を飛び越える。 + 한국어: 빠른 갈색 여우가 게으른 개를 뛰어넘습니다. + العربية: الثعلب البني السريع يقفز فوق الكلب الكسول. + Русский: Быстрая коричневая лиса прыгает через ленивую собаку. + """ * 30, + "chinese": "你好世界!这是一个测试。人工智能正在改变世界。机器学习是人工智能的一个分支。" * 50, + "contractions": "I'm, you're, he's, she's, it's, we're, they're, I'll, you'll, we'll, " * 50, +} + +BACKEND_COLORS = { + "pcre2": "#2ecc71", # Green (current default) + "regexr": "#e67e22", # Orange (experimental) +} + + +@dataclass +class BenchmarkResult: + backend: str + text_type: str + bytes_per_second: float + tokens_per_second: float + num_tokens: int + num_bytes: int + latency_ms: float + latency_std_ms: float + tokens_match: bool # Does output match the other backend? + + +def benchmark_encode( + backend: str, + encode_fn, + text: str, + text_type: str, + reference_tokens=None, + warmup: int = 50, + iterations: int = 100, +) -> BenchmarkResult: + """Benchmark a single encode function.""" + num_bytes = len(text.encode("utf-8")) + + # Warmup + for _ in range(warmup): + encode_fn(text) + + # Force garbage collection before timing + gc.collect() + + # Benchmark + times = [] + num_tokens = 0 + tokens = None + for _ in range(iterations): + start = time.perf_counter_ns() + tokens = encode_fn(text) + end = time.perf_counter_ns() + times.append((end - start) / 1e9) # Convert to seconds + num_tokens = len(tokens) + + avg_time = statistics.mean(times) + std_time = statistics.stdev(times) if len(times) > 1 else 0 + bytes_per_second = num_bytes / avg_time + tokens_per_second = num_tokens / avg_time + + # Check if tokens match reference + tokens_match = True + if reference_tokens is not None: + tokens_match = tokens == reference_tokens + + return BenchmarkResult( + backend=backend, + text_type=text_type, + bytes_per_second=bytes_per_second, + tokens_per_second=tokens_per_second, + num_tokens=num_tokens, + num_bytes=num_bytes, + latency_ms=avg_time * 1000, + latency_std_ms=std_time * 1000, + tokens_match=tokens_match, + ) + + +def run_benchmarks( + pcre2_tokenizer, + regexr_tokenizer, + iterations: int = 100, +) -> list[BenchmarkResult]: + """Run benchmarks for all text types.""" + results = [] + + # Global warmup to initialize thread pools + print("\nWarming up tokenizers...") + warmup_text = "This is a warmup text to initialize thread pools and caches." * 10 + for _ in range(100): + pcre2_tokenizer.encode(warmup_text) + regexr_tokenizer.encode(warmup_text) + print("Warmup complete.") + + print("\n" + "=" * 90) + print("PCRE2 vs REGEXR BACKEND COMPARISON") + print("=" * 90) + + for text_type, text in SAMPLE_TEXTS.items(): + num_bytes = len(text.encode("utf-8")) + print(f"\n--- {text_type.upper()} ({num_bytes:,} bytes) ---") + print(f"{'Backend':<10} {'MB/s':>10} {'Ktok/s':>10} {'Latency':>12} {'Std':>10} {'Match':>8}") + print("-" * 70) + + # Benchmark PCRE2 first (reference) + pcre2_result = benchmark_encode( + "pcre2", + pcre2_tokenizer.encode, + text, + text_type, + iterations=iterations, + ) + results.append(pcre2_result) + + # Get reference tokens for comparison + reference_tokens = pcre2_tokenizer.encode(text) + + # Benchmark Regexr + regexr_result = benchmark_encode( + "regexr", + regexr_tokenizer.encode, + text, + text_type, + reference_tokens=reference_tokens, + iterations=iterations, + ) + results.append(regexr_result) + + # Print results + for result in [pcre2_result, regexr_result]: + match_str = "✓" if result.tokens_match else "✗" + print( + f"{result.backend:<10} {result.bytes_per_second / 1e6:>10.2f} " + f"{result.tokens_per_second / 1e3:>10.2f} " + f"{result.latency_ms:>10.3f} ms " + f"{result.latency_std_ms:>8.3f} ms " + f"{match_str:>8}" + ) + + # Calculate and print speedup + speedup = pcre2_result.latency_ms / regexr_result.latency_ms + if speedup > 1.0: + print(f" → Regexr is {speedup:.2f}x FASTER") + else: + print(f" → PCRE2 is {1/speedup:.2f}x FASTER") + + return results + + +def generate_throughput_chart(results: list[BenchmarkResult], output_path: str): + """Generate throughput comparison chart.""" + + # Get unique backends and text types + backends = list(dict.fromkeys(r.backend for r in results)) + text_types = list(dict.fromkeys(r.text_type for r in results)) + + # Create figure + fig, ax = plt.subplots(figsize=(14, 7)) + + x = np.arange(len(text_types)) + width = 0.35 + + # Create bars for each backend + for i, backend in enumerate(backends): + throughputs = [] + for text_type in text_types: + for r in results: + if r.backend == backend and r.text_type == text_type: + throughputs.append(r.bytes_per_second / 1e6) + break + + offset = (i - 0.5) * width + bars = ax.bar( + x + offset, + throughputs, + width, + label=backend.upper(), + color=BACKEND_COLORS.get(backend, "#95a5a6"), + ) + + # Add value labels on bars + for bar, val in zip(bars, throughputs): + height = bar.get_height() + ax.annotate( + f'{val:.1f}', + xy=(bar.get_x() + bar.get_width() / 2, height), + xytext=(0, 3), + textcoords="offset points", + ha='center', + va='bottom', + fontsize=9, + fontweight='bold', + ) + + # Add text size annotations + text_sizes = [] + for text_type in text_types: + for r in results: + if r.text_type == text_type: + text_sizes.append(r.num_bytes) + break + + ax.set_xlabel("Text Type", fontsize=12, fontweight='bold') + ax.set_ylabel("Throughput (MB/s)", fontsize=12, fontweight='bold') + ax.set_title("PCRE2 vs Regexr: Throughput Comparison", fontsize=14, fontweight="bold") + ax.set_xticks(x) + + # Create x-tick labels with size info + xlabels = [f"{t.capitalize()}\n({text_sizes[i]:,} bytes)" for i, t in enumerate(text_types)] + ax.set_xticklabels(xlabels, fontsize=10) + + ax.legend(loc="upper left", fontsize=11) + ax.grid(axis="y", alpha=0.3) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + print(f"\n✓ Throughput chart saved to: {output_path}") + plt.close() + + +def generate_speedup_chart(results: list[BenchmarkResult], output_path: str): + """Generate speedup ratio chart (Regexr vs PCRE2).""" + + text_types = list(dict.fromkeys(r.text_type for r in results)) + + # Calculate speedup ratios + speedups = [] + for text_type in text_types: + pcre2_time = None + regexr_time = None + for r in results: + if r.text_type == text_type: + if r.backend == "pcre2": + pcre2_time = r.latency_ms + elif r.backend == "regexr": + regexr_time = r.latency_ms + + if pcre2_time and regexr_time: + # Speedup > 1 means regexr is faster + speedup = pcre2_time / regexr_time + speedups.append(speedup) + + # Create figure + fig, ax = plt.subplots(figsize=(14, 7)) + + x = np.arange(len(text_types)) + colors = ['#2ecc71' if s > 1.0 else '#e74c3c' for s in speedups] + + bars = ax.bar(x, speedups, color=colors, edgecolor='black', linewidth=1.5) + + # Add horizontal line at 1.0 (parity) + ax.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Parity (1.0x)') + + # Add value labels on bars + for bar, val in zip(bars, speedups): + height = bar.get_height() + label = f'{val:.2f}x' + if val > 1.0: + label += '\nRegexr\nFaster' + else: + label += '\nPCRE2\nFaster' + + y_pos = height + 0.05 if height > 1.0 else height - 0.05 + va_pos = 'bottom' if height > 1.0 else 'top' + + ax.annotate( + label, + xy=(bar.get_x() + bar.get_width() / 2, y_pos), + ha='center', + va=va_pos, + fontsize=9, + fontweight='bold', + ) + + # Add text size annotations + text_sizes = [] + for text_type in text_types: + for r in results: + if r.text_type == text_type: + text_sizes.append(r.num_bytes) + break + + ax.set_xlabel("Text Type", fontsize=12, fontweight='bold') + ax.set_ylabel("Speedup Ratio (PCRE2 time / Regexr time)", fontsize=12, fontweight='bold') + ax.set_title("Regexr Performance Relative to PCRE2\n(>1.0 = Regexr faster, <1.0 = PCRE2 faster)", + fontsize=14, fontweight="bold") + ax.set_xticks(x) + + xlabels = [f"{t.capitalize()}\n({text_sizes[i]:,} bytes)" for i, t in enumerate(text_types)] + ax.set_xticklabels(xlabels, fontsize=10) + + ax.legend(loc="upper right", fontsize=11) + ax.grid(axis="y", alpha=0.3) + + # Set y-axis to show both faster and slower clearly + y_max = max(speedups) * 1.15 + y_min = min(speedups) * 0.85 + ax.set_ylim(min(y_min, 0.8), max(y_max, 1.2)) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + print(f"✓ Speedup chart saved to: {output_path}") + plt.close() + + +def generate_latency_chart(results: list[BenchmarkResult], output_path: str): + """Generate latency comparison chart.""" + + backends = list(dict.fromkeys(r.backend for r in results)) + text_types = list(dict.fromkeys(r.text_type for r in results)) + + fig, ax = plt.subplots(figsize=(14, 7)) + + x = np.arange(len(text_types)) + width = 0.35 + + for i, backend in enumerate(backends): + latencies = [] + errors = [] + for text_type in text_types: + for r in results: + if r.backend == backend and r.text_type == text_type: + latencies.append(r.latency_ms) + errors.append(r.latency_std_ms) + break + + offset = (i - 0.5) * width + bars = ax.bar( + x + offset, + latencies, + width, + label=backend.upper(), + color=BACKEND_COLORS.get(backend, "#95a5a6"), + yerr=errors, + capsize=3, + ) + + # Add value labels + for bar, val in zip(bars, latencies): + height = bar.get_height() + ax.annotate( + f'{val:.2f}', + xy=(bar.get_x() + bar.get_width() / 2, height), + xytext=(0, 3), + textcoords="offset points", + ha='center', + va='bottom', + fontsize=9, + ) + + ax.set_xlabel("Text Type", fontsize=12, fontweight='bold') + ax.set_ylabel("Latency (ms) - Lower is Better", fontsize=12, fontweight='bold') + ax.set_title("PCRE2 vs Regexr: Latency Comparison", fontsize=14, fontweight="bold") + ax.set_xticks(x) + + text_sizes = [] + for text_type in text_types: + for r in results: + if r.text_type == text_type: + text_sizes.append(r.num_bytes) + break + + xlabels = [f"{t.capitalize()}\n({text_sizes[i]:,} bytes)" for i, t in enumerate(text_types)] + ax.set_xticklabels(xlabels, fontsize=10) + + ax.legend(loc="upper left", fontsize=11) + ax.grid(axis="y", alpha=0.3) + ax.set_yscale("log") # Log scale to see all values + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + print(f"✓ Latency chart saved to: {output_path}") + plt.close() + + +def print_summary(results: list[BenchmarkResult]): + """Print summary statistics.""" + print("\n" + "=" * 90) + print("SUMMARY") + print("=" * 90) + + # Calculate average speedup + text_types = list(dict.fromkeys(r.text_type for r in results)) + speedups = [] + + for text_type in text_types: + pcre2_time = None + regexr_time = None + for r in results: + if r.text_type == text_type: + if r.backend == "pcre2": + pcre2_time = r.latency_ms + elif r.backend == "regexr": + regexr_time = r.latency_ms + + if pcre2_time and regexr_time: + speedup = pcre2_time / regexr_time + speedups.append(speedup) + + avg_speedup = statistics.mean(speedups) + regexr_faster_count = sum(1 for s in speedups if s > 1.0) + pcre2_faster_count = sum(1 for s in speedups if s < 1.0) + + # Check correctness + all_match = all(r.tokens_match for r in results) + + print(f"Average speedup ratio: {avg_speedup:.3f}x") + print(f"Regexr faster on: {regexr_faster_count}/{len(speedups)} workloads") + print(f"PCRE2 faster on: {pcre2_faster_count}/{len(speedups)} workloads") + print(f"Correctness: {'✓ All outputs match' if all_match else '✗ Some outputs differ'}") + + if avg_speedup > 1.0: + print(f"\n→ Overall: Regexr is {avg_speedup:.2f}x FASTER on average") + else: + print(f"\n→ Overall: PCRE2 is {1/avg_speedup:.2f}x FASTER on average") + + print("\nDetailed breakdown:") + for text_type, speedup in zip(text_types, speedups): + if speedup > 1.0: + print(f" {text_type:>15}: Regexr {speedup:.2f}x faster") + else: + print(f" {text_type:>15}: PCRE2 {1/speedup:.2f}x faster") + + +def main(): + parser = argparse.ArgumentParser( + description="Visualize PCRE2 vs Regexr performance comparison" + ) + parser.add_argument( + "--model", + default="cl100k_base", + choices=["cl100k_base", "o200k_base", "llama3", "deepseek_v3"], + help="Model to use for benchmarking (default: cl100k_base)", + ) + parser.add_argument( + "--iterations", + type=int, + default=100, + help="Number of iterations per benchmark (default: 100)", + ) + + args = parser.parse_args() + + if not HAS_SPLINTR: + print("Error: splintr not installed") + return 1 + + print("=" * 90) + print("PCRE2 vs REGEXR BACKEND COMPARISON") + print("=" * 90) + + # Create output directory + output_dir = Path(__file__).parent / "results" + output_dir.mkdir(exist_ok=True) + + # Load tokenizers + print(f"\nLoading {args.model} tokenizers...") + regexr_tokenizer = Tokenizer.from_pretrained(args.model) # Default is regexr with JIT + pcre2_tokenizer = Tokenizer.from_pretrained(args.model).pcre2(True) + print("✓ Tokenizers loaded") + print(" → Regexr: Default (JIT engine)") + print(" → PCRE2: Enabled via .pcre2(True)") + + # Run benchmarks + results = run_benchmarks(pcre2_tokenizer, regexr_tokenizer, iterations=args.iterations) + + # Generate charts + print("\nGenerating visualizations...") + generate_throughput_chart(results, str(output_dir / "regexr_throughput.png")) + generate_speedup_chart(results, str(output_dir / "regexr_speedup.png")) + generate_latency_chart(results, str(output_dir / "regexr_latency.png")) + + # Print summary + print_summary(results) + + print("\n✓ Done!") + + +if __name__ == "__main__": + main() diff --git a/python/splintr/__init__.py b/python/splintr/__init__.py index ab65880..5a9156b 100644 --- a/python/splintr/__init__.py +++ b/python/splintr/__init__.py @@ -2,7 +2,8 @@ Splintr - Fast Rust BPE tokenizer with Python bindings A high-performance tokenizer featuring: -- PCRE2 with JIT compilation (2-4x faster than fancy-regex) +- Regexr with JIT and SIMD (default, pure Rust) +- Optional PCRE2 with JIT (requires pcre2 feature) - Rayon parallelism for multi-core encoding - Linked-list BPE algorithm (avoids O(N^2) on pathological inputs) - FxHashMap for fast lookups @@ -20,11 +21,14 @@ Usage: from splintr import Tokenizer - # Load pretrained model + # Load pretrained model (uses regexr by default) tokenizer = Tokenizer.from_pretrained("cl100k_base") # GPT-4 tokenizer = Tokenizer.from_pretrained("llama3") # Llama 3 tokenizer = Tokenizer.from_pretrained("deepseek_v3") # DeepSeek V3 + # Use PCRE2 backend (requires pcre2 feature) + # tokenizer = Tokenizer.from_pretrained("cl100k_base").pcre2(True) + # Encode text tokens = tokenizer.encode("Hello, world!") print(tokens) diff --git a/src/core/mod.rs b/src/core/mod.rs index a981163..c7fa1f9 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -11,7 +11,8 @@ //! The core is organized into four main components: //! //! - [`Tokenizer`]: Main tokenizer struct with encoding/decoding API, LRU cache, -//! and Aho-Corasick special token matching +//! and Aho-Corasick special token matching. Uses regexr backend by default, +//! with optional PCRE2 backend via `.pcre2(true)` (requires `pcre2` feature). //! - [`bpe`]: Low-level byte-pair encoding algorithm using linked-list approach //! - [`vocab`]: Vocabulary loading utilities for tiktoken format //! - [`StreamingDecoder`]: UTF-8 safe streaming decoder for token-by-token LLM output @@ -19,7 +20,8 @@ //! //! # Performance Optimizations //! -//! - **PCRE2 with JIT**: 2-4x faster than fancy-regex for pattern matching +//! - **Regexr with JIT**: Pure Rust regex engine with SIMD acceleration (default) +//! - **Optional PCRE2 with JIT**: requires `pcre2` feature //! - **Rayon parallelism**: Multi-core encoding for batch operations //! - **FxHashMap**: Faster hashing than standard HashMap for string keys //! - **Aho-Corasick**: O(N) multi-pattern matching for special tokens diff --git a/src/core/tokenizer.rs b/src/core/tokenizer.rs index c9dda4a..a9ad40b 100644 --- a/src/core/tokenizer.rs +++ b/src/core/tokenizer.rs @@ -1,7 +1,7 @@ use aho_corasick::AhoCorasick; use lru::LruCache; -use pcre2::bytes::Regex; use rayon::prelude::*; +use regexr::{Regex as RegexrRegex, RegexBuilder}; use rustc_hash::FxHashMap; use rustc_hash::FxHasher; use std::hash::{Hash, Hasher}; @@ -9,20 +9,28 @@ use std::num::NonZeroUsize; use std::sync::Mutex; use thiserror::Error; +#[cfg(feature = "pcre2")] +use pcre2::bytes::Regex as Pcre2Regex; + use super::bpe::byte_pair_encode; use super::byte_level::{byte_level_decode_bytes, byte_level_encode}; use super::vocab::{build_decoder, load_tiktoken_bpe, load_tiktoken_bpe_file, VocabError}; #[derive(Error, Debug)] pub enum TokenizerError { - #[error("Regex compilation error: {0}")] - RegexError(#[from] pcre2::Error), + #[error("Regex compilation error (regexr): {0}")] + RegexrError(#[from] regexr::Error), + #[cfg(feature = "pcre2")] + #[error("Regex compilation error (PCRE2): {0}")] + Pcre2Error(#[from] pcre2::Error), #[error("Vocabulary error: {0}")] VocabError(#[from] VocabError), #[error("Decoding error: invalid UTF-8")] Utf8Error, #[error("Aho-Corasick build error: {0}")] AhoCorasickError(#[from] aho_corasick::BuildError), + #[error("PCRE2 feature not enabled. Compile with --features pcre2")] + Pcre2NotEnabled, } /// Default regex pattern for cl100k_base (GPT-4, GPT-3.5-turbo) @@ -56,470 +64,179 @@ pub const LLAMA3_PATTERN: &str = O200K_BASE_PATTERN; /// - `<|im_start|>`: Generic message start delimiter (ChatML format) /// - `<|im_end|>`: Generic message end delimiter (ChatML format) /// -/// Example: -/// ```text -/// <|im_start|>system -/// You are a helpful assistant.<|im_end|> -/// <|im_start|>user -/// Hello!<|im_end|> -/// <|im_start|>assistant -/// Hi there!<|im_end|> -/// ``` -/// /// ## Reasoning/Thinking (100282-100283) -/// Chain-of-Thought (CoT) tokens for System 2 reasoning, similar to DeepSeek-R1 -/// or OpenAI o1-style thinking: -/// - `<|think|>`: Start of internal reasoning (hidden from user in production) -/// - `<|/think|>`: End of internal reasoning -/// -/// Example: -/// ```text -/// <|think|> -/// Let me break this down step by step... -/// First, I need to consider X. -/// Then, Y follows from X. -/// <|/think|> -/// The answer is Y. -/// ``` +/// Chain-of-Thought (CoT) tokens for System 2 reasoning /// /// ## ReAct Agent Loop (100284-100291) -/// Tokens for ReAct (Reason + Act) agent architectures: -/// - `<|plan|>`: High-level planning phase where agent decides strategy -/// - `<|step|>`: Individual step within a plan -/// - `<|act|>`: Action intent declaration (what the agent wants to do) -/// - `<|observe|>`: Observation/feedback from environment after action -/// -/// Example: -/// ```text -/// <|plan|> -/// I need to: 1) Search for info, 2) Summarize findings -/// <|/plan|> -/// <|step|>Searching for relevant information<|/step|> -/// <|act|>search("climate change effects")<|/act|> -/// <|observe|>Found 3 relevant articles...<|/observe|> -/// ``` +/// Tokens for ReAct (Reason + Act) agent architectures /// /// ## Tool/Function Calling (100292-100297) -/// Structured tool use with explicit success/error handling: -/// - `<|function|>`: Function call specification (name + arguments) -/// - `<|result|>`: Successful function return value -/// - `<|error|>`: Function execution error (enables retry logic) -/// -/// Example: -/// ```text -/// <|function|>{"name": "get_weather", "args": {"city": "London"}}<|/function|> -/// <|result|>{"temp": 18, "condition": "cloudy"}<|/result|> -/// ``` +/// Structured tool use with explicit success/error handling /// /// ## Code Execution (100298-100303) -/// Jupyter notebook-style code interpreter flow: -/// - `<|code|>`: Code block to execute -/// - `<|output|>`: Execution output (stdout, return values) -/// - `<|lang|>`: Programming language identifier -/// -/// Example: -/// ```text -/// <|code|><|lang|>python<|/lang|> -/// import math -/// print(math.sqrt(16)) -/// <|/code|> -/// <|output|>4.0<|/output|> -/// ``` +/// Jupyter notebook-style code interpreter flow /// /// ## RAG/Citations (100304-100311) -/// Retrieval-Augmented Generation with source attribution: -/// - `<|context|>`: Injected context from retrieval system -/// - `<|quote|>`: Direct quotation from source material -/// - `<|cite|>`: Citation reference marker -/// - `<|source|>`: Source metadata (URL, document ID, etc.) -/// -/// Example: -/// ```text -/// <|context|> -/// <|source|>doc_123<|/source|> -/// The Earth orbits the Sun in 365.25 days. -/// <|/context|> -/// According to the source<|cite|>doc_123<|/cite|>, <|quote|>The Earth orbits -/// the Sun in 365.25 days.<|/quote|> -/// ``` +/// Retrieval-Augmented Generation with source attribution /// /// ## Memory/State (100312-100315) -/// Long-term memory and state persistence: -/// - `<|memory|>`: Store information for future reference -/// - `<|recall|>`: Retrieve previously stored information -/// -/// Example: -/// ```text -/// <|memory|>User prefers concise responses<|/memory|> -/// ...later... -/// <|recall|>User prefers concise responses<|/recall|> -/// ``` +/// Long-term memory and state persistence /// /// ## Control Tokens (100316-100318) -/// Sequence control and formatting: -/// - `<|pad|>`: Padding token for batch alignment -/// - `<|stop|>`: Generation stop signal -/// - `<|sep|>`: Separator between segments +/// Sequence control and formatting /// /// ## Multimodal (100319-100324) -/// Placeholders for non-text content: -/// - `<|image|>`: Image embedding or base64 data -/// - `<|audio|>`: Audio embedding or encoded data -/// - `<|video|>`: Video embedding or encoded data -/// -/// Example: -/// ```text -/// Describe this image: <|image|>base64_data_here<|/image|> -/// ``` +/// Placeholders for non-text content /// /// ## Document Structure (100325-100330) -/// Semantic layout tokens for parsing structured documents: -/// - `<|title|>`: Document or section title -/// - `<|section|>`: Semantic section boundary -/// - `<|summary|>`: Condensed content summary -/// -/// Example: -/// ```text -/// <|title|>Introduction<|/title|> -/// <|section|> -/// This section covers the basics... -/// <|/section|> -/// <|summary|>Key points: X, Y, Z<|/summary|> -/// ``` +/// Semantic layout tokens for parsing structured documents pub mod cl100k_agent_tokens { - // ========================================================================= - // Conversation Structure (100277-100281) - // ========================================================================= - - /// System message marker - defines assistant behavior and constraints. pub const SYSTEM: u32 = 100277; - /// User message marker - marks human input in conversation. pub const USER: u32 = 100278; - /// Assistant message marker - marks AI responses. pub const ASSISTANT: u32 = 100279; - /// ChatML message start - generic delimiter for any role. pub const IM_START: u32 = 100280; - /// ChatML message end - closes any message block. pub const IM_END: u32 = 100281; - - // ========================================================================= - // Reasoning/Thinking - Chain-of-Thought (100282-100283) - // ========================================================================= - - /// Start of thinking/reasoning block (System 2 cognition). - /// Content between THINK and THINK_END represents internal reasoning - /// that may be hidden from users in production. pub const THINK: u32 = 100282; - /// End of thinking/reasoning block. pub const THINK_END: u32 = 100283; - - // ========================================================================= - // ReAct Agent Loop (100284-100291) - // ========================================================================= - - /// Start of planning phase - high-level strategy formulation. pub const PLAN: u32 = 100284; - /// End of planning phase. pub const PLAN_END: u32 = 100285; - /// Start of individual step - discrete action within a plan. pub const STEP: u32 = 100286; - /// End of step. pub const STEP_END: u32 = 100287; - /// Start of action - the intent to perform an operation. pub const ACT: u32 = 100288; - /// End of action. pub const ACT_END: u32 = 100289; - /// Start of observation - environment feedback after action. pub const OBSERVE: u32 = 100290; - /// End of observation. pub const OBSERVE_END: u32 = 100291; - - // ========================================================================= - // Tool/Function Calling (100292-100297) - // ========================================================================= - - /// Start of function call - contains function name and arguments (usually JSON). pub const FUNCTION: u32 = 100292; - /// End of function call. pub const FUNCTION_END: u32 = 100293; - /// Start of function result - successful return value. pub const RESULT: u32 = 100294; - /// End of function result. pub const RESULT_END: u32 = 100295; - /// Start of error block - function execution failure, enables retry logic. pub const ERROR: u32 = 100296; - /// End of error block. pub const ERROR_END: u32 = 100297; - - // ========================================================================= - // Code Execution (100298-100303) - // ========================================================================= - - /// Start of code block - executable code content. pub const CODE: u32 = 100298; - /// End of code block. pub const CODE_END: u32 = 100299; - /// Start of execution output - stdout, return values, rendered output. pub const OUTPUT: u32 = 100300; - /// End of execution output. pub const OUTPUT_END: u32 = 100301; - /// Start of language identifier (e.g., "python", "javascript"). pub const LANG: u32 = 100302; - /// End of language identifier. pub const LANG_END: u32 = 100303; - - // ========================================================================= - // RAG/Citations (100304-100311) - // ========================================================================= - - /// Start of retrieved context block - injected by RAG pipeline. pub const CONTEXT: u32 = 100304; - /// End of context block. pub const CONTEXT_END: u32 = 100305; - /// Start of direct quotation from source material. pub const QUOTE: u32 = 100306; - /// End of quotation. pub const QUOTE_END: u32 = 100307; - /// Start of citation marker - references a source. pub const CITE: u32 = 100308; - /// End of citation marker. pub const CITE_END: u32 = 100309; - /// Start of source identifier - URL, document ID, or metadata. pub const SOURCE: u32 = 100310; - /// End of source identifier. pub const SOURCE_END: u32 = 100311; - - // ========================================================================= - // Memory/State Management (100312-100315) - // ========================================================================= - - /// Start of memory block - information to persist across sessions. pub const MEMORY: u32 = 100312; - /// End of memory block. pub const MEMORY_END: u32 = 100313; - /// Start of recall block - retrieved persistent memory. pub const RECALL: u32 = 100314; - /// End of recall block. pub const RECALL_END: u32 = 100315; - - // ========================================================================= - // Control Tokens (100316-100318) - // ========================================================================= - - /// Padding token - used for batch alignment, has no semantic meaning. pub const PAD: u32 = 100316; - /// Stop token - signals end of generation. pub const STOP: u32 = 100317; - /// Separator token - delimits segments within a sequence. pub const SEP: u32 = 100318; - - // ========================================================================= - // Multimodal Placeholders (100319-100324) - // ========================================================================= - - /// Start of image content - embedding vector or encoded image data. pub const IMAGE: u32 = 100319; - /// End of image content. pub const IMAGE_END: u32 = 100320; - /// Start of audio content - embedding vector or encoded audio data. pub const AUDIO: u32 = 100321; - /// End of audio content. pub const AUDIO_END: u32 = 100322; - /// Start of video content - embedding vector or encoded video data. pub const VIDEO: u32 = 100323; - /// End of video content. pub const VIDEO_END: u32 = 100324; - - // ========================================================================= - // Document Structure (100325-100330) - // ========================================================================= - - /// Start of title - document or section title for semantic parsing. pub const TITLE: u32 = 100325; - /// End of title. pub const TITLE_END: u32 = 100326; - /// Start of section - semantic document section boundary. pub const SECTION: u32 = 100327; - /// End of section. pub const SECTION_END: u32 = 100328; - /// Start of summary - condensed content summary. pub const SUMMARY: u32 = 100329; - /// End of summary. pub const SUMMARY_END: u32 = 100330; } /// Agent tokens for o200k_base (GPT-4o). /// -/// These special tokens extend the o200k_base vocabulary for building chat models, -/// reasoning systems, and autonomous agents. Token IDs start at 200019 to avoid -/// conflicts with OpenAI's reserved range (199999-200018). -/// /// See [`cl100k_agent_tokens`] for detailed documentation on each token category. /// The token semantics are identical; only the IDs differ. pub mod o200k_agent_tokens { - // ========================================================================= - // Conversation Structure (200019-200023) - // ========================================================================= - - /// System message marker - defines assistant behavior and constraints. pub const SYSTEM: u32 = 200019; - /// User message marker - marks human input in conversation. pub const USER: u32 = 200020; - /// Assistant message marker - marks AI responses. pub const ASSISTANT: u32 = 200021; - /// ChatML message start - generic delimiter for any role. pub const IM_START: u32 = 200022; - /// ChatML message end - closes any message block. pub const IM_END: u32 = 200023; - - // ========================================================================= - // Reasoning/Thinking - Chain-of-Thought (200024-200025) - // ========================================================================= - - /// Start of thinking/reasoning block (System 2 cognition). pub const THINK: u32 = 200024; - /// End of thinking/reasoning block. pub const THINK_END: u32 = 200025; - - // ========================================================================= - // ReAct Agent Loop (200026-200033) - // ========================================================================= - - /// Start of planning phase - high-level strategy formulation. pub const PLAN: u32 = 200026; - /// End of planning phase. pub const PLAN_END: u32 = 200027; - /// Start of individual step - discrete action within a plan. pub const STEP: u32 = 200028; - /// End of step. pub const STEP_END: u32 = 200029; - /// Start of action - the intent to perform an operation. pub const ACT: u32 = 200030; - /// End of action. pub const ACT_END: u32 = 200031; - /// Start of observation - environment feedback after action. pub const OBSERVE: u32 = 200032; - /// End of observation. pub const OBSERVE_END: u32 = 200033; - - // ========================================================================= - // Tool/Function Calling (200034-200039) - // ========================================================================= - - /// Start of function call - contains function name and arguments (usually JSON). pub const FUNCTION: u32 = 200034; - /// End of function call. pub const FUNCTION_END: u32 = 200035; - /// Start of function result - successful return value. pub const RESULT: u32 = 200036; - /// End of function result. pub const RESULT_END: u32 = 200037; - /// Start of error block - function execution failure, enables retry logic. pub const ERROR: u32 = 200038; - /// End of error block. pub const ERROR_END: u32 = 200039; - - // ========================================================================= - // Code Execution (200040-200045) - // ========================================================================= - - /// Start of code block - executable code content. pub const CODE: u32 = 200040; - /// End of code block. pub const CODE_END: u32 = 200041; - /// Start of execution output - stdout, return values, rendered output. pub const OUTPUT: u32 = 200042; - /// End of execution output. pub const OUTPUT_END: u32 = 200043; - /// Start of language identifier (e.g., "python", "javascript"). pub const LANG: u32 = 200044; - /// End of language identifier. pub const LANG_END: u32 = 200045; - - // ========================================================================= - // RAG/Citations (200046-200053) - // ========================================================================= - - /// Start of retrieved context block - injected by RAG pipeline. pub const CONTEXT: u32 = 200046; - /// End of context block. pub const CONTEXT_END: u32 = 200047; - /// Start of direct quotation from source material. pub const QUOTE: u32 = 200048; - /// End of quotation. pub const QUOTE_END: u32 = 200049; - /// Start of citation marker - references a source. pub const CITE: u32 = 200050; - /// End of citation marker. pub const CITE_END: u32 = 200051; - /// Start of source identifier - URL, document ID, or metadata. pub const SOURCE: u32 = 200052; - /// End of source identifier. pub const SOURCE_END: u32 = 200053; - - // ========================================================================= - // Memory/State Management (200054-200057) - // ========================================================================= - - /// Start of memory block - information to persist across sessions. pub const MEMORY: u32 = 200054; - /// End of memory block. pub const MEMORY_END: u32 = 200055; - /// Start of recall block - retrieved persistent memory. pub const RECALL: u32 = 200056; - /// End of recall block. pub const RECALL_END: u32 = 200057; - - // ========================================================================= - // Control Tokens (200058-200060) - // ========================================================================= - - /// Padding token - used for batch alignment, has no semantic meaning. pub const PAD: u32 = 200058; - /// Stop token - signals end of generation. pub const STOP: u32 = 200059; - /// Separator token - delimits segments within a sequence. pub const SEP: u32 = 200060; - - // ========================================================================= - // Multimodal Placeholders (200061-200066) - // ========================================================================= - - /// Start of image content - embedding vector or encoded image data. pub const IMAGE: u32 = 200061; - /// End of image content. pub const IMAGE_END: u32 = 200062; - /// Start of audio content - embedding vector or encoded audio data. pub const AUDIO: u32 = 200063; - /// End of audio content. pub const AUDIO_END: u32 = 200064; - /// Start of video content - embedding vector or encoded video data. pub const VIDEO: u32 = 200065; - /// End of video content. pub const VIDEO_END: u32 = 200066; - - // ========================================================================= - // Document Structure (200067-200072) - // ========================================================================= - - /// Start of title - document or section title for semantic parsing. pub const TITLE: u32 = 200067; - /// End of title. pub const TITLE_END: u32 = 200068; - /// Start of section - semantic document section boundary. pub const SECTION: u32 = 200069; - /// End of section. pub const SECTION_END: u32 = 200070; - /// Start of summary - condensed content summary. pub const SUMMARY: u32 = 200071; - /// End of summary. pub const SUMMARY_END: u32 = 200072; } /// Default cache size for encoded chunks const DEFAULT_CACHE_SIZE: usize = 4096; -/// High-performance tokenizer using PCRE2 with JIT and Rayon parallelism. +/// Regex backend enum for switching between regexr (default) and PCRE2 (optional) +enum RegexBackend { + Regexr(Box), + #[cfg(feature = "pcre2")] + Pcre2(Pcre2Regex), +} + +impl RegexBackend { + /// Find all matches in the given text, returning (start, end) byte offsets + fn find_iter<'a>(&'a self, text: &'a str) -> Vec<(usize, usize)> { + match self { + RegexBackend::Regexr(regex) => regex + .find_iter(text) + .map(|m| (m.start(), m.end())) + .collect(), + #[cfg(feature = "pcre2")] + RegexBackend::Pcre2(regex) => regex + .find_iter(text.as_bytes()) + .filter_map(|m| m.ok()) + .map(|m| (m.start(), m.end())) + .collect(), + } + } +} + +/// High-performance BPE tokenizer with regexr backend (default) or PCRE2 (optional). /// /// # Performance Characteristics /// @@ -536,24 +253,23 @@ const DEFAULT_CACHE_SIZE: usize = 4096; /// - **Very large single texts (>1MB)**: Use [`encode_rayon`] for texts larger /// than ~1MB where Rayon parallelization within the text becomes beneficial. /// -/// # Design Decision: Sequential by Default +/// # Regex Backend /// -/// The [`encode`] method uses sequential processing because Rayon parallel -/// overhead is significant for typical text sizes: +/// By default, uses the `regexr` backend (pure Rust with JIT and SIMD support). +/// To use PCRE2 instead, enable the `pcre2` feature and call `.pcre2(true)`: /// -/// | Text Size | Sequential | Rayon | Speedup | -/// |-----------|------------|-------|---------| -/// | 100 bytes | 42 MB/s | 3 MB/s | Sequential 12x faster | -/// | 10 KB | 50 MB/s | 26 MB/s | Sequential 2x faster | -/// | 100 KB | 54 MB/s | 41 MB/s | Sequential 1.3x faster | -/// | 1 MB | 44 MB/s | 47 MB/s | Rayon 1.07x faster | +/// ```ignore +/// // Default (regexr) +/// let tokenizer = Tokenizer::from_pretrained("cl100k_base")?; /// -/// Rayon only becomes beneficial at ~1MB, which is rare in typical workloads. -/// For batch processing, use [`encode_batch`] which parallelizes across texts. +/// // With PCRE2 (requires --features pcre2) +/// let tokenizer = Tokenizer::from_pretrained("cl100k_base")?.pcre2(true)?; +/// ``` /// /// # Key Optimizations /// -/// - PCRE2 with JIT compilation (2-4x faster than fancy-regex) +/// - Regexr with JIT compilation and SIMD acceleration (default) +/// - Optional PCRE2 with JIT (2-4x faster than fancy-regex) /// - Rayon parallelism for batch encoding (across texts, not within) /// - Linked-list BPE algorithm (avoids O(N²) on pathological inputs) /// - FxHashMap for fast lookups @@ -565,21 +281,24 @@ pub struct Tokenizer { decoder: FxHashMap>, special_tokens: FxHashMap, special_tokens_decoder: FxHashMap, - special_token_strings: Vec, // Ordered list for Aho-Corasick pattern indices - regex: Regex, + special_token_strings: Vec, + regex: RegexBackend, + pattern: String, special_matcher: Option, chunk_cache: Mutex>>, - /// Whether to use ByteLevel encoding (for GPT-2/Llama/DeepSeek style tokenizers) use_byte_level: bool, + cache_size: usize, } impl Tokenizer { /// Create a new tokenizer from encoder map, special tokens, and regex pattern. /// + /// Uses regexr as the default regex backend. + /// /// # Arguments /// * `encoder` - Map of byte sequences to token IDs /// * `special_tokens` - Map of special token strings to token IDs - /// * `pattern` - PCRE2 regex pattern for tokenization + /// * `pattern` - Regex pattern for tokenization pub fn new( encoder: FxHashMap, u32>, special_tokens: FxHashMap, @@ -592,11 +311,6 @@ impl Tokenizer { /// /// ByteLevel encoding is required for GPT-2, Llama, DeepSeek, and similar tokenizers /// that use a byte-to-unicode mapping for handling arbitrary byte sequences. - /// - /// # Arguments - /// * `encoder` - Map of byte sequences to token IDs - /// * `special_tokens` - Map of special token strings to token IDs - /// * `pattern` - PCRE2 regex pattern for tokenization pub fn new_byte_level( encoder: FxHashMap, u32>, special_tokens: FxHashMap, @@ -620,7 +334,7 @@ impl Tokenizer { /// # Arguments /// * `encoder` - Map of byte sequences to token IDs /// * `special_tokens` - Map of special token strings to token IDs - /// * `pattern` - PCRE2 regex pattern for tokenization + /// * `pattern` - Regex pattern for tokenization /// * `cache_size` - Size of the LRU cache for encoded chunks /// * `use_byte_level` - Enable ByteLevel encoding for GPT-2/Llama/DeepSeek style tokenizers pub fn with_options( @@ -637,14 +351,10 @@ impl Tokenizer { .map(|(k, v)| (*v, k.clone())) .collect(); - // Compile main regex with JIT - let mut regex_builder = pcre2::bytes::RegexBuilder::new(); - regex_builder.jit_if_available(true); - regex_builder.utf(true); - regex_builder.ucp(true); // Unicode property support - let regex = regex_builder.build(pattern)?; + // Compile regex with regexr (default backend) + let regex = RegexBuilder::new(pattern).jit(true).build()?; - // Build Aho-Corasick automaton for special tokens (much faster than regex alternation) + // Build Aho-Corasick automaton for special tokens let special_token_strings: Vec = special_tokens.keys().cloned().collect(); let special_matcher = if special_token_strings.is_empty() { None @@ -653,8 +363,8 @@ impl Tokenizer { }; // Initialize LRU cache - let cache_size = NonZeroUsize::new(cache_size.max(1)).unwrap(); - let chunk_cache = Mutex::new(LruCache::new(cache_size)); + let cache_size_nz = NonZeroUsize::new(cache_size.max(1)).unwrap(); + let chunk_cache = Mutex::new(LruCache::new(cache_size_nz)); Ok(Self { encoder, @@ -662,13 +372,50 @@ impl Tokenizer { special_tokens, special_tokens_decoder, special_token_strings, - regex, + regex: RegexBackend::Regexr(Box::new(regex)), + pattern: pattern.to_string(), special_matcher, chunk_cache, use_byte_level, + cache_size, }) } + /// Switch to PCRE2 regex backend. + /// + /// PCRE2 is an alternative regex backend. Requires the `pcre2` feature + /// to be enabled at compile time. + /// + /// # Example + /// ```ignore + /// let tokenizer = Tokenizer::from_pretrained("cl100k_base")?.pcre2(true)?; + /// ``` + /// + /// # Errors + /// Returns an error if `pcre2` feature is not enabled or regex compilation fails. + #[cfg(feature = "pcre2")] + pub fn pcre2(mut self, use_pcre2: bool) -> Result { + if use_pcre2 { + let mut regex_builder = pcre2::bytes::RegexBuilder::new(); + regex_builder.jit_if_available(true); + regex_builder.utf(true); + regex_builder.ucp(true); + let regex = regex_builder.build(&self.pattern)?; + self.regex = RegexBackend::Pcre2(regex); + } + Ok(self) + } + + /// Switch to PCRE2 regex backend (stub when feature not enabled). + #[cfg(not(feature = "pcre2"))] + pub fn pcre2(self, use_pcre2: bool) -> Result { + if use_pcre2 { + Err(TokenizerError::Pcre2NotEnabled) + } else { + Ok(self) + } + } + /// Create a tokenizer from a tiktoken vocabulary file. pub fn from_file( vocab_path: &str, @@ -690,9 +437,6 @@ impl Tokenizer { } /// Create a tokenizer from raw vocabulary bytes with ByteLevel encoding. - /// - /// Use this for GPT-2, Llama, DeepSeek, and similar tokenizers that use - /// ByteLevel preprocessing. pub fn from_bytes_byte_level( vocab_data: &[u8], pattern: &str, @@ -703,10 +447,6 @@ impl Tokenizer { } /// Compute a fast hash for a byte slice to use as an LRU cache key. - /// - /// Uses FxHasher which is significantly faster than the default SipHash - /// for small keys like text chunks, with acceptable collision rates for - /// caching purposes. #[inline] fn hash_slice(slice: &[u8]) -> u64 { let mut hasher = FxHasher::default(); @@ -715,21 +455,9 @@ impl Tokenizer { } /// Encode a single text chunk with LRU caching. - /// - /// This method implements a multi-tier encoding strategy: - /// 1. **ByteLevel preprocessing** (if enabled): Convert bytes to ByteLevel representation - /// 2. **Direct lookup**: Check if the entire chunk is a known token (O(1)) - /// 3. **Cache hit**: Return cached BPE result if available (O(1)) - /// 4. **BPE encode**: Perform full BPE encoding and cache the result - /// - /// The cache dramatically improves performance for: - /// - Repeated encoding of the same text - /// - Common substrings across different inputs - /// - Text with repetitive patterns (e.g., log files, structured data) fn encode_chunk(&self, slice: &[u8]) -> Vec { // Apply ByteLevel preprocessing if enabled let bytes_to_encode: std::borrow::Cow<[u8]> = if self.use_byte_level { - // Convert raw bytes to ByteLevel representation let byte_level_str = byte_level_encode(slice); std::borrow::Cow::Owned(byte_level_str.into_bytes()) } else { @@ -741,7 +469,7 @@ impl Tokenizer { return vec![rank]; } - // Check cache (using hash of the ByteLevel-encoded bytes) + // Check cache let hash = Self::hash_slice(bytes_to_encode.as_ref()); if let Ok(mut cache) = self.chunk_cache.lock() { if let Some(cached) = cache.get(&hash) { @@ -749,7 +477,7 @@ impl Tokenizer { } } - // Perform BPE encoding on (possibly ByteLevel-encoded) bytes + // Perform BPE encoding let result = byte_pair_encode(bytes_to_encode.as_ref(), &self.encoder); // Store in cache @@ -763,39 +491,14 @@ impl Tokenizer { /// Encode text to token IDs (ignores special tokens in input). /// /// Uses sequential processing, which is faster than parallel for texts up to ~1MB. - /// Achieves ~50 MB/s throughput, approximately 3x faster than tiktoken. - /// - /// # Why Sequential? - /// - /// Rayon parallel processing has significant thread pool overhead that only - /// pays off for very large texts (~1MB+). Benchmarks show: - /// - 100 bytes: Sequential is 12x faster than Rayon - /// - 10 KB: Sequential is 2x faster - /// - 100 KB: Sequential is 1.3x faster - /// - 1 MB: Rayon becomes ~7% faster - /// - /// # When to Use Other Methods - /// - /// - **Multiple texts**: Use [`encode_batch`] for parallel encoding across texts - /// - **Very large texts (>1MB)**: Use [`encode_rayon`] for parallel within-text encoding - /// - **Special tokens**: Use [`encode_with_special`] to recognize special tokens pub fn encode(&self, text: &str) -> Vec { let text_bytes = text.as_bytes(); - - // Collect regex matches (chunks to encode) - let chunks: Vec<(usize, usize)> = self - .regex - .find_iter(text_bytes) - .filter_map(|m| m.ok()) - .map(|m| (m.start(), m.end())) - .collect(); + let chunks = self.regex.find_iter(text); if chunks.is_empty() { return vec![]; } - // Sequential encoding - Rayon overhead not worth it for texts < 1MB - // See struct-level docs for benchmark data let results: Vec> = chunks .iter() .map(|&(start, end)| { @@ -804,46 +507,20 @@ impl Tokenizer { }) .collect(); - // Flatten results results.into_iter().flatten().collect() } /// Encode text to token IDs using Rayon parallel processing. /// - /// Parallelizes BPE encoding of individual regex-matched chunks using Rayon. - /// Only beneficial for very large texts (>1MB) where parallelization overhead - /// is amortized across many chunks. - /// - /// # Performance - /// - /// | Text Size | Sequential | Rayon | Winner | - /// |-----------|------------|-------|--------| - /// | < 500 KB | ~50 MB/s | ~40 MB/s | Sequential | - /// | ~1 MB | ~44 MB/s | ~47 MB/s | Rayon (1.07x) | - /// - /// # When to Use - /// - /// - Single texts larger than ~1MB (e.g., entire books, large documents) - /// - When processing time is more critical than thread pool overhead - /// - /// For most use cases, prefer [`encode`] (sequential) or [`encode_batch`] - /// (parallel across multiple texts). + /// Only beneficial for very large texts (>1MB). pub fn encode_rayon(&self, text: &str) -> Vec { let text_bytes = text.as_bytes(); - - // Collect regex matches (chunks to encode) - let chunks: Vec<(usize, usize)> = self - .regex - .find_iter(text_bytes) - .filter_map(|m| m.ok()) - .map(|m| (m.start(), m.end())) - .collect(); + let chunks = self.regex.find_iter(text); if chunks.is_empty() { return vec![]; } - // Parallel encoding using Rayon - each chunk encoded in parallel let results: Vec> = chunks .par_iter() .map(|&(start, end)| { @@ -852,14 +529,12 @@ impl Tokenizer { }) .collect(); - // Flatten results results.into_iter().flatten().collect() } /// Encode text with special token handling. /// /// Special tokens in the input are encoded directly without BPE. - /// Uses Aho-Corasick for fast multi-pattern matching. pub fn encode_with_special(&self, text: &str) -> Vec { let Some(ref special_matcher) = self.special_matcher else { return self.encode(text); @@ -869,18 +544,15 @@ impl Tokenizer { let mut result = Vec::new(); let mut last_end = 0; - // Find all special tokens using Aho-Corasick (much faster than regex alternation) for m in special_matcher.find_iter(text_bytes) { let start = m.start(); let end = m.end(); - // Encode text before the special token if start > last_end { let slice = &text[last_end..start]; result.extend(self.encode(slice)); } - // Add the special token directly using the pattern index let pattern_idx = m.pattern().as_usize(); let token_str = &self.special_token_strings[pattern_idx]; if let Some(&rank) = self.special_tokens.get(token_str) { @@ -890,7 +562,6 @@ impl Tokenizer { last_end = end; } - // Encode remaining text after last special token if last_end < text.len() { result.extend(self.encode(&text[last_end..])); } @@ -899,27 +570,21 @@ impl Tokenizer { } /// Decode token IDs back to bytes. - /// - /// If ByteLevel encoding was used, this returns the raw bytes after reversing - /// the ByteLevel transformation. pub fn decode_bytes(&self, tokens: &[u32]) -> Vec { let mut result = Vec::with_capacity(tokens.len() * 4); for &token in tokens { if let Some(bytes) = self.decoder.get(&token) { if self.use_byte_level { - // Regular tokens are ByteLevel-encoded, decode them if let Some(decoded) = byte_level_decode_bytes(bytes) { result.extend_from_slice(&decoded); } else { - // Fallback if ByteLevel decode fails result.extend_from_slice(bytes); } } else { result.extend_from_slice(bytes); } } else if let Some(special) = self.special_tokens_decoder.get(&token) { - // Special tokens are NOT ByteLevel-encoded, emit them directly result.extend_from_slice(special.as_bytes()); } } @@ -928,8 +593,6 @@ impl Tokenizer { } /// Decode token IDs to a string. - /// - /// Returns an error if the decoded bytes are not valid UTF-8. pub fn decode(&self, tokens: &[u32]) -> Result { let bytes = self.decode_bytes(tokens); String::from_utf8(bytes).map_err(|_| TokenizerError::Utf8Error) @@ -942,32 +605,11 @@ impl Tokenizer { } /// Batch encode multiple texts in parallel. - /// - /// Uses Rayon to parallelize **across texts** (not within each text). - /// This is the most efficient approach for batch workloads because: - /// - /// 1. Each text is encoded sequentially (optimal for texts < 1MB) - /// 2. Multiple texts are processed in parallel across CPU cores - /// 3. No thread coordination overhead within individual texts - /// - /// # Performance - /// - /// Achieves ~110 MB/s throughput on batch workloads, approximately - /// 10-12x faster than tiktoken's `encode_ordinary_batch`. - /// - /// # Example - /// - /// ```ignore - /// let texts = vec!["Hello".to_string(), "World".to_string()]; - /// let token_ids = tokenizer.encode_batch(&texts); - /// ``` pub fn encode_batch(&self, texts: &[String]) -> Vec> { texts.par_iter().map(|text| self.encode(text)).collect() } /// Batch encode multiple texts with special token handling. - /// - /// Like [`encode_batch`], but recognizes special tokens in the input. pub fn encode_batch_with_special(&self, texts: &[String]) -> Vec> { texts .par_iter() @@ -976,15 +618,6 @@ impl Tokenizer { } /// Batch decode multiple token lists in parallel. - /// - /// Uses Rayon to parallelize decoding across token lists. - /// - /// # Example - /// - /// ```ignore - /// let token_lists = vec![vec![1, 2, 3], vec![4, 5, 6]]; - /// let texts = tokenizer.decode_batch(&token_lists)?; - /// ``` pub fn decode_batch(&self, token_lists: &[Vec]) -> Result, TokenizerError> { token_lists .par_iter() @@ -993,8 +626,6 @@ impl Tokenizer { } /// Batch decode multiple token lists in parallel, replacing invalid UTF-8. - /// - /// Like [`decode_batch`], but uses lossy UTF-8 conversion. pub fn decode_batch_lossy(&self, token_lists: &[Vec]) -> Vec { token_lists .par_iter() @@ -1034,23 +665,67 @@ impl Tokenizer { } } - /// Get cache statistics (hits would require additional tracking). + /// Get the current cache size. pub fn cache_len(&self) -> usize { self.chunk_cache.lock().map(|c| c.len()).unwrap_or(0) } } +impl Clone for Tokenizer { + fn clone(&self) -> Self { + // Clone the regex backend + let regex = match &self.regex { + RegexBackend::Regexr(_) => { + let regex = RegexBuilder::new(&self.pattern).jit(true).build().unwrap(); + RegexBackend::Regexr(Box::new(regex)) + } + #[cfg(feature = "pcre2")] + RegexBackend::Pcre2(_) => { + let mut regex_builder = pcre2::bytes::RegexBuilder::new(); + regex_builder.jit_if_available(true); + regex_builder.utf(true); + regex_builder.ucp(true); + let regex = regex_builder.build(&self.pattern).unwrap(); + RegexBackend::Pcre2(regex) + } + }; + + // Create a new empty cache (caches are not shared) + let cache_size_nz = NonZeroUsize::new(self.cache_size.max(1)).unwrap(); + let chunk_cache = Mutex::new(LruCache::new(cache_size_nz)); + + // Rebuild special matcher + let special_matcher = if self.special_token_strings.is_empty() { + None + } else { + Some(AhoCorasick::new(&self.special_token_strings).unwrap()) + }; + + Self { + encoder: self.encoder.clone(), + decoder: self.decoder.clone(), + special_tokens: self.special_tokens.clone(), + special_tokens_decoder: self.special_tokens_decoder.clone(), + special_token_strings: self.special_token_strings.clone(), + regex, + pattern: self.pattern.clone(), + special_matcher, + chunk_cache, + use_byte_level: self.use_byte_level, + cache_size: self.cache_size, + } + } +} + #[cfg(test)] mod tests { use super::*; fn make_test_tokenizer() -> Tokenizer { let mut encoder = FxHashMap::default(); - // Single bytes (ASCII printable range) for b in 32u8..=126 { encoder.insert(vec![b], b as u32); } - // Some merged tokens encoder.insert(b"Hello".to_vec(), 200); encoder.insert(b"World".to_vec(), 201); encoder.insert(b" World".to_vec(), 202); @@ -1058,41 +733,32 @@ mod tests { let mut special_tokens = FxHashMap::default(); special_tokens.insert("<|endoftext|>".to_string(), 50256); - // Simple pattern that matches words and spaces let pattern = r"\S+|\s+"; - Tokenizer::new(encoder, special_tokens, pattern).unwrap() } #[test] fn test_encode_decode() { let tokenizer = make_test_tokenizer(); - let text = "Hello World"; let tokens = tokenizer.encode(text); let decoded = tokenizer.decode(&tokens).unwrap(); - assert_eq!(decoded, text); } #[test] fn test_encode_with_special() { let tokenizer = make_test_tokenizer(); - let text = "Hello<|endoftext|>World"; let tokens = tokenizer.encode_with_special(text); - - // Should contain the special token assert!(tokens.contains(&50256)); } #[test] fn test_batch_encode() { let tokenizer = make_test_tokenizer(); - let texts = vec!["Hello".to_string(), "World".to_string()]; let batch_tokens = tokenizer.encode_batch(&texts); - assert_eq!(batch_tokens.len(), 2); } @@ -1105,78 +771,46 @@ mod tests { #[test] fn test_cache_works() { let tokenizer = make_test_tokenizer(); - - // Use text that isn't a direct token match to trigger BPE and caching - // "HelloWorld" isn't in the encoder, so it will go through BPE let text = "HelloWorld"; let tokens1 = tokenizer.encode(text); let tokens2 = tokenizer.encode(text); - - // Results should be identical assert_eq!(tokens1, tokens2); - - // Cache should have entries (BPE result was cached) assert!(tokenizer.cache_len() > 0); } #[test] fn test_clear_cache() { let tokenizer = make_test_tokenizer(); - - // Use text that triggers BPE encoding tokenizer.encode("HelloWorld"); assert!(tokenizer.cache_len() > 0); - tokenizer.clear_cache(); assert_eq!(tokenizer.cache_len(), 0); } - // Compile-time verification that agent tokens don't conflict with OpenAI's reserved range - const _: () = { - assert!(super::cl100k_agent_tokens::SYSTEM > 100276); // After endofprompt - assert!(super::cl100k_agent_tokens::SUMMARY_END == 100330); // Last token - assert!(super::o200k_agent_tokens::SYSTEM > 200018); // After endofprompt - assert!(super::o200k_agent_tokens::SUMMARY_END == 200072); // Last token - // Verify token ordering is correct (no gaps or overlaps) - assert!(super::cl100k_agent_tokens::USER == super::cl100k_agent_tokens::SYSTEM + 1); - assert!(super::o200k_agent_tokens::USER == super::o200k_agent_tokens::SYSTEM + 1); - }; - + #[cfg(feature = "pcre2")] #[test] - fn test_agent_tokens_encode_decode() { - // Create a tokenizer with agent tokens for testing - let mut encoder: FxHashMap, u32> = FxHashMap::default(); - encoder.insert(b"Hello".to_vec(), 0); - encoder.insert(b" ".to_vec(), 1); - encoder.insert(b"World".to_vec(), 2); - - let mut special: FxHashMap = FxHashMap::default(); - // Add some agent tokens - special.insert("<|system|>".to_string(), 100277); - special.insert("<|user|>".to_string(), 100278); - special.insert("<|assistant|>".to_string(), 100279); - special.insert("<|think|>".to_string(), 100282); - special.insert("<|/think|>".to_string(), 100283); - - let pattern = r"\S+|\s+"; - let tokenizer = Tokenizer::new(encoder, special, pattern).unwrap(); - - // Test encoding with agent tokens - let text = "<|system|>Hello<|user|>World"; - let tokens = tokenizer.encode_with_special(text); - - // Should contain the special tokens - assert!(tokens.contains(&100277)); // <|system|> - assert!(tokens.contains(&100278)); // <|user|> - - // Test decoding back + fn test_pcre2_backend() { + let tokenizer = make_test_tokenizer().pcre2(true).unwrap(); + let text = "Hello World"; + let tokens = tokenizer.encode(text); let decoded = tokenizer.decode(&tokens).unwrap(); assert_eq!(decoded, text); + } - // Test think tokens - let think_text = "<|think|>reasoning here<|/think|>"; - let think_tokens = tokenizer.encode_with_special(think_text); - assert!(think_tokens.contains(&100282)); // <|think|> - assert!(think_tokens.contains(&100283)); // <|/think|> + #[cfg(not(feature = "pcre2"))] + #[test] + fn test_pcre2_not_enabled() { + let tokenizer = make_test_tokenizer(); + let result = tokenizer.pcre2(true); + assert!(result.is_err()); } + + const _: () = { + assert!(super::cl100k_agent_tokens::SYSTEM > 100276); + assert!(super::cl100k_agent_tokens::SUMMARY_END == 100330); + assert!(super::o200k_agent_tokens::SYSTEM > 200018); + assert!(super::o200k_agent_tokens::SUMMARY_END == 200072); + assert!(super::cl100k_agent_tokens::USER == super::cl100k_agent_tokens::SYSTEM + 1); + assert!(super::o200k_agent_tokens::USER == super::o200k_agent_tokens::SYSTEM + 1); + }; } diff --git a/src/lib.rs b/src/lib.rs index e20c904..7dbc18a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,8 @@ pub use core::{ /// Splintr - Fast Rust BPE tokenizer with Python bindings /// /// A high-performance tokenizer featuring: -/// - PCRE2 with JIT compilation (2-4x faster than fancy-regex) +/// - Regexr with JIT and SIMD (default, pure Rust) +/// - Optional PCRE2 with JIT (requires `pcre2` feature) /// - Rayon parallelism for multi-core encoding /// - Linked-list BPE algorithm (avoids O(N²) on pathological inputs) /// - FxHashMap for fast lookups diff --git a/src/python/bindings.rs b/src/python/bindings.rs index 4359436..13ad869 100644 --- a/src/python/bindings.rs +++ b/src/python/bindings.rs @@ -745,6 +745,33 @@ impl PyTokenizer { Ok(Self { inner }) } + /// Switch to PCRE2 regex backend. + /// + /// PCRE2 is an alternative regex backend. Requires the `pcre2` feature + /// to be enabled at compile time. + /// + /// Args: + /// use_pcre2: Whether to use PCRE2 backend (default: True) + /// + /// Returns: + /// New Tokenizer instance with PCRE2 backend + /// + /// Raises: + /// ValueError: If pcre2 feature is not enabled + /// + /// Example: + /// tokenizer = Tokenizer.from_pretrained("cl100k_base").pcre2(True) + #[pyo3(signature = (use_pcre2=true))] + fn pcre2(&self, use_pcre2: bool) -> PyResult { + // We need to create a new tokenizer since pcre2() consumes self + // For simplicity, we'll get the error message if pcre2 is not enabled + let new_inner = self.inner.clone(); + let result = new_inner + .pcre2(use_pcre2) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + Ok(Self { inner: result }) + } + /// Encode text to token IDs. /// /// Special tokens in the input are treated as regular text. diff --git a/tests/cl100k.rs b/tests/cl100k.rs index 7b95a00..34cdfb3 100644 --- a/tests/cl100k.rs +++ b/tests/cl100k.rs @@ -4,6 +4,10 @@ //! handles special tokens, and produces consistent results. use splintr::{Tokenizer, CL100K_BASE_PATTERN}; +use std::sync::LazyLock; + +/// Shared tokenizer instance to avoid expensive re-initialization per test. +static TOKENIZER: LazyLock = LazyLock::new(create_cl100k_tokenizer_impl); // ============================================================================= // Exact Token ID Tests @@ -277,8 +281,13 @@ fn test_cl100k_fim_format() { assert_eq!(decoded, fim); } -// Helper function to create a cl100k tokenizer for testing -fn create_cl100k_tokenizer() -> Tokenizer { +/// Get the shared tokenizer instance +fn create_cl100k_tokenizer() -> &'static Tokenizer { + &TOKENIZER +} + +/// Implementation that actually constructs the tokenizer +fn create_cl100k_tokenizer_impl() -> Tokenizer { // Load the embedded vocab let vocab_bytes = include_bytes!("../python/splintr/vocabs/cl100k_base.tiktoken"); diff --git a/tests/deepseek_v3.rs b/tests/deepseek_v3.rs index 0815e19..d93e25e 100644 --- a/tests/deepseek_v3.rs +++ b/tests/deepseek_v3.rs @@ -4,6 +4,10 @@ //! handles ByteLevel BPE encoding, special tokens, and produces consistent results. use splintr::{Tokenizer, LLAMA3_PATTERN}; +use std::sync::LazyLock; + +/// Shared tokenizer instance to avoid expensive re-initialization per test. +static TOKENIZER: LazyLock = LazyLock::new(create_deepseek_v3_tokenizer_impl); // ============================================================================= // Exact Token ID Tests @@ -445,12 +449,18 @@ fn test_deepseek_v3_mixed_special_tokens() { assert_eq!(decoded, chat); } -// Helper function to create a DeepSeek V3 tokenizer for testing -fn create_deepseek_v3_tokenizer() -> Tokenizer { - create_deepseek_v3_tokenizer_by_name("deepseek_v3") +/// Get the shared tokenizer instance +fn create_deepseek_v3_tokenizer() -> &'static Tokenizer { + &TOKENIZER +} + +/// Create a fresh tokenizer by name (for variant tests only) +fn create_deepseek_v3_tokenizer_by_name(_name: &str) -> Tokenizer { + create_deepseek_v3_tokenizer_impl() } -fn create_deepseek_v3_tokenizer_by_name(name: &str) -> Tokenizer { +/// Implementation that actually constructs the tokenizer +fn create_deepseek_v3_tokenizer_impl() -> Tokenizer { // Load the embedded vocab let vocab_bytes = include_bytes!("../python/splintr/vocabs/deepseek_v3.tiktoken"); @@ -513,8 +523,6 @@ fn create_deepseek_v3_tokenizer_by_name(name: &str) -> Tokenizer { special.insert("<|output|>".to_string(), 128923); special.insert("<|/output|>".to_string(), 128924); - let _ = name; // Acknowledge variant name (all use same vocab) - // DeepSeek uses ByteLevel BPE encoding Tokenizer::from_bytes_byte_level(vocab_bytes, LLAMA3_PATTERN, special).unwrap() } diff --git a/tests/llama3.rs b/tests/llama3.rs index cf0e67f..84ea892 100644 --- a/tests/llama3.rs +++ b/tests/llama3.rs @@ -4,6 +4,10 @@ //! handles special tokens, and produces consistent results. use splintr::{Tokenizer, LLAMA3_PATTERN}; +use std::sync::LazyLock; + +/// Shared tokenizer instance to avoid expensive re-initialization per test. +static TOKENIZER: LazyLock = LazyLock::new(create_llama3_tokenizer_impl); // ============================================================================= // Exact Token ID Tests @@ -309,12 +313,18 @@ fn test_llama3_from_pretrained_variants() { ); } -// Helper function to create a Llama 3 tokenizer for testing -fn create_llama3_tokenizer() -> Tokenizer { - create_llama3_tokenizer_by_name("llama3") +/// Get the shared tokenizer instance +fn create_llama3_tokenizer() -> &'static Tokenizer { + &TOKENIZER +} + +/// Create a fresh tokenizer by name (for variant tests only) +fn create_llama3_tokenizer_by_name(_name: &str) -> Tokenizer { + create_llama3_tokenizer_impl() } -fn create_llama3_tokenizer_by_name(name: &str) -> Tokenizer { +/// Implementation that actually constructs the tokenizer +fn create_llama3_tokenizer_impl() -> Tokenizer { // Load the embedded vocab let vocab_bytes = include_bytes!("../python/splintr/vocabs/llama3.tiktoken"); @@ -364,8 +374,5 @@ fn create_llama3_tokenizer_by_name(name: &str) -> Tokenizer { special.insert("<|output|>".to_string(), 128323); special.insert("<|/output|>".to_string(), 128324); - // Use the same pattern as the Python bindings - let _ = name; // Acknowledge variant name (all use same vocab) - Tokenizer::from_bytes(vocab_bytes, LLAMA3_PATTERN, special).unwrap() } diff --git a/tests/o200k.rs b/tests/o200k.rs index bd5b27f..8d229d0 100644 --- a/tests/o200k.rs +++ b/tests/o200k.rs @@ -4,6 +4,10 @@ //! handles special tokens, and produces consistent results. use splintr::{Tokenizer, O200K_BASE_PATTERN}; +use std::sync::LazyLock; + +/// Shared tokenizer instance to avoid expensive re-initialization per test. +static TOKENIZER: LazyLock = LazyLock::new(create_o200k_tokenizer_impl); // ============================================================================= // Exact Token ID Tests @@ -274,8 +278,13 @@ fn test_o200k_larger_than_cl100k() { ); } -// Helper function to create an o200k tokenizer for testing -fn create_o200k_tokenizer() -> Tokenizer { +/// Get the shared tokenizer instance +fn create_o200k_tokenizer() -> &'static Tokenizer { + &TOKENIZER +} + +/// Implementation that actually constructs the tokenizer +fn create_o200k_tokenizer_impl() -> Tokenizer { // Load the embedded vocab let vocab_bytes = include_bytes!("../python/splintr/vocabs/o200k_base.tiktoken");