From 3333ffa654b50134b1fd25efa0be2762af2c6a24 Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 10:38:03 -0500 Subject: [PATCH 01/16] RayDP 2.0: migrate to Spark 4.1.1 / Java 17 / Scala 2.13 - Upgrade core POMs to Spark 4.1.1, Java 17, Scala 2.13.12 - Add spark411 shim module (SparkShims411, Spark411Helper, Spark411SQLHelper) with proper TaskContextImpl, ArrowConverters, and ClassicSparkSession APIs - Implement toDataFrame for Arrow batch deserialization via shim layer - Add RayCoarseGrainedExecutorBackend for Spark 4.1 executor backend - Disable legacy shim modules (spark322, spark330, spark340, spark350) - Add pyproject.toml (PEP 517) and simplify setup.py - Update CI workflow for Spark 4.1.1 --- .github/workflows/raydp.yml | 6 +- core/agent/pom.xml | 2 +- core/pom.xml | 42 +++++++--- core/raydp-main/pom.xml | 21 +++-- .../org/apache/spark/deploy/SparkSubmit.scala | 3 +- .../RayCoarseGrainedSchedulerBackend.scala | 11 ++- .../spark/sql/raydp/ObjectStoreWriter.scala | 8 +- core/shims/common/pom.xml | 6 +- .../com/intel/raydp/shims/SparkShims.scala | 3 + core/shims/pom.xml | 11 +-- core/shims/spark411/pom.xml | 64 +++++++++++++++ .../com.intel.raydp.shims.SparkShimProvider | 1 + .../com/intel/raydp/shims/SparkShims411.scala | 33 ++++++++ .../shims/spark411/SparkShimProvider.scala | 28 +++++++ .../org/apache/spark/Spark411Helper.scala | 72 +++++++++++++++++ .../RayCoarseGrainedExecutorBackend.scala | 50 ++++++++++++ .../apache/spark/sql/Spark411SQLHelper.scala | 81 +++++++++++++++++++ python/pyproject.toml | 45 +++++++++++ python/setup.py | 75 +++++------------ 19 files changed, 459 insertions(+), 103 deletions(-) create mode 100644 core/shims/spark411/pom.xml create mode 100644 core/shims/spark411/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider create mode 100644 core/shims/spark411/src/main/scala/com/intel/raydp/shims/SparkShims411.scala create mode 100644 core/shims/spark411/src/main/scala/com/intel/raydp/shims/spark411/SparkShimProvider.scala create mode 100644 core/shims/spark411/src/main/scala/org/apache/spark/Spark411Helper.scala create mode 100644 core/shims/spark411/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala create mode 100644 core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala create mode 100644 python/pyproject.toml diff --git a/.github/workflows/raydp.yml b/.github/workflows/raydp.yml index a24746b9..56491e91 100644 --- a/.github/workflows/raydp.yml +++ b/.github/workflows/raydp.yml @@ -32,9 +32,9 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.9, 3.10.14] - spark-version: [3.3.2, 3.4.0, 3.5.0] - ray-version: [2.37.0, 2.40.0, 2.50.0] + python-version: [3.10, 3.12] + spark-version: [4.1.1] + ray-version: [2.53.0] runs-on: ${{ matrix.os }} diff --git a/core/agent/pom.xml b/core/agent/pom.xml index cba5c222..2f263b09 100644 --- a/core/agent/pom.xml +++ b/core/agent/pom.xml @@ -7,7 +7,7 @@ com.intel raydp-parent - 1.7.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index ebe301c6..bbaa074d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -6,32 +6,32 @@ com.intel raydp-parent - 1.7.0-SNAPSHOT + 2.0.0-SNAPSHOT pom RayDP Parent Pom https://github.com/ray-project/raydp.git - 3.3.3 + 2.34.0 + 4.1.1 3.2.2 3.3.0 3.4.0 3.5.0 - 1.1.10.4 - 4.1.94.Final - 1.10.0 - 1.26.0 + 1.1.10.5 + 4.1.108.Final + 1.12.0 + 1.26.1 1.7.14.1 3.25.5 2.5.2 UTF-8 UTF-8 - 1.8 - 1.8 - 2.12.15 - 2.15.0 - 2.12 + 17 + 2.13.12 + 2.17.0 + 2.13 5.10.1 @@ -144,7 +144,7 @@ org.apache.commons commons-lang3 - 3.18.0 + 3.17.0 @@ -197,6 +197,24 @@ + + + + org.scalastyle + scalastyle-maven-plugin + 1.0.0 + + + org.apache.maven.plugins + maven-compiler-plugin + + ${java.version} + ${java.version} + ${java.version} + + + + org.apache.maven.plugins diff --git a/core/raydp-main/pom.xml b/core/raydp-main/pom.xml index 3c791a65..d95f634d 100644 --- a/core/raydp-main/pom.xml +++ b/core/raydp-main/pom.xml @@ -7,17 +7,13 @@ com.intel raydp-parent - 1.7.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml raydp raydp - - 2.1.0 - - sonatype @@ -128,7 +124,12 @@ org.apache.commons commons-lang3 - 3.18.0 + + + + javax.ws.rs + javax.ws.rs-api + 2.1.1 @@ -184,10 +185,6 @@ org.apache.maven.plugins maven-compiler-plugin 3.8.0 - - 1.8 - 1.8 - org.apache.maven.plugins @@ -202,7 +199,7 @@ net.alchim31.maven scala-maven-plugin - 3.3.3 + 4.8.1 scala-compile-first @@ -228,7 +225,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.7 + 3.2.5 diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 829517e5..98d98685 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -258,7 +258,6 @@ private[spark] class SparkSubmit extends Logging { } if (clusterManager == KUBERNETES) { - args.master = Utils.checkAndGetK8sMasterUrl(args.master) // Make sure KUBERNETES is included in our build if we're trying to use it if (!Utils.classIsLoadable(KUBERNETES_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { error( @@ -1044,7 +1043,7 @@ object SparkSubmit extends CommandLineUtils with Logging { super.doSubmit(args) } catch { case e: SparkUserAppException => - exitFn(e.exitCode) + exitFn(e.exitCode, None) } } diff --git a/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala b/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala index dce83a78..aa5fba47 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala @@ -21,7 +21,7 @@ import java.net.URI import java.util.concurrent.Semaphore import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} -import scala.collection.JavaConverters._ +import scala.jdk.CollectionConverters._ import scala.collection.mutable.HashMap import scala.concurrent.Future @@ -89,7 +89,7 @@ class RayCoarseGrainedSchedulerBackend( val appMasterResources = conf.getAll.filter { case (k, v) => k.startsWith(SparkOnRayConfigs.SPARK_MASTER_ACTOR_RESOURCE_PREFIX) - }.map{ case (k, v) => k->double2Double(v.toDouble) } + }.map{ case (k, v) => k->Double.box(v.toDouble) } masterHandle = RayAppMasterUtils.createAppMaster(cp, null, options.toBuffer.asJava, appMasterResources.toMap.asJava) @@ -154,14 +154,13 @@ class RayCoarseGrainedSchedulerBackend( } // Start executors with a few necessary configs for registering with the scheduler - val sparkJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf) // add Xmx, it should not be set in java opts, because Spark is not allowed. // We also add Xms to ensure the Xmx >= Xms val memoryLimit = Seq(s"-Xms${sc.executorMemory}M", s"-Xmx${sc.executorMemory}M") - val javaOpts = sparkJavaOpts ++ extraJavaOpts ++ memoryLimit ++ javaAgentOpt() + val javaOpts = extraJavaOpts ++ memoryLimit ++ javaAgentOpt() - val command = Command(driverUrl, sc.executorEnvs, + val command = Command(driverUrl, sc.executorEnvs.toMap, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) val coresPerExecutor = conf.getOption(config.EXECUTOR_CORES.key).map(_.toInt) @@ -250,7 +249,7 @@ class RayCoarseGrainedSchedulerBackend( } catch { case e: Exception => logWarning("Failed to connect to app master", e) - stop() + RayCoarseGrainedSchedulerBackend.this.stop() } } diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala index 7ff22660..4c910062 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala @@ -107,10 +107,10 @@ object ObjectStoreWriter { } val uuid = dfToId.getOrElseUpdate(df, UUID.randomUUID()) val queue = ObjectRefHolder.getQueue(uuid) - val rdd = df.toArrowBatchRdd + val rdd = SparkShimLoader.getSparkShims.toArrowBatchRdd(df) rdd.persist(storageLevel) rdd.count() - var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray + var executorIds = df.sparkSession.sparkContext.getExecutorIds.toArray val numExecutors = executorIds.length val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME) .get.asInstanceOf[ActorHandle[RayAppMaster]] @@ -167,11 +167,11 @@ object ObjectStoreWriter { "Not yet connected to Ray! Please set fault_tolerant_mode=True when starting RayDP.") } - val rdd = df.toArrowBatchRdd + val rdd = SparkShimLoader.getSparkShims.toArrowBatchRdd(df) rdd.persist(storageLevel) rdd.count() - var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray + var executorIds = df.sparkSession.sparkContext.getExecutorIds.toArray val numExecutors = executorIds.length val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME) .get.asInstanceOf[ActorHandle[RayAppMaster]] diff --git a/core/shims/common/pom.xml b/core/shims/common/pom.xml index f4f8dcc4..b6c4ca26 100644 --- a/core/shims/common/pom.xml +++ b/core/shims/common/pom.xml @@ -7,13 +7,13 @@ com.intel raydp-shims - 1.7.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml raydp-shims-common RayDP Shims Common - 1.7.0-SNAPSHOT + 2.0.0-SNAPSHOT jar @@ -25,7 +25,7 @@ net.alchim31.maven scala-maven-plugin - 3.2.2 + ${scala.plugin.version} scala-compile-first diff --git a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 2ca83522..c9e864da 100644 --- a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -21,6 +21,7 @@ import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.java.JavaRDD import org.apache.spark.executor.RayDPExecutorBackendFactory +import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SparkSession} @@ -40,4 +41,6 @@ trait SparkShims { def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext def toArrowSchema(schema : StructType, timeZoneId : String) : Schema + + def toArrowBatchRdd(df: DataFrame): RDD[Array[Byte]] } diff --git a/core/shims/pom.xml b/core/shims/pom.xml index c013538b..ca625c7c 100644 --- a/core/shims/pom.xml +++ b/core/shims/pom.xml @@ -7,7 +7,7 @@ com.intel raydp-parent - 1.7.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml @@ -17,15 +17,12 @@ common - spark322 - spark330 - spark340 - spark350 + spark411 - 2.12 - 4.3.0 + 2.13 + 4.8.1 3.2.2 diff --git a/core/shims/spark411/pom.xml b/core/shims/spark411/pom.xml new file mode 100644 index 00000000..1c80830c --- /dev/null +++ b/core/shims/spark411/pom.xml @@ -0,0 +1,64 @@ + + + + raydp-shims + com.intel + 2.0.0-SNAPSHOT + + 4.0.0 + + raydp-shims-spark411 + + + 4.1.1 + + + + + com.intel + raydp-shims-common + ${project.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + + + + + net.alchim31.maven + scala-maven-plugin + ${scala.plugin.version} + + + scala-compile-first + process-resources + + add-source + compile + + + + scala-test-compile + process-test-resources + + testCompile + + + + + + + \ No newline at end of file diff --git a/core/shims/spark411/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark411/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider new file mode 100644 index 00000000..e81063ad --- /dev/null +++ b/core/shims/spark411/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider @@ -0,0 +1 @@ +com.intel.raydp.shims.spark411.SparkShimProvider diff --git a/core/shims/spark411/src/main/scala/com/intel/raydp/shims/SparkShims411.scala b/core/shims/spark411/src/main/scala/com/intel/raydp/shims/SparkShims411.scala new file mode 100644 index 00000000..7cb7b8e7 --- /dev/null +++ b/core/shims/spark411/src/main/scala/com/intel/raydp/shims/SparkShims411.scala @@ -0,0 +1,33 @@ +package com.intel.raydp.shims + +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.{Spark411Helper, SparkEnv, TaskContext} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.executor.RayDPExecutorBackendFactory +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.{DataFrame, Spark411SQLHelper, SparkSession} + +class SparkShims411 extends SparkShims { + override def getShimDescriptor: ShimDescriptor = SparkShimDescriptor(4, 1, 1) + + override def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = { + Spark411SQLHelper.toDataFrame(rdd, schema, session) + } + + override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { + Spark411Helper.getExecutorBackendFactory + } + + override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { + Spark411Helper.getDummyTaskContext(partitionId, env) + } + + override def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { + Spark411SQLHelper.toArrowSchema(schema, timeZoneId) + } + + override def toArrowBatchRdd(df: DataFrame): RDD[Array[Byte]] = { + Spark411SQLHelper.toArrowBatchRdd(df) + } +} diff --git a/core/shims/spark411/src/main/scala/com/intel/raydp/shims/spark411/SparkShimProvider.scala b/core/shims/spark411/src/main/scala/com/intel/raydp/shims/spark411/SparkShimProvider.scala new file mode 100644 index 00000000..6e0f62b8 --- /dev/null +++ b/core/shims/spark411/src/main/scala/com/intel/raydp/shims/spark411/SparkShimProvider.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.raydp.shims.spark411 + +import com.intel.raydp.shims.{SparkShimProvider => BaseSparkShimProvider, SparkShims, SparkShims411} + +class SparkShimProvider extends BaseSparkShimProvider { + override def createShim: SparkShims = new SparkShims411() + + override def matches(version: String): Boolean = { + version.startsWith("4.1") + } +} diff --git a/core/shims/spark411/src/main/scala/org/apache/spark/Spark411Helper.scala b/core/shims/spark411/src/main/scala/org/apache/spark/Spark411Helper.scala new file mode 100644 index 00000000..ccf37472 --- /dev/null +++ b/core/shims/spark411/src/main/scala/org/apache/spark/Spark411Helper.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.executor.{CoarseGrainedExecutorBackend, RayCoarseGrainedExecutorBackend, RayDPExecutorBackendFactory, TaskMetrics} +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.rpc.RpcEnv + +import java.net.URL + +object Spark411Helper { + def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { + new RayDPExecutorBackendFactory { + override def createExecutorBackend( + rpcEnv: RpcEnv, + driverUrl: String, + executorId: String, + bindAddress: String, + hostname: String, + cores: Int, + userClassPath: Seq[URL], + env: SparkEnv, + resourcesFileOpt: Option[String], + resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { + new RayCoarseGrainedExecutorBackend( + rpcEnv, + driverUrl, + executorId, + bindAddress, + hostname, + cores, + userClassPath, + env, + resourcesFileOpt, + resourceProfile) + } + } + } + + def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { + new TaskContextImpl( + stageId = 0, + stageAttemptNumber = 0, + partitionId = partitionId, + taskAttemptId = 0, + attemptNumber = 0, + numPartitions = 0, + taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0), + localProperties = new java.util.Properties, + metricsSystem = env.metricsSystem, + taskMetrics = TaskMetrics.empty, + cpus = 0, + resources = Map.empty + ) + } +} diff --git a/core/shims/spark411/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala b/core/shims/spark411/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala new file mode 100644 index 00000000..2e6b5e25 --- /dev/null +++ b/core/shims/spark411/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.net.URL + +import org.apache.spark.SparkEnv +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.rpc.RpcEnv + +class RayCoarseGrainedExecutorBackend( + rpcEnv: RpcEnv, + driverUrl: String, + executorId: String, + bindAddress: String, + hostname: String, + cores: Int, + userClassPath: Seq[URL], + env: SparkEnv, + resourcesFileOpt: Option[String], + resourceProfile: ResourceProfile) + extends CoarseGrainedExecutorBackend( + rpcEnv, + driverUrl, + executorId, + bindAddress, + hostname, + cores, + env, + resourcesFileOpt, + resourceProfile) { + + override def getUserClassPath: Seq[URL] = userClassPath + +} diff --git a/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala b/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala new file mode 100644 index 00000000..b6853022 --- /dev/null +++ b/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.TaskContext +import org.apache.spark.sql.execution.arrow.ArrowConverters +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.sql.classic.{SparkSession => ClassicSparkSession} + +object Spark411SQLHelper { + def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { + ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames = true, largeVarTypes = false) + } + + def toArrowBatchRdd(df: DataFrame): org.apache.spark.rdd.RDD[Array[Byte]] = { + val conf = df.sparkSession.asInstanceOf[ClassicSparkSession].sessionState.conf + val timeZoneId = conf.sessionLocalTimeZone + val maxRecordsPerBatch = conf.getConf(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + df.queryExecution.toRdd.mapPartitions(iter => { + val context = TaskContext.get() + ArrowConverters.toBatchIterator( + iter, + df.schema, + maxRecordsPerBatch, + timeZoneId, + true, // errorOnDuplicatedFieldNames + false, // largeVarTypes + context) + }) + } + + /** + * Converts a JavaRDD of Arrow batches (serialized as byte arrays) to a DataFrame. + * This is the reverse operation of toArrowBatchRdd. + * + * @param rdd JavaRDD containing Arrow batches serialized as byte arrays + * @param schema JSON string representation of the StructType schema + * @param session SparkSession to use for DataFrame creation + * @return DataFrame reconstructed from the Arrow batches + */ + def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = { + val structType = DataType.fromJson(schema).asInstanceOf[StructType] + val classicSession = session.asInstanceOf[ClassicSparkSession] + + // Capture timezone on driver side - cannot access sessionState on executors + val timeZoneId = classicSession.sessionState.conf.sessionLocalTimeZone + + // Create an RDD of InternalRow by deserializing Arrow batches per partition + val rowRdd = rdd.rdd.flatMap { arrowBatch => + ArrowConverters.fromBatchIterator( + Iterator(arrowBatch), + structType, + timeZoneId, // Use captured value, not sessionState + true, // errorOnDuplicatedFieldNames + false, // largeVarTypes + TaskContext.get() + ) + } + + classicSession.internalCreateDataFrame(rowRdd, structType) + } +} diff --git a/python/pyproject.toml b/python/pyproject.toml new file mode 100644 index 00000000..cc948b9c --- /dev/null +++ b/python/pyproject.toml @@ -0,0 +1,45 @@ +[build-system] +requires = ["setuptools", "wheel", "grpcio-tools"] +build-backend = "setuptools.build_meta" + +[project] +name = "raydp" +dynamic = ["version", "readme"] +description = "RayDP: Distributed Data Processing on Ray" +authors = [ + {name = "RayDP Developers", email = "raydp-dev@googlegroups.com"} +] +license = "Apache-2.0" +keywords = ["raydp", "spark", "ray", "distributed", "data-processing"] +classifiers = [ + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +requires-python = ">=3.10" +dependencies = [ + "numpy", + "pandas >= 1.1.4", + "psutil", + "pyarrow >= 4.0.1", + "ray >= 2.53.0", + "pyspark >= 4.1.0", + "netifaces", + "protobuf > 3.19.5" +] + +[tool.setuptools] +packages = ["raydp", "raydp.jars", "raydp.bin"] +include-package-data = true + +[tool.setuptools.package-dir] +"raydp.jars" = "deps/jars" +"raydp.bin" = "deps/bin" +"mpi_network_proto" = "raydp/mpi/network" + +[tool.setuptools.package-data] +"raydp.jars" = ["*.jar"] +"raydp.bin" = ["raydp-submit"] + +[tool.setuptools.dynamic] +version = {attr = "setup.VERSION"} diff --git a/python/setup.py b/python/setup.py index dc9bbcc7..0ac907e3 100644 --- a/python/setup.py +++ b/python/setup.py @@ -22,12 +22,11 @@ from datetime import datetime from shutil import copy2, rmtree -from grpc_tools.command import build_package_protos from setuptools import find_packages, setup, Command build_mode = os.getenv("RAYDP_BUILD_MODE", "") package_name = os.getenv("RAYDP_PACKAGE_NAME", "raydp") -BASE_VERSION = "1.7.0" +BASE_VERSION = "2.0.0" if build_mode == "nightly": VERSION = BASE_VERSION + datetime.today().strftime("b%Y%m%d.dev0") # for legacy raydp_nightly package @@ -38,30 +37,14 @@ ROOT_DIR = os.path.dirname(__file__) -TEMP_PATH = "deps" CORE_DIR = os.path.abspath("../core") BIN_DIR = os.path.abspath("../bin") JARS_PATH = glob.glob(os.path.join(CORE_DIR, f"**/target/raydp-*.jar"), recursive=True) -JARS_TARGET = os.path.join(TEMP_PATH, "jars") +JARS_TARGET = os.path.join(ROOT_DIR, "raydp", "jars") SCRIPT_PATH = os.path.join(BIN_DIR, f"raydp-submit") -SCRIPT_TARGET = os.path.join(TEMP_PATH, "bin") - -if len(JARS_PATH) == 0: - print("Can't find core module jars, you need to build the jars with 'mvn clean package'" - " under core directory first.", file=sys.stderr) - sys.exit(-1) - -# build the temp dir -try: - os.mkdir(TEMP_PATH) - os.mkdir(JARS_TARGET) - os.mkdir(SCRIPT_TARGET) -except: - print(f"Temp path for symlink to parent already exists {TEMP_PATH}", file=sys.stderr) - sys.exit(-1) - +SCRIPT_TARGET = os.path.join(ROOT_DIR, "raydp", "bin") class CustomBuildPackageProtos(Command): """Command to generate project *_pb2.py modules from proto files. @@ -80,6 +63,7 @@ def finalize_options(self): pass def run(self): + from grpc_tools.command import build_package_protos # due to limitations of the proto generator, we require that only *one* # directory is provided as an 'include' directory. We assume it's the '' key # to `self.distribution.package_dir` (and get a key error if it's not @@ -87,24 +71,29 @@ def run(self): build_package_protos(self.distribution.package_dir["mpi_network_proto"], self.strict_mode) +if __name__ == "__main__": + if len(JARS_PATH) == 0: + print("Can't find core module jars, you need to build the jars with 'mvn clean package'" + " under core directory first.", file=sys.stderr) + sys.exit(-1) + + # build the temp dir + if os.path.exists(JARS_TARGET): + rmtree(JARS_TARGET) + if os.path.exists(SCRIPT_TARGET): + rmtree(SCRIPT_TARGET) + os.mkdir(JARS_TARGET) + os.mkdir(SCRIPT_TARGET) + with open(os.path.join(JARS_TARGET, "__init__.py"), "w") as f: + f.write("") + with open(os.path.join(SCRIPT_TARGET, "__init__.py"), "w") as f: + f.write("") -try: for jar_path in JARS_PATH: print(f"Copying {jar_path} to {JARS_TARGET}") copy2(jar_path, JARS_TARGET) copy2(SCRIPT_PATH, SCRIPT_TARGET) - install_requires = [ - "numpy", - "pandas >= 1.1.4", - "psutil", - "pyarrow >= 4.0.1", - "ray >= 2.37.0", - "pyspark >= 3.1.1, <=3.5.7", - "netifaces", - "protobuf > 3.19.5" - ] - _packages = find_packages() _packages.append("raydp.jars") _packages.append("raydp.bin") @@ -112,12 +101,6 @@ def run(self): setup( name=package_name, version=VERSION, - author="RayDP Developers", - author_email="raydp-dev@googlegroups.com", - license="Apache 2.0", - url="https://github.com/ray-project/raydp", - keywords="raydp spark ray distributed data-processing", - description="RayDP: Distributed Data Processing on Ray", long_description=io.open( os.path.join(ROOT_DIR, os.path.pardir, "README.md"), "r", @@ -125,23 +108,9 @@ def run(self): long_description_content_type="text/markdown", packages=_packages, include_package_data=True, - package_dir={"raydp.jars": "deps/jars", "raydp.bin": "deps/bin", - "mpi_network_proto": "raydp/mpi/network"}, + package_dir={"mpi_network_proto": "raydp/mpi/network"}, package_data={"raydp.jars": ["*.jar"], "raydp.bin": ["raydp-submit"]}, cmdclass={ 'build_proto_modules': CustomBuildPackageProtos, }, - install_requires=install_requires, - setup_requires=["grpcio-tools"], - python_requires='>=3.6', - classifiers=[ - 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - ] ) -finally: - rmtree(os.path.join(TEMP_PATH, "jars")) - rmtree(os.path.join(TEMP_PATH, "bin")) - os.rmdir(TEMP_PATH) From 453b2ae4305daa44dc0a70b9d84909a138aa483e Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 11:43:30 -0500 Subject: [PATCH 02/16] Add ScalaTest suite for Spark 4.1.1 shim, bump Scala to 2.13.17 - Add 13-test ScalaTest suite validating shim descriptor, SPI loading, Arrow schema mapping, and Arrow round-trip conversions (including nulls, timestamps, decimals, empty DataFrames, and multi-batch) - Fix bug in Spark411SQLHelper.toArrowBatchRdd where df.schema was captured inside mapPartitions closure causing CANNOT_INVOKE_IN_TRANSFORMATIONS on Spark 4.1 - Bump Scala from 2.13.12 to 2.13.17 to match Spark 4.1.1 --- core/pom.xml | 2 +- core/shims/spark411/pom.xml | 29 ++ .../apache/spark/sql/Spark411SQLHelper.scala | 3 +- .../raydp/shims/SparkShims411Suite.scala | 286 ++++++++++++++++++ 4 files changed, 318 insertions(+), 2 deletions(-) create mode 100644 core/shims/spark411/src/test/scala/com/intel/raydp/shims/SparkShims411Suite.scala diff --git a/core/pom.xml b/core/pom.xml index bbaa074d..f97c32e1 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -29,7 +29,7 @@ UTF-8 UTF-8 17 - 2.13.12 + 2.13.17 2.17.0 2.13 5.10.1 diff --git a/core/shims/spark411/pom.xml b/core/shims/spark411/pom.xml index 1c80830c..9810a847 100644 --- a/core/shims/spark411/pom.xml +++ b/core/shims/spark411/pom.xml @@ -33,6 +33,12 @@ ${spark.version} provided + + org.scalatest + scalatest_${scala.binary.version} + 3.2.18 + test + @@ -59,6 +65,29 @@ + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + org.scalatest + scalatest-maven-plugin + 2.2.0 + + --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=none + + + + test + + test + + + + \ No newline at end of file diff --git a/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala b/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala index b6853022..0b97fd78 100644 --- a/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala +++ b/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala @@ -35,11 +35,12 @@ object Spark411SQLHelper { val conf = df.sparkSession.asInstanceOf[ClassicSparkSession].sessionState.conf val timeZoneId = conf.sessionLocalTimeZone val maxRecordsPerBatch = conf.getConf(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + val schema = df.schema df.queryExecution.toRdd.mapPartitions(iter => { val context = TaskContext.get() ArrowConverters.toBatchIterator( iter, - df.schema, + schema, maxRecordsPerBatch, timeZoneId, true, // errorOnDuplicatedFieldNames diff --git a/core/shims/spark411/src/test/scala/com/intel/raydp/shims/SparkShims411Suite.scala b/core/shims/spark411/src/test/scala/com/intel/raydp/shims/SparkShims411Suite.scala new file mode 100644 index 00000000..96b6082b --- /dev/null +++ b/core/shims/spark411/src/test/scala/com/intel/raydp/shims/SparkShims411Suite.scala @@ -0,0 +1,286 @@ +package com.intel.raydp.shims + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import java.sql.{Date, Timestamp} +import java.time.{LocalDate, ZoneId} + +import org.apache.arrow.vector.types.pojo.ArrowType +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.types._ + +class SparkShims411Suite extends AnyFunSuite with BeforeAndAfterAll { + + private var spark: SparkSession = _ + + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local[2]") + .appName("SparkShims411Suite") + .config("spark.driver.host", "127.0.0.1") + .config("spark.driver.bindAddress", "127.0.0.1") + .config("spark.ui.enabled", "false") + .getOrCreate() + } + + override def afterAll(): Unit = { + if (spark != null) { + spark.stop() + spark = null + } + super.afterAll() + } + + test("shim descriptor returns 4.1.1") { + val shim = new SparkShims411() + val descriptor = shim.getShimDescriptor + assert(descriptor.isInstanceOf[SparkShimDescriptor]) + val sparkDescriptor = descriptor.asInstanceOf[SparkShimDescriptor] + assert(sparkDescriptor.major === 4) + assert(sparkDescriptor.minor === 1) + assert(sparkDescriptor.patch === 1) + assert(sparkDescriptor.toString === "4.1.1") + } + + test("provider matches 4.1 versions") { + val provider = new spark411.SparkShimProvider() + assert(provider.matches("4.1.1")) + assert(provider.matches("4.1.0")) + assert(!provider.matches("3.5.0")) + assert(!provider.matches("4.0.0")) + } + + test("provider creates SparkShims411 instance") { + val provider = new spark411.SparkShimProvider() + val shim = provider.createShim + assert(shim.isInstanceOf[SparkShims411]) + } + + test("SPI service loading works") { + SparkShimLoader.setSparkShimProviderClass( + "com.intel.raydp.shims.spark411.SparkShimProvider") + val shim = SparkShimLoader.getSparkShims + assert(shim.isInstanceOf[SparkShims411]) + assert(shim.getShimDescriptor.toString === "4.1.1") + } + + test("toArrowSchema produces valid Arrow schema") { + val shim = new SparkShims411() + val sparkSchema = new StructType() + .add("id", IntegerType) + .add("name", StringType) + .add("value", DoubleType) + + val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone + val arrowSchema = shim.toArrowSchema(sparkSchema, timeZoneId) + + assert(arrowSchema.getFields.size() === 3) + assert(arrowSchema.getFields.get(0).getName === "id") + assert(arrowSchema.getFields.get(1).getName === "name") + assert(arrowSchema.getFields.get(2).getName === "value") + } + + test("Arrow round-trip: DataFrame to ArrowBatch to DataFrame") { + val shim = new SparkShims411() + + val schema = new StructType() + .add("id", IntegerType) + .add("name", StringType) + .add("value", DoubleType) + val rows = java.util.Arrays.asList( + Row(1, "alice", 10.0), + Row(2, "bob", 20.0), + Row(3, "carol", 30.0)) + val original = spark.createDataFrame(rows, schema) + + val arrowRdd = shim.toArrowBatchRdd(original) + val schemaJson = original.schema.json + + val restored = shim.toDataFrame( + arrowRdd.toJavaRDD(), schemaJson, spark) + + val originalRows = original.collect().sortBy(_.getInt(0)) + val restoredRows = restored.collect().sortBy(_.getInt(0)) + + assert(originalRows.length === restoredRows.length) + originalRows.zip(restoredRows).foreach { case (orig, rest) => + assert(orig.getInt(0) === rest.getInt(0)) + assert(orig.getString(1) === rest.getString(1)) + assert(orig.getDouble(2) === rest.getDouble(2)) + } + } + + test("toArrowSchema maps Spark types to correct Arrow types") { + val shim = new SparkShims411() + val sparkSchema = new StructType() + .add("bool", BooleanType) + .add("byte", ByteType) + .add("short", ShortType) + .add("int", IntegerType) + .add("long", LongType) + .add("float", FloatType) + .add("double", DoubleType) + .add("str", StringType) + .add("bin", BinaryType) + + val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone + val arrowSchema = shim.toArrowSchema(sparkSchema, timeZoneId) + val fields = arrowSchema.getFields + + assert(fields.get(0).getType.isInstanceOf[ArrowType.Bool]) + assert(fields.get(1).getType === new ArrowType.Int(8, true)) + assert(fields.get(2).getType === new ArrowType.Int(16, true)) + assert(fields.get(3).getType === new ArrowType.Int(32, true)) + assert(fields.get(4).getType === new ArrowType.Int(64, true)) + assert(fields.get(5).getType === new ArrowType.FloatingPoint( + org.apache.arrow.vector.types.FloatingPointPrecision.SINGLE)) + assert(fields.get(6).getType === new ArrowType.FloatingPoint( + org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE)) + assert(fields.get(7).getType.isInstanceOf[ArrowType.Utf8]) + assert(fields.get(8).getType.isInstanceOf[ArrowType.Binary]) + } + + test("Arrow round-trip preserves null values") { + val shim = new SparkShims411() + + val schema = new StructType() + .add("id", IntegerType, nullable = false) + .add("name", StringType, nullable = true) + .add("value", DoubleType, nullable = true) + val rows = java.util.Arrays.asList( + Row(1, "alice", 10.0), + Row(2, null, null), + Row(3, "carol", 30.0)) + val original = spark.createDataFrame(rows, schema) + + val arrowRdd = shim.toArrowBatchRdd(original) + val schemaJson = original.schema.json + val restored = shim.toDataFrame(arrowRdd.toJavaRDD(), schemaJson, spark) + + val restoredRows = restored.collect().sortBy(_.getInt(0)) + assert(restoredRows.length === 3) + assert(restoredRows(0).getString(1) === "alice") + assert(restoredRows(1).isNullAt(1)) + assert(restoredRows(1).isNullAt(2)) + assert(restoredRows(2).getDouble(2) === 30.0) + } + + test("Arrow round-trip with multiple batches") { + val shim = new SparkShims411() + + // Force small batches: 2 records per batch + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "2") + try { + val schema = new StructType() + .add("id", IntegerType) + .add("label", StringType) + val rows = (1 to 10).map(i => Row(i, s"row_$i")) + val original = spark.createDataFrame( + java.util.Arrays.asList(rows: _*), schema) + .repartition(1) // single partition so multiple batches are in one partition + + val arrowRdd = shim.toArrowBatchRdd(original) + val schemaJson = original.schema.json + val restored = shim.toDataFrame(arrowRdd.toJavaRDD(), schemaJson, spark) + + val restoredRows = restored.collect().sortBy(_.getInt(0)) + assert(restoredRows.length === 10) + (1 to 10).foreach { i => + assert(restoredRows(i - 1).getInt(0) === i) + assert(restoredRows(i - 1).getString(1) === s"row_$i") + } + } finally { + spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") + } + } + + test("Arrow round-trip preserves Timestamp and Date values") { + val shim = new SparkShims411() + + val ts1 = Timestamp.valueOf("2025-01-15 10:30:00") + val ts2 = Timestamp.valueOf("2025-06-30 23:59:59") + val d1 = Date.valueOf("2025-01-15") + val d2 = Date.valueOf("2025-06-30") + + val schema = new StructType() + .add("id", IntegerType) + .add("ts", TimestampType) + .add("dt", DateType) + val rows = java.util.Arrays.asList( + Row(1, ts1, d1), + Row(2, ts2, d2)) + val original = spark.createDataFrame(rows, schema) + + val arrowRdd = shim.toArrowBatchRdd(original) + val schemaJson = original.schema.json + val restored = shim.toDataFrame(arrowRdd.toJavaRDD(), schemaJson, spark) + + val restoredRows = restored.collect().sortBy(_.getInt(0)) + assert(restoredRows.length === 2) + assert(restoredRows(0).getTimestamp(1) === ts1) + assert(restoredRows(0).getDate(2) === d1) + assert(restoredRows(1).getTimestamp(1) === ts2) + assert(restoredRows(1).getDate(2) === d2) + } + + test("Arrow round-trip preserves Decimal values") { + val shim = new SparkShims411() + + val schema = new StructType() + .add("id", IntegerType) + .add("price", DecimalType(18, 6)) + .add("quantity", DecimalType(10, 0)) + val rows = java.util.Arrays.asList( + Row(1, new java.math.BigDecimal("12345.678900"), new java.math.BigDecimal("100")), + Row(2, new java.math.BigDecimal("0.000001"), new java.math.BigDecimal("0")), + Row(3, new java.math.BigDecimal("-9999.999999"), new java.math.BigDecimal("999"))) + val original = spark.createDataFrame(rows, schema) + + val arrowRdd = shim.toArrowBatchRdd(original) + val schemaJson = original.schema.json + val restored = shim.toDataFrame(arrowRdd.toJavaRDD(), schemaJson, spark) + + val restoredRows = restored.collect().sortBy(_.getInt(0)) + assert(restoredRows.length === 3) + assert(restoredRows(0).getDecimal(1).compareTo( + new java.math.BigDecimal("12345.678900")) === 0) + assert(restoredRows(1).getDecimal(1).compareTo( + new java.math.BigDecimal("0.000001")) === 0) + assert(restoredRows(2).getDecimal(1).compareTo( + new java.math.BigDecimal("-9999.999999")) === 0) + assert(restoredRows(2).getDecimal(2).compareTo( + new java.math.BigDecimal("999")) === 0) + } + + test("Arrow round-trip with empty DataFrame") { + val shim = new SparkShims411() + + val schema = new StructType() + .add("id", IntegerType) + .add("name", StringType) + val rows = java.util.Arrays.asList[Row]() + val original = spark.createDataFrame(rows, schema) + + val arrowRdd = shim.toArrowBatchRdd(original) + val schemaJson = original.schema.json + val restored = shim.toDataFrame(arrowRdd.toJavaRDD(), schemaJson, spark) + + assert(restored.collect().length === 0) + assert(restored.schema === original.schema) + } + + test("getDummyTaskContext returns valid TaskContext") { + val shim = new SparkShims411() + val env = SparkEnv.get + + val ctx = shim.getDummyTaskContext(42, env) + + assert(ctx.isInstanceOf[TaskContext]) + assert(ctx.partitionId() === 42) + assert(ctx.stageId() === 0) + assert(ctx.attemptNumber() === 0) + } +} From bd16e3cc411bf04a6ed28007321134fa049e8f1d Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 12:38:06 -0500 Subject: [PATCH 03/16] Fix NPE in BlockInfoManager.lockForReading for Spark 4.1.1 Spark 4.1.1's BlockInfoManager requires tasks to be registered via registerTask() before they can acquire read locks. getRDDPartition was creating a dummy TaskContext without registering, causing an NPE in lockForReading when converting DataFrames to Ray Datasets. - Generate unique taskAttemptIds via AtomicLong (starting at 1000000L) to avoid collisions with real Spark tasks and concurrent calls - Register task with BlockManager before block access - Release all locks and unset TaskContext in finally block - Add regression test confirming the NPE without registration - Add lifecycle test validating register/read/cleanup flow --- .../apache/spark/executor/RayDPExecutor.scala | 53 ++++++----- .../org/apache/spark/Spark411Helper.scala | 11 ++- .../raydp/shims/SparkShims411Suite.scala | 88 ++++++++++++++++++- 3 files changed, 126 insertions(+), 26 deletions(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala index 0ed699dd..7962bb90 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala @@ -324,29 +324,36 @@ class RayDPExecutor( val env = SparkEnv.get val context = SparkShimLoader.getSparkShims.getDummyTaskContext(partitionId, env) TaskContext.setTaskContext(context) - val schema = Schema.fromJSON(schemaStr) - val blockId = BlockId.apply("rdd_" + rddId + "_" + partitionId) - val iterator = env.blockManager.get(blockId)(classTag[Array[Byte]]) match { - case Some(blockResult) => - blockResult.data.asInstanceOf[Iterator[Array[Byte]]] - case None => - logWarning("The cached block has been lost. Cache it again via driver agent") - requestRecacheRDD(rddId, driverAgentUrl) - env.blockManager.get(blockId)(classTag[Array[Byte]]) match { - case Some(blockResult) => - blockResult.data.asInstanceOf[Iterator[Array[Byte]]] - case None => - throw new RayDPException("Still cannot get the block after recache!") - } + val taskAttemptId = context.taskAttemptId() + env.blockManager.registerTask(taskAttemptId) + try { + val schema = Schema.fromJSON(schemaStr) + val blockId = BlockId.apply("rdd_" + rddId + "_" + partitionId) + val iterator = env.blockManager.get(blockId)(classTag[Array[Byte]]) match { + case Some(blockResult) => + blockResult.data.asInstanceOf[Iterator[Array[Byte]]] + case None => + logWarning("The cached block has been lost. Cache it again via driver agent") + requestRecacheRDD(rddId, driverAgentUrl) + env.blockManager.get(blockId)(classTag[Array[Byte]]) match { + case Some(blockResult) => + blockResult.data.asInstanceOf[Iterator[Array[Byte]]] + case None => + throw new RayDPException("Still cannot get the block after recache!") + } + } + val byteOut = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(byteOut)) + MessageSerializer.serialize(writeChannel, schema) + iterator.foreach(writeChannel.write) + ArrowStreamWriter.writeEndOfStream(writeChannel, new IpcOption) + val result = byteOut.toByteArray + writeChannel.close + byteOut.close + result + } finally { + env.blockManager.releaseAllLocksForTask(taskAttemptId) + TaskContext.unset() } - val byteOut = new ByteArrayOutputStream() - val writeChannel = new WriteChannel(Channels.newChannel(byteOut)) - MessageSerializer.serialize(writeChannel, schema) - iterator.foreach(writeChannel.write) - ArrowStreamWriter.writeEndOfStream(writeChannel, new IpcOption) - val result = byteOut.toByteArray - writeChannel.close - byteOut.close - result } } diff --git a/core/shims/spark411/src/main/scala/org/apache/spark/Spark411Helper.scala b/core/shims/spark411/src/main/scala/org/apache/spark/Spark411Helper.scala index ccf37472..24afe7dc 100644 --- a/core/shims/spark411/src/main/scala/org/apache/spark/Spark411Helper.scala +++ b/core/shims/spark411/src/main/scala/org/apache/spark/Spark411Helper.scala @@ -23,8 +23,10 @@ import org.apache.spark.resource.ResourceProfile import org.apache.spark.rpc.RpcEnv import java.net.URL +import java.util.concurrent.atomic.AtomicLong object Spark411Helper { + private val nextTaskAttemptId = new AtomicLong(1000000L) def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { new RayDPExecutorBackendFactory { override def createExecutorBackend( @@ -53,15 +55,20 @@ object Spark411Helper { } } + def setTaskContext(ctx: TaskContext): Unit = TaskContext.setTaskContext(ctx) + + def unsetTaskContext(): Unit = TaskContext.unset() + def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { + val taskAttemptId = nextTaskAttemptId.getAndIncrement() new TaskContextImpl( stageId = 0, stageAttemptNumber = 0, partitionId = partitionId, - taskAttemptId = 0, + taskAttemptId = taskAttemptId, attemptNumber = 0, numPartitions = 0, - taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0), + taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskAttemptId), localProperties = new java.util.Properties, metricsSystem = env.metricsSystem, taskMetrics = TaskMetrics.empty, diff --git a/core/shims/spark411/src/test/scala/com/intel/raydp/shims/SparkShims411Suite.scala b/core/shims/spark411/src/test/scala/com/intel/raydp/shims/SparkShims411Suite.scala index 96b6082b..26a0908d 100644 --- a/core/shims/spark411/src/test/scala/com/intel/raydp/shims/SparkShims411Suite.scala +++ b/core/shims/spark411/src/test/scala/com/intel/raydp/shims/SparkShims411Suite.scala @@ -6,10 +6,13 @@ import org.scalatest.funsuite.AnyFunSuite import java.sql.{Date, Timestamp} import java.time.{LocalDate, ZoneId} +import scala.reflect.classTag + import org.apache.arrow.vector.types.pojo.ArrowType -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.{Spark411Helper, SparkEnv, TaskContext} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.types._ +import org.apache.spark.storage.{RDDBlockId, StorageLevel} class SparkShims411Suite extends AnyFunSuite with BeforeAndAfterAll { @@ -283,4 +286,87 @@ class SparkShims411Suite extends AnyFunSuite with BeforeAndAfterAll { assert(ctx.stageId() === 0) assert(ctx.attemptNumber() === 0) } + + test("getDummyTaskContext generates unique taskAttemptIds") { + val shim = new SparkShims411() + val env = SparkEnv.get + + val ids = (0 until 100).map { i => + shim.getDummyTaskContext(i, env).taskAttemptId() + } + + assert(ids.distinct.size === 100, "all taskAttemptIds must be unique") + assert(ids.forall(_ >= 1000000L), "taskAttemptIds must start at 1000000+") + } + + test("BlockManager.get NPEs without registerTask (regression)") { + // This documents the exact NPE bug we fixed: Spark 4.1's BlockInfoManager + // requires tasks to be registered before they can acquire read locks. + // Setting a TaskContext with a specific taskAttemptId WITHOUT calling + // registerTask causes lockForReading to NPE on: + // readLocksByTask.get(taskAttemptId).add(blockId) + // Note: if NO TaskContext were set, BlockInfoManager would use the + // pre-registered NON_TASK_WRITER (-1024) and no NPE would occur. + val env = SparkEnv.get + val blockManager = env.blockManager + + // Cache a small RDD to have a block to read + val rdd = spark.sparkContext.parallelize(Seq(Array[Byte](1, 2, 3)), 1) + rdd.persist(StorageLevel.MEMORY_ONLY) + rdd.count() // force materialization + val blockId = RDDBlockId(rdd.id, 0) + + // Without registerTask: set TaskContext but do NOT register → NPE + val ctx = Spark411Helper.getDummyTaskContext(0, env) + Spark411Helper.setTaskContext(ctx) + try { + val ex = intercept[NullPointerException] { + blockManager.get(blockId)(classTag[Array[Byte]]) + } + assert(ex != null) + } finally { + Spark411Helper.unsetTaskContext() + } + + rdd.unpersist(blocking = true) + } + + test("BlockManager.get succeeds with registerTask and cleanup") { + // Full lifecycle test mirroring RayDPExecutor.getRDDPartition: + // register task → read block → release locks → unset context + val env = SparkEnv.get + val blockManager = env.blockManager + + val rdd = spark.sparkContext.parallelize(Seq(Array[Byte](10, 20, 30)), 1) + rdd.persist(StorageLevel.MEMORY_ONLY) + rdd.count() + val blockId = RDDBlockId(rdd.id, 0) + + val ctx = Spark411Helper.getDummyTaskContext(0, env) + Spark411Helper.setTaskContext(ctx) + val taskAttemptId = ctx.taskAttemptId() + blockManager.registerTask(taskAttemptId) + try { + val result = blockManager.get(blockId)(classTag[Array[Byte]]) + assert(result.isDefined, "block should be readable after registerTask") + } finally { + blockManager.releaseAllLocksForTask(taskAttemptId) + Spark411Helper.unsetTaskContext() + } + + // Verify lock was released: a second task can also read the same block + val ctx2 = Spark411Helper.getDummyTaskContext(0, env) + Spark411Helper.setTaskContext(ctx2) + val taskAttemptId2 = ctx2.taskAttemptId() + blockManager.registerTask(taskAttemptId2) + try { + val result2 = blockManager.get(blockId)(classTag[Array[Byte]]) + assert(result2.isDefined, "block should be readable by a second task after cleanup") + } finally { + blockManager.releaseAllLocksForTask(taskAttemptId2) + Spark411Helper.unsetTaskContext() + } + + rdd.unpersist(blocking = true) + } } From 1e9347a9358ed9c46ff3b573fa31d0cb804aebaa Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 14:16:16 -0500 Subject: [PATCH 04/16] Arrow performance optimizations for Spark 4.1.1 - Default Arrow IPC zstd compression (spark.sql.execution.arrow.compression.codec) - Pre-size ByteArrayOutputStream in getRDDPartition using BlockResult.bytes() - Read arrowUseLargeVarTypes from SQLConf instead of hardcoding false --- .../org/apache/spark/executor/RayDPExecutor.scala | 9 +++++---- .../apache/spark/sql/raydp/ObjectStoreWriter.scala | 3 ++- .../scala/com/intel/raydp/shims/SparkShims.scala | 2 +- .../scala/com/intel/raydp/shims/SparkShims411.scala | 4 ++-- .../org/apache/spark/sql/Spark411SQLHelper.scala | 12 +++++++----- python/raydp/spark/ray_cluster.py | 6 ++++++ 6 files changed, 23 insertions(+), 13 deletions(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala index 7962bb90..79063553 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala @@ -329,20 +329,21 @@ class RayDPExecutor( try { val schema = Schema.fromJSON(schemaStr) val blockId = BlockId.apply("rdd_" + rddId + "_" + partitionId) - val iterator = env.blockManager.get(blockId)(classTag[Array[Byte]]) match { + val (iterator, blockBytes) = env.blockManager.get(blockId)(classTag[Array[Byte]]) match { case Some(blockResult) => - blockResult.data.asInstanceOf[Iterator[Array[Byte]]] + (blockResult.data.asInstanceOf[Iterator[Array[Byte]]], blockResult.bytes) case None => logWarning("The cached block has been lost. Cache it again via driver agent") requestRecacheRDD(rddId, driverAgentUrl) env.blockManager.get(blockId)(classTag[Array[Byte]]) match { case Some(blockResult) => - blockResult.data.asInstanceOf[Iterator[Array[Byte]]] + (blockResult.data.asInstanceOf[Iterator[Array[Byte]]], blockResult.bytes) case None => throw new RayDPException("Still cannot get the block after recache!") } } - val byteOut = new ByteArrayOutputStream() + val estimatedSize = Math.max(blockBytes + 1024, 1024).toInt + val byteOut = new ByteArrayOutputStream(estimatedSize) val writeChannel = new WriteChannel(Channels.newChannel(byteOut)) MessageSerializer.serialize(writeChannel, schema) iterator.foreach(writeChannel.write) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala index 4c910062..282781f3 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala @@ -96,7 +96,8 @@ object ObjectStoreWriter { def toArrowSchema(df: DataFrame): Schema = { val conf = df.queryExecution.sparkSession.sessionState.conf val timeZoneId = conf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE) - SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId) + val largeVarTypes = conf.arrowUseLargeVarTypes + SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId, largeVarTypes) } @deprecated diff --git a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala index c9e864da..31d507cf 100644 --- a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -40,7 +40,7 @@ trait SparkShims { def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema + def toArrowSchema(schema : StructType, timeZoneId : String, largeVarTypes : Boolean = false) : Schema def toArrowBatchRdd(df: DataFrame): RDD[Array[Byte]] } diff --git a/core/shims/spark411/src/main/scala/com/intel/raydp/shims/SparkShims411.scala b/core/shims/spark411/src/main/scala/com/intel/raydp/shims/SparkShims411.scala index 7cb7b8e7..c44183be 100644 --- a/core/shims/spark411/src/main/scala/com/intel/raydp/shims/SparkShims411.scala +++ b/core/shims/spark411/src/main/scala/com/intel/raydp/shims/SparkShims411.scala @@ -23,8 +23,8 @@ class SparkShims411 extends SparkShims { Spark411Helper.getDummyTaskContext(partitionId, env) } - override def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { - Spark411SQLHelper.toArrowSchema(schema, timeZoneId) + override def toArrowSchema(schema: StructType, timeZoneId: String, largeVarTypes: Boolean = false): Schema = { + Spark411SQLHelper.toArrowSchema(schema, timeZoneId, largeVarTypes) } override def toArrowBatchRdd(df: DataFrame): RDD[Array[Byte]] = { diff --git a/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala b/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala index 0b97fd78..20bb0419 100644 --- a/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala +++ b/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala @@ -27,14 +27,15 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.classic.{SparkSession => ClassicSparkSession} object Spark411SQLHelper { - def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { - ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames = true, largeVarTypes = false) + def toArrowSchema(schema: StructType, timeZoneId: String, largeVarTypes: Boolean = false): Schema = { + ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames = true, largeVarTypes = largeVarTypes) } def toArrowBatchRdd(df: DataFrame): org.apache.spark.rdd.RDD[Array[Byte]] = { val conf = df.sparkSession.asInstanceOf[ClassicSparkSession].sessionState.conf val timeZoneId = conf.sessionLocalTimeZone val maxRecordsPerBatch = conf.getConf(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + val largeVarTypes = conf.arrowUseLargeVarTypes val schema = df.schema df.queryExecution.toRdd.mapPartitions(iter => { val context = TaskContext.get() @@ -44,7 +45,7 @@ object Spark411SQLHelper { maxRecordsPerBatch, timeZoneId, true, // errorOnDuplicatedFieldNames - false, // largeVarTypes + largeVarTypes, context) }) } @@ -62,8 +63,9 @@ object Spark411SQLHelper { val structType = DataType.fromJson(schema).asInstanceOf[StructType] val classicSession = session.asInstanceOf[ClassicSparkSession] - // Capture timezone on driver side - cannot access sessionState on executors + // Capture timezone and largeVarTypes on driver side - cannot access sessionState on executors val timeZoneId = classicSession.sessionState.conf.sessionLocalTimeZone + val largeVarTypes = classicSession.sessionState.conf.arrowUseLargeVarTypes // Create an RDD of InternalRow by deserializing Arrow batches per partition val rowRdd = rdd.rdd.flatMap { arrowBatch => @@ -72,7 +74,7 @@ object Spark411SQLHelper { structType, timeZoneId, // Use captured value, not sessionState true, // errorOnDuplicatedFieldNames - false, // largeVarTypes + largeVarTypes, TaskContext.get() ) } diff --git a/python/raydp/spark/ray_cluster.py b/python/raydp/spark/ray_cluster.py index 10816d25..4e712d5b 100644 --- a/python/raydp/spark/ray_cluster.py +++ b/python/raydp/spark/ray_cluster.py @@ -119,6 +119,12 @@ def _prepare_spark_configs(self): self._configs["spark.executor.instances"] = str(self._num_executors) self._configs["spark.executor.cores"] = str(self._executor_cores) self._configs["spark.executor.memory"] = str(self._executor_memory) + # Enable Arrow IPC zstd compression by default (Spark 4.1+). + # Users can override by passing their own value for this config. + arrow_codec_key = "spark.sql.execution.arrow.compression.codec" + if arrow_codec_key not in self._configs: + self._configs[arrow_codec_key] = "zstd" + if platform.system() != "Darwin": driver_node_ip = ray.util.get_node_ip_address() if "spark.driver.host" not in self._configs: From ae7aeff16b0f64f6475ae9508380e0538d3bab9a Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 14:37:14 -0500 Subject: [PATCH 05/16] Arrow performance optimizations and bump Ray to 2.47.1 - Bump Ray compile-time dependency from 2.34.0 to 2.47.1 (latest on Maven Central) --- core/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/pom.xml b/core/pom.xml index f97c32e1..0eca3538 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -13,7 +13,7 @@ https://github.com/ray-project/raydp.git - 2.34.0 + 2.47.1 4.1.1 3.2.2 3.3.0 From 2b4d26467388b7dd85d65d57d6b137aa6f3b9d0d Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 14:53:14 -0500 Subject: [PATCH 06/16] Remove legacy Spark 3.x shim modules (spark322/330/340/350) RayDP 2.0 only targets Spark 4.1 These modules were already excluded from the build but cluttered the source tree. --- .gitignore | 2 + core/pom.xml | 4 - core/shims/spark322/pom.xml | 132 ------------------ .../com.intel.raydp.shims.SparkShimProvider | 1 - .../intel/raydp/shims/SparkShimProvider.scala | 46 ------ .../com/intel/raydp/shims/SparkShims.scala | 52 ------- .../org/apache/spark/TaskContextUtils.scala | 30 ---- .../RayDPSpark322ExecutorBackendFactory.scala | 53 ------- .../org/apache/spark/sql/SparkSqlUtils.scala | 35 ----- core/shims/spark330/pom.xml | 128 ----------------- .../com.intel.raydp.shims.SparkShimProvider | 1 - .../intel/raydp/shims/SparkShimProvider.scala | 40 ------ .../com/intel/raydp/shims/SparkShims.scala | 52 ------- .../org/apache/spark/TaskContextUtils.scala | 30 ---- .../RayCoarseGrainedExecutorBackend.scala | 50 ------- .../RayDPSpark330ExecutorBackendFactory.scala | 52 ------- .../org/apache/spark/sql/SparkSqlUtils.scala | 35 ----- core/shims/spark340/pom.xml | 99 ------------- .../com.intel.raydp.shims.SparkShimProvider | 1 - .../intel/raydp/shims/SparkShimProvider.scala | 41 ------ .../com/intel/raydp/shims/SparkShims.scala | 52 ------- .../org/apache/spark/TaskContextUtils.scala | 30 ---- .../RayCoarseGrainedExecutorBackend.scala | 50 ------- .../RayDPSpark340ExecutorBackendFactory.scala | 52 ------- .../org/apache/spark/sql/SparkSqlUtils.scala | 45 ------ core/shims/spark350/pom.xml | 99 ------------- .../com.intel.raydp.shims.SparkShimProvider | 1 - .../intel/raydp/shims/SparkShimProvider.scala | 37 ----- .../com/intel/raydp/shims/SparkShims.scala | 52 ------- .../org/apache/spark/TaskContextUtils.scala | 30 ---- .../RayCoarseGrainedExecutorBackend.scala | 50 ------- .../RayDPSpark350ExecutorBackendFactory.scala | 52 ------- .../org/apache/spark/sql/SparkSqlUtils.scala | 45 ------ 33 files changed, 2 insertions(+), 1477 deletions(-) delete mode 100644 core/shims/spark322/pom.xml delete mode 100644 core/shims/spark322/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider delete mode 100644 core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala delete mode 100644 core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala delete mode 100644 core/shims/spark322/src/main/scala/org/apache/spark/TaskContextUtils.scala delete mode 100644 core/shims/spark322/src/main/scala/org/apache/spark/executor/RayDPSpark322ExecutorBackendFactory.scala delete mode 100644 core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala delete mode 100644 core/shims/spark330/pom.xml delete mode 100644 core/shims/spark330/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider delete mode 100644 core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala delete mode 100644 core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala delete mode 100644 core/shims/spark330/src/main/scala/org/apache/spark/TaskContextUtils.scala delete mode 100644 core/shims/spark330/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala delete mode 100644 core/shims/spark330/src/main/scala/org/apache/spark/executor/RayDPSpark330ExecutorBackendFactory.scala delete mode 100644 core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala delete mode 100644 core/shims/spark340/pom.xml delete mode 100644 core/shims/spark340/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider delete mode 100644 core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala delete mode 100644 core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala delete mode 100644 core/shims/spark340/src/main/scala/org/apache/spark/TaskContextUtils.scala delete mode 100644 core/shims/spark340/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala delete mode 100644 core/shims/spark340/src/main/scala/org/apache/spark/executor/RayDPSpark340ExecutorBackendFactory.scala delete mode 100644 core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala delete mode 100644 core/shims/spark350/pom.xml delete mode 100644 core/shims/spark350/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider delete mode 100644 core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala delete mode 100644 core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala delete mode 100644 core/shims/spark350/src/main/scala/org/apache/spark/TaskContextUtils.scala delete mode 100644 core/shims/spark350/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala delete mode 100644 core/shims/spark350/src/main/scala/org/apache/spark/executor/RayDPSpark350ExecutorBackendFactory.scala delete mode 100644 core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala diff --git a/.gitignore b/.gitignore index 571df90a..11145713 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,5 @@ _SUCCESS .metals/ .bloop/ + +.venv/ \ No newline at end of file diff --git a/core/pom.xml b/core/pom.xml index 0eca3538..e0e279e8 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -15,10 +15,6 @@ 2.47.1 4.1.1 - 3.2.2 - 3.3.0 - 3.4.0 - 3.5.0 1.1.10.5 4.1.108.Final 1.12.0 diff --git a/core/shims/spark322/pom.xml b/core/shims/spark322/pom.xml deleted file mode 100644 index faff6ac5..00000000 --- a/core/shims/spark322/pom.xml +++ /dev/null @@ -1,132 +0,0 @@ - - - - 4.0.0 - - - com.intel - raydp-shims - 1.7.0-SNAPSHOT - ../pom.xml - - - raydp-shims-spark322 - RayDP Shims for Spark 3.2.2 - jar - - - 2.12.15 - 2.13.5 - - - - - - org.scalastyle - scalastyle-maven-plugin - - - net.alchim31.maven - scala-maven-plugin - 3.2.2 - - - scala-compile-first - process-resources - - compile - - - - scala-test-compile-first - process-test-resources - - testCompile - - - - - - - - - src/main/resources - - - - - - - com.intel - raydp-shims-common - ${project.version} - compile - - - org.apache.spark - spark-sql_${scala.binary.version} - ${spark322.version} - provided - - - com.google.protobuf - protobuf-java - - - - - org.apache.spark - spark-core_${scala.binary.version} - ${spark322.version} - provided - - - org.xerial.snappy - snappy-java - - - org.apache.commons - commons-compress - - - org.apache.commons - commons-text - - - org.apache.ivy - ivy - - - log4j - log4j - - - - - org.xerial.snappy - snappy-java - ${snappy.version} - - - org.apache.commons - commons-compress - ${commons.compress.version} - - - org.apache.commons - commons-text - ${commons.text.version} - - - org.apache.ivy - ivy - ${ivy.version} - - - com.google.protobuf - protobuf-java - ${protobuf.version} - - - diff --git a/core/shims/spark322/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark322/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider deleted file mode 100644 index 0ce0b134..00000000 --- a/core/shims/spark322/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider +++ /dev/null @@ -1 +0,0 @@ -com.intel.raydp.shims.spark322.SparkShimProvider diff --git a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala deleted file mode 100644 index db53a74a..00000000 --- a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.intel.raydp.shims.spark322 - -import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} - -object SparkShimProvider { - val SPARK311_DESCRIPTOR = SparkShimDescriptor(3, 1, 1) - val SPARK312_DESCRIPTOR = SparkShimDescriptor(3, 1, 2) - val SPARK313_DESCRIPTOR = SparkShimDescriptor(3, 1, 3) - val SPARK320_DESCRIPTOR = SparkShimDescriptor(3, 2, 0) - val SPARK321_DESCRIPTOR = SparkShimDescriptor(3, 2, 1) - val SPARK322_DESCRIPTOR = SparkShimDescriptor(3, 2, 2) - val SPARK323_DESCRIPTOR = SparkShimDescriptor(3, 2, 3) - val SPARK324_DESCRIPTOR = SparkShimDescriptor(3, 2, 4) - val DESCRIPTOR_STRINGS = - Seq(s"$SPARK311_DESCRIPTOR", s"$SPARK312_DESCRIPTOR" ,s"$SPARK313_DESCRIPTOR", - s"$SPARK320_DESCRIPTOR", s"$SPARK321_DESCRIPTOR", s"$SPARK322_DESCRIPTOR", - s"$SPARK323_DESCRIPTOR", s"$SPARK324_DESCRIPTOR") - val DESCRIPTOR = SPARK323_DESCRIPTOR -} - -class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { - def createShim: SparkShims = { - new Spark322Shims() - } - - def matches(version: String): Boolean = { - SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) - } -} diff --git a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala deleted file mode 100644 index 6ea817db..00000000 --- a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.intel.raydp.shims.spark322 - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.executor.RayDPExecutorBackendFactory -import org.apache.spark.executor.spark322._ -import org.apache.spark.spark322.TaskContextUtils -import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.spark322.SparkSqlUtils -import com.intel.raydp.shims.{ShimDescriptor, SparkShims} -import org.apache.arrow.vector.types.pojo.Schema -import org.apache.spark.sql.types.StructType - -class Spark322Shims extends SparkShims { - override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR - - override def toDataFrame( - rdd: JavaRDD[Array[Byte]], - schema: String, - session: SparkSession): DataFrame = { - SparkSqlUtils.toDataFrame(rdd, schema, session) - } - - override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { - new RayDPSpark322ExecutorBackendFactory() - } - - override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { - TaskContextUtils.getDummyTaskContext(partitionId, env) - } - - override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) - } -} diff --git a/core/shims/spark322/src/main/scala/org/apache/spark/TaskContextUtils.scala b/core/shims/spark322/src/main/scala/org/apache/spark/TaskContextUtils.scala deleted file mode 100644 index d658cf98..00000000 --- a/core/shims/spark322/src/main/scala/org/apache/spark/TaskContextUtils.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.spark322 - -import java.util.Properties - -import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} -import org.apache.spark.memory.TaskMemoryManager - -object TaskContextUtils { - def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { - new TaskContextImpl(0, 0, partitionId, -1024, 0, - new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) - } -} diff --git a/core/shims/spark322/src/main/scala/org/apache/spark/executor/RayDPSpark322ExecutorBackendFactory.scala b/core/shims/spark322/src/main/scala/org/apache/spark/executor/RayDPSpark322ExecutorBackendFactory.scala deleted file mode 100644 index d8673679..00000000 --- a/core/shims/spark322/src/main/scala/org/apache/spark/executor/RayDPSpark322ExecutorBackendFactory.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor.spark322 - -import java.net.URL - -import org.apache.spark.SparkEnv -import org.apache.spark.executor.CoarseGrainedExecutorBackend -import org.apache.spark.executor.RayDPExecutorBackendFactory -import org.apache.spark.resource.ResourceProfile -import org.apache.spark.rpc.RpcEnv - -class RayDPSpark322ExecutorBackendFactory - extends RayDPExecutorBackendFactory { - override def createExecutorBackend( - rpcEnv: RpcEnv, - driverUrl: String, - executorId: String, - bindAddress: String, - hostname: String, - cores: Int, - userClassPath: Seq[URL], - env: SparkEnv, - resourcesFileOpt: Option[String], - resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { - new CoarseGrainedExecutorBackend( - rpcEnv, - driverUrl, - executorId, - bindAddress, - hostname, - cores, - userClassPath, - env, - resourcesFileOpt, - resourceProfile) - } -} diff --git a/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala deleted file mode 100644 index be9b409c..00000000 --- a/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.spark322 - -import org.apache.arrow.vector.types.pojo.Schema -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} -import org.apache.spark.sql.execution.arrow.ArrowConverters -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.ArrowUtils - -object SparkSqlUtils { - def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = { - ArrowConverters.toDataFrame(rdd, schema, new SQLContext(session)) - } - - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) - } -} diff --git a/core/shims/spark330/pom.xml b/core/shims/spark330/pom.xml deleted file mode 100644 index 4443f658..00000000 --- a/core/shims/spark330/pom.xml +++ /dev/null @@ -1,128 +0,0 @@ - - - - 4.0.0 - - - com.intel - raydp-shims - 1.7.0-SNAPSHOT - ../pom.xml - - - raydp-shims-spark330 - RayDP Shims for Spark 3.3.0 - jar - - - 2.12.15 - 2.13.5 - - - - - - org.scalastyle - scalastyle-maven-plugin - - - net.alchim31.maven - scala-maven-plugin - 3.2.2 - - - scala-compile-first - process-resources - - compile - - - - scala-test-compile-first - process-test-resources - - testCompile - - - - - - - - - src/main/resources - - - - - - - com.intel - raydp-shims-common - ${project.version} - compile - - - org.apache.spark - spark-sql_${scala.binary.version} - ${spark330.version} - provided - - - com.google.protobuf - protobuf-java - - - - - org.apache.spark - spark-core_${scala.binary.version} - ${spark330.version} - provided - - - org.xerial.snappy - snappy-java - - - io.netty - netty-handler - - - org.apache.commons - commons-text - - - org.apache.ivy - ivy - - - - - org.xerial.snappy - snappy-java - ${snappy.version} - - - io.netty - netty-handler - ${netty.version} - - - org.apache.commons - commons-text - ${commons.text.version} - - - org.apache.ivy - ivy - ${ivy.version} - - - com.google.protobuf - protobuf-java - ${protobuf.version} - - - diff --git a/core/shims/spark330/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark330/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider deleted file mode 100644 index 184e2dfa..00000000 --- a/core/shims/spark330/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider +++ /dev/null @@ -1 +0,0 @@ -com.intel.raydp.shims.spark330.SparkShimProvider diff --git a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala deleted file mode 100644 index 7a5f5481..00000000 --- a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.intel.raydp.shims.spark330 - -import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} - -object SparkShimProvider { - val SPARK330_DESCRIPTOR = SparkShimDescriptor(3, 3, 0) - val SPARK331_DESCRIPTOR = SparkShimDescriptor(3, 3, 1) - val SPARK332_DESCRIPTOR = SparkShimDescriptor(3, 3, 2) - val SPARK333_DESCRIPTOR = SparkShimDescriptor(3, 3, 3) - val DESCRIPTOR_STRINGS = Seq(s"$SPARK330_DESCRIPTOR", s"$SPARK331_DESCRIPTOR", - s"$SPARK332_DESCRIPTOR", s"$SPARK333_DESCRIPTOR") - val DESCRIPTOR = SPARK332_DESCRIPTOR -} - -class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { - def createShim: SparkShims = { - new Spark330Shims() - } - - def matches(version: String): Boolean = { - SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) - } -} diff --git a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala deleted file mode 100644 index 4f1a50b5..00000000 --- a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.intel.raydp.shims.spark330 - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.executor.RayDPExecutorBackendFactory -import org.apache.spark.executor.spark330._ -import org.apache.spark.spark330.TaskContextUtils -import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.spark330.SparkSqlUtils -import com.intel.raydp.shims.{ShimDescriptor, SparkShims} -import org.apache.arrow.vector.types.pojo.Schema -import org.apache.spark.sql.types.StructType - -class Spark330Shims extends SparkShims { - override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR - - override def toDataFrame( - rdd: JavaRDD[Array[Byte]], - schema: String, - session: SparkSession): DataFrame = { - SparkSqlUtils.toDataFrame(rdd, schema, session) - } - - override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { - new RayDPSpark330ExecutorBackendFactory() - } - - override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { - TaskContextUtils.getDummyTaskContext(partitionId, env) - } - - override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) - } -} diff --git a/core/shims/spark330/src/main/scala/org/apache/spark/TaskContextUtils.scala b/core/shims/spark330/src/main/scala/org/apache/spark/TaskContextUtils.scala deleted file mode 100644 index 431167f4..00000000 --- a/core/shims/spark330/src/main/scala/org/apache/spark/TaskContextUtils.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.spark330 - -import java.util.Properties - -import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} -import org.apache.spark.memory.TaskMemoryManager - -object TaskContextUtils { - def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { - new TaskContextImpl(0, 0, partitionId, -1024, 0, - new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) - } -} diff --git a/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala b/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala deleted file mode 100644 index 2e6b5e25..00000000 --- a/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import java.net.URL - -import org.apache.spark.SparkEnv -import org.apache.spark.resource.ResourceProfile -import org.apache.spark.rpc.RpcEnv - -class RayCoarseGrainedExecutorBackend( - rpcEnv: RpcEnv, - driverUrl: String, - executorId: String, - bindAddress: String, - hostname: String, - cores: Int, - userClassPath: Seq[URL], - env: SparkEnv, - resourcesFileOpt: Option[String], - resourceProfile: ResourceProfile) - extends CoarseGrainedExecutorBackend( - rpcEnv, - driverUrl, - executorId, - bindAddress, - hostname, - cores, - env, - resourcesFileOpt, - resourceProfile) { - - override def getUserClassPath: Seq[URL] = userClassPath - -} diff --git a/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayDPSpark330ExecutorBackendFactory.scala b/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayDPSpark330ExecutorBackendFactory.scala deleted file mode 100644 index 7f01e979..00000000 --- a/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayDPSpark330ExecutorBackendFactory.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor.spark330 - -import java.net.URL - -import org.apache.spark.SparkEnv -import org.apache.spark.executor._ -import org.apache.spark.resource.ResourceProfile -import org.apache.spark.rpc.RpcEnv - -class RayDPSpark330ExecutorBackendFactory - extends RayDPExecutorBackendFactory { - override def createExecutorBackend( - rpcEnv: RpcEnv, - driverUrl: String, - executorId: String, - bindAddress: String, - hostname: String, - cores: Int, - userClassPath: Seq[URL], - env: SparkEnv, - resourcesFileOpt: Option[String], - resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { - new RayCoarseGrainedExecutorBackend( - rpcEnv, - driverUrl, - executorId, - bindAddress, - hostname, - cores, - userClassPath, - env, - resourcesFileOpt, - resourceProfile) - } -} diff --git a/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala deleted file mode 100644 index 162371ad..00000000 --- a/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.spark330 - -import org.apache.arrow.vector.types.pojo.Schema -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} -import org.apache.spark.sql.execution.arrow.ArrowConverters -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.ArrowUtils - -object SparkSqlUtils { - def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = { - ArrowConverters.toDataFrame(rdd, schema, session) - } - - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) - } -} diff --git a/core/shims/spark340/pom.xml b/core/shims/spark340/pom.xml deleted file mode 100644 index 1b312747..00000000 --- a/core/shims/spark340/pom.xml +++ /dev/null @@ -1,99 +0,0 @@ - - - - 4.0.0 - - - com.intel - raydp-shims - 1.7.0-SNAPSHOT - ../pom.xml - - - raydp-shims-spark340 - RayDP Shims for Spark 3.4.0 - jar - - - 2.12.15 - 2.13.5 - - - - - - org.scalastyle - scalastyle-maven-plugin - - - net.alchim31.maven - scala-maven-plugin - 3.2.2 - - - scala-compile-first - process-resources - - compile - - - - scala-test-compile-first - process-test-resources - - testCompile - - - - - - - - - src/main/resources - - - - - - - com.intel - raydp-shims-common - ${project.version} - compile - - - org.apache.spark - spark-sql_${scala.binary.version} - ${spark340.version} - provided - - - org.apache.spark - spark-core_${scala.binary.version} - ${spark340.version} - provided - - - org.xerial.snappy - snappy-java - - - io.netty - netty-handler - - - - - org.xerial.snappy - snappy-java - ${snappy.version} - - - io.netty - netty-handler - ${netty.version} - - - diff --git a/core/shims/spark340/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark340/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider deleted file mode 100644 index 515c47a6..00000000 --- a/core/shims/spark340/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider +++ /dev/null @@ -1 +0,0 @@ -com.intel.raydp.shims.spark340.SparkShimProvider diff --git a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala deleted file mode 100644 index 229263e2..00000000 --- a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.intel.raydp.shims.spark340 - -import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} - -object SparkShimProvider { - val SPARK340_DESCRIPTOR = SparkShimDescriptor(3, 4, 0) - val SPARK341_DESCRIPTOR = SparkShimDescriptor(3, 4, 1) - val SPARK342_DESCRIPTOR = SparkShimDescriptor(3, 4, 2) - val SPARK343_DESCRIPTOR = SparkShimDescriptor(3, 4, 3) - val SPARK344_DESCRIPTOR = SparkShimDescriptor(3, 4, 4) - val DESCRIPTOR_STRINGS = Seq(s"$SPARK340_DESCRIPTOR", s"$SPARK341_DESCRIPTOR", s"$SPARK342_DESCRIPTOR", - s"$SPARK343_DESCRIPTOR", s"$SPARK344_DESCRIPTOR") - val DESCRIPTOR = SPARK341_DESCRIPTOR -} - -class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { - def createShim: SparkShims = { - new Spark340Shims() - } - - def matches(version: String): Boolean = { - SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) - } -} diff --git a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala deleted file mode 100644 index c444373f..00000000 --- a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.intel.raydp.shims.spark340 - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.executor.RayDPExecutorBackendFactory -import org.apache.spark.executor.spark340._ -import org.apache.spark.spark340.TaskContextUtils -import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.spark340.SparkSqlUtils -import com.intel.raydp.shims.{ShimDescriptor, SparkShims} -import org.apache.arrow.vector.types.pojo.Schema -import org.apache.spark.sql.types.StructType - -class Spark340Shims extends SparkShims { - override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR - - override def toDataFrame( - rdd: JavaRDD[Array[Byte]], - schema: String, - session: SparkSession): DataFrame = { - SparkSqlUtils.toDataFrame(rdd, schema, session) - } - - override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { - new RayDPSpark340ExecutorBackendFactory() - } - - override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { - TaskContextUtils.getDummyTaskContext(partitionId, env) - } - - override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) - } -} diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/TaskContextUtils.scala b/core/shims/spark340/src/main/scala/org/apache/spark/TaskContextUtils.scala deleted file mode 100644 index 780920da..00000000 --- a/core/shims/spark340/src/main/scala/org/apache/spark/TaskContextUtils.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.spark340 - -import java.util.Properties - -import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} -import org.apache.spark.memory.TaskMemoryManager - -object TaskContextUtils { - def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { - new TaskContextImpl(0, 0, partitionId, -1024, 0, 0, - new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) - } -} diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala b/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala deleted file mode 100644 index 2e6b5e25..00000000 --- a/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import java.net.URL - -import org.apache.spark.SparkEnv -import org.apache.spark.resource.ResourceProfile -import org.apache.spark.rpc.RpcEnv - -class RayCoarseGrainedExecutorBackend( - rpcEnv: RpcEnv, - driverUrl: String, - executorId: String, - bindAddress: String, - hostname: String, - cores: Int, - userClassPath: Seq[URL], - env: SparkEnv, - resourcesFileOpt: Option[String], - resourceProfile: ResourceProfile) - extends CoarseGrainedExecutorBackend( - rpcEnv, - driverUrl, - executorId, - bindAddress, - hostname, - cores, - env, - resourcesFileOpt, - resourceProfile) { - - override def getUserClassPath: Seq[URL] = userClassPath - -} diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayDPSpark340ExecutorBackendFactory.scala b/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayDPSpark340ExecutorBackendFactory.scala deleted file mode 100644 index 72abdc52..00000000 --- a/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayDPSpark340ExecutorBackendFactory.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor.spark340 - -import java.net.URL - -import org.apache.spark.SparkEnv -import org.apache.spark.executor._ -import org.apache.spark.resource.ResourceProfile -import org.apache.spark.rpc.RpcEnv - -class RayDPSpark340ExecutorBackendFactory - extends RayDPExecutorBackendFactory { - override def createExecutorBackend( - rpcEnv: RpcEnv, - driverUrl: String, - executorId: String, - bindAddress: String, - hostname: String, - cores: Int, - userClassPath: Seq[URL], - env: SparkEnv, - resourcesFileOpt: Option[String], - resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { - new RayCoarseGrainedExecutorBackend( - rpcEnv, - driverUrl, - executorId, - bindAddress, - hostname, - cores, - userClassPath, - env, - resourcesFileOpt, - resourceProfile) - } -} diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala deleted file mode 100644 index eb52d8e7..00000000 --- a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.spark340 - -import org.apache.arrow.vector.types.pojo.Schema -import org.apache.spark.TaskContext -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} -import org.apache.spark.sql.execution.arrow.ArrowConverters -import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.ArrowUtils - -object SparkSqlUtils { - def toDataFrame( - arrowBatchRDD: JavaRDD[Array[Byte]], - schemaString: String, - session: SparkSession): DataFrame = { - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] - val timeZoneId = session.sessionState.conf.sessionLocalTimeZone - val rdd = arrowBatchRDD.rdd.mapPartitions { iter => - val context = TaskContext.get() - ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) - } - session.internalCreateDataFrame(rdd.setName("arrow"), schema) - } - - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) - } -} diff --git a/core/shims/spark350/pom.xml b/core/shims/spark350/pom.xml deleted file mode 100644 index 2368daa2..00000000 --- a/core/shims/spark350/pom.xml +++ /dev/null @@ -1,99 +0,0 @@ - - - - 4.0.0 - - - com.intel - raydp-shims - 1.7.0-SNAPSHOT - ../pom.xml - - - raydp-shims-spark350 - RayDP Shims for Spark 3.5.0 - jar - - - 2.12.15 - 2.13.5 - - - - - - org.scalastyle - scalastyle-maven-plugin - - - net.alchim31.maven - scala-maven-plugin - 3.2.2 - - - scala-compile-first - process-resources - - compile - - - - scala-test-compile-first - process-test-resources - - testCompile - - - - - - - - - src/main/resources - - - - - - - com.intel - raydp-shims-common - ${project.version} - compile - - - org.apache.spark - spark-sql_${scala.binary.version} - ${spark350.version} - provided - - - org.apache.spark - spark-core_${scala.binary.version} - ${spark350.version} - provided - - - org.xerial.snappy - snappy-java - - - io.netty - netty-handler - - - - - org.xerial.snappy - snappy-java - ${snappy.version} - - - io.netty - netty-handler - ${netty.version} - - - diff --git a/core/shims/spark350/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark350/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider deleted file mode 100644 index 6e5a394e..00000000 --- a/core/shims/spark350/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider +++ /dev/null @@ -1 +0,0 @@ -com.intel.raydp.shims.spark350.SparkShimProvider diff --git a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala deleted file mode 100644 index 4a260e6f..00000000 --- a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.intel.raydp.shims.spark350 - -import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} - -object SparkShimProvider { - private val SUPPORTED_PATCHES = 0 to 7 - val DESCRIPTORS = SUPPORTED_PATCHES.map(p => SparkShimDescriptor(3, 5, p)) - val DESCRIPTOR_STRINGS = DESCRIPTORS.map(_.toString) - val DESCRIPTOR = DESCRIPTORS.head -} - -class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { - def createShim: SparkShims = { - new Spark350Shims() - } - - def matches(version: String): Boolean = { - SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) - } -} diff --git a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala deleted file mode 100644 index 721d6923..00000000 --- a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.intel.raydp.shims.spark350 - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.executor.RayDPExecutorBackendFactory -import org.apache.spark.executor.spark350._ -import org.apache.spark.spark350.TaskContextUtils -import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.spark350.SparkSqlUtils -import com.intel.raydp.shims.{ShimDescriptor, SparkShims} -import org.apache.arrow.vector.types.pojo.Schema -import org.apache.spark.sql.types.StructType - -class Spark350Shims extends SparkShims { - override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR - - override def toDataFrame( - rdd: JavaRDD[Array[Byte]], - schema: String, - session: SparkSession): DataFrame = { - SparkSqlUtils.toDataFrame(rdd, schema, session) - } - - override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { - new RayDPSpark350ExecutorBackendFactory() - } - - override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { - TaskContextUtils.getDummyTaskContext(partitionId, env) - } - - override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) - } -} diff --git a/core/shims/spark350/src/main/scala/org/apache/spark/TaskContextUtils.scala b/core/shims/spark350/src/main/scala/org/apache/spark/TaskContextUtils.scala deleted file mode 100644 index 0f38bbb9..00000000 --- a/core/shims/spark350/src/main/scala/org/apache/spark/TaskContextUtils.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.spark350 - -import java.util.Properties - -import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} -import org.apache.spark.memory.TaskMemoryManager - -object TaskContextUtils { - def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { - new TaskContextImpl(0, 0, partitionId, -1024, 0, 0, - new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) - } -} diff --git a/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala b/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala deleted file mode 100644 index 2e6b5e25..00000000 --- a/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import java.net.URL - -import org.apache.spark.SparkEnv -import org.apache.spark.resource.ResourceProfile -import org.apache.spark.rpc.RpcEnv - -class RayCoarseGrainedExecutorBackend( - rpcEnv: RpcEnv, - driverUrl: String, - executorId: String, - bindAddress: String, - hostname: String, - cores: Int, - userClassPath: Seq[URL], - env: SparkEnv, - resourcesFileOpt: Option[String], - resourceProfile: ResourceProfile) - extends CoarseGrainedExecutorBackend( - rpcEnv, - driverUrl, - executorId, - bindAddress, - hostname, - cores, - env, - resourcesFileOpt, - resourceProfile) { - - override def getUserClassPath: Seq[URL] = userClassPath - -} diff --git a/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayDPSpark350ExecutorBackendFactory.scala b/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayDPSpark350ExecutorBackendFactory.scala deleted file mode 100644 index 54d53d7d..00000000 --- a/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayDPSpark350ExecutorBackendFactory.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor.spark350 - -import java.net.URL - -import org.apache.spark.SparkEnv -import org.apache.spark.executor._ -import org.apache.spark.resource.ResourceProfile -import org.apache.spark.rpc.RpcEnv - -class RayDPSpark350ExecutorBackendFactory - extends RayDPExecutorBackendFactory { - override def createExecutorBackend( - rpcEnv: RpcEnv, - driverUrl: String, - executorId: String, - bindAddress: String, - hostname: String, - cores: Int, - userClassPath: Seq[URL], - env: SparkEnv, - resourcesFileOpt: Option[String], - resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { - new RayCoarseGrainedExecutorBackend( - rpcEnv, - driverUrl, - executorId, - bindAddress, - hostname, - cores, - userClassPath, - env, - resourcesFileOpt, - resourceProfile) - } -} diff --git a/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala deleted file mode 100644 index dfd063f7..00000000 --- a/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.spark350 - -import org.apache.arrow.vector.types.pojo.Schema -import org.apache.spark.TaskContext -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} -import org.apache.spark.sql.execution.arrow.ArrowConverters -import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.ArrowUtils - -object SparkSqlUtils { - def toDataFrame( - arrowBatchRDD: JavaRDD[Array[Byte]], - schemaString: String, - session: SparkSession): DataFrame = { - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] - val timeZoneId = session.sessionState.conf.sessionLocalTimeZone - val rdd = arrowBatchRDD.rdd.mapPartitions { iter => - val context = TaskContext.get() - ArrowConverters.fromBatchIterator(iter, schema, timeZoneId,false, context) - } - session.internalCreateDataFrame(rdd.setName("arrow"), schema) - } - - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false) - } -} From 105abeae1333e6aaa6fa97a3334697c2f05732f2 Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 16:18:36 -0500 Subject: [PATCH 07/16] Remove deprecated fromSparkRDD pipeline and dead bytes code path The old fromSparkRDD/RayDatasetRDD pipeline produced bytes blocks and was already @deprecated with zero callers. Remove it along with all supporting dead code: _convert_by_rdd (Python), RayDatasetRDD.scala, ObjectRefHolder, ObjectStoreWriter instance class, and the wasteful ray.get(blocks[0]) type-check in ray_dataset_to_spark_dataframe. --- .../org/apache/spark/rdd/RayDatasetRDD.scala | 57 ------- .../spark/sql/raydp/ObjectStoreReader.scala | 27 +--- .../spark/sql/raydp/ObjectStoreWriter.scala | 144 +----------------- python/raydp/spark/dataset.py | 23 +-- 4 files changed, 5 insertions(+), 246 deletions(-) delete mode 100644 core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala diff --git a/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala b/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala deleted file mode 100644 index 1992b9a3..00000000 --- a/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import java.util.List; - -import scala.collection.JavaConverters._ - -import io.ray.runtime.generated.Common.Address - -import org.apache.spark.{Partition, SparkContext, TaskContext} -import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.raydp.RayDPUtils -import org.apache.spark.sql.raydp.ObjectStoreReader - -private[spark] class RayDatasetRDDPartition(val ref: Array[Byte], idx: Int) extends Partition { - val index = idx -} - -private[spark] -class RayDatasetRDD( - jsc: JavaSparkContext, - @transient val objectIds: List[Array[Byte]], - locations: List[Array[Byte]]) - extends RDD[Array[Byte]](jsc.sc, Nil) { - - override def getPartitions: Array[Partition] = { - objectIds.asScala.zipWithIndex.map { case (k, i) => - new RayDatasetRDDPartition(k, i).asInstanceOf[Partition] - }.toArray - } - - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val ref = split.asInstanceOf[RayDatasetRDDPartition].ref - ObjectStoreReader.getBatchesFromStream(ref, locations.get(split.index)) - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - val address = Address.parseFrom(locations.get(split.index)) - Seq(address.getIpAddress()) - } -} diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala index 39a1c2b1..a88042c2 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala @@ -17,18 +17,10 @@ package org.apache.spark.sql.raydp -import java.io.ByteArrayInputStream -import java.nio.channels.{Channels, ReadableByteChannel} import java.util.List -import com.intel.raydp.shims.SparkShimLoader - -import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.raydp.RayDPUtils -import org.apache.spark.rdd.{RayDatasetRDD, RayObjectRefRDD} -import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} -import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.execution.arrow.ArrowConverters +import org.apache.spark.rdd.RayObjectRefRDD +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.types.{IntegerType, StructType} object ObjectStoreReader { @@ -39,19 +31,4 @@ object ObjectStoreReader { val schema = new StructType().add("idx", IntegerType) spark.createDataFrame(rdd, schema) } - - def RayDatasetToDataFrame( - sparkSession: SparkSession, - rdd: RayDatasetRDD, - schema: String): DataFrame = { - SparkShimLoader.getSparkShims.toDataFrame(JavaRDD.fromRDD(rdd), schema, sparkSession) - } - - def getBatchesFromStream( - ref: Array[Byte], - ownerAddress: Array[Byte]): Iterator[Array[Byte]] = { - val objectRef = RayDPUtils.readBinary(ref, classOf[Array[Byte]], ownerAddress) - ArrowConverters.getBatchesFromStream( - Channels.newChannel(new ByteArrayInputStream(objectRef.get))) - } } diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala index 282781f3..2a1b3cab 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala @@ -18,50 +18,20 @@ package org.apache.spark.sql.raydp import com.intel.raydp.shims.SparkShimLoader -import io.ray.api.{ActorHandle, ObjectRef, Ray} -import io.ray.runtime.AbstractRayRuntime -import java.util.{List, UUID} -import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} -import java.util.function.{Function => JFunction} +import io.ray.api.{ActorHandle, Ray} import org.apache.arrow.vector.types.pojo.Schema -import scala.collection.JavaConverters._ -import scala.collection.mutable import org.apache.spark.{RayDPException, SparkContext} import org.apache.spark.deploy.raydp._ import org.apache.spark.executor.RayDPExecutor -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.raydp.{RayDPUtils, RayExecutorUtils} +import org.apache.spark.raydp.RayExecutorUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel -class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable { - - val uuid: UUID = ObjectStoreWriter.dfToId.getOrElseUpdate(df, UUID.randomUUID()) - - /** - * For test. - */ - def getRandomRef(): List[Array[Byte]] = { - - df.queryExecution.toRdd.mapPartitions { _ => - Iterator(ObjectRefHolder.getRandom(uuid)) - }.collect().toSeq.asJava - } - - def clean(): Unit = { - ObjectStoreWriter.dfToId.remove(df) - ObjectRefHolder.removeQueue(uuid) - } - -} - object ObjectStoreWriter { - val dfToId = new mutable.HashMap[DataFrame, UUID]() var driverAgent: RayDPDriverAgent = _ var driverAgentUrl: String = _ - var address: Array[Byte] = null def connectToRay(): Unit = { if (!Ray.isInitialized) { @@ -73,26 +43,6 @@ object ObjectStoreWriter { } } - private def parseMemoryBytes(value: String): Double = { - if (value == null || value.isEmpty) { - 0.0 - } else { - // Spark parser supports both plain numbers (bytes) and strings like "100M", "2g". - JavaUtils.byteStringAsBytes(value).toDouble - } - } - - def getAddress(): Array[Byte] = { - if (address == null) { - val objectRef = Ray.put(1) - val objectRefImpl = RayDPUtils.convert(objectRef) - val objectId = objectRefImpl.getId - val runtime = Ray.internal.asInstanceOf[AbstractRayRuntime] - address = runtime.getObjectStore.getOwnershipInfo(objectId) - } - address - } - def toArrowSchema(df: DataFrame): Schema = { val conf = df.queryExecution.sparkSession.sessionState.conf val timeZoneId = conf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE) @@ -100,56 +50,6 @@ object ObjectStoreWriter { SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId, largeVarTypes) } - @deprecated - def fromSparkRDD(df: DataFrame, storageLevel: StorageLevel): Array[Array[Byte]] = { - if (!Ray.isInitialized) { - throw new RayDPException( - "Not yet connected to Ray! Please set fault_tolerant_mode=True when starting RayDP.") - } - val uuid = dfToId.getOrElseUpdate(df, UUID.randomUUID()) - val queue = ObjectRefHolder.getQueue(uuid) - val rdd = SparkShimLoader.getSparkShims.toArrowBatchRdd(df) - rdd.persist(storageLevel) - rdd.count() - var executorIds = df.sparkSession.sparkContext.getExecutorIds.toArray - val numExecutors = executorIds.length - val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME) - .get.asInstanceOf[ActorHandle[RayAppMaster]] - val restartedExecutors = RayAppMasterUtils.getRestartedExecutors(appMasterHandle) - // Check if there is any restarted executors - if (!restartedExecutors.isEmpty) { - // If present, need to use the old id to find ray actors - for (i <- 0 until numExecutors) { - if (restartedExecutors.containsKey(executorIds(i))) { - val oldId = restartedExecutors.get(executorIds(i)) - executorIds(i) = oldId - } - } - } - val schema = ObjectStoreWriter.toArrowSchema(df).toJson - val numPartitions = rdd.getNumPartitions - val results = new Array[Array[Byte]](numPartitions) - val refs = new Array[ObjectRef[Array[Byte]]](numPartitions) - val handles = executorIds.map {id => - Ray.getActor("raydp-executor-" + id) - .get - .asInstanceOf[ActorHandle[RayDPExecutor]] - } - val handlesMap = (executorIds zip handles).toMap - val locations = RayExecutorUtils.getBlockLocations( - handles(0), rdd.id, numPartitions) - for (i <- 0 until numPartitions) { - // TODO use getPreferredLocs, but we don't have a host ip to actor table now - refs(i) = RayExecutorUtils.getRDDPartition( - handlesMap(locations(i)), rdd.id, i, schema, driverAgentUrl) - queue.add(refs(i)) - } - for (i <- 0 until numPartitions) { - results(i) = RayDPUtils.convert(refs(i)).getId.getBytes - } - results - } - /** * Prepare a Spark ArrowBatch RDD for recoverable conversion and return metadata needed by * Python to build reconstructable Ray Dataset blocks via Ray tasks. @@ -213,43 +113,3 @@ object RecoverableRDDInfo { def empty: RecoverableRDDInfo = RecoverableRDDInfo(0, 0, "", "", Array.empty[String]) } -object ObjectRefHolder { - type Queue = ConcurrentLinkedQueue[ObjectRef[Array[Byte]]] - private val dfToQueue = new ConcurrentHashMap[UUID, Queue]() - - def getQueue(df: UUID): Queue = { - dfToQueue.computeIfAbsent(df, new JFunction[UUID, Queue] { - override def apply(v1: UUID): Queue = { - new Queue() - } - }) - } - - @inline - def checkQueueExists(df: UUID): Queue = { - val queue = dfToQueue.get(df) - if (queue == null) { - throw new RuntimeException("The DataFrame does not exist") - } - queue - } - - def getQueueSize(df: UUID): Int = { - val queue = checkQueueExists(df) - queue.size() - } - - def getRandom(df: UUID): Array[Byte] = { - val queue = checkQueueExists(df) - val ref = RayDPUtils.convert(queue.peek()) - ref.get() - } - - def removeQueue(df: UUID): Unit = { - dfToQueue.remove(df) - } - - def clean(): Unit = { - dfToQueue.clear() - } -} diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py index 6afd78e5..4be1fb0c 100644 --- a/python/raydp/spark/dataset.py +++ b/python/raydp/spark/dataset.py @@ -244,20 +244,6 @@ def _convert_blocks_to_dataframe(blocks): df = blocks_df.mapInPandas(_convert_blocks_to_dataframe, schema) return df -def _convert_by_rdd(spark: sql.SparkSession, - blocks: Dataset, - locations: List[bytes], - schema: StructType) -> DataFrame: - object_ids = [block.binary() for block in blocks] - schema_str = schema.json() - jvm = spark.sparkContext._jvm - # create rdd in java - rdd = jvm.org.apache.spark.rdd.RayDatasetRDD(spark._jsc, object_ids, locations) - # convert the rdd to dataframe - object_store_reader = jvm.org.apache.spark.sql.raydp.ObjectStoreReader - jdf = object_store_reader.RayDatasetToDataFrame(spark._jsparkSession, rdd, schema_str) - return DataFrame(jdf, spark._wrapped if hasattr(spark, "_wrapped") else spark) - @client_mode_wrap def get_locations(blocks): core_worker = ray.worker.global_worker.core_worker @@ -279,14 +265,7 @@ def ray_dataset_to_spark_dataframe(spark: sql.SparkSession, schema = StructType() for field in arrow_schema: schema.add(field.name, from_arrow_type(field.type), nullable=field.nullable) - #TODO how to branch on type of block? - sample = ray.get(blocks[0]) - if isinstance(sample, bytes): - return _convert_by_rdd(spark, blocks, locations, schema) - elif isinstance(sample, pa.Table): - return _convert_by_udf(spark, blocks, locations, schema) - else: - raise RuntimeError("ray.to_spark only supports arrow type blocks") + return _convert_by_udf(spark, blocks, locations, schema) def read_spark_parquet(path: str) -> Dataset: From 6b3d48d20b9e6a040c2d50e870f24aca18e23452 Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 16:31:13 -0500 Subject: [PATCH 08/16] Replace deprecated df.sql_ctx with df.sparkSession for Spark 4.1 df.sql_ctx is deprecated in PySpark 4.1 and emits warnings on every access. All usages followed the pattern df.sql_ctx.sparkSession.X which simplifies to df.sparkSession.X. Also remove the dead _wrapped guard (Spark 3.x internal) and stale TODO/comments in RayAppMaster for fractional resources that are already supported. --- .../scala/org/apache/spark/deploy/raydp/RayAppMaster.scala | 3 --- python/raydp/spark/dataset.py | 4 ++-- python/raydp/tf/estimator.py | 4 ++-- python/raydp/torch/estimator.py | 4 ++-- python/raydp/xgboost/estimator.py | 4 ++-- 5 files changed, 8 insertions(+), 11 deletions(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala index f4cc823d..85520787 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala @@ -268,7 +268,6 @@ class RayAppMaster(host: String, s"{ CPU: $rayActorCPU, " + s"${appInfo.desc.resourceReqsPerExecutor .map { case (name, amount) => s"$name: $amount" }.mkString(", ")} }..") - // TODO: Support generic fractional logical resources using prefix spark.ray.actor.resource.* // This will check with dynamic auto scale no additional pending executor actor added more // than max executors count as this result in executor even running after job completion @@ -292,8 +291,6 @@ class RayAppMaster(host: String, getAppMasterEndpointUrl(), rayActorCPU, memory, - // This won't work, Spark expect integer in custom resources, - // please see python test test_spark_on_fractional_custom_resource appInfo.desc.resourceReqsPerExecutor .map { case (name, amount) => (name, Double.box(amount)) }.asJava, placementGroup, diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py index 4be1fb0c..de51fd5e 100644 --- a/python/raydp/spark/dataset.py +++ b/python/raydp/spark/dataset.py @@ -170,7 +170,7 @@ def from_spark_recoverable(df: sql.DataFrame, if parallelism is not None: if parallelism != num_part: df = df.repartition(parallelism) - sc = df.sql_ctx.sparkSession.sparkContext + sc = df.sparkSession.sparkContext storage_level = sc._getJavaStorageLevel(storage_level) object_store_writer = sc._jvm.org.apache.spark.sql.raydp.ObjectStoreWriter # Recoverable conversion for Ray node loss: @@ -226,7 +226,7 @@ def _convert_by_udf(spark: sql.SparkSession, jdf = object_store_reader.createRayObjectRefDF(spark._jsparkSession, locations) current_namespace = ray.get_runtime_context().namespace ray_address = ray.get(holder.get_ray_address.remote()) - blocks_df = DataFrame(jdf, spark._wrapped if hasattr(spark, "_wrapped") else spark) + blocks_df = DataFrame(jdf, spark) def _convert_blocks_to_dataframe(blocks): # connect to ray if not ray.is_initialized(): diff --git a/python/raydp/tf/estimator.py b/python/raydp/tf/estimator.py index 1eb3bf1d..13c71c52 100644 --- a/python/raydp/tf/estimator.py +++ b/python/raydp/tf/estimator.py @@ -280,7 +280,7 @@ def fit_on_spark(self, train_df = self._check_and_convert(train_df) evaluate_ds = None if fs_directory is not None: - app_id = train_df.sql_ctx.sparkSession.sparkContext.applicationId + app_id = train_df.sparkSession.sparkContext.applicationId path = fs_directory.rstrip("/") + f"/{app_id}" train_df.write.parquet(path+"/train", compression=compression) train_ds = read_spark_parquet(path+"/train") @@ -291,7 +291,7 @@ def fit_on_spark(self, else: owner = None if stop_spark_after_conversion: - owner = get_raydp_master_owner(train_df.sql_ctx.sparkSession) + owner = get_raydp_master_owner(train_df.sparkSession) train_ds = spark_dataframe_to_ray_dataset(train_df, owner=owner) if evaluate_df is not None: diff --git a/python/raydp/torch/estimator.py b/python/raydp/torch/estimator.py index 4b4ba4fb..adb482a5 100644 --- a/python/raydp/torch/estimator.py +++ b/python/raydp/torch/estimator.py @@ -366,7 +366,7 @@ def fit_on_spark(self, train_df = self._check_and_convert(train_df) evaluate_ds = None if fs_directory is not None: - app_id = train_df.sql_ctx.sparkSession.sparkContext.applicationId + app_id = train_df.sparkSession.sparkContext.applicationId path = fs_directory.rstrip("/") + f"/{app_id}" train_df.write.parquet(path+"/train", compression=compression) train_ds = read_spark_parquet(path+"/train") @@ -377,7 +377,7 @@ def fit_on_spark(self, else: owner = None if stop_spark_after_conversion: - owner = get_raydp_master_owner(train_df.sql_ctx.sparkSession) + owner = get_raydp_master_owner(train_df.sparkSession) train_ds = spark_dataframe_to_ray_dataset(train_df, owner=owner) if evaluate_df is not None: diff --git a/python/raydp/xgboost/estimator.py b/python/raydp/xgboost/estimator.py index 0b6ac1f6..21f03c75 100644 --- a/python/raydp/xgboost/estimator.py +++ b/python/raydp/xgboost/estimator.py @@ -90,7 +90,7 @@ def fit_on_spark(self, train_df = self._check_and_convert(train_df) evaluate_ds = None if fs_directory is not None: - app_id = train_df.sql_ctx.sparkSession.sparkContext.applicationId + app_id = train_df.sparkSession.sparkContext.applicationId path = fs_directory.rstrip("/") + f"/{app_id}" train_df.write.parquet(path+"/train", compression=compression) train_ds = read_spark_parquet(path+"/train") @@ -101,7 +101,7 @@ def fit_on_spark(self, else: owner = None if stop_spark_after_conversion: - owner = get_raydp_master_owner(train_df.sql_ctx.sparkSession) + owner = get_raydp_master_owner(train_df.sparkSession) train_ds = spark_dataframe_to_ray_dataset(train_df, parallelism=self._num_workers, owner=owner) From 0de60559745dab0eb8c33c06df749eb35c7b07d3 Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 16:42:49 -0500 Subject: [PATCH 09/16] =?UTF-8?q?Use=20mapInArrow=20instead=20of=20mapInPa?= =?UTF-8?q?ndas=20for=20zero-copy=20Ray=E2=86=92Spark=20conversion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Eliminates the Arrow→pandas→Arrow round-trip in _convert_by_udf by switching from mapInPandas to mapInArrow (stable since Spark 3.4). Data stays in Arrow format throughout: pa.concat_tables + to_batches yields RecordBatches directly with no serialization overhead. --- python/raydp/spark/dataset.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py index de51fd5e..68138e8a 100644 --- a/python/raydp/spark/dataset.py +++ b/python/raydp/spark/dataset.py @@ -227,21 +227,22 @@ def _convert_by_udf(spark: sql.SparkSession, current_namespace = ray.get_runtime_context().namespace ray_address = ray.get(holder.get_ray_address.remote()) blocks_df = DataFrame(jdf, spark) - def _convert_blocks_to_dataframe(blocks): + def _convert_blocks_to_batches(batches): # connect to ray if not ray.is_initialized(): ray.init(address=ray_address, namespace=current_namespace, logging_level=logging.WARN) obj_holder = ray.get_actor(holder_name) - for block in blocks: - dfs = [] - for idx in block["idx"]: + for batch in batches: + indices = batch.column("idx").to_pylist() + tables = [] + for idx in indices: ref = ray.get(obj_holder.get_object.remote(df_id, idx)) - data = ray.get(ref) - dfs.append(data.to_pandas()) - yield pd.concat(dfs) - df = blocks_df.mapInPandas(_convert_blocks_to_dataframe, schema) + tables.append(ray.get(ref)) + combined = pa.concat_tables(tables) + yield from combined.to_batches() + df = blocks_df.mapInArrow(_convert_blocks_to_batches, schema) return df @client_mode_wrap From 616828893d2a1bcbaa52baec94978de8daca2df6 Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 17:00:19 -0500 Subject: [PATCH 10/16] Arrow performance tuning: lz4 codec, unlimited batch size, dead code removal - Default Arrow compression codec from zstd to lz4 (faster for intra-cluster) - Set maxRecordsPerBatch=0 (unlimited) since RayDP converts whole partitions - Skip pa.concat_tables when partition has a single block (avoids copy) - Use _ray_global_worker to avoid deprecation warning in get_locations - Remove dead RecordPiece/RayObjectPiece classes and unused pandas import --- python/raydp/spark/dataset.py | 60 ++----------------------------- python/raydp/spark/ray_cluster.py | 10 ++++-- 2 files changed, 11 insertions(+), 59 deletions(-) diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py index 68138e8a..f3879494 100644 --- a/python/raydp/spark/dataset.py +++ b/python/raydp/spark/dataset.py @@ -19,7 +19,6 @@ from dataclasses import dataclass from packaging import version -import pandas as pd import pyarrow as pa import pyspark.sql as sql from pyspark.sql import SparkSession @@ -77,56 +76,6 @@ def _fetch_arrow_table_from_executor(executor_actor_name: str, return reader.read_all() -class RecordPiece: - def __init__(self, row_ids, num_rows: int): - self.row_ids = row_ids - self.num_rows = num_rows - - def read(self, shuffle: bool) -> pd.DataFrame: - raise NotImplementedError - - def with_row_ids(self, new_row_ids) -> "RecordPiece": - raise NotImplementedError - - def __len__(self): - """Return the number of rows""" - return self.num_rows - - -class RayObjectPiece(RecordPiece): - def __init__(self, - obj_id: ray.ObjectRef, - row_ids: Optional[List[int]], - num_rows: int): - super().__init__(row_ids, num_rows) - self.obj_id = obj_id - - def read(self, shuffle: bool) -> pd.DataFrame: - data = ray.get(self.obj_id) - reader = pa.ipc.open_stream(data) - tb = reader.read_all() - df: pd.DataFrame = tb.to_pandas() - if self.row_ids: - df = df.loc[self.row_ids] - - if shuffle: - df = df.sample(frac=1.0) - return df - - def with_row_ids(self, new_row_ids) -> "RayObjectPiece": - """chang the num_rows to the length of new_row_ids. Keep the original size if - the new_row_ids is None. - """ - - if new_row_ids: - num_rows = len(new_row_ids) - else: - num_rows = self.num_rows - - return RayObjectPiece(self.obj_id, new_row_ids, num_rows) - - - @dataclass class PartitionObjectsOwner: # Actor owner name @@ -240,18 +189,15 @@ def _convert_blocks_to_batches(batches): for idx in indices: ref = ray.get(obj_holder.get_object.remote(df_id, idx)) tables.append(ray.get(ref)) - combined = pa.concat_tables(tables) + combined = tables[0] if len(tables) == 1 else pa.concat_tables(tables) yield from combined.to_batches() df = blocks_df.mapInArrow(_convert_blocks_to_batches, schema) return df @client_mode_wrap def get_locations(blocks): - core_worker = ray.worker.global_worker.core_worker - return [ - core_worker.get_owner_address(block) - for block in blocks - ] + core_worker = _ray_global_worker.core_worker + return [core_worker.get_owner_address(b) for b in blocks] def ray_dataset_to_spark_dataframe(spark: sql.SparkSession, arrow_schema, diff --git a/python/raydp/spark/ray_cluster.py b/python/raydp/spark/ray_cluster.py index 4e712d5b..2f04be1a 100644 --- a/python/raydp/spark/ray_cluster.py +++ b/python/raydp/spark/ray_cluster.py @@ -119,11 +119,17 @@ def _prepare_spark_configs(self): self._configs["spark.executor.instances"] = str(self._num_executors) self._configs["spark.executor.cores"] = str(self._executor_cores) self._configs["spark.executor.memory"] = str(self._executor_memory) - # Enable Arrow IPC zstd compression by default (Spark 4.1+). + # Enable Arrow IPC lz4 compression by default (Spark 4.1+). + # lz4 is faster than zstd for intra-cluster transfers where RayDP operates. # Users can override by passing their own value for this config. arrow_codec_key = "spark.sql.execution.arrow.compression.codec" if arrow_codec_key not in self._configs: - self._configs[arrow_codec_key] = "zstd" + self._configs[arrow_codec_key] = "lz4" + # RayDP converts whole partitions (not streaming UDFs), so one Arrow batch + # per partition minimizes per-batch overhead. 0 = unlimited rows per batch. + arrow_batch_key = "spark.sql.execution.arrow.maxRecordsPerBatch" + if arrow_batch_key not in self._configs: + self._configs[arrow_batch_key] = "0" if platform.system() != "Darwin": driver_node_ip = ray.util.get_node_ip_address() From 618d54eb8a768ac35b2d2e2f16833c0daee90ece Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 17:08:31 -0500 Subject: [PATCH 11/16] Batch ray.get calls and eliminate IPC double-copy in Arrow transfers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ray→Spark: batch actor lookups and data fetches into two ray.get() calls instead of 2N sequential calls, letting Ray pipeline transfers. Spark→Ray: replace ByteArrayOutputStream + toByteArray (two full copies of partition data) with a single pre-allocated byte[] and direct System.arraycopy (one copy), halving memory and CPU overhead in getRDDPartition. --- .../apache/spark/executor/RayDPExecutor.scala | 32 +++++++++++++------ python/raydp/spark/dataset.py | 8 ++--- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala index 79063553..4f66edb7 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala @@ -342,15 +342,29 @@ class RayDPExecutor( throw new RayDPException("Still cannot get the block after recache!") } } - val estimatedSize = Math.max(blockBytes + 1024, 1024).toInt - val byteOut = new ByteArrayOutputStream(estimatedSize) - val writeChannel = new WriteChannel(Channels.newChannel(byteOut)) - MessageSerializer.serialize(writeChannel, schema) - iterator.foreach(writeChannel.write) - ArrowStreamWriter.writeEndOfStream(writeChannel, new IpcOption) - val result = byteOut.toByteArray - writeChannel.close - byteOut.close + // Collect batch byte arrays (references only; data already in memory from BlockManager) + val batches = iterator.toArray + + // Serialize schema header and EOS marker into small temp buffers + val schemaBuf = new ByteArrayOutputStream(512) + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(schemaBuf)), schema) + val schemaBytes = schemaBuf.toByteArray + + val eosBuf = new ByteArrayOutputStream(16) + ArrowStreamWriter.writeEndOfStream(new WriteChannel(Channels.newChannel(eosBuf)), new IpcOption) + val eosBytes = eosBuf.toByteArray + + // Single allocation: copy schema + batches + EOS directly (avoids BAOS + toByteArray double-copy) + val totalSize = schemaBytes.length + batches.map(_.length.toLong).sum + eosBytes.length + val result = new Array[Byte](totalSize.toInt) + var offset = 0 + System.arraycopy(schemaBytes, 0, result, offset, schemaBytes.length) + offset += schemaBytes.length + batches.foreach { batch => + System.arraycopy(batch, 0, result, offset, batch.length) + offset += batch.length + } + System.arraycopy(eosBytes, 0, result, offset, eosBytes.length) result } finally { env.blockManager.releaseAllLocksForTask(taskAttemptId) diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py index f3879494..87383cbc 100644 --- a/python/raydp/spark/dataset.py +++ b/python/raydp/spark/dataset.py @@ -185,10 +185,10 @@ def _convert_blocks_to_batches(batches): obj_holder = ray.get_actor(holder_name) for batch in batches: indices = batch.column("idx").to_pylist() - tables = [] - for idx in indices: - ref = ray.get(obj_holder.get_object.remote(df_id, idx)) - tables.append(ray.get(ref)) + # Batch both actor lookups and data fetches so Ray can pipeline them + ref_futures = [obj_holder.get_object.remote(df_id, idx) for idx in indices] + refs = ray.get(ref_futures) + tables = ray.get(list(refs)) combined = tables[0] if len(tables) == 1 else pa.concat_tables(tables) yield from combined.to_batches() df = blocks_df.mapInArrow(_convert_blocks_to_batches, schema) From 776859cfc4bc3210f7f08ee7ceb1eec85aac93de Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 18:09:40 -0500 Subject: [PATCH 12/16] =?UTF-8?q?Overlap=20Spark=20materialization=20with?= =?UTF-8?q?=20=20=20=20=20Ray=20fetching=20in=20Spark=E2=86=92Ray=20conver?= =?UTF-8?q?sion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of blocking on rdd.count() until ALL partitions are materialized before submitting any Ray fetch tasks, run materialization in a background thread and poll BlockManager for completed partitions. As each partition becomes available, its Ray fetch task is dispatched immediately — overlapping Spark computation with Ray data transfer. JVM: add StreamingRecoverableRDD handle with getReadyPartitions() that queries BlockManager.blockIdsToLocations for newly cached blocks. Python: from_spark_recoverable now polls the handle in a loop, dispatching fetch tasks incrementally as partitions complete. --- .../spark/sql/raydp/ObjectStoreWriter.scala | 96 ++++++++++++++++++- python/raydp/spark/dataset.py | 68 +++++++------ 2 files changed, 132 insertions(+), 32 deletions(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala index 2a1b3cab..0b34564b 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala @@ -21,13 +21,13 @@ import com.intel.raydp.shims.SparkShimLoader import io.ray.api.{ActorHandle, Ray} import org.apache.arrow.vector.types.pojo.Schema -import org.apache.spark.{RayDPException, SparkContext} +import org.apache.spark.{RayDPException, SparkContext, SparkEnv} import org.apache.spark.deploy.raydp._ import org.apache.spark.executor.RayDPExecutor import org.apache.spark.raydp.RayExecutorUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel} object ObjectStoreWriter { var driverAgent: RayDPDriverAgent = _ @@ -99,6 +99,51 @@ object ObjectStoreWriter { RecoverableRDDInfo(rdd.id, numPartitions, schemaJson, driverAgentUrl, locations) } + /** + * Streaming variant: starts materialization in a background thread and returns + * a handle that Python can poll for completed partitions. This lets Ray fetch + * tasks overlap with Spark partition computation instead of blocking on rdd.count(). + */ + def startStreamingRecoverableRDD( + df: DataFrame, + storageLevel: StorageLevel): StreamingRecoverableRDD = { + if (!Ray.isInitialized) { + throw new RayDPException( + "Not yet connected to Ray! Please set fault_tolerant_mode=True when starting RayDP.") + } + + val rdd = SparkShimLoader.getSparkShims.toArrowBatchRdd(df) + rdd.persist(storageLevel) + + val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME) + .get.asInstanceOf[ActorHandle[RayAppMaster]] + val restartedExecutors = RayAppMasterUtils.getRestartedExecutors(appMasterHandle) + + val schemaJson = ObjectStoreWriter.toArrowSchema(df).toJson + val numPartitions = rdd.getNumPartitions + + val handle = new StreamingRecoverableRDD( + rdd.id, numPartitions, schemaJson, driverAgentUrl, + restartedExecutors, SparkEnv.get) + + // Start materialization in background — partitions become visible via getReadyPartitions() + val thread = new Thread("raydp-materialize-" + rdd.id) { + setDaemon(true) + override def run(): Unit = { + try { + rdd.count() + } catch { + case e: Throwable => handle.setError(e) + } finally { + handle.setComplete() + } + } + } + thread.start() + + handle + } + } case class RecoverableRDDInfo( @@ -113,3 +158,50 @@ object RecoverableRDDInfo { def empty: RecoverableRDDInfo = RecoverableRDDInfo(0, 0, "", "", Array.empty[String]) } +/** + * Handle returned by [[ObjectStoreWriter.startStreamingRecoverableRDD]]. + * Python polls [[getReadyPartitions]] to discover which partitions have been + * materialized in Spark's BlockManager, then immediately submits Ray fetch + * tasks for those partitions — overlapping Spark computation with Ray transfer. + */ +class StreamingRecoverableRDD( + val rddId: Int, + val numPartitions: Int, + val schemaJson: String, + val driverAgentUrl: String, + private val restartedExecutors: java.util.Map[String, String], + private val env: SparkEnv) { + + @volatile private var _error: Throwable = _ + @volatile private var _complete: Boolean = false + + private val blockIds: Array[BlockId] = (0 until numPartitions).map(i => + BlockId.apply("rdd_" + rddId + "_" + i) + ).toArray + + def setError(e: Throwable): Unit = { _error = e } + def setComplete(): Unit = { _complete = true } + + def isComplete: Boolean = _complete + def getError: String = if (_error != null) _error.getMessage else null + + /** + * Returns an Array[String] of length numPartitions. + * For materialized partitions: the (mapped) executor ID suitable for Ray actor lookup. + * For not-yet-ready partitions: null. + */ + def getReadyPartitions(): Array[String] = { + val locations = BlockManager.blockIdsToLocations(blockIds, env) + val result = new Array[String](numPartitions) + for ((key, value) <- locations if value.nonEmpty) { + val partitionId = key.name.substring(key.name.lastIndexOf('_') + 1).toInt + var executorId = value(0).substring(value(0).lastIndexOf('_') + 1) + if (restartedExecutors.containsKey(executorId)) { + executorId = restartedExecutors.get(executorId) + } + result(partitionId) = executorId + } + result + } +} + diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py index 87383cbc..bd35864b 100644 --- a/python/raydp/spark/dataset.py +++ b/python/raydp/spark/dataset.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import time import uuid from typing import Callable, List, Optional, Union from dataclasses import dataclass @@ -114,7 +115,12 @@ def spark_dataframe_to_ray_dataset(df: sql.DataFrame, def from_spark_recoverable(df: sql.DataFrame, storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK, parallelism: Optional[int] = None): - """Recoverable Spark->Ray conversion that survives executor loss.""" + """Recoverable Spark->Ray conversion that survives executor loss. + + Materialization and Ray fetching are overlapped: as each Spark partition + completes, its Ray fetch task is submitted immediately rather than waiting + for all partitions to finish first. + """ num_part = df.rdd.getNumPartitions() if parallelism is not None: if parallelism != num_part: @@ -122,15 +128,13 @@ def from_spark_recoverable(df: sql.DataFrame, sc = df.sparkSession.sparkContext storage_level = sc._getJavaStorageLevel(storage_level) object_store_writer = sc._jvm.org.apache.spark.sql.raydp.ObjectStoreWriter - # Recoverable conversion for Ray node loss: - # - cache Arrow bytes in Spark - # - build Ray Dataset blocks via Ray tasks (lineage), each task refetches bytes via JVM actors - info = object_store_writer.prepareRecoverableRDD(df._jdf, storage_level) - rdd_id = info.rddId() - num_partitions = info.numPartitions() - schema_json = info.schemaJson() - driver_agent_url = info.driverAgentUrl() - locations = info.locations() + + # Start materialization in the background; returns a handle we can poll. + handle = object_store_writer.startStreamingRecoverableRDD(df._jdf, storage_level) + rdd_id = handle.rddId() + num_partitions = handle.numPartitions() + schema_json = handle.schemaJson() + driver_agent_url = handle.driverAgentUrl() spark_conf = sc.getConf() fetch_num_cpus = float( @@ -138,26 +142,30 @@ def from_spark_recoverable(df: sql.DataFrame, fetch_memory_str = spark_conf.get( "spark.ray.raydp_recoverable_fetch.task.resource.memory", "0") fetch_memory = float(parse_memory_size(fetch_memory_str)) - - - refs: List[ObjectRef] = [] - for i in range(num_partitions): - executor_id = locations[i] - executor_actor_name = f"raydp-executor-{executor_id}" - task_opts = { - "num_cpus": fetch_num_cpus, - "memory": fetch_memory, - } - fetch_task = _fetch_arrow_table_from_executor.options(**task_opts) - refs.append( - fetch_task.remote( - executor_actor_name, - rdd_id, - i, - schema_json, - driver_agent_url, - ) - ) + task_opts = {"num_cpus": fetch_num_cpus, "memory": fetch_memory} + fetch_task = _fetch_arrow_table_from_executor.options(**task_opts) + + # Poll for completed partitions and dispatch Ray fetch tasks as they become ready. + refs: List[Optional[ObjectRef]] = [None] * num_partitions + dispatched = set() + + while len(dispatched) < num_partitions: + err = handle.getError() + if err is not None: + raise RuntimeError(f"Spark materialization failed: {err}") + + ready = handle.getReadyPartitions() + new_count = 0 + for i in range(num_partitions): + if i not in dispatched and ready[i] is not None: + executor_actor_name = f"raydp-executor-{ready[i]}" + refs[i] = fetch_task.remote( + executor_actor_name, rdd_id, i, schema_json, driver_agent_url) + dispatched.add(i) + new_count += 1 + + if len(dispatched) < num_partitions and new_count == 0: + time.sleep(0.1) return from_arrow_refs(refs) From 9c881311683ba05198b1a9b1591003208ac63476 Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 18:29:01 -0500 Subject: [PATCH 13/16] ensure generated artifacts are not committed to git. --- .gitignore | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 11145713..20115429 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,9 @@ _SUCCESS .metals/ .bloop/ -.venv/ \ No newline at end of file +.venv/ + +# Generated by python/setup.py during build +python/raydp/jars/__init__.py +python/raydp/bin/__init__.py +python/raydp/bin/raydp-submit \ No newline at end of file From c00553939a94d2da9537d50485ba8201ea518f24 Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 18:43:11 -0500 Subject: [PATCH 14/16] updated readme. --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4a780a78..f862ee60 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,13 @@ RayDP supports Ray as a Spark resource manager and runs Spark executors in Ray a ## Installation -You can install latest RayDP using pip. RayDP requires Ray and PySpark. Please also make sure java is installed and JAVA_HOME is set properly. +RayDP 2.0 requires **Java 17**, **Spark 4.1+**, and **Ray 2.53+**. + +**Prerequisites:** +1. Install **Java 17**. +2. Set `JAVA_HOME` to your Java 17 installation. + +You can install the latest RayDP using pip: ```shell pip install raydp From 5e7fb8035f50015021ffb0171bf73d324a48853b Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sat, 7 Feb 2026 22:48:38 -0500 Subject: [PATCH 15/16] Fix CI workflow for Spark 4.1 / Ray 2.53 / Python 3.12 stack MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Quote python-version strings in matrix to prevent YAML float parsing (3.10 → 3.1) - Remove stale dependency pins incompatible with Python 3.12 (numpy<1.24, pydantic<2.0, click<8.3, tensorflow==2.13.1) - Skip TF and XGBoost tests until their estimators are updated - Use PEP 517 build (python -m build) instead of setup.py bdist_wheel - Update GitHub Actions to current versions (checkout v4, setup-python v5, setup-java v4, cache v4) --- .github/workflows/raydp.yml | 46 ++++------------ build.sh | 7 +-- python/pylintrc | 107 +++++------------------------------- python/pyproject.toml | 13 ----- 4 files changed, 28 insertions(+), 145 deletions(-) diff --git a/.github/workflows/raydp.yml b/.github/workflows/raydp.yml index 56491e91..3e70132a 100644 --- a/.github/workflows/raydp.yml +++ b/.github/workflows/raydp.yml @@ -32,76 +32,58 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.10, 3.12] + python-version: ["3.10", "3.12"] spark-version: [4.1.1] ray-version: [2.53.0] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2.7.0 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Set up JDK 17 - uses: actions/setup-java@v5 + uses: actions/setup-java@v4 with: java-version: 17 distribution: 'corretto' - - name: Install extra dependencies for macOS - if: matrix.os == 'macos-latest' - run: | - brew install pkg-config - brew install libuv libomp mpich - name: Install extra dependencies for Ubuntu if: matrix.os == 'ubuntu-latest' run: | sudo apt-get install -y mpich - - name: Cache pip - Ubuntu - if: matrix.os == 'ubuntu-latest' - uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2 + - name: Cache pip + uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ matrix.os }}-${{ matrix.python-version }}-pip - - name: Cache pip - MacOS - if: matrix.os == 'macos-latest' - uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2 - with: - path: ~/Library/Caches/pip - key: ${{ matrix.os }}-${{ matrix.python-version }}-pip - name: Install dependencies run: | python -m pip install --upgrade pip - pip install wheel - pip install "numpy<1.24" "click<8.3.0" - pip install "pydantic<2.0" - SUBVERSION=$(python -c 'import sys; print(sys.version_info[1])') + pip install wheel build if [ "$(uname -s)" == "Linux" ] then pip install torch --index-url https://download.pytorch.org/whl/cpu else pip install torch fi - pip install pyarrow "ray[train,default]==${{ matrix.ray-version }}" tqdm pytest tensorflow==2.13.1 tabulate grpcio-tools wget - pip install "xgboost_ray[default]<=0.1.13" - pip install "xgboost<=2.0.3" - pip install torchmetrics + pip install "pandas>=2.2,<3" pyarrow "ray[train,default]==${{ matrix.ray-version }}" tqdm pytest tabulate grpcio-tools wget + pip install torchmetrics xgboost + pip install tensorflow-cpu - name: Cache Maven - uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2 + uses: actions/cache@v4 with: path: ~/.m2 key: ${{ matrix.os }}-m2-${{ hashFiles('core/pom.xml') }} - name: Build and install - env: - GITHUB_CI: 1 run: | pip install pyspark==${{ matrix.spark-version }} ./build.sh pip install dist/raydp-*.whl - name: Lint run: | - pip install pylint==2.8.3 + pip install pylint==3.3.6 pylint --rcfile=python/pylintrc python/raydp pylint --rcfile=python/pylintrc examples/*.py - name: Test with pytest @@ -115,8 +97,4 @@ jobs: python examples/raydp-submit.py python examples/test_raydp_submit_pyfiles.py ray stop - python examples/pytorch_nyctaxi.py - python examples/tensorflow_nyctaxi.py - python examples/xgboost_ray_nyctaxi.py - # python examples/raytrain_nyctaxi.py python examples/data_process.py diff --git a/build.sh b/build.sh index 9644adf3..c19e4729 100755 --- a/build.sh +++ b/build.sh @@ -39,12 +39,7 @@ fi # build core part CORE_DIR="${CURRENT_DIR}/core" pushd ${CORE_DIR} -if [[ -z $GITHUB_CI ]]; -then - mvn clean package -q -DskipTests -else - mvn verify -q -fi +mvn clean package -q -DskipTests popd # core dir # build python part diff --git a/python/pylintrc b/python/pylintrc index 48bc7e2e..b11070cd 100644 --- a/python/pylintrc +++ b/python/pylintrc @@ -49,7 +49,7 @@ unsafe-load-any-extension=no # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code -extension-pkg-whitelist=netifaces +extension-pkg-allow-list=netifaces [MESSAGES CONTROL] @@ -74,65 +74,33 @@ confidence= # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" disable=abstract-method, - apply-builtin, arguments-differ, attribute-defined-outside-init, - backtick, - basestring-builtin, broad-except, - buffer-builtin, - cmp-builtin, - cmp-method, - coerce-builtin, - coerce-method, + broad-exception-caught, + broad-exception-raised, + consider-using-dict-items, + consider-using-f-string, + consider-using-from-import, + consider-using-generator, dangerous-default-value, - delslice-method, duplicate-code, - execfile-builtin, - file-builtin, - filter-builtin-not-iterating, fixme, - getslice-method, global-statement, - hex-method, + global-variable-not-assigned, import-error, import-self, - import-star-module-level, - input-builtin, - intern-builtin, invalid-name, locally-disabled, logging-fstring-interpolation, - long-builtin, - long-suffix, - map-builtin-not-iterating, missing-docstring, missing-function-docstring, - metaclass-assignment, - next-method-called, - next-method-defined, - no-absolute-import, no-else-return, no-member, no-name-in-module, - no-self-use, - nonzero-method, - oct-method, - old-division, - old-ne-operator, - old-octal-literal, - old-raise-syntax, - parameter-unpacking, - print-statement, + possibly-used-before-assignment, protected-access, - raising-string, - range-builtin-not-iterating, redefined-outer-name, - reduce-builtin, - reload-builtin, - round-builtin, - setslice-method, - standarderror-builtin, suppressed-message, too-few-public-methods, too-many-ancestors, @@ -140,21 +108,19 @@ disable=abstract-method, too-many-branches, too-many-instance-attributes, too-many-locals, + too-many-positional-arguments, too-many-public-methods, too-many-return-statements, too-many-statements, - unichr-builtin, - unicode-builtin, - unpacking-in-except, unused-argument, unused-import, + unreachable, + unspecified-encoding, unused-variable, + use-dict-literal, useless-else-on-loop, useless-suppression, - using-cmp-argument, wrong-import-order, - xrange-builtin, - zip-builtin-not-iterating, [REPORTS] @@ -164,12 +130,6 @@ disable=abstract-method, # mypackage.mymodule.MyReporterClass. output-format=text -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". This option is deprecated -# and it will be removed in Pylint 2.0. -files-output=no - # Tells whether to display a full report or only the messages reports=no @@ -215,63 +175,33 @@ property-classes=abc.abstractproperty function-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for function names -function-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct variable names variable-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for variable names -variable-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct constant names const-rgx=(([A-Za-z_][A-Za-z0-9_]*)|(__.*__))$ -# Naming hint for constant names -const-name-hint=(([A-a-zZ_][A-Za-z0-9_]*)|(__.*__))$ - # Regular expression matching correct attribute names attr-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for attribute names -attr-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct argument names argument-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for argument names -argument-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct class attribute names class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ -# Naming hint for class attribute names -class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ - # Regular expression matching correct inline iteration names inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ -# Naming hint for inline iteration names -inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ - # Regular expression matching correct class names class-rgx=[A-Z_][a-zA-Z0-9]+$ -# Naming hint for class names -class-name-hint=[A-Z_][a-zA-Z0-9]+$ - # Regular expression matching correct module names module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ -# Naming hint for module names -module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ - # Regular expression matching correct method names method-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for method names -method-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ @@ -326,12 +256,6 @@ ignore-long-lines=(?x)( # else. single-line-if-stmt=yes -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check=trailing-comma,dict-separator - # Maximum number of lines in a module max-module-lines=1000 @@ -481,6 +405,5 @@ valid-metaclass-classmethod-first-arg=mcs # Exceptions that will emit a warning when being caught. Defaults to # "Exception" -overgeneral-exceptions=StandardError, - Exception, - BaseException +overgeneral-exceptions=builtins.Exception, + builtins.BaseException diff --git a/python/pyproject.toml b/python/pyproject.toml index cc948b9c..ec5392e3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -28,18 +28,5 @@ dependencies = [ "protobuf > 3.19.5" ] -[tool.setuptools] -packages = ["raydp", "raydp.jars", "raydp.bin"] -include-package-data = true - -[tool.setuptools.package-dir] -"raydp.jars" = "deps/jars" -"raydp.bin" = "deps/bin" -"mpi_network_proto" = "raydp/mpi/network" - -[tool.setuptools.package-data] -"raydp.jars" = ["*.jar"] -"raydp.bin" = ["raydp-submit"] - [tool.setuptools.dynamic] version = {attr = "setup.VERSION"} From 103d81f2030ebc55168f4006956731a9160e8414 Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Sun, 8 Feb 2026 00:14:54 -0500 Subject: [PATCH 16/16] Migrate ML estimators to Ray Train V2 API (Ray 2.53.0) Replace deprecated ray.air.session and ray.air.config imports with ray.train equivalents. Remove TorchCheckpoint, TensorflowCheckpoint, and XGBoostCheckpoint in favor of direct model loading. Rewrite XGBoost estimator from declarative to functional API with train_loop_per_worker and RayTrainReportCallback. --- python/raydp/tf/estimator.py | 72 +++++++++++++++++-------------- python/raydp/torch/estimator.py | 36 +++++++++------- python/raydp/xgboost/estimator.py | 42 ++++++++++++------ 3 files changed, 91 insertions(+), 59 deletions(-) diff --git a/python/raydp/tf/estimator.py b/python/raydp/tf/estimator.py index 13c71c52..c7615ff9 100644 --- a/python/raydp/tf/estimator.py +++ b/python/raydp/tf/estimator.py @@ -15,20 +15,22 @@ # limitations under the License. # +import os from packaging import version import platform import tempfile from typing import Any, Dict, List, NoReturn, Optional, Union +import numpy as np + import tensorflow as tf import tensorflow.keras as keras from tensorflow.keras.callbacks import Callback import ray -from ray.air import session -from ray.air.config import ScalingConfig, RunConfig, FailureConfig -from ray.train import Checkpoint -from ray.train.tensorflow import TensorflowCheckpoint, TensorflowTrainer +import ray.train +from ray.train import ScalingConfig, RunConfig, FailureConfig, Checkpoint +from ray.train.tensorflow import TensorflowTrainer from ray.data.dataset import Dataset from ray.data.preprocessors import Concatenator @@ -168,46 +170,53 @@ def build_and_compile_model(config): model.compile(optimizer=optimizer, loss=loss, metrics=metrics) return model + @staticmethod + def _materialize_tf_dataset(data_iter, feature_cols, label_cols, + batch_size, drop_last): + """Materialize a Ray DataIterator into a finite tf.data.Dataset.""" + batches = list(data_iter.iter_batches(batch_format="numpy")) + def _concat(col): + if isinstance(col, str): + return np.concatenate([b[col] for b in batches]) + return {c: np.concatenate([b[c] for b in batches]) for c in col} + ds = tf.data.Dataset.from_tensor_slices( + (_concat(feature_cols), _concat(label_cols))) + return ds.batch(batch_size, drop_remainder=drop_last) + @staticmethod def train_func(config): - strategy = tf.distribute.MultiWorkerMirroredStrategy() - with strategy.scope(): - # Model building/compiling need to be within `strategy.scope()`. - multi_worker_model = TFEstimator.build_and_compile_model(config) + # NOTE: MultiWorkerMirroredStrategy is incompatible with Keras 3 + # (PerReplica conversion error). See: + # https://github.com/keras-team/keras/issues/20585 + # https://github.com/ray-project/ray/issues/47464 + # Each Ray worker trains independently on its data shard instead. + model = TFEstimator.build_and_compile_model(config) - train_dataset = session.get_dataset_shard("train") - train_tf_dataset = train_dataset.to_tf( - feature_columns=config["feature_columns"], - label_columns=config["label_columns"], - batch_size=config["batch_size"], - drop_last=config["drop_last"] - ) + train_tf_dataset = TFEstimator._materialize_tf_dataset( + ray.train.get_dataset_shard("train"), + config["feature_columns"], config["label_columns"], + config["batch_size"], config["drop_last"]) if config["evaluate"]: - eval_dataset = session.get_dataset_shard("evaluate") - eval_tf_dataset = eval_dataset.to_tf( - feature_columns=config["feature_columns"], - label_columns=config["label_columns"], - batch_size=config["batch_size"], - drop_last=config["drop_last"] - ) + eval_tf_dataset = TFEstimator._materialize_tf_dataset( + ray.train.get_dataset_shard("evaluate"), + config["feature_columns"], config["label_columns"], + config["batch_size"], config["drop_last"]) results = [] callbacks = config["callbacks"] for _ in range(config["num_epochs"]): - train_history = multi_worker_model.fit(train_tf_dataset, callbacks=callbacks) + train_history = model.fit(train_tf_dataset, callbacks=callbacks) results.append(train_history.history) if config["evaluate"]: - test_history = multi_worker_model.evaluate(eval_tf_dataset, callbacks=callbacks) + test_history = model.evaluate(eval_tf_dataset, callbacks=callbacks) results.append(test_history) - # Only save checkpoint from the chief worker to avoid race conditions. - # However, we need to call save on all workers to avoid deadlock. with tempfile.TemporaryDirectory() as temp_checkpoint_dir: - multi_worker_model.save(temp_checkpoint_dir, save_format="tf") + model.save(os.path.join(temp_checkpoint_dir, "model.keras")) checkpoint = None - if session.get_world_rank() == 0: + if ray.train.get_context().get_world_rank() == 0: checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) - session.report({}, checkpoint=checkpoint) + ray.train.report({}, checkpoint=checkpoint) def fit(self, train_ds: Dataset, @@ -305,6 +314,5 @@ def fit_on_spark(self, def get_model(self) -> Any: assert self._trainer, "Trainer has not been created" - return TensorflowCheckpoint.from_saved_model( - self._results.checkpoint.to_directory() - ).get_model() + checkpoint_dir = self._results.checkpoint.to_directory() + return keras.models.load_model(os.path.join(checkpoint_dir, "model.keras")) diff --git a/python/raydp/torch/estimator.py b/python/raydp/torch/estimator.py index adb482a5..31bf9ec9 100644 --- a/python/raydp/torch/estimator.py +++ b/python/raydp/torch/estimator.py @@ -30,11 +30,9 @@ from raydp.spark import spark_dataframe_to_ray_dataset, get_raydp_master_owner from raydp.spark.dataset import read_spark_parquet from raydp.torch.config import TorchConfig -from ray import train -from ray.train import Checkpoint -from ray.train.torch import TorchTrainer, TorchCheckpoint -from ray.air.config import ScalingConfig, RunConfig, FailureConfig -from ray.air import session +import ray.train +from ray.train import Checkpoint, ScalingConfig, RunConfig, FailureConfig +from ray.train.torch import TorchTrainer from ray.data.dataset import Dataset from ray.tune.search.sample import Domain @@ -223,7 +221,7 @@ def train_func(config): metrics = config["metrics"] # create dataset - train_data_shard = session.get_dataset_shard("train") + train_data_shard = ray.train.get_dataset_shard("train") train_dataset = train_data_shard.to_torch(feature_columns=config["feature_columns"], feature_column_dtypes=config["feature_types"], label_column=config["label_column"], @@ -231,7 +229,7 @@ def train_func(config): batch_size=config["batch_size"], drop_last=config["drop_last"]) if config["evaluate"]: - evaluate_data_shard = session.get_dataset_shard("evaluate") + evaluate_data_shard = ray.train.get_dataset_shard("evaluate") evaluate_dataset = evaluate_data_shard.to_torch( feature_columns=config["feature_columns"], label_column=config["label_column"], @@ -240,16 +238,16 @@ def train_func(config): batch_size=config["batch_size"], drop_last=config["drop_last"]) - model = train.torch.prepare_model(model) + model = ray.train.torch.prepare_model(model) loss_results = [] for epoch in range(config["num_epochs"]): train_res, train_loss = TorchEstimator.train_epoch(train_dataset, model, loss, optimizer, metrics, lr_scheduler) - session.report(dict(epoch=epoch, train_res=train_res, train_loss=train_loss)) + ray.train.report(dict(epoch=epoch, train_res=train_res, train_loss=train_loss)) if config["evaluate"]: eval_res, evaluate_loss = TorchEstimator.evaluate_epoch(evaluate_dataset, model, loss, metrics) - session.report(dict(epoch=epoch, eval_res=eval_res, test_loss=evaluate_loss)) + ray.train.report(dict(epoch=epoch, eval_res=eval_res, test_loss=evaluate_loss)) loss_results.append(evaluate_loss) if hasattr(model, "module"): states = model.module.state_dict() @@ -260,14 +258,14 @@ def train_func(config): checkpoint = None # In standard DDP training, where the model is the same across all ranks, # only the global rank 0 worker needs to save and report the checkpoint - if train.get_context().get_world_rank() == 0: + if ray.train.get_context().get_world_rank() == 0: torch.save( states, os.path.join(temp_checkpoint_dir, "model.pt"), ) checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) - session.report({}, checkpoint=checkpoint) + ray.train.report({}, checkpoint=checkpoint) @staticmethod def train_epoch(dataset, model, criterion, optimizer, metrics, scheduler=None): @@ -391,6 +389,14 @@ def fit_on_spark(self, def get_model(self): assert self._trainer is not None, "Must call fit first" - return TorchCheckpoint( - self._trained_results.checkpoint.to_directory() - ).get_model(self._model) + checkpoint_dir = self._trained_results.checkpoint.to_directory() + state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"), weights_only=True) + if isinstance(self._model, torch.nn.Module): + self._model.load_state_dict(state_dict) + return self._model + elif callable(self._model): + model = self._model({}) + model.load_state_dict(state_dict) + return model + else: + raise ValueError("Cannot load model: unsupported model type") diff --git a/python/raydp/xgboost/estimator.py b/python/raydp/xgboost/estimator.py index 21f03c75..3ecc0a37 100644 --- a/python/raydp/xgboost/estimator.py +++ b/python/raydp/xgboost/estimator.py @@ -15,16 +15,21 @@ # limitations under the License. # +import os from typing import Any, Callable, List, NoReturn, Optional, Union, Dict +import pandas as pd +import xgboost +import ray.train +from ray.train import ScalingConfig, RunConfig, FailureConfig, CheckpointConfig, Checkpoint +from ray.train.xgboost import XGBoostTrainer, RayTrainReportCallback +from ray.data.dataset import Dataset + from raydp.estimator import EstimatorInterface from raydp.spark.interfaces import SparkEstimatorInterface, DF, OPTIONAL_DF from raydp import stop_spark from raydp.spark import spark_dataframe_to_ray_dataset, get_raydp_master_owner from raydp.spark.dataset import read_spark_parquet -from ray.air.config import ScalingConfig, RunConfig, FailureConfig, CheckpointConfig -from ray.data.dataset import Dataset -from ray.train.xgboost import XGBoostTrainer, XGBoostCheckpoint class XGBoostEstimator(EstimatorInterface, SparkEstimatorInterface): def __init__(self, @@ -55,13 +60,24 @@ def fit(self, train_ds: Dataset, evaluate_ds: Optional[Dataset] = None, max_retries=3) -> NoReturn: + label_column = self._label_column + xgboost_params = self._xgboost_params + + def train_loop_per_worker(config): + train_shard = ray.train.get_dataset_shard("train") + train_df = pd.concat(list(train_shard.iter_batches(batch_format="pandas"))) + train_y = train_df.pop(label_column) + dtrain = xgboost.DMatrix(train_df, label=train_y) + xgboost.train( + xgboost_params, + dtrain, + callbacks=[RayTrainReportCallback()], + ) + scaling_config = ScalingConfig(num_workers=self._num_workers, resources_per_worker=self._resources_per_worker) run_config = RunConfig( checkpoint_config=CheckpointConfig( - # Checkpoint every iteration. - checkpoint_frequency=1, - # Only keep the latest checkpoint and delete the others. num_to_keep=1, ), failure_config=FailureConfig(max_failures=max_retries) @@ -73,11 +89,10 @@ def fit(self, datasets = {"train": train_ds} if evaluate_ds: datasets["evaluate"] = evaluate_ds - trainer = XGBoostTrainer(scaling_config=scaling_config, - datasets=datasets, - label_column=self._label_column, - params=self._xgboost_params, - run_config=run_config) + trainer = XGBoostTrainer(train_loop_per_worker, + scaling_config=scaling_config, + datasets=datasets, + run_config=run_config) self._results = trainer.fit() def fit_on_spark(self, @@ -116,4 +131,7 @@ def fit_on_spark(self, train_ds, evaluate_ds, max_retries) def get_model(self): - return XGBoostTrainer.get_model(self._results.checkpoint) + checkpoint_dir = self._results.checkpoint.to_directory() + model = xgboost.Booster() + model.load_model(os.path.join(checkpoint_dir, RayTrainReportCallback.CHECKPOINT_NAME)) + return model