Skip to content

neomond/mnist-cnn-classifier

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 

Repository files navigation

MNIST Handwritten Digit Classifier

A convolutional neural network achieving 99.08% accuracy on MNIST, built with PyTorch. Features a production-style pipeline with reproducibility, CLI configuration, structured logging, and comprehensive error analysis.

Architecture

Input [1, 28, 28]
  → Conv2d(1→32, 3×3) → BatchNorm → ReLU → MaxPool(2×2)    [32, 14, 14]
  → Conv2d(32→64, 3×3) → BatchNorm → ReLU → MaxPool(2×2)   [64, 7, 7]
  → Conv2d(64→64, 3×3) → BatchNorm → ReLU                   [64, 7, 7]
  → Flatten → Linear(3136→128) → ReLU → Dropout(0.5)
  → Linear(128→10)

Results

Metric Value
Test Accuracy 99.08%
Parameters 458,890
Training Time 35s (5 epochs, Apple MPS)
Total Errors 92 / 10,000

Training Curves

Training Curves

Confusion Matrix

Confusion Matrix

Most Confident Wrong Predictions

Wrong Predictions

Learned Conv Filters (Layer 1)

Conv Filters

Feature Map Activations

Feature Maps

Quick Start

pip install torch torchvision matplotlib numpy

# Train with defaults
python mnist_cnn_classifier.py

# Custom hyperparameters
python mnist_cnn_classifier.py --epochs 10 --batch_size 256 --lr 0.0005

Techniques

  • Reproducibility — fixed seeds across all random generators for deterministic results
  • Data augmentation — random rotation (±10°) and translation (±10%)
  • Batch normalization — after each conv layer for stable training
  • Dropout (0.5) — in the classifier head to reduce overfitting
  • LR scheduling — ReduceLROnPlateau for adaptive learning rate decay
  • Auto device selection — CUDA → MPS → CPU detection

Project Structure

├── mnist_cnn_classifier.py        Training & evaluation pipeline
├── outputs/
│   ├── best_model.pth             Best model weights
│   ├── checkpoint.pth             Full training checkpoint
│   ├── 01_samples.png             Dataset samples
│   ├── 02_class_distribution.png  Class balance
│   ├── 03_training_curves.png     Loss & accuracy curves
│   ├── 04_confusion_matrix.png    Prediction analysis
│   ├── 05_wrong_predictions.png   Most confident errors
│   ├── 06_correct_predictions.png Correct prediction samples
│   ├── 07_conv_filters.png        Learned conv filters
│   └── 08_feature_maps.png        Layer 1 activations
└── README.md

Requirements

  • Python 3.10+
  • PyTorch 2.x
  • torchvision
  • matplotlib, numpy

License

MIT

About

CNN achieving 99.08% accuracy on MNIST with PyTorch — complete ML pipeline with training, evaluation, error analysis, and feature visualization.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages