This project implements a simple neural network classifier for the MNIST dataset using Python. The model is trained using the MLPClassifier from scikit-learn, and a Flask API is provided to make predictions on handwritten digit images.
- Dataset: MNIST handwritten digits dataset.
- Model: Neural network with one hidden layer, implemented using
MLPClassifier. - Data Preprocessing:
- Normalizes pixel values to improve model performance.
- Splits the dataset into training and testing sets.
- Evaluation:
- Includes cross-validation to assess model generalization.
- Monitors training loss during training.
- Deployment:
- Flask API to accept digit images as input and return predictions.
-
Clone the repository:
git clone https://github.com/Shahad-irl/mnist-classifier.git cd mnist-classifier -
Install the required Python packages:
pip install numpy scikit-learn flask matplotlib
-
Dataset Loading:
- The code attempts to load the MNIST dataset using three different sources:
fetch_openmlfrom scikit-learn.load_digitsfrom scikit-learn as a fallback.- Keras MNIST dataset if the first two fail.
- The dataset is loaded with error handling to ensure the program continues by falling back to alternative datasets when needed.
- The code attempts to load the MNIST dataset using three different sources:
-
Data Preprocessing:
- The pixel values are normalized using
StandardScalerfrom scikit-learn. - The dataset is split into training and testing sets using
train_test_split.
- The pixel values are normalized using
-
Model Training:
- A neural network model (
MLPClassifier) is defined with one hidden layer of 128 neurons, ReLU activation, and the Adam optimizer. - The model is trained using 5-fold cross-validation and the training data.
- A neural network model (
-
Model Evaluation:
- The model is evaluated on the test set to calculate accuracy.
- The training loss curve is plotted to visualize the model's performance during training.
-
Flask API:
- A simple Flask API is set up with a
/predictendpoint that accepts a POST request with a JSON object containing the input data (flattened 28x28 image). - The API returns the predicted digit based on the trained model.
- A simple Flask API is set up with a
-
Run the script to train the model:
python mnist-classifier.py
-
During training, the script will:
- Perform 5-fold cross-validation.
- Train the model on the MNIST dataset.
- Output training loss and accuracy.
-
Start the Flask server:
python mnist-classifier.py
-
Use an API testing tool (e.g., Postman) or a Python script to send POST requests to the
/predictendpoint.- Example input:
{ "data": [0, 0, 0, ..., 255] // Flattened 28x28 array of pixel values } - Example response:
{ "prediction": 5 }
- Example input:
mnist-classifier/
├── mnist_classifier.py # Main script
├── README.md # Project documentation
-
Cross-Validation Accuracy: Displayed during training.
-
Test Accuracy: Achieved after training, evaluated on a separate test set.
-
Training Loss Curve: Plots the loss during training to monitor performance.
- The model is designed for educational purposes and might not handle real-world images without preprocessing.
- Flask API expects input as a flattened 28x28 array of pixel values, normalized between 0 and 1.
- Enhance the neural network by adding more layers or using a deep learning framework (e.g., TensorFlow or PyTorch).
- Integrate image preprocessing for real-world handwritten digit recognition.
- Dockerize the Flask API for easier deployment.
Shahad-irl
A computer engineering graduate passionate about artificial intelligence and machine learning.
