Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
anyascii
colorama
konoha
fastapi
flair>=0.11
konoha
nltk
pydantic
segtok
torch
nltk
anyascii
uvicorn
9 changes: 6 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ package_dir =
= src
include_package_data = True
install_requires =
anyascii
colorama
konoha
fastapi
flair>=0.11
konoha
nltk
pydantic
segtok
torch
nltk
anyascii
uvicorn

[options.extras_require]
develop =
Expand Down
2 changes: 1 addition & 1 deletion src/REL/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def initialize_db(self, fname, table_name, columns):
db (sqlite3.Connection): a SQLite3 database with an embeddings table.
"""
# open database in autocommit mode by setting isolation_level to None.
db = sqlite3.connect(fname, isolation_level=None)
db = sqlite3.connect(fname, isolation_level=None, check_same_thread=False)

q = "create table if not exists {}(word text primary key, {})".format(
table_name, ", ".join(["{} {}".format(k, v) for k, v in columns.items()])
Expand Down
236 changes: 88 additions & 148 deletions src/REL/server.py
Original file line number Diff line number Diff line change
@@ -1,152 +1,99 @@
import json
from http.server import BaseHTTPRequestHandler

from REL.entity_disambiguation import EntityDisambiguation
from REL.ner import load_flair_ner
from flair.models import SequenceTagger

from REL.mention_detection import MentionDetection
from REL.utils import process_results

API_DOC = "API_DOC"



def make_handler(base_url, wiki_version, model, tagger_ner):
"""
Class/function combination that is used to setup an API that can be used for e.g. GERBIL evaluation.
"""
class GetHandler(BaseHTTPRequestHandler):
def __init__(self, *args, **kwargs):
self.model = model
self.tagger_ner = tagger_ner

self.base_url = base_url
self.wiki_version = wiki_version

self.custom_ner = not isinstance(tagger_ner, SequenceTagger)
self.mention_detection = MentionDetection(base_url, wiki_version)

super().__init__(*args, **kwargs)

def do_GET(self):
self.send_response(200)
self.end_headers()
self.wfile.write(
bytes(
json.dumps(
{
"schemaVersion": 1,
"label": "status",
"message": "up",
"color": "green",
}
),
"utf-8",
)
)
return

def do_HEAD(self):
# send bad request response code
self.send_response(400)
self.end_headers()
self.wfile.write(bytes(json.dumps([]), "utf-8"))
return

def do_POST(self):
"""
Returns response.

:return:
"""
try:
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)
self.send_response(200)
self.end_headers()

text, spans = self.read_json(post_data)
response = self.generate_response(text, spans)

self.wfile.write(bytes(json.dumps(response), "utf-8"))
except Exception as e:
print(f"Encountered exception: {repr(e)}")
self.send_response(400)
self.end_headers()
self.wfile.write(bytes(json.dumps([]), "utf-8"))
return

def read_json(self, post_data):
"""
Reads input JSON message.

:return: document text and spans.
"""

data = json.loads(post_data.decode("utf-8"))
text = data["text"]
text = text.replace("&", "&")

# GERBIL sends dictionary, users send list of lists.
if "spans" in data:
try:
spans = [list(d.values()) for d in data["spans"]]
except Exception:
spans = data["spans"]
pass
else:
spans = []

return text, spans

def generate_response(self, text, spans):
"""
Generates response for API. Can be either ED only or EL, meaning end-to-end.

:return: list of tuples for each entity found.
"""

if len(text) == 0:
return []

if len(spans) > 0:
# ED.
processed = {API_DOC: [text, spans]}
mentions_dataset, total_ment = self.mention_detection.format_spans(
processed
)
else:
# EL
processed = {API_DOC: [text, spans]}
mentions_dataset, total_ment = self.mention_detection.find_mentions(
processed, self.tagger_ner
)

# Disambiguation
predictions, timing = self.model.predict(mentions_dataset)

# Process result.
result = process_results(
mentions_dataset,
predictions,
processed,
include_offset=False if ((len(spans) > 0) or self.custom_ner) else True,
)

# Singular document.
if len(result) > 0:
return [*result.values()][0]

class ModelHandler:
API_DOC = "API_DOC"

def __init__(self, base_url, wiki_version, model, tagger_ner):
self.model = model
self.tagger_ner = tagger_ner

self.base_url = base_url
self.wiki_version = wiki_version

self.custom_ner = not isinstance(tagger_ner, SequenceTagger)
self.mention_detection = MentionDetection(base_url, wiki_version)

def generate_response(self,
*,
text: list,
spans: list,
conversation: bool = False):
"""
Generates response for API. Can be either ED only or EL, meaning end-to-end.

:return: list of tuples for each entity found.
"""

if len(text) == 0:
return []

return GetHandler
if len(spans) > 0:
# ED.
processed = {self.API_DOC: [text, spans]}
mentions_dataset, total_ment = self.mention_detection.format_spans(
processed)
else:
# EL
processed = {self.API_DOC: [text, spans]}
mentions_dataset, total_ment = self.mention_detection.find_mentions(
processed, self.tagger_ner)

# Disambiguation
predictions, timing = self.model.predict(mentions_dataset)

# Process result.
result = process_results(
mentions_dataset,
predictions,
processed,
include_offset=False if
((len(spans) > 0) or self.custom_ner) else True,
)

# Singular document.
if len(result) > 0:
return [*result.values()][0]

return []


from fastapi import FastAPI
from pydantic import BaseModel, Field
from typing import List

app = FastAPI()

@app.get("/")
def root():
"""Returns server status."""
return {
"schemaVersion": 1,
"label": "status",
"message": "up",
"color": "green",
}


class EntityConfig(BaseModel):
text: str = Field(..., description="Text for entity linking or disambiguation.")
spans: List[str] = Field(..., description="Spans for entity disambiguation.")


@app.post("/")
def root(config: EntityConfig):
"""Submit your text here for entity disambiguation or linking."""
response = handler.generate_response(text=config.text, spans=config.spans)
return response


if __name__ == "__main__":
import argparse
from http.server import HTTPServer

from REL.entity_disambiguation import EntityDisambiguation
from REL.ner import load_flair_ner
import uvicorn

p = argparse.ArgumentParser()
p.add_argument("base_url")
Expand All @@ -161,14 +108,7 @@ def generate_response(self, text, spans):
ed_model = EntityDisambiguation(
args.base_url, args.wiki_version, {"mode": "eval", "model_path": args.ed_model}
)
server_address = (args.bind, args.port)
server = HTTPServer(
server_address,
make_handler(args.base_url, args.wiki_version, ed_model, ner_model),
)

try:
print("Ready for listening.")
server.serve_forever()
except KeyboardInterrupt:
exit(0)
handler = ModelHandler(args.base_url, args.wiki_version, ed_model, ner_model)

uvicorn.run(app, port=args.port, host=args.bind)