diff --git a/README.md b/README.md index 8d64eef0..14b8faa2 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,12 @@ ScaleLLM
-[ScaleLLM](#) is a cutting-edge inference system engineered for large language models (LLMs), designed to meet the demands of production environments. It extends its support to a wide range of popular open-source models, including [Llama3.1](https://github.com/meta-llama/llama3), [Gemma2](https://github.com/google-deepmind/gemma), Bloom, GPT-NeoX, and more. +[ScaleLLM](#) is a cutting-edge inference system engineered for large language models (LLMs), designed to meet the demands of production environments. It extends its support to a wide range of popular open-source models, including [Llama3.1](https://github.com/meta-llama/llama3), [Gemma2](https://github.com/google-deepmind/gemma), [Phi](https://huggingface.co/microsoft/phi-2), and more. ScaleLLM is currently undergoing active development. We are fully committed to consistently enhancing its efficiency while also incorporating additional features. Feel free to explore our [**_Roadmap_**](https://github.com/vectorch-ai/ScaleLLM/issues/84) for more details. ## News: +* [01/2025] - Optimized inhouse Attention kernels * [06/2024] - ScaleLLM is now available on [PyPI](https://pypi.org/project/scalellm/). You can install it using `pip install scalellm`. * [03/2024] - [Advanced features](#advanced-features) support for ✅ [CUDA graph](#cuda-graph), ✅ [prefix cache](#prefix-cache), ✅ [chunked prefill](#chunked-prefill) and ✅ [speculative decoding](#speculative-decoding). * [11/2023] - [First release](https://github.com/vectorch-ai/ScaleLLM/releases/tag/v0.0.1) with support for popular [open-source models](#supported-models). @@ -274,21 +275,11 @@ Quantization is a crucial process for reducing the memory footprint of models. S | Models | Tensor Parallel | Quantization | Chat API | HF models examples | | :--------: | :-------------: | :----------: | :------: | :---------------------------:| -| Aquila | Yes | Yes | Yes | [BAAI/Aquila-7B](https://huggingface.co/BAAI/Aquila-7B), [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) | -| Bloom | Yes | Yes | No | [bigscience/bloom](https://huggingface.co/bigscience/bloom) | -| Baichuan | Yes | Yes | Yes | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) | -| ChatGLM4/3 | Yes | Yes | Yes | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) | | Gemma2 | Yes | Yes | Yes | [google/gemma-2-2b](https://huggingface.co/google/gemma-2-2b) | -| GPT_j | Yes | Yes | No | [EleutherAI/gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b) | -| GPT_NeoX | Yes | Yes | No | [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b) | | GPT2 | Yes | Yes | No | [gpt2](https://huggingface.co/gpt2)| -| InternLM | Yes | Yes | Yes | [internlm/internlm-7b](https://huggingface.co/internlm/internlm-7b) | | Llama3/2 | Yes | Yes | Yes | [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct), [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | -| Mistral | Yes | Yes | Yes | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) | -| MPT | Yes | Yes | Yes | [mosaicml/mpt-30b](https://huggingface.co/mosaicml/mpt-30b) | | Phi2 | Yes | Yes | No | [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) | | Qwen2 | Yes | Yes | Yes | [Qwen/Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat) | -| Yi | Yes | Yes | Yes |[01-ai/Yi-6B](https://huggingface.co/01-ai/Yi-6B), [01-ai/Yi-34B-Chat-4bits](https://huggingface.co/01-ai/Yi-34B-Chat-4bits), [01-ai/Yi-6B-200K](https://huggingface.co/01-ai/Yi-6B-200K) | If your model is not included in the supported list, we are more than willing to assist you. Please feel free to create a request for adding a new model on [GitHub Issues](https://github.com/vectorch-ai/ScaleLLM/issues). diff --git a/src/layers/embedding.h b/src/layers/embedding.h index 9abe231d..9ef07033 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -33,26 +33,6 @@ class EmbeddingImpl : public Module { return F::embedding(input, weight_); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_tensor("weight"); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - is_loaded_ = true; - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix) const { - CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight"; - } - - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - // return the weight (for testing) torch::Tensor weight() const { return weight_; } @@ -73,6 +53,7 @@ class ParallelEmbeddingImpl : public Module { const ParallelArgs& parallel_args, const torch::TensorOptions& options) : parallel_args_(parallel_args) { + const auto rank = parallel_args_.rank(); const auto world_size = parallel_args_.world_size(); CHECK(embedding_dim % world_size == 0) << "out_features " << embedding_dim << " not divisible by world_size " @@ -80,8 +61,11 @@ class ParallelEmbeddingImpl : public Module { const int64_t embedding_dim_per_partition = embedding_dim / world_size; // register the weight parameter - weight_ = register_parameter( + weight_ = register_sharded_parameter( "weight", + /*dim=*/1, + rank, + world_size, torch::empty({num_embeddings, embedding_dim_per_partition}, options)); } @@ -96,30 +80,6 @@ class ParallelEmbeddingImpl : public Module { return output; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_sharded_tensor( - "weight", - /*dim=*/1, - /*rank=*/parallel_args_.rank(), - /*world_size=*/parallel_args_.world_size()); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - is_loaded_ = true; - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix) const { - CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight"; - } - - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - // return the weight (for testing) torch::Tensor weight() const { return weight_; } @@ -143,14 +103,19 @@ class VocabParallelEmbeddingImpl : public Module { const ParallelArgs& parallel_args, const torch::TensorOptions& options) : parallel_args_(parallel_args) { - const int64_t num_embeddings_per_partition = - num_embeddings / parallel_args_.world_size(); - start_index_ = num_embeddings_per_partition * parallel_args_.rank(); + const auto rank = parallel_args_.rank(); + const auto world_size = parallel_args_.world_size(); + + const int64_t num_embeddings_per_partition = num_embeddings / world_size; + start_index_ = num_embeddings_per_partition * rank; end_index_ = start_index_ + num_embeddings_per_partition; // register the weight parameter - weight_ = register_parameter( + weight_ = register_sharded_parameter( "weight", + /*dim=*/0, + rank, + world_size, torch::empty({num_embeddings_per_partition, embedding_dim}, options)); } @@ -174,30 +139,6 @@ class VocabParallelEmbeddingImpl : public Module { return reduce_from_model_parallel_region(output, parallel_args_); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_sharded_tensor( - "weight", - /*dim=*/0, - /*rank=*/parallel_args_.rank(), - /*world_size=*/parallel_args_.world_size()); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - is_loaded_ = true; - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const { - CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight"; - } - - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - // return the weight (for testing) torch::Tensor weight() const { return weight_; } diff --git a/src/layers/fused_linear.cpp b/src/layers/fused_linear.cpp index 917ab6ec..86b5ca59 100644 --- a/src/layers/fused_linear.cpp +++ b/src/layers/fused_linear.cpp @@ -13,11 +13,13 @@ namespace llm { FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl( int64_t in_features, const std::vector& out_features_vec, + const std::vector& prefixes, bool bias, bool gather_output, const QuantArgs& quant_args, const ParallelArgs& parallel_args, const torch::TensorOptions& options) { + prefixes_ = prefixes; // check if the linear layers can be fused fused_ = quant_args.can_be_fused(); if (fused_) { @@ -72,28 +74,29 @@ std::vector FusedColumnParallelLinearImpl::forward( return outputs; } -void FusedColumnParallelLinearImpl::load_state_dict( - const StateDict& state_dict, - const std::vector& prefixes) { +size_t FusedColumnParallelLinearImpl::load(const StateDict& state_dict, + const std::string&) { if (fused_) { - fused_linear_->load_state_dict(state_dict, prefixes); + fused_linear_->load_state_dict(state_dict, prefixes_); } else { - CHECK_EQ(parallel_linears_.size(), prefixes.size()); + CHECK_EQ(parallel_linears_.size(), prefixes_.size()); for (size_t i = 0; i < parallel_linears_.size(); ++i) { - parallel_linears_[i]->load_state_dict(state_dict.select(prefixes[i])); + parallel_linears_[i]->load_state_dict(state_dict.select(prefixes_[i])); } } + return 0; } -void FusedColumnParallelLinearImpl::verify_loaded_weights( - const std::string& prefix) const { +bool FusedColumnParallelLinearImpl::verify( + const std::string& name_prefix) const { if (fused_) { - fused_linear_->verify_loaded_weights(prefix); + fused_linear_->verify_loaded_weights(name_prefix); } else { for (const auto& parallel_linear : parallel_linears_) { - parallel_linear->verify_loaded_weights(prefix); + parallel_linear->verify_loaded_weights(name_prefix); } } + return true; } } // namespace llm diff --git a/src/layers/fused_linear.h b/src/layers/fused_linear.h index 191922e6..73323479 100644 --- a/src/layers/fused_linear.h +++ b/src/layers/fused_linear.h @@ -16,6 +16,7 @@ class FusedColumnParallelLinearImpl : public Module { public: FusedColumnParallelLinearImpl(int64_t in_features, const std::vector& out_features, + const std::vector& prefixes, bool bias, bool gather_output, const QuantArgs& quant_args, @@ -24,11 +25,13 @@ class FusedColumnParallelLinearImpl : public Module { std::vector forward(torch::Tensor input); - // load_state_dict for fused weights - void load_state_dict(const StateDict& state_dict, - const std::vector& prefixes); + // load weights from the checkpoint, override this method if necessary + // returns the number of loaded parameters + size_t load(const StateDict& state_dict, + const std::string& name_prefix = std::string()) override; - void verify_loaded_weights(const std::string& prefix = "") const; + // verify whether the weights are loaded, override this method if necessary + bool verify(const std::string& name_prefix = std::string()) const override; // whether the linear layer is fused bool fused() const { return fused_; } @@ -43,6 +46,8 @@ class FusedColumnParallelLinearImpl : public Module { // sizes for each split std::vector split_sizes_; + std::vector prefixes_; + // whether the linear layer is fused bool fused_ = false; }; diff --git a/src/layers/linear_impl.cpp b/src/layers/linear_impl.cpp index 3f301868..7b6d8f04 100644 --- a/src/layers/linear_impl.cpp +++ b/src/layers/linear_impl.cpp @@ -18,6 +18,7 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl( const ParallelArgs& parallel_args, const torch::TensorOptions& options) : gather_output_(gather_output), parallel_args_(parallel_args) { + const auto rank = parallel_args_.rank(); const auto world_size = parallel_args_.world_size(); CHECK(out_features % world_size == 0) << "out_features " << out_features << " not divisible by world_size " @@ -26,13 +27,20 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl( // Note: torch.nn.functional.linear performs XA^T + b and as a result // we allocate the transpose. - weight_ = register_parameter( + weight_ = register_sharded_parameter( "weight", + /*dim=*/0, + rank, + world_size, torch::empty({out_features_per_partition, in_features}, options)); if (bias) { - bias_ = register_parameter( - "bias", torch::empty({out_features_per_partition}, options)); + bias_ = register_sharded_parameter( + "bias", + /*dim=*/0, + rank, + world_size, + torch::empty({out_features_per_partition}, options)); } } @@ -93,14 +101,18 @@ RowParallelLinearImpl::RowParallelLinearImpl( const torch::TensorOptions& options) : input_is_parallelized_(input_is_parallelized), parallel_args_(parallel_args) { + const auto rank = parallel_args_.rank(); const auto world_size = parallel_args_.world_size(); CHECK(in_features % world_size == 0) << "in_features " << in_features << " not divisible by world_size " << world_size; const int64_t in_features_per_partition = in_features / world_size; // Allocate the transpose since linear performs XA^T. - weight_ = register_parameter( + weight_ = register_sharded_parameter( "weight", + /*dim=*/1, + rank, + world_size, torch::empty({out_features, in_features_per_partition}, options)); if (bias) { diff --git a/src/layers/linear_impl.h b/src/layers/linear_impl.h index fe1f6b7c..ff551649 100644 --- a/src/layers/linear_impl.h +++ b/src/layers/linear_impl.h @@ -42,10 +42,6 @@ class ColumnParallelLinearImpl : public ParallelLinearImpl { << "bias is not loaded for " << prefix + "bias"; } - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - // return the weight (for testing) torch::Tensor weight() const { return weight_; } @@ -95,10 +91,6 @@ class RowParallelLinearImpl : public ParallelLinearImpl { << "bias is not loaded for " << prefix + "bias"; } - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - // return the weight (for testing) torch::Tensor weight() const { return weight_; } diff --git a/src/layers/normalization.h b/src/layers/normalization.h index 8996ce78..da75a81b 100644 --- a/src/layers/normalization.h +++ b/src/layers/normalization.h @@ -94,38 +94,6 @@ class LayerNormImpl : public Module { input, normalized_shape_, weight_, bias_, eps_); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_tensor("weight"); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - weight_is_loaded_ = true; - } - if (bias_.defined()) { - const auto bias = state_dict.get_tensor("bias"); - if (bias.defined()) { - CHECK_EQ(bias_.sizes(), bias.sizes()) - << "bias size mismatch for " << name(); - bias_.copy_(bias); - bias_is_loaded_ = true; - } - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const { - CHECK(weight_is_loaded_) - << "weight is not loaded for " << prefix + "weight"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; - } - - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - private: // parameter members, must be registered torch::Tensor weight_{nullptr}; @@ -159,26 +127,6 @@ class RMSNormImpl : public Module { return detail::rms_norm(input, weight_, eps_); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_tensor("weight"); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - is_loaded_ = true; - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const { - CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight"; - } - - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - private: // parameter members, must be registered torch::Tensor weight_{nullptr}; @@ -207,26 +155,6 @@ class GemmaRMSNormImpl : public Module { return detail::gemma_rms_norm(input, weight_, eps_); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_tensor("weight"); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - is_loaded_ = true; - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const { - CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight"; - } - - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - private: // parameter members, must be registered torch::Tensor weight_{nullptr}; @@ -268,26 +196,6 @@ class RMSNormResidualImpl : public Module { return detail::rms_norm(input, weight_, eps_); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_tensor("weight"); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - is_loaded_ = true; - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const { - CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight"; - } - - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - private: // parameter members, must be registered torch::Tensor weight_{nullptr}; diff --git a/src/layers/normalization_test.cpp b/src/layers/normalization_test.cpp index ba25019b..c2aa4031 100644 --- a/src/layers/normalization_test.cpp +++ b/src/layers/normalization_test.cpp @@ -26,8 +26,8 @@ TEST(NormalizationTest, LayerNorm) { LayerNorm norm(dim, eps, /*bias=*/true, options); // test load state dict - norm->load_state_dict(state_dict); - norm->verify_loaded_weights(); + norm->load(state_dict); + EXPECT_TRUE(norm->verify()); // verify output const auto input = torch::randn({100, dim}); @@ -88,8 +88,8 @@ TEST(NormalizationTest, RMSNorm) { RMSNorm norm(dim, eps, options); // test load state dict - norm->load_state_dict(state_dict); - norm->verify_loaded_weights(); + norm->load(state_dict); + EXPECT_TRUE(norm->verify()); // verify output const auto input = torch::randn({100, dim}); diff --git a/src/layers/qkv_linear.cpp b/src/layers/qkv_linear.cpp index fe4eab14..8ef9b532 100644 --- a/src/layers/qkv_linear.cpp +++ b/src/layers/qkv_linear.cpp @@ -10,15 +10,21 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl( int64_t n_heads, int64_t n_kv_heads, int64_t head_dim, + const std::vector& prefixes, bool bias, bool gather_output, const QuantArgs& quant_args, const ParallelArgs& parallel_args, - const torch::TensorOptions& options) - : n_kv_heads_(n_kv_heads), head_dim_(head_dim) { + const torch::TensorOptions& options) { + CHECK_EQ(prefixes.size(), 3) + << "prefixes size must be 3 for q, k, v projections"; + // calculate logical kv heads with support of MQA/GQA const int32_t world_size = parallel_args.world_size(); int64_t effective_kv_heads = n_kv_heads; + // replication ratio of kv heads for MQA/GQA cases + int64_t kv_replication_ratio = 0; + if (n_kv_heads >= world_size) { // partition kv heads evenly across world_size for MHA CHECK_EQ(n_kv_heads % world_size, 0) @@ -27,7 +33,7 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl( // replicate kv heads evenly across world_size for GQA/MQA CHECK_EQ(world_size % n_kv_heads, 0) << "kv heads can't be replicated evenly across world_size"; - kv_replication_ratio_ = world_size / n_kv_heads; + kv_replication_ratio = world_size / n_kv_heads; effective_kv_heads = world_size; } @@ -36,44 +42,43 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl( effective_kv_heads * head_dim, effective_kv_heads * head_dim}; - parallel_linear_ = register_module("parallel_linear", - FusedColumnParallelLinear(hidden_size, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options)); -} - -// special load_state_dict for fused cases -void QKVColumnParallelLinearImpl::load_state_dict( - const StateDict& state_dict, - const std::vector& prefixes, - const std::vector& kv_prefixes) { - if (kv_replication_ratio_ > 1) { - // replicate kv heads - auto kv_replicated_state_dict = state_dict.select_with_transform( - "", [&](const std::string& name, const torch::Tensor& tensor) { - for (const auto& kv_prefix : kv_prefixes) { - if (absl::StartsWith(name, kv_prefix)) { + // create state_dict selector to handle MQA/GQA cases + // for MQA/GQA cases, we need to replicate the weights of kv heads. + auto state_dict_selector = [=](const StateDict& sd, const std::string&) { + if (kv_replication_ratio <= 1) { + return sd.select(""); + } + // replicate kv heads for MQA/GQA cases + return sd.select_with_transform( + "", [=](const std::string& tensor_name, const torch::Tensor& tensor) { + // skip query weights + for (size_t i = 1; i < prefixes.size(); ++i) { + const auto& kv_prefix = prefixes[i]; + if (absl::StartsWith(tensor_name, kv_prefix)) { // reshape to [n_kv_heads, head_dim, ...] - auto reshaped_tensor = - tensor.reshape({n_kv_heads_, head_dim_, -1}); + auto reshaped_tensor = tensor.reshape({n_kv_heads, head_dim, -1}); // interleave repeat kv heads along kv_head dim reshaped_tensor = reshaped_tensor.repeat_interleave( - kv_replication_ratio_, /*dim=*/0); + kv_replication_ratio, /*dim=*/0); // reshape to [n_kv_heads * kv_replication_ratio * head_dim, ...] return reshaped_tensor.reshape( - {n_kv_heads_ * kv_replication_ratio_ * head_dim_, -1}); + {n_kv_heads * kv_replication_ratio * head_dim, -1}); } } return tensor; }); - parallel_linear_->load_state_dict(kv_replicated_state_dict, prefixes); - } else { - parallel_linear_->load_state_dict(state_dict, prefixes); - } + }; + + parallel_linear_ = register_module("qkv_parallel_linear", + FusedColumnParallelLinear(hidden_size, + out_features, + prefixes, + bias, + gather_output, + quant_args, + parallel_args, + options), + state_dict_selector); } } // namespace llm diff --git a/src/layers/qkv_linear.h b/src/layers/qkv_linear.h index 2e4e2e39..64923e27 100644 --- a/src/layers/qkv_linear.h +++ b/src/layers/qkv_linear.h @@ -20,6 +20,7 @@ class QKVColumnParallelLinearImpl : public Module { int64_t n_heads, int64_t n_kv_heads, int64_t head_dim, + const std::vector& prefixes, bool bias, bool gather_output, const QuantArgs& quant_args, @@ -30,24 +31,9 @@ class QKVColumnParallelLinearImpl : public Module { return parallel_linear_->forward(input); } - // special load_state_dict for fused cases - void load_state_dict(const StateDict& state_dict, - const std::vector& prefixes, - const std::vector& kv_prefixes); - - void verify_loaded_weights(const std::string& prefix = "") const { - parallel_linear_->verify_loaded_weights(prefix); - } - private: + // registered modules FusedColumnParallelLinear parallel_linear_{nullptr}; - - // replication ratio of kv heads for MQA/GQA cases - int64_t kv_replication_ratio_ = 0; - - int64_t n_kv_heads_ = 0; - - int64_t head_dim_ = 0; }; LLM_MODULE(QKVColumnParallelLinear); diff --git a/src/layers/qkv_linear_test.cpp b/src/layers/qkv_linear_test.cpp index 359b68e3..86ae2f6b 100644 --- a/src/layers/qkv_linear_test.cpp +++ b/src/layers/qkv_linear_test.cpp @@ -9,7 +9,7 @@ namespace llm { -class QKVLinearTest +class QKVColumnParallelLinearTest : public ::testing::TestWithParam> {}; -TEST_P(QKVLinearTest, LoadFusedWeight) { +TEST_P(QKVColumnParallelLinearTest, LoadFusedWeight) { const auto& [n_tokens, n_heads, n_kv_heads, n_shards, head_dim, hidden_size] = GetParam(); @@ -51,18 +51,19 @@ TEST_P(QKVLinearTest, LoadFusedWeight) { for (int32_t shard_id = 0; shard_id < n_shards; ++shard_id) { QuantArgs quant_args; ParallelArgs parallel_args(shard_id, n_shards, nullptr); - QKVColumnParallelLinearImpl linear(hidden_size, - n_heads, - n_kv_heads, - head_dim, - /*bias=*/false, - /*gather_output=*/false, - quant_args, - parallel_args, - options); - linear.load_state_dict(state_dict, - /*prefixes=*/{"query.", "key.", "value."}, - /*kv_prefixes=*/{"key.", "value."}); + QKVColumnParallelLinearImpl linear( + hidden_size, + n_heads, + n_kv_heads, + head_dim, + std::vector{"query.", "key.", "value."}, + /*bias=*/false, + /*gather_output=*/false, + quant_args, + parallel_args, + options); + linear.load(state_dict); + EXPECT_TRUE(linear.verify()); // generate random input and compare with the output auto input = torch::randn({n_tokens, hidden_size}, options); @@ -84,7 +85,7 @@ TEST_P(QKVLinearTest, LoadFusedWeight) { INSTANTIATE_TEST_SUITE_P( QKVLinearTestSuite, - QKVLinearTest, + QKVColumnParallelLinearTest, ::testing::Combine(::testing::Values(10, 32), // n_tokens ::testing::Values(8, 16, 32), // n_heads ::testing::Values(1, 2, 4, 8), // n_kv_heads diff --git a/src/model_loader/state_dict.cpp b/src/model_loader/state_dict.cpp index bdfc96b9..f20f9109 100644 --- a/src/model_loader/state_dict.cpp +++ b/src/model_loader/state_dict.cpp @@ -219,4 +219,14 @@ StateDict StateDict::select_with_transform( return selected; } +StateDict StateDict::select_with_transform( + const std::string& prefix, + std::function transform_func) const { + return select_with_transform( + prefix, + [transform_func](const std::string&, const torch::Tensor& tensor) { + return transform_func(tensor); + }); +} + } // namespace llm diff --git a/src/model_loader/state_dict.h b/src/model_loader/state_dict.h index 58c51039..d130c53e 100644 --- a/src/model_loader/state_dict.h +++ b/src/model_loader/state_dict.h @@ -46,6 +46,10 @@ class StateDict final { StateDict select_with_transform(const std::string& prefix, TensorTransform transform_func) const; + StateDict select_with_transform( + const std::string& prefix, + std::function transform_func) const; + size_t size() const { return dict_.size(); } std::string_view prefix() const { return prefix_; } diff --git a/src/models/aquila.h b/src/models/_deprecated/aquila.h similarity index 100% rename from src/models/aquila.h rename to src/models/_deprecated/aquila.h diff --git a/src/models/baichuan.h b/src/models/_deprecated/baichuan.h similarity index 100% rename from src/models/baichuan.h rename to src/models/_deprecated/baichuan.h diff --git a/src/models/bloom.h b/src/models/_deprecated/bloom.h similarity index 100% rename from src/models/bloom.h rename to src/models/_deprecated/bloom.h diff --git a/src/models/chatglm.h b/src/models/_deprecated/chatglm.h similarity index 100% rename from src/models/chatglm.h rename to src/models/_deprecated/chatglm.h diff --git a/src/models/gpt_j.h b/src/models/_deprecated/gpt_j.h similarity index 100% rename from src/models/gpt_j.h rename to src/models/_deprecated/gpt_j.h diff --git a/src/models/gpt_neox.h b/src/models/_deprecated/gpt_neox.h similarity index 100% rename from src/models/gpt_neox.h rename to src/models/_deprecated/gpt_neox.h diff --git a/src/models/internlm.h b/src/models/_deprecated/internlm.h similarity index 100% rename from src/models/internlm.h rename to src/models/_deprecated/internlm.h diff --git a/src/models/mistral.h b/src/models/_deprecated/mistral.h similarity index 100% rename from src/models/mistral.h rename to src/models/_deprecated/mistral.h diff --git a/src/models/mpt.h b/src/models/_deprecated/mpt.h similarity index 100% rename from src/models/mpt.h rename to src/models/_deprecated/mpt.h diff --git a/src/models/simple_model.h b/src/models/_deprecated/simple_model.h similarity index 100% rename from src/models/simple_model.h rename to src/models/_deprecated/simple_model.h diff --git a/src/models/qwen.h b/src/models/alibaba/qwen.h similarity index 77% rename from src/models/qwen.h rename to src/models/alibaba/qwen.h index d1867e4e..61cf7490 100644 --- a/src/models/qwen.h +++ b/src/models/alibaba/qwen.h @@ -10,6 +10,7 @@ #include "layers/attention/attention.h" #include "layers/attention/handler.h" #include "layers/embedding.h" +#include "layers/fused_linear.h" #include "layers/linear.h" #include "layers/normalization.h" #include "memory/kv_cache.h" @@ -20,7 +21,7 @@ #include "module/module_holder.h" #include "module/module_list.h" // QWen model compatible with huggingface weights -// adopted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py +// Adapted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py namespace llm::hf { class QWenMLPImpl : public Module { @@ -38,14 +39,18 @@ class QWenMLPImpl : public Module { const int64_t intermediate_size = args.intermediate_size() / 2; // register the weight parameter - w1_w2_proj_ = register_module("gate_up_proj", - ColumnParallelLinear(hidden_size, - intermediate_size * 2, - /*bias=*/false, - /*gather_output=*/false, - quant_args, - parallel_args, - options)); + gate_up_proj_ = register_module( + "gate_up_proj", + FusedColumnParallelLinear( + hidden_size, + std::vector{intermediate_size, intermediate_size}, + std::vector{"w1.", "w2."}, + /*bias=*/false, + /*gather_output=*/false, + quant_args, + parallel_args, + options), + /*selector=*/nullptr); c_proj_ = register_module("c_proj", RowParallelLinear(intermediate_size, hidden_size, @@ -57,26 +62,13 @@ class QWenMLPImpl : public Module { } torch::Tensor forward(torch::Tensor x) { - auto gate_up_proj = w1_w2_proj_(x); - auto chunks = gate_up_proj.chunk(/*chunks=*/2, /*dim=*/-1); - return c_proj_(chunks[0] * act_(chunks[1])); - } - - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - w1_w2_proj_->load_state_dict(state_dict, {"w1.", "w2."}); - c_proj_->load_state_dict(state_dict.select("c_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - w1_w2_proj_->verify_loaded_weights(prefix + "[w1,w2]."); - c_proj_->verify_loaded_weights(prefix + "c_proj."); + const auto gate_up = gate_up_proj_(x); + return c_proj_(gate_up[0] * act_(gate_up[1])); } private: // parameter members, must be registered - ColumnParallelLinear w1_w2_proj_{nullptr}; + FusedColumnParallelLinear gate_up_proj_{nullptr}; RowParallelLinear c_proj_{nullptr}; ActFunc act_{nullptr}; @@ -133,18 +125,6 @@ class QWenAttentionImpl : public Module { return c_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - c_attn_->load_state_dict(state_dict.select("c_attn.")); - c_proj_->load_state_dict(state_dict.select("c_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - c_attn_->verify_loaded_weights(prefix + "c_attn."); - c_proj_->verify_loaded_weights(prefix + "c_proj."); - } - private: // parameter members, must be registered ColumnParallelLinear c_attn_{nullptr}; @@ -183,22 +163,6 @@ class QWenBlockImpl : public Module { return h + mlp_(ln_2_(h)); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - attn_->load_state_dict(state_dict.select("attn.")); - mlp_->load_state_dict(state_dict.select("mlp.")); - ln_1_->load_state_dict(state_dict.select("ln_1.")); - ln_2_->load_state_dict(state_dict.select("ln_2.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - attn_->verify_loaded_weights(prefix + "attn."); - mlp_->verify_loaded_weights(prefix + "mlp."); - ln_1_->verify_loaded_weights(prefix + "ln_1."); - ln_2_->verify_loaded_weights(prefix + "ln_2."); - } - private: // parameter members, must be registered QWenAttention attn_{nullptr}; @@ -226,7 +190,7 @@ class QWenModelImpl : public Module { handler_ = AttentionHandler::create_handler_with_rope( args, /*interleaved=*/false, options); - blocks_ = register_module("layers", ModuleList()); + blocks_ = register_module("h", ModuleList()); layers_.reserve(args.n_layers()); for (int32_t i = 0; i < args.n_layers(); i++) { auto block = @@ -254,26 +218,6 @@ class QWenModelImpl : public Module { return ln_f_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - wte_->load_state_dict(state_dict.select("wte.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("h." + std::to_string(i) + ".")); - } - ln_f_->load_state_dict(state_dict.select("ln_f.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - wte_->verify_loaded_weights(prefix + "wte."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "h." + std::to_string(i) + - "."); - } - ln_f_->verify_loaded_weights(prefix + "ln_f."); - } - private: // parameter members, must be registered ParallelEmbedding wte_{nullptr}; @@ -331,17 +275,6 @@ class QWenForCausalLMImpl : public Module { return lm_head_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - transformer_->load_state_dict(state_dict.select("transformer.")); - lm_head_->load_state_dict(state_dict.select("lm_head.")); - } - - void verify_loaded_weights() const { - transformer_->verify_loaded_weights("transformer."); - lm_head_->verify_loaded_weights("lm_head."); - } - private: // parameter members, must be registered QWenModel transformer_{nullptr}; diff --git a/src/models/qwen2.h b/src/models/alibaba/qwen2.h similarity index 78% rename from src/models/qwen2.h rename to src/models/alibaba/qwen2.h index d93de88a..84758c2d 100644 --- a/src/models/qwen2.h +++ b/src/models/alibaba/qwen2.h @@ -43,11 +43,13 @@ class QWen2MLPImpl : public Module { FusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, + std::vector{"gate_proj.", "up_proj."}, /*bias=*/false, /*gather_output=*/false, quant_args, parallel_args, - options)); + options), + /*selector=*/nullptr); down_proj_ = register_module("down_proj", RowParallelLinear(intermediate_size, @@ -64,18 +66,6 @@ class QWen2MLPImpl : public Module { return down_proj_(act_func_(gate_up[0]) * gate_up[1]); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."}); - down_proj_->load_state_dict(state_dict.select("down_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj]."); - down_proj_->verify_loaded_weights(prefix + "down_proj."); - } - private: // parameter members, must be registered FusedColumnParallelLinear gate_up_proj_{nullptr}; @@ -104,16 +94,20 @@ class QWen2AttentionImpl : public Module { std::max(1, n_kv_heads / world_size); // register submodules - qkv_proj_ = register_module("qkv_proj", - QKVColumnParallelLinear(hidden_size, - n_heads, - n_kv_heads, - head_dim, - /*bias=*/true, - /*gather_output=*/false, - quant_args, - parallel_args, - options)); + qkv_proj_ = register_module( + "qkv_proj", + QKVColumnParallelLinear( + hidden_size, + n_heads, + n_kv_heads, + head_dim, + std::vector{"q_proj.", "k_proj.", "v_proj."}, + /*bias=*/true, + /*gather_output=*/false, + quant_args, + parallel_args, + options), + /*selector=*/nullptr); o_proj_ = register_module("o_proj", RowParallelLinear(hidden_size, @@ -146,19 +140,6 @@ class QWen2AttentionImpl : public Module { return o_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - qkv_proj_->load_state_dict( - state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."}); - o_proj_->load_state_dict(state_dict.select("o_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj]."); - o_proj_->verify_loaded_weights(prefix + "o_proj."); - } - private: // parameter members, must be registered QKVColumnParallelLinear qkv_proj_{nullptr}; @@ -208,24 +189,6 @@ class QWen2DecoderLayerImpl : public Module { return hidden_states; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - self_attn_->load_state_dict(state_dict.select("self_attn.")); - mlp_->load_state_dict(state_dict.select("mlp.")); - input_layernorm_->load_state_dict(state_dict.select("input_layernorm.")); - post_attention_layernorm_->load_state_dict( - state_dict.select("post_attention_layernorm.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - self_attn_->verify_loaded_weights(prefix + "self_attn."); - mlp_->verify_loaded_weights(prefix + "mlp."); - input_layernorm_->verify_loaded_weights(prefix + "input_layernorm."); - post_attention_layernorm_->verify_loaded_weights( - prefix + "post_attention_layernorm."); - } - private: // parameter members, must be registered QWen2Attention self_attn_{nullptr}; @@ -291,26 +254,6 @@ class QWen2ModelImpl : public Module { return norm_(h, residual); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - embed_tokens_->load_state_dict(state_dict.select("embed_tokens.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("layers." + std::to_string(i) + ".")); - } - norm_->load_state_dict(state_dict.select("norm.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + - "."); - } - norm_->verify_loaded_weights(prefix + "norm."); - } - private: // parameter members, must be registered ParallelEmbedding embed_tokens_{nullptr}; @@ -368,17 +311,6 @@ class QWen2ForCausalLMImpl : public Module { return lm_head_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - model_->load_state_dict(state_dict.select("model.")); - lm_head_->load_state_dict(state_dict.select("lm_head.")); - } - - void verify_loaded_weights() const { - model_->verify_loaded_weights("model."); - lm_head_->verify_loaded_weights("lm_head."); - } - private: // parameter members, must be registered QWen2Model model_{nullptr}; diff --git a/src/models/causal_lm.h b/src/models/causal_lm.h index 02a54e29..eaed0ffe 100644 --- a/src/models/causal_lm.h +++ b/src/models/causal_lm.h @@ -70,11 +70,15 @@ class CausalLMImpl : public CausalLM { } void load_state_dict(const StateDict& state_dict) override { - model_->load_state_dict(state_dict); + model_->load(state_dict); } void verify_loaded_weights() const override { - return model_->verify_loaded_weights(); + bool success = model_->verify(); + if (!success) { + LOG(FATAL) << "Failed to verify loaded weights for the model." + << " Please check the log for more details."; + } } torch::Device device() const override { return options_.device(); } diff --git a/src/models/deepseek/README.md b/src/models/deepseek/README.md new file mode 100644 index 00000000..bc60bf4b --- /dev/null +++ b/src/models/deepseek/README.md @@ -0,0 +1 @@ +TODO: diff --git a/src/models/gemma.h b/src/models/google/gemma.h similarity index 77% rename from src/models/gemma.h rename to src/models/google/gemma.h index dcb2c692..7934c9b6 100644 --- a/src/models/gemma.h +++ b/src/models/google/gemma.h @@ -11,6 +11,7 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/linear_impl.h" #include "layers/normalization.h" #include "layers/qkv_linear.h" #include "memory/kv_cache.h" @@ -42,11 +43,13 @@ class GemmaMLPImpl : public Module { FusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, + std::vector{"gate_proj.", "up_proj."}, /*bias=*/false, /*gather_output=*/false, quant_args, parallel_args, - options)); + options), + /*selector=*/nullptr); down_proj_ = register_module("down_proj", RowParallelLinear(intermediate_size, @@ -63,18 +66,6 @@ class GemmaMLPImpl : public Module { return down_proj_(act_func_(gate_up[0]) * gate_up[1]); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."}); - down_proj_->load_state_dict(state_dict.select("down_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj]."); - down_proj_->verify_loaded_weights(prefix + "down_proj."); - } - private: // parameter members, must be registered FusedColumnParallelLinear gate_up_proj_{nullptr}; @@ -102,16 +93,20 @@ class GemmaAttentionImpl : public Module { std::max(1, n_kv_heads / world_size); // register submodules - qkv_proj_ = register_module("qkv_proj", - QKVColumnParallelLinear(hidden_size, - n_heads, - n_kv_heads, - head_dim, - /*bias=*/false, - /*gather_output=*/false, - quant_args, - parallel_args, - options)); + qkv_proj_ = register_module( + "qkv_proj", + QKVColumnParallelLinear( + hidden_size, + n_heads, + n_kv_heads, + head_dim, + std::vector{"q_proj.", "k_proj.", "v_proj."}, + /*bias=*/false, + /*gather_output=*/false, + quant_args, + parallel_args, + options), + /*selector=*/nullptr); o_proj_ = register_module("o_proj", RowParallelLinear(n_heads * head_dim, @@ -141,19 +136,6 @@ class GemmaAttentionImpl : public Module { return o_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - qkv_proj_->load_state_dict( - state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."}); - o_proj_->load_state_dict(state_dict.select("o_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj]."); - o_proj_->verify_loaded_weights(prefix + "o_proj."); - } - private: // parameter members, must be registered QKVColumnParallelLinear qkv_proj_{nullptr}; @@ -207,21 +189,6 @@ class GemmaDecoderLayerImpl : public Module { hidden_states += residual; return hidden_states; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - input_layernorm_->load_state_dict((state_dict.select("input_layernorm."))); - mlp_->load_state_dict(state_dict.select("mlp.")); - post_attention_layernorm_->load_state_dict( - (state_dict.select("post_attention_layernorm."))); - self_attn_->load_state_dict(state_dict.select("self_attn.")); - } - void verify_loaded_weights(const std::string& prefix) const { - self_attn_->verify_loaded_weights(prefix + "self_attn."); - mlp_->verify_loaded_weights(prefix + "mlp."); - input_layernorm_->verify_loaded_weights(prefix + "input_layernorm."); - post_attention_layernorm_->verify_loaded_weights( - prefix + "post_attention_layernorm."); - } private: GemmaAttention self_attn_{nullptr}; @@ -286,26 +253,6 @@ class GemmaModelImpl : public Module { return norm_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - embed_tokens_->load_state_dict(state_dict.select("embed_tokens.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("layers." + std::to_string(i) + ".")); - } - norm_->load_state_dict((state_dict.select("norm."))); - } - - void verify_loaded_weights(const std::string& prefix) const { - embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + - "."); - } - norm_->verify_loaded_weights(prefix + "norm."); - } - private: ModelArgs modelArgs_; @@ -337,7 +284,7 @@ class GemmaForCausalLMImpl : public Module { model_ = register_module( "model", GemmaModel(args, quant_args, parallel_args, options)); - lm_head_ = register_module("lm_head", + lm_head_ = register_module("model.embed_tokens", ColumnParallelLinear(args.hidden_size(), args.vocab_size(), /*bias=*/false, @@ -370,19 +317,6 @@ class GemmaForCausalLMImpl : public Module { return lm_head_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - model_->load_state_dict(state_dict.select("model.")); - - // Share the embedding weights with the final llm_head layer. - lm_head_->load_state_dict(state_dict.select("model.embed_tokens.")); - } - - void verify_loaded_weights() const { - model_->verify_loaded_weights("model."); - lm_head_->verify_loaded_weights("model.embed_tokens."); - } - private: // parameter members, must be registered GemmaModel model_{nullptr}; diff --git a/src/models/gemma2.h b/src/models/google/gemma2.h similarity index 76% rename from src/models/gemma2.h rename to src/models/google/gemma2.h index 6282af6e..6b4e75d6 100644 --- a/src/models/gemma2.h +++ b/src/models/google/gemma2.h @@ -42,11 +42,13 @@ class Gemma2MLPImpl : public Module { FusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, + std::vector{"gate_proj.", "up_proj."}, /*bias=*/false, /*gather_output=*/false, quant_args, parallel_args, - options)); + options), + /*selector=*/nullptr); down_proj_ = register_module("down_proj", RowParallelLinear(intermediate_size, @@ -63,18 +65,6 @@ class Gemma2MLPImpl : public Module { return down_proj_(act_func_(gate_up[0]) * gate_up[1]); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."}); - down_proj_->load_state_dict(state_dict.select("down_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj]."); - down_proj_->verify_loaded_weights(prefix + "down_proj."); - } - private: // parameter members, must be registered FusedColumnParallelLinear gate_up_proj_{nullptr}; @@ -103,16 +93,20 @@ class Gemma2AttentionImpl : public Module { std::max(1, n_kv_heads / world_size); // register submodules - qkv_proj_ = register_module("qkv_proj", - QKVColumnParallelLinear(hidden_size, - n_heads, - n_kv_heads, - head_dim, - args.attn_bias(), - /*gather_output=*/false, - quant_args, - parallel_args, - options)); + qkv_proj_ = register_module( + "qkv_proj", + QKVColumnParallelLinear( + hidden_size, + n_heads, + n_kv_heads, + head_dim, + std::vector{"q_proj.", "k_proj.", "v_proj."}, + args.attn_bias(), + /*gather_output=*/false, + quant_args, + parallel_args, + options), + /*selector=*/nullptr); o_proj_ = register_module("o_proj", RowParallelLinear(n_heads * head_dim, @@ -146,19 +140,6 @@ class Gemma2AttentionImpl : public Module { return o_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - qkv_proj_->load_state_dict( - state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."}); - o_proj_->load_state_dict(state_dict.select("o_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj]."); - o_proj_->verify_loaded_weights(prefix + "o_proj."); - } - private: // parameter members, must be registered QKVColumnParallelLinear qkv_proj_{nullptr}; @@ -167,9 +148,6 @@ class Gemma2AttentionImpl : public Module { // module members without parameters Attention atten_{nullptr}; - - // size for q, k, v - std::vector qkv_sizes_; }; LLM_MODULE(Gemma2Attention); @@ -226,29 +204,6 @@ class Gemma2DecoderLayerImpl : public Module { hidden_states += residual; return hidden_states; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - input_layernorm_->load_state_dict((state_dict.select("input_layernorm."))); - self_attn_->load_state_dict(state_dict.select("self_attn.")); - mlp_->load_state_dict(state_dict.select("mlp.")); - post_attention_layernorm_->load_state_dict( - (state_dict.select("post_attention_layernorm."))); - pre_feedforward_layernorm_->load_state_dict( - state_dict.select("pre_feedforward_layernorm.")); - post_feedforward_layernorm_->load_state_dict( - state_dict.select("post_feedforward_layernorm.")); - } - void verify_loaded_weights(const std::string& prefix) const { - input_layernorm_->verify_loaded_weights(prefix + "input_layernorm."); - self_attn_->verify_loaded_weights(prefix + "self_attn."); - mlp_->verify_loaded_weights(prefix + "mlp."); - post_attention_layernorm_->verify_loaded_weights( - prefix + "post_attention_layernorm."); - pre_feedforward_layernorm_->verify_loaded_weights( - prefix + "pre_feedforward_layernorm."); - post_feedforward_layernorm_->verify_loaded_weights( - prefix + "post_feedforward_layernorm."); - } private: Gemma2Attention self_attn_{nullptr}; @@ -322,26 +277,6 @@ class Gemma2ModelImpl : public Module { return norm_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - embed_tokens_->load_state_dict(state_dict.select("embed_tokens.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("layers." + std::to_string(i) + ".")); - } - norm_->load_state_dict((state_dict.select("norm."))); - } - - void verify_loaded_weights(const std::string& prefix) const { - embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + - "."); - } - norm_->verify_loaded_weights(prefix + "norm."); - } - private: ModelArgs modelArgs_; @@ -375,7 +310,7 @@ class Gemma2ForCausalLMImpl : public Module { model_ = register_module( "model", Gemma2Model(args, quant_args, parallel_args, options)); - lm_head_ = register_module("lm_head", + lm_head_ = register_module("model.embed_tokens", ColumnParallelLinear(args.hidden_size(), args.vocab_size(), /*bias=*/false, @@ -409,19 +344,6 @@ class Gemma2ForCausalLMImpl : public Module { final_logit_soft_cap_; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - model_->load_state_dict(state_dict.select("model.")); - - // Share the embedding weights with the final llm_head layer. - lm_head_->load_state_dict(state_dict.select("model.embed_tokens.")); - } - - void verify_loaded_weights() const { - model_->verify_loaded_weights("model."); - lm_head_->verify_loaded_weights("model.embed_tokens."); - } - private: // parameter members, must be registered Gemma2Model model_{nullptr}; diff --git a/src/models/llama.h b/src/models/meta/llama.h similarity index 80% rename from src/models/llama.h rename to src/models/meta/llama.h index 120ff48d..d5b3a17d 100644 --- a/src/models/llama.h +++ b/src/models/meta/llama.h @@ -40,11 +40,13 @@ class LlamaMLPImpl : public Module { FusedColumnParallelLinear( hidden_size, std::vector{intermediate_size, intermediate_size}, + std::vector{"gate_proj.", "up_proj."}, /*bias=*/false, /*gather_output=*/false, quant_args, parallel_args, - options)); + options), + /*selector=*/nullptr); down_proj_ = register_module("down_proj", @@ -62,18 +64,6 @@ class LlamaMLPImpl : public Module { return down_proj_(act_func_(gate_up[0]) * gate_up[1]); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."}); - down_proj_->load_state_dict(state_dict.select("down_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj]."); - down_proj_->verify_loaded_weights(prefix + "down_proj."); - } - private: // parameter members, must be registered FusedColumnParallelLinear gate_up_proj_{nullptr}; @@ -101,16 +91,20 @@ class LlamaAttentionImpl : public Module { std::max(1, n_kv_heads / world_size); // register submodules - qkv_proj_ = register_module("qkv_proj", - QKVColumnParallelLinear(hidden_size, - n_heads, - n_kv_heads, - head_dim, - /*bias=*/false, - /*gather_output=*/false, - quant_args, - parallel_args, - options)); + qkv_proj_ = register_module( + "qkv_proj", + QKVColumnParallelLinear( + hidden_size, + n_heads, + n_kv_heads, + head_dim, + std::vector{"q_proj.", "k_proj.", "v_proj."}, + /*bias=*/false, + /*gather_output=*/false, + quant_args, + parallel_args, + options), + /*selector=*/nullptr); o_proj_ = register_module("o_proj", RowParallelLinear(hidden_size, @@ -140,19 +134,6 @@ class LlamaAttentionImpl : public Module { return o_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - qkv_proj_->load_state_dict( - state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."}); - o_proj_->load_state_dict(state_dict.select("o_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj]."); - o_proj_->verify_loaded_weights(prefix + "o_proj."); - } - private: // parameter members, must be registered QKVColumnParallelLinear qkv_proj_{nullptr}; @@ -197,24 +178,6 @@ class LlamaDecoderLayerImpl : public Module { return h + mlp_(post_attention_layernorm_(h)); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - self_attn_->load_state_dict(state_dict.select("self_attn.")); - mlp_->load_state_dict(state_dict.select("mlp.")); - input_layernorm_->load_state_dict(state_dict.select("input_layernorm.")); - post_attention_layernorm_->load_state_dict( - state_dict.select("post_attention_layernorm.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - self_attn_->verify_loaded_weights(prefix + "self_attn."); - mlp_->verify_loaded_weights(prefix + "mlp."); - input_layernorm_->verify_loaded_weights(prefix + "input_layernorm."); - post_attention_layernorm_->verify_loaded_weights( - prefix + "post_attention_layernorm."); - } - private: // parameter members, must be registered LlamaAttention self_attn_{nullptr}; @@ -270,26 +233,6 @@ class LlamaModelImpl : public Module { return norm_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - embed_tokens_->load_state_dict(state_dict.select("embed_tokens.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("layers." + std::to_string(i) + ".")); - } - norm_->load_state_dict(state_dict.select("norm.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + - "."); - } - norm_->verify_loaded_weights(prefix + "norm."); - } - private: // parameter members, must be registered ParallelEmbedding embed_tokens_{nullptr}; @@ -347,17 +290,6 @@ class LlamaForCausalLMImpl : public Module { return lm_head_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - model_->load_state_dict(state_dict.select("model.")); - lm_head_->load_state_dict(state_dict.select("lm_head.")); - } - - void verify_loaded_weights() const { - model_->verify_loaded_weights("model."); - lm_head_->verify_loaded_weights("lm_head."); - } - private: // parameter members, must be registered LlamaModel model_{nullptr}; diff --git a/src/models/phi.h b/src/models/microsoft/phi.h similarity index 80% rename from src/models/phi.h rename to src/models/microsoft/phi.h index 99e8dc07..f81ed123 100644 --- a/src/models/phi.h +++ b/src/models/microsoft/phi.h @@ -52,18 +52,6 @@ class PhiMLPImpl : public Module { torch::Tensor forward(torch::Tensor x) { return fc2_(act_(fc1_(x))); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - fc1_->load_state_dict(state_dict.select("fc1.")); - fc2_->load_state_dict(state_dict.select("fc2.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - fc1_->verify_loaded_weights(prefix + "fc1."); - fc2_->verify_loaded_weights(prefix + "fc2."); - } - private: // parameter members, must be registered ColumnParallelLinear fc1_{nullptr}; @@ -134,18 +122,6 @@ class PhiAttentionImpl : public Module { return out_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - Wqkv_->load_state_dict(state_dict.select("Wqkv.")); - out_proj_->load_state_dict(state_dict.select("out_proj.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - Wqkv_->verify_loaded_weights(prefix + "Wqkv."); - out_proj_->verify_loaded_weights(prefix + "out_proj."); - } - private: // parameter members, must be registered ColumnParallelLinear Wqkv_{nullptr}; @@ -191,20 +167,6 @@ class PhiBlockImpl : public Module { return x + attn_output + mlp_output; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - mixer_->load_state_dict(state_dict.select("mixer.")); - mlp_->load_state_dict(state_dict.select("mlp.")); - ln_->load_state_dict(state_dict.select("ln.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - mixer_->verify_loaded_weights(prefix + "mixer."); - mlp_->verify_loaded_weights(prefix + "mlp."); - ln_->verify_loaded_weights(prefix + "ln."); - } - private: // parameter members, must be registered PhiAttention mixer_{nullptr}; @@ -223,7 +185,7 @@ class PhiModelImpl : public Module { const torch::TensorOptions& options) { // register submodules wte_ = register_module( - "wte", + "embd.wte", ParallelEmbedding( args.vocab_size(), args.hidden_size(), parallel_args, options)); @@ -256,24 +218,6 @@ class PhiModelImpl : public Module { return h; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - wte_->load_state_dict(state_dict.select("embd.wte.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("h." + std::to_string(i) + ".")); - } - } - - void verify_loaded_weights(const std::string& prefix) const { - wte_->verify_loaded_weights(prefix + "embd.wte."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "h." + std::to_string(i) + - "."); - } - } - private: // parameter members, must be registered ParallelEmbedding wte_{nullptr}; @@ -310,17 +254,6 @@ class PhiLMHeadImpl : public Module { torch::Tensor forward(torch::Tensor x) { return linear_(ln_(x)); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - ln_->load_state_dict(state_dict.select("ln.")); - linear_->load_state_dict(state_dict.select("linear.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - ln_->verify_loaded_weights(prefix + "ln."); - linear_->verify_loaded_weights(prefix + "linear."); - } - private: // parameter members, must be registered LayerNorm ln_{nullptr}; @@ -366,17 +299,6 @@ class PhiForCausalLMImpl : public Module { return lm_head_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - transformer_->load_state_dict(state_dict.select("transformer.")); - lm_head_->load_state_dict(state_dict.select("lm_head.")); - } - - void verify_loaded_weights() const { - transformer_->verify_loaded_weights("transformer."); - lm_head_->verify_loaded_weights("lm_head."); - } - private: // parameter members, must be registered PhiModel transformer_{nullptr}; diff --git a/src/models/model_registry.cpp b/src/models/model_registry.cpp index b4c7c5b9..c41ace43 100644 --- a/src/models/model_registry.cpp +++ b/src/models/model_registry.cpp @@ -2,7 +2,7 @@ #include -#include "models.h" // IWYU pragma: keep +#include "registered_models.h" // IWYU pragma: keep namespace llm { diff --git a/src/models/models.h b/src/models/models.h deleted file mode 100644 index e569821e..00000000 --- a/src/models/models.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -// list all registered models here -#include "aquila.h" // IWYU pragma: keep -#include "baichuan.h" // IWYU pargma: keep -#include "bloom.h" // IWYU pragma: keep -#include "chatglm.h" // IWYU pragma: keep -#include "gemma.h" // IWYU pragma: keep -#include "gemma2.h" // IWYU pragma: keep -#include "gpt2.h" // IWYU pragma: keep -#include "gpt_j.h" // IWYU pragma: keep -#include "gpt_neox.h" // IWYU pragma: keep -#include "internlm.h" // IWYU pragma: keep -#include "llama.h" // IWYU pragma: keep -#include "mistral.h" // IWYU pragma: keep -#include "mpt.h" // IWYU pragma: keep -#include "phi.h" // IWYU pragma: keep -#include "qwen.h" // IWYU pragma: keep -#include "qwen2.h" // IWYU pragma: keep diff --git a/src/models/gpt2.h b/src/models/openai/gpt2.h similarity index 74% rename from src/models/gpt2.h rename to src/models/openai/gpt2.h index afde1084..2644ec48 100644 --- a/src/models/gpt2.h +++ b/src/models/openai/gpt2.h @@ -8,6 +8,7 @@ #include "layers/attention/handler.h" #include "layers/embedding.h" #include "layers/linear.h" +#include "layers/linear_impl.h" #include "layers/normalization.h" #include "memory/kv_cache.h" #include "models/model_args.h" @@ -21,6 +22,17 @@ namespace llm::hf { +namespace detail { +// GPT-2 implementation uses Conv1D instead of Linear. As a result, we +// need to transpose the weight for linear layer when loading from checkpoint. +static StateDict transpose_selector(const StateDict& sd, + const std::string& key) { + // transpose the weight + return sd.select_with_transform( + key + ".", [](const torch::Tensor& tensor) { return tensor.t(); }); +} +} // namespace detail + class GPT2MLPImpl : public Module { public: GPT2MLPImpl(const ModelArgs& args, @@ -41,7 +53,8 @@ class GPT2MLPImpl : public Module { /*gather_output=*/false, quant_args, parallel_args, - options)); + options), + detail::transpose_selector); c_proj_ = register_module("c_proj", RowParallelLinear(intermediate_size, hidden_size, @@ -49,33 +62,12 @@ class GPT2MLPImpl : public Module { /*input_is_parallelized=*/true, quant_args, parallel_args, - options)); + options), + detail::transpose_selector); } torch::Tensor forward(torch::Tensor x) { return c_proj_(act_(c_fc_(x))); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - // GPT-2 implementation uses Conv1D instead of Linear. As a result, we - // need to transpose the weight. - c_fc_->load_state_dict(state_dict.select_with_transform( - "c_fc.", - [](const std::string_view& /*name*/, const torch::Tensor& tensor) { - return tensor.t(); - })); - c_proj_->load_state_dict(state_dict.select_with_transform( - "c_proj.", - [](const std::string_view& /*name*/, const torch::Tensor& tensor) { - return tensor.t(); - })); - } - - void verify_loaded_weights(const std::string& prefix) const { - c_fc_->verify_loaded_weights(prefix + "c_fc."); - c_proj_->verify_loaded_weights(prefix + "c_proj."); - } - private: // parameter members, must be registered ColumnParallelLinear c_fc_{nullptr}; @@ -105,7 +97,8 @@ class GPT2AttentionImpl : public Module { /*gather_output=*/false, quant_args, parallel_args, - options)); + options), + detail::transpose_selector); c_proj_ = register_module("c_proj", RowParallelLinear(hidden_size_, @@ -114,7 +107,8 @@ class GPT2AttentionImpl : public Module { /*input_is_parallelized=*/true, quant_args, parallel_args, - options)); + options), + detail::transpose_selector); // initialize attention atten_ = register_module( @@ -138,27 +132,6 @@ class GPT2AttentionImpl : public Module { return c_proj_(output); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // GPT-2 implementation uses Conv1D instead of Linear. As a result, we - // need to transpose the weight. - c_attn_->load_state_dict(state_dict.select_with_transform( - "c_attn.", - [](const std::string_view& /*name*/, const torch::Tensor& tensor) { - return tensor.t(); - })); - c_proj_->load_state_dict(state_dict.select_with_transform( - "c_proj.", - [](const std::string_view& /*name*/, const torch::Tensor& tensor) { - return tensor.t(); - })); - } - - void verify_loaded_weights(const std::string& prefix) const { - c_attn_->verify_loaded_weights(prefix + "c_attn."); - c_proj_->verify_loaded_weights(prefix + "c_proj."); - } - private: // parameter members, must be registered ColumnParallelLinear c_attn_{nullptr}; @@ -207,22 +180,6 @@ class GPT2BlockImpl : public Module { return h + mlp_(ln_2_(h)); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - // call each submodule's load_state_dict function - attn_->load_state_dict(state_dict.select("attn.")); - mlp_->load_state_dict(state_dict.select("mlp.")); - ln_1_->load_state_dict(state_dict.select("ln_1.")); - ln_2_->load_state_dict(state_dict.select("ln_2.")); - } - - void verify_loaded_weights(const std::string& prefix) const { - attn_->verify_loaded_weights(prefix + "attn."); - mlp_->verify_loaded_weights(prefix + "mlp."); - ln_1_->verify_loaded_weights(prefix + "ln_1."); - ln_2_->verify_loaded_weights(prefix + "ln_2."); - } - private: // parameter members, must be registered GPT2Attention attn_{nullptr}; @@ -282,27 +239,6 @@ class GPT2ModelImpl : public Module { return ln_f_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - wte_->load_state_dict(state_dict.select("wte.")); - wpe_->load_state_dict(state_dict.select("wpe.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.select("h." + std::to_string(i) + ".")); - } - ln_f_->load_state_dict(state_dict.select("ln_f.")); - } - - void verify_loaded_weights() const { - wte_->verify_loaded_weights("wte."); - wpe_->verify_loaded_weights("wpe."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights("h." + std::to_string(i) + "."); - } - ln_f_->verify_loaded_weights("ln_f."); - } - private: // parameter members, must be registered ParallelEmbedding wte_{nullptr}; @@ -328,10 +264,13 @@ class GPT2ForCausalLMImpl : public Module { const ParallelArgs& parallel_args, const torch::TensorOptions& options) { // register submodules - model_ = register_module( - "model", GPT2Model(args, quant_args, parallel_args, options)); + model_ = + register_module("model", + GPT2Model(args, quant_args, parallel_args, options), + /*selector=*/nullptr); - lm_head_ = register_module("lm_head", + // load wte.weight + lm_head_ = register_module("wte", ColumnParallelLinear(args.hidden_size(), args.vocab_size(), /*bias=*/false, @@ -363,18 +302,6 @@ class GPT2ForCausalLMImpl : public Module { return lm_head_(h); } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - model_->load_state_dict(state_dict); - // TODO: share wte_ weight with lm_head_ to save memory - lm_head_->load_state_dict(state_dict.select("wte.")); - } - - void verify_loaded_weights() const { - model_->verify_loaded_weights(); - lm_head_->verify_loaded_weights("wte."); - } - private: // parameter members, must be registered GPT2Model model_{nullptr}; diff --git a/src/models/registered_models.h b/src/models/registered_models.h new file mode 100644 index 00000000..fba44c42 --- /dev/null +++ b/src/models/registered_models.h @@ -0,0 +1,26 @@ +#pragma once + +// list all registered models here +// Google +#include "google/gemma.h" // IWYU pragma: keep +#include "google/gemma2.h" // IWYU pragma: keep +// OpenAI +#include "openai/gpt2.h" // IWYU pragma: keep +// Meta +#include "meta/llama.h" // IWYU pragma: keep +// Microsoft +#include "microsoft/phi.h" // IWYU pragma: keep +// Alibaba +#include "alibaba/qwen.h" // IWYU pragma: keep +#include "alibaba/qwen2.h" // IWYU pragma: keep + +// Deprecated models +// #include "deprecated/aquila.h" +// #include "deprecated/baichuan.h" +// #include "deprecated/bloom.h" +// #include "deprecated/chatglm.h" +// #include "deprecated/gpt_j.h" +// #include "deprecated/gpt_neox.h" +// #include "deprecated/internlm.h" +// #include "deprecated/mistral.h" +// #include "deprecated/mpt.h" diff --git a/src/module/CMakeLists.txt b/src/module/CMakeLists.txt index ceaab617..043f5de7 100644 --- a/src/module/CMakeLists.txt +++ b/src/module/CMakeLists.txt @@ -1,4 +1,5 @@ include(cc_library) +include(cc_test) cc_library( NAME @@ -13,3 +14,15 @@ cc_library( glog::glog torch ) + +cc_test( + NAME + module_test + SRCS + module_test.cpp + DEPS + :module + :state_dict + :gtest_main + torch +) diff --git a/src/module/module.cpp b/src/module/module.cpp index 45f2546c..29d730b7 100644 --- a/src/module/module.cpp +++ b/src/module/module.cpp @@ -2,6 +2,8 @@ #include +#include "model_loader/state_dict.h" + namespace llm { using namespace torch; @@ -49,8 +51,9 @@ OrderedDict Module::named_parameters(bool recurse) const { OrderedDict result; if (!recurse) { for (const auto& parameter : parameters_) { - if (parameter.value().defined()) { - result.insert(parameter.key(), parameter.value()); + const auto& param_tensor = parameter.value().tensor; + if (param_tensor.defined()) { + result.insert(parameter.key(), param_tensor); } } } else { @@ -125,12 +128,20 @@ OrderedDict> Module::named_modules( } std::vector> Module::children() const { - return children_.values(); + std::vector> result; + for (const auto& item : children_) { + result.push_back(item.value().module); + } + return result; } OrderedDict> Module::named_children() const { - return children_; + OrderedDict> result; + for (const auto& item : children_) { + result.insert(item.key(), item.value().module); + } + return result; } void Module::apply(const ModuleApplyFunction& function) { @@ -198,31 +209,51 @@ void Module::to(torch::Device device, bool non_blocking) { to_impl(device, non_blocking); } +Tensor& Module::register_parameter(std::string name, + Tensor tensor, + const TensorLoader& loader) { + TORCH_CHECK(!name.empty(), "Parameter name must not be empty"); + // tensor.set_requires_grad(false); + Parameter param{ + .tensor = std::move(tensor), .loader = loader, .is_loaded = false}; + return parameters_.insert(std::move(name), std::move(param)).tensor; +} + Tensor& Module::register_parameter(std::string name, Tensor tensor) { + auto default_loader = [](const StateDict& sd, const std::string& key) { + return sd.get_tensor(key); + }; + return register_parameter(std::move(name), std::move(tensor), default_loader); +} + +Tensor& Module::register_sharded_parameter(std::string name, + int dim, + int rank, + int world_size, + Tensor tensor) { TORCH_CHECK(!name.empty(), "Parameter name must not be empty"); - TORCH_CHECK(name.find('.') == std::string::npos, - "Parameter name must not contain a dot (got '", - name, - "')"); - tensor.set_requires_grad(false); - return parameters_.insert(std::move(name), std::move(tensor)); + // tensor.set_requires_grad(false); + Parameter param{ + .tensor = std::move(tensor), + .loader = + [dim, rank, world_size](const StateDict& sd, const std::string& key) { + return sd.get_sharded_tensor(key, dim, rank, world_size); + }, + .is_loaded = false}; + return parameters_.insert(std::move(name), std::move(param)).tensor; } Tensor& Module::register_buffer(std::string name, Tensor tensor) { TORCH_CHECK(!name.empty(), "Buffer name must not be empty"); - TORCH_CHECK(name.find('.') == std::string::npos, - "Buffer name must not contain a dot (got '", - name, - "')"); return buffers_.insert(std::move(name), std::move(tensor)); } -void Module::unregister_module(const std::string& name) { - TORCH_CHECK(children_.contains(name), - "No Module with name `", - name, - "` is registered"); +bool Module::unregister_module(const std::string& name) { + if (!children_.contains(name)) { + return false; + } children_.erase(name); + return true; } void Module::pretty_print(std::ostream& stream) const { stream << name(); } @@ -235,7 +266,7 @@ void Module::pretty_print_recursive(std::ostream& stream, const std::string next_indentation = indentation + " "; for (const auto& child : children_) { stream << next_indentation << "(" << child.key() << "): "; - child.value()->pretty_print_recursive(stream, next_indentation); + child.value().module->pretty_print_recursive(stream, next_indentation); stream << '\n'; } stream << indentation << ")"; @@ -248,8 +279,8 @@ void Module::apply_to_submodules( const std::string& name_prefix) const { for (const auto& child : children_) { auto qualified_name = join_name(name_prefix, child.key()); - function(qualified_name, child.value()); - child.value()->apply_to_submodules(function, qualified_name); + function(qualified_name, child.value().module); + child.value().module->apply_to_submodules(function, qualified_name); } } @@ -276,4 +307,89 @@ std::ostream& operator<<(std::ostream& stream, const Module& module) { module.pretty_print_recursive(stream, ""); return stream; } + +// load weights from the checkpoint, override this method if necessary +// NOLINTNEXTLINE(misc-no-recursion) +size_t Module::load(const StateDict& state_dict, + const std::string& name_prefix) { + size_t total_loaded = 0; + // load parameters one by one + for (auto& item : parameters_) { + const auto& key = item.key(); + auto& param = item.value(); + const auto& param_tensor = param.tensor; + // skip loading for undefined tensor + if (!param_tensor.defined()) { + // mark as loaded to pass verification + param.is_loaded = true; + continue; + } + + const auto tensor = param.loader(state_dict, key); + if (!tensor.defined()) { + continue; + } + + if (param.is_loaded) { + LOG(WARNING) << "Parameter " << join_name(name_prefix, key) + << " is already loaded"; + } + + if (param_tensor.sizes() == tensor.sizes()) { + // LOG(INFO) << "Loading parameter: " << join_name(name_prefix, key) + // << " of size " << tensor.sizes(); + // copy data to the parameter tensor + param_tensor.copy_(tensor); + // mark as loaded + param.is_loaded = true; + ++total_loaded; + } else { + LOG(ERROR) << "Size mismatch for parameter " + << join_name(name_prefix, key) << ": expected " + << param_tensor.sizes() << ", got " << tensor.sizes(); + } + } + + // don't need to load buffers, since they are initialized in the constructor + + // recursively load children modules + for (const auto& item : children_) { + const auto& key = item.key(); + const auto& child = item.value(); + if (child.selector) { + // select state dict for the child module + const auto child_state_dict = child.selector(state_dict, key); + total_loaded += + child.module->load(child_state_dict, join_name(name_prefix, key)); + } else { + total_loaded += child.module->load(state_dict, name_prefix); + } + } + return total_loaded; +} + +// verify whether the weights are loaded, override this method if necessary +// NOLINTNEXTLINE(misc-no-recursion) +bool Module::verify(const std::string& name_prefix) const { + bool all_loaded = true; + for (const auto& item : parameters_) { + const auto& key = item.key(); + const auto& param = item.value(); + if (!param.is_loaded) { + LOG(ERROR) << "Missing parameter: " << join_name(name_prefix, key) + << ", size: " << param.tensor.sizes(); + } + all_loaded = all_loaded && param.is_loaded; + } + + for (const auto& item : children_) { + const auto& key = item.key(); + const auto& child = item.value(); + const std::string prefix = + child.selector ? join_name(name_prefix, key) : name_prefix; + const bool child_loaded = child.module->verify(prefix); + all_loaded = all_loaded && child_loaded; + } + return all_loaded; +} } // namespace llm diff --git a/src/module/module.h b/src/module/module.h index 69a63109..6cb96ff6 100644 --- a/src/module/module.h +++ b/src/module/module.h @@ -8,7 +8,9 @@ #include #include #include +#include +#include "model_loader/state_dict.h" #include "module_holder.h" namespace llm { @@ -148,43 +150,61 @@ class Module : public std::enable_shared_from_this { /// `stream` should be returned from the method, to allow easy chaining. virtual void pretty_print(std::ostream& stream) const; + // load weights from the checkpoint, override this method if necessary + // returns the number of loaded parameters + virtual size_t load(const StateDict& state_dict, + const std::string& name_prefix = std::string()); + + // verify whether the weights are loaded, override this method if necessary + virtual bool verify(const std::string& name_prefix = std::string()) const; + /// Registers a parameter with this `Module`. + using TensorLoader = + std::function; + torch::Tensor& register_parameter(std::string name, + torch::Tensor tensor, + const TensorLoader& loader); + torch::Tensor& register_parameter(std::string name, torch::Tensor tensor); + torch::Tensor& register_sharded_parameter(std::string name, + int dim, + int rank, + int world_size, + torch::Tensor tensor); + /// Registers a buffer with this `Module`. torch::Tensor& register_buffer(std::string name, torch::Tensor tensor); /// Registers a submodule with this `Module`. + using StateDictSelector = + std::function; + template std::shared_ptr register_module( std::string name, - std::shared_ptr module); + std::shared_ptr module, + const StateDictSelector& selector); template std::shared_ptr register_module( std::string name, - ModuleHolder module_holder); - - /// Replaces a registered submodule with this `Module`. - template - std::shared_ptr replace_module( - const std::string& name, std::shared_ptr module); template - std::shared_ptr replace_module( - const std::string& name, + std::shared_ptr register_module( + std::string name, ModuleHolder module_holder); - /// Unregisters a submodule from this `Module`. If there is no such module - /// with `name` an exception is thrown. - void unregister_module(const std::string& name); + template + std::shared_ptr register_module( + std::string name, + ModuleHolder module_holder, + const StateDictSelector& selector); - protected: - /// The registered parameters of this `Module`. - /// Inorder to access parameters_ in ParameterDict and ParameterList - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - torch::OrderedDict parameters_; + /// Unregisters a submodule from this `Module`. Returns false if no such + /// submodule was registered. + bool unregister_module(const std::string& name); private: /// Pretty prints the given `Module` into the `ostream`. @@ -209,11 +229,25 @@ class Module : public std::enable_shared_from_this { /// Returns a shared_ptr to `this` in a safe (checked) way. std::shared_ptr shared_from_this_checked() const; + struct Parameter { + torch::Tensor tensor; + std::function loader; + bool is_loaded = false; + }; + + struct Child { + std::shared_ptr module; + std::function selector; + }; + + /// The registered parameters of this `Module`. + torch::OrderedDict parameters_; + /// The registered buffers of this `Module`. torch::OrderedDict buffers_; /// The registered (direct) submodules of this `Module`. - torch::OrderedDict> children_; + torch::OrderedDict children_; /// The module's name (e.g. "LSTM"). mutable std::optional name_; @@ -248,43 +282,45 @@ const ModuleType* Module::as() const noexcept { template std::shared_ptr Module::register_module( std::string name, - std::shared_ptr module) { + std::shared_ptr module, + const StateDictSelector& selector) { TORCH_CHECK(!name.empty(), "Submodule name must not be empty"); - TORCH_CHECK(name.find('.') == std::string::npos, - "Submodule name must not contain a dot (got '", - name, - "')"); - auto& base_module = children_.insert(std::move(name), std::move(module)); + Child child{.module = std::move(module), .selector = selector}; + auto& base_module = + children_.insert(std::move(name), std::move(child)).module; return std::dynamic_pointer_cast(base_module); } template std::shared_ptr Module::register_module( std::string name, - ModuleHolder module_holder) { - return register_module(std::move(name), module_holder.ptr()); + std::shared_ptr module) { + auto default_selector = [](const StateDict& sd, const std::string& key) { + return sd.select(key + "."); + }; + return register_module(std::move(name), std::move(module), default_selector); } template -std::shared_ptr Module::replace_module( - const std::string& name, - std::shared_ptr module) { - auto& base_module = (children_[name] = std::move(module)); - return std::dynamic_pointer_cast(base_module); +std::shared_ptr Module::register_module( + std::string name, + ModuleHolder module_holder) { + return register_module(std::move(name), module_holder.ptr()); } template -std::shared_ptr Module::replace_module( - const std::string& name, - ModuleHolder module_holder) { - return replace_module(name, module_holder.ptr()); +std::shared_ptr Module::register_module( + std::string name, + ModuleHolder module_holder, + const StateDictSelector& selector) { + return register_module(std::move(name), module_holder.ptr(), selector); } template void Module::to_impl(Ts&&... ts) { // First call `to()` on every child module. for (auto& child : children_) { - child.value()->to(ts...); + child.value().module->to(ts...); } // Then move every parameter to the new dtype/device. for (auto& parameter : named_parameters(/*recurse=*/false)) { diff --git a/src/module/module_list.h b/src/module/module_list.h index 3ae3cb45..ce0bfee6 100644 --- a/src/module/module_list.h +++ b/src/module/module_list.h @@ -1,9 +1,5 @@ #pragma once -#include -#include -#include - #include #include @@ -11,6 +7,16 @@ #include "module_holder.h" namespace llm { +namespace detail { +/// A type trait whose `value` member is true if `M` derives from `Module`. +template +using is_module = std::is_base_of>; + +template +using enable_if_module_t = std::enable_if_t::value, T>; + +} // namespace detail + /// A list of `Module`s that registers its elements. class ModuleListImpl : public Module { public: @@ -40,7 +46,7 @@ class ModuleListImpl : public Module { /// Adds a new `Module` to the `ModuleList` container, moving or copying /// it into a `shared_ptr` internally. This method allows passing value types, /// and letting the container deal with the boxing. - template > + template > void push_back(M&& module) { using Type = std::remove_reference_t; push_back(std::make_shared(std::forward(module))); @@ -78,7 +84,7 @@ class ModuleListImpl : public Module { /// match. template T& at(size_t index) { - static_assert(torch::detail::is_module::value, + static_assert(detail::is_module::value, "Can only call ModuleList::at with an nn::Module type"); TORCH_CHECK(index < size(), "Index out of range"); auto module = modules_[index]->as(); @@ -95,7 +101,7 @@ class ModuleListImpl : public Module { /// match. template const T& at(size_t index) const { - static_assert(torch::detail::is_module::value, + static_assert(detail::is_module::value, "Can only call ModuleList::at with an nn::Module type"); TORCH_CHECK(index < size(), "Index out of range"); const auto module = modules_[index]->as(); @@ -120,7 +126,7 @@ class ModuleListImpl : public Module { /// match. template std::shared_ptr ptr(size_t index) const { - static_assert(torch::detail::is_module::value, + static_assert(detail::is_module::value, "Can only call ModuleList::ptr with an nn::Module type"); TORCH_CHECK(index < size(), "Index out of range"); return std::dynamic_pointer_cast(modules_[index]); @@ -138,39 +144,6 @@ class ModuleListImpl : public Module { /// True if there are no modules in the `ModuleList`. bool is_empty() const noexcept { return size() == 0; } - void insert(size_t index, std::shared_ptr module) { - TORCH_CHECK(index <= size(), "Index out of range"); - - if (index == size()) - push_back(std::move(module)); - else { - modules_.insert(modules_.begin() + Iterator::difference_type(index), - std::move(module)); - - for (const auto i : c10::irange(index, size() - 1)) { - (void)i; // Suppress unused variable warning - replace_module(std::to_string(index), modules_[index]); - } - register_module(std::to_string(size() - 1), modules_.back()); - } - } - - /// Unwraps the contained module of a `ModuleHolder` and inserts it in the - /// `ModuleList`. - template - void insert(size_t index, const ModuleHolder& module_holder) { - insert(index, module_holder.ptr()); - } - - /// inserts a new `Module` to the `ModuleList` container, moving or copying - /// it into a `shared_ptr` internally. This method allows passing value types, - /// and letting the container deal with the boxing. - template > - void insert(size_t index, M&& module) { - using Type = std::remove_reference_t; - insert(index, std::make_shared(std::forward(module))); - } - private: template void push_back_var(Head&& head, Tail&&... tail) { diff --git a/src/module/module_test.cpp b/src/module/module_test.cpp new file mode 100644 index 00000000..475f6264 --- /dev/null +++ b/src/module/module_test.cpp @@ -0,0 +1,131 @@ +#include "module.h" + +#include +#include +#include + +namespace llm { + +class ParametersModel : public Module { + public: + ParametersModel(bool bias = true) { + // register some parameters + weight_ = register_parameter("weight", torch::randn({16, 16})); + if (bias) { + bias_ = register_parameter("bias", torch::randn({32, 32})); + } + + sharded_param_ = register_sharded_parameter("sharded_param", + /*dim=*/0, + /*rank=*/0, + /*world_size=*/2, + torch::randn({16, 32})); + undefined_param_ = register_parameter("undefined_param", torch::Tensor()); + } + + torch::Tensor weight_; + torch::Tensor bias_; + + torch::Tensor sharded_param_; + torch::Tensor undefined_param_; +}; + +class SubModel : public Module { + public: + SubModel() { + param1_ = register_parameter("param1", torch::randn({8, 8})); + param2_ = register_parameter("param2", torch::randn({8, 16})); + x_ = register_module("x", std::make_shared(/*bias=*/true)); + y_ = + register_module("y", std::make_shared(/*bias=*/false)); + } + + torch::Tensor param1_; + torch::Tensor param2_; + std::shared_ptr x_; + std::shared_ptr y_; +}; + +class Model : public Module { + public: + Model() { + param1_ = register_parameter("param1", torch::randn({8, 8})); + param2_ = register_parameter("param2", torch::randn({8, 16})); + sub_model_ = register_module("submodel", std::make_shared()); + } + + torch::Tensor param1_; + torch::Tensor param2_; + std::shared_ptr sub_model_; +}; + +TEST(ModuleTest, Parameters) { + std::unordered_map dict = { + {"weight", torch::randn({16, 16})}, + {"bias", torch::randn({32, 32})}, + {"param1", torch::randn({16, 16})}, + {"param2", torch::randn({16, 16})}, + {"sharded_param", torch::randn({32, 32})}, + {"undefined_param", torch::randn({32, 32})}}; + + StateDict state_dict(dict); + EXPECT_EQ(state_dict.size(), dict.size()); + + ParametersModel model; + EXPECT_EQ(model.load(state_dict), 3); + + EXPECT_TRUE(torch::equal(model.weight_, dict["weight"])); + EXPECT_TRUE(torch::equal(model.bias_, dict["bias"])); + EXPECT_TRUE(torch::equal(model.sharded_param_, + dict["sharded_param"].chunk(2, /*dim=*/0)[0])); + EXPECT_FALSE(model.undefined_param_.defined()); + + EXPECT_TRUE(model.verify()); +} + +TEST(ModuleTest, SubModules) { + std::unordered_map dict = { + {"param1", torch::randn({8, 8})}, + {"param2", torch::randn({8, 16})}, + {"submodel.param1", torch::randn({8, 8})}, + {"submodel.param2", torch::randn({8, 16})}, + {"submodel.x.weight", torch::randn({16, 16})}, + {"submodel.x.bias", torch::randn({32, 32})}, + {"submodel.x.sharded_param", torch::randn({32, 32})}, + {"submodel.x.undefined_param", torch::randn({32, 32})}, + {"submodel.y.weight", torch::randn({16, 16})}, + {"submodel.y.bias", torch::randn({32, 32})}, + {"submodel.y.sharded_param", torch::randn({32, 32})}, + {"submodel.y.undefined_param", torch::randn({32, 32})}}; + + StateDict state_dict(dict); + EXPECT_EQ(state_dict.size(), dict.size()); + + Model model; + EXPECT_EQ(model.load(state_dict), 9); + + EXPECT_TRUE(torch::equal(model.param1_, dict["param1"])); + EXPECT_TRUE(torch::equal(model.param2_, dict["param2"])); + // submodel + EXPECT_TRUE(torch::equal(model.sub_model_->param1_, dict["submodel.param1"])); + EXPECT_TRUE(torch::equal(model.sub_model_->param2_, dict["submodel.param2"])); + EXPECT_TRUE( + torch::equal(model.sub_model_->x_->weight_, dict["submodel.x.weight"])); + EXPECT_TRUE( + torch::equal(model.sub_model_->x_->bias_, dict["submodel.x.bias"])); + EXPECT_TRUE( + torch::equal(model.sub_model_->x_->sharded_param_, + dict["submodel.x.sharded_param"].chunk(2, /*dim=*/0)[0])); + EXPECT_FALSE(model.sub_model_->x_->undefined_param_.defined()); + EXPECT_TRUE( + torch::equal(model.sub_model_->y_->weight_, dict["submodel.y.weight"])); + EXPECT_FALSE(model.sub_model_->y_->bias_.defined()); + EXPECT_TRUE( + torch::equal(model.sub_model_->y_->sharded_param_, + dict["submodel.y.sharded_param"].chunk(2, /*dim=*/0)[0])); + EXPECT_FALSE(model.sub_model_->y_->undefined_param_.defined()); + + EXPECT_TRUE(model.verify()); +} + +} // namespace llm diff --git a/src/quantization/qlinear_gptq_marlin_impl.h b/src/quantization/qlinear_gptq_marlin_impl.h index dc6b3f40..e996b479 100644 --- a/src/quantization/qlinear_gptq_marlin_impl.h +++ b/src/quantization/qlinear_gptq_marlin_impl.h @@ -35,11 +35,6 @@ class ColumnParallelQLinearGPTQMarlinImpl : public ParallelLinearImpl { void load_state_dict(const StateDict& state_dict, const std::vector& prefixes) override; - void pretty_print(std::ostream& stream) const override { - stream << name() << " qweight=" << qweight_.sizes() - << " scales=" << scales_.sizes() << " device=" << qweight_.device(); - } - private: // parameter members, must be registered DEFINE_FUSED_WEIGHT(qweight); @@ -91,11 +86,6 @@ class RowParallelQLinearGPTQMarlinImpl : public ParallelLinearImpl { // whether the weight is loaded void verify_loaded_weights(const std::string& prefix = "") const override; - void pretty_print(std::ostream& stream) const override { - stream << name() << " qweight=" << qweight_.sizes() - << " scales=" << scales_.sizes() << " device=" << qweight_.device(); - } - private: // parameter members, must be registered DEFINE_WEIGHT(qweight); diff --git a/src/quantization/qlinear_impl.h b/src/quantization/qlinear_impl.h index c442a8bd..0269246f 100644 --- a/src/quantization/qlinear_impl.h +++ b/src/quantization/qlinear_impl.h @@ -76,12 +76,6 @@ class ColumnParallelQLinearImpl : public ParallelLinearImpl { void load_state_dict(const StateDict& state_dict, const std::vector& prefixes) override; - void pretty_print(std::ostream& stream) const override { - stream << name() << " qweight=" << qweight_.sizes() - << " qzeros=" << qzeros_.sizes() << " scales=" << scales_.sizes() - << " device=" << qweight_.device(); - } - private: // parameter members, must be registered DEFINE_FUSED_WEIGHT(qweight); @@ -148,12 +142,6 @@ class RowParallelQLinearImpl : public ParallelLinearImpl { // whether the weight is loaded void verify_loaded_weights(const std::string& prefix = "") const override; - void pretty_print(std::ostream& stream) const override { - stream << name() << " qweight=" << qweight_.sizes() - << " qzeros=" << qzeros_.sizes() << " scales=" << scales_.sizes() - << " device=" << qweight_.device(); - } - private: // parameter members, must be registered DEFINE_WEIGHT(qweight);