diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala index a19e07a5..6ff24cf6 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala @@ -50,36 +50,8 @@ class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) - // Ucx shuffle logic - // Java reflection to get access to private results queue - val queueField = wrappedStreams.getClass.getDeclaredField( - "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") - queueField.setAccessible(true) - val resultQueue = queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]] - - // Do progress if queue is empty before calling next on ShuffleIterator - val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { - override def next(): (BlockId, InputStream) = { - val startTime = System.currentTimeMillis() - while (resultQueue.isEmpty) { - transport.progress() - } - shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) - wrappedStreams.next() - } - - override def hasNext: Boolean = { - val result = wrappedStreams.hasNext - if (!result) { - shuffleClient.close() - } - result - } - } - // End of ucx shuffle logic - val serializerInstance = dep.serializer.newInstance() - val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => + val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala index 931d17b6..ece175a6 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala @@ -23,9 +23,7 @@ class UcxLocalDiskShuffleExecutorComponents(sparkConf: SparkConf) override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = { val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] - while (ucxShuffleManager.ucxTransport == null) { - Thread.sleep(5) - } + ucxShuffleManager.awaitUcxTransport() blockResolver = ucxShuffleManager.shuffleBlockResolver } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala index 50b6cfd5..d2d6d50e 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala @@ -41,7 +41,6 @@ class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Ma } val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks) - transport.progress() } override def close(): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala index 37c68efb..2b6e84a4 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala @@ -104,40 +104,10 @@ private[spark] class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], val wrappedStreams = shuffleIterator.toCompletionIterator - - // Ucx shuffle logic - // Java reflection to get access to private results queue - val queueField = shuffleIterator.getClass.getDeclaredField( - "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") - queueField.setAccessible(true) - val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]] - - // Do progress if queue is empty before calling next on ShuffleIterator - val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { - override def next(): (BlockId, InputStream) = { - val startTime = System.nanoTime() - while (resultQueue.isEmpty) { - transport.progress() - } - val fetchWaitTime = System.nanoTime() - startTime - readMetrics.incFetchWaitTime(TimeUnit.NANOSECONDS.toMillis(fetchWaitTime)) - wrappedStreams.next() - } - - override def hasNext: Boolean = { - val result = wrappedStreams.hasNext - if (!result) { - shuffleClient.close() - } - result - } - } - // End of ucx shuffle logic - val serializerInstance = dep.serializer.newInstance() // Create a key/value iterator for each stream - val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => + val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala index 314b88a3..30bb1056 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala @@ -150,13 +150,6 @@ trait ShuffleTransport { */ def unregister(blockId: BlockId): Unit - /** - * Batch version of [[ fetchBlocksByBlockIds ]]. - */ - def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], - resultBufferAllocator: BufferAllocator, - callbacks: Seq[OperationCallback]): Seq[Request] - /** * Progress outstanding operations. This routine is blocking (though may poll for event). * It's required to call this routine within same thread that submitted [[ fetchBlocksByBlockIds ]]. @@ -167,3 +160,7 @@ trait ShuffleTransport { def progress(): Unit } + +class UcxFetchState(val callbacks: Seq[OperationCallback], + val request: UcxRequest, + val timestamp: Long) {} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index ae8bb119..eccd281e 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -9,6 +9,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.shuffle.ucx.memory.UcxHostBounceBuffersPool import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils} +import org.apache.spark.util.ThreadUtils import org.apache.spark.shuffle.utils.UnsafeUtils import org.openucx.jucx.UcxException import org.openucx.jucx.ucp._ @@ -16,6 +17,7 @@ import org.openucx.jucx.ucs.UcsConstants import java.net.InetSocketAddress import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger import scala.collection.concurrent.TrieMap import scala.collection.mutable @@ -53,7 +55,7 @@ class UcxStats extends OperationStats { } case class UcxShuffleBockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - override def serializedSize: Int = 12 + override def serializedSize: Int = UcxShuffleBockId.serializedSize override def serialize(byteBuffer: ByteBuffer): Unit = { byteBuffer.putInt(shuffleId) @@ -63,6 +65,8 @@ case class UcxShuffleBockId(shuffleId: Int, mapId: Int, reduceId: Int) extends B } object UcxShuffleBockId { + val serializedSize: Int = 12 + def deserialize(byteBuffer: ByteBuffer): UcxShuffleBockId = { val shuffleId = byteBuffer.getInt val mapId = byteBuffer.getInt @@ -88,6 +92,9 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo private var progressThread: Thread = _ var hostBounceBufferMemoryPool: UcxHostBounceBuffersPool = _ + private[spark] lazy val replyThreadPool = ThreadUtils.newForkJoinPool( + "UcxListenerThread", ucxShuffleConf.numListenerThreads) + private val errorHandler = new UcpEndpointErrorHandler { override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = { if (errorCode == UcsConstants.STATUS.UCS_ERR_CONNECTION_RESET) { @@ -126,7 +133,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo logInfo(s"Allocating ${ucxShuffleConf.numListenerThreads} server workers") for (i <- 0 until ucxShuffleConf.numListenerThreads) { val worker = ucxContext.newWorker(ucpWorkerParams) - allocatedServerWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = false) + allocatedServerWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = false, i.toLong) } val Array(host, port) = ucxShuffleConf.listenerAddress.split(":") @@ -150,6 +157,8 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo allocatedClientWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) } + allocatedServerWorkers.foreach(_.progressStart()) + allocatedClientWorkers.foreach(_.progressStart()) initialized = true logInfo(s"Started listener on ${listener.getAddress}") SerializationUtils.serializeInetAddress(listener.getAddress) @@ -166,7 +175,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo hostBounceBufferMemoryPool.close() allocatedClientWorkers.foreach(_.close()) - allocatedServerWorkers.foreach(_.close()) if (listener != null) { listener.close() @@ -183,6 +191,8 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo globalWorker = null } + allocatedServerWorkers.foreach(_.close()) + if (ucxContext != null) { ucxContext.close() ucxContext = null @@ -196,10 +206,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo */ override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { executorAddresses.put(executorId, workerAddress) - allocatedClientWorkers.foreach(w => { - w.getConnection(executorId) - w.progressConnect() - }) + allocatedClientWorkers.foreach(_.getConnection(executorId)) } def addExecutors(executorIdsToAddress: Map[ExecutorId, SerializableDirectBuffer]): Unit = { @@ -265,36 +272,35 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo /** * Batch version of [[ fetchBlocksByBlockIds ]]. */ - override def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], + def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: BufferAllocator, - callbacks: Seq[OperationCallback]): Seq[Request] = { - allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt) - .fetchBlocksByBlockIds(executorId, blockIds, resultBufferAllocator, callbacks) + callbacks: Seq[OperationCallback]): Unit = { + selectClientWorker.fetchBlocksByBlockIds(executorId, blockIds, + resultBufferAllocator, callbacks) } def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { allocatedServerWorkers.foreach(w => w.connectByWorkerAddress(executorId, workerAddress)) } - def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = { - val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) - val blockIds = mutable.ArrayBuffer.empty[BlockId] - - // 1. Deserialize blockIds from header - while (buffer.remaining() > 0) { - val blockId = UcxShuffleBockId.deserialize(buffer) - if (!registeredBlocks.contains(blockId)) { - throw new UcxException(s"$blockId is not registered") + def handleFetchBlockRequest(replyTag: Int, blockIds: Seq[BlockId], + replyExecutor: Long): Unit = { + replyThreadPool.submit(new Runnable { + override def run(): Unit = { + val blocks = blockIds.map(bid => registeredBlocks(bid)) + selectServerWorker.handleFetchBlockRequest(blocks, replyTag, + replyExecutor) } - blockIds += blockId - } - - val blocks = blockIds.map(bid => registeredBlocks(bid)) - amData.close() - allocatedServerWorkers((Thread.currentThread().getId % allocatedServerWorkers.length).toInt) - .handleFetchBlockRequest(blocks, replyTag, replyExecutor) + }) } + @inline + def selectClientWorker(): UcxWorkerWrapper = allocatedClientWorkers( + (Thread.currentThread().getId % allocatedClientWorkers.length).toInt) + + @inline + def selectServerWorker(): UcxWorkerWrapper = allocatedServerWorkers( + (Thread.currentThread().getId % allocatedServerWorkers.length).toInt) /** * Progress outstanding operations. This routine is blocking (though may poll for event). @@ -304,10 +310,5 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo * But not guaranteed that at least one [[ fetchBlocksByBlockIds ]] completed! */ override def progress(): Unit = { - allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt).progress() } - - def progressConnect(): Unit = { - allocatedClientWorkers.par.foreach(_.progressConnect()) - } -} +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 66cefd76..aedd06f8 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -5,7 +5,6 @@ package org.apache.spark.shuffle.ucx import java.io.Closeable -import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import scala.collection.concurrent.TrieMap import scala.util.Random @@ -62,15 +61,19 @@ class UcxRefCountMemoryBlock(baseBlock: MemoryBlock, offset: Long, size: Long, case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, isClientWorker: Boolean, id: Long = 0L) extends Closeable with Logging { - - private final val connections = new TrieMap[transport.ExecutorId, UcpEndpoint] - private val requestData = new TrieMap[Int, (Seq[OperationCallback], UcxRequest, transport.BufferAllocator)] + private[ucx] final val timeout = transport.ucxShuffleConf.getSparkConf.getTimeAsSeconds( + "spark.network.timeout", "120s") * 1000 + private[ucx] final val connections = new TrieMap[transport.ExecutorId, UcpEndpoint] + private[ucx] lazy val requestData = new TrieMap[Int, UcxFetchState] private val tag = new AtomicInteger(Random.nextInt()) - private val flushRequests = new ConcurrentLinkedQueue[UcpRequest]() - private val ioThreadPool = ThreadUtils.newForkJoinPool("IO threads", + private[ucx] lazy val ioThreadOn = transport.ucxShuffleConf.numIoThreads > 1 + private[ucx] lazy val ioThreadPool = ThreadUtils.newForkJoinPool("IO threads", transport.ucxShuffleConf.numIoThreads) - private val ioTaskSupport = new ForkJoinTaskSupport(ioThreadPool) + private[ucx] lazy val ioTaskSupport = new ForkJoinTaskSupport(ioThreadPool) + private[ucx] var progressThread: Thread = _ + + private[ucx] lazy val memPool = transport.hostBounceBufferMemoryPool if (isClientWorker) { // Receive block data handler @@ -84,7 +87,9 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i throw new UcxException(s"No data for tag $i.") } - val (callbacks, request, allocator) = data.get + val fetchState = data.get + val callbacks = fetchState.callbacks + val request = fetchState.request val stats = request.getStats.get.asInstanceOf[UcxStats] stats.receiveSize = ucpAmData.getLength @@ -116,14 +121,14 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } if (callbacks.isEmpty) UcsConstants.STATUS.UCS_OK else UcsConstants.STATUS.UCS_INPROGRESS } else { - val mem = allocator(ucpAmData.getLength) + val mem = memPool.get(ucpAmData.getLength) stats.amHandleTime = System.nanoTime() request.setRequest(worker.recvAmDataNonBlocking(ucpAmData.getDataHandle, mem.address, ucpAmData.getLength, new UcxCallback() { override def onSuccess(r: UcpRequest): Unit = { request.completed = true stats.endTime = System.nanoTime() - logDebug(s"Received rndv data of size: ${mem.size} for tag $i in " + + logDebug(s"Received rndv data of size: ${ucpAmData.getLength} for tag $i in " + s"${stats.getElapsedTimeNs} ns " + s"time from amHandle: ${System.nanoTime() - stats.amHandleTime} ns") for (b <- 0 until numBlocks) { @@ -148,26 +153,35 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } override def close(): Unit = { - val closeRequests = connections.map { - case (_, endpoint) => endpoint.closeNonBlockingForce() - } - while (!closeRequests.forall(_.isCompleted)) { - progress() + if (isClientWorker) { + val closeRequests = connections.map { + case (_, endpoint) => endpoint.closeNonBlockingForce() + } + while (!closeRequests.forall(_.isCompleted)) { + progress() + } } - ioThreadPool.shutdown() connections.clear() + if (progressThread != null) { + progressThread.interrupt() + progressThread.join(1) + } + if (ioThreadOn) { + ioThreadPool.shutdown() + } worker.close() } + def progressStart(): Unit = { + progressThread = new ProgressThread(s"UCX-progress-$id", worker, + transport.ucxShuffleConf.useWakeup) + progressThread.start() + } + /** * Blocking progress until there's outstanding flush requests. */ def progressConnect(): Unit = { - while (!flushRequests.isEmpty) { - progress() - flushRequests.removeIf(_.isCompleted) - } - logTrace(s"Flush completed. Number of connections: ${connections.keys.size}") } /** @@ -187,22 +201,26 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i def connectByWorkerAddress(executorId: transport.ExecutorId, workerAddress: ByteBuffer): Unit = { logDebug(s"Worker $this connecting back to $executorId by worker address") - val ep = worker.newEndpoint(new UcpEndpointParams().setName(s"Server connection to $executorId") - .setUcpAddress(workerAddress)) + val ep = worker.synchronized { + worker.newEndpoint(new UcpEndpointParams().setName(s"Server connection to $executorId") + .setUcpAddress(workerAddress)) + } connections.put(executorId, ep) } def getConnection(executorId: transport.ExecutorId): UcpEndpoint = { - val startTime = System.currentTimeMillis() - while (!transport.executorAddresses.contains(executorId)) { - if (System.currentTimeMillis() - startTime > - transport.ucxShuffleConf.getSparkConf.getTimeAsMs("spark.network.timeout", "100")) { - throw new UcxException(s"Don't get a worker address for $executorId") + if ((!connections.contains(executorId)) && + (!transport.executorAddresses.contains(executorId))) { + val startTime = System.currentTimeMillis() + while (!transport.executorAddresses.contains(executorId)) { + if (System.currentTimeMillis() - startTime > timeout) { + throw new UcxException(s"Don't get a worker address for $executorId") + } } } - connections.getOrElseUpdate(executorId, { + connections.getOrElseUpdate(executorId, { val address = transport.executorAddresses(executorId) val endpointParams = new UcpEndpointParams().setPeerErrorHandlingMode() .setSocketAddress(SerializationUtils.deserializeInetAddress(address)).sendClientId() @@ -213,42 +231,35 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } }).setName(s"Endpoint to $executorId") - logDebug(s"Worker $this connecting to Executor($executorId, " + - s"${SerializationUtils.deserializeInetAddress(address)}") - val ep = worker.newEndpoint(endpointParams) - val header = Platform.allocateDirectBuffer(UnsafeUtils.LONG_SIZE) - header.putLong(id) - header.rewind() - val workerAddress = worker.getAddress - - ep.sendAmNonBlocking(1, UcxUtils.getAddress(header), UnsafeUtils.LONG_SIZE, - UcxUtils.getAddress(workerAddress), workerAddress.capacity().toLong, UcpConstants.UCP_AM_SEND_FLAG_EAGER, - new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - header.clear() - workerAddress.clear() - } - }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) - flushRequests.add(ep.flushNonBlocking(null)) - ep + logDebug(s"Worker ${id.toInt}:${id>>32} connecting to Executor($executorId)") + worker.synchronized { + val ep = worker.newEndpoint(endpointParams) + val header = Platform.allocateDirectBuffer(UnsafeUtils.LONG_SIZE) + header.putLong(id) + header.rewind() + val workerAddress = worker.getAddress + + ep.sendAmNonBlocking(1, UcxUtils.getAddress(header), UnsafeUtils.LONG_SIZE, + UcxUtils.getAddress(workerAddress), workerAddress.capacity().toLong, UcpConstants.UCP_AM_SEND_FLAG_EAGER, + new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + header.clear() + workerAddress.clear() + } + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to send $errorMsg") + } + }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + ep + } }) } def fetchBlocksByBlockIds(executorId: transport.ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: transport.BufferAllocator, - callbacks: Seq[OperationCallback]): Seq[Request] = { + callbacks: Seq[OperationCallback]): Unit = { val startTime = System.nanoTime() val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.LONG_SIZE - val ep = getConnection(executorId) - - if (worker.getMaxAmHeaderSize <= - headerSize + UnsafeUtils.INT_SIZE * blockIds.length) { - val (b1, b2) = blockIds.splitAt(blockIds.length / 2) - val (c1, c2) = callbacks.splitAt(callbacks.length / 2) - val r1 = fetchBlocksByBlockIds(executorId, b1, resultBufferAllocator, c1) - val r2 = fetchBlocksByBlockIds(executorId, b2, resultBufferAllocator, c2) - return r1 ++ r2 - } val t = tag.incrementAndGet() @@ -258,32 +269,31 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i blockIds.foreach(b => b.serialize(buffer)) val request = new UcxRequest(null, new UcxStats()) - requestData.put(t, (callbacks, request, resultBufferAllocator)) + requestData.put(t, new UcxFetchState(callbacks, request, startTime)) buffer.rewind() val address = UnsafeUtils.getAdress(buffer) val dataAddress = address + headerSize - ep.sendAmNonBlocking(0, address, - headerSize, dataAddress, buffer.capacity() - headerSize, - UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - buffer.clear() - logDebug(s"Sent message on $ep to $executorId to fetch ${blockIds.length} blocks on tag $t id $id" + - s"in ${System.nanoTime() - startTime} ns") - } - }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) - - worker.progressRequest(ep.flushNonBlocking(null)) - Seq(request) + val ep = getConnection(executorId) + worker.synchronized { + ep.sendAmNonBlocking(0, address, + headerSize, dataAddress, buffer.capacity() - headerSize, + UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + buffer.clear() + logDebug(s"Sent message on $ep to $executorId to fetch ${blockIds.length} blocks on tag $t id $id" + + s"in ${System.nanoTime() - startTime} ns") + } + }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + } } def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): Unit = try { val tagAndSizes = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blocks.length - val resultMemory = transport.hostBounceBufferMemoryPool.get(tagAndSizes + blocks.map(_.getSize).sum) - .asInstanceOf[UcxBounceBufferMemoryBlock] - val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, - resultMemory.size) + val msgSize = tagAndSizes + blocks.map(_.getSize).sum + val resultMemory = memPool.get(msgSize).asInstanceOf[UcxBounceBufferMemoryBlock] + val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, msgSize) resultBuffer.putInt(replyTag) var offset = 0 @@ -297,7 +307,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i localBuffer } // Do parallel read of blocks - val blocksCollection = if (transport.ucxShuffleConf.numIoThreads > 1) { + val blocksCollection = if (ioThreadOn) { val parCollection = blocks.indices.par parCollection.tasksupport = ioTaskSupport parCollection @@ -310,25 +320,41 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } val startTime = System.nanoTime() - val req = connections(replyExecutor).sendAmNonBlocking(1, resultMemory.address, tagAndSizes, - resultMemory.address + tagAndSizes, resultMemory.size - tagAndSizes, 0, new UcxCallback { - override def onSuccess(request: UcpRequest): Unit = { - logTrace(s"Sent ${blocks.length} blocks of size: ${resultMemory.size} " + - s"to tag $replyTag in ${System.nanoTime() - startTime} ns.") - transport.hostBounceBufferMemoryPool.put(resultMemory) - } - - override def onError(ucsStatus: Int, errorMsg: String): Unit = { - logError(s"Failed to send $errorMsg") - } - }, new UcpRequestParams().setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) - .setMemoryHandle(resultMemory.memory)) + worker.synchronized { + connections(replyExecutor).sendAmNonBlocking(1, resultMemory.address, tagAndSizes, + resultMemory.address + tagAndSizes, msgSize - tagAndSizes, 0, new UcxCallback { + override def onSuccess(request: UcpRequest): Unit = { + logTrace(s"Sent ${blocks.length} blocks of size: ${msgSize} " + + s"to tag $replyTag in ${System.nanoTime() - startTime} ns.") + resultMemory.close() + } - while (!req.isCompleted) { - progress() - } + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to send $errorMsg") + resultMemory.close() + } + }, new UcpRequestParams().setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + .setMemoryHandle(resultMemory.memory)) + } } catch { case ex: Throwable => logError(s"Failed to read and send data: $ex") } } + +private[ucx] class ProgressThread( + name: String, worker: UcpWorker, useWakeup: Boolean) extends Thread { + setDaemon(true) + setName(name) + + override def run(): Unit = { + while (!isInterrupted) { + worker.synchronized { + while (worker.progress() != 0) {} + } + if (useWakeup) { + worker.waitForEvents() + } + } + } +} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala index 16bd821b..3d9be266 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala @@ -9,6 +9,7 @@ import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import java.nio.channels.FileChannel +import java.util.concurrent.CountDownLatch import java.util.concurrent.atomic.AtomicInteger import org.apache.commons.cli.{GnuParser, HelpFormatter, Options} import org.apache.spark.SparkConf @@ -127,6 +128,7 @@ object UcxPerfBenchmark extends App with Logging { } for (_ <- 0 until options.numIterations) { + val latch = new CountDownLatch(options.numOutstanding) for (b <- blockCollection) { requestInFlight.set(options.numOutstanding) for (o <- 0 until options.numOutstanding) { @@ -142,12 +144,11 @@ object UcxPerfBenchmark extends App with Logging { (options.blockSize * options.numOutstanding * options.numThreads) / (1024.0 * 1024.0 * (stats.getElapsedTimeNs / 1e9))) } + latch.countDown } } - val requests = ucxTransport.fetchBlocksByBlockIds(1, blocks, resultBufferAllocator, callbacks) - while (!requests.forall(_.isCompleted)) { - ucxTransport.progress() - } + ucxTransport.fetchBlocksByBlockIds(1, blocks, resultBufferAllocator, callbacks) + latch.await } } ucxTransport.close() diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala index a9f27b83..d8e42404 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala @@ -4,10 +4,11 @@ */ package org.apache.spark.shuffle.ucx.rpc +import java.nio.ByteBuffer import org.openucx.jucx.ucp.{UcpAmData, UcpConstants, UcpEndpoint, UcpWorker} import org.openucx.jucx.ucs.UcsConstants import org.apache.spark.internal.Logging -import org.apache.spark.shuffle.ucx.UcxShuffleTransport +import org.apache.spark.shuffle.ucx.{UcxShuffleTransport, UcxShuffleBockId} import org.apache.spark.shuffle.utils.UnsafeUtils import org.apache.spark.util.ThreadUtils @@ -16,21 +17,19 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransp setDaemon(true) setName("Global worker progress thread") - private val replyWorkersThreadPool = ThreadUtils.newDaemonFixedThreadPool(transport.ucxShuffleConf.numListenerThreads, - "UcxListenerThread") - // Main RPC thread. Submit each RPC request to separate thread and send reply back from separate worker. globalWorker.setAmRecvHandler(0, (headerAddress: Long, headerSize: Long, amData: UcpAmData, _: UcpEndpoint) => { val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) val replyTag = header.getInt val replyExecutor = header.getLong - replyWorkersThreadPool.submit(new Runnable { - override def run(): Unit = { - transport.handleFetchBlockRequest(replyTag, amData, replyExecutor) - } - }) - UcsConstants.STATUS.UCS_INPROGRESS - }, UcpConstants.UCP_AM_FLAG_PERSISTENT_DATA | UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) + val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, + amData.getLength.toInt) + val blockNum = buffer.remaining() / UcxShuffleBockId.serializedSize + val blockIds = (0 until blockNum).map( + _ => UcxShuffleBockId.deserialize(buffer)) + transport.handleFetchBlockRequest(replyTag, blockIds, replyExecutor) + UcsConstants.STATUS.UCS_OK + }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) // AM to get worker address for client worker and connect server workers to it @@ -46,13 +45,12 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransp override def run(): Unit = { if (transport.ucxShuffleConf.useWakeup) { while (!isInterrupted) { - if (globalWorker.progress() == 0) { - globalWorker.waitForEvents() - } + while (globalWorker.progress != 0) {} + globalWorker.waitForEvents() } } else { while (!isInterrupted) { - globalWorker.progress() + while (globalWorker.progress != 0) {} } } }