From 637a7b54625e38fd5cd8b07c777ddb815e92090b Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Mon, 19 Dec 2022 15:29:19 +0100 Subject: [PATCH 1/2] Re-implement server using fastapi/pydantic --- requirements.txt | 9 +- setup.cfg | 9 +- src/REL/db/base.py | 2 +- src/REL/server.py | 236 +++++++++++++++++---------------------------- 4 files changed, 101 insertions(+), 155 deletions(-) diff --git a/requirements.txt b/requirements.txt index c5a5969..c84bf33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,10 @@ +anyascii colorama -konoha +fastapi flair>=0.11 +konoha +nltk +pydantic segtok torch -nltk -anyascii +uvicorn diff --git a/setup.cfg b/setup.cfg index 8fbd4af..19ab16c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,13 +43,16 @@ package_dir = = src include_package_data = True install_requires = + anyascii colorama - konoha + fastapi flair>=0.11 + konoha + nltk + pydantic segtok torch - nltk - anyascii + uvicorn [options.extras_require] develop = diff --git a/src/REL/db/base.py b/src/REL/db/base.py index 8eec44d..2526946 100644 --- a/src/REL/db/base.py +++ b/src/REL/db/base.py @@ -40,7 +40,7 @@ def initialize_db(self, fname, table_name, columns): db (sqlite3.Connection): a SQLite3 database with an embeddings table. """ # open database in autocommit mode by setting isolation_level to None. - db = sqlite3.connect(fname, isolation_level=None) + db = sqlite3.connect(fname, isolation_level=None, check_same_thread=False) q = "create table if not exists {}(word text primary key, {})".format( table_name, ", ".join(["{} {}".format(k, v) for k, v in columns.items()]) diff --git a/src/REL/server.py b/src/REL/server.py index d26d6a9..c488330 100644 --- a/src/REL/server.py +++ b/src/REL/server.py @@ -1,152 +1,99 @@ -import json -from http.server import BaseHTTPRequestHandler - +from REL.entity_disambiguation import EntityDisambiguation +from REL.ner import load_flair_ner from flair.models import SequenceTagger - from REL.mention_detection import MentionDetection from REL.utils import process_results -API_DOC = "API_DOC" - - - -def make_handler(base_url, wiki_version, model, tagger_ner): - """ - Class/function combination that is used to setup an API that can be used for e.g. GERBIL evaluation. - """ - class GetHandler(BaseHTTPRequestHandler): - def __init__(self, *args, **kwargs): - self.model = model - self.tagger_ner = tagger_ner - - self.base_url = base_url - self.wiki_version = wiki_version - - self.custom_ner = not isinstance(tagger_ner, SequenceTagger) - self.mention_detection = MentionDetection(base_url, wiki_version) - - super().__init__(*args, **kwargs) - - def do_GET(self): - self.send_response(200) - self.end_headers() - self.wfile.write( - bytes( - json.dumps( - { - "schemaVersion": 1, - "label": "status", - "message": "up", - "color": "green", - } - ), - "utf-8", - ) - ) - return - - def do_HEAD(self): - # send bad request response code - self.send_response(400) - self.end_headers() - self.wfile.write(bytes(json.dumps([]), "utf-8")) - return - - def do_POST(self): - """ - Returns response. - - :return: - """ - try: - content_length = int(self.headers["Content-Length"]) - post_data = self.rfile.read(content_length) - self.send_response(200) - self.end_headers() - - text, spans = self.read_json(post_data) - response = self.generate_response(text, spans) - - self.wfile.write(bytes(json.dumps(response), "utf-8")) - except Exception as e: - print(f"Encountered exception: {repr(e)}") - self.send_response(400) - self.end_headers() - self.wfile.write(bytes(json.dumps([]), "utf-8")) - return - - def read_json(self, post_data): - """ - Reads input JSON message. - - :return: document text and spans. - """ - - data = json.loads(post_data.decode("utf-8")) - text = data["text"] - text = text.replace("&", "&") - - # GERBIL sends dictionary, users send list of lists. - if "spans" in data: - try: - spans = [list(d.values()) for d in data["spans"]] - except Exception: - spans = data["spans"] - pass - else: - spans = [] - - return text, spans - - def generate_response(self, text, spans): - """ - Generates response for API. Can be either ED only or EL, meaning end-to-end. - - :return: list of tuples for each entity found. - """ - - if len(text) == 0: - return [] - - if len(spans) > 0: - # ED. - processed = {API_DOC: [text, spans]} - mentions_dataset, total_ment = self.mention_detection.format_spans( - processed - ) - else: - # EL - processed = {API_DOC: [text, spans]} - mentions_dataset, total_ment = self.mention_detection.find_mentions( - processed, self.tagger_ner - ) - - # Disambiguation - predictions, timing = self.model.predict(mentions_dataset) - - # Process result. - result = process_results( - mentions_dataset, - predictions, - processed, - include_offset=False if ((len(spans) > 0) or self.custom_ner) else True, - ) - - # Singular document. - if len(result) > 0: - return [*result.values()][0] +class ModelHandler: + API_DOC = "API_DOC" + + def __init__(self, base_url, wiki_version, ed_model, ner_model): + self.model = model + self.tagger_ner = tagger_ner + + self.base_url = base_url + self.wiki_version = wiki_version + + self.custom_ner = not isinstance(tagger_ner, SequenceTagger) + self.mention_detection = MentionDetection(base_url, wiki_version) + + def generate_response(self, + *, + text: list, + spans: list, + conversation: bool = False): + """ + Generates response for API. Can be either ED only or EL, meaning end-to-end. + + :return: list of tuples for each entity found. + """ + + if len(text) == 0: return [] - return GetHandler + if len(spans) > 0: + # ED. + processed = {self.API_DOC: [text, spans]} + mentions_dataset, total_ment = self.mention_detection.format_spans( + processed) + else: + # EL + processed = {self.API_DOC: [text, spans]} + mentions_dataset, total_ment = self.mention_detection.find_mentions( + processed, self.tagger_ner) + + # Disambiguation + predictions, timing = self.model.predict(mentions_dataset) + + # Process result. + result = process_results( + mentions_dataset, + predictions, + processed, + include_offset=False if + ((len(spans) > 0) or self.custom_ner) else True, + ) + + # Singular document. + if len(result) > 0: + return [*result.values()][0] + + return [] + + +from fastapi import FastAPI +from pydantic import BaseModel +from typing import List + +app = FastAPI() + +@app.get("/") +def root(): + """Returns server status.""" + return { + "schemaVersion": 1, + "label": "status", + "message": "up", + "color": "green", + } + + +class EntityConfig(BaseModel): + text: str = Field(..., description="Text for entity linking or disambiguation.") + spans: List[str] = Field(..., description="Spans for entity disambiguation.") + + +@app.post("/") +def root(config: EntityConfig): + """Submit your text here for entity disambiguation or linking.""" + response = handler.generate_response(text=config.text, spans=config.spans) + return response if __name__ == "__main__": import argparse - from http.server import HTTPServer - - from REL.entity_disambiguation import EntityDisambiguation - from REL.ner import load_flair_ner + import uvicorn p = argparse.ArgumentParser() p.add_argument("base_url") @@ -161,14 +108,7 @@ def generate_response(self, text, spans): ed_model = EntityDisambiguation( args.base_url, args.wiki_version, {"mode": "eval", "model_path": args.ed_model} ) - server_address = (args.bind, args.port) - server = HTTPServer( - server_address, - make_handler(args.base_url, args.wiki_version, ed_model, ner_model), - ) - try: - print("Ready for listening.") - server.serve_forever() - except KeyboardInterrupt: - exit(0) + handler = ModelHandler(args.base_url, args.wiki_version, ed_model, ner_model) + + uvicorn.run(app, port=args.port, host=args.bind) From fcd57454593a70de8c84d89cbc19ce9d1f92b369 Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Mon, 19 Dec 2022 15:35:11 +0100 Subject: [PATCH 2/2] Fix names --- src/REL/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/REL/server.py b/src/REL/server.py index c488330..4946851 100644 --- a/src/REL/server.py +++ b/src/REL/server.py @@ -8,7 +8,7 @@ class ModelHandler: API_DOC = "API_DOC" - def __init__(self, base_url, wiki_version, ed_model, ner_model): + def __init__(self, base_url, wiki_version, model, tagger_ner): self.model = model self.tagger_ner = tagger_ner @@ -63,7 +63,7 @@ def generate_response(self, from fastapi import FastAPI -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing import List app = FastAPI()