From 9b67bbbc83b37dc3e348d50f3fff4d2a730ae6e6 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 22 Jan 2026 10:30:20 +0800 Subject: [PATCH 01/20] Add ColumnarIndexShuffleBlockResolver --- .../ColumnarIndexShuffleBlockResolver.scala | 22 ++++++++++++++++ .../shuffle/sort/ColumnarShuffleManager.scala | 2 +- ...lumnarIndexShuffleBlockResolverSuite.scala | 25 +++++++++++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala create mode 100644 gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala new file mode 100644 index 000000000000..e5ec39a06003 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark.SparkConf +import org.apache.spark.shuffle._ + +class ColumnarIndexShuffleBlockResolver(conf: SparkConf) extends IndexShuffleBlockResolver(conf) {} diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala index 904c2dff6ce7..c4dd0db04a1a 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -39,7 +39,7 @@ class ColumnarShuffleManager(conf: SparkConf) import ColumnarShuffleManager._ private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) - override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + override val shuffleBlockResolver = new ColumnarIndexShuffleBlockResolver(conf) /** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */ private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala new file mode 100644 index 000000000000..556bf7cf5694 --- /dev/null +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +class ColumnarIndexShuffleBlockResolverSuite extends AnyFunSuite with BeforeAndAfterAll { + test("dummy test") { + assert(1 + 1 == 2) + } +} From 0be5dae78907dafa4735b168ce1201a05d7c6f79 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 22 Jan 2026 17:28:59 +0800 Subject: [PATCH 02/20] Add DiscontiguousFileRegion --- .../sort/DiscontiguousFileRegion.scala | 128 ++++++++++++++ .../sort/DiscontigousFileRegionSuite.scala | 162 ++++++++++++++++++ 2 files changed, 290 insertions(+) create mode 100644 gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala create mode 100644 gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala new file mode 100644 index 000000000000..0c12375974de --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort; + +// org.apache.spark.shuffle.sort has io objects conflicting with Netty's io objects. +import _root_.io.netty.channel.FileRegion +import _root_.io.netty.util.AbstractReferenceCounted + +import java.io.IOException +import java.nio.channels.{FileChannel, WritableByteChannel} + +/** + * A FileRegion that maps a continuous logical stream (0 to N) onto multiple discontiguous physical + * file segments. + */ +class DiscontiguousFileRegion( + private val fileChannel: FileChannel, + private val segments: Seq[(Long, Long)] // (Physical File Offset, Length) +) extends AbstractReferenceCounted + with FileRegion { + + require(segments.nonEmpty, "Must provide at least one segment") + + private val totalCount: Long = segments.map(_._2).sum + private var bytesTransferred: Long = 0L + + /** + * Transfers data starting from the LOGICAL 'position'. + * @param target + * The socket/channel to write to. + * @param position + * The relative position (0 to totalCount) within this specific region where the transfer should + * begin. + */ + override def transferTo(target: WritableByteChannel, position: Long): Long = { + var logicalPos = position + var totalWritten = 0L + + if (logicalPos >= totalCount) { + return 0L + } + + // 1. Locate the starting segment + var segmentIndex = 0 + var currentSegmentBase = 0L // Logical start of the current segment + + // Fast-forward to the segment containing 'position' + while (segmentIndex < segments.length) { + val (phyOffset, length) = segments(segmentIndex) + val currentSegmentEnd = currentSegmentBase + length + + if (logicalPos < currentSegmentEnd) { + // FOUND IT: The request starts inside this segment. + + // 2. Calculate offsets + val offsetInSegment = logicalPos - currentSegmentBase + val physicalPos = phyOffset + offsetInSegment + val remainingInSegment = length - offsetInSegment + + // 3. Perform the transfer (Zero-Copy) + val written = fileChannel.transferTo(physicalPos, remainingInSegment, target) + + if (written == 0) { + // Socket buffer full. Return what we have. + return totalWritten + } else if (written == -1) { + throw new IOException("EOF encountered in underlying file") + } + + // 4. Update state + bytesTransferred += written + totalWritten += written + logicalPos += written + + // Optimization: If we finished this segment exactly, loop immediately to the next + // so we don't return to the Netty loop just to call us back for the next byte. + if (written == remainingInSegment) { + currentSegmentBase += length + segmentIndex += 1 + // The loop continues to the next segment... + } else { + // We wrote only part of the segment (socket likely full). Done for now. + return totalWritten + } + + } else { + // Target position is further down. Skip this segment. + currentSegmentBase += length + segmentIndex += 1 + } + } + + totalWritten + } + + override def count(): Long = totalCount + + // This returns the position within the REGION, usually 0 for the start of the region object. + override def position(): Long = 0 + + override def transferred(): Long = bytesTransferred + @deprecated override def transfered(): Long = bytesTransferred + + // --- Reference Counting --- + override def touch(hint: Any): FileRegion = this + override def touch(): FileRegion = this + override def retain(): FileRegion = { super.retain(); this } + override def retain(increment: Int): FileRegion = { super.retain(increment); this } + + override protected def deallocate(): Unit = { + // Own the file descriptor, close it here. + fileChannel.close() + } +} diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala new file mode 100644 index 000000000000..6977a6c08ccd --- /dev/null +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.scalatest.funsuite.AnyFunSuite + +import java.io.{ByteArrayOutputStream, File, FileOutputStream, RandomAccessFile} +import java.nio.ByteBuffer +import java.nio.channels.{FileChannel, WritableByteChannel} +import java.nio.charset.StandardCharsets + +// The Helper Class (Mock Socket) +class ByteArrayWritableChannel extends WritableByteChannel { + private val out = new ByteArrayOutputStream() + private var open = true + + def toByteArray: Array[Byte] = out.toByteArray + + override def write(src: ByteBuffer): Int = { + val size = src.remaining() + val buf = new Array[Byte](size) + src.get(buf) + out.write(buf) + size + } + + override def isOpen: Boolean = open + override def close(): Unit = { open = false } +} + +class DiscontiguousFileRegionSuite + extends AnyFunSuite + with BeforeAndAfterEach + with BeforeAndAfterAll { + + var tempFile: File = _ + var raf: RandomAccessFile = _ + var fileChannel: FileChannel = _ + + val fileContent = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + override def beforeAll(): Unit = { + tempFile = File.createTempFile("netty-test", ".dat") + val fos = new FileOutputStream(tempFile) + fos.write(fileContent.getBytes(StandardCharsets.UTF_8)) + fos.close() + } + + override def beforeEach(): Unit = { + raf = new RandomAccessFile(tempFile, "r") + fileChannel = raf.getChannel + } + + override def afterEach(): Unit = { + if (fileChannel != null && fileChannel.isOpen) fileChannel.close() + if (raf != null) raf.close() + } + + override def afterAll(): Unit = { + if (tempFile != null) tempFile.delete() + } + + // --- TESTS --- + test("transfer a single segment correctly") { + // Segment: "BCD" (Offset 1, Length 3) + val region = new DiscontiguousFileRegion(fileChannel, Seq((1L, 3L))) + val target = new ByteArrayWritableChannel() + + val written = region.transferTo(target, 0) + + assert(written == 3, "Should have written exactly 3 bytes") + assert(new String(target.toByteArray) == "BCD", "Content should match the first 3 letters") + } + + test("concatenate multiple discontiguous segments") { + // Segment 1: "AB" (0, 2) + // Segment 2: "YZ" (24, 2) + val region = new DiscontiguousFileRegion(fileChannel, Seq((0L, 2L), (24L, 2L))) + val target = new ByteArrayWritableChannel() + + val written = region.transferTo(target, 0) + + assert(written == 4) + assert(new String(target.toByteArray) == "ABYZ") + } + + test("handle the 'position' parameter correctly (Mid-segment start)") { + // Combined View: "ABC" + "XYZ" = "ABCXYZ" + // Indices: 012 345 + val region = new DiscontiguousFileRegion(fileChannel, Seq((0L, 3L), (23L, 3L))) + val target = new ByteArrayWritableChannel() + + // Act: Start from logical position 2 (Letter 'C') + // Expectation: Should read 'C' (from seg1) then "XYZ" (from seg2) + val written = region.transferTo(target, 2) + + assert(written == 4) + assert(new String(target.toByteArray) == "CXYZ") + } + + test("handle starting exactly at the boundary of the second segment") { + // Combined View: "AB" (size 2) + "YZ" (size 2) + val region = new DiscontiguousFileRegion(fileChannel, Seq((0L, 2L), (24L, 2L))) + val target = new ByteArrayWritableChannel() + + // Act: Start from logical position 2 (Start of second segment) + val written = region.transferTo(target, 2) + + assert(written == 2) + assert(new String(target.toByteArray) == "YZ") + } + + test("handle a position that skips multiple initial segments") { + // Segments: "A", "B", "C", "D" + val region = new DiscontiguousFileRegion( + fileChannel, + Seq( + (0L, 1L), + (1L, 1L), + (2L, 1L), + (3L, 1L) + )) + val target = new ByteArrayWritableChannel() + + // Act: Start at logical pos 3 (Should be only "D") + val written = region.transferTo(target, 3) + + assert(written == 1) + assert(new String(target.toByteArray) == "D") + } + + test("return 0 if position is beyond total size") { + val region = new DiscontiguousFileRegion(fileChannel, Seq((0L, 5L))) + val target = new ByteArrayWritableChannel() + + val written = region.transferTo(target, 100) + + assert(written == 0) + assert(target.toByteArray.length == 0) + } + + test("file channel should closed after release") { + val region = new DiscontiguousFileRegion(fileChannel, Seq((0L, 5L))) + region.release() + assert(!fileChannel.isOpen) + } +} From 74e9692e2574c4bbef8245d0d31a5b9731cb930f Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Fri, 23 Jan 2026 09:30:43 +0800 Subject: [PATCH 03/20] Add FileSegmentsBuffer#createInputStream --- .../shuffle/sort/FileSegmentsBuffer.scala | 95 +++++++++++++++++ .../sort/FileSegmentsManagedBufferSuite.scala | 100 ++++++++++++++++++ 2 files changed, 195 insertions(+) create mode 100644 gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala create mode 100644 gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala new file mode 100644 index 000000000000..0b94259ede31 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.network.util.LimitedInputStream +import org.apache.spark.network.util.TransportConf + +import java.io.{BufferedInputStream, EOFException, File, FileInputStream, InputStream, IOException, SequenceInputStream} +import java.nio.ByteBuffer +import java.util.Vector + +/** A {@link ManagedBuffer} backed by a set of segments in a file. */ +class FileSegmentsManagedBuffer( + conf: TransportConf, + file: File, + segments: Seq[(Long, Long)] +) extends ManagedBuffer { + private val totalCount: Long = segments.map(_._2).sum + + override def size(): Long = totalCount + + override def retain(): ManagedBuffer = this + override def release(): ManagedBuffer = this + + override def nioByteBuffer(): ByteBuffer = { + // Implement logic to return a ByteBuffer representing the file segments + throw new UnsupportedOperationException("Not implemented yet") + } + + @throws[IOException] + private def skipFully(in: InputStream, n: Long): Unit = { + var remaining = n + val toSkip = n + while (remaining > 0L) { + val amt = in.skip(remaining) + if (amt == 0L) { + // Try reading a byte to see if we are at EOF + if (in.read() == -1) { + val skipped = toSkip - remaining + throw new EOFException( + s"reached end of stream after skipping $skipped bytes; $toSkip bytes expected") + } + remaining -= 1 + } else { + remaining -= amt + } + } + } + + @throws[IOException] + override def createInputStream(): InputStream = { + val streams = new Vector[InputStream]() + val filesToClose = new Vector[FileInputStream]() + var shouldCloseFile = true + try { + segments.foreach { + case (offset, length) => + val fis = new FileInputStream(file) + filesToClose.add(fis) + skipFully(fis, offset) + streams.add(new BufferedInputStream(new LimitedInputStream(fis, length))) + } + shouldCloseFile = false + new SequenceInputStream(streams.elements()) + } finally { + if (shouldCloseFile) { + val it = filesToClose.elements() + while (it.hasMoreElements) { + JavaUtils.closeQuietly(it.nextElement()) + } + } + } + } + + override def convertToNetty(): AnyRef = { + // Implement logic to convert to Netty buffer if needed + throw new UnsupportedOperationException("Not implemented yet") + } +} diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala new file mode 100644 index 000000000000..7b941f42980d --- /dev/null +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark.network.util.TransportConf + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import java.io._ +import java.nio.file.Files + +class FileSegmentsManagedBufferSuite extends AnyFunSuite with BeforeAndAfterAll { + + private var tempFile: File = _ + private val fileData: Array[Byte] = (0 until 100).map(_.toByte).toArray + + override def beforeAll(): Unit = { + tempFile = Files.createTempFile("fsegments-test", ".bin").toFile + val fos = new FileOutputStream(tempFile) + fos.write(fileData) + fos.close() + } + + override def afterAll(): Unit = { + if (tempFile != null && tempFile.exists()) tempFile.delete() + } + + test("size returns sum of segment lengths") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq((2L, 10L), (20L, 5L), (50L, 15L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + assert(buf.size() == 10L + 5L + 15L) + } + + test("createInputStream reads single segment correctly") { + val conf = null.asInstanceOf[TransportConf] // Not used in this test + val segments = Seq((10L, 5L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val in = buf.createInputStream() + val read = new Array[Byte](5) + assert(in.read(read) == 5) + assert(read.sameElements(fileData.slice(10, 15))) + assert(in.read() == -1) + in.close() + } + + test("createInputStream reads multiple segments in sequence") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq((5L, 3L), (20L, 2L), (40L, 4L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val in = buf.createInputStream() + val out = new Array[Byte](9) + var total = 0 + var n = 0 + while (n > 0 || total == 0) { + n = in.read(out, total, out.length - total) + if (n > 0) total += n + else if (total == 0) n = 1 // ensure loop runs at least once + } + assert(total == 9) + assert(out.slice(0, 3).sameElements(fileData.slice(5, 8))) + assert(out.slice(3, 5).sameElements(fileData.slice(20, 22))) + assert(out.slice(5, 9).sameElements(fileData.slice(40, 44))) + assert(in.read() == -1) + in.close() + } + + test("createInputStream read nothing if segment exceeds file length") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq((105L, 10L)) // goes past EOF + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + // should raise EOFException or read nothing + try { + val in = buf.createInputStream() + val read = new Array[Byte](10) + assert(in.read(read) == -1) + in.close() + } catch { + case e: EOFException => + assert(e.getMessage.contains("reached end of stream")) + case t: Throwable => + fail(s"Unexpected exception: $t") + } + } +} From 15f2226e8a05018a5a8d916007248c5ac83ca796 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Fri, 23 Jan 2026 10:52:07 +0800 Subject: [PATCH 04/20] Impl FileSegmentsBuffer#nioByteBuffer --- .../shuffle/sort/FileSegmentsBuffer.scala | 28 +++++++++++++-- .../sort/FileSegmentsManagedBufferSuite.scala | 34 +++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala index 0b94259ede31..81feda8a5aa5 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala @@ -21,7 +21,7 @@ import org.apache.spark.network.util.JavaUtils import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.network.util.TransportConf -import java.io.{BufferedInputStream, EOFException, File, FileInputStream, InputStream, IOException, SequenceInputStream} +import java.io.{BufferedInputStream, EOFException, File, FileInputStream, InputStream, IOException, RandomAccessFile, SequenceInputStream} import java.nio.ByteBuffer import java.util.Vector @@ -39,8 +39,30 @@ class FileSegmentsManagedBuffer( override def release(): ManagedBuffer = this override def nioByteBuffer(): ByteBuffer = { - // Implement logic to return a ByteBuffer representing the file segments - throw new UnsupportedOperationException("Not implemented yet") + val buffer = ByteBuffer.allocate(size().toInt) + val channel = new RandomAccessFile(file, "r").getChannel + try { + var destPos = 0 + segments.foreach { + case (offset, length) => + buffer.position(destPos) + buffer.limit(destPos + length.toInt) + channel.position(offset) + var remaining = length + while (remaining > 0) { + val n = channel.read(buffer) + if (n == -1) { + throw new EOFException(s"EOF reached while reading segment at offset $offset") + } + remaining -= n + } + destPos += length.toInt + } + buffer.flip() + buffer + } finally { + channel.close() + } } @throws[IOException] diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala index 7b941f42980d..21be9281055e 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala @@ -47,6 +47,40 @@ class FileSegmentsManagedBufferSuite extends AnyFunSuite with BeforeAndAfterAll assert(buf.size() == 10L + 5L + 15L) } + test("nioByteBuffer reads single segment correctly") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq((10L, 5L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val bb = buf.nioByteBuffer() + val arr = new Array[Byte](5) + bb.get(arr) + assert(arr.sameElements(fileData.slice(10, 15))) + assert(!bb.hasRemaining) + } + + test("nioByteBuffer reads multiple segments in sequence") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq((5L, 3L), (20L, 2L), (40L, 4L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val bb = buf.nioByteBuffer() + val arr = new Array[Byte](9) + bb.get(arr) + assert(arr.slice(0, 3).sameElements(fileData.slice(5, 8))) + assert(arr.slice(3, 5).sameElements(fileData.slice(20, 22))) + assert(arr.slice(5, 9).sameElements(fileData.slice(40, 44))) + assert(!bb.hasRemaining) + } + + test("nioByteBuffer throws EOFException if segment exceeds file length") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq((95L, 10L)) // goes past EOF + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val thrown = intercept[EOFException] { + buf.nioByteBuffer() + } + assert(thrown.getMessage.contains("EOF reached while reading segment")) + } + test("createInputStream reads single segment correctly") { val conf = null.asInstanceOf[TransportConf] // Not used in this test val segments = Seq((10L, 5L)) From 54d9dd18a10106811f17c98adf9726210a74aace Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Fri, 23 Jan 2026 13:09:54 +0800 Subject: [PATCH 05/20] Add mmap opt for single large file --- .../shuffle/sort/FileSegmentsBuffer.scala | 8 ++++++- .../sort/FileSegmentsManagedBufferSuite.scala | 23 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala index 81feda8a5aa5..3107287fe5eb 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala @@ -23,6 +23,7 @@ import org.apache.spark.network.util.TransportConf import java.io.{BufferedInputStream, EOFException, File, FileInputStream, InputStream, IOException, RandomAccessFile, SequenceInputStream} import java.nio.ByteBuffer +import java.nio.channels.FileChannel import java.util.Vector /** A {@link ManagedBuffer} backed by a set of segments in a file. */ @@ -42,6 +43,11 @@ class FileSegmentsManagedBuffer( val buffer = ByteBuffer.allocate(size().toInt) val channel = new RandomAccessFile(file, "r").getChannel try { + if (conf != null && size() >= conf.memoryMapBytes() && segments.length == 1) { + // zero-copy for single segment that is large enough + val (offset, length) = segments.head + return channel.map(FileChannel.MapMode.READ_ONLY, offset, length) + } var destPos = 0 segments.foreach { case (offset, length) => @@ -61,7 +67,7 @@ class FileSegmentsManagedBuffer( buffer.flip() buffer } finally { - channel.close() + JavaUtils.closeQuietly(channel) } } diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala index 21be9281055e..e20be7c8fe81 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala @@ -24,6 +24,10 @@ import org.scalatest.funsuite.AnyFunSuite import java.io._ import java.nio.file.Files +class FakeTransportConf extends TransportConf("test", null) { + override def memoryMapBytes(): Int = 4096 // 4 KB +} + class FileSegmentsManagedBufferSuite extends AnyFunSuite with BeforeAndAfterAll { private var tempFile: File = _ @@ -81,6 +85,25 @@ class FileSegmentsManagedBufferSuite extends AnyFunSuite with BeforeAndAfterAll assert(thrown.getMessage.contains("EOF reached while reading segment")) } + test("nioByteBuffer with mmap for a single segment") { + // Create a large file (16 KB) + val largeData = (0 until 16384).map(_.toByte).toArray + val largeFile = Files.createTempFile("fsegments-mmap-test", ".bin").toFile + val fos = new FileOutputStream(largeFile) + fos.write(largeData) + fos.close() + + val conf = new FakeTransportConf() + val segment = (0L, 16384L) + val buf = new FileSegmentsManagedBuffer(conf, largeFile, Seq(segment)) + val nioBuf = buf.nioByteBuffer() + val arr = new Array[Byte](16384) + nioBuf.get(arr) + assert(arr.sameElements(largeData)) + + largeFile.delete() + } + test("createInputStream reads single segment correctly") { val conf = null.asInstanceOf[TransportConf] // Not used in this test val segments = Seq((10L, 5L)) From c072f0e3b8a016bf1bded5e23e4639e84cea1db0 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Fri, 23 Jan 2026 14:17:36 +0800 Subject: [PATCH 06/20] DiscontiguousFileRegion support lazy open --- .../sort/DiscontiguousFileRegion.scala | 31 ++++++++-- .../sort/DiscontigousFileRegionSuite.scala | 59 +++++++++---------- 2 files changed, 55 insertions(+), 35 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala index 0c12375974de..2dbc3e85c3c5 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala @@ -20,16 +20,19 @@ package org.apache.spark.shuffle.sort; import _root_.io.netty.channel.FileRegion import _root_.io.netty.util.AbstractReferenceCounted +import java.io.File import java.io.IOException import java.nio.channels.{FileChannel, WritableByteChannel} +import java.nio.file.StandardOpenOption /** * A FileRegion that maps a continuous logical stream (0 to N) onto multiple discontiguous physical * file segments. */ class DiscontiguousFileRegion( - private val fileChannel: FileChannel, - private val segments: Seq[(Long, Long)] // (Physical File Offset, Length) + private val file: File, + private val segments: Seq[(Long, Long)], // (Physical File Offset, Length) + private val lazyOpen: Boolean = false // If true, delay opening the file until first use ) extends AbstractReferenceCounted with FileRegion { @@ -37,6 +40,21 @@ class DiscontiguousFileRegion( private val totalCount: Long = segments.map(_._2).sum private var bytesTransferred: Long = 0L + private var fileChannel: FileChannel = null + private var closed: Boolean = false + + if (!lazyOpen) { + ensureOpen() + } + + private def ensureOpen(): Unit = { + if (closed) { + throw new IOException("File is already closed") + } + if (fileChannel == null) { + fileChannel = FileChannel.open(file.toPath, StandardOpenOption.READ) + } + } /** * Transfers data starting from the LOGICAL 'position'. @@ -50,6 +68,8 @@ class DiscontiguousFileRegion( var logicalPos = position var totalWritten = 0L + ensureOpen() + if (logicalPos >= totalCount) { return 0L } @@ -122,7 +142,10 @@ class DiscontiguousFileRegion( override def retain(increment: Int): FileRegion = { super.retain(increment); this } override protected def deallocate(): Unit = { - // Own the file descriptor, close it here. - fileChannel.close() + if (fileChannel != null && !closed) { + fileChannel.close() + fileChannel = null + closed = true + } } } diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala index 6977a6c08ccd..8adf9146b35b 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala @@ -16,12 +16,12 @@ */ package org.apache.spark.shuffle.sort -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite -import java.io.{ByteArrayOutputStream, File, FileOutputStream, RandomAccessFile} +import java.io.{ByteArrayOutputStream, File, FileOutputStream} import java.nio.ByteBuffer -import java.nio.channels.{FileChannel, WritableByteChannel} +import java.nio.channels.WritableByteChannel import java.nio.charset.StandardCharsets // The Helper Class (Mock Socket) @@ -43,14 +43,19 @@ class ByteArrayWritableChannel extends WritableByteChannel { override def close(): Unit = { open = false } } -class DiscontiguousFileRegionSuite - extends AnyFunSuite - with BeforeAndAfterEach - with BeforeAndAfterAll { +class DiscontiguousFileRegionSuite extends AnyFunSuite with BeforeAndAfterAll { + + test("transferTo throws after release is called") { + val region = new DiscontiguousFileRegion(tempFile, Seq((0L, 3L)), lazyOpen = true) + region.release() + val target = new ByteArrayWritableChannel() + val thrown = intercept[Exception] { + region.transferTo(target, 0) + } + assert(thrown.getMessage.toLowerCase.contains("closed") || thrown.isInstanceOf[IllegalStateException]) + } var tempFile: File = _ - var raf: RandomAccessFile = _ - var fileChannel: FileChannel = _ val fileContent = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -61,16 +66,6 @@ class DiscontiguousFileRegionSuite fos.close() } - override def beforeEach(): Unit = { - raf = new RandomAccessFile(tempFile, "r") - fileChannel = raf.getChannel - } - - override def afterEach(): Unit = { - if (fileChannel != null && fileChannel.isOpen) fileChannel.close() - if (raf != null) raf.close() - } - override def afterAll(): Unit = { if (tempFile != null) tempFile.delete() } @@ -78,7 +73,7 @@ class DiscontiguousFileRegionSuite // --- TESTS --- test("transfer a single segment correctly") { // Segment: "BCD" (Offset 1, Length 3) - val region = new DiscontiguousFileRegion(fileChannel, Seq((1L, 3L))) + val region = new DiscontiguousFileRegion(tempFile, Seq((1L, 3L))) val target = new ByteArrayWritableChannel() val written = region.transferTo(target, 0) @@ -87,10 +82,18 @@ class DiscontiguousFileRegionSuite assert(new String(target.toByteArray) == "BCD", "Content should match the first 3 letters") } + test("transferTo works with lazyOpen=true") { + val region = new DiscontiguousFileRegion(tempFile, Seq((2L, 4L)), lazyOpen = true) + val target = new ByteArrayWritableChannel() + val written = region.transferTo(target, 0) + assert(written == 4) + assert(new String(target.toByteArray) == "CDEF") + } + test("concatenate multiple discontiguous segments") { // Segment 1: "AB" (0, 2) // Segment 2: "YZ" (24, 2) - val region = new DiscontiguousFileRegion(fileChannel, Seq((0L, 2L), (24L, 2L))) + val region = new DiscontiguousFileRegion(tempFile, Seq((0L, 2L), (24L, 2L))) val target = new ByteArrayWritableChannel() val written = region.transferTo(target, 0) @@ -102,7 +105,7 @@ class DiscontiguousFileRegionSuite test("handle the 'position' parameter correctly (Mid-segment start)") { // Combined View: "ABC" + "XYZ" = "ABCXYZ" // Indices: 012 345 - val region = new DiscontiguousFileRegion(fileChannel, Seq((0L, 3L), (23L, 3L))) + val region = new DiscontiguousFileRegion(tempFile, Seq((0L, 3L), (23L, 3L))) val target = new ByteArrayWritableChannel() // Act: Start from logical position 2 (Letter 'C') @@ -115,7 +118,7 @@ class DiscontiguousFileRegionSuite test("handle starting exactly at the boundary of the second segment") { // Combined View: "AB" (size 2) + "YZ" (size 2) - val region = new DiscontiguousFileRegion(fileChannel, Seq((0L, 2L), (24L, 2L))) + val region = new DiscontiguousFileRegion(tempFile, Seq((0L, 2L), (24L, 2L))) val target = new ByteArrayWritableChannel() // Act: Start from logical position 2 (Start of second segment) @@ -128,7 +131,7 @@ class DiscontiguousFileRegionSuite test("handle a position that skips multiple initial segments") { // Segments: "A", "B", "C", "D" val region = new DiscontiguousFileRegion( - fileChannel, + tempFile, Seq( (0L, 1L), (1L, 1L), @@ -145,7 +148,7 @@ class DiscontiguousFileRegionSuite } test("return 0 if position is beyond total size") { - val region = new DiscontiguousFileRegion(fileChannel, Seq((0L, 5L))) + val region = new DiscontiguousFileRegion(tempFile, Seq((0L, 5L))) val target = new ByteArrayWritableChannel() val written = region.transferTo(target, 100) @@ -153,10 +156,4 @@ class DiscontiguousFileRegionSuite assert(written == 0) assert(target.toByteArray.length == 0) } - - test("file channel should closed after release") { - val region = new DiscontiguousFileRegion(fileChannel, Seq((0L, 5L))) - region.release() - assert(!fileChannel.isOpen) - } } From d534477bbc5996cc4aa89969108687af5e1f4240 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Fri, 23 Jan 2026 15:21:40 +0800 Subject: [PATCH 07/20] Impl FileSegmentsManagedBuffer#convertToNetty --- .../shuffle/sort/FileSegmentsBuffer.scala | 5 +- .../sort/DiscontigousFileRegionSuite.scala | 11 ---- .../sort/FileSegmentsManagedBufferSuite.scala | 51 ++++++++++++++++++- 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala index 3107287fe5eb..6a0f72cf9cfa 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala @@ -117,7 +117,8 @@ class FileSegmentsManagedBuffer( } override def convertToNetty(): AnyRef = { - // Implement logic to convert to Netty buffer if needed - throw new UnsupportedOperationException("Not implemented yet") + var lazyOpen = false + if (conf != null) lazyOpen = conf.lazyFileDescriptor() + new DiscontiguousFileRegion(file, segments, lazyOpen) } } diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala index 8adf9146b35b..bfac0518ccb5 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala @@ -44,17 +44,6 @@ class ByteArrayWritableChannel extends WritableByteChannel { } class DiscontiguousFileRegionSuite extends AnyFunSuite with BeforeAndAfterAll { - - test("transferTo throws after release is called") { - val region = new DiscontiguousFileRegion(tempFile, Seq((0L, 3L)), lazyOpen = true) - region.release() - val target = new ByteArrayWritableChannel() - val thrown = intercept[Exception] { - region.transferTo(target, 0) - } - assert(thrown.getMessage.toLowerCase.contains("closed") || thrown.isInstanceOf[IllegalStateException]) - } - var tempFile: File = _ val fileContent = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala index e20be7c8fe81..b171b4712bd5 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala @@ -18,17 +18,66 @@ package org.apache.spark.shuffle.sort import org.apache.spark.network.util.TransportConf +import _root_.io.netty.channel.FileRegion import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite import java.io._ +import java.nio.channels.WritableByteChannel import java.nio.file.Files -class FakeTransportConf extends TransportConf("test", null) { +class FakeTransportConf(private val lazyOpen: Boolean = false) extends TransportConf("test", null) { override def memoryMapBytes(): Int = 4096 // 4 KB + override def lazyFileDescriptor(): Boolean = lazyOpen } class FileSegmentsManagedBufferSuite extends AnyFunSuite with BeforeAndAfterAll { + private class ByteArrayWritableChannel extends WritableByteChannel { + private val out = new ByteArrayOutputStream() + private var open = true + + def toByteArray: Array[Byte] = out.toByteArray + + override def write(src: java.nio.ByteBuffer): Int = { + val size = src.remaining() + val buf = new Array[Byte](size) + src.get(buf) + out.write(buf) + size + } + + override def isOpen: Boolean = open + override def close(): Unit = { open = false } + } + + test("convertToNetty returns FileRegion with correct count and content") { + val conf = new FakeTransportConf() + val segments = Seq((2L, 3L), (10L, 4L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val nettyObj = buf.convertToNetty() + assert(nettyObj.isInstanceOf[FileRegion]) + val region = nettyObj.asInstanceOf[FileRegion] + assert(region.count() == 7L) + assert(region.position() == 0L) + val target = new ByteArrayWritableChannel() + val written = region.transferTo(target, 0) + assert(written == 7L) + val expected = fileData.slice(2, 5) ++ fileData.slice(10, 14) + assert(target.toByteArray.sameElements(expected)) + } + + test("convertToNetty supports lazyOpen via FileRegion transfer") { + val conf = new FakeTransportConf(lazyOpen = true) + val segments = Seq((5L, 3L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val nettyObj = buf.convertToNetty() + assert(nettyObj.isInstanceOf[FileRegion]) + val region = nettyObj.asInstanceOf[FileRegion] + val target = new ByteArrayWritableChannel() + val written = region.transferTo(target, 0) + assert(written == 3L) + assert(target.toByteArray.sameElements(fileData.slice(5, 8))) + } private var tempFile: File = _ private val fileData: Array[Byte] = (0 until 100).map(_.toByte).toArray From f96ca3eb5443e80e4e89fe7d7ff82dd3dc25cb10 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Fri, 23 Jan 2026 19:16:05 +0800 Subject: [PATCH 08/20] Add ColumnarIndexShuffleBlockResolver#getSegmentsFromIndex --- .../ColumnarIndexShuffleBlockResolver.scala | 48 +++++++++- ...lumnarIndexShuffleBlockResolverSuite.scala | 90 ++++++++++++++++++- 2 files changed, 133 insertions(+), 5 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala index e5ec39a06003..f9b28373df4d 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala @@ -19,4 +19,50 @@ package org.apache.spark.shuffle.sort import org.apache.spark.SparkConf import org.apache.spark.shuffle._ -class ColumnarIndexShuffleBlockResolver(conf: SparkConf) extends IndexShuffleBlockResolver(conf) {} +import java.io.DataInputStream +import java.io.File +import java.nio.channels.Channels +import java.nio.channels.SeekableByteChannel + +class ColumnarIndexShuffleBlockResolver(conf: SparkConf) extends IndexShuffleBlockResolver(conf) { + private def isNewFormat(index: File): Boolean = { + // Simple heuristic to determine new format by appending 1 extra bytes. + // Old index file always has length multiple of 8 bytes. + if (index.length() % 8 == 1) true else false + } + + private def getSegmentsFromIndex( + channel: SeekableByteChannel, + startId: Int, + endId: Int): Seq[(Long, Long)] = { + // New Index Format: + // To support a partiton index with multiple segments, + // the index file is composed of three parts: + // 1) Partition Index: index_0, index_1, ... index_N + // Each index_i is an 8-byte long integer representing the byte offset in the file + // For partition i, its segments are stored between index_i and index_(i+1). + // 2) Segment Data: [offset_0][size_0][offset_1][size_1]...[offset_N][size_N] + // Each segment is represented by two 8-byte long integers: offset and size. + // offset is the byte offset in the data file, size is the length in bytes of this segment. + // All segments for all partitions are stored sequentially after the partition index. + // 3) One extra byte at the end to distinguish from old format. + channel.position(startId * 8L) + val in = new DataInputStream(Channels.newInputStream(channel)) + var startOffset = in.readLong() + channel.position(endId * 8L) + val endOffset = in.readLong() + if (endOffset < startOffset || (endOffset - startOffset) % 16 != 0) { + throw new IllegalStateException( + s"Index file: Invalid index to segments ($startOffset, $endOffset)") + } + val segmentCount = (endOffset - startOffset) / 16 + // Read segments + channel.position(startOffset) + val segments = for (i <- 0 until segmentCount.toInt) yield { + val offset = in.readLong() + val size = in.readLong() + (offset, size) + } + segments.filter(_._2 > 0) // filter out zero-size segments + } +} diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala index 556bf7cf5694..c6a8fcfad8cb 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala @@ -15,11 +15,93 @@ * limitations under the License. */ package org.apache.spark.shuffle.sort -import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkConf + +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.funsuite.AnyFunSuite -class ColumnarIndexShuffleBlockResolverSuite extends AnyFunSuite with BeforeAndAfterAll { - test("dummy test") { - assert(1 + 1 == 2) +import java.io.File +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.nio.file.StandardOpenOption + +class ColumnarIndexShuffleBlockResolverSuite + extends AnyFunSuite + with BeforeAndAfterAll + with BeforeAndAfterEach { + + private var tmpFile: File = _ + private var channel: FileChannel = _ + + override def beforeAll(): Unit = { + tmpFile = File.createTempFile("index", ".bin") + val fc = FileChannel.open(tmpFile.toPath, StandardOpenOption.WRITE, StandardOpenOption.READ) + // Partition index: 4 partitions, so 5 offsets + // Partition 0: 0 segments + // Partition 1: 1 segment + // Partition 2: 2 segments + // Partition 3: 3 segments + // Offsets: [indexEnd, seg1Start, seg2Start, seg3Start, segEnd] + // Let's say index is 5*8 = 40 bytes, segments start at 40 + val segOffsets = Array(40L, 40L, 56L, 88L, 136L) + val indexBuf = ByteBuffer.allocate(5 * 8) + segOffsets.foreach(indexBuf.putLong) + indexBuf.flip() + fc.write(indexBuf) + // Segment data: total 6 segments (1+2+3) + val segBuf = ByteBuffer.allocate(6 * 16) + // Define a data file with 200 bytes. + // Partition 1: + segBuf.putLong(0L); segBuf.putLong(10L) + // Partition 2: + segBuf.putLong(10L); segBuf.putLong(20L) + segBuf.putLong(100L); segBuf.putLong(50L) + // Partition 3: + segBuf.putLong(30L); segBuf.putLong(70L) + segBuf.putLong(150L); segBuf.putLong(50L) + segBuf.putLong(200L); segBuf.putLong(0L) + segBuf.flip() + fc.write(segBuf) + // Add one extra byte to mark new format + fc.write(ByteBuffer.wrap(Array(1.toByte))) + fc.close() + } + + override def beforeEach(): Unit = { + channel = FileChannel.open(tmpFile.toPath, StandardOpenOption.READ) + } + + override def afterEach(): Unit = { + if (channel != null) { + channel.close() + channel = null + } + } + + override def afterAll(): Unit = { + if (tmpFile != null) tmpFile.delete() + } + + test("getSegmentsFromIndex returns correct segments for complex index file") { + val resolver = new ColumnarIndexShuffleBlockResolver(new SparkConf()) + val method = classOf[ColumnarIndexShuffleBlockResolver].getDeclaredMethod( + "getSegmentsFromIndex", + classOf[java.nio.channels.SeekableByteChannel], + classOf[Int], + classOf[Int]) + method.setAccessible(true) + // Partition 0: 0 segments + val segs0 = + method.invoke(resolver, channel, Int.box(0), Int.box(1)).asInstanceOf[Seq[(Long, Long)]] + assert(segs0.isEmpty) + // Partition 1 - 2: 3 segment + val segs1 = + method.invoke(resolver, channel, Int.box(1), Int.box(3)).asInstanceOf[Seq[(Long, Long)]] + assert(segs1 == Seq((0L, 10L), (10L, 20L), (100L, 50L))) + // Partition 3: 2 segments, empty segment will be filter out + val segs3 = + method.invoke(resolver, channel, Int.box(3), Int.box(4)).asInstanceOf[Seq[(Long, Long)]] + assert(segs3 == Seq((30L, 70L), (150L, 50L))) } } From 4d1a015af58ace9d14bb8ef633cbd440966ec27a Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Mon, 26 Jan 2026 19:17:30 +0800 Subject: [PATCH 09/20] Impl ColumnarIndexShuffleBlockResolver#getBlockData --- .../ColumnarIndexShuffleBlockResolver.scala | 89 +++++++++++- ...lumnarIndexShuffleBlockResolverSuite.scala | 135 ++++++++++++++---- 2 files changed, 194 insertions(+), 30 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala index f9b28373df4d..0d7e54237460 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala @@ -17,14 +17,26 @@ package org.apache.spark.shuffle.sort import org.apache.spark.SparkConf +import org.apache.spark.SparkException +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.shuffle._ +import org.apache.spark.storage._ import java.io.DataInputStream import java.io.File import java.nio.channels.Channels +import java.nio.channels.FileChannel import java.nio.channels.SeekableByteChannel +import java.nio.file.StandardOpenOption + +class ColumnarIndexShuffleBlockResolver( + conf: SparkConf, + private var blockManager: BlockManager = null) + extends IndexShuffleBlockResolver(conf, blockManager) { + + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") -class ColumnarIndexShuffleBlockResolver(conf: SparkConf) extends IndexShuffleBlockResolver(conf) { private def isNewFormat(index: File): Boolean = { // Simple heuristic to determine new format by appending 1 extra bytes. // Old index file always has length multiple of 8 bytes. @@ -65,4 +77,79 @@ class ColumnarIndexShuffleBlockResolver(conf: SparkConf) extends IndexShuffleBlo } segments.filter(_._2 > 0) // filter out zero-size segments } + + private def checkIndexAndDataFile(indexFile: File, dataFile: File, numPartitions: Int): Unit = { + if (!indexFile.exists()) { + throw new IllegalStateException(s"Index file $indexFile does not exist") + } + if (!dataFile.exists()) { + throw new IllegalStateException(s"Data file $dataFile does not exist") + } + if (!isNewFormat(indexFile)) { + throw new IllegalStateException(s"Index file $indexFile is not in the new format") + } + + var index = FileChannel.open(indexFile.toPath, StandardOpenOption.READ) + var dataFileSize = dataFile.length() + try { + for (i <- 0 until numPartitions) { + var segments = getSegmentsFromIndex(index, startId = i, endId = i + 1) + segments.foreach { + case (offset, size) => + if (offset < 0 || size < 0 || offset + size > dataFileSize) { + throw new IllegalStateException( + s"Index file $indexFile has invalid segment ($offset, $size) " + + s"for partition $i in data file of size $dataFileSize") + } + } + } + } finally { + index.close() + } + } + + def writeIndexFileAndCommit( + shuffleId: Int, + mapId: Long, + dataTmp: File, + indexTmp: File, + numPartitions: Int, + checksums: Option[Array[Long]] = None + ): Unit = { + val indexFile = getIndexFile(shuffleId, mapId) + val dataFile = getDataFile(shuffleId, mapId) + this.synchronized { + checkIndexAndDataFile(indexTmp, dataTmp, numPartitions) + dataTmp.renameTo(dataFile) + indexTmp.renameTo(indexFile) + } + } + + override def getBlockData(blockId: BlockId, dirs: Option[Array[String]]): ManagedBuffer = { + val (shuffleId, mapId, startReduceId, endReduceId) = blockId match { + case id: ShuffleBlockId => + (id.shuffleId, id.mapId, id.reduceId, id.reduceId + 1) + case batchId: ShuffleBlockBatchId => + (batchId.shuffleId, batchId.mapId, batchId.startReduceId, batchId.endReduceId) + case _ => + throw SparkException.internalError( + s"unexpected shuffle block id format: $blockId", + category = "SHUFFLE") + } + val indexFile = getIndexFile(shuffleId, mapId, dirs) + + if (!isNewFormat(indexFile)) { + // fallback to old implementation + return super.getBlockData(blockId, dirs) + } + + var index = FileChannel.open(indexFile.toPath, StandardOpenOption.READ) + try { + val segments = getSegmentsFromIndex(index, startReduceId, endReduceId) + val dataFile = getDataFile(shuffleId, mapId, dirs) + new FileSegmentsManagedBuffer(transportConf, dataFile, segments) + } finally { + index.close() + } + } } diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala index c6a8fcfad8cb..15b67ae80256 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala @@ -17,7 +17,15 @@ package org.apache.spark.shuffle.sort import org.apache.spark.SparkConf +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.storage._ +import org.apache.spark.util.Utils +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.funsuite.AnyFunSuite @@ -31,12 +39,20 @@ class ColumnarIndexShuffleBlockResolverSuite with BeforeAndAfterAll with BeforeAndAfterEach { - private var tmpFile: File = _ - private var channel: FileChannel = _ + @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ + + private val conf: SparkConf = new SparkConf(loadDefaults = false) + private var tempDir: File = _ + private val appId = "TESTAPP" + + private var indexFile: File = _ + private var dataFile: File = _ + private var resolver: ColumnarIndexShuffleBlockResolver = _ override def beforeAll(): Unit = { - tmpFile = File.createTempFile("index", ".bin") - val fc = FileChannel.open(tmpFile.toPath, StandardOpenOption.WRITE, StandardOpenOption.READ) + indexFile = File.createTempFile("index", ".bin") + val fc = FileChannel.open(indexFile.toPath, StandardOpenOption.WRITE, StandardOpenOption.READ) // Partition index: 4 partitions, so 5 offsets // Partition 0: 0 segments // Partition 1: 1 segment @@ -66,42 +82,103 @@ class ColumnarIndexShuffleBlockResolverSuite // Add one extra byte to mark new format fc.write(ByteBuffer.wrap(Array(1.toByte))) fc.close() + + // data file with 200 bytes, 0-199 + dataFile = File.createTempFile("data", ".bin") + val out = new java.io.FileOutputStream(dataFile) + out.write((0 until 200).map(_.toByte).toArray) + out.close() } override def beforeEach(): Unit = { - channel = FileChannel.open(tmpFile.toPath, StandardOpenOption.READ) + tempDir = Utils.createTempDir() + MockitoAnnotations.initMocks(this) + + when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(diskBlockManager.getFile(any[BlockId])).thenAnswer( + (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString)) + when(diskBlockManager.getFile(any[String])).thenAnswer( + (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString)) + when(diskBlockManager.getMergedShuffleFile(any[BlockId], any[Option[Array[String]]])) + .thenAnswer( + (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString)) + when(diskBlockManager.localDirs).thenReturn(Array(tempDir)) + when(diskBlockManager.createTempFileWith(any(classOf[File]))) + .thenAnswer { + invocationOnMock => + val file = invocationOnMock.getArguments()(0).asInstanceOf[File] + Utils.tempFileWith(file) + } + conf.set("spark.app.id", appId) + + resolver = new ColumnarIndexShuffleBlockResolver(conf, blockManager) } override def afterEach(): Unit = { - if (channel != null) { - channel.close() - channel = null - } + resolver = null + Utils.deleteRecursively(tempDir) } override def afterAll(): Unit = { - if (tmpFile != null) tmpFile.delete() + if (indexFile != null) indexFile.delete() + if (dataFile != null) dataFile.delete() } test("getSegmentsFromIndex returns correct segments for complex index file") { - val resolver = new ColumnarIndexShuffleBlockResolver(new SparkConf()) - val method = classOf[ColumnarIndexShuffleBlockResolver].getDeclaredMethod( - "getSegmentsFromIndex", - classOf[java.nio.channels.SeekableByteChannel], - classOf[Int], - classOf[Int]) - method.setAccessible(true) - // Partition 0: 0 segments - val segs0 = - method.invoke(resolver, channel, Int.box(0), Int.box(1)).asInstanceOf[Seq[(Long, Long)]] - assert(segs0.isEmpty) - // Partition 1 - 2: 3 segment - val segs1 = - method.invoke(resolver, channel, Int.box(1), Int.box(3)).asInstanceOf[Seq[(Long, Long)]] - assert(segs1 == Seq((0L, 10L), (10L, 20L), (100L, 50L))) - // Partition 3: 2 segments, empty segment will be filter out - val segs3 = - method.invoke(resolver, channel, Int.box(3), Int.box(4)).asInstanceOf[Seq[(Long, Long)]] - assert(segs3 == Seq((30L, 70L), (150L, 50L))) + val channel = FileChannel.open(indexFile.toPath, StandardOpenOption.READ) + try { + val method = classOf[ColumnarIndexShuffleBlockResolver].getDeclaredMethod( + "getSegmentsFromIndex", + classOf[java.nio.channels.SeekableByteChannel], + classOf[Int], + classOf[Int]) + method.setAccessible(true) + // Partition 0: 0 segments + val segs0 = + method.invoke(resolver, channel, Int.box(0), Int.box(1)).asInstanceOf[Seq[(Long, Long)]] + assert(segs0.isEmpty) + // Partition 1 - 2: 3 segment + val segs1_2 = + method.invoke(resolver, channel, Int.box(1), Int.box(3)).asInstanceOf[Seq[(Long, Long)]] + assert(segs1_2 == Seq((0L, 10L), (10L, 20L), (100L, 50L))) + // Partition 3: 2 segments, empty segment will be filter out + val segs3 = + method.invoke(resolver, channel, Int.box(3), Int.box(4)).asInstanceOf[Seq[(Long, Long)]] + assert(segs3 == Seq((30L, 70L), (150L, 50L))) + } finally { + channel.close() + } + } + + test("getBlockData returns correct data for partition segments") { + val shuffleId = 1 + val mapId = 0L + val partitionId = 2 + val blockId = ShuffleBlockId(shuffleId, mapId, partitionId) + // commit index and data files + resolver.writeIndexFileAndCommit(shuffleId, mapId, dataFile, indexFile, numPartitions = 4) + + // getBlockData should return a ManagedBuffer for the requested block + val buffer = resolver.getBlockData(blockId) + val bytes = readManagedBuffer(buffer) + // Partition 2 segments: (10,20), (100,50) + assert(bytes.length == 70, "Expected 70 bytes, got ${bytes.length}") + val expected = ((10 until 30) ++ (100 until 150)).map(_.toByte).toArray + assert( + bytes.sameElements(expected), + s"Expected ${expected.mkString(",")}, got ${bytes.mkString(",")}") + } + + private def readManagedBuffer(buffer: ManagedBuffer): Array[Byte] = { + val in = buffer.createInputStream() + val out = new scala.collection.mutable.ArrayBuffer[Byte]() + val buf = new Array[Byte](4096) + var n = in.read(buf) + while (n != -1) { + out ++= buf.take(n) + n = in.read(buf) + } + in.close() + out.toArray } } From d6159c101290679f036d8455f6c5372cb503200a Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Tue, 27 Jan 2026 20:30:20 +0800 Subject: [PATCH 10/20] ColumnarIndexShuffleBlockResolver: add limitation of using new format --- .../ColumnarIndexShuffleBlockResolver.scala | 11 +++++++ ...lumnarIndexShuffleBlockResolverSuite.scala | 30 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala index 0d7e54237460..450ecf5bfa67 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala @@ -36,6 +36,17 @@ class ColumnarIndexShuffleBlockResolver( extends IndexShuffleBlockResolver(conf, blockManager) { private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + private val SHUFFLE_SERVICE_ENABLED = "spark.shuffle.service.enabled" + private val PUSH_BASED_SHUFFLE_ENABLED = "spark.shuffle.push.based.enabled" + private var newFormatEnabled: Boolean = _ + + // When external shuffle service or push-based shuffle is enabled, + // it may directly access the shuffle files without going through this resolver. + // we cannot use the new index format to maintain backward compatibility. + newFormatEnabled = !(conf.getBoolean(PUSH_BASED_SHUFFLE_ENABLED, false) || + conf.getBoolean(SHUFFLE_SERVICE_ENABLED, false)) + + def canUseNewFormat(): Boolean = newFormatEnabled private def isNewFormat(index: File): Boolean = { // Simple heuristic to determine new format by appending 1 extra bytes. diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala index 15b67ae80256..fb6288d6739b 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala @@ -181,4 +181,34 @@ class ColumnarIndexShuffleBlockResolverSuite in.close() out.toArray } + + test("canUseNewFormat returns correct value based on config") { + val conf1 = new SparkConf() + .set("spark.shuffle.push.based.enabled", "false") + .set("spark.shuffle.service.enabled", "false") + val resolver1 = new ColumnarIndexShuffleBlockResolver(conf1) + assert(resolver1.canUseNewFormat(), "Should use new format when both configs are false") + + val conf2 = new SparkConf() + .set("spark.shuffle.push.based.enabled", "true") + .set("spark.shuffle.service.enabled", "false") + val resolver2 = new ColumnarIndexShuffleBlockResolver(conf2) + assert( + !resolver2.canUseNewFormat(), + "Should not use new format when push-based shuffle is enabled") + + val conf3 = new SparkConf() + .set("spark.shuffle.push.based.enabled", "false") + .set("spark.shuffle.service.enabled", "true") + val resolver3 = new ColumnarIndexShuffleBlockResolver(conf3) + assert( + !resolver3.canUseNewFormat(), + "Should not use new format when shuffle service is enabled") + + val conf4 = new SparkConf() + .set("spark.shuffle.push.based.enabled", "true") + .set("spark.shuffle.service.enabled", "true") + val resolver4 = new ColumnarIndexShuffleBlockResolver(conf4) + assert(!resolver4.canUseNewFormat(), "Should not use new format when both are enabled") + } } From 62865052ef9d681dcc0b95295344dad4ae7b1455 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 29 Jan 2026 10:01:58 +0800 Subject: [PATCH 11/20] Add an optional indexFile params to LocalPartitionWriter for multiple segments support --- .../spark/shuffle/ColumnarShuffleWriter.scala | 23 +++++++++++++++++++ cpp/core/jni/JniWrapper.cc | 5 +++- cpp/core/shuffle/LocalPartitionWriter.cc | 3 ++- cpp/core/shuffle/LocalPartitionWriter.h | 4 +++- .../LocalPartitionWriterJniWrapper.java | 1 + 5 files changed, 33 insertions(+), 3 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala index ef9877011b5d..d761c1794b53 100644 --- a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala +++ b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{SHUFFLE_COMPRESS, SHUFFLE_DISK_WRITE_BUFFER_SIZE, SHUFFLE_FILE_BUFFER_SIZE, SHUFFLE_SORT_INIT_BUFFER_SIZE, SHUFFLE_SORT_USE_RADIXSORT} import org.apache.spark.memory.SparkMemoryUtil import org.apache.spark.scheduler.MapStatus +import org.apache.spark.shuffle.sort.ColumnarIndexShuffleBlockResolver import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SparkDirectoryUtil, SparkResourceUtil, Utils} @@ -42,6 +43,8 @@ class ColumnarShuffleWriter[K, V]( with Logging { private val dep = handle.dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]] + private var columnarShuffleBlockResolver: ColumnarIndexShuffleBlockResolver = _ + private var partitionUseMultipleSegments: Boolean = false dep.shuffleWriterType match { case HashShuffleWriterType | SortShuffleWriterType | GpuHashShuffleWriterType => @@ -52,6 +55,17 @@ class ColumnarShuffleWriter[K, V]( s"expected one of: ${HashShuffleWriterType.name}, ${SortShuffleWriterType.name}") } + shuffleBlockResolver match { + case resolver: ColumnarIndexShuffleBlockResolver => + if (resolver.canUseNewFormat() && !GlutenConfig.get.columnarShuffleEnableDictionary) { + // For Dictionary encoding, the dict only finializes after all batches are processed, + // and dict is required to saved at the head of the partition data. + // So we cannot use multiple segments to save partition data incrementally. + partitionUseMultipleSegments = true + columnarShuffleBlockResolver = resolver + } + } + protected val isSort: Boolean = dep.shuffleWriterType == SortShuffleWriterType private val numPartitions: Int = dep.partitioner.numPartitions @@ -136,6 +150,11 @@ class ColumnarShuffleWriter[K, V]( } val tempDataFile = Utils.tempFileWith(shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)) + val tempIndexFile = if (partitionUseMultipleSegments) { + Utils.tempFileWith(columnarShuffleBlockResolver.getIndexFile(dep.shuffleId, mapId)) + } else { + null + } while (records.hasNext) { val cb = records.next()._2.asInstanceOf[ColumnarBatch] @@ -155,6 +174,7 @@ class ColumnarShuffleWriter[K, V]( blockManager.subDirsPerLocalDir, conf.get(SHUFFLE_FILE_BUFFER_SIZE).toInt, tempDataFile.getAbsolutePath, + if (tempIndexFile != null) tempIndexFile.getAbsolutePath else null, localDirs, GlutenConfig.get.columnarShuffleEnableDictionary ) @@ -274,6 +294,9 @@ class ColumnarShuffleWriter[K, V]( if (tempDataFile.exists() && !tempDataFile.delete()) { logError(s"Error while deleting temp file ${tempDataFile.getAbsolutePath}") } + if (tempIndexFile != null && tempIndexFile.exists() && !tempIndexFile.delete()) { + logError(s"Error while deleting temp file ${tempIndexFile.getAbsolutePath}") + } } // The partitionLength is much more than vanilla spark partitionLengths, diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc index c48886195b22..b24b56d92294 100644 --- a/cpp/core/jni/JniWrapper.cc +++ b/cpp/core/jni/JniWrapper.cc @@ -943,6 +943,7 @@ Java_org_apache_gluten_vectorized_LocalPartitionWriterJniWrapper_createPartition jint numSubDirs, jint shuffleFileBufferSize, jstring dataFileJstr, + jstring indexFileJstr, jstring localDirsJstr, jboolean enableDictionary) { JNI_METHOD_START @@ -950,6 +951,7 @@ Java_org_apache_gluten_vectorized_LocalPartitionWriterJniWrapper_createPartition const auto ctx = getRuntime(env, wrapper); auto dataFile = jStringToCString(env, dataFileJstr); + auto indexFile = jStringToCString(env, indexFileJstr); auto localDirs = splitPaths(jStringToCString(env, localDirsJstr)); auto partitionWriterOptions = std::make_shared( @@ -968,7 +970,8 @@ Java_org_apache_gluten_vectorized_LocalPartitionWriterJniWrapper_createPartition ctx->memoryManager(), partitionWriterOptions, dataFile, - std::move(localDirs)); + std::move(localDirs), + indexFile); return ctx->saveObject(partitionWriter); JNI_METHOD_END(kInvalidObjectHandle) diff --git a/cpp/core/shuffle/LocalPartitionWriter.cc b/cpp/core/shuffle/LocalPartitionWriter.cc index 948a2b0e05ef..62c0a3dc92ee 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.cc +++ b/cpp/core/shuffle/LocalPartitionWriter.cc @@ -506,7 +506,8 @@ LocalPartitionWriter::LocalPartitionWriter( MemoryManager* memoryManager, const std::shared_ptr& options, const std::string& dataFile, - std::vector localDirs) + std::vector localDirs, + const std::string& indexFile) : PartitionWriter(numPartitions, std::move(codec), memoryManager), options_(options), dataFile_(dataFile), diff --git a/cpp/core/shuffle/LocalPartitionWriter.h b/cpp/core/shuffle/LocalPartitionWriter.h index 113e5a3cfd2f..c7a52d317f25 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.h +++ b/cpp/core/shuffle/LocalPartitionWriter.h @@ -33,7 +33,8 @@ class LocalPartitionWriter : public PartitionWriter { MemoryManager* memoryManager, const std::shared_ptr& options, const std::string& dataFile, - std::vector localDirs); + std::vector localDirs, + const std::string& indexFile = nullptr); arrow::Status hashEvict( uint32_t partitionId, @@ -108,6 +109,7 @@ class LocalPartitionWriter : public PartitionWriter { std::shared_ptr options_; std::string dataFile_; + std::string indexFile_; std::vector localDirs_; bool stopped_{false}; diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LocalPartitionWriterJniWrapper.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LocalPartitionWriterJniWrapper.java index 3269e7b17610..f727a52ae29d 100644 --- a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LocalPartitionWriterJniWrapper.java +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LocalPartitionWriterJniWrapper.java @@ -47,6 +47,7 @@ public native long createPartitionWriter( int subDirsPerLocalDir, int shuffleFileBufferSize, String dataFile, + String indexFile, String localDirs, boolean enableDictionary); } From f24c45ae1931c915a527176a310bd10faa7d52d8 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Sun, 1 Feb 2026 20:40:40 +0800 Subject: [PATCH 12/20] LocalPartitionWriter: support multiple segments of partition --- .../spark/shuffle/ColumnarShuffleWriter.scala | 22 ++- cpp/core/shuffle/LocalPartitionWriter.cc | 133 +++++++++++++++++- cpp/core/shuffle/LocalPartitionWriter.h | 11 ++ 3 files changed, 159 insertions(+), 7 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala index d761c1794b53..1ab23d9bf12f 100644 --- a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala +++ b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala @@ -284,12 +284,22 @@ class ColumnarShuffleWriter[K, V]( partitionLengths = splitResult.getPartitionLengths try { - shuffleBlockResolver.writeMetadataFileAndCommit( - dep.shuffleId, - mapId, - partitionLengths, - Array[Long](), - tempDataFile) + if (partitionUseMultipleSegments) { + columnarShuffleBlockResolver.writeIndexFileAndCommit( + dep.shuffleId, + mapId, + tempDataFile, + tempIndexFile, + numPartitions, + None) + } else { + shuffleBlockResolver.writeMetadataFileAndCommit( + dep.shuffleId, + mapId, + partitionLengths, + Array[Long](), + tempDataFile) + } } finally { if (tempDataFile.exists() && !tempDataFile.delete()) { logError(s"Error while deleting temp file ${tempDataFile.getAbsolutePath}") diff --git a/cpp/core/shuffle/LocalPartitionWriter.cc b/cpp/core/shuffle/LocalPartitionWriter.cc index 62c0a3dc92ee..82a1bd298925 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.cc +++ b/cpp/core/shuffle/LocalPartitionWriter.cc @@ -30,6 +30,7 @@ #include #include #include +#include "LocalPartitionWriter.h" namespace gluten { @@ -356,6 +357,25 @@ class LocalPartitionWriter::PayloadCache { return arrow::Status::OK(); } + arrow::Result writeIncremental(uint32_t partitionId, arrow::io::OutputStream *os) { + GLUTEN_CHECK(!enableDictionary_, "Incremental write is not supported when dictionary is enabled."); + if ((partitionInUse_.has_value() && partitionInUse_.value() == partitionId) || !hasCachedPayloads(partitionId)) { + return false; + } + + auto& payloads = partitionCachedPayload_[partitionId]; + while (!payloads.empty()) { + const auto payload = std::move(payloads.front()); + payloads.pop_front(); + uint8_t blockType = static_cast(BlockType::kPlainPayload); + RETURN_NOT_OK(os->Write(&blockType, sizeof(blockType))); + RETURN_NOT_OK(payload->serialize(os)); + compressTime_ += payload->getCompressTime(); + writeTime_ += payload->getWriteTime(); + } + return true; + } + bool canSpill() { for (auto pid = 0; pid < numPartitions_; ++pid) { if (partitionInUse_.has_value() && partitionInUse_.value() == pid) { @@ -511,6 +531,7 @@ LocalPartitionWriter::LocalPartitionWriter( : PartitionWriter(numPartitions, std::move(codec), memoryManager), options_(options), dataFile_(dataFile), + indexFile_(indexFile), localDirs_(std::move(localDirs)) { init(); } @@ -563,6 +584,55 @@ void LocalPartitionWriter::init() { std::default_random_engine engine(rd()); std::shuffle(localDirs_.begin(), localDirs_.end(), engine); subDirSelection_.assign(localDirs_.size(), 0); + + if (!indexFile_.empty()) { + usePartitionMultipleSegments_ = true; + partitionSegments_.resize(numPartitions_); + } +} + +// Helper for big-endian conversion (network order) +#include +static uint64_t htonll(uint64_t value) { +#if __BYTE_ORDER == __LITTLE_ENDIAN + return (((uint64_t)htonl(value & 0xFFFFFFFFULL)) << 32) | htonl(value >> 32); +#else + return value; +#endif +} + +arrow::Status LocalPartitionWriter::writeIndexFile() { + if (!usePartitionMultipleSegments_) { + return arrow::Status::OK(); + } + + ARROW_ASSIGN_OR_RAISE(auto indexFileOs, openFile(indexFile_, options_->shuffleFileBufferSize)); + + uint64_t segmentOffset = (numPartitions_ + 1) * sizeof(int64_t); + // write segment index of each partition in big-endian + for (uint32_t pid = 0; pid < numPartitions_; ++pid) { + uint64_t beOffset = htonll(segmentOffset); + RETURN_NOT_OK(indexFileOs->Write(reinterpret_cast(&beOffset), sizeof(beOffset))); + const auto& segments = partitionSegments_[pid]; + segmentOffset += (segments.size() * 2 * sizeof(int64_t)); + } + uint64_t beOffset = htonll(segmentOffset); + RETURN_NOT_OK(indexFileOs->Write(reinterpret_cast(&beOffset), sizeof(beOffset))); + // Write partition segments info in big-endian + for (uint32_t pid = 0; pid < numPartitions_; ++pid) { + const auto& segments = partitionSegments_[pid]; + for (const auto& segment : segments) { + uint64_t beFirst = htonll(segment.first); + uint64_t beSecond = htonll(segment.second); + RETURN_NOT_OK(indexFileOs->Write(reinterpret_cast(&beFirst), sizeof(beFirst))); + RETURN_NOT_OK(indexFileOs->Write(reinterpret_cast(&beSecond), sizeof(beSecond))); + } + } + + // Write a ending char + RETURN_NOT_OK(indexFileOs->Write("", 1)); + RETURN_NOT_OK(indexFileOs->Close()); + return arrow::Status::OK(); } arrow::Result LocalPartitionWriter::mergeSpills(uint32_t partitionId, arrow::io::OutputStream* os) { @@ -601,13 +671,67 @@ arrow::Status LocalPartitionWriter::writeCachedPayloads(uint32_t partitionId, ar return arrow::Status::OK(); } +arrow::Status LocalPartitionWriter::flushCachedPlayloads() { + if (payloadCache_ == nullptr) { + return arrow::Status::OK(); + } + if (dataFileOs_ == nullptr) { + ARROW_ASSIGN_OR_RAISE(dataFileOs_, openFile(dataFile_, options_->shuffleFileBufferSize)); + } + ARROW_ASSIGN_OR_RAISE(int64_t endInDataFile, dataFileOs_->Tell()); + for (auto pid = 0; pid < numPartitions_; ++pid) { + auto startInDataFile = endInDataFile; + RETURN_NOT_OK(mergeSpills(pid, dataFileOs_.get())); + RETURN_NOT_OK(payloadCache_->writeIncremental(pid, dataFileOs_.get())); + ARROW_ASSIGN_OR_RAISE(endInDataFile, dataFileOs_->Tell()); + + auto bytesWritten = endInDataFile - startInDataFile; + if (bytesWritten > 0) { + partitionSegments_[pid].emplace_back(startInDataFile, bytesWritten); + partitionLengths_[pid] += bytesWritten; + } + } + spills_.clear(); + return arrow::Status::OK(); +} + +arrow::Status LocalPartitionWriter::writeMemoryPayload(uint32_t partitionId, std::unique_ptr payload) { + if (dataFileOs_ == nullptr) { + ARROW_ASSIGN_OR_RAISE(dataFileOs_, openFile(dataFile_, options_->shuffleFileBufferSize)); + } + + auto shouldCompress = codec_ != nullptr && payload->numRows() >= options_->compressionThreshold; + + ARROW_ASSIGN_OR_RAISE( + auto block, + payload->toBlockPayload(shouldCompress ? Payload::kToBeCompressed : Payload::kUncompressed, payloadPool_.get(), codec_.get())); + + ARROW_ASSIGN_OR_RAISE(int64_t startOffset, dataFileOs_->Tell()); + uint8_t blockType = static_cast(BlockType::kPlainPayload); + RETURN_NOT_OK(dataFileOs_->Write(&blockType, sizeof(blockType))); + RETURN_NOT_OK(block->serialize(dataFileOs_.get())); + compressTime_ += block->getCompressTime(); + writeTime_ += block->getWriteTime(); + ARROW_ASSIGN_OR_RAISE(int64_t endOffset, dataFileOs_->Tell()); + auto bytesWritten = endOffset - startOffset; + partitionSegments_[partitionId].emplace_back(startOffset, bytesWritten); + partitionLengths_[partitionId] += bytesWritten; + + return arrow::Status::OK(); +} + arrow::Status LocalPartitionWriter::stop(ShuffleWriterMetrics* metrics, int64_t& evictBytes) { if (stopped_) { return arrow::Status::OK(); } stopped_ = true; - if (useSpillFileAsDataFile_) { + if (usePartitionMultipleSegments_) { + RETURN_NOT_OK(finishSpill()); + RETURN_NOT_OK(finishMerger()); + RETURN_NOT_OK(flushCachedPlayloads()); + RETURN_NOT_OK(writeIndexFile()); + } else if (useSpillFileAsDataFile_) { ARROW_ASSIGN_OR_RAISE(auto spill, spiller_->finish()); // Merge the remaining partitions from spills. @@ -751,6 +875,7 @@ arrow::Status LocalPartitionWriter::hashEvict( for (auto& payload : merged) { RETURN_NOT_OK(payloadCache_->cache(partitionId, std::move(payload))); } + RETURN_NOT_OK(flushCachedPlayloads()); merged.clear(); } return arrow::Status::OK(); @@ -760,6 +885,12 @@ arrow::Status LocalPartitionWriter::sortEvict(uint32_t partitionId, std::unique_ptr inMemoryPayload, bool isFinal, int64_t& evictBytes) { rawPartitionLengths_[partitionId] += inMemoryPayload->rawSize(); + if (usePartitionMultipleSegments_) { + // If multiple segments per partition is enabled, write directly to the final data file. + RETURN_NOT_OK(writeMemoryPayload(partitionId, std::move(inMemoryPayload))); + return arrow::Status::OK(); + } + if (lastEvictPid_ != -1 && (partitionId < lastEvictPid_ || (isFinal && !dataFileOs_))) { lastEvictPid_ = -1; RETURN_NOT_OK(finishSpill()); diff --git a/cpp/core/shuffle/LocalPartitionWriter.h b/cpp/core/shuffle/LocalPartitionWriter.h index c7a52d317f25..6d3051d7e571 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.h +++ b/cpp/core/shuffle/LocalPartitionWriter.h @@ -91,6 +91,8 @@ class LocalPartitionWriter : public PartitionWriter { void init(); + arrow::Status writeIndexFile(); + arrow::Status requestSpill(bool isFinal); arrow::Status finishSpill(); @@ -103,6 +105,10 @@ class LocalPartitionWriter : public PartitionWriter { arrow::Status writeCachedPayloads(uint32_t partitionId, arrow::io::OutputStream* os) const; + arrow::Status flushCachedPlayloads(); + + arrow::Status writeMemoryPayload(uint32_t partitionId, std::unique_ptr inMemoryPayload); + arrow::Status clearResource(); arrow::Status populateMetrics(ShuffleWriterMetrics* metrics); @@ -130,6 +136,11 @@ class LocalPartitionWriter : public PartitionWriter { std::vector partitionLengths_; std::vector rawPartitionLengths_; + bool usePartitionMultipleSegments_{false}; + // For each partition, record all segments' (start, length) in the final data file. + // partitionSegments_[pid] = [(start1, length1), (start2, length2), ...] + std::vector>> partitionSegments_{}; + int32_t lastEvictPid_{-1}; }; } // namespace gluten From 5a2ea3ee4b882ea376900cb5d1cdc9e9a9fa7a03 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Mon, 2 Feb 2026 16:45:52 +0800 Subject: [PATCH 13/20] Add LowCopyFileSegmentsJniByteInputStream to support native read --- .../vectorized/JniByteInputStreams.java | 3 + ...LowCopyFileSegmentsJniByteInputStream.java | 165 ++++++++++++++++++ ...opyFileSegmentsJniByteInputStreamTest.java | 132 ++++++++++++++ .../shuffle/sort/FileSegmentsBuffer.scala | 4 +- 4 files changed, 302 insertions(+), 2 deletions(-) create mode 100644 gluten-arrow/src/main/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStream.java create mode 100644 gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java index 0749d0cff3fd..0eb24ef9a21a 100644 --- a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java @@ -53,6 +53,9 @@ public static JniByteInputStream create(InputStream in) { if (LowCopyFileSegmentJniByteInputStream.isSupported(unwrapped)) { return new LowCopyFileSegmentJniByteInputStream(in); } + if (LowCopyFileSegmentsJniByteInputStream.isSupported(unwrapped)) { + return new LowCopyFileSegmentsJniByteInputStream(in); + } return new OnHeapJniByteInputStream(in); } diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStream.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStream.java new file mode 100644 index 000000000000..bd6259b01786 --- /dev/null +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStream.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.vectorized; + +import org.apache.gluten.exception.GlutenException; + +import java.io.IOException; +import java.io.InputStream; +import java.io.SequenceInputStream; +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Enumeration; +import java.util.List; + +/** + * This implementation is targeted to optimize against Spark's {@link + * org.apache.spark.network.buffer.FileSegmentManagedBuffer} to make sure shuffle data is shared + * over JNI without unnecessary copy. + */ +public class LowCopyFileSegmentsJniByteInputStream implements JniByteInputStream { + private static final Field FIELD_SequenceInputStream_in; + private static final Field FIELD_SequenceInputStream_e; + + static { + try { + FIELD_SequenceInputStream_in = SequenceInputStream.class.getDeclaredField("in"); + FIELD_SequenceInputStream_in.setAccessible(true); + FIELD_SequenceInputStream_e = SequenceInputStream.class.getDeclaredField("e"); + FIELD_SequenceInputStream_e.setAccessible(true); + } catch (NoSuchFieldException e) { + throw new GlutenException(e); + } + } + + private final InputStream in; + private final List segments; + + private int currentIndex = 0; + private long bytesRead = 0L; + + public LowCopyFileSegmentsJniByteInputStream(InputStream in) { + this.in = in; // to prevent underlying netty buffer from being collected by GC + final InputStream unwrapped = JniByteInputStreams.unwrapSparkInputStream(in); + final SequenceInputStream sin = (SequenceInputStream) unwrapped; + final List streams = collectStreams(sin, false); + this.segments = buildSegments(streams); + } + + public static boolean isSupported(InputStream in) { + if (!(in instanceof SequenceInputStream)) { + return false; + } + final SequenceInputStream sin = (SequenceInputStream) in; + final List streams = collectStreams(sin, true); + for (InputStream stream : streams) { + final InputStream unwrapped = JniByteInputStreams.unwrapSparkInputStream(stream); + if (!LowCopyFileSegmentJniByteInputStream.isSupported(unwrapped)) { + return false; + } + } + return true; + } + + @Override + public long read(long destAddress, long maxSize) { + if (maxSize <= 0) { + return 0; + } + long remaining = maxSize; + long totalRead = 0L; + while (remaining > 0 && currentIndex < segments.size()) { + LowCopyFileSegmentJniByteInputStream segment = segments.get(currentIndex); + long request = remaining; + long bytes = segment.read(destAddress + totalRead, request); + if (bytes == 0) { + currentIndex++; + continue; + } + bytesRead += bytes; + remaining -= bytes; + totalRead += bytes; + } + return totalRead; + } + + @Override + public long tell() { + return bytesRead; + } + + @Override + public void close() { + try { + for (LowCopyFileSegmentJniByteInputStream segment : segments) { + segment.close(); + } + in.close(); + } catch (IOException e) { + throw new GlutenException(e); + } + } + + private static List buildSegments( + List streams) { + List segments = new ArrayList<>(streams.size()); + for (InputStream stream : streams) { + segments.add(new LowCopyFileSegmentJniByteInputStream(stream)); + } + return segments; + } + + private static List collectStreams(SequenceInputStream sin, boolean restore) { + final List streams = new ArrayList<>(); + final InputStream current; + final Enumeration enumeration; + try { + current = (InputStream) FIELD_SequenceInputStream_in.get(sin); + enumeration = (Enumeration) FIELD_SequenceInputStream_e.get(sin); + } catch (IllegalAccessException e) { + throw new GlutenException(e); + } + if (current != null) { + streams.add(current); + } + if (enumeration != null) { + while (enumeration.hasMoreElements()) { + streams.add(enumeration.nextElement()); + } + } + + if (restore) { + // restore the enumeration in SequenceInputStream.e + try { + if (streams.isEmpty()) { + FIELD_SequenceInputStream_e.set(sin, Collections.enumeration(Collections.emptyList())); + } else { + if (streams.size() == 1) { + FIELD_SequenceInputStream_e.set(sin, Collections.enumeration(Collections.emptyList())); + } else { + FIELD_SequenceInputStream_e.set( + sin, Collections.enumeration(streams.subList(1, streams.size()))); + } + } + } catch (IllegalAccessException e) { + throw new GlutenException(e); + } + } + return streams; + } +} diff --git a/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java b/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java new file mode 100644 index 000000000000..f7d2efbf78f4 --- /dev/null +++ b/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.vectorized; + +import io.netty.util.internal.PlatformDependent; +import org.apache.spark.network.util.LimitedInputStream; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.InputStream; +import java.io.SequenceInputStream; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class LowCopyFileSegmentsJniByteInputStreamTest { + private File tempFile; + + @Before + public void setUp() throws Exception { + tempFile = Files.createTempFile("gluten-segments", ".bin").toFile(); + } + + @After + public void tearDown() throws Exception { + if (tempFile != null && tempFile.exists()) { + Files.deleteIfExists(tempFile.toPath()); + } + } + + @Test + public void testReadAcrossSegments() throws Exception { + byte[] bytes = "abcdefg-0123456".getBytes(StandardCharsets.UTF_8); + try (FileOutputStream out = new FileOutputStream(tempFile)) { + out.write(bytes); + } + + int firstLen = 4; + int secondLen = bytes.length - firstLen; + InputStream first = createLimited(tempFile, 0, firstLen); + InputStream second = createLimited(tempFile, firstLen, secondLen); + List streams = Arrays.asList(first, second); + + SequenceInputStream sin = new SequenceInputStream(Collections.enumeration(streams)); + Assert.assertTrue(LowCopyFileSegmentsJniByteInputStream.isSupported(sin)); + + LowCopyFileSegmentsJniByteInputStream in = new LowCopyFileSegmentsJniByteInputStream(sin); + ByteBuffer buffer = PlatformDependent.allocateDirectNoCleaner(bytes.length); + long addr = PlatformDependent.directBufferAddress(buffer); + + long firstRead = in.read(addr, 3); + long secondRead = in.read(addr + firstRead, bytes.length - firstRead); + long totalRead = firstRead + secondRead; + + Assert.assertEquals(bytes.length, totalRead); + Assert.assertEquals(bytes.length, in.tell()); + + buffer.limit(bytes.length); + byte[] out = new byte[bytes.length]; + buffer.get(out); + Assert.assertArrayEquals(bytes, out); + + in.close(); + } + + @Test + public void testReadNonContiguousSegments() throws Exception { + byte[] bytes = "abcdefghij123456789".getBytes(StandardCharsets.UTF_8); + try (FileOutputStream out = new FileOutputStream(tempFile)) { + out.write(bytes); + } + + // Select two non-contiguous segments: [2,5) and [10,15) + InputStream seg1 = createLimited(tempFile, 2, 3); // "cde" + InputStream seg2 = createLimited(tempFile, 10, 5); // "12345" + List streams = Arrays.asList(seg1, seg2); + + SequenceInputStream sin = new SequenceInputStream(Collections.enumeration(streams)); + Assert.assertTrue(LowCopyFileSegmentsJniByteInputStream.isSupported(sin)); + + LowCopyFileSegmentsJniByteInputStream in = new LowCopyFileSegmentsJniByteInputStream(sin); + ByteBuffer buffer = PlatformDependent.allocateDirectNoCleaner(8); + long addr = PlatformDependent.directBufferAddress(buffer); + + long read = in.read(addr, 8); + Assert.assertEquals(8, read); + Assert.assertEquals(8, in.tell()); + + buffer.limit(8); + byte[] out = new byte[8]; + buffer.get(out); + // Expected: "cde12345" + Assert.assertArrayEquals("cde12345".getBytes(StandardCharsets.UTF_8), out); + + in.close(); + } + + private static InputStream createLimited(File file, long offset, long length) throws Exception { + FileInputStream fin = new FileInputStream(file); + long skipped = 0; + while (skipped < offset) { + long step = fin.skip(offset - skipped); + if (step <= 0) { + throw new IllegalStateException("Unable to skip to offset " + offset); + } + skipped += step; + } + return new LimitedInputStream(fin, length); + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala index 6a0f72cf9cfa..70e6b1f08f75 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala @@ -21,7 +21,7 @@ import org.apache.spark.network.util.JavaUtils import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.network.util.TransportConf -import java.io.{BufferedInputStream, EOFException, File, FileInputStream, InputStream, IOException, RandomAccessFile, SequenceInputStream} +import java.io.{EOFException, File, FileInputStream, InputStream, IOException, RandomAccessFile, SequenceInputStream} import java.nio.ByteBuffer import java.nio.channels.FileChannel import java.util.Vector @@ -102,7 +102,7 @@ class FileSegmentsManagedBuffer( val fis = new FileInputStream(file) filesToClose.add(fis) skipFully(fis, offset) - streams.add(new BufferedInputStream(new LimitedInputStream(fis, length))) + streams.add(new LimitedInputStream(fis, length)) } shouldCloseFile = false new SequenceInputStream(streams.elements()) From 9ceb7fd0d0779f753c4c757590536d21d4123e91 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 5 Feb 2026 11:14:32 +0800 Subject: [PATCH 14/20] Add FileSegmentsInputStream to handle segments read --- .../shuffle/sort/FileSegmentsBuffer.scala | 47 +----- .../sort/FileSegmentsInputStream.scala | 142 ++++++++++++++++++ .../sort/FileSegmentsInputStreamSuite.scala | 107 +++++++++++++ 3 files changed, 251 insertions(+), 45 deletions(-) create mode 100644 gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala create mode 100644 gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStreamSuite.scala diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala index 70e6b1f08f75..b2736e71b069 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala @@ -18,13 +18,11 @@ package org.apache.spark.shuffle.sort import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.util.JavaUtils -import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.network.util.TransportConf -import java.io.{EOFException, File, FileInputStream, InputStream, IOException, RandomAccessFile, SequenceInputStream} +import java.io.{EOFException, File, InputStream, RandomAccessFile} import java.nio.ByteBuffer import java.nio.channels.FileChannel -import java.util.Vector /** A {@link ManagedBuffer} backed by a set of segments in a file. */ class FileSegmentsManagedBuffer( @@ -71,49 +69,8 @@ class FileSegmentsManagedBuffer( } } - @throws[IOException] - private def skipFully(in: InputStream, n: Long): Unit = { - var remaining = n - val toSkip = n - while (remaining > 0L) { - val amt = in.skip(remaining) - if (amt == 0L) { - // Try reading a byte to see if we are at EOF - if (in.read() == -1) { - val skipped = toSkip - remaining - throw new EOFException( - s"reached end of stream after skipping $skipped bytes; $toSkip bytes expected") - } - remaining -= 1 - } else { - remaining -= amt - } - } - } - - @throws[IOException] override def createInputStream(): InputStream = { - val streams = new Vector[InputStream]() - val filesToClose = new Vector[FileInputStream]() - var shouldCloseFile = true - try { - segments.foreach { - case (offset, length) => - val fis = new FileInputStream(file) - filesToClose.add(fis) - skipFully(fis, offset) - streams.add(new LimitedInputStream(fis, length)) - } - shouldCloseFile = false - new SequenceInputStream(streams.elements()) - } finally { - if (shouldCloseFile) { - val it = filesToClose.elements() - while (it.hasMoreElements) { - JavaUtils.closeQuietly(it.nextElement()) - } - } - } + new FileSegmentsInputStream(file, segments) } override def convertToNetty(): AnyRef = { diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala new file mode 100644 index 000000000000..949c360bd985 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import java.io.{File, InputStream, IOException, RandomAccessFile} +import java.nio.ByteBuffer + +/** + * An InputStream that reads a list of non-contiguous file segments sequentially. + * + * @param file + * the file to read + * @param segments + * a list of (file_offset, size) segments to read in sequence + */ +class FileSegmentsInputStream(file: File, segments: Seq[(Long, Long)]) extends InputStream { + private val raf = new RandomAccessFile(file, "r") + private val channel = raf.getChannel + + private var currentIndex = 0 + private var currentOffset = 0L + private var remainingInSegment = 0L + private var closed = false + + if (segments.nonEmpty) { + currentOffset = segments.head._1 + remainingInSegment = segments.head._2 + channel.position(currentOffset) + } + + override def read(): Int = { + val buf = new Array[Byte](1) + val n = read(buf, 0, 1) + if (n <= 0) { + -1 + } else { + buf(0) & 0xff + } + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = { + if (closed) { + throw new IOException("Stream is closed") + } + if (b == null) { + throw new NullPointerException("buffer is null") + } + if (off < 0 || len < 0 || off + len > b.length) { + throw new IndexOutOfBoundsException(s"offset=$off, length=$len, buffer length=${b.length}") + } + if (len == 0) { + return 0 + } + + var totalRead = 0 + var remaining = len + while (remaining > 0 && currentIndex < segments.length) { + if (remainingInSegment == 0) { + if (!advanceSegment()) { + if (totalRead == 0) return -1 + return totalRead + } + } + val bytesToRead = Math.min(remainingInSegment, remaining.toLong).toInt + val buf = ByteBuffer.wrap(b, off + totalRead, bytesToRead) + val n = channel.read(buf) + if (n == -1) { + if (totalRead == 0) return -1 + return totalRead + } + remainingInSegment -= n + totalRead += n + remaining -= n + } + + if (totalRead == 0 && currentIndex >= segments.length) { + -1 + } else { + totalRead + } + } + + override def skip(n: Long): Long = { + if (closed) { + throw new IOException("Stream is closed") + } + if (n <= 0) { + return 0L + } + var remaining = n + var skipped = 0L + while (remaining > 0 && currentIndex < segments.length) { + if (remainingInSegment == 0 && !advanceSegment()) { + return skipped + } + val toSkip = Math.min(remainingInSegment, remaining) + channel.position(channel.position() + toSkip) + remainingInSegment -= toSkip + remaining -= toSkip + skipped += toSkip + } + skipped + } + + override def available(): Int = { + if (remainingInSegment > Int.MaxValue) Int.MaxValue else remainingInSegment.toInt + } + + override def close(): Unit = { + if (!closed) { + closed = true + channel.close() + raf.close() + } + } + + private def advanceSegment(): Boolean = { + currentIndex += 1 + if (currentIndex >= segments.length) { + return false + } + val (offset, length) = segments(currentIndex) + currentOffset = offset + remainingInSegment = length + channel.position(currentOffset) + true + } +} diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStreamSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStreamSuite.scala new file mode 100644 index 000000000000..024a7f5e5556 --- /dev/null +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStreamSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import java.io.{File, FileOutputStream} +import java.nio.file.Files + +class FileSegmentsInputStreamSuite extends AnyFunSuite with BeforeAndAfterAll { + private var tempFile: File = _ + private val fileData: Array[Byte] = (0 until 100).map(_.toByte).toArray + + override def beforeAll(): Unit = { + tempFile = Files.createTempFile("fsegments-inputstream", ".bin").toFile + val fos = new FileOutputStream(tempFile) + fos.write(fileData) + fos.close() + } + + override def afterAll(): Unit = { + if (tempFile != null && tempFile.exists()) tempFile.delete() + } + + test("read single segment") { + val segments = Seq((10L, 5L)) + val in = new FileSegmentsInputStream(tempFile, segments) + val out = readAll(in, 5) + assert(out.sameElements(fileData.slice(10, 15))) + assert(in.read() == -1) + in.close() + } + + test("read multiple non-contiguous segments") { + val segments = Seq((2L, 3L), (10L, 4L), (50L, 2L)) + val in = new FileSegmentsInputStream(tempFile, segments) + val out = readAll(in, 9) + val expected = fileData.slice(2, 5) ++ fileData.slice(10, 14) ++ fileData.slice(50, 52) + assert(out.sameElements(expected)) + assert(in.read() == -1) + in.close() + } + + test("read single bytes sequentially") { + val segments = Seq((20L, 3L)) + val in = new FileSegmentsInputStream(tempFile, segments) + val b0 = in.read() + val b1 = in.read() + val b2 = in.read() + val b3 = in.read() + assert(b0 == (fileData(20) & 0xff)) + assert(b1 == (fileData(21) & 0xff)) + assert(b2 == (fileData(22) & 0xff)) + assert(b3 == -1) + in.close() + } + + test("skip bytes across segments") { + val segments = Seq((10L, 3L), (20L, 4L)) + val in = new FileSegmentsInputStream(tempFile, segments) + // Skip 2 bytes in the first segment + val skipped1 = in.skip(2) + assert(skipped1 == 2) + // Should now be at fileData(12) + val b = in.read() + assert(b == (fileData(12) & 0xff)) + // Skip 2 bytes (should move to the second segment) + val skipped2 = in.skip(2) + assert(skipped2 == 2) + // Should now be at fileData(22) + val b2 = in.read() + assert(b2 == (fileData(22) & 0xff)) + // Skip more than remaining (should reach end) + val skipped3 = in.skip(100) + assert(skipped3 == 1) // 1 byte left in segments + assert(in.read() == -1) + in.close() + } + + private def readAll(in: FileSegmentsInputStream, size: Int): Array[Byte] = { + val buffer = new Array[Byte](size) + var total = 0 + var n = 0 + while (n >= 0 && total < size) { + n = in.read(buffer, total, size - total) + if (n > 0) { + total += n + } + } + buffer + } +} From 66d3742fbdc5bc407f0506cc8f9be76e9f08116c Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 5 Feb 2026 13:40:48 +0800 Subject: [PATCH 15/20] LowCopyFileSegmentsJniByteInputStream use FileSegmentsInputStream --- gluten-arrow/pom.xml | 7 + ...LowCopyFileSegmentsJniByteInputStream.java | 127 +++--------------- ...opyFileSegmentsJniByteInputStreamTest.java | 52 ++++--- .../sort/FileSegmentsInputStream.scala | 52 +++++++ .../sort/FileSegmentsInputStreamSuite.scala | 19 +++ 5 files changed, 121 insertions(+), 136 deletions(-) diff --git a/gluten-arrow/pom.xml b/gluten-arrow/pom.xml index a5552719b0f7..e83a710ce57f 100644 --- a/gluten-arrow/pom.xml +++ b/gluten-arrow/pom.xml @@ -34,6 +34,13 @@ ${project.version} compile + + org.apache.gluten + gluten-substrait + ${project.version} + test-jar + test + org.apache.gluten ${sparkshim.artifactId} diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStream.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStream.java index bd6259b01786..845e2f3d3f10 100644 --- a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStream.java +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStream.java @@ -18,84 +18,48 @@ import org.apache.gluten.exception.GlutenException; +import org.apache.spark.shuffle.sort.FileSegmentsInputStream; + import java.io.IOException; import java.io.InputStream; -import java.io.SequenceInputStream; -import java.lang.reflect.Field; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Enumeration; -import java.util.List; /** - * This implementation is targeted to optimize against Spark's {@link - * org.apache.spark.network.buffer.FileSegmentManagedBuffer} to make sure shuffle data is shared - * over JNI without unnecessary copy. + * This implementation is targeted to optimize against Gluten's {@link + * org.apache.spark.shuffle.sort.FileSegmentsManagedBuffer} to make sure shuffle data is shared over + * JNI without unnecessary copy. */ public class LowCopyFileSegmentsJniByteInputStream implements JniByteInputStream { - private static final Field FIELD_SequenceInputStream_in; - private static final Field FIELD_SequenceInputStream_e; - - static { - try { - FIELD_SequenceInputStream_in = SequenceInputStream.class.getDeclaredField("in"); - FIELD_SequenceInputStream_in.setAccessible(true); - FIELD_SequenceInputStream_e = SequenceInputStream.class.getDeclaredField("e"); - FIELD_SequenceInputStream_e.setAccessible(true); - } catch (NoSuchFieldException e) { - throw new GlutenException(e); - } - } - - private final InputStream in; - private final List segments; - - private int currentIndex = 0; + private final FileSegmentsInputStream fsin; private long bytesRead = 0L; + private long left; public LowCopyFileSegmentsJniByteInputStream(InputStream in) { - this.in = in; // to prevent underlying netty buffer from being collected by GC final InputStream unwrapped = JniByteInputStreams.unwrapSparkInputStream(in); - final SequenceInputStream sin = (SequenceInputStream) unwrapped; - final List streams = collectStreams(sin, false); - this.segments = buildSegments(streams); + this.fsin = (FileSegmentsInputStream) unwrapped; + left = this.fsin.remainingBytes(); } public static boolean isSupported(InputStream in) { - if (!(in instanceof SequenceInputStream)) { - return false; - } - final SequenceInputStream sin = (SequenceInputStream) in; - final List streams = collectStreams(sin, true); - for (InputStream stream : streams) { - final InputStream unwrapped = JniByteInputStreams.unwrapSparkInputStream(stream); - if (!LowCopyFileSegmentJniByteInputStream.isSupported(unwrapped)) { - return false; - } - } - return true; + return in instanceof FileSegmentsInputStream; } @Override public long read(long destAddress, long maxSize) { - if (maxSize <= 0) { + long bytesToRead = Math.min(left, maxSize); + if (bytesToRead == 0) { return 0; } - long remaining = maxSize; - long totalRead = 0L; - while (remaining > 0 && currentIndex < segments.size()) { - LowCopyFileSegmentJniByteInputStream segment = segments.get(currentIndex); - long request = remaining; - long bytes = segment.read(destAddress + totalRead, request); + try { + long bytes = fsin.read(destAddress, bytesToRead); if (bytes == 0) { - currentIndex++; - continue; + return 0; } bytesRead += bytes; - remaining -= bytes; - totalRead += bytes; + left -= bytes; + return bytes; + } catch (IOException e) { + throw new GlutenException(e); } - return totalRead; } @Override @@ -106,60 +70,9 @@ public long tell() { @Override public void close() { try { - for (LowCopyFileSegmentJniByteInputStream segment : segments) { - segment.close(); - } - in.close(); + fsin.close(); } catch (IOException e) { throw new GlutenException(e); } } - - private static List buildSegments( - List streams) { - List segments = new ArrayList<>(streams.size()); - for (InputStream stream : streams) { - segments.add(new LowCopyFileSegmentJniByteInputStream(stream)); - } - return segments; - } - - private static List collectStreams(SequenceInputStream sin, boolean restore) { - final List streams = new ArrayList<>(); - final InputStream current; - final Enumeration enumeration; - try { - current = (InputStream) FIELD_SequenceInputStream_in.get(sin); - enumeration = (Enumeration) FIELD_SequenceInputStream_e.get(sin); - } catch (IllegalAccessException e) { - throw new GlutenException(e); - } - if (current != null) { - streams.add(current); - } - if (enumeration != null) { - while (enumeration.hasMoreElements()) { - streams.add(enumeration.nextElement()); - } - } - - if (restore) { - // restore the enumeration in SequenceInputStream.e - try { - if (streams.isEmpty()) { - FIELD_SequenceInputStream_e.set(sin, Collections.enumeration(Collections.emptyList())); - } else { - if (streams.size() == 1) { - FIELD_SequenceInputStream_e.set(sin, Collections.enumeration(Collections.emptyList())); - } else { - FIELD_SequenceInputStream_e.set( - sin, Collections.enumeration(streams.subList(1, streams.size()))); - } - } - } catch (IllegalAccessException e) { - throw new GlutenException(e); - } - } - return streams; - } } diff --git a/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java b/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java index f7d2efbf78f4..a0d06a2e0c63 100644 --- a/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java +++ b/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java @@ -17,23 +17,22 @@ package org.apache.gluten.vectorized; import io.netty.util.internal.PlatformDependent; -import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.shuffle.sort.FileSegmentsInputStream; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import java.io.File; -import java.io.FileInputStream; import java.io.FileOutputStream; -import java.io.InputStream; -import java.io.SequenceInputStream; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.util.Arrays; -import java.util.Collections; -import java.util.List; + +import scala.Tuple2; +import scala.collection.JavaConverters; +import scala.collection.Seq; public class LowCopyFileSegmentsJniByteInputStreamTest { private File tempFile; @@ -59,14 +58,17 @@ public void testReadAcrossSegments() throws Exception { int firstLen = 4; int secondLen = bytes.length - firstLen; - InputStream first = createLimited(tempFile, 0, firstLen); - InputStream second = createLimited(tempFile, firstLen, secondLen); - List streams = Arrays.asList(first, second); + Seq> segments = + toScalaSeq( + Arrays.asList( + new Tuple2<>(0L, (long) firstLen), + new Tuple2<>((long) firstLen, (long) secondLen))); - SequenceInputStream sin = new SequenceInputStream(Collections.enumeration(streams)); - Assert.assertTrue(LowCopyFileSegmentsJniByteInputStream.isSupported(sin)); + FileSegmentsInputStream segmentStream = new FileSegmentsInputStream(tempFile, segments); + Assert.assertTrue(LowCopyFileSegmentsJniByteInputStream.isSupported(segmentStream)); - LowCopyFileSegmentsJniByteInputStream in = new LowCopyFileSegmentsJniByteInputStream(sin); + LowCopyFileSegmentsJniByteInputStream in = + new LowCopyFileSegmentsJniByteInputStream(segmentStream); ByteBuffer buffer = PlatformDependent.allocateDirectNoCleaner(bytes.length); long addr = PlatformDependent.directBufferAddress(buffer); @@ -93,14 +95,14 @@ public void testReadNonContiguousSegments() throws Exception { } // Select two non-contiguous segments: [2,5) and [10,15) - InputStream seg1 = createLimited(tempFile, 2, 3); // "cde" - InputStream seg2 = createLimited(tempFile, 10, 5); // "12345" - List streams = Arrays.asList(seg1, seg2); + Seq> segments = + toScalaSeq(Arrays.asList(new Tuple2<>(2L, 3L), new Tuple2<>(10L, 5L))); - SequenceInputStream sin = new SequenceInputStream(Collections.enumeration(streams)); - Assert.assertTrue(LowCopyFileSegmentsJniByteInputStream.isSupported(sin)); + FileSegmentsInputStream segmentStream = new FileSegmentsInputStream(tempFile, segments); + Assert.assertTrue(LowCopyFileSegmentsJniByteInputStream.isSupported(segmentStream)); - LowCopyFileSegmentsJniByteInputStream in = new LowCopyFileSegmentsJniByteInputStream(sin); + LowCopyFileSegmentsJniByteInputStream in = + new LowCopyFileSegmentsJniByteInputStream(segmentStream); ByteBuffer buffer = PlatformDependent.allocateDirectNoCleaner(8); long addr = PlatformDependent.directBufferAddress(buffer); @@ -117,16 +119,8 @@ public void testReadNonContiguousSegments() throws Exception { in.close(); } - private static InputStream createLimited(File file, long offset, long length) throws Exception { - FileInputStream fin = new FileInputStream(file); - long skipped = 0; - while (skipped < offset) { - long step = fin.skip(offset - skipped); - if (step <= 0) { - throw new IllegalStateException("Unable to skip to offset " + offset); - } - skipped += step; - } - return new LimitedInputStream(fin, length); + private static Seq> toScalaSeq( + java.util.List> segments) { + return JavaConverters.asScalaBuffer(segments).toSeq(); } } diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala index 949c360bd985..2d79587c1149 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.shuffle.sort +import _root_.io.netty.util.internal.PlatformDependent + import java.io.{File, InputStream, IOException, RandomAccessFile} import java.nio.ByteBuffer @@ -42,6 +44,55 @@ class FileSegmentsInputStream(file: File, segments: Seq[(Long, Long)]) extends I channel.position(currentOffset) } + /** Returns the number of bytes remaining to be read from all segments. */ + def remainingBytes: Long = { + if (closed) return 0L + var total = remainingInSegment + var idx = currentIndex + 1 + while (idx < segments.length) { + total += segments(idx)._2 + idx += 1 + } + total + } + + /** + * Read bytes directly into native memory at the given address. + * + * @param destAddress + * destination memory address + * @param maxSize + * max number of bytes to read + * @return + * number of bytes read, or 0 if no more data + */ + @throws[IOException] + def read(destAddress: Long, maxSize: Long): Long = { + if (closed) { + throw new IOException("Stream is closed") + } + + var totalRead = 0 + var remaining = maxSize + while (remaining > 0 && currentIndex < segments.length) { + if (remainingInSegment == 0) { + if (!advanceSegment()) { + return totalRead + } + } + val bytesThisRead = Math.min(remainingInSegment, remaining.toLong).toInt + val direct = PlatformDependent.directBuffer(destAddress + totalRead, bytesThisRead) + val n = channel.read(direct) + if (n == -1) { + return totalRead + } + remainingInSegment -= n + totalRead += n + remaining -= n + } + totalRead + } + override def read(): Int = { val buf = new Array[Byte](1) val n = read(buf, 0, 1) @@ -120,6 +171,7 @@ class FileSegmentsInputStream(file: File, segments: Seq[(Long, Long)]) extends I if (remainingInSegment > Int.MaxValue) Int.MaxValue else remainingInSegment.toInt } + @throws[IOException] override def close(): Unit = { if (!closed) { closed = true diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStreamSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStreamSuite.scala index 024a7f5e5556..47fe9fb4fedd 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStreamSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStreamSuite.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.shuffle.sort +import _root_.io.netty.util.internal.PlatformDependent import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite @@ -92,6 +93,24 @@ class FileSegmentsInputStreamSuite extends AnyFunSuite with BeforeAndAfterAll { in.close() } + test("read into native memory") { + val segments = Seq((2L, 3L), (10L, 4L)) + val in = new FileSegmentsInputStream(tempFile, segments) + val expected = fileData.slice(2, 5) ++ fileData.slice(10, 14) + + val buffer = java.nio.ByteBuffer.allocateDirect(expected.length) + val addr = PlatformDependent.directBufferAddress(buffer) + val read = in.read(addr, expected.length) + assert(read == expected.length) + + buffer.limit(expected.length) + val out = new Array[Byte](expected.length) + buffer.get(out) + assert(out.sameElements(expected)) + assert(in.read(addr, 1) == 0) + in.close() + } + private def readAll(in: FileSegmentsInputStream, size: Int): Array[Byte] = { val buffer = new Array[Byte](size) var total = 0 From e8f1b5434a6f82c4a0380c20677a3ea2962c8b08 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Fri, 6 Feb 2026 14:14:20 +0800 Subject: [PATCH 16/20] Avoid frequently calling fp.Tell() --- cpp/core/shuffle/LocalPartitionWriter.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/cpp/core/shuffle/LocalPartitionWriter.cc b/cpp/core/shuffle/LocalPartitionWriter.cc index 82a1bd298925..8222a0ad4244 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.cc +++ b/cpp/core/shuffle/LocalPartitionWriter.cc @@ -681,12 +681,11 @@ arrow::Status LocalPartitionWriter::flushCachedPlayloads() { ARROW_ASSIGN_OR_RAISE(int64_t endInDataFile, dataFileOs_->Tell()); for (auto pid = 0; pid < numPartitions_; ++pid) { auto startInDataFile = endInDataFile; - RETURN_NOT_OK(mergeSpills(pid, dataFileOs_.get())); - RETURN_NOT_OK(payloadCache_->writeIncremental(pid, dataFileOs_.get())); - ARROW_ASSIGN_OR_RAISE(endInDataFile, dataFileOs_->Tell()); - - auto bytesWritten = endInDataFile - startInDataFile; - if (bytesWritten > 0) { + ARROW_ASSIGN_OR_RAISE(int64_t spillWrittenBytes, mergeSpills(pid, dataFileOs_.get())); + ARROW_ASSIGN_OR_RAISE(bool cachePlayoadWritten, payloadCache_->writeIncremental(pid, dataFileOs_.get())); + if (spillWrittenBytes > 0 || cachePlayoadWritten) { + ARROW_ASSIGN_OR_RAISE(endInDataFile, dataFileOs_->Tell()); + auto bytesWritten = endInDataFile - startInDataFile; partitionSegments_[pid].emplace_back(startInDataFile, bytesWritten); partitionLengths_[pid] += bytesWritten; } From 9cac12785f1ccdc2816a7a3f3199f6951b6c4c9f Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Tue, 3 Mar 2026 13:12:36 +0800 Subject: [PATCH 17/20] LocalPartitionWriter: fix output format for sortEvict --- cpp/core/shuffle/LocalPartitionWriter.cc | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/cpp/core/shuffle/LocalPartitionWriter.cc b/cpp/core/shuffle/LocalPartitionWriter.cc index 8222a0ad4244..49c4be54311a 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.cc +++ b/cpp/core/shuffle/LocalPartitionWriter.cc @@ -699,18 +699,16 @@ arrow::Status LocalPartitionWriter::writeMemoryPayload(uint32_t partitionId, std ARROW_ASSIGN_OR_RAISE(dataFileOs_, openFile(dataFile_, options_->shuffleFileBufferSize)); } - auto shouldCompress = codec_ != nullptr && payload->numRows() >= options_->compressionThreshold; - - ARROW_ASSIGN_OR_RAISE( - auto block, - payload->toBlockPayload(shouldCompress ? Payload::kToBeCompressed : Payload::kUncompressed, payloadPool_.get(), codec_.get())); - ARROW_ASSIGN_OR_RAISE(int64_t startOffset, dataFileOs_->Tell()); - uint8_t blockType = static_cast(BlockType::kPlainPayload); - RETURN_NOT_OK(dataFileOs_->Write(&blockType, sizeof(blockType))); - RETURN_NOT_OK(block->serialize(dataFileOs_.get())); - compressTime_ += block->getCompressTime(); - writeTime_ += block->getWriteTime(); + if (codec_ != nullptr) { + ARROW_ASSIGN_OR_RAISE(auto compressOs, ShuffleCompressedOutputStream::Make(codec_.get(), options_->compressionBufferSize, dataFileOs_, arrow::default_memory_pool())); + RETURN_NOT_OK(payload->serialize(compressOs.get())); + RETURN_NOT_OK(compressOs->Flush()); + compressTime_ += compressOs->compressTime(); + RETURN_NOT_OK(compressOs->Close()); + } else { + RETURN_NOT_OK(payload->serialize(dataFileOs_.get())); + } ARROW_ASSIGN_OR_RAISE(int64_t endOffset, dataFileOs_->Tell()); auto bytesWritten = endOffset - startOffset; partitionSegments_[partitionId].emplace_back(startOffset, bytesWritten); From 93b3f53d753eb8cbe8f609e59b3463e5072cfc03 Mon Sep 17 00:00:00 2001 From: Wangyang Guo Date: Thu, 5 Mar 2026 10:21:13 +0800 Subject: [PATCH 18/20] Various fixes --- .../spark/shuffle/ColumnarShuffleWriter.scala | 3 ++- cpp/core/shuffle/LocalPartitionWriter.cc | 15 ++++++++------- cpp/core/shuffle/LocalPartitionWriter.h | 2 +- .../sort/ColumnarIndexShuffleBlockResolver.scala | 2 +- .../shuffle/sort/DiscontiguousFileRegion.scala | 2 +- .../spark/shuffle/sort/FileSegmentsBuffer.scala | 6 ++++-- .../shuffle/sort/FileSegmentsInputStream.scala | 2 +- .../ColumnarIndexShuffleBlockResolverSuite.scala | 2 +- ...e.scala => DiscontiguousFileRegionSuite.scala} | 0 9 files changed, 19 insertions(+), 15 deletions(-) rename gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/{DiscontigousFileRegionSuite.scala => DiscontiguousFileRegionSuite.scala} (100%) diff --git a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala index 1ab23d9bf12f..50649a8c7777 100644 --- a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala +++ b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala @@ -58,12 +58,13 @@ class ColumnarShuffleWriter[K, V]( shuffleBlockResolver match { case resolver: ColumnarIndexShuffleBlockResolver => if (resolver.canUseNewFormat() && !GlutenConfig.get.columnarShuffleEnableDictionary) { - // For Dictionary encoding, the dict only finializes after all batches are processed, + // For Dictionary encoding, the dict only finalizes after all batches are processed, // and dict is required to saved at the head of the partition data. // So we cannot use multiple segments to save partition data incrementally. partitionUseMultipleSegments = true columnarShuffleBlockResolver = resolver } + case _ => } protected val isSort: Boolean = dep.shuffleWriterType == SortShuffleWriterType diff --git a/cpp/core/shuffle/LocalPartitionWriter.cc b/cpp/core/shuffle/LocalPartitionWriter.cc index 49c4be54311a..fe3fb096d3d0 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.cc +++ b/cpp/core/shuffle/LocalPartitionWriter.cc @@ -30,7 +30,6 @@ #include #include #include -#include "LocalPartitionWriter.h" namespace gluten { @@ -671,7 +670,7 @@ arrow::Status LocalPartitionWriter::writeCachedPayloads(uint32_t partitionId, ar return arrow::Status::OK(); } -arrow::Status LocalPartitionWriter::flushCachedPlayloads() { +arrow::Status LocalPartitionWriter::flushCachedPayloads() { if (payloadCache_ == nullptr) { return arrow::Status::OK(); } @@ -682,15 +681,15 @@ arrow::Status LocalPartitionWriter::flushCachedPlayloads() { for (auto pid = 0; pid < numPartitions_; ++pid) { auto startInDataFile = endInDataFile; ARROW_ASSIGN_OR_RAISE(int64_t spillWrittenBytes, mergeSpills(pid, dataFileOs_.get())); - ARROW_ASSIGN_OR_RAISE(bool cachePlayoadWritten, payloadCache_->writeIncremental(pid, dataFileOs_.get())); - if (spillWrittenBytes > 0 || cachePlayoadWritten) { + ARROW_ASSIGN_OR_RAISE(bool cachePayloadWritten, payloadCache_->writeIncremental(pid, dataFileOs_.get())); + if (spillWrittenBytes > 0 || cachePayloadWritten) { ARROW_ASSIGN_OR_RAISE(endInDataFile, dataFileOs_->Tell()); auto bytesWritten = endInDataFile - startInDataFile; partitionSegments_[pid].emplace_back(startInDataFile, bytesWritten); partitionLengths_[pid] += bytesWritten; } } - spills_.clear(); + return arrow::Status::OK(); } @@ -726,7 +725,7 @@ arrow::Status LocalPartitionWriter::stop(ShuffleWriterMetrics* metrics, int64_t& if (usePartitionMultipleSegments_) { RETURN_NOT_OK(finishSpill()); RETURN_NOT_OK(finishMerger()); - RETURN_NOT_OK(flushCachedPlayloads()); + RETURN_NOT_OK(flushCachedPayloads()); RETURN_NOT_OK(writeIndexFile()); } else if (useSpillFileAsDataFile_) { ARROW_ASSIGN_OR_RAISE(auto spill, spiller_->finish()); @@ -872,7 +871,9 @@ arrow::Status LocalPartitionWriter::hashEvict( for (auto& payload : merged) { RETURN_NOT_OK(payloadCache_->cache(partitionId, std::move(payload))); } - RETURN_NOT_OK(flushCachedPlayloads()); + if (usePartitionMultipleSegments_) { + RETURN_NOT_OK(flushCachedPayloads()); + } merged.clear(); } return arrow::Status::OK(); diff --git a/cpp/core/shuffle/LocalPartitionWriter.h b/cpp/core/shuffle/LocalPartitionWriter.h index 6d3051d7e571..c44b0810b1f6 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.h +++ b/cpp/core/shuffle/LocalPartitionWriter.h @@ -105,7 +105,7 @@ class LocalPartitionWriter : public PartitionWriter { arrow::Status writeCachedPayloads(uint32_t partitionId, arrow::io::OutputStream* os) const; - arrow::Status flushCachedPlayloads(); + arrow::Status flushCachedPayloads(); arrow::Status writeMemoryPayload(uint32_t partitionId, std::unique_ptr inMemoryPayload); diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala index 450ecf5bfa67..ebd2b3570ff3 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala @@ -59,7 +59,7 @@ class ColumnarIndexShuffleBlockResolver( startId: Int, endId: Int): Seq[(Long, Long)] = { // New Index Format: - // To support a partiton index with multiple segments, + // To support a partition index with multiple segments, // the index file is composed of three parts: // 1) Partition Index: index_0, index_1, ... index_N // Each index_i is an 8-byte long integer representing the byte offset in the file diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala index 2dbc3e85c3c5..06e863550d45 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.shuffle.sort; +package org.apache.spark.shuffle.sort // org.apache.spark.shuffle.sort has io objects conflicting with Netty's io objects. import _root_.io.netty.channel.FileRegion diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala index b2736e71b069..271c0d41eb29 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala @@ -38,14 +38,16 @@ class FileSegmentsManagedBuffer( override def release(): ManagedBuffer = this override def nioByteBuffer(): ByteBuffer = { - val buffer = ByteBuffer.allocate(size().toInt) + val totalSize = size() val channel = new RandomAccessFile(file, "r").getChannel try { - if (conf != null && size() >= conf.memoryMapBytes() && segments.length == 1) { + if (conf != null && totalSize >= conf.memoryMapBytes() && segments.length == 1) { // zero-copy for single segment that is large enough val (offset, length) = segments.head return channel.map(FileChannel.MapMode.READ_ONLY, offset, length) } + val totalSizeInt = Math.toIntExact(totalSize) + val buffer = ByteBuffer.allocate(totalSizeInt) var destPos = 0 segments.foreach { case (offset, length) => diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala index 2d79587c1149..41ef6ffa2b4a 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala @@ -72,7 +72,7 @@ class FileSegmentsInputStream(file: File, segments: Seq[(Long, Long)]) extends I throw new IOException("Stream is closed") } - var totalRead = 0 + var totalRead = 0L var remaining = maxSize while (remaining > 0 && currentIndex < segments.length) { if (remainingInSegment == 0) { diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala index fb6288d6739b..14c7e5da8109 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala @@ -162,7 +162,7 @@ class ColumnarIndexShuffleBlockResolverSuite val buffer = resolver.getBlockData(blockId) val bytes = readManagedBuffer(buffer) // Partition 2 segments: (10,20), (100,50) - assert(bytes.length == 70, "Expected 70 bytes, got ${bytes.length}") + assert(bytes.length == 70, s"Expected 70 bytes, got ${bytes.length}") val expected = ((10 until 30) ++ (100 until 150)).map(_.toByte).toArray assert( bytes.sameElements(expected), diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegionSuite.scala similarity index 100% rename from gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontigousFileRegionSuite.scala rename to gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegionSuite.scala From 5857915958ad6d411370cb3de97200e9fc333dbc Mon Sep 17 00:00:00 2001 From: Guo Wangyang Date: Fri, 6 Mar 2026 10:29:14 +0800 Subject: [PATCH 19/20] FileSegmentsManagedBuffer: empty segments should work --- .../sort/DiscontiguousFileRegion.scala | 8 ++- .../sort/FileSegmentsManagedBufferSuite.scala | 68 +++++++++++++++---- 2 files changed, 60 insertions(+), 16 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala index 06e863550d45..660c16b5f45a 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala @@ -36,18 +36,22 @@ class DiscontiguousFileRegion( ) extends AbstractReferenceCounted with FileRegion { - require(segments.nonEmpty, "Must provide at least one segment") + // segments may be empty to represent a valid zero-length region (e.g., an empty shuffle block) private val totalCount: Long = segments.map(_._2).sum private var bytesTransferred: Long = 0L private var fileChannel: FileChannel = null private var closed: Boolean = false - if (!lazyOpen) { + if (!lazyOpen && totalCount > 0) { ensureOpen() } private def ensureOpen(): Unit = { + // For zero-length regions, there's nothing to read; don't open the file. + if (totalCount == 0L) { + return + } if (closed) { throw new IOException("File is already closed") } diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala index b171b4712bd5..cbda66f5024a 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala @@ -50,6 +50,20 @@ class FileSegmentsManagedBufferSuite extends AnyFunSuite with BeforeAndAfterAll override def close(): Unit = { open = false } } + private var tempFile: File = _ + private val fileData: Array[Byte] = (0 until 100).map(_.toByte).toArray + + override def beforeAll(): Unit = { + tempFile = Files.createTempFile("fsegments-test", ".bin").toFile + val fos = new FileOutputStream(tempFile) + fos.write(fileData) + fos.close() + } + + override def afterAll(): Unit = { + if (tempFile != null && tempFile.exists()) tempFile.delete() + } + test("convertToNetty returns FileRegion with correct count and content") { val conf = new FakeTransportConf() val segments = Seq((2L, 3L), (10L, 4L)) @@ -79,20 +93,6 @@ class FileSegmentsManagedBufferSuite extends AnyFunSuite with BeforeAndAfterAll assert(target.toByteArray.sameElements(fileData.slice(5, 8))) } - private var tempFile: File = _ - private val fileData: Array[Byte] = (0 until 100).map(_.toByte).toArray - - override def beforeAll(): Unit = { - tempFile = Files.createTempFile("fsegments-test", ".bin").toFile - val fos = new FileOutputStream(tempFile) - fos.write(fileData) - fos.close() - } - - override def afterAll(): Unit = { - if (tempFile != null && tempFile.exists()) tempFile.delete() - } - test("size returns sum of segment lengths") { val conf = null.asInstanceOf[TransportConf] val segments = Seq((2L, 10L), (20L, 5L), (50L, 15L)) @@ -100,6 +100,13 @@ class FileSegmentsManagedBufferSuite extends AnyFunSuite with BeforeAndAfterAll assert(buf.size() == 10L + 5L + 15L) } + test("empty segments returns zero size") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq.empty[(Long, Long)] + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + assert(buf.size() == 0L) + } + test("nioByteBuffer reads single segment correctly") { val conf = null.asInstanceOf[TransportConf] val segments = Seq((10L, 5L)) @@ -134,6 +141,15 @@ class FileSegmentsManagedBufferSuite extends AnyFunSuite with BeforeAndAfterAll assert(thrown.getMessage.contains("EOF reached while reading segment")) } + test("nioByteBuffer with empty segments returns empty buffer") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq.empty[(Long, Long)] + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val bb = buf.nioByteBuffer() + assert(bb.remaining() == 0) + assert(!bb.hasRemaining) + } + test("nioByteBuffer with mmap for a single segment") { // Create a large file (16 KB) val largeData = (0 until 16384).map(_.toByte).toArray @@ -203,4 +219,28 @@ class FileSegmentsManagedBufferSuite extends AnyFunSuite with BeforeAndAfterAll fail(s"Unexpected exception: $t") } } + + test("createInputStream with empty segments returns EOF immediately") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq.empty[(Long, Long)] + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val in = buf.createInputStream() + assert(in.read() == -1) + val out = new Array[Byte](8) + assert(in.read(out) == -1) + in.close() + } + + test("convertToNetty with empty segments has zero readable bytes") { + val conf = new FakeTransportConf() + val segments = Seq.empty[(Long, Long)] + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val nettyObj = buf.convertToNetty() + assert(nettyObj.isInstanceOf[FileRegion]) + val region = nettyObj.asInstanceOf[FileRegion] + assert(region.count() == 0L) + val target = new ByteArrayWritableChannel() + assert(region.transferTo(target, 0) == 0L) + assert(target.toByteArray.isEmpty) + } } From ecde73fa2d3d3c3f4f5ea07de893949c9e141e02 Mon Sep 17 00:00:00 2001 From: Guo Wangyang Date: Fri, 6 Mar 2026 13:10:03 +0800 Subject: [PATCH 20/20] Fixes --- .../spark/shuffle/ColumnarShuffleWriter.scala | 2 +- cpp/core/shuffle/LocalPartitionWriter.cc | 8 ++- cpp/core/shuffle/LocalPartitionWriter.h | 2 +- ...opyFileSegmentsJniByteInputStreamTest.java | 54 ++++++++++--------- .../ColumnarIndexShuffleBlockResolver.scala | 27 ++++------ .../sort/DiscontiguousFileRegion.scala | 2 - ....scala => FileSegmentsManagedBuffer.scala} | 0 ...lumnarIndexShuffleBlockResolverSuite.scala | 13 ++--- .../sort/FileSegmentsManagedBufferSuite.scala | 1 - 9 files changed, 53 insertions(+), 56 deletions(-) rename gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/{FileSegmentsBuffer.scala => FileSegmentsManagedBuffer.scala} (100%) diff --git a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala index 50649a8c7777..5bde6796c573 100644 --- a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala +++ b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala @@ -175,7 +175,7 @@ class ColumnarShuffleWriter[K, V]( blockManager.subDirsPerLocalDir, conf.get(SHUFFLE_FILE_BUFFER_SIZE).toInt, tempDataFile.getAbsolutePath, - if (tempIndexFile != null) tempIndexFile.getAbsolutePath else null, + if (tempIndexFile != null) tempIndexFile.getAbsolutePath else "", localDirs, GlutenConfig.get.columnarShuffleEnableDictionary ) diff --git a/cpp/core/shuffle/LocalPartitionWriter.cc b/cpp/core/shuffle/LocalPartitionWriter.cc index fe3fb096d3d0..b9dca98cc1c8 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.cc +++ b/cpp/core/shuffle/LocalPartitionWriter.cc @@ -628,8 +628,9 @@ arrow::Status LocalPartitionWriter::writeIndexFile() { } } - // Write a ending char - RETURN_NOT_OK(indexFileOs->Write("", 1)); + // Write an ending marker byte with value 1 + const uint8_t marker = 1; + RETURN_NOT_OK(indexFileOs->Write(&marker, 1)); RETURN_NOT_OK(indexFileOs->Close()); return arrow::Status::OK(); } @@ -671,9 +672,6 @@ arrow::Status LocalPartitionWriter::writeCachedPayloads(uint32_t partitionId, ar } arrow::Status LocalPartitionWriter::flushCachedPayloads() { - if (payloadCache_ == nullptr) { - return arrow::Status::OK(); - } if (dataFileOs_ == nullptr) { ARROW_ASSIGN_OR_RAISE(dataFileOs_, openFile(dataFile_, options_->shuffleFileBufferSize)); } diff --git a/cpp/core/shuffle/LocalPartitionWriter.h b/cpp/core/shuffle/LocalPartitionWriter.h index c44b0810b1f6..37fc01015775 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.h +++ b/cpp/core/shuffle/LocalPartitionWriter.h @@ -34,7 +34,7 @@ class LocalPartitionWriter : public PartitionWriter { const std::shared_ptr& options, const std::string& dataFile, std::vector localDirs, - const std::string& indexFile = nullptr); + const std::string& indexFile = ""); arrow::Status hashEvict( uint32_t partitionId, diff --git a/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java b/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java index a0d06a2e0c63..7d69075f1764 100644 --- a/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java +++ b/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java @@ -72,19 +72,22 @@ public void testReadAcrossSegments() throws Exception { ByteBuffer buffer = PlatformDependent.allocateDirectNoCleaner(bytes.length); long addr = PlatformDependent.directBufferAddress(buffer); - long firstRead = in.read(addr, 3); - long secondRead = in.read(addr + firstRead, bytes.length - firstRead); - long totalRead = firstRead + secondRead; - - Assert.assertEquals(bytes.length, totalRead); - Assert.assertEquals(bytes.length, in.tell()); - - buffer.limit(bytes.length); - byte[] out = new byte[bytes.length]; - buffer.get(out); - Assert.assertArrayEquals(bytes, out); - - in.close(); + try { + long firstRead = in.read(addr, 3); + long secondRead = in.read(addr + firstRead, bytes.length - firstRead); + long totalRead = firstRead + secondRead; + + Assert.assertEquals(bytes.length, totalRead); + Assert.assertEquals(bytes.length, in.tell()); + + buffer.limit(bytes.length); + byte[] out = new byte[bytes.length]; + buffer.get(out); + Assert.assertArrayEquals(bytes, out); + } finally { + PlatformDependent.freeDirectNoCleaner(buffer); + in.close(); + } } @Test @@ -106,17 +109,20 @@ public void testReadNonContiguousSegments() throws Exception { ByteBuffer buffer = PlatformDependent.allocateDirectNoCleaner(8); long addr = PlatformDependent.directBufferAddress(buffer); - long read = in.read(addr, 8); - Assert.assertEquals(8, read); - Assert.assertEquals(8, in.tell()); - - buffer.limit(8); - byte[] out = new byte[8]; - buffer.get(out); - // Expected: "cde12345" - Assert.assertArrayEquals("cde12345".getBytes(StandardCharsets.UTF_8), out); - - in.close(); + try { + long read = in.read(addr, 8); + Assert.assertEquals(8, read); + Assert.assertEquals(8, in.tell()); + + buffer.limit(8); + byte[] out = new byte[8]; + buffer.get(out); + // Expected: "cde12345" + Assert.assertArrayEquals("cde12345".getBytes(StandardCharsets.UTF_8), out); + } finally { + PlatformDependent.freeDirectNoCleaner(buffer); + in.close(); + } } private static Seq> toScalaSeq( diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala index ebd2b3570ff3..ea5ccac06d8f 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala @@ -23,12 +23,8 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.shuffle._ import org.apache.spark.storage._ -import java.io.DataInputStream import java.io.File -import java.nio.channels.Channels -import java.nio.channels.FileChannel -import java.nio.channels.SeekableByteChannel -import java.nio.file.StandardOpenOption +import java.io.RandomAccessFile class ColumnarIndexShuffleBlockResolver( conf: SparkConf, @@ -55,7 +51,7 @@ class ColumnarIndexShuffleBlockResolver( } private def getSegmentsFromIndex( - channel: SeekableByteChannel, + index: RandomAccessFile, startId: Int, endId: Int): Seq[(Long, Long)] = { // New Index Format: @@ -69,21 +65,20 @@ class ColumnarIndexShuffleBlockResolver( // offset is the byte offset in the data file, size is the length in bytes of this segment. // All segments for all partitions are stored sequentially after the partition index. // 3) One extra byte at the end to distinguish from old format. - channel.position(startId * 8L) - val in = new DataInputStream(Channels.newInputStream(channel)) - var startOffset = in.readLong() - channel.position(endId * 8L) - val endOffset = in.readLong() + index.seek(startId * 8L) + var startOffset = index.readLong() + index.seek(endId * 8L) + val endOffset = index.readLong() if (endOffset < startOffset || (endOffset - startOffset) % 16 != 0) { throw new IllegalStateException( s"Index file: Invalid index to segments ($startOffset, $endOffset)") } val segmentCount = (endOffset - startOffset) / 16 // Read segments - channel.position(startOffset) + index.seek(startOffset) val segments = for (i <- 0 until segmentCount.toInt) yield { - val offset = in.readLong() - val size = in.readLong() + val offset = index.readLong() + val size = index.readLong() (offset, size) } segments.filter(_._2 > 0) // filter out zero-size segments @@ -100,7 +95,7 @@ class ColumnarIndexShuffleBlockResolver( throw new IllegalStateException(s"Index file $indexFile is not in the new format") } - var index = FileChannel.open(indexFile.toPath, StandardOpenOption.READ) + var index = new RandomAccessFile(indexFile, "r") var dataFileSize = dataFile.length() try { for (i <- 0 until numPartitions) { @@ -154,7 +149,7 @@ class ColumnarIndexShuffleBlockResolver( return super.getBlockData(blockId, dirs) } - var index = FileChannel.open(indexFile.toPath, StandardOpenOption.READ) + var index = new RandomAccessFile(indexFile, "r") try { val segments = getSegmentsFromIndex(index, startReduceId, endReduceId) val dataFile = getDataFile(shuffleId, mapId, dirs) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala index 660c16b5f45a..c25175634443 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala @@ -101,8 +101,6 @@ class DiscontiguousFileRegion( if (written == 0) { // Socket buffer full. Return what we have. return totalWritten - } else if (written == -1) { - throw new IOException("EOF encountered in underlying file") } // 4. Update state diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBuffer.scala similarity index 100% rename from gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsBuffer.scala rename to gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBuffer.scala diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala index 14c7e5da8109..d68f9c325b2c 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.funsuite.AnyFunSuite import java.io.File +import java.io.RandomAccessFile import java.nio.ByteBuffer import java.nio.channels.FileChannel import java.nio.file.StandardOpenOption @@ -125,28 +126,28 @@ class ColumnarIndexShuffleBlockResolverSuite } test("getSegmentsFromIndex returns correct segments for complex index file") { - val channel = FileChannel.open(indexFile.toPath, StandardOpenOption.READ) + val index = new RandomAccessFile(indexFile, "r") try { val method = classOf[ColumnarIndexShuffleBlockResolver].getDeclaredMethod( "getSegmentsFromIndex", - classOf[java.nio.channels.SeekableByteChannel], + classOf[java.io.RandomAccessFile], classOf[Int], classOf[Int]) method.setAccessible(true) // Partition 0: 0 segments val segs0 = - method.invoke(resolver, channel, Int.box(0), Int.box(1)).asInstanceOf[Seq[(Long, Long)]] + method.invoke(resolver, index, Int.box(0), Int.box(1)).asInstanceOf[Seq[(Long, Long)]] assert(segs0.isEmpty) // Partition 1 - 2: 3 segment val segs1_2 = - method.invoke(resolver, channel, Int.box(1), Int.box(3)).asInstanceOf[Seq[(Long, Long)]] + method.invoke(resolver, index, Int.box(1), Int.box(3)).asInstanceOf[Seq[(Long, Long)]] assert(segs1_2 == Seq((0L, 10L), (10L, 20L), (100L, 50L))) // Partition 3: 2 segments, empty segment will be filter out val segs3 = - method.invoke(resolver, channel, Int.box(3), Int.box(4)).asInstanceOf[Seq[(Long, Long)]] + method.invoke(resolver, index, Int.box(3), Int.box(4)).asInstanceOf[Seq[(Long, Long)]] assert(segs3 == Seq((30L, 70L), (150L, 50L))) } finally { - channel.close() + index.close() } } diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala index cbda66f5024a..1ea3ebed8ae1 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala @@ -192,7 +192,6 @@ class FileSegmentsManagedBufferSuite extends AnyFunSuite with BeforeAndAfterAll while (n > 0 || total == 0) { n = in.read(out, total, out.length - total) if (n > 0) total += n - else if (total == 0) n = 1 // ensure loop runs at least once } assert(total == 9) assert(out.slice(0, 3).sameElements(fileData.slice(5, 8)))