Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 142 additions & 106 deletions ARCHITECTURE.md

Large diffs are not rendered by default.

35 changes: 19 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
[![CI](https://github.com/dahlem/torchcachex/actions/workflows/ci.yml/badge.svg)](https://github.com/dahlem/torchcachex/actions)
[![codecov](https://codecov.io/gh/dahlem/torchcachex/branch/main/graph/badge.svg)](https://codecov.io/gh/dahlem/torchcachex)

**Drop-in PyTorch module caching with Arrow IPC + SQLite backend**
**Drop-in PyTorch module caching with Arrow IPC + in-memory index backend**

`torchcachex` provides transparent, per-sample caching for non-trainable PyTorch modules with:
- ✅ **O(1) append-only writes** via incremental Arrow IPC segments
- ✅ **O(1) batched lookups** via SQLite index + Arrow memory-mapping
- ✅ **O(1) batched lookups** via in-memory index + Arrow memory-mapping
- ✅ **Native tensor storage** with automatic dtype preservation
- ✅ **LRU hot cache** for in-process hits
- ✅ **Async writes** (non-blocking forward pass)
Expand Down Expand Up @@ -422,7 +422,7 @@ Wraps a PyTorch module to add transparent per-sample caching.

### `ArrowIPCCacheBackend`

Persistent cache using Arrow IPC segments with SQLite index for O(1) operations.
Persistent cache using Arrow IPC segments with in-memory index for O(1) operations.

**Storage Format:**
```
Expand All @@ -431,7 +431,7 @@ cache_dir/module_id/
segment_000000.arrow # Incremental Arrow IPC files
segment_000001.arrow
...
index.db # SQLite with WAL mode
index.pkl # Pickled dict: key → (segment_id, row_offset)
schema.json # Auto-inferred Arrow schema
```

Expand All @@ -446,22 +446,23 @@ cache_dir/module_id/
- `current_rank` (Optional[int]): Current process rank (default: None)

**Methods:**
- `get_batch(keys, map_location="cpu")`: O(1) batch lookup via SQLite index + memory-mapped Arrow
- `get_batch(keys, map_location="cpu")`: O(1) batch lookup via in-memory index + memory-mapped Arrow
- `put_batch(items)`: O(1) append-only write to pending buffer
- `flush()`: Force flush pending writes to new Arrow segment

**Features:**
- **O(1) writes**: New data appended to incremental segments, no rewrites
- **O(1) reads**: SQLite index points directly to (segment_id, row_offset)
- **O(1) reads**: In-memory dict index points directly to (segment_id, row_offset)
- **Native tensors**: Automatic dtype preservation via Arrow's type system
- **Schema inference**: Automatically detects structure on first write
- **Crash safety**: Atomic commits via SQLite WAL + temp file approach
- **Crash safety**: Automatic index rebuild from segments on corruption
- **No database dependencies**: Simple pickle-based index persistence

## Architecture

### Storage Design

torchcachex uses a hybrid Arrow IPC + SQLite architecture optimized for billion-scale caching:
torchcachex uses a hybrid Arrow IPC + in-memory index architecture optimized for billion-scale caching:

**Components:**

Expand All @@ -471,11 +472,12 @@ torchcachex uses a hybrid Arrow IPC + SQLite architecture optimized for billion-
- Memory-mapped for zero-copy reads
- Each segment contains a batch of cached samples

2. **SQLite Index** (`index.db`)
- WAL (Write-Ahead Logging) mode for concurrent reads
2. **Pickle Index** (`index.pkl`)
- In-memory Python dict backed by pickle persistence
- Maps cache keys to (segment_id, row_offset)
- O(1) lookups via primary key index
- Tracks segment metadata (file paths, row counts)
- O(1) lookups via dict access
- Atomic persistence with temp file swap
- Auto-rebuilds from segments on corruption

3. **Schema File** (`schema.json`)
- Auto-inferred from first forward pass
Expand All @@ -488,8 +490,9 @@ torchcachex uses a hybrid Arrow IPC + SQLite architecture optimized for billion-
put_batch() → pending buffer → flush() → {
1. Create Arrow RecordBatch
2. Write to temp segment file
3. Update SQLite index (atomic transaction)
3. Update in-memory index dict
4. Atomic rename temp → final
5. Persist index.pkl (atomic)
}
```

Expand All @@ -498,7 +501,7 @@ put_batch() → pending buffer → flush() → {
```
get_batch() → {
1. Check LRU cache (in-memory)
2. Query SQLite for (segment_id, row_offset)
2. Query in-memory index for (segment_id, row_offset)
3. Memory-map Arrow segment
4. Extract rows (zero-copy)
5. Reconstruct tensors with correct dtype
Expand All @@ -508,10 +511,10 @@ get_batch() → {
**Scalability Properties:**

- **Writes**: O(1) - append new segment, update index
- **Reads**: O(1) - direct index lookup + memory-map
- **Reads**: O(1) - direct dict lookup + memory-map
- **Memory**: O(working set) - only LRU + current segment in memory
- **Disk**: O(N) - one entry per sample across segments
- **Crash Recovery**: Atomic - incomplete segments ignored, SQLite WAL ensures consistency
- **Crash Recovery**: Atomic - incomplete segments ignored, index auto-rebuilds from segments if corrupted

### Schema Inference

Expand Down
30 changes: 13 additions & 17 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@

import argparse
import os
import shutil
import tempfile
import time
from dataclasses import dataclass
from typing import List

import torch
import torch.nn as nn
Expand Down Expand Up @@ -81,7 +79,7 @@ def forward(self, x):
return self.fc(x)


def benchmark_write_scaling(tmpdir: str) -> List[BenchmarkResult]:
def benchmark_write_scaling(tmpdir: str) -> list[BenchmarkResult]:
"""Verify O(1) write scaling: flush time independent of cache size."""
print("\n[Benchmark] Write Scaling (O(1) Verification)")
print("=" * 60)
Expand Down Expand Up @@ -133,7 +131,7 @@ def benchmark_write_scaling(tmpdir: str) -> List[BenchmarkResult]:
return results


def benchmark_read_performance(tmpdir: str) -> List[BenchmarkResult]:
def benchmark_read_performance(tmpdir: str) -> list[BenchmarkResult]:
"""Measure read performance at different cache sizes."""
print("\n[Benchmark] Read Performance")
print("=" * 60)
Expand All @@ -156,7 +154,7 @@ def benchmark_read_performance(tmpdir: str) -> List[BenchmarkResult]:
backend.flush()

# Benchmark random reads
print(f" Benchmarking 1000 random reads...")
print(" Benchmarking 1000 random reads...")
import random

random.seed(42)
Expand Down Expand Up @@ -184,15 +182,13 @@ def benchmark_read_performance(tmpdir: str) -> List[BenchmarkResult]:
return results


def benchmark_memory_usage(tmpdir: str) -> List[BenchmarkResult]:
def benchmark_memory_usage(tmpdir: str) -> list[BenchmarkResult]:
"""Measure memory usage at different cache sizes."""
print("\n[Benchmark] Memory Usage")
print("=" * 60)

try:
import psutil

HAS_PSUTIL = True
except ImportError:
print(" [Skip] psutil not installed")
return []
Expand Down Expand Up @@ -237,7 +233,7 @@ def benchmark_memory_usage(tmpdir: str) -> List[BenchmarkResult]:
return results


def benchmark_cache_speedup(tmpdir: str) -> List[BenchmarkResult]:
def benchmark_cache_speedup(tmpdir: str) -> list[BenchmarkResult]:
"""Compare cached vs uncached performance."""
print("\n[Benchmark] Cache Speedup")
print("=" * 60)
Expand All @@ -250,7 +246,7 @@ def benchmark_cache_speedup(tmpdir: str) -> List[BenchmarkResult]:
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Benchmark WITHOUT caching
print(f" Running WITHOUT cache...")
print(" Running WITHOUT cache...")
module_nocache = BenchmarkModule()

start = time.time()
Expand All @@ -261,7 +257,7 @@ def benchmark_cache_speedup(tmpdir: str) -> List[BenchmarkResult]:
print(f" Time: {time_nocache:.3f}s, Calls: {module_nocache.call_count}")

# Benchmark WITH caching (first epoch - populate cache)
print(f" Running WITH cache (epoch 1 - populate)...")
print(" Running WITH cache (epoch 1 - populate)...")
backend = ArrowIPCCacheBackend(
cache_dir=tmpdir,
module_id="speedup_test",
Expand All @@ -279,7 +275,7 @@ def benchmark_cache_speedup(tmpdir: str) -> List[BenchmarkResult]:
print(f" Time: {time_epoch1:.3f}s, Module calls: {module_cached.call_count}")

# Benchmark WITH caching (second epoch - cache hits)
print(f" Running WITH cache (epoch 2 - cache hits)...")
print(" Running WITH cache (epoch 2 - cache hits)...")
module_cached.call_count = 0

start = time.time()
Expand Down Expand Up @@ -321,7 +317,7 @@ def benchmark_cache_speedup(tmpdir: str) -> List[BenchmarkResult]:
return results


def benchmark_async_write(tmpdir: str) -> List[BenchmarkResult]:
def benchmark_async_write(tmpdir: str) -> list[BenchmarkResult]:
"""Compare async vs sync write performance."""
print("\n[Benchmark] Async Write Performance")
print("=" * 60)
Expand All @@ -334,7 +330,7 @@ def benchmark_async_write(tmpdir: str) -> List[BenchmarkResult]:
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Benchmark SYNC writes
print(f" Running with SYNC writes...")
print(" Running with SYNC writes...")
backend_sync = ArrowIPCCacheBackend(
cache_dir=tmpdir,
module_id="async_test_sync",
Expand All @@ -355,7 +351,7 @@ def benchmark_async_write(tmpdir: str) -> List[BenchmarkResult]:
print(f" Time: {time_sync:.3f}s")

# Benchmark ASYNC writes
print(f" Running with ASYNC writes...")
print(" Running with ASYNC writes...")
backend_async = ArrowIPCCacheBackend(
cache_dir=tmpdir,
module_id="async_test_async",
Expand Down Expand Up @@ -398,7 +394,7 @@ def benchmark_async_write(tmpdir: str) -> List[BenchmarkResult]:
return results


def benchmark_dtype_preservation(tmpdir: str) -> List[BenchmarkResult]:
def benchmark_dtype_preservation(tmpdir: str) -> list[BenchmarkResult]:
"""Verify dtype preservation across different tensor types."""
print("\n[Benchmark] Dtype Preservation")
print("=" * 60)
Expand Down Expand Up @@ -450,7 +446,7 @@ def benchmark_dtype_preservation(tmpdir: str) -> List[BenchmarkResult]:
return results


def generate_markdown_report(all_results: List[BenchmarkResult], output_file: str):
def generate_markdown_report(all_results: list[BenchmarkResult], output_file: str):
"""Generate markdown report from benchmark results."""
print(f"\n[Report] Generating markdown report: {output_file}")

Expand Down
10 changes: 5 additions & 5 deletions examples/advanced_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ def example_kfold_cv():

# Train on fold (features cached progressively)
for batch in train_loader:
features = cached_extractor(batch["image"], cache_ids=batch["cache_ids"])
_features = cached_extractor(batch["image"], cache_ids=batch["cache_ids"])
# ... train classifier ...

backend.flush()

# Validate (reuses cached features from overlapping samples)
for batch in val_loader:
features = cached_extractor(batch["image"], cache_ids=batch["cache_ids"])
_features = cached_extractor(batch["image"], cache_ids=batch["cache_ids"])
# ... evaluate ...

print(f" Fold {fold + 1} complete\n")
Expand Down Expand Up @@ -147,7 +147,7 @@ def example_ddp_training():
print("Training (all ranks compute, only rank 0 writes cache)...")
for batch in loader:
# All ranks compute features
features = cached_extractor(batch["image"], cache_ids=batch["cache_ids"])
_features = cached_extractor(batch["image"], cache_ids=batch["cache_ids"])
# ... train on features ...

backend.flush()
Expand Down Expand Up @@ -201,13 +201,13 @@ def forward(self, x, cache_ids):

print("Training Model A (populates cache)...")
for batch in loader:
logits = model_a(batch["image"], cache_ids=batch["cache_ids"])
_logits = model_a(batch["image"], cache_ids=batch["cache_ids"])
# ... train model A ...
backend.flush()

print("Training Model B (reuses Model A's cache)...")
for batch in loader:
logits = model_b(batch["image"], cache_ids=batch["cache_ids"])
_logits = model_b(batch["image"], cache_ids=batch["cache_ids"])
# ... train model B ...

print("Model B reused all features from Model A's cache!\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/cli_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from hydra.utils import instantiate
from omegaconf import DictConfig
from rich.console import Console
from shade_io.feature_sets.filters import RemoveConstantFeaturesFilter
from sklearn.decomposition import PCA

# Import shade-io components
Expand All @@ -23,7 +24,6 @@
FilteredFeatureSet,
SimpleFeatureSet,
)
from shade_io.feature_sets.filters import RemoveConstantFeaturesFilter

logger = logging.getLogger(__name__)
console = Console()
Expand Down
2 changes: 1 addition & 1 deletion examples/minimal_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
print("Training...")
for epoch in range(3):
print(f" Epoch {epoch + 1}/3")
for batch_idx, (batch_images, batch_labels) in enumerate(loader):
for batch_idx, (batch_images, _batch_labels) in enumerate(loader):
# Get cache IDs for this batch
start_idx = batch_idx * 10
batch_cache_ids = cache_ids[start_idx : start_idx + 10]
Expand Down
2 changes: 1 addition & 1 deletion src/torchcachex/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""torchcachex: Drop-in PyTorch module caching with Arrow IPC + SQLite backend.
"""torchcachex: Drop-in PyTorch module caching with Arrow IPC + in-memory index backend.

This library provides transparent, per-sample caching for non-trainable PyTorch modules
with O(1) append-only writes, native tensor storage, batched lookups, LRU hot cache,
Expand Down
Loading
Loading