Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions torchft/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _fused_kernel_quantize_into_fp8(
# Compute maximum for the current row block by block
col_offsets = tl.arange(0, BLOCK_SIZE)
col_maxes = tl.full((BLOCK_SIZE,), 0, dtype=tl.float32)
for i_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
for _i_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
i_row_block = tl.load(
i_ptr + i_row_idx * i_row_stride + col_offsets * i_col_stride,
mask=col_offsets < i_cols_num,
Expand All @@ -146,7 +146,7 @@ def _fused_kernel_quantize_into_fp8(

# Scale and quantize current row block by block
col_offsets = tl.arange(0, BLOCK_SIZE)
for i_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
for _i_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
i_row_block = tl.load(
i_ptr + i_row_idx * i_row_stride + col_offsets * i_col_stride,
mask=col_offsets < i_cols_num,
Expand Down Expand Up @@ -240,7 +240,7 @@ def _fused_kernel_dequantize_from_fp8(

# Dequantize and store current row block by block
col_offsets = tl.arange(0, BLOCK_SIZE)
for i_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
for _i_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
i_quant_row_block = tl.load(
o_quant_ptr + col_offsets,
mask=col_offsets < i_cols_num,
Expand Down Expand Up @@ -315,7 +315,7 @@ def _fused_kernel_reduce_fp8(
col_offsets = tl.arange(0, BLOCK_SIZE)
# Compute scaling factor the reduced row
o_row_max = 0.0
for o_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
for _o_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
o_row_block_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
col_offsets_mask = col_offsets < i_cols_num
# Load blocks of quantized rows, dequantize and accumulate
Expand Down Expand Up @@ -347,7 +347,7 @@ def _fused_kernel_reduce_fp8(

col_offsets = tl.arange(0, BLOCK_SIZE)
# Reduce the row in blocks and write them out
for o_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
for _o_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
o_row_block_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
col_offsets_mask = col_offsets < i_cols_num
# Load blocks of quantized rows, dequantize and accumulate
Expand Down
Loading