Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions mnist/README.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,65 @@
# MNIST

This example shows how to run some simple models on MNIST.
This example shows how to train and run inference on MNIST models.

## Setup

Install the dependencies:

```
pip install -r requirements.txt
```

Run the example with:
## Training

Train the model with:

```
python main.py
python main.py --gpu
```

By default, the example runs on the CPU. To run on the GPU, use:
This will train a simple 2-layer MLP for 10 epochs and save the trained model to `model.safetensors`.

By default, the example runs on the CPU. To run on the GPU, use `--gpu`.

For a full list of options:

```
python main.py --gpu
python main.py --help
```

For a full list of options run:
## Inference

After training, you can run inference on the test set:

```bash
# Show predictions for 5 random samples
python infer.py --num-samples 5

# Interactive mode - visualize and predict specific samples
python infer.py --interactive

# Use a custom model path
python infer.py --model my_model.safetensors
```
python main.py --help

The inference script provides:
- Predictions with confidence scores
- ASCII art visualization (no matplotlib required)
- Interactive mode to test specific samples

Example output:
```
✓ Sample 1234:
True: 8
Predicted: 8
Confidence: 89.1%
Top 3:
8: 89.1%
3: 5.2%
9: 2.1%
```

## Other Frameworks

To run the PyTorch or JAX examples install the respective framework.
214 changes: 214 additions & 0 deletions mnist/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright © 2025 Apple Inc.

"""
CLI-based inference script for MNIST.

This script loads a trained MNIST model and provides:
1. Random sample predictions with confidence scores
2. ASCII art visualization of digits (no matplotlib required)
3. Interactive mode to test specific samples

Usage:
# First, train the model
python main.py

# Then run inference
python infer.py

# Or with custom model path
python infer.py --model model.safetensors
"""

import argparse
from pathlib import Path

import mlx.core as mx
import mlx.nn as nn
import numpy as np

import mnist
from main import MLP


def load_model(model_path: str):
"""Load a trained MNIST model from safetensors file."""
if not Path(model_path).exists():
raise FileNotFoundError(
f"Model file not found: {model_path}\n"
"Please run 'python main.py' first to train a model."
)

# Create model with same architecture as main.py
model = MLP(num_layers=2, input_dim=784, hidden_dim=32, output_dim=10)
model.load_weights(model_path)
mx.eval(model.parameters())

return model


def show_mnist_digit(image, label=None):
"""Visualize MNIST image as ASCII art (28x28)."""
image_2d = np.array(image.reshape(28, 28))

# Normalize to 0-1 range
if image_2d.max() > image_2d.min():
image_2d = (image_2d - image_2d.min()) / (image_2d.max() - image_2d.min())

# Convert to ASCII
chars = " ·:-=+*#%@"
result = []
for row in image_2d:
line = ""
for pixel in row:
char_idx = min(int(pixel * (len(chars) - 1)), len(chars) - 1)
line += chars[char_idx] * 2 # Double width for square appearance
result.append(line)

if label is not None:
print(f"\nTrue Label: {label}")
print("\n".join(result))


def predict_samples(model, images, labels, num_samples=5):
"""Show predictions for random samples."""
print("\n" + "=" * 60)
print("Random Sample Predictions")
print("=" * 60)

indices = np.random.choice(len(labels), num_samples, replace=False)

correct = 0
for idx in indices:
idx = int(idx)
image = images[idx:idx+1]
true_label = int(labels[idx].item())

# Predict
logits = model(image)
predicted_label = int(mx.argmax(logits, axis=1).item())
probabilities = mx.softmax(logits, axis=-1)[0]
confidence = float(probabilities[predicted_label].item())

# Check if correct
is_correct = predicted_label == true_label
if is_correct:
correct += 1

# Display result
status = "✓" if is_correct else "✗"
print(f"\n{status} Sample {idx}:")
print(f" True: {true_label}")
print(f" Predicted: {predicted_label}")
print(f" Confidence: {confidence*100:.1f}%")

# Top 3 predictions
top3_indices = mx.argsort(probabilities)[-3:][::-1]
print(f" Top 3:")
for i in top3_indices:
i_val = int(i.item())
prob = float(probabilities[i_val].item())
print(f" {i_val}: {prob*100:.1f}%")

print(f"\nAccuracy: {correct}/{num_samples} ({correct/num_samples*100:.0f}%)")


def interactive_mode(model, images, labels):
"""Interactive prediction mode."""
print("\n" + "=" * 60)
print("Interactive Mode")
print(f"Enter an index (0-{len(labels)-1}) to visualize and predict")
print("Enter 'r' for random sample, 'q' to quit")
print("=" * 60)

while True:
try:
user_input = input("\nIndex: ").strip().lower()

if user_input == 'q':
print("Exiting...")
break

if user_input == 'r':
idx = int(np.random.randint(0, len(labels)))
else:
idx = int(user_input)

if idx < 0 or idx >= len(labels):
print(f"Please enter a number between 0 and {len(labels)-1}")
continue

# Show image
image = images[idx]
true_label = int(labels[idx].item())
show_mnist_digit(image, true_label)

# Predict
logits = model(image.reshape(1, -1))
predicted_label = int(mx.argmax(logits, axis=1).item())
probabilities = mx.softmax(logits, axis=-1)[0]
confidence = float(probabilities[predicted_label].item())

print(f"\nPrediction: {predicted_label} (Confidence: {confidence*100:.1f}%)")
if predicted_label == true_label:
print("✓ Correct!")
else:
print("✗ Wrong!")

except ValueError:
print("Invalid input. Please enter a number, 'r', or 'q'.")
except KeyboardInterrupt:
print("\n\nExiting...")
break


def main(args):
# Load model
print(f"Loading model from {args.model}...")
model = load_model(args.model)
print("Model loaded successfully!")

# Load test data
print("Loading MNIST test data...")
_, _, test_images, test_labels = map(mx.array, getattr(mnist, args.dataset)())
print(f"Loaded {len(test_labels)} test samples")

# Show random predictions
if args.num_samples > 0:
predict_samples(model, test_images, test_labels, args.num_samples)

# Interactive mode
if args.interactive:
interactive_mode(model, test_images, test_labels)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run inference on trained MNIST model"
)
parser.add_argument(
"--model",
type=str,
default="model.safetensors",
help="Path to trained model (default: model.safetensors)",
)
parser.add_argument(
"--dataset",
type=str,
default="mnist",
choices=["mnist", "fashion_mnist"],
help="The dataset to use (default: mnist)",
)
parser.add_argument(
"--num-samples",
type=int,
default=5,
help="Number of random samples to predict (default: 5, 0 to skip)",
)
parser.add_argument(
"--interactive",
action="store_true",
help="Enable interactive prediction mode",
)
args = parser.parse_args()

main(args)
11 changes: 11 additions & 0 deletions mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def eval_fn(X, y):
f" Time {toc - tic:.3f} (s)"
)

# Save the trained model
if args.save_model:
model.save_weights(args.save_model)
print(f"\nModel saved to {args.save_model}")


if __name__ == "__main__":
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.")
Expand All @@ -98,6 +103,12 @@ def eval_fn(X, y):
choices=["mnist", "fashion_mnist"],
help="The dataset to use.",
)
parser.add_argument(
"--save-model",
type=str,
default="model.safetensors",
help="Path to save the trained model (default: model.safetensors).",
)
args = parser.parse_args()
if not args.gpu:
mx.set_default_device(mx.cpu)
Expand Down