diff --git a/src/examples/filtered/filtered.py b/src/examples/filtered/filtered.py index 3f77c85d..85575097 100644 --- a/src/examples/filtered/filtered.py +++ b/src/examples/filtered/filtered.py @@ -177,7 +177,7 @@ def read_sparse_matrix(fname, do_mmap=False): print('Constructing the LSH table') t1 = timeit.default_timer() table = falconn.LSHIndex(params_cp) - table.setup(dataset, mydic) + table.setup(dataset, mydic, {}) t2 = timeit.default_timer() print('Done') print('Construction time: {}'.format(t2 - t1)) @@ -227,7 +227,7 @@ def evaluate_number_of_probes(number_of_probes): res = 0 right = 0 for (i, query) in enumerate(queries): - res = query_object.find_nearest_neighbor(query, queries_metadata[i].indices) + res = query_object.find_nearest_neighbor(query, queries_metadata[i].indices, 20) real_point = None # for result in res: # breakpoint() diff --git a/src/examples/glove/glove.cc b/src/examples/glove/glove.cc index 0b8f428e..442b395f 100644 --- a/src/examples/glove/glove.cc +++ b/src/examples/glove/glove.cc @@ -178,7 +178,7 @@ double evaluate_num_probes(LSHNearestNeighborTable *table, int num_matches = 0; vector candidates; for (const auto &query : queries) { - query_object->get_candidates_with_duplicates(query, &candidates); + query_object->get_candidates_with_duplicates(query, &candidates, 0); for (auto x : candidates) { if (x == answers[outer_counter]) { ++num_matches; diff --git a/src/include/falconn/core/lsh_function_helpers.h b/src/include/falconn/core/lsh_function_helpers.h index d2e6780f..7d91f799 100644 --- a/src/include/falconn/core/lsh_function_helpers.h +++ b/src/include/falconn/core/lsh_function_helpers.h @@ -92,7 +92,8 @@ class HashObjectQuery { void get_probes_by_table(const VectorType& point, std::vector>* probes, - int_fast64_t num_probes) { + int_fast64_t num_probes, + int_fast8_t iterations) { if (num_probes < parent_.l_) { throw LSHFunctionError( "Number of probes must be at least " @@ -106,11 +107,22 @@ class HashObjectQuery { } hash_transformation_.apply(point, &transformed_vector_); - multiprobe_.setup_probing(transformed_vector_, num_probes); + multiprobe_.setup_probing(transformed_vector_, num_probes * (iterations+1)); int_fast32_t cur_table; HashType cur_probe; + bool stop = false; + for (int_fast8_t its = 0; its* result) { + std::vector* result, + int_fast8_t iterations) { if (result == nullptr) { throw LSHTableError("Results vector pointer is nullptr."); } @@ -134,7 +135,7 @@ class StaticLSHTable auto start_time = std::chrono::high_resolution_clock::now(); stats_.num_queries += 1; - lsh_query_.get_probes_by_table(p, &tmp_probes_by_table_, num_probes); + lsh_query_.get_probes_by_table(p, &tmp_probes_by_table_, num_probes, iterations); auto lsh_end_time = std::chrono::high_resolution_clock::now(); auto elapsed_lsh = @@ -180,7 +181,8 @@ class StaticLSHTable void get_unique_candidates(const PointType& p, int_fast64_t num_probes, int_fast64_t max_num_candidates, - std::vector* result) { + std::vector* result, + int_fast8_t iterations = 0) { if (result == nullptr) { throw LSHTableError("Results vector pointer is nullptr."); } @@ -188,7 +190,7 @@ class StaticLSHTable auto start_time = std::chrono::high_resolution_clock::now(); stats_.num_queries += 1; - get_unique_candidates_internal(p, num_probes, max_num_candidates, result); + get_unique_candidates_internal(p, num_probes, max_num_candidates, result, iterations); auto end_time = std::chrono::high_resolution_clock::now(); auto elapsed_total = @@ -222,10 +224,11 @@ class StaticLSHTable void get_unique_candidates_internal(const PointType& p, int_fast64_t num_probes, int_fast64_t max_num_candidates, - std::vector* result) { + std::vector* result, + int_fast8_t iterations = 0) { auto start_time = std::chrono::high_resolution_clock::now(); - lsh_query_.get_probes_by_table(p, &tmp_probes_by_table_, num_probes); + lsh_query_.get_probes_by_table(p, &tmp_probes_by_table_, num_probes, iterations); auto lsh_end_time = std::chrono::high_resolution_clock::now(); auto elapsed_lsh = diff --git a/src/include/falconn/core/nn_query.h b/src/include/falconn/core/nn_query.h index b1a45a69..d3db5d63 100644 --- a/src/include/falconn/core/nn_query.h +++ b/src/include/falconn/core/nn_query.h @@ -42,9 +42,10 @@ class NearestNeighborQuery { const ComparisonPointType& q_comp, const std::set& q_filter, int_fast64_t num_probes, - int_fast64_t max_num_candidates) { + int_fast64_t max_num_candidates, + int_fast64_t max_iterations) { auto start_time = std::chrono::high_resolution_clock::now(); - printf("HELLO WHAT"); + // printf("HELLO WHAT"); auto distance_start_time = std::chrono::high_resolution_clock::now(); LSHTableKeyType best_key = -1; @@ -59,7 +60,7 @@ class NearestNeighborQuery { } } - printf("small label: %d\n", smallest_label); + // printf("small label: %d\n", smallest_label); if(smallest_label != -1) { std::vector indices = small_labels_store_.get_indices_for_label(smallest_label); @@ -90,11 +91,11 @@ class NearestNeighborQuery { } int iteration = 0; - while (best_key == -1 && iteration < 5) { + while (best_key == -1 && iteration < max_iterations) { // printf("Start iteration %d\n", iteration); + table_query_->get_unique_candidates(q, num_probes, max_num_candidates, + &candidates_, iteration); iteration += 1; - table_query_->get_unique_candidates(q, num_probes*iteration, max_num_candidates, - &candidates_); // TODO: use nullptr for pointer types // printf("Fundet candidates %ld\n", candidates_.size()); if (candidates_.size() > 0) { @@ -130,7 +131,7 @@ class NearestNeighborQuery { } if(is_good) { DistanceType cur_distance = dst_(q_comp, point); - printf("%d %f\n", iter.get_key(), cur_distance); + // printf("%d %f\n", iter.get_key(), cur_distance); if (cur_distance < best_distance || no_distance_found) { best_distance = cur_distance; @@ -263,11 +264,12 @@ class NearestNeighborQuery { void get_candidates_with_duplicates(const LSHTablePointType& q, int_fast64_t num_probes, int_fast64_t max_num_candidates, - std::vector* result) { + std::vector* result, + int_fast8_t iterations) { auto start_time = std::chrono::high_resolution_clock::now(); table_query_->get_candidates_with_duplicates(q, num_probes, - max_num_candidates, result); + max_num_candidates, result, iterations); auto end_time = std::chrono::high_resolution_clock::now(); auto elapsed_total = @@ -279,11 +281,12 @@ class NearestNeighborQuery { void get_unique_candidates(const LSHTablePointType& q, int_fast64_t num_probes, int_fast64_t max_num_candidates, - std::vector* result) { + std::vector* result, + int_fast8_t iterations = 0) { auto start_time = std::chrono::high_resolution_clock::now(); table_query_->get_unique_candidates(q, num_probes, max_num_candidates, - result); + result, iterations); auto end_time = std::chrono::high_resolution_clock::now(); auto elapsed_total = diff --git a/src/include/falconn/lsh_nn_table.h b/src/include/falconn/lsh_nn_table.h index 90ad71db..aa0aa9e7 100644 --- a/src/include/falconn/lsh_nn_table.h +++ b/src/include/falconn/lsh_nn_table.h @@ -62,7 +62,7 @@ class LSHNearestNeighborQuery { /// /// Finds the key of the closest candidate in the probing sequence for q. /// - virtual KeyType find_nearest_neighbor(const PointType& q, std::set filters) = 0; + virtual KeyType find_nearest_neighbor(const PointType& q, std::set filters, int_fast64_t max_iterations) = 0; /// /// Find the keys of the k closest candidates in the probing sequence for q. @@ -96,7 +96,8 @@ class LSHNearestNeighborQuery { /// appear in the probing sequence. /// virtual void get_candidates_with_duplicates(const PointType& q, - std::vector* result) = 0; + std::vector* result, + int_fast8_t iterations) = 0; /// /// Resets the query statistics. @@ -149,7 +150,7 @@ class LSHNearestNeighborQueryPool { /// /// Finds the key of the closest candidate in the probing sequence for q. /// - virtual KeyType find_nearest_neighbor(const PointType& q, std::set filters) = 0; + virtual KeyType find_nearest_neighbor(const PointType& q, std::set filters, int_fast64_t max_iterations) = 0; /// /// Find the keys of the k closest candidates in the probing sequence for q. @@ -179,7 +180,8 @@ class LSHNearestNeighborQueryPool { /// See the documentation for LSHNearestNeighborQuery. /// virtual void get_candidates_with_duplicates(const PointType& q, - std::vector* result) = 0; + std::vector* result, + int_fast8_t iterations) = 0; /// /// Resets the query statistics. diff --git a/src/include/falconn/wrapper/cpp_wrapper_impl.h b/src/include/falconn/wrapper/cpp_wrapper_impl.h index 3b8cbd90..3265addf 100644 --- a/src/include/falconn/wrapper/cpp_wrapper_impl.h +++ b/src/include/falconn/wrapper/cpp_wrapper_impl.h @@ -305,9 +305,9 @@ class LSHNNQueryWrapper : public LSHNearestNeighborQuery { new NNQueryType(internal_query_.get(), data_storage, metadata_storage, small_labels_store)); } - KeyType find_nearest_neighbor(const PointType& q, std::set filters) { + KeyType find_nearest_neighbor(const PointType& q, std::set filters, int_fast64_t max_iterations) { return internal_nn_query_->find_nearest_neighbor(q, q, filters, num_probes_, - max_num_candidates_); + max_num_candidates_, max_iterations); } void find_k_nearest_neighbors(const PointType& q, int_fast64_t k, @@ -323,9 +323,10 @@ class LSHNNQueryWrapper : public LSHNearestNeighborQuery { } void get_candidates_with_duplicates(const PointType& q, - std::vector* result) { + std::vector* result, + int_fast8_t iterations) { internal_nn_query_->get_candidates_with_duplicates( - q, num_probes_, max_num_candidates_, result); + q, num_probes_, max_num_candidates_, result, iterations); } void get_unique_candidates(const PointType& q, std::vector* result) { @@ -403,10 +404,10 @@ class LSHNNQueryPool : public LSHNearestNeighborQueryPool { } } - KeyType find_nearest_neighbor(const PointType& q, std::set filters) { + KeyType find_nearest_neighbor(const PointType& q, std::set filters, int_fast64_t max_iterations) { int_fast32_t query_index = get_query_index_and_lock(); KeyType res = internal_nn_queries_[query_index]->find_nearest_neighbor( - q, q, filters, num_probes_, max_num_candidates_); + q, q, filters, num_probes_, max_num_candidates_, max_iterations); unlock_query(query_index); return res; } @@ -428,10 +429,11 @@ class LSHNNQueryPool : public LSHNearestNeighborQueryPool { } void get_candidates_with_duplicates(const PointType& q, - std::vector* result) { + std::vector* result, + int_fast8_t iterations) { int_fast32_t query_index = get_query_index_and_lock(); internal_nn_queries_[query_index]->get_candidates_with_duplicates( - q, num_probes_, max_num_candidates_, result); + q, num_probes_, max_num_candidates_, result, iterations); unlock_query(query_index); } diff --git a/src/python/package/falconn/__init__.py b/src/python/package/falconn/__init__.py index 009ba25c..9db8ba1b 100644 --- a/src/python/package/falconn/__init__.py +++ b/src/python/package/falconn/__init__.py @@ -137,7 +137,7 @@ def find_near_neighbors(self, query, threshold): format(threshold)) return self._inner_entity.find_near_neighbors(query, threshold) - def find_nearest_neighbor(self, query, query_filters): + def find_nearest_neighbor(self, query, query_filters, max_iterations): """Find the key of the closest candidate. Finds the key of the closest candidate in the probing sequence @@ -150,7 +150,7 @@ def find_nearest_neighbor(self, query, query_filters): the second dimension of the dataset. """ self._check_query(query) - return self._inner_entity.find_nearest_neighbor(query, query_filters) + return self._inner_entity.find_nearest_neighbor(query, query_filters, max_iterations) def get_candidates_with_duplicates(self, query): """Retrieve all the candidates for a given query. diff --git a/src/python/wrapper/python_wrapper.cc b/src/python/wrapper/python_wrapper.cc index f9325e42..88e38d58 100644 --- a/src/python/wrapper/python_wrapper.cc +++ b/src/python/wrapper/python_wrapper.cc @@ -94,12 +94,12 @@ class PyLSHNearestNeighborQueryDenseFloat { return inner_entity_->get_max_num_candidates(); } - int32_t find_nearest_neighbor(OuterNumPyArray q, const std::vector &q_filters) { + int32_t find_nearest_neighbor(OuterNumPyArray q, const std::vector &q_filters, int_fast64_t max_iterations) { InnerEigenMap converted_query = numpy_to_eigen(q); std::set converted_query_filters(q_filters.begin(), q_filters.end()); py::gil_scoped_release release; - return inner_entity_->find_nearest_neighbor(converted_query, converted_query_filters); + return inner_entity_->find_nearest_neighbor(converted_query, converted_query_filters, max_iterations); } std::vector find_k_nearest_neighbors(OuterNumPyArray q, @@ -128,11 +128,11 @@ class PyLSHNearestNeighborQueryDenseFloat { return result; } - std::vector get_candidates_with_duplicates(OuterNumPyArray q) { + std::vector get_candidates_with_duplicates(OuterNumPyArray q, int_fast8_t iterations) { InnerEigenMap converted_query = numpy_to_eigen(q); py::gil_scoped_release release; std::vector result; - inner_entity_->get_candidates_with_duplicates(converted_query, &result); + inner_entity_->get_candidates_with_duplicates(converted_query, &result, iterations); return result; } @@ -176,12 +176,12 @@ class PyLSHNearestNeighborQueryPoolDenseFloat { return inner_entity_->get_max_num_candidates(); } - int32_t find_nearest_neighbor(OuterNumPyArray q, const std::vector &q_filters) { + int32_t find_nearest_neighbor(OuterNumPyArray q, const std::vector &q_filters, int_fast64_t max_iterations) { InnerEigenMap converted_query = numpy_to_eigen(q); std::set converted_query_filters(q_filters.begin(), q_filters.end()); py::gil_scoped_release release; - return inner_entity_->find_nearest_neighbor(converted_query, converted_query_filters); + return inner_entity_->find_nearest_neighbor(converted_query, converted_query_filters, max_iterations); } std::vector find_k_nearest_neighbors(OuterNumPyArray q, @@ -210,11 +210,11 @@ class PyLSHNearestNeighborQueryPoolDenseFloat { return result; } - std::vector get_candidates_with_duplicates(OuterNumPyArray q) { + std::vector get_candidates_with_duplicates(OuterNumPyArray q, int_fast8_t iterations) { InnerEigenMap converted_query = numpy_to_eigen(q); py::gil_scoped_release release; std::vector result; - inner_entity_->get_candidates_with_duplicates(converted_query, &result); + inner_entity_->get_candidates_with_duplicates(converted_query, &result, iterations); return result; } @@ -315,12 +315,12 @@ class PyLSHNearestNeighborQueryDenseDouble { return inner_entity_->get_max_num_candidates(); } - int32_t find_nearest_neighbor(OuterNumPyArray q, const std::vector &q_filters) { + int32_t find_nearest_neighbor(OuterNumPyArray q, const std::vector &q_filters, int_fast64_t max_iterations) { InnerEigenMap converted_query = numpy_to_eigen(q); std::set converted_query_filters(q_filters.begin(), q_filters.end()); py::gil_scoped_release release; - return inner_entity_->find_nearest_neighbor(converted_query, converted_query_filters); + return inner_entity_->find_nearest_neighbor(converted_query, converted_query_filters, max_iterations); } std::vector find_k_nearest_neighbors(OuterNumPyArray q, @@ -349,11 +349,11 @@ class PyLSHNearestNeighborQueryDenseDouble { return result; } - std::vector get_candidates_with_duplicates(OuterNumPyArray q) { + std::vector get_candidates_with_duplicates(OuterNumPyArray q, int_fast8_t iterations) { InnerEigenMap converted_query = numpy_to_eigen(q); py::gil_scoped_release release; std::vector result; - inner_entity_->get_candidates_with_duplicates(converted_query, &result); + inner_entity_->get_candidates_with_duplicates(converted_query, &result, iterations); return result; } @@ -397,12 +397,12 @@ class PyLSHNearestNeighborQueryPoolDenseDouble { return inner_entity_->get_max_num_candidates(); } - int32_t find_nearest_neighbor(OuterNumPyArray q, const std::vector &q_filters) { + int32_t find_nearest_neighbor(OuterNumPyArray q, const std::vector &q_filters, int_fast64_t max_iterations) { InnerEigenMap converted_query = numpy_to_eigen(q); std::set converted_query_filters(q_filters.begin(), q_filters.end()); py::gil_scoped_release release; - return inner_entity_->find_nearest_neighbor(converted_query, converted_query_filters); + return inner_entity_->find_nearest_neighbor(converted_query, converted_query_filters, max_iterations); } std::vector find_k_nearest_neighbors(OuterNumPyArray q, @@ -431,11 +431,11 @@ class PyLSHNearestNeighborQueryPoolDenseDouble { return result; } - std::vector get_candidates_with_duplicates(OuterNumPyArray q) { + std::vector get_candidates_with_duplicates(OuterNumPyArray q, int_fast8_t iterations) { InnerEigenMap converted_query = numpy_to_eigen(q); py::gil_scoped_release release; std::vector result; - inner_entity_->get_candidates_with_duplicates(converted_query, &result); + inner_entity_->get_candidates_with_duplicates(converted_query, &result, iterations); return result; }