diff --git a/include/lsh.h b/include/lsh.h index ac69c27..6bdb494 100644 --- a/include/lsh.h +++ b/include/lsh.h @@ -17,7 +17,8 @@ namespace grann { HashTable(_u32 table_size, _u32 vector_dim); ~HashTable(); - void generate_hps(); + template + void generate_hps(const T* sample_data, const size_t ndata); std::vector<_u32> get_bucket(bitstring bucket_id); @@ -25,17 +26,24 @@ namespace grann { bitstring get_hash(const T *input_vector); void add_vector(bitstring vector_hash, _u32 vector_id); - void add_hp(std::vector hp); + void add_hp(std::vector& hp); + void set_bias(std::vector& bias); void write_to_file(std::ofstream &out); void read_from_file(std::ifstream &in); + void print_balance(); + protected: _u32 vector_dim; // dimension of points stored/each hp vector _u32 table_size; // number of hyperplanes // float **random_hps; std::vector> random_hps; std::map> hashed_vectors; + + std::vector count_plus, count_minus; + + std::vector bias; }; template @@ -53,6 +61,7 @@ namespace grann { _u32 search(const T *query, _u32 res_count, const Parameters &search_params, _u32 *indices, float *distances, QueryStats *stats = nullptr, std::vector