diff --git a/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayObjectRefRDD.scala b/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayObjectRefRDD.scala index 2c76643c..794d246d 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayObjectRefRDD.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayObjectRefRDD.scala @@ -27,24 +27,26 @@ import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.raydp.RayDPUtils import org.apache.spark.sql.Row -private[spark] class RayObjectRefRDDPartition(idx: Int) extends Partition { +private[spark] class RayObjectRefRDDPartition(idx: Int, hex: String) extends Partition { val index = idx + val objHex = hex } private[spark] class RayObjectRefRDD( sc: SparkContext, + blocksHex: List[String], locations: List[Array[Byte]]) extends RDD[Row](sc, Nil) { override def getPartitions: Array[Partition] = { - (0 until locations.size()).map { i => - new RayObjectRefRDDPartition(i).asInstanceOf[Partition] + (0 until blocksHex.size()).map { i => + new RayObjectRefRDDPartition(i, blocksHex.get(i)).asInstanceOf[Partition] }.toArray } override def compute(split: Partition, context: TaskContext): Iterator[Row] = { - (Row(split.index) :: Nil).iterator + (Row(split.asInstanceOf[RayObjectRefRDDPartition].objHex) :: Nil).iterator } override def getPreferredLocations(split: Partition): Seq[String] = { diff --git a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala index 31fbc366..43a9e806 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala @@ -29,14 +29,15 @@ import org.apache.spark.rdd.{RayDatasetRDD, RayObjectRefRDD} import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.execution.arrow.ArrowConverters -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{StringType, StructType} object ObjectStoreReader { def createRayObjectRefDF( spark: SparkSession, + blocksHex: List[String], locations: List[Array[Byte]]): DataFrame = { - val rdd = new RayObjectRefRDD(spark.sparkContext, locations) - val schema = new StructType().add("idx", IntegerType) + val rdd = new RayObjectRefRDD(spark.sparkContext, blocksHex, locations) + val schema = new StructType().add("hex", StringType) spark.createDataFrame(rdd, schema) } diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py index d6a64764..ee0b8d31 100644 --- a/python/raydp/spark/dataset.py +++ b/python/raydp/spark/dataset.py @@ -15,7 +15,6 @@ # limitations under the License. import logging -import uuid from typing import Callable, Dict, List, NoReturn, Optional, Iterable, Union import numpy as np @@ -164,8 +163,7 @@ def _save_spark_df_to_object_store(df: sql.DataFrame, use_batch: bool = True, if _use_owner is True: holder = ray.get_actor(obj_holder_name) - df_id = uuid.uuid4() - ray.get(holder.add_objects.remote(df_id, blocks)) + ray.get(holder.add_objects.remote(blocks)) return blocks, block_sizes @@ -210,12 +208,12 @@ def _convert_by_udf(spark: sql.SparkSession, schema: StructType) -> DataFrame: holder_name = spark.sparkContext.appName + RAYDP_SPARK_MASTER_SUFFIX holder = ray.get_actor(holder_name) - df_id = uuid.uuid4() - ray.get(holder.add_objects.remote(df_id, blocks)) + ray.get(holder.add_objects.remote(blocks)) jvm = spark.sparkContext._jvm object_store_reader = jvm.org.apache.spark.sql.raydp.ObjectStoreReader # create the rdd then dataframe to utilize locality - jdf = object_store_reader.createRayObjectRefDF(spark._jsparkSession, locations) + blocks_hex = [block.hex() for block in blocks] + jdf = object_store_reader.createRayObjectRefDF(spark._jsparkSession, blocks_hex, locations) current_namespace = ray.get_runtime_context().namespace ray_address = ray.get(holder.get_ray_address.remote()) blocks_df = DataFrame(jdf, spark._wrapped if hasattr(spark, "_wrapped") else spark) @@ -228,8 +226,8 @@ def _convert_blocks_to_dataframe(blocks): obj_holder = ray.get_actor(holder_name) for block in blocks: dfs = [] - for idx in block["idx"]: - ref = ray.get(obj_holder.get_object.remote(df_id, idx)) + for obj_hex in block["hex"]: + ref = ray.get(obj_holder.get_object.remote(obj_hex)) data = ray.get(ref) dfs.append(data.to_pandas()) yield pd.concat(dfs) diff --git a/python/raydp/spark/ray_cluster_master.py b/python/raydp/spark/ray_cluster_master.py index d5ff4617..6e9d43d3 100644 --- a/python/raydp/spark/ray_cluster_master.py +++ b/python/raydp/spark/ray_cluster_master.py @@ -212,11 +212,15 @@ def get_spark_home(self) -> str: assert self._started_up return self._spark_home - def add_objects(self, timestamp, objects): - self._objects[timestamp] = objects + def add_objects(self, objects): + for obj in objects: + self._objects[obj.hex()] = obj - def get_object(self, timestamp, idx): - return self._objects[timestamp][idx] + def get_object(self, obj_hex): + return self._objects[obj_hex] + + def remove_object(self, obj_hex): + del self._objects[obj_hex] def get_ray_address(self): return ray.worker.global_worker.node.address