From c5d099e57322b6fbb293dc4c3603090d1c191465 Mon Sep 17 00:00:00 2001 From: Ryan Swart Date: Wed, 15 Jan 2025 15:10:31 +1100 Subject: [PATCH 1/2] user-specified validation rule skips This feature allows for the user to specify which validation rules they want to skip. Currently we only support 2 rules (the implicit no unused fragments rule and the UniqueFragmentNames rule). --- ariadne_codegen/main.py | 4 +- ariadne_codegen/schema.py | 6 ++- ariadne_codegen/settings.py | 15 ++++++ tests/test_schema.py | 96 +++++++++++++++++++++++++++++++++++++ 4 files changed, 117 insertions(+), 4 deletions(-) diff --git a/ariadne_codegen/main.py b/ariadne_codegen/main.py index 2dcd603c..5f3ab67b 100644 --- a/ariadne_codegen/main.py +++ b/ariadne_codegen/main.py @@ -19,7 +19,7 @@ get_graphql_schema_from_path, get_graphql_schema_from_url, ) -from .settings import Strategy +from .settings import Strategy, get_validation_rule @click.command() # type: ignore @@ -64,7 +64,7 @@ def client(config_dict): fragments = [] queries = [] if settings.queries_path: - definitions = get_graphql_queries(settings.queries_path, schema) + definitions = get_graphql_queries(settings.queries_path, schema, [get_validation_rule(e) for e in settings.skip_validation_rules]) queries = filter_operations_definitions(definitions) fragments = filter_fragments_definitions(definitions) diff --git a/ariadne_codegen/schema.py b/ariadne_codegen/schema.py index 1ce36268..8ee34789 100644 --- a/ariadne_codegen/schema.py +++ b/ariadne_codegen/schema.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import Dict, Generator, List, Optional, Tuple, cast +from typing_extensions import Any, Sequence import httpx from graphql import ( @@ -14,6 +15,7 @@ IntrospectionQuery, NoUnusedFragmentsRule, OperationDefinitionNode, + UniqueFragmentNamesRule, build_ast_schema, build_client_schema, get_introspection_query, @@ -45,7 +47,7 @@ def filter_fragments_definitions( def get_graphql_queries( - queries_path: str, schema: GraphQLSchema + queries_path: str, schema: GraphQLSchema, skip_rules: Sequence[Any] = (NoUnusedFragmentsRule,) ) -> Tuple[DefinitionNode, ...]: """Get graphql queries definitions build from provided path.""" queries_str = load_graphql_files_from_path(Path(queries_path)) @@ -53,7 +55,7 @@ def get_graphql_queries( validation_errors = validate( schema=schema, document_ast=queries_ast, - rules=[r for r in specified_rules if r is not NoUnusedFragmentsRule], + rules=[r for r in specified_rules if r not in skip_rules], ) if validation_errors: raise InvalidOperationForSchema( diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 808397ba..48929c27 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -6,6 +6,8 @@ from textwrap import dedent from typing import Dict, List +from graphql.validation import UniqueFragmentNamesRule, NoUnusedFragmentsRule + from .client_generators.constants import ( DEFAULT_ASYNC_BASE_CLIENT_NAME, DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_NAME, @@ -25,6 +27,18 @@ class CommentsStrategy(str, enum.Enum): STABLE = "stable" TIMESTAMP = "timestamp" +class ValidationRuleSkips(str, enum.Enum): + UniqueFragmentNames = "UniqueFragmentNames" + NoUnusedFragments = "NoUnusedFragments" + +def get_validation_rule(rule: ValidationRuleSkips): + if rule == ValidationRuleSkips.UniqueFragmentNames: + return UniqueFragmentNamesRule + elif rule == ValidationRuleSkips.NoUnusedFragments: + return NoUnusedFragmentsRule + else: + raise ValueError(f"Unknown validation rule: {rule}") + class Strategy(str, enum.Enum): CLIENT = "client" @@ -70,6 +84,7 @@ class ClientSettings(BaseSettings): include_all_enums: bool = True async_client: bool = True opentelemetry_client: bool = False + skip_validation_rules: List[ValidationRuleSkips] = field(default_factory=lambda: [ValidationRuleSkips.UniqueFragmentNames,]) files_to_include: List[str] = field(default_factory=list) scalars: Dict[str, ScalarData] = field(default_factory=dict) diff --git a/tests/test_schema.py b/tests/test_schema.py index 42873867..367164da 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,3 +1,4 @@ +from ariadne_codegen.settings import get_validation_rule import httpx import pytest from graphql import GraphQLSchema, OperationDefinitionNode, build_schema @@ -15,6 +16,7 @@ read_graphql_file, walk_graphql_files, ) +from ariadne_codegen.settings import ValidationRuleSkips @pytest.fixture @@ -63,6 +65,49 @@ def test_query_2_str(): } """ +@pytest.fixture +def test_fragment_str(): + return """ + fragment fragmentA on Custom { + node + } + query testQuery2 { + test { + default + ...fragmentA + } + } + """ + +@pytest.fixture +def test_duplicate_fragment_str(): + return """ + fragment fragmentA on Custom { + node + } + fragment fragmentA on Custom { + node + } + query testQuery2 { + test { + default + ...fragmentA + } + } + """ + +@pytest.fixture +def test_unused_fragment_str(): + return """ + fragment fragmentA on Custom { + node + } + query testQuery2 { + test { + default + } + } + """ @pytest.fixture def single_file_schema(tmp_path_factory, schema_str): @@ -132,6 +177,24 @@ def single_file_query(tmp_path_factory, test_query_str): file_.write_text(test_query_str, encoding="utf-8") return file_ +@pytest.fixture +def single_file_query_with_fragment(tmp_path_factory, test_query_str, test_fragment_str): + file_ = tmp_path_factory.mktemp("queries").joinpath("query1_fragment.graphql") + file_.write_text(test_query_str + test_fragment_str, encoding="utf-8") + return file_ + +@pytest.fixture +def single_file_query_with_duplicate_fragment(tmp_path_factory, test_query_str, test_duplicate_fragment_str): + file_ = tmp_path_factory.mktemp("queries").joinpath("query1_duplicate_fragment.graphql") + file_.write_text(test_query_str + test_duplicate_fragment_str, encoding="utf-8") + return file_ + +@pytest.fixture +def single_file_query_with_unused_fragment(tmp_path_factory, test_query_str, test_unused_fragment_str): + file_ = tmp_path_factory.mktemp("queries").joinpath("query1_unused_fragment.graphql") + file_.write_text(test_query_str + test_unused_fragment_str, encoding="utf-8") + return file_ + @pytest.fixture def invalid_syntax_query_file(tmp_path_factory): @@ -434,3 +497,36 @@ def test_get_graphql_queries_with_invalid_query_for_schema_raises_invalid_operat get_graphql_queries( invalid_query_for_schema_file.as_posix(), build_schema(schema_str) ) + + +def test_get_graphql_queries_with_fragment_returns_schema_definitions( + single_file_query_with_fragment, schema_str +): + queries = get_graphql_queries( + single_file_query_with_fragment.as_posix(), build_schema(schema_str) + ) + + assert len(queries) == 3 + +def test_get_graphql_queries_with_duplicate_fragment_raises_invalid_operation( + single_file_query_with_duplicate_fragment, schema_str +): + with pytest.raises(InvalidOperationForSchema): + get_graphql_queries( + single_file_query_with_duplicate_fragment.as_posix(), build_schema(schema_str) + ) + +def test_get_graphql_queries_with_unused_fragment_and_no_skip_rules_raises_invalid_operation( + single_file_query_with_unused_fragment, schema_str +): + with pytest.raises(InvalidOperationForSchema): + get_graphql_queries( + single_file_query_with_unused_fragment.as_posix(), build_schema(schema_str), [] + ) + +def test_get_graphql_queries_with_skip_unique_fragment_names_and_duplicate_fragment_returns_schema_definition( + single_file_query_with_duplicate_fragment, schema_str +): + get_graphql_queries( + single_file_query_with_duplicate_fragment.as_posix(), build_schema(schema_str), [get_validation_rule(ValidationRuleSkips.NoUnusedFragments),get_validation_rule(ValidationRuleSkips.UniqueFragmentNames)] + ) From f83a0d65c529e18af0ecf8e5bbb994a8cd1ed1cf Mon Sep 17 00:00:00 2001 From: Damian Czajkowski Date: Thu, 12 Mar 2026 15:20:50 +0100 Subject: [PATCH 2/2] standarize validation rules configuration to be spread to all possible not only to two --- ariadne_codegen/main.py | 6 +++++- ariadne_codegen/settings.py | 27 ++++++++++++++------------- tests/test_schema.py | 26 +++++++++++++++++++------- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/ariadne_codegen/main.py b/ariadne_codegen/main.py index 14489765..683c8eea 100644 --- a/ariadne_codegen/main.py +++ b/ariadne_codegen/main.py @@ -65,7 +65,11 @@ def client(config_dict): fragments = [] queries = [] if settings.queries_path: - definitions = get_graphql_queries(settings.queries_path, schema, [get_validation_rule(e) for e in settings.skip_validation_rules]) + definitions = get_graphql_queries( + settings.queries_path, + schema, + [get_validation_rule(e) for e in settings.skip_validation_rules], + ) queries = filter_operations_definitions(definitions) fragments = filter_fragments_definitions(definitions) diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 45151bee..b79e7962 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -6,7 +6,7 @@ from pathlib import Path from textwrap import dedent -from graphql.validation import UniqueFragmentNamesRule, NoUnusedFragmentsRule +from graphql.validation import specified_rules from .client_generators.constants import ( DEFAULT_ASYNC_BASE_CLIENT_NAME, @@ -28,18 +28,19 @@ class CommentsStrategy(str, enum.Enum): TIMESTAMP = "timestamp" -class ValidationRuleSkips(str, enum.Enum): - UniqueFragmentNames = "UniqueFragmentNames" - NoUnusedFragments = "NoUnusedFragments" +VALIDATION_RULES_MAP = { + rule.__name__.removesuffix("Rule"): rule for rule in specified_rules +} -def get_validation_rule(rule: ValidationRuleSkips): - if rule == ValidationRuleSkips.UniqueFragmentNames: - return UniqueFragmentNamesRule - elif rule == ValidationRuleSkips.NoUnusedFragments: - return NoUnusedFragmentsRule - else: - raise ValueError(f"Unknown validation rule: {rule}") +def get_validation_rule(rule: str): + try: + return VALIDATION_RULES_MAP[rule] + except KeyError as exc: + supported_rules = ", ".join(sorted(VALIDATION_RULES_MAP)) + raise ValueError( + f"Unknown validation rule: {rule}. Supported values are: {supported_rules}" + ) from exc class Strategy(str, enum.Enum): @@ -89,9 +90,9 @@ class ClientSettings(BaseSettings): include_all_enums: bool = True async_client: bool = True opentelemetry_client: bool = False - skip_validation_rules: list[ValidationRuleSkips] = field( + skip_validation_rules: list[str] = field( default_factory=lambda: [ - ValidationRuleSkips.UniqueFragmentNames, + "NoUnusedFragments", ] ) files_to_include: list[str] = field(default_factory=list) diff --git a/tests/test_schema.py b/tests/test_schema.py index 6bdbad96..63a039f3 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -17,7 +17,7 @@ read_graphql_file, walk_graphql_files, ) -from ariadne_codegen.settings import ValidationRuleSkips, get_validation_rule +from ariadne_codegen.settings import get_validation_rule @pytest.fixture @@ -541,8 +541,9 @@ def test_get_graphql_queries_with_duplicate_fragment_raises_invalid_operation( ) -def test_get_graphql_queries_with_unused_fragment_and_no_skip_rules_raises_invalid_operation( - single_file_query_with_unused_fragment, schema_str +def test_unused_fragment_without_skips_raises_invalid_operation( + single_file_query_with_unused_fragment, + schema_str, ): with pytest.raises(InvalidOperationForSchema): get_graphql_queries( @@ -552,14 +553,25 @@ def test_get_graphql_queries_with_unused_fragment_and_no_skip_rules_raises_inval ) -def test_get_graphql_queries_with_skip_unique_fragment_names_and_duplicate_fragment_returns_schema_definition( - single_file_query_with_duplicate_fragment, schema_str +def test_duplicate_fragment_passes_when_skip_rule_enabled( + single_file_query_with_duplicate_fragment, + schema_str, ): get_graphql_queries( single_file_query_with_duplicate_fragment.as_posix(), build_schema(schema_str), [ - get_validation_rule(ValidationRuleSkips.NoUnusedFragments), - get_validation_rule(ValidationRuleSkips.UniqueFragmentNames), + get_validation_rule("NoUnusedFragments"), + get_validation_rule("UniqueFragmentNames"), ], ) + + +def test_get_validation_rule_accepts_all_specified_rule_names(): + rule = get_validation_rule("NoUnusedVariables") + assert rule.__name__ == "NoUnusedVariablesRule" + + +def test_get_validation_rule_with_unknown_rule_raises_value_error(): + with pytest.raises(ValueError): + get_validation_rule("UnknownRule")