From 2ab55570ec443419c2c540ebe157837f02961fb3 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 19 Dec 2025 17:24:22 +0800 Subject: [PATCH 1/9] AI draft for protocol buffer support Signed-off-by: Haoyang Li --- integration_tests/pom.xml | 20 ++ integration_tests/run_pyspark_from_build.sh | 23 +- integration_tests/src/main/python/data_gen.py | 110 ++++++++ .../src/main/python/protobuf_test.py | 229 +++++++++++++++++ pom.xml | 11 + .../protobuf/ProtobufDescriptorUtils.scala | 82 ++++++ .../sql/rapids/GpuFromProtobufSimple.scala | 79 ++++++ .../rapids/shims/ProtobufExprShims.scala | 235 ++++++++++++++++++ .../rapids/shims/Spark340PlusNonDBShims.scala | 2 +- 9 files changed, 788 insertions(+), 3 deletions(-) create mode 100644 integration_tests/src/main/python/protobuf_test.py create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala create mode 100644 sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala diff --git a/integration_tests/pom.xml b/integration_tests/pom.xml index e3d91be0ce3..825083b7fbe 100644 --- a/integration_tests/pom.xml +++ b/integration_tests/pom.xml @@ -142,6 +142,7 @@ parquet-hadoop*.jar spark-avro*.jar + spark-protobuf*.jar @@ -176,6 +177,24 @@ + + copy-spark-protobuf + package + + copy + + + ${spark.protobuf.copy.skip} + true + + + org.apache.spark + spark-protobuf_${scala.binary.version} + ${spark.version} + + + + @@ -216,4 +235,5 @@ + diff --git a/integration_tests/run_pyspark_from_build.sh b/integration_tests/run_pyspark_from_build.sh index 6550a3cc59f..baf04d44282 100755 --- a/integration_tests/run_pyspark_from_build.sh +++ b/integration_tests/run_pyspark_from_build.sh @@ -29,6 +29,7 @@ # - SPARK_HOME: Path to your Apache Spark installation. # - SKIP_TESTS: If set to true, skips running the Python integration tests. # - INCLUDE_SPARK_AVRO_JAR: If set to true, includes Avro tests. +# - INCLUDE_SPARK_PROTOBUF_JAR: If set to true, includes spark-protobuf (Spark 3.4.0+) on the JVM classpath. # - TEST: Specifies a specific test to run. # - TEST_TAGS: Allows filtering tests based on tags. # - TEST_TYPE: Specifies the type of tests to run. @@ -100,6 +101,7 @@ else # support alternate local jars NOT building from the source code if [ -d "$LOCAL_JAR_PATH" ]; then AVRO_JARS=$(echo "$LOCAL_JAR_PATH"/spark-avro*.jar) + PROTOBUF_JARS=$(echo "$LOCAL_JAR_PATH"/spark-protobuf*.jar) PLUGIN_JAR=$(echo "$LOCAL_JAR_PATH"/rapids-4-spark_*.jar) if [ -f $(echo $LOCAL_JAR_PATH/parquet-hadoop*.jar) ]; then export INCLUDE_PARQUET_HADOOP_TEST_JAR=true @@ -116,6 +118,7 @@ else else [[ "$SCALA_VERSION" != "2.12" ]] && TARGET_DIR=${TARGET_DIR/integration_tests/scala$SCALA_VERSION\/integration_tests} AVRO_JARS=$(echo "$TARGET_DIR"/dependency/spark-avro*.jar) + PROTOBUF_JARS=$(echo "$TARGET_DIR"/dependency/spark-protobuf*.jar) PARQUET_HADOOP_TESTS=$(echo "$TARGET_DIR"/dependency/parquet-hadoop*.jar) # remove the log4j.properties file so it doesn't conflict with ours, ignore errors # if it isn't present or already removed @@ -141,9 +144,25 @@ else AVRO_JARS="" fi - # ALL_JARS includes dist.jar integration-test.jar avro.jar parquet.jar if they exist + # spark-protobuf is an optional Spark module that exists in Spark 3.4.0+. If we have the jar staged + # under target/dependency, include it so from_protobuf() is callable from PySpark. + if [[ $( echo ${INCLUDE_SPARK_PROTOBUF_JAR:-true} | tr '[:upper:]' '[:lower:]' ) == "true" ]]; + then + # VERSION_STRING >= 3.4.0 ? + if printf '%s\n' "3.4.0" "$VERSION_STRING" | sort -V | head -1 | grep -qx "3.4.0"; then + export INCLUDE_SPARK_PROTOBUF_JAR=true + else + export INCLUDE_SPARK_PROTOBUF_JAR=false + PROTOBUF_JARS="" + fi + else + export INCLUDE_SPARK_PROTOBUF_JAR=false + PROTOBUF_JARS="" + fi + + # ALL_JARS includes dist.jar integration-test.jar avro.jar protobuf.jar parquet.jar if they exist # Remove non-existing paths and canonicalize the paths including get rid of links and `..` - ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PARQUET_HADOOP_TESTS || true) + ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PROTOBUF_JARS $PARQUET_HADOOP_TESTS || true) # `:` separated jars ALL_JARS="${ALL_JARS//$'\n'/:}" diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index fa7decac82d..837d4990832 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -857,6 +857,116 @@ def gen_bytes(): return bytes([ rand.randint(0, 255) for _ in range(length) ]) self._start(rand, gen_bytes) + +# ----------------------------------------------------------------------------- +# Protobuf (simple types) generators/utilities (for from_protobuf/to_protobuf tests) +# ----------------------------------------------------------------------------- + +_PROTOBUF_WIRE_VARINT = 0 +_PROTOBUF_WIRE_64BIT = 1 +_PROTOBUF_WIRE_LEN_DELIM = 2 +_PROTOBUF_WIRE_32BIT = 5 + +def _encode_protobuf_uvarint(value): + """Encode a non-negative integer as protobuf varint.""" + if value is None: + raise ValueError("value must not be None") + if value < 0: + raise ValueError("uvarint only supports non-negative integers") + out = bytearray() + v = int(value) + while True: + b = v & 0x7F + v >>= 7 + if v: + out.append(b | 0x80) + else: + out.append(b) + break + return bytes(out) + +def _encode_protobuf_key(field_number, wire_type): + return _encode_protobuf_uvarint((int(field_number) << 3) | int(wire_type)) + +def _encode_protobuf_field(field_number, spark_type, value): + """ + Encode a single protobuf field for a subset of scalar types. + Notes on signed ints: + - Protobuf `int32`/`int64` use *varint* encoding of the two's-complement integer. + - Negative `int32` values are encoded as a 10-byte varint (because they are sign-extended to 64 bits). + """ + if value is None: + return b"" + + if isinstance(spark_type, BooleanType): + return _encode_protobuf_key(field_number, _PROTOBUF_WIRE_VARINT) + _encode_protobuf_uvarint(1 if value else 0) + elif isinstance(spark_type, IntegerType): + # Match protobuf-java behavior for writeInt32NoTag: negative values are sign-extended and written as uint64. + u64 = int(value) & 0xFFFFFFFFFFFFFFFF + return _encode_protobuf_key(field_number, _PROTOBUF_WIRE_VARINT) + _encode_protobuf_uvarint(u64) + elif isinstance(spark_type, LongType): + u64 = int(value) & 0xFFFFFFFFFFFFFFFF + return _encode_protobuf_key(field_number, _PROTOBUF_WIRE_VARINT) + _encode_protobuf_uvarint(u64) + elif isinstance(spark_type, FloatType): + return _encode_protobuf_key(field_number, _PROTOBUF_WIRE_32BIT) + struct.pack(" bool: + """ + `spark-protobuf` is an optional external module. PySpark may have the Python wrappers + even when the JVM side isn't present on the classpath, which manifests as: + TypeError: 'JavaPackage' object is not callable + when calling into `sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf`. + """ + jvm = spark.sparkContext._jvm + candidates = [ + # Scala object `functions` compiles to `functions$` + "org.apache.spark.sql.protobuf.functions$", + # Some environments may expose it differently + "org.apache.spark.sql.protobuf.functions", + ] + for cls in candidates: + try: + jvm.java.lang.Class.forName(cls) + return True + except Exception: + continue + return False + + +def _build_simple_descriptor_set_bytes(spark): + """ + Build a FileDescriptorSet for: + package test; + syntax = "proto2"; + message Simple { + optional bool b = 1; + optional int32 i32 = 2; + optional int64 i64 = 3; + optional float f32 = 4; + optional double f64 = 5; + optional string s = 6; + } + """ + jvm = spark.sparkContext._jvm + D = jvm.com.google.protobuf.DescriptorProtos + + fd = D.FileDescriptorProto.newBuilder() \ + .setName("simple.proto") \ + .setPackage("test") + # Some Spark distributions bring an older protobuf-java where FileDescriptorProto.Builder + # does not expose setSyntax(String). For this test we only need proto2 semantics, and + # leaving syntax unset is sufficient/compatible. + try: + fd = fd.setSyntax("proto2") + except Exception: + pass + + msg = D.DescriptorProto.newBuilder().setName("Simple") + label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL + + def add_field(name, number, ftype): + msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName(name) + .setNumber(number) + .setLabel(label_opt) + .setType(ftype) + .build() + ) + + add_field("b", 1, D.FieldDescriptorProto.Type.TYPE_BOOL) + add_field("i32", 2, D.FieldDescriptorProto.Type.TYPE_INT32) + add_field("i64", 3, D.FieldDescriptorProto.Type.TYPE_INT64) + add_field("f32", 4, D.FieldDescriptorProto.Type.TYPE_FLOAT) + add_field("f64", 5, D.FieldDescriptorProto.Type.TYPE_DOUBLE) + add_field("s", 6, D.FieldDescriptorProto.Type.TYPE_STRING) + + fd.addMessageType(msg.build()) + + fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build() + # py4j converts Java byte[] to a Python bytes-like object + return bytes(fds.toByteArray()) + + +def _write_bytes_to_hadoop_path(spark, path_str, data_bytes): + sc = spark.sparkContext + config = sc._jsc.hadoopConfiguration() + jpath = sc._jvm.org.apache.hadoop.fs.Path(path_str) + fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(config) + out = fs.create(jpath, True) + try: + out.write(bytearray(data_bytes)) + finally: + out.close() + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_simple_parquet_binary_round_trip(spark_tmp_path): + from_protobuf = _try_import_from_protobuf() + # if from_protobuf is None: + # pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available") + # if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)): + # pytest.skip("spark-protobuf JVM module is not available on the classpath") + + data_path = spark_tmp_path + "/PROTOBUF_SIMPLE_PARQUET/" + desc_path = spark_tmp_path + "/simple.desc" + message_name = "test.Simple" + + # Generate descriptor bytes once using the JVM (no protoc dependency) + desc_bytes = with_cpu_session(lambda spark: _build_simple_descriptor_set_bytes(spark)) + with_cpu_session(lambda spark: _write_bytes_to_hadoop_path(spark, desc_path, desc_bytes)) + + # Build a DF with scalar columns + binary protobuf column and write to parquet + row_gen = ProtobufSimpleMessageRowGen([ + ("b", 1, BooleanGen(nullable=True)), + ("i32", 2, IntegerGen(nullable=True, min_val=0, max_val=1 << 20)), + ("i64", 3, LongGen(nullable=True, min_val=0, max_val=1 << 40, special_cases=[])), + ("f32", 4, FloatGen(nullable=True, no_nans=True)), + ("f64", 5, DoubleGen(nullable=True, no_nans=True)), + ("s", 6, StringGen(nullable=True)), + ], binary_col_name="bin") + + def write_parquet(spark): + df = gen_df(spark, row_gen, length=512) + df.write.mode("overwrite").parquet(data_path) + + with_cpu_session(write_parquet) + + # Sanity check correctness on CPU (decoded struct matches the original scalar columns) + def cpu_correctness_check(spark): + df = spark.read.parquet(data_path) + expected = f.struct( + f.col("b").alias("b"), + f.col("i32").alias("i32"), + f.col("i64").alias("i64"), + f.col("f32").alias("f32"), + f.col("f64").alias("f64"), + f.col("s").alias("s"), + ).alias("expected") + + sig = inspect.signature(from_protobuf) + if "binaryDescriptorSet" in sig.parameters: + decoded = from_protobuf(f.col("bin"), message_name, binaryDescriptorSet=bytearray(desc_bytes)).alias("decoded") + else: + decoded = from_protobuf(f.col("bin"), message_name, desc_path).alias("decoded") + + rows = df.select(expected, decoded).collect() + for r in rows: + assert r["expected"] == r["decoded"] + + with_cpu_session(cpu_correctness_check) + + # Main assertion: CPU and GPU results match for from_protobuf on a binary column read from parquet + def run_on_spark(spark): + df = spark.read.parquet(data_path) + sig = inspect.signature(from_protobuf) + if "binaryDescriptorSet" in sig.parameters: + decoded = from_protobuf(f.col("bin"), message_name, binaryDescriptorSet=bytearray(desc_bytes)) + else: + decoded = from_protobuf(f.col("bin"), message_name, desc_path) + return df.select(decoded.alias("decoded")) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_simple_null_input_returns_null(spark_tmp_path): + from_protobuf = _try_import_from_protobuf() + desc_path = spark_tmp_path + "/simple_null_input.desc" + message_name = "test.Simple" + + # Generate descriptor bytes once using the JVM (no protoc dependency) + desc_bytes = with_cpu_session(lambda spark: _build_simple_descriptor_set_bytes(spark)) + with_cpu_session(lambda spark: _write_bytes_to_hadoop_path(spark, desc_path, desc_bytes)) + + # Spark's ProtobufDataToCatalyst is NullIntolerant (null input -> null output). + def run_on_spark(spark): + df = spark.createDataFrame( + [(None,), (bytes([0x08, 0x01, 0x10, 0x7B]),)], # b=true, i32=123 + schema="bin binary", + ) + sig = inspect.signature(from_protobuf) + if "binaryDescriptorSet" in sig.parameters: + decoded = from_protobuf( + f.col("bin"), + message_name, + binaryDescriptorSet=bytearray(desc_bytes), + ) + else: + decoded = from_protobuf(f.col("bin"), message_name, desc_path) + return df.select(decoded.alias("decoded")) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + diff --git a/pom.xml b/pom.xml index 6eeff9d35be..8679b7ddf7e 100644 --- a/pom.xml +++ b/pom.xml @@ -318,6 +318,7 @@ 1.12.3 rapids-4-spark-delta-24x 2.0.6 + false delta-lake/delta-24x @@ -338,6 +339,7 @@ 1.12.3 rapids-4-spark-delta-24x 2.0.6 + false delta-lake/delta-24x @@ -358,6 +360,7 @@ 1.12.3 rapids-4-spark-delta-24x 2.0.6 + false delta-lake/delta-24x @@ -378,6 +381,7 @@ 1.12.3 rapids-4-spark-delta-24x 2.0.6 + false delta-lake/delta-24x @@ -398,6 +402,7 @@ 1.12.3 rapids-4-spark-delta-24x 2.0.6 + false delta-lake/delta-24x @@ -895,6 +900,12 @@ developer false + + + true diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala new file mode 100644 index 00000000000..f40cc2af03f --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.protobuf + +import scala.collection.mutable +import scala.collection.JavaConverters._ + +import com.google.protobuf.DescriptorProtos +import com.google.protobuf.Descriptors + +/** + * Minimal descriptor utilities for locating a message descriptor in a FileDescriptorSet. + * + * This is intentionally lightweight for the "simple types" from_protobuf patch: it supports + * descriptor sets produced by `protoc --include_imports --descriptor_set_out=...`. + */ +object ProtobufDescriptorUtils { + + def buildMessageDescriptor( + fileDescriptorSetBytes: Array[Byte], + messageName: String): Descriptors.Descriptor = { + val fds = DescriptorProtos.FileDescriptorSet.parseFrom(fileDescriptorSetBytes) + val protos = fds.getFileList.asScala.toSeq + val byName = protos.map(p => p.getName -> p).toMap + val cache = mutable.HashMap.empty[String, Descriptors.FileDescriptor] + + def buildFileDescriptor(name: String): Descriptors.FileDescriptor = { + cache.getOrElseUpdate(name, { + val p = byName.getOrElse(name, + throw new IllegalArgumentException(s"Missing FileDescriptorProto for '$name'")) + val deps = p.getDependencyList.asScala.map(buildFileDescriptor _).toArray + Descriptors.FileDescriptor.buildFrom(p, deps) + }) + } + + val fileDescriptors = protos.map(p => buildFileDescriptor(p.getName)) + val candidates = fileDescriptors.iterator.flatMap(fd => findMessageDescriptors(fd, messageName)) + .toSeq + + candidates match { + case Seq(d) => d + case Seq() => + throw new IllegalArgumentException( + s"Message '$messageName' not found in FileDescriptorSet") + case many => + val names = many.map(_.getFullName).distinct.sorted + throw new IllegalArgumentException( + s"Message '$messageName' is ambiguous; matches: ${names.mkString(", ")}") + } + } + + private def findMessageDescriptors( + fd: Descriptors.FileDescriptor, + messageName: String): Iterator[Descriptors.Descriptor] = { + def matches(d: Descriptors.Descriptor): Boolean = { + d.getName == messageName || d.getFullName == messageName || d.getFullName.endsWith("." + messageName) + } + + def walk(d: Descriptors.Descriptor): Iterator[Descriptors.Descriptor] = { + val nested = d.getNestedTypes.asScala.iterator.flatMap(walk _) + if (matches(d)) Iterator.single(d) ++ nested else nested + } + + fd.getMessageTypes.asScala.iterator.flatMap(walk _) + } +} + + diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala new file mode 100644 index 00000000000..73c23fe2f82 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import ai.rapids.cudf +import ai.rapids.cudf.BinaryOp +import ai.rapids.cudf.DType +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.{GpuColumnVector, GpuUnaryExpression} +import com.nvidia.spark.rapids.jni.ProtobufSimple +import com.nvidia.spark.rapids.shims.NullIntolerantShim + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression} +import org.apache.spark.sql.types._ + +/** + * GPU implementation for Spark's `from_protobuf` decode path (simple types only). + * + * This is designed to replace `org.apache.spark.sql.protobuf.ProtobufDataToCatalyst` when supported. + */ +case class GpuFromProtobufSimple( + outputSchema: StructType, + fieldNumbers: Array[Int], + cudfTypeIds: Array[Int], + cudfTypeScales: Array[Int], + child: Expression) + extends GpuUnaryExpression with ExpectsInputTypes with NullIntolerantShim { + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override def dataType: DataType = outputSchema.asNullable + + override def nullable: Boolean = true + + override protected def doColumnar(input: GpuColumnVector): cudf.ColumnVector = { + // Spark BinaryType is represented in cuDF as a LIST. + // ProtobufSimple returns a non-null STRUCT with nullable children. Spark's + // ProtobufDataToCatalyst is NullIntolerant, so if the input binary row is null the output + // struct row must be null as well. + val decoded = ProtobufSimple.decodeToStruct(input.getBase, fieldNumbers, cudfTypeIds, cudfTypeScales) + if (input.getBase.hasNulls) { + withResource(decoded) { _ => + decoded.mergeAndSetValidity(BinaryOp.BITWISE_AND, input.getBase) + } + } else { + decoded + } + } +} + +object GpuFromProtobufSimple { + def sparkTypeToCudfId(dt: DataType): (Int, Int) = dt match { + case BooleanType => (DType.BOOL8.getTypeId.getNativeId, 0) + case IntegerType => (DType.INT32.getTypeId.getNativeId, 0) + case LongType => (DType.INT64.getTypeId.getNativeId, 0) + case FloatType => (DType.FLOAT32.getTypeId.getNativeId, 0) + case DoubleType => (DType.FLOAT64.getTypeId.getNativeId, 0) + case StringType => (DType.STRING.getTypeId.getNativeId, 0) + case other => + throw new IllegalArgumentException(s"Unsupported Spark type for protobuf(simple): $other") + } +} + + + diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala new file mode 100644 index 00000000000..a75dda64b14 --- /dev/null +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala @@ -0,0 +1,235 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "340"} +{"spark": "341"} +{"spark": "342"} +{"spark": "343"} +{"spark": "344"} +{"spark": "350"} +{"spark": "351"} +{"spark": "352"} +{"spark": "353"} +{"spark": "354"} +{"spark": "355"} +{"spark": "356"} +{"spark": "357"} +{"spark": "400"} +{"spark": "401"} +spark-rapids-shim-json-lines ***/ + +package com.nvidia.spark.rapids.shims + +import java.nio.file.{Files, Path} + +import scala.util.Try + +import com.nvidia.spark.rapids._ +import org.apache.spark.sql.rapids.GpuFromProtobufSimple + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.UnaryExpression +import org.apache.spark.sql.types._ + +/** + * Spark 3.4+ optional integration for spark-protobuf expressions. + * + * spark-protobuf is an external module, so these rules must be registered by reflection. + */ +object ProtobufExprShims { + private[this] val protobufDataToCatalystClassName = + "org.apache.spark.sql.protobuf.ProtobufDataToCatalyst" + + private[this] val sparkProtobufUtilsObjectClassName = + "org.apache.spark.sql.protobuf.utils.ProtobufUtils$" + + def exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { + try { + val clazz = ShimReflectionUtils.loadClass(protobufDataToCatalystClassName) + .asInstanceOf[Class[_ <: UnaryExpression]] + Map(clazz.asInstanceOf[Class[_ <: Expression]] -> fromProtobufRule) + } catch { + case _: ClassNotFoundException => Map.empty + } + } + + private def fromProtobufRule: ExprRule[_ <: Expression] = { + GpuOverrides.expr[UnaryExpression]( + "Decode a BinaryType column (protobuf) into a Spark SQL struct (simple types only)", + ExprChecks.unaryProject( + // Output is a struct; the rule does detailed checks in tagExprForGpu. + TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRING), + TypeSig.all, + TypeSig.BINARY, + TypeSig.BINARY), + (e, conf, p, r) => new UnaryExprMeta[UnaryExpression](e, conf, p, r) { + + private var schema: StructType = _ + private var fieldNumbers: Array[Int] = _ + private var cudfTypeIds: Array[Int] = _ + private var cudfTypeScales: Array[Int] = _ + + override def tagExprForGpu(): Unit = { + schema = e.dataType match { + case st: StructType => st + case other => + willNotWorkOnGpu(s"Only StructType output is supported for from_protobuf(simple), got $other") + return + } + + val options = getOptionsMap(e) + if (options.nonEmpty) { + willNotWorkOnGpu(s"from_protobuf options are not supported yet on GPU: ${options.keys.mkString(",")}") + return + } + + val messageName = getMessageName(e) + val descFilePathOpt = getDescFilePath(e).orElse { + // Newer Spark may embed a descriptor set (binaryDescriptorSet). Write it to a temp file so we can + // reuse Spark's own ProtobufUtils + shaded protobuf classes to resolve the descriptor. + getDescriptorBytes(e).map(writeTempDescFile) + } + if (descFilePathOpt.isEmpty) { + willNotWorkOnGpu("from_protobuf(simple) requires a descriptor set (descFilePath or binaryDescriptorSet)") + return + } + + val msgDesc = try { + // Spark 3.4.x builds the descriptor as: ProtobufUtils.buildDescriptor(messageName, descFilePathOpt) + buildMessageDescriptorWithSparkProtobuf(messageName, descFilePathOpt) + } catch { + case t: Throwable => + willNotWorkOnGpu(s"Failed to resolve protobuf descriptor for message '$messageName': ${t.getMessage}") + return + } + + val fields = schema.fields + val fnums = new Array[Int](fields.length) + val typeIds = new Array[Int](fields.length) + val scales = new Array[Int](fields.length) + + fields.zipWithIndex.foreach { case (sf, idx) => + sf.dataType match { + case BooleanType | IntegerType | LongType | FloatType | DoubleType | StringType => + case other => + willNotWorkOnGpu(s"Unsupported field type for from_protobuf(simple): ${sf.name}: $other") + return + } + + val fd = invoke1[AnyRef](msgDesc, "findFieldByName", classOf[String], sf.name) + if (fd == null) { + willNotWorkOnGpu(s"Protobuf field '${sf.name}' not found in message '$messageName'") + return + } + + val isRepeated = Try(invoke0[java.lang.Boolean](fd, "isRepeated").booleanValue()).getOrElse(false) + if (isRepeated) { + willNotWorkOnGpu(s"Repeated fields are not supported for from_protobuf(simple): ${sf.name}") + return + } + + val protoType = invoke0[AnyRef](fd, "getType") + val protoTypeName = typeName(protoType) + val ok = (sf.dataType, protoTypeName) match { + case (BooleanType, "BOOL") => true + case (IntegerType, "INT32") => true + case (LongType, "INT64") => true + case (FloatType, "FLOAT") => true + case (DoubleType, "DOUBLE") => true + case (StringType, "STRING") => true + case _ => false + } + if (!ok) { + willNotWorkOnGpu(s"Field type mismatch for '${sf.name}': Spark ${sf.dataType} vs Protobuf $protoTypeName") + return + } + + fnums(idx) = invoke0[java.lang.Integer](fd, "getNumber").intValue() + val (tid, scale) = GpuFromProtobufSimple.sparkTypeToCudfId(sf.dataType) + typeIds(idx) = tid + scales(idx) = scale + } + + fieldNumbers = fnums + cudfTypeIds = typeIds + cudfTypeScales = scales + } + + override def convertToGpu(child: Expression): GpuExpression = { + GpuFromProtobufSimple(schema, fieldNumbers, cudfTypeIds, cudfTypeScales, child) + } + } + ) + } + + private def getMessageName(e: Expression): String = + invoke0[String](e, "messageName") + + /** + * Newer Spark versions may carry an in-expression descriptor set payload (e.g. binaryDescriptorSet). + * Spark 3.4.x does not, so callers should fall back to descFilePath(). + */ + private def getDescriptorBytes(e: Expression): Option[Array[Byte]] = { + // Spark 4.x/3.5+ (depending on the API): may be Array[Byte] or Option[Array[Byte]]. + val direct = Try(invoke0[Array[Byte]](e, "binaryDescriptorSet")).toOption + direct.orElse { + Try(invoke0[Option[Array[Byte]]](e, "binaryDescriptorSet")).toOption.flatten + } + } + + private def getDescFilePath(e: Expression): Option[String] = + Try(invoke0[Option[String]](e, "descFilePath")).toOption.flatten + + private def writeTempDescFile(descBytes: Array[Byte]): String = { + val tmp: Path = Files.createTempFile("spark-rapids-protobuf-desc-", ".desc") + Files.write(tmp, descBytes) + tmp.toFile.deleteOnExit() + tmp.toString + } + + private def buildMessageDescriptorWithSparkProtobuf( + messageName: String, + descFilePathOpt: Option[String]): AnyRef = { + val cls = ShimReflectionUtils.loadClass(sparkProtobufUtilsObjectClassName) + val module = cls.getField("MODULE$").get(null) + // buildDescriptor(messageName: String, descFilePath: Option[String]) + val m = cls.getMethod("buildDescriptor", classOf[String], classOf[scala.Option[_]]) + m.invoke(module, messageName, descFilePathOpt).asInstanceOf[AnyRef] + } + + private def typeName(t: AnyRef): String = { + if (t == null) { + "null" + } else { + // Prefer Enum.name() when available; fall back to toString. + Try(invoke0[String](t, "name")).getOrElse(t.toString) + } + } + + private def getOptionsMap(e: Expression): Map[String, String] = { + val opt = Try(invoke0[scala.collection.Map[String, String]](e, "options")).toOption + opt.map(_.toMap).getOrElse(Map.empty) + } + + private def invoke0[T](obj: AnyRef, method: String): T = + obj.getClass.getMethod(method).invoke(obj).asInstanceOf[T] + + private def invoke1[T](obj: AnyRef, method: String, arg0Cls: Class[_], arg0: AnyRef): T = + obj.getClass.getMethod(method, arg0Cls).invoke(obj, arg0).asInstanceOf[T] +} + + diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala index 6e28a071a00..cc406a156fd 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala @@ -162,7 +162,7 @@ trait Spark340PlusNonDBShims extends Spark331PlusNonDBShims { ), GpuElementAtMeta.elementAtRule(true) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap - super.getExprs ++ shimExprs + super.getExprs ++ shimExprs ++ ProtobufExprShims.exprs } override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand], From 084e9c2a65eb13f71e45243f65c536c0b85020b9 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 23 Dec 2025 17:00:38 +0800 Subject: [PATCH 2/9] style --- .../protobuf/ProtobufDescriptorUtils.scala | 6 ++- .../sql/rapids/GpuFromProtobufSimple.scala | 11 +++-- .../rapids/shims/ProtobufExprShims.scala | 45 ++++++++++++------- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala index f40cc2af03f..1975db14966 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala @@ -16,8 +16,8 @@ package com.nvidia.spark.rapids.protobuf -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable import com.google.protobuf.DescriptorProtos import com.google.protobuf.Descriptors @@ -67,7 +67,9 @@ object ProtobufDescriptorUtils { fd: Descriptors.FileDescriptor, messageName: String): Iterator[Descriptors.Descriptor] = { def matches(d: Descriptors.Descriptor): Boolean = { - d.getName == messageName || d.getFullName == messageName || d.getFullName.endsWith("." + messageName) + d.getName == messageName || + d.getFullName == messageName || + d.getFullName.endsWith("." + messageName) } def walk(d: Descriptors.Descriptor): Iterator[Descriptors.Descriptor] = { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala index 73c23fe2f82..7d85d277e40 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf import ai.rapids.cudf.BinaryOp import ai.rapids.cudf.DType -import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.{GpuColumnVector, GpuUnaryExpression} +import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.jni.ProtobufSimple import com.nvidia.spark.rapids.shims.NullIntolerantShim @@ -30,7 +30,8 @@ import org.apache.spark.sql.types._ /** * GPU implementation for Spark's `from_protobuf` decode path (simple types only). * - * This is designed to replace `org.apache.spark.sql.protobuf.ProtobufDataToCatalyst` when supported. + * This is designed to replace `org.apache.spark.sql.protobuf.ProtobufDataToCatalyst` when + * supported. */ case class GpuFromProtobufSimple( outputSchema: StructType, @@ -51,7 +52,11 @@ case class GpuFromProtobufSimple( // ProtobufSimple returns a non-null STRUCT with nullable children. Spark's // ProtobufDataToCatalyst is NullIntolerant, so if the input binary row is null the output // struct row must be null as well. - val decoded = ProtobufSimple.decodeToStruct(input.getBase, fieldNumbers, cudfTypeIds, cudfTypeScales) + val decoded = ProtobufSimple.decodeToStruct( + input.getBase, + fieldNumbers, + cudfTypeIds, + cudfTypeScales) if (input.getBase.hasNulls) { withResource(decoded) { _ => decoded.mergeAndSetValidity(BinaryOp.BITWISE_AND, input.getBase) diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala index a75dda64b14..629a119aaf8 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala @@ -39,10 +39,9 @@ import java.nio.file.{Files, Path} import scala.util.Try import com.nvidia.spark.rapids._ -import org.apache.spark.sql.rapids.GpuFromProtobufSimple -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.expressions.UnaryExpression +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} +import org.apache.spark.sql.rapids.GpuFromProtobufSimple import org.apache.spark.sql.types._ /** @@ -87,33 +86,42 @@ object ProtobufExprShims { schema = e.dataType match { case st: StructType => st case other => - willNotWorkOnGpu(s"Only StructType output is supported for from_protobuf(simple), got $other") + willNotWorkOnGpu( + s"Only StructType output is supported for from_protobuf(simple), got $other") return } val options = getOptionsMap(e) if (options.nonEmpty) { - willNotWorkOnGpu(s"from_protobuf options are not supported yet on GPU: ${options.keys.mkString(",")}") + val keys = options.keys.mkString(",") + willNotWorkOnGpu( + s"from_protobuf options are not supported yet on GPU: $keys") return } val messageName = getMessageName(e) val descFilePathOpt = getDescFilePath(e).orElse { - // Newer Spark may embed a descriptor set (binaryDescriptorSet). Write it to a temp file so we can - // reuse Spark's own ProtobufUtils + shaded protobuf classes to resolve the descriptor. + // Newer Spark may embed a descriptor set (binaryDescriptorSet). Write it to a temp file + // so we can reuse Spark's ProtobufUtils (and its shaded protobuf classes) to resolve + // the descriptor. getDescriptorBytes(e).map(writeTempDescFile) } if (descFilePathOpt.isEmpty) { - willNotWorkOnGpu("from_protobuf(simple) requires a descriptor set (descFilePath or binaryDescriptorSet)") + willNotWorkOnGpu( + "from_protobuf(simple) requires a descriptor set " + + "(descFilePath or binaryDescriptorSet)") return } val msgDesc = try { - // Spark 3.4.x builds the descriptor as: ProtobufUtils.buildDescriptor(messageName, descFilePathOpt) + // Spark 3.4.x builds the descriptor as: + // ProtobufUtils.buildDescriptor(messageName, descFilePathOpt) buildMessageDescriptorWithSparkProtobuf(messageName, descFilePathOpt) } catch { case t: Throwable => - willNotWorkOnGpu(s"Failed to resolve protobuf descriptor for message '$messageName': ${t.getMessage}") + willNotWorkOnGpu( + s"Failed to resolve protobuf descriptor for message '$messageName': " + + s"${t.getMessage}") return } @@ -126,7 +134,8 @@ object ProtobufExprShims { sf.dataType match { case BooleanType | IntegerType | LongType | FloatType | DoubleType | StringType => case other => - willNotWorkOnGpu(s"Unsupported field type for from_protobuf(simple): ${sf.name}: $other") + willNotWorkOnGpu( + s"Unsupported field type for from_protobuf(simple): ${sf.name}: $other") return } @@ -136,9 +145,12 @@ object ProtobufExprShims { return } - val isRepeated = Try(invoke0[java.lang.Boolean](fd, "isRepeated").booleanValue()).getOrElse(false) + val isRepeated = Try { + invoke0[java.lang.Boolean](fd, "isRepeated").booleanValue() + }.getOrElse(false) if (isRepeated) { - willNotWorkOnGpu(s"Repeated fields are not supported for from_protobuf(simple): ${sf.name}") + willNotWorkOnGpu( + s"Repeated fields are not supported for from_protobuf(simple): ${sf.name}") return } @@ -154,7 +166,9 @@ object ProtobufExprShims { case _ => false } if (!ok) { - willNotWorkOnGpu(s"Field type mismatch for '${sf.name}': Spark ${sf.dataType} vs Protobuf $protoTypeName") + willNotWorkOnGpu( + s"Field type mismatch for '${sf.name}': Spark ${sf.dataType} vs " + + s"Protobuf $protoTypeName") return } @@ -180,7 +194,8 @@ object ProtobufExprShims { invoke0[String](e, "messageName") /** - * Newer Spark versions may carry an in-expression descriptor set payload (e.g. binaryDescriptorSet). + * Newer Spark versions may carry an in-expression descriptor set payload + * (e.g. binaryDescriptorSet). * Spark 3.4.x does not, so callers should fall back to descFilePath(). */ private def getDescriptorBytes(e: Expression): Option[Array[Byte]] = { From 7606925907476c407bbeef5202a1e3599ddc65d3 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 25 Dec 2025 14:11:11 +0800 Subject: [PATCH 3/9] address comments Signed-off-by: Haoyang Li --- .../src/main/python/protobuf_test.py | 20 ++++--- .../src/main/python/spark_init_internal.py | 27 +++++++++ .../protobuf/ProtobufDescriptorUtils.scala | 2 - .../sql/rapids/GpuFromProtobufSimple.scala | 25 ++++---- .../rapids/shims/ProtobufExprShims.scala | 58 +++++++++++++------ 5 files changed, 94 insertions(+), 38 deletions(-) diff --git a/integration_tests/src/main/python/protobuf_test.py b/integration_tests/src/main/python/protobuf_test.py index f85f1384b1f..3e9d0a1d1cd 100644 --- a/integration_tests/src/main/python/protobuf_test.py +++ b/integration_tests/src/main/python/protobuf_test.py @@ -85,6 +85,7 @@ def _build_simple_descriptor_set_bytes(spark): try: fd = fd.setSyntax("proto2") except Exception: + # If setSyntax is unavailable (older protobuf-java), we intentionally leave syntax unset. pass msg = D.DescriptorProto.newBuilder().setName("Simple") @@ -130,17 +131,17 @@ def _write_bytes_to_hadoop_path(spark, path_str, data_bytes): @ignore_order(local=True) def test_from_protobuf_simple_parquet_binary_round_trip(spark_tmp_path): from_protobuf = _try_import_from_protobuf() - # if from_protobuf is None: - # pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available") - # if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)): - # pytest.skip("spark-protobuf JVM module is not available on the classpath") + if from_protobuf is None: + pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available") + if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)): + pytest.skip("spark-protobuf JVM module is not available on the classpath") data_path = spark_tmp_path + "/PROTOBUF_SIMPLE_PARQUET/" desc_path = spark_tmp_path + "/simple.desc" message_name = "test.Simple" # Generate descriptor bytes once using the JVM (no protoc dependency) - desc_bytes = with_cpu_session(lambda spark: _build_simple_descriptor_set_bytes(spark)) + desc_bytes = with_cpu_session(_build_simple_descriptor_set_bytes) with_cpu_session(lambda spark: _write_bytes_to_hadoop_path(spark, desc_path, desc_bytes)) # Build a DF with scalar columns + binary protobuf column and write to parquet @@ -200,11 +201,16 @@ def run_on_spark(spark): @ignore_order(local=True) def test_from_protobuf_simple_null_input_returns_null(spark_tmp_path): from_protobuf = _try_import_from_protobuf() + if from_protobuf is None: + pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available") + if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)): + pytest.skip("spark-protobuf JVM module is not available on the classpath") + desc_path = spark_tmp_path + "/simple_null_input.desc" message_name = "test.Simple" # Generate descriptor bytes once using the JVM (no protoc dependency) - desc_bytes = with_cpu_session(lambda spark: _build_simple_descriptor_set_bytes(spark)) + desc_bytes = with_cpu_session(_build_simple_descriptor_set_bytes) with_cpu_session(lambda spark: _write_bytes_to_hadoop_path(spark, desc_path, desc_bytes)) # Spark's ProtobufDataToCatalyst is NullIntolerant (null input -> null output). @@ -225,5 +231,3 @@ def run_on_spark(spark): return df.select(decoded.alias("decoded")) assert_gpu_and_cpu_are_equal_collect(run_on_spark) - - diff --git a/integration_tests/src/main/python/spark_init_internal.py b/integration_tests/src/main/python/spark_init_internal.py index 90861746b64..787dc2d7eb0 100644 --- a/integration_tests/src/main/python/spark_init_internal.py +++ b/integration_tests/src/main/python/spark_init_internal.py @@ -61,11 +61,38 @@ def findspark_init(): if spark_jars is not None: logging.info(f"Adding to findspark jars: {spark_jars}") findspark.add_jars(spark_jars) + # Also add to driver classpath so classes are available to Class.forName() + # This is needed for optional modules like spark-protobuf + _add_driver_classpath(spark_jars) if spark_jars_packages is not None: logging.info(f"Adding to findspark packages: {spark_jars_packages}") findspark.add_packages(spark_jars_packages) + +def _add_driver_classpath(jars): + """ + Add jars to the driver classpath via PYSPARK_SUBMIT_ARGS. + findspark.add_jars() only adds --jars, which doesn't make classes available + to Class.forName() on the driver. This function adds --driver-class-path. + """ + if not jars: + return + current_args = os.environ.get('PYSPARK_SUBMIT_ARGS', '') + # Remove trailing 'pyspark-shell' if present + if current_args.endswith('pyspark-shell'): + current_args = current_args[:-len('pyspark-shell')].strip() + # Skip if driver-class-path is already present + if '--driver-class-path' in current_args: + logging.info("driver-class-path already in PYSPARK_SUBMIT_ARGS, skipping") + return + # Add driver-class-path for each jar + jar_list = jars.replace(',', ' ').split() + driver_cp = ':'.join(jar_list) + new_args = f"{current_args} --driver-class-path {driver_cp} pyspark-shell".strip() + os.environ['PYSPARK_SUBMIT_ARGS'] = new_args + logging.info(f"Updated PYSPARK_SUBMIT_ARGS with driver-class-path") + def running_with_xdist(session, is_worker): try: import xdist diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala index 1975db14966..89ed22f6e2d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala @@ -80,5 +80,3 @@ object ProtobufDescriptorUtils { fd.getMessageTypes.asScala.iterator.flatMap(walk _) } } - - diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala index 7d85d277e40..6153f3f96ab 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala @@ -38,6 +38,7 @@ case class GpuFromProtobufSimple( fieldNumbers: Array[Int], cudfTypeIds: Array[Int], cudfTypeScales: Array[Int], + failOnErrors: Boolean, child: Expression) extends GpuUnaryExpression with ExpectsInputTypes with NullIntolerantShim { @@ -56,7 +57,8 @@ case class GpuFromProtobufSimple( input.getBase, fieldNumbers, cudfTypeIds, - cudfTypeScales) + cudfTypeScales, + failOnErrors) if (input.getBase.hasNulls) { withResource(decoded) { _ => decoded.mergeAndSetValidity(BinaryOp.BITWISE_AND, input.getBase) @@ -68,17 +70,20 @@ case class GpuFromProtobufSimple( } object GpuFromProtobufSimple { + // Encodings from com.nvidia.spark.rapids.jni.ProtobufSimple + val ENC_DEFAULT = 0 + val ENC_FIXED = 1 + val ENC_ZIGZAG = 2 + def sparkTypeToCudfId(dt: DataType): (Int, Int) = dt match { - case BooleanType => (DType.BOOL8.getTypeId.getNativeId, 0) - case IntegerType => (DType.INT32.getTypeId.getNativeId, 0) - case LongType => (DType.INT64.getTypeId.getNativeId, 0) - case FloatType => (DType.FLOAT32.getTypeId.getNativeId, 0) - case DoubleType => (DType.FLOAT64.getTypeId.getNativeId, 0) - case StringType => (DType.STRING.getTypeId.getNativeId, 0) + case BooleanType => (DType.BOOL8.getTypeId.getNativeId, ENC_DEFAULT) + case IntegerType => (DType.INT32.getTypeId.getNativeId, ENC_DEFAULT) + case LongType => (DType.INT64.getTypeId.getNativeId, ENC_DEFAULT) + case FloatType => (DType.FLOAT32.getTypeId.getNativeId, ENC_DEFAULT) + case DoubleType => (DType.FLOAT64.getTypeId.getNativeId, ENC_DEFAULT) + case StringType => (DType.STRING.getTypeId.getNativeId, ENC_DEFAULT) + case BinaryType => (DType.LIST.getTypeId.getNativeId, ENC_DEFAULT) case other => throw new IllegalArgumentException(s"Unsupported Spark type for protobuf(simple): $other") } } - - - diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala index 629a119aaf8..a4ad3f42145 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala @@ -71,7 +71,8 @@ object ProtobufExprShims { "Decode a BinaryType column (protobuf) into a Spark SQL struct (simple types only)", ExprChecks.unaryProject( // Output is a struct; the rule does detailed checks in tagExprForGpu. - TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRING), + TypeSig.STRUCT.nested( + TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRING + TypeSig.BINARY), TypeSig.all, TypeSig.BINARY, TypeSig.BINARY), @@ -81,6 +82,7 @@ object ProtobufExprShims { private var fieldNumbers: Array[Int] = _ private var cudfTypeIds: Array[Int] = _ private var cudfTypeScales: Array[Int] = _ + private var failOnErrors: Boolean = _ override def tagExprForGpu(): Unit = { schema = e.dataType match { @@ -92,13 +94,17 @@ object ProtobufExprShims { } val options = getOptionsMap(e) - if (options.nonEmpty) { - val keys = options.keys.mkString(",") + val supportedOptions = Set("enums.as.ints", "mode") + val unsupportedOptions = options.keys.filterNot(supportedOptions.contains) + if (unsupportedOptions.nonEmpty) { + val keys = unsupportedOptions.mkString(",") willNotWorkOnGpu( s"from_protobuf options are not supported yet on GPU: $keys") return } + val enumsAsInts = options.getOrElse("enums.as.ints", "false").toBoolean + failOnErrors = options.getOrElse("mode", "PERMISSIVE").equalsIgnoreCase("FAILFAST") val messageName = getMessageName(e) val descFilePathOpt = getDescFilePath(e).orElse { // Newer Spark may embed a descriptor set (binaryDescriptorSet). Write it to a temp file @@ -132,7 +138,8 @@ object ProtobufExprShims { fields.zipWithIndex.foreach { case (sf, idx) => sf.dataType match { - case BooleanType | IntegerType | LongType | FloatType | DoubleType | StringType => + case BooleanType | IntegerType | LongType | FloatType | DoubleType | + StringType | BinaryType => case other => willNotWorkOnGpu( s"Unsupported field type for from_protobuf(simple): ${sf.name}: $other") @@ -156,16 +163,32 @@ object ProtobufExprShims { val protoType = invoke0[AnyRef](fd, "getType") val protoTypeName = typeName(protoType) - val ok = (sf.dataType, protoTypeName) match { - case (BooleanType, "BOOL") => true - case (IntegerType, "INT32") => true - case (LongType, "INT64") => true - case (FloatType, "FLOAT") => true - case (DoubleType, "DOUBLE") => true - case (StringType, "STRING") => true - case _ => false + + val encoding = (sf.dataType, protoTypeName) match { + case (BooleanType, "BOOL") => Some(GpuFromProtobufSimple.ENC_DEFAULT) + case (IntegerType, "INT32" | "UINT32") => Some(GpuFromProtobufSimple.ENC_DEFAULT) + case (IntegerType, "SINT32") => Some(GpuFromProtobufSimple.ENC_ZIGZAG) + case (IntegerType, "FIXED32" | "SFIXED32") => Some(GpuFromProtobufSimple.ENC_FIXED) + case (LongType, "INT64" | "UINT64") => Some(GpuFromProtobufSimple.ENC_DEFAULT) + case (LongType, "SINT64") => Some(GpuFromProtobufSimple.ENC_ZIGZAG) + case (LongType, "FIXED64" | "SFIXED64") => Some(GpuFromProtobufSimple.ENC_FIXED) + // Spark may upcast smaller integers to LongType + case (LongType, "INT32" | "UINT32" | "SINT32" | "FIXED32" | "SFIXED32") => + val enc = protoTypeName match { + case "SINT32" => GpuFromProtobufSimple.ENC_ZIGZAG + case "FIXED32" | "SFIXED32" => GpuFromProtobufSimple.ENC_FIXED + case _ => GpuFromProtobufSimple.ENC_DEFAULT + } + Some(enc) + case (FloatType, "FLOAT") => Some(GpuFromProtobufSimple.ENC_DEFAULT) + case (DoubleType, "DOUBLE") => Some(GpuFromProtobufSimple.ENC_DEFAULT) + case (StringType, "STRING") => Some(GpuFromProtobufSimple.ENC_DEFAULT) + case (BinaryType, "BYTES") => Some(GpuFromProtobufSimple.ENC_DEFAULT) + case (IntegerType, "ENUM") if enumsAsInts => Some(GpuFromProtobufSimple.ENC_DEFAULT) + case _ => None } - if (!ok) { + + if (encoding.isEmpty) { willNotWorkOnGpu( s"Field type mismatch for '${sf.name}': Spark ${sf.dataType} vs " + s"Protobuf $protoTypeName") @@ -173,9 +196,9 @@ object ProtobufExprShims { } fnums(idx) = invoke0[java.lang.Integer](fd, "getNumber").intValue() - val (tid, scale) = GpuFromProtobufSimple.sparkTypeToCudfId(sf.dataType) + val (tid, _) = GpuFromProtobufSimple.sparkTypeToCudfId(sf.dataType) typeIds(idx) = tid - scales(idx) = scale + scales(idx) = encoding.get } fieldNumbers = fnums @@ -184,7 +207,8 @@ object ProtobufExprShims { } override def convertToGpu(child: Expression): GpuExpression = { - GpuFromProtobufSimple(schema, fieldNumbers, cudfTypeIds, cudfTypeScales, child) + GpuFromProtobufSimple( + schema, fieldNumbers, cudfTypeIds, cudfTypeScales, failOnErrors, child) } } ) @@ -246,5 +270,3 @@ object ProtobufExprShims { private def invoke1[T](obj: AnyRef, method: String, arg0Cls: Class[_], arg0: AnyRef): T = obj.getClass.getMethod(method, arg0Cls).invoke(obj, arg0).asInstanceOf[T] } - - From c6cde2d4f87d4d5d6d5077947d5d9d2e43a8ce00 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 25 Dec 2025 15:38:38 +0800 Subject: [PATCH 4/9] address comments Signed-off-by: Haoyang Li --- integration_tests/src/main/python/data_gen.py | 2 +- .../src/main/python/protobuf_test.py | 4 ++-- .../src/main/python/spark_init_internal.py | 4 ++-- .../protobuf/ProtobufDescriptorUtils.scala | 7 +++++- .../sql/rapids/GpuFromProtobufSimple.scala | 23 +++++++++++-------- .../rapids/shims/ProtobufExprShims.scala | 5 ++-- 6 files changed, 27 insertions(+), 18 deletions(-) diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index 837d4990832..9298bd189ec 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -925,7 +925,7 @@ class ProtobufSimpleMessageRowGen(DataGen): - one column per message field (Spark scalar types) - a binary column containing a serialized protobuf message containing those fields - This is intentionally limited to the simple scalar types supported in Patch 1: + This is intentionally limited to the simple scalar types currently supported: boolean/int32/int64/float/double/string. Fields are omitted from the encoded message if the corresponding value is None. diff --git a/integration_tests/src/main/python/protobuf_test.py b/integration_tests/src/main/python/protobuf_test.py index 3e9d0a1d1cd..4694d0aaa38 100644 --- a/integration_tests/src/main/python/protobuf_test.py +++ b/integration_tests/src/main/python/protobuf_test.py @@ -133,7 +133,7 @@ def test_from_protobuf_simple_parquet_binary_round_trip(spark_tmp_path): from_protobuf = _try_import_from_protobuf() if from_protobuf is None: pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available") - if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)): + if not with_cpu_session(_spark_protobuf_jvm_available): pytest.skip("spark-protobuf JVM module is not available on the classpath") data_path = spark_tmp_path + "/PROTOBUF_SIMPLE_PARQUET/" @@ -203,7 +203,7 @@ def test_from_protobuf_simple_null_input_returns_null(spark_tmp_path): from_protobuf = _try_import_from_protobuf() if from_protobuf is None: pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available") - if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)): + if not with_cpu_session(_spark_protobuf_jvm_available): pytest.skip("spark-protobuf JVM module is not available on the classpath") desc_path = spark_tmp_path + "/simple_null_input.desc" diff --git a/integration_tests/src/main/python/spark_init_internal.py b/integration_tests/src/main/python/spark_init_internal.py index 787dc2d7eb0..e9fc3ca8413 100644 --- a/integration_tests/src/main/python/spark_init_internal.py +++ b/integration_tests/src/main/python/spark_init_internal.py @@ -86,9 +86,9 @@ def _add_driver_classpath(jars): if '--driver-class-path' in current_args: logging.info("driver-class-path already in PYSPARK_SUBMIT_ARGS, skipping") return - # Add driver-class-path for each jar + # Add driver-class-path for each jar (use os.pathsep for platform independence) jar_list = jars.replace(',', ' ').split() - driver_cp = ':'.join(jar_list) + driver_cp = os.pathsep.join(jar_list) new_args = f"{current_args} --driver-class-path {driver_cp} pyspark-shell".strip() os.environ['PYSPARK_SUBMIT_ARGS'] = new_args logging.info(f"Updated PYSPARK_SUBMIT_ARGS with driver-class-path") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala index 89ed22f6e2d..aade729f000 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala @@ -27,6 +27,11 @@ import com.google.protobuf.Descriptors * * This is intentionally lightweight for the "simple types" from_protobuf patch: it supports * descriptor sets produced by `protoc --include_imports --descriptor_set_out=...`. + * + * NOTE: This utility is currently not used in the initial implementation, which relies on + * Spark's ProtobufUtils via reflection (buildMessageDescriptorWithSparkProtobuf). This class + * is preserved for potential future use cases where direct descriptor parsing is needed + * without depending on Spark's shaded protobuf classes. */ object ProtobufDescriptorUtils { @@ -79,4 +84,4 @@ object ProtobufDescriptorUtils { fd.getMessageTypes.asScala.iterator.flatMap(walk _) } -} +} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala index 6153f3f96ab..c41269f418f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala @@ -75,15 +75,20 @@ object GpuFromProtobufSimple { val ENC_FIXED = 1 val ENC_ZIGZAG = 2 - def sparkTypeToCudfId(dt: DataType): (Int, Int) = dt match { - case BooleanType => (DType.BOOL8.getTypeId.getNativeId, ENC_DEFAULT) - case IntegerType => (DType.INT32.getTypeId.getNativeId, ENC_DEFAULT) - case LongType => (DType.INT64.getTypeId.getNativeId, ENC_DEFAULT) - case FloatType => (DType.FLOAT32.getTypeId.getNativeId, ENC_DEFAULT) - case DoubleType => (DType.FLOAT64.getTypeId.getNativeId, ENC_DEFAULT) - case StringType => (DType.STRING.getTypeId.getNativeId, ENC_DEFAULT) - case BinaryType => (DType.LIST.getTypeId.getNativeId, ENC_DEFAULT) + /** + * Maps a Spark DataType to the corresponding cuDF native type ID. + * Note: The encoding (varint/zigzag/fixed) is determined by the protobuf field type, + * not the Spark data type, so it must be set separately based on the protobuf schema. + */ + def sparkTypeToCudfId(dt: DataType): Int = dt match { + case BooleanType => DType.BOOL8.getTypeId.getNativeId + case IntegerType => DType.INT32.getTypeId.getNativeId + case LongType => DType.INT64.getTypeId.getNativeId + case FloatType => DType.FLOAT32.getTypeId.getNativeId + case DoubleType => DType.FLOAT64.getTypeId.getNativeId + case StringType => DType.STRING.getTypeId.getNativeId + case BinaryType => DType.LIST.getTypeId.getNativeId case other => throw new IllegalArgumentException(s"Unsupported Spark type for protobuf(simple): $other") } -} +} \ No newline at end of file diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala index a4ad3f42145..8aeab9b34b2 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala @@ -196,8 +196,7 @@ object ProtobufExprShims { } fnums(idx) = invoke0[java.lang.Integer](fd, "getNumber").intValue() - val (tid, _) = GpuFromProtobufSimple.sparkTypeToCudfId(sf.dataType) - typeIds(idx) = tid + typeIds(idx) = GpuFromProtobufSimple.sparkTypeToCudfId(sf.dataType) scales(idx) = encoding.get } @@ -269,4 +268,4 @@ object ProtobufExprShims { private def invoke1[T](obj: AnyRef, method: String, arg0Cls: Class[_], arg0: AnyRef): T = obj.getClass.getMethod(method, arg0Cls).invoke(obj, arg0).asInstanceOf[T] -} +} \ No newline at end of file From 6d4eb166202ada65418598a1efddfdc581060f0b Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sun, 4 Jan 2026 15:43:57 +0800 Subject: [PATCH 5/9] copyrights Signed-off-by: Haoyang Li --- integration_tests/pom.xml | 2 +- integration_tests/run_pyspark_from_build.sh | 2 +- integration_tests/src/main/python/data_gen.py | 2 +- integration_tests/src/main/python/protobuf_test.py | 2 +- integration_tests/src/main/python/spark_init_internal.py | 2 +- pom.xml | 2 +- .../nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala | 2 +- .../org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala | 2 +- .../scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala | 2 +- .../com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/integration_tests/pom.xml b/integration_tests/pom.xml index 825083b7fbe..8180c1dbd49 100644 --- a/integration_tests/pom.xml +++ b/integration_tests/pom.xml @@ -1,6 +1,6 @@