Skip to content

PyTorch CLI flower classifier (transfer learning): train + predict (top-K).

License

Notifications You must be signed in to change notification settings

andigles/flower-image-classifier

Repository files navigation

Flower Image Classifier (Udacity AIPND)

CI

Command-line (CLI) app to train an image classifier on the Udacity Flowers dataset and predict a flower name from a new image using transfer learning.

TL;DR (try this first):

  • No-training quickstart (recommended): download a pretrained checkpoint from GitHub Releases and run predict.py.
  • Train your own checkpoint (CPU-only can be slow): python train.py flowers --epochs 2
  • Predict a label from an image: python predict.py flowers/test/3/image_06634.jpg save_directory/checkpoint.pth --top_k 5
  • Uses a pretrained backbone (default: ResNet-50) + a new classifier head.
  • Checkpoint integrity: compare file SHA-256 with the value in the Release notes.

Demo (what you should see)

After training, run:

python predict.py flowers/test/3/image_06634.jpg save_directory/checkpoint.pth --top_k 5

Example output (format may vary slightly):

Path to image: flowers/test/3/image_06634.jpg
Path to checkpoint: save_directory/checkpoint.pth
Number of top K classes: 5
Path to category names file: cat_to_name.json
GPU: False

Prediction (name): cape flower
Probability: 0.10262521356344223

Top classes (names): ['cape flower', 'cyclamen', 'lotus lotus', 'magnolia', 'columbine']
Top probabilities: [0.10262521356344223, 0.07787298411130905, 0.05228663235902786, 0.048569660633802414, 0.0458136685192585]

Smoke test

python scripts/smoke_test.py

What’s in this repo

  • train.py — trains a classifier and saves a checkpoint
  • predict.py — loads a checkpoint and predicts top-k classes for an input image
  • helper.py — training / preprocessing / checkpoint helpers
  • get_input_args.py — CLI argument definitions
  • cat_to_name.json — mapping from class id → flower name
  • assets/ — screenshots / example images (optional)
  • notebooks/ — project notebook (reference / exploration)
  • scripts/ — script (smoke test, others.)

Not included: the dataset folder flowers/ (it is ignored by git).


Approach

This project uses transfer learning:

  1. Load a pretrained convolutional neural network (CNN) backbone (default: ResNet-50).
  2. Replace the final classification layer with a new fully-connected classifier head for 102 flower classes.
  3. Freeze (or mostly freeze) backbone parameters and train the classifier head on the flowers dataset.
  4. Save a checkpoint so you can run fast predictions later without retraining.

Quickstart — clone → predict (no training)

1) Create an environment + install

Option A (recommended): pip + virtual environment

python -m venv .venv

# Windows (Git Bash)
source .venv/Scripts/activate

# macOS / Linux
# source .venv/bin/activate

pip install --upgrade pip
pip install -r requirements.txt

Option B: Conda

conda create -n flower_image_classifier python=3.11 -y
conda activate flower_image_classifier
pip install --upgrade pip
pip install -r requirements.txt

2) Dataset layout (expected)

Place the Udacity Flowers dataset in the repo root:

flowers/
  train/
  valid/
  test/

3) Download the pretrained checkpoint (recommended)

Download checkpoint.pth from the repo’s GitHub Release (e.g. v1.0.0) and save it to:

save_directory/checkpoint.pth

Now you can run predictions without training:

python predict.py flowers/test/3/image_06634.jpg save_directory/checkpoint.pth --top_k 5

If you don’t have the dataset locally (flowers/), use any local image path instead of flowers/test/....

4) Train (optional — creates your own checkpoint)

Minimal training run (2 epochs):

python train.py flowers --epochs 2

Defaults:

  • checkpoint folder: save_directory/
  • checkpoint file: save_directory/checkpoint.pth
  • architecture: resnet50

A “more realistic” training example:

python train.py flowers   --arch resnet50   --learning_rate 0.003   --hidden_units 512 256   --dropout 0.2   --epochs 3

5) Predict (top-5)

python predict.py flowers/test/3/image_06634.jpg save_directory/checkpoint.pth --top_k 5

GPU usage (optional)

If you have a CUDA-capable GPU and a compatible PyTorch install, add --gpu:

python train.py flowers --epochs 3 --gpu
python predict.py flowers/test/3/image_06634.jpg save_directory/checkpoint.pth --top_k 5 --gpu

If no GPU is available, the code runs on CPU.


Results

Setting Value
Backbone ResNet-50
Epochs 5
Learning rate 0.0005
Hidden units 512
Dropout 0.2
Validation accuracy 0.920
Test accuracy 0.878

CLI reference

Training help

python train.py -h

Common arguments:

  • data_dir (positional): dataset folder (e.g. flowers)
  • --save_dir: folder where checkpoints are saved (default: save_directory/)
  • --arch: pretrained architecture (default: resnet50)
  • --learning_rate
  • --hidden_units (space-separated list, e.g. --hidden_units 512 256)
  • --dropout
  • --epochs
  • --gpu

Prediction help

python predict.py -h

Common arguments:

  • path_to_image (positional)
  • path_to_checkpoint (positional)
  • --top_k
  • --category_names (default: cat_to_name.json)
  • --gpu

License

This project is licensed under the MIT License — see the LICENSE file for details.