diff --git a/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala b/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala
index 36c0eb6ea5d..5a98329e3b0 100644
--- a/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala
+++ b/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala
@@ -185,7 +185,7 @@ abstract class DeltaProviderBase extends DeltaIOProvider {
//
case dvRoot @ GpuProjectExec(outputList,
dvFilter @ GpuFilterExec(condition,
- dvFilterInput @ GpuProjectExec(inputList, fsse: GpuFileSourceScanExec, _)), _)
+ dvFilterInput @ GpuProjectExec(inputList, fsse: GpuFileSourceScanExec, _, _)), _, _)
if condition.references.exists(_.name == IS_ROW_DELETED_COLUMN_NAME) &&
!outputList.exists(_.name == "_metadata") && inputList.exists(_.name == "_metadata") =>
dvRoot.withNewChildren(Seq(
@@ -256,7 +256,7 @@ object DVPredicatePushdown extends ShimPredicateHelper {
def pruneIsRowDeletedColumn(plan: SparkPlan): SparkPlan = {
plan.transformUp {
- case project @ GpuProjectExec(projectList, _, _) =>
+ case project @ GpuProjectExec(projectList, _, _, _) =>
val newProjList = projectList.filterNot(isRowDeletedColumnRef(_))
project.copy(projectList = newProjList)
case fsse: GpuFileSourceScanExec =>
@@ -307,11 +307,16 @@ object DVPredicatePushdown extends ShimPredicateHelper {
def mergeIdenticalProjects(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case p @ GpuProjectExec(projList1,
- GpuProjectExec(projList2, child, enablePreSplit1), enablePreSplit2) =>
+ GpuProjectExec(projList2, child, enablePreSplit1, forcePostProjectCoalesce1),
+ enablePreSplit2, forcePostProjectCoalesce2) =>
val projSet1 = projList1.map(_.exprId).toSet
val projSet2 = projList2.map(_.exprId).toSet
if (projSet1 == projSet2) {
- GpuProjectExec(projList1, child, enablePreSplit1 && enablePreSplit2)
+ GpuProjectExec(
+ projList1,
+ child,
+ enablePreSplit1 && enablePreSplit2,
+ forcePostProjectCoalesce1 || forcePostProjectCoalesce2)
} else {
p
}
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 888f4fa807c..196f518c6a7 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -8267,7 +8267,7 @@ are limited.
|
|
|
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH |
|
|
|
@@ -8290,7 +8290,7 @@ are limited.
|
|
|
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH |
|
|
|
diff --git a/integration_tests/run_pyspark_from_build.sh b/integration_tests/run_pyspark_from_build.sh
index a619e6c66ae..5a7db886da1 100755
--- a/integration_tests/run_pyspark_from_build.sh
+++ b/integration_tests/run_pyspark_from_build.sh
@@ -46,6 +46,9 @@
# To run all tests, including Avro tests:
# INCLUDE_SPARK_AVRO_JAR=true ./run_pyspark_from_build.sh
#
+# To run tests WITHOUT Protobuf tests (protobuf is included by default):
+# INCLUDE_SPARK_PROTOBUF_JAR=false ./run_pyspark_from_build.sh
+#
# To run a specific test:
# TEST=my_test ./run_pyspark_from_build.sh
#
@@ -141,9 +144,82 @@ else
AVRO_JARS=""
fi
- # ALL_JARS includes dist.jar integration-test.jar avro.jar parquet.jar if they exist
+ # Protobuf support: Include spark-protobuf jar by default for protobuf_test.py
+ # Set INCLUDE_SPARK_PROTOBUF_JAR=false to disable
+ PROTOBUF_JARS=""
+ if [[ $( echo ${INCLUDE_SPARK_PROTOBUF_JAR} | tr '[:upper:]' '[:lower:]' ) != "false" ]];
+ then
+ export INCLUDE_SPARK_PROTOBUF_JAR=true
+ mkdir -p "${TARGET_DIR}/dependency"
+
+ # Download spark-protobuf jar if not already in target/dependency
+ PROTOBUF_JAR_NAME="spark-protobuf_${SCALA_VERSION}-${VERSION_STRING}.jar"
+ PROTOBUF_JAR_PATH="${TARGET_DIR}/dependency/${PROTOBUF_JAR_NAME}"
+
+ if [[ ! -f "$PROTOBUF_JAR_PATH" ]]; then
+ echo "Downloading spark-protobuf jar..."
+ PROTOBUF_MAVEN_URL="https://repo1.maven.org/maven2/org/apache/spark/spark-protobuf_${SCALA_VERSION}/${VERSION_STRING}/${PROTOBUF_JAR_NAME}"
+ if curl -fsL -o "$PROTOBUF_JAR_PATH" "$PROTOBUF_MAVEN_URL"; then
+ echo "Downloaded spark-protobuf jar to $PROTOBUF_JAR_PATH"
+ else
+ echo "WARNING: Failed to download spark-protobuf jar from $PROTOBUF_MAVEN_URL"
+ rm -f "$PROTOBUF_JAR_PATH"
+ fi
+ fi
+
+ # Also download protobuf-java jar (required dependency).
+ # Detect version from the jar bundled with Spark, fall back to version mapping.
+ PROTOBUF_JAVA_VERSION=""
+ BUNDLED_PB_JAR=$(ls "$SPARK_HOME"/jars/protobuf-java-[0-9]*.jar 2>/dev/null | head -1)
+ if [[ -n "$BUNDLED_PB_JAR" ]]; then
+ PROTOBUF_JAVA_VERSION=$(basename "$BUNDLED_PB_JAR" | sed 's/protobuf-java-\(.*\)\.jar/\1/')
+ echo "Detected protobuf-java version $PROTOBUF_JAVA_VERSION from SPARK_HOME"
+ fi
+ if [[ -z "$PROTOBUF_JAVA_VERSION" ]]; then
+ case "$VERSION_STRING" in
+ 3.4.*) PROTOBUF_JAVA_VERSION="3.25.1" ;;
+ 3.5.*) PROTOBUF_JAVA_VERSION="3.25.1" ;;
+ 4.0.*) PROTOBUF_JAVA_VERSION="4.29.3" ;;
+ *) PROTOBUF_JAVA_VERSION="3.25.1" ;;
+ esac
+ echo "Using protobuf-java version $PROTOBUF_JAVA_VERSION based on Spark $VERSION_STRING"
+ fi
+ PROTOBUF_JAVA_JAR_NAME="protobuf-java-${PROTOBUF_JAVA_VERSION}.jar"
+ PROTOBUF_JAVA_JAR_PATH="${TARGET_DIR}/dependency/${PROTOBUF_JAVA_JAR_NAME}"
+
+ if [[ ! -f "$PROTOBUF_JAVA_JAR_PATH" ]]; then
+ echo "Downloading protobuf-java jar..."
+ PROTOBUF_JAVA_MAVEN_URL="https://repo1.maven.org/maven2/com/google/protobuf/protobuf-java/${PROTOBUF_JAVA_VERSION}/${PROTOBUF_JAVA_JAR_NAME}"
+ if curl -fsL -o "$PROTOBUF_JAVA_JAR_PATH" "$PROTOBUF_JAVA_MAVEN_URL"; then
+ echo "Downloaded protobuf-java jar to $PROTOBUF_JAVA_JAR_PATH"
+ else
+ echo "WARNING: Failed to download protobuf-java jar from $PROTOBUF_JAVA_MAVEN_URL"
+ rm -f "$PROTOBUF_JAVA_JAR_PATH"
+ fi
+ fi
+
+ if [[ -f "$PROTOBUF_JAR_PATH" ]]; then
+ PROTOBUF_JARS="$PROTOBUF_JAR_PATH"
+ echo "Including spark-protobuf jar: $PROTOBUF_JAR_PATH"
+ fi
+ if [[ -f "$PROTOBUF_JAVA_JAR_PATH" ]]; then
+ PROTOBUF_JARS="${PROTOBUF_JARS:+$PROTOBUF_JARS }$PROTOBUF_JAVA_JAR_PATH"
+ echo "Including protobuf-java jar: $PROTOBUF_JAVA_JAR_PATH"
+ fi
+ # Also add protobuf jars to driver classpath for Class.forName() to work
+ # This is needed because --jars only adds to executor classpath
+ if [[ -n "$PROTOBUF_JARS" ]]; then
+ PROTOBUF_DRIVER_CP=$(echo "$PROTOBUF_JARS" | tr ' ' ':')
+ export PYSP_TEST_spark_driver_extraClassPath="${PYSP_TEST_spark_driver_extraClassPath:+${PYSP_TEST_spark_driver_extraClassPath}:}${PROTOBUF_DRIVER_CP}"
+ echo "Added protobuf jars to driver classpath"
+ fi
+ else
+ export INCLUDE_SPARK_PROTOBUF_JAR=false
+ fi
+
+ # ALL_JARS includes dist.jar integration-test.jar avro.jar parquet.jar protobuf.jar if they exist
# Remove non-existing paths and canonicalize the paths including get rid of links and `..`
- ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PARQUET_HADOOP_TESTS || true)
+ ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PARQUET_HADOOP_TESTS $PROTOBUF_JARS || true)
# `:` separated jars
ALL_JARS="${ALL_JARS//$'\n'/:}"
@@ -411,6 +487,7 @@ else
export PYSP_TEST_spark_gluten_loadLibFromJar=true
fi
+
SPARK_SHELL_SMOKE_TEST="${SPARK_SHELL_SMOKE_TEST:-0}"
EXPLAIN_ONLY_CPU_SMOKE_TEST="${EXPLAIN_ONLY_CPU_SMOKE_TEST:-0}"
SPARK_CONNECT_SMOKE_TEST="${SPARK_CONNECT_SMOKE_TEST:-0}"
diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py
index fa7decac82d..d6583e3fd00 100644
--- a/integration_tests/src/main/python/data_gen.py
+++ b/integration_tests/src/main/python/data_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2025, NVIDIA CORPORATION.
+# Copyright (c) 2020-2026, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -857,6 +857,389 @@ def gen_bytes():
return bytes([ rand.randint(0, 255) for _ in range(length) ])
self._start(rand, gen_bytes)
+
+# -----------------------------------------------------------------------------
+# Protobuf (simple types) generators/utilities (for from_protobuf/to_protobuf tests)
+# -----------------------------------------------------------------------------
+
+_PROTOBUF_WIRE_VARINT = 0
+_PROTOBUF_WIRE_64BIT = 1
+_PROTOBUF_WIRE_LEN_DELIM = 2
+_PROTOBUF_WIRE_32BIT = 5
+
+def _encode_protobuf_uvarint(value):
+ """Encode a non-negative integer as protobuf varint."""
+ if value is None:
+ raise ValueError("value must not be None")
+ if value < 0:
+ raise ValueError("uvarint only supports non-negative integers")
+ out = bytearray()
+ v = int(value)
+ while True:
+ b = v & 0x7F
+ v >>= 7
+ if v:
+ out.append(b | 0x80)
+ else:
+ out.append(b)
+ break
+ return bytes(out)
+
+def _encode_protobuf_key(field_number, wire_type):
+ return _encode_protobuf_uvarint((int(field_number) << 3) | int(wire_type))
+
+def _encode_protobuf_zigzag32(value):
+ """Encode a signed 32-bit integer using zigzag encoding (sint32)."""
+ return (value << 1) ^ (value >> 31)
+
+def _encode_protobuf_zigzag64(value):
+ """Encode a signed 64-bit integer using zigzag encoding (sint64)."""
+ return (value << 1) ^ (value >> 63)
+
+def _encode_protobuf_field(field_number, spark_type, value, encoding='default'):
+ """
+ Encode a single protobuf field for a subset of scalar types.
+
+ encoding: 'default', 'fixed', 'zigzag' - determines how integers are encoded
+
+ Notes on signed ints:
+ - Protobuf `int32`/`int64` use *varint* encoding of the two's-complement integer.
+ - Negative `int32` values are encoded as a 10-byte varint (because they are sign-extended to 64 bits).
+ - `sint32`/`sint64` use zigzag encoding for efficient negative number storage.
+ - `fixed32`/`fixed64` use fixed-width little-endian encoding.
+ """
+ if value is None:
+ return b""
+
+ if isinstance(spark_type, BooleanType):
+ return _encode_protobuf_key(field_number, _PROTOBUF_WIRE_VARINT) + _encode_protobuf_uvarint(1 if value else 0)
+ elif isinstance(spark_type, IntegerType):
+ if encoding == 'fixed':
+ # fixed32 / sfixed32: 4-byte little-endian (mask handles both signed and unsigned)
+ return _encode_protobuf_key(field_number, _PROTOBUF_WIRE_32BIT) + struct.pack(" bool:
+ """
+ `spark-protobuf` is an optional external module. PySpark may have the Python wrappers
+ even when the JVM side isn't present on the classpath, which manifests as:
+ TypeError: 'JavaPackage' object is not callable
+ when calling into `sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf`.
+ """
+ jvm = spark.sparkContext._jvm
+ candidates = [
+ # Scala object `functions` compiles to `functions$`
+ "org.apache.spark.sql.protobuf.functions$",
+ # Some environments may expose it differently
+ "org.apache.spark.sql.protobuf.functions",
+ ]
+ for cls in candidates:
+ try:
+ jvm.java.lang.Class.forName(cls)
+ return True
+ except Exception:
+ continue
+ return False
+
+
+# ---------------------------------------------------------------------------
+# Shared fixture and helpers to reduce per-test boilerplate
+# ---------------------------------------------------------------------------
+
+@pytest.fixture(scope="module")
+def from_protobuf_fn():
+ """Skip the entire module if from_protobuf or the JVM module is unavailable."""
+ fn = _try_import_from_protobuf()
+ if fn is None:
+ pytest.skip("from_protobuf not available")
+ if not with_cpu_session(_spark_protobuf_jvm_available):
+ pytest.skip("spark-protobuf JVM not available")
+ return fn
+
+
+def _setup_protobuf_desc(spark_tmp_path, desc_name, build_fn):
+ """Build descriptor bytes via JVM, write to HDFS, return (desc_path, desc_bytes)."""
+ desc_path = spark_tmp_path + "/" + desc_name
+ desc_bytes = with_cpu_session(build_fn)
+ with_cpu_session(
+ lambda spark: _write_bytes_to_hadoop_path(spark, desc_path, desc_bytes))
+ return desc_path, desc_bytes
+
+
+def _call_from_protobuf(from_protobuf_fn, col, message_name,
+ desc_path, desc_bytes, options=None):
+ """Call from_protobuf using the right API variant."""
+ sig = inspect.signature(from_protobuf_fn)
+ if "binaryDescriptorSet" in sig.parameters:
+ kw = dict(binaryDescriptorSet=bytearray(desc_bytes))
+ if options is not None:
+ kw["options"] = options
+ return from_protobuf_fn(col, message_name, **kw)
+ if options is not None:
+ return from_protobuf_fn(col, message_name, desc_path, options)
+ return from_protobuf_fn(col, message_name, desc_path)
+
+
+def test_call_from_protobuf_preserves_options_for_legacy_signature():
+ calls = []
+
+ def fake_from_protobuf(col, message_name, desc_path, *args):
+ calls.append((col, message_name, desc_path, args))
+ return "ok"
+
+ options = {"enums.as.ints": "true"}
+ result = _call_from_protobuf(
+ fake_from_protobuf, "col", "msg", "/tmp/test.desc", b"desc", options=options)
+
+ assert result == "ok"
+ assert calls == [("col", "msg", "/tmp/test.desc", (options,))]
+
+
+def test_encode_protobuf_packed_repeated_fixed_uses_unsigned_twos_complement():
+ i32_encoded = _encode_protobuf_packed_repeated(
+ 1, IntegerType(), [0xFFFFFFFF], encoding='fixed')
+ i64_encoded = _encode_protobuf_packed_repeated(
+ 1, LongType(), [0xFFFFFFFFFFFFFFFF], encoding='fixed')
+
+ assert i32_encoded == b"\x0a\x04" + struct.pack(" 1, 1 -> 2, -2 -> 3, 2 -> 4, etc.
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "signed.desc", _build_signed_int_descriptor_set_bytes)
+ message_name = "test.WithSignedInts"
+
+ data_gen = ProtobufMessageGen([
+ PbScalar("si32", 1, IntegerGen(
+ special_cases=[-1, 0, 1, -2147483648, 2147483647]), encoding='zigzag'),
+ PbScalar("si64", 2, LongGen(
+ special_cases=[-1, 0, 1, -9223372036854775808, 9223372036854775807]),
+ encoding='zigzag'),
+ PbScalar("sf32", 3, IntegerGen(
+ special_cases=[0, 1, -1, 2147483647, -2147483648]), encoding='fixed'),
+ PbScalar("sf64", 4, LongGen(
+ special_cases=[0, 1, -1]), encoding='fixed'),
+ ])
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+
+ return df.select(
+ decoded.getField("si32").alias("si32"),
+ decoded.getField("si64").alias("si64"),
+ decoded.getField("sf32").alias("sf32"),
+ decoded.getField("sf64").alias("sf64"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_fixed_int_descriptor_set_bytes(spark):
+ """
+ Build a FileDescriptorSet for message with fixed-width integer types:
+ message WithFixedInts {
+ optional fixed32 fx32 = 1;
+ optional fixed64 fx64 = 2;
+ }
+ """
+ D, fd = _new_proto2_file(spark, "fixed_int.proto")
+
+ label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL
+
+ msg = D.DescriptorProto.newBuilder().setName("WithFixedInts")
+ msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("fx32")
+ .setNumber(1)
+ .setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_FIXED32)
+ .build()
+ )
+ msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("fx64")
+ .setNumber(2)
+ .setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_FIXED64)
+ .build()
+ )
+ fd.addMessageType(msg.build())
+
+ fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build()
+ return bytes(fds.toByteArray())
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_fixed_integers(spark_tmp_path, from_protobuf_fn):
+ """
+ Test decoding fixed-width unsigned integer types.
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "fixed.desc", _build_fixed_int_descriptor_set_bytes)
+ message_name = "test.WithFixedInts"
+
+ data_gen = ProtobufMessageGen([
+ PbScalar("fx32", 1, IntegerGen(
+ special_cases=[0, 1, -1, 2147483647, -2147483648]), encoding='fixed'),
+ PbScalar("fx64", 2, LongGen(
+ special_cases=[0, 1, -1]), encoding='fixed'),
+ ])
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("fx32").alias("fx32"),
+ decoded.getField("fx64").alias("fx64"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_schema_projection_descriptor_set_bytes(spark):
+ """
+ Build a FileDescriptorSet for nested schema projection testing:
+ message Detail {
+ optional int32 a = 1;
+ optional int32 b = 2;
+ optional string c = 3;
+ }
+ message SchemaProj {
+ optional int32 id = 1;
+ optional string name = 2;
+ optional Detail detail = 3;
+ repeated Detail items = 4;
+ }
+ The Detail message has 3 fields so we can test pruning subsets.
+ """
+ D, fd = _new_proto2_file(spark, "schema_proj.proto")
+
+ label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL
+ label_rep = D.FieldDescriptorProto.Label.LABEL_REPEATED
+
+ # Detail message: { a: int32, b: int32, c: string }
+ detail_msg = D.DescriptorProto.newBuilder().setName("Detail")
+ detail_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("a").setNumber(1).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ detail_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("b").setNumber(2).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ detail_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("c").setNumber(3).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_STRING).build())
+ fd.addMessageType(detail_msg.build())
+
+ # SchemaProj message: { id, name, detail: Detail, items: repeated Detail }
+ main_msg = D.DescriptorProto.newBuilder().setName("SchemaProj")
+ main_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("id").setNumber(1).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ main_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("name").setNumber(2).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_STRING).build())
+ main_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("detail").setNumber(3).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE)
+ .setTypeName(".test.Detail").build())
+ main_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("items").setNumber(4).setLabel(label_rep)
+ .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE)
+ .setTypeName(".test.Detail").build())
+ fd.addMessageType(main_msg.build())
+
+ fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build()
+ return bytes(fds.toByteArray())
+
+
+# Field descriptors for SchemaProj: {id, name, detail: {a, b, c}, items[]: {a, b, c}}
+_detail_children = [
+ PbScalar("a", 1, IntegerGen()),
+ PbScalar("b", 2, IntegerGen()),
+ PbScalar("c", 3, StringGen()),
+]
+_schema_proj_fields = [
+ PbScalar("id", 1, IntegerGen()),
+ PbScalar("name", 2, StringGen()),
+ PbNested("detail", 3, _detail_children),
+ PbRepeatedMessage("items", 4, _detail_children),
+]
+
+_schema_proj_test_data = [
+ encode_pb_message(_schema_proj_fields,
+ [1, "alice", (10, 20, "d1"), [(100, 200, "i1"), (101, 201, "i2")]]),
+ encode_pb_message(_schema_proj_fields,
+ [2, "bob", (30, 40, "d2"), [(300, 400, "i3")]]),
+ encode_pb_message(_schema_proj_fields,
+ [3, "carol", (50, 60, "d3"), []]),
+]
+
+
+_schema_proj_cases = [
+ ("nested_single_field", [("id", ("id",)), ("detail_a", ("detail", "a"))]),
+ ("nested_two_fields", [("detail_a", ("detail", "a")), ("detail_c", ("detail", "c"))]),
+ ("whole_struct_no_pruning", [("id", ("id",)), ("detail", ("detail",))]),
+ ("whole_and_subfield", [("detail", ("detail",)), ("detail_a", ("detail", "a"))]),
+ ("scalar_plus_nested", [("id", ("id",)), ("name", ("name",)), ("detail_a", ("detail", "a"))]),
+ ("repeated_msg_single_subfield", [("id", ("id",)), ("items_a", ("items", "a"))]),
+ ("repeated_msg_two_subfields", [("items_a", ("items", "a")), ("items_c", ("items", "c"))]),
+ ("repeated_whole_no_pruning", [("id", ("id",)), ("items", ("items",))]),
+ ("mix_struct_and_repeated", [("id", ("id",)), ("detail_a", ("detail", "a")), ("items_c", ("items", "c"))]),
+]
+
+
+def _get_field_by_path(expr, path):
+ current = expr
+ for name in path:
+ current = current.getField(name)
+ return current
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_projection_across_alias_project_boundary(
+ spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "schema_proj_alias.desc",
+ _build_schema_projection_descriptor_set_bytes)
+ message_name = "test.SchemaProj"
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame([(row,) for row in _schema_proj_test_data], schema="bin binary")
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ aliased = df.select(decoded.alias("decoded"))
+ return aliased.select(
+ f.col("decoded").getField("detail").getField("a").alias("detail_a"),
+ f.col("decoded").getField("id").alias("id"))
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_projection_across_withcolumn_boundary(
+ spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "schema_proj_withcolumn.desc",
+ _build_schema_projection_descriptor_set_bytes)
+ message_name = "test.SchemaProj"
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame([(row,) for row in _schema_proj_test_data], schema="bin binary")
+ with_decoded = df.withColumn(
+ "decoded",
+ _call_from_protobuf(from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes))
+ return with_decoded.select(
+ f.col("decoded").getField("items").getField("a").alias("items_a"),
+ f.col("decoded").getField("id").alias("id"))
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_dual_message_projection_descriptor_set_bytes(spark):
+ """
+ message BytesView {
+ optional int32 status = 1;
+ optional bytes payload = 2;
+ }
+ message NestedPayload {
+ optional int32 count = 1;
+ }
+ message NestedView {
+ optional int32 status = 1;
+ optional NestedPayload payload = 2;
+ }
+ """
+ D, fd = _new_proto2_file(spark, "dual_projection.proto")
+ label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL
+
+ nested_msg = D.DescriptorProto.newBuilder().setName("NestedPayload")
+ nested_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("count").setNumber(1).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ fd.addMessageType(nested_msg.build())
+
+ bytes_view = D.DescriptorProto.newBuilder().setName("BytesView")
+ bytes_view.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("status").setNumber(1).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ bytes_view.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("payload").setNumber(2).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_BYTES).build())
+ fd.addMessageType(bytes_view.build())
+
+ nested_view = D.DescriptorProto.newBuilder().setName("NestedView")
+ nested_view.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("status").setNumber(1).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ nested_view.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("payload").setNumber(2).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE)
+ .setTypeName(".test.NestedPayload").build())
+ fd.addMessageType(nested_view.build())
+
+ fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build()
+ return bytes(fds.toByteArray())
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_different_messages_same_binary_column_do_not_interfere(
+ spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "dual_projection.desc",
+ _build_dual_message_projection_descriptor_set_bytes)
+
+ payload_keep = _encode_tag(1, 0) + _encode_varint(7)
+ payload_drop = _encode_tag(1, 0) + _encode_varint(9)
+ row_keep = (_encode_tag(1, 0) + _encode_varint(1) +
+ _encode_tag(2, 2) + _encode_varint(len(payload_keep)) + payload_keep)
+ row_drop = (_encode_tag(1, 0) + _encode_varint(0) +
+ _encode_tag(2, 2) + _encode_varint(len(payload_drop)) + payload_drop)
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame([(row_keep,), (row_drop,)], schema="bin binary")
+ bytes_view = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), "test.BytesView", desc_path, desc_bytes)
+ nested_view = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), "test.NestedView", desc_path, desc_bytes)
+ return df.filter(bytes_view.getField("status") == 1).select(
+ nested_view.getField("payload").getField("count").alias("payload_count"))
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_deep_nested_5_level_descriptor_set_bytes(spark):
+ """
+ Build a FileDescriptorSet for 5-level deep nesting:
+ message Level5 { optional int32 val5 = 1; }
+ message Level4 { optional int32 val4 = 1; optional Level5 level5 = 2; }
+ message Level3 { optional int32 val3 = 1; optional Level4 level4 = 2; }
+ message Level2 { optional int32 val2 = 1; optional Level3 level3 = 2; }
+ message Level1 { optional int32 val1 = 1; optional Level2 level2 = 2; }
+ """
+ D, fd = _new_proto2_file(spark, "deep_nested_5_level.proto")
+
+ label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL
+
+ # Level5 message
+ level5_msg = D.DescriptorProto.newBuilder().setName("Level5")
+ level5_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("val5")
+ .setNumber(1)
+ .setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32)
+ .build()
+ )
+ fd.addMessageType(level5_msg.build())
+
+ # Level4 message
+ level4_msg = D.DescriptorProto.newBuilder().setName("Level4")
+ level4_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("val4")
+ .setNumber(1)
+ .setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32)
+ .build()
+ )
+ level4_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("level5")
+ .setNumber(2)
+ .setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE)
+ .setTypeName(".test.Level5")
+ .build()
+ )
+ fd.addMessageType(level4_msg.build())
+
+ # Level3 message
+ level3_msg = D.DescriptorProto.newBuilder().setName("Level3")
+ level3_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("val3")
+ .setNumber(1)
+ .setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32)
+ .build()
+ )
+ level3_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("level4")
+ .setNumber(2)
+ .setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE)
+ .setTypeName(".test.Level4")
+ .build()
+ )
+ fd.addMessageType(level3_msg.build())
+
+ # Level2 message
+ level2_msg = D.DescriptorProto.newBuilder().setName("Level2")
+ level2_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("val2")
+ .setNumber(1)
+ .setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32)
+ .build()
+ )
+ level2_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("level3")
+ .setNumber(2)
+ .setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE)
+ .setTypeName(".test.Level3")
+ .build()
+ )
+ fd.addMessageType(level2_msg.build())
+
+ # Level1 message
+ level1_msg = D.DescriptorProto.newBuilder().setName("Level1")
+ level1_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("val1")
+ .setNumber(1)
+ .setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32)
+ .build()
+ )
+ level1_msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("level2")
+ .setNumber(2)
+ .setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE)
+ .setTypeName(".test.Level2")
+ .build()
+ )
+ fd.addMessageType(level1_msg.build())
+
+ fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build()
+ return bytes(fds.toByteArray())
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_deep_nesting_5_levels(spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "deep_nested_5_level.desc",
+ _build_deep_nested_5_level_descriptor_set_bytes)
+ message_name = "test.Level1"
+ data_gen = ProtobufMessageGen([
+ PbScalar("val1", 1, IntegerGen()),
+ PbNested("level2", 2, [
+ PbScalar("val2", 1, IntegerGen()),
+ PbNested("level3", 2, [
+ PbScalar("val3", 1, IntegerGen()),
+ PbNested("level4", 2, [
+ PbScalar("val4", 1, IntegerGen()),
+ PbNested("level5", 2, [
+ PbScalar("val5", 1, IntegerGen()),
+ ]),
+ ]),
+ ]),
+ ]),
+ ])
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("val1").alias("val1"),
+ decoded.getField("level2").alias("level2"),
+ )
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@pytest.mark.parametrize("case_id,select_specs", _schema_proj_cases, ids=lambda c: c[0] if isinstance(c, tuple) else str(c))
+@ignore_order(local=True)
+def test_from_protobuf_schema_projection_cases(
+ spark_tmp_path, from_protobuf_fn, case_id, select_specs):
+ """Parametrized nested-schema projection tests."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "schema_proj.desc", _build_schema_projection_descriptor_set_bytes)
+ message_name = "test.SchemaProj"
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame(
+ [(d,) for d in _schema_proj_test_data], schema="bin binary")
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ selected = [_get_field_by_path(decoded, path).alias(alias)
+ for alias, path in select_specs]
+ return df.select(*selected)
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+def _build_name_collision_descriptor_set_bytes(spark):
+ D, fd = _new_proto2_file(spark, "name_collision.proto")
+ label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL
+
+ # User message
+ user_msg = D.DescriptorProto.newBuilder().setName("User")
+ user_msg.addField(D.FieldDescriptorProto.newBuilder().setName("age").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ user_msg.addField(D.FieldDescriptorProto.newBuilder().setName("id").setNumber(2).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ fd.addMessageType(user_msg.build())
+
+ # Ad message
+ ad_msg = D.DescriptorProto.newBuilder().setName("Ad")
+ ad_msg.addField(D.FieldDescriptorProto.newBuilder().setName("id").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ fd.addMessageType(ad_msg.build())
+
+ # Event message
+ event_msg = D.DescriptorProto.newBuilder().setName("Event")
+ event_msg.addField(D.FieldDescriptorProto.newBuilder().setName("user_info").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE).setTypeName(".test.User").build())
+ event_msg.addField(D.FieldDescriptorProto.newBuilder().setName("ad_info").setNumber(2).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE).setTypeName(".test.Ad").build())
+ fd.addMessageType(event_msg.build())
+
+ fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build()
+ return bytes(fds.toByteArray())
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_bug1_name_collision(spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "name_collision.desc",
+ _build_name_collision_descriptor_set_bytes)
+ message_name = "test.Event"
+
+ data_gen = ProtobufMessageGen([
+ PbNested("user_info", 1, [
+ PbScalar("age", 1, IntegerGen()),
+ PbScalar("id", 2, IntegerGen()),
+ ]),
+ PbNested("ad_info", 2, [
+ PbScalar("id", 1, IntegerGen()),
+ ]),
+ ])
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+
+ return df.select(
+ decoded.getField("user_info").getField("age").alias("age"),
+ decoded.getField("user_info").getField("id").alias("user_id"),
+ decoded.getField("ad_info").getField("id").alias("ad_id")
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_filter_jump_descriptor_set_bytes(spark):
+ D, fd = _new_proto2_file(spark, "filter_jump.proto")
+ label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL
+
+ msg = D.DescriptorProto.newBuilder().setName("Event")
+ msg.addField(D.FieldDescriptorProto.newBuilder().setName("status").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ msg.addField(D.FieldDescriptorProto.newBuilder().setName("ad_info").setNumber(2).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_STRING).build())
+ fd.addMessageType(msg.build())
+
+ fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build()
+ return bytes(fds.toByteArray())
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_bug2_filter_jump(spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "filter_jump.desc",
+ _build_filter_jump_descriptor_set_bytes)
+ message_name = "test.Event"
+
+ data_gen = ProtobufMessageGen([
+ PbScalar("status", 1, IntegerGen(min_val=1, max_val=1)),
+ PbScalar("ad_info", 2, StringGen()),
+ ])
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ pb_expr1 = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ pb_expr2 = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+
+ return df.filter(pb_expr1.getField("status") == 1).select(pb_expr2.getField("ad_info").alias("ad_info"))
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_unrelated_struct_name_collision_descriptor_set_bytes(spark):
+ D, fd = _new_proto2_file(spark, "unrelated_struct.proto")
+ label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL
+
+ # Nested message
+ nested_msg = D.DescriptorProto.newBuilder().setName("Nested")
+ nested_msg.addField(D.FieldDescriptorProto.newBuilder().setName("dummy").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ nested_msg.addField(D.FieldDescriptorProto.newBuilder().setName("winfoid").setNumber(2).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ fd.addMessageType(nested_msg.build())
+
+ msg = D.DescriptorProto.newBuilder().setName("Event")
+ msg.addField(D.FieldDescriptorProto.newBuilder().setName("ad_info").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE).setTypeName(".test.Nested").build())
+ fd.addMessageType(msg.build())
+
+ fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build()
+ return bytes(fds.toByteArray())
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_bug3_unrelated_struct_name_collision(spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "unrelated_struct.desc",
+ _build_unrelated_struct_name_collision_descriptor_set_bytes)
+ message_name = "test.Event"
+
+ data_gen = ProtobufMessageGen([
+ PbNested("ad_info", 1, [
+ PbScalar("dummy", 1, IntegerGen()),
+ PbScalar("winfoid", 2, IntegerGen()),
+ ]),
+ ])
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ # Write to parquet to prevent Catalyst from optimizing away the GetStructField,
+ # and to ensure it runs on the GPU.
+ df_with_other = df.withColumn("other_struct",
+ f.struct(f.lit("hello").alias("dummy_str"), f.lit(42).alias("winfoid")))
+
+ path = spark_tmp_path + "/bug3_data.parquet"
+ df_with_other.write.mode("overwrite").parquet(path)
+ read_df = spark.read.parquet(path)
+
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+
+ # We only select decoded.ad_info.winfoid, so dummy is pruned.
+ # winfoid gets ordinal 0 in the pruned schema.
+ # But for other_struct, winfoid is ordinal 1.
+ # GpuGetStructFieldMeta will see "winfoid", query the ThreadLocal, get 0,
+ # and extract ordinal 0 ("hello") for other_winfoid, causing a mismatch!
+ return read_df.select(
+ decoded.getField("ad_info").getField("winfoid").alias("pb_winfoid"),
+ f.col("other_struct").getField("winfoid").alias("other_winfoid")
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_max_depth_descriptor_set_bytes(spark):
+ D, fd = _new_proto2_file(spark, "max_depth.proto")
+ label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL
+
+ # Generate 12 levels of nesting
+ for i in range(12, 0, -1):
+ msg = D.DescriptorProto.newBuilder().setName(f"Level{i}")
+ msg.addField(D.FieldDescriptorProto.newBuilder().setName(f"val{i}").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ if i < 12:
+ msg.addField(D.FieldDescriptorProto.newBuilder().setName(f"level{i+1}").setNumber(2).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE).setTypeName(f".test.Level{i+1}").build())
+ fd.addMessageType(msg.build())
+
+ fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build()
+ return bytes(fds.toByteArray())
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_bug4_max_depth(spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "max_depth.desc",
+ _build_max_depth_descriptor_set_bytes)
+ message_name = "test.Level1"
+
+ # Build the deeply nested data gen spec
+ def build_nested_gen(level):
+ if level == 12:
+ return [PbScalar(f"val{level}", 1, IntegerGen())]
+ return [
+ PbScalar(f"val{level}", 1, IntegerGen()),
+ PbNested(f"level{level+1}", 2, build_nested_gen(level+1))
+ ]
+
+ data_gen = ProtobufMessageGen(build_nested_gen(1))
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ # Deep access
+ field = decoded
+ for i in range(2, 13):
+ field = field.getField(f"level{i}")
+ return df.select(field.getField("val12").alias("val12"))
+
+ # Depth 12 exceeds GPU max nesting depth (10), so the query should
+ # gracefully fall back to CPU. Verify that it still produces correct
+ # results (CPU path) without crashing.
+ from spark_session import with_cpu_session
+ cpu_result = with_cpu_session(lambda spark: run_on_spark(spark).collect())
+ assert len(cpu_result) > 0
+
+
+# ===========================================================================
+# Regression tests for known bugs found by code review
+# ===========================================================================
+
+def _encode_varint(value):
+ """Encode a non-negative integer as a protobuf varint (for hand-crafting test bytes)."""
+ out = bytearray()
+ v = int(value)
+ while True:
+ b = v & 0x7F
+ v >>= 7
+ if v:
+ out.append(b | 0x80)
+ else:
+ out.append(b)
+ break
+ return bytes(out)
+
+
+def _encode_tag(field_number, wire_type):
+ return _encode_varint((field_number << 3) | wire_type)
+
+
+# ---------------------------------------------------------------------------
+# Bug 1: BOOL8 truncation — non-canonical bool varint values >= 256
+#
+# Protobuf spec: bool is a varint; any non-zero value means true.
+# CPU decoder (protobuf-java): CodedInputStream.readBool() = readRawVarint64() != 0 → true
+# GPU decoder: extract_varint_kernel writes static_cast(v).
+# For v = 256, static_cast(256) == 0 → false. BUG.
+# ---------------------------------------------------------------------------
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_bool_noncanonical_varint_scalar(spark_tmp_path, from_protobuf_fn):
+ """Regression test: scalar bool encoded as non-canonical varint (e.g. 256) must decode as true.
+
+ Protobuf allows any non-zero varint for bool true. The GPU decoder previously
+ truncated to uint8_t, causing values >= 256 to wrap to 0 (false).
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "simple_bool_bug.desc", _build_simple_descriptor_set_bytes)
+ message_name = "test.Simple"
+
+ # varint(256) = 0x80 0x02, varint(512) = 0x80 0x04 — valid non-canonical bool true
+ row_bool_256 = _encode_tag(1, 0) + _encode_varint(256) + \
+ _encode_tag(2, 0) + _encode_varint(99)
+
+ # Control row: canonical bool true (varint 1) — should work on both
+ row_bool_1 = _encode_tag(1, 0) + _encode_varint(1) + \
+ _encode_tag(2, 0) + _encode_varint(100)
+
+ # Another non-canonical value: varint(512)
+ row_bool_512 = _encode_tag(1, 0) + _encode_varint(512) + \
+ _encode_tag(2, 0) + _encode_varint(101)
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame(
+ [(row_bool_256,), (row_bool_1,), (row_bool_512,)],
+ schema="bin binary",
+ )
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("b").alias("b"),
+ decoded.getField("i32").alias("i32"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_repeated_bool_descriptor_set_bytes(spark):
+ """
+ message WithRepeatedBool {
+ optional int32 id = 1;
+ repeated bool flags = 2;
+ }
+ """
+ D, fd = _new_proto2_file(spark, "repeated_bool.proto")
+ label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL
+ label_rep = D.FieldDescriptorProto.Label.LABEL_REPEATED
+ msg = D.DescriptorProto.newBuilder().setName("WithRepeatedBool")
+ msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("id").setNumber(1).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ msg.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("flags").setNumber(2).setLabel(label_rep)
+ .setType(D.FieldDescriptorProto.Type.TYPE_BOOL).build())
+ fd.addMessageType(msg.build())
+ fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build()
+ return bytes(fds.toByteArray())
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_bool_noncanonical_varint_repeated(spark_tmp_path, from_protobuf_fn):
+ """Regression test: repeated bool with non-canonical varint values must all decode as true.
+
+ Same uint8_t truncation issue as the scalar case, exercised with repeated fields.
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "repeated_bool_bug.desc", _build_repeated_bool_descriptor_set_bytes)
+ message_name = "test.WithRepeatedBool"
+
+ # Repeated bool field 2 (wire type 0 = varint), unpacked.
+ # Three elements: varint(256), varint(1), varint(512) — all should decode as true.
+ row = (_encode_tag(1, 0) + _encode_varint(42) +
+ _encode_tag(2, 0) + _encode_varint(256) +
+ _encode_tag(2, 0) + _encode_varint(1) +
+ _encode_tag(2, 0) + _encode_varint(512))
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame([(row,)], schema="bin binary")
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("id").alias("id"),
+ decoded.getField("flags").alias("flags"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+# ---------------------------------------------------------------------------
+# Regression guard: nested message child field default values
+# ---------------------------------------------------------------------------
+
+def _build_nested_with_defaults_descriptor_set_bytes(spark):
+ """
+ message Inner {
+ optional int32 count = 1 [default = 42];
+ optional string label = 2 [default = "hello"];
+ optional bool flag = 3 [default = true];
+ }
+ message OuterWithNestedDefaults {
+ optional int32 id = 1;
+ optional Inner inner = 2;
+ }
+ """
+ D, fd = _new_proto2_file(spark, "nested_defaults.proto")
+ label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL
+
+ inner = D.DescriptorProto.newBuilder().setName("Inner")
+ inner.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("count").setNumber(1).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32)
+ .setDefaultValue("42").build())
+ inner.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("label").setNumber(2).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_STRING)
+ .setDefaultValue("hello").build())
+ inner.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("flag").setNumber(3).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_BOOL)
+ .setDefaultValue("true").build())
+ fd.addMessageType(inner.build())
+
+ outer = D.DescriptorProto.newBuilder().setName("OuterWithNestedDefaults")
+ outer.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("id").setNumber(1).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build())
+ outer.addField(
+ D.FieldDescriptorProto.newBuilder()
+ .setName("inner").setNumber(2).setLabel(label_opt)
+ .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE)
+ .setTypeName(".test.Inner").build())
+ fd.addMessageType(outer.build())
+
+ fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build()
+ return bytes(fds.toByteArray())
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_nested_child_default_values(spark_tmp_path, from_protobuf_fn):
+ """Regression test: proto2 default values for fields inside nested messages must be honored.
+
+ When `inner` is present but its child fields are absent, the decoder must
+ return the proto2 defaults (count=42, label="hello", flag=true), not null.
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "nested_defaults.desc",
+ _build_nested_with_defaults_descriptor_set_bytes)
+ message_name = "test.OuterWithNestedDefaults"
+
+ # Row 1: outer.id = 10, inner is present but EMPTY (0-length nested message).
+ # Wire: field 1 varint(10), field 2 length-delimited with length 0.
+ # CPU should fill inner.count=42, inner.label="hello", inner.flag=true.
+ row_empty_inner = (_encode_tag(1, 0) + _encode_varint(10) +
+ _encode_tag(2, 2) + _encode_varint(0))
+
+ # Row 2: outer.id = 20, inner has only count=7 (label and flag should get defaults).
+ inner_partial = _encode_tag(1, 0) + _encode_varint(7)
+ row_partial_inner = (_encode_tag(1, 0) + _encode_varint(20) +
+ _encode_tag(2, 2) + _encode_varint(len(inner_partial)) +
+ inner_partial)
+
+ # Row 3: outer.id = 30, inner is fully absent → inner itself is null.
+ row_no_inner = _encode_tag(1, 0) + _encode_varint(30)
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame(
+ [(row_empty_inner,), (row_partial_inner,), (row_no_inner,)],
+ schema="bin binary",
+ )
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("id").alias("id"),
+ decoded.getField("inner").getField("count").alias("inner_count"),
+ decoded.getField("inner").getField("label").alias("inner_label"),
+ decoded.getField("inner").getField("flag").alias("inner_flag"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+# ===========================================================================
+# Deep nested schema pruning tests
+#
+# These verify that the GPU path correctly prunes nested fields at depth > 2.
+# Previously, collectStructFieldReferences only recognized 2-level
+# GetStructField chains, so accessing decoded.level2.level3.val3 would
+# decode ALL of level3's children instead of only val3.
+# ===========================================================================
+
+def _deep_5_level_data_gen():
+ return ProtobufMessageGen([
+ PbScalar("val1", 1, IntegerGen()),
+ PbNested("level2", 2, [
+ PbScalar("val2", 1, IntegerGen()),
+ PbNested("level3", 2, [
+ PbScalar("val3", 1, IntegerGen()),
+ PbNested("level4", 2, [
+ PbScalar("val4", 1, IntegerGen()),
+ PbNested("level5", 2, [
+ PbScalar("val5", 1, IntegerGen()),
+ ]),
+ ]),
+ ]),
+ ]),
+ ])
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_deep_pruning_3_level_leaf(spark_tmp_path, from_protobuf_fn):
+ """Access decoded.level2.level3.val3 -- triggers 3-level deep pruning."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "dp3.desc", _build_deep_nested_5_level_descriptor_set_bytes)
+ message_name = "test.Level1"
+ data_gen = _deep_5_level_data_gen()
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("val1").alias("val1"),
+ decoded.getField("level2").getField("level3").getField("val3").alias("deep_val3"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_deep_pruning_5_level_leaf(spark_tmp_path, from_protobuf_fn):
+ """Access decoded.level2.level3.level4.level5.val5 -- deepest leaf."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "dp5.desc", _build_deep_nested_5_level_descriptor_set_bytes)
+ message_name = "test.Level1"
+ data_gen = _deep_5_level_data_gen()
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ _get_field_by_path(decoded, ["level2", "level3", "level4", "level5", "val5"])
+ .alias("val5"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_deep_pruning_mixed_depths(spark_tmp_path, from_protobuf_fn):
+ """Access leaves at different depths in the same query.
+
+ Select val1 (depth 1), val2 (depth 2), val3 (depth 3), and val5 (depth 5)
+ to exercise pruning at every intermediate level simultaneously.
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "dp_mix.desc", _build_deep_nested_5_level_descriptor_set_bytes)
+ message_name = "test.Level1"
+ data_gen = _deep_5_level_data_gen()
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("val1").alias("val1"),
+ decoded.getField("level2").getField("val2").alias("val2"),
+ _get_field_by_path(decoded, ["level2", "level3", "val3"]).alias("val3"),
+ _get_field_by_path(decoded, ["level2", "level3", "level4", "level5", "val5"])
+ .alias("val5"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_deep_pruning_sibling_at_depth_3(spark_tmp_path, from_protobuf_fn):
+ """At depth 3, access val3 but NOT level4 -- level4 subtree should be pruned."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "dp_sib3.desc", _build_deep_nested_5_level_descriptor_set_bytes)
+ message_name = "test.Level1"
+ data_gen = _deep_5_level_data_gen()
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ _get_field_by_path(decoded, ["level2", "level3", "val3"]).alias("val3"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_deep_pruning_whole_struct_at_depth_3(spark_tmp_path, from_protobuf_fn):
+ """Select the whole level3 struct -- no deep pruning inside level3."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "dp_whole3.desc", _build_deep_nested_5_level_descriptor_set_bytes)
+ message_name = "test.Level1"
+ data_gen = _deep_5_level_data_gen()
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("level2").getField("level3").alias("level3"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+# ===========================================================================
+# FAILFAST mode tests
+# ===========================================================================
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+def test_from_protobuf_failfast_malformed_data(spark_tmp_path, from_protobuf_fn):
+ """FAILFAST mode should throw on malformed protobuf data (both CPU and GPU)."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "failfast.desc", _build_simple_descriptor_set_bytes)
+ message_name = "test.Simple"
+
+ # Craft a valid row and a malformed row (truncated varint with continuation bit)
+ valid_row = _encode_tag(1, 0) + _encode_varint(1) + \
+ _encode_tag(2, 0) + _encode_varint(42)
+ malformed_row = bytes([0x08, 0x80]) # field 1, varint, but only continuation byte -- no end
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame(
+ [(valid_row,), (malformed_row,)],
+ schema="bin binary",
+ )
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes,
+ options={"mode": "FAILFAST"})
+ # Must call .collect() so the exception surfaces inside with_*_session
+ return df.select(decoded.getField("b").alias("b")).collect()
+
+ assert_gpu_and_cpu_error(run_on_spark, {}, "Malformed")
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_permissive_malformed_returns_null(spark_tmp_path, from_protobuf_fn):
+ """PERMISSIVE mode should return null for malformed rows, not throw.
+
+ Note: Spark's from_protobuf defaults to FAILFAST (unlike JSON/CSV which
+ default to PERMISSIVE), so mode must be set explicitly.
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "permissive.desc", _build_simple_descriptor_set_bytes)
+ message_name = "test.Simple"
+
+ valid_row = _encode_tag(2, 0) + _encode_varint(99)
+ malformed_row = bytes([0x08, 0x80]) # truncated varint
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame(
+ [(valid_row,), (malformed_row,)],
+ schema="bin binary",
+ )
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes,
+ options={"mode": "PERMISSIVE"})
+ return df.select(
+ decoded.getField("i32").alias("i32"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_all_null_input(spark_tmp_path, from_protobuf_fn):
+ """All rows in the input binary column are null (not empty bytes, actual nulls).
+ GPU should produce all-null struct rows matching CPU behavior."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "allnull.desc", _build_simple_descriptor_set_bytes)
+ message_name = "test.Simple"
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame(
+ [(None,), (None,), (None,)],
+ schema="bin binary",
+ )
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("i32").alias("i32"),
+ decoded.getField("s").alias("s"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
diff --git a/integration_tests/src/main/python/spark_init_internal.py b/integration_tests/src/main/python/spark_init_internal.py
index 6d67538d513..95f988cbc2b 100644
--- a/integration_tests/src/main/python/spark_init_internal.py
+++ b/integration_tests/src/main/python/spark_init_internal.py
@@ -61,11 +61,46 @@ def findspark_init():
if spark_jars is not None:
logging.info(f"Adding to findspark jars: {spark_jars}")
findspark.add_jars(spark_jars)
+ # Also add to driver classpath so classes are available to Class.forName()
+ # This is needed for optional modules like spark-protobuf
+ _add_driver_classpath(spark_jars)
if spark_jars_packages is not None:
logging.info(f"Adding to findspark packages: {spark_jars_packages}")
findspark.add_packages(spark_jars_packages)
+
+def _add_driver_classpath(jars):
+ """
+ Add jars to the driver classpath via PYSPARK_SUBMIT_ARGS.
+ findspark.add_jars() only adds --jars, which doesn't make classes available
+ to Class.forName() on the driver. This function adds --driver-class-path.
+ """
+ if not jars:
+ return
+ current_args = os.environ.get('PYSPARK_SUBMIT_ARGS', '')
+ # Remove trailing 'pyspark-shell' if present
+ if current_args.endswith('pyspark-shell'):
+ current_args = current_args[:-len('pyspark-shell')].strip()
+ jar_list = [j.strip() for j in jars.split(',') if j.strip()]
+ new_cp = os.pathsep.join(jar_list)
+ if '--driver-class-path' in current_args:
+ match = re.search(r'--driver-class-path\s+(\S+)', current_args)
+ if match:
+ existing_cp = match.group(1)
+ merged_cp = existing_cp + os.pathsep + new_cp
+ current_args = re.sub(
+ r'--driver-class-path\s+\S+',
+ lambda m: f'--driver-class-path {merged_cp}',
+ current_args)
+ else:
+ current_args += f' --driver-class-path {new_cp}'
+ else:
+ current_args += f' --driver-class-path {new_cp}'
+ new_args = f"{current_args} pyspark-shell".strip()
+ os.environ['PYSPARK_SUBMIT_ARGS'] = new_args
+ logging.info(f"Updated PYSPARK_SUBMIT_ARGS with driver-class-path")
+
def running_with_xdist(session, is_worker):
try:
import xdist
diff --git a/integration_tests/src/test/resources/protobuf_test/gen_nested_proto_data.sh b/integration_tests/src/test/resources/protobuf_test/gen_nested_proto_data.sh
new file mode 100755
index 00000000000..2283ac8a751
--- /dev/null
+++ b/integration_tests/src/test/resources/protobuf_test/gen_nested_proto_data.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+# Convenience script: compile nested_proto .proto files into a descriptor set.
+#
+# Usage:
+# ./gen_nested_proto_data.sh
+#
+# The generated .desc file is checked into the repository and used by
+# integration tests in protobuf_test.py. Re-run this script whenever
+# the .proto definitions under nested_proto/ change.
+
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
+PROTO_DIR="${SCRIPT_DIR}/nested_proto"
+OUTPUT_DIR="${SCRIPT_DIR}/nested_proto/generated"
+
+echo "=== Protobuf Descriptor Compiler ==="
+echo "Proto dir: ${PROTO_DIR}"
+echo ""
+
+# Create output directory
+mkdir -p "${OUTPUT_DIR}"
+
+# Compile proto files into a descriptor set (includes all imports)
+DESC_FILE="${OUTPUT_DIR}/main_log.desc"
+echo "Compiling proto files..."
+protoc \
+ --descriptor_set_out="${DESC_FILE}" \
+ --include_imports \
+ -I"${PROTO_DIR}" \
+ "${PROTO_DIR}/main_log.proto"
+
+echo "Generated: ${DESC_FILE}"
+echo "=== Done ==="
diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/device_req.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/device_req.proto
new file mode 100644
index 00000000000..c4d86951d96
--- /dev/null
+++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/device_req.proto
@@ -0,0 +1,11 @@
+syntax = "proto2";
+
+package com.test.proto.sample;
+
+option java_outer_classname = "DeviceReqBean";
+
+// Device request field
+message DeviceReqField {
+ optional int32 os_type = 1; // int32
+ optional bytes device_id = 2; // bytes
+}
diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/generated/main_log.desc b/integration_tests/src/test/resources/protobuf_test/nested_proto/generated/main_log.desc
new file mode 100644
index 00000000000..6e8155238b3
Binary files /dev/null and b/integration_tests/src/test/resources/protobuf_test/nested_proto/generated/main_log.desc differ
diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/main_log.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/main_log.proto
new file mode 100644
index 00000000000..f9a325a9f2e
--- /dev/null
+++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/main_log.proto
@@ -0,0 +1,103 @@
+syntax = "proto2";
+
+package com.test.proto.sample;
+
+import "module_a_res.proto";
+import "module_b_res.proto";
+import "device_req.proto";
+
+// ========== Enum type tests ==========
+enum SourceType {
+ WEB = 0; // web
+ APP = 1; // application
+ MOBILE = 4; // mobile
+}
+
+enum ChannelType {
+ CHANNEL_A = 0;
+ CHANNEL_B = 1;
+}
+
+// ========== Main log record ==========
+message MainLogRecord {
+ // required fields
+ required SourceType source = 1; // required enum
+ required uint64 timestamp = 2; // required uint64
+
+ // optional scalar types - one of each
+ optional string user_id = 3; // string
+ optional int64 account_id = 4; // int64
+ optional fixed32 client_ip = 5; // fixed32
+
+ // nested message
+ optional LogContent log_content = 6;
+}
+
+// ========== Log content (multi-level nesting) ==========
+message LogContent {
+ optional BasicInfo basic_info = 1;
+ repeated ChannelInfo channel_list = 2;
+ repeated DataSourceField source_list = 3;
+}
+
+// ========== Basic info (three-level nesting) ==========
+message BasicInfo {
+ optional RequestInfo request_info = 1;
+ optional ExtendedReqInfo extended_req_info = 2;
+ optional ServerAddedField server_added_field = 3;
+}
+
+// ========== Request info ==========
+message RequestInfo {
+ optional uint32 page_num = 1; // uint32
+ optional string channel_code = 2; // string
+ repeated uint32 experiment_ids = 3; // repeated uint32
+ optional bool is_filtered = 4; // bool
+}
+
+// ========== Extended request info (cross-file import) ==========
+message ExtendedReqInfo {
+ optional DeviceReqField device_req_field = 1; // reference to external proto
+}
+
+// ========== Server-added fields ==========
+message ServerAddedField {
+ optional uint32 region_code = 1; // uint32
+ optional string flow_type = 2; // string
+ optional int32 filter_result = 3; // int32
+ repeated int32 hit_rule_list = 4; // repeated int32
+ optional uint64 request_time = 5; // uint64
+ optional bool skip_flag = 6; // bool
+}
+
+// ========== Channel info ==========
+message ChannelInfo {
+ optional int32 channel_id = 1;
+ optional ModuleAResField module_a_res = 2; // reference to external proto
+}
+
+// ========== Source channel info ==========
+message SrcChannelInfo {
+ optional int32 channel_id = 1;
+ optional ModuleASrcResField module_a_src_res = 2; // reference to external proto
+}
+
+// ========== Data source field ==========
+message DataSourceField {
+ optional uint32 source_id = 1;
+ repeated SrcChannelInfo src_channel_list = 2;
+ optional string billing_name = 3;
+ repeated ItemDetailField item_list = 4;
+ optional bool is_free = 5;
+}
+
+// ========== Item detail field ==========
+message ItemDetailField {
+ optional uint32 rank = 1; // uint32
+ optional uint64 record_id = 2; // uint64
+ optional string keyword = 3; // string
+
+ // cross-file message references
+ optional ModuleADetailField module_a_detail = 4;
+ optional ModuleBDetailField module_b_detail = 5;
+}
diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/module_a_res.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_a_res.proto
new file mode 100644
index 00000000000..9a2674e5dd5
--- /dev/null
+++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_a_res.proto
@@ -0,0 +1,92 @@
+syntax = "proto2";
+
+package com.test.proto.sample;
+
+option java_outer_classname = "ModuleARes";
+
+import "predictor_schema.proto";
+
+// ========== Test default values ==========
+message PartnerInfo {
+ optional string token = 1 [default = ""]; // default empty string
+ optional uint64 partner_id = 2 [default = 0]; // default 0
+}
+
+// ========== Test coordinate structure (simple nesting) ==========
+message Coordinate {
+ optional double x = 1; // x coordinate - double
+ optional double y = 2; // y coordinate - double
+}
+
+// ========== Test multi-level nesting ==========
+message LocationPoint {
+ optional uint32 frequency = 1; // frequency - uint32
+ optional Coordinate coord = 2; // coordinate - nested message
+ optional uint64 timestamp = 3; // timestamp - uint64
+}
+
+// ========== Test change log ==========
+message ChangeLog {
+ optional uint32 value_before = 1; // value before change
+ optional string parameters = 2; // parameters
+}
+
+// ========== Test price change log ==========
+message PriceLog {
+ optional uint32 price_before = 1;
+}
+
+// ========== Module A response-level fields ==========
+message ModuleAResField {
+ optional string route_tag = 1; // route tag - string
+ optional int32 status_tag = 2; // status tag - int32
+ optional uint32 region_id = 3; // region id - uint32
+ repeated string experiment_ids = 4; // experiment id list - repeated string
+ optional double quality_score = 5; // quality score - double
+ repeated LocationPoint location_points = 6; // location points - repeated nested message
+ repeated uint64 interest_ids = 7; // interest id list - repeated uint64
+}
+
+// ========== Module A source response field ==========
+message ModuleASrcResField {
+ optional uint32 match_type = 1; // match type
+}
+
+// ========== Key-value pair ==========
+message KVPair {
+ optional bytes key = 1; // key - bytes
+ optional bytes value = 2; // value - bytes
+}
+
+// ========== Style configuration ==========
+message StyleConfig {
+ optional uint32 style_id = 1; // style id
+ repeated KVPair kv_pairs = 2; // kv pair list - repeated nested message
+}
+
+// ========== Module A detail field (core complex structure) ==========
+message ModuleADetailField {
+ // scalar types - one or two of each
+ optional uint32 type_code = 1; // uint32
+ optional uint64 item_id = 2; // uint64
+ optional int32 strategy_type = 3; // int32
+ optional int64 min_value = 4; // int64
+ optional bytes target_url = 5; // bytes
+ optional string title = 6; // string
+ optional bool is_valid = 7; // bool
+ optional float score_ratio = 8; // float
+
+ // repeated scalar types
+ repeated uint32 template_ids = 9; // repeated uint32
+ repeated uint64 material_ids = 10; // repeated uint64
+
+ // repeated nested messages
+ repeated StyleConfig styles = 11; // repeated message
+ repeated ChangeLog change_logs = 12; // repeated message
+
+ // nested message
+ optional PartnerInfo partner_info = 13; // nested message
+
+ // cross-file import
+ optional PredictorSchema predictor_schema = 14; // reference to external proto
+}
diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/module_b_res.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_b_res.proto
new file mode 100644
index 00000000000..0c742739835
--- /dev/null
+++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_b_res.proto
@@ -0,0 +1,29 @@
+syntax = "proto2";
+
+package com.test.proto.sample;
+
+// Module B response field
+message ModuleBResField {
+ optional uint32 type_code = 1; // uint32
+ optional string extra_info = 2; // string
+}
+
+// Block element - tests repeated scalar types
+message BlockElement {
+ optional uint64 element_id = 1; // uint64
+ repeated uint64 ref_ids = 2; // repeated uint64
+}
+
+// Block info - tests repeated nested messages
+message BlockInfo {
+ optional uint64 block_id = 1; // uint64
+ repeated BlockElement elements = 2; // repeated message
+}
+
+// Module B detail field
+message ModuleBDetailField {
+ repeated uint32 tags = 1; // repeated uint32
+ optional uint64 item_id = 2; // uint64
+ optional string name = 3; // string
+ repeated BlockInfo blocks = 4; // repeated message
+}
diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/predictor_schema.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/predictor_schema.proto
new file mode 100644
index 00000000000..bc7d86c67e2
--- /dev/null
+++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/predictor_schema.proto
@@ -0,0 +1,82 @@
+syntax = "proto2";
+
+package com.test.proto.sample;
+
+// Predictor schema - tests multi-level schema nesting
+
+// ========== Main schema structure ==========
+message PredictorSchema {
+ optional SchemaTypeA type_a_schema = 1; // nested
+ optional SchemaTypeB type_b_schema = 2; // nested
+ optional SchemaTypeC type_c_schema = 3; // nested (with repeated)
+}
+
+// ========== TypeA Query Schema ==========
+message TypeAQuerySchema {
+ optional string keyword = 1; // keyword
+ optional string session_id = 2; // session id
+}
+
+// ========== TypeA Pair Schema ==========
+message TypeAPairSchema {
+ optional string record_id = 1; // record id
+ optional string item_id = 2; // item id
+}
+
+// ========== TypeA Schema ==========
+message SchemaTypeA {
+ optional TypeAQuerySchema query_schema = 1; // query-level schema
+ repeated TypeAPairSchema pair_schema = 2; // pair list - repeated nested
+}
+
+// ========== TypeB Query Schema ==========
+message TypeBQuerySchema {
+ optional string profile_tag_id = 1; // profile tag id
+ optional string entity_id = 2; // entity id
+}
+
+// ========== TypeB Style Element ==========
+message TypeBStyleElem {
+ optional string template_id = 1; // template id
+ optional string material_id = 2; // material id
+}
+
+// ========== TypeB Style Schema ==========
+message TypeBStyleSchema {
+ repeated TypeBStyleElem values = 1; // element list - repeated nested
+}
+
+// ========== TypeB Schema ==========
+message SchemaTypeB {
+ optional TypeBQuerySchema query_schema = 1; // query-level schema
+ repeated TypeBStyleSchema style_schema = 2; // style list - repeated nested
+}
+
+// ========== TypeC Query Schema ==========
+message TypeCQuerySchema {
+ optional string keyword = 1; // keyword
+ optional string category = 2; // category
+}
+
+// ========== TypeC Pair Schema ==========
+message TypeCPairSchema {
+ optional string item_id = 1; // item id
+ optional string target_url = 2; // target url
+}
+
+// ========== TypeC Style Element (empty structure test) ==========
+message TypeCStyleElem {
+ // empty message - tests empty structure handling
+}
+
+// ========== TypeC Style Schema ==========
+message TypeCStyleSchema {
+ repeated TypeCStyleElem values = 1; // empty element list
+}
+
+// ========== TypeC Schema (full three-level structure) ==========
+message SchemaTypeC {
+ optional TypeCQuerySchema query_schema = 1; // query schema
+ repeated TypeCPairSchema pair_schema = 2; // pair list
+ repeated TypeCStyleSchema style_schema = 3; // style list
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index fef7bea85dd..29ad563f6b3 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -2628,10 +2628,7 @@ object GpuOverrides extends Logging {
TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.MAP + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.BINARY),
TypeSig.STRUCT.nested(TypeSig.all)),
- (expr, conf, p, r) => new UnaryExprMeta[GetStructField](expr, conf, p, r) {
- override def convertToGpu(arr: Expression): GpuExpression =
- GpuGetStructField(arr, expr.ordinal, expr.name)
- }),
+ (expr, conf, p, r) => new GpuGetStructFieldMeta(expr, conf, p, r)),
expr[GetArrayItem](
"Gets the field at `ordinal` in the Array",
ExprChecks.binaryProject(
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
index 5a507e7dab7..5243bcff6c6 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
@@ -1600,6 +1600,15 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
.booleanConf
.createWithDefault(true)
+ val ENABLE_PROTOBUF_BATCH_MERGE_AFTER_PROJECT =
+ conf("spark.rapids.sql.protobuf.batchMergeAfterProject.enabled")
+ .doc("When set to true, allows a GPU Project containing a schema-pruned " +
+ "`from_protobuf` decode to request a post-project batch coalesce. This is intended " +
+ "to reduce tiny batches produced after protobuf schema projection.")
+ .internal()
+ .booleanConf
+ .createWithDefault(false)
+
val ENABLE_ORC_FLOAT_TYPES_TO_STRING =
conf("spark.rapids.sql.format.orc.floatTypesToString.enable")
.doc("When reading an ORC file, the source data schemas(schemas of ORC file) may differ " +
@@ -3539,6 +3548,9 @@ class RapidsConf(conf: Map[String, String]) extends Logging {
lazy val isCoalesceAfterExpandEnabled: Boolean = get(ENABLE_COALESCE_AFTER_EXPAND)
+ lazy val isProtobufBatchMergeAfterProjectEnabled: Boolean =
+ get(ENABLE_PROTOBUF_BATCH_MERGE_AFTER_PROJECT)
+
lazy val multiThreadReadNumThreads: Int = {
// Use the largest value set among all the options.
val deprecatedConfs = Seq(
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala
index a97f830fe3e..5db9f27ad0b 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala
@@ -47,6 +47,49 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.random.BernoulliCellSampler
+object GpuProjectExecMeta {
+ private def isProtobufDecodeExpr(expr: Expression): Boolean =
+ expr.getClass.getName.endsWith("ProtobufDataToCatalyst")
+
+ private def protobufAliasExprIds(plan: SparkPlan): Set[ExprId] = plan match {
+ case ProjectExec(projectList, _) =>
+ projectList.collect {
+ case alias: Alias if isProtobufDecodeExpr(alias.child) => alias.exprId
+ }.toSet
+ case _ => Set.empty
+ }
+
+ private def isRootedAtProtobufDecode(expr: Expression, protobufExprIds: Set[ExprId]): Boolean =
+ expr match {
+ case attr: AttributeReference =>
+ protobufExprIds.contains(attr.exprId)
+ case decode if isProtobufDecodeExpr(decode) =>
+ true
+ case GetStructField(child, _, _) =>
+ isRootedAtProtobufDecode(child, protobufExprIds)
+ case gasf: GetArrayStructFields =>
+ isRootedAtProtobufDecode(gasf.child, protobufExprIds)
+ case _ =>
+ false
+ }
+
+ private def extractsFromProtobufAlias(
+ expr: Expression,
+ protobufExprIds: Set[ExprId]): Boolean = expr match {
+ case GetStructField(child, _, _) =>
+ isRootedAtProtobufDecode(child, protobufExprIds)
+ case gasf: GetArrayStructFields =>
+ isRootedAtProtobufDecode(gasf.child, protobufExprIds)
+ case other =>
+ other.children.exists(child => extractsFromProtobufAlias(child, protobufExprIds))
+ }
+
+ private[rapids] def shouldCoalesceAfterProject(plan: ProjectExec): Boolean = {
+ val protobufExprIds = protobufAliasExprIds(plan.child)
+ plan.projectList.exists(expr => extractsFromProtobufAlias(expr, protobufExprIds))
+ }
+}
+
class GpuProjectExecMeta(
proj: ProjectExec,
conf: RapidsConf,
@@ -57,11 +100,13 @@ class GpuProjectExecMeta(
// Force list to avoid recursive Java serialization of lazy list Seq implementation
val gpuExprs = childExprs.map(_.convertToGpu().asInstanceOf[NamedExpression]).toList
val gpuChild = childPlans.head.convertIfNeeded()
+ val forcePostProjectCoalesce = conf.isProtobufBatchMergeAfterProjectEnabled &&
+ GpuProjectExecMeta.shouldCoalesceAfterProject(proj)
if (conf.isProjectAstEnabled) {
// cuDF requires return column is fixed width
val allReturnTypesFixedWidth = gpuExprs.forall(e => GpuBatchUtils.isFixedWidth(e.dataType))
if (allReturnTypesFixedWidth && childExprs.forall(_.canThisBeAst)) {
- return GpuProjectAstExec(gpuExprs, gpuChild)
+ return GpuProjectAstExec(gpuExprs, gpuChild, forcePostProjectCoalesce)
}
// explain AST because this is optional and it is sometimes hard to debug
if (conf.shouldExplain) {
@@ -76,7 +121,7 @@ class GpuProjectExecMeta(
}
}
}
- GpuProjectExec(gpuExprs, gpuChild)
+ GpuProjectExec(gpuExprs, gpuChild, forcePostProjectCoalesce = forcePostProjectCoalesce)
}
}
@@ -290,6 +335,7 @@ object GpuProjectExecLike {
trait GpuProjectExecLike extends GpuPartitioningPreservingUnaryExecNode with GpuExec {
def projectList: Seq[Expression]
+ def forcePostProjectCoalesce: Boolean
override lazy val additionalMetrics: Map[String, GpuMetric] = Map(
OP_TIME_LEGACY -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_OP_TIME_LEGACY))
@@ -299,8 +345,16 @@ trait GpuProjectExecLike extends GpuPartitioningPreservingUnaryExecNode with Gpu
override def doExecute(): RDD[InternalRow] =
throw new IllegalStateException(s"Row-based execution should not occur for $this")
- // The same as what feeds us
- override def outputBatching: CoalesceGoal = GpuExec.outputBatching(child)
+ override def coalesceAfter: Boolean = forcePostProjectCoalesce
+
+ // Flagged protobuf projects intentionally drop the output batching guarantee so that
+ // the post-project coalesce inserted by transition rules is not optimized away.
+ override def outputBatching: CoalesceGoal =
+ if (forcePostProjectCoalesce) {
+ null
+ } else {
+ GpuExec.outputBatching(child)
+ }
}
/**
@@ -763,14 +817,19 @@ case class GpuProjectExec(
// immutable/List.scala#L516
projectList: List[NamedExpression],
child: SparkPlan,
- enablePreSplit: Boolean = true) extends GpuProjectExecLike {
+ enablePreSplit: Boolean = true,
+ forcePostProjectCoalesce: Boolean = false) extends GpuProjectExecLike {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override def outputBatching: CoalesceGoal = if (enablePreSplit) {
// Pre-split will make sure the size of each output batch will not be larger
// than the splitUntilSize.
- TargetSize(PreProjectSplitIterator.getSplitUntilSize)
+ if (forcePostProjectCoalesce) {
+ super.outputBatching
+ } else {
+ TargetSize(PreProjectSplitIterator.getSplitUntilSize)
+ }
} else {
super.outputBatching
}
@@ -822,7 +881,8 @@ case class GpuProjectAstExec(
// serde: https://github.com/scala/scala/blob/2.12.x/src/library/scala/collection/
// immutable/List.scala#L516
projectList: List[Expression],
- child: SparkPlan
+ child: SparkPlan,
+ forcePostProjectCoalesce: Boolean = false
) extends GpuProjectExecLike {
override def output: Seq[Attribute] = {
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala
new file mode 100644
index 00000000000..3cf6a290041
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala
@@ -0,0 +1,203 @@
+/*
+ * Copyright (c) 2025-2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids
+
+import java.util.Arrays
+
+import ai.rapids.cudf
+import ai.rapids.cudf.{BinaryOp, CudfException, DType}
+import com.nvidia.spark.rapids.{GpuColumnVector, GpuUnaryExpression}
+import com.nvidia.spark.rapids.Arm.withResource
+import com.nvidia.spark.rapids.jni.{Protobuf, ProtobufSchemaDescriptor}
+import com.nvidia.spark.rapids.shims.NullIntolerantShim
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
+import org.apache.spark.sql.types._
+
+/**
+ * GPU implementation for Spark's `from_protobuf` decode path.
+ *
+ * This is designed to replace `org.apache.spark.sql.protobuf.ProtobufDataToCatalyst` when
+ * supported.
+ *
+ * The implementation uses a flattened schema representation where nested fields have parent
+ * indices pointing to their containing message field. For pure scalar schemas, all fields
+ * are top-level (parentIndices == -1, depthLevels == 0, isRepeated == false).
+ *
+ * Schema projection is supported: `decodedSchema` contains only the top-level fields and
+ * nested children that are actually referenced by downstream operators. Downstream
+ * `GetStructField` and `GetArrayStructFields` nodes have their ordinals rewritten via
+ * `PRUNED_ORDINAL_TAG` to index into the pruned schema. Unreferenced fields are never
+ * accessed, so no null-column filling is needed.
+ *
+ * @param decodedSchema The pruned schema containing only the fields decoded by the GPU.
+ * Only fields referenced by downstream operators are included;
+ * ordinal remapping ensures correct field access into the pruned output.
+ * @param fieldNumbers Protobuf field numbers for all fields in flattened schema
+ * @param parentIndices Parent indices for all fields (-1 for top-level)
+ * @param depthLevels Nesting depth for all fields (0 for top-level)
+ * @param wireTypes Wire types for all fields
+ * @param outputTypeIds cuDF type IDs for all fields
+ * @param encodings Encodings for all fields
+ * @param isRepeated Whether each field is repeated
+ * @param isRequired Whether each field is required
+ * @param hasDefaultValue Whether each field has a default value
+ * @param defaultInts Default int/long values
+ * @param defaultFloats Default float/double values
+ * @param defaultBools Default bool values
+ * @param defaultStrings Default string/bytes values
+ * @param enumValidValues Valid enum values for each field
+ * @param enumNames Enum value names for enum-as-string fields. Parallel to enumValidValues.
+ * @param failOnErrors If true, throw exception on malformed data
+ */
+case class GpuFromProtobuf(
+ decodedSchema: StructType,
+ fieldNumbers: Array[Int],
+ parentIndices: Array[Int],
+ depthLevels: Array[Int],
+ wireTypes: Array[Int],
+ outputTypeIds: Array[Int],
+ encodings: Array[Int],
+ isRepeated: Array[Boolean],
+ isRequired: Array[Boolean],
+ hasDefaultValue: Array[Boolean],
+ defaultInts: Array[Long],
+ defaultFloats: Array[Double],
+ defaultBools: Array[Boolean],
+ defaultStrings: Array[Array[Byte]],
+ enumValidValues: Array[Array[Int]],
+ enumNames: Array[Array[Array[Byte]]],
+ failOnErrors: Boolean,
+ child: Expression)
+ extends GpuUnaryExpression with ExpectsInputTypes with NullIntolerantShim with Logging {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
+
+ override def dataType: DataType = decodedSchema
+
+ override def nullable: Boolean = true
+
+ override def equals(other: Any): Boolean = other match {
+ case that: GpuFromProtobuf =>
+ decodedSchema == that.decodedSchema &&
+ Arrays.equals(fieldNumbers, that.fieldNumbers) &&
+ Arrays.equals(parentIndices, that.parentIndices) &&
+ Arrays.equals(depthLevels, that.depthLevels) &&
+ Arrays.equals(wireTypes, that.wireTypes) &&
+ Arrays.equals(outputTypeIds, that.outputTypeIds) &&
+ Arrays.equals(encodings, that.encodings) &&
+ Arrays.equals(isRepeated, that.isRepeated) &&
+ Arrays.equals(isRequired, that.isRequired) &&
+ Arrays.equals(hasDefaultValue, that.hasDefaultValue) &&
+ Arrays.equals(defaultInts, that.defaultInts) &&
+ Arrays.equals(defaultFloats, that.defaultFloats) &&
+ Arrays.equals(defaultBools, that.defaultBools) &&
+ GpuFromProtobuf.deepEquals(defaultStrings, that.defaultStrings) &&
+ GpuFromProtobuf.deepEquals(enumValidValues, that.enumValidValues) &&
+ GpuFromProtobuf.deepEquals(enumNames, that.enumNames) &&
+ failOnErrors == that.failOnErrors &&
+ child == that.child
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ var result = decodedSchema.hashCode()
+ result = 31 * result + Arrays.hashCode(fieldNumbers)
+ result = 31 * result + Arrays.hashCode(parentIndices)
+ result = 31 * result + Arrays.hashCode(depthLevels)
+ result = 31 * result + Arrays.hashCode(wireTypes)
+ result = 31 * result + Arrays.hashCode(outputTypeIds)
+ result = 31 * result + Arrays.hashCode(encodings)
+ result = 31 * result + Arrays.hashCode(isRepeated)
+ result = 31 * result + Arrays.hashCode(isRequired)
+ result = 31 * result + Arrays.hashCode(hasDefaultValue)
+ result = 31 * result + Arrays.hashCode(defaultInts)
+ result = 31 * result + Arrays.hashCode(defaultFloats)
+ result = 31 * result + Arrays.hashCode(defaultBools)
+ result = 31 * result + GpuFromProtobuf.deepHashCode(defaultStrings)
+ result = 31 * result + GpuFromProtobuf.deepHashCode(enumValidValues)
+ result = 31 * result + GpuFromProtobuf.deepHashCode(enumNames)
+ result = 31 * result + failOnErrors.hashCode()
+ result = 31 * result + child.hashCode()
+ result
+ }
+
+ // ProtobufSchemaDescriptor is a pure-Java immutable holder for validated schema arrays.
+ // It does not own native resources, so task-scoped close hooks are not required here.
+ @transient private lazy val protobufSchema = new ProtobufSchemaDescriptor(
+ fieldNumbers, parentIndices, depthLevels, wireTypes, outputTypeIds, encodings,
+ isRepeated, isRequired, hasDefaultValue, defaultInts, defaultFloats, defaultBools,
+ defaultStrings, enumValidValues, enumNames)
+
+ override protected def doColumnar(input: GpuColumnVector): cudf.ColumnVector = {
+ val jniResult = try {
+ Protobuf.decodeToStruct(input.getBase, protobufSchema, failOnErrors)
+ } catch {
+ case e: CudfException if failOnErrors =>
+ throw new org.apache.spark.SparkException("Malformed protobuf message", e)
+ case e: CudfException =>
+ logWarning(s"Unexpected CudfException in PERMISSIVE mode: ${e.getMessage}", e)
+ throw e
+ }
+
+ // Apply input nulls to output
+ if (input.getBase.hasNulls) {
+ withResource(jniResult) { _ =>
+ jniResult.mergeAndSetValidity(BinaryOp.BITWISE_AND, input.getBase)
+ }
+ } else {
+ jniResult
+ }
+ }
+}
+
+object GpuFromProtobuf {
+ val ENC_DEFAULT = 0
+ val ENC_FIXED = 1
+ val ENC_ZIGZAG = 2
+ val ENC_ENUM_STRING = 3
+
+ /**
+ * Maps a Spark DataType to the corresponding cuDF native type ID.
+ * Note: The encoding (varint/zigzag/fixed) is determined by the protobuf field type,
+ * not the Spark data type, so it must be set separately based on the protobuf schema.
+ *
+ * @return Some(typeId) for supported types, None for unsupported types
+ */
+ def sparkTypeToCudfIdOpt(dt: DataType): Option[Int] = dt match {
+ case BooleanType => Some(DType.BOOL8.getTypeId.getNativeId)
+ case IntegerType => Some(DType.INT32.getTypeId.getNativeId)
+ case LongType => Some(DType.INT64.getTypeId.getNativeId)
+ case FloatType => Some(DType.FLOAT32.getTypeId.getNativeId)
+ case DoubleType => Some(DType.FLOAT64.getTypeId.getNativeId)
+ case StringType => Some(DType.STRING.getTypeId.getNativeId)
+ case BinaryType => Some(DType.LIST.getTypeId.getNativeId)
+ case _ => None
+ }
+
+ /**
+ * Check if a Spark DataType is supported by the GPU protobuf decoder.
+ */
+ def isTypeSupported(dt: DataType): Boolean = sparkTypeToCudfIdOpt(dt).isDefined
+
+ private def deepEquals[T](left: Array[T], right: Array[T]): Boolean =
+ Arrays.deepEquals(left.asInstanceOf[Array[Object]], right.asInstanceOf[Array[Object]])
+
+ private def deepHashCode[T](arr: Array[T]): Int =
+ Arrays.deepHashCode(arr.asInstanceOf[Array[Object]])
+}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
index 9afb3b854d6..988c03b00b1 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
@@ -32,7 +32,10 @@ import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, IntegralType, LongType, MapType, StructField, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
-case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
+case class GpuGetStructField(
+ child: Expression,
+ ordinal: Int,
+ name: Option[String] = None)
extends ShimUnaryExpression
with GpuExpression
with ShimGetStructField
@@ -41,15 +44,23 @@ case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[Strin
lazy val childSchema: StructType = child.dataType.asInstanceOf[StructType]
override def dataType: DataType = childSchema(ordinal).dataType
- override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable
+
+ override def nullable: Boolean =
+ child.nullable || childSchema(ordinal).nullable
override def toString: String = {
- val fieldName = if (resolved) childSchema(ordinal).name else s"_$ordinal"
+ val fieldName = if (resolved) {
+ childSchema(ordinal).name
+ } else {
+ s"_$ordinal"
+ }
s"$child.${name.getOrElse(fieldName)}"
}
- override def sql: String =
- child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}"
+ override def sql: String = {
+ val fieldName = childSchema(ordinal).name
+ child.sql + s".${quoteIdentifier(name.getOrElse(fieldName))}"
+ }
override def columnarEvalAny(batch: ColumnarBatch): Any = {
val dt = dataType
@@ -59,7 +70,6 @@ case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[Strin
GpuColumnVector.from(view.copyToColumnVector(), dt)
}
case s: GpuScalar =>
- // For a scalar in we want a scalar out.
if (!s.isValid) {
GpuScalar(null, dt)
} else {
@@ -402,6 +412,27 @@ case class GpuArrayPosition(left: Expression, right: Expression)
}
}
+object GpuStructFieldOrdinalTag {
+ val PRUNED_ORDINAL_TAG =
+ new org.apache.spark.sql.catalyst.trees.TreeNodeTag[Int]("GPU_PRUNED_ORDINAL")
+}
+
+class GpuGetStructFieldMeta(
+ expr: GetStructField,
+ conf: RapidsConf,
+ parent: Option[RapidsMeta[_, _, _]],
+ rule: DataFromReplacementRule)
+ extends UnaryExprMeta[GetStructField](expr, conf, parent, rule) {
+
+ def convertToGpu(child: Expression): GpuExpression = {
+ val runtimeOrd = expr.getTagValue(
+ GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).getOrElse(-1)
+ val effectiveOrd =
+ if (runtimeOrd >= 0) runtimeOrd else expr.ordinal
+ GpuGetStructField(child, effectiveOrd, expr.name)
+ }
+}
+
class GpuGetArrayStructFieldsMeta(
expr: GetArrayStructFields,
conf: RapidsConf,
@@ -409,8 +440,32 @@ class GpuGetArrayStructFieldsMeta(
rule: DataFromReplacementRule)
extends UnaryExprMeta[GetArrayStructFields](expr, conf, parent, rule) {
- def convertToGpu(child: Expression): GpuExpression =
- GpuGetArrayStructFields(child, expr.field, expr.ordinal, expr.numFields, expr.containsNull)
+ def convertToGpu(child: Expression): GpuExpression = {
+ val runtimeOrd = expr.getTagValue(
+ GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).getOrElse(-1)
+ val effectiveOrd =
+ if (runtimeOrd >= 0) runtimeOrd else expr.ordinal
+ val effectiveNumFields =
+ GpuGetArrayStructFieldsMeta.effectiveNumFields(child, expr, runtimeOrd)
+ GpuGetArrayStructFields(child, expr.field,
+ effectiveOrd, effectiveNumFields, expr.containsNull)
+ }
+}
+
+object GpuGetArrayStructFieldsMeta {
+ def effectiveNumFields(
+ child: Expression,
+ expr: GetArrayStructFields,
+ runtimeOrd: Int): Int = {
+ if (runtimeOrd >= 0) {
+ child.dataType match {
+ case ArrayType(st: StructType, _) => st.fields.length
+ case _ => expr.numFields
+ }
+ } else {
+ expr.numFields
+ }
+ }
}
/**
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaExtractor.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaExtractor.scala
new file mode 100644
index 00000000000..a6617e1c89a
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaExtractor.scala
@@ -0,0 +1,215 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids.protobuf
+
+import scala.collection.mutable
+
+import com.nvidia.spark.rapids.jni.Protobuf.{WT_32BIT, WT_64BIT, WT_LEN, WT_VARINT}
+
+import org.apache.spark.sql.rapids.GpuFromProtobuf
+import org.apache.spark.sql.types._
+
+object ProtobufSchemaExtractor {
+ def analyzeAllFields(
+ schema: StructType,
+ msgDesc: ProtobufMessageDescriptor,
+ enumsAsInts: Boolean,
+ messageName: String): Either[String, Map[String, ProtobufFieldInfo]] = {
+ val result = mutable.Map[String, ProtobufFieldInfo]()
+
+ schema.fields.foreach { sf =>
+ val fieldInfo = msgDesc.findField(sf.name) match {
+ case None =>
+ unsupportedFieldInfo(
+ sf,
+ None,
+ s"Protobuf field '${sf.name}' not found in message '$messageName'")
+ case Some(fd) =>
+ extractFieldInfo(sf, fd, enumsAsInts) match {
+ case Right(info) =>
+ info
+ case Left(reason) =>
+ unsupportedFieldInfo(sf, Some(fd), reason)
+ }
+ }
+ result(sf.name) = fieldInfo
+ }
+
+ Right(result.toMap)
+ }
+
+ def extractFieldInfo(
+ sparkField: StructField,
+ fieldDescriptor: ProtobufFieldDescriptor,
+ enumsAsInts: Boolean): Either[String, ProtobufFieldInfo] = {
+ val (isSupported, unsupportedReason, encoding) =
+ checkFieldSupport(
+ sparkField.dataType,
+ fieldDescriptor.protoTypeName,
+ fieldDescriptor.isRepeated,
+ enumsAsInts)
+
+ fieldDescriptor.defaultValueResult.map { defaultValue =>
+ ProtobufFieldInfo(
+ fieldNumber = fieldDescriptor.fieldNumber,
+ protoTypeName = fieldDescriptor.protoTypeName,
+ sparkType = sparkField.dataType,
+ encoding = encoding,
+ isSupported = isSupported,
+ unsupportedReason = unsupportedReason,
+ isRequired = fieldDescriptor.isRequired,
+ defaultValue = defaultValue,
+ enumMetadata = fieldDescriptor.enumMetadata,
+ isRepeated = fieldDescriptor.isRepeated
+ )
+ }
+ }
+
+ private def unsupportedFieldInfo(
+ sparkField: StructField,
+ fieldDescriptor: Option[ProtobufFieldDescriptor],
+ reason: String): ProtobufFieldInfo = {
+ ProtobufFieldInfo(
+ fieldNumber = fieldDescriptor.map(_.fieldNumber).getOrElse(-1),
+ protoTypeName = fieldDescriptor.map(_.protoTypeName).getOrElse("UNKNOWN"),
+ sparkType = sparkField.dataType,
+ encoding = GpuFromProtobuf.ENC_DEFAULT,
+ isSupported = false,
+ unsupportedReason = Some(reason),
+ isRequired = fieldDescriptor.exists(_.isRequired),
+ defaultValue = None,
+ enumMetadata = fieldDescriptor.flatMap(_.enumMetadata),
+ isRepeated = fieldDescriptor.exists(_.isRepeated)
+ )
+ }
+
+ def checkFieldSupport(
+ sparkType: DataType,
+ protoTypeName: String,
+ isRepeated: Boolean,
+ enumsAsInts: Boolean): (Boolean, Option[String], Int) = {
+
+ if (isRepeated) {
+ sparkType match {
+ case ArrayType(elementType, _) =>
+ elementType match {
+ case BooleanType | IntegerType | LongType | FloatType | DoubleType |
+ StringType | BinaryType =>
+ return checkScalarEncoding(elementType, protoTypeName, enumsAsInts)
+ case _: StructType =>
+ return (true, None, GpuFromProtobuf.ENC_DEFAULT)
+ case _ =>
+ return (
+ false,
+ Some(s"unsupported repeated element type: $elementType"),
+ GpuFromProtobuf.ENC_DEFAULT)
+ }
+ case _ =>
+ return (
+ false,
+ Some(s"repeated field should map to ArrayType, got: $sparkType"),
+ GpuFromProtobuf.ENC_DEFAULT)
+ }
+ }
+
+ if (protoTypeName == "MESSAGE") {
+ sparkType match {
+ case _: StructType =>
+ return (true, None, GpuFromProtobuf.ENC_DEFAULT)
+ case _ =>
+ return (
+ false,
+ Some(s"nested message should map to StructType, got: $sparkType"),
+ GpuFromProtobuf.ENC_DEFAULT)
+ }
+ }
+
+ sparkType match {
+ case BooleanType | IntegerType | LongType | FloatType | DoubleType |
+ StringType | BinaryType =>
+ case other =>
+ return (
+ false,
+ Some(s"unsupported Spark type: $other"),
+ GpuFromProtobuf.ENC_DEFAULT)
+ }
+
+ checkScalarEncoding(sparkType, protoTypeName, enumsAsInts)
+ }
+
+ def checkScalarEncoding(
+ sparkType: DataType,
+ protoTypeName: String,
+ enumsAsInts: Boolean): (Boolean, Option[String], Int) = {
+ val encoding = (sparkType, protoTypeName) match {
+ case (BooleanType, "BOOL") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (IntegerType, "INT32" | "UINT32") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (IntegerType, "SINT32") => Some(GpuFromProtobuf.ENC_ZIGZAG)
+ case (IntegerType, "FIXED32" | "SFIXED32") => Some(GpuFromProtobuf.ENC_FIXED)
+ case (LongType, "INT64" | "UINT64") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (LongType, "SINT64") => Some(GpuFromProtobuf.ENC_ZIGZAG)
+ case (LongType, "FIXED64" | "SFIXED64") => Some(GpuFromProtobuf.ENC_FIXED)
+ case (LongType, "INT32" | "UINT32" | "SINT32" | "FIXED32" | "SFIXED32") =>
+ val enc = protoTypeName match {
+ case "SINT32" => GpuFromProtobuf.ENC_ZIGZAG
+ case "FIXED32" | "SFIXED32" => GpuFromProtobuf.ENC_FIXED
+ case _ => GpuFromProtobuf.ENC_DEFAULT
+ }
+ Some(enc)
+ case (FloatType, "FLOAT") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (DoubleType, "DOUBLE") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (StringType, "STRING") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (BinaryType, "BYTES") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (IntegerType, "ENUM") if enumsAsInts => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (StringType, "ENUM") if !enumsAsInts => Some(GpuFromProtobuf.ENC_ENUM_STRING)
+ case _ => None
+ }
+
+ encoding match {
+ case Some(enc) => (true, None, enc)
+ case None =>
+ (false,
+ Some(s"type mismatch: Spark $sparkType vs Protobuf $protoTypeName"),
+ GpuFromProtobuf.ENC_DEFAULT)
+ }
+ }
+
+ def getWireType(protoTypeName: String, encoding: Int): Either[String, Int] = {
+ val wireType = protoTypeName match {
+ case "BOOL" | "INT32" | "UINT32" | "SINT32" | "INT64" | "UINT64" | "SINT64" | "ENUM" =>
+ if (encoding == GpuFromProtobuf.ENC_FIXED) {
+ if (protoTypeName.contains("64")) {
+ WT_64BIT
+ } else {
+ WT_32BIT
+ }
+ } else {
+ WT_VARINT
+ }
+ case "FIXED32" | "SFIXED32" | "FLOAT" =>
+ WT_32BIT
+ case "FIXED64" | "SFIXED64" | "DOUBLE" =>
+ WT_64BIT
+ case "STRING" | "BYTES" | "MESSAGE" =>
+ WT_LEN
+ case other =>
+ return Left(
+ s"Unknown protobuf type name '$other' - cannot determine wire type; falling back to CPU")
+ }
+ Right(wireType)
+ }
+}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaModel.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaModel.scala
new file mode 100644
index 00000000000..77c1cbdf8f3
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaModel.scala
@@ -0,0 +1,135 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids.protobuf
+
+import java.util.Arrays
+
+import org.apache.spark.sql.types.DataType
+
+sealed trait ProtobufDescriptorSource
+
+object ProtobufDescriptorSource {
+ final case class DescriptorPath(path: String) extends ProtobufDescriptorSource
+ final case class DescriptorBytes(bytes: Array[Byte]) extends ProtobufDescriptorSource {
+ override def equals(other: Any): Boolean = other match {
+ case DescriptorBytes(otherBytes) => Arrays.equals(bytes, otherBytes)
+ case _ => false
+ }
+
+ override def hashCode(): Int = Arrays.hashCode(bytes)
+ }
+}
+
+final case class ProtobufExprInfo(
+ messageName: String,
+ descriptorSource: ProtobufDescriptorSource,
+ options: Map[String, String])
+
+final case class ProtobufPlannerOptions(
+ enumsAsInts: Boolean,
+ failOnErrors: Boolean)
+
+sealed trait ProtobufDefaultValue
+
+object ProtobufDefaultValue {
+ final case class BoolValue(value: Boolean) extends ProtobufDefaultValue
+ final case class IntValue(value: Long) extends ProtobufDefaultValue
+ final case class FloatValue(value: Float) extends ProtobufDefaultValue
+ final case class DoubleValue(value: Double) extends ProtobufDefaultValue
+ final case class StringValue(value: String) extends ProtobufDefaultValue
+ final case class BinaryValue(value: Array[Byte]) extends ProtobufDefaultValue
+ final case class EnumValue(number: Int, name: String) extends ProtobufDefaultValue
+}
+
+final case class ProtobufEnumValue(number: Int, name: String)
+
+final case class ProtobufEnumMetadata(values: Seq[ProtobufEnumValue]) {
+ lazy val validValues: Array[Int] = values.map(_.number).toArray
+ lazy val orderedNames: Array[Array[Byte]] = values.map(_.name.getBytes("UTF-8")).toArray
+ lazy val namesByNumber: Map[Int, String] = values.map(v => v.number -> v.name).toMap
+
+ def enumDefault(number: Int): ProtobufDefaultValue.EnumValue = {
+ val name = namesByNumber.getOrElse(number, s"$number")
+ ProtobufDefaultValue.EnumValue(number, name)
+ }
+}
+
+trait ProtobufMessageDescriptor {
+ def syntax: String
+ def findField(name: String): Option[ProtobufFieldDescriptor]
+}
+
+trait ProtobufFieldDescriptor {
+ def name: String
+ def fieldNumber: Int
+ def protoTypeName: String
+ def isRepeated: Boolean
+ def isRequired: Boolean
+ def defaultValueResult: Either[String, Option[ProtobufDefaultValue]]
+ def enumMetadata: Option[ProtobufEnumMetadata]
+ def messageDescriptor: Option[ProtobufMessageDescriptor]
+}
+
+final case class ProtobufFieldInfo(
+ fieldNumber: Int,
+ protoTypeName: String,
+ sparkType: DataType,
+ encoding: Int,
+ isSupported: Boolean,
+ unsupportedReason: Option[String],
+ isRequired: Boolean,
+ defaultValue: Option[ProtobufDefaultValue],
+ enumMetadata: Option[ProtobufEnumMetadata],
+ isRepeated: Boolean = false) {
+ def hasDefaultValue: Boolean = defaultValue.isDefined
+}
+
+final case class FlattenedFieldDescriptor(
+ fieldNumber: Int,
+ parentIdx: Int,
+ depth: Int,
+ wireType: Int,
+ outputTypeId: Int,
+ encoding: Int,
+ isRepeated: Boolean,
+ isRequired: Boolean,
+ hasDefaultValue: Boolean,
+ defaultInt: Long,
+ defaultFloat: Double,
+ defaultBool: Boolean,
+ defaultString: Array[Byte],
+ enumValidValues: Array[Int],
+ enumNames: Array[Array[Byte]]
+)
+
+final case class FlattenedSchemaArrays(
+ fieldNumbers: Array[Int],
+ parentIndices: Array[Int],
+ depthLevels: Array[Int],
+ wireTypes: Array[Int],
+ outputTypeIds: Array[Int],
+ encodings: Array[Int],
+ isRepeated: Array[Boolean],
+ isRequired: Array[Boolean],
+ hasDefaultValue: Array[Boolean],
+ defaultInts: Array[Long],
+ defaultFloats: Array[Double],
+ defaultBools: Array[Boolean],
+ defaultStrings: Array[Array[Byte]],
+ enumValidValues: Array[Array[Int]],
+ enumNames: Array[Array[Array[Byte]]]
+)
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaValidator.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaValidator.scala
new file mode 100644
index 00000000000..627fed83296
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaValidator.scala
@@ -0,0 +1,177 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids.protobuf
+
+import org.apache.spark.sql.rapids.GpuFromProtobuf
+import org.apache.spark.sql.types._
+
+object ProtobufSchemaValidator {
+ private final case class JniDefaultValues(
+ defaultInt: Long,
+ defaultFloat: Double,
+ defaultBool: Boolean,
+ defaultString: Array[Byte])
+
+ def toFlattenedFieldDescriptor(
+ path: String,
+ field: StructField,
+ fieldInfo: ProtobufFieldInfo,
+ parentIdx: Int,
+ depth: Int,
+ outputTypeId: Int): Either[String, FlattenedFieldDescriptor] = {
+ validateFieldInfo(path, field, fieldInfo).flatMap { _ =>
+ ProtobufSchemaExtractor
+ .getWireType(fieldInfo.protoTypeName, fieldInfo.encoding)
+ .flatMap { wireType =>
+ encodeDefaultValue(path, field.dataType, fieldInfo).map { defaults =>
+ val enumValidValues = fieldInfo.enumMetadata.map(_.validValues).orNull
+ val enumNames =
+ if (fieldInfo.encoding == GpuFromProtobuf.ENC_ENUM_STRING) {
+ fieldInfo.enumMetadata.map(_.orderedNames).orNull
+ } else {
+ null
+ }
+
+ FlattenedFieldDescriptor(
+ fieldNumber = fieldInfo.fieldNumber,
+ parentIdx = parentIdx,
+ depth = depth,
+ wireType = wireType,
+ outputTypeId = outputTypeId,
+ encoding = fieldInfo.encoding,
+ isRepeated = fieldInfo.isRepeated,
+ isRequired = fieldInfo.isRequired,
+ hasDefaultValue = fieldInfo.hasDefaultValue,
+ defaultInt = defaults.defaultInt,
+ defaultFloat = defaults.defaultFloat,
+ defaultBool = defaults.defaultBool,
+ defaultString = defaults.defaultString,
+ enumValidValues = enumValidValues,
+ enumNames = enumNames
+ )
+ }
+ }
+ }
+ }
+
+ def validateFlattenedSchema(flatFields: Seq[FlattenedFieldDescriptor]): Either[String, Unit] = {
+ flatFields.zipWithIndex.foreach { case (field, idx) =>
+ if (field.parentIdx >= idx) {
+ return Left(s"Flattened protobuf schema has invalid parent index at position $idx")
+ }
+ if (field.parentIdx == -1 && field.depth != 0) {
+ return Left(s"Top-level protobuf field at position $idx must have depth 0")
+ }
+ if (field.parentIdx >= 0 && field.depth <= 0) {
+ return Left(s"Nested protobuf field at position $idx must have positive depth")
+ }
+ if (field.encoding == GpuFromProtobuf.ENC_ENUM_STRING) {
+ if (field.enumValidValues == null || field.enumNames == null) {
+ return Left(s"Enum-string field at position $idx is missing enum metadata")
+ }
+ if (field.enumValidValues.length != field.enumNames.length) {
+ return Left(s"Enum-string field at position $idx has mismatched enum metadata")
+ }
+ }
+ }
+ Right(())
+ }
+
+ def toFlattenedSchemaArrays(
+ flatFields: Array[FlattenedFieldDescriptor]): FlattenedSchemaArrays = {
+ FlattenedSchemaArrays(
+ fieldNumbers = flatFields.map(_.fieldNumber),
+ parentIndices = flatFields.map(_.parentIdx),
+ depthLevels = flatFields.map(_.depth),
+ wireTypes = flatFields.map(_.wireType),
+ outputTypeIds = flatFields.map(_.outputTypeId),
+ encodings = flatFields.map(_.encoding),
+ isRepeated = flatFields.map(_.isRepeated),
+ isRequired = flatFields.map(_.isRequired),
+ hasDefaultValue = flatFields.map(_.hasDefaultValue),
+ defaultInts = flatFields.map(_.defaultInt),
+ defaultFloats = flatFields.map(_.defaultFloat),
+ defaultBools = flatFields.map(_.defaultBool),
+ defaultStrings = flatFields.map(_.defaultString),
+ enumValidValues = flatFields.map(_.enumValidValues),
+ enumNames = flatFields.map(_.enumNames)
+ )
+ }
+
+ private def validateFieldInfo(
+ path: String,
+ field: StructField,
+ fieldInfo: ProtobufFieldInfo): Either[String, Unit] = {
+ if (fieldInfo.isRepeated && fieldInfo.hasDefaultValue) {
+ return Left(s"Repeated protobuf field '$path' cannot carry a default value")
+ }
+
+ fieldInfo.enumMetadata match {
+ case Some(enumMeta) if enumMeta.values.isEmpty =>
+ return Left(s"Enum field '$path' is missing enum values")
+ case Some(_) if fieldInfo.protoTypeName != "ENUM" =>
+ return Left(s"Non-enum field '$path' should not carry enum metadata")
+ case None if fieldInfo.protoTypeName == "ENUM" =>
+ return Left(s"Enum field '$path' is missing enum metadata")
+ case _ =>
+ }
+
+ if (fieldInfo.encoding == GpuFromProtobuf.ENC_ENUM_STRING &&
+ fieldInfo.enumMetadata.isEmpty) {
+ return Left(s"Enum-string field '$path' is missing enum metadata")
+ }
+
+ Right(())
+ }
+
+ private def encodeDefaultValue(
+ path: String,
+ dataType: DataType,
+ fieldInfo: ProtobufFieldInfo): Either[String, JniDefaultValues] = {
+ val empty = JniDefaultValues(0L, 0.0, defaultBool = false, null)
+ fieldInfo.defaultValue match {
+ case None => Right(empty)
+ case Some(defaultValue) =>
+ val targetType = dataType match {
+ case ArrayType(elementType, _) => elementType
+ case other => other
+ }
+ (targetType, defaultValue) match {
+ case (BooleanType, ProtobufDefaultValue.BoolValue(value)) =>
+ Right(empty.copy(defaultBool = value))
+ case (IntegerType | LongType, ProtobufDefaultValue.IntValue(value)) =>
+ Right(empty.copy(defaultInt = value))
+ case (IntegerType | LongType, ProtobufDefaultValue.EnumValue(number, _)) =>
+ Right(empty.copy(defaultInt = number.toLong))
+ case (FloatType, ProtobufDefaultValue.FloatValue(value)) =>
+ Right(empty.copy(defaultFloat = value.toDouble))
+ case (DoubleType, ProtobufDefaultValue.DoubleValue(value)) =>
+ Right(empty.copy(defaultFloat = value))
+ case (StringType, ProtobufDefaultValue.StringValue(value)) =>
+ Right(empty.copy(defaultString = value.getBytes("UTF-8")))
+ case (StringType, ProtobufDefaultValue.EnumValue(number, name)) =>
+ Right(empty.copy(
+ defaultInt = number.toLong,
+ defaultString = name.getBytes("UTF-8")))
+ case (BinaryType, ProtobufDefaultValue.BinaryValue(value)) =>
+ Right(empty.copy(defaultString = value))
+ case _ =>
+ Left(s"Incompatible default value for protobuf field '$path': $defaultValue")
+ }
+ }
+ }
+}
diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala
new file mode 100644
index 00000000000..44f67ffff10
--- /dev/null
+++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala
@@ -0,0 +1,791 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "340"}
+{"spark": "341"}
+{"spark": "342"}
+{"spark": "343"}
+{"spark": "344"}
+{"spark": "350"}
+{"spark": "351"}
+{"spark": "352"}
+{"spark": "353"}
+{"spark": "354"}
+{"spark": "355"}
+{"spark": "356"}
+{"spark": "357"}
+{"spark": "400"}
+{"spark": "401"}
+{"spark": "402"}
+{"spark": "411"}
+spark-rapids-shim-json-lines ***/
+
+package com.nvidia.spark.rapids.shims
+
+import scala.collection.mutable
+
+import ai.rapids.cudf.DType
+import com.nvidia.spark.rapids._
+
+import org.apache.spark.sql.catalyst.expressions.{
+ AttributeReference, Expression, GetArrayStructFields, GetStructField, UnaryExpression
+}
+import org.apache.spark.sql.execution.ProjectExec
+import org.apache.spark.sql.rapids.GpuFromProtobuf
+import org.apache.spark.sql.rapids.protobuf.{
+ FlattenedFieldDescriptor,
+ ProtobufFieldInfo,
+ ProtobufMessageDescriptor,
+ ProtobufSchemaExtractor,
+ ProtobufSchemaValidator
+}
+import org.apache.spark.sql.types._
+
+/**
+ * Spark 3.4+ optional integration for spark-protobuf expressions.
+ *
+ * spark-protobuf is an external module, so these rules must be registered by reflection.
+ */
+object ProtobufExprShims extends org.apache.spark.internal.Logging {
+ private[this] val protobufDataToCatalystClassName =
+ "org.apache.spark.sql.protobuf.ProtobufDataToCatalyst"
+
+ val PRUNED_ORDINAL_TAG =
+ org.apache.spark.sql.rapids.GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG
+
+ def exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
+ try {
+ val clazz = ShimReflectionUtils.loadClass(protobufDataToCatalystClassName)
+ .asInstanceOf[Class[_ <: UnaryExpression]]
+ Map(clazz.asInstanceOf[Class[_ <: Expression]] -> fromProtobufRule)
+ } catch {
+ case _: ClassNotFoundException => Map.empty
+ case e: Exception =>
+ logWarning(s"Failed to load $protobufDataToCatalystClassName: ${e.getMessage}")
+ Map.empty
+ }
+ }
+
+ private def fromProtobufRule: ExprRule[_ <: Expression] = {
+ GpuOverrides.expr[UnaryExpression](
+ "Decode a BinaryType column (protobuf) into a Spark SQL struct",
+ ExprChecks.unaryProject(
+ // Use TypeSig.all here because schema projection determines which fields
+ // actually need GPU support. Detailed type checking is done in tagExprForGpu.
+ TypeSig.all,
+ TypeSig.all,
+ TypeSig.BINARY,
+ TypeSig.BINARY),
+ (e, conf, p, r) => new UnaryExprMeta[UnaryExpression](e, conf, p, r) {
+
+ private var fullSchema: StructType = _
+ private var failOnErrors: Boolean = _
+
+ // Flattened schema variables for GPU decoding
+ private var flatFieldNumbers: Array[Int] = _
+ private var flatParentIndices: Array[Int] = _
+ private var flatDepthLevels: Array[Int] = _
+ private var flatWireTypes: Array[Int] = _
+ private var flatOutputTypeIds: Array[Int] = _
+ private var flatEncodings: Array[Int] = _
+ private var flatIsRepeated: Array[Boolean] = _
+ private var flatIsRequired: Array[Boolean] = _
+ private var flatHasDefaultValue: Array[Boolean] = _
+ private var flatDefaultInts: Array[Long] = _
+ private var flatDefaultFloats: Array[Double] = _
+ private var flatDefaultBools: Array[Boolean] = _
+ private var flatDefaultStrings: Array[Array[Byte]] = _
+ private var flatEnumValidValues: Array[Array[Int]] = _
+ private var flatEnumNames: Array[Array[Array[Byte]]] = _
+ // Indices in fullSchema for top-level fields that were decoded (for schema projection)
+ private var decodedTopLevelIndices: Array[Int] = _
+
+ override def tagExprForGpu(): Unit = {
+ fullSchema = e.dataType match {
+ case st: StructType => st
+ case other =>
+ willNotWorkOnGpu(
+ s"Only StructType output is supported for from_protobuf, got $other")
+ return
+ }
+
+ val exprInfo = SparkProtobufCompat.extractExprInfo(e) match {
+ case Right(info) => info
+ case Left(reason) =>
+ willNotWorkOnGpu(reason)
+ return
+ }
+ val unsupportedOptions = SparkProtobufCompat.unsupportedOptions(exprInfo.options)
+ if (unsupportedOptions.nonEmpty) {
+ val keys = unsupportedOptions.mkString(",")
+ willNotWorkOnGpu(
+ s"from_protobuf options are not supported yet on GPU: $keys")
+ return
+ }
+
+ val plannerOptions = SparkProtobufCompat.parsePlannerOptions(exprInfo.options) match {
+ case Right(opts) => opts
+ case Left(reason) =>
+ willNotWorkOnGpu(reason)
+ return
+ }
+ val enumsAsInts = plannerOptions.enumsAsInts
+ failOnErrors = plannerOptions.failOnErrors
+ val messageName = exprInfo.messageName
+
+ val msgDesc = SparkProtobufCompat.resolveMessageDescriptor(exprInfo) match {
+ case Right(desc) => desc
+ case Left(reason) =>
+ willNotWorkOnGpu(reason)
+ return
+ }
+
+ // Reject proto3 descriptors — GPU decoder only supports proto2 semantics.
+ // proto3 has different null/default-value behavior that the GPU path doesn't handle.
+ val protoSyntax = msgDesc.syntax
+ if (!SparkProtobufCompat.isGpuSupportedProtoSyntax(protoSyntax)) {
+ willNotWorkOnGpu(
+ "proto3/editions syntax is not supported by the GPU protobuf decoder; " +
+ "only proto2 is supported. The query will fall back to CPU.")
+ return
+ }
+
+ // Step 1: Analyze all fields and build field info map
+ val fieldsInfoMap =
+ ProtobufSchemaExtractor.analyzeAllFields(fullSchema, msgDesc, enumsAsInts, messageName)
+ .fold({ reason =>
+ willNotWorkOnGpu(reason)
+ return
+ }, identity)
+
+ // Step 2: Determine which fields are actually required by downstream operations
+ val requiredFieldNames = analyzeRequiredFields(fieldsInfoMap.keySet)
+
+ // Step 3: Check if all required fields are supported
+ val unsupportedRequired = requiredFieldNames.filter { name =>
+ fieldsInfoMap.get(name).exists(!_.isSupported)
+ }
+
+ if (unsupportedRequired.nonEmpty) {
+ val reasons = unsupportedRequired.map { name =>
+ val info = fieldsInfoMap(name)
+ s"${name}: ${info.unsupportedReason.getOrElse("unknown reason")}"
+ }
+ willNotWorkOnGpu(
+ s"Required fields not supported for from_protobuf: ${reasons.mkString(", ")}")
+ return
+ }
+
+ // Step 4: Identify which fields in fullSchema need to be decoded
+ // These are fields that are required AND supported
+ val indicesToDecode = fullSchema.fields.zipWithIndex.collect {
+ case (sf, idx) if requiredFieldNames.contains(sf.name) => idx
+ }
+
+ // Verify all fields to be decoded are actually supported
+ // (This catches edge cases where field analysis might have issues)
+ val unsupportedInDecode = indicesToDecode.filter { idx =>
+ val sf = fullSchema.fields(idx)
+ fieldsInfoMap.get(sf.name).exists(!_.isSupported)
+ }
+ if (unsupportedInDecode.nonEmpty) {
+ val reasons = unsupportedInDecode.map { idx =>
+ val sf = fullSchema.fields(idx)
+ val info = fieldsInfoMap(sf.name)
+ s"${sf.name}: ${info.unsupportedReason.getOrElse("unknown reason")}"
+ }
+ willNotWorkOnGpu(
+ s"Fields not supported for from_protobuf: ${reasons.mkString(", ")}")
+ return
+ }
+
+ // Step 5: Build flattened schema for GPU decoding.
+ // The flattened schema represents nested fields with parent indices.
+ // For pure scalar schemas, all fields are top-level (parentIdx == -1, depth == 0).
+ {
+ val flatFields = mutable.ArrayBuffer[FlattenedFieldDescriptor]()
+
+ // Helper to add a field and its children recursively.
+ // pathPrefix is the dot-path of ancestor fields (empty for top-level).
+ def addFieldWithChildren(
+ sf: StructField,
+ info: ProtobufFieldInfo,
+ parentIdx: Int,
+ depth: Int,
+ nestedMsgDesc: ProtobufMessageDescriptor,
+ pathPrefix: String = ""): Unit = {
+
+ val currentIdx = flatFields.size
+
+ if (depth >= 10) {
+ willNotWorkOnGpu("Protobuf nesting depth exceeds maximum supported depth of 10")
+ return
+ }
+
+ val outputType = sf.dataType match {
+ case ArrayType(elemType, _) =>
+ elemType match {
+ case _: StructType =>
+ // Repeated message field: ArrayType(StructType) - element type is STRUCT
+ DType.STRUCT.getTypeId.getNativeId
+ case other =>
+ GpuFromProtobuf.sparkTypeToCudfIdOpt(other)
+ .getOrElse(DType.INT8.getTypeId.getNativeId)
+ }
+ case _: StructType =>
+ DType.STRUCT.getTypeId.getNativeId
+ case other =>
+ GpuFromProtobuf.sparkTypeToCudfIdOpt(other)
+ .getOrElse(DType.INT8.getTypeId.getNativeId)
+ }
+
+ val path = if (pathPrefix.isEmpty) sf.name else s"$pathPrefix.${sf.name}"
+ ProtobufSchemaValidator.toFlattenedFieldDescriptor(
+ path,
+ sf,
+ info,
+ parentIdx,
+ depth,
+ outputType).fold({ reason =>
+ willNotWorkOnGpu(reason)
+ return
+ }, flatFields += _)
+
+ // For nested struct types (including repeated message = ArrayType(StructType)),
+ // add child fields
+ sf.dataType match {
+ case st: StructType if nestedMsgDesc != null =>
+ addChildFieldsFromStruct(
+ st, nestedMsgDesc, sf.name, currentIdx, depth, pathPrefix)
+
+ case ArrayType(st: StructType, _) if nestedMsgDesc != null =>
+ addChildFieldsFromStruct(
+ st, nestedMsgDesc, sf.name, currentIdx, depth, pathPrefix)
+
+ case _ => // Not a struct, no children to add
+ }
+ }
+
+ // Helper to add child fields from a struct type.
+ // Applies nested schema pruning at arbitrary depth using path-based
+ // lookup into nestedFieldRequirements.
+ def addChildFieldsFromStruct(
+ st: StructType,
+ parentMsgDesc: ProtobufMessageDescriptor,
+ fieldName: String,
+ parentIdx: Int,
+ parentDepth: Int,
+ pathPrefix: String): Unit = {
+ val path = if (pathPrefix.isEmpty) fieldName else s"$pathPrefix.$fieldName"
+ val parentField = parentMsgDesc.findField(fieldName)
+ if (parentField.isEmpty) {
+ willNotWorkOnGpu(
+ s"Nested field '$fieldName' not found in protobuf descriptor at '$path'")
+ } else {
+ parentField.get.messageDescriptor match {
+ case Some(childMsgDesc) =>
+ val requiredChildren = nestedFieldRequirements.get(path)
+ val filteredFields = requiredChildren match {
+ case Some(Some(childNames)) =>
+ st.fields.filter(f => childNames.contains(f.name))
+ case _ =>
+ st.fields
+ }
+ filteredFields.foreach { childSf =>
+ childMsgDesc.findField(childSf.name) match {
+ case None =>
+ willNotWorkOnGpu(
+ s"Nested field '${childSf.name}' not found in protobuf " +
+ s"descriptor for message at '$path'")
+ return
+ case Some(childFd) =>
+ ProtobufSchemaExtractor
+ .extractFieldInfo(childSf, childFd, enumsAsInts) match {
+ case Left(reason) =>
+ willNotWorkOnGpu(reason)
+ return
+ case Right(childInfo) =>
+ if (!childInfo.isSupported) {
+ willNotWorkOnGpu(
+ s"Nested field '${childSf.name}' at '$path': " +
+ childInfo.unsupportedReason.getOrElse("unsupported type"))
+ return
+ } else {
+ addFieldWithChildren(
+ childSf, childInfo, parentIdx, parentDepth + 1, childMsgDesc,
+ path)
+ }
+ }
+ }
+ }
+ case None =>
+ willNotWorkOnGpu(
+ s"Nested field '$fieldName' at '$path' did not resolve to a message type")
+ }
+ }
+ }
+
+ // Only add top-level fields that are actually required (schema projection).
+ // This significantly reduces GPU memory and computation for schemas with many
+ // fields when only a few are needed. Downstream GetStructField ordinals are
+ // remapped via PRUNED_ORDINAL_TAG to index into the pruned output.
+ decodedTopLevelIndices = indicesToDecode
+ indicesToDecode.foreach { schemaIdx =>
+ val sf = fullSchema.fields(schemaIdx)
+ val info = fieldsInfoMap(sf.name)
+ addFieldWithChildren(sf, info, -1, 0, msgDesc)
+ }
+
+ // Populate flattened schema variables
+ val flat = flatFields.toArray
+ ProtobufSchemaValidator.validateFlattenedSchema(flat).fold({ reason =>
+ willNotWorkOnGpu(reason)
+ return
+ }, identity)
+ val arrays = ProtobufSchemaValidator.toFlattenedSchemaArrays(flat)
+ flatFieldNumbers = arrays.fieldNumbers
+ flatParentIndices = arrays.parentIndices
+ flatDepthLevels = arrays.depthLevels
+ flatWireTypes = arrays.wireTypes
+ flatOutputTypeIds = arrays.outputTypeIds
+ flatEncodings = arrays.encodings
+ flatIsRepeated = arrays.isRepeated
+ flatIsRequired = arrays.isRequired
+ flatHasDefaultValue = arrays.hasDefaultValue
+ flatDefaultInts = arrays.defaultInts
+ flatDefaultFloats = arrays.defaultFloats
+ flatDefaultBools = arrays.defaultBools
+ flatDefaultStrings = arrays.defaultStrings
+ flatEnumValidValues = arrays.enumValidValues
+ flatEnumNames = arrays.enumNames
+ }
+ }
+
+ /**
+ * Analyze which fields are actually required by downstream operations.
+ * Traverses parent plan nodes upward, collecting struct field references from
+ * ProjectExec, FilterExec, and transparent pass-through nodes (AggregateExec,
+ * SortExec, WindowExec, etc.), then returns the set of required top-level
+ * field names.
+ *
+ * @param allFieldNames All field names in the full schema
+ * @return Set of field names that are actually required
+ */
+ private var targetExprsToRemap: Seq[Expression] = Seq.empty
+
+ private def analyzeRequiredFields(allFieldNames: Set[String]): Set[String] = {
+ val fieldReqs = mutable.Map[String, Option[Set[String]]]()
+ protobufOutputExprIds = Set.empty
+ var hasDirectStructRef = false
+ val holder = () => { hasDirectStructRef = true }
+
+ var currentMeta: Option[SparkPlanMeta[_]] = findParentPlanMeta()
+ var safeToPrune = true
+ val collectedExprs = mutable.ArrayBuffer[Expression]()
+
+ def advanceToParent(): Unit = {
+ currentMeta = currentMeta.get.parent match {
+ case Some(pm: SparkPlanMeta[_]) => Some(pm)
+ case _ => None
+ }
+ }
+
+ while (currentMeta.isDefined && safeToPrune) {
+ currentMeta.get.wrapped match {
+ case p: ProjectExec =>
+ collectedExprs ++= p.projectList
+ p.projectList.foreach {
+ case alias: org.apache.spark.sql.catalyst.expressions.Alias
+ if isProtobufStructReference(alias.child) =>
+ protobufOutputExprIds += alias.exprId
+ case _ =>
+ }
+ p.projectList.foreach(collectStructFieldReferences(_, fieldReqs, holder))
+ currentMeta = None
+ case f: org.apache.spark.sql.execution.FilterExec =>
+ collectedExprs += f.condition
+ collectStructFieldReferences(f.condition, fieldReqs, holder)
+ advanceToParent()
+ case a: org.apache.spark.sql.execution.aggregate.BaseAggregateExec =>
+ val exprs = a.aggregateExpressions ++ a.groupingExpressions
+ collectedExprs ++= exprs
+ exprs.foreach(collectStructFieldReferences(_, fieldReqs, holder))
+ advanceToParent()
+ case s: org.apache.spark.sql.execution.SortExec =>
+ val exprs = s.sortOrder
+ collectedExprs ++= exprs
+ exprs.foreach(collectStructFieldReferences(_, fieldReqs, holder))
+ advanceToParent()
+ case w: org.apache.spark.sql.execution.window.WindowExec =>
+ val exprs = w.windowExpression
+ collectedExprs ++= exprs
+ exprs.foreach(collectStructFieldReferences(_, fieldReqs, holder))
+ advanceToParent()
+ case _ =>
+ safeToPrune = false
+ }
+ }
+
+ if (!safeToPrune || collectedExprs.isEmpty || hasDirectStructRef || fieldReqs.isEmpty) {
+ targetExprsToRemap = Seq.empty
+ allFieldNames
+ } else {
+ nestedFieldRequirements = fieldReqs.toMap
+ targetExprsToRemap = collectedExprs.toSeq
+ val prunedFieldsMap = buildPrunedFieldsMap()
+ val topLevelIndices = fullSchema.fields.zipWithIndex.collect {
+ case (sf, idx) if fieldReqs.keySet.contains(sf.name) => idx
+ }
+ targetExprsToRemap.foreach(
+ registerPrunedOrdinals(_, prunedFieldsMap, topLevelIndices))
+ fieldReqs.keySet.toSet
+ }
+ }
+
+ /**
+ * Find the parent SparkPlanMeta by traversing up the parent chain.
+ */
+ private def findParentPlanMeta(): Option[SparkPlanMeta[_]] = {
+ def traverse(meta: Option[RapidsMeta[_, _, _]]): Option[SparkPlanMeta[_]] = {
+ meta match {
+ case Some(p: SparkPlanMeta[_]) => Some(p)
+ case Some(p: RapidsMeta[_, _, _]) => traverse(p.parent)
+ case _ => None
+ }
+ }
+ traverse(parent)
+ }
+
+ /**
+ * Nested field requirements: maps a field path to child requirements.
+ * Keys are dot-separated paths from the protobuf root:
+ * - "level1" -> Some(Set("level2")) (top-level struct pruning)
+ * - "level1.level2" -> Some(Set("level3")) (deep nested pruning)
+ * - "field" -> None (whole field needed)
+ *
+ * Top-level names (keys without dots) also determine which fields are decoded.
+ */
+ private var nestedFieldRequirements: Map[String, Option[Set[String]]] = Map.empty
+ private var protobufOutputExprIds: Set[
+ org.apache.spark.sql.catalyst.expressions.ExprId] = Set.empty
+
+ private def getFieldName(ordinal: Int, nameOpt: Option[String],
+ schema: StructType): String = {
+ nameOpt.getOrElse {
+ if (ordinal < schema.fields.length) schema.fields(ordinal).name
+ else s"_$ordinal"
+ }
+ }
+
+ /**
+ * Navigate the Spark schema tree by following a dot-separated path of
+ * field names. Returns the StructType at the end of the path, unwrapping
+ * ArrayType(StructType) along the way, or null if the path is invalid.
+ */
+ private def resolveSchemaAtPath(root: StructType, path: Seq[String]): StructType = {
+ var current: StructType = root
+ for (name <- path) {
+ val field = current.fields.find(_.name == name).orNull
+ if (field == null) return null
+ field.dataType match {
+ case st: StructType => current = st
+ case ArrayType(st: StructType, _) => current = st
+ case _ => return null
+ }
+ }
+ current
+ }
+
+ /**
+ * Walk a GetStructField chain upward until it reaches the protobuf
+ * reference expression, returning the sequence of field names forming
+ * the access path. Returns None if the chain does not terminate at a
+ * protobuf reference.
+ *
+ * Example: for `GetStructField(GetStructField(decoded, a_ord), b_ord)`
+ * → Some(Seq("a", "b"))
+ */
+ private def resolveFieldAccessChain(
+ expr: Expression): Option[Seq[String]] = {
+ expr match {
+ case GetStructField(child, ordinal, nameOpt) =>
+ if (isProtobufStructReference(child)) {
+ Some(Seq(getFieldName(ordinal, nameOpt, fullSchema)))
+ } else {
+ resolveFieldAccessChain(child).flatMap { parentPath =>
+ val parentSchema = if (parentPath.isEmpty) fullSchema
+ else resolveSchemaAtPath(fullSchema, parentPath)
+ if (parentSchema != null) {
+ Some(parentPath :+ getFieldName(ordinal, nameOpt, parentSchema))
+ } else {
+ None
+ }
+ }
+ }
+ case _ if isProtobufStructReference(expr) =>
+ Some(Seq.empty)
+ case _ =>
+ None
+ }
+ }
+
+ private def addNestedFieldReq(
+ fieldReqs: mutable.Map[String, Option[Set[String]]],
+ parentKey: String,
+ childName: String): Unit = {
+ fieldReqs.get(parentKey) match {
+ case Some(None) => // Already need whole field, keep it
+ case Some(Some(existing)) =>
+ fieldReqs(parentKey) = Some(existing + childName)
+ case None =>
+ fieldReqs(parentKey) = Some(Set(childName))
+ }
+ }
+
+ /**
+ * Register pruning requirements at every level of a field access path.
+ * For path = ["a", "b"] with leafName = "c":
+ * "a" -> needs child "b"
+ * "a.b" -> needs child "c"
+ */
+ private def registerPathRequirements(
+ fieldReqs: mutable.Map[String, Option[Set[String]]],
+ path: Seq[String],
+ leafName: String): Unit = {
+ for (i <- path.indices) {
+ val pathKey = path.take(i + 1).mkString(".")
+ val childName = if (i < path.length - 1) path(i + 1) else leafName
+ addNestedFieldReq(fieldReqs, pathKey, childName)
+ }
+ }
+
+ private def collectStructFieldReferences(
+ expr: Expression,
+ fieldReqs: mutable.Map[String, Option[Set[String]]],
+ hasDirectStructRefHolder: () => Unit): Unit = {
+ expr match {
+ case GetStructField(child, ordinal, nameOpt) =>
+ resolveFieldAccessChain(child) match {
+ case Some(parentPath) =>
+ val parentSchema = if (parentPath.isEmpty) fullSchema
+ else resolveSchemaAtPath(fullSchema, parentPath)
+ if (parentSchema != null) {
+ val fieldName = getFieldName(ordinal, nameOpt, parentSchema)
+ if (parentPath.isEmpty) {
+ // Direct top-level access: decoded.field_name (whole field)
+ fieldReqs(fieldName) = None
+ } else {
+ registerPathRequirements(fieldReqs, parentPath, fieldName)
+ }
+ } else {
+ collectStructFieldReferences(child, fieldReqs, hasDirectStructRefHolder)
+ }
+ case None =>
+ collectStructFieldReferences(child, fieldReqs, hasDirectStructRefHolder)
+ }
+
+ case gasf: GetArrayStructFields =>
+ resolveFieldAccessChain(gasf.child) match {
+ case Some(parentPath) if parentPath.nonEmpty =>
+ registerPathRequirements(fieldReqs, parentPath, gasf.field.name)
+ case Some(parentPath) if parentPath.isEmpty =>
+ fieldReqs(gasf.field.name) = None
+ case _ =>
+ gasf.children.foreach { child =>
+ collectStructFieldReferences(child, fieldReqs, hasDirectStructRefHolder)
+ }
+ }
+
+ case alias: org.apache.spark.sql.catalyst.expressions.Alias =>
+ if (!isProtobufStructReference(alias.child)) {
+ collectStructFieldReferences(alias.child, fieldReqs, hasDirectStructRefHolder)
+ }
+
+ case _ =>
+ if (isProtobufStructReference(expr)) {
+ hasDirectStructRefHolder()
+ }
+ expr.children.foreach { child =>
+ collectStructFieldReferences(child, fieldReqs, hasDirectStructRefHolder)
+ }
+ }
+ }
+
+ private def buildPrunedFieldsMap(): Map[String, Seq[String]] = {
+ nestedFieldRequirements.collect {
+ case (pathKey, Some(childNames)) =>
+ val pathParts = pathKey.split("\\.").toSeq
+ val childSchema = resolveSchemaAtPath(fullSchema, pathParts)
+ if (childSchema != null) {
+ val orderedNames = childSchema.fields
+ .map(_.name)
+ .filter(childNames.contains)
+ .toSeq
+ pathKey -> orderedNames
+ } else {
+ pathKey -> childNames.toSeq
+ }
+ }
+ }
+
+ private def registerPrunedOrdinals(
+ expr: Expression,
+ prunedFieldsMap: Map[String, Seq[String]],
+ topLevelIndices: Seq[Int]): Unit = {
+ expr match {
+ case gsf @ GetStructField(childExpr, ordinal, nameOpt) =>
+ resolveFieldAccessChain(childExpr) match {
+ case Some(parentPath) if parentPath.nonEmpty =>
+ val parentSchema = resolveSchemaAtPath(fullSchema, parentPath)
+ if (parentSchema != null) {
+ val pathKey = parentPath.mkString(".")
+ val childName = getFieldName(ordinal, nameOpt, parentSchema)
+ prunedFieldsMap.get(pathKey).foreach { orderedChildren =>
+ val runtimeOrd = orderedChildren.indexOf(childName)
+ if (runtimeOrd >= 0) {
+ gsf.setTagValue(ProtobufExprShims.PRUNED_ORDINAL_TAG, runtimeOrd)
+ }
+ }
+ }
+ case Some(parentPath) if parentPath.isEmpty =>
+ val runtimeOrd = topLevelIndices.indexOf(ordinal)
+ if (runtimeOrd >= 0) {
+ gsf.setTagValue(ProtobufExprShims.PRUNED_ORDINAL_TAG, runtimeOrd)
+ }
+ case _ =>
+ }
+ case gasf @ GetArrayStructFields(childExpr, field, _, _, _) =>
+ resolveFieldAccessChain(childExpr) match {
+ case Some(parentPath) if parentPath.nonEmpty =>
+ val pathKey = parentPath.mkString(".")
+ prunedFieldsMap.get(pathKey).foreach { orderedChildren =>
+ val runtimeOrd = orderedChildren.indexOf(field.name)
+ if (runtimeOrd >= 0) {
+ gasf.setTagValue(ProtobufExprShims.PRUNED_ORDINAL_TAG, runtimeOrd)
+ }
+ }
+ case _ =>
+ }
+ case _ =>
+ }
+ expr.children.foreach(registerPrunedOrdinals(_, prunedFieldsMap, topLevelIndices))
+ }
+
+ /**
+ * Check if an expression references the output of a protobuf decode expression.
+ * This can be either:
+ * 1. The ProtobufDataToCatalyst expression itself
+ * 2. An AttributeReference that references the output of ProtobufDataToCatalyst
+ * (when accessing from a downstream ProjectExec)
+ */
+ private def isProtobufStructReference(expr: Expression): Boolean = {
+ if ((expr eq e) || expr.semanticEquals(e)) {
+ return true
+ }
+
+ // Catalyst may create duplicate ProtobufDataToCatalyst
+ // instances for each GetStructField access. Match copies
+ // by class + identical input child + identical decode
+ // semantics so that
+ // analyzeRequiredFields detects all field accesses in one
+ // pass, keeping schema projection correct.
+ if (expr.getClass == e.getClass &&
+ expr.children.nonEmpty &&
+ e.children.nonEmpty &&
+ ((expr.children.head eq e.children.head) ||
+ expr.children.head.semanticEquals(
+ e.children.head)) &&
+ SparkProtobufCompat.sameDecodeSemantics(expr, e)) {
+ return true
+ }
+
+ val protobufOutputExprId
+ : Option[org.apache.spark.sql.catalyst.expressions.ExprId] =
+ parent.flatMap { meta =>
+ meta.wrapped match {
+ case alias: org.apache.spark.sql.catalyst.expressions
+ .Alias if alias.child.semanticEquals(e) =>
+ Some(alias.exprId)
+ case _ => None
+ }
+ }
+
+ expr match {
+ case attr: AttributeReference =>
+ protobufOutputExprIds.contains(attr.exprId) || protobufOutputExprId.exists(_ == attr.exprId)
+ case _ => false
+ }
+ }
+
+ override def convertToGpu(child: Expression): GpuExpression = {
+ val prunedFieldsMap = buildPrunedFieldsMap()
+ targetExprsToRemap.foreach(
+ registerPrunedOrdinals(_, prunedFieldsMap, decodedTopLevelIndices.toSeq))
+
+ val decodedSchema = {
+ def applyPruning(field: StructField, prefix: String): StructField = {
+ val path = if (prefix.isEmpty) field.name else s"$prefix.${field.name}"
+ prunedFieldsMap.get(path) match {
+ case Some(childNames) =>
+ field.dataType match {
+ case ArrayType(st: StructType, cn) =>
+ val pruned = StructType(
+ st.fields.filter(f => childNames.contains(f.name))
+ .map(f => applyPruning(f, path)))
+ field.copy(dataType = ArrayType(pruned, cn))
+ case st: StructType =>
+ val pruned = StructType(
+ st.fields.filter(f => childNames.contains(f.name))
+ .map(f => applyPruning(f, path)))
+ field.copy(dataType = pruned)
+ case _ => field
+ }
+ case None =>
+ field.dataType match {
+ case ArrayType(st: StructType, cn) =>
+ val recursed = StructType(st.fields.map(f => applyPruning(f, path)))
+ if (recursed != st) field.copy(dataType = ArrayType(recursed, cn))
+ else field
+ case st: StructType =>
+ val recursed = StructType(st.fields.map(f => applyPruning(f, path)))
+ if (recursed != st) field.copy(dataType = recursed)
+ else field
+ case _ => field
+ }
+ }
+ }
+
+ val decodedFields = decodedTopLevelIndices.map { idx =>
+ applyPruning(fullSchema.fields(idx), "")
+ }
+ StructType(decodedFields.map(f =>
+ f.copy(nullable = true)))
+ }
+
+ GpuFromProtobuf(
+ decodedSchema,
+ flatFieldNumbers, flatParentIndices,
+ flatDepthLevels, flatWireTypes, flatOutputTypeIds, flatEncodings,
+ flatIsRepeated, flatIsRequired, flatHasDefaultValue, flatDefaultInts,
+ flatDefaultFloats, flatDefaultBools, flatDefaultStrings, flatEnumValidValues,
+ flatEnumNames, failOnErrors, child)
+ }
+ }
+ )
+ }
+
+}
diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala
index f6110539a35..10ae5efacbf 100644
--- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala
+++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala
@@ -165,7 +165,7 @@ trait Spark340PlusNonDBShims extends Spark331PlusNonDBShims {
),
GpuElementAtMeta.elementAtRule(true)
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
- super.getExprs ++ shimExprs
+ super.getExprs ++ shimExprs ++ ProtobufExprShims.exprs
}
override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],
diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/SparkProtobufCompat.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/SparkProtobufCompat.scala
new file mode 100644
index 00000000000..ef2ffc5ef57
--- /dev/null
+++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/SparkProtobufCompat.scala
@@ -0,0 +1,344 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "340"}
+{"spark": "341"}
+{"spark": "342"}
+{"spark": "343"}
+{"spark": "344"}
+{"spark": "350"}
+{"spark": "351"}
+{"spark": "352"}
+{"spark": "353"}
+{"spark": "354"}
+{"spark": "355"}
+{"spark": "356"}
+{"spark": "357"}
+{"spark": "400"}
+{"spark": "401"}
+{"spark": "402"}
+{"spark": "411"}
+spark-rapids-shim-json-lines ***/
+
+package com.nvidia.spark.rapids.shims
+
+import java.lang.reflect.Method
+import java.nio.file.{Files, Paths}
+
+import scala.util.Try
+
+import com.nvidia.spark.rapids.ShimReflectionUtils
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.rapids.protobuf._
+
+private[shims] object SparkProtobufCompat extends Logging {
+ private[this] val sparkProtobufUtilsObjectClassName =
+ "org.apache.spark.sql.protobuf.utils.ProtobufUtils$"
+
+ val SupportedOptions: Set[String] = Set("enums.as.ints", "mode")
+
+ def extractExprInfo(e: Expression): Either[String, ProtobufExprInfo] = {
+ for {
+ messageName <- reflectMessageName(e)
+ options <- reflectOptions(e)
+ descriptorSource <- reflectDescriptorSource(e)
+ } yield ProtobufExprInfo(messageName, descriptorSource, options)
+ }
+
+ def sameDecodeSemantics(left: Expression, right: Expression): Boolean = {
+ (extractExprInfo(left), extractExprInfo(right)) match {
+ case (Right(leftInfo), Right(rightInfo)) => leftInfo == rightInfo
+ case _ => false
+ }
+ }
+
+ def parsePlannerOptions(
+ options: Map[String, String]): Either[String, ProtobufPlannerOptions] = {
+ val enumsAsInts = Try(options.getOrElse("enums.as.ints", "false").toBoolean)
+ .toEither
+ .left
+ .map { _ =>
+ "Invalid value for from_protobuf option 'enums.as.ints': " +
+ s"'${options.getOrElse("enums.as.ints", "")}' (expected true/false)"
+ }
+ enumsAsInts.map(v =>
+ ProtobufPlannerOptions(
+ enumsAsInts = v,
+ failOnErrors = options.getOrElse("mode", "FAILFAST").equalsIgnoreCase("FAILFAST")))
+ }
+
+ def unsupportedOptions(options: Map[String, String]): Seq[String] =
+ options.keys.filterNot(SupportedOptions.contains).toSeq.sorted
+
+ def isGpuSupportedProtoSyntax(syntax: String): Boolean =
+ syntax.nonEmpty && syntax != "PROTO3" && syntax != "EDITIONS"
+
+ def resolveMessageDescriptor(
+ exprInfo: ProtobufExprInfo): Either[String, ProtobufMessageDescriptor] = {
+ Try(buildMessageDescriptor(exprInfo.messageName, exprInfo.descriptorSource))
+ .toEither
+ .left
+ .map { t =>
+ s"Failed to resolve protobuf descriptor for message '${exprInfo.messageName}': " +
+ s"${t.getMessage}"
+ }
+ .map(new ReflectiveMessageDescriptor(_))
+ }
+
+ private def reflectMessageName(e: Expression): Either[String, String] =
+ Try(PbReflect.invoke0[String](e, "messageName")).toEither.left.map { t =>
+ s"Cannot read from_protobuf messageName via reflection: ${t.getMessage}"
+ }
+
+ private def reflectOptions(e: Expression): Either[String, Map[String, String]] = {
+ Try(PbReflect.invoke0[scala.collection.Map[String, String]](e, "options"))
+ .map(_.toMap)
+ .toEither.left.map { _ =>
+ "Cannot read from_protobuf options via reflection; falling back to CPU"
+ }
+ }
+
+ private def reflectDescriptorSource(e: Expression): Either[String, ProtobufDescriptorSource] = {
+ reflectDescFilePath(e).map(ProtobufDescriptorSource.DescriptorPath).orElse(
+ reflectDescriptorBytes(e).map(ProtobufDescriptorSource.DescriptorBytes)).toRight(
+ "from_protobuf requires a descriptor set (descFilePath or binaryFileDescriptorSet)")
+ }
+
+ private def reflectDescFilePath(e: Expression): Option[String] =
+ Try(PbReflect.invoke0[Option[String]](e, "descFilePath")).toOption.flatten
+
+ private def reflectDescriptorBytes(e: Expression): Option[Array[Byte]] = {
+ val spark35Result = Try(PbReflect.invoke0[Option[Array[Byte]]](e, "binaryFileDescriptorSet"))
+ .toOption.flatten
+ spark35Result.orElse {
+ val direct = Try(PbReflect.invoke0[Array[Byte]](e, "binaryDescriptorSet")).toOption
+ direct.orElse {
+ Try(PbReflect.invoke0[Option[Array[Byte]]](e, "binaryDescriptorSet")).toOption.flatten
+ }
+ }
+ }
+
+ private def buildMessageDescriptor(
+ messageName: String,
+ descriptorSource: ProtobufDescriptorSource): AnyRef = {
+ val cls = ShimReflectionUtils.loadClass(sparkProtobufUtilsObjectClassName)
+ val module = cls.getField("MODULE$").get(null)
+ val buildMethod = cls.getMethod("buildDescriptor", classOf[String], classOf[scala.Option[_]])
+
+ invokeBuildDescriptor(
+ buildMethod,
+ module,
+ messageName,
+ descriptorSource,
+ filePath => Files.readAllBytes(Paths.get(filePath)))
+ }
+
+ private[shims] def invokeBuildDescriptor(
+ buildMethod: Method,
+ module: AnyRef,
+ messageName: String,
+ descriptorSource: ProtobufDescriptorSource,
+ readDescriptorFile: String => Array[Byte]): AnyRef = {
+ descriptorSource match {
+ case ProtobufDescriptorSource.DescriptorBytes(bytes) =>
+ buildMethod.invoke(module, messageName, Some(bytes)).asInstanceOf[AnyRef]
+ case ProtobufDescriptorSource.DescriptorPath(filePath) =>
+ try {
+ buildMethod.invoke(module, messageName, Some(filePath)).asInstanceOf[AnyRef]
+ } catch {
+ // Spark 3.5+ changed the descriptor payload from Option[String] to Option[Array[Byte]]
+ // while keeping the same erased JVM signature. Retry with file contents when the
+ // path-based invocation clearly hit that binary-descriptor variant.
+ case ex: java.lang.reflect.InvocationTargetException
+ if ex.getCause.isInstanceOf[ClassCastException] ||
+ ex.getCause.isInstanceOf[MatchError] =>
+ buildMethod.invoke(
+ module, messageName, Some(readDescriptorFile(filePath))).asInstanceOf[AnyRef]
+ }
+ }
+ }
+
+ private def typeName(t: AnyRef): String =
+ if (t == null) "null" else Try(PbReflect.invoke0[String](t, "name")).getOrElse(t.toString)
+
+ private final class ReflectiveMessageDescriptor(raw: AnyRef) extends ProtobufMessageDescriptor {
+ override lazy val syntax: String = PbReflect.getFileSyntax(raw, typeName)
+
+ override def findField(name: String): Option[ProtobufFieldDescriptor] =
+ Option(PbReflect.findFieldByName(raw, name)).map(new ReflectiveFieldDescriptor(_))
+ }
+
+ private final class ReflectiveFieldDescriptor(raw: AnyRef) extends ProtobufFieldDescriptor {
+ override lazy val name: String = PbReflect.invoke0[String](raw, "getName")
+ override lazy val fieldNumber: Int = PbReflect.getFieldNumber(raw)
+ override lazy val protoTypeName: String = typeName(PbReflect.getFieldType(raw))
+ override lazy val isRepeated: Boolean = PbReflect.isRepeated(raw)
+ override lazy val isRequired: Boolean = PbReflect.isRequired(raw)
+ override lazy val enumMetadata: Option[ProtobufEnumMetadata] =
+ if (protoTypeName == "ENUM") {
+ Some(ProtobufEnumMetadata(PbReflect.getEnumValues(PbReflect.getEnumType(raw))))
+ } else {
+ None
+ }
+ override lazy val defaultValueResult: Either[String, Option[ProtobufDefaultValue]] =
+ Try {
+ if (PbReflect.hasDefaultValue(raw)) {
+ PbReflect.getDefaultValue(raw).map(toDefaultValue(_, protoTypeName, enumMetadata))
+ } else {
+ None
+ }
+ }.toEither.left.map { t =>
+ s"Failed to read protobuf default value for field '$name': ${t.getMessage}"
+ }
+ override lazy val messageDescriptor: Option[ProtobufMessageDescriptor] =
+ if (protoTypeName == "MESSAGE") {
+ Some(new ReflectiveMessageDescriptor(PbReflect.getMessageType(raw)))
+ } else {
+ None
+ }
+ }
+
+ private def toDefaultValue(
+ rawDefault: AnyRef,
+ protoTypeName: String,
+ enumMetadata: Option[ProtobufEnumMetadata]): ProtobufDefaultValue = protoTypeName match {
+ case "BOOL" =>
+ ProtobufDefaultValue.BoolValue(rawDefault.asInstanceOf[java.lang.Boolean].booleanValue())
+ case "FLOAT" =>
+ ProtobufDefaultValue.FloatValue(rawDefault.asInstanceOf[java.lang.Float].floatValue())
+ case "DOUBLE" =>
+ ProtobufDefaultValue.DoubleValue(rawDefault.asInstanceOf[java.lang.Double].doubleValue())
+ case "STRING" =>
+ ProtobufDefaultValue.StringValue(if (rawDefault == null) null else rawDefault.toString)
+ case "BYTES" =>
+ ProtobufDefaultValue.BinaryValue(extractBytes(rawDefault))
+ case "ENUM" =>
+ val number = extractNumber(rawDefault).intValue()
+ enumMetadata.map(_.enumDefault(number))
+ .getOrElse(ProtobufDefaultValue.EnumValue(number, rawDefault.toString))
+ case "INT32" | "UINT32" | "SINT32" | "FIXED32" | "SFIXED32" |
+ "INT64" | "UINT64" | "SINT64" | "FIXED64" | "SFIXED64" =>
+ ProtobufDefaultValue.IntValue(extractNumber(rawDefault).longValue())
+ case other =>
+ throw new IllegalStateException(
+ s"Unsupported protobuf default value type '$other' for value ${rawDefault.toString}")
+ }
+
+ private def extractNumber(rawDefault: AnyRef): java.lang.Number = rawDefault match {
+ case n: java.lang.Number => n
+ case ref: AnyRef =>
+ Try {
+ ref.getClass.getMethod("getNumber").invoke(ref).asInstanceOf[java.lang.Number]
+ }.getOrElse {
+ throw new IllegalStateException(
+ s"Unsupported protobuf numeric default value class: ${ref.getClass.getName}")
+ }
+ case _ =>
+ throw new IllegalStateException("Unexpected protobuf numeric default value")
+ }
+
+ private def extractBytes(rawDefault: AnyRef): Array[Byte] = rawDefault match {
+ case bytes: Array[Byte] => bytes
+ case ref: AnyRef =>
+ Try {
+ ref.getClass.getMethod("toByteArray").invoke(ref).asInstanceOf[Array[Byte]]
+ }.getOrElse {
+ throw new IllegalStateException(
+ s"Unsupported protobuf bytes default value class: ${ref.getClass.getName}")
+ }
+ case _ =>
+ throw new IllegalStateException("Unexpected protobuf bytes default value")
+ }
+
+ private object PbReflect {
+ private val cache = new java.util.concurrent.ConcurrentHashMap[String, Method]()
+
+ private def protobufJavaVersion: String = Try {
+ val rtCls = Class.forName("com.google.protobuf.RuntimeVersion")
+ val domain = rtCls.getField("DOMAIN").get(null)
+ val major = rtCls.getField("MAJOR").get(null)
+ val minor = rtCls.getField("MINOR").get(null)
+ val patch = rtCls.getField("PATCH").get(null)
+ s"$domain-$major.$minor.$patch"
+ }.getOrElse("unknown")
+
+ private def cached(cls: Class[_], name: String, paramTypes: Class[_]*): Method = {
+ val key = s"${cls.getName}#$name(${paramTypes.map(_.getName).mkString(",")})"
+ cache.computeIfAbsent(key, _ => {
+ try {
+ cls.getMethod(name, paramTypes: _*)
+ } catch {
+ case ex: NoSuchMethodException =>
+ throw new UnsupportedOperationException(
+ s"protobuf-java method not found: ${cls.getSimpleName}.$name " +
+ s"(protobuf-java version: $protobufJavaVersion). " +
+ s"This may indicate an incompatible protobuf-java library version.",
+ ex)
+ }
+ })
+ }
+
+ def invoke0[T](obj: AnyRef, method: String): T =
+ cached(obj.getClass, method).invoke(obj).asInstanceOf[T]
+
+ def invoke1[T](obj: AnyRef, method: String, arg0Cls: Class[_], arg0: AnyRef): T =
+ cached(obj.getClass, method, arg0Cls).invoke(obj, arg0).asInstanceOf[T]
+
+ def findFieldByName(msgDesc: AnyRef, name: String): AnyRef =
+ invoke1[AnyRef](msgDesc, "findFieldByName", classOf[String], name)
+
+ def getFieldNumber(fd: AnyRef): Int =
+ invoke0[java.lang.Integer](fd, "getNumber").intValue()
+
+ def getFieldType(fd: AnyRef): AnyRef = invoke0[AnyRef](fd, "getType")
+
+ def isRepeated(fd: AnyRef): Boolean =
+ invoke0[java.lang.Boolean](fd, "isRepeated").booleanValue()
+
+ def isRequired(fd: AnyRef): Boolean =
+ invoke0[java.lang.Boolean](fd, "isRequired").booleanValue()
+
+ def hasDefaultValue(fd: AnyRef): Boolean =
+ invoke0[java.lang.Boolean](fd, "hasDefaultValue").booleanValue()
+
+ def getDefaultValue(fd: AnyRef): Option[AnyRef] =
+ Option(invoke0[AnyRef](fd, "getDefaultValue"))
+
+ def getMessageType(fd: AnyRef): AnyRef = invoke0[AnyRef](fd, "getMessageType")
+
+ def getEnumType(fd: AnyRef): AnyRef = invoke0[AnyRef](fd, "getEnumType")
+
+ def getEnumValues(enumType: AnyRef): Seq[ProtobufEnumValue] = {
+ import scala.collection.JavaConverters._
+ val values = invoke0[java.util.List[_]](enumType, "getValues")
+ values.asScala.map { v =>
+ val ev = v.asInstanceOf[AnyRef]
+ val num = invoke0[java.lang.Integer](ev, "getNumber").intValue()
+ val enumName = invoke0[String](ev, "getName")
+ ProtobufEnumValue(num, enumName)
+ }.toSeq
+ }
+
+ def getFileSyntax(msgDesc: AnyRef, typeNameFn: AnyRef => String): String = Try {
+ val fileDesc = invoke0[AnyRef](msgDesc, "getFile")
+ val syntaxObj = invoke0[AnyRef](fileDesc, "getSyntax")
+ typeNameFn(syntaxObj)
+ }.getOrElse("")
+ }
+}
diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/ProtobufBatchMergeSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/ProtobufBatchMergeSuite.scala
new file mode 100644
index 00000000000..1f2d4492108
--- /dev/null
+++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/ProtobufBatchMergeSuite.scala
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids
+
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, GetStructField, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.execution.{LeafExecNode, ProjectExec, SparkPlan}
+import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType}
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+class ProtobufBatchMergeSuite extends AnyFunSuite {
+
+ private case class DummyColumnarLeaf(output: Seq[AttributeReference]) extends LeafExecNode {
+ override protected def doExecute(): RDD[InternalRow] =
+ throw new UnsupportedOperationException("not needed for unit test")
+
+ override protected def doExecuteColumnar(): RDD[ColumnarBatch] =
+ throw new UnsupportedOperationException("not needed for unit test")
+
+ override def supportsColumnar: Boolean = true
+ }
+
+ private case class FakeProtobufDataToCatalyst(child: Expression) extends UnaryExpression {
+ override def dataType: StructType = StructType(Seq(
+ StructField("search_id", IntegerType, nullable = true),
+ StructField("nested", StructType(Seq(StructField("value", IntegerType, nullable = true))),
+ nullable = true)))
+ override def nullable: Boolean = true
+ override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ throw new UnsupportedOperationException("codegen not needed in unit test")
+ }
+
+ test("project meta detects extractor project over protobuf child project") {
+ val binAttr = AttributeReference("bin", BinaryType)()
+ val childScan = DummyColumnarLeaf(Seq(binAttr))
+ val innerProject = ProjectExec(
+ Seq(Alias(FakeProtobufDataToCatalyst(binAttr), "decoded")()),
+ childScan)
+ val decodedAttr = innerProject.output.head.toAttribute
+ val outerProject = ProjectExec(Seq(
+ Alias(GetStructField(decodedAttr, 0, None), "search_id")(),
+ Alias(GetStructField(GetStructField(decodedAttr, 1, None), 0, None), "value")()),
+ innerProject)
+
+ assert(GpuProjectExecMeta.shouldCoalesceAfterProject(outerProject))
+ assert(!GpuProjectExecMeta.shouldCoalesceAfterProject(innerProject))
+ }
+
+ test("project meta detects direct protobuf extraction in same project") {
+ val binAttr = AttributeReference("bin", BinaryType)()
+ val childScan = DummyColumnarLeaf(Seq(binAttr))
+ val directProject = ProjectExec(
+ Seq(
+ Alias(GetStructField(FakeProtobufDataToCatalyst(binAttr), 0, None), "search_id")(),
+ Alias(
+ GetStructField(
+ GetStructField(FakeProtobufDataToCatalyst(binAttr), 1, None),
+ 0,
+ None),
+ "value")()),
+ childScan)
+
+ assert(GpuProjectExecMeta.shouldCoalesceAfterProject(directProject))
+ }
+
+ test("protobuf batch merge config defaults off and can be enabled") {
+ val enabledConf = new RapidsConf(Map(
+ RapidsConf.ENABLE_PROTOBUF_BATCH_MERGE_AFTER_PROJECT.key -> "true"))
+
+ assert(!new RapidsConf(Map.empty[String, String]).isProtobufBatchMergeAfterProjectEnabled)
+ assert(enabledConf.isProtobufBatchMergeAfterProjectEnabled)
+ }
+
+ test("flagged gpu project drops output batching guarantee for post-project merge") {
+ val childAttr = AttributeReference("value", IntegerType)()
+ val child: SparkPlan = DummyColumnarLeaf(Seq(childAttr))
+
+ val unflaggedProject = GpuProjectExec(
+ projectList = child.output.map(a => Alias(a, a.name)()).toList,
+ child = child,
+ enablePreSplit = true,
+ forcePostProjectCoalesce = false)
+ val flaggedProject = GpuProjectExec(
+ projectList = child.output.map(a => Alias(a, a.name)()).toList,
+ child = child,
+ enablePreSplit = true,
+ forcePostProjectCoalesce = true)
+
+ assert(!unflaggedProject.coalesceAfter)
+ assert(flaggedProject.coalesceAfter)
+ assert(unflaggedProject.outputBatching.isInstanceOf[TargetSize])
+ assert(flaggedProject.outputBatching == null)
+ assert(CoalesceGoal.satisfies(unflaggedProject.outputBatching, TargetSize(1L)))
+ assert(!CoalesceGoal.satisfies(flaggedProject.outputBatching, TargetSize(1L)))
+ }
+}
diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/shims/ProtobufExprShimsSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/shims/ProtobufExprShimsSuite.scala
new file mode 100644
index 00000000000..1ab2c16cdf6
--- /dev/null
+++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/shims/ProtobufExprShimsSuite.scala
@@ -0,0 +1,451 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims
+
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.GetArrayStructFields
+import org.apache.spark.sql.rapids.{
+ GpuFromProtobuf,
+ GpuGetArrayStructFieldsMeta,
+ GpuStructFieldOrdinalTag
+}
+import org.apache.spark.sql.rapids.protobuf._
+import org.apache.spark.sql.types._
+
+class ProtobufExprShimsSuite extends AnyFunSuite {
+ private val outputSchema = StructType(Seq(
+ StructField("id", IntegerType, nullable = true),
+ StructField("name", StringType, nullable = true)))
+
+ private case class FakeExprChild() extends Expression {
+ override def children: Seq[Expression] = Nil
+ override def nullable: Boolean = true
+ override def dataType: DataType = BinaryType
+ override def eval(input: org.apache.spark.sql.catalyst.InternalRow): Any = null
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ throw new UnsupportedOperationException("not needed")
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ assert(newChildren.isEmpty)
+ this
+ }
+ }
+
+ private abstract class FakeBaseProtobufExpr(childExpr: Expression) extends UnaryExpression {
+ override def child: Expression = childExpr
+ override def nullable: Boolean = true
+ override def dataType: DataType = outputSchema
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ throw new UnsupportedOperationException("not needed")
+ override protected def withNewChildInternal(newChild: Expression): Expression = this
+ }
+
+ private case class FakePathProtobufExpr(override val child: Expression)
+ extends FakeBaseProtobufExpr(child) {
+ def messageName: String = "test.Message"
+ def descFilePath: Option[String] = Some("/tmp/test.desc")
+ def options: scala.collection.Map[String, String] = Map("mode" -> "FAILFAST")
+ }
+
+ private case class FakeBytesProtobufExpr(override val child: Expression)
+ extends FakeBaseProtobufExpr(child) {
+ def messageName: String = "test.Message"
+ def binaryDescriptorSet: Array[Byte] = Array[Byte](1, 2, 3)
+ def options: scala.collection.Map[String, String] =
+ Map("mode" -> "PERMISSIVE", "enums.as.ints" -> "true")
+ }
+
+ private case class FakeMissingOptionsExpr(override val child: Expression)
+ extends FakeBaseProtobufExpr(child) {
+ def messageName: String = "test.Message"
+ def descFilePath: Option[String] = Some("/tmp/test.desc")
+ }
+
+ private case class FakeDifferentMessageExpr(override val child: Expression)
+ extends FakeBaseProtobufExpr(child) {
+ def messageName: String = "test.OtherMessage"
+ def descFilePath: Option[String] = Some("/tmp/test.desc")
+ def options: scala.collection.Map[String, String] = Map("mode" -> "FAILFAST")
+ }
+
+ private case class FakeDifferentDescriptorExpr(override val child: Expression)
+ extends FakeBaseProtobufExpr(child) {
+ def messageName: String = "test.Message"
+ def descFilePath: Option[String] = Some("/tmp/other.desc")
+ def options: scala.collection.Map[String, String] = Map("mode" -> "FAILFAST")
+ }
+
+ private case class FakeDifferentOptionsExpr(override val child: Expression)
+ extends FakeBaseProtobufExpr(child) {
+ def messageName: String = "test.Message"
+ def descFilePath: Option[String] = Some("/tmp/test.desc")
+ def options: scala.collection.Map[String, String] = Map("mode" -> "PERMISSIVE")
+ }
+
+ private case class FakeTypedUnaryExpr(
+ dt: DataType,
+ override val child: Expression = FakeExprChild()) extends UnaryExpression {
+ override def nullable: Boolean = true
+ override def dataType: DataType = dt
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ throw new UnsupportedOperationException("not needed")
+ override protected def withNewChildInternal(newChild: Expression): Expression = copy(child =
+ newChild)
+ }
+
+ private object FakeSpark34ProtobufUtils {
+ def buildDescriptor(messageName: String, descFilePath: Option[String]): String =
+ s"$messageName:${descFilePath.getOrElse("none")}"
+ }
+
+ private object FakeSpark35ProtobufUtils {
+ def buildDescriptor(messageName: String, binaryFileDescriptorSet: Option[Array[Byte]]): String =
+ s"$messageName:${binaryFileDescriptorSet.map(_.mkString(",")).getOrElse("none")}"
+ }
+
+ private case class FakeMessageDescriptor(
+ syntax: String,
+ fields: Map[String, ProtobufFieldDescriptor]) extends ProtobufMessageDescriptor {
+ override def findField(name: String): Option[ProtobufFieldDescriptor] = fields.get(name)
+ }
+
+ private case class FakeFieldDescriptor(
+ name: String,
+ fieldNumber: Int,
+ protoTypeName: String,
+ isRepeated: Boolean = false,
+ isRequired: Boolean = false,
+ defaultValue: Option[ProtobufDefaultValue] = None,
+ defaultValueError: Option[String] = None,
+ enumMetadata: Option[ProtobufEnumMetadata] = None,
+ messageDescriptor: Option[ProtobufMessageDescriptor] = None) extends ProtobufFieldDescriptor {
+ override lazy val defaultValueResult: Either[String, Option[ProtobufDefaultValue]] =
+ defaultValueError match {
+ case Some(reason) => Left(reason)
+ case None => Right(defaultValue)
+ }
+ }
+
+ test("compat extracts descriptor path and options from legacy expression") {
+ val exprInfo = SparkProtobufCompat.extractExprInfo(FakePathProtobufExpr(FakeExprChild()))
+ assert(exprInfo.isRight)
+ val info = exprInfo.toOption.get
+ assert(info.messageName == "test.Message")
+ assert(info.options == Map("mode" -> "FAILFAST"))
+ assert(info.descriptorSource ==
+ ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc"))
+ }
+
+ test("compat extracts binary descriptor source and planner options") {
+ val exprInfo = SparkProtobufCompat.extractExprInfo(FakeBytesProtobufExpr(FakeExprChild()))
+ assert(exprInfo.isRight)
+ val info = exprInfo.toOption.get
+ info.descriptorSource match {
+ case ProtobufDescriptorSource.DescriptorBytes(bytes) =>
+ assert(bytes.sameElements(Array[Byte](1, 2, 3)))
+ case other =>
+ fail(s"Unexpected descriptor source: $other")
+ }
+ val plannerOptions = SparkProtobufCompat.parsePlannerOptions(info.options)
+ assert(plannerOptions ==
+ Right(ProtobufPlannerOptions(enumsAsInts = true, failOnErrors = false)))
+ }
+
+ test("compat invokes Spark 3.4 descriptor builder with descriptor path") {
+ val buildMethod = FakeSpark34ProtobufUtils.getClass.getMethod(
+ "buildDescriptor", classOf[String], classOf[scala.Option[_]])
+
+ val result = SparkProtobufCompat.invokeBuildDescriptor(
+ buildMethod,
+ FakeSpark34ProtobufUtils,
+ "test.Message",
+ ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc"),
+ _ => fail("path-to-bytes fallback should not be needed for Spark 3.4"))
+
+ assert(result == "test.Message:/tmp/test.desc")
+ }
+
+ test("compat retries descriptor path as bytes for Spark 3.5 descriptor builder") {
+ val buildMethod = FakeSpark35ProtobufUtils.getClass.getMethod(
+ "buildDescriptor", classOf[String], classOf[scala.Option[_]])
+ var readCalls = 0
+
+ val result = SparkProtobufCompat.invokeBuildDescriptor(
+ buildMethod,
+ FakeSpark35ProtobufUtils,
+ "test.Message",
+ ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc"),
+ _ => {
+ readCalls += 1
+ Array[Byte](1, 2, 3)
+ })
+
+ assert(readCalls == 1)
+ assert(result == "test.Message:1,2,3")
+ }
+
+ test("compat passes bytes directly to Spark 3.5 descriptor builder") {
+ val buildMethod = FakeSpark35ProtobufUtils.getClass.getMethod(
+ "buildDescriptor", classOf[String], classOf[scala.Option[_]])
+
+ val result = SparkProtobufCompat.invokeBuildDescriptor(
+ buildMethod,
+ FakeSpark35ProtobufUtils,
+ "test.Message",
+ ProtobufDescriptorSource.DescriptorBytes(Array[Byte](4, 5, 6)),
+ _ => fail("binary descriptor source should not read a file"))
+
+ assert(result == "test.Message:4,5,6")
+ }
+
+ test("compat distinguishes decode semantics across message descriptor and options") {
+ val child = FakeExprChild()
+
+ assert(SparkProtobufCompat.sameDecodeSemantics(
+ FakePathProtobufExpr(child), FakePathProtobufExpr(child)))
+ assert(SparkProtobufCompat.sameDecodeSemantics(
+ FakeBytesProtobufExpr(child), FakeBytesProtobufExpr(child)))
+ assert(!SparkProtobufCompat.sameDecodeSemantics(
+ FakePathProtobufExpr(child), FakeDifferentMessageExpr(child)))
+ assert(!SparkProtobufCompat.sameDecodeSemantics(
+ FakePathProtobufExpr(child), FakeDifferentDescriptorExpr(child)))
+ assert(!SparkProtobufCompat.sameDecodeSemantics(
+ FakePathProtobufExpr(child), FakeDifferentOptionsExpr(child)))
+ }
+
+ test("compat reports missing options accessor as cpu fallback reason") {
+ val exprInfo = SparkProtobufCompat.extractExprInfo(FakeMissingOptionsExpr(FakeExprChild()))
+ assert(exprInfo.left.toOption.exists(
+ _.contains("Cannot read from_protobuf options via reflection")))
+ }
+
+ test("compat detects unsupported options and proto3 syntax") {
+ assert(SparkProtobufCompat.unsupportedOptions(Map("mode" -> "FAILFAST", "foo" -> "bar")) ==
+ Seq("foo"))
+ assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax("PROTO3"))
+ assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax("EDITIONS"))
+ assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax(""))
+ assert(SparkProtobufCompat.isGpuSupportedProtoSyntax("PROTO2"))
+ }
+
+ test("extractor preserves typed enum defaults") {
+ val enumMeta = ProtobufEnumMetadata(Seq(
+ ProtobufEnumValue(0, "UNKNOWN"),
+ ProtobufEnumValue(1, "EN"),
+ ProtobufEnumValue(2, "ZH")))
+ val msgDesc = FakeMessageDescriptor(
+ syntax = "PROTO2",
+ fields = Map(
+ "language" -> FakeFieldDescriptor(
+ name = "language",
+ fieldNumber = 1,
+ protoTypeName = "ENUM",
+ defaultValue = Some(ProtobufDefaultValue.EnumValue(1, "EN")),
+ enumMetadata = Some(enumMeta))))
+ val schema = StructType(Seq(StructField("language", StringType, nullable = true)))
+
+ val infos = ProtobufSchemaExtractor.analyzeAllFields(
+ schema, msgDesc, enumsAsInts = false, "test.Message")
+
+ assert(infos.isRight)
+ assert(infos.toOption.get("language").defaultValue.contains(
+ ProtobufDefaultValue.EnumValue(1, "EN")))
+ }
+
+ test("extractor records reflection failures as unsupported field info") {
+ val msgDesc = FakeMessageDescriptor(
+ syntax = "PROTO2",
+ fields = Map(
+ "ok" -> FakeFieldDescriptor(
+ name = "ok",
+ fieldNumber = 1,
+ protoTypeName = "INT32"),
+ "id" -> FakeFieldDescriptor(
+ name = "id",
+ fieldNumber = 2,
+ protoTypeName = "INT32",
+ defaultValueError =
+ Some("Failed to read protobuf default value for field 'id': unsupported type"))))
+ val schema = StructType(Seq(
+ StructField("ok", IntegerType, nullable = true),
+ StructField("id", IntegerType, nullable = true)))
+
+ val infos = ProtobufSchemaExtractor.analyzeAllFields(
+ schema, msgDesc, enumsAsInts = true, "test.Message")
+
+ assert(infos.isRight)
+ assert(infos.toOption.get("ok").isSupported)
+ assert(!infos.toOption.get("id").isSupported)
+ assert(infos.toOption.get("id").unsupportedReason.exists(
+ _.contains("Failed to read protobuf default value for field 'id'")))
+ }
+
+ test("validator encodes enum-string defaults into both numeric and string payloads") {
+ val enumMeta = ProtobufEnumMetadata(Seq(
+ ProtobufEnumValue(0, "UNKNOWN"),
+ ProtobufEnumValue(1, "EN")))
+ val info = ProtobufFieldInfo(
+ fieldNumber = 2,
+ protoTypeName = "ENUM",
+ sparkType = StringType,
+ encoding = GpuFromProtobuf.ENC_ENUM_STRING,
+ isSupported = true,
+ unsupportedReason = None,
+ isRequired = false,
+ defaultValue = Some(ProtobufDefaultValue.EnumValue(1, "EN")),
+ enumMetadata = Some(enumMeta),
+ isRepeated = false)
+
+ val flat = ProtobufSchemaValidator.toFlattenedFieldDescriptor(
+ path = "common.language",
+ field = StructField("language", StringType, nullable = true),
+ fieldInfo = info,
+ parentIdx = 0,
+ depth = 1,
+ outputTypeId = 6)
+
+ assert(flat.isRight)
+ assert(flat.toOption.get.defaultInt == 1L)
+ assert(new String(flat.toOption.get.defaultString, "UTF-8") == "EN")
+ assert(flat.toOption.get.enumValidValues.sameElements(Array(0, 1)))
+ assert(flat.toOption.get.enumNames
+ .map(new String(_, "UTF-8"))
+ .sameElements(Array("UNKNOWN", "EN")))
+ }
+
+ test("validator rejects enum-string field without enum metadata") {
+ val info = ProtobufFieldInfo(
+ fieldNumber = 2,
+ protoTypeName = "ENUM",
+ sparkType = StringType,
+ encoding = GpuFromProtobuf.ENC_ENUM_STRING,
+ isSupported = true,
+ unsupportedReason = None,
+ isRequired = false,
+ defaultValue = Some(ProtobufDefaultValue.EnumValue(1, "EN")),
+ enumMetadata = None,
+ isRepeated = false)
+
+ val flat = ProtobufSchemaValidator.toFlattenedFieldDescriptor(
+ path = "common.language",
+ field = StructField("language", StringType, nullable = true),
+ fieldInfo = info,
+ parentIdx = 0,
+ depth = 1,
+ outputTypeId = 6)
+
+ assert(flat.left.toOption.exists(_.contains("missing enum metadata")))
+ }
+
+ test("validator returns Left for incompatible default type instead of throwing") {
+ val info = ProtobufFieldInfo(
+ fieldNumber = 3,
+ protoTypeName = "FLOAT",
+ sparkType = DoubleType,
+ encoding = GpuFromProtobuf.ENC_DEFAULT,
+ isSupported = true,
+ unsupportedReason = None,
+ isRequired = false,
+ defaultValue = Some(ProtobufDefaultValue.FloatValue(1.5f)),
+ enumMetadata = None,
+ isRepeated = false)
+
+ val flat = ProtobufSchemaValidator.toFlattenedFieldDescriptor(
+ path = "common.score",
+ field = StructField("score", DoubleType, nullable = true),
+ fieldInfo = info,
+ parentIdx = 0,
+ depth = 1,
+ outputTypeId = 6)
+
+ assert(flat.left.toOption.exists(
+ _.contains("Incompatible default value for protobuf field 'common.score'")))
+ }
+
+ test("array struct field meta uses pruned child field count after ordinal remap") {
+ val originalStruct = StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = true)))
+ val prunedStruct = StructType(Seq(StructField("b", IntegerType, nullable = true)))
+ val originalChild = FakeTypedUnaryExpr(ArrayType(originalStruct, containsNull = true))
+ val sparkExpr = GetArrayStructFields(
+ child = originalChild,
+ field = originalStruct.fields(1),
+ ordinal = 1,
+ numFields = originalStruct.fields.length,
+ containsNull = true)
+ sparkExpr.setTagValue(GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG, 0)
+
+ val prunedChild = FakeTypedUnaryExpr(ArrayType(prunedStruct, containsNull = true))
+ val runtimeOrd = sparkExpr.getTagValue(GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).get
+
+ assert(runtimeOrd == 0)
+ assert(
+ GpuGetArrayStructFieldsMeta.effectiveNumFields(prunedChild, sparkExpr, runtimeOrd) == 1)
+ }
+
+ test("GpuFromProtobuf semantic equality is content-based for schema arrays") {
+ def emptyEnumNames: Array[Array[Byte]] = Array.empty[Array[Byte]]
+
+ val expr1 = GpuFromProtobuf(
+ decodedSchema = outputSchema,
+ fieldNumbers = Array(1, 2),
+ parentIndices = Array(-1, -1),
+ depthLevels = Array(0, 0),
+ wireTypes = Array(0, 2),
+ outputTypeIds = Array(3, 6),
+ encodings = Array(0, 0),
+ isRepeated = Array(false, false),
+ isRequired = Array(false, false),
+ hasDefaultValue = Array(false, false),
+ defaultInts = Array(0L, 0L),
+ defaultFloats = Array(0.0, 0.0),
+ defaultBools = Array(false, false),
+ defaultStrings = Array(Array.emptyByteArray, Array.emptyByteArray),
+ enumValidValues = Array(Array.emptyIntArray, Array.emptyIntArray),
+ enumNames = Array(emptyEnumNames, emptyEnumNames),
+ failOnErrors = true,
+ child = FakeExprChild())
+
+ val expr2 = GpuFromProtobuf(
+ decodedSchema = outputSchema,
+ fieldNumbers = Array(1, 2),
+ parentIndices = Array(-1, -1),
+ depthLevels = Array(0, 0),
+ wireTypes = Array(0, 2),
+ outputTypeIds = Array(3, 6),
+ encodings = Array(0, 0),
+ isRepeated = Array(false, false),
+ isRequired = Array(false, false),
+ hasDefaultValue = Array(false, false),
+ defaultInts = Array(0L, 0L),
+ defaultFloats = Array(0.0, 0.0),
+ defaultBools = Array(false, false),
+ defaultStrings = Array(Array.emptyByteArray, Array.emptyByteArray),
+ enumValidValues = Array(Array.emptyIntArray, Array.emptyIntArray),
+ enumNames = Array(emptyEnumNames.map(identity), emptyEnumNames.map(identity)),
+ failOnErrors = true,
+ child = FakeExprChild())
+
+ assert(expr1.semanticEquals(expr2))
+ assert(expr1.semanticHash() == expr2.semanticHash())
+ }
+}