Skip to content

refactor: clean up legacy load_state_dict for linear layers#503

Merged
guocuimi merged 13 commits intomainfrom
linear_refactor
Oct 8, 2025
Merged

refactor: clean up legacy load_state_dict for linear layers#503
guocuimi merged 13 commits intomainfrom
linear_refactor

Conversation

@guocuimi
Copy link
Collaborator

No description provided.

@guocuimi guocuimi requested a review from Copilot September 25, 2025 01:25
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the legacy load_state_dict method for linear layers by removing the transform function variant and updating method calls to use simplified loading mechanisms. The changes also improve code structure by using structured bindings for QKV projections and updating parameter registration to use sharded parameters.

  • Removes the legacy load_state_dict method with transform functions from linear layer implementations
  • Updates QKV projection calls across multiple model architectures to use structured bindings instead of array indexing
  • Refactors parameter registration in quantized linear layers to use register_sharded_parameter instead of register_parameter

Reviewed Changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated no comments.

Show a summary per file
File Description
src/quantization/qlinear_impl_test.cpp Updates test calls to use load() and verify() methods instead of legacy state dict methods
src/quantization/qlinear_impl.cpp Replaces parameter registration calls with sharded parameter registration including rank and world_size
src/models/meta/llama.h Replaces QKV array indexing with structured binding for cleaner code
src/models/google/gemma2.h Replaces QKV array indexing with structured binding for cleaner code
src/models/google/gemma.h Replaces QKV array indexing with structured binding for cleaner code
src/models/alibaba/qwen2.h Replaces QKV array indexing with structured binding for cleaner code
src/layers/qkv_linear_test.cpp Updates test to use structured binding for QKV outputs
src/layers/qkv_linear.h Updates forward method signature to return tuple instead of vector
src/layers/linear_impl.h Removes legacy transform function overload declaration
src/layers/linear_impl.cpp Removes legacy transform function implementation and simplifies loading
src/layers/linear.h Removes virtual transform function method from base class

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

torch::Tensor forward(torch::Tensor x) {
const auto gate_up = gate_up_proj_(x);
return down_proj_(act_func_(gate_up[0]) * gate_up[1]);
// const auto gate_up = gate_up_proj_(x);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revet this

@guocuimi guocuimi merged commit f29965e into main Oct 8, 2025
3 checks passed
@guocuimi guocuimi deleted the linear_refactor branch October 8, 2025 00:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants