Skip to content

NH-5/NNDL

Repository files navigation

概述

本项目是一个神经网络的项目,可以用于Mnist的手写数字识别,也可用于图像分类.

项目文件结构

Mnist数据集来自这里

图像分类的数据集来自开源数据集

可以保存每一次训练的时间和对应的损失函数及正确率的值在logs文件夹下的txt文件中.

将网络,数据集加载,训练结束后通知,config,训练,记录数据等模块分开实现

依赖

项目主要基于:python=3.13.5 pandas=2.3.3 requests=2.32.5 torch=2.8.0+cu129

详细信息参考env.yml,文件来自conda env export > env.yml.

模型效果

对于Mnist,采用全连接网络,训练30轮,batch_size=64的情况:

当使用均方误差函数,学习率设置3左右.在测试数据集上最终能达到接近95%的正确率.

当使用交叉熵损失函数时,学习率设置0.01.在测试数据集上最终能达到88.97%的正确率.

对于分类数据集,使用类似LeNet的网络,CNN+MLP的结构:

batch_size设置16时accuracy仅为10%.

batch_size设置32时accuracy可以达到95%.

结果复现

  • git clone代码
  • 新建.env文件,根据.env.example所示格式写入bark_key
  • 根据实际修改train.py中各项参数,运行train.py开始训练

注意事项

  • 当前仅对所提供链接处下载的数据集做过测试,确认代码能够运行

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages