Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea/
build/
boost_1_76_0/
include/boost
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)

if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(CMAKE_CXX_FLAGS_RELEASE "-O3")
set(CMAKE_CXX_FLAGS_RELEASE "-O3")
endif()

FetchContent_Declare(sqlite3 URL "http://sqlite.org/2021/sqlite-amalgamation-3340100.zip")
Expand Down
7 changes: 4 additions & 3 deletions Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "ClusteringReader.h"
#include "RuleApplication.h"
#include "JaccardEngine.h"
#include "Explanation.h"
#include "SQLiteExplanation.h"
#include "Util.hpp"
#include <chrono>
#include <stdio.h>
Expand Down Expand Up @@ -54,7 +54,8 @@ int main(int argc, char** argv)
Properties::get().REL_SIZE = index->getRelSize();

std::cout << "Reading testset..." << std::endl;
TesttripleReader* ttr = new TesttripleReader(Properties::get().PATH_TEST, index, graph, Properties::get().TRIAL);
TesttripleReader* ttr = new TesttripleReader(index, graph, Properties::get().TRIAL);
ttr->read(Properties::get().PATH_TEST);
finish = std::chrono::high_resolution_clock::now();
milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(finish - start);
std::cout << "Testset read in " << milliseconds.count() << " ms\n";
Expand Down Expand Up @@ -86,7 +87,7 @@ int main(int argc, char** argv)
Explanation* explanation = nullptr;
if (Properties::get().EXPLAIN == 1) {
std::cout << "Writing entities, relations and rules to db file..." << std::endl;
explanation = new Explanation(Properties::get().PATH_EXPLAIN, true);
explanation = new SQLiteExplanation(Properties::get().PATH_EXPLAIN, true);
explanation->begin_tr();
explanation->insertEntities(index);
explanation->insertRelations(index);
Expand Down
39 changes: 13 additions & 26 deletions include/Explanation.h
Original file line number Diff line number Diff line change
@@ -1,42 +1,29 @@
#ifndef EXPL_H
#define EXPL_H

#include <sqlite3.h>
#include "Index.h"
#include "Rule.h"
#include "RuleReader.h"
#include "ClusteringReader.h"

class Explanation {
public:
Explanation(std::string dbName, bool init = false);
~Explanation();
Explanation() {}
virtual ~Explanation() {}

void begin();
void commit();
void begin_tr();
void commit_tr();
void insertEntities(Index* index);
void insertRelations(Index* index);
void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr);
virtual void begin() = 0;
virtual void commit() = 0;
virtual void begin_tr() = 0;
virtual void commit_tr() = 0;
virtual void insertEntities(Index* index) = 0;
virtual void insertRelations(Index* index) = 0;
virtual void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr) = 0;

void insertTask(int prediction_id, bool is_head, int relation_id, int entity_id);
void insertPrediction(int task_id, int entity_id, bool hit, double confidence);
void insertRule_Entity(int rule_id, int task_id, int entity_id);
virtual void insertTask(int prediction_id, bool is_head, int relation_id, int entity_id) = 0;
virtual void insertPrediction(int task_id, int entity_id, bool hit, double confidence) = 0;
virtual void insertRule_Entity(int rule_id, int task_id, int entity_id) = 0;

// OLD
void insertCluster(int prediction_id, int entity_id, int cluster_id, double confidence);
void insertRule_Cluster(int prediction_id, int entity_id, int cluster_id, int rule_id);

int getNextTaskID();
private:
sqlite3* db;
int task_id = 0;
void initDb();
void checkErrorCode(int code);
void checkErrorCode(int code, char* sql);
sqlite3_stmt* prepare(char* sql);
void finalize(sqlite3_stmt* stmt);
virtual int getNextTaskID() = 0;
};

#endif //EXPL_H
34 changes: 34 additions & 0 deletions include/InMemoryExplanation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef INMEMORY_EXPL_H
#define INMEMORY_EXPL_H

#include "Explanation.h"

class InMemoryExplanation: public Explanation {
public:
InMemoryExplanation();
~InMemoryExplanation();

void begin();
void commit();
void begin_tr();
void commit_tr();
void insertEntities(Index* index);
void insertRelations(Index* index);
void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr);

void insertTask(int task_id, bool is_head, int relation_id, int entity_id);
void insertPrediction(int task_id, int entity_id, bool hit, double confidence);
void insertRule_Entity(int rule_id, int task_id, int entity_id);

int getNextTaskID();

std::unordered_map<int, std::unordered_map<int, std::unordered_map<int, int>>> tripleBestRules;

private:
int _task_id = 0;
std::unordered_map<int, float> ruleConfidences;
std::unordered_map<int, std::tuple<int, int, bool>> tasks;
std::unordered_map<int, std::unordered_map<int, int>> taskEntityBestRules;
};

#endif //INMEMORY_EXPL_H
4 changes: 4 additions & 0 deletions include/RuleApplication.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <map>
#include <functional>
#include <math.h>
#include <tuple>
#include "Index.h"
#include "TraintripleReader.h"
#include "TesttripleReader.h"
Expand All @@ -27,9 +28,12 @@ class RuleApplication
{
public:
RuleApplication(Index* index, TraintripleReader* graph, TesttripleReader* ttr, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp);
RuleApplication(Index* index, TraintripleReader* graph, ValidationtripleReader* vtr, RuleReader* rr, Explanation* exp);
void apply_nr_noisy(std::unordered_map<int, std::pair<std::pair<bool, std::vector<std::vector<int>>>, std::pair<bool, std::vector<std::vector<int>>>>> rel2clusters);
void apply_only_noisy();
void apply_only_max();
void updateTTR(TesttripleReader* ttr);
std::vector<std::tuple<int, int, int, float50>> apply_only_max_in_memory(size_t K);

private:
Index* index;
Expand Down
2 changes: 2 additions & 0 deletions include/RuleGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
class RuleGraph {
public:
RuleGraph(int nodesize, TraintripleReader* graph);
RuleGraph(int nodesize, TraintripleReader* graph, ValidationtripleReader* vtr);
RuleGraph(int nodesize, TraintripleReader* graph, TesttripleReader* ttr, ValidationtripleReader* vtr);
void updateTTR(TesttripleReader *ttr);
void searchDFSSingleStart_filt(bool headNotTail, int filt_v, int v, Rule& r, bool bwd, std::vector<int>& solution, bool filtValidNotTest, bool filtExceptions);
void searchDFSMultiStart_filt(bool headNotTail, int filt_v, Rule& r, bool bwd, std::vector<int>& solution, bool filtValidNotTest, bool filtExceptions);
bool existsAcyclic(int* valId, Rule& rule, bool filtValidNotTest);
Expand Down
39 changes: 39 additions & 0 deletions include/SQLiteExplanation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef SQLITE_EXPL_H
#define SQLITE_EXPL_H

#include <sqlite3.h>
#include "Explanation.h"

class SQLiteExplanation: public Explanation {
public:
SQLiteExplanation(std::string dbName, bool init = false);
~SQLiteExplanation();

void begin();
void commit();
void begin_tr();
void commit_tr();
void insertEntities(Index* index);
void insertRelations(Index* index);
void insertRules(RuleReader* rr, int relsize, ClusteringReader* cr);

void insertTask(int task_id, bool is_head, int relation_id, int entity_id);
void insertPrediction(int task_id, int entity_id, bool hit, double confidence);
void insertRule_Entity(int rule_id, int task_id, int entity_id);

// OLD
void insertCluster(int prediction_id, int entity_id, int cluster_id, double confidence);
void insertRule_Cluster(int prediction_id, int entity_id, int cluster_id, int rule_id);

int getNextTaskID();
private:
sqlite3* db;
int _task_id = 0;
void initDb();
void checkErrorCode(int code);
void checkErrorCode(int code, char* sql);
sqlite3_stmt* prepare(char* sql);
void finalize(sqlite3_stmt* stmt);
};

#endif //SQLITE_EXPL_H
7 changes: 4 additions & 3 deletions include/TesttripleReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ class TesttripleReader
{

public:
TesttripleReader(std::string filepath, Index* index, TraintripleReader* graph, int is_trial);
TesttripleReader(Index* index, TraintripleReader* graph, int is_trial);

int** getTesttriples();
int* getTesttriplesSize();
CSR<int, int>* getCSR();
RelNodeToNodes& getRelHeadToTails();
RelNodeToNodes& getRelTailToHeads();

void read(std::string filepath);
void read(std::vector<std::tuple<std::string, std::string, std::string>> & triples);

protected:

private:
Expand All @@ -41,8 +44,6 @@ class TesttripleReader

RelNodeToNodes relHeadToTails;
RelNodeToNodes relTailToHeads;

void read(std::string filepath);
};

#endif // TESTTRIPLEREADER_H
6 changes: 6 additions & 0 deletions python_bindings/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
build/
dist/
*.egg-info/
__pycache__
safran_wrapper.py
safran_wrapper_wrap.cpp
18 changes: 18 additions & 0 deletions python_bindings/safran.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import itertools
import collections
from typing import List, Dict, Tuple
from safran_wrapper import pysafran, query_output_t, query_triples_t

class SAFRAN(pysafran):
def __init__(self, train_path: str, rule_path: str, n_jobs: int = 1):
pysafran.__init__(self, train_path, rule_path, n_jobs)

def query(self, triples: List[List[str]], k: int = 100, action: str = 'applymax') -> Dict[Tuple[str, str], List[Tuple[str, float, int]]]:
if action != 'applymax':
raise ValueError('Actions supported in the SAFRAN Python wrapper are: applymax')
flat_triples = list(itertools.chain.from_iterable(triples))
pred_vals = query_output_t(pysafran.query(self, action, k, query_triples_t(flat_triples)))
out = collections.defaultdict(list)
for head, (pred, (tail, (val, rule_id))) in pred_vals:
out[head, pred].append((tail, val, rule_id))
return out
69 changes: 69 additions & 0 deletions python_bindings/safran_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#include "safran_wrapper.h"

pysafran::pysafran(std::string train_path, std::string rule_path, int num_threads) {
// TODO: initialize these globally, or make those a part of pysafran instance.
Properties::get().VERBOSE = 0;
Properties::get().PREDICT_UNKNOWN = 1;
if (num_threads != -1) {
omp_set_num_threads(num_threads);
}

this->index = new Index();
index->addNode(Properties::get().REFLEXIV_TOKEN);
index->addNode(Properties::get().UNK_TOKEN);

this->graph = new TraintripleReader(train_path, index);
Properties::get().REL_SIZE = index->getRelSize();

this->rr = new RuleReader(rule_path, index, graph);
this->vtr = new ValidationtripleReader("/dev/null", index, graph);

this->exp = new InMemoryExplanation();
this->exp->insertRules(rr, index->getRelSize(), nullptr);

this->ra = new RuleApplication(index, graph, vtr, rr, exp);
}

pysafran::~pysafran() {
delete this->ra;
delete this->vtr;
delete this->rr;
delete this->graph;
delete this->index;
delete this->exp;
}

std::vector<std::pair<std::string, std::pair<std::string, std::pair<std::string, std::pair<float, int>>>>> pysafran::query(const std::string & action, size_t k, const std::vector<std::string> & flat_triples) const {
// "Read" triples
TesttripleReader ttr(index, graph, 0);
std::vector<std::tuple<std::string, std::string, std::string>> triples;
size_t i = 0;
while (i < flat_triples.size()) {
triples.emplace_back(flat_triples[i], flat_triples[i + 1], flat_triples[i + 2]);
i += 3;
}
ttr.read(triples);
ra->updateTTR(&ttr);

// Apply specific rule application
std::vector<std::tuple<int, int, int, float50>> outAction;
if (action == "applymax") {
outAction = ra->apply_only_max_in_memory(k);
}

// Transform node ids to node names and only retain top-k
std::vector<std::pair<std::string, std::pair<std::string, std::pair<std::string, std::pair<float, int>>>>> out;
for (auto & p : outAction) {
int head_id = std::get<0>(p), relation_id = std::get<1>(p), tail_id = std::get<2>(p);
float val = (float)std::get<3>(p);
std::string headStr = *this->index->getStringOfNodeId(head_id);
std::string predStr = *this->index->getStringOfRelId(relation_id);
std::string tailStr = *this->index->getStringOfNodeId(tail_id);
int rule_id = exp->tripleBestRules[head_id][relation_id][tail_id];
std::pair<float, int> p1 = {val, rule_id};
std::pair<std::string, std::pair<float, int>> p2 = {tailStr, p1};
std::pair<std::string, std::pair<std::string, std::pair<float, int>>> p3 = {predStr, p2};
out.emplace_back(headStr, p3);
}
return out;
}
31 changes: 31 additions & 0 deletions python_bindings/safran_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "Index.h"
#include "TraintripleReader.h"
#include "ValidationtripleReader.h"
#include "RuleReader.h"
#include "RuleApplication.h"
#include "Properties.hpp"
#include "InMemoryExplanation.h"

#include <utility>
#include <string>
#include <vector>
#include <tuple>

using string_vector_t = std::vector<std::string>;

class pysafran {
public:
pysafran(std::string train_path, std::string rule_path, int num_threads);
~pysafran();

// [(head, pred, tail)] -> [(head, pred, tail, rule_confidence, rule_id)]
std::vector<std::pair<std::string, std::pair<std::string, std::pair<std::string, std::pair<float, int>>>>> query(const std::string & action, size_t k, const std::vector<std::string> & flat_triples) const;

private:
Index *index;
TraintripleReader* graph;
ValidationtripleReader* vtr;
RuleReader* rr;
RuleApplication *ra;
InMemoryExplanation *exp;
};
20 changes: 20 additions & 0 deletions python_bindings/safran_wrapper.i
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
%module safran_wrapper

%{
#define SWIG_FILE_WITH_INIT
#include "safran_wrapper.h"
%}

%include <std_string.i>
%include <std_vector.i>
%include <std_pair.i>
%include "safran_wrapper.h"

namespace std {
%template(query_triples_t) vector<string>;
%template(query_p1) pair<float, int>;
%template(query_p2) pair<string, pair<float, int>>;
%template(query_p3) pair<string, pair<string, pair<float, int>>>;
%template(query_p4) pair<string, pair<string, pair<string, pair<float, int>>>>;
%template(query_output_t) vector<pair<string, pair<string, pair<string, pair<float, int>>>>>;
}
Loading