-
Notifications
You must be signed in to change notification settings - Fork 0
05_weight_quantization_math_.md
Welcome back to the TinyQ tutorial! In our journey so far, we've seen how the Quantizer (Chapter 1) orchestrates the process, how different methods like W8A32 and W8A16 specify the target precision (Chapter 2), how Custom Quantized Layers are built to handle lower precision numbers, and how TinyQ replaces standard layers with these custom ones (Chapter 4).
Now, let's get to the heart of the conversion process itself: the Weight Quantization Math. This is the specific set of calculations that takes the original, detailed float32 weights and squeezes them into the much smaller int8 format used by our custom layers.
Imagine you have a very detailed measurement, like 5.3768214. A standard float32 number can store this level of detail. Quantization to int8 means we need to represent this number using a whole number between -128 and 127. How do you convert 5.3768214 into a single integer like 5 or 6 without losing too much information?
Directly rounding might work sometimes, but it doesn't account for the range of the original numbers. What if your numbers range from -1000 to 1000? Simply rounding them to integers between -128 and 127 won't work correctly!
We need a way to map the original range of floating-point values onto the limited range of int8 integer values. This mapping is done using a simple mathematical formula:
Real Value ≈ Scale * Quantized Value + Zero Point
Let's break this down:
-
Real Value: This is the original
float32number we want to represent (or recover). -
Quantized Value: This is the low-precision integer (
int8) that stores the compressed information. -
Scale: This is a multiplication factor. It's like the 'step size' or 'conversion rate' between the integer steps and the original real values. A larger scale means each integer step (
int8value) represents a larger step in the original floating-point range. -
Zero Point: This is an offset. It tells you which integer value corresponds to the real value
0. It helps handle ranges that are not centered around zero.
To quantize means: given a Real Value, calculate the Quantized Value (and find the Scale and Zero Point).
To dequantize means: given a Quantized Value, Scale, and Zero Point, calculate the approximate Real Value.
TinyQ focuses on a common and relatively simple type of quantization for weights: Symmetric Quantization.
Symmetric quantization makes a simplifying assumption: the range of original values is centered around zero. This allows us to set the Zero Point to 0.
The formula becomes much simpler:
Real Value ≈ Scale * Quantized Value
And conversely, to quantize:
Quantized Value ≈ Real Value / Scale
The main challenge then becomes finding the correct Scale. In symmetric quantization, the scale is typically calculated based on the maximum absolute value in the original floating-point tensor. If the range of int8 is [-Q_max, Q_max], the scale is simply:
Scale = Max Absolute Real Value / Q_max
For torch.int8, Q_max is typically 127 (the maximum positive value). So the formula is often Scale = Max Absolute Real Value / 127.
Once you have the scale, you quantize by dividing the original values by the scale and rounding to the nearest integer: quantized_value = round(real_value / scale). These rounded values are then clamped to the valid int8 range [-128, 127].
Neural network weights often have different ranges of values across different output channels (or input features, depending on the layer type and convention). For example, the weights producing output channel 1 might range from -0.1 to 0.1, while the weights for output channel 2 might range from -10.0 to 10.0.
If you calculate a single Scale based on the maximum absolute value across the entire weight tensor, the scale will be determined by the channel with the largest range (e.g., the channel ranging up to 10.0). Applying this large scale to channels with smaller ranges (like -0.1 to 0.1) will make most of their quantized values cluster tightly around zero, losing a lot of precision.
Per-channel quantization solves this by calculating a separate Scale for each output channel (each row of the weight matrix in a typical nn.Linear layer). This allows each channel's unique range to be mapped effectively onto the full int8 range [-128, 127], preserving more information per channel.
This is the method TinyQ uses for weights in both W8A32 and W8A16.
In TinyQ, the specific implementation of this per-channel symmetric quantization math for weights is handled by the function linear_q_symmetric_per_channel in utils.py.
This function takes the original float32 weights of an nn.Linear layer and calculates the int8 weights and the per-channel scales. It's a key piece of the puzzle, used by the quantize() method within the Custom Quantized Layers.
You, as a user, don't call linear_q_symmetric_per_channel directly. This math is performed automatically when the Quantizer (Chapter 1) calls the replace_linear_with_target_and_quantize function (Chapter 4), which in turn calls the quantize() method on the newly created Custom Quantized Layers.
Let's revisit the flow from Chapter 4 and zoom into the "quantize weights" step:
sequenceDiagram
participant Q as Quantizer
participant R as replace_linear...
participant OriginalLinear as nn.Linear Layer
participant CustomLayer as W8A32LinearLayer<br/>or W8A16LinearLayer
participant QuantizeMethod as CustomLayer.quantize()
participant QuantMath as linear_q_symmetric_per_channel
Q->R: Start replacement<br/>(module=Model, target_class=...)
R->R: Traverse model structure
R->OriginalLinear: Find an nn.Linear layer
R->OriginalLinear: Get weights & bias
R->CustomLayer: Create instance(...)
R->QuantizeMethod: Call quantize(original_weights)
QuantizeMethod->QuantMath: Call linear_q_symmetric_per_channel(original_weights, ...)
QuantMath-->QuantizeMethod: Return int8_weights, scales
QuantizeMethod->CustomLayer: Store int8_weights, scales, zero_points
QuantizeMethod-->R: Finished quantizing
R->CustomLayer: Copy original_bias
R->R: Replace OriginalLinear<br/>with CustomLayer in model
Note over R: Continues for all layers...
R-->Q: Return modified Module
As you can see, linear_q_symmetric_per_channel is called by the custom layer's quantize() method, which is itself called during the layer replacement process.
Let's examine the core utils.py functions involved in symmetric per-channel quantization.
First, a helper function to calculate the scale for a single tensor (or a single channel tensor):
# From utils.py
def get_q_scale_symmetric(tensor, dtype=torch.int8):
"""Calculate symmetric quantization scale.
Args:
tensor: Input tensor
dtype: Target quantization type (default: torch.int8)
Returns:
Quantization scale factor
"""
# Find the maximum absolute value in the tensor
r_max = tensor.abs().max().item()
# Get the max value for the target integer type (127 for int8)
q_max = torch.iinfo(dtype).max
# Calculate scale = Real_max_abs / Quant_max
return r_max/q_maxThis get_q_scale_symmetric function simply takes a tensor (which will be one row/channel of the weight matrix when called from linear_q_symmetric_per_channel) and calculates the scale based on its largest absolute value and the maximum value of the target integer type (int8).
Next, the main per-channel quantization function:
# From utils.py
def linear_q_symmetric_per_channel(r_tensor, dim, dtype=torch.int8):
"""Performs symmetric per-channel quantization.
Args:
r_tensor: Input tensor (original float32 weights)
dim: Dimension along which to quantize (0 for rows/output channels)
dtype: Target quantization type (default: torch.int8)
Returns:
tuple: (quantized_tensor, scale)
"""
# Get the number of channels/rows to quantize along 'dim'
output_dim = r_tensor.shape[dim]
# Create a tensor to store scales, one for each channel
scale = torch.zeros(output_dim)
# Iterate through each channel/row
for index in range(output_dim):
# Extract the sub-tensor for the current channel
sub_tensor = r_tensor.select(dim=dim, index=index)
# Calculate and store the scale for this specific channel
scale[index] = get_q_scale_symmetric(sub_tensor, dtype=dtype)
# Reshape the scale tensor so it can be correctly broadcast
# during the division/multiplication step (e.g., [out_features, 1])
scale_shape = [1] * r_tensor.dim()
scale_shape[dim] = -1 # Use -1 for the dimension being indexed (e.g., dim 0)
scale = scale.view(scale_shape)
# Quantize the entire tensor using the calculated per-channel scales
quantized_tensor = linear_q_with_scale_and_zero_point(
r_tensor, scale=scale, zero_point=0, dtype=dtype # Zero point is 0 for symmetric
)
return quantized_tensor, scaleThis function loops through each channel (dimension dim, which is 0 for the output channels of an nn.Linear layer's weight matrix), calculates a scale for that channel using get_q_scale_symmetric, collects all the scales, reshapes them, and then applies the general linear_q_with_scale_and_zero_point function (with zero_point=0) to quantize the original float32 tensor using the calculated per-channel scales.
Finally, linear_q_with_scale_and_zero_point performs the actual division, rounding, and clamping for the quantization step (and its inverse for dequantization, though dequantization for W8A32 happens in the forward pass):
# From utils.py
def linear_q_with_scale_and_zero_point(tensor, scale, zero_point, dtype=torch.int8):
"""Performs linear quantization using scale and zero-point.
Args:
tensor: Input tensor to quantize (float32)
scale: Scale factor(s) (can be per-channel)
zero_point: Zero point offset (0 for symmetric)
dtype: Target quantization type (default: torch.int8)
Returns:
Quantized tensor (int8)
"""
# Apply the formula: rounded(Real Value / Scale + Zero Point)
# Since zero_point is 0 for symmetric, this is rounded(tensor / scale)
scaled_tensor = tensor/scale + zero_point
rounded_tensor = torch.round(scaled_tensor)
# Get the min/max range for the target integer type (e.g., -128, 127 for int8)
q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
# Clamp the rounded values to fit within the int8 range and cast to int8
return rounded_tensor.clamp(min=q_min, max=q_max).to(dtype=dtype)
def linear_dequantization(quantized_tensor, scale, zero_point):
"""Performs linear dequantization using scale and zero-point.
(Used conceptualy, the actual dequantization happens differently in W8A32 forward pass)
... code details omitted ...
"""
# This is the inverse formula: scale * (quantized_tensor.float() - zero_point)
# For symmetric (zero_point=0), this is scale * quantized_tensor.float()
pass # The actual implementation exists but isn't the focus hereThis is the core math: divide by scale, add zero point (if any), round, and clamp.
Let's look briefly at how W8A32LinearLayer's quantize method calls linear_q_symmetric_per_channel:
# From tinyq.py - Inside W8A32LinearLayer.quantize()
def quantize(self, weights):
"""
Quantizes the input FP32 weights to INT8 and stores them along with
their quantization parameters.
"""
# Ensure weights are float32 for calculations
w_fp32 = weights.clone().to(torch.float32)
# Call the per-channel symmetric quantization function
# dim=0 means quantize independently for each output channel (row)
int8_weights, scales = linear_q_symmetric_per_channel(w_fp32, dim=0, dtype=torch.int8)
# Store the results as buffers in the layer
self.int8_weights = int8_weights
# Squeeze the scales to remove the extra dimension added for broadcasting
self.scales = scales.squeeze()
# Store zero points (always 0 for symmetric quantization)
self.zero_points = torch.zeros_like(self.scales, dtype=torch.int32) This quantize method is called once per layer during the model replacement process (Chapter 4). It gets the original float32 weights, passes them to linear_q_symmetric_per_channel, and then stores the resulting int8_weights, scales, and zero_points (which are 0 for symmetric) inside the W8A32LinearLayer instance.
The W8A16LinearLayer.quantize method does something very similar:
# From tinyq.py - Inside W8A16LinearLayer.quantize()
def quantize(self, weights):
w_fp32 = weights.clone().to(torch.float32)
# W8A16 uses a slightly different scale calculation method here
# It finds max absolute value per channel and calculates scale based on that.
# The implementation is slightly different than linear_q_symmetric_per_channel,
# but the concept (per-channel symmetric scale) is the same.
scales = w_fp32.abs().max(dim=-1).values / 127 # Calculate scales per channel
scales = scales.to(weights.dtype) # Keep scales in the original weight dtype (e.g. float16 if loaded in that)
# Quantize using the calculated scales and zero point 0
int8_weights = torch.round(weights/scales.unsqueeze(1)).to(torch.int8)
# Store the results
self.int8_weights = int8_weights
self.scales = scales # Note: W8A16 keeps scales in float16/float32 depending on load dtype
# W8A16 assumes zero_point is implicitly 0 as wellWhile W8A16LinearLayer calculates the scale slightly differently (a more direct calculation using max().values), the core idea remains the same: calculate per-channel scales based on the max absolute value and use them to map the original float32 weights to int8 values. Both use symmetric quantization (zero point is effectively 0).
Weight Quantization Math is the engine that converts your large float32 weights into compact int8 weights. In TinyQ, for both W8A32 and W8A16 methods, this relies on symmetric per-channel quantization.
This involves:
- Calculating a
Scalefor each output channel based on the maximum absolute value within that channel. - Assuming a
Zero Pointof 0 (symmetric). - Converting the original
float32values in each channel toint8by dividing by the channel's scale, rounding, and clamping.
The linear_q_symmetric_per_channel function in utils.py performs these steps for W8A32, while W8A16LinearLayer implements a similar calculation directly in its quantize method. This math is executed automatically when the Quantizer replaces nn.Linear layers with their custom quantized counterparts.
Now that we understand how the weights are converted and stored in a low-precision format, the next logical step is to see how the model actually uses these low-precision weights during inference – the Quantized Forward Pass Functions.