diff --git a/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala b/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala index 36c0eb6ea5d..5a98329e3b0 100644 --- a/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala +++ b/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala @@ -185,7 +185,7 @@ abstract class DeltaProviderBase extends DeltaIOProvider { // case dvRoot @ GpuProjectExec(outputList, dvFilter @ GpuFilterExec(condition, - dvFilterInput @ GpuProjectExec(inputList, fsse: GpuFileSourceScanExec, _)), _) + dvFilterInput @ GpuProjectExec(inputList, fsse: GpuFileSourceScanExec, _, _)), _, _) if condition.references.exists(_.name == IS_ROW_DELETED_COLUMN_NAME) && !outputList.exists(_.name == "_metadata") && inputList.exists(_.name == "_metadata") => dvRoot.withNewChildren(Seq( @@ -256,7 +256,7 @@ object DVPredicatePushdown extends ShimPredicateHelper { def pruneIsRowDeletedColumn(plan: SparkPlan): SparkPlan = { plan.transformUp { - case project @ GpuProjectExec(projectList, _, _) => + case project @ GpuProjectExec(projectList, _, _, _) => val newProjList = projectList.filterNot(isRowDeletedColumnRef(_)) project.copy(projectList = newProjList) case fsse: GpuFileSourceScanExec => @@ -307,11 +307,16 @@ object DVPredicatePushdown extends ShimPredicateHelper { def mergeIdenticalProjects(plan: SparkPlan): SparkPlan = { plan.transformUp { case p @ GpuProjectExec(projList1, - GpuProjectExec(projList2, child, enablePreSplit1), enablePreSplit2) => + GpuProjectExec(projList2, child, enablePreSplit1, forcePostProjectCoalesce1), + enablePreSplit2, forcePostProjectCoalesce2) => val projSet1 = projList1.map(_.exprId).toSet val projSet2 = projList2.map(_.exprId).toSet if (projSet1 == projSet2) { - GpuProjectExec(projList1, child, enablePreSplit1 && enablePreSplit2) + GpuProjectExec( + projList1, + child, + enablePreSplit1 && enablePreSplit2, + forcePostProjectCoalesce1 || forcePostProjectCoalesce2) } else { p } diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 888f4fa807c..196f518c6a7 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -8267,7 +8267,7 @@ are limited. -PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH
@@ -8290,7 +8290,7 @@ are limited. -PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH
diff --git a/integration_tests/run_pyspark_from_build.sh b/integration_tests/run_pyspark_from_build.sh index a619e6c66ae..5a7db886da1 100755 --- a/integration_tests/run_pyspark_from_build.sh +++ b/integration_tests/run_pyspark_from_build.sh @@ -46,6 +46,9 @@ # To run all tests, including Avro tests: # INCLUDE_SPARK_AVRO_JAR=true ./run_pyspark_from_build.sh # +# To run tests WITHOUT Protobuf tests (protobuf is included by default): +# INCLUDE_SPARK_PROTOBUF_JAR=false ./run_pyspark_from_build.sh +# # To run a specific test: # TEST=my_test ./run_pyspark_from_build.sh # @@ -141,9 +144,82 @@ else AVRO_JARS="" fi - # ALL_JARS includes dist.jar integration-test.jar avro.jar parquet.jar if they exist + # Protobuf support: Include spark-protobuf jar by default for protobuf_test.py + # Set INCLUDE_SPARK_PROTOBUF_JAR=false to disable + PROTOBUF_JARS="" + if [[ $( echo ${INCLUDE_SPARK_PROTOBUF_JAR} | tr '[:upper:]' '[:lower:]' ) != "false" ]]; + then + export INCLUDE_SPARK_PROTOBUF_JAR=true + mkdir -p "${TARGET_DIR}/dependency" + + # Download spark-protobuf jar if not already in target/dependency + PROTOBUF_JAR_NAME="spark-protobuf_${SCALA_VERSION}-${VERSION_STRING}.jar" + PROTOBUF_JAR_PATH="${TARGET_DIR}/dependency/${PROTOBUF_JAR_NAME}" + + if [[ ! -f "$PROTOBUF_JAR_PATH" ]]; then + echo "Downloading spark-protobuf jar..." + PROTOBUF_MAVEN_URL="https://repo1.maven.org/maven2/org/apache/spark/spark-protobuf_${SCALA_VERSION}/${VERSION_STRING}/${PROTOBUF_JAR_NAME}" + if curl -fsL -o "$PROTOBUF_JAR_PATH" "$PROTOBUF_MAVEN_URL"; then + echo "Downloaded spark-protobuf jar to $PROTOBUF_JAR_PATH" + else + echo "WARNING: Failed to download spark-protobuf jar from $PROTOBUF_MAVEN_URL" + rm -f "$PROTOBUF_JAR_PATH" + fi + fi + + # Also download protobuf-java jar (required dependency). + # Detect version from the jar bundled with Spark, fall back to version mapping. + PROTOBUF_JAVA_VERSION="" + BUNDLED_PB_JAR=$(ls "$SPARK_HOME"/jars/protobuf-java-[0-9]*.jar 2>/dev/null | head -1) + if [[ -n "$BUNDLED_PB_JAR" ]]; then + PROTOBUF_JAVA_VERSION=$(basename "$BUNDLED_PB_JAR" | sed 's/protobuf-java-\(.*\)\.jar/\1/') + echo "Detected protobuf-java version $PROTOBUF_JAVA_VERSION from SPARK_HOME" + fi + if [[ -z "$PROTOBUF_JAVA_VERSION" ]]; then + case "$VERSION_STRING" in + 3.4.*) PROTOBUF_JAVA_VERSION="3.25.1" ;; + 3.5.*) PROTOBUF_JAVA_VERSION="3.25.1" ;; + 4.0.*) PROTOBUF_JAVA_VERSION="4.29.3" ;; + *) PROTOBUF_JAVA_VERSION="3.25.1" ;; + esac + echo "Using protobuf-java version $PROTOBUF_JAVA_VERSION based on Spark $VERSION_STRING" + fi + PROTOBUF_JAVA_JAR_NAME="protobuf-java-${PROTOBUF_JAVA_VERSION}.jar" + PROTOBUF_JAVA_JAR_PATH="${TARGET_DIR}/dependency/${PROTOBUF_JAVA_JAR_NAME}" + + if [[ ! -f "$PROTOBUF_JAVA_JAR_PATH" ]]; then + echo "Downloading protobuf-java jar..." + PROTOBUF_JAVA_MAVEN_URL="https://repo1.maven.org/maven2/com/google/protobuf/protobuf-java/${PROTOBUF_JAVA_VERSION}/${PROTOBUF_JAVA_JAR_NAME}" + if curl -fsL -o "$PROTOBUF_JAVA_JAR_PATH" "$PROTOBUF_JAVA_MAVEN_URL"; then + echo "Downloaded protobuf-java jar to $PROTOBUF_JAVA_JAR_PATH" + else + echo "WARNING: Failed to download protobuf-java jar from $PROTOBUF_JAVA_MAVEN_URL" + rm -f "$PROTOBUF_JAVA_JAR_PATH" + fi + fi + + if [[ -f "$PROTOBUF_JAR_PATH" ]]; then + PROTOBUF_JARS="$PROTOBUF_JAR_PATH" + echo "Including spark-protobuf jar: $PROTOBUF_JAR_PATH" + fi + if [[ -f "$PROTOBUF_JAVA_JAR_PATH" ]]; then + PROTOBUF_JARS="${PROTOBUF_JARS:+$PROTOBUF_JARS }$PROTOBUF_JAVA_JAR_PATH" + echo "Including protobuf-java jar: $PROTOBUF_JAVA_JAR_PATH" + fi + # Also add protobuf jars to driver classpath for Class.forName() to work + # This is needed because --jars only adds to executor classpath + if [[ -n "$PROTOBUF_JARS" ]]; then + PROTOBUF_DRIVER_CP=$(echo "$PROTOBUF_JARS" | tr ' ' ':') + export PYSP_TEST_spark_driver_extraClassPath="${PYSP_TEST_spark_driver_extraClassPath:+${PYSP_TEST_spark_driver_extraClassPath}:}${PROTOBUF_DRIVER_CP}" + echo "Added protobuf jars to driver classpath" + fi + else + export INCLUDE_SPARK_PROTOBUF_JAR=false + fi + + # ALL_JARS includes dist.jar integration-test.jar avro.jar parquet.jar protobuf.jar if they exist # Remove non-existing paths and canonicalize the paths including get rid of links and `..` - ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PARQUET_HADOOP_TESTS || true) + ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PARQUET_HADOOP_TESTS $PROTOBUF_JARS || true) # `:` separated jars ALL_JARS="${ALL_JARS//$'\n'/:}" @@ -411,6 +487,7 @@ else export PYSP_TEST_spark_gluten_loadLibFromJar=true fi + SPARK_SHELL_SMOKE_TEST="${SPARK_SHELL_SMOKE_TEST:-0}" EXPLAIN_ONLY_CPU_SMOKE_TEST="${EXPLAIN_ONLY_CPU_SMOKE_TEST:-0}" SPARK_CONNECT_SMOKE_TEST="${SPARK_CONNECT_SMOKE_TEST:-0}" diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index fa7decac82d..d6583e3fd00 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# Copyright (c) 2020-2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -857,6 +857,389 @@ def gen_bytes(): return bytes([ rand.randint(0, 255) for _ in range(length) ]) self._start(rand, gen_bytes) + +# ----------------------------------------------------------------------------- +# Protobuf (simple types) generators/utilities (for from_protobuf/to_protobuf tests) +# ----------------------------------------------------------------------------- + +_PROTOBUF_WIRE_VARINT = 0 +_PROTOBUF_WIRE_64BIT = 1 +_PROTOBUF_WIRE_LEN_DELIM = 2 +_PROTOBUF_WIRE_32BIT = 5 + +def _encode_protobuf_uvarint(value): + """Encode a non-negative integer as protobuf varint.""" + if value is None: + raise ValueError("value must not be None") + if value < 0: + raise ValueError("uvarint only supports non-negative integers") + out = bytearray() + v = int(value) + while True: + b = v & 0x7F + v >>= 7 + if v: + out.append(b | 0x80) + else: + out.append(b) + break + return bytes(out) + +def _encode_protobuf_key(field_number, wire_type): + return _encode_protobuf_uvarint((int(field_number) << 3) | int(wire_type)) + +def _encode_protobuf_zigzag32(value): + """Encode a signed 32-bit integer using zigzag encoding (sint32).""" + return (value << 1) ^ (value >> 31) + +def _encode_protobuf_zigzag64(value): + """Encode a signed 64-bit integer using zigzag encoding (sint64).""" + return (value << 1) ^ (value >> 63) + +def _encode_protobuf_field(field_number, spark_type, value, encoding='default'): + """ + Encode a single protobuf field for a subset of scalar types. + + encoding: 'default', 'fixed', 'zigzag' - determines how integers are encoded + + Notes on signed ints: + - Protobuf `int32`/`int64` use *varint* encoding of the two's-complement integer. + - Negative `int32` values are encoded as a 10-byte varint (because they are sign-extended to 64 bits). + - `sint32`/`sint64` use zigzag encoding for efficient negative number storage. + - `fixed32`/`fixed64` use fixed-width little-endian encoding. + """ + if value is None: + return b"" + + if isinstance(spark_type, BooleanType): + return _encode_protobuf_key(field_number, _PROTOBUF_WIRE_VARINT) + _encode_protobuf_uvarint(1 if value else 0) + elif isinstance(spark_type, IntegerType): + if encoding == 'fixed': + # fixed32 / sfixed32: 4-byte little-endian (mask handles both signed and unsigned) + return _encode_protobuf_key(field_number, _PROTOBUF_WIRE_32BIT) + struct.pack(" bool: + """ + `spark-protobuf` is an optional external module. PySpark may have the Python wrappers + even when the JVM side isn't present on the classpath, which manifests as: + TypeError: 'JavaPackage' object is not callable + when calling into `sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf`. + """ + jvm = spark.sparkContext._jvm + candidates = [ + # Scala object `functions` compiles to `functions$` + "org.apache.spark.sql.protobuf.functions$", + # Some environments may expose it differently + "org.apache.spark.sql.protobuf.functions", + ] + for cls in candidates: + try: + jvm.java.lang.Class.forName(cls) + return True + except Exception: + continue + return False + + +# --------------------------------------------------------------------------- +# Shared fixture and helpers to reduce per-test boilerplate +# --------------------------------------------------------------------------- + +@pytest.fixture(scope="module") +def from_protobuf_fn(): + """Skip the entire module if from_protobuf or the JVM module is unavailable.""" + fn = _try_import_from_protobuf() + if fn is None: + pytest.skip("from_protobuf not available") + if not with_cpu_session(_spark_protobuf_jvm_available): + pytest.skip("spark-protobuf JVM not available") + return fn + + +def _setup_protobuf_desc(spark_tmp_path, desc_name, build_fn): + """Build descriptor bytes via JVM, write to HDFS, return (desc_path, desc_bytes).""" + desc_path = spark_tmp_path + "/" + desc_name + desc_bytes = with_cpu_session(build_fn) + with_cpu_session( + lambda spark: _write_bytes_to_hadoop_path(spark, desc_path, desc_bytes)) + return desc_path, desc_bytes + + +def _call_from_protobuf(from_protobuf_fn, col, message_name, + desc_path, desc_bytes, options=None): + """Call from_protobuf using the right API variant.""" + sig = inspect.signature(from_protobuf_fn) + if "binaryDescriptorSet" in sig.parameters: + kw = dict(binaryDescriptorSet=bytearray(desc_bytes)) + if options is not None: + kw["options"] = options + return from_protobuf_fn(col, message_name, **kw) + if options is not None: + return from_protobuf_fn(col, message_name, desc_path, options) + return from_protobuf_fn(col, message_name, desc_path) + + +def test_call_from_protobuf_preserves_options_for_legacy_signature(): + calls = [] + + def fake_from_protobuf(col, message_name, desc_path, *args): + calls.append((col, message_name, desc_path, args)) + return "ok" + + options = {"enums.as.ints": "true"} + result = _call_from_protobuf( + fake_from_protobuf, "col", "msg", "/tmp/test.desc", b"desc", options=options) + + assert result == "ok" + assert calls == [("col", "msg", "/tmp/test.desc", (options,))] + + +def test_encode_protobuf_packed_repeated_fixed_uses_unsigned_twos_complement(): + i32_encoded = _encode_protobuf_packed_repeated( + 1, IntegerType(), [0xFFFFFFFF], encoding='fixed') + i64_encoded = _encode_protobuf_packed_repeated( + 1, LongType(), [0xFFFFFFFFFFFFFFFF], encoding='fixed') + + assert i32_encoded == b"\x0a\x04" + struct.pack(" 1, 1 -> 2, -2 -> 3, 2 -> 4, etc. + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "signed.desc", _build_signed_int_descriptor_set_bytes) + message_name = "test.WithSignedInts" + + data_gen = ProtobufMessageGen([ + PbScalar("si32", 1, IntegerGen( + special_cases=[-1, 0, 1, -2147483648, 2147483647]), encoding='zigzag'), + PbScalar("si64", 2, LongGen( + special_cases=[-1, 0, 1, -9223372036854775808, 9223372036854775807]), + encoding='zigzag'), + PbScalar("sf32", 3, IntegerGen( + special_cases=[0, 1, -1, 2147483647, -2147483648]), encoding='fixed'), + PbScalar("sf64", 4, LongGen( + special_cases=[0, 1, -1]), encoding='fixed'), + ]) + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + + return df.select( + decoded.getField("si32").alias("si32"), + decoded.getField("si64").alias("si64"), + decoded.getField("sf32").alias("sf32"), + decoded.getField("sf64").alias("sf64"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_fixed_int_descriptor_set_bytes(spark): + """ + Build a FileDescriptorSet for message with fixed-width integer types: + message WithFixedInts { + optional fixed32 fx32 = 1; + optional fixed64 fx64 = 2; + } + """ + D, fd = _new_proto2_file(spark, "fixed_int.proto") + + label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL + + msg = D.DescriptorProto.newBuilder().setName("WithFixedInts") + msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("fx32") + .setNumber(1) + .setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_FIXED32) + .build() + ) + msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("fx64") + .setNumber(2) + .setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_FIXED64) + .build() + ) + fd.addMessageType(msg.build()) + + fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build() + return bytes(fds.toByteArray()) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_fixed_integers(spark_tmp_path, from_protobuf_fn): + """ + Test decoding fixed-width unsigned integer types. + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "fixed.desc", _build_fixed_int_descriptor_set_bytes) + message_name = "test.WithFixedInts" + + data_gen = ProtobufMessageGen([ + PbScalar("fx32", 1, IntegerGen( + special_cases=[0, 1, -1, 2147483647, -2147483648]), encoding='fixed'), + PbScalar("fx64", 2, LongGen( + special_cases=[0, 1, -1]), encoding='fixed'), + ]) + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("fx32").alias("fx32"), + decoded.getField("fx64").alias("fx64"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_schema_projection_descriptor_set_bytes(spark): + """ + Build a FileDescriptorSet for nested schema projection testing: + message Detail { + optional int32 a = 1; + optional int32 b = 2; + optional string c = 3; + } + message SchemaProj { + optional int32 id = 1; + optional string name = 2; + optional Detail detail = 3; + repeated Detail items = 4; + } + The Detail message has 3 fields so we can test pruning subsets. + """ + D, fd = _new_proto2_file(spark, "schema_proj.proto") + + label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL + label_rep = D.FieldDescriptorProto.Label.LABEL_REPEATED + + # Detail message: { a: int32, b: int32, c: string } + detail_msg = D.DescriptorProto.newBuilder().setName("Detail") + detail_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("a").setNumber(1).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + detail_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("b").setNumber(2).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + detail_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("c").setNumber(3).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_STRING).build()) + fd.addMessageType(detail_msg.build()) + + # SchemaProj message: { id, name, detail: Detail, items: repeated Detail } + main_msg = D.DescriptorProto.newBuilder().setName("SchemaProj") + main_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("id").setNumber(1).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + main_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("name").setNumber(2).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_STRING).build()) + main_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("detail").setNumber(3).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE) + .setTypeName(".test.Detail").build()) + main_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("items").setNumber(4).setLabel(label_rep) + .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE) + .setTypeName(".test.Detail").build()) + fd.addMessageType(main_msg.build()) + + fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build() + return bytes(fds.toByteArray()) + + +# Field descriptors for SchemaProj: {id, name, detail: {a, b, c}, items[]: {a, b, c}} +_detail_children = [ + PbScalar("a", 1, IntegerGen()), + PbScalar("b", 2, IntegerGen()), + PbScalar("c", 3, StringGen()), +] +_schema_proj_fields = [ + PbScalar("id", 1, IntegerGen()), + PbScalar("name", 2, StringGen()), + PbNested("detail", 3, _detail_children), + PbRepeatedMessage("items", 4, _detail_children), +] + +_schema_proj_test_data = [ + encode_pb_message(_schema_proj_fields, + [1, "alice", (10, 20, "d1"), [(100, 200, "i1"), (101, 201, "i2")]]), + encode_pb_message(_schema_proj_fields, + [2, "bob", (30, 40, "d2"), [(300, 400, "i3")]]), + encode_pb_message(_schema_proj_fields, + [3, "carol", (50, 60, "d3"), []]), +] + + +_schema_proj_cases = [ + ("nested_single_field", [("id", ("id",)), ("detail_a", ("detail", "a"))]), + ("nested_two_fields", [("detail_a", ("detail", "a")), ("detail_c", ("detail", "c"))]), + ("whole_struct_no_pruning", [("id", ("id",)), ("detail", ("detail",))]), + ("whole_and_subfield", [("detail", ("detail",)), ("detail_a", ("detail", "a"))]), + ("scalar_plus_nested", [("id", ("id",)), ("name", ("name",)), ("detail_a", ("detail", "a"))]), + ("repeated_msg_single_subfield", [("id", ("id",)), ("items_a", ("items", "a"))]), + ("repeated_msg_two_subfields", [("items_a", ("items", "a")), ("items_c", ("items", "c"))]), + ("repeated_whole_no_pruning", [("id", ("id",)), ("items", ("items",))]), + ("mix_struct_and_repeated", [("id", ("id",)), ("detail_a", ("detail", "a")), ("items_c", ("items", "c"))]), +] + + +def _get_field_by_path(expr, path): + current = expr + for name in path: + current = current.getField(name) + return current + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_projection_across_alias_project_boundary( + spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "schema_proj_alias.desc", + _build_schema_projection_descriptor_set_bytes) + message_name = "test.SchemaProj" + + def run_on_spark(spark): + df = spark.createDataFrame([(row,) for row in _schema_proj_test_data], schema="bin binary") + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + aliased = df.select(decoded.alias("decoded")) + return aliased.select( + f.col("decoded").getField("detail").getField("a").alias("detail_a"), + f.col("decoded").getField("id").alias("id")) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_projection_across_withcolumn_boundary( + spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "schema_proj_withcolumn.desc", + _build_schema_projection_descriptor_set_bytes) + message_name = "test.SchemaProj" + + def run_on_spark(spark): + df = spark.createDataFrame([(row,) for row in _schema_proj_test_data], schema="bin binary") + with_decoded = df.withColumn( + "decoded", + _call_from_protobuf(from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)) + return with_decoded.select( + f.col("decoded").getField("items").getField("a").alias("items_a"), + f.col("decoded").getField("id").alias("id")) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_dual_message_projection_descriptor_set_bytes(spark): + """ + message BytesView { + optional int32 status = 1; + optional bytes payload = 2; + } + message NestedPayload { + optional int32 count = 1; + } + message NestedView { + optional int32 status = 1; + optional NestedPayload payload = 2; + } + """ + D, fd = _new_proto2_file(spark, "dual_projection.proto") + label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL + + nested_msg = D.DescriptorProto.newBuilder().setName("NestedPayload") + nested_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("count").setNumber(1).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + fd.addMessageType(nested_msg.build()) + + bytes_view = D.DescriptorProto.newBuilder().setName("BytesView") + bytes_view.addField( + D.FieldDescriptorProto.newBuilder() + .setName("status").setNumber(1).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + bytes_view.addField( + D.FieldDescriptorProto.newBuilder() + .setName("payload").setNumber(2).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_BYTES).build()) + fd.addMessageType(bytes_view.build()) + + nested_view = D.DescriptorProto.newBuilder().setName("NestedView") + nested_view.addField( + D.FieldDescriptorProto.newBuilder() + .setName("status").setNumber(1).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + nested_view.addField( + D.FieldDescriptorProto.newBuilder() + .setName("payload").setNumber(2).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE) + .setTypeName(".test.NestedPayload").build()) + fd.addMessageType(nested_view.build()) + + fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build() + return bytes(fds.toByteArray()) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_different_messages_same_binary_column_do_not_interfere( + spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "dual_projection.desc", + _build_dual_message_projection_descriptor_set_bytes) + + payload_keep = _encode_tag(1, 0) + _encode_varint(7) + payload_drop = _encode_tag(1, 0) + _encode_varint(9) + row_keep = (_encode_tag(1, 0) + _encode_varint(1) + + _encode_tag(2, 2) + _encode_varint(len(payload_keep)) + payload_keep) + row_drop = (_encode_tag(1, 0) + _encode_varint(0) + + _encode_tag(2, 2) + _encode_varint(len(payload_drop)) + payload_drop) + + def run_on_spark(spark): + df = spark.createDataFrame([(row_keep,), (row_drop,)], schema="bin binary") + bytes_view = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), "test.BytesView", desc_path, desc_bytes) + nested_view = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), "test.NestedView", desc_path, desc_bytes) + return df.filter(bytes_view.getField("status") == 1).select( + nested_view.getField("payload").getField("count").alias("payload_count")) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_deep_nested_5_level_descriptor_set_bytes(spark): + """ + Build a FileDescriptorSet for 5-level deep nesting: + message Level5 { optional int32 val5 = 1; } + message Level4 { optional int32 val4 = 1; optional Level5 level5 = 2; } + message Level3 { optional int32 val3 = 1; optional Level4 level4 = 2; } + message Level2 { optional int32 val2 = 1; optional Level3 level3 = 2; } + message Level1 { optional int32 val1 = 1; optional Level2 level2 = 2; } + """ + D, fd = _new_proto2_file(spark, "deep_nested_5_level.proto") + + label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL + + # Level5 message + level5_msg = D.DescriptorProto.newBuilder().setName("Level5") + level5_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("val5") + .setNumber(1) + .setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32) + .build() + ) + fd.addMessageType(level5_msg.build()) + + # Level4 message + level4_msg = D.DescriptorProto.newBuilder().setName("Level4") + level4_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("val4") + .setNumber(1) + .setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32) + .build() + ) + level4_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("level5") + .setNumber(2) + .setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE) + .setTypeName(".test.Level5") + .build() + ) + fd.addMessageType(level4_msg.build()) + + # Level3 message + level3_msg = D.DescriptorProto.newBuilder().setName("Level3") + level3_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("val3") + .setNumber(1) + .setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32) + .build() + ) + level3_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("level4") + .setNumber(2) + .setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE) + .setTypeName(".test.Level4") + .build() + ) + fd.addMessageType(level3_msg.build()) + + # Level2 message + level2_msg = D.DescriptorProto.newBuilder().setName("Level2") + level2_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("val2") + .setNumber(1) + .setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32) + .build() + ) + level2_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("level3") + .setNumber(2) + .setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE) + .setTypeName(".test.Level3") + .build() + ) + fd.addMessageType(level2_msg.build()) + + # Level1 message + level1_msg = D.DescriptorProto.newBuilder().setName("Level1") + level1_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("val1") + .setNumber(1) + .setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32) + .build() + ) + level1_msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("level2") + .setNumber(2) + .setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE) + .setTypeName(".test.Level2") + .build() + ) + fd.addMessageType(level1_msg.build()) + + fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build() + return bytes(fds.toByteArray()) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_deep_nesting_5_levels(spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "deep_nested_5_level.desc", + _build_deep_nested_5_level_descriptor_set_bytes) + message_name = "test.Level1" + data_gen = ProtobufMessageGen([ + PbScalar("val1", 1, IntegerGen()), + PbNested("level2", 2, [ + PbScalar("val2", 1, IntegerGen()), + PbNested("level3", 2, [ + PbScalar("val3", 1, IntegerGen()), + PbNested("level4", 2, [ + PbScalar("val4", 1, IntegerGen()), + PbNested("level5", 2, [ + PbScalar("val5", 1, IntegerGen()), + ]), + ]), + ]), + ]), + ]) + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("val1").alias("val1"), + decoded.getField("level2").alias("level2"), + ) + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@pytest.mark.parametrize("case_id,select_specs", _schema_proj_cases, ids=lambda c: c[0] if isinstance(c, tuple) else str(c)) +@ignore_order(local=True) +def test_from_protobuf_schema_projection_cases( + spark_tmp_path, from_protobuf_fn, case_id, select_specs): + """Parametrized nested-schema projection tests.""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "schema_proj.desc", _build_schema_projection_descriptor_set_bytes) + message_name = "test.SchemaProj" + + def run_on_spark(spark): + df = spark.createDataFrame( + [(d,) for d in _schema_proj_test_data], schema="bin binary") + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + selected = [_get_field_by_path(decoded, path).alias(alias) + for alias, path in select_specs] + return df.select(*selected) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + +def _build_name_collision_descriptor_set_bytes(spark): + D, fd = _new_proto2_file(spark, "name_collision.proto") + label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL + + # User message + user_msg = D.DescriptorProto.newBuilder().setName("User") + user_msg.addField(D.FieldDescriptorProto.newBuilder().setName("age").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + user_msg.addField(D.FieldDescriptorProto.newBuilder().setName("id").setNumber(2).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + fd.addMessageType(user_msg.build()) + + # Ad message + ad_msg = D.DescriptorProto.newBuilder().setName("Ad") + ad_msg.addField(D.FieldDescriptorProto.newBuilder().setName("id").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + fd.addMessageType(ad_msg.build()) + + # Event message + event_msg = D.DescriptorProto.newBuilder().setName("Event") + event_msg.addField(D.FieldDescriptorProto.newBuilder().setName("user_info").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE).setTypeName(".test.User").build()) + event_msg.addField(D.FieldDescriptorProto.newBuilder().setName("ad_info").setNumber(2).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE).setTypeName(".test.Ad").build()) + fd.addMessageType(event_msg.build()) + + fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build() + return bytes(fds.toByteArray()) + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_bug1_name_collision(spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "name_collision.desc", + _build_name_collision_descriptor_set_bytes) + message_name = "test.Event" + + data_gen = ProtobufMessageGen([ + PbNested("user_info", 1, [ + PbScalar("age", 1, IntegerGen()), + PbScalar("id", 2, IntegerGen()), + ]), + PbNested("ad_info", 2, [ + PbScalar("id", 1, IntegerGen()), + ]), + ]) + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + + return df.select( + decoded.getField("user_info").getField("age").alias("age"), + decoded.getField("user_info").getField("id").alias("user_id"), + decoded.getField("ad_info").getField("id").alias("ad_id") + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_filter_jump_descriptor_set_bytes(spark): + D, fd = _new_proto2_file(spark, "filter_jump.proto") + label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL + + msg = D.DescriptorProto.newBuilder().setName("Event") + msg.addField(D.FieldDescriptorProto.newBuilder().setName("status").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + msg.addField(D.FieldDescriptorProto.newBuilder().setName("ad_info").setNumber(2).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_STRING).build()) + fd.addMessageType(msg.build()) + + fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build() + return bytes(fds.toByteArray()) + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_bug2_filter_jump(spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "filter_jump.desc", + _build_filter_jump_descriptor_set_bytes) + message_name = "test.Event" + + data_gen = ProtobufMessageGen([ + PbScalar("status", 1, IntegerGen(min_val=1, max_val=1)), + PbScalar("ad_info", 2, StringGen()), + ]) + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + pb_expr1 = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + pb_expr2 = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + + return df.filter(pb_expr1.getField("status") == 1).select(pb_expr2.getField("ad_info").alias("ad_info")) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_unrelated_struct_name_collision_descriptor_set_bytes(spark): + D, fd = _new_proto2_file(spark, "unrelated_struct.proto") + label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL + + # Nested message + nested_msg = D.DescriptorProto.newBuilder().setName("Nested") + nested_msg.addField(D.FieldDescriptorProto.newBuilder().setName("dummy").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + nested_msg.addField(D.FieldDescriptorProto.newBuilder().setName("winfoid").setNumber(2).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + fd.addMessageType(nested_msg.build()) + + msg = D.DescriptorProto.newBuilder().setName("Event") + msg.addField(D.FieldDescriptorProto.newBuilder().setName("ad_info").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE).setTypeName(".test.Nested").build()) + fd.addMessageType(msg.build()) + + fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build() + return bytes(fds.toByteArray()) + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_bug3_unrelated_struct_name_collision(spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "unrelated_struct.desc", + _build_unrelated_struct_name_collision_descriptor_set_bytes) + message_name = "test.Event" + + data_gen = ProtobufMessageGen([ + PbNested("ad_info", 1, [ + PbScalar("dummy", 1, IntegerGen()), + PbScalar("winfoid", 2, IntegerGen()), + ]), + ]) + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + # Write to parquet to prevent Catalyst from optimizing away the GetStructField, + # and to ensure it runs on the GPU. + df_with_other = df.withColumn("other_struct", + f.struct(f.lit("hello").alias("dummy_str"), f.lit(42).alias("winfoid"))) + + path = spark_tmp_path + "/bug3_data.parquet" + df_with_other.write.mode("overwrite").parquet(path) + read_df = spark.read.parquet(path) + + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + + # We only select decoded.ad_info.winfoid, so dummy is pruned. + # winfoid gets ordinal 0 in the pruned schema. + # But for other_struct, winfoid is ordinal 1. + # GpuGetStructFieldMeta will see "winfoid", query the ThreadLocal, get 0, + # and extract ordinal 0 ("hello") for other_winfoid, causing a mismatch! + return read_df.select( + decoded.getField("ad_info").getField("winfoid").alias("pb_winfoid"), + f.col("other_struct").getField("winfoid").alias("other_winfoid") + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_max_depth_descriptor_set_bytes(spark): + D, fd = _new_proto2_file(spark, "max_depth.proto") + label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL + + # Generate 12 levels of nesting + for i in range(12, 0, -1): + msg = D.DescriptorProto.newBuilder().setName(f"Level{i}") + msg.addField(D.FieldDescriptorProto.newBuilder().setName(f"val{i}").setNumber(1).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + if i < 12: + msg.addField(D.FieldDescriptorProto.newBuilder().setName(f"level{i+1}").setNumber(2).setLabel(label_opt).setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE).setTypeName(f".test.Level{i+1}").build()) + fd.addMessageType(msg.build()) + + fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build() + return bytes(fds.toByteArray()) + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_bug4_max_depth(spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "max_depth.desc", + _build_max_depth_descriptor_set_bytes) + message_name = "test.Level1" + + # Build the deeply nested data gen spec + def build_nested_gen(level): + if level == 12: + return [PbScalar(f"val{level}", 1, IntegerGen())] + return [ + PbScalar(f"val{level}", 1, IntegerGen()), + PbNested(f"level{level+1}", 2, build_nested_gen(level+1)) + ] + + data_gen = ProtobufMessageGen(build_nested_gen(1)) + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + # Deep access + field = decoded + for i in range(2, 13): + field = field.getField(f"level{i}") + return df.select(field.getField("val12").alias("val12")) + + # Depth 12 exceeds GPU max nesting depth (10), so the query should + # gracefully fall back to CPU. Verify that it still produces correct + # results (CPU path) without crashing. + from spark_session import with_cpu_session + cpu_result = with_cpu_session(lambda spark: run_on_spark(spark).collect()) + assert len(cpu_result) > 0 + + +# =========================================================================== +# Regression tests for known bugs found by code review +# =========================================================================== + +def _encode_varint(value): + """Encode a non-negative integer as a protobuf varint (for hand-crafting test bytes).""" + out = bytearray() + v = int(value) + while True: + b = v & 0x7F + v >>= 7 + if v: + out.append(b | 0x80) + else: + out.append(b) + break + return bytes(out) + + +def _encode_tag(field_number, wire_type): + return _encode_varint((field_number << 3) | wire_type) + + +# --------------------------------------------------------------------------- +# Bug 1: BOOL8 truncation — non-canonical bool varint values >= 256 +# +# Protobuf spec: bool is a varint; any non-zero value means true. +# CPU decoder (protobuf-java): CodedInputStream.readBool() = readRawVarint64() != 0 → true +# GPU decoder: extract_varint_kernel writes static_cast(v). +# For v = 256, static_cast(256) == 0 → false. BUG. +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_bool_noncanonical_varint_scalar(spark_tmp_path, from_protobuf_fn): + """Regression test: scalar bool encoded as non-canonical varint (e.g. 256) must decode as true. + + Protobuf allows any non-zero varint for bool true. The GPU decoder previously + truncated to uint8_t, causing values >= 256 to wrap to 0 (false). + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "simple_bool_bug.desc", _build_simple_descriptor_set_bytes) + message_name = "test.Simple" + + # varint(256) = 0x80 0x02, varint(512) = 0x80 0x04 — valid non-canonical bool true + row_bool_256 = _encode_tag(1, 0) + _encode_varint(256) + \ + _encode_tag(2, 0) + _encode_varint(99) + + # Control row: canonical bool true (varint 1) — should work on both + row_bool_1 = _encode_tag(1, 0) + _encode_varint(1) + \ + _encode_tag(2, 0) + _encode_varint(100) + + # Another non-canonical value: varint(512) + row_bool_512 = _encode_tag(1, 0) + _encode_varint(512) + \ + _encode_tag(2, 0) + _encode_varint(101) + + def run_on_spark(spark): + df = spark.createDataFrame( + [(row_bool_256,), (row_bool_1,), (row_bool_512,)], + schema="bin binary", + ) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("b").alias("b"), + decoded.getField("i32").alias("i32"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_repeated_bool_descriptor_set_bytes(spark): + """ + message WithRepeatedBool { + optional int32 id = 1; + repeated bool flags = 2; + } + """ + D, fd = _new_proto2_file(spark, "repeated_bool.proto") + label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL + label_rep = D.FieldDescriptorProto.Label.LABEL_REPEATED + msg = D.DescriptorProto.newBuilder().setName("WithRepeatedBool") + msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("id").setNumber(1).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + msg.addField( + D.FieldDescriptorProto.newBuilder() + .setName("flags").setNumber(2).setLabel(label_rep) + .setType(D.FieldDescriptorProto.Type.TYPE_BOOL).build()) + fd.addMessageType(msg.build()) + fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build() + return bytes(fds.toByteArray()) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_bool_noncanonical_varint_repeated(spark_tmp_path, from_protobuf_fn): + """Regression test: repeated bool with non-canonical varint values must all decode as true. + + Same uint8_t truncation issue as the scalar case, exercised with repeated fields. + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "repeated_bool_bug.desc", _build_repeated_bool_descriptor_set_bytes) + message_name = "test.WithRepeatedBool" + + # Repeated bool field 2 (wire type 0 = varint), unpacked. + # Three elements: varint(256), varint(1), varint(512) — all should decode as true. + row = (_encode_tag(1, 0) + _encode_varint(42) + + _encode_tag(2, 0) + _encode_varint(256) + + _encode_tag(2, 0) + _encode_varint(1) + + _encode_tag(2, 0) + _encode_varint(512)) + + def run_on_spark(spark): + df = spark.createDataFrame([(row,)], schema="bin binary") + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("id").alias("id"), + decoded.getField("flags").alias("flags"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +# --------------------------------------------------------------------------- +# Regression guard: nested message child field default values +# --------------------------------------------------------------------------- + +def _build_nested_with_defaults_descriptor_set_bytes(spark): + """ + message Inner { + optional int32 count = 1 [default = 42]; + optional string label = 2 [default = "hello"]; + optional bool flag = 3 [default = true]; + } + message OuterWithNestedDefaults { + optional int32 id = 1; + optional Inner inner = 2; + } + """ + D, fd = _new_proto2_file(spark, "nested_defaults.proto") + label_opt = D.FieldDescriptorProto.Label.LABEL_OPTIONAL + + inner = D.DescriptorProto.newBuilder().setName("Inner") + inner.addField( + D.FieldDescriptorProto.newBuilder() + .setName("count").setNumber(1).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32) + .setDefaultValue("42").build()) + inner.addField( + D.FieldDescriptorProto.newBuilder() + .setName("label").setNumber(2).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_STRING) + .setDefaultValue("hello").build()) + inner.addField( + D.FieldDescriptorProto.newBuilder() + .setName("flag").setNumber(3).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_BOOL) + .setDefaultValue("true").build()) + fd.addMessageType(inner.build()) + + outer = D.DescriptorProto.newBuilder().setName("OuterWithNestedDefaults") + outer.addField( + D.FieldDescriptorProto.newBuilder() + .setName("id").setNumber(1).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build()) + outer.addField( + D.FieldDescriptorProto.newBuilder() + .setName("inner").setNumber(2).setLabel(label_opt) + .setType(D.FieldDescriptorProto.Type.TYPE_MESSAGE) + .setTypeName(".test.Inner").build()) + fd.addMessageType(outer.build()) + + fds = D.FileDescriptorSet.newBuilder().addFile(fd.build()).build() + return bytes(fds.toByteArray()) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_nested_child_default_values(spark_tmp_path, from_protobuf_fn): + """Regression test: proto2 default values for fields inside nested messages must be honored. + + When `inner` is present but its child fields are absent, the decoder must + return the proto2 defaults (count=42, label="hello", flag=true), not null. + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "nested_defaults.desc", + _build_nested_with_defaults_descriptor_set_bytes) + message_name = "test.OuterWithNestedDefaults" + + # Row 1: outer.id = 10, inner is present but EMPTY (0-length nested message). + # Wire: field 1 varint(10), field 2 length-delimited with length 0. + # CPU should fill inner.count=42, inner.label="hello", inner.flag=true. + row_empty_inner = (_encode_tag(1, 0) + _encode_varint(10) + + _encode_tag(2, 2) + _encode_varint(0)) + + # Row 2: outer.id = 20, inner has only count=7 (label and flag should get defaults). + inner_partial = _encode_tag(1, 0) + _encode_varint(7) + row_partial_inner = (_encode_tag(1, 0) + _encode_varint(20) + + _encode_tag(2, 2) + _encode_varint(len(inner_partial)) + + inner_partial) + + # Row 3: outer.id = 30, inner is fully absent → inner itself is null. + row_no_inner = _encode_tag(1, 0) + _encode_varint(30) + + def run_on_spark(spark): + df = spark.createDataFrame( + [(row_empty_inner,), (row_partial_inner,), (row_no_inner,)], + schema="bin binary", + ) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("id").alias("id"), + decoded.getField("inner").getField("count").alias("inner_count"), + decoded.getField("inner").getField("label").alias("inner_label"), + decoded.getField("inner").getField("flag").alias("inner_flag"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +# =========================================================================== +# Deep nested schema pruning tests +# +# These verify that the GPU path correctly prunes nested fields at depth > 2. +# Previously, collectStructFieldReferences only recognized 2-level +# GetStructField chains, so accessing decoded.level2.level3.val3 would +# decode ALL of level3's children instead of only val3. +# =========================================================================== + +def _deep_5_level_data_gen(): + return ProtobufMessageGen([ + PbScalar("val1", 1, IntegerGen()), + PbNested("level2", 2, [ + PbScalar("val2", 1, IntegerGen()), + PbNested("level3", 2, [ + PbScalar("val3", 1, IntegerGen()), + PbNested("level4", 2, [ + PbScalar("val4", 1, IntegerGen()), + PbNested("level5", 2, [ + PbScalar("val5", 1, IntegerGen()), + ]), + ]), + ]), + ]), + ]) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_deep_pruning_3_level_leaf(spark_tmp_path, from_protobuf_fn): + """Access decoded.level2.level3.val3 -- triggers 3-level deep pruning.""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "dp3.desc", _build_deep_nested_5_level_descriptor_set_bytes) + message_name = "test.Level1" + data_gen = _deep_5_level_data_gen() + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("val1").alias("val1"), + decoded.getField("level2").getField("level3").getField("val3").alias("deep_val3"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_deep_pruning_5_level_leaf(spark_tmp_path, from_protobuf_fn): + """Access decoded.level2.level3.level4.level5.val5 -- deepest leaf.""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "dp5.desc", _build_deep_nested_5_level_descriptor_set_bytes) + message_name = "test.Level1" + data_gen = _deep_5_level_data_gen() + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + _get_field_by_path(decoded, ["level2", "level3", "level4", "level5", "val5"]) + .alias("val5"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_deep_pruning_mixed_depths(spark_tmp_path, from_protobuf_fn): + """Access leaves at different depths in the same query. + + Select val1 (depth 1), val2 (depth 2), val3 (depth 3), and val5 (depth 5) + to exercise pruning at every intermediate level simultaneously. + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "dp_mix.desc", _build_deep_nested_5_level_descriptor_set_bytes) + message_name = "test.Level1" + data_gen = _deep_5_level_data_gen() + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("val1").alias("val1"), + decoded.getField("level2").getField("val2").alias("val2"), + _get_field_by_path(decoded, ["level2", "level3", "val3"]).alias("val3"), + _get_field_by_path(decoded, ["level2", "level3", "level4", "level5", "val5"]) + .alias("val5"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_deep_pruning_sibling_at_depth_3(spark_tmp_path, from_protobuf_fn): + """At depth 3, access val3 but NOT level4 -- level4 subtree should be pruned.""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "dp_sib3.desc", _build_deep_nested_5_level_descriptor_set_bytes) + message_name = "test.Level1" + data_gen = _deep_5_level_data_gen() + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + _get_field_by_path(decoded, ["level2", "level3", "val3"]).alias("val3"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_deep_pruning_whole_struct_at_depth_3(spark_tmp_path, from_protobuf_fn): + """Select the whole level3 struct -- no deep pruning inside level3.""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "dp_whole3.desc", _build_deep_nested_5_level_descriptor_set_bytes) + message_name = "test.Level1" + data_gen = _deep_5_level_data_gen() + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("level2").getField("level3").alias("level3"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +# =========================================================================== +# FAILFAST mode tests +# =========================================================================== + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +def test_from_protobuf_failfast_malformed_data(spark_tmp_path, from_protobuf_fn): + """FAILFAST mode should throw on malformed protobuf data (both CPU and GPU).""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "failfast.desc", _build_simple_descriptor_set_bytes) + message_name = "test.Simple" + + # Craft a valid row and a malformed row (truncated varint with continuation bit) + valid_row = _encode_tag(1, 0) + _encode_varint(1) + \ + _encode_tag(2, 0) + _encode_varint(42) + malformed_row = bytes([0x08, 0x80]) # field 1, varint, but only continuation byte -- no end + + def run_on_spark(spark): + df = spark.createDataFrame( + [(valid_row,), (malformed_row,)], + schema="bin binary", + ) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes, + options={"mode": "FAILFAST"}) + # Must call .collect() so the exception surfaces inside with_*_session + return df.select(decoded.getField("b").alias("b")).collect() + + assert_gpu_and_cpu_error(run_on_spark, {}, "Malformed") + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_permissive_malformed_returns_null(spark_tmp_path, from_protobuf_fn): + """PERMISSIVE mode should return null for malformed rows, not throw. + + Note: Spark's from_protobuf defaults to FAILFAST (unlike JSON/CSV which + default to PERMISSIVE), so mode must be set explicitly. + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "permissive.desc", _build_simple_descriptor_set_bytes) + message_name = "test.Simple" + + valid_row = _encode_tag(2, 0) + _encode_varint(99) + malformed_row = bytes([0x08, 0x80]) # truncated varint + + def run_on_spark(spark): + df = spark.createDataFrame( + [(valid_row,), (malformed_row,)], + schema="bin binary", + ) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes, + options={"mode": "PERMISSIVE"}) + return df.select( + decoded.getField("i32").alias("i32"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_all_null_input(spark_tmp_path, from_protobuf_fn): + """All rows in the input binary column are null (not empty bytes, actual nulls). + GPU should produce all-null struct rows matching CPU behavior.""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "allnull.desc", _build_simple_descriptor_set_bytes) + message_name = "test.Simple" + + def run_on_spark(spark): + df = spark.createDataFrame( + [(None,), (None,), (None,)], + schema="bin binary", + ) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("i32").alias("i32"), + decoded.getField("s").alias("s"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) diff --git a/integration_tests/src/main/python/spark_init_internal.py b/integration_tests/src/main/python/spark_init_internal.py index 6d67538d513..95f988cbc2b 100644 --- a/integration_tests/src/main/python/spark_init_internal.py +++ b/integration_tests/src/main/python/spark_init_internal.py @@ -61,11 +61,46 @@ def findspark_init(): if spark_jars is not None: logging.info(f"Adding to findspark jars: {spark_jars}") findspark.add_jars(spark_jars) + # Also add to driver classpath so classes are available to Class.forName() + # This is needed for optional modules like spark-protobuf + _add_driver_classpath(spark_jars) if spark_jars_packages is not None: logging.info(f"Adding to findspark packages: {spark_jars_packages}") findspark.add_packages(spark_jars_packages) + +def _add_driver_classpath(jars): + """ + Add jars to the driver classpath via PYSPARK_SUBMIT_ARGS. + findspark.add_jars() only adds --jars, which doesn't make classes available + to Class.forName() on the driver. This function adds --driver-class-path. + """ + if not jars: + return + current_args = os.environ.get('PYSPARK_SUBMIT_ARGS', '') + # Remove trailing 'pyspark-shell' if present + if current_args.endswith('pyspark-shell'): + current_args = current_args[:-len('pyspark-shell')].strip() + jar_list = [j.strip() for j in jars.split(',') if j.strip()] + new_cp = os.pathsep.join(jar_list) + if '--driver-class-path' in current_args: + match = re.search(r'--driver-class-path\s+(\S+)', current_args) + if match: + existing_cp = match.group(1) + merged_cp = existing_cp + os.pathsep + new_cp + current_args = re.sub( + r'--driver-class-path\s+\S+', + lambda m: f'--driver-class-path {merged_cp}', + current_args) + else: + current_args += f' --driver-class-path {new_cp}' + else: + current_args += f' --driver-class-path {new_cp}' + new_args = f"{current_args} pyspark-shell".strip() + os.environ['PYSPARK_SUBMIT_ARGS'] = new_args + logging.info(f"Updated PYSPARK_SUBMIT_ARGS with driver-class-path") + def running_with_xdist(session, is_worker): try: import xdist diff --git a/integration_tests/src/test/resources/protobuf_test/gen_nested_proto_data.sh b/integration_tests/src/test/resources/protobuf_test/gen_nested_proto_data.sh new file mode 100755 index 00000000000..2283ac8a751 --- /dev/null +++ b/integration_tests/src/test/resources/protobuf_test/gen_nested_proto_data.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Convenience script: compile nested_proto .proto files into a descriptor set. +# +# Usage: +# ./gen_nested_proto_data.sh +# +# The generated .desc file is checked into the repository and used by +# integration tests in protobuf_test.py. Re-run this script whenever +# the .proto definitions under nested_proto/ change. + +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROTO_DIR="${SCRIPT_DIR}/nested_proto" +OUTPUT_DIR="${SCRIPT_DIR}/nested_proto/generated" + +echo "=== Protobuf Descriptor Compiler ===" +echo "Proto dir: ${PROTO_DIR}" +echo "" + +# Create output directory +mkdir -p "${OUTPUT_DIR}" + +# Compile proto files into a descriptor set (includes all imports) +DESC_FILE="${OUTPUT_DIR}/main_log.desc" +echo "Compiling proto files..." +protoc \ + --descriptor_set_out="${DESC_FILE}" \ + --include_imports \ + -I"${PROTO_DIR}" \ + "${PROTO_DIR}/main_log.proto" + +echo "Generated: ${DESC_FILE}" +echo "=== Done ===" diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/device_req.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/device_req.proto new file mode 100644 index 00000000000..c4d86951d96 --- /dev/null +++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/device_req.proto @@ -0,0 +1,11 @@ +syntax = "proto2"; + +package com.test.proto.sample; + +option java_outer_classname = "DeviceReqBean"; + +// Device request field +message DeviceReqField { + optional int32 os_type = 1; // int32 + optional bytes device_id = 2; // bytes +} diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/generated/main_log.desc b/integration_tests/src/test/resources/protobuf_test/nested_proto/generated/main_log.desc new file mode 100644 index 00000000000..6e8155238b3 Binary files /dev/null and b/integration_tests/src/test/resources/protobuf_test/nested_proto/generated/main_log.desc differ diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/main_log.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/main_log.proto new file mode 100644 index 00000000000..f9a325a9f2e --- /dev/null +++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/main_log.proto @@ -0,0 +1,103 @@ +syntax = "proto2"; + +package com.test.proto.sample; + +import "module_a_res.proto"; +import "module_b_res.proto"; +import "device_req.proto"; + +// ========== Enum type tests ========== +enum SourceType { + WEB = 0; // web + APP = 1; // application + MOBILE = 4; // mobile +} + +enum ChannelType { + CHANNEL_A = 0; + CHANNEL_B = 1; +} + +// ========== Main log record ========== +message MainLogRecord { + // required fields + required SourceType source = 1; // required enum + required uint64 timestamp = 2; // required uint64 + + // optional scalar types - one of each + optional string user_id = 3; // string + optional int64 account_id = 4; // int64 + optional fixed32 client_ip = 5; // fixed32 + + // nested message + optional LogContent log_content = 6; +} + +// ========== Log content (multi-level nesting) ========== +message LogContent { + optional BasicInfo basic_info = 1; + repeated ChannelInfo channel_list = 2; + repeated DataSourceField source_list = 3; +} + +// ========== Basic info (three-level nesting) ========== +message BasicInfo { + optional RequestInfo request_info = 1; + optional ExtendedReqInfo extended_req_info = 2; + optional ServerAddedField server_added_field = 3; +} + +// ========== Request info ========== +message RequestInfo { + optional uint32 page_num = 1; // uint32 + optional string channel_code = 2; // string + repeated uint32 experiment_ids = 3; // repeated uint32 + optional bool is_filtered = 4; // bool +} + +// ========== Extended request info (cross-file import) ========== +message ExtendedReqInfo { + optional DeviceReqField device_req_field = 1; // reference to external proto +} + +// ========== Server-added fields ========== +message ServerAddedField { + optional uint32 region_code = 1; // uint32 + optional string flow_type = 2; // string + optional int32 filter_result = 3; // int32 + repeated int32 hit_rule_list = 4; // repeated int32 + optional uint64 request_time = 5; // uint64 + optional bool skip_flag = 6; // bool +} + +// ========== Channel info ========== +message ChannelInfo { + optional int32 channel_id = 1; + optional ModuleAResField module_a_res = 2; // reference to external proto +} + +// ========== Source channel info ========== +message SrcChannelInfo { + optional int32 channel_id = 1; + optional ModuleASrcResField module_a_src_res = 2; // reference to external proto +} + +// ========== Data source field ========== +message DataSourceField { + optional uint32 source_id = 1; + repeated SrcChannelInfo src_channel_list = 2; + optional string billing_name = 3; + repeated ItemDetailField item_list = 4; + optional bool is_free = 5; +} + +// ========== Item detail field ========== +message ItemDetailField { + optional uint32 rank = 1; // uint32 + optional uint64 record_id = 2; // uint64 + optional string keyword = 3; // string + + // cross-file message references + optional ModuleADetailField module_a_detail = 4; + optional ModuleBDetailField module_b_detail = 5; +} diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/module_a_res.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_a_res.proto new file mode 100644 index 00000000000..9a2674e5dd5 --- /dev/null +++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_a_res.proto @@ -0,0 +1,92 @@ +syntax = "proto2"; + +package com.test.proto.sample; + +option java_outer_classname = "ModuleARes"; + +import "predictor_schema.proto"; + +// ========== Test default values ========== +message PartnerInfo { + optional string token = 1 [default = ""]; // default empty string + optional uint64 partner_id = 2 [default = 0]; // default 0 +} + +// ========== Test coordinate structure (simple nesting) ========== +message Coordinate { + optional double x = 1; // x coordinate - double + optional double y = 2; // y coordinate - double +} + +// ========== Test multi-level nesting ========== +message LocationPoint { + optional uint32 frequency = 1; // frequency - uint32 + optional Coordinate coord = 2; // coordinate - nested message + optional uint64 timestamp = 3; // timestamp - uint64 +} + +// ========== Test change log ========== +message ChangeLog { + optional uint32 value_before = 1; // value before change + optional string parameters = 2; // parameters +} + +// ========== Test price change log ========== +message PriceLog { + optional uint32 price_before = 1; +} + +// ========== Module A response-level fields ========== +message ModuleAResField { + optional string route_tag = 1; // route tag - string + optional int32 status_tag = 2; // status tag - int32 + optional uint32 region_id = 3; // region id - uint32 + repeated string experiment_ids = 4; // experiment id list - repeated string + optional double quality_score = 5; // quality score - double + repeated LocationPoint location_points = 6; // location points - repeated nested message + repeated uint64 interest_ids = 7; // interest id list - repeated uint64 +} + +// ========== Module A source response field ========== +message ModuleASrcResField { + optional uint32 match_type = 1; // match type +} + +// ========== Key-value pair ========== +message KVPair { + optional bytes key = 1; // key - bytes + optional bytes value = 2; // value - bytes +} + +// ========== Style configuration ========== +message StyleConfig { + optional uint32 style_id = 1; // style id + repeated KVPair kv_pairs = 2; // kv pair list - repeated nested message +} + +// ========== Module A detail field (core complex structure) ========== +message ModuleADetailField { + // scalar types - one or two of each + optional uint32 type_code = 1; // uint32 + optional uint64 item_id = 2; // uint64 + optional int32 strategy_type = 3; // int32 + optional int64 min_value = 4; // int64 + optional bytes target_url = 5; // bytes + optional string title = 6; // string + optional bool is_valid = 7; // bool + optional float score_ratio = 8; // float + + // repeated scalar types + repeated uint32 template_ids = 9; // repeated uint32 + repeated uint64 material_ids = 10; // repeated uint64 + + // repeated nested messages + repeated StyleConfig styles = 11; // repeated message + repeated ChangeLog change_logs = 12; // repeated message + + // nested message + optional PartnerInfo partner_info = 13; // nested message + + // cross-file import + optional PredictorSchema predictor_schema = 14; // reference to external proto +} diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/module_b_res.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_b_res.proto new file mode 100644 index 00000000000..0c742739835 --- /dev/null +++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_b_res.proto @@ -0,0 +1,29 @@ +syntax = "proto2"; + +package com.test.proto.sample; + +// Module B response field +message ModuleBResField { + optional uint32 type_code = 1; // uint32 + optional string extra_info = 2; // string +} + +// Block element - tests repeated scalar types +message BlockElement { + optional uint64 element_id = 1; // uint64 + repeated uint64 ref_ids = 2; // repeated uint64 +} + +// Block info - tests repeated nested messages +message BlockInfo { + optional uint64 block_id = 1; // uint64 + repeated BlockElement elements = 2; // repeated message +} + +// Module B detail field +message ModuleBDetailField { + repeated uint32 tags = 1; // repeated uint32 + optional uint64 item_id = 2; // uint64 + optional string name = 3; // string + repeated BlockInfo blocks = 4; // repeated message +} diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/predictor_schema.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/predictor_schema.proto new file mode 100644 index 00000000000..bc7d86c67e2 --- /dev/null +++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/predictor_schema.proto @@ -0,0 +1,82 @@ +syntax = "proto2"; + +package com.test.proto.sample; + +// Predictor schema - tests multi-level schema nesting + +// ========== Main schema structure ========== +message PredictorSchema { + optional SchemaTypeA type_a_schema = 1; // nested + optional SchemaTypeB type_b_schema = 2; // nested + optional SchemaTypeC type_c_schema = 3; // nested (with repeated) +} + +// ========== TypeA Query Schema ========== +message TypeAQuerySchema { + optional string keyword = 1; // keyword + optional string session_id = 2; // session id +} + +// ========== TypeA Pair Schema ========== +message TypeAPairSchema { + optional string record_id = 1; // record id + optional string item_id = 2; // item id +} + +// ========== TypeA Schema ========== +message SchemaTypeA { + optional TypeAQuerySchema query_schema = 1; // query-level schema + repeated TypeAPairSchema pair_schema = 2; // pair list - repeated nested +} + +// ========== TypeB Query Schema ========== +message TypeBQuerySchema { + optional string profile_tag_id = 1; // profile tag id + optional string entity_id = 2; // entity id +} + +// ========== TypeB Style Element ========== +message TypeBStyleElem { + optional string template_id = 1; // template id + optional string material_id = 2; // material id +} + +// ========== TypeB Style Schema ========== +message TypeBStyleSchema { + repeated TypeBStyleElem values = 1; // element list - repeated nested +} + +// ========== TypeB Schema ========== +message SchemaTypeB { + optional TypeBQuerySchema query_schema = 1; // query-level schema + repeated TypeBStyleSchema style_schema = 2; // style list - repeated nested +} + +// ========== TypeC Query Schema ========== +message TypeCQuerySchema { + optional string keyword = 1; // keyword + optional string category = 2; // category +} + +// ========== TypeC Pair Schema ========== +message TypeCPairSchema { + optional string item_id = 1; // item id + optional string target_url = 2; // target url +} + +// ========== TypeC Style Element (empty structure test) ========== +message TypeCStyleElem { + // empty message - tests empty structure handling +} + +// ========== TypeC Style Schema ========== +message TypeCStyleSchema { + repeated TypeCStyleElem values = 1; // empty element list +} + +// ========== TypeC Schema (full three-level structure) ========== +message SchemaTypeC { + optional TypeCQuerySchema query_schema = 1; // query schema + repeated TypeCPairSchema pair_schema = 2; // pair list + repeated TypeCStyleSchema style_schema = 3; // style list +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index fef7bea85dd..29ad563f6b3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2628,10 +2628,7 @@ object GpuOverrides extends Logging { TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.BINARY), TypeSig.STRUCT.nested(TypeSig.all)), - (expr, conf, p, r) => new UnaryExprMeta[GetStructField](expr, conf, p, r) { - override def convertToGpu(arr: Expression): GpuExpression = - GpuGetStructField(arr, expr.ordinal, expr.name) - }), + (expr, conf, p, r) => new GpuGetStructFieldMeta(expr, conf, p, r)), expr[GetArrayItem]( "Gets the field at `ordinal` in the Array", ExprChecks.binaryProject( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 5a507e7dab7..5243bcff6c6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1600,6 +1600,15 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .booleanConf .createWithDefault(true) + val ENABLE_PROTOBUF_BATCH_MERGE_AFTER_PROJECT = + conf("spark.rapids.sql.protobuf.batchMergeAfterProject.enabled") + .doc("When set to true, allows a GPU Project containing a schema-pruned " + + "`from_protobuf` decode to request a post-project batch coalesce. This is intended " + + "to reduce tiny batches produced after protobuf schema projection.") + .internal() + .booleanConf + .createWithDefault(false) + val ENABLE_ORC_FLOAT_TYPES_TO_STRING = conf("spark.rapids.sql.format.orc.floatTypesToString.enable") .doc("When reading an ORC file, the source data schemas(schemas of ORC file) may differ " + @@ -3539,6 +3548,9 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val isCoalesceAfterExpandEnabled: Boolean = get(ENABLE_COALESCE_AFTER_EXPAND) + lazy val isProtobufBatchMergeAfterProjectEnabled: Boolean = + get(ENABLE_PROTOBUF_BATCH_MERGE_AFTER_PROJECT) + lazy val multiThreadReadNumThreads: Int = { // Use the largest value set among all the options. val deprecatedConfs = Seq( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index a97f830fe3e..5db9f27ad0b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -47,6 +47,49 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.random.BernoulliCellSampler +object GpuProjectExecMeta { + private def isProtobufDecodeExpr(expr: Expression): Boolean = + expr.getClass.getName.endsWith("ProtobufDataToCatalyst") + + private def protobufAliasExprIds(plan: SparkPlan): Set[ExprId] = plan match { + case ProjectExec(projectList, _) => + projectList.collect { + case alias: Alias if isProtobufDecodeExpr(alias.child) => alias.exprId + }.toSet + case _ => Set.empty + } + + private def isRootedAtProtobufDecode(expr: Expression, protobufExprIds: Set[ExprId]): Boolean = + expr match { + case attr: AttributeReference => + protobufExprIds.contains(attr.exprId) + case decode if isProtobufDecodeExpr(decode) => + true + case GetStructField(child, _, _) => + isRootedAtProtobufDecode(child, protobufExprIds) + case gasf: GetArrayStructFields => + isRootedAtProtobufDecode(gasf.child, protobufExprIds) + case _ => + false + } + + private def extractsFromProtobufAlias( + expr: Expression, + protobufExprIds: Set[ExprId]): Boolean = expr match { + case GetStructField(child, _, _) => + isRootedAtProtobufDecode(child, protobufExprIds) + case gasf: GetArrayStructFields => + isRootedAtProtobufDecode(gasf.child, protobufExprIds) + case other => + other.children.exists(child => extractsFromProtobufAlias(child, protobufExprIds)) + } + + private[rapids] def shouldCoalesceAfterProject(plan: ProjectExec): Boolean = { + val protobufExprIds = protobufAliasExprIds(plan.child) + plan.projectList.exists(expr => extractsFromProtobufAlias(expr, protobufExprIds)) + } +} + class GpuProjectExecMeta( proj: ProjectExec, conf: RapidsConf, @@ -57,11 +100,13 @@ class GpuProjectExecMeta( // Force list to avoid recursive Java serialization of lazy list Seq implementation val gpuExprs = childExprs.map(_.convertToGpu().asInstanceOf[NamedExpression]).toList val gpuChild = childPlans.head.convertIfNeeded() + val forcePostProjectCoalesce = conf.isProtobufBatchMergeAfterProjectEnabled && + GpuProjectExecMeta.shouldCoalesceAfterProject(proj) if (conf.isProjectAstEnabled) { // cuDF requires return column is fixed width val allReturnTypesFixedWidth = gpuExprs.forall(e => GpuBatchUtils.isFixedWidth(e.dataType)) if (allReturnTypesFixedWidth && childExprs.forall(_.canThisBeAst)) { - return GpuProjectAstExec(gpuExprs, gpuChild) + return GpuProjectAstExec(gpuExprs, gpuChild, forcePostProjectCoalesce) } // explain AST because this is optional and it is sometimes hard to debug if (conf.shouldExplain) { @@ -76,7 +121,7 @@ class GpuProjectExecMeta( } } } - GpuProjectExec(gpuExprs, gpuChild) + GpuProjectExec(gpuExprs, gpuChild, forcePostProjectCoalesce = forcePostProjectCoalesce) } } @@ -290,6 +335,7 @@ object GpuProjectExecLike { trait GpuProjectExecLike extends GpuPartitioningPreservingUnaryExecNode with GpuExec { def projectList: Seq[Expression] + def forcePostProjectCoalesce: Boolean override lazy val additionalMetrics: Map[String, GpuMetric] = Map( OP_TIME_LEGACY -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_OP_TIME_LEGACY)) @@ -299,8 +345,16 @@ trait GpuProjectExecLike extends GpuPartitioningPreservingUnaryExecNode with Gpu override def doExecute(): RDD[InternalRow] = throw new IllegalStateException(s"Row-based execution should not occur for $this") - // The same as what feeds us - override def outputBatching: CoalesceGoal = GpuExec.outputBatching(child) + override def coalesceAfter: Boolean = forcePostProjectCoalesce + + // Flagged protobuf projects intentionally drop the output batching guarantee so that + // the post-project coalesce inserted by transition rules is not optimized away. + override def outputBatching: CoalesceGoal = + if (forcePostProjectCoalesce) { + null + } else { + GpuExec.outputBatching(child) + } } /** @@ -763,14 +817,19 @@ case class GpuProjectExec( // immutable/List.scala#L516 projectList: List[NamedExpression], child: SparkPlan, - enablePreSplit: Boolean = true) extends GpuProjectExecLike { + enablePreSplit: Boolean = true, + forcePostProjectCoalesce: Boolean = false) extends GpuProjectExecLike { override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def outputBatching: CoalesceGoal = if (enablePreSplit) { // Pre-split will make sure the size of each output batch will not be larger // than the splitUntilSize. - TargetSize(PreProjectSplitIterator.getSplitUntilSize) + if (forcePostProjectCoalesce) { + super.outputBatching + } else { + TargetSize(PreProjectSplitIterator.getSplitUntilSize) + } } else { super.outputBatching } @@ -822,7 +881,8 @@ case class GpuProjectAstExec( // serde: https://github.com/scala/scala/blob/2.12.x/src/library/scala/collection/ // immutable/List.scala#L516 projectList: List[Expression], - child: SparkPlan + child: SparkPlan, + forcePostProjectCoalesce: Boolean = false ) extends GpuProjectExecLike { override def output: Seq[Attribute] = { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala new file mode 100644 index 00000000000..3cf6a290041 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala @@ -0,0 +1,203 @@ +/* + * Copyright (c) 2025-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import java.util.Arrays + +import ai.rapids.cudf +import ai.rapids.cudf.{BinaryOp, CudfException, DType} +import com.nvidia.spark.rapids.{GpuColumnVector, GpuUnaryExpression} +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.jni.{Protobuf, ProtobufSchemaDescriptor} +import com.nvidia.spark.rapids.shims.NullIntolerantShim + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression} +import org.apache.spark.sql.types._ + +/** + * GPU implementation for Spark's `from_protobuf` decode path. + * + * This is designed to replace `org.apache.spark.sql.protobuf.ProtobufDataToCatalyst` when + * supported. + * + * The implementation uses a flattened schema representation where nested fields have parent + * indices pointing to their containing message field. For pure scalar schemas, all fields + * are top-level (parentIndices == -1, depthLevels == 0, isRepeated == false). + * + * Schema projection is supported: `decodedSchema` contains only the top-level fields and + * nested children that are actually referenced by downstream operators. Downstream + * `GetStructField` and `GetArrayStructFields` nodes have their ordinals rewritten via + * `PRUNED_ORDINAL_TAG` to index into the pruned schema. Unreferenced fields are never + * accessed, so no null-column filling is needed. + * + * @param decodedSchema The pruned schema containing only the fields decoded by the GPU. + * Only fields referenced by downstream operators are included; + * ordinal remapping ensures correct field access into the pruned output. + * @param fieldNumbers Protobuf field numbers for all fields in flattened schema + * @param parentIndices Parent indices for all fields (-1 for top-level) + * @param depthLevels Nesting depth for all fields (0 for top-level) + * @param wireTypes Wire types for all fields + * @param outputTypeIds cuDF type IDs for all fields + * @param encodings Encodings for all fields + * @param isRepeated Whether each field is repeated + * @param isRequired Whether each field is required + * @param hasDefaultValue Whether each field has a default value + * @param defaultInts Default int/long values + * @param defaultFloats Default float/double values + * @param defaultBools Default bool values + * @param defaultStrings Default string/bytes values + * @param enumValidValues Valid enum values for each field + * @param enumNames Enum value names for enum-as-string fields. Parallel to enumValidValues. + * @param failOnErrors If true, throw exception on malformed data + */ +case class GpuFromProtobuf( + decodedSchema: StructType, + fieldNumbers: Array[Int], + parentIndices: Array[Int], + depthLevels: Array[Int], + wireTypes: Array[Int], + outputTypeIds: Array[Int], + encodings: Array[Int], + isRepeated: Array[Boolean], + isRequired: Array[Boolean], + hasDefaultValue: Array[Boolean], + defaultInts: Array[Long], + defaultFloats: Array[Double], + defaultBools: Array[Boolean], + defaultStrings: Array[Array[Byte]], + enumValidValues: Array[Array[Int]], + enumNames: Array[Array[Array[Byte]]], + failOnErrors: Boolean, + child: Expression) + extends GpuUnaryExpression with ExpectsInputTypes with NullIntolerantShim with Logging { + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override def dataType: DataType = decodedSchema + + override def nullable: Boolean = true + + override def equals(other: Any): Boolean = other match { + case that: GpuFromProtobuf => + decodedSchema == that.decodedSchema && + Arrays.equals(fieldNumbers, that.fieldNumbers) && + Arrays.equals(parentIndices, that.parentIndices) && + Arrays.equals(depthLevels, that.depthLevels) && + Arrays.equals(wireTypes, that.wireTypes) && + Arrays.equals(outputTypeIds, that.outputTypeIds) && + Arrays.equals(encodings, that.encodings) && + Arrays.equals(isRepeated, that.isRepeated) && + Arrays.equals(isRequired, that.isRequired) && + Arrays.equals(hasDefaultValue, that.hasDefaultValue) && + Arrays.equals(defaultInts, that.defaultInts) && + Arrays.equals(defaultFloats, that.defaultFloats) && + Arrays.equals(defaultBools, that.defaultBools) && + GpuFromProtobuf.deepEquals(defaultStrings, that.defaultStrings) && + GpuFromProtobuf.deepEquals(enumValidValues, that.enumValidValues) && + GpuFromProtobuf.deepEquals(enumNames, that.enumNames) && + failOnErrors == that.failOnErrors && + child == that.child + case _ => false + } + + override def hashCode(): Int = { + var result = decodedSchema.hashCode() + result = 31 * result + Arrays.hashCode(fieldNumbers) + result = 31 * result + Arrays.hashCode(parentIndices) + result = 31 * result + Arrays.hashCode(depthLevels) + result = 31 * result + Arrays.hashCode(wireTypes) + result = 31 * result + Arrays.hashCode(outputTypeIds) + result = 31 * result + Arrays.hashCode(encodings) + result = 31 * result + Arrays.hashCode(isRepeated) + result = 31 * result + Arrays.hashCode(isRequired) + result = 31 * result + Arrays.hashCode(hasDefaultValue) + result = 31 * result + Arrays.hashCode(defaultInts) + result = 31 * result + Arrays.hashCode(defaultFloats) + result = 31 * result + Arrays.hashCode(defaultBools) + result = 31 * result + GpuFromProtobuf.deepHashCode(defaultStrings) + result = 31 * result + GpuFromProtobuf.deepHashCode(enumValidValues) + result = 31 * result + GpuFromProtobuf.deepHashCode(enumNames) + result = 31 * result + failOnErrors.hashCode() + result = 31 * result + child.hashCode() + result + } + + // ProtobufSchemaDescriptor is a pure-Java immutable holder for validated schema arrays. + // It does not own native resources, so task-scoped close hooks are not required here. + @transient private lazy val protobufSchema = new ProtobufSchemaDescriptor( + fieldNumbers, parentIndices, depthLevels, wireTypes, outputTypeIds, encodings, + isRepeated, isRequired, hasDefaultValue, defaultInts, defaultFloats, defaultBools, + defaultStrings, enumValidValues, enumNames) + + override protected def doColumnar(input: GpuColumnVector): cudf.ColumnVector = { + val jniResult = try { + Protobuf.decodeToStruct(input.getBase, protobufSchema, failOnErrors) + } catch { + case e: CudfException if failOnErrors => + throw new org.apache.spark.SparkException("Malformed protobuf message", e) + case e: CudfException => + logWarning(s"Unexpected CudfException in PERMISSIVE mode: ${e.getMessage}", e) + throw e + } + + // Apply input nulls to output + if (input.getBase.hasNulls) { + withResource(jniResult) { _ => + jniResult.mergeAndSetValidity(BinaryOp.BITWISE_AND, input.getBase) + } + } else { + jniResult + } + } +} + +object GpuFromProtobuf { + val ENC_DEFAULT = 0 + val ENC_FIXED = 1 + val ENC_ZIGZAG = 2 + val ENC_ENUM_STRING = 3 + + /** + * Maps a Spark DataType to the corresponding cuDF native type ID. + * Note: The encoding (varint/zigzag/fixed) is determined by the protobuf field type, + * not the Spark data type, so it must be set separately based on the protobuf schema. + * + * @return Some(typeId) for supported types, None for unsupported types + */ + def sparkTypeToCudfIdOpt(dt: DataType): Option[Int] = dt match { + case BooleanType => Some(DType.BOOL8.getTypeId.getNativeId) + case IntegerType => Some(DType.INT32.getTypeId.getNativeId) + case LongType => Some(DType.INT64.getTypeId.getNativeId) + case FloatType => Some(DType.FLOAT32.getTypeId.getNativeId) + case DoubleType => Some(DType.FLOAT64.getTypeId.getNativeId) + case StringType => Some(DType.STRING.getTypeId.getNativeId) + case BinaryType => Some(DType.LIST.getTypeId.getNativeId) + case _ => None + } + + /** + * Check if a Spark DataType is supported by the GPU protobuf decoder. + */ + def isTypeSupported(dt: DataType): Boolean = sparkTypeToCudfIdOpt(dt).isDefined + + private def deepEquals[T](left: Array[T], right: Array[T]): Boolean = + Arrays.deepEquals(left.asInstanceOf[Array[Object]], right.asInstanceOf[Array[Object]]) + + private def deepHashCode[T](arr: Array[T]): Int = + Arrays.deepHashCode(arr.asInstanceOf[Array[Object]]) +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 9afb3b854d6..988c03b00b1 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -32,7 +32,10 @@ import org.apache.spark.sql.rapids.shims.RapidsErrorUtils import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, IntegralType, LongType, MapType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch -case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[String] = None) +case class GpuGetStructField( + child: Expression, + ordinal: Int, + name: Option[String] = None) extends ShimUnaryExpression with GpuExpression with ShimGetStructField @@ -41,15 +44,23 @@ case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[Strin lazy val childSchema: StructType = child.dataType.asInstanceOf[StructType] override def dataType: DataType = childSchema(ordinal).dataType - override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable + + override def nullable: Boolean = + child.nullable || childSchema(ordinal).nullable override def toString: String = { - val fieldName = if (resolved) childSchema(ordinal).name else s"_$ordinal" + val fieldName = if (resolved) { + childSchema(ordinal).name + } else { + s"_$ordinal" + } s"$child.${name.getOrElse(fieldName)}" } - override def sql: String = - child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}" + override def sql: String = { + val fieldName = childSchema(ordinal).name + child.sql + s".${quoteIdentifier(name.getOrElse(fieldName))}" + } override def columnarEvalAny(batch: ColumnarBatch): Any = { val dt = dataType @@ -59,7 +70,6 @@ case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[Strin GpuColumnVector.from(view.copyToColumnVector(), dt) } case s: GpuScalar => - // For a scalar in we want a scalar out. if (!s.isValid) { GpuScalar(null, dt) } else { @@ -402,6 +412,27 @@ case class GpuArrayPosition(left: Expression, right: Expression) } } +object GpuStructFieldOrdinalTag { + val PRUNED_ORDINAL_TAG = + new org.apache.spark.sql.catalyst.trees.TreeNodeTag[Int]("GPU_PRUNED_ORDINAL") +} + +class GpuGetStructFieldMeta( + expr: GetStructField, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends UnaryExprMeta[GetStructField](expr, conf, parent, rule) { + + def convertToGpu(child: Expression): GpuExpression = { + val runtimeOrd = expr.getTagValue( + GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).getOrElse(-1) + val effectiveOrd = + if (runtimeOrd >= 0) runtimeOrd else expr.ordinal + GpuGetStructField(child, effectiveOrd, expr.name) + } +} + class GpuGetArrayStructFieldsMeta( expr: GetArrayStructFields, conf: RapidsConf, @@ -409,8 +440,32 @@ class GpuGetArrayStructFieldsMeta( rule: DataFromReplacementRule) extends UnaryExprMeta[GetArrayStructFields](expr, conf, parent, rule) { - def convertToGpu(child: Expression): GpuExpression = - GpuGetArrayStructFields(child, expr.field, expr.ordinal, expr.numFields, expr.containsNull) + def convertToGpu(child: Expression): GpuExpression = { + val runtimeOrd = expr.getTagValue( + GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).getOrElse(-1) + val effectiveOrd = + if (runtimeOrd >= 0) runtimeOrd else expr.ordinal + val effectiveNumFields = + GpuGetArrayStructFieldsMeta.effectiveNumFields(child, expr, runtimeOrd) + GpuGetArrayStructFields(child, expr.field, + effectiveOrd, effectiveNumFields, expr.containsNull) + } +} + +object GpuGetArrayStructFieldsMeta { + def effectiveNumFields( + child: Expression, + expr: GetArrayStructFields, + runtimeOrd: Int): Int = { + if (runtimeOrd >= 0) { + child.dataType match { + case ArrayType(st: StructType, _) => st.fields.length + case _ => expr.numFields + } + } else { + expr.numFields + } + } } /** diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaExtractor.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaExtractor.scala new file mode 100644 index 00000000000..a6617e1c89a --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaExtractor.scala @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.protobuf + +import scala.collection.mutable + +import com.nvidia.spark.rapids.jni.Protobuf.{WT_32BIT, WT_64BIT, WT_LEN, WT_VARINT} + +import org.apache.spark.sql.rapids.GpuFromProtobuf +import org.apache.spark.sql.types._ + +object ProtobufSchemaExtractor { + def analyzeAllFields( + schema: StructType, + msgDesc: ProtobufMessageDescriptor, + enumsAsInts: Boolean, + messageName: String): Either[String, Map[String, ProtobufFieldInfo]] = { + val result = mutable.Map[String, ProtobufFieldInfo]() + + schema.fields.foreach { sf => + val fieldInfo = msgDesc.findField(sf.name) match { + case None => + unsupportedFieldInfo( + sf, + None, + s"Protobuf field '${sf.name}' not found in message '$messageName'") + case Some(fd) => + extractFieldInfo(sf, fd, enumsAsInts) match { + case Right(info) => + info + case Left(reason) => + unsupportedFieldInfo(sf, Some(fd), reason) + } + } + result(sf.name) = fieldInfo + } + + Right(result.toMap) + } + + def extractFieldInfo( + sparkField: StructField, + fieldDescriptor: ProtobufFieldDescriptor, + enumsAsInts: Boolean): Either[String, ProtobufFieldInfo] = { + val (isSupported, unsupportedReason, encoding) = + checkFieldSupport( + sparkField.dataType, + fieldDescriptor.protoTypeName, + fieldDescriptor.isRepeated, + enumsAsInts) + + fieldDescriptor.defaultValueResult.map { defaultValue => + ProtobufFieldInfo( + fieldNumber = fieldDescriptor.fieldNumber, + protoTypeName = fieldDescriptor.protoTypeName, + sparkType = sparkField.dataType, + encoding = encoding, + isSupported = isSupported, + unsupportedReason = unsupportedReason, + isRequired = fieldDescriptor.isRequired, + defaultValue = defaultValue, + enumMetadata = fieldDescriptor.enumMetadata, + isRepeated = fieldDescriptor.isRepeated + ) + } + } + + private def unsupportedFieldInfo( + sparkField: StructField, + fieldDescriptor: Option[ProtobufFieldDescriptor], + reason: String): ProtobufFieldInfo = { + ProtobufFieldInfo( + fieldNumber = fieldDescriptor.map(_.fieldNumber).getOrElse(-1), + protoTypeName = fieldDescriptor.map(_.protoTypeName).getOrElse("UNKNOWN"), + sparkType = sparkField.dataType, + encoding = GpuFromProtobuf.ENC_DEFAULT, + isSupported = false, + unsupportedReason = Some(reason), + isRequired = fieldDescriptor.exists(_.isRequired), + defaultValue = None, + enumMetadata = fieldDescriptor.flatMap(_.enumMetadata), + isRepeated = fieldDescriptor.exists(_.isRepeated) + ) + } + + def checkFieldSupport( + sparkType: DataType, + protoTypeName: String, + isRepeated: Boolean, + enumsAsInts: Boolean): (Boolean, Option[String], Int) = { + + if (isRepeated) { + sparkType match { + case ArrayType(elementType, _) => + elementType match { + case BooleanType | IntegerType | LongType | FloatType | DoubleType | + StringType | BinaryType => + return checkScalarEncoding(elementType, protoTypeName, enumsAsInts) + case _: StructType => + return (true, None, GpuFromProtobuf.ENC_DEFAULT) + case _ => + return ( + false, + Some(s"unsupported repeated element type: $elementType"), + GpuFromProtobuf.ENC_DEFAULT) + } + case _ => + return ( + false, + Some(s"repeated field should map to ArrayType, got: $sparkType"), + GpuFromProtobuf.ENC_DEFAULT) + } + } + + if (protoTypeName == "MESSAGE") { + sparkType match { + case _: StructType => + return (true, None, GpuFromProtobuf.ENC_DEFAULT) + case _ => + return ( + false, + Some(s"nested message should map to StructType, got: $sparkType"), + GpuFromProtobuf.ENC_DEFAULT) + } + } + + sparkType match { + case BooleanType | IntegerType | LongType | FloatType | DoubleType | + StringType | BinaryType => + case other => + return ( + false, + Some(s"unsupported Spark type: $other"), + GpuFromProtobuf.ENC_DEFAULT) + } + + checkScalarEncoding(sparkType, protoTypeName, enumsAsInts) + } + + def checkScalarEncoding( + sparkType: DataType, + protoTypeName: String, + enumsAsInts: Boolean): (Boolean, Option[String], Int) = { + val encoding = (sparkType, protoTypeName) match { + case (BooleanType, "BOOL") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (IntegerType, "INT32" | "UINT32") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (IntegerType, "SINT32") => Some(GpuFromProtobuf.ENC_ZIGZAG) + case (IntegerType, "FIXED32" | "SFIXED32") => Some(GpuFromProtobuf.ENC_FIXED) + case (LongType, "INT64" | "UINT64") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (LongType, "SINT64") => Some(GpuFromProtobuf.ENC_ZIGZAG) + case (LongType, "FIXED64" | "SFIXED64") => Some(GpuFromProtobuf.ENC_FIXED) + case (LongType, "INT32" | "UINT32" | "SINT32" | "FIXED32" | "SFIXED32") => + val enc = protoTypeName match { + case "SINT32" => GpuFromProtobuf.ENC_ZIGZAG + case "FIXED32" | "SFIXED32" => GpuFromProtobuf.ENC_FIXED + case _ => GpuFromProtobuf.ENC_DEFAULT + } + Some(enc) + case (FloatType, "FLOAT") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (DoubleType, "DOUBLE") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (StringType, "STRING") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (BinaryType, "BYTES") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (IntegerType, "ENUM") if enumsAsInts => Some(GpuFromProtobuf.ENC_DEFAULT) + case (StringType, "ENUM") if !enumsAsInts => Some(GpuFromProtobuf.ENC_ENUM_STRING) + case _ => None + } + + encoding match { + case Some(enc) => (true, None, enc) + case None => + (false, + Some(s"type mismatch: Spark $sparkType vs Protobuf $protoTypeName"), + GpuFromProtobuf.ENC_DEFAULT) + } + } + + def getWireType(protoTypeName: String, encoding: Int): Either[String, Int] = { + val wireType = protoTypeName match { + case "BOOL" | "INT32" | "UINT32" | "SINT32" | "INT64" | "UINT64" | "SINT64" | "ENUM" => + if (encoding == GpuFromProtobuf.ENC_FIXED) { + if (protoTypeName.contains("64")) { + WT_64BIT + } else { + WT_32BIT + } + } else { + WT_VARINT + } + case "FIXED32" | "SFIXED32" | "FLOAT" => + WT_32BIT + case "FIXED64" | "SFIXED64" | "DOUBLE" => + WT_64BIT + case "STRING" | "BYTES" | "MESSAGE" => + WT_LEN + case other => + return Left( + s"Unknown protobuf type name '$other' - cannot determine wire type; falling back to CPU") + } + Right(wireType) + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaModel.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaModel.scala new file mode 100644 index 00000000000..77c1cbdf8f3 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaModel.scala @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.protobuf + +import java.util.Arrays + +import org.apache.spark.sql.types.DataType + +sealed trait ProtobufDescriptorSource + +object ProtobufDescriptorSource { + final case class DescriptorPath(path: String) extends ProtobufDescriptorSource + final case class DescriptorBytes(bytes: Array[Byte]) extends ProtobufDescriptorSource { + override def equals(other: Any): Boolean = other match { + case DescriptorBytes(otherBytes) => Arrays.equals(bytes, otherBytes) + case _ => false + } + + override def hashCode(): Int = Arrays.hashCode(bytes) + } +} + +final case class ProtobufExprInfo( + messageName: String, + descriptorSource: ProtobufDescriptorSource, + options: Map[String, String]) + +final case class ProtobufPlannerOptions( + enumsAsInts: Boolean, + failOnErrors: Boolean) + +sealed trait ProtobufDefaultValue + +object ProtobufDefaultValue { + final case class BoolValue(value: Boolean) extends ProtobufDefaultValue + final case class IntValue(value: Long) extends ProtobufDefaultValue + final case class FloatValue(value: Float) extends ProtobufDefaultValue + final case class DoubleValue(value: Double) extends ProtobufDefaultValue + final case class StringValue(value: String) extends ProtobufDefaultValue + final case class BinaryValue(value: Array[Byte]) extends ProtobufDefaultValue + final case class EnumValue(number: Int, name: String) extends ProtobufDefaultValue +} + +final case class ProtobufEnumValue(number: Int, name: String) + +final case class ProtobufEnumMetadata(values: Seq[ProtobufEnumValue]) { + lazy val validValues: Array[Int] = values.map(_.number).toArray + lazy val orderedNames: Array[Array[Byte]] = values.map(_.name.getBytes("UTF-8")).toArray + lazy val namesByNumber: Map[Int, String] = values.map(v => v.number -> v.name).toMap + + def enumDefault(number: Int): ProtobufDefaultValue.EnumValue = { + val name = namesByNumber.getOrElse(number, s"$number") + ProtobufDefaultValue.EnumValue(number, name) + } +} + +trait ProtobufMessageDescriptor { + def syntax: String + def findField(name: String): Option[ProtobufFieldDescriptor] +} + +trait ProtobufFieldDescriptor { + def name: String + def fieldNumber: Int + def protoTypeName: String + def isRepeated: Boolean + def isRequired: Boolean + def defaultValueResult: Either[String, Option[ProtobufDefaultValue]] + def enumMetadata: Option[ProtobufEnumMetadata] + def messageDescriptor: Option[ProtobufMessageDescriptor] +} + +final case class ProtobufFieldInfo( + fieldNumber: Int, + protoTypeName: String, + sparkType: DataType, + encoding: Int, + isSupported: Boolean, + unsupportedReason: Option[String], + isRequired: Boolean, + defaultValue: Option[ProtobufDefaultValue], + enumMetadata: Option[ProtobufEnumMetadata], + isRepeated: Boolean = false) { + def hasDefaultValue: Boolean = defaultValue.isDefined +} + +final case class FlattenedFieldDescriptor( + fieldNumber: Int, + parentIdx: Int, + depth: Int, + wireType: Int, + outputTypeId: Int, + encoding: Int, + isRepeated: Boolean, + isRequired: Boolean, + hasDefaultValue: Boolean, + defaultInt: Long, + defaultFloat: Double, + defaultBool: Boolean, + defaultString: Array[Byte], + enumValidValues: Array[Int], + enumNames: Array[Array[Byte]] +) + +final case class FlattenedSchemaArrays( + fieldNumbers: Array[Int], + parentIndices: Array[Int], + depthLevels: Array[Int], + wireTypes: Array[Int], + outputTypeIds: Array[Int], + encodings: Array[Int], + isRepeated: Array[Boolean], + isRequired: Array[Boolean], + hasDefaultValue: Array[Boolean], + defaultInts: Array[Long], + defaultFloats: Array[Double], + defaultBools: Array[Boolean], + defaultStrings: Array[Array[Byte]], + enumValidValues: Array[Array[Int]], + enumNames: Array[Array[Array[Byte]]] +) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaValidator.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaValidator.scala new file mode 100644 index 00000000000..627fed83296 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaValidator.scala @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.protobuf + +import org.apache.spark.sql.rapids.GpuFromProtobuf +import org.apache.spark.sql.types._ + +object ProtobufSchemaValidator { + private final case class JniDefaultValues( + defaultInt: Long, + defaultFloat: Double, + defaultBool: Boolean, + defaultString: Array[Byte]) + + def toFlattenedFieldDescriptor( + path: String, + field: StructField, + fieldInfo: ProtobufFieldInfo, + parentIdx: Int, + depth: Int, + outputTypeId: Int): Either[String, FlattenedFieldDescriptor] = { + validateFieldInfo(path, field, fieldInfo).flatMap { _ => + ProtobufSchemaExtractor + .getWireType(fieldInfo.protoTypeName, fieldInfo.encoding) + .flatMap { wireType => + encodeDefaultValue(path, field.dataType, fieldInfo).map { defaults => + val enumValidValues = fieldInfo.enumMetadata.map(_.validValues).orNull + val enumNames = + if (fieldInfo.encoding == GpuFromProtobuf.ENC_ENUM_STRING) { + fieldInfo.enumMetadata.map(_.orderedNames).orNull + } else { + null + } + + FlattenedFieldDescriptor( + fieldNumber = fieldInfo.fieldNumber, + parentIdx = parentIdx, + depth = depth, + wireType = wireType, + outputTypeId = outputTypeId, + encoding = fieldInfo.encoding, + isRepeated = fieldInfo.isRepeated, + isRequired = fieldInfo.isRequired, + hasDefaultValue = fieldInfo.hasDefaultValue, + defaultInt = defaults.defaultInt, + defaultFloat = defaults.defaultFloat, + defaultBool = defaults.defaultBool, + defaultString = defaults.defaultString, + enumValidValues = enumValidValues, + enumNames = enumNames + ) + } + } + } + } + + def validateFlattenedSchema(flatFields: Seq[FlattenedFieldDescriptor]): Either[String, Unit] = { + flatFields.zipWithIndex.foreach { case (field, idx) => + if (field.parentIdx >= idx) { + return Left(s"Flattened protobuf schema has invalid parent index at position $idx") + } + if (field.parentIdx == -1 && field.depth != 0) { + return Left(s"Top-level protobuf field at position $idx must have depth 0") + } + if (field.parentIdx >= 0 && field.depth <= 0) { + return Left(s"Nested protobuf field at position $idx must have positive depth") + } + if (field.encoding == GpuFromProtobuf.ENC_ENUM_STRING) { + if (field.enumValidValues == null || field.enumNames == null) { + return Left(s"Enum-string field at position $idx is missing enum metadata") + } + if (field.enumValidValues.length != field.enumNames.length) { + return Left(s"Enum-string field at position $idx has mismatched enum metadata") + } + } + } + Right(()) + } + + def toFlattenedSchemaArrays( + flatFields: Array[FlattenedFieldDescriptor]): FlattenedSchemaArrays = { + FlattenedSchemaArrays( + fieldNumbers = flatFields.map(_.fieldNumber), + parentIndices = flatFields.map(_.parentIdx), + depthLevels = flatFields.map(_.depth), + wireTypes = flatFields.map(_.wireType), + outputTypeIds = flatFields.map(_.outputTypeId), + encodings = flatFields.map(_.encoding), + isRepeated = flatFields.map(_.isRepeated), + isRequired = flatFields.map(_.isRequired), + hasDefaultValue = flatFields.map(_.hasDefaultValue), + defaultInts = flatFields.map(_.defaultInt), + defaultFloats = flatFields.map(_.defaultFloat), + defaultBools = flatFields.map(_.defaultBool), + defaultStrings = flatFields.map(_.defaultString), + enumValidValues = flatFields.map(_.enumValidValues), + enumNames = flatFields.map(_.enumNames) + ) + } + + private def validateFieldInfo( + path: String, + field: StructField, + fieldInfo: ProtobufFieldInfo): Either[String, Unit] = { + if (fieldInfo.isRepeated && fieldInfo.hasDefaultValue) { + return Left(s"Repeated protobuf field '$path' cannot carry a default value") + } + + fieldInfo.enumMetadata match { + case Some(enumMeta) if enumMeta.values.isEmpty => + return Left(s"Enum field '$path' is missing enum values") + case Some(_) if fieldInfo.protoTypeName != "ENUM" => + return Left(s"Non-enum field '$path' should not carry enum metadata") + case None if fieldInfo.protoTypeName == "ENUM" => + return Left(s"Enum field '$path' is missing enum metadata") + case _ => + } + + if (fieldInfo.encoding == GpuFromProtobuf.ENC_ENUM_STRING && + fieldInfo.enumMetadata.isEmpty) { + return Left(s"Enum-string field '$path' is missing enum metadata") + } + + Right(()) + } + + private def encodeDefaultValue( + path: String, + dataType: DataType, + fieldInfo: ProtobufFieldInfo): Either[String, JniDefaultValues] = { + val empty = JniDefaultValues(0L, 0.0, defaultBool = false, null) + fieldInfo.defaultValue match { + case None => Right(empty) + case Some(defaultValue) => + val targetType = dataType match { + case ArrayType(elementType, _) => elementType + case other => other + } + (targetType, defaultValue) match { + case (BooleanType, ProtobufDefaultValue.BoolValue(value)) => + Right(empty.copy(defaultBool = value)) + case (IntegerType | LongType, ProtobufDefaultValue.IntValue(value)) => + Right(empty.copy(defaultInt = value)) + case (IntegerType | LongType, ProtobufDefaultValue.EnumValue(number, _)) => + Right(empty.copy(defaultInt = number.toLong)) + case (FloatType, ProtobufDefaultValue.FloatValue(value)) => + Right(empty.copy(defaultFloat = value.toDouble)) + case (DoubleType, ProtobufDefaultValue.DoubleValue(value)) => + Right(empty.copy(defaultFloat = value)) + case (StringType, ProtobufDefaultValue.StringValue(value)) => + Right(empty.copy(defaultString = value.getBytes("UTF-8"))) + case (StringType, ProtobufDefaultValue.EnumValue(number, name)) => + Right(empty.copy( + defaultInt = number.toLong, + defaultString = name.getBytes("UTF-8"))) + case (BinaryType, ProtobufDefaultValue.BinaryValue(value)) => + Right(empty.copy(defaultString = value)) + case _ => + Left(s"Incompatible default value for protobuf field '$path': $defaultValue") + } + } + } +} diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala new file mode 100644 index 00000000000..44f67ffff10 --- /dev/null +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala @@ -0,0 +1,791 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "340"} +{"spark": "341"} +{"spark": "342"} +{"spark": "343"} +{"spark": "344"} +{"spark": "350"} +{"spark": "351"} +{"spark": "352"} +{"spark": "353"} +{"spark": "354"} +{"spark": "355"} +{"spark": "356"} +{"spark": "357"} +{"spark": "400"} +{"spark": "401"} +{"spark": "402"} +{"spark": "411"} +spark-rapids-shim-json-lines ***/ + +package com.nvidia.spark.rapids.shims + +import scala.collection.mutable + +import ai.rapids.cudf.DType +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.catalyst.expressions.{ + AttributeReference, Expression, GetArrayStructFields, GetStructField, UnaryExpression +} +import org.apache.spark.sql.execution.ProjectExec +import org.apache.spark.sql.rapids.GpuFromProtobuf +import org.apache.spark.sql.rapids.protobuf.{ + FlattenedFieldDescriptor, + ProtobufFieldInfo, + ProtobufMessageDescriptor, + ProtobufSchemaExtractor, + ProtobufSchemaValidator +} +import org.apache.spark.sql.types._ + +/** + * Spark 3.4+ optional integration for spark-protobuf expressions. + * + * spark-protobuf is an external module, so these rules must be registered by reflection. + */ +object ProtobufExprShims extends org.apache.spark.internal.Logging { + private[this] val protobufDataToCatalystClassName = + "org.apache.spark.sql.protobuf.ProtobufDataToCatalyst" + + val PRUNED_ORDINAL_TAG = + org.apache.spark.sql.rapids.GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG + + def exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { + try { + val clazz = ShimReflectionUtils.loadClass(protobufDataToCatalystClassName) + .asInstanceOf[Class[_ <: UnaryExpression]] + Map(clazz.asInstanceOf[Class[_ <: Expression]] -> fromProtobufRule) + } catch { + case _: ClassNotFoundException => Map.empty + case e: Exception => + logWarning(s"Failed to load $protobufDataToCatalystClassName: ${e.getMessage}") + Map.empty + } + } + + private def fromProtobufRule: ExprRule[_ <: Expression] = { + GpuOverrides.expr[UnaryExpression]( + "Decode a BinaryType column (protobuf) into a Spark SQL struct", + ExprChecks.unaryProject( + // Use TypeSig.all here because schema projection determines which fields + // actually need GPU support. Detailed type checking is done in tagExprForGpu. + TypeSig.all, + TypeSig.all, + TypeSig.BINARY, + TypeSig.BINARY), + (e, conf, p, r) => new UnaryExprMeta[UnaryExpression](e, conf, p, r) { + + private var fullSchema: StructType = _ + private var failOnErrors: Boolean = _ + + // Flattened schema variables for GPU decoding + private var flatFieldNumbers: Array[Int] = _ + private var flatParentIndices: Array[Int] = _ + private var flatDepthLevels: Array[Int] = _ + private var flatWireTypes: Array[Int] = _ + private var flatOutputTypeIds: Array[Int] = _ + private var flatEncodings: Array[Int] = _ + private var flatIsRepeated: Array[Boolean] = _ + private var flatIsRequired: Array[Boolean] = _ + private var flatHasDefaultValue: Array[Boolean] = _ + private var flatDefaultInts: Array[Long] = _ + private var flatDefaultFloats: Array[Double] = _ + private var flatDefaultBools: Array[Boolean] = _ + private var flatDefaultStrings: Array[Array[Byte]] = _ + private var flatEnumValidValues: Array[Array[Int]] = _ + private var flatEnumNames: Array[Array[Array[Byte]]] = _ + // Indices in fullSchema for top-level fields that were decoded (for schema projection) + private var decodedTopLevelIndices: Array[Int] = _ + + override def tagExprForGpu(): Unit = { + fullSchema = e.dataType match { + case st: StructType => st + case other => + willNotWorkOnGpu( + s"Only StructType output is supported for from_protobuf, got $other") + return + } + + val exprInfo = SparkProtobufCompat.extractExprInfo(e) match { + case Right(info) => info + case Left(reason) => + willNotWorkOnGpu(reason) + return + } + val unsupportedOptions = SparkProtobufCompat.unsupportedOptions(exprInfo.options) + if (unsupportedOptions.nonEmpty) { + val keys = unsupportedOptions.mkString(",") + willNotWorkOnGpu( + s"from_protobuf options are not supported yet on GPU: $keys") + return + } + + val plannerOptions = SparkProtobufCompat.parsePlannerOptions(exprInfo.options) match { + case Right(opts) => opts + case Left(reason) => + willNotWorkOnGpu(reason) + return + } + val enumsAsInts = plannerOptions.enumsAsInts + failOnErrors = plannerOptions.failOnErrors + val messageName = exprInfo.messageName + + val msgDesc = SparkProtobufCompat.resolveMessageDescriptor(exprInfo) match { + case Right(desc) => desc + case Left(reason) => + willNotWorkOnGpu(reason) + return + } + + // Reject proto3 descriptors — GPU decoder only supports proto2 semantics. + // proto3 has different null/default-value behavior that the GPU path doesn't handle. + val protoSyntax = msgDesc.syntax + if (!SparkProtobufCompat.isGpuSupportedProtoSyntax(protoSyntax)) { + willNotWorkOnGpu( + "proto3/editions syntax is not supported by the GPU protobuf decoder; " + + "only proto2 is supported. The query will fall back to CPU.") + return + } + + // Step 1: Analyze all fields and build field info map + val fieldsInfoMap = + ProtobufSchemaExtractor.analyzeAllFields(fullSchema, msgDesc, enumsAsInts, messageName) + .fold({ reason => + willNotWorkOnGpu(reason) + return + }, identity) + + // Step 2: Determine which fields are actually required by downstream operations + val requiredFieldNames = analyzeRequiredFields(fieldsInfoMap.keySet) + + // Step 3: Check if all required fields are supported + val unsupportedRequired = requiredFieldNames.filter { name => + fieldsInfoMap.get(name).exists(!_.isSupported) + } + + if (unsupportedRequired.nonEmpty) { + val reasons = unsupportedRequired.map { name => + val info = fieldsInfoMap(name) + s"${name}: ${info.unsupportedReason.getOrElse("unknown reason")}" + } + willNotWorkOnGpu( + s"Required fields not supported for from_protobuf: ${reasons.mkString(", ")}") + return + } + + // Step 4: Identify which fields in fullSchema need to be decoded + // These are fields that are required AND supported + val indicesToDecode = fullSchema.fields.zipWithIndex.collect { + case (sf, idx) if requiredFieldNames.contains(sf.name) => idx + } + + // Verify all fields to be decoded are actually supported + // (This catches edge cases where field analysis might have issues) + val unsupportedInDecode = indicesToDecode.filter { idx => + val sf = fullSchema.fields(idx) + fieldsInfoMap.get(sf.name).exists(!_.isSupported) + } + if (unsupportedInDecode.nonEmpty) { + val reasons = unsupportedInDecode.map { idx => + val sf = fullSchema.fields(idx) + val info = fieldsInfoMap(sf.name) + s"${sf.name}: ${info.unsupportedReason.getOrElse("unknown reason")}" + } + willNotWorkOnGpu( + s"Fields not supported for from_protobuf: ${reasons.mkString(", ")}") + return + } + + // Step 5: Build flattened schema for GPU decoding. + // The flattened schema represents nested fields with parent indices. + // For pure scalar schemas, all fields are top-level (parentIdx == -1, depth == 0). + { + val flatFields = mutable.ArrayBuffer[FlattenedFieldDescriptor]() + + // Helper to add a field and its children recursively. + // pathPrefix is the dot-path of ancestor fields (empty for top-level). + def addFieldWithChildren( + sf: StructField, + info: ProtobufFieldInfo, + parentIdx: Int, + depth: Int, + nestedMsgDesc: ProtobufMessageDescriptor, + pathPrefix: String = ""): Unit = { + + val currentIdx = flatFields.size + + if (depth >= 10) { + willNotWorkOnGpu("Protobuf nesting depth exceeds maximum supported depth of 10") + return + } + + val outputType = sf.dataType match { + case ArrayType(elemType, _) => + elemType match { + case _: StructType => + // Repeated message field: ArrayType(StructType) - element type is STRUCT + DType.STRUCT.getTypeId.getNativeId + case other => + GpuFromProtobuf.sparkTypeToCudfIdOpt(other) + .getOrElse(DType.INT8.getTypeId.getNativeId) + } + case _: StructType => + DType.STRUCT.getTypeId.getNativeId + case other => + GpuFromProtobuf.sparkTypeToCudfIdOpt(other) + .getOrElse(DType.INT8.getTypeId.getNativeId) + } + + val path = if (pathPrefix.isEmpty) sf.name else s"$pathPrefix.${sf.name}" + ProtobufSchemaValidator.toFlattenedFieldDescriptor( + path, + sf, + info, + parentIdx, + depth, + outputType).fold({ reason => + willNotWorkOnGpu(reason) + return + }, flatFields += _) + + // For nested struct types (including repeated message = ArrayType(StructType)), + // add child fields + sf.dataType match { + case st: StructType if nestedMsgDesc != null => + addChildFieldsFromStruct( + st, nestedMsgDesc, sf.name, currentIdx, depth, pathPrefix) + + case ArrayType(st: StructType, _) if nestedMsgDesc != null => + addChildFieldsFromStruct( + st, nestedMsgDesc, sf.name, currentIdx, depth, pathPrefix) + + case _ => // Not a struct, no children to add + } + } + + // Helper to add child fields from a struct type. + // Applies nested schema pruning at arbitrary depth using path-based + // lookup into nestedFieldRequirements. + def addChildFieldsFromStruct( + st: StructType, + parentMsgDesc: ProtobufMessageDescriptor, + fieldName: String, + parentIdx: Int, + parentDepth: Int, + pathPrefix: String): Unit = { + val path = if (pathPrefix.isEmpty) fieldName else s"$pathPrefix.$fieldName" + val parentField = parentMsgDesc.findField(fieldName) + if (parentField.isEmpty) { + willNotWorkOnGpu( + s"Nested field '$fieldName' not found in protobuf descriptor at '$path'") + } else { + parentField.get.messageDescriptor match { + case Some(childMsgDesc) => + val requiredChildren = nestedFieldRequirements.get(path) + val filteredFields = requiredChildren match { + case Some(Some(childNames)) => + st.fields.filter(f => childNames.contains(f.name)) + case _ => + st.fields + } + filteredFields.foreach { childSf => + childMsgDesc.findField(childSf.name) match { + case None => + willNotWorkOnGpu( + s"Nested field '${childSf.name}' not found in protobuf " + + s"descriptor for message at '$path'") + return + case Some(childFd) => + ProtobufSchemaExtractor + .extractFieldInfo(childSf, childFd, enumsAsInts) match { + case Left(reason) => + willNotWorkOnGpu(reason) + return + case Right(childInfo) => + if (!childInfo.isSupported) { + willNotWorkOnGpu( + s"Nested field '${childSf.name}' at '$path': " + + childInfo.unsupportedReason.getOrElse("unsupported type")) + return + } else { + addFieldWithChildren( + childSf, childInfo, parentIdx, parentDepth + 1, childMsgDesc, + path) + } + } + } + } + case None => + willNotWorkOnGpu( + s"Nested field '$fieldName' at '$path' did not resolve to a message type") + } + } + } + + // Only add top-level fields that are actually required (schema projection). + // This significantly reduces GPU memory and computation for schemas with many + // fields when only a few are needed. Downstream GetStructField ordinals are + // remapped via PRUNED_ORDINAL_TAG to index into the pruned output. + decodedTopLevelIndices = indicesToDecode + indicesToDecode.foreach { schemaIdx => + val sf = fullSchema.fields(schemaIdx) + val info = fieldsInfoMap(sf.name) + addFieldWithChildren(sf, info, -1, 0, msgDesc) + } + + // Populate flattened schema variables + val flat = flatFields.toArray + ProtobufSchemaValidator.validateFlattenedSchema(flat).fold({ reason => + willNotWorkOnGpu(reason) + return + }, identity) + val arrays = ProtobufSchemaValidator.toFlattenedSchemaArrays(flat) + flatFieldNumbers = arrays.fieldNumbers + flatParentIndices = arrays.parentIndices + flatDepthLevels = arrays.depthLevels + flatWireTypes = arrays.wireTypes + flatOutputTypeIds = arrays.outputTypeIds + flatEncodings = arrays.encodings + flatIsRepeated = arrays.isRepeated + flatIsRequired = arrays.isRequired + flatHasDefaultValue = arrays.hasDefaultValue + flatDefaultInts = arrays.defaultInts + flatDefaultFloats = arrays.defaultFloats + flatDefaultBools = arrays.defaultBools + flatDefaultStrings = arrays.defaultStrings + flatEnumValidValues = arrays.enumValidValues + flatEnumNames = arrays.enumNames + } + } + + /** + * Analyze which fields are actually required by downstream operations. + * Traverses parent plan nodes upward, collecting struct field references from + * ProjectExec, FilterExec, and transparent pass-through nodes (AggregateExec, + * SortExec, WindowExec, etc.), then returns the set of required top-level + * field names. + * + * @param allFieldNames All field names in the full schema + * @return Set of field names that are actually required + */ + private var targetExprsToRemap: Seq[Expression] = Seq.empty + + private def analyzeRequiredFields(allFieldNames: Set[String]): Set[String] = { + val fieldReqs = mutable.Map[String, Option[Set[String]]]() + protobufOutputExprIds = Set.empty + var hasDirectStructRef = false + val holder = () => { hasDirectStructRef = true } + + var currentMeta: Option[SparkPlanMeta[_]] = findParentPlanMeta() + var safeToPrune = true + val collectedExprs = mutable.ArrayBuffer[Expression]() + + def advanceToParent(): Unit = { + currentMeta = currentMeta.get.parent match { + case Some(pm: SparkPlanMeta[_]) => Some(pm) + case _ => None + } + } + + while (currentMeta.isDefined && safeToPrune) { + currentMeta.get.wrapped match { + case p: ProjectExec => + collectedExprs ++= p.projectList + p.projectList.foreach { + case alias: org.apache.spark.sql.catalyst.expressions.Alias + if isProtobufStructReference(alias.child) => + protobufOutputExprIds += alias.exprId + case _ => + } + p.projectList.foreach(collectStructFieldReferences(_, fieldReqs, holder)) + currentMeta = None + case f: org.apache.spark.sql.execution.FilterExec => + collectedExprs += f.condition + collectStructFieldReferences(f.condition, fieldReqs, holder) + advanceToParent() + case a: org.apache.spark.sql.execution.aggregate.BaseAggregateExec => + val exprs = a.aggregateExpressions ++ a.groupingExpressions + collectedExprs ++= exprs + exprs.foreach(collectStructFieldReferences(_, fieldReqs, holder)) + advanceToParent() + case s: org.apache.spark.sql.execution.SortExec => + val exprs = s.sortOrder + collectedExprs ++= exprs + exprs.foreach(collectStructFieldReferences(_, fieldReqs, holder)) + advanceToParent() + case w: org.apache.spark.sql.execution.window.WindowExec => + val exprs = w.windowExpression + collectedExprs ++= exprs + exprs.foreach(collectStructFieldReferences(_, fieldReqs, holder)) + advanceToParent() + case _ => + safeToPrune = false + } + } + + if (!safeToPrune || collectedExprs.isEmpty || hasDirectStructRef || fieldReqs.isEmpty) { + targetExprsToRemap = Seq.empty + allFieldNames + } else { + nestedFieldRequirements = fieldReqs.toMap + targetExprsToRemap = collectedExprs.toSeq + val prunedFieldsMap = buildPrunedFieldsMap() + val topLevelIndices = fullSchema.fields.zipWithIndex.collect { + case (sf, idx) if fieldReqs.keySet.contains(sf.name) => idx + } + targetExprsToRemap.foreach( + registerPrunedOrdinals(_, prunedFieldsMap, topLevelIndices)) + fieldReqs.keySet.toSet + } + } + + /** + * Find the parent SparkPlanMeta by traversing up the parent chain. + */ + private def findParentPlanMeta(): Option[SparkPlanMeta[_]] = { + def traverse(meta: Option[RapidsMeta[_, _, _]]): Option[SparkPlanMeta[_]] = { + meta match { + case Some(p: SparkPlanMeta[_]) => Some(p) + case Some(p: RapidsMeta[_, _, _]) => traverse(p.parent) + case _ => None + } + } + traverse(parent) + } + + /** + * Nested field requirements: maps a field path to child requirements. + * Keys are dot-separated paths from the protobuf root: + * - "level1" -> Some(Set("level2")) (top-level struct pruning) + * - "level1.level2" -> Some(Set("level3")) (deep nested pruning) + * - "field" -> None (whole field needed) + * + * Top-level names (keys without dots) also determine which fields are decoded. + */ + private var nestedFieldRequirements: Map[String, Option[Set[String]]] = Map.empty + private var protobufOutputExprIds: Set[ + org.apache.spark.sql.catalyst.expressions.ExprId] = Set.empty + + private def getFieldName(ordinal: Int, nameOpt: Option[String], + schema: StructType): String = { + nameOpt.getOrElse { + if (ordinal < schema.fields.length) schema.fields(ordinal).name + else s"_$ordinal" + } + } + + /** + * Navigate the Spark schema tree by following a dot-separated path of + * field names. Returns the StructType at the end of the path, unwrapping + * ArrayType(StructType) along the way, or null if the path is invalid. + */ + private def resolveSchemaAtPath(root: StructType, path: Seq[String]): StructType = { + var current: StructType = root + for (name <- path) { + val field = current.fields.find(_.name == name).orNull + if (field == null) return null + field.dataType match { + case st: StructType => current = st + case ArrayType(st: StructType, _) => current = st + case _ => return null + } + } + current + } + + /** + * Walk a GetStructField chain upward until it reaches the protobuf + * reference expression, returning the sequence of field names forming + * the access path. Returns None if the chain does not terminate at a + * protobuf reference. + * + * Example: for `GetStructField(GetStructField(decoded, a_ord), b_ord)` + * → Some(Seq("a", "b")) + */ + private def resolveFieldAccessChain( + expr: Expression): Option[Seq[String]] = { + expr match { + case GetStructField(child, ordinal, nameOpt) => + if (isProtobufStructReference(child)) { + Some(Seq(getFieldName(ordinal, nameOpt, fullSchema))) + } else { + resolveFieldAccessChain(child).flatMap { parentPath => + val parentSchema = if (parentPath.isEmpty) fullSchema + else resolveSchemaAtPath(fullSchema, parentPath) + if (parentSchema != null) { + Some(parentPath :+ getFieldName(ordinal, nameOpt, parentSchema)) + } else { + None + } + } + } + case _ if isProtobufStructReference(expr) => + Some(Seq.empty) + case _ => + None + } + } + + private def addNestedFieldReq( + fieldReqs: mutable.Map[String, Option[Set[String]]], + parentKey: String, + childName: String): Unit = { + fieldReqs.get(parentKey) match { + case Some(None) => // Already need whole field, keep it + case Some(Some(existing)) => + fieldReqs(parentKey) = Some(existing + childName) + case None => + fieldReqs(parentKey) = Some(Set(childName)) + } + } + + /** + * Register pruning requirements at every level of a field access path. + * For path = ["a", "b"] with leafName = "c": + * "a" -> needs child "b" + * "a.b" -> needs child "c" + */ + private def registerPathRequirements( + fieldReqs: mutable.Map[String, Option[Set[String]]], + path: Seq[String], + leafName: String): Unit = { + for (i <- path.indices) { + val pathKey = path.take(i + 1).mkString(".") + val childName = if (i < path.length - 1) path(i + 1) else leafName + addNestedFieldReq(fieldReqs, pathKey, childName) + } + } + + private def collectStructFieldReferences( + expr: Expression, + fieldReqs: mutable.Map[String, Option[Set[String]]], + hasDirectStructRefHolder: () => Unit): Unit = { + expr match { + case GetStructField(child, ordinal, nameOpt) => + resolveFieldAccessChain(child) match { + case Some(parentPath) => + val parentSchema = if (parentPath.isEmpty) fullSchema + else resolveSchemaAtPath(fullSchema, parentPath) + if (parentSchema != null) { + val fieldName = getFieldName(ordinal, nameOpt, parentSchema) + if (parentPath.isEmpty) { + // Direct top-level access: decoded.field_name (whole field) + fieldReqs(fieldName) = None + } else { + registerPathRequirements(fieldReqs, parentPath, fieldName) + } + } else { + collectStructFieldReferences(child, fieldReqs, hasDirectStructRefHolder) + } + case None => + collectStructFieldReferences(child, fieldReqs, hasDirectStructRefHolder) + } + + case gasf: GetArrayStructFields => + resolveFieldAccessChain(gasf.child) match { + case Some(parentPath) if parentPath.nonEmpty => + registerPathRequirements(fieldReqs, parentPath, gasf.field.name) + case Some(parentPath) if parentPath.isEmpty => + fieldReqs(gasf.field.name) = None + case _ => + gasf.children.foreach { child => + collectStructFieldReferences(child, fieldReqs, hasDirectStructRefHolder) + } + } + + case alias: org.apache.spark.sql.catalyst.expressions.Alias => + if (!isProtobufStructReference(alias.child)) { + collectStructFieldReferences(alias.child, fieldReqs, hasDirectStructRefHolder) + } + + case _ => + if (isProtobufStructReference(expr)) { + hasDirectStructRefHolder() + } + expr.children.foreach { child => + collectStructFieldReferences(child, fieldReqs, hasDirectStructRefHolder) + } + } + } + + private def buildPrunedFieldsMap(): Map[String, Seq[String]] = { + nestedFieldRequirements.collect { + case (pathKey, Some(childNames)) => + val pathParts = pathKey.split("\\.").toSeq + val childSchema = resolveSchemaAtPath(fullSchema, pathParts) + if (childSchema != null) { + val orderedNames = childSchema.fields + .map(_.name) + .filter(childNames.contains) + .toSeq + pathKey -> orderedNames + } else { + pathKey -> childNames.toSeq + } + } + } + + private def registerPrunedOrdinals( + expr: Expression, + prunedFieldsMap: Map[String, Seq[String]], + topLevelIndices: Seq[Int]): Unit = { + expr match { + case gsf @ GetStructField(childExpr, ordinal, nameOpt) => + resolveFieldAccessChain(childExpr) match { + case Some(parentPath) if parentPath.nonEmpty => + val parentSchema = resolveSchemaAtPath(fullSchema, parentPath) + if (parentSchema != null) { + val pathKey = parentPath.mkString(".") + val childName = getFieldName(ordinal, nameOpt, parentSchema) + prunedFieldsMap.get(pathKey).foreach { orderedChildren => + val runtimeOrd = orderedChildren.indexOf(childName) + if (runtimeOrd >= 0) { + gsf.setTagValue(ProtobufExprShims.PRUNED_ORDINAL_TAG, runtimeOrd) + } + } + } + case Some(parentPath) if parentPath.isEmpty => + val runtimeOrd = topLevelIndices.indexOf(ordinal) + if (runtimeOrd >= 0) { + gsf.setTagValue(ProtobufExprShims.PRUNED_ORDINAL_TAG, runtimeOrd) + } + case _ => + } + case gasf @ GetArrayStructFields(childExpr, field, _, _, _) => + resolveFieldAccessChain(childExpr) match { + case Some(parentPath) if parentPath.nonEmpty => + val pathKey = parentPath.mkString(".") + prunedFieldsMap.get(pathKey).foreach { orderedChildren => + val runtimeOrd = orderedChildren.indexOf(field.name) + if (runtimeOrd >= 0) { + gasf.setTagValue(ProtobufExprShims.PRUNED_ORDINAL_TAG, runtimeOrd) + } + } + case _ => + } + case _ => + } + expr.children.foreach(registerPrunedOrdinals(_, prunedFieldsMap, topLevelIndices)) + } + + /** + * Check if an expression references the output of a protobuf decode expression. + * This can be either: + * 1. The ProtobufDataToCatalyst expression itself + * 2. An AttributeReference that references the output of ProtobufDataToCatalyst + * (when accessing from a downstream ProjectExec) + */ + private def isProtobufStructReference(expr: Expression): Boolean = { + if ((expr eq e) || expr.semanticEquals(e)) { + return true + } + + // Catalyst may create duplicate ProtobufDataToCatalyst + // instances for each GetStructField access. Match copies + // by class + identical input child + identical decode + // semantics so that + // analyzeRequiredFields detects all field accesses in one + // pass, keeping schema projection correct. + if (expr.getClass == e.getClass && + expr.children.nonEmpty && + e.children.nonEmpty && + ((expr.children.head eq e.children.head) || + expr.children.head.semanticEquals( + e.children.head)) && + SparkProtobufCompat.sameDecodeSemantics(expr, e)) { + return true + } + + val protobufOutputExprId + : Option[org.apache.spark.sql.catalyst.expressions.ExprId] = + parent.flatMap { meta => + meta.wrapped match { + case alias: org.apache.spark.sql.catalyst.expressions + .Alias if alias.child.semanticEquals(e) => + Some(alias.exprId) + case _ => None + } + } + + expr match { + case attr: AttributeReference => + protobufOutputExprIds.contains(attr.exprId) || protobufOutputExprId.exists(_ == attr.exprId) + case _ => false + } + } + + override def convertToGpu(child: Expression): GpuExpression = { + val prunedFieldsMap = buildPrunedFieldsMap() + targetExprsToRemap.foreach( + registerPrunedOrdinals(_, prunedFieldsMap, decodedTopLevelIndices.toSeq)) + + val decodedSchema = { + def applyPruning(field: StructField, prefix: String): StructField = { + val path = if (prefix.isEmpty) field.name else s"$prefix.${field.name}" + prunedFieldsMap.get(path) match { + case Some(childNames) => + field.dataType match { + case ArrayType(st: StructType, cn) => + val pruned = StructType( + st.fields.filter(f => childNames.contains(f.name)) + .map(f => applyPruning(f, path))) + field.copy(dataType = ArrayType(pruned, cn)) + case st: StructType => + val pruned = StructType( + st.fields.filter(f => childNames.contains(f.name)) + .map(f => applyPruning(f, path))) + field.copy(dataType = pruned) + case _ => field + } + case None => + field.dataType match { + case ArrayType(st: StructType, cn) => + val recursed = StructType(st.fields.map(f => applyPruning(f, path))) + if (recursed != st) field.copy(dataType = ArrayType(recursed, cn)) + else field + case st: StructType => + val recursed = StructType(st.fields.map(f => applyPruning(f, path))) + if (recursed != st) field.copy(dataType = recursed) + else field + case _ => field + } + } + } + + val decodedFields = decodedTopLevelIndices.map { idx => + applyPruning(fullSchema.fields(idx), "") + } + StructType(decodedFields.map(f => + f.copy(nullable = true))) + } + + GpuFromProtobuf( + decodedSchema, + flatFieldNumbers, flatParentIndices, + flatDepthLevels, flatWireTypes, flatOutputTypeIds, flatEncodings, + flatIsRepeated, flatIsRequired, flatHasDefaultValue, flatDefaultInts, + flatDefaultFloats, flatDefaultBools, flatDefaultStrings, flatEnumValidValues, + flatEnumNames, failOnErrors, child) + } + } + ) + } + +} diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala index f6110539a35..10ae5efacbf 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala @@ -165,7 +165,7 @@ trait Spark340PlusNonDBShims extends Spark331PlusNonDBShims { ), GpuElementAtMeta.elementAtRule(true) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap - super.getExprs ++ shimExprs + super.getExprs ++ shimExprs ++ ProtobufExprShims.exprs } override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand], diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/SparkProtobufCompat.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/SparkProtobufCompat.scala new file mode 100644 index 00000000000..ef2ffc5ef57 --- /dev/null +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/SparkProtobufCompat.scala @@ -0,0 +1,344 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "340"} +{"spark": "341"} +{"spark": "342"} +{"spark": "343"} +{"spark": "344"} +{"spark": "350"} +{"spark": "351"} +{"spark": "352"} +{"spark": "353"} +{"spark": "354"} +{"spark": "355"} +{"spark": "356"} +{"spark": "357"} +{"spark": "400"} +{"spark": "401"} +{"spark": "402"} +{"spark": "411"} +spark-rapids-shim-json-lines ***/ + +package com.nvidia.spark.rapids.shims + +import java.lang.reflect.Method +import java.nio.file.{Files, Paths} + +import scala.util.Try + +import com.nvidia.spark.rapids.ShimReflectionUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.rapids.protobuf._ + +private[shims] object SparkProtobufCompat extends Logging { + private[this] val sparkProtobufUtilsObjectClassName = + "org.apache.spark.sql.protobuf.utils.ProtobufUtils$" + + val SupportedOptions: Set[String] = Set("enums.as.ints", "mode") + + def extractExprInfo(e: Expression): Either[String, ProtobufExprInfo] = { + for { + messageName <- reflectMessageName(e) + options <- reflectOptions(e) + descriptorSource <- reflectDescriptorSource(e) + } yield ProtobufExprInfo(messageName, descriptorSource, options) + } + + def sameDecodeSemantics(left: Expression, right: Expression): Boolean = { + (extractExprInfo(left), extractExprInfo(right)) match { + case (Right(leftInfo), Right(rightInfo)) => leftInfo == rightInfo + case _ => false + } + } + + def parsePlannerOptions( + options: Map[String, String]): Either[String, ProtobufPlannerOptions] = { + val enumsAsInts = Try(options.getOrElse("enums.as.ints", "false").toBoolean) + .toEither + .left + .map { _ => + "Invalid value for from_protobuf option 'enums.as.ints': " + + s"'${options.getOrElse("enums.as.ints", "")}' (expected true/false)" + } + enumsAsInts.map(v => + ProtobufPlannerOptions( + enumsAsInts = v, + failOnErrors = options.getOrElse("mode", "FAILFAST").equalsIgnoreCase("FAILFAST"))) + } + + def unsupportedOptions(options: Map[String, String]): Seq[String] = + options.keys.filterNot(SupportedOptions.contains).toSeq.sorted + + def isGpuSupportedProtoSyntax(syntax: String): Boolean = + syntax.nonEmpty && syntax != "PROTO3" && syntax != "EDITIONS" + + def resolveMessageDescriptor( + exprInfo: ProtobufExprInfo): Either[String, ProtobufMessageDescriptor] = { + Try(buildMessageDescriptor(exprInfo.messageName, exprInfo.descriptorSource)) + .toEither + .left + .map { t => + s"Failed to resolve protobuf descriptor for message '${exprInfo.messageName}': " + + s"${t.getMessage}" + } + .map(new ReflectiveMessageDescriptor(_)) + } + + private def reflectMessageName(e: Expression): Either[String, String] = + Try(PbReflect.invoke0[String](e, "messageName")).toEither.left.map { t => + s"Cannot read from_protobuf messageName via reflection: ${t.getMessage}" + } + + private def reflectOptions(e: Expression): Either[String, Map[String, String]] = { + Try(PbReflect.invoke0[scala.collection.Map[String, String]](e, "options")) + .map(_.toMap) + .toEither.left.map { _ => + "Cannot read from_protobuf options via reflection; falling back to CPU" + } + } + + private def reflectDescriptorSource(e: Expression): Either[String, ProtobufDescriptorSource] = { + reflectDescFilePath(e).map(ProtobufDescriptorSource.DescriptorPath).orElse( + reflectDescriptorBytes(e).map(ProtobufDescriptorSource.DescriptorBytes)).toRight( + "from_protobuf requires a descriptor set (descFilePath or binaryFileDescriptorSet)") + } + + private def reflectDescFilePath(e: Expression): Option[String] = + Try(PbReflect.invoke0[Option[String]](e, "descFilePath")).toOption.flatten + + private def reflectDescriptorBytes(e: Expression): Option[Array[Byte]] = { + val spark35Result = Try(PbReflect.invoke0[Option[Array[Byte]]](e, "binaryFileDescriptorSet")) + .toOption.flatten + spark35Result.orElse { + val direct = Try(PbReflect.invoke0[Array[Byte]](e, "binaryDescriptorSet")).toOption + direct.orElse { + Try(PbReflect.invoke0[Option[Array[Byte]]](e, "binaryDescriptorSet")).toOption.flatten + } + } + } + + private def buildMessageDescriptor( + messageName: String, + descriptorSource: ProtobufDescriptorSource): AnyRef = { + val cls = ShimReflectionUtils.loadClass(sparkProtobufUtilsObjectClassName) + val module = cls.getField("MODULE$").get(null) + val buildMethod = cls.getMethod("buildDescriptor", classOf[String], classOf[scala.Option[_]]) + + invokeBuildDescriptor( + buildMethod, + module, + messageName, + descriptorSource, + filePath => Files.readAllBytes(Paths.get(filePath))) + } + + private[shims] def invokeBuildDescriptor( + buildMethod: Method, + module: AnyRef, + messageName: String, + descriptorSource: ProtobufDescriptorSource, + readDescriptorFile: String => Array[Byte]): AnyRef = { + descriptorSource match { + case ProtobufDescriptorSource.DescriptorBytes(bytes) => + buildMethod.invoke(module, messageName, Some(bytes)).asInstanceOf[AnyRef] + case ProtobufDescriptorSource.DescriptorPath(filePath) => + try { + buildMethod.invoke(module, messageName, Some(filePath)).asInstanceOf[AnyRef] + } catch { + // Spark 3.5+ changed the descriptor payload from Option[String] to Option[Array[Byte]] + // while keeping the same erased JVM signature. Retry with file contents when the + // path-based invocation clearly hit that binary-descriptor variant. + case ex: java.lang.reflect.InvocationTargetException + if ex.getCause.isInstanceOf[ClassCastException] || + ex.getCause.isInstanceOf[MatchError] => + buildMethod.invoke( + module, messageName, Some(readDescriptorFile(filePath))).asInstanceOf[AnyRef] + } + } + } + + private def typeName(t: AnyRef): String = + if (t == null) "null" else Try(PbReflect.invoke0[String](t, "name")).getOrElse(t.toString) + + private final class ReflectiveMessageDescriptor(raw: AnyRef) extends ProtobufMessageDescriptor { + override lazy val syntax: String = PbReflect.getFileSyntax(raw, typeName) + + override def findField(name: String): Option[ProtobufFieldDescriptor] = + Option(PbReflect.findFieldByName(raw, name)).map(new ReflectiveFieldDescriptor(_)) + } + + private final class ReflectiveFieldDescriptor(raw: AnyRef) extends ProtobufFieldDescriptor { + override lazy val name: String = PbReflect.invoke0[String](raw, "getName") + override lazy val fieldNumber: Int = PbReflect.getFieldNumber(raw) + override lazy val protoTypeName: String = typeName(PbReflect.getFieldType(raw)) + override lazy val isRepeated: Boolean = PbReflect.isRepeated(raw) + override lazy val isRequired: Boolean = PbReflect.isRequired(raw) + override lazy val enumMetadata: Option[ProtobufEnumMetadata] = + if (protoTypeName == "ENUM") { + Some(ProtobufEnumMetadata(PbReflect.getEnumValues(PbReflect.getEnumType(raw)))) + } else { + None + } + override lazy val defaultValueResult: Either[String, Option[ProtobufDefaultValue]] = + Try { + if (PbReflect.hasDefaultValue(raw)) { + PbReflect.getDefaultValue(raw).map(toDefaultValue(_, protoTypeName, enumMetadata)) + } else { + None + } + }.toEither.left.map { t => + s"Failed to read protobuf default value for field '$name': ${t.getMessage}" + } + override lazy val messageDescriptor: Option[ProtobufMessageDescriptor] = + if (protoTypeName == "MESSAGE") { + Some(new ReflectiveMessageDescriptor(PbReflect.getMessageType(raw))) + } else { + None + } + } + + private def toDefaultValue( + rawDefault: AnyRef, + protoTypeName: String, + enumMetadata: Option[ProtobufEnumMetadata]): ProtobufDefaultValue = protoTypeName match { + case "BOOL" => + ProtobufDefaultValue.BoolValue(rawDefault.asInstanceOf[java.lang.Boolean].booleanValue()) + case "FLOAT" => + ProtobufDefaultValue.FloatValue(rawDefault.asInstanceOf[java.lang.Float].floatValue()) + case "DOUBLE" => + ProtobufDefaultValue.DoubleValue(rawDefault.asInstanceOf[java.lang.Double].doubleValue()) + case "STRING" => + ProtobufDefaultValue.StringValue(if (rawDefault == null) null else rawDefault.toString) + case "BYTES" => + ProtobufDefaultValue.BinaryValue(extractBytes(rawDefault)) + case "ENUM" => + val number = extractNumber(rawDefault).intValue() + enumMetadata.map(_.enumDefault(number)) + .getOrElse(ProtobufDefaultValue.EnumValue(number, rawDefault.toString)) + case "INT32" | "UINT32" | "SINT32" | "FIXED32" | "SFIXED32" | + "INT64" | "UINT64" | "SINT64" | "FIXED64" | "SFIXED64" => + ProtobufDefaultValue.IntValue(extractNumber(rawDefault).longValue()) + case other => + throw new IllegalStateException( + s"Unsupported protobuf default value type '$other' for value ${rawDefault.toString}") + } + + private def extractNumber(rawDefault: AnyRef): java.lang.Number = rawDefault match { + case n: java.lang.Number => n + case ref: AnyRef => + Try { + ref.getClass.getMethod("getNumber").invoke(ref).asInstanceOf[java.lang.Number] + }.getOrElse { + throw new IllegalStateException( + s"Unsupported protobuf numeric default value class: ${ref.getClass.getName}") + } + case _ => + throw new IllegalStateException("Unexpected protobuf numeric default value") + } + + private def extractBytes(rawDefault: AnyRef): Array[Byte] = rawDefault match { + case bytes: Array[Byte] => bytes + case ref: AnyRef => + Try { + ref.getClass.getMethod("toByteArray").invoke(ref).asInstanceOf[Array[Byte]] + }.getOrElse { + throw new IllegalStateException( + s"Unsupported protobuf bytes default value class: ${ref.getClass.getName}") + } + case _ => + throw new IllegalStateException("Unexpected protobuf bytes default value") + } + + private object PbReflect { + private val cache = new java.util.concurrent.ConcurrentHashMap[String, Method]() + + private def protobufJavaVersion: String = Try { + val rtCls = Class.forName("com.google.protobuf.RuntimeVersion") + val domain = rtCls.getField("DOMAIN").get(null) + val major = rtCls.getField("MAJOR").get(null) + val minor = rtCls.getField("MINOR").get(null) + val patch = rtCls.getField("PATCH").get(null) + s"$domain-$major.$minor.$patch" + }.getOrElse("unknown") + + private def cached(cls: Class[_], name: String, paramTypes: Class[_]*): Method = { + val key = s"${cls.getName}#$name(${paramTypes.map(_.getName).mkString(",")})" + cache.computeIfAbsent(key, _ => { + try { + cls.getMethod(name, paramTypes: _*) + } catch { + case ex: NoSuchMethodException => + throw new UnsupportedOperationException( + s"protobuf-java method not found: ${cls.getSimpleName}.$name " + + s"(protobuf-java version: $protobufJavaVersion). " + + s"This may indicate an incompatible protobuf-java library version.", + ex) + } + }) + } + + def invoke0[T](obj: AnyRef, method: String): T = + cached(obj.getClass, method).invoke(obj).asInstanceOf[T] + + def invoke1[T](obj: AnyRef, method: String, arg0Cls: Class[_], arg0: AnyRef): T = + cached(obj.getClass, method, arg0Cls).invoke(obj, arg0).asInstanceOf[T] + + def findFieldByName(msgDesc: AnyRef, name: String): AnyRef = + invoke1[AnyRef](msgDesc, "findFieldByName", classOf[String], name) + + def getFieldNumber(fd: AnyRef): Int = + invoke0[java.lang.Integer](fd, "getNumber").intValue() + + def getFieldType(fd: AnyRef): AnyRef = invoke0[AnyRef](fd, "getType") + + def isRepeated(fd: AnyRef): Boolean = + invoke0[java.lang.Boolean](fd, "isRepeated").booleanValue() + + def isRequired(fd: AnyRef): Boolean = + invoke0[java.lang.Boolean](fd, "isRequired").booleanValue() + + def hasDefaultValue(fd: AnyRef): Boolean = + invoke0[java.lang.Boolean](fd, "hasDefaultValue").booleanValue() + + def getDefaultValue(fd: AnyRef): Option[AnyRef] = + Option(invoke0[AnyRef](fd, "getDefaultValue")) + + def getMessageType(fd: AnyRef): AnyRef = invoke0[AnyRef](fd, "getMessageType") + + def getEnumType(fd: AnyRef): AnyRef = invoke0[AnyRef](fd, "getEnumType") + + def getEnumValues(enumType: AnyRef): Seq[ProtobufEnumValue] = { + import scala.collection.JavaConverters._ + val values = invoke0[java.util.List[_]](enumType, "getValues") + values.asScala.map { v => + val ev = v.asInstanceOf[AnyRef] + val num = invoke0[java.lang.Integer](ev, "getNumber").intValue() + val enumName = invoke0[String](ev, "getName") + ProtobufEnumValue(num, enumName) + }.toSeq + } + + def getFileSyntax(msgDesc: AnyRef, typeNameFn: AnyRef => String): String = Try { + val fileDesc = invoke0[AnyRef](msgDesc, "getFile") + val syntaxObj = invoke0[AnyRef](fileDesc, "getSyntax") + typeNameFn(syntaxObj) + }.getOrElse("") + } +} diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/ProtobufBatchMergeSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/ProtobufBatchMergeSuite.scala new file mode 100644 index 00000000000..1f2d4492108 --- /dev/null +++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/ProtobufBatchMergeSuite.scala @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, GetStructField, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.execution.{LeafExecNode, ProjectExec, SparkPlan} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch + +class ProtobufBatchMergeSuite extends AnyFunSuite { + + private case class DummyColumnarLeaf(output: Seq[AttributeReference]) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = + throw new UnsupportedOperationException("not needed for unit test") + + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = + throw new UnsupportedOperationException("not needed for unit test") + + override def supportsColumnar: Boolean = true + } + + private case class FakeProtobufDataToCatalyst(child: Expression) extends UnaryExpression { + override def dataType: StructType = StructType(Seq( + StructField("search_id", IntegerType, nullable = true), + StructField("nested", StructType(Seq(StructField("value", IntegerType, nullable = true))), + nullable = true))) + override def nullable: Boolean = true + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException("codegen not needed in unit test") + } + + test("project meta detects extractor project over protobuf child project") { + val binAttr = AttributeReference("bin", BinaryType)() + val childScan = DummyColumnarLeaf(Seq(binAttr)) + val innerProject = ProjectExec( + Seq(Alias(FakeProtobufDataToCatalyst(binAttr), "decoded")()), + childScan) + val decodedAttr = innerProject.output.head.toAttribute + val outerProject = ProjectExec(Seq( + Alias(GetStructField(decodedAttr, 0, None), "search_id")(), + Alias(GetStructField(GetStructField(decodedAttr, 1, None), 0, None), "value")()), + innerProject) + + assert(GpuProjectExecMeta.shouldCoalesceAfterProject(outerProject)) + assert(!GpuProjectExecMeta.shouldCoalesceAfterProject(innerProject)) + } + + test("project meta detects direct protobuf extraction in same project") { + val binAttr = AttributeReference("bin", BinaryType)() + val childScan = DummyColumnarLeaf(Seq(binAttr)) + val directProject = ProjectExec( + Seq( + Alias(GetStructField(FakeProtobufDataToCatalyst(binAttr), 0, None), "search_id")(), + Alias( + GetStructField( + GetStructField(FakeProtobufDataToCatalyst(binAttr), 1, None), + 0, + None), + "value")()), + childScan) + + assert(GpuProjectExecMeta.shouldCoalesceAfterProject(directProject)) + } + + test("protobuf batch merge config defaults off and can be enabled") { + val enabledConf = new RapidsConf(Map( + RapidsConf.ENABLE_PROTOBUF_BATCH_MERGE_AFTER_PROJECT.key -> "true")) + + assert(!new RapidsConf(Map.empty[String, String]).isProtobufBatchMergeAfterProjectEnabled) + assert(enabledConf.isProtobufBatchMergeAfterProjectEnabled) + } + + test("flagged gpu project drops output batching guarantee for post-project merge") { + val childAttr = AttributeReference("value", IntegerType)() + val child: SparkPlan = DummyColumnarLeaf(Seq(childAttr)) + + val unflaggedProject = GpuProjectExec( + projectList = child.output.map(a => Alias(a, a.name)()).toList, + child = child, + enablePreSplit = true, + forcePostProjectCoalesce = false) + val flaggedProject = GpuProjectExec( + projectList = child.output.map(a => Alias(a, a.name)()).toList, + child = child, + enablePreSplit = true, + forcePostProjectCoalesce = true) + + assert(!unflaggedProject.coalesceAfter) + assert(flaggedProject.coalesceAfter) + assert(unflaggedProject.outputBatching.isInstanceOf[TargetSize]) + assert(flaggedProject.outputBatching == null) + assert(CoalesceGoal.satisfies(unflaggedProject.outputBatching, TargetSize(1L))) + assert(!CoalesceGoal.satisfies(flaggedProject.outputBatching, TargetSize(1L))) + } +} diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/shims/ProtobufExprShimsSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/shims/ProtobufExprShimsSuite.scala new file mode 100644 index 00000000000..1ab2c16cdf6 --- /dev/null +++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/shims/ProtobufExprShimsSuite.scala @@ -0,0 +1,451 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims + +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.GetArrayStructFields +import org.apache.spark.sql.rapids.{ + GpuFromProtobuf, + GpuGetArrayStructFieldsMeta, + GpuStructFieldOrdinalTag +} +import org.apache.spark.sql.rapids.protobuf._ +import org.apache.spark.sql.types._ + +class ProtobufExprShimsSuite extends AnyFunSuite { + private val outputSchema = StructType(Seq( + StructField("id", IntegerType, nullable = true), + StructField("name", StringType, nullable = true))) + + private case class FakeExprChild() extends Expression { + override def children: Seq[Expression] = Nil + override def nullable: Boolean = true + override def dataType: DataType = BinaryType + override def eval(input: org.apache.spark.sql.catalyst.InternalRow): Any = null + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException("not needed") + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + assert(newChildren.isEmpty) + this + } + } + + private abstract class FakeBaseProtobufExpr(childExpr: Expression) extends UnaryExpression { + override def child: Expression = childExpr + override def nullable: Boolean = true + override def dataType: DataType = outputSchema + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException("not needed") + override protected def withNewChildInternal(newChild: Expression): Expression = this + } + + private case class FakePathProtobufExpr(override val child: Expression) + extends FakeBaseProtobufExpr(child) { + def messageName: String = "test.Message" + def descFilePath: Option[String] = Some("/tmp/test.desc") + def options: scala.collection.Map[String, String] = Map("mode" -> "FAILFAST") + } + + private case class FakeBytesProtobufExpr(override val child: Expression) + extends FakeBaseProtobufExpr(child) { + def messageName: String = "test.Message" + def binaryDescriptorSet: Array[Byte] = Array[Byte](1, 2, 3) + def options: scala.collection.Map[String, String] = + Map("mode" -> "PERMISSIVE", "enums.as.ints" -> "true") + } + + private case class FakeMissingOptionsExpr(override val child: Expression) + extends FakeBaseProtobufExpr(child) { + def messageName: String = "test.Message" + def descFilePath: Option[String] = Some("/tmp/test.desc") + } + + private case class FakeDifferentMessageExpr(override val child: Expression) + extends FakeBaseProtobufExpr(child) { + def messageName: String = "test.OtherMessage" + def descFilePath: Option[String] = Some("/tmp/test.desc") + def options: scala.collection.Map[String, String] = Map("mode" -> "FAILFAST") + } + + private case class FakeDifferentDescriptorExpr(override val child: Expression) + extends FakeBaseProtobufExpr(child) { + def messageName: String = "test.Message" + def descFilePath: Option[String] = Some("/tmp/other.desc") + def options: scala.collection.Map[String, String] = Map("mode" -> "FAILFAST") + } + + private case class FakeDifferentOptionsExpr(override val child: Expression) + extends FakeBaseProtobufExpr(child) { + def messageName: String = "test.Message" + def descFilePath: Option[String] = Some("/tmp/test.desc") + def options: scala.collection.Map[String, String] = Map("mode" -> "PERMISSIVE") + } + + private case class FakeTypedUnaryExpr( + dt: DataType, + override val child: Expression = FakeExprChild()) extends UnaryExpression { + override def nullable: Boolean = true + override def dataType: DataType = dt + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException("not needed") + override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = + newChild) + } + + private object FakeSpark34ProtobufUtils { + def buildDescriptor(messageName: String, descFilePath: Option[String]): String = + s"$messageName:${descFilePath.getOrElse("none")}" + } + + private object FakeSpark35ProtobufUtils { + def buildDescriptor(messageName: String, binaryFileDescriptorSet: Option[Array[Byte]]): String = + s"$messageName:${binaryFileDescriptorSet.map(_.mkString(",")).getOrElse("none")}" + } + + private case class FakeMessageDescriptor( + syntax: String, + fields: Map[String, ProtobufFieldDescriptor]) extends ProtobufMessageDescriptor { + override def findField(name: String): Option[ProtobufFieldDescriptor] = fields.get(name) + } + + private case class FakeFieldDescriptor( + name: String, + fieldNumber: Int, + protoTypeName: String, + isRepeated: Boolean = false, + isRequired: Boolean = false, + defaultValue: Option[ProtobufDefaultValue] = None, + defaultValueError: Option[String] = None, + enumMetadata: Option[ProtobufEnumMetadata] = None, + messageDescriptor: Option[ProtobufMessageDescriptor] = None) extends ProtobufFieldDescriptor { + override lazy val defaultValueResult: Either[String, Option[ProtobufDefaultValue]] = + defaultValueError match { + case Some(reason) => Left(reason) + case None => Right(defaultValue) + } + } + + test("compat extracts descriptor path and options from legacy expression") { + val exprInfo = SparkProtobufCompat.extractExprInfo(FakePathProtobufExpr(FakeExprChild())) + assert(exprInfo.isRight) + val info = exprInfo.toOption.get + assert(info.messageName == "test.Message") + assert(info.options == Map("mode" -> "FAILFAST")) + assert(info.descriptorSource == + ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc")) + } + + test("compat extracts binary descriptor source and planner options") { + val exprInfo = SparkProtobufCompat.extractExprInfo(FakeBytesProtobufExpr(FakeExprChild())) + assert(exprInfo.isRight) + val info = exprInfo.toOption.get + info.descriptorSource match { + case ProtobufDescriptorSource.DescriptorBytes(bytes) => + assert(bytes.sameElements(Array[Byte](1, 2, 3))) + case other => + fail(s"Unexpected descriptor source: $other") + } + val plannerOptions = SparkProtobufCompat.parsePlannerOptions(info.options) + assert(plannerOptions == + Right(ProtobufPlannerOptions(enumsAsInts = true, failOnErrors = false))) + } + + test("compat invokes Spark 3.4 descriptor builder with descriptor path") { + val buildMethod = FakeSpark34ProtobufUtils.getClass.getMethod( + "buildDescriptor", classOf[String], classOf[scala.Option[_]]) + + val result = SparkProtobufCompat.invokeBuildDescriptor( + buildMethod, + FakeSpark34ProtobufUtils, + "test.Message", + ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc"), + _ => fail("path-to-bytes fallback should not be needed for Spark 3.4")) + + assert(result == "test.Message:/tmp/test.desc") + } + + test("compat retries descriptor path as bytes for Spark 3.5 descriptor builder") { + val buildMethod = FakeSpark35ProtobufUtils.getClass.getMethod( + "buildDescriptor", classOf[String], classOf[scala.Option[_]]) + var readCalls = 0 + + val result = SparkProtobufCompat.invokeBuildDescriptor( + buildMethod, + FakeSpark35ProtobufUtils, + "test.Message", + ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc"), + _ => { + readCalls += 1 + Array[Byte](1, 2, 3) + }) + + assert(readCalls == 1) + assert(result == "test.Message:1,2,3") + } + + test("compat passes bytes directly to Spark 3.5 descriptor builder") { + val buildMethod = FakeSpark35ProtobufUtils.getClass.getMethod( + "buildDescriptor", classOf[String], classOf[scala.Option[_]]) + + val result = SparkProtobufCompat.invokeBuildDescriptor( + buildMethod, + FakeSpark35ProtobufUtils, + "test.Message", + ProtobufDescriptorSource.DescriptorBytes(Array[Byte](4, 5, 6)), + _ => fail("binary descriptor source should not read a file")) + + assert(result == "test.Message:4,5,6") + } + + test("compat distinguishes decode semantics across message descriptor and options") { + val child = FakeExprChild() + + assert(SparkProtobufCompat.sameDecodeSemantics( + FakePathProtobufExpr(child), FakePathProtobufExpr(child))) + assert(SparkProtobufCompat.sameDecodeSemantics( + FakeBytesProtobufExpr(child), FakeBytesProtobufExpr(child))) + assert(!SparkProtobufCompat.sameDecodeSemantics( + FakePathProtobufExpr(child), FakeDifferentMessageExpr(child))) + assert(!SparkProtobufCompat.sameDecodeSemantics( + FakePathProtobufExpr(child), FakeDifferentDescriptorExpr(child))) + assert(!SparkProtobufCompat.sameDecodeSemantics( + FakePathProtobufExpr(child), FakeDifferentOptionsExpr(child))) + } + + test("compat reports missing options accessor as cpu fallback reason") { + val exprInfo = SparkProtobufCompat.extractExprInfo(FakeMissingOptionsExpr(FakeExprChild())) + assert(exprInfo.left.toOption.exists( + _.contains("Cannot read from_protobuf options via reflection"))) + } + + test("compat detects unsupported options and proto3 syntax") { + assert(SparkProtobufCompat.unsupportedOptions(Map("mode" -> "FAILFAST", "foo" -> "bar")) == + Seq("foo")) + assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax("PROTO3")) + assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax("EDITIONS")) + assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax("")) + assert(SparkProtobufCompat.isGpuSupportedProtoSyntax("PROTO2")) + } + + test("extractor preserves typed enum defaults") { + val enumMeta = ProtobufEnumMetadata(Seq( + ProtobufEnumValue(0, "UNKNOWN"), + ProtobufEnumValue(1, "EN"), + ProtobufEnumValue(2, "ZH"))) + val msgDesc = FakeMessageDescriptor( + syntax = "PROTO2", + fields = Map( + "language" -> FakeFieldDescriptor( + name = "language", + fieldNumber = 1, + protoTypeName = "ENUM", + defaultValue = Some(ProtobufDefaultValue.EnumValue(1, "EN")), + enumMetadata = Some(enumMeta)))) + val schema = StructType(Seq(StructField("language", StringType, nullable = true))) + + val infos = ProtobufSchemaExtractor.analyzeAllFields( + schema, msgDesc, enumsAsInts = false, "test.Message") + + assert(infos.isRight) + assert(infos.toOption.get("language").defaultValue.contains( + ProtobufDefaultValue.EnumValue(1, "EN"))) + } + + test("extractor records reflection failures as unsupported field info") { + val msgDesc = FakeMessageDescriptor( + syntax = "PROTO2", + fields = Map( + "ok" -> FakeFieldDescriptor( + name = "ok", + fieldNumber = 1, + protoTypeName = "INT32"), + "id" -> FakeFieldDescriptor( + name = "id", + fieldNumber = 2, + protoTypeName = "INT32", + defaultValueError = + Some("Failed to read protobuf default value for field 'id': unsupported type")))) + val schema = StructType(Seq( + StructField("ok", IntegerType, nullable = true), + StructField("id", IntegerType, nullable = true))) + + val infos = ProtobufSchemaExtractor.analyzeAllFields( + schema, msgDesc, enumsAsInts = true, "test.Message") + + assert(infos.isRight) + assert(infos.toOption.get("ok").isSupported) + assert(!infos.toOption.get("id").isSupported) + assert(infos.toOption.get("id").unsupportedReason.exists( + _.contains("Failed to read protobuf default value for field 'id'"))) + } + + test("validator encodes enum-string defaults into both numeric and string payloads") { + val enumMeta = ProtobufEnumMetadata(Seq( + ProtobufEnumValue(0, "UNKNOWN"), + ProtobufEnumValue(1, "EN"))) + val info = ProtobufFieldInfo( + fieldNumber = 2, + protoTypeName = "ENUM", + sparkType = StringType, + encoding = GpuFromProtobuf.ENC_ENUM_STRING, + isSupported = true, + unsupportedReason = None, + isRequired = false, + defaultValue = Some(ProtobufDefaultValue.EnumValue(1, "EN")), + enumMetadata = Some(enumMeta), + isRepeated = false) + + val flat = ProtobufSchemaValidator.toFlattenedFieldDescriptor( + path = "common.language", + field = StructField("language", StringType, nullable = true), + fieldInfo = info, + parentIdx = 0, + depth = 1, + outputTypeId = 6) + + assert(flat.isRight) + assert(flat.toOption.get.defaultInt == 1L) + assert(new String(flat.toOption.get.defaultString, "UTF-8") == "EN") + assert(flat.toOption.get.enumValidValues.sameElements(Array(0, 1))) + assert(flat.toOption.get.enumNames + .map(new String(_, "UTF-8")) + .sameElements(Array("UNKNOWN", "EN"))) + } + + test("validator rejects enum-string field without enum metadata") { + val info = ProtobufFieldInfo( + fieldNumber = 2, + protoTypeName = "ENUM", + sparkType = StringType, + encoding = GpuFromProtobuf.ENC_ENUM_STRING, + isSupported = true, + unsupportedReason = None, + isRequired = false, + defaultValue = Some(ProtobufDefaultValue.EnumValue(1, "EN")), + enumMetadata = None, + isRepeated = false) + + val flat = ProtobufSchemaValidator.toFlattenedFieldDescriptor( + path = "common.language", + field = StructField("language", StringType, nullable = true), + fieldInfo = info, + parentIdx = 0, + depth = 1, + outputTypeId = 6) + + assert(flat.left.toOption.exists(_.contains("missing enum metadata"))) + } + + test("validator returns Left for incompatible default type instead of throwing") { + val info = ProtobufFieldInfo( + fieldNumber = 3, + protoTypeName = "FLOAT", + sparkType = DoubleType, + encoding = GpuFromProtobuf.ENC_DEFAULT, + isSupported = true, + unsupportedReason = None, + isRequired = false, + defaultValue = Some(ProtobufDefaultValue.FloatValue(1.5f)), + enumMetadata = None, + isRepeated = false) + + val flat = ProtobufSchemaValidator.toFlattenedFieldDescriptor( + path = "common.score", + field = StructField("score", DoubleType, nullable = true), + fieldInfo = info, + parentIdx = 0, + depth = 1, + outputTypeId = 6) + + assert(flat.left.toOption.exists( + _.contains("Incompatible default value for protobuf field 'common.score'"))) + } + + test("array struct field meta uses pruned child field count after ordinal remap") { + val originalStruct = StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true))) + val prunedStruct = StructType(Seq(StructField("b", IntegerType, nullable = true))) + val originalChild = FakeTypedUnaryExpr(ArrayType(originalStruct, containsNull = true)) + val sparkExpr = GetArrayStructFields( + child = originalChild, + field = originalStruct.fields(1), + ordinal = 1, + numFields = originalStruct.fields.length, + containsNull = true) + sparkExpr.setTagValue(GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG, 0) + + val prunedChild = FakeTypedUnaryExpr(ArrayType(prunedStruct, containsNull = true)) + val runtimeOrd = sparkExpr.getTagValue(GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).get + + assert(runtimeOrd == 0) + assert( + GpuGetArrayStructFieldsMeta.effectiveNumFields(prunedChild, sparkExpr, runtimeOrd) == 1) + } + + test("GpuFromProtobuf semantic equality is content-based for schema arrays") { + def emptyEnumNames: Array[Array[Byte]] = Array.empty[Array[Byte]] + + val expr1 = GpuFromProtobuf( + decodedSchema = outputSchema, + fieldNumbers = Array(1, 2), + parentIndices = Array(-1, -1), + depthLevels = Array(0, 0), + wireTypes = Array(0, 2), + outputTypeIds = Array(3, 6), + encodings = Array(0, 0), + isRepeated = Array(false, false), + isRequired = Array(false, false), + hasDefaultValue = Array(false, false), + defaultInts = Array(0L, 0L), + defaultFloats = Array(0.0, 0.0), + defaultBools = Array(false, false), + defaultStrings = Array(Array.emptyByteArray, Array.emptyByteArray), + enumValidValues = Array(Array.emptyIntArray, Array.emptyIntArray), + enumNames = Array(emptyEnumNames, emptyEnumNames), + failOnErrors = true, + child = FakeExprChild()) + + val expr2 = GpuFromProtobuf( + decodedSchema = outputSchema, + fieldNumbers = Array(1, 2), + parentIndices = Array(-1, -1), + depthLevels = Array(0, 0), + wireTypes = Array(0, 2), + outputTypeIds = Array(3, 6), + encodings = Array(0, 0), + isRepeated = Array(false, false), + isRequired = Array(false, false), + hasDefaultValue = Array(false, false), + defaultInts = Array(0L, 0L), + defaultFloats = Array(0.0, 0.0), + defaultBools = Array(false, false), + defaultStrings = Array(Array.emptyByteArray, Array.emptyByteArray), + enumValidValues = Array(Array.emptyIntArray, Array.emptyIntArray), + enumNames = Array(emptyEnumNames.map(identity), emptyEnumNames.map(identity)), + failOnErrors = true, + child = FakeExprChild()) + + assert(expr1.semanticEquals(expr2)) + assert(expr1.semanticHash() == expr2.semanticHash()) + } +}