Skip to content

Latest commit

 

History

History
59 lines (41 loc) · 2.03 KB

File metadata and controls

59 lines (41 loc) · 2.03 KB

metal-quant-ext2

metal-quant-ext2 is a repository of my research on PyTorch MPS kernel extensions using Apple Metal. The name includes quant which implies some quantization.

Goal is to develop mps pytorch extensions for efficient local fine-tuning using pytorch huggingface models and TRL

The study includes:

Requirements

  • MacOS 15.3.1 or later
  • Python 3.12
  • pytorch

Usage

pip3 install -r requirements.txt
pip3 install --ignore-installed .

Blockwise Quantization 8-bit

blockwise_quant is a function that applies symmetric blockwise 8-bit quantization to a pytorch tensor

from metal_quant_ext2 import blockwise_quant, dequantize
mps_device = torch.device("mps")

input_tensor = torch.randn(1024, device=mps_device, dtype=torch.float32)

quantized = torch.empty_like(input_tensor, dtype=torch.int8) # Will inherit device from input_tensor (MPS)

scales = torch.empty(num_blocks, device=cpu_device, dtype=torch.float32)

offsets = torch.empty(num_blocks, device=cpu_device, dtype=torch.float32)

# the actual MTL call
blockwise_quant(input_tensor, quantized, scales, offsets)

print(f"quantized: {quantized}")
assert torch.all(quantized.cpu() >= -127) and torch.al(quantized.cpu() <= 127)

# Dequantize MTL call
scales = scales.to(mps_device)
output = torch.empty_like(input_tensor)
dequantize(quantized, scales, output)

Testing

Check out the test file with assertions test-blockwise-quant.py

Blockwise Quantization Example

Below is a python script that helped me understand blockwise quantization

code-samples/blockwise-quantization.py