Skip to content

This repository contains a complete implementation of a plant disease classification system using a CBAM (Convolutional Block Attention Module) augmented ResNet18 architecture. The system is designed to accurately identify various plant diseases from images, leveraging attention mechanisms to focus on the most relevant features for diagnosis.

Notifications You must be signed in to change notification settings

Jeremy-Cleland/PlantDoc

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

53 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

🌿 PlantDoc: Plant Disease Classification with CBAM-Augmented ResNet18

Python 3.8+ PyTorch 2.1+ License: MIT Deep Learning: CBAM Code style: ruff

State-of-the-art plant disease classification using attention-enhanced deep learning

This repository contains a complete implementation of a plant disease classification system using a CBAM (Convolutional Block Attention Module) augmented ResNet18 architecture. The system is designed to accurately identify various plant diseases from images, leveraging attention mechanisms to focus on the most relevant features for diagnosis.

Table of Contents

Overview

Plant diseases cause significant crop losses worldwide. Early and accurate detection is crucial for effective management. This project implements a state-of-the-art deep learning approach that combines ResNet18 with attention mechanisms to improve classification accuracy for plant disease diagnosis.

The CBAM architecture enhances the model's ability to focus on relevant disease features by applying:

  1. Channel attention - Emphasizes important feature channels ("what" to focus on)
  2. Spatial attention - Highlights important regions in the image ("where" to focus on)

✨ Key Features

Model & Architecture

  • 🧠 CBAM-Enhanced ResNet18: Dual attention mechanisms for superior feature focus
  • πŸ”§ Customizable Attention: Configurable reduction ratios and kernel sizes
  • πŸ”„ Transfer Learning: Pre-trained weights with fine-tuning capabilities

Data & Augmentation

  • πŸ” Advanced Preprocessing: Comprehensive pipeline with Albumentations
  • πŸ”€ State-of-the-art Augmentation: RandAugment, CutMix, and MixUp strategies
  • πŸ“Š Data Validation: Automatic integrity checking and analysis

Training & Optimization

  • ⚑ Mixed Precision Training: FP16/BF16 support for faster training
  • πŸ“ˆ Adaptive Optimization: Learning rate scheduling and gradient clipping
  • πŸŽ›οΈ Hyperparameter Tuning: Integrated Optuna-based optimization
  • πŸ”„ Stochastic Weight Averaging: Enhanced generalization capabilities

Interpretability & Visualization

  • πŸ‘οΈ Attention Visualization: Interactive tools to understand model focus
  • πŸ”₯ GradCAM Integration: Class activation mapping for decision explanation
  • πŸ” SHAP Analysis: Feature importance visualization
  • πŸ“ Comprehensive Reporting: Automated HTML reports with interactive plots

Developer Experience

  • πŸ–₯️ Intuitive CLI: Command-line interface for all operations
  • βš™οΈ Configuration System: Flexible YAML-based configuration
  • πŸ“ Experiment Tracking: Automatic versioning and result logging
  • πŸ§ͺ Testing Framework: Comprehensive unit and integration tests

Project Structure

plantdoc/
β”œβ”€β”€ pyproject.toml         # Modern Python packaging config
β”œβ”€β”€ README.md              # Project documentation
β”œβ”€β”€ .gitignore             # Git ignore file
β”œβ”€β”€ cli/                   # Command-line interface
β”‚   └── main.py            # Main CLI entry point
β”œβ”€β”€ configs/               # Configuration files
β”‚   └── config.yaml        # Main configuration file
β”œβ”€β”€ core/                  # Core modules
β”‚   β”œβ”€β”€ data/              # Data processing
β”‚   β”‚   β”œβ”€β”€ datamodule.py  # PyTorch data module
β”‚   β”‚   β”œβ”€β”€ datasets.py    # Dataset implementations
β”‚   β”‚   β”œβ”€β”€ transforms.py  # Data transformations
β”‚   β”‚   └── prepare_data.py # Data preparation utilities
β”‚   β”œβ”€β”€ evaluation/        # Model evaluation
β”‚   β”‚   β”œβ”€β”€ evaluate.py    # Evaluation pipeline
β”‚   β”‚   β”œβ”€β”€ interpretability.py # GradCAM implementation
β”‚   β”‚   β”œβ”€β”€ metrics.py     # Evaluation metrics
β”‚   β”‚   └── shap_evaluation.py # SHAP analysis
β”‚   β”œβ”€β”€ models/            # Model definitions
β”‚   β”‚   β”œβ”€β”€ attention.py   # CBAM implementation
β”‚   β”‚   β”œβ”€β”€ base.py        # Base model class
β”‚   β”‚   β”œβ”€β”€ model_cbam18.py # CBAM-ResNet18 model
β”‚   β”‚   β”œβ”€β”€ registry.py    # Model registry
β”‚   β”‚   β”œβ”€β”€ backbones/     # Model backbones
β”‚   β”‚   └── heads/         # Classification heads
β”‚   β”œβ”€β”€ training/          # Training utilities
β”‚   β”‚   β”œβ”€β”€ callbacks/     # Training callbacks
β”‚   β”‚   β”œβ”€β”€ loss.py        # Loss functions
β”‚   β”‚   β”œβ”€β”€ optimizers.py  # Optimizer configurations
β”‚   β”‚   β”œβ”€β”€ schedulers.py  # LR scheduler implementations
β”‚   β”‚   └── train.py       # Training loop
β”‚   β”œβ”€β”€ tuning/            # Hyperparameter tuning
β”‚   β”‚   β”œβ”€β”€ optuna_runner.py # Optuna integration
β”‚   β”‚   └── search_space.py # Hyperparameter search space
β”‚   └── visualization/     # Visualization tools
β”‚       β”œβ”€β”€ attention_viz.py # Attention visualization
β”‚       └── visualization.py # General visualizations
β”œβ”€β”€ reports/               # Reporting utilities
β”‚   β”œβ”€β”€ generate_plots.py  # Plot generation
β”‚   β”œβ”€β”€ generate_report.py # HTML report generation
β”‚   └── templates/         # Report templates
β”œβ”€β”€ utils/                 # Utility functions
β”‚   β”œβ”€β”€ config_utils.py    # Configuration utilities
β”‚   β”œβ”€β”€ logger.py          # Logging setup
β”‚   β”œβ”€β”€ paths.py           # Path management
β”‚   └── mps_utils.py       # Apple Silicon GPU utilities
β”œβ”€β”€ scripts/               # Utility scripts
└── data/                  # Data directory (managed via config)
    └── raw/               # Example structure for raw data

πŸš€ Quick Start

Get up and running with PlantDoc in minutes:

# Install the package (with visualization extras)
pip install plantdoc[viz]

# Download a sample image (Apple Scab)
curl -L -o apple_scab.jpg "https://raw.githubusercontent.com/spMohanty/PlantVillage-Dataset/master/raw/color/Apple___Apple_scab/0a5e9323-dbad-432d-ac58-d291718345d9___FREC_Scab_3417.JPG"

# Run inference
python -m plantdoc.cli.main predict --image apple_scab.jpg --visualize --top-k 3

This will:

  1. Classify the disease in the image using the default pre-trained model.
  2. Generate a visualization showing the model's attention map.
  3. Display the top 3 predictions with confidence scores.

Example Output Visualization
Example classification output with probabilities

πŸ› οΈ Installation

Prerequisites

  • Python 3.8+ (3.8, 3.9, 3.10, 3.11 supported)
  • PyTorch 2.1+
  • CUDA-capable GPU (recommended for significant speedup) or Apple Silicon (MPS support)

Option 1: Install from PyPI (Recommended)

# Install the base package (inference only)
pip install plantdoc

# Install with visualization dependencies
pip install plantdoc[viz]

# Install with all development tools (for contribution or source modification)
pip install plantdoc[dev]

Option 2: Install from Source

For development or modification:

# Clone the repository
git clone https://github.com/yourusername/plantdoc.git # Replace with your repo URL
cd plantdoc

# Create and activate a virtual environment (recommended)
python -m venv venv
source venv/bin/activate  # On Windows use `venv\\Scripts\\activate`

# Install in editable mode with development dependencies
pip install -e ".[dev]"

Verify Installation

# Check CLI is accessible and show version
python -m plantdoc.cli.main --version

# List available models provided by the package
python -m plantdoc.cli.main models --list

Usage

PlantDoc provides a comprehensive command-line interface (CLI) for managing the entire workflow. All commands follow the pattern python -m plantdoc.cli.main <command> [options].

Use python -m plantdoc.cli.main --help or python -m plantdoc.cli.main <command> --help for detailed help on commands and options.

Data Preparation

Prepare your dataset (validate, analyze, split) before training:

# Prepare data using configuration specified in config.yaml
python -m plantdoc.cli.main prepare --config configs/config.yaml

Key Options:

  • --raw-dir: Specify the directory containing raw images.
  • --output-dir: Specify where processed data and splits should be saved.
  • --dry-run: Perform checks without modifying files.

Training

Train a new model or fine-tune an existing one:

# Train using the default configuration
python -m plantdoc.cli.main train --config configs/config.yaml

Key Options:

  • --config: Path to the main configuration file.
  • --model: Override the model architecture (e.g., --model resnet18).
  • --epochs: Override the number of training epochs.
  • --batch-size: Override the training batch size.
  • --experiment: Specify a custom name for the experiment run.
  • --version: Specify a version number for the experiment run.
  • --resume: Resume training from the latest checkpoint in the experiment directory.

Training automatically logs metrics, saves checkpoints, generates visualizations (like attention maps), and creates a summary report.

Evaluation

Evaluate a trained model on a dataset split (e.g., test set):

# Evaluate the best checkpoint from a specific experiment
python -m plantdoc.cli.main eval --config configs/config.yaml --experiment <your_experiment_name> --version <your_version> --checkpoint best

Key Options:

  • --checkpoint: Path to a specific checkpoint file or best/last.
  • --split: Specify the data split to evaluate on (val or test).
  • --output-dir: Directory to save evaluation results (metrics, plots).
  • --interpret: Generate additional interpretability plots (e.g., GradCAM, SHAP if configured).

Attention Visualization

Generate attention map visualizations for a specific image using a trained model:

# Visualize attention for an image using the best checkpoint from an experiment
python -m plantdoc.cli.main attention --image path/to/your/image.jpg --experiment <your_experiment_name> --checkpoint best

Key Options:

  • --model: Specify the model architecture if not using an experiment checkpoint.
  • --layers: Specify specific layers for visualization.
  • --output-dir: Directory to save the visualization output.

Generating Reports

Generate comprehensive HTML reports summarizing training or evaluation runs:

# Generate a report for a specific experiment run
python -m plantdoc.cli.main report --experiment <your_experiment_name> --version <your_version>

Key Options:

  • --output-dir: Specify where the report should be saved.
  • --template: Use a custom report template.

Hyperparameter Tuning

Perform automated hyperparameter optimization using Optuna:

# Run hyperparameter tuning based on the search space defined in the config
python -m plantdoc.cli.main tune --config configs/config.yaml --trials 100

Key Options:

  • --trials: Number of optimization trials to run.
  • --storage: Database URL for storing Optuna study results.
  • --study-name: Custom name for the Optuna study.

Model Registry

Explore available models and their configurations:

# List all registered models
python -m plantdoc.cli.main models --list

# Get detailed information about a specific model
python -m plantdoc.cli.main models --model cbam_only_resnet18

# Get parameter schema in JSON or YAML format
python -m plantdoc.cli.main models --model cbam_only_resnet18 --format json

Configuration

The project relies heavily on a YAML-based configuration system managed via configs/config.yaml. This central file controls all aspects of the workflow:

  • Data: Paths, class names, image size, normalization stats, train/val/test splits.
  • Model: Architecture choice, backbone specifics, attention module parameters (reduction ratio, kernel size), pre-trained weights, regularization (dropout, stochastic depth).
  • Augmentation: Albumentations pipeline definition, RandAugment, CutMix, MixUp parameters.
  • Training: Number of epochs, batch size, learning rate, weight decay, gradient clipping, mixed precision (FP16/BF16).
  • Optimization: Optimizer type (AdamW, SGD, etc.), learning rate scheduler (Cosine Annealing, ReduceLROnPlateau, etc.), loss function (CrossEntropy, Label Smoothing).
  • Callbacks: Configuration for early stopping, model checkpointing (saving best/last weights), logging, visualization generation during training.
  • Evaluation: Metrics to compute (Accuracy, F1, Precision, Recall), confusion matrix settings, interpretability options (GradCAM layers, SHAP samples).
  • Hardware: Device selection (cuda, mps, cpu), number of workers for data loading.
  • Tuning: Hyperparameter search space definition for Optuna.

You can override any configuration setting via the command line using dot notation, e.g., python -m plantdoc.cli.main train --training.epochs 50 --training.optimizer.lr 0.0005.

🧠 Model Architecture

PlantDoc utilizes a ResNet18 backbone enhanced with the Convolutional Block Attention Module (CBAM) for improved feature representation and focus on relevant image regions.

CBAM: Dual Attention Mechanism

CBAM sequentially infers attention maps along two separate dimensions: channel and spatial. The overall attention process can be summarized as:

F'  = M_c(F) βŠ— F        (Channel Attention)
F'' = M_s(F') βŠ— F'      (Spatial Attention)

where ( F ) is the input feature map, ( M_c ) is the channel attention map, ( M_s ) is the spatial attention map, and ( \otimes ) denotes element-wise multiplication.

CBAM Architecture Detailed
CBAM Architecture: Channel Attention Module (top) and Spatial Attention Module (bottom) applied sequentially.

1. Channel Attention Module

Focuses on "what" is meaningful in the input image channels.

Click to view ChannelAttention implementation
import torch
import torch.nn as nn

class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction_ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        # Shared MLP
        self.mlp = nn.Sequential(
            nn.Conv2d(channels, channels // reduction_ratio, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction_ratio, channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.mlp(self.avg_pool(x))
        max_out = self.mlp(self.max_pool(x))
        attention = self.sigmoid(avg_out + max_out)
        return x * attention.expand_as(x) # Apply attention
  • Process: Aggregates spatial information using average and max pooling, processes through a shared Multi-Layer Perceptron (MLP), combines the outputs, and applies sigmoid activation to generate channel weights.

2. Spatial Attention Module

Focuses on "where" the informative parts are located in the feature map.

Click to view SpatialAttention implementation
import torch
import torch.nn as nn

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = kernel_size // 2
        # Convolution layer to process pooled features
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Pool across channels
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        # Concatenate pooled features
        pooled = torch.cat([avg_out, max_out], dim=1)
        # Generate spatial attention map
        attention = self.sigmoid(self.conv(pooled))
        return x * attention.expand_as(x) # Apply attention
  • Process: Aggregates channel information using average and max pooling along the channel axis, concatenates them, applies a convolution layer to generate a 2D spatial attention map, and uses sigmoid activation.

3. Integration with ResNet

  • CBAM blocks are typically inserted after each residual block in the ResNet architecture.
  • The reduction_ratio for the channel attention MLP and the kernel_size for the spatial attention convolution are configurable parameters.

Visualization of Attention Maps

Understanding where the model focuses is crucial for trust and debugging. PlantDoc integrates tools to visualize these attention maps.

Attention Map Visualization Example
Example visualization of spatial attention maps overlaid on input images, highlighting regions important for classification (e.g., lesions on a potato leaf).

This dual attention mechanism allows the model to dynamically emphasize salient features in both channel and spatial dimensions, leading to improved performance, especially on images with complex backgrounds or subtle disease symptoms.

πŸ“Š Performance Benchmarks

The CBAM-augmented ResNet18 model demonstrates significant improvements over standard baselines across our plant disease datasets.

Accuracy Comparison

Training History Plot
Training History: Accuracy and Loss Curves for CBAM-ResNet18 v1

Model Top-1 Accuracy F1 Score Precision Recall Training Time
CBAM-ResNet18 v1 97.46% 99.16% 99.21% 99.17% 9h 46m 44s
CBAM-ResNet18 v2 96.71% 99.17% 99.19% 99.16% 3h 14m 43s

Note: CBAM-ResNet18 v2 offers comparable performance with significantly reduced training time (3x faster)

Performance on Challenging Cases

The CBAM attention mechanism significantly improves model performance on difficult cases:

Challenge Category Improvement with CBAM
Early-stage diseases +8.3%
Visually similar diseases +10.8%
Variable lighting conditions +7.9%
Small lesions / symptoms +12.5%

Confusion Matrix
Confusion Matrix showing strong performance across all 39 plant disease classes

Robustness Analysis

  • Data Efficiency: Achieves comparable accuracy with approximately 30% less training data.
  • Generalization: Demonstrates ~18% better performance on out-of-distribution test sets.
  • Calibration: Expected Calibration Error (ECE) reduced by ~45%, indicating more reliable confidence scores.

Confidence Distribution
Model confidence distribution showing well-calibrated predictions

Troubleshooting

Common Issues

Installation Problems

Issue: ERROR: Could not find a version that satisfies the requirement torch>=2.1.0 Solution: PyTorch needs to be installed manually first, as its distribution depends on your OS and CUDA version. Visit pytorch.org and follow the instructions specific to your system. Example (Linux/Windows, CUDA 11.8):

pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118

Example (MacOS, Apple Silicon):

pip install torch torchvision torchaudio

After installing PyTorch, retry pip install plantdoc.

Issue: ModuleNotFoundError: No module named 'plantdoc.cli' (when running from source) Solution: Ensure you have installed the package in editable mode (pip install -e ".[dev]") from the root directory of the cloned repository, and that your virtual environment is activated.

CUDA/MPS Issues

Issue: RuntimeError: CUDA error: device-side assert triggered or similar CUDA errors. Solution:

  1. Verify your NVIDIA driver and CUDA toolkit compatibility with the installed PyTorch version.
  2. Try reducing the training.batch_size in your configuration file or via CLI override (--training.batch_size <smaller_number>).
  3. Check for potential issues in the data loading or augmentation pipeline that might produce invalid inputs. Use prepare command's validation features.

Issue: Poor performance or errors on Apple Silicon (MPS). Solution:

  1. Ensure you have the latest macOS and PyTorch versions. MPS support is rapidly evolving.
  2. Some operations might not be fully supported on MPS yet. Check PyTorch documentation for compatibility. The code includes utilities (utils/mps_utils.py) but might require adjustments based on the specific error.
  3. Try running with CPU (--hardware.device cpu) to isolate the issue.

Memory Errors

Issue: OutOfMemoryError: CUDA out of memory or system running out of RAM. Solution:

  1. Reduce Batch Size: Lower training.batch_size or evaluation.batch_size.
  2. Reduce Image Size: Lower data.image_size if feasible for your task.
  3. Enable Mixed Precision: Use training.precision: 16-mixed (requires compatible GPU).
  4. Reduce Dataloader Workers: Lower hardware.num_workers.
  5. Use Gradient Accumulation: Modify training script to accumulate gradients over multiple smaller batches (requires code change, not currently implemented via config).

About

This repository contains a complete implementation of a plant disease classification system using a CBAM (Convolutional Block Attention Module) augmented ResNet18 architecture. The system is designed to accurately identify various plant diseases from images, leveraging attention mechanisms to focus on the most relevant features for diagnosis.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published