From b97bf824ea76238e5372e7ee3d2af934630911bc Mon Sep 17 00:00:00 2001 From: chen Date: Wed, 4 Feb 2026 07:00:56 +0000 Subject: [PATCH] refactor: change NamedModules return type to std::vector - Rename private NamedModules() to use odered std::vector return type - Add public named_modules() wrapper with recurse/remove_duplicate params - Add BuildNameMap() to create module->name map for precision checking Co-Authored-By: Claude Sonnet 4.5 --- example/gpt2/main.cc | 2 + example/llama3/main.cc | 2 + infini_train/include/nn/modules/module.h | 8 +-- .../include/utils/precision_checker.h | 4 ++ infini_train/src/nn/modules/module.cc | 63 ++++++++++++------- infini_train/src/utils/precision_checker.cc | 44 +++++++++++-- scripts/compare_loss.py | 4 +- scripts/compare_tps.py | 4 +- 8 files changed, 96 insertions(+), 35 deletions(-) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 5bc66a9..3219c1f 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -190,6 +190,8 @@ void Train(const nn::parallel::Rank &rank) { model->To(device); + utils::PrecisionChecker::BuildNameMap(model.get()); + // select the data type // TODO(lzm): change to solely rely on the weight file info for determining the dtype when autocast is supported DataType dtype; diff --git a/example/llama3/main.cc b/example/llama3/main.cc index e9b5090..ff9b666 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -169,6 +169,8 @@ void Train(const nn::parallel::Rank &rank) { model->To(device); + utils::PrecisionChecker::BuildNameMap(model.get()); + LOG(INFO) << "Rank " << rank.GlobalRank() << ": Model loaded to device."; DataType dtype; diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index b36a59d..398166b 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -80,6 +80,10 @@ class Module : public std::enable_shared_from_this { virtual std::shared_ptr ReplicateForDataParallel(int device_idx) const; + std::vector>> + NamedModules(std::unordered_set *memory = nullptr, const std::string &prefix = "", + bool remove_duplicate = true); + // Hook registration methods std::shared_ptr RegisterForwardPreHook(ModulePreHook hook); std::shared_ptr RegisterForwardPostHook(ModulePostHook hook); @@ -99,10 +103,6 @@ class Module : public std::enable_shared_from_this { std::vector backward_post_hooks_; private: - std::unordered_map> - NamedModules(const std::string &prefix = "", bool remove_duplicate = true, - std::unordered_set *memory = nullptr); - friend std::vector> parallel::function::Replicate(const std::shared_ptr &network, const std::vector &devices); }; diff --git a/infini_train/include/utils/precision_checker.h b/infini_train/include/utils/precision_checker.h index 2d83569..87214fb 100644 --- a/infini_train/include/utils/precision_checker.h +++ b/infini_train/include/utils/precision_checker.h @@ -38,6 +38,10 @@ class PrecisionChecker { // Called automatically by PrecisionCheckEnv::Init when level >= MODULE static void Init(const PrecisionCheckConfig &global_config, const Config &config = DefaultConfig()); + // Build name map from root_model without registering hooks + // Called by PrecisionCheckEnv::RegisterWithRootModel + static void BuildNameMap(nn::Module *root_model); + static void RegisterForFunction(autograd::Function *func, const std::string &name = "", const Config &config = DefaultConfig()); diff --git a/infini_train/src/nn/modules/module.cc b/infini_train/src/nn/modules/module.cc index b86a1be..0ac1165 100644 --- a/infini_train/src/nn/modules/module.cc +++ b/infini_train/src/nn/modules/module.cc @@ -1,5 +1,6 @@ #include "infini_train/include/nn/modules/module.h" +#include #include #include #include @@ -71,40 +72,58 @@ std::vector> Module::Buffers() const { std::vector> Module::modules() { std::vector> modules; auto named_modules = NamedModules(); - for (auto &[_, module] : named_modules) { - if (_ != "") { + + std::shared_ptr root; + for (auto &[name, module] : named_modules) { + if (name != "") { modules.push_back(module); + } else { + root = module; } } - modules.insert(modules.begin(), named_modules[""]); + + modules.insert(modules.begin(), root); return modules; } -// FIXME(dcj): can not call this function in constructor -std::unordered_map> -Module::NamedModules(const std::string &prefix, bool remove_duplicate, std::unordered_set *memory) { +std::vector>> +Module::NamedModules(std::unordered_set *memory, const std::string &prefix, bool remove_duplicate) { std::unordered_set local_memory; if (memory == nullptr) { memory = &local_memory; } - std::unordered_map> named_modules; - if (!memory->contains(this)) { - if (remove_duplicate) { - memory->insert(this); + + std::vector>> named_modules; + + // Only dedup when remove_duplicate=true + if (remove_duplicate) { + if (memory->contains(this)) { + return named_modules; // already visited: don't emit, don't recurse } - CHECK(!named_modules.contains(prefix)); - named_modules.emplace(prefix, shared_from_this()); - for (auto &[name, module] : modules_) { - if (!module) { - continue; - } - auto submodule_prefix = (prefix.empty() ? "" : prefix + ".") + name; - for (auto &[sub_name, sub_module] : module->NamedModules(submodule_prefix, remove_duplicate, memory)) { - CHECK(!named_modules.contains(sub_name)); - named_modules.emplace(sub_name, sub_module); - } + memory->insert(this); + } + + // Emit self first (pre-order) + named_modules.emplace_back(prefix, shared_from_this()); + + // Collect children then sort by key for stable order + std::vector>> children; + children.reserve(modules_.size()); + for (const auto &[name, module] : modules_) { + if (!module) { + continue; } + children.emplace_back(name, module); } + std::sort(children.begin(), children.end(), [](const auto &a, const auto &b) { return a.first < b.first; }); + + // Recurse in sorted order + for (const auto &[name, module] : children) { + const auto submodule_prefix = (prefix.empty() ? "" : prefix + ".") + name; + auto sub = module->NamedModules(memory, submodule_prefix, remove_duplicate); + named_modules.insert(named_modules.end(), sub.begin(), sub.end()); + } + return named_modules; } @@ -192,7 +211,7 @@ std::vector> Module::operator()(const std::vectorgrad_fn()->RegisterBackwardPostHook( [this](autograd::Function *, const std::vector> &grad_inputs, const std::vector> &grad_outputs) { - // Registry convention: (grad_outputs, grad_inputs) - PyTorch style + // Registry convention: (grad_outputs, grad_inputs) utils::GlobalModuleHookRegistry::Instance().CallModuleFullBackwardHooks(this, grad_outputs, grad_inputs); }); diff --git a/infini_train/src/utils/precision_checker.cc b/infini_train/src/utils/precision_checker.cc index 08e10a4..3391b9a 100644 --- a/infini_train/src/utils/precision_checker.cc +++ b/infini_train/src/utils/precision_checker.cc @@ -25,6 +25,8 @@ namespace infini_train::utils { +static std::unordered_map g_module_name_map; + namespace { // Simple MD5 implementation @@ -263,7 +265,7 @@ void SaveNpy(const std::shared_ptr &tensor, const std::string &name, int const auto &output_path = PrecisionCheckEnv::Instance().GetOutputPath(); std::string dir = output_path + "/rank_" + std::to_string(rank); std::filesystem::create_directories(dir); - std::string filename = dir + "/" + name + "_" + std::to_string(idx) + "_" + stage + ".npy"; + std::string filename = dir + "/" + name + (idx > 0 ? "_" + std::to_string(idx) : "") + "_" + stage + ".npy"; if (tensor->Dtype() == DataType::kFLOAT32) { tensor->SaveAsNpy(filename); @@ -320,6 +322,9 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string // Output to log auto &log_stream = GetLogStream(); + // Format: name[_idx]_forward/backward (match .npy filename format) + std::string log_name = name + (idx > 0 ? "_" + std::to_string(idx) : "") + "_" + stage_short; + if (global_config.format == "md5") { // MD5 format std::string md5; @@ -338,7 +343,7 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string // Original precision MD5 md5 = ComputeMD5(cpu_tensor->DataPtr(), byte_size); } - log_stream << context_key << " " << name << "_" << idx << "_" << stage << " tensor[" << i << "]: " + log_stream << context_key << " " << log_name << " tensor[" << i << "]: " << "dtype=" << DataTypeToString(cpu_tensor->Dtype()) << " " << "shape=" << FormatShape(cpu_tensor->Dims()) << " " << "md5=" << md5 << std::endl; @@ -350,7 +355,7 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string = (config.check_nan && stats.nan_count > 0) || (config.check_inf && stats.inf_count > 0); const std::string error_marker = has_error ? " <- ERROR" : ""; - log_stream << context_key << " " << name << "_" << idx << "_" << stage << " tensor[" << i << "]: " + log_stream << context_key << " " << log_name << " tensor[" << i << "]: " << "dtype=" << DataTypeToString(cpu_tensor->Dtype()) << " " << "shape=" << FormatShape(cpu_tensor->Dims()) << " " << "min=" << stats.min_val << " " @@ -388,17 +393,46 @@ void PrecisionChecker::Init(const PrecisionCheckConfig &global_config, const Con GlobalModuleHookRegistry::Instance().RegisterModuleForwardHook( [config](nn::Module *module, const std::vector> &inputs, const std::vector> &outputs) { - CheckTensors("Forward Output", module->type(), outputs, config); + auto it = g_module_name_map.find(module); + const std::string &name = (it != g_module_name_map.end()) ? it->second : module->type(); + CheckTensors("Forward Output", name, outputs, config); }); // Register global module full backward hook (checks gradients on every backward) GlobalModuleHookRegistry::Instance().RegisterModuleFullBackwardHook( [config](nn::Module *module, const std::vector> &grad_outputs, const std::vector> &grad_inputs) { - CheckTensors("GradOutputs", module->type(), grad_outputs, config); + auto it = g_module_name_map.find(module); + const std::string &name = (it != g_module_name_map.end()) ? it->second : module->type(); + CheckTensors("GradOutputs", name, grad_outputs, config); }); } +static inline bool ShouldSkipNameMap(std::string_view name) { + return name.rfind("__pp", 0) == 0; // starts_with("__pp") +} + +void PrecisionChecker::BuildNameMap(nn::Module *root_model) { + const auto &global_config = PrecisionCheckEnv::Instance().GetConfig(); + if (global_config.level == PrecisionCheckLevel::OFF || root_model == nullptr) { + return; + } + + auto named = root_model->NamedModules(/*memory=*/nullptr, /*prefix=*/"", /*remove_duplicate=*/false); + g_module_name_map.clear(); + g_module_name_map.reserve(named.size()); + + for (const auto &[name, module] : named) { + if (name.empty()) { + continue; // skip root + } + if (ShouldSkipNameMap(name)) { + continue; // skip PP internal tree + } + g_module_name_map[module.get()] = name; // keep InfiniTrain path directly + } +} + void PrecisionChecker::RegisterForFunction(autograd::Function *func, const std::string &name, const Config &config) { const std::string func_name = name.empty() ? "Function" : name; diff --git a/scripts/compare_loss.py b/scripts/compare_loss.py index d8ed871..8b58126 100755 --- a/scripts/compare_loss.py +++ b/scripts/compare_loss.py @@ -62,8 +62,8 @@ def main(): args.threshold_fp32 = args.threshold args.threshold_bf16 = args.threshold - files1 = {f.name: f for f in args.dir1.glob('*.log')} - files2 = {f.name: f for f in args.dir2.glob('*.log')} + files1 = {f.name: f for f in args.dir1.glob('*.log') if not f.name.startswith('build')} + files2 = {f.name: f for f in args.dir2.glob('*.log') if not f.name.startswith('build')} only_in_1 = set(files1.keys()) - set(files2.keys()) only_in_2 = set(files2.keys()) - set(files1.keys()) diff --git a/scripts/compare_tps.py b/scripts/compare_tps.py index 8a3cb80..270b1dd 100755 --- a/scripts/compare_tps.py +++ b/scripts/compare_tps.py @@ -55,8 +55,8 @@ def main(): parser.add_argument('--verbose', action='store_true', help='Print detailed output for all files, including passed ones') args = parser.parse_args() - files1 = {f.name: f for f in args.dir1.glob('*.log')} - files2 = {f.name: f for f in args.dir2.glob('*.log')} + files1 = {f.name: f for f in args.dir1.glob('*.log') if not f.name.startswith('build')} + files2 = {f.name: f for f in args.dir2.glob('*.log') if not f.name.startswith('build')} only_in_1 = set(files1.keys()) - set(files2.keys()) only_in_2 = set(files2.keys()) - set(files1.keys())