Skip to content

pt模型转rknn模型过大 #515

@lrm2017

Description

@lrm2017

直接将pt模型转为rknn模型,导出的模型巨大300MB,与转为onnx再导出rknn模型(2MB)相差很大,是代码哪里配置有问题么
`#!/usr/bin/env python3
"""
XFeat PyTorch模型直接转换为RKNN
避免GridSample问题,支持灵活配置
"""

import sys
import os
import cv2
import numpy as np
import torch
import glob
from rknn.api import RKNN

def create_xfeat_dataset(input_shape, output_path="dataset_xfeat_pt.txt", num_samples=50):
"""为XFeat创建数据集"""
print(f"创建XFeat数据集: {output_path}")

# 解析输入形状
if len(input_shape) == 4:
    batch, channels, height, width = input_shape
    is_nchw = True
else:
    print(f"不支持的输入形状: {input_shape}")
    return None

print(f"输入格式: NCHW")
print(f"目标尺寸: {width}x{height}")

# 查找datasets目录中的图像
assets_dir = "datasets"
if not os.path.exists(assets_dir):
    print(f"❌ datasets目录不存在: {assets_dir}")
    return None

# 查找图像文件(包括子目录)
image_files = []
for ext in ['.jpg', '.jpeg', '.png', '.bmp']:
    image_files.extend(glob.glob(os.path.join(assets_dir, f"**/*{ext}"), recursive=True))
    image_files.extend(glob.glob(os.path.join(assets_dir, f"**/*{ext.upper()}"), recursive=True))

if not image_files:
    print(f"❌ 在 {assets_dir} 目录中未找到图像文件")
    return None

print(f"找到 {len(image_files)} 个图像文件")

# 创建临时目录
temp_dir = "temp_xfeat_pt_data"
os.makedirs(temp_dir, exist_ok=True)

# 限制样本数量
num_samples = min(num_samples, len(image_files))

with open(output_path, 'w') as f:
    for i in range(num_samples):
        image_file = image_files[i % len(image_files)]
        print(f"处理图像: {os.path.basename(image_file)}")
        
        # 读取并处理图像
        img = cv2.imread(image_file)
        if img is None:
            continue
            
        # 转换为RGB并调整尺寸
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_resized = cv2.resize(img_rgb, (width, height))
        
        # 归一化到[0,1]
        image = img_resized.astype(np.float32) / 255.0
        
        # 转换为NCHW格式: [H, W, C] -> [C, H, W]
        image = np.transpose(image, (2, 0, 1))
        image = np.expand_dims(image, axis=0)  # [1, C, H, W]
        
        # 保存为numpy数组
        image_path = os.path.join(temp_dir, f"image_{i:03d}.npy")
        np.save(image_path, image)
        f.write(f"{image_path}\n")

print(f"数据集创建完成,包含 {num_samples} 个样本")
return output_path

def convert_xfeat_pt_to_rknn(pt_path, input_shape, platform='rk3568', output_path=None,
dense=False, top_k=512,
optimization_level=1, num_samples=50):
"""转换XFeat PyTorch模型为RKNN"""
if output_path is None:
model_name = f"xfeat_{'dense' if dense else 'sparse'}{top_k}{input_shape[2]}x{input_shape[3]}"
output_path = f"rknn/{model_name}.rknn"

print(f"转换XFeat模型: {pt_path}")
print(f"输入形状: {input_shape}")
print(f"目标平台: {platform}")
print(f"输出路径: {output_path}")
print(f"Dense模式: {dense}")
print(f"关键点数量: {top_k}")

# 创建数据集
dataset_path = create_xfeat_dataset(input_shape, num_samples=num_samples)
if not dataset_path:
    return False

try:
    # 创建RKNN对象
    rknn = RKNN(verbose=True)
    
    # 配置模型 - 添加完整的量化配置
    print('--> Config model')
    rknn.config(
        mean_values=[[127.5, 127.5, 127.5]], 
        std_values=[[128.0, 128.0, 128.0]], 
        quant_img_RGB2BGR=False,
        quantized_algorithm='normal',
        quantized_dtype='asymmetric_quantized-8',  # 8位量化,大幅减小模型
        quantized_method='channel',                # 通道级量化
        target_platform=platform,
        optimization_level=optimization_level,
        model_pruning=False,
    )
    print('done')
    
    # 加载PyTorch模型 - 参考官方test.py格式
    print('--> Loading model')
    
    # 尝试多种加载方法
    load_success = False
    
    # 方法1:标准格式
    try:
        ret = rknn.load_pytorch(
            model=pt_path, 
            input_size_list=[[input_shape[0], input_shape[1], input_shape[2], input_shape[3]]]  # 格式: [[C, H, W]]
        )
        if ret == 0:
            load_success = True
            print('✅ 标准格式加载成功')
        else:
            print(f'❌ 标准格式加载失败,错误码: {ret}')
    except Exception as e:
        print(f'❌ 标准格式加载异常: {e}')
    
    # 方法2:尝试不同的输入格式
    if not load_success:
        try:
            print('尝试不同的输入格式...')
            ret = rknn.load_pytorch(
                model=pt_path, 
                input_size_list=[[input_shape[2], input_shape[3], input_shape[1]]]  # 格式: [[H, W, C]]
            )
            if ret == 0:
                load_success = True
                print('✅ HWC格式加载成功')
            else:
                print(f'❌ HWC格式加载失败,错误码: {ret}')
        except Exception as e:
            print(f'❌ HWC格式加载异常: {e}')
    
    # 方法3:尝试不指定输入尺寸
    if not load_success:
        try:
            print('尝试不指定输入尺寸...')
            ret = rknn.load_pytorch(model=pt_path)
            if ret == 0:
                load_success = True
                print('✅ 无尺寸限制加载成功')
            else:
                print(f'❌ 无尺寸限制加载失败,错误码: {ret}')
        except Exception as e:
            print(f'❌ 无尺寸限制加载异常: {e}')
    
    if not load_success:
        print('❌ 所有加载方法都失败了')
        return False
    
    print('done')
    
    # 构建模型
    print('--> Building model')
    
    # 尝试不同的构建选项
    build_success = False
    
    # 方法1:使用量化 + 优化配置
    try:
        ret = rknn.build(
            do_quantization=True, 
            dataset=dataset_path,
            rknn_batch_size=1,  # 减小批处理大小
        )
        if ret == 0:
            build_success = True
            print('✅ 量化构建成功')
        else:
            print(f'❌ 量化构建失败,错误码: {ret}')
    except Exception as e:
        print(f'❌ 量化构建异常: {e}')
    
    # 方法2:不使用量化
    if not build_success:
        try:
            print('尝试不使用量化构建...')
            ret = rknn.build(do_quantization=False)
            if ret == 0:
                build_success = True
                print('✅ 非量化构建成功')
            else:
                print(f'❌ 非量化构建失败,错误码: {ret}')
        except Exception as e:
            print(f'❌ 非量化构建异常: {e}')
    
    # 方法3:使用更少的样本
    if not build_success:
        try:
            print('尝试使用更少的样本构建...')
            # 创建更小的数据集
            small_dataset = create_xfeat_dataset(input_shape, num_samples=10)
            if small_dataset:
                ret = rknn.build(do_quantization=True, dataset=small_dataset)
                if ret == 0:
                    build_success = True
                    print('✅ 小样本构建成功')
                else:
                    print(f'❌ 小样本构建失败,错误码: {ret}')
        except Exception as e:
            print(f'❌ 小样本构建异常: {e}')
    
    if not build_success:
        print('❌ 所有构建方法都失败了')
        return False
    
    print('done')
    
    # 导出RKNN模型
    print('--> Export rknn model')
    ret = rknn.export_rknn(output_path)
    if ret != 0:
        print('Export rknn model failed!')
        return False
    print('done')
    
    # 释放资源
    rknn.release()
    
    # 清理临时文件
    if os.path.exists("temp_xfeat_pt_data"):
        import shutil
        shutil.rmtree("temp_xfeat_pt_data")
    
    print(f"✅ 转换成功: {output_path}")
    return True
    
except Exception as e:
    print(f"❌ 转换失败: {e}")
    return False

def export_xfeat_pt_to_rknn(xfeat_path, output_folder="rknn",
input_shape=(1, 3, 480, 640),
dynamic=False,
dense=False,
top_k=256,
platform='rk3568',
optimization_level=1,
num_samples=50):
"""导出XFeat PyTorch模型为RKNN格式"""
print("=" * 50)
print("XFeat PyTorch to RKNN 转换器")
print("=" * 50)
print(f"模型路径: {xfeat_path}")
print(f"输出目录: {output_folder}")
print(f"输入形状: {input_shape}")
print(f"动态输入: {dynamic}")
print(f"Dense模式: {dense}")
print(f"关键点数量: {top_k}")
print(f"目标平台: {platform}")
print(f"优化级别: {optimization_level}")
print("=" * 50)

# 创建输出目录
os.makedirs(output_folder, exist_ok=True)

# 生成输出文件名
model_name = f"xfeat_{'dense' if dense else 'sparse'}_{top_k}_{input_shape[2]}x{input_shape[3]}"
output_path = os.path.join(output_folder, f"{model_name}.rknn")

# 转换模型
success = convert_xfeat_pt_to_rknn(
    pt_path=xfeat_path,
    input_shape=input_shape,
    platform=platform,
    output_path=output_path,
    dense=dense,
    top_k=top_k,
    optimization_level=optimization_level,
    num_samples=num_samples
)

if success:
    print("\n🎉 XFeat转换完成!")
    print(f"RKNN模型已保存到: {output_path}")
    print("现在可以使用转换后的RKNN模型进行推理了。")
else:
    print("\n❌ 转换失败!")
    print("可能的原因:")
    print("1. PyTorch模型包含RK3568不支持的操作")
    print("2. 模型结构过于复杂")
    print("3. 输入形状不匹配")
    print("4. 内存不足")

return success

def main():
"""主函数 - 支持命令行参数"""
if len(sys.argv) < 2:
print("Usage: python3 {} xfeat_pt_model_path [options]".format(sys.argv[0]))
print("\nOptions:")
print(" --input_shape WIDTH,HEIGHT 输入尺寸 (默认: 640,480)")
print(" --platform PLATFORM 目标平台 (默认: rk3568)")
print(" --dense 使用dense模式")
print(" --top_k NUM 关键点数量 (默认: 256)")
print(" --layers NUM LightGlue层数 (默认: 2)")
print(" --optimization NUM 优化级别 (默认: 1)")
print(" --samples NUM 数据集样本数 (默认: 50)")
print("\nExample:")
print(" python3 xfeat_pt_to_rknn.py weights/xfeat.pt --input_shape 1024,1536 --top_k 512")
print(" python3 xfeat_pt_to_rknn.py weights/xfeat.pt --dense --platform rk3588")
exit(1)

pt_path = sys.argv[1]

# 解析命令行参数
input_shape = (1, 3, 480, 640)  # 默认
platform = 'rk3568'
dense = False
top_k = 256
optimization_level = 1
num_samples = 50

i = 2
while i < len(sys.argv):
    arg = sys.argv[i]
    if arg == '--input_shape' and i + 1 < len(sys.argv):
        width, height = map(int, sys.argv[i + 1].split(','))
        input_shape = (1, 3, height, width)
        i += 2
    elif arg == '--platform' and i + 1 < len(sys.argv):
        platform = sys.argv[i + 1]
        i += 2
    elif arg == '--dense':
        dense = True
        i += 1
    elif arg == '--top_k' and i + 1 < len(sys.argv):
        top_k = int(sys.argv[i + 1])
        i += 2
    elif arg == '--optimization' and i + 1 < len(sys.argv):
        optimization_level = int(sys.argv[i + 1])
        i += 2
    elif arg == '--samples' and i + 1 < len(sys.argv):
        num_samples = int(sys.argv[i + 1])
        i += 2
    else:
        print(f"未知参数: {arg}")
        exit(1)

if not os.path.exists(pt_path):
    print(f"❌ PyTorch文件不存在: {pt_path}")
    exit(1)

# 执行转换
success = export_xfeat_pt_to_rknn(
    xfeat_path=pt_path,
    input_shape=input_shape,
    dynamic=False,
    dense=dense,
    top_k=top_k,
    platform=platform,
    optimization_level=optimization_level,
    num_samples=num_samples
)

exit(0 if success else 1)

if name == 'main':
# 如果直接运行脚本,使用默认参数
if len(sys.argv) == 1:
# 示例:使用默认参数转换XFeat模型
export_xfeat_pt_to_rknn(
xfeat_path="weights/xfeat_dummy.pt",
output_folder="rknn",
input_shape=(1, 3, 480, 640), # N C H W
dynamic=False, # 固定输入
dense=False, # 使用稀疏特征,避免复杂操作
top_k=256, # 大幅减少关键点数
platform='rk3568',
optimization_level=1,
num_samples=50
)
else:
# 使用命令行参数
main()
`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions