From 16f2e895db62a0253498cfb54f6febb268782156 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 7 Oct 2025 22:45:25 -0700 Subject: [PATCH 1/2] feat: clean up load_state_dict for quant linear --- src/engine/llm_engine.cpp | 4 +- src/engine/worker.cpp | 13 +- src/engine/worker.h | 7 +- src/engine/worker_test.cpp | 4 +- src/layers/linear/parallel_linear.cpp | 16 +- src/layers/linear/parallel_linear.h | 15 -- .../linear/qkv_parallel_linear_test.cpp | 2 +- src/layers/quantization/CMakeLists.txt | 13 +- ...{qlinear_impl.cpp => parallel_qlinear.cpp} | 143 +++++------------- .../{qlinear_impl.h => parallel_qlinear.h} | 16 -- ...ptq_impl.cpp => parallel_qlinear_gptq.cpp} | 2 +- ...ar_gptq_impl.h => parallel_qlinear_gptq.h} | 2 +- ...est.cpp => parallel_qlinear_gptq_test.cpp} | 21 +-- .../quantization/parallel_qlinear_test.cpp | 115 ++++++++++++++ src/layers/quantization/qlinear_awq_impl.h | 2 +- .../quantization/qlinear_awq_marlin_impl.cpp | 74 --------- .../quantization/qlinear_awq_marlin_impl.h | 16 -- .../quantization/qlinear_exllamav2_impl.h | 2 +- .../quantization/qlinear_gptq_marlin_impl.cpp | 87 ----------- .../quantization/qlinear_gptq_marlin_impl.h | 16 -- src/models/causal_lm.h | 10 +- 21 files changed, 198 insertions(+), 382 deletions(-) rename src/layers/quantization/{qlinear_impl.cpp => parallel_qlinear.cpp} (66%) rename src/layers/quantization/{qlinear_impl.h => parallel_qlinear.h} (89%) rename src/layers/quantization/{qlinear_gptq_impl.cpp => parallel_qlinear_gptq.cpp} (99%) rename src/layers/quantization/{qlinear_gptq_impl.h => parallel_qlinear_gptq.h} (99%) rename src/layers/quantization/{qlinear_impl_test.cpp => parallel_qlinear_gptq_test.cpp} (77%) create mode 100644 src/layers/quantization/parallel_qlinear_test.cpp diff --git a/src/engine/llm_engine.cpp b/src/engine/llm_engine.cpp index 15642685..97e9796b 100644 --- a/src/engine/llm_engine.cpp +++ b/src/engine/llm_engine.cpp @@ -193,7 +193,7 @@ bool LLMEngine::init_model(const std::string& model_weights_path) { std::vector> futures; futures.reserve(workers_.size()); for (auto& worker : workers_) { - futures.push_back(worker->load_state_dict_async(state_dict)); + futures.push_back(worker->load_async(state_dict)); } // wait for all futures to complete auto results = folly::collectAll(futures).get(); @@ -206,7 +206,7 @@ bool LLMEngine::init_model(const std::string& model_weights_path) { // verify the weights are loaded correctly for (const auto& worker : workers_) { - worker->verify_loaded_weights(); + worker->verify(); } return true; } diff --git a/src/engine/worker.cpp b/src/engine/worker.cpp index 72c3d83f..0ddec1a8 100644 --- a/src/engine/worker.cpp +++ b/src/engine/worker.cpp @@ -88,14 +88,14 @@ void Worker::capture_cuda_graph(uint32_t batch_size) { return model_runner_->capture_cuda_graphs(batch_size, kv_caches_); } -void Worker::load_state_dict(const StateDict& state_dict) { +void Worker::load(const StateDict& state_dict) { CHECK(model_ != nullptr) << "Model is not initialized."; - model_->load_state_dict(state_dict); + model_->load(state_dict); } -void Worker::verify_loaded_weights() const { +void Worker::verify() const { CHECK(model_ != nullptr) << "Model is not initialized."; - model_->verify_loaded_weights(); + model_->verify(); } std::tuple Worker::profile_device_memory() { @@ -270,14 +270,13 @@ folly::SemiFuture Worker::capture_cuda_graph_async( return future; } -folly::SemiFuture Worker::load_state_dict_async( - const StateDict& state_dict) { +folly::SemiFuture Worker::load_async(const StateDict& state_dict) { folly::Promise promise; auto future = promise.getSemiFuture(); threadpool_.schedule( [this, &state_dict, promise = std::move(promise)]() mutable { // load the model weights from state_dict within the working thread - this->load_state_dict(state_dict); + this->load(state_dict); promise.setValue(); }); return future; diff --git a/src/engine/worker.h b/src/engine/worker.h index 9a929b2b..798a3f2c 100644 --- a/src/engine/worker.h +++ b/src/engine/worker.h @@ -30,10 +30,10 @@ class Worker final { // Load the model weights from state_dict. blocking call // can be called multiple times to reload the model with different parameters - void load_state_dict(const StateDict& state_dict); + void load(const StateDict& state_dict); // verify if the model is loaded correctly - void verify_loaded_weights() const; + void verify() const; // returns available memory and total memory std::tuple profile_device_memory(); @@ -57,8 +57,7 @@ class Worker final { // Load the model weights from state_dict. async call // the future returns a successfull status with no meaningful value - folly::SemiFuture load_state_dict_async( - const StateDict& state_dict); + folly::SemiFuture load_async(const StateDict& state_dict); folly::SemiFuture> profile_device_memory_async(); diff --git a/src/engine/worker_test.cpp b/src/engine/worker_test.cpp index 4df260cb..b38fe873 100644 --- a/src/engine/worker_test.cpp +++ b/src/engine/worker_test.cpp @@ -45,8 +45,8 @@ class TestableWorker { } auto state_dict = create_state_dict(); - worker_->load_state_dict(state_dict); - worker_->verify_loaded_weights(); + worker_->load(state_dict); + worker_->verify(); return true; } diff --git a/src/layers/linear/parallel_linear.cpp b/src/layers/linear/parallel_linear.cpp index ed591457..c44bdcf2 100644 --- a/src/layers/linear/parallel_linear.cpp +++ b/src/layers/linear/parallel_linear.cpp @@ -8,10 +8,10 @@ #include #include "layers/module/module.h" +#include "layers/quantization/parallel_qlinear_gptq.h" #include "layers/quantization/qlinear_awq_impl.h" #include "layers/quantization/qlinear_awq_marlin_impl.h" #include "layers/quantization/qlinear_exllamav2_impl.h" -#include "layers/quantization/qlinear_gptq_impl.h" #include "layers/quantization/qlinear_gptq_marlin_impl.h" #include "model_parallel/model_parallel.h" @@ -182,13 +182,13 @@ std::shared_ptr create_column_parallel_linear( parallel_args, options); } - return std ::make_shared(in_features, - out_features, - bias, - gather_output, - parallel_args, - options, - prefix); + return std::make_shared(in_features, + out_features, + bias, + gather_output, + parallel_args, + options, + prefix); } std::shared_ptr create_row_parallel_linear( diff --git a/src/layers/linear/parallel_linear.h b/src/layers/linear/parallel_linear.h index d1ebee68..a0fd3f2e 100644 --- a/src/layers/linear/parallel_linear.h +++ b/src/layers/linear/parallel_linear.h @@ -19,21 +19,6 @@ class ParallelLinearImpl : public Module { ~ParallelLinearImpl() override = default; virtual torch::Tensor forward(torch::Tensor input) = 0; - - // TODO: clean up the interface of load_state_dict - virtual void load_state_dict(const StateDict& state_dict) { - LOG(FATAL) << "not implemented"; - } - - virtual void verify_loaded_weights(const std::string& prefix = "") const { - LOG(FATAL) << "not implemented"; - } - - // special load_state_dict for fused cases - virtual void load_state_dict(const StateDict& /*state_dict*/, - const std::vector& /*prefixes*/) { - LOG(FATAL) << "not implemented"; - } }; // Linear layer with column parallelism. diff --git a/src/layers/linear/qkv_parallel_linear_test.cpp b/src/layers/linear/qkv_parallel_linear_test.cpp index a45b1eb5..e3aa53a7 100644 --- a/src/layers/linear/qkv_parallel_linear_test.cpp +++ b/src/layers/linear/qkv_parallel_linear_test.cpp @@ -62,7 +62,7 @@ TEST_P(QKVColumnParallelLinearTest, LoadFusedWeight) { quant_args, parallel_args, options); - linear.load(state_dict); + EXPECT_EQ(linear.load(state_dict), 3); EXPECT_TRUE(linear.verify()); // generate random input and compare with the output diff --git a/src/layers/quantization/CMakeLists.txt b/src/layers/quantization/CMakeLists.txt index 94761865..6af1d5f7 100644 --- a/src/layers/quantization/CMakeLists.txt +++ b/src/layers/quantization/CMakeLists.txt @@ -6,16 +6,16 @@ cc_library( quantization HDRS pack_utils.h - qlinear_impl.h - qlinear_gptq_impl.h + parallel_qlinear.h + parallel_qlinear_gptq.h qlinear_exllamav2_impl.h qlinear_awq_impl.h qlinear_gptq_marlin_impl.h qlinear_awq_marlin_impl.h SRCS pack_utils.cpp - qlinear_impl.cpp - qlinear_gptq_impl.cpp + parallel_qlinear.cpp + parallel_qlinear_gptq.cpp qlinear_exllamav2_impl.cpp qlinear_awq_impl.cpp qlinear_gptq_marlin_impl.cpp @@ -34,10 +34,11 @@ cc_library( cc_test( NAME - quantization_test + qlinear_test SRCS pack_utils_test.cpp - qlinear_impl_test.cpp + parallel_qlinear_test.cpp + parallel_qlinear_gptq_test.cpp DEPS :quantization :state_dict diff --git a/src/layers/quantization/qlinear_impl.cpp b/src/layers/quantization/parallel_qlinear.cpp similarity index 66% rename from src/layers/quantization/qlinear_impl.cpp rename to src/layers/quantization/parallel_qlinear.cpp index bb41ec04..17aa668d 100644 --- a/src/layers/quantization/qlinear_impl.cpp +++ b/src/layers/quantization/parallel_qlinear.cpp @@ -1,4 +1,4 @@ -#include "qlinear_impl.h" +#include "parallel_qlinear.h" #include #include @@ -18,13 +18,20 @@ namespace detail { // construct weights matrix for gptq from quantized weights // return the weights matrix [in_features, out_features] with following formula: // weights = scales * (qweights - qzeros) +// pack_factor = 32 / bits +// n_in_ints = in_features / pack_factor +// n_out_ints = out_features / pack_factor +// n_groups = ceil(in_features / group_size) torch::Tensor construct_weights( - const torch::Tensor& qweights, // [n_ints, out_features] IntTensor - const torch::Tensor& qzeros, // [n_groups, n_ints] IntTensor + const torch::Tensor& qweights, // [n_in_ints, out_features] IntTensor + const torch::Tensor& qzeros, // [n_groups, n_out_ints] IntTensor const torch::Tensor& scales, // [n_groups, out_features] HalfTensor const torch::Tensor& g_idx, // [in_features] IntTensor int64_t bits) { CHECK(bits == 2 || bits == 4 || bits == 8) << "Only 2,4,8 bits are supported"; + const int64_t pack_factor = 32 / bits; + const int64_t n_groups = scales.size(0); + const int64_t out_features = scales.size(1); std::vector bits_to_shift; for (int32_t i = 0; i < 32; i += bits) { @@ -35,9 +42,9 @@ torch::Tensor construct_weights( torch::tensor(bits_to_shift, qweights.options()).unsqueeze(0); const auto dtype = (bits == 8) ? torch::kInt16 : torch::kInt8; const uint16_t mask = static_cast(std::pow(2, bits) - 1); - // [n_groups, out_features/n_bits, n_ints] + // [n_groups, out_features/pack_factor, pack_factor] auto zeros = torch::bitwise_right_shift( - qzeros.unsqueeze(2).expand({-1, -1, 32 / bits}), + qzeros.unsqueeze(2).expand({-1, -1, pack_factor}), shift_bits.unsqueeze(0)) .to(dtype); zeros.bitwise_and_(mask); @@ -45,14 +52,17 @@ torch::Tensor construct_weights( // [n_groups, out_features] zeros = zeros.reshape(scales.sizes()); + // [in_features/pack_factor, pack_factor, out_features] auto weights = torch::bitwise_right_shift( - qweights.unsqueeze(1).expand({-1, 32 / bits, -1}), + qweights.unsqueeze(1).expand({-1, pack_factor, -1}), shift_bits.unsqueeze(-1)) .to(dtype); weights.bitwise_and_(mask); - weights = weights.reshape({-1, qweights.size(1)}); + // [in_features, out_features] + weights = weights.reshape({-1, out_features}); // auto gathered_scales = scales.gather(/*dim=*/0, /*index=*/g_idx); // auto gathered_zeros = zeros.gather(/*dim=*/0, /*index=*/g_idx); + // return gathered_scales * (weights - gathered_zeros); return scales.index({g_idx}) * (weights - zeros.index({g_idx})); } @@ -61,11 +71,14 @@ torch::Tensor construct_weights( // return the weights matrix [in_features, out_features] with following formula: // weights = scales * (qweights - qzeros) torch::Tensor construct_weights( - const torch::Tensor& qweights, // [n_ints, out_features] IntTensor - const torch::Tensor& qzeros, // [n_groups, n_ints] IntTensor + const torch::Tensor& qweights, // [n_in_ints, out_features] IntTensor + const torch::Tensor& qzeros, // [n_groups, n_out_ints] IntTensor const torch::Tensor& scales, // [n_groups, out_features] HalfTensor int64_t bits) { CHECK(bits == 2 || bits == 4 || bits == 8) << "Only 2,4,8 bits are supported"; + const int64_t pack_factor = 32 / bits; + const int64_t n_groups = scales.size(0); + const int64_t out_features = scales.size(1); std::vector bits_to_shift; for (int32_t i = 0; i < 32; i += bits) { @@ -77,25 +90,28 @@ torch::Tensor construct_weights( torch::tensor(bits_to_shift, qweights.options()).unsqueeze(0); const auto dtype = (bits == 8) ? torch::kInt16 : torch::kInt8; const uint16_t mask = static_cast(std::pow(2, bits) - 1); - // [n_groups, out_features/n_bits, n_ints] + // [n_groups, out_features/pack_factor, pack_factor] auto zeros = torch::bitwise_right_shift( - qzeros.unsqueeze(2).expand({-1, -1, 32 / bits}), + qzeros.unsqueeze(2).expand({-1, -1, pack_factor}), shift_bits.unsqueeze(0)) .to(dtype); zeros.bitwise_and_(mask); zeros.add_(1); // [n_groups, 1, out_features] - zeros = zeros.reshape({scales.size(0), 1, scales.size(1)}); + zeros = zeros.reshape({n_groups, 1, out_features}); + // [in_features/pack_factor, pack_factor, out_features] auto weights = torch::bitwise_right_shift( - qweights.unsqueeze(1).expand({-1, 32 / bits, -1}), + qweights.unsqueeze(1).expand({-1, pack_factor, -1}), shift_bits.unsqueeze(-1)) .to(dtype); weights.bitwise_and_(mask); - // [n_groups, group_size, out_features] - weights = weights.reshape({scales.size(0), -1, scales.size(1)}); + // [in_features, out_features] => [n_groups, group_size, out_features] + weights = weights.reshape({n_groups, -1, out_features}); + // [n_groups, 1, out_features] * [n_groups, group_size, out_features] weights = scales.unsqueeze(1) * (weights - zeros); - return weights.reshape({-1, scales.size(1)}); + // [n_groups, group_size, out_features] => [in_features, out_features] + return weights.reshape({-1, out_features}); } } // namespace detail @@ -141,13 +157,13 @@ ColumnParallelQLinearImpl::ColumnParallelQLinearImpl( torch::empty({in_features, out_features_per_partition / pack_factor}, options.dtype(torch::kInt32))); } + const int64_t n_groups = round_up(in_features, group_size); qzeros_ = register_sharded_parameter( "qzeros", /*dim=*/1, rank, world_size, - torch::empty({round_up(in_features, group_size), - out_features_per_partition / pack_factor}, + torch::empty({n_groups, out_features_per_partition / pack_factor}, options.dtype(torch::kInt32))); scales_ = register_sharded_parameter( @@ -155,9 +171,8 @@ ColumnParallelQLinearImpl::ColumnParallelQLinearImpl( /*dim=*/1, rank, world_size, - torch::empty( - {round_up(in_features, group_size), out_features_per_partition}, - options)); + torch::empty({n_groups, out_features_per_partition}, options)); + if (bias) { bias_ = register_sharded_parameter( "bias", @@ -174,58 +189,11 @@ torch::Tensor ColumnParallelQLinearImpl::quant_matmul( const torch::Tensor& qzeros, const torch::Tensor& scales) const { const int64_t out_features = qweight.size(-1); - torch::Tensor output = - torch::zeros({input.size(0), out_features}, input.options()); + // scales * (qweights - qzeros): [in_features, out_features] const auto weights = detail::construct_weights(qweight, qzeros, scales, bits_); - torch::matmul_out(/*out=*/output, /*self=*/input, /*other=*/weights); - return output; -} - -// load the weight from the checkpoint -void ColumnParallelQLinearImpl::load_state_dict(const StateDict& state_dict) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // load sharded weights on dim 1 - LOAD_SHARDED_WEIGHT(qweight, 1); - LOAD_SHARDED_WEIGHT(qzeros, 1); - LOAD_SHARDED_WEIGHT(scales, 1); - - // load bias if defined - if (bias_.defined()) { - // load sharded bias on dim 0 - LOAD_SHARDED_WEIGHT(bias, 0); - } -} - -// special load_state_dict for fused cases -void ColumnParallelQLinearImpl::load_state_dict( - const StateDict& state_dict, - const std::vector& prefixes) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // load and merge weights on dim 1 - LOAD_FUSED_WEIGHT(qweight, 1); - LOAD_FUSED_WEIGHT(qzeros, 1); - LOAD_FUSED_WEIGHT(scales, 1); - - // load bias if defined - if (bias_.defined()) { - // load and merge bias on dim 0 - LOAD_FUSED_WEIGHT(bias, 0); - } -} - -void ColumnParallelQLinearImpl::verify_loaded_weights( - const std::string& prefix) const { - CHECK(qweight_is_loaded_) - << "qweight is not loaded for " << prefix + "qweight"; - CHECK(qzeros_is_loaded_) << "qzeros is not loaded for " << prefix + "qzeros"; - CHECK(scales_is_loaded_) << "scales is not loaded for " << prefix + "scales"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; + // output: [batch, out_features] + return torch::matmul(input, weights); } RowParallelQLinearImpl::RowParallelQLinearImpl( @@ -299,38 +267,11 @@ torch::Tensor RowParallelQLinearImpl::quant_matmul( const torch::Tensor& qzeros, const torch::Tensor& scales) const { const int64_t out_features = qweight.size(-1); - torch::Tensor output = - torch::zeros({input.size(0), out_features}, input.options()); + // scales * (qweights - qzeros): [in_features, out_features] const auto weights = detail::construct_weights(qweight, qzeros, scales, bits_); - torch::matmul_out(/*out=*/output, /*self=*/input, /*other=*/weights); - return output; -} - -// load the weight from the checkpoint -void RowParallelQLinearImpl::load_state_dict(const StateDict& state_dict) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // load sharded weights on dim 0 - LOAD_SHARDED_WEIGHT(qweight, 0); - LOAD_SHARDED_WEIGHT(qzeros, 0); - LOAD_SHARDED_WEIGHT(scales, 0); - - if (bias_.defined()) { - // load bias - LOAD_WEIGHT(bias); - } -} - -void RowParallelQLinearImpl::verify_loaded_weights( - const std::string& prefix) const { - CHECK(qweight_is_loaded_) - << "qweight is not loaded for " << prefix + "qweight"; - CHECK(qzeros_is_loaded_) << "qzeros is not loaded for " << prefix + "qzeros"; - CHECK(scales_is_loaded_) << "scales is not loaded for " << prefix + "scales"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; + // output: [batch, out_features] + return torch::matmul(input, weights); } } // namespace llm diff --git a/src/layers/quantization/qlinear_impl.h b/src/layers/quantization/parallel_qlinear.h similarity index 89% rename from src/layers/quantization/qlinear_impl.h rename to src/layers/quantization/parallel_qlinear.h index d02e2186..4b1e1fb2 100644 --- a/src/layers/quantization/qlinear_impl.h +++ b/src/layers/quantization/parallel_qlinear.h @@ -48,9 +48,6 @@ class ColumnParallelQLinearImpl : public ParallelLinearImpl { const ParallelArgs& parallel_args, const torch::TensorOptions& options); - // verify if the weight is loaded correctly - void verify_loaded_weights(const std::string& prefix = "") const override; - // all subclasses must implement this function virtual torch::Tensor quant_matmul(const torch::Tensor& input, const torch::Tensor& qweight, @@ -68,13 +65,6 @@ class ColumnParallelQLinearImpl : public ParallelLinearImpl { return output; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; - - // special load_state_dict for fused cases - void load_state_dict(const StateDict& state_dict, - const std::vector& prefixes) override; - private: // parameter members, must be registered DEFINE_FUSED_WEIGHT(qweight); @@ -135,12 +125,6 @@ class RowParallelQLinearImpl : public ParallelLinearImpl { return output; } - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const override; - private: // parameter members, must be registered DEFINE_WEIGHT(qweight); diff --git a/src/layers/quantization/qlinear_gptq_impl.cpp b/src/layers/quantization/parallel_qlinear_gptq.cpp similarity index 99% rename from src/layers/quantization/qlinear_gptq_impl.cpp rename to src/layers/quantization/parallel_qlinear_gptq.cpp index 56a0f819..ef0bd2e3 100644 --- a/src/layers/quantization/qlinear_gptq_impl.cpp +++ b/src/layers/quantization/parallel_qlinear_gptq.cpp @@ -1,4 +1,4 @@ -#include "qlinear_gptq_impl.h" +#include "parallel_qlinear_gptq.h" #include #include diff --git a/src/layers/quantization/qlinear_gptq_impl.h b/src/layers/quantization/parallel_qlinear_gptq.h similarity index 99% rename from src/layers/quantization/qlinear_gptq_impl.h rename to src/layers/quantization/parallel_qlinear_gptq.h index d1e85835..af76947c 100644 --- a/src/layers/quantization/qlinear_gptq_impl.h +++ b/src/layers/quantization/parallel_qlinear_gptq.h @@ -5,7 +5,7 @@ #include "model_loader/state_dict.h" #include "models/model_args.h" -#include "qlinear_impl.h" +#include "parallel_qlinear.h" namespace llm { diff --git a/src/layers/quantization/qlinear_impl_test.cpp b/src/layers/quantization/parallel_qlinear_gptq_test.cpp similarity index 77% rename from src/layers/quantization/qlinear_impl_test.cpp rename to src/layers/quantization/parallel_qlinear_gptq_test.cpp index 26291ab4..18287a45 100644 --- a/src/layers/quantization/qlinear_impl_test.cpp +++ b/src/layers/quantization/parallel_qlinear_gptq_test.cpp @@ -1,27 +1,14 @@ +#include "parallel_qlinear_gptq.h" + #include #include #include #include "model_loader/state_dict.h" -#include "qlinear_gptq_impl.h" namespace llm { -TEST(QlinearTest, Basic) { - auto state_dict = StateDict::load_safetensors("data/gptq_small.safetensors"); - auto weights = detail::construct_weights(state_dict->get_tensor("qweight"), - state_dict->get_tensor("qzeros"), - state_dict->get_tensor("scales"), - /*bits=*/4); - auto weights_2 = detail::construct_weights(state_dict->get_tensor("qweight"), - state_dict->get_tensor("qzeros"), - state_dict->get_tensor("scales"), - state_dict->get_tensor("g_idx"), - /*bits=*/4); - EXPECT_TRUE(torch::allclose(weights, weights_2)); -} - -TEST(QlinearTest, ColumnParallelQuantLinear) { +TEST(GPTQQlinearTest, ColumnParallelQLinear) { if (!torch::cuda::is_available()) { GTEST_SKIP() << "CUDA not available, skipping test"; } @@ -58,7 +45,7 @@ TEST(QlinearTest, ColumnParallelQuantLinear) { /*atol=*/1e-02)); } -TEST(QlinearTest, RowParallelQuantLinear) { +TEST(GPTQQlinearTest, RowParallelQLinear) { if (!torch::cuda::is_available()) { GTEST_SKIP() << "CUDA not available, skipping test"; } diff --git a/src/layers/quantization/parallel_qlinear_test.cpp b/src/layers/quantization/parallel_qlinear_test.cpp new file mode 100644 index 00000000..5d53a2cc --- /dev/null +++ b/src/layers/quantization/parallel_qlinear_test.cpp @@ -0,0 +1,115 @@ +#include "parallel_qlinear.h" + +#include +#include +#include + +#include "model_loader/state_dict.h" + +namespace llm { +namespace { +// TODO: create a quantized state dict for testing +// std::shared_ptr create_quant_state_dict(int64_t in_features, +// int64_t out_features, +// int64_t bits, +// int64_t group_size) { +// const auto options = torch::dtype(torch::kFloat32).device(torch::kCPU); + +// // create random weight +// auto weight = torch::rand({out_features, in_features}, options); + +// return std::make_unique(); +// } +} // namespace + +TEST(QlinearTest, Basic) { + auto state_dict = StateDict::load_safetensors("data/gptq_small.safetensors"); + auto weights = detail::construct_weights(state_dict->get_tensor("qweight"), + state_dict->get_tensor("qzeros"), + state_dict->get_tensor("scales"), + /*bits=*/4); + auto weights_2 = detail::construct_weights(state_dict->get_tensor("qweight"), + state_dict->get_tensor("qzeros"), + state_dict->get_tensor("scales"), + state_dict->get_tensor("g_idx"), + /*bits=*/4); + EXPECT_TRUE(torch::allclose(weights, weights_2)); +} + +TEST(QlinearTest, ColumnParallelQLinear) { + if (!torch::cuda::is_available()) { + GTEST_SKIP() << "CUDA not available, skipping test"; + } + + const int64_t in_features = 4096; + const int64_t out_features = 4096; + QuantArgs quant_args; + quant_args.bits(4); + quant_args.group_size(128); + const auto options = torch::dtype(torch::kHalf).device(torch::kCUDA); + ColumnParallelQLinearImpl qlinear(in_features, + out_features, + /*bias=*/false, + quant_args, + /*qweight_pack_dim=*/0, + /*gather_output=*/false, + ParallelArgs(0, 1, nullptr), + options); + auto state_dict = StateDict::load_safetensors("data/gptq.safetensors"); + auto weights = detail::construct_weights(state_dict->get_tensor("qweight"), + state_dict->get_tensor("qzeros"), + state_dict->get_tensor("scales"), + /*bits=*/4); + weights = weights.to(torch::kCUDA); + + qlinear.load(*state_dict); + EXPECT_TRUE(qlinear.verify()); + + auto input = torch::rand({40960, in_features}, options); + auto output = qlinear.forward(input); + auto desired_output = torch::matmul(input, weights); + EXPECT_TRUE(torch::allclose(output, + desired_output, + /*rtol=*/1e-01, + /*atol=*/1e-02)); +} + +TEST(QlinearTest, RowParallelQuantLinear) { + if (!torch::cuda::is_available()) { + GTEST_SKIP() << "CUDA not available, skipping test"; + } + + const int64_t in_features = 4096; + const int64_t out_features = 4096; + QuantArgs quant_args; + quant_args.bits(4); + quant_args.group_size(128); + const auto options = torch::dtype(torch::kHalf).device(torch::kCUDA); + RowParallelQLinearImpl qlinear(in_features, + out_features, + /*bias=*/false, + quant_args, + /*qweight_pack_dim=*/0, + /*input_is_parallelized=*/true, + ParallelArgs(0, 1, nullptr), + options); + auto state_dict = StateDict::load_safetensors("data/gptq.safetensors"); + auto weights = detail::construct_weights(state_dict->get_tensor("qweight"), + state_dict->get_tensor("qzeros"), + state_dict->get_tensor("scales"), + /*bits=*/4); + weights = weights.to(torch::kCUDA); + + qlinear.load(*state_dict); + EXPECT_TRUE(qlinear.verify()); + + auto input = torch::rand({40960, in_features}, options); + auto output = qlinear.forward(input); + auto desired_output = torch::matmul(input, weights); + EXPECT_TRUE(torch::allclose(output, + desired_output, + /*rtol=*/1e-01, + /*atol=*/1e-02)); +} + +} // namespace llm diff --git a/src/layers/quantization/qlinear_awq_impl.h b/src/layers/quantization/qlinear_awq_impl.h index 593611d3..084c3a6e 100644 --- a/src/layers/quantization/qlinear_awq_impl.h +++ b/src/layers/quantization/qlinear_awq_impl.h @@ -4,7 +4,7 @@ #include "model_loader/state_dict.h" #include "models/model_args.h" -#include "qlinear_impl.h" +#include "parallel_qlinear.h" namespace llm { // quantized linear layers using awq diff --git a/src/layers/quantization/qlinear_awq_marlin_impl.cpp b/src/layers/quantization/qlinear_awq_marlin_impl.cpp index a3c9922f..999f446b 100644 --- a/src/layers/quantization/qlinear_awq_marlin_impl.cpp +++ b/src/layers/quantization/qlinear_awq_marlin_impl.cpp @@ -180,53 +180,6 @@ ColumnParallelQLinearAWQMarlinImpl::ColumnParallelQLinearAWQMarlinImpl( perm_ = torch::empty({0}, options.dtype(torch::kInt32)); } -// load the weight from the checkpoint -void ColumnParallelQLinearAWQMarlinImpl::load_state_dict( - const StateDict& state_dict) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // load sharded weights on dim 1 - LOAD_SHARDED_WEIGHT(qweight, 1); - LOAD_SHARDED_WEIGHT(qzeros, 1); - LOAD_SHARDED_WEIGHT(scales, 1); - - // load bias if defined - if (bias_.defined()) { - // load sharded bias on dim 0 - LOAD_SHARDED_WEIGHT(bias, 0); - } -} - -// special load_state_dict for fused cases -void ColumnParallelQLinearAWQMarlinImpl::load_state_dict( - const StateDict& state_dict, - const std::vector& prefixes) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // load and merge weights on dim 1 - LOAD_FUSED_WEIGHT(qweight, 1); - LOAD_FUSED_WEIGHT(qzeros, 1); - LOAD_FUSED_WEIGHT(scales, 1); - - // load bias if defined - if (bias_.defined()) { - // load and merge bias on dim 0 - LOAD_FUSED_WEIGHT(bias, 0); - } -} - -void ColumnParallelQLinearAWQMarlinImpl::verify_loaded_weights( - const std::string& prefix) const { - CHECK(qweight_is_loaded_) - << "qweight is not loaded for " << prefix + "qweight"; - CHECK(qzeros_is_loaded_) << "qzeros is not loaded for " << prefix + "qzeros"; - CHECK(scales_is_loaded_) << "scales is not loaded for " << prefix + "scales"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; -} - torch::Tensor ColumnParallelQLinearAWQMarlinImpl::forward(torch::Tensor input) { // repack qweight and scales to marlin compatible format at the first call if (!weight_repacked_) { @@ -302,33 +255,6 @@ RowParallelQLinearAWQMarlinImpl::RowParallelQLinearAWQMarlinImpl( perm_ = torch::empty({0}, options.dtype(torch::kInt32)); } -// load the weight from the checkpoint -void RowParallelQLinearAWQMarlinImpl::load_state_dict( - const StateDict& state_dict) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // load sharded weights on dim 0 - LOAD_SHARDED_WEIGHT(qweight, 0); - LOAD_SHARDED_WEIGHT(qzeros, 0); - LOAD_SHARDED_WEIGHT(scales, 0); - - if (bias_.defined()) { - // load bias - LOAD_WEIGHT(bias); - } -} - -void RowParallelQLinearAWQMarlinImpl::verify_loaded_weights( - const std::string& prefix) const { - CHECK(qweight_is_loaded_) - << "qweight is not loaded for " << prefix + "qweight"; - CHECK(qzeros_is_loaded_) << "qzeros is not loaded for " << prefix + "qzeros"; - CHECK(scales_is_loaded_) << "scales is not loaded for " << prefix + "scales"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; -} - torch::Tensor RowParallelQLinearAWQMarlinImpl::forward(torch::Tensor input) { // repack qweight and scales to marlin compatible format at the first call if (!weight_repacked_) { diff --git a/src/layers/quantization/qlinear_awq_marlin_impl.h b/src/layers/quantization/qlinear_awq_marlin_impl.h index fcdb5d11..4fd5bf17 100644 --- a/src/layers/quantization/qlinear_awq_marlin_impl.h +++ b/src/layers/quantization/qlinear_awq_marlin_impl.h @@ -20,18 +20,8 @@ class ColumnParallelQLinearAWQMarlinImpl : public ParallelLinearImpl { const ParallelArgs& parallel_args, const torch::TensorOptions& options); - // verify if the weight is loaded correctly - void verify_loaded_weights(const std::string& prefix = "") const override; - torch::Tensor forward(torch::Tensor input) override; - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; - - // special load_state_dict for fused cases - void load_state_dict(const StateDict& state_dict, - const std::vector& prefixes) override; - private: // parameter members, must be registered DEFINE_FUSED_WEIGHT(qweight); @@ -67,12 +57,6 @@ class RowParallelQLinearAWQMarlinImpl : public ParallelLinearImpl { torch::Tensor forward(torch::Tensor input) override; - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const override; - private: // parameter members, must be registered DEFINE_WEIGHT(qweight); diff --git a/src/layers/quantization/qlinear_exllamav2_impl.h b/src/layers/quantization/qlinear_exllamav2_impl.h index 2299dd3d..832206fb 100644 --- a/src/layers/quantization/qlinear_exllamav2_impl.h +++ b/src/layers/quantization/qlinear_exllamav2_impl.h @@ -5,7 +5,7 @@ #include "model_loader/state_dict.h" #include "models/model_args.h" -#include "qlinear_impl.h" +#include "parallel_qlinear.h" namespace llm { diff --git a/src/layers/quantization/qlinear_gptq_marlin_impl.cpp b/src/layers/quantization/qlinear_gptq_marlin_impl.cpp index 1fb11a72..e7bbe36a 100644 --- a/src/layers/quantization/qlinear_gptq_marlin_impl.cpp +++ b/src/layers/quantization/qlinear_gptq_marlin_impl.cpp @@ -123,58 +123,6 @@ ColumnParallelQLinearGPTQMarlinImpl::ColumnParallelQLinearGPTQMarlinImpl( workspace_ = torch::zeros({max_workspace_size}, options.dtype(torch::kInt32)); } -// load the weight from the checkpoint -void ColumnParallelQLinearGPTQMarlinImpl::load_state_dict( - const StateDict& state_dict) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // load sharded weights on dim 1 - LOAD_SHARDED_WEIGHT(qweight, 1); - LOAD_SHARDED_WEIGHT(scales, 1); - - if (act_order_) { - LOAD_WEIGHT(g_idx); - } - - // load bias if defined - if (bias_.defined()) { - // load sharded bias on dim 0 - LOAD_SHARDED_WEIGHT(bias, 0); - } -} - -// special load_state_dict for fused cases -void ColumnParallelQLinearGPTQMarlinImpl::load_state_dict( - const StateDict& state_dict, - const std::vector& prefixes) { - CHECK(!act_order_) << "fused weight does not support desc_act"; - - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // load and merge weights on dim 1 - LOAD_FUSED_WEIGHT(qweight, 1); - LOAD_FUSED_WEIGHT(scales, 1); - - // load bias if defined - if (bias_.defined()) { - // load and merge bias on dim 0 - LOAD_FUSED_WEIGHT(bias, 0); - } -} - -void ColumnParallelQLinearGPTQMarlinImpl::verify_loaded_weights( - const std::string& prefix) const { - CHECK(qweight_is_loaded_) - << "qweight is not loaded for " << prefix + "qweight"; - CHECK(scales_is_loaded_) << "scales is not loaded for " << prefix + "scales"; - CHECK(!act_order_ || g_idx_is_loaded_) - << "g_idx is not loaded for " << prefix + "g_idx"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; -} - torch::Tensor ColumnParallelQLinearGPTQMarlinImpl::forward( torch::Tensor input) { // repack qweight and scales to marlin compatible format at the first call @@ -259,41 +207,6 @@ RowParallelQLinearGPTQMarlinImpl::RowParallelQLinearGPTQMarlinImpl( workspace_ = torch::zeros({max_workspace_size}, options.dtype(torch::kInt32)); } -// load the weight from the checkpoint -void RowParallelQLinearGPTQMarlinImpl::load_state_dict( - const StateDict& state_dict) { - const auto rank = parallel_args_.rank(); - const auto world_size = parallel_args_.world_size(); - - // load sharded weights on dim 0 - LOAD_SHARDED_WEIGHT(qweight, 0); - if (load_full_scales_) { - LOAD_WEIGHT(scales); - } else { - LOAD_SHARDED_WEIGHT(scales, 0); - } - - if (act_order_) { - LOAD_SHARDED_WEIGHT(g_idx, 0); - } - - if (bias_.defined()) { - // load bias - LOAD_WEIGHT(bias); - } -} - -void RowParallelQLinearGPTQMarlinImpl::verify_loaded_weights( - const std::string& prefix) const { - CHECK(qweight_is_loaded_) - << "qweight is not loaded for " << prefix + "qweight"; - CHECK(scales_is_loaded_) << "scales is not loaded for " << prefix + "scales"; - CHECK(!act_order_ || g_idx_is_loaded_) - << "g_idx is not loaded for " << prefix + "g_idx"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; -} - torch::Tensor RowParallelQLinearGPTQMarlinImpl::forward(torch::Tensor input) { // repack qweight and scales to marlin compatible format at the first call if (!perm_.defined()) { diff --git a/src/layers/quantization/qlinear_gptq_marlin_impl.h b/src/layers/quantization/qlinear_gptq_marlin_impl.h index f44fb9a6..00bf532b 100644 --- a/src/layers/quantization/qlinear_gptq_marlin_impl.h +++ b/src/layers/quantization/qlinear_gptq_marlin_impl.h @@ -23,18 +23,8 @@ class ColumnParallelQLinearGPTQMarlinImpl : public ParallelLinearImpl { const ParallelArgs& parallel_args, const torch::TensorOptions& options); - // verify if the weight is loaded correctly - void verify_loaded_weights(const std::string& prefix = "") const override; - torch::Tensor forward(torch::Tensor input) override; - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; - - // special load_state_dict for fused cases - void load_state_dict(const StateDict& state_dict, - const std::vector& prefixes) override; - private: // parameter members, must be registered DEFINE_FUSED_WEIGHT(qweight); @@ -80,12 +70,6 @@ class RowParallelQLinearGPTQMarlinImpl : public ParallelLinearImpl { torch::Tensor forward(torch::Tensor input) override; - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const override; - private: // parameter members, must be registered DEFINE_WEIGHT(qweight); diff --git a/src/models/causal_lm.h b/src/models/causal_lm.h index e65aa368..5973008f 100644 --- a/src/models/causal_lm.h +++ b/src/models/causal_lm.h @@ -34,10 +34,10 @@ class CausalLM { const torch::Tensor& seleted_idxes) = 0; // load the model from the given state_dict - virtual void load_state_dict(const StateDict& state_dict) = 0; + virtual void load(const StateDict& state_dict) = 0; // verify if the model is loaded correctly - virtual void verify_loaded_weights() const = 0; + virtual void verify() const = 0; virtual torch::Device device() const = 0; @@ -69,11 +69,9 @@ class CausalLMImpl : public CausalLM { return model_->logits(hidden_states, seleted_idxes); } - void load_state_dict(const StateDict& state_dict) override { - model_->load(state_dict); - } + void load(const StateDict& state_dict) override { model_->load(state_dict); } - void verify_loaded_weights() const override { + void verify() const override { bool success = model_->verify(); if (!success) { LOG(FATAL) << "Failed to verify loaded weights for the model." From 7566bd34760cc66556d838130b1c548b9984cb8e Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Sun, 2 Nov 2025 19:59:55 -0800 Subject: [PATCH 2/2] update --- src/tokenizer/tokenizer.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tokenizer/tokenizer.h b/src/tokenizer/tokenizer.h index 87508cce..00247d1a 100644 --- a/src/tokenizer/tokenizer.h +++ b/src/tokenizer/tokenizer.h @@ -20,6 +20,7 @@ namespace llm { // 2. Reversing this process by converting a sequence of integers back into // human-readable text using the same vocabulary. // +// // For example: // ids = tokenizer.Encode("Hello, world!") # [1, 2, 3] // text = tokenizer.Decode(ids) # "Hello, world!"