Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onecomp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def main():
)
parser.add_argument(
"--device",
default="cuda:0",
help="device to place the model on (default: cuda:0)",
default=None,
help="device to place the model on (default: auto-detect CUDA -> MPS -> CPU)",
)
parser.add_argument(
"--no-qep",
Expand Down
13 changes: 10 additions & 3 deletions onecomp/qep/_qep_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from dataclasses import dataclass, field

from onecomp.utils.device import get_default_device


@dataclass
class QEPConfig:
Expand All @@ -23,8 +25,9 @@ class QEPConfig:
Default is 0.01.
perccorr (float): Correction percentage for error propagation.
Default is 0.5.
device (str): Device to use for QEP computations (e.g., "cuda").
Default is "cuda:0".
device (str or None): Device to use for QEP computations
(e.g., "cuda", "mps", "cpu"). When ``None`` (default),
auto-detected at runtime (CUDA > MPS > CPU).
exclude_layer_keywords (list[str]): List of keywords to identify
layers excluded from error propagation. Layers whose names
contain any of these keywords will be excluded.
Expand All @@ -51,6 +54,10 @@ class QEPConfig:
general: bool = False
percdamp: float = 0.01
perccorr: float = 0.5
device: str = "cuda:0"
device: str = None
exclude_layer_keywords: list[str] = field(default_factory=lambda: ["mlp.down_proj"])
# TODO: exclude_layer_keywords depends on the architecture and needs to be fixed

def __post_init__(self):
if self.device is None:
self.device = str(get_default_device())
3 changes: 2 additions & 1 deletion onecomp/qep/_quantize_with_qep_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
move_kwargs_to_device,
expand_kwargs_batch,
)
from onecomp.utils.device import empty_cache

logger = getLogger(__name__)

Expand Down Expand Up @@ -355,6 +356,6 @@ def run_quantize_with_qep_arch(

# free memory
block_q.cpu()
torch.cuda.empty_cache()
empty_cache(device)

quantizer.execute_post_processing()
5 changes: 3 additions & 2 deletions onecomp/quantized_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .quantizer.dbf.dbf_layer import DoubleBinaryLinear
from .quantizer.gptq.config import resolve_gptq_layer_wbits, resolve_gptq_layer_group_size
from .quantizer.gptq.gptq_layer import GPTQLinear
from .utils.device import get_default_device
from .utils.quant_config import get_quant_param

logger = getLogger(__name__)
Expand Down Expand Up @@ -119,7 +120,7 @@ def load_quantized_model(
device_map_resolved = infer_auto_device_map(model)
model = dispatch_model(model, device_map=device_map_resolved)
except ImportError:
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(get_default_device())

tokenizer = AutoTokenizer.from_pretrained(
save_directory,
Expand Down Expand Up @@ -181,7 +182,7 @@ def load_quantized_model_pt(
device_map_resolved = infer_auto_device_map(model)
model = dispatch_model(model, device_map=device_map_resolved)
except ImportError:
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(get_default_device())

tokenizer = AutoTokenizer.from_pretrained(
save_directory,
Expand Down
34 changes: 27 additions & 7 deletions onecomp/quantizer/_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,26 @@
import torch
from torch.nn import Linear, Conv2d, Conv1d

from onecomp.utils.device import empty_cache


def _safe_cholesky(tensor, **kwargs):
if tensor.device.type == "mps":
return torch.linalg.cholesky(tensor.cpu(), **kwargs).to(tensor.device)
return torch.linalg.cholesky(tensor, **kwargs)


def _safe_cholesky_inverse(tensor):
if tensor.device.type == "mps":
return torch.cholesky_inverse(tensor.cpu()).to(tensor.device)
return torch.cholesky_inverse(tensor)


def _safe_cholesky_solve(b, u):
if b.device.type == "mps":
return torch.cholesky_solve(b.cpu(), u.cpu()).to(b.device)
return torch.cholesky_solve(b, u)


@dataclass
class QuantizationResult:
Expand Down Expand Up @@ -202,7 +222,7 @@ def quantize(
result.quantization_time = end_time - start_time

self.results[name] = result
torch.cuda.empty_cache()
empty_cache()

if self.calc_quant_error:
# Record quantization error
Expand Down Expand Up @@ -248,7 +268,7 @@ def quantize_with_qep(
percdamp=percdamp,
perccorr=perccorr,
)
torch.cuda.empty_cache()
empty_cache()

self.logger.info("Quantizing layer: %s", name)
result = self.quantize_layer(module, quant_input_activation, hessian=hessian)
Expand All @@ -262,7 +282,7 @@ def quantize_with_qep(
result.quantization_time = end_time - start_time

self.results[name] = result
torch.cuda.empty_cache()
empty_cache()

if self.calc_quant_error:
# Record quantization error
Expand Down Expand Up @@ -298,7 +318,7 @@ def _record_quantization_error(
result.relative_weight_squared_error,
) = self.calculate_weight_quantization_error(module, dequantized_weight)

torch.cuda.empty_cache()
empty_cache()

def adjust_weight(
self,
Expand Down Expand Up @@ -343,9 +363,9 @@ def adjust_weight(
damp = percdamp * torch.mean(torch.diag(hessian))
diag = torch.arange(hessian.shape[0], device=hessian.device)
hessian[diag, diag] += damp
cholesky = torch.linalg.cholesky(hessian)
cholesky = _safe_cholesky(hessian)
rhs = weight @ delta_hatX
delta_weight = torch.cholesky_solve(rhs.t(), cholesky).t()
delta_weight = _safe_cholesky_solve(rhs.t(), cholesky).t()
weight = weight + (perccorr * delta_weight)

if isinstance(module, Conv1d):
Expand Down Expand Up @@ -847,7 +867,7 @@ def calculate_output_quantization_error(

del batch_diff, batch_X_T

torch.cuda.empty_cache()
empty_cache()

# MSE = output_squared_error / (out_features * total_samples)
mean_output_squared_error = output_squared_error / num_elements
Expand Down
11 changes: 6 additions & 5 deletions onecomp/quantizer/autobit/activation_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
forward_input,
move_kwargs_to_device,
)
from onecomp.utils.device import get_default_device, empty_cache


def _find_head_modules(model, blocks):
Expand Down Expand Up @@ -85,14 +86,14 @@ def collect_activation_stats_blockwise(
from onecomp.utils.calibration import prepare_calibration_dataset

if device is None:
device = torch.device("cuda")
device = get_default_device()

original_device = next(model.parameters()).device
if original_device.type != "cpu":
if logger:
logger.info("Moving model to CPU for block-wise activation collection")
model.to("cpu")
torch.cuda.empty_cache()
empty_cache(original_device)

model_id = getattr(model.config, "_name_or_path", None)

Expand Down Expand Up @@ -152,7 +153,7 @@ def collect_activation_stats_blockwise(
for h in hooks:
h.remove()
block.cpu()
torch.cuda.empty_cache()
empty_cache(device)

# Collect b_diag
if use_curvature_b:
Expand Down Expand Up @@ -202,7 +203,7 @@ def collect_activation_stats_blockwise(
for h in hooks:
h.remove()
block.cpu()
torch.cuda.empty_cache()
empty_cache(device)

a_diag = {}
b_diag = {}
Expand Down Expand Up @@ -270,6 +271,6 @@ def _compute_loss_grad(final_hidden, norm, lm_head, input_ids, device):

norm.cpu()
lm_head.cpu()
torch.cuda.empty_cache()
empty_cache(device)

return torch.cat(all_grads)
12 changes: 7 additions & 5 deletions onecomp/quantizer/gptq/_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from torch import nn
from transformers import Conv1D

from onecomp.quantizer._quantizer import Quantizer, QuantizationResult
from onecomp.quantizer._quantizer import Quantizer, QuantizationResult, _safe_cholesky, _safe_cholesky_inverse
from onecomp.utils.quant_config import get_quant_param
from onecomp.utils.device import empty_cache


@dataclass
Expand Down Expand Up @@ -517,9 +518,9 @@ def run_gptq( # pylint: disable=too-many-positional-arguments
damp = percdamp * torch.mean(torch.diag(hessian))
diag = torch.arange(hessian.shape[0], device=hessian.device)
hessian[diag, diag] += damp
hessian = torch.linalg.cholesky(hessian)
hessian = torch.cholesky_inverse(hessian)
hessian = torch.linalg.cholesky(hessian, upper=True)
hessian = _safe_cholesky(hessian)
hessian = _safe_cholesky_inverse(hessian)
hessian = _safe_cholesky(hessian, upper=True)
Hinv = hessian

# Accumulate per-group scale/zero for grouped quantization
Expand Down Expand Up @@ -598,9 +599,10 @@ def run_gptq( # pylint: disable=too-many-positional-arguments
zero = quantizer.zero.to(dtype=torch.int32, device="cpu")
perm = perm.cpu() if perm is not None else None

_device = quantized_weight.device
del hessian, Hinv, matrix_W, Q_int
gc.collect()
torch.cuda.empty_cache()
empty_cache(_device)

return {
"qweight": quantized_weight,
Expand Down
10 changes: 7 additions & 3 deletions onecomp/quantizer/gptq/gptq_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
zero: torch.Tensor, # FP16
perm: Optional[torch.Tensor] = None, # INT64
bias: Optional[torch.Tensor] = None,
device: str = "cuda",
device: Union[str, torch.device, None] = None,
pack_weights: bool = True, # Pack INT weights for memory efficiency
use_gemlite: Optional[bool] = None, # GemLite flag
):
Expand All @@ -236,7 +236,11 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.groupsize = groupsize
self.actorder = actorder

device = torch.device(device) if isinstance(device, str) else device
if device is None:
from onecomp.utils.device import get_default_device
device = get_default_device()
elif isinstance(device, str):
device = torch.device(device)

# Decide whether to use GemLite
if use_gemlite is None:
Expand Down Expand Up @@ -384,7 +388,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

@classmethod
def from_quantization_result( # pylint: disable=too-many-positional-arguments
cls, result, bias=None, device="cuda", pack_weights=True, use_gemlite=None
cls, result, bias=None, device=None, pack_weights=True, use_gemlite=None
):
"""
Build GPTQLinear from GPTQResult (quantizer.results).
Expand Down
23 changes: 15 additions & 8 deletions onecomp/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .utils import calculate_perplexity as calc_perplexity
from .utils import prepare_calibration_dataset as prepare_calib_dataset
from .log import setup_logger
from .utils import get_default_device, empty_cache


class Runner:
Expand Down Expand Up @@ -368,7 +369,7 @@ def auto_run(
wbits: Optional[float] = None,
total_vram_gb: Optional[float] = None,
groupsize: int = 128,
device: str = "cuda:0",
device: str = None,
qep: bool = True,
evaluate: bool = True,
eval_original_model: bool = False,
Expand All @@ -394,7 +395,8 @@ def auto_run(
automatically.
groupsize (int): GPTQ group size (default: 128).
Use -1 to disable grouping.
device (str): Device to place the model on (default: "cuda:0").
device (str or None): Device to place the model on.
When ``None`` (default), auto-detected (CUDA > MPS > CPU).
qep (bool): Whether to use QEP (default: True).
evaluate (bool): Whether to calculate perplexity and
accuracy after quantization (default: True).
Expand Down Expand Up @@ -443,6 +445,10 @@ def auto_run(
setup_logger()
logger = getLogger(__name__)

if device is None:
device = str(get_default_device())
logger.info("Auto-detected device: %s", device)

candidate_bits = (2, 3, 4, 8)

if wbits is None:
Expand Down Expand Up @@ -478,7 +484,8 @@ def auto_run(
save_path=save_dir if save_dir is not None else None,
enable_fused_groups=True,
)
runner = cls(model_config=model_config, quantizer=quantizer, qep=qep)
qep_config = QEPConfig(device=device)
runner = cls(model_config=model_config, quantizer=quantizer, qep=qep, qep_config=qep_config)
runner.run()

if evaluate:
Expand Down Expand Up @@ -1018,7 +1025,7 @@ def _calculate_evaluation(
tokenizer = self.model_config.load_tokenizer()
original_result = eval_function(model=model, tokenizer=tokenizer, **eval_args)
del model, tokenizer
torch.cuda.empty_cache()
empty_cache()

if quantized_model:
try:
Expand All @@ -1035,7 +1042,7 @@ def _calculate_evaluation(
model.to(self.model_config.device)
quantized_result = eval_function(model=model, tokenizer=tokenizer, **eval_args)
del model, tokenizer
torch.cuda.empty_cache()
empty_cache()
except NotImplementedError:
logger.warning(
"This quantization method does not support creating a quantized model; "
Expand All @@ -1050,7 +1057,7 @@ def _calculate_evaluation(
self.update_model_weights(model, quantizer=quantizer)
dequantized_result = eval_function(model=model, tokenizer=tokenizer, **eval_args)
del model, tokenizer
torch.cuda.empty_cache()
empty_cache()

return original_result, dequantized_result, quantized_result

Expand Down Expand Up @@ -1781,7 +1788,7 @@ def analyze_cumulative_error(
)
# Release fragmented GPU memory from previous operations (e.g., run())
gc.collect()
torch.cuda.empty_cache()
empty_cache()

model = self.model_config.load_model()
input_device = next(model.parameters()).device
Expand All @@ -1805,7 +1812,7 @@ def analyze_cumulative_error(
)
# Release fragmented GPU memory from previous operations (e.g., run())
gc.collect()
torch.cuda.empty_cache()
empty_cache()

model = self.model_config.load_model()
input_device = next(model.parameters()).device
Expand Down
Loading