Skip to content
Merged
13 changes: 2 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ ScaleLLM

<div align="left">

[ScaleLLM](#) is a cutting-edge inference system engineered for large language models (LLMs), designed to meet the demands of production environments. It extends its support to a wide range of popular open-source models, including [Llama3.1](https://github.com/meta-llama/llama3), [Gemma2](https://github.com/google-deepmind/gemma), Bloom, GPT-NeoX, and more.
[ScaleLLM](#) is a cutting-edge inference system engineered for large language models (LLMs), designed to meet the demands of production environments. It extends its support to a wide range of popular open-source models, including [Llama3.1](https://github.com/meta-llama/llama3), [Gemma2](https://github.com/google-deepmind/gemma), [Phi](https://huggingface.co/microsoft/phi-2), and more.

ScaleLLM is currently undergoing active development. We are fully committed to consistently enhancing its efficiency while also incorporating additional features. Feel free to explore our [**_Roadmap_**](https://github.com/vectorch-ai/ScaleLLM/issues/84) for more details.

## News:
* [01/2025] - Optimized inhouse Attention kernels
* [06/2024] - ScaleLLM is now available on [PyPI](https://pypi.org/project/scalellm/). You can install it using `pip install scalellm`.
* [03/2024] - [Advanced features](#advanced-features) support for ✅ [CUDA graph](#cuda-graph), ✅ [prefix cache](#prefix-cache), ✅ [chunked prefill](#chunked-prefill) and ✅ [speculative decoding](#speculative-decoding).
* [11/2023] - [First release](https://github.com/vectorch-ai/ScaleLLM/releases/tag/v0.0.1) with support for popular [open-source models](#supported-models).
Expand Down Expand Up @@ -274,21 +275,11 @@ Quantization is a crucial process for reducing the memory footprint of models. S

| Models | Tensor Parallel | Quantization | Chat API | HF models examples |
| :--------: | :-------------: | :----------: | :------: | :---------------------------:|
| Aquila | Yes | Yes | Yes | [BAAI/Aquila-7B](https://huggingface.co/BAAI/Aquila-7B), [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) |
| Bloom | Yes | Yes | No | [bigscience/bloom](https://huggingface.co/bigscience/bloom) |
| Baichuan | Yes | Yes | Yes | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) |
| ChatGLM4/3 | Yes | Yes | Yes | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
| Gemma2 | Yes | Yes | Yes | [google/gemma-2-2b](https://huggingface.co/google/gemma-2-2b) |
| GPT_j | Yes | Yes | No | [EleutherAI/gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b) |
| GPT_NeoX | Yes | Yes | No | [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b) |
| GPT2 | Yes | Yes | No | [gpt2](https://huggingface.co/gpt2)|
| InternLM | Yes | Yes | Yes | [internlm/internlm-7b](https://huggingface.co/internlm/internlm-7b) |
| Llama3/2 | Yes | Yes | Yes | [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct), [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) |
| Mistral | Yes | Yes | Yes | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) |
| MPT | Yes | Yes | Yes | [mosaicml/mpt-30b](https://huggingface.co/mosaicml/mpt-30b) |
| Phi2 | Yes | Yes | No | [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) |
| Qwen2 | Yes | Yes | Yes | [Qwen/Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat) |
| Yi | Yes | Yes | Yes |[01-ai/Yi-6B](https://huggingface.co/01-ai/Yi-6B), [01-ai/Yi-34B-Chat-4bits](https://huggingface.co/01-ai/Yi-34B-Chat-4bits), [01-ai/Yi-6B-200K](https://huggingface.co/01-ai/Yi-6B-200K) |

If your model is not included in the supported list, we are more than willing to assist you. Please feel free to create a request for adding a new model on [GitHub Issues](https://github.com/vectorch-ai/ScaleLLM/issues).

Expand Down
87 changes: 14 additions & 73 deletions src/layers/embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,6 @@ class EmbeddingImpl : public Module {
return F::embedding(input, weight_);
}

// load the weight from the checkpoint
void load_state_dict(const StateDict& state_dict) {
const auto weight = state_dict.get_tensor("weight");
if (weight.defined()) {
CHECK_EQ(weight_.sizes(), weight.sizes())
<< "weight size mismatch for " << name();
weight_.copy_(weight);
is_loaded_ = true;
}
}

// whether the weight is loaded
void verify_loaded_weights(const std::string& prefix) const {
CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight";
}

void pretty_print(std::ostream& stream) const override {
stream << name() << " " << weight_.sizes() << " " << weight_.device();
}

// return the weight (for testing)
torch::Tensor weight() const { return weight_; }

Expand All @@ -73,15 +53,19 @@ class ParallelEmbeddingImpl : public Module {
const ParallelArgs& parallel_args,
const torch::TensorOptions& options)
: parallel_args_(parallel_args) {
const auto rank = parallel_args_.rank();
const auto world_size = parallel_args_.world_size();
CHECK(embedding_dim % world_size == 0)
<< "out_features " << embedding_dim << " not divisible by world_size "
<< world_size;
const int64_t embedding_dim_per_partition = embedding_dim / world_size;

// register the weight parameter
weight_ = register_parameter(
weight_ = register_sharded_parameter(
"weight",
/*dim=*/1,
rank,
world_size,
torch::empty({num_embeddings, embedding_dim_per_partition}, options));
}

Expand All @@ -96,30 +80,6 @@ class ParallelEmbeddingImpl : public Module {
return output;
}

// load the weight from the checkpoint
void load_state_dict(const StateDict& state_dict) {
const auto weight = state_dict.get_sharded_tensor(
"weight",
/*dim=*/1,
/*rank=*/parallel_args_.rank(),
/*world_size=*/parallel_args_.world_size());
if (weight.defined()) {
CHECK_EQ(weight_.sizes(), weight.sizes())
<< "weight size mismatch for " << name();
weight_.copy_(weight);
is_loaded_ = true;
}
}

// whether the weight is loaded
void verify_loaded_weights(const std::string& prefix) const {
CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight";
}

void pretty_print(std::ostream& stream) const override {
stream << name() << " " << weight_.sizes() << " " << weight_.device();
}

// return the weight (for testing)
torch::Tensor weight() const { return weight_; }

Expand All @@ -143,14 +103,19 @@ class VocabParallelEmbeddingImpl : public Module {
const ParallelArgs& parallel_args,
const torch::TensorOptions& options)
: parallel_args_(parallel_args) {
const int64_t num_embeddings_per_partition =
num_embeddings / parallel_args_.world_size();
start_index_ = num_embeddings_per_partition * parallel_args_.rank();
const auto rank = parallel_args_.rank();
const auto world_size = parallel_args_.world_size();

const int64_t num_embeddings_per_partition = num_embeddings / world_size;
start_index_ = num_embeddings_per_partition * rank;
end_index_ = start_index_ + num_embeddings_per_partition;

// register the weight parameter
weight_ = register_parameter(
weight_ = register_sharded_parameter(
"weight",
/*dim=*/0,
rank,
world_size,
torch::empty({num_embeddings_per_partition, embedding_dim}, options));
}

Expand All @@ -174,30 +139,6 @@ class VocabParallelEmbeddingImpl : public Module {
return reduce_from_model_parallel_region(output, parallel_args_);
}

// load the weight from the checkpoint
void load_state_dict(const StateDict& state_dict) {
const auto weight = state_dict.get_sharded_tensor(
"weight",
/*dim=*/0,
/*rank=*/parallel_args_.rank(),
/*world_size=*/parallel_args_.world_size());
if (weight.defined()) {
CHECK_EQ(weight_.sizes(), weight.sizes())
<< "weight size mismatch for " << name();
weight_.copy_(weight);
is_loaded_ = true;
}
}

// whether the weight is loaded
void verify_loaded_weights(const std::string& prefix = "") const {
CHECK(is_loaded_) << "weight is not loaded for " << prefix + "weight";
}

void pretty_print(std::ostream& stream) const override {
stream << name() << " " << weight_.sizes() << " " << weight_.device();
}

// return the weight (for testing)
torch::Tensor weight() const { return weight_; }

Expand Down
23 changes: 13 additions & 10 deletions src/layers/fused_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ namespace llm {
FusedColumnParallelLinearImpl::FusedColumnParallelLinearImpl(
int64_t in_features,
const std::vector<int64_t>& out_features_vec,
const std::vector<std::string>& prefixes,
bool bias,
bool gather_output,
const QuantArgs& quant_args,
const ParallelArgs& parallel_args,
const torch::TensorOptions& options) {
prefixes_ = prefixes;
// check if the linear layers can be fused
fused_ = quant_args.can_be_fused();
if (fused_) {
Expand Down Expand Up @@ -72,28 +74,29 @@ std::vector<torch::Tensor> FusedColumnParallelLinearImpl::forward(
return outputs;
}

void FusedColumnParallelLinearImpl::load_state_dict(
const StateDict& state_dict,
const std::vector<std::string>& prefixes) {
size_t FusedColumnParallelLinearImpl::load(const StateDict& state_dict,
const std::string&) {
if (fused_) {
fused_linear_->load_state_dict(state_dict, prefixes);
fused_linear_->load_state_dict(state_dict, prefixes_);
} else {
CHECK_EQ(parallel_linears_.size(), prefixes.size());
CHECK_EQ(parallel_linears_.size(), prefixes_.size());
for (size_t i = 0; i < parallel_linears_.size(); ++i) {
parallel_linears_[i]->load_state_dict(state_dict.select(prefixes[i]));
parallel_linears_[i]->load_state_dict(state_dict.select(prefixes_[i]));
}
}
return 0;
}

void FusedColumnParallelLinearImpl::verify_loaded_weights(
const std::string& prefix) const {
bool FusedColumnParallelLinearImpl::verify(
const std::string& name_prefix) const {
if (fused_) {
fused_linear_->verify_loaded_weights(prefix);
fused_linear_->verify_loaded_weights(name_prefix);
} else {
for (const auto& parallel_linear : parallel_linears_) {
parallel_linear->verify_loaded_weights(prefix);
parallel_linear->verify_loaded_weights(name_prefix);
}
}
return true;
}

} // namespace llm
13 changes: 9 additions & 4 deletions src/layers/fused_linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class FusedColumnParallelLinearImpl : public Module {
public:
FusedColumnParallelLinearImpl(int64_t in_features,
const std::vector<int64_t>& out_features,
const std::vector<std::string>& prefixes,
bool bias,
bool gather_output,
const QuantArgs& quant_args,
Expand All @@ -24,11 +25,13 @@ class FusedColumnParallelLinearImpl : public Module {

std::vector<torch::Tensor> forward(torch::Tensor input);

// load_state_dict for fused weights
void load_state_dict(const StateDict& state_dict,
const std::vector<std::string>& prefixes);
// load weights from the checkpoint, override this method if necessary
// returns the number of loaded parameters
size_t load(const StateDict& state_dict,
const std::string& name_prefix = std::string()) override;

void verify_loaded_weights(const std::string& prefix = "") const;
// verify whether the weights are loaded, override this method if necessary
bool verify(const std::string& name_prefix = std::string()) const override;

// whether the linear layer is fused
bool fused() const { return fused_; }
Expand All @@ -43,6 +46,8 @@ class FusedColumnParallelLinearImpl : public Module {
// sizes for each split
std::vector<int64_t> split_sizes_;

std::vector<std::string> prefixes_;

// whether the linear layer is fused
bool fused_ = false;
};
Expand Down
20 changes: 16 additions & 4 deletions src/layers/linear_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl(
const ParallelArgs& parallel_args,
const torch::TensorOptions& options)
: gather_output_(gather_output), parallel_args_(parallel_args) {
const auto rank = parallel_args_.rank();
const auto world_size = parallel_args_.world_size();
CHECK(out_features % world_size == 0)
<< "out_features " << out_features << " not divisible by world_size "
Expand All @@ -26,13 +27,20 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl(

// Note: torch.nn.functional.linear performs XA^T + b and as a result
// we allocate the transpose.
weight_ = register_parameter(
weight_ = register_sharded_parameter(
"weight",
/*dim=*/0,
rank,
world_size,
torch::empty({out_features_per_partition, in_features}, options));

if (bias) {
bias_ = register_parameter(
"bias", torch::empty({out_features_per_partition}, options));
bias_ = register_sharded_parameter(
"bias",
/*dim=*/0,
rank,
world_size,
torch::empty({out_features_per_partition}, options));
}
}

Expand Down Expand Up @@ -93,14 +101,18 @@ RowParallelLinearImpl::RowParallelLinearImpl(
const torch::TensorOptions& options)
: input_is_parallelized_(input_is_parallelized),
parallel_args_(parallel_args) {
const auto rank = parallel_args_.rank();
const auto world_size = parallel_args_.world_size();
CHECK(in_features % world_size == 0)
<< "in_features " << in_features << " not divisible by world_size "
<< world_size;
const int64_t in_features_per_partition = in_features / world_size;
// Allocate the transpose since linear performs XA^T.
weight_ = register_parameter(
weight_ = register_sharded_parameter(
"weight",
/*dim=*/1,
rank,
world_size,
torch::empty({out_features, in_features_per_partition}, options));

if (bias) {
Expand Down
8 changes: 0 additions & 8 deletions src/layers/linear_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ class ColumnParallelLinearImpl : public ParallelLinearImpl {
<< "bias is not loaded for " << prefix + "bias";
}

void pretty_print(std::ostream& stream) const override {
stream << name() << " " << weight_.sizes() << " " << weight_.device();
}

// return the weight (for testing)
torch::Tensor weight() const { return weight_; }

Expand Down Expand Up @@ -95,10 +91,6 @@ class RowParallelLinearImpl : public ParallelLinearImpl {
<< "bias is not loaded for " << prefix + "bias";
}

void pretty_print(std::ostream& stream) const override {
stream << name() << " " << weight_.sizes() << " " << weight_.device();
}

// return the weight (for testing)
torch::Tensor weight() const { return weight_; }

Expand Down
Loading