Project
Three-class classification of gravitational microlensing events using a CNN-GRU architecture optimized for the Nancy Grace Roman Space Telescope:
- Class 0: Flat (no lensing)
- Class 1: PSPL (Point Source Point Lens)
- Class 2: Binary lens (planetary or stellar companion)
The model uses depthwise separable convolutions, flash attention pooling, hierarchical classification with separate BCE losses, and processes variable-length sequences with causal masking. Temporal encoding uses observation intervals (Δt) rather than absolute timestamps to prevent data leakage.
# Clone repository
git clone https://github.com/kunalb541/Thesis.git
cd Thesis
# Create conda environment (PyTorch CUDA 12.1 included)
conda env create -f environment.yml
conda activate microlens
# Install Flash Attention (optional, 2-3x attention speedup)
pip install flash-attn --no-build-isolation
# Verify installation
python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: {torch.cuda.is_available()}')"
python -c "import VBBinaryLensing; print('VBBinaryLensing: OK')"
python -c "from flash_attn import flash_attn_func; print('Flash Attention: OK')"Note: Flash Attention is optional. The model automatically falls back to PyTorch's F.scaled_dot_product_attention if Flash Attention is not available. Performance impact: ~2-3x slower attention pooling only (minimal overall impact).
Flash Attention for attention pooling requires:
- NVIDIA GPU with compute capability >= 7.0 (V100, A100, H100, RTX 3090/4090)
- CUDA 11.6+
- PyTorch 2.0+
Standard Installation:
pip install packaging ninja
pip install flash-attn --no-build-isolationIf standard installation fails, try building from source:
pip install packaging ninja
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
python setup.py installOn HPC clusters (e.g., bwForCluster):
module load devel/cuda/12.1
module load compiler/gnu/12.1
pip install flash-attn --no-build-isolationThe environment.yml includes PyTorch CUDA 12.1 by default.
For different hardware, edit environment.yml before creating the environment:
AMD GPUs (ROCm 6.0) - MI300 series:
# After creating environment:
pip install torch==2.2.0 --index-url https://download.pytorch.org/whl/rocm6.0NVIDIA GPUs (CUDA 11.8) - older cards:
# In environment.yml, replace PyTorch lines with:
- pytorch::pytorch-cuda=11.8CPU only:
# In environment.yml, replace PyTorch lines with:
- pytorch::cpuonlyValidate the complete pipeline:
cd code
# 1. Generate test dataset (300 events)
python simulate.py \
--n_flat 100 \
--n_pspl 100 \
--n_binary 100 \
--binary_preset baseline \
--output ../data/raw/test.h5
# 2. Train model (5 epochs for validation)
python train.py \
--data ../data/raw/test.h5 \
--epochs 5 \
--batch-size 32 \
--hierarchical \
--use-aux-head
# 3. Evaluate (experiment name is auto-generated by train.py)
ls ../results/ # Find your experiment name
python evaluate.py --experiment-name d64_l4_hier --data ../data/raw/test.h5Note: The --experiment-name argument supports partial matching - you don't need the exact directory name!
python simulate.py \
--n_flat 100000 \
--n_pspl 100000 \
--n_binary 100000 \
--binary_preset baseline \
--output ../data/raw/baseline.h5 \
--num_workers 32 \
--seed 42 \
--oversample 1.3Binary Lens Presets:
| Preset | Mass Ratio (q) | Separation (s) | Impact (u₀) | Caustic Required | Description |
|---|---|---|---|---|---|
distinct |
0.1 - 1.0 | 0.8 - 1.2 | 0.001 - 0.3 | Yes | Forced caustics |
planetary |
10⁻⁴ - 10⁻² | 0.6 - 1.6 | 0.001 - 0.3 | ✗ | Exoplanet regime |
stellar |
0.1 - 1.0 | 0.3 - 3.0 | 0.001 - 0.5 | ✗ | Binary star systems |
baseline |
10⁻⁴ - 1.0 | 0.3 - 3.0 | 0.001 - 1.0 | ✗ | Full parameter space |
Output: <output>.h5
- Core datasets:
flux,delta_t,labels,timestamps - Parameters:
params_flat,params_pspl,params_binary(structured arrays) - Metadata: Mission duration, cadence, seed, version, etc.
Data Format Note:
fluxcontains magnifications (A), not Jansky flux- A = 1.0 is baseline (no magnification)
- A > 1.0 means brighter (magnified)
- A = 0.0 indicates masked/missing observations
python train.py \
--data ../data/raw/baseline.h5 \
--output ../results \
--epochs 50 \
--batch-size 64 \
--lr 5e-4 \
--weight-decay 1e-4 \
--warmup-epochs 3 \
--d-model 64 \
--n-layers 4 \
--dropout 0.3 \
--window-size 5 \
--hierarchical \
--use-aux-head \
--attention-pooling \
--stage1-weight 1.0 \
--stage2-weight 1.0 \
--aux-weight 0.5Using train.sbatch (Recommended):
# Submit the complete pipeline
sbatch train.sbatchThe train.sbatch script handles:
- Data copying to /dev/shm for fast RAM-backed reads
- Multi-node distributed training with torchrun (40 GPUs)
- Automatic checkpoint resumption on timeout
- Auto-submission of continuation jobs
- Automatic evaluation job submission after training completes
Manual SLURM Execution:
salloc --partition=gpu_a100_short --nodes=10 --gres=gpu:4 --exclusive --time=00:30:00
cd ~/Thesis/code
conda activate microlens
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_NODELIST" | head -n 1)
export MASTER_PORT=29500
export LOCAL_WORLD_SIZE=4
srun torchrun \
--nnodes=$SLURM_NNODES \
--nproc-per-node=4 \
--rdzv-backend=c10d \
--rdzv-endpoint="${MASTER_ADDR}:${MASTER_PORT}" \
--rdzv-id="train-$(date +%s)" \
train.py \
--data /tmp/train.h5 \
--output ../results \
--epochs 300 \
--batch-size 512 \
--accumulation-steps 2 \
--lr 5e-4 \
--hierarchical \
--use-aux-head \
--attention-pooling \
--save-every 3The hierarchical mode uses a two-stage classification approach:
Stage 1: Flat vs Non-Flat (Binary Cross-Entropy)
↓
Stage 2: PSPL vs Binary (Binary Cross-Entropy, only for Non-Flat)
↓
Auxiliary: 3-class Cross-Entropy (gradient stability)
Loss Function:
L_total = λ₁ × L_stage1 + λ₂ × L_stage2 + λ_aux × L_auxiliary
Key Arguments:
--hierarchical: Enable hierarchical classification--use-aux-head: Add auxiliary 3-class head for gradient stability--stage1-weight: Weight for Stage 1 loss (default: 1.0)--stage2-weight: Weight for Stage 2 loss (default: 1.0)--aux-weight: Weight for auxiliary loss (default: 0.5)
Training automatically resumes from checkpoints:
# Resume from latest checkpoint
python train.py \
--data baseline.h5 \
--epochs 300 \
--resume ../results/checkpoints/d64_l4_hier_*/checkpoints/checkpoint_latest.ptTraining Outputs (saved in results/<experiment>/):
| File | Description |
|---|---|
best.pt |
Best checkpoint (highest validation accuracy) |
epoch_NNN.pt |
Periodic checkpoints |
checkpoints/checkpoint_latest.pt |
Most recent checkpoint (for resumption) |
config.json |
Full configuration and hyperparameters |
python evaluate.py \
--experiment-name d64_l4_hier \
--data ../data/test/test.h5 \
--batch-size 512 \
--early-detection \
--n-evolution-per-type 5 \
--colorblind-safeNote: Train.py auto-generates names like d64_l4_hier_20250101_143022. Partial matching works:
python evaluate.py --experiment-name d64_l4 --data test.h5Evaluation Outputs (saved in results/<experiment>/eval_<dataset>_<timestamp>/):
| File | Description |
|---|---|
evaluation_summary.json |
Overall metrics and configuration |
confusion_matrix.png |
Normalized confusion matrix heatmap |
roc_curves.png |
One-vs-rest ROC curves with AUC scores |
calibration.png |
Reliability diagram with confidence histograms |
u0_dependency.png |
Binary accuracy vs. impact parameter |
temporal_bias_check.png |
t₀ distribution comparison (KS-test) |
early_detection_curve.png |
Accuracy vs. observation completeness |
evolution_<class>_<idx>.png |
Probability evolution (3-panel) |
example_light_curves.png |
Grid of example classifications |
per_class_metrics.png |
Precision/recall/F1 bar chart |
Metrics Computed:
- Overall: accuracy, precision, recall, F1-score
- Per-class: precision, recall, F1 for Flat, PSPL, Binary
- AUROC: macro and weighted averages
- Calibration reliability (ECE)
- Bootstrap confidence intervals
- Early detection at [10%, 20%, 30%, 50%, 70%, 100%] completeness
- Impact parameter (u₀) dependency for binary events
- Temporal bias check (Kolmogorov-Smirnov test)
Input: Flux [B, N], Time Intervals Δt [B, N]
│
├─ Input Projection: Linear(2 → d_model)
│
Feature Extraction: Depthwise Separable Conv (2 blocks, causal)
│ ├─ Block 1: kernel=5, dilation=1
│ └─ Block 2: kernel=5, dilation=2 (multi-scale)
│
Recurrent Core: Stacked GRU (CuDNN-fused)
│ └─ n_layers with dropout between layers
│
Layer Normalization + Residual Connection
│
Temporal Pooling:
│ ├─ Attention Pooling (multi-head, learnable query)
│ └─ OR Mean Pooling (masked by sequence length)
│
┌─────────────────────────────────────────────────────┐
│ Hierarchical Classification (v4.0.0) │
│ │
│ Shared Trunk: Linear + LayerNorm + SiLU + Dropout │
│ │ │
│ ├─ Stage 1 Head: Flat vs Non-Flat (BCE) │
│ ├─ Stage 2 Head: PSPL vs Binary (BCE) │
│ └─ Auxiliary Head: 3-class (CE) │
└─────────────────────────────────────────────────────┘
│
Output: Log-Probabilities [B, 3]
Default Configuration:
| Parameter | Value | Description |
|---|---|---|
d_model |
64 | Hidden dimension |
n_layers |
4 | GRU layers |
dropout |
0.3 | Dropout rate |
window_size |
5 | Conv kernel size |
max_seq_len |
7000 | Maximum sequence length |
n_classes |
3 | Output classes |
hierarchical |
True | Two-stage classification |
attention_pooling |
True | Multi-head attention pooling |
head_init_std |
0.15 | Classification head initialization |
Parameter Count: ~100-500K (varies with d_model and n_layers)
Thesis/
├── code/
│ ├── simulate.py # v4.0.0 - Data generation
│ ├── train.py # v4.1.0 - Distributed training
│ ├── model.py # v4.0.0 - Neural network architecture
│ ├── evaluate.py # v4.1.0 - Evaluation suite
│ ├── train.sbatch # v4.1.0 - Multi-GPU training job
│ ├── data.sbatch # v4.0.0 - Data generation job
│ └── logs/
│
├── data/
│ ├── raw/ # Training datasets
│ └── test/ # Test datasets
│
├── results/
│ └── checkpoints/ # Model checkpoints and evaluations
│
├── .gitignore
├── environment.yml
├── README.md
└── LICENSE
train.py & evaluate.py OPTIMIZATION:
- Fixed /dev/shm cleanup race condition (barrier before delete)
- Broadcast optimization (rank0 computes stats/indices, broadcasts to others)
- Subsampling now returns file indices for correct params extraction
- torch.serialization import guarded to prevent crashes
- torch.load() wrapped for weights_only compatibility
- O(N²) params extraction replaced with O(N) precomputed index mapping
All Components Synchronized:
train.pyv4.0.0 → v4.1.0model.pyv4.0.0simulate.pyv4.0.0evaluate.pyv4.0.0 → v4.1.0
model.py CRITICAL FIXES:
@torch.compiler.disable→@torch._dynamo.disable- Split Q/KV projections in FlashAttentionPooling
- Fixed attention mask shape for flash attention
- Fixed frozen dataclass mutation in ModelConfig
- Fixed residual causality violation (moved before conv)
simulate.py FIXES:
- Time grid uses np.arange for exact 15-minute cadence
- u0 caustic sampling bounds check for planetary regime
- VBBinaryLensing return type normalized immediately
- MIN_MAGNIFICATION_CLIP increased from 0.1 to 0.5
train.py FIXES:
- Deterministic /dev/shm path (SLURM_JOB_ID, not PID)
- Model calls use keyword arguments (lengths=lengths)
- lengths tensor moved to device
- torch.load compatibility wrapper
- Realistic Roman 72-day observing season
- Global m_base array saved in HDF5
- SharedRAMLensingDataset eliminates train/val double loading
- Hierarchical classification with separate BCE losses
- Auxiliary 3-class head for gradient stability
- Per-class recall logging during training
- All magic numbers replaced with named constants
Observational Configuration (Roman-like):
| Parameter | Value | Description |
|---|---|---|
| Temporal sampling | 15 min | Roman Galactic Bulge cadence |
| Missing observations | 5% | Uniform random gaps |
| Photometric error | Realistic | Roman F146 detector model |
| Season duration | 72 days | Single Roman observing season |
| Max sequence length | ~6900 | Observations per season |
| Source magnitude | 18-24 AB | Baseline brightness |
Microlensing Parameters:
| Parameter | Range | Description |
|---|---|---|
| Einstein timescale (t_E) | 5-30 days | Event duration |
| Peak time (t₀) | 10%-90% of season | Central time |
| Impact parameter (u₀) | 0.001-1.0 | Closest approach |
| Binary separation (s) | 0.1-3.0 | Einstein radii |
| Mass ratio (q) | 10⁻⁴ - 1.0 | Secondary/primary mass |
| Source radius (ρ) | 10⁻³ - 0.05 | Einstein radii |
MIT License - See LICENSE for details.
Survey Resources:
- OGLE: http://ogle.astrouw.edu.pl/
- MOA: https://www.massey.ac.nz/~iabond/moa/
- Nancy Grace Roman: https://roman.gsfc.nasa.gov/