From 084fd6db6f640fd7fe92ef22bf3bf98510c1c2b7 Mon Sep 17 00:00:00 2001 From: zizhao Date: Tue, 25 Apr 2023 07:10:40 +0300 Subject: [PATCH 01/33] bind worker to thread --- .../compat/spark_2_4/UcxShuffleClient.scala | 7 +- .../compat/spark_2_4/UcxShuffleReader.scala | 2 +- .../compat/spark_3_0/UcxShuffleClient.scala | 9 +- .../compat/spark_3_0/UcxShuffleReader.scala | 2 +- .../shuffle/ucx/UcxShuffleTransport.scala | 93 ++++++++++++++----- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 47 +++++++++- .../ucx/rpc/GlobalWorkerRpcThread.scala | 9 +- 7 files changed, 132 insertions(+), 37 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala index cff68d1d..46cc4003 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala @@ -8,6 +8,7 @@ import org.apache.spark.shuffle.utils.UnsafeUtils import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient{ + val worker = transport.selectLocalWorker() override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { @@ -28,10 +29,14 @@ class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient } } val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) - transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks) + worker.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks) } override def close(): Unit = { + transport.releaseLocalWorker() + } + def progress(): Unit = { + worker.progress() } } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala index a19e07a5..1bc8a66f 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala @@ -62,7 +62,7 @@ class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], override def next(): (BlockId, InputStream) = { val startTime = System.currentTimeMillis() while (resultQueue.isEmpty) { - transport.progress() + shuffleClient.progress() } shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) wrappedStreams.next() diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala index 50b6cfd5..0a046b66 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala @@ -12,7 +12,7 @@ import org.apache.spark.shuffle.utils.UnsafeUtils import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Map[Long, Int]) extends BlockStoreClient with Logging { - + val worker = transport.selectLocalWorker() override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { @@ -40,11 +40,14 @@ class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Ma } } val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) - transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks) - transport.progress() + worker.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks) } override def close(): Unit = { + transport.releaseLocalWorker() + } + def progress(): Unit = { + worker.progress() } } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala index 37c68efb..655c5828 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala @@ -117,7 +117,7 @@ private[spark] class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], override def next(): (BlockId, InputStream) = { val startTime = System.nanoTime() while (resultQueue.isEmpty) { - transport.progress() + shuffleClient.progress() } val fetchWaitTime = System.nanoTime() - startTime readMetrics.incFetchWaitTime(TimeUnit.NANOSECONDS.toMillis(fetchWaitTime)) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index b7c6fbe7..9f5208a9 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -14,8 +14,11 @@ import org.openucx.jucx.UcxException import org.openucx.jucx.ucp._ import org.openucx.jucx.ucs.UcsConstants +import java.lang.ThreadLocal import java.net.InetSocketAddress import java.nio.ByteBuffer +import java.util.concurrent.ArrayBlockingQueue +import java.util.concurrent.atomic.AtomicInteger import scala.collection.concurrent.TrieMap import scala.collection.mutable @@ -81,8 +84,15 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val endpoints = mutable.Set.empty[UcpEndpoint] val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] + private val localWorker = new ThreadLocal[UcxWorkerWrapper] { + override def initialValue = null + } + private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _ - private var allocatedServerWorkers: Array[UcxWorkerWrapper] = _ + private val clientWorkerIds = new ArrayBlockingQueue[Int](ucxShuffleConf.numWorkers) + + private var allocatedServerThreads: Array[UcxWorkerThread] = _ + private val serverThreadId = new AtomicInteger() private val registeredBlocks = new TrieMap[BlockId, Block] private var progressThread: Thread = _ @@ -122,11 +132,13 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo globalWorker = ucxContext.newWorker(ucpWorkerParams) hostBounceBufferMemoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) - allocatedServerWorkers = new Array[UcxWorkerWrapper](ucxShuffleConf.numListenerThreads) + allocatedServerThreads = new Array[UcxWorkerThread](ucxShuffleConf.numListenerThreads) logInfo(s"Allocating ${ucxShuffleConf.numListenerThreads} server workers") for (i <- 0 until ucxShuffleConf.numListenerThreads) { val worker = ucxContext.newWorker(ucpWorkerParams) - allocatedServerWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = false) + val workerWrapper = UcxWorkerWrapper(worker, this, isClientWorker = false, i.toLong) + allocatedServerThreads(i) = new UcxWorkerThread(workerWrapper) + allocatedServerThreads(i).start() } val Array(host, port) = ucxShuffleConf.listenerAddress.split(":") @@ -148,6 +160,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo ucpWorkerParams.setClientId(clientId) val worker = ucxContext.newWorker(ucpWorkerParams) allocatedClientWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) + clientWorkerIds.offer(i) } initialized = true @@ -166,7 +179,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo hostBounceBufferMemoryPool.close() allocatedClientWorkers.foreach(_.close()) - allocatedServerWorkers.foreach(_.close()) if (listener != null) { listener.close() @@ -183,6 +195,9 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo globalWorker = null } + allocatedServerThreads.foreach(_.close()) + allocatedServerThreads.foreach(_.join(10)) + if (ucxContext != null) { ucxContext.close() ucxContext = null @@ -268,34 +283,68 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo override def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: BufferAllocator, callbacks: Seq[OperationCallback]): Seq[Request] = { - allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt) + selectLocalWorker() .fetchBlocksByBlockIds(executorId, blockIds, resultBufferAllocator, callbacks) } def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { executorAddresses.put(executorId, workerAddress) - allocatedServerWorkers.foreach(w => w.connectByWorkerAddress(executorId, workerAddress)) + allocatedServerThreads.foreach(t => t.submit(new Runnable { + override def run(): Unit = { + t.workerWrapper.connectByWorkerAddress(executorId, workerAddress) + } + })) } def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = { - val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) - val blockIds = mutable.ArrayBuffer.empty[BlockId] - - // 1. Deserialize blockIds from header - while (buffer.remaining() > 0) { - val blockId = UcxShuffleBockId.deserialize(buffer) - if (!registeredBlocks.contains(blockId)) { - throw new UcxException(s"$blockId is not registered") + val server = selectServerThread() + server.submit(new Runnable { + override def run(): Unit = { + val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) + val blockIds = mutable.ArrayBuffer.empty[BlockId] + + // 1. Deserialize blockIds from header + while (buffer.remaining() > 0) { + val blockId = UcxShuffleBockId.deserialize(buffer) + if (!registeredBlocks.contains(blockId)) { + throw new UcxException(s"$blockId is not registered") + } + blockIds += blockId + } + + val blocks = blockIds.map(bid => registeredBlocks(bid)) + amData.close() + + server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor) + } + }) + } + + def selectLocalWorker(): UcxWorkerWrapper = { + Option(localWorker.get()) match { + case Some(worker) => worker + case None => { + val worker = allocatedClientWorkers(clientWorkerIds.poll()) + localWorker.set(worker) + worker } - blockIds += blockId } + } - val blocks = blockIds.map(bid => registeredBlocks(bid)) - amData.close() - allocatedServerWorkers((Thread.currentThread().getId % allocatedServerWorkers.length).toInt) - .handleFetchBlockRequest(blocks, replyTag, replyExecutor) + def releaseLocalWorker(): Unit = { + Option(localWorker.get()) match { + case Some(worker) => { + clientWorkerIds.offer((worker.id >> 32).toInt - 1) + localWorker.set(null) + } + case None => {} + } } + def selectServerThread(): UcxWorkerThread = { + allocatedServerThreads( + (serverThreadId.incrementAndGet() % allocatedServerThreads.length).abs) + } /** * Progress outstanding operations. This routine is blocking (though may poll for event). @@ -304,9 +353,9 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo * Return from this method guarantees that at least some operation was progressed. * But not guaranteed that at least one [[ fetchBlocksByBlockIds ]] completed! */ - override def progress(): Unit = { - allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt).progress() - } + override def progress(): Unit = { + selectLocalWorker().progress() + } def progressConnect(): Unit = { allocatedClientWorkers.par.foreach(_.progressConnect()) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 6ad88feb..bcdde33e 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -6,7 +6,7 @@ package org.apache.spark.shuffle.ucx import java.io.Closeable import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.{AtomicInteger, AtomicBoolean} import scala.collection.concurrent.TrieMap import scala.util.Random import org.openucx.jucx.ucp._ @@ -72,6 +72,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i transport.ucxShuffleConf.numIoThreads) private val ioTaskSupport = new ForkJoinTaskSupport(ioThreadPool) + if (isClientWorker) { // Receive block data handler worker.setAmRecvHandler(1, @@ -332,3 +333,47 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } } + +class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with Logging { + val id = workerWrapper.id + val transport = workerWrapper.transport + val useWakeup = workerWrapper.transport.ucxShuffleConf.useWakeup + + private val stopping = new AtomicBoolean(false) + private val outstandingRequests = new ConcurrentLinkedQueue[UcpRequest]() + private val outstandingTasks = new ConcurrentLinkedQueue[Runnable]() + + setDaemon(true) + setName(s"UCX-worker $id") + + override def run(): Unit = { + logDebug(s"UCX-worker $id started") + while (!stopping.get()) { + processTask() + } + workerWrapper.close() + logDebug(s"UCX-worker $id stopped") + } + + def processTask(): Unit = { + Option(outstandingTasks.poll()) match { + case Some(task) => { + task.run() + } + case None => { + workerWrapper.worker.waitForEvents() + } + } + } + + def submit(task: Runnable): Unit = { + outstandingTasks.offer(task) + workerWrapper.worker.signal() + } + + def close(): Unit = { + logDebug(s"UCX-worker $id stopping") + stopping.set(true) + workerWrapper.worker.signal() + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala index a9f27b83..8b4676e9 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala @@ -16,19 +16,12 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransp setDaemon(true) setName("Global worker progress thread") - private val replyWorkersThreadPool = ThreadUtils.newDaemonFixedThreadPool(transport.ucxShuffleConf.numListenerThreads, - "UcxListenerThread") - // Main RPC thread. Submit each RPC request to separate thread and send reply back from separate worker. globalWorker.setAmRecvHandler(0, (headerAddress: Long, headerSize: Long, amData: UcpAmData, _: UcpEndpoint) => { val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) val replyTag = header.getInt val replyExecutor = header.getLong - replyWorkersThreadPool.submit(new Runnable { - override def run(): Unit = { - transport.handleFetchBlockRequest(replyTag, amData, replyExecutor) - } - }) + transport.handleFetchBlockRequest(replyTag, amData, replyExecutor) UcsConstants.STATUS.UCS_INPROGRESS }, UcpConstants.UCP_AM_FLAG_PERSISTENT_DATA | UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) From 761792d377921fee3cf1fa93ac2d880f588dc677 Mon Sep 17 00:00:00 2001 From: zizhao Date: Tue, 25 Apr 2023 09:13:08 +0300 Subject: [PATCH 02/33] use wakeup/yield in server/client worker --- .../compat/spark_2_4/UcxShuffleClient.scala | 9 +++++- .../compat/spark_3_0/UcxShuffleClient.scala | 9 +++++- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 29 ++++++++++++------- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala index 46cc4003..75ec5aa0 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala @@ -9,6 +9,7 @@ import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => Spar class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient{ val worker = transport.selectLocalWorker() + var numFetched = 0 override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { @@ -26,6 +27,7 @@ class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient this } }) + numFetched += 1 } } val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) @@ -37,6 +39,11 @@ class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient } def progress(): Unit = { - worker.progress() + numFetched = 0 + while (numFetched == 0) { + if (worker.worker.progress() == 0) { + Thread.`yield`() + } + } } } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala index 0a046b66..de98dad0 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala @@ -13,6 +13,7 @@ import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => Spar class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Map[Long, Int]) extends BlockStoreClient with Logging { val worker = transport.selectLocalWorker() + var numFetched = 0 override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { @@ -37,6 +38,7 @@ class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Ma this } }) + numFetched += 1 } } val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) @@ -48,6 +50,11 @@ class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Ma } def progress(): Unit = { - worker.progress() + numFetched = 0 + while (numFetched == 0) { + if (worker.worker.progress() == 0) { + Thread.`yield`() + } + } } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index bcdde33e..49fc7bd6 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -189,7 +189,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i def connectByWorkerAddress(executorId: transport.ExecutorId, workerAddress: ByteBuffer): Unit = { logDebug(s"Worker $this connecting back to $executorId by worker address") val ep = worker.newEndpoint(new UcpEndpointParams().setName(s"Server connection to $executorId") - .setUcpAddress(workerAddress)) + .setUcpAddress(workerAddress))) connections.put(executorId, ep) } @@ -268,13 +268,12 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i ep.sendAmNonBlocking(0, address, headerSize, dataAddress, buffer.capacity() - headerSize, UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - buffer.clear() - logDebug(s"Sent message on $ep to $executorId to fetch ${blockIds.length} blocks on tag $t id $id" + - s"in ${System.nanoTime() - startTime} ns") - } - }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) - + override def onSuccess(request: UcpRequest): Unit = { + buffer.clear() + logDebug(s"Sent message on $ep to $executorId to fetch ${blockIds.length} blocks on tag $t id $id" + + s"in ${System.nanoTime() - startTime} ns") + } + }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) worker.progressRequest(ep.flushNonBlocking(null)) Seq(request) } @@ -325,8 +324,16 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i }, new UcpRequestParams().setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) .setMemoryHandle(resultMemory.memory)) - while (!req.isCompleted) { - progress() + if (transport.ucxShuffleConf.useWakeup) { + while (!req.isCompleted) { + if (worker.progress() == 0) { + worker.waitForEvents() + } + } + } else { + while (!req.isCompleted) { + worker.progress() + } } } catch { case ex: Throwable => logError(s"Failed to read and send data: $ex") @@ -376,4 +383,4 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L stopping.set(true) workerWrapper.worker.signal() } -} \ No newline at end of file +} From 87b19e06ad0484b1b4a355319a5cba713c7da546 Mon Sep 17 00:00:00 2001 From: zizhao Date: Tue, 25 Apr 2023 10:14:22 +0300 Subject: [PATCH 03/33] use asyn --- .../org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 9f5208a9..e423552d 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -315,7 +315,10 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val blocks = blockIds.map(bid => registeredBlocks(bid)) amData.close() - server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor) + Option(server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor)) match { + case Some(req) => server.submit(req) + case None => {} + } } }) } From c0755053546e2e3ca8cba9862821b758e2fefcd9 Mon Sep 17 00:00:00 2001 From: zizhao Date: Tue, 25 Apr 2023 10:14:44 +0300 Subject: [PATCH 04/33] use loop --- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 59 ++++++++++--------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 49fc7bd6..cdcbff85 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -189,7 +189,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i def connectByWorkerAddress(executorId: transport.ExecutorId, workerAddress: ByteBuffer): Unit = { logDebug(s"Worker $this connecting back to $executorId by worker address") val ep = worker.newEndpoint(new UcpEndpointParams().setName(s"Server connection to $executorId") - .setUcpAddress(workerAddress))) + .setUcpAddress(workerAddress)) connections.put(executorId, ep) } @@ -268,17 +268,17 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i ep.sendAmNonBlocking(0, address, headerSize, dataAddress, buffer.capacity() - headerSize, UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - buffer.clear() - logDebug(s"Sent message on $ep to $executorId to fetch ${blockIds.length} blocks on tag $t id $id" + - s"in ${System.nanoTime() - startTime} ns") - } - }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + override def onSuccess(request: UcpRequest): Unit = { + buffer.clear() + logDebug(s"Sent message on $ep to $executorId to fetch ${blockIds.length} blocks on tag $t id $id" + + s"in ${System.nanoTime() - startTime} ns") + } + }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) worker.progressRequest(ep.flushNonBlocking(null)) Seq(request) } - def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): Unit = try { + def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): UcpRequest = try { val tagAndSizes = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blocks.length val resultMemory = transport.hostBounceBufferMemoryPool.get(tagAndSizes + blocks.map(_.getSize).sum) .asInstanceOf[UcxBounceBufferMemoryBlock] @@ -310,7 +310,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } val startTime = System.nanoTime() - val req = getConnection(replyExecutor).sendAmNonBlocking(1, resultMemory.address, tagAndSizes, + getConnection(replyExecutor).sendAmNonBlocking(1, resultMemory.address, tagAndSizes, resultMemory.address + tagAndSizes, resultMemory.size - tagAndSizes, 0, new UcxCallback { override def onSuccess(request: UcpRequest): Unit = { logTrace(s"Sent ${blocks.length} blocks of size: ${resultMemory.size} " + @@ -323,26 +323,16 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } }, new UcpRequestParams().setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) .setMemoryHandle(resultMemory.memory)) - - if (transport.ucxShuffleConf.useWakeup) { - while (!req.isCompleted) { - if (worker.progress() == 0) { - worker.waitForEvents() - } - } - } else { - while (!req.isCompleted) { - worker.progress() - } - } } catch { case ex: Throwable => logError(s"Failed to read and send data: $ex") + null } } class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with Logging { val id = workerWrapper.id + val worker = workerWrapper.worker val transport = workerWrapper.transport val useWakeup = workerWrapper.transport.ucxShuffleConf.useWakeup @@ -357,6 +347,7 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L logDebug(s"UCX-worker $id started") while (!stopping.get()) { processTask() + processRequest() } workerWrapper.close() logDebug(s"UCX-worker $id stopped") @@ -367,20 +358,34 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L case Some(task) => { task.run() } - case None => { - workerWrapper.worker.waitForEvents() - } + case None => {} + } + } + + def processRequest(): Unit = { + var req = outstandingRequests.peek() + while(req != null && req.isCompleted) { + outstandingRequests.poll() + req = outstandingRequests.peek() + } + while (worker.progress() != 0) {} + if (outstandingTasks.isEmpty && useWakeup) { + worker.waitForEvents() } } def submit(task: Runnable): Unit = { outstandingTasks.offer(task) - workerWrapper.worker.signal() + worker.signal() + } + + def submit(request: UcpRequest): Unit = { + outstandingRequests.offer(request) } def close(): Unit = { logDebug(s"UCX-worker $id stopping") stopping.set(true) - workerWrapper.worker.signal() + worker.signal() } -} +} \ No newline at end of file From fb0bb3415327ebf2481aab53b1b67bb0ec66b66a Mon Sep 17 00:00:00 2001 From: zizhao Date: Thu, 4 May 2023 09:13:41 +0300 Subject: [PATCH 05/33] connect executors in setup thread --- .../apache/spark/shuffle/ucx/UcxShuffleTransport.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index b7c6fbe7..83c09494 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -197,6 +197,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { executorAddresses.put(executorId, workerAddress) allocatedClientWorkers.foreach(w => { + // logDebug(s" connect ($executorId)[$workerAddress]") w.getConnection(executorId) w.progressConnect() }) @@ -205,11 +206,14 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo def addExecutors(executorIdsToAddress: Map[ExecutorId, SerializableDirectBuffer]): Unit = { executorIdsToAddress.foreach { case (executorId, address) => executorAddresses.put(executorId, address.value) + allocatedClientWorkers.foreach(_.getConnection(executorId)) + // logDebug(s" addExecutors ($executorId)[$workerAddress]") } } - def preConnect(): Unit = { - allocatedClientWorkers.foreach(_.preconnect()) + def preConnect(): Unit = { + println(s" addExecutors ${executorAddresses.keys}") + allocatedClientWorkers.foreach(_.progressConnect) } /** From 6b4d43e932155883f0eab5df3d746d9431ac0dd5 Mon Sep 17 00:00:00 2001 From: zizhao Date: Wed, 10 May 2023 06:12:45 +0300 Subject: [PATCH 06/33] yield in loop --- .../shuffle/compat/spark_2_4/UcxShuffleClient.scala | 9 ++------- .../org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala | 6 ++++-- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala index 75ec5aa0..afb1717f 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala @@ -9,7 +9,6 @@ import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => Spar class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient{ val worker = transport.selectLocalWorker() - var numFetched = 0 override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { @@ -27,7 +26,6 @@ class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient this } }) - numFetched += 1 } } val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) @@ -39,11 +37,8 @@ class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient } def progress(): Unit = { - numFetched = 0 - while (numFetched == 0) { - if (worker.worker.progress() == 0) { - Thread.`yield`() - } + while (worker.worker.progress() == 0) { + Thread.`yield`() } } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index cdcbff85..3f82201b 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -165,7 +165,9 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i */ def progressConnect(): Unit = { while (!flushRequests.isEmpty) { - progress() + if (0 == progress()) { + Thread.`yield`() + } flushRequests.removeIf(_.isCompleted) } logTrace(s"Flush completed. Number of connections: ${connections.keys.size}") @@ -274,7 +276,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i s"in ${System.nanoTime() - startTime} ns") } }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) - worker.progressRequest(ep.flushNonBlocking(null)) + // worker.progressRequest(ep.flushNonBlocking(null)) Seq(request) } From d6c1978ee77f322253c8f29c5e1452e57ece0aed Mon Sep 17 00:00:00 2001 From: zizhao Date: Tue, 23 May 2023 09:45:08 +0300 Subject: [PATCH 07/33] reverse client bind --- .../compat/spark_2_4/UcxShuffleClient.scala | 2 +- .../shuffle/ucx/UcxShuffleTransport.scala | 48 ++++++++++--------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala index afb1717f..4ebfd81d 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala @@ -33,7 +33,7 @@ class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient } override def close(): Unit = { - transport.releaseLocalWorker() + // transport.releaseLocalWorker() } def progress(): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 28463572..b67e8f05 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -84,12 +84,12 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val endpoints = mutable.Set.empty[UcpEndpoint] val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] - private val localWorker = new ThreadLocal[UcxWorkerWrapper] { - override def initialValue = null - } + // private val localWorker = new ThreadLocal[UcxWorkerWrapper] { + // override def initialValue = null + // } private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _ - private val clientWorkerIds = new ArrayBlockingQueue[Int](ucxShuffleConf.numWorkers) + // private val clientWorkerId = new AtomicInteger() private var allocatedServerThreads: Array[UcxWorkerThread] = _ private val serverThreadId = new AtomicInteger() @@ -160,7 +160,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo ucpWorkerParams.setClientId(clientId) val worker = ucxContext.newWorker(ucpWorkerParams) allocatedClientWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) - clientWorkerIds.offer(i) } initialized = true @@ -328,25 +327,30 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo } def selectLocalWorker(): UcxWorkerWrapper = { - Option(localWorker.get()) match { - case Some(worker) => worker - case None => { - val worker = allocatedClientWorkers(clientWorkerIds.poll()) - localWorker.set(worker) - worker - } - } + allocatedClientWorkers( + (Thread.currentThread().getId % allocatedClientWorkers.length).toInt) } - def releaseLocalWorker(): Unit = { - Option(localWorker.get()) match { - case Some(worker) => { - clientWorkerIds.offer((worker.id >> 32).toInt - 1) - localWorker.set(null) - } - case None => {} - } - } + // def selectLocalWorker(): UcxWorkerWrapper = { + // Option(localWorker.get()) match { + // case Some(worker) => worker + // case None => { + // val worker = allocatedClientWorkers( + // (clientWorkerId.incrementAndGet() % allocatedServerThreads.length).abs) + // localWorker.set(worker) + // worker + // } + // } + // } + + // def releaseLocalWorker(): Unit = { + // Option(localWorker.get()) match { + // case Some(worker) => { + // localWorker.set(null) + // } + // case None => {} + // } + // } def selectServerThread(): UcxWorkerThread = { allocatedServerThreads( From 731d5d86fba38d3808d98ebd6bd1e2ca9567178d Mon Sep 17 00:00:00 2001 From: zizhao Date: Wed, 31 May 2023 12:53:49 +0300 Subject: [PATCH 08/33] replace while with if in client progress --- .../shuffle/compat/spark_2_4/UcxShuffleClient.scala | 2 +- .../shuffle/compat/spark_3_0/UcxShuffleClient.scala | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala index 4ebfd81d..010edcbb 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala @@ -37,7 +37,7 @@ class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient } def progress(): Unit = { - while (worker.worker.progress() == 0) { + if (worker.worker.progress() == 0) { Thread.`yield`() } } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala index de98dad0..bdba1806 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala @@ -13,7 +13,6 @@ import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => Spar class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Map[Long, Int]) extends BlockStoreClient with Logging { val worker = transport.selectLocalWorker() - var numFetched = 0 override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { @@ -38,7 +37,6 @@ class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Ma this } }) - numFetched += 1 } } val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) @@ -50,11 +48,8 @@ class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Ma } def progress(): Unit = { - numFetched = 0 - while (numFetched == 0) { - if (worker.worker.progress() == 0) { - Thread.`yield`() - } + if (worker.worker.progress() == 0) { + Thread.`yield`() } } } From 1279bc839ad9f2f4546dbbf4faad0925772a2a44 Mon Sep 17 00:00:00 2001 From: zizhao Date: Wed, 31 May 2023 13:17:45 +0300 Subject: [PATCH 09/33] remove useless modifications --- .../compat/spark_2_4/UcxShuffleClient.scala | 1 - .../compat/spark_3_0/UcxShuffleClient.scala | 1 - .../shuffle/ucx/UcxShuffleTransport.scala | 30 ------------------- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 2 -- 4 files changed, 34 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala index 010edcbb..071e805a 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala @@ -33,7 +33,6 @@ class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient } override def close(): Unit = { - // transport.releaseLocalWorker() } def progress(): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala index bdba1806..cc03b413 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala @@ -44,7 +44,6 @@ class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Ma } override def close(): Unit = { - transport.releaseLocalWorker() } def progress(): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index b67e8f05..991437fd 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -14,7 +14,6 @@ import org.openucx.jucx.UcxException import org.openucx.jucx.ucp._ import org.openucx.jucx.ucs.UcsConstants -import java.lang.ThreadLocal import java.net.InetSocketAddress import java.nio.ByteBuffer import java.util.concurrent.ArrayBlockingQueue @@ -84,12 +83,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val endpoints = mutable.Set.empty[UcpEndpoint] val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] - // private val localWorker = new ThreadLocal[UcxWorkerWrapper] { - // override def initialValue = null - // } - private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _ - // private val clientWorkerId = new AtomicInteger() private var allocatedServerThreads: Array[UcxWorkerThread] = _ private val serverThreadId = new AtomicInteger() @@ -211,7 +205,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { executorAddresses.put(executorId, workerAddress) allocatedClientWorkers.foreach(w => { - // logDebug(s" connect ($executorId)[$workerAddress]") w.getConnection(executorId) w.progressConnect() }) @@ -221,12 +214,10 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo executorIdsToAddress.foreach { case (executorId, address) => executorAddresses.put(executorId, address.value) allocatedClientWorkers.foreach(_.getConnection(executorId)) - // logDebug(s" addExecutors ($executorId)[$workerAddress]") } } def preConnect(): Unit = { - println(s" addExecutors ${executorAddresses.keys}") allocatedClientWorkers.foreach(_.progressConnect) } @@ -331,27 +322,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo (Thread.currentThread().getId % allocatedClientWorkers.length).toInt) } - // def selectLocalWorker(): UcxWorkerWrapper = { - // Option(localWorker.get()) match { - // case Some(worker) => worker - // case None => { - // val worker = allocatedClientWorkers( - // (clientWorkerId.incrementAndGet() % allocatedServerThreads.length).abs) - // localWorker.set(worker) - // worker - // } - // } - // } - - // def releaseLocalWorker(): Unit = { - // Option(localWorker.get()) match { - // case Some(worker) => { - // localWorker.set(null) - // } - // case None => {} - // } - // } - def selectServerThread(): UcxWorkerThread = { allocatedServerThreads( (serverThreadId.incrementAndGet() % allocatedServerThreads.length).abs) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 3f82201b..4a445e32 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -72,7 +72,6 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i transport.ucxShuffleConf.numIoThreads) private val ioTaskSupport = new ForkJoinTaskSupport(ioThreadPool) - if (isClientWorker) { // Receive block data handler worker.setAmRecvHandler(1, @@ -276,7 +275,6 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i s"in ${System.nanoTime() - startTime} ns") } }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) - // worker.progressRequest(ep.flushNonBlocking(null)) Seq(request) } From a1e09ec4fa19ba47fd27081f448684f6bc6edf43 Mon Sep 17 00:00:00 2001 From: zizhao Date: Tue, 6 Jun 2023 06:48:15 +0300 Subject: [PATCH 10/33] use sync send --- .../shuffle/ucx/UcxShuffleTransport.scala | 9 +-- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 58 +++++++------------ 2 files changed, 25 insertions(+), 42 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 991437fd..f3be26b5 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -309,10 +309,11 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val blocks = blockIds.map(bid => registeredBlocks(bid)) amData.close() - Option(server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor)) match { - case Some(req) => server.submit(req) - case None => {} - } + server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor) + // Option(server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor)) match { + // case Some(req) => server.submit(req) + // case None => {} + // } } }) } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 4a445e32..1bf93917 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -5,7 +5,7 @@ package org.apache.spark.shuffle.ucx import java.io.Closeable -import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.{ConcurrentLinkedQueue, LinkedBlockingQueue} import java.util.concurrent.atomic.{AtomicInteger, AtomicBoolean} import scala.collection.concurrent.TrieMap import scala.util.Random @@ -164,9 +164,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i */ def progressConnect(): Unit = { while (!flushRequests.isEmpty) { - if (0 == progress()) { - Thread.`yield`() - } + progress() flushRequests.removeIf(_.isCompleted) } logTrace(s"Flush completed. Number of connections: ${connections.keys.size}") @@ -278,7 +276,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i Seq(request) } - def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): UcpRequest = try { + def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): Unit = try { val tagAndSizes = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blocks.length val resultMemory = transport.hostBounceBufferMemoryPool.get(tagAndSizes + blocks.map(_.getSize).sum) .asInstanceOf[UcxBounceBufferMemoryBlock] @@ -310,7 +308,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } val startTime = System.nanoTime() - getConnection(replyExecutor).sendAmNonBlocking(1, resultMemory.address, tagAndSizes, + val req = getConnection(replyExecutor).sendAmNonBlocking(1, resultMemory.address, tagAndSizes, resultMemory.address + tagAndSizes, resultMemory.size - tagAndSizes, 0, new UcxCallback { override def onSuccess(request: UcpRequest): Unit = { logTrace(s"Sent ${blocks.length} blocks of size: ${resultMemory.size} " + @@ -323,22 +321,30 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } }, new UcpRequestParams().setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) .setMemoryHandle(resultMemory.memory)) + + val useWakeup = transport.ucxShuffleConf.useWakeup + if (useWakeup) { + while (!req.isCompleted) { + if (worker.progress() == 0) { + worker.waitForEvents() + } + } + } else { + while (!req.isCompleted) { + worker.progress() + } + } } catch { case ex: Throwable => logError(s"Failed to read and send data: $ex") - null } } class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with Logging { val id = workerWrapper.id - val worker = workerWrapper.worker - val transport = workerWrapper.transport - val useWakeup = workerWrapper.transport.ucxShuffleConf.useWakeup private val stopping = new AtomicBoolean(false) - private val outstandingRequests = new ConcurrentLinkedQueue[UcpRequest]() - private val outstandingTasks = new ConcurrentLinkedQueue[Runnable]() + private val outstandingTasks = new LinkedBlockingQueue[Runnable]() setDaemon(true) setName(s"UCX-worker $id") @@ -347,45 +353,21 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L logDebug(s"UCX-worker $id started") while (!stopping.get()) { processTask() - processRequest() } workerWrapper.close() logDebug(s"UCX-worker $id stopped") } def processTask(): Unit = { - Option(outstandingTasks.poll()) match { - case Some(task) => { - task.run() - } - case None => {} - } - } - - def processRequest(): Unit = { - var req = outstandingRequests.peek() - while(req != null && req.isCompleted) { - outstandingRequests.poll() - req = outstandingRequests.peek() - } - while (worker.progress() != 0) {} - if (outstandingTasks.isEmpty && useWakeup) { - worker.waitForEvents() - } + outstandingTasks.take().run() } def submit(task: Runnable): Unit = { - outstandingTasks.offer(task) - worker.signal() - } - - def submit(request: UcpRequest): Unit = { - outstandingRequests.offer(request) + outstandingTasks.put(task) } def close(): Unit = { logDebug(s"UCX-worker $id stopping") stopping.set(true) - worker.signal() } } \ No newline at end of file From 92674143655211b03d6736e55763cc35691459a5 Mon Sep 17 00:00:00 2001 From: zizhao Date: Mon, 12 Jun 2023 18:29:53 +0300 Subject: [PATCH 11/33] bind worker to thread --- .../compat/spark_2_4/UcxShuffleClient.scala | 7 +-- .../compat/spark_2_4/UcxShuffleReader.scala | 5 +-- .../compat/spark_3_0/UcxShuffleClient.scala | 7 +-- .../compat/spark_3_0/UcxShuffleReader.scala | 5 +-- .../shuffle/ucx/UcxShuffleTransport.scala | 41 +++++++++-------- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 45 ++++++++++--------- 6 files changed, 53 insertions(+), 57 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala index 071e805a..f3ed456e 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala @@ -8,7 +8,7 @@ import org.apache.spark.shuffle.utils.UnsafeUtils import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient{ - val worker = transport.selectLocalWorker() + val worker = transport.selectClientWorker override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { @@ -33,11 +33,6 @@ class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient } override def close(): Unit = { - } - def progress(): Unit = { - if (worker.worker.progress() == 0) { - Thread.`yield`() - } } } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala index 1bc8a66f..b59b6ad8 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala @@ -52,6 +52,7 @@ class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], // Ucx shuffle logic // Java reflection to get access to private results queue + val worker = shuffleClient.worker val queueField = wrappedStreams.getClass.getDeclaredField( "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") queueField.setAccessible(true) @@ -61,9 +62,7 @@ class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { override def next(): (BlockId, InputStream) = { val startTime = System.currentTimeMillis() - while (resultQueue.isEmpty) { - shuffleClient.progress() - } + worker.progressBlocked(() => !resultQueue.isEmpty) shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) wrappedStreams.next() } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala index cc03b413..e94953f4 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala @@ -12,7 +12,7 @@ import org.apache.spark.shuffle.utils.UnsafeUtils import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Map[Long, Int]) extends BlockStoreClient with Logging { - val worker = transport.selectLocalWorker() + val worker = transport.selectClientWorker override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { @@ -44,11 +44,6 @@ class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Ma } override def close(): Unit = { - } - def progress(): Unit = { - if (worker.worker.progress() == 0) { - Thread.`yield`() - } } } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala index 655c5828..3f658b56 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala @@ -107,6 +107,7 @@ private[spark] class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], // Ucx shuffle logic // Java reflection to get access to private results queue + val worker = shuffleClient.worker val queueField = shuffleIterator.getClass.getDeclaredField( "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") queueField.setAccessible(true) @@ -116,9 +117,7 @@ private[spark] class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { override def next(): (BlockId, InputStream) = { val startTime = System.nanoTime() - while (resultQueue.isEmpty) { - shuffleClient.progress() - } + worker.progressBlocked(() => !resultQueue.isEmpty) val fetchWaitTime = System.nanoTime() - startTime readMetrics.incFetchWaitTime(TimeUnit.NANOSECONDS.toMillis(fetchWaitTime)) wrappedStreams.next() diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index f3be26b5..96836eb4 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -14,6 +14,7 @@ import org.openucx.jucx.UcxException import org.openucx.jucx.ucp._ import org.openucx.jucx.ucs.UcsConstants +import java.lang.ThreadLocal import java.net.InetSocketAddress import java.nio.ByteBuffer import java.util.concurrent.ArrayBlockingQueue @@ -84,6 +85,8 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _ + private val clientWorkerId = new AtomicInteger() + private val clientWorker = new ThreadLocal[UcxWorkerWrapper] private var allocatedServerThreads: Array[UcxWorkerThread] = _ private val serverThreadId = new AtomicInteger() @@ -213,12 +216,11 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo def addExecutors(executorIdsToAddress: Map[ExecutorId, SerializableDirectBuffer]): Unit = { executorIdsToAddress.foreach { case (executorId, address) => executorAddresses.put(executorId, address.value) - allocatedClientWorkers.foreach(_.getConnection(executorId)) } } - def preConnect(): Unit = { - allocatedClientWorkers.foreach(_.progressConnect) + def preConnect(): Unit = { + allocatedClientWorkers.foreach(_.preconnect()) } /** @@ -277,7 +279,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo override def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: BufferAllocator, callbacks: Seq[OperationCallback]): Seq[Request] = { - selectLocalWorker() + selectClientWorker .fetchBlocksByBlockIds(executorId, blockIds, resultBufferAllocator, callbacks) } @@ -291,7 +293,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo } def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = { - val server = selectServerThread() + val server = selectServerThread server.submit(new Runnable { override def run(): Unit = { val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) @@ -310,23 +312,24 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo amData.close() server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor) - // Option(server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor)) match { - // case Some(req) => server.submit(req) - // case None => {} - // } } }) } - def selectLocalWorker(): UcxWorkerWrapper = { - allocatedClientWorkers( - (Thread.currentThread().getId % allocatedClientWorkers.length).toInt) + def selectClientWorker(): UcxWorkerWrapper = Option(clientWorker.get) match { + case Some(worker) => worker + case None => { + val worker = allocatedClientWorkers( + (clientWorkerId.incrementAndGet() % allocatedClientWorkers.length).abs) + clientWorker.set(worker) + worker + } } - def selectServerThread(): UcxWorkerThread = { - allocatedServerThreads( - (serverThreadId.incrementAndGet() % allocatedServerThreads.length).abs) - } + @inline + def selectServerThread(): UcxWorkerThread = allocatedServerThreads( + (serverThreadId.incrementAndGet() % allocatedServerThreads.length).abs + ) /** * Progress outstanding operations. This routine is blocking (though may poll for event). @@ -335,9 +338,9 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo * Return from this method guarantees that at least some operation was progressed. * But not guaranteed that at least one [[ fetchBlocksByBlockIds ]] completed! */ - override def progress(): Unit = { - selectLocalWorker().progress() - } + override def progress(): Unit = { + selectClientWorker.progress() + } def progressConnect(): Unit = { allocatedClientWorkers.par.foreach(_.progressConnect()) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 1bf93917..41708d0d 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -177,6 +177,20 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i worker.progress() } + @inline + def progressBlocked(isFinished: () => Boolean): Unit = { + transport.ucxShuffleConf.useWakeup match { + case true => while (!isFinished()) { + if (worker.progress() == 0) { + worker.waitForEvents() + } + } + case false => while (!isFinished()) { + worker.progress() + } + } + } + /** * Establish connections to known instances. */ @@ -241,14 +255,14 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.LONG_SIZE val ep = getConnection(executorId) - if (worker.getMaxAmHeaderSize <= - headerSize + UnsafeUtils.INT_SIZE * blockIds.length) { - val (b1, b2) = blockIds.splitAt(blockIds.length / 2) - val (c1, c2) = callbacks.splitAt(callbacks.length / 2) - val r1 = fetchBlocksByBlockIds(executorId, b1, resultBufferAllocator, c1) - val r2 = fetchBlocksByBlockIds(executorId, b2, resultBufferAllocator, c2) - return r1 ++ r2 - } + // if (worker.getMaxAmHeaderSize <= + // headerSize + UnsafeUtils.INT_SIZE * blockIds.length) { + // val (b1, b2) = blockIds.splitAt(blockIds.length / 2) + // val (c1, c2) = callbacks.splitAt(callbacks.length / 2) + // val r1 = fetchBlocksByBlockIds(executorId, b1, resultBufferAllocator, c1) + // val r2 = fetchBlocksByBlockIds(executorId, b2, resultBufferAllocator, c2) + // return r1 ++ r2 + // } val t = tag.incrementAndGet() @@ -322,18 +336,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i }, new UcpRequestParams().setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) .setMemoryHandle(resultMemory.memory)) - val useWakeup = transport.ucxShuffleConf.useWakeup - if (useWakeup) { - while (!req.isCompleted) { - if (worker.progress() == 0) { - worker.waitForEvents() - } - } - } else { - while (!req.isCompleted) { - worker.progress() - } - } + progressBlocked(() => req.isCompleted) } catch { case ex: Throwable => logError(s"Failed to read and send data: $ex") } @@ -358,10 +361,12 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L logDebug(s"UCX-worker $id stopped") } + @inline def processTask(): Unit = { outstandingTasks.take().run() } + @inline def submit(task: Runnable): Unit = { outstandingTasks.put(task) } From f94359f2fee8ddc13e5642660bf41eb9b59cf8a7 Mon Sep 17 00:00:00 2001 From: zizhao Date: Tue, 13 Jun 2023 10:20:04 +0300 Subject: [PATCH 12/33] bind client worker --- .../compat/spark_2_4/UcxShuffleClient.scala | 2 +- .../shuffle/ucx/UcxShuffleTransport.scala | 24 ++++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala index f3ed456e..22301d7f 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala @@ -33,6 +33,6 @@ class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient } override def close(): Unit = { - + transport.releaseClientWorker() } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 96836eb4..b8583b6f 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -85,7 +85,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _ - private val clientWorkerId = new AtomicInteger() + private var clientWorkerIds: ArrayBlockingQueue[Int] = _ private val clientWorker = new ThreadLocal[UcxWorkerWrapper] private var allocatedServerThreads: Array[UcxWorkerThread] = _ @@ -151,12 +151,14 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo progressThread.start() allocatedClientWorkers = new Array[UcxWorkerWrapper](ucxShuffleConf.numWorkers) + clientWorkerIds = new ArrayBlockingQueue[Int](ucxShuffleConf.numWorkers) logInfo(s"Allocating ${ucxShuffleConf.numWorkers} client workers") for (i <- 0 until ucxShuffleConf.numWorkers) { - val clientId: Long = ((i.toLong + 1L) << 32) | executorId + val clientId = genClientId(i.toLong, executorId) ucpWorkerParams.setClientId(clientId) val worker = ucxContext.newWorker(ucpWorkerParams) allocatedClientWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) + clientWorkerIds.add(i) } initialized = true @@ -316,16 +318,32 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo }) } + @inline def selectClientWorker(): UcxWorkerWrapper = Option(clientWorker.get) match { case Some(worker) => worker case None => { val worker = allocatedClientWorkers( - (clientWorkerId.incrementAndGet() % allocatedClientWorkers.length).abs) + (clientWorkerIds.poll() % allocatedClientWorkers.length).abs) clientWorker.set(worker) worker } } + @inline + def releaseClientWorker(): Unit = Option(clientWorker.get) match { + case Some(worker) => { + clientWorker.set(null) + clientWorkerIds.add(getWorkerId(worker.id).toInt) + } + case None => {} + } + + @inline + def genClientId(workerId: Long, executorId: Long): Long = ((workerId + 1L) << 32) | executorId + + @inline + def getWorkerId(clientId: Long): Long = (clientId >> 32) - 1L + @inline def selectServerThread(): UcxWorkerThread = allocatedServerThreads( (serverThreadId.incrementAndGet() % allocatedServerThreads.length).abs From 4c3ba6e607775a1272d6f577ecb7ac02c3f2affd Mon Sep 17 00:00:00 2001 From: zizhao Date: Tue, 13 Jun 2023 12:47:26 +0300 Subject: [PATCH 13/33] use rr instead of blocked queue to select client --- .../compat/spark_2_4/UcxShuffleClient.scala | 2 +- .../shuffle/ucx/UcxShuffleTransport.scala | 23 +++---------------- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala index 22301d7f..f3ed456e 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala @@ -33,6 +33,6 @@ class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient } override def close(): Unit = { - transport.releaseClientWorker() + } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index b8583b6f..a01e6bcc 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -85,7 +85,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _ - private var clientWorkerIds: ArrayBlockingQueue[Int] = _ + private var clientWorkerId = new AtomicInteger() private val clientWorker = new ThreadLocal[UcxWorkerWrapper] private var allocatedServerThreads: Array[UcxWorkerThread] = _ @@ -151,14 +151,12 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo progressThread.start() allocatedClientWorkers = new Array[UcxWorkerWrapper](ucxShuffleConf.numWorkers) - clientWorkerIds = new ArrayBlockingQueue[Int](ucxShuffleConf.numWorkers) logInfo(s"Allocating ${ucxShuffleConf.numWorkers} client workers") for (i <- 0 until ucxShuffleConf.numWorkers) { - val clientId = genClientId(i.toLong, executorId) + val clientId: Long = ((i.toLong + 1L) << 32) | executorId ucpWorkerParams.setClientId(clientId) val worker = ucxContext.newWorker(ucpWorkerParams) allocatedClientWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) - clientWorkerIds.add(i) } initialized = true @@ -323,27 +321,12 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo case Some(worker) => worker case None => { val worker = allocatedClientWorkers( - (clientWorkerIds.poll() % allocatedClientWorkers.length).abs) + (clientWorkerId.incrementAndGet() % allocatedClientWorkers.length).abs) clientWorker.set(worker) worker } } - @inline - def releaseClientWorker(): Unit = Option(clientWorker.get) match { - case Some(worker) => { - clientWorker.set(null) - clientWorkerIds.add(getWorkerId(worker.id).toInt) - } - case None => {} - } - - @inline - def genClientId(workerId: Long, executorId: Long): Long = ((workerId + 1L) << 32) | executorId - - @inline - def getWorkerId(clientId: Long): Long = (clientId >> 32) - 1L - @inline def selectServerThread(): UcxWorkerThread = allocatedServerThreads( (serverThreadId.incrementAndGet() % allocatedServerThreads.length).abs From d8c12f1a3a95aa1543a8bf1bc7fac7929f011c98 Mon Sep 17 00:00:00 2001 From: zizhao Date: Tue, 13 Jun 2023 18:36:18 +0300 Subject: [PATCH 14/33] add progressBlocked for flush --- .../org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala | 7 +++++-- .../org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala | 6 ++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index a01e6bcc..06a68c1d 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -191,8 +191,11 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo globalWorker = null } - allocatedServerThreads.foreach(_.close()) - allocatedServerThreads.foreach(_.join(10)) + allocatedServerThreads.foreach{ case(t) => + t.close() + t.join(10) + t.workerWrapper.close() + } if (ucxContext != null) { ucxContext.close() diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 41708d0d..8c33e12a 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -287,6 +287,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i s"in ${System.nanoTime() - startTime} ns") } }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + Seq(request) } @@ -348,6 +349,7 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L private val stopping = new AtomicBoolean(false) private val outstandingTasks = new LinkedBlockingQueue[Runnable]() + private val dummy = new Runnable { override def run = {}} setDaemon(true) setName(s"UCX-worker $id") @@ -357,7 +359,6 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L while (!stopping.get()) { processTask() } - workerWrapper.close() logDebug(s"UCX-worker $id stopped") } @@ -371,8 +372,9 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L outstandingTasks.put(task) } + @inline def close(): Unit = { - logDebug(s"UCX-worker $id stopping") stopping.set(true) + outstandingTasks.put(dummy) } } \ No newline at end of file From 893ee5c0c79545f5f333edbd6f6a408a5749ba50 Mon Sep 17 00:00:00 2001 From: zizhao Date: Wed, 14 Jun 2023 05:33:44 +0300 Subject: [PATCH 15/33] add asyn --- .../shuffle/ucx/UcxShuffleTransport.scala | 4 +- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 58 +++++++++++++------ 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 06a68c1d..8b00a5a2 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -192,9 +192,9 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo } allocatedServerThreads.foreach{ case(t) => - t.close() + t.interrupt() t.join(10) - t.workerWrapper.close() + t.close() } if (ucxContext != null) { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 8c33e12a..15097980 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -6,7 +6,7 @@ package org.apache.spark.shuffle.ucx import java.io.Closeable import java.util.concurrent.{ConcurrentLinkedQueue, LinkedBlockingQueue} -import java.util.concurrent.atomic.{AtomicInteger, AtomicBoolean} +import java.util.concurrent.atomic.AtomicInteger import scala.collection.concurrent.TrieMap import scala.util.Random import org.openucx.jucx.ucp._ @@ -62,6 +62,7 @@ class UcxRefCountMemoryBlock(baseBlock: MemoryBlock, offset: Long, size: Long, case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, isClientWorker: Boolean, id: Long = 0L) extends Closeable with Logging { + private val useWakeup = transport.ucxShuffleConf.useWakeup private final val connections = new TrieMap[transport.ExecutorId, UcpEndpoint] private val requestData = new TrieMap[Int, (Seq[OperationCallback], UcxRequest, transport.BufferAllocator)] @@ -179,13 +180,14 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i @inline def progressBlocked(isFinished: () => Boolean): Unit = { - transport.ucxShuffleConf.useWakeup match { - case true => while (!isFinished()) { + if (useWakeup) { + while (!isFinished()) { if (worker.progress() == 0) { worker.waitForEvents() } } - case false => while (!isFinished()) { + } else { + while (!isFinished()) { worker.progress() } } @@ -291,7 +293,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i Seq(request) } - def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): Unit = try { + def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): UcpRequest = try { val tagAndSizes = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blocks.length val resultMemory = transport.hostBounceBufferMemoryPool.get(tagAndSizes + blocks.map(_.getSize).sum) .asInstanceOf[UcxBounceBufferMemoryBlock] @@ -323,7 +325,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } val startTime = System.nanoTime() - val req = getConnection(replyExecutor).sendAmNonBlocking(1, resultMemory.address, tagAndSizes, + getConnection(replyExecutor).sendAmNonBlocking(1, resultMemory.address, tagAndSizes, resultMemory.address + tagAndSizes, resultMemory.size - tagAndSizes, 0, new UcxCallback { override def onSuccess(request: UcpRequest): Unit = { logTrace(s"Sent ${blocks.length} blocks of size: ${resultMemory.size} " + @@ -336,45 +338,67 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } }, new UcpRequestParams().setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) .setMemoryHandle(resultMemory.memory)) - - progressBlocked(() => req.isCompleted) } catch { case ex: Throwable => logError(s"Failed to read and send data: $ex") + null } } class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with Logging { val id = workerWrapper.id + val worker = workerWrapper.worker + val transport = workerWrapper.transport + val useWakeup = workerWrapper.transport.ucxShuffleConf.useWakeup - private val stopping = new AtomicBoolean(false) - private val outstandingTasks = new LinkedBlockingQueue[Runnable]() - private val dummy = new Runnable { override def run = {}} + private val outstandingRequests = new ConcurrentLinkedQueue[UcpRequest]() + private val outstandingTasks = new ConcurrentLinkedQueue[Runnable]() setDaemon(true) setName(s"UCX-worker $id") override def run(): Unit = { logDebug(s"UCX-worker $id started") - while (!stopping.get()) { + while (!isInterrupted) { processTask() + processRequest() } logDebug(s"UCX-worker $id stopped") } @inline - def processTask(): Unit = { - outstandingTasks.take().run() + def processTask(): Unit = Option(outstandingTasks.poll()) match { + case Some(task) => task.run() + case None => {} + } + + @inline + def processRequest(): Unit = { + var req = outstandingRequests.peek() + while(req != null && req.isCompleted) { + outstandingRequests.poll() + req = outstandingRequests.peek() + } + while (worker.progress() != 0) {} + if (outstandingTasks.isEmpty && useWakeup) { + worker.waitForEvents() + } } @inline def submit(task: Runnable): Unit = { - outstandingTasks.put(task) + outstandingTasks.offer(task) + worker.signal() + } + + @inline + def submit(request: UcpRequest): Unit = { + outstandingRequests.offer(request) + worker.signal() } @inline def close(): Unit = { - stopping.set(true) - outstandingTasks.put(dummy) + workerWrapper.close() } } \ No newline at end of file From 1546d738b67a186868826abc5e01139594469557 Mon Sep 17 00:00:00 2001 From: zizhao Date: Wed, 14 Jun 2023 06:24:39 +0300 Subject: [PATCH 16/33] add pre connection --- .../org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 8b00a5a2..d941409f 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -219,11 +219,12 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo def addExecutors(executorIdsToAddress: Map[ExecutorId, SerializableDirectBuffer]): Unit = { executorIdsToAddress.foreach { case (executorId, address) => executorAddresses.put(executorId, address.value) + allocatedClientWorkers.foreach(_.getConnection(executorId)) } } - def preConnect(): Unit = { - allocatedClientWorkers.foreach(_.preconnect()) + def preConnect(): Unit = { + allocatedClientWorkers.foreach(_.progressConnect) } /** From 3f14725616104c2095d7804998b904536bc01233 Mon Sep 17 00:00:00 2001 From: zizhao Date: Wed, 14 Jun 2023 07:43:47 +0800 Subject: [PATCH 17/33] submit req after fetch --- .../org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index d941409f..775be74c 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -315,7 +315,10 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val blocks = blockIds.map(bid => registeredBlocks(bid)) amData.close() - server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor) + Option(server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor)) match { + case Some(req) => server.submit(req) + case None => {} + } } }) } From 931fba1e06a2fcf94dd315ce410b961a6384754a Mon Sep 17 00:00:00 2001 From: zizhao Date: Tue, 20 Jun 2023 11:50:59 +0800 Subject: [PATCH 18/33] bind client worker to fixed thread --- .../compat/spark_2_4/UcxShuffleClient.scala | 4 +- .../compat/spark_2_4/UcxShuffleReader.scala | 29 +----- .../compat/spark_3_0/UcxShuffleClient.scala | 4 +- .../compat/spark_3_0/UcxShuffleReader.scala | 31 +----- .../spark/shuffle/ucx/ShuffleTransport.scala | 2 +- .../shuffle/ucx/UcxShuffleTransport.scala | 94 +++++++++++-------- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 62 ++++-------- .../shuffle/ucx/perf/UcxPerfBenchmark.scala | 6 +- .../ucx/rpc/UcxExecutorRpcEndpoint.scala | 13 +-- 9 files changed, 85 insertions(+), 160 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala index f3ed456e..20f87378 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala @@ -8,7 +8,7 @@ import org.apache.spark.shuffle.utils.UnsafeUtils import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient{ - val worker = transport.selectClientWorker + override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { @@ -29,7 +29,7 @@ class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient } } val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) - worker.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks) + transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks) } override def close(): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala index b59b6ad8..6ff24cf6 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala @@ -50,35 +50,8 @@ class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) - // Ucx shuffle logic - // Java reflection to get access to private results queue - val worker = shuffleClient.worker - val queueField = wrappedStreams.getClass.getDeclaredField( - "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") - queueField.setAccessible(true) - val resultQueue = queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]] - - // Do progress if queue is empty before calling next on ShuffleIterator - val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { - override def next(): (BlockId, InputStream) = { - val startTime = System.currentTimeMillis() - worker.progressBlocked(() => !resultQueue.isEmpty) - shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) - wrappedStreams.next() - } - - override def hasNext: Boolean = { - val result = wrappedStreams.hasNext - if (!result) { - shuffleClient.close() - } - result - } - } - // End of ucx shuffle logic - val serializerInstance = dep.serializer.newInstance() - val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => + val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala index e94953f4..d2d6d50e 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala @@ -12,7 +12,7 @@ import org.apache.spark.shuffle.utils.UnsafeUtils import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Map[Long, Int]) extends BlockStoreClient with Logging { - val worker = transport.selectClientWorker + override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { @@ -40,7 +40,7 @@ class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Ma } } val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) - worker.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks) + transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks) } override def close(): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala index 3f658b56..2b6e84a4 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala @@ -104,39 +104,10 @@ private[spark] class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], val wrappedStreams = shuffleIterator.toCompletionIterator - - // Ucx shuffle logic - // Java reflection to get access to private results queue - val worker = shuffleClient.worker - val queueField = shuffleIterator.getClass.getDeclaredField( - "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") - queueField.setAccessible(true) - val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]] - - // Do progress if queue is empty before calling next on ShuffleIterator - val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { - override def next(): (BlockId, InputStream) = { - val startTime = System.nanoTime() - worker.progressBlocked(() => !resultQueue.isEmpty) - val fetchWaitTime = System.nanoTime() - startTime - readMetrics.incFetchWaitTime(TimeUnit.NANOSECONDS.toMillis(fetchWaitTime)) - wrappedStreams.next() - } - - override def hasNext: Boolean = { - val result = wrappedStreams.hasNext - if (!result) { - shuffleClient.close() - } - result - } - } - // End of ucx shuffle logic - val serializerInstance = dep.serializer.newInstance() // Create a key/value iterator for each stream - val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => + val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala index 314b88a3..cf6dbc98 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala @@ -155,7 +155,7 @@ trait ShuffleTransport { */ def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: BufferAllocator, - callbacks: Seq[OperationCallback]): Seq[Request] + callbacks: Seq[OperationCallback]): Unit /** * Progress outstanding operations. This routine is blocking (though may poll for event). diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 775be74c..ea6e39c5 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -17,7 +17,6 @@ import org.openucx.jucx.ucs.UcsConstants import java.lang.ThreadLocal import java.net.InetSocketAddress import java.nio.ByteBuffer -import java.util.concurrent.ArrayBlockingQueue import java.util.concurrent.atomic.AtomicInteger import scala.collection.concurrent.TrieMap import scala.collection.mutable @@ -84,9 +83,9 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val endpoints = mutable.Set.empty[UcpEndpoint] val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] - private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _ - private var clientWorkerId = new AtomicInteger() - private val clientWorker = new ThreadLocal[UcxWorkerWrapper] + private var allocatedClientThreads: Array[UcxWorkerThread] = _ + private var clientThreadId = new AtomicInteger() + // private var clientLocal = new ThreadLocal[UcxWorkerThread] = _ private var allocatedServerThreads: Array[UcxWorkerThread] = _ private val serverThreadId = new AtomicInteger() @@ -150,13 +149,15 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo progressThread = new GlobalWorkerRpcThread(globalWorker, this) progressThread.start() - allocatedClientWorkers = new Array[UcxWorkerWrapper](ucxShuffleConf.numWorkers) + allocatedClientThreads = new Array[UcxWorkerThread](ucxShuffleConf.numWorkers) logInfo(s"Allocating ${ucxShuffleConf.numWorkers} client workers") for (i <- 0 until ucxShuffleConf.numWorkers) { val clientId: Long = ((i.toLong + 1L) << 32) | executorId ucpWorkerParams.setClientId(clientId) val worker = ucxContext.newWorker(ucpWorkerParams) - allocatedClientWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) + val workerWrapper = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) + allocatedClientThreads(i) = new UcxWorkerThread(workerWrapper) + allocatedClientThreads(i).start() } initialized = true @@ -174,7 +175,11 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo hostBounceBufferMemoryPool.close() - allocatedClientWorkers.foreach(_.close()) + allocatedClientThreads.foreach{ case(t) => + t.interrupt() + t.join(10) + t.close() + } if (listener != null) { listener.close() @@ -210,21 +215,33 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo */ override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { executorAddresses.put(executorId, workerAddress) - allocatedClientWorkers.foreach(w => { - w.getConnection(executorId) - w.progressConnect() - }) + allocatedClientThreads.foreach { t => t.submit( + new Runnable { + override def run = { + t.workerWrapper.getConnection(executorId) + t.workerWrapper.progressConnect() + } + }) + } } def addExecutors(executorIdsToAddress: Map[ExecutorId, SerializableDirectBuffer]): Unit = { executorIdsToAddress.foreach { case (executorId, address) => executorAddresses.put(executorId, address.value) - allocatedClientWorkers.foreach(_.getConnection(executorId)) + } + allocatedClientThreads.foreach { t => t.submit( + new Runnable { + override def run = { + executorIdsToAddress.foreach { + case (executorId, _) => t.workerWrapper.getConnection(executorId) + } + t.workerWrapper.progressConnect() + } + }) } } - def preConnect(): Unit = { - allocatedClientWorkers.foreach(_.progressConnect) + def preConnect(): Unit = { } /** @@ -282,24 +299,25 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo */ override def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: BufferAllocator, - callbacks: Seq[OperationCallback]): Seq[Request] = { - selectClientWorker - .fetchBlocksByBlockIds(executorId, blockIds, resultBufferAllocator, callbacks) + callbacks: Seq[OperationCallback]): Unit = { + val client = selectClientThread + client.submit(new Runnable { + override def run = client.workerWrapper.fetchBlocksByBlockIds( + executorId, blockIds, resultBufferAllocator, callbacks) + }) } def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { executorAddresses.put(executorId, workerAddress) allocatedServerThreads.foreach(t => t.submit(new Runnable { - override def run(): Unit = { - t.workerWrapper.connectByWorkerAddress(executorId, workerAddress) - } + override def run = t.workerWrapper.connectByWorkerAddress(executorId, workerAddress) })) } def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = { val server = selectServerThread server.submit(new Runnable { - override def run(): Unit = { + override def run = { val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) val blockIds = mutable.ArrayBuffer.empty[BlockId] @@ -315,29 +333,28 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val blocks = blockIds.map(bid => registeredBlocks(bid)) amData.close() - Option(server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor)) match { - case Some(req) => server.submit(req) - case None => {} - } + server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor) } }) } + // @inline + // def selectClientThread(): UcxWorkerThread = Option(clientLocal.get) match { + // case Some(client) => client + // case None => + // val client = allocatedClientThreads( + // (clientThreadId.incrementAndGet() % allocatedClientThreads.length).abs) + // clientLocal.set(client) + // client + // } + @inline - def selectClientWorker(): UcxWorkerWrapper = Option(clientWorker.get) match { - case Some(worker) => worker - case None => { - val worker = allocatedClientWorkers( - (clientWorkerId.incrementAndGet() % allocatedClientWorkers.length).abs) - clientWorker.set(worker) - worker - } - } + def selectClientThread(): UcxWorkerThread = allocatedClientThreads( + (clientThreadId.incrementAndGet() % allocatedClientThreads.length).abs) @inline def selectServerThread(): UcxWorkerThread = allocatedServerThreads( - (serverThreadId.incrementAndGet() % allocatedServerThreads.length).abs - ) + (serverThreadId.incrementAndGet() % allocatedServerThreads.length).abs) /** * Progress outstanding operations. This routine is blocking (though may poll for event). @@ -347,10 +364,5 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo * But not guaranteed that at least one [[ fetchBlocksByBlockIds ]] completed! */ override def progress(): Unit = { - selectClientWorker.progress() - } - - def progressConnect(): Unit = { - allocatedClientWorkers.par.foreach(_.progressConnect()) } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 15097980..042c579e 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -5,7 +5,7 @@ package org.apache.spark.shuffle.ucx import java.io.Closeable -import java.util.concurrent.{ConcurrentLinkedQueue, LinkedBlockingQueue} +import java.util.concurrent.{ConcurrentLinkedQueue, Callable, FutureTask} import java.util.concurrent.atomic.AtomicInteger import scala.collection.concurrent.TrieMap import scala.util.Random @@ -252,20 +252,11 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i def fetchBlocksByBlockIds(executorId: transport.ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: transport.BufferAllocator, - callbacks: Seq[OperationCallback]): Seq[Request] = { + callbacks: Seq[OperationCallback]): Unit = { val startTime = System.nanoTime() val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.LONG_SIZE val ep = getConnection(executorId) - // if (worker.getMaxAmHeaderSize <= - // headerSize + UnsafeUtils.INT_SIZE * blockIds.length) { - // val (b1, b2) = blockIds.splitAt(blockIds.length / 2) - // val (c1, c2) = callbacks.splitAt(callbacks.length / 2) - // val r1 = fetchBlocksByBlockIds(executorId, b1, resultBufferAllocator, c1) - // val r2 = fetchBlocksByBlockIds(executorId, b2, resultBufferAllocator, c2) - // return r1 ++ r2 - // } - val t = tag.incrementAndGet() val buffer = Platform.allocateDirectBuffer(headerSize + blockIds.map(_.serializedSize).sum) @@ -289,11 +280,9 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i s"in ${System.nanoTime() - startTime} ns") } }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) - - Seq(request) } - def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): UcpRequest = try { + def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): Unit = try { val tagAndSizes = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blocks.length val resultMemory = transport.hostBounceBufferMemoryPool.get(tagAndSizes + blocks.map(_.getSize).sum) .asInstanceOf[UcxBounceBufferMemoryBlock] @@ -340,7 +329,6 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i .setMemoryHandle(resultMemory.memory)) } catch { case ex: Throwable => logError(s"Failed to read and send data: $ex") - null } } @@ -351,8 +339,7 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L val transport = workerWrapper.transport val useWakeup = workerWrapper.transport.ucxShuffleConf.useWakeup - private val outstandingRequests = new ConcurrentLinkedQueue[UcpRequest]() - private val outstandingTasks = new ConcurrentLinkedQueue[Runnable]() + private val taskQueue = new ConcurrentLinkedQueue[FutureTask[_]]() setDaemon(true) setName(s"UCX-worker $id") @@ -360,41 +347,32 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L override def run(): Unit = { logDebug(s"UCX-worker $id started") while (!isInterrupted) { - processTask() - processRequest() + Option(taskQueue.poll()) match { + case Some(task) => task.run + case None => {} + } + while (worker.progress() != 0) {} + if(taskQueue.isEmpty && useWakeup) { + worker.waitForEvents() + } } logDebug(s"UCX-worker $id stopped") } @inline - def processTask(): Unit = Option(outstandingTasks.poll()) match { - case Some(task) => task.run() - case None => {} - } - - @inline - def processRequest(): Unit = { - var req = outstandingRequests.peek() - while(req != null && req.isCompleted) { - outstandingRequests.poll() - req = outstandingRequests.peek() - } - while (worker.progress() != 0) {} - if (outstandingTasks.isEmpty && useWakeup) { - worker.waitForEvents() - } - } - - @inline - def submit(task: Runnable): Unit = { - outstandingTasks.offer(task) + def submit(task: Callable[_]) = { + val future = new FutureTask(task) + taskQueue.offer(future) worker.signal() + future } @inline - def submit(request: UcpRequest): Unit = { - outstandingRequests.offer(request) + def submit(task: Runnable) = { + val future = new FutureTask(task, Unit) + taskQueue.offer(future) worker.signal() + future } @inline diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala index 16bd821b..d81bb47c 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala @@ -145,9 +145,9 @@ object UcxPerfBenchmark extends App with Logging { } } val requests = ucxTransport.fetchBlocksByBlockIds(1, blocks, resultBufferAllocator, callbacks) - while (!requests.forall(_.isCompleted)) { - ucxTransport.progress() - } + // while (!requests.forall(_.isCompleted)) { + // ucxTransport.progress() + // } } } ucxTransport.close() diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala index bd9ebb74..7f0be8ab 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala @@ -20,18 +20,9 @@ class UcxExecutorRpcEndpoint(override val rpcEnv: RpcEnv, transport: UcxShuffleT case ExecutorAdded(executorId: Long, _: RpcEndpointRef, ucxWorkerAddress: SerializableDirectBuffer) => logDebug(s"Received ExecutorAdded($executorId)") - executorService.submit(new Runnable() { - override def run(): Unit = { - transport.addExecutor(executorId, ucxWorkerAddress.value) - } - }) + transport.addExecutor(executorId, ucxWorkerAddress.value) case IntroduceAllExecutors(executorIdToWorkerAdresses: Map[Long, SerializableDirectBuffer]) => logDebug(s"Received IntroduceAllExecutors(${executorIdToWorkerAdresses.keys.mkString(",")}") - executorService.submit(new Runnable() { - override def run(): Unit = { - transport.addExecutors(executorIdToWorkerAdresses) - transport.preConnect() - } - }) + transport.addExecutors(executorIdToWorkerAdresses) } } From 6a49fd6700d776c7d45bcf2cd8ca740f88a6b5fc Mon Sep 17 00:00:00 2001 From: zizhao Date: Sun, 25 Jun 2023 06:02:05 +0300 Subject: [PATCH 19/33] temp --- .../shuffle/ucx/UcxShuffleTransport.scala | 35 ++++++++++--------- .../ucx/rpc/UcxExecutorRpcEndpoint.scala | 4 ++- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index ea6e39c5..c8410967 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -215,13 +215,11 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo */ override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { executorAddresses.put(executorId, workerAddress) - allocatedClientThreads.foreach { t => t.submit( - new Runnable { - override def run = { - t.workerWrapper.getConnection(executorId) - t.workerWrapper.progressConnect() - } - }) + allocatedClientThreads.foreach { t => t.submit(new Runnable { + override def run = { + t.workerWrapper.getConnection(executorId) + t.workerWrapper.progressConnect() + }}) } } @@ -229,16 +227,21 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo executorIdsToAddress.foreach { case (executorId, address) => executorAddresses.put(executorId, address.value) } - allocatedClientThreads.foreach { t => t.submit( - new Runnable { - override def run = { - executorIdsToAddress.foreach { - case (executorId, _) => t.workerWrapper.getConnection(executorId) - } - t.workerWrapper.progressConnect() + allocatedClientThreads.foreach(t => t.submit(new Runnable { + override def run = { + val startTime = System.currentTimeMillis() + executorIdsToAddress.foreach { + case (executorId, _) => t.workerWrapper.getConnection(executorId) } - }) - } + t.workerWrapper.progressConnect() + logInfo(s"preconnect cost ${System.currentTimeMillis() - startTime}ms") + } + })) + // allocatedClientThreads.foreach(t => t.submit(new Runnable { + // override def run = { + // t.workerWrapper.progressConnect() + // } + // })) } def preConnect(): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala index 7f0be8ab..e8376f6d 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala @@ -22,7 +22,9 @@ class UcxExecutorRpcEndpoint(override val rpcEnv: RpcEnv, transport: UcxShuffleT logDebug(s"Received ExecutorAdded($executorId)") transport.addExecutor(executorId, ucxWorkerAddress.value) case IntroduceAllExecutors(executorIdToWorkerAdresses: Map[Long, SerializableDirectBuffer]) => - logDebug(s"Received IntroduceAllExecutors(${executorIdToWorkerAdresses.keys.mkString(",")}") + // logDebug(s"Received IntroduceAllExecutors(${executorIdToWorkerAdresses.keys.mkString(",")}") + val startTime = System.currentTimeMillis() transport.addExecutors(executorIdToWorkerAdresses) + logInfo(s"IntroduceAllExecutors cost ${System.currentTimeMillis() - startTime}ms") } } From e393aa3afc438e1768355d7ed6cbed940c0d699c Mon Sep 17 00:00:00 2001 From: zizhao Date: Mon, 26 Jun 2023 05:56:18 +0300 Subject: [PATCH 20/33] split context --- .../shuffle/ucx/UcxShuffleTransport.scala | 22 ++++++++++++------- .../ucx/rpc/GlobalWorkerRpcThread.scala | 3 +++ .../ucx/rpc/UcxExecutorRpcEndpoint.scala | 4 ++-- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index c8410967..ae50b3dc 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -83,10 +83,12 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val endpoints = mutable.Set.empty[UcpEndpoint] val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] + private[ucx] var cliContext: Array[UcpContext] = _ private var allocatedClientThreads: Array[UcxWorkerThread] = _ private var clientThreadId = new AtomicInteger() // private var clientLocal = new ThreadLocal[UcxWorkerThread] = _ - + + private[ucx] var srvContext: Array[UcpContext] = _ private var allocatedServerThreads: Array[UcxWorkerThread] = _ private val serverThreadId = new AtomicInteger() @@ -128,10 +130,12 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo globalWorker = ucxContext.newWorker(ucpWorkerParams) hostBounceBufferMemoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) + srvContext = new Array[UcpContext](ucxShuffleConf.numListenerThreads) allocatedServerThreads = new Array[UcxWorkerThread](ucxShuffleConf.numListenerThreads) logInfo(s"Allocating ${ucxShuffleConf.numListenerThreads} server workers") for (i <- 0 until ucxShuffleConf.numListenerThreads) { - val worker = ucxContext.newWorker(ucpWorkerParams) + srvContext(i) = new UcpContext(params) + val worker = srvContext(i).newWorker(ucpWorkerParams) val workerWrapper = UcxWorkerWrapper(worker, this, isClientWorker = false, i.toLong) allocatedServerThreads(i) = new UcxWorkerThread(workerWrapper) allocatedServerThreads(i).start() @@ -149,12 +153,14 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo progressThread = new GlobalWorkerRpcThread(globalWorker, this) progressThread.start() + cliContext = new Array[UcpContext](ucxShuffleConf.numWorkers) allocatedClientThreads = new Array[UcxWorkerThread](ucxShuffleConf.numWorkers) logInfo(s"Allocating ${ucxShuffleConf.numWorkers} client workers") for (i <- 0 until ucxShuffleConf.numWorkers) { + cliContext(i) = new UcpContext(params) val clientId: Long = ((i.toLong + 1L) << 32) | executorId ucpWorkerParams.setClientId(clientId) - val worker = ucxContext.newWorker(ucpWorkerParams) + val worker = cliContext(i).newWorker(ucpWorkerParams) val workerWrapper = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) allocatedClientThreads(i) = new UcxWorkerThread(workerWrapper) allocatedClientThreads(i).start() @@ -215,12 +221,12 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo */ override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { executorAddresses.put(executorId, workerAddress) - allocatedClientThreads.foreach { t => t.submit(new Runnable { + allocatedClientThreads.foreach(t => t.submit(new Runnable { override def run = { t.workerWrapper.getConnection(executorId) t.workerWrapper.progressConnect() - }}) - } + } + })) } def addExecutors(executorIdsToAddress: Map[ExecutorId, SerializableDirectBuffer]): Unit = { @@ -230,7 +236,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo allocatedClientThreads.foreach(t => t.submit(new Runnable { override def run = { val startTime = System.currentTimeMillis() - executorIdsToAddress.foreach { + executorAddresses.foreach { case (executorId, _) => t.workerWrapper.getConnection(executorId) } t.workerWrapper.progressConnect() @@ -306,7 +312,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val client = selectClientThread client.submit(new Runnable { override def run = client.workerWrapper.fetchBlocksByBlockIds( - executorId, blockIds, resultBufferAllocator, callbacks) + executorId, blockIds, resultBufferAllocator, callbacks) }) } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala index 8b4676e9..ec7bb995 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala @@ -4,6 +4,7 @@ */ package org.apache.spark.shuffle.ucx.rpc +import java.nio.ByteBuffer import org.openucx.jucx.ucp.{UcpAmData, UcpConstants, UcpEndpoint, UcpWorker} import org.openucx.jucx.ucs.UcsConstants import org.apache.spark.internal.Logging @@ -32,6 +33,8 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransp val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) val executorId = header.getLong val workerAddress = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) + // val copiedAddress = ByteBuffer.allocateDirect(workerAddress.capacity).put(workerAddress) + // amData.close transport.connectServerWorkers(executorId, workerAddress) UcsConstants.STATUS.UCS_OK }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala index e8376f6d..47383300 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala @@ -20,11 +20,11 @@ class UcxExecutorRpcEndpoint(override val rpcEnv: RpcEnv, transport: UcxShuffleT case ExecutorAdded(executorId: Long, _: RpcEndpointRef, ucxWorkerAddress: SerializableDirectBuffer) => logDebug(s"Received ExecutorAdded($executorId)") - transport.addExecutor(executorId, ucxWorkerAddress.value) + transport.addExecutor(executorId, ucxWorkerAddress.value) // may need to use ucxWorkerAddress instead of its value case IntroduceAllExecutors(executorIdToWorkerAdresses: Map[Long, SerializableDirectBuffer]) => // logDebug(s"Received IntroduceAllExecutors(${executorIdToWorkerAdresses.keys.mkString(",")}") val startTime = System.currentTimeMillis() transport.addExecutors(executorIdToWorkerAdresses) - logInfo(s"IntroduceAllExecutors cost ${System.currentTimeMillis() - startTime}ms") + logInfo(s" IntroduceAllExecutors cost ${System.currentTimeMillis() - startTime}ms") } } From 7c471705d9c04cef92805f51fd8d25e4b3ad93e1 Mon Sep 17 00:00:00 2001 From: zizhao Date: Mon, 26 Jun 2023 07:11:16 +0300 Subject: [PATCH 21/33] client use monitor --- .../shuffle/ucx/UcxShuffleTransport.scala | 46 ++++++------------- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 4 +- .../ucx/rpc/GlobalWorkerRpcThread.scala | 2 - .../ucx/rpc/UcxExecutorRpcEndpoint.scala | 17 +++++-- 4 files changed, 28 insertions(+), 41 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index ae50b3dc..2e89efa2 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -83,12 +83,10 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val endpoints = mutable.Set.empty[UcpEndpoint] val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] - private[ucx] var cliContext: Array[UcpContext] = _ private var allocatedClientThreads: Array[UcxWorkerThread] = _ private var clientThreadId = new AtomicInteger() // private var clientLocal = new ThreadLocal[UcxWorkerThread] = _ - - private[ucx] var srvContext: Array[UcpContext] = _ + private var allocatedServerThreads: Array[UcxWorkerThread] = _ private val serverThreadId = new AtomicInteger() @@ -130,12 +128,10 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo globalWorker = ucxContext.newWorker(ucpWorkerParams) hostBounceBufferMemoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) - srvContext = new Array[UcpContext](ucxShuffleConf.numListenerThreads) allocatedServerThreads = new Array[UcxWorkerThread](ucxShuffleConf.numListenerThreads) logInfo(s"Allocating ${ucxShuffleConf.numListenerThreads} server workers") for (i <- 0 until ucxShuffleConf.numListenerThreads) { - srvContext(i) = new UcpContext(params) - val worker = srvContext(i).newWorker(ucpWorkerParams) + val worker = ucxContext.newWorker(ucpWorkerParams) val workerWrapper = UcxWorkerWrapper(worker, this, isClientWorker = false, i.toLong) allocatedServerThreads(i) = new UcxWorkerThread(workerWrapper) allocatedServerThreads(i).start() @@ -153,14 +149,12 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo progressThread = new GlobalWorkerRpcThread(globalWorker, this) progressThread.start() - cliContext = new Array[UcpContext](ucxShuffleConf.numWorkers) allocatedClientThreads = new Array[UcxWorkerThread](ucxShuffleConf.numWorkers) logInfo(s"Allocating ${ucxShuffleConf.numWorkers} client workers") for (i <- 0 until ucxShuffleConf.numWorkers) { - cliContext(i) = new UcpContext(params) val clientId: Long = ((i.toLong + 1L) << 32) | executorId ucpWorkerParams.setClientId(clientId) - val worker = cliContext(i).newWorker(ucpWorkerParams) + val worker = ucxContext.newWorker(ucpWorkerParams) val workerWrapper = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) allocatedClientThreads(i) = new UcxWorkerThread(workerWrapper) allocatedClientThreads(i).start() @@ -221,36 +215,22 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo */ override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { executorAddresses.put(executorId, workerAddress) - allocatedClientThreads.foreach(t => t.submit(new Runnable { - override def run = { - t.workerWrapper.getConnection(executorId) - t.workerWrapper.progressConnect() - } - })) + allocatedClientThreads.foreach(t => t.worker.synchronized { + t.workerWrapper.getConnection(executorId) + t.workerWrapper.progressConnect() + }) } def addExecutors(executorIdsToAddress: Map[ExecutorId, SerializableDirectBuffer]): Unit = { executorIdsToAddress.foreach { case (executorId, address) => executorAddresses.put(executorId, address.value) } - allocatedClientThreads.foreach(t => t.submit(new Runnable { - override def run = { - val startTime = System.currentTimeMillis() - executorAddresses.foreach { - case (executorId, _) => t.workerWrapper.getConnection(executorId) - } - t.workerWrapper.progressConnect() - logInfo(s"preconnect cost ${System.currentTimeMillis() - startTime}ms") - } - })) - // allocatedClientThreads.foreach(t => t.submit(new Runnable { - // override def run = { - // t.workerWrapper.progressConnect() - // } - // })) } def preConnect(): Unit = { + allocatedClientThreads.foreach(t => t.worker.synchronized { + t.workerWrapper.preconnect() + }) } /** @@ -318,9 +298,9 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { executorAddresses.put(executorId, workerAddress) - allocatedServerThreads.foreach(t => t.submit(new Runnable { - override def run = t.workerWrapper.connectByWorkerAddress(executorId, workerAddress) - })) + allocatedServerThreads.foreach(t => t.worker.synchronized { + t.workerWrapper.connectByWorkerAddress(executorId, workerAddress) + }) } def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 042c579e..5b86eb81 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -351,7 +351,9 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L case Some(task) => task.run case None => {} } - while (worker.progress() != 0) {} + worker.synchronized { + while (worker.progress() != 0) {} + } if(taskQueue.isEmpty && useWakeup) { worker.waitForEvents() } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala index ec7bb995..d434542d 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala @@ -33,8 +33,6 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransp val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) val executorId = header.getLong val workerAddress = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) - // val copiedAddress = ByteBuffer.allocateDirect(workerAddress.capacity).put(workerAddress) - // amData.close transport.connectServerWorkers(executorId, workerAddress) UcsConstants.STATUS.UCS_OK }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala index 47383300..bd9ebb74 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala @@ -20,11 +20,18 @@ class UcxExecutorRpcEndpoint(override val rpcEnv: RpcEnv, transport: UcxShuffleT case ExecutorAdded(executorId: Long, _: RpcEndpointRef, ucxWorkerAddress: SerializableDirectBuffer) => logDebug(s"Received ExecutorAdded($executorId)") - transport.addExecutor(executorId, ucxWorkerAddress.value) // may need to use ucxWorkerAddress instead of its value + executorService.submit(new Runnable() { + override def run(): Unit = { + transport.addExecutor(executorId, ucxWorkerAddress.value) + } + }) case IntroduceAllExecutors(executorIdToWorkerAdresses: Map[Long, SerializableDirectBuffer]) => - // logDebug(s"Received IntroduceAllExecutors(${executorIdToWorkerAdresses.keys.mkString(",")}") - val startTime = System.currentTimeMillis() - transport.addExecutors(executorIdToWorkerAdresses) - logInfo(s" IntroduceAllExecutors cost ${System.currentTimeMillis() - startTime}ms") + logDebug(s"Received IntroduceAllExecutors(${executorIdToWorkerAdresses.keys.mkString(",")}") + executorService.submit(new Runnable() { + override def run(): Unit = { + transport.addExecutors(executorIdToWorkerAdresses) + transport.preConnect() + } + }) } } From cdc9992b32fd770884f1f6110ef74cbf9831adb2 Mon Sep 17 00:00:00 2001 From: zizhao Date: Tue, 27 Jun 2023 05:56:34 +0300 Subject: [PATCH 22/33] rm monitor --- .../spark/shuffle/ucx/UcxShuffleTransport.scala | 11 ++++++----- .../apache/spark/shuffle/ucx/UcxWorkerWrapper.scala | 4 +--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 2e89efa2..08397457 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -215,7 +215,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo */ override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { executorAddresses.put(executorId, workerAddress) - allocatedClientThreads.foreach(t => t.worker.synchronized { + allocatedClientThreads.foreach(t => { t.workerWrapper.getConnection(executorId) t.workerWrapper.progressConnect() }) @@ -225,12 +225,13 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo executorIdsToAddress.foreach { case (executorId, address) => executorAddresses.put(executorId, address.value) } + allocatedClientThreads.foreach(t => { + executorIdsToAddress.keys.foreach(t.workerWrapper.getConnection(_)) + t.workerWrapper.progressConnect() + }) } def preConnect(): Unit = { - allocatedClientThreads.foreach(t => t.worker.synchronized { - t.workerWrapper.preconnect() - }) } /** @@ -298,7 +299,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { executorAddresses.put(executorId, workerAddress) - allocatedServerThreads.foreach(t => t.worker.synchronized { + allocatedServerThreads.foreach(t => { t.workerWrapper.connectByWorkerAddress(executorId, workerAddress) }) } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 5b86eb81..042c579e 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -351,9 +351,7 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L case Some(task) => task.run case None => {} } - worker.synchronized { - while (worker.progress() != 0) {} - } + while (worker.progress() != 0) {} if(taskQueue.isEmpty && useWakeup) { worker.waitForEvents() } From b04c2251dae3b28d09d1fe920ecfd48bb526308d Mon Sep 17 00:00:00 2001 From: zizhao Date: Fri, 30 Jun 2023 09:17:16 +0300 Subject: [PATCH 23/33] add sync in worker thread --- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 27 +++++-------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 042c579e..c2783833 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -5,7 +5,7 @@ package org.apache.spark.shuffle.ucx import java.io.Closeable -import java.util.concurrent.{ConcurrentLinkedQueue, Callable, FutureTask} +import java.util.concurrent.{ConcurrentLinkedQueue, Callable, Future, FutureTask} import java.util.concurrent.atomic.AtomicInteger import scala.collection.concurrent.TrieMap import scala.util.Random @@ -178,21 +178,6 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i worker.progress() } - @inline - def progressBlocked(isFinished: () => Boolean): Unit = { - if (useWakeup) { - while (!isFinished()) { - if (worker.progress() == 0) { - worker.waitForEvents() - } - } - } else { - while (!isFinished()) { - worker.progress() - } - } - } - /** * Establish connections to known instances. */ @@ -339,7 +324,7 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L val transport = workerWrapper.transport val useWakeup = workerWrapper.transport.ucxShuffleConf.useWakeup - private val taskQueue = new ConcurrentLinkedQueue[FutureTask[_]]() + private val taskQueue = new ConcurrentLinkedQueue[Runnable]() setDaemon(true) setName(s"UCX-worker $id") @@ -351,7 +336,9 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L case Some(task) => task.run case None => {} } - while (worker.progress() != 0) {} + worker.synchronized { + while (worker.progress() != 0) {} + } if(taskQueue.isEmpty && useWakeup) { worker.waitForEvents() } @@ -360,7 +347,7 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L } @inline - def submit(task: Callable[_]) = { + def submit(task: Callable[_]): Future[_] = { val future = new FutureTask(task) taskQueue.offer(future) worker.signal() @@ -368,7 +355,7 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L } @inline - def submit(task: Runnable) = { + def submit(task: Runnable): Future[Unit.type] = { val future = new FutureTask(task, Unit) taskQueue.offer(future) worker.signal() From 8267cf79f183ce199ac97168dd363a0b36d475ca Mon Sep 17 00:00:00 2001 From: zizhao Date: Mon, 3 Jul 2023 16:35:15 +0300 Subject: [PATCH 24/33] connection monitor --- .../org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index c2783833..c5ea6304 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -188,8 +188,10 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i def connectByWorkerAddress(executorId: transport.ExecutorId, workerAddress: ByteBuffer): Unit = { logDebug(s"Worker $this connecting back to $executorId by worker address") - val ep = worker.newEndpoint(new UcpEndpointParams().setName(s"Server connection to $executorId") + val ep = worker.synchronized { + worker.newEndpoint(new UcpEndpointParams().setName(s"Server connection to $executorId") .setUcpAddress(workerAddress)) + } connections.put(executorId, ep) } @@ -203,7 +205,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } } - connections.getOrElseUpdate(executorId, { + connections.getOrElseUpdate(executorId, worker.synchronized { val address = transport.executorAddresses(executorId) val endpointParams = new UcpEndpointParams().setPeerErrorHandlingMode() .setSocketAddress(SerializationUtils.deserializeInetAddress(address)).sendClientId() From 1b5ffd5bb87d82e67dcd3a8aa07a41331eec3c00 Mon Sep 17 00:00:00 2001 From: zizhao Date: Wed, 5 Jul 2023 13:03:43 +0300 Subject: [PATCH 25/33] reduce synchronized region --- .../shuffle/ucx/UcxShuffleTransport.scala | 17 ++------ .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 40 ++++++++++--------- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 08397457..a4e31f66 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -175,11 +175,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo hostBounceBufferMemoryPool.close() - allocatedClientThreads.foreach{ case(t) => - t.interrupt() - t.join(10) - t.close() - } + allocatedClientThreads.foreach(_.close) if (listener != null) { listener.close() @@ -196,11 +192,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo globalWorker = null } - allocatedServerThreads.foreach{ case(t) => - t.interrupt() - t.join(10) - t.close() - } + allocatedServerThreads.foreach(_.close) if (ucxContext != null) { ucxContext.close() @@ -225,13 +217,10 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo executorIdsToAddress.foreach { case (executorId, address) => executorAddresses.put(executorId, address.value) } - allocatedClientThreads.foreach(t => { - executorIdsToAddress.keys.foreach(t.workerWrapper.getConnection(_)) - t.workerWrapper.progressConnect() - }) } def preConnect(): Unit = { + allocatedClientThreads.foreach(_.workerWrapper.preconnect()) } /** diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index c5ea6304..3f6f810c 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -190,7 +190,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i logDebug(s"Worker $this connecting back to $executorId by worker address") val ep = worker.synchronized { worker.newEndpoint(new UcpEndpointParams().setName(s"Server connection to $executorId") - .setUcpAddress(workerAddress)) + .setUcpAddress(workerAddress)) } connections.put(executorId, ep) } @@ -205,7 +205,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } } - connections.getOrElseUpdate(executorId, worker.synchronized { + connections.getOrElseUpdate(executorId, { val address = transport.executorAddresses(executorId) val endpointParams = new UcpEndpointParams().setPeerErrorHandlingMode() .setSocketAddress(SerializationUtils.deserializeInetAddress(address)).sendClientId() @@ -218,22 +218,24 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i logDebug(s"Worker $this connecting to Executor($executorId, " + s"${SerializationUtils.deserializeInetAddress(address)}") - val ep = worker.newEndpoint(endpointParams) - val header = Platform.allocateDirectBuffer(UnsafeUtils.LONG_SIZE) - header.putLong(id) - header.rewind() - val workerAddress = worker.getAddress - - ep.sendAmNonBlocking(1, UcxUtils.getAddress(header), UnsafeUtils.LONG_SIZE, - UcxUtils.getAddress(workerAddress), workerAddress.capacity().toLong, UcpConstants.UCP_AM_SEND_FLAG_EAGER, - new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - header.clear() - workerAddress.clear() - } - }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) - flushRequests.add(ep.flushNonBlocking(null)) - ep + worker.synchronized { + val ep = worker.newEndpoint(endpointParams) + val header = Platform.allocateDirectBuffer(UnsafeUtils.LONG_SIZE) + header.putLong(id) + header.rewind() + val workerAddress = worker.getAddress + + ep.sendAmNonBlocking(1, UcxUtils.getAddress(header), UnsafeUtils.LONG_SIZE, + UcxUtils.getAddress(workerAddress), workerAddress.capacity().toLong, UcpConstants.UCP_AM_SEND_FLAG_EAGER, + new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + header.clear() + workerAddress.clear() + } + }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + flushRequests.add(ep.flushNonBlocking(null)) + ep + } }) } @@ -366,6 +368,8 @@ class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with L @inline def close(): Unit = { + interrupt() + join(10) workerWrapper.close() } } \ No newline at end of file From 4b9554749cffef93bffa37e56ca3adaecf48debd Mon Sep 17 00:00:00 2001 From: zizhao Date: Mon, 17 Jul 2023 06:10:31 +0300 Subject: [PATCH 26/33] rm use code --- .../compat/spark_2_4/UcxShuffleClient.scala | 1 - .../spark/shuffle/ucx/ShuffleTransport.scala | 7 ------- .../spark/shuffle/ucx/UcxShuffleTransport.scala | 14 +------------- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 3 +-- .../spark/shuffle/ucx/perf/UcxPerfBenchmark.scala | 5 +---- 5 files changed, 3 insertions(+), 27 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala index 20f87378..cff68d1d 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala @@ -8,7 +8,6 @@ import org.apache.spark.shuffle.utils.UnsafeUtils import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient{ - override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala index cf6dbc98..c4679eda 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala @@ -150,13 +150,6 @@ trait ShuffleTransport { */ def unregister(blockId: BlockId): Unit - /** - * Batch version of [[ fetchBlocksByBlockIds ]]. - */ - def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], - resultBufferAllocator: BufferAllocator, - callbacks: Seq[OperationCallback]): Unit - /** * Progress outstanding operations. This routine is blocking (though may poll for event). * It's required to call this routine within same thread that submitted [[ fetchBlocksByBlockIds ]]. diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index a4e31f66..d2b44150 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -14,7 +14,6 @@ import org.openucx.jucx.UcxException import org.openucx.jucx.ucp._ import org.openucx.jucx.ucs.UcsConstants -import java.lang.ThreadLocal import java.net.InetSocketAddress import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicInteger @@ -85,7 +84,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo private var allocatedClientThreads: Array[UcxWorkerThread] = _ private var clientThreadId = new AtomicInteger() - // private var clientLocal = new ThreadLocal[UcxWorkerThread] = _ private var allocatedServerThreads: Array[UcxWorkerThread] = _ private val serverThreadId = new AtomicInteger() @@ -278,7 +276,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo */ override def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: BufferAllocator, - callbacks: Seq[OperationCallback]): Unit = { + callbacks: Seq[OperationCallback]) = { val client = selectClientThread client.submit(new Runnable { override def run = client.workerWrapper.fetchBlocksByBlockIds( @@ -317,16 +315,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo }) } - // @inline - // def selectClientThread(): UcxWorkerThread = Option(clientLocal.get) match { - // case Some(client) => client - // case None => - // val client = allocatedClientThreads( - // (clientThreadId.incrementAndGet() % allocatedClientThreads.length).abs) - // clientLocal.set(client) - // client - // } - @inline def selectClientThread(): UcxWorkerThread = allocatedClientThreads( (clientThreadId.incrementAndGet() % allocatedClientThreads.length).abs) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 3f6f810c..723cfe9e 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -62,7 +62,6 @@ class UcxRefCountMemoryBlock(baseBlock: MemoryBlock, offset: Long, size: Long, case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, isClientWorker: Boolean, id: Long = 0L) extends Closeable with Logging { - private val useWakeup = transport.ucxShuffleConf.useWakeup private final val connections = new TrieMap[transport.ExecutorId, UcpEndpoint] private val requestData = new TrieMap[Int, (Seq[OperationCallback], UcxRequest, transport.BufferAllocator)] @@ -205,7 +204,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } } - connections.getOrElseUpdate(executorId, { + connections.getOrElseUpdate(executorId, { val address = transport.executorAddresses(executorId) val endpointParams = new UcpEndpointParams().setPeerErrorHandlingMode() .setSocketAddress(SerializationUtils.deserializeInetAddress(address)).sendClientId() diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala index d81bb47c..6e98b6ee 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala @@ -144,10 +144,7 @@ object UcxPerfBenchmark extends App with Logging { } } } - val requests = ucxTransport.fetchBlocksByBlockIds(1, blocks, resultBufferAllocator, callbacks) - // while (!requests.forall(_.isCompleted)) { - // ucxTransport.progress() - // } + ucxTransport.fetchBlocksByBlockIds(1, blocks, resultBufferAllocator, callbacks).get } } ucxTransport.close() From a608bb10bce550dfa2dcdab22b6bcd187393b613 Mon Sep 17 00:00:00 2001 From: zizhao Date: Mon, 17 Jul 2023 06:26:55 +0300 Subject: [PATCH 27/33] reset RR code --- .../apache/spark/shuffle/ucx/UcxShuffleTransport.scala | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index d2b44150..7a23ca58 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -16,7 +16,6 @@ import org.openucx.jucx.ucs.UcsConstants import java.net.InetSocketAddress import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicInteger import scala.collection.concurrent.TrieMap import scala.collection.mutable @@ -83,10 +82,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] private var allocatedClientThreads: Array[UcxWorkerThread] = _ - private var clientThreadId = new AtomicInteger() - private var allocatedServerThreads: Array[UcxWorkerThread] = _ - private val serverThreadId = new AtomicInteger() private val registeredBlocks = new TrieMap[BlockId, Block] private var progressThread: Thread = _ @@ -274,7 +270,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo /** * Batch version of [[ fetchBlocksByBlockIds ]]. */ - override def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], + def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: BufferAllocator, callbacks: Seq[OperationCallback]) = { val client = selectClientThread @@ -317,11 +313,11 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo @inline def selectClientThread(): UcxWorkerThread = allocatedClientThreads( - (clientThreadId.incrementAndGet() % allocatedClientThreads.length).abs) + (Thread.currentThread().getId % allocatedClientThreads.length).toInt) @inline def selectServerThread(): UcxWorkerThread = allocatedServerThreads( - (serverThreadId.incrementAndGet() % allocatedServerThreads.length).abs) + (Thread.currentThread().getId % allocatedServerThreads.length).toInt) /** * Progress outstanding operations. This routine is blocking (though may poll for event). From c55a642621ff6bde57cdb2fd9456ff67bb94f307 Mon Sep 17 00:00:00 2001 From: zizhao Date: Mon, 17 Jul 2023 07:01:04 +0300 Subject: [PATCH 28/33] add latch for perf bench --- .../org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala | 2 +- .../apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 7a23ca58..3f8a3df3 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -272,7 +272,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo */ def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: BufferAllocator, - callbacks: Seq[OperationCallback]) = { + callbacks: Seq[OperationCallback]): Unit = { val client = selectClientThread client.submit(new Runnable { override def run = client.workerWrapper.fetchBlocksByBlockIds( diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala index 6e98b6ee..3d9be266 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala @@ -9,6 +9,7 @@ import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import java.nio.channels.FileChannel +import java.util.concurrent.CountDownLatch import java.util.concurrent.atomic.AtomicInteger import org.apache.commons.cli.{GnuParser, HelpFormatter, Options} import org.apache.spark.SparkConf @@ -127,6 +128,7 @@ object UcxPerfBenchmark extends App with Logging { } for (_ <- 0 until options.numIterations) { + val latch = new CountDownLatch(options.numOutstanding) for (b <- blockCollection) { requestInFlight.set(options.numOutstanding) for (o <- 0 until options.numOutstanding) { @@ -142,9 +144,11 @@ object UcxPerfBenchmark extends App with Logging { (options.blockSize * options.numOutstanding * options.numThreads) / (1024.0 * 1024.0 * (stats.getElapsedTimeNs / 1e9))) } + latch.countDown } } - ucxTransport.fetchBlocksByBlockIds(1, blocks, resultBufferAllocator, callbacks).get + ucxTransport.fetchBlocksByBlockIds(1, blocks, resultBufferAllocator, callbacks) + latch.await } } ucxTransport.close() From bb1718001a0a66105e64ab8a0ca873d02f8f8a0e Mon Sep 17 00:00:00 2001 From: Zihao Zhao Date: Thu, 30 May 2024 07:53:28 +0800 Subject: [PATCH 29/33] add thread pool + progress thread --- .../shuffle/ucx/UcxShuffleTransport.scala | 110 ++++++----- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 181 ++++++++---------- .../ucx/rpc/GlobalWorkerRpcThread.scala | 20 +- 3 files changed, 156 insertions(+), 155 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 3f8a3df3..12cafc67 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -9,6 +9,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.shuffle.ucx.memory.UcxHostBounceBuffersPool import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils} +import org.apache.spark.util.ThreadUtils import org.apache.spark.shuffle.utils.UnsafeUtils import org.openucx.jucx.UcxException import org.openucx.jucx.ucp._ @@ -16,6 +17,7 @@ import org.openucx.jucx.ucs.UcsConstants import java.net.InetSocketAddress import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger import scala.collection.concurrent.TrieMap import scala.collection.mutable @@ -53,7 +55,7 @@ class UcxStats extends OperationStats { } case class UcxShuffleBockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - override def serializedSize: Int = 12 + override def serializedSize: Int = UcxShuffleBockId.serializedSize override def serialize(byteBuffer: ByteBuffer): Unit = { byteBuffer.putInt(shuffleId) @@ -63,6 +65,8 @@ case class UcxShuffleBockId(shuffleId: Int, mapId: Int, reduceId: Int) extends B } object UcxShuffleBockId { + val serializedSize: Int = 12 + def deserialize(byteBuffer: ByteBuffer): UcxShuffleBockId = { val shuffleId = byteBuffer.getInt val mapId = byteBuffer.getInt @@ -81,13 +85,20 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val endpoints = mutable.Set.empty[UcpEndpoint] val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] - private var allocatedClientThreads: Array[UcxWorkerThread] = _ - private var allocatedServerThreads: Array[UcxWorkerThread] = _ + private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _ + private var clientWorkerId = new AtomicInteger() + + private var allocatedServerWorkers: Array[UcxWorkerWrapper] = _ + private val serverWorkerId = new AtomicInteger() + private var serverLocal = new ThreadLocal[UcxWorkerWrapper] private val registeredBlocks = new TrieMap[BlockId, Block] private var progressThread: Thread = _ var hostBounceBufferMemoryPool: UcxHostBounceBuffersPool = _ + private[spark] lazy val replyThreadPool = ThreadUtils.newForkJoinPool( + "UcxListenerThread", ucxShuffleConf.numListenerThreads) + private val errorHandler = new UcpEndpointErrorHandler { override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = { if (errorCode == UcsConstants.STATUS.UCS_ERR_CONNECTION_RESET) { @@ -122,13 +133,11 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo globalWorker = ucxContext.newWorker(ucpWorkerParams) hostBounceBufferMemoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) - allocatedServerThreads = new Array[UcxWorkerThread](ucxShuffleConf.numListenerThreads) + allocatedServerWorkers = new Array[UcxWorkerWrapper](ucxShuffleConf.numListenerThreads) logInfo(s"Allocating ${ucxShuffleConf.numListenerThreads} server workers") for (i <- 0 until ucxShuffleConf.numListenerThreads) { val worker = ucxContext.newWorker(ucpWorkerParams) - val workerWrapper = UcxWorkerWrapper(worker, this, isClientWorker = false, i.toLong) - allocatedServerThreads(i) = new UcxWorkerThread(workerWrapper) - allocatedServerThreads(i).start() + allocatedServerWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = false, i.toLong) } val Array(host, port) = ucxShuffleConf.listenerAddress.split(":") @@ -143,17 +152,17 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo progressThread = new GlobalWorkerRpcThread(globalWorker, this) progressThread.start() - allocatedClientThreads = new Array[UcxWorkerThread](ucxShuffleConf.numWorkers) + allocatedClientWorkers = new Array[UcxWorkerWrapper](ucxShuffleConf.numWorkers) logInfo(s"Allocating ${ucxShuffleConf.numWorkers} client workers") for (i <- 0 until ucxShuffleConf.numWorkers) { val clientId: Long = ((i.toLong + 1L) << 32) | executorId ucpWorkerParams.setClientId(clientId) val worker = ucxContext.newWorker(ucpWorkerParams) - val workerWrapper = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) - allocatedClientThreads(i) = new UcxWorkerThread(workerWrapper) - allocatedClientThreads(i).start() + allocatedClientWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) } + allocatedServerWorkers.foreach(_.progressStart()) + allocatedClientWorkers.foreach(_.progressStart()) initialized = true logInfo(s"Started listener on ${listener.getAddress}") SerializationUtils.serializeInetAddress(listener.getAddress) @@ -169,7 +178,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo hostBounceBufferMemoryPool.close() - allocatedClientThreads.foreach(_.close) + allocatedClientWorkers.foreach(_.close()) if (listener != null) { listener.close() @@ -186,7 +195,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo globalWorker = null } - allocatedServerThreads.foreach(_.close) + allocatedServerWorkers.foreach(_.close()) if (ucxContext != null) { ucxContext.close() @@ -200,11 +209,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo * connection establishment outside of UcxShuffleManager. */ override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { - executorAddresses.put(executorId, workerAddress) - allocatedClientThreads.foreach(t => { - t.workerWrapper.getConnection(executorId) - t.workerWrapper.progressConnect() - }) + allocatedClientWorkers.foreach(_.getConnection(executorId)) } def addExecutors(executorIdsToAddress: Map[ExecutorId, SerializableDirectBuffer]): Unit = { @@ -214,7 +219,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo } def preConnect(): Unit = { - allocatedClientThreads.foreach(_.workerWrapper.preconnect()) + allocatedClientWorkers.foreach(_.preconnect()) } /** @@ -273,51 +278,39 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: BufferAllocator, callbacks: Seq[OperationCallback]): Unit = { - val client = selectClientThread - client.submit(new Runnable { - override def run = client.workerWrapper.fetchBlocksByBlockIds( - executorId, blockIds, resultBufferAllocator, callbacks) - }) + selectClientWorker.fetchBlocksByBlockIds(executorId, blockIds, + resultBufferAllocator, callbacks) } def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { - executorAddresses.put(executorId, workerAddress) - allocatedServerThreads.foreach(t => { - t.workerWrapper.connectByWorkerAddress(executorId, workerAddress) - }) + allocatedServerWorkers.foreach( + _.connectByWorkerAddress(executorId, workerAddress)) } - def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = { - val server = selectServerThread - server.submit(new Runnable { - override def run = { - val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) - val blockIds = mutable.ArrayBuffer.empty[BlockId] - - // 1. Deserialize blockIds from header - while (buffer.remaining() > 0) { - val blockId = UcxShuffleBockId.deserialize(buffer) - if (!registeredBlocks.contains(blockId)) { - throw new UcxException(s"$blockId is not registered") - } - blockIds += blockId - } - + def handleFetchBlockRequest(replyTag: Int, blockIds: Seq[BlockId], + replyExecutor: Long): Unit = { + replyThreadPool.submit(new Runnable { + override def run(): Unit = { val blocks = blockIds.map(bid => registeredBlocks(bid)) - amData.close() - - server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor) + selectServerWorker.handleFetchBlockRequest(blocks, replyTag, + replyExecutor) } }) } @inline - def selectClientThread(): UcxWorkerThread = allocatedClientThreads( - (Thread.currentThread().getId % allocatedClientThreads.length).toInt) + def selectClientWorker(): UcxWorkerWrapper = allocatedClientWorkers( + (clientWorkerId.incrementAndGet() % allocatedClientWorkers.length).abs) @inline - def selectServerThread(): UcxWorkerThread = allocatedServerThreads( - (Thread.currentThread().getId % allocatedServerThreads.length).toInt) + def selectServerWorker(): UcxWorkerWrapper = Option(serverLocal.get) match { + case Some(server) => server + case None => + val server = allocatedServerWorkers( + (serverWorkerId.incrementAndGet() % allocatedServerWorkers.length).abs) + serverLocal.set(server) + server + } /** * Progress outstanding operations. This routine is blocking (though may poll for event). @@ -329,3 +322,18 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo override def progress(): Unit = { } } + +private[ucx] class UcxSucceedOperationResult( + mem: MemoryBlock, stats: OperationStats) extends OperationResult { + override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS + + override def getError: TransportError = null + + override def getStats: Option[OperationStats] = Option(stats) + + override def getData: MemoryBlock = mem +} + +private[ucx] class UcxFetchState(val callbacks: Seq[OperationCallback], + val request: UcxRequest, + val timestamp: Long) {} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 723cfe9e..f945c5e7 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -62,15 +62,19 @@ class UcxRefCountMemoryBlock(baseBlock: MemoryBlock, offset: Long, size: Long, case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, isClientWorker: Boolean, id: Long = 0L) extends Closeable with Logging { - - private final val connections = new TrieMap[transport.ExecutorId, UcpEndpoint] - private val requestData = new TrieMap[Int, (Seq[OperationCallback], UcxRequest, transport.BufferAllocator)] + private[ucx] final val timeout = transport.ucxShuffleConf.getSparkConf.getTimeAsSeconds( + "spark.network.timeout", "120s") * 1000 + private[ucx] final val connections = new TrieMap[transport.ExecutorId, UcpEndpoint] + private[ucx] lazy val requestData = new TrieMap[Int, UcxFetchState] private val tag = new AtomicInteger(Random.nextInt()) - private val flushRequests = new ConcurrentLinkedQueue[UcpRequest]() - private val ioThreadPool = ThreadUtils.newForkJoinPool("IO threads", + private[ucx] lazy val ioThreadOn = transport.ucxShuffleConf.numIoThreads > 1 + private[ucx] lazy val ioThreadPool = ThreadUtils.newForkJoinPool("IO threads", transport.ucxShuffleConf.numIoThreads) - private val ioTaskSupport = new ForkJoinTaskSupport(ioThreadPool) + private[ucx] lazy val ioTaskSupport = new ForkJoinTaskSupport(ioThreadPool) + private[ucx] var progressThread: Thread = _ + + private[ucx] lazy val memPool = transport.hostBounceBufferMemoryPool if (isClientWorker) { // Receive block data handler @@ -84,7 +88,9 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i throw new UcxException(s"No data for tag $i.") } - val (callbacks, request, allocator) = data.get + val fetchState = data.get + val callbacks = fetchState.callbacks + val request = fetchState.request val stats = request.getStats.get.asInstanceOf[UcxStats] stats.receiveSize = ucpAmData.getLength @@ -116,14 +122,14 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } if (callbacks.isEmpty) UcsConstants.STATUS.UCS_OK else UcsConstants.STATUS.UCS_INPROGRESS } else { - val mem = allocator(ucpAmData.getLength) + val mem = memPool.get(ucpAmData.getLength) stats.amHandleTime = System.nanoTime() request.setRequest(worker.recvAmDataNonBlocking(ucpAmData.getDataHandle, mem.address, ucpAmData.getLength, new UcxCallback() { override def onSuccess(r: UcpRequest): Unit = { request.completed = true stats.endTime = System.nanoTime() - logDebug(s"Received rndv data of size: ${mem.size} for tag $i in " + + logDebug(s"Received rndv data of size: ${ucpAmData.getLength} for tag $i in " + s"${stats.getElapsedTimeNs} ns " + s"time from amHandle: ${System.nanoTime() - stats.amHandleTime} ns") for (b <- 0 until numBlocks) { @@ -148,26 +154,35 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } override def close(): Unit = { - val closeRequests = connections.map { - case (_, endpoint) => endpoint.closeNonBlockingForce() - } - while (!closeRequests.forall(_.isCompleted)) { - progress() + if (isClientWorker) { + val closeRequests = connections.map { + case (_, endpoint) => endpoint.closeNonBlockingForce() + } + while (!closeRequests.forall(_.isCompleted)) { + progress() + } } - ioThreadPool.shutdown() connections.clear() + if (progressThread != null) { + progressThread.interrupt() + progressThread.join(1) + } + if (ioThreadOn) { + ioThreadPool.shutdown() + } worker.close() } + def progressStart(): Unit = { + progressThread = new ProgressThread(s"UCX-progress-$id", worker, + transport.ucxShuffleConf.useWakeup) + progressThread.start() + } + /** * Blocking progress until there's outstanding flush requests. */ def progressConnect(): Unit = { - while (!flushRequests.isEmpty) { - progress() - flushRequests.removeIf(_.isCompleted) - } - logTrace(s"Flush completed. Number of connections: ${connections.keys.size}") } /** @@ -196,15 +211,18 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i def getConnection(executorId: transport.ExecutorId): UcpEndpoint = { - val startTime = System.currentTimeMillis() - while (!transport.executorAddresses.contains(executorId)) { - if (System.currentTimeMillis() - startTime > - transport.ucxShuffleConf.getSparkConf.getTimeAsMs("spark.network.timeout", "100")) { - throw new UcxException(s"Don't get a worker address for $executorId") + if (!connections.contains(executorId)) { + if (!transport.executorAddresses.contains(executorId)) { + val startTime = System.currentTimeMillis() + while (!transport.executorAddresses.contains(executorId)) { + if (System.currentTimeMillis() - startTime > timeout) { + throw new UcxException(s"Don't get a worker address for $executorId") + } + } } } - connections.getOrElseUpdate(executorId, { + connections.getOrElseUpdate(executorId, { val address = transport.executorAddresses(executorId) val endpointParams = new UcpEndpointParams().setPeerErrorHandlingMode() .setSocketAddress(SerializationUtils.deserializeInetAddress(address)).sendClientId() @@ -215,8 +233,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } }).setName(s"Endpoint to $executorId") - logDebug(s"Worker $this connecting to Executor($executorId, " + - s"${SerializationUtils.deserializeInetAddress(address)}") + logDebug(s"Worker ${id.toInt}:${id>>32} connecting to Executor($executorId)") worker.synchronized { val ep = worker.newEndpoint(endpointParams) val header = Platform.allocateDirectBuffer(UnsafeUtils.LONG_SIZE) @@ -231,8 +248,10 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i header.clear() workerAddress.clear() } + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to send $errorMsg") + } }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) - flushRequests.add(ep.flushNonBlocking(null)) ep } }) @@ -243,7 +262,6 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i callbacks: Seq[OperationCallback]): Unit = { val startTime = System.nanoTime() val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.LONG_SIZE - val ep = getConnection(executorId) val t = tag.incrementAndGet() @@ -253,29 +271,31 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i blockIds.foreach(b => b.serialize(buffer)) val request = new UcxRequest(null, new UcxStats()) - requestData.put(t, (callbacks, request, resultBufferAllocator)) + requestData.put(t, new UcxFetchState(callbacks, request, startTime)) buffer.rewind() val address = UnsafeUtils.getAdress(buffer) val dataAddress = address + headerSize - ep.sendAmNonBlocking(0, address, - headerSize, dataAddress, buffer.capacity() - headerSize, - UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - buffer.clear() - logDebug(s"Sent message on $ep to $executorId to fetch ${blockIds.length} blocks on tag $t id $id" + - s"in ${System.nanoTime() - startTime} ns") - } - }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + val ep = getConnection(executorId) + worker.synchronized { + ep.sendAmNonBlocking(0, address, + headerSize, dataAddress, buffer.capacity() - headerSize, + UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + buffer.clear() + logDebug(s"Sent message on $ep to $executorId to fetch ${blockIds.length} blocks on tag $t id $id" + + s"in ${System.nanoTime() - startTime} ns") + } + }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + } } def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): Unit = try { val tagAndSizes = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blocks.length - val resultMemory = transport.hostBounceBufferMemoryPool.get(tagAndSizes + blocks.map(_.getSize).sum) - .asInstanceOf[UcxBounceBufferMemoryBlock] - val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, - resultMemory.size) + val msgSize = tagAndSizes + blocks.map(_.getSize).sum + val resultMemory = memPool.get(msgSize).asInstanceOf[UcxBounceBufferMemoryBlock] + val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, msgSize) resultBuffer.putInt(replyTag) var offset = 0 @@ -289,7 +309,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i localBuffer } // Do parallel read of blocks - val blocksCollection = if (transport.ucxShuffleConf.numIoThreads > 1) { + val blocksCollection = if (ioThreadOn) { val parCollection = blocks.indices.par parCollection.tasksupport = ioTaskSupport parCollection @@ -302,73 +322,42 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } val startTime = System.nanoTime() - getConnection(replyExecutor).sendAmNonBlocking(1, resultMemory.address, tagAndSizes, - resultMemory.address + tagAndSizes, resultMemory.size - tagAndSizes, 0, new UcxCallback { - override def onSuccess(request: UcpRequest): Unit = { - logTrace(s"Sent ${blocks.length} blocks of size: ${resultMemory.size} " + - s"to tag $replyTag in ${System.nanoTime() - startTime} ns.") - transport.hostBounceBufferMemoryPool.put(resultMemory) - } + val ep = getConnection(replyExecutor) + worker.synchronized { + ep.sendAmNonBlocking(1, resultMemory.address, tagAndSizes, + resultMemory.address + tagAndSizes, msgSize - tagAndSizes, 0, new UcxCallback { + override def onSuccess(request: UcpRequest): Unit = { + logTrace(s"Sent ${blocks.length} blocks of size: ${msgSize} " + + s"to tag $replyTag in ${System.nanoTime() - startTime} ns.") + resultMemory.close() + } - override def onError(ucsStatus: Int, errorMsg: String): Unit = { - logError(s"Failed to send $errorMsg") - } - }, new UcpRequestParams().setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) - .setMemoryHandle(resultMemory.memory)) + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to send $errorMsg") + resultMemory.close() + } + }, new UcpRequestParams().setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + .setMemoryHandle(resultMemory.memory)) + } } catch { case ex: Throwable => logError(s"Failed to read and send data: $ex") } } -class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with Logging { - val id = workerWrapper.id - val worker = workerWrapper.worker - val transport = workerWrapper.transport - val useWakeup = workerWrapper.transport.ucxShuffleConf.useWakeup - - private val taskQueue = new ConcurrentLinkedQueue[Runnable]() - +private[ucx] class ProgressThread( + name: String, worker: UcpWorker, useWakeup: Boolean) extends Thread { setDaemon(true) - setName(s"UCX-worker $id") + setName(name) override def run(): Unit = { - logDebug(s"UCX-worker $id started") while (!isInterrupted) { - Option(taskQueue.poll()) match { - case Some(task) => task.run - case None => {} - } worker.synchronized { while (worker.progress() != 0) {} } - if(taskQueue.isEmpty && useWakeup) { + if (useWakeup) { worker.waitForEvents() } } - logDebug(s"UCX-worker $id stopped") - } - - @inline - def submit(task: Callable[_]): Future[_] = { - val future = new FutureTask(task) - taskQueue.offer(future) - worker.signal() - future - } - - @inline - def submit(task: Runnable): Future[Unit.type] = { - val future = new FutureTask(task, Unit) - taskQueue.offer(future) - worker.signal() - future - } - - @inline - def close(): Unit = { - interrupt() - join(10) - workerWrapper.close() } } \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala index d434542d..d8e42404 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala @@ -8,7 +8,7 @@ import java.nio.ByteBuffer import org.openucx.jucx.ucp.{UcpAmData, UcpConstants, UcpEndpoint, UcpWorker} import org.openucx.jucx.ucs.UcsConstants import org.apache.spark.internal.Logging -import org.apache.spark.shuffle.ucx.UcxShuffleTransport +import org.apache.spark.shuffle.ucx.{UcxShuffleTransport, UcxShuffleBockId} import org.apache.spark.shuffle.utils.UnsafeUtils import org.apache.spark.util.ThreadUtils @@ -22,9 +22,14 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransp val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) val replyTag = header.getInt val replyExecutor = header.getLong - transport.handleFetchBlockRequest(replyTag, amData, replyExecutor) - UcsConstants.STATUS.UCS_INPROGRESS - }, UcpConstants.UCP_AM_FLAG_PERSISTENT_DATA | UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) + val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, + amData.getLength.toInt) + val blockNum = buffer.remaining() / UcxShuffleBockId.serializedSize + val blockIds = (0 until blockNum).map( + _ => UcxShuffleBockId.deserialize(buffer)) + transport.handleFetchBlockRequest(replyTag, blockIds, replyExecutor) + UcsConstants.STATUS.UCS_OK + }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) // AM to get worker address for client worker and connect server workers to it @@ -40,13 +45,12 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransp override def run(): Unit = { if (transport.ucxShuffleConf.useWakeup) { while (!isInterrupted) { - if (globalWorker.progress() == 0) { - globalWorker.waitForEvents() - } + while (globalWorker.progress != 0) {} + globalWorker.waitForEvents() } } else { while (!isInterrupted) { - globalWorker.progress() + while (globalWorker.progress != 0) {} } } } From 84189059a45431d4812bdbbb267a093d0cfd7248 Mon Sep 17 00:00:00 2001 From: zizhao Date: Mon, 3 Jun 2024 06:06:26 +0300 Subject: [PATCH 30/33] shuffle: simplify --- .../spark/shuffle/ucx/ShuffleTransport.scala | 4 ++++ .../spark/shuffle/ucx/UcxShuffleTransport.scala | 17 +---------------- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala index c4679eda..30bb1056 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala @@ -160,3 +160,7 @@ trait ShuffleTransport { def progress(): Unit } + +class UcxFetchState(val callbacks: Seq[OperationCallback], + val request: UcxRequest, + val timestamp: Long) {} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 12cafc67..e48add01 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -321,19 +321,4 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo */ override def progress(): Unit = { } -} - -private[ucx] class UcxSucceedOperationResult( - mem: MemoryBlock, stats: OperationStats) extends OperationResult { - override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS - - override def getError: TransportError = null - - override def getStats: Option[OperationStats] = Option(stats) - - override def getData: MemoryBlock = mem -} - -private[ucx] class UcxFetchState(val callbacks: Seq[OperationCallback], - val request: UcxRequest, - val timestamp: Long) {} \ No newline at end of file +} \ No newline at end of file From 9a931e9f29c4a6200d9e9233a22c56880843f6b8 Mon Sep 17 00:00:00 2001 From: zizhao Date: Mon, 3 Jun 2024 06:56:59 +0300 Subject: [PATCH 31/33] use awaitUcxTransport instead of sleep in UcxLocalDiskShuffleExecutorComponents --- .../spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala index 931d17b6..ece175a6 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala @@ -23,9 +23,7 @@ class UcxLocalDiskShuffleExecutorComponents(sparkConf: SparkConf) override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = { val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] - while (ucxShuffleManager.ucxTransport == null) { - Thread.sleep(5) - } + ucxShuffleManager.awaitUcxTransport() blockResolver = ucxShuffleManager.shuffleBlockResolver } From 2d1c464c4f02528f66592a2e6886f311182b35c9 Mon Sep 17 00:00:00 2001 From: zizhao Date: Mon, 3 Jun 2024 07:29:30 +0300 Subject: [PATCH 32/33] simplify code Signed-off-by: zihao zhao --- .../spark/shuffle/ucx/UcxShuffleTransport.scala | 1 + .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 14 ++++++-------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 5ca7ecd3..eccd281e 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -205,6 +205,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo * connection establishment outside of UcxShuffleManager. */ override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { + executorAddresses.put(executorId, workerAddress) allocatedClientWorkers.foreach(_.getConnection(executorId)) } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index dee7899c..aedd06f8 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -5,7 +5,6 @@ package org.apache.spark.shuffle.ucx import java.io.Closeable -import java.util.concurrent.{ConcurrentLinkedQueue, Callable, Future, FutureTask} import java.util.concurrent.atomic.AtomicInteger import scala.collection.concurrent.TrieMap import scala.util.Random @@ -211,13 +210,12 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i def getConnection(executorId: transport.ExecutorId): UcpEndpoint = { - if (!connections.contains(executorId)) { - if (!transport.executorAddresses.contains(executorId)) { - val startTime = System.currentTimeMillis() - while (!transport.executorAddresses.contains(executorId)) { - if (System.currentTimeMillis() - startTime > timeout) { - throw new UcxException(s"Don't get a worker address for $executorId") - } + if ((!connections.contains(executorId)) && + (!transport.executorAddresses.contains(executorId))) { + val startTime = System.currentTimeMillis() + while (!transport.executorAddresses.contains(executorId)) { + if (System.currentTimeMillis() - startTime > timeout) { + throw new UcxException(s"Don't get a worker address for $executorId") } } } From 7438e8fea79998583f1280b443d51c8cf9078a83 Mon Sep 17 00:00:00 2001 From: zizhao Date: Mon, 17 Jul 2023 07:06:08 +0300 Subject: [PATCH 33/33] select worker round robin Signed-off-by: zihao zhao --- .../spark/shuffle/ucx/UcxShuffleTransport.scala | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index eccd281e..87c8c863 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -86,7 +86,11 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _ + private val clientWorkerId = new AtomicInteger() + private var allocatedServerWorkers: Array[UcxWorkerWrapper] = _ + private val serverWorkerId = new AtomicInteger() + private val serverLocal = new ThreadLocal[UcxWorkerWrapper] private val registeredBlocks = new TrieMap[BlockId, Block] private var progressThread: Thread = _ @@ -296,11 +300,17 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo @inline def selectClientWorker(): UcxWorkerWrapper = allocatedClientWorkers( - (Thread.currentThread().getId % allocatedClientWorkers.length).toInt) + (clientWorkerId.incrementAndGet() % allocatedClientWorkers.length).abs) @inline - def selectServerWorker(): UcxWorkerWrapper = allocatedServerWorkers( - (Thread.currentThread().getId % allocatedServerWorkers.length).toInt) + def selectServerWorker(): UcxWorkerWrapper = Option(serverLocal.get) match { + case Some(server) => server + case None => + val server = allocatedServerWorkers( + (serverWorkerId.incrementAndGet() % allocatedServerWorkers.length).abs) + serverLocal.set(server) + server + } /** * Progress outstanding operations. This routine is blocking (though may poll for event).