diff --git a/integration_tests/pom.xml b/integration_tests/pom.xml index e3d91be0ce3..f178e84ffba 100644 --- a/integration_tests/pom.xml +++ b/integration_tests/pom.xml @@ -1,6 +1,6 @@ + + true diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala new file mode 100644 index 00000000000..cabc8d2905d --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.protobuf + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import com.google.protobuf.DescriptorProtos +import com.google.protobuf.Descriptors + +/** + * Minimal descriptor utilities for locating a message descriptor in a FileDescriptorSet. + * + * This is intentionally lightweight for the "simple types" from_protobuf patch: it supports + * descriptor sets produced by `protoc --include_imports --descriptor_set_out=...`. + * + * NOTE: This utility is currently not used in the initial implementation, which relies on + * Spark's ProtobufUtils via reflection (buildMessageDescriptorWithSparkProtobuf). This class + * is preserved for potential future use cases where direct descriptor parsing is needed + * without depending on Spark's shaded protobuf classes. + */ +object ProtobufDescriptorUtils { + + def buildMessageDescriptor( + fileDescriptorSetBytes: Array[Byte], + messageName: String): Descriptors.Descriptor = { + val fds = DescriptorProtos.FileDescriptorSet.parseFrom(fileDescriptorSetBytes) + val protos = fds.getFileList.asScala.toSeq + val byName = protos.map(p => p.getName -> p).toMap + val cache = mutable.HashMap.empty[String, Descriptors.FileDescriptor] + + def buildFileDescriptor(name: String): Descriptors.FileDescriptor = { + cache.getOrElseUpdate(name, { + val p = byName.getOrElse(name, + throw new IllegalArgumentException(s"Missing FileDescriptorProto for '$name'")) + val deps = p.getDependencyList.asScala.map(buildFileDescriptor _).toArray + Descriptors.FileDescriptor.buildFrom(p, deps) + }) + } + + val fileDescriptors = protos.map(p => buildFileDescriptor(p.getName)) + val candidates = fileDescriptors.iterator.flatMap(fd => findMessageDescriptors(fd, messageName)) + .toSeq + + candidates match { + case Seq(d) => d + case Seq() => + throw new IllegalArgumentException( + s"Message '$messageName' not found in FileDescriptorSet") + case many => + val names = many.map(_.getFullName).distinct.sorted + throw new IllegalArgumentException( + s"Message '$messageName' is ambiguous; matches: ${names.mkString(", ")}") + } + } + + private def findMessageDescriptors( + fd: Descriptors.FileDescriptor, + messageName: String): Iterator[Descriptors.Descriptor] = { + def matches(d: Descriptors.Descriptor): Boolean = { + d.getName == messageName || + d.getFullName == messageName || + d.getFullName.endsWith("." + messageName) + } + + def walk(d: Descriptors.Descriptor): Iterator[Descriptors.Descriptor] = { + val nested = d.getNestedTypes.asScala.iterator.flatMap(walk _) + if (matches(d)) Iterator.single(d) ++ nested else nested + } + + fd.getMessageTypes.asScala.iterator.flatMap(walk _) + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala new file mode 100644 index 00000000000..0437a85a5a4 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala @@ -0,0 +1,232 @@ +/* + * Copyright (c) 2025-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import ai.rapids.cudf +import ai.rapids.cudf.{BinaryOp, CudfException, DType} +import com.nvidia.spark.rapids.{GpuColumnVector, GpuUnaryExpression} +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.jni.Protobuf +import com.nvidia.spark.rapids.shims.NullIntolerantShim + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression} +import org.apache.spark.sql.types._ + +/** + * GPU implementation for Spark's `from_protobuf` decode path. + * + * This is designed to replace `org.apache.spark.sql.protobuf.ProtobufDataToCatalyst` when + * supported. + * + * The implementation uses a two-pass approach in the CUDA kernel: + * - Pass 1: Scan all messages once, recording (offset, length) for each requested field + * - Pass 2: Extract data in parallel using the recorded locations + * + * This is significantly faster than per-field parsing when decoding multiple fields, + * as each message is only parsed once regardless of the number of fields. + * + * @param fullSchema The complete output schema (must match the original expression's dataType) + * @param decodedFieldIndices Indices into fullSchema for fields that will be decoded by GPU. + * Fields not in this array will be null columns. + * @param fieldNumbers Protobuf field numbers for decoded fields (parallel to decodedFieldIndices) + * @param cudfTypeIds cuDF type IDs for ALL fields in fullSchema + * @param cudfTypeScales Encodings for decoded fields (parallel to decodedFieldIndices) + * @param failOnErrors If true, throw exception on malformed data; if false, return null + */ +case class GpuFromProtobuf( + fullSchema: StructType, + decodedFieldIndices: Array[Int], + fieldNumbers: Array[Int], + cudfTypeIds: Array[Int], + cudfTypeScales: Array[Int], + failOnErrors: Boolean, + child: Expression) + extends GpuUnaryExpression with ExpectsInputTypes with NullIntolerantShim { + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override def dataType: DataType = fullSchema.asNullable + + override def nullable: Boolean = true + + // Lazy computation of unsupported field indices (complex types like StructType) + @transient + private lazy val unsupportedFieldIndices: Set[Int] = { + fullSchema.fields.zipWithIndex.collect { + case (sf, idx) if !GpuFromProtobuf.isTypeSupported(sf.dataType) => idx + }.toSet + } + + override protected def doColumnar(input: GpuColumnVector): cudf.ColumnVector = { + val numRows = input.getRowCount.toInt + + // Call the optimized JNI API that: + // 1. Uses fused kernel to scan all fields in one pass + // 2. Creates LIST directly for bytes fields (no intermediate strings column) + // 3. Returns struct with decoded fields + null columns for supported types + val jniResult = try { + Protobuf.decodeToStruct( + input.getBase, + fullSchema.fields.length, // total number of fields in output + decodedFieldIndices, // which fields to decode + fieldNumbers, // protobuf field numbers + cudfTypeIds, // types for ALL fields (INT8 placeholder for unsupported) + cudfTypeScales, // encodings for decoded fields + failOnErrors) + } catch { + case e: CudfException if failOnErrors => + // Re-throw as a SparkException for consistent error handling + throw new org.apache.spark.SparkException("Malformed protobuf message", e) + } + + // If there are fields with unsupported types, we need to replace placeholder columns + // with properly typed null columns + val result = if (unsupportedFieldIndices.isEmpty) { + jniResult + } else { + withResource(jniResult) { struct => + // Build children array, replacing placeholders with properly typed null columns + val children = new Array[cudf.ColumnVector](fullSchema.fields.length) + try { + for (i <- fullSchema.fields.indices) { + if (unsupportedFieldIndices.contains(i)) { + // Create properly typed null column for unsupported types + children(i) = GpuFromProtobuf.createNullColumn(fullSchema.fields(i).dataType, numRows) + } else { + // Copy the column from JNI result (incRefCount to share ownership) + children(i) = struct.getChildColumnView(i).copyToColumnVector() + } + } + cudf.ColumnVector.makeStruct(numRows, children: _*) + } finally { + children.foreach(c => if (c != null) c.close()) + } + } + } + + // Apply input nulls to output + if (input.getBase.hasNulls) { + withResource(result) { _ => + result.mergeAndSetValidity(BinaryOp.BITWISE_AND, input.getBase) + } + } else { + result + } + } +} + +object GpuFromProtobuf { + // Encodings from com.nvidia.spark.rapids.jni.Protobuf + val ENC_DEFAULT = 0 + val ENC_FIXED = 1 + val ENC_ZIGZAG = 2 + + /** + * Maps a Spark DataType to the corresponding cuDF native type ID. + * Note: The encoding (varint/zigzag/fixed) is determined by the protobuf field type, + * not the Spark data type, so it must be set separately based on the protobuf schema. + * + * @return Some(typeId) for supported types, None for unsupported types + */ + def sparkTypeToCudfIdOpt(dt: DataType): Option[Int] = dt match { + case BooleanType => Some(DType.BOOL8.getTypeId.getNativeId) + case IntegerType => Some(DType.INT32.getTypeId.getNativeId) + case LongType => Some(DType.INT64.getTypeId.getNativeId) + case FloatType => Some(DType.FLOAT32.getTypeId.getNativeId) + case DoubleType => Some(DType.FLOAT64.getTypeId.getNativeId) + case StringType => Some(DType.STRING.getTypeId.getNativeId) + case BinaryType => Some(DType.LIST.getTypeId.getNativeId) + case _ => None + } + + /** + * Check if a Spark DataType is supported by the GPU protobuf decoder. + */ + def isTypeSupported(dt: DataType): Boolean = sparkTypeToCudfIdOpt(dt).isDefined + + /** + * Create an all-null column of the specified Spark DataType. + * This is used for fields with unsupported types (nested structs, arrays, etc.) + * that are not decoded but need to be present in the output struct. + */ + def createNullColumn(dt: DataType, numRows: Int): cudf.ColumnVector = { + // Helper to create null arrays for boxed types + def nullBools = Array.fill[java.lang.Boolean](numRows)(null) + def nullInts = Array.fill[java.lang.Integer](numRows)(null) + def nullLongs = Array.fill[java.lang.Long](numRows)(null) + def nullFloats = Array.fill[java.lang.Float](numRows)(null) + def nullDoubles = Array.fill[java.lang.Double](numRows)(null) + + dt match { + case BooleanType => cudf.ColumnVector.fromBoxedBooleans(nullBools: _*) + case IntegerType => cudf.ColumnVector.fromBoxedInts(nullInts: _*) + case LongType => cudf.ColumnVector.fromBoxedLongs(nullLongs: _*) + case FloatType => cudf.ColumnVector.fromBoxedFloats(nullFloats: _*) + case DoubleType => cudf.ColumnVector.fromBoxedDoubles(nullDoubles: _*) + case StringType => cudf.ColumnVector.fromStrings(Array.fill[String](numRows)(null): _*) + case BinaryType => + // Binary is LIST - create all-null list column using Scalar API + val elementType = new cudf.HostColumnVector.BasicType(true, DType.INT8) + withResource(cudf.Scalar.listFromNull(elementType)) { nullScalar => + cudf.ColumnVector.fromScalar(nullScalar, numRows) + } + case st: StructType => + // Recursively create null columns for struct fields + val children = st.fields.map(f => createNullColumn(f.dataType, numRows)) + try { + withResource(cudf.ColumnVector.makeStruct(numRows, children: _*)) { structCol => + // Set all rows to null - mergeAndSetValidity returns a NEW column + withResource(cudf.ColumnVector.fromBoxedBooleans(nullBools: _*)) { nullMask => + structCol.mergeAndSetValidity(BinaryOp.BITWISE_AND, nullMask) + } + } + } finally { + children.foreach(_.close()) + } + case ArrayType(elementType, _) => + // Create empty arrays with all nulls using Scalar API + val cudfElementDType = sparkTypeToCudfIdOpt(elementType) + .map(id => DType.fromNative(id, 0)) + .getOrElse(DType.INT8) // fallback for nested complex types + val elemType = new cudf.HostColumnVector.BasicType(true, cudfElementDType) + withResource(cudf.Scalar.listFromNull(elemType)) { nullScalar => + cudf.ColumnVector.fromScalar(nullScalar, numRows) + } + case MapType(keyType, valueType, _) => + // Maps are represented as LIST> in cuDF + // For all-null maps, we create a list column with STRUCT element type + val cudfKeyDType = sparkTypeToCudfIdOpt(keyType) + .map(id => DType.fromNative(id, 0)) + .getOrElse(DType.INT8) + val cudfValueDType = sparkTypeToCudfIdOpt(valueType) + .map(id => DType.fromNative(id, 0)) + .getOrElse(DType.INT8) + // Create the struct type for map entries (key, value) + val keyFieldType = new cudf.HostColumnVector.BasicType(true, cudfKeyDType) + val valueFieldType = new cudf.HostColumnVector.BasicType(true, cudfValueDType) + val structType = new cudf.HostColumnVector.StructType(true, keyFieldType, valueFieldType) + // Create an all-null map column (list of structs) + withResource(cudf.Scalar.listFromNull(structType)) { nullScalar => + cudf.ColumnVector.fromScalar(nullScalar, numRows) + } + case _ => + // Fallback for any other types - create INT8 nulls as placeholder + // This should not happen in practice since unsupported types should be caught earlier + cudf.ColumnVector.fromBoxedBytes(Array.fill[java.lang.Byte](numRows)(null): _*) + } + } +} diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala new file mode 100644 index 00000000000..927747516fb --- /dev/null +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala @@ -0,0 +1,540 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "340"} +{"spark": "341"} +{"spark": "342"} +{"spark": "343"} +{"spark": "344"} +{"spark": "350"} +{"spark": "351"} +{"spark": "352"} +{"spark": "353"} +{"spark": "354"} +{"spark": "355"} +{"spark": "356"} +{"spark": "357"} +{"spark": "400"} +{"spark": "401"} +spark-rapids-shim-json-lines ***/ + +package com.nvidia.spark.rapids.shims + +import java.nio.file.{Files, Path} + +import scala.collection.mutable +import scala.util.Try + +import ai.rapids.cudf.DType +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.catalyst.expressions.{ + AttributeReference, Expression, GetStructField, UnaryExpression +} +import org.apache.spark.sql.execution.ProjectExec +import org.apache.spark.sql.rapids.GpuFromProtobuf +import org.apache.spark.sql.types._ + +/** + * Information about a protobuf field for schema projection support. + */ +private[shims] case class ProtobufFieldInfo( + fieldNumber: Int, + protoTypeName: String, + sparkType: DataType, + encoding: Int, + isSupported: Boolean, + unsupportedReason: Option[String] +) + +/** + * Spark 3.4+ optional integration for spark-protobuf expressions. + * + * spark-protobuf is an external module, so these rules must be registered by reflection. + */ +object ProtobufExprShims { + private[this] val protobufDataToCatalystClassName = + "org.apache.spark.sql.protobuf.ProtobufDataToCatalyst" + + private[this] val sparkProtobufUtilsObjectClassName = + "org.apache.spark.sql.protobuf.utils.ProtobufUtils$" + + def exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { + try { + val clazz = ShimReflectionUtils.loadClass(protobufDataToCatalystClassName) + .asInstanceOf[Class[_ <: UnaryExpression]] + Map(clazz.asInstanceOf[Class[_ <: Expression]] -> fromProtobufRule) + } catch { + case _: ClassNotFoundException => Map.empty + } + } + + private def fromProtobufRule: ExprRule[_ <: Expression] = { + GpuOverrides.expr[UnaryExpression]( + "Decode a BinaryType column (protobuf) into a Spark SQL struct", + ExprChecks.unaryProject( + // Use TypeSig.all here because schema projection determines which fields + // actually need GPU support. Detailed type checking is done in tagExprForGpu. + TypeSig.all, + TypeSig.all, + TypeSig.BINARY, + TypeSig.BINARY), + (e, conf, p, r) => new UnaryExprMeta[UnaryExpression](e, conf, p, r) { + + // Full schema from the expression (must match original dataType for compatibility) + private var fullSchema: StructType = _ + // Indices into fullSchema for fields that will be decoded by GPU + private var decodedFieldIndices: Array[Int] = _ + private var fieldNumbers: Array[Int] = _ + // cudfTypeIds contains type IDs for ALL fields in fullSchema (for the new optimized API) + private var cudfTypeIds: Array[Int] = _ + // cudfTypeScales: encodings for decoded fields (parallel to decodedFieldIndices) + private var cudfTypeScales: Array[Int] = _ + private var failOnErrors: Boolean = _ + + override def tagExprForGpu(): Unit = { + fullSchema = e.dataType match { + case st: StructType => st + case other => + willNotWorkOnGpu( + s"Only StructType output is supported for from_protobuf, got $other") + return + } + + val options = getOptionsMap(e) + val supportedOptions = Set("enums.as.ints", "mode") + val unsupportedOptions = options.keys.filterNot(supportedOptions.contains) + if (unsupportedOptions.nonEmpty) { + val keys = unsupportedOptions.mkString(",") + willNotWorkOnGpu( + s"from_protobuf options are not supported yet on GPU: $keys") + return + } + + val enumsAsInts = options.getOrElse("enums.as.ints", "false").toBoolean + failOnErrors = options.getOrElse("mode", "PERMISSIVE").equalsIgnoreCase("FAILFAST") + val messageName = getMessageName(e) + val descFilePathOpt = getDescFilePath(e).orElse { + // Newer Spark may embed a descriptor set (binaryDescriptorSet). Write it to a temp file + // so we can reuse Spark's ProtobufUtils (and its shaded protobuf classes) to resolve + // the descriptor. + getDescriptorBytes(e).map(writeTempDescFile) + } + if (descFilePathOpt.isEmpty) { + willNotWorkOnGpu( + "from_protobuf requires a descriptor set " + + "(descFilePath or binaryDescriptorSet)") + return + } + + val msgDesc = try { + // Spark 3.4.x builds the descriptor as: + // ProtobufUtils.buildDescriptor(messageName, descFilePathOpt) + buildMessageDescriptorWithSparkProtobuf(messageName, descFilePathOpt) + } catch { + case t: Throwable => + willNotWorkOnGpu( + s"Failed to resolve protobuf descriptor for message '$messageName': " + + s"${t.getMessage}") + return + } + + // Step 1: Analyze all fields and build field info map + val allFieldsInfo = analyzeAllFields(fullSchema, msgDesc, enumsAsInts, messageName) + if (allFieldsInfo.isEmpty) { + // Error was already reported in analyzeAllFields + return + } + val fieldsInfoMap = allFieldsInfo.get + + // Step 2: Determine which fields are actually required by downstream operations + val requiredFieldNames = analyzeRequiredFields(fieldsInfoMap.keySet) + + // Step 3: Check if all required fields are supported + val unsupportedRequired = requiredFieldNames.filter { name => + fieldsInfoMap.get(name).exists(!_.isSupported) + } + + if (unsupportedRequired.nonEmpty) { + val reasons = unsupportedRequired.map { name => + val info = fieldsInfoMap(name) + s"${name}: ${info.unsupportedReason.getOrElse("unknown reason")}" + } + willNotWorkOnGpu( + s"Required fields not supported for from_protobuf: ${reasons.mkString(", ")}") + return + } + + // Step 4: Identify which fields in fullSchema need to be decoded + // These are fields that are required AND supported + val indicesToDecode = fullSchema.fields.zipWithIndex.collect { + case (sf, idx) if requiredFieldNames.contains(sf.name) => idx + } + decodedFieldIndices = indicesToDecode + + // Step 5: Build cudfTypeIds for ALL fields in fullSchema + // For unsupported types (nested struct, array, etc.), use INT8 as placeholder. + // These placeholder columns will be replaced with properly typed null columns in Scala. + cudfTypeIds = fullSchema.fields.map { sf => + GpuFromProtobuf.sparkTypeToCudfIdOpt(sf.dataType) + .getOrElse(DType.INT8.getTypeId.getNativeId) // placeholder for unsupported types + } + + // Step 6: Build arrays for decoded fields only (parallel to decodedFieldIndices) + val fnums = new Array[Int](indicesToDecode.length) + val scales = new Array[Int](indicesToDecode.length) + + indicesToDecode.zipWithIndex.foreach { case (schemaIdx, arrIdx) => + val sf = fullSchema.fields(schemaIdx) + val info = fieldsInfoMap(sf.name) + fnums(arrIdx) = info.fieldNumber + scales(arrIdx) = info.encoding + } + + fieldNumbers = fnums + cudfTypeScales = scales + } + + /** + * Analyze all fields in the schema and build a map of field name to ProtobufFieldInfo. + * Returns None if there's an error that should abort processing. + */ + private def analyzeAllFields( + schema: StructType, + msgDesc: AnyRef, + enumsAsInts: Boolean, + messageName: String): Option[Map[String, ProtobufFieldInfo]] = { + val result = mutable.Map[String, ProtobufFieldInfo]() + + for (sf <- schema.fields) { + val fd = invoke1[AnyRef](msgDesc, "findFieldByName", classOf[String], sf.name) + if (fd == null) { + willNotWorkOnGpu( + s"Protobuf field '${sf.name}' not found in message '$messageName'") + return None + } + + val isRepeated = Try { + invoke0[java.lang.Boolean](fd, "isRepeated").booleanValue() + }.getOrElse(false) + + val protoType = invoke0[AnyRef](fd, "getType") + val protoTypeName = typeName(protoType) + val fieldNumber = invoke0[java.lang.Integer](fd, "getNumber").intValue() + + // Check field support and determine encoding + val (isSupported, unsupportedReason, encoding) = + checkFieldSupport(sf.dataType, protoTypeName, isRepeated, enumsAsInts) + + result(sf.name) = ProtobufFieldInfo( + fieldNumber = fieldNumber, + protoTypeName = protoTypeName, + sparkType = sf.dataType, + encoding = encoding, + isSupported = isSupported, + unsupportedReason = unsupportedReason + ) + } + + Some(result.toMap) + } + + /** + * Check if a field type is supported and return encoding information. + * @return (isSupported, unsupportedReason, encoding) + */ + private def checkFieldSupport( + sparkType: DataType, + protoTypeName: String, + isRepeated: Boolean, + enumsAsInts: Boolean): (Boolean, Option[String], Int) = { + + if (isRepeated) { + return (false, Some("repeated fields are not supported"), GpuFromProtobuf.ENC_DEFAULT) + } + + // Check Spark type is one of the supported simple types + sparkType match { + case BooleanType | IntegerType | LongType | FloatType | DoubleType | + StringType | BinaryType => + // Supported Spark type, continue to check encoding + case other => + return (false, Some(s"unsupported Spark type: $other"), GpuFromProtobuf.ENC_DEFAULT) + } + + // Determine encoding based on Spark type and proto type combination + val encoding = (sparkType, protoTypeName) match { + case (BooleanType, "BOOL") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (IntegerType, "INT32" | "UINT32") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (IntegerType, "SINT32") => Some(GpuFromProtobuf.ENC_ZIGZAG) + case (IntegerType, "FIXED32" | "SFIXED32") => Some(GpuFromProtobuf.ENC_FIXED) + case (LongType, "INT64" | "UINT64") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (LongType, "SINT64") => Some(GpuFromProtobuf.ENC_ZIGZAG) + case (LongType, "FIXED64" | "SFIXED64") => Some(GpuFromProtobuf.ENC_FIXED) + // Spark may upcast smaller integers to LongType + case (LongType, "INT32" | "UINT32" | "SINT32" | "FIXED32" | "SFIXED32") => + val enc = protoTypeName match { + case "SINT32" => GpuFromProtobuf.ENC_ZIGZAG + case "FIXED32" | "SFIXED32" => GpuFromProtobuf.ENC_FIXED + case _ => GpuFromProtobuf.ENC_DEFAULT + } + Some(enc) + case (FloatType, "FLOAT") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (DoubleType, "DOUBLE") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (StringType, "STRING") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (BinaryType, "BYTES") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (IntegerType, "ENUM") if enumsAsInts => Some(GpuFromProtobuf.ENC_DEFAULT) + case _ => None + } + + encoding match { + case Some(enc) => (true, None, enc) + case None => + (false, + Some(s"type mismatch: Spark $sparkType vs Protobuf $protoTypeName"), + GpuFromProtobuf.ENC_DEFAULT) + } + } + + /** + * Analyze which fields are actually required by downstream operations. + * Currently supports analyzing parent Project expressions. + * + * @param allFieldNames All field names in the full schema + * @return Set of field names that are actually required + */ + private def analyzeRequiredFields(allFieldNames: Set[String]): Set[String] = { + // Try to find parent SparkPlanMeta and analyze downstream Project + val parentPlanOpt = findParentPlanMeta() + + parentPlanOpt match { + case Some(planMeta) => + // First, try to analyze the immediate parent + analyzeDownstreamProject(planMeta) match { + case Some(fields) if fields.nonEmpty => + // Successfully identified required fields via schema projection + fields + case _ => + // The immediate parent might be a ProjectExec that just aliases the output. + // Try to look at its parent (the grandparent) for GetStructField references. + planMeta.parent match { + case Some(grandParentMeta: SparkPlanMeta[_]) => + analyzeDownstreamProject(grandParentMeta) match { + case Some(fields) if fields.nonEmpty => fields + case _ => allFieldNames + } + case _ => allFieldNames + } + } + case None => + // No parent SparkPlanMeta found in the meta tree, assume all fields are needed + allFieldNames + } + } + + /** + * Find the parent SparkPlanMeta by traversing up the parent chain. + */ + private def findParentPlanMeta(): Option[SparkPlanMeta[_]] = { + def traverse(meta: Option[RapidsMeta[_, _, _]]): Option[SparkPlanMeta[_]] = { + meta match { + case Some(p: SparkPlanMeta[_]) => Some(p) + case Some(p: RapidsMeta[_, _, _]) => traverse(p.parent) + case _ => None + } + } + traverse(parent) + } + + /** + * Analyze a Project plan to find which struct fields are actually used. + * This looks for GetStructField expressions that reference our protobuf output. + */ + private def analyzeDownstreamProject(planMeta: SparkPlanMeta[_]): Option[Set[String]] = { + planMeta.wrapped match { + case p: ProjectExec => + // Collect all GetStructField references from the project list + val fieldRefs = mutable.Set[String]() + var hasDirectStructRef = false + + p.projectList.foreach { expr => + collectStructFieldReferences(expr, fieldRefs, hasDirectStructRefHolder = () => { + hasDirectStructRef = true + }) + } + + if (hasDirectStructRef) { + // If the entire struct is referenced directly (not via GetStructField), + // we need all fields + None + } else if (fieldRefs.nonEmpty) { + Some(fieldRefs.toSet) + } else { + // No GetStructField found - this shouldn't happen for valid plans + // where from_protobuf is followed by field access + None + } + case _ => + // Not a ProjectExec, cannot analyze schema projection + None + } + } + + /** + * Recursively collect field names from GetStructField expressions. + * Also tracks if the struct is used directly without field extraction. + */ + private def collectStructFieldReferences( + expr: Expression, + fieldRefs: mutable.Set[String], + hasDirectStructRefHolder: () => Unit): Unit = { + expr match { + case GetStructField(child, ordinal, nameOpt) => + // Check if this GetStructField extracts from our protobuf struct + if (isProtobufStructReference(child)) { + // Get field name from the schema using ordinal + val fieldName = nameOpt.getOrElse { + if (ordinal < fullSchema.fields.length) { + fullSchema.fields(ordinal).name + } else { + s"_$ordinal" + } + } + fieldRefs += fieldName + // Don't recurse into child - we've handled this protobuf reference + } else { + // Child is not a protobuf struct, recurse to check for nested access + collectStructFieldReferences(child, fieldRefs, hasDirectStructRefHolder) + } + + case _ => + // Check if this expression directly references our protobuf struct + // without extracting a field (e.g., passing the whole struct to a function) + if (isProtobufStructReference(expr)) { + hasDirectStructRefHolder() + } + // Recursively check children + expr.children.foreach { child => + collectStructFieldReferences(child, fieldRefs, hasDirectStructRefHolder) + } + } + } + + /** + * Check if an expression references the output of a protobuf decode expression. + * This can be either: + * 1. The ProtobufDataToCatalyst expression itself + * 2. An AttributeReference that references the output of ProtobufDataToCatalyst + * (when accessing from a downstream ProjectExec) + */ + private def isProtobufStructReference(expr: Expression): Boolean = { + // Check if expr is a ProtobufDataToCatalyst expression + if (expr.getClass.getName.contains("ProtobufDataToCatalyst")) { + return true + } + + // Check if expr is an AttributeReference with the same schema as our protobuf output + // This handles the case where GetStructField references a column from a parent Project + expr match { + case attr: AttributeReference => + // Check if the data type matches our full schema (struct type from protobuf) + attr.dataType match { + case st: StructType => + // Compare field names and types only. We intentionally do not compare + // nullable flags because schema transformations (like projections or + // certain optimizations) may change nullability while the underlying + // schema structure remains the same. For schema projection detection, + // matching names and types is sufficient to identify protobuf output. + st.fields.length == fullSchema.fields.length && + st.fields.zip(fullSchema.fields).forall { case (a, b) => + a.name == b.name && a.dataType == b.dataType + } + case _ => false + } + case _ => false + } + } + + override def convertToGpu(child: Expression): GpuExpression = { + GpuFromProtobuf( + fullSchema, decodedFieldIndices, fieldNumbers, cudfTypeIds, cudfTypeScales, + failOnErrors, child) + } + } + ) + } + + private def getMessageName(e: Expression): String = + invoke0[String](e, "messageName") + + /** + * Newer Spark versions may carry an in-expression descriptor set payload + * (e.g. binaryDescriptorSet). + * Spark 3.4.x does not, so callers should fall back to descFilePath(). + */ + private def getDescriptorBytes(e: Expression): Option[Array[Byte]] = { + // Spark 4.x/3.5+ (depending on the API): may be Array[Byte] or Option[Array[Byte]]. + val direct = Try(invoke0[Array[Byte]](e, "binaryDescriptorSet")).toOption + direct.orElse { + Try(invoke0[Option[Array[Byte]]](e, "binaryDescriptorSet")).toOption.flatten + } + } + + private def getDescFilePath(e: Expression): Option[String] = + Try(invoke0[Option[String]](e, "descFilePath")).toOption.flatten + + private def writeTempDescFile(descBytes: Array[Byte]): String = { + val tmp: Path = Files.createTempFile("spark-rapids-protobuf-desc-", ".desc") + Files.write(tmp, descBytes) + // deleteOnExit() is not guaranteed to run on abnormal JVM termination, but these + // descriptor files are small (typically < 10KB) and only created when using + // binaryDescriptorSet (Spark 4.0+). The risk of temporary file accumulation is + // acceptable for this use case. + tmp.toFile.deleteOnExit() + tmp.toString + } + + private def buildMessageDescriptorWithSparkProtobuf( + messageName: String, + descFilePathOpt: Option[String]): AnyRef = { + val cls = ShimReflectionUtils.loadClass(sparkProtobufUtilsObjectClassName) + val module = cls.getField("MODULE$").get(null) + // buildDescriptor(messageName: String, descFilePath: Option[String]) + val m = cls.getMethod("buildDescriptor", classOf[String], classOf[scala.Option[_]]) + m.invoke(module, messageName, descFilePathOpt).asInstanceOf[AnyRef] + } + + private def typeName(t: AnyRef): String = { + if (t == null) { + "null" + } else { + // Prefer Enum.name() when available; fall back to toString. + Try(invoke0[String](t, "name")).getOrElse(t.toString) + } + } + + private def getOptionsMap(e: Expression): Map[String, String] = { + val opt = Try(invoke0[scala.collection.Map[String, String]](e, "options")).toOption + opt.map(_.toMap).getOrElse(Map.empty) + } + + private def invoke0[T](obj: AnyRef, method: String): T = + obj.getClass.getMethod(method).invoke(obj).asInstanceOf[T] + + private def invoke1[T](obj: AnyRef, method: String, arg0Cls: Class[_], arg0: AnyRef): T = + obj.getClass.getMethod(method, arg0Cls).invoke(obj, arg0).asInstanceOf[T] +} diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala index 6e28a071a00..56bfa229051 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -162,7 +162,7 @@ trait Spark340PlusNonDBShims extends Spark331PlusNonDBShims { ), GpuElementAtMeta.elementAtRule(true) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap - super.getExprs ++ shimExprs + super.getExprs ++ shimExprs ++ ProtobufExprShims.exprs } override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],