|
| 1 | +# Burn原生YOLO训练器使用指南 |
| 2 | + |
| 3 | +## 概述 |
| 4 | + |
| 5 | +本指南介绍如何使用纯Rust实现的Burn训练器来训练YOLO模型,完全替代Python依赖。 |
| 6 | + |
| 7 | +## 技术栈 |
| 8 | + |
| 9 | +- **训练框架**: Burn (Rust原生深度学习框架) |
| 10 | +- **CPU后端**: burn-ndarray |
| 11 | +- **GPU后端**: burn-cudarc (可选,需要CUDA环境) |
| 12 | +- **数据加载**: 原生Rust实现 |
| 13 | + |
| 14 | +## 主要模块 |
| 15 | + |
| 16 | +### 1. burn_trainer - 训练器核心 |
| 17 | + |
| 18 | +```rust |
| 19 | +use crate::modules::yolo::services::burn_trainer::{ |
| 20 | + BurnTrainer, |
| 21 | + TrainingConfig, |
| 22 | + YOLOConfig, |
| 23 | + TrainingEvent, |
| 24 | +}; |
| 25 | +``` |
| 26 | + |
| 27 | +**创建训练器:** |
| 28 | + |
| 29 | +```rust |
| 30 | +let trainer = BurnTrainer::new(); |
| 31 | +``` |
| 32 | + |
| 33 | +**训练配置:** |
| 34 | + |
| 35 | +```rust |
| 36 | +let config = TrainingConfig { |
| 37 | + project_name: "my_yolo_model".to_string(), |
| 38 | + epochs: 100, |
| 39 | + batch_size: 16, |
| 40 | + image_size: 640, |
| 41 | + num_classes: 4, // 根据数据集调整 |
| 42 | + optimizer: "SGD".to_string(), |
| 43 | + learning_rate: 0.01, |
| 44 | + weight_decay: 0.0005, |
| 45 | + momentum: 0.937, |
| 46 | + warmup_epochs: 3, |
| 47 | + device: "cpu".to_string(), // 或 "cuda" |
| 48 | + workers: 8, |
| 49 | + save_period: 10, |
| 50 | +}; |
| 51 | +``` |
| 52 | + |
| 53 | +### 2. yolo_dataset - 数据加载 |
| 54 | + |
| 55 | +```rust |
| 56 | +use crate::modules::yolo::services::yolo_dataset::{ |
| 57 | + YOLODataset, |
| 58 | + DatasetConfig, |
| 59 | + BoundingBox, |
| 60 | +}; |
| 61 | +``` |
| 62 | + |
| 63 | +**创建数据集:** |
| 64 | + |
| 65 | +```rust |
| 66 | +let config = DatasetConfig { |
| 67 | + dataset_path: PathBuf::from("./dataset"), |
| 68 | + train_images: PathBuf::from("./dataset/train/images"), |
| 69 | + train_labels: PathBuf::from("./dataset/train/labels"), |
| 70 | + val_images: PathBuf::from("./dataset/val/images"), |
| 71 | + val_labels: PathBuf::from("./dataset/val/labels"), |
| 72 | + class_names: vec![ |
| 73 | + "person".to_string(), |
| 74 | + "car".to_string(), |
| 75 | + "dog".to_string(), |
| 76 | + "cat".to_string(), |
| 77 | + ], |
| 78 | + num_classes: 4, |
| 79 | +}; |
| 80 | + |
| 81 | +let dataset = YOLODataset::new(config, 640)?; |
| 82 | +``` |
| 83 | + |
| 84 | +**数据增强:** |
| 85 | + |
| 86 | +```rust |
| 87 | +// 随机水平翻转 |
| 88 | +dataset.random_flip(&mut boxes); |
| 89 | + |
| 90 | +// 随机亮度调整 |
| 91 | +let brightness = dataset.random_brightness(); |
| 92 | +``` |
| 93 | + |
| 94 | +### 3. yolo_loss - 损失函数 |
| 95 | + |
| 96 | +```rust |
| 97 | +use crate::modules::yolo::services::yolo_loss::{ |
| 98 | + YOLOLoss, |
| 99 | + YOLOLossConfig, |
| 100 | + YOLOTarget, |
| 101 | + BoxTarget, |
| 102 | + ciou_loss, |
| 103 | + calculate_iou, |
| 104 | +}; |
| 105 | +``` |
| 106 | + |
| 107 | +**损失配置:** |
| 108 | + |
| 109 | +```rust |
| 110 | +let config = YOLOLossConfig { |
| 111 | + box_weight: 7.5, // 边界框损失权重 |
| 112 | + cls_weight: 0.5, // 分类损失权重 |
| 113 | + dfl_weight: 1.5, // DFL损失权重 |
| 114 | + cls_for_bg: 0.25, // 背景类别的分类权重 |
| 115 | +}; |
| 116 | +``` |
| 117 | + |
| 118 | +## 完整训练示例 |
| 119 | + |
| 120 | +```rust |
| 121 | +use tokio::sync::mpsc; |
| 122 | + |
| 123 | +async fn train_yolo_model() -> Result<(), String> { |
| 124 | + // 1. 创建训练器 |
| 125 | + let trainer = BurnTrainer::new(); |
| 126 | + |
| 127 | + // 2. 创建训练配置 |
| 128 | + let config = TrainingConfig { |
| 129 | + project_name: "wildlife_detector".to_string(), |
| 130 | + epochs: 300, |
| 131 | + batch_size: 16, |
| 132 | + image_size: 640, |
| 133 | + num_classes: 4, |
| 134 | + optimizer: "SGD".to_string(), |
| 135 | + learning_rate: 0.01, |
| 136 | + weight_decay: 0.0005, |
| 137 | + momentum: 0.937, |
| 138 | + warmup_epochs: 3, |
| 139 | + device: "cpu".to_string(), |
| 140 | + workers: 8, |
| 141 | + save_period: 10, |
| 142 | + }; |
| 143 | + |
| 144 | + // 3. 创建事件通道 |
| 145 | + let (event_tx, mut event_rx) = mpsc::unbounded_channel(); |
| 146 | + |
| 147 | + // 4. 启动训练 |
| 148 | + let training_id = "train_001".to_string(); |
| 149 | + let result = trainer.train_async(training_id, config, event_tx).await?; |
| 150 | + |
| 151 | + // 5. 处理训练事件 |
| 152 | + while let Some(event) = event_rx.recv().await { |
| 153 | + match event { |
| 154 | + TrainingEvent::Started { total_epochs, .. } => { |
| 155 | + println!("训练开始,共 {} 个epoch", total_epochs); |
| 156 | + } |
| 157 | + TrainingEvent::BatchProgress(state) => { |
| 158 | + println!("Epoch {}/{}, Batch {}/{}, Loss: {:.4f}", |
| 159 | + state.epoch, state.total_epochs, |
| 160 | + state.batch, state.total_batches, |
| 161 | + state.total_loss); |
| 162 | + } |
| 163 | + TrainingEvent::EpochComplete { epoch, map50, .. } => { |
| 164 | + if let Some(mAP) = map50 { |
| 165 | + println!("Epoch {} 完成, mAP@50: {:.4f}", epoch, mAP); |
| 166 | + } |
| 167 | + } |
| 168 | + TrainingEvent::Complete { model_path } => { |
| 169 | + println!("训练完成! 模型保存到: {}", model_path); |
| 170 | + break; |
| 171 | + } |
| 172 | + TrainingEvent::Error { error } => { |
| 173 | + eprintln!("训练错误: {}", error); |
| 174 | + break; |
| 175 | + } |
| 176 | + TrainingEvent::Stopped => { |
| 177 | + println!("训练被停止"); |
| 178 | + break; |
| 179 | + } |
| 180 | + } |
| 181 | + } |
| 182 | + |
| 183 | + Ok(()) |
| 184 | +} |
| 185 | +``` |
| 186 | + |
| 187 | +## 从Python训练迁移 |
| 188 | + |
| 189 | +### Python方式 (已废弃) |
| 190 | + |
| 191 | +```python |
| 192 | +# 启动Python YOLO训练 |
| 193 | +python yolo_server.py 8080 |
| 194 | +# 发送训练命令... |
| 195 | +``` |
| 196 | + |
| 197 | +### Rust方式 (推荐) |
| 198 | + |
| 199 | +```rust |
| 200 | +use crate::modules::yolo::services::burn_trainer::BurnTrainer; |
| 201 | + |
| 202 | +// 直接在Rust中训练,无需外部进程 |
| 203 | +let trainer = BurnTrainer::new(); |
| 204 | +let result = trainer.train_async(id, config, event_tx).await?; |
| 205 | +``` |
| 206 | + |
| 207 | +## 性能对比 |
| 208 | + |
| 209 | +| 指标 | Python (Ultralytics) | Rust (Burn) | |
| 210 | +|------|---------------------|-------------| |
| 211 | +| 训练速度 | 基准 | ~相同 | |
| 212 | +| 依赖 | Python + PyTorch | 仅Rust | |
| 213 | +| 可移植性 | 受限于Python环境 | 完全可移植 | |
| 214 | +| GPU支持 | 原生CUDA | burn-cudarc | |
| 215 | +| 部署难度 | 需要Python运行时 | 纯二进制 | |
| 216 | + |
| 217 | +## GPU加速 (可选) |
| 218 | + |
| 219 | +要启用GPU训练,需要: |
| 220 | + |
| 221 | +1. 安装CUDA Toolkit (11.8或12.1) |
| 222 | +2. 在`Cargo.toml`中启用burn-cudarc: |
| 223 | + |
| 224 | +```toml |
| 225 | +burn-cudarc = "0.5" # 取消注释 |
| 226 | +``` |
| 227 | + |
| 228 | +3. 设置设备为"cuda": |
| 229 | + |
| 230 | +```rust |
| 231 | +let config = TrainingConfig { |
| 232 | + device: "cuda".to_string(), |
| 233 | + // ... |
| 234 | +}; |
| 235 | +``` |
| 236 | + |
| 237 | +## 模型导出 |
| 238 | + |
| 239 | +训练完成后,模型将保存为ONNX格式: |
| 240 | + |
| 241 | +``` |
| 242 | +train/{project_name}/weights/best.onnx |
| 243 | +``` |
| 244 | + |
| 245 | +可以使用现有的推理引擎进行部署: |
| 246 | + |
| 247 | +```rust |
| 248 | +use crate::modules::yolo::services::inference_engine::InferenceEngine; |
| 249 | + |
| 250 | +let engine = InferenceEngine::new("train/wildlife_detector/weights/best.onnx")?; |
| 251 | +let detections = engine.detect(&image, 0.65)?; |
| 252 | +``` |
| 253 | + |
| 254 | +## 限制和注意事项 |
| 255 | + |
| 256 | +### 当前限制 |
| 257 | + |
| 258 | +1. **模型简化**: 简化版的YOLOv8模型架构用于演示 |
| 259 | +2. **性能**: CPU训练速度较慢,建议使用GPU |
| 260 | +3. **数据增强**: 当前实现较基础,可扩展更多增强方法 |
| 261 | + |
| 262 | +### 待完成功能 |
| 263 | + |
| 264 | +- [ ] 完整的YOLOv8架构实现 |
| 265 | +- [ ] 更多数据增强方法 (Mosaic, MixUp等) |
| 266 | +- [ ] 学习率调度器 (Cosine Annealing, Warmup等) |
| 267 | +- [ ] 模型导出为TorchScript格式 |
| 268 | +- [ ] 验证集评估和mAP计算 |
| 269 | + |
| 270 | +## 故障排除 |
| 271 | + |
| 272 | +### 问题1: 编译错误 "cannot find module" |
| 273 | + |
| 274 | +**解决方案**: 确保在`src-tauri/src/modules/yolo/services/mod.rs`中导入了模块: |
| 275 | + |
| 276 | +```rust |
| 277 | +pub mod burn_trainer; |
| 278 | +pub mod yolo_dataset; |
| 279 | +pub mod yolo_loss; |
| 280 | +``` |
| 281 | + |
| 282 | +### 问题2: 训练很慢 |
| 283 | + |
| 284 | +**解决方案**: |
| 285 | +1. 使用GPU加速 (设置`device: "cuda"`) |
| 286 | +2. 减小图像尺寸 (从640降到320) |
| 287 | +3. 减小batch_size以适应内存 |
| 288 | + |
| 289 | +### 问题3: OOM (内存不足) |
| 290 | + |
| 291 | +**解决方案**: |
| 292 | +1. 减小batch_size |
| 293 | +2. 使用更小的图像尺寸 |
| 294 | +3. 启用梯度累积 |
| 295 | + |
| 296 | +## 进一步优化建议 |
| 297 | + |
| 298 | +1. **混合精度训练**: 实现FP16训练以减少内存使用 |
| 299 | +2. **多GPU训练**: 使用数据并行 |
| 300 | +3. **模型量化**: 训练后量化以减少模型大小 |
| 301 | +4. **更深的网络**: 实现完整的YOLOv8l或YOLOv8x |
| 302 | + |
| 303 | +## 总结 |
| 304 | + |
| 305 | +Burn训练器提供了一个完全原生Rust的YOLO训练解决方案,消除了对Python的依赖。虽然当前版本是简化实现,但它为未来完整功能奠定了基础。通过使用Burn,我们可以: |
| 306 | + |
| 307 | +✅ 消除Python依赖 |
| 308 | +✅ 实现完全可移植的部署 |
| 309 | +✅ 利用Rust的性能优势 |
| 310 | +✅ 集成到现有的Rust项目中 |
| 311 | + |
| 312 | +随着更多功能的实现,Burn训练器将成为生产环境中的可靠选择。 |
0 commit comments