Skip to content

kunalb541/BinML

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

525 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Gravitational Microlensing Event Classification

Project

Python 3.10+ PyTorch 2.2+ Version 4.0.0 License: MIT


Overview

Three-class classification of gravitational microlensing events using a CNN-GRU architecture optimized for the Nancy Grace Roman Space Telescope:

  • Class 0: Flat (no lensing)
  • Class 1: PSPL (Point Source Point Lens)
  • Class 2: Binary lens (planetary or stellar companion)

The model uses depthwise separable convolutions, flash attention pooling, hierarchical classification with separate BCE losses, and processes variable-length sequences with causal masking. Temporal encoding uses observation intervals (Δt) rather than absolute timestamps to prevent data leakage.


Installation

Quick Start

# Clone repository
git clone https://github.com/kunalb541/Thesis.git
cd Thesis

# Create conda environment (PyTorch CUDA 12.1 included)
conda env create -f environment.yml
conda activate microlens

# Install Flash Attention (optional, 2-3x attention speedup)
pip install flash-attn --no-build-isolation

# Verify installation
python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: {torch.cuda.is_available()}')"
python -c "import VBBinaryLensing; print('VBBinaryLensing: OK')"
python -c "from flash_attn import flash_attn_func; print('Flash Attention: OK')"

Flash Attention Installation

Note: Flash Attention is optional. The model automatically falls back to PyTorch's F.scaled_dot_product_attention if Flash Attention is not available. Performance impact: ~2-3x slower attention pooling only (minimal overall impact).

Flash Attention for attention pooling requires:

  • NVIDIA GPU with compute capability >= 7.0 (V100, A100, H100, RTX 3090/4090)
  • CUDA 11.6+
  • PyTorch 2.0+

Standard Installation:

pip install packaging ninja
pip install flash-attn --no-build-isolation

If standard installation fails, try building from source:

pip install packaging ninja
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
python setup.py install

On HPC clusters (e.g., bwForCluster):

module load devel/cuda/12.1
module load compiler/gnu/12.1
pip install flash-attn --no-build-isolation

GPU Configuration

The environment.yml includes PyTorch CUDA 12.1 by default.

For different hardware, edit environment.yml before creating the environment:

AMD GPUs (ROCm 6.0) - MI300 series:

# After creating environment:
pip install torch==2.2.0 --index-url https://download.pytorch.org/whl/rocm6.0

NVIDIA GPUs (CUDA 11.8) - older cards:

# In environment.yml, replace PyTorch lines with:
- pytorch::pytorch-cuda=11.8

CPU only:

# In environment.yml, replace PyTorch lines with:
- pytorch::cpuonly

Quick Start

Validate the complete pipeline:

cd code

# 1. Generate test dataset (300 events)
python simulate.py \
    --n_flat 100 \
    --n_pspl 100 \
    --n_binary 100 \
    --binary_preset baseline \
    --output ../data/raw/test.h5

# 2. Train model (5 epochs for validation)
python train.py \
    --data ../data/raw/test.h5 \
    --epochs 5 \
    --batch-size 32 \
    --hierarchical \
    --use-aux-head

# 3. Evaluate (experiment name is auto-generated by train.py)
ls ../results/   # Find your experiment name
python evaluate.py --experiment-name d64_l4_hier --data ../data/raw/test.h5

Note: The --experiment-name argument supports partial matching - you don't need the exact directory name!


Data Generation

Simulation Command

python simulate.py \
    --n_flat 100000 \
    --n_pspl 100000 \
    --n_binary 100000 \
    --binary_preset baseline \
    --output ../data/raw/baseline.h5 \
    --num_workers 32 \
    --seed 42 \
    --oversample 1.3

Binary Lens Presets:

Preset Mass Ratio (q) Separation (s) Impact (u₀) Caustic Required Description
distinct 0.1 - 1.0 0.8 - 1.2 0.001 - 0.3 Yes Forced caustics
planetary 10⁻⁴ - 10⁻² 0.6 - 1.6 0.001 - 0.3 Exoplanet regime
stellar 0.1 - 1.0 0.3 - 3.0 0.001 - 0.5 Binary star systems
baseline 10⁻⁴ - 1.0 0.3 - 3.0 0.001 - 1.0 Full parameter space

Output: <output>.h5

  • Core datasets: flux, delta_t, labels, timestamps
  • Parameters: params_flat, params_pspl, params_binary (structured arrays)
  • Metadata: Mission duration, cadence, seed, version, etc.

Data Format Note:

  • flux contains magnifications (A), not Jansky flux
  • A = 1.0 is baseline (no magnification)
  • A > 1.0 means brighter (magnified)
  • A = 0.0 indicates masked/missing observations

Training

Single GPU

python train.py \
    --data ../data/raw/baseline.h5 \
    --output ../results \
    --epochs 50 \
    --batch-size 64 \
    --lr 5e-4 \
    --weight-decay 1e-4 \
    --warmup-epochs 3 \
    --d-model 64 \
    --n-layers 4 \
    --dropout 0.3 \
    --window-size 5 \
    --hierarchical \
    --use-aux-head \
    --attention-pooling \
    --stage1-weight 1.0 \
    --stage2-weight 1.0 \
    --aux-weight 0.5

Distributed Training (Multi-Node Multi-GPU)

Using train.sbatch (Recommended):

# Submit the complete pipeline
sbatch train.sbatch

The train.sbatch script handles:

  1. Data copying to /dev/shm for fast RAM-backed reads
  2. Multi-node distributed training with torchrun (40 GPUs)
  3. Automatic checkpoint resumption on timeout
  4. Auto-submission of continuation jobs
  5. Automatic evaluation job submission after training completes

Manual SLURM Execution:

salloc --partition=gpu_a100_short --nodes=10 --gres=gpu:4 --exclusive --time=00:30:00

cd ~/Thesis/code
conda activate microlens

export MASTER_ADDR=$(scontrol show hostnames "$SLURM_NODELIST" | head -n 1)
export MASTER_PORT=29500
export LOCAL_WORLD_SIZE=4

srun torchrun \
    --nnodes=$SLURM_NNODES \
    --nproc-per-node=4 \
    --rdzv-backend=c10d \
    --rdzv-endpoint="${MASTER_ADDR}:${MASTER_PORT}" \
    --rdzv-id="train-$(date +%s)" \
    train.py \
    --data /tmp/train.h5 \
    --output ../results \
    --epochs 300 \
    --batch-size 512 \
    --accumulation-steps 2 \
    --lr 5e-4 \
    --hierarchical \
    --use-aux-head \
    --attention-pooling \
    --save-every 3

Hierarchical Classification

The hierarchical mode uses a two-stage classification approach:

Stage 1: Flat vs Non-Flat (Binary Cross-Entropy)
    ↓
Stage 2: PSPL vs Binary (Binary Cross-Entropy, only for Non-Flat)
    ↓
Auxiliary: 3-class Cross-Entropy (gradient stability)

Loss Function:

L_total = λ₁ × L_stage1 + λ₂ × L_stage2 + λ_aux × L_auxiliary

Key Arguments:

  • --hierarchical: Enable hierarchical classification
  • --use-aux-head: Add auxiliary 3-class head for gradient stability
  • --stage1-weight: Weight for Stage 1 loss (default: 1.0)
  • --stage2-weight: Weight for Stage 2 loss (default: 1.0)
  • --aux-weight: Weight for auxiliary loss (default: 0.5)

Checkpoint Resumption

Training automatically resumes from checkpoints:

# Resume from latest checkpoint
python train.py \
    --data baseline.h5 \
    --epochs 300 \
    --resume ../results/checkpoints/d64_l4_hier_*/checkpoints/checkpoint_latest.pt

Training Outputs (saved in results/<experiment>/):

File Description
best.pt Best checkpoint (highest validation accuracy)
epoch_NNN.pt Periodic checkpoints
checkpoints/checkpoint_latest.pt Most recent checkpoint (for resumption)
config.json Full configuration and hyperparameters

Evaluation

Standard Evaluation

python evaluate.py \
    --experiment-name d64_l4_hier \
    --data ../data/test/test.h5 \
    --batch-size 512 \
    --early-detection \
    --n-evolution-per-type 5 \
    --colorblind-safe

Note: Train.py auto-generates names like d64_l4_hier_20250101_143022. Partial matching works:

python evaluate.py --experiment-name d64_l4 --data test.h5

Evaluation Outputs (saved in results/<experiment>/eval_<dataset>_<timestamp>/):

File Description
evaluation_summary.json Overall metrics and configuration
confusion_matrix.png Normalized confusion matrix heatmap
roc_curves.png One-vs-rest ROC curves with AUC scores
calibration.png Reliability diagram with confidence histograms
u0_dependency.png Binary accuracy vs. impact parameter
temporal_bias_check.png t₀ distribution comparison (KS-test)
early_detection_curve.png Accuracy vs. observation completeness
evolution_<class>_<idx>.png Probability evolution (3-panel)
example_light_curves.png Grid of example classifications
per_class_metrics.png Precision/recall/F1 bar chart

Metrics Computed:

  • Overall: accuracy, precision, recall, F1-score
  • Per-class: precision, recall, F1 for Flat, PSPL, Binary
  • AUROC: macro and weighted averages
  • Calibration reliability (ECE)
  • Bootstrap confidence intervals
  • Early detection at [10%, 20%, 30%, 50%, 70%, 100%] completeness
  • Impact parameter (u₀) dependency for binary events
  • Temporal bias check (Kolmogorov-Smirnov test)

Architecture

Model: RomanMicrolensingClassifier (v4.0.0)

Input: Flux [B, N], Time Intervals Δt [B, N]
  │
  ├─ Input Projection: Linear(2 → d_model)
  │
Feature Extraction: Depthwise Separable Conv (2 blocks, causal)
  │  ├─ Block 1: kernel=5, dilation=1
  │  └─ Block 2: kernel=5, dilation=2 (multi-scale)
  │
Recurrent Core: Stacked GRU (CuDNN-fused)
  │  └─ n_layers with dropout between layers
  │
Layer Normalization + Residual Connection
  │
Temporal Pooling:
  │  ├─ Attention Pooling (multi-head, learnable query)
  │  └─ OR Mean Pooling (masked by sequence length)
  │
┌─────────────────────────────────────────────────────┐
│ Hierarchical Classification (v4.0.0)                │
│                                                     │
│  Shared Trunk: Linear + LayerNorm + SiLU + Dropout  │
│        │                                            │
│        ├─ Stage 1 Head: Flat vs Non-Flat (BCE)      │
│        ├─ Stage 2 Head: PSPL vs Binary (BCE)        │
│        └─ Auxiliary Head: 3-class (CE)              │
└─────────────────────────────────────────────────────┘
  │
Output: Log-Probabilities [B, 3]

Default Configuration:

Parameter Value Description
d_model 64 Hidden dimension
n_layers 4 GRU layers
dropout 0.3 Dropout rate
window_size 5 Conv kernel size
max_seq_len 7000 Maximum sequence length
n_classes 3 Output classes
hierarchical True Two-stage classification
attention_pooling True Multi-head attention pooling
head_init_std 0.15 Classification head initialization

Parameter Count: ~100-500K (varies with d_model and n_layers)


Project Structure

Thesis/
├── code/
│   ├── simulate.py      # v4.0.0 - Data generation
│   ├── train.py         # v4.1.0 - Distributed training
│   ├── model.py         # v4.0.0 - Neural network architecture
│   ├── evaluate.py      # v4.1.0 - Evaluation suite
│   ├── train.sbatch     # v4.1.0 - Multi-GPU training job
│   ├── data.sbatch      # v4.0.0 - Data generation job
│   └── logs/            
│
├── data/
│   ├── raw/             # Training datasets
│   └── test/            # Test datasets
│
├── results/
│   └── checkpoints/     # Model checkpoints and evaluations
│ 
├── .gitignore
├── environment.yml
├── README.md
└── LICENSE

Version History

v4.1.0 (Current) - January 2025

train.py & evaluate.py OPTIMIZATION:

  • Fixed /dev/shm cleanup race condition (barrier before delete)
  • Broadcast optimization (rank0 computes stats/indices, broadcasts to others)
  • Subsampling now returns file indices for correct params extraction
  • torch.serialization import guarded to prevent crashes
  • torch.load() wrapped for weights_only compatibility
  • O(N²) params extraction replaced with O(N) precomputed index mapping

v4.0.0 - January 2025

All Components Synchronized:

  • train.py v4.0.0 → v4.1.0
  • model.py v4.0.0
  • simulate.py v4.0.0
  • evaluate.py v4.0.0 → v4.1.0

model.py CRITICAL FIXES:

  • @torch.compiler.disable@torch._dynamo.disable
  • Split Q/KV projections in FlashAttentionPooling
  • Fixed attention mask shape for flash attention
  • Fixed frozen dataclass mutation in ModelConfig
  • Fixed residual causality violation (moved before conv)

simulate.py FIXES:

  • Time grid uses np.arange for exact 15-minute cadence
  • u0 caustic sampling bounds check for planetary regime
  • VBBinaryLensing return type normalized immediately
  • MIN_MAGNIFICATION_CLIP increased from 0.1 to 0.5

train.py FIXES:

  • Deterministic /dev/shm path (SLURM_JOB_ID, not PID)
  • Model calls use keyword arguments (lengths=lengths)
  • lengths tensor moved to device
  • torch.load compatibility wrapper

v3.1.0 - December 2024

  • Realistic Roman 72-day observing season
  • Global m_base array saved in HDF5
  • SharedRAMLensingDataset eliminates train/val double loading

v3.0.0 - December 2024

  • Hierarchical classification with separate BCE losses
  • Auxiliary 3-class head for gradient stability
  • Per-class recall logging during training
  • All magic numbers replaced with named constants

Physical Parameter Ranges

Observational Configuration (Roman-like):

Parameter Value Description
Temporal sampling 15 min Roman Galactic Bulge cadence
Missing observations 5% Uniform random gaps
Photometric error Realistic Roman F146 detector model
Season duration 72 days Single Roman observing season
Max sequence length ~6900 Observations per season
Source magnitude 18-24 AB Baseline brightness

Microlensing Parameters:

Parameter Range Description
Einstein timescale (t_E) 5-30 days Event duration
Peak time (t₀) 10%-90% of season Central time
Impact parameter (u₀) 0.001-1.0 Closest approach
Binary separation (s) 0.1-3.0 Einstein radii
Mass ratio (q) 10⁻⁴ - 1.0 Secondary/primary mass
Source radius (ρ) 10⁻³ - 0.05 Einstein radii

License

MIT License - See LICENSE for details.


References

Survey Resources:


About

Gravitational Microlensing Event Classification

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages