diff --git a/ariadne_codegen/main.py b/ariadne_codegen/main.py index 59df36be..910bd474 100644 --- a/ariadne_codegen/main.py +++ b/ariadne_codegen/main.py @@ -19,7 +19,7 @@ get_graphql_schema_from_path, get_graphql_schema_from_url, ) -from .settings import Strategy +from .settings import Strategy, get_validation_rule @click.command() @@ -66,7 +66,11 @@ def client(config_dict): fragments = [] queries = [] if settings.queries_path: - definitions = get_graphql_queries(settings.queries_path, schema) + definitions = get_graphql_queries( + settings.queries_path, + schema, + [get_validation_rule(e) for e in settings.skip_validation_rules], + ) queries = filter_operations_definitions(definitions) fragments = filter_fragments_definitions(definitions) diff --git a/ariadne_codegen/schema.py b/ariadne_codegen/schema.py index f5ed291c..4556fa55 100644 --- a/ariadne_codegen/schema.py +++ b/ariadne_codegen/schema.py @@ -1,4 +1,4 @@ -from collections.abc import Generator +from collections.abc import Generator, Sequence from dataclasses import asdict from pathlib import Path from typing import Optional, cast @@ -23,6 +23,7 @@ specified_rules, validate, ) +from typing_extensions import Any from .client_generators.constants import MIXIN_FROM_NAME, MIXIN_IMPORT_NAME, MIXIN_NAME from .exceptions import ( @@ -48,7 +49,9 @@ def filter_fragments_definitions( def get_graphql_queries( - queries_path: str, schema: GraphQLSchema + queries_path: str, + schema: GraphQLSchema, + skip_rules: Sequence[Any] = (NoUnusedFragmentsRule,), ) -> tuple[DefinitionNode, ...]: """Get graphql queries definitions build from provided path.""" queries_str = load_graphql_files_from_path(Path(queries_path)) @@ -56,7 +59,7 @@ def get_graphql_queries( validation_errors = validate( schema=schema, document_ast=queries_ast, - rules=[r for r in specified_rules if r is not NoUnusedFragmentsRule], + rules=[r for r in specified_rules if r not in skip_rules], ) if validation_errors: raise InvalidOperationForSchema( diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index d87877b6..709bfa2b 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -6,6 +6,8 @@ from pathlib import Path from textwrap import dedent +from graphql.validation import specified_rules + from .client_generators.constants import ( DEFAULT_ASYNC_BASE_CLIENT_NAME, DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_NAME, @@ -26,6 +28,21 @@ class CommentsStrategy(str, enum.Enum): TIMESTAMP = "timestamp" +VALIDATION_RULES_MAP = { + rule.__name__.removesuffix("Rule"): rule for rule in specified_rules +} + + +def get_validation_rule(rule: str): + try: + return VALIDATION_RULES_MAP[rule] + except KeyError as exc: + supported_rules = ", ".join(sorted(VALIDATION_RULES_MAP)) + raise ValueError( + f"Unknown validation rule: {rule}. Supported values are: {supported_rules}" + ) from exc + + class Strategy(str, enum.Enum): CLIENT = "client" GRAPHQL_SCHEMA = "graphqlschema" @@ -125,6 +142,11 @@ class ClientSettings(BaseSettings): include_all_enums: bool = True async_client: bool = True opentelemetry_client: bool = False + skip_validation_rules: list[str] = field( + default_factory=lambda: [ + "NoUnusedFragments", + ] + ) files_to_include: list[str] = field(default_factory=list) scalars: dict[str, ScalarData] = field(default_factory=dict) default_optional_fields_to_none: bool = False diff --git a/tests/test_schema.py b/tests/test_schema.py index 7a0b3973..fd130e0f 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -17,7 +17,7 @@ read_graphql_file, walk_graphql_files, ) -from ariadne_codegen.settings import IntrospectionSettings +from ariadne_codegen.settings import IntrospectionSettings, get_validation_rule @pytest.fixture @@ -67,6 +67,53 @@ def test_query_2_str(): """ +@pytest.fixture +def test_fragment_str(): + return """ + fragment fragmentA on Custom { + node + } + query testQuery2 { + test { + default + ...fragmentA + } + } + """ + + +@pytest.fixture +def test_duplicate_fragment_str(): + return """ + fragment fragmentA on Custom { + node + } + fragment fragmentA on Custom { + node + } + query testQuery2 { + test { + default + ...fragmentA + } + } + """ + + +@pytest.fixture +def test_unused_fragment_str(): + return """ + fragment fragmentA on Custom { + node + } + query testQuery2 { + test { + default + } + } + """ + + @pytest.fixture def single_file_schema(tmp_path_factory, schema_str): file_ = tmp_path_factory.mktemp("schema").joinpath("schema.graphql") @@ -136,6 +183,37 @@ def single_file_query(tmp_path_factory, test_query_str): return file_ +@pytest.fixture +def single_file_query_with_fragment( + tmp_path_factory, test_query_str, test_fragment_str +): + file_ = tmp_path_factory.mktemp("queries").joinpath("query1_fragment.graphql") + file_.write_text(test_query_str + test_fragment_str, encoding="utf-8") + return file_ + + +@pytest.fixture +def single_file_query_with_duplicate_fragment( + tmp_path_factory, test_query_str, test_duplicate_fragment_str +): + file_ = tmp_path_factory.mktemp("queries").joinpath( + "query1_duplicate_fragment.graphql" + ) + file_.write_text(test_query_str + test_duplicate_fragment_str, encoding="utf-8") + return file_ + + +@pytest.fixture +def single_file_query_with_unused_fragment( + tmp_path_factory, test_query_str, test_unused_fragment_str +): + file_ = tmp_path_factory.mktemp("queries").joinpath( + "query1_unused_fragment.graphql" + ) + file_.write_text(test_query_str + test_unused_fragment_str, encoding="utf-8") + return file_ + + @pytest.fixture def invalid_syntax_query_file(tmp_path_factory): file_ = tmp_path_factory.mktemp("queries").joinpath("query.graphql") @@ -449,6 +527,62 @@ def test_get_graphql_queries_with_invalid_query_for_schema_raises_invalid_operat ) +def test_get_graphql_queries_with_fragment_returns_schema_definitions( + single_file_query_with_fragment, schema_str +): + queries = get_graphql_queries( + single_file_query_with_fragment.as_posix(), build_schema(schema_str) + ) + + assert len(queries) == 3 + + +def test_get_graphql_queries_with_duplicate_fragment_raises_invalid_operation( + single_file_query_with_duplicate_fragment, schema_str +): + with pytest.raises(InvalidOperationForSchema): + get_graphql_queries( + single_file_query_with_duplicate_fragment.as_posix(), + build_schema(schema_str), + ) + + +def test_unused_fragment_without_skips_raises_invalid_operation( + single_file_query_with_unused_fragment, + schema_str, +): + with pytest.raises(InvalidOperationForSchema): + get_graphql_queries( + single_file_query_with_unused_fragment.as_posix(), + build_schema(schema_str), + [], + ) + + +def test_duplicate_fragment_passes_when_skip_rule_enabled( + single_file_query_with_duplicate_fragment, + schema_str, +): + get_graphql_queries( + single_file_query_with_duplicate_fragment.as_posix(), + build_schema(schema_str), + [ + get_validation_rule("NoUnusedFragments"), + get_validation_rule("UniqueFragmentNames"), + ], + ) + + +def test_get_validation_rule_accepts_all_specified_rule_names(): + rule = get_validation_rule("NoUnusedVariables") + assert rule.__name__ == "NoUnusedVariablesRule" + + +def test_get_validation_rule_with_unknown_rule_raises_value_error(): + with pytest.raises(ValueError): + get_validation_rule("UnknownRule") + + def test_introspect_remote_schema_passes_introspection_settings_to_introspection_query( mocker, ):