Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,6 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
new CloseableCHColumnBatchIterator(iter, Some(pipelineTime))
}

// only set file schema for text format table
private def setFileSchemaForLocalFiles(
localFilesNode: LocalFilesNode,
scan: BasicScanExecTransformer): Unit = {
if (scan.fileFormat == ReadFileFormat.TextReadFormat) {
val names =
ConverterUtils.collectAttributeNamesWithoutExprId(scan.output)
localFilesNode.setFileSchema(getFileSchema(scan.getDataSchema, names.asScala.toSeq))
}
}

override def genSplitInfo(
partition: InputPartition,
partitionSchema: StructType,
Expand Down Expand Up @@ -248,9 +237,6 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
case (split, i) =>
split match {
case filesNode: LocalFilesNode if leaves(i).isInstanceOf[BasicScanExecTransformer] =>
setFileSchemaForLocalFiles(
filesNode,
leaves(i).asInstanceOf[BasicScanExecTransformer])
filesNode.toProtobuf.toByteArray
case extensionTableNode: ExtensionTableNode =>
extensionTableNode.toProtobuf.toByteArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ case class CHRangeExecTransformer(
nameList,
columnTypeNodes,
null,
null,
extensionNode,
context,
context.nextOperatorId(this.nodeName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ object MergeTreePartsPartitionsUtil extends Logging {
typeNodes,
nameList,
columnTypeNodes,
tableSchema,
transformer.map(_.doTransform(substraitContext)).orNull,
extensionNode,
substraitContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ object CHMergeTreeWriterInjects {
typeNodes,
nameList,
columnTypeNodes,
tableSchema,
null,
extensionNode,
substraitContext,
Expand Down
7 changes: 6 additions & 1 deletion cpp-ch/local-engine/Parser/RelParsers/ReadRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,12 @@ QueryPlanStepPtr ReadRelParser::parseReadRelWithLocalFile(const substrait::ReadR
debug::dumpMessage(local_files, "local_files");
}

auto source = std::make_shared<SubstraitFileSource>(getContext(), header, local_files);
DB::Block tableHeader = header;
if (rel.has_table_schema()) {
tableHeader = TypeParser::buildBlockFromNamedStructWithoutDFS(rel.table_schema());
}

auto source = std::make_shared<SubstraitFileSource>(getContext(), header, local_files, tableHeader);
auto source_pipe = Pipe(source);
auto source_step = std::make_unique<SubstraitFileSourceStep>(getContext(), std::move(source_pipe), "substrait local files");
source_step->setStepDescription("read local files");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ ExcelTextFormatFile::createInputFormat(const DB::Block & header, const std::shar

std::shared_ptr<DB::PeekableReadBuffer> buffer = std::make_unique<DB::PeekableReadBuffer>(*read_buffer);
DB::Names column_names;
column_names.reserve(file_info.schema().names_size());
for (const auto & item : file_info.schema().names())
column_names.reserve(schema_.getNames().size());
for (const auto & item : schema_.getNames())
column_names.push_back(item);

auto txt_input_format
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ class ExcelTextFormatFile : public FormatFile

public:
explicit ExcelTextFormatFile(
DB::ContextPtr context_, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_)
: FormatFile(context_, file_info_, read_buffer_builder_)
DB::ContextPtr context_, const DB::Block& input_header, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_)
: FormatFile(context_, file_info_, read_buffer_builder_),
schema_(input_header.getNamesAndTypesList())
{
}

Expand All @@ -55,6 +56,8 @@ class ExcelTextFormatFile : public FormatFile

private:
DB::FormatSettings createFormatSettings() const;

const DB::NamesAndTypesList schema_;
};


Expand Down
6 changes: 3 additions & 3 deletions cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ FormatFile::FormatFile(DB::ContextPtr context_, const SubstraitInputFile & file_
}

FormatFilePtr FormatFileUtil::createFile(
DB::ContextPtr context, ReadBufferBuilderPtr read_buffer_builder, const substrait::ReadRel::LocalFiles::FileOrFiles & file)
DB::ContextPtr context, const DB::Block& input_header, ReadBufferBuilderPtr read_buffer_builder, const substrait::ReadRel::LocalFiles::FileOrFiles & file)
{
#if USE_PARQUET
if (file.has_parquet() || (file.has_iceberg() && file.iceberg().has_parquet()))
Expand All @@ -160,9 +160,9 @@ FormatFilePtr FormatFileUtil::createFile(
if (file.has_text())
{
if (ExcelTextFormatFile::useThis(context))
return std::make_shared<ExcelTextFormatFile>(context, file, read_buffer_builder);
return std::make_shared<ExcelTextFormatFile>(context, input_header, file, read_buffer_builder);
else
return std::make_shared<TextFormatFile>(context, file, read_buffer_builder);
return std::make_shared<TextFormatFile>(context, input_header, file, read_buffer_builder);
}
#endif

Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,6 @@ class FormatFileUtil
{
public:
static FormatFilePtr
createFile(DB::ContextPtr context, ReadBufferBuilderPtr read_buffer_builder, const substrait::ReadRel::LocalFiles::FileOrFiles & file);
createFile(DB::ContextPtr context, const DB::Block& input_header, ReadBufferBuilderPtr read_buffer_builder, const substrait::ReadRel::LocalFiles::FileOrFiles & file);
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
namespace local_engine
{

static std::vector<FormatFilePtr> initializeFiles(const substrait::ReadRel::LocalFiles & file_infos, const DB::ContextPtr & context)
static std::vector<FormatFilePtr> initializeFiles(const substrait::ReadRel::LocalFiles & file_infos, const DB::ContextPtr & context, const DB::Block & input_header)
{
if (file_infos.items().empty())
return {};
std::vector<FormatFilePtr> files;
const Poco::URI file_uri(file_infos.items().Get(0).uri_file());
ReadBufferBuilderPtr read_buffer_builder = ReadBufferBuilderFactory::instance().createBuilder(file_uri.getScheme(), context);
for (const auto & item : file_infos.items())
files.emplace_back(FormatFileUtil::createFile(context, read_buffer_builder, item));
files.emplace_back(FormatFileUtil::createFile(context, input_header, read_buffer_builder, item));
return files;
}

Expand All @@ -53,9 +53,9 @@ static DB::Block initReadHeader(const DB::Block & block, const FormatFiles & fil
}

SubstraitFileSource::SubstraitFileSource(
const DB::ContextPtr & context_, const DB::Block & outputHeader_, const substrait::ReadRel::LocalFiles & file_infos)
const DB::ContextPtr & context_, const DB::Block & outputHeader_, const substrait::ReadRel::LocalFiles & file_infos, const DB::Block & input_header_)
: DB::ISource(toShared(BaseReader::buildRowCountHeader(outputHeader_)), false)
, files(initializeFiles(file_infos, context_))
, files(initializeFiles(file_infos, context_, input_header_))
, outputHeader(outputHeader_)
, readHeader(initReadHeader(outputHeader, files))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using FormatFiles = std::vector<FormatFilePtr>;
class SubstraitFileSource : public DB::ISource
{
public:
SubstraitFileSource(const DB::ContextPtr & context_, const DB::Block & header_, const substrait::ReadRel::LocalFiles & file_infos);
SubstraitFileSource(const DB::ContextPtr & context_, const DB::Block & header_, const substrait::ReadRel::LocalFiles & file_infos, const DB::Block & input_header_);
~SubstraitFileSource() override;

String getName() const override { return "SubstraitFileSource"; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ namespace local_engine
{

TextFormatFile::TextFormatFile(
DB::ContextPtr context_, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_)
: FormatFile(context_, file_info_, read_buffer_builder_)
DB::ContextPtr context_, const DB::Block& input_header, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_)
: FormatFile(context_, file_info_, read_buffer_builder_),
schema_(input_header.getNamesAndTypesList())
{
}

Expand Down
12 changes: 5 additions & 7 deletions cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,19 @@ class TextFormatFile : public FormatFile
{
public:
explicit TextFormatFile(
DB::ContextPtr context_, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_);
DB::ContextPtr context_, const DB::Block& input_header, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_);
~TextFormatFile() override = default;

FormatFile::InputFormatPtr
createInputFormat(const DB::Block & header, const std::shared_ptr<const DB::ActionsDAG> & filter_actions_dag = nullptr) override;

DB::NamesAndTypesList getSchema() const
{
const auto & schema = file_info.schema();
auto header = TypeParser::buildBlockFromNamedStructWithoutDFS(schema);
return header.getNamesAndTypesList();
}
DB::NamesAndTypesList getSchema() const { return schema_; }

bool supportSplit() const override { return true; }
String getFileFormat() const override { return "HiveText"; }

private:
const DB::NamesAndTypesList schema_;
};

}
Expand Down
31 changes: 25 additions & 6 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1282,26 +1282,45 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
SubstraitParser::parseColumnTypes(baseSchema, columnTypes);
}

// Velox requires Filter Pushdown must being enabled.
bool filterPushdownEnabled = true;
auto names = colNameList;
auto types = veloxTypeList;
auto dataColumns = ROW(std::move(names), std::move(types));
// The columns we project from the file.
auto baseSchema = ROW(std::move(names), std::move(types));
// The columns present in the table, if not available default to the baseSchema.
auto tableSchema = baseSchema;
if (readRel.has_table_schema()) {
const auto& tableSchemaStruct = readRel.table_schema();
std::vector<std::string> tableColNames;
std::vector<TypePtr> tableColTypes;
tableColNames.reserve(tableSchemaStruct.names().size());
for (const auto& name : tableSchemaStruct.names()) {
std::string fieldName = name;
if (asLowerCase) {
folly::toLowerAscii(fieldName);
}
tableColNames.emplace_back(fieldName);
}
tableColTypes = SubstraitParser::parseNamedStruct(tableSchemaStruct, asLowerCase);
tableSchema = ROW(std::move(tableColNames), std::move(tableColTypes));
}

// Velox requires Filter Pushdown must being enabled.
bool filterPushdownEnabled = true;
std::shared_ptr<connector::hive::HiveTableHandle> tableHandle;
if (!readRel.has_filter()) {
tableHandle = std::make_shared<connector::hive::HiveTableHandle>(
kHiveConnectorId, "hive_table", filterPushdownEnabled, common::SubfieldFilters{}, nullptr, dataColumns);
kHiveConnectorId, "hive_table", filterPushdownEnabled, common::SubfieldFilters{}, nullptr, tableSchema);
} else {
common::SubfieldFilters subfieldFilters;
auto remainingFilter = exprConverter_->toVeloxExpr(readRel.filter(), dataColumns);
auto remainingFilter = exprConverter_->toVeloxExpr(readRel.filter(), baseSchema);

tableHandle = std::make_shared<connector::hive::HiveTableHandle>(
kHiveConnectorId,
"hive_table",
filterPushdownEnabled,
std::move(subfieldFilters),
remainingFilter,
dataColumns);
tableSchema);
}

// Get assignments and out names.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ case class IcebergScanTransformer(
override lazy val getPartitionSchema: StructType =
GlutenIcebergSourceUtil.getReadPartitionSchema(scan)

override def getDataSchema: StructType = new StructType()
override def getDataSchema: StructType = GlutenIcebergSourceUtil.getDataSchema(scan)

// TODO: get root paths from table.
override def getRootPathsInternal: Seq[String] = Seq.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,17 @@ object GlutenIcebergSourceUtil {
throw new UnsupportedOperationException("Only support iceberg SparkBatchQueryScan.")
}

def getDataSchema(sparkScan: Scan): StructType = sparkScan match {
case scan: SparkBatchQueryScan =>
val tasks = scan.tasks().asScala
asFileScanTask(tasks.toList).foreach(
task => return SparkSchemaUtil.convert(task.spec().schema()))
throw new UnsupportedOperationException(
"Failed to get data schema from iceberg SparkBatchQueryScan.")
case _ =>
throw new UnsupportedOperationException("Only support iceberg SparkBatchQueryScan.")
}

private def asFileScanTask(tasks: List[ScanTask]): List[FileScanTask] = {
if (tasks.forall(_.isFileScanTask)) {
tasks.map(_.asFileScanTask())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@
package org.apache.gluten.substrait.rel;

import org.apache.gluten.config.GlutenConfig;
import org.apache.gluten.expression.ConverterUtils;
import org.apache.gluten.substrait.utils.SubstraitUtil;

import io.substrait.proto.NamedStruct;
import io.substrait.proto.ReadRel;
import io.substrait.proto.Type;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
Expand Down Expand Up @@ -106,21 +103,6 @@ public void setFileSchema(StructType schema) {
this.fileSchema = schema;
}

private NamedStruct buildNamedStruct() {
NamedStruct.Builder namedStructBuilder = NamedStruct.newBuilder();

if (fileSchema != null) {
Type.Struct.Builder structBuilder = Type.Struct.newBuilder();
for (StructField field : fileSchema.fields()) {
structBuilder.addTypes(
ConverterUtils.getTypeNode(field.dataType(), field.nullable()).toProtobuf());
namedStructBuilder.addNames(ConverterUtils.normalizeColName(field.name()));
}
namedStructBuilder.setStruct(structBuilder.build());
}
return namedStructBuilder.build();
}

@Override
public List<String> preferredLocations() {
return this.preferredLocations;
Expand Down Expand Up @@ -195,7 +177,7 @@ public ReadRel.LocalFiles toProtobuf() {
ReadRel.LocalFiles.FileOrFiles.metadataColumn.newBuilder();
fileBuilder.addMetadataColumns(mcBuilder.build());
}
NamedStruct namedStruct = buildNamedStruct();
NamedStruct namedStruct = org.apache.gluten.utils.SubstraitUtil.createNamedStruct(fileSchema);
fileBuilder.setSchema(namedStruct);

if (!otherMetadataColumns.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.substrait.proto.ReadRel;
import io.substrait.proto.Rel;
import io.substrait.proto.RelCommon;
import org.apache.spark.sql.types.StructType;

import java.io.Serializable;
import java.util.ArrayList;
Expand All @@ -35,18 +36,21 @@ public class ReadRelNode implements RelNode, Serializable {
private final List<TypeNode> types = new ArrayList<>();
private final List<String> names = new ArrayList<>();
private final List<ColumnTypeNode> columnTypeNodes = new ArrayList<>();
private final StructType tableSchema;
private final ExpressionNode filterNode;
private final AdvancedExtensionNode extensionNode;
private boolean streamKafka = false;

ReadRelNode(
List<TypeNode> types,
List<String> names,
StructType tableSchema,
ExpressionNode filterNode,
List<ColumnTypeNode> columnTypeNodes,
AdvancedExtensionNode extensionNode) {
this.types.addAll(types);
this.names.addAll(names);
this.tableSchema = tableSchema;
this.filterNode = filterNode;
this.columnTypeNodes.addAll(columnTypeNodes);
this.extensionNode = extensionNode;
Expand All @@ -69,6 +73,10 @@ public Rel toProtobuf() {
readBuilder.setBaseSchema(nStructBuilder.build());
readBuilder.setStreamKafka(streamKafka);

if (tableSchema != null) {
readBuilder.setTableSchema(SubstraitUtil.createNamedStruct(tableSchema));
}

if (filterNode != null) {
readBuilder.setFilter(filterNode.toProtobuf());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import io.substrait.proto.*;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.types.StructType;

import java.util.List;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -149,12 +150,13 @@ public static RelNode makeReadRel(
List<TypeNode> types,
List<String> names,
List<ColumnTypeNode> columnTypeNodes,
StructType tableSchema,
ExpressionNode filter,
AdvancedExtensionNode extensionNode,
SubstraitContext context,
Long operatorId) {
context.registerRelToOperator(operatorId);
return new ReadRelNode(types, names, filter, columnTypeNodes, extensionNode);
return new ReadRelNode(types, names, tableSchema, filter, columnTypeNodes, extensionNode);
}

public static RelNode makeReadRelForInputIterator(
Expand Down
Loading
Loading