This document summarizes the model families used by src/scripts/train.py.
The training script supports:
- PointGlobalMixedViT (
Transformer1style runs) - OmniLearned PET2 (
--use-omnilearned small|medium|large) - CondOnlyMLP baseline (
--cond_only)
Tasks are either:
- Regression (e.g.,
-E-available-no-muon) - Classification (e.g.,
-npi2)
Implementation: src/models/vit.py (PointGlobalMixedViT).
Per event:
- point features (continuous + optional PID embedding)
- 2D coordinates for positional encoding
- optional global/event-level features
A CLS token is used for readout, and an event token can be added for global conditioning.
The model uses multi-head self-attention via PyTorch scaled dot-product attention.
For one head:
Multi-head output is concatenated and projected:
Each encoder block is pre-norm residual attention + MLP:
with MLP hidden size approximately mlp_ratio * d_model.
src/jobs/submit_train_jobs.py uses (for Transformer1):
d_model = 128depth = 4n_heads = 8dropout = 0.0attn_dropout = 0.0
Implementation: src/models/omnilearned/network.py, created in create_omnilearned_model.
Enabled via:
--use-omnilearned small
--use-omnilearned medium
--use-omnilearned largeParameter presets are defined in src/models/omnilearned/utils.py:
- small: base_dim 128, num_heads 8, num_transformers 8
- medium: base_dim 512, num_heads 16, num_transformers 12
- large: base_dim 1024, num_heads 32, num_transformers 28
In src/jobs/submit_train_jobs.py, the model identifiers in the for model in [...] loop map to CLI flags as:
OLS-> OmniLearned small + pretrained (--use-pretrained pretrain_s)OLS_RW-> OmniLearned small from random initialization (--use-omnilearned smallwith no--use-pretrained)OLM-> OmniLearned medium + pretrained (--use-pretrained pretrain_m)
Implementation: CondOnlyMLP in src/scripts/train.py.
Used with --cond_only, this model ignores particle tokens and uses only global conditioning features. It is an MLP with residual blocks:
and final prediction head:
For regression, the output can be constrained positive using softplus.
The final head is task-dependent:
- Regression: linear head to one scalar
- Classification: linear head to
N_classeslogits
Conceptually:
where