From d3445afbc05ce1e474ac68d12fc60bdd4a87f959 Mon Sep 17 00:00:00 2001 From: zizhao Date: Thu, 4 Jan 2024 15:42:00 +0200 Subject: [PATCH] wait driver endpoint start up in manger initialization Signed-off-by: zizhao --- .../shuffle/ucx/CommonUcxShuffleManager.scala | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleManager.scala index d9bbcaa2..e751be5c 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleManager.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleManager.scala @@ -9,7 +9,8 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.concurrent.ExecutionContext.Implicits.global import scala.util.Success -import org.apache.spark.rpc.RpcEnv +import org.apache.spark.SparkException +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.shuffle.ucx.rpc.{UcxDriverRpcEndpoint, UcxExecutorRpcEndpoint} import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} @@ -81,18 +82,28 @@ abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) e val address = transport.init() ucxTransport = transport latch.countDown() - val rpcEnv = RpcEnv.create("ucx-rpc-env", blockManager.host, blockManager.port, - conf, new SecurityManager(conf), clientMode = false) + val rpcEnv = SparkEnv.get.rpcEnv executorEndpoint = new UcxExecutorRpcEndpoint(rpcEnv, ucxTransport, setupThread) val endpoint = rpcEnv.setupEndpoint( s"ucx-shuffle-executor-${blockManager.executorId}", executorEndpoint) - val driverEndpoint = RpcUtils.makeDriverRef(driverRpcName, conf, rpcEnv) + var driverCost = 0 + var driverEndpoint: RpcEndpointRef = null + while (driverEndpoint == null) { + try { + driverEndpoint = RpcUtils.makeDriverRef(driverRpcName, conf, rpcEnv) + } catch { + case e: SparkException => { + Thread.sleep(5) + driverCost += 5 + } + } + } driverEndpoint.ask[IntroduceAllExecutors](ExecutorAdded(blockManager.executorId.toLong, endpoint, new SerializableDirectBuffer(address))) .andThen { case Success(msg) => - logInfo(s"Receive reply $msg") + logInfo(s"Driver take $driverCost ms. Receive reply ${msg.asInstanceOf[IntroduceAllExecutors].executorIdToAddress.keys}") executorEndpoint.receive(msg) } }