Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def hydrate_parallelization_params(
)
with reader.envvar_prefix(Section.chunk), reader.use(values.get("chunks")):
chunks_model = ChunkingConfig(
type=reader.str("type") or None,
size=reader.int("size") or defs.CHUNK_SIZE,
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
group_by_columns=reader.list("group_by_columns", "BY_COLUMNS")
Expand Down
3 changes: 2 additions & 1 deletion graphrag/config/models/chunking_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
class ChunkingConfig(BaseModel):
"""Configuration section for chunking."""

type: str = Field(description="The Split type", default=None)
size: int = Field(description="The chunk size to use.", default=defs.CHUNK_SIZE)
overlap: int = Field(
description="The chunk overlap to use.", default=defs.CHUNK_OVERLAP
Expand All @@ -29,7 +30,7 @@ def resolved_strategy(self) -> dict:
from graphrag.index.verbs.text.chunk import ChunkStrategyType

return self.strategy or {
"type": ChunkStrategyType.tokens,
"type": self.type or ChunkStrategyType.tokens,
"chunk_size": self.size,
"chunk_overlap": self.overlap,
"group_by_columns": self.group_by_columns,
Expand Down
125 changes: 125 additions & 0 deletions graphrag/index/verbs/text/chunk/strategies/chinese.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from collections.abc import Iterable

from datashaper import ProgressTicker

from .typing import TextChunk

import re
import logging

from langchain.text_splitter import RecursiveCharacterTextSplitter

from typing import Any, List, Optional

DEFAULT_CHUNK_SIZE = 300 # chars
DEFAULT_CHUNK_OVERLAP = 0 # chars

log = logging.getLogger(__name__)


def run(
input: list[str], args: dict[str, Any], tick: ProgressTicker
) -> Iterable[TextChunk]:
log.info("using chinese text splitter configuration: %s", args)
keep_separator = args.get("keep_separator", True)
is_separator_regex = args.get("is_separator_regex", True)
chunk_size = args.get("chunk_size", DEFAULT_CHUNK_SIZE)
chunk_overlap = args.get("chunk_overlap", DEFAULT_CHUNK_OVERLAP)
text_splitter = ChineseRecursiveTextSplitter(
keep_separator=keep_separator, is_separator_regex=is_separator_regex, chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
textChunks = []
for doc_idx, text in enumerate(input):
chunks = text_splitter.split_text(text)
for chunk in chunks:
textChunks.append(TextChunk(
text_chunk=chunk,
source_doc_indices=[doc_idx],
))
tick(1)
return textChunks


def _split_text_with_regex_from_end(
text: str, separator: str, keep_separator: bool
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
if len(_splits) % 2 == 1:
splits += _splits[-1:]
# splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]


class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = True,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or [
"\n\n",
"\n",
"。|!|?",
"\.\s|\!\s|\?\s",
";|;\s",
",|,\s",
]
self._is_separator_regex = is_separator_regex

def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1:]
break

_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator)

# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return [
re.sub(r"\n{2,}", "\n", chunk.strip())
for chunk in final_chunks
if chunk.strip() != ""
]
25 changes: 15 additions & 10 deletions graphrag/index/verbs/text/chunk/text_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ChunkStrategyType(str, Enum):

tokens = "tokens"
sentence = "sentence"
chinese = "chinese"

def __repr__(self):
"""Get a string representation."""
Expand All @@ -43,12 +44,12 @@ def __repr__(self):

@verb(name="chunk")
def chunk(
input: VerbInput,
column: str,
to: str,
callbacks: VerbCallbacks,
strategy: dict[str, Any] | None = None,
**_kwargs,
input: VerbInput,
column: str,
to: str,
callbacks: VerbCallbacks,
strategy: dict[str, Any] | None = None,
**_kwargs,
) -> TableContainer:
"""
Chunk a piece of text into smaller pieces.
Expand Down Expand Up @@ -106,10 +107,10 @@ def chunk(


def run_strategy(
strategy: ChunkStrategy,
input: ChunkInput,
strategy_args: dict[str, Any],
tick: ProgressTicker,
strategy: ChunkStrategy,
input: ChunkInput,
strategy_args: dict[str, Any],
tick: ProgressTicker,
) -> list[str | tuple[list[str] | None, str, int]]:
"""Run strategy method definition."""
if isinstance(input, str):
Expand Down Expand Up @@ -157,6 +158,10 @@ def load_strategy(strategy: ChunkStrategyType) -> ChunkStrategy:

bootstrap()
return run_sentence
case ChunkStrategyType.chinese:
from .strategies.chinese import run as run_chinese

return run_chinese
case _:
msg = f"Unknown strategy: {strategy}"
raise ValueError(msg)
3 changes: 3 additions & 0 deletions graphrag/index/verbs/text/embed/strategies/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ async def run(
if is_null(input):
return TextEmbeddingResult(embeddings=None)

log.debug("text embedding input=%s", input)
llm_config = args.get("llm", {})
batch_size = args.get("batch_size", 16)
batch_max_tokens = args.get("batch_max_tokens", 8191)
Expand All @@ -47,6 +48,7 @@ async def run(
batch_max_tokens,
splitter,
)
log.debug("text embedding prepared input=%s, size=%s", texts, input_sizes)
log.info(
"embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, max_tokens=%d",
len(input),
Expand Down Expand Up @@ -95,6 +97,7 @@ async def _execute(
semaphore: asyncio.Semaphore,
) -> list[list[float]]:
async def embed(chunk: list[str]):
log.debug("text embedding chunk=%s", chunk)
async with semaphore:
chunk_embeddings = await llm(chunk)
result = np.array(chunk_embeddings.output)
Expand Down
3 changes: 3 additions & 0 deletions graphrag/llm/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def get_sleep_time_from_error(e: Any) -> float:
if isinstance(e, RateLimitError) and _please_retry_after in str(e):
# could be second or seconds
sleep_time = int(str(e).split(_please_retry_after)[1].split(" second")[0])
elif isinstance(e, RateLimitError):
# for Chinese model api cloud be fixed second
sleep_time = 5.0

return sleep_time

Expand Down
2 changes: 1 addition & 1 deletion graphrag/query/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def run_local_search(
response_type=response_type,
)

result = search_engine.search(query=query)
result = search_engine.search(query=query, chunk_type=config.chunks.type)
reporter.success(f"Local Search Response: {result.response}")
return result.response

Expand Down
4 changes: 3 additions & 1 deletion graphrag/query/context_builder/entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.vector_stores import BaseVectorStore
from graphrag.index.verbs.text.chunk.text_chunk import ChunkStrategyType


class EntityVectorStoreKey(str, Enum):
Expand Down Expand Up @@ -42,6 +43,7 @@ def map_query_to_entities(
exclude_entity_names: list[str] | None = None,
k: int = 10,
oversample_scaler: int = 2,
chunk_type: str = ChunkStrategyType.tokens,
) -> list[Entity]:
"""Extract entities that match a given query using semantic similarity of text embeddings of query and entity descriptions."""
if include_entity_names is None:
Expand All @@ -54,7 +56,7 @@ def map_query_to_entities(
# oversample to account for excluded entities
search_results = text_embedding_vectorstore.similarity_search_by_text(
text=query,
text_embedder=lambda t: text_embedder.embed(t),
text_embedder=lambda t: text_embedder.embed(t, chunk_type=chunk_type),
k=k * oversample_scaler,
)
for result in search_results:
Expand Down
32 changes: 29 additions & 3 deletions graphrag/query/llm/oai/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from graphrag.query.llm.text_utils import chunk_text
from graphrag.query.progress import StatusReporter
from graphrag.index.verbs.text.chunk.text_chunk import ChunkStrategyType


class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl):
Expand Down Expand Up @@ -75,6 +76,18 @@ def embed(self, text: str, **kwargs: Any) -> list[float]:
For text longer than max_tokens, chunk texts into max_tokens, embed each chunk, then combine using weighted average.
Please refer to: https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
"""
chunk_type = kwargs.get("chunk_type", ChunkStrategyType.tokens)
if chunk_type == ChunkStrategyType.chinese:
try:
embedding, chunk_len = self._embed_with_retry(text, **kwargs)
return embedding
# TODO: catch a more specific exception
except Exception as e: # noqa BLE001
self._reporter.error(
message="Error embedding chunk",
details={self.__class__.__name__: str(e)},
)

token_chunks = chunk_text(
text=text, token_encoder=self.token_encoder, max_tokens=self.max_tokens
)
Expand All @@ -91,8 +104,11 @@ def embed(self, text: str, **kwargs: Any) -> list[float]:
message="Error embedding chunk",
details={self.__class__.__name__: str(e)},
)

continue

if sum(chunk_lens) == 0:
raise ValueError("chunk_lens 权重数组的和为零,无法进行平均计算。")

chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens)
chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings)
return chunk_embeddings.tolist()
Expand All @@ -114,13 +130,21 @@ async def aembed(self, text: str, **kwargs: Any) -> list[float]:
embedding_results = [result for result in embedding_results if result[0]]
chunk_embeddings = [result[0] for result in embedding_results]
chunk_lens = [result[1] for result in embedding_results]

if sum(chunk_lens) == 0:
raise ValueError("chunk_lens 权重数组的和为零,无法进行平均计算。")

chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) # type: ignore
chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings)
return chunk_embeddings.tolist()

def _embed_with_retry(
self, text: str | tuple, **kwargs: Any
) -> tuple[list[float], int]:
# 确保 chunk_type 参数不传递给 create 函数
if 'chunk_type' in kwargs:
del kwargs['chunk_type']

try:
retryer = Retrying(
stop=stop_after_attempt(self.max_retries),
Expand Down Expand Up @@ -148,12 +172,15 @@ def _embed_with_retry(
)
return ([], 0)
else:
# TODO: why not just throw in this case?
return ([], 0)

async def _aembed_with_retry(
self, text: str | tuple, **kwargs: Any
) -> tuple[list[float], int]:
# 确保 chunk_type 参数不传递给 create 函数
if 'chunk_type' in kwargs:
del kwargs['chunk_type']

try:
retryer = AsyncRetrying(
stop=stop_after_attempt(self.max_retries),
Expand All @@ -178,5 +205,4 @@ async def _aembed_with_retry(
)
return ([], 0)
else:
# TODO: why not just throw in this case?
return ([], 0)