Skip to content

"ValueError: Expected 1 columns but got 8 columns." when quantizing Llama3.2-1B #201

@SunXT-0719

Description

@SunXT-0719

Thank you for providing such a great repo, but i can't run quantization for other model, even it's llama.
The command is:

python run_vptq.py \
        --model_name meta-llama/Llama-3.2-1B \
        --output_dir /data/sunxuetin/Llama-3.2-1B-VPTQ \
        --vector_lens -1 8 \
        --group_num 1 \
        --num_centroids -1 65536 \
        --num_res_centroids -1 256 \
        --blocksize 128 \
        --new_eval \
        --seq_len 8192 \
        --kmeans_mode hessian \
        --num_gpus 1 \
        --enable_perm \
        --enable_norm \
        --save_model \
        --save_packed_model \
        --hessian_path /data/sunxuetin/hessian/llama3.2-1B \
        --inv_hessian_path /data/sunxuetin/hessian_inv/llama3.2-1B \
        --ktol 1e-5 --kiter 100

I got hessian via https://github.com/VPTQ/hessian_collector/blob/master/hessian_collector.py as you say in other issues,
The complete error message is:

args: VPTQArguments(model_name='meta-llama/Llama-3.2-1B', seq_len=8192, quant_step=1, percdamp=0.01, blocksize=128, output_dir='/data/sunxuetin/Llama-3.2-1B-VPTQ/2025-10-21-10-09-01', seed=42, eval=False, new_eval=True, save_model=True, save_packed_model=True, disable_actorder=False, hessian_path='/data/sunxuetin/hessian/llama3.2-1B', inv_hessian_path='/data/sunxuetin/hessian_inv/llama3.2-1B', num_gpus=1, eval_nsamples=128, save_qlinear=True, absorb_perm=False)
quant_args: QuantizationArguments(vector_lens=[-1, 8], num_centroids=[-1, 65536], num_res_centroids=[-1, 256], npercent=0, group_num=1, group_size=-1, kiter=100, ktol=1e-05, kseed=0, kmeans_mode='hessian', kmeans_alpha=0, enable_norm=True, norm_dim=0, enable_perm=True)
Starting VPTQ...
model dtype: torch.bfloat16
----quantization start ...---- 2025-10-21 10:09:04
gpu 0 tasks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
INFO - ----Quantizing on cuda:0----
INFO - ----Quantizing layer 0 ...---- 2025-10-21 10:09:09 on cuda:0 dtype torch.bfloat16
INFO - dict_keys(['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', 'mlp.gate_proj', 'mlp.up_proj', 'mlp.down_proj'])
INFO - load Hessian from /data/sunxuetin/hessian/llama3.2-1B/0_qkv.pt
INFO - load inv Hessian from /data/sunxuetin/hessian_inv/llama3.2-1B/0_qkv.pt
INFO - ----Quantizing llama ...---- 2025-10-21 10:09:09 0.self_attn.q_proj
INFO - enabling norm dim 0, layer_name:0.self_attn.q_proj, scale:torch.Size([2048]), bias:torch.Size([2048])
INFO - kmeans_mode: hessian, enable_perm: True, enable_norm: True,
INFO - data shape: torch.Size([2048, 2048]), weights shape: torch.Size([2048, 2048])
INFO - group_size: 2048 number of groups: 1
INFO - idx: 0, num_centroids: -1, skip
INFO - cuml kmeans 23 iterations, error 1998.8406982421875
INFO - idx: 1, quant_data shape: torch.Size([524288, 8])
INFO - idx: 1, quant_data shape: torch.Size([2048, 2048])
INFO - quantized_data shape: torch.Size([2048, 2048])
INFO - 0.self_attn.q_proj 1st kmeans time: 2.732691526412964
INFO - 0.self_attn.q_proj qweight init shape: torch.Size([2048, 2048]), weight shape: torch.Size([2048, 2048])
INFO - 0.self_attn.q_proj proxy error before VPTQ: 0.0001238604891113937, 0.03463037312030792, 0.0035766432993113995
INFO - 0.self_attn.q_proj 1st error time: 1.029219150543213
INFO - 0.self_attn.q_proj proxy error after VPTQ: 6.773966742912307e-05, 0.03463037312030792, 0.0019560768269002438
INFO - group_size: 2048 number of groups: 1
INFO - idx: 0, num_centroids: -1, skip
INFO - kmeans_mode: hessian, cuml kmeans, 256 clusters
Traceback (most recent call last):
File "/data/sunxuetin/projects/VPTQ/run_vptq.py", line 86, in
model, quantizers = quant_llama(model, args, quant_args)
File "/data/sunxuetin/projects/VPTQ/vptq/models/llama.py", line 122, in quant_llama
layer_state_dicts, layer_qlinear_args = quantize_executer(0, tasks[0], args, quant_args, None, None)
File "/data/sunxuetin/projects/VPTQ/vptq/quantize_executer.py", line 69, in quantize_executer
layer, qlinear_args = layer_quantizer(
File "/data/sunxuetin/projects/VPTQ/vptq/layer_quantizer.py", line 105, in layer_quantizer
_vptq.fast_vector_quant()
File "/data/sunxuetin/projects/VPTQ/vptq/vptq.py", line 261, in fast_vector_quant
qweight_residual = self.quantizer.init_res_centroids_indices(qerror, kmeans_weight)
File "/data/sunxuetin/projects/VPTQ/vptq/quantizer.py", line 403, in init_res_centroids_indices
_kmeans.fit(sub_vectors, sample_weight=vector_weights)
File "/data/sunxuetin/anaconda3/envs/vptq-algo/lib/python3.10/site-packages/cuml/internals/api_decorators.py", line 211, in wrapper
ret = func(*args, **kwargs)
File "cuml/cluster/kmeans.pyx", line 520, in cuml.cluster.kmeans.KMeans.fit
File "/data/sunxuetin/anaconda3/envs/vptq-algo/lib/python3.10/site-packages/cuml/internals/api_decorators.py", line 211, in wrapper
ret = func(*args, **kwargs)
File "cuml/cluster/kmeans.pyx", line 535, in cuml.cluster.kmeans.KMeans._fit
File "/data/sunxuetin/anaconda3/envs/vptq-algo/lib/python3.10/site-packages/cuml/internals/input_utils.py", line 358, in input_to_cuml_array
arr = CumlArray.from_input(
File "/data/sunxuetin/anaconda3/envs/vptq-algo/lib/python3.10/site-packages/cuml/internals/memory_utils.py", line 74, in cupy_rmm_wrapper
return func(*args, **kwargs)
File "/data/sunxuetin/anaconda3/envs/vptq-algo/lib/python3.10/site-packages/cuml/internals/array.py", line 1216, in from_input
raise ValueError(
ValueError: Expected 1 columns but got 8 columns.

I've printed the shape of sub_vectors, it's (524288, 8).
Have anyone meets the same problem? Thanks a lot if you can give me a solution!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions