diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index e0ad09200d..f2b0b07fed 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -37,6 +37,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py index 1e62f91eb8..8d81d578a7 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py @@ -10,7 +10,7 @@ from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.constants import TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling -from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor import pytest import torch diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index ad08c0474d..9dd965fa94 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -8,7 +8,7 @@ import pytest import torch import transformer_engine.pytorch as te -from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.pytorch import ( Quantizer, Float8Quantizer, @@ -125,7 +125,7 @@ def test_basic_construction_all_same_shape(self) -> None: grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shape, quantizer=None, device="cuda", dtype=torch.float32, @@ -147,7 +147,7 @@ def test_basic_construction_varying_first_dim(self) -> None: grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shape, quantizer=None, device="cuda", dtype=torch.float32, @@ -170,14 +170,18 @@ def test_split_into_quantized_tensors_no_quantization(self) -> None: grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shape, quantizer=None, device="cuda", dtype=torch.float32, ) - # Get the original data pointer - original_data_ptr = grouped_tensor.data.data_ptr() + # GroupedTensor is a wrapper; use backing storage buffer pointer. + storage = grouped_tensor.rowwise_data + if storage is None: + storage = grouped_tensor.columnwise_data + assert storage is not None + original_data_ptr = storage.data_ptr() # Split into tensors tensors = grouped_tensor.split_into_quantized_tensors() @@ -207,13 +211,18 @@ def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shape, quantizer=quantizer, device="cuda", + dtype=torch.float32, ) - # Get the original data pointer - original_data_ptr = grouped_tensor.data.data_ptr() + # GroupedTensor is a wrapper; use backing storage buffer pointer. + storage = grouped_tensor.rowwise_data + if storage is None: + storage = grouped_tensor.columnwise_data + assert storage is not None + original_data_ptr = storage.data_ptr() # Split into tensors tensors = grouped_tensor.split_into_quantized_tensors() @@ -236,13 +245,17 @@ def test_split_varying_shapes(self) -> None: grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shape, quantizer=None, device="cuda", dtype=torch.float32, ) - original_data_ptr = grouped_tensor.data.data_ptr() + storage = grouped_tensor.rowwise_data + if storage is None: + storage = grouped_tensor.columnwise_data + assert storage is not None + original_data_ptr = storage.data_ptr() tensors = grouped_tensor.split_into_quantized_tensors() assert len(tensors) == num_tensors @@ -264,13 +277,18 @@ def test_quantize_inplace(self, quantization: str) -> None: grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shape, quantizer=quantizer, device="cuda", + dtype=torch.float32, ) # Get original data pointers before quantization - original_data_ptr = grouped_tensor.data.data_ptr() + storage = grouped_tensor.rowwise_data + if storage is None: + storage = grouped_tensor.columnwise_data + assert storage is not None + original_data_ptr = storage.data_ptr() original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr() original_scale_ptr = ( grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None @@ -283,7 +301,7 @@ def test_quantize_inplace(self, quantization: str) -> None: quantized_tensors = grouped_tensor.quantize(input_tensors) # Verify data pointers haven't changed (in-place operation) - assert grouped_tensor.data.data_ptr() == original_data_ptr + assert storage.data_ptr() == original_data_ptr assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr if original_scale_ptr is not None: assert grouped_tensor.scale.data_ptr() == original_scale_ptr @@ -304,13 +322,18 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shape, quantizer=quantizer, device="cuda", + dtype=torch.float32, ) # Get original data pointers - original_data_ptr = grouped_tensor.data.data_ptr() + storage = grouped_tensor.rowwise_data + if storage is None: + storage = grouped_tensor.columnwise_data + assert storage is not None + original_data_ptr = storage.data_ptr() # Create input tensors with varying shapes input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] @@ -319,7 +342,7 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: quantized_tensors = grouped_tensor.quantize(input_tensors) # Verify data pointer hasn't changed - assert grouped_tensor.data.data_ptr() == original_data_ptr + assert storage.data_ptr() == original_data_ptr # Verify each tensor points to correct location cumulative_numel = 0 @@ -329,38 +352,6 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: assert rowwise_data.data_ptr() == original_data_ptr + expected_offset cumulative_numel += tensor_shape[0] * tensor_shape[1] - @pytest.mark.parametrize("quantization", _quantization_params) - def test_static_quantize_method(self, quantization: str) -> None: - """Test the static quantize method""" - num_tensors = 3 - shape = [(512, 512) for _ in range(num_tensors)] - quantizer = make_quantizer(quantization, num_tensors, shape) - - # Create input tensors - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] - - # Use static quantize method - grouped_tensor = GroupedTensor.create_and_quantize( - tensors=input_tensors, - quantizer=quantizer, - device="cuda", - ) - - # Verify the grouped tensor was created correctly - assert grouped_tensor.num_tensors == num_tensors - assert grouped_tensor.has_data() - - # Verify quantized_tensors were created and point to same storage - assert grouped_tensor.quantized_tensors is not None - assert len(grouped_tensor.quantized_tensors) == num_tensors - - original_data_ptr = grouped_tensor.data.data_ptr() - for i, qtensor in enumerate(grouped_tensor.quantized_tensors): - rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) - numel = shape[i][0] * shape[i][1] - expected_offset = _rowwise_offset_bytes(i * numel, quantization) - assert rowwise_data.data_ptr() == original_data_ptr + expected_offset - @pytest.mark.parametrize( "shape", [[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]], @@ -374,9 +365,6 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: # Create BF16 input tensors and pack into a 2D tensor input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] - quantized_tensors = [ - MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(tensor) for tensor in input_tensors - ] grouped_input = torch.cat(input_tensors, dim=0) # Create MXFP8 output grouped tensor (rowwise only for easier validation) @@ -406,7 +394,7 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: expected_data = torch.cat(expected_data) expected_scale_inv = torch.cat(expected_scale_inv) - assert torch.equal(grouped_output.data, expected_data) + assert torch.equal(grouped_output.rowwise_data, expected_data) assert torch.equal(grouped_output.scale_inv, expected_scale_inv) @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) @@ -451,7 +439,7 @@ def test_group_quantize_cudagraph_capturable(self) -> None: torch.cuda.synchronize() expected = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims) - assert torch.equal(static_output.data, expected.data) + assert torch.equal(static_output.rowwise_data, expected.rowwise_data) assert torch.equal(static_output.scale_inv, expected.scale_inv) def test_clear(self) -> None: @@ -461,7 +449,7 @@ def test_clear(self) -> None: grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, + shapes=shape, quantizer=None, device="cuda", dtype=torch.float32, @@ -474,5 +462,5 @@ def test_clear(self) -> None: assert not grouped_tensor.has_data() assert grouped_tensor.num_tensors == 0 - assert grouped_tensor.data is None + assert grouped_tensor.rowwise_data is None assert grouped_tensor.logical_shape == (0, 0) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 3ef8c0983f..384b6774f6 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -138,115 +138,21 @@ def reset_global_fp8_state(): FP8GlobalStateManager.reset() -def check_grouped_tensor_pointers_helper(tensors, num_elems_in_byte=1, tensor_name="tensor"): - """ - Verify that tensors are stored in contiguous memory. - - Args: - tensors: List or iterable of tensors to check - num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4) - tensor_name: Name to use in error messages - """ - tensor_list = list(tensors) - if len(tensor_list) < 2: - return # Nothing to check - - for i in range(1, len(tensor_list)): - prev_tensor = tensor_list[i - 1] - curr_tensor = tensor_list[i] - - # Calculate expected offset based on previous tensor size - prev_numel = prev_tensor.numel() - expected_offset = (prev_numel // num_elems_in_byte) * prev_tensor.element_size() - - # Verify current tensor's data pointer is correctly offset - expected_ptr = prev_tensor.data_ptr() + expected_offset - actual_ptr = curr_tensor.data_ptr() - - assert ( - actual_ptr == expected_ptr - ), f"{tensor_name} {i} data pointer mismatch: expected {expected_ptr}, got {actual_ptr}" - - -def check_grouped_tensor_pointers( - weights: List[torch.Tensor], fp8_recipe: Optional[recipe.Recipe] = None +def check_grouped_weight( + module: GroupedLinear, num_gemms: int, out_features: int, in_features: int ): """ - Verify that the pointers of the weights are in contiguous memory for GroupedTensor. - TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach. + Verify GroupedLinear exposes one grouped weight parameter with shape + [num_gemms, out_features, in_features]. """ - - num_elems_in_a_data_byte = 1 if fp8_recipe is None else 2 if fp8_recipe.nvfp4() else 1 - - # Check data. - if hasattr(weights[0], "_data") and weights[0]._data is not None: - data_tensors = [w._data for w in weights] - check_grouped_tensor_pointers_helper(data_tensors, num_elems_in_byte=1, tensor_name="data") - - # Check transpose. - if hasattr(weights[0], "_transpose") and weights[0]._transpose is not None: - transpose_tensors = [w._transpose for w in weights] - check_grouped_tensor_pointers_helper( - transpose_tensors, num_elems_in_byte=1, tensor_name="transpose" - ) - - # Check scale_inv. - if hasattr(weights[0], "_scale_inv") and weights[0]._scale_inv is not None: - scale_inv_tensors = [w._scale_inv for w in weights] - check_grouped_tensor_pointers_helper( - scale_inv_tensors, num_elems_in_byte=1, tensor_name="scale_inv" - ) - - # Check rowwise scale_inv. - if hasattr(weights[0], "_rowwise_scale_inv") and weights[0]._rowwise_scale_inv is not None: - scale_inv_tensors = [w._rowwise_scale_inv for w in weights] - check_grouped_tensor_pointers_helper( - scale_inv_tensors, num_elems_in_byte=1, tensor_name="rowwise_scale_inv" - ) - - # Check columnwise scale_inv. - if ( - hasattr(weights[0], "_columnwise_scale_inv") - and weights[0]._columnwise_scale_inv is not None - ): - columnwise_scale_inv_tensors = [w._columnwise_scale_inv for w in weights] - check_grouped_tensor_pointers_helper( - columnwise_scale_inv_tensors, - num_elems_in_byte=1, - tensor_name="columnwise scale_inv", - ) - - # Check rowwise amax. - if hasattr(weights[0], "_rowwise_amax") and weights[0]._rowwise_amax is not None: - rowwise_amax_tensors = [w._rowwise_amax for w in weights] - check_grouped_tensor_pointers_helper( - rowwise_amax_tensors, num_elems_in_byte=1, tensor_name="rowwise amax" - ) - - # Check columnwise amax. - if hasattr(weights[0], "_columnwise_amax") and weights[0]._columnwise_amax is not None: - columnwise_amax_tensors = [w._columnwise_amax for w in weights] - check_grouped_tensor_pointers_helper( - columnwise_amax_tensors, num_elems_in_byte=1, tensor_name="columnwise amax" - ) - - # Check rowwise data. - if hasattr(weights[0], "_rowwise_data") and weights[0]._rowwise_data is not None: - rowwise_data_tensors = [w._rowwise_data for w in weights] - check_grouped_tensor_pointers_helper( - rowwise_data_tensors, - num_elems_in_byte=num_elems_in_a_data_byte, - tensor_name="rowwise data", - ) - - # Check columnwise data. - if hasattr(weights[0], "_columnwise_data") and weights[0]._columnwise_data is not None: - columnwise_data_tensors = [w._columnwise_data for w in weights] - check_grouped_tensor_pointers_helper( - columnwise_data_tensors, - num_elems_in_byte=num_elems_in_a_data_byte, - tensor_name="columnwise data", - ) + weight_params = [(name, p) for name, p in module.named_parameters() if "weight" in name] + assert len(weight_params) == 1, f"Expected 1 grouped weight parameter, got {len(weight_params)}" + name, weight = weight_params[0] + assert name == "weight", f"Expected grouped parameter name 'weight', got {name}" + assert tuple(weight.shape) == (num_gemms, out_features, in_features), ( + "Grouped weight has unexpected shape. " + f"Expected {(num_gemms, out_features, in_features)}, got {tuple(weight.shape)}" + ) def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): @@ -603,9 +509,6 @@ def test_sanity_grouped_linear( bs = bs * 16 num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) - if single_param: - os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] = "1" - if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -620,13 +523,13 @@ def test_sanity_grouped_linear( ffn_hidden_size, bias=use_bias, params_dtype=dtype, + single_grouped_parameter=single_param, ).cuda() - # Verify that weights are stored in contiguous GroupedTensor storage. - weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)] + # Verify grouped linear exposes a single grouped weight parameter. if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()): if single_param: - check_grouped_tensor_pointers(weights, fp8_recipe) + check_grouped_weight(te_grouped_linear, num_gemms, ffn_hidden_size, config.hidden_size) inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True @@ -645,9 +548,6 @@ def test_sanity_grouped_linear( loss.backward() assert out.shape == (num_tokens, ffn_hidden_size) - if single_param: - del os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] - @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index b1d60cc3da..8302a13010 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -35,8 +35,9 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *NVFP4TensorPythonClass = nullptr; PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; PyTypeObject *NVFP4QuantizerClass = nullptr; -std::once_flag extension_init_flag; +PyTypeObject *GroupedTensorPythonClass = nullptr; PyTypeObject *GroupedTensorStoragePythonClass = nullptr; +std::once_flag extension_init_flag; void init_float8_extension() { auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); @@ -103,11 +104,17 @@ void init_nvfp4_extensions() { } void init_grouped_tensor_extension() { - if (GroupedTensorStoragePythonClass) return; + if (GroupedTensorPythonClass && GroupedTensorStoragePythonClass) return; auto grouped_tensor_module = - py::module_::import("transformer_engine.pytorch.tensor.storage.grouped_tensor"); - GroupedTensorStoragePythonClass = reinterpret_cast( + py::module_::import("transformer_engine.pytorch.tensor.grouped_tensor"); + GroupedTensorPythonClass = reinterpret_cast( PyObject_GetAttrString(grouped_tensor_module.ptr(), "GroupedTensor")); + auto grouped_tensor_storage_module = + py::module_::import("transformer_engine.pytorch.tensor.storage.grouped_tensor_storage"); + GroupedTensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(grouped_tensor_storage_module.ptr(), "GroupedTensorStorage")); + NVTE_CHECK(GroupedTensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch grouped tensor extension."); NVTE_CHECK(GroupedTensorStoragePythonClass != nullptr, "Internal error: could not initialize pyTorch grouped tensor extension."); } diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 059eb5e3fb..9e640537f9 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -43,6 +43,7 @@ extern PyTypeObject *Float8BlockwiseQuantizerClass; extern PyTypeObject *NVFP4TensorPythonClass; extern PyTypeObject *NVFP4TensorStoragePythonClass; extern PyTypeObject *NVFP4QuantizerClass; +extern PyTypeObject *GroupedTensorPythonClass; extern PyTypeObject *GroupedTensorStoragePythonClass; void init_extension(); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0da5f69197..0135c7f01c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -67,13 +67,13 @@ std::optional build_grouped_tensor_offsets(const size_t num_tensors, } const auto& first_dims_tensor = first_dims.value(); + NVTE_CHECK(first_dims_tensor.is_cuda(), "first_dims must be on CUDA."); NVTE_CHECK(first_dims_tensor.scalar_type() == at::kLong, "first_dims must have dtype int64."); NVTE_CHECK(static_cast(first_dims_tensor.numel()) == num_tensors, "first_dims must have length ", num_tensors, "."); const int64_t logical_last_dim_i64 = static_cast(logical_last_dim); - auto scaled_first_dims = first_dims_tensor * logical_last_dim_i64; - + auto scaled_first_dims = (first_dims_tensor * logical_last_dim_i64).contiguous(); // Single kernel needed for these ops. auto cumsum = at::cumsum(scaled_first_dims, 0); auto zero = at::zeros({1}, cumsum.options()); @@ -88,6 +88,11 @@ py::object maybe_tensor_to_py(const std::optional& tensor) { return tensor ? py::cast(*tensor) : py::none(); } +py::handle grouped_tensor_python_class(const bool internal) { + PyTypeObject* cls = internal ? GroupedTensorStoragePythonClass : GroupedTensorPythonClass; + return py::handle(reinterpret_cast(cls)); +} + } // namespace constexpr size_t NVFP4_BLOCK_SIZE = 16; @@ -172,18 +177,30 @@ std::pair NoneQuantizer::create_grouped_tensor getTensorShape(*tensor_offsets)); } - py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); - py::object out_py = GroupedTensorClass( - "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), - "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), - "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), "scale_inv"_a = py::none(), - "columnwise_scale_inv"_a = py::none(), "amax"_a = py::none(), - "columnwise_amax"_a = py::none(), "scale"_a = py::none(), - "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), - "last_dims"_a = py::none(), - "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), - "logical_shape"_a = std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); + py::dict kwargs; + py::tuple args(0); + kwargs["shape"] = py::cast(std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["num_tensors"] = py::cast(num_tensors); + kwargs["quantizer"] = quantizer; + kwargs["data"] = maybe_tensor_to_py(rowwise_data); + kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); + kwargs["scale_inv"] = py::none(); + kwargs["columnwise_scale_inv"] = py::none(); + kwargs["amax"] = py::none(); + kwargs["columnwise_amax"] = py::none(); + kwargs["scale"] = py::none(); + kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); + kwargs["last_dims"] = py::none(); + kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); + PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); + py::object out_py = py::reinterpret_steal(result); return {std::move(out_cpp), std::move(out_py)}; } @@ -366,19 +383,30 @@ std::pair Float8Quantizer::create_grouped_tens getTensorShape(*tensor_offsets)); } - py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); - py::object out_py = GroupedTensorClass( - "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), - "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), - "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), - "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), - "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = amax, - "columnwise_amax"_a = py::none(), "scale"_a = py::none(), - "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), - "last_dims"_a = py::none(), - "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), - "logical_shape"_a = std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); + py::dict kwargs; + py::tuple args(0); + kwargs["shape"] = py::cast(std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["num_tensors"] = py::cast(num_tensors); + kwargs["quantizer"] = quantizer; + kwargs["data"] = maybe_tensor_to_py(rowwise_data); + kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); + kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); + kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); + kwargs["amax"] = amax; + kwargs["columnwise_amax"] = py::none(); + kwargs["scale"] = py::none(); + kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); + kwargs["last_dims"] = py::none(); + kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); + PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); + py::object out_py = py::reinterpret_steal(result); return {std::move(out_cpp), std::move(out_py)}; } @@ -673,19 +701,30 @@ std::pair Float8CurrentScalingQuantizer::creat getTensorShape(*tensor_offsets)); } - py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); - py::object out_py = GroupedTensorClass( - "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), - "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), - "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), - "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), - "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = amax, - "columnwise_amax"_a = py::none(), "scale"_a = scale, - "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), - "last_dims"_a = py::none(), - "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), - "logical_shape"_a = std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); + py::dict kwargs; + py::tuple args(0); + kwargs["shape"] = py::cast(std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["num_tensors"] = py::cast(num_tensors); + kwargs["quantizer"] = quantizer; + kwargs["data"] = maybe_tensor_to_py(rowwise_data); + kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); + kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); + kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); + kwargs["amax"] = amax; + kwargs["columnwise_amax"] = py::none(); + kwargs["scale"] = scale; + kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); + kwargs["last_dims"] = py::none(); + kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); + PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); + py::object out_py = py::reinterpret_steal(result); return {std::move(out_cpp), std::move(out_py)}; } @@ -1020,19 +1059,30 @@ std::pair Float8BlockQuantizer::create_grouped getTensorShape(*tensor_offsets)); } - py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); - py::object out_py = GroupedTensorClass( - "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), - "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), - "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), - "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), - "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = py::none(), - "columnwise_amax"_a = py::none(), "scale"_a = py::none(), - "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), - "last_dims"_a = py::none(), - "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), - "logical_shape"_a = std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); + py::dict kwargs; + py::tuple args(0); + kwargs["shape"] = py::cast(std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["num_tensors"] = py::cast(num_tensors); + kwargs["quantizer"] = quantizer; + kwargs["data"] = maybe_tensor_to_py(rowwise_data); + kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); + kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); + kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); + kwargs["amax"] = py::none(); + kwargs["columnwise_amax"] = py::none(); + kwargs["scale"] = py::none(); + kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); + kwargs["last_dims"] = py::none(); + kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); + PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); + py::object out_py = py::reinterpret_steal(result); return {std::move(out_cpp), std::move(out_py)}; } @@ -1425,19 +1475,30 @@ std::pair MXFP8Quantizer::create_grouped_tenso out_cpp.set_with_gemm_swizzled_scales(this->optimize_for_gemm); - py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); - py::object out_py = GroupedTensorClass( - "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), - "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), - "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), - "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), - "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = py::none(), - "columnwise_amax"_a = py::none(), "scale"_a = py::none(), - "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), - "last_dims"_a = py::none(), - "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), - "logical_shape"_a = std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); + py::dict kwargs; + py::tuple args(0); + kwargs["shape"] = py::cast(std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["num_tensors"] = py::cast(num_tensors); + kwargs["quantizer"] = quantizer; + kwargs["data"] = maybe_tensor_to_py(rowwise_data); + kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); + kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); + kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); + kwargs["amax"] = py::none(); + kwargs["columnwise_amax"] = py::none(); + kwargs["scale"] = py::none(); + kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); + kwargs["last_dims"] = py::none(); + kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); + PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); + py::object out_py = py::reinterpret_steal(result); return {std::move(out_cpp), std::move(out_py)}; } @@ -1842,20 +1903,30 @@ std::pair NVFP4Quantizer::create_grouped_tenso out_cpp.set_with_gemm_swizzled_scales(this->optimize_for_gemm); - py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); - py::object out_py = GroupedTensorClass( - "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), - "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), - "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), - "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), - "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), - "amax"_a = maybe_tensor_to_py(rowwise_amax), - "columnwise_amax"_a = maybe_tensor_to_py(columnwise_amax), "scale"_a = py::none(), - "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), - "last_dims"_a = py::none(), - "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), - "logical_shape"_a = std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); + py::dict kwargs; + py::tuple args(0); + kwargs["shape"] = py::cast(std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["num_tensors"] = py::cast(num_tensors); + kwargs["quantizer"] = quantizer; + kwargs["data"] = maybe_tensor_to_py(rowwise_data); + kwargs["columnwise_data"] = maybe_tensor_to_py(columnwise_data); + kwargs["scale_inv"] = maybe_tensor_to_py(rowwise_scale_inv); + kwargs["columnwise_scale_inv"] = maybe_tensor_to_py(columnwise_scale_inv); + kwargs["amax"] = maybe_tensor_to_py(rowwise_amax); + kwargs["columnwise_amax"] = maybe_tensor_to_py(columnwise_amax); + kwargs["scale"] = py::none(); + kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); + kwargs["last_dims"] = py::none(); + kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); + PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create GroupedTensor instance"); + py::object out_py = py::reinterpret_steal(result); return {std::move(out_cpp), std::move(out_py)}; } diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b381073d78..f3e7b57cf1 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -6,7 +6,6 @@ from typing import Union, Optional, Callable, Tuple, List from itertools import chain import warnings -import os import functools import torch @@ -14,7 +13,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from .base import ( get_dummy_wgrad, TransformerEngineBaseModule, @@ -595,6 +594,10 @@ class GroupedLinear(TransformerEngineBaseModule): cast tensor. In some scenarios, the input tensor is used by multiple modules, and saving the original input tensor may reduce the memory usage. Cannot work with FP8 DelayedScaling recipe. + single_grouped_parameter : bool, default = False + If set to ``True``, grouped weights are stored as a single grouped parameter + instead of one parameter per GEMM. + EXPERIMENTAL and subject to change. Notes ----- @@ -625,6 +628,7 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, save_original_input: bool = False, + single_grouped_parameter: bool = False, name: Optional[str] = None, ) -> None: super().__init__(name) @@ -641,6 +645,7 @@ def __init__( self.ub_overlap_ag = ub_overlap_ag self.ub_name = ub_name self.save_original_input = save_original_input + self.single_grouped_parameter = single_grouped_parameter assert ( not ub_overlap_rs and not ub_overlap_ag ), "GroupedLinear doesn't support Userbuffer overlap." @@ -767,7 +772,7 @@ def make_grouped_weights(self, defer_init=False) -> None: # Create the weight storage. grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=self.num_gemms, - shape=[(self.out_features, self.in_features)] * self.num_gemms, + shapes=[(self.out_features, self.in_features)] * self.num_gemms, quantizer=weight_quantizers[0], dtype=self.params_dtype, device=weights[0].device, @@ -781,22 +786,27 @@ def make_grouped_weights(self, defer_init=False) -> None: else: grouped_weights.quantized_tensors[i].copy_(weights[i]) - # Re-register the grouped weights as parameters. + # Re-register as a single grouped weight parameter. + # Re-register as a single grouped weight parameter. + assert isinstance(grouped_weights, torch.Tensor) and ( + weight_quantizers[0] is None or not weight_quantizers[0].internal + ), "Found internal quantizer with `single_grouped_parameter=True`." + self.register_parameter( + "weight", + torch.nn.Parameter(grouped_weights), + init_fn=self.init_method, + get_rng_state_tracker=self.get_rng_state_tracker, + fp8_meta_index=self._offsets["weight"], + ) for i in range(self.num_gemms): - self.register_parameter( - f"weight{i}", - torch.nn.Parameter(grouped_weights.quantized_tensors[i]), - init_fn=self.init_method, - get_rng_state_tracker=self.get_rng_state_tracker, - fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"], - ) + self.register_parameter(f"weight{i}", None) self.set_tensor_parallel_attributes(defer_init=defer_init) def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) # Grouped tensor weights is an opt-in feature. - if bool(int(os.getenv("NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS", "0"))): + if self.single_grouped_parameter: self.make_grouped_weights(defer_init=defer_init) def set_tensor_parallel_attributes(self, defer_init=False) -> None: @@ -804,13 +814,22 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: if not defer_init: # Set parallelism attributes for linear weights - for i in range(self.num_gemms): + grouped_weight = getattr(self, "weight", None) + if grouped_weight is not None: set_tensor_model_parallel_attributes( - tensor=getattr(self, f"weight{i}"), + tensor=grouped_weight, is_parallel=True, dim=1 if self.parallel_mode == "row" else 0, stride=1, ) + else: + for i in range(self.num_gemms): + set_tensor_model_parallel_attributes( + tensor=getattr(self, f"weight{i}"), + is_parallel=True, + dim=1 if self.parallel_mode == "row" else 0, + stride=1, + ) # Set parallelism attributes for linear biases if self.use_bias: @@ -933,7 +952,7 @@ def backward_dw(self): with get_nvtx_range_context("_GroupedLinear_wgrad"): (_, grad_biases_, _), tensor_list = self.wgrad_store.pop() wgrad_list = tensor_list[2] - weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + weight_params = self._get_weight_tensors() bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fuse_wgrad_accumulation: for i in range(self.num_gemms): @@ -983,7 +1002,14 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" - weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + grouped_weight = getattr(self, "weight", None) + if grouped_weight is not None: + weight_tensors = grouped_weight.quantized_tensors + if weight_tensors is None: + # TODO(ksivaman): Remove this after GEMM integration. + weight_tensors = grouped_weight.split_into_quantized_tensors() + else: + weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] if not self.fp8 and any(isinstance(w, QuantizedTensorStorage) for w in weight_tensors): warnings.warn( "You are using quantized weights without quantized compute. " diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index cb199d24b5..5668056700 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -17,10 +17,12 @@ from .storage.mxfp8_tensor_storage import MXFP8TensorStorage from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from .storage.nvfp4_tensor_storage import NVFP4TensorStorage +from .storage.grouped_tensor_storage import GroupedTensorStorage from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from .nvfp4_tensor import NVFP4Tensor, NVFP4Quantizer +from .grouped_tensor import GroupedTensor from .utils import cast_master_weights_to_fp8, replace_raw_data __all__ = [ @@ -35,11 +37,13 @@ "MXFP8TensorStorage", "Float8BlockwiseQTensorStorage", "NVFP4TensorStorage", + "GroupedTensorStorage", "QuantizedTensor", "Float8Tensor", "MXFP8Tensor", "Float8BlockwiseQTensor", "NVFP4Tensor", + "GroupedTensor", "prepare_for_saving", "restore_from_saved", ] @@ -89,5 +93,7 @@ def get_all_tensor_types(): Float8BlockwiseQTensorStorage, NVFP4Tensor, NVFP4TensorStorage, + GroupedTensor, + GroupedTensorStorage, ] return all_tensor_types diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py new file mode 100644 index 0000000000..767b0ccb35 --- /dev/null +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -0,0 +1,205 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Grouped tensor class for handling collections of tensors with different shapes""" +from __future__ import annotations + +from typing import List, Optional, Tuple + +import torch +from torch.utils._pytree import tree_map + +from ..quantized_tensor import QuantizedTensorStorage, Quantizer +from .storage.grouped_tensor_storage import GroupedTensorStorage + + +# For now, conservatively ban all shape manipulating ops. +BANNED_SHAPE_OPS = { + torch.ops.aten.view.default, + torch.ops.aten._unsafe_view.default, + torch.ops.aten.reshape.default, + torch.ops.aten._reshape_alias.default, + torch.ops.aten.flatten.using_ints, + torch.ops.aten.unflatten.int, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + torch.ops.aten.unsqueeze.default, + torch.ops.aten.transpose.int, + torch.ops.aten.permute.default, + torch.ops.aten.movedim.int, + torch.ops.aten.t.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.narrow.default, + torch.ops.aten.select.int, + torch.ops.aten.split.Tensor, + torch.ops.aten.chunk.default, + torch.ops.aten.expand.default, + torch.ops.aten.expand_as.default, + torch.ops.aten.cat.default, + torch.ops.aten.stack.default, +} + + +class GroupedTensor(GroupedTensorStorage, torch.Tensor): + """Tensor wrapper class for grouped tensor storage.""" + + def __new__( + cls, + shape: Tuple[int, int], + dtype: torch.dtype, + num_tensors: int, + shapes: Optional[List[Tuple[int, int]]] = None, + quantizer: Optional[Quantizer] = None, + data: Optional[torch.Tensor] = None, + columnwise_data: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + columnwise_scale_inv: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + columnwise_amax: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + first_dims: Optional[torch.Tensor] = None, + last_dims: Optional[torch.Tensor] = None, + tensor_offsets: Optional[torch.Tensor] = None, + offsets: Optional[List[int]] = None, + scale_inv_offsets: Optional[List[int]] = None, + columnwise_scale_inv_offsets: Optional[List[int]] = None, + ): + del quantizer + del offsets + del scale_inv_offsets + del columnwise_scale_inv_offsets + + if ( + shapes is not None + and len(shapes) == num_tensors + and num_tensors > 0 + and all(shapes[0] == s for s in shapes) + ): + wrapper_shape = (num_tensors, shapes[0][0], shapes[0][1]) + else: + wrapper_shape = shape + + device = None + for maybe_tensor in ( + data, + columnwise_data, + scale_inv, + columnwise_scale_inv, + amax, + columnwise_amax, + scale, + first_dims, + last_dims, + tensor_offsets, + ): + if maybe_tensor is not None: + device = maybe_tensor.device + break + if device is None: + device = torch.device("cuda") + + strides = [1] * len(wrapper_shape) + for i in range(len(wrapper_shape) - 2, -1, -1): + strides[i] = strides[i + 1] * wrapper_shape[i + 1] + return torch.Tensor._make_wrapper_subclass( + cls, + wrapper_shape, + strides=tuple(strides), + storage_offset=0, + dtype=dtype, + layout=torch.strided, + requires_grad=False, + device=device, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + """Dispatch by dequantizing grouped members, then requantizing writes.""" + if kwargs is None: + kwargs = {} + + # Parameter construction calls detach()/alias-like paths. + if func in (torch.ops.aten.detach.default, torch.ops.aten.alias.default): + return args[0] + + # Don't allow reshape/view etc. + if func in BANNED_SHAPE_OPS: + raise RuntimeError(f"{cls.__name__} forbids shape-manipulation op: {func} ") + + def grouped_to_stacked_tensor(grouped: GroupedTensor) -> torch.Tensor: + if not grouped.all_same_shape(): + raise NotImplementedError( + "GroupedTensor __torch_dispatch__ currently supports only uniform member shapes" + ) + grouped_members = grouped.quantized_tensors + if grouped_members is None: + grouped_members = grouped.split_into_quantized_tensors() + dequantized_members = [ + ( + member.dequantize(dtype=grouped.get_dtype()) + if isinstance(member, QuantizedTensorStorage) + else member + ) + for member in grouped_members + ] + return torch.stack(dequantized_members, dim=0) + + def maybe_unwrap(arg): + if isinstance(arg, GroupedTensor): + return grouped_to_stacked_tensor(arg) + return arg + + def update_grouped_tensor_inplace(grouped: GroupedTensor, updated: torch.Tensor): + if not grouped.all_same_shape(): + raise NotImplementedError( + "GroupedTensor __torch_dispatch__ currently supports only uniform member shapes" + ) + updated_members = list(updated.unbind(dim=0)) + if grouped.quantizer is None: + grouped_members = grouped.quantized_tensors + if grouped_members is None: + grouped_members = grouped.split_into_quantized_tensors() + for dst, src in zip(grouped_members, updated_members): + dst.copy_(src) + else: + grouped.quantize(updated_members) + + def maybe_update_inplace(arg, new_arg, schema_arg): + if ( + isinstance(arg, GroupedTensor) + and isinstance(new_arg, torch.Tensor) + and hasattr(schema_arg, "alias_info") + and hasattr(schema_arg.alias_info, "is_write") + and schema_arg.alias_info.is_write + ): + update_grouped_tensor_inplace(arg, new_arg) + elif isinstance(arg, list) and isinstance(new_arg, list): + for a, na in zip(arg, new_arg): + maybe_update_inplace(a, na, schema_arg) + + # In-place op: dequantize members, perform op, write back into grouped storage. + if func._schema.is_mutable: + new_args = tree_map(maybe_unwrap, args) + new_kwargs = tree_map(maybe_unwrap, kwargs) + schema_args = func._schema.arguments + args_len = len(args) + super().__torch_dispatch__(func, types, new_args, new_kwargs) + for arg, new_arg, schema_arg in zip(args, new_args, schema_args): + maybe_update_inplace(arg, new_arg, schema_arg) + for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): + assert kwarg == new_kwarg == schema_arg.name, "name of kwarg should match schema" + maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) + return None + + # Default op: operate on dequantized stacked tensors. + new_args = tree_map(maybe_unwrap, args) + new_kwargs = tree_map(maybe_unwrap, kwargs) + return super().__torch_dispatch__(func, types, new_args, new_kwargs) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + # Do not force GroupedTensor on outputs. + return torch._C._disabled_torch_function_impl(func, types, args, kwargs) diff --git a/transformer_engine/pytorch/tensor/storage/__init__.py b/transformer_engine/pytorch/tensor/storage/__init__.py index 7c8a014c1d..44a77d975f 100644 --- a/transformer_engine/pytorch/tensor/storage/__init__.py +++ b/transformer_engine/pytorch/tensor/storage/__init__.py @@ -7,4 +7,4 @@ from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401 from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401 from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401 -from .grouped_tensor import GroupedTensor # noqa: F401 +from .grouped_tensor_storage import GroupedTensorStorage # noqa: F401 diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py similarity index 87% rename from transformer_engine/pytorch/tensor/storage/grouped_tensor.py rename to transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index bf5792ffc9..92006ba45b 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -2,13 +2,12 @@ # # See LICENSE for license information. -"""Grouped tensor class for handling collections of tensors with different shapes""" +"""Grouped tensor storage class for handling collections of tensors with different shapes""" from __future__ import annotations from typing import Optional, Tuple, List, Union import math import torch - from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ..mxfp8_tensor import MXFP8Tensor @@ -21,7 +20,7 @@ from .nvfp4_tensor_storage import NVFP4TensorStorage -class GroupedTensor: +class GroupedTensorStorage: """ EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. @@ -51,10 +50,11 @@ class GroupedTensor: def __init__( self, + shape: Tuple[int, int], + dtype: torch.dtype, num_tensors: int, - shape: Optional[List[Tuple[int, int]]] = None, + shapes: Optional[List[Tuple[int, int]]] = None, quantizer: Optional[Quantizer] = None, - dtype: Optional[torch.dtype] = None, data: Optional[torch.Tensor] = None, columnwise_data: Optional[torch.Tensor] = None, scale_inv: Optional[torch.Tensor] = None, @@ -68,15 +68,16 @@ def __init__( offsets: Optional[List[int]] = None, scale_inv_offsets: Optional[List[int]] = None, columnwise_scale_inv_offsets: Optional[List[int]] = None, - logical_shape: Optional[Tuple[int, int]] = None, ) -> None: """ Initialize a GroupedTensor. Args: + shape: 2D tuple representing conceptual shape + dtype: Data type of the grouped tensor num_tensors: Number of tensors in the group - shape: 2D shape of each tensor (len num_tensors) - quantizer: Quantizer for the grouped tensor + shapes: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer used for all tensors in the group data: Row-wise data buffer (1D flattened) columnwise_data: Column-wise data buffer (1D flattened) scale_inv: Row-wise scale inverse buffer @@ -88,17 +89,14 @@ def __init__( last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform) offsets: Vector of integer offsets for each tensor. - logical_shape: 2D tuple representing conceptual shape """ self.num_tensors = num_tensors self.quantizer = quantizer - self.shape = shape - self.dtype = ( - dtype if dtype is not None else torch.float32 - ) # Default to float32 if not provided + self.tensor_shapes = shapes + self.fake_dtype = dtype # Data buffers - self.data = data + self.rowwise_data = data self.columnwise_data = columnwise_data self.scale_inv = scale_inv self.columnwise_scale_inv = columnwise_scale_inv @@ -132,7 +130,7 @@ def __init__( # Logical shape: conceptual 2D shape of the grouped data (REQUIRED) # Represents how the 1D flattened data should be interpreted as 2D # Always 2D with positive dimensions - self.logical_shape = logical_shape if logical_shape is not None else (0, 0) + self.logical_shape = shape # Hold a reference to the quantized tensors that occupy same storage as the GroupedTensor. # Used as a convenience. @@ -145,7 +143,7 @@ def has_data(self) -> bool: Returns: True if data buffer is initialized, False otherwise """ - return self.data is not None + return self.rowwise_data is not None def has_columnwise_data(self) -> bool: """ @@ -239,14 +237,13 @@ def get_dtype(self) -> torch.dtype: The high precision dtype of the data buffer """ - return self.dtype + return self.fake_dtype def clear(self) -> None: """ Reset tensor data and clear all buffers. """ - self.shape = None - self.data = None + self.rowwise_data = None self.columnwise_data = None self.scale_inv = None self.columnwise_scale_inv = None @@ -263,49 +260,34 @@ def clear(self) -> None: self.offsets = None self.scale_inv_offsets = None self.columnwise_scale_inv_offsets = None + self.tensor_shapes = [] + self.fake_dtype = torch.float32 def __repr__(self) -> str: - """String representation of the GroupedTensor.""" + """String representation of the GroupedTensorStorage.""" return ( - f"GroupedTensor(num_tensors={self.num_tensors}, " - f"shape={self.shape}, " + f"GroupedTensorStorage(num_tensors={self.num_tensors}, " + f"shapes={self.tensor_shapes}, " f"logical_shape={self.logical_shape}, " + f"quantizer={self.quantizer}, " f"dtype={self.get_dtype()})" ) - def __str__(self) -> str: - """User-friendly string representation.""" - shape_info = [] - if self.all_same_shape(): - shape_info.append("uniform shape") - else: - if not self.all_same_first_dim(): - shape_info.append("varying first dim") - if not self.all_same_last_dim(): - shape_info.append("varying last dim") - - return ( - f"GroupedTensor with {self.num_tensors} tensors " - f"({', '.join(shape_info) if shape_info else 'uniform'}), " - f"logical_shape={self.logical_shape}, " - f"dtype={self.get_dtype()}" - ) - @staticmethod def make_grouped_tensor_with_shapes( num_tensors: int, - shape: List[Tuple[int, int]], + shapes: List[Tuple[int, int]], quantizer: Optional[Quantizer] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ) -> GroupedTensor: + ) -> GroupedTensorStorage: """ Create a GroupedTensor for storing multiple weight tensors of the same shape. Args: num_tensors: Number of tensors - shape: 2D shape of each tensor (len num_tensors) - quantizer: Quantizer for each tensor + shapes: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer used for all tensors device: Device to allocate tensors on, defaults to current cuda device dtype: Data type of the tensor (for high precision case) @@ -314,20 +296,20 @@ def make_grouped_tensor_with_shapes( """ # First dim - first_dim_list = [s[0] for s in shape] + first_dim_list = [s[0] for s in shapes] uniform_first_dim = all(first_dim_list[0] == x for x in first_dim_list) logical_first_dim = sum(first_dim_list) if uniform_first_dim: first_dims = None else: - first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device=device) + first_dims = torch.tensor([s[0] for s in shapes], dtype=torch.int64, device=device) # Last dim - last_dim_list = [s[1] for s in shape] + last_dim_list = [s[1] for s in shapes] logical_last_dim = last_dim_list[0] assert all(logical_last_dim == x for x in last_dim_list), "Last dims should be uniform" - return GroupedTensor.make_grouped_tensor( + return GroupedTensorStorage.make_grouped_tensor( num_tensors=num_tensors, first_dims=first_dims, last_dims=None, @@ -348,7 +330,7 @@ def make_grouped_tensor( quantizer: Optional[Quantizer] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ) -> GroupedTensor: + ) -> GroupedTensorStorage: """ Create a GroupedTensor for storing multiple weight tensors of the same shape. @@ -358,8 +340,8 @@ def make_grouped_tensor( last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) logical_first_dim: Logical first dimension logical_last_dim: Logical last dimension - quantizer: Quantizer for each tensor - Used to figure out the recipe and what to allocate. + quantizer: Quantizer used for all tensors. Used to figure out recipe + and what to allocate. device: Device to allocate tensors on, defaults to current cuda device dtype: Data type of the tensor (for high precision case) @@ -574,10 +556,22 @@ def make_grouped_tensor( else: raise ValueError(f"Unsupported quantizer for GroupedTensor: {quantizer}") - grouped_tensor = GroupedTensor( + # Construct wrapper vs storage based on quantizer.internal. + # If quantizer is None (high precision path), default to wrapper class. + # TODO(ksivaman): Properly handle high precision path. + internal = False if quantizer is None else quantizer.internal + if internal: + grouped_tensor_class = GroupedTensorStorage + else: + from ..grouped_tensor import GroupedTensor + + grouped_tensor_class = GroupedTensor + + grouped_tensor = grouped_tensor_class( + logical_shape, + dtype, num_tensors=num_tensors, - shape=shape, - dtype=dtype, + shapes=shape, quantizer=quantizer, data=data, columnwise_data=columnwise_data, @@ -592,7 +586,6 @@ def make_grouped_tensor( offsets=offsets, scale_inv_offsets=scale_inv_offsets, columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, - logical_shape=logical_shape, ) grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() @@ -620,8 +613,8 @@ def split_into_quantized_tensors( no_quantization = self.quantizer is None - # if self.shape is None, then trigger D2H copy and get the shape (not graph safe) - if self.shape is None: + # if self.tensor_shapes is None, then trigger D2H copy and get the shape (not graph safe) + if self.tensor_shapes is None: first_dims_list = ( [self.logical_shape[0]] * self.num_tensors if self.first_dims is None @@ -635,7 +628,7 @@ def split_into_quantized_tensors( shape_list = [] for i in range(self.num_tensors): shape_list.append((first_dims_list[i], last_dims_list[i])) - self.shape = shape_list + self.tensor_shapes = shape_list # edge case: handle the case where tensor_offsets is given but offsets is not set if self.offsets is None and self.tensor_offsets is not None: @@ -645,7 +638,7 @@ def split_into_quantized_tensors( if no_quantization: for i in range(self.num_tensors): # Get tensor shape - tensor_shape = self.shape[i] + tensor_shape = self.tensor_shapes[i] # Get tensor data slice if self.offsets is not None: @@ -654,7 +647,7 @@ def split_into_quantized_tensors( end_offset = start_offset + numel if self.has_data(): - tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + tensor_data = self.rowwise_data[start_offset:end_offset].view(tensor_shape) result.append(tensor_data) elif self.has_columnwise_data(): tensor_data = self.columnwise_data[start_offset:end_offset].view( @@ -670,7 +663,7 @@ def split_into_quantized_tensors( end_offset = start_offset + numel if self.has_data(): - tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + tensor_data = self.rowwise_data[start_offset:end_offset].view(tensor_shape) result.append(tensor_data) elif self.has_columnwise_data(): tensor_data = self.columnwise_data[start_offset:end_offset].view( @@ -698,8 +691,9 @@ def split_into_quantized_tensors( self.columnwise_scale_inv_offsets = self.tensor_offsets // 32 for i in range(self.num_tensors): + quantizer = self.quantizer # Get tensor shape - tensor_shape = self.shape[i] + tensor_shape = self.tensor_shapes[i] numel = tensor_shape[0] * tensor_shape[1] # Get data offsets @@ -712,7 +706,7 @@ def split_into_quantized_tensors( data_end = data_start + numel # Special shape handling for NVFP4. - nvfp4 = self.quantizer._get_compatible_recipe().nvfp4() + nvfp4 = quantizer._get_compatible_recipe().nvfp4() if nvfp4: data_start = data_start // 2 data_end = data_end // 2 @@ -723,15 +717,15 @@ def split_into_quantized_tensors( if self.has_data(): if nvfp4: - rowwise_tensor_shape = self.quantizer.convert_shape_for_fp4(tensor_shape) + rowwise_tensor_shape = quantizer.convert_shape_for_fp4(tensor_shape) else: rowwise_tensor_shape = tensor_shape - rowwise_data = self.data[data_start:data_end].view(rowwise_tensor_shape) + rowwise_data = self.rowwise_data[data_start:data_end].view(rowwise_tensor_shape) if self.has_columnwise_data(): - columnwise_tensor_shape = self.quantizer.get_columnwise_shape(tensor_shape) + columnwise_tensor_shape = quantizer.get_columnwise_shape(tensor_shape) if nvfp4: - columnwise_tensor_shape = self.quantizer.convert_shape_for_fp4( + columnwise_tensor_shape = quantizer.convert_shape_for_fp4( columnwise_tensor_shape ) columnwise_data = self.columnwise_data[data_start:data_end].view( @@ -750,7 +744,7 @@ def split_into_quantized_tensors( scale_end = self.scale_inv_offsets[i + 1] # Calculate expected scale shape for MXFP8 - scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + scale_shape = quantizer.get_scale_shape(tensor_shape, False) rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) if ( @@ -761,25 +755,25 @@ def split_into_quantized_tensors( # for paged stashing, columnwise_scale_inv should depend on the split offsets cscale_end = self.columnwise_scale_inv_offsets[i + 1] - cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + cscale_shape = quantizer.get_scale_shape(tensor_shape, True) columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( cscale_shape ) - if self.quantizer.internal: + if quantizer.internal: mxfp8_tensor_class = MXFP8TensorStorage else: mxfp8_tensor_class = MXFP8Tensor tensor = mxfp8_tensor_class( shape=tensor_shape, - dtype=self.dtype, + dtype=self.fake_dtype, rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, - fp8_dtype=self.quantizer.dtype, - quantizer=self.quantizer, - with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm, + fp8_dtype=quantizer.dtype, + quantizer=quantizer, + with_gemm_swizzled_scales=quantizer.optimize_for_gemm, ) result.append(tensor) @@ -790,18 +784,18 @@ def split_into_quantized_tensors( if self.scale_inv is not None: scale_inv = self.scale_inv[i : i + 1] - if self.quantizer.internal: + if quantizer.internal: float8_tensor_class = Float8TensorStorage else: float8_tensor_class = Float8Tensor tensor = float8_tensor_class( shape=tensor_shape, - dtype=self.dtype, + dtype=self.fake_dtype, data=rowwise_data, fp8_scale_inv=scale_inv, - fp8_dtype=self.quantizer.dtype, - quantizer=self.quantizer, + fp8_dtype=quantizer.dtype, + quantizer=quantizer, data_transpose=columnwise_data, ) result.append(tensor) @@ -818,7 +812,7 @@ def split_into_quantized_tensors( scale_end = self.scale_inv_offsets[i + 1] # Get scale shape from quantizer - scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + scale_shape = quantizer.get_scale_shape(tensor_shape, False) rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) if ( @@ -830,28 +824,28 @@ def split_into_quantized_tensors( cscale_end = self.columnwise_scale_inv_offsets[i + 1] # Get columnwise scale shape from quantizer - cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + cscale_shape = quantizer.get_scale_shape(tensor_shape, True) columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( cscale_shape ) # Compute is_2D_scaled and data_format from quantizer attributes - is_2D_scaled = self.quantizer.block_scaling_dim == 2 + is_2D_scaled = quantizer.block_scaling_dim == 2 - if self.quantizer.internal: + if quantizer.internal: float8_blockwise_q_tensor_class = Float8BlockwiseQTensorStorage else: float8_blockwise_q_tensor_class = Float8BlockwiseQTensor tensor = float8_blockwise_q_tensor_class( shape=tensor_shape, - dtype=self.dtype, + dtype=self.fake_dtype, rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, - fp8_dtype=self.quantizer.dtype, - quantizer=self.quantizer, + fp8_dtype=quantizer.dtype, + quantizer=quantizer, is_2D_scaled=is_2D_scaled, ) result.append(tensor) @@ -870,7 +864,7 @@ def split_into_quantized_tensors( scale_end = self.scale_inv_offsets[i + 1] # Get scale shape from quantizer - scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + scale_shape = quantizer.get_scale_shape(tensor_shape, False) rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) if ( @@ -882,7 +876,7 @@ def split_into_quantized_tensors( cscale_end = self.columnwise_scale_inv_offsets[i + 1] # Get columnwise scale shape from quantizer - cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + cscale_shape = quantizer.get_scale_shape(tensor_shape, True) columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( cscale_shape ) @@ -894,23 +888,23 @@ def split_into_quantized_tensors( if self.columnwise_amax is not None: amax_columnwise = self.columnwise_amax[i : i + 1] - if self.quantizer.internal: + if quantizer.internal: nvfp4_tensor_class = NVFP4TensorStorage else: nvfp4_tensor_class = NVFP4Tensor tensor = nvfp4_tensor_class( shape=tensor_shape, - dtype=self.dtype, + dtype=self.fake_dtype, rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, amax_rowwise=amax_rowwise, amax_columnwise=amax_columnwise, - fp4_dtype=self.quantizer.dtype, - quantizer=self.quantizer, - with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm, + fp4_dtype=quantizer.dtype, + quantizer=quantizer, + with_gemm_swizzled_scales=quantizer.optimize_for_gemm, ) result.append(tensor) @@ -919,32 +913,6 @@ def split_into_quantized_tensors( return result - @staticmethod - def create_and_quantize( - tensors: int, - quantizer: None | Quantizer, - *, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - noop_flag: Optional[torch.Tensor] = None, - ) -> Tuple[QuantizedTensorStorage, ...]: - """ - Quantize given tensors into quantized tensors with underlying - storage allocated in a GroupedTensor. - """ - - grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=len(tensors), - shape=[t.shape for t in tensors], - quantizer=quantizer, - device=device, - dtype=dtype, - ) - - grouped_tensor.quantize(tensors, noop_flag=noop_flag) - - return grouped_tensor - def quantize( self, tensors: List[torch.Tensor],