diff --git a/lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/build/build_pipeline.py b/lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/build/build_pipeline.py index c9b9e8b1..ff8fe251 100644 --- a/lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/build/build_pipeline.py +++ b/lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/build/build_pipeline.py @@ -18,6 +18,7 @@ from graphrag_toolkit.lexical_graph.indexing.build.checkpoint import Checkpoint, CheckpointWriter from graphrag_toolkit.lexical_graph.indexing.build.node_builders import NodeBuilders from graphrag_toolkit.lexical_graph.indexing.build.build_filters import BuildFilters +from graphrag_toolkit.lexical_graph.utils.arg_utils import first_non_none from llama_index.core.utils import iter_batch from llama_index.core.ingestion import IngestionPipeline @@ -190,11 +191,11 @@ def __init__(self, components = components or [] num_workers = num_workers or GraphRAGConfig.build_num_workers batch_size = batch_size or GraphRAGConfig.build_batch_size - batch_writes_enabled = batch_writes_enabled or GraphRAGConfig.batch_writes_enabled + batch_writes_enabled = first_non_none([batch_writes_enabled, GraphRAGConfig.batch_writes_enabled]) batch_write_size = batch_write_size or GraphRAGConfig.build_batch_write_size - include_domain_labels = include_domain_labels or GraphRAGConfig.include_domain_labels - include_local_entities = include_local_entities or GraphRAGConfig.include_local_entities - include_classification_in_entity_id = include_classification_in_entity_id or GraphRAGConfig.include_classification_in_entity_id + include_domain_labels = first_non_none([include_domain_labels, GraphRAGConfig.include_domain_labels]) + include_local_entities = first_non_none([include_local_entities, GraphRAGConfig.include_local_entities]) + include_classification_in_entity_id = first_non_none([include_classification_in_entity_id, GraphRAGConfig.include_classification_in_entity_id]) source_metadata_formatter = source_metadata_formatter or DefaultSourceMetadataFormatter() for c in components: diff --git a/lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/extract/extraction_pipeline.py b/lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/extract/extraction_pipeline.py index 4485d2c0..65917332 100644 --- a/lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/extract/extraction_pipeline.py +++ b/lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/extract/extraction_pipeline.py @@ -19,6 +19,7 @@ from graphrag_toolkit.lexical_graph.indexing.build.checkpoint import Checkpoint from graphrag_toolkit.lexical_graph.indexing.extract.docs_to_nodes import DocsToNodes from graphrag_toolkit.lexical_graph.indexing.extract.id_rewriter import IdRewriter +from graphrag_toolkit.lexical_graph.utils.arg_utils import first_non_none from llama_index.core.node_parser import NodeParser from llama_index.core.utils import iter_batch @@ -231,7 +232,7 @@ def __init__(self, components = components or [] num_workers = num_workers or GraphRAGConfig.extraction_num_workers batch_size = batch_size or GraphRAGConfig.extraction_batch_size - include_classification_in_entity_id = include_classification_in_entity_id or GraphRAGConfig.include_classification_in_entity_id + include_classification_in_entity_id = first_non_none([include_classification_in_entity_id, GraphRAGConfig.include_classification_in_entity_id]) extract_timestamp = kwargs.pop('extract_timestamp', None) if num_workers > multiprocessing.cpu_count(): diff --git a/lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/id_generator.py b/lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/id_generator.py index 84e67b1b..a86b8c79 100644 --- a/lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/id_generator.py +++ b/lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/id_generator.py @@ -5,6 +5,7 @@ from graphrag_toolkit.lexical_graph import TenantId, GraphRAGConfig from graphrag_toolkit.lexical_graph.indexing.utils.hash_utils import get_hash +from graphrag_toolkit.lexical_graph.utils.arg_utils import first_non_none from llama_index.core.bridge.pydantic import BaseModel @@ -41,7 +42,7 @@ def __init__(self, tenant_id:TenantId=None, include_classification_in_entity_id: """ super().__init__( tenant_id=tenant_id or TenantId(), - include_classification_in_entity_id=include_classification_in_entity_id or GraphRAGConfig.include_classification_in_entity_id, + include_classification_in_entity_id=first_non_none([include_classification_in_entity_id, GraphRAGConfig.include_classification_in_entity_id]), use_chunk_id_delimiter=use_chunk_id_delimiter ) diff --git a/lexical-graph/src/graphrag_toolkit/lexical_graph/lexical_graph_index.py b/lexical-graph/src/graphrag_toolkit/lexical_graph/lexical_graph_index.py index 42cc7d7d..8e95409e 100644 --- a/lexical-graph/src/graphrag_toolkit/lexical_graph/lexical_graph_index.py +++ b/lexical-graph/src/graphrag_toolkit/lexical_graph/lexical_graph_index.py @@ -33,6 +33,7 @@ from graphrag_toolkit.lexical_graph.indexing.build import BuildFilters from graphrag_toolkit.lexical_graph.indexing.build.null_builder import NullBuilder from graphrag_toolkit.lexical_graph.indexing.build.delete_sources import DeleteSources +from graphrag_toolkit.lexical_graph.utils.arg_utils import first_non_none from llama_index.core.node_parser import SentenceSplitter, NodeParser from llama_index.core.schema import BaseNode @@ -502,7 +503,7 @@ def build( build_config = self.indexing_config.build - enable_versioning = kwargs.get('enable_versioning', None) or build_config.enable_versioning or GraphRAGConfig.enable_versioning + enable_versioning = first_non_none([kwargs.get('enable_versioning', None), build_config.enable_versioning, GraphRAGConfig.enable_versioning]) components = [] @@ -566,7 +567,7 @@ def extract_and_build( build_config = self.indexing_config.build - enable_versioning = kwargs.get('enable_versioning', None) or build_config.enable_versioning or GraphRAGConfig.enable_versioning + enable_versioning = first_non_none([kwargs.get('enable_versioning', None), build_config.enable_versioning, GraphRAGConfig.enable_versioning]) build_components = [] diff --git a/lexical-graph/src/graphrag_toolkit/lexical_graph/utils/arg_utils.py b/lexical-graph/src/graphrag_toolkit/lexical_graph/utils/arg_utils.py new file mode 100644 index 00000000..4a74c4dc --- /dev/null +++ b/lexical-graph/src/graphrag_toolkit/lexical_graph/utils/arg_utils.py @@ -0,0 +1,7 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Any + +def first_non_none(items:List[Any]): + return next((item for item in items if item is not None), None) \ No newline at end of file diff --git a/lexical-graph/tests/unit/test_id_generator.py b/lexical-graph/tests/unit/test_id_generator.py index bfa69131..eb7965ca 100644 --- a/lexical-graph/tests/unit/test_id_generator.py +++ b/lexical-graph/tests/unit/test_id_generator.py @@ -107,14 +107,8 @@ def test_create_entity_id_classification_matters_when_enabled(default_id_gen): def test_create_entity_id_classification_ignored_when_disabled(default_tenant): """With classification disabled, entity identity depends only on value — so 'Amazon/Company' and 'Amazon/River' collapse to the same node. - - SOURCE BUG: IdGenerator.__init__ uses - `include_classification_in_entity_id or GraphRAGConfig.include_classification_in_entity_id` - Because False is falsy, passing False is silently overridden by the config default (True). - Workaround: set the field directly on the instance after construction. """ - gen = IdGenerator(tenant_id=default_tenant, include_classification_in_entity_id=True) - gen.include_classification_in_entity_id = False + gen = IdGenerator(tenant_id=default_tenant, include_classification_in_entity_id=False) assert gen.create_entity_id("Amazon", "Company") == gen.create_entity_id("Amazon", "River") diff --git a/lexical-graph/tests/unit/utils/test_arg_utils.py b/lexical-graph/tests/unit/utils/test_arg_utils.py new file mode 100644 index 00000000..0e9dd6c3 --- /dev/null +++ b/lexical-graph/tests/unit/utils/test_arg_utils.py @@ -0,0 +1,11 @@ +import pytest + +from graphrag_toolkit.lexical_graph.utils.arg_utils import first_non_none + +def test_first_non_none(): + assert first_non_none([None, None, 3]) == 3 + assert first_non_none([None, 2, 3]) == 2 + assert first_non_none([1, 2, 3]) == 1 + assert first_non_none([None, False, True]) == False + assert first_non_none([None, True, False]) == True + assert first_non_none([None, None, None]) is None \ No newline at end of file