From 2972034b35802d862d3e9da1f630bb6b10b5c2f0 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 23:25:56 -0700 Subject: [PATCH 01/13] refactor: clean up legacy load_state_dict for linear layers --- src/layers/linear.h | 6 ------ src/layers/linear_impl.cpp | 12 ++---------- src/layers/linear_impl.h | 4 ---- src/layers/qkv_linear.h | 7 +++++-- src/layers/qkv_linear_test.cpp | 8 ++++---- src/models/alibaba/qwen2.h | 5 ++--- src/models/google/gemma.h | 5 ++--- src/models/google/gemma2.h | 5 ++--- src/models/meta/llama.h | 5 ++--- 9 files changed, 19 insertions(+), 38 deletions(-) diff --git a/src/layers/linear.h b/src/layers/linear.h index 99da103b..c7947f47 100644 --- a/src/layers/linear.h +++ b/src/layers/linear.h @@ -26,12 +26,6 @@ class ParallelLinearImpl : public Module { virtual void verify_loaded_weights(const std::string& prefix = "") const = 0; - // load state dict with a transform function - virtual void load_state_dict(const StateDict& /*state_dict*/, - TensorTransform /*transform_func*/) { - LOG(FATAL) << "not implemented"; - } - // special load_state_dict for fused cases virtual void load_state_dict(const StateDict& /*state_dict*/, const std::vector& /*prefixes*/) { diff --git a/src/layers/linear_impl.cpp b/src/layers/linear_impl.cpp index 7b6d8f04..ac0176a5 100644 --- a/src/layers/linear_impl.cpp +++ b/src/layers/linear_impl.cpp @@ -55,23 +55,15 @@ torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) { // load the weight from the checkpoint void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict) { - // call load_state_dict with identity transform - load_state_dict(state_dict, - [](const torch::Tensor& tensor) { return tensor; }); -} - -void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict, - TensorTransform transform_func) { - CHECK(transform_func != nullptr) << "transform_func must be provided"; const auto rank = parallel_args_.rank(); const auto world_size = parallel_args_.world_size(); // load sharded weights on dim 0 - LOAD_SHARDED_WEIGHT_WITH_TRANSFORM(weight, 0); + LOAD_SHARDED_WEIGHT(weight, 0); if (bias_.defined()) { // load sharded bias on dim 0 - LOAD_SHARDED_WEIGHT_WITH_TRANSFORM(bias, 0); + LOAD_SHARDED_WEIGHT(bias, 0); } } diff --git a/src/layers/linear_impl.h b/src/layers/linear_impl.h index ff551649..de8c70c7 100644 --- a/src/layers/linear_impl.h +++ b/src/layers/linear_impl.h @@ -26,10 +26,6 @@ class ColumnParallelLinearImpl : public ParallelLinearImpl { // load the weight from the checkpoint void load_state_dict(const StateDict& state_dict) override; - // load state dict with a transform function - void load_state_dict(const StateDict& state_dict, - TensorTransform transform_func) override; - // special load_state_dict for fused cases void load_state_dict(const StateDict& state_dict, const std::vector& prefixes) override; diff --git a/src/layers/qkv_linear.h b/src/layers/qkv_linear.h index 64923e27..9a5dffff 100644 --- a/src/layers/qkv_linear.h +++ b/src/layers/qkv_linear.h @@ -27,8 +27,11 @@ class QKVColumnParallelLinearImpl : public Module { const ParallelArgs& parallel_args, const torch::TensorOptions& options); - std::vector forward(torch::Tensor input) { - return parallel_linear_->forward(input); + // returns (query, key, value) + std::tuple forward( + torch::Tensor input) { + const auto qkv = parallel_linear_->forward(input); + return {qkv[0], qkv[1], qkv[2]}; } private: diff --git a/src/layers/qkv_linear_test.cpp b/src/layers/qkv_linear_test.cpp index 86ae2f6b..6f1a5e05 100644 --- a/src/layers/qkv_linear_test.cpp +++ b/src/layers/qkv_linear_test.cpp @@ -67,19 +67,19 @@ TEST_P(QKVColumnParallelLinearTest, LoadFusedWeight) { // generate random input and compare with the output auto input = torch::randn({n_tokens, hidden_size}, options); - auto qkv = linear.forward(input); + const auto [q, k, v] = linear.forward(input); const int64_t kv_shard_id = n_kv_heads >= n_shards ? shard_id : n_kv_heads * shard_id / n_shards; auto query = input.matmul(query_chunks[shard_id].t()); - EXPECT_TRUE(torch::allclose(qkv[0], query, /*rtol=*/1e-5, /*atol=*/1e-5)); + EXPECT_TRUE(torch::allclose(q, query, /*rtol=*/1e-5, /*atol=*/1e-5)); auto key = input.matmul(key_chunks[kv_shard_id].t()); - EXPECT_TRUE(torch::allclose(qkv[1], key, /*rtol=*/1e-5, /*atol=*/1e-5)); + EXPECT_TRUE(torch::allclose(k, key, /*rtol=*/1e-5, /*atol=*/1e-5)); auto value = input.matmul(value_chunks[kv_shard_id].t()); - EXPECT_TRUE(torch::allclose(qkv[2], value, /*rtol=*/1e-5, /*atol=*/1e-5)); + EXPECT_TRUE(torch::allclose(v, value, /*rtol=*/1e-5, /*atol=*/1e-5)); } } diff --git a/src/models/alibaba/qwen2.h b/src/models/alibaba/qwen2.h index 84758c2d..9f1103b1 100644 --- a/src/models/alibaba/qwen2.h +++ b/src/models/alibaba/qwen2.h @@ -133,10 +133,9 @@ class QWen2AttentionImpl : public Module { const InputParameters& input_params) { // (num_tokens, dim) x (dim, n_local_heads * head_dim) // => (num_tokens, n_local_heads * head_dim) - const auto qkv = qkv_proj_(x); + const auto [q, k, v] = qkv_proj_(x); // calculate attention, output: (num_tokens, n_local_heads * head_dim) - const auto output = - atten_(qkv[0], qkv[1], qkv[2], positions, kv_cache, input_params); + const auto output = atten_(q, k, v, positions, kv_cache, input_params); return o_proj_(output); } diff --git a/src/models/google/gemma.h b/src/models/google/gemma.h index 7934c9b6..555ed279 100644 --- a/src/models/google/gemma.h +++ b/src/models/google/gemma.h @@ -128,11 +128,10 @@ class GemmaAttentionImpl : public Module { const InputParameters& input_params) { // (num_tokens, dim) x (dim, n_local_heads * head_dim) // => (num_tokens, n_local_heads * head_dim) - const auto qkv = qkv_proj_(x); + const auto [q, k, v] = qkv_proj_(x); // calculate attention, // output: (num_tokens, n_local_heads*head_dim) - const auto output = - atten_(qkv[0], qkv[1], qkv[2], positions, kv_cache, input_params); + const auto output = atten_(q, k, v, positions, kv_cache, input_params); return o_proj_(output); } diff --git a/src/models/google/gemma2.h b/src/models/google/gemma2.h index 6b4e75d6..f93d3ac4 100644 --- a/src/models/google/gemma2.h +++ b/src/models/google/gemma2.h @@ -132,11 +132,10 @@ class Gemma2AttentionImpl : public Module { const InputParameters& input_params) { // (num_tokens, dim) x (dim, n_local_heads * head_dim) // => (num_tokens, n_local_heads * head_dim) - const auto qkv = qkv_proj_(x); + const auto [q, k, v] = qkv_proj_(x); // calculate attention, // output: (num_tokens, n_local_heads*head_dim) - const auto output = - atten_(qkv[0], qkv[1], qkv[2], positions, kv_cache, input_params); + const auto output = atten_(q, k, v, positions, kv_cache, input_params); return o_proj_(output); } diff --git a/src/models/meta/llama.h b/src/models/meta/llama.h index d5b3a17d..382c3f8a 100644 --- a/src/models/meta/llama.h +++ b/src/models/meta/llama.h @@ -127,10 +127,9 @@ class LlamaAttentionImpl : public Module { const InputParameters& input_params) { // (num_tokens, dim) x (dim, n_local_heads * head_dim) // => (num_tokens, n_local_heads * head_dim) - const auto qkv = qkv_proj_(x); + const auto [q, k, v] = qkv_proj_(x); // calculate attention, output: (num_tokens, n_local_heads * head_dim) - const auto output = - atten_(qkv[0], qkv[1], qkv[2], positions, kv_cache, input_params); + const auto output = atten_(q, k, v, positions, kv_cache, input_params); return o_proj_(output); } From b4cc8f2b0cac4b17fa322e6d14bc290f51776f2c Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 23 Sep 2025 23:50:47 -0700 Subject: [PATCH 02/13] update qlinear --- src/quantization/qlinear_impl.cpp | 61 +++++++++++++++++++------- src/quantization/qlinear_impl_test.cpp | 8 ++-- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/src/quantization/qlinear_impl.cpp b/src/quantization/qlinear_impl.cpp index 695d0dad..01071296 100644 --- a/src/quantization/qlinear_impl.cpp +++ b/src/quantization/qlinear_impl.cpp @@ -117,7 +117,8 @@ ColumnParallelQLinearImpl::ColumnParallelQLinearImpl( quant_args.group_size() > 0 ? quant_args.group_size() : in_features; CHECK(qweight_pack_dim == 0 || qweight_pack_dim == 1) << "qweight_pack_dim must be 0 or 1"; - const int64_t world_size = parallel_args.world_size(); + const auto rank = parallel_args_.rank(); + const auto world_size = parallel_args_.world_size(); CHECK(out_features % world_size == 0) << "out_features " << out_features << " not divisible by world_size " << world_size; @@ -125,29 +126,46 @@ ColumnParallelQLinearImpl::ColumnParallelQLinearImpl( const int64_t pack_factor = 32 / bits; if (qweight_pack_dim == 0) { - qweight_ = register_parameter( + qweight_ = register_sharded_parameter( "qweight", + /*dim=*/1, + rank, + world_size, torch::empty({in_features / pack_factor, out_features_per_partition}, options.dtype(torch::kInt32))); } else { - qweight_ = register_parameter( + qweight_ = register_sharded_parameter( "qweight", + /*dim=*/1, + rank, + world_size, torch::empty({in_features, out_features_per_partition / pack_factor}, options.dtype(torch::kInt32))); } - qzeros_ = register_parameter( + qzeros_ = register_sharded_parameter( "qzeros", + /*dim=*/1, + rank, + world_size, torch::empty({round_up(in_features, group_size), out_features_per_partition / pack_factor}, options.dtype(torch::kInt32))); - scales_ = register_parameter("scales", - torch::empty({round_up(in_features, group_size), - out_features_per_partition}, - options)); + scales_ = register_sharded_parameter( + "scales", + /*dim=*/1, + rank, + world_size, + torch::empty( + {round_up(in_features, group_size), out_features_per_partition}, + options)); if (bias) { - bias_ = register_parameter( - "bias", torch::empty({out_features_per_partition}, options)); + bias_ = register_sharded_parameter( + "bias", + /*dim=*/0, + rank, + world_size, + torch::empty({out_features_per_partition}, options)); } } @@ -226,7 +244,8 @@ RowParallelQLinearImpl::RowParallelQLinearImpl( const auto bits = quant_args.bits(); CHECK(qweight_pack_dim == 0 || qweight_pack_dim == 1) << "qweight_pack_dim must be 0 or 1"; - const int64_t world_size = parallel_args.world_size(); + const auto rank = parallel_args_.rank(); + const auto world_size = parallel_args_.world_size(); CHECK(in_features % world_size == 0) << "in_features " << in_features << " not divisible by world_size " << world_size; @@ -236,24 +255,36 @@ RowParallelQLinearImpl::RowParallelQLinearImpl( quant_args.group_size() > 0 ? quant_args.group_size() : in_features; if (qweight_pack_dim == 0) { - qweight_ = register_parameter( + qweight_ = register_sharded_parameter( "qweight", + /*dim=*/0, + rank, + world_size, torch::empty({in_features_per_partition / pack_factor, out_features}, options.dtype(torch::kInt32))); } else { - qweight_ = register_parameter( + qweight_ = register_sharded_parameter( "qweight", + /*dim=*/0, + rank, + world_size, torch::empty({in_features_per_partition, out_features / pack_factor}, options.dtype(torch::kInt32))); } - qzeros_ = register_parameter( + qzeros_ = register_sharded_parameter( "qzeros", + /*dim=*/0, + rank, + world_size, torch::empty({round_up(in_features_per_partition, group_size), out_features / pack_factor}, options.dtype(torch::kInt32))); - scales_ = register_parameter( + scales_ = register_sharded_parameter( "scales", + /*dim=*/0, + rank, + world_size, torch::empty( {round_up(in_features_per_partition, group_size), out_features}, options)); diff --git a/src/quantization/qlinear_impl_test.cpp b/src/quantization/qlinear_impl_test.cpp index 26e50044..26291ab4 100644 --- a/src/quantization/qlinear_impl_test.cpp +++ b/src/quantization/qlinear_impl_test.cpp @@ -46,8 +46,8 @@ TEST(QlinearTest, ColumnParallelQuantLinear) { /*bits=*/4); weights = weights.to(torch::kCUDA); - qlinear.load_state_dict(*state_dict); - qlinear.verify_loaded_weights(); + qlinear.load(*state_dict); + EXPECT_TRUE(qlinear.verify()); auto input = torch::rand({40960, in_features}, options); auto output = qlinear.forward(input); @@ -83,8 +83,8 @@ TEST(QlinearTest, RowParallelQuantLinear) { /*bits=*/4); weights = weights.to(torch::kCUDA); - qlinear.load_state_dict(*state_dict); - qlinear.verify_loaded_weights(); + qlinear.load(*state_dict); + EXPECT_TRUE(qlinear.verify()); auto input = torch::rand({40960, in_features}, options); auto output = qlinear.forward(input); From 31afc667e5e0f2dd71b4c2c703f20435821d7e47 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 7 Oct 2025 10:51:44 -0700 Subject: [PATCH 03/13] clean up load_state_dict for linear layers --- src/layers/fused_linear.cpp | 71 ++++++++++----------------- src/layers/fused_linear.h | 12 +---- src/layers/linear.cpp | 85 +++++++++++++++++++++++++------- src/layers/linear.h | 19 ++++++- src/layers/linear_impl.cpp | 98 +++++++++++++++++++++++-------------- src/layers/linear_impl.h | 62 +++++++++++------------ src/module/module.cpp | 32 ++++++------ src/module/module.h | 6 +++ 8 files changed, 230 insertions(+), 155 deletions(-) diff --git a/src/layers/fused_linear.cpp b/src/layers/fused_linear.cpp index 86b5ca59..9b3034dc 100644 --- a/src/layers/fused_linear.cpp +++ b/src/layers/fused_linear.cpp @@ -19,20 +19,21 @@ FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( const QuantArgs& quant_args, const ParallelArgs& parallel_args, const torch::TensorOptions& options) { - prefixes_ = prefixes; // check if the linear layers can be fused fused_ = quant_args.can_be_fused(); if (fused_) { // fused linear layer - const int64_t out_features = std::accumulate( - out_features_vec.begin(), out_features_vec.end(), int64_t(0)); - fused_linear_ = ColumnParallelLinear(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options); + fused_linear_ = register_module("fused_linear", + ColumnParallelLinear(in_features, + out_features_vec, + prefixes, + bias, + gather_output, + quant_args, + parallel_args, + options), + /*selector=*/nullptr); + // TODO: clean up following code for calculating split sizes // calculate split sizes split_sizes_.reserve(out_features_vec.size()); const auto world_size = parallel_args.world_size(); @@ -45,14 +46,22 @@ FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( } else { // non-fused linear layers parallel_linears_.reserve(out_features_vec.size()); - for (const auto& out_features : out_features_vec) { - parallel_linears_.emplace_back(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options); + for (size_t i = 0; i < out_features_vec.size(); ++i) { + const auto& prefix = prefixes[i]; + const auto out_features = out_features_vec[i]; + + const auto linear = register_module("linear", + ColumnParallelLinear(in_features, + out_features, + bias, + gather_output, + quant_args, + parallel_args, + options, + prefix), + /*selector=*/nullptr); + + parallel_linears_.emplace_back(linear); } } } @@ -73,30 +82,4 @@ std::vector FusedColumnParallelLinearImpl::forward( } return outputs; } - -size_t FusedColumnParallelLinearImpl::load(const StateDict& state_dict, - const std::string&) { - if (fused_) { - fused_linear_->load_state_dict(state_dict, prefixes_); - } else { - CHECK_EQ(parallel_linears_.size(), prefixes_.size()); - for (size_t i = 0; i < parallel_linears_.size(); ++i) { - parallel_linears_[i]->load_state_dict(state_dict.select(prefixes_[i])); - } - } - return 0; -} - -bool FusedColumnParallelLinearImpl::verify( - const std::string& name_prefix) const { - if (fused_) { - fused_linear_->verify_loaded_weights(name_prefix); - } else { - for (const auto& parallel_linear : parallel_linears_) { - parallel_linear->verify_loaded_weights(name_prefix); - } - } - return true; -} - } // namespace llm diff --git a/src/layers/fused_linear.h b/src/layers/fused_linear.h index 73323479..86d1f116 100644 --- a/src/layers/fused_linear.h +++ b/src/layers/fused_linear.h @@ -25,14 +25,6 @@ class FusedColumnParallelLinearImpl : public Module { std::vector forward(torch::Tensor input); - // load weights from the checkpoint, override this method if necessary - // returns the number of loaded parameters - size_t load(const StateDict& state_dict, - const std::string& name_prefix = std::string()) override; - - // verify whether the weights are loaded, override this method if necessary - bool verify(const std::string& name_prefix = std::string()) const override; - // whether the linear layer is fused bool fused() const { return fused_; } @@ -43,11 +35,9 @@ class FusedColumnParallelLinearImpl : public Module { // fused linear layer ColumnParallelLinear fused_linear_{nullptr}; - // sizes for each split + // size for each split std::vector split_sizes_; - std::vector prefixes_; - // whether the linear layer is fused bool fused_ = false; }; diff --git a/src/layers/linear.cpp b/src/layers/linear.cpp index 70a6d6f8..c013d1c4 100644 --- a/src/layers/linear.cpp +++ b/src/layers/linear.cpp @@ -38,18 +38,6 @@ namespace { parallel_args, \ options); -#define MAKE_ROW_PARALLEL_LINEAR(LinearlImplClass) \ - std::make_shared(in_features, \ - out_features, \ - bias, \ - input_is_parallelized, \ - parallel_args, \ - options); - -#define MAKE_COLUMN_PARALLEL_LINEAR(LinearlImplClass) \ - std::make_shared( \ - in_features, out_features, bias, gather_output, parallel_args, options); - std::shared_ptr create_column_parallel_qlinear_by_impl( int64_t in_features, int64_t out_features, @@ -139,6 +127,7 @@ std::shared_ptr create_column_parallel_qlinear( } // not supported quant method LOG(FATAL) << "Unsupported quant method: " << quant_args.quant_method(); + return nullptr; } std::shared_ptr create_row_parallel_qlinear( @@ -170,6 +159,7 @@ std::shared_ptr create_row_parallel_qlinear( } // not supported quant method LOG(FATAL) << "Unsupported quant method: " << quant_args.quant_method(); + return nullptr; } std::shared_ptr create_column_parallel_linear( @@ -179,7 +169,8 @@ std::shared_ptr create_column_parallel_linear( bool gather_output, const QuantArgs& quant_args, const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { + const torch::TensorOptions& options, + const std::string& prefix) { if (!quant_args.quant_method().empty()) { return create_column_parallel_qlinear(in_features, out_features, @@ -189,7 +180,40 @@ std::shared_ptr create_column_parallel_linear( parallel_args, options); } - return MAKE_COLUMN_PARALLEL_LINEAR(ColumnParallelLinearImpl); + return std ::make_shared(in_features, + out_features, + bias, + gather_output, + parallel_args, + options, + prefix); +} + +std::shared_ptr create_column_parallel_linear( + int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + // if (!quant_args.quant_method().empty()) { + // return create_column_parallel_qlinear(in_features, + // out_features, + // bias, + // gather_output, + // quant_args, + // parallel_args, + // options); + // } + return std ::make_shared(in_features, + out_features, + prefixes, + bias, + gather_output, + parallel_args, + options); } std::shared_ptr create_row_parallel_linear( @@ -209,7 +233,13 @@ std::shared_ptr create_row_parallel_linear( parallel_args, options); } - return MAKE_ROW_PARALLEL_LINEAR(RowParallelLinearImpl); + return std ::make_shared(in_features, + out_features, + bias, + input_is_parallelized, + parallel_args, + options); + ; } } // namespace @@ -221,9 +251,29 @@ ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, bool gather_output, const QuantArgs& quant_args, const ParallelArgs& parallel_args, - const torch::TensorOptions& options) + const torch::TensorOptions& options, + const std::string& prefix) + : ModuleHolder(create_column_parallel_linear(in_features, + out_features, + bias, + gather_output, + quant_args, + parallel_args, + options, + prefix)) {} + +ColumnParallelLinear::ColumnParallelLinear( + int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) : ModuleHolder(create_column_parallel_linear(in_features, out_features, + prefixes, bias, gather_output, quant_args, @@ -242,7 +292,8 @@ ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, gather_output, {}, /*quant_args*/ parallel_args, - options)) {} + options, + "")) {} // construct a rotary positional embedding. // chose right implementation based on the args. diff --git a/src/layers/linear.h b/src/layers/linear.h index c7947f47..4049b32e 100644 --- a/src/layers/linear.h +++ b/src/layers/linear.h @@ -22,9 +22,14 @@ class ParallelLinearImpl : public Module { virtual torch::Tensor forward(torch::Tensor input) = 0; - virtual void load_state_dict(const StateDict& state_dict) = 0; + // TODO: clean up the interface of load_state_dict + virtual void load_state_dict(const StateDict& state_dict) { + LOG(FATAL) << "not implemented"; + } - virtual void verify_loaded_weights(const std::string& prefix = "") const = 0; + virtual void verify_loaded_weights(const std::string& prefix = "") const { + LOG(FATAL) << "not implemented"; + } // special load_state_dict for fused cases virtual void load_state_dict(const StateDict& /*state_dict*/, @@ -46,6 +51,16 @@ class ColumnParallelLinear : public ModuleHolder { bool gather_output, const QuantArgs& quant_args, const ParallelArgs& parallel_args, + const torch::TensorOptions& options, + const std::string& prefix = ""); + + ColumnParallelLinear(int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, const torch::TensorOptions& options); ColumnParallelLinear(int64_t in_features, diff --git a/src/layers/linear_impl.cpp b/src/layers/linear_impl.cpp index ac0176a5..977c94ef 100644 --- a/src/layers/linear_impl.cpp +++ b/src/layers/linear_impl.cpp @@ -4,8 +4,8 @@ #include #include -#include "model_loader/state_dict.h" #include "model_parallel/model_parallel.h" +#include "module/module.h" namespace llm { @@ -16,7 +16,8 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl( bool bias, bool gather_output, const ParallelArgs& parallel_args, - const torch::TensorOptions& options) + const torch::TensorOptions& options, + const std::string& prefix) : gather_output_(gather_output), parallel_args_(parallel_args) { const auto rank = parallel_args_.rank(); const auto world_size = parallel_args_.world_size(); @@ -28,7 +29,7 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl( // Note: torch.nn.functional.linear performs XA^T + b and as a result // we allocate the transpose. weight_ = register_sharded_parameter( - "weight", + detail::join_name(prefix, "weight"), /*dim=*/0, rank, world_size, @@ -36,7 +37,7 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl( if (bias) { bias_ = register_sharded_parameter( - "bias", + detail::join_name(prefix, "bias"), /*dim=*/0, rank, world_size, @@ -53,34 +54,70 @@ torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) { return output; } -// load the weight from the checkpoint -void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict) { +FColumnParallelLinearImpl::FColumnParallelLinearImpl( + int64_t in_features, + const std::vector& out_features_vec, + const std::vector& prefixes, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) + : gather_output_(gather_output), parallel_args_(parallel_args) { const auto rank = parallel_args_.rank(); const auto world_size = parallel_args_.world_size(); - // load sharded weights on dim 0 - LOAD_SHARDED_WEIGHT(weight, 0); - - if (bias_.defined()) { - // load sharded bias on dim 0 - LOAD_SHARDED_WEIGHT(bias, 0); + // calculate split size for each prefix + std::vector split_sizes; + split_sizes.reserve(out_features_vec.size()); + for (const auto& out_features : out_features_vec) { + CHECK(out_features % world_size == 0) + << "out_features " << out_features << " not divisible by world_size " + << world_size; + split_sizes.push_back(out_features / world_size); } -} -// special load_state_dict for fused cases -void ColumnParallelLinearImpl::load_state_dict( - const StateDict& state_dict, - const std::vector& prefixes) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); + const int64_t fused_out_features = + std::accumulate(split_sizes.begin(), split_sizes.end(), int64_t(0)); + + // allocate fused weight + weight_ = torch::empty({fused_out_features, in_features}, options); + const auto weights = weight_.split(split_sizes, /*dim=*/0); + // register sharded weights for each prefix + for (size_t i = 0; i < prefixes.size(); ++i) { + const auto& prefix = prefixes[i]; + const auto& weight = weights[i]; + // register the weight as a parameter to make sure it is moved to the + register_sharded_parameter(detail::join_name(prefix, "weight"), + /*dim=*/0, + rank, + world_size, + weight); + } - // load and merge the weights on dim 0 - LOAD_FUSED_WEIGHT(weight, 0); + if (bias) { + bias_ = torch::empty({fused_out_features}, options); + const auto biases = bias_.split(split_sizes, /*dim=*/0); + + // register sharded weights for each prefix + for (size_t i = 0; i < prefixes.size(); ++i) { + const auto& prefix = prefixes[i]; + const auto& bias = biases[i]; + register_sharded_parameter(detail::join_name(prefix, "bias"), + /*dim=*/0, + rank, + world_size, + bias); + } + } +} - if (bias_.defined()) { - // load and merge the bias on dim 0 - LOAD_FUSED_WEIGHT(bias, 0); +torch::Tensor FColumnParallelLinearImpl::forward(torch::Tensor input) { + namespace F = torch::nn::functional; + auto output = F::linear(input, weight_, bias_); + if (parallel_args_.world_size() > 1 && gather_output_) { + output = gather_from_model_parallel_region(output, parallel_args_); } + return output; } // Linear layer with row parallelism. @@ -128,17 +165,4 @@ torch::Tensor RowParallelLinearImpl::forward(torch::Tensor input) { return output; } -// load the weight from the checkpoint -void RowParallelLinearImpl::load_state_dict(const StateDict& state_dict) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // load sharded weights on dim 1 - LOAD_SHARDED_WEIGHT(weight, 1); - - if (bias_.defined()) { - LOAD_WEIGHT(bias); - } -} - } // namespace llm diff --git a/src/layers/linear_impl.h b/src/layers/linear_impl.h index de8c70c7..40cacf45 100644 --- a/src/layers/linear_impl.h +++ b/src/layers/linear_impl.h @@ -19,34 +19,47 @@ class ColumnParallelLinearImpl : public ParallelLinearImpl { bool bias, bool gather_output, const ParallelArgs& parallel_args, - const torch::TensorOptions& options); + const torch::TensorOptions& options, + const std::string& prefix = ""); torch::Tensor forward(torch::Tensor input) override; - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; + // return the weight (for testing) + torch::Tensor weight() const { return weight_; } - // special load_state_dict for fused cases - void load_state_dict(const StateDict& state_dict, - const std::vector& prefixes) override; + private: + // parameter members, must be registered + // we allocate the transpose since linear performs XA^T. + // A^T: [out_features_per_partition, in_features] + torch::Tensor weight_; + torch::Tensor bias_; - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix) const override { - CHECK(weight_is_loaded_) - << "weight is not loaded for " << prefix + "weight"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; - } + // whether to gather the output + bool gather_output_; - // return the weight (for testing) - torch::Tensor weight() const { return weight_; } + // parallel args + ParallelArgs parallel_args_; +}; + +// Fused linear layer with column parallelism. +class FColumnParallelLinearImpl : public ParallelLinearImpl { + public: + FColumnParallelLinearImpl(int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); + + torch::Tensor forward(torch::Tensor input) override; private: // parameter members, must be registered // we allocate the transpose since linear performs XA^T. // A^T: [out_features_per_partition, in_features] - DEFINE_FUSED_WEIGHT(weight); - DEFINE_FUSED_WEIGHT(bias); + torch::Tensor weight_; + torch::Tensor bias_; // whether to gather the output bool gather_output_; @@ -76,17 +89,6 @@ class RowParallelLinearImpl : public ParallelLinearImpl { torch::Tensor forward(torch::Tensor input) override; - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const override { - CHECK(weight_is_loaded_) - << "weight is not loaded for " << prefix + "weight"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; - } - // return the weight (for testing) torch::Tensor weight() const { return weight_; } @@ -94,8 +96,8 @@ class RowParallelLinearImpl : public ParallelLinearImpl { // parameter members, must be registered // we allocate the transpose since linear performs XA^T. // A^T: [out_features, in_features_per_partition] - DEFINE_WEIGHT(weight); - DEFINE_WEIGHT(bias); + torch::Tensor weight_; + torch::Tensor bias_; // whether the input is already parallelized bool input_is_parallelized_; diff --git a/src/module/module.cpp b/src/module/module.cpp index 29d730b7..3e859c64 100644 --- a/src/module/module.cpp +++ b/src/module/module.cpp @@ -7,7 +7,7 @@ namespace llm { using namespace torch; -namespace { +namespace detail { /// Joins names hierarchically: "name_prefix.name" if `name_prefix` is /// non-empty, else just "name". std::string join_name(const std::string& name_prefix, const std::string& name) { @@ -19,12 +19,15 @@ std::string join_name(const std::string& name_prefix, const std::string& name) { full_name.reserve(total_size); if (!name_prefix.empty()) { full_name += name_prefix; - full_name.push_back('.'); + // insert separator if necessary + if (name_prefix.back() != '.') { + full_name.push_back('.'); + } } full_name += name; return full_name; } -} // namespace +} // namespace detail Module::Module() : parameters_("Parameter"), buffers_("Buffer"), children_("Submodule") {} @@ -60,7 +63,8 @@ OrderedDict Module::named_parameters(bool recurse) const { apply([&result](const std::string& name, const Module& module) { for (const auto& parameter : module.named_parameters(/*recurse=*/false)) { TORCH_INTERNAL_ASSERT(parameter.value().defined()); - result.insert(join_name(name, parameter.key()), parameter.value()); + result.insert(detail::join_name(name, parameter.key()), + parameter.value()); } }); } @@ -83,7 +87,7 @@ OrderedDict Module::named_buffers(bool recurse) const { apply([&result](const std::string& name, const Module& module) { for (const auto& buffer : module.named_buffers(/*recurse=*/false)) { TORCH_INTERNAL_ASSERT(buffer.value().defined()); - result.insert(join_name(name, buffer.key()), buffer.value()); + result.insert(detail::join_name(name, buffer.key()), buffer.value()); } }); } @@ -278,7 +282,7 @@ void Module::apply_to_submodules( const NamedModulePointerApplyFunction& function, const std::string& name_prefix) const { for (const auto& child : children_) { - auto qualified_name = join_name(name_prefix, child.key()); + auto qualified_name = detail::join_name(name_prefix, child.key()); function(qualified_name, child.value().module); child.value().module->apply_to_submodules(function, qualified_name); } @@ -331,13 +335,13 @@ size_t Module::load(const StateDict& state_dict, } if (param.is_loaded) { - LOG(WARNING) << "Parameter " << join_name(name_prefix, key) + LOG(WARNING) << "Parameter " << detail::join_name(name_prefix, key) << " is already loaded"; } if (param_tensor.sizes() == tensor.sizes()) { - // LOG(INFO) << "Loading parameter: " << join_name(name_prefix, key) - // << " of size " << tensor.sizes(); + LOG(INFO) << "Loading parameter: " << detail::join_name(name_prefix, key) + << " of size " << tensor.sizes(); // copy data to the parameter tensor param_tensor.copy_(tensor); // mark as loaded @@ -345,7 +349,7 @@ size_t Module::load(const StateDict& state_dict, ++total_loaded; } else { LOG(ERROR) << "Size mismatch for parameter " - << join_name(name_prefix, key) << ": expected " + << detail::join_name(name_prefix, key) << ": expected " << param_tensor.sizes() << ", got " << tensor.sizes(); } } @@ -359,8 +363,8 @@ size_t Module::load(const StateDict& state_dict, if (child.selector) { // select state dict for the child module const auto child_state_dict = child.selector(state_dict, key); - total_loaded += - child.module->load(child_state_dict, join_name(name_prefix, key)); + total_loaded += child.module->load(child_state_dict, + detail::join_name(name_prefix, key)); } else { total_loaded += child.module->load(state_dict, name_prefix); } @@ -376,7 +380,7 @@ bool Module::verify(const std::string& name_prefix) const { const auto& key = item.key(); const auto& param = item.value(); if (!param.is_loaded) { - LOG(ERROR) << "Missing parameter: " << join_name(name_prefix, key) + LOG(ERROR) << "Missing parameter: " << detail::join_name(name_prefix, key) << ", size: " << param.tensor.sizes(); } all_loaded = all_loaded && param.is_loaded; @@ -386,7 +390,7 @@ bool Module::verify(const std::string& name_prefix) const { const auto& key = item.key(); const auto& child = item.value(); const std::string prefix = - child.selector ? join_name(name_prefix, key) : name_prefix; + child.selector ? detail::join_name(name_prefix, key) : name_prefix; const bool child_loaded = child.module->verify(prefix); all_loaded = all_loaded && child_loaded; } diff --git a/src/module/module.h b/src/module/module.h index 6cb96ff6..2d21610d 100644 --- a/src/module/module.h +++ b/src/module/module.h @@ -15,6 +15,12 @@ namespace llm { +namespace detail { +/// Joins names hierarchically: "name_prefix.name" if `name_prefix` is +/// non-empty, else just "name". +std::string join_name(const std::string& name_prefix, const std::string& name); +} // namespace detail + /// The base class for all modules. /// /// A `Module` is an abstraction over the implementation of some function or From 36449ade49b619c8364dfd26bcc1dd8367111f8a Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 7 Oct 2025 15:25:59 -0700 Subject: [PATCH 04/13] rename. --- src/layers/CMakeLists.txt | 14 +- src/layers/fused_linear.cpp | 85 ---------- src/layers/linear.cpp | 79 ++++------ src/layers/linear.h | 18 +-- src/layers/linear_test.cpp | 148 +++++++++--------- src/layers/multi_parallel_linear.cpp | 56 +++++++ ...fused_linear.h => multi_parallel_linear.h} | 17 +- .../{linear_impl.cpp => parallel_linear.cpp} | 58 +++++-- .../{linear_impl.h => parallel_linear.h} | 47 ++++-- ...qkv_linear.cpp => qkv_parallel_linear.cpp} | 4 +- .../{qkv_linear.h => qkv_parallel_linear.h} | 4 +- ..._test.cpp => qkv_parallel_linear_test.cpp} | 2 +- src/models/alibaba/qwen.h | 6 +- src/models/alibaba/qwen2.h | 6 +- src/models/google/gemma.h | 8 +- src/models/google/gemma2.h | 6 +- src/models/meta/llama.h | 13 +- src/models/openai/gpt2.h | 2 +- src/models/registered_models.h | 16 +- src/module/module.cpp | 5 +- src/quantization/qlinear_impl.cpp | 2 +- src/quantization/qlinear_impl.h | 2 +- 22 files changed, 309 insertions(+), 289 deletions(-) delete mode 100644 src/layers/fused_linear.cpp create mode 100644 src/layers/multi_parallel_linear.cpp rename src/layers/{fused_linear.h => multi_parallel_linear.h} (72%) rename src/layers/{linear_impl.cpp => parallel_linear.cpp} (72%) rename src/layers/{linear_impl.h => parallel_linear.h} (64%) rename src/layers/{qkv_linear.cpp => qkv_parallel_linear.cpp} (97%) rename src/layers/{qkv_linear.h => qkv_parallel_linear.h} (93%) rename src/layers/{qkv_linear_test.cpp => qkv_parallel_linear_test.cpp} (99%) diff --git a/src/layers/CMakeLists.txt b/src/layers/CMakeLists.txt index f178ce47..a96b85cb 100644 --- a/src/layers/CMakeLists.txt +++ b/src/layers/CMakeLists.txt @@ -6,15 +6,15 @@ cc_library( linear HDRS linear.h - qkv_linear.h - linear_impl.h - fused_linear.h + qkv_parallel_linear.h + parallel_linear.h + multi_parallel_linear.h weight_utils.h SRCS linear.cpp - qkv_linear.cpp - linear_impl.cpp - fused_linear.cpp + qkv_parallel_linear.cpp + parallel_linear.cpp + multi_parallel_linear.cpp weight_utils.cpp DEPS :state_dict @@ -74,7 +74,7 @@ cc_test( pos_embedding_test.cpp normalization_test.cpp linear_test.cpp - qkv_linear_test.cpp + qkv_parallel_linear_test.cpp DEPS :layers :state_dict diff --git a/src/layers/fused_linear.cpp b/src/layers/fused_linear.cpp deleted file mode 100644 index 9b3034dc..00000000 --- a/src/layers/fused_linear.cpp +++ /dev/null @@ -1,85 +0,0 @@ -#include "fused_linear.h" - -#include -#include - -#include "linear.h" -#include "model_loader/state_dict.h" -#include "model_parallel/parallel_args.h" -#include "quantization/quant_args.h" - -namespace llm { - -FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( - int64_t in_features, - const std::vector& out_features_vec, - const std::vector& prefixes, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { - // check if the linear layers can be fused - fused_ = quant_args.can_be_fused(); - if (fused_) { - // fused linear layer - fused_linear_ = register_module("fused_linear", - ColumnParallelLinear(in_features, - out_features_vec, - prefixes, - bias, - gather_output, - quant_args, - parallel_args, - options), - /*selector=*/nullptr); - // TODO: clean up following code for calculating split sizes - // calculate split sizes - split_sizes_.reserve(out_features_vec.size()); - const auto world_size = parallel_args.world_size(); - for (const auto& out_features : out_features_vec) { - CHECK(out_features % world_size == 0) - << "out_features " << out_features << " not divisible by world_size " - << world_size; - split_sizes_.push_back(out_features / world_size); - } - } else { - // non-fused linear layers - parallel_linears_.reserve(out_features_vec.size()); - for (size_t i = 0; i < out_features_vec.size(); ++i) { - const auto& prefix = prefixes[i]; - const auto out_features = out_features_vec[i]; - - const auto linear = register_module("linear", - ColumnParallelLinear(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options, - prefix), - /*selector=*/nullptr); - - parallel_linears_.emplace_back(linear); - } - } -} - -std::vector FusedColumnParallelLinearImpl::forward( - torch::Tensor input) { - if (fused_) { - auto fused_output = fused_linear_->forward(input); - return fused_output.split(split_sizes_, /*dim=*/1); - } - - // otherwise, use the non-fused linear layers - std::vector outputs; - outputs.reserve(parallel_linears_.size()); - for (auto& parallel_linear : parallel_linears_) { - auto output = parallel_linear->forward(input); - outputs.push_back(output); - } - return outputs; -} -} // namespace llm diff --git a/src/layers/linear.cpp b/src/layers/linear.cpp index c013d1c4..8c1526c8 100644 --- a/src/layers/linear.cpp +++ b/src/layers/linear.cpp @@ -6,7 +6,7 @@ #include #include -#include "linear_impl.h" +#include "parallel_linear.h" #include "quantization/qlinear_awq_impl.h" #include "quantization/qlinear_awq_marlin_impl.h" #include "quantization/qlinear_exllamav2_impl.h" @@ -189,33 +189,6 @@ std::shared_ptr create_column_parallel_linear( prefix); } -std::shared_ptr create_column_parallel_linear( - int64_t in_features, - const std::vector& out_features, - const std::vector& prefixes, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { - // if (!quant_args.quant_method().empty()) { - // return create_column_parallel_qlinear(in_features, - // out_features, - // bias, - // gather_output, - // quant_args, - // parallel_args, - // options); - // } - return std ::make_shared(in_features, - out_features, - prefixes, - bias, - gather_output, - parallel_args, - options); -} - std::shared_ptr create_row_parallel_linear( int64_t in_features, int64_t out_features, @@ -239,8 +212,38 @@ std::shared_ptr create_row_parallel_linear( input_is_parallelized, parallel_args, options); - ; } + +// std::shared_ptr create_multi_column_parallel_linear( +// int64_t in_features, +// const std::vector& out_features, +// const std::vector& prefixes, +// bool bias, +// bool gather_output, +// const QuantArgs& quant_args, +// const ParallelArgs& parallel_args, +// const torch::TensorOptions& options) { +// // check if the linear layers can be fused +// const bool fused = quant_args.can_be_fused(); +// std::shared_ptr impl; +// if (fused) { +// return std::make_shared(in_features, +// out_features, +// prefixes, +// bias, +// gather_output, +// parallel_args, +// options); +// } + +// return std::make_shared(in_features, +// out_features, +// prefixes, +// bias, +// gather_output, +// parallel_args, +// options); +// } } // namespace // construct a ColumnParallelLinear. @@ -262,24 +265,6 @@ ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, options, prefix)) {} -ColumnParallelLinear::ColumnParallelLinear( - int64_t in_features, - const std::vector& out_features, - const std::vector& prefixes, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) - : ModuleHolder(create_column_parallel_linear(in_features, - out_features, - prefixes, - bias, - gather_output, - quant_args, - parallel_args, - options)) {} - ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, int64_t out_features, bool bias, diff --git a/src/layers/linear.h b/src/layers/linear.h index 4049b32e..606cdfdd 100644 --- a/src/layers/linear.h +++ b/src/layers/linear.h @@ -37,6 +37,15 @@ class ParallelLinearImpl : public Module { LOG(FATAL) << "not implemented"; } }; +LLM_MODULE(ParallelLinear); + +class MultiParallelLinearImpl : public Module { + public: + ~MultiParallelLinearImpl() override = default; + + virtual std::vector forward(torch::Tensor input) = 0; +}; +LLM_MODULE(MultiParallelLinear); class ColumnParallelLinear : public ModuleHolder { public: @@ -54,15 +63,6 @@ class ColumnParallelLinear : public ModuleHolder { const torch::TensorOptions& options, const std::string& prefix = ""); - ColumnParallelLinear(int64_t in_features, - const std::vector& out_features, - const std::vector& prefixes, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); - ColumnParallelLinear(int64_t in_features, int64_t out_features, bool bias, diff --git a/src/layers/linear_test.cpp b/src/layers/linear_test.cpp index 3306ebf2..442d0bd1 100644 --- a/src/layers/linear_test.cpp +++ b/src/layers/linear_test.cpp @@ -8,8 +8,8 @@ #include #include -#include "linear_impl.h" #include "model_loader/state_dict.h" +#include "parallel_linear.h" namespace llm { @@ -40,7 +40,7 @@ TEST(LinearTest, RowParallelLoadWeight) { parallel_args, options); // test load state dict for transformer - linear.load_state_dict(state_dict); + EXPECT_EQ(linear.load(state_dict), 2); auto named_parameters = linear.named_parameters(/*recurse=*/false); EXPECT_TRUE(torch::equal(state_dict.get_tensor("weight"), named_parameters["weight"])); @@ -58,7 +58,7 @@ TEST(LinearTest, RowParallelLoadWeight) { /*input_is_parallelized=*/false, parallel_args, options); - linear.load_state_dict(state_dict); + EXPECT_EQ(linear.load(state_dict), 2); auto named_parameters = linear.named_parameters(/*recurse=*/false); @@ -78,30 +78,6 @@ TEST(LinearTest, RowParallelLoadWeight) { } } -TEST(LinearTest, RowParallelLoadFusedWeight) { - // test load state dict for row parallel linear - const int64_t in_features = 10; - const int64_t out_features = 20; - - torch::Device device(torch::kCPU); - torch::ScalarType dtype(torch::kFloat); - const auto options = torch::dtype(dtype).device(device); - StateDict state_dict({}); - - // test load weight - ParallelArgs parallel_args(0, 1, nullptr); - RowParallelLinearImpl linear(in_features, - out_features * 3, - /*bias=*/false, - /*input_is_parallelized=*/true, - parallel_args, - options); - // test load fused weight - EXPECT_DEATH( - linear.ParallelLinearImpl::load_state_dict(state_dict, {"query."}), - "not implemented"); -} - TEST(LinearTest, ColumnParallelLoadWeight) { // test load state dict for linear const int64_t in_features = 10; @@ -128,7 +104,7 @@ TEST(LinearTest, ColumnParallelLoadWeight) { parallel_args, options); // test load state dict for transformer - linear.load_state_dict(state_dict); + EXPECT_EQ(linear.load(state_dict), 1); auto named_parameters = linear.named_parameters(/*recurse=*/false); EXPECT_TRUE(torch::equal(state_dict.get_tensor("weight"), named_parameters["weight"])); @@ -143,7 +119,7 @@ TEST(LinearTest, ColumnParallelLoadWeight) { /*gather_output=*/false, parallel_args, options); - linear.load_state_dict(state_dict); + EXPECT_EQ(linear.load(state_dict), 1); auto named_parameters = linear.named_parameters(/*recurse=*/false); @@ -162,11 +138,16 @@ TEST(LinearTest, ColumnParallelLoadWeight) { TEST(LinearTest, ColumnParallelLoadFusedWeight) { // test load state dict for linear const int64_t in_features = 10; - const int64_t out_features = 20; + const int64_t out_features = 40; torch::Device device(torch::kCPU); torch::ScalarType dtype(torch::kFloat); const auto options = torch::dtype(dtype).device(device); + + std::vector out_features_vec = { + out_features, out_features, out_features}; + std::vector prefixes = {"query.", "key.", "value."}; + std::unordered_map state_dict_data; // Allocate transposed weight matrix state_dict_data["query.weight"] = torch::randn({out_features, in_features}); @@ -179,59 +160,80 @@ TEST(LinearTest, ColumnParallelLoadFusedWeight) { // test load weight { ParallelArgs parallel_args(0, 1, nullptr); - ColumnParallelLinearImpl linear(in_features, - out_features * 3, - /*bias=*/false, - /*gather_output=*/false, - parallel_args, - options); + FusedColumnParallelLinearImpl linear(in_features, + out_features_vec, + prefixes, + /*bias=*/false, + /*gather_output=*/false, + parallel_args, + options); // test load fused weight - linear.load_state_dict(state_dict, {"query.", "key.", "value."}); - - auto named_parameters = linear.named_parameters(/*recurse=*/false); - ASSERT_TRUE(named_parameters.contains("weight")); - - const auto loaded_weight = named_parameters["weight"]; - EXPECT_EQ(loaded_weight.sizes(), - torch::IntArrayRef({3 * out_features, in_features})); - - auto desired_weight = torch::cat({state_dict_data["query.weight"], - state_dict_data["key.weight"], - state_dict_data["value.weight"]}, - /*dim=*/0); - EXPECT_TRUE(torch::equal(loaded_weight, desired_weight)); + EXPECT_EQ(linear.load(state_dict), 3); + + for (const auto& prefix : prefixes) { + auto named_parameters = linear.named_parameters(/*recurse=*/false); + const auto key = detail::join_name(prefix, "weight"); + ASSERT_TRUE(named_parameters.contains(key)); + + const auto& loaded_weight = named_parameters[key]; + EXPECT_EQ(loaded_weight.sizes(), + torch::IntArrayRef({out_features, in_features})); + EXPECT_TRUE(torch::equal(loaded_weight, state_dict_data[key])); + } + + // verify the fused weight + const auto loaded_fused_weight = linear.weight(); + const auto desired_fused_weight = + torch::cat({state_dict_data["query.weight"], + state_dict_data["key.weight"], + state_dict_data["value.weight"]}, + /*dim=*/0); + EXPECT_TRUE(torch::equal(loaded_fused_weight, desired_fused_weight)); } - // test load weight with 2 shards - const int32_t num_shards = 2; + // test load weight with 4 shards + const int32_t num_shards = 4; for (int32_t shard_id = 0; shard_id < num_shards; ++shard_id) { ParallelArgs parallel_args_0(shard_id, num_shards, nullptr); - ColumnParallelLinearImpl linear(in_features, - out_features * 3, - /*bias=*/false, - /*gather_output=*/false, - parallel_args_0, - options); - linear.load_state_dict(state_dict, {"query.", "key.", "value."}); + FusedColumnParallelLinearImpl linear(in_features, + out_features_vec, + prefixes, + /*bias=*/false, + /*gather_output=*/false, + parallel_args_0, + options); + EXPECT_EQ(linear.load(state_dict), 3); auto named_parameters = linear.named_parameters(/*recurse=*/false); - const auto loaded_weight = named_parameters["weight"]; - EXPECT_EQ(loaded_weight.sizes(), - torch::IntArrayRef({3 * out_features / num_shards, in_features})); + // check size for each prefix + for (const auto& prefix : prefixes) { + auto named_parameters = linear.named_parameters(/*recurse=*/false); + const auto key = detail::join_name(prefix, "weight"); + ASSERT_TRUE(named_parameters.contains(key)); - // shard weight then cat - auto query_weight = state_dict_data["query.weight"].chunk( - /*chunks=*/num_shards, /*dim=*/0)[shard_id]; - auto key_weight = state_dict_data["key.weight"].chunk(/*chunks=*/num_shards, - /*dim=*/0)[shard_id]; - auto value_weight = state_dict_data["value.weight"].chunk( - /*chunks=*/num_shards, /*dim=*/0)[shard_id]; - - auto desired_weight = torch::cat({query_weight, key_weight, value_weight}, - /*dim=*/0); + const auto& loaded_weight = named_parameters[key]; + EXPECT_EQ(loaded_weight.sizes(), + torch::IntArrayRef({out_features / num_shards, in_features})); + EXPECT_TRUE(torch::equal( + loaded_weight, state_dict_data[key].chunk(num_shards, 0)[shard_id])); + } - EXPECT_TRUE(torch::equal(loaded_weight, desired_weight)); + // shard weight then cat + auto sharded_query_weight = + state_dict_data["query.weight"].chunk(num_shards, 0)[shard_id]; + auto sharded_key_weight = + state_dict_data["key.weight"].chunk(num_shards, 0)[shard_id]; + auto sharded_value_weight = + state_dict_data["value.weight"].chunk(num_shards, 0)[shard_id]; + + // verify the fused weight + const auto loaded_fused_weight = linear.weight(); + auto desired_fused_weight = torch::cat( + {sharded_query_weight, sharded_key_weight, sharded_value_weight}, + /*dim=*/0); + + EXPECT_TRUE(torch::equal(loaded_fused_weight, desired_fused_weight)); } } diff --git a/src/layers/multi_parallel_linear.cpp b/src/layers/multi_parallel_linear.cpp new file mode 100644 index 00000000..413aa48e --- /dev/null +++ b/src/layers/multi_parallel_linear.cpp @@ -0,0 +1,56 @@ +#include "multi_parallel_linear.h" + +#include +#include + +#include "model_parallel/parallel_args.h" +#include "parallel_linear.h" +#include "quantization/quant_args.h" + +namespace llm { + +MultiColumnParallelLinearImpl::MultiColumnParallelLinearImpl( + int64_t in_features, + const std::vector& out_features_vec, + const std::vector& prefixes, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + // check if the linear layers can be fused + fused_ = quant_args.can_be_fused(); + if (fused_) { + // fused linear layer + fused_linear_ = register_module("fused_linear", + FusedColumnParallelLinear(in_features, + out_features_vec, + prefixes, + bias, + gather_output, + parallel_args, + options), + /*selector=*/nullptr); + } else { + // non-fused linear layers + grouped_linear_ = + register_module("grouped_linear", + GroupedColumnParallelLinear(in_features, + out_features_vec, + prefixes, + bias, + gather_output, + parallel_args, + options), + /*selector=*/nullptr); + } +} + +std::vector MultiColumnParallelLinearImpl::forward( + torch::Tensor input) { + if (fused_) { + return fused_linear_(input); + } + return grouped_linear_(input); +} +} // namespace llm diff --git a/src/layers/fused_linear.h b/src/layers/multi_parallel_linear.h similarity index 72% rename from src/layers/fused_linear.h rename to src/layers/multi_parallel_linear.h index 86d1f116..6a78308b 100644 --- a/src/layers/fused_linear.h +++ b/src/layers/multi_parallel_linear.h @@ -3,18 +3,18 @@ #include #include -#include "linear.h" -#include "model_loader/state_dict.h" +// #include "linear.h" #include "model_parallel/parallel_args.h" #include "module/module.h" #include "module/module_holder.h" +#include "parallel_linear.h" #include "quantization/quant_args.h" namespace llm { -class FusedColumnParallelLinearImpl : public Module { +class MultiColumnParallelLinearImpl : public Module { public: - FusedColumnParallelLinearImpl(int64_t in_features, + MultiColumnParallelLinearImpl(int64_t in_features, const std::vector& out_features, const std::vector& prefixes, bool bias, @@ -30,17 +30,14 @@ class FusedColumnParallelLinearImpl : public Module { private: // non-fused linear layers - std::vector parallel_linears_; + GroupedColumnParallelLinear grouped_linear_{nullptr}; // fused linear layer - ColumnParallelLinear fused_linear_{nullptr}; - - // size for each split - std::vector split_sizes_; + FusedColumnParallelLinear fused_linear_{nullptr}; // whether the linear layer is fused bool fused_ = false; }; -LLM_MODULE(FusedColumnParallelLinear); +LLM_MODULE(MultiColumnParallelLinear); } // namespace llm diff --git a/src/layers/linear_impl.cpp b/src/layers/parallel_linear.cpp similarity index 72% rename from src/layers/linear_impl.cpp rename to src/layers/parallel_linear.cpp index 977c94ef..f0117c0f 100644 --- a/src/layers/linear_impl.cpp +++ b/src/layers/parallel_linear.cpp @@ -1,4 +1,4 @@ -#include "linear_impl.h" +#include "parallel_linear.h" #include #include @@ -54,7 +54,7 @@ torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) { return output; } -FColumnParallelLinearImpl::FColumnParallelLinearImpl( +FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( int64_t in_features, const std::vector& out_features_vec, const std::vector& prefixes, @@ -67,21 +67,20 @@ FColumnParallelLinearImpl::FColumnParallelLinearImpl( const auto world_size = parallel_args_.world_size(); // calculate split size for each prefix - std::vector split_sizes; - split_sizes.reserve(out_features_vec.size()); + split_sizes_.reserve(out_features_vec.size()); for (const auto& out_features : out_features_vec) { CHECK(out_features % world_size == 0) << "out_features " << out_features << " not divisible by world_size " << world_size; - split_sizes.push_back(out_features / world_size); + split_sizes_.push_back(out_features / world_size); } const int64_t fused_out_features = - std::accumulate(split_sizes.begin(), split_sizes.end(), int64_t(0)); + std::accumulate(split_sizes_.begin(), split_sizes_.end(), int64_t(0)); // allocate fused weight weight_ = torch::empty({fused_out_features, in_features}, options); - const auto weights = weight_.split(split_sizes, /*dim=*/0); + const auto weights = weight_.split(split_sizes_, /*dim=*/0); // register sharded weights for each prefix for (size_t i = 0; i < prefixes.size(); ++i) { const auto& prefix = prefixes[i]; @@ -96,7 +95,7 @@ FColumnParallelLinearImpl::FColumnParallelLinearImpl( if (bias) { bias_ = torch::empty({fused_out_features}, options); - const auto biases = bias_.split(split_sizes, /*dim=*/0); + const auto biases = bias_.split(split_sizes_, /*dim=*/0); // register sharded weights for each prefix for (size_t i = 0; i < prefixes.size(); ++i) { @@ -111,13 +110,52 @@ FColumnParallelLinearImpl::FColumnParallelLinearImpl( } } -torch::Tensor FColumnParallelLinearImpl::forward(torch::Tensor input) { +std::vector FusedColumnParallelLinearImpl::forward( + torch::Tensor input) { namespace F = torch::nn::functional; auto output = F::linear(input, weight_, bias_); if (parallel_args_.world_size() > 1 && gather_output_) { output = gather_from_model_parallel_region(output, parallel_args_); } - return output; + return output.split(split_sizes_, /*dim=*/1); +} + +GroupedColumnParallelLinearImpl::GroupedColumnParallelLinearImpl( + int64_t in_features, + const std::vector& out_features_vec, + const std::vector& prefixes, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + // register linear layers one by one + parallel_linears_.reserve(out_features_vec.size()); + for (size_t i = 0; i < out_features_vec.size(); ++i) { + const auto& prefix = prefixes[i]; + const auto out_features = out_features_vec[i]; + const auto linear = register_module( + "linear_" + std::to_string(i), + std::make_shared(in_features, + out_features, + bias, + gather_output, + parallel_args, + options, + prefix), + /*selector=*/nullptr); + + parallel_linears_.emplace_back(linear); + } +} + +std::vector GroupedColumnParallelLinearImpl::forward( + torch::Tensor input) { + std::vector outputs; + outputs.reserve(parallel_linears_.size()); + for (auto& parallel_linear : parallel_linears_) { + outputs.push_back(parallel_linear->forward(input)); + } + return outputs; } // Linear layer with row parallelism. diff --git a/src/layers/linear_impl.h b/src/layers/parallel_linear.h similarity index 64% rename from src/layers/linear_impl.h rename to src/layers/parallel_linear.h index 40cacf45..d868bbb2 100644 --- a/src/layers/linear_impl.h +++ b/src/layers/parallel_linear.h @@ -3,9 +3,10 @@ #include #include +#include + #include "linear.h" -#include "model_loader/state_dict.h" -#include "weight_utils.h" +#include "module/module_holder.h" namespace llm { @@ -42,17 +43,20 @@ class ColumnParallelLinearImpl : public ParallelLinearImpl { }; // Fused linear layer with column parallelism. -class FColumnParallelLinearImpl : public ParallelLinearImpl { +class FusedColumnParallelLinearImpl : public MultiParallelLinearImpl { public: - FColumnParallelLinearImpl(int64_t in_features, - const std::vector& out_features, - const std::vector& prefixes, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); + FusedColumnParallelLinearImpl(int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); - torch::Tensor forward(torch::Tensor input) override; + std::vector forward(torch::Tensor input) override; + + // return the weight (for testing) + torch::Tensor weight() const { return weight_; } private: // parameter members, must be registered @@ -61,12 +65,33 @@ class FColumnParallelLinearImpl : public ParallelLinearImpl { torch::Tensor weight_; torch::Tensor bias_; + std::vector split_sizes_; + // whether to gather the output bool gather_output_; // parallel args ParallelArgs parallel_args_; }; +LLM_MODULE(FusedColumnParallelLinear); + +class GroupedColumnParallelLinearImpl : public MultiParallelLinearImpl { + public: + GroupedColumnParallelLinearImpl(int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); + + std::vector forward(torch::Tensor input) override; + + private: + // parameter members, must be registered + std::vector> parallel_linears_; +}; +LLM_MODULE(GroupedColumnParallelLinear); // Linear layer with row parallelism. // The linear layer is defined as Y = XA + b. A is parallelized along diff --git a/src/layers/qkv_linear.cpp b/src/layers/qkv_parallel_linear.cpp similarity index 97% rename from src/layers/qkv_linear.cpp rename to src/layers/qkv_parallel_linear.cpp index 8ef9b532..055c8ec0 100644 --- a/src/layers/qkv_linear.cpp +++ b/src/layers/qkv_parallel_linear.cpp @@ -1,4 +1,4 @@ -#include "qkv_linear.h" +#include "qkv_parallel_linear.h" #include #include @@ -70,7 +70,7 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl( }; parallel_linear_ = register_module("qkv_parallel_linear", - FusedColumnParallelLinear(hidden_size, + MultiColumnParallelLinear(hidden_size, out_features, prefixes, bias, diff --git a/src/layers/qkv_linear.h b/src/layers/qkv_parallel_linear.h similarity index 93% rename from src/layers/qkv_linear.h rename to src/layers/qkv_parallel_linear.h index 9a5dffff..29a648b8 100644 --- a/src/layers/qkv_linear.h +++ b/src/layers/qkv_parallel_linear.h @@ -3,11 +3,11 @@ #include #include -#include "fused_linear.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" #include "module/module.h" #include "module/module_holder.h" +#include "multi_parallel_linear.h" #include "quantization/quant_args.h" namespace llm { @@ -36,7 +36,7 @@ class QKVColumnParallelLinearImpl : public Module { private: // registered modules - FusedColumnParallelLinear parallel_linear_{nullptr}; + MultiColumnParallelLinear parallel_linear_{nullptr}; }; LLM_MODULE(QKVColumnParallelLinear); diff --git a/src/layers/qkv_linear_test.cpp b/src/layers/qkv_parallel_linear_test.cpp similarity index 99% rename from src/layers/qkv_linear_test.cpp rename to src/layers/qkv_parallel_linear_test.cpp index 6f1a5e05..a45b1eb5 100644 --- a/src/layers/qkv_linear_test.cpp +++ b/src/layers/qkv_parallel_linear_test.cpp @@ -1,4 +1,4 @@ -#include "qkv_linear.h" +#include "qkv_parallel_linear.h" #include #include diff --git a/src/models/alibaba/qwen.h b/src/models/alibaba/qwen.h index 61cf7490..5861278a 100644 --- a/src/models/alibaba/qwen.h +++ b/src/models/alibaba/qwen.h @@ -10,8 +10,8 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/fused_linear.h" #include "layers/linear.h" +#include "layers/multi_parallel_linear.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" @@ -41,7 +41,7 @@ class QWenMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + MultiColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, std::vector{"w1.", "w2."}, @@ -68,7 +68,7 @@ class QWenMLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + MultiColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear c_proj_{nullptr}; ActFunc act_{nullptr}; diff --git a/src/models/alibaba/qwen2.h b/src/models/alibaba/qwen2.h index 9f1103b1..272eec70 100644 --- a/src/models/alibaba/qwen2.h +++ b/src/models/alibaba/qwen2.h @@ -12,7 +12,7 @@ #include "layers/embedding.h" #include "layers/linear.h" #include "layers/normalization.h" -#include "layers/qkv_linear.h" +#include "layers/qkv_parallel_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" @@ -40,7 +40,7 @@ class QWen2MLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + MultiColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, std::vector{"gate_proj.", "up_proj."}, @@ -68,7 +68,7 @@ class QWen2MLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + MultiColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function diff --git a/src/models/google/gemma.h b/src/models/google/gemma.h index 555ed279..3f9b6180 100644 --- a/src/models/google/gemma.h +++ b/src/models/google/gemma.h @@ -11,9 +11,9 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" -#include "layers/linear_impl.h" #include "layers/normalization.h" -#include "layers/qkv_linear.h" +#include "layers/parallel_linear.h" +#include "layers/qkv_parallel_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" @@ -40,7 +40,7 @@ class GemmaMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + MultiColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, std::vector{"gate_proj.", "up_proj."}, @@ -68,7 +68,7 @@ class GemmaMLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + MultiColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function diff --git a/src/models/google/gemma2.h b/src/models/google/gemma2.h index f93d3ac4..f0f77253 100644 --- a/src/models/google/gemma2.h +++ b/src/models/google/gemma2.h @@ -12,7 +12,7 @@ #include "layers/embedding.h" #include "layers/linear.h" #include "layers/normalization.h" -#include "layers/qkv_linear.h" +#include "layers/qkv_parallel_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" @@ -39,7 +39,7 @@ class Gemma2MLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + MultiColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, std::vector{"gate_proj.", "up_proj."}, @@ -67,7 +67,7 @@ class Gemma2MLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + MultiColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function diff --git a/src/models/meta/llama.h b/src/models/meta/llama.h index 382c3f8a..d61e391f 100644 --- a/src/models/meta/llama.h +++ b/src/models/meta/llama.h @@ -8,10 +8,10 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/fused_linear.h" #include "layers/linear.h" +#include "layers/multi_parallel_linear.h" #include "layers/normalization.h" -#include "layers/qkv_linear.h" +#include "layers/qkv_parallel_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" @@ -37,7 +37,7 @@ class LlamaMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + MultiColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, std::vector{"gate_proj.", "up_proj."}, @@ -60,13 +60,14 @@ class LlamaMLPImpl : public Module { } torch::Tensor forward(torch::Tensor x) { - const auto gate_up = gate_up_proj_(x); - return down_proj_(act_func_(gate_up[0]) * gate_up[1]); + // const auto gate_up = gate_up_proj_(x); + // return down_proj_(act_func_(gate_up[0]) * gate_up[1]); + return {}; } private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + MultiColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function diff --git a/src/models/openai/gpt2.h b/src/models/openai/gpt2.h index 2644ec48..ef5e5e37 100644 --- a/src/models/openai/gpt2.h +++ b/src/models/openai/gpt2.h @@ -8,8 +8,8 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" -#include "layers/linear_impl.h" #include "layers/normalization.h" +#include "layers/parallel_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" diff --git a/src/models/registered_models.h b/src/models/registered_models.h index fba44c42..12b78f06 100644 --- a/src/models/registered_models.h +++ b/src/models/registered_models.h @@ -2,17 +2,17 @@ // list all registered models here // Google -#include "google/gemma.h" // IWYU pragma: keep -#include "google/gemma2.h" // IWYU pragma: keep -// OpenAI -#include "openai/gpt2.h" // IWYU pragma: keep +// #include "google/gemma.h" // IWYU pragma: keep +// #include "google/gemma2.h" // IWYU pragma: keep +// // OpenAI +// #include "openai/gpt2.h" // IWYU pragma: keep // Meta #include "meta/llama.h" // IWYU pragma: keep // Microsoft -#include "microsoft/phi.h" // IWYU pragma: keep -// Alibaba -#include "alibaba/qwen.h" // IWYU pragma: keep -#include "alibaba/qwen2.h" // IWYU pragma: keep +// #include "microsoft/phi.h" // IWYU pragma: keep +// // Alibaba +// #include "alibaba/qwen.h" // IWYU pragma: keep +// #include "alibaba/qwen2.h" // IWYU pragma: keep // Deprecated models // #include "deprecated/aquila.h" diff --git a/src/module/module.cpp b/src/module/module.cpp index 3e859c64..a9082e10 100644 --- a/src/module/module.cpp +++ b/src/module/module.cpp @@ -340,8 +340,9 @@ size_t Module::load(const StateDict& state_dict, } if (param_tensor.sizes() == tensor.sizes()) { - LOG(INFO) << "Loading parameter: " << detail::join_name(name_prefix, key) - << " of size " << tensor.sizes(); + // LOG(INFO) << "Loading parameter: " << detail::join_name(name_prefix, + // key) + // << " of size " << tensor.sizes(); // copy data to the parameter tensor param_tensor.copy_(tensor); // mark as loaded diff --git a/src/quantization/qlinear_impl.cpp b/src/quantization/qlinear_impl.cpp index 01071296..c16b5cf3 100644 --- a/src/quantization/qlinear_impl.cpp +++ b/src/quantization/qlinear_impl.cpp @@ -4,7 +4,7 @@ #include #include -#include "layers/linear_impl.h" +#include "layers/linear.h" #include "model_loader/state_dict.h" namespace llm { diff --git a/src/quantization/qlinear_impl.h b/src/quantization/qlinear_impl.h index 0269246f..129330cf 100644 --- a/src/quantization/qlinear_impl.h +++ b/src/quantization/qlinear_impl.h @@ -3,7 +3,7 @@ #include #include -#include "layers/linear_impl.h" +#include "layers/linear.h" #include "layers/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/model_parallel.h" From 93c4409f47c255dfc97f92277b348151ce454c8b Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 7 Oct 2025 15:32:57 -0700 Subject: [PATCH 05/13] revert --- src/models/meta/llama.h | 5 ++--- src/models/registered_models.h | 16 ++++++++-------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/models/meta/llama.h b/src/models/meta/llama.h index d61e391f..75fac872 100644 --- a/src/models/meta/llama.h +++ b/src/models/meta/llama.h @@ -60,9 +60,8 @@ class LlamaMLPImpl : public Module { } torch::Tensor forward(torch::Tensor x) { - // const auto gate_up = gate_up_proj_(x); - // return down_proj_(act_func_(gate_up[0]) * gate_up[1]); - return {}; + const auto gate_up = gate_up_proj_(x); + return down_proj_(act_func_(gate_up[0]) * gate_up[1]); } private: diff --git a/src/models/registered_models.h b/src/models/registered_models.h index 12b78f06..fba44c42 100644 --- a/src/models/registered_models.h +++ b/src/models/registered_models.h @@ -2,17 +2,17 @@ // list all registered models here // Google -// #include "google/gemma.h" // IWYU pragma: keep -// #include "google/gemma2.h" // IWYU pragma: keep -// // OpenAI -// #include "openai/gpt2.h" // IWYU pragma: keep +#include "google/gemma.h" // IWYU pragma: keep +#include "google/gemma2.h" // IWYU pragma: keep +// OpenAI +#include "openai/gpt2.h" // IWYU pragma: keep // Meta #include "meta/llama.h" // IWYU pragma: keep // Microsoft -// #include "microsoft/phi.h" // IWYU pragma: keep -// // Alibaba -// #include "alibaba/qwen.h" // IWYU pragma: keep -// #include "alibaba/qwen2.h" // IWYU pragma: keep +#include "microsoft/phi.h" // IWYU pragma: keep +// Alibaba +#include "alibaba/qwen.h" // IWYU pragma: keep +#include "alibaba/qwen2.h" // IWYU pragma: keep // Deprecated models // #include "deprecated/aquila.h" From 8a102cd957d5973b5c2d0e0a0ba318df4bbbd61a Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 7 Oct 2025 15:41:26 -0700 Subject: [PATCH 06/13] move linear into linear folder --- src/layers/CMakeLists.txt | 29 +------------ src/layers/linear/CMakeLists.txt | 41 +++++++++++++++++++ src/layers/{ => linear}/linear.cpp | 0 src/layers/{ => linear}/linear.h | 0 src/layers/{ => linear}/linear_test.cpp | 0 .../{ => linear}/multi_parallel_linear.cpp | 0 .../{ => linear}/multi_parallel_linear.h | 0 src/layers/{ => linear}/parallel_linear.cpp | 0 src/layers/{ => linear}/parallel_linear.h | 0 .../{ => linear}/qkv_parallel_linear.cpp | 0 src/layers/{ => linear}/qkv_parallel_linear.h | 0 .../{ => linear}/qkv_parallel_linear_test.cpp | 0 src/layers/{ => linear}/weight_utils.cpp | 2 +- src/layers/{ => linear}/weight_utils.h | 2 +- src/models/alibaba/qwen.h | 4 +- src/models/alibaba/qwen2.h | 4 +- src/models/google/gemma.h | 6 +-- src/models/google/gemma2.h | 4 +- src/models/meta/llama.h | 6 +-- src/models/microsoft/phi.h | 2 +- src/models/openai/gpt2.h | 4 +- src/quantization/qlinear_awq_marlin_impl.cpp | 2 +- src/quantization/qlinear_awq_marlin_impl.h | 4 +- src/quantization/qlinear_gptq_marlin_impl.cpp | 2 +- src/quantization/qlinear_gptq_marlin_impl.h | 4 +- src/quantization/qlinear_impl.cpp | 2 +- src/quantization/qlinear_impl.h | 4 +- 27 files changed, 68 insertions(+), 54 deletions(-) create mode 100644 src/layers/linear/CMakeLists.txt rename src/layers/{ => linear}/linear.cpp (100%) rename src/layers/{ => linear}/linear.h (100%) rename src/layers/{ => linear}/linear_test.cpp (100%) rename src/layers/{ => linear}/multi_parallel_linear.cpp (100%) rename src/layers/{ => linear}/multi_parallel_linear.h (100%) rename src/layers/{ => linear}/parallel_linear.cpp (100%) rename src/layers/{ => linear}/parallel_linear.h (100%) rename src/layers/{ => linear}/qkv_parallel_linear.cpp (100%) rename src/layers/{ => linear}/qkv_parallel_linear.h (100%) rename src/layers/{ => linear}/qkv_parallel_linear_test.cpp (100%) rename src/layers/{ => linear}/weight_utils.cpp (99%) rename src/layers/{ => linear}/weight_utils.h (99%) diff --git a/src/layers/CMakeLists.txt b/src/layers/CMakeLists.txt index a96b85cb..17426c74 100644 --- a/src/layers/CMakeLists.txt +++ b/src/layers/CMakeLists.txt @@ -1,32 +1,6 @@ include(cc_library) include(cc_test) -cc_library( - NAME - linear - HDRS - linear.h - qkv_parallel_linear.h - parallel_linear.h - multi_parallel_linear.h - weight_utils.h - SRCS - linear.cpp - qkv_parallel_linear.cpp - parallel_linear.cpp - multi_parallel_linear.cpp - weight_utils.cpp - DEPS - :state_dict - :model_parallel - :quantization - :kernels - :module - glog::glog - gflags::gflags - torch -) - cc_library( NAME pos_embedding @@ -73,8 +47,6 @@ cc_test( activation_test.cpp pos_embedding_test.cpp normalization_test.cpp - linear_test.cpp - qkv_parallel_linear_test.cpp DEPS :layers :state_dict @@ -82,5 +54,6 @@ cc_test( :gtest_main ) +add_subdirectory(linear) add_subdirectory(attention) add_subdirectory(moe) diff --git a/src/layers/linear/CMakeLists.txt b/src/layers/linear/CMakeLists.txt new file mode 100644 index 00000000..b5cd255f --- /dev/null +++ b/src/layers/linear/CMakeLists.txt @@ -0,0 +1,41 @@ +include(cc_library) +include(cc_test) + +cc_library( + NAME + linear + HDRS + linear.h + parallel_linear.h + multi_parallel_linear.h + qkv_parallel_linear.h + weight_utils.h + SRCS + linear.cpp + parallel_linear.cpp + qkv_parallel_linear.cpp + multi_parallel_linear.cpp + weight_utils.cpp + DEPS + :state_dict + :model_parallel + :quantization + :kernels + :module + glog::glog + gflags::gflags + torch +) + +cc_test( + NAME + linear_test + SRCS + linear_test.cpp + qkv_parallel_linear_test.cpp + DEPS + :linear + :state_dict + absl::random_random + :gtest_main +) diff --git a/src/layers/linear.cpp b/src/layers/linear/linear.cpp similarity index 100% rename from src/layers/linear.cpp rename to src/layers/linear/linear.cpp diff --git a/src/layers/linear.h b/src/layers/linear/linear.h similarity index 100% rename from src/layers/linear.h rename to src/layers/linear/linear.h diff --git a/src/layers/linear_test.cpp b/src/layers/linear/linear_test.cpp similarity index 100% rename from src/layers/linear_test.cpp rename to src/layers/linear/linear_test.cpp diff --git a/src/layers/multi_parallel_linear.cpp b/src/layers/linear/multi_parallel_linear.cpp similarity index 100% rename from src/layers/multi_parallel_linear.cpp rename to src/layers/linear/multi_parallel_linear.cpp diff --git a/src/layers/multi_parallel_linear.h b/src/layers/linear/multi_parallel_linear.h similarity index 100% rename from src/layers/multi_parallel_linear.h rename to src/layers/linear/multi_parallel_linear.h diff --git a/src/layers/parallel_linear.cpp b/src/layers/linear/parallel_linear.cpp similarity index 100% rename from src/layers/parallel_linear.cpp rename to src/layers/linear/parallel_linear.cpp diff --git a/src/layers/parallel_linear.h b/src/layers/linear/parallel_linear.h similarity index 100% rename from src/layers/parallel_linear.h rename to src/layers/linear/parallel_linear.h diff --git a/src/layers/qkv_parallel_linear.cpp b/src/layers/linear/qkv_parallel_linear.cpp similarity index 100% rename from src/layers/qkv_parallel_linear.cpp rename to src/layers/linear/qkv_parallel_linear.cpp diff --git a/src/layers/qkv_parallel_linear.h b/src/layers/linear/qkv_parallel_linear.h similarity index 100% rename from src/layers/qkv_parallel_linear.h rename to src/layers/linear/qkv_parallel_linear.h diff --git a/src/layers/qkv_parallel_linear_test.cpp b/src/layers/linear/qkv_parallel_linear_test.cpp similarity index 100% rename from src/layers/qkv_parallel_linear_test.cpp rename to src/layers/linear/qkv_parallel_linear_test.cpp diff --git a/src/layers/weight_utils.cpp b/src/layers/linear/weight_utils.cpp similarity index 99% rename from src/layers/weight_utils.cpp rename to src/layers/linear/weight_utils.cpp index 174db65a..a037dd4d 100644 --- a/src/layers/weight_utils.cpp +++ b/src/layers/linear/weight_utils.cpp @@ -111,4 +111,4 @@ void WeightUtils::load_fused_weight( } } -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/layers/weight_utils.h b/src/layers/linear/weight_utils.h similarity index 99% rename from src/layers/weight_utils.h rename to src/layers/linear/weight_utils.h index 147506ad..dbe71a8e 100644 --- a/src/layers/weight_utils.h +++ b/src/layers/linear/weight_utils.h @@ -82,4 +82,4 @@ class WeightUtils { #define LOAD_WEIGHT(name) \ WeightUtils::load_weight(state_dict, #name, name##_, name##_is_loaded_); -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/models/alibaba/qwen.h b/src/models/alibaba/qwen.h index 5861278a..fce16e5f 100644 --- a/src/models/alibaba/qwen.h +++ b/src/models/alibaba/qwen.h @@ -10,8 +10,8 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear.h" -#include "layers/multi_parallel_linear.h" +#include "layers/linear/linear.h" +#include "layers/linear/multi_parallel_linear.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" diff --git a/src/models/alibaba/qwen2.h b/src/models/alibaba/qwen2.h index 272eec70..093f43a9 100644 --- a/src/models/alibaba/qwen2.h +++ b/src/models/alibaba/qwen2.h @@ -10,9 +10,9 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear.h" +#include "layers/linear/linear.h" +#include "layers/linear/qkv_parallel_linear.h" #include "layers/normalization.h" -#include "layers/qkv_parallel_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" diff --git a/src/models/google/gemma.h b/src/models/google/gemma.h index 3f9b6180..cd3f40d2 100644 --- a/src/models/google/gemma.h +++ b/src/models/google/gemma.h @@ -10,10 +10,10 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear.h" +#include "layers/linear/linear.h" +#include "layers/linear/parallel_linear.h" +#include "layers/linear/qkv_parallel_linear.h" #include "layers/normalization.h" -#include "layers/parallel_linear.h" -#include "layers/qkv_parallel_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" diff --git a/src/models/google/gemma2.h b/src/models/google/gemma2.h index f0f77253..acfe98f6 100644 --- a/src/models/google/gemma2.h +++ b/src/models/google/gemma2.h @@ -10,9 +10,9 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear.h" +#include "layers/linear/linear.h" +#include "layers/linear/qkv_parallel_linear.h" #include "layers/normalization.h" -#include "layers/qkv_parallel_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" diff --git a/src/models/meta/llama.h b/src/models/meta/llama.h index 75fac872..1f184240 100644 --- a/src/models/meta/llama.h +++ b/src/models/meta/llama.h @@ -8,10 +8,10 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear.h" -#include "layers/multi_parallel_linear.h" +#include "layers/linear/linear.h" +#include "layers/linear/multi_parallel_linear.h" +#include "layers/linear/qkv_parallel_linear.h" #include "layers/normalization.h" -#include "layers/qkv_parallel_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" diff --git a/src/models/microsoft/phi.h b/src/models/microsoft/phi.h index f81ed123..3f361048 100644 --- a/src/models/microsoft/phi.h +++ b/src/models/microsoft/phi.h @@ -6,7 +6,7 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear.h" +#include "layers/linear/linear.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" diff --git a/src/models/openai/gpt2.h b/src/models/openai/gpt2.h index ef5e5e37..88e03e75 100644 --- a/src/models/openai/gpt2.h +++ b/src/models/openai/gpt2.h @@ -7,9 +7,9 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear.h" +#include "layers/linear/linear.h" +#include "layers/linear/parallel_linear.h" #include "layers/normalization.h" -#include "layers/parallel_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" diff --git a/src/quantization/qlinear_awq_marlin_impl.cpp b/src/quantization/qlinear_awq_marlin_impl.cpp index 5cf8c139..a3c9922f 100644 --- a/src/quantization/qlinear_awq_marlin_impl.cpp +++ b/src/quantization/qlinear_awq_marlin_impl.cpp @@ -7,7 +7,7 @@ #include #include "kernels/quantization/marlin.h" -#include "layers/weight_utils.h" +#include "layers/linear/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/model_parallel.h" #include "pack_utils.h" diff --git a/src/quantization/qlinear_awq_marlin_impl.h b/src/quantization/qlinear_awq_marlin_impl.h index 9d2d7fc3..879793c5 100644 --- a/src/quantization/qlinear_awq_marlin_impl.h +++ b/src/quantization/qlinear_awq_marlin_impl.h @@ -3,8 +3,8 @@ #include #include -#include "layers/linear.h" -#include "layers/weight_utils.h" +#include "layers/linear/linear.h" +#include "layers/linear/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" diff --git a/src/quantization/qlinear_gptq_marlin_impl.cpp b/src/quantization/qlinear_gptq_marlin_impl.cpp index a19b7302..1fb11a72 100644 --- a/src/quantization/qlinear_gptq_marlin_impl.cpp +++ b/src/quantization/qlinear_gptq_marlin_impl.cpp @@ -5,7 +5,7 @@ #include #include "kernels/quantization/marlin.h" -#include "layers/weight_utils.h" +#include "layers/linear/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/model_parallel.h" diff --git a/src/quantization/qlinear_gptq_marlin_impl.h b/src/quantization/qlinear_gptq_marlin_impl.h index e996b479..62a5fad7 100644 --- a/src/quantization/qlinear_gptq_marlin_impl.h +++ b/src/quantization/qlinear_gptq_marlin_impl.h @@ -3,8 +3,8 @@ #include #include -#include "layers/linear.h" -#include "layers/weight_utils.h" +#include "layers/linear/linear.h" +#include "layers/linear/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" diff --git a/src/quantization/qlinear_impl.cpp b/src/quantization/qlinear_impl.cpp index c16b5cf3..2fe0517b 100644 --- a/src/quantization/qlinear_impl.cpp +++ b/src/quantization/qlinear_impl.cpp @@ -4,7 +4,7 @@ #include #include -#include "layers/linear.h" +#include "layers/linear/linear.h" #include "model_loader/state_dict.h" namespace llm { diff --git a/src/quantization/qlinear_impl.h b/src/quantization/qlinear_impl.h index 129330cf..eee26afc 100644 --- a/src/quantization/qlinear_impl.h +++ b/src/quantization/qlinear_impl.h @@ -3,8 +3,8 @@ #include #include -#include "layers/linear.h" -#include "layers/weight_utils.h" +#include "layers/linear/linear.h" +#include "layers/linear/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/model_parallel.h" #include "models/model_args.h" From 0ca92b5e1c9133d373e04fb64a4c49c36269cd1f Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 7 Oct 2025 15:52:16 -0700 Subject: [PATCH 07/13] refactor --- src/layers/linear/multi_parallel_linear.cpp | 48 ++++++++++----------- src/layers/linear/multi_parallel_linear.h | 15 +------ 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/src/layers/linear/multi_parallel_linear.cpp b/src/layers/linear/multi_parallel_linear.cpp index 413aa48e..1f2a2819 100644 --- a/src/layers/linear/multi_parallel_linear.cpp +++ b/src/layers/linear/multi_parallel_linear.cpp @@ -3,6 +3,7 @@ #include #include +#include "layers/linear/linear.h" #include "model_parallel/parallel_args.h" #include "parallel_linear.h" #include "quantization/quant_args.h" @@ -19,38 +20,35 @@ MultiColumnParallelLinearImpl::MultiColumnParallelLinearImpl( const ParallelArgs& parallel_args, const torch::TensorOptions& options) { // check if the linear layers can be fused - fused_ = quant_args.can_be_fused(); - if (fused_) { + std::shared_ptr linear; + if (quant_args.can_be_fused()) { // fused linear layer - fused_linear_ = register_module("fused_linear", - FusedColumnParallelLinear(in_features, - out_features_vec, - prefixes, - bias, - gather_output, - parallel_args, - options), - /*selector=*/nullptr); + linear = register_module("fused_linear", + FusedColumnParallelLinear(in_features, + out_features_vec, + prefixes, + bias, + gather_output, + parallel_args, + options), + /*selector=*/nullptr); } else { // non-fused linear layers - grouped_linear_ = - register_module("grouped_linear", - GroupedColumnParallelLinear(in_features, - out_features_vec, - prefixes, - bias, - gather_output, - parallel_args, - options), - /*selector=*/nullptr); + linear = register_module("grouped_linear", + GroupedColumnParallelLinear(in_features, + out_features_vec, + prefixes, + bias, + gather_output, + parallel_args, + options), + /*selector=*/nullptr); } + linear_ = linear; } std::vector MultiColumnParallelLinearImpl::forward( torch::Tensor input) { - if (fused_) { - return fused_linear_(input); - } - return grouped_linear_(input); + return linear_(input); } } // namespace llm diff --git a/src/layers/linear/multi_parallel_linear.h b/src/layers/linear/multi_parallel_linear.h index 6a78308b..827420f0 100644 --- a/src/layers/linear/multi_parallel_linear.h +++ b/src/layers/linear/multi_parallel_linear.h @@ -3,11 +3,10 @@ #include #include -// #include "linear.h" +#include "linear.h" #include "model_parallel/parallel_args.h" #include "module/module.h" #include "module/module_holder.h" -#include "parallel_linear.h" #include "quantization/quant_args.h" namespace llm { @@ -25,18 +24,8 @@ class MultiColumnParallelLinearImpl : public Module { std::vector forward(torch::Tensor input); - // whether the linear layer is fused - bool fused() const { return fused_; } - private: - // non-fused linear layers - GroupedColumnParallelLinear grouped_linear_{nullptr}; - - // fused linear layer - FusedColumnParallelLinear fused_linear_{nullptr}; - - // whether the linear layer is fused - bool fused_ = false; + MultiParallelLinear linear_{nullptr}; }; LLM_MODULE(MultiColumnParallelLinear); From 8e4bacb12c97cf80f4c2dd30ebc154fed882d02a Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 7 Oct 2025 16:11:35 -0700 Subject: [PATCH 08/13] remove linear.h/cpp --- src/layers/linear/CMakeLists.txt | 2 - src/layers/linear/linear.cpp | 299 --------------- src/layers/linear/linear.h | 89 ----- src/layers/linear/linear_test.cpp | 1 + src/layers/linear/multi_parallel_linear.cpp | 106 +++++- src/layers/linear/multi_parallel_linear.h | 62 ++- src/layers/linear/parallel_linear.cpp | 393 ++++++++++++++------ src/layers/linear/parallel_linear.h | 125 ++++--- src/models/alibaba/qwen.h | 1 - src/models/alibaba/qwen2.h | 1 - src/models/google/gemma.h | 1 - src/models/google/gemma2.h | 1 - src/models/meta/llama.h | 1 - src/models/microsoft/phi.h | 1 - src/models/openai/gpt2.h | 1 - src/quantization/qlinear_awq_marlin_impl.h | 2 +- src/quantization/qlinear_gptq_marlin_impl.h | 2 +- src/quantization/qlinear_impl.cpp | 1 - src/quantization/qlinear_impl.h | 3 +- 19 files changed, 530 insertions(+), 562 deletions(-) delete mode 100644 src/layers/linear/linear.cpp delete mode 100644 src/layers/linear/linear.h diff --git a/src/layers/linear/CMakeLists.txt b/src/layers/linear/CMakeLists.txt index b5cd255f..6a17000c 100644 --- a/src/layers/linear/CMakeLists.txt +++ b/src/layers/linear/CMakeLists.txt @@ -5,13 +5,11 @@ cc_library( NAME linear HDRS - linear.h parallel_linear.h multi_parallel_linear.h qkv_parallel_linear.h weight_utils.h SRCS - linear.cpp parallel_linear.cpp qkv_parallel_linear.cpp multi_parallel_linear.cpp diff --git a/src/layers/linear/linear.cpp b/src/layers/linear/linear.cpp deleted file mode 100644 index 8c1526c8..00000000 --- a/src/layers/linear/linear.cpp +++ /dev/null @@ -1,299 +0,0 @@ -#include "linear.h" - -#include -#include - -#include -#include - -#include "parallel_linear.h" -#include "quantization/qlinear_awq_impl.h" -#include "quantization/qlinear_awq_marlin_impl.h" -#include "quantization/qlinear_exllamav2_impl.h" -#include "quantization/qlinear_gptq_impl.h" -#include "quantization/qlinear_gptq_marlin_impl.h" - -DEFINE_string( - qlinear_gptq_impl, - "auto", - "type of qlinear gptq impl: slow, cuda, exllamav2, marlin or auto"); - -namespace llm { -namespace { -#define MAKE_ROW_PARALLEL_QLINEAR(QLinearlImplClass) \ - std::make_shared(in_features, \ - out_features, \ - bias, \ - quant_args, \ - input_is_parallelized, \ - parallel_args, \ - options); - -#define MAKE_COLUMN_PARALLEL_QLINEAR(QLinearlImplClass) \ - std::make_shared(in_features, \ - out_features, \ - bias, \ - quant_args, \ - gather_output, \ - parallel_args, \ - options); - -std::shared_ptr create_column_parallel_qlinear_by_impl( - int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { - if (boost::iequals(FLAGS_qlinear_gptq_impl, "slow")) { - return std::make_shared(in_features, - out_features, - bias, - quant_args, - /*qweight_pack_dim=*/0, - gather_output, - parallel_args, - options); - } - if (boost::iequals(FLAGS_qlinear_gptq_impl, "cuda")) { - return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearGPTQImpl); - } - if (boost::iequals(FLAGS_qlinear_gptq_impl, "exllamav2")) { - return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearExllamav2Impl); - } - if (boost::iequals(FLAGS_qlinear_gptq_impl, "marlin")) { - return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearGPTQMarlinImpl); - } - return nullptr; -} - -std::shared_ptr create_row_parallel_qlinear_by_impl( - int64_t in_features, - int64_t out_features, - bool bias, - bool input_is_parallelized, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { - if (boost::iequals(FLAGS_qlinear_gptq_impl, "slow")) { - return std::make_shared(in_features, - out_features, - bias, - quant_args, - /*qweight_pack_dim=*/0, - input_is_parallelized, - parallel_args, - options); - } - if (boost::iequals(FLAGS_qlinear_gptq_impl, "cuda")) { - return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearGPTQImpl); - } - if (boost::iequals(FLAGS_qlinear_gptq_impl, "exllamav2")) { - return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearExllamav2Impl); - } - if (boost::iequals(FLAGS_qlinear_gptq_impl, "marlin")) { - return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearGPTQMarlinImpl); - } - return nullptr; -} - -std::shared_ptr create_column_parallel_qlinear( - int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { - if (auto qlinear = create_column_parallel_qlinear_by_impl(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options)) { - return qlinear; - } - if (boost::iequals(quant_args.quant_method(), "gptq")) { - // default to use marlin implementation for gptq - return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearGPTQMarlinImpl); - } - if (boost::iequals(quant_args.quant_method(), "awq") || - boost::iequals(quant_args.quant_method(), "GEMM")) { - // default to use awq implementation for gemm - // return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearAWQImpl); - return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearAWQMarlinImpl); - } - // not supported quant method - LOG(FATAL) << "Unsupported quant method: " << quant_args.quant_method(); - return nullptr; -} - -std::shared_ptr create_row_parallel_qlinear( - int64_t in_features, - int64_t out_features, - bool bias, - bool input_is_parallelized, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { - if (auto qlinear = create_row_parallel_qlinear_by_impl(in_features, - out_features, - bias, - input_is_parallelized, - quant_args, - parallel_args, - options)) { - return qlinear; - } - if (boost::iequals(quant_args.quant_method(), "gptq")) { - // default to use marlin implementation for gptq - return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearGPTQMarlinImpl); - } - if (boost::iequals(quant_args.quant_method(), "awq") || - boost::iequals(quant_args.quant_method(), "GEMM")) { - // default to use awq implementation for gemm - // return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearAWQImpl); - return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearAWQMarlinImpl); - } - // not supported quant method - LOG(FATAL) << "Unsupported quant method: " << quant_args.quant_method(); - return nullptr; -} - -std::shared_ptr create_column_parallel_linear( - int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options, - const std::string& prefix) { - if (!quant_args.quant_method().empty()) { - return create_column_parallel_qlinear(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options); - } - return std ::make_shared(in_features, - out_features, - bias, - gather_output, - parallel_args, - options, - prefix); -} - -std::shared_ptr create_row_parallel_linear( - int64_t in_features, - int64_t out_features, - bool bias, - bool input_is_parallelized, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { - if (!quant_args.quant_method().empty()) { - return create_row_parallel_qlinear(in_features, - out_features, - bias, - input_is_parallelized, - quant_args, - parallel_args, - options); - } - return std ::make_shared(in_features, - out_features, - bias, - input_is_parallelized, - parallel_args, - options); -} - -// std::shared_ptr create_multi_column_parallel_linear( -// int64_t in_features, -// const std::vector& out_features, -// const std::vector& prefixes, -// bool bias, -// bool gather_output, -// const QuantArgs& quant_args, -// const ParallelArgs& parallel_args, -// const torch::TensorOptions& options) { -// // check if the linear layers can be fused -// const bool fused = quant_args.can_be_fused(); -// std::shared_ptr impl; -// if (fused) { -// return std::make_shared(in_features, -// out_features, -// prefixes, -// bias, -// gather_output, -// parallel_args, -// options); -// } - -// return std::make_shared(in_features, -// out_features, -// prefixes, -// bias, -// gather_output, -// parallel_args, -// options); -// } -} // namespace - -// construct a ColumnParallelLinear. -// chose right implementation based on the args. -ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options, - const std::string& prefix) - : ModuleHolder(create_column_parallel_linear(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options, - prefix)) {} - -ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) - : ModuleHolder(create_column_parallel_linear(in_features, - out_features, - bias, - gather_output, - {}, /*quant_args*/ - parallel_args, - options, - "")) {} - -// construct a rotary positional embedding. -// chose right implementation based on the args. -RowParallelLinear::RowParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool input_is_parallelized, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) - : ModuleHolder(create_row_parallel_linear(in_features, - out_features, - bias, - input_is_parallelized, - quant_args, - parallel_args, - options)) {} -} // namespace llm diff --git a/src/layers/linear/linear.h b/src/layers/linear/linear.h deleted file mode 100644 index 606cdfdd..00000000 --- a/src/layers/linear/linear.h +++ /dev/null @@ -1,89 +0,0 @@ -#pragma once - -#include -#include - -#include "model_loader/state_dict.h" -#include "model_parallel/parallel_args.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "quantization/quant_args.h" - -namespace llm { - -using TensorTransform = std::function; - -// an interface for parallel linear layer. -// all linear classes should inherit from this class and implement the forward -// function. -class ParallelLinearImpl : public Module { - public: - ~ParallelLinearImpl() override = default; - - virtual torch::Tensor forward(torch::Tensor input) = 0; - - // TODO: clean up the interface of load_state_dict - virtual void load_state_dict(const StateDict& state_dict) { - LOG(FATAL) << "not implemented"; - } - - virtual void verify_loaded_weights(const std::string& prefix = "") const { - LOG(FATAL) << "not implemented"; - } - - // special load_state_dict for fused cases - virtual void load_state_dict(const StateDict& /*state_dict*/, - const std::vector& /*prefixes*/) { - LOG(FATAL) << "not implemented"; - } -}; -LLM_MODULE(ParallelLinear); - -class MultiParallelLinearImpl : public Module { - public: - ~MultiParallelLinearImpl() override = default; - - virtual std::vector forward(torch::Tensor input) = 0; -}; -LLM_MODULE(MultiParallelLinear); - -class ColumnParallelLinear : public ModuleHolder { - public: - using ModuleHolder::ModuleHolder; - using Impl [[maybe_unused]] = ParallelLinearImpl; - - // construct a rotary positional embedding. - // chose right implementation based on the args. - ColumnParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options, - const std::string& prefix = ""); - - ColumnParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); -}; - -class RowParallelLinear : public ModuleHolder { - public: - using ModuleHolder::ModuleHolder; - using Impl [[maybe_unused]] = ParallelLinearImpl; - - // construct a rotary positional embedding. - // chose right implementation based on the args. - RowParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool input_is_parallelized, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); -}; -} // namespace llm diff --git a/src/layers/linear/linear_test.cpp b/src/layers/linear/linear_test.cpp index 442d0bd1..6b72063f 100644 --- a/src/layers/linear/linear_test.cpp +++ b/src/layers/linear/linear_test.cpp @@ -9,6 +9,7 @@ #include #include "model_loader/state_dict.h" +#include "multi_parallel_linear.h" #include "parallel_linear.h" namespace llm { diff --git a/src/layers/linear/multi_parallel_linear.cpp b/src/layers/linear/multi_parallel_linear.cpp index 1f2a2819..146ec5d0 100644 --- a/src/layers/linear/multi_parallel_linear.cpp +++ b/src/layers/linear/multi_parallel_linear.cpp @@ -3,13 +3,117 @@ #include #include -#include "layers/linear/linear.h" +#include "model_parallel/model_parallel.h" #include "model_parallel/parallel_args.h" #include "parallel_linear.h" #include "quantization/quant_args.h" namespace llm { +FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( + int64_t in_features, + const std::vector& out_features_vec, + const std::vector& prefixes, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) + : gather_output_(gather_output), parallel_args_(parallel_args) { + const auto rank = parallel_args_.rank(); + const auto world_size = parallel_args_.world_size(); + + // calculate split size for each prefix + split_sizes_.reserve(out_features_vec.size()); + for (const auto& out_features : out_features_vec) { + CHECK(out_features % world_size == 0) + << "out_features " << out_features << " not divisible by world_size " + << world_size; + split_sizes_.push_back(out_features / world_size); + } + + const int64_t fused_out_features = + std::accumulate(split_sizes_.begin(), split_sizes_.end(), int64_t(0)); + + // allocate fused weight + weight_ = torch::empty({fused_out_features, in_features}, options); + const auto weights = weight_.split(split_sizes_, /*dim=*/0); + // register sharded weights for each prefix + for (size_t i = 0; i < prefixes.size(); ++i) { + const auto& prefix = prefixes[i]; + const auto& weight = weights[i]; + // register the weight as a parameter to make sure it is moved to the + register_sharded_parameter(detail::join_name(prefix, "weight"), + /*dim=*/0, + rank, + world_size, + weight); + } + + if (bias) { + bias_ = torch::empty({fused_out_features}, options); + const auto biases = bias_.split(split_sizes_, /*dim=*/0); + + // register sharded weights for each prefix + for (size_t i = 0; i < prefixes.size(); ++i) { + const auto& prefix = prefixes[i]; + const auto& bias = biases[i]; + register_sharded_parameter(detail::join_name(prefix, "bias"), + /*dim=*/0, + rank, + world_size, + bias); + } + } +} + +std::vector FusedColumnParallelLinearImpl::forward( + torch::Tensor input) { + namespace F = torch::nn::functional; + auto output = F::linear(input, weight_, bias_); + if (parallel_args_.world_size() > 1 && gather_output_) { + output = gather_from_model_parallel_region(output, parallel_args_); + } + return output.split(split_sizes_, /*dim=*/1); +} + +GroupedColumnParallelLinearImpl::GroupedColumnParallelLinearImpl( + int64_t in_features, + const std::vector& out_features_vec, + const std::vector& prefixes, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + // register linear layers one by one + parallel_linears_.reserve(out_features_vec.size()); + for (size_t i = 0; i < out_features_vec.size(); ++i) { + const auto& prefix = prefixes[i]; + const auto out_features = out_features_vec[i]; + const auto linear = register_module( + "linear_" + std::to_string(i), + std::make_shared(in_features, + out_features, + bias, + gather_output, + parallel_args, + options, + prefix), + /*selector=*/nullptr); + + parallel_linears_.emplace_back(linear); + } +} + +std::vector GroupedColumnParallelLinearImpl::forward( + torch::Tensor input) { + std::vector outputs; + outputs.reserve(parallel_linears_.size()); + for (auto& parallel_linear : parallel_linears_) { + outputs.push_back(parallel_linear->forward(input)); + } + return outputs; +} + MultiColumnParallelLinearImpl::MultiColumnParallelLinearImpl( int64_t in_features, const std::vector& out_features_vec, diff --git a/src/layers/linear/multi_parallel_linear.h b/src/layers/linear/multi_parallel_linear.h index 827420f0..b36c5181 100644 --- a/src/layers/linear/multi_parallel_linear.h +++ b/src/layers/linear/multi_parallel_linear.h @@ -3,14 +3,74 @@ #include #include -#include "linear.h" +// #include "linear.h" #include "model_parallel/parallel_args.h" #include "module/module.h" #include "module/module_holder.h" +#include "parallel_linear.h" #include "quantization/quant_args.h" namespace llm { +class MultiParallelLinearImpl : public Module { + public: + ~MultiParallelLinearImpl() override = default; + + virtual std::vector forward(torch::Tensor input) = 0; +}; +LLM_MODULE(MultiParallelLinear); + +// Fused linear layer with column parallelism. +class FusedColumnParallelLinearImpl : public MultiParallelLinearImpl { + public: + FusedColumnParallelLinearImpl(int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); + + std::vector forward(torch::Tensor input) override; + + // return the weight (for testing) + torch::Tensor weight() const { return weight_; } + + private: + // parameter members, must be registered + // we allocate the transpose since linear performs XA^T. + // A^T: [out_features_per_partition, in_features] + torch::Tensor weight_; + torch::Tensor bias_; + + std::vector split_sizes_; + + // whether to gather the output + bool gather_output_; + + // parallel args + ParallelArgs parallel_args_; +}; +LLM_MODULE(FusedColumnParallelLinear); + +class GroupedColumnParallelLinearImpl : public MultiParallelLinearImpl { + public: + GroupedColumnParallelLinearImpl(int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); + + std::vector forward(torch::Tensor input) override; + + private: + // parameter members, must be registered + std::vector> parallel_linears_; +}; +LLM_MODULE(GroupedColumnParallelLinear); + class MultiColumnParallelLinearImpl : public Module { public: MultiColumnParallelLinearImpl(int64_t in_features, diff --git a/src/layers/linear/parallel_linear.cpp b/src/layers/linear/parallel_linear.cpp index f0117c0f..40219588 100644 --- a/src/layers/linear/parallel_linear.cpp +++ b/src/layers/linear/parallel_linear.cpp @@ -4,10 +4,249 @@ #include #include +#include +#include + #include "model_parallel/model_parallel.h" #include "module/module.h" +#include "quantization/qlinear_awq_impl.h" +#include "quantization/qlinear_awq_marlin_impl.h" +#include "quantization/qlinear_exllamav2_impl.h" +#include "quantization/qlinear_gptq_impl.h" +#include "quantization/qlinear_gptq_marlin_impl.h" + +DEFINE_string( + qlinear_gptq_impl, + "auto", + "type of qlinear gptq impl: slow, cuda, exllamav2, marlin or auto"); namespace llm { +namespace { +#define MAKE_ROW_PARALLEL_QLINEAR(QLinearlImplClass) \ + std::make_shared(in_features, \ + out_features, \ + bias, \ + quant_args, \ + input_is_parallelized, \ + parallel_args, \ + options); + +#define MAKE_COLUMN_PARALLEL_QLINEAR(QLinearlImplClass) \ + std::make_shared(in_features, \ + out_features, \ + bias, \ + quant_args, \ + gather_output, \ + parallel_args, \ + options); + +std::shared_ptr create_column_parallel_qlinear_by_impl( + int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + if (boost::iequals(FLAGS_qlinear_gptq_impl, "slow")) { + return std::make_shared(in_features, + out_features, + bias, + quant_args, + /*qweight_pack_dim=*/0, + gather_output, + parallel_args, + options); + } + if (boost::iequals(FLAGS_qlinear_gptq_impl, "cuda")) { + return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearGPTQImpl); + } + if (boost::iequals(FLAGS_qlinear_gptq_impl, "exllamav2")) { + return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearExllamav2Impl); + } + if (boost::iequals(FLAGS_qlinear_gptq_impl, "marlin")) { + return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearGPTQMarlinImpl); + } + return nullptr; +} + +std::shared_ptr create_row_parallel_qlinear_by_impl( + int64_t in_features, + int64_t out_features, + bool bias, + bool input_is_parallelized, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + if (boost::iequals(FLAGS_qlinear_gptq_impl, "slow")) { + return std::make_shared(in_features, + out_features, + bias, + quant_args, + /*qweight_pack_dim=*/0, + input_is_parallelized, + parallel_args, + options); + } + if (boost::iequals(FLAGS_qlinear_gptq_impl, "cuda")) { + return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearGPTQImpl); + } + if (boost::iequals(FLAGS_qlinear_gptq_impl, "exllamav2")) { + return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearExllamav2Impl); + } + if (boost::iequals(FLAGS_qlinear_gptq_impl, "marlin")) { + return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearGPTQMarlinImpl); + } + return nullptr; +} + +std::shared_ptr create_column_parallel_qlinear( + int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + if (auto qlinear = create_column_parallel_qlinear_by_impl(in_features, + out_features, + bias, + gather_output, + quant_args, + parallel_args, + options)) { + return qlinear; + } + if (boost::iequals(quant_args.quant_method(), "gptq")) { + // default to use marlin implementation for gptq + return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearGPTQMarlinImpl); + } + if (boost::iequals(quant_args.quant_method(), "awq") || + boost::iequals(quant_args.quant_method(), "GEMM")) { + // default to use awq implementation for gemm + // return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearAWQImpl); + return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearAWQMarlinImpl); + } + // not supported quant method + LOG(FATAL) << "Unsupported quant method: " << quant_args.quant_method(); + return nullptr; +} + +std::shared_ptr create_row_parallel_qlinear( + int64_t in_features, + int64_t out_features, + bool bias, + bool input_is_parallelized, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + if (auto qlinear = create_row_parallel_qlinear_by_impl(in_features, + out_features, + bias, + input_is_parallelized, + quant_args, + parallel_args, + options)) { + return qlinear; + } + if (boost::iequals(quant_args.quant_method(), "gptq")) { + // default to use marlin implementation for gptq + return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearGPTQMarlinImpl); + } + if (boost::iequals(quant_args.quant_method(), "awq") || + boost::iequals(quant_args.quant_method(), "GEMM")) { + // default to use awq implementation for gemm + // return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearAWQImpl); + return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearAWQMarlinImpl); + } + // not supported quant method + LOG(FATAL) << "Unsupported quant method: " << quant_args.quant_method(); + return nullptr; +} + +std::shared_ptr create_column_parallel_linear( + int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options, + const std::string& prefix) { + if (!quant_args.quant_method().empty()) { + return create_column_parallel_qlinear(in_features, + out_features, + bias, + gather_output, + quant_args, + parallel_args, + options); + } + return std ::make_shared(in_features, + out_features, + bias, + gather_output, + parallel_args, + options, + prefix); +} + +std::shared_ptr create_row_parallel_linear( + int64_t in_features, + int64_t out_features, + bool bias, + bool input_is_parallelized, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + if (!quant_args.quant_method().empty()) { + return create_row_parallel_qlinear(in_features, + out_features, + bias, + input_is_parallelized, + quant_args, + parallel_args, + options); + } + return std ::make_shared(in_features, + out_features, + bias, + input_is_parallelized, + parallel_args, + options); +} + +// std::shared_ptr create_multi_column_parallel_linear( +// int64_t in_features, +// const std::vector& out_features, +// const std::vector& prefixes, +// bool bias, +// bool gather_output, +// const QuantArgs& quant_args, +// const ParallelArgs& parallel_args, +// const torch::TensorOptions& options) { +// // check if the linear layers can be fused +// const bool fused = quant_args.can_be_fused(); +// std::shared_ptr impl; +// if (fused) { +// return std::make_shared(in_features, +// out_features, +// prefixes, +// bias, +// gather_output, +// parallel_args, +// options); +// } + +// return std::make_shared(in_features, +// out_features, +// prefixes, +// bias, +// gather_output, +// parallel_args, +// options); +// } +} // namespace // Linear layer with column parallelism. ColumnParallelLinearImpl::ColumnParallelLinearImpl( @@ -54,110 +293,6 @@ torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) { return output; } -FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( - int64_t in_features, - const std::vector& out_features_vec, - const std::vector& prefixes, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) - : gather_output_(gather_output), parallel_args_(parallel_args) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // calculate split size for each prefix - split_sizes_.reserve(out_features_vec.size()); - for (const auto& out_features : out_features_vec) { - CHECK(out_features % world_size == 0) - << "out_features " << out_features << " not divisible by world_size " - << world_size; - split_sizes_.push_back(out_features / world_size); - } - - const int64_t fused_out_features = - std::accumulate(split_sizes_.begin(), split_sizes_.end(), int64_t(0)); - - // allocate fused weight - weight_ = torch::empty({fused_out_features, in_features}, options); - const auto weights = weight_.split(split_sizes_, /*dim=*/0); - // register sharded weights for each prefix - for (size_t i = 0; i < prefixes.size(); ++i) { - const auto& prefix = prefixes[i]; - const auto& weight = weights[i]; - // register the weight as a parameter to make sure it is moved to the - register_sharded_parameter(detail::join_name(prefix, "weight"), - /*dim=*/0, - rank, - world_size, - weight); - } - - if (bias) { - bias_ = torch::empty({fused_out_features}, options); - const auto biases = bias_.split(split_sizes_, /*dim=*/0); - - // register sharded weights for each prefix - for (size_t i = 0; i < prefixes.size(); ++i) { - const auto& prefix = prefixes[i]; - const auto& bias = biases[i]; - register_sharded_parameter(detail::join_name(prefix, "bias"), - /*dim=*/0, - rank, - world_size, - bias); - } - } -} - -std::vector FusedColumnParallelLinearImpl::forward( - torch::Tensor input) { - namespace F = torch::nn::functional; - auto output = F::linear(input, weight_, bias_); - if (parallel_args_.world_size() > 1 && gather_output_) { - output = gather_from_model_parallel_region(output, parallel_args_); - } - return output.split(split_sizes_, /*dim=*/1); -} - -GroupedColumnParallelLinearImpl::GroupedColumnParallelLinearImpl( - int64_t in_features, - const std::vector& out_features_vec, - const std::vector& prefixes, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { - // register linear layers one by one - parallel_linears_.reserve(out_features_vec.size()); - for (size_t i = 0; i < out_features_vec.size(); ++i) { - const auto& prefix = prefixes[i]; - const auto out_features = out_features_vec[i]; - const auto linear = register_module( - "linear_" + std::to_string(i), - std::make_shared(in_features, - out_features, - bias, - gather_output, - parallel_args, - options, - prefix), - /*selector=*/nullptr); - - parallel_linears_.emplace_back(linear); - } -} - -std::vector GroupedColumnParallelLinearImpl::forward( - torch::Tensor input) { - std::vector outputs; - outputs.reserve(parallel_linears_.size()); - for (auto& parallel_linear : parallel_linears_) { - outputs.push_back(parallel_linear->forward(input)); - } - return outputs; -} - // Linear layer with row parallelism. RowParallelLinearImpl::RowParallelLinearImpl( int64_t in_features, @@ -203,4 +338,54 @@ torch::Tensor RowParallelLinearImpl::forward(torch::Tensor input) { return output; } +// construct a ColumnParallelLinear. +// chose right implementation based on the args. +ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options, + const std::string& prefix) + : ModuleHolder(create_column_parallel_linear(in_features, + out_features, + bias, + gather_output, + quant_args, + parallel_args, + options, + prefix)) {} + +ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) + : ModuleHolder(create_column_parallel_linear(in_features, + out_features, + bias, + gather_output, + {}, /*quant_args*/ + parallel_args, + options, + "")) {} + +// construct a rotary positional embedding. +// chose right implementation based on the args. +RowParallelLinear::RowParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool input_is_parallelized, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) + : ModuleHolder(create_row_parallel_linear(in_features, + out_features, + bias, + input_is_parallelized, + quant_args, + parallel_args, + options)) {} } // namespace llm diff --git a/src/layers/linear/parallel_linear.h b/src/layers/linear/parallel_linear.h index d868bbb2..e70a4ffb 100644 --- a/src/layers/linear/parallel_linear.h +++ b/src/layers/linear/parallel_linear.h @@ -3,13 +3,40 @@ #include #include -#include - -#include "linear.h" +#include "model_loader/state_dict.h" +#include "model_parallel/parallel_args.h" +#include "module/module.h" #include "module/module_holder.h" +#include "quantization/quant_args.h" namespace llm { +// an interface for parallel linear layer. +// all linear classes should inherit from this class and implement the forward +// function. +class ParallelLinearImpl : public Module { + public: + ~ParallelLinearImpl() override = default; + + virtual torch::Tensor forward(torch::Tensor input) = 0; + + // TODO: clean up the interface of load_state_dict + virtual void load_state_dict(const StateDict& state_dict) { + LOG(FATAL) << "not implemented"; + } + + virtual void verify_loaded_weights(const std::string& prefix = "") const { + LOG(FATAL) << "not implemented"; + } + + // special load_state_dict for fused cases + virtual void load_state_dict(const StateDict& /*state_dict*/, + const std::vector& /*prefixes*/) { + LOG(FATAL) << "not implemented"; + } +}; +LLM_MODULE(ParallelLinear); + // Linear layer with column parallelism. // The linear layer is defined as Y = XA + b. A is parallelized along // its second dimension as A = [A_1, ..., A_p]. @@ -42,57 +69,6 @@ class ColumnParallelLinearImpl : public ParallelLinearImpl { ParallelArgs parallel_args_; }; -// Fused linear layer with column parallelism. -class FusedColumnParallelLinearImpl : public MultiParallelLinearImpl { - public: - FusedColumnParallelLinearImpl(int64_t in_features, - const std::vector& out_features, - const std::vector& prefixes, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); - - std::vector forward(torch::Tensor input) override; - - // return the weight (for testing) - torch::Tensor weight() const { return weight_; } - - private: - // parameter members, must be registered - // we allocate the transpose since linear performs XA^T. - // A^T: [out_features_per_partition, in_features] - torch::Tensor weight_; - torch::Tensor bias_; - - std::vector split_sizes_; - - // whether to gather the output - bool gather_output_; - - // parallel args - ParallelArgs parallel_args_; -}; -LLM_MODULE(FusedColumnParallelLinear); - -class GroupedColumnParallelLinearImpl : public MultiParallelLinearImpl { - public: - GroupedColumnParallelLinearImpl(int64_t in_features, - const std::vector& out_features, - const std::vector& prefixes, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); - - std::vector forward(torch::Tensor input) override; - - private: - // parameter members, must be registered - std::vector> parallel_linears_; -}; -LLM_MODULE(GroupedColumnParallelLinear); - // Linear layer with row parallelism. // The linear layer is defined as Y = XA + b. A is parallelized along // its first dimension and X along its second dimension as: @@ -130,4 +106,45 @@ class RowParallelLinearImpl : public ParallelLinearImpl { // parallel args ParallelArgs parallel_args_; }; + +class ColumnParallelLinear : public ModuleHolder { + public: + using ModuleHolder::ModuleHolder; + using Impl [[maybe_unused]] = ParallelLinearImpl; + + // construct a rotary positional embedding. + // chose right implementation based on the args. + ColumnParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options, + const std::string& prefix = ""); + + ColumnParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); +}; + +class RowParallelLinear : public ModuleHolder { + public: + using ModuleHolder::ModuleHolder; + using Impl [[maybe_unused]] = ParallelLinearImpl; + + // construct a rotary positional embedding. + // chose right implementation based on the args. + RowParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool input_is_parallelized, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); +}; + } // namespace llm diff --git a/src/models/alibaba/qwen.h b/src/models/alibaba/qwen.h index fce16e5f..942a8a56 100644 --- a/src/models/alibaba/qwen.h +++ b/src/models/alibaba/qwen.h @@ -10,7 +10,6 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear/linear.h" #include "layers/linear/multi_parallel_linear.h" #include "layers/normalization.h" #include "memory/kv_cache.h" diff --git a/src/models/alibaba/qwen2.h b/src/models/alibaba/qwen2.h index 093f43a9..75fef6d5 100644 --- a/src/models/alibaba/qwen2.h +++ b/src/models/alibaba/qwen2.h @@ -10,7 +10,6 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear/linear.h" #include "layers/linear/qkv_parallel_linear.h" #include "layers/normalization.h" #include "memory/kv_cache.h" diff --git a/src/models/google/gemma.h b/src/models/google/gemma.h index cd3f40d2..16b88549 100644 --- a/src/models/google/gemma.h +++ b/src/models/google/gemma.h @@ -10,7 +10,6 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear/linear.h" #include "layers/linear/parallel_linear.h" #include "layers/linear/qkv_parallel_linear.h" #include "layers/normalization.h" diff --git a/src/models/google/gemma2.h b/src/models/google/gemma2.h index acfe98f6..37d580b3 100644 --- a/src/models/google/gemma2.h +++ b/src/models/google/gemma2.h @@ -10,7 +10,6 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear/linear.h" #include "layers/linear/qkv_parallel_linear.h" #include "layers/normalization.h" #include "memory/kv_cache.h" diff --git a/src/models/meta/llama.h b/src/models/meta/llama.h index 1f184240..1e1bce35 100644 --- a/src/models/meta/llama.h +++ b/src/models/meta/llama.h @@ -8,7 +8,6 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear/linear.h" #include "layers/linear/multi_parallel_linear.h" #include "layers/linear/qkv_parallel_linear.h" #include "layers/normalization.h" diff --git a/src/models/microsoft/phi.h b/src/models/microsoft/phi.h index 3f361048..bfdce9b2 100644 --- a/src/models/microsoft/phi.h +++ b/src/models/microsoft/phi.h @@ -6,7 +6,6 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear/linear.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" diff --git a/src/models/openai/gpt2.h b/src/models/openai/gpt2.h index 88e03e75..0d2df8aa 100644 --- a/src/models/openai/gpt2.h +++ b/src/models/openai/gpt2.h @@ -7,7 +7,6 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear/linear.h" #include "layers/linear/parallel_linear.h" #include "layers/normalization.h" #include "memory/kv_cache.h" diff --git a/src/quantization/qlinear_awq_marlin_impl.h b/src/quantization/qlinear_awq_marlin_impl.h index 879793c5..fcdb5d11 100644 --- a/src/quantization/qlinear_awq_marlin_impl.h +++ b/src/quantization/qlinear_awq_marlin_impl.h @@ -3,7 +3,7 @@ #include #include -#include "layers/linear/linear.h" +#include "layers/linear/parallel_linear.h" #include "layers/linear/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" diff --git a/src/quantization/qlinear_gptq_marlin_impl.h b/src/quantization/qlinear_gptq_marlin_impl.h index 62a5fad7..f44fb9a6 100644 --- a/src/quantization/qlinear_gptq_marlin_impl.h +++ b/src/quantization/qlinear_gptq_marlin_impl.h @@ -3,7 +3,7 @@ #include #include -#include "layers/linear/linear.h" +#include "layers/linear/parallel_linear.h" #include "layers/linear/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" diff --git a/src/quantization/qlinear_impl.cpp b/src/quantization/qlinear_impl.cpp index 2fe0517b..bb41ec04 100644 --- a/src/quantization/qlinear_impl.cpp +++ b/src/quantization/qlinear_impl.cpp @@ -4,7 +4,6 @@ #include #include -#include "layers/linear/linear.h" #include "model_loader/state_dict.h" namespace llm { diff --git a/src/quantization/qlinear_impl.h b/src/quantization/qlinear_impl.h index eee26afc..d02e2186 100644 --- a/src/quantization/qlinear_impl.h +++ b/src/quantization/qlinear_impl.h @@ -3,11 +3,10 @@ #include #include -#include "layers/linear/linear.h" +#include "layers/linear/parallel_linear.h" #include "layers/linear/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/model_parallel.h" -#include "models/model_args.h" namespace llm { From 16c26746780b239ed009f6cf7413a4249862da74 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 7 Oct 2025 16:15:05 -0700 Subject: [PATCH 09/13] move quantization into layers folder --- src/CMakeLists.txt | 1 - src/engine/llm_engine.h | 2 +- src/engine/worker.h | 2 +- src/engine/worker_test.cpp | 2 +- src/layers/CMakeLists.txt | 1 + src/layers/linear/multi_parallel_linear.cpp | 2 +- src/layers/linear/multi_parallel_linear.h | 2 +- src/layers/linear/parallel_linear.cpp | 10 +++++----- src/layers/linear/parallel_linear.h | 2 +- src/layers/linear/qkv_parallel_linear.h | 2 +- src/{ => layers}/quantization/CMakeLists.txt | 8 ++++---- src/{ => layers}/quantization/data/gptq.safetensors | Bin .../quantization/data/gptq_small.safetensors | Bin src/{ => layers}/quantization/pack_utils.cpp | 2 +- src/{ => layers}/quantization/pack_utils.h | 2 +- src/{ => layers}/quantization/pack_utils_test.cpp | 2 +- src/{ => layers}/quantization/qlinear_awq_impl.cpp | 0 src/{ => layers}/quantization/qlinear_awq_impl.h | 0 .../quantization/qlinear_awq_marlin_impl.cpp | 0 .../quantization/qlinear_awq_marlin_impl.h | 0 .../quantization/qlinear_exllamav2_impl.cpp | 0 .../quantization/qlinear_exllamav2_impl.h | 0 src/{ => layers}/quantization/qlinear_gptq_impl.cpp | 0 src/{ => layers}/quantization/qlinear_gptq_impl.h | 0 .../quantization/qlinear_gptq_marlin_impl.cpp | 0 .../quantization/qlinear_gptq_marlin_impl.h | 0 src/{ => layers}/quantization/qlinear_impl.cpp | 0 src/{ => layers}/quantization/qlinear_impl.h | 0 src/{ => layers}/quantization/qlinear_impl_test.cpp | 0 src/{ => layers}/quantization/quant_args.h | 0 src/model_loader/args_overrider.h | 4 ++-- src/model_loader/model_loader.h | 2 +- src/models/causal_lm.h | 2 +- src/models/model_registry.h | 2 +- 34 files changed, 25 insertions(+), 25 deletions(-) rename src/{ => layers}/quantization/CMakeLists.txt (96%) rename src/{ => layers}/quantization/data/gptq.safetensors (100%) rename src/{ => layers}/quantization/data/gptq_small.safetensors (100%) rename src/{ => layers}/quantization/pack_utils.cpp (98%) rename src/{ => layers}/quantization/pack_utils.h (93%) rename src/{ => layers}/quantization/pack_utils_test.cpp (97%) rename src/{ => layers}/quantization/qlinear_awq_impl.cpp (100%) rename src/{ => layers}/quantization/qlinear_awq_impl.h (100%) rename src/{ => layers}/quantization/qlinear_awq_marlin_impl.cpp (100%) rename src/{ => layers}/quantization/qlinear_awq_marlin_impl.h (100%) rename src/{ => layers}/quantization/qlinear_exllamav2_impl.cpp (100%) rename src/{ => layers}/quantization/qlinear_exllamav2_impl.h (100%) rename src/{ => layers}/quantization/qlinear_gptq_impl.cpp (100%) rename src/{ => layers}/quantization/qlinear_gptq_impl.h (100%) rename src/{ => layers}/quantization/qlinear_gptq_marlin_impl.cpp (100%) rename src/{ => layers}/quantization/qlinear_gptq_marlin_impl.h (100%) rename src/{ => layers}/quantization/qlinear_impl.cpp (100%) rename src/{ => layers}/quantization/qlinear_impl.h (100%) rename src/{ => layers}/quantization/qlinear_impl_test.cpp (100%) rename src/{ => layers}/quantization/quant_args.h (100%) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e7548482..7f25e844 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,7 +5,6 @@ add_subdirectory(kernels) add_subdirectory(tokenizer) add_subdirectory(module) add_subdirectory(layers) -add_subdirectory(quantization) add_subdirectory(models) add_subdirectory(model_loader) add_subdirectory(model_parallel) diff --git a/src/engine/llm_engine.h b/src/engine/llm_engine.h index 4f47fede..b59213df 100644 --- a/src/engine/llm_engine.h +++ b/src/engine/llm_engine.h @@ -5,8 +5,8 @@ #include "batch.h" #include "common/macros.h" #include "engine.h" +#include "layers/quantization/quant_args.h" #include "memory/block_manager.h" -#include "quantization/quant_args.h" #include "tokenizer/tokenizer.h" #include "tokenizer/tokenizer_args.h" #include "worker.h" diff --git a/src/engine/worker.h b/src/engine/worker.h index 1d25a6ed..9a929b2b 100644 --- a/src/engine/worker.h +++ b/src/engine/worker.h @@ -4,6 +4,7 @@ #include #include "common/threadpool.h" +#include "layers/quantization/quant_args.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" #include "model_runner.h" @@ -11,7 +12,6 @@ #include "models/model_args.h" #include "models/parameters.h" #include "parameters.h" -#include "quantization/quant_args.h" namespace llm { diff --git a/src/engine/worker_test.cpp b/src/engine/worker_test.cpp index ae4d4a4f..4df260cb 100644 --- a/src/engine/worker_test.cpp +++ b/src/engine/worker_test.cpp @@ -4,9 +4,9 @@ #include "engine/batch.h" #include "engine/utils.h" +#include "layers/quantization/quant_args.h" #include "memory/block_manager.h" #include "models/simple_model.h" -#include "quantization/quant_args.h" namespace llm { class TestableWorker { diff --git a/src/layers/CMakeLists.txt b/src/layers/CMakeLists.txt index 17426c74..e638f619 100644 --- a/src/layers/CMakeLists.txt +++ b/src/layers/CMakeLists.txt @@ -55,5 +55,6 @@ cc_test( ) add_subdirectory(linear) +add_subdirectory(quantization) add_subdirectory(attention) add_subdirectory(moe) diff --git a/src/layers/linear/multi_parallel_linear.cpp b/src/layers/linear/multi_parallel_linear.cpp index 146ec5d0..5a93935a 100644 --- a/src/layers/linear/multi_parallel_linear.cpp +++ b/src/layers/linear/multi_parallel_linear.cpp @@ -3,10 +3,10 @@ #include #include +#include "layers/quantization/quant_args.h" #include "model_parallel/model_parallel.h" #include "model_parallel/parallel_args.h" #include "parallel_linear.h" -#include "quantization/quant_args.h" namespace llm { diff --git a/src/layers/linear/multi_parallel_linear.h b/src/layers/linear/multi_parallel_linear.h index b36c5181..54636b0c 100644 --- a/src/layers/linear/multi_parallel_linear.h +++ b/src/layers/linear/multi_parallel_linear.h @@ -4,11 +4,11 @@ #include // #include "linear.h" +#include "layers/quantization/quant_args.h" #include "model_parallel/parallel_args.h" #include "module/module.h" #include "module/module_holder.h" #include "parallel_linear.h" -#include "quantization/quant_args.h" namespace llm { diff --git a/src/layers/linear/parallel_linear.cpp b/src/layers/linear/parallel_linear.cpp index 40219588..8cddf1cf 100644 --- a/src/layers/linear/parallel_linear.cpp +++ b/src/layers/linear/parallel_linear.cpp @@ -7,13 +7,13 @@ #include #include +#include "layers/quantization/qlinear_awq_impl.h" +#include "layers/quantization/qlinear_awq_marlin_impl.h" +#include "layers/quantization/qlinear_exllamav2_impl.h" +#include "layers/quantization/qlinear_gptq_impl.h" +#include "layers/quantization/qlinear_gptq_marlin_impl.h" #include "model_parallel/model_parallel.h" #include "module/module.h" -#include "quantization/qlinear_awq_impl.h" -#include "quantization/qlinear_awq_marlin_impl.h" -#include "quantization/qlinear_exllamav2_impl.h" -#include "quantization/qlinear_gptq_impl.h" -#include "quantization/qlinear_gptq_marlin_impl.h" DEFINE_string( qlinear_gptq_impl, diff --git a/src/layers/linear/parallel_linear.h b/src/layers/linear/parallel_linear.h index e70a4ffb..055766b5 100644 --- a/src/layers/linear/parallel_linear.h +++ b/src/layers/linear/parallel_linear.h @@ -3,11 +3,11 @@ #include #include +#include "layers/quantization/quant_args.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" #include "module/module.h" #include "module/module_holder.h" -#include "quantization/quant_args.h" namespace llm { diff --git a/src/layers/linear/qkv_parallel_linear.h b/src/layers/linear/qkv_parallel_linear.h index 29a648b8..5e502c06 100644 --- a/src/layers/linear/qkv_parallel_linear.h +++ b/src/layers/linear/qkv_parallel_linear.h @@ -3,12 +3,12 @@ #include #include +#include "layers/quantization/quant_args.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" #include "module/module.h" #include "module/module_holder.h" #include "multi_parallel_linear.h" -#include "quantization/quant_args.h" namespace llm { diff --git a/src/quantization/CMakeLists.txt b/src/layers/quantization/CMakeLists.txt similarity index 96% rename from src/quantization/CMakeLists.txt rename to src/layers/quantization/CMakeLists.txt index 325b0d02..94761865 100644 --- a/src/quantization/CMakeLists.txt +++ b/src/layers/quantization/CMakeLists.txt @@ -2,9 +2,9 @@ include(cc_library) include(cc_test) cc_library( - NAME + NAME quantization - HDRS + HDRS pack_utils.h qlinear_impl.h qlinear_gptq_impl.h @@ -12,7 +12,7 @@ cc_library( qlinear_awq_impl.h qlinear_gptq_marlin_impl.h qlinear_awq_marlin_impl.h - SRCS + SRCS pack_utils.cpp qlinear_impl.cpp qlinear_gptq_impl.cpp @@ -42,7 +42,7 @@ cc_test( :quantization :state_dict :gtest_main - DATA + DATA data/gptq_small.safetensors data/gptq.safetensors ) diff --git a/src/quantization/data/gptq.safetensors b/src/layers/quantization/data/gptq.safetensors similarity index 100% rename from src/quantization/data/gptq.safetensors rename to src/layers/quantization/data/gptq.safetensors diff --git a/src/quantization/data/gptq_small.safetensors b/src/layers/quantization/data/gptq_small.safetensors similarity index 100% rename from src/quantization/data/gptq_small.safetensors rename to src/layers/quantization/data/gptq_small.safetensors diff --git a/src/quantization/pack_utils.cpp b/src/layers/quantization/pack_utils.cpp similarity index 98% rename from src/quantization/pack_utils.cpp rename to src/layers/quantization/pack_utils.cpp index 882823ef..f872efb7 100644 --- a/src/quantization/pack_utils.cpp +++ b/src/layers/quantization/pack_utils.cpp @@ -75,4 +75,4 @@ torch::Tensor unpack_cols( return qweight.to(packed_qweight); } -} // namespace llm::pack_utils \ No newline at end of file +} // namespace llm::pack_utils diff --git a/src/quantization/pack_utils.h b/src/layers/quantization/pack_utils.h similarity index 93% rename from src/quantization/pack_utils.h rename to src/layers/quantization/pack_utils.h index c98355c8..e8db0eea 100644 --- a/src/quantization/pack_utils.h +++ b/src/layers/quantization/pack_utils.h @@ -14,4 +14,4 @@ torch::Tensor unpack_cols( const torch::Tensor& packed_qweight, // (m, n/pack_factor) int64_t num_bits); -} // namespace llm::pack_utils \ No newline at end of file +} // namespace llm::pack_utils diff --git a/src/quantization/pack_utils_test.cpp b/src/layers/quantization/pack_utils_test.cpp similarity index 97% rename from src/quantization/pack_utils_test.cpp rename to src/layers/quantization/pack_utils_test.cpp index 52125db9..ed13132f 100644 --- a/src/quantization/pack_utils_test.cpp +++ b/src/layers/quantization/pack_utils_test.cpp @@ -28,4 +28,4 @@ TEST(PackUtilsTest, Basic) { EXPECT_TRUE(unpacked_qweight.equal(qweight)); } -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/quantization/qlinear_awq_impl.cpp b/src/layers/quantization/qlinear_awq_impl.cpp similarity index 100% rename from src/quantization/qlinear_awq_impl.cpp rename to src/layers/quantization/qlinear_awq_impl.cpp diff --git a/src/quantization/qlinear_awq_impl.h b/src/layers/quantization/qlinear_awq_impl.h similarity index 100% rename from src/quantization/qlinear_awq_impl.h rename to src/layers/quantization/qlinear_awq_impl.h diff --git a/src/quantization/qlinear_awq_marlin_impl.cpp b/src/layers/quantization/qlinear_awq_marlin_impl.cpp similarity index 100% rename from src/quantization/qlinear_awq_marlin_impl.cpp rename to src/layers/quantization/qlinear_awq_marlin_impl.cpp diff --git a/src/quantization/qlinear_awq_marlin_impl.h b/src/layers/quantization/qlinear_awq_marlin_impl.h similarity index 100% rename from src/quantization/qlinear_awq_marlin_impl.h rename to src/layers/quantization/qlinear_awq_marlin_impl.h diff --git a/src/quantization/qlinear_exllamav2_impl.cpp b/src/layers/quantization/qlinear_exllamav2_impl.cpp similarity index 100% rename from src/quantization/qlinear_exllamav2_impl.cpp rename to src/layers/quantization/qlinear_exllamav2_impl.cpp diff --git a/src/quantization/qlinear_exllamav2_impl.h b/src/layers/quantization/qlinear_exllamav2_impl.h similarity index 100% rename from src/quantization/qlinear_exllamav2_impl.h rename to src/layers/quantization/qlinear_exllamav2_impl.h diff --git a/src/quantization/qlinear_gptq_impl.cpp b/src/layers/quantization/qlinear_gptq_impl.cpp similarity index 100% rename from src/quantization/qlinear_gptq_impl.cpp rename to src/layers/quantization/qlinear_gptq_impl.cpp diff --git a/src/quantization/qlinear_gptq_impl.h b/src/layers/quantization/qlinear_gptq_impl.h similarity index 100% rename from src/quantization/qlinear_gptq_impl.h rename to src/layers/quantization/qlinear_gptq_impl.h diff --git a/src/quantization/qlinear_gptq_marlin_impl.cpp b/src/layers/quantization/qlinear_gptq_marlin_impl.cpp similarity index 100% rename from src/quantization/qlinear_gptq_marlin_impl.cpp rename to src/layers/quantization/qlinear_gptq_marlin_impl.cpp diff --git a/src/quantization/qlinear_gptq_marlin_impl.h b/src/layers/quantization/qlinear_gptq_marlin_impl.h similarity index 100% rename from src/quantization/qlinear_gptq_marlin_impl.h rename to src/layers/quantization/qlinear_gptq_marlin_impl.h diff --git a/src/quantization/qlinear_impl.cpp b/src/layers/quantization/qlinear_impl.cpp similarity index 100% rename from src/quantization/qlinear_impl.cpp rename to src/layers/quantization/qlinear_impl.cpp diff --git a/src/quantization/qlinear_impl.h b/src/layers/quantization/qlinear_impl.h similarity index 100% rename from src/quantization/qlinear_impl.h rename to src/layers/quantization/qlinear_impl.h diff --git a/src/quantization/qlinear_impl_test.cpp b/src/layers/quantization/qlinear_impl_test.cpp similarity index 100% rename from src/quantization/qlinear_impl_test.cpp rename to src/layers/quantization/qlinear_impl_test.cpp diff --git a/src/quantization/quant_args.h b/src/layers/quantization/quant_args.h similarity index 100% rename from src/quantization/quant_args.h rename to src/layers/quantization/quant_args.h diff --git a/src/model_loader/args_overrider.h b/src/model_loader/args_overrider.h index 99b38fc6..83132873 100644 --- a/src/model_loader/args_overrider.h +++ b/src/model_loader/args_overrider.h @@ -2,8 +2,8 @@ #include +#include "layers/quantization/quant_args.h" #include "models/model_args.h" -#include "quantization/quant_args.h" #include "tokenizer/tokenizer_args.h" // Model args flags @@ -56,4 +56,4 @@ void override_args_from_gflag(ModelArgs& args, QuantArgs& quant_args, TokenizerArgs& tokenizer_args); -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/model_loader/model_loader.h b/src/model_loader/model_loader.h index dbac26cf..b1ac0782 100644 --- a/src/model_loader/model_loader.h +++ b/src/model_loader/model_loader.h @@ -4,9 +4,9 @@ #include +#include "layers/quantization/quant_args.h" #include "model_loader/state_dict.h" #include "models/model_args.h" -#include "quantization/quant_args.h" #include "tokenizer/tokenizer.h" #include "tokenizer/tokenizer_args.h" diff --git a/src/models/causal_lm.h b/src/models/causal_lm.h index eaed0ffe..e65aa368 100644 --- a/src/models/causal_lm.h +++ b/src/models/causal_lm.h @@ -5,12 +5,12 @@ #include +#include "layers/quantization/quant_args.h" #include "memory/kv_cache.h" #include "model_args.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" #include "parameters.h" -#include "quantization/quant_args.h" namespace llm { diff --git a/src/models/model_registry.h b/src/models/model_registry.h index a15f7138..93590ad3 100644 --- a/src/models/model_registry.h +++ b/src/models/model_registry.h @@ -8,9 +8,9 @@ #include "chat_template/chat_template.h" #include "common/json_reader.h" #include "common/type_traits.h" // IWYU pragma: keep +#include "layers/quantization/quant_args.h" #include "model_args.h" #include "model_parallel/parallel_args.h" -#include "quantization/quant_args.h" #include "tokenizer/tokenizer_args.h" namespace llm { From 8d493721c4b520c77472548b693abd9924860ab2 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 7 Oct 2025 16:41:30 -0700 Subject: [PATCH 10/13] refactor ModuleHolder --- src/layers/linear/multi_parallel_linear.cpp | 76 ++++++++++++--------- src/layers/linear/multi_parallel_linear.h | 30 ++++---- src/layers/linear/parallel_linear.cpp | 14 +++- src/layers/linear/parallel_linear.h | 12 ++-- src/module/module.h | 2 + 5 files changed, 81 insertions(+), 53 deletions(-) diff --git a/src/layers/linear/multi_parallel_linear.cpp b/src/layers/linear/multi_parallel_linear.cpp index 5a93935a..fb2f2a45 100644 --- a/src/layers/linear/multi_parallel_linear.cpp +++ b/src/layers/linear/multi_parallel_linear.cpp @@ -9,6 +9,37 @@ #include "parallel_linear.h" namespace llm { +namespace { + +std::shared_ptr create_multi_column_parallel_linear( + int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + // check if the linear layers can be fused + if (quant_args.can_be_fused()) { + return std::make_shared(in_features, + out_features, + prefixes, + bias, + gather_output, + parallel_args, + options); + } + + return std::make_shared(in_features, + out_features, + prefixes, + bias, + gather_output, + parallel_args, + options); +} +} // namespace FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( int64_t in_features, @@ -114,45 +145,28 @@ std::vector GroupedColumnParallelLinearImpl::forward( return outputs; } -MultiColumnParallelLinearImpl::MultiColumnParallelLinearImpl( +MultiColumnParallelLinear::MultiColumnParallelLinear(std::nullptr_t) + : ModuleHolder(nullptr) {} + +MultiColumnParallelLinear::MultiColumnParallelLinear( + std::shared_ptr module) + : ModuleHolder(std::move(module)) {} + +MultiColumnParallelLinear::MultiColumnParallelLinear( int64_t in_features, - const std::vector& out_features_vec, + const std::vector& out_features, const std::vector& prefixes, bool bias, bool gather_output, const QuantArgs& quant_args, const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { - // check if the linear layers can be fused - std::shared_ptr linear; - if (quant_args.can_be_fused()) { - // fused linear layer - linear = register_module("fused_linear", - FusedColumnParallelLinear(in_features, - out_features_vec, + const torch::TensorOptions& options) + : ModuleHolder(create_multi_column_parallel_linear(in_features, + out_features, prefixes, bias, gather_output, + quant_args, parallel_args, - options), - /*selector=*/nullptr); - } else { - // non-fused linear layers - linear = register_module("grouped_linear", - GroupedColumnParallelLinear(in_features, - out_features_vec, - prefixes, - bias, - gather_output, - parallel_args, - options), - /*selector=*/nullptr); - } - linear_ = linear; -} - -std::vector MultiColumnParallelLinearImpl::forward( - torch::Tensor input) { - return linear_(input); -} + options)) {} } // namespace llm diff --git a/src/layers/linear/multi_parallel_linear.h b/src/layers/linear/multi_parallel_linear.h index 54636b0c..c5fa2c30 100644 --- a/src/layers/linear/multi_parallel_linear.h +++ b/src/layers/linear/multi_parallel_linear.h @@ -18,7 +18,6 @@ class MultiParallelLinearImpl : public Module { virtual std::vector forward(torch::Tensor input) = 0; }; -LLM_MODULE(MultiParallelLinear); // Fused linear layer with column parallelism. class FusedColumnParallelLinearImpl : public MultiParallelLinearImpl { @@ -71,22 +70,21 @@ class GroupedColumnParallelLinearImpl : public MultiParallelLinearImpl { }; LLM_MODULE(GroupedColumnParallelLinear); -class MultiColumnParallelLinearImpl : public Module { +class MultiColumnParallelLinear : public ModuleHolder { public: - MultiColumnParallelLinearImpl(int64_t in_features, - const std::vector& out_features, - const std::vector& prefixes, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); - - std::vector forward(torch::Tensor input); - - private: - MultiParallelLinear linear_{nullptr}; + /* implicit */ MultiColumnParallelLinear(std::nullptr_t); + + /* implicit */ MultiColumnParallelLinear( + std::shared_ptr module); + + MultiColumnParallelLinear(int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); }; -LLM_MODULE(MultiColumnParallelLinear); } // namespace llm diff --git a/src/layers/linear/parallel_linear.cpp b/src/layers/linear/parallel_linear.cpp index 8cddf1cf..cfe6a8a4 100644 --- a/src/layers/linear/parallel_linear.cpp +++ b/src/layers/linear/parallel_linear.cpp @@ -340,6 +340,13 @@ torch::Tensor RowParallelLinearImpl::forward(torch::Tensor input) { // construct a ColumnParallelLinear. // chose right implementation based on the args. +ColumnParallelLinear::ColumnParallelLinear(std::nullptr_t) + : ModuleHolder(nullptr) {} + +ColumnParallelLinear::ColumnParallelLinear( + std::shared_ptr module) + : ModuleHolder(std::move(module)) {} + ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, int64_t out_features, bool bias, @@ -372,8 +379,13 @@ ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, options, "")) {} -// construct a rotary positional embedding. +// construct a row parallel linear. // chose right implementation based on the args. +RowParallelLinear::RowParallelLinear(std::nullptr_t) : ModuleHolder(nullptr) {} + +RowParallelLinear::RowParallelLinear(std::shared_ptr module) + : ModuleHolder(std::move(module)) {} + RowParallelLinear::RowParallelLinear(int64_t in_features, int64_t out_features, bool bias, diff --git a/src/layers/linear/parallel_linear.h b/src/layers/linear/parallel_linear.h index 055766b5..1a74d505 100644 --- a/src/layers/linear/parallel_linear.h +++ b/src/layers/linear/parallel_linear.h @@ -35,7 +35,6 @@ class ParallelLinearImpl : public Module { LOG(FATAL) << "not implemented"; } }; -LLM_MODULE(ParallelLinear); // Linear layer with column parallelism. // The linear layer is defined as Y = XA + b. A is parallelized along @@ -109,8 +108,10 @@ class RowParallelLinearImpl : public ParallelLinearImpl { class ColumnParallelLinear : public ModuleHolder { public: - using ModuleHolder::ModuleHolder; - using Impl [[maybe_unused]] = ParallelLinearImpl; + /* implicit */ ColumnParallelLinear(std::nullptr_t); + + /* implicit */ ColumnParallelLinear( + std::shared_ptr module); // construct a rotary positional embedding. // chose right implementation based on the args. @@ -133,8 +134,9 @@ class ColumnParallelLinear : public ModuleHolder { class RowParallelLinear : public ModuleHolder { public: - using ModuleHolder::ModuleHolder; - using Impl [[maybe_unused]] = ParallelLinearImpl; + /* implicit */ RowParallelLinear(std::nullptr_t); + + /* implicit */ RowParallelLinear(std::shared_ptr module); // construct a rotary positional embedding. // chose right implementation based on the args. diff --git a/src/module/module.h b/src/module/module.h index 2d21610d..664f4d9c 100644 --- a/src/module/module.h +++ b/src/module/module.h @@ -1,3 +1,5 @@ +// Adapted from +// https://github.com/pytorch/pytorch/blob/main/torch/csrc/api/include/torch/nn/module.h #pragma once #include From 2500606cb4d997fee1ba5ce00c12105ab8f446b5 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 7 Oct 2025 16:44:59 -0700 Subject: [PATCH 11/13] add ref link --- src/module/module_holder.h | 2 ++ src/module/module_list.h | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/module/module_holder.h b/src/module/module_holder.h index abe131a5..c17aac06 100644 --- a/src/module/module_holder.h +++ b/src/module/module_holder.h @@ -1,3 +1,5 @@ +// Adapted from +// https://github.com/pytorch/pytorch/blob/main/torch/csrc/api/include/torch/nn/pimpl.h #pragma once #include diff --git a/src/module/module_list.h b/src/module/module_list.h index ce0bfee6..56cf833c 100644 --- a/src/module/module_list.h +++ b/src/module/module_list.h @@ -1,3 +1,5 @@ +// Adapted from +// https://github.com/pytorch/pytorch/blob/main/torch/csrc/api/include/torch/nn/modules/container/modulelist.h #pragma once #include From 8cbcd1cd771b8791dca58f749faffb5c76180c92 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 7 Oct 2025 17:10:04 -0700 Subject: [PATCH 12/13] add unittests --- src/layers/linear/CMakeLists.txt | 3 +- .../linear/multi_parallel_linear_test.cpp | 197 ++++++++++++++++++ src/layers/linear/parallel_linear.cpp | 31 --- ...near_test.cpp => parallel_linear_test.cpp} | 111 +--------- 4 files changed, 204 insertions(+), 138 deletions(-) create mode 100644 src/layers/linear/multi_parallel_linear_test.cpp rename src/layers/linear/{linear_test.cpp => parallel_linear_test.cpp} (55%) diff --git a/src/layers/linear/CMakeLists.txt b/src/layers/linear/CMakeLists.txt index 6a17000c..8c0a8d8a 100644 --- a/src/layers/linear/CMakeLists.txt +++ b/src/layers/linear/CMakeLists.txt @@ -29,7 +29,8 @@ cc_test( NAME linear_test SRCS - linear_test.cpp + parallel_linear_test.cpp + multi_parallel_linear_test.cpp qkv_parallel_linear_test.cpp DEPS :linear diff --git a/src/layers/linear/multi_parallel_linear_test.cpp b/src/layers/linear/multi_parallel_linear_test.cpp new file mode 100644 index 00000000..241c32c9 --- /dev/null +++ b/src/layers/linear/multi_parallel_linear_test.cpp @@ -0,0 +1,197 @@ +#include "multi_parallel_linear.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "model_loader/state_dict.h" + +namespace llm { + +TEST(MultiParallelLinearTest, FusedColumnParallelLinear) { + // test load state dict for linear + const int64_t in_features = 10; + const int64_t out_features = 40; + + torch::Device device(torch::kCPU); + torch::ScalarType dtype(torch::kFloat); + const auto options = torch::dtype(dtype).device(device); + + std::vector out_features_vec = { + out_features, out_features, out_features}; + std::vector prefixes = {"query.", "key.", "value."}; + + std::unordered_map state_dict_data; + // Allocate transposed weight matrix + state_dict_data["query.weight"] = torch::randn({out_features, in_features}); + state_dict_data["key.weight"] = torch::randn({out_features, in_features}); + state_dict_data["value.weight"] = torch::randn({out_features, in_features}); + + // weight is not sharded + StateDict state_dict(state_dict_data); + + // test load weight + { + ParallelArgs parallel_args(0, 1, nullptr); + FusedColumnParallelLinearImpl linear(in_features, + out_features_vec, + prefixes, + /*bias=*/false, + /*gather_output=*/false, + parallel_args, + options); + // test load fused weight + EXPECT_EQ(linear.load(state_dict), 3); + + for (const auto& prefix : prefixes) { + auto named_parameters = linear.named_parameters(/*recurse=*/false); + const auto key = detail::join_name(prefix, "weight"); + ASSERT_TRUE(named_parameters.contains(key)); + + const auto& loaded_weight = named_parameters[key]; + EXPECT_EQ(loaded_weight.sizes(), + torch::IntArrayRef({out_features, in_features})); + EXPECT_TRUE(torch::equal(loaded_weight, state_dict_data[key])); + } + + // verify the fused weight + const auto loaded_fused_weight = linear.weight(); + const auto desired_fused_weight = + torch::cat({state_dict_data["query.weight"], + state_dict_data["key.weight"], + state_dict_data["value.weight"]}, + /*dim=*/0); + EXPECT_TRUE(torch::equal(loaded_fused_weight, desired_fused_weight)); + } + + // test load weight with 4 shards + const int32_t num_shards = 4; + for (int32_t shard_id = 0; shard_id < num_shards; ++shard_id) { + ParallelArgs parallel_args_0(shard_id, num_shards, nullptr); + FusedColumnParallelLinearImpl linear(in_features, + out_features_vec, + prefixes, + /*bias=*/false, + /*gather_output=*/false, + parallel_args_0, + options); + EXPECT_EQ(linear.load(state_dict), 3); + + auto named_parameters = linear.named_parameters(/*recurse=*/false); + + // check size for each prefix + for (const auto& prefix : prefixes) { + const auto key = detail::join_name(prefix, "weight"); + ASSERT_TRUE(named_parameters.contains(key)); + + const auto& loaded_weight = named_parameters[key]; + EXPECT_EQ(loaded_weight.sizes(), + torch::IntArrayRef({out_features / num_shards, in_features})); + EXPECT_TRUE(torch::equal( + loaded_weight, state_dict_data[key].chunk(num_shards, 0)[shard_id])); + } + + // shard weight then cat + auto sharded_query_weight = + state_dict_data["query.weight"].chunk(num_shards, 0)[shard_id]; + auto sharded_key_weight = + state_dict_data["key.weight"].chunk(num_shards, 0)[shard_id]; + auto sharded_value_weight = + state_dict_data["value.weight"].chunk(num_shards, 0)[shard_id]; + + // verify the fused weight + const auto loaded_fused_weight = linear.weight(); + auto desired_fused_weight = torch::cat( + {sharded_query_weight, sharded_key_weight, sharded_value_weight}, + /*dim=*/0); + + EXPECT_TRUE(torch::equal(loaded_fused_weight, desired_fused_weight)); + } +} + +TEST(MultiParallelLinearTest, GroupedColumnParallelLinear) { + const int64_t in_features = 10; + const int64_t out_features = 40; + std::vector out_features_vec = { + out_features, out_features, out_features}; + std::vector prefixes = {"query.", "key.", "value."}; + + torch::Device device(torch::kCPU); + torch::ScalarType dtype(torch::kFloat); + const auto options = torch::dtype(dtype).device(device); + + std::unordered_map state_dict_data; + // Allocate transposed weight matrix + state_dict_data["query.weight"] = torch::randn({out_features, in_features}); + state_dict_data["key.weight"] = torch::randn({out_features, in_features}); + state_dict_data["value.weight"] = torch::randn({out_features, in_features}); + // weight is not sharded + StateDict state_dict(state_dict_data); + + // test load weight + { + ParallelArgs parallel_args(0, 1, nullptr); + GroupedColumnParallelLinearImpl linear(in_features, + out_features_vec, + prefixes, + /*bias=*/false, + /*gather_output=*/false, + parallel_args, + options); + // test load grouped weight + EXPECT_EQ(linear.load(state_dict), 3); + + auto named_parameters = linear.named_parameters(/*recurse=*/true); + for (size_t i = 0; i < prefixes.size(); ++i) { + const auto prefix = "linear_" + std::to_string(i) + "." + prefixes[i]; + const auto key = detail::join_name(prefix, "weight"); + ASSERT_TRUE(named_parameters.contains(key)); + + const auto& loaded_weight = named_parameters[key]; + + const auto sd_key = detail::join_name(prefixes[i], "weight"); + + EXPECT_EQ(loaded_weight.sizes(), + torch::IntArrayRef({out_features, in_features})); + EXPECT_TRUE(torch::equal(loaded_weight, state_dict_data[sd_key])); + } + } + + // test load weight with 4 shards + const int32_t num_shards = 4; + for (int32_t shard_id = 0; shard_id < num_shards; ++shard_id) { + ParallelArgs parallel_args(shard_id, num_shards, nullptr); + GroupedColumnParallelLinearImpl linear(in_features, + out_features_vec, + prefixes, + /*bias=*/false, + /*gather_output=*/false, + parallel_args, + options); + EXPECT_EQ(linear.load(state_dict), 3); + auto named_parameters = linear.named_parameters(/*recurse=*/true); + // check size for each prefix + for (size_t i = 0; i < prefixes.size(); ++i) { + const auto prefix = "linear_" + std::to_string(i) + "." + prefixes[i]; + const auto key = detail::join_name(prefix, "weight"); + ASSERT_TRUE(named_parameters.contains(key)); + + const auto& loaded_weight = named_parameters[key]; + EXPECT_EQ(loaded_weight.sizes(), + torch::IntArrayRef({out_features / num_shards, in_features})); + const auto sd_key = detail::join_name(prefixes[i], "weight"); + EXPECT_TRUE( + torch::equal(loaded_weight, + state_dict_data[sd_key].chunk(num_shards, 0)[shard_id])); + } + } +} + +} // namespace llm diff --git a/src/layers/linear/parallel_linear.cpp b/src/layers/linear/parallel_linear.cpp index cfe6a8a4..9bc46e73 100644 --- a/src/layers/linear/parallel_linear.cpp +++ b/src/layers/linear/parallel_linear.cpp @@ -215,37 +215,6 @@ std::shared_ptr create_row_parallel_linear( parallel_args, options); } - -// std::shared_ptr create_multi_column_parallel_linear( -// int64_t in_features, -// const std::vector& out_features, -// const std::vector& prefixes, -// bool bias, -// bool gather_output, -// const QuantArgs& quant_args, -// const ParallelArgs& parallel_args, -// const torch::TensorOptions& options) { -// // check if the linear layers can be fused -// const bool fused = quant_args.can_be_fused(); -// std::shared_ptr impl; -// if (fused) { -// return std::make_shared(in_features, -// out_features, -// prefixes, -// bias, -// gather_output, -// parallel_args, -// options); -// } - -// return std::make_shared(in_features, -// out_features, -// prefixes, -// bias, -// gather_output, -// parallel_args, -// options); -// } } // namespace // Linear layer with column parallelism. diff --git a/src/layers/linear/linear_test.cpp b/src/layers/linear/parallel_linear_test.cpp similarity index 55% rename from src/layers/linear/linear_test.cpp rename to src/layers/linear/parallel_linear_test.cpp index 6b72063f..236b58d1 100644 --- a/src/layers/linear/linear_test.cpp +++ b/src/layers/linear/parallel_linear_test.cpp @@ -1,20 +1,21 @@ +#include "parallel_linear.h" + #include #include #include #include #include +#include #include #include #include #include "model_loader/state_dict.h" -#include "multi_parallel_linear.h" -#include "parallel_linear.h" namespace llm { -TEST(LinearTest, RowParallelLoadWeight) { +TEST(ParallelLinearTest, RowParallelLinear) { // test load state dict for row parallel linear const int64_t in_features = 10; const int64_t out_features = 20; @@ -79,7 +80,7 @@ TEST(LinearTest, RowParallelLoadWeight) { } } -TEST(LinearTest, ColumnParallelLoadWeight) { +TEST(ParallelLinearTest, ColumnParallelLinear) { // test load state dict for linear const int64_t in_features = 10; const int64_t out_features = 20; @@ -136,106 +137,4 @@ TEST(LinearTest, ColumnParallelLoadWeight) { } } -TEST(LinearTest, ColumnParallelLoadFusedWeight) { - // test load state dict for linear - const int64_t in_features = 10; - const int64_t out_features = 40; - - torch::Device device(torch::kCPU); - torch::ScalarType dtype(torch::kFloat); - const auto options = torch::dtype(dtype).device(device); - - std::vector out_features_vec = { - out_features, out_features, out_features}; - std::vector prefixes = {"query.", "key.", "value."}; - - std::unordered_map state_dict_data; - // Allocate transposed weight matrix - state_dict_data["query.weight"] = torch::randn({out_features, in_features}); - state_dict_data["key.weight"] = torch::randn({out_features, in_features}); - state_dict_data["value.weight"] = torch::randn({out_features, in_features}); - - // weight is not sharded - StateDict state_dict(state_dict_data); - - // test load weight - { - ParallelArgs parallel_args(0, 1, nullptr); - FusedColumnParallelLinearImpl linear(in_features, - out_features_vec, - prefixes, - /*bias=*/false, - /*gather_output=*/false, - parallel_args, - options); - // test load fused weight - EXPECT_EQ(linear.load(state_dict), 3); - - for (const auto& prefix : prefixes) { - auto named_parameters = linear.named_parameters(/*recurse=*/false); - const auto key = detail::join_name(prefix, "weight"); - ASSERT_TRUE(named_parameters.contains(key)); - - const auto& loaded_weight = named_parameters[key]; - EXPECT_EQ(loaded_weight.sizes(), - torch::IntArrayRef({out_features, in_features})); - EXPECT_TRUE(torch::equal(loaded_weight, state_dict_data[key])); - } - - // verify the fused weight - const auto loaded_fused_weight = linear.weight(); - const auto desired_fused_weight = - torch::cat({state_dict_data["query.weight"], - state_dict_data["key.weight"], - state_dict_data["value.weight"]}, - /*dim=*/0); - EXPECT_TRUE(torch::equal(loaded_fused_weight, desired_fused_weight)); - } - - // test load weight with 4 shards - const int32_t num_shards = 4; - for (int32_t shard_id = 0; shard_id < num_shards; ++shard_id) { - ParallelArgs parallel_args_0(shard_id, num_shards, nullptr); - FusedColumnParallelLinearImpl linear(in_features, - out_features_vec, - prefixes, - /*bias=*/false, - /*gather_output=*/false, - parallel_args_0, - options); - EXPECT_EQ(linear.load(state_dict), 3); - - auto named_parameters = linear.named_parameters(/*recurse=*/false); - - // check size for each prefix - for (const auto& prefix : prefixes) { - auto named_parameters = linear.named_parameters(/*recurse=*/false); - const auto key = detail::join_name(prefix, "weight"); - ASSERT_TRUE(named_parameters.contains(key)); - - const auto& loaded_weight = named_parameters[key]; - EXPECT_EQ(loaded_weight.sizes(), - torch::IntArrayRef({out_features / num_shards, in_features})); - EXPECT_TRUE(torch::equal( - loaded_weight, state_dict_data[key].chunk(num_shards, 0)[shard_id])); - } - - // shard weight then cat - auto sharded_query_weight = - state_dict_data["query.weight"].chunk(num_shards, 0)[shard_id]; - auto sharded_key_weight = - state_dict_data["key.weight"].chunk(num_shards, 0)[shard_id]; - auto sharded_value_weight = - state_dict_data["value.weight"].chunk(num_shards, 0)[shard_id]; - - // verify the fused weight - const auto loaded_fused_weight = linear.weight(); - auto desired_fused_weight = torch::cat( - {sharded_query_weight, sharded_key_weight, sharded_value_weight}, - /*dim=*/0); - - EXPECT_TRUE(torch::equal(loaded_fused_weight, desired_fused_weight)); - } -} - } // namespace llm From 3e0d016b591b11d51ee8200451c9d06b7bec2ebd Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 7 Oct 2025 17:13:44 -0700 Subject: [PATCH 13/13] move '"module' into layers folder --- src/CMakeLists.txt | 1 - src/layers/CMakeLists.txt | 1 + src/layers/attention/attention.h | 4 ++-- src/layers/embedding.h | 4 ++-- src/layers/linear/multi_parallel_linear.h | 4 ++-- src/layers/linear/parallel_linear.cpp | 2 +- src/layers/linear/parallel_linear.h | 4 ++-- src/layers/linear/qkv_parallel_linear.h | 4 ++-- src/{ => layers}/module/CMakeLists.txt | 0 src/{ => layers}/module/module.cpp | 0 src/{ => layers}/module/module.h | 0 src/{ => layers}/module/module_holder.h | 0 src/{ => layers}/module/module_list.h | 0 src/{ => layers}/module/module_test.cpp | 0 src/layers/normalization.h | 4 ++-- src/models/_deprecated/aquila.h | 6 +++--- src/models/_deprecated/baichuan.h | 6 +++--- src/models/_deprecated/bloom.h | 6 +++--- src/models/_deprecated/chatglm.h | 6 +++--- src/models/_deprecated/gpt_j.h | 6 +++--- src/models/_deprecated/gpt_neox.h | 6 +++--- src/models/_deprecated/internlm.h | 6 +++--- src/models/_deprecated/mistral.h | 6 +++--- src/models/_deprecated/mpt.h | 6 +++--- src/models/_deprecated/simple_model.h | 6 +++--- src/models/alibaba/qwen.h | 6 +++--- src/models/alibaba/qwen2.h | 6 +++--- src/models/google/gemma.h | 6 +++--- src/models/google/gemma2.h | 6 +++--- src/models/meta/llama.h | 6 +++--- src/models/microsoft/phi.h | 6 +++--- src/models/openai/gpt2.h | 6 +++--- 32 files changed, 65 insertions(+), 65 deletions(-) rename src/{ => layers}/module/CMakeLists.txt (100%) rename src/{ => layers}/module/module.cpp (100%) rename src/{ => layers}/module/module.h (100%) rename src/{ => layers}/module/module_holder.h (100%) rename src/{ => layers}/module/module_list.h (100%) rename src/{ => layers}/module/module_test.cpp (100%) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7f25e844..1b441180 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,7 +3,6 @@ add_subdirectory(common) add_subdirectory(handlers) add_subdirectory(kernels) add_subdirectory(tokenizer) -add_subdirectory(module) add_subdirectory(layers) add_subdirectory(models) add_subdirectory(model_loader) diff --git a/src/layers/CMakeLists.txt b/src/layers/CMakeLists.txt index e638f619..abd44bbd 100644 --- a/src/layers/CMakeLists.txt +++ b/src/layers/CMakeLists.txt @@ -54,6 +54,7 @@ cc_test( :gtest_main ) +add_subdirectory(module) add_subdirectory(linear) add_subdirectory(quantization) add_subdirectory(attention) diff --git a/src/layers/attention/attention.h b/src/layers/attention/attention.h index 64c6bb5e..663fea71 100644 --- a/src/layers/attention/attention.h +++ b/src/layers/attention/attention.h @@ -3,10 +3,10 @@ #include #include "layers/attention/handler.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" #include "memory/kv_cache.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" namespace llm { diff --git a/src/layers/embedding.h b/src/layers/embedding.h index 9ef07033..fbe1ddf0 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -5,10 +5,10 @@ #include +#include "layers/module/module.h" +#include "layers/module/module_holder.h" #include "model_loader/state_dict.h" #include "model_parallel/model_parallel.h" -#include "module/module.h" -#include "module/module_holder.h" namespace llm { diff --git a/src/layers/linear/multi_parallel_linear.h b/src/layers/linear/multi_parallel_linear.h index c5fa2c30..42a2295e 100644 --- a/src/layers/linear/multi_parallel_linear.h +++ b/src/layers/linear/multi_parallel_linear.h @@ -4,10 +4,10 @@ #include // #include "linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" #include "layers/quantization/quant_args.h" #include "model_parallel/parallel_args.h" -#include "module/module.h" -#include "module/module_holder.h" #include "parallel_linear.h" namespace llm { diff --git a/src/layers/linear/parallel_linear.cpp b/src/layers/linear/parallel_linear.cpp index 9bc46e73..ed591457 100644 --- a/src/layers/linear/parallel_linear.cpp +++ b/src/layers/linear/parallel_linear.cpp @@ -7,13 +7,13 @@ #include #include +#include "layers/module/module.h" #include "layers/quantization/qlinear_awq_impl.h" #include "layers/quantization/qlinear_awq_marlin_impl.h" #include "layers/quantization/qlinear_exllamav2_impl.h" #include "layers/quantization/qlinear_gptq_impl.h" #include "layers/quantization/qlinear_gptq_marlin_impl.h" #include "model_parallel/model_parallel.h" -#include "module/module.h" DEFINE_string( qlinear_gptq_impl, diff --git a/src/layers/linear/parallel_linear.h b/src/layers/linear/parallel_linear.h index 1a74d505..d1ebee68 100644 --- a/src/layers/linear/parallel_linear.h +++ b/src/layers/linear/parallel_linear.h @@ -3,11 +3,11 @@ #include #include +#include "layers/module/module.h" +#include "layers/module/module_holder.h" #include "layers/quantization/quant_args.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" -#include "module/module.h" -#include "module/module_holder.h" namespace llm { diff --git a/src/layers/linear/qkv_parallel_linear.h b/src/layers/linear/qkv_parallel_linear.h index 5e502c06..43bff377 100644 --- a/src/layers/linear/qkv_parallel_linear.h +++ b/src/layers/linear/qkv_parallel_linear.h @@ -3,11 +3,11 @@ #include #include +#include "layers/module/module.h" +#include "layers/module/module_holder.h" #include "layers/quantization/quant_args.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" -#include "module/module.h" -#include "module/module_holder.h" #include "multi_parallel_linear.h" namespace llm { diff --git a/src/module/CMakeLists.txt b/src/layers/module/CMakeLists.txt similarity index 100% rename from src/module/CMakeLists.txt rename to src/layers/module/CMakeLists.txt diff --git a/src/module/module.cpp b/src/layers/module/module.cpp similarity index 100% rename from src/module/module.cpp rename to src/layers/module/module.cpp diff --git a/src/module/module.h b/src/layers/module/module.h similarity index 100% rename from src/module/module.h rename to src/layers/module/module.h diff --git a/src/module/module_holder.h b/src/layers/module/module_holder.h similarity index 100% rename from src/module/module_holder.h rename to src/layers/module/module_holder.h diff --git a/src/module/module_list.h b/src/layers/module/module_list.h similarity index 100% rename from src/module/module_list.h rename to src/layers/module/module_list.h diff --git a/src/module/module_test.cpp b/src/layers/module/module_test.cpp similarity index 100% rename from src/module/module_test.cpp rename to src/layers/module/module_test.cpp diff --git a/src/layers/normalization.h b/src/layers/normalization.h index da75a81b..6011589c 100644 --- a/src/layers/normalization.h +++ b/src/layers/normalization.h @@ -6,9 +6,9 @@ #include #include "kernels/layernorm_kernels.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" #include "model_loader/state_dict.h" -#include "module/module.h" -#include "module/module_holder.h" DECLARE_bool(disable_custom_kernels); namespace llm { diff --git a/src/models/_deprecated/aquila.h b/src/models/_deprecated/aquila.h index 59e7b3e4..c76343a2 100644 --- a/src/models/_deprecated/aquila.h +++ b/src/models/_deprecated/aquila.h @@ -9,14 +9,14 @@ #include "layers/embedding.h" #include "layers/fused_linear.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Aquila model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/_deprecated/baichuan.h b/src/models/_deprecated/baichuan.h index db77fcdc..8fa2b56d 100644 --- a/src/models/_deprecated/baichuan.h +++ b/src/models/_deprecated/baichuan.h @@ -11,14 +11,14 @@ #include "layers/embedding.h" #include "layers/fused_linear.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Baichuan model compatible with huggingface weights diff --git a/src/models/_deprecated/bloom.h b/src/models/_deprecated/bloom.h index 0499993b..3cb07b7b 100644 --- a/src/models/_deprecated/bloom.h +++ b/src/models/_deprecated/bloom.h @@ -8,14 +8,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // bloom model compatible with huggingface weights diff --git a/src/models/_deprecated/chatglm.h b/src/models/_deprecated/chatglm.h index cd8c71b3..d9729f96 100644 --- a/src/models/_deprecated/chatglm.h +++ b/src/models/_deprecated/chatglm.h @@ -9,14 +9,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" #include "tokenizer/tokenizer_args.h" // ChatGLM model compatible with huggingface weights diff --git a/src/models/_deprecated/gpt_j.h b/src/models/_deprecated/gpt_j.h index 3905860b..1485b1b7 100644 --- a/src/models/_deprecated/gpt_j.h +++ b/src/models/_deprecated/gpt_j.h @@ -7,14 +7,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // GPTJ model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/_deprecated/gpt_neox.h b/src/models/_deprecated/gpt_neox.h index 65a2a1a3..e4ec964b 100644 --- a/src/models/_deprecated/gpt_neox.h +++ b/src/models/_deprecated/gpt_neox.h @@ -7,14 +7,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // gpt-neox model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/_deprecated/internlm.h b/src/models/_deprecated/internlm.h index 8f100c32..df145482 100644 --- a/src/models/_deprecated/internlm.h +++ b/src/models/_deprecated/internlm.h @@ -9,14 +9,14 @@ #include "layers/embedding.h" #include "layers/fused_linear.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Internlm model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/_deprecated/mistral.h b/src/models/_deprecated/mistral.h index 9b8e8ad0..f7d340d7 100644 --- a/src/models/_deprecated/mistral.h +++ b/src/models/_deprecated/mistral.h @@ -7,15 +7,15 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "layers/qkv_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Mistral model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/_deprecated/mpt.h b/src/models/_deprecated/mpt.h index b435b828..9e06c11b 100644 --- a/src/models/_deprecated/mpt.h +++ b/src/models/_deprecated/mpt.h @@ -10,14 +10,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // mpt model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/_deprecated/simple_model.h b/src/models/_deprecated/simple_model.h index dc22d3a0..ed7000bb 100644 --- a/src/models/_deprecated/simple_model.h +++ b/src/models/_deprecated/simple_model.h @@ -7,13 +7,13 @@ #include "layers/embedding.h" #include "layers/fused_linear.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // simple model for test namespace llm { diff --git a/src/models/alibaba/qwen.h b/src/models/alibaba/qwen.h index 942a8a56..5f03c103 100644 --- a/src/models/alibaba/qwen.h +++ b/src/models/alibaba/qwen.h @@ -11,14 +11,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear/multi_parallel_linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // QWen model compatible with huggingface weights // Adapted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py namespace llm::hf { diff --git a/src/models/alibaba/qwen2.h b/src/models/alibaba/qwen2.h index 75fef6d5..41dcd301 100644 --- a/src/models/alibaba/qwen2.h +++ b/src/models/alibaba/qwen2.h @@ -11,14 +11,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear/qkv_parallel_linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // QWen2 model compatible with huggingface weights // ref to: // https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/qwen2/modeling_qwen2.py diff --git a/src/models/google/gemma.h b/src/models/google/gemma.h index 16b88549..d07bb248 100644 --- a/src/models/google/gemma.h +++ b/src/models/google/gemma.h @@ -12,14 +12,14 @@ #include "layers/embedding.h" #include "layers/linear/parallel_linear.h" #include "layers/linear/qkv_parallel_linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Gemma model compatible with huggingface weight namespace llm::hf { diff --git a/src/models/google/gemma2.h b/src/models/google/gemma2.h index 37d580b3..200da7c9 100644 --- a/src/models/google/gemma2.h +++ b/src/models/google/gemma2.h @@ -11,14 +11,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear/qkv_parallel_linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Gemma2 model compatible with huggingface weight namespace llm::hf { diff --git a/src/models/meta/llama.h b/src/models/meta/llama.h index 1e1bce35..22a0cd39 100644 --- a/src/models/meta/llama.h +++ b/src/models/meta/llama.h @@ -10,14 +10,14 @@ #include "layers/embedding.h" #include "layers/linear/multi_parallel_linear.h" #include "layers/linear/qkv_parallel_linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // llama2 model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/microsoft/phi.h b/src/models/microsoft/phi.h index bfdce9b2..1aa97bb8 100644 --- a/src/models/microsoft/phi.h +++ b/src/models/microsoft/phi.h @@ -6,14 +6,14 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Phi model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/openai/gpt2.h b/src/models/openai/gpt2.h index 0d2df8aa..3cd6a167 100644 --- a/src/models/openai/gpt2.h +++ b/src/models/openai/gpt2.h @@ -8,14 +8,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear/parallel_linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // gpt2 model compatible with huggingface weights