diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala index 62ca100cc354..160ffa1425fa 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala @@ -113,17 +113,6 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { new CloseableCHColumnBatchIterator(iter, Some(pipelineTime)) } - // only set file schema for text format table - private def setFileSchemaForLocalFiles( - localFilesNode: LocalFilesNode, - scan: BasicScanExecTransformer): Unit = { - if (scan.fileFormat == ReadFileFormat.TextReadFormat) { - val names = - ConverterUtils.collectAttributeNamesWithoutExprId(scan.output) - localFilesNode.setFileSchema(getFileSchema(scan.getDataSchema, names.asScala.toSeq)) - } - } - override def genSplitInfo( partition: InputPartition, partitionSchema: StructType, @@ -248,9 +237,6 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { case (split, i) => split match { case filesNode: LocalFilesNode if leaves(i).isInstanceOf[BasicScanExecTransformer] => - setFileSchemaForLocalFiles( - filesNode, - leaves(i).asInstanceOf[BasicScanExecTransformer]) filesNode.toProtobuf.toByteArray case extensionTableNode: ExtensionTableNode => extensionTableNode.toProtobuf.toByteArray diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHRangeExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHRangeExecTransformer.scala index bdb716c67640..5acd7d7abac9 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHRangeExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHRangeExecTransformer.scala @@ -97,6 +97,7 @@ case class CHRangeExecTransformer( nameList, columnTypeNodes, null, + null, extensionNode, context, context.nextOperatorId(this.nodeName)) diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/clickhouse/utils/MergeTreePartsPartitionsUtil.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/clickhouse/utils/MergeTreePartsPartitionsUtil.scala index 749de3f9d493..2c07bfae6d72 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/clickhouse/utils/MergeTreePartsPartitionsUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/clickhouse/utils/MergeTreePartsPartitionsUtil.scala @@ -635,6 +635,7 @@ object MergeTreePartsPartitionsUtil extends Logging { typeNodes, nameList, columnTypeNodes, + tableSchema, transformer.map(_.doTransform(substraitContext)).orNull, extensionNode, substraitContext, diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala index cf5a0cadba2c..5d6f5fc9d00e 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala @@ -200,6 +200,7 @@ object CHMergeTreeWriterInjects { typeNodes, nameList, columnTypeNodes, + tableSchema, null, extensionNode, substraitContext, diff --git a/cpp-ch/local-engine/Parser/RelParsers/ReadRelParser.cpp b/cpp-ch/local-engine/Parser/RelParsers/ReadRelParser.cpp index 455219e43263..9f76a0edf839 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/ReadRelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParsers/ReadRelParser.cpp @@ -191,7 +191,12 @@ QueryPlanStepPtr ReadRelParser::parseReadRelWithLocalFile(const substrait::ReadR debug::dumpMessage(local_files, "local_files"); } - auto source = std::make_shared(getContext(), header, local_files); + DB::Block tableHeader = header; + if (rel.has_table_schema()) { + tableHeader = TypeParser::buildBlockFromNamedStructWithoutDFS(rel.table_schema()); + } + + auto source = std::make_shared(getContext(), header, local_files, tableHeader); auto source_pipe = Pipe(source); auto source_step = std::make_unique(getContext(), std::move(source_pipe), "substrait local files"); source_step->setStepDescription("read local files"); diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/ExcelTextFormatFile.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/ExcelTextFormatFile.cpp index c87facdd9b59..40addd5f18df 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/ExcelTextFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/ExcelTextFormatFile.cpp @@ -74,8 +74,8 @@ ExcelTextFormatFile::createInputFormat(const DB::Block & header, const std::shar std::shared_ptr buffer = std::make_unique(*read_buffer); DB::Names column_names; - column_names.reserve(file_info.schema().names_size()); - for (const auto & item : file_info.schema().names()) + column_names.reserve(schema_.getNames().size()); + for (const auto & item : schema_.getNames()) column_names.push_back(item); auto txt_input_format diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/ExcelTextFormatFile.h b/cpp-ch/local-engine/Storages/SubstraitSource/ExcelTextFormatFile.h index b762fce57637..ef740f1da2c3 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/ExcelTextFormatFile.h +++ b/cpp-ch/local-engine/Storages/SubstraitSource/ExcelTextFormatFile.h @@ -40,8 +40,9 @@ class ExcelTextFormatFile : public FormatFile public: explicit ExcelTextFormatFile( - DB::ContextPtr context_, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_) - : FormatFile(context_, file_info_, read_buffer_builder_) + DB::ContextPtr context_, const DB::Block& input_header, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_) + : FormatFile(context_, file_info_, read_buffer_builder_), + schema_(input_header.getNamesAndTypesList()) { } @@ -55,6 +56,8 @@ class ExcelTextFormatFile : public FormatFile private: DB::FormatSettings createFormatSettings() const; + + const DB::NamesAndTypesList schema_; }; diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.cpp index 6829d34d4adb..3537259b4534 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.cpp @@ -141,7 +141,7 @@ FormatFile::FormatFile(DB::ContextPtr context_, const SubstraitInputFile & file_ } FormatFilePtr FormatFileUtil::createFile( - DB::ContextPtr context, ReadBufferBuilderPtr read_buffer_builder, const substrait::ReadRel::LocalFiles::FileOrFiles & file) + DB::ContextPtr context, const DB::Block& input_header, ReadBufferBuilderPtr read_buffer_builder, const substrait::ReadRel::LocalFiles::FileOrFiles & file) { #if USE_PARQUET if (file.has_parquet() || (file.has_iceberg() && file.iceberg().has_parquet())) @@ -160,9 +160,9 @@ FormatFilePtr FormatFileUtil::createFile( if (file.has_text()) { if (ExcelTextFormatFile::useThis(context)) - return std::make_shared(context, file, read_buffer_builder); + return std::make_shared(context, input_header, file, read_buffer_builder); else - return std::make_shared(context, file, read_buffer_builder); + return std::make_shared(context, input_header, file, read_buffer_builder); } #endif diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.h b/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.h index 60d28f5bbae4..3be8a2919e6e 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.h +++ b/cpp-ch/local-engine/Storages/SubstraitSource/FormatFile.h @@ -165,6 +165,6 @@ class FormatFileUtil { public: static FormatFilePtr - createFile(DB::ContextPtr context, ReadBufferBuilderPtr read_buffer_builder, const substrait::ReadRel::LocalFiles::FileOrFiles & file); + createFile(DB::ContextPtr context, const DB::Block& input_header, ReadBufferBuilderPtr read_buffer_builder, const substrait::ReadRel::LocalFiles::FileOrFiles & file); }; } diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.cpp index e32de4b28cd4..1ee38a574ce6 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.cpp @@ -25,7 +25,7 @@ namespace local_engine { -static std::vector initializeFiles(const substrait::ReadRel::LocalFiles & file_infos, const DB::ContextPtr & context) +static std::vector initializeFiles(const substrait::ReadRel::LocalFiles & file_infos, const DB::ContextPtr & context, const DB::Block & input_header) { if (file_infos.items().empty()) return {}; @@ -33,7 +33,7 @@ static std::vector initializeFiles(const substrait::ReadRel::Loca const Poco::URI file_uri(file_infos.items().Get(0).uri_file()); ReadBufferBuilderPtr read_buffer_builder = ReadBufferBuilderFactory::instance().createBuilder(file_uri.getScheme(), context); for (const auto & item : file_infos.items()) - files.emplace_back(FormatFileUtil::createFile(context, read_buffer_builder, item)); + files.emplace_back(FormatFileUtil::createFile(context, input_header, read_buffer_builder, item)); return files; } @@ -53,9 +53,9 @@ static DB::Block initReadHeader(const DB::Block & block, const FormatFiles & fil } SubstraitFileSource::SubstraitFileSource( - const DB::ContextPtr & context_, const DB::Block & outputHeader_, const substrait::ReadRel::LocalFiles & file_infos) + const DB::ContextPtr & context_, const DB::Block & outputHeader_, const substrait::ReadRel::LocalFiles & file_infos, const DB::Block & input_header_) : DB::ISource(toShared(BaseReader::buildRowCountHeader(outputHeader_)), false) - , files(initializeFiles(file_infos, context_)) + , files(initializeFiles(file_infos, context_, input_header_)) , outputHeader(outputHeader_) , readHeader(initReadHeader(outputHeader, files)) { diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.h b/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.h index b40dd9a82d11..24e1400e7b70 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.h +++ b/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.h @@ -38,7 +38,7 @@ using FormatFiles = std::vector; class SubstraitFileSource : public DB::ISource { public: - SubstraitFileSource(const DB::ContextPtr & context_, const DB::Block & header_, const substrait::ReadRel::LocalFiles & file_infos); + SubstraitFileSource(const DB::ContextPtr & context_, const DB::Block & header_, const substrait::ReadRel::LocalFiles & file_infos, const DB::Block & input_header_); ~SubstraitFileSource() override; String getName() const override { return "SubstraitFileSource"; } diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.cpp index 2b9f3c225de5..3391dd1b1ecc 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.cpp @@ -28,8 +28,9 @@ namespace local_engine { TextFormatFile::TextFormatFile( - DB::ContextPtr context_, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_) - : FormatFile(context_, file_info_, read_buffer_builder_) + DB::ContextPtr context_, const DB::Block& input_header, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_) + : FormatFile(context_, file_info_, read_buffer_builder_), + schema_(input_header.getNamesAndTypesList()) { } diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.h b/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.h index 0e6827e6530b..3c88d0a6405d 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.h +++ b/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.h @@ -29,21 +29,19 @@ class TextFormatFile : public FormatFile { public: explicit TextFormatFile( - DB::ContextPtr context_, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_); + DB::ContextPtr context_, const DB::Block& input_header, const substrait::ReadRel::LocalFiles::FileOrFiles & file_info_, ReadBufferBuilderPtr read_buffer_builder_); ~TextFormatFile() override = default; FormatFile::InputFormatPtr createInputFormat(const DB::Block & header, const std::shared_ptr & filter_actions_dag = nullptr) override; - DB::NamesAndTypesList getSchema() const - { - const auto & schema = file_info.schema(); - auto header = TypeParser::buildBlockFromNamedStructWithoutDFS(schema); - return header.getNamesAndTypesList(); - } + DB::NamesAndTypesList getSchema() const { return schema_; } bool supportSplit() const override { return true; } String getFileFormat() const override { return "HiveText"; } + +private: + const DB::NamesAndTypesList schema_; }; } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index 2df53f56e406..ea26ff0b8c6c 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -1282,18 +1282,37 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: SubstraitParser::parseColumnTypes(baseSchema, columnTypes); } - // Velox requires Filter Pushdown must being enabled. - bool filterPushdownEnabled = true; auto names = colNameList; auto types = veloxTypeList; - auto dataColumns = ROW(std::move(names), std::move(types)); + // The columns we project from the file. + auto baseSchema = ROW(std::move(names), std::move(types)); + // The columns present in the table, if not available default to the baseSchema. + auto tableSchema = baseSchema; + if (readRel.has_table_schema()) { + const auto& tableSchemaStruct = readRel.table_schema(); + std::vector tableColNames; + std::vector tableColTypes; + tableColNames.reserve(tableSchemaStruct.names().size()); + for (const auto& name : tableSchemaStruct.names()) { + std::string fieldName = name; + if (asLowerCase) { + folly::toLowerAscii(fieldName); + } + tableColNames.emplace_back(fieldName); + } + tableColTypes = SubstraitParser::parseNamedStruct(tableSchemaStruct, asLowerCase); + tableSchema = ROW(std::move(tableColNames), std::move(tableColTypes)); + } + + // Velox requires Filter Pushdown must being enabled. + bool filterPushdownEnabled = true; std::shared_ptr tableHandle; if (!readRel.has_filter()) { tableHandle = std::make_shared( - kHiveConnectorId, "hive_table", filterPushdownEnabled, common::SubfieldFilters{}, nullptr, dataColumns); + kHiveConnectorId, "hive_table", filterPushdownEnabled, common::SubfieldFilters{}, nullptr, tableSchema); } else { common::SubfieldFilters subfieldFilters; - auto remainingFilter = exprConverter_->toVeloxExpr(readRel.filter(), dataColumns); + auto remainingFilter = exprConverter_->toVeloxExpr(readRel.filter(), baseSchema); tableHandle = std::make_shared( kHiveConnectorId, @@ -1301,7 +1320,7 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: filterPushdownEnabled, std::move(subfieldFilters), remainingFilter, - dataColumns); + tableSchema); } // Get assignments and out names. diff --git a/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala b/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala index 4cdc51fe2aa0..9cdefdc351dd 100644 --- a/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala +++ b/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala @@ -130,7 +130,7 @@ case class IcebergScanTransformer( override lazy val getPartitionSchema: StructType = GlutenIcebergSourceUtil.getReadPartitionSchema(scan) - override def getDataSchema: StructType = new StructType() + override def getDataSchema: StructType = GlutenIcebergSourceUtil.getDataSchema(scan) // TODO: get root paths from table. override def getRootPathsInternal: Seq[String] = Seq.empty diff --git a/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala b/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala index 3816499a97a1..4f2a521390a8 100644 --- a/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala +++ b/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala @@ -195,6 +195,17 @@ object GlutenIcebergSourceUtil { throw new UnsupportedOperationException("Only support iceberg SparkBatchQueryScan.") } + def getDataSchema(sparkScan: Scan): StructType = sparkScan match { + case scan: SparkBatchQueryScan => + val tasks = scan.tasks().asScala + asFileScanTask(tasks.toList).foreach( + task => return SparkSchemaUtil.convert(task.spec().schema())) + throw new UnsupportedOperationException( + "Failed to get data schema from iceberg SparkBatchQueryScan.") + case _ => + throw new UnsupportedOperationException("Only support iceberg SparkBatchQueryScan.") + } + private def asFileScanTask(tasks: List[ScanTask]): List[FileScanTask] = { if (tasks.forall(_.isFileScanTask)) { tasks.map(_.asFileScanTask()) diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/LocalFilesNode.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/LocalFilesNode.java index f3faee09e742..80e365ec0e46 100644 --- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/LocalFilesNode.java +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/LocalFilesNode.java @@ -17,13 +17,10 @@ package org.apache.gluten.substrait.rel; import org.apache.gluten.config.GlutenConfig; -import org.apache.gluten.expression.ConverterUtils; import org.apache.gluten.substrait.utils.SubstraitUtil; import io.substrait.proto.NamedStruct; import io.substrait.proto.ReadRel; -import io.substrait.proto.Type; -import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import java.util.ArrayList; @@ -106,21 +103,6 @@ public void setFileSchema(StructType schema) { this.fileSchema = schema; } - private NamedStruct buildNamedStruct() { - NamedStruct.Builder namedStructBuilder = NamedStruct.newBuilder(); - - if (fileSchema != null) { - Type.Struct.Builder structBuilder = Type.Struct.newBuilder(); - for (StructField field : fileSchema.fields()) { - structBuilder.addTypes( - ConverterUtils.getTypeNode(field.dataType(), field.nullable()).toProtobuf()); - namedStructBuilder.addNames(ConverterUtils.normalizeColName(field.name())); - } - namedStructBuilder.setStruct(structBuilder.build()); - } - return namedStructBuilder.build(); - } - @Override public List preferredLocations() { return this.preferredLocations; @@ -195,7 +177,7 @@ public ReadRel.LocalFiles toProtobuf() { ReadRel.LocalFiles.FileOrFiles.metadataColumn.newBuilder(); fileBuilder.addMetadataColumns(mcBuilder.build()); } - NamedStruct namedStruct = buildNamedStruct(); + NamedStruct namedStruct = org.apache.gluten.utils.SubstraitUtil.createNamedStruct(fileSchema); fileBuilder.setSchema(namedStruct); if (!otherMetadataColumns.isEmpty()) { diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/ReadRelNode.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/ReadRelNode.java index b82e05fd367e..b2d7df60ec0f 100644 --- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/ReadRelNode.java +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/ReadRelNode.java @@ -26,6 +26,7 @@ import io.substrait.proto.ReadRel; import io.substrait.proto.Rel; import io.substrait.proto.RelCommon; +import org.apache.spark.sql.types.StructType; import java.io.Serializable; import java.util.ArrayList; @@ -35,6 +36,7 @@ public class ReadRelNode implements RelNode, Serializable { private final List types = new ArrayList<>(); private final List names = new ArrayList<>(); private final List columnTypeNodes = new ArrayList<>(); + private final StructType tableSchema; private final ExpressionNode filterNode; private final AdvancedExtensionNode extensionNode; private boolean streamKafka = false; @@ -42,11 +44,13 @@ public class ReadRelNode implements RelNode, Serializable { ReadRelNode( List types, List names, + StructType tableSchema, ExpressionNode filterNode, List columnTypeNodes, AdvancedExtensionNode extensionNode) { this.types.addAll(types); this.names.addAll(names); + this.tableSchema = tableSchema; this.filterNode = filterNode; this.columnTypeNodes.addAll(columnTypeNodes); this.extensionNode = extensionNode; @@ -69,6 +73,10 @@ public Rel toProtobuf() { readBuilder.setBaseSchema(nStructBuilder.build()); readBuilder.setStreamKafka(streamKafka); + if (tableSchema != null) { + readBuilder.setTableSchema(SubstraitUtil.createNamedStruct(tableSchema)); + } + if (filterNode != null) { readBuilder.setFilter(filterNode.toProtobuf()); } diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java index c8a028d0be4d..3d08aa662668 100644 --- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java @@ -30,6 +30,7 @@ import io.substrait.proto.*; import org.apache.spark.sql.catalyst.expressions.Attribute; +import org.apache.spark.sql.types.StructType; import java.util.List; import java.util.stream.Collectors; @@ -149,12 +150,13 @@ public static RelNode makeReadRel( List types, List names, List columnTypeNodes, + StructType tableSchema, ExpressionNode filter, AdvancedExtensionNode extensionNode, SubstraitContext context, Long operatorId) { context.registerRelToOperator(operatorId); - return new ReadRelNode(types, names, filter, columnTypeNodes, extensionNode); + return new ReadRelNode(types, names, tableSchema, filter, columnTypeNodes, extensionNode); } public static RelNode makeReadRelForInputIterator( diff --git a/gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto b/gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto index 176a4e4c2585..4c07b0cbd873 100644 --- a/gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto +++ b/gluten-substrait/src/main/resources/substrait/proto/substrait/algebra.proto @@ -62,6 +62,7 @@ message ReadRel { Expression best_effort_filter = 11; Expression.MaskExpression projection = 4; substrait.extensions.AdvancedExtension advanced_extension = 10; + NamedStruct table_schema = 12; // Definition of which type of scan operation is to be performed oneof read_type { @@ -218,7 +219,7 @@ message ReadRel { repeated partitionColumn partition_columns = 17; /// File schema - NamedStruct schema = 18; + NamedStruct schema = 18 [deprecated=true]; message metadataColumn { string key = 1; diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala index 79671224243c..4afdd13ffbd2 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala @@ -152,6 +152,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource typeNodes, nameList, columnTypeNodes, + getDataSchema, exprNode, extensionNode, context, diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala b/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala index 9e1117085214..2323f45ef552 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala @@ -24,6 +24,7 @@ import org.apache.gluten.substrait.expression.ExpressionNode import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.types.StructType import io.substrait.proto.{CrossRel, JoinRel, NamedStruct, Type} @@ -107,4 +108,18 @@ object SubstraitUtil { val nameList = ConverterUtils.collectAttributeNamesWithExprId(output) createNameStructBuilder(typeList, nameList, Collections.emptyList()).build() } + + def createNamedStruct(struct: StructType): NamedStruct = { + val namedStructBuilder = NamedStruct.newBuilder + if (struct != null) { + val structBuilder = Type.Struct.newBuilder + for (field <- struct.fields) { + structBuilder.addTypes( + ConverterUtils.getTypeNode(field.dataType, field.nullable).toProtobuf) + namedStructBuilder.addNames(ConverterUtils.normalizeColName(field.name)) + } + namedStructBuilder.setStruct(structBuilder.build) + } + namedStructBuilder.build + } }