From 7cac40112dd82e747f6625673e34375178ad27fe Mon Sep 17 00:00:00 2001 From: Aakash Thatte Date: Wed, 15 Oct 2025 23:18:20 +0530 Subject: [PATCH 1/6] add initial code for ganq from author's implementation --- .../quantization/backends/ganq/__init__.py | 13 ++ .../quantization/backends/ganq/ganq.py | 85 ++++++++ .../quantization/backends/ganq/lut_quant.py | 187 ++++++++++++++++++ .../quantization/backends/ganq/utils.py | 97 +++++++++ 4 files changed, 382 insertions(+) create mode 100644 src/pruna/algorithms/quantization/backends/ganq/__init__.py create mode 100644 src/pruna/algorithms/quantization/backends/ganq/ganq.py create mode 100644 src/pruna/algorithms/quantization/backends/ganq/lut_quant.py create mode 100644 src/pruna/algorithms/quantization/backends/ganq/utils.py diff --git a/src/pruna/algorithms/quantization/backends/ganq/__init__.py b/src/pruna/algorithms/quantization/backends/ganq/__init__.py new file mode 100644 index 00000000..38e0d7e5 --- /dev/null +++ b/src/pruna/algorithms/quantization/backends/ganq/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/pruna/algorithms/quantization/backends/ganq/ganq.py b/src/pruna/algorithms/quantization/backends/ganq/ganq.py new file mode 100644 index 00000000..a723deeb --- /dev/null +++ b/src/pruna/algorithms/quantization/backends/ganq/ganq.py @@ -0,0 +1,85 @@ +# ruff: noqa: N806, N803, N802 +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import transformers + +from pruna.algorithms.quantization.backends.ganq.lut_quant import LUTQuant + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + +class GANQ: + """GANQ class for quantizing neural network layers.""" + + def __init__(self, layer, model_type): + self.layer = layer + self.dev = self.layer.weight.device + self.model_type = model_type + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.XXt = torch.zeros((self.columns, self.columns), device=self.dev) + + def add_batch(self, inp, out): + """Accumulate input statistics for quantization.""" + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + + if isinstance(self.layer, (nn.Linear, transformers.Conv1D)): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + inp = inp.float() + self.XXt += inp @ inp.T + + def fasterquant(self, sparsity=0.0, bits=4, max_epoch=10, pre_process=True, full_rows=0): + """Main function to perform GANQ quantization.""" + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + quant = LUTQuant( + bits=bits, + W=W, + XXt=self.XXt, + max_epoch=max_epoch, + sparsity=sparsity, + model_type=self.model_type, + pre_process=pre_process, + full_rows=full_rows, + ) + W = quant.quantization() + + torch.cuda.synchronize() + + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + + def free(self): + """Free up memory.""" + self.XXt = None + torch.cuda.empty_cache() diff --git a/src/pruna/algorithms/quantization/backends/ganq/lut_quant.py b/src/pruna/algorithms/quantization/backends/ganq/lut_quant.py new file mode 100644 index 00000000..e5c1e5ef --- /dev/null +++ b/src/pruna/algorithms/quantization/backends/ganq/lut_quant.py @@ -0,0 +1,187 @@ +# ruff: noqa: N806, N803, N802 +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pruna.algorithms.quantization.backends.ganq.utils import ( + denormalize, + init_t_3bit, + init_t_4bit, + norm_params, + normalize, +) + + +class LUTQuant: + """LUTQuant class for quantization using Look-Up Tables (LUTs).""" + + def __init__( + self, + bits, + W, + XXt, + max_epoch=10, + sparsity=0.0, + full_rows=0, + model_type="opt", + pre_process=True, + ): + if bits not in [3, 4]: + raise NotImplementedError("Only 3 and 4 bits are supported") + + # Set quantization parameters + self.bits = bits + self.num_levels = 2**self.bits + self.sparsity = sparsity + self.full_rows = full_rows + self.max_epoch = max_epoch + self.model_type = model_type + self.pre_process = pre_process + + # Store input tensors + self.W = W + self.m, self.n = self.W.shape + self.device = W.device + self.XXt = XXt + + # Handle outliers if sparsity is enabled + if self.sparsity > 0: + ratio = 1 - 0.5 * self.sparsity + self.W, self.S, self.row_mask = self.outlier_split(ratio=ratio, full_rows=self.full_rows) + + # Normalize weights if pre-processing is enabled + if self.pre_process: + self.median, self.iqr = norm_params(self.W) + self.W = normalize(self.W, self.median, self.iqr) + + # Compute Cholesky decomposition with numerical stability + if self.model_type == "opt": + dead = torch.diag(self.XXt) == 0 + self.XXt[dead, dead] = 1 + self.W[:, dead] = torch.mean(self.W[:, ~dead], dim=1, keepdim=True) + offset = (torch.sum(torch.abs(self.XXt), dim=1) - 2 * torch.diag(self.XXt)).clamp(min=1e-8) + self.L = torch.linalg.cholesky(self.XXt + torch.diag(offset)) + else: + try: + self.L = torch.linalg.cholesky(self.XXt) + except RuntimeError: + offset = (torch.sum(torch.abs(self.XXt), dim=1) - 2 * torch.diag(self.XXt)).clamp(min=1e-8) + self.L = torch.linalg.cholesky(self.XXt + torch.diag(offset)) + + # Precompute matrices + self.LLt = self.L @ self.L.T + self.WLLt = self.W @ self.LLt + self.WL = self.W @ self.L + self.WXXt = self.W @ self.XXt + + def dequantization(self, T, S): + """Dequantize using the selection matrix S and table T.""" + return torch.einsum("mil, mln -> min", T.unsqueeze(1), S).squeeze(1) + + def initialize_T(self, W, seg_levels=None, seg_boundaries=None): + """Initialize the table T using distribution-based splits.""" + if self.bits == 4: + seg_levels = seg_levels or [2, 6, 6, 2] if self.sparsity > 0 else seg_levels or [3, 5, 5, 3] + seg_boundaries = seg_boundaries or [0.0, 0.25, 0.5, 0.75, 1.0] + assert sum(seg_levels) == self.num_levels, "Number of levels must sum up to total num_levels." + return init_t_4bit(W, seg_levels, seg_boundaries) + elif self.bits == 3: + seg_levels = seg_levels or [2, 4, 2] + seg_boundaries = seg_boundaries or [0.0, 0.20, 0.80, 1.0] + assert sum(seg_levels) == self.num_levels, "Number of levels must sum up to total num_levels." + return init_t_3bit(W, seg_levels, seg_boundaries) + else: + raise NotImplementedError("Only 3 and 4 bits are supported") + + def outlier_split(self, ratio=0.9975, full_rows=0): + """Outlier detection version 3 (more sophisticated).""" + outlier_mask = torch.zeros_like(self.W) + if full_rows > 0: + row_variances = torch.var(self.W, dim=1, unbiased=False) + _, largest_variance_indices = torch.topk(row_variances, full_rows, largest=True, sorted=False) + outlier_mask[largest_variance_indices] = 1.0 + ratio = 2 * (1 - ratio) * self.m * self.n / ((self.m - full_rows) * self.n) + ratio = 1 - 2 * ratio + + row_mask = 1.0 - outlier_mask + cutoff_idx = int(self.n * ratio) - 1 + + sorted_W, sorted_indices = torch.sort(self.W, dim=1, stable=True) + cutoff_values = sorted_W[:, cutoff_idx].unsqueeze(1) + + lower_cutoff_idx = round(self.n * (1 - ratio) + 0.5) + 1 + lower_cutoff_values = sorted_W[:, lower_cutoff_idx + 1].unsqueeze(1) + + outliers = (cutoff_values <= self.W) | (lower_cutoff_values >= self.W) + outlier_mask[outliers] = 1.0 + + S_prime = self.W * outlier_mask + W_prime = self.W - S_prime + + return W_prime, S_prime, row_mask + + def update_S(self, T): + """Update the selection matrix S.""" + W_q = torch.zeros_like(self.W, device=self.device) + S = torch.zeros(self.m, self.num_levels, self.n, device=self.device) + + for i in range(self.n - 1, -1, -1): + residuals = self.WL[:, i] - torch.sum(W_q[:, i + 1 :] * self.L[i + 1 :, i], dim=1) + candidates = T * self.L[i, i] + + differences = torch.abs(candidates - residuals.unsqueeze(1)) + closest_indices = torch.argmin(differences, dim=1) + W_q[:, i] = T.gather(1, closest_indices.unsqueeze(1)).squeeze(1) + S[torch.arange(self.m), closest_indices, i] = 1 + + return S + + def update_T(self, S): + """Update the table T.""" + St = S.transpose(-1, -2) + SLLt = S @ self.LLt + SLLtSt = torch.matmul(SLLt, St) + + numerator = torch.matmul(self.WLLt.unsqueeze(1), St) + denominator = torch.linalg.pinv(SLLtSt.to(torch.float64)).to(torch.float32) + + T = torch.matmul(numerator, denominator) + + return T.squeeze(1) + + def quantization(self): + """Main quantization loop.""" + T = self.initialize_T(self.W) + best_diff = float("inf") + best_S, best_T = None, None + + for _ in range(self.max_epoch): + S = self.update_S(T) + T = self.update_T(S) + + residual = self.W - self.dequantization(T, S) + current_diff = torch.trace(residual.T @ residual @ self.XXt).item() + + if current_diff < best_diff: + best_S, best_T = S.detach(), T.detach() + best_diff = current_diff + + if self.max_epoch > 0: + if self.pre_process: + best_T = denormalize(best_T, self.median, self.iqr) + self.W = self.dequantization(best_T, best_S) + if self.sparsity > 0: + self.W = self.W * self.row_mask + self.S + + return self.W diff --git a/src/pruna/algorithms/quantization/backends/ganq/utils.py b/src/pruna/algorithms/quantization/backends/ganq/utils.py new file mode 100644 index 00000000..c9017d72 --- /dev/null +++ b/src/pruna/algorithms/quantization/backends/ganq/utils.py @@ -0,0 +1,97 @@ +# ruff: noqa: N806, N803, N802 +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def init_t_4bit(W, seg_levels=[3, 5, 5, 3], seg_boundaries=[0.0, 0.25, 0.5, 0.75, 1.0]): + """4-bit quantization with non-uniform levels.""" + W_min = W.min(dim=1).values + W_max = W.max(dim=1).values + range_vals = W_max - W_min + + T_segments = [] + + for i, levels in enumerate(seg_levels): + seg_start_ratio = seg_boundaries[i] + seg_end_ratio = seg_boundaries[i + 1] + seg_start = W_min + range_vals * seg_start_ratio + seg_end = W_min + range_vals * seg_end_ratio + seg_lin = torch.linspace(0, 1, steps=levels, device=W.device) + seg_T = seg_start.unsqueeze(1) + (seg_end - seg_start).unsqueeze(1) * seg_lin.unsqueeze(0) + T_segments.append(seg_T) + + return torch.cat(T_segments, dim=1) + + +def init_t_3bit(W, sub_counts=[2, 4, 2], fraction_boundaries=[0.0, 0.20, 0.80, 1.0]): + """3-bit quantization with distribution-based splits.""" + device = W.device + B, N = W.shape + sorted_data, _ = W.sort(dim=1) + + frac_t = torch.tensor(fraction_boundaries, device=device, dtype=sorted_data.dtype) + M = len(sub_counts) + + splitted_list = [] + + for j in range(M): + c_j = sub_counts[j] + if c_j == 0: + continue + + f0 = frac_t[j] + f1 = frac_t[j + 1] + + if j == 0: + base = torch.linspace(0.0, 1.0, steps=c_j, device=device, dtype=sorted_data.dtype) + frac_grid = f0 + base * (f1 - f0) + else: + base_full = torch.linspace(0.0, 1.0, steps=c_j + 1, device=device, dtype=sorted_data.dtype) + base_segment = base_full[1:] + frac_grid = f0 + base_segment * (f1 - f0) + + frac_grid = frac_grid.unsqueeze(0).expand(B, -1) + indexf = frac_grid * (N - 1) + left_idx = torch.floor(indexf).long() + right_idx = torch.clamp(left_idx + 1, max=N - 1) + alpha = indexf - left_idx + + left_vals = torch.gather(sorted_data, 1, left_idx) + right_vals = torch.gather(sorted_data, 1, right_idx) + sub_points = (1.0 - alpha) * left_vals + alpha * right_vals + + splitted_list.append(sub_points) + + return ( + torch.cat(splitted_list, dim=1) if splitted_list else torch.empty((B, 0), device=device, dtype=sorted_data.dtype) + ) + + +def norm_params(W): + """Compute the median and IQR for normalization.""" + median = W.median(dim=1, keepdim=True).values + q75, q25 = torch.quantile(W, 0.75, dim=1, keepdim=True), torch.quantile(W, 0.25, dim=1, keepdim=True) + iqr = q75 - q25 + 1e-8 + return median, iqr + + +def normalize(tensor, median, iqr): + """Normalize the tensor using median and IQR.""" + return (tensor - median) / iqr + + +def denormalize(tensor, median, iqr): + """Denormalize the tensor using median and IQR.""" + return tensor * iqr + median From de045b8c7cba1a978a120b84bf433647c4e17750 Mon Sep 17 00:00:00 2001 From: Aakash Thatte Date: Sat, 18 Oct 2025 14:47:05 +0530 Subject: [PATCH 2/6] add func to find layers --- .../quantization/backends/ganq/utils.py | 36 ++++++++++++++++--- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/src/pruna/algorithms/quantization/backends/ganq/utils.py b/src/pruna/algorithms/quantization/backends/ganq/utils.py index c9017d72..f731ff2e 100644 --- a/src/pruna/algorithms/quantization/backends/ganq/utils.py +++ b/src/pruna/algorithms/quantization/backends/ganq/utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +import torch.nn as nn def init_t_4bit(W, seg_levels=[3, 5, 5, 3], seg_boundaries=[0.0, 0.25, 0.5, 0.75, 1.0]): @@ -29,7 +30,9 @@ def init_t_4bit(W, seg_levels=[3, 5, 5, 3], seg_boundaries=[0.0, 0.25, 0.5, 0.75 seg_start = W_min + range_vals * seg_start_ratio seg_end = W_min + range_vals * seg_end_ratio seg_lin = torch.linspace(0, 1, steps=levels, device=W.device) - seg_T = seg_start.unsqueeze(1) + (seg_end - seg_start).unsqueeze(1) * seg_lin.unsqueeze(0) + seg_T = seg_start.unsqueeze(1) + (seg_end - seg_start).unsqueeze( + 1 + ) * seg_lin.unsqueeze(0) T_segments.append(seg_T) return torch.cat(T_segments, dim=1) @@ -55,10 +58,14 @@ def init_t_3bit(W, sub_counts=[2, 4, 2], fraction_boundaries=[0.0, 0.20, 0.80, 1 f1 = frac_t[j + 1] if j == 0: - base = torch.linspace(0.0, 1.0, steps=c_j, device=device, dtype=sorted_data.dtype) + base = torch.linspace( + 0.0, 1.0, steps=c_j, device=device, dtype=sorted_data.dtype + ) frac_grid = f0 + base * (f1 - f0) else: - base_full = torch.linspace(0.0, 1.0, steps=c_j + 1, device=device, dtype=sorted_data.dtype) + base_full = torch.linspace( + 0.0, 1.0, steps=c_j + 1, device=device, dtype=sorted_data.dtype + ) base_segment = base_full[1:] frac_grid = f0 + base_segment * (f1 - f0) @@ -75,14 +82,18 @@ def init_t_3bit(W, sub_counts=[2, 4, 2], fraction_boundaries=[0.0, 0.20, 0.80, 1 splitted_list.append(sub_points) return ( - torch.cat(splitted_list, dim=1) if splitted_list else torch.empty((B, 0), device=device, dtype=sorted_data.dtype) + torch.cat(splitted_list, dim=1) + if splitted_list + else torch.empty((B, 0), device=device, dtype=sorted_data.dtype) ) def norm_params(W): """Compute the median and IQR for normalization.""" median = W.median(dim=1, keepdim=True).values - q75, q25 = torch.quantile(W, 0.75, dim=1, keepdim=True), torch.quantile(W, 0.25, dim=1, keepdim=True) + q75, q25 = torch.quantile(W, 0.75, dim=1, keepdim=True), torch.quantile( + W, 0.25, dim=1, keepdim=True + ) iqr = q75 - q25 + 1e-8 return median, iqr @@ -95,3 +106,18 @@ def normalize(tensor, median, iqr): def denormalize(tensor, median, iqr): """Denormalize the tensor using median and IQR.""" return tensor * iqr + median + + +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""): + """For a given module, find all sub-modules of specified layer types + and return a dictionary of their names and instances.""" + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update( + find_layers( + child, layers=layers, name=name + "." + name1 if name != "" else name1 + ) + ) + return res From c1889d5fbc89b1b71bd77822ebf1dfbc51c17a76 Mon Sep 17 00:00:00 2001 From: Aakash Thatte Date: Sat, 18 Oct 2025 14:47:47 +0530 Subject: [PATCH 3/6] add note --- .../algorithms/quantization/backends/ganq/ganq.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/pruna/algorithms/quantization/backends/ganq/ganq.py b/src/pruna/algorithms/quantization/backends/ganq/ganq.py index a723deeb..de27975e 100644 --- a/src/pruna/algorithms/quantization/backends/ganq/ganq.py +++ b/src/pruna/algorithms/quantization/backends/ganq/ganq.py @@ -45,14 +45,19 @@ def add_batch(self, inp, out): inp = inp.unsqueeze(0) if isinstance(self.layer, (nn.Linear, transformers.Conv1D)): - if len(inp.shape) == 3: + + # Note: Official implementation uses == 3 condition, + # refer here - https://github.com/Evans-Z/GANQ/blob/176a87701fd0e07aea1ccd4f3faff84871d79f44/ganq.py#L39 + if len(inp.shape) > 2: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() inp = inp.float() self.XXt += inp @ inp.T - def fasterquant(self, sparsity=0.0, bits=4, max_epoch=10, pre_process=True, full_rows=0): + def fasterquant( + self, sparsity=0.0, bits=4, max_epoch=10, pre_process=True, full_rows=0 + ): """Main function to perform GANQ quantization.""" W = self.layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): @@ -77,7 +82,9 @@ def fasterquant(self, sparsity=0.0, bits=4, max_epoch=10, pre_process=True, full if isinstance(self.layer, transformers.Conv1D): W = W.t() - self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + self.layer.weight.data = W.reshape(self.layer.weight.shape).to( + self.layer.weight.data.dtype + ) def free(self): """Free up memory.""" From 56490fb3faf9d0f35b21600ea0e0751f4ab788f7 Mon Sep 17 00:00:00 2001 From: Aakash Thatte Date: Sat, 18 Oct 2025 14:48:09 +0530 Subject: [PATCH 4/6] add imports to init --- .../quantization/backends/ganq/__init__.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/pruna/algorithms/quantization/backends/ganq/__init__.py b/src/pruna/algorithms/quantization/backends/ganq/__init__.py index 38e0d7e5..ba4b41cf 100644 --- a/src/pruna/algorithms/quantization/backends/ganq/__init__.py +++ b/src/pruna/algorithms/quantization/backends/ganq/__init__.py @@ -11,3 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .ganq import GANQ +from .lut_quant import LUTQuant +from .utils import * + +__all__ = [ + "GANQ", + "LUTQuant", + "init_t_3bit", + "init_t_4bit", + "normalize", + "denormalize", + "norm_params", + "find_layers", +] From e57992bf16edcc1e717160e5dbe1d989ab146405 Mon Sep 17 00:00:00 2001 From: Aakash Thatte Date: Sat, 18 Oct 2025 14:49:11 +0530 Subject: [PATCH 5/6] add initial ganq quantizer --- src/pruna/algorithms/quantization/ganq.py | 220 ++++++++++++++++++++++ 1 file changed, 220 insertions(+) create mode 100644 src/pruna/algorithms/quantization/ganq.py diff --git a/src/pruna/algorithms/quantization/ganq.py b/src/pruna/algorithms/quantization/ganq.py new file mode 100644 index 00000000..3532d6fc --- /dev/null +++ b/src/pruna/algorithms/quantization/ganq.py @@ -0,0 +1,220 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any + +import torch +from ConfigSpace import Constant, OrdinalHyperparameter + +from pruna.algorithms.quantization import PrunaQuantizer +from pruna.config.hyperparameters import Boolean +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import ( + is_causal_lm, +) + +from pruna.engine.utils import safe_memory_cleanup +from pruna.logging.filter import SuppressOutput +from pruna.logging.logger import pruna_logger + + +class GANQQuantizer(PrunaQuantizer): + """GPU-Adaptive Non-Uniform Quantization (GANQ). + + GANQ performs layer-wise LUT-based non-uniform quantization by + alternating optimization of codebook (T) and selection matrix (S). + It adapts to the distribution of each layer’s weights and supports + optional normalization and outlier handling.""" + + algorithm_name: str = "ganq" + references: dict[str, str] = { + "GitHub": "https://github.com/Evans-Z/GANQ", + "Article": "https://arxiv.org/pdf/2501.12956", + } + runs_on: list[str] = ["cuda"] + dataset_required: bool = True + compatible_algorithms = dict(compiler=["torch_compile"]) + processor_required: bool = False + tokenizer_required: bool = False + + def get_hyperparameters(self): + return [ + OrdinalHyperparameter( + "weight_bits", + [3, 4], + default_value=4, + meta=dict(desc="Bit width for weight quantization."), + ), + OrdinalHyperparameter( + "max_epoch", + [5, 10, 20], + default_value=10, + meta=dict(desc="Number of GANQ alternating steps."), + ), + Boolean( + "pre_process", + default=True, + meta=dict( + desc="Normalize weights with median/IQR before quantization." + ), + ), + Constant("sparsity", value=0.0), + Constant("full_rows", value=0), + ] + + def model_check_fn(self, model: Any) -> bool: + return is_causal_lm(model) + + def _apply(self, model, smash_config: SmashConfigPrefixWrapper): + imported_packages = self.import_algorithm_packages() + GANQ, find_layers = imported_packages["GANQ"], imported_packages["find_layers"] + + pruna_logger.info("Running GANQ layer-wise quantization...") + model.eval() + device = smash_config["device"] + + val_dl = smash_config.val_dataloader() + + # TODO: Align on whether to use batch, or use calib data and use tokenizer + # calib_data = recover_text_from_dataloader(val_dl, smash_config.tokenizer) # type: ignore[arg-type] + # tokenizer = smash_config.tokenizer + # max_len = getattr(model.config, "max_position_embeddings") + safe_memory_cleanup() + + layers = model.model.layers + print(layers) + inps = [] + cache = {"i": 0, "attention_mask": None, "position_ids": None} + + # Note: This is only used to capture inputs to layer 0, hence the value error + class Catcher(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inp, **kwargs): + if cache["i"] < len(inps): + inps[cache["i"]] = inp + else: + inps.append(inp) + cache["i"] += 1 + cache["attention_mask"] = kwargs.get("attention_mask", None) + cache["position_ids"] = kwargs.get("position_ids", None) + raise ValueError + + layers[0] = Catcher(layers[0]) + + for batch in val_dl: + try: + model(batch[0].to(device)) + except ValueError: + pass + + layers[0] = layers[0].module + pruna_logger.info(f"Captured {len(inps)} input samples for layer 0.") + + outs = torch.zeros_like(inps[0]) + attention_mask = cache["attention_mask"] + position_ids = cache["position_ids"] + + for i, layer in enumerate(layers): + layer = layer.to(device) + layer_dict = find_layers(layer) + + pruna_logger.info(f"Quantizing layer {i} ({len(layer_dict)} submodules)...") + + # Group submodules like the official repo + sequential_groups = [ + ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], + ["self_attn.o_proj"], + ["mlp.up_proj", "mlp.gate_proj"], + ["mlp.down_proj"], + ] + + for group in sequential_groups: + subset = {n: layer_dict[n] for n in group if n in layer_dict} + if not subset: + continue + + gpts = { + name: GANQ(mod, model_type="hf") for name, mod in subset.items() + } + + def make_hook(name): + def hook_fn(_, inp, out): + gpts[name].add_batch(inp[0].detach(), out.detach()) + + return hook_fn + + handles = [ + mod.register_forward_hook(make_hook(name)) + for name, mod in subset.items() + ] + + with torch.no_grad(): + for j in range(len(inps)): + + # Note: This is different from author's implementation, because of change in RoPE handling in transformers version 4.42 and 4.53 + # refer original implementation here - https://github.com/Evans-Z/GANQ/blob/176a87701fd0e07aea1ccd4f3faff84871d79f44/llama.py#L127 + hs = inps[j].to(device) # [1, seq, hidden] + seq_len = hs.shape[1] + pos_ids = position_ids[:1, :seq_len].to(device) + cos, sin = model.model.rotary_emb(hs, pos_ids) + cache_pos = torch.arange(seq_len, device=device) + + outs = layer( + hs, + attention_mask=attention_mask, + position_ids=pos_ids, + position_embeddings=(cos, sin), + )[0] + + for h in handles: + h.remove() + + # Run quantization per submodule + for name, gpt in gpts.items(): + pruna_logger.info(f"Quantizing {name}...") + gpt.fasterquant( + # TODO: Figure out a way to pass these parameters + # sparsity=smash_config["sparsity"], + # bits=smash_config["weight_bits"], + # max_epoch=smash_config["max_epoch"], + # pre_process=smash_config["pre_process"], + # full_rows=smash_config["full_rows"], + ) + gpt.free() + + inps = [outs] + layer = layer.cpu() + torch.cuda.empty_cache() + + pruna_logger.info("✅ GANQ quantization complete.") + return model + + def import_algorithm_packages(self): + + with SuppressOutput(): + from pruna.algorithms.quantization.backends.ganq import ( + GANQ, + LUTQuant, + find_layers, + ) + + return dict( + GANQ=GANQ, + LUTQuant=LUTQuant, + find_layers=find_layers, + ) From a06cb52f85766e5a3e514119c35682ba3dd48b57 Mon Sep 17 00:00:00 2001 From: Aakash Thatte Date: Sat, 18 Oct 2025 14:59:26 +0530 Subject: [PATCH 6/6] add basic readme --- .../quantization/backends/ganq/README.md | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 src/pruna/algorithms/quantization/backends/ganq/README.md diff --git a/src/pruna/algorithms/quantization/backends/ganq/README.md b/src/pruna/algorithms/quantization/backends/ganq/README.md new file mode 100644 index 00000000..325bc8b3 --- /dev/null +++ b/src/pruna/algorithms/quantization/backends/ganq/README.md @@ -0,0 +1,129 @@ +#### Guide to use GANQ quantization + +**Quantize a model** + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from pruna.config.smash_config import SmashConfig +from pruna.data.pruna_datamodule import PrunaDataModule + + +import torch +from transformers import AutoModelForCausalLM + +import torch +from pruna.algorithms.quantization.ganq import GANQQuantizer + +# ------------------------------------------------------------------------- +# 1. Load model and tokenizer +# ------------------------------------------------------------------------- +model_name = "HuggingFaceTB/SmolLM2-135M" +tokenizer = AutoTokenizer.from_pretrained(model_name) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token +model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float16, device_map="auto" +) +model.eval() + +# ------------------------------------------------------------------------- +# 2. Build SmashConfig for Pruna Quantizer +# ------------------------------------------------------------------------- +smash_config = SmashConfig( + batch_size=4, + device="cuda" if torch.cuda.is_available() else "cpu", + cache_dir_prefix="./cache_ganq", +) + +# Add tokenizer +smash_config.add_tokenizer(tokenizer) + +# Use Pruna's built-in WikiText dataset (handles train/val/test splits automatically) +data_module = PrunaDataModule.from_string( + "WikiText", + tokenizer=tokenizer, + collate_fn_args=dict(max_seq_len=256), +) +data_module.limit_datasets(32) # Limit to 32 examples per split for quick testing +smash_config.add_data(data_module) + +# Configure quantizer parameters +smash_config.load_dict( + { + "quantizer": "ganq", + "ganq_weight_bits": 4, + "ganq_max_epoch": 10, + "ganq_pre_process": True, + } +) + +# ------------------------------------------------------------------------- +# 4. Run Quantization +# ------------------------------------------------------------------------- +quantizer = GANQQuantizer() + +quantized_model = quantizer._apply(model, smash_config) + +# ------------------------------------------------------------------------- +# 5. Save the quantized model +# ------------------------------------------------------------------------- +quantized_model.save_pretrained("./ganq_quantized_smollm") +tokenizer.save_pretrained("./ganq_quantized_smollm") + +print("✅ GANQ quantization complete and model saved at ./ganq_quantized_smollm") + + +def model_size_in_mb(model): + param_size = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + buffer_size = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + size_all_mb = (param_size + buffer_size) / 1024**2 + return size_all_mb + + +original_size = model_size_in_mb(model) +quantized_size = model_size_in_mb(quantized_model) +print(f"Original model size: {original_size:.2f} MB") +print(f"Quantized model size: {quantized_size:.2f} MB") + +``` + + +**Verify if quantization worked** + +The logic here is that since GANQ uses a codebook of size (m, L) for a weight matrix for size (m,n) where L is 2^k (k = number of bits), each row in the weight matrix W should only contain values from the corressponding row in the codebook, where selection is driven by the one hot matrix S. So number of unique values in each row of W should be exactly L. + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch + +model_name = "HuggingFaceTB/SmolLM2-135M" +tokenizer = AutoTokenizer.from_pretrained(model_name) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token +model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float16, device_map="auto" +) +model.eval() + +model_q = AutoModelForCausalLM.from_pretrained( + "ganq_quantized_smollm" +) + +def verify_unique_entries_in_row(layer, row_idx=0): + Wq = layer.self_attn.q_proj.weight.data + unique_entries = torch.unique(Wq[row_idx]) + print(f"Number of unique entries in row {row_idx}: {unique_entries.numel()}") + +verify_unique_entries_in_row(model_q.model.layers[1], row_idx=1) +verify_unique_entries_in_row(model.model.layers[1], row_idx=1) + +# In my experiments, it gave this: +# Number of unique entries in row 1: 16 (since I used 4-bit quantization) +# Number of unique entries in row 1: 471 +``` \ No newline at end of file