Command-line (CLI) app to train an image classifier on the Udacity Flowers dataset and predict a flower name from a new image using transfer learning.
TL;DR (try this first):
- No-training quickstart (recommended): download a pretrained checkpoint from GitHub Releases and run
predict.py. - Train your own checkpoint (CPU-only can be slow):
python train.py flowers --epochs 2 - Predict a label from an image:
python predict.py flowers/test/3/image_06634.jpg save_directory/checkpoint.pth --top_k 5 - Uses a pretrained backbone (default: ResNet-50) + a new classifier head.
- Checkpoint integrity: compare file SHA-256 with the value in the Release notes.
After training, run:
python predict.py flowers/test/3/image_06634.jpg save_directory/checkpoint.pth --top_k 5Example output (format may vary slightly):
Path to image: flowers/test/3/image_06634.jpg
Path to checkpoint: save_directory/checkpoint.pth
Number of top K classes: 5
Path to category names file: cat_to_name.json
GPU: False
Prediction (name): cape flower
Probability: 0.10262521356344223
Top classes (names): ['cape flower', 'cyclamen', 'lotus lotus', 'magnolia', 'columbine']
Top probabilities: [0.10262521356344223, 0.07787298411130905, 0.05228663235902786, 0.048569660633802414, 0.0458136685192585]
python scripts/smoke_test.pytrain.py— trains a classifier and saves a checkpointpredict.py— loads a checkpoint and predicts top-k classes for an input imagehelper.py— training / preprocessing / checkpoint helpersget_input_args.py— CLI argument definitionscat_to_name.json— mapping from class id → flower nameassets/— screenshots / example images (optional)notebooks/— project notebook (reference / exploration)scripts/— script (smoke test, others.)
Not included: the dataset folder flowers/ (it is ignored by git).
This project uses transfer learning:
- Load a pretrained convolutional neural network (CNN) backbone (default: ResNet-50).
- Replace the final classification layer with a new fully-connected classifier head for 102 flower classes.
- Freeze (or mostly freeze) backbone parameters and train the classifier head on the flowers dataset.
- Save a checkpoint so you can run fast predictions later without retraining.
python -m venv .venv
# Windows (Git Bash)
source .venv/Scripts/activate
# macOS / Linux
# source .venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txtconda create -n flower_image_classifier python=3.11 -y
conda activate flower_image_classifier
pip install --upgrade pip
pip install -r requirements.txtPlace the Udacity Flowers dataset in the repo root:
flowers/
train/
valid/
test/
Download checkpoint.pth from the repo’s GitHub Release (e.g. v1.0.0) and save it to:
save_directory/checkpoint.pth
Now you can run predictions without training:
python predict.py flowers/test/3/image_06634.jpg save_directory/checkpoint.pth --top_k 5If you don’t have the dataset locally (flowers/), use any local image path instead of flowers/test/....
Minimal training run (2 epochs):
python train.py flowers --epochs 2Defaults:
- checkpoint folder:
save_directory/ - checkpoint file:
save_directory/checkpoint.pth - architecture:
resnet50
A “more realistic” training example:
python train.py flowers --arch resnet50 --learning_rate 0.003 --hidden_units 512 256 --dropout 0.2 --epochs 3python predict.py flowers/test/3/image_06634.jpg save_directory/checkpoint.pth --top_k 5If you have a CUDA-capable GPU and a compatible PyTorch install, add --gpu:
python train.py flowers --epochs 3 --gpu
python predict.py flowers/test/3/image_06634.jpg save_directory/checkpoint.pth --top_k 5 --gpuIf no GPU is available, the code runs on CPU.
| Setting | Value |
|---|---|
| Backbone | ResNet-50 |
| Epochs | 5 |
| Learning rate | 0.0005 |
| Hidden units | 512 |
| Dropout | 0.2 |
| Validation accuracy | 0.920 |
| Test accuracy | 0.878 |
python train.py -hCommon arguments:
data_dir(positional): dataset folder (e.g.flowers)--save_dir: folder where checkpoints are saved (default:save_directory/)--arch: pretrained architecture (default:resnet50)--learning_rate--hidden_units(space-separated list, e.g.--hidden_units 512 256)--dropout--epochs--gpu
python predict.py -hCommon arguments:
path_to_image(positional)path_to_checkpoint(positional)--top_k--category_names(default:cat_to_name.json)--gpu
This project is licensed under the MIT License — see the LICENSE file for details.