diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e7548482..1b441180 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,9 +3,7 @@ add_subdirectory(common) add_subdirectory(handlers) add_subdirectory(kernels) add_subdirectory(tokenizer) -add_subdirectory(module) add_subdirectory(layers) -add_subdirectory(quantization) add_subdirectory(models) add_subdirectory(model_loader) add_subdirectory(model_parallel) diff --git a/src/engine/llm_engine.h b/src/engine/llm_engine.h index 4f47fede..b59213df 100644 --- a/src/engine/llm_engine.h +++ b/src/engine/llm_engine.h @@ -5,8 +5,8 @@ #include "batch.h" #include "common/macros.h" #include "engine.h" +#include "layers/quantization/quant_args.h" #include "memory/block_manager.h" -#include "quantization/quant_args.h" #include "tokenizer/tokenizer.h" #include "tokenizer/tokenizer_args.h" #include "worker.h" diff --git a/src/engine/worker.h b/src/engine/worker.h index 1d25a6ed..9a929b2b 100644 --- a/src/engine/worker.h +++ b/src/engine/worker.h @@ -4,6 +4,7 @@ #include #include "common/threadpool.h" +#include "layers/quantization/quant_args.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" #include "model_runner.h" @@ -11,7 +12,6 @@ #include "models/model_args.h" #include "models/parameters.h" #include "parameters.h" -#include "quantization/quant_args.h" namespace llm { diff --git a/src/engine/worker_test.cpp b/src/engine/worker_test.cpp index ae4d4a4f..4df260cb 100644 --- a/src/engine/worker_test.cpp +++ b/src/engine/worker_test.cpp @@ -4,9 +4,9 @@ #include "engine/batch.h" #include "engine/utils.h" +#include "layers/quantization/quant_args.h" #include "memory/block_manager.h" #include "models/simple_model.h" -#include "quantization/quant_args.h" namespace llm { class TestableWorker { diff --git a/src/layers/CMakeLists.txt b/src/layers/CMakeLists.txt index f178ce47..abd44bbd 100644 --- a/src/layers/CMakeLists.txt +++ b/src/layers/CMakeLists.txt @@ -1,32 +1,6 @@ include(cc_library) include(cc_test) -cc_library( - NAME - linear - HDRS - linear.h - qkv_linear.h - linear_impl.h - fused_linear.h - weight_utils.h - SRCS - linear.cpp - qkv_linear.cpp - linear_impl.cpp - fused_linear.cpp - weight_utils.cpp - DEPS - :state_dict - :model_parallel - :quantization - :kernels - :module - glog::glog - gflags::gflags - torch -) - cc_library( NAME pos_embedding @@ -73,8 +47,6 @@ cc_test( activation_test.cpp pos_embedding_test.cpp normalization_test.cpp - linear_test.cpp - qkv_linear_test.cpp DEPS :layers :state_dict @@ -82,5 +54,8 @@ cc_test( :gtest_main ) +add_subdirectory(module) +add_subdirectory(linear) +add_subdirectory(quantization) add_subdirectory(attention) add_subdirectory(moe) diff --git a/src/layers/attention/attention.h b/src/layers/attention/attention.h index 64c6bb5e..663fea71 100644 --- a/src/layers/attention/attention.h +++ b/src/layers/attention/attention.h @@ -3,10 +3,10 @@ #include #include "layers/attention/handler.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" #include "memory/kv_cache.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" namespace llm { diff --git a/src/layers/embedding.h b/src/layers/embedding.h index 9ef07033..fbe1ddf0 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -5,10 +5,10 @@ #include +#include "layers/module/module.h" +#include "layers/module/module_holder.h" #include "model_loader/state_dict.h" #include "model_parallel/model_parallel.h" -#include "module/module.h" -#include "module/module_holder.h" namespace llm { diff --git a/src/layers/fused_linear.cpp b/src/layers/fused_linear.cpp deleted file mode 100644 index 86b5ca59..00000000 --- a/src/layers/fused_linear.cpp +++ /dev/null @@ -1,102 +0,0 @@ -#include "fused_linear.h" - -#include -#include - -#include "linear.h" -#include "model_loader/state_dict.h" -#include "model_parallel/parallel_args.h" -#include "quantization/quant_args.h" - -namespace llm { - -FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( - int64_t in_features, - const std::vector& out_features_vec, - const std::vector& prefixes, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { - prefixes_ = prefixes; - // check if the linear layers can be fused - fused_ = quant_args.can_be_fused(); - if (fused_) { - // fused linear layer - const int64_t out_features = std::accumulate( - out_features_vec.begin(), out_features_vec.end(), int64_t(0)); - fused_linear_ = ColumnParallelLinear(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options); - // calculate split sizes - split_sizes_.reserve(out_features_vec.size()); - const auto world_size = parallel_args.world_size(); - for (const auto& out_features : out_features_vec) { - CHECK(out_features % world_size == 0) - << "out_features " << out_features << " not divisible by world_size " - << world_size; - split_sizes_.push_back(out_features / world_size); - } - } else { - // non-fused linear layers - parallel_linears_.reserve(out_features_vec.size()); - for (const auto& out_features : out_features_vec) { - parallel_linears_.emplace_back(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options); - } - } -} - -std::vector FusedColumnParallelLinearImpl::forward( - torch::Tensor input) { - if (fused_) { - auto fused_output = fused_linear_->forward(input); - return fused_output.split(split_sizes_, /*dim=*/1); - } - - // otherwise, use the non-fused linear layers - std::vector outputs; - outputs.reserve(parallel_linears_.size()); - for (auto& parallel_linear : parallel_linears_) { - auto output = parallel_linear->forward(input); - outputs.push_back(output); - } - return outputs; -} - -size_t FusedColumnParallelLinearImpl::load(const StateDict& state_dict, - const std::string&) { - if (fused_) { - fused_linear_->load_state_dict(state_dict, prefixes_); - } else { - CHECK_EQ(parallel_linears_.size(), prefixes_.size()); - for (size_t i = 0; i < parallel_linears_.size(); ++i) { - parallel_linears_[i]->load_state_dict(state_dict.select(prefixes_[i])); - } - } - return 0; -} - -bool FusedColumnParallelLinearImpl::verify( - const std::string& name_prefix) const { - if (fused_) { - fused_linear_->verify_loaded_weights(name_prefix); - } else { - for (const auto& parallel_linear : parallel_linears_) { - parallel_linear->verify_loaded_weights(name_prefix); - } - } - return true; -} - -} // namespace llm diff --git a/src/layers/fused_linear.h b/src/layers/fused_linear.h deleted file mode 100644 index 73323479..00000000 --- a/src/layers/fused_linear.h +++ /dev/null @@ -1,56 +0,0 @@ -#pragma once - -#include -#include - -#include "linear.h" -#include "model_loader/state_dict.h" -#include "model_parallel/parallel_args.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "quantization/quant_args.h" - -namespace llm { - -class FusedColumnParallelLinearImpl : public Module { - public: - FusedColumnParallelLinearImpl(int64_t in_features, - const std::vector& out_features, - const std::vector& prefixes, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); - - std::vector forward(torch::Tensor input); - - // load weights from the checkpoint, override this method if necessary - // returns the number of loaded parameters - size_t load(const StateDict& state_dict, - const std::string& name_prefix = std::string()) override; - - // verify whether the weights are loaded, override this method if necessary - bool verify(const std::string& name_prefix = std::string()) const override; - - // whether the linear layer is fused - bool fused() const { return fused_; } - - private: - // non-fused linear layers - std::vector parallel_linears_; - - // fused linear layer - ColumnParallelLinear fused_linear_{nullptr}; - - // sizes for each split - std::vector split_sizes_; - - std::vector prefixes_; - - // whether the linear layer is fused - bool fused_ = false; -}; -LLM_MODULE(FusedColumnParallelLinear); - -} // namespace llm diff --git a/src/layers/linear.h b/src/layers/linear.h deleted file mode 100644 index 99da103b..00000000 --- a/src/layers/linear.h +++ /dev/null @@ -1,80 +0,0 @@ -#pragma once - -#include -#include - -#include "model_loader/state_dict.h" -#include "model_parallel/parallel_args.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "quantization/quant_args.h" - -namespace llm { - -using TensorTransform = std::function; - -// an interface for parallel linear layer. -// all linear classes should inherit from this class and implement the forward -// function. -class ParallelLinearImpl : public Module { - public: - ~ParallelLinearImpl() override = default; - - virtual torch::Tensor forward(torch::Tensor input) = 0; - - virtual void load_state_dict(const StateDict& state_dict) = 0; - - virtual void verify_loaded_weights(const std::string& prefix = "") const = 0; - - // load state dict with a transform function - virtual void load_state_dict(const StateDict& /*state_dict*/, - TensorTransform /*transform_func*/) { - LOG(FATAL) << "not implemented"; - } - - // special load_state_dict for fused cases - virtual void load_state_dict(const StateDict& /*state_dict*/, - const std::vector& /*prefixes*/) { - LOG(FATAL) << "not implemented"; - } -}; - -class ColumnParallelLinear : public ModuleHolder { - public: - using ModuleHolder::ModuleHolder; - using Impl [[maybe_unused]] = ParallelLinearImpl; - - // construct a rotary positional embedding. - // chose right implementation based on the args. - ColumnParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); - - ColumnParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); -}; - -class RowParallelLinear : public ModuleHolder { - public: - using ModuleHolder::ModuleHolder; - using Impl [[maybe_unused]] = ParallelLinearImpl; - - // construct a rotary positional embedding. - // chose right implementation based on the args. - RowParallelLinear(int64_t in_features, - int64_t out_features, - bool bias, - bool input_is_parallelized, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); -}; -} // namespace llm diff --git a/src/layers/linear/CMakeLists.txt b/src/layers/linear/CMakeLists.txt new file mode 100644 index 00000000..8c0a8d8a --- /dev/null +++ b/src/layers/linear/CMakeLists.txt @@ -0,0 +1,40 @@ +include(cc_library) +include(cc_test) + +cc_library( + NAME + linear + HDRS + parallel_linear.h + multi_parallel_linear.h + qkv_parallel_linear.h + weight_utils.h + SRCS + parallel_linear.cpp + qkv_parallel_linear.cpp + multi_parallel_linear.cpp + weight_utils.cpp + DEPS + :state_dict + :model_parallel + :quantization + :kernels + :module + glog::glog + gflags::gflags + torch +) + +cc_test( + NAME + linear_test + SRCS + parallel_linear_test.cpp + multi_parallel_linear_test.cpp + qkv_parallel_linear_test.cpp + DEPS + :linear + :state_dict + absl::random_random + :gtest_main +) diff --git a/src/layers/linear/multi_parallel_linear.cpp b/src/layers/linear/multi_parallel_linear.cpp new file mode 100644 index 00000000..fb2f2a45 --- /dev/null +++ b/src/layers/linear/multi_parallel_linear.cpp @@ -0,0 +1,172 @@ +#include "multi_parallel_linear.h" + +#include +#include + +#include "layers/quantization/quant_args.h" +#include "model_parallel/model_parallel.h" +#include "model_parallel/parallel_args.h" +#include "parallel_linear.h" + +namespace llm { +namespace { + +std::shared_ptr create_multi_column_parallel_linear( + int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + // check if the linear layers can be fused + if (quant_args.can_be_fused()) { + return std::make_shared(in_features, + out_features, + prefixes, + bias, + gather_output, + parallel_args, + options); + } + + return std::make_shared(in_features, + out_features, + prefixes, + bias, + gather_output, + parallel_args, + options); +} +} // namespace + +FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( + int64_t in_features, + const std::vector& out_features_vec, + const std::vector& prefixes, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) + : gather_output_(gather_output), parallel_args_(parallel_args) { + const auto rank = parallel_args_.rank(); + const auto world_size = parallel_args_.world_size(); + + // calculate split size for each prefix + split_sizes_.reserve(out_features_vec.size()); + for (const auto& out_features : out_features_vec) { + CHECK(out_features % world_size == 0) + << "out_features " << out_features << " not divisible by world_size " + << world_size; + split_sizes_.push_back(out_features / world_size); + } + + const int64_t fused_out_features = + std::accumulate(split_sizes_.begin(), split_sizes_.end(), int64_t(0)); + + // allocate fused weight + weight_ = torch::empty({fused_out_features, in_features}, options); + const auto weights = weight_.split(split_sizes_, /*dim=*/0); + // register sharded weights for each prefix + for (size_t i = 0; i < prefixes.size(); ++i) { + const auto& prefix = prefixes[i]; + const auto& weight = weights[i]; + // register the weight as a parameter to make sure it is moved to the + register_sharded_parameter(detail::join_name(prefix, "weight"), + /*dim=*/0, + rank, + world_size, + weight); + } + + if (bias) { + bias_ = torch::empty({fused_out_features}, options); + const auto biases = bias_.split(split_sizes_, /*dim=*/0); + + // register sharded weights for each prefix + for (size_t i = 0; i < prefixes.size(); ++i) { + const auto& prefix = prefixes[i]; + const auto& bias = biases[i]; + register_sharded_parameter(detail::join_name(prefix, "bias"), + /*dim=*/0, + rank, + world_size, + bias); + } + } +} + +std::vector FusedColumnParallelLinearImpl::forward( + torch::Tensor input) { + namespace F = torch::nn::functional; + auto output = F::linear(input, weight_, bias_); + if (parallel_args_.world_size() > 1 && gather_output_) { + output = gather_from_model_parallel_region(output, parallel_args_); + } + return output.split(split_sizes_, /*dim=*/1); +} + +GroupedColumnParallelLinearImpl::GroupedColumnParallelLinearImpl( + int64_t in_features, + const std::vector& out_features_vec, + const std::vector& prefixes, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + // register linear layers one by one + parallel_linears_.reserve(out_features_vec.size()); + for (size_t i = 0; i < out_features_vec.size(); ++i) { + const auto& prefix = prefixes[i]; + const auto out_features = out_features_vec[i]; + const auto linear = register_module( + "linear_" + std::to_string(i), + std::make_shared(in_features, + out_features, + bias, + gather_output, + parallel_args, + options, + prefix), + /*selector=*/nullptr); + + parallel_linears_.emplace_back(linear); + } +} + +std::vector GroupedColumnParallelLinearImpl::forward( + torch::Tensor input) { + std::vector outputs; + outputs.reserve(parallel_linears_.size()); + for (auto& parallel_linear : parallel_linears_) { + outputs.push_back(parallel_linear->forward(input)); + } + return outputs; +} + +MultiColumnParallelLinear::MultiColumnParallelLinear(std::nullptr_t) + : ModuleHolder(nullptr) {} + +MultiColumnParallelLinear::MultiColumnParallelLinear( + std::shared_ptr module) + : ModuleHolder(std::move(module)) {} + +MultiColumnParallelLinear::MultiColumnParallelLinear( + int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) + : ModuleHolder(create_multi_column_parallel_linear(in_features, + out_features, + prefixes, + bias, + gather_output, + quant_args, + parallel_args, + options)) {} +} // namespace llm diff --git a/src/layers/linear/multi_parallel_linear.h b/src/layers/linear/multi_parallel_linear.h new file mode 100644 index 00000000..42a2295e --- /dev/null +++ b/src/layers/linear/multi_parallel_linear.h @@ -0,0 +1,90 @@ +#pragma once + +#include +#include + +// #include "linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/quantization/quant_args.h" +#include "model_parallel/parallel_args.h" +#include "parallel_linear.h" + +namespace llm { + +class MultiParallelLinearImpl : public Module { + public: + ~MultiParallelLinearImpl() override = default; + + virtual std::vector forward(torch::Tensor input) = 0; +}; + +// Fused linear layer with column parallelism. +class FusedColumnParallelLinearImpl : public MultiParallelLinearImpl { + public: + FusedColumnParallelLinearImpl(int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); + + std::vector forward(torch::Tensor input) override; + + // return the weight (for testing) + torch::Tensor weight() const { return weight_; } + + private: + // parameter members, must be registered + // we allocate the transpose since linear performs XA^T. + // A^T: [out_features_per_partition, in_features] + torch::Tensor weight_; + torch::Tensor bias_; + + std::vector split_sizes_; + + // whether to gather the output + bool gather_output_; + + // parallel args + ParallelArgs parallel_args_; +}; +LLM_MODULE(FusedColumnParallelLinear); + +class GroupedColumnParallelLinearImpl : public MultiParallelLinearImpl { + public: + GroupedColumnParallelLinearImpl(int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); + + std::vector forward(torch::Tensor input) override; + + private: + // parameter members, must be registered + std::vector> parallel_linears_; +}; +LLM_MODULE(GroupedColumnParallelLinear); + +class MultiColumnParallelLinear : public ModuleHolder { + public: + /* implicit */ MultiColumnParallelLinear(std::nullptr_t); + + /* implicit */ MultiColumnParallelLinear( + std::shared_ptr module); + + MultiColumnParallelLinear(int64_t in_features, + const std::vector& out_features, + const std::vector& prefixes, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); +}; + +} // namespace llm diff --git a/src/layers/linear/multi_parallel_linear_test.cpp b/src/layers/linear/multi_parallel_linear_test.cpp new file mode 100644 index 00000000..241c32c9 --- /dev/null +++ b/src/layers/linear/multi_parallel_linear_test.cpp @@ -0,0 +1,197 @@ +#include "multi_parallel_linear.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "model_loader/state_dict.h" + +namespace llm { + +TEST(MultiParallelLinearTest, FusedColumnParallelLinear) { + // test load state dict for linear + const int64_t in_features = 10; + const int64_t out_features = 40; + + torch::Device device(torch::kCPU); + torch::ScalarType dtype(torch::kFloat); + const auto options = torch::dtype(dtype).device(device); + + std::vector out_features_vec = { + out_features, out_features, out_features}; + std::vector prefixes = {"query.", "key.", "value."}; + + std::unordered_map state_dict_data; + // Allocate transposed weight matrix + state_dict_data["query.weight"] = torch::randn({out_features, in_features}); + state_dict_data["key.weight"] = torch::randn({out_features, in_features}); + state_dict_data["value.weight"] = torch::randn({out_features, in_features}); + + // weight is not sharded + StateDict state_dict(state_dict_data); + + // test load weight + { + ParallelArgs parallel_args(0, 1, nullptr); + FusedColumnParallelLinearImpl linear(in_features, + out_features_vec, + prefixes, + /*bias=*/false, + /*gather_output=*/false, + parallel_args, + options); + // test load fused weight + EXPECT_EQ(linear.load(state_dict), 3); + + for (const auto& prefix : prefixes) { + auto named_parameters = linear.named_parameters(/*recurse=*/false); + const auto key = detail::join_name(prefix, "weight"); + ASSERT_TRUE(named_parameters.contains(key)); + + const auto& loaded_weight = named_parameters[key]; + EXPECT_EQ(loaded_weight.sizes(), + torch::IntArrayRef({out_features, in_features})); + EXPECT_TRUE(torch::equal(loaded_weight, state_dict_data[key])); + } + + // verify the fused weight + const auto loaded_fused_weight = linear.weight(); + const auto desired_fused_weight = + torch::cat({state_dict_data["query.weight"], + state_dict_data["key.weight"], + state_dict_data["value.weight"]}, + /*dim=*/0); + EXPECT_TRUE(torch::equal(loaded_fused_weight, desired_fused_weight)); + } + + // test load weight with 4 shards + const int32_t num_shards = 4; + for (int32_t shard_id = 0; shard_id < num_shards; ++shard_id) { + ParallelArgs parallel_args_0(shard_id, num_shards, nullptr); + FusedColumnParallelLinearImpl linear(in_features, + out_features_vec, + prefixes, + /*bias=*/false, + /*gather_output=*/false, + parallel_args_0, + options); + EXPECT_EQ(linear.load(state_dict), 3); + + auto named_parameters = linear.named_parameters(/*recurse=*/false); + + // check size for each prefix + for (const auto& prefix : prefixes) { + const auto key = detail::join_name(prefix, "weight"); + ASSERT_TRUE(named_parameters.contains(key)); + + const auto& loaded_weight = named_parameters[key]; + EXPECT_EQ(loaded_weight.sizes(), + torch::IntArrayRef({out_features / num_shards, in_features})); + EXPECT_TRUE(torch::equal( + loaded_weight, state_dict_data[key].chunk(num_shards, 0)[shard_id])); + } + + // shard weight then cat + auto sharded_query_weight = + state_dict_data["query.weight"].chunk(num_shards, 0)[shard_id]; + auto sharded_key_weight = + state_dict_data["key.weight"].chunk(num_shards, 0)[shard_id]; + auto sharded_value_weight = + state_dict_data["value.weight"].chunk(num_shards, 0)[shard_id]; + + // verify the fused weight + const auto loaded_fused_weight = linear.weight(); + auto desired_fused_weight = torch::cat( + {sharded_query_weight, sharded_key_weight, sharded_value_weight}, + /*dim=*/0); + + EXPECT_TRUE(torch::equal(loaded_fused_weight, desired_fused_weight)); + } +} + +TEST(MultiParallelLinearTest, GroupedColumnParallelLinear) { + const int64_t in_features = 10; + const int64_t out_features = 40; + std::vector out_features_vec = { + out_features, out_features, out_features}; + std::vector prefixes = {"query.", "key.", "value."}; + + torch::Device device(torch::kCPU); + torch::ScalarType dtype(torch::kFloat); + const auto options = torch::dtype(dtype).device(device); + + std::unordered_map state_dict_data; + // Allocate transposed weight matrix + state_dict_data["query.weight"] = torch::randn({out_features, in_features}); + state_dict_data["key.weight"] = torch::randn({out_features, in_features}); + state_dict_data["value.weight"] = torch::randn({out_features, in_features}); + // weight is not sharded + StateDict state_dict(state_dict_data); + + // test load weight + { + ParallelArgs parallel_args(0, 1, nullptr); + GroupedColumnParallelLinearImpl linear(in_features, + out_features_vec, + prefixes, + /*bias=*/false, + /*gather_output=*/false, + parallel_args, + options); + // test load grouped weight + EXPECT_EQ(linear.load(state_dict), 3); + + auto named_parameters = linear.named_parameters(/*recurse=*/true); + for (size_t i = 0; i < prefixes.size(); ++i) { + const auto prefix = "linear_" + std::to_string(i) + "." + prefixes[i]; + const auto key = detail::join_name(prefix, "weight"); + ASSERT_TRUE(named_parameters.contains(key)); + + const auto& loaded_weight = named_parameters[key]; + + const auto sd_key = detail::join_name(prefixes[i], "weight"); + + EXPECT_EQ(loaded_weight.sizes(), + torch::IntArrayRef({out_features, in_features})); + EXPECT_TRUE(torch::equal(loaded_weight, state_dict_data[sd_key])); + } + } + + // test load weight with 4 shards + const int32_t num_shards = 4; + for (int32_t shard_id = 0; shard_id < num_shards; ++shard_id) { + ParallelArgs parallel_args(shard_id, num_shards, nullptr); + GroupedColumnParallelLinearImpl linear(in_features, + out_features_vec, + prefixes, + /*bias=*/false, + /*gather_output=*/false, + parallel_args, + options); + EXPECT_EQ(linear.load(state_dict), 3); + auto named_parameters = linear.named_parameters(/*recurse=*/true); + // check size for each prefix + for (size_t i = 0; i < prefixes.size(); ++i) { + const auto prefix = "linear_" + std::to_string(i) + "." + prefixes[i]; + const auto key = detail::join_name(prefix, "weight"); + ASSERT_TRUE(named_parameters.contains(key)); + + const auto& loaded_weight = named_parameters[key]; + EXPECT_EQ(loaded_weight.sizes(), + torch::IntArrayRef({out_features / num_shards, in_features})); + const auto sd_key = detail::join_name(prefixes[i], "weight"); + EXPECT_TRUE( + torch::equal(loaded_weight, + state_dict_data[sd_key].chunk(num_shards, 0)[shard_id])); + } + } +} + +} // namespace llm diff --git a/src/layers/linear.cpp b/src/layers/linear/parallel_linear.cpp similarity index 67% rename from src/layers/linear.cpp rename to src/layers/linear/parallel_linear.cpp index 70a6d6f8..ed591457 100644 --- a/src/layers/linear.cpp +++ b/src/layers/linear/parallel_linear.cpp @@ -1,17 +1,19 @@ -#include "linear.h" +#include "parallel_linear.h" +#include #include #include #include #include -#include "linear_impl.h" -#include "quantization/qlinear_awq_impl.h" -#include "quantization/qlinear_awq_marlin_impl.h" -#include "quantization/qlinear_exllamav2_impl.h" -#include "quantization/qlinear_gptq_impl.h" -#include "quantization/qlinear_gptq_marlin_impl.h" +#include "layers/module/module.h" +#include "layers/quantization/qlinear_awq_impl.h" +#include "layers/quantization/qlinear_awq_marlin_impl.h" +#include "layers/quantization/qlinear_exllamav2_impl.h" +#include "layers/quantization/qlinear_gptq_impl.h" +#include "layers/quantization/qlinear_gptq_marlin_impl.h" +#include "model_parallel/model_parallel.h" DEFINE_string( qlinear_gptq_impl, @@ -38,18 +40,6 @@ namespace { parallel_args, \ options); -#define MAKE_ROW_PARALLEL_LINEAR(LinearlImplClass) \ - std::make_shared(in_features, \ - out_features, \ - bias, \ - input_is_parallelized, \ - parallel_args, \ - options); - -#define MAKE_COLUMN_PARALLEL_LINEAR(LinearlImplClass) \ - std::make_shared( \ - in_features, out_features, bias, gather_output, parallel_args, options); - std::shared_ptr create_column_parallel_qlinear_by_impl( int64_t in_features, int64_t out_features, @@ -139,6 +129,7 @@ std::shared_ptr create_column_parallel_qlinear( } // not supported quant method LOG(FATAL) << "Unsupported quant method: " << quant_args.quant_method(); + return nullptr; } std::shared_ptr create_row_parallel_qlinear( @@ -170,6 +161,7 @@ std::shared_ptr create_row_parallel_qlinear( } // not supported quant method LOG(FATAL) << "Unsupported quant method: " << quant_args.quant_method(); + return nullptr; } std::shared_ptr create_column_parallel_linear( @@ -179,7 +171,8 @@ std::shared_ptr create_column_parallel_linear( bool gather_output, const QuantArgs& quant_args, const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { + const torch::TensorOptions& options, + const std::string& prefix) { if (!quant_args.quant_method().empty()) { return create_column_parallel_qlinear(in_features, out_features, @@ -189,7 +182,13 @@ std::shared_ptr create_column_parallel_linear( parallel_args, options); } - return MAKE_COLUMN_PARALLEL_LINEAR(ColumnParallelLinearImpl); + return std ::make_shared(in_features, + out_features, + bias, + gather_output, + parallel_args, + options, + prefix); } std::shared_ptr create_row_parallel_linear( @@ -209,26 +208,130 @@ std::shared_ptr create_row_parallel_linear( parallel_args, options); } - return MAKE_ROW_PARALLEL_LINEAR(RowParallelLinearImpl); + return std ::make_shared(in_features, + out_features, + bias, + input_is_parallelized, + parallel_args, + options); } } // namespace +// Linear layer with column parallelism. +ColumnParallelLinearImpl::ColumnParallelLinearImpl( + int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options, + const std::string& prefix) + : gather_output_(gather_output), parallel_args_(parallel_args) { + const auto rank = parallel_args_.rank(); + const auto world_size = parallel_args_.world_size(); + CHECK(out_features % world_size == 0) + << "out_features " << out_features << " not divisible by world_size " + << world_size; + const int64_t out_features_per_partition = out_features / world_size; + + // Note: torch.nn.functional.linear performs XA^T + b and as a result + // we allocate the transpose. + weight_ = register_sharded_parameter( + detail::join_name(prefix, "weight"), + /*dim=*/0, + rank, + world_size, + torch::empty({out_features_per_partition, in_features}, options)); + + if (bias) { + bias_ = register_sharded_parameter( + detail::join_name(prefix, "bias"), + /*dim=*/0, + rank, + world_size, + torch::empty({out_features_per_partition}, options)); + } +} + +torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) { + namespace F = torch::nn::functional; + auto output = F::linear(input, weight_, bias_); + if (parallel_args_.world_size() > 1 && gather_output_) { + output = gather_from_model_parallel_region(output, parallel_args_); + } + return output; +} + +// Linear layer with row parallelism. +RowParallelLinearImpl::RowParallelLinearImpl( + int64_t in_features, + int64_t out_features, + bool bias, + bool input_is_parallelized, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) + : input_is_parallelized_(input_is_parallelized), + parallel_args_(parallel_args) { + const auto rank = parallel_args_.rank(); + const auto world_size = parallel_args_.world_size(); + CHECK(in_features % world_size == 0) + << "in_features " << in_features << " not divisible by world_size " + << world_size; + const int64_t in_features_per_partition = in_features / world_size; + // Allocate the transpose since linear performs XA^T. + weight_ = register_sharded_parameter( + "weight", + /*dim=*/1, + rank, + world_size, + torch::empty({out_features, in_features_per_partition}, options)); + + if (bias) { + bias_ = register_parameter("bias", torch::empty({out_features}, options)); + } +} + +torch::Tensor RowParallelLinearImpl::forward(torch::Tensor input) { + namespace F = torch::nn::functional; + if (!input_is_parallelized_) { + input = scatter_to_model_parallel_region(input, parallel_args_); + } + auto output = F::linear(input, weight_); + if (parallel_args_.world_size() > 1) { + output = reduce_from_model_parallel_region(output, parallel_args_); + } + // N.B. need to apply bias after the reduce + if (bias_.defined()) { + output.add_(bias_); + } + return output; +} + // construct a ColumnParallelLinear. // chose right implementation based on the args. +ColumnParallelLinear::ColumnParallelLinear(std::nullptr_t) + : ModuleHolder(nullptr) {} + +ColumnParallelLinear::ColumnParallelLinear( + std::shared_ptr module) + : ModuleHolder(std::move(module)) {} + ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, int64_t out_features, bool bias, bool gather_output, const QuantArgs& quant_args, const ParallelArgs& parallel_args, - const torch::TensorOptions& options) + const torch::TensorOptions& options, + const std::string& prefix) : ModuleHolder(create_column_parallel_linear(in_features, out_features, bias, gather_output, quant_args, parallel_args, - options)) {} + options, + prefix)) {} ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, int64_t out_features, @@ -242,10 +345,16 @@ ColumnParallelLinear::ColumnParallelLinear(int64_t in_features, gather_output, {}, /*quant_args*/ parallel_args, - options)) {} + options, + "")) {} -// construct a rotary positional embedding. +// construct a row parallel linear. // chose right implementation based on the args. +RowParallelLinear::RowParallelLinear(std::nullptr_t) : ModuleHolder(nullptr) {} + +RowParallelLinear::RowParallelLinear(std::shared_ptr module) + : ModuleHolder(std::move(module)) {} + RowParallelLinear::RowParallelLinear(int64_t in_features, int64_t out_features, bool bias, diff --git a/src/layers/linear/parallel_linear.h b/src/layers/linear/parallel_linear.h new file mode 100644 index 00000000..d1ebee68 --- /dev/null +++ b/src/layers/linear/parallel_linear.h @@ -0,0 +1,152 @@ +#pragma once + +#include +#include + +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/quantization/quant_args.h" +#include "model_loader/state_dict.h" +#include "model_parallel/parallel_args.h" + +namespace llm { + +// an interface for parallel linear layer. +// all linear classes should inherit from this class and implement the forward +// function. +class ParallelLinearImpl : public Module { + public: + ~ParallelLinearImpl() override = default; + + virtual torch::Tensor forward(torch::Tensor input) = 0; + + // TODO: clean up the interface of load_state_dict + virtual void load_state_dict(const StateDict& state_dict) { + LOG(FATAL) << "not implemented"; + } + + virtual void verify_loaded_weights(const std::string& prefix = "") const { + LOG(FATAL) << "not implemented"; + } + + // special load_state_dict for fused cases + virtual void load_state_dict(const StateDict& /*state_dict*/, + const std::vector& /*prefixes*/) { + LOG(FATAL) << "not implemented"; + } +}; + +// Linear layer with column parallelism. +// The linear layer is defined as Y = XA + b. A is parallelized along +// its second dimension as A = [A_1, ..., A_p]. +class ColumnParallelLinearImpl : public ParallelLinearImpl { + public: + ColumnParallelLinearImpl(int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options, + const std::string& prefix = ""); + + torch::Tensor forward(torch::Tensor input) override; + + // return the weight (for testing) + torch::Tensor weight() const { return weight_; } + + private: + // parameter members, must be registered + // we allocate the transpose since linear performs XA^T. + // A^T: [out_features_per_partition, in_features] + torch::Tensor weight_; + torch::Tensor bias_; + + // whether to gather the output + bool gather_output_; + + // parallel args + ParallelArgs parallel_args_; +}; + +// Linear layer with row parallelism. +// The linear layer is defined as Y = XA + b. A is parallelized along +// its first dimension and X along its second dimension as: +// - - +// | A_1 | +// | . | +// A = | . | X = [X_1, ..., X_p] +// | . | +// | A_p | +// - - +class RowParallelLinearImpl : public ParallelLinearImpl { + public: + RowParallelLinearImpl(int64_t in_features, + int64_t out_features, + bool bias, + bool input_is_parallelized, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); + + torch::Tensor forward(torch::Tensor input) override; + + // return the weight (for testing) + torch::Tensor weight() const { return weight_; } + + private: + // parameter members, must be registered + // we allocate the transpose since linear performs XA^T. + // A^T: [out_features, in_features_per_partition] + torch::Tensor weight_; + torch::Tensor bias_; + + // whether the input is already parallelized + bool input_is_parallelized_; + + // parallel args + ParallelArgs parallel_args_; +}; + +class ColumnParallelLinear : public ModuleHolder { + public: + /* implicit */ ColumnParallelLinear(std::nullptr_t); + + /* implicit */ ColumnParallelLinear( + std::shared_ptr module); + + // construct a rotary positional embedding. + // chose right implementation based on the args. + ColumnParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options, + const std::string& prefix = ""); + + ColumnParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); +}; + +class RowParallelLinear : public ModuleHolder { + public: + /* implicit */ RowParallelLinear(std::nullptr_t); + + /* implicit */ RowParallelLinear(std::shared_ptr module); + + // construct a rotary positional embedding. + // chose right implementation based on the args. + RowParallelLinear(int64_t in_features, + int64_t out_features, + bool bias, + bool input_is_parallelized, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); +}; + +} // namespace llm diff --git a/src/layers/linear_test.cpp b/src/layers/linear/parallel_linear_test.cpp similarity index 53% rename from src/layers/linear_test.cpp rename to src/layers/linear/parallel_linear_test.cpp index 3306ebf2..236b58d1 100644 --- a/src/layers/linear_test.cpp +++ b/src/layers/linear/parallel_linear_test.cpp @@ -1,19 +1,21 @@ +#include "parallel_linear.h" + #include #include #include #include #include +#include #include #include #include -#include "linear_impl.h" #include "model_loader/state_dict.h" namespace llm { -TEST(LinearTest, RowParallelLoadWeight) { +TEST(ParallelLinearTest, RowParallelLinear) { // test load state dict for row parallel linear const int64_t in_features = 10; const int64_t out_features = 20; @@ -40,7 +42,7 @@ TEST(LinearTest, RowParallelLoadWeight) { parallel_args, options); // test load state dict for transformer - linear.load_state_dict(state_dict); + EXPECT_EQ(linear.load(state_dict), 2); auto named_parameters = linear.named_parameters(/*recurse=*/false); EXPECT_TRUE(torch::equal(state_dict.get_tensor("weight"), named_parameters["weight"])); @@ -58,7 +60,7 @@ TEST(LinearTest, RowParallelLoadWeight) { /*input_is_parallelized=*/false, parallel_args, options); - linear.load_state_dict(state_dict); + EXPECT_EQ(linear.load(state_dict), 2); auto named_parameters = linear.named_parameters(/*recurse=*/false); @@ -78,31 +80,7 @@ TEST(LinearTest, RowParallelLoadWeight) { } } -TEST(LinearTest, RowParallelLoadFusedWeight) { - // test load state dict for row parallel linear - const int64_t in_features = 10; - const int64_t out_features = 20; - - torch::Device device(torch::kCPU); - torch::ScalarType dtype(torch::kFloat); - const auto options = torch::dtype(dtype).device(device); - StateDict state_dict({}); - - // test load weight - ParallelArgs parallel_args(0, 1, nullptr); - RowParallelLinearImpl linear(in_features, - out_features * 3, - /*bias=*/false, - /*input_is_parallelized=*/true, - parallel_args, - options); - // test load fused weight - EXPECT_DEATH( - linear.ParallelLinearImpl::load_state_dict(state_dict, {"query."}), - "not implemented"); -} - -TEST(LinearTest, ColumnParallelLoadWeight) { +TEST(ParallelLinearTest, ColumnParallelLinear) { // test load state dict for linear const int64_t in_features = 10; const int64_t out_features = 20; @@ -128,7 +106,7 @@ TEST(LinearTest, ColumnParallelLoadWeight) { parallel_args, options); // test load state dict for transformer - linear.load_state_dict(state_dict); + EXPECT_EQ(linear.load(state_dict), 1); auto named_parameters = linear.named_parameters(/*recurse=*/false); EXPECT_TRUE(torch::equal(state_dict.get_tensor("weight"), named_parameters["weight"])); @@ -143,7 +121,7 @@ TEST(LinearTest, ColumnParallelLoadWeight) { /*gather_output=*/false, parallel_args, options); - linear.load_state_dict(state_dict); + EXPECT_EQ(linear.load(state_dict), 1); auto named_parameters = linear.named_parameters(/*recurse=*/false); @@ -159,80 +137,4 @@ TEST(LinearTest, ColumnParallelLoadWeight) { } } -TEST(LinearTest, ColumnParallelLoadFusedWeight) { - // test load state dict for linear - const int64_t in_features = 10; - const int64_t out_features = 20; - - torch::Device device(torch::kCPU); - torch::ScalarType dtype(torch::kFloat); - const auto options = torch::dtype(dtype).device(device); - std::unordered_map state_dict_data; - // Allocate transposed weight matrix - state_dict_data["query.weight"] = torch::randn({out_features, in_features}); - state_dict_data["key.weight"] = torch::randn({out_features, in_features}); - state_dict_data["value.weight"] = torch::randn({out_features, in_features}); - - // weight is not sharded - StateDict state_dict(state_dict_data); - - // test load weight - { - ParallelArgs parallel_args(0, 1, nullptr); - ColumnParallelLinearImpl linear(in_features, - out_features * 3, - /*bias=*/false, - /*gather_output=*/false, - parallel_args, - options); - // test load fused weight - linear.load_state_dict(state_dict, {"query.", "key.", "value."}); - - auto named_parameters = linear.named_parameters(/*recurse=*/false); - ASSERT_TRUE(named_parameters.contains("weight")); - - const auto loaded_weight = named_parameters["weight"]; - EXPECT_EQ(loaded_weight.sizes(), - torch::IntArrayRef({3 * out_features, in_features})); - - auto desired_weight = torch::cat({state_dict_data["query.weight"], - state_dict_data["key.weight"], - state_dict_data["value.weight"]}, - /*dim=*/0); - EXPECT_TRUE(torch::equal(loaded_weight, desired_weight)); - } - - // test load weight with 2 shards - const int32_t num_shards = 2; - for (int32_t shard_id = 0; shard_id < num_shards; ++shard_id) { - ParallelArgs parallel_args_0(shard_id, num_shards, nullptr); - ColumnParallelLinearImpl linear(in_features, - out_features * 3, - /*bias=*/false, - /*gather_output=*/false, - parallel_args_0, - options); - linear.load_state_dict(state_dict, {"query.", "key.", "value."}); - - auto named_parameters = linear.named_parameters(/*recurse=*/false); - const auto loaded_weight = named_parameters["weight"]; - - EXPECT_EQ(loaded_weight.sizes(), - torch::IntArrayRef({3 * out_features / num_shards, in_features})); - - // shard weight then cat - auto query_weight = state_dict_data["query.weight"].chunk( - /*chunks=*/num_shards, /*dim=*/0)[shard_id]; - auto key_weight = state_dict_data["key.weight"].chunk(/*chunks=*/num_shards, - /*dim=*/0)[shard_id]; - auto value_weight = state_dict_data["value.weight"].chunk( - /*chunks=*/num_shards, /*dim=*/0)[shard_id]; - - auto desired_weight = torch::cat({query_weight, key_weight, value_weight}, - /*dim=*/0); - - EXPECT_TRUE(torch::equal(loaded_weight, desired_weight)); - } -} - } // namespace llm diff --git a/src/layers/qkv_linear.cpp b/src/layers/linear/qkv_parallel_linear.cpp similarity index 97% rename from src/layers/qkv_linear.cpp rename to src/layers/linear/qkv_parallel_linear.cpp index 8ef9b532..055c8ec0 100644 --- a/src/layers/qkv_linear.cpp +++ b/src/layers/linear/qkv_parallel_linear.cpp @@ -1,4 +1,4 @@ -#include "qkv_linear.h" +#include "qkv_parallel_linear.h" #include #include @@ -70,7 +70,7 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl( }; parallel_linear_ = register_module("qkv_parallel_linear", - FusedColumnParallelLinear(hidden_size, + MultiColumnParallelLinear(hidden_size, out_features, prefixes, bias, diff --git a/src/layers/qkv_linear.h b/src/layers/linear/qkv_parallel_linear.h similarity index 69% rename from src/layers/qkv_linear.h rename to src/layers/linear/qkv_parallel_linear.h index 64923e27..43bff377 100644 --- a/src/layers/qkv_linear.h +++ b/src/layers/linear/qkv_parallel_linear.h @@ -3,12 +3,12 @@ #include #include -#include "fused_linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/quantization/quant_args.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "quantization/quant_args.h" +#include "multi_parallel_linear.h" namespace llm { @@ -27,13 +27,16 @@ class QKVColumnParallelLinearImpl : public Module { const ParallelArgs& parallel_args, const torch::TensorOptions& options); - std::vector forward(torch::Tensor input) { - return parallel_linear_->forward(input); + // returns (query, key, value) + std::tuple forward( + torch::Tensor input) { + const auto qkv = parallel_linear_->forward(input); + return {qkv[0], qkv[1], qkv[2]}; } private: // registered modules - FusedColumnParallelLinear parallel_linear_{nullptr}; + MultiColumnParallelLinear parallel_linear_{nullptr}; }; LLM_MODULE(QKVColumnParallelLinear); diff --git a/src/layers/qkv_linear_test.cpp b/src/layers/linear/qkv_parallel_linear_test.cpp similarity index 91% rename from src/layers/qkv_linear_test.cpp rename to src/layers/linear/qkv_parallel_linear_test.cpp index 86ae2f6b..a45b1eb5 100644 --- a/src/layers/qkv_linear_test.cpp +++ b/src/layers/linear/qkv_parallel_linear_test.cpp @@ -1,4 +1,4 @@ -#include "qkv_linear.h" +#include "qkv_parallel_linear.h" #include #include @@ -67,19 +67,19 @@ TEST_P(QKVColumnParallelLinearTest, LoadFusedWeight) { // generate random input and compare with the output auto input = torch::randn({n_tokens, hidden_size}, options); - auto qkv = linear.forward(input); + const auto [q, k, v] = linear.forward(input); const int64_t kv_shard_id = n_kv_heads >= n_shards ? shard_id : n_kv_heads * shard_id / n_shards; auto query = input.matmul(query_chunks[shard_id].t()); - EXPECT_TRUE(torch::allclose(qkv[0], query, /*rtol=*/1e-5, /*atol=*/1e-5)); + EXPECT_TRUE(torch::allclose(q, query, /*rtol=*/1e-5, /*atol=*/1e-5)); auto key = input.matmul(key_chunks[kv_shard_id].t()); - EXPECT_TRUE(torch::allclose(qkv[1], key, /*rtol=*/1e-5, /*atol=*/1e-5)); + EXPECT_TRUE(torch::allclose(k, key, /*rtol=*/1e-5, /*atol=*/1e-5)); auto value = input.matmul(value_chunks[kv_shard_id].t()); - EXPECT_TRUE(torch::allclose(qkv[2], value, /*rtol=*/1e-5, /*atol=*/1e-5)); + EXPECT_TRUE(torch::allclose(v, value, /*rtol=*/1e-5, /*atol=*/1e-5)); } } diff --git a/src/layers/weight_utils.cpp b/src/layers/linear/weight_utils.cpp similarity index 99% rename from src/layers/weight_utils.cpp rename to src/layers/linear/weight_utils.cpp index 174db65a..a037dd4d 100644 --- a/src/layers/weight_utils.cpp +++ b/src/layers/linear/weight_utils.cpp @@ -111,4 +111,4 @@ void WeightUtils::load_fused_weight( } } -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/layers/weight_utils.h b/src/layers/linear/weight_utils.h similarity index 99% rename from src/layers/weight_utils.h rename to src/layers/linear/weight_utils.h index 147506ad..dbe71a8e 100644 --- a/src/layers/weight_utils.h +++ b/src/layers/linear/weight_utils.h @@ -82,4 +82,4 @@ class WeightUtils { #define LOAD_WEIGHT(name) \ WeightUtils::load_weight(state_dict, #name, name##_, name##_is_loaded_); -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/layers/linear_impl.cpp b/src/layers/linear_impl.cpp deleted file mode 100644 index 7b6d8f04..00000000 --- a/src/layers/linear_impl.cpp +++ /dev/null @@ -1,152 +0,0 @@ -#include "linear_impl.h" - -#include -#include -#include - -#include "model_loader/state_dict.h" -#include "model_parallel/model_parallel.h" - -namespace llm { - -// Linear layer with column parallelism. -ColumnParallelLinearImpl::ColumnParallelLinearImpl( - int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) - : gather_output_(gather_output), parallel_args_(parallel_args) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - CHECK(out_features % world_size == 0) - << "out_features " << out_features << " not divisible by world_size " - << world_size; - const int64_t out_features_per_partition = out_features / world_size; - - // Note: torch.nn.functional.linear performs XA^T + b and as a result - // we allocate the transpose. - weight_ = register_sharded_parameter( - "weight", - /*dim=*/0, - rank, - world_size, - torch::empty({out_features_per_partition, in_features}, options)); - - if (bias) { - bias_ = register_sharded_parameter( - "bias", - /*dim=*/0, - rank, - world_size, - torch::empty({out_features_per_partition}, options)); - } -} - -torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) { - namespace F = torch::nn::functional; - auto output = F::linear(input, weight_, bias_); - if (parallel_args_.world_size() > 1 && gather_output_) { - output = gather_from_model_parallel_region(output, parallel_args_); - } - return output; -} - -// load the weight from the checkpoint -void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict) { - // call load_state_dict with identity transform - load_state_dict(state_dict, - [](const torch::Tensor& tensor) { return tensor; }); -} - -void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict, - TensorTransform transform_func) { - CHECK(transform_func != nullptr) << "transform_func must be provided"; - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // load sharded weights on dim 0 - LOAD_SHARDED_WEIGHT_WITH_TRANSFORM(weight, 0); - - if (bias_.defined()) { - // load sharded bias on dim 0 - LOAD_SHARDED_WEIGHT_WITH_TRANSFORM(bias, 0); - } -} - -// special load_state_dict for fused cases -void ColumnParallelLinearImpl::load_state_dict( - const StateDict& state_dict, - const std::vector& prefixes) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // load and merge the weights on dim 0 - LOAD_FUSED_WEIGHT(weight, 0); - - if (bias_.defined()) { - // load and merge the bias on dim 0 - LOAD_FUSED_WEIGHT(bias, 0); - } -} - -// Linear layer with row parallelism. -RowParallelLinearImpl::RowParallelLinearImpl( - int64_t in_features, - int64_t out_features, - bool bias, - bool input_is_parallelized, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) - : input_is_parallelized_(input_is_parallelized), - parallel_args_(parallel_args) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - CHECK(in_features % world_size == 0) - << "in_features " << in_features << " not divisible by world_size " - << world_size; - const int64_t in_features_per_partition = in_features / world_size; - // Allocate the transpose since linear performs XA^T. - weight_ = register_sharded_parameter( - "weight", - /*dim=*/1, - rank, - world_size, - torch::empty({out_features, in_features_per_partition}, options)); - - if (bias) { - bias_ = register_parameter("bias", torch::empty({out_features}, options)); - } -} - -torch::Tensor RowParallelLinearImpl::forward(torch::Tensor input) { - namespace F = torch::nn::functional; - if (!input_is_parallelized_) { - input = scatter_to_model_parallel_region(input, parallel_args_); - } - auto output = F::linear(input, weight_); - if (parallel_args_.world_size() > 1) { - output = reduce_from_model_parallel_region(output, parallel_args_); - } - // N.B. need to apply bias after the reduce - if (bias_.defined()) { - output.add_(bias_); - } - return output; -} - -// load the weight from the checkpoint -void RowParallelLinearImpl::load_state_dict(const StateDict& state_dict) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // load sharded weights on dim 1 - LOAD_SHARDED_WEIGHT(weight, 1); - - if (bias_.defined()) { - LOAD_WEIGHT(bias); - } -} - -} // namespace llm diff --git a/src/layers/linear_impl.h b/src/layers/linear_impl.h deleted file mode 100644 index ff551649..00000000 --- a/src/layers/linear_impl.h +++ /dev/null @@ -1,110 +0,0 @@ -#pragma once - -#include -#include - -#include "linear.h" -#include "model_loader/state_dict.h" -#include "weight_utils.h" - -namespace llm { - -// Linear layer with column parallelism. -// The linear layer is defined as Y = XA + b. A is parallelized along -// its second dimension as A = [A_1, ..., A_p]. -class ColumnParallelLinearImpl : public ParallelLinearImpl { - public: - ColumnParallelLinearImpl(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); - - torch::Tensor forward(torch::Tensor input) override; - - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; - - // load state dict with a transform function - void load_state_dict(const StateDict& state_dict, - TensorTransform transform_func) override; - - // special load_state_dict for fused cases - void load_state_dict(const StateDict& state_dict, - const std::vector& prefixes) override; - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix) const override { - CHECK(weight_is_loaded_) - << "weight is not loaded for " << prefix + "weight"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; - } - - // return the weight (for testing) - torch::Tensor weight() const { return weight_; } - - private: - // parameter members, must be registered - // we allocate the transpose since linear performs XA^T. - // A^T: [out_features_per_partition, in_features] - DEFINE_FUSED_WEIGHT(weight); - DEFINE_FUSED_WEIGHT(bias); - - // whether to gather the output - bool gather_output_; - - // parallel args - ParallelArgs parallel_args_; -}; - -// Linear layer with row parallelism. -// The linear layer is defined as Y = XA + b. A is parallelized along -// its first dimension and X along its second dimension as: -// - - -// | A_1 | -// | . | -// A = | . | X = [X_1, ..., X_p] -// | . | -// | A_p | -// - - -class RowParallelLinearImpl : public ParallelLinearImpl { - public: - RowParallelLinearImpl(int64_t in_features, - int64_t out_features, - bool bias, - bool input_is_parallelized, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); - - torch::Tensor forward(torch::Tensor input) override; - - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const override { - CHECK(weight_is_loaded_) - << "weight is not loaded for " << prefix + "weight"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; - } - - // return the weight (for testing) - torch::Tensor weight() const { return weight_; } - - private: - // parameter members, must be registered - // we allocate the transpose since linear performs XA^T. - // A^T: [out_features, in_features_per_partition] - DEFINE_WEIGHT(weight); - DEFINE_WEIGHT(bias); - - // whether the input is already parallelized - bool input_is_parallelized_; - - // parallel args - ParallelArgs parallel_args_; -}; -} // namespace llm diff --git a/src/module/CMakeLists.txt b/src/layers/module/CMakeLists.txt similarity index 100% rename from src/module/CMakeLists.txt rename to src/layers/module/CMakeLists.txt diff --git a/src/module/module.cpp b/src/layers/module/module.cpp similarity index 92% rename from src/module/module.cpp rename to src/layers/module/module.cpp index 29d730b7..a9082e10 100644 --- a/src/module/module.cpp +++ b/src/layers/module/module.cpp @@ -7,7 +7,7 @@ namespace llm { using namespace torch; -namespace { +namespace detail { /// Joins names hierarchically: "name_prefix.name" if `name_prefix` is /// non-empty, else just "name". std::string join_name(const std::string& name_prefix, const std::string& name) { @@ -19,12 +19,15 @@ std::string join_name(const std::string& name_prefix, const std::string& name) { full_name.reserve(total_size); if (!name_prefix.empty()) { full_name += name_prefix; - full_name.push_back('.'); + // insert separator if necessary + if (name_prefix.back() != '.') { + full_name.push_back('.'); + } } full_name += name; return full_name; } -} // namespace +} // namespace detail Module::Module() : parameters_("Parameter"), buffers_("Buffer"), children_("Submodule") {} @@ -60,7 +63,8 @@ OrderedDict Module::named_parameters(bool recurse) const { apply([&result](const std::string& name, const Module& module) { for (const auto& parameter : module.named_parameters(/*recurse=*/false)) { TORCH_INTERNAL_ASSERT(parameter.value().defined()); - result.insert(join_name(name, parameter.key()), parameter.value()); + result.insert(detail::join_name(name, parameter.key()), + parameter.value()); } }); } @@ -83,7 +87,7 @@ OrderedDict Module::named_buffers(bool recurse) const { apply([&result](const std::string& name, const Module& module) { for (const auto& buffer : module.named_buffers(/*recurse=*/false)) { TORCH_INTERNAL_ASSERT(buffer.value().defined()); - result.insert(join_name(name, buffer.key()), buffer.value()); + result.insert(detail::join_name(name, buffer.key()), buffer.value()); } }); } @@ -278,7 +282,7 @@ void Module::apply_to_submodules( const NamedModulePointerApplyFunction& function, const std::string& name_prefix) const { for (const auto& child : children_) { - auto qualified_name = join_name(name_prefix, child.key()); + auto qualified_name = detail::join_name(name_prefix, child.key()); function(qualified_name, child.value().module); child.value().module->apply_to_submodules(function, qualified_name); } @@ -331,12 +335,13 @@ size_t Module::load(const StateDict& state_dict, } if (param.is_loaded) { - LOG(WARNING) << "Parameter " << join_name(name_prefix, key) + LOG(WARNING) << "Parameter " << detail::join_name(name_prefix, key) << " is already loaded"; } if (param_tensor.sizes() == tensor.sizes()) { - // LOG(INFO) << "Loading parameter: " << join_name(name_prefix, key) + // LOG(INFO) << "Loading parameter: " << detail::join_name(name_prefix, + // key) // << " of size " << tensor.sizes(); // copy data to the parameter tensor param_tensor.copy_(tensor); @@ -345,7 +350,7 @@ size_t Module::load(const StateDict& state_dict, ++total_loaded; } else { LOG(ERROR) << "Size mismatch for parameter " - << join_name(name_prefix, key) << ": expected " + << detail::join_name(name_prefix, key) << ": expected " << param_tensor.sizes() << ", got " << tensor.sizes(); } } @@ -359,8 +364,8 @@ size_t Module::load(const StateDict& state_dict, if (child.selector) { // select state dict for the child module const auto child_state_dict = child.selector(state_dict, key); - total_loaded += - child.module->load(child_state_dict, join_name(name_prefix, key)); + total_loaded += child.module->load(child_state_dict, + detail::join_name(name_prefix, key)); } else { total_loaded += child.module->load(state_dict, name_prefix); } @@ -376,7 +381,7 @@ bool Module::verify(const std::string& name_prefix) const { const auto& key = item.key(); const auto& param = item.value(); if (!param.is_loaded) { - LOG(ERROR) << "Missing parameter: " << join_name(name_prefix, key) + LOG(ERROR) << "Missing parameter: " << detail::join_name(name_prefix, key) << ", size: " << param.tensor.sizes(); } all_loaded = all_loaded && param.is_loaded; @@ -386,7 +391,7 @@ bool Module::verify(const std::string& name_prefix) const { const auto& key = item.key(); const auto& child = item.value(); const std::string prefix = - child.selector ? join_name(name_prefix, key) : name_prefix; + child.selector ? detail::join_name(name_prefix, key) : name_prefix; const bool child_loaded = child.module->verify(prefix); all_loaded = all_loaded && child_loaded; } diff --git a/src/module/module.h b/src/layers/module/module.h similarity index 97% rename from src/module/module.h rename to src/layers/module/module.h index 6cb96ff6..664f4d9c 100644 --- a/src/module/module.h +++ b/src/layers/module/module.h @@ -1,3 +1,5 @@ +// Adapted from +// https://github.com/pytorch/pytorch/blob/main/torch/csrc/api/include/torch/nn/module.h #pragma once #include @@ -15,6 +17,12 @@ namespace llm { +namespace detail { +/// Joins names hierarchically: "name_prefix.name" if `name_prefix` is +/// non-empty, else just "name". +std::string join_name(const std::string& name_prefix, const std::string& name); +} // namespace detail + /// The base class for all modules. /// /// A `Module` is an abstraction over the implementation of some function or diff --git a/src/module/module_holder.h b/src/layers/module/module_holder.h similarity index 98% rename from src/module/module_holder.h rename to src/layers/module/module_holder.h index abe131a5..c17aac06 100644 --- a/src/module/module_holder.h +++ b/src/layers/module/module_holder.h @@ -1,3 +1,5 @@ +// Adapted from +// https://github.com/pytorch/pytorch/blob/main/torch/csrc/api/include/torch/nn/pimpl.h #pragma once #include diff --git a/src/module/module_list.h b/src/layers/module/module_list.h similarity index 97% rename from src/module/module_list.h rename to src/layers/module/module_list.h index ce0bfee6..56cf833c 100644 --- a/src/module/module_list.h +++ b/src/layers/module/module_list.h @@ -1,3 +1,5 @@ +// Adapted from +// https://github.com/pytorch/pytorch/blob/main/torch/csrc/api/include/torch/nn/modules/container/modulelist.h #pragma once #include diff --git a/src/module/module_test.cpp b/src/layers/module/module_test.cpp similarity index 100% rename from src/module/module_test.cpp rename to src/layers/module/module_test.cpp diff --git a/src/layers/normalization.h b/src/layers/normalization.h index da75a81b..6011589c 100644 --- a/src/layers/normalization.h +++ b/src/layers/normalization.h @@ -6,9 +6,9 @@ #include #include "kernels/layernorm_kernels.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" #include "model_loader/state_dict.h" -#include "module/module.h" -#include "module/module_holder.h" DECLARE_bool(disable_custom_kernels); namespace llm { diff --git a/src/quantization/CMakeLists.txt b/src/layers/quantization/CMakeLists.txt similarity index 96% rename from src/quantization/CMakeLists.txt rename to src/layers/quantization/CMakeLists.txt index 325b0d02..94761865 100644 --- a/src/quantization/CMakeLists.txt +++ b/src/layers/quantization/CMakeLists.txt @@ -2,9 +2,9 @@ include(cc_library) include(cc_test) cc_library( - NAME + NAME quantization - HDRS + HDRS pack_utils.h qlinear_impl.h qlinear_gptq_impl.h @@ -12,7 +12,7 @@ cc_library( qlinear_awq_impl.h qlinear_gptq_marlin_impl.h qlinear_awq_marlin_impl.h - SRCS + SRCS pack_utils.cpp qlinear_impl.cpp qlinear_gptq_impl.cpp @@ -42,7 +42,7 @@ cc_test( :quantization :state_dict :gtest_main - DATA + DATA data/gptq_small.safetensors data/gptq.safetensors ) diff --git a/src/quantization/data/gptq.safetensors b/src/layers/quantization/data/gptq.safetensors similarity index 100% rename from src/quantization/data/gptq.safetensors rename to src/layers/quantization/data/gptq.safetensors diff --git a/src/quantization/data/gptq_small.safetensors b/src/layers/quantization/data/gptq_small.safetensors similarity index 100% rename from src/quantization/data/gptq_small.safetensors rename to src/layers/quantization/data/gptq_small.safetensors diff --git a/src/quantization/pack_utils.cpp b/src/layers/quantization/pack_utils.cpp similarity index 98% rename from src/quantization/pack_utils.cpp rename to src/layers/quantization/pack_utils.cpp index 882823ef..f872efb7 100644 --- a/src/quantization/pack_utils.cpp +++ b/src/layers/quantization/pack_utils.cpp @@ -75,4 +75,4 @@ torch::Tensor unpack_cols( return qweight.to(packed_qweight); } -} // namespace llm::pack_utils \ No newline at end of file +} // namespace llm::pack_utils diff --git a/src/quantization/pack_utils.h b/src/layers/quantization/pack_utils.h similarity index 93% rename from src/quantization/pack_utils.h rename to src/layers/quantization/pack_utils.h index c98355c8..e8db0eea 100644 --- a/src/quantization/pack_utils.h +++ b/src/layers/quantization/pack_utils.h @@ -14,4 +14,4 @@ torch::Tensor unpack_cols( const torch::Tensor& packed_qweight, // (m, n/pack_factor) int64_t num_bits); -} // namespace llm::pack_utils \ No newline at end of file +} // namespace llm::pack_utils diff --git a/src/quantization/pack_utils_test.cpp b/src/layers/quantization/pack_utils_test.cpp similarity index 97% rename from src/quantization/pack_utils_test.cpp rename to src/layers/quantization/pack_utils_test.cpp index 52125db9..ed13132f 100644 --- a/src/quantization/pack_utils_test.cpp +++ b/src/layers/quantization/pack_utils_test.cpp @@ -28,4 +28,4 @@ TEST(PackUtilsTest, Basic) { EXPECT_TRUE(unpacked_qweight.equal(qweight)); } -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/quantization/qlinear_awq_impl.cpp b/src/layers/quantization/qlinear_awq_impl.cpp similarity index 100% rename from src/quantization/qlinear_awq_impl.cpp rename to src/layers/quantization/qlinear_awq_impl.cpp diff --git a/src/quantization/qlinear_awq_impl.h b/src/layers/quantization/qlinear_awq_impl.h similarity index 100% rename from src/quantization/qlinear_awq_impl.h rename to src/layers/quantization/qlinear_awq_impl.h diff --git a/src/quantization/qlinear_awq_marlin_impl.cpp b/src/layers/quantization/qlinear_awq_marlin_impl.cpp similarity index 99% rename from src/quantization/qlinear_awq_marlin_impl.cpp rename to src/layers/quantization/qlinear_awq_marlin_impl.cpp index 5cf8c139..a3c9922f 100644 --- a/src/quantization/qlinear_awq_marlin_impl.cpp +++ b/src/layers/quantization/qlinear_awq_marlin_impl.cpp @@ -7,7 +7,7 @@ #include #include "kernels/quantization/marlin.h" -#include "layers/weight_utils.h" +#include "layers/linear/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/model_parallel.h" #include "pack_utils.h" diff --git a/src/quantization/qlinear_awq_marlin_impl.h b/src/layers/quantization/qlinear_awq_marlin_impl.h similarity index 97% rename from src/quantization/qlinear_awq_marlin_impl.h rename to src/layers/quantization/qlinear_awq_marlin_impl.h index 9d2d7fc3..fcdb5d11 100644 --- a/src/quantization/qlinear_awq_marlin_impl.h +++ b/src/layers/quantization/qlinear_awq_marlin_impl.h @@ -3,8 +3,8 @@ #include #include -#include "layers/linear.h" -#include "layers/weight_utils.h" +#include "layers/linear/parallel_linear.h" +#include "layers/linear/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" diff --git a/src/quantization/qlinear_exllamav2_impl.cpp b/src/layers/quantization/qlinear_exllamav2_impl.cpp similarity index 100% rename from src/quantization/qlinear_exllamav2_impl.cpp rename to src/layers/quantization/qlinear_exllamav2_impl.cpp diff --git a/src/quantization/qlinear_exllamav2_impl.h b/src/layers/quantization/qlinear_exllamav2_impl.h similarity index 100% rename from src/quantization/qlinear_exllamav2_impl.h rename to src/layers/quantization/qlinear_exllamav2_impl.h diff --git a/src/quantization/qlinear_gptq_impl.cpp b/src/layers/quantization/qlinear_gptq_impl.cpp similarity index 100% rename from src/quantization/qlinear_gptq_impl.cpp rename to src/layers/quantization/qlinear_gptq_impl.cpp diff --git a/src/quantization/qlinear_gptq_impl.h b/src/layers/quantization/qlinear_gptq_impl.h similarity index 100% rename from src/quantization/qlinear_gptq_impl.h rename to src/layers/quantization/qlinear_gptq_impl.h diff --git a/src/quantization/qlinear_gptq_marlin_impl.cpp b/src/layers/quantization/qlinear_gptq_marlin_impl.cpp similarity index 99% rename from src/quantization/qlinear_gptq_marlin_impl.cpp rename to src/layers/quantization/qlinear_gptq_marlin_impl.cpp index a19b7302..1fb11a72 100644 --- a/src/quantization/qlinear_gptq_marlin_impl.cpp +++ b/src/layers/quantization/qlinear_gptq_marlin_impl.cpp @@ -5,7 +5,7 @@ #include #include "kernels/quantization/marlin.h" -#include "layers/weight_utils.h" +#include "layers/linear/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/model_parallel.h" diff --git a/src/quantization/qlinear_gptq_marlin_impl.h b/src/layers/quantization/qlinear_gptq_marlin_impl.h similarity index 97% rename from src/quantization/qlinear_gptq_marlin_impl.h rename to src/layers/quantization/qlinear_gptq_marlin_impl.h index e996b479..f44fb9a6 100644 --- a/src/quantization/qlinear_gptq_marlin_impl.h +++ b/src/layers/quantization/qlinear_gptq_marlin_impl.h @@ -3,8 +3,8 @@ #include #include -#include "layers/linear.h" -#include "layers/weight_utils.h" +#include "layers/linear/parallel_linear.h" +#include "layers/linear/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" diff --git a/src/quantization/qlinear_impl.cpp b/src/layers/quantization/qlinear_impl.cpp similarity index 89% rename from src/quantization/qlinear_impl.cpp rename to src/layers/quantization/qlinear_impl.cpp index 695d0dad..bb41ec04 100644 --- a/src/quantization/qlinear_impl.cpp +++ b/src/layers/quantization/qlinear_impl.cpp @@ -4,7 +4,6 @@ #include #include -#include "layers/linear_impl.h" #include "model_loader/state_dict.h" namespace llm { @@ -117,7 +116,8 @@ ColumnParallelQLinearImpl::ColumnParallelQLinearImpl( quant_args.group_size() > 0 ? quant_args.group_size() : in_features; CHECK(qweight_pack_dim == 0 || qweight_pack_dim == 1) << "qweight_pack_dim must be 0 or 1"; - const int64_t world_size = parallel_args.world_size(); + const auto rank = parallel_args_.rank(); + const auto world_size = parallel_args_.world_size(); CHECK(out_features % world_size == 0) << "out_features " << out_features << " not divisible by world_size " << world_size; @@ -125,29 +125,46 @@ ColumnParallelQLinearImpl::ColumnParallelQLinearImpl( const int64_t pack_factor = 32 / bits; if (qweight_pack_dim == 0) { - qweight_ = register_parameter( + qweight_ = register_sharded_parameter( "qweight", + /*dim=*/1, + rank, + world_size, torch::empty({in_features / pack_factor, out_features_per_partition}, options.dtype(torch::kInt32))); } else { - qweight_ = register_parameter( + qweight_ = register_sharded_parameter( "qweight", + /*dim=*/1, + rank, + world_size, torch::empty({in_features, out_features_per_partition / pack_factor}, options.dtype(torch::kInt32))); } - qzeros_ = register_parameter( + qzeros_ = register_sharded_parameter( "qzeros", + /*dim=*/1, + rank, + world_size, torch::empty({round_up(in_features, group_size), out_features_per_partition / pack_factor}, options.dtype(torch::kInt32))); - scales_ = register_parameter("scales", - torch::empty({round_up(in_features, group_size), - out_features_per_partition}, - options)); + scales_ = register_sharded_parameter( + "scales", + /*dim=*/1, + rank, + world_size, + torch::empty( + {round_up(in_features, group_size), out_features_per_partition}, + options)); if (bias) { - bias_ = register_parameter( - "bias", torch::empty({out_features_per_partition}, options)); + bias_ = register_sharded_parameter( + "bias", + /*dim=*/0, + rank, + world_size, + torch::empty({out_features_per_partition}, options)); } } @@ -226,7 +243,8 @@ RowParallelQLinearImpl::RowParallelQLinearImpl( const auto bits = quant_args.bits(); CHECK(qweight_pack_dim == 0 || qweight_pack_dim == 1) << "qweight_pack_dim must be 0 or 1"; - const int64_t world_size = parallel_args.world_size(); + const auto rank = parallel_args_.rank(); + const auto world_size = parallel_args_.world_size(); CHECK(in_features % world_size == 0) << "in_features " << in_features << " not divisible by world_size " << world_size; @@ -236,24 +254,36 @@ RowParallelQLinearImpl::RowParallelQLinearImpl( quant_args.group_size() > 0 ? quant_args.group_size() : in_features; if (qweight_pack_dim == 0) { - qweight_ = register_parameter( + qweight_ = register_sharded_parameter( "qweight", + /*dim=*/0, + rank, + world_size, torch::empty({in_features_per_partition / pack_factor, out_features}, options.dtype(torch::kInt32))); } else { - qweight_ = register_parameter( + qweight_ = register_sharded_parameter( "qweight", + /*dim=*/0, + rank, + world_size, torch::empty({in_features_per_partition, out_features / pack_factor}, options.dtype(torch::kInt32))); } - qzeros_ = register_parameter( + qzeros_ = register_sharded_parameter( "qzeros", + /*dim=*/0, + rank, + world_size, torch::empty({round_up(in_features_per_partition, group_size), out_features / pack_factor}, options.dtype(torch::kInt32))); - scales_ = register_parameter( + scales_ = register_sharded_parameter( "scales", + /*dim=*/0, + rank, + world_size, torch::empty( {round_up(in_features_per_partition, group_size), out_features}, options)); diff --git a/src/quantization/qlinear_impl.h b/src/layers/quantization/qlinear_impl.h similarity index 98% rename from src/quantization/qlinear_impl.h rename to src/layers/quantization/qlinear_impl.h index 0269246f..d02e2186 100644 --- a/src/quantization/qlinear_impl.h +++ b/src/layers/quantization/qlinear_impl.h @@ -3,11 +3,10 @@ #include #include -#include "layers/linear_impl.h" -#include "layers/weight_utils.h" +#include "layers/linear/parallel_linear.h" +#include "layers/linear/weight_utils.h" #include "model_loader/state_dict.h" #include "model_parallel/model_parallel.h" -#include "models/model_args.h" namespace llm { diff --git a/src/quantization/qlinear_impl_test.cpp b/src/layers/quantization/qlinear_impl_test.cpp similarity index 96% rename from src/quantization/qlinear_impl_test.cpp rename to src/layers/quantization/qlinear_impl_test.cpp index 26e50044..26291ab4 100644 --- a/src/quantization/qlinear_impl_test.cpp +++ b/src/layers/quantization/qlinear_impl_test.cpp @@ -46,8 +46,8 @@ TEST(QlinearTest, ColumnParallelQuantLinear) { /*bits=*/4); weights = weights.to(torch::kCUDA); - qlinear.load_state_dict(*state_dict); - qlinear.verify_loaded_weights(); + qlinear.load(*state_dict); + EXPECT_TRUE(qlinear.verify()); auto input = torch::rand({40960, in_features}, options); auto output = qlinear.forward(input); @@ -83,8 +83,8 @@ TEST(QlinearTest, RowParallelQuantLinear) { /*bits=*/4); weights = weights.to(torch::kCUDA); - qlinear.load_state_dict(*state_dict); - qlinear.verify_loaded_weights(); + qlinear.load(*state_dict); + EXPECT_TRUE(qlinear.verify()); auto input = torch::rand({40960, in_features}, options); auto output = qlinear.forward(input); diff --git a/src/quantization/quant_args.h b/src/layers/quantization/quant_args.h similarity index 100% rename from src/quantization/quant_args.h rename to src/layers/quantization/quant_args.h diff --git a/src/model_loader/args_overrider.h b/src/model_loader/args_overrider.h index 99b38fc6..83132873 100644 --- a/src/model_loader/args_overrider.h +++ b/src/model_loader/args_overrider.h @@ -2,8 +2,8 @@ #include +#include "layers/quantization/quant_args.h" #include "models/model_args.h" -#include "quantization/quant_args.h" #include "tokenizer/tokenizer_args.h" // Model args flags @@ -56,4 +56,4 @@ void override_args_from_gflag(ModelArgs& args, QuantArgs& quant_args, TokenizerArgs& tokenizer_args); -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/model_loader/model_loader.h b/src/model_loader/model_loader.h index dbac26cf..b1ac0782 100644 --- a/src/model_loader/model_loader.h +++ b/src/model_loader/model_loader.h @@ -4,9 +4,9 @@ #include +#include "layers/quantization/quant_args.h" #include "model_loader/state_dict.h" #include "models/model_args.h" -#include "quantization/quant_args.h" #include "tokenizer/tokenizer.h" #include "tokenizer/tokenizer_args.h" diff --git a/src/models/_deprecated/aquila.h b/src/models/_deprecated/aquila.h index 59e7b3e4..c76343a2 100644 --- a/src/models/_deprecated/aquila.h +++ b/src/models/_deprecated/aquila.h @@ -9,14 +9,14 @@ #include "layers/embedding.h" #include "layers/fused_linear.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Aquila model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/_deprecated/baichuan.h b/src/models/_deprecated/baichuan.h index db77fcdc..8fa2b56d 100644 --- a/src/models/_deprecated/baichuan.h +++ b/src/models/_deprecated/baichuan.h @@ -11,14 +11,14 @@ #include "layers/embedding.h" #include "layers/fused_linear.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Baichuan model compatible with huggingface weights diff --git a/src/models/_deprecated/bloom.h b/src/models/_deprecated/bloom.h index 0499993b..3cb07b7b 100644 --- a/src/models/_deprecated/bloom.h +++ b/src/models/_deprecated/bloom.h @@ -8,14 +8,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // bloom model compatible with huggingface weights diff --git a/src/models/_deprecated/chatglm.h b/src/models/_deprecated/chatglm.h index cd8c71b3..d9729f96 100644 --- a/src/models/_deprecated/chatglm.h +++ b/src/models/_deprecated/chatglm.h @@ -9,14 +9,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" #include "tokenizer/tokenizer_args.h" // ChatGLM model compatible with huggingface weights diff --git a/src/models/_deprecated/gpt_j.h b/src/models/_deprecated/gpt_j.h index 3905860b..1485b1b7 100644 --- a/src/models/_deprecated/gpt_j.h +++ b/src/models/_deprecated/gpt_j.h @@ -7,14 +7,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // GPTJ model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/_deprecated/gpt_neox.h b/src/models/_deprecated/gpt_neox.h index 65a2a1a3..e4ec964b 100644 --- a/src/models/_deprecated/gpt_neox.h +++ b/src/models/_deprecated/gpt_neox.h @@ -7,14 +7,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // gpt-neox model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/_deprecated/internlm.h b/src/models/_deprecated/internlm.h index 8f100c32..df145482 100644 --- a/src/models/_deprecated/internlm.h +++ b/src/models/_deprecated/internlm.h @@ -9,14 +9,14 @@ #include "layers/embedding.h" #include "layers/fused_linear.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Internlm model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/_deprecated/mistral.h b/src/models/_deprecated/mistral.h index 9b8e8ad0..f7d340d7 100644 --- a/src/models/_deprecated/mistral.h +++ b/src/models/_deprecated/mistral.h @@ -7,15 +7,15 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "layers/qkv_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Mistral model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/_deprecated/mpt.h b/src/models/_deprecated/mpt.h index b435b828..9e06c11b 100644 --- a/src/models/_deprecated/mpt.h +++ b/src/models/_deprecated/mpt.h @@ -10,14 +10,14 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // mpt model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/_deprecated/simple_model.h b/src/models/_deprecated/simple_model.h index dc22d3a0..ed7000bb 100644 --- a/src/models/_deprecated/simple_model.h +++ b/src/models/_deprecated/simple_model.h @@ -7,13 +7,13 @@ #include "layers/embedding.h" #include "layers/fused_linear.h" #include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // simple model for test namespace llm { diff --git a/src/models/alibaba/qwen.h b/src/models/alibaba/qwen.h index 61cf7490..5f03c103 100644 --- a/src/models/alibaba/qwen.h +++ b/src/models/alibaba/qwen.h @@ -10,16 +10,15 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/fused_linear.h" -#include "layers/linear.h" +#include "layers/linear/multi_parallel_linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // QWen model compatible with huggingface weights // Adapted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py namespace llm::hf { @@ -41,7 +40,7 @@ class QWenMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + MultiColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, std::vector{"w1.", "w2."}, @@ -68,7 +67,7 @@ class QWenMLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + MultiColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear c_proj_{nullptr}; ActFunc act_{nullptr}; diff --git a/src/models/alibaba/qwen2.h b/src/models/alibaba/qwen2.h index 84758c2d..41dcd301 100644 --- a/src/models/alibaba/qwen2.h +++ b/src/models/alibaba/qwen2.h @@ -10,16 +10,15 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear.h" +#include "layers/linear/qkv_parallel_linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" -#include "layers/qkv_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // QWen2 model compatible with huggingface weights // ref to: // https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/qwen2/modeling_qwen2.py @@ -40,7 +39,7 @@ class QWen2MLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + MultiColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, std::vector{"gate_proj.", "up_proj."}, @@ -68,7 +67,7 @@ class QWen2MLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + MultiColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function @@ -133,10 +132,9 @@ class QWen2AttentionImpl : public Module { const InputParameters& input_params) { // (num_tokens, dim) x (dim, n_local_heads * head_dim) // => (num_tokens, n_local_heads * head_dim) - const auto qkv = qkv_proj_(x); + const auto [q, k, v] = qkv_proj_(x); // calculate attention, output: (num_tokens, n_local_heads * head_dim) - const auto output = - atten_(qkv[0], qkv[1], qkv[2], positions, kv_cache, input_params); + const auto output = atten_(q, k, v, positions, kv_cache, input_params); return o_proj_(output); } diff --git a/src/models/causal_lm.h b/src/models/causal_lm.h index eaed0ffe..e65aa368 100644 --- a/src/models/causal_lm.h +++ b/src/models/causal_lm.h @@ -5,12 +5,12 @@ #include +#include "layers/quantization/quant_args.h" #include "memory/kv_cache.h" #include "model_args.h" #include "model_loader/state_dict.h" #include "model_parallel/parallel_args.h" #include "parameters.h" -#include "quantization/quant_args.h" namespace llm { diff --git a/src/models/google/gemma.h b/src/models/google/gemma.h index 7934c9b6..d07bb248 100644 --- a/src/models/google/gemma.h +++ b/src/models/google/gemma.h @@ -10,17 +10,16 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear.h" -#include "layers/linear_impl.h" +#include "layers/linear/parallel_linear.h" +#include "layers/linear/qkv_parallel_linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" -#include "layers/qkv_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Gemma model compatible with huggingface weight namespace llm::hf { @@ -40,7 +39,7 @@ class GemmaMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + MultiColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, std::vector{"gate_proj.", "up_proj."}, @@ -68,7 +67,7 @@ class GemmaMLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + MultiColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function @@ -128,11 +127,10 @@ class GemmaAttentionImpl : public Module { const InputParameters& input_params) { // (num_tokens, dim) x (dim, n_local_heads * head_dim) // => (num_tokens, n_local_heads * head_dim) - const auto qkv = qkv_proj_(x); + const auto [q, k, v] = qkv_proj_(x); // calculate attention, // output: (num_tokens, n_local_heads*head_dim) - const auto output = - atten_(qkv[0], qkv[1], qkv[2], positions, kv_cache, input_params); + const auto output = atten_(q, k, v, positions, kv_cache, input_params); return o_proj_(output); } diff --git a/src/models/google/gemma2.h b/src/models/google/gemma2.h index 6b4e75d6..200da7c9 100644 --- a/src/models/google/gemma2.h +++ b/src/models/google/gemma2.h @@ -10,16 +10,15 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear.h" +#include "layers/linear/qkv_parallel_linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" -#include "layers/qkv_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Gemma2 model compatible with huggingface weight namespace llm::hf { @@ -39,7 +38,7 @@ class Gemma2MLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + MultiColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, std::vector{"gate_proj.", "up_proj."}, @@ -67,7 +66,7 @@ class Gemma2MLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + MultiColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function @@ -132,11 +131,10 @@ class Gemma2AttentionImpl : public Module { const InputParameters& input_params) { // (num_tokens, dim) x (dim, n_local_heads * head_dim) // => (num_tokens, n_local_heads * head_dim) - const auto qkv = qkv_proj_(x); + const auto [q, k, v] = qkv_proj_(x); // calculate attention, // output: (num_tokens, n_local_heads*head_dim) - const auto output = - atten_(qkv[0], qkv[1], qkv[2], positions, kv_cache, input_params); + const auto output = atten_(q, k, v, positions, kv_cache, input_params); return o_proj_(output); } diff --git a/src/models/meta/llama.h b/src/models/meta/llama.h index d5b3a17d..22a0cd39 100644 --- a/src/models/meta/llama.h +++ b/src/models/meta/llama.h @@ -8,17 +8,16 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/fused_linear.h" -#include "layers/linear.h" +#include "layers/linear/multi_parallel_linear.h" +#include "layers/linear/qkv_parallel_linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" -#include "layers/qkv_linear.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // llama2 model compatible with huggingface weights namespace llm::hf { @@ -37,7 +36,7 @@ class LlamaMLPImpl : public Module { // register the weight parameter gate_up_proj_ = register_module( "gate_up_proj", - FusedColumnParallelLinear( + MultiColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, std::vector{"gate_proj.", "up_proj."}, @@ -66,7 +65,7 @@ class LlamaMLPImpl : public Module { private: // parameter members, must be registered - FusedColumnParallelLinear gate_up_proj_{nullptr}; + MultiColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear down_proj_{nullptr}; // activation function @@ -127,10 +126,9 @@ class LlamaAttentionImpl : public Module { const InputParameters& input_params) { // (num_tokens, dim) x (dim, n_local_heads * head_dim) // => (num_tokens, n_local_heads * head_dim) - const auto qkv = qkv_proj_(x); + const auto [q, k, v] = qkv_proj_(x); // calculate attention, output: (num_tokens, n_local_heads * head_dim) - const auto output = - atten_(qkv[0], qkv[1], qkv[2], positions, kv_cache, input_params); + const auto output = atten_(q, k, v, positions, kv_cache, input_params); return o_proj_(output); } diff --git a/src/models/microsoft/phi.h b/src/models/microsoft/phi.h index f81ed123..1aa97bb8 100644 --- a/src/models/microsoft/phi.h +++ b/src/models/microsoft/phi.h @@ -6,15 +6,14 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // Phi model compatible with huggingface weights namespace llm::hf { diff --git a/src/models/model_registry.h b/src/models/model_registry.h index a15f7138..93590ad3 100644 --- a/src/models/model_registry.h +++ b/src/models/model_registry.h @@ -8,9 +8,9 @@ #include "chat_template/chat_template.h" #include "common/json_reader.h" #include "common/type_traits.h" // IWYU pragma: keep +#include "layers/quantization/quant_args.h" #include "model_args.h" #include "model_parallel/parallel_args.h" -#include "quantization/quant_args.h" #include "tokenizer/tokenizer_args.h" namespace llm { diff --git a/src/models/openai/gpt2.h b/src/models/openai/gpt2.h index 2644ec48..3cd6a167 100644 --- a/src/models/openai/gpt2.h +++ b/src/models/openai/gpt2.h @@ -7,16 +7,15 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" -#include "layers/linear.h" -#include "layers/linear_impl.h" +#include "layers/linear/parallel_linear.h" +#include "layers/module/module.h" +#include "layers/module/module_holder.h" +#include "layers/module/module_list.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" #include "models/model_registry.h" #include "models/parameters.h" -#include "module/module.h" -#include "module/module_holder.h" -#include "module/module_list.h" // gpt2 model compatible with huggingface weights