From a45140ed4a8c31db9d9166e9accd2439a17a345a Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Sat, 31 Jan 2026 23:57:59 -0800 Subject: [PATCH 01/34] do one hop forward fetch if recache data change executor --- .../apache/spark/executor/RayDPExecutor.scala | 88 ++++++++++++++++++- 1 file changed, 86 insertions(+), 2 deletions(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala index afc8de0b..cb8ff3b4 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala @@ -25,6 +25,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.reflect.classTag import com.intel.raydp.shims.SparkShimLoader +import io.ray.api.ActorHandle import io.ray.api.Ray import io.ray.runtime.config.RayConfig import org.apache.arrow.vector.ipc.{ArrowStreamWriter, WriteChannel} @@ -309,11 +310,52 @@ class RayDPExecutor( env.shutdown } - def getRDDPartition( + private def parseExecutorIdFromLocation(loc: String): String = { + loc.substring(loc.lastIndexOf('_') + 1) + } + + /** Refresh the current executor ID that owns a cached Spark block, if any. */ + private def getCurrentBlockOwnerExecutorId(blockId: BlockId): Option[String] = { + val env = SparkEnv.get + val locations = BlockManager.blockIdsToLocations(Array(blockId), env) + val locs = locations.getOrElse(blockId, Seq.empty[String]) + if (locs.nonEmpty) { + Some(parseExecutorIdFromLocation(locs.head)) + } else { + None + } + } + + /** + * Map a (potentially restarted) Spark executor ID to the Ray actor-name executor ID. + * + * When a RayDP executor actor restarts, it keeps its Ray actor name, but Spark may assign a new + * executor ID. RayAppMaster tracks a mapping (new -> old). We must use the old ID to resolve + * the Ray actor by name. + */ + private def resolveRayActorExecutorId(sparkExecutorId: String): String = { + try { + val appMasterHandle = + Ray.getActor(RayAppMaster.ACTOR_NAME).get.asInstanceOf[ActorHandle[RayAppMaster]] + val restartedExecutors = RayAppMasterUtils.getRestartedExecutors(appMasterHandle) + if (restartedExecutors != null && restartedExecutors.containsKey(sparkExecutorId)) { + restartedExecutors.get(sparkExecutorId) + } else { + sparkExecutorId + } + } catch { + case _: Throwable => + // Best-effort: if we cannot query the app master for any reason, fall back to the given ID. + sparkExecutorId + } + } + + private def getRDDPartitionInternal( rddId: Int, partitionId: Int, schemaStr: String, - driverAgentUrl: String): Array[Byte] = { + driverAgentUrl: String, + allowForward: Boolean): Array[Byte] = { while (!started.get) { // wait until executor is started // this might happen if executor restarts @@ -339,6 +381,30 @@ class RayDPExecutor( env.blockManager.get(blockId)(classTag[Array[Byte]]) match { case Some(blockResult) => blockResult.data.asInstanceOf[Iterator[Array[Byte]]] + case None if allowForward => + // The block may have been (re)cached on a different executor after recache. + val ownerOpt = getCurrentBlockOwnerExecutorId(blockId) + ownerOpt match { + case Some(ownerSparkExecutorId) if ownerSparkExecutorId != executorId => + val ownerRayExecutorId = resolveRayActorExecutorId(ownerSparkExecutorId) + logWarning( + s"Cached block $blockId not found on executor $executorId after recache. " + + s"Forwarding fetch to executor $ownerSparkExecutorId " + + s"(ray actor id $ownerRayExecutorId).") + val otherHandle = + Ray.getActor("raydp-executor-" + ownerRayExecutorId).get + .asInstanceOf[ActorHandle[RayDPExecutor]] + // One-hop forward only: call no-forward variant on the target executor and + // return the Arrow IPC bytes directly. + return otherHandle + .task( + (e: RayDPExecutor) => + e.getRDDPartitionNoForward(rddId, partitionId, schemaStr, driverAgentUrl)) + .remote() + .get() + case Some(_) | None => + throw new RayDPException("Still cannot get the block after recache!") + } case None => throw new RayDPException("Still cannot get the block after recache!") } @@ -353,4 +419,22 @@ class RayDPExecutor( byteOut.close() result } + + /** Public entry-point used by cross-language calls. Allows forwarding. */ + def getRDDPartition( + rddId: Int, + partitionId: Int, + schemaStr: String, + driverAgentUrl: String): Array[Byte] = { + getRDDPartitionInternal(rddId, partitionId, schemaStr, driverAgentUrl, allowForward = true) + } + + /** Internal one-hop target to prevent forward loops. */ + def getRDDPartitionNoForward( + rddId: Int, + partitionId: Int, + schemaStr: String, + driverAgentUrl: String): Array[Byte] = { + getRDDPartitionInternal(rddId, partitionId, schemaStr, driverAgentUrl, allowForward = false) + } } From 22259a68dd9bb36d88b5059661a18bc522ef9ae5 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Sun, 1 Feb 2026 00:15:08 -0800 Subject: [PATCH 02/34] more robust executor id parse --- .../org/apache/spark/executor/RayDPExecutor.scala | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala index cb8ff3b4..59f905ad 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala @@ -310,20 +310,11 @@ class RayDPExecutor( env.shutdown } - private def parseExecutorIdFromLocation(loc: String): String = { - loc.substring(loc.lastIndexOf('_') + 1) - } - /** Refresh the current executor ID that owns a cached Spark block, if any. */ private def getCurrentBlockOwnerExecutorId(blockId: BlockId): Option[String] = { val env = SparkEnv.get - val locations = BlockManager.blockIdsToLocations(Array(blockId), env) - val locs = locations.getOrElse(blockId, Seq.empty[String]) - if (locs.nonEmpty) { - Some(parseExecutorIdFromLocation(locs.head)) - } else { - None - } + val locs = env.blockManager.master.getLocations(blockId) + if (locs != null && locs.nonEmpty) Some(locs.head.executorId) else None } /** From 7b505580824d87e5ec618a1fff9319c80ffc84ff Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Sun, 1 Feb 2026 01:03:39 -0800 Subject: [PATCH 03/34] add test --- .../tests/test_recoverable_forwarding.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 python/raydp/tests/test_recoverable_forwarding.py diff --git a/python/raydp/tests/test_recoverable_forwarding.py b/python/raydp/tests/test_recoverable_forwarding.py new file mode 100644 index 00000000..9d567d65 --- /dev/null +++ b/python/raydp/tests/test_recoverable_forwarding.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import platform + +import pytest +from pyspark.storagelevel import StorageLevel +import ray +import ray.util.client as ray_client + +from raydp.spark import dataset as spark_dataset + + +if platform.system() == "Darwin": + # Spark-on-Ray recoverable path is unstable on macOS and can crash the raylet. + pytest.skip("Skip recoverable forwarding test on macOS", allow_module_level=True) + + +@pytest.mark.parametrize("spark_on_ray_2_executors", ["local"], indirect=True) +def test_recoverable_forwarding_via_fetch_task(spark_on_ray_2_executors): + """Verify JVM-side forwarding in recoverable Spark->Ray conversion. + + We deliberately trigger the recoverable fetch task to contact an executor actor that is not + the current owner of the cached Spark block for the chosen partition. The request should still + succeed because the executor refreshes the block owner and forwards the fetch one hop. + """ + if ray_client.ray.is_connected(): + pytest.skip("Skip forwarding test in Ray client mode") + + spark = spark_on_ray_2_executors + + # Create enough partitions so that at least two different executors own cached blocks. + df = spark.range(0, 10000, numPartitions=8) + + sc = spark.sparkContext + storage_level = sc._getJavaStorageLevel(StorageLevel.MEMORY_AND_DISK) + object_store_writer = sc._jvm.org.apache.spark.sql.raydp.ObjectStoreWriter + + info = object_store_writer.prepareRecoverableRDD(df._jdf, storage_level) + rdd_id = info.rddId() + schema_json = info.schemaJson() + driver_agent_url = info.driverAgentUrl() + locations = list(info.locations()) + + assert locations + unique_execs = sorted(set(locations)) + assert len(unique_execs) >= 2, f"Need >=2 executors, got {unique_execs}" + + # Pick a partition and intentionally target the *wrong* executor actor. + partition_id = 0 + owner_executor_id = locations[partition_id] + wrong_executor_id = next(e for e in unique_execs if e != owner_executor_id) + + # Ensure Ray cross-language calls are enabled for the worker side. + spark_dataset._enable_load_code_from_local() + + wrong_executor_actor_name = f"raydp-executor-{wrong_executor_id}" + table = ray.get( + spark_dataset._fetch_arrow_table_from_executor.remote( + wrong_executor_actor_name, rdd_id, partition_id, schema_json, driver_agent_url + ) + ) + assert table.num_rows > 0 + From 75bff2e79da7e163428342b7ec7720c02bd03d8c Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Sun, 1 Feb 2026 16:29:02 -0800 Subject: [PATCH 04/34] add test --- .../apache/spark/executor/RayDPExecutor.scala | 2 +- .../tests/test_recoverable_forwarding.py | 201 ++++++++++++++---- 2 files changed, 157 insertions(+), 46 deletions(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala index 59f905ad..0b9fb837 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala @@ -383,7 +383,7 @@ class RayDPExecutor( s"Forwarding fetch to executor $ownerSparkExecutorId " + s"(ray actor id $ownerRayExecutorId).") val otherHandle = - Ray.getActor("raydp-executor-" + ownerRayExecutorId).get + Ray.getActor("raydp-executor-" + ownerRayExecutorId).get() .asInstanceOf[ActorHandle[RayDPExecutor]] // One-hop forward only: call no-forward variant on the target executor and // return the Arrow IPC bytes directly. diff --git a/python/raydp/tests/test_recoverable_forwarding.py b/python/raydp/tests/test_recoverable_forwarding.py index 9d567d65..6f205d93 100644 --- a/python/raydp/tests/test_recoverable_forwarding.py +++ b/python/raydp/tests/test_recoverable_forwarding.py @@ -15,64 +15,175 @@ # limitations under the License. # -import platform - +import os import pytest +import pyarrow as pa from pyspark.storagelevel import StorageLevel import ray +from ray.cluster_utils import Cluster +from ray.data import from_arrow_refs import ray.util.client as ray_client +import raydp + +try: + # Ray cross-language calls require enabling load_code_from_local. + # This is an internal Ray API; keep it isolated and optional. + from ray._private.worker import global_worker as _ray_global_worker # type: ignore +except Exception: # pragma: no cover + _ray_global_worker = None + +@ray.remote(max_retries=-1) +def _fetch_arrow_table_from_executor( + executor_actor_name: str, + rdd_id: int, + partition_id: int, + schema_json: str, + driver_agent_url: str, +) -> pa.Table: + """Fetch Arrow table bytes from a JVM executor actor and decode to `pyarrow.Table`. + + This is a test-local version of RayDP's recoverable fetch task. Keeping it in this test + avoids Ray remote function registration issues when driver/workers import different `raydp` + versions. + """ + if _ray_global_worker is not None: + _ray_global_worker.set_load_code_from_local(True) -from raydp.spark import dataset as spark_dataset - + executor_actor = ray.get_actor(executor_actor_name) + ipc_bytes = ray.get( + executor_actor.getRDDPartition.remote( + rdd_id, partition_id, schema_json, driver_agent_url + ) + ) + reader = pa.ipc.open_stream(pa.BufferReader(ipc_bytes)) + table = reader.read_all() + # Match RayDP behavior: strip schema metadata for stability. + table = table.replace_schema_metadata() + return table -if platform.system() == "Darwin": - # Spark-on-Ray recoverable path is unstable on macOS and can crash the raylet. - pytest.skip("Skip recoverable forwarding test on macOS", allow_module_level=True) -@pytest.mark.parametrize("spark_on_ray_2_executors", ["local"], indirect=True) -def test_recoverable_forwarding_via_fetch_task(spark_on_ray_2_executors): +def test_recoverable_forwarding_via_fetch_task(jdk17_extra_spark_configs): """Verify JVM-side forwarding in recoverable Spark->Ray conversion. - We deliberately trigger the recoverable fetch task to contact an executor actor that is not - the current owner of the cached Spark block for the chosen partition. The request should still - succeed because the executor refreshes the block owner and forwards the fetch one hop. + This test intentionally calls the recoverable fetch task on the *wrong* Spark executor actor. + It should still succeed because `RayDPExecutor.getRDDPartition` refreshes the block owner and + forwards the fetch one hop. """ if ray_client.ray.is_connected(): pytest.skip("Skip forwarding test in Ray client mode") - spark = spark_on_ray_2_executors - - # Create enough partitions so that at least two different executors own cached blocks. - df = spark.range(0, 10000, numPartitions=8) - - sc = spark.sparkContext - storage_level = sc._getJavaStorageLevel(StorageLevel.MEMORY_AND_DISK) - object_store_writer = sc._jvm.org.apache.spark.sql.raydp.ObjectStoreWriter - - info = object_store_writer.prepareRecoverableRDD(df._jdf, storage_level) - rdd_id = info.rddId() - schema_json = info.schemaJson() - driver_agent_url = info.driverAgentUrl() - locations = list(info.locations()) - - assert locations - unique_execs = sorted(set(locations)) - assert len(unique_execs) >= 2, f"Need >=2 executors, got {unique_execs}" - - # Pick a partition and intentionally target the *wrong* executor actor. - partition_id = 0 - owner_executor_id = locations[partition_id] - wrong_executor_id = next(e for e in unique_execs if e != owner_executor_id) - - # Ensure Ray cross-language calls are enabled for the worker side. - spark_dataset._enable_load_code_from_local() - - wrong_executor_actor_name = f"raydp-executor-{wrong_executor_id}" - table = ray.get( - spark_dataset._fetch_arrow_table_from_executor.remote( - wrong_executor_actor_name, rdd_id, partition_id, schema_json, driver_agent_url - ) + stop_after = os.environ.get("RAYDP_TRACE_STOP_AFTER", "").strip().lower() + fetch_mode = os.environ.get("RAYDP_FETCH_MODE", "task").strip().lower() + cluster = Cluster( + initialize_head=True, + head_node_args={ + "num_cpus": 2, + "resources": {"master": 10}, + "include_dashboard": True, + "dashboard_port": 0, + }, ) - assert table.num_rows > 0 + cluster.add_node(num_cpus=4, resources={"spark_executor": 10}) + + def phase(name: str) -> None: + # Prints are the most reliable breadcrumb if the raylet crashes. + print(f"\n=== PHASE: {name} ===", flush=True) + + def should_stop(name: str) -> bool: + return bool(stop_after) and stop_after == name.lower() + + spark = None + try: + # Single-node Ray is sufficient to reproduce / bisect the crash. + phase("ray.init") + ray.shutdown() + ray.init(address=cluster.address, include_dashboard=False) + if should_stop("ray.init"): + return + + phase("raydp.init_spark") + node_ip = ray.util.get_node_ip_address() + spark = raydp.init_spark( + app_name="test_recoverable_forwarding_via_fetch_task", + num_executors=2, + executor_cores=1, + executor_memory="500M", + configs={ + "spark.driver.host": node_ip, + "spark.driver.bindAddress": node_ip, + **jdk17_extra_spark_configs, + }, + ) + if should_stop("raydp.init_spark"): + return + + phase("spark.range.count") + df = spark.range(0, 10000, numPartitions=8) + _ = df.count() + if should_stop("spark.range.count"): + return + + phase("prepareRecoverableRDD") + sc = spark.sparkContext + storage_level = sc._getJavaStorageLevel(StorageLevel.MEMORY_AND_DISK) + object_store_writer = sc._jvm.org.apache.spark.sql.raydp.ObjectStoreWriter + info = object_store_writer.prepareRecoverableRDD(df._jdf, storage_level) + rdd_id = info.rddId() + schema_json = info.schemaJson() + driver_agent_url = info.driverAgentUrl() + locations = list(info.locations()) + if should_stop("preparerecoverablerdd"): + return + + assert locations + unique_execs = sorted(set(locations)) + assert len(unique_execs) >= 2, f"Need >=2 executors, got {unique_execs}" + + partition_id = 0 + owner_executor_id = locations[partition_id] + wrong_executor_id = next(e for e in unique_execs if e != owner_executor_id) + wrong_executor_actor_name = f"raydp-executor-{wrong_executor_id}" + + phase("fetch_wrong_executor") + + phase("get_wrong_executor_actor") + wrong_executor_actor = ray.get_actor(wrong_executor_actor_name) + if should_stop("get_wrong_executor_actor"): + return + + phase("call_fetch_task") + if fetch_mode == "driver": + phase("driver_call_java_actor") + if _ray_global_worker is not None: + _ray_global_worker.set_load_code_from_local(True) + ipc_bytes = ray.get( + wrong_executor_actor.getRDDPartition.remote( + rdd_id, partition_id, schema_json, driver_agent_url + ) + ) + reader = pa.ipc.open_stream(pa.BufferReader(ipc_bytes)) + table = reader.read_all() + table = table.replace_schema_metadata() + else: + phase("task_call_java_actor") + refs: list[ray.ObjectRef] = [] + refs.append( + _fetch_arrow_table_from_executor.remote( + wrong_executor_actor_name, + rdd_id, + partition_id, + schema_json, + driver_agent_url, + ) + ) + table = from_arrow_refs(refs) + assert table.count() > 0 + finally: + phase("teardown") + + spark.stop() + raydp.stop_spark() + ray.shutdown() + cluster.shutdown() From 32935ab55d924d26b280866e7306fc83b5560094 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Mon, 2 Feb 2026 00:30:22 -0800 Subject: [PATCH 05/34] remove test --- .../tests/test_recoverable_forwarding.py | 189 ------------------ 1 file changed, 189 deletions(-) delete mode 100644 python/raydp/tests/test_recoverable_forwarding.py diff --git a/python/raydp/tests/test_recoverable_forwarding.py b/python/raydp/tests/test_recoverable_forwarding.py deleted file mode 100644 index 6f205d93..00000000 --- a/python/raydp/tests/test_recoverable_forwarding.py +++ /dev/null @@ -1,189 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import os -import pytest -import pyarrow as pa -from pyspark.storagelevel import StorageLevel -import ray -from ray.cluster_utils import Cluster -from ray.data import from_arrow_refs -import ray.util.client as ray_client -import raydp - -try: - # Ray cross-language calls require enabling load_code_from_local. - # This is an internal Ray API; keep it isolated and optional. - from ray._private.worker import global_worker as _ray_global_worker # type: ignore -except Exception: # pragma: no cover - _ray_global_worker = None - -@ray.remote(max_retries=-1) -def _fetch_arrow_table_from_executor( - executor_actor_name: str, - rdd_id: int, - partition_id: int, - schema_json: str, - driver_agent_url: str, -) -> pa.Table: - """Fetch Arrow table bytes from a JVM executor actor and decode to `pyarrow.Table`. - - This is a test-local version of RayDP's recoverable fetch task. Keeping it in this test - avoids Ray remote function registration issues when driver/workers import different `raydp` - versions. - """ - if _ray_global_worker is not None: - _ray_global_worker.set_load_code_from_local(True) - - executor_actor = ray.get_actor(executor_actor_name) - ipc_bytes = ray.get( - executor_actor.getRDDPartition.remote( - rdd_id, partition_id, schema_json, driver_agent_url - ) - ) - reader = pa.ipc.open_stream(pa.BufferReader(ipc_bytes)) - table = reader.read_all() - # Match RayDP behavior: strip schema metadata for stability. - table = table.replace_schema_metadata() - return table - - - -def test_recoverable_forwarding_via_fetch_task(jdk17_extra_spark_configs): - """Verify JVM-side forwarding in recoverable Spark->Ray conversion. - - This test intentionally calls the recoverable fetch task on the *wrong* Spark executor actor. - It should still succeed because `RayDPExecutor.getRDDPartition` refreshes the block owner and - forwards the fetch one hop. - """ - if ray_client.ray.is_connected(): - pytest.skip("Skip forwarding test in Ray client mode") - - stop_after = os.environ.get("RAYDP_TRACE_STOP_AFTER", "").strip().lower() - fetch_mode = os.environ.get("RAYDP_FETCH_MODE", "task").strip().lower() - cluster = Cluster( - initialize_head=True, - head_node_args={ - "num_cpus": 2, - "resources": {"master": 10}, - "include_dashboard": True, - "dashboard_port": 0, - }, - ) - cluster.add_node(num_cpus=4, resources={"spark_executor": 10}) - - def phase(name: str) -> None: - # Prints are the most reliable breadcrumb if the raylet crashes. - print(f"\n=== PHASE: {name} ===", flush=True) - - def should_stop(name: str) -> bool: - return bool(stop_after) and stop_after == name.lower() - - spark = None - try: - # Single-node Ray is sufficient to reproduce / bisect the crash. - phase("ray.init") - ray.shutdown() - ray.init(address=cluster.address, include_dashboard=False) - if should_stop("ray.init"): - return - - phase("raydp.init_spark") - node_ip = ray.util.get_node_ip_address() - spark = raydp.init_spark( - app_name="test_recoverable_forwarding_via_fetch_task", - num_executors=2, - executor_cores=1, - executor_memory="500M", - configs={ - "spark.driver.host": node_ip, - "spark.driver.bindAddress": node_ip, - **jdk17_extra_spark_configs, - }, - ) - if should_stop("raydp.init_spark"): - return - - phase("spark.range.count") - df = spark.range(0, 10000, numPartitions=8) - _ = df.count() - if should_stop("spark.range.count"): - return - - phase("prepareRecoverableRDD") - sc = spark.sparkContext - storage_level = sc._getJavaStorageLevel(StorageLevel.MEMORY_AND_DISK) - object_store_writer = sc._jvm.org.apache.spark.sql.raydp.ObjectStoreWriter - info = object_store_writer.prepareRecoverableRDD(df._jdf, storage_level) - rdd_id = info.rddId() - schema_json = info.schemaJson() - driver_agent_url = info.driverAgentUrl() - locations = list(info.locations()) - if should_stop("preparerecoverablerdd"): - return - - assert locations - unique_execs = sorted(set(locations)) - assert len(unique_execs) >= 2, f"Need >=2 executors, got {unique_execs}" - - partition_id = 0 - owner_executor_id = locations[partition_id] - wrong_executor_id = next(e for e in unique_execs if e != owner_executor_id) - wrong_executor_actor_name = f"raydp-executor-{wrong_executor_id}" - - phase("fetch_wrong_executor") - - phase("get_wrong_executor_actor") - wrong_executor_actor = ray.get_actor(wrong_executor_actor_name) - if should_stop("get_wrong_executor_actor"): - return - - phase("call_fetch_task") - if fetch_mode == "driver": - phase("driver_call_java_actor") - if _ray_global_worker is not None: - _ray_global_worker.set_load_code_from_local(True) - ipc_bytes = ray.get( - wrong_executor_actor.getRDDPartition.remote( - rdd_id, partition_id, schema_json, driver_agent_url - ) - ) - reader = pa.ipc.open_stream(pa.BufferReader(ipc_bytes)) - table = reader.read_all() - table = table.replace_schema_metadata() - else: - phase("task_call_java_actor") - refs: list[ray.ObjectRef] = [] - refs.append( - _fetch_arrow_table_from_executor.remote( - wrong_executor_actor_name, - rdd_id, - partition_id, - schema_json, - driver_agent_url, - ) - ) - table = from_arrow_refs(refs) - assert table.count() > 0 - finally: - phase("teardown") - - spark.stop() - raydp.stop_spark() - ray.shutdown() - cluster.shutdown() - From 099007fc19c112f1b4e82f0c78529c8004bdc529 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Mon, 2 Feb 2026 00:33:42 -0800 Subject: [PATCH 06/34] revert change in dataset.py --- python/raydp/spark/dataset.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py index 19a368e5..06d205f7 100644 --- a/python/raydp/spark/dataset.py +++ b/python/raydp/spark/dataset.py @@ -77,10 +77,6 @@ def _fetch_arrow_table_from_executor(executor_actor_name: str, rdd_id, partition_id, schema_json, driver_agent_url)) reader = pa.ipc.open_stream(pa.BufferReader(ipc_bytes)) table = reader.read_all() - # Spark's Arrow conversion may attach schema metadata. Ray Data metadata extraction - # can be sensitive to unexpected schema metadata in some Ray/PyArrow combinations. - # Strip schema metadata to make blocks more portable/deterministic. - table = table.replace_schema_metadata() return table From a489788a67cbbbfb7e0fbe7a4e132def4d083855 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Mon, 2 Feb 2026 01:08:42 -0800 Subject: [PATCH 07/34] clean up --- .../main/scala/org/apache/spark/executor/RayDPExecutor.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala index 0b9fb837..853c38e1 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala @@ -376,7 +376,7 @@ class RayDPExecutor( // The block may have been (re)cached on a different executor after recache. val ownerOpt = getCurrentBlockOwnerExecutorId(blockId) ownerOpt match { - case Some(ownerSparkExecutorId) if ownerSparkExecutorId != executorId => + case Some(ownerSparkExecutorId) => val ownerRayExecutorId = resolveRayActorExecutorId(ownerSparkExecutorId) logWarning( s"Cached block $blockId not found on executor $executorId after recache. " + @@ -393,7 +393,7 @@ class RayDPExecutor( e.getRDDPartitionNoForward(rddId, partitionId, schemaStr, driverAgentUrl)) .remote() .get() - case Some(_) | None => + case None => throw new RayDPException("Still cannot get the block after recache!") } case None => From 2c8df451cb52e44a0ab525fed84d47b8dab6e3fe Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Mon, 2 Feb 2026 01:11:26 -0800 Subject: [PATCH 08/34] clean up --- .../main/scala/org/apache/spark/executor/RayDPExecutor.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala index 853c38e1..9ec06431 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala @@ -394,7 +394,8 @@ class RayDPExecutor( .remote() .get() case None => - throw new RayDPException("Still cannot get the block after recache!") + throw new RayDPException( + s"Still cannot get block $blockId for RDD $rddId after recache!") } case None => throw new RayDPException("Still cannot get the block after recache!") From fb31b09c872cc0b479e776b9bdfd0b746d426c7d Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Mon, 2 Feb 2026 12:33:06 -0800 Subject: [PATCH 09/34] strip off table metadata again --- README.md | 2 +- python/raydp/spark/dataset.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4a780a78..07ed1395 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,7 @@ Please refer to [NYC Taxi PyTorch Estimator](./examples/pytorch_nyctaxi.py) and RayDP now converts Spark DataFrames to Ray Datasets using a recoverable pipeline by default. This makes the resulting Ray Dataset resilient to Spark executor loss (the Arrow IPC bytes are cached in Spark and fetched via Ray tasks with lineage). -The recoverable conversion is also available directly via `raydp.spark.from_spark_recoverable`, and it persists (caches) the Spark DataFrame. You can provide the storage level through the `storage_level` keyword parameter. +The recoverable conversion is also available directly via `raydp.spark.from_spark_recoverable`, and it persists (caches) the Spark DataFrame. By default it uses disk-only persistence; you can override this via the `storage_level` keyword parameter. ```python import ray diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py index 06d205f7..ce49f20a 100644 --- a/python/raydp/spark/dataset.py +++ b/python/raydp/spark/dataset.py @@ -77,6 +77,10 @@ def _fetch_arrow_table_from_executor(executor_actor_name: str, rdd_id, partition_id, schema_json, driver_agent_url)) reader = pa.ipc.open_stream(pa.BufferReader(ipc_bytes)) table = reader.read_all() + # Spark's Arrow conversion may attach schema metadata. Ray Data metadata extraction + # can be sensitive to unexpected schema metadata in some Ray/PyArrow combinations. + # Strip schema metadata to make blocks more portable/deterministic. + table = table.replace_schema_metadata() return table @@ -166,7 +170,7 @@ def spark_dataframe_to_ray_dataset(df: sql.DataFrame, return from_spark_recoverable(df, parallelism=parallelism) def from_spark_recoverable(df: sql.DataFrame, - storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK, + storage_level: StorageLevel = StorageLevel.DISK_ONLY, parallelism: Optional[int] = None): """Recoverable Spark->Ray conversion that survives executor loss.""" original_df = df From b60bffd84d68b024a187840913add8e6a25c8dbe Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Mon, 16 Feb 2026 23:56:47 +0800 Subject: [PATCH 10/34] fix spark gc race condition --- README.md | 2 +- python/raydp/spark/dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 07ed1395..4a780a78 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,7 @@ Please refer to [NYC Taxi PyTorch Estimator](./examples/pytorch_nyctaxi.py) and RayDP now converts Spark DataFrames to Ray Datasets using a recoverable pipeline by default. This makes the resulting Ray Dataset resilient to Spark executor loss (the Arrow IPC bytes are cached in Spark and fetched via Ray tasks with lineage). -The recoverable conversion is also available directly via `raydp.spark.from_spark_recoverable`, and it persists (caches) the Spark DataFrame. By default it uses disk-only persistence; you can override this via the `storage_level` keyword parameter. +The recoverable conversion is also available directly via `raydp.spark.from_spark_recoverable`, and it persists (caches) the Spark DataFrame. You can provide the storage level through the `storage_level` keyword parameter. ```python import ray diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py index ce49f20a..19a368e5 100644 --- a/python/raydp/spark/dataset.py +++ b/python/raydp/spark/dataset.py @@ -170,7 +170,7 @@ def spark_dataframe_to_ray_dataset(df: sql.DataFrame, return from_spark_recoverable(df, parallelism=parallelism) def from_spark_recoverable(df: sql.DataFrame, - storage_level: StorageLevel = StorageLevel.DISK_ONLY, + storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK, parallelism: Optional[int] = None): """Recoverable Spark->Ray conversion that survives executor loss.""" original_df = df From 86564cb3748c2e15c4706eea9b734d56b94652f9 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Tue, 24 Dec 2024 13:20:31 -0800 Subject: [PATCH 11/34] Add spark 3.4.4 and 3.5.4 support From 5334d13e0fc1804a42b2b0d3051e09da4d6be9c7 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Thu, 21 Aug 2025 09:30:52 -0700 Subject: [PATCH 12/34] Support Spark 4.0.0 --- .github/workflows/ray_nightly_test.yml | 2 +- core/pom.xml | 7 +- .../spark/deploy/raydp/RayAppMaster.scala | 22 +++-- .../spark/sql/raydp/ObjectStoreWriter.scala | 1 - core/shims/pom.xml | 3 +- core/shims/spark322/pom.xml | 2 +- core/shims/spark330/pom.xml | 2 +- core/shims/spark340/pom.xml | 2 +- core/shims/spark350/pom.xml | 2 +- core/shims/spark400/pom.xml | 99 +++++++++++++++++++ .../com.intel.raydp.shims.SparkShimProvider | 1 + .../intel/raydp/shims/SparkShimProvider.scala | 36 +++++++ .../com/intel/raydp/shims/SparkShims.scala | 51 ++++++++++ .../org/apache/spark/TaskContextUtils.scala | 30 ++++++ .../RayCoarseGrainedExecutorBackend.scala | 50 ++++++++++ .../RayDPSpark400ExecutorBackendFactory.scala | 51 ++++++++++ .../org/apache/spark/sql/SparkSqlUtils.scala | 60 +++++++++++ python/raydp/tests/conftest.py | 5 +- python/setup.py | 2 +- 19 files changed, 408 insertions(+), 20 deletions(-) create mode 100644 core/shims/spark400/pom.xml create mode 100644 core/shims/spark400/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider create mode 100644 core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala create mode 100644 core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala create mode 100644 core/shims/spark400/src/main/scala/org/apache/spark/TaskContextUtils.scala create mode 100644 core/shims/spark400/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala create mode 100644 core/shims/spark400/src/main/scala/org/apache/spark/executor/RayDPSpark400ExecutorBackendFactory.scala create mode 100644 core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala diff --git a/.github/workflows/ray_nightly_test.yml b/.github/workflows/ray_nightly_test.yml index 95b4eb96..68df8ae4 100644 --- a/.github/workflows/ray_nightly_test.yml +++ b/.github/workflows/ray_nightly_test.yml @@ -32,7 +32,7 @@ jobs: matrix: os: [ ubuntu-latest ] python-version: [3.9, 3.10.14] - spark-version: [3.3.2, 3.4.0, 3.5.0] + spark-version: [3.2.4, 3.3.2, 3.4.0, 3.5.6, 4.0.0] runs-on: ${{ matrix.os }} diff --git a/core/pom.xml b/core/pom.xml index 67d945d7..6641e7fe 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -18,6 +18,7 @@ 3.3.0 3.4.0 3.5.0 + 4.0.0 1.1.10.4 4.1.94.Final 1.10.0 @@ -29,9 +30,9 @@ UTF-8 1.8 1.8 - 2.12.15 - 2.18.6 - 2.12 + 2.13.12 + 2.15.0 + 2.13 5.10.1 diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala index f4cc823d..2fa7f9fe 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala @@ -22,15 +22,16 @@ import java.text.SimpleDateFormat import java.util.{Date, Locale} import javax.xml.bind.DatatypeConverter -import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap +import scala.jdk.CollectionConverters._ + +import com.fasterxml.jackson.core.JsonFactory +import com.fasterxml.jackson.databind.ObjectMapper import io.ray.api.{ActorHandle, PlacementGroups, Ray} import io.ray.api.id.PlacementGroupId import io.ray.api.placementgroup.PlacementGroup import io.ray.runtime.config.RayConfig -import org.json4s._ -import org.json4s.jackson.JsonMethods._ import org.apache.spark.{RayDPException, SecurityManager, SparkConf} import org.apache.spark.executor.RayDPExecutor @@ -39,6 +40,7 @@ import org.apache.spark.raydp.{RayExecutorUtils, SparkOnRayConfigs} import org.apache.spark.rpc._ import org.apache.spark.util.Utils + class RayAppMaster(host: String, port: Int, actorExtraClasspath: String) extends Serializable with Logging { @@ -298,7 +300,7 @@ class RayAppMaster(host: String, .map { case (name, amount) => (name, Double.box(amount)) }.asJava, placementGroup, getNextBundleIndex, - seqAsJavaList(appInfo.desc.command.javaOpts)) + appInfo.desc.command.javaOpts.asJava) appInfo.addPendingRegisterExecutor(executorId, handler, sparkCoresPerExecutor, memory) } @@ -356,11 +358,15 @@ object RayAppMaster extends Serializable { val ACTOR_NAME = "RAY_APP_MASTER" def setProperties(properties: String): Unit = { - implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats - val parsed = parse(properties).extract[Map[String, String]] - parsed.foreach{ case (key, value) => - System.setProperty(key, value) + // Use Jackson ObjectMapper directly to avoid JSON4S version conflicts + val mapper = new ObjectMapper() + val javaMap = mapper.readValue(properties, classOf[java.util.Map[String, Object]]) + val scalaMap = javaMap.asScala.toMap + scalaMap.foreach{ case (key, value) => + // Convert all values to strings since System.setProperty expects String + System.setProperty(key, value.toString) } + // Use the same session dir as the python side RayConfig.create().setSessionDir(System.getProperty("ray.session-dir")) } diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala index 2d607044..133656dc 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala @@ -24,7 +24,6 @@ import java.util.{List, UUID} import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} import java.util.function.{Function => JFunction} import org.apache.arrow.vector.types.pojo.Schema -import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.{RayDPException, SparkContext} diff --git a/core/shims/pom.xml b/core/shims/pom.xml index c013538b..ac16dba7 100644 --- a/core/shims/pom.xml +++ b/core/shims/pom.xml @@ -21,10 +21,11 @@ spark330 spark340 spark350 + spark400 - 2.12 + 2.13 4.3.0 3.2.2 diff --git a/core/shims/spark322/pom.xml b/core/shims/spark322/pom.xml index faff6ac5..0e9100c1 100644 --- a/core/shims/spark322/pom.xml +++ b/core/shims/spark322/pom.xml @@ -16,7 +16,7 @@ jar - 2.12.15 + 2.13.12 2.13.5 diff --git a/core/shims/spark330/pom.xml b/core/shims/spark330/pom.xml index 4443f658..3e229ade 100644 --- a/core/shims/spark330/pom.xml +++ b/core/shims/spark330/pom.xml @@ -16,7 +16,7 @@ jar - 2.12.15 + 2.13.12 2.13.5 diff --git a/core/shims/spark340/pom.xml b/core/shims/spark340/pom.xml index 1b312747..684309bd 100644 --- a/core/shims/spark340/pom.xml +++ b/core/shims/spark340/pom.xml @@ -16,7 +16,7 @@ jar - 2.12.15 + 2.13.12 2.13.5 diff --git a/core/shims/spark350/pom.xml b/core/shims/spark350/pom.xml index 2368daa2..f33c4a98 100644 --- a/core/shims/spark350/pom.xml +++ b/core/shims/spark350/pom.xml @@ -16,7 +16,7 @@ jar - 2.12.15 + 2.13.12 2.13.5 diff --git a/core/shims/spark400/pom.xml b/core/shims/spark400/pom.xml new file mode 100644 index 00000000..1a1c1e6f --- /dev/null +++ b/core/shims/spark400/pom.xml @@ -0,0 +1,99 @@ + + + + 4.0.0 + + + com.intel + raydp-shims + 1.7.0-SNAPSHOT + ../pom.xml + + + raydp-shims-spark400 + RayDP Shims for Spark 4.0.0 + jar + + + 2.13.12 + 2.13.5 + + + + + + org.scalastyle + scalastyle-maven-plugin + + + net.alchim31.maven + scala-maven-plugin + 3.2.2 + + + scala-compile-first + process-resources + + compile + + + + scala-test-compile-first + process-test-resources + + testCompile + + + + + + + + + src/main/resources + + + + + + + com.intel + raydp-shims-common + ${project.version} + compile + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark400.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark400.version} + provided + + + org.xerial.snappy + snappy-java + + + io.netty + netty-handler + + + + + org.xerial.snappy + snappy-java + ${snappy.version} + + + io.netty + netty-handler + ${netty.version} + + + diff --git a/core/shims/spark400/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark400/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider new file mode 100644 index 00000000..f88bbd7a --- /dev/null +++ b/core/shims/spark400/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider @@ -0,0 +1 @@ +com.intel.raydp.shims.spark400.SparkShimProvider diff --git a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala new file mode 100644 index 00000000..a39b57f6 --- /dev/null +++ b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.raydp.shims.spark400 + +import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} + +object SparkShimProvider { + val SPARK400_DESCRIPTOR = SparkShimDescriptor(4, 0, 0) + val DESCRIPTOR_STRINGS = Seq(s"$SPARK400_DESCRIPTOR") + val DESCRIPTOR = SPARK400_DESCRIPTOR +} + +class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { + def createShim: SparkShims = { + new Spark400Shims() + } + + def matches(version: String): Boolean = { + SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) + } +} diff --git a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala new file mode 100644 index 00000000..f815bfbd --- /dev/null +++ b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.raydp.shims.spark400 + +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.executor.{RayDPExecutorBackendFactory, RayDPSpark400ExecutorBackendFactory} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.spark400.SparkSqlUtils +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.spark400.TaskContextUtils +import com.intel.raydp.shims.{ShimDescriptor, SparkShims} + +class Spark400Shims extends SparkShims { + override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR + + override def toDataFrame( + rdd: JavaRDD[Array[Byte]], + schema: String, + session: SparkSession): DataFrame = { + SparkSqlUtils.toDataFrame(rdd, schema, session) + } + + override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { + new RayDPSpark400ExecutorBackendFactory() + } + + override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { + TaskContextUtils.getDummyTaskContext(partitionId, env) + } + + override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + } +} diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/TaskContextUtils.scala b/core/shims/spark400/src/main/scala/org/apache/spark/TaskContextUtils.scala new file mode 100644 index 00000000..287105cd --- /dev/null +++ b/core/shims/spark400/src/main/scala/org/apache/spark/TaskContextUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.spark400 + +import java.util.Properties + +import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} +import org.apache.spark.memory.TaskMemoryManager + +object TaskContextUtils { + def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { + new TaskContextImpl(0, 0, partitionId, -1024, 0, 0, + new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) + } +} diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala b/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala new file mode 100644 index 00000000..2e6b5e25 --- /dev/null +++ b/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.net.URL + +import org.apache.spark.SparkEnv +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.rpc.RpcEnv + +class RayCoarseGrainedExecutorBackend( + rpcEnv: RpcEnv, + driverUrl: String, + executorId: String, + bindAddress: String, + hostname: String, + cores: Int, + userClassPath: Seq[URL], + env: SparkEnv, + resourcesFileOpt: Option[String], + resourceProfile: ResourceProfile) + extends CoarseGrainedExecutorBackend( + rpcEnv, + driverUrl, + executorId, + bindAddress, + hostname, + cores, + env, + resourcesFileOpt, + resourceProfile) { + + override def getUserClassPath: Seq[URL] = userClassPath + +} diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayDPSpark400ExecutorBackendFactory.scala b/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayDPSpark400ExecutorBackendFactory.scala new file mode 100644 index 00000000..eed998bd --- /dev/null +++ b/core/shims/spark400/src/main/scala/org/apache/spark/executor/RayDPSpark400ExecutorBackendFactory.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.SparkEnv +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.rpc.RpcEnv + +import java.net.URL + +class RayDPSpark400ExecutorBackendFactory + extends RayDPExecutorBackendFactory { + override def createExecutorBackend( + rpcEnv: RpcEnv, + driverUrl: String, + executorId: String, + bindAddress: String, + hostname: String, + cores: Int, + userClassPath: Seq[URL], + env: SparkEnv, + resourcesFileOpt: Option[String], + resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { + new RayCoarseGrainedExecutorBackend( + rpcEnv, + driverUrl, + executorId, + bindAddress, + hostname, + cores, + userClassPath, + env, + resourcesFileOpt, + resourceProfile) + } +} diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala new file mode 100644 index 00000000..9f6abced --- /dev/null +++ b/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.spark400 + +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.TaskContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.execution.arrow.ArrowConverters +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.catalyst.InternalRow + +object SparkSqlUtils { + def toDataFrame( + arrowBatchRDD: JavaRDD[Array[Byte]], + schemaString: String, + session: SparkSession): DataFrame = { + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val timeZoneId = session.sessionState.conf.sessionLocalTimeZone + val internalRowRdd = arrowBatchRDD.rdd.mapPartitions { iter => + val context = TaskContext.get() + ArrowConverters.fromBatchIterator( + arrowBatchIter = iter, + schema = schema, + timeZoneId = timeZoneId, + errorOnDuplicatedFieldNames = false, + largeVarTypes = true, + context = context) + } + val rowRdd = internalRowRdd.map { internalRow => + Row.fromSeq(internalRow.toSeq(schema)) + } + session.createDataFrame(rowRdd.setName("arrow"), schema) + } + + def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + ArrowUtils.toArrowSchema( + schema = schema, + timeZoneId = timeZoneId, + errorOnDuplicatedFieldNames = false, + largeVarTypes = true + ) + } +} diff --git a/python/raydp/tests/conftest.py b/python/raydp/tests/conftest.py index a2a43b6d..8e01615f 100644 --- a/python/raydp/tests/conftest.py +++ b/python/raydp/tests/conftest.py @@ -65,7 +65,8 @@ def jdk17_extra_spark_configs() -> Dict[str, str]: @pytest.fixture(scope="function") def spark_session(request, jdk17_extra_spark_configs): - builder = SparkSession.builder.master("local[2]").appName("RayDP test") + builder = SparkSession.builder.master("local[2]").appName("RayDP test") \ + .config("spark.sql.ansi.enabled", "false") for k, v in jdk17_extra_spark_configs.items(): builder = builder.config(k, v) spark = builder.getOrCreate() @@ -98,6 +99,7 @@ def spark_on_ray_small(request, jdk17_extra_spark_configs): extra_configs = { "spark.driver.host": node_ip, "spark.driver.bindAddress": node_ip, + "spark.sql.ansi.enabled": "false", **jdk17_extra_spark_configs } spark = raydp.init_spark("test", 1, 1, "500M", configs=extra_configs) @@ -126,6 +128,7 @@ def spark_on_ray_2_executors(request, jdk17_extra_spark_configs): extra_configs = { "spark.driver.host": node_ip, "spark.driver.bindAddress": node_ip, + "spark.sql.ansi.enabled": "false", **jdk17_extra_spark_configs } spark = raydp.init_spark("test", 2, 1, "500M", configs=extra_configs) diff --git a/python/setup.py b/python/setup.py index 54077281..879b51f7 100644 --- a/python/setup.py +++ b/python/setup.py @@ -100,7 +100,7 @@ def run(self): "psutil", "pyarrow >= 4.0.1", "ray >= 2.37.0", - "pyspark >= 3.1.1, <=3.5.7", + "pyspark >= 3.1.1, <=4.1.1", "protobuf > 3.19.5" ] From b64a28445b0e828219744f340986f73865977c9d Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Mon, 25 Aug 2025 02:20:20 -0700 Subject: [PATCH 13/34] exclude spark 3.x --- .github/workflows/ray_nightly_test.yml | 2 +- .github/workflows/raydp.yml | 4 ++-- python/setup.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ray_nightly_test.yml b/.github/workflows/ray_nightly_test.yml index 68df8ae4..5c72fc82 100644 --- a/.github/workflows/ray_nightly_test.yml +++ b/.github/workflows/ray_nightly_test.yml @@ -32,7 +32,7 @@ jobs: matrix: os: [ ubuntu-latest ] python-version: [3.9, 3.10.14] - spark-version: [3.2.4, 3.3.2, 3.4.0, 3.5.6, 4.0.0] + spark-version: [4.0.0] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/raydp.yml b/.github/workflows/raydp.yml index a24746b9..a2203737 100644 --- a/.github/workflows/raydp.yml +++ b/.github/workflows/raydp.yml @@ -33,8 +33,8 @@ jobs: matrix: os: [ubuntu-latest] python-version: [3.9, 3.10.14] - spark-version: [3.3.2, 3.4.0, 3.5.0] - ray-version: [2.37.0, 2.40.0, 2.50.0] + spark-version: [4.0.0] + ray-version: [2.34.0, 2.40.0, 2.50.0] runs-on: ${{ matrix.os }} diff --git a/python/setup.py b/python/setup.py index 879b51f7..38e31102 100644 --- a/python/setup.py +++ b/python/setup.py @@ -100,7 +100,8 @@ def run(self): "psutil", "pyarrow >= 4.0.1", "ray >= 2.37.0", - "pyspark >= 3.1.1, <=4.1.1", + "pyspark >= 4.0.0", + "netifaces", "protobuf > 3.19.5" ] From 64f97948a7ec381b08d11af19659d92510ae5335 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Mon, 25 Aug 2025 02:32:59 -0700 Subject: [PATCH 14/34] add distribution From 67faeb58d57fb88bf3de18e1891abc242e11748d Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Mon, 25 Aug 2025 09:37:50 -0700 Subject: [PATCH 15/34] lint --- .../main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala index 2fa7f9fe..f835106a 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayAppMaster.scala @@ -27,7 +27,6 @@ import scala.jdk.CollectionConverters._ import com.fasterxml.jackson.core.JsonFactory import com.fasterxml.jackson.databind.ObjectMapper - import io.ray.api.{ActorHandle, PlacementGroups, Ray} import io.ray.api.id.PlacementGroupId import io.ray.api.placementgroup.PlacementGroup From 0a71973a9394d03b974c71c9e3425cb81e9b642f Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Tue, 26 Aug 2025 00:08:42 -0700 Subject: [PATCH 16/34] Do not use largeVarTypes --- .../src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index 9f6abced..90cd1b95 100644 --- a/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -40,7 +40,7 @@ object SparkSqlUtils { schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false, - largeVarTypes = true, + largeVarTypes = false, context = context) } val rowRdd = internalRowRdd.map { internalRow => @@ -54,7 +54,7 @@ object SparkSqlUtils { schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false, - largeVarTypes = true + largeVarTypes = false ) } } From 067f277214bd2c4c1cf429504b9cd9a862fbf767 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Tue, 26 Aug 2025 01:03:31 -0700 Subject: [PATCH 17/34] class to classic session to convert internalRowRdd to rdd --- .../main/scala/org/apache/spark/sql/SparkSqlUtils.scala | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index 90cd1b95..d40bce7c 100644 --- a/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.spark400 import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD -import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.classic.ClassicConversions +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.catalyst.InternalRow object SparkSqlUtils { def toDataFrame( @@ -43,10 +43,7 @@ object SparkSqlUtils { largeVarTypes = false, context = context) } - val rowRdd = internalRowRdd.map { internalRow => - Row.fromSeq(internalRow.toSeq(schema)) - } - session.createDataFrame(rowRdd.setName("arrow"), schema) + ClassicConversions.castToImpl(session).internalCreateDataFrame(internalRowRdd.setName("arrow"), schema) } def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { From 0651c9d0ab7646f6e0b43e18476f03aff638d724 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Tue, 26 Aug 2025 02:27:00 -0700 Subject: [PATCH 18/34] arrow to rdd --- .../spark/sql/raydp/ObjectStoreWriter.scala | 4 ++-- .../com/intel/raydp/shims/SparkShims.scala | 5 ++++- .../com/intel/raydp/shims/SparkShims.scala | 12 ++++++++-- .../org/apache/spark/sql/SparkSqlUtils.scala | 7 +++++- .../com/intel/raydp/shims/SparkShims.scala | 16 ++++++++++++-- .../org/apache/spark/sql/SparkSqlUtils.scala | 7 +++++- .../com/intel/raydp/shims/SparkShims.scala | 16 ++++++++++++-- .../org/apache/spark/sql/SparkSqlUtils.scala | 7 +++++- .../com/intel/raydp/shims/SparkShims.scala | 15 +++++++++++-- .../org/apache/spark/sql/SparkSqlUtils.scala | 16 ++++++++++++-- .../com/intel/raydp/shims/SparkShims.scala | 18 ++++++++++++--- .../org/apache/spark/sql/SparkSqlUtils.scala | 22 ++++++++++++++----- 12 files changed, 120 insertions(+), 25 deletions(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala index 133656dc..f6fc0898 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala @@ -97,7 +97,7 @@ object ObjectStoreWriter { def toArrowSchema(df: DataFrame): Schema = { val conf = df.queryExecution.sparkSession.sessionState.conf val timeZoneId = conf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE) - SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId) + SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId, df.sparkSession) } @deprecated @@ -108,7 +108,7 @@ object ObjectStoreWriter { } val uuid = dfToId.getOrElseUpdate(df, UUID.randomUUID()) val queue = ObjectRefHolder.getQueue(uuid) - val rdd = df.toArrowBatchRdd + val rdd = SparkShimLoader.getSparkShims.toArrowBatchRDD(df) rdd.persist(storageLevel) rdd.count() var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray diff --git a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 2ca83522..c1f47fc2 100644 --- a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -21,6 +21,7 @@ import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.java.JavaRDD import org.apache.spark.executor.RayDPExecutorBackendFactory +import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SparkSession} @@ -39,5 +40,7 @@ trait SparkShims { def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema + def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema + + def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] } diff --git a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 6ea817db..6c423e33 100644 --- a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark322.SparkSqlUtils import com.intel.raydp.shims.{ShimDescriptor, SparkShims} import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType class Spark322Shims extends SparkShims { @@ -46,7 +47,14 @@ class Spark322Shims extends SparkShims { TaskContextUtils.getDummyTaskContext(partitionId, env) } - override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + override def toArrowSchema( + schema : StructType, + timeZoneId : String, + sparkSession: SparkSession) : Schema = { + SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, session = sparkSession) + } + + override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { + SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } } diff --git a/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index be9b409c..609c7112 100644 --- a/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.spark322 import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types.StructType @@ -29,7 +30,11 @@ object SparkSqlUtils { ArrowConverters.toDataFrame(rdd, schema, new SQLContext(session)) } - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + def toArrowSchema(schema : StructType, timeZoneId : String, session: SparkSession) : Schema = { ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) } + + def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = { + dataFrame.toArrowBatchRdd + } } diff --git a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 4f1a50b5..26197052 100644 --- a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark330.SparkSqlUtils import com.intel.raydp.shims.{ShimDescriptor, SparkShims} import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType class Spark330Shims extends SparkShims { @@ -46,7 +47,18 @@ class Spark330Shims extends SparkShims { TaskContextUtils.getDummyTaskContext(partitionId, env) } - override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + override def toArrowSchema( + schema : StructType, + timeZoneId : String, + sparkSession: SparkSession) : Schema = { + SparkSqlUtils.toArrowSchema( + schema = schema, + timeZoneId = timeZoneId, + sparkSession = sparkSession + ) + } + + override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { + SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } } diff --git a/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index 162371ad..8c937dcd 100644 --- a/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.spark330 import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types.StructType @@ -29,7 +30,11 @@ object SparkSqlUtils { ArrowConverters.toDataFrame(rdd, schema, session) } - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema = { ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) } + + def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = { + dataFrame.toArrowBatchRdd + } } diff --git a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala index c444373f..26717840 100644 --- a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark340.SparkSqlUtils import com.intel.raydp.shims.{ShimDescriptor, SparkShims} import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType class Spark340Shims extends SparkShims { @@ -46,7 +47,18 @@ class Spark340Shims extends SparkShims { TaskContextUtils.getDummyTaskContext(partitionId, env) } - override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + override def toArrowSchema( + schema : StructType, + timeZoneId : String, + sparkSession: SparkSession) : Schema = { + SparkSqlUtils.toArrowSchema( + schema = schema, + timeZoneId = timeZoneId, + sparkSession = sparkSession + ) + } + + override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { + SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } } diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index eb52d8e7..3ec33569 100644 --- a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.spark340 import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types._ @@ -39,7 +40,11 @@ object SparkSqlUtils { session.internalCreateDataFrame(rdd.setName("arrow"), schema) } - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema = { ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) } + + def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = { + SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) + } } diff --git a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 721d6923..5b2f2eec 100644 --- a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark350.SparkSqlUtils import com.intel.raydp.shims.{ShimDescriptor, SparkShims} import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType class Spark350Shims extends SparkShims { @@ -46,7 +47,17 @@ class Spark350Shims extends SparkShims { TaskContextUtils.getDummyTaskContext(partitionId, env) } - override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + override def toArrowSchema(schema : StructType, + timeZoneId : String, + sparkSession: SparkSession) : Schema = { + SparkSqlUtils.toArrowSchema( + schema = schema, + timeZoneId = timeZoneId, + sparkSession = sparkSession + ) + } + + override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { + SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } } diff --git a/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index dfd063f7..a12c4256 100644 --- a/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.spark350 import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types._ @@ -39,7 +40,18 @@ object SparkSqlUtils { session.internalCreateDataFrame(rdd.setName("arrow"), schema) } - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false) + def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema = { + val errorOnDuplicatedFieldNames = + sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy" + + ArrowUtils.toArrowSchema( + schema = schema, + timeZoneId = timeZoneId, + errorOnDuplicatedFieldNames = errorOnDuplicatedFieldNames + ) + } + + def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = { + dataFrame.toArrowBatchRdd } } diff --git a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala index f815bfbd..540edd2f 100644 --- a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -22,10 +22,15 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.executor.{RayDPExecutorBackendFactory, RayDPSpark400ExecutorBackendFactory} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.spark400.SparkSqlUtils -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.spark400.TaskContextUtils import com.intel.raydp.shims.{ShimDescriptor, SparkShims} +import org.apache.spark.rdd.{MapPartitionsRDD, RDD} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.arrow.ArrowConverters + +import scala.reflect.ClassTag class Spark400Shims extends SparkShims { override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR @@ -45,7 +50,14 @@ class Spark400Shims extends SparkShims { TaskContextUtils.getDummyTaskContext(partitionId, env) } - override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { - SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) + override def toArrowSchema( + schema: StructType, + timeZoneId: String, + sparkSession: SparkSession): Schema = { + SparkSqlUtils.toArrowSchema(schema, timeZoneId, sparkSession) + } + + override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { + SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } } diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index d40bce7c..aab0e2fe 100644 --- a/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark400/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.spark400 import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD -import org.apache.spark.sql.classic.ClassicConversions -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils @@ -43,15 +44,24 @@ object SparkSqlUtils { largeVarTypes = false, context = context) } - ClassicConversions.castToImpl(session).internalCreateDataFrame(internalRowRdd.setName("arrow"), schema) + session.internalCreateDataFrame(internalRowRdd.setName("arrow"), schema) } - def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { + def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = { + dataFrame.toArrowBatchRdd + } + + def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema = { + val errorOnDuplicatedFieldNames = + sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy" + val largeVarTypes = + sparkSession.sessionState.conf.arrowUseLargeVarTypes + ArrowUtils.toArrowSchema( schema = schema, timeZoneId = timeZoneId, - errorOnDuplicatedFieldNames = false, - largeVarTypes = false + errorOnDuplicatedFieldNames = errorOnDuplicatedFieldNames, + largeVarTypes = largeVarTypes ) } } From f3f6e752065ef0f5a9d1a7ab1bcef5ce4b139028 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Fri, 26 Sep 2025 01:18:10 -0700 Subject: [PATCH 19/34] pin click<8.3.0 --- .github/workflows/raydp.yml | 4 ++-- .../scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/raydp.yml b/.github/workflows/raydp.yml index a2203737..e2916eef 100644 --- a/.github/workflows/raydp.yml +++ b/.github/workflows/raydp.yml @@ -74,8 +74,8 @@ jobs: run: | python -m pip install --upgrade pip pip install wheel - pip install "numpy<1.24" "click<8.3.0" - pip install "pydantic<2.0" + pip install "numpy<1.24" + pip install "pydantic<2.0" "click<8.3.0" SUBVERSION=$(python -c 'import sys; print(sys.version_info[1])') if [ "$(uname -s)" == "Linux" ] then diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala index f6fc0898..8feecb23 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala @@ -95,7 +95,7 @@ object ObjectStoreWriter { } def toArrowSchema(df: DataFrame): Schema = { - val conf = df.queryExecution.sparkSession.sessionState.conf + val conf = df.sparkSession.sessionState.conf val timeZoneId = conf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE) SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId, df.sparkSession) } @@ -111,7 +111,7 @@ object ObjectStoreWriter { val rdd = SparkShimLoader.getSparkShims.toArrowBatchRDD(df) rdd.persist(storageLevel) rdd.count() - var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray + val executorIds = df.sparkSession.sparkContext.getExecutorIds.toArray val numExecutors = executorIds.length val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME) .get.asInstanceOf[ActorHandle[RayAppMaster]] From 66656ee8292de2e67db332fe24e419d1898d053b Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Thu, 20 Nov 2025 21:28:19 -0800 Subject: [PATCH 20/34] make jackson provided --- core/pom.xml | 7 ++++++- core/raydp-main/pom.xml | 5 ----- core/shims/spark322/pom.xml | 1 - core/shims/spark330/pom.xml | 1 - core/shims/spark340/pom.xml | 1 - core/shims/spark350/pom.xml | 1 - core/shims/spark400/pom.xml | 1 - 7 files changed, 6 insertions(+), 11 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 6641e7fe..79c5efed 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -31,7 +31,7 @@ 1.8 1.8 2.13.12 - 2.15.0 + 2.18.2 2.13 5.10.1 @@ -152,16 +152,19 @@ com.fasterxml.jackson.core jackson-core ${jackson.version} + provided com.fasterxml.jackson.core jackson-databind ${jackson.version} + provided com.fasterxml.jackson.core jackson-annotations ${jackson.version} + provided @@ -169,6 +172,7 @@ com.fasterxml.jackson.module jackson-module-scala_${scala.binary.version} ${jackson.version} + provided com.google.guava @@ -180,6 +184,7 @@ com.fasterxml.jackson.module jackson-module-jaxb-annotations ${jackson.version} + provided diff --git a/core/raydp-main/pom.xml b/core/raydp-main/pom.xml index 3c791a65..78effa21 100644 --- a/core/raydp-main/pom.xml +++ b/core/raydp-main/pom.xml @@ -134,24 +134,20 @@ com.fasterxml.jackson.core jackson-core - ${jackson.version} com.fasterxml.jackson.core jackson-databind - ${jackson.version} com.fasterxml.jackson.core jackson-annotations - ${jackson.version} com.fasterxml.jackson.module jackson-module-scala_${scala.binary.version} - ${jackson.version} com.google.guava @@ -162,7 +158,6 @@ com.fasterxml.jackson.module jackson-module-jaxb-annotations - ${jackson.version} diff --git a/core/shims/spark322/pom.xml b/core/shims/spark322/pom.xml index 0e9100c1..295b3d73 100644 --- a/core/shims/spark322/pom.xml +++ b/core/shims/spark322/pom.xml @@ -17,7 +17,6 @@ 2.13.12 - 2.13.5 diff --git a/core/shims/spark330/pom.xml b/core/shims/spark330/pom.xml index 3e229ade..6972fef1 100644 --- a/core/shims/spark330/pom.xml +++ b/core/shims/spark330/pom.xml @@ -17,7 +17,6 @@ 2.13.12 - 2.13.5 diff --git a/core/shims/spark340/pom.xml b/core/shims/spark340/pom.xml index 684309bd..52af6ed5 100644 --- a/core/shims/spark340/pom.xml +++ b/core/shims/spark340/pom.xml @@ -17,7 +17,6 @@ 2.13.12 - 2.13.5 diff --git a/core/shims/spark350/pom.xml b/core/shims/spark350/pom.xml index f33c4a98..0afa3289 100644 --- a/core/shims/spark350/pom.xml +++ b/core/shims/spark350/pom.xml @@ -17,7 +17,6 @@ 2.13.12 - 2.13.5 diff --git a/core/shims/spark400/pom.xml b/core/shims/spark400/pom.xml index 1a1c1e6f..fd3f8494 100644 --- a/core/shims/spark400/pom.xml +++ b/core/shims/spark400/pom.xml @@ -17,7 +17,6 @@ 2.13.12 - 2.13.5 From 90a46992a92a1910f38b3604c46c58bdab58ca63 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Thu, 20 Nov 2025 21:30:20 -0800 Subject: [PATCH 21/34] Support spark 4.0.1 --- .../main/scala/com/intel/raydp/shims/SparkShimProvider.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala index a39b57f6..6652c182 100644 --- a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala +++ b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala @@ -21,7 +21,10 @@ import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} object SparkShimProvider { val SPARK400_DESCRIPTOR = SparkShimDescriptor(4, 0, 0) - val DESCRIPTOR_STRINGS = Seq(s"$SPARK400_DESCRIPTOR") + val SPARK401_DESCRIPTOR = SparkShimDescriptor(4, 0, 1) + val DESCRIPTOR_STRINGS = Seq( + s"$SPARK400_DESCRIPTOR", s"$SPARK401_DESCRIPTOR" + ) val DESCRIPTOR = SPARK400_DESCRIPTOR } From 201e96ac9fff40ab17a863973f82970e64c6e05d Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Thu, 20 Nov 2025 22:19:28 -0800 Subject: [PATCH 22/34] tf/estimator.py: only write checkpoint in rank0 --- python/raydp/tf/estimator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/raydp/tf/estimator.py b/python/raydp/tf/estimator.py index 1eb3bf1d..c96cb387 100644 --- a/python/raydp/tf/estimator.py +++ b/python/raydp/tf/estimator.py @@ -207,7 +207,7 @@ def train_func(config): if session.get_world_rank() == 0: checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) - session.report({}, checkpoint=checkpoint) + session.report({}, checkpoint=checkpoint) def fit(self, train_ds: Dataset, From f520469a3249cf670eb8baa784aef20c63371561 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Mon, 8 Dec 2025 14:42:19 -0800 Subject: [PATCH 23/34] revert tf/estimator.py --- .../scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala | 3 ++- python/raydp/tf/estimator.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala index 8feecb23..d84be564 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala @@ -24,6 +24,7 @@ import java.util.{List, UUID} import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} import java.util.function.{Function => JFunction} import org.apache.arrow.vector.types.pojo.Schema +import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.{RayDPException, SparkContext} @@ -168,7 +169,7 @@ object ObjectStoreWriter { "Not yet connected to Ray! Please set fault_tolerant_mode=True when starting RayDP.") } - val rdd = df.toArrowBatchRdd + val rdd = SparkShimLoader.getSparkShims.toArrowBatchRDD(df) rdd.persist(storageLevel) rdd.count() // Keep a strong reference so Spark's ContextCleaner does not GC the cached blocks diff --git a/python/raydp/tf/estimator.py b/python/raydp/tf/estimator.py index c96cb387..1eb3bf1d 100644 --- a/python/raydp/tf/estimator.py +++ b/python/raydp/tf/estimator.py @@ -207,7 +207,7 @@ def train_func(config): if session.get_world_rank() == 0: checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) - session.report({}, checkpoint=checkpoint) + session.report({}, checkpoint=checkpoint) def fit(self, train_ds: Dataset, From 679eca61c513554bd160caa023be6331a8ee6d61 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Tue, 17 Feb 2026 01:15:23 +0800 Subject: [PATCH 24/34] support spark 4.1.x --- .github/workflows/pypi.yml | 5 +- .github/workflows/pypi_release.yml | 7 +- .github/workflows/ray_nightly_test.yml | 4 +- .github/workflows/raydp.yml | 4 +- core/pom.xml | 1 + core/shims/pom.xml | 1 + .../intel/raydp/shims/SparkShimProvider.scala | 10 +- core/shims/spark410/pom.xml | 98 +++++++++++++++++++ .../com.intel.raydp.shims.SparkShimProvider | 1 + .../intel/raydp/shims/SparkShimProvider.scala | 37 +++++++ .../com/intel/raydp/shims/SparkShims.scala | 59 +++++++++++ .../org/apache/spark/TaskContextUtils.scala | 30 ++++++ .../RayCoarseGrainedExecutorBackend.scala | 50 ++++++++++ .../RayDPSpark410ExecutorBackendFactory.scala | 51 ++++++++++ .../org/apache/spark/sql/SparkSqlUtils.scala | 67 +++++++++++++ python/setup.py | 4 +- 16 files changed, 412 insertions(+), 17 deletions(-) create mode 100644 core/shims/spark410/pom.xml create mode 100644 core/shims/spark410/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider create mode 100644 core/shims/spark410/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala create mode 100644 core/shims/spark410/src/main/scala/com/intel/raydp/shims/SparkShims.scala create mode 100644 core/shims/spark410/src/main/scala/org/apache/spark/TaskContextUtils.scala create mode 100644 core/shims/spark410/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala create mode 100644 core/shims/spark410/src/main/scala/org/apache/spark/executor/RayDPSpark410ExecutorBackendFactory.scala create mode 100644 core/shims/spark410/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 15c10874..9725b3dd 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -38,10 +38,11 @@ jobs: uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4 with: python-version: 3.10.14 - - name: Set up JDK 1.8 + - name: Set up JDK 17 uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4 with: - java-version: 1.8 + java-version: 17 + distribution: 'corretto' - name: days since the commit date run: | : diff --git a/.github/workflows/pypi_release.yml b/.github/workflows/pypi_release.yml index 6e284b03..00ea2411 100644 --- a/.github/workflows/pypi_release.yml +++ b/.github/workflows/pypi_release.yml @@ -35,7 +35,7 @@ jobs: name: build wheel and upload release runs-on: ubuntu-latest env: - PYSPARK_VERSION: "3.5.7" + PYSPARK_VERSION: "4.1.0" RAY_VERSION: "2.40.0" steps: - uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master @@ -46,10 +46,11 @@ jobs: uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4 with: python-version: 3.10.14 - - name: Set up JDK 1.8 + - name: Set up JDK 17 uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4 with: - java-version: 1.8 + java-version: 17 + distribution: 'corretto' - name: Install extra dependencies for Ubuntu run: | sudo apt-get install -y mpich diff --git a/.github/workflows/ray_nightly_test.yml b/.github/workflows/ray_nightly_test.yml index 5c72fc82..ab579cca 100644 --- a/.github/workflows/ray_nightly_test.yml +++ b/.github/workflows/ray_nightly_test.yml @@ -31,8 +31,8 @@ jobs: strategy: matrix: os: [ ubuntu-latest ] - python-version: [3.9, 3.10.14] - spark-version: [4.0.0] + python-version: [3.10.14, 3.11, 3.12] + spark-version: [4.0.0, 4.1.0] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/raydp.yml b/.github/workflows/raydp.yml index e2916eef..b713be26 100644 --- a/.github/workflows/raydp.yml +++ b/.github/workflows/raydp.yml @@ -32,8 +32,8 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.9, 3.10.14] - spark-version: [4.0.0] + python-version: [3.10.14, 3.11, 3.12] + spark-version: [4.0.0, 4.1.0] ray-version: [2.34.0, 2.40.0, 2.50.0] runs-on: ${{ matrix.os }} diff --git a/core/pom.xml b/core/pom.xml index 79c5efed..a0cbfb29 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -19,6 +19,7 @@ 3.4.0 3.5.0 4.0.0 + 4.1.0 1.1.10.4 4.1.94.Final 1.10.0 diff --git a/core/shims/pom.xml b/core/shims/pom.xml index ac16dba7..75c8d0cc 100644 --- a/core/shims/pom.xml +++ b/core/shims/pom.xml @@ -22,6 +22,7 @@ spark340 spark350 spark400 + spark410 diff --git a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala index 6652c182..70eeef10 100644 --- a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala +++ b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala @@ -20,12 +20,10 @@ package com.intel.raydp.shims.spark400 import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} object SparkShimProvider { - val SPARK400_DESCRIPTOR = SparkShimDescriptor(4, 0, 0) - val SPARK401_DESCRIPTOR = SparkShimDescriptor(4, 0, 1) - val DESCRIPTOR_STRINGS = Seq( - s"$SPARK400_DESCRIPTOR", s"$SPARK401_DESCRIPTOR" - ) - val DESCRIPTOR = SPARK400_DESCRIPTOR + private val SUPPORTED_PATCHES = 0 to 2 + val DESCRIPTORS = SUPPORTED_PATCHES.map(p => SparkShimDescriptor(4, 0, p)) + val DESCRIPTOR_STRINGS = DESCRIPTORS.map(_.toString) + val DESCRIPTOR = DESCRIPTORS.head } class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { diff --git a/core/shims/spark410/pom.xml b/core/shims/spark410/pom.xml new file mode 100644 index 00000000..df13c9cb --- /dev/null +++ b/core/shims/spark410/pom.xml @@ -0,0 +1,98 @@ + + + + 4.0.0 + + + com.intel + raydp-shims + 1.7.0-SNAPSHOT + ../pom.xml + + + raydp-shims-spark410 + RayDP Shims for Spark 4.1.0 + jar + + + 2.13.12 + + + + + + org.scalastyle + scalastyle-maven-plugin + + + net.alchim31.maven + scala-maven-plugin + 3.2.2 + + + scala-compile-first + process-resources + + compile + + + + scala-test-compile-first + process-test-resources + + testCompile + + + + + + + + + src/main/resources + + + + + + + com.intel + raydp-shims-common + ${project.version} + compile + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark410.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark410.version} + provided + + + org.xerial.snappy + snappy-java + + + io.netty + netty-handler + + + + + org.xerial.snappy + snappy-java + ${snappy.version} + + + io.netty + netty-handler + ${netty.version} + + + diff --git a/core/shims/spark410/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider b/core/shims/spark410/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider new file mode 100644 index 00000000..38ccea86 --- /dev/null +++ b/core/shims/spark410/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider @@ -0,0 +1 @@ +com.intel.raydp.shims.spark410.SparkShimProvider diff --git a/core/shims/spark410/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark410/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala new file mode 100644 index 00000000..96d8b7a0 --- /dev/null +++ b/core/shims/spark410/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.raydp.shims.spark410 + +import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} + +object SparkShimProvider { + private val SUPPORTED_PATCHES = 0 to 1 + val DESCRIPTORS = SUPPORTED_PATCHES.map(p => SparkShimDescriptor(4, 1, p)) + val DESCRIPTOR_STRINGS = DESCRIPTORS.map(_.toString) + val DESCRIPTOR = DESCRIPTORS.head +} + +class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { + def createShim: SparkShims = { + new Spark410Shims() + } + + def matches(version: String): Boolean = { + SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) + } +} diff --git a/core/shims/spark410/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark410/src/main/scala/com/intel/raydp/shims/SparkShims.scala new file mode 100644 index 00000000..cea3b323 --- /dev/null +++ b/core/shims/spark410/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.raydp.shims.spark410 + +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.executor.{RayDPExecutorBackendFactory, RayDPSpark410ExecutorBackendFactory} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.spark410.SparkSqlUtils +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.spark410.TaskContextUtils +import com.intel.raydp.shims.{ShimDescriptor, SparkShims} +import org.apache.spark.rdd.RDD + +class Spark410Shims extends SparkShims { + override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR + + override def toDataFrame( + rdd: JavaRDD[Array[Byte]], + schema: String, + session: SparkSession): DataFrame = { + SparkSqlUtils.toDataFrame(rdd, schema, session) + } + + override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { + new RayDPSpark410ExecutorBackendFactory() + } + + override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { + TaskContextUtils.getDummyTaskContext(partitionId, env) + } + + override def toArrowSchema( + schema: StructType, + timeZoneId: String, + sparkSession: SparkSession): Schema = { + SparkSqlUtils.toArrowSchema(schema, timeZoneId, sparkSession) + } + + override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { + SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) + } +} diff --git a/core/shims/spark410/src/main/scala/org/apache/spark/TaskContextUtils.scala b/core/shims/spark410/src/main/scala/org/apache/spark/TaskContextUtils.scala new file mode 100644 index 00000000..d46b6822 --- /dev/null +++ b/core/shims/spark410/src/main/scala/org/apache/spark/TaskContextUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.spark410 + +import java.util.Properties + +import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} +import org.apache.spark.memory.TaskMemoryManager + +object TaskContextUtils { + def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { + new TaskContextImpl(0, 0, partitionId, -1024, 0, 0, + new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) + } +} diff --git a/core/shims/spark410/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala b/core/shims/spark410/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala new file mode 100644 index 00000000..2e6b5e25 --- /dev/null +++ b/core/shims/spark410/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.net.URL + +import org.apache.spark.SparkEnv +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.rpc.RpcEnv + +class RayCoarseGrainedExecutorBackend( + rpcEnv: RpcEnv, + driverUrl: String, + executorId: String, + bindAddress: String, + hostname: String, + cores: Int, + userClassPath: Seq[URL], + env: SparkEnv, + resourcesFileOpt: Option[String], + resourceProfile: ResourceProfile) + extends CoarseGrainedExecutorBackend( + rpcEnv, + driverUrl, + executorId, + bindAddress, + hostname, + cores, + env, + resourcesFileOpt, + resourceProfile) { + + override def getUserClassPath: Seq[URL] = userClassPath + +} diff --git a/core/shims/spark410/src/main/scala/org/apache/spark/executor/RayDPSpark410ExecutorBackendFactory.scala b/core/shims/spark410/src/main/scala/org/apache/spark/executor/RayDPSpark410ExecutorBackendFactory.scala new file mode 100644 index 00000000..35433182 --- /dev/null +++ b/core/shims/spark410/src/main/scala/org/apache/spark/executor/RayDPSpark410ExecutorBackendFactory.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.SparkEnv +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.rpc.RpcEnv + +import java.net.URL + +class RayDPSpark410ExecutorBackendFactory + extends RayDPExecutorBackendFactory { + override def createExecutorBackend( + rpcEnv: RpcEnv, + driverUrl: String, + executorId: String, + bindAddress: String, + hostname: String, + cores: Int, + userClassPath: Seq[URL], + env: SparkEnv, + resourcesFileOpt: Option[String], + resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { + new RayCoarseGrainedExecutorBackend( + rpcEnv, + driverUrl, + executorId, + bindAddress, + hostname, + cores, + userClassPath, + env, + resourcesFileOpt, + resourceProfile) + } +} diff --git a/core/shims/spark410/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark410/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala new file mode 100644 index 00000000..cdbb0809 --- /dev/null +++ b/core/shims/spark410/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.spark410 + +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.TaskContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.classic.ClassicConversions.castToImpl +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.execution.arrow.ArrowConverters +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils + +object SparkSqlUtils { + def toDataFrame( + arrowBatchRDD: JavaRDD[Array[Byte]], + schemaString: String, + session: SparkSession): DataFrame = { + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val timeZoneId = session.sessionState.conf.sessionLocalTimeZone + val internalRowRdd = arrowBatchRDD.rdd.mapPartitions { iter => + val context = TaskContext.get() + ArrowConverters.fromBatchIterator( + arrowBatchIter = iter, + schema = schema, + timeZoneId = timeZoneId, + errorOnDuplicatedFieldNames = false, + largeVarTypes = false, + context = context) + } + session.internalCreateDataFrame(internalRowRdd.setName("arrow"), schema) + } + + def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = { + dataFrame.toArrowBatchRdd + } + + def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema = { + val errorOnDuplicatedFieldNames = + sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy" + val largeVarTypes = + sparkSession.sessionState.conf.arrowUseLargeVarTypes + + ArrowUtils.toArrowSchema( + schema = schema, + timeZoneId = timeZoneId, + errorOnDuplicatedFieldNames = errorOnDuplicatedFieldNames, + largeVarTypes = largeVarTypes + ) + } +} diff --git a/python/setup.py b/python/setup.py index 38e31102..f3b9a1f1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -136,9 +136,9 @@ def run(self): python_requires='>=3.6', classifiers=[ 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', ] ) finally: From 6c7a1b3e13b81e26e47e9b649cde1c7616c00695 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Tue, 17 Feb 2026 01:45:39 +0800 Subject: [PATCH 25/34] deprecate python 3.9, add 3.11 to CI --- .github/workflows/ray_nightly_test.yml | 2 +- .github/workflows/raydp.yml | 4 ++-- python/setup.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ray_nightly_test.yml b/.github/workflows/ray_nightly_test.yml index ab579cca..db5b162a 100644 --- a/.github/workflows/ray_nightly_test.yml +++ b/.github/workflows/ray_nightly_test.yml @@ -31,7 +31,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest ] - python-version: [3.10.14, 3.11, 3.12] + python-version: [3.10.14, 3.11] spark-version: [4.0.0, 4.1.0] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/raydp.yml b/.github/workflows/raydp.yml index b713be26..480405bf 100644 --- a/.github/workflows/raydp.yml +++ b/.github/workflows/raydp.yml @@ -32,9 +32,9 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.10.14, 3.11, 3.12] + python-version: [3.10.14, 3.11] spark-version: [4.0.0, 4.1.0] - ray-version: [2.34.0, 2.40.0, 2.50.0] + ray-version: [2.37.0, 2.40.0, 2.50.0] runs-on: ${{ matrix.os }} diff --git a/python/setup.py b/python/setup.py index f3b9a1f1..06540c81 100644 --- a/python/setup.py +++ b/python/setup.py @@ -133,12 +133,11 @@ def run(self): }, install_requires=install_requires, setup_requires=["grpcio-tools"], - python_requires='>=3.6', + python_requires='>=3.10', classifiers=[ 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', ] ) finally: From 187ce969f68fcd1b854b494e39eca775141947f5 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Tue, 17 Feb 2026 02:04:43 +0800 Subject: [PATCH 26/34] update pylint --- .github/workflows/raydp.yml | 2 +- python/pylintrc | 102 +++--------------------------------- 2 files changed, 9 insertions(+), 95 deletions(-) diff --git a/.github/workflows/raydp.yml b/.github/workflows/raydp.yml index 480405bf..1916886d 100644 --- a/.github/workflows/raydp.yml +++ b/.github/workflows/raydp.yml @@ -101,7 +101,7 @@ jobs: pip install dist/raydp-*.whl - name: Lint run: | - pip install pylint==2.8.3 + pip install pylint==3.2.7 pylint --rcfile=python/pylintrc python/raydp pylint --rcfile=python/pylintrc examples/*.py - name: Test with pytest diff --git a/python/pylintrc b/python/pylintrc index 907de8cb..d0c6b5c7 100644 --- a/python/pylintrc +++ b/python/pylintrc @@ -74,65 +74,30 @@ confidence= # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" disable=abstract-method, - apply-builtin, arguments-differ, attribute-defined-outside-init, - backtick, - basestring-builtin, broad-except, - buffer-builtin, - cmp-builtin, - cmp-method, - coerce-builtin, - coerce-method, + broad-exception-raised, + consider-using-dict-items, + consider-using-from-import, + consider-using-generator, dangerous-default-value, - delslice-method, duplicate-code, - execfile-builtin, - file-builtin, - filter-builtin-not-iterating, fixme, - getslice-method, global-statement, - hex-method, + global-variable-not-assigned, import-error, import-self, - import-star-module-level, - input-builtin, - intern-builtin, invalid-name, locally-disabled, logging-fstring-interpolation, - long-builtin, - long-suffix, - map-builtin-not-iterating, missing-docstring, missing-function-docstring, - metaclass-assignment, - next-method-called, - next-method-defined, - no-absolute-import, no-else-return, no-member, no-name-in-module, - no-self-use, - nonzero-method, - oct-method, - old-division, - old-ne-operator, - old-octal-literal, - old-raise-syntax, - parameter-unpacking, - print-statement, protected-access, - raising-string, - range-builtin-not-iterating, redefined-outer-name, - reduce-builtin, - reload-builtin, - round-builtin, - setslice-method, - standarderror-builtin, suppressed-message, too-few-public-methods, too-many-ancestors, @@ -143,18 +108,13 @@ disable=abstract-method, too-many-public-methods, too-many-return-statements, too-many-statements, - unichr-builtin, - unicode-builtin, - unpacking-in-except, unused-argument, unused-import, unused-variable, + use-dict-literal, useless-else-on-loop, useless-suppression, - using-cmp-argument, wrong-import-order, - xrange-builtin, - zip-builtin-not-iterating, [REPORTS] @@ -164,12 +124,6 @@ disable=abstract-method, # mypackage.mymodule.MyReporterClass. output-format=text -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". This option is deprecated -# and it will be removed in Pylint 2.0. -files-output=no - # Tells whether to display a full report or only the messages reports=no @@ -206,72 +160,39 @@ bad-names=foo,bar,baz,toto,tutu,tata # the name regexes allow several styles. name-group= -# Include a hint for the correct naming format with invalid-name -include-naming-hint=no - # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. property-classes=abc.abstractproperty function-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for function names -function-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct variable names variable-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for variable names -variable-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct constant names const-rgx=(([A-Za-z_][A-Za-z0-9_]*)|(__.*__))$ -# Naming hint for constant names -const-name-hint=(([A-a-zZ_][A-Za-z0-9_]*)|(__.*__))$ - # Regular expression matching correct attribute names attr-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for attribute names -attr-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct argument names argument-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for argument names -argument-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct class attribute names class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ -# Naming hint for class attribute names -class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ - # Regular expression matching correct inline iteration names inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ -# Naming hint for inline iteration names -inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ - # Regular expression matching correct class names class-rgx=[A-Z_][a-zA-Z0-9]+$ -# Naming hint for class names -class-name-hint=[A-Z_][a-zA-Z0-9]+$ - # Regular expression matching correct module names module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ -# Naming hint for module names -module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ - # Regular expression matching correct method names method-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for method names -method-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ @@ -326,12 +247,6 @@ ignore-long-lines=(?x)( # else. single-line-if-stmt=yes -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check=trailing-comma,dict-separator - # Maximum number of lines in a module max-module-lines=1000 @@ -481,6 +396,5 @@ valid-metaclass-classmethod-first-arg=mcs # Exceptions that will emit a warning when being caught. Defaults to # "Exception" -overgeneral-exceptions=StandardError, - Exception, - BaseException +overgeneral-exceptions=builtins.Exception, + builtins.BaseException From 8cfc83217e9ba98763c69f1984f10eb1e5a0fc9b Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Tue, 17 Feb 2026 02:18:41 +0800 Subject: [PATCH 27/34] fix pyint rules --- python/pylintrc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/pylintrc b/python/pylintrc index d0c6b5c7..103656f9 100644 --- a/python/pylintrc +++ b/python/pylintrc @@ -79,6 +79,7 @@ disable=abstract-method, broad-except, broad-exception-raised, consider-using-dict-items, + consider-using-f-string, consider-using-from-import, consider-using-generator, dangerous-default-value, @@ -95,6 +96,8 @@ disable=abstract-method, missing-function-docstring, no-else-return, no-member, + not-callable, + possibly-used-before-assignment, no-name-in-module, protected-access, redefined-outer-name, @@ -110,6 +113,8 @@ disable=abstract-method, too-many-statements, unused-argument, unused-import, + unreachable, + unspecified-encoding, unused-variable, use-dict-literal, useless-else-on-loop, From 6473ccc5a396b82ab4bc3bc76bc34e210761521a Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Tue, 17 Feb 2026 02:34:16 +0800 Subject: [PATCH 28/34] fix tensorflow version --- .github/workflows/ray_nightly_test.yml | 12 ++++++------ .github/workflows/raydp.yml | 3 +-- python/setup.py | 2 +- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ray_nightly_test.yml b/.github/workflows/ray_nightly_test.yml index db5b162a..8d846d60 100644 --- a/.github/workflows/ray_nightly_test.yml +++ b/.github/workflows/ray_nightly_test.yml @@ -74,7 +74,7 @@ jobs: run: | python -m pip install --upgrade pip pip install wheel - pip install "numpy<1.24" "click<8.3.0" + pip install "click<8.3.0" SUBVERSION=$(python -c 'import sys; print(sys.version_info[1])') if [ "$(uname -s)" == "Linux" ] then @@ -83,14 +83,14 @@ jobs: pip install torch fi case $PYTHON_VERSION in - 3.9) - pip install "ray[train,default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp39-cp39-manylinux2014_x86_64.whl" - ;; 3.10.14) pip install "ray[train,default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl" ;; + 3.11) + pip install "ray[train,default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp311-cp311-manylinux2014_x86_64.whl" + ;; esac - pip install pyarrow tqdm pytest tensorflow==2.13.1 tabulate grpcio-tools wget + pip install pyarrow tqdm pytest "tensorflow>=2.16.1,<2.19" tabulate grpcio-tools wget pip install "xgboost_ray[default]<=0.1.13" pip install torchmetrics HOROVOD_WITH_GLOO=1 @@ -110,7 +110,7 @@ jobs: pip install dist/raydp-*.whl - name: Lint run: | - pip install pylint==2.8.3 + pip install pylint==3.2.7 pylint --rcfile=python/pylintrc python/raydp pylint --rcfile=python/pylintrc examples/*.py - name: Test with pytest diff --git a/.github/workflows/raydp.yml b/.github/workflows/raydp.yml index 1916886d..2baa7e47 100644 --- a/.github/workflows/raydp.yml +++ b/.github/workflows/raydp.yml @@ -74,7 +74,6 @@ jobs: run: | python -m pip install --upgrade pip pip install wheel - pip install "numpy<1.24" pip install "pydantic<2.0" "click<8.3.0" SUBVERSION=$(python -c 'import sys; print(sys.version_info[1])') if [ "$(uname -s)" == "Linux" ] @@ -83,7 +82,7 @@ jobs: else pip install torch fi - pip install pyarrow "ray[train,default]==${{ matrix.ray-version }}" tqdm pytest tensorflow==2.13.1 tabulate grpcio-tools wget + pip install pyarrow "ray[train,default]==${{ matrix.ray-version }}" tqdm pytest "tensorflow>=2.16.1,<2.19" tabulate grpcio-tools wget pip install "xgboost_ray[default]<=0.1.13" pip install "xgboost<=2.0.3" pip install torchmetrics diff --git a/python/setup.py b/python/setup.py index 06540c81..c03fb34d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -96,7 +96,7 @@ def run(self): install_requires = [ "numpy", - "pandas >= 1.1.4", + "pandas >= 1.1.4, < 2.2", "psutil", "pyarrow >= 4.0.1", "ray >= 2.37.0", From 93d5d42950737f0d1243d0400ac21f35c70b2b37 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Tue, 17 Feb 2026 02:55:38 +0800 Subject: [PATCH 29/34] pin pandas<3 version --- python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index c03fb34d..da87d762 100644 --- a/python/setup.py +++ b/python/setup.py @@ -96,7 +96,7 @@ def run(self): install_requires = [ "numpy", - "pandas >= 1.1.4, < 2.2", + "pandas >= 2.2.0, < 3.0.0", "psutil", "pyarrow >= 4.0.1", "ray >= 2.37.0", From f83db26d5071a02f06356f77d0ee562dc2365db9 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Tue, 17 Feb 2026 03:09:01 +0800 Subject: [PATCH 30/34] remove df.sqlContext reference --- .../scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala index d84be564..36e86da3 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala @@ -176,7 +176,7 @@ object ObjectStoreWriter { // before Ray tasks fetch them. recoverableRDDs.put(rdd.id, rdd) - var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray + val executorIds = df.sparkSession.sparkContext.getExecutorIds.toArray val numExecutors = executorIds.length val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME) .get.asInstanceOf[ActorHandle[RayAppMaster]] From 7d4c4dedbfb0a50c119a975377cc00c0adf5f186 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Tue, 17 Feb 2026 10:22:11 +0800 Subject: [PATCH 31/34] extract commandlineutils to custom spark submit --- core/pom.xml | 2 +- .../main/scala/org/apache/spark/deploy/SparkSubmit.scala | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index a0cbfb29..f6cdbb39 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -13,7 +13,7 @@ https://github.com/ray-project/raydp.git - 3.3.3 + 4.0.0 3.2.2 3.3.0 3.4.0 diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 829517e5..e4208547 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -986,7 +986,13 @@ private[spark] object InProcessSparkSubmit { } -object SparkSubmit extends CommandLineUtils with Logging { +object SparkSubmit extends Logging { + + // Inlined from CommandLineLoggingUtils to avoid binary incompatibility + // between Spark 4.0.x (exitFn: Int => Unit) and 4.1.x (exitFn: (Int, Option[Throwable]) => Unit) + private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) + private[spark] var printStream: PrintStream = System.err + private[spark] def printMessage(str: String): Unit = printStream.println(str) // Cluster managers private val YARN = 1 From e7148fef817d65e5b567224e68cd257e2cd8c369 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Tue, 17 Feb 2026 19:30:17 +0800 Subject: [PATCH 32/34] add new shims --- core/pom.xml | 6 ++-- .../org/apache/spark/deploy/SparkSubmit.scala | 34 +++++++++++-------- .../RayCoarseGrainedSchedulerBackend.scala | 5 ++- .../raydp/shims/CommandLineUtilsBridge.scala | 29 ++++++++++++++++ .../com/intel/raydp/shims/SparkShims.scala | 2 ++ .../com/intel/raydp/shims/SparkShims.scala | 7 +++- .../deploy/spark322/SparkSubmitUtils.scala | 33 ++++++++++++++++++ .../com/intel/raydp/shims/SparkShims.scala | 7 +++- .../deploy/spark330/SparkSubmitUtils.scala | 33 ++++++++++++++++++ .../com/intel/raydp/shims/SparkShims.scala | 7 +++- .../deploy/spark340/SparkSubmitUtils.scala | 33 ++++++++++++++++++ .../org/apache/spark/sql/SparkSqlUtils.scala | 2 +- .../com/intel/raydp/shims/SparkShims.scala | 7 +++- .../deploy/spark350/SparkSubmitUtils.scala | 33 ++++++++++++++++++ .../com/intel/raydp/shims/SparkShims.scala | 7 +++- .../deploy/spark400/SparkSubmitUtils.scala | 33 ++++++++++++++++++ .../com/intel/raydp/shims/SparkShims.scala | 7 +++- .../deploy/spark410/SparkSubmitUtils.scala | 33 ++++++++++++++++++ 18 files changed, 293 insertions(+), 25 deletions(-) create mode 100644 core/shims/common/src/main/scala/com/intel/raydp/shims/CommandLineUtilsBridge.scala create mode 100644 core/shims/spark322/src/main/scala/org/apache/spark/deploy/spark322/SparkSubmitUtils.scala create mode 100644 core/shims/spark330/src/main/scala/org/apache/spark/deploy/spark330/SparkSubmitUtils.scala create mode 100644 core/shims/spark340/src/main/scala/org/apache/spark/deploy/spark340/SparkSubmitUtils.scala create mode 100644 core/shims/spark350/src/main/scala/org/apache/spark/deploy/spark350/SparkSubmitUtils.scala create mode 100644 core/shims/spark400/src/main/scala/org/apache/spark/deploy/spark400/SparkSubmitUtils.scala create mode 100644 core/shims/spark410/src/main/scala/org/apache/spark/deploy/spark410/SparkSubmitUtils.scala diff --git a/core/pom.xml b/core/pom.xml index f6cdbb39..5d070063 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -13,7 +13,7 @@ https://github.com/ray-project/raydp.git - 4.0.0 + 3.3.3 3.2.2 3.3.0 3.4.0 @@ -29,8 +29,8 @@ 2.5.2 UTF-8 UTF-8 - 1.8 - 1.8 + 17 + 17 2.13.12 2.18.2 2.13 diff --git a/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index e4208547..d9e8396a 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -24,13 +24,13 @@ import java.security.PrivilegedExceptionAction import java.text.ParseException import java.util.{ServiceLoader, UUID} import java.util.jar.JarInputStream -import javax.ws.rs.core.UriBuilder import scala.annotation.tailrec import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.util.{Properties, Try} +import com.intel.raydp.shims.SparkShimLoader import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} import org.apache.hadoop.fs.{FileSystem, Path} @@ -258,7 +258,10 @@ private[spark] class SparkSubmit extends Logging { } if (clusterManager == KUBERNETES) { - args.master = Utils.checkAndGetK8sMasterUrl(args.master) + val checkedMaster = Utils.checkAndGetK8sMasterUrl(args.master) + SparkShimLoader.getSparkShims + .getCommandLineUtilsBridge + .setSubmitMaster(args, checkedMaster) // Make sure KUBERNETES is included in our build if we're trying to use it if (!Utils.classIsLoadable(KUBERNETES_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { error( @@ -340,7 +343,7 @@ private[spark] class SparkSubmit extends Logging { // update spark config from args args.toSparkConf(Option(sparkConf)) - val hadoopConf = conf.getOrElse(SparkHadoopUtil.newConfiguration(sparkConf)) + val hadoopConf = conf.getOrElse(SparkHadoopUtil.get.newConfiguration(sparkConf)) val targetDir = Utils.createTempDir() // Kerberos is not supported in standalone mode, and keytab support is not yet available @@ -393,8 +396,10 @@ private[spark] class SparkSubmit extends Logging { val archiveLocalFiles = Option(args.archives).map { uris => val resolvedUris = Utils.stringToSeq(uris).map(Utils.resolveURI) val localArchives = downloadFileList( - resolvedUris.map( - UriBuilder.fromUri(_).fragment(null).build().toString).mkString(","), + resolvedUris.map { uri => + new URI(uri.getScheme, + uri.getRawSchemeSpecificPart, null).toString + }.mkString(","), targetDir, sparkConf, hadoopConf) // SPARK-33748: this mimics the behaviour of Yarn cluster mode. If the driver is running @@ -413,8 +418,9 @@ private[spark] class SparkSubmit extends Logging { Utils.unpack(source, dest) // Keep the URIs of local files with the given fragments. - UriBuilder.fromUri( - localArchive).fragment(resolvedUri.getFragment).build().toString + new URI(localArchive.getScheme, + localArchive.getRawSchemeSpecificPart, + resolvedUri.getFragment).toString }.mkString(",") }.orNull args.files = filesLocalFiles @@ -988,11 +994,10 @@ private[spark] object InProcessSparkSubmit { object SparkSubmit extends Logging { - // Inlined from CommandLineLoggingUtils to avoid binary incompatibility - // between Spark 4.0.x (exitFn: Int => Unit) and 4.1.x (exitFn: (Int, Option[Throwable]) => Unit) - private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) - private[spark] var printStream: PrintStream = System.err - private[spark] def printMessage(str: String): Unit = printStream.println(str) + var printStream: PrintStream = System.err + // scalastyle:off println + def printMessage(str: String): Unit = printStream.println(str) + // scalastyle:on println // Cluster managers private val YARN = 1 @@ -1025,7 +1030,7 @@ object SparkSubmit extends Logging { private[deploy] val KUBERNETES_CLUSTER_SUBMIT_CLASS = "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" - override def main(args: Array[String]): Unit = { + def main(args: Array[String]): Unit = { val submit = new SparkSubmit() { self => @@ -1050,7 +1055,8 @@ object SparkSubmit extends Logging { super.doSubmit(args) } catch { case e: SparkUserAppException => - exitFn(e.exitCode) + SparkShimLoader.getSparkShims + .getCommandLineUtilsBridge.callExit(e.exitCode) } } diff --git a/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala b/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala index dce83a78..d4387867 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayCoarseGrainedSchedulerBackend.scala @@ -154,7 +154,10 @@ class RayCoarseGrainedSchedulerBackend( } // Start executors with a few necessary configs for registering with the scheduler - val sparkJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf) + val sparkJavaOpts = conf.getAll + .filter { case (k, _) => SparkConf.isExecutorStartupConf(k) } + .map { case (k, v) => s"-D$k=$v" } + .toSeq // add Xmx, it should not be set in java opts, because Spark is not allowed. // We also add Xms to ensure the Xmx >= Xms val memoryLimit = Seq(s"-Xms${sc.executorMemory}M", s"-Xmx${sc.executorMemory}M") diff --git a/core/shims/common/src/main/scala/com/intel/raydp/shims/CommandLineUtilsBridge.scala b/core/shims/common/src/main/scala/com/intel/raydp/shims/CommandLineUtilsBridge.scala new file mode 100644 index 00000000..2d4b1995 --- /dev/null +++ b/core/shims/common/src/main/scala/com/intel/raydp/shims/CommandLineUtilsBridge.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.raydp.shims + +/** + * Bridge trait that delegates CommandLineUtils behavior to + * version-specific helper objects. Each shim module provides + * a SparkSubmitUtils that extends the real Spark CommandLineUtils + * and implements this bridge. + */ +trait CommandLineUtilsBridge { + def callExit(code: Int): Unit + def setSubmitMaster(args: Any, master: String): Unit +} diff --git a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala index c1f47fc2..a4e9d22e 100644 --- a/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -43,4 +43,6 @@ trait SparkShims { def toArrowSchema(schema : StructType, timeZoneId : String, sparkSession: SparkSession) : Schema def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] + + def getCommandLineUtilsBridge: CommandLineUtilsBridge } diff --git a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 6c423e33..a782ecd0 100644 --- a/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark322._ import org.apache.spark.spark322.TaskContextUtils import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark322.SparkSqlUtils -import com.intel.raydp.shims.{ShimDescriptor, SparkShims} +import com.intel.raydp.shims.{CommandLineUtilsBridge, ShimDescriptor, SparkShims} import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.deploy.spark322.SparkSubmitUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType @@ -57,4 +58,8 @@ class Spark322Shims extends SparkShims { override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } + + override def getCommandLineUtilsBridge: CommandLineUtilsBridge = { + SparkSubmitUtils + } } diff --git a/core/shims/spark322/src/main/scala/org/apache/spark/deploy/spark322/SparkSubmitUtils.scala b/core/shims/spark322/src/main/scala/org/apache/spark/deploy/spark322/SparkSubmitUtils.scala new file mode 100644 index 00000000..53e1057c --- /dev/null +++ b/core/shims/spark322/src/main/scala/org/apache/spark/deploy/spark322/SparkSubmitUtils.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.spark322 + +import org.apache.spark.deploy.SparkSubmitArguments +import org.apache.spark.util.CommandLineUtils +import com.intel.raydp.shims.CommandLineUtilsBridge + +object SparkSubmitUtils + extends CommandLineUtils with CommandLineUtilsBridge { + override def main(args: Array[String]): Unit = {} + + override def callExit(code: Int): Unit = exitFn(code) + + override def setSubmitMaster(args: Any, master: String): Unit = { + args.asInstanceOf[SparkSubmitArguments].master = master + } +} diff --git a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 26197052..e4236709 100644 --- a/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark330._ import org.apache.spark.spark330.TaskContextUtils import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark330.SparkSqlUtils -import com.intel.raydp.shims.{ShimDescriptor, SparkShims} +import com.intel.raydp.shims.{CommandLineUtilsBridge, ShimDescriptor, SparkShims} import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.deploy.spark330.SparkSubmitUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType @@ -61,4 +62,8 @@ class Spark330Shims extends SparkShims { override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } + + override def getCommandLineUtilsBridge: CommandLineUtilsBridge = { + SparkSubmitUtils + } } diff --git a/core/shims/spark330/src/main/scala/org/apache/spark/deploy/spark330/SparkSubmitUtils.scala b/core/shims/spark330/src/main/scala/org/apache/spark/deploy/spark330/SparkSubmitUtils.scala new file mode 100644 index 00000000..55f4c4e1 --- /dev/null +++ b/core/shims/spark330/src/main/scala/org/apache/spark/deploy/spark330/SparkSubmitUtils.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.spark330 + +import org.apache.spark.deploy.SparkSubmitArguments +import org.apache.spark.util.CommandLineUtils +import com.intel.raydp.shims.CommandLineUtilsBridge + +object SparkSubmitUtils + extends CommandLineUtils with CommandLineUtilsBridge { + override def main(args: Array[String]): Unit = {} + + override def callExit(code: Int): Unit = exitFn(code) + + override def setSubmitMaster(args: Any, master: String): Unit = { + args.asInstanceOf[SparkSubmitArguments].master = master + } +} diff --git a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 26717840..65295753 100644 --- a/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark340._ import org.apache.spark.spark340.TaskContextUtils import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark340.SparkSqlUtils -import com.intel.raydp.shims.{ShimDescriptor, SparkShims} +import com.intel.raydp.shims.{CommandLineUtilsBridge, ShimDescriptor, SparkShims} import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.deploy.spark340.SparkSubmitUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType @@ -61,4 +62,8 @@ class Spark340Shims extends SparkShims { override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } + + override def getCommandLineUtilsBridge: CommandLineUtilsBridge = { + SparkSubmitUtils + } } diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/deploy/spark340/SparkSubmitUtils.scala b/core/shims/spark340/src/main/scala/org/apache/spark/deploy/spark340/SparkSubmitUtils.scala new file mode 100644 index 00000000..8cb64f64 --- /dev/null +++ b/core/shims/spark340/src/main/scala/org/apache/spark/deploy/spark340/SparkSubmitUtils.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.spark340 + +import org.apache.spark.deploy.SparkSubmitArguments +import org.apache.spark.util.CommandLineUtils +import com.intel.raydp.shims.CommandLineUtilsBridge + +object SparkSubmitUtils + extends CommandLineUtils with CommandLineUtilsBridge { + override def main(args: Array[String]): Unit = {} + + override def callExit(code: Int): Unit = exitFn(code) + + override def setSubmitMaster(args: Any, master: String): Unit = { + args.asInstanceOf[SparkSubmitArguments].maybeMaster = Some(master) + } +} diff --git a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala b/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala index 3ec33569..55bee8dd 100644 --- a/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala +++ b/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala @@ -45,6 +45,6 @@ object SparkSqlUtils { } def toArrowRDD(dataFrame: DataFrame, sparkSession: SparkSession): RDD[Array[Byte]] = { - SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) + dataFrame.toArrowBatchRdd } } diff --git a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 5b2f2eec..dcde188c 100644 --- a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark350._ import org.apache.spark.spark350.TaskContextUtils import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.spark350.SparkSqlUtils -import com.intel.raydp.shims.{ShimDescriptor, SparkShims} +import com.intel.raydp.shims.{CommandLineUtilsBridge, ShimDescriptor, SparkShims} import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.deploy.spark350.SparkSubmitUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType @@ -60,4 +61,8 @@ class Spark350Shims extends SparkShims { override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } + + override def getCommandLineUtilsBridge: CommandLineUtilsBridge = { + SparkSubmitUtils + } } diff --git a/core/shims/spark350/src/main/scala/org/apache/spark/deploy/spark350/SparkSubmitUtils.scala b/core/shims/spark350/src/main/scala/org/apache/spark/deploy/spark350/SparkSubmitUtils.scala new file mode 100644 index 00000000..1802edbe --- /dev/null +++ b/core/shims/spark350/src/main/scala/org/apache/spark/deploy/spark350/SparkSubmitUtils.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.spark350 + +import org.apache.spark.deploy.SparkSubmitArguments +import org.apache.spark.util.CommandLineUtils +import com.intel.raydp.shims.CommandLineUtilsBridge + +object SparkSubmitUtils + extends CommandLineUtils with CommandLineUtilsBridge { + override def main(args: Array[String]): Unit = {} + + override def callExit(code: Int): Unit = exitFn(code) + + override def setSubmitMaster(args: Any, master: String): Unit = { + args.asInstanceOf[SparkSubmitArguments].maybeMaster = Some(master) + } +} diff --git a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala index 540edd2f..722099b3 100644 --- a/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark400/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.spark400.SparkSqlUtils import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.spark400.TaskContextUtils -import com.intel.raydp.shims.{ShimDescriptor, SparkShims} +import com.intel.raydp.shims.{CommandLineUtilsBridge, ShimDescriptor, SparkShims} +import org.apache.spark.deploy.spark400.SparkSubmitUtils import org.apache.spark.rdd.{MapPartitionsRDD, RDD} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.arrow.ArrowConverters @@ -60,4 +61,8 @@ class Spark400Shims extends SparkShims { override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } + + override def getCommandLineUtilsBridge: CommandLineUtilsBridge = { + SparkSubmitUtils + } } diff --git a/core/shims/spark400/src/main/scala/org/apache/spark/deploy/spark400/SparkSubmitUtils.scala b/core/shims/spark400/src/main/scala/org/apache/spark/deploy/spark400/SparkSubmitUtils.scala new file mode 100644 index 00000000..0f5aac07 --- /dev/null +++ b/core/shims/spark400/src/main/scala/org/apache/spark/deploy/spark400/SparkSubmitUtils.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.spark400 + +import org.apache.spark.deploy.SparkSubmitArguments +import org.apache.spark.util.CommandLineUtils +import com.intel.raydp.shims.CommandLineUtilsBridge + +object SparkSubmitUtils + extends CommandLineUtils with CommandLineUtilsBridge { + override def main(args: Array[String]): Unit = {} + + override def callExit(code: Int): Unit = exitFn(code) + + override def setSubmitMaster(args: Any, master: String): Unit = { + args.asInstanceOf[SparkSubmitArguments].maybeMaster = Some(master) + } +} diff --git a/core/shims/spark410/src/main/scala/com/intel/raydp/shims/SparkShims.scala b/core/shims/spark410/src/main/scala/com/intel/raydp/shims/SparkShims.scala index cea3b323..3efb60dc 100644 --- a/core/shims/spark410/src/main/scala/com/intel/raydp/shims/SparkShims.scala +++ b/core/shims/spark410/src/main/scala/com/intel/raydp/shims/SparkShims.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.spark410.SparkSqlUtils import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.spark410.TaskContextUtils -import com.intel.raydp.shims.{ShimDescriptor, SparkShims} +import com.intel.raydp.shims.{CommandLineUtilsBridge, ShimDescriptor, SparkShims} +import org.apache.spark.deploy.spark410.SparkSubmitUtils import org.apache.spark.rdd.RDD class Spark410Shims extends SparkShims { @@ -56,4 +57,8 @@ class Spark410Shims extends SparkShims { override def toArrowBatchRDD(dataFrame: DataFrame): RDD[Array[Byte]] = { SparkSqlUtils.toArrowRDD(dataFrame, dataFrame.sparkSession) } + + override def getCommandLineUtilsBridge: CommandLineUtilsBridge = { + SparkSubmitUtils + } } diff --git a/core/shims/spark410/src/main/scala/org/apache/spark/deploy/spark410/SparkSubmitUtils.scala b/core/shims/spark410/src/main/scala/org/apache/spark/deploy/spark410/SparkSubmitUtils.scala new file mode 100644 index 00000000..fc559f76 --- /dev/null +++ b/core/shims/spark410/src/main/scala/org/apache/spark/deploy/spark410/SparkSubmitUtils.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.spark410 + +import org.apache.spark.deploy.SparkSubmitArguments +import org.apache.spark.util.CommandLineUtils +import com.intel.raydp.shims.CommandLineUtilsBridge + +object SparkSubmitUtils + extends CommandLineUtils with CommandLineUtilsBridge { + override def main(args: Array[String]): Unit = {} + + override def callExit(code: Int): Unit = exitFn(code, None) + + override def setSubmitMaster(args: Any, master: String): Unit = { + args.asInstanceOf[SparkSubmitArguments].maybeMaster = Some(master) + } +} From 1989452b4fc5236cfecca83b32411817fb92d382 Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Wed, 18 Feb 2026 00:07:11 +0800 Subject: [PATCH 33/34] compile against 4.0.0 --- core/pom.xml | 2 +- python/raydp/tf/estimator.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 5d070063..2ebaab55 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -13,7 +13,7 @@ https://github.com/ray-project/raydp.git - 3.3.3 + 4.0.0 3.2.2 3.3.0 3.4.0 diff --git a/python/raydp/tf/estimator.py b/python/raydp/tf/estimator.py index 1eb3bf1d..b9b9b537 100644 --- a/python/raydp/tf/estimator.py +++ b/python/raydp/tf/estimator.py @@ -175,13 +175,22 @@ def train_func(config): # Model building/compiling need to be within `strategy.scope()`. multi_worker_model = TFEstimator.build_and_compile_model(config) + # Disable auto-sharding since Ray already handles data distribution + # across workers. Without this, MultiWorkerMirroredStrategy tries to + # re-shard the dataset, producing PerReplica objects that Keras 3.x + # cannot convert back to tensors. + ds_options = tf.data.Options() + ds_options.experimental_distribute.auto_shard_policy = ( + tf.data.experimental.AutoShardPolicy.OFF + ) + train_dataset = session.get_dataset_shard("train") train_tf_dataset = train_dataset.to_tf( feature_columns=config["feature_columns"], label_columns=config["label_columns"], batch_size=config["batch_size"], drop_last=config["drop_last"] - ) + ).with_options(ds_options) if config["evaluate"]: eval_dataset = session.get_dataset_shard("evaluate") eval_tf_dataset = eval_dataset.to_tf( @@ -189,7 +198,7 @@ def train_func(config): label_columns=config["label_columns"], batch_size=config["batch_size"], drop_last=config["drop_last"] - ) + ).with_options(ds_options) results = [] callbacks = config["callbacks"] for _ in range(config["num_epochs"]): From a86d51dd482e32486ecf8c5443b22987275f854c Mon Sep 17 00:00:00 2001 From: Pang Wu Date: Wed, 18 Feb 2026 01:16:28 +0800 Subject: [PATCH 34/34] use legacy keras --- .github/workflows/pypi_release.yml | 2 +- .github/workflows/ray_nightly_test.yml | 4 ++-- .github/workflows/raydp.yml | 4 ++-- examples/tensorflow_titanic.ipynb | 1 + python/raydp/tf/estimator.py | 14 +++----------- python/setup.py | 5 ++++- 6 files changed, 13 insertions(+), 17 deletions(-) diff --git a/.github/workflows/pypi_release.yml b/.github/workflows/pypi_release.yml index 00ea2411..38845301 100644 --- a/.github/workflows/pypi_release.yml +++ b/.github/workflows/pypi_release.yml @@ -66,7 +66,7 @@ jobs: pip install "numpy<1.24" "click<8.3.0" pip install "pydantic<2.0" pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install pyarrow "ray[train,default]==${{ env.RAY_VERSION }}" tqdm pytest tensorflow==2.13.1 tabulate grpcio-tools wget + pip install pyarrow "ray[train,default]==${{ env.RAY_VERSION }}" tqdm pytest tensorflow==2.16.1 tf_keras tabulate grpcio-tools wget pip install "xgboost_ray[default]<=0.1.13" pip install "xgboost<=2.0.3" pip install torchmetrics diff --git a/.github/workflows/ray_nightly_test.yml b/.github/workflows/ray_nightly_test.yml index 8d846d60..6da5888a 100644 --- a/.github/workflows/ray_nightly_test.yml +++ b/.github/workflows/ray_nightly_test.yml @@ -90,7 +90,7 @@ jobs: pip install "ray[train,default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp311-cp311-manylinux2014_x86_64.whl" ;; esac - pip install pyarrow tqdm pytest "tensorflow>=2.16.1,<2.19" tabulate grpcio-tools wget + pip install pyarrow tqdm pytest tabulate grpcio-tools wget pip install "xgboost_ray[default]<=0.1.13" pip install torchmetrics HOROVOD_WITH_GLOO=1 @@ -107,7 +107,7 @@ jobs: run: | pip install pyspark==${{ matrix.spark-version }} ./build.sh - pip install dist/raydp-*.whl + pip install "$(ls dist/raydp-*.whl)[tensorflow]" - name: Lint run: | pip install pylint==3.2.7 diff --git a/.github/workflows/raydp.yml b/.github/workflows/raydp.yml index 2baa7e47..0ad7d4a4 100644 --- a/.github/workflows/raydp.yml +++ b/.github/workflows/raydp.yml @@ -82,7 +82,7 @@ jobs: else pip install torch fi - pip install pyarrow "ray[train,default]==${{ matrix.ray-version }}" tqdm pytest "tensorflow>=2.16.1,<2.19" tabulate grpcio-tools wget + pip install pyarrow "ray[train,default]==${{ matrix.ray-version }}" tqdm pytest tabulate grpcio-tools wget pip install "xgboost_ray[default]<=0.1.13" pip install "xgboost<=2.0.3" pip install torchmetrics @@ -97,7 +97,7 @@ jobs: run: | pip install pyspark==${{ matrix.spark-version }} ./build.sh - pip install dist/raydp-*.whl + pip install "$(ls dist/raydp-*.whl)[tensorflow]" - name: Lint run: | pip install pylint==3.2.7 diff --git a/examples/tensorflow_titanic.ipynb b/examples/tensorflow_titanic.ipynb index e6f6c1e0..4837e5cc 100644 --- a/examples/tensorflow_titanic.ipynb +++ b/examples/tensorflow_titanic.ipynb @@ -15,6 +15,7 @@ "source": [ "import ray\n", "import os\n", + "os.environ[\"TF_USE_LEGACY_KERAS\"] = \"1\"\n", "import re\n", "import pandas as pd, numpy as np\n", "\n", diff --git a/python/raydp/tf/estimator.py b/python/raydp/tf/estimator.py index b9b9b537..fdbbcc8d 100644 --- a/python/raydp/tf/estimator.py +++ b/python/raydp/tf/estimator.py @@ -38,6 +38,7 @@ from raydp.spark.interfaces import SparkEstimatorInterface, DF, OPTIONAL_DF from raydp import stop_spark + class TFEstimator(EstimatorInterface, SparkEstimatorInterface): def __init__(self, num_workers: int = 1, @@ -175,22 +176,13 @@ def train_func(config): # Model building/compiling need to be within `strategy.scope()`. multi_worker_model = TFEstimator.build_and_compile_model(config) - # Disable auto-sharding since Ray already handles data distribution - # across workers. Without this, MultiWorkerMirroredStrategy tries to - # re-shard the dataset, producing PerReplica objects that Keras 3.x - # cannot convert back to tensors. - ds_options = tf.data.Options() - ds_options.experimental_distribute.auto_shard_policy = ( - tf.data.experimental.AutoShardPolicy.OFF - ) - train_dataset = session.get_dataset_shard("train") train_tf_dataset = train_dataset.to_tf( feature_columns=config["feature_columns"], label_columns=config["label_columns"], batch_size=config["batch_size"], drop_last=config["drop_last"] - ).with_options(ds_options) + ) if config["evaluate"]: eval_dataset = session.get_dataset_shard("evaluate") eval_tf_dataset = eval_dataset.to_tf( @@ -198,7 +190,7 @@ def train_func(config): label_columns=config["label_columns"], batch_size=config["batch_size"], drop_last=config["drop_last"] - ).with_options(ds_options) + ) results = [] callbacks = config["callbacks"] for _ in range(config["num_epochs"]): diff --git a/python/setup.py b/python/setup.py index da87d762..ed6b05ed 100644 --- a/python/setup.py +++ b/python/setup.py @@ -101,7 +101,6 @@ def run(self): "pyarrow >= 4.0.1", "ray >= 2.37.0", "pyspark >= 4.0.0", - "netifaces", "protobuf > 3.19.5" ] @@ -132,6 +131,10 @@ def run(self): 'build_proto_modules': CustomBuildPackageProtos, }, install_requires=install_requires, + extras_require={ + "tensorflow": ["tensorflow>=2.15.1,<2.16"], + "tensorflow-gpu": ["tensorflow[and-cuda]>=2.15.1,<2.16"], + }, setup_requires=["grpcio-tools"], python_requires='>=3.10', classifiers=[