Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,9 @@ dist/

/cpp-ch/local-engine/Parser/*_udf
!/cpp-ch/local-engine/Parser/example_udf


# build arrow
dev/arrow_ep/
ep/_ep/

Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package org.apache.gluten.planner.cost

import org.apache.gluten.execution._
import org.apache.gluten.extension.columnar.enumerated.RemoveFilter
import org.apache.gluten.extension.columnar.enumerated.RemoveFilter.NoopFilter
import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, RowToColumnarLike}
import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec
import org.apache.gluten.ras.{Cost, CostModel}
import org.apache.gluten.utils.PlanUtil

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.types.{ArrayType, MapType, StructType}

class VeloxCostModel extends CostModel[SparkPlan] with Logging {
import VeloxCostModel._
private val infLongCost = Long.MaxValue

logInfo(s"Created cost model: ${classOf[VeloxCostModel]}")

override def costOf(node: SparkPlan): GlutenCost = node match {
case _: GroupLeafExec => throw new IllegalStateException()
case _ => GlutenCost(longCostOf(node))
}

private def longCostOf(node: SparkPlan): Long = node match {
case n =>
val selfCost = selfLongCostOf(n)

// Sum with ceil to avoid overflow.
def safeSum(a: Long, b: Long): Long = {
assert(a >= 0)
assert(b >= 0)
val sum = a + b
if (sum < a || sum < b) Long.MaxValue else sum
}

(n.children.map(longCostOf).toList :+ selfCost).reduce(safeSum)
}

// A plan next to scan may be able to push runtime filters to scan in Velox.
// TODO: Check join build side
private def isNextToScanTransformer(plan: SparkPlan): Boolean = {
plan.children.exists {
case NoopFilter(_: BasicScanExecTransformer, _) => true
case _: BasicScanExecTransformer => true
case _ => false
}
}

// A very rough estimation as of now.
private def selfLongCostOf(node: SparkPlan): Long = {
node match {
case _: ProjectExecTransformer => 0L
case _: ProjectExec => 0L

case _: ShuffleExchangeExec => 3L
case _: VeloxResizeBatchesExec => 0L
case _: ColumnarShuffleExchangeExec => 2L

// Consider joins, aggregations, windows to be highest priority for Gluten to offload.
case _: BroadcastHashJoinExec => 60L
case j: BroadcastHashJoinExecTransformer if isNextToScanTransformer(j) => 32L
case _: BroadcastHashJoinExecTransformer => 52L

case _: ShuffledHashJoinExec => 80L
case j: ShuffledHashJoinExecTransformer if isNextToScanTransformer(j) => 32L
case _: ShuffledHashJoinExecTransformer => 52L

case _: SortMergeJoinExec => 80L
case j: SortMergeJoinExecTransformer if isNextToScanTransformer(j) => 32L
case _: SortMergeJoinExecTransformer => 52L

case _: HashAggregateExec => 80L
case _: ObjectHashAggregateExec => 80L
case _: SortAggregateExec => 80L
case _: HashAggregateExecTransformer => 52L

case _: WindowExec => 80L
case _: WindowExecTransformer => 52L

case r2c: RowToColumnarExecBase if hasComplexTypes(r2c.schema) =>
// Avoid moving computation back to native when transition has complex types in schema.
// Such transitions are observed to be extremely expensive as of now.
Long.MaxValue

// Row-to-Velox is observed much more expensive than Velox-to-row.
case ColumnarToRowExec(child) => 2L
case RowToColumnarExec(child) => 15L
case ColumnarToRowLike(child) => 2L
case RowToColumnarLike(child) => 15L

case _: RemoveFilter.NoopFilter =>
// To make planner choose the tree that has applied rule PushFilterToScan.
0L

case p if PlanUtil.isGlutenColumnarOp(p) => 2L
case p if PlanUtil.isVanillaColumnarOp(p) => 3L
// Other row ops. Usually a vanilla row op.
case _ => 5L
}
}

private def isCheapExpression(ne: NamedExpression): Boolean = ne match {
case Alias(_: Attribute, _) => true
case _: Attribute => true
case _ => false
}

private def hasComplexTypes(schema: StructType): Boolean = {
schema.exists(_.dataType match {
case _: StructType => true
case _: ArrayType => true
case _: MapType => true
case _ => false
})
}

override def costComparator(): Ordering[Cost] = Ordering.Long.on {
case GlutenCost(value) => value
case _ => throw new IllegalStateException("Unexpected cost type")
}

override def makeInfCost(): Cost = GlutenCost(infLongCost)
}

object VeloxCostModel {
case class GlutenCost(value: Long) extends Cost
}
7 changes: 7 additions & 0 deletions cpp/velox/udf/examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,10 @@ target_link_libraries(myudf velox)

add_library(myudaf SHARED "MyUDAF.cc")
target_link_libraries(myudaf velox)

add_library(LinkedIdUDF SHARED "LinkedIdUDF.cc")
target_link_libraries(LinkedIdUDF velox)

add_library(ListVectorSimilarityUDF SHARED "ListVectorSimilarityUDF.cc")
target_link_libraries(ListVectorSimilarityUDF velox)

117 changes: 117 additions & 0 deletions cpp/velox/udf/examples/LinkedIdUDF.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@

#include <velox/expression/VectorFunction.h>
#include <velox/functions/Macros.h>
#include <velox/functions/Registerer.h>
#include <iostream>
#include "udf/Udf.h"
#include "udf/examples/UdfCommon.h"
#include <string>
#include <stdexcept>
#include <memory>
#include <openssl/md5.h>

using namespace facebook::velox;
using namespace facebook::velox::exec;

namespace {

static const char* kString="string";

std::string GetHashKey(const std::string& str) {
if (str.empty()) {
throw std::invalid_argument("Input string cannot be null");
}
unsigned char digest[MD5_DIGEST_LENGTH];
MD5(reinterpret_cast<const unsigned char*>(str.c_str()), str.size(), digest);
std::string key;
key.reserve(MD5_DIGEST_LENGTH * 2);
for (unsigned char i : digest) {
char buf[3];
snprintf(buf, sizeof(buf), "%02x", i);
key.append(buf);
}
return key;
}

template <typename T>
struct LinkIdUDF {
VELOX_DEFINE_FUNCTION_TYPES(T);


void call(out_type<Varchar>& out, const arg_type<Varchar>& arg1) {

auto cppArg1 = std::string(arg1);
if (cppArg1.empty()) {
{out = "";
return;}
}
{out = GetHashKey(cppArg1);
return;}
}


};

class LinkIdUDFRegisterer final : public gluten::UdfRegisterer {
public:
int getNumUdf() override {
return 1;
}

void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
udfEntries[index++] = {name_.c_str(), kString, 1, myArg_};
}

void registerSignatures() override {
facebook::velox::registerFunction<LinkIdUDF, Varchar, Varchar>({name_});
}

private:
const std::string name_ = "com.pinterest.hadoop.hive.LinkIdUDF";
const char* myArg_[1] = {kString};
};

std::vector<std::shared_ptr<gluten::UdfRegisterer>>& globalRegisters() {
static std::vector<std::shared_ptr<gluten::UdfRegisterer>> registerers;
return registerers;
}

void setupRegisterers() {
static bool inited = false;
if (inited) {
return;
}
auto& registerers = globalRegisters();

registerers.push_back(std::make_shared<LinkIdUDFRegisterer>());

inited = true;
}
} // namespace

DEFINE_GET_NUM_UDF {
setupRegisterers();

int numUdf = 0;
for (const auto& registerer : globalRegisters()) {
numUdf += registerer->getNumUdf();
}
return numUdf;
}

DEFINE_GET_UDF_ENTRIES {
setupRegisterers();

int index = 0;
for (const auto& registerer : globalRegisters()) {
registerer->populateUdfEntries(index, udfEntries);
}
}

DEFINE_REGISTER_UDF {
setupRegisterers();

for (const auto& registerer : globalRegisters()) {
registerer->registerSignatures();
}
}
Loading