From ac2e71eb750704918c5d40bfcf01c45e151f05c9 Mon Sep 17 00:00:00 2001 From: PHILO-HE Date: Tue, 3 Mar 2026 11:35:27 +0800 Subject: [PATCH] Initial --- .../clickhouse/ClickHouseSparkCatalog.scala | 7 +- .../clickhouse/ClickHouseSparkCatalog.scala | 7 +- .../clickhouse/ClickHouseSparkCatalog.scala | 5 +- .../clickhouse/CHIteratorApi.scala | 10 +- .../clickhouse/CHSparkPlanExecApi.scala | 15 +- .../backendsapi/velox/VeloxIteratorApi.scala | 4 +- .../VeloxBloomFilterMightContain.scala | 6 +- .../aggregate/VeloxBloomFilterAggregate.scala | 7 +- .../gluten/extension/ArrowConvertorRule.scala | 8 +- ...omFilterMightContainJointRewriteRule.scala | 39 ++++- .../execution/JoinExecTransformer.scala | 2 +- .../softaffinity/SoftAffinityManager.scala | 6 +- .../spark/shuffle/GlutenShuffleUtils.scala | 10 +- .../SoftAffinityWithRDDInfoSuite.scala | 164 +++++++++--------- .../apache/gluten/sql/shims/SparkShims.scala | 91 +--------- .../sql/shims/spark33/Spark33Shims.scala | 129 +------------- .../sql/shims/spark34/Spark34Shims.scala | 128 +------------- .../sql/shims/spark35/Spark35Shims.scala | 129 +------------- .../sql/shims/spark40/Spark40Shims.scala | 129 +------------- .../sql/shims/spark41/Spark41Shims.scala | 129 +------------- 20 files changed, 160 insertions(+), 865 deletions(-) diff --git a/backends-clickhouse/src-delta20/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala b/backends-clickhouse/src-delta20/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala index 61e1da44d0af..47b2ae2bd1a1 100644 --- a/backends-clickhouse/src-delta20/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala +++ b/backends-clickhouse/src-delta20/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala @@ -16,8 +16,6 @@ */ package org.apache.spark.sql.execution.datasources.v2.clickhouse -import org.apache.gluten.sql.shims.SparkShimLoader - import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException} @@ -35,6 +33,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging import org.apache.spark.sql.delta.sources.{DeltaSourceUtils, DeltaSQLConf} import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.v2.clickhouse.utils.CHDataSourceUtils +import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.StructType @@ -119,7 +118,7 @@ class ClickHouseSparkCatalog sourceQuery: Option[DataFrame], operation: TableCreationModes.CreationMode): Table = { val (partitionColumns, maybeBucketSpec) = - SparkShimLoader.getSparkShims.convertPartitionTransforms(partitions) + CatalogUtil.convertPartitionTransforms(partitions) var newSchema = schema var newPartitionColumns = partitionColumns var newBucketSpec = maybeBucketSpec @@ -232,7 +231,7 @@ class ClickHouseSparkCatalog case _ => true }.toMap val (partitionColumns, maybeBucketSpec) = - SparkShimLoader.getSparkShims.convertPartitionTransforms(partitions) + CatalogUtil.convertPartitionTransforms(partitions) var newSchema = schema var newPartitionColumns = partitionColumns var newBucketSpec = maybeBucketSpec diff --git a/backends-clickhouse/src-delta23/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala b/backends-clickhouse/src-delta23/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala index 61e1da44d0af..47b2ae2bd1a1 100644 --- a/backends-clickhouse/src-delta23/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala +++ b/backends-clickhouse/src-delta23/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala @@ -16,8 +16,6 @@ */ package org.apache.spark.sql.execution.datasources.v2.clickhouse -import org.apache.gluten.sql.shims.SparkShimLoader - import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException} @@ -35,6 +33,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging import org.apache.spark.sql.delta.sources.{DeltaSourceUtils, DeltaSQLConf} import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.v2.clickhouse.utils.CHDataSourceUtils +import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.StructType @@ -119,7 +118,7 @@ class ClickHouseSparkCatalog sourceQuery: Option[DataFrame], operation: TableCreationModes.CreationMode): Table = { val (partitionColumns, maybeBucketSpec) = - SparkShimLoader.getSparkShims.convertPartitionTransforms(partitions) + CatalogUtil.convertPartitionTransforms(partitions) var newSchema = schema var newPartitionColumns = partitionColumns var newBucketSpec = maybeBucketSpec @@ -232,7 +231,7 @@ class ClickHouseSparkCatalog case _ => true }.toMap val (partitionColumns, maybeBucketSpec) = - SparkShimLoader.getSparkShims.convertPartitionTransforms(partitions) + CatalogUtil.convertPartitionTransforms(partitions) var newSchema = schema var newPartitionColumns = partitionColumns var newBucketSpec = maybeBucketSpec diff --git a/backends-clickhouse/src-delta33/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala b/backends-clickhouse/src-delta33/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala index dde7013962d0..6873ad776f8b 100644 --- a/backends-clickhouse/src-delta33/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala +++ b/backends-clickhouse/src-delta33/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala @@ -16,8 +16,6 @@ */ package org.apache.spark.sql.execution.datasources.v2.clickhouse -import org.apache.gluten.sql.shims.SparkShimLoader - import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier @@ -39,6 +37,7 @@ import org.apache.spark.sql.delta.sources.{DeltaSourceUtils, DeltaSQLConf} import org.apache.spark.sql.delta.stats.StatisticsCollection import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.v2.clickhouse.utils.CHDataSourceUtils +import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.StructType @@ -136,7 +135,7 @@ class ClickHouseSparkCatalog sourceQuery: Option[DataFrame], operation: TableCreationModes.CreationMode): Table = { val (partitionColumns, maybeBucketSpec) = - SparkShimLoader.getSparkShims.convertPartitionTransforms(partitions) + CatalogUtil.convertPartitionTransforms(partitions) var newSchema = schema var newPartitionColumns = partitionColumns var newBucketSpec = maybeBucketSpec diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala index 2cd9d8516493..16dc2bccba92 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala @@ -199,14 +199,8 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { } partitionColumns.add(partitionColumn) - val (fileSize, modificationTime) = - SparkShimLoader.getSparkShims.getFileSizeAndModificationTime(file) - (fileSize, modificationTime) match { - case (Some(size), Some(time)) => - fileSizes.add(JLong.valueOf(size)) - modificationTimes.add(JLong.valueOf(time)) - case _ => - } + fileSizes.add(file.fileSize) + modificationTimes.add(file.modificationTime) val otherConstantMetadataColumnValues = DeltaShimLoader.getDeltaShims.convertRowIndexFilterIdEncoded( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 9208b48740ce..cdf2eae418b4 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -41,7 +41,7 @@ import org.apache.spark.shuffle.utils.CHShuffleUtil import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectList, CollectSet} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, CollectList, CollectSet} import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning, RangePartitioning} @@ -56,7 +56,7 @@ import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildS import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil} import org.apache.spark.sql.execution.window._ -import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SparkVersionUtil @@ -602,7 +602,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { CHFlattenedExpression.sigOr ) ++ ExpressionExtensionTrait.expressionExtensionSigList ++ - SparkShimLoader.getSparkShims.bloomFilterExpressionMappings() + Seq( + Sig[BloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN), + Sig[BloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG) + ) } /** Define backend-specific expression converter. */ @@ -940,12 +943,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate - override def genDecimalRoundExpressionOutput( - decimalType: DecimalType, - toScale: Int): DecimalType = { - SparkShimLoader.getSparkShims.genDecimalRoundExpressionOutput(decimalType, toScale) - } - override def genWindowGroupLimitTransformer( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala index 668e60b20542..575d3844fd53 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala @@ -83,9 +83,9 @@ class VeloxIteratorApi extends IteratorApi with Logging { val locations = filePartitions.flatMap(p => SoftAffinity.getFilePartitionLocations(p)) val (paths, starts, lengths) = getPartitionedFileInfo(partitionFiles).unzip3 val (fileSizes, modificationTimes) = partitionFiles - .map(f => SparkShimLoader.getSparkShims.getFileSizeAndModificationTime(f)) + .map(f => (f.fileSize, f.modificationTime)) .collect { - case (Some(size), Some(time)) => + case (size, time) => (JLong.valueOf(size), JLong.valueOf(time)) } .unzip diff --git a/backends-velox/src/main/scala/org/apache/gluten/expression/VeloxBloomFilterMightContain.scala b/backends-velox/src/main/scala/org/apache/gluten/expression/VeloxBloomFilterMightContain.scala index 220d98f40fcf..0e5ab142cb2a 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/expression/VeloxBloomFilterMightContain.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/expression/VeloxBloomFilterMightContain.scala @@ -16,12 +16,11 @@ */ package org.apache.gluten.expression -import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.utils.VeloxBloomFilter import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, BloomFilterMightContain, Expression} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.types.DataType @@ -43,8 +42,7 @@ case class VeloxBloomFilterMightContain( extends BinaryExpression { import VeloxBloomFilterMightContain._ - private val delegate = - SparkShimLoader.getSparkShims.newMightContain(bloomFilterExpression, valueExpression) + private val delegate = BloomFilterMightContain(bloomFilterExpression, valueExpression) override def prettyName: String = "velox_might_contain" diff --git a/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxBloomFilterAggregate.scala b/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxBloomFilterAggregate.scala index a3d6f738a2b5..0632e5a81f48 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxBloomFilterAggregate.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxBloomFilterAggregate.scala @@ -16,13 +16,12 @@ */ package org.apache.gluten.expression.aggregate -import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.utils.VeloxBloomFilter import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.trees.TernaryLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -47,12 +46,12 @@ case class VeloxBloomFilterAggregate( extends TypedImperativeAggregate[BloomFilter] with TernaryLike[Expression] { - private val delegate = SparkShimLoader.getSparkShims.newBloomFilterAggregate[BloomFilter]( + private val delegate = BloomFilterAggregate( child, estimatedNumItemsExpression, numBitsExpression, mutableAggBufferOffset, - inputAggBufferOffset) + inputAggBufferOffset).asInstanceOf[TypedImperativeAggregate[BloomFilter]] override def prettyName: String = "velox_bloom_filter_agg" diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowConvertorRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowConvertorRule.scala index 925f2a6be94f..009674e810e2 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowConvertorRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/ArrowConvertorRule.scala @@ -19,14 +19,13 @@ package org.apache.gluten.extension import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.datasource.ArrowCSVFileFormat import org.apache.gluten.datasource.v2.ArrowCSVTable -import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.annotation.Experimental import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.PermissiveMode +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, PermissiveMode} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -102,6 +101,7 @@ case class ArrowConvertorRule(session: SparkSession) extends Rule[LogicalPlan] { } private def checkCsvOptions(csvOptions: CSVOptions, timeZone: String): Boolean = { + val default = new CSVOptions(CaseInsensitiveMap(Map()), csvOptions.columnPruning, timeZone) csvOptions.headerFlag && !csvOptions.multiLine && csvOptions.delimiter.length == 1 && csvOptions.quote == '\"' && @@ -112,7 +112,9 @@ case class ArrowConvertorRule(session: SparkSession) extends Rule[LogicalPlan] { csvOptions.nullValue == "" && csvOptions.emptyValueInRead == "" && csvOptions.comment == '\u0000' && csvOptions.columnPruning && - SparkShimLoader.getSparkShims.dateTimestampFormatInReadIsDefaultValue(csvOptions, timeZone) + csvOptions.dateFormatInRead == default.dateFormatInRead && + csvOptions.timestampFormatInRead == default.timestampFormatInRead && + csvOptions.timestampNTZFormatInRead == default.timestampNTZFormatInRead } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala index 9b743a4f22f5..efbedc6ca007 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala @@ -19,9 +19,10 @@ package org.apache.gluten.extension import org.apache.gluten.config.GlutenConfig import org.apache.gluten.expression.VeloxBloomFilterMightContain import org.apache.gluten.expression.aggregate.VeloxBloomFilterAggregate -import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, BloomFilterMightContain, Expression} +import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan @@ -40,12 +41,42 @@ case class BloomFilterMightContainJointRewriteRule( out } + private def replaceBloomFilterAggregate[T]( + expr: Expression, + bloomFilterAggReplacer: ( + Expression, + Expression, + Expression, + Int, + Int) => TypedImperativeAggregate[T]): Expression = expr match { + case BloomFilterAggregate( + child, + estimatedNumItemsExpression, + numBitsExpression, + mutableAggBufferOffset, + inputAggBufferOffset) => + bloomFilterAggReplacer( + child, + estimatedNumItemsExpression, + numBitsExpression, + mutableAggBufferOffset, + inputAggBufferOffset) + case other => other + } + + private def replaceMightContain[T]( + expr: Expression, + mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression = expr match { + case BloomFilterMightContain(bloomFilterExpression, valueExpression) => + mightContainReplacer(bloomFilterExpression, valueExpression) + case other => other + } + private def applyForNode(p: SparkPlan) = { p.transformExpressions { case e => - SparkShimLoader.getSparkShims.replaceMightContain( - SparkShimLoader.getSparkShims - .replaceBloomFilterAggregate(e, VeloxBloomFilterAggregate.apply), + replaceMightContain( + replaceBloomFilterAggregate(e, VeloxBloomFilterAggregate.apply), VeloxBloomFilterMightContain.apply) } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala index 980843f05c7a..e5db3385154d 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala @@ -55,7 +55,7 @@ trait ColumnarShuffledJoin extends BaseJoinExec { // partitioning doesn't satisfy `HashClusteredDistribution`. UnspecifiedDistribution :: UnspecifiedDistribution :: Nil } else { - SparkShimLoader.getSparkShims.getDistribution(leftKeys, rightKeys) + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/softaffinity/SoftAffinityManager.scala b/gluten-substrait/src/main/scala/org/apache/gluten/softaffinity/SoftAffinityManager.scala index bba178a796fc..1691a336acbf 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/softaffinity/SoftAffinityManager.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/softaffinity/SoftAffinityManager.scala @@ -20,7 +20,6 @@ import org.apache.gluten.config.GlutenConfig import org.apache.gluten.hash.ConsistentHash import org.apache.gluten.logging.LogLevelUtil import org.apache.gluten.softaffinity.strategy.{ConsistentHashSoftAffinityStrategy, ExecutorNode} -import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging @@ -185,7 +184,7 @@ abstract class AffinityManager extends LogLevelUtil with Logging { val partitions = rddPartitionInfoMap.getIfPresent(rddId) if (partitions != null) { val key = partitions - .filter(p => p._1 == SparkShimLoader.getSparkShims.getPartitionId(event.taskInfo)) + .filter(p => p._1 == event.taskInfo.partitionId) .map(pInfo => s"${pInfo._2}_${pInfo._3}_${pInfo._4}") .sortBy(p => p) .mkString(",") @@ -322,8 +321,7 @@ object SoftAffinityManager extends AffinityManager { override lazy val detectDuplicateReading: Boolean = SparkEnv.get.conf.getBoolean( GlutenConfig.GLUTEN_SOFT_AFFINITY_DUPLICATE_READING_DETECT_ENABLED.key, GlutenConfig.GLUTEN_SOFT_AFFINITY_DUPLICATE_READING_DETECT_ENABLED.defaultValue.get - ) && - SparkShimLoader.getSparkShims.supportDuplicateReadingTracking + ) override lazy val duplicateReadingMaxCacheItems: Int = SparkEnv.get.conf.getInt( GlutenConfig.GLUTEN_SOFT_AFFINITY_DUPLICATE_READING_MAX_CACHE_ITEMS.key, diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala index 80b0e94830c9..213b2831f46c 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala @@ -18,10 +18,9 @@ package org.apache.spark.shuffle import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.config.GlutenConfig -import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.vectorized.NativePartitioning -import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.{ShuffleUtils, SparkConf, TaskContext} import org.apache.spark.internal.config._ import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.shuffle.sort.ColumnarShuffleHandle @@ -120,12 +119,7 @@ object GlutenShuffleUtils { startPartition: Int, endPartition: Int ): Tuple2[Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], Boolean] = { - SparkShimLoader.getSparkShims.getShuffleReaderParam( - handle, - startMapIndex, - endMapIndex, - startPartition, - endPartition) + ShuffleUtils.getReaderParam(handle, startMapIndex, endMapIndex, startPartition, endPartition) } def getSortShuffleWriter[K, V]( diff --git a/gluten-substrait/src/test/scala/org/apache/spark/softaffinity/SoftAffinityWithRDDInfoSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/softaffinity/SoftAffinityWithRDDInfoSuite.scala index 5bfb781e7aea..f3af21939b1b 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/softaffinity/SoftAffinityWithRDDInfoSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/softaffinity/SoftAffinityWithRDDInfoSuite.scala @@ -51,97 +51,93 @@ class SoftAffinityWithRDDInfoSuite extends QueryTest with SharedSparkSession wit .set("spark.ui.enabled", "false") test("Soft Affinity Scheduler with duplicate reading detection") { - if (SparkShimLoader.getSparkShims.supportDuplicateReadingTracking) { - val addEvent0 = SparkListenerExecutorAdded( - System.currentTimeMillis(), - "0", - new ExecutorInfo("host-0", 3, null)) - val addEvent1 = SparkListenerExecutorAdded( - System.currentTimeMillis(), - "1", - new ExecutorInfo("host-1", 3, null)) - val removedEvent0 = SparkListenerExecutorRemoved(System.currentTimeMillis(), "0", "") - val removedEvent1 = SparkListenerExecutorRemoved(System.currentTimeMillis(), "1", "") - val rdd1 = new RDDInfo(1, "", 3, StorageLevel.NONE, false, Seq.empty) - val rdd2 = new RDDInfo(2, "", 3, StorageLevel.NONE, false, Seq.empty) - var stage1 = new StageInfo(1, 0, "", 1, Seq(rdd1, rdd2), Seq.empty, "", resourceProfileId = 0) - val stage1SubmitEvent = SparkListenerStageSubmitted(stage1) - val stage1EndEvent = SparkListenerStageCompleted(stage1) - val taskEnd1 = SparkListenerTaskEnd( - 1, + val addEvent0 = SparkListenerExecutorAdded( + System.currentTimeMillis(), + "0", + new ExecutorInfo("host-0", 3, null)) + val addEvent1 = SparkListenerExecutorAdded( + System.currentTimeMillis(), + "1", + new ExecutorInfo("host-1", 3, null)) + val removedEvent0 = SparkListenerExecutorRemoved(System.currentTimeMillis(), "0", "") + val removedEvent1 = SparkListenerExecutorRemoved(System.currentTimeMillis(), "1", "") + val rdd1 = new RDDInfo(1, "", 3, StorageLevel.NONE, false, Seq.empty) + val rdd2 = new RDDInfo(2, "", 3, StorageLevel.NONE, false, Seq.empty) + var stage1 = new StageInfo(1, 0, "", 1, Seq(rdd1, rdd2), Seq.empty, "", resourceProfileId = 0) + val stage1SubmitEvent = SparkListenerStageSubmitted(stage1) + val stage1EndEvent = SparkListenerStageCompleted(stage1) + val taskEnd1 = SparkListenerTaskEnd( + 1, + 0, + "", + org.apache.spark.Success, + // this is little tricky here for 3.2 compatibility, we use -1 for partition id. + new TaskInfo(1, 1, 1, 1L, "0", "host-0", TaskLocality.ANY, false), + null, + null + ) + val files = Seq( + SparkShimLoader.getSparkShims.generatePartitionedFile( + InternalRow.empty, + "fakePath0", 0, - "", - org.apache.spark.Success, - // this is little tricky here for 3.2 compatibility, we use -1 for partition id. - new TaskInfo(1, 1, 1, 1L, "0", "host-0", TaskLocality.ANY, false), - null, - null - ) - val files = Seq( - SparkShimLoader.getSparkShims.generatePartitionedFile( - InternalRow.empty, - "fakePath0", - 0, - 100, - Array("host-3")), - SparkShimLoader.getSparkShims.generatePartitionedFile( - InternalRow.empty, - "fakePath0", - 100, - 200, - Array("host-3")) - ).toArray - val filePartition = FilePartition(-1, files) - val softAffinityListener = new SoftAffinityListener() - softAffinityListener.onExecutorAdded(addEvent0) - softAffinityListener.onExecutorAdded(addEvent1) - SoftAffinityManager.updatePartitionMap(filePartition, 1) - assert(SoftAffinityManager.rddPartitionInfoMap.size == 1) - softAffinityListener.onStageSubmitted(stage1SubmitEvent) - softAffinityListener.onTaskEnd(taskEnd1) - assert(SoftAffinityManager.duplicateReadingInfos.size == 1) - // check location (executor 0) of dulicate reading is returned. - val locations = SoftAffinity.getFilePartitionLocations(filePartition) + 100, + Array("host-3")), + SparkShimLoader.getSparkShims.generatePartitionedFile( + InternalRow.empty, + "fakePath0", + 100, + 200, + Array("host-3")) + ).toArray + val filePartition = FilePartition(-1, files) + val softAffinityListener = new SoftAffinityListener() + softAffinityListener.onExecutorAdded(addEvent0) + softAffinityListener.onExecutorAdded(addEvent1) + SoftAffinityManager.updatePartitionMap(filePartition, 1) + assert(SoftAffinityManager.rddPartitionInfoMap.size == 1) + softAffinityListener.onStageSubmitted(stage1SubmitEvent) + softAffinityListener.onTaskEnd(taskEnd1) + assert(SoftAffinityManager.duplicateReadingInfos.size == 1) + // check location (executor 0) of dulicate reading is returned. + val locations = SoftAffinity.getFilePartitionLocations(filePartition) - assertResult(Set("executor_host-0_0")) { - locations.toSet - } - softAffinityListener.onStageCompleted(stage1EndEvent) - // stage 1 completed, check all middle status is cleared. - assert(SoftAffinityManager.rddPartitionInfoMap.size == 0) - assert(SoftAffinityManager.stageInfoMap.size == 0) - softAffinityListener.onExecutorRemoved(removedEvent0) - softAffinityListener.onExecutorRemoved(removedEvent1) - // executor 0 is removed, return empty. - assert(SoftAffinityManager.askExecutors(filePartition).isEmpty) + assertResult(Set("executor_host-0_0")) { + locations.toSet } + softAffinityListener.onStageCompleted(stage1EndEvent) + // stage 1 completed, check all middle status is cleared. + assert(SoftAffinityManager.rddPartitionInfoMap.size == 0) + assert(SoftAffinityManager.stageInfoMap.size == 0) + softAffinityListener.onExecutorRemoved(removedEvent0) + softAffinityListener.onExecutorRemoved(removedEvent1) + // executor 0 is removed, return empty. + assert(SoftAffinityManager.askExecutors(filePartition).isEmpty) } test("Duplicate reading detection limits middle states count") { // This test simulate the case listener bus stucks. We need to make sure the middle states // count would not exceed the configed threshold. - if (SparkShimLoader.getSparkShims.supportDuplicateReadingTracking) { - val files = Seq( - SparkShimLoader.getSparkShims.generatePartitionedFile( - InternalRow.empty, - "fakePath0", - 0, - 100, - Array("host-3")), - SparkShimLoader.getSparkShims.generatePartitionedFile( - InternalRow.empty, - "fakePath0", - 100, - 200, - Array("host-3")) - ).toArray - val filePartition = FilePartition(-1, files) - FakeSoftAffinityManager.updatePartitionMap(filePartition, 1) - assert(FakeSoftAffinityManager.rddPartitionInfoMap.size == 1) - val filePartition1 = FilePartition(-1, files) - FakeSoftAffinityManager.updatePartitionMap(filePartition1, 2) - assert(FakeSoftAffinityManager.rddPartitionInfoMap.size == 1) - assert(FakeSoftAffinityManager.stageInfoMap.size <= 1) - } + val files = Seq( + SparkShimLoader.getSparkShims.generatePartitionedFile( + InternalRow.empty, + "fakePath0", + 0, + 100, + Array("host-3")), + SparkShimLoader.getSparkShims.generatePartitionedFile( + InternalRow.empty, + "fakePath0", + 100, + 200, + Array("host-3")) + ).toArray + val filePartition = FilePartition(-1, files) + FakeSoftAffinityManager.updatePartitionMap(filePartition, 1) + assert(FakeSoftAffinityManager.rddPartitionInfoMap.size == 1) + val filePartition1 = FilePartition(-1, files) + FakeSoftAffinityManager.updatePartitionMap(filePartition1, 2) + assert(FakeSoftAffinityManager.rddPartitionInfoMap.size == 1) + assert(FakeSoftAffinityManager.stageInfoMap.size <= 1) } } diff --git a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala index 1e03b7921a10..1f6d015393f1 100644 --- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala @@ -22,21 +22,15 @@ import org.apache.gluten.expression.Sig import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.io.FileCommitProtocol -import org.apache.spark.scheduler.TaskInfo -import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.csv.CSVOptions -import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryExpression, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RaiseError, UnBase64} -import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RaiseError, UnBase64} import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{InputPartition, Scan} import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.execution._ @@ -44,13 +38,10 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanExecBase} -import org.apache.spark.sql.execution.datasources.v2.text.TextScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike} import org.apache.spark.sql.execution.window.WindowGroupLimitExecShim import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, StructType} -import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.storage.{BlockId, BlockManagerId} import org.apache.spark.util.SparkShimVersionUtil import org.apache.hadoop.fs.{FileStatus, Path} @@ -85,9 +76,6 @@ object SparkShimDescriptor { } trait SparkShims { - // for this purpose, change HashClusteredDistribution to ClusteredDistribution - // https://github.com/apache/spark/pull/32875 - def getDistribution(leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Seq[Distribution] def scalarExpressionMappings: Seq[Sig] @@ -95,24 +83,12 @@ trait SparkShims { def runtimeReplaceableExpressionMappings: Seq[Sig] - def convertPartitionTransforms(partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) - def generateFileScanRDD( sparkSession: SparkSession, readFunction: PartitionedFile => Iterator[InternalRow], filePartitions: Seq[FilePartition], fileSourceScanExec: FileSourceScanExec): FileScanRDD - def getTextScan( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - dataSchema: StructType, - readDataSchema: StructType, - readPartitionSchema: StructType, - options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty, - dataFilters: Seq[Expression] = Seq.empty): TextScan - def filesGroupedToBuckets( selectedPartitions: Array[PartitionDirectory]): Map[Int, Array[PartitionedFile]] @@ -127,32 +103,6 @@ trait SparkShims { length: Long, @transient locations: Array[String] = Array.empty): PartitionedFile - def bloomFilterExpressionMappings(): Seq[Sig] - - def newBloomFilterAggregate[T]( - child: Expression, - estimatedNumItemsExpression: Expression, - numBitsExpression: Expression, - mutableAggBufferOffset: Int, - inputAggBufferOffset: Int): TypedImperativeAggregate[T] - - def newMightContain( - bloomFilterExpression: Expression, - valueExpression: Expression): BinaryExpression - - def replaceBloomFilterAggregate[T]( - expr: Expression, - bloomFilterAggReplacer: ( - Expression, - Expression, - Expression, - Int, - Int) => TypedImperativeAggregate[T]): Expression - - def replaceMightContain[T]( - expr: Expression, - mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression - def isWindowGroupLimitExec(plan: SparkPlan): Boolean = false def getWindowGroupLimitExecShim(plan: SparkPlan): WindowGroupLimitExecShim = null @@ -194,22 +144,9 @@ trait SparkShims { sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit - def getShuffleReaderParam[K, C]( - handle: ShuffleHandle, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int): Tuple2[Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], Boolean] - // Compatible with Spark-3.5 and later def getShuffleAdvisoryPartitionSize(shuffle: ShuffleExchangeLike): Option[Long] = None - // Partition id in TaskInfo is only available after spark 3.3. - def getPartitionId(taskInfo: TaskInfo): Int - - // Because above, this feature is only supported after spark 3.3 - def supportDuplicateReadingTracking: Boolean - def getFileStatus(partition: PartitionDirectory): Seq[(FileStatus, Map[String, Any])] def isFileSplittable(relation: HadoopFsRelation, filePath: Path, sparkSchema: StructType): Boolean @@ -231,9 +168,6 @@ trait SparkShims { def attributesFromStruct(structType: StructType): Seq[Attribute] - // Spark 3.3 and later only have file size and modification time in PartitionedFile - def getFileSizeAndModificationTime(file: PartitionedFile): (Option[Long], Option[Long]) - def generateMetadataColumns( file: PartitionedFile, metadataColumnNames: Seq[String] = Seq.empty): Map[String, String] = { @@ -278,32 +212,11 @@ trait SparkShims { def withAnsiEvalMode(expr: Expression): Boolean = false - def dateTimestampFormatInReadIsDefaultValue(csvOptions: CSVOptions, timeZone: String): Boolean - def createParquetFilters( conf: SQLConf, schema: MessageType, caseSensitive: Option[Boolean] = None): ParquetFilters - def genDecimalRoundExpressionOutput(decimalType: DecimalType, toScale: Int): DecimalType = { - val p = decimalType.precision - val s = decimalType.scale - // After rounding we may need one more digit in the integral part, - // e.g. `ceil(9.9, 0)` -> `10`, `ceil(99, -1)` -> `100`. - val integralLeastNumDigits = p - s + 1 - if (toScale < 0) { - // negative scale means we need to adjust `-scale` number of digits before the decimal - // point, which means we need at lease `-scale + 1` digits (after rounding). - val newPrecision = math.max(integralLeastNumDigits, -toScale + 1) - // We have to accept the risk of overflow as we can't exceed the max precision. - DecimalType(math.min(newPrecision, DecimalType.MAX_PRECISION), 0) - } else { - val newScale = math.min(s, toScale) - // We have to accept the risk of overflow as we can't exceed the max precision. - DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale) - } - } - def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = { throw new UnsupportedOperationException("ArrayInsert not supported.") } diff --git a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala index 19bb10c0eea7..0ea3d7d9b8fe 100644 --- a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala @@ -23,37 +23,27 @@ import org.apache.gluten.sql.shims.SparkShims import org.apache.gluten.utils.ExceptionUtils import org.apache.spark._ -import org.apache.spark.scheduler.TaskInfo -import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.DecimalPrecision -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, RegrR2, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.RegrR2 import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TimestampFormatter} import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.catalyst.util.TimestampFormatter import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, QueryExecution, SparkPlan, SparkPlanner} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters import org.apache.spark.sql.execution.datasources.v2.BatchScanExec -import org.apache.spark.sql.execution.datasources.v2.text.TextScan -import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.types.{DecimalType, StructField, StructType} -import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.storage.{BlockId, BlockManagerId} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.crypto.ParquetCryptoRuntimeException @@ -65,11 +55,6 @@ import java.time.ZoneOffset import scala.collection.mutable class Spark33Shims extends SparkShims { - override def getDistribution( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression]): Seq[Distribution] = { - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - } override def scalarExpressionMappings: Seq[Sig] = { Seq( @@ -99,11 +84,6 @@ class Spark33Shims extends SparkShims { ) } - override def convertPartitionTransforms( - partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = { - CatalogUtil.convertPartitionTransforms(partitions) - } - override def generateFileScanRDD( sparkSession: SparkSession, readFunction: PartitionedFile => Iterator[InternalRow], @@ -120,26 +100,6 @@ class Spark33Shims extends SparkShims { ) } - override def getTextScan( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - dataSchema: StructType, - readDataSchema: StructType, - readPartitionSchema: StructType, - options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression], - dataFilters: Seq[Expression]): TextScan = { - TextScan( - sparkSession, - fileIndex, - dataSchema, - readDataSchema, - readPartitionSchema, - options, - partitionFilters, - dataFilters) - } - override def filesGroupedToBuckets( selectedPartitions: Array[PartitionDirectory]): Map[Int, Array[PartitionedFile]] = { selectedPartitions @@ -164,67 +124,6 @@ class Spark33Shims extends SparkShims { @transient locations: Array[String] = Array.empty): PartitionedFile = PartitionedFile(partitionValues, filePath, start, length, locations) - override def bloomFilterExpressionMappings(): Seq[Sig] = Seq( - Sig[BloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN), - Sig[BloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG) - ) - - override def newBloomFilterAggregate[T]( - child: Expression, - estimatedNumItemsExpression: Expression, - numBitsExpression: Expression, - mutableAggBufferOffset: Int, - inputAggBufferOffset: Int): TypedImperativeAggregate[T] = { - BloomFilterAggregate( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset).asInstanceOf[TypedImperativeAggregate[T]] - } - - override def newMightContain( - bloomFilterExpression: Expression, - valueExpression: Expression): BinaryExpression = { - BloomFilterMightContain(bloomFilterExpression, valueExpression) - } - - override def replaceBloomFilterAggregate[T]( - expr: Expression, - bloomFilterAggReplacer: ( - Expression, - Expression, - Expression, - Int, - Int) => TypedImperativeAggregate[T]): Expression = expr match { - case BloomFilterAggregate( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset) => - bloomFilterAggReplacer( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset) - case other => other - } - - override def replaceMightContain[T]( - expr: Expression, - mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression = expr match { - case BloomFilterMightContain(bloomFilterExpression, valueExpression) => - mightContainReplacer(bloomFilterExpression, valueExpression) - case other => other - } - - override def getFileSizeAndModificationTime( - file: PartitionedFile): (Option[Long], Option[Long]) = { - (Some(file.fileSize), Some(file.modificationTime)) - } - override def generateMetadataColumns( file: PartitionedFile, metadataColumnNames: Seq[String]): Map[String, String] = { @@ -275,21 +174,6 @@ class Spark33Shims extends SparkShims { sc.cancelJobGroup(broadcastExchange.runId.toString) } - override def getShuffleReaderParam[K, C]( - handle: ShuffleHandle, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int): Tuple2[Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], Boolean] = { - ShuffleUtils.getReaderParam(handle, startMapIndex, endMapIndex, startPartition, endPartition) - } - - override def getPartitionId(taskInfo: TaskInfo): Int = { - taskInfo.partitionId - } - - override def supportDuplicateReadingTracking: Boolean = true - def getFileStatus(partition: PartitionDirectory): Seq[(FileStatus, Map[String, Any])] = partition.files.map(f => (f, Map.empty[String, Any])) @@ -345,15 +229,6 @@ class Spark33Shims extends SparkShims { } } - override def dateTimestampFormatInReadIsDefaultValue( - csvOptions: CSVOptions, - timeZone: String): Boolean = { - val default = new CSVOptions(CaseInsensitiveMap(Map()), csvOptions.columnPruning, timeZone) - csvOptions.dateFormatInRead == default.dateFormatInRead && - csvOptions.timestampFormatInRead == default.timestampFormatInRead && - csvOptions.timestampNTZFormatInRead == default.timestampNTZFormatInRead - } - override def createParquetFilters( conf: SQLConf, schema: MessageType, diff --git a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala index 199c5313da46..61bc3bc94568 100644 --- a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala @@ -25,38 +25,29 @@ import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.paths.SparkPath -import org.apache.spark.scheduler.TaskInfo -import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.DecimalPrecision -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, KeyGroupedPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, InternalRowComparableWrapper, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{InternalRowComparableWrapper, TimestampFormatter} import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanExecBase} -import org.apache.spark.sql.execution.datasources.v2.text.TextScan -import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike import org.apache.spark.sql.extension.RewriteCreateTableAsSelect import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType, StructField, StructType} -import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.storage.{BlockId, BlockManagerId} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.crypto.ParquetCryptoRuntimeException @@ -69,11 +60,6 @@ import scala.collection.mutable import scala.reflect.ClassTag class Spark34Shims extends SparkShims { - override def getDistribution( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression]): Seq[Distribution] = { - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - } override def scalarExpressionMappings: Seq[Sig] = { Seq( @@ -116,11 +102,6 @@ class Spark34Shims extends SparkShims { ) } - override def convertPartitionTransforms( - partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = { - CatalogUtil.convertPartitionTransforms(partitions) - } - override def generateFileScanRDD( sparkSession: SparkSession, readFunction: PartitionedFile => Iterator[InternalRow], @@ -137,26 +118,6 @@ class Spark34Shims extends SparkShims { ) } - override def getTextScan( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - dataSchema: StructType, - readDataSchema: StructType, - readPartitionSchema: StructType, - options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression], - dataFilters: Seq[Expression]): TextScan = { - new TextScan( - sparkSession, - fileIndex, - dataSchema, - readDataSchema, - readPartitionSchema, - options, - partitionFilters, - dataFilters) - } - override def filesGroupedToBuckets( selectedPartitions: Array[PartitionDirectory]): Map[Int, Array[PartitionedFile]] = { selectedPartitions @@ -181,67 +142,6 @@ class Spark34Shims extends SparkShims { @transient locations: Array[String] = Array.empty): PartitionedFile = PartitionedFile(partitionValues, SparkPath.fromPathString(filePath), start, length, locations) - override def bloomFilterExpressionMappings(): Seq[Sig] = Seq( - Sig[BloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN), - Sig[BloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG) - ) - - override def newBloomFilterAggregate[T]( - child: Expression, - estimatedNumItemsExpression: Expression, - numBitsExpression: Expression, - mutableAggBufferOffset: Int, - inputAggBufferOffset: Int): TypedImperativeAggregate[T] = { - BloomFilterAggregate( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset).asInstanceOf[TypedImperativeAggregate[T]] - } - - override def newMightContain( - bloomFilterExpression: Expression, - valueExpression: Expression): BinaryExpression = { - BloomFilterMightContain(bloomFilterExpression, valueExpression) - } - - override def replaceBloomFilterAggregate[T]( - expr: Expression, - bloomFilterAggReplacer: ( - Expression, - Expression, - Expression, - Int, - Int) => TypedImperativeAggregate[T]): Expression = expr match { - case BloomFilterAggregate( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset) => - bloomFilterAggReplacer( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset) - case other => other - } - - override def replaceMightContain[T]( - expr: Expression, - mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression = expr match { - case BloomFilterMightContain(bloomFilterExpression, valueExpression) => - mightContainReplacer(bloomFilterExpression, valueExpression) - case other => other - } - - override def getFileSizeAndModificationTime( - file: PartitionedFile): (Option[Long], Option[Long]) = { - (Some(file.fileSize), Some(file.modificationTime)) - } - override def generateMetadataColumns( file: PartitionedFile, metadataColumnNames: Seq[String]): Map[String, String] = { @@ -338,21 +238,6 @@ class Spark34Shims extends SparkShims { sc.cancelJobGroup(broadcastExchange.runId.toString) } - override def getShuffleReaderParam[K, C]( - handle: ShuffleHandle, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int): Tuple2[Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], Boolean] = { - ShuffleUtils.getReaderParam(handle, startMapIndex, endMapIndex, startPartition, endPartition) - } - - override def getPartitionId(taskInfo: TaskInfo): Int = { - taskInfo.partitionId - } - - override def supportDuplicateReadingTracking: Boolean = true - def getFileStatus(partition: PartitionDirectory): Seq[(FileStatus, Map[String, Any])] = partition.files.map(f => (f, Map.empty[String, Any])) @@ -565,15 +450,6 @@ class Spark34Shims extends SparkShims { } } - override def dateTimestampFormatInReadIsDefaultValue( - csvOptions: CSVOptions, - timeZone: String): Boolean = { - val default = new CSVOptions(CaseInsensitiveMap(Map()), csvOptions.columnPruning, timeZone) - csvOptions.dateFormatInRead == default.dateFormatInRead && - csvOptions.timestampFormatInRead == default.timestampFormatInRead && - csvOptions.timestampNTZFormatInRead == default.timestampNTZFormatInRead - } - override def createParquetFilters( conf: SQLConf, schema: MessageType, diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index 0ee78c08a6c4..7e31af9b672c 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -24,38 +24,29 @@ import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.paths.SparkPath -import org.apache.spark.scheduler.TaskInfo -import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.analysis.DecimalPrecision -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, KeyGroupedPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, InternalRowComparableWrapper, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{InternalRowComparableWrapper, TimestampFormatter} import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetFilters} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanExecBase} -import org.apache.spark.sql.execution.datasources.v2.text.TextScan -import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike} import org.apache.spark.sql.execution.window.{Final, GlutenFinal, GlutenPartial, Partial, WindowGroupLimitExec, WindowGroupLimitExecShim} import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType, StructField, StructType} -import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.storage.{BlockId, BlockManagerId} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.metadata.FileMetaData.EncryptionType @@ -71,12 +62,6 @@ import scala.reflect.ClassTag class Spark35Shims extends SparkShims { - override def getDistribution( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression]): Seq[Distribution] = { - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - } - override def scalarExpressionMappings: Seq[Sig] = { Seq( Sig[SplitPart](ExpressionNames.SPLIT_PART), @@ -121,11 +106,6 @@ class Spark35Shims extends SparkShims { ) } - override def convertPartitionTransforms( - partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = { - CatalogUtil.convertPartitionTransforms(partitions) - } - override def generateFileScanRDD( sparkSession: SparkSession, readFunction: PartitionedFile => Iterator[InternalRow], @@ -142,26 +122,6 @@ class Spark35Shims extends SparkShims { ) } - override def getTextScan( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - dataSchema: StructType, - readDataSchema: StructType, - readPartitionSchema: StructType, - options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression], - dataFilters: Seq[Expression]): TextScan = { - new TextScan( - sparkSession, - fileIndex, - dataSchema, - readDataSchema, - readPartitionSchema, - options, - partitionFilters, - dataFilters) - } - override def filesGroupedToBuckets( selectedPartitions: Array[PartitionDirectory]): Map[Int, Array[PartitionedFile]] = { selectedPartitions @@ -184,67 +144,6 @@ class Spark35Shims extends SparkShims { @transient locations: Array[String] = Array.empty): PartitionedFile = PartitionedFile(partitionValues, SparkPath.fromPathString(filePath), start, length, locations) - override def bloomFilterExpressionMappings(): Seq[Sig] = Seq( - Sig[BloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN), - Sig[BloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG) - ) - - override def newBloomFilterAggregate[T]( - child: Expression, - estimatedNumItemsExpression: Expression, - numBitsExpression: Expression, - mutableAggBufferOffset: Int, - inputAggBufferOffset: Int): TypedImperativeAggregate[T] = { - BloomFilterAggregate( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset).asInstanceOf[TypedImperativeAggregate[T]] - } - - override def newMightContain( - bloomFilterExpression: Expression, - valueExpression: Expression): BinaryExpression = { - BloomFilterMightContain(bloomFilterExpression, valueExpression) - } - - override def replaceBloomFilterAggregate[T]( - expr: Expression, - bloomFilterAggReplacer: ( - Expression, - Expression, - Expression, - Int, - Int) => TypedImperativeAggregate[T]): Expression = expr match { - case BloomFilterAggregate( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset) => - bloomFilterAggReplacer( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset) - case other => other - } - - override def replaceMightContain[T]( - expr: Expression, - mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression = expr match { - case BloomFilterMightContain(bloomFilterExpression, valueExpression) => - mightContainReplacer(bloomFilterExpression, valueExpression) - case other => other - } - - override def getFileSizeAndModificationTime( - file: PartitionedFile): (Option[Long], Option[Long]) = { - (Some(file.fileSize), Some(file.modificationTime)) - } - override def generateMetadataColumns( file: PartitionedFile, metadataColumnNames: Seq[String]): Map[String, String] = { @@ -376,24 +275,9 @@ class Spark35Shims extends SparkShims { sc.cancelJobsWithTag(broadcastExchange.jobTag) } - override def getShuffleReaderParam[K, C]( - handle: ShuffleHandle, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int): Tuple2[Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], Boolean] = { - ShuffleUtils.getReaderParam(handle, startMapIndex, endMapIndex, startPartition, endPartition) - } - override def getShuffleAdvisoryPartitionSize(shuffle: ShuffleExchangeLike): Option[Long] = shuffle.advisoryPartitionSize - override def getPartitionId(taskInfo: TaskInfo): Int = { - taskInfo.partitionId - } - - override def supportDuplicateReadingTracking: Boolean = true - def getFileStatus(partition: PartitionDirectory): Seq[(FileStatus, Map[String, Any])] = partition.files.map(f => (f.fileStatus, f.metadata)) @@ -603,15 +487,6 @@ class Spark35Shims extends SparkShims { } } - override def dateTimestampFormatInReadIsDefaultValue( - csvOptions: CSVOptions, - timeZone: String): Boolean = { - val default = new CSVOptions(CaseInsensitiveMap(Map()), csvOptions.columnPruning, timeZone) - csvOptions.dateFormatInRead == default.dateFormatInRead && - csvOptions.timestampFormatInRead == default.timestampFormatInRead && - csvOptions.timestampNTZFormatInRead == default.timestampNTZFormatInRead - } - override def createParquetFilters( conf: SQLConf, schema: MessageType, diff --git a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala index bbbad38010f1..0c3b163ce533 100644 --- a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala +++ b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala @@ -24,26 +24,21 @@ import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.paths.SparkPath -import org.apache.spark.scheduler.TaskInfo -import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.analysis.DecimalPrecisionTypeCoercion -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftSingle} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, InternalRowComparableWrapper, MapData, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{InternalRowComparableWrapper, MapData, TimestampFormatter} import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan} import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.execution._ @@ -51,14 +46,10 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetFilters} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, BatchScanExecShim, DataSourceV2ScanExecBase} -import org.apache.spark.sql.execution.datasources.v2.text.TextScan -import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike} import org.apache.spark.sql.execution.window.{Final, Partial, _} import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.storage.{BlockId, BlockManagerId} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.metadata.{CompressionCodecName, ParquetMetadata} @@ -74,12 +65,6 @@ import scala.reflect.ClassTag class Spark40Shims extends SparkShims { - override def getDistribution( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression]): Seq[Distribution] = { - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - } - override def scalarExpressionMappings: Seq[Sig] = { Seq( Sig[SplitPart](ExpressionNames.SPLIT_PART), @@ -126,11 +111,6 @@ class Spark40Shims extends SparkShims { ) } - override def convertPartitionTransforms( - partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = { - CatalogUtil.convertPartitionTransforms(partitions) - } - override def generateFileScanRDD( sparkSession: SparkSession, readFunction: PartitionedFile => Iterator[InternalRow], @@ -147,26 +127,6 @@ class Spark40Shims extends SparkShims { ) } - override def getTextScan( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - dataSchema: StructType, - readDataSchema: StructType, - readPartitionSchema: StructType, - options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression], - dataFilters: Seq[Expression]): TextScan = { - TextScan( - sparkSession, - fileIndex, - dataSchema, - readDataSchema, - readPartitionSchema, - options, - partitionFilters, - dataFilters) - } - override def filesGroupedToBuckets( selectedPartitions: Array[PartitionDirectory]): Map[Int, Array[PartitionedFile]] = { selectedPartitions @@ -189,67 +149,6 @@ class Spark40Shims extends SparkShims { @transient locations: Array[String] = Array.empty): PartitionedFile = PartitionedFile(partitionValues, SparkPath.fromPathString(filePath), start, length, locations) - override def bloomFilterExpressionMappings(): Seq[Sig] = Seq( - Sig[BloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN), - Sig[BloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG) - ) - - override def newBloomFilterAggregate[T]( - child: Expression, - estimatedNumItemsExpression: Expression, - numBitsExpression: Expression, - mutableAggBufferOffset: Int, - inputAggBufferOffset: Int): TypedImperativeAggregate[T] = { - BloomFilterAggregate( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset).asInstanceOf[TypedImperativeAggregate[T]] - } - - override def newMightContain( - bloomFilterExpression: Expression, - valueExpression: Expression): BinaryExpression = { - BloomFilterMightContain(bloomFilterExpression, valueExpression) - } - - override def replaceBloomFilterAggregate[T]( - expr: Expression, - bloomFilterAggReplacer: ( - Expression, - Expression, - Expression, - Int, - Int) => TypedImperativeAggregate[T]): Expression = expr match { - case BloomFilterAggregate( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset) => - bloomFilterAggReplacer( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset) - case other => other - } - - override def replaceMightContain[T]( - expr: Expression, - mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression = expr match { - case BloomFilterMightContain(bloomFilterExpression, valueExpression) => - mightContainReplacer(bloomFilterExpression, valueExpression) - case other => other - } - - override def getFileSizeAndModificationTime( - file: PartitionedFile): (Option[Long], Option[Long]) = { - (Some(file.fileSize), Some(file.modificationTime)) - } - override def generateMetadataColumns( file: PartitionedFile, metadataColumnNames: Seq[String]): Map[String, String] = { @@ -381,24 +280,9 @@ class Spark40Shims extends SparkShims { sc.cancelJobsWithTag(broadcastExchange.jobTag) } - override def getShuffleReaderParam[K, C]( - handle: ShuffleHandle, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int): Tuple2[Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], Boolean] = { - ShuffleUtils.getReaderParam(handle, startMapIndex, endMapIndex, startPartition, endPartition) - } - override def getShuffleAdvisoryPartitionSize(shuffle: ShuffleExchangeLike): Option[Long] = shuffle.advisoryPartitionSize - override def getPartitionId(taskInfo: TaskInfo): Int = { - taskInfo.partitionId - } - - override def supportDuplicateReadingTracking: Boolean = true - def getFileStatus(partition: PartitionDirectory): Seq[(FileStatus, Map[String, Any])] = partition.files.map(f => (f.fileStatus, f.metadata)) @@ -628,15 +512,6 @@ class Spark40Shims extends SparkShims { } } - override def dateTimestampFormatInReadIsDefaultValue( - csvOptions: CSVOptions, - timeZone: String): Boolean = { - val default = new CSVOptions(CaseInsensitiveMap(Map()), csvOptions.columnPruning, timeZone) - csvOptions.dateFormatInRead == default.dateFormatInRead && - csvOptions.timestampFormatInRead == default.timestampFormatInRead && - csvOptions.timestampNTZFormatInRead == default.timestampNTZFormatInRead - } - override def createParquetFilters( conf: SQLConf, schema: MessageType, diff --git a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala index a031becdb064..0e3e752f9970 100644 --- a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala +++ b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala @@ -24,25 +24,20 @@ import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.paths.SparkPath -import org.apache.spark.scheduler.TaskInfo -import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.analysis.DecimalPrecisionTypeCoercion -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftSingle} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, InternalRowComparableWrapper, MapData, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{InternalRowComparableWrapper, MapData, TimestampFormatter} import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan} import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.execution._ @@ -50,14 +45,10 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetFilters} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, BatchScanExecShim, DataSourceV2ScanExecBase} -import org.apache.spark.sql.execution.datasources.v2.text.TextScan -import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike} import org.apache.spark.sql.execution.window.{Final, Partial, _} import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.storage.{BlockId, BlockManagerId} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.metadata.{CompressionCodecName, ParquetMetadata} @@ -73,12 +64,6 @@ import scala.reflect.ClassTag class Spark41Shims extends SparkShims { - override def getDistribution( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression]): Seq[Distribution] = { - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - } - override def scalarExpressionMappings: Seq[Sig] = { Seq( Sig[SplitPart](ExpressionNames.SPLIT_PART), @@ -125,11 +110,6 @@ class Spark41Shims extends SparkShims { ) } - override def convertPartitionTransforms( - partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = { - CatalogUtil.convertPartitionTransforms(partitions) - } - override def generateFileScanRDD( sparkSession: SparkSession, readFunction: PartitionedFile => Iterator[InternalRow], @@ -146,26 +126,6 @@ class Spark41Shims extends SparkShims { ) } - override def getTextScan( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - dataSchema: StructType, - readDataSchema: StructType, - readPartitionSchema: StructType, - options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression], - dataFilters: Seq[Expression]): TextScan = { - TextScan( - sparkSession, - fileIndex, - dataSchema, - readDataSchema, - readPartitionSchema, - options, - partitionFilters, - dataFilters) - } - override def filesGroupedToBuckets( selectedPartitions: Array[PartitionDirectory]): Map[Int, Array[PartitionedFile]] = { selectedPartitions @@ -188,67 +148,6 @@ class Spark41Shims extends SparkShims { @transient locations: Array[String] = Array.empty): PartitionedFile = PartitionedFile(partitionValues, SparkPath.fromPathString(filePath), start, length, locations) - override def bloomFilterExpressionMappings(): Seq[Sig] = Seq( - Sig[BloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN), - Sig[BloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG) - ) - - override def newBloomFilterAggregate[T]( - child: Expression, - estimatedNumItemsExpression: Expression, - numBitsExpression: Expression, - mutableAggBufferOffset: Int, - inputAggBufferOffset: Int): TypedImperativeAggregate[T] = { - BloomFilterAggregate( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset).asInstanceOf[TypedImperativeAggregate[T]] - } - - override def newMightContain( - bloomFilterExpression: Expression, - valueExpression: Expression): BinaryExpression = { - BloomFilterMightContain(bloomFilterExpression, valueExpression) - } - - override def replaceBloomFilterAggregate[T]( - expr: Expression, - bloomFilterAggReplacer: ( - Expression, - Expression, - Expression, - Int, - Int) => TypedImperativeAggregate[T]): Expression = expr match { - case BloomFilterAggregate( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset) => - bloomFilterAggReplacer( - child, - estimatedNumItemsExpression, - numBitsExpression, - mutableAggBufferOffset, - inputAggBufferOffset) - case other => other - } - - override def replaceMightContain[T]( - expr: Expression, - mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression = expr match { - case BloomFilterMightContain(bloomFilterExpression, valueExpression) => - mightContainReplacer(bloomFilterExpression, valueExpression) - case other => other - } - - override def getFileSizeAndModificationTime( - file: PartitionedFile): (Option[Long], Option[Long]) = { - (Some(file.fileSize), Some(file.modificationTime)) - } - override def generateMetadataColumns( file: PartitionedFile, metadataColumnNames: Seq[String]): Map[String, String] = { @@ -380,24 +279,9 @@ class Spark41Shims extends SparkShims { sc.cancelJobsWithTag(broadcastExchange.jobTag) } - override def getShuffleReaderParam[K, C]( - handle: ShuffleHandle, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int): Tuple2[Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], Boolean] = { - ShuffleUtils.getReaderParam(handle, startMapIndex, endMapIndex, startPartition, endPartition) - } - override def getShuffleAdvisoryPartitionSize(shuffle: ShuffleExchangeLike): Option[Long] = shuffle.advisoryPartitionSize - override def getPartitionId(taskInfo: TaskInfo): Int = { - taskInfo.partitionId - } - - override def supportDuplicateReadingTracking: Boolean = true - def getFileStatus(partition: PartitionDirectory): Seq[(FileStatus, Map[String, Any])] = partition.files.map(f => (f.fileStatus, f.metadata)) @@ -627,15 +511,6 @@ class Spark41Shims extends SparkShims { } } - override def dateTimestampFormatInReadIsDefaultValue( - csvOptions: CSVOptions, - timeZone: String): Boolean = { - val default = new CSVOptions(CaseInsensitiveMap(Map()), csvOptions.columnPruning, timeZone) - csvOptions.dateFormatInRead == default.dateFormatInRead && - csvOptions.timestampFormatInRead == default.timestampFormatInRead && - csvOptions.timestampNTZFormatInRead == default.timestampNTZFormatInRead - } - override def createParquetFilters( conf: SQLConf, schema: MessageType,