diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index c2acd7b1d6..1f9cff8bf1 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -369,6 +369,7 @@ def hydrate_parallelization_params( ) with reader.envvar_prefix(Section.chunk), reader.use(values.get("chunks")): chunks_model = ChunkingConfig( + type=reader.str("type") or None, size=reader.int("size") or defs.CHUNK_SIZE, overlap=reader.int("overlap") or defs.CHUNK_OVERLAP, group_by_columns=reader.list("group_by_columns", "BY_COLUMNS") diff --git a/graphrag/config/models/chunking_config.py b/graphrag/config/models/chunking_config.py index ad7e3d0a9d..3f79bfc605 100644 --- a/graphrag/config/models/chunking_config.py +++ b/graphrag/config/models/chunking_config.py @@ -11,6 +11,7 @@ class ChunkingConfig(BaseModel): """Configuration section for chunking.""" + type: str = Field(description="The Split type", default=None) size: int = Field(description="The chunk size to use.", default=defs.CHUNK_SIZE) overlap: int = Field( description="The chunk overlap to use.", default=defs.CHUNK_OVERLAP @@ -29,7 +30,7 @@ def resolved_strategy(self) -> dict: from graphrag.index.verbs.text.chunk import ChunkStrategyType return self.strategy or { - "type": ChunkStrategyType.tokens, + "type": self.type or ChunkStrategyType.tokens, "chunk_size": self.size, "chunk_overlap": self.overlap, "group_by_columns": self.group_by_columns, diff --git a/graphrag/index/verbs/text/chunk/strategies/chinese.py b/graphrag/index/verbs/text/chunk/strategies/chinese.py new file mode 100644 index 0000000000..eace8f494e --- /dev/null +++ b/graphrag/index/verbs/text/chunk/strategies/chinese.py @@ -0,0 +1,125 @@ +from collections.abc import Iterable + +from datashaper import ProgressTicker + +from .typing import TextChunk + +import re +import logging + +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from typing import Any, List, Optional + +DEFAULT_CHUNK_SIZE = 300 # chars +DEFAULT_CHUNK_OVERLAP = 0 # chars + +log = logging.getLogger(__name__) + + +def run( + input: list[str], args: dict[str, Any], tick: ProgressTicker +) -> Iterable[TextChunk]: + log.info("using chinese text splitter configuration: %s", args) + keep_separator = args.get("keep_separator", True) + is_separator_regex = args.get("is_separator_regex", True) + chunk_size = args.get("chunk_size", DEFAULT_CHUNK_SIZE) + chunk_overlap = args.get("chunk_overlap", DEFAULT_CHUNK_OVERLAP) + text_splitter = ChineseRecursiveTextSplitter( + keep_separator=keep_separator, is_separator_regex=is_separator_regex, chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + textChunks = [] + for doc_idx, text in enumerate(input): + chunks = text_splitter.split_text(text) + for chunk in chunks: + textChunks.append(TextChunk( + text_chunk=chunk, + source_doc_indices=[doc_idx], + )) + tick(1) + return textChunks + + +def _split_text_with_regex_from_end( + text: str, separator: str, keep_separator: bool +) -> List[str]: + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])] + if len(_splits) % 2 == 1: + splits += _splits[-1:] + # splits = [_splits[0]] + splits + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + +class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): + def __init__( + self, + separators: Optional[List[str]] = None, + keep_separator: bool = True, + is_separator_regex: bool = True, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(keep_separator=keep_separator, **kwargs) + self._separators = separators or [ + "\n\n", + "\n", + "。|!|?", + "\.\s|\!\s|\?\s", + ";|;\s", + ",|,\s", + ] + self._is_separator_regex = is_separator_regex + + def _split_text(self, text: str, separators: List[str]) -> List[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + _separator = _s if self._is_separator_regex else re.escape(_s) + if _s == "": + separator = _s + break + if re.search(_separator, text): + separator = _s + new_separators = separators[i + 1:] + break + + _separator = separator if self._is_separator_regex else re.escape(separator) + splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator) + + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + _separator = "" if self._keep_separator else separator + for s in splits: + if self._length_function(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return [ + re.sub(r"\n{2,}", "\n", chunk.strip()) + for chunk in final_chunks + if chunk.strip() != "" + ] diff --git a/graphrag/index/verbs/text/chunk/text_chunk.py b/graphrag/index/verbs/text/chunk/text_chunk.py index d8fab44f64..a41468b06f 100644 --- a/graphrag/index/verbs/text/chunk/text_chunk.py +++ b/graphrag/index/verbs/text/chunk/text_chunk.py @@ -35,6 +35,7 @@ class ChunkStrategyType(str, Enum): tokens = "tokens" sentence = "sentence" + chinese = "chinese" def __repr__(self): """Get a string representation.""" @@ -43,12 +44,12 @@ def __repr__(self): @verb(name="chunk") def chunk( - input: VerbInput, - column: str, - to: str, - callbacks: VerbCallbacks, - strategy: dict[str, Any] | None = None, - **_kwargs, + input: VerbInput, + column: str, + to: str, + callbacks: VerbCallbacks, + strategy: dict[str, Any] | None = None, + **_kwargs, ) -> TableContainer: """ Chunk a piece of text into smaller pieces. @@ -106,10 +107,10 @@ def chunk( def run_strategy( - strategy: ChunkStrategy, - input: ChunkInput, - strategy_args: dict[str, Any], - tick: ProgressTicker, + strategy: ChunkStrategy, + input: ChunkInput, + strategy_args: dict[str, Any], + tick: ProgressTicker, ) -> list[str | tuple[list[str] | None, str, int]]: """Run strategy method definition.""" if isinstance(input, str): @@ -157,6 +158,10 @@ def load_strategy(strategy: ChunkStrategyType) -> ChunkStrategy: bootstrap() return run_sentence + case ChunkStrategyType.chinese: + from .strategies.chinese import run as run_chinese + + return run_chinese case _: msg = f"Unknown strategy: {strategy}" raise ValueError(msg) diff --git a/graphrag/index/verbs/text/embed/strategies/openai.py b/graphrag/index/verbs/text/embed/strategies/openai.py index 0658d604cc..de8a745bf1 100644 --- a/graphrag/index/verbs/text/embed/strategies/openai.py +++ b/graphrag/index/verbs/text/embed/strategies/openai.py @@ -31,6 +31,7 @@ async def run( if is_null(input): return TextEmbeddingResult(embeddings=None) + log.debug("text embedding input=%s", input) llm_config = args.get("llm", {}) batch_size = args.get("batch_size", 16) batch_max_tokens = args.get("batch_max_tokens", 8191) @@ -47,6 +48,7 @@ async def run( batch_max_tokens, splitter, ) + log.debug("text embedding prepared input=%s, size=%s", texts, input_sizes) log.info( "embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, max_tokens=%d", len(input), @@ -95,6 +97,7 @@ async def _execute( semaphore: asyncio.Semaphore, ) -> list[list[float]]: async def embed(chunk: list[str]): + log.debug("text embedding chunk=%s", chunk) async with semaphore: chunk_embeddings = await llm(chunk) result = np.array(chunk_embeddings.output) diff --git a/graphrag/llm/openai/utils.py b/graphrag/llm/openai/utils.py index d529a8c069..c68274e506 100644 --- a/graphrag/llm/openai/utils.py +++ b/graphrag/llm/openai/utils.py @@ -106,6 +106,9 @@ def get_sleep_time_from_error(e: Any) -> float: if isinstance(e, RateLimitError) and _please_retry_after in str(e): # could be second or seconds sleep_time = int(str(e).split(_please_retry_after)[1].split(" second")[0]) + elif isinstance(e, RateLimitError): + # for Chinese model api cloud be fixed second + sleep_time = 5.0 return sleep_time diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index aef7c2965b..c1b14d1a10 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -151,7 +151,7 @@ def run_local_search( response_type=response_type, ) - result = search_engine.search(query=query) + result = search_engine.search(query=query, chunk_type=config.chunks.type) reporter.success(f"Local Search Response: {result.response}") return result.response diff --git a/graphrag/query/context_builder/entity_extraction.py b/graphrag/query/context_builder/entity_extraction.py index 82a0699cd8..d0179a8b02 100644 --- a/graphrag/query/context_builder/entity_extraction.py +++ b/graphrag/query/context_builder/entity_extraction.py @@ -12,6 +12,7 @@ ) from graphrag.query.llm.base import BaseTextEmbedding from graphrag.vector_stores import BaseVectorStore +from graphrag.index.verbs.text.chunk.text_chunk import ChunkStrategyType class EntityVectorStoreKey(str, Enum): @@ -42,6 +43,7 @@ def map_query_to_entities( exclude_entity_names: list[str] | None = None, k: int = 10, oversample_scaler: int = 2, + chunk_type: str = ChunkStrategyType.tokens, ) -> list[Entity]: """Extract entities that match a given query using semantic similarity of text embeddings of query and entity descriptions.""" if include_entity_names is None: @@ -54,7 +56,7 @@ def map_query_to_entities( # oversample to account for excluded entities search_results = text_embedding_vectorstore.similarity_search_by_text( text=query, - text_embedder=lambda t: text_embedder.embed(t), + text_embedder=lambda t: text_embedder.embed(t, chunk_type=chunk_type), k=k * oversample_scaler, ) for result in search_results: diff --git a/graphrag/query/llm/oai/embedding.py b/graphrag/query/llm/oai/embedding.py index f40372dbce..5cb5f60d6b 100644 --- a/graphrag/query/llm/oai/embedding.py +++ b/graphrag/query/llm/oai/embedding.py @@ -26,6 +26,7 @@ ) from graphrag.query.llm.text_utils import chunk_text from graphrag.query.progress import StatusReporter +from graphrag.index.verbs.text.chunk.text_chunk import ChunkStrategyType class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl): @@ -75,6 +76,18 @@ def embed(self, text: str, **kwargs: Any) -> list[float]: For text longer than max_tokens, chunk texts into max_tokens, embed each chunk, then combine using weighted average. Please refer to: https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb """ + chunk_type = kwargs.get("chunk_type", ChunkStrategyType.tokens) + if chunk_type == ChunkStrategyType.chinese: + try: + embedding, chunk_len = self._embed_with_retry(text, **kwargs) + return embedding + # TODO: catch a more specific exception + except Exception as e: # noqa BLE001 + self._reporter.error( + message="Error embedding chunk", + details={self.__class__.__name__: str(e)}, + ) + token_chunks = chunk_text( text=text, token_encoder=self.token_encoder, max_tokens=self.max_tokens ) @@ -91,8 +104,11 @@ def embed(self, text: str, **kwargs: Any) -> list[float]: message="Error embedding chunk", details={self.__class__.__name__: str(e)}, ) - continue + + if sum(chunk_lens) == 0: + raise ValueError("chunk_lens 权重数组的和为零,无法进行平均计算。") + chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings) return chunk_embeddings.tolist() @@ -114,6 +130,10 @@ async def aembed(self, text: str, **kwargs: Any) -> list[float]: embedding_results = [result for result in embedding_results if result[0]] chunk_embeddings = [result[0] for result in embedding_results] chunk_lens = [result[1] for result in embedding_results] + + if sum(chunk_lens) == 0: + raise ValueError("chunk_lens 权重数组的和为零,无法进行平均计算。") + chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) # type: ignore chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings) return chunk_embeddings.tolist() @@ -121,6 +141,10 @@ async def aembed(self, text: str, **kwargs: Any) -> list[float]: def _embed_with_retry( self, text: str | tuple, **kwargs: Any ) -> tuple[list[float], int]: + # 确保 chunk_type 参数不传递给 create 函数 + if 'chunk_type' in kwargs: + del kwargs['chunk_type'] + try: retryer = Retrying( stop=stop_after_attempt(self.max_retries), @@ -148,12 +172,15 @@ def _embed_with_retry( ) return ([], 0) else: - # TODO: why not just throw in this case? return ([], 0) async def _aembed_with_retry( self, text: str | tuple, **kwargs: Any ) -> tuple[list[float], int]: + # 确保 chunk_type 参数不传递给 create 函数 + if 'chunk_type' in kwargs: + del kwargs['chunk_type'] + try: retryer = AsyncRetrying( stop=stop_after_attempt(self.max_retries), @@ -178,5 +205,4 @@ async def _aembed_with_retry( ) return ([], 0) else: - # TODO: why not just throw in this case? return ([], 0)