diff --git a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala index 0ed699dd..30ab8fb6 100644 --- a/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala +++ b/core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala @@ -324,29 +324,36 @@ class RayDPExecutor( val env = SparkEnv.get val context = SparkShimLoader.getSparkShims.getDummyTaskContext(partitionId, env) TaskContext.setTaskContext(context) - val schema = Schema.fromJSON(schemaStr) - val blockId = BlockId.apply("rdd_" + rddId + "_" + partitionId) - val iterator = env.blockManager.get(blockId)(classTag[Array[Byte]]) match { - case Some(blockResult) => - blockResult.data.asInstanceOf[Iterator[Array[Byte]]] - case None => - logWarning("The cached block has been lost. Cache it again via driver agent") - requestRecacheRDD(rddId, driverAgentUrl) - env.blockManager.get(blockId)(classTag[Array[Byte]]) match { - case Some(blockResult) => - blockResult.data.asInstanceOf[Iterator[Array[Byte]]] - case None => - throw new RayDPException("Still cannot get the block after recache!") - } + try { + val schema = Schema.fromJSON(schemaStr) + val blockId = BlockId.apply("rdd_" + rddId + "_" + partitionId) + val iterator = env.blockManager.get(blockId)(classTag[Array[Byte]]) match { + case Some(blockResult) => + blockResult.data.asInstanceOf[Iterator[Array[Byte]]] + case None => + logWarning("The cached block has been lost. Cache it again via driver agent") + requestRecacheRDD(rddId, driverAgentUrl) + env.blockManager.get(blockId)(classTag[Array[Byte]]) match { + case Some(blockResult) => + blockResult.data.asInstanceOf[Iterator[Array[Byte]]] + case None => + throw new RayDPException("Still cannot get the block after recache!") + } + } + val byteOut = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(byteOut)) + try { + MessageSerializer.serialize(writeChannel, schema) + iterator.foreach(writeChannel.write) + ArrowStreamWriter.writeEndOfStream(writeChannel, new IpcOption) + byteOut.toByteArray + } finally { + writeChannel.close() + byteOut.close() + } + } finally { + env.blockManager.releaseAllLocksForTask(context.taskAttemptId()) + TaskContext.unset() } - val byteOut = new ByteArrayOutputStream() - val writeChannel = new WriteChannel(Channels.newChannel(byteOut)) - MessageSerializer.serialize(writeChannel, schema) - iterator.foreach(writeChannel.write) - ArrowStreamWriter.writeEndOfStream(writeChannel, new IpcOption) - val result = byteOut.toByteArray - writeChannel.close - byteOut.close - result } }