Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.gluten.vectorized.{NativeColumnarToRowInfo, NativeColumnarToRo
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.{BroadcastUtils, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -74,7 +74,7 @@ case class VeloxColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBas
child.executeColumnar().mapPartitions {
it =>
VeloxColumnarToRowExec
.toRowIterator(it, output, numOutputRows, numInputBatches, convertTime)
.toRowIterator(it, numOutputRows, numInputBatches, convertTime)
}
}

Expand All @@ -89,7 +89,7 @@ case class VeloxColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBas
sparkContext,
mode,
relation,
VeloxColumnarToRowExec.toRowIterator(_, output, numOutputRows, numInputBatches, convertTime))
VeloxColumnarToRowExec.toRowIterator(_, numOutputRows, numInputBatches, convertTime))
}

protected def withNewChildInternal(newChild: SparkPlan): VeloxColumnarToRowExec =
Expand All @@ -98,37 +98,30 @@ case class VeloxColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBas

object VeloxColumnarToRowExec {

def toRowIterator(
batches: Iterator[ColumnarBatch],
output: Seq[Attribute]): Iterator[InternalRow] = {
def toRowIterator(batches: Iterator[ColumnarBatch]): Iterator[InternalRow] = {
val numOutputRows = new SQLMetric("numOutputRows")
val numInputBatches = new SQLMetric("numInputBatches")
val convertTime = new SQLMetric("convertTime")
toRowIterator(
batches,
output,
numOutputRows,
numInputBatches,
convertTime
)
}

def toRowIterator(
batches: Iterator[ColumnarBatch],
output: Seq[Attribute],
numOutputRows: SQLMetric,
numInputBatches: SQLMetric,
convertTime: SQLMetric): Iterator[InternalRow] = {
if (batches.isEmpty) {
return Iterator.empty
}

val runtime = Runtimes.contextInstance(BackendsApiManager.getBackendName, "ColumnarToRow")
// TODO: Pass the jni jniWrapper and arrowSchema and serializeSchema method by broadcast.
val jniWrapper = NativeColumnarToRowJniWrapper.create(runtime)
val c2rId = jniWrapper.nativeColumnarToRowInit()
val converter = new Converter(convertTime)

val res: Iterator[Iterator[InternalRow]] = new Iterator[Iterator[InternalRow]] {

override def hasNext: Boolean = {
batches.hasNext
}
Expand All @@ -137,61 +130,80 @@ object VeloxColumnarToRowExec {
val batch = batches.next()
numInputBatches += 1
numOutputRows += batch.numRows()
converter.toRowIterator(batch)
}
}
Iterators
.wrap(res.flatten)
.protectInvocationFlow() // Spark may call `hasNext()` again after a false output which
// is not allowed by Gluten iterators. E.g. GroupedIterator#fetchNextGroupIterator
.recycleIterator {
converter.close()
}
.create()
}

if (batch.numRows == 0) {
batch.close()
return Iterator.empty
}
/**
* A convenient C2R API to allow caller converts batches on demand without having to pass in an
* Iterator[ColumnarBatch].
*/
class Converter(convertTime: SQLMetric) {
private val runtime =
Runtimes.contextInstance(BackendsApiManager.getBackendName, "VeloxColumnarToRow")
// TODO: Pass the jni jniWrapper and arrowSchema and serializeSchema method by broadcast.
private val jniWrapper = NativeColumnarToRowJniWrapper.create(runtime)
private val c2rId = jniWrapper.nativeColumnarToRowInit()

if (output.isEmpty) {
val rows = ColumnarBatches.emptyRowIterator(batch.numRows()).asScala
batch.close()
return rows
}
def toRowIterator(batch: ColumnarBatch): Iterator[InternalRow] = {
if (batch.numRows() == 0) {
return Iterator.empty
}

if (batch.numCols() == 0) {
val rows = ColumnarBatches.emptyRowIterator(batch.numRows()).asScala
return rows
}

VeloxColumnarBatches.checkVeloxBatch(batch)
VeloxColumnarBatches.checkVeloxBatch(batch)

val cols = batch.numCols()
val rows = batch.numRows()
val batchHandle = ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, batch)
var info: NativeColumnarToRowInfo = null
new Iterator[InternalRow] {
private val cols = batch.numCols()
private val rows = batch.numRows()
private val batchHandle =
ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, batch)

new Iterator[InternalRow] {
var rowId = 0
var baseLength = 0
val row = new UnsafeRow(cols)
// Mutable members.
private var rowId = 0
private var baseLength = 0
private val row = new UnsafeRow(cols)
private var info: NativeColumnarToRowInfo = _

override def hasNext: Boolean = {
rowId < rows
}
override def hasNext: Boolean = {
rowId < rows
}

override def next: UnsafeRow = {
if (rowId == 0 || rowId == baseLength + info.lengths.length) {
baseLength = if (info == null) {
baseLength
} else {
baseLength + info.lengths.length
}
val before = System.currentTimeMillis()
info = jniWrapper.nativeColumnarToRowConvert(c2rId, batchHandle, rowId)
convertTime += (System.currentTimeMillis() - before)
override def next(): InternalRow = {
if (rowId == 0 || rowId == baseLength + info.lengths.length) {
baseLength = if (info == null) {
baseLength
} else {
baseLength + info.lengths.length
}
val (offset, length) =
(info.offsets(rowId - baseLength), info.lengths(rowId - baseLength))
row.pointTo(null, info.memoryAddress + offset, length)
rowId += 1
row
val before = System.currentTimeMillis()
info = jniWrapper.nativeColumnarToRowConvert(c2rId, batchHandle, rowId)
convertTime += (System.currentTimeMillis() - before)
}
val (offset, length) =
(info.offsets(rowId - baseLength), info.lengths(rowId - baseLength))
row.pointTo(null, info.memoryAddress + offset, length)
rowId += 1
row
}
}
}
Iterators
.wrap(res.flatten)
.protectInvocationFlow() // Spark may call `hasNext()` again after a false output which
// is not allowed by Gluten iterators. E.g. GroupedIterator#fetchNextGroupIterator
.recycleIterator {
jniWrapper.nativeClose(c2rId)
}
.create()

def close(): Unit = {
jniWrapper.nativeClose(c2rId)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,7 @@ class ColumnarCachedBatchSerializer extends CachedBatchSerializer with Logging {

val rddColumnarBatch =
convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf)
rddColumnarBatch.mapPartitions {
it => VeloxColumnarToRowExec.toRowIterator(it, selectedAttributes)
}
rddColumnarBatch.mapPartitions(it => VeloxColumnarToRowExec.toRowIterator(it))
}

override def convertColumnarBatchToCachedBatch(
Expand Down