From 1b8f069722401329d66583170c31321af0e6cd5b Mon Sep 17 00:00:00 2001 From: Liyun Xiu Date: Thu, 19 Mar 2026 15:35:21 +0800 Subject: [PATCH] Fix IDMapIndex read/write segment fault & fix ExactMatch score no scaling back Signed-off-by: Liyun Xiu --- nsparse/id_map_index.cpp | 4 +- nsparse/io/index_io.cpp | 33 ++-- nsparse/io/index_io.h | 7 +- nsparse/seismic_scalar_quantized_index.cpp | 9 +- tests/index_io_test.cpp | 152 ++++++++++++++++++ tests/seismic_scalar_quantized_index_test.cpp | 58 +++++++ 6 files changed, 250 insertions(+), 13 deletions(-) diff --git a/nsparse/id_map_index.cpp b/nsparse/id_map_index.cpp index 28bbe26..e821423 100644 --- a/nsparse/id_map_index.cpp +++ b/nsparse/id_map_index.cpp @@ -75,7 +75,7 @@ void IDMapIndex::add_with_ids(idx_t n, const idx_t* indptr, } } void IDMapIndex::write_index(IOWriter* io_writer) { - nsparse::write_index(delegate_, io_writer); + nsparse::detail::write_index(delegate_, io_writer, true); // Write internal_to_external_ vector size_t map_size = internal_to_external_.size(); @@ -86,7 +86,7 @@ void IDMapIndex::write_index(IOWriter* io_writer) { } void IDMapIndex::read_index(IOReader* io_reader) { - delegate_ = nsparse::read_index(io_reader); + delegate_ = nsparse::detail::read_index(io_reader, true); // Read internal_to_external_ vector size_t map_size = 0; diff --git a/nsparse/io/index_io.cpp b/nsparse/io/index_io.cpp index f9d9c42..75b4776 100644 --- a/nsparse/io/index_io.cpp +++ b/nsparse/io/index_io.cpp @@ -57,7 +57,9 @@ Index* read_header(IOReader* io_reader) { } } } // namespace -void write_index(Index* index, IOWriter* io_writer) { + +namespace detail { +void write_index(Index* index, IOWriter* io_writer, bool keep_open) { auto* index_io = dynamic_cast(index); if (index_io == nullptr) { throw std::runtime_error("Index does not support serialization"); @@ -66,24 +68,37 @@ void write_index(Index* index, IOWriter* io_writer) { write_header(index, io_writer); // write index customized payload index_io->write_index(io_writer); - io_writer->close(); -} - -void write_index(Index* index, char* filename) { - FileIOWriter writer(filename); - write_index(index, &writer); + if (!keep_open) { + io_writer->close(); + } } -Index* read_index(IOReader* io_reader) { +Index* read_index(IOReader* io_reader, bool keep_open) { Index* index = read_header(io_reader); auto* index_io = dynamic_cast(index); if (index_io == nullptr) { throw std::runtime_error("Index does not support serialization"); } index_io->read_index(io_reader); - io_reader->close(); + if (!keep_open) { + io_reader->close(); + } return index; } +} // namespace detail + +void write_index(Index* index, IOWriter* io_writer) { + detail::write_index(index, io_writer, false); +} + +void write_index(Index* index, char* filename) { + FileIOWriter writer(filename); + write_index(index, &writer); +} + +Index* read_index(IOReader* io_reader) { + return detail::read_index(io_reader, false); +} Index* read_index(char* filename) { FileIOReader reader(filename); diff --git a/nsparse/io/index_io.h b/nsparse/io/index_io.h index ea96bbe..e08a43d 100644 --- a/nsparse/io/index_io.h +++ b/nsparse/io/index_io.h @@ -14,9 +14,14 @@ #include "nsparse/io/io.h" namespace nsparse { +namespace detail { +void write_index(Index* index, IOWriter* io_writer, bool keep_open); +Index* read_index(IOReader* io_reader, bool keep_open); +} // namespace detail + void write_index(Index* index, char* filename); -Index* read_index(char* filename); void write_index(Index* index, IOWriter* io_writer); +Index* read_index(char* filename); Index* read_index(IOReader* io_reader); } // namespace nsparse diff --git a/nsparse/seismic_scalar_quantized_index.cpp b/nsparse/seismic_scalar_quantized_index.cpp index 73cdd2e..41b1111 100644 --- a/nsparse/seismic_scalar_quantized_index.cpp +++ b/nsparse/seismic_scalar_quantized_index.cpp @@ -208,11 +208,18 @@ auto SeismicScalarQuantizedIndex::search(idx_t n, const idx_t* indptr, // if filter ids size is <= k, just run exact match if (detail::should_run_exact_match(search_parameters->get_id_selector(), k, &query_vectors)) { - return detail::ExactMatcher::search( + auto [distances, labels] = detail::ExactMatcher::search( vectors_.get(), dynamic_cast( search_parameters->get_id_selector()), &query_vectors, element_size, k); + // Decode quantized dot product scores + for (auto& query_distances : distances) { + for (auto& dist : query_distances) { + dist = sq_.decode_dot_product(dist, query_sq); + } + } + return {distances, labels}; } std::vector> result_distances(n); diff --git a/tests/index_io_test.cpp b/tests/index_io_test.cpp index 840058e..fa70be1 100644 --- a/tests/index_io_test.cpp +++ b/tests/index_io_test.cpp @@ -19,6 +19,7 @@ #include "nsparse/brutal_index.h" #include "nsparse/id_map_index.h" #include "nsparse/index.h" +#include "nsparse/inverted_index.h" #include "nsparse/io/buffered_io.h" #include "nsparse/io/io.h" #include "nsparse/seismic_index.h" @@ -28,6 +29,49 @@ namespace { +// IOWriter that throws if write() is called after close(). +// Simulates real file I/O where writing to a closed stream is invalid. +class StrictBufferedIOWriter : public nsparse::IOWriter { +public: + void write(void* ptr, size_t size, size_t nitems) override { + if (closed_) { + throw std::runtime_error( + "Write after close: stream already closed"); + } + delegate_.write(ptr, size, nitems); + } + + void close() override { closed_ = true; } + + const std::vector& data() const { return delegate_.data(); } + size_t size() const { return delegate_.size(); } + +private: + nsparse::BufferedIOWriter delegate_; + bool closed_ = false; +}; + +// IOReader that throws if read() is called after close(). +// Simulates real file I/O where reading from a closed stream is invalid. +class StrictBufferedIOReader : public nsparse::IOReader { +public: + explicit StrictBufferedIOReader(const std::vector& data) + : delegate_(data) {} + + size_t read(void* ptr, size_t size, size_t nitems) override { + if (closed_) { + throw std::runtime_error("Read after close: stream already closed"); + } + return delegate_.read(ptr, size, nitems); + } + + void close() override { closed_ = true; } + +private: + nsparse::BufferedIOReader delegate_; + bool closed_ = false; +}; + // Mock Index that implements both Index and IndexIO for testing class MockIndex : public nsparse::Index, public nsparse::IndexIO { public: @@ -302,3 +346,111 @@ TEST(IndexIO, RoundtripIDMapIndexWithData) { delete original; delete loaded; } + +// Verify StrictBufferedIOWriter throws on write after close +TEST(IndexIO, StrictWriterThrowsAfterClose) { + StrictBufferedIOWriter writer; + int val = 42; + writer.write(&val, sizeof(int), 1); + writer.close(); + ASSERT_THROW(writer.write(&val, sizeof(int), 1), std::runtime_error); +} + +// Verify StrictBufferedIOReader throws on read after close +TEST(IndexIO, StrictReaderThrowsAfterClose) { + std::vector buf(sizeof(int), 0); + StrictBufferedIOReader reader(buf); + int val = 0; + reader.read(&val, sizeof(int), 1); + reader.close(); + ASSERT_THROW(reader.read(&val, sizeof(int), 1), std::runtime_error); +} + +// IDMapIndex wrapping SeismicIndex: write then read with strict IO. +// Before the keep_open fix, write_index closed the stream before +// IDMapIndex could write its id map, causing a write-after-close crash. +TEST(IndexIO, StrictIO_RoundtripIDMapSeismicIndex) { + auto* seismic = new nsparse::SeismicIndex(128); + auto* original = new nsparse::IDMapIndex(seismic); + + std::vector indptr = {0, 2, 4}; + std::vector indices = {0, 1, 2, 3}; + std::vector values = {1.0F, 0.5F, 0.8F, 0.3F}; + std::vector ids = {100, 200}; + original->add_with_ids(2, indptr.data(), indices.data(), values.data(), + ids.data()); + + StrictBufferedIOWriter writer; + ASSERT_NO_THROW(nsparse::write_index(original, &writer)); + + StrictBufferedIOReader reader(writer.data()); + nsparse::Index* loaded = nullptr; + ASSERT_NO_THROW(loaded = nsparse::read_index(&reader)); + + ASSERT_NE(loaded, nullptr); + ASSERT_EQ(loaded->id(), original->id()); + ASSERT_EQ(loaded->num_vectors(), 2); + + delete original; + delete loaded; +} + +// IDMapIndex wrapping InvertedIndex: the original segfault scenario. +// InvertedIndex has a non-trivial write_index/read_index, so the +// stream must stay open for IDMapIndex to write/read its id map after +// the delegate is serialized. +// Note: InvertedIndex::get_vectors() returns nullptr after build() +// (vectors_ is consumed to create inverted_lists_), so num_vectors() +// returns 0. We verify the roundtrip by checking the index type instead. +TEST(IndexIO, StrictIO_RoundtripIDMapInvertedIndex) { + auto* inverted = new nsparse::InvertedIndex(128); + auto* original = new nsparse::IDMapIndex(inverted); + + std::vector indptr = {0, 2, 4}; + std::vector indices = {0, 1, 2, 3}; + std::vector values = {1.0F, 0.5F, 0.8F, 0.3F}; + std::vector ids = {100, 200}; + original->add_with_ids(2, indptr.data(), indices.data(), values.data(), + ids.data()); + original->build(); + + StrictBufferedIOWriter writer; + ASSERT_NO_THROW(nsparse::write_index(original, &writer)); + + StrictBufferedIOReader reader(writer.data()); + nsparse::Index* loaded = nullptr; + ASSERT_NO_THROW(loaded = nsparse::read_index(&reader)); + + ASSERT_NE(loaded, nullptr); + ASSERT_EQ(loaded->id(), original->id()); + + delete original; + delete loaded; +} + +// IDMapIndex wrapping SeismicScalarQuantizedIndex with strict IO +TEST(IndexIO, StrictIO_RoundtripIDMapSeismicSQIndex) { + auto* sq_index = new nsparse::SeismicScalarQuantizedIndex(128); + auto* original = new nsparse::IDMapIndex(sq_index); + + std::vector indptr = {0, 2, 4}; + std::vector indices = {0, 1, 2, 3}; + std::vector values = {1.0F, 0.5F, 0.8F, 0.3F}; + std::vector ids = {100, 200}; + original->add_with_ids(2, indptr.data(), indices.data(), values.data(), + ids.data()); + + StrictBufferedIOWriter writer; + ASSERT_NO_THROW(nsparse::write_index(original, &writer)); + + StrictBufferedIOReader reader(writer.data()); + nsparse::Index* loaded = nullptr; + ASSERT_NO_THROW(loaded = nsparse::read_index(&reader)); + + ASSERT_NE(loaded, nullptr); + ASSERT_EQ(loaded->id(), original->id()); + ASSERT_EQ(loaded->num_vectors(), 2); + + delete original; + delete loaded; +} diff --git a/tests/seismic_scalar_quantized_index_test.cpp b/tests/seismic_scalar_quantized_index_test.cpp index 86fdc28..acd70d7 100644 --- a/tests/seismic_scalar_quantized_index_test.cpp +++ b/tests/seismic_scalar_quantized_index_test.cpp @@ -727,6 +727,64 @@ TEST(SeismicSQIndexSearch, search_exact_match_with_small_selector) { EXPECT_GT(distances[0], distances[1]); } +TEST(SeismicSQIndexSearch, search_exact_match_scores_match_normal_path_scores) { + TestableSeismicSQIndex index(QuantizerType::QT_8bit, 0.0F, 1.0F, 10, 3, + 0.5F, 3); + Index* idx = &index; + + // doc0: term0=1.0, doc1: term0=0.5, doc2: term0=0.8 + index.add_docs({{{0, 1.0F}}, {{0, 0.5F}}, {{0, 0.8F}}}); + index.build(); + + std::vector query_indptr = {0, 1}; + std::vector query_indices = {0}; + std::vector query_values = {1.0F}; + + // Normal path (no selector): k=3, all 3 docs returned + std::vector labels_normal(3, -1); + std::vector distances_normal(3, -1.0F); + SeismicSearchParameters params_normal(5, 1000.0F); + idx->search(1, query_indptr.data(), query_indices.data(), + query_values.data(), 3, distances_normal.data(), + labels_normal.data(), ¶ms_normal); + + // Exact match path: selector size (2) <= k (2) + std::vector allowed_ids = {0, 2}; + ArrayIDSelector selector(allowed_ids.size(), allowed_ids.data()); + SeismicSearchParameters params_filtered(5, 1000.0F); + params_filtered.set_id_selector(&selector); + + std::vector labels_filtered(2, -1); + std::vector distances_filtered(2, -1.0F); + idx->search(1, query_indptr.data(), query_indices.data(), + query_values.data(), 2, distances_filtered.data(), + labels_filtered.data(), ¶ms_filtered); + + // Find doc0's score from the normal path + float doc0_score_normal = -1.0F; + for (int i = 0; i < 3; ++i) { + if (labels_normal[i] == 0) { + doc0_score_normal = distances_normal[i]; + break; + } + } + // Find doc0's score from the exact match path + float doc0_score_filtered = -1.0F; + for (int i = 0; i < 2; ++i) { + if (labels_filtered[i] == 0) { + doc0_score_filtered = distances_filtered[i]; + break; + } + } + + ASSERT_GE(doc0_score_normal, 0.0F); + ASSERT_GE(doc0_score_filtered, 0.0F); + // Scores must match — both paths should decode quantized dot products + EXPECT_FLOAT_EQ(doc0_score_normal, doc0_score_filtered); + // Sanity: decoded score should be in a reasonable range (not raw quantized) + EXPECT_LT(doc0_score_filtered, 10.0F); +} + TEST(SeismicSQIndexSearch, search_with_id_selector_filters_results) { TestableSeismicSQIndex index(QuantizerType::QT_8bit, 0.0F, 1.0F, 10, 3, 0.5F, 3);