Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions include/lsh.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,33 @@ namespace grann {
HashTable(_u32 table_size, _u32 vector_dim);
~HashTable();

void generate_hps();
template<typename T>
void generate_hps(const T* sample_data, const size_t ndata);

std::vector<_u32> get_bucket(bitstring bucket_id);

template<typename T>
bitstring get_hash(const T *input_vector);

void add_vector(bitstring vector_hash, _u32 vector_id);
void add_hp(std::vector<float> hp);
void add_hp(std::vector<float>& hp);
void set_bias(std::vector<float>& bias);

void write_to_file(std::ofstream &out);
void read_from_file(std::ifstream &in);

void print_balance();

protected:
_u32 vector_dim; // dimension of points stored/each hp vector
_u32 table_size; // number of hyperplanes
// float **random_hps;
std::vector<std::vector<float>> random_hps;
std::map<size_t, std::vector<_u32>> hashed_vectors;

std::vector<size_t> count_plus, count_minus;

std::vector<float> bias;
};

template<typename T>
Expand All @@ -53,6 +61,7 @@ namespace grann {
_u32 search(const T *query, _u32 res_count, const Parameters &search_params,
_u32 *indices, float *distances, QueryStats *stats = nullptr,
std::vector<label> search_filters = std::vector<label>());
void print_balance();

protected:
_u32 num_tables;
Expand Down
89 changes: 77 additions & 12 deletions src/lsh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@ namespace grann {

vector_dim = vector_d;
table_size = table_s;

for (_u32 i=0; i<table_s; ++i) {
count_minus.push_back(0);
count_plus.push_back(0);
}
}

HashTable::~HashTable() {
}

void HashTable::generate_hps() {
template<typename T>
void HashTable::generate_hps(const T* sample_data, const size_t nsample) {
std::random_device r;
std::default_random_engine rng{r()};
std::normal_distribution<float> gaussian_dist;
Expand All @@ -36,6 +42,19 @@ namespace grann {
random_hp.push_back(add);
}
add_hp(random_hp);

std::vector<T> hash_values;
hash_values.reserve(nsample);
#pragma omp parallel for
for (size_t i=0; i<nsample; ++i) {
float dot_p=0.0;
for (size_t j = 0; j < vector_dim; j++) {
dot_p += random_hp[j] * sample_data[i*vector_dim + j];
}
hash_values.push_back(dot_p);
}
std::sort(hash_values.begin(), hash_values.end());
bias.push_back(hash_values[hash_values.size()/2]);
}
}

Expand All @@ -45,32 +64,45 @@ namespace grann {

template<typename T>
bitstring HashTable::get_hash(const T *input_vector) {
bitstring input_bits;
bitstring bits;
for (size_t i = 0; i < table_size; i++) {
// float dot_p = std::inner_product(random_hps[i].begin(),
// random_hps[i].end(), input_vector, 0.0);
float dot_p = 0.0;
for (size_t j = 0; j < vector_dim; j++) {
float x = random_hps[i][j] * input_vector[j];
dot_p += x;
dot_p += random_hps[i][j] * input_vector[j];
}

if (dot_p > bias[i]) {
bits[i] = 1;
count_plus[i]++;
}
else {
bits[i] = 0;
count_minus[i]++;
}
}
return bits;
}

if (dot_p > 0)
input_bits[i] = 1;
else
input_bits[i] = 0;
void HashTable::print_balance() {
for (_u32 i=0; i < table_size; ++i) {
std::cout << count_plus[i] << "|" << count_minus[i] << " ";
}
return input_bits;
}

void HashTable::add_vector(bitstring vector_hash, _u32 vector_id) {
hashed_vectors[(size_t) vector_hash.to_ulong()].push_back(vector_id);
}

void HashTable::add_hp(std::vector<float> hp) {
void HashTable::add_hp(std::vector<float>& hp) {
random_hps.push_back(hp);
}

void HashTable::set_bias(std::vector<float>& bias) {
this->bias = bias;
}

void HashTable::write_to_file(std::ofstream &out) {
// 0. write null terminator to denote start of hashtable
out.put('@');
Expand All @@ -86,7 +118,14 @@ namespace grann {

out.put('%');

// 2. write map
// 2. write hyperplane vectors
for (const auto &b : bias) {
out.write(reinterpret_cast<const char *>(&b), sizeof(float));
}

out.put('%');

// 3. write map
_u32 num_buckets = hashed_vectors.size();
out.write(reinterpret_cast<const char *>(&num_buckets), sizeof(_u32));
for (const auto &bucket : hashed_vectors) {
Expand Down Expand Up @@ -134,7 +173,7 @@ namespace grann {
// 1. generate hyperplanes for the table
for (size_t i = 0; i < num_tables; i++) {
HashTable table = HashTable(table_size, this->_aligned_dim);
table.generate_hps();
table.generate_hps(this->_data, this->_num_points);
tables.push_back(table);
}

Expand All @@ -146,6 +185,15 @@ namespace grann {
}
}
}

template<typename T>
void LSHIndex<T>::print_balance() {
for (auto &table : tables) {
std::cout << std::endl;
table.print_balance();
}
std::cout << std::endl;
}

template<typename T>
_u32 LSHIndex<T>::search(const T *query, _u32 res_count,
Expand Down Expand Up @@ -274,6 +322,23 @@ namespace grann {
exit(1);
}

std::vector<float> bias;
for (size_t j = 0; j < f_table_size; j++) {
float next_bias_element;
in.read(reinterpret_cast<char *>(&next_bias_element), sizeof(float));
bias.push_back(next_bias_element);
}
curr_table.set_bias(bias);

in.get(mid);
if ((mid) != '%') {
printf("%d\n", mid);
perror(
"Mistake in file formation, missing null terminator at middle. "
"Exiting...");
exit(1);
}

// 3. get map
_u32 f_num_buckets;
in.read(reinterpret_cast<char *>(&f_num_buckets), sizeof(_u32));
Expand Down
1 change: 1 addition & 0 deletions tests/build_lsh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ int build_lsh_index(const std::string &data_path, const std::string &labels_file

grann::LSHIndex<T> lsh(m, data_path.c_str(), idmap, labels_file);
lsh.build(params);
lsh.print_balance();
lsh.save(save_path.c_str());

return 0;
Expand Down