diff --git a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangProjectVisitor.java b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangProjectVisitor.java index 775acc49b..eb4342a9b 100755 --- a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangProjectVisitor.java +++ b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangProjectVisitor.java @@ -19,14 +19,15 @@ package org.apache.wayang.api.sql.calcite.converter; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rex.RexNode; import org.apache.wayang.api.sql.calcite.converter.functions.ProjectMapFuncImpl; import org.apache.wayang.api.sql.calcite.rel.WayangProject; import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.function.ProjectionDescriptor; import org.apache.wayang.basic.operators.MapOperator; import org.apache.wayang.core.plan.wayangplan.Operator; +import org.apache.wayang.core.types.BasicDataUnitType; import java.util.List; @@ -39,14 +40,15 @@ public class WayangProjectVisitor extends WayangRelNodeVisitor { Operator visit(final WayangProject wayangRelNode) { final Operator childOp = wayangRelConverter.convert(wayangRelNode.getInput(0)); - /* Quick check */ - final List projects = ((Project) wayangRelNode).getProjects(); + final List projects = wayangRelNode.getProjects(); - // TODO: create a map with specific dataset type - final MapOperator projection = new MapOperator<>( + final ProjectionDescriptor projectionDescriptor = new ProjectionDescriptor<>( new ProjectMapFuncImpl(projects), - Record.class, - Record.class); + wayangRelNode.getRowType().getFieldNames(), + BasicDataUnitType.createBasic(Record.class), + BasicDataUnitType.createBasic(Record.class)); + + final MapOperator projection = new MapOperator<>(projectionDescriptor); childOp.connectTo(0, projection, 0); diff --git a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangTableScanVisitor.java b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangTableScanVisitor.java index b82d4a92b..569384697 100755 --- a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangTableScanVisitor.java +++ b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangTableScanVisitor.java @@ -26,8 +26,11 @@ import org.apache.wayang.api.sql.sources.fs.JavaCSVTableSource; import org.apache.wayang.core.plan.wayangplan.Operator; import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.jdbc.operators.JdbcTableSource; +import org.apache.wayang.jdbc.platform.JdbcPlatformTemplate; import org.apache.wayang.postgres.operators.PostgresTableSource; import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.operators.TableSource; import java.util.List; import java.util.stream.Collectors; @@ -49,9 +52,7 @@ Operator visit(final WayangTableScan wayangRelNode) { if (tableSource.equals("postgres")) { return new PostgresTableSource(tableName, columnNames.toArray(String[]::new)); - } - - if (tableSource.equals("fs")) { + } else if (tableSource.equals("fs")) { final ModelParser modelParser; try { modelParser = this.wayangRelConverter.getConfiguration() == null @@ -72,7 +73,18 @@ Operator visit(final WayangTableScan wayangRelNode) { final char separator = modelParser.getSchemaDelimiter(tableSource); return new JavaCSVTableSource<>(url, DataSetType.createDefault(Record.class), fieldTypes, separator); + } else if (wayangRelNode.getTable().getQualifiedName().size() == 1) { + // we assume that it is coming from a test environement or in memory db. + + return new JdbcTableSource(wayangRelNode.getTable().getQualifiedName().get(0), wayangRelNode.getRowType().getFieldNames().toArray(String[]::new)) { + + @Override + public JdbcPlatformTemplate getPlatform() { + throw new UnsupportedOperationException("Unimplemented method 'getPlatform'"); + } + }; } else - throw new RuntimeException("Source not supported"); + throw new RuntimeException( + "Source not supported, got: " + tableSource + ", expected either postgres or filesystem (fs)."); } } diff --git a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/CallTreeFactory.java b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/CallTreeFactory.java index 086d3356d..6c93c4221 100644 --- a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/CallTreeFactory.java +++ b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/CallTreeFactory.java @@ -58,14 +58,14 @@ public default Node fromRexNode(final RexNode node) { * @return a serializable function of +, -, * or / * @throws UnsupportedOperationException on unrecognized {@link SqlKind} */ - public SerializableFunction, Object> deriveOperation(SqlKind kind); + public SerializableFunction, Object> deriveOperation(final SqlKind kind); } interface Node extends Serializable { public Object evaluate(final Record rec); } -class Call implements Node { +final class Call implements Node { private final List operands; final SerializableFunction, Object> operation; @@ -83,7 +83,7 @@ public Object evaluate(final Record rec) { } } -class Literal implements Node { +final class Literal implements Node { final Serializable value; Literal(final RexLiteral literal) { @@ -109,7 +109,7 @@ public Object evaluate(final Record rec) { } } -class InputRef implements Node { +final class InputRef implements Node { private final int key; InputRef(final RexInputRef inputRef) { diff --git a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/optimizer/Optimizer.java b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/optimizer/Optimizer.java index 382a7cac0..8a48d8697 100755 --- a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/optimizer/Optimizer.java +++ b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/optimizer/Optimizer.java @@ -211,11 +211,11 @@ public RelNode optimize(RelNode node, RelTraitSet requiredTraitSet, RuleSet rule ); } - public WayangPlan convert(RelNode relNode) { + public static WayangPlan convert(RelNode relNode) { return convert(relNode, new ArrayList<>()); } - public WayangPlan convert(RelNode relNode, Collection collector) { + public static WayangPlan convert(RelNode relNode, Collection collector) { LocalCallbackSink sink = LocalCallbackSink.createCollectingSink(collector, Record.class); @@ -225,8 +225,7 @@ public WayangPlan convert(RelNode relNode, Collection collector) { return new WayangPlan(sink); } - public WayangPlan convertWithConfig(RelNode relNode, Configuration configuration, Collection collector) { - + public static WayangPlan convertWithConfig(RelNode relNode, Configuration configuration, Collection collector) { LocalCallbackSink sink = LocalCallbackSink.createCollectingSink(collector, Record.class); Operator op = new WayangRelConverter(configuration).convert(relNode); diff --git a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/rel/WayangTableScan.java b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/rel/WayangTableScan.java index 432b7903e..e002dbcb1 100755 --- a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/rel/WayangTableScan.java +++ b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/rel/WayangTableScan.java @@ -27,12 +27,12 @@ import org.apache.calcite.rel.hint.RelHint; import org.apache.calcite.schema.Table; import org.apache.wayang.api.sql.calcite.convention.WayangConvention; -import org.apache.wayang.api.sql.calcite.utils.ModelParser; import java.util.List; public class WayangTableScan extends TableScan implements WayangRel { + //TODO: fields are never queried, why? private final int[] fields; public WayangTableScan(RelOptCluster cluster, @@ -83,11 +83,15 @@ public String toString() { } public String getQualifiedName() { - return table.getQualifiedName().get(1); + return table.getQualifiedName().size() == 1 + ? table.getQualifiedName().get(0) + : table.getQualifiedName().get(1); } public String getTableName() { - return table.getQualifiedName().get(1); + return table.getQualifiedName().size() == 1 + ? table.getQualifiedName().get(0) + : table.getQualifiedName().get(1); } public List getColumnNames() { diff --git a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/context/SqlContext.java b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/context/SqlContext.java index af26273cf..3b8209979 100755 --- a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/context/SqlContext.java +++ b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/context/SqlContext.java @@ -167,7 +167,7 @@ public static void main(final String[] args) throws Exception { PrintUtils.print("After translating logical intermediate plan", wayangRel); final Collection collector = new ArrayList<>(); - final WayangPlan wayangPlan = optimizer.convertWithConfig(wayangRel, configuration, collector); + final WayangPlan wayangPlan = Optimizer.convertWithConfig(wayangRel, configuration, collector); collector.add(new Record(wayangRel.getRowType().getFieldNames().toArray())); context.execute(getJobName(), wayangPlan); @@ -182,7 +182,6 @@ public static void main(final String[] args) throws Exception { } public Collection executeSql(final String sql) throws SqlParseException { - final Properties configProperties = Optimizer.ConfigProperties.getDefaults(); final RelDataTypeFactory relDataTypeFactory = new JavaTypeFactoryImpl(); @@ -216,7 +215,7 @@ public Collection executeSql(final String sql) throws SqlParseException PrintUtils.print("After translating logical intermediate plan", wayangRel); final Collection collector = new ArrayList<>(); - final WayangPlan wayangPlan = optimizer.convert(wayangRel, collector); + final WayangPlan wayangPlan = Optimizer.convert(wayangRel, collector); this.execute(getJobName(), wayangPlan); diff --git a/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java b/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java index ad918036c..9d10106ab 100755 --- a/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java +++ b/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java @@ -17,19 +17,55 @@ package org.apache.wayang.api.sql; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.stream.Collectors; + +import org.apache.calcite.jdbc.CalciteSchema; import org.apache.calcite.jdbc.JavaTypeFactoryImpl; +import org.apache.calcite.plan.Contexts; +import org.apache.calcite.plan.ConventionTraitDef; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptCostImpl; +import org.apache.calcite.plan.RelOptSchema; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.volcano.VolcanoPlanner; +import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.logical.LogicalTableScan; +import org.apache.calcite.rel.rel2sql.RelToSqlConverter; import org.apache.calcite.rel.rules.CoreRules; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeFactory.Builder; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.schema.impl.AbstractTable; +import org.apache.calcite.sql.SqlDialect; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.RuleSet; import org.apache.calcite.tools.RuleSets; import org.apache.wayang.api.sql.calcite.convention.WayangConvention; @@ -43,36 +79,25 @@ import org.apache.wayang.basic.data.Record; import org.apache.wayang.basic.data.Tuple2; import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.api.WayangContext; import org.apache.wayang.core.function.FunctionDescriptor.SerializablePredicate; +import org.apache.wayang.core.mapping.PlanTransformation; +import org.apache.wayang.core.plan.wayangplan.Operator; import org.apache.wayang.core.plan.wayangplan.PlanTraversal; import org.apache.wayang.core.plan.wayangplan.WayangPlan; import org.apache.wayang.java.Java; +import org.apache.wayang.jdbc.execution.JdbcExecutor; +import org.apache.wayang.jdbc.operators.JdbcProjectionOperator; +import org.apache.wayang.jdbc.operators.JdbcTableSource; +import org.apache.wayang.postgres.mapping.ProjectionMapping; import org.apache.wayang.spark.Spark; - import org.json.simple.parser.ParseException; import org.junit.jupiter.api.Test; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import java.util.stream.Collectors; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - class SqlToWayangRelTest { /** @@ -117,12 +142,92 @@ private Tuple2, WayangPlan> buildCollectorAndWayangPlan(final final Collection collector = new ArrayList<>(); - final WayangPlan wayangPlan = optimizer.convertWithConfig(wayangRel, context.getConfiguration(), + final WayangPlan wayangPlan = Optimizer.convertWithConfig(wayangRel, context.getConfiguration(), collector); return new Tuple2<>(collector, wayangPlan); } + @Test + void sqlApiSourceTest() throws Exception { + final JavaTypeFactoryImpl typeFactory = new JavaTypeFactoryImpl(); + final RexBuilder rexBuilder = new RexBuilder(typeFactory); + + final VolcanoPlanner planner = new VolcanoPlanner(RelOptCostImpl.FACTORY, Contexts.empty()); + planner.addRelTraitDef(ConventionTraitDef.INSTANCE); + final RelOptCluster cluster = RelOptCluster.create(planner, rexBuilder); + + final SchemaPlus rootSchema = Frameworks.createRootSchema(true); + final RelDataType rowType = new Builder(typeFactory) + .add("ID", typeFactory.createJavaType(Integer.class)) + .add("NAME", typeFactory.createJavaType(String.class)) + .build(); + + rootSchema.add("T1", new AbstractTable() { + @Override + public RelDataType getRowType(final RelDataTypeFactory typeFactory) { + return rowType; + } + }); + + final RelOptSchema relOptSchema = new CalciteCatalogReader( + CalciteSchema.from(rootSchema), + CalciteSchema.from(rootSchema).path(null), + typeFactory, + mock()); + + final RelOptTable t1 = relOptSchema.getTableForMember(Arrays.asList("T1")); + + final TableScan scan1 = LogicalTableScan.create(cluster, t1, List.of()); + + final SqlDialect dialect = SqlDialect.DatabaseProduct.CALCITE.getDialect(); + final RelToSqlConverter converter = new RelToSqlConverter(dialect); + final SqlNode sqlNode = converter.visitRoot(scan1).asStatement(); + + final Properties configProperties = Optimizer.ConfigProperties.getDefaults(); + final RelDataTypeFactory relDataTypeFactory = new JavaTypeFactoryImpl(); + + final Optimizer optimizer = Optimizer.create( + CalciteSchema.from(rootSchema), + configProperties, + relDataTypeFactory); + + final SqlNode validatedSqlNode = optimizer.validate(sqlNode); + final RelNode relNode = optimizer.convert(validatedSqlNode); + + final RuleSet rules = RuleSets.ofList( + CoreRules.FILTER_INTO_JOIN, + WayangRules.WAYANG_TABLESCAN_RULE, + WayangRules.WAYANG_TABLESCAN_ENUMERABLE_RULE, + WayangRules.WAYANG_PROJECT_RULE, + WayangRules.WAYANG_FILTER_RULE, + WayangRules.WAYANG_JOIN_RULE, + WayangRules.WAYANG_AGGREGATE_RULE, + WayangRules.WAYANG_SORT_RULE); + + final RelNode wayangRel = optimizer.optimize( + relNode, + relNode.getTraitSet().plus(WayangConvention.INSTANCE), + rules); + + final WayangPlan plan = Optimizer.convert(wayangRel, new ArrayList()); + + final ProjectionMapping projectionMapping = new ProjectionMapping(); + final PlanTransformation projectionTransformation = projectionMapping.getTransformations().iterator().next().thatReplaces(); + + plan.applyTransformations(List.of(projectionTransformation)); + + final Collection operators = PlanTraversal.upstream().traverse(plan.getSinks()).getTraversedNodes(); + + final JdbcTableSource table = operators.stream().filter(op -> op instanceof JdbcTableSource).map(JdbcTableSource.class::cast).findFirst().orElseThrow(); + final JdbcProjectionOperator projection = operators.stream().filter(op -> op instanceof JdbcProjectionOperator).map(JdbcProjectionOperator.class::cast).findFirst().orElseThrow(); + + final JdbcExecutor jdbcExecutor = mock(); + final StringBuilder query = JdbcExecutor.createSqlString(jdbcExecutor, table, List.of(), projection, List.of()); + + assertEquals("SELECT ID, NAME FROM T1;", query.toString()); + } + @Test void javaJoinTest() throws Exception { final SqlContext sqlContext = this.createSqlContext("/data/largeLeftTableIndex.csv"); @@ -637,7 +742,8 @@ void exampleCustomDelimiter() throws Exception { " \"type\": \"custom\",\r\n" + // " \"factory\": \"org.apache.calcite.adapter.file.FileSchemaFactory\",\r\n" + // " \"operand\": {\r\n" + // - " \"directory\": \"" + "/" + this.getClass().getResource("/data").getPath() + "\",\r\n" + // + " \"directory\": \"" + "/" + this.getClass().getResource("/data").getPath() + + "\",\r\n" + // " \"delimiter\": \"|\"" + " }\r\n" + // " }\r\n" + // @@ -695,7 +801,8 @@ private SqlContext createSqlContext(final String tableResourceName) " \"type\": \"custom\",\r\n" + " \"factory\": \"org.apache.calcite.adapter.file.FileSchemaFactory\",\r\n" + " \"operand\": {\r\n" + - " \"directory\": \"" + "/" + this.getClass().getResource("/data").getPath() + "\"\r\n" + + " \"directory\": \"" + "/" + this.getClass().getResource("/data").getPath() + + "\"\r\n" + " }\r\n" + " }\r\n" + " ]\r\n" + diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/function/ProjectionDescriptor.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/function/ProjectionDescriptor.java index 5db00d778..15aeabc53 100644 --- a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/function/ProjectionDescriptor.java +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/function/ProjectionDescriptor.java @@ -18,140 +18,42 @@ package org.apache.wayang.basic.function; +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + import org.apache.wayang.basic.data.Record; import org.apache.wayang.basic.types.RecordType; import org.apache.wayang.core.function.FunctionDescriptor; import org.apache.wayang.core.function.TransformationDescriptor; import org.apache.wayang.core.types.BasicDataUnitType; -import java.lang.reflect.Field; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - /** - * This descriptor pertains to projections. It takes field names of the input type to describe the projection. + * This descriptor pertains to projections. It takes field names of the input + * type to describe the projection. */ public class ProjectionDescriptor extends TransformationDescriptor { - private List fieldNames; - - /** - * Creates a new instance. - * - * @param inputTypeClass input type - * @param outputTypeClass output type - * @param fieldNames names of the fields to be projected - */ - public ProjectionDescriptor(Class inputTypeClass, - Class outputTypeClass, - String... fieldNames) { - this(BasicDataUnitType.createBasic(inputTypeClass), - BasicDataUnitType.createBasic(outputTypeClass), - fieldNames); - } - - /** - * Creates a new instance. - * - * @param inputType input type - * @param outputType output type - * @param fieldNames names of the fields to be projected - */ - public ProjectionDescriptor(BasicDataUnitType inputType, BasicDataUnitType outputType, String... fieldNames) { - this(createPojoJavaImplementation(fieldNames, inputType), - Collections.unmodifiableList(Arrays.asList(fieldNames)), - inputType, - outputType); - } - - /** - * Basic constructor. - * - * @param javaImplementation Java-based implementation of the projection - * @param fieldNames names of the fields to be projected - * @param inputType input {@link BasicDataUnitType} - * @param outputType output {@link BasicDataUnitType} - */ - private ProjectionDescriptor(SerializableFunction javaImplementation, - List fieldNames, - BasicDataUnitType inputType, - BasicDataUnitType outputType) { - super(javaImplementation, inputType, outputType); - this.fieldNames = fieldNames; - } - - /** - * Creates a new instance that specifically projects {@link Record}s. - * - * @param inputType input {@link RecordType} - * @param fieldNames names of fields to be projected - * @return the new instance - */ - public static ProjectionDescriptor createForRecords(RecordType inputType, String... fieldNames) { - final SerializableFunction javaImplementation = createRecordJavaImplementation(fieldNames, inputType); - return new ProjectionDescriptor<>( - javaImplementation, - Arrays.asList(fieldNames), - inputType, - new RecordType(fieldNames) - ); - } - - private static FunctionDescriptor.SerializableFunction - createPojoJavaImplementation(String[] fieldNames, BasicDataUnitType inputType) { - // Get the names of the fields to be projected. - if (fieldNames.length != 1) { - return t -> { - throw new IllegalStateException("The projection descriptor currently supports only a single field."); - }; - } - String fieldName = fieldNames[0]; - return new PojoImplementation<>(fieldName); - } - - private static FunctionDescriptor.SerializableFunction - createRecordJavaImplementation(String[] fieldNames, RecordType inputType) { - return new RecordImplementation(inputType, fieldNames); - } - - /** - * Transforms an array of {@link RecordType} field names to indices. - * - * @param recordType that maps field names to indices - * @param fieldNames the field names - * @return the field indices - */ - private static int[] toIndices(RecordType recordType, String[] fieldNames) { - int[] fieldIndices = new int[fieldNames.length]; - for (int i = 0; i < fieldNames.length; i++) { - String fieldName = fieldNames[i]; - fieldIndices[i] = recordType.getIndex(fieldName); - } - return fieldIndices; - } - - public List getFieldNames() { - return this.fieldNames; - } - /** * Java implementation of a projection on POJOs via reflection. */ - // TODO: Revise implementation to support multiple field projection, by names and indexes. - private static class PojoImplementation implements FunctionDescriptor.SerializableFunction { + // TODO: Revise implementation to support multiple field projection, by names + // and indexes. + private static class PojoImplementation + implements FunctionDescriptor.SerializableFunction { private final String fieldName; private Field field; - private PojoImplementation(String fieldName) { + private PojoImplementation(final String fieldName) { this.fieldName = fieldName; } @Override @SuppressWarnings("unchecked") - public Output apply(Input input) { + public Output apply(final Input input) { // Initialization code. if (this.field == null) { @@ -161,7 +63,7 @@ public Output apply(Input input) { // Find the projection field via reflection. try { this.field = typeClass.getField(this.fieldName); - } catch (Exception e) { + } catch (final Exception e) { throw new IllegalStateException("The configuration of the projection seems to be illegal.", e); } } @@ -169,7 +71,7 @@ public Output apply(Input input) { // Actual function. try { return (Output) this.field.get(input); - } catch (IllegalAccessException e) { + } catch (final IllegalAccessException e) { throw new RuntimeException("Illegal projection function.", e); } } @@ -191,19 +93,121 @@ private static class RecordImplementation implements FunctionDescriptor.Serializ * @param recordType {@link RecordType} of input {@link Record}s * @param fieldNames that should be projected on */ - private RecordImplementation(RecordType recordType, String... fieldNames) { + private RecordImplementation(final RecordType recordType, final String... fieldNames) { this.fieldIndices = toIndices(recordType, fieldNames); } @Override @SuppressWarnings("unchecked") - public Record apply(Record input) { - Object[] projectedFields = new Object[this.fieldIndices.length]; + public Record apply(final Record input) { + final Object[] projectedFields = new Object[this.fieldIndices.length]; for (int i = 0; i < this.fieldIndices.length; i++) { - int fieldIndex = this.fieldIndices[i]; + final int fieldIndex = this.fieldIndices[i]; projectedFields[i] = input.getField(fieldIndex); } return new Record(projectedFields); } } + + /** + * Creates a new instance that specifically projects {@link Record}s. + * + * @param inputType input {@link RecordType} + * @param fieldNames names of fields to be projected + * @return the new instance + */ + public static ProjectionDescriptor createForRecords(final RecordType inputType, final String... fieldNames) { + final SerializableFunction javaImplementation = createRecordJavaImplementation(fieldNames, + inputType); + return new ProjectionDescriptor<>( + javaImplementation, + Arrays.asList(fieldNames), + inputType, + new RecordType(fieldNames)); + } + + private static FunctionDescriptor.SerializableFunction createPojoJavaImplementation( + final String[] fieldNames, final BasicDataUnitType inputType) { + // Get the names of the fields to be projected. + if (fieldNames.length != 1) { + return t -> { + throw new IllegalStateException("The projection descriptor currently supports only a single field."); + }; + } + final String fieldName = fieldNames[0]; + return new PojoImplementation<>(fieldName); + } + + private static FunctionDescriptor.SerializableFunction createRecordJavaImplementation( + final String[] fieldNames, final RecordType inputType) { + return new RecordImplementation(inputType, fieldNames); + } + + /** + * Transforms an array of {@link RecordType} field names to indices. + * + * @param recordType that maps field names to indices + * @param fieldNames the field names + * @return the field indices + */ + private static int[] toIndices(final RecordType recordType, final String[] fieldNames) { + final int[] fieldIndices = new int[fieldNames.length]; + for (int i = 0; i < fieldNames.length; i++) { + final String fieldName = fieldNames[i]; + fieldIndices[i] = recordType.getIndex(fieldName); + } + return fieldIndices; + } + + private final List fieldNames; + + /** + * Creates a new instance. + * + * @param inputTypeClass input type + * @param outputTypeClass output type + * @param fieldNames names of the fields to be projected + */ + public ProjectionDescriptor(final Class inputTypeClass, + final Class outputTypeClass, + final String... fieldNames) { + this(BasicDataUnitType.createBasic(inputTypeClass), + BasicDataUnitType.createBasic(outputTypeClass), + fieldNames); + } + + /** + * Creates a new instance. + * + * @param inputType input type + * @param outputType output type + * @param fieldNames names of the fields to be projected + */ + public ProjectionDescriptor(final BasicDataUnitType inputType, final BasicDataUnitType outputType, + final String... fieldNames) { + this(createPojoJavaImplementation(fieldNames, inputType), + Collections.unmodifiableList(Arrays.asList(fieldNames)), + inputType, + outputType); + } + + /** + * Basic constructor. + * + * @param javaImplementation Java-based implementation of the projection + * @param fieldNames names of the fields to be projected + * @param inputType input {@link BasicDataUnitType} + * @param outputType output {@link BasicDataUnitType} + */ + public ProjectionDescriptor(final SerializableFunction javaImplementation, + final List fieldNames, + final BasicDataUnitType inputType, + final BasicDataUnitType outputType) { + super(javaImplementation, inputType, outputType); + this.fieldNames = fieldNames; + } + + public List getFieldNames() { + return this.fieldNames; + } } diff --git a/wayang-platforms/wayang-jdbc-template/src/main/java/org/apache/wayang/jdbc/execution/JdbcExecutor.java b/wayang-platforms/wayang-jdbc-template/src/main/java/org/apache/wayang/jdbc/execution/JdbcExecutor.java index 4a7df7d3c..f7a9d7c5a 100644 --- a/wayang-platforms/wayang-jdbc-template/src/main/java/org/apache/wayang/jdbc/execution/JdbcExecutor.java +++ b/wayang-platforms/wayang-jdbc-template/src/main/java/org/apache/wayang/jdbc/execution/JdbcExecutor.java @@ -27,7 +27,6 @@ import org.apache.wayang.core.plan.executionplan.Channel; import org.apache.wayang.core.plan.executionplan.ExecutionStage; import org.apache.wayang.core.plan.executionplan.ExecutionTask; -import org.apache.wayang.core.plan.wayangplan.Operator; import org.apache.wayang.core.platform.ExecutionState; import org.apache.wayang.core.platform.Executor; import org.apache.wayang.core.platform.ExecutorTemplate; @@ -41,6 +40,7 @@ import org.apache.wayang.jdbc.operators.JdbcFilterOperator; import org.apache.wayang.jdbc.operators.JdbcJoinOperator; import org.apache.wayang.jdbc.operators.JdbcProjectionOperator; +import org.apache.wayang.jdbc.operators.JdbcTableSource; import org.apache.wayang.jdbc.platform.JdbcPlatformTemplate; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -78,7 +78,7 @@ public JdbcExecutor(final JdbcPlatformTemplate platform, final Job job) { @Override public void execute(final ExecutionStage stage, final OptimizationContext optimizationContext, final ExecutionState executionState) { - final Tuple2 pair = this.createSqlQuery(stage, optimizationContext); + final Tuple2 pair = JdbcExecutor.createSqlQuery(stage, optimizationContext, this); final String query = pair.field0; final SqlQueryChannel.Instance queryChannel = pair.field1; @@ -98,7 +98,7 @@ public void execute(final ExecutionStage stage, final OptimizationContext optimi * @param stage in which the follow-up {@link ExecutionTask} should be * @return the said follow-up {@link ExecutionTask} or {@code null} if none */ - private ExecutionTask findJdbcExecutionOperatorTaskInStage(final ExecutionTask task, final ExecutionStage stage) { + private static ExecutionTask findJdbcExecutionOperatorTaskInStage(final ExecutionTask task, final ExecutionStage stage) { assert task.getNumOuputChannels() == 1; final Channel outputChannel = task.getOutputChannel(0); final ExecutionTask consumer = WayangCollections.getSingle(outputChannel.getConsumers()); @@ -116,15 +116,15 @@ private ExecutionTask findJdbcExecutionOperatorTaskInStage(final ExecutionTask t * {@link ExecutionTask} * @return the {@link SqlQueryChannel.Instance} */ - private SqlQueryChannel.Instance instantiateOutboundChannel(final ExecutionTask task, - final OptimizationContext optimizationContext) { + private static SqlQueryChannel.Instance instantiateOutboundChannel(final ExecutionTask task, + final OptimizationContext optimizationContext, final JdbcExecutor jdbcExecutor) { assert task.getNumOuputChannels() == 1 : String.format("Illegal task: %s.", task); assert task.getOutputChannel(0) instanceof SqlQueryChannel : String.format("Illegal task: %s.", task); final SqlQueryChannel outputChannel = (SqlQueryChannel) task.getOutputChannel(0); final OptimizationContext.OperatorContext operatorContext = optimizationContext .getOperatorContext(task.getOperator()); - return outputChannel.createInstance(this, operatorContext, 0); + return outputChannel.createInstance(jdbcExecutor, operatorContext, 0); } /** @@ -139,10 +139,10 @@ private SqlQueryChannel.Instance instantiateOutboundChannel(final ExecutionTask * to keep track of lineage * @return the {@link SqlQueryChannel.Instance} */ - private SqlQueryChannel.Instance instantiateOutboundChannel(final ExecutionTask task, + private static SqlQueryChannel.Instance instantiateOutboundChannel(final ExecutionTask task, final OptimizationContext optimizationContext, - final SqlQueryChannel.Instance predecessorChannelInstance) { - final SqlQueryChannel.Instance newInstance = this.instantiateOutboundChannel(task, optimizationContext); + final SqlQueryChannel.Instance predecessorChannelInstance, final JdbcExecutor jdbcExecutor) { + final SqlQueryChannel.Instance newInstance = JdbcExecutor.instantiateOutboundChannel(task, optimizationContext, jdbcExecutor); newInstance.getLineage().addPredecessor(predecessorChannelInstance.getLineage()); return newInstance; } @@ -154,8 +154,8 @@ private SqlQueryChannel.Instance instantiateOutboundChannel(final ExecutionTask * @param context * @return a tuple containing the sql statement */ - protected Tuple2 createSqlQuery(final ExecutionStage stage, - final OptimizationContext context) { + protected static Tuple2 createSqlQuery(final ExecutionStage stage, + final OptimizationContext context, final JdbcExecutor jdbcExecutor) { final Collection startTasks = stage.getStartTasks(); final Collection termTasks = stage.getTerminalTasks(); @@ -168,44 +168,49 @@ protected Tuple2 createSqlQuery(final Executio : "Invalid JDBC stage: Start task has to be a TableSource"; // Extract the different types of ExecutionOperators from the stage. - final TableSource tableOp = (TableSource) startTask.getOperator(); - SqlQueryChannel.Instance tipChannelInstance = this.instantiateOutboundChannel(startTask, context); - final Collection filterTasks = new ArrayList<>(4); - ExecutionTask projectionTask = null; - final Collection joinTasks = new ArrayList<>(); + final JdbcTableSource tableOp = (JdbcTableSource) startTask.getOperator(); + SqlQueryChannel.Instance tipChannelInstance = JdbcExecutor.instantiateOutboundChannel(startTask, context, jdbcExecutor); + final Collection filterTasks = new ArrayList<>(4); + JdbcProjectionOperator projectionTask = null; + final Collection> joinTasks = new ArrayList<>(); final Set allTasks = stage.getAllTasks(); assert allTasks.size() <= 3; - ExecutionTask nextTask = this.findJdbcExecutionOperatorTaskInStage(startTask, stage); + ExecutionTask nextTask = JdbcExecutor.findJdbcExecutionOperatorTaskInStage(startTask, stage); while (nextTask != null) { // Evaluate the nextTask. - if (nextTask.getOperator() instanceof JdbcFilterOperator) { - filterTasks.add(nextTask); - } else if (nextTask.getOperator() instanceof JdbcProjectionOperator) { + if (nextTask.getOperator() instanceof final JdbcFilterOperator filterOperator) { + filterTasks.add(filterOperator); + } else if (nextTask.getOperator() instanceof JdbcProjectionOperator projectionOperator) { assert projectionTask == null; // Allow one projection operator per stage for now. - projectionTask = nextTask; - } else if (nextTask.getOperator() instanceof JdbcJoinOperator) { - joinTasks.add(nextTask); + projectionTask = projectionOperator; + } else if (nextTask.getOperator() instanceof JdbcJoinOperator joinOperator) { + joinTasks.add(joinOperator); } else { throw new WayangException(String.format("Unsupported JDBC execution task %s", nextTask.toString())); } // Move the tipChannelInstance. - tipChannelInstance = this.instantiateOutboundChannel(nextTask, context, tipChannelInstance); + tipChannelInstance = JdbcExecutor.instantiateOutboundChannel(nextTask, context, tipChannelInstance, jdbcExecutor); // Go to the next nextTask. - nextTask = this.findJdbcExecutionOperatorTaskInStage(nextTask, stage); + nextTask = JdbcExecutor.findJdbcExecutionOperatorTaskInStage(nextTask, stage); } // Create the SQL query. - final String tableName = this.getSqlClause(tableOp); + final StringBuilder query = createSqlString(jdbcExecutor, tableOp, filterTasks, projectionTask, joinTasks); + return new Tuple2<>(query.toString(), tipChannelInstance); + } + + public static StringBuilder createSqlString(final JdbcExecutor jdbcExecutor, final JdbcTableSource tableOp, + final Collection filterTasks, JdbcProjectionOperator projectionTask, + final Collection> joinTasks) { + final String tableName = tableOp.createSqlClause(jdbcExecutor.connection, jdbcExecutor.functionCompiler); final Collection conditions = filterTasks.stream() - .map(ExecutionTask::getOperator) - .map(this::getSqlClause) + .map(op -> op.createSqlClause(jdbcExecutor.connection, jdbcExecutor.functionCompiler)) .collect(Collectors.toList()); - final String projection = projectionTask == null ? "*" : this.getSqlClause(projectionTask.getOperator()); + final String projection = projectionTask == null ? "*" : projectionTask.createSqlClause(jdbcExecutor.connection, jdbcExecutor.functionCompiler); final Collection joins = joinTasks.stream() - .map(ExecutionTask::getOperator) - .map(this::getSqlClause) + .map(op -> op.createSqlClause(jdbcExecutor.connection, jdbcExecutor.functionCompiler)) .collect(Collectors.toList()); final StringBuilder sb = new StringBuilder(1000); @@ -225,17 +230,7 @@ protected Tuple2 createSqlQuery(final Executio } } sb.append(';'); - return new Tuple2<>(sb.toString(), tipChannelInstance); - } - - /** - * Creates a SQL clause that corresponds to the given {@link Operator}. - * - * @param operator for that the SQL clause should be generated - * @return the SQL clause - */ - private String getSqlClause(final Operator operator) { - return ((JdbcExecutionOperator) operator).createSqlClause(this.connection, this.functionCompiler); + return sb; } @Override diff --git a/wayang-platforms/wayang-postgres/src/main/java/org/apache/wayang/postgres/mapping/ProjectionMapping.java b/wayang-platforms/wayang-postgres/src/main/java/org/apache/wayang/postgres/mapping/ProjectionMapping.java index 0fae1a861..85ed2d2a5 100644 --- a/wayang-platforms/wayang-postgres/src/main/java/org/apache/wayang/postgres/mapping/ProjectionMapping.java +++ b/wayang-platforms/wayang-postgres/src/main/java/org/apache/wayang/postgres/mapping/ProjectionMapping.java @@ -34,10 +34,8 @@ import java.util.Collections; /** - * /** * Mapping from {@link MapOperator} to {@link PostgresProjectionOperator}. */ -@SuppressWarnings("unchecked") public class ProjectionMapping implements Mapping { @Override @@ -45,8 +43,7 @@ public Collection getTransformations() { return Collections.singleton(new PlanTransformation( this.createSubplanPattern(), this.createReplacementSubplanFactory(), - PostgresPlatform.getInstance() - )); + PostgresPlatform.getInstance())); } private SubplanPattern createSubplanPattern() { @@ -55,10 +52,8 @@ private SubplanPattern createSubplanPattern() { new MapOperator<>( null, DataSetType.createDefault(Record.class), - DataSetType.createDefault(Record.class) - ), - false - ) + DataSetType.createDefault(Record.class)), + false) .withAdditionalTest(op -> op.getFunctionDescriptor() instanceof ProjectionDescriptor) .withAdditionalTest(op -> op.getNumInputs() == 1); // No broadcasts. return SubplanPattern.createSingleton(operatorPattern); @@ -66,7 +61,6 @@ private SubplanPattern createSubplanPattern() { private ReplacementSubplanFactory createReplacementSubplanFactory() { return new ReplacementSubplanFactory.OfSingleOperators>( - (matchedOperator, epoch) -> new PostgresProjectionOperator(matchedOperator).at(epoch) - ); + (matchedOperator, epoch) -> new PostgresProjectionOperator(matchedOperator).at(epoch)); } }