diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecSink.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecSink.java index 8dbac05783c3..81ab40c1c8cd 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecSink.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecSink.java @@ -16,9 +16,6 @@ */ package org.apache.flink.table.planner.plan.nodes.exec.common; -import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; -import org.apache.gluten.util.LogicalTypeConverter; -import org.apache.gluten.util.PlanNodeIdGenerator; import org.apache.gluten.velox.VeloxSourceSinkFactory; import org.apache.flink.api.common.io.OutputFormat; @@ -73,6 +70,7 @@ import org.apache.flink.table.runtime.operators.sink.ConstraintEnforcer; import org.apache.flink.table.runtime.operators.sink.RowKindSetter; import org.apache.flink.table.runtime.operators.sink.SinkOperator; +import org.apache.flink.table.runtime.operators.sink.StreamRecordTimestampInserter; import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; import org.apache.flink.table.types.logical.BinaryType; import org.apache.flink.table.types.logical.CharType; @@ -558,11 +556,6 @@ private Transformation applyRowtimeTransformation( if (rowtimeFieldIndex == -1) { return inputTransform; } - // --- Begin Gluten-specific code changes --- - io.github.zhztheplayer.velox4j.type.RowType outputType = - (io.github.zhztheplayer.velox4j.type.RowType) - LogicalTypeConverter.toVLType( - ((InternalTypeInfo) inputTransform.getOutputType()).toLogicalType()); return ExecNodeUtil.createOneInputTransformation( inputTransform, createTransformationMeta( @@ -570,15 +563,49 @@ private Transformation applyRowtimeTransformation( String.format("StreamRecordTimestampInserter(rowtime field: %s)", rowtimeFieldIndex), "StreamRecordTimestampInserter", config), - // TODO: support it, Map.of() will not be used, hardcode it here. - new GlutenOneInputOperator( - null, PlanNodeIdGenerator.newId(), null, Map.of("1", outputType)), + new StreamRecordTimestampInserter(rowtimeFieldIndex), inputTransform.getOutputType(), sinkParallelism, sinkParallelismConfigured); - // --- End Gluten-specific code changes --- } + /* + private Transformation applyRowtimeTransformation( + Transformation inputTransform, + int rowtimeFieldIndex, + int sinkParallelism, + ExecNodeConfig config) { + // Don't apply the transformation/operator if there is no rowtimeFieldIndex + if (rowtimeFieldIndex == -1) { + return inputTransform; + } + // --- Begin Gluten-specific code changes --- + io.github.zhztheplayer.velox4j.type.RowType outputType = + (io.github.zhztheplayer.velox4j.type.RowType) + LogicalTypeConverter.toVLType( + ((InternalTypeInfo) inputTransform.getOutputType()).toLogicalType()); + return ExecNodeUtil.createOneInputTransformation( + inputTransform, + createTransformationMeta( + TIMESTAMP_INSERTER_TRANSFORMATION, + String.format("StreamRecordTimestampInserter(rowtime field: %s)", rowtimeFieldIndex), + "StreamRecordTimestampInserter", + config), + // TODO: support it, Map.of() will not be used, hardcode it here. + new GlutenOneInputOperator( + null, + PlanNodeIdGenerator.newId(), + null, + Map.of("1", outputType), + RowData.class, + RowData.class), + inputTransform.getOutputType(), + sinkParallelism, + sinkParallelismConfigured); + // --- End Gluten-specific code changes --- + } + */ + private InternalTypeInfo getInputTypeInfo() { return InternalTypeInfo.of(getInputEdges().get(0).getOutputType()); } diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecCalc.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecCalc.java index bc98c5154cc9..1062f3c08e2a 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecCalc.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecCalc.java @@ -19,7 +19,7 @@ import org.apache.gluten.rexnode.RexConversionContext; import org.apache.gluten.rexnode.RexNodeConverter; import org.apache.gluten.rexnode.Utils; -import org.apache.gluten.table.runtime.operators.GlutenVectorOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -53,6 +53,8 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; import org.apache.calcite.rex.RexNode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import javax.annotation.Nullable; @@ -68,6 +70,7 @@ minPlanVersion = FlinkVersion.v1_15, minStateVersion = FlinkVersion.v1_15) public class StreamExecCalc extends CommonExecCalc implements StreamExecNode { + private static final Logger LOG = LoggerFactory.getLogger(StreamExecCalc.class); public StreamExecCalc( ReadableConfig tableConfig, @@ -142,11 +145,14 @@ public Transformation translateToPlanInternal( (io.github.zhztheplayer.velox4j.type.RowType) LogicalTypeConverter.toVLType(getOutputType()); final OneInputStreamOperator calOperator = - new GlutenVectorOneInputOperator( + new GlutenOneInputOperator( new StatefulPlanNode(project.getId(), project), PlanNodeIdGenerator.newId(), inputType, - Map.of(project.getId(), outputType)); + Map.of(project.getId(), outputType), + RowData.class, + RowData.class, + "StreamExecCalc"); return ExecNodeUtil.createOneInputTransformation( inputTransform, new TransformationMetadata("gluten-calc", "Gluten cal operator"), diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeduplicate.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeduplicate.java index ed89b717747e..899a98365e51 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeduplicate.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeduplicate.java @@ -16,7 +16,7 @@ */ package org.apache.flink.table.planner.plan.nodes.exec.stream; -import org.apache.gluten.table.runtime.operators.GlutenVectorOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -212,11 +212,14 @@ protected Transformation translateToPlanInternal( deduplicateNode, outputType); operator = - new GlutenVectorOneInputOperator( + new GlutenOneInputOperator( new StatefulPlanNode(streamRankNode.getId(), streamRankNode), PlanNodeIdGenerator.newId(), inputType, - Map.of(streamRankNode.getId(), outputType)); + Map.of(streamRankNode.getId(), outputType), + RowData.class, + RowData.class, + "StreamExecDeduplicate"); } else { throw new RuntimeException("ProcTime in deduplicate is not supported."); } diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecExchange.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecExchange.java index e6ec9162c2a5..17e3966d52a7 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecExchange.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecExchange.java @@ -19,7 +19,7 @@ import org.apache.gluten.streaming.api.operators.GlutenOperator; import org.apache.gluten.streaming.runtime.partitioner.GlutenKeyGroupStreamPartitioner; import org.apache.gluten.table.runtime.keyselector.GlutenKeySelector; -import org.apache.gluten.table.runtime.operators.GlutenVectorOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -164,8 +164,14 @@ protected Transformation translateToPlanInternal( partitionFunctionSpec); PlanNode exchange = new StreamPartitionNode(id, localPartition, parallelism); final OneInputStreamOperator exchangeKeyGenerator = - new GlutenVectorOneInputOperator( - new StatefulPlanNode(id, exchange), id, glutenInputType, Map.of(id, outputType)); + new GlutenOneInputOperator( + new StatefulPlanNode(id, exchange), + id, + glutenInputType, + Map.of(id, outputType), + RowData.class, + RowData.class, + "StreamExecExchange"); inputTransform = ExecNodeUtil.createOneInputTransformation( inputTransform, diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java index ee8c91a3757c..9cbc017232d5 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java @@ -19,7 +19,7 @@ import org.apache.gluten.rexnode.AggregateCallConverter; import org.apache.gluten.rexnode.Utils; import org.apache.gluten.rexnode.WindowUtils; -import org.apache.gluten.table.runtime.operators.GlutenVectorOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -254,11 +254,14 @@ protected Transformation translateToPlanInternal( outputType, rowtimeIndex); final OneInputStreamOperator windowOperator = - new GlutenVectorOneInputOperator( + new GlutenOneInputOperator( new StatefulPlanNode(windowAgg.getId(), windowAgg), PlanNodeIdGenerator.newId(), inputType, - Map.of(windowAgg.getId(), outputType)); + Map.of(windowAgg.getId(), outputType), + RowData.class, + RowData.class, + "StreamExecGlobalWindowAggregate"); // --- End Gluten-specific code changes --- final RowDataKeySelector selector = diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java index 6758111254d9..0f9ff4a9a5e4 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java @@ -18,7 +18,7 @@ import org.apache.gluten.rexnode.AggregateCallConverter; import org.apache.gluten.rexnode.Utils; -import org.apache.gluten.table.runtime.operators.GlutenVectorOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -206,11 +206,14 @@ protected Transformation translateToPlanInternal( new GroupAggregationNode( PlanNodeIdGenerator.newId(), aggsHandlerNode, keySelectorSpec, outputType); final OneInputStreamOperator operator = - new GlutenVectorOneInputOperator( + new GlutenOneInputOperator( new StatefulPlanNode(aggregation.getId(), aggregation), PlanNodeIdGenerator.newId(), inputType, - Map.of(aggregation.getId(), outputType)); + Map.of(aggregation.getId(), outputType), + RowData.class, + RowData.class, + "StreamExecGroupAggregate"); // --- End Gluten-specific code changes --- // partitioned aggregation diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java index a2ce1425eac8..2fd681519668 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java @@ -18,7 +18,7 @@ import org.apache.gluten.rexnode.AggregateCallConverter; import org.apache.gluten.rexnode.Utils; -import org.apache.gluten.table.runtime.operators.GlutenVectorOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -240,11 +240,14 @@ protected Transformation translateToPlanInternal( 1, // TODO: get from window attributes outputType); final OneInputStreamOperator windowOperator = - new GlutenVectorOneInputOperator( + new GlutenOneInputOperator( new StatefulPlanNode(windowAgg.getId(), windowAgg), PlanNodeIdGenerator.newId(), inputType, - Map.of(windowAgg.getId(), outputType)); + Map.of(windowAgg.getId(), outputType), + RowData.class, + RowData.class, + "StreamExecGroupWindowAggregate"); // --- End Gluten-specific code changes --- final OneInputTransformation transform = @@ -264,6 +267,9 @@ protected Transformation translateToPlanInternal( InternalTypeInfo.of(inputRowType)); transform.setStateKeySelector(selector); transform.setStateKeyType(selector.getProducedType()); + + // GlutenKeySelector selector = new GlutenKeySelector(); + // transform.setStateKeySelector(selector); return transform; } } diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecJoin.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecJoin.java index fe8cec7974d6..cf66cbbee4be 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecJoin.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecJoin.java @@ -19,7 +19,7 @@ import org.apache.gluten.rexnode.RexConversionContext; import org.apache.gluten.rexnode.RexNodeConverter; import org.apache.gluten.rexnode.Utils; -import org.apache.gluten.table.runtime.operators.GlutenVectorTwoInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenTwoInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -308,13 +308,16 @@ protected Transformation translateToPlanInternal( outputType, 1024); operator = - new GlutenVectorTwoInputOperator( + new GlutenTwoInputOperator( new StatefulPlanNode(join.getId(), join), leftInput.getId(), rightInput.getId(), leftInputType, rightInputType, - Map.of(join.getId(), outputType)); + Map.of(join.getId(), outputType), + RowData.class, + RowData.class, + "StreamExecJoin"); // --- End Gluten-specific code changes --- } } diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java index 4ffcf7998c27..fd0cc04a3077 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java @@ -19,7 +19,7 @@ import org.apache.gluten.rexnode.AggregateCallConverter; import org.apache.gluten.rexnode.Utils; import org.apache.gluten.rexnode.WindowUtils; -import org.apache.gluten.table.runtime.operators.GlutenVectorOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -222,11 +222,14 @@ protected Transformation translateToPlanInternal( outputType, rowtimeIndex); final OneInputStreamOperator localAggOperator = - new GlutenVectorOneInputOperator( + new GlutenOneInputOperator( new StatefulPlanNode(windowAgg.getId(), windowAgg), PlanNodeIdGenerator.newId(), inputType, - Map.of(windowAgg.getId(), outputType)); + Map.of(windowAgg.getId(), outputType), + RowData.class, + RowData.class, + "StreamExecLocalWindowAggregate"); // --- End Gluten-specific code changes --- return ExecNodeUtil.createOneInputTransformation( diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecRank.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecRank.java index c1b87e8f60ac..eeaa27b7504a 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecRank.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecRank.java @@ -17,7 +17,7 @@ package org.apache.flink.table.planner.plan.nodes.exec.stream; import org.apache.gluten.rexnode.Utils; -import org.apache.gluten.table.runtime.operators.GlutenVectorOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -268,11 +268,14 @@ protected Transformation translateToPlanInternal( streamTopNNode, outputType); final OneInputStreamOperator operator = - new GlutenVectorOneInputOperator( + new GlutenOneInputOperator( new StatefulPlanNode(streamRankNode.getId(), streamRankNode), PlanNodeIdGenerator.newId(), inputType, - Map.of(streamRankNode.getId(), outputType)); + Map.of(streamRankNode.getId(), outputType), + RowData.class, + RowData.class, + "StreamExecRank"); // --- End Gluten-specific code changes --- OneInputTransformation transform = diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWatermarkAssigner.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWatermarkAssigner.java index e0a56ac918ba..3230bac8c24d 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWatermarkAssigner.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWatermarkAssigner.java @@ -19,7 +19,7 @@ import org.apache.gluten.rexnode.RexConversionContext; import org.apache.gluten.rexnode.RexNodeConverter; import org.apache.gluten.rexnode.Utils; -import org.apache.gluten.table.runtime.operators.GlutenVectorOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -162,11 +162,14 @@ protected Transformation translateToPlanInternal( rowtimeFieldIndex, watermarkInterval); final OneInputStreamOperator watermarkOperator = - new GlutenVectorOneInputOperator( + new GlutenOneInputOperator( new StatefulPlanNode(watermark.getId(), watermark), PlanNodeIdGenerator.newId(), inputType, - Map.of(watermark.getId(), outputType)); + Map.of(watermark.getId(), outputType), + RowData.class, + RowData.class, + "StreamExecWatermarkAssigner"); return ExecNodeUtil.createOneInputTransformation( inputTransform, diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java index afe5d0148396..499a65516db3 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java @@ -19,7 +19,7 @@ import org.apache.gluten.rexnode.AggregateCallConverter; import org.apache.gluten.rexnode.Utils; import org.apache.gluten.rexnode.WindowUtils; -import org.apache.gluten.table.runtime.operators.GlutenVectorOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -246,11 +246,14 @@ protected Transformation translateToPlanInternal( outputType, rowtimeIndex); final OneInputStreamOperator windowOperator = - new GlutenVectorOneInputOperator( + new GlutenOneInputOperator( new StatefulPlanNode(windowAgg.getId(), windowAgg), PlanNodeIdGenerator.newId(), inputType, - Map.of(windowAgg.getId(), outputType)); + Map.of(windowAgg.getId(), outputType), + RowData.class, + RowData.class, + "StreamExecWindowAggregate"); // --- End Gluten-specific code changes --- final RowDataKeySelector selector = @@ -270,6 +273,8 @@ protected Transformation translateToPlanInternal( false); // set KeyType and Selector for state + // This key selector will be updated in OffloadedJobGraphGenerator to GlutenKeySelector as + // needed. transform.setStateKeySelector(selector); transform.setStateKeyType(selector.getProducedType()); return transform; diff --git a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowJoin.java b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowJoin.java index d32b22911b53..400a86ccb053 100644 --- a/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowJoin.java +++ b/gluten-flink/planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowJoin.java @@ -17,7 +17,7 @@ package org.apache.flink.table.planner.plan.nodes.exec.stream; import org.apache.gluten.rexnode.Utils; -import org.apache.gluten.table.runtime.operators.GlutenVectorTwoInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenTwoInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -227,13 +227,16 @@ protected Transformation translateToPlanInternal( leftWindowEndIndex, rightWindowEndIndex); final TwoInputStreamOperator operator = - new GlutenVectorTwoInputOperator( + new GlutenTwoInputOperator( new StatefulPlanNode(join.getId(), join), leftInput.getId(), rightInput.getId(), leftInputType, rightInputType, - Map.of(join.getId(), outputType)); + Map.of(join.getId(), outputType), + RowData.class, + RowData.class, + "StreamExecWindowJoin"); // --- End Gluten-specific code changes --- final RowType returnType = (RowType) getOutputType(); diff --git a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/FileSystemSinkFactory.java b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/FileSystemSinkFactory.java index 657e32176241..ff08dcec638d 100644 --- a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/FileSystemSinkFactory.java +++ b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/FileSystemSinkFactory.java @@ -128,7 +128,10 @@ public Transformation buildVeloxSink( new StatefulPlanNode(fileSystemWriteNode.getId(), fileSystemWriteNode), PlanNodeIdGenerator.newId(), inputDataColumns, - Map.of(fileSystemWriteNode.getId(), ignore)); + Map.of(fileSystemWriteNode.getId(), ignore), + RowData.class, + RowData.class, + "FileSystemInsertTable"); GlutenOneInputOperatorFactory operatorFactory = new GlutenOneInputOperatorFactory(onewInputOperator); Transformation veloxFileWriterTransformation = diff --git a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/FromElementsSourceFactory.java b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/FromElementsSourceFactory.java index da31edeccd14..40c8b86c11f5 100644 --- a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/FromElementsSourceFactory.java +++ b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/FromElementsSourceFactory.java @@ -17,7 +17,7 @@ package org.apache.gluten.velox; import org.apache.gluten.streaming.api.operators.GlutenStreamSource; -import org.apache.gluten.table.runtime.operators.GlutenVectorSourceFunction; +import org.apache.gluten.table.runtime.operators.GlutenSourceFunction; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; import org.apache.gluten.util.ReflectUtils; @@ -94,11 +94,13 @@ public Transformation buildVeloxSource( new TableScanNode(PlanNodeIdGenerator.newId(), rowType, tableHandle, List.of()); GlutenStreamSource op = new GlutenStreamSource( - new GlutenVectorSourceFunction( + new GlutenSourceFunction( new StatefulPlanNode(scanNode.getId(), scanNode), Map.of(scanNode.getId(), rowType), scanNode.getId(), - new FromElementsConnectorSplit("connector-from-elements", 0, false))); + new FromElementsConnectorSplit("connector-from-elements", 0, false), + RowData.class), + "FromElementsSource"); return new LegacySourceTransformation( sourceTransformation.getName(), op, diff --git a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/FuzzerSourceSinkFactory.java b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/FuzzerSourceSinkFactory.java index 16ac7f083167..6e9c232f048a 100644 --- a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/FuzzerSourceSinkFactory.java +++ b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/FuzzerSourceSinkFactory.java @@ -18,8 +18,8 @@ import org.apache.gluten.streaming.api.operators.GlutenOneInputOperatorFactory; import org.apache.gluten.streaming.api.operators.GlutenStreamSource; -import org.apache.gluten.table.runtime.operators.GlutenVectorOneInputOperator; -import org.apache.gluten.table.runtime.operators.GlutenVectorSourceFunction; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenSourceFunction; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -86,11 +86,12 @@ public Transformation buildVeloxSource( PlanNode tableScan = new TableScanNode(id, outputType, tableHandle, List.of()); GlutenStreamSource sourceOp = new GlutenStreamSource( - new GlutenVectorSourceFunction( + new GlutenSourceFunction( new StatefulPlanNode(id, tableScan), Map.of(id, outputType), id, - new FuzzerConnectorSplit("connector-fuzzer", 1000))); + new FuzzerConnectorSplit("connector-fuzzer", 1000), + RowData.class)); return new LegacySourceTransformation( sourceTransformation.getName(), sourceOp, @@ -128,11 +129,14 @@ public Transformation buildVeloxSink( List.of(new EmptyNode(outputType))); GlutenOneInputOperatorFactory operatorFactory = new GlutenOneInputOperatorFactory( - new GlutenVectorOneInputOperator( + new GlutenOneInputOperator( new StatefulPlanNode(plan.getId(), plan), PlanNodeIdGenerator.newId(), outputType, - Map.of(plan.getId(), ignore))); + Map.of(plan.getId(), ignore), + RowData.class, + RowData.class, + "FuzzerSink")); DataStream newInputStream = sinkTransformation .getInputStream() diff --git a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/KafkaSourceSinkFactory.java b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/KafkaSourceSinkFactory.java index 54e2b8b4361a..8644dac3dac5 100644 --- a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/KafkaSourceSinkFactory.java +++ b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/KafkaSourceSinkFactory.java @@ -17,7 +17,7 @@ package org.apache.gluten.velox; import org.apache.gluten.streaming.api.operators.GlutenStreamSource; -import org.apache.gluten.table.runtime.operators.GlutenVectorSourceFunction; +import org.apache.gluten.table.runtime.operators.GlutenSourceFunction; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; import org.apache.gluten.util.ReflectUtils; @@ -115,11 +115,13 @@ public Transformation buildVeloxSource( TableScanNode kafkaScan = new TableScanNode(planId, outputType, kafkaTableHandle, List.of()); GlutenStreamSource sourceOp = new GlutenStreamSource( - new GlutenVectorSourceFunction( + new GlutenSourceFunction( new StatefulPlanNode(kafkaScan.getId(), kafkaScan), Map.of(kafkaScan.getId(), outputType), kafkaScan.getId(), - connectorSplit)); + connectorSplit, + RowData.class), + "KafkaSource"); SourceTransformation sourceTransformation = (SourceTransformation) transformation; return new LegacySourceTransformation( sourceTransformation.getName(), diff --git a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/NexmarkSourceFactory.java b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/NexmarkSourceFactory.java index 736f3cc3c72b..ae0d94fb3cc8 100644 --- a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/NexmarkSourceFactory.java +++ b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/NexmarkSourceFactory.java @@ -17,7 +17,7 @@ package org.apache.gluten.velox; import org.apache.gluten.streaming.api.operators.GlutenStreamSource; -import org.apache.gluten.table.runtime.operators.GlutenVectorSourceFunction; +import org.apache.gluten.table.runtime.operators.GlutenSourceFunction; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; import org.apache.gluten.util.ReflectUtils; @@ -35,10 +35,14 @@ import org.apache.flink.table.data.RowData; import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.util.List; import java.util.Map; public class NexmarkSourceFactory implements VeloxSourceSinkFactory { + private static final Logger LOG = LoggerFactory.getLogger(NexmarkSourceFactory.class); @SuppressWarnings("rawtypes") @Override @@ -79,13 +83,15 @@ public Transformation buildVeloxSource( new TableScanNode(id, outputType, new NexmarkTableHandle("connector-nexmark"), List.of()); GlutenStreamSource sourceOp = new GlutenStreamSource( - new GlutenVectorSourceFunction( + new GlutenSourceFunction( new StatefulPlanNode(tableScan.getId(), tableScan), Map.of(id, outputType), id, new NexmarkConnectorSplit( "connector-nexmark", - maxEvents > Integer.MAX_VALUE ? Integer.MAX_VALUE : maxEvents.intValue()))); + maxEvents > Integer.MAX_VALUE ? Integer.MAX_VALUE : maxEvents.intValue()), + RowData.class)); + return new LegacySourceTransformation( transformation.getName(), sourceOp, diff --git a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/PrintSinkFactory.java b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/PrintSinkFactory.java index b00a76a21f16..737d6bab7e78 100644 --- a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/PrintSinkFactory.java +++ b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/PrintSinkFactory.java @@ -17,7 +17,7 @@ package org.apache.gluten.velox; import org.apache.gluten.streaming.api.operators.GlutenOneInputOperatorFactory; -import org.apache.gluten.table.runtime.operators.GlutenVectorOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -114,11 +114,14 @@ public Transformation buildVeloxSink( inputTrans, transformation.getName(), new GlutenOneInputOperatorFactory( - new GlutenVectorOneInputOperator( + new GlutenOneInputOperator( new StatefulPlanNode(tableWriteNode.getId(), tableWriteNode), PlanNodeIdGenerator.newId(), inputColumns, - Map.of(tableWriteNode.getId(), ignore))), + Map.of(tableWriteNode.getId(), ignore), + RowData.class, + RowData.class, + "PrintSink")), transformation.getParallelism()); } } diff --git a/gluten-flink/runtime/src/main/java/org/apache/flink/client/StreamGraphTranslator.java b/gluten-flink/runtime/src/main/java/org/apache/flink/client/StreamGraphTranslator.java index db702d56053e..6be550e71cd5 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/flink/client/StreamGraphTranslator.java +++ b/gluten-flink/runtime/src/main/java/org/apache/flink/client/StreamGraphTranslator.java @@ -16,41 +16,16 @@ */ package org.apache.flink.client; -import org.apache.gluten.streaming.api.operators.GlutenOneInputOperatorFactory; -import org.apache.gluten.streaming.api.operators.GlutenOperator; -import org.apache.gluten.streaming.api.operators.GlutenStreamSource; -import org.apache.gluten.table.runtime.keyselector.GlutenKeySelector; -import org.apache.gluten.table.runtime.operators.GlutenVectorOneInputOperator; -import org.apache.gluten.table.runtime.operators.GlutenVectorSourceFunction; -import org.apache.gluten.table.runtime.operators.GlutenVectorTwoInputOperator; -import org.apache.gluten.table.runtime.typeutils.GlutenRowVectorSerializer; -import org.apache.gluten.util.Utils; - -import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; -import io.github.zhztheplayer.velox4j.type.RowType; +import org.apache.gluten.client.OffloadedJobGraphGenerator; import org.apache.flink.api.dag.Pipeline; import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.JobGraph; -import org.apache.flink.runtime.jobgraph.JobVertex; -import org.apache.flink.streaming.api.graph.NonChainedOutput; -import org.apache.flink.streaming.api.graph.StreamConfig; -import org.apache.flink.streaming.api.graph.StreamEdge; import org.apache.flink.streaming.api.graph.StreamGraph; -import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; -import org.apache.flink.streaming.api.operators.StreamOperator; -import org.apache.flink.streaming.api.operators.StreamOperatorFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; - import static org.apache.flink.util.Preconditions.checkArgument; /** @@ -73,12 +48,17 @@ public StreamGraphTranslator(ClassLoader userClassloader) { @Override public JobGraph translateToJobGraph( Pipeline pipeline, Configuration optimizerConfiguration, int defaultParallelism) { + // --- Begin Gluten-specific code changes --- + checkArgument( pipeline instanceof StreamGraph, "Given pipeline is not a DataStream StreamGraph."); - StreamGraph streamGraph = (StreamGraph) pipeline; JobGraph jobGraph = streamGraph.getJobGraph(userClassloader, null); - return mergeGlutenOperators(jobGraph); + OffloadedJobGraphGenerator generator = + new OffloadedJobGraphGenerator(jobGraph, userClassloader); + return generator.generate(); + // --- End Gluten-specific code changes --- + } @Override @@ -95,155 +75,4 @@ public String translateToJSONExecutionPlan(Pipeline pipeline) { public boolean canTranslate(Pipeline pipeline) { return pipeline instanceof StreamGraph; } - - // --- Begin Gluten-specific code changes --- - private JobGraph mergeGlutenOperators(JobGraph jobGraph) { - for (JobVertex vertex : jobGraph.getVertices()) { - StreamConfig streamConfig = new StreamConfig(vertex.getConfiguration()); - buildGlutenChains(streamConfig); - LOG.debug("Vertex {} is {}.", vertex.getName(), streamConfig); - } - return jobGraph; - } - - // A JobVertex may contain several operators chained like this: Source-->Op1-->Op2-->Sink1. - // -->Sink2. - // If the operators connected all support translated to gluten, we merge them into - // a single GlutenOperator to avoid data transferred between flink and native. - // One operator may be followed by several other operators. - private void buildGlutenChains(StreamConfig vertexConfig) { - Map serializedTasks = - vertexConfig.getTransitiveChainedTaskConfigs(userClassloader); - Map chainedTasks = new HashMap<>(serializedTasks.size()); - serializedTasks.forEach( - (id, config) -> chainedTasks.put(id, new StreamConfig(config.getConfiguration()))); - buildGlutenChains(vertexConfig, chainedTasks); - // TODO: may need fallback if failed. - vertexConfig.setAndSerializeTransitiveChainedTaskConfigs(chainedTasks); - } - - private void buildGlutenChains(StreamConfig taskConfig, Map chainedTasks) { - List outEdges = taskConfig.getChainedOutputs(userClassloader); - Optional sourceOperatorOpt = getGlutenOperator(taskConfig); - GlutenOperator sourceOperator = sourceOperatorOpt.orElse(null); - boolean isSourceGluten = sourceOperatorOpt.isPresent(); - if (outEdges == null || outEdges.isEmpty()) { - LOG.debug("{} has no chained task.", taskConfig.getOperatorName()); - // TODO: judge whether can set? - if (isSourceGluten) { - if (taskConfig.getOperatorName().equals("exchange-hash")) { - taskConfig.setTypeSerializerOut(new GlutenRowVectorSerializer(null)); - } - Map nodeToNonChainedOuts = new HashMap<>(outEdges.size()); - taskConfig - .getOperatorNonChainedOutputs(userClassloader) - .forEach(edge -> nodeToNonChainedOuts.put(edge.getDataSetId(), sourceOperator.getId())); - Utils.setNodeToNonChainedOutputs(taskConfig, nodeToNonChainedOuts); - taskConfig.serializeAllConfigs(); - } - return; - } - Map nodeToChainedOuts = new HashMap<>(outEdges.size()); - Map nodeToNonChainedOuts = new HashMap<>(outEdges.size()); - Map nodeToOutTypes = new HashMap<>(outEdges.size()); - List chainedOutputs = new ArrayList<>(outEdges.size()); - List nonChainedOutputs = new ArrayList<>(outEdges.size()); - StatefulPlanNode sourceNode = isSourceGluten ? sourceOperator.getPlanNode() : null; - boolean allGluten = true; - LOG.debug("Edge size {}, OP {}", outEdges.size(), sourceOperator); - for (StreamEdge outEdge : outEdges) { - StreamConfig outTask = chainedTasks.get(outEdge.getTargetId()); - if (outTask == null) { - LOG.error("Not find task {} in Chained tasks", outEdge.getTargetId()); - allGluten = false; - break; - } - buildGlutenChains(outTask, chainedTasks); - Optional outOperator = getGlutenOperator(outTask); - if (isSourceGluten && outOperator.isPresent()) { - StatefulPlanNode outNode = outOperator.get().getPlanNode(); - if (sourceNode != null) { - sourceNode.addTarget(outNode); - LOG.debug("Add {} target {}", sourceNode, outNode); - } else { - sourceNode = outNode; - LOG.debug("Set target node to {}", outNode); - } - Map node2Out = Utils.getNodeToChainedOutputs(outTask, userClassloader); - if (node2Out != null) { - nodeToChainedOuts.putAll(node2Out); - nodeToNonChainedOuts.putAll(Utils.getNodeToNonChainedOutputs(outTask, userClassloader)); - } else { - outTask - .getChainedOutputs(userClassloader) - .forEach(edge -> nodeToChainedOuts.put(outNode.getId(), edge.getTargetId())); - outTask - .getOperatorNonChainedOutputs(userClassloader) - .forEach(edge -> nodeToNonChainedOuts.put(edge.getDataSetId(), outNode.getId())); - } - nodeToOutTypes.putAll(outOperator.get().getOutputTypes()); - chainedOutputs.addAll(outTask.getChainedOutputs(userClassloader)); - nonChainedOutputs.addAll(outTask.getOperatorNonChainedOutputs(userClassloader)); - } else { - allGluten = false; - LOG.debug( - "{} and {} can not be merged", taskConfig.getOperatorName(), outTask.getOperatorName()); - break; - } - } - if (allGluten) { - if (sourceOperator instanceof GlutenStreamSource) { - GlutenStreamSource streamSource = (GlutenStreamSource) sourceOperator; - taskConfig.setStreamOperator( - new GlutenStreamSource( - new GlutenVectorSourceFunction( - sourceNode, - nodeToOutTypes, - sourceOperator.getId(), - streamSource.getConnectorSplit()))); - } else if (sourceOperator instanceof GlutenVectorTwoInputOperator) { - GlutenVectorTwoInputOperator twoInputOperator = - (GlutenVectorTwoInputOperator) sourceOperator; - taskConfig.setStreamOperator( - new GlutenVectorTwoInputOperator( - sourceNode, - twoInputOperator.getLeftId(), - twoInputOperator.getRightId(), - twoInputOperator.getLeftInputType(), - twoInputOperator.getRightInputType(), - nodeToOutTypes)); - // TODO: judge whether can set? - taskConfig.setStatePartitioner(0, new GlutenKeySelector()); - taskConfig.setStatePartitioner(1, new GlutenKeySelector()); - taskConfig.setupNetworkInputs( - new GlutenRowVectorSerializer(null), new GlutenRowVectorSerializer(null)); - } else { - taskConfig.setStreamOperator( - new GlutenVectorOneInputOperator( - sourceNode, sourceOperator.getId(), sourceOperator.getInputType(), nodeToOutTypes)); - // TODO: judge whether can set? - taskConfig.setStatePartitioner(0, new GlutenKeySelector()); - taskConfig.setupNetworkInputs(new GlutenRowVectorSerializer(null)); - } - Utils.setNodeToChainedOutputs(taskConfig, nodeToChainedOuts); - Utils.setNodeToNonChainedOutputs(taskConfig, nodeToNonChainedOuts); - taskConfig.setChainedOutputs(chainedOutputs); - taskConfig.setOperatorNonChainedOutputs(nonChainedOutputs); - taskConfig.serializeAllConfigs(); - } - } - - private Optional getGlutenOperator(StreamConfig taskConfig) { - StreamOperatorFactory operatorFactory = taskConfig.getStreamOperatorFactory(userClassloader); - if (operatorFactory instanceof SimpleOperatorFactory) { - StreamOperator streamOperator = taskConfig.getStreamOperator(userClassloader); - if (streamOperator instanceof GlutenOperator) { - return Optional.of((GlutenOperator) streamOperator); - } - } else if (operatorFactory instanceof GlutenOneInputOperatorFactory) { - return Optional.of(((GlutenOneInputOperatorFactory) operatorFactory).getOperator()); - } - return Optional.empty(); - } - // --- End Gluten-specific code changes --- } diff --git a/gluten-flink/runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java b/gluten-flink/runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java deleted file mode 100644 index 082bd6edaf57..000000000000 --- a/gluten-flink/runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java +++ /dev/null @@ -1,944 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.streaming.runtime.tasks; - -import org.apache.gluten.streaming.runtime.tasks.GlutenOutputCollector; -import org.apache.gluten.util.Utils; - -import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.metrics.Counter; -import org.apache.flink.metrics.SimpleCounter; -import org.apache.flink.metrics.groups.OperatorMetricGroup; -import org.apache.flink.runtime.checkpoint.CheckpointException; -import org.apache.flink.runtime.checkpoint.CheckpointMetaData; -import org.apache.flink.runtime.checkpoint.CheckpointOptions; -import org.apache.flink.runtime.checkpoint.StateObjectCollection; -import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; -import org.apache.flink.runtime.event.AbstractEvent; -import org.apache.flink.runtime.execution.Environment; -import org.apache.flink.runtime.io.network.api.StopMode; -import org.apache.flink.runtime.io.network.api.writer.RecordWriter; -import org.apache.flink.runtime.io.network.api.writer.RecordWriterDelegate; -import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate; -import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; -import org.apache.flink.runtime.jobgraph.OperatorID; -import org.apache.flink.runtime.metrics.MetricNames; -import org.apache.flink.runtime.metrics.groups.InternalOperatorMetricGroup; -import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; -import org.apache.flink.runtime.operators.coordination.AcknowledgeCheckpointEvent; -import org.apache.flink.runtime.operators.coordination.OperatorEvent; -import org.apache.flink.runtime.operators.coordination.OperatorEventDispatcher; -import org.apache.flink.runtime.plugable.SerializationDelegate; -import org.apache.flink.runtime.state.CheckpointStreamFactory; -import org.apache.flink.runtime.state.SnapshotResult; -import org.apache.flink.streaming.api.graph.NonChainedOutput; -import org.apache.flink.streaming.api.graph.StreamConfig; -import org.apache.flink.streaming.api.graph.StreamEdge; -import org.apache.flink.streaming.api.operators.BoundedMultiInput; -import org.apache.flink.streaming.api.operators.CountingOutput; -import org.apache.flink.streaming.api.operators.Input; -import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator; -import org.apache.flink.streaming.api.operators.OneInputStreamOperator; -import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures; -import org.apache.flink.streaming.api.operators.Output; -import org.apache.flink.streaming.api.operators.SourceOperator; -import org.apache.flink.streaming.api.operators.StreamOperator; -import org.apache.flink.streaming.api.operators.StreamOperatorFactory; -import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil; -import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer; -import org.apache.flink.streaming.runtime.io.RecordWriterOutput; -import org.apache.flink.streaming.runtime.io.StreamTaskSourceInput; -import org.apache.flink.streaming.runtime.operators.sink.SinkWriterOperatorFactory; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxExecutorFactory; -import org.apache.flink.util.CollectionUtil; -import org.apache.flink.util.FlinkException; -import org.apache.flink.util.OutputTag; -import org.apache.flink.util.SerializedValue; - -import org.apache.flink.shaded.guava31.com.google.common.io.Closer; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import javax.annotation.Nullable; - -import java.io.Closeable; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Supplier; -import java.util.stream.Collectors; - -import static org.apache.flink.util.Preconditions.checkArgument; -import static org.apache.flink.util.Preconditions.checkNotNull; -import static org.apache.flink.util.Preconditions.checkState; - -/** - * The {@code OperatorChain} contains all operators that are executed as one chain within a single - * {@link StreamTask}. - * - *

The main entry point to the chain is it's {@code mainOperator}. {@code mainOperator} is - * driving the execution of the {@link StreamTask}, by pulling the records from network inputs - * and/or source inputs and pushing produced records to the remaining chained operators. - * - * @param The type of elements accepted by the chain, i.e., the input type of the chain's main - * operator. - */ -public abstract class OperatorChain> - implements BoundedMultiInput, Closeable { - - private static final Logger LOG = LoggerFactory.getLogger(OperatorChain.class); - - protected final RecordWriterOutput[] streamOutputs; - - protected final WatermarkGaugeExposingOutput> mainOperatorOutput; - - /** - * For iteration, {@link StreamIterationHead} and {@link StreamIterationTail} used for executing - * feedback edges do not contain any operators, in which case, {@code mainOperatorWrapper} and - * {@code tailOperatorWrapper} are null. - * - *

Usually first operator in the chain is the same as {@link #mainOperatorWrapper}, but that's - * not the case if there are chained source inputs. In this case, one of the source inputs will be - * the first operator. For example the following operator chain is possible: - * - *

-   * first
-   *      \
-   *      main (multi-input) -> ... -> tail
-   *      /
-   * second
-   * 
- * - *

Where "first" and "second" (there can be more) are chained source operators. When it comes - * to things like closing, stat initialisation or state snapshotting, the operator chain is - * traversed: first, second, main, ..., tail or in reversed order: tail, ..., main, second, first - */ - @Nullable protected final StreamOperatorWrapper mainOperatorWrapper; - - @Nullable protected final StreamOperatorWrapper firstOperatorWrapper; - @Nullable protected final StreamOperatorWrapper tailOperatorWrapper; - - protected final Map chainedSources; - - protected final int numOperators; - - protected final OperatorEventDispatcherImpl operatorEventDispatcher; - - protected final Closer closer = Closer.create(); - - protected final @Nullable FinishedOnRestoreInput finishedOnRestoreInput; - - protected boolean isClosed; - - public OperatorChain( - StreamTask containingTask, - RecordWriterDelegate>> recordWriterDelegate) { - - this.operatorEventDispatcher = - new OperatorEventDispatcherImpl( - containingTask.getEnvironment().getUserCodeClassLoader().asClassLoader(), - containingTask.getEnvironment().getOperatorCoordinatorEventGateway()); - - final ClassLoader userCodeClassloader = containingTask.getUserCodeClassLoader(); - final StreamConfig configuration = containingTask.getConfiguration(); - - StreamOperatorFactory operatorFactory = - configuration.getStreamOperatorFactory(userCodeClassloader); - - // we read the chained configs, and the order of record writer registrations by output name - Map chainedConfigs = - configuration.getTransitiveChainedTaskConfigsWithSelf(userCodeClassloader); - - // create the final output stream writers - // we iterate through all the out edges from this job vertex and create a stream output - List outputsInOrder = - configuration.getVertexNonChainedOutputs(userCodeClassloader); - Map> recordWriterOutputs = - CollectionUtil.newHashMapWithExpectedSize(outputsInOrder.size()); - this.streamOutputs = new RecordWriterOutput[outputsInOrder.size()]; - this.finishedOnRestoreInput = - this.isTaskDeployedAsFinished() - ? new FinishedOnRestoreInput( - streamOutputs, configuration.getInputs(userCodeClassloader).length) - : null; - - // from here on, we need to make sure that the output writers are shut down again on failure - boolean success = false; - try { - createChainOutputs( - outputsInOrder, - recordWriterDelegate, - chainedConfigs, - containingTask, - recordWriterOutputs); - - // we create the chain of operators and grab the collector that leads into the chain - List> allOpWrappers = new ArrayList<>(chainedConfigs.size()); - this.mainOperatorOutput = - createOutputCollector( - containingTask, - configuration, - chainedConfigs, - userCodeClassloader, - recordWriterOutputs, - allOpWrappers, - containingTask.getMailboxExecutorFactory(), - operatorFactory != null); - - if (operatorFactory != null) { - Tuple2> mainOperatorAndTimeService = - StreamOperatorFactoryUtil.createOperator( - operatorFactory, - containingTask, - configuration, - mainOperatorOutput, - operatorEventDispatcher); - - OP mainOperator = mainOperatorAndTimeService.f0; - mainOperator - .getMetricGroup() - .gauge(MetricNames.IO_CURRENT_OUTPUT_WATERMARK, mainOperatorOutput.getWatermarkGauge()); - this.mainOperatorWrapper = - createOperatorWrapper( - mainOperator, containingTask, configuration, mainOperatorAndTimeService.f1, true); - - // add main operator to end of chain - allOpWrappers.add(mainOperatorWrapper); - - this.tailOperatorWrapper = allOpWrappers.get(0); - } else { - checkState(allOpWrappers.size() == 0); - this.mainOperatorWrapper = null; - this.tailOperatorWrapper = null; - } - - this.chainedSources = - createChainedSources( - containingTask, - configuration.getInputs(userCodeClassloader), - chainedConfigs, - userCodeClassloader, - allOpWrappers); - - this.numOperators = allOpWrappers.size(); - - firstOperatorWrapper = linkOperatorWrappers(allOpWrappers); - - success = true; - } finally { - // make sure we clean up after ourselves in case of a failure after acquiring - // the first resources - if (!success) { - for (int i = 0; i < streamOutputs.length; i++) { - if (streamOutputs[i] != null) { - streamOutputs[i].close(); - } - streamOutputs[i] = null; - } - } - } - } - - @VisibleForTesting - OperatorChain( - List> allOperatorWrappers, - RecordWriterOutput[] streamOutputs, - WatermarkGaugeExposingOutput> mainOperatorOutput, - StreamOperatorWrapper mainOperatorWrapper) { - this.streamOutputs = streamOutputs; - this.finishedOnRestoreInput = null; - this.mainOperatorOutput = checkNotNull(mainOperatorOutput); - this.operatorEventDispatcher = null; - - checkState(allOperatorWrappers != null && allOperatorWrappers.size() > 0); - this.mainOperatorWrapper = checkNotNull(mainOperatorWrapper); - this.tailOperatorWrapper = allOperatorWrappers.get(0); - this.numOperators = allOperatorWrappers.size(); - this.chainedSources = Collections.emptyMap(); - - firstOperatorWrapper = linkOperatorWrappers(allOperatorWrappers); - } - - public abstract boolean isTaskDeployedAsFinished(); - - public abstract void dispatchOperatorEvent( - OperatorID operator, SerializedValue event) throws FlinkException; - - public abstract void prepareSnapshotPreBarrier(long checkpointId) throws Exception; - - /** - * Ends the main operator input specified by {@code inputId}). - * - * @param inputId the input ID starts from 1 which indicates the first input. - */ - public abstract void endInput(int inputId) throws Exception; - - /** - * Initialize state and open all operators in the chain from tail to heads, contrary to - * {@link StreamOperator#close()} which happens heads to tail (see {@link - * #finishOperators(StreamTaskActionExecutor, StopMode)}). - */ - public abstract void initializeStateAndOpenOperators( - StreamTaskStateInitializer streamTaskStateInitializer) throws Exception; - - /** - * Closes all operators in a chain effect way. Closing happens from heads to tail operator - * in the chain, contrary to {@link StreamOperator#open()} which happens tail to heads (see - * {@link #initializeStateAndOpenOperators(StreamTaskStateInitializer)}). - */ - public abstract void finishOperators(StreamTaskActionExecutor actionExecutor, StopMode stopMode) - throws Exception; - - public abstract void notifyCheckpointComplete(long checkpointId) throws Exception; - - public abstract void notifyCheckpointAborted(long checkpointId) throws Exception; - - public abstract void notifyCheckpointSubsumed(long checkpointId) throws Exception; - - public abstract void snapshotState( - Map operatorSnapshotsInProgress, - CheckpointMetaData checkpointMetaData, - CheckpointOptions checkpointOptions, - Supplier isRunning, - ChannelStateWriter.ChannelStateWriteResult channelStateWriteResult, - CheckpointStreamFactory storage) - throws Exception; - - public OperatorEventDispatcher getOperatorEventDispatcher() { - return operatorEventDispatcher; - } - - public void broadcastEvent(AbstractEvent event) throws IOException { - broadcastEvent(event, false); - } - - public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent) throws IOException { - for (RecordWriterOutput streamOutput : streamOutputs) { - streamOutput.broadcastEvent(event, isPriorityEvent); - } - } - - public void alignedBarrierTimeout(long checkpointId) throws IOException { - for (RecordWriterOutput streamOutput : streamOutputs) { - streamOutput.alignedBarrierTimeout(checkpointId); - } - } - - public void abortCheckpoint(long checkpointId, CheckpointException cause) { - for (RecordWriterOutput streamOutput : streamOutputs) { - streamOutput.abortCheckpoint(checkpointId, cause); - } - } - - /** - * Execute {@link StreamOperator#close()} of each operator in the chain of this {@link - * StreamTask}. Closing happens from tail to head operator in the chain. - */ - public void closeAllOperators() throws Exception { - isClosed = true; - } - - public RecordWriterOutput[] getStreamOutputs() { - return streamOutputs; - } - - /** Returns an {@link Iterable} which traverses all operators in forward topological order. */ - @VisibleForTesting - public Iterable> getAllOperators() { - return getAllOperators(false); - } - - /** - * Returns an {@link Iterable} which traverses all operators in forward or reverse topological - * order. - */ - protected Iterable> getAllOperators(boolean reverse) { - return reverse - ? new StreamOperatorWrapper.ReadIterator(tailOperatorWrapper, true) - : new StreamOperatorWrapper.ReadIterator(mainOperatorWrapper, false); - } - - public Input getFinishedOnRestoreInputOrDefault(Input defaultInput) { - return finishedOnRestoreInput == null ? defaultInput : finishedOnRestoreInput; - } - - public int getNumberOfOperators() { - return numOperators; - } - - public WatermarkGaugeExposingOutput> getMainOperatorOutput() { - return mainOperatorOutput; - } - - public ChainedSource getChainedSource(StreamConfig.SourceInputConfig sourceInput) { - checkArgument( - chainedSources.containsKey(sourceInput), - "Chained source with sourcedId = [%s] was not found", - sourceInput); - return chainedSources.get(sourceInput); - } - - public List>> getChainedSourceOutputs() { - return chainedSources.values().stream() - .map(ChainedSource::getSourceOutput) - .collect(Collectors.toList()); - } - - public StreamTaskSourceInput getSourceTaskInput(StreamConfig.SourceInputConfig sourceInput) { - checkArgument( - chainedSources.containsKey(sourceInput), - "Chained source with sourcedId = [%s] was not found", - sourceInput); - return chainedSources.get(sourceInput).getSourceTaskInput(); - } - - public List> getSourceTaskInputs() { - return chainedSources.values().stream() - .map(ChainedSource::getSourceTaskInput) - .collect(Collectors.toList()); - } - - /** - * This method should be called before finishing the record emission, to make sure any data that - * is still buffered will be sent. It also ensures that all data sending related exceptions are - * recognized. - * - * @throws IOException Thrown, if the buffered data cannot be pushed into the output streams. - */ - public void flushOutputs() throws IOException { - for (RecordWriterOutput streamOutput : getStreamOutputs()) { - streamOutput.flush(); - } - } - - /** - * This method releases all resources of the record writer output. It stops the output flushing - * thread (if there is one) and releases all buffers currently held by the output serializers. - * - *

This method should never fail. - */ - public void close() throws IOException { - closer.close(); - } - - @Nullable - public OP getMainOperator() { - return (mainOperatorWrapper == null) ? null : mainOperatorWrapper.getStreamOperator(); - } - - @Nullable - protected StreamOperator getTailOperator() { - return (tailOperatorWrapper == null) ? null : tailOperatorWrapper.getStreamOperator(); - } - - protected void snapshotChannelStates( - StreamOperator op, - ChannelStateWriter.ChannelStateWriteResult channelStateWriteResult, - OperatorSnapshotFutures snapshotInProgress) { - if (op == getMainOperator()) { - snapshotInProgress.setInputChannelStateFuture( - channelStateWriteResult - .getInputChannelStateHandles() - .thenApply(StateObjectCollection::new) - .thenApply(SnapshotResult::of)); - } - if (op == getTailOperator()) { - snapshotInProgress.setResultSubpartitionStateFuture( - channelStateWriteResult - .getResultSubpartitionStateHandles() - .thenApply(StateObjectCollection::new) - .thenApply(SnapshotResult::of)); - } - } - - public boolean isClosed() { - return isClosed; - } - - /** Wrapper class to access the chained sources and their's outputs. */ - public static class ChainedSource { - private final WatermarkGaugeExposingOutput> chainedSourceOutput; - private final StreamTaskSourceInput sourceTaskInput; - - public ChainedSource( - WatermarkGaugeExposingOutput> chainedSourceOutput, - StreamTaskSourceInput sourceTaskInput) { - this.chainedSourceOutput = chainedSourceOutput; - this.sourceTaskInput = sourceTaskInput; - } - - public WatermarkGaugeExposingOutput> getSourceOutput() { - return chainedSourceOutput; - } - - public StreamTaskSourceInput getSourceTaskInput() { - return sourceTaskInput; - } - } - - // ------------------------------------------------------------------------ - // initialization utilities - // ------------------------------------------------------------------------ - - private void createChainOutputs( - List outputsInOrder, - RecordWriterDelegate>> recordWriterDelegate, - Map chainedConfigs, - StreamTask containingTask, - Map> recordWriterOutputs) { - for (int i = 0; i < outputsInOrder.size(); ++i) { - NonChainedOutput output = outputsInOrder.get(i); - - RecordWriterOutput recordWriterOutput = - createStreamOutput( - recordWriterDelegate.getRecordWriter(i), - output, - chainedConfigs.get(output.getSourceNodeId()), - containingTask.getEnvironment()); - - this.streamOutputs[i] = recordWriterOutput; - recordWriterOutputs.put(output.getDataSetId(), recordWriterOutput); - } - } - - private RecordWriterOutput createStreamOutput( - RecordWriter>> recordWriter, - NonChainedOutput streamOutput, - StreamConfig upStreamConfig, - Environment taskEnvironment) { - OutputTag sideOutputTag = - streamOutput.getOutputTag(); // OutputTag, return null if not sideOutput - - TypeSerializer outSerializer; - - if (streamOutput.getOutputTag() != null) { - // side output - outSerializer = - upStreamConfig.getTypeSerializerSideOut( - streamOutput.getOutputTag(), - taskEnvironment.getUserCodeClassLoader().asClassLoader()); - } else { - // main output - outSerializer = - upStreamConfig.getTypeSerializerOut( - taskEnvironment.getUserCodeClassLoader().asClassLoader()); - } - - return closer.register( - new RecordWriterOutput( - recordWriter, - outSerializer, - sideOutputTag, - streamOutput.supportsUnalignedCheckpoints())); - } - - @SuppressWarnings("rawtypes") - private Map createChainedSources( - StreamTask containingTask, - StreamConfig.InputConfig[] configuredInputs, - Map chainedConfigs, - ClassLoader userCodeClassloader, - List> allOpWrappers) { - if (Arrays.stream(configuredInputs) - .noneMatch(input -> input instanceof StreamConfig.SourceInputConfig)) { - return Collections.emptyMap(); - } - checkState( - mainOperatorWrapper.getStreamOperator() instanceof MultipleInputStreamOperator, - "Creating chained input is only supported with MultipleInputStreamOperator and MultipleInputStreamTask"); - Map chainedSourceInputs = new HashMap<>(); - MultipleInputStreamOperator multipleInputOperator = - (MultipleInputStreamOperator) mainOperatorWrapper.getStreamOperator(); - List operatorInputs = multipleInputOperator.getInputs(); - - int sourceInputGateIndex = - Arrays.stream(containingTask.getEnvironment().getAllInputGates()) - .mapToInt(IndexedInputGate::getInputGateIndex) - .max() - .orElse(-1) - + 1; - - for (int inputId = 0; inputId < configuredInputs.length; inputId++) { - if (!(configuredInputs[inputId] instanceof StreamConfig.SourceInputConfig)) { - continue; - } - StreamConfig.SourceInputConfig sourceInput = - (StreamConfig.SourceInputConfig) configuredInputs[inputId]; - int sourceEdgeId = sourceInput.getInputEdge().getSourceId(); - StreamConfig sourceInputConfig = chainedConfigs.get(sourceEdgeId); - OutputTag outputTag = sourceInput.getInputEdge().getOutputTag(); - - WatermarkGaugeExposingOutput chainedSourceOutput = - createChainedSourceOutput( - containingTask, - sourceInputConfig, - userCodeClassloader, - getFinishedOnRestoreInputOrDefault(operatorInputs.get(inputId)), - multipleInputOperator.getMetricGroup(), - outputTag); - - SourceOperator sourceOperator = - (SourceOperator) - createOperator( - containingTask, - sourceInputConfig, - userCodeClassloader, - (WatermarkGaugeExposingOutput>) chainedSourceOutput, - allOpWrappers, - true); - chainedSourceInputs.put( - sourceInput, - new ChainedSource( - chainedSourceOutput, - this.isTaskDeployedAsFinished() - ? new StreamTaskFinishedOnRestoreSourceInput<>( - sourceOperator, sourceInputGateIndex++, inputId) - : new StreamTaskSourceInput<>(sourceOperator, sourceInputGateIndex++, inputId))); - } - return chainedSourceInputs; - } - - /** - * Get the numRecordsOut counter for the operator represented by the given config. And re-use the - * operator-level counter for the task-level numRecordsOut counter if this operator is at the end - * of the operator chain. - * - *

Return null if we should not use the numRecordsOut counter to track the records emitted by - * this operator. - */ - @Nullable - private Counter getOperatorRecordsOutCounter( - StreamTask containingTask, StreamConfig operatorConfig) { - ClassLoader userCodeClassloader = containingTask.getUserCodeClassLoader(); - Class> streamOperatorFactoryClass = - operatorConfig.getStreamOperatorFactoryClass(userCodeClassloader); - - // Do not use the numRecordsOut counter on output if this operator is SinkWriterOperator. - // - // Metric "numRecordsOut" is defined as the total number of records written to the - // external system in FLIP-33, but this metric is occupied in AbstractStreamOperator as the - // number of records sent to downstream operators, which is number of Committable batches - // sent to SinkCommitter. So we skip registering this metric on output and leave this metric - // to sink writer implementations to report. - try { - Class sinkWriterFactoryClass = - userCodeClassloader.loadClass(SinkWriterOperatorFactory.class.getName()); - if (sinkWriterFactoryClass.isAssignableFrom(streamOperatorFactoryClass)) { - return null; - } - } catch (ClassNotFoundException e) { - throw new StreamTaskException( - "Could not load SinkWriterOperatorFactory class from userCodeClassloader.", e); - } - - InternalOperatorMetricGroup operatorMetricGroup = - containingTask - .getEnvironment() - .getMetricGroup() - .getOrAddOperator(operatorConfig.getOperatorID(), operatorConfig.getOperatorName()); - - return operatorMetricGroup.getIOMetricGroup().getNumRecordsOutCounter(); - } - - @SuppressWarnings({"rawtypes", "unchecked"}) - private WatermarkGaugeExposingOutput createChainedSourceOutput( - StreamTask containingTask, - StreamConfig sourceInputConfig, - ClassLoader userCodeClassloader, - Input input, - OperatorMetricGroup metricGroup, - OutputTag outputTag) { - - Counter recordsOutCounter = getOperatorRecordsOutCounter(containingTask, sourceInputConfig); - - WatermarkGaugeExposingOutput chainedSourceOutput; - if (containingTask.getExecutionConfig().isObjectReuseEnabled()) { - chainedSourceOutput = new ChainingOutput(input, recordsOutCounter, metricGroup, outputTag); - } else { - TypeSerializer inSerializer = sourceInputConfig.getTypeSerializerOut(userCodeClassloader); - chainedSourceOutput = - new CopyingChainingOutput(input, inSerializer, recordsOutCounter, metricGroup, outputTag); - } - /** - * Chained sources are closed when {@link - * org.apache.flink.streaming.runtime.io.StreamTaskSourceInput} are being closed. - */ - return closer.register(chainedSourceOutput); - } - - private WatermarkGaugeExposingOutput> createOutputCollector( - StreamTask containingTask, - StreamConfig operatorConfig, - Map chainedConfigs, - ClassLoader userCodeClassloader, - Map> recordWriterOutputs, - List> allOperatorWrappers, - MailboxExecutorFactory mailboxExecutorFactory, - boolean shouldAddMetric) { - // --- Begin Gluten-specific code changes --- - List>> allOutputs = new ArrayList<>(4); - Map>> glutenOutputs = new HashMap<>(); - - Map node2outputs = - Utils.getNodeToNonChainedOutputs(operatorConfig, userCodeClassloader); - // create collectors for the network outputs - for (NonChainedOutput streamOutput : - operatorConfig.getOperatorNonChainedOutputs(userCodeClassloader)) { - @SuppressWarnings("unchecked") - RecordWriterOutput recordWriterOutput = - (RecordWriterOutput) recordWriterOutputs.get(streamOutput.getDataSetId()); - - allOutputs.add(recordWriterOutput); - glutenOutputs.put(node2outputs.get(streamOutput.getDataSetId()), recordWriterOutput); - } - - // Create collectors for the chained outputs - for (StreamEdge outputEdge : operatorConfig.getChainedOutputs(userCodeClassloader)) { - int outputId = outputEdge.getTargetId(); - StreamConfig chainedOpConfig = chainedConfigs.get(outputId); - - WatermarkGaugeExposingOutput> output = - createOperatorChain( - containingTask, - operatorConfig, - chainedOpConfig, - chainedConfigs, - userCodeClassloader, - recordWriterOutputs, - allOperatorWrappers, - outputEdge.getOutputTag(), - mailboxExecutorFactory, - shouldAddMetric); - checkState(output instanceof OutputWithChainingCheck); - allOutputs.add((OutputWithChainingCheck) output); - - // If the operator has multiple downstream chained operators, only one of them should - // increment the recordsOutCounter for this operator. Set shouldAddMetric to false - // so that we would skip adding the counter to other downstream operators. - shouldAddMetric = false; - } - - WatermarkGaugeExposingOutput> result; - - if (allOutputs.size() == 1) { - result = allOutputs.get(0); - // only if this is a single RecordWriterOutput, reuse its numRecordOut for task. - if (result instanceof RecordWriterOutput) { - Counter numRecordsOutCounter = createNumRecordsOutCounter(containingTask); - ((RecordWriterOutput) result).setNumRecordsOut(numRecordsOutCounter); - } - } else { - if (glutenOutputs.size() > 0 && allOutputs.size() != glutenOutputs.size()) { - throw new RuntimeException("Number of outputs and gluten outputs do not match."); - } - // TODO: add counter - result = closer.register(new GlutenOutputCollector<>(glutenOutputs, null)); - } - // --- End Gluten-specific code changes --- - - if (shouldAddMetric) { - // Create a CountingOutput to increment the recordsOutCounter for this operator - // if we have not added the counter to any downstream chained operator. - Counter recordsOutCounter = getOperatorRecordsOutCounter(containingTask, operatorConfig); - if (recordsOutCounter != null) { - result = new CountingOutput<>(result, recordsOutCounter); - } - } - return result; - } - - private static Counter createNumRecordsOutCounter(StreamTask containingTask) { - TaskIOMetricGroup taskIOMetricGroup = - containingTask.getEnvironment().getMetricGroup().getIOMetricGroup(); - Counter counter = new SimpleCounter(); - taskIOMetricGroup.reuseRecordsOutputCounter(counter); - return counter; - } - - /** - * Recursively create chain of operators that starts from the given {@param operatorConfig}. - * Operators are created tail to head and wrapped into an {@link WatermarkGaugeExposingOutput}. - */ - private WatermarkGaugeExposingOutput> createOperatorChain( - StreamTask containingTask, - StreamConfig prevOperatorConfig, - StreamConfig operatorConfig, - Map chainedConfigs, - ClassLoader userCodeClassloader, - Map> recordWriterOutputs, - List> allOperatorWrappers, - OutputTag outputTag, - MailboxExecutorFactory mailboxExecutorFactory, - boolean shouldAddMetricForPrevOperator) { - // create the output that the operator writes to first. this may recursively create more - // operators - WatermarkGaugeExposingOutput> chainedOperatorOutput = - createOutputCollector( - containingTask, - operatorConfig, - chainedConfigs, - userCodeClassloader, - recordWriterOutputs, - allOperatorWrappers, - mailboxExecutorFactory, - true); - - OneInputStreamOperator chainedOperator = - createOperator( - containingTask, - operatorConfig, - userCodeClassloader, - chainedOperatorOutput, - allOperatorWrappers, - false); - - return wrapOperatorIntoOutput( - chainedOperator, - containingTask, - prevOperatorConfig, - operatorConfig, - userCodeClassloader, - outputTag, - shouldAddMetricForPrevOperator); - } - - /** - * Create and return a single operator from the given {@param operatorConfig} that will be - * producing records to the {@param output}. - */ - private > OP createOperator( - StreamTask containingTask, - StreamConfig operatorConfig, - ClassLoader userCodeClassloader, - WatermarkGaugeExposingOutput> output, - List> allOperatorWrappers, - boolean isHead) { - - // now create the operator and give it the output collector to write its output to - Tuple2> chainedOperatorAndTimeService = - StreamOperatorFactoryUtil.createOperator( - operatorConfig.getStreamOperatorFactory(userCodeClassloader), - containingTask, - operatorConfig, - output, - operatorEventDispatcher); - - OP chainedOperator = chainedOperatorAndTimeService.f0; - allOperatorWrappers.add( - createOperatorWrapper( - chainedOperator, - containingTask, - operatorConfig, - chainedOperatorAndTimeService.f1, - isHead)); - - chainedOperator - .getMetricGroup() - .gauge(MetricNames.IO_CURRENT_OUTPUT_WATERMARK, output.getWatermarkGauge()::getValue); - return chainedOperator; - } - - private WatermarkGaugeExposingOutput> wrapOperatorIntoOutput( - OneInputStreamOperator operator, - StreamTask containingTask, - StreamConfig prevOperatorConfig, - StreamConfig operatorConfig, - ClassLoader userCodeClassloader, - OutputTag outputTag, - boolean shouldAddMetricForPrevOperator) { - - Counter recordsOutCounter = null; - - if (shouldAddMetricForPrevOperator) { - recordsOutCounter = getOperatorRecordsOutCounter(containingTask, prevOperatorConfig); - } - - WatermarkGaugeExposingOutput> currentOperatorOutput; - if (containingTask.getExecutionConfig().isObjectReuseEnabled()) { - currentOperatorOutput = - new ChainingOutput<>(operator, recordsOutCounter, operator.getMetricGroup(), outputTag); - } else { - TypeSerializer inSerializer = operatorConfig.getTypeSerializerIn1(userCodeClassloader); - currentOperatorOutput = - new CopyingChainingOutput<>( - operator, inSerializer, recordsOutCounter, operator.getMetricGroup(), outputTag); - } - - // wrap watermark gauges since registered metrics must be unique - operator - .getMetricGroup() - .gauge( - MetricNames.IO_CURRENT_INPUT_WATERMARK, - currentOperatorOutput.getWatermarkGauge()::getValue); - - return closer.register(currentOperatorOutput); - } - - /** - * Links operator wrappers in forward topological order. - * - * @param allOperatorWrappers is an operator wrapper list of reverse topological order - */ - private StreamOperatorWrapper linkOperatorWrappers( - List> allOperatorWrappers) { - StreamOperatorWrapper previous = null; - for (StreamOperatorWrapper current : allOperatorWrappers) { - if (previous != null) { - previous.setPrevious(current); - } - current.setNext(previous); - previous = current; - } - return previous; - } - - private > StreamOperatorWrapper createOperatorWrapper( - P operator, - StreamTask containingTask, - StreamConfig operatorConfig, - Optional processingTimeService, - boolean isHead) { - return new StreamOperatorWrapper<>( - operator, - processingTimeService, - containingTask.getMailboxExecutorFactory().createExecutor(operatorConfig.getChainIndex()), - isHead); - } - - protected void sendAcknowledgeCheckpointEvent(long checkpointId) { - if (operatorEventDispatcher == null) { - return; - } - - operatorEventDispatcher - .getRegisteredOperators() - .forEach( - x -> - operatorEventDispatcher - .getOperatorEventGateway(x) - .sendEventToCoordinator(new AcknowledgeCheckpointEvent(checkpointId))); - } -} diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/client/OffloadedJobGraphGenerator.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/client/OffloadedJobGraphGenerator.java new file mode 100644 index 000000000000..cd04b988554b --- /dev/null +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/client/OffloadedJobGraphGenerator.java @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.client; + +import org.apache.gluten.streaming.api.operators.GlutenOneInputOperatorFactory; +import org.apache.gluten.streaming.api.operators.GlutenOperator; +import org.apache.gluten.streaming.api.operators.GlutenStreamSource; +import org.apache.gluten.table.runtime.keyselector.GlutenKeySelector; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenSourceFunction; +import org.apache.gluten.table.runtime.operators.GlutenTwoInputOperator; +import org.apache.gluten.table.runtime.typeutils.GlutenStatefulRecordSerializer; + +import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; +import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; +import io.github.zhztheplayer.velox4j.type.RowType; + +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.graph.StreamEdge; +import org.apache.flink.streaming.api.graph.StreamNode; +import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.table.data.RowData; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/* + * If a operator is offloadable + * - update its input/output serializers as needed. + * - update its key selectors as needed. + * - coalesce it with its siblings as needed. Also need to update the stream edges. + */ +public class OffloadedJobGraphGenerator { + private static final Logger LOG = LoggerFactory.getLogger(OffloadedJobGraphGenerator.class); + private final JobGraph jobGraph; + private final ClassLoader userClassloader; + private boolean hasGenerated = false; + + public OffloadedJobGraphGenerator(JobGraph jobGraph, ClassLoader userClassloader) { + this.jobGraph = jobGraph; + this.userClassloader = userClassloader; + } + + public JobGraph generate() { + if (hasGenerated) { + throw new IllegalStateException("JobGraph has been generated."); + } + hasGenerated = true; + for (JobVertex jobVertex : jobGraph.getVertices()) { + offloadJobVertex(jobVertex); + } + return jobGraph; + } + + private void offloadJobVertex(JobVertex jobVertex) { + OperatorChainSliceGraphGenerator graphGenerator = + new OperatorChainSliceGraphGenerator(jobVertex, userClassloader); + OperatorChainSliceGraph chainSliceGraph = graphGenerator.getGraph(); + chainSliceGraph.dumpLog(); + + OperatorChainSlice sourceChainSlice = chainSliceGraph.getSourceSlice(); + OperatorChainSliceGraph targetChainSliceGraph = new OperatorChainSliceGraph(); + visitAndOffloadChainOperators(sourceChainSlice, chainSliceGraph, targetChainSliceGraph, 0); + visitAndUpdateStreamEdges(sourceChainSlice, chainSliceGraph, targetChainSliceGraph); + serializeAllOperatorsConfigs(targetChainSliceGraph); + + StreamConfig sourceConfig = sourceChainSlice.getOperatorConfigs().get(0); + StreamConfig targetSourceConfig = + targetChainSliceGraph.getSlice(sourceChainSlice.id()).getOperatorConfigs().get(0); + + Map chainedConfig = new HashMap(); + if (sourceChainSlice.isOffloadable()) { + // Update the first operator config + sourceConfig.setStreamOperatorFactory( + targetSourceConfig.getStreamOperatorFactory(userClassloader)); + List chainedOutputs = targetSourceConfig.getChainedOutputs(userClassloader); + sourceConfig.setChainedOutputs(targetSourceConfig.getChainedOutputs(userClassloader)); + + // Update the serializers and partitioners + sourceConfig.setTypeSerializerOut(targetSourceConfig.getTypeSerializerOut(userClassloader)); + sourceConfig.setInputs(targetSourceConfig.getInputs(userClassloader)); + KeySelector keySelector0 = targetSourceConfig.getStatePartitioner(0, userClassloader); + if (keySelector0 != null) { + sourceConfig.setStatePartitioner(0, keySelector0); + } + KeySelector keySelector1 = targetSourceConfig.getStatePartitioner(1, userClassloader); + if (keySelector1 != null) { + sourceConfig.setStatePartitioner(1, keySelector1); + } + + // The chained operators should be empty. + } else { + List operatorConfigs = sourceChainSlice.getOperatorConfigs(); + for (int i = 1; i < operatorConfigs.size(); i++) { + StreamConfig opConfig = operatorConfigs.get(i); + chainedConfig.put(opConfig.getVertexID(), opConfig); + } + } + for (OperatorChainSlice chainSlice : targetChainSliceGraph.getSlices().values()) { + if (chainSlice.id().equals(sourceChainSlice.id())) { + continue; + } + List operatorConfigs = chainSlice.getOperatorConfigs(); + for (StreamConfig opConfig : operatorConfigs) { + chainedConfig.put(opConfig.getVertexID(), opConfig); + } + } + sourceConfig.setAndSerializeTransitiveChainedTaskConfigs(chainedConfig); + sourceConfig.serializeAllConfigs(); + } + + // Fold offloadable operator chain slice + private void visitAndOffloadChainOperators( + OperatorChainSlice chainSlice, + OperatorChainSliceGraph originalChainSliceGraph, + OperatorChainSliceGraph targetChainSliceGraph, + Integer chainedIndex) { + List outputs = chainSlice.getOutputs(); + List outputIndex = new ArrayList<>(); + OperatorChainSlice finalChainSlice = null; + if (chainSlice.isOffloadable()) { + finalChainSlice = + OffloadOperatorChainSlice(originalChainSliceGraph, chainSlice, chainedIndex); + chainedIndex = chainedIndex + 1; + } else { + finalChainSlice = applyUnoffloadableOperatorChainSlice(chainSlice, chainedIndex); + chainedIndex = chainedIndex + chainSlice.getOperatorConfigs().size(); + } + + finalChainSlice.getInputs().addAll(chainSlice.getInputs()); + finalChainSlice.getOutputs().addAll(chainSlice.getOutputs()); + targetChainSliceGraph.addSlice(chainSlice.id(), finalChainSlice); + + for (Integer outputChainIndex : outputs) { + OperatorChainSlice outputChainSlice = originalChainSliceGraph.getSlice(outputChainIndex); + OperatorChainSlice outputResultChainSlice = targetChainSliceGraph.getSlice(outputChainIndex); + if (outputResultChainSlice == null) { + visitAndOffloadChainOperators( + outputChainSlice, originalChainSliceGraph, targetChainSliceGraph, chainedIndex); + } + } + } + + // Keep the original operator chain slice as is. + private OperatorChainSlice applyUnoffloadableOperatorChainSlice( + OperatorChainSlice originalChainSlice, Integer chainedIndex) { + OperatorChainSlice finalChainSlice = new OperatorChainSlice(originalChainSlice.id()); + List operatorConfigs = originalChainSlice.getOperatorConfigs(); + for (StreamConfig opConfig : operatorConfigs) { + StreamConfig newOpConfig = new StreamConfig(new Configuration(opConfig.getConfiguration())); + newOpConfig.setChainIndex(chainedIndex); + finalChainSlice.getOperatorConfigs().add(newOpConfig); + } + finalChainSlice.setOffloadable(false); + return finalChainSlice; + } + + // Fold offloadable operator chain slice, and update the input/output channel serializers + private OperatorChainSlice OffloadOperatorChainSlice( + OperatorChainSliceGraph chainSliceGraph, + OperatorChainSlice originalChainSlice, + Integer chainedIndex) { + OperatorChainSlice finalChainSlice = new OperatorChainSlice(originalChainSlice.id()); + List operatorConfigs = originalChainSlice.getOperatorConfigs(); + + // May coalesce multiple operators into one in the future. + if (operatorConfigs.size() != 1) { + throw new UnsupportedOperationException( + "Only one operator is supported for offloaded operator chain slice."); + } + + StreamConfig originalOpConfig = operatorConfigs.get(0); + GlutenOperator originalOp = getGlutenOperator(originalOpConfig).get(); + StatefulPlanNode planNode = originalOp.getPlanNode(); + // Create a new operator config for the offloaded operator. + StreamConfig finalOpConfig = + new StreamConfig(new Configuration(originalOpConfig.getConfiguration())); + if (originalOp instanceof GlutenStreamSource) { + boolean couldOutputRowVector = couldOutputRowVector(originalChainSlice, chainSliceGraph); + Class outClass = couldOutputRowVector ? StatefulRecord.class : RowData.class; + GlutenStreamSource newSourceOp = + new GlutenStreamSource( + new GlutenSourceFunction<>( + planNode, + originalOp.getOutputTypes(), + originalOp.getId(), + ((GlutenStreamSource) originalOp).getConnectorSplit(), + outClass)); + finalOpConfig.setStreamOperator(newSourceOp); + if (couldOutputRowVector) { + RowType rowType = originalOp.getOutputTypes().entrySet().iterator().next().getValue(); + finalOpConfig.setTypeSerializerOut( + new GlutenStatefulRecordSerializer(rowType, originalOp.getId())); + } + } else if (originalOp instanceof GlutenOneInputOperator) { + boolean couldOutputRowVector = couldOutputRowVector(originalChainSlice, chainSliceGraph); + boolean couldInputRowVector = couldInputRowVector(originalChainSlice, chainSliceGraph); + Class inClass = couldInputRowVector ? StatefulRecord.class : RowData.class; + Class outClass = couldOutputRowVector ? StatefulRecord.class : RowData.class; + GlutenOneInputOperator newOneInputOp = + new GlutenOneInputOperator<>( + planNode, + originalOp.getId(), + originalOp.getInputType(), + originalOp.getOutputTypes(), + inClass, + outClass, + originalOp.getDescription()); + finalOpConfig.setStreamOperator(newOneInputOp); + if (couldOutputRowVector) { + RowType rowType = originalOp.getOutputTypes().entrySet().iterator().next().getValue(); + finalOpConfig.setTypeSerializerOut( + new GlutenStatefulRecordSerializer(rowType, originalOp.getId())); + } + if (couldInputRowVector) { + finalOpConfig.setupNetworkInputs( + new GlutenStatefulRecordSerializer(originalOp.getInputType(), originalOp.getId())); + + // This node is the first node in the chain. If it has a state partitioner, we need to + // change it to GlutenKeySelector. + KeySelector keySelector = originalOpConfig.getStatePartitioner(0, userClassloader); + if (keySelector != null) { + LOG.info( + "State partitioner ({}) found in the first node {}, change it to GlutenKeySelector.", + keySelector.getClass().getName(), + originalOp.getDescription()); + finalOpConfig.setStatePartitioner(0, new GlutenKeySelector()); + } + } + } else if (originalOp instanceof GlutenTwoInputOperator) { + GlutenTwoInputOperator twoInputOp = (GlutenTwoInputOperator) originalOp; + boolean couldOutputRowVector = couldOutputRowVector(originalChainSlice, chainSliceGraph); + boolean couldInputRowVector = couldInputRowVector(originalChainSlice, chainSliceGraph); + KeySelector keySelector0 = originalOpConfig.getStatePartitioner(0, userClassloader); + if (keySelector0 != null) { + LOG.info( + "State partitioner ({}) found in the first node {}, change it to GlutenKeySelector.", + keySelector0.getClass().getName(), + originalOp.getDescription()); + finalOpConfig.setStatePartitioner(0, new GlutenKeySelector()); + } + KeySelector keySelector1 = originalOpConfig.getStatePartitioner(1, userClassloader); + if (keySelector1 != null) { + LOG.info( + "State partitioner ({}) found in the second node {}, change it to GlutenKeySelector.", + keySelector1.getClass().getName(), + originalOp.getDescription()); + finalOpConfig.setStatePartitioner(1, new GlutenKeySelector()); + } + Class inClass = couldInputRowVector ? StatefulRecord.class : RowData.class; + Class outClass = couldOutputRowVector ? StatefulRecord.class : RowData.class; + GlutenTwoInputOperator newTwoInputOp = + new GlutenTwoInputOperator<>( + planNode, + twoInputOp.getLeftId(), + twoInputOp.getRightId(), + twoInputOp.getLeftInputType(), + twoInputOp.getRightInputType(), + twoInputOp.getOutputTypes(), + inClass, + outClass); + finalOpConfig.setStreamOperator(newTwoInputOp); + finalOpConfig.setStatePartitioner(0, new GlutenKeySelector()); + finalOpConfig.setStatePartitioner(1, new GlutenKeySelector()); + // Update the output channel serializer + if (couldOutputRowVector) { + RowType rowType = twoInputOp.getOutputTypes().entrySet().iterator().next().getValue(); + finalOpConfig.setTypeSerializerOut( + new GlutenStatefulRecordSerializer(rowType, twoInputOp.getId())); + } + // Update the input channel serializers + if (couldInputRowVector) { + finalOpConfig.setupNetworkInputs( + new GlutenStatefulRecordSerializer(twoInputOp.getLeftInputType(), twoInputOp.getId()), + new GlutenStatefulRecordSerializer(twoInputOp.getRightInputType(), twoInputOp.getId())); + } + } else { + throw new UnsupportedOperationException( + "Only GlutenStreamSource is supported for offloaded operator chain slice."); + } + + finalOpConfig.setChainIndex(chainedIndex); + finalChainSlice.getOperatorConfigs().add(finalOpConfig); + finalChainSlice.setOffloadable(true); + return finalChainSlice; + } + + private StreamNode mockStreamNode(StreamConfig streamConfig) { + return new StreamNode( + streamConfig.getVertexID(), + null, + null, + (StreamOperatorFactory) streamConfig.getStreamOperatorFactory(userClassloader), + streamConfig.getOperatorName(), + null); + } + + // Incase the vetexs has been changed, update the stream edges. + private void visitAndUpdateStreamEdges( + OperatorChainSlice originalChainSlice, + OperatorChainSliceGraph originalChainSliceGraph, + OperatorChainSliceGraph targetChainSliceGraph) { + OperatorChainSlice targetChainSlice = targetChainSliceGraph.getSlice(originalChainSlice.id()); + if (targetChainSlice.isOffloadable()) { + List outputIDs = originalChainSlice.getOutputs(); + List operatorConfigs = targetChainSlice.getOperatorConfigs(); + StreamConfig targetOpConfig = operatorConfigs.get(0); + if (outputIDs.size() == 0) { + targetOpConfig.setChainedOutputs(new ArrayList<>()); + return; + } + List newOutputEdges = new ArrayList<>(); + List originalOutputEdges = + originalChainSlice + .getOperatorConfigs() + .get(originalChainSlice.getOperatorConfigs().size() - 1) + .getChainedOutputs(userClassloader); + for (int i = 0; i < outputIDs.size(); i++) { + Integer outputID = outputIDs.get(i); + OperatorChainSlice outputOriginalChainSlice = originalChainSliceGraph.getSlice(outputID); + OperatorChainSlice outputTargetChainSlice = targetChainSliceGraph.getSlice(outputID); + StreamConfig outputOpConfig = + outputTargetChainSlice + .getOperatorConfigs() + .get(0); // The first operator config is the representative. + StreamEdge originalEdge = originalOutputEdges.get(i); + StreamEdge newEdge = + new StreamEdge( + mockStreamNode(targetOpConfig), + mockStreamNode(outputOpConfig), + originalEdge.getTypeNumber(), + originalEdge.getBufferTimeout(), + originalEdge.getPartitioner(), + originalEdge.getOutputTag(), + originalEdge.getExchangeMode(), + 0, + originalEdge.getIntermediateDatasetIdToProduce()); + newOutputEdges.add(newEdge); + } + targetOpConfig.setChainedOutputs(newOutputEdges); + } + + for (Integer outputChain : originalChainSlice.getOutputs()) { + visitAndUpdateStreamEdges( + originalChainSliceGraph.getSlice(outputChain), + originalChainSliceGraph, + targetChainSliceGraph); + } + } + + void serializeAllOperatorsConfigs(OperatorChainSliceGraph chainSliceGraph) { + for (OperatorChainSlice chainSlice : chainSliceGraph.getSlices().values()) { + List operatorConfigs = chainSlice.getOperatorConfigs(); + for (StreamConfig opConfig : operatorConfigs) { + opConfig.serializeAllConfigs(); + } + } + } + + private Optional getGlutenOperator(StreamConfig taskConfig) { + StreamOperatorFactory operatorFactory = taskConfig.getStreamOperatorFactory(userClassloader); + if (operatorFactory instanceof SimpleOperatorFactory) { + StreamOperator streamOperator = taskConfig.getStreamOperator(userClassloader); + if (streamOperator instanceof GlutenOperator) { + return Optional.of((GlutenOperator) streamOperator); + } + } else if (operatorFactory instanceof GlutenOneInputOperatorFactory) { + return Optional.of(((GlutenOneInputOperatorFactory) operatorFactory).getOperator()); + } + return Optional.empty(); + } + + boolean isAllOffloadable(OperatorChainSliceGraph chainSliceGraph, List chainIDs) { + for (Integer chainID : chainIDs) { + OperatorChainSlice chainSlice = chainSliceGraph.getSlice(chainID); + if (!chainSlice.isOffloadable()) { + return false; + } + } + return true; + } + + boolean couldOutputRowVector( + OperatorChainSlice chainSlice, OperatorChainSliceGraph chainSliceGraph) { + boolean could = true; + for (Integer outputID : chainSlice.getOutputs()) { + OperatorChainSlice outputChainSlice = chainSliceGraph.getSlice(outputID); + if (!outputChainSlice.isOffloadable()) { + could = false; + break; + } + List inputs = outputChainSlice.getInputs(); + if (!isAllOffloadable(chainSliceGraph, inputs)) { + could = false; + break; + } + } + return could; + } + + boolean couldInputRowVector( + OperatorChainSlice chainSlice, OperatorChainSliceGraph chainSliceGraph) { + boolean could = true; + for (Integer inputID : chainSlice.getInputs()) { + OperatorChainSlice inputChainSlice = chainSliceGraph.getSlice(inputID); + if (!inputChainSlice.isOffloadable()) { + could = false; + break; + } + if (!couldOutputRowVector(inputChainSlice, chainSliceGraph)) { + could = false; + break; + } + } + return could; + } +} diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/client/OperatorChainSlice.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/client/OperatorChainSlice.java new file mode 100644 index 000000000000..91b96d811d04 --- /dev/null +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/client/OperatorChainSlice.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.client; + +import org.apache.flink.streaming.api.graph.StreamConfig; + +import java.util.ArrayList; +import java.util.List; + +// Split operator chain into slices for offloading +// In the same slice, operators are all could offload or not. +public class OperatorChainSlice { + // upstream slice indices + private List inputs; + // downstream slice indices + private List outputs; + private List operatorConfigs; + private Integer id; + private Boolean offloadable = false; + + public OperatorChainSlice(Integer id) { + inputs = new ArrayList<>(); + outputs = new ArrayList<>(); + operatorConfigs = new ArrayList<>(); + this.id = id; + } + + public Integer id() { + return id; + } + + public List getInputs() { + return inputs; + } + + public List getOutputs() { + return outputs; + } + + public List getOperatorConfigs() { + return operatorConfigs; + } + + public Boolean isOffloadable() { + return offloadable; + } + + public void setOffloadable(Boolean offloadable) { + this.offloadable = offloadable; + } +} diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/client/OperatorChainSliceGraph.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/client/OperatorChainSliceGraph.java new file mode 100644 index 000000000000..21711a551261 --- /dev/null +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/client/OperatorChainSliceGraph.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.client; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class OperatorChainSliceGraph { + private static final Logger LOG = LoggerFactory.getLogger(OperatorChainSliceGraph.class); + private Map slices; + + public OperatorChainSliceGraph() { + slices = new HashMap<>(); + } + + public void addSlice(Integer id, OperatorChainSlice chainSlice) { + slices.put(id, chainSlice); + } + + public OperatorChainSlice getSlice(Integer id) { + return slices.get(id); + } + + public void removeSlice(Integer id) { + slices.remove(id); + } + + public OperatorChainSlice getSourceSlice() { + List sourceCandidates = new ArrayList<>(); + + for (OperatorChainSlice chainSlice : slices.values()) { + if (chainSlice.getInputs().isEmpty()) { + sourceCandidates.add(chainSlice); + } + } + + if (sourceCandidates.isEmpty()) { + throw new IllegalStateException( + "No source suboperator chain found (no suboperator chain with empty inputs)"); + } else if (sourceCandidates.size() > 1) { + throw new IllegalStateException( + "Multiple source suboperator chains found: " + + sourceCandidates.size() + + " suboperator chains have empty inputs"); + } + + return sourceCandidates.get(0); + } + + public Map getSlices() { + return slices; + } + + public void dumpLog() { + for (OperatorChainSlice chainSlice : slices.values()) { + LOG.info("Slice ID: {}, offloadable: {}", chainSlice.id(), chainSlice.isOffloadable()); + LOG.info(" Inputs: {}", chainSlice.getInputs().toString()); + LOG.info(" Outputs: {}", chainSlice.getOutputs().toString()); + LOG.info( + " Operator Configs: {}", + chainSlice.getOperatorConfigs().stream() + .map(config -> config.getOperatorName() + "(" + config.getVertexID() + ")") + .reduce((a, b) -> a + ", " + b) + .orElse("")); + } + } +} diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/client/OperatorChainSliceGraphGenerator.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/client/OperatorChainSliceGraphGenerator.java new file mode 100644 index 000000000000..35b22ad54a3f --- /dev/null +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/client/OperatorChainSliceGraphGenerator.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.client; + +import org.apache.gluten.streaming.api.operators.GlutenOneInputOperatorFactory; +import org.apache.gluten.streaming.api.operators.GlutenOperator; + +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.graph.StreamEdge; +import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +class OperatorChainSliceGraphGenerator { + private static final Logger LOG = LoggerFactory.getLogger(OperatorChainSliceGraphGenerator.class); + private OperatorChainSliceGraph chainSliceGraph = null; + private Map> operatorParents; + private JobVertex jobVertex; + private Map chainedConfigs; + private final ClassLoader userClassloader; + + public OperatorChainSliceGraphGenerator(JobVertex jobVertex, ClassLoader userClassloader) { + this.operatorParents = new HashMap<>(); + this.jobVertex = jobVertex; + this.userClassloader = userClassloader; + } + + public OperatorChainSliceGraph getGraph() { + generateInternal(); + return chainSliceGraph; + } + + private void generateInternal() { + if (chainSliceGraph != null) { + return; + } + chainSliceGraph = new OperatorChainSliceGraph(); + + StreamConfig rootOpConfig = new StreamConfig(jobVertex.getConfiguration()); + + chainedConfigs = new HashMap<>(); + rootOpConfig + .getTransitiveChainedTaskConfigs(userClassloader) + .forEach( + (id, config) -> { + chainedConfigs.put(id, new StreamConfig(config.getConfiguration())); + }); + chainedConfigs.put(rootOpConfig.getVertexID(), rootOpConfig); + + collectOperatorParents(rootOpConfig, null); + + OperatorChainSlice chainSlice = new OperatorChainSlice(rootOpConfig.getVertexID()); + chainSlice.setOffloadable(isOffloadableOperator(rootOpConfig)); + chainSlice.getOperatorConfigs().add(rootOpConfig); + chainSliceGraph.addSlice(chainSlice.id(), chainSlice); + + advanceOperatorChainSlice(chainSlice, rootOpConfig); + } + + private void advanceOperatorChainSlice( + OperatorChainSlice chainSlice, StreamConfig currentOpConfig) { + List outputEdges = currentOpConfig.getChainedOutputs(userClassloader); + if (outputEdges == null || outputEdges.isEmpty()) { + return; + } + if (outputEdges.size() == 1) { + Integer targetId = outputEdges.get(0).getTargetId(); + StreamConfig childOpConfig = chainedConfigs.get(targetId); + // We don't coalesce operators into the same velox plan at present. Each operator is a + // separate velox plan. + startNewOperatorChainSlice(chainSlice, childOpConfig); + } else { + for (StreamEdge edge : outputEdges) { + Integer targetId = edge.getTargetId(); + StreamConfig childOpConfig = chainedConfigs.get(targetId); + startNewOperatorChainSlice(chainSlice, childOpConfig); + } + } + } + + private void startNewOperatorChainSlice( + OperatorChainSlice parentChainSlice, StreamConfig childOpConfig) { + Boolean isFistVisit = false; + OperatorChainSlice childChainSlice = chainSliceGraph.getSlice(childOpConfig.getVertexID()); + if (childChainSlice == null) { + isFistVisit = true; + childChainSlice = new OperatorChainSlice(childOpConfig.getVertexID()); + } + + parentChainSlice.getOutputs().add(childChainSlice.id()); + childChainSlice.getInputs().add(parentChainSlice.id()); + // If this path has been visited, do not advance again. + if (isFistVisit) { + childChainSlice.setOffloadable(isOffloadableOperator(childOpConfig)); + childChainSlice.getOperatorConfigs().add(childOpConfig); + chainSliceGraph.addSlice(childOpConfig.getVertexID(), childChainSlice); + advanceOperatorChainSlice(childChainSlice, childOpConfig); + } + } + + private void collectOperatorParents(StreamConfig currentOp, StreamConfig parentOp) { + List parents = + operatorParents.computeIfAbsent(currentOp.getVertexID(), k -> new ArrayList<>()); + if (parentOp != null) { + parents.add(parentOp.getVertexID()); + } + + List outputEdges = currentOp.getChainedOutputs(userClassloader); + if (outputEdges == null || outputEdges.isEmpty()) { + return; + } + for (StreamEdge edge : outputEdges) { + Integer targetId = edge.getTargetId(); + StreamConfig childOp = chainedConfigs.get(targetId); + + collectOperatorParents(childOp, currentOp); + } + } + + private boolean isOffloadableOperator(StreamConfig opConfig) { + StreamOperatorFactory operatorFactory = opConfig.getStreamOperatorFactory(userClassloader); + if (operatorFactory instanceof SimpleOperatorFactory) { + StreamOperator streamOperator = opConfig.getStreamOperator(userClassloader); + if (streamOperator instanceof GlutenOperator) { + return true; + } + } else if (operatorFactory instanceof GlutenOneInputOperatorFactory) { + return true; + } + return false; + } + + private void alignOffloadableOperatorChainSlice( + OperatorChainSlice chainSlice, OperatorChainSliceGraph chainSliceGraph) { + List outputIDs = chainSlice.getOutputs(); + for (Integer outputID : outputIDs) { + OperatorChainSlice outputChainSlice = chainSliceGraph.getSlice(outputID); + alignOffloadableOperatorChainSlice(outputChainSlice, chainSliceGraph); + } + if (!chainSlice.isOffloadable()) { + return; + } + + if (outputIDs.size() > 1) { + boolean allOffloadable = true; + boolean hasOffloadable = false; + for (Integer outputID : outputIDs) { + OperatorChainSlice outputChainSlice = chainSliceGraph.getSlice(outputID); + if (!outputChainSlice.isOffloadable()) { + allOffloadable = false; + } else { + hasOffloadable = true; + } + } + if (hasOffloadable && !allOffloadable) { + chainSlice.setOffloadable(false); + } + } + + List inputIDs = chainSlice.getInputs(); + if (inputIDs.size() > 1) { + boolean allOffloadable = true; + boolean hasOffloadable = false; + for (Integer inputID : inputIDs) { + OperatorChainSlice inputChainSlice = chainSliceGraph.getSlice(inputID); + if (!inputChainSlice.isOffloadable()) { + allOffloadable = false; + } else { + hasOffloadable = true; + } + } + } + } +} diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/streaming/api/operators/GlutenOperator.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/streaming/api/operators/GlutenOperator.java index 6ed7d0497b69..97205519679d 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/streaming/api/operators/GlutenOperator.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/streaming/api/operators/GlutenOperator.java @@ -30,4 +30,8 @@ public interface GlutenOperator { public Map getOutputTypes(); public String getId(); + + public default String getDescription() { + return ""; + } } diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/streaming/api/operators/GlutenStreamSource.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/streaming/api/operators/GlutenStreamSource.java index 0281def229ea..349261ff4d5c 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/streaming/api/operators/GlutenStreamSource.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/streaming/api/operators/GlutenStreamSource.java @@ -16,7 +16,7 @@ */ package org.apache.gluten.streaming.api.operators; -import org.apache.gluten.table.runtime.operators.GlutenVectorSourceFunction; +import org.apache.gluten.table.runtime.operators.GlutenSourceFunction; import org.apache.gluten.util.ReflectUtils; import org.apache.gluten.util.Utils; @@ -32,11 +32,24 @@ /** Legacy stream source operator in gluten, which will call Velox to run. */ public class GlutenStreamSource extends StreamSource implements GlutenOperator { - private final GlutenVectorSourceFunction sourceFunction; + private final GlutenSourceFunction sourceFunction; + private final String description; - public GlutenStreamSource(GlutenVectorSourceFunction function) { + public GlutenStreamSource(GlutenSourceFunction function) { super(function); sourceFunction = function; + this.description = ""; + } + + public GlutenStreamSource(GlutenSourceFunction function, String description) { + super(function); + sourceFunction = function; + this.description = description; + } + + @Override + public String getDescription() { + return description; } @Override diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenOneInputOperator.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenOneInputOperator.java index 09af2f4576ba..2e41c193e83b 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenOneInputOperator.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenOneInputOperator.java @@ -19,32 +19,28 @@ import org.apache.gluten.streaming.api.operators.GlutenOperator; import org.apache.gluten.table.runtime.config.VeloxConnectorConfig; import org.apache.gluten.table.runtime.config.VeloxQueryConfig; -import org.apache.gluten.vectorized.FlinkRowToVLVectorConvertor; -import io.github.zhztheplayer.velox4j.Velox4j; import io.github.zhztheplayer.velox4j.connector.ExternalStreamConnectorSplit; import io.github.zhztheplayer.velox4j.connector.ExternalStreamTableHandle; import io.github.zhztheplayer.velox4j.connector.ExternalStreams; -import io.github.zhztheplayer.velox4j.data.RowVector; import io.github.zhztheplayer.velox4j.iterator.UpIterator; -import io.github.zhztheplayer.velox4j.memory.AllocationListener; -import io.github.zhztheplayer.velox4j.memory.MemoryManager; import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; import io.github.zhztheplayer.velox4j.plan.TableScanNode; import io.github.zhztheplayer.velox4j.query.Query; import io.github.zhztheplayer.velox4j.query.SerialTask; import io.github.zhztheplayer.velox4j.serde.Serde; -import io.github.zhztheplayer.velox4j.session.Session; import io.github.zhztheplayer.velox4j.stateful.StatefulElement; +import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; +import io.github.zhztheplayer.velox4j.stateful.StatefulWatermark; import io.github.zhztheplayer.velox4j.type.RowType; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.table.data.RowData; import org.apache.flink.table.runtime.operators.TableStreamOperator; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,8 +48,8 @@ import java.util.Map; /** Calculate operator in gluten, which will call Velox to run. */ -public class GlutenOneInputOperator extends TableStreamOperator - implements OneInputStreamOperator, GlutenOperator { +public class GlutenOneInputOperator extends TableStreamOperator + implements OneInputStreamOperator, GlutenOperator { private static final Logger LOG = LoggerFactory.getLogger(GlutenOneInputOperator.class); @@ -61,33 +57,72 @@ public class GlutenOneInputOperator extends TableStreamOperator private final String id; private final RowType inputType; private final Map outputTypes; + private final RowType outputType; + private final String description; - private StreamRecord outElement = null; - - private MemoryManager memoryManager; - private Session session; - private Query query; - private ExternalStreams.BlockingQueue inputQueue; - private BufferAllocator allocator; - private SerialTask task; + private transient GlutenSessionResource sessionResource; + private transient Query query; + private transient ExternalStreams.BlockingQueue inputQueue; + private transient SerialTask task; + private final Class inClass; + private final Class outClass; + private transient VectorInputBridge inputBridge; + private transient VectorOutputBridge outputBridge; public GlutenOneInputOperator( - StatefulPlanNode plan, String id, RowType inputType, Map outputTypes) { + StatefulPlanNode plan, + String id, + RowType inputType, + Map outputTypes, + Class inClass, + Class outClass, + String description) { + if (plan == null) { + throw new IllegalArgumentException("plan is null"); + } this.glutenPlan = plan; this.id = id; this.inputType = inputType; this.outputTypes = outputTypes; + this.inClass = inClass; + this.outClass = outClass; + this.inputBridge = new VectorInputBridge<>(inClass, getId()); + this.outputBridge = new VectorOutputBridge<>(outClass); + this.outputType = outputTypes.values().iterator().next(); + this.description = description; + } + + public GlutenOneInputOperator( + StatefulPlanNode plan, + String id, + RowType inputType, + Map outputTypes, + Class inClass, + Class outClass) { + this(plan, id, inputType, outputTypes, inClass, outClass, ""); } @Override - public void open() throws Exception { - super.open(); - outElement = new StreamRecord(null); - memoryManager = MemoryManager.create(AllocationListener.NOOP); - session = Velox4j.newSession(memoryManager); + public String getDescription() { + return description; + } - inputQueue = session.externalStreamOps().newBlockingQueue(); + void initSession() { + if (sessionResource != null) { + return; + } + if (inputBridge == null) { + inputBridge = new VectorInputBridge<>(inClass, getId()); + } + if (outputBridge == null) { + outputBridge = new VectorOutputBridge<>(outClass); + } + sessionResource = new GlutenSessionResource(); + inputQueue = sessionResource.getSession().externalStreamOps().newBlockingQueue(); // add a mock input as velox not allow the source is empty. + if (inputType == null) { + throw new IllegalArgumentException("inputType is null. plan is " + Serde.toJson(glutenPlan)); + } StatefulPlanNode mockInput = new StatefulPlanNode( id, @@ -104,43 +139,66 @@ public void open() throws Exception { mockInput, VeloxQueryConfig.getConfig(getRuntimeContext()), VeloxConnectorConfig.getConfig(getRuntimeContext())); - allocator = new RootAllocator(Long.MAX_VALUE); - task = session.queryOps().execute(query); - ExternalStreamConnectorSplit split = - new ExternalStreamConnectorSplit("connector-external-stream", inputQueue.id()); - task.addSplit(id, split); + task = sessionResource.getSession().queryOps().execute(query); + task.addSplit( + id, new ExternalStreamConnectorSplit("connector-external-stream", inputQueue.id())); task.noMoreSplits(id); } @Override - public void processElement(StreamRecord element) { - try (RowVector inRv = - FlinkRowToVLVectorConvertor.fromRowData( - element.getValue(), allocator, session, inputType)) { - inputQueue.put(inRv); + public void open() throws Exception { + super.open(); + initSession(); + } + + @Override + public void processElement(StreamRecord element) { + if (element.getValue() == null) { + return; + } + StatefulRecord statefulRecord = + inputBridge.getRowVector( + element, sessionResource.getAllocator(), sessionResource.getSession(), inputType); + inputQueue.put(statefulRecord.getRowVector()); + + // Only the rowvectors generated by this operator should be closed here. + if (getId().equals(statefulRecord.getNodeId())) { + statefulRecord.close(); + } + processElementInternal(); + } + + private void processElementInternal() { + while (true) { UpIterator.State state = task.advance(); if (state == UpIterator.State.AVAILABLE) { final StatefulElement statefulElement = task.statefulGet(); - - try (RowVector outRv = statefulElement.asRecord().getRowVector()) { - List rows = - FlinkRowToVLVectorConvertor.toRowData( - outRv, allocator, outputTypes.values().iterator().next()); - for (RowData row : rows) { - output.collect(outElement.replace(row)); - } + if (statefulElement.isWatermark()) { + StatefulWatermark watermark = statefulElement.asWatermark(); + output.emitWatermark(new Watermark(watermark.getTimestamp())); + } else { + outputBridge.collect( + output, statefulElement.asRecord(), sessionResource.getAllocator(), outputType); } + statefulElement.close(); + } else { + break; } } } @Override public void close() throws Exception { - inputQueue.close(); - task.close(); - session.close(); - memoryManager.close(); - allocator.close(); + if (inputQueue != null) { + inputQueue.noMoreInput(); + inputQueue.close(); + } + if (task != null) { + task.close(); + } + if (sessionResource != null) { + sessionResource.close(); + } } @Override @@ -162,4 +220,41 @@ public Map getOutputTypes() { public String getId() { return id; } + + @Override + public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { + // TODO: notify velox + super.prepareSnapshotPreBarrier(checkpointId); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + // TODO: implement it + task.snapshotState(0); + super.snapshotState(context); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + if (task == null) { + initSession(); + } + // TODO: implement it + task.initializeState(0); + super.initializeState(context); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + // TODO: notify velox + task.notifyCheckpointComplete(checkpointId); + super.notifyCheckpointComplete(checkpointId); + } + + @Override + public void notifyCheckpointAborted(long checkpointId) throws Exception { + // TODO: notify velox + task.notifyCheckpointAborted(checkpointId); + super.notifyCheckpointAborted(checkpointId); + } } diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSessionResource.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSessionResource.java new file mode 100644 index 000000000000..b54102c466f5 --- /dev/null +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSessionResource.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.table.runtime.operators; + +import io.github.zhztheplayer.velox4j.Velox4j; +import io.github.zhztheplayer.velox4j.memory.AllocationListener; +import io.github.zhztheplayer.velox4j.memory.MemoryManager; +import io.github.zhztheplayer.velox4j.session.Session; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; + +// Manage the session and resource for Velox. +class GlutenSessionResource { + private Session session; + private MemoryManager memoryManager; + private BufferAllocator allocator; + + public GlutenSessionResource() { + this.memoryManager = MemoryManager.create(AllocationListener.NOOP); + this.session = Velox4j.newSession(memoryManager); + this.allocator = new RootAllocator(Long.MAX_VALUE); + } + + public void close() { + if (session != null) { + session.close(); + session = null; + } + if (memoryManager != null) { + memoryManager.close(); + memoryManager = null; + } + if (allocator != null) { + allocator.close(); + allocator = null; + } + } + + public Session getSession() { + return session; + } + + public MemoryManager getMemoryManager() { + return memoryManager; + } + + public BufferAllocator getAllocator() { + return allocator; + } +} diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java index 54c33f5159b8..461d5e360720 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java @@ -18,35 +18,39 @@ import org.apache.gluten.table.runtime.config.VeloxConnectorConfig; import org.apache.gluten.table.runtime.config.VeloxQueryConfig; +import org.apache.gluten.table.runtime.metrics.SourceTaskMetrics; import org.apache.gluten.vectorized.FlinkRowToVLVectorConvertor; -import io.github.zhztheplayer.velox4j.Velox4j; import io.github.zhztheplayer.velox4j.connector.ConnectorSplit; -import io.github.zhztheplayer.velox4j.data.RowVector; import io.github.zhztheplayer.velox4j.iterator.UpIterator; -import io.github.zhztheplayer.velox4j.memory.AllocationListener; -import io.github.zhztheplayer.velox4j.memory.MemoryManager; import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; import io.github.zhztheplayer.velox4j.query.Query; import io.github.zhztheplayer.velox4j.query.SerialTask; -import io.github.zhztheplayer.velox4j.serde.Serde; import io.github.zhztheplayer.velox4j.session.Session; import io.github.zhztheplayer.velox4j.stateful.StatefulElement; +import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; import io.github.zhztheplayer.velox4j.type.RowType; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; +import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.table.data.RowData; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.List; import java.util.Map; -/** Gluten legacy source function, call velox plan to execute. */ -public class GlutenSourceFunction extends RichParallelSourceFunction { +/** + * Gluten legacy source function, call velox plan to execute. It sends RowVector to downstream + * instead of RowData to avoid data convert. + */ +public class GlutenSourceFunction extends RichParallelSourceFunction + implements CheckpointedFunction { private static final Logger LOG = LoggerFactory.getLogger(GlutenSourceFunction.class); private final StatefulPlanNode planNode; @@ -55,20 +59,24 @@ public class GlutenSourceFunction extends RichParallelSourceFunction { private final ConnectorSplit split; private volatile boolean isRunning = true; - private Session session; + private GlutenSessionResource sessionResource; private Query query; - BufferAllocator allocator; - private MemoryManager memoryManager; + private SerialTask task; + private SourceTaskMetrics taskMetrics; + private final Class outClass; + private boolean isClosed = false; public GlutenSourceFunction( StatefulPlanNode planNode, Map outputTypes, String id, - ConnectorSplit split) { + ConnectorSplit split, + Class outClass) { this.planNode = planNode; this.outputTypes = outputTypes; this.id = id; this.split = split; + this.outClass = outClass; } public StatefulPlanNode getPlanNode() { @@ -88,48 +96,104 @@ public ConnectorSplit getConnectorSplit() { } @Override - public void run(SourceContext sourceContext) throws Exception { - LOG.debug("Running GlutenSourceFunction: " + Serde.toJson(planNode)); - memoryManager = MemoryManager.create(AllocationListener.NOOP); - session = Velox4j.newSession(memoryManager); - query = - new Query( - planNode, - VeloxQueryConfig.getConfig(getRuntimeContext()), - VeloxConnectorConfig.getConfig(getRuntimeContext())); - allocator = new RootAllocator(Long.MAX_VALUE); + public void open(Configuration parameters) throws Exception { + initSession(); + } + + @Override + public void run(SourceContext sourceContext) throws Exception { - SerialTask task = session.queryOps().execute(query); - task.addSplit(id, split); - task.noMoreSplits(id); while (isRunning) { UpIterator.State state = task.advance(); if (state == UpIterator.State.AVAILABLE) { - final StatefulElement element = task.statefulGet(); - try (final RowVector outRv = element.asRecord().getRowVector()) { - List rows = - FlinkRowToVLVectorConvertor.toRowData( - outRv, allocator, outputTypes.values().iterator().next()); - for (RowData row : rows) { - sourceContext.collect(row); + StatefulElement element = task.statefulGet(); + if (element.isRecord()) { + StatefulRecord record = element.asRecord(); + if (outClass.isAssignableFrom(RowData.class)) { + List rows = + FlinkRowToVLVectorConvertor.toRowData( + record.getRowVector(), sessionResource.getAllocator(), outputTypes.get(id)); + for (RowData row : rows) { + sourceContext.collect((OUT) row); + } + } else if (outClass.isAssignableFrom(StatefulRecord.class)) { + StatefulRecord statefulRecord = (StatefulRecord) record; + sourceContext.collect((OUT) record); + } else { + throw new UnsupportedOperationException( + "Unsupported output class: " + outClass.getName()); } + } else if (element.isWatermark()) { + sourceContext.emitWatermark(new Watermark(element.asWatermark().getTimestamp())); + } else { + LOG.debug("ignore not record or watermark element"); } + element.close(); } else if (state == UpIterator.State.BLOCKED) { LOG.debug("Get empty row"); } else { LOG.info("Velox task finished"); break; } + taskMetrics.updateMetrics(task, id); } - - task.close(); - session.close(); - memoryManager.close(); - allocator.close(); } @Override public void cancel() { isRunning = false; } + + @Override + public void close() throws Exception { + isRunning = false; + if (task != null) { + task.close(); + task = null; + } + if (sessionResource != null) { + sessionResource.close(); + sessionResource = null; + } + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + // TODO: implement it + this.task.snapshotState(0); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + initSession(); + // TODO: implement it + this.task.initializeState(0); + } + + public String[] notifyCheckpointComplete(long checkpointId) throws Exception { + // TODO: notify velox + return this.task.notifyCheckpointComplete(checkpointId); + } + + public void notifyCheckpointAborted(long checkpointId) throws Exception { + // TODO: notify velox + this.task.notifyCheckpointAborted(checkpointId); + } + + private void initSession() { + if (sessionResource != null) { + return; + } + sessionResource = new GlutenSessionResource(); + Session session = sessionResource.getSession(); + query = + new Query( + planNode, + VeloxQueryConfig.getConfig(getRuntimeContext()), + VeloxConnectorConfig.getConfig(getRuntimeContext())); + task = session.queryOps().execute(query); + task.addSplit(id, split); + task.noMoreSplits(id); + taskMetrics = new SourceTaskMetrics(getRuntimeContext().getMetricGroup()); + } } diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenVectorTwoInputOperator.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenTwoInputOperator.java similarity index 65% rename from gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenVectorTwoInputOperator.java rename to gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenTwoInputOperator.java index 4fbfe8c7ad23..60e0a4b206b0 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenVectorTwoInputOperator.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenTwoInputOperator.java @@ -20,18 +20,12 @@ import org.apache.gluten.table.runtime.config.VeloxConnectorConfig; import org.apache.gluten.table.runtime.config.VeloxQueryConfig; -import io.github.zhztheplayer.velox4j.Velox4j; import io.github.zhztheplayer.velox4j.connector.ExternalStreamConnectorSplit; import io.github.zhztheplayer.velox4j.connector.ExternalStreams; -import io.github.zhztheplayer.velox4j.data.RowVector; import io.github.zhztheplayer.velox4j.iterator.UpIterator; -import io.github.zhztheplayer.velox4j.memory.AllocationListener; -import io.github.zhztheplayer.velox4j.memory.MemoryManager; import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; import io.github.zhztheplayer.velox4j.query.Query; import io.github.zhztheplayer.velox4j.query.SerialTask; -import io.github.zhztheplayer.velox4j.serde.Serde; -import io.github.zhztheplayer.velox4j.session.Session; import io.github.zhztheplayer.velox4j.stateful.StatefulElement; import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; import io.github.zhztheplayer.velox4j.stateful.StatefulWatermark; @@ -53,11 +47,10 @@ * Two input operator in gluten, which will call Velox to run. It receives RowVector from upstream * instead of flink RowData. */ -public class GlutenVectorTwoInputOperator extends AbstractStreamOperator - implements TwoInputStreamOperator, - GlutenOperator { +public class GlutenTwoInputOperator extends AbstractStreamOperator + implements TwoInputStreamOperator, GlutenOperator { - private static final Logger LOG = LoggerFactory.getLogger(GlutenVectorTwoInputOperator.class); + private static final Logger LOG = LoggerFactory.getLogger(GlutenTwoInputOperator.class); private final StatefulPlanNode glutenPlan; private final String leftId; @@ -65,82 +58,99 @@ public class GlutenVectorTwoInputOperator extends AbstractStreamOperator outputTypes; + private final RowType outputType; - private StreamRecord outElement = null; - - private MemoryManager memoryManager; - private Session session; + private GlutenSessionResource sessionResource; private Query query; private ExternalStreams.BlockingQueue leftInputQueue; private ExternalStreams.BlockingQueue rightInputQueue; private SerialTask task; + private final Class inClass; + private final Class outClass; + private VectorInputBridge inputBridge; + private VectorOutputBridge outputBridge; + private String description; - public GlutenVectorTwoInputOperator( + public GlutenTwoInputOperator( StatefulPlanNode plan, String leftId, String rightId, RowType leftInputType, RowType rightInputType, - Map outputTypes) { + Map outputTypes, + Class inClass, + Class outClass, + String description) { this.glutenPlan = plan; this.leftId = leftId; this.rightId = rightId; this.leftInputType = leftInputType; this.rightInputType = rightInputType; this.outputTypes = outputTypes; + this.inClass = inClass; + this.outClass = outClass; + this.inputBridge = new VectorInputBridge<>(inClass, getId()); + this.outputBridge = new VectorOutputBridge<>(outClass); + this.outputType = outputTypes.values().iterator().next(); + this.description = description; } - // initializeState is called before open, so need to init gluten task first. - private void initGlutenTask() { - memoryManager = MemoryManager.create(AllocationListener.NOOP); - session = Velox4j.newSession(memoryManager); - query = - new Query( - glutenPlan, - VeloxQueryConfig.getConfig(getRuntimeContext()), - VeloxConnectorConfig.getConfig(getRuntimeContext())); - task = session.queryOps().execute(query); - LOG.debug("Gluten Plan: {}", Serde.toJson(glutenPlan)); - LOG.debug("OutTypes: {}", outputTypes.keySet()); - LOG.debug("RuntimeContext: {}", getRuntimeContext().getClass().getName()); + public GlutenTwoInputOperator( + StatefulPlanNode plan, + String leftId, + String rightId, + RowType leftInputType, + RowType rightInputType, + Map outputTypes, + Class inClass, + Class outClass) { + this(plan, leftId, rightId, leftInputType, rightInputType, outputTypes, inClass, outClass, ""); + } + + @Override + public String getDescription() { + return description; } @Override public void open() throws Exception { super.open(); - if (task == null) { - initGlutenTask(); - } - outElement = new StreamRecord(null); - leftInputQueue = session.externalStreamOps().newBlockingQueue(); - rightInputQueue = session.externalStreamOps().newBlockingQueue(); - ExternalStreamConnectorSplit leftSplit = - new ExternalStreamConnectorSplit("connector-external-stream", leftInputQueue.id()); - ExternalStreamConnectorSplit rightSplit = - new ExternalStreamConnectorSplit("connector-external-stream", rightInputQueue.id()); - task.addSplit(leftId, leftSplit); - task.noMoreSplits(leftId); - task.addSplit(rightId, rightSplit); - task.noMoreSplits(rightId); + initSession(); } @Override - public void processElement1(StreamRecord element) { - final RowVector inRv = element.getValue().getRowVector(); - leftInputQueue.put(inRv); - processElement(); - inRv.close(); + public String getId() { + return glutenPlan.getId(); } @Override - public void processElement2(StreamRecord element) { - final RowVector inRv = element.getValue().getRowVector(); - rightInputQueue.put(inRv); - processElement(); - inRv.close(); + public void processElement1(StreamRecord element) { + StatefulRecord statefulRecord = + inputBridge.getRowVector( + element, sessionResource.getAllocator(), sessionResource.getSession(), leftInputType); + leftInputQueue.put(statefulRecord.getRowVector()); + // Only the rowvectors generated by this operator should be closed here. + if (getId().equals(statefulRecord.getNodeId())) { + statefulRecord.close(); + } + processElementInternal(); + } + + @Override + public void processElement2(StreamRecord element) { + StatefulRecord statefulRecord = + inputBridge.getRowVector( + element, sessionResource.getAllocator(), sessionResource.getSession(), rightInputType); + rightInputQueue.put(statefulRecord.getRowVector()); + // Only the rowvectors generated by this operator should be closed here. + + if (getId().equals(statefulRecord.getNodeId())) { + statefulRecord.close(); + } + processElementInternal(); } - private void processElement() { + private void processElementInternal() { while (true) { UpIterator.State state = task.advance(); if (state == UpIterator.State.AVAILABLE) { @@ -149,10 +159,10 @@ private void processElement() { StatefulWatermark watermark = element.asWatermark(); output.emitWatermark(new Watermark(watermark.getTimestamp())); } else { - final StatefulRecord statefulRecord = element.asRecord(); - output.collect(outElement.replace(statefulRecord)); - statefulRecord.close(); + outputBridge.collect( + output, element.asRecord(), sessionResource.getAllocator(), outputType); } + element.close(); } else { break; } @@ -163,23 +173,30 @@ private void processElement() { public void processWatermark1(Watermark mark) throws Exception { // TODO: implement it; task.notifyWatermark(mark.getTimestamp(), 1); - processElement(); + processElementInternal(); } @Override public void processWatermark2(Watermark mark) throws Exception { // TODO: implement it; task.notifyWatermark(mark.getTimestamp(), 2); - processElement(); + processElementInternal(); } @Override public void close() throws Exception { - leftInputQueue.close(); - rightInputQueue.close(); - task.close(); - session.close(); - memoryManager.close(); + if (leftInputQueue != null) { + leftInputQueue.close(); + } + if (rightInputQueue != null) { + rightInputQueue.close(); + } + if (task != null) { + task.close(); + } + if (sessionResource != null) { + sessionResource.close(); + } } @Override @@ -189,7 +206,7 @@ public StatefulPlanNode getPlanNode() { @Override public RowType getInputType() { - throw new RuntimeException("Should not call getInputType on GlutenVectorTwoInputOperator"); + throw new RuntimeException("Should not call getInputType on GlutenTwoInputOperator"); } public RowType getLeftInputType() { @@ -205,11 +222,6 @@ public Map getOutputTypes() { return outputTypes; } - @Override - public String getId() { - throw new RuntimeException("Should not call getId on GlutenVectorTwoInputOperator"); - } - public String getLeftId() { return leftId; } @@ -233,14 +245,39 @@ public void snapshotState(StateSnapshotContext context) throws Exception { @Override public void initializeState(StateInitializationContext context) throws Exception { - if (task == null) { - initGlutenTask(); - } + initSession(); // TODO: implement it task.initializeState(0); super.initializeState(context); } + private void initSession() { + if (sessionResource != null) { + return; + } + + sessionResource = new GlutenSessionResource(); + + leftInputQueue = sessionResource.getSession().externalStreamOps().newBlockingQueue(); + rightInputQueue = sessionResource.getSession().externalStreamOps().newBlockingQueue(); + + query = + new Query( + glutenPlan, + VeloxQueryConfig.getConfig(getRuntimeContext()), + VeloxConnectorConfig.getConfig(getRuntimeContext())); + task = sessionResource.getSession().queryOps().execute(query); + + ExternalStreamConnectorSplit leftSplit = + new ExternalStreamConnectorSplit("connector-external-stream", leftInputQueue.id()); + ExternalStreamConnectorSplit rightSplit = + new ExternalStreamConnectorSplit("connector-external-stream", rightInputQueue.id()); + task.addSplit(leftId, leftSplit); + task.noMoreSplits(leftId); + task.addSplit(rightId, rightSplit); + task.noMoreSplits(rightId); + } + @Override public void notifyCheckpointComplete(long checkpointId) throws Exception { // TODO: notify velox diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenVectorOneInputOperator.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenVectorOneInputOperator.java deleted file mode 100644 index b55ebb686890..000000000000 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenVectorOneInputOperator.java +++ /dev/null @@ -1,203 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.table.runtime.operators; - -import org.apache.gluten.streaming.api.operators.GlutenOperator; -import org.apache.gluten.table.runtime.config.VeloxConnectorConfig; -import org.apache.gluten.table.runtime.config.VeloxQueryConfig; - -import io.github.zhztheplayer.velox4j.Velox4j; -import io.github.zhztheplayer.velox4j.connector.ExternalStreamConnectorSplit; -import io.github.zhztheplayer.velox4j.connector.ExternalStreamTableHandle; -import io.github.zhztheplayer.velox4j.connector.ExternalStreams; -import io.github.zhztheplayer.velox4j.data.RowVector; -import io.github.zhztheplayer.velox4j.iterator.UpIterator; -import io.github.zhztheplayer.velox4j.memory.AllocationListener; -import io.github.zhztheplayer.velox4j.memory.MemoryManager; -import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; -import io.github.zhztheplayer.velox4j.plan.TableScanNode; -import io.github.zhztheplayer.velox4j.query.Query; -import io.github.zhztheplayer.velox4j.query.SerialTask; -import io.github.zhztheplayer.velox4j.serde.Serde; -import io.github.zhztheplayer.velox4j.session.Session; -import io.github.zhztheplayer.velox4j.stateful.StatefulElement; -import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; -import io.github.zhztheplayer.velox4j.stateful.StatefulWatermark; -import io.github.zhztheplayer.velox4j.type.RowType; - -import org.apache.flink.runtime.state.StateInitializationContext; -import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.streaming.api.operators.OneInputStreamOperator; -import org.apache.flink.streaming.api.watermark.Watermark; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.table.runtime.operators.TableStreamOperator; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.List; -import java.util.Map; - -/** Calculate operator in gluten, which will call Velox to run. */ -public class GlutenVectorOneInputOperator extends TableStreamOperator - implements OneInputStreamOperator, GlutenOperator { - - private static final Logger LOG = LoggerFactory.getLogger(GlutenVectorOneInputOperator.class); - - private final StatefulPlanNode glutenPlan; - private final String id; - private final RowType inputType; - private final Map outputTypes; - - private StreamRecord outElement = null; - - private MemoryManager memoryManager; - private Session session; - private Query query; - private ExternalStreams.BlockingQueue inputQueue; - protected SerialTask task; - - public GlutenVectorOneInputOperator( - StatefulPlanNode plan, String id, RowType inputType, Map outputTypes) { - this.glutenPlan = plan; - this.id = id; - this.inputType = inputType; - this.outputTypes = outputTypes; - } - - void initGlutenTask() { - memoryManager = MemoryManager.create(AllocationListener.NOOP); - session = Velox4j.newSession(memoryManager); - // add a mock input as velox not allow the source is empty. - StatefulPlanNode mockInput = - new StatefulPlanNode( - id, - new TableScanNode( - id, - inputType, - new ExternalStreamTableHandle("connector-external-stream"), - List.of())); - mockInput.addTarget(glutenPlan); - LOG.debug("Gluten Plan: {}", Serde.toJson(mockInput)); - LOG.debug("OutTypes: {}", outputTypes.keySet()); - query = - new Query( - mockInput, - VeloxQueryConfig.getConfig(getRuntimeContext()), - VeloxConnectorConfig.getConfig(getRuntimeContext())); - task = session.queryOps().execute(query); - } - - @Override - public void open() throws Exception { - super.open(); - outElement = new StreamRecord(null); - inputQueue = session.externalStreamOps().newBlockingQueue(); - ExternalStreamConnectorSplit split = - new ExternalStreamConnectorSplit("connector-external-stream", inputQueue.id()); - task.addSplit(id, split); - task.noMoreSplits(id); - } - - @Override - public void processElement(StreamRecord element) { - RowVector inRv = element.getValue().getRowVector(); - inputQueue.put(inRv); - while (true) { - UpIterator.State state = task.advance(); - if (state == UpIterator.State.AVAILABLE) { - final StatefulElement statefulElement = task.statefulGet(); - if (statefulElement.isWatermark()) { - StatefulWatermark watermark = statefulElement.asWatermark(); - output.emitWatermark(new Watermark(watermark.getTimestamp())); - } else { - final StatefulRecord statefulRecord = statefulElement.asRecord(); - output.collect(outElement.replace(statefulRecord)); - statefulRecord.close(); - } - } else { - break; - } - } - inRv.close(); - } - - @Override - public void close() throws Exception { - inputQueue.close(); - task.close(); - session.close(); - memoryManager.close(); - } - - @Override - public StatefulPlanNode getPlanNode() { - return glutenPlan; - } - - @Override - public RowType getInputType() { - return inputType; - } - - @Override - public Map getOutputTypes() { - return outputTypes; - } - - @Override - public String getId() { - return id; - } - - @Override - public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { - // TODO: notify velox - super.prepareSnapshotPreBarrier(checkpointId); - } - - @Override - public void snapshotState(StateSnapshotContext context) throws Exception { - // TODO: implement it - task.snapshotState(0); - super.snapshotState(context); - } - - @Override - public void initializeState(StateInitializationContext context) throws Exception { - if (task == null) { - initGlutenTask(); - } - // TODO: implement it - task.initializeState(0); - super.initializeState(context); - } - - @Override - public void notifyCheckpointComplete(long checkpointId) throws Exception { - // TODO: notify velox - task.notifyCheckpointComplete(checkpointId); - super.notifyCheckpointComplete(checkpointId); - } - - @Override - public void notifyCheckpointAborted(long checkpointId) throws Exception { - // TODO: notify velox - task.notifyCheckpointAborted(checkpointId); - super.notifyCheckpointAborted(checkpointId); - } -} diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenVectorSourceFunction.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenVectorSourceFunction.java deleted file mode 100644 index 67bc802946d3..000000000000 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenVectorSourceFunction.java +++ /dev/null @@ -1,185 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.table.runtime.operators; - -import org.apache.gluten.table.runtime.config.VeloxConnectorConfig; -import org.apache.gluten.table.runtime.config.VeloxQueryConfig; -import org.apache.gluten.table.runtime.metrics.SourceTaskMetrics; - -import io.github.zhztheplayer.velox4j.Velox4j; -import io.github.zhztheplayer.velox4j.connector.ConnectorSplit; -import io.github.zhztheplayer.velox4j.iterator.UpIterator; -import io.github.zhztheplayer.velox4j.memory.AllocationListener; -import io.github.zhztheplayer.velox4j.memory.MemoryManager; -import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; -import io.github.zhztheplayer.velox4j.query.Query; -import io.github.zhztheplayer.velox4j.query.SerialTask; -import io.github.zhztheplayer.velox4j.serde.Serde; -import io.github.zhztheplayer.velox4j.session.Session; -import io.github.zhztheplayer.velox4j.stateful.StatefulElement; -import io.github.zhztheplayer.velox4j.type.RowType; - -import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.state.FunctionInitializationContext; -import org.apache.flink.runtime.state.FunctionSnapshotContext; -import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; -import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; -import org.apache.flink.streaming.api.watermark.Watermark; - -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Map; - -/** - * Gluten legacy source function, call velox plan to execute. It sends RowVector to downstream - * instead of RowData to avoid data convert. - */ -public class GlutenVectorSourceFunction extends RichParallelSourceFunction - implements CheckpointedFunction { - private static final Logger LOG = LoggerFactory.getLogger(GlutenVectorSourceFunction.class); - - private final StatefulPlanNode planNode; - private final Map outputTypes; - private final String id; - private final ConnectorSplit split; - private volatile boolean isRunning = true; - - private Session session; - private Query query; - private BufferAllocator allocator; - private MemoryManager memoryManager; - private SerialTask task; - private SourceTaskMetrics taskMetrics; - - public GlutenVectorSourceFunction( - StatefulPlanNode planNode, - Map outputTypes, - String id, - ConnectorSplit split) { - this.planNode = planNode; - this.outputTypes = outputTypes; - this.id = id; - this.split = split; - } - - public StatefulPlanNode getPlanNode() { - return planNode; - } - - public Map getOutputTypes() { - return outputTypes; - } - - public String getId() { - return id; - } - - public ConnectorSplit getConnectorSplit() { - return split; - } - - @Override - public void open(Configuration parameters) throws Exception { - if (memoryManager == null) { - memoryManager = MemoryManager.create(AllocationListener.NOOP); - session = Velox4j.newSession(memoryManager); - query = - new Query( - planNode, - VeloxQueryConfig.getConfig(getRuntimeContext()), - VeloxConnectorConfig.getConfig(getRuntimeContext())); - allocator = new RootAllocator(Long.MAX_VALUE); - - task = session.queryOps().execute(query); - task.addSplit(id, split); - task.noMoreSplits(id); - } - taskMetrics = new SourceTaskMetrics(getRuntimeContext().getMetricGroup()); - } - - @Override - public void run(SourceContext sourceContext) throws Exception { - while (isRunning) { - UpIterator.State state = task.advance(); - if (state == UpIterator.State.AVAILABLE) { - final StatefulElement element = task.statefulGet(); - if (element.isWatermark()) { - sourceContext.emitWatermark(new Watermark(element.asWatermark().getTimestamp())); - } else { - sourceContext.collect(element); - } - element.close(); - } else if (state == UpIterator.State.BLOCKED) { - LOG.debug("Get empty row"); - } else { - LOG.info("Velox task finished"); - break; - } - taskMetrics.updateMetrics(task, id); - } - - task.close(); - session.close(); - memoryManager.close(); - allocator.close(); - } - - @Override - public void cancel() { - isRunning = false; - } - - @Override - public void snapshotState(FunctionSnapshotContext context) throws Exception { - // TODO: implement it - this.task.snapshotState(0); - } - - @Override - public void initializeState(FunctionInitializationContext context) throws Exception { - if (memoryManager == null) { - LOG.debug("Running GlutenSourceFunction: " + Serde.toJson(planNode)); - memoryManager = MemoryManager.create(AllocationListener.NOOP); - session = Velox4j.newSession(memoryManager); - query = - new Query( - planNode, - VeloxQueryConfig.getConfig(getRuntimeContext()), - VeloxConnectorConfig.getConfig(getRuntimeContext())); - allocator = new RootAllocator(Long.MAX_VALUE); - - task = session.queryOps().execute(query); - task.addSplit(id, split); - task.noMoreSplits(id); - } - // TODO: implement it - this.task.initializeState(0); - } - - public String[] notifyCheckpointComplete(long checkpointId) throws Exception { - // TODO: notify velox - return this.task.notifyCheckpointComplete(checkpointId); - } - - public void notifyCheckpointAborted(long checkpointId) throws Exception { - // TODO: notify velox - this.task.notifyCheckpointAborted(checkpointId); - } -} diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/typeutils/GlutenRowVectorSerializer.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/typeutils/GlutenStatefulRecordSerializer.java similarity index 83% rename from gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/typeutils/GlutenRowVectorSerializer.java rename to gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/typeutils/GlutenStatefulRecordSerializer.java index db17a47a2d59..1040e776c132 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/typeutils/GlutenRowVectorSerializer.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/typeutils/GlutenStatefulRecordSerializer.java @@ -31,24 +31,31 @@ import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.io.Closeable; import java.io.IOException; /** Serializer for {@link RowVector}. */ @Internal -public class GlutenRowVectorSerializer extends TypeSerializer implements Closeable { +public class GlutenStatefulRecordSerializer extends TypeSerializer + implements Closeable { + private static final Logger LOG = LoggerFactory.getLogger(GlutenStatefulRecordSerializer.class); private static final long serialVersionUID = 1L; private final RowType rowType; private transient MemoryManager memoryManager; private transient Session session; + private final String nodeId; - public GlutenRowVectorSerializer(RowType rowType) { + public GlutenStatefulRecordSerializer(RowType rowType, String nodeId) { + this.nodeId = nodeId; this.rowType = rowType; } @Override public TypeSerializer duplicate() { - return new GlutenRowVectorSerializer(rowType); + return new GlutenStatefulRecordSerializer(rowType, nodeId); } @Override @@ -73,7 +80,7 @@ public StatefulRecord deserialize(DataInputView source) throws IOException { byte[] str = new byte[len]; source.readFully(str); RowVector rowVector = session.baseVectorOps().deserializeOne(new String(str)).asRowVector(); - StatefulRecord record = new StatefulRecord(null, 0, 0, false, -1); + StatefulRecord record = new StatefulRecord(nodeId, 0, 0, false, -1); record.setRowVector(rowVector); return record; } @@ -85,7 +92,7 @@ public StatefulRecord deserialize(StatefulRecord reuse, DataInputView source) th @Override public StatefulRecord copy(StatefulRecord from) { - throw new RuntimeException("Not implemented for gluten"); + return from; } @Override @@ -100,9 +107,9 @@ public void copy(DataInputView source, DataOutputView target) throws IOException @Override public boolean equals(Object obj) { - if (obj instanceof GlutenRowVectorSerializer) { + if (obj instanceof GlutenStatefulRecordSerializer) { if (rowType != null) { - GlutenRowVectorSerializer other = (GlutenRowVectorSerializer) obj; + GlutenStatefulRecordSerializer other = (GlutenStatefulRecordSerializer) obj; return rowType.equals(other.rowType); } return true; @@ -121,7 +128,7 @@ public int hashCode() { @Override public boolean isImmutableType() { - return false; + return true; } @Override @@ -131,7 +138,7 @@ public int getLength() { @Override public TypeSerializerSnapshot snapshotConfiguration() { - return new RowVectorSerializerSnapshot(rowType); + return new RowVectorSerializerSnapshot(rowType, nodeId); } @Override @@ -148,13 +155,15 @@ public static final class RowVectorSerializerSnapshot private static final int CURRENT_VERSION = 1; private RowType rowType; + private String nodeId; @SuppressWarnings("unused") public RowVectorSerializerSnapshot() { // this constructor is used when restoring from a checkpoint/savepoint. } - RowVectorSerializerSnapshot(RowType rowType) { + RowVectorSerializerSnapshot(RowType rowType, String nodeId) { + this.nodeId = nodeId; this.rowType = rowType; } @@ -171,8 +180,8 @@ public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCode throws IOException {} @Override - public GlutenRowVectorSerializer restoreSerializer() { - return new GlutenRowVectorSerializer(rowType); + public GlutenStatefulRecordSerializer restoreSerializer() { + return new GlutenStatefulRecordSerializer(rowType, nodeId); } @Override diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/util/VectorInputBridge.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/util/VectorInputBridge.java new file mode 100644 index 000000000000..7c9e3e5e3d54 --- /dev/null +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/util/VectorInputBridge.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.table.runtime.operators; + +import org.apache.gluten.vectorized.FlinkRowToVLVectorConvertor; + +import io.github.zhztheplayer.velox4j.data.RowVector; +import io.github.zhztheplayer.velox4j.session.Session; +import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; +import io.github.zhztheplayer.velox4j.type.RowType; + +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.data.RowData; + +import org.apache.arrow.memory.BufferAllocator; + +import java.io.Serializable; + +// This bridge is used to convert the input data to RowVector. +public class VectorInputBridge implements Serializable { + private static final long serialVersionUID = 1L; + private final Class inClass; + private final String nodeId; + + public class RowVectorWrapper { + public RowVector rowVector; + public String nodeId; + + public RowVectorWrapper(RowVector rowVector, String nodeId) { + this.rowVector = rowVector; + this.nodeId = nodeId; + } + } + ; + + public VectorInputBridge(Class inClass, String nodeId) { + this.inClass = inClass; + this.nodeId = nodeId; + } + + public StatefulRecord getRowVector( + StreamRecord inputData, BufferAllocator allocator, Session session, RowType inputType) { + if (inClass.isAssignableFrom(RowData.class)) { + RowData rowData = (RowData) inputData.getValue(); + RowVector rowVector = + FlinkRowToVLVectorConvertor.fromRowData(rowData, allocator, session, inputType); + StatefulRecord statefulRecord = new StatefulRecord(nodeId, rowVector.id(), 0, false, -1); + statefulRecord.setRowVector(rowVector); + return statefulRecord; + } else if (inClass.isAssignableFrom(StatefulRecord.class)) { + // Create a new RowVector Reference. And the original RowVector Object is safe to close. + return (StatefulRecord) inputData.getValue(); + } else { + throw new UnsupportedOperationException("Unsupported input class: " + inClass.getName()); + } + } +} diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/util/VectorOutputBridge.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/util/VectorOutputBridge.java new file mode 100644 index 000000000000..d4020571d68a --- /dev/null +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/util/VectorOutputBridge.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.table.runtime.operators; + +import org.apache.gluten.vectorized.FlinkRowToVLVectorConvertor; + +import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; +import io.github.zhztheplayer.velox4j.type.RowType; + +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.data.RowData; + +import org.apache.arrow.memory.BufferAllocator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; +import java.util.List; + +/* + * This bridge is used to convert the output data to RowData or StatefulRecord. + * and collect the output data to the collector. + */ +public class VectorOutputBridge implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(VectorOutputBridge.class); + private static final long serialVersionUID = 1L; + private final Class outClass; + private transient StreamRecord outElement; + + public VectorOutputBridge(Class outClass) { + this.outClass = outClass; + this.outElement = new StreamRecord<>(null); + } + + private StreamRecord getOutElement() { + if (outElement == null) { + outElement = new StreamRecord<>(null); + } + return outElement; + } + + public void collect( + Output> collector, + StatefulRecord record, + BufferAllocator allocator, + RowType outputType) { + if (outClass.isAssignableFrom(RowData.class)) { + List rows = + FlinkRowToVLVectorConvertor.toRowData(record.getRowVector(), allocator, outputType); + for (RowData row : rows) { + collector.collect(getOutElement().replace((OUT) row)); + } + } else if (outClass.isAssignableFrom(StatefulRecord.class)) { + collector.collect(getOutElement().replace((OUT) record)); + } else { + throw new UnsupportedOperationException("Unsupported output class: " + outClass.getName()); + } + } +} diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java index d96f48ecbe2f..e6a511a49211 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java @@ -395,11 +395,18 @@ protected void setValue(int index, byte[] value) { } } -class TimestampVectorWriter extends BaseVectorWriter { +class TimestampVectorWriter extends BaseVectorWriter { private final int precision = 3; // Millisecond precision public TimestampVectorWriter(Type fieldType, BufferAllocator allocator, FieldVector vector) { super(vector); + // Verify that the vector is a timestamp vector (either TimeStampMilliVector or + // TimeStampMilliTZVector) + if (!(vector instanceof TimeStampMilliVector) && !(vector instanceof TimeStampMilliTZVector)) { + throw new IllegalArgumentException( + "Expected TimeStampMilliVector or TimeStampMilliTZVector, but got: " + + vector.getClass().getName()); + } } @Override @@ -414,7 +421,16 @@ protected Long getValue(ArrayData arrayData, int index) { @Override protected void setValue(int index, Long value) { - this.typedVector.setSafe(index, value); + + // Both TimeStampMilliVector and TimeStampMilliTZVector support setSafe with long value + if (this.typedVector instanceof TimeStampMilliVector) { + ((TimeStampMilliVector) this.typedVector).setSafe(index, value); + } else if (this.typedVector instanceof TimeStampMilliTZVector) { + ((TimeStampMilliTZVector) this.typedVector).setSafe(index, value); + } else { + throw new IllegalStateException( + "Unexpected vector type: " + this.typedVector.getClass().getName()); + } } } diff --git a/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamFilterTest.java b/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamFilterTest.java index 15a3968fe96e..1955ece25ae0 100644 --- a/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamFilterTest.java +++ b/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamFilterTest.java @@ -216,7 +216,9 @@ public TestableGlutenOneInputOperator( new io.github.zhztheplayer.velox4j.plan.StatefulPlanNode(veloxPlan.getId(), veloxPlan), PlanNodeIdGenerator.newId(), veloxType, - Map.of(veloxPlan.getId(), veloxType)); + Map.of(veloxPlan.getId(), veloxType), + RowData.class, + RowData.class); } @Override diff --git a/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamJoinOperatorTest.java b/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamJoinOperatorTest.java index 88532398651a..9215431a5455 100644 --- a/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamJoinOperatorTest.java +++ b/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamJoinOperatorTest.java @@ -16,7 +16,7 @@ */ package org.apache.gluten.streaming.api.operators; -import org.apache.gluten.table.runtime.operators.GlutenVectorTwoInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenTwoInputOperator; import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; @@ -77,7 +77,7 @@ public static void setupTestData() { @Test public void testInnerJoin() throws Exception { - GlutenVectorTwoInputOperator operator = createGlutenJoinOperator(FlinkJoinType.INNER); + GlutenTwoInputOperator operator = createGlutenJoinOperator(FlinkJoinType.INNER); List expectedOutput = Arrays.asList( @@ -99,7 +99,7 @@ public void testInnerJoin() throws Exception { @Test @Disabled public void testLeftJoin() throws Exception { - GlutenVectorTwoInputOperator operator = createGlutenJoinOperator(FlinkJoinType.LEFT); + GlutenTwoInputOperator operator = createGlutenJoinOperator(FlinkJoinType.LEFT); List expectedOutput = Arrays.asList( @@ -127,7 +127,7 @@ public void testLeftJoin() throws Exception { @Test @Disabled public void testRightJoin() throws Exception { - GlutenVectorTwoInputOperator operator = createGlutenJoinOperator(FlinkJoinType.RIGHT); + GlutenTwoInputOperator operator = createGlutenJoinOperator(FlinkJoinType.RIGHT); List expectedOutput = Arrays.asList( @@ -151,7 +151,7 @@ public void testRightJoin() throws Exception { @Test @Disabled public void testFullOuterJoin() throws Exception { - GlutenVectorTwoInputOperator operator = createGlutenJoinOperator(FlinkJoinType.FULL); + GlutenTwoInputOperator operator = createGlutenJoinOperator(FlinkJoinType.FULL); List expectedOutput = Arrays.asList( @@ -181,7 +181,7 @@ public void testFullOuterJoin() throws Exception { @Test public void testInnerJoinWithNonEquiCondition() throws Exception { RexNode nonEquiCondition = createNonEquiCondition(); - GlutenVectorTwoInputOperator operator = + GlutenTwoInputOperator operator = createGlutenJoinOperator(FlinkJoinType.INNER, nonEquiCondition); List expectedOutput = diff --git a/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamJoinOperatorTestBase.java b/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamJoinOperatorTestBase.java index 7aa7f57d68be..8b0e040a3d23 100644 --- a/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamJoinOperatorTestBase.java +++ b/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamJoinOperatorTestBase.java @@ -19,7 +19,7 @@ import org.apache.gluten.rexnode.RexConversionContext; import org.apache.gluten.rexnode.RexNodeConverter; import org.apache.gluten.rexnode.Utils; -import org.apache.gluten.table.runtime.operators.GlutenVectorTwoInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenTwoInputOperator; import org.apache.gluten.table.runtime.stream.common.Velox4jEnvironment; import org.apache.gluten.util.LogicalTypeConverter; import org.apache.gluten.util.PlanNodeIdGenerator; @@ -160,11 +160,11 @@ protected org.apache.flink.table.types.logical.RowType getOutputType() { return outputRowType; } - protected GlutenVectorTwoInputOperator createGlutenJoinOperator(FlinkJoinType joinType) { + protected GlutenTwoInputOperator createGlutenJoinOperator(FlinkJoinType joinType) { return createGlutenJoinOperator(joinType, null); } - protected GlutenVectorTwoInputOperator createGlutenJoinOperator( + protected GlutenTwoInputOperator createGlutenJoinOperator( FlinkJoinType joinType, RexNode nonEquiCondition) { JoinType veloxJoinType = Utils.toVLJoinType(joinType); @@ -226,13 +226,15 @@ protected GlutenVectorTwoInputOperator createGlutenJoinOperator( outputVeloxType, 1024); - return new GlutenVectorTwoInputOperator( + return new GlutenTwoInputOperator( new StatefulPlanNode(join.getId(), join), leftInput.getId(), rightInput.getId(), leftVeloxType, rightVeloxType, - Map.of(join.getId(), outputVeloxType)); + Map.of(join.getId(), outputVeloxType), + RowData.class, + RowData.class); } protected void processTestData( @@ -281,7 +283,7 @@ private StatefulRecord convertToStatefulRecord(RowData rowData, RowType rowType) } protected void executeJoinTest( - GlutenVectorTwoInputOperator operator, + GlutenTwoInputOperator operator, List leftData, List rightData, List expectedOutput) diff --git a/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamOperatorTestBase.java b/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamOperatorTestBase.java index 224dcfd1911e..f26bb9038d85 100644 --- a/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamOperatorTestBase.java +++ b/gluten-flink/ut/src/test/java/org/apache/gluten/streaming/api/operators/GlutenStreamOperatorTestBase.java @@ -165,6 +165,8 @@ protected GlutenOneInputOperator createTestOperator( new StatefulPlanNode(veloxPlan.getId(), veloxPlan), PlanNodeIdGenerator.newId(), inputVeloxType, - Map.of(veloxPlan.getId(), outputVeloxType)); + Map.of(veloxPlan.getId(), outputVeloxType), + RowData.class, + RowData.class); } }