Skip to content

David-Rodriguez-Barrios/Cat-Breed-Identification

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

34 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

🐱 Cat Breed Identification

Drama cat

A deep learning project to identify cat breeds from images using various CNN and Transformer architectures. Includes both training scripts and a web interface for real-time breed identification.

πŸ“ Project Structure

Cat-Breed-Identification/
β”œβ”€β”€ backend/                       # Backend API and web interface
β”‚   └── core_service/
β”‚       β”œβ”€β”€ main.py                # FastAPI application entry point
β”‚       β”œβ”€β”€ controllers/           # API endpoints
β”‚       β”‚   └── identify.py        # Cat breed identification endpoint
β”‚       β”œβ”€β”€ services/              # Business logic
β”‚       β”‚   └── model_service.py   # Model loading and inference
β”‚       └── static/                # Frontend web interface
β”‚           └── index.html        # Web UI for cat breed identification
β”œβ”€β”€ data/
β”‚   β”œβ”€β”€ raw/                       # Original dataset from Kaggle
β”‚   β”‚   └── cat_breeds_dataset/
β”‚   β”‚       └── images/            # All cat breed images
β”‚   └── processed/                 # Preprocessed data splits
β”‚       β”œβ”€β”€ train/                 # Training set (70%)
β”‚       β”œβ”€β”€ val/                   # Validation set (15%)
β”‚       └── test/                  # Test set (15%)
β”œβ”€β”€ docs/                          # Documentation
β”œβ”€β”€ src/                           # Training scripts
β”‚   β”œβ”€β”€ train_resnet50_cats.py
β”‚   β”œβ”€β”€ train_resnet50_cats_hyper_param_optmized.py
β”‚   β”œβ”€β”€ train_mobilenetv3_cats.py
β”‚   β”œβ”€β”€ train_mobilenetv3_cats_hyper_param_optmized.py
β”‚   └── train_vit_cats.py          # Vision Transformer training
β”œβ”€β”€ scripts/                       # Utility and data processing scripts
β”‚   β”œβ”€β”€ download_data.py           # Download dataset from Kaggle
β”‚   β”œβ”€β”€ analyze_dataset.py         # Analyze dataset statistics
β”‚   β”œβ”€β”€ detect_and_clean_dataset.py  # Detect and clean dataset
β”‚   β”œβ”€β”€ clean_dataset_by_threshold.py  # Remove breeds below threshold
β”‚   β”œβ”€β”€ remove_suspect_images.py   # Remove suspect images
β”‚   └── finalize_dataset.py        # Finalize dataset preparation
β”œβ”€β”€ models/                        # Trained model weights (.pth files)
β”‚   β”œβ”€β”€ mobilenetv3_large_catbreeds_*.pth
β”‚   β”œβ”€β”€ vit_cats_best.pth
β”‚   └── *_training_metrics.png     # Training visualization plots
β”œβ”€β”€ requirements.txt
└── README.md

πŸš€ How to Run

1. Setup Environment

Linux/Mac:

python3 -m venv venv
source venv/bin/activate

Windows:

python -m venv venv
venv\Scripts\activate

⚑ Optional: Install PyTorch with CUDA (Recommended for Training)

If you plan to train or fine-tune models, you should install a CUDA-enabled version of PyTorch.
The default torch in requirements.txt installs the CPU-only build, which is very slow for training.

Install CUDA-enabled PyTorch before installing the rest of the dependencies.

To install PyTorch with GPU support, choose the command that matches your CUDA version from the official site:

➑ https://pytorch.org/get-started/locally/

For example, for CUDA 13.0:

pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu130

2. Install Dependencies

pip install -r requirements.txt

3. Setup Kaggle API Credentials

Before downloading the dataset, you need to set up Kaggle API credentials:

  1. Go to https://www.kaggle.com/settings and create an API token

  2. Download the kaggle.json file

  3. Place it in:

    • Windows: C:\Users\<YourUsername>\.kaggle\kaggle.json
    • Linux/Mac: ~/.kaggle/kaggle.json

    Or set environment variables:

    • KAGGLE_USERNAME=your_username
    • KAGGLE_KEY=your_api_key

For more details: https://github.com/Kaggle/kaggle-api#api-credentials

4. Download Dataset

You have two options for obtaining the dataset:

Option A: Download Pre-cleaned Dataset (Recommended)

The cleaned and preprocessed dataset used in this project is available for direct download:

πŸ“¦ Download Cleaned Dataset (data.zip)

This dataset has been:

  • Cleaned and filtered for quality
  • Processed to remove suspect images
  • Ready to use for training

After downloading, extract the zip file and place the contents in data/raw/cat_breeds_dataset/images/

Option B: Download from Kaggle (Original Source)

Run the download script to automatically download and organize the original dataset:

python scripts/download_data.py

This will:

  • Download the dataset from Kaggle (ma7555/cat-breeds-dataset)
  • Extract and organize it to data/raw/cat_breeds_dataset/images/
  • Verify the download was successful

Note: If using the original Kaggle dataset, you may want to run the cleaning scripts (scripts/detect_and_clean_dataset.py, scripts/remove_suspect_images.py, etc.) to match the cleaned dataset quality.

5. Analyze Dataset (Optional)

Analyze the dataset to see statistics and distribution:

python scripts/analyze_dataset.py

This will show:

  • Total images and breeds
  • Images per breed
  • Distribution statistics
  • Top/bottom breeds by image count
  • Potential issues and recommendations

6. Train Models

Run any training script from the project root:

# ResNet50 models
python src/train_resnet50_cats.py
python src/train_resnet50_cats_hyper_param_optmized.py

# MobileNetV3 models
python src/train_mobilenetv3_cats.py
python src/train_mobilenetv3_cats_hyper_param_optmized.py

# Vision Transformer (ViT)
python src/train_vit_cats.py

Trained models will be saved to models/ directory. Training scripts automatically generate visualization plots showing loss, accuracy, and learning rate schedules.

7. Run the Web Interface (Frontend/Backend)

The project includes a web interface for real-time cat breed identification. To run it:

Option 1: From the backend directory (Recommended)

# Navigate to backend directory
cd backend/core_service

# Run the FastAPI server
uvicorn main:app --reload --host 0.0.0.0 --port 8000

Option 2: From the project root

# Run from project root (adjust path as needed)
python -m uvicorn backend.core_service.main:app --reload --host 0.0.0.0 --port 8000

Option 3: Using Python module directly

cd backend/core_service
python -m uvicorn main:app --reload --host 0.0.0.0 --port 8000

The server will start and you can access:

Web Interface Features

The web interface allows you to:

  • Upload a cat image (drag & drop or click to browse)
  • Get real-time breed identification predictions
  • See top-k predictions with confidence scores
  • Beautiful, responsive UI with cat-themed design

Web Interface Screenshot

API Endpoints

Identify Cat Breed:

POST /identify/{model_name}

Parameters:

  • model_name: Currently supports mobilenetv3_large_catbreeds
  • image: Image file (multipart/form-data)
  • topk (query parameter): Number of top predictions (default: 3, max: 10)

Example using curl:

curl -X POST "http://localhost:8000/identify/mobilenetv3_large_catbreeds?topk=5" \
  -H "accept: application/json" \
  -H "Content-Type: multipart/form-data" \
  -F "image=@path/to/cat_image.jpg"

πŸ“Š Models

The project includes training scripts for:

  • ResNet50 - Classic deep CNN architecture with transfer learning
  • MobileNetV3 - Lightweight model optimized for mobile/edge devices
  • Vision Transformer (ViT) - Transformer-based architecture trained from scratch
  • Hyperparameter Optimization - Automated hyperparameter tuning scripts for ResNet50 and MobileNetV3

πŸ“ Notes

  • Training scripts expect data in data/raw/cat_breeds_dataset/images/
  • Models are saved to models/ directory
  • The scripts automatically split data into train/val/test (70/15/15) if using raw data
  • For pre-split data, use data/processed/ (requires script modification)
  • The web interface requires at least one trained model in the models/ directory

AI Tools Citation

This project utilized AI tools as learning aids and guides throughout development. All code was written by the authors, with AI assistance used for understanding concepts, debugging, and structural guidance.

ChatGPT 5.1

Plotting and Visualization:

  • The plotting functionality in training scripts (plot_training_metrics() function in src/train_mobilenetv3_cats.py, src/train_resnet50_cats.py, and src/train_vit_cats.py) was improvised from ChatGPT 5.1 by copying the structure (matplotlib subplots, loss/accuracy curves, learning rate schedules) and then modifying it with our own data and metrics.

Function Implementation Guidance:

  • Model Loading: ChatGPT was consulted to understand how to properly load PyTorch models using torch.load() and model.load_state_dict() (implemented in backend/core_service/services/model_service.py, lines 90-103, 113-130, 143-173).
  • Learning Rate Scheduling: Guidance on implementing ReduceLROnPlateau scheduler for adaptive learning rate adjustment (used in training scripts, e.g., src/train_mobilenetv3_cats.py lines 588-589).
  • Data Splitting: Understanding how to use train_test_split with stratification for balanced dataset splits (implemented in training scripts, e.g., src/train_mobilenetv3_cats.py lines 193-209).

Backend Development:

  • ChatGPT was used to help load models and explain how to run the FastAPI instance for the backend service (backend/core_service/services/model_service.py).

Hyperparameter Tuning:

  • For the hyperparameter optimization scripts (src/train_mobilenetv3_cats_hyper_param_optmized.py and src/train_resnet50_cats_hyper_param_optmized.py), ChatGPT was used to suggest hyperparameter values that make sense for the training process.

Claude Sonnet 4.5

File Structuring and Pattern Recognition:

  • File finding and pattern recognition code was developed with the help of Claude Sonnet 4.5, as we were running into issues with our code not being able to save and find files automatically. Specific implementations include:
    • find_best_model_path() function using glob.glob() and regex pattern matching to locate model files by accuracy (in backend/core_service/services/model_service.py, lines 43-82).
    • Recursive file searching using Path.rglob() for dataset analysis (in scripts/analyze_dataset.py line 63, scripts/download_data.py lines 174-198, scripts/detect_and_clean_dataset.py line 78).
    • Model file pattern matching using glob.glob() and regex in training scripts to check for existing models (e.g., src/train_mobilenetv3_cats.py lines 412-426, src/train_resnet50_cats.py lines 406-426).

Dataset Analysis Script:

  • For the scripts/analyze_dataset.py script, we provided Claude with the idea of what we wanted to see in our program (statistics, distribution analysis, top/bottom breeds). Claude was able to come up with a skeleton code structure that we then used to implement our own code with our specific requirements and data.

Frontend Development

  • AI was not used in implementing the frontend HTML file (backend/core_service/static/index.html). It was only used to modify styling and CSS to speed up the development process.

Gemini 3 Pro

Documentation:

  • The grammar and sentence structure of this README file was fixed and improved using Gemini 3 Pro.

General Statement

Everything else in this project was done independently by the us. Code was never directly copied blindly from any AI model but rather used as a guide, pointing direction, and a learning tool to understand concepts and best practices. All implementations were adapted, modified, and integrated into our project architecture according to our specific requirements.

Code was not copied from any source.


Made with ❀️

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •