-
Notifications
You must be signed in to change notification settings - Fork 187
Description
直接将pt模型转为rknn模型,导出的模型巨大300MB,与转为onnx再导出rknn模型(2MB)相差很大,是代码哪里配置有问题么
`#!/usr/bin/env python3
"""
XFeat PyTorch模型直接转换为RKNN
避免GridSample问题,支持灵活配置
"""
import sys
import os
import cv2
import numpy as np
import torch
import glob
from rknn.api import RKNN
def create_xfeat_dataset(input_shape, output_path="dataset_xfeat_pt.txt", num_samples=50):
"""为XFeat创建数据集"""
print(f"创建XFeat数据集: {output_path}")
# 解析输入形状
if len(input_shape) == 4:
batch, channels, height, width = input_shape
is_nchw = True
else:
print(f"不支持的输入形状: {input_shape}")
return None
print(f"输入格式: NCHW")
print(f"目标尺寸: {width}x{height}")
# 查找datasets目录中的图像
assets_dir = "datasets"
if not os.path.exists(assets_dir):
print(f"❌ datasets目录不存在: {assets_dir}")
return None
# 查找图像文件(包括子目录)
image_files = []
for ext in ['.jpg', '.jpeg', '.png', '.bmp']:
image_files.extend(glob.glob(os.path.join(assets_dir, f"**/*{ext}"), recursive=True))
image_files.extend(glob.glob(os.path.join(assets_dir, f"**/*{ext.upper()}"), recursive=True))
if not image_files:
print(f"❌ 在 {assets_dir} 目录中未找到图像文件")
return None
print(f"找到 {len(image_files)} 个图像文件")
# 创建临时目录
temp_dir = "temp_xfeat_pt_data"
os.makedirs(temp_dir, exist_ok=True)
# 限制样本数量
num_samples = min(num_samples, len(image_files))
with open(output_path, 'w') as f:
for i in range(num_samples):
image_file = image_files[i % len(image_files)]
print(f"处理图像: {os.path.basename(image_file)}")
# 读取并处理图像
img = cv2.imread(image_file)
if img is None:
continue
# 转换为RGB并调整尺寸
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_resized = cv2.resize(img_rgb, (width, height))
# 归一化到[0,1]
image = img_resized.astype(np.float32) / 255.0
# 转换为NCHW格式: [H, W, C] -> [C, H, W]
image = np.transpose(image, (2, 0, 1))
image = np.expand_dims(image, axis=0) # [1, C, H, W]
# 保存为numpy数组
image_path = os.path.join(temp_dir, f"image_{i:03d}.npy")
np.save(image_path, image)
f.write(f"{image_path}\n")
print(f"数据集创建完成,包含 {num_samples} 个样本")
return output_path
def convert_xfeat_pt_to_rknn(pt_path, input_shape, platform='rk3568', output_path=None,
dense=False, top_k=512,
optimization_level=1, num_samples=50):
"""转换XFeat PyTorch模型为RKNN"""
if output_path is None:
model_name = f"xfeat_{'dense' if dense else 'sparse'}{top_k}{input_shape[2]}x{input_shape[3]}"
output_path = f"rknn/{model_name}.rknn"
print(f"转换XFeat模型: {pt_path}")
print(f"输入形状: {input_shape}")
print(f"目标平台: {platform}")
print(f"输出路径: {output_path}")
print(f"Dense模式: {dense}")
print(f"关键点数量: {top_k}")
# 创建数据集
dataset_path = create_xfeat_dataset(input_shape, num_samples=num_samples)
if not dataset_path:
return False
try:
# 创建RKNN对象
rknn = RKNN(verbose=True)
# 配置模型 - 添加完整的量化配置
print('--> Config model')
rknn.config(
mean_values=[[127.5, 127.5, 127.5]],
std_values=[[128.0, 128.0, 128.0]],
quant_img_RGB2BGR=False,
quantized_algorithm='normal',
quantized_dtype='asymmetric_quantized-8', # 8位量化,大幅减小模型
quantized_method='channel', # 通道级量化
target_platform=platform,
optimization_level=optimization_level,
model_pruning=False,
)
print('done')
# 加载PyTorch模型 - 参考官方test.py格式
print('--> Loading model')
# 尝试多种加载方法
load_success = False
# 方法1:标准格式
try:
ret = rknn.load_pytorch(
model=pt_path,
input_size_list=[[input_shape[0], input_shape[1], input_shape[2], input_shape[3]]] # 格式: [[C, H, W]]
)
if ret == 0:
load_success = True
print('✅ 标准格式加载成功')
else:
print(f'❌ 标准格式加载失败,错误码: {ret}')
except Exception as e:
print(f'❌ 标准格式加载异常: {e}')
# 方法2:尝试不同的输入格式
if not load_success:
try:
print('尝试不同的输入格式...')
ret = rknn.load_pytorch(
model=pt_path,
input_size_list=[[input_shape[2], input_shape[3], input_shape[1]]] # 格式: [[H, W, C]]
)
if ret == 0:
load_success = True
print('✅ HWC格式加载成功')
else:
print(f'❌ HWC格式加载失败,错误码: {ret}')
except Exception as e:
print(f'❌ HWC格式加载异常: {e}')
# 方法3:尝试不指定输入尺寸
if not load_success:
try:
print('尝试不指定输入尺寸...')
ret = rknn.load_pytorch(model=pt_path)
if ret == 0:
load_success = True
print('✅ 无尺寸限制加载成功')
else:
print(f'❌ 无尺寸限制加载失败,错误码: {ret}')
except Exception as e:
print(f'❌ 无尺寸限制加载异常: {e}')
if not load_success:
print('❌ 所有加载方法都失败了')
return False
print('done')
# 构建模型
print('--> Building model')
# 尝试不同的构建选项
build_success = False
# 方法1:使用量化 + 优化配置
try:
ret = rknn.build(
do_quantization=True,
dataset=dataset_path,
rknn_batch_size=1, # 减小批处理大小
)
if ret == 0:
build_success = True
print('✅ 量化构建成功')
else:
print(f'❌ 量化构建失败,错误码: {ret}')
except Exception as e:
print(f'❌ 量化构建异常: {e}')
# 方法2:不使用量化
if not build_success:
try:
print('尝试不使用量化构建...')
ret = rknn.build(do_quantization=False)
if ret == 0:
build_success = True
print('✅ 非量化构建成功')
else:
print(f'❌ 非量化构建失败,错误码: {ret}')
except Exception as e:
print(f'❌ 非量化构建异常: {e}')
# 方法3:使用更少的样本
if not build_success:
try:
print('尝试使用更少的样本构建...')
# 创建更小的数据集
small_dataset = create_xfeat_dataset(input_shape, num_samples=10)
if small_dataset:
ret = rknn.build(do_quantization=True, dataset=small_dataset)
if ret == 0:
build_success = True
print('✅ 小样本构建成功')
else:
print(f'❌ 小样本构建失败,错误码: {ret}')
except Exception as e:
print(f'❌ 小样本构建异常: {e}')
if not build_success:
print('❌ 所有构建方法都失败了')
return False
print('done')
# 导出RKNN模型
print('--> Export rknn model')
ret = rknn.export_rknn(output_path)
if ret != 0:
print('Export rknn model failed!')
return False
print('done')
# 释放资源
rknn.release()
# 清理临时文件
if os.path.exists("temp_xfeat_pt_data"):
import shutil
shutil.rmtree("temp_xfeat_pt_data")
print(f"✅ 转换成功: {output_path}")
return True
except Exception as e:
print(f"❌ 转换失败: {e}")
return False
def export_xfeat_pt_to_rknn(xfeat_path, output_folder="rknn",
input_shape=(1, 3, 480, 640),
dynamic=False,
dense=False,
top_k=256,
platform='rk3568',
optimization_level=1,
num_samples=50):
"""导出XFeat PyTorch模型为RKNN格式"""
print("=" * 50)
print("XFeat PyTorch to RKNN 转换器")
print("=" * 50)
print(f"模型路径: {xfeat_path}")
print(f"输出目录: {output_folder}")
print(f"输入形状: {input_shape}")
print(f"动态输入: {dynamic}")
print(f"Dense模式: {dense}")
print(f"关键点数量: {top_k}")
print(f"目标平台: {platform}")
print(f"优化级别: {optimization_level}")
print("=" * 50)
# 创建输出目录
os.makedirs(output_folder, exist_ok=True)
# 生成输出文件名
model_name = f"xfeat_{'dense' if dense else 'sparse'}_{top_k}_{input_shape[2]}x{input_shape[3]}"
output_path = os.path.join(output_folder, f"{model_name}.rknn")
# 转换模型
success = convert_xfeat_pt_to_rknn(
pt_path=xfeat_path,
input_shape=input_shape,
platform=platform,
output_path=output_path,
dense=dense,
top_k=top_k,
optimization_level=optimization_level,
num_samples=num_samples
)
if success:
print("\n🎉 XFeat转换完成!")
print(f"RKNN模型已保存到: {output_path}")
print("现在可以使用转换后的RKNN模型进行推理了。")
else:
print("\n❌ 转换失败!")
print("可能的原因:")
print("1. PyTorch模型包含RK3568不支持的操作")
print("2. 模型结构过于复杂")
print("3. 输入形状不匹配")
print("4. 内存不足")
return success
def main():
"""主函数 - 支持命令行参数"""
if len(sys.argv) < 2:
print("Usage: python3 {} xfeat_pt_model_path [options]".format(sys.argv[0]))
print("\nOptions:")
print(" --input_shape WIDTH,HEIGHT 输入尺寸 (默认: 640,480)")
print(" --platform PLATFORM 目标平台 (默认: rk3568)")
print(" --dense 使用dense模式")
print(" --top_k NUM 关键点数量 (默认: 256)")
print(" --layers NUM LightGlue层数 (默认: 2)")
print(" --optimization NUM 优化级别 (默认: 1)")
print(" --samples NUM 数据集样本数 (默认: 50)")
print("\nExample:")
print(" python3 xfeat_pt_to_rknn.py weights/xfeat.pt --input_shape 1024,1536 --top_k 512")
print(" python3 xfeat_pt_to_rknn.py weights/xfeat.pt --dense --platform rk3588")
exit(1)
pt_path = sys.argv[1]
# 解析命令行参数
input_shape = (1, 3, 480, 640) # 默认
platform = 'rk3568'
dense = False
top_k = 256
optimization_level = 1
num_samples = 50
i = 2
while i < len(sys.argv):
arg = sys.argv[i]
if arg == '--input_shape' and i + 1 < len(sys.argv):
width, height = map(int, sys.argv[i + 1].split(','))
input_shape = (1, 3, height, width)
i += 2
elif arg == '--platform' and i + 1 < len(sys.argv):
platform = sys.argv[i + 1]
i += 2
elif arg == '--dense':
dense = True
i += 1
elif arg == '--top_k' and i + 1 < len(sys.argv):
top_k = int(sys.argv[i + 1])
i += 2
elif arg == '--optimization' and i + 1 < len(sys.argv):
optimization_level = int(sys.argv[i + 1])
i += 2
elif arg == '--samples' and i + 1 < len(sys.argv):
num_samples = int(sys.argv[i + 1])
i += 2
else:
print(f"未知参数: {arg}")
exit(1)
if not os.path.exists(pt_path):
print(f"❌ PyTorch文件不存在: {pt_path}")
exit(1)
# 执行转换
success = export_xfeat_pt_to_rknn(
xfeat_path=pt_path,
input_shape=input_shape,
dynamic=False,
dense=dense,
top_k=top_k,
platform=platform,
optimization_level=optimization_level,
num_samples=num_samples
)
exit(0 if success else 1)
if name == 'main':
# 如果直接运行脚本,使用默认参数
if len(sys.argv) == 1:
# 示例:使用默认参数转换XFeat模型
export_xfeat_pt_to_rknn(
xfeat_path="weights/xfeat_dummy.pt",
output_folder="rknn",
input_shape=(1, 3, 480, 640), # N C H W
dynamic=False, # 固定输入
dense=False, # 使用稀疏特征,避免复杂操作
top_k=256, # 大幅减少关键点数
platform='rk3568',
optimization_level=1,
num_samples=50
)
else:
# 使用命令行参数
main()
`