基于深度学习的玉米粒完整性分类系统,使用改进的 SqueezeNet + Ghost Module 网络架构,用于识别玉米粒是否完整(intact)或破损(broken)。
本项目实现了一个轻量级的图像分类模型,专门用于玉米粒的质量检测。模型基于 SqueezeNet 架构并结合 Ghost Module 进行优化,在保持高准确率的同时减少计算资源消耗。
- 🌽 二分类任务:识别完整(intact)和破损(broken)的玉米粒
- 🚀 轻量级模型:基于 SqueezeNet + Ghost Module 优化
- 📊 完整的训练流程:支持训练、验证、测试全流程
- 📈 详细的评估指标:准确率、精确率、召回率、F1-score、混淆矩阵
- 🔮 灵活的预测功能:支持单张图片和批量预测
- 📝 日志记录:自动保存训练日志和模型检查点
CornProject/
├── config. py # 配置文件
├── train.py # 训练脚本
├── evaluate.py # 评估脚本
├── predict.py # 预测脚本
├── requirements.txt # 依赖包列表
├── models/
│ └── squeezenet_ghost.py # 模型架构
├── utils/
│ ├── augmentation. py # 数据增强
│ ├── data_loader.py # 数据加载器
│ └── metrics.py # 评估指标
├── data/ # 数据目录
│ ├── train/ # 训练集
│ ├── val/ # 验证集
│ └── test/ # 测试集
├── checkpoints/ # 模型保存目录
└── logs/ # 日志保存目录
- Python 3.8+
- CUDA 11.0+ (如果使用 GPU)
# 克隆项目
git clone https://github.com/Jie-Huangi/CornProject.git
cd CornProject
# 安装依赖包
pip install -r requirements.txt- torch >= 2.0.0
- torchvision >= 0. 15.0
- numpy >= 1.24.0
- opencv-python >= 4.7. 0
- Pillow >= 9.5.0
- scikit-learn >= 1.2.0
- matplotlib >= 3.7.0
- tqdm >= 4.65.0
按以下结构组织您的数据集:
data/
├── train/
│ ├── intact/ # 完整玉米粒图片
│ └── broken/ # 破损玉米粒图片
├── val/
│ ├── intact/
│ └── broken/
└── test/
├── intact/
└── broken/
- 支持的图片格式:
. jpg,.png - 图片会自动调整为 224x224 尺寸
- 训练时会自动应用数据增强
编辑 config.py 文件,根据需要调整配置:
# 数据配置
DATA_CONFIG = {
'train_path': DATA_DIR / 'train',
'val_path': DATA_DIR / 'val',
'test_path': DATA_DIR / 'test',
'image_size': 224,
'batch_size': 32,
'num_workers': 4,
}
# 训练配置
TRAIN_CONFIG = {
'num_epochs': 100,
'learning_rate': 0.001,
'weight_decay': 1e-4,
'dropout': 0.5,
'num_classes': 2,
'device': 'cuda', # 使用 'cpu' 如果没有 GPU
}python train.py训练过程会:
- 自动创建
checkpoints/和logs/目录 - 显示实时训练进度和指标
- 保存最佳模型到
checkpoints/best_model.pth - 保存最后一次训练检查点到
checkpoints/last_checkpoint.pth - 生成训练历史记录
logs/training_history.json - 记录详细日志到
logs/training. log
训练输出示例:
Epoch 1/100 [Train]: 100%|██████████| 50/50 [00:30<00:00, loss: 0.6234, acc: 0.7123]
Epoch 1/100 [Val]: 100%|██████████| 10/10 [00:05<00:00, acc: 0.7500]
Train Loss: 0.6234, Train Acc: 0.7123
Val Acc: 0.7500, Precision: 0.7345, Recall: 0.7621, F1: 0.7481
Best model saved with accuracy: 0.7500
在测试集上评估训练好的模型:
python evaluate.py评估会生成:
- 测试集准确率、精确率、召回率、F1-score
- 详细的分类报告
- 混淆矩阵图
logs/confusion_matrix.png - 指标可视化图
logs/metrics. png
输出示例:
==================================================
Test Results
==================================================
Accuracy: 0.8945
Precision: 0.8876
Recall: 0.9021
F1-Score: 0.8948
Classification Report:
precision recall f1-score support
intact 0.91 0.88 0.89 150
broken 0.87 0.90 0.88 145
accuracy 0.89 295
from predict import Predictor
import config
# 创建预测器
predictor = Predictor(
checkpoint_path=config.CHECKPOINT_PATH,
device='cuda' # 或 'cpu'
)
# 预测单张图片
class_name, confidence, all_probs = predictor.predict_image('path/to/image.jpg')
print(f"预测类别: {class_name}")
print(f"置信度: {confidence:. 4f}")
print(f"所有概率: {all_probs}")# 预测整个文件夹中的图片
results = predictor.predict_batch('path/to/image/directory', save_results=True)
# 结果会保存到 logs/predictions.json# 在图片上可视化预测结果
predictor.visualize_prediction(
image_path='path/to/image.jpg',
save_path='output. jpg'
)直接运行预测脚本(需要先修改 predict.py 中的示例图片路径):
python predict.py本项目使用改进的 SqueezeNet 架构:
- 基础架构:SqueezeNet 1.1
- 改进模块:Ghost Module(减少参数量和计算量)
- 分类器:全连接层 + Dropout + Softmax
- 参数量:约 1.2M(相比原始网络减少约 30%)
训练时自动应用以下增强方法:
- 随机水平翻转
- 随机旋转(±15°)
- 颜色抖动(亮度、对比度、饱和度、色调)
- 随机调整尺寸和裁剪
- 归一化(ImageNet 统计值)
使用 ReduceLROnPlateau 策略:
- 验证准确率 10 个 epoch 不提升时,学习率减半
- 帮助模型跳出局部最优
自动保存验证集上表现最好的模型
降低 batch_size 在 config.py 中:
DATA_CONFIG = {
'batch_size': 16, # 从 32 降低到 16
}- 确保使用 GPU:
device='cuda' - 增加
num_workers用于数据加载 - 减少
image_size如果精度要求不高
- 增加
dropout值(如 0.6-0.7) - 增加
weight_decay(如 1e-3) - 收集更多训练数据
- 增强数据增强策略
确保数据集路径正确,或在 config.py 中修改:
DATA_DIR = Path("your/custom/data/path")best_model.pth: 验证集上表现最好的模型last_checkpoint.pth: 最后一个 epoch 的检查点
training. log: 训练过程日志training_history. json: 训练历史数据(loss、acc 等)confusion_matrix.png: 混淆矩阵可视化metrics.png: 评估指标可视化predictions.json: 批量预测结果
欢迎提交 Issue 和 Pull Request!
本项目仅供学习和研究使用。
如有问题,请在 GitHub 上提交 Issue。
祝使用愉快!🌽