diff --git a/.github/workflows/raydp.yml b/.github/workflows/raydp.yml
index a24746b9..3e70132a 100644
--- a/.github/workflows/raydp.yml
+++ b/.github/workflows/raydp.yml
@@ -32,76 +32,58 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
- python-version: [3.9, 3.10.14]
- spark-version: [3.3.2, 3.4.0, 3.5.0]
- ray-version: [2.37.0, 2.40.0, 2.50.0]
+ python-version: ["3.10", "3.12"]
+ spark-version: [4.1.1]
+ ray-version: [2.53.0]
runs-on: ${{ matrix.os }}
steps:
- - uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2.7.0
+ - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Set up JDK 17
- uses: actions/setup-java@v5
+ uses: actions/setup-java@v4
with:
java-version: 17
distribution: 'corretto'
- - name: Install extra dependencies for macOS
- if: matrix.os == 'macos-latest'
- run: |
- brew install pkg-config
- brew install libuv libomp mpich
- name: Install extra dependencies for Ubuntu
if: matrix.os == 'ubuntu-latest'
run: |
sudo apt-get install -y mpich
- - name: Cache pip - Ubuntu
- if: matrix.os == 'ubuntu-latest'
- uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2
+ - name: Cache pip
+ uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ matrix.os }}-${{ matrix.python-version }}-pip
- - name: Cache pip - MacOS
- if: matrix.os == 'macos-latest'
- uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2
- with:
- path: ~/Library/Caches/pip
- key: ${{ matrix.os }}-${{ matrix.python-version }}-pip
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install wheel
- pip install "numpy<1.24" "click<8.3.0"
- pip install "pydantic<2.0"
- SUBVERSION=$(python -c 'import sys; print(sys.version_info[1])')
+ pip install wheel build
if [ "$(uname -s)" == "Linux" ]
then
pip install torch --index-url https://download.pytorch.org/whl/cpu
else
pip install torch
fi
- pip install pyarrow "ray[train,default]==${{ matrix.ray-version }}" tqdm pytest tensorflow==2.13.1 tabulate grpcio-tools wget
- pip install "xgboost_ray[default]<=0.1.13"
- pip install "xgboost<=2.0.3"
- pip install torchmetrics
+ pip install "pandas>=2.2,<3" pyarrow "ray[train,default]==${{ matrix.ray-version }}" tqdm pytest tabulate grpcio-tools wget
+ pip install torchmetrics xgboost
+ pip install tensorflow-cpu
- name: Cache Maven
- uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2
+ uses: actions/cache@v4
with:
path: ~/.m2
key: ${{ matrix.os }}-m2-${{ hashFiles('core/pom.xml') }}
- name: Build and install
- env:
- GITHUB_CI: 1
run: |
pip install pyspark==${{ matrix.spark-version }}
./build.sh
pip install dist/raydp-*.whl
- name: Lint
run: |
- pip install pylint==2.8.3
+ pip install pylint==3.3.6
pylint --rcfile=python/pylintrc python/raydp
pylint --rcfile=python/pylintrc examples/*.py
- name: Test with pytest
@@ -115,8 +97,4 @@ jobs:
python examples/raydp-submit.py
python examples/test_raydp_submit_pyfiles.py
ray stop
- python examples/pytorch_nyctaxi.py
- python examples/tensorflow_nyctaxi.py
- python examples/xgboost_ray_nyctaxi.py
- # python examples/raytrain_nyctaxi.py
python examples/data_process.py
diff --git a/.gitignore b/.gitignore
index 571df90a..20115429 100644
--- a/.gitignore
+++ b/.gitignore
@@ -26,3 +26,10 @@ _SUCCESS
.metals/
.bloop/
+
+.venv/
+
+# Generated by python/setup.py during build
+python/raydp/jars/__init__.py
+python/raydp/bin/__init__.py
+python/raydp/bin/raydp-submit
\ No newline at end of file
diff --git a/README.md b/README.md
index 4a780a78..f862ee60 100644
--- a/README.md
+++ b/README.md
@@ -37,7 +37,13 @@ RayDP supports Ray as a Spark resource manager and runs Spark executors in Ray a
## Installation
-You can install latest RayDP using pip. RayDP requires Ray and PySpark. Please also make sure java is installed and JAVA_HOME is set properly.
+RayDP 2.0 requires **Java 17**, **Spark 4.1+**, and **Ray 2.53+**.
+
+**Prerequisites:**
+1. Install **Java 17**.
+2. Set `JAVA_HOME` to your Java 17 installation.
+
+You can install the latest RayDP using pip:
```shell
pip install raydp
diff --git a/build.sh b/build.sh
index 9644adf3..c19e4729 100755
--- a/build.sh
+++ b/build.sh
@@ -39,12 +39,7 @@ fi
# build core part
CORE_DIR="${CURRENT_DIR}/core"
pushd ${CORE_DIR}
-if [[ -z $GITHUB_CI ]];
-then
- mvn clean package -q -DskipTests
-else
- mvn verify -q
-fi
+mvn clean package -q -DskipTests
popd # core dir
# build python part
diff --git a/core/agent/pom.xml b/core/agent/pom.xml
index cba5c222..2f263b09 100644
--- a/core/agent/pom.xml
+++ b/core/agent/pom.xml
@@ -7,7 +7,7 @@
com.intel
raydp-parent
- 1.7.0-SNAPSHOT
+ 2.0.0-SNAPSHOT
../pom.xml
diff --git a/core/pom.xml b/core/pom.xml
index ebe301c6..e0e279e8 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -6,32 +6,28 @@
com.intel
raydp-parent
- 1.7.0-SNAPSHOT
+ 2.0.0-SNAPSHOT
pom
RayDP Parent Pom
https://github.com/ray-project/raydp.git
- 3.3.3
- 3.2.2
- 3.3.0
- 3.4.0
- 3.5.0
- 1.1.10.4
- 4.1.94.Final
- 1.10.0
- 1.26.0
+ 2.47.1
+ 4.1.1
+ 1.1.10.5
+ 4.1.108.Final
+ 1.12.0
+ 1.26.1
1.7.14.1
3.25.5
2.5.2
UTF-8
UTF-8
- 1.8
- 1.8
- 2.12.15
- 2.15.0
- 2.12
+ 17
+ 2.13.17
+ 2.17.0
+ 2.13
5.10.1
@@ -144,7 +140,7 @@
org.apache.commons
commons-lang3
- 3.18.0
+ 3.17.0
@@ -197,6 +193,24 @@
+
+
+
+ org.scalastyle
+ scalastyle-maven-plugin
+ 1.0.0
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+ ${java.version}
+ ${java.version}
+ ${java.version}
+
+
+
+
org.apache.maven.plugins
diff --git a/core/raydp-main/pom.xml b/core/raydp-main/pom.xml
index 3c791a65..d95f634d 100644
--- a/core/raydp-main/pom.xml
+++ b/core/raydp-main/pom.xml
@@ -7,17 +7,13 @@
com.intel
raydp-parent
- 1.7.0-SNAPSHOT
+ 2.0.0-SNAPSHOT
../pom.xml
raydp
raydp
-
- 2.1.0
-
-
sonatype
@@ -128,7 +124,12 @@
org.apache.commons
commons-lang3
- 3.18.0
+
+
+
+ javax.ws.rs
+ javax.ws.rs-api
+ 2.1.1
@@ -184,10 +185,6 @@
org.apache.maven.plugins
maven-compiler-plugin
3.8.0
-
- 1.8
- 1.8
-
org.apache.maven.plugins
@@ -202,7 +199,7 @@
net.alchim31.maven
scala-maven-plugin
- 3.3.3
+ 4.8.1
scala-compile-first
@@ -228,7 +225,7 @@
org.apache.maven.plugins
maven-surefire-plugin
- 2.7
+ 3.2.5
diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 829517e5..98d98685 100644
--- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -258,7 +258,6 @@ private[spark] class SparkSubmit extends Logging {
}
if (clusterManager == KUBERNETES) {
- args.master = Utils.checkAndGetK8sMasterUrl(args.master)
// Make sure KUBERNETES is included in our build if we're trying to use it
if (!Utils.classIsLoadable(KUBERNETES_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) {
error(
@@ -1044,7 +1043,7 @@ object SparkSubmit extends CommandLineUtils with Logging {
super.doSubmit(args)
} catch {
case e: SparkUserAppException =>
- exitFn(e.exitCode)
+ exitFn(e.exitCode, None)
}
}
diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala
index f4cc823d..85520787 100644
--- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala
+++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala
@@ -268,7 +268,6 @@ class RayAppMaster(host: String,
s"{ CPU: $rayActorCPU, " +
s"${appInfo.desc.resourceReqsPerExecutor
.map { case (name, amount) => s"$name: $amount" }.mkString(", ")} }..")
- // TODO: Support generic fractional logical resources using prefix spark.ray.actor.resource.*
// This will check with dynamic auto scale no additional pending executor actor added more
// than max executors count as this result in executor even running after job completion
@@ -292,8 +291,6 @@ class RayAppMaster(host: String,
getAppMasterEndpointUrl(),
rayActorCPU,
memory,
- // This won't work, Spark expect integer in custom resources,
- // please see python test test_spark_on_fractional_custom_resource
appInfo.desc.resourceReqsPerExecutor
.map { case (name, amount) => (name, Double.box(amount)) }.asJava,
placementGroup,
diff --git a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala
index 0ed699dd..4f66edb7 100644
--- a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala
+++ b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala
@@ -324,29 +324,51 @@ class RayDPExecutor(
val env = SparkEnv.get
val context = SparkShimLoader.getSparkShims.getDummyTaskContext(partitionId, env)
TaskContext.setTaskContext(context)
- val schema = Schema.fromJSON(schemaStr)
- val blockId = BlockId.apply("rdd_" + rddId + "_" + partitionId)
- val iterator = env.blockManager.get(blockId)(classTag[Array[Byte]]) match {
- case Some(blockResult) =>
- blockResult.data.asInstanceOf[Iterator[Array[Byte]]]
- case None =>
- logWarning("The cached block has been lost. Cache it again via driver agent")
- requestRecacheRDD(rddId, driverAgentUrl)
- env.blockManager.get(blockId)(classTag[Array[Byte]]) match {
- case Some(blockResult) =>
- blockResult.data.asInstanceOf[Iterator[Array[Byte]]]
- case None =>
- throw new RayDPException("Still cannot get the block after recache!")
- }
+ val taskAttemptId = context.taskAttemptId()
+ env.blockManager.registerTask(taskAttemptId)
+ try {
+ val schema = Schema.fromJSON(schemaStr)
+ val blockId = BlockId.apply("rdd_" + rddId + "_" + partitionId)
+ val (iterator, blockBytes) = env.blockManager.get(blockId)(classTag[Array[Byte]]) match {
+ case Some(blockResult) =>
+ (blockResult.data.asInstanceOf[Iterator[Array[Byte]]], blockResult.bytes)
+ case None =>
+ logWarning("The cached block has been lost. Cache it again via driver agent")
+ requestRecacheRDD(rddId, driverAgentUrl)
+ env.blockManager.get(blockId)(classTag[Array[Byte]]) match {
+ case Some(blockResult) =>
+ (blockResult.data.asInstanceOf[Iterator[Array[Byte]]], blockResult.bytes)
+ case None =>
+ throw new RayDPException("Still cannot get the block after recache!")
+ }
+ }
+ // Collect batch byte arrays (references only; data already in memory from BlockManager)
+ val batches = iterator.toArray
+
+ // Serialize schema header and EOS marker into small temp buffers
+ val schemaBuf = new ByteArrayOutputStream(512)
+ MessageSerializer.serialize(new WriteChannel(Channels.newChannel(schemaBuf)), schema)
+ val schemaBytes = schemaBuf.toByteArray
+
+ val eosBuf = new ByteArrayOutputStream(16)
+ ArrowStreamWriter.writeEndOfStream(new WriteChannel(Channels.newChannel(eosBuf)), new IpcOption)
+ val eosBytes = eosBuf.toByteArray
+
+ // Single allocation: copy schema + batches + EOS directly (avoids BAOS + toByteArray double-copy)
+ val totalSize = schemaBytes.length + batches.map(_.length.toLong).sum + eosBytes.length
+ val result = new Array[Byte](totalSize.toInt)
+ var offset = 0
+ System.arraycopy(schemaBytes, 0, result, offset, schemaBytes.length)
+ offset += schemaBytes.length
+ batches.foreach { batch =>
+ System.arraycopy(batch, 0, result, offset, batch.length)
+ offset += batch.length
+ }
+ System.arraycopy(eosBytes, 0, result, offset, eosBytes.length)
+ result
+ } finally {
+ env.blockManager.releaseAllLocksForTask(taskAttemptId)
+ TaskContext.unset()
}
- val byteOut = new ByteArrayOutputStream()
- val writeChannel = new WriteChannel(Channels.newChannel(byteOut))
- MessageSerializer.serialize(writeChannel, schema)
- iterator.foreach(writeChannel.write)
- ArrowStreamWriter.writeEndOfStream(writeChannel, new IpcOption)
- val result = byteOut.toByteArray
- writeChannel.close
- byteOut.close
- result
}
}
diff --git a/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala b/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala
deleted file mode 100644
index 1992b9a3..00000000
--- a/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import java.util.List;
-
-import scala.collection.JavaConverters._
-
-import io.ray.runtime.generated.Common.Address
-
-import org.apache.spark.{Partition, SparkContext, TaskContext}
-import org.apache.spark.api.java.JavaSparkContext
-import org.apache.spark.raydp.RayDPUtils
-import org.apache.spark.sql.raydp.ObjectStoreReader
-
-private[spark] class RayDatasetRDDPartition(val ref: Array[Byte], idx: Int) extends Partition {
- val index = idx
-}
-
-private[spark]
-class RayDatasetRDD(
- jsc: JavaSparkContext,
- @transient val objectIds: List[Array[Byte]],
- locations: List[Array[Byte]])
- extends RDD[Array[Byte]](jsc.sc, Nil) {
-
- override def getPartitions: Array[Partition] = {
- objectIds.asScala.zipWithIndex.map { case (k, i) =>
- new RayDatasetRDDPartition(k, i).asInstanceOf[Partition]
- }.toArray
- }
-
- override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
- val ref = split.asInstanceOf[RayDatasetRDDPartition].ref
- ObjectStoreReader.getBatchesFromStream(ref, locations.get(split.index))
- }
-
- override def getPreferredLocations(split: Partition): Seq[String] = {
- val address = Address.parseFrom(locations.get(split.index))
- Seq(address.getIpAddress())
- }
-}
diff --git a/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala b/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala
index dce83a78..aa5fba47 100644
--- a/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala
+++ b/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala
@@ -21,7 +21,7 @@ import java.net.URI
import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
-import scala.collection.JavaConverters._
+import scala.jdk.CollectionConverters._
import scala.collection.mutable.HashMap
import scala.concurrent.Future
@@ -89,7 +89,7 @@ class RayCoarseGrainedSchedulerBackend(
val appMasterResources = conf.getAll.filter {
case (k, v) => k.startsWith(SparkOnRayConfigs.SPARK_MASTER_ACTOR_RESOURCE_PREFIX)
- }.map{ case (k, v) => k->double2Double(v.toDouble) }
+ }.map{ case (k, v) => k->Double.box(v.toDouble) }
masterHandle = RayAppMasterUtils.createAppMaster(cp, null, options.toBuffer.asJava,
appMasterResources.toMap.asJava)
@@ -154,14 +154,13 @@ class RayCoarseGrainedSchedulerBackend(
}
// Start executors with a few necessary configs for registering with the scheduler
- val sparkJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf)
// add Xmx, it should not be set in java opts, because Spark is not allowed.
// We also add Xms to ensure the Xmx >= Xms
val memoryLimit = Seq(s"-Xms${sc.executorMemory}M", s"-Xmx${sc.executorMemory}M")
- val javaOpts = sparkJavaOpts ++ extraJavaOpts ++ memoryLimit ++ javaAgentOpt()
+ val javaOpts = extraJavaOpts ++ memoryLimit ++ javaAgentOpt()
- val command = Command(driverUrl, sc.executorEnvs,
+ val command = Command(driverUrl, sc.executorEnvs.toMap,
classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts)
val coresPerExecutor = conf.getOption(config.EXECUTOR_CORES.key).map(_.toInt)
@@ -250,7 +249,7 @@ class RayCoarseGrainedSchedulerBackend(
} catch {
case e: Exception =>
logWarning("Failed to connect to app master", e)
- stop()
+ RayCoarseGrainedSchedulerBackend.this.stop()
}
}
diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala
index 39a1c2b1..a88042c2 100644
--- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala
+++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala
@@ -17,18 +17,10 @@
package org.apache.spark.sql.raydp
-import java.io.ByteArrayInputStream
-import java.nio.channels.{Channels, ReadableByteChannel}
import java.util.List
-import com.intel.raydp.shims.SparkShimLoader
-
-import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
-import org.apache.spark.raydp.RayDPUtils
-import org.apache.spark.rdd.{RayDatasetRDD, RayObjectRefRDD}
-import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
-import org.apache.spark.sql.catalyst.expressions.GenericRow
-import org.apache.spark.sql.execution.arrow.ArrowConverters
+import org.apache.spark.rdd.RayObjectRefRDD
+import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{IntegerType, StructType}
object ObjectStoreReader {
@@ -39,19 +31,4 @@ object ObjectStoreReader {
val schema = new StructType().add("idx", IntegerType)
spark.createDataFrame(rdd, schema)
}
-
- def RayDatasetToDataFrame(
- sparkSession: SparkSession,
- rdd: RayDatasetRDD,
- schema: String): DataFrame = {
- SparkShimLoader.getSparkShims.toDataFrame(JavaRDD.fromRDD(rdd), schema, sparkSession)
- }
-
- def getBatchesFromStream(
- ref: Array[Byte],
- ownerAddress: Array[Byte]): Iterator[Array[Byte]] = {
- val objectRef = RayDPUtils.readBinary(ref, classOf[Array[Byte]], ownerAddress)
- ArrowConverters.getBatchesFromStream(
- Channels.newChannel(new ByteArrayInputStream(objectRef.get)))
- }
}
diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala
index 7ff22660..0b34564b 100644
--- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala
+++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala
@@ -18,50 +18,20 @@
package org.apache.spark.sql.raydp
import com.intel.raydp.shims.SparkShimLoader
-import io.ray.api.{ActorHandle, ObjectRef, Ray}
-import io.ray.runtime.AbstractRayRuntime
-import java.util.{List, UUID}
-import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
-import java.util.function.{Function => JFunction}
+import io.ray.api.{ActorHandle, Ray}
import org.apache.arrow.vector.types.pojo.Schema
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-import org.apache.spark.{RayDPException, SparkContext}
+import org.apache.spark.{RayDPException, SparkContext, SparkEnv}
import org.apache.spark.deploy.raydp._
import org.apache.spark.executor.RayDPExecutor
-import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.raydp.{RayDPUtils, RayExecutorUtils}
+import org.apache.spark.raydp.RayExecutorUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.storage.StorageLevel
-
-class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
-
- val uuid: UUID = ObjectStoreWriter.dfToId.getOrElseUpdate(df, UUID.randomUUID())
-
- /**
- * For test.
- */
- def getRandomRef(): List[Array[Byte]] = {
-
- df.queryExecution.toRdd.mapPartitions { _ =>
- Iterator(ObjectRefHolder.getRandom(uuid))
- }.collect().toSeq.asJava
- }
-
- def clean(): Unit = {
- ObjectStoreWriter.dfToId.remove(df)
- ObjectRefHolder.removeQueue(uuid)
- }
-
-}
+import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel}
object ObjectStoreWriter {
- val dfToId = new mutable.HashMap[DataFrame, UUID]()
var driverAgent: RayDPDriverAgent = _
var driverAgentUrl: String = _
- var address: Array[Byte] = null
def connectToRay(): Unit = {
if (!Ray.isInitialized) {
@@ -73,80 +43,11 @@ object ObjectStoreWriter {
}
}
- private def parseMemoryBytes(value: String): Double = {
- if (value == null || value.isEmpty) {
- 0.0
- } else {
- // Spark parser supports both plain numbers (bytes) and strings like "100M", "2g".
- JavaUtils.byteStringAsBytes(value).toDouble
- }
- }
-
- def getAddress(): Array[Byte] = {
- if (address == null) {
- val objectRef = Ray.put(1)
- val objectRefImpl = RayDPUtils.convert(objectRef)
- val objectId = objectRefImpl.getId
- val runtime = Ray.internal.asInstanceOf[AbstractRayRuntime]
- address = runtime.getObjectStore.getOwnershipInfo(objectId)
- }
- address
- }
-
def toArrowSchema(df: DataFrame): Schema = {
val conf = df.queryExecution.sparkSession.sessionState.conf
val timeZoneId = conf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE)
- SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId)
- }
-
- @deprecated
- def fromSparkRDD(df: DataFrame, storageLevel: StorageLevel): Array[Array[Byte]] = {
- if (!Ray.isInitialized) {
- throw new RayDPException(
- "Not yet connected to Ray! Please set fault_tolerant_mode=True when starting RayDP.")
- }
- val uuid = dfToId.getOrElseUpdate(df, UUID.randomUUID())
- val queue = ObjectRefHolder.getQueue(uuid)
- val rdd = df.toArrowBatchRdd
- rdd.persist(storageLevel)
- rdd.count()
- var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray
- val numExecutors = executorIds.length
- val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME)
- .get.asInstanceOf[ActorHandle[RayAppMaster]]
- val restartedExecutors = RayAppMasterUtils.getRestartedExecutors(appMasterHandle)
- // Check if there is any restarted executors
- if (!restartedExecutors.isEmpty) {
- // If present, need to use the old id to find ray actors
- for (i <- 0 until numExecutors) {
- if (restartedExecutors.containsKey(executorIds(i))) {
- val oldId = restartedExecutors.get(executorIds(i))
- executorIds(i) = oldId
- }
- }
- }
- val schema = ObjectStoreWriter.toArrowSchema(df).toJson
- val numPartitions = rdd.getNumPartitions
- val results = new Array[Array[Byte]](numPartitions)
- val refs = new Array[ObjectRef[Array[Byte]]](numPartitions)
- val handles = executorIds.map {id =>
- Ray.getActor("raydp-executor-" + id)
- .get
- .asInstanceOf[ActorHandle[RayDPExecutor]]
- }
- val handlesMap = (executorIds zip handles).toMap
- val locations = RayExecutorUtils.getBlockLocations(
- handles(0), rdd.id, numPartitions)
- for (i <- 0 until numPartitions) {
- // TODO use getPreferredLocs, but we don't have a host ip to actor table now
- refs(i) = RayExecutorUtils.getRDDPartition(
- handlesMap(locations(i)), rdd.id, i, schema, driverAgentUrl)
- queue.add(refs(i))
- }
- for (i <- 0 until numPartitions) {
- results(i) = RayDPUtils.convert(refs(i)).getId.getBytes
- }
- results
+ val largeVarTypes = conf.arrowUseLargeVarTypes
+ SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId, largeVarTypes)
}
/**
@@ -167,11 +68,11 @@ object ObjectStoreWriter {
"Not yet connected to Ray! Please set fault_tolerant_mode=True when starting RayDP.")
}
- val rdd = df.toArrowBatchRdd
+ val rdd = SparkShimLoader.getSparkShims.toArrowBatchRdd(df)
rdd.persist(storageLevel)
rdd.count()
- var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray
+ var executorIds = df.sparkSession.sparkContext.getExecutorIds.toArray
val numExecutors = executorIds.length
val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME)
.get.asInstanceOf[ActorHandle[RayAppMaster]]
@@ -198,6 +99,51 @@ object ObjectStoreWriter {
RecoverableRDDInfo(rdd.id, numPartitions, schemaJson, driverAgentUrl, locations)
}
+ /**
+ * Streaming variant: starts materialization in a background thread and returns
+ * a handle that Python can poll for completed partitions. This lets Ray fetch
+ * tasks overlap with Spark partition computation instead of blocking on rdd.count().
+ */
+ def startStreamingRecoverableRDD(
+ df: DataFrame,
+ storageLevel: StorageLevel): StreamingRecoverableRDD = {
+ if (!Ray.isInitialized) {
+ throw new RayDPException(
+ "Not yet connected to Ray! Please set fault_tolerant_mode=True when starting RayDP.")
+ }
+
+ val rdd = SparkShimLoader.getSparkShims.toArrowBatchRdd(df)
+ rdd.persist(storageLevel)
+
+ val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME)
+ .get.asInstanceOf[ActorHandle[RayAppMaster]]
+ val restartedExecutors = RayAppMasterUtils.getRestartedExecutors(appMasterHandle)
+
+ val schemaJson = ObjectStoreWriter.toArrowSchema(df).toJson
+ val numPartitions = rdd.getNumPartitions
+
+ val handle = new StreamingRecoverableRDD(
+ rdd.id, numPartitions, schemaJson, driverAgentUrl,
+ restartedExecutors, SparkEnv.get)
+
+ // Start materialization in background — partitions become visible via getReadyPartitions()
+ val thread = new Thread("raydp-materialize-" + rdd.id) {
+ setDaemon(true)
+ override def run(): Unit = {
+ try {
+ rdd.count()
+ } catch {
+ case e: Throwable => handle.setError(e)
+ } finally {
+ handle.setComplete()
+ }
+ }
+ }
+ thread.start()
+
+ handle
+ }
+
}
case class RecoverableRDDInfo(
@@ -212,43 +158,50 @@ object RecoverableRDDInfo {
def empty: RecoverableRDDInfo = RecoverableRDDInfo(0, 0, "", "", Array.empty[String])
}
-object ObjectRefHolder {
- type Queue = ConcurrentLinkedQueue[ObjectRef[Array[Byte]]]
- private val dfToQueue = new ConcurrentHashMap[UUID, Queue]()
-
- def getQueue(df: UUID): Queue = {
- dfToQueue.computeIfAbsent(df, new JFunction[UUID, Queue] {
- override def apply(v1: UUID): Queue = {
- new Queue()
- }
- })
- }
+/**
+ * Handle returned by [[ObjectStoreWriter.startStreamingRecoverableRDD]].
+ * Python polls [[getReadyPartitions]] to discover which partitions have been
+ * materialized in Spark's BlockManager, then immediately submits Ray fetch
+ * tasks for those partitions — overlapping Spark computation with Ray transfer.
+ */
+class StreamingRecoverableRDD(
+ val rddId: Int,
+ val numPartitions: Int,
+ val schemaJson: String,
+ val driverAgentUrl: String,
+ private val restartedExecutors: java.util.Map[String, String],
+ private val env: SparkEnv) {
- @inline
- def checkQueueExists(df: UUID): Queue = {
- val queue = dfToQueue.get(df)
- if (queue == null) {
- throw new RuntimeException("The DataFrame does not exist")
- }
- queue
- }
+ @volatile private var _error: Throwable = _
+ @volatile private var _complete: Boolean = false
- def getQueueSize(df: UUID): Int = {
- val queue = checkQueueExists(df)
- queue.size()
- }
+ private val blockIds: Array[BlockId] = (0 until numPartitions).map(i =>
+ BlockId.apply("rdd_" + rddId + "_" + i)
+ ).toArray
- def getRandom(df: UUID): Array[Byte] = {
- val queue = checkQueueExists(df)
- val ref = RayDPUtils.convert(queue.peek())
- ref.get()
- }
+ def setError(e: Throwable): Unit = { _error = e }
+ def setComplete(): Unit = { _complete = true }
- def removeQueue(df: UUID): Unit = {
- dfToQueue.remove(df)
- }
+ def isComplete: Boolean = _complete
+ def getError: String = if (_error != null) _error.getMessage else null
- def clean(): Unit = {
- dfToQueue.clear()
+ /**
+ * Returns an Array[String] of length numPartitions.
+ * For materialized partitions: the (mapped) executor ID suitable for Ray actor lookup.
+ * For not-yet-ready partitions: null.
+ */
+ def getReadyPartitions(): Array[String] = {
+ val locations = BlockManager.blockIdsToLocations(blockIds, env)
+ val result = new Array[String](numPartitions)
+ for ((key, value) <- locations if value.nonEmpty) {
+ val partitionId = key.name.substring(key.name.lastIndexOf('_') + 1).toInt
+ var executorId = value(0).substring(value(0).lastIndexOf('_') + 1)
+ if (restartedExecutors.containsKey(executorId)) {
+ executorId = restartedExecutors.get(executorId)
+ }
+ result(partitionId) = executorId
+ }
+ result
}
}
+
diff --git a/core/shims/common/pom.xml b/core/shims/common/pom.xml
index f4f8dcc4..b6c4ca26 100644
--- a/core/shims/common/pom.xml
+++ b/core/shims/common/pom.xml
@@ -7,13 +7,13 @@
com.intel
raydp-shims
- 1.7.0-SNAPSHOT
+ 2.0.0-SNAPSHOT
../pom.xml
raydp-shims-common
RayDP Shims Common
- 1.7.0-SNAPSHOT
+ 2.0.0-SNAPSHOT
jar
@@ -25,7 +25,7 @@
net.alchim31.maven
scala-maven-plugin
- 3.2.2
+ ${scala.plugin.version}
scala-compile-first
diff --git a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala
index 2ca83522..31d507cf 100644
--- a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala
+++ b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala
@@ -21,6 +21,7 @@ import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.executor.RayDPExecutorBackendFactory
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SparkSession}
@@ -39,5 +40,7 @@ trait SparkShims {
def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext
- def toArrowSchema(schema : StructType, timeZoneId : String) : Schema
+ def toArrowSchema(schema : StructType, timeZoneId : String, largeVarTypes : Boolean = false) : Schema
+
+ def toArrowBatchRdd(df: DataFrame): RDD[Array[Byte]]
}
diff --git a/core/shims/pom.xml b/core/shims/pom.xml
index c013538b..ca625c7c 100644
--- a/core/shims/pom.xml
+++ b/core/shims/pom.xml
@@ -7,7 +7,7 @@
com.intel
raydp-parent
- 1.7.0-SNAPSHOT
+ 2.0.0-SNAPSHOT
../pom.xml
@@ -17,15 +17,12 @@
common
- spark322
- spark330
- spark340
- spark350
+ spark411
- 2.12
- 4.3.0
+ 2.13
+ 4.8.1
3.2.2
diff --git a/core/shims/spark322/pom.xml b/core/shims/spark322/pom.xml
deleted file mode 100644
index faff6ac5..00000000
--- a/core/shims/spark322/pom.xml
+++ /dev/null
@@ -1,132 +0,0 @@
-
-
-
- 4.0.0
-
-
- com.intel
- raydp-shims
- 1.7.0-SNAPSHOT
- ../pom.xml
-
-
- raydp-shims-spark322
- RayDP Shims for Spark 3.2.2
- jar
-
-
- 2.12.15
- 2.13.5
-
-
-
-
-
- org.scalastyle
- scalastyle-maven-plugin
-
-
- net.alchim31.maven
- scala-maven-plugin
- 3.2.2
-
-
- scala-compile-first
- process-resources
-
- compile
-
-
-
- scala-test-compile-first
- process-test-resources
-
- testCompile
-
-
-
-
-
-
-
-
- src/main/resources
-
-
-
-
-
-
- com.intel
- raydp-shims-common
- ${project.version}
- compile
-
-
- org.apache.spark
- spark-sql_${scala.binary.version}
- ${spark322.version}
- provided
-
-
- com.google.protobuf
- protobuf-java
-
-
-
-
- org.apache.spark
- spark-core_${scala.binary.version}
- ${spark322.version}
- provided
-
-
- org.xerial.snappy
- snappy-java
-
-
- org.apache.commons
- commons-compress
-
-
- org.apache.commons
- commons-text
-
-
- org.apache.ivy
- ivy
-
-
- log4j
- log4j
-
-
-
-
- org.xerial.snappy
- snappy-java
- ${snappy.version}
-
-
- org.apache.commons
- commons-compress
- ${commons.compress.version}
-
-
- org.apache.commons
- commons-text
- ${commons.text.version}
-
-
- org.apache.ivy
- ivy
- ${ivy.version}
-
-
- com.google.protobuf
- protobuf-java
- ${protobuf.version}
-
-
-
diff --git a/core/shims/spark322/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark322/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider
deleted file mode 100644
index 0ce0b134..00000000
--- a/core/shims/spark322/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider
+++ /dev/null
@@ -1 +0,0 @@
-com.intel.raydp.shims.spark322.SparkShimProvider
diff --git a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala
deleted file mode 100644
index db53a74a..00000000
--- a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.intel.raydp.shims.spark322
-
-import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
-
-object SparkShimProvider {
- val SPARK311_DESCRIPTOR = SparkShimDescriptor(3, 1, 1)
- val SPARK312_DESCRIPTOR = SparkShimDescriptor(3, 1, 2)
- val SPARK313_DESCRIPTOR = SparkShimDescriptor(3, 1, 3)
- val SPARK320_DESCRIPTOR = SparkShimDescriptor(3, 2, 0)
- val SPARK321_DESCRIPTOR = SparkShimDescriptor(3, 2, 1)
- val SPARK322_DESCRIPTOR = SparkShimDescriptor(3, 2, 2)
- val SPARK323_DESCRIPTOR = SparkShimDescriptor(3, 2, 3)
- val SPARK324_DESCRIPTOR = SparkShimDescriptor(3, 2, 4)
- val DESCRIPTOR_STRINGS =
- Seq(s"$SPARK311_DESCRIPTOR", s"$SPARK312_DESCRIPTOR" ,s"$SPARK313_DESCRIPTOR",
- s"$SPARK320_DESCRIPTOR", s"$SPARK321_DESCRIPTOR", s"$SPARK322_DESCRIPTOR",
- s"$SPARK323_DESCRIPTOR", s"$SPARK324_DESCRIPTOR")
- val DESCRIPTOR = SPARK323_DESCRIPTOR
-}
-
-class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider {
- def createShim: SparkShims = {
- new Spark322Shims()
- }
-
- def matches(version: String): Boolean = {
- SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
- }
-}
diff --git a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala
deleted file mode 100644
index 6ea817db..00000000
--- a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.intel.raydp.shims.spark322
-
-import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.executor.RayDPExecutorBackendFactory
-import org.apache.spark.executor.spark322._
-import org.apache.spark.spark322.TaskContextUtils
-import org.apache.spark.sql.{DataFrame, SparkSession}
-import org.apache.spark.sql.spark322.SparkSqlUtils
-import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
-import org.apache.arrow.vector.types.pojo.Schema
-import org.apache.spark.sql.types.StructType
-
-class Spark322Shims extends SparkShims {
- override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
-
- override def toDataFrame(
- rdd: JavaRDD[Array[Byte]],
- schema: String,
- session: SparkSession): DataFrame = {
- SparkSqlUtils.toDataFrame(rdd, schema, session)
- }
-
- override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
- new RayDPSpark322ExecutorBackendFactory()
- }
-
- override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
- TaskContextUtils.getDummyTaskContext(partitionId, env)
- }
-
- override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
- SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
- }
-}
diff --git a/core/shims/spark322/src/main/scala/org/apache/spark/TaskContextUtils.scala b/core/shims/spark322/src/main/scala/org/apache/spark/TaskContextUtils.scala
deleted file mode 100644
index d658cf98..00000000
--- a/core/shims/spark322/src/main/scala/org/apache/spark/TaskContextUtils.scala
+++ /dev/null
@@ -1,30 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.spark322
-
-import java.util.Properties
-
-import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
-import org.apache.spark.memory.TaskMemoryManager
-
-object TaskContextUtils {
- def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
- new TaskContextImpl(0, 0, partitionId, -1024, 0,
- new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem)
- }
-}
diff --git a/core/shims/spark322/src/main/scala/org/apache/spark/executor/RayDPSpark322ExecutorBackendFactory.scala b/core/shims/spark322/src/main/scala/org/apache/spark/executor/RayDPSpark322ExecutorBackendFactory.scala
deleted file mode 100644
index d8673679..00000000
--- a/core/shims/spark322/src/main/scala/org/apache/spark/executor/RayDPSpark322ExecutorBackendFactory.scala
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.executor.spark322
-
-import java.net.URL
-
-import org.apache.spark.SparkEnv
-import org.apache.spark.executor.CoarseGrainedExecutorBackend
-import org.apache.spark.executor.RayDPExecutorBackendFactory
-import org.apache.spark.resource.ResourceProfile
-import org.apache.spark.rpc.RpcEnv
-
-class RayDPSpark322ExecutorBackendFactory
- extends RayDPExecutorBackendFactory {
- override def createExecutorBackend(
- rpcEnv: RpcEnv,
- driverUrl: String,
- executorId: String,
- bindAddress: String,
- hostname: String,
- cores: Int,
- userClassPath: Seq[URL],
- env: SparkEnv,
- resourcesFileOpt: Option[String],
- resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
- new CoarseGrainedExecutorBackend(
- rpcEnv,
- driverUrl,
- executorId,
- bindAddress,
- hostname,
- cores,
- userClassPath,
- env,
- resourcesFileOpt,
- resourceProfile)
- }
-}
diff --git a/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
deleted file mode 100644
index be9b409c..00000000
--- a/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.spark322
-
-import org.apache.arrow.vector.types.pojo.Schema
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
-import org.apache.spark.sql.execution.arrow.ArrowConverters
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.ArrowUtils
-
-object SparkSqlUtils {
- def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = {
- ArrowConverters.toDataFrame(rdd, schema, new SQLContext(session))
- }
-
- def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
- ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
- }
-}
diff --git a/core/shims/spark330/pom.xml b/core/shims/spark330/pom.xml
deleted file mode 100644
index 4443f658..00000000
--- a/core/shims/spark330/pom.xml
+++ /dev/null
@@ -1,128 +0,0 @@
-
-
-
- 4.0.0
-
-
- com.intel
- raydp-shims
- 1.7.0-SNAPSHOT
- ../pom.xml
-
-
- raydp-shims-spark330
- RayDP Shims for Spark 3.3.0
- jar
-
-
- 2.12.15
- 2.13.5
-
-
-
-
-
- org.scalastyle
- scalastyle-maven-plugin
-
-
- net.alchim31.maven
- scala-maven-plugin
- 3.2.2
-
-
- scala-compile-first
- process-resources
-
- compile
-
-
-
- scala-test-compile-first
- process-test-resources
-
- testCompile
-
-
-
-
-
-
-
-
- src/main/resources
-
-
-
-
-
-
- com.intel
- raydp-shims-common
- ${project.version}
- compile
-
-
- org.apache.spark
- spark-sql_${scala.binary.version}
- ${spark330.version}
- provided
-
-
- com.google.protobuf
- protobuf-java
-
-
-
-
- org.apache.spark
- spark-core_${scala.binary.version}
- ${spark330.version}
- provided
-
-
- org.xerial.snappy
- snappy-java
-
-
- io.netty
- netty-handler
-
-
- org.apache.commons
- commons-text
-
-
- org.apache.ivy
- ivy
-
-
-
-
- org.xerial.snappy
- snappy-java
- ${snappy.version}
-
-
- io.netty
- netty-handler
- ${netty.version}
-
-
- org.apache.commons
- commons-text
- ${commons.text.version}
-
-
- org.apache.ivy
- ivy
- ${ivy.version}
-
-
- com.google.protobuf
- protobuf-java
- ${protobuf.version}
-
-
-
diff --git a/core/shims/spark330/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark330/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider
deleted file mode 100644
index 184e2dfa..00000000
--- a/core/shims/spark330/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider
+++ /dev/null
@@ -1 +0,0 @@
-com.intel.raydp.shims.spark330.SparkShimProvider
diff --git a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala
deleted file mode 100644
index 7a5f5481..00000000
--- a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.intel.raydp.shims.spark330
-
-import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
-
-object SparkShimProvider {
- val SPARK330_DESCRIPTOR = SparkShimDescriptor(3, 3, 0)
- val SPARK331_DESCRIPTOR = SparkShimDescriptor(3, 3, 1)
- val SPARK332_DESCRIPTOR = SparkShimDescriptor(3, 3, 2)
- val SPARK333_DESCRIPTOR = SparkShimDescriptor(3, 3, 3)
- val DESCRIPTOR_STRINGS = Seq(s"$SPARK330_DESCRIPTOR", s"$SPARK331_DESCRIPTOR",
- s"$SPARK332_DESCRIPTOR", s"$SPARK333_DESCRIPTOR")
- val DESCRIPTOR = SPARK332_DESCRIPTOR
-}
-
-class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider {
- def createShim: SparkShims = {
- new Spark330Shims()
- }
-
- def matches(version: String): Boolean = {
- SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
- }
-}
diff --git a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala
deleted file mode 100644
index 4f1a50b5..00000000
--- a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.intel.raydp.shims.spark330
-
-import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.executor.RayDPExecutorBackendFactory
-import org.apache.spark.executor.spark330._
-import org.apache.spark.spark330.TaskContextUtils
-import org.apache.spark.sql.{DataFrame, SparkSession}
-import org.apache.spark.sql.spark330.SparkSqlUtils
-import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
-import org.apache.arrow.vector.types.pojo.Schema
-import org.apache.spark.sql.types.StructType
-
-class Spark330Shims extends SparkShims {
- override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
-
- override def toDataFrame(
- rdd: JavaRDD[Array[Byte]],
- schema: String,
- session: SparkSession): DataFrame = {
- SparkSqlUtils.toDataFrame(rdd, schema, session)
- }
-
- override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
- new RayDPSpark330ExecutorBackendFactory()
- }
-
- override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
- TaskContextUtils.getDummyTaskContext(partitionId, env)
- }
-
- override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
- SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
- }
-}
diff --git a/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayDPSpark330ExecutorBackendFactory.scala b/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayDPSpark330ExecutorBackendFactory.scala
deleted file mode 100644
index 7f01e979..00000000
--- a/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayDPSpark330ExecutorBackendFactory.scala
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.executor.spark330
-
-import java.net.URL
-
-import org.apache.spark.SparkEnv
-import org.apache.spark.executor._
-import org.apache.spark.resource.ResourceProfile
-import org.apache.spark.rpc.RpcEnv
-
-class RayDPSpark330ExecutorBackendFactory
- extends RayDPExecutorBackendFactory {
- override def createExecutorBackend(
- rpcEnv: RpcEnv,
- driverUrl: String,
- executorId: String,
- bindAddress: String,
- hostname: String,
- cores: Int,
- userClassPath: Seq[URL],
- env: SparkEnv,
- resourcesFileOpt: Option[String],
- resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
- new RayCoarseGrainedExecutorBackend(
- rpcEnv,
- driverUrl,
- executorId,
- bindAddress,
- hostname,
- cores,
- userClassPath,
- env,
- resourcesFileOpt,
- resourceProfile)
- }
-}
diff --git a/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
deleted file mode 100644
index 162371ad..00000000
--- a/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.spark330
-
-import org.apache.arrow.vector.types.pojo.Schema
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
-import org.apache.spark.sql.execution.arrow.ArrowConverters
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.ArrowUtils
-
-object SparkSqlUtils {
- def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = {
- ArrowConverters.toDataFrame(rdd, schema, session)
- }
-
- def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
- ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
- }
-}
diff --git a/core/shims/spark340/pom.xml b/core/shims/spark340/pom.xml
deleted file mode 100644
index 1b312747..00000000
--- a/core/shims/spark340/pom.xml
+++ /dev/null
@@ -1,99 +0,0 @@
-
-
-
- 4.0.0
-
-
- com.intel
- raydp-shims
- 1.7.0-SNAPSHOT
- ../pom.xml
-
-
- raydp-shims-spark340
- RayDP Shims for Spark 3.4.0
- jar
-
-
- 2.12.15
- 2.13.5
-
-
-
-
-
- org.scalastyle
- scalastyle-maven-plugin
-
-
- net.alchim31.maven
- scala-maven-plugin
- 3.2.2
-
-
- scala-compile-first
- process-resources
-
- compile
-
-
-
- scala-test-compile-first
- process-test-resources
-
- testCompile
-
-
-
-
-
-
-
-
- src/main/resources
-
-
-
-
-
-
- com.intel
- raydp-shims-common
- ${project.version}
- compile
-
-
- org.apache.spark
- spark-sql_${scala.binary.version}
- ${spark340.version}
- provided
-
-
- org.apache.spark
- spark-core_${scala.binary.version}
- ${spark340.version}
- provided
-
-
- org.xerial.snappy
- snappy-java
-
-
- io.netty
- netty-handler
-
-
-
-
- org.xerial.snappy
- snappy-java
- ${snappy.version}
-
-
- io.netty
- netty-handler
- ${netty.version}
-
-
-
diff --git a/core/shims/spark340/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark340/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider
deleted file mode 100644
index 515c47a6..00000000
--- a/core/shims/spark340/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider
+++ /dev/null
@@ -1 +0,0 @@
-com.intel.raydp.shims.spark340.SparkShimProvider
diff --git a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala
deleted file mode 100644
index 229263e2..00000000
--- a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.intel.raydp.shims.spark340
-
-import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
-
-object SparkShimProvider {
- val SPARK340_DESCRIPTOR = SparkShimDescriptor(3, 4, 0)
- val SPARK341_DESCRIPTOR = SparkShimDescriptor(3, 4, 1)
- val SPARK342_DESCRIPTOR = SparkShimDescriptor(3, 4, 2)
- val SPARK343_DESCRIPTOR = SparkShimDescriptor(3, 4, 3)
- val SPARK344_DESCRIPTOR = SparkShimDescriptor(3, 4, 4)
- val DESCRIPTOR_STRINGS = Seq(s"$SPARK340_DESCRIPTOR", s"$SPARK341_DESCRIPTOR", s"$SPARK342_DESCRIPTOR",
- s"$SPARK343_DESCRIPTOR", s"$SPARK344_DESCRIPTOR")
- val DESCRIPTOR = SPARK341_DESCRIPTOR
-}
-
-class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider {
- def createShim: SparkShims = {
- new Spark340Shims()
- }
-
- def matches(version: String): Boolean = {
- SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
- }
-}
diff --git a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala
deleted file mode 100644
index c444373f..00000000
--- a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.intel.raydp.shims.spark340
-
-import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.executor.RayDPExecutorBackendFactory
-import org.apache.spark.executor.spark340._
-import org.apache.spark.spark340.TaskContextUtils
-import org.apache.spark.sql.{DataFrame, SparkSession}
-import org.apache.spark.sql.spark340.SparkSqlUtils
-import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
-import org.apache.arrow.vector.types.pojo.Schema
-import org.apache.spark.sql.types.StructType
-
-class Spark340Shims extends SparkShims {
- override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
-
- override def toDataFrame(
- rdd: JavaRDD[Array[Byte]],
- schema: String,
- session: SparkSession): DataFrame = {
- SparkSqlUtils.toDataFrame(rdd, schema, session)
- }
-
- override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
- new RayDPSpark340ExecutorBackendFactory()
- }
-
- override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
- TaskContextUtils.getDummyTaskContext(partitionId, env)
- }
-
- override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
- SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
- }
-}
diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/TaskContextUtils.scala b/core/shims/spark340/src/main/scala/org/apache/spark/TaskContextUtils.scala
deleted file mode 100644
index 780920da..00000000
--- a/core/shims/spark340/src/main/scala/org/apache/spark/TaskContextUtils.scala
+++ /dev/null
@@ -1,30 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.spark340
-
-import java.util.Properties
-
-import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
-import org.apache.spark.memory.TaskMemoryManager
-
-object TaskContextUtils {
- def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
- new TaskContextImpl(0, 0, partitionId, -1024, 0, 0,
- new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem)
- }
-}
diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala b/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala
deleted file mode 100644
index 2e6b5e25..00000000
--- a/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.executor
-
-import java.net.URL
-
-import org.apache.spark.SparkEnv
-import org.apache.spark.resource.ResourceProfile
-import org.apache.spark.rpc.RpcEnv
-
-class RayCoarseGrainedExecutorBackend(
- rpcEnv: RpcEnv,
- driverUrl: String,
- executorId: String,
- bindAddress: String,
- hostname: String,
- cores: Int,
- userClassPath: Seq[URL],
- env: SparkEnv,
- resourcesFileOpt: Option[String],
- resourceProfile: ResourceProfile)
- extends CoarseGrainedExecutorBackend(
- rpcEnv,
- driverUrl,
- executorId,
- bindAddress,
- hostname,
- cores,
- env,
- resourcesFileOpt,
- resourceProfile) {
-
- override def getUserClassPath: Seq[URL] = userClassPath
-
-}
diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayDPSpark340ExecutorBackendFactory.scala b/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayDPSpark340ExecutorBackendFactory.scala
deleted file mode 100644
index 72abdc52..00000000
--- a/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayDPSpark340ExecutorBackendFactory.scala
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.executor.spark340
-
-import java.net.URL
-
-import org.apache.spark.SparkEnv
-import org.apache.spark.executor._
-import org.apache.spark.resource.ResourceProfile
-import org.apache.spark.rpc.RpcEnv
-
-class RayDPSpark340ExecutorBackendFactory
- extends RayDPExecutorBackendFactory {
- override def createExecutorBackend(
- rpcEnv: RpcEnv,
- driverUrl: String,
- executorId: String,
- bindAddress: String,
- hostname: String,
- cores: Int,
- userClassPath: Seq[URL],
- env: SparkEnv,
- resourcesFileOpt: Option[String],
- resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
- new RayCoarseGrainedExecutorBackend(
- rpcEnv,
- driverUrl,
- executorId,
- bindAddress,
- hostname,
- cores,
- userClassPath,
- env,
- resourcesFileOpt,
- resourceProfile)
- }
-}
diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
deleted file mode 100644
index eb52d8e7..00000000
--- a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.spark340
-
-import org.apache.arrow.vector.types.pojo.Schema
-import org.apache.spark.TaskContext
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
-import org.apache.spark.sql.execution.arrow.ArrowConverters
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.util.ArrowUtils
-
-object SparkSqlUtils {
- def toDataFrame(
- arrowBatchRDD: JavaRDD[Array[Byte]],
- schemaString: String,
- session: SparkSession): DataFrame = {
- val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
- val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
- val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
- val context = TaskContext.get()
- ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
- }
- session.internalCreateDataFrame(rdd.setName("arrow"), schema)
- }
-
- def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
- ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
- }
-}
diff --git a/core/shims/spark350/pom.xml b/core/shims/spark350/pom.xml
deleted file mode 100644
index 2368daa2..00000000
--- a/core/shims/spark350/pom.xml
+++ /dev/null
@@ -1,99 +0,0 @@
-
-
-
- 4.0.0
-
-
- com.intel
- raydp-shims
- 1.7.0-SNAPSHOT
- ../pom.xml
-
-
- raydp-shims-spark350
- RayDP Shims for Spark 3.5.0
- jar
-
-
- 2.12.15
- 2.13.5
-
-
-
-
-
- org.scalastyle
- scalastyle-maven-plugin
-
-
- net.alchim31.maven
- scala-maven-plugin
- 3.2.2
-
-
- scala-compile-first
- process-resources
-
- compile
-
-
-
- scala-test-compile-first
- process-test-resources
-
- testCompile
-
-
-
-
-
-
-
-
- src/main/resources
-
-
-
-
-
-
- com.intel
- raydp-shims-common
- ${project.version}
- compile
-
-
- org.apache.spark
- spark-sql_${scala.binary.version}
- ${spark350.version}
- provided
-
-
- org.apache.spark
- spark-core_${scala.binary.version}
- ${spark350.version}
- provided
-
-
- org.xerial.snappy
- snappy-java
-
-
- io.netty
- netty-handler
-
-
-
-
- org.xerial.snappy
- snappy-java
- ${snappy.version}
-
-
- io.netty
- netty-handler
- ${netty.version}
-
-
-
diff --git a/core/shims/spark350/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark350/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider
deleted file mode 100644
index 6e5a394e..00000000
--- a/core/shims/spark350/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider
+++ /dev/null
@@ -1 +0,0 @@
-com.intel.raydp.shims.spark350.SparkShimProvider
diff --git a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala
deleted file mode 100644
index 4a260e6f..00000000
--- a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.intel.raydp.shims.spark350
-
-import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
-
-object SparkShimProvider {
- private val SUPPORTED_PATCHES = 0 to 7
- val DESCRIPTORS = SUPPORTED_PATCHES.map(p => SparkShimDescriptor(3, 5, p))
- val DESCRIPTOR_STRINGS = DESCRIPTORS.map(_.toString)
- val DESCRIPTOR = DESCRIPTORS.head
-}
-
-class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider {
- def createShim: SparkShims = {
- new Spark350Shims()
- }
-
- def matches(version: String): Boolean = {
- SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
- }
-}
diff --git a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala
deleted file mode 100644
index 721d6923..00000000
--- a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.intel.raydp.shims.spark350
-
-import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.executor.RayDPExecutorBackendFactory
-import org.apache.spark.executor.spark350._
-import org.apache.spark.spark350.TaskContextUtils
-import org.apache.spark.sql.{DataFrame, SparkSession}
-import org.apache.spark.sql.spark350.SparkSqlUtils
-import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
-import org.apache.arrow.vector.types.pojo.Schema
-import org.apache.spark.sql.types.StructType
-
-class Spark350Shims extends SparkShims {
- override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
-
- override def toDataFrame(
- rdd: JavaRDD[Array[Byte]],
- schema: String,
- session: SparkSession): DataFrame = {
- SparkSqlUtils.toDataFrame(rdd, schema, session)
- }
-
- override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
- new RayDPSpark350ExecutorBackendFactory()
- }
-
- override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
- TaskContextUtils.getDummyTaskContext(partitionId, env)
- }
-
- override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
- SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
- }
-}
diff --git a/core/shims/spark350/src/main/scala/org/apache/spark/TaskContextUtils.scala b/core/shims/spark350/src/main/scala/org/apache/spark/TaskContextUtils.scala
deleted file mode 100644
index 0f38bbb9..00000000
--- a/core/shims/spark350/src/main/scala/org/apache/spark/TaskContextUtils.scala
+++ /dev/null
@@ -1,30 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.spark350
-
-import java.util.Properties
-
-import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
-import org.apache.spark.memory.TaskMemoryManager
-
-object TaskContextUtils {
- def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
- new TaskContextImpl(0, 0, partitionId, -1024, 0, 0,
- new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem)
- }
-}
diff --git a/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala b/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala
deleted file mode 100644
index 2e6b5e25..00000000
--- a/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.executor
-
-import java.net.URL
-
-import org.apache.spark.SparkEnv
-import org.apache.spark.resource.ResourceProfile
-import org.apache.spark.rpc.RpcEnv
-
-class RayCoarseGrainedExecutorBackend(
- rpcEnv: RpcEnv,
- driverUrl: String,
- executorId: String,
- bindAddress: String,
- hostname: String,
- cores: Int,
- userClassPath: Seq[URL],
- env: SparkEnv,
- resourcesFileOpt: Option[String],
- resourceProfile: ResourceProfile)
- extends CoarseGrainedExecutorBackend(
- rpcEnv,
- driverUrl,
- executorId,
- bindAddress,
- hostname,
- cores,
- env,
- resourcesFileOpt,
- resourceProfile) {
-
- override def getUserClassPath: Seq[URL] = userClassPath
-
-}
diff --git a/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayDPSpark350ExecutorBackendFactory.scala b/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayDPSpark350ExecutorBackendFactory.scala
deleted file mode 100644
index 54d53d7d..00000000
--- a/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayDPSpark350ExecutorBackendFactory.scala
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.executor.spark350
-
-import java.net.URL
-
-import org.apache.spark.SparkEnv
-import org.apache.spark.executor._
-import org.apache.spark.resource.ResourceProfile
-import org.apache.spark.rpc.RpcEnv
-
-class RayDPSpark350ExecutorBackendFactory
- extends RayDPExecutorBackendFactory {
- override def createExecutorBackend(
- rpcEnv: RpcEnv,
- driverUrl: String,
- executorId: String,
- bindAddress: String,
- hostname: String,
- cores: Int,
- userClassPath: Seq[URL],
- env: SparkEnv,
- resourcesFileOpt: Option[String],
- resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
- new RayCoarseGrainedExecutorBackend(
- rpcEnv,
- driverUrl,
- executorId,
- bindAddress,
- hostname,
- cores,
- userClassPath,
- env,
- resourcesFileOpt,
- resourceProfile)
- }
-}
diff --git a/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
deleted file mode 100644
index dfd063f7..00000000
--- a/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.spark350
-
-import org.apache.arrow.vector.types.pojo.Schema
-import org.apache.spark.TaskContext
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
-import org.apache.spark.sql.execution.arrow.ArrowConverters
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.util.ArrowUtils
-
-object SparkSqlUtils {
- def toDataFrame(
- arrowBatchRDD: JavaRDD[Array[Byte]],
- schemaString: String,
- session: SparkSession): DataFrame = {
- val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
- val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
- val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
- val context = TaskContext.get()
- ArrowConverters.fromBatchIterator(iter, schema, timeZoneId,false, context)
- }
- session.internalCreateDataFrame(rdd.setName("arrow"), schema)
- }
-
- def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
- ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false)
- }
-}
diff --git a/core/shims/spark411/pom.xml b/core/shims/spark411/pom.xml
new file mode 100644
index 00000000..9810a847
--- /dev/null
+++ b/core/shims/spark411/pom.xml
@@ -0,0 +1,93 @@
+
+
+
+ raydp-shims
+ com.intel
+ 2.0.0-SNAPSHOT
+
+ 4.0.0
+
+ raydp-shims-spark411
+
+
+ 4.1.1
+
+
+
+
+ com.intel
+ raydp-shims-common
+ ${project.version}
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ 3.2.18
+ test
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ ${scala.plugin.version}
+
+
+ scala-compile-first
+ process-resources
+
+ add-source
+ compile
+
+
+
+ scala-test-compile
+ process-test-resources
+
+ testCompile
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+
+ true
+
+
+
+ org.scalatest
+ scalatest-maven-plugin
+ 2.2.0
+
+ --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=none
+
+
+
+ test
+
+ test
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/core/shims/spark411/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark411/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider
new file mode 100644
index 00000000..e81063ad
--- /dev/null
+++ b/core/shims/spark411/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider
@@ -0,0 +1 @@
+com.intel.raydp.shims.spark411.SparkShimProvider
diff --git a/core/shims/spark411/src/main/scala/com/intel/raydp/shims/SparkShims411.scala b/core/shims/spark411/src/main/scala/com/intel/raydp/shims/SparkShims411.scala
new file mode 100644
index 00000000..c44183be
--- /dev/null
+++ b/core/shims/spark411/src/main/scala/com/intel/raydp/shims/SparkShims411.scala
@@ -0,0 +1,33 @@
+package com.intel.raydp.shims
+
+import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark.{Spark411Helper, SparkEnv, TaskContext}
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.executor.RayDPExecutorBackendFactory
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.{DataFrame, Spark411SQLHelper, SparkSession}
+
+class SparkShims411 extends SparkShims {
+ override def getShimDescriptor: ShimDescriptor = SparkShimDescriptor(4, 1, 1)
+
+ override def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = {
+ Spark411SQLHelper.toDataFrame(rdd, schema, session)
+ }
+
+ override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
+ Spark411Helper.getExecutorBackendFactory
+ }
+
+ override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
+ Spark411Helper.getDummyTaskContext(partitionId, env)
+ }
+
+ override def toArrowSchema(schema: StructType, timeZoneId: String, largeVarTypes: Boolean = false): Schema = {
+ Spark411SQLHelper.toArrowSchema(schema, timeZoneId, largeVarTypes)
+ }
+
+ override def toArrowBatchRdd(df: DataFrame): RDD[Array[Byte]] = {
+ Spark411SQLHelper.toArrowBatchRdd(df)
+ }
+}
diff --git a/core/shims/spark330/src/main/scala/org/apache/spark/TaskContextUtils.scala b/core/shims/spark411/src/main/scala/com/intel/raydp/shims/spark411/SparkShimProvider.scala
similarity index 65%
rename from core/shims/spark330/src/main/scala/org/apache/spark/TaskContextUtils.scala
rename to core/shims/spark411/src/main/scala/com/intel/raydp/shims/spark411/SparkShimProvider.scala
index 431167f4..6e0f62b8 100644
--- a/core/shims/spark330/src/main/scala/org/apache/spark/TaskContextUtils.scala
+++ b/core/shims/spark411/src/main/scala/com/intel/raydp/shims/spark411/SparkShimProvider.scala
@@ -15,16 +15,14 @@
* limitations under the License.
*/
-package org.apache.spark.spark330
+package com.intel.raydp.shims.spark411
-import java.util.Properties
+import com.intel.raydp.shims.{SparkShimProvider => BaseSparkShimProvider, SparkShims, SparkShims411}
-import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
-import org.apache.spark.memory.TaskMemoryManager
+class SparkShimProvider extends BaseSparkShimProvider {
+ override def createShim: SparkShims = new SparkShims411()
-object TaskContextUtils {
- def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
- new TaskContextImpl(0, 0, partitionId, -1024, 0,
- new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem)
+ override def matches(version: String): Boolean = {
+ version.startsWith("4.1")
}
}
diff --git a/core/shims/spark411/src/main/scala/org/apache/spark/Spark411Helper.scala b/core/shims/spark411/src/main/scala/org/apache/spark/Spark411Helper.scala
new file mode 100644
index 00000000..24afe7dc
--- /dev/null
+++ b/core/shims/spark411/src/main/scala/org/apache/spark/Spark411Helper.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.executor.{CoarseGrainedExecutorBackend, RayCoarseGrainedExecutorBackend, RayDPExecutorBackendFactory, TaskMetrics}
+import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.rpc.RpcEnv
+
+import java.net.URL
+import java.util.concurrent.atomic.AtomicLong
+
+object Spark411Helper {
+ private val nextTaskAttemptId = new AtomicLong(1000000L)
+ def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
+ new RayDPExecutorBackendFactory {
+ override def createExecutorBackend(
+ rpcEnv: RpcEnv,
+ driverUrl: String,
+ executorId: String,
+ bindAddress: String,
+ hostname: String,
+ cores: Int,
+ userClassPath: Seq[URL],
+ env: SparkEnv,
+ resourcesFileOpt: Option[String],
+ resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
+ new RayCoarseGrainedExecutorBackend(
+ rpcEnv,
+ driverUrl,
+ executorId,
+ bindAddress,
+ hostname,
+ cores,
+ userClassPath,
+ env,
+ resourcesFileOpt,
+ resourceProfile)
+ }
+ }
+ }
+
+ def setTaskContext(ctx: TaskContext): Unit = TaskContext.setTaskContext(ctx)
+
+ def unsetTaskContext(): Unit = TaskContext.unset()
+
+ def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
+ val taskAttemptId = nextTaskAttemptId.getAndIncrement()
+ new TaskContextImpl(
+ stageId = 0,
+ stageAttemptNumber = 0,
+ partitionId = partitionId,
+ taskAttemptId = taskAttemptId,
+ attemptNumber = 0,
+ numPartitions = 0,
+ taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskAttemptId),
+ localProperties = new java.util.Properties,
+ metricsSystem = env.metricsSystem,
+ taskMetrics = TaskMetrics.empty,
+ cpus = 0,
+ resources = Map.empty
+ )
+ }
+}
diff --git a/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala b/core/shims/spark411/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala
similarity index 100%
rename from core/shims/spark330/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala
rename to core/shims/spark411/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala
diff --git a/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala b/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala
new file mode 100644
index 00000000..20bb0419
--- /dev/null
+++ b/core/shims/spark411/src/main/scala/org/apache/spark/sql/Spark411SQLHelper.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.execution.arrow.ArrowConverters
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.sql.classic.{SparkSession => ClassicSparkSession}
+
+object Spark411SQLHelper {
+ def toArrowSchema(schema: StructType, timeZoneId: String, largeVarTypes: Boolean = false): Schema = {
+ ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames = true, largeVarTypes = largeVarTypes)
+ }
+
+ def toArrowBatchRdd(df: DataFrame): org.apache.spark.rdd.RDD[Array[Byte]] = {
+ val conf = df.sparkSession.asInstanceOf[ClassicSparkSession].sessionState.conf
+ val timeZoneId = conf.sessionLocalTimeZone
+ val maxRecordsPerBatch = conf.getConf(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
+ val largeVarTypes = conf.arrowUseLargeVarTypes
+ val schema = df.schema
+ df.queryExecution.toRdd.mapPartitions(iter => {
+ val context = TaskContext.get()
+ ArrowConverters.toBatchIterator(
+ iter,
+ schema,
+ maxRecordsPerBatch,
+ timeZoneId,
+ true, // errorOnDuplicatedFieldNames
+ largeVarTypes,
+ context)
+ })
+ }
+
+ /**
+ * Converts a JavaRDD of Arrow batches (serialized as byte arrays) to a DataFrame.
+ * This is the reverse operation of toArrowBatchRdd.
+ *
+ * @param rdd JavaRDD containing Arrow batches serialized as byte arrays
+ * @param schema JSON string representation of the StructType schema
+ * @param session SparkSession to use for DataFrame creation
+ * @return DataFrame reconstructed from the Arrow batches
+ */
+ def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = {
+ val structType = DataType.fromJson(schema).asInstanceOf[StructType]
+ val classicSession = session.asInstanceOf[ClassicSparkSession]
+
+ // Capture timezone and largeVarTypes on driver side - cannot access sessionState on executors
+ val timeZoneId = classicSession.sessionState.conf.sessionLocalTimeZone
+ val largeVarTypes = classicSession.sessionState.conf.arrowUseLargeVarTypes
+
+ // Create an RDD of InternalRow by deserializing Arrow batches per partition
+ val rowRdd = rdd.rdd.flatMap { arrowBatch =>
+ ArrowConverters.fromBatchIterator(
+ Iterator(arrowBatch),
+ structType,
+ timeZoneId, // Use captured value, not sessionState
+ true, // errorOnDuplicatedFieldNames
+ largeVarTypes,
+ TaskContext.get()
+ )
+ }
+
+ classicSession.internalCreateDataFrame(rowRdd, structType)
+ }
+}
diff --git a/core/shims/spark411/src/test/scala/com/intel/raydp/shims/SparkShims411Suite.scala b/core/shims/spark411/src/test/scala/com/intel/raydp/shims/SparkShims411Suite.scala
new file mode 100644
index 00000000..26a0908d
--- /dev/null
+++ b/core/shims/spark411/src/test/scala/com/intel/raydp/shims/SparkShims411Suite.scala
@@ -0,0 +1,372 @@
+package com.intel.raydp.shims
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+import java.sql.{Date, Timestamp}
+import java.time.{LocalDate, ZoneId}
+
+import scala.reflect.classTag
+
+import org.apache.arrow.vector.types.pojo.ArrowType
+import org.apache.spark.{Spark411Helper, SparkEnv, TaskContext}
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.types._
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
+
+class SparkShims411Suite extends AnyFunSuite with BeforeAndAfterAll {
+
+ private var spark: SparkSession = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark = SparkSession.builder()
+ .master("local[2]")
+ .appName("SparkShims411Suite")
+ .config("spark.driver.host", "127.0.0.1")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.ui.enabled", "false")
+ .getOrCreate()
+ }
+
+ override def afterAll(): Unit = {
+ if (spark != null) {
+ spark.stop()
+ spark = null
+ }
+ super.afterAll()
+ }
+
+ test("shim descriptor returns 4.1.1") {
+ val shim = new SparkShims411()
+ val descriptor = shim.getShimDescriptor
+ assert(descriptor.isInstanceOf[SparkShimDescriptor])
+ val sparkDescriptor = descriptor.asInstanceOf[SparkShimDescriptor]
+ assert(sparkDescriptor.major === 4)
+ assert(sparkDescriptor.minor === 1)
+ assert(sparkDescriptor.patch === 1)
+ assert(sparkDescriptor.toString === "4.1.1")
+ }
+
+ test("provider matches 4.1 versions") {
+ val provider = new spark411.SparkShimProvider()
+ assert(provider.matches("4.1.1"))
+ assert(provider.matches("4.1.0"))
+ assert(!provider.matches("3.5.0"))
+ assert(!provider.matches("4.0.0"))
+ }
+
+ test("provider creates SparkShims411 instance") {
+ val provider = new spark411.SparkShimProvider()
+ val shim = provider.createShim
+ assert(shim.isInstanceOf[SparkShims411])
+ }
+
+ test("SPI service loading works") {
+ SparkShimLoader.setSparkShimProviderClass(
+ "com.intel.raydp.shims.spark411.SparkShimProvider")
+ val shim = SparkShimLoader.getSparkShims
+ assert(shim.isInstanceOf[SparkShims411])
+ assert(shim.getShimDescriptor.toString === "4.1.1")
+ }
+
+ test("toArrowSchema produces valid Arrow schema") {
+ val shim = new SparkShims411()
+ val sparkSchema = new StructType()
+ .add("id", IntegerType)
+ .add("name", StringType)
+ .add("value", DoubleType)
+
+ val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+ val arrowSchema = shim.toArrowSchema(sparkSchema, timeZoneId)
+
+ assert(arrowSchema.getFields.size() === 3)
+ assert(arrowSchema.getFields.get(0).getName === "id")
+ assert(arrowSchema.getFields.get(1).getName === "name")
+ assert(arrowSchema.getFields.get(2).getName === "value")
+ }
+
+ test("Arrow round-trip: DataFrame to ArrowBatch to DataFrame") {
+ val shim = new SparkShims411()
+
+ val schema = new StructType()
+ .add("id", IntegerType)
+ .add("name", StringType)
+ .add("value", DoubleType)
+ val rows = java.util.Arrays.asList(
+ Row(1, "alice", 10.0),
+ Row(2, "bob", 20.0),
+ Row(3, "carol", 30.0))
+ val original = spark.createDataFrame(rows, schema)
+
+ val arrowRdd = shim.toArrowBatchRdd(original)
+ val schemaJson = original.schema.json
+
+ val restored = shim.toDataFrame(
+ arrowRdd.toJavaRDD(), schemaJson, spark)
+
+ val originalRows = original.collect().sortBy(_.getInt(0))
+ val restoredRows = restored.collect().sortBy(_.getInt(0))
+
+ assert(originalRows.length === restoredRows.length)
+ originalRows.zip(restoredRows).foreach { case (orig, rest) =>
+ assert(orig.getInt(0) === rest.getInt(0))
+ assert(orig.getString(1) === rest.getString(1))
+ assert(orig.getDouble(2) === rest.getDouble(2))
+ }
+ }
+
+ test("toArrowSchema maps Spark types to correct Arrow types") {
+ val shim = new SparkShims411()
+ val sparkSchema = new StructType()
+ .add("bool", BooleanType)
+ .add("byte", ByteType)
+ .add("short", ShortType)
+ .add("int", IntegerType)
+ .add("long", LongType)
+ .add("float", FloatType)
+ .add("double", DoubleType)
+ .add("str", StringType)
+ .add("bin", BinaryType)
+
+ val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+ val arrowSchema = shim.toArrowSchema(sparkSchema, timeZoneId)
+ val fields = arrowSchema.getFields
+
+ assert(fields.get(0).getType.isInstanceOf[ArrowType.Bool])
+ assert(fields.get(1).getType === new ArrowType.Int(8, true))
+ assert(fields.get(2).getType === new ArrowType.Int(16, true))
+ assert(fields.get(3).getType === new ArrowType.Int(32, true))
+ assert(fields.get(4).getType === new ArrowType.Int(64, true))
+ assert(fields.get(5).getType === new ArrowType.FloatingPoint(
+ org.apache.arrow.vector.types.FloatingPointPrecision.SINGLE))
+ assert(fields.get(6).getType === new ArrowType.FloatingPoint(
+ org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE))
+ assert(fields.get(7).getType.isInstanceOf[ArrowType.Utf8])
+ assert(fields.get(8).getType.isInstanceOf[ArrowType.Binary])
+ }
+
+ test("Arrow round-trip preserves null values") {
+ val shim = new SparkShims411()
+
+ val schema = new StructType()
+ .add("id", IntegerType, nullable = false)
+ .add("name", StringType, nullable = true)
+ .add("value", DoubleType, nullable = true)
+ val rows = java.util.Arrays.asList(
+ Row(1, "alice", 10.0),
+ Row(2, null, null),
+ Row(3, "carol", 30.0))
+ val original = spark.createDataFrame(rows, schema)
+
+ val arrowRdd = shim.toArrowBatchRdd(original)
+ val schemaJson = original.schema.json
+ val restored = shim.toDataFrame(arrowRdd.toJavaRDD(), schemaJson, spark)
+
+ val restoredRows = restored.collect().sortBy(_.getInt(0))
+ assert(restoredRows.length === 3)
+ assert(restoredRows(0).getString(1) === "alice")
+ assert(restoredRows(1).isNullAt(1))
+ assert(restoredRows(1).isNullAt(2))
+ assert(restoredRows(2).getDouble(2) === 30.0)
+ }
+
+ test("Arrow round-trip with multiple batches") {
+ val shim = new SparkShims411()
+
+ // Force small batches: 2 records per batch
+ spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "2")
+ try {
+ val schema = new StructType()
+ .add("id", IntegerType)
+ .add("label", StringType)
+ val rows = (1 to 10).map(i => Row(i, s"row_$i"))
+ val original = spark.createDataFrame(
+ java.util.Arrays.asList(rows: _*), schema)
+ .repartition(1) // single partition so multiple batches are in one partition
+
+ val arrowRdd = shim.toArrowBatchRdd(original)
+ val schemaJson = original.schema.json
+ val restored = shim.toDataFrame(arrowRdd.toJavaRDD(), schemaJson, spark)
+
+ val restoredRows = restored.collect().sortBy(_.getInt(0))
+ assert(restoredRows.length === 10)
+ (1 to 10).foreach { i =>
+ assert(restoredRows(i - 1).getInt(0) === i)
+ assert(restoredRows(i - 1).getString(1) === s"row_$i")
+ }
+ } finally {
+ spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
+ }
+ }
+
+ test("Arrow round-trip preserves Timestamp and Date values") {
+ val shim = new SparkShims411()
+
+ val ts1 = Timestamp.valueOf("2025-01-15 10:30:00")
+ val ts2 = Timestamp.valueOf("2025-06-30 23:59:59")
+ val d1 = Date.valueOf("2025-01-15")
+ val d2 = Date.valueOf("2025-06-30")
+
+ val schema = new StructType()
+ .add("id", IntegerType)
+ .add("ts", TimestampType)
+ .add("dt", DateType)
+ val rows = java.util.Arrays.asList(
+ Row(1, ts1, d1),
+ Row(2, ts2, d2))
+ val original = spark.createDataFrame(rows, schema)
+
+ val arrowRdd = shim.toArrowBatchRdd(original)
+ val schemaJson = original.schema.json
+ val restored = shim.toDataFrame(arrowRdd.toJavaRDD(), schemaJson, spark)
+
+ val restoredRows = restored.collect().sortBy(_.getInt(0))
+ assert(restoredRows.length === 2)
+ assert(restoredRows(0).getTimestamp(1) === ts1)
+ assert(restoredRows(0).getDate(2) === d1)
+ assert(restoredRows(1).getTimestamp(1) === ts2)
+ assert(restoredRows(1).getDate(2) === d2)
+ }
+
+ test("Arrow round-trip preserves Decimal values") {
+ val shim = new SparkShims411()
+
+ val schema = new StructType()
+ .add("id", IntegerType)
+ .add("price", DecimalType(18, 6))
+ .add("quantity", DecimalType(10, 0))
+ val rows = java.util.Arrays.asList(
+ Row(1, new java.math.BigDecimal("12345.678900"), new java.math.BigDecimal("100")),
+ Row(2, new java.math.BigDecimal("0.000001"), new java.math.BigDecimal("0")),
+ Row(3, new java.math.BigDecimal("-9999.999999"), new java.math.BigDecimal("999")))
+ val original = spark.createDataFrame(rows, schema)
+
+ val arrowRdd = shim.toArrowBatchRdd(original)
+ val schemaJson = original.schema.json
+ val restored = shim.toDataFrame(arrowRdd.toJavaRDD(), schemaJson, spark)
+
+ val restoredRows = restored.collect().sortBy(_.getInt(0))
+ assert(restoredRows.length === 3)
+ assert(restoredRows(0).getDecimal(1).compareTo(
+ new java.math.BigDecimal("12345.678900")) === 0)
+ assert(restoredRows(1).getDecimal(1).compareTo(
+ new java.math.BigDecimal("0.000001")) === 0)
+ assert(restoredRows(2).getDecimal(1).compareTo(
+ new java.math.BigDecimal("-9999.999999")) === 0)
+ assert(restoredRows(2).getDecimal(2).compareTo(
+ new java.math.BigDecimal("999")) === 0)
+ }
+
+ test("Arrow round-trip with empty DataFrame") {
+ val shim = new SparkShims411()
+
+ val schema = new StructType()
+ .add("id", IntegerType)
+ .add("name", StringType)
+ val rows = java.util.Arrays.asList[Row]()
+ val original = spark.createDataFrame(rows, schema)
+
+ val arrowRdd = shim.toArrowBatchRdd(original)
+ val schemaJson = original.schema.json
+ val restored = shim.toDataFrame(arrowRdd.toJavaRDD(), schemaJson, spark)
+
+ assert(restored.collect().length === 0)
+ assert(restored.schema === original.schema)
+ }
+
+ test("getDummyTaskContext returns valid TaskContext") {
+ val shim = new SparkShims411()
+ val env = SparkEnv.get
+
+ val ctx = shim.getDummyTaskContext(42, env)
+
+ assert(ctx.isInstanceOf[TaskContext])
+ assert(ctx.partitionId() === 42)
+ assert(ctx.stageId() === 0)
+ assert(ctx.attemptNumber() === 0)
+ }
+
+ test("getDummyTaskContext generates unique taskAttemptIds") {
+ val shim = new SparkShims411()
+ val env = SparkEnv.get
+
+ val ids = (0 until 100).map { i =>
+ shim.getDummyTaskContext(i, env).taskAttemptId()
+ }
+
+ assert(ids.distinct.size === 100, "all taskAttemptIds must be unique")
+ assert(ids.forall(_ >= 1000000L), "taskAttemptIds must start at 1000000+")
+ }
+
+ test("BlockManager.get NPEs without registerTask (regression)") {
+ // This documents the exact NPE bug we fixed: Spark 4.1's BlockInfoManager
+ // requires tasks to be registered before they can acquire read locks.
+ // Setting a TaskContext with a specific taskAttemptId WITHOUT calling
+ // registerTask causes lockForReading to NPE on:
+ // readLocksByTask.get(taskAttemptId).add(blockId)
+ // Note: if NO TaskContext were set, BlockInfoManager would use the
+ // pre-registered NON_TASK_WRITER (-1024) and no NPE would occur.
+ val env = SparkEnv.get
+ val blockManager = env.blockManager
+
+ // Cache a small RDD to have a block to read
+ val rdd = spark.sparkContext.parallelize(Seq(Array[Byte](1, 2, 3)), 1)
+ rdd.persist(StorageLevel.MEMORY_ONLY)
+ rdd.count() // force materialization
+ val blockId = RDDBlockId(rdd.id, 0)
+
+ // Without registerTask: set TaskContext but do NOT register → NPE
+ val ctx = Spark411Helper.getDummyTaskContext(0, env)
+ Spark411Helper.setTaskContext(ctx)
+ try {
+ val ex = intercept[NullPointerException] {
+ blockManager.get(blockId)(classTag[Array[Byte]])
+ }
+ assert(ex != null)
+ } finally {
+ Spark411Helper.unsetTaskContext()
+ }
+
+ rdd.unpersist(blocking = true)
+ }
+
+ test("BlockManager.get succeeds with registerTask and cleanup") {
+ // Full lifecycle test mirroring RayDPExecutor.getRDDPartition:
+ // register task → read block → release locks → unset context
+ val env = SparkEnv.get
+ val blockManager = env.blockManager
+
+ val rdd = spark.sparkContext.parallelize(Seq(Array[Byte](10, 20, 30)), 1)
+ rdd.persist(StorageLevel.MEMORY_ONLY)
+ rdd.count()
+ val blockId = RDDBlockId(rdd.id, 0)
+
+ val ctx = Spark411Helper.getDummyTaskContext(0, env)
+ Spark411Helper.setTaskContext(ctx)
+ val taskAttemptId = ctx.taskAttemptId()
+ blockManager.registerTask(taskAttemptId)
+ try {
+ val result = blockManager.get(blockId)(classTag[Array[Byte]])
+ assert(result.isDefined, "block should be readable after registerTask")
+ } finally {
+ blockManager.releaseAllLocksForTask(taskAttemptId)
+ Spark411Helper.unsetTaskContext()
+ }
+
+ // Verify lock was released: a second task can also read the same block
+ val ctx2 = Spark411Helper.getDummyTaskContext(0, env)
+ Spark411Helper.setTaskContext(ctx2)
+ val taskAttemptId2 = ctx2.taskAttemptId()
+ blockManager.registerTask(taskAttemptId2)
+ try {
+ val result2 = blockManager.get(blockId)(classTag[Array[Byte]])
+ assert(result2.isDefined, "block should be readable by a second task after cleanup")
+ } finally {
+ blockManager.releaseAllLocksForTask(taskAttemptId2)
+ Spark411Helper.unsetTaskContext()
+ }
+
+ rdd.unpersist(blocking = true)
+ }
+}
diff --git a/python/pylintrc b/python/pylintrc
index 48bc7e2e..b11070cd 100644
--- a/python/pylintrc
+++ b/python/pylintrc
@@ -49,7 +49,7 @@ unsafe-load-any-extension=no
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code
-extension-pkg-whitelist=netifaces
+extension-pkg-allow-list=netifaces
[MESSAGES CONTROL]
@@ -74,65 +74,33 @@ confidence=
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=abstract-method,
- apply-builtin,
arguments-differ,
attribute-defined-outside-init,
- backtick,
- basestring-builtin,
broad-except,
- buffer-builtin,
- cmp-builtin,
- cmp-method,
- coerce-builtin,
- coerce-method,
+ broad-exception-caught,
+ broad-exception-raised,
+ consider-using-dict-items,
+ consider-using-f-string,
+ consider-using-from-import,
+ consider-using-generator,
dangerous-default-value,
- delslice-method,
duplicate-code,
- execfile-builtin,
- file-builtin,
- filter-builtin-not-iterating,
fixme,
- getslice-method,
global-statement,
- hex-method,
+ global-variable-not-assigned,
import-error,
import-self,
- import-star-module-level,
- input-builtin,
- intern-builtin,
invalid-name,
locally-disabled,
logging-fstring-interpolation,
- long-builtin,
- long-suffix,
- map-builtin-not-iterating,
missing-docstring,
missing-function-docstring,
- metaclass-assignment,
- next-method-called,
- next-method-defined,
- no-absolute-import,
no-else-return,
no-member,
no-name-in-module,
- no-self-use,
- nonzero-method,
- oct-method,
- old-division,
- old-ne-operator,
- old-octal-literal,
- old-raise-syntax,
- parameter-unpacking,
- print-statement,
+ possibly-used-before-assignment,
protected-access,
- raising-string,
- range-builtin-not-iterating,
redefined-outer-name,
- reduce-builtin,
- reload-builtin,
- round-builtin,
- setslice-method,
- standarderror-builtin,
suppressed-message,
too-few-public-methods,
too-many-ancestors,
@@ -140,21 +108,19 @@ disable=abstract-method,
too-many-branches,
too-many-instance-attributes,
too-many-locals,
+ too-many-positional-arguments,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
- unichr-builtin,
- unicode-builtin,
- unpacking-in-except,
unused-argument,
unused-import,
+ unreachable,
+ unspecified-encoding,
unused-variable,
+ use-dict-literal,
useless-else-on-loop,
useless-suppression,
- using-cmp-argument,
wrong-import-order,
- xrange-builtin,
- zip-builtin-not-iterating,
[REPORTS]
@@ -164,12 +130,6 @@ disable=abstract-method,
# mypackage.mymodule.MyReporterClass.
output-format=text
-# Put messages in a separate file for each module / package specified on the
-# command line instead of printing them on stdout. Reports (if any) will be
-# written in a file name "pylint_global.[txt|html]". This option is deprecated
-# and it will be removed in Pylint 2.0.
-files-output=no
-
# Tells whether to display a full report or only the messages
reports=no
@@ -215,63 +175,33 @@ property-classes=abc.abstractproperty
function-rgx=[a-z_][a-z0-9_]{2,30}$
-# Naming hint for function names
-function-name-hint=[a-z_][a-z0-9_]{2,30}$
-
# Regular expression matching correct variable names
variable-rgx=[a-z_][a-z0-9_]{2,30}$
-# Naming hint for variable names
-variable-name-hint=[a-z_][a-z0-9_]{2,30}$
-
# Regular expression matching correct constant names
const-rgx=(([A-Za-z_][A-Za-z0-9_]*)|(__.*__))$
-# Naming hint for constant names
-const-name-hint=(([A-a-zZ_][A-Za-z0-9_]*)|(__.*__))$
-
# Regular expression matching correct attribute names
attr-rgx=[a-z_][a-z0-9_]{2,30}$
-# Naming hint for attribute names
-attr-name-hint=[a-z_][a-z0-9_]{2,30}$
-
# Regular expression matching correct argument names
argument-rgx=[a-z_][a-z0-9_]{2,30}$
-# Naming hint for argument names
-argument-name-hint=[a-z_][a-z0-9_]{2,30}$
-
# Regular expression matching correct class attribute names
class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
-# Naming hint for class attribute names
-class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
-
# Regular expression matching correct inline iteration names
inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
-# Naming hint for inline iteration names
-inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
-
# Regular expression matching correct class names
class-rgx=[A-Z_][a-zA-Z0-9]+$
-# Naming hint for class names
-class-name-hint=[A-Z_][a-zA-Z0-9]+$
-
# Regular expression matching correct module names
module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
-# Naming hint for module names
-module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
-
# Regular expression matching correct method names
method-rgx=[a-z_][a-z0-9_]{2,30}$
-# Naming hint for method names
-method-name-hint=[a-z_][a-z0-9_]{2,30}$
-
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
@@ -326,12 +256,6 @@ ignore-long-lines=(?x)(
# else.
single-line-if-stmt=yes
-# List of optional constructs for which whitespace checking is disabled. `dict-
-# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
-# `trailing-comma` allows a space between comma and closing bracket: (a, ).
-# `empty-line` allows space-only lines.
-no-space-check=trailing-comma,dict-separator
-
# Maximum number of lines in a module
max-module-lines=1000
@@ -481,6 +405,5 @@ valid-metaclass-classmethod-first-arg=mcs
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
-overgeneral-exceptions=StandardError,
- Exception,
- BaseException
+overgeneral-exceptions=builtins.Exception,
+ builtins.BaseException
diff --git a/python/pyproject.toml b/python/pyproject.toml
new file mode 100644
index 00000000..ec5392e3
--- /dev/null
+++ b/python/pyproject.toml
@@ -0,0 +1,32 @@
+[build-system]
+requires = ["setuptools", "wheel", "grpcio-tools"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "raydp"
+dynamic = ["version", "readme"]
+description = "RayDP: Distributed Data Processing on Ray"
+authors = [
+ {name = "RayDP Developers", email = "raydp-dev@googlegroups.com"}
+]
+license = "Apache-2.0"
+keywords = ["raydp", "spark", "ray", "distributed", "data-processing"]
+classifiers = [
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+]
+requires-python = ">=3.10"
+dependencies = [
+ "numpy",
+ "pandas >= 1.1.4",
+ "psutil",
+ "pyarrow >= 4.0.1",
+ "ray >= 2.53.0",
+ "pyspark >= 4.1.0",
+ "netifaces",
+ "protobuf > 3.19.5"
+]
+
+[tool.setuptools.dynamic]
+version = {attr = "setup.VERSION"}
diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py
index 6afd78e5..bd35864b 100644
--- a/python/raydp/spark/dataset.py
+++ b/python/raydp/spark/dataset.py
@@ -14,12 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import time
import uuid
from typing import Callable, List, Optional, Union
from dataclasses import dataclass
from packaging import version
-import pandas as pd
import pyarrow as pa
import pyspark.sql as sql
from pyspark.sql import SparkSession
@@ -77,56 +77,6 @@ def _fetch_arrow_table_from_executor(executor_actor_name: str,
return reader.read_all()
-class RecordPiece:
- def __init__(self, row_ids, num_rows: int):
- self.row_ids = row_ids
- self.num_rows = num_rows
-
- def read(self, shuffle: bool) -> pd.DataFrame:
- raise NotImplementedError
-
- def with_row_ids(self, new_row_ids) -> "RecordPiece":
- raise NotImplementedError
-
- def __len__(self):
- """Return the number of rows"""
- return self.num_rows
-
-
-class RayObjectPiece(RecordPiece):
- def __init__(self,
- obj_id: ray.ObjectRef,
- row_ids: Optional[List[int]],
- num_rows: int):
- super().__init__(row_ids, num_rows)
- self.obj_id = obj_id
-
- def read(self, shuffle: bool) -> pd.DataFrame:
- data = ray.get(self.obj_id)
- reader = pa.ipc.open_stream(data)
- tb = reader.read_all()
- df: pd.DataFrame = tb.to_pandas()
- if self.row_ids:
- df = df.loc[self.row_ids]
-
- if shuffle:
- df = df.sample(frac=1.0)
- return df
-
- def with_row_ids(self, new_row_ids) -> "RayObjectPiece":
- """chang the num_rows to the length of new_row_ids. Keep the original size if
- the new_row_ids is None.
- """
-
- if new_row_ids:
- num_rows = len(new_row_ids)
- else:
- num_rows = self.num_rows
-
- return RayObjectPiece(self.obj_id, new_row_ids, num_rows)
-
-
-
@dataclass
class PartitionObjectsOwner:
# Actor owner name
@@ -165,23 +115,26 @@ def spark_dataframe_to_ray_dataset(df: sql.DataFrame,
def from_spark_recoverable(df: sql.DataFrame,
storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK,
parallelism: Optional[int] = None):
- """Recoverable Spark->Ray conversion that survives executor loss."""
+ """Recoverable Spark->Ray conversion that survives executor loss.
+
+ Materialization and Ray fetching are overlapped: as each Spark partition
+ completes, its Ray fetch task is submitted immediately rather than waiting
+ for all partitions to finish first.
+ """
num_part = df.rdd.getNumPartitions()
if parallelism is not None:
if parallelism != num_part:
df = df.repartition(parallelism)
- sc = df.sql_ctx.sparkSession.sparkContext
+ sc = df.sparkSession.sparkContext
storage_level = sc._getJavaStorageLevel(storage_level)
object_store_writer = sc._jvm.org.apache.spark.sql.raydp.ObjectStoreWriter
- # Recoverable conversion for Ray node loss:
- # - cache Arrow bytes in Spark
- # - build Ray Dataset blocks via Ray tasks (lineage), each task refetches bytes via JVM actors
- info = object_store_writer.prepareRecoverableRDD(df._jdf, storage_level)
- rdd_id = info.rddId()
- num_partitions = info.numPartitions()
- schema_json = info.schemaJson()
- driver_agent_url = info.driverAgentUrl()
- locations = info.locations()
+
+ # Start materialization in the background; returns a handle we can poll.
+ handle = object_store_writer.startStreamingRecoverableRDD(df._jdf, storage_level)
+ rdd_id = handle.rddId()
+ num_partitions = handle.numPartitions()
+ schema_json = handle.schemaJson()
+ driver_agent_url = handle.driverAgentUrl()
spark_conf = sc.getConf()
fetch_num_cpus = float(
@@ -189,26 +142,30 @@ def from_spark_recoverable(df: sql.DataFrame,
fetch_memory_str = spark_conf.get(
"spark.ray.raydp_recoverable_fetch.task.resource.memory", "0")
fetch_memory = float(parse_memory_size(fetch_memory_str))
-
-
- refs: List[ObjectRef] = []
- for i in range(num_partitions):
- executor_id = locations[i]
- executor_actor_name = f"raydp-executor-{executor_id}"
- task_opts = {
- "num_cpus": fetch_num_cpus,
- "memory": fetch_memory,
- }
- fetch_task = _fetch_arrow_table_from_executor.options(**task_opts)
- refs.append(
- fetch_task.remote(
- executor_actor_name,
- rdd_id,
- i,
- schema_json,
- driver_agent_url,
- )
- )
+ task_opts = {"num_cpus": fetch_num_cpus, "memory": fetch_memory}
+ fetch_task = _fetch_arrow_table_from_executor.options(**task_opts)
+
+ # Poll for completed partitions and dispatch Ray fetch tasks as they become ready.
+ refs: List[Optional[ObjectRef]] = [None] * num_partitions
+ dispatched = set()
+
+ while len(dispatched) < num_partitions:
+ err = handle.getError()
+ if err is not None:
+ raise RuntimeError(f"Spark materialization failed: {err}")
+
+ ready = handle.getReadyPartitions()
+ new_count = 0
+ for i in range(num_partitions):
+ if i not in dispatched and ready[i] is not None:
+ executor_actor_name = f"raydp-executor-{ready[i]}"
+ refs[i] = fetch_task.remote(
+ executor_actor_name, rdd_id, i, schema_json, driver_agent_url)
+ dispatched.add(i)
+ new_count += 1
+
+ if len(dispatched) < num_partitions and new_count == 0:
+ time.sleep(0.1)
return from_arrow_refs(refs)
@@ -226,45 +183,29 @@ def _convert_by_udf(spark: sql.SparkSession,
jdf = object_store_reader.createRayObjectRefDF(spark._jsparkSession, locations)
current_namespace = ray.get_runtime_context().namespace
ray_address = ray.get(holder.get_ray_address.remote())
- blocks_df = DataFrame(jdf, spark._wrapped if hasattr(spark, "_wrapped") else spark)
- def _convert_blocks_to_dataframe(blocks):
+ blocks_df = DataFrame(jdf, spark)
+ def _convert_blocks_to_batches(batches):
# connect to ray
if not ray.is_initialized():
ray.init(address=ray_address,
namespace=current_namespace,
logging_level=logging.WARN)
obj_holder = ray.get_actor(holder_name)
- for block in blocks:
- dfs = []
- for idx in block["idx"]:
- ref = ray.get(obj_holder.get_object.remote(df_id, idx))
- data = ray.get(ref)
- dfs.append(data.to_pandas())
- yield pd.concat(dfs)
- df = blocks_df.mapInPandas(_convert_blocks_to_dataframe, schema)
+ for batch in batches:
+ indices = batch.column("idx").to_pylist()
+ # Batch both actor lookups and data fetches so Ray can pipeline them
+ ref_futures = [obj_holder.get_object.remote(df_id, idx) for idx in indices]
+ refs = ray.get(ref_futures)
+ tables = ray.get(list(refs))
+ combined = tables[0] if len(tables) == 1 else pa.concat_tables(tables)
+ yield from combined.to_batches()
+ df = blocks_df.mapInArrow(_convert_blocks_to_batches, schema)
return df
-def _convert_by_rdd(spark: sql.SparkSession,
- blocks: Dataset,
- locations: List[bytes],
- schema: StructType) -> DataFrame:
- object_ids = [block.binary() for block in blocks]
- schema_str = schema.json()
- jvm = spark.sparkContext._jvm
- # create rdd in java
- rdd = jvm.org.apache.spark.rdd.RayDatasetRDD(spark._jsc, object_ids, locations)
- # convert the rdd to dataframe
- object_store_reader = jvm.org.apache.spark.sql.raydp.ObjectStoreReader
- jdf = object_store_reader.RayDatasetToDataFrame(spark._jsparkSession, rdd, schema_str)
- return DataFrame(jdf, spark._wrapped if hasattr(spark, "_wrapped") else spark)
-
@client_mode_wrap
def get_locations(blocks):
- core_worker = ray.worker.global_worker.core_worker
- return [
- core_worker.get_owner_address(block)
- for block in blocks
- ]
+ core_worker = _ray_global_worker.core_worker
+ return [core_worker.get_owner_address(b) for b in blocks]
def ray_dataset_to_spark_dataframe(spark: sql.SparkSession,
arrow_schema,
@@ -279,14 +220,7 @@ def ray_dataset_to_spark_dataframe(spark: sql.SparkSession,
schema = StructType()
for field in arrow_schema:
schema.add(field.name, from_arrow_type(field.type), nullable=field.nullable)
- #TODO how to branch on type of block?
- sample = ray.get(blocks[0])
- if isinstance(sample, bytes):
- return _convert_by_rdd(spark, blocks, locations, schema)
- elif isinstance(sample, pa.Table):
- return _convert_by_udf(spark, blocks, locations, schema)
- else:
- raise RuntimeError("ray.to_spark only supports arrow type blocks")
+ return _convert_by_udf(spark, blocks, locations, schema)
def read_spark_parquet(path: str) -> Dataset:
diff --git a/python/raydp/spark/ray_cluster.py b/python/raydp/spark/ray_cluster.py
index 10816d25..2f04be1a 100644
--- a/python/raydp/spark/ray_cluster.py
+++ b/python/raydp/spark/ray_cluster.py
@@ -119,6 +119,18 @@ def _prepare_spark_configs(self):
self._configs["spark.executor.instances"] = str(self._num_executors)
self._configs["spark.executor.cores"] = str(self._executor_cores)
self._configs["spark.executor.memory"] = str(self._executor_memory)
+ # Enable Arrow IPC lz4 compression by default (Spark 4.1+).
+ # lz4 is faster than zstd for intra-cluster transfers where RayDP operates.
+ # Users can override by passing their own value for this config.
+ arrow_codec_key = "spark.sql.execution.arrow.compression.codec"
+ if arrow_codec_key not in self._configs:
+ self._configs[arrow_codec_key] = "lz4"
+ # RayDP converts whole partitions (not streaming UDFs), so one Arrow batch
+ # per partition minimizes per-batch overhead. 0 = unlimited rows per batch.
+ arrow_batch_key = "spark.sql.execution.arrow.maxRecordsPerBatch"
+ if arrow_batch_key not in self._configs:
+ self._configs[arrow_batch_key] = "0"
+
if platform.system() != "Darwin":
driver_node_ip = ray.util.get_node_ip_address()
if "spark.driver.host" not in self._configs:
diff --git a/python/raydp/tf/estimator.py b/python/raydp/tf/estimator.py
index 1eb3bf1d..c7615ff9 100644
--- a/python/raydp/tf/estimator.py
+++ b/python/raydp/tf/estimator.py
@@ -15,20 +15,22 @@
# limitations under the License.
#
+import os
from packaging import version
import platform
import tempfile
from typing import Any, Dict, List, NoReturn, Optional, Union
+import numpy as np
+
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.callbacks import Callback
import ray
-from ray.air import session
-from ray.air.config import ScalingConfig, RunConfig, FailureConfig
-from ray.train import Checkpoint
-from ray.train.tensorflow import TensorflowCheckpoint, TensorflowTrainer
+import ray.train
+from ray.train import ScalingConfig, RunConfig, FailureConfig, Checkpoint
+from ray.train.tensorflow import TensorflowTrainer
from ray.data.dataset import Dataset
from ray.data.preprocessors import Concatenator
@@ -168,46 +170,53 @@ def build_and_compile_model(config):
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
return model
+ @staticmethod
+ def _materialize_tf_dataset(data_iter, feature_cols, label_cols,
+ batch_size, drop_last):
+ """Materialize a Ray DataIterator into a finite tf.data.Dataset."""
+ batches = list(data_iter.iter_batches(batch_format="numpy"))
+ def _concat(col):
+ if isinstance(col, str):
+ return np.concatenate([b[col] for b in batches])
+ return {c: np.concatenate([b[c] for b in batches]) for c in col}
+ ds = tf.data.Dataset.from_tensor_slices(
+ (_concat(feature_cols), _concat(label_cols)))
+ return ds.batch(batch_size, drop_remainder=drop_last)
+
@staticmethod
def train_func(config):
- strategy = tf.distribute.MultiWorkerMirroredStrategy()
- with strategy.scope():
- # Model building/compiling need to be within `strategy.scope()`.
- multi_worker_model = TFEstimator.build_and_compile_model(config)
+ # NOTE: MultiWorkerMirroredStrategy is incompatible with Keras 3
+ # (PerReplica conversion error). See:
+ # https://github.com/keras-team/keras/issues/20585
+ # https://github.com/ray-project/ray/issues/47464
+ # Each Ray worker trains independently on its data shard instead.
+ model = TFEstimator.build_and_compile_model(config)
- train_dataset = session.get_dataset_shard("train")
- train_tf_dataset = train_dataset.to_tf(
- feature_columns=config["feature_columns"],
- label_columns=config["label_columns"],
- batch_size=config["batch_size"],
- drop_last=config["drop_last"]
- )
+ train_tf_dataset = TFEstimator._materialize_tf_dataset(
+ ray.train.get_dataset_shard("train"),
+ config["feature_columns"], config["label_columns"],
+ config["batch_size"], config["drop_last"])
if config["evaluate"]:
- eval_dataset = session.get_dataset_shard("evaluate")
- eval_tf_dataset = eval_dataset.to_tf(
- feature_columns=config["feature_columns"],
- label_columns=config["label_columns"],
- batch_size=config["batch_size"],
- drop_last=config["drop_last"]
- )
+ eval_tf_dataset = TFEstimator._materialize_tf_dataset(
+ ray.train.get_dataset_shard("evaluate"),
+ config["feature_columns"], config["label_columns"],
+ config["batch_size"], config["drop_last"])
results = []
callbacks = config["callbacks"]
for _ in range(config["num_epochs"]):
- train_history = multi_worker_model.fit(train_tf_dataset, callbacks=callbacks)
+ train_history = model.fit(train_tf_dataset, callbacks=callbacks)
results.append(train_history.history)
if config["evaluate"]:
- test_history = multi_worker_model.evaluate(eval_tf_dataset, callbacks=callbacks)
+ test_history = model.evaluate(eval_tf_dataset, callbacks=callbacks)
results.append(test_history)
- # Only save checkpoint from the chief worker to avoid race conditions.
- # However, we need to call save on all workers to avoid deadlock.
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
- multi_worker_model.save(temp_checkpoint_dir, save_format="tf")
+ model.save(os.path.join(temp_checkpoint_dir, "model.keras"))
checkpoint = None
- if session.get_world_rank() == 0:
+ if ray.train.get_context().get_world_rank() == 0:
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
- session.report({}, checkpoint=checkpoint)
+ ray.train.report({}, checkpoint=checkpoint)
def fit(self,
train_ds: Dataset,
@@ -280,7 +289,7 @@ def fit_on_spark(self,
train_df = self._check_and_convert(train_df)
evaluate_ds = None
if fs_directory is not None:
- app_id = train_df.sql_ctx.sparkSession.sparkContext.applicationId
+ app_id = train_df.sparkSession.sparkContext.applicationId
path = fs_directory.rstrip("/") + f"/{app_id}"
train_df.write.parquet(path+"/train", compression=compression)
train_ds = read_spark_parquet(path+"/train")
@@ -291,7 +300,7 @@ def fit_on_spark(self,
else:
owner = None
if stop_spark_after_conversion:
- owner = get_raydp_master_owner(train_df.sql_ctx.sparkSession)
+ owner = get_raydp_master_owner(train_df.sparkSession)
train_ds = spark_dataframe_to_ray_dataset(train_df,
owner=owner)
if evaluate_df is not None:
@@ -305,6 +314,5 @@ def fit_on_spark(self,
def get_model(self) -> Any:
assert self._trainer, "Trainer has not been created"
- return TensorflowCheckpoint.from_saved_model(
- self._results.checkpoint.to_directory()
- ).get_model()
+ checkpoint_dir = self._results.checkpoint.to_directory()
+ return keras.models.load_model(os.path.join(checkpoint_dir, "model.keras"))
diff --git a/python/raydp/torch/estimator.py b/python/raydp/torch/estimator.py
index 4b4ba4fb..31bf9ec9 100644
--- a/python/raydp/torch/estimator.py
+++ b/python/raydp/torch/estimator.py
@@ -30,11 +30,9 @@
from raydp.spark import spark_dataframe_to_ray_dataset, get_raydp_master_owner
from raydp.spark.dataset import read_spark_parquet
from raydp.torch.config import TorchConfig
-from ray import train
-from ray.train import Checkpoint
-from ray.train.torch import TorchTrainer, TorchCheckpoint
-from ray.air.config import ScalingConfig, RunConfig, FailureConfig
-from ray.air import session
+import ray.train
+from ray.train import Checkpoint, ScalingConfig, RunConfig, FailureConfig
+from ray.train.torch import TorchTrainer
from ray.data.dataset import Dataset
from ray.tune.search.sample import Domain
@@ -223,7 +221,7 @@ def train_func(config):
metrics = config["metrics"]
# create dataset
- train_data_shard = session.get_dataset_shard("train")
+ train_data_shard = ray.train.get_dataset_shard("train")
train_dataset = train_data_shard.to_torch(feature_columns=config["feature_columns"],
feature_column_dtypes=config["feature_types"],
label_column=config["label_column"],
@@ -231,7 +229,7 @@ def train_func(config):
batch_size=config["batch_size"],
drop_last=config["drop_last"])
if config["evaluate"]:
- evaluate_data_shard = session.get_dataset_shard("evaluate")
+ evaluate_data_shard = ray.train.get_dataset_shard("evaluate")
evaluate_dataset = evaluate_data_shard.to_torch(
feature_columns=config["feature_columns"],
label_column=config["label_column"],
@@ -240,16 +238,16 @@ def train_func(config):
batch_size=config["batch_size"],
drop_last=config["drop_last"])
- model = train.torch.prepare_model(model)
+ model = ray.train.torch.prepare_model(model)
loss_results = []
for epoch in range(config["num_epochs"]):
train_res, train_loss = TorchEstimator.train_epoch(train_dataset, model, loss,
optimizer, metrics, lr_scheduler)
- session.report(dict(epoch=epoch, train_res=train_res, train_loss=train_loss))
+ ray.train.report(dict(epoch=epoch, train_res=train_res, train_loss=train_loss))
if config["evaluate"]:
eval_res, evaluate_loss = TorchEstimator.evaluate_epoch(evaluate_dataset,
model, loss, metrics)
- session.report(dict(epoch=epoch, eval_res=eval_res, test_loss=evaluate_loss))
+ ray.train.report(dict(epoch=epoch, eval_res=eval_res, test_loss=evaluate_loss))
loss_results.append(evaluate_loss)
if hasattr(model, "module"):
states = model.module.state_dict()
@@ -260,14 +258,14 @@ def train_func(config):
checkpoint = None
# In standard DDP training, where the model is the same across all ranks,
# only the global rank 0 worker needs to save and report the checkpoint
- if train.get_context().get_world_rank() == 0:
+ if ray.train.get_context().get_world_rank() == 0:
torch.save(
states,
os.path.join(temp_checkpoint_dir, "model.pt"),
)
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
- session.report({}, checkpoint=checkpoint)
+ ray.train.report({}, checkpoint=checkpoint)
@staticmethod
def train_epoch(dataset, model, criterion, optimizer, metrics, scheduler=None):
@@ -366,7 +364,7 @@ def fit_on_spark(self,
train_df = self._check_and_convert(train_df)
evaluate_ds = None
if fs_directory is not None:
- app_id = train_df.sql_ctx.sparkSession.sparkContext.applicationId
+ app_id = train_df.sparkSession.sparkContext.applicationId
path = fs_directory.rstrip("/") + f"/{app_id}"
train_df.write.parquet(path+"/train", compression=compression)
train_ds = read_spark_parquet(path+"/train")
@@ -377,7 +375,7 @@ def fit_on_spark(self,
else:
owner = None
if stop_spark_after_conversion:
- owner = get_raydp_master_owner(train_df.sql_ctx.sparkSession)
+ owner = get_raydp_master_owner(train_df.sparkSession)
train_ds = spark_dataframe_to_ray_dataset(train_df,
owner=owner)
if evaluate_df is not None:
@@ -391,6 +389,14 @@ def fit_on_spark(self,
def get_model(self):
assert self._trainer is not None, "Must call fit first"
- return TorchCheckpoint(
- self._trained_results.checkpoint.to_directory()
- ).get_model(self._model)
+ checkpoint_dir = self._trained_results.checkpoint.to_directory()
+ state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"), weights_only=True)
+ if isinstance(self._model, torch.nn.Module):
+ self._model.load_state_dict(state_dict)
+ return self._model
+ elif callable(self._model):
+ model = self._model({})
+ model.load_state_dict(state_dict)
+ return model
+ else:
+ raise ValueError("Cannot load model: unsupported model type")
diff --git a/python/raydp/xgboost/estimator.py b/python/raydp/xgboost/estimator.py
index 0b6ac1f6..3ecc0a37 100644
--- a/python/raydp/xgboost/estimator.py
+++ b/python/raydp/xgboost/estimator.py
@@ -15,16 +15,21 @@
# limitations under the License.
#
+import os
from typing import Any, Callable, List, NoReturn, Optional, Union, Dict
+import pandas as pd
+import xgboost
+import ray.train
+from ray.train import ScalingConfig, RunConfig, FailureConfig, CheckpointConfig, Checkpoint
+from ray.train.xgboost import XGBoostTrainer, RayTrainReportCallback
+from ray.data.dataset import Dataset
+
from raydp.estimator import EstimatorInterface
from raydp.spark.interfaces import SparkEstimatorInterface, DF, OPTIONAL_DF
from raydp import stop_spark
from raydp.spark import spark_dataframe_to_ray_dataset, get_raydp_master_owner
from raydp.spark.dataset import read_spark_parquet
-from ray.air.config import ScalingConfig, RunConfig, FailureConfig, CheckpointConfig
-from ray.data.dataset import Dataset
-from ray.train.xgboost import XGBoostTrainer, XGBoostCheckpoint
class XGBoostEstimator(EstimatorInterface, SparkEstimatorInterface):
def __init__(self,
@@ -55,13 +60,24 @@ def fit(self,
train_ds: Dataset,
evaluate_ds: Optional[Dataset] = None,
max_retries=3) -> NoReturn:
+ label_column = self._label_column
+ xgboost_params = self._xgboost_params
+
+ def train_loop_per_worker(config):
+ train_shard = ray.train.get_dataset_shard("train")
+ train_df = pd.concat(list(train_shard.iter_batches(batch_format="pandas")))
+ train_y = train_df.pop(label_column)
+ dtrain = xgboost.DMatrix(train_df, label=train_y)
+ xgboost.train(
+ xgboost_params,
+ dtrain,
+ callbacks=[RayTrainReportCallback()],
+ )
+
scaling_config = ScalingConfig(num_workers=self._num_workers,
resources_per_worker=self._resources_per_worker)
run_config = RunConfig(
checkpoint_config=CheckpointConfig(
- # Checkpoint every iteration.
- checkpoint_frequency=1,
- # Only keep the latest checkpoint and delete the others.
num_to_keep=1,
),
failure_config=FailureConfig(max_failures=max_retries)
@@ -73,11 +89,10 @@ def fit(self,
datasets = {"train": train_ds}
if evaluate_ds:
datasets["evaluate"] = evaluate_ds
- trainer = XGBoostTrainer(scaling_config=scaling_config,
- datasets=datasets,
- label_column=self._label_column,
- params=self._xgboost_params,
- run_config=run_config)
+ trainer = XGBoostTrainer(train_loop_per_worker,
+ scaling_config=scaling_config,
+ datasets=datasets,
+ run_config=run_config)
self._results = trainer.fit()
def fit_on_spark(self,
@@ -90,7 +105,7 @@ def fit_on_spark(self,
train_df = self._check_and_convert(train_df)
evaluate_ds = None
if fs_directory is not None:
- app_id = train_df.sql_ctx.sparkSession.sparkContext.applicationId
+ app_id = train_df.sparkSession.sparkContext.applicationId
path = fs_directory.rstrip("/") + f"/{app_id}"
train_df.write.parquet(path+"/train", compression=compression)
train_ds = read_spark_parquet(path+"/train")
@@ -101,7 +116,7 @@ def fit_on_spark(self,
else:
owner = None
if stop_spark_after_conversion:
- owner = get_raydp_master_owner(train_df.sql_ctx.sparkSession)
+ owner = get_raydp_master_owner(train_df.sparkSession)
train_ds = spark_dataframe_to_ray_dataset(train_df,
parallelism=self._num_workers,
owner=owner)
@@ -116,4 +131,7 @@ def fit_on_spark(self,
train_ds, evaluate_ds, max_retries)
def get_model(self):
- return XGBoostTrainer.get_model(self._results.checkpoint)
+ checkpoint_dir = self._results.checkpoint.to_directory()
+ model = xgboost.Booster()
+ model.load_model(os.path.join(checkpoint_dir, RayTrainReportCallback.CHECKPOINT_NAME))
+ return model
diff --git a/python/setup.py b/python/setup.py
index dc9bbcc7..0ac907e3 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -22,12 +22,11 @@
from datetime import datetime
from shutil import copy2, rmtree
-from grpc_tools.command import build_package_protos
from setuptools import find_packages, setup, Command
build_mode = os.getenv("RAYDP_BUILD_MODE", "")
package_name = os.getenv("RAYDP_PACKAGE_NAME", "raydp")
-BASE_VERSION = "1.7.0"
+BASE_VERSION = "2.0.0"
if build_mode == "nightly":
VERSION = BASE_VERSION + datetime.today().strftime("b%Y%m%d.dev0")
# for legacy raydp_nightly package
@@ -38,30 +37,14 @@
ROOT_DIR = os.path.dirname(__file__)
-TEMP_PATH = "deps"
CORE_DIR = os.path.abspath("../core")
BIN_DIR = os.path.abspath("../bin")
JARS_PATH = glob.glob(os.path.join(CORE_DIR, f"**/target/raydp-*.jar"), recursive=True)
-JARS_TARGET = os.path.join(TEMP_PATH, "jars")
+JARS_TARGET = os.path.join(ROOT_DIR, "raydp", "jars")
SCRIPT_PATH = os.path.join(BIN_DIR, f"raydp-submit")
-SCRIPT_TARGET = os.path.join(TEMP_PATH, "bin")
-
-if len(JARS_PATH) == 0:
- print("Can't find core module jars, you need to build the jars with 'mvn clean package'"
- " under core directory first.", file=sys.stderr)
- sys.exit(-1)
-
-# build the temp dir
-try:
- os.mkdir(TEMP_PATH)
- os.mkdir(JARS_TARGET)
- os.mkdir(SCRIPT_TARGET)
-except:
- print(f"Temp path for symlink to parent already exists {TEMP_PATH}", file=sys.stderr)
- sys.exit(-1)
-
+SCRIPT_TARGET = os.path.join(ROOT_DIR, "raydp", "bin")
class CustomBuildPackageProtos(Command):
"""Command to generate project *_pb2.py modules from proto files.
@@ -80,6 +63,7 @@ def finalize_options(self):
pass
def run(self):
+ from grpc_tools.command import build_package_protos
# due to limitations of the proto generator, we require that only *one*
# directory is provided as an 'include' directory. We assume it's the '' key
# to `self.distribution.package_dir` (and get a key error if it's not
@@ -87,24 +71,29 @@ def run(self):
build_package_protos(self.distribution.package_dir["mpi_network_proto"],
self.strict_mode)
+if __name__ == "__main__":
+ if len(JARS_PATH) == 0:
+ print("Can't find core module jars, you need to build the jars with 'mvn clean package'"
+ " under core directory first.", file=sys.stderr)
+ sys.exit(-1)
+
+ # build the temp dir
+ if os.path.exists(JARS_TARGET):
+ rmtree(JARS_TARGET)
+ if os.path.exists(SCRIPT_TARGET):
+ rmtree(SCRIPT_TARGET)
+ os.mkdir(JARS_TARGET)
+ os.mkdir(SCRIPT_TARGET)
+ with open(os.path.join(JARS_TARGET, "__init__.py"), "w") as f:
+ f.write("")
+ with open(os.path.join(SCRIPT_TARGET, "__init__.py"), "w") as f:
+ f.write("")
-try:
for jar_path in JARS_PATH:
print(f"Copying {jar_path} to {JARS_TARGET}")
copy2(jar_path, JARS_TARGET)
copy2(SCRIPT_PATH, SCRIPT_TARGET)
- install_requires = [
- "numpy",
- "pandas >= 1.1.4",
- "psutil",
- "pyarrow >= 4.0.1",
- "ray >= 2.37.0",
- "pyspark >= 3.1.1, <=3.5.7",
- "netifaces",
- "protobuf > 3.19.5"
- ]
-
_packages = find_packages()
_packages.append("raydp.jars")
_packages.append("raydp.bin")
@@ -112,12 +101,6 @@ def run(self):
setup(
name=package_name,
version=VERSION,
- author="RayDP Developers",
- author_email="raydp-dev@googlegroups.com",
- license="Apache 2.0",
- url="https://github.com/ray-project/raydp",
- keywords="raydp spark ray distributed data-processing",
- description="RayDP: Distributed Data Processing on Ray",
long_description=io.open(
os.path.join(ROOT_DIR, os.path.pardir, "README.md"),
"r",
@@ -125,23 +108,9 @@ def run(self):
long_description_content_type="text/markdown",
packages=_packages,
include_package_data=True,
- package_dir={"raydp.jars": "deps/jars", "raydp.bin": "deps/bin",
- "mpi_network_proto": "raydp/mpi/network"},
+ package_dir={"mpi_network_proto": "raydp/mpi/network"},
package_data={"raydp.jars": ["*.jar"], "raydp.bin": ["raydp-submit"]},
cmdclass={
'build_proto_modules': CustomBuildPackageProtos,
},
- install_requires=install_requires,
- setup_requires=["grpcio-tools"],
- python_requires='>=3.6',
- classifiers=[
- 'License :: OSI Approved :: Apache Software License',
- 'Programming Language :: Python :: 3.8',
- 'Programming Language :: Python :: 3.9',
- 'Programming Language :: Python :: 3.10',
- ]
)
-finally:
- rmtree(os.path.join(TEMP_PATH, "jars"))
- rmtree(os.path.join(TEMP_PATH, "bin"))
- os.rmdir(TEMP_PATH)