Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.sql.delta

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.delta.stats.DeltaScan

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources

import org.apache.gluten.backendsapi.BackendsApiManager

import org.apache.spark.sql.SparkSession
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1408,8 +1408,11 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
|(8, '4Z+i4Z+g4Z+i4Z+lLeGfoeGfoS3hn6Hhn6M='),
|(9, null),
|(10, '4Keo4Kem4Keo4KerLeCnp+Cnpy3gp6fgp6k='),
|(11, 'MjAyNS0xMS0xMg==')
|(11, 'MjAyNS0xMS0xMg=='),
|(12, '4LmS4LmQ4LmS4LmVLeC5keC5kS3guZHguZM=')
|""".stripMargin)
// base64 inputs decode to local digit dates:
// 1-3 Arabic-Indic, 5 Persian, 7 Devanagari, 8 Khmer, 10 Bengali, 11 ASCII, 12 Thai
var query_sql = """
|select
|from_unixtime(unix_timestamp(cast(unbase64(d) as string), 'yyyy-MM-dd')),
Expand All @@ -1433,6 +1436,28 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
|""".stripMargin
compareResultsAgainstVanillaSpark(query_sql, true, { _ => })

query_sql = """
|select from_unixtime(
| unix_timestamp(
| regexp_replace(
| cast(unbase64('4LmS4LmQ4LmS4LmVLeC5keC5kS3guZHguZM=') as string),
| '-0', '-'),
| 'yyyy-MM-dd'),
| 'yyyy-MM-dd')
|""".stripMargin
compareResultsAgainstVanillaSpark(query_sql, true, { _ => })

query_sql = """
|select from_unixtime(
| unix_timestamp(
| regexp_replace(
| cast(unbase64('4Z+i4Z+g4Z+i4Z+lLeGfoeGfoS3hn6Hhn6M=') as string),
| '-0', '-'),
| 'yyyy-MM-dd'),
| 'yyyy-MM-dd')
|""".stripMargin
compareResultsAgainstVanillaSpark(query_sql, true, { _ => })

sql("drop table tb_local_date")
}
}
Expand Down
229 changes: 175 additions & 54 deletions cpp-ch/local-engine/Functions/LocalDigitsToAsciiDigitForDate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,31 @@
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <simdjson/implementation_detection.h>
#if SIMDJSON_IMPLEMENTATION_ICELAKE && defined(__AVX512F__) && defined(__AVX512BW__)
#include <simdjson/icelake/simd.h>
namespace simdjson_impl = simdjson::icelake::simd;
#elif SIMDJSON_IMPLEMENTATION_HASWELL && defined(__AVX2__)
#include <simdjson/haswell/simd.h>
namespace simdjson_impl = simdjson::haswell::simd;
#elif SIMDJSON_IMPLEMENTATION_WESTMERE && defined(__SSE4_2__)
#include <simdjson/westmere/simd.h>
namespace simdjson_impl = simdjson::westmere::simd;
#elif SIMDJSON_IMPLEMENTATION_ARM64
#include <simdjson/arm64/simd.h>
namespace simdjson_impl = simdjson::arm64::simd;
#elif SIMDJSON_IMPLEMENTATION_PPC64
#include <simdjson/ppc64/simd.h>
namespace simdjson_impl = simdjson::ppc64::simd;
#elif SIMDJSON_IMPLEMENTATION_LSX
#include <simdjson/lsx/simd.h>
namespace simdjson_impl = simdjson::lsx::simd;
#elif SIMDJSON_IMPLEMENTATION_LASX
#include <simdjson/lasx/simd.h>
namespace simdjson_impl = simdjson::lasx::simd;
#else
#define SIMDJSON_NO_SIMD 1
#endif
#include <boost/iostreams/detail/select.hpp>
#include <Common/Exception.h>
#include <Common/logger_useful.h>
Expand Down Expand Up @@ -97,7 +122,9 @@ class LocalDigitsToAsciiDigitForDateFunction : public DB::IFunction
getName(),
data_col->getName());
auto date_str = col_str->getDataAt(0);
auto new_str = convertLocalDigit(date_str);
std::string new_str;
if (!convertLocalDigitIfNeeded(date_str, new_str))
return arguments[0].column;
auto new_data_col = data_col->cloneEmpty();
new_data_col->insertData(new_str.c_str(), new_str.size());
return DB::ColumnConst::create(std::move(new_data_col), input_rows_count);
Expand All @@ -120,62 +147,43 @@ class LocalDigitsToAsciiDigitForDateFunction : public DB::IFunction
getName(),
data_col->getName());

auto nested_data_col = DB::removeNullable(arguments[0].column);
bool has_local_digit = false;
size_t row_index = 0;
for (row_index = 0; row_index < input_rows_count; ++row_index)
std::string converted;
DB::MutableColumnPtr res_col;
for (size_t row_index = 0; row_index < input_rows_count; ++row_index)
{
if (null_map && (*null_map)[row_index])
{
if (res_col)
res_col->insertDefault();
continue;
}
auto str = col_str->getDataAt(row_index);
if (hasLocalDigit(str))
if (convertLocalDigitIfNeeded(str, converted))
{
has_local_digit = true;
break;
if (!res_col)
{
res_col = data_col->cloneEmpty();
if (row_index)
res_col->insertManyFrom(*data_col, 0, row_index);
}
LOG_DEBUG(
getLogger("LocalDigitsToAsciiDigitForDateFunction"),
"Converted local digit string {} to ascii digit string: {}",
col_str->getDataAt(row_index).toString(),
converted);
res_col->insertData(converted.c_str(), converted.size());
}
}

if (!has_local_digit)
{
// No local language digits found, return the original column
return arguments[0].column;
}

auto res_col = data_col->cloneEmpty();
if (row_index)
{
res_col->insertManyFrom(*data_col, 0, row_index);
}
for (; row_index < input_rows_count; ++row_index)
{
if (null_map && (*null_map)[row_index])
else if (res_col)
{
res_col->insertDefault();
continue;
res_col->insertFrom(*data_col, row_index);
}
auto str = convertLocalDigit(col_str->getDataAt(row_index));
LOG_ERROR(getLogger("LocalDigitsToAsciiDigitForDateFunction"), "Converted local digit string {} to ascii digit string: {}", col_str->getDataAt(row_index).toString(), str);
res_col->insertData(str.c_str(), str.size());
}
if (!res_col)
return arguments[0].column;
return res_col;
}

private:
bool hasLocalDigit(StringRef str) const
{
if (!str.size)
return false;
// In most cases, the first byte is a digit.
char c = reinterpret_cast<char>(str.data[0]);
if ('0' <= c && c <= '9')
{
return false;
}
return true;
}

char toAsciiDigit(char32_t c) const {
// In Thai and Persian, dates typically do not use the Gregorian calendar.
// This may cause failures in unix_timestamp parsing.
Expand All @@ -195,41 +203,154 @@ class LocalDigitsToAsciiDigitForDateFunction : public DB::IFunction
return 0;
}

String convertLocalDigit(const StringRef & str) const
bool hasNonAsciiSimd(const char * data, size_t size) const
{
#if SIMDJSON_NO_SIMD
const unsigned char * bytes = reinterpret_cast<const unsigned char *>(data);
for (size_t i = 0; i < size; ++i)
{
if (bytes[i] & 0x80)
return true;
}
return false;
#else
using simd8_u8 = simdjson_impl::simd8<uint8_t>;
constexpr size_t kBlockSize = simd8_u8::SIZE;
size_t i = 0;
for (; i + kBlockSize <= size; i += kBlockSize)
{
if (!simd8_u8::load(reinterpret_cast<const uint8_t *>(data + i)).is_ascii())
return true;
}
for (; i < size; ++i)
{
if (static_cast<unsigned char>(data[i]) & 0x80)
return true;
}
return false;
#endif
}

bool convertLocalDigitIfNeeded(StringRef str, std::string & result) const
{
std::string result;
if (!str.size)
return false;
if (!hasNonAsciiSimd(str.data, str.size))
return false;
result.clear();
result.reserve(str.size);
bool has_local_digit = false;
for (size_t i = 0; i < str.size;)
{
unsigned char c = str.data[i];
char32_t cp = 0;
if ((c & 0x80) == 0) // 1-byte
{
cp = c;
result.push_back(c);
i += 1;
continue;
}
else if ((c & 0xE0) == 0xC0) // 2-byte
{
cp = ((c & 0x1F) << 6) | (str.data[i + 1] & 0x3F);
unsigned char b1 = str.data[i + 1];
if (c == 0xD9 && b1 >= 0xA0 && b1 <= 0xA9) // Arabic-Indic
{
result.push_back(static_cast<char>('0' + (b1 - 0xA0)));
has_local_digit = true;
i += 2;
continue;
}
if (c == 0xDB && b1 >= 0xB0 && b1 <= 0xB9) // Eastern Arabic-Indic (Persian)
{
result.push_back(static_cast<char>('0' + (b1 - 0xB0)));
has_local_digit = true;
i += 2;
continue;
}
cp = ((c & 0x1F) << 6) | (b1 & 0x3F);
auto local_digit = toAsciiDigit(cp);
if (local_digit)
{
result.push_back(local_digit);
has_local_digit = true;
}
else
{
result.push_back(static_cast<char>(c));
result.push_back(static_cast<char>(b1));
}
i += 2;
continue;
}
else if ((c & 0xF0) == 0xE0) // 3-byte
{
cp = ((c & 0x0F) << 12) | ((str.data[i + 1] & 0x3F) << 6) | (str.data[i + 2] & 0x3F);
unsigned char b1 = str.data[i + 1];
unsigned char b2 = str.data[i + 2];
if (c == 0xE0)
{
if ((b1 == 0xA5 && b2 >= 0xA6 && b2 <= 0xAF) || // Devanagari
(b1 == 0xA7 && b2 >= 0xA6 && b2 <= 0xAF)) // Bengali
{
result.push_back(static_cast<char>('0' + (b2 - 0xA6)));
has_local_digit = true;
i += 3;
continue;
}
if (b1 == 0xB9 && b2 >= 0x90 && b2 <= 0x99) // Thai
{
result.push_back(static_cast<char>('0' + (b2 - 0x90)));
has_local_digit = true;
i += 3;
continue;
}
}
else if (c == 0xE1 && b1 == 0x9F && b2 >= 0xA0 && b2 <= 0xA9) // Khmer
{
result.push_back(static_cast<char>('0' + (b2 - 0xA0)));
has_local_digit = true;
i += 3;
continue;
}
cp = ((c & 0x0F) << 12) | ((b1 & 0x3F) << 6) | (b2 & 0x3F);
auto local_digit = toAsciiDigit(cp);
if (local_digit)
{
result.push_back(local_digit);
has_local_digit = true;
}
else
{
result.push_back(static_cast<char>(c));
result.push_back(static_cast<char>(b1));
result.push_back(static_cast<char>(b2));
}
i += 3;
continue;
}
else if ((c & 0xF8) == 0xF0) // 4-byte
{
cp = ((c & 0x07) << 18) | ((str.data[i + 1] & 0x3F) << 12) | ((str.data[i + 2] & 0x3F) << 6) | (str.data[i + 3] & 0x3F);
unsigned char b1 = str.data[i + 1];
unsigned char b2 = str.data[i + 2];
unsigned char b3 = str.data[i + 3];
cp = ((c & 0x07) << 18) | ((b1 & 0x3F) << 12) | ((b2 & 0x3F) << 6) | (b3 & 0x3F);
auto local_digit = toAsciiDigit(cp);
if (local_digit)
{
result.push_back(local_digit);
has_local_digit = true;
}
else
{
result.push_back(static_cast<char>(c));
result.push_back(static_cast<char>(b1));
result.push_back(static_cast<char>(b2));
result.push_back(static_cast<char>(b3));
}
i += 4;
continue;
}
auto local_digit = toAsciiDigit(cp);
if (local_digit)
result.push_back(local_digit);
else
result.push_back(cp);
}
return result;
return has_local_digit;
}
};

Expand Down