From 24269f3c7eb1b14541f9981feddddd5fde604604 Mon Sep 17 00:00:00 2001 From: Hoontheduck Date: Sat, 22 Nov 2025 20:12:28 +0900 Subject: [PATCH] Add inference script for MNIST This adds a CLI-based inference script to demonstrate model loading and prediction after training. - New file: mnist/infer.py with ASCII visualization - Modified: mnist/main.py to save trained weights - Modified: mnist/README.md with inference documentation --- mnist/README.md | 50 +++++++++-- mnist/infer.py | 214 ++++++++++++++++++++++++++++++++++++++++++++++++ mnist/main.py | 11 +++ 3 files changed, 268 insertions(+), 7 deletions(-) create mode 100644 mnist/infer.py diff --git a/mnist/README.md b/mnist/README.md index 7103360f3..2002ef466 100644 --- a/mnist/README.md +++ b/mnist/README.md @@ -1,6 +1,8 @@ # MNIST -This example shows how to run some simple models on MNIST. +This example shows how to train and run inference on MNIST models. + +## Setup Install the dependencies: @@ -8,22 +10,56 @@ Install the dependencies: pip install -r requirements.txt ``` -Run the example with: +## Training + +Train the model with: ``` -python main.py +python main.py --gpu ``` -By default, the example runs on the CPU. To run on the GPU, use: +This will train a simple 2-layer MLP for 10 epochs and save the trained model to `model.safetensors`. + +By default, the example runs on the CPU. To run on the GPU, use `--gpu`. + +For a full list of options: ``` -python main.py --gpu +python main.py --help ``` -For a full list of options run: +## Inference + +After training, you can run inference on the test set: +```bash +# Show predictions for 5 random samples +python infer.py --num-samples 5 + +# Interactive mode - visualize and predict specific samples +python infer.py --interactive + +# Use a custom model path +python infer.py --model my_model.safetensors ``` -python main.py --help + +The inference script provides: +- Predictions with confidence scores +- ASCII art visualization (no matplotlib required) +- Interactive mode to test specific samples + +Example output: +``` +✓ Sample 1234: + True: 8 + Predicted: 8 + Confidence: 89.1% + Top 3: + 8: 89.1% + 3: 5.2% + 9: 2.1% ``` +## Other Frameworks + To run the PyTorch or JAX examples install the respective framework. diff --git a/mnist/infer.py b/mnist/infer.py new file mode 100644 index 000000000..6f7933825 --- /dev/null +++ b/mnist/infer.py @@ -0,0 +1,214 @@ +# Copyright © 2025 Apple Inc. + +""" +CLI-based inference script for MNIST. + +This script loads a trained MNIST model and provides: +1. Random sample predictions with confidence scores +2. ASCII art visualization of digits (no matplotlib required) +3. Interactive mode to test specific samples + +Usage: + # First, train the model + python main.py + + # Then run inference + python infer.py + + # Or with custom model path + python infer.py --model model.safetensors +""" + +import argparse +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +import mnist +from main import MLP + + +def load_model(model_path: str): + """Load a trained MNIST model from safetensors file.""" + if not Path(model_path).exists(): + raise FileNotFoundError( + f"Model file not found: {model_path}\n" + "Please run 'python main.py' first to train a model." + ) + + # Create model with same architecture as main.py + model = MLP(num_layers=2, input_dim=784, hidden_dim=32, output_dim=10) + model.load_weights(model_path) + mx.eval(model.parameters()) + + return model + + +def show_mnist_digit(image, label=None): + """Visualize MNIST image as ASCII art (28x28).""" + image_2d = np.array(image.reshape(28, 28)) + + # Normalize to 0-1 range + if image_2d.max() > image_2d.min(): + image_2d = (image_2d - image_2d.min()) / (image_2d.max() - image_2d.min()) + + # Convert to ASCII + chars = " ·:-=+*#%@" + result = [] + for row in image_2d: + line = "" + for pixel in row: + char_idx = min(int(pixel * (len(chars) - 1)), len(chars) - 1) + line += chars[char_idx] * 2 # Double width for square appearance + result.append(line) + + if label is not None: + print(f"\nTrue Label: {label}") + print("\n".join(result)) + + +def predict_samples(model, images, labels, num_samples=5): + """Show predictions for random samples.""" + print("\n" + "=" * 60) + print("Random Sample Predictions") + print("=" * 60) + + indices = np.random.choice(len(labels), num_samples, replace=False) + + correct = 0 + for idx in indices: + idx = int(idx) + image = images[idx:idx+1] + true_label = int(labels[idx].item()) + + # Predict + logits = model(image) + predicted_label = int(mx.argmax(logits, axis=1).item()) + probabilities = mx.softmax(logits, axis=-1)[0] + confidence = float(probabilities[predicted_label].item()) + + # Check if correct + is_correct = predicted_label == true_label + if is_correct: + correct += 1 + + # Display result + status = "✓" if is_correct else "✗" + print(f"\n{status} Sample {idx}:") + print(f" True: {true_label}") + print(f" Predicted: {predicted_label}") + print(f" Confidence: {confidence*100:.1f}%") + + # Top 3 predictions + top3_indices = mx.argsort(probabilities)[-3:][::-1] + print(f" Top 3:") + for i in top3_indices: + i_val = int(i.item()) + prob = float(probabilities[i_val].item()) + print(f" {i_val}: {prob*100:.1f}%") + + print(f"\nAccuracy: {correct}/{num_samples} ({correct/num_samples*100:.0f}%)") + + +def interactive_mode(model, images, labels): + """Interactive prediction mode.""" + print("\n" + "=" * 60) + print("Interactive Mode") + print(f"Enter an index (0-{len(labels)-1}) to visualize and predict") + print("Enter 'r' for random sample, 'q' to quit") + print("=" * 60) + + while True: + try: + user_input = input("\nIndex: ").strip().lower() + + if user_input == 'q': + print("Exiting...") + break + + if user_input == 'r': + idx = int(np.random.randint(0, len(labels))) + else: + idx = int(user_input) + + if idx < 0 or idx >= len(labels): + print(f"Please enter a number between 0 and {len(labels)-1}") + continue + + # Show image + image = images[idx] + true_label = int(labels[idx].item()) + show_mnist_digit(image, true_label) + + # Predict + logits = model(image.reshape(1, -1)) + predicted_label = int(mx.argmax(logits, axis=1).item()) + probabilities = mx.softmax(logits, axis=-1)[0] + confidence = float(probabilities[predicted_label].item()) + + print(f"\nPrediction: {predicted_label} (Confidence: {confidence*100:.1f}%)") + if predicted_label == true_label: + print("✓ Correct!") + else: + print("✗ Wrong!") + + except ValueError: + print("Invalid input. Please enter a number, 'r', or 'q'.") + except KeyboardInterrupt: + print("\n\nExiting...") + break + + +def main(args): + # Load model + print(f"Loading model from {args.model}...") + model = load_model(args.model) + print("Model loaded successfully!") + + # Load test data + print("Loading MNIST test data...") + _, _, test_images, test_labels = map(mx.array, getattr(mnist, args.dataset)()) + print(f"Loaded {len(test_labels)} test samples") + + # Show random predictions + if args.num_samples > 0: + predict_samples(model, test_images, test_labels, args.num_samples) + + # Interactive mode + if args.interactive: + interactive_mode(model, test_images, test_labels) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run inference on trained MNIST model" + ) + parser.add_argument( + "--model", + type=str, + default="model.safetensors", + help="Path to trained model (default: model.safetensors)", + ) + parser.add_argument( + "--dataset", + type=str, + default="mnist", + choices=["mnist", "fashion_mnist"], + help="The dataset to use (default: mnist)", + ) + parser.add_argument( + "--num-samples", + type=int, + default=5, + help="Number of random samples to predict (default: 5, 0 to skip)", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Enable interactive prediction mode", + ) + args = parser.parse_args() + + main(args) diff --git a/mnist/main.py b/mnist/main.py index 5ee7c5d91..543645fce 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -87,6 +87,11 @@ def eval_fn(X, y): f" Time {toc - tic:.3f} (s)" ) + # Save the trained model + if args.save_model: + model.save_weights(args.save_model) + print(f"\nModel saved to {args.save_model}") + if __name__ == "__main__": parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.") @@ -98,6 +103,12 @@ def eval_fn(X, y): choices=["mnist", "fashion_mnist"], help="The dataset to use.", ) + parser.add_argument( + "--save-model", + type=str, + default="model.safetensors", + help="Path to save the trained model (default: model.safetensors).", + ) args = parser.parse_args() if not args.gpu: mx.set_default_device(mx.cpu)