From b9011b3dd7a1713a6d4300104ab1b4d4374e2dda Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 24 Sep 2025 18:58:18 +0200 Subject: [PATCH] [VL] Add convenient C2R API --- .../execution/VeloxColumnarToRowExec.scala | 126 ++++++++++-------- .../ColumnarCachedBatchSerializer.scala | 4 +- 2 files changed, 70 insertions(+), 60 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala index 31cda32dad18..097c7d489785 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala @@ -26,7 +26,7 @@ import org.apache.gluten.vectorized.{NativeColumnarToRowInfo, NativeColumnarToRo import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.{BroadcastUtils, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types._ @@ -74,7 +74,7 @@ case class VeloxColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBas child.executeColumnar().mapPartitions { it => VeloxColumnarToRowExec - .toRowIterator(it, output, numOutputRows, numInputBatches, convertTime) + .toRowIterator(it, numOutputRows, numInputBatches, convertTime) } } @@ -89,7 +89,7 @@ case class VeloxColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBas sparkContext, mode, relation, - VeloxColumnarToRowExec.toRowIterator(_, output, numOutputRows, numInputBatches, convertTime)) + VeloxColumnarToRowExec.toRowIterator(_, numOutputRows, numInputBatches, convertTime)) } protected def withNewChildInternal(newChild: SparkPlan): VeloxColumnarToRowExec = @@ -98,23 +98,20 @@ case class VeloxColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBas object VeloxColumnarToRowExec { - def toRowIterator( - batches: Iterator[ColumnarBatch], - output: Seq[Attribute]): Iterator[InternalRow] = { + def toRowIterator(batches: Iterator[ColumnarBatch]): Iterator[InternalRow] = { val numOutputRows = new SQLMetric("numOutputRows") val numInputBatches = new SQLMetric("numInputBatches") val convertTime = new SQLMetric("convertTime") toRowIterator( batches, - output, numOutputRows, numInputBatches, convertTime ) } + def toRowIterator( batches: Iterator[ColumnarBatch], - output: Seq[Attribute], numOutputRows: SQLMetric, numInputBatches: SQLMetric, convertTime: SQLMetric): Iterator[InternalRow] = { @@ -122,13 +119,9 @@ object VeloxColumnarToRowExec { return Iterator.empty } - val runtime = Runtimes.contextInstance(BackendsApiManager.getBackendName, "ColumnarToRow") - // TODO: Pass the jni jniWrapper and arrowSchema and serializeSchema method by broadcast. - val jniWrapper = NativeColumnarToRowJniWrapper.create(runtime) - val c2rId = jniWrapper.nativeColumnarToRowInit() + val converter = new Converter(convertTime) val res: Iterator[Iterator[InternalRow]] = new Iterator[Iterator[InternalRow]] { - override def hasNext: Boolean = { batches.hasNext } @@ -137,61 +130,80 @@ object VeloxColumnarToRowExec { val batch = batches.next() numInputBatches += 1 numOutputRows += batch.numRows() + converter.toRowIterator(batch) + } + } + Iterators + .wrap(res.flatten) + .protectInvocationFlow() // Spark may call `hasNext()` again after a false output which + // is not allowed by Gluten iterators. E.g. GroupedIterator#fetchNextGroupIterator + .recycleIterator { + converter.close() + } + .create() + } - if (batch.numRows == 0) { - batch.close() - return Iterator.empty - } + /** + * A convenient C2R API to allow caller converts batches on demand without having to pass in an + * Iterator[ColumnarBatch]. + */ + class Converter(convertTime: SQLMetric) { + private val runtime = + Runtimes.contextInstance(BackendsApiManager.getBackendName, "VeloxColumnarToRow") + // TODO: Pass the jni jniWrapper and arrowSchema and serializeSchema method by broadcast. + private val jniWrapper = NativeColumnarToRowJniWrapper.create(runtime) + private val c2rId = jniWrapper.nativeColumnarToRowInit() - if (output.isEmpty) { - val rows = ColumnarBatches.emptyRowIterator(batch.numRows()).asScala - batch.close() - return rows - } + def toRowIterator(batch: ColumnarBatch): Iterator[InternalRow] = { + if (batch.numRows() == 0) { + return Iterator.empty + } + + if (batch.numCols() == 0) { + val rows = ColumnarBatches.emptyRowIterator(batch.numRows()).asScala + return rows + } - VeloxColumnarBatches.checkVeloxBatch(batch) + VeloxColumnarBatches.checkVeloxBatch(batch) - val cols = batch.numCols() - val rows = batch.numRows() - val batchHandle = ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, batch) - var info: NativeColumnarToRowInfo = null + new Iterator[InternalRow] { + private val cols = batch.numCols() + private val rows = batch.numRows() + private val batchHandle = + ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, batch) - new Iterator[InternalRow] { - var rowId = 0 - var baseLength = 0 - val row = new UnsafeRow(cols) + // Mutable members. + private var rowId = 0 + private var baseLength = 0 + private val row = new UnsafeRow(cols) + private var info: NativeColumnarToRowInfo = _ - override def hasNext: Boolean = { - rowId < rows - } + override def hasNext: Boolean = { + rowId < rows + } - override def next: UnsafeRow = { - if (rowId == 0 || rowId == baseLength + info.lengths.length) { - baseLength = if (info == null) { - baseLength - } else { - baseLength + info.lengths.length - } - val before = System.currentTimeMillis() - info = jniWrapper.nativeColumnarToRowConvert(c2rId, batchHandle, rowId) - convertTime += (System.currentTimeMillis() - before) + override def next(): InternalRow = { + if (rowId == 0 || rowId == baseLength + info.lengths.length) { + baseLength = if (info == null) { + baseLength + } else { + baseLength + info.lengths.length } - val (offset, length) = - (info.offsets(rowId - baseLength), info.lengths(rowId - baseLength)) - row.pointTo(null, info.memoryAddress + offset, length) - rowId += 1 - row + val before = System.currentTimeMillis() + info = jniWrapper.nativeColumnarToRowConvert(c2rId, batchHandle, rowId) + convertTime += (System.currentTimeMillis() - before) } + val (offset, length) = + (info.offsets(rowId - baseLength), info.lengths(rowId - baseLength)) + row.pointTo(null, info.memoryAddress + offset, length) + rowId += 1 + row } } } - Iterators - .wrap(res.flatten) - .protectInvocationFlow() // Spark may call `hasNext()` again after a false output which - // is not allowed by Gluten iterators. E.g. GroupedIterator#fetchNextGroupIterator - .recycleIterator { - jniWrapper.nativeClose(c2rId) - } - .create() + + def close(): Unit = { + jniWrapper.nativeClose(c2rId) + } } } diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala index f8d6bd886b6b..80e5039e76da 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala @@ -152,9 +152,7 @@ class ColumnarCachedBatchSerializer extends CachedBatchSerializer with Logging { val rddColumnarBatch = convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf) - rddColumnarBatch.mapPartitions { - it => VeloxColumnarToRowExec.toRowIterator(it, selectedAttributes) - } + rddColumnarBatch.mapPartitions(it => VeloxColumnarToRowExec.toRowIterator(it)) } override def convertColumnarBatchToCachedBatch(