diff --git a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala index ef9877011b5d..5bde6796c573 100644 --- a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala +++ b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{SHUFFLE_COMPRESS, SHUFFLE_DISK_WRITE_BUFFER_SIZE, SHUFFLE_FILE_BUFFER_SIZE, SHUFFLE_SORT_INIT_BUFFER_SIZE, SHUFFLE_SORT_USE_RADIXSORT} import org.apache.spark.memory.SparkMemoryUtil import org.apache.spark.scheduler.MapStatus +import org.apache.spark.shuffle.sort.ColumnarIndexShuffleBlockResolver import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SparkDirectoryUtil, SparkResourceUtil, Utils} @@ -42,6 +43,8 @@ class ColumnarShuffleWriter[K, V]( with Logging { private val dep = handle.dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]] + private var columnarShuffleBlockResolver: ColumnarIndexShuffleBlockResolver = _ + private var partitionUseMultipleSegments: Boolean = false dep.shuffleWriterType match { case HashShuffleWriterType | SortShuffleWriterType | GpuHashShuffleWriterType => @@ -52,6 +55,18 @@ class ColumnarShuffleWriter[K, V]( s"expected one of: ${HashShuffleWriterType.name}, ${SortShuffleWriterType.name}") } + shuffleBlockResolver match { + case resolver: ColumnarIndexShuffleBlockResolver => + if (resolver.canUseNewFormat() && !GlutenConfig.get.columnarShuffleEnableDictionary) { + // For Dictionary encoding, the dict only finalizes after all batches are processed, + // and dict is required to saved at the head of the partition data. + // So we cannot use multiple segments to save partition data incrementally. + partitionUseMultipleSegments = true + columnarShuffleBlockResolver = resolver + } + case _ => + } + protected val isSort: Boolean = dep.shuffleWriterType == SortShuffleWriterType private val numPartitions: Int = dep.partitioner.numPartitions @@ -136,6 +151,11 @@ class ColumnarShuffleWriter[K, V]( } val tempDataFile = Utils.tempFileWith(shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)) + val tempIndexFile = if (partitionUseMultipleSegments) { + Utils.tempFileWith(columnarShuffleBlockResolver.getIndexFile(dep.shuffleId, mapId)) + } else { + null + } while (records.hasNext) { val cb = records.next()._2.asInstanceOf[ColumnarBatch] @@ -155,6 +175,7 @@ class ColumnarShuffleWriter[K, V]( blockManager.subDirsPerLocalDir, conf.get(SHUFFLE_FILE_BUFFER_SIZE).toInt, tempDataFile.getAbsolutePath, + if (tempIndexFile != null) tempIndexFile.getAbsolutePath else "", localDirs, GlutenConfig.get.columnarShuffleEnableDictionary ) @@ -264,16 +285,29 @@ class ColumnarShuffleWriter[K, V]( partitionLengths = splitResult.getPartitionLengths try { - shuffleBlockResolver.writeMetadataFileAndCommit( - dep.shuffleId, - mapId, - partitionLengths, - Array[Long](), - tempDataFile) + if (partitionUseMultipleSegments) { + columnarShuffleBlockResolver.writeIndexFileAndCommit( + dep.shuffleId, + mapId, + tempDataFile, + tempIndexFile, + numPartitions, + None) + } else { + shuffleBlockResolver.writeMetadataFileAndCommit( + dep.shuffleId, + mapId, + partitionLengths, + Array[Long](), + tempDataFile) + } } finally { if (tempDataFile.exists() && !tempDataFile.delete()) { logError(s"Error while deleting temp file ${tempDataFile.getAbsolutePath}") } + if (tempIndexFile != null && tempIndexFile.exists() && !tempIndexFile.delete()) { + logError(s"Error while deleting temp file ${tempIndexFile.getAbsolutePath}") + } } // The partitionLength is much more than vanilla spark partitionLengths, diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc index c48886195b22..b24b56d92294 100644 --- a/cpp/core/jni/JniWrapper.cc +++ b/cpp/core/jni/JniWrapper.cc @@ -943,6 +943,7 @@ Java_org_apache_gluten_vectorized_LocalPartitionWriterJniWrapper_createPartition jint numSubDirs, jint shuffleFileBufferSize, jstring dataFileJstr, + jstring indexFileJstr, jstring localDirsJstr, jboolean enableDictionary) { JNI_METHOD_START @@ -950,6 +951,7 @@ Java_org_apache_gluten_vectorized_LocalPartitionWriterJniWrapper_createPartition const auto ctx = getRuntime(env, wrapper); auto dataFile = jStringToCString(env, dataFileJstr); + auto indexFile = jStringToCString(env, indexFileJstr); auto localDirs = splitPaths(jStringToCString(env, localDirsJstr)); auto partitionWriterOptions = std::make_shared( @@ -968,7 +970,8 @@ Java_org_apache_gluten_vectorized_LocalPartitionWriterJniWrapper_createPartition ctx->memoryManager(), partitionWriterOptions, dataFile, - std::move(localDirs)); + std::move(localDirs), + indexFile); return ctx->saveObject(partitionWriter); JNI_METHOD_END(kInvalidObjectHandle) diff --git a/cpp/core/shuffle/LocalPartitionWriter.cc b/cpp/core/shuffle/LocalPartitionWriter.cc index 948a2b0e05ef..b9dca98cc1c8 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.cc +++ b/cpp/core/shuffle/LocalPartitionWriter.cc @@ -356,6 +356,25 @@ class LocalPartitionWriter::PayloadCache { return arrow::Status::OK(); } + arrow::Result writeIncremental(uint32_t partitionId, arrow::io::OutputStream *os) { + GLUTEN_CHECK(!enableDictionary_, "Incremental write is not supported when dictionary is enabled."); + if ((partitionInUse_.has_value() && partitionInUse_.value() == partitionId) || !hasCachedPayloads(partitionId)) { + return false; + } + + auto& payloads = partitionCachedPayload_[partitionId]; + while (!payloads.empty()) { + const auto payload = std::move(payloads.front()); + payloads.pop_front(); + uint8_t blockType = static_cast(BlockType::kPlainPayload); + RETURN_NOT_OK(os->Write(&blockType, sizeof(blockType))); + RETURN_NOT_OK(payload->serialize(os)); + compressTime_ += payload->getCompressTime(); + writeTime_ += payload->getWriteTime(); + } + return true; + } + bool canSpill() { for (auto pid = 0; pid < numPartitions_; ++pid) { if (partitionInUse_.has_value() && partitionInUse_.value() == pid) { @@ -506,10 +525,12 @@ LocalPartitionWriter::LocalPartitionWriter( MemoryManager* memoryManager, const std::shared_ptr& options, const std::string& dataFile, - std::vector localDirs) + std::vector localDirs, + const std::string& indexFile) : PartitionWriter(numPartitions, std::move(codec), memoryManager), options_(options), dataFile_(dataFile), + indexFile_(indexFile), localDirs_(std::move(localDirs)) { init(); } @@ -562,6 +583,56 @@ void LocalPartitionWriter::init() { std::default_random_engine engine(rd()); std::shuffle(localDirs_.begin(), localDirs_.end(), engine); subDirSelection_.assign(localDirs_.size(), 0); + + if (!indexFile_.empty()) { + usePartitionMultipleSegments_ = true; + partitionSegments_.resize(numPartitions_); + } +} + +// Helper for big-endian conversion (network order) +#include +static uint64_t htonll(uint64_t value) { +#if __BYTE_ORDER == __LITTLE_ENDIAN + return (((uint64_t)htonl(value & 0xFFFFFFFFULL)) << 32) | htonl(value >> 32); +#else + return value; +#endif +} + +arrow::Status LocalPartitionWriter::writeIndexFile() { + if (!usePartitionMultipleSegments_) { + return arrow::Status::OK(); + } + + ARROW_ASSIGN_OR_RAISE(auto indexFileOs, openFile(indexFile_, options_->shuffleFileBufferSize)); + + uint64_t segmentOffset = (numPartitions_ + 1) * sizeof(int64_t); + // write segment index of each partition in big-endian + for (uint32_t pid = 0; pid < numPartitions_; ++pid) { + uint64_t beOffset = htonll(segmentOffset); + RETURN_NOT_OK(indexFileOs->Write(reinterpret_cast(&beOffset), sizeof(beOffset))); + const auto& segments = partitionSegments_[pid]; + segmentOffset += (segments.size() * 2 * sizeof(int64_t)); + } + uint64_t beOffset = htonll(segmentOffset); + RETURN_NOT_OK(indexFileOs->Write(reinterpret_cast(&beOffset), sizeof(beOffset))); + // Write partition segments info in big-endian + for (uint32_t pid = 0; pid < numPartitions_; ++pid) { + const auto& segments = partitionSegments_[pid]; + for (const auto& segment : segments) { + uint64_t beFirst = htonll(segment.first); + uint64_t beSecond = htonll(segment.second); + RETURN_NOT_OK(indexFileOs->Write(reinterpret_cast(&beFirst), sizeof(beFirst))); + RETURN_NOT_OK(indexFileOs->Write(reinterpret_cast(&beSecond), sizeof(beSecond))); + } + } + + // Write an ending marker byte with value 1 + const uint8_t marker = 1; + RETURN_NOT_OK(indexFileOs->Write(&marker, 1)); + RETURN_NOT_OK(indexFileOs->Close()); + return arrow::Status::OK(); } arrow::Result LocalPartitionWriter::mergeSpills(uint32_t partitionId, arrow::io::OutputStream* os) { @@ -600,13 +671,61 @@ arrow::Status LocalPartitionWriter::writeCachedPayloads(uint32_t partitionId, ar return arrow::Status::OK(); } +arrow::Status LocalPartitionWriter::flushCachedPayloads() { + if (dataFileOs_ == nullptr) { + ARROW_ASSIGN_OR_RAISE(dataFileOs_, openFile(dataFile_, options_->shuffleFileBufferSize)); + } + ARROW_ASSIGN_OR_RAISE(int64_t endInDataFile, dataFileOs_->Tell()); + for (auto pid = 0; pid < numPartitions_; ++pid) { + auto startInDataFile = endInDataFile; + ARROW_ASSIGN_OR_RAISE(int64_t spillWrittenBytes, mergeSpills(pid, dataFileOs_.get())); + ARROW_ASSIGN_OR_RAISE(bool cachePayloadWritten, payloadCache_->writeIncremental(pid, dataFileOs_.get())); + if (spillWrittenBytes > 0 || cachePayloadWritten) { + ARROW_ASSIGN_OR_RAISE(endInDataFile, dataFileOs_->Tell()); + auto bytesWritten = endInDataFile - startInDataFile; + partitionSegments_[pid].emplace_back(startInDataFile, bytesWritten); + partitionLengths_[pid] += bytesWritten; + } + } + + return arrow::Status::OK(); +} + +arrow::Status LocalPartitionWriter::writeMemoryPayload(uint32_t partitionId, std::unique_ptr payload) { + if (dataFileOs_ == nullptr) { + ARROW_ASSIGN_OR_RAISE(dataFileOs_, openFile(dataFile_, options_->shuffleFileBufferSize)); + } + + ARROW_ASSIGN_OR_RAISE(int64_t startOffset, dataFileOs_->Tell()); + if (codec_ != nullptr) { + ARROW_ASSIGN_OR_RAISE(auto compressOs, ShuffleCompressedOutputStream::Make(codec_.get(), options_->compressionBufferSize, dataFileOs_, arrow::default_memory_pool())); + RETURN_NOT_OK(payload->serialize(compressOs.get())); + RETURN_NOT_OK(compressOs->Flush()); + compressTime_ += compressOs->compressTime(); + RETURN_NOT_OK(compressOs->Close()); + } else { + RETURN_NOT_OK(payload->serialize(dataFileOs_.get())); + } + ARROW_ASSIGN_OR_RAISE(int64_t endOffset, dataFileOs_->Tell()); + auto bytesWritten = endOffset - startOffset; + partitionSegments_[partitionId].emplace_back(startOffset, bytesWritten); + partitionLengths_[partitionId] += bytesWritten; + + return arrow::Status::OK(); +} + arrow::Status LocalPartitionWriter::stop(ShuffleWriterMetrics* metrics, int64_t& evictBytes) { if (stopped_) { return arrow::Status::OK(); } stopped_ = true; - if (useSpillFileAsDataFile_) { + if (usePartitionMultipleSegments_) { + RETURN_NOT_OK(finishSpill()); + RETURN_NOT_OK(finishMerger()); + RETURN_NOT_OK(flushCachedPayloads()); + RETURN_NOT_OK(writeIndexFile()); + } else if (useSpillFileAsDataFile_) { ARROW_ASSIGN_OR_RAISE(auto spill, spiller_->finish()); // Merge the remaining partitions from spills. @@ -750,6 +869,9 @@ arrow::Status LocalPartitionWriter::hashEvict( for (auto& payload : merged) { RETURN_NOT_OK(payloadCache_->cache(partitionId, std::move(payload))); } + if (usePartitionMultipleSegments_) { + RETURN_NOT_OK(flushCachedPayloads()); + } merged.clear(); } return arrow::Status::OK(); @@ -759,6 +881,12 @@ arrow::Status LocalPartitionWriter::sortEvict(uint32_t partitionId, std::unique_ptr inMemoryPayload, bool isFinal, int64_t& evictBytes) { rawPartitionLengths_[partitionId] += inMemoryPayload->rawSize(); + if (usePartitionMultipleSegments_) { + // If multiple segments per partition is enabled, write directly to the final data file. + RETURN_NOT_OK(writeMemoryPayload(partitionId, std::move(inMemoryPayload))); + return arrow::Status::OK(); + } + if (lastEvictPid_ != -1 && (partitionId < lastEvictPid_ || (isFinal && !dataFileOs_))) { lastEvictPid_ = -1; RETURN_NOT_OK(finishSpill()); diff --git a/cpp/core/shuffle/LocalPartitionWriter.h b/cpp/core/shuffle/LocalPartitionWriter.h index 113e5a3cfd2f..37fc01015775 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.h +++ b/cpp/core/shuffle/LocalPartitionWriter.h @@ -33,7 +33,8 @@ class LocalPartitionWriter : public PartitionWriter { MemoryManager* memoryManager, const std::shared_ptr& options, const std::string& dataFile, - std::vector localDirs); + std::vector localDirs, + const std::string& indexFile = ""); arrow::Status hashEvict( uint32_t partitionId, @@ -90,6 +91,8 @@ class LocalPartitionWriter : public PartitionWriter { void init(); + arrow::Status writeIndexFile(); + arrow::Status requestSpill(bool isFinal); arrow::Status finishSpill(); @@ -102,12 +105,17 @@ class LocalPartitionWriter : public PartitionWriter { arrow::Status writeCachedPayloads(uint32_t partitionId, arrow::io::OutputStream* os) const; + arrow::Status flushCachedPayloads(); + + arrow::Status writeMemoryPayload(uint32_t partitionId, std::unique_ptr inMemoryPayload); + arrow::Status clearResource(); arrow::Status populateMetrics(ShuffleWriterMetrics* metrics); std::shared_ptr options_; std::string dataFile_; + std::string indexFile_; std::vector localDirs_; bool stopped_{false}; @@ -128,6 +136,11 @@ class LocalPartitionWriter : public PartitionWriter { std::vector partitionLengths_; std::vector rawPartitionLengths_; + bool usePartitionMultipleSegments_{false}; + // For each partition, record all segments' (start, length) in the final data file. + // partitionSegments_[pid] = [(start1, length1), (start2, length2), ...] + std::vector>> partitionSegments_{}; + int32_t lastEvictPid_{-1}; }; } // namespace gluten diff --git a/gluten-arrow/pom.xml b/gluten-arrow/pom.xml index a5552719b0f7..e83a710ce57f 100644 --- a/gluten-arrow/pom.xml +++ b/gluten-arrow/pom.xml @@ -34,6 +34,13 @@ ${project.version} compile + + org.apache.gluten + gluten-substrait + ${project.version} + test-jar + test + org.apache.gluten ${sparkshim.artifactId} diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java index 0749d0cff3fd..0eb24ef9a21a 100644 --- a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java @@ -53,6 +53,9 @@ public static JniByteInputStream create(InputStream in) { if (LowCopyFileSegmentJniByteInputStream.isSupported(unwrapped)) { return new LowCopyFileSegmentJniByteInputStream(in); } + if (LowCopyFileSegmentsJniByteInputStream.isSupported(unwrapped)) { + return new LowCopyFileSegmentsJniByteInputStream(in); + } return new OnHeapJniByteInputStream(in); } diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LocalPartitionWriterJniWrapper.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LocalPartitionWriterJniWrapper.java index 3269e7b17610..f727a52ae29d 100644 --- a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LocalPartitionWriterJniWrapper.java +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LocalPartitionWriterJniWrapper.java @@ -47,6 +47,7 @@ public native long createPartitionWriter( int subDirsPerLocalDir, int shuffleFileBufferSize, String dataFile, + String indexFile, String localDirs, boolean enableDictionary); } diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStream.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStream.java new file mode 100644 index 000000000000..845e2f3d3f10 --- /dev/null +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStream.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.vectorized; + +import org.apache.gluten.exception.GlutenException; + +import org.apache.spark.shuffle.sort.FileSegmentsInputStream; + +import java.io.IOException; +import java.io.InputStream; + +/** + * This implementation is targeted to optimize against Gluten's {@link + * org.apache.spark.shuffle.sort.FileSegmentsManagedBuffer} to make sure shuffle data is shared over + * JNI without unnecessary copy. + */ +public class LowCopyFileSegmentsJniByteInputStream implements JniByteInputStream { + private final FileSegmentsInputStream fsin; + private long bytesRead = 0L; + private long left; + + public LowCopyFileSegmentsJniByteInputStream(InputStream in) { + final InputStream unwrapped = JniByteInputStreams.unwrapSparkInputStream(in); + this.fsin = (FileSegmentsInputStream) unwrapped; + left = this.fsin.remainingBytes(); + } + + public static boolean isSupported(InputStream in) { + return in instanceof FileSegmentsInputStream; + } + + @Override + public long read(long destAddress, long maxSize) { + long bytesToRead = Math.min(left, maxSize); + if (bytesToRead == 0) { + return 0; + } + try { + long bytes = fsin.read(destAddress, bytesToRead); + if (bytes == 0) { + return 0; + } + bytesRead += bytes; + left -= bytes; + return bytes; + } catch (IOException e) { + throw new GlutenException(e); + } + } + + @Override + public long tell() { + return bytesRead; + } + + @Override + public void close() { + try { + fsin.close(); + } catch (IOException e) { + throw new GlutenException(e); + } + } +} diff --git a/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java b/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java new file mode 100644 index 000000000000..7d69075f1764 --- /dev/null +++ b/gluten-arrow/src/test/java/org/apache/gluten/vectorized/LowCopyFileSegmentsJniByteInputStreamTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.vectorized; + +import io.netty.util.internal.PlatformDependent; +import org.apache.spark.shuffle.sort.FileSegmentsInputStream; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.FileOutputStream; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.Arrays; + +import scala.Tuple2; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +public class LowCopyFileSegmentsJniByteInputStreamTest { + private File tempFile; + + @Before + public void setUp() throws Exception { + tempFile = Files.createTempFile("gluten-segments", ".bin").toFile(); + } + + @After + public void tearDown() throws Exception { + if (tempFile != null && tempFile.exists()) { + Files.deleteIfExists(tempFile.toPath()); + } + } + + @Test + public void testReadAcrossSegments() throws Exception { + byte[] bytes = "abcdefg-0123456".getBytes(StandardCharsets.UTF_8); + try (FileOutputStream out = new FileOutputStream(tempFile)) { + out.write(bytes); + } + + int firstLen = 4; + int secondLen = bytes.length - firstLen; + Seq> segments = + toScalaSeq( + Arrays.asList( + new Tuple2<>(0L, (long) firstLen), + new Tuple2<>((long) firstLen, (long) secondLen))); + + FileSegmentsInputStream segmentStream = new FileSegmentsInputStream(tempFile, segments); + Assert.assertTrue(LowCopyFileSegmentsJniByteInputStream.isSupported(segmentStream)); + + LowCopyFileSegmentsJniByteInputStream in = + new LowCopyFileSegmentsJniByteInputStream(segmentStream); + ByteBuffer buffer = PlatformDependent.allocateDirectNoCleaner(bytes.length); + long addr = PlatformDependent.directBufferAddress(buffer); + + try { + long firstRead = in.read(addr, 3); + long secondRead = in.read(addr + firstRead, bytes.length - firstRead); + long totalRead = firstRead + secondRead; + + Assert.assertEquals(bytes.length, totalRead); + Assert.assertEquals(bytes.length, in.tell()); + + buffer.limit(bytes.length); + byte[] out = new byte[bytes.length]; + buffer.get(out); + Assert.assertArrayEquals(bytes, out); + } finally { + PlatformDependent.freeDirectNoCleaner(buffer); + in.close(); + } + } + + @Test + public void testReadNonContiguousSegments() throws Exception { + byte[] bytes = "abcdefghij123456789".getBytes(StandardCharsets.UTF_8); + try (FileOutputStream out = new FileOutputStream(tempFile)) { + out.write(bytes); + } + + // Select two non-contiguous segments: [2,5) and [10,15) + Seq> segments = + toScalaSeq(Arrays.asList(new Tuple2<>(2L, 3L), new Tuple2<>(10L, 5L))); + + FileSegmentsInputStream segmentStream = new FileSegmentsInputStream(tempFile, segments); + Assert.assertTrue(LowCopyFileSegmentsJniByteInputStream.isSupported(segmentStream)); + + LowCopyFileSegmentsJniByteInputStream in = + new LowCopyFileSegmentsJniByteInputStream(segmentStream); + ByteBuffer buffer = PlatformDependent.allocateDirectNoCleaner(8); + long addr = PlatformDependent.directBufferAddress(buffer); + + try { + long read = in.read(addr, 8); + Assert.assertEquals(8, read); + Assert.assertEquals(8, in.tell()); + + buffer.limit(8); + byte[] out = new byte[8]; + buffer.get(out); + // Expected: "cde12345" + Assert.assertArrayEquals("cde12345".getBytes(StandardCharsets.UTF_8), out); + } finally { + PlatformDependent.freeDirectNoCleaner(buffer); + in.close(); + } + } + + private static Seq> toScalaSeq( + java.util.List> segments) { + return JavaConverters.asScalaBuffer(segments).toSeq(); + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala new file mode 100644 index 000000000000..ea5ccac06d8f --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolver.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark.SparkConf +import org.apache.spark.SparkException +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.shuffle._ +import org.apache.spark.storage._ + +import java.io.File +import java.io.RandomAccessFile + +class ColumnarIndexShuffleBlockResolver( + conf: SparkConf, + private var blockManager: BlockManager = null) + extends IndexShuffleBlockResolver(conf, blockManager) { + + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + private val SHUFFLE_SERVICE_ENABLED = "spark.shuffle.service.enabled" + private val PUSH_BASED_SHUFFLE_ENABLED = "spark.shuffle.push.based.enabled" + private var newFormatEnabled: Boolean = _ + + // When external shuffle service or push-based shuffle is enabled, + // it may directly access the shuffle files without going through this resolver. + // we cannot use the new index format to maintain backward compatibility. + newFormatEnabled = !(conf.getBoolean(PUSH_BASED_SHUFFLE_ENABLED, false) || + conf.getBoolean(SHUFFLE_SERVICE_ENABLED, false)) + + def canUseNewFormat(): Boolean = newFormatEnabled + + private def isNewFormat(index: File): Boolean = { + // Simple heuristic to determine new format by appending 1 extra bytes. + // Old index file always has length multiple of 8 bytes. + if (index.length() % 8 == 1) true else false + } + + private def getSegmentsFromIndex( + index: RandomAccessFile, + startId: Int, + endId: Int): Seq[(Long, Long)] = { + // New Index Format: + // To support a partition index with multiple segments, + // the index file is composed of three parts: + // 1) Partition Index: index_0, index_1, ... index_N + // Each index_i is an 8-byte long integer representing the byte offset in the file + // For partition i, its segments are stored between index_i and index_(i+1). + // 2) Segment Data: [offset_0][size_0][offset_1][size_1]...[offset_N][size_N] + // Each segment is represented by two 8-byte long integers: offset and size. + // offset is the byte offset in the data file, size is the length in bytes of this segment. + // All segments for all partitions are stored sequentially after the partition index. + // 3) One extra byte at the end to distinguish from old format. + index.seek(startId * 8L) + var startOffset = index.readLong() + index.seek(endId * 8L) + val endOffset = index.readLong() + if (endOffset < startOffset || (endOffset - startOffset) % 16 != 0) { + throw new IllegalStateException( + s"Index file: Invalid index to segments ($startOffset, $endOffset)") + } + val segmentCount = (endOffset - startOffset) / 16 + // Read segments + index.seek(startOffset) + val segments = for (i <- 0 until segmentCount.toInt) yield { + val offset = index.readLong() + val size = index.readLong() + (offset, size) + } + segments.filter(_._2 > 0) // filter out zero-size segments + } + + private def checkIndexAndDataFile(indexFile: File, dataFile: File, numPartitions: Int): Unit = { + if (!indexFile.exists()) { + throw new IllegalStateException(s"Index file $indexFile does not exist") + } + if (!dataFile.exists()) { + throw new IllegalStateException(s"Data file $dataFile does not exist") + } + if (!isNewFormat(indexFile)) { + throw new IllegalStateException(s"Index file $indexFile is not in the new format") + } + + var index = new RandomAccessFile(indexFile, "r") + var dataFileSize = dataFile.length() + try { + for (i <- 0 until numPartitions) { + var segments = getSegmentsFromIndex(index, startId = i, endId = i + 1) + segments.foreach { + case (offset, size) => + if (offset < 0 || size < 0 || offset + size > dataFileSize) { + throw new IllegalStateException( + s"Index file $indexFile has invalid segment ($offset, $size) " + + s"for partition $i in data file of size $dataFileSize") + } + } + } + } finally { + index.close() + } + } + + def writeIndexFileAndCommit( + shuffleId: Int, + mapId: Long, + dataTmp: File, + indexTmp: File, + numPartitions: Int, + checksums: Option[Array[Long]] = None + ): Unit = { + val indexFile = getIndexFile(shuffleId, mapId) + val dataFile = getDataFile(shuffleId, mapId) + this.synchronized { + checkIndexAndDataFile(indexTmp, dataTmp, numPartitions) + dataTmp.renameTo(dataFile) + indexTmp.renameTo(indexFile) + } + } + + override def getBlockData(blockId: BlockId, dirs: Option[Array[String]]): ManagedBuffer = { + val (shuffleId, mapId, startReduceId, endReduceId) = blockId match { + case id: ShuffleBlockId => + (id.shuffleId, id.mapId, id.reduceId, id.reduceId + 1) + case batchId: ShuffleBlockBatchId => + (batchId.shuffleId, batchId.mapId, batchId.startReduceId, batchId.endReduceId) + case _ => + throw SparkException.internalError( + s"unexpected shuffle block id format: $blockId", + category = "SHUFFLE") + } + val indexFile = getIndexFile(shuffleId, mapId, dirs) + + if (!isNewFormat(indexFile)) { + // fallback to old implementation + return super.getBlockData(blockId, dirs) + } + + var index = new RandomAccessFile(indexFile, "r") + try { + val segments = getSegmentsFromIndex(index, startReduceId, endReduceId) + val dataFile = getDataFile(shuffleId, mapId, dirs) + new FileSegmentsManagedBuffer(transportConf, dataFile, segments) + } finally { + index.close() + } + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala index 904c2dff6ce7..c4dd0db04a1a 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -39,7 +39,7 @@ class ColumnarShuffleManager(conf: SparkConf) import ColumnarShuffleManager._ private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) - override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + override val shuffleBlockResolver = new ColumnarIndexShuffleBlockResolver(conf) /** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */ private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala new file mode 100644 index 000000000000..c25175634443 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegion.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +// org.apache.spark.shuffle.sort has io objects conflicting with Netty's io objects. +import _root_.io.netty.channel.FileRegion +import _root_.io.netty.util.AbstractReferenceCounted + +import java.io.File +import java.io.IOException +import java.nio.channels.{FileChannel, WritableByteChannel} +import java.nio.file.StandardOpenOption + +/** + * A FileRegion that maps a continuous logical stream (0 to N) onto multiple discontiguous physical + * file segments. + */ +class DiscontiguousFileRegion( + private val file: File, + private val segments: Seq[(Long, Long)], // (Physical File Offset, Length) + private val lazyOpen: Boolean = false // If true, delay opening the file until first use +) extends AbstractReferenceCounted + with FileRegion { + + // segments may be empty to represent a valid zero-length region (e.g., an empty shuffle block) + + private val totalCount: Long = segments.map(_._2).sum + private var bytesTransferred: Long = 0L + private var fileChannel: FileChannel = null + private var closed: Boolean = false + + if (!lazyOpen && totalCount > 0) { + ensureOpen() + } + + private def ensureOpen(): Unit = { + // For zero-length regions, there's nothing to read; don't open the file. + if (totalCount == 0L) { + return + } + if (closed) { + throw new IOException("File is already closed") + } + if (fileChannel == null) { + fileChannel = FileChannel.open(file.toPath, StandardOpenOption.READ) + } + } + + /** + * Transfers data starting from the LOGICAL 'position'. + * @param target + * The socket/channel to write to. + * @param position + * The relative position (0 to totalCount) within this specific region where the transfer should + * begin. + */ + override def transferTo(target: WritableByteChannel, position: Long): Long = { + var logicalPos = position + var totalWritten = 0L + + ensureOpen() + + if (logicalPos >= totalCount) { + return 0L + } + + // 1. Locate the starting segment + var segmentIndex = 0 + var currentSegmentBase = 0L // Logical start of the current segment + + // Fast-forward to the segment containing 'position' + while (segmentIndex < segments.length) { + val (phyOffset, length) = segments(segmentIndex) + val currentSegmentEnd = currentSegmentBase + length + + if (logicalPos < currentSegmentEnd) { + // FOUND IT: The request starts inside this segment. + + // 2. Calculate offsets + val offsetInSegment = logicalPos - currentSegmentBase + val physicalPos = phyOffset + offsetInSegment + val remainingInSegment = length - offsetInSegment + + // 3. Perform the transfer (Zero-Copy) + val written = fileChannel.transferTo(physicalPos, remainingInSegment, target) + + if (written == 0) { + // Socket buffer full. Return what we have. + return totalWritten + } + + // 4. Update state + bytesTransferred += written + totalWritten += written + logicalPos += written + + // Optimization: If we finished this segment exactly, loop immediately to the next + // so we don't return to the Netty loop just to call us back for the next byte. + if (written == remainingInSegment) { + currentSegmentBase += length + segmentIndex += 1 + // The loop continues to the next segment... + } else { + // We wrote only part of the segment (socket likely full). Done for now. + return totalWritten + } + + } else { + // Target position is further down. Skip this segment. + currentSegmentBase += length + segmentIndex += 1 + } + } + + totalWritten + } + + override def count(): Long = totalCount + + // This returns the position within the REGION, usually 0 for the start of the region object. + override def position(): Long = 0 + + override def transferred(): Long = bytesTransferred + @deprecated override def transfered(): Long = bytesTransferred + + // --- Reference Counting --- + override def touch(hint: Any): FileRegion = this + override def touch(): FileRegion = this + override def retain(): FileRegion = { super.retain(); this } + override def retain(increment: Int): FileRegion = { super.retain(increment); this } + + override protected def deallocate(): Unit = { + if (fileChannel != null && !closed) { + fileChannel.close() + fileChannel = null + closed = true + } + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala new file mode 100644 index 000000000000..41ef6ffa2b4a --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStream.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import _root_.io.netty.util.internal.PlatformDependent + +import java.io.{File, InputStream, IOException, RandomAccessFile} +import java.nio.ByteBuffer + +/** + * An InputStream that reads a list of non-contiguous file segments sequentially. + * + * @param file + * the file to read + * @param segments + * a list of (file_offset, size) segments to read in sequence + */ +class FileSegmentsInputStream(file: File, segments: Seq[(Long, Long)]) extends InputStream { + private val raf = new RandomAccessFile(file, "r") + private val channel = raf.getChannel + + private var currentIndex = 0 + private var currentOffset = 0L + private var remainingInSegment = 0L + private var closed = false + + if (segments.nonEmpty) { + currentOffset = segments.head._1 + remainingInSegment = segments.head._2 + channel.position(currentOffset) + } + + /** Returns the number of bytes remaining to be read from all segments. */ + def remainingBytes: Long = { + if (closed) return 0L + var total = remainingInSegment + var idx = currentIndex + 1 + while (idx < segments.length) { + total += segments(idx)._2 + idx += 1 + } + total + } + + /** + * Read bytes directly into native memory at the given address. + * + * @param destAddress + * destination memory address + * @param maxSize + * max number of bytes to read + * @return + * number of bytes read, or 0 if no more data + */ + @throws[IOException] + def read(destAddress: Long, maxSize: Long): Long = { + if (closed) { + throw new IOException("Stream is closed") + } + + var totalRead = 0L + var remaining = maxSize + while (remaining > 0 && currentIndex < segments.length) { + if (remainingInSegment == 0) { + if (!advanceSegment()) { + return totalRead + } + } + val bytesThisRead = Math.min(remainingInSegment, remaining.toLong).toInt + val direct = PlatformDependent.directBuffer(destAddress + totalRead, bytesThisRead) + val n = channel.read(direct) + if (n == -1) { + return totalRead + } + remainingInSegment -= n + totalRead += n + remaining -= n + } + totalRead + } + + override def read(): Int = { + val buf = new Array[Byte](1) + val n = read(buf, 0, 1) + if (n <= 0) { + -1 + } else { + buf(0) & 0xff + } + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = { + if (closed) { + throw new IOException("Stream is closed") + } + if (b == null) { + throw new NullPointerException("buffer is null") + } + if (off < 0 || len < 0 || off + len > b.length) { + throw new IndexOutOfBoundsException(s"offset=$off, length=$len, buffer length=${b.length}") + } + if (len == 0) { + return 0 + } + + var totalRead = 0 + var remaining = len + while (remaining > 0 && currentIndex < segments.length) { + if (remainingInSegment == 0) { + if (!advanceSegment()) { + if (totalRead == 0) return -1 + return totalRead + } + } + val bytesToRead = Math.min(remainingInSegment, remaining.toLong).toInt + val buf = ByteBuffer.wrap(b, off + totalRead, bytesToRead) + val n = channel.read(buf) + if (n == -1) { + if (totalRead == 0) return -1 + return totalRead + } + remainingInSegment -= n + totalRead += n + remaining -= n + } + + if (totalRead == 0 && currentIndex >= segments.length) { + -1 + } else { + totalRead + } + } + + override def skip(n: Long): Long = { + if (closed) { + throw new IOException("Stream is closed") + } + if (n <= 0) { + return 0L + } + var remaining = n + var skipped = 0L + while (remaining > 0 && currentIndex < segments.length) { + if (remainingInSegment == 0 && !advanceSegment()) { + return skipped + } + val toSkip = Math.min(remainingInSegment, remaining) + channel.position(channel.position() + toSkip) + remainingInSegment -= toSkip + remaining -= toSkip + skipped += toSkip + } + skipped + } + + override def available(): Int = { + if (remainingInSegment > Int.MaxValue) Int.MaxValue else remainingInSegment.toInt + } + + @throws[IOException] + override def close(): Unit = { + if (!closed) { + closed = true + channel.close() + raf.close() + } + } + + private def advanceSegment(): Boolean = { + currentIndex += 1 + if (currentIndex >= segments.length) { + return false + } + val (offset, length) = segments(currentIndex) + currentOffset = offset + remainingInSegment = length + channel.position(currentOffset) + true + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBuffer.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBuffer.scala new file mode 100644 index 000000000000..271c0d41eb29 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBuffer.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.network.util.TransportConf + +import java.io.{EOFException, File, InputStream, RandomAccessFile} +import java.nio.ByteBuffer +import java.nio.channels.FileChannel + +/** A {@link ManagedBuffer} backed by a set of segments in a file. */ +class FileSegmentsManagedBuffer( + conf: TransportConf, + file: File, + segments: Seq[(Long, Long)] +) extends ManagedBuffer { + private val totalCount: Long = segments.map(_._2).sum + + override def size(): Long = totalCount + + override def retain(): ManagedBuffer = this + override def release(): ManagedBuffer = this + + override def nioByteBuffer(): ByteBuffer = { + val totalSize = size() + val channel = new RandomAccessFile(file, "r").getChannel + try { + if (conf != null && totalSize >= conf.memoryMapBytes() && segments.length == 1) { + // zero-copy for single segment that is large enough + val (offset, length) = segments.head + return channel.map(FileChannel.MapMode.READ_ONLY, offset, length) + } + val totalSizeInt = Math.toIntExact(totalSize) + val buffer = ByteBuffer.allocate(totalSizeInt) + var destPos = 0 + segments.foreach { + case (offset, length) => + buffer.position(destPos) + buffer.limit(destPos + length.toInt) + channel.position(offset) + var remaining = length + while (remaining > 0) { + val n = channel.read(buffer) + if (n == -1) { + throw new EOFException(s"EOF reached while reading segment at offset $offset") + } + remaining -= n + } + destPos += length.toInt + } + buffer.flip() + buffer + } finally { + JavaUtils.closeQuietly(channel) + } + } + + override def createInputStream(): InputStream = { + new FileSegmentsInputStream(file, segments) + } + + override def convertToNetty(): AnyRef = { + var lazyOpen = false + if (conf != null) lazyOpen = conf.lazyFileDescriptor() + new DiscontiguousFileRegion(file, segments, lazyOpen) + } +} diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala new file mode 100644 index 000000000000..d68f9c325b2c --- /dev/null +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/ColumnarIndexShuffleBlockResolverSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark.SparkConf +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.storage._ +import org.apache.spark.util.Utils + +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.scalatest.funsuite.AnyFunSuite + +import java.io.File +import java.io.RandomAccessFile +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.nio.file.StandardOpenOption + +class ColumnarIndexShuffleBlockResolverSuite + extends AnyFunSuite + with BeforeAndAfterAll + with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ + + private val conf: SparkConf = new SparkConf(loadDefaults = false) + private var tempDir: File = _ + private val appId = "TESTAPP" + + private var indexFile: File = _ + private var dataFile: File = _ + private var resolver: ColumnarIndexShuffleBlockResolver = _ + + override def beforeAll(): Unit = { + indexFile = File.createTempFile("index", ".bin") + val fc = FileChannel.open(indexFile.toPath, StandardOpenOption.WRITE, StandardOpenOption.READ) + // Partition index: 4 partitions, so 5 offsets + // Partition 0: 0 segments + // Partition 1: 1 segment + // Partition 2: 2 segments + // Partition 3: 3 segments + // Offsets: [indexEnd, seg1Start, seg2Start, seg3Start, segEnd] + // Let's say index is 5*8 = 40 bytes, segments start at 40 + val segOffsets = Array(40L, 40L, 56L, 88L, 136L) + val indexBuf = ByteBuffer.allocate(5 * 8) + segOffsets.foreach(indexBuf.putLong) + indexBuf.flip() + fc.write(indexBuf) + // Segment data: total 6 segments (1+2+3) + val segBuf = ByteBuffer.allocate(6 * 16) + // Define a data file with 200 bytes. + // Partition 1: + segBuf.putLong(0L); segBuf.putLong(10L) + // Partition 2: + segBuf.putLong(10L); segBuf.putLong(20L) + segBuf.putLong(100L); segBuf.putLong(50L) + // Partition 3: + segBuf.putLong(30L); segBuf.putLong(70L) + segBuf.putLong(150L); segBuf.putLong(50L) + segBuf.putLong(200L); segBuf.putLong(0L) + segBuf.flip() + fc.write(segBuf) + // Add one extra byte to mark new format + fc.write(ByteBuffer.wrap(Array(1.toByte))) + fc.close() + + // data file with 200 bytes, 0-199 + dataFile = File.createTempFile("data", ".bin") + val out = new java.io.FileOutputStream(dataFile) + out.write((0 until 200).map(_.toByte).toArray) + out.close() + } + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + MockitoAnnotations.initMocks(this) + + when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(diskBlockManager.getFile(any[BlockId])).thenAnswer( + (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString)) + when(diskBlockManager.getFile(any[String])).thenAnswer( + (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString)) + when(diskBlockManager.getMergedShuffleFile(any[BlockId], any[Option[Array[String]]])) + .thenAnswer( + (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString)) + when(diskBlockManager.localDirs).thenReturn(Array(tempDir)) + when(diskBlockManager.createTempFileWith(any(classOf[File]))) + .thenAnswer { + invocationOnMock => + val file = invocationOnMock.getArguments()(0).asInstanceOf[File] + Utils.tempFileWith(file) + } + conf.set("spark.app.id", appId) + + resolver = new ColumnarIndexShuffleBlockResolver(conf, blockManager) + } + + override def afterEach(): Unit = { + resolver = null + Utils.deleteRecursively(tempDir) + } + + override def afterAll(): Unit = { + if (indexFile != null) indexFile.delete() + if (dataFile != null) dataFile.delete() + } + + test("getSegmentsFromIndex returns correct segments for complex index file") { + val index = new RandomAccessFile(indexFile, "r") + try { + val method = classOf[ColumnarIndexShuffleBlockResolver].getDeclaredMethod( + "getSegmentsFromIndex", + classOf[java.io.RandomAccessFile], + classOf[Int], + classOf[Int]) + method.setAccessible(true) + // Partition 0: 0 segments + val segs0 = + method.invoke(resolver, index, Int.box(0), Int.box(1)).asInstanceOf[Seq[(Long, Long)]] + assert(segs0.isEmpty) + // Partition 1 - 2: 3 segment + val segs1_2 = + method.invoke(resolver, index, Int.box(1), Int.box(3)).asInstanceOf[Seq[(Long, Long)]] + assert(segs1_2 == Seq((0L, 10L), (10L, 20L), (100L, 50L))) + // Partition 3: 2 segments, empty segment will be filter out + val segs3 = + method.invoke(resolver, index, Int.box(3), Int.box(4)).asInstanceOf[Seq[(Long, Long)]] + assert(segs3 == Seq((30L, 70L), (150L, 50L))) + } finally { + index.close() + } + } + + test("getBlockData returns correct data for partition segments") { + val shuffleId = 1 + val mapId = 0L + val partitionId = 2 + val blockId = ShuffleBlockId(shuffleId, mapId, partitionId) + // commit index and data files + resolver.writeIndexFileAndCommit(shuffleId, mapId, dataFile, indexFile, numPartitions = 4) + + // getBlockData should return a ManagedBuffer for the requested block + val buffer = resolver.getBlockData(blockId) + val bytes = readManagedBuffer(buffer) + // Partition 2 segments: (10,20), (100,50) + assert(bytes.length == 70, s"Expected 70 bytes, got ${bytes.length}") + val expected = ((10 until 30) ++ (100 until 150)).map(_.toByte).toArray + assert( + bytes.sameElements(expected), + s"Expected ${expected.mkString(",")}, got ${bytes.mkString(",")}") + } + + private def readManagedBuffer(buffer: ManagedBuffer): Array[Byte] = { + val in = buffer.createInputStream() + val out = new scala.collection.mutable.ArrayBuffer[Byte]() + val buf = new Array[Byte](4096) + var n = in.read(buf) + while (n != -1) { + out ++= buf.take(n) + n = in.read(buf) + } + in.close() + out.toArray + } + + test("canUseNewFormat returns correct value based on config") { + val conf1 = new SparkConf() + .set("spark.shuffle.push.based.enabled", "false") + .set("spark.shuffle.service.enabled", "false") + val resolver1 = new ColumnarIndexShuffleBlockResolver(conf1) + assert(resolver1.canUseNewFormat(), "Should use new format when both configs are false") + + val conf2 = new SparkConf() + .set("spark.shuffle.push.based.enabled", "true") + .set("spark.shuffle.service.enabled", "false") + val resolver2 = new ColumnarIndexShuffleBlockResolver(conf2) + assert( + !resolver2.canUseNewFormat(), + "Should not use new format when push-based shuffle is enabled") + + val conf3 = new SparkConf() + .set("spark.shuffle.push.based.enabled", "false") + .set("spark.shuffle.service.enabled", "true") + val resolver3 = new ColumnarIndexShuffleBlockResolver(conf3) + assert( + !resolver3.canUseNewFormat(), + "Should not use new format when shuffle service is enabled") + + val conf4 = new SparkConf() + .set("spark.shuffle.push.based.enabled", "true") + .set("spark.shuffle.service.enabled", "true") + val resolver4 = new ColumnarIndexShuffleBlockResolver(conf4) + assert(!resolver4.canUseNewFormat(), "Should not use new format when both are enabled") + } +} diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegionSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegionSuite.scala new file mode 100644 index 000000000000..bfac0518ccb5 --- /dev/null +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/DiscontiguousFileRegionSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import java.io.{ByteArrayOutputStream, File, FileOutputStream} +import java.nio.ByteBuffer +import java.nio.channels.WritableByteChannel +import java.nio.charset.StandardCharsets + +// The Helper Class (Mock Socket) +class ByteArrayWritableChannel extends WritableByteChannel { + private val out = new ByteArrayOutputStream() + private var open = true + + def toByteArray: Array[Byte] = out.toByteArray + + override def write(src: ByteBuffer): Int = { + val size = src.remaining() + val buf = new Array[Byte](size) + src.get(buf) + out.write(buf) + size + } + + override def isOpen: Boolean = open + override def close(): Unit = { open = false } +} + +class DiscontiguousFileRegionSuite extends AnyFunSuite with BeforeAndAfterAll { + var tempFile: File = _ + + val fileContent = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + override def beforeAll(): Unit = { + tempFile = File.createTempFile("netty-test", ".dat") + val fos = new FileOutputStream(tempFile) + fos.write(fileContent.getBytes(StandardCharsets.UTF_8)) + fos.close() + } + + override def afterAll(): Unit = { + if (tempFile != null) tempFile.delete() + } + + // --- TESTS --- + test("transfer a single segment correctly") { + // Segment: "BCD" (Offset 1, Length 3) + val region = new DiscontiguousFileRegion(tempFile, Seq((1L, 3L))) + val target = new ByteArrayWritableChannel() + + val written = region.transferTo(target, 0) + + assert(written == 3, "Should have written exactly 3 bytes") + assert(new String(target.toByteArray) == "BCD", "Content should match the first 3 letters") + } + + test("transferTo works with lazyOpen=true") { + val region = new DiscontiguousFileRegion(tempFile, Seq((2L, 4L)), lazyOpen = true) + val target = new ByteArrayWritableChannel() + val written = region.transferTo(target, 0) + assert(written == 4) + assert(new String(target.toByteArray) == "CDEF") + } + + test("concatenate multiple discontiguous segments") { + // Segment 1: "AB" (0, 2) + // Segment 2: "YZ" (24, 2) + val region = new DiscontiguousFileRegion(tempFile, Seq((0L, 2L), (24L, 2L))) + val target = new ByteArrayWritableChannel() + + val written = region.transferTo(target, 0) + + assert(written == 4) + assert(new String(target.toByteArray) == "ABYZ") + } + + test("handle the 'position' parameter correctly (Mid-segment start)") { + // Combined View: "ABC" + "XYZ" = "ABCXYZ" + // Indices: 012 345 + val region = new DiscontiguousFileRegion(tempFile, Seq((0L, 3L), (23L, 3L))) + val target = new ByteArrayWritableChannel() + + // Act: Start from logical position 2 (Letter 'C') + // Expectation: Should read 'C' (from seg1) then "XYZ" (from seg2) + val written = region.transferTo(target, 2) + + assert(written == 4) + assert(new String(target.toByteArray) == "CXYZ") + } + + test("handle starting exactly at the boundary of the second segment") { + // Combined View: "AB" (size 2) + "YZ" (size 2) + val region = new DiscontiguousFileRegion(tempFile, Seq((0L, 2L), (24L, 2L))) + val target = new ByteArrayWritableChannel() + + // Act: Start from logical position 2 (Start of second segment) + val written = region.transferTo(target, 2) + + assert(written == 2) + assert(new String(target.toByteArray) == "YZ") + } + + test("handle a position that skips multiple initial segments") { + // Segments: "A", "B", "C", "D" + val region = new DiscontiguousFileRegion( + tempFile, + Seq( + (0L, 1L), + (1L, 1L), + (2L, 1L), + (3L, 1L) + )) + val target = new ByteArrayWritableChannel() + + // Act: Start at logical pos 3 (Should be only "D") + val written = region.transferTo(target, 3) + + assert(written == 1) + assert(new String(target.toByteArray) == "D") + } + + test("return 0 if position is beyond total size") { + val region = new DiscontiguousFileRegion(tempFile, Seq((0L, 5L))) + val target = new ByteArrayWritableChannel() + + val written = region.transferTo(target, 100) + + assert(written == 0) + assert(target.toByteArray.length == 0) + } +} diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStreamSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStreamSuite.scala new file mode 100644 index 000000000000..47fe9fb4fedd --- /dev/null +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsInputStreamSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import _root_.io.netty.util.internal.PlatformDependent +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import java.io.{File, FileOutputStream} +import java.nio.file.Files + +class FileSegmentsInputStreamSuite extends AnyFunSuite with BeforeAndAfterAll { + private var tempFile: File = _ + private val fileData: Array[Byte] = (0 until 100).map(_.toByte).toArray + + override def beforeAll(): Unit = { + tempFile = Files.createTempFile("fsegments-inputstream", ".bin").toFile + val fos = new FileOutputStream(tempFile) + fos.write(fileData) + fos.close() + } + + override def afterAll(): Unit = { + if (tempFile != null && tempFile.exists()) tempFile.delete() + } + + test("read single segment") { + val segments = Seq((10L, 5L)) + val in = new FileSegmentsInputStream(tempFile, segments) + val out = readAll(in, 5) + assert(out.sameElements(fileData.slice(10, 15))) + assert(in.read() == -1) + in.close() + } + + test("read multiple non-contiguous segments") { + val segments = Seq((2L, 3L), (10L, 4L), (50L, 2L)) + val in = new FileSegmentsInputStream(tempFile, segments) + val out = readAll(in, 9) + val expected = fileData.slice(2, 5) ++ fileData.slice(10, 14) ++ fileData.slice(50, 52) + assert(out.sameElements(expected)) + assert(in.read() == -1) + in.close() + } + + test("read single bytes sequentially") { + val segments = Seq((20L, 3L)) + val in = new FileSegmentsInputStream(tempFile, segments) + val b0 = in.read() + val b1 = in.read() + val b2 = in.read() + val b3 = in.read() + assert(b0 == (fileData(20) & 0xff)) + assert(b1 == (fileData(21) & 0xff)) + assert(b2 == (fileData(22) & 0xff)) + assert(b3 == -1) + in.close() + } + + test("skip bytes across segments") { + val segments = Seq((10L, 3L), (20L, 4L)) + val in = new FileSegmentsInputStream(tempFile, segments) + // Skip 2 bytes in the first segment + val skipped1 = in.skip(2) + assert(skipped1 == 2) + // Should now be at fileData(12) + val b = in.read() + assert(b == (fileData(12) & 0xff)) + // Skip 2 bytes (should move to the second segment) + val skipped2 = in.skip(2) + assert(skipped2 == 2) + // Should now be at fileData(22) + val b2 = in.read() + assert(b2 == (fileData(22) & 0xff)) + // Skip more than remaining (should reach end) + val skipped3 = in.skip(100) + assert(skipped3 == 1) // 1 byte left in segments + assert(in.read() == -1) + in.close() + } + + test("read into native memory") { + val segments = Seq((2L, 3L), (10L, 4L)) + val in = new FileSegmentsInputStream(tempFile, segments) + val expected = fileData.slice(2, 5) ++ fileData.slice(10, 14) + + val buffer = java.nio.ByteBuffer.allocateDirect(expected.length) + val addr = PlatformDependent.directBufferAddress(buffer) + val read = in.read(addr, expected.length) + assert(read == expected.length) + + buffer.limit(expected.length) + val out = new Array[Byte](expected.length) + buffer.get(out) + assert(out.sameElements(expected)) + assert(in.read(addr, 1) == 0) + in.close() + } + + private def readAll(in: FileSegmentsInputStream, size: Int): Array[Byte] = { + val buffer = new Array[Byte](size) + var total = 0 + var n = 0 + while (n >= 0 && total < size) { + n = in.read(buffer, total, size - total) + if (n > 0) { + total += n + } + } + buffer + } +} diff --git a/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala new file mode 100644 index 000000000000..1ea3ebed8ae1 --- /dev/null +++ b/gluten-substrait/src/test/scala/org/apache/spark/shuffle/sort/FileSegmentsManagedBufferSuite.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark.network.util.TransportConf + +import _root_.io.netty.channel.FileRegion +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import java.io._ +import java.nio.channels.WritableByteChannel +import java.nio.file.Files + +class FakeTransportConf(private val lazyOpen: Boolean = false) extends TransportConf("test", null) { + override def memoryMapBytes(): Int = 4096 // 4 KB + override def lazyFileDescriptor(): Boolean = lazyOpen +} + +class FileSegmentsManagedBufferSuite extends AnyFunSuite with BeforeAndAfterAll { + private class ByteArrayWritableChannel extends WritableByteChannel { + private val out = new ByteArrayOutputStream() + private var open = true + + def toByteArray: Array[Byte] = out.toByteArray + + override def write(src: java.nio.ByteBuffer): Int = { + val size = src.remaining() + val buf = new Array[Byte](size) + src.get(buf) + out.write(buf) + size + } + + override def isOpen: Boolean = open + override def close(): Unit = { open = false } + } + + private var tempFile: File = _ + private val fileData: Array[Byte] = (0 until 100).map(_.toByte).toArray + + override def beforeAll(): Unit = { + tempFile = Files.createTempFile("fsegments-test", ".bin").toFile + val fos = new FileOutputStream(tempFile) + fos.write(fileData) + fos.close() + } + + override def afterAll(): Unit = { + if (tempFile != null && tempFile.exists()) tempFile.delete() + } + + test("convertToNetty returns FileRegion with correct count and content") { + val conf = new FakeTransportConf() + val segments = Seq((2L, 3L), (10L, 4L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val nettyObj = buf.convertToNetty() + assert(nettyObj.isInstanceOf[FileRegion]) + val region = nettyObj.asInstanceOf[FileRegion] + assert(region.count() == 7L) + assert(region.position() == 0L) + val target = new ByteArrayWritableChannel() + val written = region.transferTo(target, 0) + assert(written == 7L) + val expected = fileData.slice(2, 5) ++ fileData.slice(10, 14) + assert(target.toByteArray.sameElements(expected)) + } + + test("convertToNetty supports lazyOpen via FileRegion transfer") { + val conf = new FakeTransportConf(lazyOpen = true) + val segments = Seq((5L, 3L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val nettyObj = buf.convertToNetty() + assert(nettyObj.isInstanceOf[FileRegion]) + val region = nettyObj.asInstanceOf[FileRegion] + val target = new ByteArrayWritableChannel() + val written = region.transferTo(target, 0) + assert(written == 3L) + assert(target.toByteArray.sameElements(fileData.slice(5, 8))) + } + + test("size returns sum of segment lengths") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq((2L, 10L), (20L, 5L), (50L, 15L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + assert(buf.size() == 10L + 5L + 15L) + } + + test("empty segments returns zero size") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq.empty[(Long, Long)] + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + assert(buf.size() == 0L) + } + + test("nioByteBuffer reads single segment correctly") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq((10L, 5L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val bb = buf.nioByteBuffer() + val arr = new Array[Byte](5) + bb.get(arr) + assert(arr.sameElements(fileData.slice(10, 15))) + assert(!bb.hasRemaining) + } + + test("nioByteBuffer reads multiple segments in sequence") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq((5L, 3L), (20L, 2L), (40L, 4L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val bb = buf.nioByteBuffer() + val arr = new Array[Byte](9) + bb.get(arr) + assert(arr.slice(0, 3).sameElements(fileData.slice(5, 8))) + assert(arr.slice(3, 5).sameElements(fileData.slice(20, 22))) + assert(arr.slice(5, 9).sameElements(fileData.slice(40, 44))) + assert(!bb.hasRemaining) + } + + test("nioByteBuffer throws EOFException if segment exceeds file length") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq((95L, 10L)) // goes past EOF + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val thrown = intercept[EOFException] { + buf.nioByteBuffer() + } + assert(thrown.getMessage.contains("EOF reached while reading segment")) + } + + test("nioByteBuffer with empty segments returns empty buffer") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq.empty[(Long, Long)] + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val bb = buf.nioByteBuffer() + assert(bb.remaining() == 0) + assert(!bb.hasRemaining) + } + + test("nioByteBuffer with mmap for a single segment") { + // Create a large file (16 KB) + val largeData = (0 until 16384).map(_.toByte).toArray + val largeFile = Files.createTempFile("fsegments-mmap-test", ".bin").toFile + val fos = new FileOutputStream(largeFile) + fos.write(largeData) + fos.close() + + val conf = new FakeTransportConf() + val segment = (0L, 16384L) + val buf = new FileSegmentsManagedBuffer(conf, largeFile, Seq(segment)) + val nioBuf = buf.nioByteBuffer() + val arr = new Array[Byte](16384) + nioBuf.get(arr) + assert(arr.sameElements(largeData)) + + largeFile.delete() + } + + test("createInputStream reads single segment correctly") { + val conf = null.asInstanceOf[TransportConf] // Not used in this test + val segments = Seq((10L, 5L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val in = buf.createInputStream() + val read = new Array[Byte](5) + assert(in.read(read) == 5) + assert(read.sameElements(fileData.slice(10, 15))) + assert(in.read() == -1) + in.close() + } + + test("createInputStream reads multiple segments in sequence") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq((5L, 3L), (20L, 2L), (40L, 4L)) + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val in = buf.createInputStream() + val out = new Array[Byte](9) + var total = 0 + var n = 0 + while (n > 0 || total == 0) { + n = in.read(out, total, out.length - total) + if (n > 0) total += n + } + assert(total == 9) + assert(out.slice(0, 3).sameElements(fileData.slice(5, 8))) + assert(out.slice(3, 5).sameElements(fileData.slice(20, 22))) + assert(out.slice(5, 9).sameElements(fileData.slice(40, 44))) + assert(in.read() == -1) + in.close() + } + + test("createInputStream read nothing if segment exceeds file length") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq((105L, 10L)) // goes past EOF + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + // should raise EOFException or read nothing + try { + val in = buf.createInputStream() + val read = new Array[Byte](10) + assert(in.read(read) == -1) + in.close() + } catch { + case e: EOFException => + assert(e.getMessage.contains("reached end of stream")) + case t: Throwable => + fail(s"Unexpected exception: $t") + } + } + + test("createInputStream with empty segments returns EOF immediately") { + val conf = null.asInstanceOf[TransportConf] + val segments = Seq.empty[(Long, Long)] + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val in = buf.createInputStream() + assert(in.read() == -1) + val out = new Array[Byte](8) + assert(in.read(out) == -1) + in.close() + } + + test("convertToNetty with empty segments has zero readable bytes") { + val conf = new FakeTransportConf() + val segments = Seq.empty[(Long, Long)] + val buf = new FileSegmentsManagedBuffer(conf, tempFile, segments) + val nettyObj = buf.convertToNetty() + assert(nettyObj.isInstanceOf[FileRegion]) + val region = nettyObj.asInstanceOf[FileRegion] + assert(region.count() == 0L) + val target = new ByteArrayWritableChannel() + assert(region.transferTo(target, 0) == 0L) + assert(target.toByteArray.isEmpty) + } +}