diff --git a/rebl/ed/entity_disambiguation.py b/rebl/ed/entity_disambiguation.py index 6fc69a7..c367bb3 100644 --- a/rebl/ed/entity_disambiguation.py +++ b/rebl/ed/entity_disambiguation.py @@ -2,6 +2,7 @@ import gzip import json import time +import pdb import pandas as pd import pyarrow as pa @@ -29,7 +30,7 @@ def __init__(self, **kwargs): self.stream_raw_source_file = input_stream_gen_lines(self.arguments['source_file']) self.mention_detection = MentionDetection(self.arguments['base_url'], self.arguments['wiki_version']) self.model = RelED(self.arguments['base_url'], self.arguments['wiki_version'], self.config, - reset_embeddings=True) + reset_embeddings=True, search_corefs=self.arguments['search_corefs']) self.docs_done = 0 def get_ids(self): @@ -135,16 +136,16 @@ def create_disambiguate_batches(self): if len(batch) >= self.arguments['write_batch_size']: yield batch batch = [] - for start_pos, span, text, entity, ed_score, tag, md_score in result[identifier]: + for start_pos, span, text, entity, ed_score, tag, md_score, is_coref in result[identifier]: doc_id, field = identifier.split('+') field = self.fields_inverted[field] - batch.append([doc_id, field, start_pos, start_pos + span, entity, ed_score, tag, md_score]) + batch.append([doc_id, field, start_pos, start_pos + span, entity, ed_score, tag, md_score, is_coref]) yield batch def process(self): gen = self.create_disambiguate_batches() df = pd.DataFrame(next(gen), - columns=['doc_id', 'field', 'start_pos', 'end_pos', 'entity', 'ed_score', 'tag', 'md_score']) + columns=['doc_id', 'field', 'start_pos', 'end_pos', 'entity', 'ed_score', 'tag', 'md_score', 'is_coref']) table = pa.Table.from_pandas(df=df, preserve_index=False) t = time.time() with pq.ParquetWriter(self.out_file, schema=table.schema) as writer: @@ -152,7 +153,7 @@ def process(self): for batch in gen: df = pd.DataFrame(batch, columns=['doc_id', 'field', 'start_pos', 'end_pos', 'entity', 'ed_score', 'tag', - 'md_score']) + 'md_score', 'is_coref']) table = pa.Table.from_pandas(df=df, preserve_index=False) writer.write_table(table) batch_time = time.time() - t @@ -171,7 +172,8 @@ def get_arguments(kwargs): 'base_url': None, 'wiki_version': None, 'identifier': 'docid', - 'write_batch_size': 10000 + 'write_batch_size': 10000, + "search_corefs": None } for key, item in arguments.items(): if kwargs.get(key) is not None: @@ -182,6 +184,7 @@ def get_arguments(kwargs): for key in ['md_file', 'source_file', 'out_file', 'base_url', 'wiki_version']: if arguments[key] is None: raise IOError(f'Argument {key} needs to be provided') + print(arguments) return arguments @@ -241,5 +244,12 @@ def get_arguments(kwargs): help='Write batch size', default=10000 ) + parser.add_argument( + '--search_corefs', + type=str, + choices=['all', 'lsh', 'off'], + required=True, + help="Setting for search_corefs in Entity Disambiguation." + ) ed = EntityDisambiguation(**vars(parser.parse_args())) ed.process()