From a2d4590fccf5e478817b268d1fc1c0c12cf0f5fd Mon Sep 17 00:00:00 2001 From: chen Date: Thu, 12 Feb 2026 09:11:49 +0000 Subject: [PATCH] feat(lora): add Low-Rank Adaptation support for efficient fine-tuning - Add LoRA module infrastructure with configurable rank, alpha, dropout - Implement LoRALinear wrapper for seamless integration with Linear layers - Support tensor parallelism via LoRAParallelLinear - Add LoRAModel utility for managing multiple LoRA layers - Integrate LoRA configuration and utilities - Add GPT2 example demonstrating LoRA fine-tuning - Include comprehensive usage documentation and test suite Co-Authored-By: Claude Opus 4.6 --- CMakeLists.txt | 4 + docs/lora_usage.md | 615 ++++++++++++++++ example/gpt2/main.cc | 53 +- infini_train/include/nn/lora/lora_config.h | 41 ++ infini_train/include/nn/lora/lora_linear.h | 72 ++ infini_train/include/nn/lora/lora_model.h | 86 +++ .../include/nn/lora/lora_parallel_linear.h | 118 +++ infini_train/include/nn/lora/lora_utils.h | 86 +++ infini_train/include/nn/modules/module.h | 1 + infini_train/src/nn/lora/lora_config.cc | 43 ++ infini_train/src/nn/lora/lora_linear.cc | 208 ++++++ infini_train/src/nn/lora/lora_model.cc | 73 ++ .../src/nn/lora/lora_parallel_linear.cc | 377 ++++++++++ infini_train/src/nn/lora/lora_utils.cc | 340 +++++++++ infini_train/src/nn/modules/module.cc | 5 + test/lora/test_lora.cc | 675 ++++++++++++++++++ 16 files changed, 2795 insertions(+), 2 deletions(-) create mode 100644 docs/lora_usage.md create mode 100644 infini_train/include/nn/lora/lora_config.h create mode 100644 infini_train/include/nn/lora/lora_linear.h create mode 100644 infini_train/include/nn/lora/lora_model.h create mode 100644 infini_train/include/nn/lora/lora_parallel_linear.h create mode 100644 infini_train/include/nn/lora/lora_utils.h create mode 100644 infini_train/src/nn/lora/lora_config.cc create mode 100644 infini_train/src/nn/lora/lora_linear.cc create mode 100644 infini_train/src/nn/lora/lora_model.cc create mode 100644 infini_train/src/nn/lora/lora_parallel_linear.cc create mode 100644 infini_train/src/nn/lora/lora_utils.cc create mode 100644 test/lora/test_lora.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 6c16068..86b0b85 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -197,3 +197,7 @@ target_link_libraries(test_hook infini_train) add_executable(test_precision_check test/hook/test_precision_check.cc) target_link_libraries(test_precision_check infini_train) + +add_executable(test_lora test/lora/test_lora.cc) +target_link_libraries(test_lora infini_train) + diff --git a/docs/lora_usage.md b/docs/lora_usage.md new file mode 100644 index 0000000..60a28ec --- /dev/null +++ b/docs/lora_usage.md @@ -0,0 +1,615 @@ +# LoRA 使用说明 + +本文档介绍如何在 InfiniTrain 中使用 LoRA (Low-Rank Adaptation) 进行高效微调。 + +## 目录 + +1. [快速开始](#快速开始) +2. [核心概念](#核心概念) +3. [命令行使用](#命令行使用-gpt2-示例) +4. [LoRAModel 包装器](#lora模型-包装器-推荐模式) +5. [API 参考](#api-参考) +6. [使用示例](#使用示例) +7. [最佳实践](#最佳实践) + +## 快速开始 + +### 头文件引入 + +```cpp +#include "nn/lora/lora_config.h" +#include "nn/lora/lora_linear.h" +#include "nn/lora/lora_utils.h" +// 如果使用张量并行 +#include "nn/lora/lora_parallel_linear.h" +``` + +### 最简示例 + +```cpp +using namespace infini_train::nn::lora; + +// 1. 创建 LoRA 配置 +LoRAConfig config; +config.rank = 8; // 低秩维度 +config.alpha = 16.0f; // 缩放因子 + +// 2. 获取 LoRA 模型 +auto* lora_model = GetLoRAModel(model, config); + +// 3. 获取 LoRA 参数用于优化器 +auto lora_params = lora_model->TrainableParameters(); +auto optimizer = std::make_shared(lora_params, lr); + +// 4. 训练循环 +for (int step = 0; step < num_steps; ++step) { + auto loss = (*model)(inputs); + loss->Backward(); + optimizer->Step(); + optimizer->ZeroGrad(); +} + +// 6. 保存 LoRA 权重 +SaveLoRAWeights(model, "lora_weights.bin"); +``` + +## 核心概念 + +### LoRA 原理 + +LoRA 通过低秩分解来近似权重更新: + +``` +原始: y = Wx + b +LoRA: y = Wx + b + (α/r) × x × A^T × B^T +``` + +其中: +- `W` 是冻结的原始权重 +- `A` 是形状为 `[rank, in_features]` 的可训练矩阵 +- `B` 是形状为 `[out_features, rank]` 的可训练矩阵 +- `α/r` 是缩放因子 + +### 参数效率 + +假设原始 Linear 层参数量为 `in × out`,LoRA 只需训练 `rank × (in + out)` 个参数。 + +例如:`in=4096, out=4096, rank=8` +- 原始参数:16,777,216 +- LoRA 参数:65,536 (仅 0.39%) + +## LoRAModel 包装器类 + +### LoRAModel + +遵循 PEFT 模式的 LoRA 包装器,封装基础模型和 LoRA 配置。使用 `NamedModules()` 自动遍历模型层次结构。 + +```cpp +class LoRAModel : public Module { +public: + // 构造函数 - 自动遍历模型层次结构 + LoRAModel(std::shared_ptr base_model, + const LoRAConfig &config); + + // 获取可训练参数 + std::vector> TrainableParameters() const; + + // 获取所有参数 + std::vector> Parameters() const override; + + // LoRA 权重管理 + void SaveLoRA(const std::string &filepath) const; + void LoadLoRA(const std::string &filepath); + void Merge(); + void Unmerge(); + bool IsMerged() const; + + // 打印摘要 + void PrintSummary() const; + + // 访问基础模型 + std::shared_ptr base_model() const; + + // 获取 LoRA 配置 + const LoRAConfig &config() const; +}; +``` + +### 工厂函数 + +```cpp +template +std::shared_ptr CreateLoRAModel( + const ConfigType &model_config, + const LoRAConfig &lora_config) { + auto base_model = std::make_shared(model_config); + return std::make_shared(base_model, lora_config); +} +``` + +## API 参考 + +### LoRAConfig - 配置结构 + +```cpp +struct LoRAConfig { + int64_t rank = 8; // 低秩维度 r + float alpha = 16.0f; // 缩放因子 α + float dropout = 0.0f; // Dropout 概率(暂未实现) + + // 目标模块名称(默认只对 attention 层应用) + std::unordered_set target_modules = {"c_attn", "attn.c_proj"}; + + // 初始化参数 + bool use_kaiming_a = true; // A 矩阵使用 Kaiming 初始化 + float kaiming_a_param = 1.0f; // Kaiming 初始化参数 + + // 计算缩放因子 + float Scaling() const; // 返回 alpha / rank + + // 检查模块是否应该应用 LoRA + bool ShouldApplyLoRA(const std::string &module_name) const; +}; +``` + +### 模型应用函数 + +#### GetLoRAModel + +PEFT-style 运行时包装器,使用 `NamedModules()` 自动遍历模型层次结构,创建 LoRA 模型。 + +```cpp +LoRAModel* GetLoRAModel( + std::shared_ptr model, // 目标模型 + const LoRAConfig &config // LoRA 配置 +); +``` + +**参数说明:** +- `model`: 要包装的模型 +- `config`: LoRA 配置(包含 `target_modules` 指定目标层) + +**返回值:** `LoRAModel*`,可用于调用 `LoadLoRA()`, `SaveLoRA()`, `PrintSummary()` 等方法 + +**使用示例:** +```cpp +// 配置 LoRA +LoRAConfig config{8, 16.0f}; +config.SetTargetModules("c_attn,attn.c_proj"); // 只对 attention +// config.SetTargetModules("c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj"); // 包含 MLP + +// 一行启用 LoRA +auto* lora_model = nn::lora::GetLoRAModel(model, config); +``` + +#### InjectLoRALayers + +使用 `NamedModules()` 自动遍历模型层次结构,将 LoRA 注入到所有匹配的层中。 + +```cpp +void InjectLoRALayers( + std::shared_ptr model, // 目标模型 + const LoRAConfig &config // LoRA 配置 +); +``` + +**参数说明:** +- `model`: 要注入 LoRA 的模型 +- `config`: LoRA 配置(通过 `target_modules` 指定目标层) + +### 参数管理函数 + +#### FreezeBaseModel / UnfreezeModel + +```cpp +// 冻结基础模型所有参数 +void FreezeBaseModel(std::shared_ptr model); + +// 解冻所有参数 +void UnfreezeModel(std::shared_ptr model); +``` + +#### GetLoRAParameters / GetBaseParameters + +```cpp +// 获取 LoRA 参数(用于优化器) +std::vector> GetLoRAParameters( + const std::shared_ptr &model); + +// 获取基础模型参数 +std::vector> GetBaseParameters( + const std::shared_ptr &model); +``` + +### 权重合并函数 + +#### MergeLoRAWeights / UnmergeLoRAWeights + +```cpp +// 合并 LoRA 权重到基础权重: W' = W + (α/r) × B × A +void MergeLoRAWeights(std::shared_ptr model); + +// 恢复原始基础权重 +void UnmergeLoRAWeights(std::shared_ptr model); +``` + +**使用场景:** +- 推理时合并权重可以消除额外计算开销 +- 导出模型时合并权重得到标准模型格式 + +### 保存/加载函数 + +```cpp +// 保存 LoRA 权重到文件 +void SaveLoRAWeights(const std::shared_ptr &model, + const std::string &filepath); + +// 从文件加载 LoRA 权重 +void LoadLoRAWeights(std::shared_ptr model, + const std::string &filepath); + +// 获取 LoRA 状态字典 +std::unordered_map> +LoRAStateDict(const std::shared_ptr &model); + +// 加载 LoRA 状态字典 +void LoadLoRAStateDict( + std::shared_ptr model, + const std::unordered_map> &state_dict); +``` + +### 统计函数 + +```cpp +// 打印 LoRA 模型摘要 +void PrintLoRASummary(const std::shared_ptr &model); + +// 统计可训练参数数量 +int64_t CountTrainableParameters(const std::shared_ptr &model); + +// 统计总参数数量 +int64_t CountTotalParameters(const std::shared_ptr &model); +``` + +## 使用示例 + +### 示例 1: GPT2 微调 + +```cpp +#include "example/gpt2/gpt2.h" +#include "nn/lora/lora_utils.h" + +using namespace infini_train::nn::lora; + +int main() { + // 创建 GPT2 模型 + auto model = std::make_shared(config); + model->LoadWeights("gpt2_weights.bin"); + + // 配置 LoRA + LoRAConfig lora_config; + lora_config.rank = 8; + lora_config.alpha = 16.0f; + lora_config.SetTargetModules("c_attn,attn.c_proj"); // 只对 attention 层 + + // 获取 LoRA 模型 + auto* lora_model = GetLoRAModel(model, lora_config); + + // 打印参数统计 + PrintLoRASummary(lora_model); + // 输出示例: + // ========== LoRA Model Summary ========== + // Total parameters: 124,439,808 + // Trainable parameters: 294,912 (0.24%) + // Frozen parameters: 124,144,896 + // ========================================= + + // 创建优化器(只优化 LoRA 参数) + auto lora_params = lora_model->TrainableParameters(); + auto optimizer = std::make_shared(lora_params, /*lr=*/1e-4); + + // 训练循环 + for (int step = 0; step < num_steps; ++step) { + auto [loss, logits] = (*lora_model)({input_ids}); + loss->Backward(); + optimizer->Step(); + optimizer->ZeroGrad(); + + if (step % 100 == 0) { + std::cout << "Step " << step << ", Loss: " << loss->Item() << std::endl; + } + } + + // 保存 LoRA 权重(仅几 MB) + lora_model->SaveLoRA("gpt2_lora.bin"); + + return 0; +} +``` + +### 示例 2: LLaMA3 分布式微调 + +```cpp +#include "example/llama3/llama3.h" +#include "nn/lora/lora_utils.h" +#include "nn/parallel/process_group.h" + +using namespace infini_train::nn::lora; + +int main(int argc, char **argv) { + // 初始化分布式环境 + InitDistributed(argc, argv); + + // 创建 LLaMA3 模型(带张量并行) + LLaMA3Config config; + config.n_layers = 32; + config.tensor_parallel = 2; + + auto model = std::make_shared(config); + model->LoadWeights("llama3_weights/"); + + // 配置 LoRA(包含 MLP 层以获得更好效果) + LoRAConfig lora_config{16, 32.0f}; + lora_config.SetTargetModules("c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj"); + + // 获取 LoRA 模型(通过 target_modules 配置包含 MLP 层) + auto* lora_model = GetLoRAModel(model, lora_config); + + PrintLoRASummary(lora_model); + + // 训练... + + // 保存 + if (GetRank() == 0) { + SaveLoRAWeights(model, "llama3_lora.bin"); + } + + return 0; +} +``` + +### 示例 3: 推理时合并权重 + +```cpp +// 加载基础模型 +auto model = std::make_shared(config); +model->LoadWeights("gpt2_weights.bin"); + +// 获取 LoRA 模型 +auto* lora_model = GetLoRAModel(model, lora_config); + +// 加载 LoRA 权重 +lora_model->LoadLoRA("gpt2_lora.bin"); + +// 合并权重(推理时无额外开销) +lora_model->Merge(); + +// 现在可以像普通模型一样推理 +auto output = (*lora_model)({input_ids}); + +// 如果需要继续训练,先解除合并 +lora_model->Unmerge(); +``` + +### 示例 4: 自定义目标层 + +```cpp +// 或者对所有线性层应用 +config.SetTargetModules("c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj,lm_head"); + +// 获取 LoRA 模型 +auto* lora_model = GetLoRAModel(model, config); +``` + +## 最佳实践 + +### 1. 选择合适的 rank + +| 任务类型 | 推荐 rank | 说明 | +|---------|----------|------| +| 简单分类任务 | 4-8 | 参数少,训练快 | +| 文本生成微调 | 8-16 | 平衡效果和效率 | +| 复杂任务适配 | 16-64 | 更强表达能力 | + +### 2. alpha 设置 + +- 通常设置 `alpha = 2 × rank` +- 较大的 alpha 会增加 LoRA 的影响 +- 可以通过调整 alpha 来控制微调强度 + +### 3. 目标层选择 + +```cpp +// 推荐:只对 attention 层(参数效率最高) +config.SetTargetModules("c_attn,attn.c_proj"); + +// 可选:包含 MLP 层(效果可能更好,但参数更多) +config.SetTargetModules("c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj"); +``` + +### 4. 学习率 + +- LoRA 通常使用比全量微调更高的学习率 +- 推荐范围:1e-4 到 1e-3 +- 可以使用学习率预热和衰减 + +### 5. 内存优化 + +```cpp +// 只保存 LoRA 权重(几 MB vs 几 GB) +SaveLoRAWeights(model, "lora.bin"); + +// 推理时合并权重,消除额外计算 +MergeLoRAWeights(model); +``` + +## 模型层名称参考 + +### GPT2 模型结构 + +``` +transformer.wte # Token Embedding +transformer.wpe # Position Embedding +transformer.h.{i}.ln_1 # LayerNorm 1 +transformer.h.{i}.attn.c_attn # QKV 投影 (ColumnParallel) +transformer.h.{i}.attn.c_proj # Output 投影 (RowParallel) +transformer.h.{i}.ln_2 # LayerNorm 2 +transformer.h.{i}.mlp.c_fc # MLP 第一层 (ColumnParallel) +transformer.h.{i}.mlp.c_proj # MLP 第二层 (RowParallel) +transformer.ln_f # Final LayerNorm +lm_head # Language Model Head +``` + +### LLaMA3 模型结构 + +``` +transformer.tok_emb # Token Embedding +transformer.h.{i}.attn_norm # RMSNorm (attention) +transformer.h.{i}.attn.c_attn # QKV 投影 (ColumnParallel) +transformer.h.{i}.attn.c_proj # Output 投影 (RowParallel) +transformer.h.{i}.ffn_norm # RMSNorm (FFN) +transformer.h.{i}.mlp.c_fc # FFN gate (ColumnParallel) +transformer.h.{i}.mlp.c_fc2 # FFN up (ColumnParallel) +transformer.h.{i}.mlp.c_proj # FFN down (RowParallel) +transformer.norm # Final RMSNorm +lm_head # Language Model Head +``` + +## 命令行使用 (GPT2 示例) + +### 启用 LoRA 训练 + +```bash +./build/gpt2 \ + --device cuda \ + --input_bin data/train.bin \ + --llmc_filepath data/gpt2_124M.bin \ + --batch_size 4 \ + --sequence_length 64 \ + --num_iteration 10 \ + --learning_rate 1e-5 \ + --lora_rank 8 \ + --lora_alpha 16.0 \ + --lora_target_modules "c_attn,attn.c_proj" \ + --lora_save_path data/lora_weights +``` + +### 命令行参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `--lora_rank` | 0 | LoRA 秩 (0 = 禁用) | +| `--lora_alpha` | 16.0 | LoRA 缩放因子 | +| `--lora_target_modules` | "c_attn,attn.c_proj" | 目标模块 (逗号分隔: c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj) | +| `--lora_load_path` | "" | 加载已有 LoRA 权重 | +| `--lora_save_path` | "" | 保存 LoRA 权重路径 | + +### 加载已有 LoRA 权重 + +```bash +./build/gpt2 \ + ... + --lora_rank 8 \ + --lora_alpha 16.0 \ + --lora_load_path data/lora_weights +``` + +## LoRAModel 包装器 (推荐模式) + +### 概述 + +`LoRAModel` 是一个包装器类,遵循 PEFT 设计模式,将 LoRA 作为包装器应用于基础模型,而不是直接修改模型代码。 + +### 优势 + +- **透明性**: 训练循环无需修改,直接使用 `(*model)(inputs)` +- **参数管理**: 自动获取可训练参数 +- **权重管理**: 内置 Save/Load/Merge 方法 + +### 使用示例 + +```cpp +#include "infini_train/include/nn/lora/lora_model.h" + +using namespace infini_train::nn::lora; + +int main() { + // 1. 创建基础模型 + auto base_model = std::make_shared(config); + base_model->LoadWeights("gpt2_weights.bin"); + + // 2. 创建 LoRA 配置 + LoRAConfig lora_config{8, 16.0f}; + lora_config.SetTargetModules("c_attn,attn.c_proj"); // 只对 attention 层 + + // 3. 创建 LoRA 包装器 (一行代码) + auto lora_model = std::make_shared(base_model, lora_config); + + // 4. 获取可训练参数用于优化器 + auto trainable_params = lora_model->TrainableParameters(); + auto optimizer = std::make_shared(trainable_params, 1e-5); + + // 5. 打印摘要 + lora_model->PrintSummary(); + // 输出: + // ========== LoRA Model Summary ========== + // Total parameters: 176062464 + // Trainable parameters: 442368 (0.251256%) + // Frozen parameters: 175620096 + // ========================================= + + // 6. 训练循环 (无需修改) + for (int step = 0; step < num_steps; ++step) { + auto logits = (*lora_model)({x, y})[0]; + auto loss = (*loss_fn)({logits, y})[0]; + loss->Backward(); + optimizer->Step(); + optimizer->ZeroGrad(); + } + + // 7. 保存 LoRA 权重 + lora_model->SaveLoRA("lora_weights.bin"); + + return 0; +} +``` + +### 工厂函数 + +对于任意模型类型,可以使用模板工厂函数: + +```cpp +#include "infini_train/include/nn/lora/lora_model.h" + +auto lora_model = CreateLoRAModel( + model_config, // GPT2 模型配置 + lora_config // LoRA 配置 +); +``` + +## 常见问题 + +### Q: LoRA 权重文件有多大? + +A: 取决于 rank 和目标层数量。以 GPT2-small (12层) 为例: +- rank=8, attention only: ~1.2 MB +- rank=16, attention + MLP: ~4.8 MB + +### Q: 如何在不同任务间切换 LoRA? + +A: 保存和加载不同的 LoRA 权重文件: +```cpp +// 任务 A +LoadLoRAWeights(model, "task_a_lora.bin"); +// 推理... + +// 任务 B +LoadLoRAWeights(model, "task_b_lora.bin"); +// 推理... +``` + +### Q: 可以同时使用多个 LoRA 吗? + +A: 当前实现不支持多 LoRA 组合。如需此功能,可以: +1. 合并多个 LoRA 权重后加载 +2. 扩展实现支持 LoRA 堆叠 diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 3dfeadd..4c22013 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -13,6 +13,8 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" +#include "infini_train/include/nn/lora/lora_model.h" +#include "infini_train/include/nn/lora/lora_utils.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" @@ -79,6 +81,13 @@ DEFINE_string( precision_check, "", "precision check config: level=N,format=simple|table,output_md5=true|false,output_path=PATH,baseline=PATH"); +// LoRA parameters +DEFINE_int32(lora_rank, 0, "LoRA rank (0 = disabled)"); +DEFINE_double(lora_alpha, 16.0, "LoRA alpha scaling factor"); +DEFINE_string(lora_target_modules, "c_attn,c_proj", "LoRA target modules (comma-separated: c_attn,c_proj,c_fc,c_fc2,mlp.c_proj)"); +DEFINE_string(lora_save_path, "", "Path to save LoRA weights after training"); +DEFINE_string(lora_load_path, "", "Path to load LoRA weights from"); + using namespace infini_train; namespace { @@ -179,6 +188,8 @@ void Train(const nn::parallel::Rank &rank) { // init the model, either from scratch or from OpenAI pretrained checkpoint GPT2Config model_config; std::shared_ptr model = nullptr; + std::shared_ptr lora_model = nullptr; // LoRA wrapper (if enabled) + if (!FLAGS_llmc_filepath.empty()) { model = GPT2::FromLLMC(FLAGS_llmc_filepath); } else if (kModelToConfigs.count(FLAGS_model)) { @@ -191,6 +202,27 @@ void Train(const nn::parallel::Rank &rank) { model->To(device); utils::PrecisionChecker::BuildNameMap(model.get()); + // Apply LoRA using GetLoRAModel (PEFT-style Runtime Wrapper) + bool lora_enabled = FLAGS_lora_rank > 0; + if (lora_enabled) { + nn::lora::LoRAConfig lora_config{FLAGS_lora_rank, static_cast(FLAGS_lora_alpha)}; + lora_config.SetTargetModules(FLAGS_lora_target_modules); + + // GetLoRAModel handles InjectLoRALayers + FreezeBase automatically + lora_model = nn::lora::GetLoRAModel(model, lora_config); + + // Load LoRA weights if specified + if (!FLAGS_lora_load_path.empty()) { + LOG(INFO) << "Loading LoRA weights from: " << FLAGS_lora_load_path; + lora_model->LoadLoRA(FLAGS_lora_load_path); + } + + // Print LoRA summary + lora_model->PrintSummary(); + + // Use LoRAModel as the training model + model = lora_model; + } // select the data type // TODO(lzm): change to solely rely on the weight file info for determining the dtype when autocast is supported @@ -205,6 +237,16 @@ void Train(const nn::parallel::Rank &rank) { auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); + // Create optimizer - use LoRAModel's TrainableParameters() if LoRA is enabled + std::vector> params_to_optimize; + if (lora_model) { + params_to_optimize = lora_model->TrainableParameters(); + LOG(INFO) << "Optimizing " << params_to_optimize.size() << " LoRA parameters"; + } else { + params_to_optimize = model->Parameters(); + LOG(INFO) << "Optimizing " << params_to_optimize.size() << " model parameters"; + } + if (pp_world_size > 1) { // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. @@ -261,10 +303,10 @@ void Train(const nn::parallel::Rank &rank) { auto model_chunks = (pp_world_size > 1) ? *(dynamic_cast(model.get())->mutable_chunks()) : std::vector>{model}; - optimizer = std::make_shared(optimizer_creator, model->Parameters(), + optimizer = std::make_shared(optimizer_creator, params_to_optimize, model_chunks, ddp_world_size, ddp_rank); } else { - optimizer = optimizer_creator(model->Parameters()); + optimizer = optimizer_creator(params_to_optimize); } auto train_iter = train_loader.begin(); @@ -393,6 +435,13 @@ void Train(const nn::parallel::Rank &rank) { } } } + + // Save LoRA weights if enabled and path specified + if (lora_model && !FLAGS_lora_save_path.empty()) { + LOG(INFO) << "Saving LoRA weights to: " << FLAGS_lora_save_path; + lora_model->SaveLoRA(FLAGS_lora_save_path); + } + #ifdef PROFILE_MODE Profiler::Instance().Report("gpt2.report", Profiler::SortBy::DeviceTimePercentage); Profiler::Instance().PrintRecords("gpt2.records.log"); diff --git a/infini_train/include/nn/lora/lora_config.h b/infini_train/include/nn/lora/lora_config.h new file mode 100644 index 0000000..fb4f474 --- /dev/null +++ b/infini_train/include/nn/lora/lora_config.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include + +namespace infini_train::nn::lora { + +// LoRA (Low-Rank Adaptation) configuration +struct LoRAConfig { + // Core LoRA parameters + int64_t rank = 8; // Low-rank dimension (r) + float alpha = 16.0f; // Scaling factor (alpha) + float dropout = 0.0f; // Dropout probability (optional, not implemented yet) + + // Target modules specification (default: attention layers only) + std::unordered_set target_modules = {"c_attn", "c_proj"}; + + // Initialization parameters + bool use_kaiming_a = true; // Use Kaiming init for A matrix + float kaiming_a_param = 1.0f; // Parameter 'a' for Kaiming init + + // Default constructor + LoRAConfig() = default; + + // Constructor with rank and alpha (PEFT-style aggregate initialization) + LoRAConfig(int64_t r, float a, float d = 0.0f) + : rank(r), alpha(a), dropout(d) {} + + // Set target modules from comma-separated string (PEFT-compatible) + void SetTargetModules(const std::string& targets); + + // Compute scaling factor: output = base_output + scaling * lora_output + float Scaling() const; + + // Check if a module name should have LoRA applied + // Matches if the module name ends with any of the target module names + bool ShouldApplyLoRA(const std::string &module_name) const; +}; + +} // namespace infini_train::nn::lora diff --git a/infini_train/include/nn/lora/lora_linear.h b/infini_train/include/nn/lora/lora_linear.h new file mode 100644 index 0000000..2939e5c --- /dev/null +++ b/infini_train/include/nn/lora/lora_linear.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/lora/lora_config.h" +#include "infini_train/include/nn/modules/module.h" + +namespace infini_train { +class Tensor; +class Device; +} // namespace infini_train + +namespace infini_train::nn::lora { + +// LoRA wrapper for standard Linear layer +// Implements: y = Wx + b + (alpha/r) * x @ A^T @ B^T +// Where W is frozen, A and B are trainable low-rank matrices +class LoRALinear : public nn::CloneableModule { +public: + static constexpr char kType[] = "LoRALinear"; + + // Parameter names + static constexpr char kParamWeightName[] = "weight"; // Frozen base weight + static constexpr char kParamBiasName[] = "bias"; // Frozen base bias + static constexpr char kParamLoraAName[] = "lora_A"; // Trainable A matrix [rank, in_features] + static constexpr char kParamLoraBName[] = "lora_B"; // Trainable B matrix [out_features, rank] + + // Constructor from scratch + LoRALinear(int64_t in_features, int64_t out_features, const LoRAConfig &config, bool bias = true, + const Device *device = nullptr); + + // Constructor wrapping existing Linear module (transfers ownership of parameters) + LoRALinear(std::shared_ptr base_linear, const LoRAConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + + // LoRA-specific methods + void MergeWeights(); // Merge LoRA weights into base: W' = W + (alpha/r) * B @ A + void UnmergeWeights(); // Restore original base weights + bool IsMerged() const { return merged_; } + + // Get only LoRA parameters (for optimizer) + std::vector> LoRAParameters() const; + + // Override Parameters() to return only trainable (LoRA) parameters + std::vector> Parameters() const override; + + // Get all parameters including frozen base weights (for state dict) + std::vector> AllParameters() const; + + // Accessors + int64_t in_features() const; + int64_t out_features() const; + int64_t rank() const; + float scaling() const; + +private: + void InitLoRAWeights(); + void FreezeBaseWeights(); + + LoRAConfig config_; + int64_t in_features_; + int64_t out_features_; + bool bias_; + bool merged_ = false; + + // Store original weight for unmerge + std::shared_ptr original_weight_; +}; + +} // namespace infini_train::nn::lora diff --git a/infini_train/include/nn/lora/lora_model.h b/infini_train/include/nn/lora/lora_model.h new file mode 100644 index 0000000..7def373 --- /dev/null +++ b/infini_train/include/nn/lora/lora_model.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/nn/lora/lora_config.h" +#include "infini_train/include/nn/modules/module.h" + +namespace infini_train::nn::lora { + +// LoRAModel: A wrapper that applies LoRA to any base model +// This follows the PEFT design pattern where LoRA is applied as a wrapper +// rather than modifying the base model code directly. +// +// Usage: +// auto base_model = std::make_shared(config); +// LoRAConfig lora_config{8, 16.0f}; +// lora_config.SetTargetModules("c_attn,c_proj"); // or include mlp layers +// auto lora_model = std::make_shared(base_model, lora_config); +// +// // Training: only LoRA parameters are trainable +// auto optimizer = SGD(lora_model->TrainableParameters(), lr); +// +// // Save only LoRA weights +// lora_model->SaveLoRA("lora_weights.bin"); +// +// // Load LoRA weights +// lora_model->LoadLoRA("lora_weights.bin"); +// +// // Merge for inference (optional) +// lora_model->Merge(); +// +class LoRAModel : public Module { +public: + static constexpr char kType[] = "LoRAModel"; + + // Constructor: wraps a base model with LoRA + // Uses NamedModules() to automatically traverse the model hierarchy + // Parameters: + // - base_model: The original model (GPT2, LLaMA3, etc.) + // - config: LoRA configuration (rank, alpha, target_modules) + LoRAModel(std::shared_ptr base_model, const LoRAConfig &config); + + // Forward pass (delegates to base model) + std::vector> Forward(const std::vector> &inputs) override; + + // Get only trainable (LoRA) parameters for optimizer + std::vector> TrainableParameters() const; + + // Get all parameters (for state dict) + std::vector> Parameters() const override; + + // LoRA weight management + void SaveLoRA(const std::string &filepath) const; + void LoadLoRA(const std::string &filepath); + + // Merge/unmerge LoRA weights into base model + void Merge(); + void Unmerge(); + bool IsMerged() const; + + // Print summary + void PrintSummary() const; + + // Access base model + std::shared_ptr base_model() const; + + // Get LoRA config + const LoRAConfig &config() const; + +private: + std::shared_ptr base_model_; + LoRAConfig config_; + bool merged_ = false; +}; + +// Factory function for creating LoRA-enabled models +// This is the recommended way to create LoRA models +template +std::shared_ptr CreateLoRAModel(const ConfigType &model_config, const LoRAConfig &lora_config) { + auto base_model = std::make_shared(model_config); + return std::make_shared(base_model, lora_config); +} + +} // namespace infini_train::nn::lora diff --git a/infini_train/include/nn/lora/lora_parallel_linear.h b/infini_train/include/nn/lora/lora_parallel_linear.h new file mode 100644 index 0000000..6f83782 --- /dev/null +++ b/infini_train/include/nn/lora/lora_parallel_linear.h @@ -0,0 +1,118 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/lora/lora_config.h" +#include "infini_train/include/nn/modules/module.h" + +namespace infini_train { +class Tensor; +class Device; +} // namespace infini_train + +namespace infini_train::nn::lora { + +// LoRA wrapper for ColumnParallelLinear +// Weight shape: [out_features_per_partition, in_features] +// LoRA A: [rank, in_features] - replicated across TP ranks +// LoRA B: [out_features_per_partition, rank] - sharded like base weight +class LoRAColumnParallelLinear : public nn::CloneableModule { +public: + static constexpr char kType[] = "LoRAColumnParallelLinear"; + + static constexpr char kParamWeightName[] = "weight"; + static constexpr char kParamBiasName[] = "bias"; + static constexpr char kParamLoraAName[] = "lora_A"; + static constexpr char kParamLoraBName[] = "lora_B"; + + // Constructor wrapping existing ColumnParallelLinear + LoRAColumnParallelLinear(std::shared_ptr base_module, const LoRAConfig &config, int64_t in_features, + int64_t out_features); + + // Constructor wrapping existing ColumnParallelLinear (auto-infer dimensions from weight) + LoRAColumnParallelLinear(std::shared_ptr base_module, const LoRAConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + + void MergeWeights(); + void UnmergeWeights(); + bool IsMerged() const; + + std::vector> LoRAParameters() const; + std::vector> Parameters() const override; + + int64_t in_features() const; + int64_t out_features() const; + int64_t rank() const; + +private: + void InitLoRAWeights(); + void FreezeBaseWeights(); + + LoRAConfig config_; + int64_t in_features_; + int64_t out_features_; + int64_t out_features_per_partition_; + bool bias_; + bool gather_output_; + bool input_is_parallel_; + bool skip_bias_add_; + bool sequence_parallel_; + bool merged_ = false; + + std::shared_ptr original_weight_; +}; + +// LoRA wrapper for RowParallelLinear +// Weight shape: [out_features, in_features_per_partition] +// LoRA A: [rank, in_features_per_partition] - sharded like base weight +// LoRA B: [out_features, rank] - replicated, but gradient needs AllReduce +class LoRARowParallelLinear : public nn::CloneableModule { +public: + static constexpr char kType[] = "LoRARowParallelLinear"; + + static constexpr char kParamWeightName[] = "weight"; + static constexpr char kParamBiasName[] = "bias"; + static constexpr char kParamLoraAName[] = "lora_A"; + static constexpr char kParamLoraBName[] = "lora_B"; + + // Constructor wrapping existing RowParallelLinear + LoRARowParallelLinear(std::shared_ptr base_module, const LoRAConfig &config, int64_t in_features, + int64_t out_features); + + // Constructor wrapping existing RowParallelLinear (auto-infer dimensions from weight) + LoRARowParallelLinear(std::shared_ptr base_module, const LoRAConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + + void MergeWeights(); + void UnmergeWeights(); + bool IsMerged() const; + + std::vector> LoRAParameters() const; + std::vector> Parameters() const override; + + int64_t in_features() const; + int64_t out_features() const; + int64_t rank() const; + +private: + void InitLoRAWeights(); + void FreezeBaseWeights(); + + LoRAConfig config_; + int64_t in_features_; + int64_t out_features_; + int64_t in_features_per_partition_; + bool bias_; + bool reduce_output_; + bool input_is_parallel_; + bool skip_bias_add_; + bool sequence_parallel_; + bool merged_ = false; + + std::shared_ptr original_weight_; +}; + +} // namespace infini_train::nn::lora diff --git a/infini_train/include/nn/lora/lora_utils.h b/infini_train/include/nn/lora/lora_utils.h new file mode 100644 index 0000000..0bb8bc2 --- /dev/null +++ b/infini_train/include/nn/lora/lora_utils.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include +#include + +#include "infini_train/include/nn/lora/lora_config.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::nn { +class Module; +} + +namespace infini_train::nn::lora { + +// Forward declaration +class LoRAModel; + +// PEFT-style get_peft_model equivalent (Runtime Wrapper) +// Creates a LoRA-wrapped model with automatic module detection using NamedModules +// Parameters: +// - model: The model to wrap +// - config: LoRA configuration (rank, alpha, target_modules) +// Returns: The LoRA-wrapped model as shared_ptr +std::shared_ptr GetLoRAModel(std::shared_ptr model, const LoRAConfig &config); + +// Internal transform: inject LoRA layers into all matching modules +// Uses NamedModules() to automatically traverse the entire model hierarchy +// Parameters: +// - model: The model to inject LoRA into +// - config: LoRA configuration (rank, alpha, target_modules) +void InjectLoRALayers(std::shared_ptr model, const LoRAConfig &config); + +// Replace a module at the given path with a new module +// Parameters: +// - model: Root model containing the module +// - path: Full path to the module (e.g., "transformer.h.0.attn.c_attn") +// - new_module: The new module to replace with +void ReplaceModuleByPath(std::shared_ptr model, const std::string &path, std::shared_ptr new_module); + +// Freeze all base model parameters (set requires_grad = false) +void FreezeBaseModel(std::shared_ptr model); + +// Unfreeze all parameters (set requires_grad = true) +void UnfreezeModel(std::shared_ptr model); + +// Get only LoRA parameters from a model (for optimizer) +// Returns parameters from LoRALinear, LoRAColumnParallelLinear, LoRARowParallelLinear modules +std::vector> GetLoRAParameters(const std::shared_ptr &model); + +// Get only base (frozen) parameters +std::vector> GetBaseParameters(const std::shared_ptr &model); + +// Merge all LoRA weights in the model +void MergeLoRAWeights(std::shared_ptr model); + +// Unmerge all LoRA weights in the model +void UnmergeLoRAWeights(std::shared_ptr model); + +// Save only LoRA weights to file +void SaveLoRAWeights(const std::shared_ptr &model, const std::string &filepath); + +// Load LoRA weights from file +void LoadLoRAWeights(std::shared_ptr model, const std::string &filepath); + +// Get LoRA state dict (only LoRA parameters with their names) +std::unordered_map> LoRAStateDict(const std::shared_ptr &model); + +// Load LoRA state dict +void LoadLoRAStateDict(std::shared_ptr model, + const std::unordered_map> &state_dict); + +// Print LoRA model summary (trainable vs frozen parameters) +void PrintLoRASummary(const std::shared_ptr &model); + +// Count trainable parameters +int64_t CountTrainableParameters(const std::shared_ptr &model); + +// Count total parameters +int64_t CountTotalParameters(const std::shared_ptr &model); + +} // namespace infini_train::nn::lora diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index 6482840..4e307a8 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -57,6 +57,7 @@ class Module : public std::enable_shared_from_this { std::vector> modules(); std::shared_ptr mutable_module(const std::string &name); const Module &module(const std::string &name) const; + void replace_module(const std::string &name, std::shared_ptr new_module); std::unordered_map> StateDict() const; diff --git a/infini_train/src/nn/lora/lora_config.cc b/infini_train/src/nn/lora/lora_config.cc new file mode 100644 index 0000000..96fba7a --- /dev/null +++ b/infini_train/src/nn/lora/lora_config.cc @@ -0,0 +1,43 @@ +#include "infini_train/include/nn/lora/lora_config.h" + +#include + +namespace infini_train::nn::lora { + +void LoRAConfig::SetTargetModules(const std::string& targets) { + target_modules.clear(); + std::stringstream ss(targets); + std::string module; + while (std::getline(ss, module, ',')) { + // Trim whitespace + module.erase(module.find_last_not_of(" \t\r\n") + 1); + module.erase(0, module.find_first_not_of(" \t\r\n")); + if (!module.empty()) { + target_modules.insert(module); + } + } +} + +float LoRAConfig::Scaling() const { + return alpha / static_cast(rank); +} + +bool LoRAConfig::ShouldApplyLoRA(const std::string &module_name) const { + // Check if the module name ends with any of the target module names + // e.g., "transformer.h.0.attn.c_attn" should match "c_attn" + for (const auto &target : target_modules) { + // Check if module_name ends with target + if (module_name.length() >= target.length()) { + size_t pos = module_name.length() - target.length(); + if (module_name.substr(pos) == target) { + // Make sure it's a complete component (preceded by '.' or at start) + if (pos == 0 || module_name[pos - 1] == '.') { + return true; + } + } + } + } + return false; +} + +} // namespace infini_train::nn::lora diff --git a/infini_train/src/nn/lora/lora_linear.cc b/infini_train/src/nn/lora/lora_linear.cc new file mode 100644 index 0000000..35117bf --- /dev/null +++ b/infini_train/src/nn/lora/lora_linear.cc @@ -0,0 +1,208 @@ +#include "infini_train/include/nn/lora/lora_linear.h" + +#include +#include +#include +#include + +#include "infini_train/include/autograd/linear.h" +#include "infini_train/include/device.h" +#include "infini_train/include/nn/init.h" +#include "infini_train/include/nn/modules/linear.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::lora { + +LoRALinear::LoRALinear(int64_t in_features, int64_t out_features, const LoRAConfig &config, bool bias, + const Device *device) + : CloneableModule(kType), config_(config), in_features_(in_features), out_features_(out_features), bias_(bias) { + device_ = device ? *device : Device(); + + // Create base weight (frozen) + parameters_[kParamWeightName] + = std::make_shared(std::vector{out_features, in_features}, DataType::kFLOAT32, device_); + init::KaimingUniform(parameters_[kParamWeightName], sqrt(5.0f)); + + // Create base bias (frozen) + if (bias) { + parameters_[kParamBiasName] + = std::make_shared(std::vector{out_features}, DataType::kFLOAT32, device_); + const auto [fan_in, _] = init::CalculateFanInAndFanOut(parameters_[kParamWeightName]); + const float bound = fan_in > 0 ? 1.0 / sqrt(fan_in) : 0.0; + init::Uniform(parameters_[kParamBiasName], -bound, bound); + } + + // Initialize LoRA weights + InitLoRAWeights(); + + // Freeze base weights + FreezeBaseWeights(); +} + +LoRALinear::LoRALinear(std::shared_ptr base_linear, const LoRAConfig &config) + : CloneableModule(kType), config_(config), bias_(false) { + if (!base_linear) { + throw std::invalid_argument("base_linear cannot be null"); + } + + // Get device from base linear + device_ = base_linear->parameter(nn::Linear::kParamWeightName)->GetDevice(); + + // Transfer weight from base linear + parameters_[kParamWeightName] = base_linear->parameter(nn::Linear::kParamWeightName); + + // Get dimensions from weight shape [out_features, in_features] + const auto &weight_dims = parameters_[kParamWeightName]->Dims(); + out_features_ = weight_dims[0]; + in_features_ = weight_dims[1]; + + // Transfer bias if exists + if (base_linear->has_parameter(nn::Linear::kParamBiasName)) { + parameters_[kParamBiasName] = base_linear->parameter(nn::Linear::kParamBiasName); + bias_ = true; + } + + // Initialize LoRA weights + InitLoRAWeights(); + + // Freeze base weights + FreezeBaseWeights(); +} + +void LoRALinear::InitLoRAWeights() { + // A matrix: [rank, in_features] + // Initialize with Kaiming uniform (or normal based on config) + parameters_[kParamLoraAName] + = std::make_shared(std::vector{config_.rank, in_features_}, DataType::kFLOAT32, device_) + ->RequiresGrad(); + + if (config_.use_kaiming_a) { + init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + } else { + init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + } + + // B matrix: [out_features, rank] + // Initialize with zeros (ensures LoRA starts as identity transformation) + parameters_[kParamLoraBName] + = std::make_shared(std::vector{out_features_, config_.rank}, DataType::kFLOAT32, device_) + ->RequiresGrad(); + init::Zeros(parameters_[kParamLoraBName]); +} + +void LoRALinear::FreezeBaseWeights() { + // Set requires_grad to false for base weights + parameters_[kParamWeightName]->set_requires_grad(false); + if (bias_) { + parameters_[kParamBiasName]->set_requires_grad(false); + } +} + +std::vector> LoRALinear::Forward(const std::vector> &input_tensors) { + const auto &input = input_tensors[0]; + + // Base linear computation: y = x @ W^T + b + auto base_output = std::make_shared()->Apply( + bias_ ? std::vector>{input, parameters_[kParamWeightName], parameters_[kParamBiasName]} + : std::vector>{input, parameters_[kParamWeightName]})[0]; + + if (merged_) { + // If merged, base weight already contains LoRA contribution + return {base_output}; + } + + // LoRA computation: delta = (alpha/r) * x @ A^T @ B^T + // A: [rank, in_features], B: [out_features, rank] + // x @ A^T: [batch, seq, in_features] @ [in_features, rank] = [batch, seq, rank] + // (x @ A^T) @ B^T: [batch, seq, rank] @ [rank, out_features] = [batch, seq, out_features] + + // Compute x @ A^T (using Linear function with A as weight, no bias) + auto hidden = std::make_shared()->Apply({input, parameters_[kParamLoraAName]})[0]; + + // Compute hidden @ B^T (using Linear function with B as weight, no bias) + auto lora_output = std::make_shared()->Apply({hidden, parameters_[kParamLoraBName]})[0]; + + // Scale and add: y = base_output + scaling * lora_output + float scaling = config_.Scaling(); + auto scaled_lora = lora_output->Mul(scaling); + + return {base_output->Add(scaled_lora)}; +} + +void LoRALinear::MergeWeights() { + if (merged_) { + return; + } + + // Save original weight for potential unmerge + original_weight_ = std::make_shared(*parameters_[kParamWeightName]); + + // W' = W + (alpha/r) * B @ A + // W: [out_features, in_features] + // B: [out_features, rank] + // A: [rank, in_features] + // B @ A: [out_features, in_features] + + auto lora_A = parameters_[kParamLoraAName]; + auto lora_B = parameters_[kParamLoraBName]; + + // Compute B @ A using matmul + auto delta = lora_B->Matmul(lora_A); // [out_features, in_features] + + // Scale and add to weight + float scaling = config_.Scaling(); + auto scaled_delta = delta->Mul(scaling); + auto new_weight = parameters_[kParamWeightName]->Add(scaled_delta); + + // Update weight data + parameters_[kParamWeightName]->CopyFrom(new_weight); + + merged_ = true; +} + +void LoRALinear::UnmergeWeights() { + if (!merged_ || !original_weight_) { + return; + } + + parameters_[kParamWeightName]->CopyFrom(original_weight_); + merged_ = false; +} + +std::vector> LoRALinear::LoRAParameters() const { + return {parameters_.at(kParamLoraAName), parameters_.at(kParamLoraBName)}; +} + +std::vector> LoRALinear::Parameters() const { + // Only return trainable LoRA parameters + return LoRAParameters(); +} + +std::vector> LoRALinear::AllParameters() const { + std::vector> all_params; + all_params.push_back(parameters_.at(kParamWeightName)); + if (bias_) { + all_params.push_back(parameters_.at(kParamBiasName)); + } + all_params.push_back(parameters_.at(kParamLoraAName)); + all_params.push_back(parameters_.at(kParamLoraBName)); + return all_params; +} + +int64_t LoRALinear::in_features() const { + return in_features_; +} + +int64_t LoRALinear::out_features() const { + return out_features_; +} + +int64_t LoRALinear::rank() const { + return config_.rank; +} + +float LoRALinear::scaling() const { + return config_.Scaling(); +} + +} // namespace infini_train::nn::lora diff --git a/infini_train/src/nn/lora/lora_model.cc b/infini_train/src/nn/lora/lora_model.cc new file mode 100644 index 0000000..5016da6 --- /dev/null +++ b/infini_train/src/nn/lora/lora_model.cc @@ -0,0 +1,73 @@ +#include "infini_train/include/nn/lora/lora_model.h" + +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/lora/lora_utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::lora { + +LoRAModel::LoRAModel(std::shared_ptr base_model, const LoRAConfig &config) + : base_model_(base_model), config_(config) { + // Inject LoRA layers into the base model using NamedModules + InjectLoRALayers(base_model_, config_); + + // Freeze base model parameters + FreezeBaseModel(base_model_); + + LOG(INFO) << "LoRAModel created with rank=" << config_.rank << ", alpha=" << config_.alpha; +} + +std::vector> LoRAModel::Forward(const std::vector> &inputs) { + return (*base_model_)(inputs); +} + +std::vector> LoRAModel::TrainableParameters() const { + return GetLoRAParameters(base_model_); +} + +std::vector> LoRAModel::Parameters() const { + return base_model_->Parameters(); +} + +void LoRAModel::SaveLoRA(const std::string &filepath) const { + SaveLoRAWeights(base_model_, filepath); +} + +void LoRAModel::LoadLoRA(const std::string &filepath) { + LoadLoRAWeights(base_model_, filepath); +} + +void LoRAModel::Merge() { + if (!merged_) { + MergeLoRAWeights(base_model_); + merged_ = true; + } +} + +void LoRAModel::Unmerge() { + if (merged_) { + UnmergeLoRAWeights(base_model_); + merged_ = false; + } +} + +void LoRAModel::PrintSummary() const { + PrintLoRASummary(base_model_); +} + +bool LoRAModel::IsMerged() const { + return merged_; +} + +std::shared_ptr LoRAModel::base_model() const { + return base_model_; +} + +const LoRAConfig &LoRAModel::config() const { + return config_; +} + +} // namespace infini_train::nn::lora diff --git a/infini_train/src/nn/lora/lora_parallel_linear.cc b/infini_train/src/nn/lora/lora_parallel_linear.cc new file mode 100644 index 0000000..6187f02 --- /dev/null +++ b/infini_train/src/nn/lora/lora_parallel_linear.cc @@ -0,0 +1,377 @@ +#include "infini_train/include/nn/lora/lora_parallel_linear.h" + +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/linear.h" +#include "infini_train/include/device.h" +#include "infini_train/include/nn/init.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/tensor_parallel.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::lora { + +// ============================================================================ +// LoRAColumnParallelLinear Implementation +// ============================================================================ + +LoRAColumnParallelLinear::LoRAColumnParallelLinear(std::shared_ptr base_module, const LoRAConfig &config, + int64_t in_features, int64_t out_features) + : CloneableModule(kType), config_(config), in_features_(in_features), out_features_(out_features), bias_(false), + gather_output_(false), input_is_parallel_(false), skip_bias_add_(false), sequence_parallel_(false) { + if (!base_module) { + throw std::invalid_argument("base_module cannot be null"); + } + + // Get device from base module + device_ = base_module->parameter(parallel::ColumnParallelLinear::kParamWeightName)->GetDevice(); + + // Transfer weight from base module + parameters_[kParamWeightName] = base_module->parameter(parallel::ColumnParallelLinear::kParamWeightName); + + // Get dimensions from weight shape [out_features_per_partition, in_features] + const auto &weight_dims = parameters_[kParamWeightName]->Dims(); + out_features_per_partition_ = weight_dims[0]; + + // Transfer bias if exists + if (base_module->has_parameter(parallel::ColumnParallelLinear::kParamBiasName)) { + parameters_[kParamBiasName] = base_module->parameter(parallel::ColumnParallelLinear::kParamBiasName); + bias_ = true; + } + + // Initialize LoRA weights + InitLoRAWeights(); + + // Freeze base weights + FreezeBaseWeights(); +} + +LoRAColumnParallelLinear::LoRAColumnParallelLinear(std::shared_ptr base_module, const LoRAConfig &config) + : CloneableModule(kType), config_(config), bias_(false), gather_output_(false), input_is_parallel_(false), + skip_bias_add_(false), sequence_parallel_(false) { + if (!base_module) { + throw std::invalid_argument("base_module cannot be null"); + } + + // Get device from base module + device_ = base_module->parameter(parallel::ColumnParallelLinear::kParamWeightName)->GetDevice(); + + // Transfer weight from base module + parameters_[kParamWeightName] = base_module->parameter(parallel::ColumnParallelLinear::kParamWeightName); + + // Get dimensions from weight shape [out_features_per_partition, in_features] + const auto &weight_dims = parameters_[kParamWeightName]->Dims(); + out_features_per_partition_ = weight_dims[0]; + in_features_ = weight_dims[1]; + + // Calculate total out_features (assuming tensor parallelism) + int tp_size = parallel::global::GetTensorParallelSize(); + out_features_ = out_features_per_partition_ * tp_size; + + // Transfer bias if exists + if (base_module->has_parameter(parallel::ColumnParallelLinear::kParamBiasName)) { + parameters_[kParamBiasName] = base_module->parameter(parallel::ColumnParallelLinear::kParamBiasName); + bias_ = true; + } + + // Initialize LoRA weights + InitLoRAWeights(); + + // Freeze base weights + FreezeBaseWeights(); +} + +void LoRAColumnParallelLinear::InitLoRAWeights() { + // A matrix: [rank, in_features] - replicated across TP ranks + parameters_[kParamLoraAName] + = std::make_shared(std::vector{config_.rank, in_features_}, DataType::kFLOAT32, device_) + ->RequiresGrad(); + + if (config_.use_kaiming_a) { + init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + } else { + init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + } + + // B matrix: [out_features_per_partition, rank] - sharded like base weight + parameters_[kParamLoraBName] + = std::make_shared(std::vector{out_features_per_partition_, config_.rank}, DataType::kFLOAT32, + device_) + ->RequiresGrad(); + init::Zeros(parameters_[kParamLoraBName]); +} + +void LoRAColumnParallelLinear::FreezeBaseWeights() { + parameters_[kParamWeightName]->set_requires_grad(false); + if (bias_) { + parameters_[kParamBiasName]->set_requires_grad(false); + } +} + +std::vector> +LoRAColumnParallelLinear::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1) << "LoRAColumnParallelLinear takes exactly one input"; + const auto &input = input_tensors[0]; + + // Base linear computation + auto base_output = std::make_shared()->Apply( + (bias_ && !skip_bias_add_) + ? std::vector>{input, parameters_[kParamWeightName], parameters_[kParamBiasName]} + : std::vector>{input, parameters_[kParamWeightName]})[0]; + + if (!merged_) { + // LoRA computation: x @ A^T @ B^T + // A is replicated [rank, in_features], so x @ A^T gives same result on all ranks + auto hidden = std::make_shared()->Apply({input, parameters_[kParamLoraAName]})[0]; + + // B is sharded [out_features_per_partition, rank], so hidden @ B^T gives sharded output + auto lora_output = std::make_shared()->Apply({hidden, parameters_[kParamLoraBName]})[0]; + + // Scale and add + float scaling = config_.Scaling(); + auto scaled_lora = lora_output->Mul(scaling); + base_output = base_output->Add(scaled_lora); + } + + return skip_bias_add_ ? std::vector>{base_output, + bias_ ? parameters_.at(kParamBiasName) : nullptr} + : std::vector>{base_output}; +} + +void LoRAColumnParallelLinear::MergeWeights() { + if (merged_) { + return; + } + + original_weight_ = std::make_shared(*parameters_[kParamWeightName]); + + // W' = W + (alpha/r) * B @ A + // W: [out_features_per_partition, in_features] + // B: [out_features_per_partition, rank] + // A: [rank, in_features] + auto delta = parameters_[kParamLoraBName]->Matmul(parameters_[kParamLoraAName]); + auto scaled_delta = delta->Mul(config_.Scaling()); + auto new_weight = parameters_[kParamWeightName]->Add(scaled_delta); + parameters_[kParamWeightName]->CopyFrom(new_weight); + + merged_ = true; +} + +void LoRAColumnParallelLinear::UnmergeWeights() { + if (!merged_ || !original_weight_) { + return; + } + parameters_[kParamWeightName]->CopyFrom(original_weight_); + merged_ = false; +} + +std::vector> LoRAColumnParallelLinear::LoRAParameters() const { + return {parameters_.at(kParamLoraAName), parameters_.at(kParamLoraBName)}; +} + +std::vector> LoRAColumnParallelLinear::Parameters() const { return LoRAParameters(); } + +bool LoRAColumnParallelLinear::IsMerged() const { + return merged_; +} + +int64_t LoRAColumnParallelLinear::in_features() const { + return in_features_; +} + +int64_t LoRAColumnParallelLinear::out_features() const { + return out_features_; +} + +int64_t LoRAColumnParallelLinear::rank() const { + return config_.rank; +} + +// ============================================================================ +// LoRARowParallelLinear Implementation +// ============================================================================ + +LoRARowParallelLinear::LoRARowParallelLinear(std::shared_ptr base_module, const LoRAConfig &config, + int64_t in_features, int64_t out_features) + : CloneableModule(kType), config_(config), in_features_(in_features), out_features_(out_features), bias_(false), + reduce_output_(false), input_is_parallel_(false), skip_bias_add_(false), sequence_parallel_(false) { + if (!base_module) { + throw std::invalid_argument("base_module cannot be null"); + } + + // Get device from base module + device_ = base_module->parameter(parallel::RowParallelLinear::kParamWeightName)->GetDevice(); + + // Transfer weight from base module + parameters_[kParamWeightName] = base_module->parameter(parallel::RowParallelLinear::kParamWeightName); + + // Get dimensions from weight shape [out_features, in_features_per_partition] + const auto &weight_dims = parameters_[kParamWeightName]->Dims(); + in_features_per_partition_ = weight_dims[1]; + + // Transfer bias if exists + if (base_module->has_parameter(parallel::RowParallelLinear::kParamBiasName)) { + parameters_[kParamBiasName] = base_module->parameter(parallel::RowParallelLinear::kParamBiasName); + bias_ = true; + } + + // Initialize LoRA weights + InitLoRAWeights(); + + // Freeze base weights + FreezeBaseWeights(); +} + +LoRARowParallelLinear::LoRARowParallelLinear(std::shared_ptr base_module, const LoRAConfig &config) + : CloneableModule(kType), config_(config), bias_(false), reduce_output_(false), input_is_parallel_(false), + skip_bias_add_(false), sequence_parallel_(false) { + if (!base_module) { + throw std::invalid_argument("base_module cannot be null"); + } + + // Get device from base module + device_ = base_module->parameter(parallel::RowParallelLinear::kParamWeightName)->GetDevice(); + + // Transfer weight from base module + parameters_[kParamWeightName] = base_module->parameter(parallel::RowParallelLinear::kParamWeightName); + + // Get dimensions from weight shape [out_features, in_features_per_partition] + const auto &weight_dims = parameters_[kParamWeightName]->Dims(); + out_features_ = weight_dims[0]; + in_features_per_partition_ = weight_dims[1]; + + // Calculate total in_features (assuming tensor parallelism) + int tp_size = parallel::global::GetTensorParallelSize(); + in_features_ = in_features_per_partition_ * tp_size; + + // Transfer bias if exists + if (base_module->has_parameter(parallel::RowParallelLinear::kParamBiasName)) { + parameters_[kParamBiasName] = base_module->parameter(parallel::RowParallelLinear::kParamBiasName); + bias_ = true; + } + + // Initialize LoRA weights + InitLoRAWeights(); + + // Freeze base weights + FreezeBaseWeights(); +} + +void LoRARowParallelLinear::InitLoRAWeights() { + // A matrix: [rank, in_features_per_partition] - sharded like base weight + parameters_[kParamLoraAName] + = std::make_shared(std::vector{config_.rank, in_features_per_partition_}, DataType::kFLOAT32, + device_) + ->RequiresGrad(); + + if (config_.use_kaiming_a) { + init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + } else { + init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + } + + // B matrix: [out_features, rank] - replicated across TP ranks + parameters_[kParamLoraBName] + = std::make_shared(std::vector{out_features_, config_.rank}, DataType::kFLOAT32, device_) + ->RequiresGrad(); + init::Zeros(parameters_[kParamLoraBName]); +} + +void LoRARowParallelLinear::FreezeBaseWeights() { + parameters_[kParamWeightName]->set_requires_grad(false); + if (bias_) { + parameters_[kParamBiasName]->set_requires_grad(false); + } +} + +std::vector> +LoRARowParallelLinear::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1) << "LoRARowParallelLinear takes exactly one input"; + const auto &input = input_tensors[0]; + + // Base linear computation (local matmul) + auto base_output + = std::make_shared()->Apply({input, parameters_[kParamWeightName]})[0]; + + if (!merged_) { + // LoRA computation for RowParallel: + // A is sharded [rank, in_features_per_partition] + // x_local @ A_local^T gives partial result [batch, seq, rank] + auto hidden_local = std::make_shared()->Apply({input, parameters_[kParamLoraAName]})[0]; + + // For RowParallel, we need to sum the partial results from all ranks + // This is handled by the reduce operation that follows + // B is replicated [out_features, rank] + auto lora_output = std::make_shared()->Apply({hidden_local, parameters_[kParamLoraBName]})[0]; + + // Scale and add to base output (before reduce) + float scaling = config_.Scaling(); + auto scaled_lora = lora_output->Mul(scaling); + base_output = base_output->Add(scaled_lora); + } + + // Handle bias + if (bias_ && !skip_bias_add_) { + base_output = base_output->Add(parameters_[kParamBiasName]); + } + + return skip_bias_add_ ? std::vector>{base_output, + bias_ ? parameters_.at(kParamBiasName) : nullptr} + : std::vector>{base_output}; +} + +void LoRARowParallelLinear::MergeWeights() { + if (merged_) { + return; + } + + original_weight_ = std::make_shared(*parameters_[kParamWeightName]); + + // W' = W + (alpha/r) * B @ A + // W: [out_features, in_features_per_partition] + // B: [out_features, rank] + // A: [rank, in_features_per_partition] + auto delta = parameters_[kParamLoraBName]->Matmul(parameters_[kParamLoraAName]); + auto scaled_delta = delta->Mul(config_.Scaling()); + auto new_weight = parameters_[kParamWeightName]->Add(scaled_delta); + parameters_[kParamWeightName]->CopyFrom(new_weight); + + merged_ = true; +} + +void LoRARowParallelLinear::UnmergeWeights() { + if (!merged_ || !original_weight_) { + return; + } + parameters_[kParamWeightName]->CopyFrom(original_weight_); + merged_ = false; +} + +std::vector> LoRARowParallelLinear::LoRAParameters() const { + return {parameters_.at(kParamLoraAName), parameters_.at(kParamLoraBName)}; +} + +std::vector> LoRARowParallelLinear::Parameters() const { return LoRAParameters(); } + +bool LoRARowParallelLinear::IsMerged() const { + return merged_; +} + +int64_t LoRARowParallelLinear::in_features() const { + return in_features_; +} + +int64_t LoRARowParallelLinear::out_features() const { + return out_features_; +} + +int64_t LoRARowParallelLinear::rank() const { + return config_.rank; +} + +} // namespace infini_train::nn::lora diff --git a/infini_train/src/nn/lora/lora_utils.cc b/infini_train/src/nn/lora/lora_utils.cc new file mode 100644 index 0000000..932124d --- /dev/null +++ b/infini_train/src/nn/lora/lora_utils.cc @@ -0,0 +1,340 @@ +#include "infini_train/include/nn/lora/lora_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/device.h" +#include "infini_train/include/nn/lora/lora_linear.h" +#include "infini_train/include/nn/lora/lora_model.h" +#include "infini_train/include/nn/lora/lora_parallel_linear.h" +#include "infini_train/include/nn/modules/linear.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/tensor_parallel.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::lora { + +std::shared_ptr GetLoRAModel(std::shared_ptr model, const LoRAConfig &config) { + // PEFT-style: Create LoRAModel wrapper which handles everything automatically + // Uses NamedModules() to traverse the entire model hierarchy + auto lora_model = std::make_shared(model, config); + LOG(INFO) << "GetLoRAModel: Created LoRA model with rank=" << config.rank << ", alpha=" << config.alpha; + return lora_model; +} + +void ReplaceModuleByPath(std::shared_ptr model, const std::string &path, std::shared_ptr new_module) { + // Parse the path (e.g., "transformer.h.0.attn.c_attn" -> ["transformer", "h", "0", "attn", "c_attn"]) + std::vector parts; + std::string remaining = path; + size_t pos = 0; + while ((pos = remaining.find('.')) != std::string::npos) { + parts.push_back(remaining.substr(0, pos)); + remaining = remaining.substr(pos + 1); + } + parts.push_back(remaining); + + // Navigate to parent module + std::shared_ptr current = model; + for (size_t i = 0; i < parts.size() - 1; ++i) { + current = current->mutable_module(parts[i]); + if (!current) { + LOG(ERROR) << "ReplaceModuleByPath: Failed to find path: " << path; + return; + } + } + + // Replace the module + const std::string &module_name = parts.back(); + current->replace_module(module_name, new_module); +} + +void InjectLoRALayers(std::shared_ptr model, const LoRAConfig &config) { + // Use NamedModules() to automatically traverse the entire model hierarchy + auto named_modules = model->NamedModules(); + + int lora_layers_applied = 0; + + for (const auto &[name, module] : named_modules) { + if (name.empty()) continue; // skip root module + + // Check if this module should have LoRA applied + if (!config.ShouldApplyLoRA(name)) continue; + + // Get module type and wrap if it's Linear/ColumnParallelLinear/RowParallelLinear + auto type = module->type(); + + if (type == Linear::kType) { + auto lora_module = std::make_shared(module, config); + ReplaceModuleByPath(model, name, lora_module); + lora_layers_applied++; + } else if (type == parallel::ColumnParallelLinear::kType) { + auto lora_module = std::make_shared(module, config); + ReplaceModuleByPath(model, name, lora_module); + lora_layers_applied++; + } else if (type == parallel::RowParallelLinear::kType) { + auto lora_module = std::make_shared(module, config); + ReplaceModuleByPath(model, name, lora_module); + lora_layers_applied++; + } + } + + LOG(INFO) << "InjectLoRALayers: Applied LoRA to " << lora_layers_applied << " layers " + << "(rank=" << config.rank << ", alpha=" << config.alpha << ")"; +} + +void FreezeBaseModel(std::shared_ptr model) { + model->Apply([](Module *m) { + for (auto &[name, param] : m->StateDict()) { + // Skip LoRA parameters + if (name.find("lora_A") != std::string::npos || name.find("lora_B") != std::string::npos) { + continue; + } + param->set_requires_grad(false); + } + }); +} + +void UnfreezeModel(std::shared_ptr model) { + model->Apply([](Module *m) { + for (auto &[name, param] : m->StateDict()) { + param->set_requires_grad(true); + } + }); +} + +std::vector> GetLoRAParameters(const std::shared_ptr &model) { + std::vector> lora_params; + + model->Apply([&lora_params](Module *m) { + // Check if this is a LoRA module + if (m->type() == LoRALinear::kType) { + auto lora_module = dynamic_cast(m); + if (lora_module) { + auto params = lora_module->LoRAParameters(); + lora_params.insert(lora_params.end(), params.begin(), params.end()); + } + } else if (m->type() == LoRAColumnParallelLinear::kType) { + auto lora_module = dynamic_cast(m); + if (lora_module) { + auto params = lora_module->LoRAParameters(); + lora_params.insert(lora_params.end(), params.begin(), params.end()); + } + } else if (m->type() == LoRARowParallelLinear::kType) { + auto lora_module = dynamic_cast(m); + if (lora_module) { + auto params = lora_module->LoRAParameters(); + lora_params.insert(lora_params.end(), params.begin(), params.end()); + } + } + }); + + return lora_params; +} + +std::vector> GetBaseParameters(const std::shared_ptr &model) { + std::vector> base_params; + + for (auto &[name, param] : model->StateDict()) { + // Skip LoRA parameters + if (name.find("lora_A") != std::string::npos || name.find("lora_B") != std::string::npos) { + continue; + } + base_params.push_back(param); + } + + return base_params; +} + +void MergeLoRAWeights(std::shared_ptr model) { + model->Apply([](Module *m) { + if (m->type() == LoRALinear::kType) { + dynamic_cast(m)->MergeWeights(); + } else if (m->type() == LoRAColumnParallelLinear::kType) { + dynamic_cast(m)->MergeWeights(); + } else if (m->type() == LoRARowParallelLinear::kType) { + dynamic_cast(m)->MergeWeights(); + } + }); +} + +void UnmergeLoRAWeights(std::shared_ptr model) { + model->Apply([](Module *m) { + if (m->type() == LoRALinear::kType) { + dynamic_cast(m)->UnmergeWeights(); + } else if (m->type() == LoRAColumnParallelLinear::kType) { + dynamic_cast(m)->UnmergeWeights(); + } else if (m->type() == LoRARowParallelLinear::kType) { + dynamic_cast(m)->UnmergeWeights(); + } + }); +} + +std::unordered_map> LoRAStateDict(const std::shared_ptr &model) { + std::unordered_map> lora_state_dict; + + for (auto &[name, param] : model->StateDict()) { + // Only include LoRA parameters + if (name.find("lora_A") != std::string::npos || name.find("lora_B") != std::string::npos) { + lora_state_dict[name] = param; + } + } + + return lora_state_dict; +} + +void LoadLoRAStateDict(std::shared_ptr model, + const std::unordered_map> &state_dict) { + auto model_state_dict = model->StateDict(); + + for (auto &[name, param] : state_dict) { + if (model_state_dict.find(name) != model_state_dict.end()) { + model_state_dict[name]->CopyFrom(param); + } else { + LOG(WARNING) << "LoRA parameter not found in model: " << name; + } + } +} + +void SaveLoRAWeights(const std::shared_ptr &model, const std::string &filepath) { + auto lora_state_dict = LoRAStateDict(model); + + std::ofstream file(filepath, std::ios::binary); + CHECK(file.is_open()) << "Failed to open file for writing: " << filepath; + + // Write magic number + uint32_t magic = 0x4C4F5241; // "LORA" + file.write(reinterpret_cast(&magic), sizeof(magic)); + + // Write version + uint32_t version = 1; + file.write(reinterpret_cast(&version), sizeof(version)); + + // Write number of tensors + uint32_t num_tensors = static_cast(lora_state_dict.size()); + file.write(reinterpret_cast(&num_tensors), sizeof(num_tensors)); + + // Write each tensor + for (const auto &[name, tensor] : lora_state_dict) { + // Write name length and name + uint32_t name_len = static_cast(name.length()); + file.write(reinterpret_cast(&name_len), sizeof(name_len)); + file.write(name.c_str(), name_len); + + // Write tensor dimensions + const auto &dims = tensor->Dims(); + uint32_t num_dims = static_cast(dims.size()); + file.write(reinterpret_cast(&num_dims), sizeof(num_dims)); + for (auto dim : dims) { + int64_t d = dim; + file.write(reinterpret_cast(&d), sizeof(d)); + } + + // Write tensor data (copy to CPU first if needed) + int64_t num_elements = tensor->NumElements(); + Tensor cpu_tensor = tensor->To(Device(Device::DeviceType::kCPU, 0)); + file.write(reinterpret_cast(cpu_tensor.DataPtr()), num_elements * sizeof(float)); + } + + file.close(); + LOG(INFO) << "Saved LoRA weights to " << filepath << " (" << num_tensors << " tensors)"; +} + +void LoadLoRAWeights(std::shared_ptr model, const std::string &filepath) { + std::ifstream file(filepath, std::ios::binary); + CHECK(file.is_open()) << "Failed to open file for reading: " << filepath; + + // Read and verify magic number + uint32_t magic; + file.read(reinterpret_cast(&magic), sizeof(magic)); + CHECK_EQ(magic, 0x4C4F5241) << "Invalid LoRA file format"; + + // Read version + uint32_t version; + file.read(reinterpret_cast(&version), sizeof(version)); + CHECK_EQ(version, 1) << "Unsupported LoRA file version: " << version; + + // Read number of tensors + uint32_t num_tensors; + file.read(reinterpret_cast(&num_tensors), sizeof(num_tensors)); + + auto model_state_dict = model->StateDict(); + + // Read each tensor + for (uint32_t i = 0; i < num_tensors; ++i) { + // Read name + uint32_t name_len; + file.read(reinterpret_cast(&name_len), sizeof(name_len)); + std::string name(name_len, '\0'); + file.read(&name[0], name_len); + + // Read dimensions + uint32_t num_dims; + file.read(reinterpret_cast(&num_dims), sizeof(num_dims)); + std::vector dims(num_dims); + for (uint32_t j = 0; j < num_dims; ++j) { + file.read(reinterpret_cast(&dims[j]), sizeof(int64_t)); + } + + // Calculate number of elements + int64_t num_elements = 1; + for (auto dim : dims) { + num_elements *= dim; + } + + // Read tensor data into a temporary CPU tensor + auto cpu_tensor = std::make_shared(dims, DataType::kFLOAT32, + Device(Device::DeviceType::kCPU, 0)); + file.read(reinterpret_cast(cpu_tensor->DataPtr()), num_elements * sizeof(float)); + + // Load into model + auto it = model_state_dict.find(name); + if (it != model_state_dict.end()) { + it->second->CopyFrom(cpu_tensor); + } else { + LOG(WARNING) << "LoRA parameter not found in model: " << name; + } + } + + file.close(); + LOG(INFO) << "Loaded LoRA weights from " << filepath << " (" << num_tensors << " tensors)"; +} + +int64_t CountTrainableParameters(const std::shared_ptr &model) { + int64_t count = 0; + for (auto ¶m : model->Parameters()) { + if (param->requires_grad()) { + count += param->NumElements(); + } + } + return count; +} + +int64_t CountTotalParameters(const std::shared_ptr &model) { + int64_t count = 0; + for (auto &[name, param] : model->StateDict()) { + count += param->NumElements(); + } + return count; +} + +void PrintLoRASummary(const std::shared_ptr &model) { + int64_t trainable = CountTrainableParameters(model); + int64_t total = CountTotalParameters(model); + int64_t frozen = total - trainable; + + double trainable_pct = 100.0 * trainable / total; + + std::cout << "========== LoRA Model Summary ==========" << std::endl; + std::cout << "Total parameters: " << total << std::endl; + std::cout << "Trainable parameters: " << trainable << " (" << trainable_pct << "%)" << std::endl; + std::cout << "Frozen parameters: " << frozen << std::endl; + std::cout << "=========================================" << std::endl; +} + +} // namespace infini_train::nn::lora diff --git a/infini_train/src/nn/modules/module.cc b/infini_train/src/nn/modules/module.cc index 1b764ed..4678743 100644 --- a/infini_train/src/nn/modules/module.cc +++ b/infini_train/src/nn/modules/module.cc @@ -129,6 +129,11 @@ Module::NamedModules(std::unordered_set *memory, const std::string &pr std::shared_ptr Module::mutable_module(const std::string &name) { return modules_.at(name); } +void Module::replace_module(const std::string &name, std::shared_ptr new_module) { + CHECK(modules_.find(name) != modules_.end()) << "Module not found: " << name; + modules_[name] = new_module; +} + const Module &Module::module(const std::string &name) const { CHECK(modules_.find(name) != modules_.end()); return *modules_.at(name).get(); diff --git a/test/lora/test_lora.cc b/test/lora/test_lora.cc new file mode 100644 index 0000000..7cc8ee6 --- /dev/null +++ b/test/lora/test_lora.cc @@ -0,0 +1,675 @@ +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/lora/lora_config.h" +#include "infini_train/include/nn/lora/lora_linear.h" +#include "infini_train/include/nn/lora/lora_model.h" +#include "infini_train/include/nn/lora/lora_utils.h" +#include "infini_train/include/nn/modules/linear.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/tensor.h" + +using namespace infini_train; +using namespace infini_train::nn::lora; + +// ============================================================================ +// Test 1: LoRAConfig +// ============================================================================ +void test_lora_config() { + std::cout << "\n=== Test 1: LoRAConfig ===" << std::endl; + + LoRAConfig config; + config.rank = 8; + config.alpha = 16.0f; + + // Test scaling calculation + float expected_scaling = 16.0f / 8.0f; + CHECK_EQ(config.Scaling(), expected_scaling) << "Scaling calculation failed"; + std::cout << "Scaling: " << config.Scaling() << " (expected: " << expected_scaling << ")" << std::endl; + + // Test ShouldApplyLoRA + CHECK(config.ShouldApplyLoRA("c_attn")) << "Should match c_attn"; + CHECK(config.ShouldApplyLoRA("transformer.h.0.attn.c_attn")) << "Should match nested c_attn"; + CHECK(config.ShouldApplyLoRA("c_proj")) << "Should match c_proj"; + CHECK(!config.ShouldApplyLoRA("c_fc")) << "Should not match c_fc (not in default targets)"; + CHECK(!config.ShouldApplyLoRA("random_layer")) << "Should not match random_layer"; + + std::cout << "LoRAConfig tests passed!" << std::endl; +} + +// ============================================================================ +// Test 2: LoRALinear Initialization +// ============================================================================ +void test_lora_linear_init() { + std::cout << "\n=== Test 2: LoRALinear Initialization ===" << std::endl; + + LoRAConfig config; + config.rank = 4; + config.alpha = 8.0f; + + int64_t in_features = 64; + int64_t out_features = 128; + + auto lora_linear = std::make_shared(in_features, out_features, config, /*bias=*/true); + + // Check parameter shapes + auto weight = lora_linear->parameter(LoRALinear::kParamWeightName); + auto bias = lora_linear->parameter(LoRALinear::kParamBiasName); + auto lora_A = lora_linear->parameter(LoRALinear::kParamLoraAName); + auto lora_B = lora_linear->parameter(LoRALinear::kParamLoraBName); + + CHECK_EQ(weight->Dims().size(), 2); + CHECK_EQ(weight->Dims()[0], out_features); + CHECK_EQ(weight->Dims()[1], in_features); + std::cout << "Weight shape: [" << weight->Dims()[0] << ", " << weight->Dims()[1] << "]" << std::endl; + + CHECK_EQ(bias->Dims().size(), 1); + CHECK_EQ(bias->Dims()[0], out_features); + std::cout << "Bias shape: [" << bias->Dims()[0] << "]" << std::endl; + + CHECK_EQ(lora_A->Dims().size(), 2); + CHECK_EQ(lora_A->Dims()[0], config.rank); + CHECK_EQ(lora_A->Dims()[1], in_features); + std::cout << "LoRA A shape: [" << lora_A->Dims()[0] << ", " << lora_A->Dims()[1] << "]" << std::endl; + + CHECK_EQ(lora_B->Dims().size(), 2); + CHECK_EQ(lora_B->Dims()[0], out_features); + CHECK_EQ(lora_B->Dims()[1], config.rank); + std::cout << "LoRA B shape: [" << lora_B->Dims()[0] << ", " << lora_B->Dims()[1] << "]" << std::endl; + + // Check requires_grad + CHECK(!weight->requires_grad()) << "Base weight should be frozen"; + CHECK(!bias->requires_grad()) << "Base bias should be frozen"; + CHECK(lora_A->requires_grad()) << "LoRA A should be trainable"; + CHECK(lora_B->requires_grad()) << "LoRA B should be trainable"; + std::cout << "requires_grad check passed!" << std::endl; + + // Check Parameters() returns only LoRA params + auto params = lora_linear->Parameters(); + CHECK_EQ(params.size(), 2) << "Parameters() should return only LoRA params"; + std::cout << "Parameters() returns " << params.size() << " tensors (LoRA A and B)" << std::endl; + + std::cout << "LoRALinear initialization tests passed!" << std::endl; +} + +// ============================================================================ +// Test 3: LoRALinear Forward Pass +// ============================================================================ +void test_lora_linear_forward() { + std::cout << "\n=== Test 3: LoRALinear Forward Pass ===" << std::endl; + + LoRAConfig config; + config.rank = 4; + config.alpha = 8.0f; + + int64_t in_features = 64; + int64_t out_features = 128; + int64_t batch_size = 2; + int64_t seq_len = 10; + + auto lora_linear = std::make_shared(in_features, out_features, config, /*bias=*/true); + + // Create input tensor + auto input = std::make_shared(std::vector{batch_size, seq_len, in_features}, DataType::kFLOAT32); + + // Forward pass + auto output = (*lora_linear)({input})[0]; + + // Check output shape + CHECK_EQ(output->Dims().size(), 3); + CHECK_EQ(output->Dims()[0], batch_size); + CHECK_EQ(output->Dims()[1], seq_len); + CHECK_EQ(output->Dims()[2], out_features); + std::cout << "Output shape: [" << output->Dims()[0] << ", " << output->Dims()[1] << ", " << output->Dims()[2] << "]" + << std::endl; + + std::cout << "LoRALinear forward pass tests passed!" << std::endl; +} + +// ============================================================================ +// Test 4: LoRALinear Weight Merging +// ============================================================================ +void test_lora_linear_merge() { + std::cout << "\n=== Test 4: LoRALinear Weight Merging ===" << std::endl; + + LoRAConfig config; + config.rank = 4; + config.alpha = 8.0f; + + int64_t in_features = 32; + int64_t out_features = 64; + + auto lora_linear = std::make_shared(in_features, out_features, config, /*bias=*/false); + + // Print weight sum before merge + auto weight_before = lora_linear->parameter(LoRALinear::kParamWeightName); + auto lora_A = lora_linear->parameter(LoRALinear::kParamLoraAName); + auto lora_B = lora_linear->parameter(LoRALinear::kParamLoraBName); + + float weight_before_sum = weight_before->EigenMatrix().sum(); + float lora_A_sum = lora_A->EigenMatrix().sum(); + float lora_B_sum = lora_B->EigenMatrix().sum(); + + std::cout << "\n--- Before Merge ---" << std::endl; + std::cout << "Base weight sum: " << weight_before_sum << std::endl; + std::cout << "LoRA A sum: " << lora_A_sum << std::endl; + std::cout << "LoRA B sum: " << lora_B_sum << std::endl; + std::cout << "Scaling (alpha/r): " << config.Scaling() << std::endl; + + // Create input + auto input = std::make_shared(std::vector{2, 5, in_features}, DataType::kFLOAT32); + input->EigenMatrix().setRandom(); + + // Get output before merge + auto output_before = (*lora_linear)({input})[0]; + float output_before_sum = output_before->EigenMatrix().sum(); + std::cout << "Output sum before merge: " << output_before_sum << std::endl; + + // Merge weights + CHECK(!lora_linear->IsMerged()) << "Should not be merged initially"; + lora_linear->MergeWeights(); + CHECK(lora_linear->IsMerged()) << "Should be merged after MergeWeights()"; + std::cout << "\nWeights merged successfully" << std::endl; + + // Print weight sum after merge + auto weight_after = lora_linear->parameter(LoRALinear::kParamWeightName); + float weight_after_sum = weight_after->EigenMatrix().sum(); + std::cout << "\n--- After Merge ---" << std::endl; + std::cout << "Base weight sum after merge: " << weight_after_sum << std::endl; + std::cout << "Weight change (should be ~LoRA contribution): " << (weight_after_sum - weight_before_sum) << std::endl; + + // Get output after merge + auto output_merged = (*lora_linear)({input})[0]; + float output_merged_sum = output_merged->EigenMatrix().sum(); + std::cout << "Output sum after merge: " << output_merged_sum << std::endl; + + // Verify: output_after should equal output_before (numerically) + std::cout << "\nVerification: output_before == output_after? " << std::endl; + std::cout << " Before: " << output_before_sum << std::endl; + std::cout << " After: " << output_merged_sum << std::endl; + std::cout << " Diff: " << std::abs(output_before_sum - output_merged_sum) << std::endl; + CHECK(std::abs(output_before_sum - output_merged_sum) < 1e-3) << "Outputs should be numerically identical!"; + + // Shape comparison (always same) + std::cout << "\nOutput shape: [" << output_before->Dims()[0] << ", " << output_before->Dims()[1] << ", " + << output_before->Dims()[2] << "] (unchanged)" << std::endl; + + // Unmerge weights + lora_linear->UnmergeWeights(); + CHECK(!lora_linear->IsMerged()) << "Should not be merged after UnmergeWeights()"; + + // Print weight sum after unmerge + auto weight_unmerged = lora_linear->parameter(LoRALinear::kParamWeightName); + float weight_unmerged_sum = weight_unmerged->EigenMatrix().sum(); + std::cout << "\n--- After Unmerge ---" << std::endl; + std::cout << "Base weight sum after unmerge: " << weight_unmerged_sum << std::endl; + + // Verify: weight should be restored to original value + std::cout << "\nVerification: weight restored after unmerge? " << std::endl; + std::cout << " Original: " << weight_before_sum << std::endl; + std::cout << " Unmerged: " << weight_unmerged_sum << std::endl; + std::cout << " Diff: " << std::abs(weight_before_sum - weight_unmerged_sum) << std::endl; + CHECK(std::abs(weight_before_sum - weight_unmerged_sum) < 1e-5) << "Weight should be restored!"; + + // Get output after unmerge + auto output_unmerged = (*lora_linear)({input})[0]; + float output_unmerged_sum = output_unmerged->EigenMatrix().sum(); + std::cout << "Output sum after unmerge: " << output_unmerged_sum << std::endl; + + // Shape comparison: merge doesn't change shape, only weights + CHECK(output_before->Dims() == output_merged->Dims()) << "Shape should be identical after merge"; + CHECK(output_merged->Dims() == output_unmerged->Dims()) << "Shape should be identical after unmerge"; + + std::cout << "\nLoRALinear weight merging tests passed!" << std::endl; +} + +// ============================================================================ +// Test 5: LoRA Utility Functions +// ============================================================================ +void test_lora_utils() { + std::cout << "\n=== Test 5: LoRA Utility Functions ===" << std::endl; + + LoRAConfig config; + config.rank = 4; + config.alpha = 8.0f; + + auto lora_linear = std::make_shared(32, 64, config, /*bias=*/true); + + // Test GetLoRAParameters + auto lora_params = GetLoRAParameters(lora_linear); + CHECK_EQ(lora_params.size(), 2) << "Should have 2 LoRA parameters"; + std::cout << "GetLoRAParameters returned " << lora_params.size() << " parameters" << std::endl; + + // Test CountTrainableParameters + int64_t trainable = CountTrainableParameters(lora_linear); + int64_t expected_trainable = config.rank * 32 + 64 * config.rank; // A: [4, 32], B: [64, 4] + CHECK_EQ(trainable, expected_trainable) << "Trainable parameter count mismatch"; + std::cout << "Trainable parameters: " << trainable << " (expected: " << expected_trainable << ")" << std::endl; + + // Test CountTotalParameters + int64_t total = CountTotalParameters(lora_linear); + int64_t expected_total = 64 * 32 + 64 + config.rank * 32 + 64 * config.rank; // weight + bias + A + B + CHECK_EQ(total, expected_total) << "Total parameter count mismatch"; + std::cout << "Total parameters: " << total << " (expected: " << expected_total << ")" << std::endl; + + // Test PrintLoRASummary + std::cout << "\nLoRA Summary:" << std::endl; + PrintLoRASummary(lora_linear); + + std::cout << "LoRA utility function tests passed!" << std::endl; +} + +// ============================================================================ +// Test 6: LoRALinear from existing Linear +// ============================================================================ +void test_lora_from_linear() { + std::cout << "\n=== Test 6: LoRALinear from existing Linear ===" << std::endl; + + // Create a standard Linear layer + auto linear = std::make_shared(64, 128, /*bias=*/true); + + // Wrap it with LoRA + LoRAConfig config; + config.rank = 8; + config.alpha = 16.0f; + + auto lora_linear = std::make_shared(linear, config); + + // Check dimensions + CHECK_EQ(lora_linear->in_features(), 64); + CHECK_EQ(lora_linear->out_features(), 128); + CHECK_EQ(lora_linear->rank(), 8); + std::cout << "LoRALinear created from Linear: in=" << lora_linear->in_features() + << ", out=" << lora_linear->out_features() << ", rank=" << lora_linear->rank() << std::endl; + + // Test forward pass + auto input = std::make_shared(std::vector{2, 10, 64}, DataType::kFLOAT32); + auto output = (*lora_linear)({input})[0]; + + CHECK_EQ(output->Dims()[0], 2); + CHECK_EQ(output->Dims()[1], 10); + CHECK_EQ(output->Dims()[2], 128); + std::cout << "Forward pass successful, output shape: [" << output->Dims()[0] << ", " << output->Dims()[1] << ", " + << output->Dims()[2] << "]" << std::endl; + + std::cout << "LoRALinear from existing Linear tests passed!" << std::endl; +} + +// ============================================================================ +// Test 7: LoRAModel Wrapper (simplified test for wrapper interface) +// ============================================================================ +void test_lora_model_wrapper() { + std::cout << "\n=== Test 7: LoRAModel Wrapper (Simplified) ===" << std::endl; + + // Create LoRA config + LoRAConfig lora_config; + lora_config.rank = 8; + lora_config.alpha = 16.0f; + + // Create base Linear module (simple test without InjectLoRALayers) + auto base_linear = std::make_shared(64, 128, /*bias=*/true); + + // Create a minimal wrapper test by manually testing what LoRAModel does + // Apply LoRA directly to the Linear layer + auto lora_linear = std::make_shared(base_linear, lora_config); + + // Replace the base_linear in its container + // Note: In a real use case, you would use InjectLoRALayers on a transformer model + + // Test GetLoRAParameters on the LoRA Linear + auto lora_params = GetLoRAParameters(lora_linear); + CHECK_GT(lora_params.size(), 0) << "Should have trainable parameters"; + std::cout << "LoRA parameters extracted: " << lora_params.size() << std::endl; + + // Test CountTrainableParameters + int64_t trainable = CountTrainableParameters(lora_linear); + CHECK_EQ(trainable, lora_config.rank * 64 + 128 * lora_config.rank); + std::cout << "Trainable parameters: " << trainable << std::endl; + + // Test PrintSummary + std::cout << "\nLoRA Summary for Linear wrapper:" << std::endl; + PrintLoRASummary(lora_linear); + + // Test Save/Load LoRA on the LoRA Linear + const std::string test_path = "/tmp/test_lora_linear.bin"; + SaveLoRAWeights(lora_linear, test_path); + std::cout << "SaveLoRAWeights completed" << std::endl; + + LoadLoRAWeights(lora_linear, test_path); + std::cout << "LoadLoRAWeights completed" << std::endl; + + // Test Merge/Unmerge on LoRA Linear + CHECK(!lora_linear->IsMerged()) << "Should not be merged initially"; + lora_linear->MergeWeights(); + CHECK(lora_linear->IsMerged()) << "Should be merged after MergeWeights()"; + std::cout << "MergeWeights completed" << std::endl; + + lora_linear->UnmergeWeights(); + CHECK(!lora_linear->IsMerged()) << "Should be unmerged after UnmergeWeights()"; + std::cout << "UnmergeWeights completed" << std::endl; + + std::cout << "LoRAModel wrapper tests passed!" << std::endl; +} + +// ============================================================================ +// Test 8: SetTargetModules parsing +// ============================================================================ +void test_set_target_modules() { + std::cout << "\n=== Test 8: SetTargetModules Parsing ===" << std::endl; + + LoRAConfig config; + + // Test single target + config.SetTargetModules("c_attn"); + CHECK_EQ(config.target_modules.size(), 1); + CHECK(config.target_modules.count("c_attn")); + std::cout << "Single target: OK" << std::endl; + + // Test multiple targets + config.SetTargetModules("c_attn,c_proj,c_fc"); + CHECK_EQ(config.target_modules.size(), 3); + CHECK(config.target_modules.count("c_attn")); + CHECK(config.target_modules.count("c_proj")); + CHECK(config.target_modules.count("c_fc")); + std::cout << "Multiple targets: OK" << std::endl; + + // Test with spaces + config.SetTargetModules("c_attn, c_proj , c_fc"); + CHECK_EQ(config.target_modules.size(), 3); + std::cout << "Targets with spaces: OK" << std::endl; + + // Test empty/whitespace + config.SetTargetModules("c_attn,,c_proj"); + CHECK_EQ(config.target_modules.size(), 2); + std::cout << "Empty entries ignored: OK" << std::endl; + + std::cout << "SetTargetModules tests passed!" << std::endl; +} + +// ============================================================================ +// Test 9: ShouldApplyLoRA edge cases (attn.c_proj vs mlp.c_proj) +// ============================================================================ +void test_should_apply_lora_edge_cases() { + std::cout << "\n=== Test 9: ShouldApplyLoRA Edge Cases ===" << std::endl; + + // Test: Only attn.c_proj in target_modules + { + LoRAConfig config; + config.SetTargetModules("c_attn,attn.c_proj"); + + // Should match attention paths + CHECK(config.ShouldApplyLoRA("attn.c_proj")); + CHECK(config.ShouldApplyLoRA("transformer.h.0.attn.c_proj")); + CHECK(config.ShouldApplyLoRA("transformer.h.1.attn.c_proj")); + + // Should NOT match mlp paths + CHECK(!config.ShouldApplyLoRA("mlp.c_proj")); + CHECK(!config.ShouldApplyLoRA("transformer.h.0.mlp.c_proj")); + std::cout << "attn.c_proj only: OK" << std::endl; + } + + // Test: Only mlp.c_proj in target_modules + { + LoRAConfig config; + config.SetTargetModules("c_attn,mlp.c_proj"); + + // Should NOT match attention paths + CHECK(!config.ShouldApplyLoRA("attn.c_proj")); + CHECK(!config.ShouldApplyLoRA("transformer.h.0.attn.c_proj")); + + // Should match mlp paths + CHECK(config.ShouldApplyLoRA("mlp.c_proj")); + CHECK(config.ShouldApplyLoRA("transformer.h.0.mlp.c_proj")); + std::cout << "mlp.c_proj only: OK" << std::endl; + } + + // Test: Generic c_proj in target_modules (matches both) + { + LoRAConfig config; + config.SetTargetModules("c_attn,c_proj"); + + // Should match both attention and mlp + CHECK(config.ShouldApplyLoRA("transformer.h.0.attn.c_proj")); + CHECK(config.ShouldApplyLoRA("transformer.h.0.mlp.c_proj")); + std::cout << "Generic c_proj (matches both): OK" << std::endl; + } + + // Test: All targets + { + LoRAConfig config; + config.SetTargetModules("c_attn,attn.c_proj,c_fc,c_fc2,mlp.c_proj"); + + CHECK(config.ShouldApplyLoRA("transformer.h.0.attn.c_attn")); + CHECK(config.ShouldApplyLoRA("transformer.h.0.attn.c_proj")); + CHECK(config.ShouldApplyLoRA("transformer.h.0.mlp.c_fc")); + CHECK(config.ShouldApplyLoRA("transformer.h.0.mlp.c_fc2")); + CHECK(config.ShouldApplyLoRA("transformer.h.0.mlp.c_proj")); + std::cout << "All targets: OK" << std::endl; + } + + std::cout << "ShouldApplyLoRA edge cases tests passed!" << std::endl; +} + +// ============================================================================ +// Test 10: ReplaceModuleByPath +// ============================================================================ +void test_replace_module_by_path() { + std::cout << "\n=== Test 10: ReplaceModuleByPath ===" << std::endl; + + // Create a container module with nested structure + auto container = std::make_shared("Container"); + auto sub_module = std::make_shared(64, 128, /*bias=*/true); + container->add_module("sub", sub_module); + + // Create a replacement module + auto replacement = std::make_shared(64, 128, /*bias=*/false); + + // Replace by path + ReplaceModuleByPath(container, "sub", replacement); + + // Verify replacement + auto replaced = container->mutable_module("sub"); + CHECK(replaced == replacement); + CHECK(replaced->type() == nn::Linear::kType); + std::cout << "Simple path replacement: OK" << std::endl; + + // Test with nested container (simulated) + auto nested_container = std::make_shared("NestedContainer"); + nested_container->add_module("level1", container); + + auto another_replacement = std::make_shared(64, 128, /*bias=*/true); + ReplaceModuleByPath(nested_container, "level1.sub", another_replacement); + + auto final_replaced = nested_container->mutable_module("level1")->mutable_module("sub"); + CHECK(final_replaced == another_replacement); + std::cout << "Nested path replacement: OK" << std::endl; + + // Test invalid path + auto test_container = std::make_shared("Test"); + test_container->add_module("exists", std::make_shared(64, 128)); + ReplaceModuleByPath(test_container, "nonexistent.path", replacement); + // Should not crash, just log error + std::cout << "Invalid path handling: OK" << std::endl; + + std::cout << "ReplaceModuleByPath tests passed!" << std::endl; +} + +// ============================================================================ +// Test 11: FreezeBaseModel / UnfreezeModel +// ============================================================================ +void test_freeze_unfreeze() { + std::cout << "\n=== Test 11: FreezeBaseModel / UnfreezeModel ===" << std::endl; + + // Create a simple model with multiple Linear layers + auto model = std::make_shared("TestModel"); + auto linear1 = std::make_shared(64, 128, /*bias=*/true); + auto linear2 = std::make_shared(128, 256, /*bias=*/true); + model->add_module("layer1", linear1); + model->add_module("layer2", linear2); + + // Initially, all parameters should be trainable + int64_t initial_trainable = CountTrainableParameters(model); + CHECK_EQ(initial_trainable, (64 * 128 + 128) + (128 * 256 + 256)); + std::cout << "Initial trainable params: " << initial_trainable << std::endl; + + // Freeze base model (but we need to add LoRA params first) + LoRAConfig lora_config; + lora_config.rank = 4; + lora_config.alpha = 8.0f; + + auto lora1 = std::make_shared(linear1, lora_config); + auto lora2 = std::make_shared(linear2, lora_config); + model->add_module("lora1", lora1); + model->add_module("lora2", lora2); + + // Now freeze base model + FreezeBaseModel(model); + + // Count trainable parameters (should be only LoRA params) + int64_t after_freeze = CountTrainableParameters(model); + int64_t expected_lora = lora_config.rank * 64 + 128 * lora_config.rank + // lora1 A + B + lora_config.rank * 128 + 256 * lora_config.rank; // lora2 A + B + CHECK_EQ(after_freeze, expected_lora); + std::cout << "After freeze trainable: " << after_freeze << " (expected: " << expected_lora << ")" << std::endl; + + // Unfreeze all + UnfreezeModel(model); + int64_t after_unfreeze = CountTrainableParameters(model); + // Should be back to all params trainable + CHECK_GT(after_unfreeze, 0); + std::cout << "After unfreeze trainable: " << after_unfreeze << std::endl; + + std::cout << "FreezeBaseModel / UnfreezeModel tests passed!" << std::endl; +} + +// ============================================================================ +// Test 12: LoRAStateDict +// ============================================================================ +void test_lora_state_dict() { + std::cout << "\n=== Test 12: LoRAStateDict ===" << std::endl; + + // Create a model with LoRA + auto model = std::make_shared("TestModel"); + auto linear = std::make_shared(64, 128, /*bias=*/true); + model->add_module("linear", linear); + + LoRAConfig lora_config; + lora_config.rank = 4; + lora_config.alpha = 8.0f; + + auto lora_linear = std::make_shared(linear, lora_config); + model->replace_module("linear", lora_linear); + + // Get LoRA state dict + auto lora_dict = LoRAStateDict(model); + + // Should contain only LoRA parameters (lora_A and lora_B) + CHECK_EQ(lora_dict.size(), 2); + std::cout << "LoRA state dict size: " << lora_dict.size() << std::endl; + + bool has_lora_a = false, has_lora_b = false; + for (const auto &[name, param] : lora_dict) { + if (name.find("lora_A") != std::string::npos) has_lora_a = true; + if (name.find("lora_B") != std::string::npos) has_lora_b = true; + } + CHECK(has_lora_a && has_lora_b) << "Should have both lora_A and lora_B"; + std::cout << "Contains lora_A: " << (has_lora_a ? "YES" : "NO") << std::endl; + std::cout << "Contains lora_B: " << (has_lora_b ? "YES" : "NO") << std::endl; + + std::cout << "LoRAStateDict tests passed!" << std::endl; +} + +// ============================================================================ +// Test 13: GetLoRAModel simplified API +// ============================================================================ +void test_get_lora_model() { + std::cout << "\n=== Test 13: GetLoRAModel Simplified API ===" << std::endl; + + // Create a simple model + auto model = std::make_shared("SimpleModel"); + auto linear1 = std::make_shared(64, 128, /*bias=*/true); + auto linear2 = std::make_shared(128, 256, /*bias=*/true); + model->add_module("linear1", linear1); + model->add_module("linear2", linear2); + + // Configure LoRA + LoRAConfig config; + config.rank = 4; + config.alpha = 8.0f; + config.SetTargetModules("linear1"); // Only apply to linear1 + + // Use GetLoRAModel (simplified API) + auto lora_model = GetLoRAModel(model, config); + + CHECK(lora_model != nullptr); + std::cout << "GetLoRAModel returned valid pointer" << std::endl; + + // Test that LoRA was applied + auto lora_params = lora_model->TrainableParameters(); + // linear1 has LoRA: rank * 64 + 128 * rank = 4*64 + 128*4 = 256 + 512 = 768 + CHECK_EQ(lora_params.size(), 768); + std::cout << "Trainable parameters from LoRA model: " << lora_params.size() << std::endl; + + // Test PrintSummary + std::cout << "\nLoRA Model Summary:" << std::endl; + lora_model->PrintSummary(); + + // Test base_model access + auto base = lora_model->base_model(); + CHECK(base != nullptr); + std::cout << "base_model() returns valid pointer" << std::endl; + + // Test config access + auto cfg = lora_model->config(); + CHECK_EQ(cfg.rank, 4); + CHECK_EQ(cfg.alpha, 8.0f); + std::cout << "config() returns correct values" << std::endl; + + // Test Merge/Unmerge + CHECK(!lora_model->IsMerged()); + lora_model->Merge(); + CHECK(lora_model->IsMerged()); + std::cout << "Merge/Unmerge: OK" << std::endl; + + lora_model->Unmerge(); + CHECK(!lora_model->IsMerged()); + + std::cout << "GetLoRAModel simplified API tests passed!" << std::endl; +} + +int main(int argc, char **argv) { + google::InitGoogleLogging(argv[0]); + FLAGS_logtostderr = 1; + + // Initialize parallel settings (required for some tensor operations) + // Parameters: nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, + // pipeline_parallel_size, virtual_pipeline_parallel_size + nn::parallel::global::InitAllEnv(1, 1, false, 1, 1); + + std::cout << "========================================" << std::endl; + std::cout << " LoRA Module Unit Tests " << std::endl; + std::cout << "========================================" << std::endl; + + test_lora_config(); + test_lora_linear_init(); + test_lora_linear_forward(); + test_lora_linear_merge(); + test_lora_utils(); + test_lora_from_linear(); + test_lora_model_wrapper(); + test_set_target_modules(); + test_should_apply_lora_edge_cases(); + test_replace_module_by_path(); + test_freeze_unfreeze(); + test_lora_state_dict(); + test_get_lora_model(); + + std::cout << "\n========================================" << std::endl; + std::cout << " All LoRA Tests Passed! " << std::endl; + std::cout << "========================================" << std::endl; + + return 0; +}