A PyTorch-based CNN project for classifying fish doodles vs. not-fish doodles, using ResNet18 and QuickDraw data. Includes robust preprocessing for transparency, early stopping, ONNX export, and test scripts.
This was made in conjunction with my fishes project.
-
train_fish_doodle_classifier.py
Main training script. Handles dataset loading, preprocessing, model training (with early stopping), and ONNX export. Also includes utilities for generating datasets from Google QuickDraw. -
test_fish_classifier.py
Script for evaluating the trained model (PyTorch or ONNX) on your dataset or external images. Saves transformed images for inspection. -
requirements.txt
Python dependencies for training and testing (PyTorch, torchvision, onnx, onnxruntime, scikit-learn, tqdm, Pillow, numpy, opencv-python). -
dataset/
Directory for your training images, with subfoldersfish/andnot_fish/. -
quickdraw/
Contains downloaded QuickDraw.ndjsonfiles for generating synthetic training data. -
fish_doodle_classifier.pth
Saved PyTorch model weights after training. -
fish_doodle_classifier.onnx
Exported ONNX model for cross-platform inference. -
test_fish_classifier.py
Script to test the model on your dataset or custom images, using either PyTorch or ONNX. -
README.md
Here :)
-
Install dependencies:
pip install -r requirements.txt
-
Prepare your dataset:
- Place fish images in
dataset/fish/and not-fish images indataset/not_fish/. - Or, use QuickDraw data by running:
python train_fish_doodle_classifier.py --pretrain
- Place fish images in
-
Train the model:
python train_fish_doodle_classifier.py
- Early stopping is enabled to prevent overfitting.
- The best model is exported as both
.pthand.onnx.
-
Test the model:
python test_fish_classifier.py
- Evaluates on your dataset and prints classification metrics.
- Also supports ONNX inference for deployment.
-
Transparency Handling: All images are composited onto a white background before preprocessing. So if your fish is all white, it won't work.
-
Early Stopping: Training halts if validation loss does not improve for 5 epochs, reducing overfitting.
-
ONNX Export: Model is exported to ONNX for compatibility with non-PyTorch environments. I use this in the fishes frontend.
-
QuickDraw Integration: Scripts can auto-download and convert QuickDraw doodles for both fish and not-fish classes. I ended up not using this, but have left it in the repo in case someone else wants to.
-
Consistent Preprocessing: The same preprocessing pipeline is used for both training and inference, including in the test script.
-
Class Imbalance: Weighted sampling and loss are used to address class imbalance between fish and not-fish.
- The model expects 224x224 grayscale images (3 channels for ResNet compatibility).
- All code is designed for clarity and reproducibility.
- For best results, inspect the saved transformed images to verify preprocessing.
Feel free to modify the scripts for your own dataset or use case!