diff --git a/README.md b/README.md
index 8d64eef0..14b8faa2 100644
--- a/README.md
+++ b/README.md
@@ -32,11 +32,12 @@ ScaleLLM
-[ScaleLLM](#) is a cutting-edge inference system engineered for large language models (LLMs), designed to meet the demands of production environments. It extends its support to a wide range of popular open-source models, including [Llama3.1](https://github.com/meta-llama/llama3), [Gemma2](https://github.com/google-deepmind/gemma), Bloom, GPT-NeoX, and more.
+[ScaleLLM](#) is a cutting-edge inference system engineered for large language models (LLMs), designed to meet the demands of production environments. It extends its support to a wide range of popular open-source models, including [Llama3.1](https://github.com/meta-llama/llama3), [Gemma2](https://github.com/google-deepmind/gemma), [Phi](https://huggingface.co/microsoft/phi-2), and more.
ScaleLLM is currently undergoing active development. We are fully committed to consistently enhancing its efficiency while also incorporating additional features. Feel free to explore our [**_Roadmap_**](https://github.com/vectorch-ai/ScaleLLM/issues/84) for more details.
## News:
+* [01/2025] - Optimized inhouse Attention kernels
* [06/2024] - ScaleLLM is now available on [PyPI](https://pypi.org/project/scalellm/). You can install it using `pip install scalellm`.
* [03/2024] - [Advanced features](#advanced-features) support for ✅ [CUDA graph](#cuda-graph), ✅ [prefix cache](#prefix-cache), ✅ [chunked prefill](#chunked-prefill) and ✅ [speculative decoding](#speculative-decoding).
* [11/2023] - [First release](https://github.com/vectorch-ai/ScaleLLM/releases/tag/v0.0.1) with support for popular [open-source models](#supported-models).
@@ -274,21 +275,11 @@ Quantization is a crucial process for reducing the memory footprint of models. S
| Models | Tensor Parallel | Quantization | Chat API | HF models examples |
| :--------: | :-------------: | :----------: | :------: | :---------------------------:|
-| Aquila | Yes | Yes | Yes | [BAAI/Aquila-7B](https://huggingface.co/BAAI/Aquila-7B), [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) |
-| Bloom | Yes | Yes | No | [bigscience/bloom](https://huggingface.co/bigscience/bloom) |
-| Baichuan | Yes | Yes | Yes | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) |
-| ChatGLM4/3 | Yes | Yes | Yes | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
| Gemma2 | Yes | Yes | Yes | [google/gemma-2-2b](https://huggingface.co/google/gemma-2-2b) |
-| GPT_j | Yes | Yes | No | [EleutherAI/gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b) |
-| GPT_NeoX | Yes | Yes | No | [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b) |
| GPT2 | Yes | Yes | No | [gpt2](https://huggingface.co/gpt2)|
-| InternLM | Yes | Yes | Yes | [internlm/internlm-7b](https://huggingface.co/internlm/internlm-7b) |
| Llama3/2 | Yes | Yes | Yes | [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct), [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) |
-| Mistral | Yes | Yes | Yes | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) |
-| MPT | Yes | Yes | Yes | [mosaicml/mpt-30b](https://huggingface.co/mosaicml/mpt-30b) |
| Phi2 | Yes | Yes | No | [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) |
| Qwen2 | Yes | Yes | Yes | [Qwen/Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat) |
-| Yi | Yes | Yes | Yes |[01-ai/Yi-6B](https://huggingface.co/01-ai/Yi-6B), [01-ai/Yi-34B-Chat-4bits](https://huggingface.co/01-ai/Yi-34B-Chat-4bits), [01-ai/Yi-6B-200K](https://huggingface.co/01-ai/Yi-6B-200K) |
If your model is not included in the supported list, we are more than willing to assist you. Please feel free to create a request for adding a new model on [GitHub Issues](https://github.com/vectorch-ai/ScaleLLM/issues).
diff --git a/src/layers/embedding.h b/src/layers/embedding.h
index 9abe231d..9ef07033 100644
--- a/src/layers/embedding.h
+++ b/src/layers/embedding.h
@@ -33,26 +33,6 @@ class EmbeddingImpl : public Module {
return F::embedding(input, weight_);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- const auto weight = state_dict.get_tensor("weight");
- if (weight.defined()) {
- CHECK_EQ(weight_.sizes(), weight.sizes())
- << "weight size mismatch for " << name();
- weight_.copy_(weight);
- is_loaded_ = true;
- }
- }
-
- // whether the weight is loaded
- void verify_loaded_weights(const std::string& prefix) const {
- CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight";
- }
-
- void pretty_print(std::ostream& stream) const override {
- stream << name() << " " << weight_.sizes() << " " << weight_.device();
- }
-
// return the weight (for testing)
torch::Tensor weight() const { return weight_; }
@@ -73,6 +53,7 @@ class ParallelEmbeddingImpl : public Module {
const ParallelArgs& parallel_args,
const torch::TensorOptions& options)
: parallel_args_(parallel_args) {
+ const auto rank = parallel_args_.rank();
const auto world_size = parallel_args_.world_size();
CHECK(embedding_dim % world_size == 0)
<< "out_features " << embedding_dim << " not divisible by world_size "
@@ -80,8 +61,11 @@ class ParallelEmbeddingImpl : public Module {
const int64_t embedding_dim_per_partition = embedding_dim / world_size;
// register the weight parameter
- weight_ = register_parameter(
+ weight_ = register_sharded_parameter(
"weight",
+ /*dim=*/1,
+ rank,
+ world_size,
torch::empty({num_embeddings, embedding_dim_per_partition}, options));
}
@@ -96,30 +80,6 @@ class ParallelEmbeddingImpl : public Module {
return output;
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- const auto weight = state_dict.get_sharded_tensor(
- "weight",
- /*dim=*/1,
- /*rank=*/parallel_args_.rank(),
- /*world_size=*/parallel_args_.world_size());
- if (weight.defined()) {
- CHECK_EQ(weight_.sizes(), weight.sizes())
- << "weight size mismatch for " << name();
- weight_.copy_(weight);
- is_loaded_ = true;
- }
- }
-
- // whether the weight is loaded
- void verify_loaded_weights(const std::string& prefix) const {
- CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight";
- }
-
- void pretty_print(std::ostream& stream) const override {
- stream << name() << " " << weight_.sizes() << " " << weight_.device();
- }
-
// return the weight (for testing)
torch::Tensor weight() const { return weight_; }
@@ -143,14 +103,19 @@ class VocabParallelEmbeddingImpl : public Module {
const ParallelArgs& parallel_args,
const torch::TensorOptions& options)
: parallel_args_(parallel_args) {
- const int64_t num_embeddings_per_partition =
- num_embeddings / parallel_args_.world_size();
- start_index_ = num_embeddings_per_partition * parallel_args_.rank();
+ const auto rank = parallel_args_.rank();
+ const auto world_size = parallel_args_.world_size();
+
+ const int64_t num_embeddings_per_partition = num_embeddings / world_size;
+ start_index_ = num_embeddings_per_partition * rank;
end_index_ = start_index_ + num_embeddings_per_partition;
// register the weight parameter
- weight_ = register_parameter(
+ weight_ = register_sharded_parameter(
"weight",
+ /*dim=*/0,
+ rank,
+ world_size,
torch::empty({num_embeddings_per_partition, embedding_dim}, options));
}
@@ -174,30 +139,6 @@ class VocabParallelEmbeddingImpl : public Module {
return reduce_from_model_parallel_region(output, parallel_args_);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- const auto weight = state_dict.get_sharded_tensor(
- "weight",
- /*dim=*/0,
- /*rank=*/parallel_args_.rank(),
- /*world_size=*/parallel_args_.world_size());
- if (weight.defined()) {
- CHECK_EQ(weight_.sizes(), weight.sizes())
- << "weight size mismatch for " << name();
- weight_.copy_(weight);
- is_loaded_ = true;
- }
- }
-
- // whether the weight is loaded
- void verify_loaded_weights(const std::string& prefix = "") const {
- CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight";
- }
-
- void pretty_print(std::ostream& stream) const override {
- stream << name() << " " << weight_.sizes() << " " << weight_.device();
- }
-
// return the weight (for testing)
torch::Tensor weight() const { return weight_; }
diff --git a/src/layers/fused_linear.cpp b/src/layers/fused_linear.cpp
index 917ab6ec..86b5ca59 100644
--- a/src/layers/fused_linear.cpp
+++ b/src/layers/fused_linear.cpp
@@ -13,11 +13,13 @@ namespace llm {
FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl(
int64_t in_features,
const std::vector
& out_features_vec,
+ const std::vector& prefixes,
bool bias,
bool gather_output,
const QuantArgs& quant_args,
const ParallelArgs& parallel_args,
const torch::TensorOptions& options) {
+ prefixes_ = prefixes;
// check if the linear layers can be fused
fused_ = quant_args.can_be_fused();
if (fused_) {
@@ -72,28 +74,29 @@ std::vector FusedColumnParallelLinearImpl::forward(
return outputs;
}
-void FusedColumnParallelLinearImpl::load_state_dict(
- const StateDict& state_dict,
- const std::vector& prefixes) {
+size_t FusedColumnParallelLinearImpl::load(const StateDict& state_dict,
+ const std::string&) {
if (fused_) {
- fused_linear_->load_state_dict(state_dict, prefixes);
+ fused_linear_->load_state_dict(state_dict, prefixes_);
} else {
- CHECK_EQ(parallel_linears_.size(), prefixes.size());
+ CHECK_EQ(parallel_linears_.size(), prefixes_.size());
for (size_t i = 0; i < parallel_linears_.size(); ++i) {
- parallel_linears_[i]->load_state_dict(state_dict.select(prefixes[i]));
+ parallel_linears_[i]->load_state_dict(state_dict.select(prefixes_[i]));
}
}
+ return 0;
}
-void FusedColumnParallelLinearImpl::verify_loaded_weights(
- const std::string& prefix) const {
+bool FusedColumnParallelLinearImpl::verify(
+ const std::string& name_prefix) const {
if (fused_) {
- fused_linear_->verify_loaded_weights(prefix);
+ fused_linear_->verify_loaded_weights(name_prefix);
} else {
for (const auto& parallel_linear : parallel_linears_) {
- parallel_linear->verify_loaded_weights(prefix);
+ parallel_linear->verify_loaded_weights(name_prefix);
}
}
+ return true;
}
} // namespace llm
diff --git a/src/layers/fused_linear.h b/src/layers/fused_linear.h
index 191922e6..73323479 100644
--- a/src/layers/fused_linear.h
+++ b/src/layers/fused_linear.h
@@ -16,6 +16,7 @@ class FusedColumnParallelLinearImpl : public Module {
public:
FusedColumnParallelLinearImpl(int64_t in_features,
const std::vector& out_features,
+ const std::vector& prefixes,
bool bias,
bool gather_output,
const QuantArgs& quant_args,
@@ -24,11 +25,13 @@ class FusedColumnParallelLinearImpl : public Module {
std::vector forward(torch::Tensor input);
- // load_state_dict for fused weights
- void load_state_dict(const StateDict& state_dict,
- const std::vector& prefixes);
+ // load weights from the checkpoint, override this method if necessary
+ // returns the number of loaded parameters
+ size_t load(const StateDict& state_dict,
+ const std::string& name_prefix = std::string()) override;
- void verify_loaded_weights(const std::string& prefix = "") const;
+ // verify whether the weights are loaded, override this method if necessary
+ bool verify(const std::string& name_prefix = std::string()) const override;
// whether the linear layer is fused
bool fused() const { return fused_; }
@@ -43,6 +46,8 @@ class FusedColumnParallelLinearImpl : public Module {
// sizes for each split
std::vector split_sizes_;
+ std::vector prefixes_;
+
// whether the linear layer is fused
bool fused_ = false;
};
diff --git a/src/layers/linear_impl.cpp b/src/layers/linear_impl.cpp
index 3f301868..7b6d8f04 100644
--- a/src/layers/linear_impl.cpp
+++ b/src/layers/linear_impl.cpp
@@ -18,6 +18,7 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl(
const ParallelArgs& parallel_args,
const torch::TensorOptions& options)
: gather_output_(gather_output), parallel_args_(parallel_args) {
+ const auto rank = parallel_args_.rank();
const auto world_size = parallel_args_.world_size();
CHECK(out_features % world_size == 0)
<< "out_features " << out_features << " not divisible by world_size "
@@ -26,13 +27,20 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl(
// Note: torch.nn.functional.linear performs XA^T + b and as a result
// we allocate the transpose.
- weight_ = register_parameter(
+ weight_ = register_sharded_parameter(
"weight",
+ /*dim=*/0,
+ rank,
+ world_size,
torch::empty({out_features_per_partition, in_features}, options));
if (bias) {
- bias_ = register_parameter(
- "bias", torch::empty({out_features_per_partition}, options));
+ bias_ = register_sharded_parameter(
+ "bias",
+ /*dim=*/0,
+ rank,
+ world_size,
+ torch::empty({out_features_per_partition}, options));
}
}
@@ -93,14 +101,18 @@ RowParallelLinearImpl::RowParallelLinearImpl(
const torch::TensorOptions& options)
: input_is_parallelized_(input_is_parallelized),
parallel_args_(parallel_args) {
+ const auto rank = parallel_args_.rank();
const auto world_size = parallel_args_.world_size();
CHECK(in_features % world_size == 0)
<< "in_features " << in_features << " not divisible by world_size "
<< world_size;
const int64_t in_features_per_partition = in_features / world_size;
// Allocate the transpose since linear performs XA^T.
- weight_ = register_parameter(
+ weight_ = register_sharded_parameter(
"weight",
+ /*dim=*/1,
+ rank,
+ world_size,
torch::empty({out_features, in_features_per_partition}, options));
if (bias) {
diff --git a/src/layers/linear_impl.h b/src/layers/linear_impl.h
index fe1f6b7c..ff551649 100644
--- a/src/layers/linear_impl.h
+++ b/src/layers/linear_impl.h
@@ -42,10 +42,6 @@ class ColumnParallelLinearImpl : public ParallelLinearImpl {
<< "bias is not loaded for " << prefix + "bias";
}
- void pretty_print(std::ostream& stream) const override {
- stream << name() << " " << weight_.sizes() << " " << weight_.device();
- }
-
// return the weight (for testing)
torch::Tensor weight() const { return weight_; }
@@ -95,10 +91,6 @@ class RowParallelLinearImpl : public ParallelLinearImpl {
<< "bias is not loaded for " << prefix + "bias";
}
- void pretty_print(std::ostream& stream) const override {
- stream << name() << " " << weight_.sizes() << " " << weight_.device();
- }
-
// return the weight (for testing)
torch::Tensor weight() const { return weight_; }
diff --git a/src/layers/normalization.h b/src/layers/normalization.h
index 8996ce78..da75a81b 100644
--- a/src/layers/normalization.h
+++ b/src/layers/normalization.h
@@ -94,38 +94,6 @@ class LayerNormImpl : public Module {
input, normalized_shape_, weight_, bias_, eps_);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- const auto weight = state_dict.get_tensor("weight");
- if (weight.defined()) {
- CHECK_EQ(weight_.sizes(), weight.sizes())
- << "weight size mismatch for " << name();
- weight_.copy_(weight);
- weight_is_loaded_ = true;
- }
- if (bias_.defined()) {
- const auto bias = state_dict.get_tensor("bias");
- if (bias.defined()) {
- CHECK_EQ(bias_.sizes(), bias.sizes())
- << "bias size mismatch for " << name();
- bias_.copy_(bias);
- bias_is_loaded_ = true;
- }
- }
- }
-
- // whether the weight is loaded
- void verify_loaded_weights(const std::string& prefix = "") const {
- CHECK(weight_is_loaded_)
- << "weight is not loaded for " << prefix + "weight";
- CHECK(!bias_.defined() || bias_is_loaded_)
- << "bias is not loaded for " << prefix + "bias";
- }
-
- void pretty_print(std::ostream& stream) const override {
- stream << name() << " " << weight_.sizes() << " " << weight_.device();
- }
-
private:
// parameter members, must be registered
torch::Tensor weight_{nullptr};
@@ -159,26 +127,6 @@ class RMSNormImpl : public Module {
return detail::rms_norm(input, weight_, eps_);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- const auto weight = state_dict.get_tensor("weight");
- if (weight.defined()) {
- CHECK_EQ(weight_.sizes(), weight.sizes())
- << "weight size mismatch for " << name();
- weight_.copy_(weight);
- is_loaded_ = true;
- }
- }
-
- // whether the weight is loaded
- void verify_loaded_weights(const std::string& prefix = "") const {
- CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight";
- }
-
- void pretty_print(std::ostream& stream) const override {
- stream << name() << " " << weight_.sizes() << " " << weight_.device();
- }
-
private:
// parameter members, must be registered
torch::Tensor weight_{nullptr};
@@ -207,26 +155,6 @@ class GemmaRMSNormImpl : public Module {
return detail::gemma_rms_norm(input, weight_, eps_);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- const auto weight = state_dict.get_tensor("weight");
- if (weight.defined()) {
- CHECK_EQ(weight_.sizes(), weight.sizes())
- << "weight size mismatch for " << name();
- weight_.copy_(weight);
- is_loaded_ = true;
- }
- }
-
- // whether the weight is loaded
- void verify_loaded_weights(const std::string& prefix = "") const {
- CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight";
- }
-
- void pretty_print(std::ostream& stream) const override {
- stream << name() << " " << weight_.sizes() << " " << weight_.device();
- }
-
private:
// parameter members, must be registered
torch::Tensor weight_{nullptr};
@@ -268,26 +196,6 @@ class RMSNormResidualImpl : public Module {
return detail::rms_norm(input, weight_, eps_);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- const auto weight = state_dict.get_tensor("weight");
- if (weight.defined()) {
- CHECK_EQ(weight_.sizes(), weight.sizes())
- << "weight size mismatch for " << name();
- weight_.copy_(weight);
- is_loaded_ = true;
- }
- }
-
- // whether the weight is loaded
- void verify_loaded_weights(const std::string& prefix = "") const {
- CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight";
- }
-
- void pretty_print(std::ostream& stream) const override {
- stream << name() << " " << weight_.sizes() << " " << weight_.device();
- }
-
private:
// parameter members, must be registered
torch::Tensor weight_{nullptr};
diff --git a/src/layers/normalization_test.cpp b/src/layers/normalization_test.cpp
index ba25019b..c2aa4031 100644
--- a/src/layers/normalization_test.cpp
+++ b/src/layers/normalization_test.cpp
@@ -26,8 +26,8 @@ TEST(NormalizationTest, LayerNorm) {
LayerNorm norm(dim, eps, /*bias=*/true, options);
// test load state dict
- norm->load_state_dict(state_dict);
- norm->verify_loaded_weights();
+ norm->load(state_dict);
+ EXPECT_TRUE(norm->verify());
// verify output
const auto input = torch::randn({100, dim});
@@ -88,8 +88,8 @@ TEST(NormalizationTest, RMSNorm) {
RMSNorm norm(dim, eps, options);
// test load state dict
- norm->load_state_dict(state_dict);
- norm->verify_loaded_weights();
+ norm->load(state_dict);
+ EXPECT_TRUE(norm->verify());
// verify output
const auto input = torch::randn({100, dim});
diff --git a/src/layers/qkv_linear.cpp b/src/layers/qkv_linear.cpp
index fe4eab14..8ef9b532 100644
--- a/src/layers/qkv_linear.cpp
+++ b/src/layers/qkv_linear.cpp
@@ -10,15 +10,21 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl(
int64_t n_heads,
int64_t n_kv_heads,
int64_t head_dim,
+ const std::vector& prefixes,
bool bias,
bool gather_output,
const QuantArgs& quant_args,
const ParallelArgs& parallel_args,
- const torch::TensorOptions& options)
- : n_kv_heads_(n_kv_heads), head_dim_(head_dim) {
+ const torch::TensorOptions& options) {
+ CHECK_EQ(prefixes.size(), 3)
+ << "prefixes size must be 3 for q, k, v projections";
+
// calculate logical kv heads with support of MQA/GQA
const int32_t world_size = parallel_args.world_size();
int64_t effective_kv_heads = n_kv_heads;
+ // replication ratio of kv heads for MQA/GQA cases
+ int64_t kv_replication_ratio = 0;
+
if (n_kv_heads >= world_size) {
// partition kv heads evenly across world_size for MHA
CHECK_EQ(n_kv_heads % world_size, 0)
@@ -27,7 +33,7 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl(
// replicate kv heads evenly across world_size for GQA/MQA
CHECK_EQ(world_size % n_kv_heads, 0)
<< "kv heads can't be replicated evenly across world_size";
- kv_replication_ratio_ = world_size / n_kv_heads;
+ kv_replication_ratio = world_size / n_kv_heads;
effective_kv_heads = world_size;
}
@@ -36,44 +42,43 @@ QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl(
effective_kv_heads * head_dim,
effective_kv_heads * head_dim};
- parallel_linear_ = register_module("parallel_linear",
- FusedColumnParallelLinear(hidden_size,
- out_features,
- bias,
- gather_output,
- quant_args,
- parallel_args,
- options));
-}
-
-// special load_state_dict for fused cases
-void QKVColumnParallelLinearImpl::load_state_dict(
- const StateDict& state_dict,
- const std::vector& prefixes,
- const std::vector& kv_prefixes) {
- if (kv_replication_ratio_ > 1) {
- // replicate kv heads
- auto kv_replicated_state_dict = state_dict.select_with_transform(
- "", [&](const std::string& name, const torch::Tensor& tensor) {
- for (const auto& kv_prefix : kv_prefixes) {
- if (absl::StartsWith(name, kv_prefix)) {
+ // create state_dict selector to handle MQA/GQA cases
+ // for MQA/GQA cases, we need to replicate the weights of kv heads.
+ auto state_dict_selector = [=](const StateDict& sd, const std::string&) {
+ if (kv_replication_ratio <= 1) {
+ return sd.select("");
+ }
+ // replicate kv heads for MQA/GQA cases
+ return sd.select_with_transform(
+ "", [=](const std::string& tensor_name, const torch::Tensor& tensor) {
+ // skip query weights
+ for (size_t i = 1; i < prefixes.size(); ++i) {
+ const auto& kv_prefix = prefixes[i];
+ if (absl::StartsWith(tensor_name, kv_prefix)) {
// reshape to [n_kv_heads, head_dim, ...]
- auto reshaped_tensor =
- tensor.reshape({n_kv_heads_, head_dim_, -1});
+ auto reshaped_tensor = tensor.reshape({n_kv_heads, head_dim, -1});
// interleave repeat kv heads along kv_head dim
reshaped_tensor = reshaped_tensor.repeat_interleave(
- kv_replication_ratio_, /*dim=*/0);
+ kv_replication_ratio, /*dim=*/0);
// reshape to [n_kv_heads * kv_replication_ratio * head_dim, ...]
return reshaped_tensor.reshape(
- {n_kv_heads_ * kv_replication_ratio_ * head_dim_, -1});
+ {n_kv_heads * kv_replication_ratio * head_dim, -1});
}
}
return tensor;
});
- parallel_linear_->load_state_dict(kv_replicated_state_dict, prefixes);
- } else {
- parallel_linear_->load_state_dict(state_dict, prefixes);
- }
+ };
+
+ parallel_linear_ = register_module("qkv_parallel_linear",
+ FusedColumnParallelLinear(hidden_size,
+ out_features,
+ prefixes,
+ bias,
+ gather_output,
+ quant_args,
+ parallel_args,
+ options),
+ state_dict_selector);
}
} // namespace llm
diff --git a/src/layers/qkv_linear.h b/src/layers/qkv_linear.h
index 2e4e2e39..64923e27 100644
--- a/src/layers/qkv_linear.h
+++ b/src/layers/qkv_linear.h
@@ -20,6 +20,7 @@ class QKVColumnParallelLinearImpl : public Module {
int64_t n_heads,
int64_t n_kv_heads,
int64_t head_dim,
+ const std::vector& prefixes,
bool bias,
bool gather_output,
const QuantArgs& quant_args,
@@ -30,24 +31,9 @@ class QKVColumnParallelLinearImpl : public Module {
return parallel_linear_->forward(input);
}
- // special load_state_dict for fused cases
- void load_state_dict(const StateDict& state_dict,
- const std::vector& prefixes,
- const std::vector& kv_prefixes);
-
- void verify_loaded_weights(const std::string& prefix = "") const {
- parallel_linear_->verify_loaded_weights(prefix);
- }
-
private:
+ // registered modules
FusedColumnParallelLinear parallel_linear_{nullptr};
-
- // replication ratio of kv heads for MQA/GQA cases
- int64_t kv_replication_ratio_ = 0;
-
- int64_t n_kv_heads_ = 0;
-
- int64_t head_dim_ = 0;
};
LLM_MODULE(QKVColumnParallelLinear);
diff --git a/src/layers/qkv_linear_test.cpp b/src/layers/qkv_linear_test.cpp
index 359b68e3..86ae2f6b 100644
--- a/src/layers/qkv_linear_test.cpp
+++ b/src/layers/qkv_linear_test.cpp
@@ -9,7 +9,7 @@
namespace llm {
-class QKVLinearTest
+class QKVColumnParallelLinearTest
: public ::testing::TestWithParam> {};
-TEST_P(QKVLinearTest, LoadFusedWeight) {
+TEST_P(QKVColumnParallelLinearTest, LoadFusedWeight) {
const auto& [n_tokens, n_heads, n_kv_heads, n_shards, head_dim, hidden_size] =
GetParam();
@@ -51,18 +51,19 @@ TEST_P(QKVLinearTest, LoadFusedWeight) {
for (int32_t shard_id = 0; shard_id < n_shards; ++shard_id) {
QuantArgs quant_args;
ParallelArgs parallel_args(shard_id, n_shards, nullptr);
- QKVColumnParallelLinearImpl linear(hidden_size,
- n_heads,
- n_kv_heads,
- head_dim,
- /*bias=*/false,
- /*gather_output=*/false,
- quant_args,
- parallel_args,
- options);
- linear.load_state_dict(state_dict,
- /*prefixes=*/{"query.", "key.", "value."},
- /*kv_prefixes=*/{"key.", "value."});
+ QKVColumnParallelLinearImpl linear(
+ hidden_size,
+ n_heads,
+ n_kv_heads,
+ head_dim,
+ std::vector{"query.", "key.", "value."},
+ /*bias=*/false,
+ /*gather_output=*/false,
+ quant_args,
+ parallel_args,
+ options);
+ linear.load(state_dict);
+ EXPECT_TRUE(linear.verify());
// generate random input and compare with the output
auto input = torch::randn({n_tokens, hidden_size}, options);
@@ -84,7 +85,7 @@ TEST_P(QKVLinearTest, LoadFusedWeight) {
INSTANTIATE_TEST_SUITE_P(
QKVLinearTestSuite,
- QKVLinearTest,
+ QKVColumnParallelLinearTest,
::testing::Combine(::testing::Values(10, 32), // n_tokens
::testing::Values(8, 16, 32), // n_heads
::testing::Values(1, 2, 4, 8), // n_kv_heads
diff --git a/src/model_loader/state_dict.cpp b/src/model_loader/state_dict.cpp
index bdfc96b9..f20f9109 100644
--- a/src/model_loader/state_dict.cpp
+++ b/src/model_loader/state_dict.cpp
@@ -219,4 +219,14 @@ StateDict StateDict::select_with_transform(
return selected;
}
+StateDict StateDict::select_with_transform(
+ const std::string& prefix,
+ std::function transform_func) const {
+ return select_with_transform(
+ prefix,
+ [transform_func](const std::string&, const torch::Tensor& tensor) {
+ return transform_func(tensor);
+ });
+}
+
} // namespace llm
diff --git a/src/model_loader/state_dict.h b/src/model_loader/state_dict.h
index 58c51039..d130c53e 100644
--- a/src/model_loader/state_dict.h
+++ b/src/model_loader/state_dict.h
@@ -46,6 +46,10 @@ class StateDict final {
StateDict select_with_transform(const std::string& prefix,
TensorTransform transform_func) const;
+ StateDict select_with_transform(
+ const std::string& prefix,
+ std::function transform_func) const;
+
size_t size() const { return dict_.size(); }
std::string_view prefix() const { return prefix_; }
diff --git a/src/models/aquila.h b/src/models/_deprecated/aquila.h
similarity index 100%
rename from src/models/aquila.h
rename to src/models/_deprecated/aquila.h
diff --git a/src/models/baichuan.h b/src/models/_deprecated/baichuan.h
similarity index 100%
rename from src/models/baichuan.h
rename to src/models/_deprecated/baichuan.h
diff --git a/src/models/bloom.h b/src/models/_deprecated/bloom.h
similarity index 100%
rename from src/models/bloom.h
rename to src/models/_deprecated/bloom.h
diff --git a/src/models/chatglm.h b/src/models/_deprecated/chatglm.h
similarity index 100%
rename from src/models/chatglm.h
rename to src/models/_deprecated/chatglm.h
diff --git a/src/models/gpt_j.h b/src/models/_deprecated/gpt_j.h
similarity index 100%
rename from src/models/gpt_j.h
rename to src/models/_deprecated/gpt_j.h
diff --git a/src/models/gpt_neox.h b/src/models/_deprecated/gpt_neox.h
similarity index 100%
rename from src/models/gpt_neox.h
rename to src/models/_deprecated/gpt_neox.h
diff --git a/src/models/internlm.h b/src/models/_deprecated/internlm.h
similarity index 100%
rename from src/models/internlm.h
rename to src/models/_deprecated/internlm.h
diff --git a/src/models/mistral.h b/src/models/_deprecated/mistral.h
similarity index 100%
rename from src/models/mistral.h
rename to src/models/_deprecated/mistral.h
diff --git a/src/models/mpt.h b/src/models/_deprecated/mpt.h
similarity index 100%
rename from src/models/mpt.h
rename to src/models/_deprecated/mpt.h
diff --git a/src/models/simple_model.h b/src/models/_deprecated/simple_model.h
similarity index 100%
rename from src/models/simple_model.h
rename to src/models/_deprecated/simple_model.h
diff --git a/src/models/qwen.h b/src/models/alibaba/qwen.h
similarity index 77%
rename from src/models/qwen.h
rename to src/models/alibaba/qwen.h
index d1867e4e..61cf7490 100644
--- a/src/models/qwen.h
+++ b/src/models/alibaba/qwen.h
@@ -10,6 +10,7 @@
#include "layers/attention/attention.h"
#include "layers/attention/handler.h"
#include "layers/embedding.h"
+#include "layers/fused_linear.h"
#include "layers/linear.h"
#include "layers/normalization.h"
#include "memory/kv_cache.h"
@@ -20,7 +21,7 @@
#include "module/module_holder.h"
#include "module/module_list.h"
// QWen model compatible with huggingface weights
-// adopted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
+// Adapted from https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
namespace llm::hf {
class QWenMLPImpl : public Module {
@@ -38,14 +39,18 @@ class QWenMLPImpl : public Module {
const int64_t intermediate_size = args.intermediate_size() / 2;
// register the weight parameter
- w1_w2_proj_ = register_module("gate_up_proj",
- ColumnParallelLinear(hidden_size,
- intermediate_size * 2,
- /*bias=*/false,
- /*gather_output=*/false,
- quant_args,
- parallel_args,
- options));
+ gate_up_proj_ = register_module(
+ "gate_up_proj",
+ FusedColumnParallelLinear(
+ hidden_size,
+ std::vector{intermediate_size, intermediate_size},
+ std::vector{"w1.", "w2."},
+ /*bias=*/false,
+ /*gather_output=*/false,
+ quant_args,
+ parallel_args,
+ options),
+ /*selector=*/nullptr);
c_proj_ = register_module("c_proj",
RowParallelLinear(intermediate_size,
hidden_size,
@@ -57,26 +62,13 @@ class QWenMLPImpl : public Module {
}
torch::Tensor forward(torch::Tensor x) {
- auto gate_up_proj = w1_w2_proj_(x);
- auto chunks = gate_up_proj.chunk(/*chunks=*/2, /*dim=*/-1);
- return c_proj_(chunks[0] * act_(chunks[1]));
- }
-
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- w1_w2_proj_->load_state_dict(state_dict, {"w1.", "w2."});
- c_proj_->load_state_dict(state_dict.select("c_proj."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- w1_w2_proj_->verify_loaded_weights(prefix + "[w1,w2].");
- c_proj_->verify_loaded_weights(prefix + "c_proj.");
+ const auto gate_up = gate_up_proj_(x);
+ return c_proj_(gate_up[0] * act_(gate_up[1]));
}
private:
// parameter members, must be registered
- ColumnParallelLinear w1_w2_proj_{nullptr};
+ FusedColumnParallelLinear gate_up_proj_{nullptr};
RowParallelLinear c_proj_{nullptr};
ActFunc act_{nullptr};
@@ -133,18 +125,6 @@ class QWenAttentionImpl : public Module {
return c_proj_(output);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- c_attn_->load_state_dict(state_dict.select("c_attn."));
- c_proj_->load_state_dict(state_dict.select("c_proj."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- c_attn_->verify_loaded_weights(prefix + "c_attn.");
- c_proj_->verify_loaded_weights(prefix + "c_proj.");
- }
-
private:
// parameter members, must be registered
ColumnParallelLinear c_attn_{nullptr};
@@ -183,22 +163,6 @@ class QWenBlockImpl : public Module {
return h + mlp_(ln_2_(h));
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- attn_->load_state_dict(state_dict.select("attn."));
- mlp_->load_state_dict(state_dict.select("mlp."));
- ln_1_->load_state_dict(state_dict.select("ln_1."));
- ln_2_->load_state_dict(state_dict.select("ln_2."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- attn_->verify_loaded_weights(prefix + "attn.");
- mlp_->verify_loaded_weights(prefix + "mlp.");
- ln_1_->verify_loaded_weights(prefix + "ln_1.");
- ln_2_->verify_loaded_weights(prefix + "ln_2.");
- }
-
private:
// parameter members, must be registered
QWenAttention attn_{nullptr};
@@ -226,7 +190,7 @@ class QWenModelImpl : public Module {
handler_ = AttentionHandler::create_handler_with_rope(
args, /*interleaved=*/false, options);
- blocks_ = register_module("layers", ModuleList());
+ blocks_ = register_module("h", ModuleList());
layers_.reserve(args.n_layers());
for (int32_t i = 0; i < args.n_layers(); i++) {
auto block =
@@ -254,26 +218,6 @@ class QWenModelImpl : public Module {
return ln_f_(h);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- wte_->load_state_dict(state_dict.select("wte."));
- // call each layer's load_state_dict function
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->load_state_dict(
- state_dict.select("h." + std::to_string(i) + "."));
- }
- ln_f_->load_state_dict(state_dict.select("ln_f."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- wte_->verify_loaded_weights(prefix + "wte.");
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->verify_loaded_weights(prefix + "h." + std::to_string(i) +
- ".");
- }
- ln_f_->verify_loaded_weights(prefix + "ln_f.");
- }
-
private:
// parameter members, must be registered
ParallelEmbedding wte_{nullptr};
@@ -331,17 +275,6 @@ class QWenForCausalLMImpl : public Module {
return lm_head_(h);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- transformer_->load_state_dict(state_dict.select("transformer."));
- lm_head_->load_state_dict(state_dict.select("lm_head."));
- }
-
- void verify_loaded_weights() const {
- transformer_->verify_loaded_weights("transformer.");
- lm_head_->verify_loaded_weights("lm_head.");
- }
-
private:
// parameter members, must be registered
QWenModel transformer_{nullptr};
diff --git a/src/models/qwen2.h b/src/models/alibaba/qwen2.h
similarity index 78%
rename from src/models/qwen2.h
rename to src/models/alibaba/qwen2.h
index d93de88a..84758c2d 100644
--- a/src/models/qwen2.h
+++ b/src/models/alibaba/qwen2.h
@@ -43,11 +43,13 @@ class QWen2MLPImpl : public Module {
FusedColumnParallelLinear(
hidden_size,
std::vector{intermediate_size, intermediate_size},
+ std::vector{"gate_proj.", "up_proj."},
/*bias=*/false,
/*gather_output=*/false,
quant_args,
parallel_args,
- options));
+ options),
+ /*selector=*/nullptr);
down_proj_ =
register_module("down_proj",
RowParallelLinear(intermediate_size,
@@ -64,18 +66,6 @@ class QWen2MLPImpl : public Module {
return down_proj_(act_func_(gate_up[0]) * gate_up[1]);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."});
- down_proj_->load_state_dict(state_dict.select("down_proj."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj].");
- down_proj_->verify_loaded_weights(prefix + "down_proj.");
- }
-
private:
// parameter members, must be registered
FusedColumnParallelLinear gate_up_proj_{nullptr};
@@ -104,16 +94,20 @@ class QWen2AttentionImpl : public Module {
std::max(1, n_kv_heads / world_size);
// register submodules
- qkv_proj_ = register_module("qkv_proj",
- QKVColumnParallelLinear(hidden_size,
- n_heads,
- n_kv_heads,
- head_dim,
- /*bias=*/true,
- /*gather_output=*/false,
- quant_args,
- parallel_args,
- options));
+ qkv_proj_ = register_module(
+ "qkv_proj",
+ QKVColumnParallelLinear(
+ hidden_size,
+ n_heads,
+ n_kv_heads,
+ head_dim,
+ std::vector{"q_proj.", "k_proj.", "v_proj."},
+ /*bias=*/true,
+ /*gather_output=*/false,
+ quant_args,
+ parallel_args,
+ options),
+ /*selector=*/nullptr);
o_proj_ = register_module("o_proj",
RowParallelLinear(hidden_size,
@@ -146,19 +140,6 @@ class QWen2AttentionImpl : public Module {
return o_proj_(output);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- qkv_proj_->load_state_dict(
- state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."});
- o_proj_->load_state_dict(state_dict.select("o_proj."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj].");
- o_proj_->verify_loaded_weights(prefix + "o_proj.");
- }
-
private:
// parameter members, must be registered
QKVColumnParallelLinear qkv_proj_{nullptr};
@@ -208,24 +189,6 @@ class QWen2DecoderLayerImpl : public Module {
return hidden_states;
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- self_attn_->load_state_dict(state_dict.select("self_attn."));
- mlp_->load_state_dict(state_dict.select("mlp."));
- input_layernorm_->load_state_dict(state_dict.select("input_layernorm."));
- post_attention_layernorm_->load_state_dict(
- state_dict.select("post_attention_layernorm."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- self_attn_->verify_loaded_weights(prefix + "self_attn.");
- mlp_->verify_loaded_weights(prefix + "mlp.");
- input_layernorm_->verify_loaded_weights(prefix + "input_layernorm.");
- post_attention_layernorm_->verify_loaded_weights(
- prefix + "post_attention_layernorm.");
- }
-
private:
// parameter members, must be registered
QWen2Attention self_attn_{nullptr};
@@ -291,26 +254,6 @@ class QWen2ModelImpl : public Module {
return norm_(h, residual);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- embed_tokens_->load_state_dict(state_dict.select("embed_tokens."));
- // call each layer's load_state_dict function
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->load_state_dict(
- state_dict.select("layers." + std::to_string(i) + "."));
- }
- norm_->load_state_dict(state_dict.select("norm."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- embed_tokens_->verify_loaded_weights(prefix + "embed_tokens.");
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) +
- ".");
- }
- norm_->verify_loaded_weights(prefix + "norm.");
- }
-
private:
// parameter members, must be registered
ParallelEmbedding embed_tokens_{nullptr};
@@ -368,17 +311,6 @@ class QWen2ForCausalLMImpl : public Module {
return lm_head_(h);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- model_->load_state_dict(state_dict.select("model."));
- lm_head_->load_state_dict(state_dict.select("lm_head."));
- }
-
- void verify_loaded_weights() const {
- model_->verify_loaded_weights("model.");
- lm_head_->verify_loaded_weights("lm_head.");
- }
-
private:
// parameter members, must be registered
QWen2Model model_{nullptr};
diff --git a/src/models/causal_lm.h b/src/models/causal_lm.h
index 02a54e29..eaed0ffe 100644
--- a/src/models/causal_lm.h
+++ b/src/models/causal_lm.h
@@ -70,11 +70,15 @@ class CausalLMImpl : public CausalLM {
}
void load_state_dict(const StateDict& state_dict) override {
- model_->load_state_dict(state_dict);
+ model_->load(state_dict);
}
void verify_loaded_weights() const override {
- return model_->verify_loaded_weights();
+ bool success = model_->verify();
+ if (!success) {
+ LOG(FATAL) << "Failed to verify loaded weights for the model."
+ << " Please check the log for more details.";
+ }
}
torch::Device device() const override { return options_.device(); }
diff --git a/src/models/deepseek/README.md b/src/models/deepseek/README.md
new file mode 100644
index 00000000..bc60bf4b
--- /dev/null
+++ b/src/models/deepseek/README.md
@@ -0,0 +1 @@
+TODO:
diff --git a/src/models/gemma.h b/src/models/google/gemma.h
similarity index 77%
rename from src/models/gemma.h
rename to src/models/google/gemma.h
index dcb2c692..7934c9b6 100644
--- a/src/models/gemma.h
+++ b/src/models/google/gemma.h
@@ -11,6 +11,7 @@
#include "layers/attention/handler.h"
#include "layers/embedding.h"
#include "layers/linear.h"
+#include "layers/linear_impl.h"
#include "layers/normalization.h"
#include "layers/qkv_linear.h"
#include "memory/kv_cache.h"
@@ -42,11 +43,13 @@ class GemmaMLPImpl : public Module {
FusedColumnParallelLinear(
hidden_size,
std::vector{intermediate_size, intermediate_size},
+ std::vector{"gate_proj.", "up_proj."},
/*bias=*/false,
/*gather_output=*/false,
quant_args,
parallel_args,
- options));
+ options),
+ /*selector=*/nullptr);
down_proj_ =
register_module("down_proj",
RowParallelLinear(intermediate_size,
@@ -63,18 +66,6 @@ class GemmaMLPImpl : public Module {
return down_proj_(act_func_(gate_up[0]) * gate_up[1]);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."});
- down_proj_->load_state_dict(state_dict.select("down_proj."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj].");
- down_proj_->verify_loaded_weights(prefix + "down_proj.");
- }
-
private:
// parameter members, must be registered
FusedColumnParallelLinear gate_up_proj_{nullptr};
@@ -102,16 +93,20 @@ class GemmaAttentionImpl : public Module {
std::max(1, n_kv_heads / world_size);
// register submodules
- qkv_proj_ = register_module("qkv_proj",
- QKVColumnParallelLinear(hidden_size,
- n_heads,
- n_kv_heads,
- head_dim,
- /*bias=*/false,
- /*gather_output=*/false,
- quant_args,
- parallel_args,
- options));
+ qkv_proj_ = register_module(
+ "qkv_proj",
+ QKVColumnParallelLinear(
+ hidden_size,
+ n_heads,
+ n_kv_heads,
+ head_dim,
+ std::vector{"q_proj.", "k_proj.", "v_proj."},
+ /*bias=*/false,
+ /*gather_output=*/false,
+ quant_args,
+ parallel_args,
+ options),
+ /*selector=*/nullptr);
o_proj_ = register_module("o_proj",
RowParallelLinear(n_heads * head_dim,
@@ -141,19 +136,6 @@ class GemmaAttentionImpl : public Module {
return o_proj_(output);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- qkv_proj_->load_state_dict(
- state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."});
- o_proj_->load_state_dict(state_dict.select("o_proj."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj].");
- o_proj_->verify_loaded_weights(prefix + "o_proj.");
- }
-
private:
// parameter members, must be registered
QKVColumnParallelLinear qkv_proj_{nullptr};
@@ -207,21 +189,6 @@ class GemmaDecoderLayerImpl : public Module {
hidden_states += residual;
return hidden_states;
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- input_layernorm_->load_state_dict((state_dict.select("input_layernorm.")));
- mlp_->load_state_dict(state_dict.select("mlp."));
- post_attention_layernorm_->load_state_dict(
- (state_dict.select("post_attention_layernorm.")));
- self_attn_->load_state_dict(state_dict.select("self_attn."));
- }
- void verify_loaded_weights(const std::string& prefix) const {
- self_attn_->verify_loaded_weights(prefix + "self_attn.");
- mlp_->verify_loaded_weights(prefix + "mlp.");
- input_layernorm_->verify_loaded_weights(prefix + "input_layernorm.");
- post_attention_layernorm_->verify_loaded_weights(
- prefix + "post_attention_layernorm.");
- }
private:
GemmaAttention self_attn_{nullptr};
@@ -286,26 +253,6 @@ class GemmaModelImpl : public Module {
return norm_(h);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- embed_tokens_->load_state_dict(state_dict.select("embed_tokens."));
- // call each layer's load_state_dict function
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->load_state_dict(
- state_dict.select("layers." + std::to_string(i) + "."));
- }
- norm_->load_state_dict((state_dict.select("norm.")));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- embed_tokens_->verify_loaded_weights(prefix + "embed_tokens.");
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) +
- ".");
- }
- norm_->verify_loaded_weights(prefix + "norm.");
- }
-
private:
ModelArgs modelArgs_;
@@ -337,7 +284,7 @@ class GemmaForCausalLMImpl : public Module {
model_ = register_module(
"model", GemmaModel(args, quant_args, parallel_args, options));
- lm_head_ = register_module("lm_head",
+ lm_head_ = register_module("model.embed_tokens",
ColumnParallelLinear(args.hidden_size(),
args.vocab_size(),
/*bias=*/false,
@@ -370,19 +317,6 @@ class GemmaForCausalLMImpl : public Module {
return lm_head_(h);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- model_->load_state_dict(state_dict.select("model."));
-
- // Share the embedding weights with the final llm_head layer.
- lm_head_->load_state_dict(state_dict.select("model.embed_tokens."));
- }
-
- void verify_loaded_weights() const {
- model_->verify_loaded_weights("model.");
- lm_head_->verify_loaded_weights("model.embed_tokens.");
- }
-
private:
// parameter members, must be registered
GemmaModel model_{nullptr};
diff --git a/src/models/gemma2.h b/src/models/google/gemma2.h
similarity index 76%
rename from src/models/gemma2.h
rename to src/models/google/gemma2.h
index 6282af6e..6b4e75d6 100644
--- a/src/models/gemma2.h
+++ b/src/models/google/gemma2.h
@@ -42,11 +42,13 @@ class Gemma2MLPImpl : public Module {
FusedColumnParallelLinear(
hidden_size,
std::vector{intermediate_size, intermediate_size},
+ std::vector{"gate_proj.", "up_proj."},
/*bias=*/false,
/*gather_output=*/false,
quant_args,
parallel_args,
- options));
+ options),
+ /*selector=*/nullptr);
down_proj_ =
register_module("down_proj",
RowParallelLinear(intermediate_size,
@@ -63,18 +65,6 @@ class Gemma2MLPImpl : public Module {
return down_proj_(act_func_(gate_up[0]) * gate_up[1]);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."});
- down_proj_->load_state_dict(state_dict.select("down_proj."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj].");
- down_proj_->verify_loaded_weights(prefix + "down_proj.");
- }
-
private:
// parameter members, must be registered
FusedColumnParallelLinear gate_up_proj_{nullptr};
@@ -103,16 +93,20 @@ class Gemma2AttentionImpl : public Module {
std::max(1, n_kv_heads / world_size);
// register submodules
- qkv_proj_ = register_module("qkv_proj",
- QKVColumnParallelLinear(hidden_size,
- n_heads,
- n_kv_heads,
- head_dim,
- args.attn_bias(),
- /*gather_output=*/false,
- quant_args,
- parallel_args,
- options));
+ qkv_proj_ = register_module(
+ "qkv_proj",
+ QKVColumnParallelLinear(
+ hidden_size,
+ n_heads,
+ n_kv_heads,
+ head_dim,
+ std::vector{"q_proj.", "k_proj.", "v_proj."},
+ args.attn_bias(),
+ /*gather_output=*/false,
+ quant_args,
+ parallel_args,
+ options),
+ /*selector=*/nullptr);
o_proj_ = register_module("o_proj",
RowParallelLinear(n_heads * head_dim,
@@ -146,19 +140,6 @@ class Gemma2AttentionImpl : public Module {
return o_proj_(output);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- qkv_proj_->load_state_dict(
- state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."});
- o_proj_->load_state_dict(state_dict.select("o_proj."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj].");
- o_proj_->verify_loaded_weights(prefix + "o_proj.");
- }
-
private:
// parameter members, must be registered
QKVColumnParallelLinear qkv_proj_{nullptr};
@@ -167,9 +148,6 @@ class Gemma2AttentionImpl : public Module {
// module members without parameters
Attention atten_{nullptr};
-
- // size for q, k, v
- std::vector qkv_sizes_;
};
LLM_MODULE(Gemma2Attention);
@@ -226,29 +204,6 @@ class Gemma2DecoderLayerImpl : public Module {
hidden_states += residual;
return hidden_states;
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- input_layernorm_->load_state_dict((state_dict.select("input_layernorm.")));
- self_attn_->load_state_dict(state_dict.select("self_attn."));
- mlp_->load_state_dict(state_dict.select("mlp."));
- post_attention_layernorm_->load_state_dict(
- (state_dict.select("post_attention_layernorm.")));
- pre_feedforward_layernorm_->load_state_dict(
- state_dict.select("pre_feedforward_layernorm."));
- post_feedforward_layernorm_->load_state_dict(
- state_dict.select("post_feedforward_layernorm."));
- }
- void verify_loaded_weights(const std::string& prefix) const {
- input_layernorm_->verify_loaded_weights(prefix + "input_layernorm.");
- self_attn_->verify_loaded_weights(prefix + "self_attn.");
- mlp_->verify_loaded_weights(prefix + "mlp.");
- post_attention_layernorm_->verify_loaded_weights(
- prefix + "post_attention_layernorm.");
- pre_feedforward_layernorm_->verify_loaded_weights(
- prefix + "pre_feedforward_layernorm.");
- post_feedforward_layernorm_->verify_loaded_weights(
- prefix + "post_feedforward_layernorm.");
- }
private:
Gemma2Attention self_attn_{nullptr};
@@ -322,26 +277,6 @@ class Gemma2ModelImpl : public Module {
return norm_(h);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- embed_tokens_->load_state_dict(state_dict.select("embed_tokens."));
- // call each layer's load_state_dict function
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->load_state_dict(
- state_dict.select("layers." + std::to_string(i) + "."));
- }
- norm_->load_state_dict((state_dict.select("norm.")));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- embed_tokens_->verify_loaded_weights(prefix + "embed_tokens.");
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) +
- ".");
- }
- norm_->verify_loaded_weights(prefix + "norm.");
- }
-
private:
ModelArgs modelArgs_;
@@ -375,7 +310,7 @@ class Gemma2ForCausalLMImpl : public Module {
model_ = register_module(
"model", Gemma2Model(args, quant_args, parallel_args, options));
- lm_head_ = register_module("lm_head",
+ lm_head_ = register_module("model.embed_tokens",
ColumnParallelLinear(args.hidden_size(),
args.vocab_size(),
/*bias=*/false,
@@ -409,19 +344,6 @@ class Gemma2ForCausalLMImpl : public Module {
final_logit_soft_cap_;
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- model_->load_state_dict(state_dict.select("model."));
-
- // Share the embedding weights with the final llm_head layer.
- lm_head_->load_state_dict(state_dict.select("model.embed_tokens."));
- }
-
- void verify_loaded_weights() const {
- model_->verify_loaded_weights("model.");
- lm_head_->verify_loaded_weights("model.embed_tokens.");
- }
-
private:
// parameter members, must be registered
Gemma2Model model_{nullptr};
diff --git a/src/models/llama.h b/src/models/meta/llama.h
similarity index 80%
rename from src/models/llama.h
rename to src/models/meta/llama.h
index 120ff48d..d5b3a17d 100644
--- a/src/models/llama.h
+++ b/src/models/meta/llama.h
@@ -40,11 +40,13 @@ class LlamaMLPImpl : public Module {
FusedColumnParallelLinear(
hidden_size,
std::vector{intermediate_size, intermediate_size},
+ std::vector{"gate_proj.", "up_proj."},
/*bias=*/false,
/*gather_output=*/false,
quant_args,
parallel_args,
- options));
+ options),
+ /*selector=*/nullptr);
down_proj_ =
register_module("down_proj",
@@ -62,18 +64,6 @@ class LlamaMLPImpl : public Module {
return down_proj_(act_func_(gate_up[0]) * gate_up[1]);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- gate_up_proj_->load_state_dict(state_dict, {"gate_proj.", "up_proj."});
- down_proj_->load_state_dict(state_dict.select("down_proj."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- gate_up_proj_->verify_loaded_weights(prefix + "[gate_proj,up_proj].");
- down_proj_->verify_loaded_weights(prefix + "down_proj.");
- }
-
private:
// parameter members, must be registered
FusedColumnParallelLinear gate_up_proj_{nullptr};
@@ -101,16 +91,20 @@ class LlamaAttentionImpl : public Module {
std::max(1, n_kv_heads / world_size);
// register submodules
- qkv_proj_ = register_module("qkv_proj",
- QKVColumnParallelLinear(hidden_size,
- n_heads,
- n_kv_heads,
- head_dim,
- /*bias=*/false,
- /*gather_output=*/false,
- quant_args,
- parallel_args,
- options));
+ qkv_proj_ = register_module(
+ "qkv_proj",
+ QKVColumnParallelLinear(
+ hidden_size,
+ n_heads,
+ n_kv_heads,
+ head_dim,
+ std::vector{"q_proj.", "k_proj.", "v_proj."},
+ /*bias=*/false,
+ /*gather_output=*/false,
+ quant_args,
+ parallel_args,
+ options),
+ /*selector=*/nullptr);
o_proj_ = register_module("o_proj",
RowParallelLinear(hidden_size,
@@ -140,19 +134,6 @@ class LlamaAttentionImpl : public Module {
return o_proj_(output);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- qkv_proj_->load_state_dict(
- state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."});
- o_proj_->load_state_dict(state_dict.select("o_proj."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- qkv_proj_->verify_loaded_weights(prefix + "[q_proj,k_proj,v_proj].");
- o_proj_->verify_loaded_weights(prefix + "o_proj.");
- }
-
private:
// parameter members, must be registered
QKVColumnParallelLinear qkv_proj_{nullptr};
@@ -197,24 +178,6 @@ class LlamaDecoderLayerImpl : public Module {
return h + mlp_(post_attention_layernorm_(h));
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- self_attn_->load_state_dict(state_dict.select("self_attn."));
- mlp_->load_state_dict(state_dict.select("mlp."));
- input_layernorm_->load_state_dict(state_dict.select("input_layernorm."));
- post_attention_layernorm_->load_state_dict(
- state_dict.select("post_attention_layernorm."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- self_attn_->verify_loaded_weights(prefix + "self_attn.");
- mlp_->verify_loaded_weights(prefix + "mlp.");
- input_layernorm_->verify_loaded_weights(prefix + "input_layernorm.");
- post_attention_layernorm_->verify_loaded_weights(
- prefix + "post_attention_layernorm.");
- }
-
private:
// parameter members, must be registered
LlamaAttention self_attn_{nullptr};
@@ -270,26 +233,6 @@ class LlamaModelImpl : public Module {
return norm_(h);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- embed_tokens_->load_state_dict(state_dict.select("embed_tokens."));
- // call each layer's load_state_dict function
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->load_state_dict(
- state_dict.select("layers." + std::to_string(i) + "."));
- }
- norm_->load_state_dict(state_dict.select("norm."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- embed_tokens_->verify_loaded_weights(prefix + "embed_tokens.");
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) +
- ".");
- }
- norm_->verify_loaded_weights(prefix + "norm.");
- }
-
private:
// parameter members, must be registered
ParallelEmbedding embed_tokens_{nullptr};
@@ -347,17 +290,6 @@ class LlamaForCausalLMImpl : public Module {
return lm_head_(h);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- model_->load_state_dict(state_dict.select("model."));
- lm_head_->load_state_dict(state_dict.select("lm_head."));
- }
-
- void verify_loaded_weights() const {
- model_->verify_loaded_weights("model.");
- lm_head_->verify_loaded_weights("lm_head.");
- }
-
private:
// parameter members, must be registered
LlamaModel model_{nullptr};
diff --git a/src/models/phi.h b/src/models/microsoft/phi.h
similarity index 80%
rename from src/models/phi.h
rename to src/models/microsoft/phi.h
index 99e8dc07..f81ed123 100644
--- a/src/models/phi.h
+++ b/src/models/microsoft/phi.h
@@ -52,18 +52,6 @@ class PhiMLPImpl : public Module {
torch::Tensor forward(torch::Tensor x) { return fc2_(act_(fc1_(x))); }
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- fc1_->load_state_dict(state_dict.select("fc1."));
- fc2_->load_state_dict(state_dict.select("fc2."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- fc1_->verify_loaded_weights(prefix + "fc1.");
- fc2_->verify_loaded_weights(prefix + "fc2.");
- }
-
private:
// parameter members, must be registered
ColumnParallelLinear fc1_{nullptr};
@@ -134,18 +122,6 @@ class PhiAttentionImpl : public Module {
return out_proj_(output);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- Wqkv_->load_state_dict(state_dict.select("Wqkv."));
- out_proj_->load_state_dict(state_dict.select("out_proj."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- Wqkv_->verify_loaded_weights(prefix + "Wqkv.");
- out_proj_->verify_loaded_weights(prefix + "out_proj.");
- }
-
private:
// parameter members, must be registered
ColumnParallelLinear Wqkv_{nullptr};
@@ -191,20 +167,6 @@ class PhiBlockImpl : public Module {
return x + attn_output + mlp_output;
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- mixer_->load_state_dict(state_dict.select("mixer."));
- mlp_->load_state_dict(state_dict.select("mlp."));
- ln_->load_state_dict(state_dict.select("ln."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- mixer_->verify_loaded_weights(prefix + "mixer.");
- mlp_->verify_loaded_weights(prefix + "mlp.");
- ln_->verify_loaded_weights(prefix + "ln.");
- }
-
private:
// parameter members, must be registered
PhiAttention mixer_{nullptr};
@@ -223,7 +185,7 @@ class PhiModelImpl : public Module {
const torch::TensorOptions& options) {
// register submodules
wte_ = register_module(
- "wte",
+ "embd.wte",
ParallelEmbedding(
args.vocab_size(), args.hidden_size(), parallel_args, options));
@@ -256,24 +218,6 @@ class PhiModelImpl : public Module {
return h;
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- wte_->load_state_dict(state_dict.select("embd.wte."));
- // call each layer's load_state_dict function
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->load_state_dict(
- state_dict.select("h." + std::to_string(i) + "."));
- }
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- wte_->verify_loaded_weights(prefix + "embd.wte.");
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->verify_loaded_weights(prefix + "h." + std::to_string(i) +
- ".");
- }
- }
-
private:
// parameter members, must be registered
ParallelEmbedding wte_{nullptr};
@@ -310,17 +254,6 @@ class PhiLMHeadImpl : public Module {
torch::Tensor forward(torch::Tensor x) { return linear_(ln_(x)); }
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- ln_->load_state_dict(state_dict.select("ln."));
- linear_->load_state_dict(state_dict.select("linear."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- ln_->verify_loaded_weights(prefix + "ln.");
- linear_->verify_loaded_weights(prefix + "linear.");
- }
-
private:
// parameter members, must be registered
LayerNorm ln_{nullptr};
@@ -366,17 +299,6 @@ class PhiForCausalLMImpl : public Module {
return lm_head_(h);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- transformer_->load_state_dict(state_dict.select("transformer."));
- lm_head_->load_state_dict(state_dict.select("lm_head."));
- }
-
- void verify_loaded_weights() const {
- transformer_->verify_loaded_weights("transformer.");
- lm_head_->verify_loaded_weights("lm_head.");
- }
-
private:
// parameter members, must be registered
PhiModel transformer_{nullptr};
diff --git a/src/models/model_registry.cpp b/src/models/model_registry.cpp
index b4c7c5b9..c41ace43 100644
--- a/src/models/model_registry.cpp
+++ b/src/models/model_registry.cpp
@@ -2,7 +2,7 @@
#include
-#include "models.h" // IWYU pragma: keep
+#include "registered_models.h" // IWYU pragma: keep
namespace llm {
diff --git a/src/models/models.h b/src/models/models.h
deleted file mode 100644
index e569821e..00000000
--- a/src/models/models.h
+++ /dev/null
@@ -1,19 +0,0 @@
-#pragma once
-
-// list all registered models here
-#include "aquila.h" // IWYU pragma: keep
-#include "baichuan.h" // IWYU pargma: keep
-#include "bloom.h" // IWYU pragma: keep
-#include "chatglm.h" // IWYU pragma: keep
-#include "gemma.h" // IWYU pragma: keep
-#include "gemma2.h" // IWYU pragma: keep
-#include "gpt2.h" // IWYU pragma: keep
-#include "gpt_j.h" // IWYU pragma: keep
-#include "gpt_neox.h" // IWYU pragma: keep
-#include "internlm.h" // IWYU pragma: keep
-#include "llama.h" // IWYU pragma: keep
-#include "mistral.h" // IWYU pragma: keep
-#include "mpt.h" // IWYU pragma: keep
-#include "phi.h" // IWYU pragma: keep
-#include "qwen.h" // IWYU pragma: keep
-#include "qwen2.h" // IWYU pragma: keep
diff --git a/src/models/gpt2.h b/src/models/openai/gpt2.h
similarity index 74%
rename from src/models/gpt2.h
rename to src/models/openai/gpt2.h
index afde1084..2644ec48 100644
--- a/src/models/gpt2.h
+++ b/src/models/openai/gpt2.h
@@ -8,6 +8,7 @@
#include "layers/attention/handler.h"
#include "layers/embedding.h"
#include "layers/linear.h"
+#include "layers/linear_impl.h"
#include "layers/normalization.h"
#include "memory/kv_cache.h"
#include "models/model_args.h"
@@ -21,6 +22,17 @@
namespace llm::hf {
+namespace detail {
+// GPT-2 implementation uses Conv1D instead of Linear. As a result, we
+// need to transpose the weight for linear layer when loading from checkpoint.
+static StateDict transpose_selector(const StateDict& sd,
+ const std::string& key) {
+ // transpose the weight
+ return sd.select_with_transform(
+ key + ".", [](const torch::Tensor& tensor) { return tensor.t(); });
+}
+} // namespace detail
+
class GPT2MLPImpl : public Module {
public:
GPT2MLPImpl(const ModelArgs& args,
@@ -41,7 +53,8 @@ class GPT2MLPImpl : public Module {
/*gather_output=*/false,
quant_args,
parallel_args,
- options));
+ options),
+ detail::transpose_selector);
c_proj_ = register_module("c_proj",
RowParallelLinear(intermediate_size,
hidden_size,
@@ -49,33 +62,12 @@ class GPT2MLPImpl : public Module {
/*input_is_parallelized=*/true,
quant_args,
parallel_args,
- options));
+ options),
+ detail::transpose_selector);
}
torch::Tensor forward(torch::Tensor x) { return c_proj_(act_(c_fc_(x))); }
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- // GPT-2 implementation uses Conv1D instead of Linear. As a result, we
- // need to transpose the weight.
- c_fc_->load_state_dict(state_dict.select_with_transform(
- "c_fc.",
- [](const std::string_view& /*name*/, const torch::Tensor& tensor) {
- return tensor.t();
- }));
- c_proj_->load_state_dict(state_dict.select_with_transform(
- "c_proj.",
- [](const std::string_view& /*name*/, const torch::Tensor& tensor) {
- return tensor.t();
- }));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- c_fc_->verify_loaded_weights(prefix + "c_fc.");
- c_proj_->verify_loaded_weights(prefix + "c_proj.");
- }
-
private:
// parameter members, must be registered
ColumnParallelLinear c_fc_{nullptr};
@@ -105,7 +97,8 @@ class GPT2AttentionImpl : public Module {
/*gather_output=*/false,
quant_args,
parallel_args,
- options));
+ options),
+ detail::transpose_selector);
c_proj_ = register_module("c_proj",
RowParallelLinear(hidden_size_,
@@ -114,7 +107,8 @@ class GPT2AttentionImpl : public Module {
/*input_is_parallelized=*/true,
quant_args,
parallel_args,
- options));
+ options),
+ detail::transpose_selector);
// initialize attention
atten_ = register_module(
@@ -138,27 +132,6 @@ class GPT2AttentionImpl : public Module {
return c_proj_(output);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // GPT-2 implementation uses Conv1D instead of Linear. As a result, we
- // need to transpose the weight.
- c_attn_->load_state_dict(state_dict.select_with_transform(
- "c_attn.",
- [](const std::string_view& /*name*/, const torch::Tensor& tensor) {
- return tensor.t();
- }));
- c_proj_->load_state_dict(state_dict.select_with_transform(
- "c_proj.",
- [](const std::string_view& /*name*/, const torch::Tensor& tensor) {
- return tensor.t();
- }));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- c_attn_->verify_loaded_weights(prefix + "c_attn.");
- c_proj_->verify_loaded_weights(prefix + "c_proj.");
- }
-
private:
// parameter members, must be registered
ColumnParallelLinear c_attn_{nullptr};
@@ -207,22 +180,6 @@ class GPT2BlockImpl : public Module {
return h + mlp_(ln_2_(h));
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- // call each submodule's load_state_dict function
- attn_->load_state_dict(state_dict.select("attn."));
- mlp_->load_state_dict(state_dict.select("mlp."));
- ln_1_->load_state_dict(state_dict.select("ln_1."));
- ln_2_->load_state_dict(state_dict.select("ln_2."));
- }
-
- void verify_loaded_weights(const std::string& prefix) const {
- attn_->verify_loaded_weights(prefix + "attn.");
- mlp_->verify_loaded_weights(prefix + "mlp.");
- ln_1_->verify_loaded_weights(prefix + "ln_1.");
- ln_2_->verify_loaded_weights(prefix + "ln_2.");
- }
-
private:
// parameter members, must be registered
GPT2Attention attn_{nullptr};
@@ -282,27 +239,6 @@ class GPT2ModelImpl : public Module {
return ln_f_(h);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- wte_->load_state_dict(state_dict.select("wte."));
- wpe_->load_state_dict(state_dict.select("wpe."));
- // call each layer's load_state_dict function
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->load_state_dict(
- state_dict.select("h." + std::to_string(i) + "."));
- }
- ln_f_->load_state_dict(state_dict.select("ln_f."));
- }
-
- void verify_loaded_weights() const {
- wte_->verify_loaded_weights("wte.");
- wpe_->verify_loaded_weights("wpe.");
- for (int i = 0; i < layers_.size(); i++) {
- layers_[i]->verify_loaded_weights("h." + std::to_string(i) + ".");
- }
- ln_f_->verify_loaded_weights("ln_f.");
- }
-
private:
// parameter members, must be registered
ParallelEmbedding wte_{nullptr};
@@ -328,10 +264,13 @@ class GPT2ForCausalLMImpl : public Module {
const ParallelArgs& parallel_args,
const torch::TensorOptions& options) {
// register submodules
- model_ = register_module(
- "model", GPT2Model(args, quant_args, parallel_args, options));
+ model_ =
+ register_module("model",
+ GPT2Model(args, quant_args, parallel_args, options),
+ /*selector=*/nullptr);
- lm_head_ = register_module("lm_head",
+ // load wte.weight
+ lm_head_ = register_module("wte",
ColumnParallelLinear(args.hidden_size(),
args.vocab_size(),
/*bias=*/false,
@@ -363,18 +302,6 @@ class GPT2ForCausalLMImpl : public Module {
return lm_head_(h);
}
- // load the weight from the checkpoint
- void load_state_dict(const StateDict& state_dict) {
- model_->load_state_dict(state_dict);
- // TODO: share wte_ weight with lm_head_ to save memory
- lm_head_->load_state_dict(state_dict.select("wte."));
- }
-
- void verify_loaded_weights() const {
- model_->verify_loaded_weights();
- lm_head_->verify_loaded_weights("wte.");
- }
-
private:
// parameter members, must be registered
GPT2Model model_{nullptr};
diff --git a/src/models/registered_models.h b/src/models/registered_models.h
new file mode 100644
index 00000000..fba44c42
--- /dev/null
+++ b/src/models/registered_models.h
@@ -0,0 +1,26 @@
+#pragma once
+
+// list all registered models here
+// Google
+#include "google/gemma.h" // IWYU pragma: keep
+#include "google/gemma2.h" // IWYU pragma: keep
+// OpenAI
+#include "openai/gpt2.h" // IWYU pragma: keep
+// Meta
+#include "meta/llama.h" // IWYU pragma: keep
+// Microsoft
+#include "microsoft/phi.h" // IWYU pragma: keep
+// Alibaba
+#include "alibaba/qwen.h" // IWYU pragma: keep
+#include "alibaba/qwen2.h" // IWYU pragma: keep
+
+// Deprecated models
+// #include "deprecated/aquila.h"
+// #include "deprecated/baichuan.h"
+// #include "deprecated/bloom.h"
+// #include "deprecated/chatglm.h"
+// #include "deprecated/gpt_j.h"
+// #include "deprecated/gpt_neox.h"
+// #include "deprecated/internlm.h"
+// #include "deprecated/mistral.h"
+// #include "deprecated/mpt.h"
diff --git a/src/module/CMakeLists.txt b/src/module/CMakeLists.txt
index ceaab617..043f5de7 100644
--- a/src/module/CMakeLists.txt
+++ b/src/module/CMakeLists.txt
@@ -1,4 +1,5 @@
include(cc_library)
+include(cc_test)
cc_library(
NAME
@@ -13,3 +14,15 @@ cc_library(
glog::glog
torch
)
+
+cc_test(
+ NAME
+ module_test
+ SRCS
+ module_test.cpp
+ DEPS
+ :module
+ :state_dict
+ :gtest_main
+ torch
+)
diff --git a/src/module/module.cpp b/src/module/module.cpp
index 45f2546c..29d730b7 100644
--- a/src/module/module.cpp
+++ b/src/module/module.cpp
@@ -2,6 +2,8 @@
#include
+#include "model_loader/state_dict.h"
+
namespace llm {
using namespace torch;
@@ -49,8 +51,9 @@ OrderedDict Module::named_parameters(bool recurse) const {
OrderedDict result;
if (!recurse) {
for (const auto& parameter : parameters_) {
- if (parameter.value().defined()) {
- result.insert(parameter.key(), parameter.value());
+ const auto& param_tensor = parameter.value().tensor;
+ if (param_tensor.defined()) {
+ result.insert(parameter.key(), param_tensor);
}
}
} else {
@@ -125,12 +128,20 @@ OrderedDict> Module::named_modules(
}
std::vector> Module::children() const {
- return children_.values();
+ std::vector> result;
+ for (const auto& item : children_) {
+ result.push_back(item.value().module);
+ }
+ return result;
}
OrderedDict> Module::named_children()
const {
- return children_;
+ OrderedDict> result;
+ for (const auto& item : children_) {
+ result.insert(item.key(), item.value().module);
+ }
+ return result;
}
void Module::apply(const ModuleApplyFunction& function) {
@@ -198,31 +209,51 @@ void Module::to(torch::Device device, bool non_blocking) {
to_impl(device, non_blocking);
}
+Tensor& Module::register_parameter(std::string name,
+ Tensor tensor,
+ const TensorLoader& loader) {
+ TORCH_CHECK(!name.empty(), "Parameter name must not be empty");
+ // tensor.set_requires_grad(false);
+ Parameter param{
+ .tensor = std::move(tensor), .loader = loader, .is_loaded = false};
+ return parameters_.insert(std::move(name), std::move(param)).tensor;
+}
+
Tensor& Module::register_parameter(std::string name, Tensor tensor) {
+ auto default_loader = [](const StateDict& sd, const std::string& key) {
+ return sd.get_tensor(key);
+ };
+ return register_parameter(std::move(name), std::move(tensor), default_loader);
+}
+
+Tensor& Module::register_sharded_parameter(std::string name,
+ int dim,
+ int rank,
+ int world_size,
+ Tensor tensor) {
TORCH_CHECK(!name.empty(), "Parameter name must not be empty");
- TORCH_CHECK(name.find('.') == std::string::npos,
- "Parameter name must not contain a dot (got '",
- name,
- "')");
- tensor.set_requires_grad(false);
- return parameters_.insert(std::move(name), std::move(tensor));
+ // tensor.set_requires_grad(false);
+ Parameter param{
+ .tensor = std::move(tensor),
+ .loader =
+ [dim, rank, world_size](const StateDict& sd, const std::string& key) {
+ return sd.get_sharded_tensor(key, dim, rank, world_size);
+ },
+ .is_loaded = false};
+ return parameters_.insert(std::move(name), std::move(param)).tensor;
}
Tensor& Module::register_buffer(std::string name, Tensor tensor) {
TORCH_CHECK(!name.empty(), "Buffer name must not be empty");
- TORCH_CHECK(name.find('.') == std::string::npos,
- "Buffer name must not contain a dot (got '",
- name,
- "')");
return buffers_.insert(std::move(name), std::move(tensor));
}
-void Module::unregister_module(const std::string& name) {
- TORCH_CHECK(children_.contains(name),
- "No Module with name `",
- name,
- "` is registered");
+bool Module::unregister_module(const std::string& name) {
+ if (!children_.contains(name)) {
+ return false;
+ }
children_.erase(name);
+ return true;
}
void Module::pretty_print(std::ostream& stream) const { stream << name(); }
@@ -235,7 +266,7 @@ void Module::pretty_print_recursive(std::ostream& stream,
const std::string next_indentation = indentation + " ";
for (const auto& child : children_) {
stream << next_indentation << "(" << child.key() << "): ";
- child.value()->pretty_print_recursive(stream, next_indentation);
+ child.value().module->pretty_print_recursive(stream, next_indentation);
stream << '\n';
}
stream << indentation << ")";
@@ -248,8 +279,8 @@ void Module::apply_to_submodules(
const std::string& name_prefix) const {
for (const auto& child : children_) {
auto qualified_name = join_name(name_prefix, child.key());
- function(qualified_name, child.value());
- child.value()->apply_to_submodules(function, qualified_name);
+ function(qualified_name, child.value().module);
+ child.value().module->apply_to_submodules(function, qualified_name);
}
}
@@ -276,4 +307,89 @@ std::ostream& operator<<(std::ostream& stream, const Module& module) {
module.pretty_print_recursive(stream, "");
return stream;
}
+
+// load weights from the checkpoint, override this method if necessary
+// NOLINTNEXTLINE(misc-no-recursion)
+size_t Module::load(const StateDict& state_dict,
+ const std::string& name_prefix) {
+ size_t total_loaded = 0;
+ // load parameters one by one
+ for (auto& item : parameters_) {
+ const auto& key = item.key();
+ auto& param = item.value();
+ const auto& param_tensor = param.tensor;
+ // skip loading for undefined tensor
+ if (!param_tensor.defined()) {
+ // mark as loaded to pass verification
+ param.is_loaded = true;
+ continue;
+ }
+
+ const auto tensor = param.loader(state_dict, key);
+ if (!tensor.defined()) {
+ continue;
+ }
+
+ if (param.is_loaded) {
+ LOG(WARNING) << "Parameter " << join_name(name_prefix, key)
+ << " is already loaded";
+ }
+
+ if (param_tensor.sizes() == tensor.sizes()) {
+ // LOG(INFO) << "Loading parameter: " << join_name(name_prefix, key)
+ // << " of size " << tensor.sizes();
+ // copy data to the parameter tensor
+ param_tensor.copy_(tensor);
+ // mark as loaded
+ param.is_loaded = true;
+ ++total_loaded;
+ } else {
+ LOG(ERROR) << "Size mismatch for parameter "
+ << join_name(name_prefix, key) << ": expected "
+ << param_tensor.sizes() << ", got " << tensor.sizes();
+ }
+ }
+
+ // don't need to load buffers, since they are initialized in the constructor
+
+ // recursively load children modules
+ for (const auto& item : children_) {
+ const auto& key = item.key();
+ const auto& child = item.value();
+ if (child.selector) {
+ // select state dict for the child module
+ const auto child_state_dict = child.selector(state_dict, key);
+ total_loaded +=
+ child.module->load(child_state_dict, join_name(name_prefix, key));
+ } else {
+ total_loaded += child.module->load(state_dict, name_prefix);
+ }
+ }
+ return total_loaded;
+}
+
+// verify whether the weights are loaded, override this method if necessary
+// NOLINTNEXTLINE(misc-no-recursion)
+bool Module::verify(const std::string& name_prefix) const {
+ bool all_loaded = true;
+ for (const auto& item : parameters_) {
+ const auto& key = item.key();
+ const auto& param = item.value();
+ if (!param.is_loaded) {
+ LOG(ERROR) << "Missing parameter: " << join_name(name_prefix, key)
+ << ", size: " << param.tensor.sizes();
+ }
+ all_loaded = all_loaded && param.is_loaded;
+ }
+
+ for (const auto& item : children_) {
+ const auto& key = item.key();
+ const auto& child = item.value();
+ const std::string prefix =
+ child.selector ? join_name(name_prefix, key) : name_prefix;
+ const bool child_loaded = child.module->verify(prefix);
+ all_loaded = all_loaded && child_loaded;
+ }
+ return all_loaded;
+}
} // namespace llm
diff --git a/src/module/module.h b/src/module/module.h
index 69a63109..6cb96ff6 100644
--- a/src/module/module.h
+++ b/src/module/module.h
@@ -8,7 +8,9 @@
#include
#include
#include
+#include
+#include "model_loader/state_dict.h"
#include "module_holder.h"
namespace llm {
@@ -148,43 +150,61 @@ class Module : public std::enable_shared_from_this {
/// `stream` should be returned from the method, to allow easy chaining.
virtual void pretty_print(std::ostream& stream) const;
+ // load weights from the checkpoint, override this method if necessary
+ // returns the number of loaded parameters
+ virtual size_t load(const StateDict& state_dict,
+ const std::string& name_prefix = std::string());
+
+ // verify whether the weights are loaded, override this method if necessary
+ virtual bool verify(const std::string& name_prefix = std::string()) const;
+
/// Registers a parameter with this `Module`.
+ using TensorLoader =
+ std::function;
+ torch::Tensor& register_parameter(std::string name,
+ torch::Tensor tensor,
+ const TensorLoader& loader);
+
torch::Tensor& register_parameter(std::string name, torch::Tensor tensor);
+ torch::Tensor& register_sharded_parameter(std::string name,
+ int dim,
+ int rank,
+ int world_size,
+ torch::Tensor tensor);
+
/// Registers a buffer with this `Module`.
torch::Tensor& register_buffer(std::string name, torch::Tensor tensor);
/// Registers a submodule with this `Module`.
+ using StateDictSelector =
+ std::function;
+
template
std::shared_ptr register_module(
std::string name,
- std::shared_ptr module);
+ std::shared_ptr module,
+ const StateDictSelector& selector);
template
std::shared_ptr register_module(
std::string name,
- ModuleHolder module_holder);
-
- /// Replaces a registered submodule with this `Module`.
- template
- std::shared_ptr replace_module(
- const std::string& name,
std::shared_ptr module);
template
- std::shared_ptr replace_module(
- const std::string& name,
+ std::shared_ptr register_module(
+ std::string name,
ModuleHolder module_holder);
- /// Unregisters a submodule from this `Module`. If there is no such module
- /// with `name` an exception is thrown.
- void unregister_module(const std::string& name);
+ template
+ std::shared_ptr register_module(
+ std::string name,
+ ModuleHolder module_holder,
+ const StateDictSelector& selector);
- protected:
- /// The registered parameters of this `Module`.
- /// Inorder to access parameters_ in ParameterDict and ParameterList
- // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
- torch::OrderedDict parameters_;
+ /// Unregisters a submodule from this `Module`. Returns false if no such
+ /// submodule was registered.
+ bool unregister_module(const std::string& name);
private:
/// Pretty prints the given `Module` into the `ostream`.
@@ -209,11 +229,25 @@ class Module : public std::enable_shared_from_this {
/// Returns a shared_ptr to `this` in a safe (checked) way.
std::shared_ptr shared_from_this_checked() const;
+ struct Parameter {
+ torch::Tensor tensor;
+ std::function loader;
+ bool is_loaded = false;
+ };
+
+ struct Child {
+ std::shared_ptr module;
+ std::function selector;
+ };
+
+ /// The registered parameters of this `Module`.
+ torch::OrderedDict parameters_;
+
/// The registered buffers of this `Module`.
torch::OrderedDict buffers_;
/// The registered (direct) submodules of this `Module`.
- torch::OrderedDict> children_;
+ torch::OrderedDict children_;
/// The module's name (e.g. "LSTM").
mutable std::optional name_;
@@ -248,43 +282,45 @@ const ModuleType* Module::as() const noexcept {
template
std::shared_ptr Module::register_module(
std::string name,
- std::shared_ptr module) {
+ std::shared_ptr module,
+ const StateDictSelector& selector) {
TORCH_CHECK(!name.empty(), "Submodule name must not be empty");
- TORCH_CHECK(name.find('.') == std::string::npos,
- "Submodule name must not contain a dot (got '",
- name,
- "')");
- auto& base_module = children_.insert(std::move(name), std::move(module));
+ Child child{.module = std::move(module), .selector = selector};
+ auto& base_module =
+ children_.insert(std::move(name), std::move(child)).module;
return std::dynamic_pointer_cast(base_module);
}
template
std::shared_ptr Module::register_module(
std::string name,
- ModuleHolder module_holder) {
- return register_module(std::move(name), module_holder.ptr());
+ std::shared_ptr module) {
+ auto default_selector = [](const StateDict& sd, const std::string& key) {
+ return sd.select(key + ".");
+ };
+ return register_module(std::move(name), std::move(module), default_selector);
}
template
-std::shared_ptr Module::replace_module(
- const std::string& name,
- std::shared_ptr module) {
- auto& base_module = (children_[name] = std::move(module));
- return std::dynamic_pointer_cast(base_module);
+std::shared_ptr Module::register_module(
+ std::string name,
+ ModuleHolder module_holder) {
+ return register_module(std::move(name), module_holder.ptr());
}
template
-std::shared_ptr Module::replace_module(
- const std::string& name,
- ModuleHolder module_holder) {
- return replace_module(name, module_holder.ptr());
+std::shared_ptr Module::register_module(
+ std::string name,
+ ModuleHolder module_holder,
+ const StateDictSelector& selector) {
+ return register_module(std::move(name), module_holder.ptr(), selector);
}
template
void Module::to_impl(Ts&&... ts) {
// First call `to()` on every child module.
for (auto& child : children_) {
- child.value()->to(ts...);
+ child.value().module->to(ts...);
}
// Then move every parameter to the new dtype/device.
for (auto& parameter : named_parameters(/*recurse=*/false)) {
diff --git a/src/module/module_list.h b/src/module/module_list.h
index 3ae3cb45..ce0bfee6 100644
--- a/src/module/module_list.h
+++ b/src/module/module_list.h
@@ -1,9 +1,5 @@
#pragma once
-#include
-#include
-#include
-
#include
#include
@@ -11,6 +7,16 @@
#include "module_holder.h"
namespace llm {
+namespace detail {
+/// A type trait whose `value` member is true if `M` derives from `Module`.
+template
+using is_module = std::is_base_of>;
+
+template
+using enable_if_module_t = std::enable_if_t::value, T>;
+
+} // namespace detail
+
/// A list of `Module`s that registers its elements.
class ModuleListImpl : public Module {
public:
@@ -40,7 +46,7 @@ class ModuleListImpl : public Module {
/// Adds a new `Module` to the `ModuleList` container, moving or copying
/// it into a `shared_ptr` internally. This method allows passing value types,
/// and letting the container deal with the boxing.
- template >
+ template >
void push_back(M&& module) {
using Type = std::remove_reference_t;
push_back(std::make_shared(std::forward(module)));
@@ -78,7 +84,7 @@ class ModuleListImpl : public Module {
/// match.
template
T& at(size_t index) {
- static_assert(torch::detail::is_module::value,
+ static_assert(detail::is_module::value,
"Can only call ModuleList::at with an nn::Module type");
TORCH_CHECK(index < size(), "Index out of range");
auto module = modules_[index]->as();
@@ -95,7 +101,7 @@ class ModuleListImpl : public Module {
/// match.
template
const T& at(size_t index) const {
- static_assert(torch::detail::is_module::value,
+ static_assert(detail::is_module::value,
"Can only call ModuleList::at with an nn::Module type");
TORCH_CHECK(index < size(), "Index out of range");
const auto module = modules_[index]->as();
@@ -120,7 +126,7 @@ class ModuleListImpl : public Module {
/// match.
template
std::shared_ptr ptr(size_t index) const {
- static_assert(torch::detail::is_module::value,
+ static_assert(detail::is_module::value,
"Can only call ModuleList::ptr with an nn::Module type");
TORCH_CHECK(index < size(), "Index out of range");
return std::dynamic_pointer_cast(modules_[index]);
@@ -138,39 +144,6 @@ class ModuleListImpl : public Module {
/// True if there are no modules in the `ModuleList`.
bool is_empty() const noexcept { return size() == 0; }
- void insert(size_t index, std::shared_ptr module) {
- TORCH_CHECK(index <= size(), "Index out of range");
-
- if (index == size())
- push_back(std::move(module));
- else {
- modules_.insert(modules_.begin() + Iterator::difference_type(index),
- std::move(module));
-
- for (const auto i : c10::irange(index, size() - 1)) {
- (void)i; // Suppress unused variable warning
- replace_module(std::to_string(index), modules_[index]);
- }
- register_module(std::to_string(size() - 1), modules_.back());
- }
- }
-
- /// Unwraps the contained module of a `ModuleHolder` and inserts it in the
- /// `ModuleList`.
- template
- void insert(size_t index, const ModuleHolder& module_holder) {
- insert(index, module_holder.ptr());
- }
-
- /// inserts a new `Module` to the `ModuleList` container, moving or copying
- /// it into a `shared_ptr` internally. This method allows passing value types,
- /// and letting the container deal with the boxing.
- template >
- void insert(size_t index, M&& module) {
- using Type = std::remove_reference_t;
- insert(index, std::make_shared(std::forward(module)));
- }
-
private:
template
void push_back_var(Head&& head, Tail&&... tail) {
diff --git a/src/module/module_test.cpp b/src/module/module_test.cpp
new file mode 100644
index 00000000..475f6264
--- /dev/null
+++ b/src/module/module_test.cpp
@@ -0,0 +1,131 @@
+#include "module.h"
+
+#include
+#include
+#include
+
+namespace llm {
+
+class ParametersModel : public Module {
+ public:
+ ParametersModel(bool bias = true) {
+ // register some parameters
+ weight_ = register_parameter("weight", torch::randn({16, 16}));
+ if (bias) {
+ bias_ = register_parameter("bias", torch::randn({32, 32}));
+ }
+
+ sharded_param_ = register_sharded_parameter("sharded_param",
+ /*dim=*/0,
+ /*rank=*/0,
+ /*world_size=*/2,
+ torch::randn({16, 32}));
+ undefined_param_ = register_parameter("undefined_param", torch::Tensor());
+ }
+
+ torch::Tensor weight_;
+ torch::Tensor bias_;
+
+ torch::Tensor sharded_param_;
+ torch::Tensor undefined_param_;
+};
+
+class SubModel : public Module {
+ public:
+ SubModel() {
+ param1_ = register_parameter("param1", torch::randn({8, 8}));
+ param2_ = register_parameter("param2", torch::randn({8, 16}));
+ x_ = register_module("x", std::make_shared(/*bias=*/true));
+ y_ =
+ register_module("y", std::make_shared(/*bias=*/false));
+ }
+
+ torch::Tensor param1_;
+ torch::Tensor param2_;
+ std::shared_ptr x_;
+ std::shared_ptr y_;
+};
+
+class Model : public Module {
+ public:
+ Model() {
+ param1_ = register_parameter("param1", torch::randn({8, 8}));
+ param2_ = register_parameter("param2", torch::randn({8, 16}));
+ sub_model_ = register_module("submodel", std::make_shared());
+ }
+
+ torch::Tensor param1_;
+ torch::Tensor param2_;
+ std::shared_ptr sub_model_;
+};
+
+TEST(ModuleTest, Parameters) {
+ std::unordered_map dict = {
+ {"weight", torch::randn({16, 16})},
+ {"bias", torch::randn({32, 32})},
+ {"param1", torch::randn({16, 16})},
+ {"param2", torch::randn({16, 16})},
+ {"sharded_param", torch::randn({32, 32})},
+ {"undefined_param", torch::randn({32, 32})}};
+
+ StateDict state_dict(dict);
+ EXPECT_EQ(state_dict.size(), dict.size());
+
+ ParametersModel model;
+ EXPECT_EQ(model.load(state_dict), 3);
+
+ EXPECT_TRUE(torch::equal(model.weight_, dict["weight"]));
+ EXPECT_TRUE(torch::equal(model.bias_, dict["bias"]));
+ EXPECT_TRUE(torch::equal(model.sharded_param_,
+ dict["sharded_param"].chunk(2, /*dim=*/0)[0]));
+ EXPECT_FALSE(model.undefined_param_.defined());
+
+ EXPECT_TRUE(model.verify());
+}
+
+TEST(ModuleTest, SubModules) {
+ std::unordered_map dict = {
+ {"param1", torch::randn({8, 8})},
+ {"param2", torch::randn({8, 16})},
+ {"submodel.param1", torch::randn({8, 8})},
+ {"submodel.param2", torch::randn({8, 16})},
+ {"submodel.x.weight", torch::randn({16, 16})},
+ {"submodel.x.bias", torch::randn({32, 32})},
+ {"submodel.x.sharded_param", torch::randn({32, 32})},
+ {"submodel.x.undefined_param", torch::randn({32, 32})},
+ {"submodel.y.weight", torch::randn({16, 16})},
+ {"submodel.y.bias", torch::randn({32, 32})},
+ {"submodel.y.sharded_param", torch::randn({32, 32})},
+ {"submodel.y.undefined_param", torch::randn({32, 32})}};
+
+ StateDict state_dict(dict);
+ EXPECT_EQ(state_dict.size(), dict.size());
+
+ Model model;
+ EXPECT_EQ(model.load(state_dict), 9);
+
+ EXPECT_TRUE(torch::equal(model.param1_, dict["param1"]));
+ EXPECT_TRUE(torch::equal(model.param2_, dict["param2"]));
+ // submodel
+ EXPECT_TRUE(torch::equal(model.sub_model_->param1_, dict["submodel.param1"]));
+ EXPECT_TRUE(torch::equal(model.sub_model_->param2_, dict["submodel.param2"]));
+ EXPECT_TRUE(
+ torch::equal(model.sub_model_->x_->weight_, dict["submodel.x.weight"]));
+ EXPECT_TRUE(
+ torch::equal(model.sub_model_->x_->bias_, dict["submodel.x.bias"]));
+ EXPECT_TRUE(
+ torch::equal(model.sub_model_->x_->sharded_param_,
+ dict["submodel.x.sharded_param"].chunk(2, /*dim=*/0)[0]));
+ EXPECT_FALSE(model.sub_model_->x_->undefined_param_.defined());
+ EXPECT_TRUE(
+ torch::equal(model.sub_model_->y_->weight_, dict["submodel.y.weight"]));
+ EXPECT_FALSE(model.sub_model_->y_->bias_.defined());
+ EXPECT_TRUE(
+ torch::equal(model.sub_model_->y_->sharded_param_,
+ dict["submodel.y.sharded_param"].chunk(2, /*dim=*/0)[0]));
+ EXPECT_FALSE(model.sub_model_->y_->undefined_param_.defined());
+
+ EXPECT_TRUE(model.verify());
+}
+
+} // namespace llm
diff --git a/src/quantization/qlinear_gptq_marlin_impl.h b/src/quantization/qlinear_gptq_marlin_impl.h
index dc6b3f40..e996b479 100644
--- a/src/quantization/qlinear_gptq_marlin_impl.h
+++ b/src/quantization/qlinear_gptq_marlin_impl.h
@@ -35,11 +35,6 @@ class ColumnParallelQLinearGPTQMarlinImpl : public ParallelLinearImpl {
void load_state_dict(const StateDict& state_dict,
const std::vector& prefixes) override;
- void pretty_print(std::ostream& stream) const override {
- stream << name() << " qweight=" << qweight_.sizes()
- << " scales=" << scales_.sizes() << " device=" << qweight_.device();
- }
-
private:
// parameter members, must be registered
DEFINE_FUSED_WEIGHT(qweight);
@@ -91,11 +86,6 @@ class RowParallelQLinearGPTQMarlinImpl : public ParallelLinearImpl {
// whether the weight is loaded
void verify_loaded_weights(const std::string& prefix = "") const override;
- void pretty_print(std::ostream& stream) const override {
- stream << name() << " qweight=" << qweight_.sizes()
- << " scales=" << scales_.sizes() << " device=" << qweight_.device();
- }
-
private:
// parameter members, must be registered
DEFINE_WEIGHT(qweight);
diff --git a/src/quantization/qlinear_impl.h b/src/quantization/qlinear_impl.h
index c442a8bd..0269246f 100644
--- a/src/quantization/qlinear_impl.h
+++ b/src/quantization/qlinear_impl.h
@@ -76,12 +76,6 @@ class ColumnParallelQLinearImpl : public ParallelLinearImpl {
void load_state_dict(const StateDict& state_dict,
const std::vector& prefixes) override;
- void pretty_print(std::ostream& stream) const override {
- stream << name() << " qweight=" << qweight_.sizes()
- << " qzeros=" << qzeros_.sizes() << " scales=" << scales_.sizes()
- << " device=" << qweight_.device();
- }
-
private:
// parameter members, must be registered
DEFINE_FUSED_WEIGHT(qweight);
@@ -148,12 +142,6 @@ class RowParallelQLinearImpl : public ParallelLinearImpl {
// whether the weight is loaded
void verify_loaded_weights(const std::string& prefix = "") const override;
- void pretty_print(std::ostream& stream) const override {
- stream << name() << " qweight=" << qweight_.sizes()
- << " qzeros=" << qzeros_.sizes() << " scales=" << scales_.sizes()
- << " device=" << qweight_.device();
- }
-
private:
// parameter members, must be registered
DEFINE_WEIGHT(qweight);