diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 02a76f69f0..9322149e46 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -105,6 +105,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Log] -> CometLog, classOf[Log2] -> CometLog2, classOf[Log10] -> CometLog10, + classOf[Logarithm] -> CometLogarithm, classOf[Multiply] -> CometMultiply, classOf[Pow] -> CometScalarFunction("pow"), classOf[Rand] -> CometRand, diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala b/spark/src/main/scala/org/apache/comet/serde/math.scala index 5a0393142a..b258ca6fae 100644 --- a/spark/src/main/scala/org/apache/comet/serde/math.scala +++ b/spark/src/main/scala/org/apache/comet/serde/math.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Tan, Unhex} +import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Logarithm, Tan, Unhex} import org.apache.spark.sql.types.{DecimalType, NumericType} import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -138,6 +138,20 @@ object CometLog2 extends CometExpressionSerde[Log2] with MathExprBase { } } +object CometLogarithm extends CometExpressionSerde[Logarithm] with MathExprBase { + override def convert( + expr: Logarithm, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + // Spark's Logarithm(left=base, right=value) returns null when result is NaN, + // which happens when base <= 0 or value <= 0. Apply nullIfNegative to both. + val leftExpr = exprToProtoInternal(nullIfNegative(expr.left), inputs, binding) + val rightExpr = exprToProtoInternal(nullIfNegative(expr.right), inputs, binding) + val optExpr = scalarFunctionExprToProto("log", leftExpr, rightExpr) + optExprWithInfo(optExpr, expr, expr.left, expr.right) + } +} + object CometHex extends CometExpressionSerde[Hex] with MathExprBase { override def convert( expr: Hex, diff --git a/spark/src/test/resources/sql-tests/expressions/math/tan.sql b/spark/src/test/resources/sql-tests/expressions/math/tan.sql index 21bd44f907..9496844804 100644 --- a/spark/src/test/resources/sql-tests/expressions/math/tan.sql +++ b/spark/src/test/resources/sql-tests/expressions/math/tan.sql @@ -16,6 +16,7 @@ -- under the License. -- ConfigMatrix: parquet.enable.dictionary=false,true +-- Config: spark.comet.expression.Tan.allowIncompatible=true statement CREATE TABLE test_tan(d double) USING parquet diff --git a/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala b/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala index 020759a7a6..5a0b34e056 100644 --- a/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala @@ -102,7 +102,7 @@ class CometSqlFileTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { case SparkAnswerOnly => checkSparkAnswer(sql) case WithTolerance(tol) => - checkSparkAnswerWithTolerance(sql, tol) + checkSparkAnswerAndOperatorWithTolerance(sql, tol) case ExpectFallback(reason) => checkSparkAnswerAndFallbackReason(sql, reason) case Ignore(reason) => diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 33c1d444b9..138b073b87 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -190,6 +190,17 @@ abstract class CometTestBase internalCheckSparkAnswer(df, assertCometNative = false, withTol = Some(absTol)) } + /** + * Check that the query returns the correct results when Comet is enabled and that Comet + * replaced all possible operators. Use the provided `tol` when comparing floating-point + * results. + */ + protected def checkSparkAnswerAndOperatorWithTolerance( + query: String, + absTol: Double = 1e-6): (SparkPlan, SparkPlan) = { + internalCheckSparkAnswer(sql(query), assertCometNative = true, withTol = Some(absTol)) + } + /** * Check that the query returns the correct results when Comet is enabled and that Comet * replaced all possible operators except for those specified in the excluded list.