diff --git a/pinecone_text/sparse/bm25_encoder.py b/pinecone_text/sparse/bm25_encoder.py index fe5ed0c..ff62529 100644 --- a/pinecone_text/sparse/bm25_encoder.py +++ b/pinecone_text/sparse/bm25_encoder.py @@ -6,7 +6,7 @@ from tqdm.auto import tqdm import wget from typing import List, Optional, Dict, Union, Tuple -from collections import Counter +from collections import Counter, OrderedDict from pinecone_text.sparse import SparseVector from pinecone_text.sparse.base_sparse_encoder import BaseSparseEncoder @@ -26,6 +26,7 @@ def __init__( remove_stopwords: bool = True, stem: bool = True, language: str = "english", + indptrs: bool = False, ): """ OKapi BM25 with mmh3 hashing @@ -38,6 +39,7 @@ def __init__( remove_stopwords: Whether to remove stopwords tokens stem: Whether to stem the tokens (using SnowballStemmer) language: The language of the text (used for stopwords and stemmer) + indptrs: Whether to return token positions within document frequency, to form a scipy.sparse array Example: @@ -55,6 +57,7 @@ def __init__( # Fixed params self.b: float = b self.k1: float = k1 + self.indptrs: bool = indptrs self._tokenizer = BM25Tokenizer( lower_case=lower_case, @@ -118,17 +121,20 @@ def encode_documents( raise ValueError("texts must be a string or list of strings") def _encode_single_document(self, text: str) -> SparseVector: - indices, doc_tf = self._tf(text) + indptrs, indices, doc_tf = self._tf(text) tf = np.array(doc_tf) tf_sum = sum(tf) tf_normed = tf / ( self.k1 * (1.0 - self.b + self.b * (tf_sum / self.avgdl)) + tf ) - return { - "indices": indices, - "values": tf_normed.tolist(), - } + + encoded_document = OrderedDict() + if self.indptrs: + encoded_document["indptrs"] = indptrs + encoded_document["indices"] = indices + encoded_document["values"] = tf_normed.tolist() + return encoded_document def encode_queries( self, texts: Union[str, List[str]] @@ -267,18 +273,19 @@ def _hash_text(token: str) -> int: """Use mmh3 to hash text to 32-bit unsigned integer""" return mmh3.hash(token, signed=False) - def _tf(self, text: str) -> Tuple[List[int], List[int]]: + def _tf(self, text: str) -> Tuple[List[int], List[int], [List[int]]]: """ Calculate term frequency for a given text Args: text: a document to calculate term frequency for - Returns: a tuple of two lists: + Returns: a tuple of three lists: + indptrs: list of position pointers indices: list of term indices values: list of term frequencies """ counts = Counter((self._hash_text(token) for token in self._tokenizer(text))) items = list(counts.items()) - return [idx for idx, _ in items], [val for _, val in items] + return [self.doc_freq.index(idx) for idx, _ in items], [idx for idx, _ in items], [val for _, val in items]