From 0308a64711308c6cf0f30f5ce0d86a087ef7f8a2 Mon Sep 17 00:00:00 2001 From: aki916 Date: Mon, 6 Apr 2026 01:48:23 +0900 Subject: [PATCH 1/2] detect current device and be applicable for mps (apple silicon) --- onecomp/cli.py | 4 +-- onecomp/qep/_qep_config.py | 13 +++++-- onecomp/qep/_quantize_with_qep_arch.py | 3 +- onecomp/quantized_model_loader.py | 5 +-- onecomp/quantizer/_quantizer.py | 34 +++++++++++++++---- onecomp/quantizer/autobit/activation_stats.py | 11 +++--- onecomp/quantizer/gptq/_gptq.py | 12 ++++--- onecomp/quantizer/gptq/gptq_layer.py | 10 ++++-- onecomp/runner.py | 23 ++++++++----- .../runner_methods/chunked_quantization.py | 7 ++-- onecomp/utils/__init__.py | 5 +++ onecomp/utils/device.py | 31 +++++++++++++++++ 12 files changed, 119 insertions(+), 39 deletions(-) create mode 100644 onecomp/utils/device.py diff --git a/onecomp/cli.py b/onecomp/cli.py index 37497a4..949a3a1 100644 --- a/onecomp/cli.py +++ b/onecomp/cli.py @@ -40,8 +40,8 @@ def main(): ) parser.add_argument( "--device", - default="cuda:0", - help="device to place the model on (default: cuda:0)", + default=None, + help="device to place the model on (default: auto-detect CUDA -> MPS -> CPU)", ) parser.add_argument( "--no-qep", diff --git a/onecomp/qep/_qep_config.py b/onecomp/qep/_qep_config.py index c1801e0..7d24cf5 100644 --- a/onecomp/qep/_qep_config.py +++ b/onecomp/qep/_qep_config.py @@ -8,6 +8,8 @@ from dataclasses import dataclass, field +from onecomp.utils.device import get_default_device + @dataclass class QEPConfig: @@ -23,8 +25,9 @@ class QEPConfig: Default is 0.01. perccorr (float): Correction percentage for error propagation. Default is 0.5. - device (str): Device to use for QEP computations (e.g., "cuda"). - Default is "cuda:0". + device (str or None): Device to use for QEP computations + (e.g., "cuda", "mps", "cpu"). When ``None`` (default), + auto-detected at runtime (CUDA > MPS > CPU). exclude_layer_keywords (list[str]): List of keywords to identify layers excluded from error propagation. Layers whose names contain any of these keywords will be excluded. @@ -51,6 +54,10 @@ class QEPConfig: general: bool = False percdamp: float = 0.01 perccorr: float = 0.5 - device: str = "cuda:0" + device: str = None exclude_layer_keywords: list[str] = field(default_factory=lambda: ["mlp.down_proj"]) # TODO: exclude_layer_keywords depends on the architecture and needs to be fixed + + def __post_init__(self): + if self.device is None: + self.device = str(get_default_device()) diff --git a/onecomp/qep/_quantize_with_qep_arch.py b/onecomp/qep/_quantize_with_qep_arch.py index 793b61c..ab94efa 100644 --- a/onecomp/qep/_quantize_with_qep_arch.py +++ b/onecomp/qep/_quantize_with_qep_arch.py @@ -32,6 +32,7 @@ move_kwargs_to_device, expand_kwargs_batch, ) +from onecomp.utils.device import empty_cache logger = getLogger(__name__) @@ -355,6 +356,6 @@ def run_quantize_with_qep_arch( # free memory block_q.cpu() - torch.cuda.empty_cache() + empty_cache(device) quantizer.execute_post_processing() diff --git a/onecomp/quantized_model_loader.py b/onecomp/quantized_model_loader.py index a9ac04f..570be7a 100644 --- a/onecomp/quantized_model_loader.py +++ b/onecomp/quantized_model_loader.py @@ -22,6 +22,7 @@ from .quantizer.dbf.dbf_layer import DoubleBinaryLinear from .quantizer.gptq.config import resolve_gptq_layer_wbits, resolve_gptq_layer_group_size from .quantizer.gptq.gptq_layer import GPTQLinear +from .utils.device import get_default_device from .utils.quant_config import get_quant_param logger = getLogger(__name__) @@ -119,7 +120,7 @@ def load_quantized_model( device_map_resolved = infer_auto_device_map(model) model = dispatch_model(model, device_map=device_map_resolved) except ImportError: - model = model.to("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(get_default_device()) tokenizer = AutoTokenizer.from_pretrained( save_directory, @@ -181,7 +182,7 @@ def load_quantized_model_pt( device_map_resolved = infer_auto_device_map(model) model = dispatch_model(model, device_map=device_map_resolved) except ImportError: - model = model.to("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(get_default_device()) tokenizer = AutoTokenizer.from_pretrained( save_directory, diff --git a/onecomp/quantizer/_quantizer.py b/onecomp/quantizer/_quantizer.py index b252dfe..96781ba 100644 --- a/onecomp/quantizer/_quantizer.py +++ b/onecomp/quantizer/_quantizer.py @@ -17,6 +17,26 @@ import torch from torch.nn import Linear, Conv2d, Conv1d +from onecomp.utils.device import empty_cache + + +def _safe_cholesky(tensor, **kwargs): + if tensor.device.type == "mps": + return torch.linalg.cholesky(tensor.cpu(), **kwargs).to(tensor.device) + return torch.linalg.cholesky(tensor, **kwargs) + + +def _safe_cholesky_inverse(tensor): + if tensor.device.type == "mps": + return torch.cholesky_inverse(tensor.cpu()).to(tensor.device) + return torch.cholesky_inverse(tensor) + + +def _safe_cholesky_solve(b, u): + if b.device.type == "mps": + return torch.cholesky_solve(b.cpu(), u.cpu()).to(b.device) + return torch.cholesky_solve(b, u) + @dataclass class QuantizationResult: @@ -202,7 +222,7 @@ def quantize( result.quantization_time = end_time - start_time self.results[name] = result - torch.cuda.empty_cache() + empty_cache() if self.calc_quant_error: # Record quantization error @@ -248,7 +268,7 @@ def quantize_with_qep( percdamp=percdamp, perccorr=perccorr, ) - torch.cuda.empty_cache() + empty_cache() self.logger.info("Quantizing layer: %s", name) result = self.quantize_layer(module, quant_input_activation, hessian=hessian) @@ -262,7 +282,7 @@ def quantize_with_qep( result.quantization_time = end_time - start_time self.results[name] = result - torch.cuda.empty_cache() + empty_cache() if self.calc_quant_error: # Record quantization error @@ -298,7 +318,7 @@ def _record_quantization_error( result.relative_weight_squared_error, ) = self.calculate_weight_quantization_error(module, dequantized_weight) - torch.cuda.empty_cache() + empty_cache() def adjust_weight( self, @@ -343,9 +363,9 @@ def adjust_weight( damp = percdamp * torch.mean(torch.diag(hessian)) diag = torch.arange(hessian.shape[0], device=hessian.device) hessian[diag, diag] += damp - cholesky = torch.linalg.cholesky(hessian) + cholesky = _safe_cholesky(hessian) rhs = weight @ delta_hatX - delta_weight = torch.cholesky_solve(rhs.t(), cholesky).t() + delta_weight = _safe_cholesky_solve(rhs.t(), cholesky).t() weight = weight + (perccorr * delta_weight) if isinstance(module, Conv1d): @@ -847,7 +867,7 @@ def calculate_output_quantization_error( del batch_diff, batch_X_T - torch.cuda.empty_cache() + empty_cache() # MSE = output_squared_error / (out_features * total_samples) mean_output_squared_error = output_squared_error / num_elements diff --git a/onecomp/quantizer/autobit/activation_stats.py b/onecomp/quantizer/autobit/activation_stats.py index 14bf10f..cd56173 100644 --- a/onecomp/quantizer/autobit/activation_stats.py +++ b/onecomp/quantizer/autobit/activation_stats.py @@ -16,6 +16,7 @@ forward_input, move_kwargs_to_device, ) +from onecomp.utils.device import get_default_device, empty_cache def _find_head_modules(model, blocks): @@ -85,14 +86,14 @@ def collect_activation_stats_blockwise( from onecomp.utils.calibration import prepare_calibration_dataset if device is None: - device = torch.device("cuda") + device = get_default_device() original_device = next(model.parameters()).device if original_device.type != "cpu": if logger: logger.info("Moving model to CPU for block-wise activation collection") model.to("cpu") - torch.cuda.empty_cache() + empty_cache(original_device) model_id = getattr(model.config, "_name_or_path", None) @@ -152,7 +153,7 @@ def collect_activation_stats_blockwise( for h in hooks: h.remove() block.cpu() - torch.cuda.empty_cache() + empty_cache(device) # Collect b_diag if use_curvature_b: @@ -202,7 +203,7 @@ def collect_activation_stats_blockwise( for h in hooks: h.remove() block.cpu() - torch.cuda.empty_cache() + empty_cache(device) a_diag = {} b_diag = {} @@ -270,6 +271,6 @@ def _compute_loss_grad(final_hidden, norm, lm_head, input_ids, device): norm.cpu() lm_head.cpu() - torch.cuda.empty_cache() + empty_cache(device) return torch.cat(all_grads) diff --git a/onecomp/quantizer/gptq/_gptq.py b/onecomp/quantizer/gptq/_gptq.py index a5ef0d7..22e06b1 100644 --- a/onecomp/quantizer/gptq/_gptq.py +++ b/onecomp/quantizer/gptq/_gptq.py @@ -19,8 +19,9 @@ from torch import nn from transformers import Conv1D -from onecomp.quantizer._quantizer import Quantizer, QuantizationResult +from onecomp.quantizer._quantizer import Quantizer, QuantizationResult, _safe_cholesky, _safe_cholesky_inverse from onecomp.utils.quant_config import get_quant_param +from onecomp.utils.device import empty_cache @dataclass @@ -517,9 +518,9 @@ def run_gptq( # pylint: disable=too-many-positional-arguments damp = percdamp * torch.mean(torch.diag(hessian)) diag = torch.arange(hessian.shape[0], device=hessian.device) hessian[diag, diag] += damp - hessian = torch.linalg.cholesky(hessian) - hessian = torch.cholesky_inverse(hessian) - hessian = torch.linalg.cholesky(hessian, upper=True) + hessian = _safe_cholesky(hessian) + hessian = _safe_cholesky_inverse(hessian) + hessian = _safe_cholesky(hessian, upper=True) Hinv = hessian # Accumulate per-group scale/zero for grouped quantization @@ -598,9 +599,10 @@ def run_gptq( # pylint: disable=too-many-positional-arguments zero = quantizer.zero.to(dtype=torch.int32, device="cpu") perm = perm.cpu() if perm is not None else None + _device = quantized_weight.device del hessian, Hinv, matrix_W, Q_int gc.collect() - torch.cuda.empty_cache() + empty_cache(_device) return { "qweight": quantized_weight, diff --git a/onecomp/quantizer/gptq/gptq_layer.py b/onecomp/quantizer/gptq/gptq_layer.py index 97370a7..1db7605 100644 --- a/onecomp/quantizer/gptq/gptq_layer.py +++ b/onecomp/quantizer/gptq/gptq_layer.py @@ -224,7 +224,7 @@ def __init__( # pylint: disable=too-many-positional-arguments zero: torch.Tensor, # FP16 perm: Optional[torch.Tensor] = None, # INT64 bias: Optional[torch.Tensor] = None, - device: str = "cuda", + device: Union[str, torch.device, None] = None, pack_weights: bool = True, # Pack INT weights for memory efficiency use_gemlite: Optional[bool] = None, # GemLite flag ): @@ -236,7 +236,11 @@ def __init__( # pylint: disable=too-many-positional-arguments self.groupsize = groupsize self.actorder = actorder - device = torch.device(device) if isinstance(device, str) else device + if device is None: + from onecomp.utils.device import get_default_device + device = get_default_device() + elif isinstance(device, str): + device = torch.device(device) # Decide whether to use GemLite if use_gemlite is None: @@ -384,7 +388,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @classmethod def from_quantization_result( # pylint: disable=too-many-positional-arguments - cls, result, bias=None, device="cuda", pack_weights=True, use_gemlite=None + cls, result, bias=None, device=None, pack_weights=True, use_gemlite=None ): """ Build GPTQLinear from GPTQResult (quantizer.results). diff --git a/onecomp/runner.py b/onecomp/runner.py index af022e2..0275ca0 100644 --- a/onecomp/runner.py +++ b/onecomp/runner.py @@ -26,6 +26,7 @@ from .utils import calculate_perplexity as calc_perplexity from .utils import prepare_calibration_dataset as prepare_calib_dataset from .log import setup_logger +from .utils import get_default_device, empty_cache class Runner: @@ -368,7 +369,7 @@ def auto_run( wbits: Optional[float] = None, total_vram_gb: Optional[float] = None, groupsize: int = 128, - device: str = "cuda:0", + device: str = None, qep: bool = True, evaluate: bool = True, eval_original_model: bool = False, @@ -394,7 +395,8 @@ def auto_run( automatically. groupsize (int): GPTQ group size (default: 128). Use -1 to disable grouping. - device (str): Device to place the model on (default: "cuda:0"). + device (str or None): Device to place the model on. + When ``None`` (default), auto-detected (CUDA > MPS > CPU). qep (bool): Whether to use QEP (default: True). evaluate (bool): Whether to calculate perplexity and accuracy after quantization (default: True). @@ -443,6 +445,10 @@ def auto_run( setup_logger() logger = getLogger(__name__) + if device is None: + device = str(get_default_device()) + logger.info("Auto-detected device: %s", device) + candidate_bits = (2, 3, 4, 8) if wbits is None: @@ -478,7 +484,8 @@ def auto_run( save_path=save_dir if save_dir is not None else None, enable_fused_groups=True, ) - runner = cls(model_config=model_config, quantizer=quantizer, qep=qep) + qep_config = QEPConfig(device=device) + runner = cls(model_config=model_config, quantizer=quantizer, qep=qep, qep_config=qep_config) runner.run() if evaluate: @@ -1018,7 +1025,7 @@ def _calculate_evaluation( tokenizer = self.model_config.load_tokenizer() original_result = eval_function(model=model, tokenizer=tokenizer, **eval_args) del model, tokenizer - torch.cuda.empty_cache() + empty_cache() if quantized_model: try: @@ -1035,7 +1042,7 @@ def _calculate_evaluation( model.to(self.model_config.device) quantized_result = eval_function(model=model, tokenizer=tokenizer, **eval_args) del model, tokenizer - torch.cuda.empty_cache() + empty_cache() except NotImplementedError: logger.warning( "This quantization method does not support creating a quantized model; " @@ -1050,7 +1057,7 @@ def _calculate_evaluation( self.update_model_weights(model, quantizer=quantizer) dequantized_result = eval_function(model=model, tokenizer=tokenizer, **eval_args) del model, tokenizer - torch.cuda.empty_cache() + empty_cache() return original_result, dequantized_result, quantized_result @@ -1781,7 +1788,7 @@ def analyze_cumulative_error( ) # Release fragmented GPU memory from previous operations (e.g., run()) gc.collect() - torch.cuda.empty_cache() + empty_cache() model = self.model_config.load_model() input_device = next(model.parameters()).device @@ -1805,7 +1812,7 @@ def analyze_cumulative_error( ) # Release fragmented GPU memory from previous operations (e.g., run()) gc.collect() - torch.cuda.empty_cache() + empty_cache() model = self.model_config.load_model() input_device = next(model.parameters()).device diff --git a/onecomp/runner_methods/chunked_quantization.py b/onecomp/runner_methods/chunked_quantization.py index 2946e1e..5fd7d4d 100644 --- a/onecomp/runner_methods/chunked_quantization.py +++ b/onecomp/runner_methods/chunked_quantization.py @@ -34,6 +34,7 @@ from onecomp.model_config import ModelConfig from onecomp.quantizer._quantizer import Quantizer, QuantizationResult from onecomp.utils import prepare_calibration_dataset +from onecomp.utils.device import empty_cache logger = getLogger(__name__) @@ -247,7 +248,7 @@ def hook(_module, input, _output): # pylint: disable=redefined-builtin # Free memory del chunk_inputs - torch.cuda.empty_cache() + empty_cache() logger.info( " Chunk %d/%d done (samples %d-%d)", @@ -318,7 +319,7 @@ def quantize_group(quantizer, group, xtx_dict, nsamples): result.quantization_time = end_time - start_time quantizer.results[name] = result - torch.cuda.empty_cache() + empty_cache() # ============================================================================= @@ -385,4 +386,4 @@ def record_quantization_errors(quantizer, group, xtx_dict, nsamples): else None ) - torch.cuda.empty_cache() + empty_cache() diff --git a/onecomp/utils/__init__.py b/onecomp/utils/__init__.py index 9c45844..abdbf7e 100644 --- a/onecomp/utils/__init__.py +++ b/onecomp/utils/__init__.py @@ -38,3 +38,8 @@ move_kwargs_to_device, expand_kwargs_batch, ) + +from .device import ( + get_default_device, + empty_cache, +) diff --git a/onecomp/utils/device.py b/onecomp/utils/device.py new file mode 100644 index 0000000..ed3eff6 --- /dev/null +++ b/onecomp/utils/device.py @@ -0,0 +1,31 @@ +""" +Device-related utilities for cross-platform support (CUDA / MPS / CPU). + +Copyright 2025-2026 Fujitsu Ltd. + +""" + +import torch + + +def get_default_device() -> torch.device: + """Return the best available device: CUDA > MPS > CPU.""" + if torch.cuda.is_available(): + return torch.device("cuda") + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cpu") + + +def empty_cache(device: torch.device | str | None = None) -> None: + """Release device memory cache for the given device type. + + Safe to call on any platform — silently does nothing when the + device backend is not available. + """ + device_type = torch.device(device if device is not None else get_default_device()).type + + if device_type == "cuda": + torch.cuda.empty_cache() + elif device_type == "mps" and hasattr(torch.mps, "empty_cache"): + torch.mps.empty_cache() From 1b76811a58f4faea593be872c7ce4f2d1d4b228e Mon Sep 17 00:00:00 2001 From: aki916 Date: Thu, 9 Apr 2026 09:49:57 +0900 Subject: [PATCH 2/2] calculate NLL with cpu even user specify for using mps --- onecomp/utils/perplexity.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onecomp/utils/perplexity.py b/onecomp/utils/perplexity.py index 82c9505..7f415ad 100644 --- a/onecomp/utils/perplexity.py +++ b/onecomp/utils/perplexity.py @@ -123,7 +123,9 @@ def calculate_perplexity( if stride is None: stride = max_length seq_len = encodings.input_ids.size(1) - nll_sum = torch.tensor(0.0, dtype=torch.float64, device=device) + use_cpu_accum = device.type == "mps" if isinstance(device, torch.device) else str(device).startswith("mps") + accum_device = torch.device("cpu") if use_cpu_accum else device + nll_sum = torch.tensor(0.0, dtype=torch.float64, device=accum_device) n_tokens = 0 prev_end_loc = 0 for begin_loc in tqdm(range(0, seq_len, stride)): @@ -138,7 +140,7 @@ def calculate_perplexity( # N.B. the model only calculates loss over trg_len - 1 labels, # because it internally shifts the labels # to the left by 1. - neg_log_likelihood = outputs.loss.to(torch.float64) + neg_log_likelihood = outputs.loss.to(accum_device).to(torch.float64) # Accumulate the total negative log-likelihood and the total number of tokens num_valid_tokens = ( (target_ids != -100).sum().item()