diff --git a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java index 59df0423fad26..c70382bf93190 100644 --- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java +++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java @@ -232,6 +232,7 @@ public enum LogKeys implements LogKey { EXPR, EXPR_TERMS, EXTENDED_EXPLAIN_GENERATOR, + FALLBACK_STORAGE_BLOCKS_SIZE, FAILED_STAGE, FAILED_STAGE_NAME, FAILURES, @@ -473,6 +474,7 @@ public enum LogKeys implements LogKey { NUM_EXECUTOR_DESIRED, NUM_EXECUTOR_LAUNCH, NUM_EXECUTOR_TARGET, + NUM_FALLBACK_STORAGE_BLOCKS, NUM_FAILURES, NUM_FEATURES, NUM_FILES, diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 9fee7a36a0445..feddf7d12b340 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1501,6 +1501,13 @@ package object config { "maxRemoteBlockSizeFetchToMem cannot be larger than (Int.MaxValue - 512) bytes.") .createWithDefaultString("200m") + private[spark] val REDUCER_FALLBACK_STORAGE_READ_THREADS = + ConfigBuilder("spark.reducer.fallbackStorage.readThreads") + .doc("Number of threads used by the reducer to read shuffle blocks from fallback storage.") + .version("4.2.0") + .intConf + .createWithDefault(5) + private[spark] val TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES = ConfigBuilder("spark.taskMetrics.trackUpdatedBlockStatuses") .doc("Enable tracking of updatedBlockStatuses in the TaskMetrics. Off by default since " + diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 89177346a789a..6350c2eef785c 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -53,6 +53,12 @@ trait BlockDataManager { */ def getLocalBlockData(blockId: BlockId): ManagedBuffer + /** + * Interface to get fallback storage block data. Throws an exception if the block cannot be found + * or cannot be read successfully. + */ + def getFallbackStorageBlockData(blockId: BlockId): ManagedBuffer + /** * Put the block locally, using the given storage level. * diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 7918d1618eb06..381089ff8bdd2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -83,6 +83,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), + SparkEnv.get.conf.get(config.REDUCER_FALLBACK_STORAGE_READ_THREADS), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED), diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 5fbc8dca74f68..3e69af157693a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -759,16 +759,7 @@ private[spark] class BlockManager( override def getLocalBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { logDebug(s"Getting local shuffle block ${blockId}") - try { - shuffleManager.shuffleBlockResolver.getBlockData(blockId) - } catch { - case e: IOException => - if (conf.get(config.STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH).isDefined) { - FallbackStorage.read(conf, blockId) - } else { - throw e - } - } + shuffleManager.shuffleBlockResolver.getBlockData(blockId) } else { getLocalBytes(blockId) match { case Some(blockData) => @@ -783,6 +774,25 @@ private[spark] class BlockManager( } } + /** + * Interface to get fallback storage block data. Throws an exception if the block cannot be found + * or cannot be read successfully. + */ + override def getFallbackStorageBlockData(blockId: BlockId): ManagedBuffer = { + require(conf.get(config.STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH).isDefined) + + if (blockId.isShuffle) { + logDebug(s"Getting fallback storage block ${blockId}") + FallbackStorage.read(conf, blockId) + } else { + // If this block manager receives a request for a block that it doesn't have then it's + // likely that the master has outdated block statuses for this block. Therefore, we send + // an RPC so that this block is marked as being unavailable from this block manager. + reportBlockStatus(blockId, BlockStatus.empty) + throw SparkCoreErrors.blockNotFoundError(blockId) + } + } + /** * Put the block locally, using the given storage level. * diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index cc552a2985f7e..f8dddee9cacac 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import java.io.{InputStream, IOException} import java.nio.channels.ClosedByInterruptException -import java.util.concurrent.{LinkedBlockingDeque, TimeUnit} +import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.CheckedInputStream import javax.annotation.concurrent.GuardedBy @@ -27,6 +27,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} import scala.util.{Failure, Success} import io.netty.util.internal.OutOfDirectMemoryError @@ -37,12 +38,12 @@ import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID import org.apache.spark.errors.SparkCoreErrors import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys._ -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} import org.apache.spark.network.util.{NettyUtils, TransportConf} import org.apache.spark.shuffle.ShuffleReadMetricsReporter -import org.apache.spark.util.{Clock, CompletionIterator, SystemClock, TaskCompletionListener, Utils} +import org.apache.spark.util.{Clock, CompletionIterator, SystemClock, TaskCompletionListener, ThreadUtils, Utils} /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -73,6 +74,7 @@ import org.apache.spark.util.{Clock, CompletionIterator, SystemClock, TaskComple * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param maxAttemptsOnNettyOOM The max number of a block could retry due to Netty OOM before * throwing the fetch failure. + * @param fallbackStorageReadThreads number of threads reading concurrently from fallback storage * @param detectCorrupt whether to detect any corruption in fetched blocks. * @param checksumEnabled whether the shuffle checksum is enabled. When enabled, Spark will try to * diagnose the cause of the block corruption. @@ -95,6 +97,7 @@ final class ShuffleBlockFetcherIterator( maxBlocksInFlightPerAddress: Int, val maxReqSizeShuffleToMem: Long, maxAttemptsOnNettyOOM: Int, + fallbackStorageReadThreads: Int, detectCorrupt: Boolean, detectCorruptUseExtraMemory: Boolean, checksumEnabled: Boolean, @@ -139,9 +142,26 @@ final class ShuffleBlockFetcherIterator( */ @volatile private[this] var currentResult: SuccessFetchResult = null + /** + * Queue of fallback storage requests to issue; we'll pull requests off this gradually to make + * sure that the number of bytes and requests in flight is limited to maxBytesInFlight and + * maxReqsInFlight. + */ + private[this] val fallbackStorageRequests = new Queue[FallbackStorageRequest] + + /** + * Thread pool reading from fallback storage, first creating FallbackStorageRequest from + * block id and map index, then materializing requests to SuccessFetchResult. + */ + // This is visible for testing + private[storage] val fallbackStorageReadPool: ThreadPoolExecutor = + ThreadUtils.newDaemonFixedThreadPool(fallbackStorageReadThreads, "fallback-storage-read") + private[this] val fallbackStorageReadContext: ExecutionContextExecutor = + ExecutionContext.fromExecutor(fallbackStorageReadPool) + /** * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - * the number of bytes in flight is limited to maxBytesInFlight. + * the number of bytes and requests in flight is limited to maxBytesInFlight and maxReqsInFlight. */ private[this] val fetchRequests = new Queue[FetchRequest] @@ -259,6 +279,35 @@ final class ShuffleBlockFetcherIterator( logWarning(log"Failed to cleanup shuffle fetch temp file ${MDC(PATH, file.path())}") } } + fallbackStorageReadPool.shutdownNow() + } + + private[this] def createFallbackStorageRequest(blockId: BlockId, mapIndex: Int): Unit = { + Future { + if (!isZombie) { + try { + val block = blockManager.getFallbackStorageBlockData(blockId) + val request = FallbackStorageRequest(blockId, mapIndex, block) + results.put(PreparedFallbackStorageRequestResult(request)) + } catch { + case e: Throwable => + // the FailureFetchResult will stop iteration of this iterator + // task completion listener will shut down the thread pool / execution context + // the synchronized protects isZombie and blocks cleanup() from calling + // fallbackStorageReadPool.shutdownNow(), which would interrupt results.put + // that interrupted exception would kill the executor + synchronized { + if (!isZombie) { + logError(log"Failed to prepare request to read block ${MDC(BLOCK_ID, blockId)} " + + log"from fallback storage", e) + val result = FailureFetchResult( + blockId, mapIndex, FallbackStorage.FALLBACK_BLOCK_MANAGER_ID, e) + results.putFirst(result) + } + } + } + } + }(fallbackStorageReadContext) } private[this] def sendRequest(req: FetchRequest): Unit = { @@ -393,7 +442,8 @@ final class ShuffleBlockFetcherIterator( localBlocks: mutable.LinkedHashSet[(BlockId, Int)], hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]], - pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId], + fallbackStorageBlocks: mutable.LinkedHashSet[(BlockId, Int)]): ArrayBuffer[FetchRequest] = { logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") @@ -402,13 +452,15 @@ final class ShuffleBlockFetcherIterator( // in order to limit the amount of data in flight val collectedRemoteRequests = new ArrayBuffer[FetchRequest] var localBlockBytes = 0L + var fallbackStorageBlockBytes = 0L var hostLocalBlockBytes = 0L var numHostLocalBlocks = 0 var pushMergedLocalBlockBytes = 0L val prevNumBlocksToFetch = numBlocksToFetch - val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId - val localExecIds = Set(blockManager.blockManagerId.executorId, fallback) + val localExecId = blockManager.blockManagerId.executorId + val fallbackExecId = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId + val localAndFallbackExecIds = Set(localExecId, fallbackExecId) for ((address, blockInfos) <- blocksByAddress) { checkBlockSizes(blockInfos) if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) { @@ -420,12 +472,17 @@ final class ShuffleBlockFetcherIterator( } else { collectFetchRequests(address, blockInfos, collectedRemoteRequests) } - } else if (localExecIds.contains(address.executorId)) { + } else if (localAndFallbackExecIds.contains(address.executorId)) { val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) numBlocksToFetch += mergedBlockInfos.size - localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) - localBlockBytes += mergedBlockInfos.map(_.size).sum + if (address.executorId == localExecId) { + localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + localBlockBytes += mergedBlockInfos.map(_.size).sum + } else { + fallbackStorageBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + fallbackStorageBlockBytes += mergedBlockInfos.map(_.size).sum + } } else if (blockManager.hostLocalDirManager.isDefined && address.host == blockManager.blockManagerId.host) { val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( @@ -445,13 +502,14 @@ final class ShuffleBlockFetcherIterator( } val (remoteBlockBytes, numRemoteBlocks) = collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) - val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + - pushMergedLocalBlockBytes + val totalBytes = localBlockBytes + fallbackStorageBlockBytes + remoteBlockBytes + + hostLocalBlockBytes + pushMergedLocalBlockBytes val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch - assert(blocksToFetchCurrentIteration == localBlocks.size + + assert(blocksToFetchCurrentIteration == localBlocks.size + fallbackStorageBlocks.size + numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size, s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " + s"of the number of local blocks ${localBlocks.size} + " + + s"the number of fallback storage blocks ${fallbackStorageBlocks.size} + " + s"the number of host-local blocks ${numHostLocalBlocks} " + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + s"+ the number of remote blocks ${numRemoteBlocks} ") @@ -459,8 +517,10 @@ final class ShuffleBlockFetcherIterator( log"Getting ${MDC(NUM_BLOCKS, blocksToFetchCurrentIteration)} " + log"(${MDC(TOTAL_SIZE, Utils.bytesToString(totalBytes))}) non-empty blocks including " + log"${MDC(NUM_LOCAL_BLOCKS, localBlocks.size)} " + - log"(${MDC(LOCAL_BLOCKS_SIZE, Utils.bytesToString(localBlockBytes))}) local and " + - log"${MDC(NUM_HOST_LOCAL_BLOCKS, numHostLocalBlocks)} " + + log"(${MDC(LOCAL_BLOCKS_SIZE, Utils.bytesToString(localBlockBytes))}) " + + log"local and ${MDC(NUM_FALLBACK_STORAGE_BLOCKS, fallbackStorageBlocks.size)} " + + log"(${MDC(FALLBACK_STORAGE_BLOCKS_SIZE, Utils.bytesToString(fallbackStorageBlockBytes))}) " + + log"fallback storage and ${MDC(NUM_HOST_LOCAL_BLOCKS, numHostLocalBlocks)} " + log"(${MDC(HOST_LOCAL_BLOCKS_SIZE, Utils.bytesToString(hostLocalBlockBytes))}) " + log"host-local and ${MDC(NUM_PUSH_MERGED_LOCAL_BLOCKS, pushMergedLocalBlocks.size)} " + log"(${MDC(PUSH_MERGED_LOCAL_BLOCKS_SIZE, Utils.bytesToString(pushMergedLocalBlockBytes))})" + @@ -712,13 +772,22 @@ final class ShuffleBlockFetcherIterator( context.addTaskCompletionListener(onCompleteCallback) // Local blocks to fetch, excluding zero-sized blocks. val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val fallbackBlocks = mutable.LinkedHashSet[(BlockId, Int)]() val hostLocalBlocksByExecutor = mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() - // Partition blocks by the different fetch modes: local, host-local, push-merged-local and - // remote blocks. + + // Partition blocks by the different fetch modes: local, host-local, push-merged-local, + // fallback storage and remote blocks. val remoteRequests = partitionBlocksByFetchMode( - blocksByAddress, localBlocks, hostLocalBlocksByExecutor, pushMergedLocalBlocks) + blocksByAddress, localBlocks, hostLocalBlocksByExecutor, + pushMergedLocalBlocks, fallbackBlocks) + + // Turn the fallback storage blocks into read requests in random order. + Utils.randomize(fallbackBlocks).foreach { case (blockId, mapIndex) => + createFallbackStorageRequest(blockId, mapIndex) + } + // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(remoteRequests) assert ((0 == reqsInFlight) == (0 == bytesInFlight), @@ -738,6 +807,7 @@ final class ShuffleBlockFetcherIterator( // Get Local Blocks fetchLocalBlocks(localBlocks) logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + // Get host local blocks if any withFetchWaitTimeTracked(fetchAllHostLocalBlocks(hostLocalBlocksByExecutor)) pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks) @@ -831,7 +901,9 @@ final class ShuffleBlockFetcherIterator( // It is a host local block or a local shuffle chunk shuffleMetricsUpdate(blockId, buf, local = true) } else { - numBlocksInFlightPerAddress(address) -= 1 + if (address != FallbackStorage.FALLBACK_BLOCK_MANAGER_ID) { + numBlocksInFlightPerAddress(address) -= 1 + } shuffleMetricsUpdate(blockId, buf, local = false) bytesInFlight -= size } @@ -998,6 +1070,10 @@ final class ShuffleBlockFetcherIterator( defReqQueue.enqueue(request) result = null + case PreparedFallbackStorageRequestResult(request) => + fallbackStorageRequests.enqueue(request) + result = null + case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) => // We get this result in 3 cases: // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the @@ -1188,13 +1264,18 @@ final class ShuffleBlockFetcherIterator( } } + // Send fallback storage requests up to maxBytesInFlight + while (isBlockFetchable(fallbackStorageRequests)) { + sendFallbackStorageRequest(fallbackStorageRequests.dequeue()) + } + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host // immediately, defer the request until the next time it can be processed. // Process any outstanding deferred fetch requests if possible. if (deferredFetchRequests.nonEmpty) { for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { - while (isRemoteBlockFetchable(defReqQueue) && + while (isBlockFetchable(defReqQueue) && !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) { val request = defReqQueue.dequeue() logDebug(s"Processing deferred fetch request for $remoteAddress with " @@ -1208,7 +1289,7 @@ final class ShuffleBlockFetcherIterator( } // Process any regular fetch requests if possible. - while (isRemoteBlockFetchable(fetchRequests)) { + while (isBlockFetchable(fetchRequests)) { val request = fetchRequests.dequeue() val remoteAddress = request.address if (isRemoteAddressMaxedOut(remoteAddress, request)) { @@ -1231,7 +1312,45 @@ final class ShuffleBlockFetcherIterator( numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size } - def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { + def sendFallbackStorageRequest(request: FallbackStorageRequest): Unit = { + bytesInFlight += request.size + reqsInFlight += 1 + + Future { + if (!isZombie) { + logDebug(log"Reading block ${MDC(BLOCK_ID, request.blockId)} from fallback storage") + try { + // materialize the block ManagedBuffer and store data in SuccessFetchResult + val buf = new NioManagedBuffer(request.block.nioByteBuffer()) + // TODO: add fallback storage metrics + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + val result = SuccessFetchResult( + request.blockId, request.mapIndex, FallbackStorage.FALLBACK_BLOCK_MANAGER_ID, + request.size, buf, isNetworkReqDone = true) + results.put(result) + } catch { + case e: Throwable => + // the FailureFetchResult will stop iteration of this iterator + // task completion listener will shut down the thread pool / execution context + // the synchronized protects isZombie and blocks cleanup() from calling + // fallbackStorageReadPool.shutdownNow(), which would interrupt results.put + // that interrupted exception would kill the executor + synchronized { + if (!isZombie) { + logError(log"Failed to read block ${MDC(BLOCK_ID, request.blockId)} " + + log"from fallback storage", e) + val result = FailureFetchResult( + request.blockId, request.mapIndex, FallbackStorage.FALLBACK_BLOCK_MANAGER_ID, e) + results.putFirst(result) + } + } + } + } + }(fallbackStorageReadContext) + } + + def isBlockFetchable[T <: Request](fetchReqQueue: Queue[T]): Boolean = { fetchReqQueue.nonEmpty && (bytesInFlight == 0 || (reqsInFlight + 1 <= maxReqsInFlight && @@ -1287,11 +1406,17 @@ final class ShuffleBlockFetcherIterator( originalBlocksByAddr: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]): Unit = { val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val originalFallbackStorageBlocks = mutable.LinkedHashSet[(BlockId, Int)]() val originalHostLocalBlocksByExecutor = mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() val originalRemoteReqs = partitionBlocksByFetchMode(originalBlocksByAddr, - originalLocalBlocks, originalHostLocalBlocksByExecutor, originalMergedLocalBlocks) + originalLocalBlocks, originalHostLocalBlocksByExecutor, + originalMergedLocalBlocks, originalFallbackStorageBlocks) + // Turn the fallback storage blocks into read requests in random order. + Utils.randomize(originalFallbackStorageBlocks).foreach { case (blockId, mapIndex) => + createFallbackStorageRequest(blockId, mapIndex) + } // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(originalRemoteReqs) logInfo(log"Created ${MDC(NUM_REQUESTS, originalRemoteReqs.size)} fallback remote requests " + @@ -1538,6 +1663,10 @@ object ShuffleBlockFetcherIterator { result } + private[storage] trait Request { + val size: Long + } + /** * The block information to fetch used in FetchRequest. * @param blockId block id @@ -1560,10 +1689,25 @@ object ShuffleBlockFetcherIterator { case class FetchRequest( address: BlockManagerId, blocks: collection.Seq[FetchBlockInfo], - forMergedMetas: Boolean = false) { + forMergedMetas: Boolean = false) extends Request { val size = blocks.map(_.size).sum } + /** + * A request to fetch blocks from the Fallback Storage. Holds block data lazily. + * We read the data asynchronously and multithreaded. The result is a SuccessFetchResult + * where buf contains the materialized data. + * @param blockId The block id to read + * @param mapIndex The mapId of the block + * @param block the block as a lazy ManagedBuffer + */ + case class FallbackStorageRequest( + blockId: BlockId, + mapIndex: Int, + block: ManagedBuffer) extends Request { + val size: Long = block.size() + } + /** * Result of a fetch from a remote block. */ @@ -1610,6 +1754,16 @@ object ShuffleBlockFetcherIterator { private[storage] case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult + /** + * Fetching block data from the fallback storage is a two-steps process: + * 1. read offset and size of the shuffle block from fallback storage + * 2. read the block data from fallback storage + * A PreparedFallbackStorageRequestResult is the outcome of the first step, + * the SuccessFetchResult is the outcome of the second step. + */ + private[storage] case class PreparedFallbackStorageRequestResult( + fallbackStorageRequest: FallbackStorageRequest) extends FetchResult + /** * Result of an un-successful fetch of either of these: * 1) Remote shuffle chunk. diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 08220a26010fc..fe07d6d2cbb90 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io._ import java.nio.ByteBuffer +import java.util import java.util.UUID import java.util.concurrent.{CompletableFuture, Semaphore} import java.util.zip.CheckedInputStream @@ -36,6 +37,8 @@ import org.mockito.Mockito.{doThrow, mock, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.roaringbitmap.RoaringBitmap +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext} import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID @@ -49,7 +52,7 @@ import org.apache.spark.storage.ShuffleBlockFetcherIterator._ import org.apache.spark.util.Utils -class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { +class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with Eventually { private var transfer: BlockTransferService = _ private var mapOutputTracker: MapOutputTracker = _ @@ -153,6 +156,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { val in = mock(classOf[InputStream]) when(in.read(any())).thenReturn(1) when(in.read(any(), any(), any())).thenReturn(1) + val buf = ByteBuffer.allocate(size) + util.Arrays.fill(buf.array(), 1.byteValue) + when(mockManagedBuffer.nioByteBuffer()).thenReturn(buf) when(mockManagedBuffer.createInputStream()).thenReturn(in) when(mockManagedBuffer.size()).thenReturn(size) mockManagedBuffer @@ -191,6 +197,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { maxBlocksInFlightPerAddress: Int = Int.MaxValue, maxReqSizeShuffleToMem: Int = Int.MaxValue, maxAttemptsOnNettyOOM: Int = 10, + fallbackStorageReadThreads: Int = 5, detectCorrupt: Boolean = true, detectCorruptUseExtraMemory: Boolean = true, checksumEnabled: Boolean = true, @@ -217,6 +224,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { maxBlocksInFlightPerAddress, maxReqSizeShuffleToMem, maxAttemptsOnNettyOOM, + fallbackStorageReadThreads, detectCorrupt, detectCorruptUseExtraMemory, checksumEnabled, @@ -340,7 +348,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { ShuffleBlockId(0, 9, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 10, 0) -> createMockManagedBuffer()) fallbackBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) + doReturn(buf).when(blockManager).getFallbackStorageBlockData(meq(blockId)) } val iterator = createShuffleBlockIteratorWithDefaults( @@ -353,9 +361,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { blockManager = Some(blockManager) ) - // 3 local blocks and 2 fallback blocks fetched in initialization - verify(blockManager, times(3 + 2)).getLocalBlockData(any()) + // 3 local blocks fetched in initialization + verify(blockManager, times(3)).getLocalBlockData(any()) + // 2 fallback storage blocks fetched in initialization + // initialize creates futures that eventually call into getFallbackStorageBlockData + eventually(timeout(1.seconds), interval(10.millis)) { + assert(iterator.fallbackStorageReadPool.getCompletedTaskCount >= 2) + } + verify(blockManager, times(2)).getFallbackStorageBlockData(any()) // SPARK-55469: but buffer data have never been materialized fallbackBlocks.values.foreach { mockBuf => verify(mockBuf, never()).nioByteBuffer() @@ -372,7 +386,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = allBlocks(blockId) - verifyBufferRelease(mockBuf, inputStream) + if (!fallbackBlocks.contains(blockId)) { + verifyBufferRelease(mockBuf, inputStream) + } } assert(!iterator.hasNext) @@ -386,8 +402,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // SPARK-55469: fallback buffer data have been materialized once fallbackBlocks.values.foreach { mockBuf => - verify(mockBuf, never()).nioByteBuffer() - verify(mockBuf, times(1)).createInputStream() + verify(mockBuf, times(1)).nioByteBuffer() + verify(mockBuf, never()).createInputStream() verify(mockBuf, never()).convertToNetty() verify(mockBuf, never()).convertToNettyForSsl() } @@ -502,7 +518,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { val mergedFallbackBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockBatchId(0, 1, 0, 2) -> createMockManagedBuffer()) mergedFallbackBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) + doReturn(buf).when(blockManager).getFallbackStorageBlockData(meq(blockId)) } // Make sure remote blocks would return the merged block @@ -544,9 +560,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { doBatchFetch = true ) - // 1 local merge block and 1 fallback merge block fetched in initialization - verify(blockManager, times(1 + 1)).getLocalBlockData(any()) + // 1 local merge block fetched in initialization + verify(blockManager, times(1)).getLocalBlockData(any()) + // 1 fallback merge block fetched in initialization + // initialize creates futures that eventually call into getFallbackStorageBlockData + eventually(timeout(1.seconds), interval(10.millis)) { + assert(iterator.fallbackStorageReadPool.getCompletedTaskCount >= 1) + } + verify(blockManager, times(1)).getFallbackStorageBlockData(any()) // SPARK-55469: but buffer data have never been materialized mergedFallbackBlocks.values.foreach { mockBuf => verify(mockBuf, never()).nioByteBuffer() @@ -563,7 +585,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { verifyFetchBlocksInvocationCount(1) // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = allBlocks(blockId) - verifyBufferRelease(mockBuf, inputStream) + if (!mergedFallbackBlocks.contains(blockId)) { + verifyBufferRelease(mockBuf, inputStream) + } } assert(!iterator.hasNext) @@ -571,15 +595,17 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { verify(blockManager, times(1)) .getHostLocalShuffleData(any(), meq(Array("local-dir"))) + // 1 merged remote block is read from the same block manager + verifyFetchBlocksInvocationCount(1) + assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1) + // SPARK-55469: merged fallback buffer data have been materialized once mergedFallbackBlocks.values.foreach { mockBuf => - verify(mockBuf, never()).nioByteBuffer() - verify(mockBuf, times(1)).createInputStream() + verify(mockBuf, times(1)).nioByteBuffer() + verify(mockBuf, never()).createInputStream() verify(mockBuf, never()).convertToNetty() verify(mockBuf, never()).convertToNettyForSsl() } - - assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1) } test("fetch continuous blocks in batch should respect maxBytesInFlight") { @@ -2137,46 +2163,4 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { assert(iterator.next()._1 === ShuffleBlockId(0, 1, 0)) assert(!iterator.hasNext) } - - test("Fast fail when failed to get fallback storage blocks") { - val blockManager = createMockBlockManager() - - // Make sure blockManager.getBlockData would return the blocks - val localBmId = blockManager.blockManagerId - val localBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer()) - localBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) - } - - // Make sure fallback storage would return the blocks - val fallbackBmId = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID - val fallbackBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer()) - fallbackBlocks.take(1).foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) - } - fallbackBlocks.takeRight(1).foreach { case (blockId, _) => - doThrow(new RuntimeException("Cannot read from fallback storage")) - .when(blockManager).getLocalBlockData(meq(blockId)) - } - - val iterator = createShuffleBlockIteratorWithDefaults( - Map( - localBmId -> toBlockList(localBlocks.keys, 1L, 0), - fallbackBmId -> toBlockList(fallbackBlocks.keys, 1L, 1) - ), - blockManager = Some(blockManager) - ) - - // Fetch failure should be placed in the head of results, exception should be thrown for the - // 1st instance. - intercept[FetchFailedException] { iterator.next() } - assert(iterator.next()._1 === ShuffleBlockId(0, 0, 0)) - assert(iterator.next()._1 === ShuffleBlockId(0, 1, 0)) - assert(iterator.next()._1 === ShuffleBlockId(0, 2, 0)) - assert(!iterator.hasNext) - } }