Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nsparse/id_map_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void IDMapIndex::add_with_ids(idx_t n, const idx_t* indptr,
}
}
void IDMapIndex::write_index(IOWriter* io_writer) {
nsparse::write_index(delegate_, io_writer);
nsparse::detail::write_index(delegate_, io_writer, true);

// Write internal_to_external_ vector
size_t map_size = internal_to_external_.size();
Expand All @@ -86,7 +86,7 @@ void IDMapIndex::write_index(IOWriter* io_writer) {
}

void IDMapIndex::read_index(IOReader* io_reader) {
delegate_ = nsparse::read_index(io_reader);
delegate_ = nsparse::detail::read_index(io_reader, true);

// Read internal_to_external_ vector
size_t map_size = 0;
Expand Down
33 changes: 24 additions & 9 deletions nsparse/io/index_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ Index* read_header(IOReader* io_reader) {
}
}
} // namespace
void write_index(Index* index, IOWriter* io_writer) {

namespace detail {
void write_index(Index* index, IOWriter* io_writer, bool keep_open) {
auto* index_io = dynamic_cast<IndexIO*>(index);
if (index_io == nullptr) {
throw std::runtime_error("Index does not support serialization");
Expand All @@ -66,24 +68,37 @@ void write_index(Index* index, IOWriter* io_writer) {
write_header(index, io_writer);
// write index customized payload
index_io->write_index(io_writer);
io_writer->close();
}

void write_index(Index* index, char* filename) {
FileIOWriter writer(filename);
write_index(index, &writer);
if (!keep_open) {
io_writer->close();
}
}

Index* read_index(IOReader* io_reader) {
Index* read_index(IOReader* io_reader, bool keep_open) {
Index* index = read_header(io_reader);
auto* index_io = dynamic_cast<IndexIO*>(index);
if (index_io == nullptr) {
throw std::runtime_error("Index does not support serialization");
}
index_io->read_index(io_reader);
io_reader->close();
if (!keep_open) {
io_reader->close();
}
return index;
}
} // namespace detail

void write_index(Index* index, IOWriter* io_writer) {
detail::write_index(index, io_writer, false);
}

void write_index(Index* index, char* filename) {
FileIOWriter writer(filename);
write_index(index, &writer);
}

Index* read_index(IOReader* io_reader) {
return detail::read_index(io_reader, false);
}

Index* read_index(char* filename) {
FileIOReader reader(filename);
Expand Down
7 changes: 6 additions & 1 deletion nsparse/io/index_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
#include "nsparse/io/io.h"
namespace nsparse {

namespace detail {
void write_index(Index* index, IOWriter* io_writer, bool keep_open);
Index* read_index(IOReader* io_reader, bool keep_open);
} // namespace detail

void write_index(Index* index, char* filename);
Index* read_index(char* filename);
void write_index(Index* index, IOWriter* io_writer);
Index* read_index(char* filename);
Index* read_index(IOReader* io_reader);

} // namespace nsparse
Expand Down
9 changes: 8 additions & 1 deletion nsparse/seismic_scalar_quantized_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,18 @@ auto SeismicScalarQuantizedIndex::search(idx_t n, const idx_t* indptr,
// if filter ids size is <= k, just run exact match
if (detail::should_run_exact_match(search_parameters->get_id_selector(), k,
&query_vectors)) {
return detail::ExactMatcher::search(
auto [distances, labels] = detail::ExactMatcher::search(
vectors_.get(),
dynamic_cast<const IDSelectorEnumerable*>(
search_parameters->get_id_selector()),
&query_vectors, element_size, k);
// Decode quantized dot product scores
for (auto& query_distances : distances) {
for (auto& dist : query_distances) {
dist = sq_.decode_dot_product(dist, query_sq);
}
}
return {distances, labels};
}

std::vector<std::vector<float>> result_distances(n);
Expand Down
152 changes: 152 additions & 0 deletions tests/index_io_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "nsparse/brutal_index.h"
#include "nsparse/id_map_index.h"
#include "nsparse/index.h"
#include "nsparse/inverted_index.h"
#include "nsparse/io/buffered_io.h"
#include "nsparse/io/io.h"
#include "nsparse/seismic_index.h"
Expand All @@ -28,6 +29,49 @@

namespace {

// IOWriter that throws if write() is called after close().
// Simulates real file I/O where writing to a closed stream is invalid.
class StrictBufferedIOWriter : public nsparse::IOWriter {
public:
void write(void* ptr, size_t size, size_t nitems) override {
if (closed_) {
throw std::runtime_error(
"Write after close: stream already closed");
}
delegate_.write(ptr, size, nitems);
}

void close() override { closed_ = true; }

const std::vector<uint8_t>& data() const { return delegate_.data(); }
size_t size() const { return delegate_.size(); }

private:
nsparse::BufferedIOWriter delegate_;
bool closed_ = false;
};

// IOReader that throws if read() is called after close().
// Simulates real file I/O where reading from a closed stream is invalid.
class StrictBufferedIOReader : public nsparse::IOReader {
public:
explicit StrictBufferedIOReader(const std::vector<uint8_t>& data)
: delegate_(data) {}

size_t read(void* ptr, size_t size, size_t nitems) override {
if (closed_) {
throw std::runtime_error("Read after close: stream already closed");
}
return delegate_.read(ptr, size, nitems);
}

void close() override { closed_ = true; }

private:
nsparse::BufferedIOReader delegate_;
bool closed_ = false;
};

// Mock Index that implements both Index and IndexIO for testing
class MockIndex : public nsparse::Index, public nsparse::IndexIO {
public:
Expand Down Expand Up @@ -302,3 +346,111 @@ TEST(IndexIO, RoundtripIDMapIndexWithData) {
delete original;
delete loaded;
}

// Verify StrictBufferedIOWriter throws on write after close
TEST(IndexIO, StrictWriterThrowsAfterClose) {
StrictBufferedIOWriter writer;
int val = 42;
writer.write(&val, sizeof(int), 1);
writer.close();
ASSERT_THROW(writer.write(&val, sizeof(int), 1), std::runtime_error);
}

// Verify StrictBufferedIOReader throws on read after close
TEST(IndexIO, StrictReaderThrowsAfterClose) {
std::vector<uint8_t> buf(sizeof(int), 0);
StrictBufferedIOReader reader(buf);
int val = 0;
reader.read(&val, sizeof(int), 1);
reader.close();
ASSERT_THROW(reader.read(&val, sizeof(int), 1), std::runtime_error);
}

// IDMapIndex wrapping SeismicIndex: write then read with strict IO.
// Before the keep_open fix, write_index closed the stream before
// IDMapIndex could write its id map, causing a write-after-close crash.
TEST(IndexIO, StrictIO_RoundtripIDMapSeismicIndex) {
auto* seismic = new nsparse::SeismicIndex(128);
auto* original = new nsparse::IDMapIndex(seismic);

std::vector<nsparse::idx_t> indptr = {0, 2, 4};
std::vector<nsparse::term_t> indices = {0, 1, 2, 3};
std::vector<float> values = {1.0F, 0.5F, 0.8F, 0.3F};
std::vector<nsparse::idx_t> ids = {100, 200};
original->add_with_ids(2, indptr.data(), indices.data(), values.data(),
ids.data());

StrictBufferedIOWriter writer;
ASSERT_NO_THROW(nsparse::write_index(original, &writer));

StrictBufferedIOReader reader(writer.data());
nsparse::Index* loaded = nullptr;
ASSERT_NO_THROW(loaded = nsparse::read_index(&reader));

ASSERT_NE(loaded, nullptr);
ASSERT_EQ(loaded->id(), original->id());
ASSERT_EQ(loaded->num_vectors(), 2);

delete original;
delete loaded;
}

// IDMapIndex wrapping InvertedIndex: the original segfault scenario.
// InvertedIndex has a non-trivial write_index/read_index, so the
// stream must stay open for IDMapIndex to write/read its id map after
// the delegate is serialized.
// Note: InvertedIndex::get_vectors() returns nullptr after build()
// (vectors_ is consumed to create inverted_lists_), so num_vectors()
// returns 0. We verify the roundtrip by checking the index type instead.
TEST(IndexIO, StrictIO_RoundtripIDMapInvertedIndex) {
auto* inverted = new nsparse::InvertedIndex(128);
auto* original = new nsparse::IDMapIndex(inverted);

std::vector<nsparse::idx_t> indptr = {0, 2, 4};
std::vector<nsparse::term_t> indices = {0, 1, 2, 3};
std::vector<float> values = {1.0F, 0.5F, 0.8F, 0.3F};
std::vector<nsparse::idx_t> ids = {100, 200};
original->add_with_ids(2, indptr.data(), indices.data(), values.data(),
ids.data());
original->build();

StrictBufferedIOWriter writer;
ASSERT_NO_THROW(nsparse::write_index(original, &writer));

StrictBufferedIOReader reader(writer.data());
nsparse::Index* loaded = nullptr;
ASSERT_NO_THROW(loaded = nsparse::read_index(&reader));

ASSERT_NE(loaded, nullptr);
ASSERT_EQ(loaded->id(), original->id());

delete original;
delete loaded;
}

// IDMapIndex wrapping SeismicScalarQuantizedIndex with strict IO
TEST(IndexIO, StrictIO_RoundtripIDMapSeismicSQIndex) {
auto* sq_index = new nsparse::SeismicScalarQuantizedIndex(128);
auto* original = new nsparse::IDMapIndex(sq_index);

std::vector<nsparse::idx_t> indptr = {0, 2, 4};
std::vector<nsparse::term_t> indices = {0, 1, 2, 3};
std::vector<float> values = {1.0F, 0.5F, 0.8F, 0.3F};
std::vector<nsparse::idx_t> ids = {100, 200};
original->add_with_ids(2, indptr.data(), indices.data(), values.data(),
ids.data());

StrictBufferedIOWriter writer;
ASSERT_NO_THROW(nsparse::write_index(original, &writer));

StrictBufferedIOReader reader(writer.data());
nsparse::Index* loaded = nullptr;
ASSERT_NO_THROW(loaded = nsparse::read_index(&reader));

ASSERT_NE(loaded, nullptr);
ASSERT_EQ(loaded->id(), original->id());
ASSERT_EQ(loaded->num_vectors(), 2);

delete original;
delete loaded;
}
58 changes: 58 additions & 0 deletions tests/seismic_scalar_quantized_index_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,64 @@ TEST(SeismicSQIndexSearch, search_exact_match_with_small_selector) {
EXPECT_GT(distances[0], distances[1]);
}

TEST(SeismicSQIndexSearch, search_exact_match_scores_match_normal_path_scores) {
TestableSeismicSQIndex index(QuantizerType::QT_8bit, 0.0F, 1.0F, 10, 3,
0.5F, 3);
Index* idx = &index;

// doc0: term0=1.0, doc1: term0=0.5, doc2: term0=0.8
index.add_docs({{{0, 1.0F}}, {{0, 0.5F}}, {{0, 0.8F}}});
index.build();

std::vector<idx_t> query_indptr = {0, 1};
std::vector<term_t> query_indices = {0};
std::vector<float> query_values = {1.0F};

// Normal path (no selector): k=3, all 3 docs returned
std::vector<idx_t> labels_normal(3, -1);
std::vector<float> distances_normal(3, -1.0F);
SeismicSearchParameters params_normal(5, 1000.0F);
idx->search(1, query_indptr.data(), query_indices.data(),
query_values.data(), 3, distances_normal.data(),
labels_normal.data(), &params_normal);

// Exact match path: selector size (2) <= k (2)
std::vector<idx_t> allowed_ids = {0, 2};
ArrayIDSelector selector(allowed_ids.size(), allowed_ids.data());
SeismicSearchParameters params_filtered(5, 1000.0F);
params_filtered.set_id_selector(&selector);

std::vector<idx_t> labels_filtered(2, -1);
std::vector<float> distances_filtered(2, -1.0F);
idx->search(1, query_indptr.data(), query_indices.data(),
query_values.data(), 2, distances_filtered.data(),
labels_filtered.data(), &params_filtered);

// Find doc0's score from the normal path
float doc0_score_normal = -1.0F;
for (int i = 0; i < 3; ++i) {
if (labels_normal[i] == 0) {
doc0_score_normal = distances_normal[i];
break;
}
}
// Find doc0's score from the exact match path
float doc0_score_filtered = -1.0F;
for (int i = 0; i < 2; ++i) {
if (labels_filtered[i] == 0) {
doc0_score_filtered = distances_filtered[i];
break;
}
}

ASSERT_GE(doc0_score_normal, 0.0F);
ASSERT_GE(doc0_score_filtered, 0.0F);
// Scores must match — both paths should decode quantized dot products
EXPECT_FLOAT_EQ(doc0_score_normal, doc0_score_filtered);
// Sanity: decoded score should be in a reasonable range (not raw quantized)
EXPECT_LT(doc0_score_filtered, 10.0F);
}

TEST(SeismicSQIndexSearch, search_with_id_selector_filters_results) {
TestableSeismicSQIndex index(QuantizerType::QT_8bit, 0.0F, 1.0F, 10, 3,
0.5F, 3);
Expand Down
Loading