Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions velox/experimental/cudf/connectors/hive/CudfHiveDataSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
#include <cudf/stream_compaction.hpp>
#include <cudf/table/table.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/transform.hpp>

#include <cuda_runtime.h>
#include <nvtx3/nvtx3.hpp>
Expand Down Expand Up @@ -99,7 +98,7 @@ CudfHiveDataSource::CudfHiveDataSource(
VELOX_CHECK_NOT_NULL(
tableHandle_, "TableHandle must be an instance of HiveTableHandle");

// Copy subfield filters
// Copy subfield filters.
for (const auto& [k, v] : tableHandle_->subfieldFilters()) {
subfieldFilters_.emplace(k.clone(), v->clone());
// Add fields in the filter to the columns to read if not there
Expand Down Expand Up @@ -493,6 +492,40 @@ void CudfHiveDataSource::setupCudfDataSourceAndOptions() {
dataSource_ = std::move(cudf::io::make_datasources(sourceInfo).front());
}

RowTypePtr readerFilterType = nullptr;
bool hasDecimalFilter = false;
if (subfieldFilters_.size()) {
readerFilterType = [&] {
if (tableHandle_->dataColumns()) {
std::vector<std::string> newNames;
std::vector<TypePtr> newTypes;

for (const auto& name : readColumnNames_) {
// Ensure all columns being read are available to the filter
auto parsedType = tableHandle_->dataColumns()->findChild(name);
newNames.emplace_back(std::move(name));
newTypes.push_back(parsedType);
}

return ROW(std::move(newNames), std::move(newTypes));
} else {
return outputType_;
}
}();

for (const auto& [field, _] : subfieldFilters_) {
if (!field.valid()) {
continue;
}
const auto& fieldName = field.baseName();
const auto fieldType = readerFilterType->findChild(fieldName);
if (fieldType && fieldType->isDecimal()) {
hasDecimalFilter = true;
break;
}
}
}

// Reader options
readerOptions_ =
cudf::io::parquet_reader_options::builder(std::move(sourceInfo))
Expand All @@ -501,6 +534,7 @@ void CudfHiveDataSource::setupCudfDataSourceAndOptions() {
.allow_mismatched_pq_schemas(
cudfHiveConfig_->isAllowMismatchedCudfHiveSchemas())
.timestamp_type(cudfHiveConfig_->timestampType())
.use_jit_filter(hasDecimalFilter)
.build();

// Set skip_bytes and num_bytes if available
Expand All @@ -511,6 +545,7 @@ void CudfHiveDataSource::setupCudfDataSourceAndOptions() {
readerOptions_.set_num_bytes(split_->size());
}

// Set filter expression created in constructor if any subfield filters
if (subfieldFilterExpr_ != nullptr) {
readerOptions_.set_filter(*subfieldFilterExpr_);
}
Expand Down
134 changes: 94 additions & 40 deletions velox/experimental/cudf/exec/CudfHashJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "velox/experimental/cudf/exec/Utilities.h"
#include "velox/experimental/cudf/exec/VeloxCudfInterop.h"
#include "velox/experimental/cudf/expression/AstExpression.h"
#include "velox/experimental/cudf/expression/AstExpressionUtils.h"
#include "velox/experimental/cudf/expression/ExpressionEvaluator.h"

#include "velox/core/PlanNode.h"
Expand Down Expand Up @@ -416,6 +417,8 @@ CudfHashJoinProbe::CudfHashJoinProbe(
// simplify expression
exec::ExprSet exprs({joinNode_->filter()}, operatorCtx_->execCtx());
VELOX_CHECK_EQ(exprs.exprs().size(), 1);
useAstFilter_ = CudfConfig::getInstance().astExpressionEnabled &&
!containsDecimalType(exprs.exprs()[0]);

// Create a reusable evaluator for the filter column. This is expensive to
// build, and the expression + input schema are stable for the lifetime of
Expand All @@ -431,25 +434,27 @@ CudfHashJoinProbe::CudfHashJoinProbe(
// and the column locations in that schema translate to column locations
// in whole tables

// create ast tree
if (joinNode_->isRightJoin() || joinNode_->isRightSemiFilterJoin()) {
createAstTree(
exprs.exprs()[0],
tree_,
scalars_,
buildType_,
probeType_,
rightPrecomputeInstructions_,
leftPrecomputeInstructions_);
} else {
createAstTree(
exprs.exprs()[0],
tree_,
scalars_,
probeType_,
buildType_,
leftPrecomputeInstructions_,
rightPrecomputeInstructions_);
if (useAstFilter_) {
// create ast tree
if (joinNode_->isRightJoin() || joinNode_->isRightSemiFilterJoin()) {
createAstTree(
exprs.exprs()[0],
tree_,
scalars_,
buildType_,
probeType_,
rightPrecomputeInstructions_,
leftPrecomputeInstructions_);
} else {
createAstTree(
exprs.exprs()[0],
tree_,
scalars_,
probeType_,
buildType_,
leftPrecomputeInstructions_,
rightPrecomputeInstructions_);
}
}
}
}
Expand Down Expand Up @@ -803,15 +808,35 @@ std::vector<std::unique_ptr<cudf::table>> CudfHashJoinProbe::innerJoin(
std::vector<std::unique_ptr<cudf::column>> joinedCols;

if (joinNode_->filter()) {
cudfOutputs.push_back(filteredOutputIndices(
leftTableView,
leftIndicesCol,
rightTableView,
rightIndicesCol,
extendedLeftView,
extendedRightView,
cudf::join_kind::INNER_JOIN,
stream));
if (useAstFilter_) {
cudfOutputs.push_back(filteredOutputIndices(
leftTableView,
leftIndicesCol,
rightTableView,
rightIndicesCol,
extendedLeftView,
extendedRightView,
cudf::join_kind::INNER_JOIN,
stream));
} else {
auto filterFunc =
[stream](
std::vector<std::unique_ptr<cudf::column>>&& joinedCols,
cudf::column_view filterColumn) {
auto filterTable =
std::make_unique<cudf::table>(std::move(joinedCols));
auto filteredTable = cudf::apply_boolean_mask(
*filterTable, filterColumn, stream, get_output_mr());
return filteredTable->release();
};
cudfOutputs.push_back(filteredOutput(
leftTableView,
leftIndicesCol,
rightTableView,
rightIndicesCol,
filterFunc,
stream));
}
} else {
cudfOutputs.push_back(unfilteredOutput(
leftTableView,
Expand Down Expand Up @@ -878,15 +903,35 @@ std::vector<std::unique_ptr<cudf::table>> CudfHashJoinProbe::leftJoin(
std::vector<std::unique_ptr<cudf::column>> joinedCols;

if (joinNode_->filter()) {
cudfOutputs.push_back(filteredOutputIndices(
leftTableView,
leftIndicesCol,
rightTableView,
rightIndicesCol,
extendedLeftView,
extendedRightView,
cudf::join_kind::LEFT_JOIN,
stream));
if (useAstFilter_) {
cudfOutputs.push_back(filteredOutputIndices(
leftTableView,
leftIndicesCol,
rightTableView,
rightIndicesCol,
extendedLeftView,
extendedRightView,
cudf::join_kind::LEFT_JOIN,
stream));
} else {
auto filterFunc =
[stream](
std::vector<std::unique_ptr<cudf::column>>&& joinedCols,
cudf::column_view filterColumn) {
auto filterTable =
std::make_unique<cudf::table>(std::move(joinedCols));
auto filteredTable = cudf::apply_boolean_mask(
*filterTable, filterColumn, stream, get_output_mr());
return filteredTable->release();
};
cudfOutputs.push_back(filteredOutput(
leftTableView,
leftIndicesCol,
rightTableView,
rightIndicesCol,
filterFunc,
stream));
}
} else {
cudfOutputs.push_back(unfilteredOutput(
leftTableView,
Expand Down Expand Up @@ -1194,6 +1239,9 @@ std::vector<std::unique_ptr<cudf::table>> CudfHashJoinProbe::leftSemiFilterJoin(
std::unique_ptr<rmm::device_uvector<cudf::size_type>> leftJoinIndices;

if (joinNode_->filter()) {
if (!useAstFilter_) {
VELOX_NYI("Join filter requires AST for semi joins");
}
leftJoinIndices = cudf::mixed_left_semi_join(
leftTableView.select(leftKeyIndices_),
rightTableView.select(rightKeyIndices_),
Expand Down Expand Up @@ -1244,6 +1292,9 @@ CudfHashJoinProbe::rightSemiFilterJoin(

std::unique_ptr<rmm::device_uvector<cudf::size_type>> rightJoinIndices;
if (joinNode_->filter()) {
if (!useAstFilter_) {
VELOX_NYI("Join filter requires AST for semi joins");
}
rightJoinIndices = cudf::mixed_left_semi_join(
rightTableView.select(rightKeyIndices_),
leftTableView.select(leftKeyIndices_),
Expand Down Expand Up @@ -1313,6 +1364,9 @@ std::vector<std::unique_ptr<cudf::table>> CudfHashJoinProbe::antiJoin(

std::unique_ptr<rmm::device_uvector<cudf::size_type>> leftJoinIndices;
if (joinNode_->filter()) {
if (!useAstFilter_) {
VELOX_NYI("Join filter requires AST for anti joins");
}
leftJoinIndices = cudf::mixed_left_anti_join(
leftTableView.select(leftKeyIndices_),
rightTableView.select(rightKeyIndices_),
Expand Down Expand Up @@ -1402,10 +1456,10 @@ RowVectorPtr CudfHashJoinProbe::getOutput() {
for (size_t li = 0; li < leftColumnOutputIndices_.size(); ++li) {
auto outIdx = leftColumnOutputIndices_[li];
auto probeChannel = leftColumnIndicesToGather_[li];
auto leftCudfType =
veloxToCudfTypeId(probeType_->childAt(probeChannel));
auto leftCudfDataType =
veloxToCudfDataType(probeType_->childAt(probeChannel));
auto nullScalar = cudf::make_default_constructed_scalar(
cudf::data_type{leftCudfType}, stream, get_temp_mr());
leftCudfDataType, stream, get_temp_mr());
outCols[outIdx] = cudf::make_column_from_scalar(
*nullScalar, m, stream, get_output_mr());
}
Expand Down
1 change: 1 addition & 0 deletions velox/experimental/cudf/exec/CudfHashJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class CudfHashJoinProbe : public exec::Operator, public NvtxHelper {
/** @brief Output column positions for right table columns */
std::vector<size_t> rightColumnOutputIndices_;
bool finished_{false};
bool useAstFilter_{true};

// Copied from HashProbe.h
// Indicates whether to skip probe input data processing or not. It only
Expand Down
3 changes: 1 addition & 2 deletions velox/experimental/cudf/expression/AstExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ ColumnOrView ASTExpression::eval(
}
}();
if (finalize) {
const auto requestedType =
cudf::data_type(cudf_velox::veloxToCudfTypeId(expr_->type()));
const auto requestedType = cudf_velox::veloxToCudfDataType(expr_->type());
auto resultView = asView(result);
if (resultView.type() != requestedType) {
result = cudf::cast(resultView, requestedType, stream, mr);
Expand Down
2 changes: 2 additions & 0 deletions velox/experimental/cudf/expression/AstExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#include <cudf/ast/expressions.hpp>

#include <utility>

namespace facebook::velox::cudf_velox {

const std::string kAstEvaluatorName = "ast";
Expand Down
20 changes: 16 additions & 4 deletions velox/experimental/cudf/expression/AstExpressionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "velox/experimental/cudf/expression/AstUtils.h"
// TODO(kn): in another PR
// #include "velox/experimental/cudf/CudfNoDefaults.h"
#include "velox/experimental/cudf/expression/DecimalUtils.h"

#include "velox/expression/ConstantExpr.h"
#include "velox/expression/FieldReference.h"
Expand Down Expand Up @@ -225,14 +226,22 @@ bool isAstExprSupported(const std::shared_ptr<velox::exec::Expr>& expr) {
using velox::exec::FieldReference;
using Op = cudf::ast::ast_operator;

// reject anything with DECIMAL for now
// @TODO implement DECIMAL in AST and JIT
if (containsDecimalType(expr)) {
LOG(WARNING) << "DECIMAL expression not supported by AST/JIT: "
<< expr->toString();
return false;
}

const auto name =
stripPrefix(expr->name(), CudfConfig::getInstance().functionNamePrefix);
const auto len = expr->inputs().size();

// Literals and field references are always supported
auto isSupportedLiteral = [&](const TypePtr& type) {
try {
auto cudfType = cudf::data_type(veloxToCudfTypeId(type));
auto cudfType = veloxToCudfDataType(type);
return cudf::is_fixed_width(cudfType) ||
cudfType.id() == cudf::type_id::STRING;
} catch (...) {
Expand Down Expand Up @@ -260,8 +269,7 @@ bool isAstExprSupported(const std::shared_ptr<velox::exec::Expr>& expr) {
inputCudfDataTypes.reserve(len);
for (const auto& input : expr->inputs()) {
try {
inputCudfDataTypes.push_back(
cudf::data_type(veloxToCudfTypeId(input->type())));
inputCudfDataTypes.push_back(veloxToCudfDataType(input->type()));
} catch (...) {
return false;
}
Expand Down Expand Up @@ -386,7 +394,11 @@ cudf::ast::expression const& AstContext::addPrecomputeInstructionOnSide(
auto nestedIndices = getNestedColumnIndices(
inputRowSchema[sideIdx].get()->childAt(columnIndex), fieldName);
precomputeInstructions[sideIdx].get().emplace_back(
columnIndex, instruction, newColumnIndex, nestedIndices, node);
columnIndex,
instruction,
newColumnIndex,
std::move(nestedIndices),
node);
}
auto side = static_cast<cudf::ast::table_reference>(sideIdx);
return tree.push(cudf::ast::column_reference(newColumnIndex, side));
Expand Down
Loading
Loading