From bfbd8a282ddfe781b6a512a6d4aea26dc32264f8 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Fri, 30 May 2025 23:11:11 +0800 Subject: [PATCH] Fix bug in casting string to timestamp: Spark400+ and DB35 do not support pattern: spaces + Thh:mm:ss Signed-off-by: Chong Gao --- docs/compatibility.md | 6 ------ .../src/main/python/cast_test.py | 20 +------------------ .../com/nvidia/spark/rapids/GpuCast.scala | 4 ++-- .../nvidia/spark/rapids/VersionUtils.scala | 18 +++++++++++++++++ 4 files changed, 21 insertions(+), 27 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index 755f1372776..d478ba9f894 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -753,12 +753,6 @@ The following formats/patterns are supported on the GPU. Timezone of UTC is assu | `"tomorrow"` | Yes | | `"yesterday"` | Yes | -### String to Timestamp -GPU Aligns to Spark except a known case which is actually a Spark bug. -Spark 35x supports the following case(spaces + Thh:mm:ss) while Spark 400 does not: -cast(' T00:00:00' as timestamp) -For more details, refer to [bug link](https://github.com/NVIDIA/spark-rapids-jni/issues/3401) - ### Constant Folding ConstantFolding is an operator optimization rule in Catalyst that replaces expressions that can diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py index e9418c34039..9305cfc7b0a 100644 --- a/integration_tests/src/main/python/cast_test.py +++ b/integration_tests/src/main/python/cast_test.py @@ -935,6 +935,7 @@ def _gen_df(spark): ("T23:17:50",), ("T23:17:50",), ("T23:17:50",), + (" \r\n\tT23:17:50",), # This is testing issue: https://github.com/NVIDIA/spark-rapids-jni/issues/3401 ("T23:17:50 \r\n\t",), ("T00",), ("T1:2",), @@ -963,25 +964,6 @@ def _query(spark): assert_gpu_and_cpu_are_equal_collect(lambda spark: _query(spark)) -# Spark 400 and DB35 can not handle pattern: left spaces + Thh:mm:ss, refer to the bug link -@pytest.mark.skipif(is_spark_400_or_later() or is_databricks_version_or_later(14, 3), - reason="https://github.com/NVIDIA/spark-rapids-jni/issues/3401") -def test_cast_string_to_timestamp_for_just_time_spaces_leading(): - def _gen_df(spark): - return spark.createDataFrame( - [ - (" \r\n\tT23:17:50 \r\n\t",), - (" \r\n\tT23:17:50",), - ], - 'str_col string') - - def _query(spark): - spark._jvm.com.nvidia.spark.rapids.jni.GpuTimeZoneDB.cacheDatabase(2200) - return _gen_df(spark).selectExpr("cast(str_col as timestamp)") - - assert_gpu_and_cpu_are_equal_collect(lambda spark: _query(spark)) - - def test_cast_string_to_timestamp_valid_just_time_with_timezone(): # For the just time strings, will get current date to fill the missing date. # E.g.: "T00:00:00" will be "2025-05-23T00:00:00" diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index af9b534ec2f..c18dfbe2f7b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -1370,8 +1370,8 @@ object GpuCast { defaultTimeZone: Option[String] = Option.empty[String]): ColumnVector = { val tz = defaultTimeZone.getOrElse("Z") val normalizedTZ = ZoneId.of(tz, ZoneId.SHORT_IDS).normalized().toString - val isSpark320 = VersionUtils.cmpSparkVersion(3, 2, 0) == 0 - closeOnExcept(CastStrings.toTimestamp(input, normalizedTZ, ansiMode, isSpark320)) { result => + val versionForJni = VersionUtils.getVersionForJni + closeOnExcept(CastStrings.toTimestamp(input, normalizedTZ, ansiMode, versionForJni)) { result => if (ansiMode && result == null) { throw new DateTimeException("One or more values is not a valid timestamp") } else { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala index a9ac8fc1b72..3456f0ca66f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala @@ -16,6 +16,8 @@ package com.nvidia.spark.rapids +import com.nvidia.spark.rapids.jni.{Version => VersionForJni, Platform => PlatformForJni} + object VersionUtils { lazy val isSpark320OrLater: Boolean = cmpSparkVersion(3, 2, 0) >= 0 @@ -43,4 +45,20 @@ object VersionUtils { val sparkFullVersion = ((sparkMajor.toLong * 1000) + sparkMinor) * 1000 + sparkBugfix sparkFullVersion.compareTo(fullVersion) } + + /** + * Get the version used by JNI interface + * Must use `com.nvidia.spark.rapids.jni.Version` in the JNI interface + */ + def getVersionForJni: VersionForJni = { + val sparkShimVersion = ShimLoader.getShimVersion + sparkShimVersion match { + case SparkShimVersion(a, b, c) => + new VersionForJni(PlatformForJni.SPARK, a, b, c) + case DatabricksShimVersion(a, b, c, _) => + new VersionForJni(PlatformForJni.DATABRICKS, a, b, c) + case ClouderaShimVersion(a, b, c, _) => + new VersionForJni(PlatformForJni.CLOUDERA, a, b, c) + } + } }