From ca9db0d7709c943113cbf584cd31bedb8aca38c5 Mon Sep 17 00:00:00 2001 From: f-hafner Date: Thu, 15 Dec 2022 15:48:39 +0100 Subject: [PATCH 1/3] add coref switch to ED --- rebl/ed/entity_disambiguation.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/rebl/ed/entity_disambiguation.py b/rebl/ed/entity_disambiguation.py index 96458ea..5558abe 100644 --- a/rebl/ed/entity_disambiguation.py +++ b/rebl/ed/entity_disambiguation.py @@ -29,7 +29,7 @@ def __init__(self, **kwargs): self.stream_raw_source_file = input_stream_gen_lines(self.arguments['source_file']) self.mention_detection = MentionDetection(self.arguments['base_url'], self.arguments['wiki_version']) self.model = RelED(self.arguments['base_url'], self.arguments['wiki_version'], self.config, - reset_embeddings=True) + reset_embeddings=True, no_corefs=self.arguments['no_corefs']) self.docs_done = 0 def get_ids(self): @@ -151,7 +151,8 @@ def get_arguments(kwargs): 'base_url': None, 'wiki_version': None, 'identifier': 'docid', - 'write_batch_size': 10000 + 'write_batch_size': 10000, + "no_corefs": None } for key, item in arguments.items(): if kwargs.get(key) is not None: @@ -162,6 +163,7 @@ def get_arguments(kwargs): for key in ['md_file', 'source_file', 'out_file', 'base_url', 'wiki_version']: if arguments[key] is None: raise IOError(f'Argument {key} needs to be provided') + print(arguments) return arguments @@ -221,5 +223,11 @@ def get_arguments(kwargs): help='Write batch size', default=10000 ) + parser.add_argument( + "--no_corefs", + action="store_true", + help="use function with_coref for entity disambiguation()?", + default=False + ) ed = EntityDisambiguation(**vars(parser.parse_args())) ed.process() From 958e48ae9942ebaaa42dc8e51bb8be0aea728b9f Mon Sep 17 00:00:00 2001 From: f-hafner Date: Fri, 16 Dec 2022 11:07:52 +0100 Subject: [PATCH 2/3] add coref indicator to results --- rebl/ed/entity_disambiguation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/rebl/ed/entity_disambiguation.py b/rebl/ed/entity_disambiguation.py index 5558abe..9e18f07 100644 --- a/rebl/ed/entity_disambiguation.py +++ b/rebl/ed/entity_disambiguation.py @@ -2,6 +2,7 @@ import gzip import json import time +import pdb import pandas as pd import pyarrow as pa @@ -115,16 +116,16 @@ def create_disambiguate_batches(self): if len(batch) >= self.arguments['write_batch_size']: yield batch batch = [] - for start_pos, span, text, entity, ed_score, tag, md_score in result[identifier]: + for start_pos, span, text, entity, ed_score, tag, md_score, is_coref in result[identifier]: doc_id, field = identifier.split('+') field = self.fields_inverted[field] - batch.append([doc_id, field, start_pos, start_pos + span, entity, ed_score, tag, md_score]) + batch.append([doc_id, field, start_pos, start_pos + span, entity, ed_score, tag, md_score, is_coref]) yield batch def process(self): gen = self.create_disambiguate_batches() df = pd.DataFrame(next(gen), - columns=['doc_id', 'field', 'start_pos', 'end_pos', 'entity', 'ed_score', 'tag', 'md_score']) + columns=['doc_id', 'field', 'start_pos', 'end_pos', 'entity', 'ed_score', 'tag', 'md_score', 'is_coref']) table = pa.Table.from_pandas(df=df, preserve_index=False) t = time.time() with pq.ParquetWriter(self.out_file, schema=table.schema) as writer: @@ -132,7 +133,7 @@ def process(self): for batch in gen: df = pd.DataFrame(batch, columns=['doc_id', 'field', 'start_pos', 'end_pos', 'entity', 'ed_score', 'tag', - 'md_score']) + 'md_score', 'is_coref']) table = pa.Table.from_pandas(df=df, preserve_index=False) writer.write_table(table) batch_time = time.time() - t From 4c8d5d69b5f7639e67e0838129f45a557b91fb53 Mon Sep 17 00:00:00 2001 From: f-hafner Date: Mon, 9 Jan 2023 10:57:39 +0100 Subject: [PATCH 3/3] integrate search_corefs option from REL --- rebl/ed/entity_disambiguation.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/rebl/ed/entity_disambiguation.py b/rebl/ed/entity_disambiguation.py index 9e18f07..959779f 100644 --- a/rebl/ed/entity_disambiguation.py +++ b/rebl/ed/entity_disambiguation.py @@ -30,7 +30,7 @@ def __init__(self, **kwargs): self.stream_raw_source_file = input_stream_gen_lines(self.arguments['source_file']) self.mention_detection = MentionDetection(self.arguments['base_url'], self.arguments['wiki_version']) self.model = RelED(self.arguments['base_url'], self.arguments['wiki_version'], self.config, - reset_embeddings=True, no_corefs=self.arguments['no_corefs']) + reset_embeddings=True, search_corefs=self.arguments['search_corefs']) self.docs_done = 0 def get_ids(self): @@ -153,7 +153,7 @@ def get_arguments(kwargs): 'wiki_version': None, 'identifier': 'docid', 'write_batch_size': 10000, - "no_corefs": None + "search_corefs": None } for key, item in arguments.items(): if kwargs.get(key) is not None: @@ -224,11 +224,12 @@ def get_arguments(kwargs): help='Write batch size', default=10000 ) - parser.add_argument( - "--no_corefs", - action="store_true", - help="use function with_coref for entity disambiguation()?", - default=False + parser.add_argument( + '--search_corefs', + type=str, + choices=['all', 'lsh', 'off'], + required=True, + help="Setting for search_corefs in Entity Disambiguation." ) ed = EntityDisambiguation(**vars(parser.parse_args())) ed.process()