@@ -174,7 +174,7 @@ def load(self, file_path: str) -> None:
174174 def query_papers_by_doc_similarity (self , query_docs : List [str ], sentence_search : bool = False ,
175175 remove_sections : List [PaperSections ] = None ,
176176 must_be_present : List [PaperSections ] = None , path_to_model : str = None ,
177- average_match : bool = True ) -> List [SimilarityResult ]:
177+ average_match : bool = True , num_best : int = 10 ) -> List [SimilarityResult ]:
178178 """query papers in the corpus by similarity with the provided query documents, which can be fulltext documents
179179 or sentences
180180
@@ -186,6 +186,7 @@ def query_papers_by_doc_similarity(self, query_docs: List[str], sentence_search:
186186 sections
187187 path_to_model (str): path to word2vec model
188188 average_match (bool): merge query documents and calculate average similarity to them
189+ num_best (int): limit to the first n results by similarity score
189190
190191 Returns:
191192 List[SimilarityResult]: list of papers most similar to the provided query documents
@@ -198,12 +199,13 @@ def query_papers_by_doc_similarity(self, query_docs: List[str], sentence_search:
198199 split_sentences = sentence_search , remove_sections = remove_sections , must_be_present = must_be_present ,
199200 lowercase = False , tokenize = False , remove_stopwords = False , remove_alpha = False )
200201 docsim_index , dictionary = get_softcosine_index (model = model , model_path = path_to_model ,
201- corpus_list_token = corpus_list_token )
202+ corpus_list_token = corpus_list_token , num_best = num_best )
202203 query_docs_preprocessed = [preprocess (doc = sentence , lower = True , tokenize = True , remove_stopwords = True ,
203204 remove_alpha = True ) for sentence in query_docs ]
204205 sims = get_similar_documents (docsim_index , dictionary , query_docs_preprocessed , idx_paperid_map ,
205206 average_match = average_match )
206- return [SimilarityResult (score = sim .score , paper_id = sim .paper_id , match_idx = sim .match_idx ,
207- query_idx = sim .query_idx , match = "\" " + corpus_list_token_orig [sim .match_idx ] + "\" " ,
208- query = "\" " + (" " .join (query_docs ) if average_match else query_docs [sim .query_idx ]) +
209- "\" " ) for sim in sims ]
207+ results = [SimilarityResult (score = sim .score , paper_id = sim .paper_id , match_idx = sim .match_idx ,
208+ query_idx = sim .query_idx , match = "\" " + corpus_list_token_orig [sim .match_idx ] + "\" " ,
209+ query = "\" " + (" " .join (query_docs ) if average_match else query_docs [sim .query_idx ]
210+ ) + "\" " ) for sim in sims ]
211+ return results [0 :num_best ] if len (results ) > num_best else results
0 commit comments