diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/ArrowColumnarToVeloxColumnarExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/ArrowColumnarToVeloxColumnarExec.scala index 920988675087..cca381ee8436 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/ArrowColumnarToVeloxColumnarExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/ArrowColumnarToVeloxColumnarExec.scala @@ -19,12 +19,19 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.arrow.ArrowBatchTypes.ArrowNativeBatchType import org.apache.gluten.backendsapi.velox.VeloxBatchType import org.apache.gluten.columnarbatch.VeloxColumnarBatches +import org.apache.gluten.extension.columnar.transition.Convention import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.vectorized.ColumnarBatch case class ArrowColumnarToVeloxColumnarExec(override val child: SparkPlan) - extends ColumnarToColumnarExec(ArrowNativeBatchType, VeloxBatchType) { + extends ColumnarToColumnarExec(child) + with GlutenColumnarToColumnarTransition { + + override protected val from: Convention.BatchType = ArrowNativeBatchType + + override protected val to: Convention.BatchType = VeloxBatchType + override protected def mapIterator(in: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { in.map(b => VeloxColumnarBatches.toVeloxBatch(b)) } diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxResizeBatchesExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxResizeBatchesExec.scala index a1ec54ffbc13..acca0c7b68aa 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxResizeBatchesExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxResizeBatchesExec.scala @@ -17,6 +17,7 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.velox.VeloxBatchType +import org.apache.gluten.extension.columnar.transition.Convention import org.apache.gluten.iterator.ClosableIterator import org.apache.gluten.utils.VeloxBatchResizer @@ -35,7 +36,7 @@ case class VeloxResizeBatchesExec( override val child: SparkPlan, minOutputBatchSize: Int, maxOutputBatchSize: Int) - extends ColumnarToColumnarExec(VeloxBatchType, VeloxBatchType) { + extends ColumnarToColumnarExec(child) { override protected def mapIterator(in: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { VeloxBatchResizer.create(minOutputBatchSize, maxOutputBatchSize, in.asJava).asScala @@ -54,4 +55,8 @@ case class VeloxResizeBatchesExec( override def outputOrdering: Seq[SortOrder] = child.outputOrdering override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) + + override def batchType(): Convention.BatchType = VeloxBatchType + + override def rowType0(): Convention.RowType = Convention.RowType.None } diff --git a/gluten-arrow/src/main/scala/org/apache/gluten/execution/LoadArrowDataExec.scala b/gluten-arrow/src/main/scala/org/apache/gluten/execution/LoadArrowDataExec.scala index 1f1750113ad5..133d67b0f164 100644 --- a/gluten-arrow/src/main/scala/org/apache/gluten/execution/LoadArrowDataExec.scala +++ b/gluten-arrow/src/main/scala/org/apache/gluten/execution/LoadArrowDataExec.scala @@ -18,6 +18,7 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.arrow.ArrowBatchTypes.{ArrowJavaBatchType, ArrowNativeBatchType} import org.apache.gluten.columnarbatch.ColumnarBatches +import org.apache.gluten.extension.columnar.transition.Convention import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.spark.sql.execution.SparkPlan @@ -25,7 +26,13 @@ import org.apache.spark.sql.vectorized.ColumnarBatch /** Converts input data with batch type [[ArrowNativeBatchType]] to type [[ArrowJavaBatchType]]. */ case class LoadArrowDataExec(override val child: SparkPlan) - extends ColumnarToColumnarExec(ArrowNativeBatchType, ArrowJavaBatchType) { + extends ColumnarToColumnarExec(child) + with GlutenColumnarToColumnarTransition { + + override protected val from: Convention.BatchType = ArrowNativeBatchType + + override protected val to: Convention.BatchType = ArrowJavaBatchType + override protected def mapIterator(in: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { in.map(b => ColumnarBatches.load(ArrowBufferAllocators.contextInstance, b)) } diff --git a/gluten-arrow/src/main/scala/org/apache/gluten/execution/OffloadArrowDataExec.scala b/gluten-arrow/src/main/scala/org/apache/gluten/execution/OffloadArrowDataExec.scala index 6e548adbf6cf..256c178fdb25 100644 --- a/gluten-arrow/src/main/scala/org/apache/gluten/execution/OffloadArrowDataExec.scala +++ b/gluten-arrow/src/main/scala/org/apache/gluten/execution/OffloadArrowDataExec.scala @@ -18,6 +18,7 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.arrow.ArrowBatchTypes.{ArrowJavaBatchType, ArrowNativeBatchType} import org.apache.gluten.columnarbatch.ColumnarBatches +import org.apache.gluten.extension.columnar.transition.Convention import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.spark.sql.execution.SparkPlan @@ -25,7 +26,13 @@ import org.apache.spark.sql.vectorized.ColumnarBatch /** Converts input data with batch type [[ArrowJavaBatchType]] to type [[ArrowNativeBatchType]]. */ case class OffloadArrowDataExec(override val child: SparkPlan) - extends ColumnarToColumnarExec(ArrowJavaBatchType, ArrowNativeBatchType) { + extends ColumnarToColumnarExec(child) + with GlutenColumnarToColumnarTransition { + + override protected val from: Convention.BatchType = ArrowJavaBatchType + + override protected val to: Convention.BatchType = ArrowNativeBatchType + override protected def mapIterator(in: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { in.map(b => ColumnarBatches.offload(ArrowBufferAllocators.contextInstance, b)) } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToColumnarExec.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToColumnarExec.scala index f19b89898388..f825bc1f1810 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToColumnarExec.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToColumnarExec.scala @@ -16,25 +16,20 @@ */ package org.apache.gluten.execution -import org.apache.gluten.extension.columnar.transition.{Convention, ConventionReq} import org.apache.gluten.iterator.Iterators import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.vectorized.ColumnarBatch import java.util.concurrent.atomic.AtomicLong -abstract class ColumnarToColumnarExec(from: Convention.BatchType, to: Convention.BatchType) - extends ColumnarToColumnarTransition - with GlutenPlan { - - override def isSameConvention: Boolean = from == to - - def child: SparkPlan +abstract class ColumnarToColumnarExec(override val child: SparkPlan) + extends GlutenPlan + with UnaryExecNode { protected def mapIterator(in: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] @@ -51,17 +46,8 @@ abstract class ColumnarToColumnarExec(from: Convention.BatchType, to: Convention "selfTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to convert batches") ) - override def batchType(): Convention.BatchType = to - - override def rowType0(): Convention.RowType = { - Convention.RowType.None - } - - override def requiredChildConvention(): Seq[ConventionReq] = { - List(ConventionReq.ofBatch(ConventionReq.BatchType.Is(from))) - } - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { val numInputRows = longMetric("numInputRows") val numInputBatches = longMetric("numInputBatches") diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToColumnarTransition.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToColumnarTransition.scala index 8c4757b45c83..8284db0f62d2 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToColumnarTransition.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToColumnarTransition.scala @@ -18,6 +18,9 @@ package org.apache.gluten.execution import org.apache.spark.sql.execution.UnaryExecNode -trait ColumnarToColumnarTransition extends UnaryExecNode { - def isSameConvention: Boolean -} +/** + * A columnar-to-columnar transition. By implementing this trait, the class will be seen by + * [[org.apache.gluten.extension.columnar.transition.RemoveTransitions]] and removed when that rule + * is executed. + */ +trait ColumnarToColumnarTransition extends UnaryExecNode diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/GlutenColumnarToColumnarTransition.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/GlutenColumnarToColumnarTransition.scala new file mode 100644 index 000000000000..d2f62f0bf911 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/GlutenColumnarToColumnarTransition.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.gluten.extension.columnar.transition.{Convention, ConventionReq} + +/** + * A convenience trait for [[GlutenPlan]] to implement [[ColumnarToColumnarTransition]] at the same + * time. Note the implementation class will be seen by + * [[org.apache.gluten.extension.columnar.transition.RemoveTransitions]] and removed when that rule + * is executed. + */ +trait GlutenColumnarToColumnarTransition extends ColumnarToColumnarTransition with GlutenPlan { + protected val from: Convention.BatchType + protected val to: Convention.BatchType + + override def batchType(): Convention.BatchType = to + + override def rowType0(): Convention.RowType = { + Convention.RowType.None + } + + override def requiredChildConvention(): Seq[ConventionReq] = { + List(ConventionReq.ofBatch(ConventionReq.BatchType.Is(from))) + } +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala index 9f309b843f55..dbec061b15e3 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala @@ -69,7 +69,7 @@ package object transition { object ColumnarToColumnarLike { def unapply(plan: SparkPlan): Option[SparkPlan] = { plan match { - case c2c: ColumnarToColumnarTransition if !c2c.isSameConvention => + case c2c: ColumnarToColumnarTransition => Some(c2c.child) case _ => None } diff --git a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuiteBase.scala b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuiteBase.scala index 37240085094f..827881ebc95c 100644 --- a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuiteBase.scala +++ b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuiteBase.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.extension.columnar.transition -import org.apache.gluten.execution.{ColumnarToColumnarExec, GlutenPlan} +import org.apache.gluten.execution.{ColumnarToColumnarExec, GlutenColumnarToColumnarTransition, GlutenPlan} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -146,10 +146,11 @@ object TransitionSuiteBase { } case class BatchToBatch( - from: Convention.BatchType, - to: Convention.BatchType, + override val from: Convention.BatchType, + override val to: Convention.BatchType, override val child: SparkPlan) - extends ColumnarToColumnarExec(from, to) { + extends ColumnarToColumnarExec(child) + with GlutenColumnarToColumnarTransition { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException()