Skip to content

Latest commit

 

History

History
511 lines (358 loc) · 21.2 KB

File metadata and controls

511 lines (358 loc) · 21.2 KB

ZeroBatch v2 — Batch-Native Format (LLMBATCH)

Overview

ZeroBatch v2 is a binary file format and zero-copy loader designed for serving pre-tokenized LLM training data as PyTorch tensors. It stores data in pre-formed batches — contiguous blocks of [batch_size, seq_len] uint32 tokens, page-aligned on disk — so that serving a batch to the training loop requires no collation, no per-sample Python work, and no memory allocation.

A single batch read consists of:

offset = 4096 + batch_index * padded_batch_bytes
mmap[offset:end] → np.frombuffer → reshape → torch.from_numpy

Five operations. ~18 microseconds. No Python loops proportional to batch size.


Why This Exists

Standard data loading pipelines (PyTorch DataLoader, HuggingFace datasets) store data per-sample and assemble batches at runtime. For each batch of 32 samples, the typical pipeline executes:

  1. 32 __getitem__ calls (Python function overhead per sample)
  2. 32 tensor slice or deserialization operations
  3. torch.stack() to collate 32 tensors into one batch tensor
  4. Shuffle index computation per sample

This per-sample overhead dominates when data is already in memory or on fast storage and compute is cheap. On an Apple M2 with an SSD (CPU-only), per-sample collation accounts for 80% of data loading time.

ZeroBatch v2 eliminates all of this by pre-forming batches at write time and storing them in a layout that the loader can serve with a single contiguous read.

Honest caveat: On GPU hardware (A100), the model forward/backward pass takes ~130ms per step. Data loading takes <0.1ms regardless of backend. The per-sample overhead that ZeroBatch eliminates is real, but it's invisible when the GPU compute dominates by 1000x. ZeroBatch's data loading advantage translates to training speedup only in CPU-bound or I/O-bound scenarios.


File Format Specification

Magic and Identification

Property Value
Magic bytes LLMBATCH (8 bytes, no null terminator)
File extension .batch
Byte order Little-endian throughout
Token dtype uint32 (4 bytes per token)

Layout

┌──────────────────────────────────────────┐
│ Header (4096 bytes, page-aligned)        │
├──────────────────────────────────────────┤
│ Batch 0 data + page padding              │
├──────────────────────────────────────────┤
│ Batch 1 data + page padding              │
├──────────────────────────────────────────┤
│ ...                                      │
├──────────────────────────────────────────┤
│ Batch N-1 data + page padding            │
└──────────────────────────────────────────┘

Header (4096 bytes)

Offset  Size   Type    Field           Description
──────  ─────  ──────  ──────────────  ─────────────────────────────────────
0       8      bytes   magic           "LLMBATCH" — format identifier
8       4      u32     version         Format version (currently 1)
12      4      u32     batch_size      Samples per batch (e.g. 32)
16      4      u32     seq_len         Tokens per sample (e.g. 512)
20      8      u64     num_batches     Total number of batches in file
28      4      u32     dtype           Token dtype code (0 = uint32)
32      4      u32     seed            Shuffle seed used during writing
36      4      u32     total_records   Total source records before batching
40      4056   zeros   reserved        Padding to 4096-byte page boundary

The header occupies exactly one OS memory page (4096 bytes). This guarantees that batch data starts on a page boundary, enabling the OS to page in individual batches without touching header memory.

Batch Data

Each batch occupies a fixed-size slot:

raw_batch_bytes    = batch_size × seq_len × 4
padded_batch_bytes = ceil(raw_batch_bytes / 4096) × 4096

For the default configuration (batch_size=32, seq_len=512):

raw    = 32 × 512 × 4 = 65,536 bytes (64 KiB) — already page-aligned
padded = 65,536 bytes (no padding needed)

The byte offset of batch i is:

offset = 4096 + i × padded_batch_bytes

This is O(1) random access. No offset table. No index lookup. One integer multiply and one addition.

Token Layout Within a Batch

Tokens are stored as a flat array of uint32 values in row-major order:

[sample_0_token_0] [sample_0_token_1] ... [sample_0_token_511]
[sample_1_token_0] [sample_1_token_1] ... [sample_1_token_511]
...
[sample_31_token_0] [sample_31_token_1] ... [sample_31_token_511]

This maps directly to np.frombuffer(..., dtype=uint32).reshape(batch_size, seq_len) — a zero-copy view of the memory-mapped file.

File Size Formula

file_size = 4096 + num_batches × padded_batch_bytes

For 104,829 records at batch_size=32, seq_len=512:

num_batches = 104829 // 32 = 3275
file_size   = 4096 + 3275 × 65536 = 214,643,796 bytes ≈ 205 MB

Storage Efficiency

ZeroBatch v2 stores tokens as uint32 (4 bytes). PyTorch .pt files store int64 (8 bytes) — double the size for the same data. The v2 format converts uint32 → int64 at read time via arr.astype(np.int64), which is a single vectorized operation.

Format Storage per token 100K records × 512 tokens
ZeroBatch v2 (uint32) 4 bytes ~205 MB
PyTorch .pt (int64) 8 bytes ~409 MB
HF Arrow ~4-8 bytes + overhead ~210 MB

Writing

API

from zerobatch.batch_writer import write_batch_file

num_batches = write_batch_file(
    output_path="dataset.batch",
    tokens_array=tokens,   # np.ndarray, shape (N, seq_len), dtype uint32
    batch_size=32,
    seq_len=512,
    seed=42,
)

What the Writer Does

  1. Shuffle — Generates a random permutation of record indices using the given seed. Records from different parts of the source dataset are mixed into the same batches.

  2. Batch — Groups shuffled records into batches of batch_size. Remainder records are dropped (consistent with drop_last=True in PyTorch DataLoader).

  3. Write header — Writes the 4096-byte header with all metadata.

  4. Write batches — For each batch, packs batch_size records into a contiguous uint32 array and writes it, followed by zero-padding to the next page boundary.

Streaming Writer

For datasets too large to hold in RAM as a single numpy array:

from zerobatch.batch_writer import write_batch_dataset

num_batches = write_batch_dataset(
    output_path="dataset.batch",
    tokens_iter=token_generator(),  # yields 1D arrays of shape (seq_len,)
    total_records=1000000,
    batch_size=32,
    seq_len=512,
    seed=42,
)

The streaming writer buffers records into a numpy array, shuffles once, then writes batches sequentially. It requires total_records × seq_len × 4 bytes of RAM for the buffer — the same as the data itself.

For truly streaming writes that never hold the full dataset (used in the scaling benchmark generator), the synthetic generator writes batches directly without global shuffling, since the data is random anyway.


Reading

Architecture

                    ┌─────────────────────┐
                    │   BatchNativeLoader  │  ← user-facing iterator
                    │   (block-shuffle)    │
                    └──────────┬──────────┘
                               │
                    ┌──────────▼──────────┐
                    │   BatchFileReader    │  ← mmap + header parsing
                    │   (zero-copy reads)  │
                    └──────────┬──────────┘
                               │
                    ┌──────────▼──────────┐
                    │    OS mmap / VFS     │  ← kernel manages paging
                    └──────────┬──────────┘
                               │
                    ┌──────────▼──────────┐
                    │    SSD / NVMe        │  ← physical I/O
                    └─────────────────────┘

BatchFileReader — Low-Level Reader

Opens the file, parses the header, and provides O(1) batch access:

from zerobatch.batch_loader import BatchFileReader

reader = BatchFileReader("dataset.batch")
print(reader.num_batches)  # 3275
print(reader.batch_size)   # 32
print(reader.seq_len)      # 512

# Zero-copy numpy view into mmap
arr = reader.read_batch(0)     # shape (32, 512), dtype uint32

# PyTorch tensor (requires uint32 → int64 copy)
tensor = reader.read_batch_torch(0)  # shape (32, 512), dtype int64

How read_batch works (5 operations):

def read_batch(self, batch_index):
    # 1. Compute byte offset (integer multiply + add)
    offset = HEADER_SIZE + batch_index * self.padded_batch_bytes
    end = offset + self.batch_size * self.seq_len * 4

    # 2. Slice mmap (OS handles paging — may trigger page fault)
    buf = self.mm[offset:end]

    # 3. Interpret bytes as uint32 array (zero-copy view)
    arr = np.frombuffer(buf, dtype=np.uint32)

    # 4. Reshape to 2D (view, no copy)
    return arr.reshape(self.batch_size, self.seq_len)

The read_batch_torch method adds one step — arr.astype(np.int64) — which is the only memory allocation per batch (converting from uint32 to int64 for PyTorch).

BatchNativeLoader — Training Loop Iterator

Wraps BatchFileReader with shuffling and epoch support:

from zerobatch.batch_loader import BatchNativeLoader

loader = BatchNativeLoader(
    path="dataset.batch",
    block_size=256,    # batches per super-block
    shuffle=True,
    seed=42,
    epoch=0,
)

for batch in loader:
    # batch: torch.LongTensor, shape (32, 512)
    loss = model(batch)
    loss.backward()

Epoch support:

for epoch in range(num_epochs):
    loader.set_epoch(epoch)  # changes shuffle order
    for batch in loader:
        train(batch)

Block-Shuffle Strategy

Standard random shuffling of mmap'd data causes random page faults across the entire file — every batch access touches a different page, defeating the OS prefetcher. Sequential access, on the other hand, gives no randomness.

ZeroBatch v2 uses block-shuffle — a middle ground that preserves statistical randomness while maintaining mmap locality:

How It Works

  1. Divide the file's batches into super-blocks of block_size consecutive batches (default: 256).

  2. Shuffle the block order — each epoch uses seed XOR epoch as the random seed, producing a different block permutation per epoch.

  3. Iterate sequentially within each block — accessing 256 consecutive batches triggers OS readahead and sequential prefetching.

File on disk:     [Block0: B0-B255] [Block1: B256-B511] [Block2: B512-B767] ...

Epoch 0 order:    Block2 → Block0 → Block3 → Block1 → ...
                  (within each: sequential B512,B513,...,B767, then B0,B1,...,B255, ...)

Epoch 1 order:    Block1 → Block3 → Block0 → Block2 → ...
                  (different block order, same sequential access within)

Why This Works

  • Statistical randomness: With 256 batches per block and thousands of blocks, the training loop sees data in a different order each epoch. The model never sees the same batch sequence twice.

  • mmap locality: Within each block, 256 × 64 KiB = 16 MiB of sequential reads. The OS prefetcher recognizes this pattern and pre-pages the data, eliminating stalls.

  • Epoch variation: seed XOR epoch produces a completely different block permutation each epoch, so the model doesn't memorize block boundaries.

Block Size Choice

Block size Sequential run Blocks for 100K records Trade-off
1 64 KiB 3,275 Max randomness, worst locality
64 4 MiB ~51 Good balance
256 (default) 16 MiB ~13 Good locality, sufficient randomness
1024 64 MiB ~3 Near-sequential, minimal shuffling

The default of 256 gives 16 MiB sequential runs — large enough for the OS prefetcher to work effectively, small enough to provide meaningful epoch-to-epoch variation.


Memory Model

Why mmap

ZeroBatch v2 uses mmap(2) to map the batch file into virtual memory. The key advantage: the OS manages physical memory automatically.

Virtual address space (unlimited):
┌──────────────────────────────────────────┐
│ Header  │ Batch 0 │ Batch 1 │ ... │ Bn  │  ← mapped to file
└──────────────────────────────────────────┘

Physical RAM (limited — e.g. 8 GB):
┌───────────────────────────┐
│ Active batches only       │  ← OS pages in/out on demand
└───────────────────────────┘
  • Small datasets (fits in RAM): After warm-up, all pages are resident. Performance matches in-RAM access.
  • Large datasets (exceeds RAM): The OS evicts old pages and pages in new ones as needed. Only the working set (current block of 256 batches ≈ 16 MiB) needs to be resident.
  • Huge datasets (many times RAM): Still works. The OS pages in one block at a time. Throughput degrades gracefully based on SSD bandwidth, not catastrophically like torch.load() which requires the entire dataset in RAM.

Comparison with In-RAM Loading

Property PyTorch torch.load() ZeroBatch v2 mmap
Init time Loads entire file into RAM Opens file descriptor (instant)
Memory usage file_size bytes always resident Only active pages (~16 MiB)
Dataset > RAM Swap thrashing / OOM Works (OS manages paging)
First batch Fast (data already in RAM) May trigger page fault (~1ms)
Steady state Fast Fast (pages become resident)

Measured: 2 GB Dataset on 8 GB RAM with 5 GB RAM Pressure

Backend tokens/sec Init time Notes
ZeroBatch v2 (mmap) 1,076,348,619 3 ms Only pages active block
PyTorch DataLoader 108,579,759 2,611 ms torch.load() into 2 GB free RAM
HF Arrow 3,025,756 4,131 ms Per-sample dict overhead + paging

ZeroBatch v2 is 10x faster than PyTorch and 356x faster than HuggingFace when the dataset exceeds available RAM.


Comparison with ZeroBatch v1

ZeroBatch v1 (LLMCHNK format) stores individual records with an offset table for random access. Batching happens at read time in Python.

Property v1 (LLMCHNK) v2 (LLMBATCH)
Storage unit Individual records Pre-formed batches
Access pattern Offset table lookup per record Direct byte seek per batch
Collation Python-side np.stack() of 32 records None (pre-batched on disk)
Python calls per batch 32 read_record + 1 np.stack 1 read_batch
Shuffle granularity Per-sample (full random) Per-block (256 batches)
Header 32 bytes (not aligned) 4096 bytes (page-aligned)
Batch alignment N/A Page-aligned (4096 bytes)
Flexible batch size Yes (any size at runtime) No (fixed at write time)
Throughput 174M tok/s 914M tok/s (5.26x)

When to use v1: You need to change batch size at runtime, or need per-record random access (e.g., for evaluation on specific samples).

When to use v2: Batch size is fixed at data preparation time, and you want maximum throughput during training.


Performance Characteristics

Per-Batch Cost Breakdown

Operation Time Allocates?
Compute byte offset ~1 ns No
mmap slice (warm) ~2 us No (kernel page table walk)
np.frombuffer ~0.5 us No (view into mmap)
.reshape(32, 512) ~0.1 us No (view, no copy)
.astype(np.int64) ~10 us Yes (the only allocation)
torch.from_numpy ~0.5 us No (shares numpy memory)
Total ~15-18 us One 128 KiB allocation

The only memory allocation per batch is the uint32 → int64 type conversion. Everything else is a view into existing memory.

Throughput

Measured on Apple M2, 8 GB RAM, macOS, CPU-only (data loading in isolation, no model):

Dataset size Throughput Notes
205 MB (fits in RAM) 914M tok/s All pages resident after warmup
2 GB (exceeds free RAM) 1,076M tok/s OS paging works well for sequential blocks

The 2 GB result is faster than the 205 MB result because the 2 GB dataset has more batches, giving better amortization of cold-start effects across the median trial selection.

GPU context: On an A100 (65 GB dataset, 167 GB RAM), both ZeroBatch and PyTorch memmap deliver data fast enough to keep the GPU fully saturated. The raw throughput advantage does not translate to faster training steps — see the training benchmark report for details.

Latency Percentiles (205 MB, warm)

Percentile Latency
avg 0.018 ms
p50 0.017 ms
p95 0.019 ms
p99 0.024 ms

The tight p50-p99 spread (17-24 us) shows that batch-level access has predictable, low-variance latency — unlike per-sample loaders where collation time varies with cache state.


Integration with PyTorch Training

Basic Training Loop

from zerobatch.batch_loader import BatchNativeLoader

loader = BatchNativeLoader("data/dataset.batch", seed=42)

model = MyModel()
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(num_epochs):
    loader.set_epoch(epoch)
    for batch in loader:
        # batch: torch.LongTensor, shape (batch_size, seq_len)
        logits = model(batch[:, :-1])
        loss = F.cross_entropy(logits.view(-1, vocab_size), batch[:, 1:].reshape(-1))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Creating a Dataset

import numpy as np
from zerobatch.batch_writer import write_batch_file

# tokens: numpy array of shape (num_records, seq_len), dtype uint32
# e.g., from tokenizing a text corpus
tokens = np.load("tokenized_data.npy").astype(np.uint32)

write_batch_file(
    output_path="dataset.batch",
    tokens_array=tokens,
    batch_size=32,
    seq_len=512,
    seed=42,
)

Limitations and Trade-offs

  1. No GPU training speedup — On modern GPUs (A100, H100), the forward/backward pass dominates step time. ZeroBatch's data loading optimization is invisible. Both loaders deliver ~130ms/step on an A100. The throughput advantage only matters in CPU-bound, I/O-bound, or memory-constrained scenarios.

  2. Fixed batch size — Batch size is baked into the file at write time. Changing batch size requires re-writing the file. Use v1 if you need runtime-configurable batch sizes.

  3. Fixed sequence length — All records must be the same length (padded/truncated during tokenization). Variable-length sequences must be padded to the maximum.

  4. No per-record access — The file is organized by batch, not by record. You cannot efficiently read record #42 without knowing which batch contains it.

  5. uint32 → int64 copy — PyTorch requires int64 for token IDs (embedding lookups). The type conversion is the only per-batch memory allocation. If PyTorch adds uint32 embedding support, this copy can be eliminated.

  6. Block-shuffle is not full shuffle — Records within the same batch are always served together. True per-sample shuffling requires v1 or re-writing the file with a different seed each epoch. In practice, block-shuffle provides sufficient randomness for LLM training — the model sees different batch orderings each epoch, and the initial write already shuffled records into batches.

  7. Single-file format — The entire dataset is one file. For multi-terabyte datasets, a sharded version (multiple .batch files with a manifest) would be needed.

  8. ~2x storage overhead — The page-aligned batch layout uses approximately 2x the disk space of a raw memmap file for the same data.


Source Files

File Description
python/zerobatch/batch_writer.py Writer — converts token arrays to LLMBATCH format
python/zerobatch/batch_loader.py Reader — BatchFileReader (mmap) + BatchNativeLoader (iterator)
benchmarks/zerobatch_v2_loader.py Benchmark adapter