From bb82b527f1e94f9445c90f2fb806f42b12de67a3 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Mon, 22 Sep 2025 18:53:54 -0700 Subject: [PATCH 01/14] feat: add state dict load and verify for module --- src/module/module.cpp | 137 ++++++++++++++++++++++++++++++++------- src/module/module.h | 108 ++++++++++++++++++++---------- src/module/module_list.h | 55 ++++------------ 3 files changed, 201 insertions(+), 99 deletions(-) diff --git a/src/module/module.cpp b/src/module/module.cpp index 45f2546c..94931b24 100644 --- a/src/module/module.cpp +++ b/src/module/module.cpp @@ -2,6 +2,8 @@ #include +#include "model_loader/state_dict.h" + namespace llm { using namespace torch; @@ -49,8 +51,9 @@ OrderedDict Module::named_parameters(bool recurse) const { OrderedDict result; if (!recurse) { for (const auto& parameter : parameters_) { - if (parameter.value().defined()) { - result.insert(parameter.key(), parameter.value()); + const auto& param_tensor = parameter.value().tensor; + if (param_tensor.defined()) { + result.insert(parameter.key(), param_tensor); } } } else { @@ -125,12 +128,20 @@ OrderedDict> Module::named_modules( } std::vector> Module::children() const { - return children_.values(); + std::vector> result; + for (const auto& item : children_) { + result.push_back(item.value().module); + } + return result; } OrderedDict> Module::named_children() const { - return children_; + OrderedDict> result; + for (const auto& item : children_) { + result.insert(item.key(), item.value().module); + } + return result; } void Module::apply(const ModuleApplyFunction& function) { @@ -198,31 +209,51 @@ void Module::to(torch::Device device, bool non_blocking) { to_impl(device, non_blocking); } +Tensor& Module::register_parameter(std::string name, + Tensor tensor, + const TensorLoader& loader) { + TORCH_CHECK(!name.empty(), "Parameter name must not be empty"); + // tensor.set_requires_grad(false); + Parameter param{ + .tensor = std::move(tensor), .loader = loader, .is_loaded = false}; + return parameters_.insert(std::move(name), std::move(param)).tensor; +} + Tensor& Module::register_parameter(std::string name, Tensor tensor) { + auto default_loader = [](const StateDict& sd, const std::string& key) { + return sd.get_tensor(key); + }; + return register_parameter(std::move(name), std::move(tensor), default_loader); +} + +Tensor& Module::register_sharded_parameter(std::string name, + int dim, + int rank, + int world_size, + Tensor tensor) { TORCH_CHECK(!name.empty(), "Parameter name must not be empty"); - TORCH_CHECK(name.find('.') == std::string::npos, - "Parameter name must not contain a dot (got '", - name, - "')"); - tensor.set_requires_grad(false); - return parameters_.insert(std::move(name), std::move(tensor)); + // tensor.set_requires_grad(false); + Parameter param{ + .tensor = std::move(tensor), + .loader = + [dim, rank, world_size](const StateDict& sd, const std::string& key) { + return sd.get_sharded_tensor(key, dim, rank, world_size); + }, + .is_loaded = false}; + return parameters_.insert(std::move(name), std::move(param)).tensor; } Tensor& Module::register_buffer(std::string name, Tensor tensor) { TORCH_CHECK(!name.empty(), "Buffer name must not be empty"); - TORCH_CHECK(name.find('.') == std::string::npos, - "Buffer name must not contain a dot (got '", - name, - "')"); return buffers_.insert(std::move(name), std::move(tensor)); } -void Module::unregister_module(const std::string& name) { - TORCH_CHECK(children_.contains(name), - "No Module with name `", - name, - "` is registered"); +bool Module::unregister_module(const std::string& name) { + if (!children_.contains(name)) { + return false; + } children_.erase(name); + return true; } void Module::pretty_print(std::ostream& stream) const { stream << name(); } @@ -235,7 +266,7 @@ void Module::pretty_print_recursive(std::ostream& stream, const std::string next_indentation = indentation + " "; for (const auto& child : children_) { stream << next_indentation << "(" << child.key() << "): "; - child.value()->pretty_print_recursive(stream, next_indentation); + child.value().module->pretty_print_recursive(stream, next_indentation); stream << '\n'; } stream << indentation << ")"; @@ -248,8 +279,8 @@ void Module::apply_to_submodules( const std::string& name_prefix) const { for (const auto& child : children_) { auto qualified_name = join_name(name_prefix, child.key()); - function(qualified_name, child.value()); - child.value()->apply_to_submodules(function, qualified_name); + function(qualified_name, child.value().module); + child.value().module->apply_to_submodules(function, qualified_name); } } @@ -276,4 +307,66 @@ std::ostream& operator<<(std::ostream& stream, const Module& module) { module.pretty_print_recursive(stream, ""); return stream; } + +// load weights from the checkpoint, override this method if necessary +// NOLINTNEXTLINE(misc-no-recursion) +void Module::load(const StateDict& state_dict, const std::string& name_prefix) { + // load parameters one by one + for (auto& item : parameters_) { + const auto& key = item.key(); + auto& param = item.value(); + + // clear the load status before loading + param.is_loaded = false; + const auto tensor = param.loader(state_dict, key); + if (!tensor.defined()) { + continue; + } + + const auto& param_tensor = param.tensor; + if (param_tensor.sizes() == tensor.sizes()) { + // copy data to the parameter tensor + param_tensor.copy_(tensor); + // mark as loaded + param.is_loaded = true; + } else { + LOG(ERROR) << "Parameter size mismatch for " + << join_name(name_prefix, key) << ": expected " + << param_tensor.sizes() << ", got " << tensor.sizes(); + } + } + + // don't need to load buffers, since they are initialized in the constructor + + // recursively load children modules + for (const auto& item : children_) { + const auto& key = item.key(); + const auto& child = item.value(); + // select state dict for the child module + const auto child_state_dict = child.selector(state_dict, key); + child.module->load(child_state_dict, join_name(name_prefix, key)); + } +} + +// verify whether the weights are loaded, override this method if necessary +// NOLINTNEXTLINE(misc-no-recursion) +bool Module::verify(const std::string& name_prefix) const { + bool all_loaded = true; + for (const auto& item : parameters_) { + const auto& key = item.key(); + const auto& param = item.value(); + if (!param.is_loaded) { + LOG(ERROR) << "Missing parameter: " << join_name(name_prefix, key); + } + all_loaded = all_loaded && param.is_loaded; + } + + for (const auto& item : children_) { + const auto& key = item.key(); + const auto& child = item.value(); + const bool child_loaded = child.module->verify(join_name(name_prefix, key)); + all_loaded = all_loaded && child_loaded; + } + return all_loaded; +} } // namespace llm diff --git a/src/module/module.h b/src/module/module.h index 69a63109..f99d08a5 100644 --- a/src/module/module.h +++ b/src/module/module.h @@ -8,7 +8,9 @@ #include #include #include +#include +#include "model_loader/state_dict.h" #include "module_holder.h" namespace llm { @@ -148,43 +150,60 @@ class Module : public std::enable_shared_from_this { /// `stream` should be returned from the method, to allow easy chaining. virtual void pretty_print(std::ostream& stream) const; + // load weights from the checkpoint, override this method if necessary + virtual void load(const StateDict& state_dict, + const std::string& name_prefix = std::string()); + + // verify whether the weights are loaded, override this method if necessary + virtual bool verify(const std::string& name_prefix = std::string()) const; + /// Registers a parameter with this `Module`. + using TensorLoader = + std::function; + torch::Tensor& register_parameter(std::string name, + torch::Tensor tensor, + const TensorLoader& loader); + torch::Tensor& register_parameter(std::string name, torch::Tensor tensor); + torch::Tensor& register_sharded_parameter(std::string name, + int dim, + int rank, + int world_size, + torch::Tensor tensor); + /// Registers a buffer with this `Module`. torch::Tensor& register_buffer(std::string name, torch::Tensor tensor); /// Registers a submodule with this `Module`. + using StateDictSelector = + std::function; + template std::shared_ptr register_module( std::string name, - std::shared_ptr module); + std::shared_ptr module, + const StateDictSelector& selector); template std::shared_ptr register_module( std::string name, - ModuleHolder module_holder); - - /// Replaces a registered submodule with this `Module`. - template - std::shared_ptr replace_module( - const std::string& name, std::shared_ptr module); template - std::shared_ptr replace_module( - const std::string& name, + std::shared_ptr register_module( + std::string name, ModuleHolder module_holder); - /// Unregisters a submodule from this `Module`. If there is no such module - /// with `name` an exception is thrown. - void unregister_module(const std::string& name); + template + std::shared_ptr register_module( + std::string name, + ModuleHolder module_holder, + const StateDictSelector& selector); - protected: - /// The registered parameters of this `Module`. - /// Inorder to access parameters_ in ParameterDict and ParameterList - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - torch::OrderedDict parameters_; + /// Unregisters a submodule from this `Module`. Returns false if no such + /// submodule was registered. + bool unregister_module(const std::string& name); private: /// Pretty prints the given `Module` into the `ostream`. @@ -209,11 +228,25 @@ class Module : public std::enable_shared_from_this { /// Returns a shared_ptr to `this` in a safe (checked) way. std::shared_ptr shared_from_this_checked() const; + struct Parameter { + torch::Tensor tensor; + std::function loader; + bool is_loaded = false; + }; + + struct Child { + std::shared_ptr module; + std::function selector; + }; + + /// The registered parameters of this `Module`. + torch::OrderedDict parameters_; + /// The registered buffers of this `Module`. torch::OrderedDict buffers_; /// The registered (direct) submodules of this `Module`. - torch::OrderedDict> children_; + torch::OrderedDict children_; /// The module's name (e.g. "LSTM"). mutable std::optional name_; @@ -248,43 +281,46 @@ const ModuleType* Module::as() const noexcept { template std::shared_ptr Module::register_module( std::string name, - std::shared_ptr module) { + std::shared_ptr module, + const StateDictSelector& selector) { TORCH_CHECK(!name.empty(), "Submodule name must not be empty"); - TORCH_CHECK(name.find('.') == std::string::npos, - "Submodule name must not contain a dot (got '", - name, - "')"); - auto& base_module = children_.insert(std::move(name), std::move(module)); + Child child{.module = std::move(module), .selector = selector}; + auto& base_module = + children_.insert(std::move(name), std::move(child)).module; return std::dynamic_pointer_cast(base_module); } template std::shared_ptr Module::register_module( std::string name, - ModuleHolder module_holder) { - return register_module(std::move(name), module_holder.ptr()); + std::shared_ptr module) { + auto default_selector = [](const StateDict& sd, const std::string& key) { + const std::string prefix = key + '.'; + return sd.select(prefix); + }; + return register_module(std::move(name), std::move(module), default_selector); } template -std::shared_ptr Module::replace_module( - const std::string& name, - std::shared_ptr module) { - auto& base_module = (children_[name] = std::move(module)); - return std::dynamic_pointer_cast(base_module); +std::shared_ptr Module::register_module( + std::string name, + ModuleHolder module_holder) { + return register_module(std::move(name), module_holder.ptr()); } template -std::shared_ptr Module::replace_module( - const std::string& name, - ModuleHolder module_holder) { - return replace_module(name, module_holder.ptr()); +std::shared_ptr Module::register_module( + std::string name, + ModuleHolder module_holder, + const StateDictSelector& selector) { + return register_module(std::move(name), module_holder.ptr(), selector); } template void Module::to_impl(Ts&&... ts) { // First call `to()` on every child module. for (auto& child : children_) { - child.value()->to(ts...); + child.value().module->to(ts...); } // Then move every parameter to the new dtype/device. for (auto& parameter : named_parameters(/*recurse=*/false)) { diff --git a/src/module/module_list.h b/src/module/module_list.h index 3ae3cb45..ce0bfee6 100644 --- a/src/module/module_list.h +++ b/src/module/module_list.h @@ -1,9 +1,5 @@ #pragma once -#include -#include -#include - #include #include @@ -11,6 +7,16 @@ #include "module_holder.h" namespace llm { +namespace detail { +/// A type trait whose `value` member is true if `M` derives from `Module`. +template +using is_module = std::is_base_of>; + +template +using enable_if_module_t = std::enable_if_t::value, T>; + +} // namespace detail + /// A list of `Module`s that registers its elements. class ModuleListImpl : public Module { public: @@ -40,7 +46,7 @@ class ModuleListImpl : public Module { /// Adds a new `Module` to the `ModuleList` container, moving or copying /// it into a `shared_ptr` internally. This method allows passing value types, /// and letting the container deal with the boxing. - template > + template > void push_back(M&& module) { using Type = std::remove_reference_t; push_back(std::make_shared(std::forward(module))); @@ -78,7 +84,7 @@ class ModuleListImpl : public Module { /// match. template T& at(size_t index) { - static_assert(torch::detail::is_module::value, + static_assert(detail::is_module::value, "Can only call ModuleList::at with an nn::Module type"); TORCH_CHECK(index < size(), "Index out of range"); auto module = modules_[index]->as(); @@ -95,7 +101,7 @@ class ModuleListImpl : public Module { /// match. template const T& at(size_t index) const { - static_assert(torch::detail::is_module::value, + static_assert(detail::is_module::value, "Can only call ModuleList::at with an nn::Module type"); TORCH_CHECK(index < size(), "Index out of range"); const auto module = modules_[index]->as(); @@ -120,7 +126,7 @@ class ModuleListImpl : public Module { /// match. template std::shared_ptr ptr(size_t index) const { - static_assert(torch::detail::is_module::value, + static_assert(detail::is_module::value, "Can only call ModuleList::ptr with an nn::Module type"); TORCH_CHECK(index < size(), "Index out of range"); return std::dynamic_pointer_cast(modules_[index]); @@ -138,39 +144,6 @@ class ModuleListImpl : public Module { /// True if there are no modules in the `ModuleList`. bool is_empty() const noexcept { return size() == 0; } - void insert(size_t index, std::shared_ptr module) { - TORCH_CHECK(index <= size(), "Index out of range"); - - if (index == size()) - push_back(std::move(module)); - else { - modules_.insert(modules_.begin() + Iterator::difference_type(index), - std::move(module)); - - for (const auto i : c10::irange(index, size() - 1)) { - (void)i; // Suppress unused variable warning - replace_module(std::to_string(index), modules_[index]); - } - register_module(std::to_string(size() - 1), modules_.back()); - } - } - - /// Unwraps the contained module of a `ModuleHolder` and inserts it in the - /// `ModuleList`. - template - void insert(size_t index, const ModuleHolder& module_holder) { - insert(index, module_holder.ptr()); - } - - /// inserts a new `Module` to the `ModuleList` container, moving or copying - /// it into a `shared_ptr` internally. This method allows passing value types, - /// and letting the container deal with the boxing. - template > - void insert(size_t index, M&& module) { - using Type = std::remove_reference_t; - insert(index, std::make_shared(std::forward(module))); - } - private: template void push_back_var(Head&& head, Tail&&... tail) { From fd283c5e1d2a00c7e64fe801c61e8c076f5a930b Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Mon, 22 Sep 2025 21:37:30 -0700 Subject: [PATCH 02/14] added unittests --- src/module/CMakeLists.txt | 13 ++++ src/module/module.cpp | 19 ++++-- src/module/module.h | 5 +- src/module/module_test.cpp | 131 +++++++++++++++++++++++++++++++++++++ 4 files changed, 162 insertions(+), 6 deletions(-) create mode 100644 src/module/module_test.cpp diff --git a/src/module/CMakeLists.txt b/src/module/CMakeLists.txt index ceaab617..043f5de7 100644 --- a/src/module/CMakeLists.txt +++ b/src/module/CMakeLists.txt @@ -1,4 +1,5 @@ include(cc_library) +include(cc_test) cc_library( NAME @@ -13,3 +14,15 @@ cc_library( glog::glog torch ) + +cc_test( + NAME + module_test + SRCS + module_test.cpp + DEPS + :module + :state_dict + :gtest_main + torch +) diff --git a/src/module/module.cpp b/src/module/module.cpp index 94931b24..6a72f5eb 100644 --- a/src/module/module.cpp +++ b/src/module/module.cpp @@ -310,11 +310,20 @@ std::ostream& operator<<(std::ostream& stream, const Module& module) { // load weights from the checkpoint, override this method if necessary // NOLINTNEXTLINE(misc-no-recursion) -void Module::load(const StateDict& state_dict, const std::string& name_prefix) { +size_t Module::load(const StateDict& state_dict, + const std::string& name_prefix) { + size_t total_loaded = 0; // load parameters one by one for (auto& item : parameters_) { const auto& key = item.key(); auto& param = item.value(); + const auto& param_tensor = param.tensor; + // skip loading for undefined tensor + if (!param_tensor.defined()) { + // mark as loaded to pass verification + param.is_loaded = true; + continue; + } // clear the load status before loading param.is_loaded = false; @@ -323,14 +332,14 @@ void Module::load(const StateDict& state_dict, const std::string& name_prefix) { continue; } - const auto& param_tensor = param.tensor; if (param_tensor.sizes() == tensor.sizes()) { // copy data to the parameter tensor param_tensor.copy_(tensor); // mark as loaded param.is_loaded = true; + ++total_loaded; } else { - LOG(ERROR) << "Parameter size mismatch for " + LOG(ERROR) << "Size mismatch for parameter " << join_name(name_prefix, key) << ": expected " << param_tensor.sizes() << ", got " << tensor.sizes(); } @@ -344,8 +353,10 @@ void Module::load(const StateDict& state_dict, const std::string& name_prefix) { const auto& child = item.value(); // select state dict for the child module const auto child_state_dict = child.selector(state_dict, key); - child.module->load(child_state_dict, join_name(name_prefix, key)); + total_loaded += + child.module->load(child_state_dict, join_name(name_prefix, key)); } + return total_loaded; } // verify whether the weights are loaded, override this method if necessary diff --git a/src/module/module.h b/src/module/module.h index f99d08a5..daab3f90 100644 --- a/src/module/module.h +++ b/src/module/module.h @@ -151,8 +151,9 @@ class Module : public std::enable_shared_from_this { virtual void pretty_print(std::ostream& stream) const; // load weights from the checkpoint, override this method if necessary - virtual void load(const StateDict& state_dict, - const std::string& name_prefix = std::string()); + // returns the number of loaded parameters + virtual size_t load(const StateDict& state_dict, + const std::string& name_prefix = std::string()); // verify whether the weights are loaded, override this method if necessary virtual bool verify(const std::string& name_prefix = std::string()) const; diff --git a/src/module/module_test.cpp b/src/module/module_test.cpp new file mode 100644 index 00000000..475f6264 --- /dev/null +++ b/src/module/module_test.cpp @@ -0,0 +1,131 @@ +#include "module.h" + +#include +#include +#include + +namespace llm { + +class ParametersModel : public Module { + public: + ParametersModel(bool bias = true) { + // register some parameters + weight_ = register_parameter("weight", torch::randn({16, 16})); + if (bias) { + bias_ = register_parameter("bias", torch::randn({32, 32})); + } + + sharded_param_ = register_sharded_parameter("sharded_param", + /*dim=*/0, + /*rank=*/0, + /*world_size=*/2, + torch::randn({16, 32})); + undefined_param_ = register_parameter("undefined_param", torch::Tensor()); + } + + torch::Tensor weight_; + torch::Tensor bias_; + + torch::Tensor sharded_param_; + torch::Tensor undefined_param_; +}; + +class SubModel : public Module { + public: + SubModel() { + param1_ = register_parameter("param1", torch::randn({8, 8})); + param2_ = register_parameter("param2", torch::randn({8, 16})); + x_ = register_module("x", std::make_shared(/*bias=*/true)); + y_ = + register_module("y", std::make_shared(/*bias=*/false)); + } + + torch::Tensor param1_; + torch::Tensor param2_; + std::shared_ptr x_; + std::shared_ptr y_; +}; + +class Model : public Module { + public: + Model() { + param1_ = register_parameter("param1", torch::randn({8, 8})); + param2_ = register_parameter("param2", torch::randn({8, 16})); + sub_model_ = register_module("submodel", std::make_shared()); + } + + torch::Tensor param1_; + torch::Tensor param2_; + std::shared_ptr sub_model_; +}; + +TEST(ModuleTest, Parameters) { + std::unordered_map dict = { + {"weight", torch::randn({16, 16})}, + {"bias", torch::randn({32, 32})}, + {"param1", torch::randn({16, 16})}, + {"param2", torch::randn({16, 16})}, + {"sharded_param", torch::randn({32, 32})}, + {"undefined_param", torch::randn({32, 32})}}; + + StateDict state_dict(dict); + EXPECT_EQ(state_dict.size(), dict.size()); + + ParametersModel model; + EXPECT_EQ(model.load(state_dict), 3); + + EXPECT_TRUE(torch::equal(model.weight_, dict["weight"])); + EXPECT_TRUE(torch::equal(model.bias_, dict["bias"])); + EXPECT_TRUE(torch::equal(model.sharded_param_, + dict["sharded_param"].chunk(2, /*dim=*/0)[0])); + EXPECT_FALSE(model.undefined_param_.defined()); + + EXPECT_TRUE(model.verify()); +} + +TEST(ModuleTest, SubModules) { + std::unordered_map dict = { + {"param1", torch::randn({8, 8})}, + {"param2", torch::randn({8, 16})}, + {"submodel.param1", torch::randn({8, 8})}, + {"submodel.param2", torch::randn({8, 16})}, + {"submodel.x.weight", torch::randn({16, 16})}, + {"submodel.x.bias", torch::randn({32, 32})}, + {"submodel.x.sharded_param", torch::randn({32, 32})}, + {"submodel.x.undefined_param", torch::randn({32, 32})}, + {"submodel.y.weight", torch::randn({16, 16})}, + {"submodel.y.bias", torch::randn({32, 32})}, + {"submodel.y.sharded_param", torch::randn({32, 32})}, + {"submodel.y.undefined_param", torch::randn({32, 32})}}; + + StateDict state_dict(dict); + EXPECT_EQ(state_dict.size(), dict.size()); + + Model model; + EXPECT_EQ(model.load(state_dict), 9); + + EXPECT_TRUE(torch::equal(model.param1_, dict["param1"])); + EXPECT_TRUE(torch::equal(model.param2_, dict["param2"])); + // submodel + EXPECT_TRUE(torch::equal(model.sub_model_->param1_, dict["submodel.param1"])); + EXPECT_TRUE(torch::equal(model.sub_model_->param2_, dict["submodel.param2"])); + EXPECT_TRUE( + torch::equal(model.sub_model_->x_->weight_, dict["submodel.x.weight"])); + EXPECT_TRUE( + torch::equal(model.sub_model_->x_->bias_, dict["submodel.x.bias"])); + EXPECT_TRUE( + torch::equal(model.sub_model_->x_->sharded_param_, + dict["submodel.x.sharded_param"].chunk(2, /*dim=*/0)[0])); + EXPECT_FALSE(model.sub_model_->x_->undefined_param_.defined()); + EXPECT_TRUE( + torch::equal(model.sub_model_->y_->weight_, dict["submodel.y.weight"])); + EXPECT_FALSE(model.sub_model_->y_->bias_.defined()); + EXPECT_TRUE( + torch::equal(model.sub_model_->y_->sharded_param_, + dict["submodel.y.sharded_param"].chunk(2, /*dim=*/0)[0])); + EXPECT_FALSE(model.sub_model_->y_->undefined_param_.defined()); + + EXPECT_TRUE(model.verify()); +} + +} // namespace llm From 0897dbc55ec2e2971aba23c670c17ba7c97f40f3 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 11:34:48 -0700 Subject: [PATCH 03/14] gpt2 load --- src/layers/embedding.h | 6 +- src/layers/fused_linear.cpp | 14 +-- src/layers/fused_linear.h | 4 +- src/layers/linear.cpp | 43 +++++---- src/layers/linear.h | 44 ++++----- src/layers/linear_impl.h | 2 + src/model_loader/state_dict.cpp | 10 ++ src/model_loader/state_dict.h | 4 + src/models/causal_lm.h | 8 +- src/models/gpt2.h | 160 +++++++++----------------------- src/models/models.h | 32 +++---- src/module/module.cpp | 19 +++- src/module/module.h | 3 +- 13 files changed, 157 insertions(+), 192 deletions(-) diff --git a/src/layers/embedding.h b/src/layers/embedding.h index 9abe231d..f54b4766 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -73,6 +73,7 @@ class ParallelEmbeddingImpl : public Module { const ParallelArgs& parallel_args, const torch::TensorOptions& options) : parallel_args_(parallel_args) { + const auto rank = parallel_args_.rank(); const auto world_size = parallel_args_.world_size(); CHECK(embedding_dim % world_size == 0) << "out_features " << embedding_dim << " not divisible by world_size " @@ -80,8 +81,11 @@ class ParallelEmbeddingImpl : public Module { const int64_t embedding_dim_per_partition = embedding_dim / world_size; // register the weight parameter - weight_ = register_parameter( + weight_ = register_sharded_parameter( "weight", + /*dim=*/1, + rank, + world_size, torch::empty({num_embeddings, embedding_dim_per_partition}, options)); } diff --git a/src/layers/fused_linear.cpp b/src/layers/fused_linear.cpp index 917ab6ec..c529d7bc 100644 --- a/src/layers/fused_linear.cpp +++ b/src/layers/fused_linear.cpp @@ -24,13 +24,13 @@ FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( // fused linear layer const int64_t out_features = std::accumulate( out_features_vec.begin(), out_features_vec.end(), int64_t(0)); - fused_linear_ = ColumnParallelLinear(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options); + fused_linear_ = LegacyColumnParallelLinear(in_features, + out_features, + bias, + gather_output, + quant_args, + parallel_args, + options); // calculate split sizes split_sizes_.reserve(out_features_vec.size()); const auto world_size = parallel_args.world_size(); diff --git a/src/layers/fused_linear.h b/src/layers/fused_linear.h index 191922e6..ccc8d1a9 100644 --- a/src/layers/fused_linear.h +++ b/src/layers/fused_linear.h @@ -35,10 +35,10 @@ class FusedColumnParallelLinearImpl : public Module { private: // non-fused linear layers - std::vector parallel_linears_; + std::vector parallel_linears_; // fused linear layer - ColumnParallelLinear fused_linear_{nullptr}; + LegacyColumnParallelLinear fused_linear_{nullptr}; // sizes for each split std::vector split_sizes_; diff --git a/src/layers/linear.cpp b/src/layers/linear.cpp index 70a6d6f8..e35fc9b7 100644 --- a/src/layers/linear.cpp +++ b/src/layers/linear.cpp @@ -215,13 +215,14 @@ std::shared_ptr create_row_parallel_linear( // construct a ColumnParallelLinear. // chose right implementation based on the args. -ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) +LegacyColumnParallelLinear::LegacyColumnParallelLinear( + int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) : ModuleHolder(create_column_parallel_linear(in_features, out_features, bias, @@ -230,12 +231,13 @@ ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, parallel_args, options)) {} -ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) +LegacyColumnParallelLinear::LegacyColumnParallelLinear( + int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) : ModuleHolder(create_column_parallel_linear(in_features, out_features, bias, @@ -246,13 +248,14 @@ ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, // construct a rotary positional embedding. // chose right implementation based on the args. -RowParallelLinear::RowParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool input_is_parallelized, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) +LegacyRowParallelLinear::LegacyRowParallelLinear( + int64_t in_features, + int64_t out_features, + bool bias, + bool input_is_parallelized, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) : ModuleHolder(create_row_parallel_linear(in_features, out_features, bias, diff --git a/src/layers/linear.h b/src/layers/linear.h index 99da103b..a9145e5d 100644 --- a/src/layers/linear.h +++ b/src/layers/linear.h @@ -39,42 +39,42 @@ class ParallelLinearImpl : public Module { } }; -class ColumnParallelLinear : public ModuleHolder { +class LegacyColumnParallelLinear : public ModuleHolder { public: using ModuleHolder::ModuleHolder; using Impl [[maybe_unused]] = ParallelLinearImpl; // construct a rotary positional embedding. // chose right implementation based on the args. - ColumnParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); + LegacyColumnParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); - ColumnParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); + LegacyColumnParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); }; -class RowParallelLinear : public ModuleHolder { +class LegacyRowParallelLinear : public ModuleHolder { public: using ModuleHolder::ModuleHolder; using Impl [[maybe_unused]] = ParallelLinearImpl; // construct a rotary positional embedding. // chose right implementation based on the args. - RowParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool input_is_parallelized, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); + LegacyRowParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool input_is_parallelized, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); }; } // namespace llm diff --git a/src/layers/linear_impl.h b/src/layers/linear_impl.h index fe1f6b7c..62bcb7a9 100644 --- a/src/layers/linear_impl.h +++ b/src/layers/linear_impl.h @@ -62,6 +62,7 @@ class ColumnParallelLinearImpl : public ParallelLinearImpl { // parallel args ParallelArgs parallel_args_; }; +LLM_MODULE(ColumnParallelLinear); // Linear layer with row parallelism. // The linear layer is defined as Y = XA + b. A is parallelized along @@ -115,4 +116,5 @@ class RowParallelLinearImpl : public ParallelLinearImpl { // parallel args ParallelArgs parallel_args_; }; +LLM_MODULE(RowParallelLinear); } // namespace llm diff --git a/src/model_loader/state_dict.cpp b/src/model_loader/state_dict.cpp index bdfc96b9..f20f9109 100644 --- a/src/model_loader/state_dict.cpp +++ b/src/model_loader/state_dict.cpp @@ -219,4 +219,14 @@ StateDict StateDict::select_with_transform( return selected; } +StateDict StateDict::select_with_transform( + const std::string& prefix, + std::function transform_func) const { + return select_with_transform( + prefix, + [transform_func](const std::string&, const torch::Tensor& tensor) { + return transform_func(tensor); + }); +} + } // namespace llm diff --git a/src/model_loader/state_dict.h b/src/model_loader/state_dict.h index 58c51039..d130c53e 100644 --- a/src/model_loader/state_dict.h +++ b/src/model_loader/state_dict.h @@ -46,6 +46,10 @@ class StateDict final { StateDict select_with_transform(const std::string& prefix, TensorTransform transform_func) const; + StateDict select_with_transform( + const std::string& prefix, + std::function transform_func) const; + size_t size() const { return dict_.size(); } std::string_view prefix() const { return prefix_; } diff --git a/src/models/causal_lm.h b/src/models/causal_lm.h index 02a54e29..eaed0ffe 100644 --- a/src/models/causal_lm.h +++ b/src/models/causal_lm.h @@ -70,11 +70,15 @@ class CausalLMImpl : public CausalLM { } void load_state_dict(const StateDict& state_dict) override { - model_->load_state_dict(state_dict); + model_->load(state_dict); } void verify_loaded_weights() const override { - return model_->verify_loaded_weights(); + bool success = model_->verify(); + if (!success) { + LOG(FATAL) << "Failed to verify loaded weights for the model." + << " Please check the log for more details."; + } } torch::Device device() const override { return options_.device(); } diff --git a/src/models/gpt2.h b/src/models/gpt2.h index afde1084..98be0cf7 100644 --- a/src/models/gpt2.h +++ b/src/models/gpt2.h @@ -8,6 +8,7 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/linear_impl.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" @@ -21,6 +22,17 @@ namespace llm::hf { +namespace detail { +// GPT-2 implementation uses Conv1D instead of Linear. As a result, we +// need to transpose the weight for linear layer when loading from checkpoint. +static StateDict transpose_selector(const StateDict& sd, + const std::string& key) { + // transpose the weight + return sd.select_with_transform( + key + ".", [](const torch::Tensor& tensor) { return tensor.t(); }); +} +} // namespace detail + class GPT2MLPImpl : public Module { public: GPT2MLPImpl(const ModelArgs& args, @@ -35,51 +47,32 @@ class GPT2MLPImpl : public Module { // register the weight parameter c_fc_ = register_module("c_fc", - ColumnParallelLinear(hidden_size, - intermediate_size, - /*bias=*/true, - /*gather_output=*/false, - quant_args, - parallel_args, - options)); - c_proj_ = register_module("c_proj", - RowParallelLinear(intermediate_size, + LegacyColumnParallelLinear(hidden_size, + intermediate_size, + /*bias=*/true, + /*gather_output=*/false, + quant_args, + parallel_args, + options), + detail::transpose_selector); + c_proj_ = + register_module("c_proj", + LegacyRowParallelLinear(intermediate_size, hidden_size, /*bias=*/true, /*input_is_parallelized=*/true, quant_args, parallel_args, - options)); + options), + detail::transpose_selector); } torch::Tensor forward(torch::Tensor x) { return c_proj_(act_(c_fc_(x))); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - // GPT-2 implementation uses Conv1D instead of Linear. As a result, we - // need to transpose the weight. - c_fc_->load_state_dict(state_dict.select_with_transform( - "c_fc.", - [](const std::string_view& /*name*/, const torch::Tensor& tensor) { - return tensor.t(); - })); - c_proj_->load_state_dict(state_dict.select_with_transform( - "c_proj.", - [](const std::string_view& /*name*/, const torch::Tensor& tensor) { - return tensor.t(); - })); - } - - void verify_loaded_weights(const std::string& prefix) const { - c_fc_->verify_loaded_weights(prefix + "c_fc."); - c_proj_->verify_loaded_weights(prefix + "c_proj."); - } - private: // parameter members, must be registered - ColumnParallelLinear c_fc_{nullptr}; - RowParallelLinear c_proj_{nullptr}; + LegacyColumnParallelLinear c_fc_{nullptr}; + LegacyRowParallelLinear c_proj_{nullptr}; ActFunc act_{nullptr}; }; @@ -98,23 +91,27 @@ class GPT2AttentionImpl : public Module { head_dim_ = args.head_dim(); // register submodules - c_attn_ = register_module("c_attn", - ColumnParallelLinear(hidden_size_, + c_attn_ = + register_module("c_attn", + LegacyColumnParallelLinear(hidden_size_, 3 * hidden_size_, /*bias=*/true, /*gather_output=*/false, quant_args, parallel_args, - options)); + options), + detail::transpose_selector); - c_proj_ = register_module("c_proj", - RowParallelLinear(hidden_size_, + c_proj_ = + register_module("c_proj", + LegacyRowParallelLinear(hidden_size_, hidden_size_, /*bias=*/true, /*input_is_parallelized=*/true, quant_args, parallel_args, - options)); + options), + detail::transpose_selector); // initialize attention atten_ = register_module( @@ -138,32 +135,11 @@ class GPT2AttentionImpl : public Module { return c_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // GPT-2 implementation uses Conv1D instead of Linear. As a result, we - // need to transpose the weight. - c_attn_->load_state_dict(state_dict.select_with_transform( - "c_attn.", - [](const std::string_view& /*name*/, const torch::Tensor& tensor) { - return tensor.t(); - })); - c_proj_->load_state_dict(state_dict.select_with_transform( - "c_proj.", - [](const std::string_view& /*name*/, const torch::Tensor& tensor) { - return tensor.t(); - })); - } - - void verify_loaded_weights(const std::string& prefix) const { - c_attn_->verify_loaded_weights(prefix + "c_attn."); - c_proj_->verify_loaded_weights(prefix + "c_proj."); - } - private: // parameter members, must be registered - ColumnParallelLinear c_attn_{nullptr}; + LegacyColumnParallelLinear c_attn_{nullptr}; - RowParallelLinear c_proj_{nullptr}; + LegacyRowParallelLinear c_proj_{nullptr}; // module members without parameters Attention atten_{nullptr}; @@ -207,22 +183,6 @@ class GPT2BlockImpl : public Module { return h + mlp_(ln_2_(h)); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - attn_->load_state_dict(state_dict.select("attn.")); - mlp_->load_state_dict(state_dict.select("mlp.")); - ln_1_->load_state_dict(state_dict.select("ln_1.")); - ln_2_->load_state_dict(state_dict.select("ln_2.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - attn_->verify_loaded_weights(prefix + "attn."); - mlp_->verify_loaded_weights(prefix + "mlp."); - ln_1_->verify_loaded_weights(prefix + "ln_1."); - ln_2_->verify_loaded_weights(prefix + "ln_2."); - } - private: // parameter members, must be registered GPT2Attention attn_{nullptr}; @@ -282,27 +242,6 @@ class GPT2ModelImpl : public Module { return ln_f_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - wte_->load_state_dict(state_dict.select("wte.")); - wpe_->load_state_dict(state_dict.select("wpe.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("h." + std::to_string(i) + ".")); - } - ln_f_->load_state_dict(state_dict.select("ln_f.")); - } - - void verify_loaded_weights() const { - wte_->verify_loaded_weights("wte."); - wpe_->verify_loaded_weights("wpe."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights("h." + std::to_string(i) + "."); - } - ln_f_->verify_loaded_weights("ln_f."); - } - private: // parameter members, must be registered ParallelEmbedding wte_{nullptr}; @@ -328,10 +267,13 @@ class GPT2ForCausalLMImpl : public Module { const ParallelArgs& parallel_args, const torch::TensorOptions& options) { // register submodules - model_ = register_module( - "model", GPT2Model(args, quant_args, parallel_args, options)); + model_ = + register_module("model", + GPT2Model(args, quant_args, parallel_args, options), + /*selector=*/nullptr); - lm_head_ = register_module("lm_head", + // load wte.weight + lm_head_ = register_module("wte", ColumnParallelLinear(args.hidden_size(), args.vocab_size(), /*bias=*/false, @@ -363,18 +305,6 @@ class GPT2ForCausalLMImpl : public Module { return lm_head_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - model_->load_state_dict(state_dict); - // TODO: share wte_ weight with lm_head_ to save memory - lm_head_->load_state_dict(state_dict.select("wte.")); - } - - void verify_loaded_weights() const { - model_->verify_loaded_weights(); - lm_head_->verify_loaded_weights("wte."); - } - private: // parameter members, must be registered GPT2Model model_{nullptr}; diff --git a/src/models/models.h b/src/models/models.h index e569821e..9dfabab8 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -1,19 +1,19 @@ #pragma once // list all registered models here -#include "aquila.h" // IWYU pragma: keep -#include "baichuan.h" // IWYU pargma: keep -#include "bloom.h" // IWYU pragma: keep -#include "chatglm.h" // IWYU pragma: keep -#include "gemma.h" // IWYU pragma: keep -#include "gemma2.h" // IWYU pragma: keep -#include "gpt2.h" // IWYU pragma: keep -#include "gpt_j.h" // IWYU pragma: keep -#include "gpt_neox.h" // IWYU pragma: keep -#include "internlm.h" // IWYU pragma: keep -#include "llama.h" // IWYU pragma: keep -#include "mistral.h" // IWYU pragma: keep -#include "mpt.h" // IWYU pragma: keep -#include "phi.h" // IWYU pragma: keep -#include "qwen.h" // IWYU pragma: keep -#include "qwen2.h" // IWYU pragma: keep +// #include "aquila.h" // IWYU pragma: keep +// #include "baichuan.h" // IWYU pargma: keep +// #include "bloom.h" // IWYU pragma: keep +// #include "chatglm.h" // IWYU pragma: keep +// #include "gemma.h" // IWYU pragma: keep +// #include "gemma2.h" // IWYU pragma: keep +#include "gpt2.h" // IWYU pragma: keep +// #include "gpt_j.h" // IWYU pragma: keep +// #include "gpt_neox.h" // IWYU pragma: keep +// #include "internlm.h" // IWYU pragma: keep +// #include "llama.h" // IWYU pragma: keep +// #include "mistral.h" // IWYU pragma: keep +// #include "mpt.h" // IWYU pragma: keep +// #include "phi.h" // IWYU pragma: keep +// #include "qwen.h" // IWYU pragma: keep +// #include "qwen2.h" // IWYU pragma: keep diff --git a/src/module/module.cpp b/src/module/module.cpp index 6a72f5eb..cb245114 100644 --- a/src/module/module.cpp +++ b/src/module/module.cpp @@ -329,10 +329,13 @@ size_t Module::load(const StateDict& state_dict, param.is_loaded = false; const auto tensor = param.loader(state_dict, key); if (!tensor.defined()) { + LOG(ERROR) << "Missing parameter: " << join_name(name_prefix, key); continue; } if (param_tensor.sizes() == tensor.sizes()) { + LOG(INFO) << "Loading parameter: " << join_name(name_prefix, key) + << " of size " << tensor.sizes(); // copy data to the parameter tensor param_tensor.copy_(tensor); // mark as loaded @@ -351,10 +354,14 @@ size_t Module::load(const StateDict& state_dict, for (const auto& item : children_) { const auto& key = item.key(); const auto& child = item.value(); - // select state dict for the child module - const auto child_state_dict = child.selector(state_dict, key); - total_loaded += - child.module->load(child_state_dict, join_name(name_prefix, key)); + if (child.selector) { + // select state dict for the child module + const auto child_state_dict = child.selector(state_dict, key); + total_loaded += + child.module->load(child_state_dict, join_name(name_prefix, key)); + } else { + total_loaded += child.module->load(state_dict, name_prefix); + } } return total_loaded; } @@ -375,7 +382,9 @@ bool Module::verify(const std::string& name_prefix) const { for (const auto& item : children_) { const auto& key = item.key(); const auto& child = item.value(); - const bool child_loaded = child.module->verify(join_name(name_prefix, key)); + const std::string prefix = + child.selector ? join_name(name_prefix, key) : name_prefix; + const bool child_loaded = child.module->verify(prefix); all_loaded = all_loaded && child_loaded; } return all_loaded; diff --git a/src/module/module.h b/src/module/module.h index daab3f90..6cb96ff6 100644 --- a/src/module/module.h +++ b/src/module/module.h @@ -296,8 +296,7 @@ std::shared_ptr Module::register_module( std::string name, std::shared_ptr module) { auto default_selector = [](const StateDict& sd, const std::string& key) { - const std::string prefix = key + '.'; - return sd.select(prefix); + return sd.select(key + "."); }; return register_module(std::move(name), std::move(module), default_selector); } From 024386554ca440902f4c679d2157fd0c9594c351 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 13:31:37 -0700 Subject: [PATCH 04/14] llama & gemma wip --- src/layers/fused_linear.cpp | 8 +- src/layers/fused_linear.h | 18 ++-- src/layers/qkv_linear.cpp | 17 ++-- src/layers/qkv_linear.h | 2 +- src/models/aquila.h | 4 +- src/models/baichuan.h | 4 +- src/models/gemma.h | 30 +++---- src/models/gemma2.h | 4 +- src/models/internlm.h | 4 +- src/models/llama.h | 160 +++++++++++++----------------------- src/models/mistral.h | 4 +- src/models/models.h | 4 +- src/models/qwen2.h | 4 +- src/models/simple_model.h | 4 +- src/module/module.cpp | 8 +- 15 files changed, 117 insertions(+), 158 deletions(-) diff --git a/src/layers/fused_linear.cpp b/src/layers/fused_linear.cpp index c529d7bc..d88cd565 100644 --- a/src/layers/fused_linear.cpp +++ b/src/layers/fused_linear.cpp @@ -10,7 +10,7 @@ namespace llm { -FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( +LegacyFusedColumnParallelLinearImpl::LegacyFusedColumnParallelLinearImpl( int64_t in_features, const std::vector& out_features_vec, bool bias, @@ -55,7 +55,7 @@ FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( } } -std::vector FusedColumnParallelLinearImpl::forward( +std::vector LegacyFusedColumnParallelLinearImpl::forward( torch::Tensor input) { if (fused_) { auto fused_output = fused_linear_->forward(input); @@ -72,7 +72,7 @@ std::vector FusedColumnParallelLinearImpl::forward( return outputs; } -void FusedColumnParallelLinearImpl::load_state_dict( +void LegacyFusedColumnParallelLinearImpl::load_state_dict( const StateDict& state_dict, const std::vector& prefixes) { if (fused_) { @@ -85,7 +85,7 @@ void FusedColumnParallelLinearImpl::load_state_dict( } } -void FusedColumnParallelLinearImpl::verify_loaded_weights( +void LegacyFusedColumnParallelLinearImpl::verify_loaded_weights( const std::string& prefix) const { if (fused_) { fused_linear_->verify_loaded_weights(prefix); diff --git a/src/layers/fused_linear.h b/src/layers/fused_linear.h index ccc8d1a9..bbb08fea 100644 --- a/src/layers/fused_linear.h +++ b/src/layers/fused_linear.h @@ -12,15 +12,15 @@ namespace llm { -class FusedColumnParallelLinearImpl : public Module { +class LegacyFusedColumnParallelLinearImpl : public Module { public: - FusedColumnParallelLinearImpl(int64_t in_features, - const std::vector& out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); + LegacyFusedColumnParallelLinearImpl(int64_t in_features, + const std::vector& out_features, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); std::vector forward(torch::Tensor input); @@ -46,6 +46,6 @@ class FusedColumnParallelLinearImpl : public Module { // whether the linear layer is fused bool fused_ = false; }; -LLM_MODULE(FusedColumnParallelLinear); +LLM_MODULE(LegacyFusedColumnParallelLinear); } // namespace llm diff --git a/src/layers/qkv_linear.cpp b/src/layers/qkv_linear.cpp index fe4eab14..ab8ba299 100644 --- a/src/layers/qkv_linear.cpp +++ b/src/layers/qkv_linear.cpp @@ -36,14 +36,15 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl( effective_kv_heads * head_dim, effective_kv_heads * head_dim}; - parallel_linear_ = register_module("parallel_linear", - FusedColumnParallelLinear(hidden_size, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options)); + parallel_linear_ = + register_module("parallel_linear", + LegacyFusedColumnParallelLinear(hidden_size, + out_features, + bias, + gather_output, + quant_args, + parallel_args, + options)); } // special load_state_dict for fused cases diff --git a/src/layers/qkv_linear.h b/src/layers/qkv_linear.h index 2e4e2e39..b5e34871 100644 --- a/src/layers/qkv_linear.h +++ b/src/layers/qkv_linear.h @@ -40,7 +40,7 @@ class QKVColumnParallelLinearImpl : public Module { } private: - FusedColumnParallelLinear parallel_linear_{nullptr}; + LegacyFusedColumnParallelLinear parallel_linear_{nullptr}; // replication ratio of kv heads for MQA/GQA cases int64_t kv_replication_ratio_ = 0; diff --git a/src/models/aquila.h b/src/models/aquila.h index 59e7b3e4..6ab4ed6e 100644 --- a/src/models/aquila.h +++ b/src/models/aquila.h @@ -36,7 +36,7 @@ class AquilaMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + LegacyFusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, /*bias=*/false, @@ -74,7 +74,7 @@ class AquilaMLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function diff --git a/src/models/baichuan.h b/src/models/baichuan.h index db77fcdc..62f122d4 100644 --- a/src/models/baichuan.h +++ b/src/models/baichuan.h @@ -46,7 +46,7 @@ class BaichuanMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + LegacyFusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, /*bias=*/false, @@ -84,7 +84,7 @@ class BaichuanMLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function diff --git a/src/models/gemma.h b/src/models/gemma.h index dcb2c692..f8870cd1 100644 --- a/src/models/gemma.h +++ b/src/models/gemma.h @@ -11,6 +11,7 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/linear_impl.h" #include "layers/normalization.h" #include "layers/qkv_linear.h" #include "memory/kv_cache.h" @@ -39,7 +40,7 @@ class GemmaMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + LegacyFusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, /*bias=*/false, @@ -49,13 +50,13 @@ class GemmaMLPImpl : public Module { options)); down_proj_ = register_module("down_proj", - RowParallelLinear(intermediate_size, - hidden_size, - /*bias=*/false, - /*input_is_parallelized=*/true, - quant_args, - parallel_args, - options)); + LegacyRowParallelLinear(intermediate_size, + hidden_size, + /*bias=*/false, + /*input_is_parallelized=*/true, + quant_args, + parallel_args, + options)); } torch::Tensor forward(torch::Tensor x) { @@ -77,8 +78,8 @@ class GemmaMLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; - RowParallelLinear down_proj_{nullptr}; + LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; + LegacyRowParallelLinear down_proj_{nullptr}; // activation function ActFunc act_func_{nullptr}; @@ -113,8 +114,9 @@ class GemmaAttentionImpl : public Module { parallel_args, options)); - o_proj_ = register_module("o_proj", - RowParallelLinear(n_heads * head_dim, + o_proj_ = + register_module("o_proj", + LegacyRowParallelLinear(n_heads * head_dim, hidden_size, /*bias=*/false, /*input_is_parallelized=*/true, @@ -158,7 +160,7 @@ class GemmaAttentionImpl : public Module { // parameter members, must be registered QKVColumnParallelLinear qkv_proj_{nullptr}; - RowParallelLinear o_proj_{nullptr}; + LegacyRowParallelLinear o_proj_{nullptr}; // module members without parameters Attention atten_{nullptr}; @@ -337,7 +339,7 @@ class GemmaForCausalLMImpl : public Module { model_ = register_module( "model", GemmaModel(args, quant_args, parallel_args, options)); - lm_head_ = register_module("lm_head", + lm_head_ = register_module("model.embed_tokens", ColumnParallelLinear(args.hidden_size(), args.vocab_size(), /*bias=*/false, diff --git a/src/models/gemma2.h b/src/models/gemma2.h index 6282af6e..23e35eff 100644 --- a/src/models/gemma2.h +++ b/src/models/gemma2.h @@ -39,7 +39,7 @@ class Gemma2MLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + LegacyFusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, /*bias=*/false, @@ -77,7 +77,7 @@ class Gemma2MLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function diff --git a/src/models/internlm.h b/src/models/internlm.h index 8f100c32..3bea7989 100644 --- a/src/models/internlm.h +++ b/src/models/internlm.h @@ -35,7 +35,7 @@ class InternlmMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + LegacyFusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, /*bias=*/false, @@ -73,7 +73,7 @@ class InternlmMLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // calculate act(x) * y diff --git a/src/models/llama.h b/src/models/llama.h index 120ff48d..f8aaf696 100644 --- a/src/models/llama.h +++ b/src/models/llama.h @@ -10,6 +10,7 @@ #include "layers/embedding.h" #include "layers/fused_linear.h" #include "layers/linear.h" +#include "layers/linear_impl.h" #include "layers/normalization.h" #include "layers/qkv_linear.h" #include "memory/kv_cache.h" @@ -35,26 +36,26 @@ class LlamaMLPImpl : public Module { const int64_t intermediate_size = args.intermediate_size(); // register the weight parameter - gate_up_proj_ = register_module( - "gate_up_proj", - FusedColumnParallelLinear( - hidden_size, - std::vector{intermediate_size, intermediate_size}, - /*bias=*/false, - /*gather_output=*/false, - quant_args, - parallel_args, - options)); + // gate_up_proj_ = register_module( + // "gate_up_proj", + // LegacyFusedColumnParallelLinear( + // hidden_size, + // std::vector{intermediate_size, intermediate_size}, + // /*bias=*/false, + // /*gather_output=*/false, + // quant_args, + // parallel_args, + // options)); down_proj_ = register_module("down_proj", - RowParallelLinear(intermediate_size, - hidden_size, - /*bias=*/false, - /*input_is_parallelized=*/true, - quant_args, - parallel_args, - options)); + LegacyRowParallelLinear(intermediate_size, + hidden_size, + /*bias=*/false, + /*input_is_parallelized=*/true, + quant_args, + parallel_args, + options)); } torch::Tensor forward(torch::Tensor x) { @@ -62,22 +63,22 @@ class LlamaMLPImpl : public Module { return down_proj_(act_func_(gate_up[0]) * gate_up[1]); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."}); - down_proj_->load_state_dict(state_dict.select("down_proj.")); - } + // // load the weight from the checkpoint + // void load_state_dict(const StateDict& state_dict) { + // // call each submodule's load_state_dict function + // gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."}); + // down_proj_->load_state_dict(state_dict.select("down_proj.")); + // } - void verify_loaded_weights(const std::string& prefix) const { - gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj]."); - down_proj_->verify_loaded_weights(prefix + "down_proj."); - } + // void verify_loaded_weights(const std::string& prefix) const { + // gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj]."); + // down_proj_->verify_loaded_weights(prefix + "down_proj."); + // } private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; - RowParallelLinear down_proj_{nullptr}; + LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; + LegacyRowParallelLinear down_proj_{nullptr}; // activation function ActFunc act_func_{nullptr}; @@ -101,19 +102,20 @@ class LlamaAttentionImpl : public Module { std::max(1, n_kv_heads / world_size); // register submodules - qkv_proj_ = register_module("qkv_proj", - QKVColumnParallelLinear(hidden_size, - n_heads, - n_kv_heads, - head_dim, - /*bias=*/false, - /*gather_output=*/false, - quant_args, - parallel_args, - options)); - - o_proj_ = register_module("o_proj", - RowParallelLinear(hidden_size, + // qkv_proj_ = register_module("qkv_proj", + // QKVColumnParallelLinear(hidden_size, + // n_heads, + // n_kv_heads, + // head_dim, + // /*bias=*/false, + // /*gather_output=*/false, + // quant_args, + // parallel_args, + // options)); + + o_proj_ = + register_module("o_proj", + LegacyRowParallelLinear(hidden_size, hidden_size, /*bias=*/false, /*input_is_parallelized=*/true, @@ -140,24 +142,25 @@ class LlamaAttentionImpl : public Module { return o_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - qkv_proj_->load_state_dict( - state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."}); - o_proj_->load_state_dict(state_dict.select("o_proj.")); - } + // // load the weight from the checkpoint + // void load_state_dict(const StateDict& state_dict) { + // // call each submodule's load_state_dict function + // qkv_proj_->load_state_dict( + // state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", + // "v_proj."}); + // o_proj_->load_state_dict(state_dict.select("o_proj.")); + // } - void verify_loaded_weights(const std::string& prefix) const { - qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj]."); - o_proj_->verify_loaded_weights(prefix + "o_proj."); - } + // void verify_loaded_weights(const std::string& prefix) const { + // qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj]."); + // o_proj_->verify_loaded_weights(prefix + "o_proj."); + // } private: // parameter members, must be registered QKVColumnParallelLinear qkv_proj_{nullptr}; - RowParallelLinear o_proj_{nullptr}; + LegacyRowParallelLinear o_proj_{nullptr}; // module members without parameters Attention atten_{nullptr}; @@ -197,24 +200,6 @@ class LlamaDecoderLayerImpl : public Module { return h + mlp_(post_attention_layernorm_(h)); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - self_attn_->load_state_dict(state_dict.select("self_attn.")); - mlp_->load_state_dict(state_dict.select("mlp.")); - input_layernorm_->load_state_dict(state_dict.select("input_layernorm.")); - post_attention_layernorm_->load_state_dict( - state_dict.select("post_attention_layernorm.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - self_attn_->verify_loaded_weights(prefix + "self_attn."); - mlp_->verify_loaded_weights(prefix + "mlp."); - input_layernorm_->verify_loaded_weights(prefix + "input_layernorm."); - post_attention_layernorm_->verify_loaded_weights( - prefix + "post_attention_layernorm."); - } - private: // parameter members, must be registered LlamaAttention self_attn_{nullptr}; @@ -270,26 +255,6 @@ class LlamaModelImpl : public Module { return norm_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - embed_tokens_->load_state_dict(state_dict.select("embed_tokens.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("layers." + std::to_string(i) + ".")); - } - norm_->load_state_dict(state_dict.select("norm.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + - "."); - } - norm_->verify_loaded_weights(prefix + "norm."); - } - private: // parameter members, must be registered ParallelEmbedding embed_tokens_{nullptr}; @@ -347,17 +312,6 @@ class LlamaForCausalLMImpl : public Module { return lm_head_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - model_->load_state_dict(state_dict.select("model.")); - lm_head_->load_state_dict(state_dict.select("lm_head.")); - } - - void verify_loaded_weights() const { - model_->verify_loaded_weights("model."); - lm_head_->verify_loaded_weights("lm_head."); - } - private: // parameter members, must be registered LlamaModel model_{nullptr}; diff --git a/src/models/mistral.h b/src/models/mistral.h index 9b8e8ad0..a5822482 100644 --- a/src/models/mistral.h +++ b/src/models/mistral.h @@ -34,7 +34,7 @@ class MistralMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + LegacyFusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, /*bias=*/false, @@ -72,7 +72,7 @@ class MistralMLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; ActFunc act_func_{nullptr}; diff --git a/src/models/models.h b/src/models/models.h index 9dfabab8..8c37049b 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -5,9 +5,9 @@ // #include "baichuan.h" // IWYU pargma: keep // #include "bloom.h" // IWYU pragma: keep // #include "chatglm.h" // IWYU pragma: keep -// #include "gemma.h" // IWYU pragma: keep +#include "gemma.h" // IWYU pragma: keep // #include "gemma2.h" // IWYU pragma: keep -#include "gpt2.h" // IWYU pragma: keep +// #include "gpt2.h" // IWYU pragma: keep // #include "gpt_j.h" // IWYU pragma: keep // #include "gpt_neox.h" // IWYU pragma: keep // #include "internlm.h" // IWYU pragma: keep diff --git a/src/models/qwen2.h b/src/models/qwen2.h index d93de88a..ba540730 100644 --- a/src/models/qwen2.h +++ b/src/models/qwen2.h @@ -40,7 +40,7 @@ class QWen2MLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + LegacyFusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, /*bias=*/false, @@ -78,7 +78,7 @@ class QWen2MLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function diff --git a/src/models/simple_model.h b/src/models/simple_model.h index dc22d3a0..82ee5d26 100644 --- a/src/models/simple_model.h +++ b/src/models/simple_model.h @@ -31,7 +31,7 @@ class SimpleMLPImpl : public Module { gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + LegacyFusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, false, @@ -65,7 +65,7 @@ class SimpleMLPImpl : public Module { } private: - FusedColumnParallelLinear gate_up_proj_{nullptr}; + LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; ActFunc act_func_{nullptr}; diff --git a/src/module/module.cpp b/src/module/module.cpp index cb245114..3e78f3b4 100644 --- a/src/module/module.cpp +++ b/src/module/module.cpp @@ -325,14 +325,16 @@ size_t Module::load(const StateDict& state_dict, continue; } - // clear the load status before loading - param.is_loaded = false; const auto tensor = param.loader(state_dict, key); if (!tensor.defined()) { - LOG(ERROR) << "Missing parameter: " << join_name(name_prefix, key); continue; } + if (param.is_loaded) { + LOG(WARNING) << "Parameter " << join_name(name_prefix, key) + << " is already loaded"; + } + if (param_tensor.sizes() == tensor.sizes()) { LOG(INFO) << "Loading parameter: " << join_name(name_prefix, key) << " of size " << tensor.sizes(); From 2ad3ca8bc41c8cb2dca1b3c421a4b7e6d050a1fe Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 14:44:03 -0700 Subject: [PATCH 05/14] fix gemma --- src/layers/fused_linear.cpp | 89 ++++++++++++++++++++++++++++++ src/layers/fused_linear.h | 41 ++++++++++++++ src/layers/qkv_linear.cpp | 70 ++++++++++++------------ src/layers/qkv_linear.h | 20 ++----- src/models/gemma.h | 105 +++++++----------------------------- src/models/models.h | 2 +- 6 files changed, 190 insertions(+), 137 deletions(-) diff --git a/src/layers/fused_linear.cpp b/src/layers/fused_linear.cpp index d88cd565..be8fb399 100644 --- a/src/layers/fused_linear.cpp +++ b/src/layers/fused_linear.cpp @@ -96,4 +96,93 @@ void LegacyFusedColumnParallelLinearImpl::verify_loaded_weights( } } +FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( + int64_t in_features, + const std::vector& out_features_vec, + const std::vector& prefixes, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + prefixes_ = prefixes; + // check if the linear layers can be fused + fused_ = quant_args.can_be_fused(); + if (fused_) { + // fused linear layer + const int64_t out_features = std::accumulate( + out_features_vec.begin(), out_features_vec.end(), int64_t(0)); + fused_linear_ = LegacyColumnParallelLinear(in_features, + out_features, + bias, + gather_output, + quant_args, + parallel_args, + options); + // calculate split sizes + split_sizes_.reserve(out_features_vec.size()); + const auto world_size = parallel_args.world_size(); + for (const auto& out_features : out_features_vec) { + CHECK(out_features % world_size == 0) + << "out_features " << out_features << " not divisible by world_size " + << world_size; + split_sizes_.push_back(out_features / world_size); + } + } else { + // non-fused linear layers + parallel_linears_.reserve(out_features_vec.size()); + for (const auto& out_features : out_features_vec) { + parallel_linears_.emplace_back(in_features, + out_features, + bias, + gather_output, + quant_args, + parallel_args, + options); + } + } +} + +std::vector FusedColumnParallelLinearImpl::forward( + torch::Tensor input) { + if (fused_) { + auto fused_output = fused_linear_->forward(input); + return fused_output.split(split_sizes_, /*dim=*/1); + } + + // otherwise, use the non-fused linear layers + std::vector outputs; + outputs.reserve(parallel_linears_.size()); + for (auto& parallel_linear : parallel_linears_) { + auto output = parallel_linear->forward(input); + outputs.push_back(output); + } + return outputs; +} + +size_t FusedColumnParallelLinearImpl::load(const StateDict& state_dict, + const std::string&) { + if (fused_) { + fused_linear_->load_state_dict(state_dict, prefixes_); + } else { + CHECK_EQ(parallel_linears_.size(), prefixes_.size()); + for (size_t i = 0; i < parallel_linears_.size(); ++i) { + parallel_linears_[i]->load_state_dict(state_dict.select(prefixes_[i])); + } + } + return 0; +} + +bool FusedColumnParallelLinearImpl::verify( + const std::string& name_prefix) const { + if (fused_) { + fused_linear_->verify_loaded_weights(name_prefix); + } else { + for (const auto& parallel_linear : parallel_linears_) { + parallel_linear->verify_loaded_weights(name_prefix); + } + } + return true; +} + } // namespace llm diff --git a/src/layers/fused_linear.h b/src/layers/fused_linear.h index bbb08fea..c2c10935 100644 --- a/src/layers/fused_linear.h +++ b/src/layers/fused_linear.h @@ -48,4 +48,45 @@ class LegacyFusedColumnParallelLinearImpl : public Module { }; LLM_MODULE(LegacyFusedColumnParallelLinear); +class FusedColumnParallelLinearImpl : public Module { + public: + FusedColumnParallelLinearImpl(int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); + + std::vector forward(torch::Tensor input); + + // load weights from the checkpoint, override this method if necessary + // returns the number of loaded parameters + size_t load(const StateDict& state_dict, + const std::string& name_prefix = std::string()) override; + + // verify whether the weights are loaded, override this method if necessary + bool verify(const std::string& name_prefix = std::string()) const override; + + // whether the linear layer is fused + bool fused() const { return fused_; } + + private: + // non-fused linear layers + std::vector parallel_linears_; + + // fused linear layer + LegacyColumnParallelLinear fused_linear_{nullptr}; + + // sizes for each split + std::vector split_sizes_; + + std::vector prefixes_; + + // whether the linear layer is fused + bool fused_ = false; +}; +LLM_MODULE(FusedColumnParallelLinear); + } // namespace llm diff --git a/src/layers/qkv_linear.cpp b/src/layers/qkv_linear.cpp index ab8ba299..054a45e2 100644 --- a/src/layers/qkv_linear.cpp +++ b/src/layers/qkv_linear.cpp @@ -12,13 +12,19 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl( int64_t head_dim, bool bias, bool gather_output, + const std::vector& prefixes, const QuantArgs& quant_args, const ParallelArgs& parallel_args, - const torch::TensorOptions& options) - : n_kv_heads_(n_kv_heads), head_dim_(head_dim) { + const torch::TensorOptions& options) { + CHECK_EQ(prefixes.size(), 3) + << "prefixes size must be 3 for q, k, v projections"; + // calculate logical kv heads with support of MQA/GQA const int32_t world_size = parallel_args.world_size(); int64_t effective_kv_heads = n_kv_heads; + // replication ratio of kv heads for MQA/GQA cases + int64_t kv_replication_ratio = 0; + if (n_kv_heads >= world_size) { // partition kv heads evenly across world_size for MHA CHECK_EQ(n_kv_heads % world_size, 0) @@ -27,7 +33,7 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl( // replicate kv heads evenly across world_size for GQA/MQA CHECK_EQ(world_size % n_kv_heads, 0) << "kv heads can't be replicated evenly across world_size"; - kv_replication_ratio_ = world_size / n_kv_heads; + kv_replication_ratio = world_size / n_kv_heads; effective_kv_heads = world_size; } @@ -36,45 +42,43 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl( effective_kv_heads * head_dim, effective_kv_heads * head_dim}; - parallel_linear_ = - register_module("parallel_linear", - LegacyFusedColumnParallelLinear(hidden_size, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options)); -} - -// special load_state_dict for fused cases -void QKVColumnParallelLinearImpl::load_state_dict( - const StateDict& state_dict, - const std::vector& prefixes, - const std::vector& kv_prefixes) { - if (kv_replication_ratio_ > 1) { - // replicate kv heads - auto kv_replicated_state_dict = state_dict.select_with_transform( - "", [&](const std::string& name, const torch::Tensor& tensor) { - for (const auto& kv_prefix : kv_prefixes) { - if (absl::StartsWith(name, kv_prefix)) { + // create state_dict selector to handle MQA/GQA cases + // for MQA/GQA cases, we need to replicate the weights of kv heads. + auto state_dict_selector = [=](const StateDict& sd, const std::string&) { + if (kv_replication_ratio <= 1) { + return sd.select(""); + } + // replicate kv heads for MQA/GQA cases + return sd.select_with_transform( + "", [=](const std::string& tensor_name, const torch::Tensor& tensor) { + // skip query weights + for (size_t i = 1; i < prefixes.size(); ++i) { + const auto& kv_prefix = prefixes[i]; + if (absl::StartsWith(tensor_name, kv_prefix)) { // reshape to [n_kv_heads, head_dim, ...] - auto reshaped_tensor = - tensor.reshape({n_kv_heads_, head_dim_, -1}); + auto reshaped_tensor = tensor.reshape({n_kv_heads, head_dim, -1}); // interleave repeat kv heads along kv_head dim reshaped_tensor = reshaped_tensor.repeat_interleave( - kv_replication_ratio_, /*dim=*/0); + kv_replication_ratio, /*dim=*/0); // reshape to [n_kv_heads * kv_replication_ratio * head_dim, ...] return reshaped_tensor.reshape( - {n_kv_heads_ * kv_replication_ratio_ * head_dim_, -1}); + {n_kv_heads * kv_replication_ratio * head_dim, -1}); } } return tensor; }); - parallel_linear_->load_state_dict(kv_replicated_state_dict, prefixes); - } else { - parallel_linear_->load_state_dict(state_dict, prefixes); - } + }; + + parallel_linear_ = register_module("qkv_parallel_linear", + FusedColumnParallelLinear(hidden_size, + out_features, + prefixes, + bias, + gather_output, + quant_args, + parallel_args, + options), + state_dict_selector); } } // namespace llm diff --git a/src/layers/qkv_linear.h b/src/layers/qkv_linear.h index b5e34871..89dd2c6c 100644 --- a/src/layers/qkv_linear.h +++ b/src/layers/qkv_linear.h @@ -22,6 +22,7 @@ class QKVColumnParallelLinearImpl : public Module { int64_t head_dim, bool bias, bool gather_output, + const std::vector& prefixes, const QuantArgs& quant_args, const ParallelArgs& parallel_args, const torch::TensorOptions& options); @@ -30,24 +31,9 @@ class QKVColumnParallelLinearImpl : public Module { return parallel_linear_->forward(input); } - // special load_state_dict for fused cases - void load_state_dict(const StateDict& state_dict, - const std::vector& prefixes, - const std::vector& kv_prefixes); - - void verify_loaded_weights(const std::string& prefix = "") const { - parallel_linear_->verify_loaded_weights(prefix); - } - private: - LegacyFusedColumnParallelLinear parallel_linear_{nullptr}; - - // replication ratio of kv heads for MQA/GQA cases - int64_t kv_replication_ratio_ = 0; - - int64_t n_kv_heads_ = 0; - - int64_t head_dim_ = 0; + // registered modules + FusedColumnParallelLinear parallel_linear_{nullptr}; }; LLM_MODULE(QKVColumnParallelLinear); diff --git a/src/models/gemma.h b/src/models/gemma.h index f8870cd1..687dd5b2 100644 --- a/src/models/gemma.h +++ b/src/models/gemma.h @@ -40,14 +40,16 @@ class GemmaMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - LegacyFusedColumnParallelLinear( + FusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, + std::vector{"gate_proj.", "up_proj."}, /*bias=*/false, /*gather_output=*/false, quant_args, parallel_args, - options)); + options), + /*selector=*/nullptr); down_proj_ = register_module("down_proj", LegacyRowParallelLinear(intermediate_size, @@ -64,21 +66,9 @@ class GemmaMLPImpl : public Module { return down_proj_(act_func_(gate_up[0]) * gate_up[1]); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."}); - down_proj_->load_state_dict(state_dict.select("down_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj]."); - down_proj_->verify_loaded_weights(prefix + "down_proj."); - } - private: // parameter members, must be registered - LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; + FusedColumnParallelLinear gate_up_proj_{nullptr}; LegacyRowParallelLinear down_proj_{nullptr}; // activation function @@ -103,16 +93,20 @@ class GemmaAttentionImpl : public Module { std::max(1, n_kv_heads / world_size); // register submodules - qkv_proj_ = register_module("qkv_proj", - QKVColumnParallelLinear(hidden_size, - n_heads, - n_kv_heads, - head_dim, - /*bias=*/false, - /*gather_output=*/false, - quant_args, - parallel_args, - options)); + qkv_proj_ = register_module( + "qkv_proj", + QKVColumnParallelLinear( + hidden_size, + n_heads, + n_kv_heads, + head_dim, + /*bias=*/false, + /*gather_output=*/false, + std::vector{"q_proj.", "k_proj.", "v_proj."}, + quant_args, + parallel_args, + options), + /*selector=*/nullptr); o_proj_ = register_module("o_proj", @@ -143,19 +137,6 @@ class GemmaAttentionImpl : public Module { return o_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - qkv_proj_->load_state_dict( - state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."}); - o_proj_->load_state_dict(state_dict.select("o_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj]."); - o_proj_->verify_loaded_weights(prefix + "o_proj."); - } - private: // parameter members, must be registered QKVColumnParallelLinear qkv_proj_{nullptr}; @@ -209,21 +190,6 @@ class GemmaDecoderLayerImpl : public Module { hidden_states += residual; return hidden_states; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - input_layernorm_->load_state_dict((state_dict.select("input_layernorm."))); - mlp_->load_state_dict(state_dict.select("mlp.")); - post_attention_layernorm_->load_state_dict( - (state_dict.select("post_attention_layernorm."))); - self_attn_->load_state_dict(state_dict.select("self_attn.")); - } - void verify_loaded_weights(const std::string& prefix) const { - self_attn_->verify_loaded_weights(prefix + "self_attn."); - mlp_->verify_loaded_weights(prefix + "mlp."); - input_layernorm_->verify_loaded_weights(prefix + "input_layernorm."); - post_attention_layernorm_->verify_loaded_weights( - prefix + "post_attention_layernorm."); - } private: GemmaAttention self_attn_{nullptr}; @@ -288,26 +254,6 @@ class GemmaModelImpl : public Module { return norm_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - embed_tokens_->load_state_dict(state_dict.select("embed_tokens.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("layers." + std::to_string(i) + ".")); - } - norm_->load_state_dict((state_dict.select("norm."))); - } - - void verify_loaded_weights(const std::string& prefix) const { - embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + - "."); - } - norm_->verify_loaded_weights(prefix + "norm."); - } - private: ModelArgs modelArgs_; @@ -372,19 +318,6 @@ class GemmaForCausalLMImpl : public Module { return lm_head_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - model_->load_state_dict(state_dict.select("model.")); - - // Share the embedding weights with the final llm_head layer. - lm_head_->load_state_dict(state_dict.select("model.embed_tokens.")); - } - - void verify_loaded_weights() const { - model_->verify_loaded_weights("model."); - lm_head_->verify_loaded_weights("model.embed_tokens."); - } - private: // parameter members, must be registered GemmaModel model_{nullptr}; diff --git a/src/models/models.h b/src/models/models.h index 8c37049b..f0a5d887 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -7,7 +7,7 @@ // #include "chatglm.h" // IWYU pragma: keep #include "gemma.h" // IWYU pragma: keep // #include "gemma2.h" // IWYU pragma: keep -// #include "gpt2.h" // IWYU pragma: keep +#include "gpt2.h" // IWYU pragma: keep // #include "gpt_j.h" // IWYU pragma: keep // #include "gpt_neox.h" // IWYU pragma: keep // #include "internlm.h" // IWYU pragma: keep From d7f631b2f324b6ee40526842a38497f2a7e621db Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 15:05:44 -0700 Subject: [PATCH 06/14] update --- src/layers/fused_linear.cpp | 28 ++++----- src/layers/fused_linear.h | 8 +-- src/layers/linear.cpp | 43 ++++++------- src/layers/linear.h | 44 +++++++------- src/layers/linear_impl.h | 2 - src/models/gemma.h | 23 ++++--- src/models/gemma2.h | 118 ++++++------------------------------ src/models/gpt2.h | 43 ++++++------- src/models/llama.h | 23 ++++--- src/models/models.h | 6 +- 10 files changed, 125 insertions(+), 213 deletions(-) diff --git a/src/layers/fused_linear.cpp b/src/layers/fused_linear.cpp index be8fb399..6d305d8d 100644 --- a/src/layers/fused_linear.cpp +++ b/src/layers/fused_linear.cpp @@ -24,13 +24,13 @@ LegacyFusedColumnParallelLinearImpl::LegacyFusedColumnParallelLinearImpl( // fused linear layer const int64_t out_features = std::accumulate( out_features_vec.begin(), out_features_vec.end(), int64_t(0)); - fused_linear_ = LegacyColumnParallelLinear(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options); + fused_linear_ = ColumnParallelLinear(in_features, + out_features, + bias, + gather_output, + quant_args, + parallel_args, + options); // calculate split sizes split_sizes_.reserve(out_features_vec.size()); const auto world_size = parallel_args.world_size(); @@ -112,13 +112,13 @@ FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( // fused linear layer const int64_t out_features = std::accumulate( out_features_vec.begin(), out_features_vec.end(), int64_t(0)); - fused_linear_ = LegacyColumnParallelLinear(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options); + fused_linear_ = ColumnParallelLinear(in_features, + out_features, + bias, + gather_output, + quant_args, + parallel_args, + options); // calculate split sizes split_sizes_.reserve(out_features_vec.size()); const auto world_size = parallel_args.world_size(); diff --git a/src/layers/fused_linear.h b/src/layers/fused_linear.h index c2c10935..6a975163 100644 --- a/src/layers/fused_linear.h +++ b/src/layers/fused_linear.h @@ -35,10 +35,10 @@ class LegacyFusedColumnParallelLinearImpl : public Module { private: // non-fused linear layers - std::vector parallel_linears_; + std::vector parallel_linears_; // fused linear layer - LegacyColumnParallelLinear fused_linear_{nullptr}; + ColumnParallelLinear fused_linear_{nullptr}; // sizes for each split std::vector split_sizes_; @@ -74,10 +74,10 @@ class FusedColumnParallelLinearImpl : public Module { private: // non-fused linear layers - std::vector parallel_linears_; + std::vector parallel_linears_; // fused linear layer - LegacyColumnParallelLinear fused_linear_{nullptr}; + ColumnParallelLinear fused_linear_{nullptr}; // sizes for each split std::vector split_sizes_; diff --git a/src/layers/linear.cpp b/src/layers/linear.cpp index e35fc9b7..70a6d6f8 100644 --- a/src/layers/linear.cpp +++ b/src/layers/linear.cpp @@ -215,14 +215,13 @@ std::shared_ptr create_row_parallel_linear( // construct a ColumnParallelLinear. // chose right implementation based on the args. -LegacyColumnParallelLinear::LegacyColumnParallelLinear( - int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) +ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) : ModuleHolder(create_column_parallel_linear(in_features, out_features, bias, @@ -231,13 +230,12 @@ LegacyColumnParallelLinear::LegacyColumnParallelLinear( parallel_args, options)) {} -LegacyColumnParallelLinear::LegacyColumnParallelLinear( - int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) +ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) : ModuleHolder(create_column_parallel_linear(in_features, out_features, bias, @@ -248,14 +246,13 @@ LegacyColumnParallelLinear::LegacyColumnParallelLinear( // construct a rotary positional embedding. // chose right implementation based on the args. -LegacyRowParallelLinear::LegacyRowParallelLinear( - int64_t in_features, - int64_t out_features, - bool bias, - bool input_is_parallelized, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) +RowParallelLinear::RowParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool input_is_parallelized, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) : ModuleHolder(create_row_parallel_linear(in_features, out_features, bias, diff --git a/src/layers/linear.h b/src/layers/linear.h index a9145e5d..99da103b 100644 --- a/src/layers/linear.h +++ b/src/layers/linear.h @@ -39,42 +39,42 @@ class ParallelLinearImpl : public Module { } }; -class LegacyColumnParallelLinear : public ModuleHolder { +class ColumnParallelLinear : public ModuleHolder { public: using ModuleHolder::ModuleHolder; using Impl [[maybe_unused]] = ParallelLinearImpl; // construct a rotary positional embedding. // chose right implementation based on the args. - LegacyColumnParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); + ColumnParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); - LegacyColumnParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); + ColumnParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); }; -class LegacyRowParallelLinear : public ModuleHolder { +class RowParallelLinear : public ModuleHolder { public: using ModuleHolder::ModuleHolder; using Impl [[maybe_unused]] = ParallelLinearImpl; // construct a rotary positional embedding. // chose right implementation based on the args. - LegacyRowParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool input_is_parallelized, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); + RowParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool input_is_parallelized, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); }; } // namespace llm diff --git a/src/layers/linear_impl.h b/src/layers/linear_impl.h index 62bcb7a9..fe1f6b7c 100644 --- a/src/layers/linear_impl.h +++ b/src/layers/linear_impl.h @@ -62,7 +62,6 @@ class ColumnParallelLinearImpl : public ParallelLinearImpl { // parallel args ParallelArgs parallel_args_; }; -LLM_MODULE(ColumnParallelLinear); // Linear layer with row parallelism. // The linear layer is defined as Y = XA + b. A is parallelized along @@ -116,5 +115,4 @@ class RowParallelLinearImpl : public ParallelLinearImpl { // parallel args ParallelArgs parallel_args_; }; -LLM_MODULE(RowParallelLinear); } // namespace llm diff --git a/src/models/gemma.h b/src/models/gemma.h index 687dd5b2..f994d23b 100644 --- a/src/models/gemma.h +++ b/src/models/gemma.h @@ -52,13 +52,13 @@ class GemmaMLPImpl : public Module { /*selector=*/nullptr); down_proj_ = register_module("down_proj", - LegacyRowParallelLinear(intermediate_size, - hidden_size, - /*bias=*/false, - /*input_is_parallelized=*/true, - quant_args, - parallel_args, - options)); + RowParallelLinear(intermediate_size, + hidden_size, + /*bias=*/false, + /*input_is_parallelized=*/true, + quant_args, + parallel_args, + options)); } torch::Tensor forward(torch::Tensor x) { @@ -69,7 +69,7 @@ class GemmaMLPImpl : public Module { private: // parameter members, must be registered FusedColumnParallelLinear gate_up_proj_{nullptr}; - LegacyRowParallelLinear down_proj_{nullptr}; + RowParallelLinear down_proj_{nullptr}; // activation function ActFunc act_func_{nullptr}; @@ -108,9 +108,8 @@ class GemmaAttentionImpl : public Module { options), /*selector=*/nullptr); - o_proj_ = - register_module("o_proj", - LegacyRowParallelLinear(n_heads * head_dim, + o_proj_ = register_module("o_proj", + RowParallelLinear(n_heads * head_dim, hidden_size, /*bias=*/false, /*input_is_parallelized=*/true, @@ -141,7 +140,7 @@ class GemmaAttentionImpl : public Module { // parameter members, must be registered QKVColumnParallelLinear qkv_proj_{nullptr}; - LegacyRowParallelLinear o_proj_{nullptr}; + RowParallelLinear o_proj_{nullptr}; // module members without parameters Attention atten_{nullptr}; diff --git a/src/models/gemma2.h b/src/models/gemma2.h index 23e35eff..98b0202a 100644 --- a/src/models/gemma2.h +++ b/src/models/gemma2.h @@ -39,14 +39,16 @@ class Gemma2MLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - LegacyFusedColumnParallelLinear( + FusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, + std::vector{"gate_proj.", "up_proj."}, /*bias=*/false, /*gather_output=*/false, quant_args, parallel_args, - options)); + options), + /*selector=*/nullptr); down_proj_ = register_module("down_proj", RowParallelLinear(intermediate_size, @@ -63,21 +65,9 @@ class Gemma2MLPImpl : public Module { return down_proj_(act_func_(gate_up[0]) * gate_up[1]); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."}); - down_proj_->load_state_dict(state_dict.select("down_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj]."); - down_proj_->verify_loaded_weights(prefix + "down_proj."); - } - private: // parameter members, must be registered - LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; + FusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function @@ -103,16 +93,20 @@ class Gemma2AttentionImpl : public Module { std::max(1, n_kv_heads / world_size); // register submodules - qkv_proj_ = register_module("qkv_proj", - QKVColumnParallelLinear(hidden_size, - n_heads, - n_kv_heads, - head_dim, - args.attn_bias(), - /*gather_output=*/false, - quant_args, - parallel_args, - options)); + qkv_proj_ = register_module( + "qkv_proj", + QKVColumnParallelLinear( + hidden_size, + n_heads, + n_kv_heads, + head_dim, + args.attn_bias(), + /*gather_output=*/false, + std::vector{"q_proj.", "k_proj.", "v_proj."}, + quant_args, + parallel_args, + options), + /*selector=*/nullptr); o_proj_ = register_module("o_proj", RowParallelLinear(n_heads * head_dim, @@ -146,19 +140,6 @@ class Gemma2AttentionImpl : public Module { return o_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - qkv_proj_->load_state_dict( - state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."}); - o_proj_->load_state_dict(state_dict.select("o_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj]."); - o_proj_->verify_loaded_weights(prefix + "o_proj."); - } - private: // parameter members, must be registered QKVColumnParallelLinear qkv_proj_{nullptr}; @@ -167,9 +148,6 @@ class Gemma2AttentionImpl : public Module { // module members without parameters Attention atten_{nullptr}; - - // size for q, k, v - std::vector qkv_sizes_; }; LLM_MODULE(Gemma2Attention); @@ -226,29 +204,6 @@ class Gemma2DecoderLayerImpl : public Module { hidden_states += residual; return hidden_states; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - input_layernorm_->load_state_dict((state_dict.select("input_layernorm."))); - self_attn_->load_state_dict(state_dict.select("self_attn.")); - mlp_->load_state_dict(state_dict.select("mlp.")); - post_attention_layernorm_->load_state_dict( - (state_dict.select("post_attention_layernorm."))); - pre_feedforward_layernorm_->load_state_dict( - state_dict.select("pre_feedforward_layernorm.")); - post_feedforward_layernorm_->load_state_dict( - state_dict.select("post_feedforward_layernorm.")); - } - void verify_loaded_weights(const std::string& prefix) const { - input_layernorm_->verify_loaded_weights(prefix + "input_layernorm."); - self_attn_->verify_loaded_weights(prefix + "self_attn."); - mlp_->verify_loaded_weights(prefix + "mlp."); - post_attention_layernorm_->verify_loaded_weights( - prefix + "post_attention_layernorm."); - pre_feedforward_layernorm_->verify_loaded_weights( - prefix + "pre_feedforward_layernorm."); - post_feedforward_layernorm_->verify_loaded_weights( - prefix + "post_feedforward_layernorm."); - } private: Gemma2Attention self_attn_{nullptr}; @@ -322,26 +277,6 @@ class Gemma2ModelImpl : public Module { return norm_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - embed_tokens_->load_state_dict(state_dict.select("embed_tokens.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("layers." + std::to_string(i) + ".")); - } - norm_->load_state_dict((state_dict.select("norm."))); - } - - void verify_loaded_weights(const std::string& prefix) const { - embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + - "."); - } - norm_->verify_loaded_weights(prefix + "norm."); - } - private: ModelArgs modelArgs_; @@ -375,7 +310,7 @@ class Gemma2ForCausalLMImpl : public Module { model_ = register_module( "model", Gemma2Model(args, quant_args, parallel_args, options)); - lm_head_ = register_module("lm_head", + lm_head_ = register_module("model.embed_tokens", ColumnParallelLinear(args.hidden_size(), args.vocab_size(), /*bias=*/false, @@ -409,19 +344,6 @@ class Gemma2ForCausalLMImpl : public Module { final_logit_soft_cap_; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - model_->load_state_dict(state_dict.select("model.")); - - // Share the embedding weights with the final llm_head layer. - lm_head_->load_state_dict(state_dict.select("model.embed_tokens.")); - } - - void verify_loaded_weights() const { - model_->verify_loaded_weights("model."); - lm_head_->verify_loaded_weights("model.embed_tokens."); - } - private: // parameter members, must be registered Gemma2Model model_{nullptr}; diff --git a/src/models/gpt2.h b/src/models/gpt2.h index 98be0cf7..2644ec48 100644 --- a/src/models/gpt2.h +++ b/src/models/gpt2.h @@ -47,32 +47,31 @@ class GPT2MLPImpl : public Module { // register the weight parameter c_fc_ = register_module("c_fc", - LegacyColumnParallelLinear(hidden_size, - intermediate_size, - /*bias=*/true, - /*gather_output=*/false, - quant_args, - parallel_args, - options), + ColumnParallelLinear(hidden_size, + intermediate_size, + /*bias=*/true, + /*gather_output=*/false, + quant_args, + parallel_args, + options), detail::transpose_selector); - c_proj_ = - register_module("c_proj", - LegacyRowParallelLinear(intermediate_size, + c_proj_ = register_module("c_proj", + RowParallelLinear(intermediate_size, hidden_size, /*bias=*/true, /*input_is_parallelized=*/true, quant_args, parallel_args, options), - detail::transpose_selector); + detail::transpose_selector); } torch::Tensor forward(torch::Tensor x) { return c_proj_(act_(c_fc_(x))); } private: // parameter members, must be registered - LegacyColumnParallelLinear c_fc_{nullptr}; - LegacyRowParallelLinear c_proj_{nullptr}; + ColumnParallelLinear c_fc_{nullptr}; + RowParallelLinear c_proj_{nullptr}; ActFunc act_{nullptr}; }; @@ -91,27 +90,25 @@ class GPT2AttentionImpl : public Module { head_dim_ = args.head_dim(); // register submodules - c_attn_ = - register_module("c_attn", - LegacyColumnParallelLinear(hidden_size_, + c_attn_ = register_module("c_attn", + ColumnParallelLinear(hidden_size_, 3 * hidden_size_, /*bias=*/true, /*gather_output=*/false, quant_args, parallel_args, options), - detail::transpose_selector); + detail::transpose_selector); - c_proj_ = - register_module("c_proj", - LegacyRowParallelLinear(hidden_size_, + c_proj_ = register_module("c_proj", + RowParallelLinear(hidden_size_, hidden_size_, /*bias=*/true, /*input_is_parallelized=*/true, quant_args, parallel_args, options), - detail::transpose_selector); + detail::transpose_selector); // initialize attention atten_ = register_module( @@ -137,9 +134,9 @@ class GPT2AttentionImpl : public Module { private: // parameter members, must be registered - LegacyColumnParallelLinear c_attn_{nullptr}; + ColumnParallelLinear c_attn_{nullptr}; - LegacyRowParallelLinear c_proj_{nullptr}; + RowParallelLinear c_proj_{nullptr}; // module members without parameters Attention atten_{nullptr}; diff --git a/src/models/llama.h b/src/models/llama.h index f8aaf696..229da58f 100644 --- a/src/models/llama.h +++ b/src/models/llama.h @@ -49,13 +49,13 @@ class LlamaMLPImpl : public Module { down_proj_ = register_module("down_proj", - LegacyRowParallelLinear(intermediate_size, - hidden_size, - /*bias=*/false, - /*input_is_parallelized=*/true, - quant_args, - parallel_args, - options)); + RowParallelLinear(intermediate_size, + hidden_size, + /*bias=*/false, + /*input_is_parallelized=*/true, + quant_args, + parallel_args, + options)); } torch::Tensor forward(torch::Tensor x) { @@ -78,7 +78,7 @@ class LlamaMLPImpl : public Module { private: // parameter members, must be registered LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; - LegacyRowParallelLinear down_proj_{nullptr}; + RowParallelLinear down_proj_{nullptr}; // activation function ActFunc act_func_{nullptr}; @@ -113,9 +113,8 @@ class LlamaAttentionImpl : public Module { // parallel_args, // options)); - o_proj_ = - register_module("o_proj", - LegacyRowParallelLinear(hidden_size, + o_proj_ = register_module("o_proj", + RowParallelLinear(hidden_size, hidden_size, /*bias=*/false, /*input_is_parallelized=*/true, @@ -160,7 +159,7 @@ class LlamaAttentionImpl : public Module { // parameter members, must be registered QKVColumnParallelLinear qkv_proj_{nullptr}; - LegacyRowParallelLinear o_proj_{nullptr}; + RowParallelLinear o_proj_{nullptr}; // module members without parameters Attention atten_{nullptr}; diff --git a/src/models/models.h b/src/models/models.h index f0a5d887..1d79341d 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -5,9 +5,9 @@ // #include "baichuan.h" // IWYU pargma: keep // #include "bloom.h" // IWYU pragma: keep // #include "chatglm.h" // IWYU pragma: keep -#include "gemma.h" // IWYU pragma: keep -// #include "gemma2.h" // IWYU pragma: keep -#include "gpt2.h" // IWYU pragma: keep +#include "gemma.h" // IWYU pragma: keep +#include "gemma2.h" // IWYU pragma: keep +#include "gpt2.h" // IWYU pragma: keep // #include "gpt_j.h" // IWYU pragma: keep // #include "gpt_neox.h" // IWYU pragma: keep // #include "internlm.h" // IWYU pragma: keep From 13c35ee74eac5ec0c3faa83bd29625745f775c13 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 15:25:08 -0700 Subject: [PATCH 07/14] fix qkv_linear test --- src/layers/qkv_linear_test.cpp | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/layers/qkv_linear_test.cpp b/src/layers/qkv_linear_test.cpp index 359b68e3..bffded22 100644 --- a/src/layers/qkv_linear_test.cpp +++ b/src/layers/qkv_linear_test.cpp @@ -9,7 +9,7 @@ namespace llm { -class QKVLinearTest +class QKVColumnParallelLinearTest : public ::testing::TestWithParam> {}; -TEST_P(QKVLinearTest, LoadFusedWeight) { +TEST_P(QKVColumnParallelLinearTest, LoadFusedWeight) { const auto& [n_tokens, n_heads, n_kv_heads, n_shards, head_dim, hidden_size] = GetParam(); @@ -51,18 +51,19 @@ TEST_P(QKVLinearTest, LoadFusedWeight) { for (int32_t shard_id = 0; shard_id < n_shards; ++shard_id) { QuantArgs quant_args; ParallelArgs parallel_args(shard_id, n_shards, nullptr); - QKVColumnParallelLinearImpl linear(hidden_size, - n_heads, - n_kv_heads, - head_dim, - /*bias=*/false, - /*gather_output=*/false, - quant_args, - parallel_args, - options); - linear.load_state_dict(state_dict, - /*prefixes=*/{"query.", "key.", "value."}, - /*kv_prefixes=*/{"key.", "value."}); + QKVColumnParallelLinearImpl linear( + hidden_size, + n_heads, + n_kv_heads, + head_dim, + /*bias=*/false, + /*gather_output=*/false, + std::vector{"query.", "key.", "value."}, + quant_args, + parallel_args, + options); + linear.load(state_dict); + EXPECT_TRUE(linear.verify()); // generate random input and compare with the output auto input = torch::randn({n_tokens, hidden_size}, options); @@ -84,7 +85,7 @@ TEST_P(QKVLinearTest, LoadFusedWeight) { INSTANTIATE_TEST_SUITE_P( QKVLinearTestSuite, - QKVLinearTest, + QKVColumnParallelLinearTest, ::testing::Combine(::testing::Values(10, 32), // n_tokens ::testing::Values(8, 16, 32), // n_heads ::testing::Values(1, 2, 4, 8), // n_kv_heads From 2899884711b8828f3ed6612c4f4d6348da0a9b47 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 15:28:07 -0700 Subject: [PATCH 08/14] rename --- src/layers/fused_linear.cpp | 86 ------------------------------------- src/layers/fused_linear.h | 36 ---------------- src/models/aquila.h | 4 +- src/models/baichuan.h | 4 +- src/models/internlm.h | 4 +- src/models/llama.h | 4 +- src/models/mistral.h | 4 +- src/models/qwen2.h | 4 +- src/models/simple_model.h | 4 +- 9 files changed, 14 insertions(+), 136 deletions(-) diff --git a/src/layers/fused_linear.cpp b/src/layers/fused_linear.cpp index 6d305d8d..86b5ca59 100644 --- a/src/layers/fused_linear.cpp +++ b/src/layers/fused_linear.cpp @@ -10,92 +10,6 @@ namespace llm { -LegacyFusedColumnParallelLinearImpl::LegacyFusedColumnParallelLinearImpl( - int64_t in_features, - const std::vector& out_features_vec, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { - // check if the linear layers can be fused - fused_ = quant_args.can_be_fused(); - if (fused_) { - // fused linear layer - const int64_t out_features = std::accumulate( - out_features_vec.begin(), out_features_vec.end(), int64_t(0)); - fused_linear_ = ColumnParallelLinear(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options); - // calculate split sizes - split_sizes_.reserve(out_features_vec.size()); - const auto world_size = parallel_args.world_size(); - for (const auto& out_features : out_features_vec) { - CHECK(out_features % world_size == 0) - << "out_features " << out_features << " not divisible by world_size " - << world_size; - split_sizes_.push_back(out_features / world_size); - } - } else { - // non-fused linear layers - parallel_linears_.reserve(out_features_vec.size()); - for (const auto& out_features : out_features_vec) { - parallel_linears_.emplace_back(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options); - } - } -} - -std::vector LegacyFusedColumnParallelLinearImpl::forward( - torch::Tensor input) { - if (fused_) { - auto fused_output = fused_linear_->forward(input); - return fused_output.split(split_sizes_, /*dim=*/1); - } - - // otherwise, use the non-fused linear layers - std::vector outputs; - outputs.reserve(parallel_linears_.size()); - for (auto& parallel_linear : parallel_linears_) { - auto output = parallel_linear->forward(input); - outputs.push_back(output); - } - return outputs; -} - -void LegacyFusedColumnParallelLinearImpl::load_state_dict( - const StateDict& state_dict, - const std::vector& prefixes) { - if (fused_) { - fused_linear_->load_state_dict(state_dict, prefixes); - } else { - CHECK_EQ(parallel_linears_.size(), prefixes.size()); - for (size_t i = 0; i < parallel_linears_.size(); ++i) { - parallel_linears_[i]->load_state_dict(state_dict.select(prefixes[i])); - } - } -} - -void LegacyFusedColumnParallelLinearImpl::verify_loaded_weights( - const std::string& prefix) const { - if (fused_) { - fused_linear_->verify_loaded_weights(prefix); - } else { - for (const auto& parallel_linear : parallel_linears_) { - parallel_linear->verify_loaded_weights(prefix); - } - } -} - FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( int64_t in_features, const std::vector& out_features_vec, diff --git a/src/layers/fused_linear.h b/src/layers/fused_linear.h index 6a975163..73323479 100644 --- a/src/layers/fused_linear.h +++ b/src/layers/fused_linear.h @@ -12,42 +12,6 @@ namespace llm { -class LegacyFusedColumnParallelLinearImpl : public Module { - public: - LegacyFusedColumnParallelLinearImpl(int64_t in_features, - const std::vector& out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); - - std::vector forward(torch::Tensor input); - - // load_state_dict for fused weights - void load_state_dict(const StateDict& state_dict, - const std::vector& prefixes); - - void verify_loaded_weights(const std::string& prefix = "") const; - - // whether the linear layer is fused - bool fused() const { return fused_; } - - private: - // non-fused linear layers - std::vector parallel_linears_; - - // fused linear layer - ColumnParallelLinear fused_linear_{nullptr}; - - // sizes for each split - std::vector split_sizes_; - - // whether the linear layer is fused - bool fused_ = false; -}; -LLM_MODULE(LegacyFusedColumnParallelLinear); - class FusedColumnParallelLinearImpl : public Module { public: FusedColumnParallelLinearImpl(int64_t in_features, diff --git a/src/models/aquila.h b/src/models/aquila.h index 6ab4ed6e..59e7b3e4 100644 --- a/src/models/aquila.h +++ b/src/models/aquila.h @@ -36,7 +36,7 @@ class AquilaMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - LegacyFusedColumnParallelLinear( + FusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, /*bias=*/false, @@ -74,7 +74,7 @@ class AquilaMLPImpl : public Module { private: // parameter members, must be registered - LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; + FusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function diff --git a/src/models/baichuan.h b/src/models/baichuan.h index 62f122d4..db77fcdc 100644 --- a/src/models/baichuan.h +++ b/src/models/baichuan.h @@ -46,7 +46,7 @@ class BaichuanMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - LegacyFusedColumnParallelLinear( + FusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, /*bias=*/false, @@ -84,7 +84,7 @@ class BaichuanMLPImpl : public Module { private: // parameter members, must be registered - LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; + FusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function diff --git a/src/models/internlm.h b/src/models/internlm.h index 3bea7989..8f100c32 100644 --- a/src/models/internlm.h +++ b/src/models/internlm.h @@ -35,7 +35,7 @@ class InternlmMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - LegacyFusedColumnParallelLinear( + FusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, /*bias=*/false, @@ -73,7 +73,7 @@ class InternlmMLPImpl : public Module { private: // parameter members, must be registered - LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; + FusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // calculate act(x) * y diff --git a/src/models/llama.h b/src/models/llama.h index 229da58f..f72d0d8c 100644 --- a/src/models/llama.h +++ b/src/models/llama.h @@ -38,7 +38,7 @@ class LlamaMLPImpl : public Module { // register the weight parameter // gate_up_proj_ = register_module( // "gate_up_proj", - // LegacyFusedColumnParallelLinear( + // FusedColumnParallelLinear( // hidden_size, // std::vector{intermediate_size, intermediate_size}, // /*bias=*/false, @@ -77,7 +77,7 @@ class LlamaMLPImpl : public Module { private: // parameter members, must be registered - LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; + FusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function diff --git a/src/models/mistral.h b/src/models/mistral.h index a5822482..9b8e8ad0 100644 --- a/src/models/mistral.h +++ b/src/models/mistral.h @@ -34,7 +34,7 @@ class MistralMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - LegacyFusedColumnParallelLinear( + FusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, /*bias=*/false, @@ -72,7 +72,7 @@ class MistralMLPImpl : public Module { private: // parameter members, must be registered - LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; + FusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; ActFunc act_func_{nullptr}; diff --git a/src/models/qwen2.h b/src/models/qwen2.h index ba540730..d93de88a 100644 --- a/src/models/qwen2.h +++ b/src/models/qwen2.h @@ -40,7 +40,7 @@ class QWen2MLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - LegacyFusedColumnParallelLinear( + FusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, /*bias=*/false, @@ -78,7 +78,7 @@ class QWen2MLPImpl : public Module { private: // parameter members, must be registered - LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; + FusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function diff --git a/src/models/simple_model.h b/src/models/simple_model.h index 82ee5d26..dc22d3a0 100644 --- a/src/models/simple_model.h +++ b/src/models/simple_model.h @@ -31,7 +31,7 @@ class SimpleMLPImpl : public Module { gate_up_proj_ = register_module( "gate_up_proj", - LegacyFusedColumnParallelLinear( + FusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, false, @@ -65,7 +65,7 @@ class SimpleMLPImpl : public Module { } private: - LegacyFusedColumnParallelLinear gate_up_proj_{nullptr}; + FusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; ActFunc act_func_{nullptr}; From 94442076f8f7304d9f0694dd64ae3624a00d7c9c Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 18:35:15 -0700 Subject: [PATCH 09/14] clean up legacy load and verify function for layers. --- src/layers/embedding.h | 69 ++++------------------------ src/layers/linear_impl.cpp | 12 ++++- src/layers/normalization.h | 76 ------------------------------- src/layers/normalization_test.cpp | 8 ++-- 4 files changed, 23 insertions(+), 142 deletions(-) diff --git a/src/layers/embedding.h b/src/layers/embedding.h index f54b4766..c4dbffc3 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -33,22 +33,6 @@ class EmbeddingImpl : public Module { return F::embedding(input, weight_); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_tensor("weight"); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - is_loaded_ = true; - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix) const { - CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight"; - } - void pretty_print(std::ostream& stream) const override { stream << name() << " " << weight_.sizes() << " " << weight_.device(); } @@ -100,26 +84,6 @@ class ParallelEmbeddingImpl : public Module { return output; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_sharded_tensor( - "weight", - /*dim=*/1, - /*rank=*/parallel_args_.rank(), - /*world_size=*/parallel_args_.world_size()); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - is_loaded_ = true; - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix) const { - CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight"; - } - void pretty_print(std::ostream& stream) const override { stream << name() << " " << weight_.sizes() << " " << weight_.device(); } @@ -147,14 +111,19 @@ class VocabParallelEmbeddingImpl : public Module { const ParallelArgs& parallel_args, const torch::TensorOptions& options) : parallel_args_(parallel_args) { - const int64_t num_embeddings_per_partition = - num_embeddings / parallel_args_.world_size(); - start_index_ = num_embeddings_per_partition * parallel_args_.rank(); + const auto rank = parallel_args_.rank(); + const auto world_size = parallel_args_.world_size(); + + const int64_t num_embeddings_per_partition = num_embeddings / world_size; + start_index_ = num_embeddings_per_partition * rank; end_index_ = start_index_ + num_embeddings_per_partition; // register the weight parameter - weight_ = register_parameter( + weight_ = register_sharded_parameter( "weight", + /*dim=*/0, + rank, + world_size, torch::empty({num_embeddings_per_partition, embedding_dim}, options)); } @@ -178,26 +147,6 @@ class VocabParallelEmbeddingImpl : public Module { return reduce_from_model_parallel_region(output, parallel_args_); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_sharded_tensor( - "weight", - /*dim=*/0, - /*rank=*/parallel_args_.rank(), - /*world_size=*/parallel_args_.world_size()); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - is_loaded_ = true; - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const { - CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight"; - } - void pretty_print(std::ostream& stream) const override { stream << name() << " " << weight_.sizes() << " " << weight_.device(); } diff --git a/src/layers/linear_impl.cpp b/src/layers/linear_impl.cpp index 3f301868..ff877e4d 100644 --- a/src/layers/linear_impl.cpp +++ b/src/layers/linear_impl.cpp @@ -18,6 +18,7 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl( const ParallelArgs& parallel_args, const torch::TensorOptions& options) : gather_output_(gather_output), parallel_args_(parallel_args) { + const auto rank = parallel_args_.rank(); const auto world_size = parallel_args_.world_size(); CHECK(out_features % world_size == 0) << "out_features " << out_features << " not divisible by world_size " @@ -26,8 +27,11 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl( // Note: torch.nn.functional.linear performs XA^T + b and as a result // we allocate the transpose. - weight_ = register_parameter( + weight_ = register_sharded_parameter( "weight", + /*dim=*/0, + rank, + world_size, torch::empty({out_features_per_partition, in_features}, options)); if (bias) { @@ -93,14 +97,18 @@ RowParallelLinearImpl::RowParallelLinearImpl( const torch::TensorOptions& options) : input_is_parallelized_(input_is_parallelized), parallel_args_(parallel_args) { + const auto rank = parallel_args_.rank(); const auto world_size = parallel_args_.world_size(); CHECK(in_features % world_size == 0) << "in_features " << in_features << " not divisible by world_size " << world_size; const int64_t in_features_per_partition = in_features / world_size; // Allocate the transpose since linear performs XA^T. - weight_ = register_parameter( + weight_ = register_sharded_parameter( "weight", + /*dim=*/1, + rank, + world_size, torch::empty({out_features, in_features_per_partition}, options)); if (bias) { diff --git a/src/layers/normalization.h b/src/layers/normalization.h index 8996ce78..c522f7f7 100644 --- a/src/layers/normalization.h +++ b/src/layers/normalization.h @@ -94,34 +94,6 @@ class LayerNormImpl : public Module { input, normalized_shape_, weight_, bias_, eps_); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_tensor("weight"); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - weight_is_loaded_ = true; - } - if (bias_.defined()) { - const auto bias = state_dict.get_tensor("bias"); - if (bias.defined()) { - CHECK_EQ(bias_.sizes(), bias.sizes()) - << "bias size mismatch for " << name(); - bias_.copy_(bias); - bias_is_loaded_ = true; - } - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const { - CHECK(weight_is_loaded_) - << "weight is not loaded for " << prefix + "weight"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; - } - void pretty_print(std::ostream& stream) const override { stream << name() << " " << weight_.sizes() << " " << weight_.device(); } @@ -159,22 +131,6 @@ class RMSNormImpl : public Module { return detail::rms_norm(input, weight_, eps_); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_tensor("weight"); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - is_loaded_ = true; - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const { - CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight"; - } - void pretty_print(std::ostream& stream) const override { stream << name() << " " << weight_.sizes() << " " << weight_.device(); } @@ -207,22 +163,6 @@ class GemmaRMSNormImpl : public Module { return detail::gemma_rms_norm(input, weight_, eps_); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_tensor("weight"); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - is_loaded_ = true; - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const { - CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight"; - } - void pretty_print(std::ostream& stream) const override { stream << name() << " " << weight_.sizes() << " " << weight_.device(); } @@ -268,22 +208,6 @@ class RMSNormResidualImpl : public Module { return detail::rms_norm(input, weight_, eps_); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_tensor("weight"); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - is_loaded_ = true; - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const { - CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight"; - } - void pretty_print(std::ostream& stream) const override { stream << name() << " " << weight_.sizes() << " " << weight_.device(); } diff --git a/src/layers/normalization_test.cpp b/src/layers/normalization_test.cpp index ba25019b..c2aa4031 100644 --- a/src/layers/normalization_test.cpp +++ b/src/layers/normalization_test.cpp @@ -26,8 +26,8 @@ TEST(NormalizationTest, LayerNorm) { LayerNorm norm(dim, eps, /*bias=*/true, options); // test load state dict - norm->load_state_dict(state_dict); - norm->verify_loaded_weights(); + norm->load(state_dict); + EXPECT_TRUE(norm->verify()); // verify output const auto input = torch::randn({100, dim}); @@ -88,8 +88,8 @@ TEST(NormalizationTest, RMSNorm) { RMSNorm norm(dim, eps, options); // test load state dict - norm->load_state_dict(state_dict); - norm->verify_loaded_weights(); + norm->load(state_dict); + EXPECT_TRUE(norm->verify()); // verify output const auto input = torch::randn({100, dim}); From f3a9e813f210d779a5b1fb9cabaa92818e7126c1 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 18:38:39 -0700 Subject: [PATCH 10/14] clean up pretty_print --- src/layers/embedding.h | 12 ------------ src/layers/linear_impl.h | 8 -------- src/layers/normalization.h | 16 ---------------- src/quantization/qlinear_gptq_marlin_impl.h | 10 ---------- src/quantization/qlinear_impl.h | 12 ------------ 5 files changed, 58 deletions(-) diff --git a/src/layers/embedding.h b/src/layers/embedding.h index c4dbffc3..9ef07033 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -33,10 +33,6 @@ class EmbeddingImpl : public Module { return F::embedding(input, weight_); } - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - // return the weight (for testing) torch::Tensor weight() const { return weight_; } @@ -84,10 +80,6 @@ class ParallelEmbeddingImpl : public Module { return output; } - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - // return the weight (for testing) torch::Tensor weight() const { return weight_; } @@ -147,10 +139,6 @@ class VocabParallelEmbeddingImpl : public Module { return reduce_from_model_parallel_region(output, parallel_args_); } - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - // return the weight (for testing) torch::Tensor weight() const { return weight_; } diff --git a/src/layers/linear_impl.h b/src/layers/linear_impl.h index fe1f6b7c..ff551649 100644 --- a/src/layers/linear_impl.h +++ b/src/layers/linear_impl.h @@ -42,10 +42,6 @@ class ColumnParallelLinearImpl : public ParallelLinearImpl { << "bias is not loaded for " << prefix + "bias"; } - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - // return the weight (for testing) torch::Tensor weight() const { return weight_; } @@ -95,10 +91,6 @@ class RowParallelLinearImpl : public ParallelLinearImpl { << "bias is not loaded for " << prefix + "bias"; } - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - // return the weight (for testing) torch::Tensor weight() const { return weight_; } diff --git a/src/layers/normalization.h b/src/layers/normalization.h index c522f7f7..da75a81b 100644 --- a/src/layers/normalization.h +++ b/src/layers/normalization.h @@ -94,10 +94,6 @@ class LayerNormImpl : public Module { input, normalized_shape_, weight_, bias_, eps_); } - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - private: // parameter members, must be registered torch::Tensor weight_{nullptr}; @@ -131,10 +127,6 @@ class RMSNormImpl : public Module { return detail::rms_norm(input, weight_, eps_); } - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - private: // parameter members, must be registered torch::Tensor weight_{nullptr}; @@ -163,10 +155,6 @@ class GemmaRMSNormImpl : public Module { return detail::gemma_rms_norm(input, weight_, eps_); } - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - private: // parameter members, must be registered torch::Tensor weight_{nullptr}; @@ -208,10 +196,6 @@ class RMSNormResidualImpl : public Module { return detail::rms_norm(input, weight_, eps_); } - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - private: // parameter members, must be registered torch::Tensor weight_{nullptr}; diff --git a/src/quantization/qlinear_gptq_marlin_impl.h b/src/quantization/qlinear_gptq_marlin_impl.h index dc6b3f40..e996b479 100644 --- a/src/quantization/qlinear_gptq_marlin_impl.h +++ b/src/quantization/qlinear_gptq_marlin_impl.h @@ -35,11 +35,6 @@ class ColumnParallelQLinearGPTQMarlinImpl : public ParallelLinearImpl { void load_state_dict(const StateDict& state_dict, const std::vector& prefixes) override; - void pretty_print(std::ostream& stream) const override { - stream << name() << " qweight=" << qweight_.sizes() - << " scales=" << scales_.sizes() << " device=" << qweight_.device(); - } - private: // parameter members, must be registered DEFINE_FUSED_WEIGHT(qweight); @@ -91,11 +86,6 @@ class RowParallelQLinearGPTQMarlinImpl : public ParallelLinearImpl { // whether the weight is loaded void verify_loaded_weights(const std::string& prefix = "") const override; - void pretty_print(std::ostream& stream) const override { - stream << name() << " qweight=" << qweight_.sizes() - << " scales=" << scales_.sizes() << " device=" << qweight_.device(); - } - private: // parameter members, must be registered DEFINE_WEIGHT(qweight); diff --git a/src/quantization/qlinear_impl.h b/src/quantization/qlinear_impl.h index c442a8bd..0269246f 100644 --- a/src/quantization/qlinear_impl.h +++ b/src/quantization/qlinear_impl.h @@ -76,12 +76,6 @@ class ColumnParallelQLinearImpl : public ParallelLinearImpl { void load_state_dict(const StateDict& state_dict, const std::vector& prefixes) override; - void pretty_print(std::ostream& stream) const override { - stream << name() << " qweight=" << qweight_.sizes() - << " qzeros=" << qzeros_.sizes() << " scales=" << scales_.sizes() - << " device=" << qweight_.device(); - } - private: // parameter members, must be registered DEFINE_FUSED_WEIGHT(qweight); @@ -148,12 +142,6 @@ class RowParallelQLinearImpl : public ParallelLinearImpl { // whether the weight is loaded void verify_loaded_weights(const std::string& prefix = "") const override; - void pretty_print(std::ostream& stream) const override { - stream << name() << " qweight=" << qweight_.sizes() - << " qzeros=" << qzeros_.sizes() << " scales=" << scales_.sizes() - << " device=" << qweight_.device(); - } - private: // parameter members, must be registered DEFINE_WEIGHT(qweight); From 045658c38e0cc5f3a9633e8171b08c6a5f785d3f Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 18:54:39 -0700 Subject: [PATCH 11/14] fix llama --- src/layers/qkv_linear.cpp | 2 +- src/layers/qkv_linear.h | 2 +- src/layers/qkv_linear_test.cpp | 2 +- src/models/gemma.h | 2 +- src/models/gemma2.h | 2 +- src/models/llama.h | 72 ++++++++++++---------------------- src/models/models.h | 2 +- 7 files changed, 32 insertions(+), 52 deletions(-) diff --git a/src/layers/qkv_linear.cpp b/src/layers/qkv_linear.cpp index 054a45e2..8ef9b532 100644 --- a/src/layers/qkv_linear.cpp +++ b/src/layers/qkv_linear.cpp @@ -10,9 +10,9 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl( int64_t n_heads, int64_t n_kv_heads, int64_t head_dim, + const std::vector& prefixes, bool bias, bool gather_output, - const std::vector& prefixes, const QuantArgs& quant_args, const ParallelArgs& parallel_args, const torch::TensorOptions& options) { diff --git a/src/layers/qkv_linear.h b/src/layers/qkv_linear.h index 89dd2c6c..64923e27 100644 --- a/src/layers/qkv_linear.h +++ b/src/layers/qkv_linear.h @@ -20,9 +20,9 @@ class QKVColumnParallelLinearImpl : public Module { int64_t n_heads, int64_t n_kv_heads, int64_t head_dim, + const std::vector& prefixes, bool bias, bool gather_output, - const std::vector& prefixes, const QuantArgs& quant_args, const ParallelArgs& parallel_args, const torch::TensorOptions& options); diff --git a/src/layers/qkv_linear_test.cpp b/src/layers/qkv_linear_test.cpp index bffded22..86ae2f6b 100644 --- a/src/layers/qkv_linear_test.cpp +++ b/src/layers/qkv_linear_test.cpp @@ -56,9 +56,9 @@ TEST_P(QKVColumnParallelLinearTest, LoadFusedWeight) { n_heads, n_kv_heads, head_dim, + std::vector{"query.", "key.", "value."}, /*bias=*/false, /*gather_output=*/false, - std::vector{"query.", "key.", "value."}, quant_args, parallel_args, options); diff --git a/src/models/gemma.h b/src/models/gemma.h index f994d23b..7934c9b6 100644 --- a/src/models/gemma.h +++ b/src/models/gemma.h @@ -100,9 +100,9 @@ class GemmaAttentionImpl : public Module { n_heads, n_kv_heads, head_dim, + std::vector{"q_proj.", "k_proj.", "v_proj."}, /*bias=*/false, /*gather_output=*/false, - std::vector{"q_proj.", "k_proj.", "v_proj."}, quant_args, parallel_args, options), diff --git a/src/models/gemma2.h b/src/models/gemma2.h index 98b0202a..6b4e75d6 100644 --- a/src/models/gemma2.h +++ b/src/models/gemma2.h @@ -100,9 +100,9 @@ class Gemma2AttentionImpl : public Module { n_heads, n_kv_heads, head_dim, + std::vector{"q_proj.", "k_proj.", "v_proj."}, args.attn_bias(), /*gather_output=*/false, - std::vector{"q_proj.", "k_proj.", "v_proj."}, quant_args, parallel_args, options), diff --git a/src/models/llama.h b/src/models/llama.h index f72d0d8c..b615097f 100644 --- a/src/models/llama.h +++ b/src/models/llama.h @@ -36,16 +36,18 @@ class LlamaMLPImpl : public Module { const int64_t intermediate_size = args.intermediate_size(); // register the weight parameter - // gate_up_proj_ = register_module( - // "gate_up_proj", - // FusedColumnParallelLinear( - // hidden_size, - // std::vector{intermediate_size, intermediate_size}, - // /*bias=*/false, - // /*gather_output=*/false, - // quant_args, - // parallel_args, - // options)); + gate_up_proj_ = register_module( + "gate_up_proj", + FusedColumnParallelLinear( + hidden_size, + std::vector{intermediate_size, intermediate_size}, + std::vector{"gate_proj.", "up_proj."}, + /*bias=*/false, + /*gather_output=*/false, + quant_args, + parallel_args, + options), + /*selector=*/nullptr); down_proj_ = register_module("down_proj", @@ -63,18 +65,6 @@ class LlamaMLPImpl : public Module { return down_proj_(act_func_(gate_up[0]) * gate_up[1]); } - // // load the weight from the checkpoint - // void load_state_dict(const StateDict& state_dict) { - // // call each submodule's load_state_dict function - // gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."}); - // down_proj_->load_state_dict(state_dict.select("down_proj.")); - // } - - // void verify_loaded_weights(const std::string& prefix) const { - // gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj]."); - // down_proj_->verify_loaded_weights(prefix + "down_proj."); - // } - private: // parameter members, must be registered FusedColumnParallelLinear gate_up_proj_{nullptr}; @@ -102,16 +92,20 @@ class LlamaAttentionImpl : public Module { std::max(1, n_kv_heads / world_size); // register submodules - // qkv_proj_ = register_module("qkv_proj", - // QKVColumnParallelLinear(hidden_size, - // n_heads, - // n_kv_heads, - // head_dim, - // /*bias=*/false, - // /*gather_output=*/false, - // quant_args, - // parallel_args, - // options)); + qkv_proj_ = register_module( + "qkv_proj", + QKVColumnParallelLinear( + hidden_size, + n_heads, + n_kv_heads, + head_dim, + std::vector{"q_proj.", "k_proj.", "v_proj."}, + /*bias=*/false, + /*gather_output=*/false, + quant_args, + parallel_args, + options), + /*selector=*/nullptr); o_proj_ = register_module("o_proj", RowParallelLinear(hidden_size, @@ -141,20 +135,6 @@ class LlamaAttentionImpl : public Module { return o_proj_(output); } - // // load the weight from the checkpoint - // void load_state_dict(const StateDict& state_dict) { - // // call each submodule's load_state_dict function - // qkv_proj_->load_state_dict( - // state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", - // "v_proj."}); - // o_proj_->load_state_dict(state_dict.select("o_proj.")); - // } - - // void verify_loaded_weights(const std::string& prefix) const { - // qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj]."); - // o_proj_->verify_loaded_weights(prefix + "o_proj."); - // } - private: // parameter members, must be registered QKVColumnParallelLinear qkv_proj_{nullptr}; diff --git a/src/models/models.h b/src/models/models.h index 1d79341d..d70f7329 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -11,7 +11,7 @@ // #include "gpt_j.h" // IWYU pragma: keep // #include "gpt_neox.h" // IWYU pragma: keep // #include "internlm.h" // IWYU pragma: keep -// #include "llama.h" // IWYU pragma: keep +#include "llama.h" // IWYU pragma: keep // #include "mistral.h" // IWYU pragma: keep // #include "mpt.h" // IWYU pragma: keep // #include "phi.h" // IWYU pragma: keep From 4ffe8db98a598014e994f401a89b845ae20a42dd Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 19:55:42 -0700 Subject: [PATCH 12/14] fix qwen and qwen2 --- src/layers/linear_impl.cpp | 8 ++- src/models/llama.h | 1 - src/models/models.h | 4 +- src/models/qwen.h | 103 +++++++------------------------------ src/models/qwen2.h | 102 ++++++------------------------------ 5 files changed, 43 insertions(+), 175 deletions(-) diff --git a/src/layers/linear_impl.cpp b/src/layers/linear_impl.cpp index ff877e4d..7b6d8f04 100644 --- a/src/layers/linear_impl.cpp +++ b/src/layers/linear_impl.cpp @@ -35,8 +35,12 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl( torch::empty({out_features_per_partition, in_features}, options)); if (bias) { - bias_ = register_parameter( - "bias", torch::empty({out_features_per_partition}, options)); + bias_ = register_sharded_parameter( + "bias", + /*dim=*/0, + rank, + world_size, + torch::empty({out_features_per_partition}, options)); } } diff --git a/src/models/llama.h b/src/models/llama.h index b615097f..d5b3a17d 100644 --- a/src/models/llama.h +++ b/src/models/llama.h @@ -10,7 +10,6 @@ #include "layers/embedding.h" #include "layers/fused_linear.h" #include "layers/linear.h" -#include "layers/linear_impl.h" #include "layers/normalization.h" #include "layers/qkv_linear.h" #include "memory/kv_cache.h" diff --git a/src/models/models.h b/src/models/models.h index d70f7329..ce713f08 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -15,5 +15,5 @@ // #include "mistral.h" // IWYU pragma: keep // #include "mpt.h" // IWYU pragma: keep // #include "phi.h" // IWYU pragma: keep -// #include "qwen.h" // IWYU pragma: keep -// #include "qwen2.h" // IWYU pragma: keep +#include "qwen.h" // IWYU pragma: keep +#include "qwen2.h" // IWYU pragma: keep diff --git a/src/models/qwen.h b/src/models/qwen.h index d1867e4e..61cf7490 100644 --- a/src/models/qwen.h +++ b/src/models/qwen.h @@ -10,6 +10,7 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" +#include "layers/fused_linear.h" #include "layers/linear.h" #include "layers/normalization.h" #include "memory/kv_cache.h" @@ -20,7 +21,7 @@ #include "module/module_holder.h" #include "module/module_list.h" // QWen model compatible with huggingface weights -// adopted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py +// Adapted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py namespace llm::hf { class QWenMLPImpl : public Module { @@ -38,14 +39,18 @@ class QWenMLPImpl : public Module { const int64_t intermediate_size = args.intermediate_size() / 2; // register the weight parameter - w1_w2_proj_ = register_module("gate_up_proj", - ColumnParallelLinear(hidden_size, - intermediate_size * 2, - /*bias=*/false, - /*gather_output=*/false, - quant_args, - parallel_args, - options)); + gate_up_proj_ = register_module( + "gate_up_proj", + FusedColumnParallelLinear( + hidden_size, + std::vector{intermediate_size, intermediate_size}, + std::vector{"w1.", "w2."}, + /*bias=*/false, + /*gather_output=*/false, + quant_args, + parallel_args, + options), + /*selector=*/nullptr); c_proj_ = register_module("c_proj", RowParallelLinear(intermediate_size, hidden_size, @@ -57,26 +62,13 @@ class QWenMLPImpl : public Module { } torch::Tensor forward(torch::Tensor x) { - auto gate_up_proj = w1_w2_proj_(x); - auto chunks = gate_up_proj.chunk(/*chunks=*/2, /*dim=*/-1); - return c_proj_(chunks[0] * act_(chunks[1])); - } - - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - w1_w2_proj_->load_state_dict(state_dict, {"w1.", "w2."}); - c_proj_->load_state_dict(state_dict.select("c_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - w1_w2_proj_->verify_loaded_weights(prefix + "[w1,w2]."); - c_proj_->verify_loaded_weights(prefix + "c_proj."); + const auto gate_up = gate_up_proj_(x); + return c_proj_(gate_up[0] * act_(gate_up[1])); } private: // parameter members, must be registered - ColumnParallelLinear w1_w2_proj_{nullptr}; + FusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear c_proj_{nullptr}; ActFunc act_{nullptr}; @@ -133,18 +125,6 @@ class QWenAttentionImpl : public Module { return c_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - c_attn_->load_state_dict(state_dict.select("c_attn.")); - c_proj_->load_state_dict(state_dict.select("c_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - c_attn_->verify_loaded_weights(prefix + "c_attn."); - c_proj_->verify_loaded_weights(prefix + "c_proj."); - } - private: // parameter members, must be registered ColumnParallelLinear c_attn_{nullptr}; @@ -183,22 +163,6 @@ class QWenBlockImpl : public Module { return h + mlp_(ln_2_(h)); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - attn_->load_state_dict(state_dict.select("attn.")); - mlp_->load_state_dict(state_dict.select("mlp.")); - ln_1_->load_state_dict(state_dict.select("ln_1.")); - ln_2_->load_state_dict(state_dict.select("ln_2.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - attn_->verify_loaded_weights(prefix + "attn."); - mlp_->verify_loaded_weights(prefix + "mlp."); - ln_1_->verify_loaded_weights(prefix + "ln_1."); - ln_2_->verify_loaded_weights(prefix + "ln_2."); - } - private: // parameter members, must be registered QWenAttention attn_{nullptr}; @@ -226,7 +190,7 @@ class QWenModelImpl : public Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", ModuleList()); + blocks_ = register_module("h", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -254,26 +218,6 @@ class QWenModelImpl : public Module { return ln_f_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - wte_->load_state_dict(state_dict.select("wte.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("h." + std::to_string(i) + ".")); - } - ln_f_->load_state_dict(state_dict.select("ln_f.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - wte_->verify_loaded_weights(prefix + "wte."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "h." + std::to_string(i) + - "."); - } - ln_f_->verify_loaded_weights(prefix + "ln_f."); - } - private: // parameter members, must be registered ParallelEmbedding wte_{nullptr}; @@ -331,17 +275,6 @@ class QWenForCausalLMImpl : public Module { return lm_head_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - transformer_->load_state_dict(state_dict.select("transformer.")); - lm_head_->load_state_dict(state_dict.select("lm_head.")); - } - - void verify_loaded_weights() const { - transformer_->verify_loaded_weights("transformer."); - lm_head_->verify_loaded_weights("lm_head."); - } - private: // parameter members, must be registered QWenModel transformer_{nullptr}; diff --git a/src/models/qwen2.h b/src/models/qwen2.h index d93de88a..84758c2d 100644 --- a/src/models/qwen2.h +++ b/src/models/qwen2.h @@ -43,11 +43,13 @@ class QWen2MLPImpl : public Module { FusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, + std::vector{"gate_proj.", "up_proj."}, /*bias=*/false, /*gather_output=*/false, quant_args, parallel_args, - options)); + options), + /*selector=*/nullptr); down_proj_ = register_module("down_proj", RowParallelLinear(intermediate_size, @@ -64,18 +66,6 @@ class QWen2MLPImpl : public Module { return down_proj_(act_func_(gate_up[0]) * gate_up[1]); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."}); - down_proj_->load_state_dict(state_dict.select("down_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj]."); - down_proj_->verify_loaded_weights(prefix + "down_proj."); - } - private: // parameter members, must be registered FusedColumnParallelLinear gate_up_proj_{nullptr}; @@ -104,16 +94,20 @@ class QWen2AttentionImpl : public Module { std::max(1, n_kv_heads / world_size); // register submodules - qkv_proj_ = register_module("qkv_proj", - QKVColumnParallelLinear(hidden_size, - n_heads, - n_kv_heads, - head_dim, - /*bias=*/true, - /*gather_output=*/false, - quant_args, - parallel_args, - options)); + qkv_proj_ = register_module( + "qkv_proj", + QKVColumnParallelLinear( + hidden_size, + n_heads, + n_kv_heads, + head_dim, + std::vector{"q_proj.", "k_proj.", "v_proj."}, + /*bias=*/true, + /*gather_output=*/false, + quant_args, + parallel_args, + options), + /*selector=*/nullptr); o_proj_ = register_module("o_proj", RowParallelLinear(hidden_size, @@ -146,19 +140,6 @@ class QWen2AttentionImpl : public Module { return o_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - qkv_proj_->load_state_dict( - state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."}); - o_proj_->load_state_dict(state_dict.select("o_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj]."); - o_proj_->verify_loaded_weights(prefix + "o_proj."); - } - private: // parameter members, must be registered QKVColumnParallelLinear qkv_proj_{nullptr}; @@ -208,24 +189,6 @@ class QWen2DecoderLayerImpl : public Module { return hidden_states; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - self_attn_->load_state_dict(state_dict.select("self_attn.")); - mlp_->load_state_dict(state_dict.select("mlp.")); - input_layernorm_->load_state_dict(state_dict.select("input_layernorm.")); - post_attention_layernorm_->load_state_dict( - state_dict.select("post_attention_layernorm.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - self_attn_->verify_loaded_weights(prefix + "self_attn."); - mlp_->verify_loaded_weights(prefix + "mlp."); - input_layernorm_->verify_loaded_weights(prefix + "input_layernorm."); - post_attention_layernorm_->verify_loaded_weights( - prefix + "post_attention_layernorm."); - } - private: // parameter members, must be registered QWen2Attention self_attn_{nullptr}; @@ -291,26 +254,6 @@ class QWen2ModelImpl : public Module { return norm_(h, residual); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - embed_tokens_->load_state_dict(state_dict.select("embed_tokens.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("layers." + std::to_string(i) + ".")); - } - norm_->load_state_dict(state_dict.select("norm.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + - "."); - } - norm_->verify_loaded_weights(prefix + "norm."); - } - private: // parameter members, must be registered ParallelEmbedding embed_tokens_{nullptr}; @@ -368,17 +311,6 @@ class QWen2ForCausalLMImpl : public Module { return lm_head_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - model_->load_state_dict(state_dict.select("model.")); - lm_head_->load_state_dict(state_dict.select("lm_head.")); - } - - void verify_loaded_weights() const { - model_->verify_loaded_weights("model."); - lm_head_->verify_loaded_weights("lm_head."); - } - private: // parameter members, must be registered QWen2Model model_{nullptr}; From e30f6e259a60675a409ae8870ce4acbde9762750 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 21:39:27 -0700 Subject: [PATCH 13/14] clean up models --- src/models/{ => _deprecated}/aquila.h | 0 src/models/{ => _deprecated}/baichuan.h | 0 src/models/{ => _deprecated}/bloom.h | 0 src/models/{ => _deprecated}/chatglm.h | 0 src/models/{ => _deprecated}/gpt_j.h | 0 src/models/{ => _deprecated}/gpt_neox.h | 0 src/models/{ => _deprecated}/internlm.h | 0 src/models/{ => _deprecated}/mistral.h | 0 src/models/{ => _deprecated}/mpt.h | 0 src/models/{ => _deprecated}/simple_model.h | 0 src/models/{ => alibaba}/qwen.h | 0 src/models/{ => alibaba}/qwen2.h | 0 src/models/deepseek/README.md | 1 + src/models/{ => google}/gemma.h | 0 src/models/{ => google}/gemma2.h | 0 src/models/{ => meta}/llama.h | 0 src/models/{ => microsoft}/phi.h | 80 +-------------------- src/models/model_registry.cpp | 2 +- src/models/models.h | 19 ----- src/models/{ => openai}/gpt2.h | 0 src/models/registered_models.h | 26 +++++++ 21 files changed, 29 insertions(+), 99 deletions(-) rename src/models/{ => _deprecated}/aquila.h (100%) rename src/models/{ => _deprecated}/baichuan.h (100%) rename src/models/{ => _deprecated}/bloom.h (100%) rename src/models/{ => _deprecated}/chatglm.h (100%) rename src/models/{ => _deprecated}/gpt_j.h (100%) rename src/models/{ => _deprecated}/gpt_neox.h (100%) rename src/models/{ => _deprecated}/internlm.h (100%) rename src/models/{ => _deprecated}/mistral.h (100%) rename src/models/{ => _deprecated}/mpt.h (100%) rename src/models/{ => _deprecated}/simple_model.h (100%) rename src/models/{ => alibaba}/qwen.h (100%) rename src/models/{ => alibaba}/qwen2.h (100%) create mode 100644 src/models/deepseek/README.md rename src/models/{ => google}/gemma.h (100%) rename src/models/{ => google}/gemma2.h (100%) rename src/models/{ => meta}/llama.h (100%) rename src/models/{ => microsoft}/phi.h (80%) delete mode 100644 src/models/models.h rename src/models/{ => openai}/gpt2.h (100%) create mode 100644 src/models/registered_models.h diff --git a/src/models/aquila.h b/src/models/_deprecated/aquila.h similarity index 100% rename from src/models/aquila.h rename to src/models/_deprecated/aquila.h diff --git a/src/models/baichuan.h b/src/models/_deprecated/baichuan.h similarity index 100% rename from src/models/baichuan.h rename to src/models/_deprecated/baichuan.h diff --git a/src/models/bloom.h b/src/models/_deprecated/bloom.h similarity index 100% rename from src/models/bloom.h rename to src/models/_deprecated/bloom.h diff --git a/src/models/chatglm.h b/src/models/_deprecated/chatglm.h similarity index 100% rename from src/models/chatglm.h rename to src/models/_deprecated/chatglm.h diff --git a/src/models/gpt_j.h b/src/models/_deprecated/gpt_j.h similarity index 100% rename from src/models/gpt_j.h rename to src/models/_deprecated/gpt_j.h diff --git a/src/models/gpt_neox.h b/src/models/_deprecated/gpt_neox.h similarity index 100% rename from src/models/gpt_neox.h rename to src/models/_deprecated/gpt_neox.h diff --git a/src/models/internlm.h b/src/models/_deprecated/internlm.h similarity index 100% rename from src/models/internlm.h rename to src/models/_deprecated/internlm.h diff --git a/src/models/mistral.h b/src/models/_deprecated/mistral.h similarity index 100% rename from src/models/mistral.h rename to src/models/_deprecated/mistral.h diff --git a/src/models/mpt.h b/src/models/_deprecated/mpt.h similarity index 100% rename from src/models/mpt.h rename to src/models/_deprecated/mpt.h diff --git a/src/models/simple_model.h b/src/models/_deprecated/simple_model.h similarity index 100% rename from src/models/simple_model.h rename to src/models/_deprecated/simple_model.h diff --git a/src/models/qwen.h b/src/models/alibaba/qwen.h similarity index 100% rename from src/models/qwen.h rename to src/models/alibaba/qwen.h diff --git a/src/models/qwen2.h b/src/models/alibaba/qwen2.h similarity index 100% rename from src/models/qwen2.h rename to src/models/alibaba/qwen2.h diff --git a/src/models/deepseek/README.md b/src/models/deepseek/README.md new file mode 100644 index 00000000..bc60bf4b --- /dev/null +++ b/src/models/deepseek/README.md @@ -0,0 +1 @@ +TODO: diff --git a/src/models/gemma.h b/src/models/google/gemma.h similarity index 100% rename from src/models/gemma.h rename to src/models/google/gemma.h diff --git a/src/models/gemma2.h b/src/models/google/gemma2.h similarity index 100% rename from src/models/gemma2.h rename to src/models/google/gemma2.h diff --git a/src/models/llama.h b/src/models/meta/llama.h similarity index 100% rename from src/models/llama.h rename to src/models/meta/llama.h diff --git a/src/models/phi.h b/src/models/microsoft/phi.h similarity index 80% rename from src/models/phi.h rename to src/models/microsoft/phi.h index 99e8dc07..f81ed123 100644 --- a/src/models/phi.h +++ b/src/models/microsoft/phi.h @@ -52,18 +52,6 @@ class PhiMLPImpl : public Module { torch::Tensor forward(torch::Tensor x) { return fc2_(act_(fc1_(x))); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - fc1_->load_state_dict(state_dict.select("fc1.")); - fc2_->load_state_dict(state_dict.select("fc2.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - fc1_->verify_loaded_weights(prefix + "fc1."); - fc2_->verify_loaded_weights(prefix + "fc2."); - } - private: // parameter members, must be registered ColumnParallelLinear fc1_{nullptr}; @@ -134,18 +122,6 @@ class PhiAttentionImpl : public Module { return out_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - Wqkv_->load_state_dict(state_dict.select("Wqkv.")); - out_proj_->load_state_dict(state_dict.select("out_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - Wqkv_->verify_loaded_weights(prefix + "Wqkv."); - out_proj_->verify_loaded_weights(prefix + "out_proj."); - } - private: // parameter members, must be registered ColumnParallelLinear Wqkv_{nullptr}; @@ -191,20 +167,6 @@ class PhiBlockImpl : public Module { return x + attn_output + mlp_output; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - mixer_->load_state_dict(state_dict.select("mixer.")); - mlp_->load_state_dict(state_dict.select("mlp.")); - ln_->load_state_dict(state_dict.select("ln.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - mixer_->verify_loaded_weights(prefix + "mixer."); - mlp_->verify_loaded_weights(prefix + "mlp."); - ln_->verify_loaded_weights(prefix + "ln."); - } - private: // parameter members, must be registered PhiAttention mixer_{nullptr}; @@ -223,7 +185,7 @@ class PhiModelImpl : public Module { const torch::TensorOptions& options) { // register submodules wte_ = register_module( - "wte", + "embd.wte", ParallelEmbedding( args.vocab_size(), args.hidden_size(), parallel_args, options)); @@ -256,24 +218,6 @@ class PhiModelImpl : public Module { return h; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - wte_->load_state_dict(state_dict.select("embd.wte.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("h." + std::to_string(i) + ".")); - } - } - - void verify_loaded_weights(const std::string& prefix) const { - wte_->verify_loaded_weights(prefix + "embd.wte."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "h." + std::to_string(i) + - "."); - } - } - private: // parameter members, must be registered ParallelEmbedding wte_{nullptr}; @@ -310,17 +254,6 @@ class PhiLMHeadImpl : public Module { torch::Tensor forward(torch::Tensor x) { return linear_(ln_(x)); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - ln_->load_state_dict(state_dict.select("ln.")); - linear_->load_state_dict(state_dict.select("linear.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - ln_->verify_loaded_weights(prefix + "ln."); - linear_->verify_loaded_weights(prefix + "linear."); - } - private: // parameter members, must be registered LayerNorm ln_{nullptr}; @@ -366,17 +299,6 @@ class PhiForCausalLMImpl : public Module { return lm_head_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - transformer_->load_state_dict(state_dict.select("transformer.")); - lm_head_->load_state_dict(state_dict.select("lm_head.")); - } - - void verify_loaded_weights() const { - transformer_->verify_loaded_weights("transformer."); - lm_head_->verify_loaded_weights("lm_head."); - } - private: // parameter members, must be registered PhiModel transformer_{nullptr}; diff --git a/src/models/model_registry.cpp b/src/models/model_registry.cpp index b4c7c5b9..c41ace43 100644 --- a/src/models/model_registry.cpp +++ b/src/models/model_registry.cpp @@ -2,7 +2,7 @@ #include -#include "models.h" // IWYU pragma: keep +#include "registered_models.h" // IWYU pragma: keep namespace llm { diff --git a/src/models/models.h b/src/models/models.h deleted file mode 100644 index ce713f08..00000000 --- a/src/models/models.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -// list all registered models here -// #include "aquila.h" // IWYU pragma: keep -// #include "baichuan.h" // IWYU pargma: keep -// #include "bloom.h" // IWYU pragma: keep -// #include "chatglm.h" // IWYU pragma: keep -#include "gemma.h" // IWYU pragma: keep -#include "gemma2.h" // IWYU pragma: keep -#include "gpt2.h" // IWYU pragma: keep -// #include "gpt_j.h" // IWYU pragma: keep -// #include "gpt_neox.h" // IWYU pragma: keep -// #include "internlm.h" // IWYU pragma: keep -#include "llama.h" // IWYU pragma: keep -// #include "mistral.h" // IWYU pragma: keep -// #include "mpt.h" // IWYU pragma: keep -// #include "phi.h" // IWYU pragma: keep -#include "qwen.h" // IWYU pragma: keep -#include "qwen2.h" // IWYU pragma: keep diff --git a/src/models/gpt2.h b/src/models/openai/gpt2.h similarity index 100% rename from src/models/gpt2.h rename to src/models/openai/gpt2.h diff --git a/src/models/registered_models.h b/src/models/registered_models.h new file mode 100644 index 00000000..fba44c42 --- /dev/null +++ b/src/models/registered_models.h @@ -0,0 +1,26 @@ +#pragma once + +// list all registered models here +// Google +#include "google/gemma.h" // IWYU pragma: keep +#include "google/gemma2.h" // IWYU pragma: keep +// OpenAI +#include "openai/gpt2.h" // IWYU pragma: keep +// Meta +#include "meta/llama.h" // IWYU pragma: keep +// Microsoft +#include "microsoft/phi.h" // IWYU pragma: keep +// Alibaba +#include "alibaba/qwen.h" // IWYU pragma: keep +#include "alibaba/qwen2.h" // IWYU pragma: keep + +// Deprecated models +// #include "deprecated/aquila.h" +// #include "deprecated/baichuan.h" +// #include "deprecated/bloom.h" +// #include "deprecated/chatglm.h" +// #include "deprecated/gpt_j.h" +// #include "deprecated/gpt_neox.h" +// #include "deprecated/internlm.h" +// #include "deprecated/mistral.h" +// #include "deprecated/mpt.h" From e4e626d75396c7dfc3768f251cd1a0c1e5865da4 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 21:50:06 -0700 Subject: [PATCH 14/14] clean up logs and update readme --- README.md | 13 ++----------- src/module/module.cpp | 7 ++++--- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 8d64eef0..14b8faa2 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,12 @@ ScaleLLM
-[ScaleLLM](#) is a cutting-edge inference system engineered for large language models (LLMs), designed to meet the demands of production environments. It extends its support to a wide range of popular open-source models, including [Llama3.1](https://github.com/meta-llama/llama3), [Gemma2](https://github.com/google-deepmind/gemma), Bloom, GPT-NeoX, and more. +[ScaleLLM](#) is a cutting-edge inference system engineered for large language models (LLMs), designed to meet the demands of production environments. It extends its support to a wide range of popular open-source models, including [Llama3.1](https://github.com/meta-llama/llama3), [Gemma2](https://github.com/google-deepmind/gemma), [Phi](https://huggingface.co/microsoft/phi-2), and more. ScaleLLM is currently undergoing active development. We are fully committed to consistently enhancing its efficiency while also incorporating additional features. Feel free to explore our [**_Roadmap_**](https://github.com/vectorch-ai/ScaleLLM/issues/84) for more details. ## News: +* [01/2025] - Optimized inhouse Attention kernels * [06/2024] - ScaleLLM is now available on [PyPI](https://pypi.org/project/scalellm/). You can install it using `pip install scalellm`. * [03/2024] - [Advanced features](#advanced-features) support for ✅ [CUDA graph](#cuda-graph), ✅ [prefix cache](#prefix-cache), ✅ [chunked prefill](#chunked-prefill) and ✅ [speculative decoding](#speculative-decoding). * [11/2023] - [First release](https://github.com/vectorch-ai/ScaleLLM/releases/tag/v0.0.1) with support for popular [open-source models](#supported-models). @@ -274,21 +275,11 @@ Quantization is a crucial process for reducing the memory footprint of models. S | Models | Tensor Parallel | Quantization | Chat API | HF models examples | | :--------: | :-------------: | :----------: | :------: | :---------------------------:| -| Aquila | Yes | Yes | Yes | [BAAI/Aquila-7B](https://huggingface.co/BAAI/Aquila-7B), [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) | -| Bloom | Yes | Yes | No | [bigscience/bloom](https://huggingface.co/bigscience/bloom) | -| Baichuan | Yes | Yes | Yes | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) | -| ChatGLM4/3 | Yes | Yes | Yes | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) | | Gemma2 | Yes | Yes | Yes | [google/gemma-2-2b](https://huggingface.co/google/gemma-2-2b) | -| GPT_j | Yes | Yes | No | [EleutherAI/gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b) | -| GPT_NeoX | Yes | Yes | No | [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b) | | GPT2 | Yes | Yes | No | [gpt2](https://huggingface.co/gpt2)| -| InternLM | Yes | Yes | Yes | [internlm/internlm-7b](https://huggingface.co/internlm/internlm-7b) | | Llama3/2 | Yes | Yes | Yes | [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct), [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | -| Mistral | Yes | Yes | Yes | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) | -| MPT | Yes | Yes | Yes | [mosaicml/mpt-30b](https://huggingface.co/mosaicml/mpt-30b) | | Phi2 | Yes | Yes | No | [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) | | Qwen2 | Yes | Yes | Yes | [Qwen/Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat) | -| Yi | Yes | Yes | Yes |[01-ai/Yi-6B](https://huggingface.co/01-ai/Yi-6B), [01-ai/Yi-34B-Chat-4bits](https://huggingface.co/01-ai/Yi-34B-Chat-4bits), [01-ai/Yi-6B-200K](https://huggingface.co/01-ai/Yi-6B-200K) | If your model is not included in the supported list, we are more than willing to assist you. Please feel free to create a request for adding a new model on [GitHub Issues](https://github.com/vectorch-ai/ScaleLLM/issues). diff --git a/src/module/module.cpp b/src/module/module.cpp index 3e78f3b4..29d730b7 100644 --- a/src/module/module.cpp +++ b/src/module/module.cpp @@ -336,8 +336,8 @@ size_t Module::load(const StateDict& state_dict, } if (param_tensor.sizes() == tensor.sizes()) { - LOG(INFO) << "Loading parameter: " << join_name(name_prefix, key) - << " of size " << tensor.sizes(); + // LOG(INFO) << "Loading parameter: " << join_name(name_prefix, key) + // << " of size " << tensor.sizes(); // copy data to the parameter tensor param_tensor.copy_(tensor); // mark as loaded @@ -376,7 +376,8 @@ bool Module::verify(const std::string& name_prefix) const { const auto& key = item.key(); const auto& param = item.value(); if (!param.is_loaded) { - LOG(ERROR) << "Missing parameter: " << join_name(name_prefix, key); + LOG(ERROR) << "Missing parameter: " << join_name(name_prefix, key) + << ", size: " << param.tensor.sizes(); } all_loaded = all_loaded && param.is_loaded; }