From e3fb475251e8c8050ec83a699de74bb4ed9140b6 Mon Sep 17 00:00:00 2001 From: Aakash Thatte Date: Sun, 9 Nov 2025 23:36:29 +0530 Subject: [PATCH 1/4] add initial implementation for shortgpt --- src/pruna/algorithms/shortgpt.py | 149 +++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 src/pruna/algorithms/shortgpt.py diff --git a/src/pruna/algorithms/shortgpt.py b/src/pruna/algorithms/shortgpt.py new file mode 100644 index 00000000..5f8ded6b --- /dev/null +++ b/src/pruna/algorithms/shortgpt.py @@ -0,0 +1,149 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +import torch +import torch.nn.functional as F +import numpy as np +from tqdm import tqdm +from typing import Any, Dict, List + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag as tags +from pruna.config.hyperparameters import Boolean +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.save import SAVE_FUNCTIONS +from pruna.logging.logger import pruna_logger + +class ShortGPT(PrunaAlgorithmBase): + """ + ShortGPT algorithm for pruning transformer layers using a block influence metric. + + ShortGPT identifies and prunes less important blocks in transformer models based on their + BI scores, which uses the similarity between a layers input and output to measure its importance. + """ + + algorithm_name: str = "shortgpt" + group_tags: list[str] = [tags.PRUNER] + references: dict[str, str] = { + "Paper": "https://arxiv.org/pdf/2403.03853", + } + save_fn = SAVE_FUNCTIONS.pickled + tokenizer_required: bool = True + dataset_required: bool = True + processor_required: bool = False + runs_on: list[str] = ["cuda", "cpu"] + + def get_hyperparameters(self) -> list: + from ConfigSpace import CategoricalHyperparameter, UniformFloatHyperparameter, UniformIntegerHyperparameter + return [ + CategoricalHyperparameter( + "metric_type", ["BI"], default_value="BI", + meta=dict(desc="Metric type for layer importance: Block Influence") + ), + UniformFloatHyperparameter( + "prune_ratio", lower=0.0, upper=0.8, default_value=0.25, + meta=dict(desc="Fraction of layers to prune") + ), + Boolean("angular", meta=dict(desc="Use angular distance for BI computation")), + UniformIntegerHyperparameter( + "calibration_samples", lower=8, upper=512, default_value=64, + meta=dict(desc="Number of calibration samples to compute metrics") + ), + ] + + + @staticmethod + @torch.inference_mode() + def compute_block_influence(model, tokenizer, texts, angular=False, device="cuda", max_samples=64): + model.eval().to(device) + num_layers = len(model.model.layers) + bis = torch.zeros(num_layers + 1, device=device) + counts = 0 + + for text in tqdm(texts[:max_samples], desc="Computing Block Influence"): + inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device) + input_ids = inputs["input_ids"] + hiddens = [] + + def hook_fn(_, __, out): + if isinstance(out, tuple): out = out[0] + hiddens.append(out) + + handles = [layer.register_forward_hook(hook_fn) for layer in model.model.layers] + _ = model(input_ids=input_ids) + for h in handles: h.remove() + + hiddens.insert(0, model.model.embed_tokens(input_ids)) + hiddens.append(model.model.norm(hiddens[-1])) + + for i in range(len(hiddens) - 1): + in_h, out_h = hiddens[i].float(), hiddens[i + 1].float() + cos = F.cosine_similarity( + in_h.view(-1, in_h.shape[-1]), + out_h.view(-1, out_h.shape[-1]), + dim=-1 + ) + if angular: + cos = cos.clamp(-1 + 1e-7, 1 - 1e-7) + bi = torch.acos(cos).mean() / np.pi + else: + bi = (1 - cos).mean() + bis[i] += bi + counts += 1 + + bis /= counts + return bis.tolist() + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + device = smash_config["device"] + model = model.to(device) + model.eval() + + pruna_logger.info(f"[ShortGPT] Starting layer pruning for model on device: {device}") + pruna_logger.info(f"[ShortGPT] Model depth: {len(model.model.layers)}") + pruna_logger.info(f"[ShortGPT] Model parameters: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M") + tokenizer = smash_config["tokenizer"] + + texts = smash_config["texts"] + + metric_type = smash_config["metric_type"] + prune_ratio = smash_config["prune_ratio"] + angular = smash_config["angular"] + + pruna_logger.info(f"[ShortGPT] Running {metric_type}-based layer pruning (ratio={prune_ratio:.2f})") + + scores = self.compute_block_influence(model, tokenizer, texts, angular=angular, device=device) + + num_layers = len(model.model.layers) + n_prune = int(prune_ratio * num_layers) + layer_scores = np.array(scores[1:num_layers+1]) # skip embedding span + + prune_indices = np.argsort(layer_scores)[:n_prune].tolist() + keep_indices = [i for i in range(num_layers) if i not in prune_indices] + + pruna_logger.info(f"[ShortGPT] Pruning {n_prune}/{num_layers} layers: {prune_indices}") + + kept_layers = torch.nn.ModuleList([layer for i, layer in enumerate(model.model.layers) if i in keep_indices]) + model.model.layers = kept_layers + + pruna_logger.info(f"[ShortGPT] Pruned model depth: {len(model.model.layers)}") + pruna_logger.info(f"[ShortGPT] Pruned model parameters: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M") + + return model + + + def model_check_fn(self, model): + return isinstance(model, torch.nn.Module) + From 393163d003b41d1f6dfaafafa751636496449ae1 Mon Sep 17 00:00:00 2001 From: Aakash Thatte Date: Wed, 12 Nov 2025 20:30:33 +0530 Subject: [PATCH 2/4] support dataloder --- src/pruna/algorithms/shortgpt.py | 109 ++++++++++++++++++++++--------- 1 file changed, 78 insertions(+), 31 deletions(-) diff --git a/src/pruna/algorithms/shortgpt.py b/src/pruna/algorithms/shortgpt.py index 5f8ded6b..8d0de180 100644 --- a/src/pruna/algorithms/shortgpt.py +++ b/src/pruna/algorithms/shortgpt.py @@ -13,11 +13,14 @@ # limitations under the License. from __future__ import annotations -import torch -import torch.nn.functional as F + +from typing import Any + import numpy as np +import torch +import torch.nn.functional as f +from ConfigSpace import CategoricalHyperparameter, UniformFloatHyperparameter from tqdm import tqdm -from typing import Any, Dict, List from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags @@ -26,11 +29,12 @@ from pruna.engine.save import SAVE_FUNCTIONS from pruna.logging.logger import pruna_logger + class ShortGPT(PrunaAlgorithmBase): """ ShortGPT algorithm for pruning transformer layers using a block influence metric. - ShortGPT identifies and prunes less important blocks in transformer models based on their + ShortGPT identifies and prunes less important blocks in transformer models based on their BI scores, which uses the similarity between a layers input and output to measure its importance. """ @@ -46,54 +50,82 @@ class ShortGPT(PrunaAlgorithmBase): runs_on: list[str] = ["cuda", "cpu"] def get_hyperparameters(self) -> list: - from ConfigSpace import CategoricalHyperparameter, UniformFloatHyperparameter, UniformIntegerHyperparameter + """ + Configure all algorithm-specific hyperparameters with ConfigSpace. + + Returns + ------- + list + The hyperparameters. + """ return [ CategoricalHyperparameter( - "metric_type", ["BI"], default_value="BI", - meta=dict(desc="Metric type for layer importance: Block Influence") + "metric_type", + ["BI"], + default_value="BI", + meta=dict(desc="Metric type for layer importance: Block Influence"), ), UniformFloatHyperparameter( - "prune_ratio", lower=0.0, upper=0.8, default_value=0.25, - meta=dict(desc="Fraction of layers to prune") + "prune_ratio", + lower=0.0, + upper=0.8, + default_value=0.25, + meta=dict(desc="Fraction of layers to prune"), ), Boolean("angular", meta=dict(desc="Use angular distance for BI computation")), - UniformIntegerHyperparameter( - "calibration_samples", lower=8, upper=512, default_value=64, - meta=dict(desc="Number of calibration samples to compute metrics") - ), ] - @staticmethod @torch.inference_mode() - def compute_block_influence(model, tokenizer, texts, angular=False, device="cuda", max_samples=64): + def compute_block_influence(model, tokenizer, dataloader, angular=False, device="cuda"): + """ + Compute the block influence scores for each transformer layer in the model. + + The block influence score for a layer is given as 1 - the cosine similarity + between the layer's input and output activations, averaged over the dataset. + """ model.eval().to(device) num_layers = len(model.model.layers) bis = torch.zeros(num_layers + 1, device=device) counts = 0 - for text in tqdm(texts[:max_samples], desc="Computing Block Influence"): - inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device) + for batch_idx, batch in enumerate(tqdm(dataloader, desc="Computing Block Influence")): + if isinstance(batch, dict) and "text" in batch: + texts = batch["text"] + elif isinstance(batch, list): + texts = batch + else: + raise ValueError(f"Unsupported batch type: {type(batch)}") + + inputs = tokenizer( + texts, + return_tensors="pt", + truncation=True, + max_length=512, + padding=True, + ).to(device) input_ids = inputs["input_ids"] hiddens = [] def hook_fn(_, __, out): - if isinstance(out, tuple): out = out[0] + if isinstance(out, tuple): + out = out[0] hiddens.append(out) handles = [layer.register_forward_hook(hook_fn) for layer in model.model.layers] _ = model(input_ids=input_ids) - for h in handles: h.remove() + for h in handles: + h.remove() hiddens.insert(0, model.model.embed_tokens(input_ids)) hiddens.append(model.model.norm(hiddens[-1])) for i in range(len(hiddens) - 1): in_h, out_h = hiddens[i].float(), hiddens[i + 1].float() - cos = F.cosine_similarity( + cos = f.cosine_similarity( in_h.view(-1, in_h.shape[-1]), out_h.view(-1, out_h.shape[-1]), - dim=-1 + dim=-1, ) if angular: cos = cos.clamp(-1 + 1e-7, 1 - 1e-7) @@ -105,7 +137,7 @@ def hook_fn(_, __, out): bis /= counts return bis.tolist() - + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: device = smash_config["device"] model = model.to(device) @@ -115,35 +147,50 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: pruna_logger.info(f"[ShortGPT] Model depth: {len(model.model.layers)}") pruna_logger.info(f"[ShortGPT] Model parameters: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M") tokenizer = smash_config["tokenizer"] - - texts = smash_config["texts"] - metric_type = smash_config["metric_type"] + dataloader = smash_config["train_dataloader"] prune_ratio = smash_config["prune_ratio"] angular = smash_config["angular"] - pruna_logger.info(f"[ShortGPT] Running {metric_type}-based layer pruning (ratio={prune_ratio:.2f})") + pruna_logger.info(f"[ShortGPT] Running layer pruning (ratio={prune_ratio:.2f})") - scores = self.compute_block_influence(model, tokenizer, texts, angular=angular, device=device) + scores = self.compute_block_influence(model, tokenizer, dataloader, angular=angular, device=device) num_layers = len(model.model.layers) n_prune = int(prune_ratio * num_layers) - layer_scores = np.array(scores[1:num_layers+1]) # skip embedding span + + # not using the final norm layer score, because paper only mentions only transformer layers # noqa + # TODO: Should we even compute the norm layer score? # noqa + layer_scores = np.array(scores[:num_layers]) prune_indices = np.argsort(layer_scores)[:n_prune].tolist() keep_indices = [i for i in range(num_layers) if i not in prune_indices] pruna_logger.info(f"[ShortGPT] Pruning {n_prune}/{num_layers} layers: {prune_indices}") + pruna_logger.info(f"[ShortGPT] Removing layers: {prune_indices}") kept_layers = torch.nn.ModuleList([layer for i, layer in enumerate(model.model.layers) if i in keep_indices]) model.model.layers = kept_layers pruna_logger.info(f"[ShortGPT] Pruned model depth: {len(model.model.layers)}") - pruna_logger.info(f"[ShortGPT] Pruned model parameters: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M") + pruna_logger.info( + f"[ShortGPT] Pruned model parameters: {sum(p.numel() for p in model.parameters()) / 1_000_000:.2f}M" + ) return model - def model_check_fn(self, model): + """ + Check if the model is a torch.nn.Module. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a torch.nn.Module, False otherwise. + """ return isinstance(model, torch.nn.Module) - From 2b8b0192d643058e06bc5a58c12fad699979e6a3 Mon Sep 17 00:00:00 2001 From: Aakash Thatte Date: Wed, 12 Nov 2025 20:32:38 +0530 Subject: [PATCH 3/4] simplify capturing hidden states --- src/pruna/algorithms/shortgpt.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/pruna/algorithms/shortgpt.py b/src/pruna/algorithms/shortgpt.py index 8d0de180..03d7781b 100644 --- a/src/pruna/algorithms/shortgpt.py +++ b/src/pruna/algorithms/shortgpt.py @@ -105,20 +105,9 @@ def compute_block_influence(model, tokenizer, dataloader, angular=False, device= padding=True, ).to(device) input_ids = inputs["input_ids"] - hiddens = [] - def hook_fn(_, __, out): - if isinstance(out, tuple): - out = out[0] - hiddens.append(out) - - handles = [layer.register_forward_hook(hook_fn) for layer in model.model.layers] - _ = model(input_ids=input_ids) - for h in handles: - h.remove() - - hiddens.insert(0, model.model.embed_tokens(input_ids)) - hiddens.append(model.model.norm(hiddens[-1])) + outputs = model(input_ids=input_ids, output_hidden_states=True) + hiddens = list(outputs.hidden_states) for i in range(len(hiddens) - 1): in_h, out_h = hiddens[i].float(), hiddens[i + 1].float() From 15bc073f9bcf9600535762067efccc93bd3e8156 Mon Sep 17 00:00:00 2001 From: Aakash Thatte Date: Wed, 12 Nov 2025 21:18:39 +0530 Subject: [PATCH 4/4] add todo --- src/pruna/algorithms/shortgpt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pruna/algorithms/shortgpt.py b/src/pruna/algorithms/shortgpt.py index 03d7781b..fd458c99 100644 --- a/src/pruna/algorithms/shortgpt.py +++ b/src/pruna/algorithms/shortgpt.py @@ -89,6 +89,8 @@ def compute_block_influence(model, tokenizer, dataloader, angular=False, device= bis = torch.zeros(num_layers + 1, device=device) counts = 0 + # TODO: Discuss if we should keep clearing device cache in case of gpu, + # because model and data keep moving to device for batch_idx, batch in enumerate(tqdm(dataloader, desc="Computing Block Influence")): if isinstance(batch, dict) and "text" in batch: texts = batch["text"]