Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor

import pytest
import torch
Expand Down
100 changes: 44 additions & 56 deletions tests/pytorch/test_grouped_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor
from transformer_engine.pytorch import (
Quantizer,
Float8Quantizer,
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_basic_construction_all_same_shape(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
Expand All @@ -147,7 +147,7 @@ def test_basic_construction_varying_first_dim(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
Expand All @@ -170,14 +170,18 @@ def test_split_into_quantized_tensors_no_quantization(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)

# Get the original data pointer
original_data_ptr = grouped_tensor.data.data_ptr()
# GroupedTensor is a wrapper; use backing storage buffer pointer.
storage = grouped_tensor.rowwise_data
if storage is None:
storage = grouped_tensor.columnwise_data
assert storage is not None
original_data_ptr = storage.data_ptr()

# Split into tensors
tensors = grouped_tensor.split_into_quantized_tensors()
Expand Down Expand Up @@ -207,13 +211,18 @@ def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=quantizer,
device="cuda",
dtype=torch.float32,
)

# Get the original data pointer
original_data_ptr = grouped_tensor.data.data_ptr()
# GroupedTensor is a wrapper; use backing storage buffer pointer.
storage = grouped_tensor.rowwise_data
if storage is None:
storage = grouped_tensor.columnwise_data
assert storage is not None
original_data_ptr = storage.data_ptr()

# Split into tensors
tensors = grouped_tensor.split_into_quantized_tensors()
Expand All @@ -236,13 +245,17 @@ def test_split_varying_shapes(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)

original_data_ptr = grouped_tensor.data.data_ptr()
storage = grouped_tensor.rowwise_data
if storage is None:
storage = grouped_tensor.columnwise_data
assert storage is not None
original_data_ptr = storage.data_ptr()
tensors = grouped_tensor.split_into_quantized_tensors()

assert len(tensors) == num_tensors
Expand All @@ -264,13 +277,18 @@ def test_quantize_inplace(self, quantization: str) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=quantizer,
device="cuda",
dtype=torch.float32,
)

# Get original data pointers before quantization
original_data_ptr = grouped_tensor.data.data_ptr()
storage = grouped_tensor.rowwise_data
if storage is None:
storage = grouped_tensor.columnwise_data
assert storage is not None
original_data_ptr = storage.data_ptr()
original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr()
original_scale_ptr = (
grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None
Expand All @@ -283,7 +301,7 @@ def test_quantize_inplace(self, quantization: str) -> None:
quantized_tensors = grouped_tensor.quantize(input_tensors)

# Verify data pointers haven't changed (in-place operation)
assert grouped_tensor.data.data_ptr() == original_data_ptr
assert storage.data_ptr() == original_data_ptr
assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr
if original_scale_ptr is not None:
assert grouped_tensor.scale.data_ptr() == original_scale_ptr
Expand All @@ -304,13 +322,18 @@ def test_quantize_varying_shapes(self, quantization: str) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=quantizer,
device="cuda",
dtype=torch.float32,
)

# Get original data pointers
original_data_ptr = grouped_tensor.data.data_ptr()
storage = grouped_tensor.rowwise_data
if storage is None:
storage = grouped_tensor.columnwise_data
assert storage is not None
original_data_ptr = storage.data_ptr()

# Create input tensors with varying shapes
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
Expand All @@ -319,7 +342,7 @@ def test_quantize_varying_shapes(self, quantization: str) -> None:
quantized_tensors = grouped_tensor.quantize(input_tensors)

# Verify data pointer hasn't changed
assert grouped_tensor.data.data_ptr() == original_data_ptr
assert storage.data_ptr() == original_data_ptr

# Verify each tensor points to correct location
cumulative_numel = 0
Expand All @@ -329,38 +352,6 @@ def test_quantize_varying_shapes(self, quantization: str) -> None:
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
cumulative_numel += tensor_shape[0] * tensor_shape[1]

@pytest.mark.parametrize("quantization", _quantization_params)
def test_static_quantize_method(self, quantization: str) -> None:
"""Test the static quantize method"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizer = make_quantizer(quantization, num_tensors, shape)

# Create input tensors
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]

# Use static quantize method
grouped_tensor = GroupedTensor.create_and_quantize(
tensors=input_tensors,
quantizer=quantizer,
device="cuda",
)

# Verify the grouped tensor was created correctly
assert grouped_tensor.num_tensors == num_tensors
assert grouped_tensor.has_data()

# Verify quantized_tensors were created and point to same storage
assert grouped_tensor.quantized_tensors is not None
assert len(grouped_tensor.quantized_tensors) == num_tensors

original_data_ptr = grouped_tensor.data.data_ptr()
for i, qtensor in enumerate(grouped_tensor.quantized_tensors):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset

@pytest.mark.parametrize(
"shape",
[[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]],
Expand All @@ -374,9 +365,6 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None:

# Create BF16 input tensors and pack into a 2D tensor
input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape]
quantized_tensors = [
MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(tensor) for tensor in input_tensors
]
grouped_input = torch.cat(input_tensors, dim=0)

# Create MXFP8 output grouped tensor (rowwise only for easier validation)
Expand Down Expand Up @@ -406,7 +394,7 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None:
expected_data = torch.cat(expected_data)
expected_scale_inv = torch.cat(expected_scale_inv)

assert torch.equal(grouped_output.data, expected_data)
assert torch.equal(grouped_output.rowwise_data, expected_data)
assert torch.equal(grouped_output.scale_inv, expected_scale_inv)

@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
Expand Down Expand Up @@ -451,7 +439,7 @@ def test_group_quantize_cudagraph_capturable(self) -> None:
torch.cuda.synchronize()

expected = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims)
assert torch.equal(static_output.data, expected.data)
assert torch.equal(static_output.rowwise_data, expected.rowwise_data)
assert torch.equal(static_output.scale_inv, expected.scale_inv)

def test_clear(self) -> None:
Expand All @@ -461,7 +449,7 @@ def test_clear(self) -> None:

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
shapes=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
Expand All @@ -474,5 +462,5 @@ def test_clear(self) -> None:

assert not grouped_tensor.has_data()
assert grouped_tensor.num_tensors == 0
assert grouped_tensor.data is None
assert grouped_tensor.rowwise_data is None
assert grouped_tensor.logical_shape == (0, 0)
Loading
Loading