From 8467c27d88b9c4c85cecd36a19d848b05017a33c Mon Sep 17 00:00:00 2001 From: Rexwell Minnis Date: Tue, 17 Feb 2026 09:50:12 -0500 Subject: [PATCH] Fix TaskContext leak and resource cleanup in getRDDPartition getRDDPartition sets a dummy TaskContext but never unsets it. Since executor actors reuse threads (maxConcurrency=2), a stale TaskContext from one call leaks into subsequent operations on the same thread, which can cause subtle issues with metrics, accumulators, and task-local state. Additionally, if an exception occurs mid-method, neither the WriteChannel/ByteArrayOutputStream nor the BlockManager read locks are cleaned up. Wrap the method body in try/finally to guarantee: - BlockManager read locks are released via releaseAllLocksForTask - TaskContext is unset via TaskContext.unset() - WriteChannel and ByteArrayOutputStream are closed. --- .../apache/spark/executor/RayDPExecutor.scala | 53 +++++++++++-------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala index 0ed699dd..30ab8fb6 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala @@ -324,29 +324,36 @@ class RayDPExecutor( val env = SparkEnv.get val context = SparkShimLoader.getSparkShims.getDummyTaskContext(partitionId, env) TaskContext.setTaskContext(context) - val schema = Schema.fromJSON(schemaStr) - val blockId = BlockId.apply("rdd_" + rddId + "_" + partitionId) - val iterator = env.blockManager.get(blockId)(classTag[Array[Byte]]) match { - case Some(blockResult) => - blockResult.data.asInstanceOf[Iterator[Array[Byte]]] - case None => - logWarning("The cached block has been lost. Cache it again via driver agent") - requestRecacheRDD(rddId, driverAgentUrl) - env.blockManager.get(blockId)(classTag[Array[Byte]]) match { - case Some(blockResult) => - blockResult.data.asInstanceOf[Iterator[Array[Byte]]] - case None => - throw new RayDPException("Still cannot get the block after recache!") - } + try { + val schema = Schema.fromJSON(schemaStr) + val blockId = BlockId.apply("rdd_" + rddId + "_" + partitionId) + val iterator = env.blockManager.get(blockId)(classTag[Array[Byte]]) match { + case Some(blockResult) => + blockResult.data.asInstanceOf[Iterator[Array[Byte]]] + case None => + logWarning("The cached block has been lost. Cache it again via driver agent") + requestRecacheRDD(rddId, driverAgentUrl) + env.blockManager.get(blockId)(classTag[Array[Byte]]) match { + case Some(blockResult) => + blockResult.data.asInstanceOf[Iterator[Array[Byte]]] + case None => + throw new RayDPException("Still cannot get the block after recache!") + } + } + val byteOut = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(byteOut)) + try { + MessageSerializer.serialize(writeChannel, schema) + iterator.foreach(writeChannel.write) + ArrowStreamWriter.writeEndOfStream(writeChannel, new IpcOption) + byteOut.toByteArray + } finally { + writeChannel.close() + byteOut.close() + } + } finally { + env.blockManager.releaseAllLocksForTask(context.taskAttemptId()) + TaskContext.unset() } - val byteOut = new ByteArrayOutputStream() - val writeChannel = new WriteChannel(Channels.newChannel(byteOut)) - MessageSerializer.serialize(writeChannel, schema) - iterator.foreach(writeChannel.write) - ArrowStreamWriter.writeEndOfStream(writeChannel, new IpcOption) - val result = byteOut.toByteArray - writeChannel.close - byteOut.close - result } }