From e251edc60edbe1c32e3d749857a1e7fba5943256 Mon Sep 17 00:00:00 2001 From: Anton Belyy Date: Fri, 22 Apr 2022 12:46:51 -0400 Subject: [PATCH 1/2] Implement Python bindings for the applymax action --- .github/workflows/build_cmake.yml | 12 +- .gitignore | 1 + CMakeLists.txt | 29 +- CMakeSettings.json | 28 ++ Main.cpp | 128 +++---- SAFRAN.vcxproj | 2 + SAFRAN.vcxproj.filters | 6 + include/Explanation.h | 29 ++ include/InMemoryExplanation.h | 34 ++ include/Properties.hpp | 87 ++++- include/Rule.h | 22 +- include/RuleApplication.h | 9 +- include/RuleGraph.h | 2 + include/RuleReader.h | 2 +- include/SQLiteExplanation.h | 39 ++ include/TesttripleReader.h | 7 +- include/Util.hpp | 12 +- python_bindings/.gitignore | 6 + python_bindings/safran.py | 16 + python_bindings/safran_wrapper.cpp | 95 +++++ python_bindings/safran_wrapper.h | 32 ++ python_bindings/safran_wrapper.i | 20 ++ python_bindings/setup.py | 51 +++ src/InMemoryExplanation.cpp | 96 +++++ src/JaccardEngine.cpp | 64 +--- src/Rule.cpp | 53 +-- src/RuleApplication.cpp | 554 ++++++++++++----------------- src/RuleGraph.cpp | 38 +- src/RuleReader.cpp | 7 +- src/SQLiteExplanation.cpp | 332 +++++++++++++++++ src/TesttripleReader.cpp | 53 ++- 31 files changed, 1313 insertions(+), 553 deletions(-) create mode 100644 CMakeSettings.json create mode 100644 include/Explanation.h create mode 100644 include/InMemoryExplanation.h create mode 100644 include/SQLiteExplanation.h create mode 100644 python_bindings/.gitignore create mode 100644 python_bindings/safran.py create mode 100644 python_bindings/safran_wrapper.cpp create mode 100644 python_bindings/safran_wrapper.h create mode 100644 python_bindings/safran_wrapper.i create mode 100644 python_bindings/setup.py create mode 100644 src/InMemoryExplanation.cpp create mode 100644 src/SQLiteExplanation.cpp diff --git a/.github/workflows/build_cmake.yml b/.github/workflows/build_cmake.yml index ed14d53..6bc9886 100644 --- a/.github/workflows/build_cmake.yml +++ b/.github/workflows/build_cmake.yml @@ -16,12 +16,12 @@ jobs: fail-fast: false matrix: config: - - { - name: "Windows Latest MSVC", artifact: "Windows-MSVC.7z", - os: windows-latest, - cc: "cl", cxx: "cl", - environment_script: "C:/Program Files (x86)/Microsoft Visual Studio/2019/Enterprise/VC/Auxiliary/Build/vcvars64.bat" - } +# - { +# name: "Windows Latest MSVC", artifact: "Windows-MSVC.7z", +# os: windows-latest, +# cc: "cl", cxx: "cl", +# environment_script: "C:/Program Files (x86)/Microsoft Visual Studio/2019/Enterprise/VC/Auxiliary/Build/vcvars64.bat" +# } - { name: "Windows Latest MinGW", artifact: "Windows-MinGW.7z", os: windows-latest, diff --git a/.gitignore b/.gitignore index 0250e6e..7ef241b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.idea/ build/ boost_1_76_0/ include/boost diff --git a/CMakeLists.txt b/CMakeLists.txt index 9fe7ba4..88efba1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,42 +1,52 @@ cmake_minimum_required(VERSION 3.9.0) -project(safran CXX) + +include(FetchContent) + +project(safran CXX C) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) -set(CMAKE_BUILD_TYPE Release) -set(CMAKE_CXX_FLAGS_RELEASE "-O3") +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(CMAKE_CXX_FLAGS_RELEASE "-O3") +endif() + +FetchContent_Declare(sqlite3 URL "http://sqlite.org/2021/sqlite-amalgamation-3340100.zip") +set(FETCHCONTENT_QUIET OFF) +FetchContent_MakeAvailable(sqlite3) + +message(STATUS "${sqlite3_SOURCE_DIR}") + include_directories( ${PROJECT_SOURCE_DIR}/boost_1_76_0 ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src + ${sqlite3_SOURCE_DIR} ) file(GLOB all_src "${PROJECT_SOURCE_DIR}/include/*.h" "${PROJECT_SOURCE_DIR}/include/*.hpp" "${PROJECT_SOURCE_DIR}/src/*.cpp" + "${sqlite3_SOURCE_DIR}/*.c" + "${sqlite3_SOURCE_DIR}/*.h" ) - add_library(lsafran ${all_src}) find_package(OpenMP REQUIRED) -target_link_libraries(lsafran PUBLIC OpenMP::OpenMP_CXX) +target_link_libraries(lsafran PUBLIC OpenMP::OpenMP_CXX ${CMAKE_DL_LIBS}) add_executable(safran ${PROJECT_SOURCE_DIR}/Main.cpp) -target_link_libraries(safran lsafran) +target_link_libraries(safran PUBLIC lsafran) install(TARGETS safran) -Include(FetchContent) - FetchContent_Declare( Catch2 GIT_REPOSITORY https://github.com/catchorg/Catch2.git GIT_TAG v3.0.0-preview3 ) - FetchContent_MakeAvailable(Catch2) add_executable(tests tests/test.cpp) @@ -47,4 +57,3 @@ include(CTest) include(Catch) catch_discover_tests(tests) - diff --git a/CMakeSettings.json b/CMakeSettings.json new file mode 100644 index 0000000..5268670 --- /dev/null +++ b/CMakeSettings.json @@ -0,0 +1,28 @@ +{ + "configurations": [ + { + "name": "x64-Debug (Standard)", + "generator": "Ninja", + "configurationType": "Debug", + "inheritEnvironments": [ "msvc_x64_x64" ], + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeCommandArgs": "", + "buildCommandArgs": "", + "ctestCommandArgs": "", + "variables": [] + }, + { + "name": "x64-Release", + "generator": "Ninja", + "configurationType": "Release", + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeCommandArgs": "", + "buildCommandArgs": "", + "ctestCommandArgs": "", + "inheritEnvironments": [ "msvc_x64_x64" ], + "variables": [] + } + ] +} \ No newline at end of file diff --git a/Main.cpp b/Main.cpp index 5c3679b..0210c02 100644 --- a/Main.cpp +++ b/Main.cpp @@ -9,6 +9,8 @@ #include "ClusteringReader.h" #include "RuleApplication.h" #include "JaccardEngine.h" +#include "SQLiteExplanation.h" +#include "Util.hpp" #include #include @@ -22,7 +24,8 @@ int main(int argc, char** argv) std::cout << "Wrong number of startup arguments, please make sure that arguments are in form of {action} {path to properties}" << std::endl; exit(-1); } - Properties::get().ACTION = argv[1]; + Properties::get().setAction(argv[1]); + Action action = Properties::get().ACTION; bool success = Properties::get().read(argv[2]); if (!success) { std::cout << "No properties file found, falling back to default\n"; @@ -36,6 +39,7 @@ int main(int argc, char** argv) auto start = std::chrono::high_resolution_clock::now(); Index* index = new Index(); + // Adding the reflexive token to the index index->addNode(Properties::get().REFLEXIV_TOKEN); index->addNode(Properties::get().UNK_TOKEN); @@ -49,17 +53,9 @@ int main(int argc, char** argv) Properties::get().REL_SIZE = index->getRelSize(); - - //"C:\\Users\\Simon\\Desktop\\data\\alpha-50" - std::cout << "Reading rules..." << std::endl; - RuleReader* rr = new RuleReader(Properties::get().PATH_RULES, index, graph); - finish = std::chrono::high_resolution_clock::now(); - milliseconds = std::chrono::duration_cast(finish - start); - std::cout << "Rules read in " << milliseconds.count() << " ms\n"; - start = finish; - std::cout << "Reading testset..." << std::endl; - TesttripleReader* ttr = new TesttripleReader(Properties::get().PATH_TEST, index, graph, Properties::get().TRIAL); + TesttripleReader* ttr = new TesttripleReader(index, graph, Properties::get().TRIAL); + ttr->read(Properties::get().PATH_TEST); finish = std::chrono::high_resolution_clock::now(); milliseconds = std::chrono::duration_cast(finish - start); std::cout << "Testset read in " << milliseconds.count() << " ms\n"; @@ -72,79 +68,65 @@ int main(int argc, char** argv) std::cout << "Validationset read in " << milliseconds.count() << " ms\n"; start = finish; - /* - Rule* rules_adj_list = rr->getCSR()->getAdjList(); - int* adj_begin = rr->getCSR()->getAdjBegin(); - std::string r("/award/award_winning_work/awards_won./award/award_honor/award_winner"); - int relation = *index->getIdOfRelationstring(r); - - std::string t("/m/0g69lg"); - int tail = *index->getIdOfNodestring(t); - - std::string h("/m/015ppk"); - int head = *index->getIdOfNodestring(h); - - auto tails = ttr->getRelTailToHeads()[relation][tail]; - - int reflexiv_token = 0; - - int ind_ptr = adj_begin[3 + relation]; - int lenRules = adj_begin[3 + relation + 1] - ind_ptr; - RuleGraph* rulegraph = new RuleGraph(index->getNodeSize(), graph, ttr, vtr); - for (int i = 0; i < lenRules; i++) { - std::vector headresults_vec; - Rule& r = rules_adj_list[ind_ptr + i]; - if (r.getRulestring().compare("/award/award_winning_work/awards_won./award/award_honor/award_winner(/m/015ppk,Y) <= /award/award_winner/awards_won./award/award_honor/award_winner(Y,/m/05cqhl)") == 0) { - if (r.is_ac2() && r.getRuletype() == Ruletype::YRule && *r.getHeadconstant() != tail) { - if (rulegraph->existsAcyclic(&tail, r, true)) { - headresults_vec.push_back(*r.getHeadconstant()); - } - } - } - std::vector filtered_headresults_vec; - for (auto a : headresults_vec) { - if (a == tail) continue; - if (a == reflexiv_token) { - a = tail; - } - if (a == head || tails.find(a) == tails.end()) { - filtered_headresults_vec.push_back(a); - } - } - for (auto a : filtered_headresults_vec) { - std::cout << a << " " << *index->getStringOfNodeId(a) << " "; - } + std::cout << "Reading rules..." << std::endl; + RuleReader* rr = new RuleReader(Properties::get().PATH_RULES, index, graph); + + // READ Clustering if present + ClusteringReader* cr = nullptr; + if (action == applynrnoisy) { + cr = new ClusteringReader(Properties::get().PATH_CLUSTER, rr->getCSR(), index, graph); } + finish = std::chrono::high_resolution_clock::now(); + milliseconds = std::chrono::duration_cast(finish - start); + std::cout << "Rules read in " << milliseconds.count() << " ms\n"; + start = finish; - exit(-1); - */ + - std::cout << "Applying rules..." << std::endl; + // PREPARE Explanation DB if wanted + Explanation* explanation = nullptr; + if (Properties::get().EXPLAIN == 1) { + std::cout << "Writing entities, relations and rules to db file..." << std::endl; + explanation = new SQLiteExplanation(Properties::get().PATH_EXPLAIN, true); + explanation->begin_tr(); + explanation->insertEntities(index); + explanation->insertRelations(index); + explanation->insertRules(rr, index->getRelSize(), cr); + finish = std::chrono::high_resolution_clock::now(); + milliseconds = std::chrono::duration_cast(finish - start); + std::cout << "Written in " << milliseconds.count() << " ms\n"; + start = finish; + } - if (Properties::get().ACTION.compare("learnnrnoisy") == 0) { + std::cout << "Applying rules..." << std::endl; + if (action == learnnrnoisy) { ClusteringEngine* ce = new ClusteringEngine(index, graph, ttr, vtr, rr); ce->learn(); } - else if (Properties::get().ACTION.compare("applynrnoisy") == 0) { - ClusteringReader* cr = new ClusteringReader(Properties::get().PATH_CLUSTER, rr->getCSR(), index, graph); - RuleApplication* ca = new RuleApplication(index, graph, ttr, vtr, rr); - ca->apply_nr_noisy(cr->getRelToClusters()); - } - else if (Properties::get().ACTION.compare("applymax") == 0) { - RuleApplication* ca = new RuleApplication(index, graph, ttr, vtr, rr); - ca->apply_only_max(); - } - else if (Properties::get().ACTION.compare("applynoisy") == 0) { - RuleApplication* ca = new RuleApplication(index, graph, ttr, vtr, rr); - ca->apply_only_noisy(); - } - else if (Properties::get().ACTION.compare("calcjacc") == 0) { + else if (action == calcjacc) { JaccardEngine* jacccalc = new JaccardEngine(index, graph, vtr, rr); jacccalc->calculate_jaccard(); } else { - std::cout << "ACTION not found" << "\n"; - exit(-1); + if (action == applynrnoisy) { + RuleApplication* ca = new RuleApplication(index, graph, ttr, vtr, rr, explanation); + ca->apply_nr_noisy(cr->getRelToClusters()); + } + else if (action == applymax) { + RuleApplication* ca = new RuleApplication(index, graph, ttr, vtr, rr, explanation); + ca->apply_only_max(); + } + else if (action == applynoisy) { + RuleApplication* ca = new RuleApplication(index, graph, ttr, vtr, rr, explanation); + ca->apply_only_noisy(); + } + else { + std::cout << "ACTION not found" << "\n"; + exit(-1); + } + if (Properties::get().EXPLAIN == 1) { + explanation->commit_tr(); + } } finish = std::chrono::high_resolution_clock::now(); milliseconds = std::chrono::duration_cast(finish - start); diff --git a/SAFRAN.vcxproj b/SAFRAN.vcxproj index 1063b13..0cac131 100644 --- a/SAFRAN.vcxproj +++ b/SAFRAN.vcxproj @@ -19,6 +19,7 @@ + @@ -43,6 +44,7 @@ + diff --git a/SAFRAN.vcxproj.filters b/SAFRAN.vcxproj.filters index 2d0620f..ec57e6a 100644 --- a/SAFRAN.vcxproj.filters +++ b/SAFRAN.vcxproj.filters @@ -81,6 +81,9 @@ Header Files + + Header Files + @@ -140,5 +143,8 @@ Source Files + + Source Files + \ No newline at end of file diff --git a/include/Explanation.h b/include/Explanation.h new file mode 100644 index 0000000..627b91f --- /dev/null +++ b/include/Explanation.h @@ -0,0 +1,29 @@ +#ifndef EXPL_H +#define EXPL_H + +#include "Index.h" +#include "Rule.h" +#include "RuleReader.h" +#include "ClusteringReader.h" + +class Explanation { +public: + Explanation() {} + virtual ~Explanation() {} + + virtual void begin() = 0; + virtual void commit() = 0; + virtual void begin_tr() = 0; + virtual void commit_tr() = 0; + virtual void insertEntities(Index* index) = 0; + virtual void insertRelations(Index* index) = 0; + virtual void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr) = 0; + + virtual void insertTask(int prediction_id, bool is_head, int relation_id, int entity_id) = 0; + virtual void insertPrediction(int task_id, int entity_id, bool hit, double confidence) = 0; + virtual void insertRule_Entity(int rule_id, int task_id, int entity_id) = 0; + + virtual int getNextTaskID() = 0; +}; + +#endif //EXPL_H \ No newline at end of file diff --git a/include/InMemoryExplanation.h b/include/InMemoryExplanation.h new file mode 100644 index 0000000..525a199 --- /dev/null +++ b/include/InMemoryExplanation.h @@ -0,0 +1,34 @@ +#ifndef INMEMORY_EXPL_H +#define INMEMORY_EXPL_H + +#include "Explanation.h" + +class InMemoryExplanation: public Explanation { +public: + InMemoryExplanation(); + ~InMemoryExplanation(); + + void begin(); + void commit(); + void begin_tr(); + void commit_tr(); + void insertEntities(Index* index); + void insertRelations(Index* index); + void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr); + + void insertTask(int task_id, bool is_head, int relation_id, int entity_id); + void insertPrediction(int task_id, int entity_id, bool hit, double confidence); + void insertRule_Entity(int rule_id, int task_id, int entity_id); + + int getNextTaskID(); + + std::unordered_map>> tripleBestRules; + +private: + int _task_id = 0; + std::unordered_map ruleConfidences; + std::unordered_map> tasks; + std::unordered_map> taskEntityBestRules; +}; + +#endif //INMEMORY_EXPL_H \ No newline at end of file diff --git a/include/Properties.hpp b/include/Properties.hpp index 008d14a..845c007 100644 --- a/include/Properties.hpp +++ b/include/Properties.hpp @@ -12,11 +12,13 @@ #include "Util.hpp" +enum Action { learnnrnoisy, calcjacc, applynoisy, applymax, applynrnoisy }; + class Properties { public: // ACTION - std::string ACTION = "applymax"; + Action ACTION = applymax; // PATHS std::string PATH_TRAINING = "train.txt"; @@ -37,6 +39,11 @@ class Properties { std::string UNK_TOKEN = "UNKNOWN"; int UNSEEN_NEGATIVE_EXAMPLES = 5; int ONLY_UNCONNECTED = 0; + + // APPLY + int EXPLAIN = 0; + std::string PATH_EXPLAIN = "explanation.db"; + int ONLY_XY = 0; int VERBOSE = 1; int PREDICT_UNKNOWN = 0; @@ -149,9 +156,16 @@ class Properties { else if (strKey.compare("SEED") == 0) { SEED = std::stoi(strVal); } + else if (strKey.compare("EXPLAIN") == 0) { + EXPLAIN = std::stoi(strVal); + } + else if (strKey.compare("PATH_EXPLAIN") == 0) { + PATH_EXPLAIN = strVal; + } else if (strKey.compare("ONLY_XY") == 0) { ONLY_XY = std::stoi(strVal); } + else if (strKey.compare("PREDICT_UNKNOWN") == 0) { PREDICT_UNKNOWN = std::stoi(strVal); } @@ -176,21 +190,22 @@ class Properties { std::string toString() { std::ostringstream string_rep; - string_rep << "ACTION = " << ACTION << std::endl; + string_rep << "ACTION = " << getAction(ACTION) << std::endl; // PATHS string_rep << "PATH_TRAINING = " << PATH_TRAINING << std::endl; string_rep << "PATH_TEST = " << PATH_TEST << std::endl; string_rep << "PATH_VALID = " << PATH_VALID << std::endl; string_rep << "PATH_RULES = " << PATH_RULES << std::endl; - if (ACTION.compare("learnnrnoisy") == 0 || ACTION.compare("calcjacc") == 0) { + if (ACTION == learnnrnoisy || ACTION == calcjacc) { string_rep << "PATH_JACCARD = " << PATH_JACCARD << std::endl; } - if (ACTION.compare("learnnrnoisy") == 0 || ACTION.compare("applynrnoisy") == 0) { + if (ACTION == learnnrnoisy || ACTION == applynrnoisy) { string_rep << "PATH_CLUSTER = " << PATH_CLUSTER << std::endl; } - if (ACTION.compare("applymax") == 0 || ACTION.compare("applynoisy") == 0 || ACTION.compare("applynrnoisy") == 0) { + if (ACTION == applymax || ACTION == applynoisy || ACTION == applynrnoisy) { string_rep << "PATH_OUTPUT = " << PATH_OUTPUT << std::endl; + string_rep << "EXPLAIN = " << EXPLAIN << std::endl; } // GENERAL PROPS @@ -199,32 +214,28 @@ class Properties { string_rep << "UNSEEN_NEGATIVE_EXAMPLES = " << UNSEEN_NEGATIVE_EXAMPLES << std::endl; string_rep << "TOP_K_OUTPUT = " << TOP_K_OUTPUT << std::endl; string_rep << "REFLEXIV_TOKEN = " << REFLEXIV_TOKEN << std::endl; + string_rep << "ONLY_UNCONNECTED = " << ONLY_UNCONNECTED << std::endl; - // SPECIFIC PROPS - if (ACTION.compare("applymax") == 0 || ACTION.compare("applynoisy") == 0 || ACTION.compare("applynrnoisy") == 0 || ACTION.compare("learnnrnoisy") == 0) { - string_rep << "ONLY_UNCONNECTED = " << ONLY_UNCONNECTED << std::endl; - } - - if (ACTION.compare("learnnrnoisy") == 0) { + if (ACTION == learnnrnoisy) { string_rep << "STRATEGY = " << STRATEGY << std::endl; if (STRATEGY.compare("random") == 0) { string_rep << "ITERATIONS = " << ITERATIONS << std::endl; } } - if (ACTION.compare("learnnrnoisy") == 0 || ACTION.compare("calcjacc") == 0) { + if (ACTION == learnnrnoisy || ACTION == calcjacc) { string_rep << "RESOLUTION = " << RESOLUTION << std::endl; string_rep << "SEED = " << SEED << std::endl; } - if (ACTION.compare("learnnrnoisy") == 0) { + if (ACTION == learnnrnoisy) { string_rep << "BUFFER_SIZE = " << BUFFER_SIZE << std::endl; } - if (ACTION.compare("calcjacc") == 0) { + if (ACTION == calcjacc) { string_rep << "CLUSTER_SET = " << CLUSTER_SET << std::endl; } - if (ACTION.compare("applymax") == 0 || ACTION.compare("applynoisy") == 0 || ACTION.compare("applynrnoisy") == 0) { + if (ACTION == applymax || ACTION == applynoisy || ACTION == applynrnoisy) { string_rep << "TRIAL = " << TRIAL << std::endl; if (TRIAL == 1) { string_rep << "TRIAL_SIZE = " << TRIAL_SIZE << std::endl; @@ -239,6 +250,49 @@ class Properties { return string_rep.str().c_str(); } + void setAction(std::string action) { + if (action.compare("learnnrnoisy") == 0) { + this->ACTION = learnnrnoisy; + } + else if (action.compare("calcjacc") == 0) { + this->ACTION = calcjacc; + } + else if (action.compare("applynoisy") == 0) { + this->ACTION = applynoisy; + } + else if (action.compare("applymax") == 0) { + this->ACTION = applymax; + } + else if (action.compare("applynrnoisy") == 0) { + this->ACTION = applynrnoisy; + } + else { + std::cout << "ACTION" << action << " not found" << "\n"; + exit(-1); + } + } + + std::string getAction(int action) { + if (action == learnnrnoisy) { + return "learnnrnoisy"; + } + else if (action == calcjacc) { + return "calcjacc"; + } + else if (action == applynoisy) { + return "applynoisy"; + } + else if (action == applymax) { + return "applymax"; + } + else if (action == applynrnoisy) { + return "applynrnoisy"; + } + else { + std::cout << "ACTION" << action << " not found" << "\n"; + exit(-1); + } + } private: Properties() {}; Properties(const Properties&); @@ -261,6 +315,9 @@ class Properties { { return rtrim(ltrim(s)); } + + + }; #endif //PROPERTIES_H diff --git a/include/Rule.h b/include/Rule.h index 84620b5..052bf7c 100644 --- a/include/Rule.h +++ b/include/Rule.h @@ -7,7 +7,7 @@ #include #include #include - +#include #include "Properties.hpp" enum Ruletype { XYRule, XRule, YRule, None }; @@ -18,6 +18,7 @@ class Rule Rule() {} Rule(int no1, int no2, double confidence); + void setID(int ID); void print(); //Setter @@ -29,8 +30,11 @@ class Rule void setHeadconstant(int* constant); void setBodyconstantId(int* id); void setRulestring(std::string rule); + void setBuffer(std::vector buffer); bool isBuffered(); + std::vector& getBuffer(); + void removeBuffer(); void setHeadBuffer(int head, std::vector buffer); bool isHeadBuffered(int head); @@ -46,10 +50,8 @@ class Rule bool is_ac1(); bool is_ac2(); - void add_head_exception(int val); - void add_tail_exception(int val); - //Getter + int& getID(); Ruletype getRuletype(); int& getRulelength(); int* getHeadrelation(); @@ -61,19 +63,17 @@ class Rule long long getPredicted(); double getAppliedConfidence(); std::string getRulestring(); - std::vector& getBuffer(); long long get_body_hash(); - void removeBuffer(); void compute_body_hash(); bool is_body_equal(Rule other); Rule& operator=(Rule* other); - std::unordered_set head_exceptions; - std::unordered_set tail_exceptions; protected: private: + //ID (Needed for explaination) + int ID; //Type of rule XYRule, XRule, YRule Ruletype type; //Length of body @@ -96,11 +96,11 @@ class Rule std::string rulestring; - std::vector buffer; + std::optional> buffer{}; bool buffered = false; - std::unordered_map> tailBuffer; - std::unordered_map> headBuffer; + std::optional>> tailBuffer{}; + std::optional>> headBuffer{}; }; diff --git a/include/RuleApplication.h b/include/RuleApplication.h index e969d21..5c17abe 100644 --- a/include/RuleApplication.h +++ b/include/RuleApplication.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "Index.h" #include "TraintripleReader.h" #include "TesttripleReader.h" @@ -15,6 +16,7 @@ #include "Util.hpp" #include "ScoreTree.h" #include "boost/multiprecision/cpp_bin_float.hpp" +#include "Explanation.h" #include #include @@ -25,10 +27,13 @@ typedef boost::multiprecision::cpp_bin_float_50 float50; class RuleApplication { public: - RuleApplication(Index* index, TraintripleReader* graph, TesttripleReader* ttr, ValidationtripleReader* vtr, RuleReader* rr); + RuleApplication(Index* index, TraintripleReader* graph, TesttripleReader* ttr, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp); + RuleApplication(Index* index, TraintripleReader* graph, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp); void apply_nr_noisy(std::unordered_map>>, std::pair>>>> rel2clusters); void apply_only_noisy(); void apply_only_max(); + void updateTTR(TesttripleReader* ttr); + std::vector> apply_only_max_in_memory(size_t K); private: Index* index; @@ -37,6 +42,8 @@ class RuleApplication ValidationtripleReader* vtr; RuleReader* rr; + Explanation* exp; + FILE* pFile; RuleGraph* rulegraph; int reflexiv_token; diff --git a/include/RuleGraph.h b/include/RuleGraph.h index 79e0ceb..3e54578 100644 --- a/include/RuleGraph.h +++ b/include/RuleGraph.h @@ -16,7 +16,9 @@ class RuleGraph { public: RuleGraph(int nodesize, TraintripleReader* graph); + RuleGraph(int nodesize, TraintripleReader* graph, ValidationtripleReader* vtr); RuleGraph(int nodesize, TraintripleReader* graph, TesttripleReader* ttr, ValidationtripleReader* vtr); + void updateTTR(TesttripleReader *ttr); void searchDFSSingleStart_filt(bool headNotTail, int filt_v, int v, Rule& r, bool bwd, std::vector& solution, bool filtValidNotTest, bool filtExceptions); void searchDFSMultiStart_filt(bool headNotTail, int filt_v, Rule& r, bool bwd, std::vector& solution, bool filtValidNotTest, bool filtExceptions); bool existsAcyclic(int* valId, Rule& rule, bool filtValidNotTest); diff --git a/include/RuleReader.h b/include/RuleReader.h index 99ee747..55f4d0e 100644 --- a/include/RuleReader.h +++ b/include/RuleReader.h @@ -28,7 +28,7 @@ class RuleReader CSR* csr; void read(std::string filepath); - Rule* parseRule(std::vector rule); + Rule* parseRule(std::vector rule, int currID); bool parseXtoY(Ruletype type, std::string& head, std::string& tail); std::string getRelation(std::string atom, std::string previous, int* relation); std::pair getHeadTail(std::string& atom); diff --git a/include/SQLiteExplanation.h b/include/SQLiteExplanation.h new file mode 100644 index 0000000..a88aff0 --- /dev/null +++ b/include/SQLiteExplanation.h @@ -0,0 +1,39 @@ +#ifndef SQLITE_EXPL_H +#define SQLITE_EXPL_H + +#include +#include "Explanation.h" + +class SQLiteExplanation: public Explanation { +public: + SQLiteExplanation(std::string dbName, bool init = false); + ~SQLiteExplanation(); + + void begin(); + void commit(); + void begin_tr(); + void commit_tr(); + void insertEntities(Index* index); + void insertRelations(Index* index); + void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr); + + void insertTask(int task_id, bool is_head, int relation_id, int entity_id); + void insertPrediction(int task_id, int entity_id, bool hit, double confidence); + void insertRule_Entity(int rule_id, int task_id, int entity_id); + + // OLD + void insertCluster(int prediction_id, int entity_id, int cluster_id, double confidence); + void insertRule_Cluster(int prediction_id, int entity_id, int cluster_id, int rule_id); + + int getNextTaskID(); +private: + sqlite3* db; + int _task_id = 0; + void initDb(); + void checkErrorCode(int code); + void checkErrorCode(int code, char* sql); + sqlite3_stmt* prepare(char* sql); + void finalize(sqlite3_stmt* stmt); +}; + +#endif //SQLITE_EXPL_H \ No newline at end of file diff --git a/include/TesttripleReader.h b/include/TesttripleReader.h index 8610031..a0361df 100644 --- a/include/TesttripleReader.h +++ b/include/TesttripleReader.h @@ -21,7 +21,7 @@ class TesttripleReader { public: - TesttripleReader(std::string filepath, Index* index, TraintripleReader* graph, int is_trial); + TesttripleReader(Index* index, TraintripleReader* graph, int is_trial); int** getTesttriples(); int* getTesttriplesSize(); @@ -29,6 +29,9 @@ class TesttripleReader RelNodeToNodes& getRelHeadToTails(); RelNodeToNodes& getRelTailToHeads(); + void read(std::string filepath); + void read(std::vector> & triples); + protected: private: @@ -41,8 +44,6 @@ class TesttripleReader RelNodeToNodes relHeadToTails; RelNodeToNodes relTailToHeads; - - void read(std::string filepath); }; #endif // TESTTRIPLEREADER_H diff --git a/include/Util.hpp b/include/Util.hpp index a088f5d..09b9c56 100644 --- a/include/Util.hpp +++ b/include/Util.hpp @@ -17,6 +17,7 @@ #include #include #include +#include namespace util{ inline std::vector split(const std::string& str, char delim = ' ') { @@ -152,7 +153,7 @@ namespace util{ // That is faster than reading them one-by-one using the std::istream. // Code that uses streambuf this way must be guarded by a sentry object. // The sentry object performs various tasks, - // such as thread synchronization and updating the stream state. + // such as thread synchronization && updating the stream state. std::istream::sentry se(is, true); std::streambuf* sb = is.rdbuf(); @@ -195,6 +196,15 @@ namespace util{ return 0; } } + + inline std::string getDbName() { + time_t now = time(0); + struct tm tstruct; + char buf[80]; + tstruct = *localtime(&now); + strftime(buf, sizeof(buf), "%Y-%m-%d-%H-%M-%S.db", &tstruct); + return buf; + } } #endif // UTIL_H diff --git a/python_bindings/.gitignore b/python_bindings/.gitignore new file mode 100644 index 0000000..d6b83bb --- /dev/null +++ b/python_bindings/.gitignore @@ -0,0 +1,6 @@ +build/ +dist/ +*.egg-info/ +__pycache__ +safran_wrapper.py +safran_wrapper_wrap.cpp diff --git a/python_bindings/safran.py b/python_bindings/safran.py new file mode 100644 index 0000000..9848f7f --- /dev/null +++ b/python_bindings/safran.py @@ -0,0 +1,16 @@ +import itertools +import collections +from typing import List, Dict, Tuple +from safran_wrapper import pysafran, query_output_t, query_triples_t + +class SAFRAN(pysafran): + def __init__(self, train_path: str, rule_path: str, n_jobs: int = 1): + pysafran.__init__(self, train_path, rule_path, n_jobs) + + def query(self, triples: List[List[str]], k: int = 100, action: str = 'applymax') -> Dict[Tuple[str, str], List[Tuple[str, float, int]]]: + flat_triples = list(itertools.chain.from_iterable(triples)) + pred_vals = query_output_t(pysafran.query(self, action, k, query_triples_t(flat_triples))) + out = collections.defaultdict(list) + for head, (pred, (tail, (val, rule_id))) in pred_vals: + out[head, pred].append((tail, val, rule_id)) + return out \ No newline at end of file diff --git a/python_bindings/safran_wrapper.cpp b/python_bindings/safran_wrapper.cpp new file mode 100644 index 0000000..5c1641f --- /dev/null +++ b/python_bindings/safran_wrapper.cpp @@ -0,0 +1,95 @@ +#include "safran_wrapper.h" + +pysafran::pysafran(std::string train_path, std::string rule_path, int num_threads) { + // TODO (Anton): Perhaps instead initialize these globally, or make those a part of pysafran instance. + Properties::get().VERBOSE = 0; + Properties::get().PREDICT_UNKNOWN = 1; + if (num_threads != -1) { + omp_set_num_threads(num_threads); + } + + this->index = new Index(); + index->addNode(Properties::get().REFLEXIV_TOKEN); + index->addNode(Properties::get().UNK_TOKEN); + + this->graph = new TraintripleReader(train_path, index); + Properties::get().REL_SIZE = index->getRelSize(); + + this->rr = new RuleReader(rule_path, index, graph); + this->vtr = new ValidationtripleReader("/dev/null", index, graph); + + this->exp = new InMemoryExplanation(); + this->exp->insertRules(rr, index->getRelSize(), nullptr); + + this->ra = new RuleApplication(index, graph, vtr, rr, exp); +} + +pysafran::~pysafran() { + delete this->ra; + delete this->vtr; + delete this->rr; + delete this->graph; + delete this->index; + delete this->exp; +} + +std::vector>>>> pysafran::query(const std::string & action, size_t k, + const std::vector & flat_triples) const { + // "Read" triples + if (Properties::get().VERBOSE == 1) { + std::cout << "read start" << std::endl; + } + TesttripleReader ttr(index, graph, 0); + std::vector> triples; + size_t i = 0; + while (i < flat_triples.size()) { + triples.emplace_back(flat_triples[i], flat_triples[i + 1], flat_triples[i + 2]); + i += 3; + } + ttr.read(triples); + if (Properties::get().VERBOSE == 1) { + std::cout << "read finish" << std::endl; + } + + if (Properties::get().VERBOSE == 1) { + std::cout << "apply start" << std::endl; + } + // Apply specific rule application + std::vector> outAction; + if (action == "applymax") { + if (Properties::get().VERBOSE == 1) { + std::cout << "ra start" << std::endl; + } + ra->updateTTR(&ttr); + if (Properties::get().VERBOSE == 1) { + std::cout << "ra end" << std::endl; + } + outAction = ra->apply_only_max_in_memory(k); + } + if (Properties::get().VERBOSE == 1) { + std::cout << "apply finish" << std::endl; + } + + // Transform node ids to node names and only retain top-k + std::vector>>>> out; + if (Properties::get().VERBOSE == 1) { + std::cout << "transform start" << std::endl; + } + for (auto & p : outAction) { + int head_id = std::get<0>(p), relation_id = std::get<1>(p), tail_id = std::get<2>(p); + float val = (float)std::get<3>(p); + std::string headStr = *this->index->getStringOfNodeId(head_id); + std::string predStr = *this->index->getStringOfRelId(relation_id); + std::string tailStr = *this->index->getStringOfNodeId(tail_id); + // TODO: check that exists? + int rule_id = exp->tripleBestRules[head_id][relation_id][tail_id]; + std::pair p1 = {val, rule_id}; + std::pair> p2 = {tailStr, p1}; + std::pair>> p3 = {predStr, p2}; + out.emplace_back(headStr, p3); + } + if (Properties::get().VERBOSE == 1) { + std::cout << "transform finish" << std::endl; + } + return out; +} diff --git a/python_bindings/safran_wrapper.h b/python_bindings/safran_wrapper.h new file mode 100644 index 0000000..d6cf74c --- /dev/null +++ b/python_bindings/safran_wrapper.h @@ -0,0 +1,32 @@ +#include "Index.h" +#include "TraintripleReader.h" +#include "ValidationtripleReader.h" +#include "RuleReader.h" +#include "RuleApplication.h" +#include "Properties.hpp" +#include "InMemoryExplanation.h" + +#include +#include +#include +#include + +using string_vector_t = std::vector; + +class pysafran { +public: + pysafran(std::string train_path, std::string rule_path, int num_threads); + ~pysafran(); + + // [(head, pred, tail, confidence, rule_id)] + std::vector>>>> query(const std::string & action, size_t k, + const std::vector & flat_triples) const; + +private: + Index *index; + TraintripleReader* graph; + ValidationtripleReader* vtr; + RuleReader* rr; + RuleApplication *ra; + InMemoryExplanation *exp; +}; \ No newline at end of file diff --git a/python_bindings/safran_wrapper.i b/python_bindings/safran_wrapper.i new file mode 100644 index 0000000..69e5d10 --- /dev/null +++ b/python_bindings/safran_wrapper.i @@ -0,0 +1,20 @@ +%module safran_wrapper + +%{ +#define SWIG_FILE_WITH_INIT +#include "safran_wrapper.h" +%} + +%include +%include +%include +%include "safran_wrapper.h" + +namespace std { + %template(query_triples_t) vector; + %template(query_p1) pair; + %template(query_p2) pair>; + %template(query_p3) pair>>; + %template(query_p4) pair>>>; + %template(query_output_t) vector>>>>; +} diff --git a/python_bindings/setup.py b/python_bindings/setup.py new file mode 100644 index 0000000..9a90adb --- /dev/null +++ b/python_bindings/setup.py @@ -0,0 +1,51 @@ +import os +import sys +import setuptools +from glob import glob +from setuptools import setup, Extension +from setuptools.command.install import install +from distutils.command.build import build + +__version__ = '1.0.0' + +source_files = ['safran_wrapper.cpp', 'safran_wrapper.i'] + +source_files += glob('../src/*.cpp') +source_files = [src for src in source_files if src not in {'../src/SQLiteExplanation.cpp'}] + +ext_modules = [ + Extension( + '_safran_wrapper', + source_files, + swig_opts=["-c++", "-extranative"], + language='c++', + extra_compile_args=["-std=c++17", "-I../include", "-I../boost_1_76_0", "-fopenmp"], + extra_link_args=["-lgomp"], + ), +] + +class CustomBuild(build): + def run(self): + self.run_command('build_ext') + build.run(self) + +class CustomInstall(install): + def run(self): + self.run_command('build_ext') + self.do_egg_install() + +custom_cmdclass = {'build': CustomBuild, 'install': CustomInstall} + +setup( + name='safran', + version=__version__, + description='SAFRAN: Scalable and fast non-redundant rule application', + author='S. Ott, C. Melicke, M. Samwald, A. Belyy', + url='https://github.com/AVBelyy/SAFRAN', + ext_modules=ext_modules, + cmdclass=custom_cmdclass, + install_requires=[], + setup_requires=[], + py_modules=['safran', 'safran_wrapper'], + zip_safe=False, +) diff --git a/src/InMemoryExplanation.cpp b/src/InMemoryExplanation.cpp new file mode 100644 index 0000000..8447f03 --- /dev/null +++ b/src/InMemoryExplanation.cpp @@ -0,0 +1,96 @@ +#include "InMemoryExplanation.h" + +InMemoryExplanation::InMemoryExplanation() { +} + +InMemoryExplanation::~InMemoryExplanation() = default; + +void InMemoryExplanation::insertEntities(Index* index) { +} + +void InMemoryExplanation::insertRelations(Index* index) { +} + +void InMemoryExplanation::insertRules(RuleReader* rr, int relsize, ClusteringReader* cr) { + int* adj_begin = rr->getCSR()->getAdjBegin(); + Rule* rules_adj_list = rr->getCSR()->getAdjList(); + + if (cr == nullptr) { + for (int rel = 0; rel < relsize; rel++) { + int ind_ptr = adj_begin[3 + rel]; + int lenRules = adj_begin[3 + rel + 1] - ind_ptr; + for (int i = 0; i < lenRules; i++) { + Rule &r = rules_adj_list[ind_ptr + i]; + int rule_id = r.getID(); + auto conf = (float)r.getAppliedConfidence(); + ruleConfidences[rule_id] = conf; + } + } + } else { + // Not implemented yet. + return; + } +} + +void InMemoryExplanation::insertTask(int task_id, bool is_head, int relation_id, int entity_id) { +#pragma omp critical + { + tasks[task_id] = {relation_id, entity_id, is_head}; + } +} + +void InMemoryExplanation::insertPrediction(int task_id, int entity_id, bool hit, double confidence) { +} + +void InMemoryExplanation::insertRule_Entity(int rule_id, int task_id, int other_entity_id) { + bool need_update = true; + if (taskEntityBestRules.find(task_id) != taskEntityBestRules.end()) { + if (taskEntityBestRules[task_id].find(other_entity_id) != taskEntityBestRules[task_id].end()) { + int prev_rule_id = taskEntityBestRules[task_id][other_entity_id]; + float prev_conf = ruleConfidences[prev_rule_id], new_conf = ruleConfidences[rule_id]; + if (new_conf <= prev_conf) { + need_update = false; + } + } + } + if (need_update) { + auto & p = tasks[task_id]; + int relation_id = std::get<0>(p), entity_id = std::get<1>(p), is_head = std::get<2>(p); + int head_id, tail_id; + if (is_head) { + head_id = entity_id; + tail_id = other_entity_id; + } else { + head_id = other_entity_id; + tail_id = entity_id; + } +#pragma omp critical + // Update the current best rule, according to the "applymax" action + { + taskEntityBestRules[task_id][other_entity_id] = rule_id; + tripleBestRules[head_id][relation_id][tail_id] = rule_id; + } + } +} + +void InMemoryExplanation::begin_tr() { +} + +void InMemoryExplanation::begin() { +} + +void InMemoryExplanation::commit() { +} + +void InMemoryExplanation::commit_tr() { +} + +int InMemoryExplanation::getNextTaskID() { + int task_id_; +#pragma omp critical + { + _task_id++; + task_id_ = _task_id; + } + return task_id_; +} \ No newline at end of file diff --git a/src/JaccardEngine.cpp b/src/JaccardEngine.cpp index 338c1c9..fb87093 100644 --- a/src/JaccardEngine.cpp +++ b/src/JaccardEngine.cpp @@ -92,15 +92,7 @@ void JaccardEngine::calc_sols(std::vector* solutions, Rule** rules, i std::vector results; rulegraph->searchDFSMultiStart(currRule, true, results); - std::vector filt_results; - for (auto res : results) { - if (currRule.head_exceptions.find(res) != currRule.head_exceptions.end()) { - continue; - } - filt_results.push_back(res); - } - - heads.push_back(filt_results); + heads.push_back(results); tails.push_back(std::vector {*currRule.getHeadconstant()}); } } @@ -116,16 +108,7 @@ void JaccardEngine::calc_sols(std::vector* solutions, Rule** rules, i std::vector results; rulegraph->searchDFSMultiStart(currRule, true, results); heads.push_back(std::vector {*currRule.getHeadconstant()}); - - std::vector filt_results; - for (auto res : results) { - if (currRule.tail_exceptions.find(res) != currRule.tail_exceptions.end()) { - continue; - } - filt_results.push_back(res); - } - - tails.push_back(filt_results); + tails.push_back(results); } } else { @@ -144,22 +127,11 @@ void JaccardEngine::calc_sols(std::vector* solutions, Rule** rules, i int ind_ptr = adj_list[start_indptr + val]; int len = adj_list[start_indptr + val + 1] - ind_ptr; if (len > 0) { - if (currRule.head_exceptions.find(val) != currRule.head_exceptions.end()) { - continue; - } std::vector results; rulegraph->searchDFSSingleStart(val, currRule, false, results, previous, visited); if (results.size() > 0) { - std::vector filt_results; - for (auto res : results) { - if (currRule.tail_exceptions.find(res) != currRule.tail_exceptions.end()) { - continue; - } - filt_results.push_back(res); - } - heads.push_back(std::vector {val}); - tails.push_back(filt_results); + tails.push_back(results); } for (int i = 0; i < rulelength; i++) { std::fill(visited[i], visited[i] + size, false); @@ -188,38 +160,10 @@ void JaccardEngine::calc_jaccs(std::vector* solutions, Rule** rules, } for (int j = 0; j < len; j++) { if (i != j) { - /* - Rule rule_i = rules[i]; - Rule rule_j = rules[j]; - - if (rule_i.is_c() && rule_j.is_c()) { - double jaccard = calc_jacc_samp(solutions[i], solutions[j], true); - if (jaccard > 0.0) { - jacc[i].push_back(std::make_pair(j, jaccard)); - } - } - else if (rule_i.is_ac2() && rule_j.is_ac2()) { - double jaccard = calc_jacc_samp(solutions[i], solutions[j], true); - if (jaccard > 0.0) { - jacc[i].push_back(std::make_pair(j, jaccard)); - } - } - else if ((rule_i.is_c() && rule_j.is_ac2()) || (rule_i.is_ac2() && rule_j.is_c())) { - double jaccard = calc_jacc_samp(solutions[i], solutions[j], true); - if (jaccard > 0.0) { - jacc[i].push_back(std::make_pair(j, jaccard)); - } - } - */ Rule& rule_i = *rules[i]; Rule& rule_j = *rules[j]; - /* - if (rule_i.get_body_hash() == rule_j.get_body_hash()) { - jacc[i].push_back(std::make_pair(j, 1.0)); - } - else { - */ + if ((rule_i.is_ac2() || rule_i.is_ac1()) && (rule_j.is_ac2() || rule_j.is_ac1()) && rule_i.getRuletype() == rule_j.getRuletype() && *rule_i.getHeadconstant() != *rule_j.getHeadconstant()) { continue; } diff --git a/src/Rule.cpp b/src/Rule.cpp index 53250c3..1813450 100644 --- a/src/Rule.cpp +++ b/src/Rule.cpp @@ -43,6 +43,9 @@ void Rule::setBodyconstantId(int * id) { void Rule::setRulestring(std::string rule) { rulestring = rule; } +void Rule::setID(int ID) { + this->ID = ID; +} void Rule::setBuffer(std::vector buffer) { this->buffer = buffer; @@ -51,31 +54,44 @@ void Rule::setBuffer(std::vector buffer) { bool Rule::isBuffered() { return buffered; } +std::vector& Rule::getBuffer() { + return *buffer; +} +void Rule::removeBuffer() { + (*buffer).clear(); + buffered = false; +} void Rule::setHeadBuffer(int head, std::vector buffer) { - headBuffer[head] = buffer; + if (!headBuffer.has_value()) { + headBuffer = std::unordered_map>(); + } + (*headBuffer)[head] = buffer; } bool Rule::isHeadBuffered(int head) { - return headBuffer.find(head) != headBuffer.end(); + return (*headBuffer).find(head) != (*headBuffer).end(); } std::vector& Rule::getHeadBuffered(int head) { - return headBuffer[head]; + return (*headBuffer)[head]; } void Rule::clearHeadBuffer() { - headBuffer.clear(); + (*headBuffer).clear(); } void Rule::setTailBuffer(int tail, std::vector buffer) { - tailBuffer[tail] = buffer; + if (!tailBuffer.has_value()) { + tailBuffer = std::unordered_map>(); + } + (*tailBuffer)[tail] = buffer; } bool Rule::isTailBuffered(int tail) { - return tailBuffer.find(tail) != tailBuffer.end(); + return (*tailBuffer).find(tail) != (*tailBuffer).end(); } std::vector& Rule::getTailBuffered(int tail) { - return tailBuffer[tail]; + return (*tailBuffer)[tail]; } void Rule::clearTailBuffer() { - tailBuffer.clear(); + (*tailBuffer).clear(); } bool Rule::is_c() { @@ -99,15 +115,10 @@ bool Rule::is_ac2() { return false; } -void Rule::add_head_exception(int val) { - head_exceptions.insert(val); -} - -void Rule::add_tail_exception(int val) { - tail_exceptions.insert(val); -} - //Getter +int& Rule::getID() { +return ID; +} Ruletype Rule::getRuletype() { return type; } @@ -141,13 +152,6 @@ long long Rule::getPredicted() { std::string Rule::getRulestring() { return rulestring; } -std::vector& Rule::getBuffer() { - return buffer; -} -void Rule::removeBuffer() { - buffer.clear(); - buffered = false; -} long long Rule::get_body_hash() { return bodyhash; @@ -199,6 +203,7 @@ bool Rule::is_body_equal(Rule other) { Rule& Rule::operator=(Rule* other) { + ID = other->ID; type = other->type; rulelength = other->rulelength; predicted = other->predicted; @@ -212,7 +217,5 @@ Rule& Rule::operator=(Rule* other) rulestring = other->rulestring; applied_confidence = other->applied_confidence; bodyhash = other->bodyhash; - tail_exceptions = other->tail_exceptions; - head_exceptions = other->head_exceptions; return *this; } diff --git a/src/RuleApplication.cpp b/src/RuleApplication.cpp index 90fbd62..d3d14d2 100644 --- a/src/RuleApplication.cpp +++ b/src/RuleApplication.cpp @@ -1,6 +1,17 @@ #include "RuleApplication.h" -RuleApplication::RuleApplication(Index* index, TraintripleReader* graph, TesttripleReader* ttr, ValidationtripleReader* vtr, RuleReader* rr) { +RuleApplication::RuleApplication(Index* index, TraintripleReader* graph, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp) { + this->index = index; + this->graph = graph; + this->vtr = vtr; + this->rr = rr; + this->rulegraph = new RuleGraph(index->getNodeSize(), graph, vtr); + reflexiv_token = *index->getIdOfNodestring(Properties::get().REFLEXIV_TOKEN); + this->k = Properties::get().TOP_K_OUTPUT; + this->exp = exp; +} + +RuleApplication::RuleApplication(Index* index, TraintripleReader* graph, TesttripleReader* ttr, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp) { this->index = index; this->graph = graph; this->ttr = ttr; @@ -9,6 +20,12 @@ RuleApplication::RuleApplication(Index* index, TraintripleReader* graph, Testtri this->rulegraph = new RuleGraph(index->getNodeSize(), graph, ttr, vtr); reflexiv_token = *index->getIdOfNodestring(Properties::get().REFLEXIV_TOKEN); this->k = Properties::get().TOP_K_OUTPUT; + this->exp = exp; +} + +void RuleApplication::updateTTR(TesttripleReader* testReader) { + this->ttr = testReader; + this->rulegraph->updateTTR(testReader); } void RuleApplication::apply_nr_noisy(std::unordered_map>>, std::pair>>>> rel2clusters) { @@ -147,325 +164,54 @@ void RuleApplication::apply_only_max() { fclose(pFile); } -/* -void RuleApplication::max(int rel, std::vector> clusters) { - - int* adj_lists = graph->getCSR()->getAdjList(); - int* adj_list_starts = graph->getCSR()->getAdjBegin(); - - Rule* rules_adj_list = rr->getCSR()->getAdjList(); - int* adj_begin = rr->getCSR()->getAdjBegin(); - - int** testtriples = ttr->getTesttriples(); - int* testtriplessize = ttr->getTesttriplesSize(); - - int* tt_adj_list = ttr->getCSR()->getAdjList(); - int* tt_adj_begin = ttr->getCSR()->getAdjBegin(); - - int* vt_adj_list = vtr->getCSR()->getAdjList(); - int* vt_adj_begin = vtr->getCSR()->getAdjBegin(); - - int nodesize = index->getNodeSize(); - int ind_ptr = adj_begin[3 + rel]; - int lenRules = adj_begin[3 + rel + 1] - ind_ptr; - - - std::unordered_map>>> headTailResults; - std::unordered_map>>> tailHeadResults; - - { - // adj list of testtriple x r ? - int* t_adj_list = &(tt_adj_list[tt_adj_begin[rel * 2]]); - int lenHeads = t_adj_list[1]; // size + 1 of testtriple testtriple heads of a specific relation - auto rHT = ttr->getRelHeadToTails()[rel]; -#pragma omp parallel for schedule(dynamic) - for (int b = 0; b < rHT.bucket_count(); b++) { - double* result_tail = new double[nodesize]; - int * count_tail = new int[nodesize]; - std::fill(result_tail, result_tail + nodesize, 0.0); - std::fill(count_tail, count_tail + nodesize, 0); - for (auto heads = rHT.begin(b); heads != rHT.end(b); heads++) { - int head = heads->first; - int* head_ind_ptr = &t_adj_list[3 + head]; - int lenTails = heads->second.size(); - - - if (lenTails > 0) { - std::vector touched_tails; - - for (int ruleIndex = 0; ruleIndex < lenRules; ruleIndex++) { - Rule& currRule = rules_adj_list[ind_ptr + ruleIndex]; - - std::vector tailresults_vec; - - if (currRule.is_c()) { - rulegraph->searchDFSSingleStart_filt(true, head, head, currRule, false, tailresults_vec, true, true); - } - else { - - if (currRule.isBuffered()) { - if (currRule.getRuletype() == Ruletype::XRule) { - if (util::in_sorted(currRule.getBuffer(), head)) { - tailresults_vec.push_back(*currRule.getHeadconstant()); - } - } - else if (currRule.getRuletype() == Ruletype::YRule && head == *currRule.getHeadconstant()) { - tailresults_vec = currRule.getBuffer(); - } - } - else { - if (currRule.is_ac2() && currRule.getRuletype() == Ruletype::XRule) { - std::vector comp; - rulegraph->searchDFSSingleStart_filt(false, *currRule.getHeadconstant(), *currRule.getBodyconstantId(), currRule, true, comp, true, true); -#pragma omp critical - { - if (!currRule.isBuffered())currRule.setBuffer(comp); - } - if (util::in_sorted(comp, head)) { - tailresults_vec.push_back(*currRule.getHeadconstant()); - } - } - else if (currRule.is_ac2() && currRule.getRuletype() == Ruletype::YRule && head == *currRule.getHeadconstant()) { - rulegraph->searchDFSSingleStart_filt(true, *currRule.getHeadconstant(), *currRule.getBodyconstantId(), currRule, true, tailresults_vec, true, true); -#pragma omp critical - { - if (!currRule.isBuffered())currRule.setBuffer(tailresults_vec); - } - } - else if (currRule.is_ac1() && currRule.getRuletype() == Ruletype::XRule) { - std::vector comp; - rulegraph->searchDFSMultiStart_filt(false, *currRule.getHeadconstant(), currRule, true, comp, true, true); -#pragma omp critical - { - if (!currRule.isBuffered())currRule.setBuffer(comp); - } - if (util::in_sorted(comp, head)) { - tailresults_vec.push_back(*currRule.getHeadconstant()); - } - } - else if (currRule.is_ac1() && currRule.getRuletype() == Ruletype::YRule && head == *currRule.getHeadconstant()) { - rulegraph->searchDFSMultiStart_filt(true, *currRule.getHeadconstant(), currRule, true, tailresults_vec, true, true); -#pragma omp critical - { - if (!currRule.isBuffered())currRule.setBuffer(tailresults_vec); - } - } - } - } - - if (tailresults_vec.size() > 0) { - for (auto tailresult : tailresults_vec) { - if (result_tail[tailresult] == 0.0) { - result_tail[tailresult] = currRule.getAppliedConfidence(); - count_tail[tailresult]++; - touched_tails.push_back(tailresult); - } - else { - if (result_tail[tailresult] < currRule.getAppliedConfidence()) { - throw std::runtime_error("HI"); - } - else { - result_tail[tailresult] = result_tail[tailresult] + (pow(0.0001,(double)count_tail[tailresult]) * currRule.getAppliedConfidence()); - count_tail[tailresult]++; - } - } - } - } - } - - for (int tailIndex = 0; tailIndex < lenTails; tailIndex++) { - int tail = t_adj_list[3 + lenHeads + *head_ind_ptr + tailIndex]; - - MinHeap tails(k); - for (auto i : touched_tails) { - if (result_tail[i] >= tails.getMin().second) { - double confidence = result_tail[i]; - if (i == reflexiv_token) { - i = head; - } - if (i == tail || heads->second.find(i) == heads->second.end()) { - tails.deleteMin(); - tails.insertKey(std::make_pair(i, confidence)); - } - } - } - - std::vector> tailresults_vec; - for (int i = k - 1; i >= 0; i--) { - std::pair tail_pred = tails.extractMin(); - if (tail_pred.first != -1) tailresults_vec.push_back(tail_pred); - } - std::reverse(tailresults_vec.begin(), tailresults_vec.end()); -#pragma omp critical - { - headTailResults[head][tail] = tailresults_vec; - } - - } - for (auto i : touched_tails) { - result_tail[i] = 0.0; - count_tail[i] = 0; - } - } - } - delete[] count_tail; - delete[] result_tail; - } - } - { - // adj list of testtriple x r ? - int* t_adj_list = &(tt_adj_list[tt_adj_begin[rel * 2 + 1]]); - int lenTails = t_adj_list[1]; // size + 1 of testtriple testtriple heads of a specific relation - auto rTH = ttr->getRelTailToHeads()[rel]; -#pragma omp parallel for schedule(dynamic) - for (int b = 0; b < rTH.bucket_count(); b++) { - double* result_head = new double[nodesize]; - int* count_head = new int[nodesize]; - std::fill(result_head, result_head + nodesize, 0.0); - std::fill(count_head, count_head + nodesize, 0); - - for (auto tails = rTH.begin(b); tails != rTH.end(b); tails++) { - int tail = tails->first; - int* tail_ind_ptr = &t_adj_list[3 + tail]; - int lenHeads = tails->second.size(); - - - if (lenHeads > 0) { - std::vector touched_heads; - - for (auto ruleIndex : clusters[0]) { - Rule& currRule = rules_adj_list[ind_ptr + ruleIndex]; - - std::vector headresults_vec; - - if (currRule.is_c()) { - rulegraph->searchDFSSingleStart_filt(false, tail, tail, currRule, true, headresults_vec, true, true); - } - else { - if (currRule.isBuffered()) { - if (currRule.getRuletype() == Ruletype::XRule && tail == *currRule.getHeadconstant()) { - headresults_vec = currRule.getBuffer(); - } - else if (currRule.getRuletype() == Ruletype::YRule) { - if (util::in_sorted(currRule.getBuffer(), tail)) { - headresults_vec.push_back(*currRule.getHeadconstant()); - } - } - } - else { - if (currRule.is_ac2() && currRule.getRuletype() == Ruletype::XRule && tail == *currRule.getHeadconstant()) { - rulegraph->searchDFSSingleStart_filt(false, *currRule.getHeadconstant(), *currRule.getBodyconstantId(), currRule, true, headresults_vec, true, true); -#pragma omp critical - { - if (!currRule.isBuffered()) currRule.setBuffer(headresults_vec); - } - } - else if (currRule.is_ac2() && currRule.getRuletype() == Ruletype::YRule) { - std::vector comp; - rulegraph->searchDFSSingleStart_filt(true, *currRule.getHeadconstant(), *currRule.getBodyconstantId(), currRule, true, comp, true, true); -#pragma omp critical - { - if (!currRule.isBuffered())currRule.setBuffer(comp); - } - if (util::in_sorted(comp, tail)) { - headresults_vec.push_back(*currRule.getHeadconstant()); - } - } - else if (currRule.is_ac1() && currRule.getRuletype() == Ruletype::XRule && tail == *currRule.getHeadconstant()) { - rulegraph->searchDFSMultiStart_filt(false, *currRule.getHeadconstant(), currRule, true, headresults_vec, true, true); -#pragma omp critical - { - if (!currRule.isBuffered()) currRule.setBuffer(headresults_vec); - } - } - else if (currRule.is_ac1() && currRule.getRuletype() == Ruletype::YRule) { - std::vector comp; - rulegraph->searchDFSMultiStart_filt(true, *currRule.getHeadconstant(), currRule, true, comp, true, true); -#pragma omp critical - { - if (!currRule.isBuffered())currRule.setBuffer(comp); - } - if (util::in_sorted(comp, tail)) { - headresults_vec.push_back(*currRule.getHeadconstant()); - } - } - } - } - - if (headresults_vec.size() > 0) { - for (auto headresult : headresults_vec) { - if (result_head[headresult] == 0.0) { - result_head[headresult] = currRule.getAppliedConfidence(); - touched_heads.push_back(headresult); - count_head[headresult]++; - } - else { - if (result_head[headresult] < currRule.getAppliedConfidence()) { - throw std::runtime_error("HI"); - } - else { - result_head[headresult] = result_head[headresult] + (pow(0.0001,count_head[headresult]) * currRule.getAppliedConfidence()); - count_head[headresult]++; - } - } - } - } - } - - for (int headIndex = 0; headIndex < lenHeads; headIndex++) { - int head = t_adj_list[3 + lenTails + *tail_ind_ptr + headIndex]; - - MinHeap heads(k); - for (auto i : touched_heads) { - if (result_head[i] >= heads.getMin().second) { - double confidence = result_head[i]; - if (i == reflexiv_token) { - i = tail; - } - if (i == head || tails->second.find(i) == tails->second.end()) { - heads.deleteMin(); - heads.insertKey(std::make_pair(i, confidence)); - } - } - } - - std::vector> headresults_vec; - for (int i = k - 1; i >= 0; i--) { - std::pair head_pred = heads.extractMin(); - if (head_pred.first != -1) headresults_vec.push_back(head_pred); - } - std::reverse(headresults_vec.begin(), headresults_vec.end()); -#pragma omp critical - { - tailHeadResults[tail][head] = headresults_vec; - } - } - for (auto i : touched_heads) { - result_head[i] = 0.0; - count_head[i] = 0; - } - } - } - delete[] count_head; - delete[] result_head; - } - } -#pragma omp critical - { - auto it_head = headTailResults.begin(); - while (it_head != headTailResults.end()) { - auto it_tail = it_head->second.begin(); - while (it_tail != it_head->second.end()) { - { - writeTopKCandidates(it_head->first, rel, it_tail->first, tailHeadResults[it_tail->first][it_head->first], it_tail->second, pFile, k); - } - it_tail++; - } - it_head++; - } - } +std::vector> RuleApplication::apply_only_max_in_memory(size_t K) { + if (Properties::get().VERBOSE == 1) { + std::cout << "aomim init start" << std::endl; + } + int* adj_begin = rr->getCSR()->getAdjBegin(); + std::vector> out; + + int iterations = index->getRelSize(); + if (Properties::get().VERBOSE == 1) { + std::cout << "aomim init end" << std::endl; + } + for (int rel = 0; rel < iterations; rel++) { + // TODO (Anton): precompute clusters outside of this call! + int ind_ptr = adj_begin[3 + rel]; + int lenRules = adj_begin[3 + rel + 1] - ind_ptr; + if (Properties::get().VERBOSE == 1) { + std::cout << "pred " << rel << " nrules=" << lenRules; + } + std::vector> clusters; + std::vector cluster; + for (int i = 0; i < lenRules; i++) { + cluster.push_back(i); + } + clusters.push_back(cluster); + + std::unordered_map>>> headTailResults; + headTailResults = max(rel, clusters, false); + if (Properties::get().VERBOSE == 1) { + std::cout << " nresults=" << headTailResults.size() << std::endl; + } + + for (auto & [head, tailResults] : headTailResults) { + for (auto & [_, tails] : tailResults) { + size_t maxSize = std::min(tails.size(), K); + size_t i = 0; + for (auto [tail, val] : tails) { + if (i >= maxSize) { + break; + } + out.emplace_back(head, rel, tail, val); + i++; + } + } + } + } + + return out; } -*/ std::unordered_map>>> RuleApplication::noisy(int rel, std::vector> clusters, bool predictHeadNotTail) { int* adj_lists = graph->getCSR()->getAdjList(); @@ -501,6 +247,7 @@ std::unordered_map> entityToRules; for (auto heads = rHT.begin(b); heads != rHT.end(b); heads++) { int head = heads->first; int* head_ind_ptr = &t_adj_list[3 + head]; @@ -542,6 +289,9 @@ std::unordered_map 0) { for (auto tailresult : tailresults_vec) { + if (this->exp != nullptr) { + entityToRules[tailresult].push_back(currRule.getID()); + } if (cluster_result_tail[tailresult] == 0.0) { cluster_result_tail[tailresult] = currRule.getAppliedConfidence(); touched_cluster_tails.push_back(tailresult); @@ -563,6 +313,44 @@ std::unordered_mapexp != nullptr) { + + MinHeap tails(10); + for (auto i : touched_tails) { + if (result_tail[i] >= tails.getMin().second) { + tails.deleteMin(); + tails.insertKey(std::make_pair(i, result_tail[i])); + } + } + + std::vector> tailresults_vec; + for (int i = 9; i >= 0; i--) { + std::pair tail_pred = tails.extractMin(); + if (tail_pred.first != -1) tailresults_vec.push_back(tail_pred); + } + std::reverse(tailresults_vec.begin(), tailresults_vec.end()); + + int task_id = exp->getNextTaskID(); + exp->begin(); + exp->insertTask(task_id, true, rel, head); + for (auto p : tailresults_vec) { + bool hit = false; + if (heads->second.find(p.first) != heads->second.end()) { + hit = true; + } + exp->insertPrediction(task_id, p.first, hit, (double) p.second); + auto it = entityToRules.find(p.first); + if (it != entityToRules.end()) { + for (auto rule : it->second) { + exp->insertRule_Entity(rule, task_id, p.first); + } + } + } + exp->commit(); + entityToRules.clear(); + } + for (int tailIndex = 0; tailIndex < lenTails; tailIndex++) { int tail = t_adj_list[3 + lenHeads + *head_ind_ptr + tailIndex]; @@ -612,7 +400,7 @@ std::unordered_map> entityToRules; for (auto tails = rTH.begin(b); tails != rTH.end(b); tails++) { int tail = tails->first; int* tail_ind_ptr = &t_adj_list[3 + tail]; @@ -654,6 +442,9 @@ std::unordered_map 0) { for (auto headresult : headresults_vec) { + if (this->exp != nullptr) { + entityToRules[headresult].push_back(currRule.getID()); + } if (cluster_result_head[headresult] == 0.0) { cluster_result_head[headresult] = currRule.getAppliedConfidence(); touched_cluster_heads.push_back(headresult); @@ -675,6 +466,43 @@ std::unordered_mapexp != nullptr) { + MinHeap heads(10); + for (auto i : touched_heads) { + if (result_head[i] >= heads.getMin().second) { + heads.deleteMin(); + heads.insertKey(std::make_pair(i, result_head[i])); + } + } + + std::vector> headresults_vec; + for (int i = 9; i >= 0; i--) { + std::pair head_pred = heads.extractMin(); + if (head_pred.first != -1) headresults_vec.push_back(head_pred); + } + std::reverse(headresults_vec.begin(), headresults_vec.end()); + + int task_id = exp->getNextTaskID(); + exp->begin(); + exp->insertTask(task_id, false, rel, tail); + for (auto p : headresults_vec) { + bool hit = false; + if (tails->second.find(p.first) != tails->second.end()) { + hit = true; + } + exp->insertPrediction(task_id, p.first, hit, (double) p.second); + auto it = entityToRules.find(p.first); + if (it != entityToRules.end()) { + for (auto rule : it->second) { + exp->insertRule_Entity(rule, task_id, p.first); + } + } + } + exp->commit(); + entityToRules.clear(); + } + for (int headIndex = 0; headIndex < lenHeads; headIndex++) { int head = t_adj_list[3 + lenTails + *tail_ind_ptr + headIndex]; @@ -740,7 +568,6 @@ std::unordered_map>>> entityEntityResults; if (!predictHeadNotTail) @@ -751,6 +578,7 @@ std::unordered_mapgetRelHeadToTails()[rel]; #pragma omp parallel for schedule(dynamic) for (int b = 0; b < rHT.bucket_count(); b++) { + std::unordered_map> entityToRules; for (auto heads = rHT.begin(b); heads != rHT.end(b); heads++) { int head = heads->first; int* head_ind_ptr = &t_adj_list[3 + head]; @@ -760,6 +588,12 @@ std::unordered_map 0) { ScoreTree* tailScoreTrees = new ScoreTree[lenTails]; std::vector fineScoreTrees(lenTails); + + ScoreTree* expScoreTree = nullptr; + if (this->exp != nullptr) { + expScoreTree = new ScoreTree(); + } + bool stop = false; for (auto ruleIndex : clusters[0]) { Rule& currRule = rules_adj_list[ind_ptr + ruleIndex]; @@ -788,6 +622,14 @@ std::unordered_map 0) { stop = true; + + if (this->exp != nullptr) { + for (auto a : tailresults_vec) { + entityToRules[a].push_back(currRule.getID()); + } + expScoreTree->addValues(currRule.getAppliedConfidence(), &tailresults_vec[0], tailresults_vec.size()); + } + for (int tailIndex = 0; tailIndex < lenTails; tailIndex++) { if (fineScoreTrees[tailIndex] == false) { int tail = t_adj_list[3 + lenHeads + *head_ind_ptr + tailIndex]; @@ -816,6 +658,31 @@ std::unordered_mapexp != nullptr) { + std::vector> tailresults_vec; + expScoreTree->getResults(tailresults_vec); + std::sort(tailresults_vec.begin(), tailresults_vec.end(), finalResultComperator); + + int task_id = exp->getNextTaskID(); + exp->begin(); + exp->insertTask(task_id, true, rel, head); + for (auto p : tailresults_vec) { + bool hit = false; + if (heads->second.find(p.first) != heads->second.end()) { + hit = true; + } + exp->insertPrediction(task_id, p.first, hit, p.second); + auto it = entityToRules.find(p.first); + if (it != entityToRules.end()) { + for (auto rule : it->second) { + exp->insertRule_Entity(rule, task_id, p.first); + } + } + } + exp->commit(); + entityToRules.clear(); + } + for (int tailIndex = 0; tailIndex < lenTails; tailIndex++) { int tail = t_adj_list[3 + lenHeads + *head_ind_ptr + tailIndex]; auto cmp = [](std::pair const& a, std::pair const& b) @@ -842,6 +709,9 @@ std::unordered_mapexp != nullptr) { + delete expScoreTree; + } } } } @@ -854,6 +724,7 @@ std::unordered_mapgetRelTailToHeads()[rel]; #pragma omp parallel for schedule(dynamic) for (int b = 0; b < rTH.bucket_count(); b++) { + std::unordered_map> entityToRules; for (auto tails = rTH.begin(b); tails != rTH.end(b); tails++) { int tail = tails->first; @@ -863,6 +734,12 @@ std::unordered_map 0) { ScoreTree* headScoreTrees = new ScoreTree[lenHeads]; std::vector fineScoreTrees(lenHeads); + + ScoreTree* expScoreTree = nullptr; + if (this->exp != nullptr) { + expScoreTree = new ScoreTree(); + } + bool stop = false; for (auto ruleIndex : clusters[0]) { Rule& currRule = rules_adj_list[ind_ptr + ruleIndex]; @@ -892,6 +769,13 @@ std::unordered_map 0) { + if (this->exp != nullptr) { + for (auto a : headresults_vec) { + entityToRules[a].push_back(currRule.getID()); + } + expScoreTree->addValues(currRule.getAppliedConfidence(), &headresults_vec[0], headresults_vec.size()); + } + stop = true; for (int headIndex = 0; headIndex < lenHeads; headIndex++) { if (fineScoreTrees[headIndex] == false) { @@ -922,6 +806,31 @@ std::unordered_mapexp != nullptr) { + std::vector> headresults_vec; + expScoreTree->getResults(headresults_vec); + std::sort(headresults_vec.begin(), headresults_vec.end(), finalResultComperator); + + int task_id = exp->getNextTaskID(); + exp->begin(); + exp->insertTask(task_id, false, rel, tail); + for (auto p : headresults_vec) { + bool hit = false; + if (tails->second.find(p.first) != tails->second.end()) { + hit = true; + } + exp->insertPrediction(task_id, p.first, hit, p.second); + auto it = entityToRules.find(p.first); + if (it != entityToRules.end()) { + for (auto rule : it->second) { + exp->insertRule_Entity(rule, task_id, p.first); + } + } + } + exp->commit(); + entityToRules.clear(); + } + for (int headIndex = 0; headIndex < lenHeads; headIndex++) { int head = t_adj_list[3 + lenTails + *tail_ind_ptr + headIndex]; // Get Headresults and final sorting @@ -944,6 +853,9 @@ std::unordered_mapexp != nullptr) { + delete expScoreTree; + } } } } diff --git a/src/RuleGraph.cpp b/src/RuleGraph.cpp index 4035a46..bb48b21 100644 --- a/src/RuleGraph.cpp +++ b/src/RuleGraph.cpp @@ -21,6 +21,23 @@ RuleGraph::RuleGraph(int nodesize, TraintripleReader* graph, TesttripleReader* t this->relCounter = graph->getRelCounter(); } +RuleGraph::RuleGraph(int nodesize, TraintripleReader* graph, ValidationtripleReader* vtr) { + this->size = nodesize; + this->graph = graph; + adj_lists = graph->getCSR()->getAdjList(); + adj_list_starts = graph->getCSR()->getAdjBegin(); + this->train_relHeadToTails = graph->getRelHeadToTails(); + this->train_relTailToHeads = graph->getRelTailToHeads(); + this->valid_relHeadToTails = vtr->getRelHeadToTails(); + this->valid_relTailToHeads = vtr->getRelTailToHeads(); + this->relCounter = graph->getRelCounter(); +} + +void RuleGraph::updateTTR(TesttripleReader* ttr) { + this->test_relHeadToTails = ttr->getRelHeadToTails(); + this->test_relTailToHeads = ttr->getRelTailToHeads(); +} + void RuleGraph::searchDFSSingleStart_filt(bool headNotTail, int filt_v, int v, Rule& r, bool bwd, std::vector& solution, bool filtValidNotTest, bool filtExceptions) { int rulelength = r.getRulelength(); @@ -183,12 +200,6 @@ void RuleGraph::searchDFSMultiStart_filt(bool headNotTail, int filt_v, Rule& r, for (int val = 0; val < size_indptr - 1; val++) { //OI if (val == *r.getHeadconstant()) continue; - if (r.getRuletype() == Ruletype::YRule && r.head_exceptions.find(val) != r.head_exceptions.end()) { - continue; - } - if (r.getRuletype() == Ruletype::XRule && r.tail_exceptions.find(val) != r.tail_exceptions.end()) { - continue; - } int ind_ptr = adj_list[start_indptr + val]; int len = adj_list[start_indptr + val + 1] - ind_ptr; if (len > 0) { @@ -325,12 +336,6 @@ void RuleGraph::searchDFSMultiStart(Rule& r, bool bwd, std::vector& solutio for (int val = 0; val < size_indptr - 1; val++) { //OI if (val == *r.getHeadconstant()) continue; - if (r.getRuletype() == Ruletype::YRule && r.head_exceptions.find(val) != r.head_exceptions.end()) { - continue; - } - if (r.getRuletype() == Ruletype::XRule && r.tail_exceptions.find(val) != r.tail_exceptions.end()) { - continue; - } int ind_ptr = adj_list[start_indptr + val]; int len = adj_list[start_indptr + val + 1] - ind_ptr; if (len > 0) { @@ -384,15 +389,6 @@ void RuleGraph::searchDFSUtil_filt(Rule* r, bool headNotTail, int filt_value, in return; } } - - if (filtExceptions && (r->is_c() || r->is_ac1())) { - if (!headNotTail && r->head_exceptions.find(ex_head) != r->head_exceptions.end()) { - return; - } - if (headNotTail && r->tail_exceptions.find(ex_tail) != r->tail_exceptions.end()) { - return; - } - } if (Properties::get().ONLY_UNCONNECTED == 1) { if ((*relCounter).find(head) != (*relCounter).end()) { diff --git a/src/RuleReader.cpp b/src/RuleReader.cpp index 5a54f63..6db1d19 100644 --- a/src/RuleReader.cpp +++ b/src/RuleReader.cpp @@ -11,6 +11,7 @@ CSR * RuleReader::getCSR() { } void RuleReader::read(std::string filepath) { + int currID = 0; RelToRules rules; std::string line; std::ifstream myfile(filepath); @@ -18,7 +19,8 @@ void RuleReader::read(std::string filepath) { while (!util::safeGetline(myfile, line).eof()) { std::vector rawrule = util::split(line, '\t'); - Rule * r = parseRule(rawrule); + Rule * r = parseRule(rawrule, currID); + currID++; if (r != nullptr) { if (Properties::get().ONLY_XY == 0 || (Properties::get().ONLY_XY == 1 && r->getRuletype() == Ruletype::XYRule)) { //TODO no insert if rule bad, is probably never the cas (Rules are sampled from trainset) @@ -38,11 +40,12 @@ void RuleReader::read(std::string filepath) { csr = new CSR(index->getRelSize(), rules); } -Rule* RuleReader::parseRule(std::vector rule) { +Rule* RuleReader::parseRule(std::vector rule, int currID) { Rule * ruleObj = new Rule(std::stoi(rule[0]), std::stoi(rule[1]), std::stod(rule[2])); std::string rawrule = rule[3]; ruleObj->setRulestring(rawrule); + ruleObj->setID(currID); std::stringstream ss(rawrule); Ruletype type = Ruletype::None; diff --git a/src/SQLiteExplanation.cpp b/src/SQLiteExplanation.cpp new file mode 100644 index 0000000..a609043 --- /dev/null +++ b/src/SQLiteExplanation.cpp @@ -0,0 +1,332 @@ +#include "SQLiteExplanation.h" + +SQLiteExplanation::SQLiteExplanation(std::string dbName, bool init) { + checkErrorCode(sqlite3_open_v2(dbName.c_str(), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)); + if (init) initDb(); +} + +SQLiteExplanation::~SQLiteExplanation() { + sqlite3_close(db); +} + +void SQLiteExplanation::initDb() { + + checkErrorCode(sqlite3_exec(db, "PRAGMA synchronous=OFF", NULL, NULL, NULL)); + checkErrorCode(sqlite3_exec(db, "PRAGMA count_changes=OFF", NULL, NULL, NULL)); + checkErrorCode(sqlite3_exec(db, "PRAGMA journal_mode=MEMORY", NULL, NULL, NULL)); + checkErrorCode(sqlite3_exec(db, "PRAGMA temp_store=MEMORY", NULL, NULL, NULL)); + + char* sql; + + sql = "CREATE TABLE Entity(" \ + "ID INT PRIMARY KEY NOT NULL," \ + "NAME TEXT NOT NULL" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Relation(" \ + "ID INT PRIMARY KEY NOT NULL," \ + "NAME TEXT NOT NULL" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Rule(" \ + "ID INT PRIMARY KEY NOT NULL," \ + "HEAD_CLUSTER_ID INT NOT NULL," \ + "TAIL_CLUSTER_ID INT NOT NULL," \ + "DEF TEXT NOT NULL," \ + "CONFIDENCE DOUBLE NOT NULL,"\ + "PREDICTED INT NOT NULL,"\ + "CORRECTLY_PREDICTED INT NOT NULL"\ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Task(" \ + "ID INT PRIMARY KEY NOT NULL," \ + "IsHead BOOL," \ + "EntityID INT," \ + "RelationID INT," \ + "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ + "FOREIGN KEY(RelationID) REFERENCES Relation(ID)" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Prediction(" \ + "TaskID INT NOT NULL," \ + "EntityID INT NOT NULL," \ + "Hit BOOL NOT NULL," + "CONFIDENCE REAL NOT NULL,"\ + "FOREIGN KEY(TaskID) REFERENCES Task(ID)," \ + "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ + "PRIMARY KEY(TaskID, EntityID)" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Rule_Entity(" \ + "RuleID INT NOT NULL," \ + "TaskID INT NOT NULL," \ + "EntityID INT NOT NULL," \ + "FOREIGN KEY(RuleID) REFERENCES Rule(ID)," \ + "FOREIGN KEY(TaskID) REFERENCES Task(ID)," \ + "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ + "PRIMARY KEY(RuleID, TaskID, EntityID)" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE INDEX asdf ON Rule_Entity (" \ + "TaskID," \ + "EntityID" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); +} + +void SQLiteExplanation::insertEntities(Index* index) { + char* sql = "INSERT INTO Entity (ID,NAME) VALUES (?,?);"; + sqlite3_stmt* stmt = prepare(sql); + for (int id = 0; id < index->getNodeSize(); id++) { + + // bind the value + checkErrorCode(sqlite3_bind_int(stmt, 1, id)); + + std::string curie = *index->getStringOfNodeId(id); + checkErrorCode(sqlite3_bind_text(stmt, 2, curie.c_str(), strlen(curie.c_str()), 0)); + + checkErrorCode(sqlite3_step(stmt)); + checkErrorCode(sqlite3_reset(stmt)); + } + finalize(stmt); +} +void SQLiteExplanation::insertRelations(Index* index) { + char* sql = "INSERT INTO Relation (ID,NAME) VALUES (?,?);"; + sqlite3_stmt* stmt = prepare(sql); + for (int id = 0; id < index->getRelSize(); id++) { + + // bind the value + checkErrorCode(sqlite3_bind_int(stmt, 1, id)); + + std::string relation = *index->getStringOfRelId(id); + checkErrorCode(sqlite3_bind_text(stmt, 2, relation.c_str(), strlen(relation.c_str()), 0)); + + checkErrorCode(sqlite3_step(stmt)); + checkErrorCode(sqlite3_reset(stmt)); + } + finalize(stmt); +} + +void SQLiteExplanation::insertRules(RuleReader* rr, int relsize, ClusteringReader* cr) { + int* adj_begin = rr->getCSR()->getAdjBegin(); + Rule* rules_adj_list = rr->getCSR()->getAdjList(); + + char* sql = "INSERT INTO Rule (ID,HEAD_CLUSTER_ID,TAIL_CLUSTER_ID,DEF,CONFIDENCE,PREDICTED,CORRECTLY_PREDICTED) VALUES (?,?,?,?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); + + if (cr == nullptr) { + for (int rel = 0; rel < relsize; rel++) { + int ind_ptr = adj_begin[3 + rel]; + int lenRules = adj_begin[3 + rel + 1] - ind_ptr; + for (int i = 0; i < lenRules; i++) { + + Rule& r = rules_adj_list[ind_ptr + i]; + // bind the value + checkErrorCode(sqlite3_bind_int(stmt, 1, r.getID())); + + checkErrorCode(sqlite3_bind_int(stmt, 2, 0)); + checkErrorCode(sqlite3_bind_int(stmt, 3, 0)); + + std::string rulestr = r.getRulestring(); + checkErrorCode(sqlite3_bind_text(stmt, 4, rulestr.c_str(), strlen(rulestr.c_str()), 0)); + + checkErrorCode(sqlite3_bind_double(stmt, 5, r.getAppliedConfidence())); + checkErrorCode(sqlite3_bind_int(stmt, 6, r.getPredicted())); + checkErrorCode(sqlite3_bind_int(stmt, 7, r.getCorrectlyPredicted())); + + checkErrorCode(sqlite3_step(stmt)); + checkErrorCode(sqlite3_reset(stmt)); + } + } + } + else { + std::unordered_map>>, std::pair>>>> rel2clusters = cr->getRelToClusters(); + for (int rel = 0; rel < relsize; rel++) { + int ind_ptr = adj_begin[3 + rel]; + int lenRules = adj_begin[3 + rel + 1] - ind_ptr; + + std::unordered_map ruleToHeadCluster; + std::unordered_map ruleToTailCluster; + + // Head clusters + std::vector> cluster = rel2clusters[rel].first.second; + for (int cluster_id = 0; cluster_id < cluster.size(); cluster_id++) { + for (auto rule : cluster[cluster_id]) { + Rule& r = rules_adj_list[ind_ptr + rule]; + ruleToHeadCluster[r.getID()] = cluster_id; + } + } + + cluster = rel2clusters[rel].second.second; + for (int cluster_id = 0; cluster_id < cluster.size(); cluster_id++) { + for (auto rule : cluster[cluster_id]) { + Rule& r = rules_adj_list[ind_ptr + rule]; + ruleToTailCluster[r.getID()] = cluster_id; + } + } + + for (int i = 0; i < lenRules; i++) { + + Rule& r = rules_adj_list[ind_ptr + i]; + // bind the value + checkErrorCode(sqlite3_bind_int(stmt, 1, r.getID())); + + checkErrorCode(sqlite3_bind_int(stmt, 2, ruleToHeadCluster[r.getID()])); + + checkErrorCode(sqlite3_bind_int(stmt, 3, ruleToTailCluster[r.getID()])); + + std::string rulestr = r.getRulestring(); + checkErrorCode(sqlite3_bind_text(stmt, 4, rulestr.c_str(), strlen(rulestr.c_str()), 0)); + + checkErrorCode(sqlite3_bind_double(stmt, 5, r.getAppliedConfidence())); + checkErrorCode(sqlite3_bind_int(stmt, 6, r.getPredicted())); + checkErrorCode(sqlite3_bind_int(stmt, 7, r.getCorrectlyPredicted())); + + checkErrorCode(sqlite3_step(stmt)); + checkErrorCode(sqlite3_reset(stmt)); + } + } + } + + finalize(stmt); +} + +void SQLiteExplanation::insertTask(int task_id, bool is_head, int relation_id, int entity_id) { + //std::cout << "Task " << task_id << " " << is_head << " " << relation_id << " " << entity_id << "\n"; + char* sql = "INSERT INTO Task (ID,IsHead,RelationID,EntityId) VALUES (?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); + + checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert task bind 1"); + checkErrorCode(sqlite3_bind_int(stmt, 2, (int)is_head), "insert task bind 2"); + checkErrorCode(sqlite3_bind_int(stmt, 3, relation_id), "insert task bind 3"); + checkErrorCode(sqlite3_bind_int(stmt, 4, entity_id), "insert task bind 4"); + + checkErrorCode(sqlite3_step(stmt), "insert task step"); + finalize(stmt); +} + +void SQLiteExplanation::insertPrediction(int task_id, int entity_id, bool hit, double confidence) { + //std::cout << "Prediction " << task_id << " " << entity_id << " " << " " << confidence << "\n"; + char* sql = "INSERT INTO Prediction (TaskID,EntityID,Hit,Confidence) VALUES (?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); + + checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert pred bind 1"); + checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id), "insert pred bind 2"); + checkErrorCode(sqlite3_bind_double(stmt, 3, (int)hit), "insert pred bind 3"); + checkErrorCode(sqlite3_bind_double(stmt, 4, confidence), "insert pred bind 4"); + + checkErrorCode(sqlite3_step(stmt), "insert pred step"); + finalize(stmt); +} + +void SQLiteExplanation::insertCluster(int task_id, int entity_id, int cluster_id, double confidence) { + //std::cout << "Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << confidence << "\n"; + char* sql = "INSERT INTO Cluster (PredictionTaskID, PredictionEntityID, ID, Confidence) VALUES (?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); + + checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert clus bind 1"); + checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id), "insert clus bind 2"); + checkErrorCode(sqlite3_bind_int(stmt, 3, cluster_id), "insert clus bind 3"); + checkErrorCode(sqlite3_bind_double(stmt, 4, confidence), "insert clus bind 1"); + + checkErrorCode(sqlite3_step(stmt), sql); + finalize(stmt); +} + +void SQLiteExplanation::insertRule_Cluster(int task_id, int entity_id, int cluster_id, int rule_id) { + //std::cout << "Rule_Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << rule_id << "\n"; + char* sql = "INSERT INTO Rule_Cluster (ClusterPredictionTaskID,ClusterPredictionEntityID,ClusterID,RuleID) VALUES (?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); + + checkErrorCode(sqlite3_bind_int(stmt, 1, task_id)); + checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id)); + checkErrorCode(sqlite3_bind_int(stmt, 3, cluster_id)); + checkErrorCode(sqlite3_bind_int(stmt, 4, rule_id)); + + checkErrorCode(sqlite3_step(stmt)); + finalize(stmt); +} + +void SQLiteExplanation::insertRule_Entity(int rule_id, int task_id, int entity_id) { + //std::cout << "Rule_Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << rule_id << "\n"; + char* sql = "INSERT OR IGNORE INTO Rule_Entity (RuleID, TaskID, EntityID) VALUES (?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); + + checkErrorCode(sqlite3_bind_int(stmt, 1, rule_id)); + checkErrorCode(sqlite3_bind_int(stmt, 2, task_id)); + checkErrorCode(sqlite3_bind_int(stmt, 3, entity_id)); + + checkErrorCode(sqlite3_step(stmt)); + finalize(stmt); +} + +void SQLiteExplanation::begin_tr() { + checkErrorCode(sqlite3_exec(db, "BEGIN TRANSACTION", NULL, NULL, NULL), "BEGIN TRANSACTION"); +} + +void SQLiteExplanation::begin() { + //sqlite3_mutex_enter(sqlite3_db_mutex(db)); +} + +sqlite3_stmt* SQLiteExplanation::prepare(char* sql) { + sqlite3_stmt* stmt; + checkErrorCode(sqlite3_prepare(db, sql, -1, &stmt, NULL), sql); + return stmt; +} + +void SQLiteExplanation::finalize(sqlite3_stmt* stmt) { + checkErrorCode(sqlite3_finalize(stmt)); +} + +void SQLiteExplanation::commit() { + //sqlite3_mutex_leave(sqlite3_db_mutex(db)); +} + +void SQLiteExplanation::commit_tr() { + checkErrorCode(sqlite3_exec(db, "COMMIT TRANSACTION", NULL, NULL, NULL), "COMMIT TRANSACTION"); +} + +void SQLiteExplanation::checkErrorCode(int code) { + if (code != SQLITE_OK && code != SQLITE_DONE) { + const char* err; + if (this->db) { + err = sqlite3_errmsg(this->db); + } + else { + err = sqlite3_errstr(code); + } + std::cerr << "Error " << code << ": " << err << '\n'; + std::exit(EXIT_FAILURE); + } +} + +void SQLiteExplanation::checkErrorCode(int code, char* sql) { + if (code != SQLITE_OK && code != SQLITE_DONE) { + const char* err; + if (this->db) { + err = sqlite3_errmsg(this->db); + } + else { + err = sqlite3_errstr(code); + } + std::cerr << sql << "\n"; + std::cerr << "Error " << code << ": " << err << '\n'; + std::exit(EXIT_FAILURE); + } +} + +int SQLiteExplanation::getNextTaskID() { + int task_id_; +#pragma omp critical + { + _task_id++; + task_id_ = _task_id; + } + return task_id_; +} diff --git a/src/TesttripleReader.cpp b/src/TesttripleReader.cpp index 57be33c..0656628 100644 --- a/src/TesttripleReader.cpp +++ b/src/TesttripleReader.cpp @@ -1,10 +1,9 @@ #include "TesttripleReader.h" -TesttripleReader::TesttripleReader(std::string filepath, Index * index, TraintripleReader* graph, int is_trial) { +TesttripleReader::TesttripleReader(Index * index, TraintripleReader* graph, int is_trial) { this->index = index; this->graph = graph; this->is_trial = is_trial; - read(filepath); } int ** TesttripleReader::getTesttriples() { @@ -76,7 +75,8 @@ void TesttripleReader::read(std::string filepath) { fclose(test_sample_file); std::cout << "Written test sample to " << Properties::get().PATH_TEST_SAMPLE << std::endl; - TesttripleReader* sample_reader = new TesttripleReader(Properties::get().PATH_TEST_SAMPLE.c_str(), index, graph, 0); + TesttripleReader* sample_reader = new TesttripleReader(index, graph, 0); + sample_reader->read(Properties::get().PATH_TEST_SAMPLE); csr = sample_reader->getCSR(); delete sample_reader; } @@ -103,3 +103,50 @@ void TesttripleReader::read(std::string filepath) { exit(-1); } } + +void TesttripleReader::read(std::vector> & triples) { + std::vector> testtriplesVector; + size_t numTriplesRead = 0; + for (auto & [head, rel, tail] : triples) + { + try { + int* headId = index->getIdOfNodestring(head); + int* relId = index->getIdOfRelationstring(rel); + int* tailId = index->getIdOfNodestring(tail); + std::vector testtriple; + testtriple.push_back(headId); + testtriple.push_back(relId); + testtriple.push_back(tailId); + testtriplesVector.push_back(testtriple); + + relHeadToTails[*relId][*headId].insert(*tailId); + relTailToHeads[*relId][*tailId].insert(*headId); + + numTriplesRead++; + } + catch (std::runtime_error& e) {} + } + if (Properties::get().VERBOSE == 1) { + std::cout << "read " << numTriplesRead << " triples" << std::endl; + std::cout << "csr start" << std::endl; + std::cout << "relsize=" << index->getRelSize() << " nodesize=" << index->getNodeSize() << std::endl; + } + csr = new CSR(index->getRelSize(), index->getNodeSize(), relHeadToTails, relTailToHeads); + if (Properties::get().VERBOSE == 1) { + std::cout << "csr end" << std::endl; + } + + testtripleSize = new int; + *testtripleSize = testtriplesVector.size(); + // Convert to pointers of pointers + int * testtriplesstore; + testtriples = new int*[*testtripleSize]; + testtriplesstore = new int[(*testtripleSize) * 3]; + + for (int i = 0; i < (*testtripleSize); i++) { + testtriplesstore[i * 3] = *(testtriplesVector[i][0]); + testtriplesstore[i * 3 + 1] = *(testtriplesVector[i][1]); + testtriplesstore[i * 3 + 2] = *(testtriplesVector[i][2]); + testtriples[i] = &testtriplesstore[i * 3]; + } +} From 115d0481c48302aaa958e364f357cfc7510131e6 Mon Sep 17 00:00:00 2001 From: Anton Belyy Date: Fri, 22 Apr 2022 16:22:40 -0400 Subject: [PATCH 2/2] Nicer error messages and codestyle fixes --- CMakeLists.txt | 2 +- Main.cpp | 2 +- include/InMemoryExplanation.h | 38 +-- include/RuleApplication.h | 6 +- include/SQLiteExplanation.h | 48 +-- include/TesttripleReader.h | 4 +- python_bindings/safran.py | 20 +- python_bindings/safran_wrapper.cpp | 131 +++----- python_bindings/safran_wrapper.h | 21 +- python_bindings/safran_wrapper.i | 12 +- python_bindings/setup.py | 50 +-- src/Explanation.cpp | 332 ------------------- src/InMemoryExplanation.cpp | 104 +++--- src/RuleApplication.cpp | 100 +++--- src/RuleGraph.cpp | 22 +- src/SQLiteExplanation.cpp | 508 ++++++++++++++--------------- src/TesttripleReader.cpp | 82 +++-- 17 files changed, 550 insertions(+), 932 deletions(-) delete mode 100644 src/Explanation.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 88efba1..99d903b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,7 +9,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - set(CMAKE_CXX_FLAGS_RELEASE "-O3") + set(CMAKE_CXX_FLAGS_RELEASE "-O3") endif() FetchContent_Declare(sqlite3 URL "http://sqlite.org/2021/sqlite-amalgamation-3340100.zip") diff --git a/Main.cpp b/Main.cpp index 0210c02..d399a9e 100644 --- a/Main.cpp +++ b/Main.cpp @@ -55,7 +55,7 @@ int main(int argc, char** argv) std::cout << "Reading testset..." << std::endl; TesttripleReader* ttr = new TesttripleReader(index, graph, Properties::get().TRIAL); - ttr->read(Properties::get().PATH_TEST); + ttr->read(Properties::get().PATH_TEST); finish = std::chrono::high_resolution_clock::now(); milliseconds = std::chrono::duration_cast(finish - start); std::cout << "Testset read in " << milliseconds.count() << " ms\n"; diff --git a/include/InMemoryExplanation.h b/include/InMemoryExplanation.h index 525a199..cb8289e 100644 --- a/include/InMemoryExplanation.h +++ b/include/InMemoryExplanation.h @@ -5,30 +5,30 @@ class InMemoryExplanation: public Explanation { public: - InMemoryExplanation(); - ~InMemoryExplanation(); + InMemoryExplanation(); + ~InMemoryExplanation(); - void begin(); - void commit(); - void begin_tr(); - void commit_tr(); - void insertEntities(Index* index); - void insertRelations(Index* index); - void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr); + void begin(); + void commit(); + void begin_tr(); + void commit_tr(); + void insertEntities(Index* index); + void insertRelations(Index* index); + void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr); - void insertTask(int task_id, bool is_head, int relation_id, int entity_id); - void insertPrediction(int task_id, int entity_id, bool hit, double confidence); - void insertRule_Entity(int rule_id, int task_id, int entity_id); + void insertTask(int task_id, bool is_head, int relation_id, int entity_id); + void insertPrediction(int task_id, int entity_id, bool hit, double confidence); + void insertRule_Entity(int rule_id, int task_id, int entity_id); - int getNextTaskID(); + int getNextTaskID(); - std::unordered_map>> tripleBestRules; + std::unordered_map>> tripleBestRules; private: - int _task_id = 0; - std::unordered_map ruleConfidences; - std::unordered_map> tasks; - std::unordered_map> taskEntityBestRules; + int _task_id = 0; + std::unordered_map ruleConfidences; + std::unordered_map> tasks; + std::unordered_map> taskEntityBestRules; }; -#endif //INMEMORY_EXPL_H \ No newline at end of file +#endif //INMEMORY_EXPL_H diff --git a/include/RuleApplication.h b/include/RuleApplication.h index 5c17abe..4c63adb 100644 --- a/include/RuleApplication.h +++ b/include/RuleApplication.h @@ -28,12 +28,12 @@ class RuleApplication { public: RuleApplication(Index* index, TraintripleReader* graph, TesttripleReader* ttr, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp); - RuleApplication(Index* index, TraintripleReader* graph, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp); + RuleApplication(Index* index, TraintripleReader* graph, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp); void apply_nr_noisy(std::unordered_map>>, std::pair>>>> rel2clusters); void apply_only_noisy(); void apply_only_max(); - void updateTTR(TesttripleReader* ttr); - std::vector> apply_only_max_in_memory(size_t K); + void updateTTR(TesttripleReader* ttr); + std::vector> apply_only_max_in_memory(size_t K); private: Index* index; diff --git a/include/SQLiteExplanation.h b/include/SQLiteExplanation.h index a88aff0..08b89d1 100644 --- a/include/SQLiteExplanation.h +++ b/include/SQLiteExplanation.h @@ -6,34 +6,34 @@ class SQLiteExplanation: public Explanation { public: - SQLiteExplanation(std::string dbName, bool init = false); - ~SQLiteExplanation(); + SQLiteExplanation(std::string dbName, bool init = false); + ~SQLiteExplanation(); - void begin(); - void commit(); - void begin_tr(); - void commit_tr(); - void insertEntities(Index* index); - void insertRelations(Index* index); - void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr); + void begin(); + void commit(); + void begin_tr(); + void commit_tr(); + void insertEntities(Index* index); + void insertRelations(Index* index); + void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr); - void insertTask(int task_id, bool is_head, int relation_id, int entity_id); - void insertPrediction(int task_id, int entity_id, bool hit, double confidence); - void insertRule_Entity(int rule_id, int task_id, int entity_id); + void insertTask(int task_id, bool is_head, int relation_id, int entity_id); + void insertPrediction(int task_id, int entity_id, bool hit, double confidence); + void insertRule_Entity(int rule_id, int task_id, int entity_id); - // OLD - void insertCluster(int prediction_id, int entity_id, int cluster_id, double confidence); - void insertRule_Cluster(int prediction_id, int entity_id, int cluster_id, int rule_id); + // OLD + void insertCluster(int prediction_id, int entity_id, int cluster_id, double confidence); + void insertRule_Cluster(int prediction_id, int entity_id, int cluster_id, int rule_id); - int getNextTaskID(); + int getNextTaskID(); private: - sqlite3* db; - int _task_id = 0; - void initDb(); - void checkErrorCode(int code); - void checkErrorCode(int code, char* sql); - sqlite3_stmt* prepare(char* sql); - void finalize(sqlite3_stmt* stmt); + sqlite3* db; + int _task_id = 0; + void initDb(); + void checkErrorCode(int code); + void checkErrorCode(int code, char* sql); + sqlite3_stmt* prepare(char* sql); + void finalize(sqlite3_stmt* stmt); }; -#endif //SQLITE_EXPL_H \ No newline at end of file +#endif //SQLITE_EXPL_H diff --git a/include/TesttripleReader.h b/include/TesttripleReader.h index a0361df..8e0cc56 100644 --- a/include/TesttripleReader.h +++ b/include/TesttripleReader.h @@ -21,7 +21,7 @@ class TesttripleReader { public: - TesttripleReader(Index* index, TraintripleReader* graph, int is_trial); + TesttripleReader(Index* index, TraintripleReader* graph, int is_trial); int** getTesttriples(); int* getTesttriplesSize(); @@ -30,7 +30,7 @@ class TesttripleReader RelNodeToNodes& getRelTailToHeads(); void read(std::string filepath); - void read(std::vector> & triples); + void read(std::vector> & triples); protected: diff --git a/python_bindings/safran.py b/python_bindings/safran.py index 9848f7f..604fae3 100644 --- a/python_bindings/safran.py +++ b/python_bindings/safran.py @@ -4,13 +4,15 @@ from safran_wrapper import pysafran, query_output_t, query_triples_t class SAFRAN(pysafran): - def __init__(self, train_path: str, rule_path: str, n_jobs: int = 1): - pysafran.__init__(self, train_path, rule_path, n_jobs) + def __init__(self, train_path: str, rule_path: str, n_jobs: int = 1): + pysafran.__init__(self, train_path, rule_path, n_jobs) - def query(self, triples: List[List[str]], k: int = 100, action: str = 'applymax') -> Dict[Tuple[str, str], List[Tuple[str, float, int]]]: - flat_triples = list(itertools.chain.from_iterable(triples)) - pred_vals = query_output_t(pysafran.query(self, action, k, query_triples_t(flat_triples))) - out = collections.defaultdict(list) - for head, (pred, (tail, (val, rule_id))) in pred_vals: - out[head, pred].append((tail, val, rule_id)) - return out \ No newline at end of file + def query(self, triples: List[List[str]], k: int = 100, action: str = 'applymax') -> Dict[Tuple[str, str], List[Tuple[str, float, int]]]: + if action != 'applymax': + raise ValueError('Actions supported in the SAFRAN Python wrapper are: applymax') + flat_triples = list(itertools.chain.from_iterable(triples)) + pred_vals = query_output_t(pysafran.query(self, action, k, query_triples_t(flat_triples))) + out = collections.defaultdict(list) + for head, (pred, (tail, (val, rule_id))) in pred_vals: + out[head, pred].append((tail, val, rule_id)) + return out diff --git a/python_bindings/safran_wrapper.cpp b/python_bindings/safran_wrapper.cpp index 06b9eac..5da50b2 100644 --- a/python_bindings/safran_wrapper.cpp +++ b/python_bindings/safran_wrapper.cpp @@ -1,94 +1,69 @@ #include "safran_wrapper.h" pysafran::pysafran(std::string train_path, std::string rule_path, int num_threads) { - // TODO: initialize these globally, or make those a part of pysafran instance. - Properties::get().VERBOSE = 0; - Properties::get().PREDICT_UNKNOWN = 1; - if (num_threads != -1) { - omp_set_num_threads(num_threads); - } + // TODO: initialize these globally, or make those a part of pysafran instance. + Properties::get().VERBOSE = 0; + Properties::get().PREDICT_UNKNOWN = 1; + if (num_threads != -1) { + omp_set_num_threads(num_threads); + } - this->index = new Index(); - index->addNode(Properties::get().REFLEXIV_TOKEN); - index->addNode(Properties::get().UNK_TOKEN); + this->index = new Index(); + index->addNode(Properties::get().REFLEXIV_TOKEN); + index->addNode(Properties::get().UNK_TOKEN); - this->graph = new TraintripleReader(train_path, index); - Properties::get().REL_SIZE = index->getRelSize(); + this->graph = new TraintripleReader(train_path, index); + Properties::get().REL_SIZE = index->getRelSize(); - this->rr = new RuleReader(rule_path, index, graph); - this->vtr = new ValidationtripleReader("/dev/null", index, graph); + this->rr = new RuleReader(rule_path, index, graph); + this->vtr = new ValidationtripleReader("/dev/null", index, graph); - this->exp = new InMemoryExplanation(); - this->exp->insertRules(rr, index->getRelSize(), nullptr); + this->exp = new InMemoryExplanation(); + this->exp->insertRules(rr, index->getRelSize(), nullptr); - this->ra = new RuleApplication(index, graph, vtr, rr, exp); + this->ra = new RuleApplication(index, graph, vtr, rr, exp); } pysafran::~pysafran() { - delete this->ra; - delete this->vtr; - delete this->rr; - delete this->graph; - delete this->index; - delete this->exp; + delete this->ra; + delete this->vtr; + delete this->rr; + delete this->graph; + delete this->index; + delete this->exp; } -std::vector>>>> pysafran::query(const std::string & action, size_t k, - const std::vector & flat_triples) const { - // "Read" triples - if (Properties::get().VERBOSE == 1) { - std::cout << "read start" << std::endl; - } - TesttripleReader ttr(index, graph, 0); - std::vector> triples; - size_t i = 0; - while (i < flat_triples.size()) { - triples.emplace_back(flat_triples[i], flat_triples[i + 1], flat_triples[i + 2]); - i += 3; - } - ttr.read(triples); - if (Properties::get().VERBOSE == 1) { - std::cout << "read finish" << std::endl; - } +std::vector>>>> pysafran::query(const std::string & action, size_t k, const std::vector & flat_triples) const { + // "Read" triples + TesttripleReader ttr(index, graph, 0); + std::vector> triples; + size_t i = 0; + while (i < flat_triples.size()) { + triples.emplace_back(flat_triples[i], flat_triples[i + 1], flat_triples[i + 2]); + i += 3; + } + ttr.read(triples); + ra->updateTTR(&ttr); - if (Properties::get().VERBOSE == 1) { - std::cout << "apply start" << std::endl; - } - // Apply specific rule application - std::vector> outAction; - if (action == "applymax") { - if (Properties::get().VERBOSE == 1) { - std::cout << "ra start" << std::endl; - } - ra->updateTTR(&ttr); - if (Properties::get().VERBOSE == 1) { - std::cout << "ra end" << std::endl; - } - outAction = ra->apply_only_max_in_memory(k); - } - if (Properties::get().VERBOSE == 1) { - std::cout << "apply finish" << std::endl; - } + // Apply specific rule application + std::vector> outAction; + if (action == "applymax") { + outAction = ra->apply_only_max_in_memory(k); + } - // Transform node ids to node names and only retain top-k - std::vector>>>> out; - if (Properties::get().VERBOSE == 1) { - std::cout << "transform start" << std::endl; - } - for (auto & p : outAction) { - int head_id = std::get<0>(p), relation_id = std::get<1>(p), tail_id = std::get<2>(p); - float val = (float)std::get<3>(p); - std::string headStr = *this->index->getStringOfNodeId(head_id); - std::string predStr = *this->index->getStringOfRelId(relation_id); - std::string tailStr = *this->index->getStringOfNodeId(tail_id); - int rule_id = exp->tripleBestRules[head_id][relation_id][tail_id]; - std::pair p1 = {val, rule_id}; - std::pair> p2 = {tailStr, p1}; - std::pair>> p3 = {predStr, p2}; - out.emplace_back(headStr, p3); - } - if (Properties::get().VERBOSE == 1) { - std::cout << "transform finish" << std::endl; - } - return out; + // Transform node ids to node names and only retain top-k + std::vector>>>> out; + for (auto & p : outAction) { + int head_id = std::get<0>(p), relation_id = std::get<1>(p), tail_id = std::get<2>(p); + float val = (float)std::get<3>(p); + std::string headStr = *this->index->getStringOfNodeId(head_id); + std::string predStr = *this->index->getStringOfRelId(relation_id); + std::string tailStr = *this->index->getStringOfNodeId(tail_id); + int rule_id = exp->tripleBestRules[head_id][relation_id][tail_id]; + std::pair p1 = {val, rule_id}; + std::pair> p2 = {tailStr, p1}; + std::pair>> p3 = {predStr, p2}; + out.emplace_back(headStr, p3); + } + return out; } diff --git a/python_bindings/safran_wrapper.h b/python_bindings/safran_wrapper.h index d6cf74c..243e789 100644 --- a/python_bindings/safran_wrapper.h +++ b/python_bindings/safran_wrapper.h @@ -15,18 +15,17 @@ using string_vector_t = std::vector; class pysafran { public: - pysafran(std::string train_path, std::string rule_path, int num_threads); - ~pysafran(); + pysafran(std::string train_path, std::string rule_path, int num_threads); + ~pysafran(); - // [(head, pred, tail, confidence, rule_id)] - std::vector>>>> query(const std::string & action, size_t k, - const std::vector & flat_triples) const; + // [(head, pred, tail)] -> [(head, pred, tail, rule_confidence, rule_id)] + std::vector>>>> query(const std::string & action, size_t k, const std::vector & flat_triples) const; private: - Index *index; - TraintripleReader* graph; - ValidationtripleReader* vtr; - RuleReader* rr; - RuleApplication *ra; - InMemoryExplanation *exp; + Index *index; + TraintripleReader* graph; + ValidationtripleReader* vtr; + RuleReader* rr; + RuleApplication *ra; + InMemoryExplanation *exp; }; \ No newline at end of file diff --git a/python_bindings/safran_wrapper.i b/python_bindings/safran_wrapper.i index 69e5d10..247715e 100644 --- a/python_bindings/safran_wrapper.i +++ b/python_bindings/safran_wrapper.i @@ -11,10 +11,10 @@ %include "safran_wrapper.h" namespace std { - %template(query_triples_t) vector; - %template(query_p1) pair; - %template(query_p2) pair>; - %template(query_p3) pair>>; - %template(query_p4) pair>>>; - %template(query_output_t) vector>>>>; + %template(query_triples_t) vector; + %template(query_p1) pair; + %template(query_p2) pair>; + %template(query_p3) pair>>; + %template(query_p4) pair>>>; + %template(query_output_t) vector>>>>; } diff --git a/python_bindings/setup.py b/python_bindings/setup.py index 9a90adb..67a3428 100644 --- a/python_bindings/setup.py +++ b/python_bindings/setup.py @@ -14,38 +14,38 @@ source_files = [src for src in source_files if src not in {'../src/SQLiteExplanation.cpp'}] ext_modules = [ - Extension( - '_safran_wrapper', - source_files, - swig_opts=["-c++", "-extranative"], - language='c++', - extra_compile_args=["-std=c++17", "-I../include", "-I../boost_1_76_0", "-fopenmp"], - extra_link_args=["-lgomp"], - ), + Extension( + '_safran_wrapper', + source_files, + swig_opts=["-c++", "-extranative"], + language='c++', + extra_compile_args=["-std=c++17", "-I../include", "-I../boost_1_76_0", "-fopenmp"], + extra_link_args=["-fopenmp"], + ), ] class CustomBuild(build): - def run(self): - self.run_command('build_ext') - build.run(self) + def run(self): + self.run_command('build_ext') + build.run(self) class CustomInstall(install): - def run(self): - self.run_command('build_ext') - self.do_egg_install() + def run(self): + self.run_command('build_ext') + self.do_egg_install() custom_cmdclass = {'build': CustomBuild, 'install': CustomInstall} setup( - name='safran', - version=__version__, - description='SAFRAN: Scalable and fast non-redundant rule application', - author='S. Ott, C. Melicke, M. Samwald, A. Belyy', - url='https://github.com/AVBelyy/SAFRAN', - ext_modules=ext_modules, - cmdclass=custom_cmdclass, - install_requires=[], - setup_requires=[], - py_modules=['safran', 'safran_wrapper'], - zip_safe=False, + name='safran', + version=__version__, + description='SAFRAN: Scalable and fast non-redundant rule application', + author='S. Ott, C. Melicke, M. Samwald, A. Belyy', + url='https://github.com/AVBelyy/SAFRAN', + ext_modules=ext_modules, + cmdclass=custom_cmdclass, + install_requires=[], + setup_requires=[], + py_modules=['safran', 'safran_wrapper'], + zip_safe=False, ) diff --git a/src/Explanation.cpp b/src/Explanation.cpp deleted file mode 100644 index cecf4ae..0000000 --- a/src/Explanation.cpp +++ /dev/null @@ -1,332 +0,0 @@ -#include "Explanation.h" - -Explanation::Explanation(std::string dbName, bool init) { - checkErrorCode(sqlite3_open_v2(dbName.c_str(), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)); - if (init) initDb(); -} - -Explanation::~Explanation() { - sqlite3_close(db); -} - -void Explanation::initDb() { - - checkErrorCode(sqlite3_exec(db, "PRAGMA synchronous=OFF", NULL, NULL, NULL)); - checkErrorCode(sqlite3_exec(db, "PRAGMA count_changes=OFF", NULL, NULL, NULL)); - checkErrorCode(sqlite3_exec(db, "PRAGMA journal_mode=MEMORY", NULL, NULL, NULL)); - checkErrorCode(sqlite3_exec(db, "PRAGMA temp_store=MEMORY", NULL, NULL, NULL)); - - char* sql; - - sql = "CREATE TABLE Entity(" \ - "ID INT PRIMARY KEY NOT NULL," \ - "NAME TEXT NOT NULL" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Relation(" \ - "ID INT PRIMARY KEY NOT NULL," \ - "NAME TEXT NOT NULL" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Rule(" \ - "ID INT PRIMARY KEY NOT NULL," \ - "HEAD_CLUSTER_ID INT NOT NULL," \ - "TAIL_CLUSTER_ID INT NOT NULL," \ - "DEF TEXT NOT NULL," \ - "CONFIDENCE DOUBLE NOT NULL,"\ - "PREDICTED INT NOT NULL,"\ - "CORRECTLY_PREDICTED INT NOT NULL"\ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Task(" \ - "ID INT PRIMARY KEY NOT NULL," \ - "IsHead BOOL," \ - "EntityID INT," \ - "RelationID INT," \ - "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ - "FOREIGN KEY(RelationID) REFERENCES Relation(ID)" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Prediction(" \ - "TaskID INT NOT NULL," \ - "EntityID INT NOT NULL," \ - "Hit BOOL NOT NULL," - "CONFIDENCE REAL NOT NULL,"\ - "FOREIGN KEY(TaskID) REFERENCES Task(ID)," \ - "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ - "PRIMARY KEY(TaskID, EntityID)" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Rule_Entity(" \ - "RuleID INT NOT NULL," \ - "TaskID INT NOT NULL," \ - "EntityID INT NOT NULL," \ - "FOREIGN KEY(RuleID) REFERENCES Rule(ID)," \ - "FOREIGN KEY(TaskID) REFERENCES Task(ID)," \ - "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ - "PRIMARY KEY(RuleID, TaskID, EntityID)" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE INDEX asdf ON Rule_Entity (" \ - "TaskID," \ - "EntityID" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); -} - -void Explanation::insertEntities(Index* index) { - char* sql = "INSERT INTO Entity (ID,NAME) VALUES (?,?);"; - sqlite3_stmt* stmt = prepare(sql); - for (int id = 0; id < index->getNodeSize(); id++) { - - // bind the value - checkErrorCode(sqlite3_bind_int(stmt, 1, id)); - - std::string curie = *index->getStringOfNodeId(id); - checkErrorCode(sqlite3_bind_text(stmt, 2, curie.c_str(), strlen(curie.c_str()), 0)); - - checkErrorCode(sqlite3_step(stmt)); - checkErrorCode(sqlite3_reset(stmt)); - } - finalize(stmt); -} -void Explanation::insertRelations(Index* index) { - char* sql = "INSERT INTO Relation (ID,NAME) VALUES (?,?);"; - sqlite3_stmt* stmt = prepare(sql); - for (int id = 0; id < index->getRelSize(); id++) { - - // bind the value - checkErrorCode(sqlite3_bind_int(stmt, 1, id)); - - std::string relation = *index->getStringOfRelId(id); - checkErrorCode(sqlite3_bind_text(stmt, 2, relation.c_str(), strlen(relation.c_str()), 0)); - - checkErrorCode(sqlite3_step(stmt)); - checkErrorCode(sqlite3_reset(stmt)); - } - finalize(stmt); -} - -void Explanation::insertRules(RuleReader* rr, int relsize, ClusteringReader* cr) { - int* adj_begin = rr->getCSR()->getAdjBegin(); - Rule* rules_adj_list = rr->getCSR()->getAdjList(); - - char* sql = "INSERT INTO Rule (ID,HEAD_CLUSTER_ID,TAIL_CLUSTER_ID,DEF,CONFIDENCE,PREDICTED,CORRECTLY_PREDICTED) VALUES (?,?,?,?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); - - if (cr == nullptr) { - for (int rel = 0; rel < relsize; rel++) { - int ind_ptr = adj_begin[3 + rel]; - int lenRules = adj_begin[3 + rel + 1] - ind_ptr; - for (int i = 0; i < lenRules; i++) { - - Rule& r = rules_adj_list[ind_ptr + i]; - // bind the value - checkErrorCode(sqlite3_bind_int(stmt, 1, r.getID())); - - checkErrorCode(sqlite3_bind_int(stmt, 2, 0)); - checkErrorCode(sqlite3_bind_int(stmt, 3, 0)); - - std::string rulestr = r.getRulestring(); - checkErrorCode(sqlite3_bind_text(stmt, 4, rulestr.c_str(), strlen(rulestr.c_str()), 0)); - - checkErrorCode(sqlite3_bind_double(stmt, 5, r.getAppliedConfidence())); - checkErrorCode(sqlite3_bind_int(stmt, 6, r.getPredicted())); - checkErrorCode(sqlite3_bind_int(stmt, 7, r.getCorrectlyPredicted())); - - checkErrorCode(sqlite3_step(stmt)); - checkErrorCode(sqlite3_reset(stmt)); - } - } - } - else { - std::unordered_map>>, std::pair>>>> rel2clusters = cr->getRelToClusters(); - for (int rel = 0; rel < relsize; rel++) { - int ind_ptr = adj_begin[3 + rel]; - int lenRules = adj_begin[3 + rel + 1] - ind_ptr; - - std::unordered_map ruleToHeadCluster; - std::unordered_map ruleToTailCluster; - - // Head clusters - std::vector> cluster = rel2clusters[rel].first.second; - for (int cluster_id = 0; cluster_id < cluster.size(); cluster_id++) { - for (auto rule : cluster[cluster_id]) { - Rule& r = rules_adj_list[ind_ptr + rule]; - ruleToHeadCluster[r.getID()] = cluster_id; - } - } - - cluster = rel2clusters[rel].second.second; - for (int cluster_id = 0; cluster_id < cluster.size(); cluster_id++) { - for (auto rule : cluster[cluster_id]) { - Rule& r = rules_adj_list[ind_ptr + rule]; - ruleToTailCluster[r.getID()] = cluster_id; - } - } - - for (int i = 0; i < lenRules; i++) { - - Rule& r = rules_adj_list[ind_ptr + i]; - // bind the value - checkErrorCode(sqlite3_bind_int(stmt, 1, r.getID())); - - checkErrorCode(sqlite3_bind_int(stmt, 2, ruleToHeadCluster[r.getID()])); - - checkErrorCode(sqlite3_bind_int(stmt, 3, ruleToTailCluster[r.getID()])); - - std::string rulestr = r.getRulestring(); - checkErrorCode(sqlite3_bind_text(stmt, 4, rulestr.c_str(), strlen(rulestr.c_str()), 0)); - - checkErrorCode(sqlite3_bind_double(stmt, 5, r.getAppliedConfidence())); - checkErrorCode(sqlite3_bind_int(stmt, 6, r.getPredicted())); - checkErrorCode(sqlite3_bind_int(stmt, 7, r.getCorrectlyPredicted())); - - checkErrorCode(sqlite3_step(stmt)); - checkErrorCode(sqlite3_reset(stmt)); - } - } - } - - finalize(stmt); -} - -void Explanation::insertTask(int task_id, bool is_head, int relation_id, int entity_id) { - //std::cout << "Task " << task_id << " " << is_head << " " << relation_id << " " << entity_id << "\n"; - char* sql = "INSERT INTO Task (ID,IsHead,RelationID,EntityId) VALUES (?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); - - checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert task bind 1"); - checkErrorCode(sqlite3_bind_int(stmt, 2, (int)is_head), "insert task bind 2"); - checkErrorCode(sqlite3_bind_int(stmt, 3, relation_id), "insert task bind 3"); - checkErrorCode(sqlite3_bind_int(stmt, 4, entity_id), "insert task bind 4"); - - checkErrorCode(sqlite3_step(stmt), "insert task step"); - finalize(stmt); -} - -void Explanation::insertPrediction(int task_id, int entity_id, bool hit, double confidence) { - //std::cout << "Prediction " << task_id << " " << entity_id << " " << " " << confidence << "\n"; - char* sql = "INSERT INTO Prediction (TaskID,EntityID,Hit,Confidence) VALUES (?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); - - checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert pred bind 1"); - checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id), "insert pred bind 2"); - checkErrorCode(sqlite3_bind_double(stmt, 3, (int)hit), "insert pred bind 3"); - checkErrorCode(sqlite3_bind_double(stmt, 4, confidence), "insert pred bind 4"); - - checkErrorCode(sqlite3_step(stmt), "insert pred step"); - finalize(stmt); -} - -void Explanation::insertCluster(int task_id, int entity_id, int cluster_id, double confidence) { - //std::cout << "Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << confidence << "\n"; - char* sql = "INSERT INTO Cluster (PredictionTaskID, PredictionEntityID, ID, Confidence) VALUES (?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); - - checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert clus bind 1"); - checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id), "insert clus bind 2"); - checkErrorCode(sqlite3_bind_int(stmt, 3, cluster_id), "insert clus bind 3"); - checkErrorCode(sqlite3_bind_double(stmt, 4, confidence), "insert clus bind 1"); - - checkErrorCode(sqlite3_step(stmt), sql); - finalize(stmt); -} - -void Explanation::insertRule_Cluster(int task_id, int entity_id, int cluster_id, int rule_id) { - //std::cout << "Rule_Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << rule_id << "\n"; - char* sql = "INSERT INTO Rule_Cluster (ClusterPredictionTaskID,ClusterPredictionEntityID,ClusterID,RuleID) VALUES (?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); - - checkErrorCode(sqlite3_bind_int(stmt, 1, task_id)); - checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id)); - checkErrorCode(sqlite3_bind_int(stmt, 3, cluster_id)); - checkErrorCode(sqlite3_bind_int(stmt, 4, rule_id)); - - checkErrorCode(sqlite3_step(stmt)); - finalize(stmt); -} - -void Explanation::insertRule_Entity(int rule_id, int task_id, int entity_id) { - //std::cout << "Rule_Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << rule_id << "\n"; - char* sql = "INSERT OR IGNORE INTO Rule_Entity (RuleID, TaskID, EntityID) VALUES (?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); - - checkErrorCode(sqlite3_bind_int(stmt, 1, rule_id)); - checkErrorCode(sqlite3_bind_int(stmt, 2, task_id)); - checkErrorCode(sqlite3_bind_int(stmt, 3, entity_id)); - - checkErrorCode(sqlite3_step(stmt)); - finalize(stmt); -} - -void Explanation::begin_tr() { - checkErrorCode(sqlite3_exec(db, "BEGIN TRANSACTION", NULL, NULL, NULL), "BEGIN TRANSACTION"); -} - -void Explanation::begin() { - //sqlite3_mutex_enter(sqlite3_db_mutex(db)); -} - -sqlite3_stmt* Explanation::prepare(char* sql) { - sqlite3_stmt* stmt; - checkErrorCode(sqlite3_prepare(db, sql, -1, &stmt, NULL), sql); - return stmt; -} - -void Explanation::finalize(sqlite3_stmt* stmt) { - checkErrorCode(sqlite3_finalize(stmt)); -} - -void Explanation::commit() { - //sqlite3_mutex_leave(sqlite3_db_mutex(db)); -} - -void Explanation::commit_tr() { - checkErrorCode(sqlite3_exec(db, "COMMIT TRANSACTION", NULL, NULL, NULL), "COMMIT TRANSACTION"); -} - -void Explanation::checkErrorCode(int code) { - if (code != SQLITE_OK && code != SQLITE_DONE) { - const char* err; - if (this->db) { - err = sqlite3_errmsg(this->db); - } - else { - err = sqlite3_errstr(code); - } - std::cerr << "Error " << code << ": " << err << '\n'; - std::exit(EXIT_FAILURE); - } -} - -void Explanation::checkErrorCode(int code, char* sql) { - if (code != SQLITE_OK && code != SQLITE_DONE) { - const char* err; - if (this->db) { - err = sqlite3_errmsg(this->db); - } - else { - err = sqlite3_errstr(code); - } - std::cerr << sql << "\n"; - std::cerr << "Error " << code << ": " << err << '\n'; - std::exit(EXIT_FAILURE); - } -} - -int Explanation::getNextTaskID() { - int task_id_; -#pragma omp critical - { - task_id++; - task_id_ = task_id; - } - return task_id_; -} \ No newline at end of file diff --git a/src/InMemoryExplanation.cpp b/src/InMemoryExplanation.cpp index 8447f03..7fd9d01 100644 --- a/src/InMemoryExplanation.cpp +++ b/src/InMemoryExplanation.cpp @@ -1,7 +1,6 @@ #include "InMemoryExplanation.h" -InMemoryExplanation::InMemoryExplanation() { -} +InMemoryExplanation::InMemoryExplanation() = default; InMemoryExplanation::~InMemoryExplanation() = default; @@ -12,65 +11,60 @@ void InMemoryExplanation::insertRelations(Index* index) { } void InMemoryExplanation::insertRules(RuleReader* rr, int relsize, ClusteringReader* cr) { - int* adj_begin = rr->getCSR()->getAdjBegin(); - Rule* rules_adj_list = rr->getCSR()->getAdjList(); + int* adj_begin = rr->getCSR()->getAdjBegin(); + Rule* rules_adj_list = rr->getCSR()->getAdjList(); - if (cr == nullptr) { - for (int rel = 0; rel < relsize; rel++) { - int ind_ptr = adj_begin[3 + rel]; - int lenRules = adj_begin[3 + rel + 1] - ind_ptr; - for (int i = 0; i < lenRules; i++) { - Rule &r = rules_adj_list[ind_ptr + i]; - int rule_id = r.getID(); - auto conf = (float)r.getAppliedConfidence(); - ruleConfidences[rule_id] = conf; - } - } - } else { - // Not implemented yet. - return; - } + for (int rel = 0; rel < relsize; rel++) { + int ind_ptr = adj_begin[3 + rel]; + int lenRules = adj_begin[3 + rel + 1] - ind_ptr; + for (int i = 0; i < lenRules; i++) { + Rule &r = rules_adj_list[ind_ptr + i]; + int rule_id = r.getID(); + auto conf = (float)r.getAppliedConfidence(); + ruleConfidences[rule_id] = conf; + } + } } void InMemoryExplanation::insertTask(int task_id, bool is_head, int relation_id, int entity_id) { #pragma omp critical - { - tasks[task_id] = {relation_id, entity_id, is_head}; - } + { + tasks[task_id] = {relation_id, entity_id, is_head}; + } } void InMemoryExplanation::insertPrediction(int task_id, int entity_id, bool hit, double confidence) { } void InMemoryExplanation::insertRule_Entity(int rule_id, int task_id, int other_entity_id) { - bool need_update = true; - if (taskEntityBestRules.find(task_id) != taskEntityBestRules.end()) { - if (taskEntityBestRules[task_id].find(other_entity_id) != taskEntityBestRules[task_id].end()) { - int prev_rule_id = taskEntityBestRules[task_id][other_entity_id]; - float prev_conf = ruleConfidences[prev_rule_id], new_conf = ruleConfidences[rule_id]; - if (new_conf <= prev_conf) { - need_update = false; - } - } - } - if (need_update) { - auto & p = tasks[task_id]; - int relation_id = std::get<0>(p), entity_id = std::get<1>(p), is_head = std::get<2>(p); - int head_id, tail_id; - if (is_head) { - head_id = entity_id; - tail_id = other_entity_id; - } else { - head_id = other_entity_id; - tail_id = entity_id; - } + bool need_update = true; + if (taskEntityBestRules.find(task_id) != taskEntityBestRules.end()) { + if (taskEntityBestRules[task_id].find(other_entity_id) != taskEntityBestRules[task_id].end()) { + int prev_rule_id = taskEntityBestRules[task_id][other_entity_id]; + float prev_conf = ruleConfidences[prev_rule_id], new_conf = ruleConfidences[rule_id]; + if (new_conf <= prev_conf) { + need_update = false; + } + } + } + if (need_update) { + auto & p = tasks[task_id]; + int relation_id = std::get<0>(p), entity_id = std::get<1>(p), is_head = std::get<2>(p); + int head_id, tail_id; + if (is_head) { + head_id = entity_id; + tail_id = other_entity_id; + } else { + head_id = other_entity_id; + tail_id = entity_id; + } #pragma omp critical - // Update the current best rule, according to the "applymax" action - { - taskEntityBestRules[task_id][other_entity_id] = rule_id; - tripleBestRules[head_id][relation_id][tail_id] = rule_id; - } - } + // Update the current best rule, according to the "applymax" action + { + taskEntityBestRules[task_id][other_entity_id] = rule_id; + tripleBestRules[head_id][relation_id][tail_id] = rule_id; + } + } } void InMemoryExplanation::begin_tr() { @@ -86,11 +80,11 @@ void InMemoryExplanation::commit_tr() { } int InMemoryExplanation::getNextTaskID() { - int task_id_; + int task_id_; #pragma omp critical - { - _task_id++; - task_id_ = _task_id; - } - return task_id_; + { + _task_id++; + task_id_ = _task_id; + } + return task_id_; } \ No newline at end of file diff --git a/src/RuleApplication.cpp b/src/RuleApplication.cpp index 17a7752..6d2185b 100644 --- a/src/RuleApplication.cpp +++ b/src/RuleApplication.cpp @@ -1,14 +1,14 @@ #include "RuleApplication.h" RuleApplication::RuleApplication(Index* index, TraintripleReader* graph, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp) { - this->index = index; - this->graph = graph; - this->vtr = vtr; - this->rr = rr; - this->rulegraph = new RuleGraph(index->getNodeSize(), graph, vtr); - reflexiv_token = *index->getIdOfNodestring(Properties::get().REFLEXIV_TOKEN); - this->k = Properties::get().TOP_K_OUTPUT; - this->exp = exp; + this->index = index; + this->graph = graph; + this->vtr = vtr; + this->rr = rr; + this->rulegraph = new RuleGraph(index->getNodeSize(), graph, vtr); + reflexiv_token = *index->getIdOfNodestring(Properties::get().REFLEXIV_TOKEN); + this->k = Properties::get().TOP_K_OUTPUT; + this->exp = exp; } RuleApplication::RuleApplication(Index* index, TraintripleReader* graph, TesttripleReader* ttr, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp) { @@ -24,8 +24,8 @@ RuleApplication::RuleApplication(Index* index, TraintripleReader* graph, Testtri } void RuleApplication::updateTTR(TesttripleReader* testReader) { - this->ttr = testReader; - this->rulegraph->updateTTR(testReader); + this->ttr = testReader; + this->rulegraph->updateTTR(testReader); } void RuleApplication::apply_nr_noisy(std::unordered_map>>, std::pair>>>> rel2clusters) { @@ -165,52 +165,40 @@ void RuleApplication::apply_only_max() { } std::vector> RuleApplication::apply_only_max_in_memory(size_t K) { - if (Properties::get().VERBOSE == 1) { - std::cout << "aomim init start" << std::endl; - } - int* adj_begin = rr->getCSR()->getAdjBegin(); - std::vector> out; - - int iterations = index->getRelSize(); - if (Properties::get().VERBOSE == 1) { - std::cout << "aomim init end" << std::endl; - } - for (int rel = 0; rel < iterations; rel++) { - // TODO: precompute clusters outside of this call - int ind_ptr = adj_begin[3 + rel]; - int lenRules = adj_begin[3 + rel + 1] - ind_ptr; - if (Properties::get().VERBOSE == 1) { - std::cout << "pred " << rel << " nrules=" << lenRules; - } - std::vector> clusters; - std::vector cluster; - for (int i = 0; i < lenRules; i++) { - cluster.push_back(i); - } - clusters.push_back(cluster); - - std::unordered_map>>> headTailResults; - headTailResults = max(rel, clusters, false); - if (Properties::get().VERBOSE == 1) { - std::cout << " nresults=" << headTailResults.size() << std::endl; - } - - for (auto & [head, tailResults] : headTailResults) { - for (auto & [_, tails] : tailResults) { - size_t maxSize = std::min(tails.size(), K); - size_t i = 0; - for (auto [tail, val] : tails) { - if (i >= maxSize) { - break; - } - out.emplace_back(head, rel, tail, val); - i++; - } - } - } - } - - return out; + int* adj_begin = rr->getCSR()->getAdjBegin(); + std::vector> out; + + int iterations = index->getRelSize(); + for (int rel = 0; rel < iterations; rel++) { + // TODO: precompute clusters outside of this call + int ind_ptr = adj_begin[3 + rel]; + int lenRules = adj_begin[3 + rel + 1] - ind_ptr; + std::vector> clusters; + std::vector cluster; + for (int i = 0; i < lenRules; i++) { + cluster.push_back(i); + } + clusters.push_back(cluster); + + std::unordered_map>>> headTailResults; + headTailResults = max(rel, clusters, false); + + for (auto & [head, tailResults] : headTailResults) { + for (auto & [_, tails] : tailResults) { + size_t maxSize = std::min(tails.size(), K); + size_t i = 0; + for (auto [tail, val] : tails) { + if (i >= maxSize) { + break; + } + out.emplace_back(head, rel, tail, val); + i++; + } + } + } + } + + return out; } std::unordered_map>>> RuleApplication::noisy(int rel, std::vector> clusters, bool predictHeadNotTail) { diff --git a/src/RuleGraph.cpp b/src/RuleGraph.cpp index bb48b21..416d6f2 100644 --- a/src/RuleGraph.cpp +++ b/src/RuleGraph.cpp @@ -22,20 +22,20 @@ RuleGraph::RuleGraph(int nodesize, TraintripleReader* graph, TesttripleReader* t } RuleGraph::RuleGraph(int nodesize, TraintripleReader* graph, ValidationtripleReader* vtr) { - this->size = nodesize; - this->graph = graph; - adj_lists = graph->getCSR()->getAdjList(); - adj_list_starts = graph->getCSR()->getAdjBegin(); - this->train_relHeadToTails = graph->getRelHeadToTails(); - this->train_relTailToHeads = graph->getRelTailToHeads(); - this->valid_relHeadToTails = vtr->getRelHeadToTails(); - this->valid_relTailToHeads = vtr->getRelTailToHeads(); - this->relCounter = graph->getRelCounter(); + this->size = nodesize; + this->graph = graph; + adj_lists = graph->getCSR()->getAdjList(); + adj_list_starts = graph->getCSR()->getAdjBegin(); + this->train_relHeadToTails = graph->getRelHeadToTails(); + this->train_relTailToHeads = graph->getRelTailToHeads(); + this->valid_relHeadToTails = vtr->getRelHeadToTails(); + this->valid_relTailToHeads = vtr->getRelTailToHeads(); + this->relCounter = graph->getRelCounter(); } void RuleGraph::updateTTR(TesttripleReader* ttr) { - this->test_relHeadToTails = ttr->getRelHeadToTails(); - this->test_relTailToHeads = ttr->getRelTailToHeads(); + this->test_relHeadToTails = ttr->getRelHeadToTails(); + this->test_relTailToHeads = ttr->getRelTailToHeads(); } void RuleGraph::searchDFSSingleStart_filt(bool headNotTail, int filt_v, int v, Rule& r, bool bwd, std::vector& solution, bool filtValidNotTest, bool filtExceptions) { diff --git a/src/SQLiteExplanation.cpp b/src/SQLiteExplanation.cpp index a609043..2b090bb 100644 --- a/src/SQLiteExplanation.cpp +++ b/src/SQLiteExplanation.cpp @@ -1,332 +1,332 @@ #include "SQLiteExplanation.h" SQLiteExplanation::SQLiteExplanation(std::string dbName, bool init) { - checkErrorCode(sqlite3_open_v2(dbName.c_str(), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)); - if (init) initDb(); + checkErrorCode(sqlite3_open_v2(dbName.c_str(), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)); + if (init) initDb(); } SQLiteExplanation::~SQLiteExplanation() { - sqlite3_close(db); + sqlite3_close(db); } void SQLiteExplanation::initDb() { - checkErrorCode(sqlite3_exec(db, "PRAGMA synchronous=OFF", NULL, NULL, NULL)); - checkErrorCode(sqlite3_exec(db, "PRAGMA count_changes=OFF", NULL, NULL, NULL)); - checkErrorCode(sqlite3_exec(db, "PRAGMA journal_mode=MEMORY", NULL, NULL, NULL)); - checkErrorCode(sqlite3_exec(db, "PRAGMA temp_store=MEMORY", NULL, NULL, NULL)); - - char* sql; - - sql = "CREATE TABLE Entity(" \ - "ID INT PRIMARY KEY NOT NULL," \ - "NAME TEXT NOT NULL" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Relation(" \ - "ID INT PRIMARY KEY NOT NULL," \ - "NAME TEXT NOT NULL" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Rule(" \ - "ID INT PRIMARY KEY NOT NULL," \ - "HEAD_CLUSTER_ID INT NOT NULL," \ - "TAIL_CLUSTER_ID INT NOT NULL," \ - "DEF TEXT NOT NULL," \ - "CONFIDENCE DOUBLE NOT NULL,"\ - "PREDICTED INT NOT NULL,"\ - "CORRECTLY_PREDICTED INT NOT NULL"\ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Task(" \ - "ID INT PRIMARY KEY NOT NULL," \ - "IsHead BOOL," \ - "EntityID INT," \ - "RelationID INT," \ - "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ - "FOREIGN KEY(RelationID) REFERENCES Relation(ID)" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Prediction(" \ - "TaskID INT NOT NULL," \ - "EntityID INT NOT NULL," \ - "Hit BOOL NOT NULL," - "CONFIDENCE REAL NOT NULL,"\ - "FOREIGN KEY(TaskID) REFERENCES Task(ID)," \ - "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ - "PRIMARY KEY(TaskID, EntityID)" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE TABLE Rule_Entity(" \ - "RuleID INT NOT NULL," \ - "TaskID INT NOT NULL," \ - "EntityID INT NOT NULL," \ - "FOREIGN KEY(RuleID) REFERENCES Rule(ID)," \ - "FOREIGN KEY(TaskID) REFERENCES Task(ID)," \ - "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ - "PRIMARY KEY(RuleID, TaskID, EntityID)" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); - - sql = "CREATE INDEX asdf ON Rule_Entity (" \ - "TaskID," \ - "EntityID" \ - ");"; - checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + checkErrorCode(sqlite3_exec(db, "PRAGMA synchronous=OFF", NULL, NULL, NULL)); + checkErrorCode(sqlite3_exec(db, "PRAGMA count_changes=OFF", NULL, NULL, NULL)); + checkErrorCode(sqlite3_exec(db, "PRAGMA journal_mode=MEMORY", NULL, NULL, NULL)); + checkErrorCode(sqlite3_exec(db, "PRAGMA temp_store=MEMORY", NULL, NULL, NULL)); + + char* sql; + + sql = "CREATE TABLE Entity(" \ + "ID INT PRIMARY KEY NOT NULL," \ + "NAME TEXT NOT NULL" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Relation(" \ + "ID INT PRIMARY KEY NOT NULL," \ + "NAME TEXT NOT NULL" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Rule(" \ + "ID INT PRIMARY KEY NOT NULL," \ + "HEAD_CLUSTER_ID INT NOT NULL," \ + "TAIL_CLUSTER_ID INT NOT NULL," \ + "DEF TEXT NOT NULL," \ + "CONFIDENCE DOUBLE NOT NULL,"\ + "PREDICTED INT NOT NULL,"\ + "CORRECTLY_PREDICTED INT NOT NULL"\ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Task(" \ + "ID INT PRIMARY KEY NOT NULL," \ + "IsHead BOOL," \ + "EntityID INT," \ + "RelationID INT," \ + "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ + "FOREIGN KEY(RelationID) REFERENCES Relation(ID)" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Prediction(" \ + "TaskID INT NOT NULL," \ + "EntityID INT NOT NULL," \ + "Hit BOOL NOT NULL," + "CONFIDENCE REAL NOT NULL,"\ + "FOREIGN KEY(TaskID) REFERENCES Task(ID)," \ + "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ + "PRIMARY KEY(TaskID, EntityID)" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE TABLE Rule_Entity(" \ + "RuleID INT NOT NULL," \ + "TaskID INT NOT NULL," \ + "EntityID INT NOT NULL," \ + "FOREIGN KEY(RuleID) REFERENCES Rule(ID)," \ + "FOREIGN KEY(TaskID) REFERENCES Task(ID)," \ + "FOREIGN KEY(EntityID) REFERENCES Entity(ID)," \ + "PRIMARY KEY(RuleID, TaskID, EntityID)" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); + + sql = "CREATE INDEX asdf ON Rule_Entity (" \ + "TaskID," \ + "EntityID" \ + ");"; + checkErrorCode(sqlite3_exec(db, sql, NULL, NULL, NULL), sql); } void SQLiteExplanation::insertEntities(Index* index) { - char* sql = "INSERT INTO Entity (ID,NAME) VALUES (?,?);"; - sqlite3_stmt* stmt = prepare(sql); - for (int id = 0; id < index->getNodeSize(); id++) { + char* sql = "INSERT INTO Entity (ID,NAME) VALUES (?,?);"; + sqlite3_stmt* stmt = prepare(sql); + for (int id = 0; id < index->getNodeSize(); id++) { - // bind the value - checkErrorCode(sqlite3_bind_int(stmt, 1, id)); + // bind the value + checkErrorCode(sqlite3_bind_int(stmt, 1, id)); - std::string curie = *index->getStringOfNodeId(id); - checkErrorCode(sqlite3_bind_text(stmt, 2, curie.c_str(), strlen(curie.c_str()), 0)); + std::string curie = *index->getStringOfNodeId(id); + checkErrorCode(sqlite3_bind_text(stmt, 2, curie.c_str(), strlen(curie.c_str()), 0)); - checkErrorCode(sqlite3_step(stmt)); - checkErrorCode(sqlite3_reset(stmt)); - } - finalize(stmt); + checkErrorCode(sqlite3_step(stmt)); + checkErrorCode(sqlite3_reset(stmt)); + } + finalize(stmt); } void SQLiteExplanation::insertRelations(Index* index) { - char* sql = "INSERT INTO Relation (ID,NAME) VALUES (?,?);"; - sqlite3_stmt* stmt = prepare(sql); - for (int id = 0; id < index->getRelSize(); id++) { + char* sql = "INSERT INTO Relation (ID,NAME) VALUES (?,?);"; + sqlite3_stmt* stmt = prepare(sql); + for (int id = 0; id < index->getRelSize(); id++) { - // bind the value - checkErrorCode(sqlite3_bind_int(stmt, 1, id)); + // bind the value + checkErrorCode(sqlite3_bind_int(stmt, 1, id)); - std::string relation = *index->getStringOfRelId(id); - checkErrorCode(sqlite3_bind_text(stmt, 2, relation.c_str(), strlen(relation.c_str()), 0)); + std::string relation = *index->getStringOfRelId(id); + checkErrorCode(sqlite3_bind_text(stmt, 2, relation.c_str(), strlen(relation.c_str()), 0)); - checkErrorCode(sqlite3_step(stmt)); - checkErrorCode(sqlite3_reset(stmt)); - } - finalize(stmt); + checkErrorCode(sqlite3_step(stmt)); + checkErrorCode(sqlite3_reset(stmt)); + } + finalize(stmt); } void SQLiteExplanation::insertRules(RuleReader* rr, int relsize, ClusteringReader* cr) { - int* adj_begin = rr->getCSR()->getAdjBegin(); - Rule* rules_adj_list = rr->getCSR()->getAdjList(); - - char* sql = "INSERT INTO Rule (ID,HEAD_CLUSTER_ID,TAIL_CLUSTER_ID,DEF,CONFIDENCE,PREDICTED,CORRECTLY_PREDICTED) VALUES (?,?,?,?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); - - if (cr == nullptr) { - for (int rel = 0; rel < relsize; rel++) { - int ind_ptr = adj_begin[3 + rel]; - int lenRules = adj_begin[3 + rel + 1] - ind_ptr; - for (int i = 0; i < lenRules; i++) { - - Rule& r = rules_adj_list[ind_ptr + i]; - // bind the value - checkErrorCode(sqlite3_bind_int(stmt, 1, r.getID())); - - checkErrorCode(sqlite3_bind_int(stmt, 2, 0)); - checkErrorCode(sqlite3_bind_int(stmt, 3, 0)); - - std::string rulestr = r.getRulestring(); - checkErrorCode(sqlite3_bind_text(stmt, 4, rulestr.c_str(), strlen(rulestr.c_str()), 0)); - - checkErrorCode(sqlite3_bind_double(stmt, 5, r.getAppliedConfidence())); - checkErrorCode(sqlite3_bind_int(stmt, 6, r.getPredicted())); - checkErrorCode(sqlite3_bind_int(stmt, 7, r.getCorrectlyPredicted())); - - checkErrorCode(sqlite3_step(stmt)); - checkErrorCode(sqlite3_reset(stmt)); - } - } - } - else { - std::unordered_map>>, std::pair>>>> rel2clusters = cr->getRelToClusters(); - for (int rel = 0; rel < relsize; rel++) { - int ind_ptr = adj_begin[3 + rel]; - int lenRules = adj_begin[3 + rel + 1] - ind_ptr; - - std::unordered_map ruleToHeadCluster; - std::unordered_map ruleToTailCluster; - - // Head clusters - std::vector> cluster = rel2clusters[rel].first.second; - for (int cluster_id = 0; cluster_id < cluster.size(); cluster_id++) { - for (auto rule : cluster[cluster_id]) { - Rule& r = rules_adj_list[ind_ptr + rule]; - ruleToHeadCluster[r.getID()] = cluster_id; - } - } - - cluster = rel2clusters[rel].second.second; - for (int cluster_id = 0; cluster_id < cluster.size(); cluster_id++) { - for (auto rule : cluster[cluster_id]) { - Rule& r = rules_adj_list[ind_ptr + rule]; - ruleToTailCluster[r.getID()] = cluster_id; - } - } - - for (int i = 0; i < lenRules; i++) { - - Rule& r = rules_adj_list[ind_ptr + i]; - // bind the value - checkErrorCode(sqlite3_bind_int(stmt, 1, r.getID())); - - checkErrorCode(sqlite3_bind_int(stmt, 2, ruleToHeadCluster[r.getID()])); - - checkErrorCode(sqlite3_bind_int(stmt, 3, ruleToTailCluster[r.getID()])); - - std::string rulestr = r.getRulestring(); - checkErrorCode(sqlite3_bind_text(stmt, 4, rulestr.c_str(), strlen(rulestr.c_str()), 0)); - - checkErrorCode(sqlite3_bind_double(stmt, 5, r.getAppliedConfidence())); - checkErrorCode(sqlite3_bind_int(stmt, 6, r.getPredicted())); - checkErrorCode(sqlite3_bind_int(stmt, 7, r.getCorrectlyPredicted())); - - checkErrorCode(sqlite3_step(stmt)); - checkErrorCode(sqlite3_reset(stmt)); - } - } - } - - finalize(stmt); + int* adj_begin = rr->getCSR()->getAdjBegin(); + Rule* rules_adj_list = rr->getCSR()->getAdjList(); + + char* sql = "INSERT INTO Rule (ID,HEAD_CLUSTER_ID,TAIL_CLUSTER_ID,DEF,CONFIDENCE,PREDICTED,CORRECTLY_PREDICTED) VALUES (?,?,?,?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); + + if (cr == nullptr) { + for (int rel = 0; rel < relsize; rel++) { + int ind_ptr = adj_begin[3 + rel]; + int lenRules = adj_begin[3 + rel + 1] - ind_ptr; + for (int i = 0; i < lenRules; i++) { + + Rule& r = rules_adj_list[ind_ptr + i]; + // bind the value + checkErrorCode(sqlite3_bind_int(stmt, 1, r.getID())); + + checkErrorCode(sqlite3_bind_int(stmt, 2, 0)); + checkErrorCode(sqlite3_bind_int(stmt, 3, 0)); + + std::string rulestr = r.getRulestring(); + checkErrorCode(sqlite3_bind_text(stmt, 4, rulestr.c_str(), strlen(rulestr.c_str()), 0)); + + checkErrorCode(sqlite3_bind_double(stmt, 5, r.getAppliedConfidence())); + checkErrorCode(sqlite3_bind_int(stmt, 6, r.getPredicted())); + checkErrorCode(sqlite3_bind_int(stmt, 7, r.getCorrectlyPredicted())); + + checkErrorCode(sqlite3_step(stmt)); + checkErrorCode(sqlite3_reset(stmt)); + } + } + } + else { + std::unordered_map>>, std::pair>>>> rel2clusters = cr->getRelToClusters(); + for (int rel = 0; rel < relsize; rel++) { + int ind_ptr = adj_begin[3 + rel]; + int lenRules = adj_begin[3 + rel + 1] - ind_ptr; + + std::unordered_map ruleToHeadCluster; + std::unordered_map ruleToTailCluster; + + // Head clusters + std::vector> cluster = rel2clusters[rel].first.second; + for (int cluster_id = 0; cluster_id < cluster.size(); cluster_id++) { + for (auto rule : cluster[cluster_id]) { + Rule& r = rules_adj_list[ind_ptr + rule]; + ruleToHeadCluster[r.getID()] = cluster_id; + } + } + + cluster = rel2clusters[rel].second.second; + for (int cluster_id = 0; cluster_id < cluster.size(); cluster_id++) { + for (auto rule : cluster[cluster_id]) { + Rule& r = rules_adj_list[ind_ptr + rule]; + ruleToTailCluster[r.getID()] = cluster_id; + } + } + + for (int i = 0; i < lenRules; i++) { + + Rule& r = rules_adj_list[ind_ptr + i]; + // bind the value + checkErrorCode(sqlite3_bind_int(stmt, 1, r.getID())); + + checkErrorCode(sqlite3_bind_int(stmt, 2, ruleToHeadCluster[r.getID()])); + + checkErrorCode(sqlite3_bind_int(stmt, 3, ruleToTailCluster[r.getID()])); + + std::string rulestr = r.getRulestring(); + checkErrorCode(sqlite3_bind_text(stmt, 4, rulestr.c_str(), strlen(rulestr.c_str()), 0)); + + checkErrorCode(sqlite3_bind_double(stmt, 5, r.getAppliedConfidence())); + checkErrorCode(sqlite3_bind_int(stmt, 6, r.getPredicted())); + checkErrorCode(sqlite3_bind_int(stmt, 7, r.getCorrectlyPredicted())); + + checkErrorCode(sqlite3_step(stmt)); + checkErrorCode(sqlite3_reset(stmt)); + } + } + } + + finalize(stmt); } void SQLiteExplanation::insertTask(int task_id, bool is_head, int relation_id, int entity_id) { - //std::cout << "Task " << task_id << " " << is_head << " " << relation_id << " " << entity_id << "\n"; - char* sql = "INSERT INTO Task (ID,IsHead,RelationID,EntityId) VALUES (?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); + //std::cout << "Task " << task_id << " " << is_head << " " << relation_id << " " << entity_id << "\n"; + char* sql = "INSERT INTO Task (ID,IsHead,RelationID,EntityId) VALUES (?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); - checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert task bind 1"); - checkErrorCode(sqlite3_bind_int(stmt, 2, (int)is_head), "insert task bind 2"); - checkErrorCode(sqlite3_bind_int(stmt, 3, relation_id), "insert task bind 3"); - checkErrorCode(sqlite3_bind_int(stmt, 4, entity_id), "insert task bind 4"); + checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert task bind 1"); + checkErrorCode(sqlite3_bind_int(stmt, 2, (int)is_head), "insert task bind 2"); + checkErrorCode(sqlite3_bind_int(stmt, 3, relation_id), "insert task bind 3"); + checkErrorCode(sqlite3_bind_int(stmt, 4, entity_id), "insert task bind 4"); - checkErrorCode(sqlite3_step(stmt), "insert task step"); - finalize(stmt); + checkErrorCode(sqlite3_step(stmt), "insert task step"); + finalize(stmt); } void SQLiteExplanation::insertPrediction(int task_id, int entity_id, bool hit, double confidence) { - //std::cout << "Prediction " << task_id << " " << entity_id << " " << " " << confidence << "\n"; - char* sql = "INSERT INTO Prediction (TaskID,EntityID,Hit,Confidence) VALUES (?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); + //std::cout << "Prediction " << task_id << " " << entity_id << " " << " " << confidence << "\n"; + char* sql = "INSERT INTO Prediction (TaskID,EntityID,Hit,Confidence) VALUES (?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); - checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert pred bind 1"); - checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id), "insert pred bind 2"); - checkErrorCode(sqlite3_bind_double(stmt, 3, (int)hit), "insert pred bind 3"); - checkErrorCode(sqlite3_bind_double(stmt, 4, confidence), "insert pred bind 4"); + checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert pred bind 1"); + checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id), "insert pred bind 2"); + checkErrorCode(sqlite3_bind_double(stmt, 3, (int)hit), "insert pred bind 3"); + checkErrorCode(sqlite3_bind_double(stmt, 4, confidence), "insert pred bind 4"); - checkErrorCode(sqlite3_step(stmt), "insert pred step"); - finalize(stmt); + checkErrorCode(sqlite3_step(stmt), "insert pred step"); + finalize(stmt); } void SQLiteExplanation::insertCluster(int task_id, int entity_id, int cluster_id, double confidence) { - //std::cout << "Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << confidence << "\n"; - char* sql = "INSERT INTO Cluster (PredictionTaskID, PredictionEntityID, ID, Confidence) VALUES (?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); + //std::cout << "Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << confidence << "\n"; + char* sql = "INSERT INTO Cluster (PredictionTaskID, PredictionEntityID, ID, Confidence) VALUES (?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); - checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert clus bind 1"); - checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id), "insert clus bind 2"); - checkErrorCode(sqlite3_bind_int(stmt, 3, cluster_id), "insert clus bind 3"); - checkErrorCode(sqlite3_bind_double(stmt, 4, confidence), "insert clus bind 1"); + checkErrorCode(sqlite3_bind_int(stmt, 1, task_id), "insert clus bind 1"); + checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id), "insert clus bind 2"); + checkErrorCode(sqlite3_bind_int(stmt, 3, cluster_id), "insert clus bind 3"); + checkErrorCode(sqlite3_bind_double(stmt, 4, confidence), "insert clus bind 1"); - checkErrorCode(sqlite3_step(stmt), sql); - finalize(stmt); + checkErrorCode(sqlite3_step(stmt), sql); + finalize(stmt); } void SQLiteExplanation::insertRule_Cluster(int task_id, int entity_id, int cluster_id, int rule_id) { - //std::cout << "Rule_Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << rule_id << "\n"; - char* sql = "INSERT INTO Rule_Cluster (ClusterPredictionTaskID,ClusterPredictionEntityID,ClusterID,RuleID) VALUES (?,?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); + //std::cout << "Rule_Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << rule_id << "\n"; + char* sql = "INSERT INTO Rule_Cluster (ClusterPredictionTaskID,ClusterPredictionEntityID,ClusterID,RuleID) VALUES (?,?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); - checkErrorCode(sqlite3_bind_int(stmt, 1, task_id)); - checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id)); - checkErrorCode(sqlite3_bind_int(stmt, 3, cluster_id)); - checkErrorCode(sqlite3_bind_int(stmt, 4, rule_id)); + checkErrorCode(sqlite3_bind_int(stmt, 1, task_id)); + checkErrorCode(sqlite3_bind_int(stmt, 2, entity_id)); + checkErrorCode(sqlite3_bind_int(stmt, 3, cluster_id)); + checkErrorCode(sqlite3_bind_int(stmt, 4, rule_id)); - checkErrorCode(sqlite3_step(stmt)); - finalize(stmt); + checkErrorCode(sqlite3_step(stmt)); + finalize(stmt); } void SQLiteExplanation::insertRule_Entity(int rule_id, int task_id, int entity_id) { - //std::cout << "Rule_Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << rule_id << "\n"; - char* sql = "INSERT OR IGNORE INTO Rule_Entity (RuleID, TaskID, EntityID) VALUES (?,?,?);"; - sqlite3_stmt* stmt = prepare(sql); + //std::cout << "Rule_Cluster " << task_id << " " << entity_id << " " << cluster_id << " " << rule_id << "\n"; + char* sql = "INSERT OR IGNORE INTO Rule_Entity (RuleID, TaskID, EntityID) VALUES (?,?,?);"; + sqlite3_stmt* stmt = prepare(sql); - checkErrorCode(sqlite3_bind_int(stmt, 1, rule_id)); - checkErrorCode(sqlite3_bind_int(stmt, 2, task_id)); - checkErrorCode(sqlite3_bind_int(stmt, 3, entity_id)); + checkErrorCode(sqlite3_bind_int(stmt, 1, rule_id)); + checkErrorCode(sqlite3_bind_int(stmt, 2, task_id)); + checkErrorCode(sqlite3_bind_int(stmt, 3, entity_id)); - checkErrorCode(sqlite3_step(stmt)); - finalize(stmt); + checkErrorCode(sqlite3_step(stmt)); + finalize(stmt); } void SQLiteExplanation::begin_tr() { - checkErrorCode(sqlite3_exec(db, "BEGIN TRANSACTION", NULL, NULL, NULL), "BEGIN TRANSACTION"); + checkErrorCode(sqlite3_exec(db, "BEGIN TRANSACTION", NULL, NULL, NULL), "BEGIN TRANSACTION"); } void SQLiteExplanation::begin() { - //sqlite3_mutex_enter(sqlite3_db_mutex(db)); + //sqlite3_mutex_enter(sqlite3_db_mutex(db)); } sqlite3_stmt* SQLiteExplanation::prepare(char* sql) { - sqlite3_stmt* stmt; - checkErrorCode(sqlite3_prepare(db, sql, -1, &stmt, NULL), sql); - return stmt; + sqlite3_stmt* stmt; + checkErrorCode(sqlite3_prepare(db, sql, -1, &stmt, NULL), sql); + return stmt; } void SQLiteExplanation::finalize(sqlite3_stmt* stmt) { - checkErrorCode(sqlite3_finalize(stmt)); + checkErrorCode(sqlite3_finalize(stmt)); } void SQLiteExplanation::commit() { - //sqlite3_mutex_leave(sqlite3_db_mutex(db)); + //sqlite3_mutex_leave(sqlite3_db_mutex(db)); } void SQLiteExplanation::commit_tr() { - checkErrorCode(sqlite3_exec(db, "COMMIT TRANSACTION", NULL, NULL, NULL), "COMMIT TRANSACTION"); + checkErrorCode(sqlite3_exec(db, "COMMIT TRANSACTION", NULL, NULL, NULL), "COMMIT TRANSACTION"); } void SQLiteExplanation::checkErrorCode(int code) { - if (code != SQLITE_OK && code != SQLITE_DONE) { - const char* err; - if (this->db) { - err = sqlite3_errmsg(this->db); - } - else { - err = sqlite3_errstr(code); - } - std::cerr << "Error " << code << ": " << err << '\n'; - std::exit(EXIT_FAILURE); - } + if (code != SQLITE_OK && code != SQLITE_DONE) { + const char* err; + if (this->db) { + err = sqlite3_errmsg(this->db); + } + else { + err = sqlite3_errstr(code); + } + std::cerr << "Error " << code << ": " << err << '\n'; + std::exit(EXIT_FAILURE); + } } void SQLiteExplanation::checkErrorCode(int code, char* sql) { - if (code != SQLITE_OK && code != SQLITE_DONE) { - const char* err; - if (this->db) { - err = sqlite3_errmsg(this->db); - } - else { - err = sqlite3_errstr(code); - } - std::cerr << sql << "\n"; - std::cerr << "Error " << code << ": " << err << '\n'; - std::exit(EXIT_FAILURE); - } + if (code != SQLITE_OK && code != SQLITE_DONE) { + const char* err; + if (this->db) { + err = sqlite3_errmsg(this->db); + } + else { + err = sqlite3_errstr(code); + } + std::cerr << sql << "\n"; + std::cerr << "Error " << code << ": " << err << '\n'; + std::exit(EXIT_FAILURE); + } } int SQLiteExplanation::getNextTaskID() { - int task_id_; + int task_id_; #pragma omp critical - { - _task_id++; - task_id_ = _task_id; - } - return task_id_; + { + _task_id++; + task_id_ = _task_id; + } + return task_id_; } diff --git a/src/TesttripleReader.cpp b/src/TesttripleReader.cpp index 0656628..4d7ad72 100644 --- a/src/TesttripleReader.cpp +++ b/src/TesttripleReader.cpp @@ -76,7 +76,7 @@ void TesttripleReader::read(std::string filepath) { std::cout << "Written test sample to " << Properties::get().PATH_TEST_SAMPLE << std::endl; TesttripleReader* sample_reader = new TesttripleReader(index, graph, 0); - sample_reader->read(Properties::get().PATH_TEST_SAMPLE); + sample_reader->read(Properties::get().PATH_TEST_SAMPLE); csr = sample_reader->getCSR(); delete sample_reader; } @@ -105,48 +105,40 @@ void TesttripleReader::read(std::string filepath) { } void TesttripleReader::read(std::vector> & triples) { - std::vector> testtriplesVector; - size_t numTriplesRead = 0; - for (auto & [head, rel, tail] : triples) - { - try { - int* headId = index->getIdOfNodestring(head); - int* relId = index->getIdOfRelationstring(rel); - int* tailId = index->getIdOfNodestring(tail); - std::vector testtriple; - testtriple.push_back(headId); - testtriple.push_back(relId); - testtriple.push_back(tailId); - testtriplesVector.push_back(testtriple); - - relHeadToTails[*relId][*headId].insert(*tailId); - relTailToHeads[*relId][*tailId].insert(*headId); - - numTriplesRead++; - } - catch (std::runtime_error& e) {} - } - if (Properties::get().VERBOSE == 1) { - std::cout << "read " << numTriplesRead << " triples" << std::endl; - std::cout << "csr start" << std::endl; - std::cout << "relsize=" << index->getRelSize() << " nodesize=" << index->getNodeSize() << std::endl; - } - csr = new CSR(index->getRelSize(), index->getNodeSize(), relHeadToTails, relTailToHeads); - if (Properties::get().VERBOSE == 1) { - std::cout << "csr end" << std::endl; - } - - testtripleSize = new int; - *testtripleSize = testtriplesVector.size(); - // Convert to pointers of pointers - int * testtriplesstore; - testtriples = new int*[*testtripleSize]; - testtriplesstore = new int[(*testtripleSize) * 3]; - - for (int i = 0; i < (*testtripleSize); i++) { - testtriplesstore[i * 3] = *(testtriplesVector[i][0]); - testtriplesstore[i * 3 + 1] = *(testtriplesVector[i][1]); - testtriplesstore[i * 3 + 2] = *(testtriplesVector[i][2]); - testtriples[i] = &testtriplesstore[i * 3]; - } + std::vector> testtriplesVector; + size_t numTriplesRead = 0; + for (auto & [head, rel, tail] : triples) + { + try { + int* headId = index->getIdOfNodestring(head); + int* relId = index->getIdOfRelationstring(rel); + int* tailId = index->getIdOfNodestring(tail); + std::vector testtriple; + testtriple.push_back(headId); + testtriple.push_back(relId); + testtriple.push_back(tailId); + testtriplesVector.push_back(testtriple); + + relHeadToTails[*relId][*headId].insert(*tailId); + relTailToHeads[*relId][*tailId].insert(*headId); + + numTriplesRead++; + } + catch (std::runtime_error& e) {} + } + csr = new CSR(index->getRelSize(), index->getNodeSize(), relHeadToTails, relTailToHeads); + + testtripleSize = new int; + *testtripleSize = testtriplesVector.size(); + // Convert to pointers of pointers + int * testtriplesstore; + testtriples = new int*[*testtripleSize]; + testtriplesstore = new int[(*testtripleSize) * 3]; + + for (int i = 0; i < (*testtripleSize); i++) { + testtriplesstore[i * 3] = *(testtriplesVector[i][0]); + testtriplesstore[i * 3 + 1] = *(testtriplesVector[i][1]); + testtriplesstore[i * 3 + 2] = *(testtriplesVector[i][2]); + testtriples[i] = &testtriplesstore[i * 3]; + } }