Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,16 @@ package object config {
.booleanConf
.createWithDefault(false)

private[spark] val STORAGE_DECOMMISSION_FALLBACK_STORAGE_ALWAYS_READ =
ConfigBuilder("spark.storage.decommission.fallbackStorage.alwaysRead")
.doc("If true, Spark reads shuffle data only from fallback storage. " +
s"This requires ${STORAGE_DECOMMISSION_FALLBACK_STORAGE_PROACTIVE_RELIABLE.key} " +
s"to be true. This is useful to avoid disruption due to executor decommission " +
s"or executor failures, or to benchmark reading from the fallback storage.")
.version("4.2.0")
.booleanConf
.createWithDefault(false)

private[spark] val STORAGE_DECOMMISSION_SHUFFLE_MAX_DISK_SIZE =
ConfigBuilder("spark.storage.decommission.shuffleBlocks.maxDiskSize")
.doc("Maximum disk space to use to store shuffle blocks before rejecting remote " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging {
fallbackStorage.foreach(
_.copy(shuffleBlockInfo, blockManager, isAsyncCopy = false, reportBlockStatus = false)
)

// point map status directly to the fallback storage if always-read is enabled
if (FallbackStorage.isAlwaysRead(SparkEnv.get.conf)) {
mapStatus.foreach(_.updateLocation(FallbackStorage.FALLBACK_BLOCK_MANAGER_ID))
}
} else {
// we ignore exceptions that occur asynchronously, this is best-effort replication
// we do not want to defer the task in any way
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys._
import org.apache.spark.internal.config.{STORAGE_DECOMMISSION_FALLBACK_STORAGE_CLEANUP, STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH, STORAGE_DECOMMISSION_FALLBACK_STORAGE_PROACTIVE_ENABLED, STORAGE_DECOMMISSION_FALLBACK_STORAGE_PROACTIVE_RELIABLE}
import org.apache.spark.internal.config.{STORAGE_DECOMMISSION_FALLBACK_STORAGE_ALWAYS_READ, STORAGE_DECOMMISSION_FALLBACK_STORAGE_CLEANUP, STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH, STORAGE_DECOMMISSION_FALLBACK_STORAGE_PROACTIVE_ENABLED, STORAGE_DECOMMISSION_FALLBACK_STORAGE_PROACTIVE_RELIABLE}
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcTimeout}
Expand Down Expand Up @@ -176,6 +176,10 @@ private[spark] object FallbackStorage extends Logging {
isProactive(conf) &&
conf.get(STORAGE_DECOMMISSION_FALLBACK_STORAGE_PROACTIVE_RELIABLE)

def isAlwaysRead(conf: SparkConf): Boolean =
isProactive(conf) &&
conf.get(STORAGE_DECOMMISSION_FALLBACK_STORAGE_ALWAYS_READ)

def getPath(conf: SparkConf): Path =
new Path(conf.get(STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH).get)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle

import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.when
import org.scalatestplus.mockito.MockitoSugar.mock

import org.apache.spark.{HashPartitioner, Partition, ShuffleDependency, SparkConf, SparkContext, SparkEnv, SparkFunSuite, TaskContextImpl}
import org.apache.spark.internal.config.{STORAGE_DECOMMISSION_FALLBACK_STORAGE_ALWAYS_READ, STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH, STORAGE_DECOMMISSION_FALLBACK_STORAGE_PROACTIVE_ENABLED, STORAGE_DECOMMISSION_FALLBACK_STORAGE_PROACTIVE_RELIABLE}
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.api.ShuffleDriverComponents
import org.apache.spark.storage.{BlockManager, BlockManagerId, BlockManagerMaster, DiskBlockManager, FallbackStorage}

class ShuffleWriteProcessorSuite extends SparkFunSuite {

test("write returns map status given by writer.stop") {
doTest()
}

test("write returns map status with fallback storage location") {
val conf = new SparkConf(false)
conf.set("spark.app.id", "testing")
conf.set(STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH, "/tmp/")
conf.set(STORAGE_DECOMMISSION_FALLBACK_STORAGE_PROACTIVE_ENABLED, true)
conf.set(STORAGE_DECOMMISSION_FALLBACK_STORAGE_PROACTIVE_RELIABLE, true)
conf.set(STORAGE_DECOMMISSION_FALLBACK_STORAGE_ALWAYS_READ, true)
doTest(Some(conf), Some(FallbackStorage.FALLBACK_BLOCK_MANAGER_ID))
}

def doTest(
confOpt: Option[SparkConf] = None,
expectedMapStatusLocation: Option[BlockManagerId] = None): Unit = {
val conf = confOpt.getOrElse(new SparkConf(false))

val context = new TaskContextImpl(1, 1, 0, 1, 1, 2, null, null, null, cpus = 1)
val mapId = 1L

val bmId = BlockManagerId("exec-1", "host", 1234)
val mapStatus = MapStatus(bmId, Array(10L, 20L), mapId, 0L)

val writer = mock[ShuffleWriter[Int, Int]]
when(writer.stop(true)).thenReturn(Some(mapStatus))

val shuffleManager = mock[ShuffleManager]
when(shuffleManager.getWriter[Int, Int](any(), meq(mapId), meq(context), any()))
.thenReturn(writer)

val blockManager = mock[BlockManager]
val dbm = new DiskBlockManager(conf, deleteFilesOnStop = false, isDriver = false)
val bmm = mock[BlockManagerMaster]
when(blockManager.diskBlockManager).thenReturn(dbm)
when(blockManager.master).thenReturn(bmm)
val resolver = new IndexShuffleBlockResolver(conf, blockManager)
when(blockManager.migratableResolver).thenReturn(resolver)

val env = mock[SparkEnv]
SparkEnv.set(env)
when(env.conf).thenReturn(conf)
when(env.shuffleManager).thenReturn(shuffleManager)
when(env.blockManager).thenReturn(blockManager)

val sc = mock[SparkContext]
when(sc.env).thenReturn(env)
when(sc.conf).thenReturn(conf)
when(sc.newShuffleId()).thenReturn(1)
when(sc.cleaner).thenReturn(None)
when(sc.shuffleDriverComponents).thenReturn(mock[ShuffleDriverComponents])

val partitions = (0 to 1).toArray.map(id => new Partition { override def index: Int = id })
val rdd = mock[RDD[Product2[Int, Int]]]
when(rdd.context).thenReturn(sc)
when(rdd.sparkContext).thenReturn(sc)
when(rdd.partitions).thenReturn(partitions)

val it = Iterator.empty
val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2), null,
shuffleWriterProcessor = new ShuffleWriteProcessor())
val actualMapStatus = dep.shuffleWriterProcessor.write(it, dep, 1L, 0, context)
assert(actualMapStatus.location === expectedMapStatusLocation.getOrElse(bmId))
}

}
Loading