本项目是一个神经网络的项目,可以用于Mnist的手写数字识别,也可用于图像分类.
Mnist数据集来自这里
图像分类的数据集来自开源数据集
可以保存每一次训练的时间和对应的损失函数及正确率的值在logs文件夹下的txt文件中.
将网络,数据集加载,训练结束后通知,config,训练,记录数据等模块分开实现
项目主要基于:python=3.13.5 pandas=2.3.3 requests=2.32.5 torch=2.8.0+cu129
详细信息参考env.yml,文件来自conda env export > env.yml.
对于Mnist,采用全连接网络,训练30轮,batch_size=64的情况:
当使用均方误差函数,学习率设置3左右.在测试数据集上最终能达到接近95%的正确率.
当使用交叉熵损失函数时,学习率设置0.01.在测试数据集上最终能达到88.97%的正确率.
对于分类数据集,使用类似LeNet的网络,CNN+MLP的结构:
batch_size设置16时accuracy仅为10%.
batch_size设置32时accuracy可以达到95%.
git clone代码- 新建
.env文件,根据.env.example所示格式写入bark_key - 根据实际修改
train.py中各项参数,运行train.py开始训练
- 当前仅对所提供链接处下载的数据集做过测试,确认代码能够运行