From c9c8b4a8c966d1376ce83502436837a80902333e Mon Sep 17 00:00:00 2001 From: MithunR Date: Wed, 11 Mar 2026 13:33:23 -0700 Subject: [PATCH] BloomFilter v2 support Fixes #14148. This commit adds support for the new BloomFilter v2 format that was added in Apache Spark 4.1.1 (via https://github.com/apache/spark/commit/a08d8b093c0ec09cc6ce2c3642502f4842aebd86). The v1 format used INT32s for bit index calculation. When the number of items in the bloom-filter approaches INT_MAX, one sees a higher rate of collisions. The v2 format uses INT64 values for bit index calculations, allowing the full bit space to be addressed. Apparently, this reduces the false positive rates for large filters. Before the fix in this current PR was applied to spark-rapids, only certain bloom filter join tests would fail against Apache Spark 4.1.1; specifically: 1. `test_bloom_filter_join_cpu_build`, where the bloom filter is built on CPU and then probed on GPU. This failed because the CPU would produce a v2 filter that couldn't be treated as a v1 format on GPU. 2. `test_bloom_filter_join_split_cpu_build`, where the bloom filter is partially aggregated on CPU, then merged on GPU. Again, the GPU-side merging expected v1 format, while the CPU produced v2. Note that `test_bloom_filter_join_cpu_probe` and `test_bloom_filter_join` did not actually fail on 4.1.1. That is because: 1. `test_bloom_filter_join_cpu_probe` tests CPU probing, which supports v1 and v2 flexibly. 2. `test_bloom_filter_join` tests both the build and probe jointly being either on CPU, or GPU. The CPU ran v2 format, the GPU ran v1. Both produce the same query results, albeit with different formats. The fix in this commit allows for v1 and v2 formats to be jointly supported on GPU, depending on the Spark version. Signed-off-by: MithunR --- .../src/main/python/join_test.py | 2 - .../nvidia/spark/rapids/GpuBloomFilter.scala | 57 ++++++++++++++----- .../shims/BloomFilterConstantsShims.scala | 49 ++++++++++++++++ .../spark/rapids/shims/BloomFilterShims.scala | 5 +- .../aggregate/GpuBloomFilterAggregate.scala | 15 +++-- .../shims/BloomFilterConstantsShims.scala | 24 ++++++++ .../BloomFilterAggregateQuerySuite.scala | 17 +++++- 7 files changed, 148 insertions(+), 21 deletions(-) create mode 100644 sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala create mode 100644 sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 5225a2082b5..bd2e357eb4d 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -1502,7 +1502,6 @@ def test_bloom_filter_join_cpu_probe(is_multi_column, kudo_enabled): @pytest.mark.parametrize("is_multi_column", [False, True], ids=idfn) @pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921") @pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0") -@pytest.mark.xfail(condition=is_spark_411_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/14148") @pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn) def test_bloom_filter_join_cpu_build(is_multi_column, kudo_enabled): conf = {"spark.rapids.sql.expression.BloomFilterAggregate": "false", @@ -1517,7 +1516,6 @@ def test_bloom_filter_join_cpu_build(is_multi_column, kudo_enabled): @pytest.mark.parametrize("is_multi_column", [False, True], ids=idfn) @pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921") @pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0") -@pytest.mark.xfail(condition=is_spark_411_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/14148") @pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn) def test_bloom_filter_join_split_cpu_build(agg_replace_mode, is_multi_column, kudo_enabled): conf = {"spark.rapids.sql.hashAgg.replaceMode": agg_replace_mode, diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/GpuBloomFilter.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/GpuBloomFilter.scala index 74a37d875be..0dac6a4f79c 100644 --- a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/GpuBloomFilter.scala +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/GpuBloomFilter.scala @@ -45,7 +45,8 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids -import ai.rapids.cudf.{BaseDeviceMemoryBuffer, ColumnVector, Cuda, DeviceMemoryBuffer, DType} +import ai.rapids.cudf.{BaseDeviceMemoryBuffer, ColumnVector, Cuda, DeviceMemoryBuffer, DType, + HostMemoryBuffer} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.jni.BloomFilter @@ -78,14 +79,33 @@ class GpuBloomFilter(buffer: DeviceMemoryBuffer) extends AutoCloseable { } object GpuBloomFilter { - // Spark serializes their bloom filters in a specific format, see BloomFilterImpl.readFrom. - // Data is written via DataOutputStream, so everything is big-endian. - // Byte Offset Size Description - // 0 4 Version ID (see Spark's BloomFilter.Version) - // 4 4 Number of hash functions - // 8 4 Number of longs, N - // 12 N*8 Bloom filter data buffer as longs - private val HEADER_SIZE = 12 + // Spark serializes bloom filters in one of two formats. All values are big-endian. + // + // V1 (12-byte header): + // Byte Offset Size Description + // 0 4 Version ID (1) + // 4 4 Number of hash functions + // 8 4 Number of longs, N + // 12 N*8 Bloom filter data buffer as longs + // + // V2 (16-byte header): + // Byte Offset Size Description + // 0 4 Version ID (2) + // 4 4 Number of hash functions + // 8 4 Hash seed + // 12 4 Number of longs, N + // 16 N*8 Bloom filter data buffer as longs + private val HEADER_SIZE_V1 = 12 + private val HEADER_SIZE_V2 = 16 + + private def readVersionFromDevice(data: BaseDeviceMemoryBuffer): Int = { + withResource(data.sliceWithCopy(0, 4)) { versionSlice => + withResource(HostMemoryBuffer.allocate(4)) { hostBuf => + hostBuf.copyFromDeviceBuffer(versionSlice) + Integer.reverseBytes(hostBuf.getInt(0)) + } + } + } def apply(s: GpuScalar): GpuBloomFilter = { s.dataType match { @@ -100,11 +120,22 @@ object GpuBloomFilter { } def deserialize(data: BaseDeviceMemoryBuffer): GpuBloomFilter = { - // Sanity check bloom filter header val totalLen = data.getLength - val bitBufferLen = totalLen - HEADER_SIZE - require(totalLen >= HEADER_SIZE, s"header size is $totalLen") - require(bitBufferLen % 8 == 0, "buffer length not a multiple of 8") + require(totalLen >= HEADER_SIZE_V1, s"buffer too small: $totalLen") + + val version = readVersionFromDevice(data) + val headerSize = version match { + case 1 => HEADER_SIZE_V1 + case 2 => HEADER_SIZE_V2 + case _ => throw new IllegalArgumentException( + s"Unknown bloom filter version: $version") + } + require(totalLen >= headerSize, + s"buffer too small for bloom filter V$version: $totalLen") + val bitBufferLen = totalLen - headerSize + require(bitBufferLen % 8 == 0, + s"bit buffer length ($bitBufferLen) not a multiple of 8") + val filterBuffer = DeviceMemoryBuffer.allocate(totalLen) closeOnExcept(filterBuffer) { buf => buf.copyFromDeviceBufferAsync(0, data, 0, buf.getLength, Cuda.DEFAULT_STREAM) diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala new file mode 100644 index 00000000000..3332ba7d8f4 --- /dev/null +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +{"spark": "330db"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332db"} +{"spark": "333"} +{"spark": "334"} +{"spark": "340"} +{"spark": "341"} +{"spark": "341db"} +{"spark": "342"} +{"spark": "343"} +{"spark": "344"} +{"spark": "350"} +{"spark": "350db143"} +{"spark": "351"} +{"spark": "352"} +{"spark": "353"} +{"spark": "354"} +{"spark": "355"} +{"spark": "356"} +{"spark": "357"} +{"spark": "358"} +{"spark": "400"} +{"spark": "401"} +{"spark": "402"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +object BloomFilterConstantsShims { + val BLOOM_FILTER_FORMAT_VERSION: Int = 1 +} \ No newline at end of file diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala index 7590a075e89..707d83fc608 100644 --- a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala @@ -46,6 +46,7 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.jni.BloomFilter import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate @@ -80,7 +81,9 @@ object BloomFilterShims { GpuBloomFilterAggregate( childExprs.head.convertToGpu(), a.estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue, - a.numBitsExpression.eval().asInstanceOf[Number].longValue) + a.numBitsExpression.eval().asInstanceOf[Number].longValue, + BloomFilterConstantsShims.BLOOM_FILTER_FORMAT_VERSION, + BloomFilter.DEFAULT_SEED) } }) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/aggregate/GpuBloomFilterAggregate.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/aggregate/GpuBloomFilterAggregate.scala index 2e0cab83747..05ea339a6b1 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/aggregate/GpuBloomFilterAggregate.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/aggregate/GpuBloomFilterAggregate.scala @@ -59,7 +59,9 @@ import org.apache.spark.sql.types.{BinaryType, DataType} case class GpuBloomFilterAggregate( child: Expression, estimatedNumItemsRequested: Long, - numBitsRequested: Long) extends GpuAggregateFunction { + numBitsRequested: Long, + version: Int = BloomFilter.VERSION_2, + seed: Int = BloomFilter.DEFAULT_SEED) extends GpuAggregateFunction { override def nullable: Boolean = true @@ -81,7 +83,8 @@ case class GpuBloomFilterAggregate( override val inputProjection: Seq[Expression] = Seq(child) - override val updateAggregates: Seq[CudfAggregate] = Seq(GpuBloomFilterUpdate(numHashes, numBits)) + override val updateAggregates: Seq[CudfAggregate] = + Seq(GpuBloomFilterUpdate(numHashes, numBits, version, seed)) override val mergeAggregates: Seq[CudfAggregate] = Seq(GpuBloomFilterMerge()) @@ -110,9 +113,13 @@ object GpuBloomFilterAggregate { } } -case class GpuBloomFilterUpdate(numHashes: Int, numBits: Long) extends CudfAggregate { +case class GpuBloomFilterUpdate( + numHashes: Int, + numBits: Long, + version: Int, + seed: Int) extends CudfAggregate { override val reductionAggregate: ColumnVector => Scalar = (col: ColumnVector) => { - closeOnExcept(BloomFilter.create(numHashes, numBits)) { bloomFilter => + closeOnExcept(BloomFilter.create(version, numHashes, numBits, seed)) { bloomFilter => BloomFilter.put(bloomFilter, col) bloomFilter } diff --git a/sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala b/sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala new file mode 100644 index 00000000000..66b12a915f6 --- /dev/null +++ b/sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "411"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +object BloomFilterConstantsShims { + val BLOOM_FILTER_FORMAT_VERSION: Int = 2 +} \ No newline at end of file diff --git a/tests/src/test/spark330/scala/com/nvidia/spark/rapids/BloomFilterAggregateQuerySuite.scala b/tests/src/test/spark330/scala/com/nvidia/spark/rapids/BloomFilterAggregateQuerySuite.scala index b13d8266333..3936407e406 100644 --- a/tests/src/test/spark330/scala/com/nvidia/spark/rapids/BloomFilterAggregateQuerySuite.scala +++ b/tests/src/test/spark330/scala/com/nvidia/spark/rapids/BloomFilterAggregateQuerySuite.scala @@ -177,8 +177,9 @@ class BloomFilterAggregateQuerySuite extends SparkQueryCompareTestSuite { } } + // V1 literal: version=1, numHashes=5, numLongs=3, followed by 3 longs of bit data testSparkResultsAreEqual( - "might_contain with literal bloom filter buffer", + "might_contain with V1 literal bloom filter buffer", spark => spark.range(1, 1).asInstanceOf[DataFrame], conf=bloomFilterEnabledConf.clone()) { df => @@ -190,6 +191,20 @@ class BloomFilterAggregateQuerySuite extends SparkQueryCompareTestSuite { } } + // V2 literal: version=2, numHashes=5, seed=0, numLongs=3, followed by 3 longs of bit data + testSparkResultsAreEqual( + "might_contain with V2 literal bloom filter buffer", + spark => spark.range(1, 1).asInstanceOf[DataFrame], + conf=bloomFilterEnabledConf.clone()) { + df => + withExposedSqlFuncs(df.sparkSession) { spark => + spark.sql( + """SELECT might_contain( + |X'0000000200000005000000000000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', + |cast(201 as long))""".stripMargin) + } + } + testSparkResultsAreEqual( "might_contain with all NULL inputs", spark => spark.range(1, 1).asInstanceOf[DataFrame],