Skip to content

Jie-Huangi/CornProject

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

玉米粒分类项目 (CornProject)

基于深度学习的玉米粒完整性分类系统,使用改进的 SqueezeNet + Ghost Module 网络架构,用于识别玉米粒是否完整(intact)或破损(broken)。

项目简介

本项目实现了一个轻量级的图像分类模型,专门用于玉米粒的质量检测。模型基于 SqueezeNet 架构并结合 Ghost Module 进行优化,在保持高准确率的同时减少计算资源消耗。

功能特性

  • 🌽 二分类任务:识别完整(intact)和破损(broken)的玉米粒
  • 🚀 轻量级模型:基于 SqueezeNet + Ghost Module 优化
  • 📊 完整的训练流程:支持训练、验证、测试全流程
  • 📈 详细的评估指标:准确率、精确率、召回率、F1-score、混淆矩阵
  • 🔮 灵活的预测功能:支持单张图片和批量预测
  • 📝 日志记录:自动保存训练日志和模型检查点

项目结构

CornProject/
├── config. py              # 配置文件
├── train.py              # 训练脚本
├── evaluate.py           # 评估脚本
├── predict.py            # 预测脚本
├── requirements.txt      # 依赖包列表
├── models/              
│   └── squeezenet_ghost.py  # 模型架构
├── utils/
│   ├── augmentation. py   # 数据增强
│   ├── data_loader.py    # 数据加载器
│   └── metrics.py        # 评估指标
├── data/                 # 数据目录
│   ├── train/           # 训练集
│   ├── val/             # 验证集
│   └── test/            # 测试集
├── checkpoints/         # 模型保存目录
└── logs/                # 日志保存目录

环境配置

系统要求

  • Python 3.8+
  • CUDA 11.0+ (如果使用 GPU)

安装依赖

# 克隆项目
git clone https://github.com/Jie-Huangi/CornProject.git
cd CornProject

# 安装依赖包
pip install -r requirements.txt

依赖包清单

  • torch >= 2.0.0
  • torchvision >= 0. 15.0
  • numpy >= 1.24.0
  • opencv-python >= 4.7. 0
  • Pillow >= 9.5.0
  • scikit-learn >= 1.2.0
  • matplotlib >= 3.7.0
  • tqdm >= 4.65.0

数据准备

数据集结构

按以下结构组织您的数据集:

data/
├── train/
│   ├── intact/          # 完整玉米粒图片
│   └── broken/          # 破损玉米粒图片
├── val/
│   ├── intact/
│   └── broken/
└── test/
    ├── intact/
    └── broken/

数据格式

  • 支持的图片格式:. jpg, .png
  • 图片会自动调整为 224x224 尺寸
  • 训练时会自动应用数据增强

使用方法

1. 配置参数

编辑 config.py 文件,根据需要调整配置:

# 数据配置
DATA_CONFIG = {
    'train_path': DATA_DIR / 'train',
    'val_path': DATA_DIR / 'val',
    'test_path': DATA_DIR / 'test',
    'image_size': 224,
    'batch_size': 32,
    'num_workers': 4,
}

# 训练配置
TRAIN_CONFIG = {
    'num_epochs': 100,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'dropout': 0.5,
    'num_classes': 2,
    'device': 'cuda',  # 使用 'cpu' 如果没有 GPU
}

2. 训练模型

python train.py

训练过程会:

  • 自动创建 checkpoints/logs/ 目录
  • 显示实时训练进度和指标
  • 保存最佳模型到 checkpoints/best_model.pth
  • 保存最后一次训练检查点到 checkpoints/last_checkpoint.pth
  • 生成训练历史记录 logs/training_history.json
  • 记录详细日志到 logs/training. log

训练输出示例:

Epoch 1/100 [Train]: 100%|██████████| 50/50 [00:30<00:00, loss: 0.6234, acc: 0.7123]
Epoch 1/100 [Val]: 100%|██████████| 10/10 [00:05<00:00, acc: 0.7500]
Train Loss: 0.6234, Train Acc: 0.7123
Val Acc: 0.7500, Precision: 0.7345, Recall: 0.7621, F1: 0.7481
Best model saved with accuracy: 0.7500

3. 评估模型

在测试集上评估训练好的模型:

python evaluate.py

评估会生成:

  • 测试集准确率、精确率、召回率、F1-score
  • 详细的分类报告
  • 混淆矩阵图 logs/confusion_matrix.png
  • 指标可视化图 logs/metrics. png

输出示例:

==================================================
Test Results
==================================================
Accuracy: 0.8945
Precision: 0.8876
Recall: 0.9021
F1-Score: 0.8948

Classification Report:
              precision    recall  f1-score   support
      intact       0.91      0.88      0.89       150
      broken       0.87      0.90      0.88       145
    accuracy                           0.89       295

4. 预测

单张图片预测

from predict import Predictor
import config

# 创建预测器
predictor = Predictor(
    checkpoint_path=config.CHECKPOINT_PATH,
    device='cuda'  # 或 'cpu'
)

# 预测单张图片
class_name, confidence, all_probs = predictor.predict_image('path/to/image.jpg')
print(f"预测类别: {class_name}")
print(f"置信度: {confidence:. 4f}")
print(f"所有概率: {all_probs}")

批量预测

# 预测整个文件夹中的图片
results = predictor.predict_batch('path/to/image/directory', save_results=True)
# 结果会保存到 logs/predictions.json

可视化预测结果

# 在图片上可视化预测结果
predictor.visualize_prediction(
    image_path='path/to/image.jpg',
    save_path='output. jpg'
)

命令行预测

直接运行预测脚本(需要先修改 predict.py 中的示例图片路径):

python predict.py

模型架构

本项目使用改进的 SqueezeNet 架构:

  • 基础架构:SqueezeNet 1.1
  • 改进模块:Ghost Module(减少参数量和计算量)
  • 分类器:全连接层 + Dropout + Softmax
  • 参数量:约 1.2M(相比原始网络减少约 30%)

性能优化

数据增强

训练时自动应用以下增强方法:

  • 随机水平翻转
  • 随机旋转(±15°)
  • 颜色抖动(亮度、对比度、饱和度、色调)
  • 随机调整尺寸和裁剪
  • 归一化(ImageNet 统计值)

学习率调度

使用 ReduceLROnPlateau 策略:

  • 验证准确率 10 个 epoch 不提升时,学习率减半
  • 帮助模型跳出局部最优

早停机制

自动保存验证集上表现最好的模型

常见问题

1. CUDA Out of Memory

降低 batch_sizeconfig.py 中:

DATA_CONFIG = {
    'batch_size': 16,  # 从 32 降低到 16
}

2. 训练速度慢

  • 确保使用 GPU:device='cuda'
  • 增加 num_workers 用于数据加载
  • 减少 image_size 如果精度要求不高

3. 模型过拟合

  • 增加 dropout 值(如 0.6-0.7)
  • 增加 weight_decay(如 1e-3)
  • 收集更多训练数据
  • 增强数据增强策略

4. 找不到数据集

确保数据集路径正确,或在 config.py 中修改:

DATA_DIR = Path("your/custom/data/path")

输出文件说明

checkpoints/

  • best_model.pth: 验证集上表现最好的模型
  • last_checkpoint.pth: 最后一个 epoch 的检查点

logs/

  • training. log: 训练过程日志
  • training_history. json: 训练历史数据(loss、acc 等)
  • confusion_matrix.png: 混淆矩阵可视化
  • metrics.png: 评估指标可视化
  • predictions.json: 批量预测结果

贡献

欢迎提交 Issue 和 Pull Request!

许可证

本项目仅供学习和研究使用。

联系方式

如有问题,请在 GitHub 上提交 Issue。


祝使用愉快!🌽

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages