diff --git a/.gitignore b/.gitignore index 0250e6e..7ef241b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.idea/ build/ boost_1_76_0/ include/boost diff --git a/CMakeLists.txt b/CMakeLists.txt index 88efba1..99d903b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,7 +9,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - set(CMAKE_CXX_FLAGS_RELEASE "-O3") + set(CMAKE_CXX_FLAGS_RELEASE "-O3") endif() FetchContent_Declare(sqlite3 URL "http://sqlite.org/2021/sqlite-amalgamation-3340100.zip") diff --git a/Main.cpp b/Main.cpp index b3ed7b5..d399a9e 100644 --- a/Main.cpp +++ b/Main.cpp @@ -9,7 +9,7 @@ #include "ClusteringReader.h" #include "RuleApplication.h" #include "JaccardEngine.h" -#include "Explanation.h" +#include "SQLiteExplanation.h" #include "Util.hpp" #include #include @@ -54,7 +54,8 @@ int main(int argc, char** argv) Properties::get().REL_SIZE = index->getRelSize(); std::cout << "Reading testset..." << std::endl; - TesttripleReader* ttr = new TesttripleReader(Properties::get().PATH_TEST, index, graph, Properties::get().TRIAL); + TesttripleReader* ttr = new TesttripleReader(index, graph, Properties::get().TRIAL); + ttr->read(Properties::get().PATH_TEST); finish = std::chrono::high_resolution_clock::now(); milliseconds = std::chrono::duration_cast(finish - start); std::cout << "Testset read in " << milliseconds.count() << " ms\n"; @@ -86,7 +87,7 @@ int main(int argc, char** argv) Explanation* explanation = nullptr; if (Properties::get().EXPLAIN == 1) { std::cout << "Writing entities, relations and rules to db file..." << std::endl; - explanation = new Explanation(Properties::get().PATH_EXPLAIN, true); + explanation = new SQLiteExplanation(Properties::get().PATH_EXPLAIN, true); explanation->begin_tr(); explanation->insertEntities(index); explanation->insertRelations(index); diff --git a/include/Explanation.h b/include/Explanation.h index 726ad90..627b91f 100644 --- a/include/Explanation.h +++ b/include/Explanation.h @@ -1,7 +1,6 @@ #ifndef EXPL_H #define EXPL_H -#include #include "Index.h" #include "Rule.h" #include "RuleReader.h" @@ -9,34 +8,22 @@ class Explanation { public: - Explanation(std::string dbName, bool init = false); - ~Explanation(); + Explanation() {} + virtual ~Explanation() {} - void begin(); - void commit(); - void begin_tr(); - void commit_tr(); - void insertEntities(Index* index); - void insertRelations(Index* index); - void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr); + virtual void begin() = 0; + virtual void commit() = 0; + virtual void begin_tr() = 0; + virtual void commit_tr() = 0; + virtual void insertEntities(Index* index) = 0; + virtual void insertRelations(Index* index) = 0; + virtual void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr) = 0; - void insertTask(int prediction_id, bool is_head, int relation_id, int entity_id); - void insertPrediction(int task_id, int entity_id, bool hit, double confidence); - void insertRule_Entity(int rule_id, int task_id, int entity_id); + virtual void insertTask(int prediction_id, bool is_head, int relation_id, int entity_id) = 0; + virtual void insertPrediction(int task_id, int entity_id, bool hit, double confidence) = 0; + virtual void insertRule_Entity(int rule_id, int task_id, int entity_id) = 0; - // OLD - void insertCluster(int prediction_id, int entity_id, int cluster_id, double confidence); - void insertRule_Cluster(int prediction_id, int entity_id, int cluster_id, int rule_id); - - int getNextTaskID(); -private: - sqlite3* db; - int task_id = 0; - void initDb(); - void checkErrorCode(int code); - void checkErrorCode(int code, char* sql); - sqlite3_stmt* prepare(char* sql); - void finalize(sqlite3_stmt* stmt); + virtual int getNextTaskID() = 0; }; #endif //EXPL_H \ No newline at end of file diff --git a/include/InMemoryExplanation.h b/include/InMemoryExplanation.h new file mode 100644 index 0000000..cb8289e --- /dev/null +++ b/include/InMemoryExplanation.h @@ -0,0 +1,34 @@ +#ifndef INMEMORY_EXPL_H +#define INMEMORY_EXPL_H + +#include "Explanation.h" + +class InMemoryExplanation: public Explanation { +public: + InMemoryExplanation(); + ~InMemoryExplanation(); + + void begin(); + void commit(); + void begin_tr(); + void commit_tr(); + void insertEntities(Index* index); + void insertRelations(Index* index); + void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr); + + void insertTask(int task_id, bool is_head, int relation_id, int entity_id); + void insertPrediction(int task_id, int entity_id, bool hit, double confidence); + void insertRule_Entity(int rule_id, int task_id, int entity_id); + + int getNextTaskID(); + + std::unordered_map>> tripleBestRules; + +private: + int _task_id = 0; + std::unordered_map ruleConfidences; + std::unordered_map> tasks; + std::unordered_map> taskEntityBestRules; +}; + +#endif //INMEMORY_EXPL_H diff --git a/include/RuleApplication.h b/include/RuleApplication.h index 3ae3420..4c63adb 100644 --- a/include/RuleApplication.h +++ b/include/RuleApplication.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "Index.h" #include "TraintripleReader.h" #include "TesttripleReader.h" @@ -27,9 +28,12 @@ class RuleApplication { public: RuleApplication(Index* index, TraintripleReader* graph, TesttripleReader* ttr, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp); + RuleApplication(Index* index, TraintripleReader* graph, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp); void apply_nr_noisy(std::unordered_map>>, std::pair>>>> rel2clusters); void apply_only_noisy(); void apply_only_max(); + void updateTTR(TesttripleReader* ttr); + std::vector> apply_only_max_in_memory(size_t K); private: Index* index; diff --git a/include/RuleGraph.h b/include/RuleGraph.h index 79e0ceb..3e54578 100644 --- a/include/RuleGraph.h +++ b/include/RuleGraph.h @@ -16,7 +16,9 @@ class RuleGraph { public: RuleGraph(int nodesize, TraintripleReader* graph); + RuleGraph(int nodesize, TraintripleReader* graph, ValidationtripleReader* vtr); RuleGraph(int nodesize, TraintripleReader* graph, TesttripleReader* ttr, ValidationtripleReader* vtr); + void updateTTR(TesttripleReader *ttr); void searchDFSSingleStart_filt(bool headNotTail, int filt_v, int v, Rule& r, bool bwd, std::vector& solution, bool filtValidNotTest, bool filtExceptions); void searchDFSMultiStart_filt(bool headNotTail, int filt_v, Rule& r, bool bwd, std::vector& solution, bool filtValidNotTest, bool filtExceptions); bool existsAcyclic(int* valId, Rule& rule, bool filtValidNotTest); diff --git a/include/SQLiteExplanation.h b/include/SQLiteExplanation.h new file mode 100644 index 0000000..08b89d1 --- /dev/null +++ b/include/SQLiteExplanation.h @@ -0,0 +1,39 @@ +#ifndef SQLITE_EXPL_H +#define SQLITE_EXPL_H + +#include +#include "Explanation.h" + +class SQLiteExplanation: public Explanation { +public: + SQLiteExplanation(std::string dbName, bool init = false); + ~SQLiteExplanation(); + + void begin(); + void commit(); + void begin_tr(); + void commit_tr(); + void insertEntities(Index* index); + void insertRelations(Index* index); + void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr); + + void insertTask(int task_id, bool is_head, int relation_id, int entity_id); + void insertPrediction(int task_id, int entity_id, bool hit, double confidence); + void insertRule_Entity(int rule_id, int task_id, int entity_id); + + // OLD + void insertCluster(int prediction_id, int entity_id, int cluster_id, double confidence); + void insertRule_Cluster(int prediction_id, int entity_id, int cluster_id, int rule_id); + + int getNextTaskID(); +private: + sqlite3* db; + int _task_id = 0; + void initDb(); + void checkErrorCode(int code); + void checkErrorCode(int code, char* sql); + sqlite3_stmt* prepare(char* sql); + void finalize(sqlite3_stmt* stmt); +}; + +#endif //SQLITE_EXPL_H diff --git a/include/TesttripleReader.h b/include/TesttripleReader.h index 8610031..8e0cc56 100644 --- a/include/TesttripleReader.h +++ b/include/TesttripleReader.h @@ -21,7 +21,7 @@ class TesttripleReader { public: - TesttripleReader(std::string filepath, Index* index, TraintripleReader* graph, int is_trial); + TesttripleReader(Index* index, TraintripleReader* graph, int is_trial); int** getTesttriples(); int* getTesttriplesSize(); @@ -29,6 +29,9 @@ class TesttripleReader RelNodeToNodes& getRelHeadToTails(); RelNodeToNodes& getRelTailToHeads(); + void read(std::string filepath); + void read(std::vector> & triples); + protected: private: @@ -41,8 +44,6 @@ class TesttripleReader RelNodeToNodes relHeadToTails; RelNodeToNodes relTailToHeads; - - void read(std::string filepath); }; #endif // TESTTRIPLEREADER_H diff --git a/python_bindings/.gitignore b/python_bindings/.gitignore new file mode 100644 index 0000000..d6b83bb --- /dev/null +++ b/python_bindings/.gitignore @@ -0,0 +1,6 @@ +build/ +dist/ +*.egg-info/ +__pycache__ +safran_wrapper.py +safran_wrapper_wrap.cpp diff --git a/python_bindings/safran.py b/python_bindings/safran.py new file mode 100644 index 0000000..604fae3 --- /dev/null +++ b/python_bindings/safran.py @@ -0,0 +1,18 @@ +import itertools +import collections +from typing import List, Dict, Tuple +from safran_wrapper import pysafran, query_output_t, query_triples_t + +class SAFRAN(pysafran): + def __init__(self, train_path: str, rule_path: str, n_jobs: int = 1): + pysafran.__init__(self, train_path, rule_path, n_jobs) + + def query(self, triples: List[List[str]], k: int = 100, action: str = 'applymax') -> Dict[Tuple[str, str], List[Tuple[str, float, int]]]: + if action != 'applymax': + raise ValueError('Actions supported in the SAFRAN Python wrapper are: applymax') + flat_triples = list(itertools.chain.from_iterable(triples)) + pred_vals = query_output_t(pysafran.query(self, action, k, query_triples_t(flat_triples))) + out = collections.defaultdict(list) + for head, (pred, (tail, (val, rule_id))) in pred_vals: + out[head, pred].append((tail, val, rule_id)) + return out diff --git a/python_bindings/safran_wrapper.cpp b/python_bindings/safran_wrapper.cpp new file mode 100644 index 0000000..5da50b2 --- /dev/null +++ b/python_bindings/safran_wrapper.cpp @@ -0,0 +1,69 @@ +#include "safran_wrapper.h" + +pysafran::pysafran(std::string train_path, std::string rule_path, int num_threads) { + // TODO: initialize these globally, or make those a part of pysafran instance. + Properties::get().VERBOSE = 0; + Properties::get().PREDICT_UNKNOWN = 1; + if (num_threads != -1) { + omp_set_num_threads(num_threads); + } + + this->index = new Index(); + index->addNode(Properties::get().REFLEXIV_TOKEN); + index->addNode(Properties::get().UNK_TOKEN); + + this->graph = new TraintripleReader(train_path, index); + Properties::get().REL_SIZE = index->getRelSize(); + + this->rr = new RuleReader(rule_path, index, graph); + this->vtr = new ValidationtripleReader("/dev/null", index, graph); + + this->exp = new InMemoryExplanation(); + this->exp->insertRules(rr, index->getRelSize(), nullptr); + + this->ra = new RuleApplication(index, graph, vtr, rr, exp); +} + +pysafran::~pysafran() { + delete this->ra; + delete this->vtr; + delete this->rr; + delete this->graph; + delete this->index; + delete this->exp; +} + +std::vector>>>> pysafran::query(const std::string & action, size_t k, const std::vector & flat_triples) const { + // "Read" triples + TesttripleReader ttr(index, graph, 0); + std::vector> triples; + size_t i = 0; + while (i < flat_triples.size()) { + triples.emplace_back(flat_triples[i], flat_triples[i + 1], flat_triples[i + 2]); + i += 3; + } + ttr.read(triples); + ra->updateTTR(&ttr); + + // Apply specific rule application + std::vector> outAction; + if (action == "applymax") { + outAction = ra->apply_only_max_in_memory(k); + } + + // Transform node ids to node names and only retain top-k + std::vector>>>> out; + for (auto & p : outAction) { + int head_id = std::get<0>(p), relation_id = std::get<1>(p), tail_id = std::get<2>(p); + float val = (float)std::get<3>(p); + std::string headStr = *this->index->getStringOfNodeId(head_id); + std::string predStr = *this->index->getStringOfRelId(relation_id); + std::string tailStr = *this->index->getStringOfNodeId(tail_id); + int rule_id = exp->tripleBestRules[head_id][relation_id][tail_id]; + std::pair p1 = {val, rule_id}; + std::pair> p2 = {tailStr, p1}; + std::pair>> p3 = {predStr, p2}; + out.emplace_back(headStr, p3); + } + return out; +} diff --git a/python_bindings/safran_wrapper.h b/python_bindings/safran_wrapper.h new file mode 100644 index 0000000..243e789 --- /dev/null +++ b/python_bindings/safran_wrapper.h @@ -0,0 +1,31 @@ +#include "Index.h" +#include "TraintripleReader.h" +#include "ValidationtripleReader.h" +#include "RuleReader.h" +#include "RuleApplication.h" +#include "Properties.hpp" +#include "InMemoryExplanation.h" + +#include +#include +#include +#include + +using string_vector_t = std::vector; + +class pysafran { +public: + pysafran(std::string train_path, std::string rule_path, int num_threads); + ~pysafran(); + + // [(head, pred, tail)] -> [(head, pred, tail, rule_confidence, rule_id)] + std::vector>>>> query(const std::string & action, size_t k, const std::vector & flat_triples) const; + +private: + Index *index; + TraintripleReader* graph; + ValidationtripleReader* vtr; + RuleReader* rr; + RuleApplication *ra; + InMemoryExplanation *exp; +}; \ No newline at end of file diff --git a/python_bindings/safran_wrapper.i b/python_bindings/safran_wrapper.i new file mode 100644 index 0000000..247715e --- /dev/null +++ b/python_bindings/safran_wrapper.i @@ -0,0 +1,20 @@ +%module safran_wrapper + +%{ +#define SWIG_FILE_WITH_INIT +#include "safran_wrapper.h" +%} + +%include +%include +%include +%include "safran_wrapper.h" + +namespace std { + %template(query_triples_t) vector; + %template(query_p1) pair; + %template(query_p2) pair>; + %template(query_p3) pair>>; + %template(query_p4) pair>>>; + %template(query_output_t) vector>>>>; +} diff --git a/python_bindings/setup.py b/python_bindings/setup.py new file mode 100644 index 0000000..67a3428 --- /dev/null +++ b/python_bindings/setup.py @@ -0,0 +1,51 @@ +import os +import sys +import setuptools +from glob import glob +from setuptools import setup, Extension +from setuptools.command.install import install +from distutils.command.build import build + +__version__ = '1.0.0' + +source_files = ['safran_wrapper.cpp', 'safran_wrapper.i'] + +source_files += glob('../src/*.cpp') +source_files = [src for src in source_files if src not in {'../src/SQLiteExplanation.cpp'}] + +ext_modules = [ + Extension( + '_safran_wrapper', + source_files, + swig_opts=["-c++", "-extranative"], + language='c++', + extra_compile_args=["-std=c++17", "-I../include", "-I../boost_1_76_0", "-fopenmp"], + extra_link_args=["-fopenmp"], + ), +] + +class CustomBuild(build): + def run(self): + self.run_command('build_ext') + build.run(self) + +class CustomInstall(install): + def run(self): + self.run_command('build_ext') + self.do_egg_install() + +custom_cmdclass = {'build': CustomBuild, 'install': CustomInstall} + +setup( + name='safran', + version=__version__, + description='SAFRAN: Scalable and fast non-redundant rule application', + author='S. Ott, C. Melicke, M. Samwald, A. Belyy', + url='https://github.com/AVBelyy/SAFRAN', + ext_modules=ext_modules, + cmdclass=custom_cmdclass, + install_requires=[], + setup_requires=[], + py_modules=['safran', 'safran_wrapper'], + zip_safe=False, +) diff --git a/src/Explanation.cpp b/src/Explanation.cpp deleted file mode 100644 index cecf4ae..0000000 --- a/src/Explanation.cpp +++ /dev/null @@ -1,332 +0,0 @@ -#include "Explanation.h" - -Explanation::Explanation(std::string dbName, bool init) { - checkErrorCode(sqlite3_open_v2(dbName.c_str(), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)); - if (init) initDb(); -} - -Explanation::~Explanation() { - sqlite3_close(db); -} - -void Explanation::initDb() { - - checkErrorCode(sqlite3_exec(db, "PRAGMA synchronous=OFF", NULL, NULL, NULL)); - checkErrorCode(sqlite3_exec(db, "PRAGMA count_changes=OFF", NULL, NULL, NULL)); - checkErrorCode(sqlite3_exec(db, "PRAGMA journal_mode=MEMORY", NULL, NULL, NULL)); - checkErrorCode(sqlite3_exec(db, "PRAGMA temp_store=MEMORY", NULL, NULL, NULL)); - - char* sql; - - sql = "CREATE TABLE Entity(" \ - "ID INT PRIMARY KEY NOT NULL," \ - "NAME TEXT NOT NULL" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Relation(" \ - "ID INT PRIMARY KEY NOT NULL," \ - "NAME TEXT NOT NULL" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Rule(" \ - "ID INT PRIMARY KEY NOT NULL," \ - "HEAD_CLUSTER_ID INT NOT NULL," \ - "TAIL_CLUSTER_ID INT NOT NULL," \ - "DEF TEXT NOT NULL," \ - "CONFIDENCE DOUBLE NOT NULL,"\ - "PREDICTED INT NOT NULL,"\ - "CORRECTLY_PREDICTED INT NOT NULL"\ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Task(" \ - "ID INT PRIMARY KEY NOT NULL," \ - "IsHead BOOL," \ - "EntityID INT," \ - "RelationID INT," \ - "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ - "FOREIGN KEY(RelationID) REFERENCES Relation(ID)" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Prediction(" \ - "TaskID INT NOT NULL," \ - "EntityID INT NOT NULL," \ - "Hit BOOL NOT NULL," - "CONFIDENCE REAL NOT NULL,"\ - "FOREIGN KEY(TaskID) REFERENCES Task(ID)," \ - "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ - "PRIMARY KEY(TaskID, EntityID)" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Rule_Entity(" \ - "RuleID INT NOT NULL," \ - "TaskID INT NOT NULL," \ - "EntityID INT NOT NULL," \ - "FOREIGN KEY(RuleID) REFERENCES Rule(ID)," \ - "FOREIGN KEY(TaskID) REFERENCES Task(ID)," \ - "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ - "PRIMARY KEY(RuleID, TaskID, EntityID)" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE INDEX asdf ON Rule_Entity (" \ - "TaskID," \ - "EntityID" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); -} - -void Explanation::insertEntities(Index* index) { - char* sql = "INSERT INTO Entity (ID,NAME) VALUES (?,?);"; - sqlite3_stmt* stmt = prepare(sql); - for (int id = 0; id < index->getNodeSize(); id++) { - - // bind the value - checkErrorCode(sqlite3_bind_int(stmt, 1, id)); - - std::string curie = *index->getStringOfNodeId(id); - checkErrorCode(sqlite3_bind_text(stmt, 2, curie.c_str(), strlen(curie.c_str()), 0)); - - checkErrorCode(sqlite3_step(stmt)); - checkErrorCode(sqlite3_reset(stmt)); - } - finalize(stmt); -} -void Explanation::insertRelations(Index* index) { - char* sql = "INSERT INTO Relation (ID,NAME) VALUES (?,?);"; - sqlite3_stmt* stmt = prepare(sql); - for (int id = 0; id < index->getRelSize(); id++) { - - // bind the value - checkErrorCode(sqlite3_bind_int(stmt, 1, id)); - - std::string relation = *index->getStringOfRelId(id); - checkErrorCode(sqlite3_bind_text(stmt, 2, relation.c_str(), strlen(relation.c_str()), 0)); - - checkErrorCode(sqlite3_step(stmt)); - checkErrorCode(sqlite3_reset(stmt)); - } - finalize(stmt); -} - -void Explanation::insertRules(RuleReader* rr, int relsize, ClusteringReader* cr) { - int* adj_begin = rr->getCSR()->getAdjBegin(); - Rule* rules_adj_list = rr->getCSR()->getAdjList(); - - char* sql = "INSERT INTO Rule (ID,HEAD_CLUSTER_ID,TAIL_CLUSTER_ID,DEF,CONFIDENCE,PREDICTED,CORRECTLY_PREDICTED) VALUES (?,?,?,?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); - - if (cr == nullptr) { - for (int rel = 0; rel < relsize; rel++) { - int ind_ptr = adj_begin[3 + rel]; - int lenRules = adj_begin[3 + rel + 1] - ind_ptr; - for (int i = 0; i < lenRules; i++) { - - Rule& r = rules_adj_list[ind_ptr + i]; - // bind the value - checkErrorCode(sqlite3_bind_int(stmt, 1, r.getID())); - - checkErrorCode(sqlite3_bind_int(stmt, 2, 0)); - checkErrorCode(sqlite3_bind_int(stmt, 3, 0)); - - std::string rulestr = r.getRulestring(); - checkErrorCode(sqlite3_bind_text(stmt, 4, rulestr.c_str(), strlen(rulestr.c_str()), 0)); - - checkErrorCode(sqlite3_bind_double(stmt, 5, r.getAppliedConfidence())); - checkErrorCode(sqlite3_bind_int(stmt, 6, r.getPredicted())); - checkErrorCode(sqlite3_bind_int(stmt, 7, r.getCorrectlyPredicted())); - - checkErrorCode(sqlite3_step(stmt)); - checkErrorCode(sqlite3_reset(stmt)); - } - } - } - else { - std::unordered_map>>, std::pair>>>> rel2clusters = cr->getRelToClusters(); - for (int rel = 0; rel < relsize; rel++) { - int ind_ptr = adj_begin[3 + rel]; - int lenRules = adj_begin[3 + rel + 1] - ind_ptr; - - std::unordered_map ruleToHeadCluster; - std::unordered_map ruleToTailCluster; - - // Head clusters - std::vector> cluster = rel2clusters[rel].first.second; - for (int cluster_id = 0; cluster_id < cluster.size(); cluster_id++) { - for (auto rule : cluster[cluster_id]) { - Rule& r = rules_adj_list[ind_ptr + rule]; - ruleToHeadCluster[r.getID()] = cluster_id; - } - } - - cluster = rel2clusters[rel].second.second; - for (int cluster_id = 0; cluster_id < cluster.size(); cluster_id++) { - for (auto rule : cluster[cluster_id]) { - Rule& r = rules_adj_list[ind_ptr + rule]; - ruleToTailCluster[r.getID()] = cluster_id; - } - } - - for (int i = 0; i < lenRules; i++) { - - Rule& r = rules_adj_list[ind_ptr + i]; - // bind the value - checkErrorCode(sqlite3_bind_int(stmt, 1, r.getID())); - - checkErrorCode(sqlite3_bind_int(stmt, 2, ruleToHeadCluster[r.getID()])); - - checkErrorCode(sqlite3_bind_int(stmt, 3, ruleToTailCluster[r.getID()])); - - std::string rulestr = r.getRulestring(); - checkErrorCode(sqlite3_bind_text(stmt, 4, rulestr.c_str(), strlen(rulestr.c_str()), 0)); - - checkErrorCode(sqlite3_bind_double(stmt, 5, r.getAppliedConfidence())); - checkErrorCode(sqlite3_bind_int(stmt, 6, r.getPredicted())); - checkErrorCode(sqlite3_bind_int(stmt, 7, r.getCorrectlyPredicted())); - - checkErrorCode(sqlite3_step(stmt)); - checkErrorCode(sqlite3_reset(stmt)); - } - } - } - - finalize(stmt); -} - -void Explanation::insertTask(int task_id, bool is_head, int relation_id, int entity_id) { - //std::cout << "Task " << task_id << " " << is_head << " " << relation_id << " " << entity_id << "\n"; - char* sql = "INSERT INTO Task (ID,IsHead,RelationID,EntityId) VALUES (?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); - - checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert task bind 1"); - checkErrorCode(sqlite3_bind_int(stmt, 2, (int)is_head), "insert task bind 2"); - checkErrorCode(sqlite3_bind_int(stmt, 3, relation_id), "insert task bind 3"); - checkErrorCode(sqlite3_bind_int(stmt, 4, entity_id), "insert task bind 4"); - - checkErrorCode(sqlite3_step(stmt), "insert task step"); - finalize(stmt); -} - -void Explanation::insertPrediction(int task_id, int entity_id, bool hit, double confidence) { - //std::cout << "Prediction " << task_id << " " << entity_id << " " << " " << confidence << "\n"; - char* sql = "INSERT INTO Prediction (TaskID,EntityID,Hit,Confidence) VALUES (?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); - - checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert pred bind 1"); - checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id), "insert pred bind 2"); - checkErrorCode(sqlite3_bind_double(stmt, 3, (int)hit), "insert pred bind 3"); - checkErrorCode(sqlite3_bind_double(stmt, 4, confidence), "insert pred bind 4"); - - checkErrorCode(sqlite3_step(stmt), "insert pred step"); - finalize(stmt); -} - -void Explanation::insertCluster(int task_id, int entity_id, int cluster_id, double confidence) { - //std::cout << "Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << confidence << "\n"; - char* sql = "INSERT INTO Cluster (PredictionTaskID, PredictionEntityID, ID, Confidence) VALUES (?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); - - checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert clus bind 1"); - checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id), "insert clus bind 2"); - checkErrorCode(sqlite3_bind_int(stmt, 3, cluster_id), "insert clus bind 3"); - checkErrorCode(sqlite3_bind_double(stmt, 4, confidence), "insert clus bind 1"); - - checkErrorCode(sqlite3_step(stmt), sql); - finalize(stmt); -} - -void Explanation::insertRule_Cluster(int task_id, int entity_id, int cluster_id, int rule_id) { - //std::cout << "Rule_Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << rule_id << "\n"; - char* sql = "INSERT INTO Rule_Cluster (ClusterPredictionTaskID,ClusterPredictionEntityID,ClusterID,RuleID) VALUES (?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); - - checkErrorCode(sqlite3_bind_int(stmt, 1, task_id)); - checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id)); - checkErrorCode(sqlite3_bind_int(stmt, 3, cluster_id)); - checkErrorCode(sqlite3_bind_int(stmt, 4, rule_id)); - - checkErrorCode(sqlite3_step(stmt)); - finalize(stmt); -} - -void Explanation::insertRule_Entity(int rule_id, int task_id, int entity_id) { - //std::cout << "Rule_Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << rule_id << "\n"; - char* sql = "INSERT OR IGNORE INTO Rule_Entity (RuleID, TaskID, EntityID) VALUES (?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); - - checkErrorCode(sqlite3_bind_int(stmt, 1, rule_id)); - checkErrorCode(sqlite3_bind_int(stmt, 2, task_id)); - checkErrorCode(sqlite3_bind_int(stmt, 3, entity_id)); - - checkErrorCode(sqlite3_step(stmt)); - finalize(stmt); -} - -void Explanation::begin_tr() { - checkErrorCode(sqlite3_exec(db, "BEGIN TRANSACTION", NULL, NULL, NULL), "BEGIN TRANSACTION"); -} - -void Explanation::begin() { - //sqlite3_mutex_enter(sqlite3_db_mutex(db)); -} - -sqlite3_stmt* Explanation::prepare(char* sql) { - sqlite3_stmt* stmt; - checkErrorCode(sqlite3_prepare(db, sql, -1, &stmt, NULL), sql); - return stmt; -} - -void Explanation::finalize(sqlite3_stmt* stmt) { - checkErrorCode(sqlite3_finalize(stmt)); -} - -void Explanation::commit() { - //sqlite3_mutex_leave(sqlite3_db_mutex(db)); -} - -void Explanation::commit_tr() { - checkErrorCode(sqlite3_exec(db, "COMMIT TRANSACTION", NULL, NULL, NULL), "COMMIT TRANSACTION"); -} - -void Explanation::checkErrorCode(int code) { - if (code != SQLITE_OK && code != SQLITE_DONE) { - const char* err; - if (this->db) { - err = sqlite3_errmsg(this->db); - } - else { - err = sqlite3_errstr(code); - } - std::cerr << "Error " << code << ": " << err << '\n'; - std::exit(EXIT_FAILURE); - } -} - -void Explanation::checkErrorCode(int code, char* sql) { - if (code != SQLITE_OK && code != SQLITE_DONE) { - const char* err; - if (this->db) { - err = sqlite3_errmsg(this->db); - } - else { - err = sqlite3_errstr(code); - } - std::cerr << sql << "\n"; - std::cerr << "Error " << code << ": " << err << '\n'; - std::exit(EXIT_FAILURE); - } -} - -int Explanation::getNextTaskID() { - int task_id_; -#pragma omp critical - { - task_id++; - task_id_ = task_id; - } - return task_id_; -} \ No newline at end of file diff --git a/src/InMemoryExplanation.cpp b/src/InMemoryExplanation.cpp new file mode 100644 index 0000000..7fd9d01 --- /dev/null +++ b/src/InMemoryExplanation.cpp @@ -0,0 +1,90 @@ +#include "InMemoryExplanation.h" + +InMemoryExplanation::InMemoryExplanation() = default; + +InMemoryExplanation::~InMemoryExplanation() = default; + +void InMemoryExplanation::insertEntities(Index* index) { +} + +void InMemoryExplanation::insertRelations(Index* index) { +} + +void InMemoryExplanation::insertRules(RuleReader* rr, int relsize, ClusteringReader* cr) { + int* adj_begin = rr->getCSR()->getAdjBegin(); + Rule* rules_adj_list = rr->getCSR()->getAdjList(); + + for (int rel = 0; rel < relsize; rel++) { + int ind_ptr = adj_begin[3 + rel]; + int lenRules = adj_begin[3 + rel + 1] - ind_ptr; + for (int i = 0; i < lenRules; i++) { + Rule &r = rules_adj_list[ind_ptr + i]; + int rule_id = r.getID(); + auto conf = (float)r.getAppliedConfidence(); + ruleConfidences[rule_id] = conf; + } + } +} + +void InMemoryExplanation::insertTask(int task_id, bool is_head, int relation_id, int entity_id) { +#pragma omp critical + { + tasks[task_id] = {relation_id, entity_id, is_head}; + } +} + +void InMemoryExplanation::insertPrediction(int task_id, int entity_id, bool hit, double confidence) { +} + +void InMemoryExplanation::insertRule_Entity(int rule_id, int task_id, int other_entity_id) { + bool need_update = true; + if (taskEntityBestRules.find(task_id) != taskEntityBestRules.end()) { + if (taskEntityBestRules[task_id].find(other_entity_id) != taskEntityBestRules[task_id].end()) { + int prev_rule_id = taskEntityBestRules[task_id][other_entity_id]; + float prev_conf = ruleConfidences[prev_rule_id], new_conf = ruleConfidences[rule_id]; + if (new_conf <= prev_conf) { + need_update = false; + } + } + } + if (need_update) { + auto & p = tasks[task_id]; + int relation_id = std::get<0>(p), entity_id = std::get<1>(p), is_head = std::get<2>(p); + int head_id, tail_id; + if (is_head) { + head_id = entity_id; + tail_id = other_entity_id; + } else { + head_id = other_entity_id; + tail_id = entity_id; + } +#pragma omp critical + // Update the current best rule, according to the "applymax" action + { + taskEntityBestRules[task_id][other_entity_id] = rule_id; + tripleBestRules[head_id][relation_id][tail_id] = rule_id; + } + } +} + +void InMemoryExplanation::begin_tr() { +} + +void InMemoryExplanation::begin() { +} + +void InMemoryExplanation::commit() { +} + +void InMemoryExplanation::commit_tr() { +} + +int InMemoryExplanation::getNextTaskID() { + int task_id_; +#pragma omp critical + { + _task_id++; + task_id_ = _task_id; + } + return task_id_; +} \ No newline at end of file diff --git a/src/RuleApplication.cpp b/src/RuleApplication.cpp index 5d40787..6d2185b 100644 --- a/src/RuleApplication.cpp +++ b/src/RuleApplication.cpp @@ -1,5 +1,16 @@ #include "RuleApplication.h" +RuleApplication::RuleApplication(Index* index, TraintripleReader* graph, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp) { + this->index = index; + this->graph = graph; + this->vtr = vtr; + this->rr = rr; + this->rulegraph = new RuleGraph(index->getNodeSize(), graph, vtr); + reflexiv_token = *index->getIdOfNodestring(Properties::get().REFLEXIV_TOKEN); + this->k = Properties::get().TOP_K_OUTPUT; + this->exp = exp; +} + RuleApplication::RuleApplication(Index* index, TraintripleReader* graph, TesttripleReader* ttr, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp) { this->index = index; this->graph = graph; @@ -12,6 +23,11 @@ RuleApplication::RuleApplication(Index* index, TraintripleReader* graph, Testtri this->exp = exp; } +void RuleApplication::updateTTR(TesttripleReader* testReader) { + this->ttr = testReader; + this->rulegraph->updateTTR(testReader); +} + void RuleApplication::apply_nr_noisy(std::unordered_map>>, std::pair>>>> rel2clusters) { int* adj_begin = rr->getCSR()->getAdjBegin(); @@ -148,6 +164,43 @@ void RuleApplication::apply_only_max() { fclose(pFile); } +std::vector> RuleApplication::apply_only_max_in_memory(size_t K) { + int* adj_begin = rr->getCSR()->getAdjBegin(); + std::vector> out; + + int iterations = index->getRelSize(); + for (int rel = 0; rel < iterations; rel++) { + // TODO: precompute clusters outside of this call + int ind_ptr = adj_begin[3 + rel]; + int lenRules = adj_begin[3 + rel + 1] - ind_ptr; + std::vector> clusters; + std::vector cluster; + for (int i = 0; i < lenRules; i++) { + cluster.push_back(i); + } + clusters.push_back(cluster); + + std::unordered_map>>> headTailResults; + headTailResults = max(rel, clusters, false); + + for (auto & [head, tailResults] : headTailResults) { + for (auto & [_, tails] : tailResults) { + size_t maxSize = std::min(tails.size(), K); + size_t i = 0; + for (auto [tail, val] : tails) { + if (i >= maxSize) { + break; + } + out.emplace_back(head, rel, tail, val); + i++; + } + } + } + } + + return out; +} + std::unordered_map>>> RuleApplication::noisy(int rel, std::vector> clusters, bool predictHeadNotTail) { int* adj_lists = graph->getCSR()->getAdjList(); int* adj_list_starts = graph->getCSR()->getAdjBegin(); diff --git a/src/RuleGraph.cpp b/src/RuleGraph.cpp index e364692..416d6f2 100644 --- a/src/RuleGraph.cpp +++ b/src/RuleGraph.cpp @@ -21,6 +21,23 @@ RuleGraph::RuleGraph(int nodesize, TraintripleReader* graph, TesttripleReader* t this->relCounter = graph->getRelCounter(); } +RuleGraph::RuleGraph(int nodesize, TraintripleReader* graph, ValidationtripleReader* vtr) { + this->size = nodesize; + this->graph = graph; + adj_lists = graph->getCSR()->getAdjList(); + adj_list_starts = graph->getCSR()->getAdjBegin(); + this->train_relHeadToTails = graph->getRelHeadToTails(); + this->train_relTailToHeads = graph->getRelTailToHeads(); + this->valid_relHeadToTails = vtr->getRelHeadToTails(); + this->valid_relTailToHeads = vtr->getRelTailToHeads(); + this->relCounter = graph->getRelCounter(); +} + +void RuleGraph::updateTTR(TesttripleReader* ttr) { + this->test_relHeadToTails = ttr->getRelHeadToTails(); + this->test_relTailToHeads = ttr->getRelTailToHeads(); +} + void RuleGraph::searchDFSSingleStart_filt(bool headNotTail, int filt_v, int v, Rule& r, bool bwd, std::vector& solution, bool filtValidNotTest, bool filtExceptions) { int rulelength = r.getRulelength(); diff --git a/src/SQLiteExplanation.cpp b/src/SQLiteExplanation.cpp new file mode 100644 index 0000000..2b090bb --- /dev/null +++ b/src/SQLiteExplanation.cpp @@ -0,0 +1,332 @@ +#include "SQLiteExplanation.h" + +SQLiteExplanation::SQLiteExplanation(std::string dbName, bool init) { + checkErrorCode(sqlite3_open_v2(dbName.c_str(), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)); + if (init) initDb(); +} + +SQLiteExplanation::~SQLiteExplanation() { + sqlite3_close(db); +} + +void SQLiteExplanation::initDb() { + + checkErrorCode(sqlite3_exec(db, "PRAGMA synchronous=OFF", NULL, NULL, NULL)); + checkErrorCode(sqlite3_exec(db, "PRAGMA count_changes=OFF", NULL, NULL, NULL)); + checkErrorCode(sqlite3_exec(db, "PRAGMA journal_mode=MEMORY", NULL, NULL, NULL)); + checkErrorCode(sqlite3_exec(db, "PRAGMA temp_store=MEMORY", NULL, NULL, NULL)); + + char* sql; + + sql = "CREATE TABLE Entity(" \ + "ID INT PRIMARY KEY NOT NULL," \ + "NAME TEXT NOT NULL" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Relation(" \ + "ID INT PRIMARY KEY NOT NULL," \ + "NAME TEXT NOT NULL" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Rule(" \ + "ID INT PRIMARY KEY NOT NULL," \ + "HEAD_CLUSTER_ID INT NOT NULL," \ + "TAIL_CLUSTER_ID INT NOT NULL," \ + "DEF TEXT NOT NULL," \ + "CONFIDENCE DOUBLE NOT NULL,"\ + "PREDICTED INT NOT NULL,"\ + "CORRECTLY_PREDICTED INT NOT NULL"\ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Task(" \ + "ID INT PRIMARY KEY NOT NULL," \ + "IsHead BOOL," \ + "EntityID INT," \ + "RelationID INT," \ + "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ + "FOREIGN KEY(RelationID) REFERENCES Relation(ID)" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Prediction(" \ + "TaskID INT NOT NULL," \ + "EntityID INT NOT NULL," \ + "Hit BOOL NOT NULL," + "CONFIDENCE REAL NOT NULL,"\ + "FOREIGN KEY(TaskID) REFERENCES Task(ID)," \ + "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ + "PRIMARY KEY(TaskID, EntityID)" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Rule_Entity(" \ + "RuleID INT NOT NULL," \ + "TaskID INT NOT NULL," \ + "EntityID INT NOT NULL," \ + "FOREIGN KEY(RuleID) REFERENCES Rule(ID)," \ + "FOREIGN KEY(TaskID) REFERENCES Task(ID)," \ + "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ + "PRIMARY KEY(RuleID, TaskID, EntityID)" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE INDEX asdf ON Rule_Entity (" \ + "TaskID," \ + "EntityID" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); +} + +void SQLiteExplanation::insertEntities(Index* index) { + char* sql = "INSERT INTO Entity (ID,NAME) VALUES (?,?);"; + sqlite3_stmt* stmt = prepare(sql); + for (int id = 0; id < index->getNodeSize(); id++) { + + // bind the value + checkErrorCode(sqlite3_bind_int(stmt, 1, id)); + + std::string curie = *index->getStringOfNodeId(id); + checkErrorCode(sqlite3_bind_text(stmt, 2, curie.c_str(), strlen(curie.c_str()), 0)); + + checkErrorCode(sqlite3_step(stmt)); + checkErrorCode(sqlite3_reset(stmt)); + } + finalize(stmt); +} +void SQLiteExplanation::insertRelations(Index* index) { + char* sql = "INSERT INTO Relation (ID,NAME) VALUES (?,?);"; + sqlite3_stmt* stmt = prepare(sql); + for (int id = 0; id < index->getRelSize(); id++) { + + // bind the value + checkErrorCode(sqlite3_bind_int(stmt, 1, id)); + + std::string relation = *index->getStringOfRelId(id); + checkErrorCode(sqlite3_bind_text(stmt, 2, relation.c_str(), strlen(relation.c_str()), 0)); + + checkErrorCode(sqlite3_step(stmt)); + checkErrorCode(sqlite3_reset(stmt)); + } + finalize(stmt); +} + +void SQLiteExplanation::insertRules(RuleReader* rr, int relsize, ClusteringReader* cr) { + int* adj_begin = rr->getCSR()->getAdjBegin(); + Rule* rules_adj_list = rr->getCSR()->getAdjList(); + + char* sql = "INSERT INTO Rule (ID,HEAD_CLUSTER_ID,TAIL_CLUSTER_ID,DEF,CONFIDENCE,PREDICTED,CORRECTLY_PREDICTED) VALUES (?,?,?,?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); + + if (cr == nullptr) { + for (int rel = 0; rel < relsize; rel++) { + int ind_ptr = adj_begin[3 + rel]; + int lenRules = adj_begin[3 + rel + 1] - ind_ptr; + for (int i = 0; i < lenRules; i++) { + + Rule& r = rules_adj_list[ind_ptr + i]; + // bind the value + checkErrorCode(sqlite3_bind_int(stmt, 1, r.getID())); + + checkErrorCode(sqlite3_bind_int(stmt, 2, 0)); + checkErrorCode(sqlite3_bind_int(stmt, 3, 0)); + + std::string rulestr = r.getRulestring(); + checkErrorCode(sqlite3_bind_text(stmt, 4, rulestr.c_str(), strlen(rulestr.c_str()), 0)); + + checkErrorCode(sqlite3_bind_double(stmt, 5, r.getAppliedConfidence())); + checkErrorCode(sqlite3_bind_int(stmt, 6, r.getPredicted())); + checkErrorCode(sqlite3_bind_int(stmt, 7, r.getCorrectlyPredicted())); + + checkErrorCode(sqlite3_step(stmt)); + checkErrorCode(sqlite3_reset(stmt)); + } + } + } + else { + std::unordered_map>>, std::pair>>>> rel2clusters = cr->getRelToClusters(); + for (int rel = 0; rel < relsize; rel++) { + int ind_ptr = adj_begin[3 + rel]; + int lenRules = adj_begin[3 + rel + 1] - ind_ptr; + + std::unordered_map ruleToHeadCluster; + std::unordered_map ruleToTailCluster; + + // Head clusters + std::vector> cluster = rel2clusters[rel].first.second; + for (int cluster_id = 0; cluster_id < cluster.size(); cluster_id++) { + for (auto rule : cluster[cluster_id]) { + Rule& r = rules_adj_list[ind_ptr + rule]; + ruleToHeadCluster[r.getID()] = cluster_id; + } + } + + cluster = rel2clusters[rel].second.second; + for (int cluster_id = 0; cluster_id < cluster.size(); cluster_id++) { + for (auto rule : cluster[cluster_id]) { + Rule& r = rules_adj_list[ind_ptr + rule]; + ruleToTailCluster[r.getID()] = cluster_id; + } + } + + for (int i = 0; i < lenRules; i++) { + + Rule& r = rules_adj_list[ind_ptr + i]; + // bind the value + checkErrorCode(sqlite3_bind_int(stmt, 1, r.getID())); + + checkErrorCode(sqlite3_bind_int(stmt, 2, ruleToHeadCluster[r.getID()])); + + checkErrorCode(sqlite3_bind_int(stmt, 3, ruleToTailCluster[r.getID()])); + + std::string rulestr = r.getRulestring(); + checkErrorCode(sqlite3_bind_text(stmt, 4, rulestr.c_str(), strlen(rulestr.c_str()), 0)); + + checkErrorCode(sqlite3_bind_double(stmt, 5, r.getAppliedConfidence())); + checkErrorCode(sqlite3_bind_int(stmt, 6, r.getPredicted())); + checkErrorCode(sqlite3_bind_int(stmt, 7, r.getCorrectlyPredicted())); + + checkErrorCode(sqlite3_step(stmt)); + checkErrorCode(sqlite3_reset(stmt)); + } + } + } + + finalize(stmt); +} + +void SQLiteExplanation::insertTask(int task_id, bool is_head, int relation_id, int entity_id) { + //std::cout << "Task " << task_id << " " << is_head << " " << relation_id << " " << entity_id << "\n"; + char* sql = "INSERT INTO Task (ID,IsHead,RelationID,EntityId) VALUES (?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); + + checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert task bind 1"); + checkErrorCode(sqlite3_bind_int(stmt, 2, (int)is_head), "insert task bind 2"); + checkErrorCode(sqlite3_bind_int(stmt, 3, relation_id), "insert task bind 3"); + checkErrorCode(sqlite3_bind_int(stmt, 4, entity_id), "insert task bind 4"); + + checkErrorCode(sqlite3_step(stmt), "insert task step"); + finalize(stmt); +} + +void SQLiteExplanation::insertPrediction(int task_id, int entity_id, bool hit, double confidence) { + //std::cout << "Prediction " << task_id << " " << entity_id << " " << " " << confidence << "\n"; + char* sql = "INSERT INTO Prediction (TaskID,EntityID,Hit,Confidence) VALUES (?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); + + checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert pred bind 1"); + checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id), "insert pred bind 2"); + checkErrorCode(sqlite3_bind_double(stmt, 3, (int)hit), "insert pred bind 3"); + checkErrorCode(sqlite3_bind_double(stmt, 4, confidence), "insert pred bind 4"); + + checkErrorCode(sqlite3_step(stmt), "insert pred step"); + finalize(stmt); +} + +void SQLiteExplanation::insertCluster(int task_id, int entity_id, int cluster_id, double confidence) { + //std::cout << "Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << confidence << "\n"; + char* sql = "INSERT INTO Cluster (PredictionTaskID, PredictionEntityID, ID, Confidence) VALUES (?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); + + checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert clus bind 1"); + checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id), "insert clus bind 2"); + checkErrorCode(sqlite3_bind_int(stmt, 3, cluster_id), "insert clus bind 3"); + checkErrorCode(sqlite3_bind_double(stmt, 4, confidence), "insert clus bind 1"); + + checkErrorCode(sqlite3_step(stmt), sql); + finalize(stmt); +} + +void SQLiteExplanation::insertRule_Cluster(int task_id, int entity_id, int cluster_id, int rule_id) { + //std::cout << "Rule_Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << rule_id << "\n"; + char* sql = "INSERT INTO Rule_Cluster (ClusterPredictionTaskID,ClusterPredictionEntityID,ClusterID,RuleID) VALUES (?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); + + checkErrorCode(sqlite3_bind_int(stmt, 1, task_id)); + checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id)); + checkErrorCode(sqlite3_bind_int(stmt, 3, cluster_id)); + checkErrorCode(sqlite3_bind_int(stmt, 4, rule_id)); + + checkErrorCode(sqlite3_step(stmt)); + finalize(stmt); +} + +void SQLiteExplanation::insertRule_Entity(int rule_id, int task_id, int entity_id) { + //std::cout << "Rule_Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << rule_id << "\n"; + char* sql = "INSERT OR IGNORE INTO Rule_Entity (RuleID, TaskID, EntityID) VALUES (?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); + + checkErrorCode(sqlite3_bind_int(stmt, 1, rule_id)); + checkErrorCode(sqlite3_bind_int(stmt, 2, task_id)); + checkErrorCode(sqlite3_bind_int(stmt, 3, entity_id)); + + checkErrorCode(sqlite3_step(stmt)); + finalize(stmt); +} + +void SQLiteExplanation::begin_tr() { + checkErrorCode(sqlite3_exec(db, "BEGIN TRANSACTION", NULL, NULL, NULL), "BEGIN TRANSACTION"); +} + +void SQLiteExplanation::begin() { + //sqlite3_mutex_enter(sqlite3_db_mutex(db)); +} + +sqlite3_stmt* SQLiteExplanation::prepare(char* sql) { + sqlite3_stmt* stmt; + checkErrorCode(sqlite3_prepare(db, sql, -1, &stmt, NULL), sql); + return stmt; +} + +void SQLiteExplanation::finalize(sqlite3_stmt* stmt) { + checkErrorCode(sqlite3_finalize(stmt)); +} + +void SQLiteExplanation::commit() { + //sqlite3_mutex_leave(sqlite3_db_mutex(db)); +} + +void SQLiteExplanation::commit_tr() { + checkErrorCode(sqlite3_exec(db, "COMMIT TRANSACTION", NULL, NULL, NULL), "COMMIT TRANSACTION"); +} + +void SQLiteExplanation::checkErrorCode(int code) { + if (code != SQLITE_OK && code != SQLITE_DONE) { + const char* err; + if (this->db) { + err = sqlite3_errmsg(this->db); + } + else { + err = sqlite3_errstr(code); + } + std::cerr << "Error " << code << ": " << err << '\n'; + std::exit(EXIT_FAILURE); + } +} + +void SQLiteExplanation::checkErrorCode(int code, char* sql) { + if (code != SQLITE_OK && code != SQLITE_DONE) { + const char* err; + if (this->db) { + err = sqlite3_errmsg(this->db); + } + else { + err = sqlite3_errstr(code); + } + std::cerr << sql << "\n"; + std::cerr << "Error " << code << ": " << err << '\n'; + std::exit(EXIT_FAILURE); + } +} + +int SQLiteExplanation::getNextTaskID() { + int task_id_; +#pragma omp critical + { + _task_id++; + task_id_ = _task_id; + } + return task_id_; +} diff --git a/src/TesttripleReader.cpp b/src/TesttripleReader.cpp index 57be33c..4d7ad72 100644 --- a/src/TesttripleReader.cpp +++ b/src/TesttripleReader.cpp @@ -1,10 +1,9 @@ #include "TesttripleReader.h" -TesttripleReader::TesttripleReader(std::string filepath, Index * index, TraintripleReader* graph, int is_trial) { +TesttripleReader::TesttripleReader(Index * index, TraintripleReader* graph, int is_trial) { this->index = index; this->graph = graph; this->is_trial = is_trial; - read(filepath); } int ** TesttripleReader::getTesttriples() { @@ -76,7 +75,8 @@ void TesttripleReader::read(std::string filepath) { fclose(test_sample_file); std::cout << "Written test sample to " << Properties::get().PATH_TEST_SAMPLE << std::endl; - TesttripleReader* sample_reader = new TesttripleReader(Properties::get().PATH_TEST_SAMPLE.c_str(), index, graph, 0); + TesttripleReader* sample_reader = new TesttripleReader(index, graph, 0); + sample_reader->read(Properties::get().PATH_TEST_SAMPLE); csr = sample_reader->getCSR(); delete sample_reader; } @@ -103,3 +103,42 @@ void TesttripleReader::read(std::string filepath) { exit(-1); } } + +void TesttripleReader::read(std::vector> & triples) { + std::vector> testtriplesVector; + size_t numTriplesRead = 0; + for (auto & [head, rel, tail] : triples) + { + try { + int* headId = index->getIdOfNodestring(head); + int* relId = index->getIdOfRelationstring(rel); + int* tailId = index->getIdOfNodestring(tail); + std::vector testtriple; + testtriple.push_back(headId); + testtriple.push_back(relId); + testtriple.push_back(tailId); + testtriplesVector.push_back(testtriple); + + relHeadToTails[*relId][*headId].insert(*tailId); + relTailToHeads[*relId][*tailId].insert(*headId); + + numTriplesRead++; + } + catch (std::runtime_error& e) {} + } + csr = new CSR(index->getRelSize(), index->getNodeSize(), relHeadToTails, relTailToHeads); + + testtripleSize = new int; + *testtripleSize = testtriplesVector.size(); + // Convert to pointers of pointers + int * testtriplesstore; + testtriples = new int*[*testtripleSize]; + testtriplesstore = new int[(*testtripleSize) * 3]; + + for (int i = 0; i < (*testtripleSize); i++) { + testtriplesstore[i * 3] = *(testtriplesVector[i][0]); + testtriplesstore[i * 3 + 1] = *(testtriplesVector[i][1]); + testtriplesstore[i * 3 + 2] = *(testtriplesVector[i][2]); + testtriples[i] = &testtriplesstore[i * 3]; + } +}