Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/engine/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ bool LLMEngine::init_model(const std::string& model_weights_path) {
std::vector<folly::SemiFuture<folly::Unit>> futures;
futures.reserve(workers_.size());
for (auto& worker : workers_) {
futures.push_back(worker->load_state_dict_async(state_dict));
futures.push_back(worker->load_async(state_dict));
}
// wait for all futures to complete
auto results = folly::collectAll(futures).get();
Expand All @@ -206,7 +206,7 @@ bool LLMEngine::init_model(const std::string& model_weights_path) {

// verify the weights are loaded correctly
for (const auto& worker : workers_) {
worker->verify_loaded_weights();
worker->verify();
}
return true;
}
Expand Down
13 changes: 6 additions & 7 deletions src/engine/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ void Worker::capture_cuda_graph(uint32_t batch_size) {
return model_runner_->capture_cuda_graphs(batch_size, kv_caches_);
}

void Worker::load_state_dict(const StateDict& state_dict) {
void Worker::load(const StateDict& state_dict) {
CHECK(model_ != nullptr) << "Model is not initialized.";
model_->load_state_dict(state_dict);
model_->load(state_dict);
}

void Worker::verify_loaded_weights() const {
void Worker::verify() const {
CHECK(model_ != nullptr) << "Model is not initialized.";
model_->verify_loaded_weights();
model_->verify();
}

std::tuple<int64_t, int64_t> Worker::profile_device_memory() {
Expand Down Expand Up @@ -270,14 +270,13 @@ folly::SemiFuture<folly::Unit> Worker::capture_cuda_graph_async(
return future;
}

folly::SemiFuture<folly::Unit> Worker::load_state_dict_async(
const StateDict& state_dict) {
folly::SemiFuture<folly::Unit> Worker::load_async(const StateDict& state_dict) {
folly::Promise<folly::Unit> promise;
auto future = promise.getSemiFuture();
threadpool_.schedule(
[this, &state_dict, promise = std::move(promise)]() mutable {
// load the model weights from state_dict within the working thread
this->load_state_dict(state_dict);
this->load(state_dict);
promise.setValue();
});
return future;
Expand Down
7 changes: 3 additions & 4 deletions src/engine/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ class Worker final {

// Load the model weights from state_dict. blocking call
// can be called multiple times to reload the model with different parameters
void load_state_dict(const StateDict& state_dict);
void load(const StateDict& state_dict);

// verify if the model is loaded correctly
void verify_loaded_weights() const;
void verify() const;

// returns available memory and total memory
std::tuple<int64_t, int64_t> profile_device_memory();
Expand All @@ -57,8 +57,7 @@ class Worker final {

// Load the model weights from state_dict. async call
// the future returns a successfull status with no meaningful value
folly::SemiFuture<folly::Unit> load_state_dict_async(
const StateDict& state_dict);
folly::SemiFuture<folly::Unit> load_async(const StateDict& state_dict);

folly::SemiFuture<std::tuple<int64_t, int64_t>> profile_device_memory_async();

Expand Down
4 changes: 2 additions & 2 deletions src/engine/worker_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class TestableWorker {
}

auto state_dict = create_state_dict();
worker_->load_state_dict(state_dict);
worker_->verify_loaded_weights();
worker_->load(state_dict);
worker_->verify();
return true;
}

Expand Down
16 changes: 8 additions & 8 deletions src/layers/linear/parallel_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
#include <memory>

#include "layers/module/module.h"
#include "layers/quantization/parallel_qlinear_gptq.h"
#include "layers/quantization/qlinear_awq_impl.h"
#include "layers/quantization/qlinear_awq_marlin_impl.h"
#include "layers/quantization/qlinear_exllamav2_impl.h"
#include "layers/quantization/qlinear_gptq_impl.h"
#include "layers/quantization/qlinear_gptq_marlin_impl.h"
#include "model_parallel/model_parallel.h"

Expand Down Expand Up @@ -182,13 +182,13 @@ std::shared_ptr<ParallelLinearImpl> create_column_parallel_linear(
parallel_args,
options);
}
return std ::make_shared<ColumnParallelLinearImpl>(in_features,
out_features,
bias,
gather_output,
parallel_args,
options,
prefix);
return std::make_shared<ColumnParallelLinearImpl>(in_features,
out_features,
bias,
gather_output,
parallel_args,
options,
prefix);
}

std::shared_ptr<ParallelLinearImpl> create_row_parallel_linear(
Expand Down
15 changes: 0 additions & 15 deletions src/layers/linear/parallel_linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,6 @@ class ParallelLinearImpl : public Module {
~ParallelLinearImpl() override = default;

virtual torch::Tensor forward(torch::Tensor input) = 0;

// TODO: clean up the interface of load_state_dict
virtual void load_state_dict(const StateDict& state_dict) {
LOG(FATAL) << "not implemented";
}

virtual void verify_loaded_weights(const std::string& prefix = "") const {
LOG(FATAL) << "not implemented";
}

// special load_state_dict for fused cases
virtual void load_state_dict(const StateDict& /*state_dict*/,
const std::vector<std::string>& /*prefixes*/) {
LOG(FATAL) << "not implemented";
}
};

// Linear layer with column parallelism.
Expand Down
2 changes: 1 addition & 1 deletion src/layers/linear/qkv_parallel_linear_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ TEST_P(QKVColumnParallelLinearTest, LoadFusedWeight) {
quant_args,
parallel_args,
options);
linear.load(state_dict);
EXPECT_EQ(linear.load(state_dict), 3);
EXPECT_TRUE(linear.verify());

// generate random input and compare with the output
Expand Down
13 changes: 7 additions & 6 deletions src/layers/quantization/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ cc_library(
quantization
HDRS
pack_utils.h
qlinear_impl.h
qlinear_gptq_impl.h
parallel_qlinear.h
parallel_qlinear_gptq.h
qlinear_exllamav2_impl.h
qlinear_awq_impl.h
qlinear_gptq_marlin_impl.h
qlinear_awq_marlin_impl.h
SRCS
pack_utils.cpp
qlinear_impl.cpp
qlinear_gptq_impl.cpp
parallel_qlinear.cpp
parallel_qlinear_gptq.cpp
qlinear_exllamav2_impl.cpp
qlinear_awq_impl.cpp
qlinear_gptq_marlin_impl.cpp
Expand All @@ -34,10 +34,11 @@ cc_library(

cc_test(
NAME
quantization_test
qlinear_test
SRCS
pack_utils_test.cpp
qlinear_impl_test.cpp
parallel_qlinear_test.cpp
parallel_qlinear_gptq_test.cpp
DEPS
:quantization
:state_dict
Expand Down
Loading
Loading