1D Convolutional Neural Network for Antimicrobial Resistance Prediction from Mass Spectrometry Data
A ResNet-inspired deep learning architecture that predicts antimicrobial resistance patterns directly from MALDI-TOF mass spectrometry signals.
Input (6000×1)
↓
Conv1D (kernel=7, stride=2) → BatchNorm → ReLU → MaxPool
↓
[ResidualBlock × 2] (64 channels, stride=1)
↓
[ResidualBlock × 2] (128 channels, stride=2)
↓
[ResidualBlock × 2] (256 channels, stride=2)
↓
[ResidualBlock × 2] (512 channels, stride=2)
↓
Global Average Pooling → Dropout (0.5) → FC → Sigmoid
↓
Output (10 antibiotics)
Input
├─→ Conv1D → BatchNorm → ReLU → Conv1D → BatchNorm
│ ↓
└────────────────────────────────────────────→ Add → ReLU
↓
Output
Core model implementation:
ResidualBlock: Convolutional block with skip connectionsDeepG2P: Main ResNet-1D architecturecreate_deepg2p_model(): Factory function for different model sizes
Key Features:
- Flexible input dimensions (default: 6000×1)
- Multi-label classification (10 antibiotics)
- He/Xavier weight initialization
- Feature map extraction for interpretability
Model Sizes:
# Small: 32 base channels, 2-2-2-2 blocks, 0.3 dropout
model = create_deepg2p_model(model_size='small') # ~500K params
# Medium: 64 base channels, 2-2-2-2 blocks, 0.5 dropout
model = create_deepg2p_model(model_size='medium') # ~2M params
# Large: 64 base channels, 3-4-6-3 blocks, 0.5 dropout
model = create_deepg2p_model(model_size='large') # ~5M paramsComprehensive training pipeline:
Components:
- DRIAMSDataset: Custom PyTorch Dataset for .npy files
- Loss: BCEWithLogitsLoss with automatic pos_weight (handles class imbalance)
- Optimizer: AdamW (lr=1e-4, weight_decay=1e-5)
- Scheduler: ReduceLROnPlateau (patience=3, factor=0.5)
- Metrics: AUPRC, AUROC, Loss
- Logging: TensorBoard + console output
Features:
- Automatic class imbalance handling
- Best model checkpointing (
models/best_model.pth) - Periodic checkpoints every 5 epochs
- Training configuration saved to JSON
- Progress bars with live metrics
# Basic training with default parameters
python src/train.py
# Custom training
python src/train.py \
--train-features data/processed/X_train.npy \
--train-labels data/processed/y_train.npy \
--val-features data/processed/X_val.npy \
--val-labels data/processed/y_val.npy \
--epochs 20 \
--batch-size 32 \
--lr 1e-4 \
--model-size medium \
--num-antibiotics 10# Launch TensorBoard
tensorboard --logdir results/logs
# View at http://localhost:6006import torch
from model import create_deepg2p_model
# Load model
model = create_deepg2p_model(num_antibiotics=10)
checkpoint = torch.load('models/best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Predict
x = torch.randn(1, 1, 6000) # Single spectrum
with torch.no_grad():
logits = model(x)
probs = torch.sigmoid(logits)
print(f"Resistance probabilities: {probs[0].numpy()}").npy files (X, y)
↓
DRIAMSDataset (PyTorch Dataset)
↓
DataLoader (batch_size=32, shuffle=True)
↓
DeepG2P Model
↓
BCEWithLogitsLoss (pos_weight for imbalance)
↓
AdamW Optimizer (lr=1e-4)
↓
Metrics: AUPRC, AUROC
↓
Best Model → models/best_model.pth
The training pipeline automatically calculates pos_weight for BCEWithLogitsLoss:
pos_weight = (# negative samples) / (# positive samples)This penalizes false negatives more heavily for rare resistance cases.
======================================================================
DeepG2P Training Pipeline - Antimicrobial Resistance Prediction
======================================================================
🖥️ Device: cuda
GPU: NVIDIA GeForce RTX 3080
Memory: 10.00 GB
📂 Loading datasets...
Training samples: 8000
Validation samples: 2000
⚖️ Calculating class weights for imbalanced data...
Class imbalance ratio: 15.23:1 (negative:positive)
🏗️ Building DeepG2P model (size: medium)...
Total parameters: 2,147,850
Trainable parameters: 2,147,850
🚀 Starting training for 20 epochs...
Batch size: 32
Learning rate: 0.0001
Optimizer: AdamW
Loss: BCEWithLogitsLoss (pos_weight=15.23)
======================================================================
Epoch 1/20
======================================================================
Epoch 1 [Train]: 100%|████████| 250/250 [00:45<00:00, 5.5it/s, loss=0.2134]
Epoch 1 [Val]: 100%|████████| 63/63 [00:08<00:00, 7.8it/s, loss=0.1892]
📊 Epoch 1 Summary:
Train Loss: 0.2247 | Train AUPRC: 0.7234
Val Loss: 0.1965 | Val AUPRC: 0.7812 | Val AUROC: 0.8456
✅ New best model saved! (Val Loss: 0.1965)
...
- Shape:
(batch_size, 1, 6000) - Type:
torch.FloatTensor - Data: MALDI-TOF mass spectrometry intensities (6000 m/z bins)
- Shape:
(batch_size, 10) - Type:
torch.FloatTensor(logits) - Range:
[0, 1]after sigmoid - Interpretation: Probability of resistance for each antibiotic
- Ceftriaxone
- Ciprofloxacin
- Cefixime
- Ampicillin
- Gentamicin
- Trimethoprim
- Tetracycline
- Chloramphenicol
- Azithromycin
- Meropenem
- Primary metric for imbalanced multi-label classification
- More informative than AUROC for rare events
- Target: >0.80 per antibiotic
- Secondary metric for overall discrimination
- Target: >0.85 per antibiotic
- Combines sigmoid + BCE for numerical stability
- Uses pos_weight to handle class imbalance
- Target: <0.15 validation loss
models/
├── best_model.pth # Best model (lowest val loss)
├── final_model.pth # Final epoch model
├── checkpoint_epoch_5.pth # Checkpoint at epoch 5
├── checkpoint_epoch_10.pth # Checkpoint at epoch 10
├── checkpoint_epoch_15.pth # Checkpoint at epoch 15
└── checkpoint_epoch_20.pth # Checkpoint at epoch 20
results/
├── logs/ # TensorBoard logs
│ └── events.out.tfevents.*
└── training_config.json # Training hyperparameters
checkpoint = {
'epoch': int,
'model_state_dict': OrderedDict,
'optimizer_state_dict': OrderedDict,
'loss': float,
'timestamp': str (ISO format)
}- Skip connections prevent vanishing gradients in deep networks
- 1D convolutions naturally handle sequential mass spec data
- Global average pooling reduces overfitting vs fully connected layers
- Proven architecture adapted from image classification (ResNet-18)
- Numerically stable (log-sum-exp trick)
- Multi-label friendly (independent binary predictions)
- pos_weight handles class imbalance without resampling
- Adaptive learning rates per parameter
- Weight decay decoupled from gradient updates
- Better generalization than Adam
- ResNet: He et al. (2016) - "Deep Residual Learning for Image Recognition"
- DRIAMS: Weis et al. (2020) - "Direct Antimicrobial Resistance Prediction from MALDI-TOF Mass Spectra"
- BCEWithLogits: PyTorch Documentation - Numerically stable binary cross-entropy
- Update
num_antibioticsparameter in model creation - Prepare labels with correct dimensions
- Retrain model
model = create_deepg2p_model(num_antibiotics=15) # 5 new antibiotics# Add more residual blocks
model = DeepG2P(
num_blocks=[3, 4, 6, 3], # ResNet-34 configuration
base_channels=64
)# Focal loss for extreme imbalance
from torch import nn
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
bce_loss = nn.BCEWithLogitsLoss(reduction='none')(inputs, targets)
pt = torch.exp(-bce_loss)
focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss
return focal_loss.mean()- Monitor class balance: Check pos_weight during training
- Use validation set: Never evaluate on training data
- Track AUPRC: More informative than accuracy for imbalanced data
- Save checkpoints: Enables recovery from crashes
- Log hyperparameters: Save training_config.json for reproducibility
- Use TensorBoard: Visualize training curves in real-time
# Reduce batch size
python src/train.py --batch-size 16
# Use smaller model
python src/train.py --model-size small- Check class balance (pos_weight should be >1 for rare events)
- Increase training epochs
- Add data augmentation (noise, shifts)
- Use larger model
- Use GPU (
device='cuda') - Increase batch size
- Reduce num_workers if CPU-bound
- Use mixed precision training (future feature)
MIT License - See repository root for details.
Author: Vihaan Kulkarni
Contact: GitHub
Date: January 2026