Skip to content

Commit e2bbdba

Browse files
committed
fix(similarity_query): new param to set max num of results
1 parent fc2814c commit e2bbdba

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

wbtools/literature/corpus.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def load(self, file_path: str) -> None:
174174
def query_papers_by_doc_similarity(self, query_docs: List[str], sentence_search: bool = False,
175175
remove_sections: List[PaperSections] = None,
176176
must_be_present: List[PaperSections] = None, path_to_model: str = None,
177-
average_match: bool = True) -> List[SimilarityResult]:
177+
average_match: bool = True, num_best: int = 10) -> List[SimilarityResult]:
178178
"""query papers in the corpus by similarity with the provided query documents, which can be fulltext documents
179179
or sentences
180180
@@ -186,6 +186,7 @@ def query_papers_by_doc_similarity(self, query_docs: List[str], sentence_search:
186186
sections
187187
path_to_model (str): path to word2vec model
188188
average_match (bool): merge query documents and calculate average similarity to them
189+
num_best (int): limit to the first n results by similarity score
189190
190191
Returns:
191192
List[SimilarityResult]: list of papers most similar to the provided query documents
@@ -198,12 +199,13 @@ def query_papers_by_doc_similarity(self, query_docs: List[str], sentence_search:
198199
split_sentences=sentence_search, remove_sections=remove_sections, must_be_present=must_be_present,
199200
lowercase=False, tokenize=False, remove_stopwords=False, remove_alpha=False)
200201
docsim_index, dictionary = get_softcosine_index(model=model, model_path=path_to_model,
201-
corpus_list_token=corpus_list_token)
202+
corpus_list_token=corpus_list_token, num_best=num_best)
202203
query_docs_preprocessed = [preprocess(doc=sentence, lower=True, tokenize=True, remove_stopwords=True,
203204
remove_alpha=True) for sentence in query_docs]
204205
sims = get_similar_documents(docsim_index, dictionary, query_docs_preprocessed, idx_paperid_map,
205206
average_match=average_match)
206-
return [SimilarityResult(score=sim.score, paper_id=sim.paper_id, match_idx=sim.match_idx,
207-
query_idx=sim.query_idx, match="\"" + corpus_list_token_orig[sim.match_idx] + "\"",
208-
query="\"" + (" ".join(query_docs) if average_match else query_docs[sim.query_idx]) +
209-
"\"") for sim in sims]
207+
results = [SimilarityResult(score=sim.score, paper_id=sim.paper_id, match_idx=sim.match_idx,
208+
query_idx=sim.query_idx, match="\"" + corpus_list_token_orig[sim.match_idx] + "\"",
209+
query="\"" + (" ".join(query_docs) if average_match else query_docs[sim.query_idx]
210+
) + "\"") for sim in sims]
211+
return results[0:num_best] if len(results) > num_best else results

0 commit comments

Comments
 (0)