Skip to content

cache_on_device=True causes silent GPU OOM stall with large observations #22

@cweniger

Description

@cweniger

Problem

When cache_on_device: true is set in the estimator loop config, CachedDataLoader moves the entire training buffer to GPU as torch tensors. For large observation vectors (e.g., 1M bins), this silently exhausts GPU memory.

PyTorch/CUDA does not raise an OOM error in this case — instead the process stalls indefinitely as CUDA's memory allocator retries internally. The driver shows no error; training appears hung at "Initializing model...".

Observed behavior

  • 05_linear_regression with 1024 samples × 1M bins, cache_on_device: true
  • nvidia-smi shows 40209/40960 MiB used (A100-40GB), 0% GPU util
  • Process stalls for 15+ minutes with no error or progress

Expected behavior

Either:

  1. Detect insufficient GPU memory before attempting the transfer and fall back to CPU caching with a warning
  2. Only cache small tensors (theta, logprob) on GPU; keep large observation vectors on CPU
  3. Raise an explicit error if the buffer exceeds a configurable GPU memory budget

Suggested approach

  • Add a max_gpu_cache_bytes threshold (or per-key size check) so large observation arrays stay on CPU
  • Wrap the .to(device) in a try/except for torch.cuda.OutOfMemoryError and fall back to CPU with a warning
  • Consider making cache_on_device accept a list of key patterns rather than a blanket bool

Context

The torch tensor cache (CachedDataLoader) was introduced to speed up batch sampling. CPU-side torch tensors already provide ~5x speedup over numpy. The GPU cache is an optional optimization that only makes sense when the buffer fits comfortably in GPU memory alongside the model.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions