Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/examples/filtered/filtered.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def read_sparse_matrix(fname, do_mmap=False):
print('Constructing the LSH table')
t1 = timeit.default_timer()
table = falconn.LSHIndex(params_cp)
table.setup(dataset, mydic)
table.setup(dataset, mydic, {})
t2 = timeit.default_timer()
print('Done')
print('Construction time: {}'.format(t2 - t1))
Expand Down Expand Up @@ -227,7 +227,7 @@ def evaluate_number_of_probes(number_of_probes):
res = 0
right = 0
for (i, query) in enumerate(queries):
res = query_object.find_nearest_neighbor(query, queries_metadata[i].indices)
res = query_object.find_nearest_neighbor(query, queries_metadata[i].indices, 20)
real_point = None
# for result in res:
# breakpoint()
Expand Down
2 changes: 1 addition & 1 deletion src/examples/glove/glove.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ double evaluate_num_probes(LSHNearestNeighborTable<Point> *table,
int num_matches = 0;
vector<int32_t> candidates;
for (const auto &query : queries) {
query_object->get_candidates_with_duplicates(query, &candidates);
query_object->get_candidates_with_duplicates(query, &candidates, 0);
for (auto x : candidates) {
if (x == answers[outer_counter]) {
++num_matches;
Expand Down
16 changes: 14 additions & 2 deletions src/include/falconn/core/lsh_function_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ class HashObjectQuery {

void get_probes_by_table(const VectorType& point,
std::vector<std::vector<HashType>>* probes,
int_fast64_t num_probes) {
int_fast64_t num_probes,
int_fast8_t iterations) {
if (num_probes < parent_.l_) {
throw LSHFunctionError(
"Number of probes must be at least "
Expand All @@ -106,11 +107,22 @@ class HashObjectQuery {
}

hash_transformation_.apply(point, &transformed_vector_);
multiprobe_.setup_probing(transformed_vector_, num_probes);
multiprobe_.setup_probing(transformed_vector_, num_probes * (iterations+1));

int_fast32_t cur_table;
HashType cur_probe;
bool stop = false;

for (int_fast8_t its = 0; its<iterations*num_probes; ++its) {
if (!multiprobe_.get_next_probe(&cur_probe, &cur_table)) {
stop = true;
break;
}
}

if (stop) {
return;
}
for (int_fast64_t ii = 0; ii < num_probes; ++ii) {
if (!multiprobe_.get_next_probe(&cur_probe, &cur_table)) {
break;
Expand Down
15 changes: 9 additions & 6 deletions src/include/falconn/core/lsh_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,16 @@ class StaticLSHTable
void get_candidates_with_duplicates(const PointType& p,
int_fast64_t num_probes,
int_fast64_t max_num_candidates,
std::vector<KeyType>* result) {
std::vector<KeyType>* result,
int_fast8_t iterations) {
if (result == nullptr) {
throw LSHTableError("Results vector pointer is nullptr.");
}

auto start_time = std::chrono::high_resolution_clock::now();
stats_.num_queries += 1;

lsh_query_.get_probes_by_table(p, &tmp_probes_by_table_, num_probes);
lsh_query_.get_probes_by_table(p, &tmp_probes_by_table_, num_probes, iterations);

auto lsh_end_time = std::chrono::high_resolution_clock::now();
auto elapsed_lsh =
Expand Down Expand Up @@ -180,15 +181,16 @@ class StaticLSHTable

void get_unique_candidates(const PointType& p, int_fast64_t num_probes,
int_fast64_t max_num_candidates,
std::vector<KeyType>* result) {
std::vector<KeyType>* result,
int_fast8_t iterations = 0) {
if (result == nullptr) {
throw LSHTableError("Results vector pointer is nullptr.");
}

auto start_time = std::chrono::high_resolution_clock::now();
stats_.num_queries += 1;

get_unique_candidates_internal(p, num_probes, max_num_candidates, result);
get_unique_candidates_internal(p, num_probes, max_num_candidates, result, iterations);

auto end_time = std::chrono::high_resolution_clock::now();
auto elapsed_total =
Expand Down Expand Up @@ -222,10 +224,11 @@ class StaticLSHTable
void get_unique_candidates_internal(const PointType& p,
int_fast64_t num_probes,
int_fast64_t max_num_candidates,
std::vector<KeyType>* result) {
std::vector<KeyType>* result,
int_fast8_t iterations = 0) {
auto start_time = std::chrono::high_resolution_clock::now();

lsh_query_.get_probes_by_table(p, &tmp_probes_by_table_, num_probes);
lsh_query_.get_probes_by_table(p, &tmp_probes_by_table_, num_probes, iterations);

auto lsh_end_time = std::chrono::high_resolution_clock::now();
auto elapsed_lsh =
Expand Down
25 changes: 14 additions & 11 deletions src/include/falconn/core/nn_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ class NearestNeighborQuery {
const ComparisonPointType& q_comp,
const std::set<int>& q_filter,
int_fast64_t num_probes,
int_fast64_t max_num_candidates) {
int_fast64_t max_num_candidates,
int_fast64_t max_iterations) {
auto start_time = std::chrono::high_resolution_clock::now();
printf("HELLO WHAT");
// printf("HELLO WHAT");

auto distance_start_time = std::chrono::high_resolution_clock::now();
LSHTableKeyType best_key = -1;
Expand All @@ -59,7 +60,7 @@ class NearestNeighborQuery {
}
}

printf("small label: %d\n", smallest_label);
// printf("small label: %d\n", smallest_label);
if(smallest_label != -1)
{
std::vector<int> indices = small_labels_store_.get_indices_for_label(smallest_label);
Expand Down Expand Up @@ -90,11 +91,11 @@ class NearestNeighborQuery {

}
int iteration = 0;
while (best_key == -1 && iteration < 5) {
while (best_key == -1 && iteration < max_iterations) {
// printf("Start iteration %d\n", iteration);
table_query_->get_unique_candidates(q, num_probes, max_num_candidates,
&candidates_, iteration);
iteration += 1;
table_query_->get_unique_candidates(q, num_probes*iteration, max_num_candidates,
&candidates_);
// TODO: use nullptr for pointer types
// printf("Fundet candidates %ld\n", candidates_.size());
if (candidates_.size() > 0) {
Expand Down Expand Up @@ -130,7 +131,7 @@ class NearestNeighborQuery {
}
if(is_good) {
DistanceType cur_distance = dst_(q_comp, point);
printf("%d %f\n", iter.get_key(), cur_distance);
// printf("%d %f\n", iter.get_key(), cur_distance);
if (cur_distance < best_distance || no_distance_found) {

best_distance = cur_distance;
Expand Down Expand Up @@ -263,11 +264,12 @@ class NearestNeighborQuery {
void get_candidates_with_duplicates(const LSHTablePointType& q,
int_fast64_t num_probes,
int_fast64_t max_num_candidates,
std::vector<LSHTableKeyType>* result) {
std::vector<LSHTableKeyType>* result,
int_fast8_t iterations) {
auto start_time = std::chrono::high_resolution_clock::now();

table_query_->get_candidates_with_duplicates(q, num_probes,
max_num_candidates, result);
max_num_candidates, result, iterations);

auto end_time = std::chrono::high_resolution_clock::now();
auto elapsed_total =
Expand All @@ -279,11 +281,12 @@ class NearestNeighborQuery {
void get_unique_candidates(const LSHTablePointType& q,
int_fast64_t num_probes,
int_fast64_t max_num_candidates,
std::vector<LSHTableKeyType>* result) {
std::vector<LSHTableKeyType>* result,
int_fast8_t iterations = 0) {
auto start_time = std::chrono::high_resolution_clock::now();

table_query_->get_unique_candidates(q, num_probes, max_num_candidates,
result);
result, iterations);

auto end_time = std::chrono::high_resolution_clock::now();
auto elapsed_total =
Expand Down
10 changes: 6 additions & 4 deletions src/include/falconn/lsh_nn_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class LSHNearestNeighborQuery {
///
/// Finds the key of the closest candidate in the probing sequence for q.
///
virtual KeyType find_nearest_neighbor(const PointType& q, std::set<int> filters) = 0;
virtual KeyType find_nearest_neighbor(const PointType& q, std::set<int> filters, int_fast64_t max_iterations) = 0;

///
/// Find the keys of the k closest candidates in the probing sequence for q.
Expand Down Expand Up @@ -96,7 +96,8 @@ class LSHNearestNeighborQuery {
/// appear in the probing sequence.
///
virtual void get_candidates_with_duplicates(const PointType& q,
std::vector<KeyType>* result) = 0;
std::vector<KeyType>* result,
int_fast8_t iterations) = 0;

///
/// Resets the query statistics.
Expand Down Expand Up @@ -149,7 +150,7 @@ class LSHNearestNeighborQueryPool {
///
/// Finds the key of the closest candidate in the probing sequence for q.
///
virtual KeyType find_nearest_neighbor(const PointType& q, std::set<int> filters) = 0;
virtual KeyType find_nearest_neighbor(const PointType& q, std::set<int> filters, int_fast64_t max_iterations) = 0;

///
/// Find the keys of the k closest candidates in the probing sequence for q.
Expand Down Expand Up @@ -179,7 +180,8 @@ class LSHNearestNeighborQueryPool {
/// See the documentation for LSHNearestNeighborQuery.
///
virtual void get_candidates_with_duplicates(const PointType& q,
std::vector<KeyType>* result) = 0;
std::vector<KeyType>* result,
int_fast8_t iterations) = 0;

///
/// Resets the query statistics.
Expand Down
18 changes: 10 additions & 8 deletions src/include/falconn/wrapper/cpp_wrapper_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ class LSHNNQueryWrapper : public LSHNearestNeighborQuery<PointType, KeyType> {
new NNQueryType(internal_query_.get(), data_storage, metadata_storage, small_labels_store));
}

KeyType find_nearest_neighbor(const PointType& q, std::set<int> filters) {
KeyType find_nearest_neighbor(const PointType& q, std::set<int> filters, int_fast64_t max_iterations) {
return internal_nn_query_->find_nearest_neighbor(q, q, filters, num_probes_,
max_num_candidates_);
max_num_candidates_, max_iterations);
}

void find_k_nearest_neighbors(const PointType& q, int_fast64_t k,
Expand All @@ -323,9 +323,10 @@ class LSHNNQueryWrapper : public LSHNearestNeighborQuery<PointType, KeyType> {
}

void get_candidates_with_duplicates(const PointType& q,
std::vector<KeyType>* result) {
std::vector<KeyType>* result,
int_fast8_t iterations) {
internal_nn_query_->get_candidates_with_duplicates(
q, num_probes_, max_num_candidates_, result);
q, num_probes_, max_num_candidates_, result, iterations);
}

void get_unique_candidates(const PointType& q, std::vector<KeyType>* result) {
Expand Down Expand Up @@ -403,10 +404,10 @@ class LSHNNQueryPool : public LSHNearestNeighborQueryPool<PointType, KeyType> {
}
}

KeyType find_nearest_neighbor(const PointType& q, std::set<int> filters) {
KeyType find_nearest_neighbor(const PointType& q, std::set<int> filters, int_fast64_t max_iterations) {
int_fast32_t query_index = get_query_index_and_lock();
KeyType res = internal_nn_queries_[query_index]->find_nearest_neighbor(
q, q, filters, num_probes_, max_num_candidates_);
q, q, filters, num_probes_, max_num_candidates_, max_iterations);
unlock_query(query_index);
return res;
}
Expand All @@ -428,10 +429,11 @@ class LSHNNQueryPool : public LSHNearestNeighborQueryPool<PointType, KeyType> {
}

void get_candidates_with_duplicates(const PointType& q,
std::vector<KeyType>* result) {
std::vector<KeyType>* result,
int_fast8_t iterations) {
int_fast32_t query_index = get_query_index_and_lock();
internal_nn_queries_[query_index]->get_candidates_with_duplicates(
q, num_probes_, max_num_candidates_, result);
q, num_probes_, max_num_candidates_, result, iterations);
unlock_query(query_index);
}

Expand Down
4 changes: 2 additions & 2 deletions src/python/package/falconn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def find_near_neighbors(self, query, threshold):
format(threshold))
return self._inner_entity.find_near_neighbors(query, threshold)

def find_nearest_neighbor(self, query, query_filters):
def find_nearest_neighbor(self, query, query_filters, max_iterations):
"""Find the key of the closest candidate.

Finds the key of the closest candidate in the probing sequence
Expand All @@ -150,7 +150,7 @@ def find_nearest_neighbor(self, query, query_filters):
the second dimension of the dataset.
"""
self._check_query(query)
return self._inner_entity.find_nearest_neighbor(query, query_filters)
return self._inner_entity.find_nearest_neighbor(query, query_filters, max_iterations)

def get_candidates_with_duplicates(self, query):
"""Retrieve all the candidates for a given query.
Expand Down
Loading