Skip to content

Optimize dataloader for faster training#2

Draft
gregorkrz wants to merge 1 commit intomainfrom
cursor/optimize-dataloader-84fb
Draft

Optimize dataloader for faster training#2
gregorkrz wants to merge 1 commit intomainfrom
cursor/optimize-dataloader-84fb

Conversation

@gregorkrz
Copy link
Copy Markdown
Owner

Summary

Nine targeted optimizations to the data loading and preprocessing pipeline that reduce per-batch overhead and improve GPU utilization during training.

Optimizations

# Change File Impact
1 float64 → float32 tensors functions_data.py 2x less memory bandwidth; faster GPU math (especially on consumer GPUs with slow FP64)
2 Vectorize batch_idx construction functions_data.py Replace Python for-loop over events with torch.repeat_interleave
3 Vectorize PID one-hot encoding functions_data.py Replace O(n_particles × n_pids) Python loops with dict lookup + scatter
4 Vectorize renumber_clusters functions_data.py Use torch.unique(return_inverse=True) — one kernel instead of Python loop
5 Replace debug print() with logging functions_data.py Eliminates per-batch stdout flush overhead
6 Binary search in EventDatasetCollection.get_idx dataset.py O(log n) np.searchsorted vs O(n) linear scan
7 Skip torch.cat for single-element batches functions_data.py Avoids unnecessary tensor copy in concat_event_collection
8 Vectorize add_batch_number functions_data.py torch.cumsum + repeat_interleave instead of Python loop + list append
9 Improved DataLoader config train_utils.py, parser_args.py shuffle=True for training, adaptive prefetch_factor, default num_workers 4→ from 1

Why these matter

The training loop spends significant time in CPU-side data preprocessing (get_batch, concat_events, tensor construction) between GPU forward/backward passes. These optimizations target the hottest code paths:

  • get_batch runs on every training step — vectorizing batch_idx, PID encoding, and removing print statements directly reduces wall-clock time per step
  • float32 halves memory traffic for all particle tensors through the entire pipeline (CPU→GPU transfer, collation, feature construction)
  • renumber_clusters is called on every batch during filtering — the vectorized version avoids Python-level iteration and temporary tensor allocation
  • DataLoader workers + prefetch overlap CPU preprocessing with GPU compute

Testing

  • All 88 unit tests pass (including updated tests for the float32 change)
  • Gradio demo inference verified end-to-end (same results: 2 model jets, 9 AK jets on QCD event #15)
Open in Web Open in Cursor 

Performance optimizations to the data loading pipeline:

1. Switch to_tensor from float64 to float32
   - Halves memory bandwidth for all particle tensors
   - Faster GPU operations (especially on consumer GPUs with slow FP64)

2. Vectorize batch_idx construction in get_batch
   - Replace Python for-loop with torch.repeat_interleave
   - O(1) tensor ops instead of O(n_events) Python iterations

3. Vectorize PID one-hot encoding
   - Replace O(n_particles * n_pids) Python loops with dict lookup + scatter
   - Single-pass vectorized assignment

4. Vectorize renumber_clusters
   - Use torch.unique(return_inverse=True) instead of Python loop + index table
   - Eliminates temporary mapping tensor allocation

5. Replace debug prints with logging in get_batch
   - Removes per-batch print() calls that flush stdout on every iteration
   - Switched to Python logging module (debug/warning levels)

6. Optimize EventDatasetCollection.get_idx with np.searchsorted
   - O(log n) binary search vs O(n) linear scan over dataset thresholds

7. Optimize concat_event_collection
   - Skip torch.cat for single-element batches (avoids unnecessary copy)

8. Vectorize add_batch_number
   - Use torch.cumsum + torch.repeat_interleave instead of Python loop

9. Improve DataLoader configuration
   - Enable shuffle=True for training (better convergence)
   - Add adaptive prefetch_factor based on batch_size/num_workers
   - Increase default num_workers from 1 to 4

Also includes test updates for the float32 change.

Co-authored-by: Gregor Kržmanc <gregor.krzmanc@cern.ch>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants