-
Notifications
You must be signed in to change notification settings - Fork 18
Open
Labels
Description
向flowvision中添加模型的几个步骤
整体分为以下几个过程:
1. 添加模型
添加新的模型,需要注意几个细节
- 如果借鉴参考了别人的model, 需要在文件开头声明
"""
Modified from xxx.py
"""- 导包顺序必须按照
python自带的包 - 额外安装的包 - 自身仓库的module的顺序(注意:重复的代码模块需要写入单独的文件。如Drop Path,Patch Embedding等)
import math
import oneflow as flow
import oneflow.nn as nn
from .registry import ModelCreator
from .utils import load_state_dict_from_url- 必须定义
model_urls变量
model_urls = {
"convnext_tiny_224": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/ConvNeXt/convnext_tiny_1k_224_ema.zip",
}- 定义并注册好模型,写好相关的docstring(注意:docstring风格需要统一,如冒号,大小写,结尾句号等)
class ConvNeXt(nn.Module)
def __init__(self, **kwargs):
pass
def _create_convnext(arch, pretrained=False, progress=True, **model_kwargs):
model = ConvNeXt(**model_kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
@ModelCreator.register_model
def convnext_tiny_224(pretrained=False, progress=True, **kwargs):
"""
Constructs the ConvNext-Tiny model trained on ImageNet2012.
.. note::
ConvNext-Tiny model from `"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>` _.
The required input size of the model is 224x224.
Args:
pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False``
progress (bool): If True, displays a progress bar of the download to stderrt. Default: ``True``
For example:
.. code-block:: python
>>> import flowvision
>>> convnext_tiny_224 = flowvision.models.convnext_tiny_224(pretrained=False, progress=True)
"""
model_kwargs = dict(
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
**kwargs
)
return _create_convnext(
"convnext_tiny_224", pretrained=pretrained, progress=progress, **model_kwargs
)2. 转换对应的模型权重
- 一个简单的函数, 自己可以拓展,但是基本是按照这个函数来改
from flowvision.models import ModelCreator
import oneflow as flow
import torch
def convert_torch_to_flow(model, torch_weight_path, save_path):
parameters = torch.load(torch_weight_path)
new_parameters = dict()
for key, value in parameters.items():
if "num_batches_tracked" not in key:
val = value.detach().cpu().numpy()
new_parameters[key] = val
model.load_state_dict(new_parameters)
flow.save(model.state_dict(), save_path)
print("successfully save model to %s" % (save_path))
model = ModelCreator.create_model("efficientnet_b7")
torch_weight = "/home/rentianhe/code/OneFlow-Models/vision/weights/efficientnet_b7_lukemelas-dcc49843.pth"
convert_torch_to_flow(model, torch_weight, save_path="./weights/efficientnet_b7")3. 测试权重结果并记录
利用 https://github.com/Oneflow-Inc/vision/blob/main/projects/benchmark/classification/eval.sh 进行测试
4. 更新MODEL_ZOO
更新结果至 https://github.com/Oneflow-Inc/vision/blob/main/results/results_imagenet.md
5. 更新docstring
更新至 https://github.com/Oneflow-Inc/vision/blob/main/docs/source/flowvision.models.rst
6. 更新README中的表格
更新至 https://github.com/Oneflow-Inc/vision#overview-of-flowvision-structure
7. 添加与torch的速度对比,测试数据要列在pr里面,同时在测速脚本里添加对应的模型
Ldpe2G, simonJJJ and lixiang007666