Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ void Train(const nn::parallel::Rank &rank) {

model->To(device);

utils::PrecisionChecker::BuildNameMap(model.get());

// select the data type
// TODO(lzm): change to solely rely on the weight file info for determining the dtype when autocast is supported
DataType dtype;
Expand Down
2 changes: 2 additions & 0 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ void Train(const nn::parallel::Rank &rank) {

model->To(device);

utils::PrecisionChecker::BuildNameMap(model.get());

LOG(INFO) << "Rank " << rank.GlobalRank() << ": Model loaded to device.";

DataType dtype;
Expand Down
8 changes: 4 additions & 4 deletions infini_train/include/nn/modules/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class Module : public std::enable_shared_from_this<Module> {

virtual std::shared_ptr<Module> ReplicateForDataParallel(int device_idx) const;

std::vector<std::pair<std::string, std::shared_ptr<Module>>>
NamedModules(std::unordered_set<Module *> *memory = nullptr, const std::string &prefix = "",
bool remove_duplicate = true);

// Hook registration methods
std::shared_ptr<infini_train::HookHandle> RegisterForwardPreHook(ModulePreHook hook);
std::shared_ptr<infini_train::HookHandle> RegisterForwardPostHook(ModulePostHook hook);
Expand All @@ -99,10 +103,6 @@ class Module : public std::enable_shared_from_this<Module> {
std::vector<ModulePostHook> backward_post_hooks_;

private:
std::unordered_map<std::string, std::shared_ptr<Module>>
NamedModules(const std::string &prefix = "", bool remove_duplicate = true,
std::unordered_set<Module *> *memory = nullptr);

friend std::vector<std::shared_ptr<Module>>
parallel::function::Replicate(const std::shared_ptr<Module> &network, const std::vector<const Device *> &devices);
};
Expand Down
4 changes: 4 additions & 0 deletions infini_train/include/utils/precision_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class PrecisionChecker {
// Called automatically by PrecisionCheckEnv::Init when level >= MODULE
static void Init(const PrecisionCheckConfig &global_config, const Config &config = DefaultConfig());

// Build name map from root_model without registering hooks
// Called by PrecisionCheckEnv::RegisterWithRootModel
static void BuildNameMap(nn::Module *root_model);

static void RegisterForFunction(autograd::Function *func, const std::string &name = "",
const Config &config = DefaultConfig());

Expand Down
63 changes: 41 additions & 22 deletions infini_train/src/nn/modules/module.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "infini_train/include/nn/modules/module.h"

#include <algorithm>
#include <memory>
#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -71,40 +72,58 @@ std::vector<std::shared_ptr<Tensor>> Module::Buffers() const {
std::vector<std::shared_ptr<Module>> Module::modules() {
std::vector<std::shared_ptr<Module>> modules;
auto named_modules = NamedModules();
for (auto &[_, module] : named_modules) {
if (_ != "") {

std::shared_ptr<Module> root;
for (auto &[name, module] : named_modules) {
if (name != "") {
modules.push_back(module);
} else {
root = module;
}
}
modules.insert(modules.begin(), named_modules[""]);

modules.insert(modules.begin(), root);
return modules;
}

// FIXME(dcj): can not call this function in constructor
std::unordered_map<std::string, std::shared_ptr<Module>>
Module::NamedModules(const std::string &prefix, bool remove_duplicate, std::unordered_set<Module *> *memory) {
std::vector<std::pair<std::string, std::shared_ptr<Module>>>
Module::NamedModules(std::unordered_set<Module *> *memory, const std::string &prefix, bool remove_duplicate) {
std::unordered_set<Module *> local_memory;
if (memory == nullptr) {
memory = &local_memory;
}
std::unordered_map<std::string, std::shared_ptr<Module>> named_modules;
if (!memory->contains(this)) {
if (remove_duplicate) {
memory->insert(this);

std::vector<std::pair<std::string, std::shared_ptr<Module>>> named_modules;

// Only dedup when remove_duplicate=true
if (remove_duplicate) {
if (memory->contains(this)) {
return named_modules; // already visited: don't emit, don't recurse
}
CHECK(!named_modules.contains(prefix));
named_modules.emplace(prefix, shared_from_this());
for (auto &[name, module] : modules_) {
if (!module) {
continue;
}
auto submodule_prefix = (prefix.empty() ? "" : prefix + ".") + name;
for (auto &[sub_name, sub_module] : module->NamedModules(submodule_prefix, remove_duplicate, memory)) {
CHECK(!named_modules.contains(sub_name));
named_modules.emplace(sub_name, sub_module);
}
memory->insert(this);
}

// Emit self first (pre-order)
named_modules.emplace_back(prefix, shared_from_this());

// Collect children then sort by key for stable order
std::vector<std::pair<std::string, std::shared_ptr<Module>>> children;
children.reserve(modules_.size());
for (const auto &[name, module] : modules_) {
if (!module) {
continue;
}
children.emplace_back(name, module);
}
std::sort(children.begin(), children.end(), [](const auto &a, const auto &b) { return a.first < b.first; });

// Recurse in sorted order
for (const auto &[name, module] : children) {
const auto submodule_prefix = (prefix.empty() ? "" : prefix + ".") + name;
auto sub = module->NamedModules(memory, submodule_prefix, remove_duplicate);
named_modules.insert(named_modules.end(), sub.begin(), sub.end());
}

return named_modules;
}

Expand Down Expand Up @@ -192,7 +211,7 @@ std::vector<std::shared_ptr<Tensor>> Module::operator()(const std::vector<std::s
output->grad_fn()->RegisterBackwardPostHook(
[this](autograd::Function *, const std::vector<std::shared_ptr<Tensor>> &grad_inputs,
const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
// Registry convention: (grad_outputs, grad_inputs) - PyTorch style
// Registry convention: (grad_outputs, grad_inputs)
utils::GlobalModuleHookRegistry::Instance().CallModuleFullBackwardHooks(this, grad_outputs,
grad_inputs);
});
Expand Down
44 changes: 39 additions & 5 deletions infini_train/src/utils/precision_checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

namespace infini_train::utils {

static std::unordered_map<const nn::Module *, std::string> g_module_name_map;

namespace {

// Simple MD5 implementation
Expand Down Expand Up @@ -263,7 +265,7 @@ void SaveNpy(const std::shared_ptr<Tensor> &tensor, const std::string &name, int
const auto &output_path = PrecisionCheckEnv::Instance().GetOutputPath();
std::string dir = output_path + "/rank_" + std::to_string(rank);
std::filesystem::create_directories(dir);
std::string filename = dir + "/" + name + "_" + std::to_string(idx) + "_" + stage + ".npy";
std::string filename = dir + "/" + name + (idx > 0 ? "_" + std::to_string(idx) : "") + "_" + stage + ".npy";

if (tensor->Dtype() == DataType::kFLOAT32) {
tensor->SaveAsNpy(filename);
Expand Down Expand Up @@ -320,6 +322,9 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string
// Output to log
auto &log_stream = GetLogStream();

// Format: name[_idx]_forward/backward (match .npy filename format)
std::string log_name = name + (idx > 0 ? "_" + std::to_string(idx) : "") + "_" + stage_short;

if (global_config.format == "md5") {
// MD5 format
std::string md5;
Expand All @@ -338,7 +343,7 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string
// Original precision MD5
md5 = ComputeMD5(cpu_tensor->DataPtr(), byte_size);
}
log_stream << context_key << " " << name << "_" << idx << "_" << stage << " tensor[" << i << "]: "
log_stream << context_key << " " << log_name << " tensor[" << i << "]: "
<< "dtype=" << DataTypeToString(cpu_tensor->Dtype()) << " "
<< "shape=" << FormatShape(cpu_tensor->Dims()) << " "
<< "md5=" << md5 << std::endl;
Expand All @@ -350,7 +355,7 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string
= (config.check_nan && stats.nan_count > 0) || (config.check_inf && stats.inf_count > 0);
const std::string error_marker = has_error ? " <- ERROR" : "";

log_stream << context_key << " " << name << "_" << idx << "_" << stage << " tensor[" << i << "]: "
log_stream << context_key << " " << log_name << " tensor[" << i << "]: "
<< "dtype=" << DataTypeToString(cpu_tensor->Dtype()) << " "
<< "shape=" << FormatShape(cpu_tensor->Dims()) << " "
<< "min=" << stats.min_val << " "
Expand Down Expand Up @@ -388,17 +393,46 @@ void PrecisionChecker::Init(const PrecisionCheckConfig &global_config, const Con
GlobalModuleHookRegistry::Instance().RegisterModuleForwardHook(
[config](nn::Module *module, const std::vector<std::shared_ptr<Tensor>> &inputs,
const std::vector<std::shared_ptr<Tensor>> &outputs) {
CheckTensors("Forward Output", module->type(), outputs, config);
auto it = g_module_name_map.find(module);
const std::string &name = (it != g_module_name_map.end()) ? it->second : module->type();
CheckTensors("Forward Output", name, outputs, config);
});

// Register global module full backward hook (checks gradients on every backward)
GlobalModuleHookRegistry::Instance().RegisterModuleFullBackwardHook(
[config](nn::Module *module, const std::vector<std::shared_ptr<Tensor>> &grad_outputs,
const std::vector<std::shared_ptr<Tensor>> &grad_inputs) {
CheckTensors("GradOutputs", module->type(), grad_outputs, config);
auto it = g_module_name_map.find(module);
const std::string &name = (it != g_module_name_map.end()) ? it->second : module->type();
CheckTensors("GradOutputs", name, grad_outputs, config);
});
}

static inline bool ShouldSkipNameMap(std::string_view name) {
return name.rfind("__pp", 0) == 0; // starts_with("__pp")
}

void PrecisionChecker::BuildNameMap(nn::Module *root_model) {
const auto &global_config = PrecisionCheckEnv::Instance().GetConfig();
if (global_config.level == PrecisionCheckLevel::OFF || root_model == nullptr) {
return;
}

auto named = root_model->NamedModules(/*memory=*/nullptr, /*prefix=*/"", /*remove_duplicate=*/false);
g_module_name_map.clear();
g_module_name_map.reserve(named.size());

for (const auto &[name, module] : named) {
if (name.empty()) {
continue; // skip root
}
if (ShouldSkipNameMap(name)) {
continue; // skip PP internal tree
}
g_module_name_map[module.get()] = name; // keep InfiniTrain path directly
}
}

void PrecisionChecker::RegisterForFunction(autograd::Function *func, const std::string &name, const Config &config) {
const std::string func_name = name.empty() ? "Function" : name;

Expand Down
4 changes: 2 additions & 2 deletions scripts/compare_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def main():
args.threshold_fp32 = args.threshold
args.threshold_bf16 = args.threshold

files1 = {f.name: f for f in args.dir1.glob('*.log')}
files2 = {f.name: f for f in args.dir2.glob('*.log')}
files1 = {f.name: f for f in args.dir1.glob('*.log') if not f.name.startswith('build')}
files2 = {f.name: f for f in args.dir2.glob('*.log') if not f.name.startswith('build')}

only_in_1 = set(files1.keys()) - set(files2.keys())
only_in_2 = set(files2.keys()) - set(files1.keys())
Expand Down
4 changes: 2 additions & 2 deletions scripts/compare_tps.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def main():
parser.add_argument('--verbose', action='store_true', help='Print detailed output for all files, including passed ones')
args = parser.parse_args()

files1 = {f.name: f for f in args.dir1.glob('*.log')}
files2 = {f.name: f for f in args.dir2.glob('*.log')}
files1 = {f.name: f for f in args.dir1.glob('*.log') if not f.name.startswith('build')}
files2 = {f.name: f for f in args.dir2.glob('*.log') if not f.name.startswith('build')}

only_in_1 = set(files1.keys()) - set(files2.keys())
only_in_2 = set(files2.keys()) - set(files1.keys())
Expand Down