Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
*/
package org.apache.spark.sql.execution.datasources.v2.clickhouse

import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException}
Expand All @@ -35,6 +33,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.sources.{DeltaSourceUtils, DeltaSQLConf}
import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils}
import org.apache.spark.sql.execution.datasources.v2.clickhouse.utils.CHDataSourceUtils
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.sources.InsertableRelation
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -119,7 +118,7 @@ class ClickHouseSparkCatalog
sourceQuery: Option[DataFrame],
operation: TableCreationModes.CreationMode): Table = {
val (partitionColumns, maybeBucketSpec) =
SparkShimLoader.getSparkShims.convertPartitionTransforms(partitions)
CatalogUtil.convertPartitionTransforms(partitions)
var newSchema = schema
var newPartitionColumns = partitionColumns
var newBucketSpec = maybeBucketSpec
Expand Down Expand Up @@ -232,7 +231,7 @@ class ClickHouseSparkCatalog
case _ => true
}.toMap
val (partitionColumns, maybeBucketSpec) =
SparkShimLoader.getSparkShims.convertPartitionTransforms(partitions)
CatalogUtil.convertPartitionTransforms(partitions)
var newSchema = schema
var newPartitionColumns = partitionColumns
var newBucketSpec = maybeBucketSpec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
*/
package org.apache.spark.sql.execution.datasources.v2.clickhouse

import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException}
Expand All @@ -35,6 +33,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.sources.{DeltaSourceUtils, DeltaSQLConf}
import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils}
import org.apache.spark.sql.execution.datasources.v2.clickhouse.utils.CHDataSourceUtils
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.sources.InsertableRelation
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -119,7 +118,7 @@ class ClickHouseSparkCatalog
sourceQuery: Option[DataFrame],
operation: TableCreationModes.CreationMode): Table = {
val (partitionColumns, maybeBucketSpec) =
SparkShimLoader.getSparkShims.convertPartitionTransforms(partitions)
CatalogUtil.convertPartitionTransforms(partitions)
var newSchema = schema
var newPartitionColumns = partitionColumns
var newBucketSpec = maybeBucketSpec
Expand Down Expand Up @@ -232,7 +231,7 @@ class ClickHouseSparkCatalog
case _ => true
}.toMap
val (partitionColumns, maybeBucketSpec) =
SparkShimLoader.getSparkShims.convertPartitionTransforms(partitions)
CatalogUtil.convertPartitionTransforms(partitions)
var newSchema = schema
var newPartitionColumns = partitionColumns
var newBucketSpec = maybeBucketSpec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
*/
package org.apache.spark.sql.execution.datasources.v2.clickhouse

import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
Expand All @@ -39,6 +37,7 @@ import org.apache.spark.sql.delta.sources.{DeltaSourceUtils, DeltaSQLConf}
import org.apache.spark.sql.delta.stats.StatisticsCollection
import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils}
import org.apache.spark.sql.execution.datasources.v2.clickhouse.utils.CHDataSourceUtils
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.sources.InsertableRelation
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -136,7 +135,7 @@ class ClickHouseSparkCatalog
sourceQuery: Option[DataFrame],
operation: TableCreationModes.CreationMode): Table = {
val (partitionColumns, maybeBucketSpec) =
SparkShimLoader.getSparkShims.convertPartitionTransforms(partitions)
CatalogUtil.convertPartitionTransforms(partitions)
var newSchema = schema
var newPartitionColumns = partitionColumns
var newBucketSpec = maybeBucketSpec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,8 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
}
partitionColumns.add(partitionColumn)

val (fileSize, modificationTime) =
SparkShimLoader.getSparkShims.getFileSizeAndModificationTime(file)
(fileSize, modificationTime) match {
case (Some(size), Some(time)) =>
fileSizes.add(JLong.valueOf(size))
modificationTimes.add(JLong.valueOf(time))
case _ =>
}
fileSizes.add(file.fileSize)
modificationTimes.add(file.modificationTime)

val otherConstantMetadataColumnValues =
DeltaShimLoader.getDeltaShims.convertRowIndexFilterIdEncoded(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.shuffle.utils.CHShuffleUtil
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectList, CollectSet}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, CollectList, CollectSet}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning, RangePartitioning}
Expand All @@ -56,7 +56,7 @@ import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildS
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil}
import org.apache.spark.sql.execution.window._
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SparkVersionUtil

Expand Down Expand Up @@ -602,7 +602,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
CHFlattenedExpression.sigOr
) ++
ExpressionExtensionTrait.expressionExtensionSigList ++
SparkShimLoader.getSparkShims.bloomFilterExpressionMappings()
Seq(
Sig[BloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN),
Sig[BloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG)
)
}

/** Define backend-specific expression converter. */
Expand Down Expand Up @@ -940,12 +943,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {

override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate

override def genDecimalRoundExpressionOutput(
decimalType: DecimalType,
toScale: Int): DecimalType = {
SparkShimLoader.getSparkShims.genDecimalRoundExpressionOutput(decimalType, toScale)
}

override def genWindowGroupLimitTransformer(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ class VeloxIteratorApi extends IteratorApi with Logging {
val locations = filePartitions.flatMap(p => SoftAffinity.getFilePartitionLocations(p))
val (paths, starts, lengths) = getPartitionedFileInfo(partitionFiles).unzip3
val (fileSizes, modificationTimes) = partitionFiles
.map(f => SparkShimLoader.getSparkShims.getFileSizeAndModificationTime(f))
.map(f => (f.fileSize, f.modificationTime))
.collect {
case (Some(size), Some(time)) =>
case (size, time) =>
(JLong.valueOf(size), JLong.valueOf(time))
}
.unzip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
*/
package org.apache.gluten.expression

import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.VeloxBloomFilter

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression}
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, BloomFilterMightContain, Expression}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.types.DataType
Expand All @@ -43,8 +42,7 @@ case class VeloxBloomFilterMightContain(
extends BinaryExpression {
import VeloxBloomFilterMightContain._

private val delegate =
SparkShimLoader.getSparkShims.newMightContain(bloomFilterExpression, valueExpression)
private val delegate = BloomFilterMightContain(bloomFilterExpression, valueExpression)

override def prettyName: String = "velox_might_contain"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
*/
package org.apache.gluten.expression.aggregate

import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.VeloxBloomFilter

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand All @@ -47,12 +46,12 @@ case class VeloxBloomFilterAggregate(
extends TypedImperativeAggregate[BloomFilter]
with TernaryLike[Expression] {

private val delegate = SparkShimLoader.getSparkShims.newBloomFilterAggregate[BloomFilter](
private val delegate = BloomFilterAggregate(
child,
estimatedNumItemsExpression,
numBitsExpression,
mutableAggBufferOffset,
inputAggBufferOffset)
inputAggBufferOffset).asInstanceOf[TypedImperativeAggregate[BloomFilter]]

override def prettyName: String = "velox_bloom_filter_agg"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ package org.apache.gluten.extension
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.datasource.ArrowCSVFileFormat
import org.apache.gluten.datasource.v2.ArrowCSVTable
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.PermissiveMode
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, PermissiveMode}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
Expand Down Expand Up @@ -102,6 +101,7 @@ case class ArrowConvertorRule(session: SparkSession) extends Rule[LogicalPlan] {
}

private def checkCsvOptions(csvOptions: CSVOptions, timeZone: String): Boolean = {
val default = new CSVOptions(CaseInsensitiveMap(Map()), csvOptions.columnPruning, timeZone)
csvOptions.headerFlag && !csvOptions.multiLine &&
csvOptions.delimiter.length == 1 &&
csvOptions.quote == '\"' &&
Expand All @@ -112,7 +112,9 @@ case class ArrowConvertorRule(session: SparkSession) extends Rule[LogicalPlan] {
csvOptions.nullValue == "" &&
csvOptions.emptyValueInRead == "" && csvOptions.comment == '\u0000' &&
csvOptions.columnPruning &&
SparkShimLoader.getSparkShims.dateTimestampFormatInReadIsDefaultValue(csvOptions, timeZone)
csvOptions.dateFormatInRead == default.dateFormatInRead &&
csvOptions.timestampFormatInRead == default.timestampFormatInRead &&
csvOptions.timestampNTZFormatInRead == default.timestampNTZFormatInRead
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to confirm timestampNTZFormatInRead exists in Spark 3.3's CSVOptions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirmed in source code. It exists. And the compilation also help ensures this. Thanks.

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ package org.apache.gluten.extension
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.expression.VeloxBloomFilterMightContain
import org.apache.gluten.expression.aggregate.VeloxBloomFilterAggregate
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, BloomFilterMightContain, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

Expand All @@ -40,12 +41,42 @@ case class BloomFilterMightContainJointRewriteRule(
out
}

private def replaceBloomFilterAggregate[T](
expr: Expression,
bloomFilterAggReplacer: (
Expression,
Expression,
Expression,
Int,
Int) => TypedImperativeAggregate[T]): Expression = expr match {
case BloomFilterAggregate(
child,
estimatedNumItemsExpression,
numBitsExpression,
mutableAggBufferOffset,
inputAggBufferOffset) =>
bloomFilterAggReplacer(
child,
estimatedNumItemsExpression,
numBitsExpression,
mutableAggBufferOffset,
inputAggBufferOffset)
case other => other
}

private def replaceMightContain[T](
expr: Expression,
mightContainReplacer: (Expression, Expression) => BinaryExpression): Expression = expr match {
case BloomFilterMightContain(bloomFilterExpression, valueExpression) =>
mightContainReplacer(bloomFilterExpression, valueExpression)
case other => other
}

private def applyForNode(p: SparkPlan) = {
p.transformExpressions {
case e =>
SparkShimLoader.getSparkShims.replaceMightContain(
SparkShimLoader.getSparkShims
.replaceBloomFilterAggregate(e, VeloxBloomFilterAggregate.apply),
replaceMightContain(
replaceBloomFilterAggregate(e, VeloxBloomFilterAggregate.apply),
VeloxBloomFilterMightContain.apply)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ trait ColumnarShuffledJoin extends BaseJoinExec {
// partitioning doesn't satisfy `HashClusteredDistribution`.
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
} else {
SparkShimLoader.getSparkShims.getDistribution(leftKeys, rightKeys)
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.hash.ConsistentHash
import org.apache.gluten.logging.LogLevelUtil
import org.apache.gluten.softaffinity.strategy.{ConsistentHashSoftAffinityStrategy, ExecutorNode}
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -185,7 +184,7 @@ abstract class AffinityManager extends LogLevelUtil with Logging {
val partitions = rddPartitionInfoMap.getIfPresent(rddId)
if (partitions != null) {
val key = partitions
.filter(p => p._1 == SparkShimLoader.getSparkShims.getPartitionId(event.taskInfo))
.filter(p => p._1 == event.taskInfo.partitionId)
.map(pInfo => s"${pInfo._2}_${pInfo._3}_${pInfo._4}")
.sortBy(p => p)
.mkString(",")
Expand Down Expand Up @@ -322,8 +321,7 @@ object SoftAffinityManager extends AffinityManager {
override lazy val detectDuplicateReading: Boolean = SparkEnv.get.conf.getBoolean(
GlutenConfig.GLUTEN_SOFT_AFFINITY_DUPLICATE_READING_DETECT_ENABLED.key,
GlutenConfig.GLUTEN_SOFT_AFFINITY_DUPLICATE_READING_DETECT_ENABLED.defaultValue.get
) &&
SparkShimLoader.getSparkShims.supportDuplicateReadingTracking
)

override lazy val duplicateReadingMaxCacheItems: Int = SparkEnv.get.conf.getInt(
GlutenConfig.GLUTEN_SOFT_AFFINITY_DUPLICATE_READING_MAX_CACHE_ITEMS.key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ package org.apache.spark.shuffle

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.vectorized.NativePartitioning

import org.apache.spark.{SparkConf, TaskContext}
import org.apache.spark.{ShuffleUtils, SparkConf, TaskContext}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ShuffleUtils class should be in the shims (it exists in shims/spark34/ etc.). We need to confirm it's available for Spark 3.3 as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it also exists in shims/spark33.

import org.apache.spark.internal.config._
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.shuffle.sort.ColumnarShuffleHandle
Expand Down Expand Up @@ -120,12 +119,7 @@ object GlutenShuffleUtils {
startPartition: Int,
endPartition: Int
): Tuple2[Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], Boolean] = {
SparkShimLoader.getSparkShims.getShuffleReaderParam(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition)
ShuffleUtils.getReaderParam(handle, startMapIndex, endMapIndex, startPartition, endPartition)
}

def getSortShuffleWriter[K, V](
Expand Down
Loading