From 2263c344b0d82502438e4cb02b5cf00918bc5ab0 Mon Sep 17 00:00:00 2001 From: sujunjun Date: Tue, 16 Jul 2024 10:56:10 +0800 Subject: [PATCH 1/6] Chinese Splitter --- graphrag/config/create_graphrag_config.py | 2 ++ graphrag/config/models/chunking_config.py | 5 ++-- graphrag/index/verbs/text/chunk/text_chunk.py | 25 +++++++++++-------- graphrag/llm/openai/utils.py | 3 +++ 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index c2acd7b1d6..2017548853 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -55,6 +55,7 @@ UmapConfig, ) from .read_dotenv import read_dotenv +from ..index.verbs.text.chunk import ChunkStrategyType InputModelValidator = TypeAdapter(GraphRagConfigInput) @@ -369,6 +370,7 @@ def hydrate_parallelization_params( ) with reader.envvar_prefix(Section.chunk), reader.use(values.get("chunks")): chunks_model = ChunkingConfig( + type=reader.str("type") or ChunkStrategyType.tokens, size=reader.int("size") or defs.CHUNK_SIZE, overlap=reader.int("overlap") or defs.CHUNK_OVERLAP, group_by_columns=reader.list("group_by_columns", "BY_COLUMNS") diff --git a/graphrag/config/models/chunking_config.py b/graphrag/config/models/chunking_config.py index ad7e3d0a9d..8ded61446e 100644 --- a/graphrag/config/models/chunking_config.py +++ b/graphrag/config/models/chunking_config.py @@ -6,11 +6,13 @@ from pydantic import BaseModel, Field import graphrag.config.defaults as defs +from graphrag.index.verbs.text.chunk import ChunkStrategyType class ChunkingConfig(BaseModel): """Configuration section for chunking.""" + type: str = Field(description="The Split type", default=ChunkStrategyType.tokens) size: int = Field(description="The chunk size to use.", default=defs.CHUNK_SIZE) overlap: int = Field( description="The chunk overlap to use.", default=defs.CHUNK_OVERLAP @@ -26,10 +28,9 @@ class ChunkingConfig(BaseModel): def resolved_strategy(self) -> dict: """Get the resolved chunking strategy.""" - from graphrag.index.verbs.text.chunk import ChunkStrategyType return self.strategy or { - "type": ChunkStrategyType.tokens, + "type": self.type, "chunk_size": self.size, "chunk_overlap": self.overlap, "group_by_columns": self.group_by_columns, diff --git a/graphrag/index/verbs/text/chunk/text_chunk.py b/graphrag/index/verbs/text/chunk/text_chunk.py index d8fab44f64..a41468b06f 100644 --- a/graphrag/index/verbs/text/chunk/text_chunk.py +++ b/graphrag/index/verbs/text/chunk/text_chunk.py @@ -35,6 +35,7 @@ class ChunkStrategyType(str, Enum): tokens = "tokens" sentence = "sentence" + chinese = "chinese" def __repr__(self): """Get a string representation.""" @@ -43,12 +44,12 @@ def __repr__(self): @verb(name="chunk") def chunk( - input: VerbInput, - column: str, - to: str, - callbacks: VerbCallbacks, - strategy: dict[str, Any] | None = None, - **_kwargs, + input: VerbInput, + column: str, + to: str, + callbacks: VerbCallbacks, + strategy: dict[str, Any] | None = None, + **_kwargs, ) -> TableContainer: """ Chunk a piece of text into smaller pieces. @@ -106,10 +107,10 @@ def chunk( def run_strategy( - strategy: ChunkStrategy, - input: ChunkInput, - strategy_args: dict[str, Any], - tick: ProgressTicker, + strategy: ChunkStrategy, + input: ChunkInput, + strategy_args: dict[str, Any], + tick: ProgressTicker, ) -> list[str | tuple[list[str] | None, str, int]]: """Run strategy method definition.""" if isinstance(input, str): @@ -157,6 +158,10 @@ def load_strategy(strategy: ChunkStrategyType) -> ChunkStrategy: bootstrap() return run_sentence + case ChunkStrategyType.chinese: + from .strategies.chinese import run as run_chinese + + return run_chinese case _: msg = f"Unknown strategy: {strategy}" raise ValueError(msg) diff --git a/graphrag/llm/openai/utils.py b/graphrag/llm/openai/utils.py index d529a8c069..c68274e506 100644 --- a/graphrag/llm/openai/utils.py +++ b/graphrag/llm/openai/utils.py @@ -106,6 +106,9 @@ def get_sleep_time_from_error(e: Any) -> float: if isinstance(e, RateLimitError) and _please_retry_after in str(e): # could be second or seconds sleep_time = int(str(e).split(_please_retry_after)[1].split(" second")[0]) + elif isinstance(e, RateLimitError): + # for Chinese model api cloud be fixed second + sleep_time = 5.0 return sleep_time From dba9411871e3bce730f71a26e8edcc5a735d8963 Mon Sep 17 00:00:00 2001 From: sujunjun Date: Tue, 16 Jul 2024 11:14:04 +0800 Subject: [PATCH 2/6] support chinese splitter --- .../verbs/text/chunk/strategies/chinese.py | 195 ++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 graphrag/index/verbs/text/chunk/strategies/chinese.py diff --git a/graphrag/index/verbs/text/chunk/strategies/chinese.py b/graphrag/index/verbs/text/chunk/strategies/chinese.py new file mode 100644 index 0000000000..bf51bffe02 --- /dev/null +++ b/graphrag/index/verbs/text/chunk/strategies/chinese.py @@ -0,0 +1,195 @@ +from collections.abc import Iterable + +from datashaper import ProgressTicker + +from .typing import TextChunk + +import re + +from langchain.text_splitter import CharacterTextSplitter + +from typing import Any, List, Optional + +from langchain.text_splitter import RecursiveCharacterTextSplitter + +DEFAULT_CHUNK_SIZE = 500 # tokens +DEFAULT_CHUNK_OVERLAP = 0 # tokens + + +def run( + input: list[str], args: dict[str, Any], tick: ProgressTicker +) -> Iterable[TextChunk]: + keep_separator = args.get("keep_separator", True) + is_separator_regex = args.get("is_separator_regex", True) + chunk_size = args.get("chunk_size", DEFAULT_CHUNK_SIZE) + chunk_overlap = args.get("chunk_overlap", DEFAULT_CHUNK_OVERLAP) + text_splitter = ChineseRecursiveTextSplitter( + keep_separator=keep_separator, is_separator_regex=is_separator_regex, chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + textChunks = [] + for doc_idx, text in enumerate(input): + chunks = text_splitter.split_text(text) + for chunk in chunks: + textChunks.append(TextChunk( + text_chunk=chunk, + source_doc_indices=[doc_idx], + )) + return textChunks + + +def _split_text_with_regex_from_end( + text: str, separator: str, keep_separator: bool +) -> List[str]: + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])] + if len(_splits) % 2 == 1: + splits += _splits[-1:] + # splits = [_splits[0]] + splits + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + +class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): + def __init__( + self, + separators: Optional[List[str]] = None, + keep_separator: bool = True, + is_separator_regex: bool = True, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(keep_separator=keep_separator, **kwargs) + self._separators = separators or [ + "\n\n", + "\n", + "。|!|?", + "\.\s|\!\s|\?\s", + ";|;\s", + ",|,\s", + ] + self._is_separator_regex = is_separator_regex + + def _split_text(self, text: str, separators: List[str]) -> List[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + _separator = _s if self._is_separator_regex else re.escape(_s) + if _s == "": + separator = _s + break + if re.search(_separator, text): + separator = _s + new_separators = separators[i + 1:] + break + + _separator = separator if self._is_separator_regex else re.escape(separator) + splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator) + + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + _separator = "" if self._keep_separator else separator + for s in splits: + if self._length_function(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return [ + re.sub(r"\n{2,}", "\n", chunk.strip()) + for chunk in final_chunks + if chunk.strip() != "" + ] + + +class ChineseTextSplitter(CharacterTextSplitter): + def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs): + super().__init__(**kwargs) + self.pdf = pdf + self.sentence_size = sentence_size + + def split_text1(self, text: str) -> List[str]: + if self.pdf: + text = re.sub(r"\n{3,}", "\n", text) + text = re.sub("\s", " ", text) + text = text.replace("\n\n", "") + sent_sep_pattern = re.compile( + '([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))' + ) # del :; + sent_list = [] + for ele in sent_sep_pattern.split(text): + if sent_sep_pattern.match(ele) and sent_list: + sent_list[-1] += ele + elif ele: + sent_list.append(ele) + return sent_list + + def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑 + if self.pdf: + text = re.sub(r"\n{3,}", r"\n", text) + text = re.sub("\s", " ", text) + text = re.sub("\n\n", "", text) + + text = re.sub(r"([;;.!?。!?\?])([^”’])", r"\1\n\2", text) # 单字符断句符 + text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号 + text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号 + text = re.sub( + r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r"\1\n\2", text + ) + # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 + text = text.rstrip() # 段尾如果有多余的\n就去掉它 + # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。 + ls = [i for i in text.split("\n") if i] + for ele in ls: + if len(ele) > self.sentence_size: + ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r"\1\n\2", ele) + ele1_ls = ele1.split("\n") + for ele_ele1 in ele1_ls: + if len(ele_ele1) > self.sentence_size: + ele_ele2 = re.sub( + r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', + r"\1\n\2", + ele_ele1, + ) + ele2_ls = ele_ele2.split("\n") + for ele_ele2 in ele2_ls: + if len(ele_ele2) > self.sentence_size: + ele_ele3 = re.sub( + '( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2 + ) + ele2_id = ele2_ls.index(ele_ele2) + ele2_ls = ( + ele2_ls[:ele2_id] + + [i for i in ele_ele3.split("\n") if i] + + ele2_ls[ele2_id + 1:] + ) + ele_id = ele1_ls.index(ele_ele1) + ele1_ls = ( + ele1_ls[:ele_id] + + [i for i in ele2_ls if i] + + ele1_ls[ele_id + 1:] + ) + + id = ls.index(ele) + ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:] + return ls From d2602a3bf15a214329dd916065c2ed6f1e5e12f2 Mon Sep 17 00:00:00 2001 From: sujunjun Date: Tue, 16 Jul 2024 14:19:13 +0800 Subject: [PATCH 3/6] chinese --- graphrag/config/create_graphrag_config.py | 3 +-- graphrag/config/models/chunking_config.py | 6 +++--- graphrag/index/verbs/text/chunk/strategies/chinese.py | 5 +++++ 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index 2017548853..1f9cff8bf1 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -55,7 +55,6 @@ UmapConfig, ) from .read_dotenv import read_dotenv -from ..index.verbs.text.chunk import ChunkStrategyType InputModelValidator = TypeAdapter(GraphRagConfigInput) @@ -370,7 +369,7 @@ def hydrate_parallelization_params( ) with reader.envvar_prefix(Section.chunk), reader.use(values.get("chunks")): chunks_model = ChunkingConfig( - type=reader.str("type") or ChunkStrategyType.tokens, + type=reader.str("type") or None, size=reader.int("size") or defs.CHUNK_SIZE, overlap=reader.int("overlap") or defs.CHUNK_OVERLAP, group_by_columns=reader.list("group_by_columns", "BY_COLUMNS") diff --git a/graphrag/config/models/chunking_config.py b/graphrag/config/models/chunking_config.py index 8ded61446e..3f79bfc605 100644 --- a/graphrag/config/models/chunking_config.py +++ b/graphrag/config/models/chunking_config.py @@ -6,13 +6,12 @@ from pydantic import BaseModel, Field import graphrag.config.defaults as defs -from graphrag.index.verbs.text.chunk import ChunkStrategyType class ChunkingConfig(BaseModel): """Configuration section for chunking.""" - type: str = Field(description="The Split type", default=ChunkStrategyType.tokens) + type: str = Field(description="The Split type", default=None) size: int = Field(description="The chunk size to use.", default=defs.CHUNK_SIZE) overlap: int = Field( description="The chunk overlap to use.", default=defs.CHUNK_OVERLAP @@ -28,9 +27,10 @@ class ChunkingConfig(BaseModel): def resolved_strategy(self) -> dict: """Get the resolved chunking strategy.""" + from graphrag.index.verbs.text.chunk import ChunkStrategyType return self.strategy or { - "type": self.type, + "type": self.type or ChunkStrategyType.tokens, "chunk_size": self.size, "chunk_overlap": self.overlap, "group_by_columns": self.group_by_columns, diff --git a/graphrag/index/verbs/text/chunk/strategies/chinese.py b/graphrag/index/verbs/text/chunk/strategies/chinese.py index bf51bffe02..c1fa98e872 100644 --- a/graphrag/index/verbs/text/chunk/strategies/chinese.py +++ b/graphrag/index/verbs/text/chunk/strategies/chinese.py @@ -5,6 +5,7 @@ from .typing import TextChunk import re +import logging from langchain.text_splitter import CharacterTextSplitter @@ -15,10 +16,13 @@ DEFAULT_CHUNK_SIZE = 500 # tokens DEFAULT_CHUNK_OVERLAP = 0 # tokens +log = logging.getLogger(__name__) + def run( input: list[str], args: dict[str, Any], tick: ProgressTicker ) -> Iterable[TextChunk]: + log.info("using chinese text splitter configuration: %s", args) keep_separator = args.get("keep_separator", True) is_separator_regex = args.get("is_separator_regex", True) chunk_size = args.get("chunk_size", DEFAULT_CHUNK_SIZE) @@ -35,6 +39,7 @@ def run( text_chunk=chunk, source_doc_indices=[doc_idx], )) + tick(1) return textChunks From a2c81adb8ca7613f0436b3e3d52d3fc212018948 Mon Sep 17 00:00:00 2001 From: sujunjun Date: Tue, 16 Jul 2024 14:22:05 +0800 Subject: [PATCH 4/6] chinese --- .../verbs/text/chunk/strategies/chinese.py | 81 +------------------ 1 file changed, 3 insertions(+), 78 deletions(-) diff --git a/graphrag/index/verbs/text/chunk/strategies/chinese.py b/graphrag/index/verbs/text/chunk/strategies/chinese.py index c1fa98e872..eace8f494e 100644 --- a/graphrag/index/verbs/text/chunk/strategies/chinese.py +++ b/graphrag/index/verbs/text/chunk/strategies/chinese.py @@ -7,14 +7,12 @@ import re import logging -from langchain.text_splitter import CharacterTextSplitter +from langchain.text_splitter import RecursiveCharacterTextSplitter from typing import Any, List, Optional -from langchain.text_splitter import RecursiveCharacterTextSplitter - -DEFAULT_CHUNK_SIZE = 500 # tokens -DEFAULT_CHUNK_OVERLAP = 0 # tokens +DEFAULT_CHUNK_SIZE = 300 # chars +DEFAULT_CHUNK_OVERLAP = 0 # chars log = logging.getLogger(__name__) @@ -125,76 +123,3 @@ def _split_text(self, text: str, separators: List[str]) -> List[str]: for chunk in final_chunks if chunk.strip() != "" ] - - -class ChineseTextSplitter(CharacterTextSplitter): - def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs): - super().__init__(**kwargs) - self.pdf = pdf - self.sentence_size = sentence_size - - def split_text1(self, text: str) -> List[str]: - if self.pdf: - text = re.sub(r"\n{3,}", "\n", text) - text = re.sub("\s", " ", text) - text = text.replace("\n\n", "") - sent_sep_pattern = re.compile( - '([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))' - ) # del :; - sent_list = [] - for ele in sent_sep_pattern.split(text): - if sent_sep_pattern.match(ele) and sent_list: - sent_list[-1] += ele - elif ele: - sent_list.append(ele) - return sent_list - - def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑 - if self.pdf: - text = re.sub(r"\n{3,}", r"\n", text) - text = re.sub("\s", " ", text) - text = re.sub("\n\n", "", text) - - text = re.sub(r"([;;.!?。!?\?])([^”’])", r"\1\n\2", text) # 单字符断句符 - text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号 - text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号 - text = re.sub( - r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r"\1\n\2", text - ) - # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 - text = text.rstrip() # 段尾如果有多余的\n就去掉它 - # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。 - ls = [i for i in text.split("\n") if i] - for ele in ls: - if len(ele) > self.sentence_size: - ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r"\1\n\2", ele) - ele1_ls = ele1.split("\n") - for ele_ele1 in ele1_ls: - if len(ele_ele1) > self.sentence_size: - ele_ele2 = re.sub( - r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', - r"\1\n\2", - ele_ele1, - ) - ele2_ls = ele_ele2.split("\n") - for ele_ele2 in ele2_ls: - if len(ele_ele2) > self.sentence_size: - ele_ele3 = re.sub( - '( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2 - ) - ele2_id = ele2_ls.index(ele_ele2) - ele2_ls = ( - ele2_ls[:ele2_id] - + [i for i in ele_ele3.split("\n") if i] - + ele2_ls[ele2_id + 1:] - ) - ele_id = ele1_ls.index(ele_ele1) - ele1_ls = ( - ele1_ls[:ele_id] - + [i for i in ele2_ls if i] - + ele1_ls[ele_id + 1:] - ) - - id = ls.index(ele) - ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:] - return ls From 7422c946a78655c38c171aa498d43811bbd79a5c Mon Sep 17 00:00:00 2001 From: sujunjun Date: Thu, 18 Jul 2024 12:41:01 +0800 Subject: [PATCH 5/6] =?UTF-8?q?=E7=BB=9F=E4=B8=80embedding=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../index/verbs/text/embed/strategies/openai.py | 3 +++ graphrag/query/cli.py | 2 +- graphrag/query/context_builder/entity_extraction.py | 4 +++- graphrag/query/llm/oai/embedding.py | 13 +++++++++++++ 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/graphrag/index/verbs/text/embed/strategies/openai.py b/graphrag/index/verbs/text/embed/strategies/openai.py index 0658d604cc..de8a745bf1 100644 --- a/graphrag/index/verbs/text/embed/strategies/openai.py +++ b/graphrag/index/verbs/text/embed/strategies/openai.py @@ -31,6 +31,7 @@ async def run( if is_null(input): return TextEmbeddingResult(embeddings=None) + log.debug("text embedding input=%s", input) llm_config = args.get("llm", {}) batch_size = args.get("batch_size", 16) batch_max_tokens = args.get("batch_max_tokens", 8191) @@ -47,6 +48,7 @@ async def run( batch_max_tokens, splitter, ) + log.debug("text embedding prepared input=%s, size=%s", texts, input_sizes) log.info( "embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, max_tokens=%d", len(input), @@ -95,6 +97,7 @@ async def _execute( semaphore: asyncio.Semaphore, ) -> list[list[float]]: async def embed(chunk: list[str]): + log.debug("text embedding chunk=%s", chunk) async with semaphore: chunk_embeddings = await llm(chunk) result = np.array(chunk_embeddings.output) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index aef7c2965b..c1b14d1a10 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -151,7 +151,7 @@ def run_local_search( response_type=response_type, ) - result = search_engine.search(query=query) + result = search_engine.search(query=query, chunk_type=config.chunks.type) reporter.success(f"Local Search Response: {result.response}") return result.response diff --git a/graphrag/query/context_builder/entity_extraction.py b/graphrag/query/context_builder/entity_extraction.py index 82a0699cd8..d0179a8b02 100644 --- a/graphrag/query/context_builder/entity_extraction.py +++ b/graphrag/query/context_builder/entity_extraction.py @@ -12,6 +12,7 @@ ) from graphrag.query.llm.base import BaseTextEmbedding from graphrag.vector_stores import BaseVectorStore +from graphrag.index.verbs.text.chunk.text_chunk import ChunkStrategyType class EntityVectorStoreKey(str, Enum): @@ -42,6 +43,7 @@ def map_query_to_entities( exclude_entity_names: list[str] | None = None, k: int = 10, oversample_scaler: int = 2, + chunk_type: str = ChunkStrategyType.tokens, ) -> list[Entity]: """Extract entities that match a given query using semantic similarity of text embeddings of query and entity descriptions.""" if include_entity_names is None: @@ -54,7 +56,7 @@ def map_query_to_entities( # oversample to account for excluded entities search_results = text_embedding_vectorstore.similarity_search_by_text( text=query, - text_embedder=lambda t: text_embedder.embed(t), + text_embedder=lambda t: text_embedder.embed(t, chunk_type=chunk_type), k=k * oversample_scaler, ) for result in search_results: diff --git a/graphrag/query/llm/oai/embedding.py b/graphrag/query/llm/oai/embedding.py index f40372dbce..685f52e98f 100644 --- a/graphrag/query/llm/oai/embedding.py +++ b/graphrag/query/llm/oai/embedding.py @@ -26,6 +26,7 @@ ) from graphrag.query.llm.text_utils import chunk_text from graphrag.query.progress import StatusReporter +from graphrag.index.verbs.text.chunk.text_chunk import ChunkStrategyType class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl): @@ -75,6 +76,18 @@ def embed(self, text: str, **kwargs: Any) -> list[float]: For text longer than max_tokens, chunk texts into max_tokens, embed each chunk, then combine using weighted average. Please refer to: https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb """ + chunk_type = kwargs.get("chunk_type", ChunkStrategyType.tokens) + if chunk_type == ChunkStrategyType.chinese: + try: + embedding, chunk_len = self._embed_with_retry(text, **kwargs) + return embedding + # TODO: catch a more specific exception + except Exception as e: # noqa BLE001 + self._reporter.error( + message="Error embedding chunk", + details={self.__class__.__name__: str(e)}, + ) + token_chunks = chunk_text( text=text, token_encoder=self.token_encoder, max_tokens=self.max_tokens ) From 72f54e405dfa0edc57ac2bc833a180d5af3c4018 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=B2=81=E6=B5=A9?= <14227447+luhao0614@user.noreply.gitee.com> Date: Mon, 29 Jul 2024 19:42:49 +0800 Subject: [PATCH 6/6] =?UTF-8?q?=E8=A7=A3=E5=86=B3embedding=E6=8A=A5?= =?UTF-8?q?=E9=94=99=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- graphrag/query/llm/oai/embedding.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/graphrag/query/llm/oai/embedding.py b/graphrag/query/llm/oai/embedding.py index 685f52e98f..5cb5f60d6b 100644 --- a/graphrag/query/llm/oai/embedding.py +++ b/graphrag/query/llm/oai/embedding.py @@ -104,8 +104,11 @@ def embed(self, text: str, **kwargs: Any) -> list[float]: message="Error embedding chunk", details={self.__class__.__name__: str(e)}, ) - continue + + if sum(chunk_lens) == 0: + raise ValueError("chunk_lens 权重数组的和为零,无法进行平均计算。") + chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings) return chunk_embeddings.tolist() @@ -127,6 +130,10 @@ async def aembed(self, text: str, **kwargs: Any) -> list[float]: embedding_results = [result for result in embedding_results if result[0]] chunk_embeddings = [result[0] for result in embedding_results] chunk_lens = [result[1] for result in embedding_results] + + if sum(chunk_lens) == 0: + raise ValueError("chunk_lens 权重数组的和为零,无法进行平均计算。") + chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) # type: ignore chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings) return chunk_embeddings.tolist() @@ -134,6 +141,10 @@ async def aembed(self, text: str, **kwargs: Any) -> list[float]: def _embed_with_retry( self, text: str | tuple, **kwargs: Any ) -> tuple[list[float], int]: + # 确保 chunk_type 参数不传递给 create 函数 + if 'chunk_type' in kwargs: + del kwargs['chunk_type'] + try: retryer = Retrying( stop=stop_after_attempt(self.max_retries), @@ -161,12 +172,15 @@ def _embed_with_retry( ) return ([], 0) else: - # TODO: why not just throw in this case? return ([], 0) async def _aembed_with_retry( self, text: str | tuple, **kwargs: Any ) -> tuple[list[float], int]: + # 确保 chunk_type 参数不传递给 create 函数 + if 'chunk_type' in kwargs: + del kwargs['chunk_type'] + try: retryer = AsyncRetrying( stop=stop_after_attempt(self.max_retries), @@ -191,5 +205,4 @@ async def _aembed_with_retry( ) return ([], 0) else: - # TODO: why not just throw in this case? return ([], 0)