From fabc62862d87ca4bf600494bfc767c946c9bdae0 Mon Sep 17 00:00:00 2001 From: 2pk03 Date: Mon, 15 Dec 2025 15:44:06 +0100 Subject: [PATCH 1/5] Spark DataFrames support / Optimizer load profiles --- .../org/apache/wayang/api/DataQuanta.scala | 18 +- .../apache/wayang/api/JavaPlanBuilder.scala | 10 ++ .../org/apache/wayang/api/PlanBuilder.scala | 10 ++ .../wayang/basic/operators/ParquetSink.java | 58 +++++++ .../wayang/basic/operators/ParquetSource.java | 12 ++ .../spark/channels/ChannelConversions.java | 51 +++++- .../wayang/spark/channels/DatasetChannel.java | 123 ++++++++++++++ .../apache/wayang/spark/mapping/Mappings.java | 3 +- .../spark/mapping/ParquetSinkMapping.java | 56 +++++++ .../operators/SparkDatasetToRddOperator.java | 96 +++++++++++ .../spark/operators/SparkParquetSink.java | 90 ++++++++++ .../spark/operators/SparkParquetSource.java | 36 ++-- .../operators/SparkRddToDatasetOperator.java | 104 ++++++++++++ .../wayang/spark/util/DatasetConverters.java | 157 ++++++++++++++++++ .../wayang-spark-defaults.properties | 22 +++ .../spark/operators/DatasetChannelTest.java | 98 +++++++++++ .../spark/operators/DatasetTestUtils.java | 84 ++++++++++ .../SparkDatasetToRddOperatorTest.java | 46 +++++ .../operators/SparkOperatorTestBase.java | 47 ++++-- .../spark/operators/SparkParquetSinkTest.java | 70 ++++++++ .../SparkParquetSourceDatasetOutputTest.java | 57 +++++++ .../SparkRddToDatasetOperatorTest.java | 43 +++++ .../wayang/spark/test/ChannelFactory.java | 17 ++ 23 files changed, 1279 insertions(+), 29 deletions(-) create mode 100644 wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/ParquetSink.java create mode 100644 wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/channels/DatasetChannel.java create mode 100644 wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/mapping/ParquetSinkMapping.java create mode 100644 wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkDatasetToRddOperator.java create mode 100644 wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkParquetSink.java create mode 100644 wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkRddToDatasetOperator.java create mode 100644 wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/util/DatasetConverters.java create mode 100644 wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/DatasetChannelTest.java create mode 100644 wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/DatasetTestUtils.java create mode 100644 wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkDatasetToRddOperatorTest.java create mode 100644 wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkParquetSinkTest.java create mode 100644 wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkParquetSourceDatasetOutputTest.java create mode 100644 wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkRddToDatasetOperatorTest.java diff --git a/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuanta.scala b/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuanta.scala index 99f6f9cc7..e9673950e 100644 --- a/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuanta.scala +++ b/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuanta.scala @@ -36,7 +36,7 @@ import org.apache.wayang.core.optimizer.costs.LoadProfileEstimator import org.apache.wayang.core.plan.wayangplan._ import org.apache.wayang.core.platform.Platform import org.apache.wayang.core.util.{Tuple => WayangTuple} -import org.apache.wayang.basic.data.{Tuple2 => WayangTuple2} +import org.apache.wayang.basic.data.{Record, Tuple2 => WayangTuple2} import org.apache.wayang.basic.model.{DLModel, LogisticRegressionModel,DecisionTreeRegressionModel}; import org.apache.wayang.commons.util.profiledb.model.Experiment import com.google.protobuf.ByteString; @@ -1027,6 +1027,12 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I writeTextFileJava(url, toSerializableFunction(formatterUdf), udfLoad) } + def writeParquet(url: String, overwrite: Boolean = false)(implicit ev: Out =:= Record): Unit = + writeParquetJava(url, overwrite, preferDataset = false) + + def writeParquetAsDataset(url: String, overwrite: Boolean = true)(implicit ev: Out =:= Record): Unit = + writeParquetJava(url, overwrite, preferDataset = true) + /** * Write the data quanta in this instance to a text file. Triggers execution. * @@ -1090,6 +1096,16 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I this.planBuilder.sinks.clear() } + private def writeParquetJava(url: String, overwrite: Boolean, preferDataset: Boolean)(implicit ev: Out =:= Record): Unit = { + val _ = ev + val sink = new ParquetSink(url, overwrite, preferDataset) + sink.setName(s"Write parquet $url") + this.connectTo(sink, 0) + this.planBuilder.sinks += sink + this.planBuilder.buildAndExecute() + this.planBuilder.sinks.clear() + } + /** * Write the data quanta in this instance to a Object file. Triggers execution. * diff --git a/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/JavaPlanBuilder.scala b/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/JavaPlanBuilder.scala index d1f9a118a..8792e80f0 100644 --- a/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/JavaPlanBuilder.scala +++ b/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/JavaPlanBuilder.scala @@ -72,6 +72,16 @@ class JavaPlanBuilder(wayangCtx: WayangContext, jobName: String) { def readParquet(url: String, projection: Array[String] = null): UnarySourceDataQuantaBuilder[UnarySourceDataQuantaBuilder[_, Record], Record] = createSourceBuilder(ParquetSource.create(url, projection))(ClassTag(classOf[Record])) + /** + * Read a parquet file and provide it as a dataset of [[Record]]s backed by Spark Datasets. + * + * @param url the URL of the Parquet file + * @param projection the projection, if any + * @return [[DataQuantaBuilder]] for the file + */ + def readParquetAsDataset(url: String, projection: Array[String] = null): UnarySourceDataQuantaBuilder[UnarySourceDataQuantaBuilder[_, Record], Record] = + createSourceBuilder(ParquetSource.create(url, projection).preferDatasetOutput(true))(ClassTag(classOf[Record])) + /** * Read a text file from a Google Cloud Storage bucket and provide it as a dataset of [[String]]s, one per line. * diff --git a/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/PlanBuilder.scala b/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/PlanBuilder.scala index 648755492..d441a40b8 100644 --- a/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/PlanBuilder.scala +++ b/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/PlanBuilder.scala @@ -140,6 +140,16 @@ class PlanBuilder(private[api] val wayangContext: WayangContext, private var job */ def readParquet(url: String, projection: Array[String] = null): DataQuanta[Record] = load(ParquetSource.create(url, projection)) + /** + * Read a parquet file and keep it backed by a Spark Dataset throughout execution. + * + * @param url the URL of the Parquet file + * @param projection the projection, if any + * @return [[DataQuanta]] of [[Record]]s backed by a Spark Dataset when executed on Spark + */ + def readParquetAsDataset(url: String, projection: Array[String] = null): DataQuanta[Record] = + load(ParquetSource.create(url, projection).preferDatasetOutput(true)) + /** * Read a text file from a Google Cloud Storage bucket and provide it as a dataset of [[String]]s, one per line. * diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/ParquetSink.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/ParquetSink.java new file mode 100644 index 000000000..ae283991f --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/ParquetSink.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.plan.wayangplan.UnarySink; +import org.apache.wayang.core.types.DataSetType; + +/** + * Logical operator that writes {@link Record}s into a Parquet file. + */ +public class ParquetSink extends UnarySink { + + private final String outputUrl; + + private final boolean isOverwrite; + + private final boolean preferDataset; + + public ParquetSink(String outputUrl, boolean isOverwrite, boolean preferDataset, DataSetType type) { + super(type); + this.outputUrl = outputUrl; + this.isOverwrite = isOverwrite; + this.preferDataset = preferDataset; + } + + public ParquetSink(String outputUrl, boolean isOverwrite, boolean preferDataset) { + this(outputUrl, isOverwrite, preferDataset, DataSetType.createDefault(Record.class)); + } + + public String getOutputUrl() { + return this.outputUrl; + } + + public boolean isOverwrite() { + return this.isOverwrite; + } + + public boolean prefersDataset() { + return this.preferDataset; + } +} diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/ParquetSource.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/ParquetSource.java index 943bdbc5f..c5df31937 100644 --- a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/ParquetSource.java +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/ParquetSource.java @@ -57,6 +57,8 @@ public class ParquetSource extends UnarySource { private MessageType schema; + private boolean preferDatasetOutput = false; + /** * Creates a new instance. * @@ -124,6 +126,16 @@ public ParquetSource(ParquetSource that) { this.projection = that.getProjection(); this.metadata = that.getMetadata(); this.schema = that.getSchema(); + this.preferDatasetOutput = that.preferDatasetOutput; + } + + public ParquetSource preferDatasetOutput(boolean preferDataset) { + this.preferDatasetOutput = preferDataset; + return this; + } + + public boolean isDatasetOutputPreferred() { + return this.preferDatasetOutput; } @Override diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/channels/ChannelConversions.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/channels/ChannelConversions.java index 9f8fca121..87ece4576 100644 --- a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/channels/ChannelConversions.java +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/channels/ChannelConversions.java @@ -19,9 +19,12 @@ package org.apache.wayang.spark.channels; import org.apache.wayang.basic.channels.FileChannel; +import org.apache.wayang.basic.data.Record; import org.apache.wayang.basic.data.Tuple2; +import org.apache.wayang.core.api.Configuration; import org.apache.wayang.core.optimizer.channels.ChannelConversion; import org.apache.wayang.core.optimizer.channels.DefaultChannelConversion; +import org.apache.wayang.core.plan.executionplan.Channel; import org.apache.wayang.core.types.DataSetType; import org.apache.wayang.java.channels.CollectionChannel; import org.apache.wayang.java.platform.JavaPlatform; @@ -29,8 +32,10 @@ import org.apache.wayang.spark.operators.SparkCacheOperator; import org.apache.wayang.spark.operators.SparkCollectOperator; import org.apache.wayang.spark.operators.SparkCollectionSource; +import org.apache.wayang.spark.operators.SparkDatasetToRddOperator; import org.apache.wayang.spark.operators.SparkObjectFileSink; import org.apache.wayang.spark.operators.SparkObjectFileSource; +import org.apache.wayang.spark.operators.SparkRddToDatasetOperator; import org.apache.wayang.spark.operators.SparkTsvFileSink; import org.apache.wayang.spark.operators.SparkTsvFileSource; @@ -108,6 +113,32 @@ public class ChannelConversions { () -> new SparkObjectFileSource<>(DataSetType.createDefault(Void.class)) ); + public static final ChannelConversion DATASET_TO_UNCACHED_RDD = new DefaultChannelConversion( + DatasetChannel.UNCACHED_DESCRIPTOR, + RddChannel.UNCACHED_DESCRIPTOR, + () -> new SparkDatasetToRddOperator() + ); + + public static final ChannelConversion CACHED_DATASET_TO_UNCACHED_RDD = new DefaultChannelConversion( + DatasetChannel.CACHED_DESCRIPTOR, + RddChannel.UNCACHED_DESCRIPTOR, + () -> new SparkDatasetToRddOperator() + ); + + public static final ChannelConversion UNCACHED_RDD_TO_UNCACHED_DATASET = new DefaultChannelConversion( + RddChannel.UNCACHED_DESCRIPTOR, + DatasetChannel.UNCACHED_DESCRIPTOR, + ChannelConversions::createRddToDatasetOperator, + "via SparkRddToDatasetOperator" + ); + + public static final ChannelConversion CACHED_RDD_TO_UNCACHED_DATASET = new DefaultChannelConversion( + RddChannel.CACHED_DESCRIPTOR, + DatasetChannel.UNCACHED_DESCRIPTOR, + ChannelConversions::createRddToDatasetOperator, + "via SparkRddToDatasetOperator" + ); + public static Collection ALL = Arrays.asList( UNCACHED_RDD_TO_CACHED_RDD, COLLECTION_TO_BROADCAST, @@ -119,6 +150,24 @@ public class ChannelConversions { HDFS_OBJECT_FILE_TO_UNCACHED_RDD, // HDFS_TSV_TO_UNCACHED_RDD, CACHED_RDD_TO_HDFS_TSV, - UNCACHED_RDD_TO_HDFS_TSV + UNCACHED_RDD_TO_HDFS_TSV, + DATASET_TO_UNCACHED_RDD, + CACHED_DATASET_TO_UNCACHED_RDD, + UNCACHED_RDD_TO_UNCACHED_DATASET, + CACHED_RDD_TO_UNCACHED_DATASET ); + + private static SparkRddToDatasetOperator createRddToDatasetOperator(Channel sourceChannel, + Configuration configuration) { + DataSetType type = DataSetType.createDefault(Record.class); + if (sourceChannel != null) { + DataSetType sourceType = sourceChannel.getDataSetType(); + if (Record.class.isAssignableFrom(sourceType.getDataUnitType().getTypeClass())) { + @SuppressWarnings("unchecked") + DataSetType casted = (DataSetType) sourceType; + type = casted; + } + } + return new SparkRddToDatasetOperator(type); + } } diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/channels/DatasetChannel.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/channels/DatasetChannel.java new file mode 100644 index 000000000..9d3057fae --- /dev/null +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/channels/DatasetChannel.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.spark.channels; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.executionplan.Channel; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.platform.AbstractChannelInstance; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.Executor; +import org.apache.wayang.core.util.Actions; +import org.apache.wayang.spark.execution.SparkExecutor; + +/** + * {@link Channel} implementation that transports Spark {@link Dataset}s (i.e., DataFrames). + */ +public class DatasetChannel extends Channel { + + public static final ChannelDescriptor UNCACHED_DESCRIPTOR = new ChannelDescriptor( + DatasetChannel.class, false, false + ); + + public static final ChannelDescriptor CACHED_DESCRIPTOR = new ChannelDescriptor( + DatasetChannel.class, true, true + ); + + public DatasetChannel(ChannelDescriptor descriptor, OutputSlot outputSlot) { + super(descriptor, outputSlot); + assert descriptor == UNCACHED_DESCRIPTOR || descriptor == CACHED_DESCRIPTOR; + } + + private DatasetChannel(DatasetChannel parent) { + super(parent); + } + + @Override + public DatasetChannel copy() { + return new DatasetChannel(this); + } + + @Override + public Instance createInstance(Executor executor, + OptimizationContext.OperatorContext producerOperatorContext, + int producerOutputIndex) { + return new Instance((SparkExecutor) executor, producerOperatorContext, producerOutputIndex); + } + + /** + * {@link ChannelInstance} for {@link DatasetChannel}s. + */ + public class Instance extends AbstractChannelInstance { + + private Dataset dataset; + + public Instance(SparkExecutor executor, + OptimizationContext.OperatorContext producerOperatorContext, + int producerOutputIndex) { + super(executor, producerOperatorContext, producerOutputIndex); + } + + /** + * Store a {@link Dataset} in this channel and optionally measure its cardinality. + * + * @param dataset the {@link Dataset} to store + * @param sparkExecutor the {@link SparkExecutor} handling this channel + */ + public void accept(Dataset dataset, SparkExecutor sparkExecutor) { + this.dataset = dataset; + if (this.isMarkedForInstrumentation()) { + this.measureCardinality(dataset); + } + } + + /** + * Provide the stored {@link Dataset}. + * + * @return the stored {@link Dataset} + */ + public Dataset provideDataset() { + return this.dataset; + } + + @Override + protected void doDispose() { + if (this.isDatasetCached() && this.dataset != null) { + Actions.doSafe(() -> this.dataset.unpersist()); + this.dataset = null; + } + } + + private void measureCardinality(Dataset dataset) { + this.setMeasuredCardinality(dataset.count()); + } + + private boolean isDatasetCached() { + return this.getChannel().isReusable(); + } + + @Override + public DatasetChannel getChannel() { + return DatasetChannel.this; + } + } +} diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/mapping/Mappings.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/mapping/Mappings.java index bc42956f9..1e3642a9a 100644 --- a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/mapping/Mappings.java +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/mapping/Mappings.java @@ -60,7 +60,8 @@ public class Mappings { new SampleMapping(), new ZipWithIdMapping(), new KafkaTopicSinkMapping(), - new KafkaTopicSourceMapping() + new KafkaTopicSourceMapping(), + new ParquetSinkMapping() ); diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/mapping/ParquetSinkMapping.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/mapping/ParquetSinkMapping.java new file mode 100644 index 000000000..b5b9b636f --- /dev/null +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/mapping/ParquetSinkMapping.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.spark.mapping; + +import org.apache.wayang.basic.operators.ParquetSink; +import org.apache.wayang.core.mapping.Mapping; +import org.apache.wayang.core.mapping.OperatorPattern; +import org.apache.wayang.core.mapping.PlanTransformation; +import org.apache.wayang.core.mapping.ReplacementSubplanFactory; +import org.apache.wayang.core.mapping.SubplanPattern; +import org.apache.wayang.spark.operators.SparkParquetSink; +import org.apache.wayang.spark.platform.SparkPlatform; + +import java.util.Collection; +import java.util.Collections; + +public class ParquetSinkMapping implements Mapping { + + @Override + public Collection getTransformations() { + return Collections.singleton(new PlanTransformation( + this.createSubplanPattern(), + this.createReplacementSubplanFactory(), + SparkPlatform.getInstance() + )); + } + + private SubplanPattern createSubplanPattern() { + OperatorPattern operatorPattern = new OperatorPattern<>( + "sink", new ParquetSink("", true, true), false + ); + return SubplanPattern.createSingleton(operatorPattern); + } + + private ReplacementSubplanFactory createReplacementSubplanFactory() { + return new ReplacementSubplanFactory.OfSingleOperators( + (matchedOperator, epoch) -> new SparkParquetSink(matchedOperator).at(epoch) + ); + } +} diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkDatasetToRddOperator.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkDatasetToRddOperator.java new file mode 100644 index 000000000..16222cabf --- /dev/null +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkDatasetToRddOperator.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.spark.operators; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.optimizer.costs.LoadProfileEstimator; +import org.apache.wayang.core.optimizer.costs.LoadProfileEstimators; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.plan.wayangplan.UnaryToUnaryOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.spark.channels.DatasetChannel; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.execution.SparkExecutor; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.Optional; + +/** + * Conversion operator from {@link DatasetChannel} to {@link RddChannel}. + */ +public class SparkDatasetToRddOperator extends UnaryToUnaryOperator implements SparkExecutionOperator { + + public SparkDatasetToRddOperator() { + super(DataSetType.createDefault(Row.class), DataSetType.createDefault(Row.class), false); + } + + @Override + public Tuple, Collection> evaluate(ChannelInstance[] inputs, + ChannelInstance[] outputs, + SparkExecutor sparkExecutor, + OptimizationContext.OperatorContext operatorContext) { + DatasetChannel.Instance input = (DatasetChannel.Instance) inputs[0]; + RddChannel.Instance output = (RddChannel.Instance) outputs[0]; + + Dataset dataset = input.provideDataset(); + output.accept(dataset.toJavaRDD(), sparkExecutor); + + return ExecutionOperator.modelLazyExecution(inputs, outputs, operatorContext); + } + + @Override + public List getSupportedInputChannels(int index) { + return Arrays.asList(DatasetChannel.UNCACHED_DESCRIPTOR, DatasetChannel.CACHED_DESCRIPTOR); + } + + @Override + public List getSupportedOutputChannels(int index) { + return Collections.singletonList(RddChannel.UNCACHED_DESCRIPTOR); + } + + @Override + public boolean containsAction() { + return false; + } + + @Override + public String getLoadProfileEstimatorConfigurationKey() { + return "wayang.spark.dataset-to-rdd.load"; + } + + @Override + public Optional createLoadProfileEstimator(Configuration configuration) { + return Optional.ofNullable( + LoadProfileEstimators.createFromSpecification( + this.getLoadProfileEstimatorConfigurationKey(), configuration + ) + ); + } +} diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkParquetSink.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkParquetSink.java new file mode 100644 index 000000000..587033156 --- /dev/null +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkParquetSink.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.spark.operators; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.operators.ParquetSink; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.spark.channels.DatasetChannel; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.execution.SparkExecutor; +import org.apache.wayang.spark.util.DatasetConverters; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +/** + * Writes records to Parquet using Spark. + */ +public class SparkParquetSink extends ParquetSink implements SparkExecutionOperator { + + private final SaveMode saveMode; + + public SparkParquetSink(ParquetSink that) { + super(that.getOutputUrl(), that.isOverwrite(), that.prefersDataset(), that.getType()); + this.saveMode = that.isOverwrite() ? SaveMode.Overwrite : SaveMode.ErrorIfExists; + } + + @Override + public Tuple, Collection> evaluate(ChannelInstance[] inputs, + ChannelInstance[] outputs, + SparkExecutor sparkExecutor, + OptimizationContext.OperatorContext operatorContext) { + final Dataset dataset = this.obtainDataset(inputs[0], sparkExecutor); + dataset.write().mode(this.saveMode).parquet(this.getOutputUrl()); + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + private Dataset obtainDataset(ChannelInstance input, SparkExecutor sparkExecutor) { + if (input instanceof DatasetChannel.Instance) { + return ((DatasetChannel.Instance) input).provideDataset(); + } + JavaRDD rdd = ((RddChannel.Instance) input).provideRdd(); + return DatasetConverters.recordsToDataset(rdd, this.getType(), sparkExecutor.ss); + } + + @Override + public List getSupportedInputChannels(int index) { + if (this.prefersDataset()) { + return Arrays.asList(DatasetChannel.UNCACHED_DESCRIPTOR, DatasetChannel.CACHED_DESCRIPTOR); + } + return Arrays.asList(DatasetChannel.UNCACHED_DESCRIPTOR, DatasetChannel.CACHED_DESCRIPTOR, + RddChannel.UNCACHED_DESCRIPTOR, RddChannel.CACHED_DESCRIPTOR); + } + + @Override + public List getSupportedOutputChannels(int index) { + throw new UnsupportedOperationException("This operator has no outputs."); + } + + @Override + public boolean containsAction() { + return true; + } +} diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkParquetSource.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkParquetSource.java index adae5a0d7..315acbb2c 100644 --- a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkParquetSource.java +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkParquetSource.java @@ -29,6 +29,7 @@ import org.apache.wayang.core.platform.ChannelInstance; import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.spark.channels.DatasetChannel; import org.apache.wayang.spark.channels.RddChannel; import org.apache.wayang.spark.execution.SparkExecutor; @@ -64,8 +65,6 @@ public Tuple, Collection> eval assert inputs.length == this.getNumInputs(); assert outputs.length == this.getNumOutputs(); - RddChannel.Instance output = (RddChannel.Instance) outputs[0]; - Dataset table = sparkExecutor.ss.read().parquet(this.getInputUrl().trim()); // Reads a projection, if any (loads the complete file if no projection defined) @@ -74,16 +73,6 @@ public Tuple, Collection> eval table = table.selectExpr(projection); } - // Wrap dataset into a JavaRDD and convert Row's to Record's - JavaRDD rdd = table.toJavaRDD().map(row -> { - List values = IntStream.range(0, row.size()) - .mapToObj(row::get) - .collect(Collectors.toList()); - return new Record(values); - }); - this.name(rdd); - output.accept(rdd, sparkExecutor); - ExecutionLineageNode prepareLineageNode = new ExecutionLineageNode(operatorContext); prepareLineageNode.add(LoadProfileEstimators.createFromSpecification( "wayang.spark.parquetsource.load.prepare", sparkExecutor.getConfiguration() @@ -92,7 +81,25 @@ public Tuple, Collection> eval mainLineageNode.add(LoadProfileEstimators.createFromSpecification( "wayang.spark.parquetsource.load.main", sparkExecutor.getConfiguration() )); - output.getLineage().addPredecessor(mainLineageNode); + + if (this.isDatasetOutputPreferred() && outputs[0] instanceof DatasetChannel.Instance) { + DatasetChannel.Instance datasetOutput = + (DatasetChannel.Instance) outputs[0]; + datasetOutput.accept(table, sparkExecutor); + datasetOutput.getLineage().addPredecessor(mainLineageNode); + } else { + RddChannel.Instance output = (RddChannel.Instance) outputs[0]; + // Wrap dataset into a JavaRDD and convert Row's to Record's + JavaRDD rdd = table.toJavaRDD().map(row -> { + List values = IntStream.range(0, row.size()) + .mapToObj(row::get) + .collect(Collectors.toList()); + return new Record(values); + }); + this.name(rdd); + output.accept(rdd, sparkExecutor); + output.getLineage().addPredecessor(mainLineageNode); + } return prepareLineageNode.collectAndMark(); } @@ -109,6 +116,9 @@ public List getSupportedInputChannels(int index) { @Override public List getSupportedOutputChannels(int index) { + if (this.isDatasetOutputPreferred()) { + return Arrays.asList(DatasetChannel.UNCACHED_DESCRIPTOR, RddChannel.UNCACHED_DESCRIPTOR); + } return Collections.singletonList(RddChannel.UNCACHED_DESCRIPTOR); } diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkRddToDatasetOperator.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkRddToDatasetOperator.java new file mode 100644 index 000000000..ecde6033a --- /dev/null +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkRddToDatasetOperator.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.spark.operators; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.optimizer.costs.LoadProfileEstimator; +import org.apache.wayang.core.optimizer.costs.LoadProfileEstimators; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.plan.wayangplan.UnaryToUnaryOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.spark.channels.DatasetChannel; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.execution.SparkExecutor; +import org.apache.wayang.spark.util.DatasetConverters; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.Optional; + +/** + * Conversion operator from {@link RddChannel} to {@link DatasetChannel}. + */ +public class SparkRddToDatasetOperator extends UnaryToUnaryOperator implements SparkExecutionOperator { + + public SparkRddToDatasetOperator() { + this(DataSetType.createDefault(Record.class)); + } + + public SparkRddToDatasetOperator(DataSetType type) { + super(type, type, false); + } + + @Override + public Tuple, Collection> evaluate(ChannelInstance[] inputs, + ChannelInstance[] outputs, + SparkExecutor sparkExecutor, + OptimizationContext.OperatorContext operatorContext) { + RddChannel.Instance input = (RddChannel.Instance) inputs[0]; + DatasetChannel.Instance output = (DatasetChannel.Instance) outputs[0]; + + JavaRDD records = input.provideRdd(); + Dataset dataset = DatasetConverters.recordsToDataset(records, this.getInputType(), sparkExecutor.ss); + output.accept(dataset, sparkExecutor); + + return ExecutionOperator.modelLazyExecution(inputs, outputs, operatorContext); + } + + @Override + public List getSupportedInputChannels(int index) { + return Arrays.asList(RddChannel.UNCACHED_DESCRIPTOR, RddChannel.CACHED_DESCRIPTOR); + } + + @Override + public List getSupportedOutputChannels(int index) { + return Arrays.asList(DatasetChannel.UNCACHED_DESCRIPTOR, DatasetChannel.CACHED_DESCRIPTOR); + } + + @Override + public boolean containsAction() { + return false; + } + + @Override + public String getLoadProfileEstimatorConfigurationKey() { + return "wayang.spark.rdd-to-dataset.load"; + } + + @Override + public Optional createLoadProfileEstimator(Configuration configuration) { + return Optional.ofNullable( + LoadProfileEstimators.createFromSpecification( + this.getLoadProfileEstimatorConfigurationKey(), configuration + ) + ); + } +} diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/util/DatasetConverters.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/util/DatasetConverters.java new file mode 100644 index 000000000..bf7bd7ff7 --- /dev/null +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/util/DatasetConverters.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.spark.util; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.types.RecordType; +import org.apache.wayang.core.types.DataSetType; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.List; + +/** + * Utility methods to convert {@link Record}-backed RDDs into Spark {@link Dataset}s. + */ +public final class DatasetConverters { + + private static final int SCHEMA_SAMPLE_SIZE = 50; + + private DatasetConverters() { + } + + /** + * Convert an RDD of {@link Record}s into a Spark {@link Dataset}. + * + * @param records the records to convert + * @param dataSetType type information about the records (field names etc.) + * @param sparkSession the {@link SparkSession} used to create the {@link Dataset} + * @return a {@link Dataset} view over {@code records} + */ + public static Dataset recordsToDataset(JavaRDD records, + DataSetType dataSetType, + SparkSession sparkSession) { + StructType schema = deriveSchema(records, dataSetType); + JavaRDD rows = records.map(record -> RowFactory.create(record.getValues())); + return sparkSession.createDataFrame(rows, schema); + } + + private static StructType deriveSchema(JavaRDD rdd, DataSetType dataSetType) { + List samples = rdd.take(SCHEMA_SAMPLE_SIZE); + RecordType recordType = extractRecordType(dataSetType); + String[] fieldNames = resolveFieldNames(samples, recordType); + + List fields = new ArrayList<>(fieldNames.length); + for (int column = 0; column < fieldNames.length; column++) { + DataType dataType = inferColumnType(samples, column); + fields.add(DataTypes.createStructField(fieldNames[column], dataType, true)); + } + return new StructType(fields.toArray(new StructField[0])); + } + + private static RecordType extractRecordType(DataSetType dataSetType) { + if (dataSetType == null || dataSetType.getDataUnitType() == null) { + return null; + } + if (dataSetType.getDataUnitType() instanceof RecordType) { + return (RecordType) dataSetType.getDataUnitType(); + } + if (dataSetType.getDataUnitType().toBasicDataUnitType() instanceof RecordType) { + return (RecordType) dataSetType.getDataUnitType().toBasicDataUnitType(); + } + return null; + } + + private static String[] resolveFieldNames(List samples, RecordType recordType) { + if (recordType != null && recordType.getFieldNames() != null && recordType.getFieldNames().length > 0) { + return recordType.getFieldNames(); + } + Record sample = samples.isEmpty() ? null : samples.get(0); + int numFields = sample == null ? 0 : sample.size(); + String[] names = new String[numFields]; + for (int index = 0; index < numFields; index++) { + names[index] = "field" + index; + } + return names; + } + + private static DataType inferColumnType(List samples, int columnIndex) { + for (Record sample : samples) { + if (sample == null || columnIndex >= sample.size()) { + continue; + } + Object value = sample.getField(columnIndex); + if (value == null) { + continue; + } + DataType dataType = toSparkType(value); + if (dataType != null) { + return dataType; + } + } + return DataTypes.StringType; + } + + private static DataType toSparkType(Object value) { + if (value instanceof String || value instanceof Character) { + return DataTypes.StringType; + } else if (value instanceof Integer) { + return DataTypes.IntegerType; + } else if (value instanceof Long) { + return DataTypes.LongType; + } else if (value instanceof Short) { + return DataTypes.ShortType; + } else if (value instanceof Byte) { + return DataTypes.ByteType; + } else if (value instanceof Double) { + return DataTypes.DoubleType; + } else if (value instanceof Float) { + return DataTypes.FloatType; + } else if (value instanceof Boolean) { + return DataTypes.BooleanType; + } else if (value instanceof Timestamp) { + return DataTypes.TimestampType; + } else if (value instanceof java.sql.Date) { + return DataTypes.DateType; + } else if (value instanceof byte[]) { + return DataTypes.BinaryType; + } else if (value instanceof BigDecimal) { + BigDecimal decimal = (BigDecimal) value; + int precision = Math.min(38, Math.max(decimal.precision(), decimal.scale())); + int scale = Math.max(0, decimal.scale()); + return DataTypes.createDecimalType(precision, scale); + } else if (value instanceof BigInteger) { + BigInteger bigInteger = (BigInteger) value; + int precision = Math.min(38, bigInteger.toString().length()); + return DataTypes.createDecimalType(precision, 0); + } + return null; + } +} diff --git a/wayang-platforms/wayang-spark/src/main/resources/wayang-spark-defaults.properties b/wayang-platforms/wayang-spark/src/main/resources/wayang-spark-defaults.properties index e96446f91..0ffe1a449 100644 --- a/wayang-platforms/wayang-spark/src/main/resources/wayang-spark-defaults.properties +++ b/wayang-platforms/wayang-spark/src/main/resources/wayang-spark-defaults.properties @@ -198,6 +198,28 @@ wayang.spark.sort.load = {\ "ru":"${wayang:logGrowth(0.1, 0.1, 1000000, in0)}"\ } +wayang.spark.dataset-to-rdd.load = {\ + "in":1, "out":1,\ + "cpu":"${900*in0 + 250000}",\ + "ram":"10000",\ + "disk":"0",\ + "net":"0",\ + "p":0.9,\ + "overhead":0,\ + "ru":"${wayang:logGrowth(0.1, 0.1, 1000000, in0)}"\ +} + +wayang.spark.rdd-to-dataset.load = {\ + "in":1, "out":1,\ + "cpu":"${1200*in0 + 750000}",\ + "ram":"10000",\ + "disk":"0",\ + "net":"0",\ + "p":0.9,\ + "overhead":0,\ + "ru":"${wayang:logGrowth(0.1, 0.1, 1000000, in0)}"\ +} + wayang.spark.globalreduce.load.template = {\ "type":"mathex", "in":1, "out":1,\ "cpu":"?*in0 + ?"\ diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/DatasetChannelTest.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/DatasetChannelTest.java new file mode 100644 index 000000000..38450e5f3 --- /dev/null +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/DatasetChannelTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.spark.operators; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.wayang.spark.channels.DatasetChannel; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class DatasetChannelTest extends SparkOperatorTestBase { + + @Test + void acceptAndProvideDataset() { + List rows = sampleRows(); + Dataset dataset = createDataset(rows); + DatasetChannel.Instance instance = createDatasetChannelInstance(false); + + instance.accept(dataset, this.sparkExecutor); + + assertEquals(rows, instance.provideDataset().collectAsList()); + } + + @Test + void instrumentationCountsRows() { + List rows = sampleRows(); + Dataset dataset = createDataset(rows); + DatasetChannel.Instance instance = createDatasetChannelInstance(true); + + instance.accept(dataset, this.sparkExecutor); + + assertTrue(instance.getMeasuredCardinality().isPresent()); + assertEquals(rows.size(), instance.getMeasuredCardinality().getAsLong()); + } + + @Test + void noInstrumentationLeavesCardinalityEmpty() { + DatasetChannel.Instance instance = createDatasetChannelInstance(false); + + instance.accept(createDataset(sampleRows()), this.sparkExecutor); + + assertFalse(instance.getMeasuredCardinality().isPresent()); + } + + private DatasetChannel.Instance createDatasetChannelInstance(boolean instrumented) { + DatasetChannel channel = (DatasetChannel) DatasetChannel.UNCACHED_DESCRIPTOR + .createChannel(null, this.configuration); + if (instrumented) { + channel.markForInstrumentation(); + } + return (DatasetChannel.Instance) channel.createInstance(this.sparkExecutor, null, -1); + } + + private Dataset createDataset(List rows) { + return this.sparkExecutor.ss.createDataFrame(rows, sampleSchema()); + } + + private StructType sampleSchema() { + StructField[] fields = new StructField[]{ + DataTypes.createStructField("name", DataTypes.StringType, false), + DataTypes.createStructField("age", DataTypes.IntegerType, false) + }; + return new StructType(fields); + } + + private List sampleRows() { + return Arrays.asList( + RowFactory.create("alice", 30), + RowFactory.create("bob", 25) + ); + } +} diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/DatasetTestUtils.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/DatasetTestUtils.java new file mode 100644 index 000000000..e6736464e --- /dev/null +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/DatasetTestUtils.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.spark.operators; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.spark.execution.SparkExecutor; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.Comparator; + +final class DatasetTestUtils { + + private DatasetTestUtils() { + } + + static Dataset createSampleDataset(SparkExecutor sparkExecutor) { + return sparkExecutor.ss.createDataFrame(sampleRows(), sampleSchema()); + } + + static List sampleRows() { + return Arrays.asList( + RowFactory.create("alice", 30), + RowFactory.create("bob", 25), + RowFactory.create("carol", 41) + ); + } + + static List sampleRecords() { + return Arrays.asList( + new Record("alice", 30), + new Record("bob", 25), + new Record("carol", 41) + ); + } + + static StructType sampleSchema() { + StructField[] fields = new StructField[]{ + DataTypes.createStructField("name", DataTypes.StringType, false), + DataTypes.createStructField("age", DataTypes.IntegerType, false) + }; + return new StructType(fields); + } + + static void deleteRecursively(Path directory) throws IOException { + if (Files.notExists(directory)) { + return; + } + Files.walk(directory) + .sorted(Comparator.reverseOrder()) + .forEach(path -> { + try { + Files.deleteIfExists(path); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } +} diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkDatasetToRddOperatorTest.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkDatasetToRddOperatorTest.java new file mode 100644 index 000000000..29ec75eec --- /dev/null +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkDatasetToRddOperatorTest.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.spark.operators; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.spark.channels.RddChannel; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class SparkDatasetToRddOperatorTest extends SparkOperatorTestBase { + + @Test + void testConversionPreservesRows() { + Dataset dataset = DatasetTestUtils.createSampleDataset(this.sparkExecutor); + SparkDatasetToRddOperator operator = new SparkDatasetToRddOperator(); + + ChannelInstance[] inputs = new ChannelInstance[]{this.createDatasetChannelInstance(dataset)}; + ChannelInstance[] outputs = new ChannelInstance[]{this.createRddChannelInstance()}; + + this.evaluate(operator, inputs, outputs); + + List rows = ((RddChannel.Instance) outputs[0]).provideRdd().collect(); + assertEquals(dataset.collectAsList(), rows); + } +} diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkOperatorTestBase.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkOperatorTestBase.java index 222954da9..f2d8fb68f 100644 --- a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkOperatorTestBase.java +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkOperatorTestBase.java @@ -19,26 +19,29 @@ package org.apache.wayang.spark.operators; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.wayang.core.api.Configuration; import org.apache.wayang.core.api.Job; import org.apache.wayang.core.optimizer.DefaultOptimizationContext; import org.apache.wayang.core.optimizer.OptimizationContext; import org.apache.wayang.core.plan.wayangplan.Operator; +import org.apache.wayang.core.plan.wayangplan.WayangPlan; import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.java.channels.CollectionChannel; import org.apache.wayang.core.platform.CrossPlatformExecutor; import org.apache.wayang.core.profiling.FullInstrumentationStrategy; -import org.apache.wayang.java.channels.CollectionChannel; import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.channels.DatasetChannel; import org.apache.wayang.spark.execution.SparkExecutor; import org.apache.wayang.spark.platform.SparkPlatform; import org.apache.wayang.spark.test.ChannelFactory; +import org.apache.wayang.core.api.WayangContext; import org.junit.jupiter.api.BeforeEach; +import java.lang.reflect.Field; import java.util.Collection; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - /** * Test base for {@link SparkExecutionOperator} tests. */ @@ -48,22 +51,32 @@ class SparkOperatorTestBase { protected SparkExecutor sparkExecutor; + private Job job; + @BeforeEach void setUp() { - this.configuration = new Configuration(); - if(sparkExecutor == null) - this.sparkExecutor = (SparkExecutor) SparkPlatform.getInstance().getExecutorFactory().create(this.mockJob()); + WayangContext context = new WayangContext(new Configuration()); + this.job = context.createJob("spark-operator-test", new WayangPlan()); + this.configuration = this.job.getConfiguration(); + this.ensureCrossPlatformExecutor(); + this.sparkExecutor = (SparkExecutor) SparkPlatform.getInstance().getExecutorFactory().create(this.job); } - Job mockJob() { - final Job job = mock(Job.class); - when(job.getConfiguration()).thenReturn(this.configuration); - when(job.getCrossPlatformExecutor()).thenReturn(new CrossPlatformExecutor(job, new FullInstrumentationStrategy())); - return job; + private void ensureCrossPlatformExecutor() { + try { + Field field = Job.class.getDeclaredField("crossPlatformExecutor"); + field.setAccessible(true); + if (field.get(this.job) == null) { + CrossPlatformExecutor executor = new CrossPlatformExecutor(this.job, new FullInstrumentationStrategy()); + field.set(this.job, executor); + } + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Failed to initialize CrossPlatformExecutor for tests.", e); + } } protected OptimizationContext.OperatorContext createOperatorContext(Operator operator) { - OptimizationContext optimizationContext = new DefaultOptimizationContext(mockJob()); + OptimizationContext optimizationContext = new DefaultOptimizationContext(this.job); return optimizationContext.addOneTimeOperator(operator); } @@ -81,6 +94,14 @@ RddChannel.Instance createRddChannelInstance(Collection collection) { return ChannelFactory.createRddChannelInstance(collection, this.sparkExecutor, this.configuration); } + DatasetChannel.Instance createDatasetChannelInstance() { + return ChannelFactory.createDatasetChannelInstance(this.configuration); + } + + DatasetChannel.Instance createDatasetChannelInstance(Dataset dataset) { + return ChannelFactory.createDatasetChannelInstance(dataset, this.sparkExecutor, this.configuration); + } + protected CollectionChannel.Instance createCollectionChannelInstance() { return ChannelFactory.createCollectionChannelInstance(this.configuration); } diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkParquetSinkTest.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkParquetSinkTest.java new file mode 100644 index 000000000..eb7db83b1 --- /dev/null +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkParquetSinkTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.spark.operators; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.wayang.basic.operators.ParquetSink; +import org.apache.wayang.core.platform.ChannelInstance; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class SparkParquetSinkTest extends SparkOperatorTestBase { + + @Test + void writesDatasetToParquet() throws IOException { + Dataset dataset = DatasetTestUtils.createSampleDataset(this.sparkExecutor); + Path outputDir = Files.createTempDirectory("wayang-dataset-parquet-sink"); + try { + SparkParquetSink sink = new SparkParquetSink(new ParquetSink(outputDir.toString(), true, true)); + + ChannelInstance[] inputs = new ChannelInstance[]{this.createDatasetChannelInstance(dataset)}; + ChannelInstance[] outputs = new ChannelInstance[0]; + + this.evaluate(sink, inputs, outputs); + + Dataset stored = this.sparkExecutor.ss.read().parquet(outputDir.toString()); + assertEquals(dataset.collectAsList(), stored.collectAsList()); + } finally { + DatasetTestUtils.deleteRecursively(outputDir); + } + } + + @Test + void writesRddToParquet() throws IOException { + Path outputDir = Files.createTempDirectory("wayang-rdd-parquet-sink"); + try { + SparkParquetSink sink = new SparkParquetSink(new ParquetSink(outputDir.toString(), true, false)); + ChannelInstance[] inputs = new ChannelInstance[]{this.createRddChannelInstance(DatasetTestUtils.sampleRecords())}; + ChannelInstance[] outputs = new ChannelInstance[0]; + + this.evaluate(sink, inputs, outputs); + + Dataset stored = this.sparkExecutor.ss.read().parquet(outputDir.toString()); + assertEquals(DatasetTestUtils.sampleRows(), stored.collectAsList()); + } finally { + DatasetTestUtils.deleteRecursively(outputDir); + } + } +} diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkParquetSourceDatasetOutputTest.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkParquetSourceDatasetOutputTest.java new file mode 100644 index 000000000..d4249aed0 --- /dev/null +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkParquetSourceDatasetOutputTest.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.spark.operators; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.spark.channels.DatasetChannel; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class SparkParquetSourceDatasetOutputTest extends SparkOperatorTestBase { + + @Test + void producesDatasetChannel() throws IOException { + Dataset dataset = DatasetTestUtils.createSampleDataset(this.sparkExecutor); + Path inputPath = Files.createTempDirectory("wayang-parquet-source"); + try { + dataset.write().mode(SaveMode.Overwrite).parquet(inputPath.toString()); + + SparkParquetSource source = new SparkParquetSource(inputPath.toString(), null); + source.preferDatasetOutput(true); + + ChannelInstance[] inputs = new ChannelInstance[0]; + ChannelInstance[] outputs = new ChannelInstance[]{this.createDatasetChannelInstance()}; + + this.evaluate(source, inputs, outputs); + + Dataset result = ((DatasetChannel.Instance) outputs[0]).provideDataset(); + assertEquals(dataset.collectAsList(), result.collectAsList()); + } finally { + DatasetTestUtils.deleteRecursively(inputPath); + } + } +} diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkRddToDatasetOperatorTest.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkRddToDatasetOperatorTest.java new file mode 100644 index 000000000..0762e55e0 --- /dev/null +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkRddToDatasetOperatorTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.spark.operators; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.spark.channels.DatasetChannel; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class SparkRddToDatasetOperatorTest extends SparkOperatorTestBase { + + @Test + void convertsRecordsToDataset() { + SparkRddToDatasetOperator operator = new SparkRddToDatasetOperator(); + + ChannelInstance[] inputs = new ChannelInstance[]{this.createRddChannelInstance(DatasetTestUtils.sampleRecords())}; + ChannelInstance[] outputs = new ChannelInstance[]{this.createDatasetChannelInstance()}; + + this.evaluate(operator, inputs, outputs); + + Dataset dataset = ((DatasetChannel.Instance) outputs[0]).provideDataset(); + assertEquals(DatasetTestUtils.sampleRows(), dataset.collectAsList()); + } +} diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/test/ChannelFactory.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/test/ChannelFactory.java index 3c430d116..565cb0619 100644 --- a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/test/ChannelFactory.java +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/test/ChannelFactory.java @@ -23,6 +23,9 @@ import org.apache.wayang.core.platform.ChannelDescriptor; import org.apache.wayang.core.util.WayangCollections; import org.apache.wayang.java.channels.CollectionChannel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.wayang.spark.channels.DatasetChannel; import org.apache.wayang.spark.channels.RddChannel; import org.apache.wayang.spark.execution.SparkExecutor; import org.junit.jupiter.api.BeforeEach; @@ -73,4 +76,18 @@ public static CollectionChannel.Instance createCollectionChannelInstance(Collect return instance; } + public static DatasetChannel.Instance createDatasetChannelInstance(Configuration configuration) { + return (DatasetChannel.Instance) DatasetChannel.UNCACHED_DESCRIPTOR + .createChannel(null, configuration) + .createInstance(sparkExecutor, null, -1); + } + + public static DatasetChannel.Instance createDatasetChannelInstance(Dataset dataset, + SparkExecutor sparkExecutor, + Configuration configuration) { + DatasetChannel.Instance instance = createDatasetChannelInstance(configuration); + instance.accept(dataset, sparkExecutor); + return instance; + } + } From 5f42f3c5232f865a0d8c3812ea35844acb126402 Mon Sep 17 00:00:00 2001 From: 2pk03 Date: Mon, 15 Dec 2025 15:56:07 +0100 Subject: [PATCH 2/5] Update readme / add documentation --- README.md | 14 ++++++++++++++ guides/spark-datasets.md | 42 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 guides/spark-datasets.md diff --git a/README.md b/README.md index c9cf4b472..a3d9560cd 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,20 @@ Apache Wayang provides a flexible architecture which enables easy addition of ne For a quick guide on how to run WordCount see [here](guides/tutorial.md). +### Spark Dataset / DataFrame pipelines + +Wayang’s Spark platform can now execute end-to-end pipelines on Spark `Dataset[Row]` (aka DataFrames). This is particularly useful when working with lakehouse-style storage (Parquet/Delta) or when you want to plug Spark ML stages into a Wayang plan without repeatedly falling back to RDDs. + +To build a Dataset-backed pipeline: + +1. **Use the Dataset-aware plan builder APIs.** + - `PlanBuilder.readParquetAsDataset(...)` (or the Java equivalent) reads Parquet files directly into a Dataset channel. + - `DataQuanta.writeParquetAsDataset(...)` writes a Dataset channel without converting it back to an RDD. +2. **Keep operators dataset-compatible.** Most operators continue to work unchanged; if an operator explicitly prefers RDDs, Wayang will insert the necessary conversions automatically (at an additional cost). Custom operators can expose `DatasetChannel` descriptors to stay in the dataframe world. +3. **Let the optimizer do the rest.** The optimizer now assigns a higher cost to Dataset↔RDD conversions, so once you opt into Dataset sources/sinks the plan will stay in Dataset form by default. + +No extra flags are required—just opt into the Dataset-based APIs where you want dataframe semantics. If you see unexpected conversions in your execution plan, check that the upstream/downstream operators you use can consume `DatasetChannel`s; otherwise Wayang will insert a conversion operator for you. + ## Quick Guide for Developing with Wayang For a quick guide on how to use Wayang in your Java/Scala project see [here](guides/develop-with-Wayang.md). diff --git a/guides/spark-datasets.md b/guides/spark-datasets.md new file mode 100644 index 000000000..718807e03 --- /dev/null +++ b/guides/spark-datasets.md @@ -0,0 +1,42 @@ +--- +title: Spark Dataset pipelines +description: How to build Wayang jobs that stay on Spark Datasets/DataFrames from source to sink. +--- + +Wayang’s Spark backend can now run entire pipelines on Spark `Dataset[Row]` (a.k.a. DataFrames). Use this mode when you ingest from lakehouse formats (Parquet/Delta), interoperate with Spark ML stages, or simply prefer schema-aware processing. This guide explains how to opt in. + +## When to use Dataset channels + +- **Lakehouse storage:** Reading Parquet/Delta directly into datasets avoids repeated schema inference and keeps Spark’s optimized Parquet reader in play. +- **Spark ML:** Our ML operators already convert RDDs into DataFrames internally. Feeding them a dataset channel skips that conversion and preserves column names. +- **Federated pipelines:** You can mix dataset-backed stages on Spark with other platforms; Wayang will insert conversions only when strictly necessary. + +## Enable Dataset sources and sinks + +1. **Plan builder APIs:** + - `PlanBuilder.readParquetAsDataset(...)` (Scala/Java) loads Parquet files into a `DatasetChannel` instead of an `RddChannel`. + - `DataQuanta.writeParquetAsDataset(...)` writes a dataset back to Parquet without converting to RDD first. +2. **Prefer dataset-friendly operators:** Most unary/binary operators accept either channel type, but custom operators can advertise dataset descriptors explicitly. See `DatasetChannel` in `wayang-platforms/wayang-spark` for details. +3. **Let the optimizer keep it:** The optimizer now assigns costs to Dataset↔RDD conversions, so once your plan starts with a dataset channel it will stay in dataset form unless an operator demands an RDD. + +## Mixing with RDD operators + +If a stage only supports RDDs, Wayang inserts conversion operators automatically: + +- `SparkRddToDatasetOperator` converts an RDD of `org.apache.wayang.basic.data.Record` into a Spark `Dataset[Row]` (using sampled schema inference or `RecordType`). +- `SparkDatasetToRddOperator` turns a `Dataset[Row]` back into a JavaRDD.`Record`. + +Both conversions carry non-trivial load profiles. You’ll see them in plan explanations if you mix dataset- and RDD-only operators. + +## Developer checklist + +- **Use `RecordType` when possible.** Providing field names in your logical operators helps the converter derive a precise schema. +- **Re-use `sparkExecutor.ss`.** When writing custom Spark operators that build DataFrames, use the provided `SparkExecutor` instead of `SparkSession.builder()` to avoid extra contexts. +- **Watch plan explanations.** Run `PlanBuilder.buildAndExplain(true)` to verify whether conversions are inserted. If they are, consider adding dataset descriptors to your operators. + +## Current limitations + +- Only Parquet sources/sinks expose dataset-specific APIs today. Text/Object sources still produce RDD channels. +- ML4All pipelines currently emit plain `double[]`/`Double` RDDs. They still benefit from the internal DataFrame conversions but do not expose dataset channels yet. + +Contributions to widen dataset support (e.g., dataset-aware `map`/`filter` or ML4All stages) are welcome. From 599508d25c189f3d34aaf4d06f6ffc26e6cd53e9 Mon Sep 17 00:00:00 2001 From: 2pk03 Date: Mon, 15 Dec 2025 16:02:52 +0100 Subject: [PATCH 3/5] add license header --- guides/spark-datasets.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/guides/spark-datasets.md b/guides/spark-datasets.md index 718807e03..cf5a83938 100644 --- a/guides/spark-datasets.md +++ b/guides/spark-datasets.md @@ -1,3 +1,22 @@ + + --- title: Spark Dataset pipelines description: How to build Wayang jobs that stay on Spark Datasets/DataFrames from source to sink. From 7bcf2ff375ece76c218368b2b822431f25ca7f6d Mon Sep 17 00:00:00 2001 From: 2pk03 Date: Mon, 15 Dec 2025 16:39:11 +0100 Subject: [PATCH 4/5] Deterministic plan selection with stable comparator + add integration test --- .../costs/DefaultEstimatableCost.java | 6 +- .../LatentOperatorPruningStrategy.java | 2 +- .../enumeration/PlanEnumeration.java | 19 ++- .../enumeration/PlanImplementation.java | 76 +++++++++++ .../enumeration/TopKPruningStrategy.java | 10 +- .../org/apache/wayang/core/util/MultiMap.java | 8 +- .../PlanEnumerationDeterminismTest.java | 126 ++++++++++++++++++ 7 files changed, 218 insertions(+), 29 deletions(-) create mode 100644 wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumerationDeterminismTest.java diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/costs/DefaultEstimatableCost.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/costs/DefaultEstimatableCost.java index ebc0f8cd2..2981d4b2a 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/costs/DefaultEstimatableCost.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/costs/DefaultEstimatableCost.java @@ -66,11 +66,7 @@ public class DefaultEstimatableCost implements EstimatableCost { Set executedStages ) { final PlanImplementation bestPlanImplementation = executionPlans.stream() - .reduce((p1, p2) -> { - final double t1 = p1.getSquashedCostEstimate(); - final double t2 = p2.getSquashedCostEstimate(); - return t1 < t2 ? p1 : p2; - }) + .min(PlanImplementation.costComparator()) .orElseThrow(() -> new WayangException("Could not find an execution plan.")); return bestPlanImplementation; } diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/LatentOperatorPruningStrategy.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/LatentOperatorPruningStrategy.java index a228f148f..fa797e900 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/LatentOperatorPruningStrategy.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/LatentOperatorPruningStrategy.java @@ -85,7 +85,7 @@ private PlanImplementation selectBestPlanBinary(PlanImplementation p1, PlanImplementation p2) { final double t1 = p1.getSquashedCostEstimate(true); final double t2 = p2.getSquashedCostEstimate(true); - final boolean isPickP1 = t1 <= t2; + final boolean isPickP1 = PlanImplementation.costComparator().compare(p1, p2) <= 0; if (logger.isDebugEnabled()) { if (isPickP1) { LogManager.getLogger(LatentOperatorPruningStrategy.class).debug( diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java index e00753c37..ba1d7d4b5 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java @@ -412,16 +412,15 @@ private Collection concatenatePartialPlansBatchwise( if (junction == null) continue; // If we found a junction, then we can enumerate all PlanImplementation combinations. - final List> groupPlans = WayangCollections.map( - concatGroupCombo, - concatGroup -> { - Set concatDescriptors = concatGroup2concatDescriptor.get(concatGroup); - Set planImplementations = new HashSet<>(concatDescriptors.size()); - for (PlanImplementation.ConcatenationDescriptor concatDescriptor : concatDescriptors) { - planImplementations.add(concatDescriptor.getPlanImplementation()); - } - return planImplementations; - }); + final List> groupPlans = WayangCollections.map( + concatGroupCombo, + concatGroup -> { + Set concatDescriptors = concatGroup2concatDescriptor.get(concatGroup); + return concatDescriptors.stream() + .map(PlanImplementation.ConcatenationDescriptor::getPlanImplementation) + .sorted(PlanImplementation.structuralComparator()) + .collect(Collectors.toList()); + }); for (List planCombo : WayangCollections.streamedCrossProduct(groupPlans)) { PlanImplementation basePlan = planCombo.get(0); diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java index 320f75ce9..501d92e11 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java @@ -49,6 +49,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -59,6 +60,7 @@ import java.util.Set; import java.util.function.ToDoubleFunction; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; /** @@ -67,6 +69,11 @@ public class PlanImplementation { private static final Logger logger = LogManager.getLogger(PlanImplementation.class); + private static final Comparator COST_COMPARATOR = + Comparator.comparingDouble((PlanImplementation plan) -> plan.getSquashedCostEstimate(true)) + .thenComparing(PlanImplementation::getDeterministicIdentifier); + private static final Comparator STRUCTURAL_COMPARATOR = + Comparator.comparing(PlanImplementation::getDeterministicIdentifier); /** * {@link ExecutionOperator}s contained in this instance. @@ -182,6 +189,14 @@ private PlanImplementation(PlanEnumeration planEnumeration, assert this.planEnumeration != null; } + public static Comparator costComparator() { + return COST_COMPARATOR; + } + + public static Comparator structuralComparator() { + return STRUCTURAL_COMPARATOR; + } + /** * @return the {@link PlanEnumeration} this instance belongs to @@ -978,6 +993,67 @@ Stream streamOperators() { return operatorStream; } + /** + * Provides a deterministic identifier that captures the current state of this plan. While not guaranteed to + * be unique, it is stable across runs for the same logical plan and can therefore be used for reproducible + * ordering. + * + * @return the deterministic identifier + */ + public String getDeterministicIdentifier() { + final String operatorDescriptor = this.operators.stream() + .map(PlanImplementation::describeOperator) + .sorted() + .collect(Collectors.joining("|")); + final String junctionDescriptor = this.junctions.values().stream() + .map(PlanImplementation::describeJunction) + .sorted() + .collect(Collectors.joining("|")); + final String loopDescriptor = this.loopImplementations.entrySet().stream() + .map(entry -> describeLoop(entry.getKey(), entry.getValue())) + .sorted() + .collect(Collectors.joining("|")); + return operatorDescriptor + "#" + junctionDescriptor + "#" + loopDescriptor; + } + + private static String describeOperator(Operator operator) { + final String name = operator.getName() == null ? "" : operator.getName(); + return operator.getClass().getName() + ":" + name + ":" + operator.getEpoch(); + } + + private static String describeJunction(Junction junction) { + final String source = describeOutputSlot(junction.getSourceOutput()); + final String targets = IntStream.range(0, junction.getNumTargets()) + .mapToObj(i -> describeInputSlot(junction.getTargetInput(i))) + .sorted() + .collect(Collectors.joining(",")); + return source + "->" + targets; + } + + private static String describeLoop(LoopSubplan loop, LoopImplementation implementation) { + final String descriptor = describeOperator(loop); + final String iterationDescriptor = implementation.getIterationImplementations().stream() + .map(iteration -> Integer.toString(iteration.getNumIterations())) + .collect(Collectors.joining(",")); + return descriptor + ":" + iterationDescriptor; + } + + private static String describeInputSlot(InputSlot slot) { + if (slot == null) { + return "null"; + } + final Operator owner = slot.getOwner(); + return describeOperator(owner) + ".in[" + slot.getIndex() + "]:" + slot.getName(); + } + + private static String describeOutputSlot(OutputSlot slot) { + if (slot == null) { + return "null"; + } + final Operator owner = slot.getOwner(); + return describeOperator(owner) + ".out[" + slot.getIndex() + "]:" + slot.getName(); + } + @Override public String toString() { return String.format("PlanImplementation[%s, %s, costs=%s]", diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/TopKPruningStrategy.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/TopKPruningStrategy.java index a80f9d99c..2519bc13c 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/TopKPruningStrategy.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/TopKPruningStrategy.java @@ -40,16 +40,8 @@ public void prune(PlanEnumeration planEnumeration) { if (planEnumeration.getPlanImplementations().size() <= this.k) return; ArrayList planImplementations = new ArrayList<>(planEnumeration.getPlanImplementations()); - planImplementations.sort(this::comparePlanImplementations); + planImplementations.sort(PlanImplementation.costComparator()); planEnumeration.getPlanImplementations().retainAll(planImplementations.subList(0, this.k)); } - - private int comparePlanImplementations(PlanImplementation p1, - PlanImplementation p2) { - final double t1 = p1.getSquashedCostEstimate(true); - final double t2 = p2.getSquashedCostEstimate(true); - return Double.compare(t1, t2); - } - } diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/MultiMap.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/MultiMap.java index 574d4ff44..4fa6bf3a3 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/MultiMap.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/MultiMap.java @@ -18,14 +18,14 @@ package org.apache.wayang.core.util; -import java.util.HashMap; -import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.Set; /** * Maps keys to multiple values. Each key value pair is unique. */ -public class MultiMap extends HashMap> { +public class MultiMap extends LinkedHashMap> { /** * Associate a key with a new value. @@ -35,7 +35,7 @@ public class MultiMap extends HashMap> { * @return whether the value was not yet associated with the key */ public boolean putSingle(K key, V value) { - final Set values = this.computeIfAbsent(key, k -> new HashSet<>()); + final Set values = this.computeIfAbsent(key, k -> new LinkedHashSet<>()); return values.add(value); } diff --git a/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumerationDeterminismTest.java b/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumerationDeterminismTest.java new file mode 100644 index 000000000..d6e6efcda --- /dev/null +++ b/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumerationDeterminismTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.core.optimizer.enumeration; + +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.api.configuration.ExplicitCollectionProvider; +import org.apache.wayang.core.optimizer.DefaultOptimizationContext; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.optimizer.costs.ConstantLoadProfileEstimator; +import org.apache.wayang.core.optimizer.costs.LoadEstimate; +import org.apache.wayang.core.optimizer.costs.LoadProfile; +import org.apache.wayang.core.plan.executionplan.Channel; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.plan.wayangplan.InputSlot; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.test.DummyExecutionOperator; +import org.apache.wayang.core.test.DummyReusableChannel; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Integration test that exercises {@link PlanEnumeration#concatenate(OutputSlot, Collection, Map, OptimizationContext, org.apache.wayang.commons.util.profiledb.model.measurement.TimeMeasurement)} + * to ensure that plan combinations are produced deterministically. + */ +class PlanEnumerationDeterminismTest { + + @Test + void concatenationProducesStablePlanOrdering() { + Configuration configuration = new Configuration(); + configuration.setPruningStrategyClassProvider( + new ExplicitCollectionProvider>(configuration) + ); + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + + DummyExecutionOperator producer = new DummyExecutionOperator(0, 1, false); + DummyExecutionOperator consumer = new DummyExecutionOperator(1, 0, false); + registerChannelDescriptors(producer, consumer); + registerLoadEstimator(configuration, producer, 10); + registerLoadEstimator(configuration, consumer, 5); + + List firstRun = enumerateDeterministicIds(job, producer, consumer); + List secondRun = enumerateDeterministicIds(job, producer, consumer); + + assertTrue(firstRun.size() > 1, "Expected multiple plan implementations."); + assertEquals(firstRun, secondRun, "Enumeration order must be deterministic."); + } + + private static List enumerateDeterministicIds(Job job, + ExecutionOperator producer, + ExecutionOperator consumer) { + DefaultOptimizationContext optimizationContext = new DefaultOptimizationContext(job); + optimizationContext.addOneTimeOperator(producer); + optimizationContext.addOneTimeOperator(consumer); + + PlanEnumeration baseEnumeration = PlanEnumeration.createSingleton(producer, optimizationContext); + duplicatePlanImplementations(baseEnumeration, 3); + + PlanEnumeration targetEnumeration = PlanEnumeration.createSingleton(consumer, optimizationContext); + duplicatePlanImplementations(targetEnumeration, 2); + + Map, PlanEnumeration> targets = new LinkedHashMap<>(); + targets.put(consumer.getInput(0), targetEnumeration); + + PlanEnumeration concatenated = baseEnumeration.concatenate( + producer.getOutput(0), + Collections.emptyList(), + targets, + optimizationContext, + null + ); + + return concatenated.getPlanImplementations().stream() + .map(PlanImplementation::getDeterministicIdentifier) + .collect(Collectors.toList()); + } + + private static void duplicatePlanImplementations(PlanEnumeration enumeration, int desiredCount) { + PlanImplementation template = enumeration.getPlanImplementations().iterator().next(); + while (enumeration.getPlanImplementations().size() < desiredCount) { + enumeration.add(new PlanImplementation(template)); + } + } + + private static void registerLoadEstimator(Configuration configuration, + ExecutionOperator operator, + long cpuCost) { + ConstantLoadProfileEstimator estimator = new ConstantLoadProfileEstimator( + new LoadProfile(new LoadEstimate(cpuCost), new LoadEstimate(1)) + ); + configuration.getOperatorLoadProfileEstimatorProvider().set(operator, estimator); + } + + private static void registerChannelDescriptors(DummyExecutionOperator producer, + DummyExecutionOperator consumer) { + producer.getSupportedOutputChannels(0).add(DummyReusableChannel.DESCRIPTOR); + consumer.getSupportedInputChannels(0).add(DummyReusableChannel.DESCRIPTOR); + } +} From 93fbadb35a1132f9495b88e7ebe64c2a281b7937 Mon Sep 17 00:00:00 2001 From: 2pk03 Date: Tue, 16 Dec 2025 08:41:55 +0100 Subject: [PATCH 5/5] Ensure plan enumeration & channel conversion are deterministic. Use ordered sets for scope/slot tracking, add stable cost tiebreaker, and cover with determinism tests. --- .../channels/ChannelConversionGraph.java | 48 ++++-- .../enumeration/PlanEnumeration.java | 4 +- .../enumeration/PlanImplementation.java | 9 +- .../wayang/core/util/WayangCollections.java | 8 +- ...ChannelConversionGraphDeterminismTest.java | 137 ++++++++++++++++++ .../PlanEnumerationDeterminismTest.java | 51 +++++-- 6 files changed, 225 insertions(+), 32 deletions(-) create mode 100644 wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraphDeterminismTest.java diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraph.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraph.java index 3e06c95f7..f56c6080d 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraph.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraph.java @@ -47,15 +47,16 @@ import java.util.Collections; import java.util.Comparator; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import java.util.function.ToDoubleFunction; import java.util.stream.Collectors; +import java.util.stream.StreamSupport; /** * This graph contains a set of {@link ChannelConversion}s. @@ -168,7 +169,7 @@ private Tree mergeTrees(Collection trees) { final Tree firstTree = iterator.next(); Bitmask combinationSettledIndices = new Bitmask(firstTree.settledDestinationIndices); int maxSettledIndices = combinationSettledIndices.cardinality(); - final HashSet employedChannelDescriptors = new HashSet<>(firstTree.employedChannelDescriptors); + final LinkedHashSet employedChannelDescriptors = new LinkedHashSet<>(firstTree.employedChannelDescriptors); int maxVisitedChannelDescriptors = employedChannelDescriptors.size(); double costs = firstTree.costs; TreeVertex newRoot = new TreeVertex(firstTree.root.channelDescriptor, firstTree.root.settledIndices); @@ -222,7 +223,11 @@ public static class CostbasedTreeSelectionStrategy implements TreeSelectionStrat @Override public Tree select(Tree t1, Tree t2) { - return t1.costs <= t2.costs ? t1 : t2; + int cmp = Double.compare(t1.costs, t2.costs); + if (cmp == 0) { + cmp = t1.getDeterministicSignature().compareTo(t2.getDeterministicSignature()); + } + return cmp <= 0 ? t1 : t2; } } @@ -381,7 +386,7 @@ private ShortestTreeSearcher(OutputSlot sourceOutput, this.existingDestinationChannelIndices = new Bitmask(); this.collectExistingChannels(sourceChannel); - this.openChannelDescriptors = new HashSet<>(openChannels.size()); + this.openChannelDescriptors = new LinkedHashSet<>(openChannels.size()); for (Channel openChannel : openChannels) { this.openChannelDescriptors.add(openChannel.getDescriptor()); } @@ -477,7 +482,9 @@ private Set resolveSupportedChannels(final InputSlot input final List supportedInputChannels = owner.getSupportedInputChannels(input.getIndex()); if (input.isLoopInvariant()) { // Loop input is needed in several iterations and must therefore be reusable. - return supportedInputChannels.stream().filter(ChannelDescriptor::isReusable).collect(Collectors.toSet()); + return supportedInputChannels.stream() + .filter(ChannelDescriptor::isReusable) + .collect(Collectors.toCollection(LinkedHashSet::new)); } else { return WayangCollections.asSet(supportedInputChannels); } @@ -546,7 +553,7 @@ private void kernelizeChannelRequests() { } if (channelDescriptors.size() - numReusableChannels == 1) { iterator.remove(); - channelDescriptors = new HashSet<>(channelDescriptors); + channelDescriptors = new LinkedHashSet<>(channelDescriptors); channelDescriptors.removeIf(channelDescriptor -> !channelDescriptor.isReusable()); kernelDestChannelDescriptorSetsToIndicesUpdates.add(new Tuple<>(channelDescriptors, indices)); } @@ -575,7 +582,7 @@ private void kernelizeChannelRequests() { */ private Tree searchTree() { // Prepare the recursive traversal. - final HashSet visitedChannelDescriptors = new HashSet<>(16); + final LinkedHashSet visitedChannelDescriptors = new LinkedHashSet<>(16); visitedChannelDescriptors.add(this.sourceChannelDescriptor); // Perform the traversal. @@ -777,7 +784,7 @@ private Set getSuccessorChannelDescriptors(ChannelDescriptor final Channel channel = this.existingChannels.get(descriptor); if (channel == null || this.openChannelDescriptors.contains(descriptor)) return null; - Set result = new HashSet<>(); + Set result = new LinkedHashSet<>(); for (ExecutionTask consumer : channel.getConsumers()) { if (!consumer.getOperator().isAuxiliary()) continue; for (Channel successorChannel : consumer.getOutputChannels()) { @@ -988,7 +995,12 @@ private static class Tree { * * @see TreeVertex#channelDescriptor */ - private final Set employedChannelDescriptors = new HashSet<>(); + private final Set employedChannelDescriptors = new LinkedHashSet<>(); + + /** + * Cached deterministic signature for tie-breaking. + */ + private String deterministicSignature; /** * The sum of the costs of all {@link TreeEdge}s of this instance. @@ -1010,6 +1022,7 @@ static Tree singleton(ChannelDescriptor channelDescriptor, Bitmask settledIndice this.root = root; this.settledDestinationIndices = settledDestinationIndices; this.employedChannelDescriptors.add(root.channelDescriptor); + this.deterministicSignature = null; } /** @@ -1033,6 +1046,21 @@ void reroot(ChannelDescriptor newRootChannelDescriptor, this.employedChannelDescriptors.add(newRootChannelDescriptor); this.settledDestinationIndices.orInPlace(newRootSettledIndices); this.costs += edge.costEstimate; + this.deterministicSignature = null; + } + + private String getDeterministicSignature() { + if (this.deterministicSignature == null) { + final String descriptorSignature = this.employedChannelDescriptors.stream() + .map(Object::toString) + .sorted() + .collect(Collectors.joining("|")); + final String indexSignature = StreamSupport.stream(this.settledDestinationIndices.spliterator(), false) + .map(String::valueOf) + .collect(Collectors.joining(",")); + this.deterministicSignature = descriptorSignature + "#" + indexSignature; + } + return this.deterministicSignature; } @Override @@ -1090,7 +1118,7 @@ private void copyEdgesFrom(TreeVertex that) { * @return a {@link Set} of said {@link ChannelConversion}s */ private Set getChildChannelConversions() { - Set channelConversions = new HashSet<>(); + Set channelConversions = new LinkedHashSet<>(); for (TreeEdge edge : this.outEdges) { channelConversions.add(edge.channelConversion); channelConversions.addAll(edge.destination.getChildChannelConversions()); diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java index ba1d7d4b5..bec527383 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java @@ -42,7 +42,7 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -91,7 +91,7 @@ public class PlanEnumeration { * Creates a new instance. */ public PlanEnumeration() { - this(new HashSet<>(), new HashSet<>(), new HashSet<>()); + this(new LinkedHashSet<>(), new LinkedHashSet<>(), new LinkedHashSet<>()); } /** diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java index 501d92e11..9c0147af8 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java @@ -53,6 +53,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -255,7 +256,7 @@ Collection> findExecutionOperatorInputs(final InputSlot someInpu // Discern LoopHeadOperator InputSlots and loop body InputSlots. final List iterationImpls = loopImplementation.getIterationImplementations(); - final Collection> collector = new HashSet<>(innerInputs.size()); + final Collection> collector = new LinkedHashSet<>(innerInputs.size()); for (InputSlot innerInput : innerInputs) { if (innerInput.getOwner() == loopSubplan.getLoopHead()) { final LoopImplementation.IterationImplementation initialIterationImpl = iterationImpls.get(0); @@ -329,7 +330,7 @@ Collection, PlanImplementation>> findExecutionOperatorOutput // For all the iterations, return the potential OutputSlots. final List iterationImpls = loopImplementation.getIterationImplementations(); - final Set, PlanImplementation>> collector = new HashSet<>(iterationImpls.size()); + final Set, PlanImplementation>> collector = new LinkedHashSet<>(iterationImpls.size()); for (LoopImplementation.IterationImplementation iterationImpl : iterationImpls) { final Collection, PlanImplementation>> outputsWithContext = iterationImpl.getBodyImplementation().findExecutionOperatorOutputWithContext(innerOutput); @@ -695,8 +696,8 @@ public double getSquashedCostEstimate() { private Tuple, List> getParallelOperatorJunctionAllCostEstimate(Operator operator) { - Set inputOperators = new HashSet<>(); - Set inputJunction = new HashSet<>(); + Set inputOperators = new LinkedHashSet<>(); + Set inputJunction = new LinkedHashSet<>(); List probalisticCost = new ArrayList<>(); List squashedCost = new ArrayList<>(); diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/WayangCollections.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/WayangCollections.java index 5b3e8918f..b1eb6d787 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/WayangCollections.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/WayangCollections.java @@ -24,7 +24,7 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Iterator; import java.util.LinkedList; import java.util.List; @@ -59,7 +59,7 @@ public static Set asSet(Collection collection) { if (collection instanceof Set) { return (Set) collection; } - return new HashSet<>(collection); + return new LinkedHashSet<>(collection); } /** @@ -69,7 +69,7 @@ public static Set asSet(Iterable iterable) { if (iterable instanceof Set) { return (Set) iterable; } - Set set = new HashSet<>(); + Set set = new LinkedHashSet<>(); for (T t : iterable) { set.add(t); } @@ -80,7 +80,7 @@ public static Set asSet(Iterable iterable) { * Provides the given {@code values} as {@link Set}. */ public static Set asSet(T... values) { - Set set = new HashSet<>(values.length); + Set set = new LinkedHashSet<>(values.length); for (T value : values) { set.add(value); } diff --git a/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraphDeterminismTest.java b/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraphDeterminismTest.java new file mode 100644 index 000000000..7bdab816f --- /dev/null +++ b/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraphDeterminismTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.core.optimizer.channels; + +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.optimizer.DefaultOptimizationContext; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.optimizer.OptimizationUtils; +import org.apache.wayang.core.optimizer.cardinality.CardinalityEstimate; +import org.apache.wayang.core.plan.executionplan.Channel; +import org.apache.wayang.core.plan.executionplan.ExecutionTask; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.plan.wayangplan.InputSlot; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.Junction; +import org.apache.wayang.core.test.DummyExecutionOperator; +import org.apache.wayang.core.test.DummyExternalReusableChannel; +import org.apache.wayang.core.test.DummyNonReusableChannel; +import org.apache.wayang.core.test.DummyReusableChannel; +import org.apache.wayang.core.test.MockFactory; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class ChannelConversionGraphDeterminismTest { + + private static Supplier createDummyExecutionOperatorFactory(ChannelDescriptor channelDescriptor) { + return () -> { + ExecutionOperator execOp = new DummyExecutionOperator(1, 1, false); + execOp.getSupportedOutputChannels(0).add(channelDescriptor); + return execOp; + }; + } + + private static DefaultChannelConversion conversion(ChannelDescriptor source, ChannelDescriptor target) { + return new DefaultChannelConversion(source, target, createDummyExecutionOperatorFactory(target)); + } + + @Test + void channelConversionSelectionIsStable() { + List first = computeJunctionFingerprint(); + List second = computeJunctionFingerprint(); + assertEquals(first, second, "Channel conversion choices must be deterministic."); + } + + private static List computeJunctionFingerprint() { + Configuration configuration = new Configuration(); + ChannelConversionGraph graph = new ChannelConversionGraph(configuration); + graph.add(conversion(DummyReusableChannel.DESCRIPTOR, DummyNonReusableChannel.DESCRIPTOR)); + graph.add(conversion(DummyReusableChannel.DESCRIPTOR, DummyExternalReusableChannel.DESCRIPTOR)); + graph.add(conversion(DummyExternalReusableChannel.DESCRIPTOR, DummyNonReusableChannel.DESCRIPTOR)); + graph.add(conversion(DummyNonReusableChannel.DESCRIPTOR, DummyReusableChannel.DESCRIPTOR)); + + Job job = MockFactory.createJob(configuration); + OptimizationContext optimizationContext = new DefaultOptimizationContext(job); + + DummyExecutionOperator sourceOperator = new DummyExecutionOperator(0, 1, false); + sourceOperator.getSupportedOutputChannels(0).add(DummyReusableChannel.DESCRIPTOR); + optimizationContext.addOneTimeOperator(sourceOperator) + .setOutputCardinality(0, new CardinalityEstimate(1000, 1000, 1d)); + + DummyExecutionOperator destOperator0 = new DummyExecutionOperator(1, 1, false); + destOperator0.getSupportedInputChannels(0).add(DummyNonReusableChannel.DESCRIPTOR); + + DummyExecutionOperator destOperator1 = new DummyExecutionOperator(1, 1, false); + destOperator1.getSupportedInputChannels(0).add(DummyExternalReusableChannel.DESCRIPTOR); + + Junction junction = graph.findMinimumCostJunction( + sourceOperator.getOutput(0), + Arrays.asList(destOperator0.getInput(0), destOperator1.getInput(0)), + optimizationContext, + false + ); + + return describeJunction(junction); + } + + private static List describeJunction(Junction junction) { + List descriptorList = new ArrayList<>(); + descriptorList.add(describeChannel(junction.getSourceChannel(), true)); + for (int i = 0; i < junction.getNumTargets(); i++) { + descriptorList.add(describeChannel(junction.getTargetChannel(i), false)); + } + return descriptorList; + } + + private static String describeChannel(Channel channel, boolean isSourceChannel) { + if (channel == null) { + return "null"; + } + List descriptors = new ArrayList<>(); + Channel cursor = channel; + while (cursor != null) { + descriptors.add(cursor.getDescriptor().toString() + (cursor.isCopy() ? ":copy" : ":orig")); + ExecutionTask producer = cursor.getProducer(); + if (producer == null || producer.getNumInputChannels() == 0) { + break; + } + // If we are describing the top-level source channel (junction entry), stop once we reach the producer that + // has no inputs. For target channels, follow until the conversion tree ends. + if (isSourceChannel) { + cursor = producer.getNumInputChannels() == 0 ? null : producer.getInputChannel(0); + } else if (producer.getNumInputChannels() == 0) { + cursor = null; + } else { + cursor = producer.getInputChannel(0); + } + } + Collections.reverse(descriptors); + return descriptors.stream().collect(Collectors.joining("->")); + } +} diff --git a/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumerationDeterminismTest.java b/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumerationDeterminismTest.java index d6e6efcda..f72525740 100644 --- a/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumerationDeterminismTest.java +++ b/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumerationDeterminismTest.java @@ -34,6 +34,7 @@ import org.apache.wayang.core.test.DummyReusableChannel; import org.junit.jupiter.api.Test; +import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; @@ -62,32 +63,58 @@ void concatenationProducesStablePlanOrdering() { DummyExecutionOperator producer = new DummyExecutionOperator(0, 1, false); DummyExecutionOperator consumer = new DummyExecutionOperator(1, 0, false); - registerChannelDescriptors(producer, consumer); + registerChannelDescriptors(producer, Collections.singletonList(consumer)); registerLoadEstimator(configuration, producer, 10); registerLoadEstimator(configuration, consumer, 5); - List firstRun = enumerateDeterministicIds(job, producer, consumer); - List secondRun = enumerateDeterministicIds(job, producer, consumer); + List firstRun = enumerateDeterministicIds(job, producer, Collections.singletonList(consumer), 3, 2); + List secondRun = enumerateDeterministicIds(job, producer, Collections.singletonList(consumer), 3, 2); assertTrue(firstRun.size() > 1, "Expected multiple plan implementations."); assertEquals(firstRun, secondRun, "Enumeration order must be deterministic."); } + @Test + void concatenationWithMultipleTargetsRemainsStable() { + Configuration configuration = new Configuration(); + configuration.setPruningStrategyClassProvider( + new ExplicitCollectionProvider>(configuration) + ); + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + + DummyExecutionOperator producer = new DummyExecutionOperator(0, 1, false); + DummyExecutionOperator consumerA = new DummyExecutionOperator(1, 0, false); + DummyExecutionOperator consumerB = new DummyExecutionOperator(1, 0, false); + registerChannelDescriptors(producer, Arrays.asList(consumerA, consumerB)); + registerLoadEstimator(configuration, producer, 10); + registerLoadEstimator(configuration, consumerA, 7); + registerLoadEstimator(configuration, consumerB, 3); + + List firstRun = enumerateDeterministicIds(job, producer, Arrays.asList(consumerA, consumerB), 4, 2); + List secondRun = enumerateDeterministicIds(job, producer, Arrays.asList(consumerA, consumerB), 4, 2); + + assertEquals(firstRun, secondRun, "Enumeration order with multiple targets must be deterministic."); + } + private static List enumerateDeterministicIds(Job job, ExecutionOperator producer, - ExecutionOperator consumer) { + List consumers, + int numBaseCopies, + int numTargetCopies) { DefaultOptimizationContext optimizationContext = new DefaultOptimizationContext(job); optimizationContext.addOneTimeOperator(producer); - optimizationContext.addOneTimeOperator(consumer); + consumers.forEach(optimizationContext::addOneTimeOperator); PlanEnumeration baseEnumeration = PlanEnumeration.createSingleton(producer, optimizationContext); - duplicatePlanImplementations(baseEnumeration, 3); - - PlanEnumeration targetEnumeration = PlanEnumeration.createSingleton(consumer, optimizationContext); - duplicatePlanImplementations(targetEnumeration, 2); + duplicatePlanImplementations(baseEnumeration, numBaseCopies); Map, PlanEnumeration> targets = new LinkedHashMap<>(); - targets.put(consumer.getInput(0), targetEnumeration); + consumers.forEach(consumer -> { + PlanEnumeration targetEnumeration = PlanEnumeration.createSingleton((ExecutionOperator) consumer, optimizationContext); + duplicatePlanImplementations(targetEnumeration, numTargetCopies); + targets.put(consumer.getInput(0), targetEnumeration); + }); PlanEnumeration concatenated = baseEnumeration.concatenate( producer.getOutput(0), @@ -119,8 +146,8 @@ private static void registerLoadEstimator(Configuration configuration, } private static void registerChannelDescriptors(DummyExecutionOperator producer, - DummyExecutionOperator consumer) { + List consumers) { producer.getSupportedOutputChannels(0).add(DummyReusableChannel.DESCRIPTOR); - consumer.getSupportedInputChannels(0).add(DummyReusableChannel.DESCRIPTOR); + consumers.forEach(consumer -> consumer.getSupportedInputChannels(0).add(DummyReusableChannel.DESCRIPTOR)); } }