Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ariadne_codegen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
get_graphql_schema_from_path,
get_graphql_schema_from_url,
)
from .settings import Strategy
from .settings import Strategy, get_validation_rule


@click.command()
Expand Down Expand Up @@ -66,7 +66,11 @@ def client(config_dict):
fragments = []
queries = []
if settings.queries_path:
definitions = get_graphql_queries(settings.queries_path, schema)
definitions = get_graphql_queries(
settings.queries_path,
schema,
[get_validation_rule(e) for e in settings.skip_validation_rules],
)
queries = filter_operations_definitions(definitions)
fragments = filter_fragments_definitions(definitions)

Expand Down
9 changes: 6 additions & 3 deletions ariadne_codegen/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Sequence
from dataclasses import asdict
from pathlib import Path
from typing import Optional, cast
Expand All @@ -23,6 +23,7 @@
specified_rules,
validate,
)
from typing_extensions import Any

from .client_generators.constants import MIXIN_FROM_NAME, MIXIN_IMPORT_NAME, MIXIN_NAME
from .exceptions import (
Expand All @@ -48,15 +49,17 @@ def filter_fragments_definitions(


def get_graphql_queries(
queries_path: str, schema: GraphQLSchema
queries_path: str,
schema: GraphQLSchema,
skip_rules: Sequence[Any] = (NoUnusedFragmentsRule,),
) -> tuple[DefinitionNode, ...]:
"""Get graphql queries definitions build from provided path."""
queries_str = load_graphql_files_from_path(Path(queries_path))
queries_ast = parse(queries_str)
validation_errors = validate(
schema=schema,
document_ast=queries_ast,
rules=[r for r in specified_rules if r is not NoUnusedFragmentsRule],
rules=[r for r in specified_rules if r not in skip_rules],
)
if validation_errors:
raise InvalidOperationForSchema(
Expand Down
22 changes: 22 additions & 0 deletions ariadne_codegen/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pathlib import Path
from textwrap import dedent

from graphql.validation import specified_rules

from .client_generators.constants import (
DEFAULT_ASYNC_BASE_CLIENT_NAME,
DEFAULT_ASYNC_BASE_CLIENT_OPEN_TELEMETRY_NAME,
Expand All @@ -26,6 +28,21 @@ class CommentsStrategy(str, enum.Enum):
TIMESTAMP = "timestamp"


VALIDATION_RULES_MAP = {
rule.__name__.removesuffix("Rule"): rule for rule in specified_rules
}


def get_validation_rule(rule: str):
try:
return VALIDATION_RULES_MAP[rule]
except KeyError as exc:
supported_rules = ", ".join(sorted(VALIDATION_RULES_MAP))
raise ValueError(
f"Unknown validation rule: {rule}. Supported values are: {supported_rules}"
) from exc


class Strategy(str, enum.Enum):
CLIENT = "client"
GRAPHQL_SCHEMA = "graphqlschema"
Expand Down Expand Up @@ -125,6 +142,11 @@ class ClientSettings(BaseSettings):
include_all_enums: bool = True
async_client: bool = True
opentelemetry_client: bool = False
skip_validation_rules: list[str] = field(
default_factory=lambda: [
"NoUnusedFragments",
]
)
files_to_include: list[str] = field(default_factory=list)
scalars: dict[str, ScalarData] = field(default_factory=dict)
default_optional_fields_to_none: bool = False
Expand Down
136 changes: 135 additions & 1 deletion tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
read_graphql_file,
walk_graphql_files,
)
from ariadne_codegen.settings import IntrospectionSettings
from ariadne_codegen.settings import IntrospectionSettings, get_validation_rule


@pytest.fixture
Expand Down Expand Up @@ -67,6 +67,53 @@ def test_query_2_str():
"""


@pytest.fixture
def test_fragment_str():
return """
fragment fragmentA on Custom {
node
}
query testQuery2 {
test {
default
...fragmentA
}
}
"""


@pytest.fixture
def test_duplicate_fragment_str():
return """
fragment fragmentA on Custom {
node
}
fragment fragmentA on Custom {
node
}
query testQuery2 {
test {
default
...fragmentA
}
}
"""


@pytest.fixture
def test_unused_fragment_str():
return """
fragment fragmentA on Custom {
node
}
query testQuery2 {
test {
default
}
}
"""


@pytest.fixture
def single_file_schema(tmp_path_factory, schema_str):
file_ = tmp_path_factory.mktemp("schema").joinpath("schema.graphql")
Expand Down Expand Up @@ -136,6 +183,37 @@ def single_file_query(tmp_path_factory, test_query_str):
return file_


@pytest.fixture
def single_file_query_with_fragment(
tmp_path_factory, test_query_str, test_fragment_str
):
file_ = tmp_path_factory.mktemp("queries").joinpath("query1_fragment.graphql")
file_.write_text(test_query_str + test_fragment_str, encoding="utf-8")
return file_


@pytest.fixture
def single_file_query_with_duplicate_fragment(
tmp_path_factory, test_query_str, test_duplicate_fragment_str
):
file_ = tmp_path_factory.mktemp("queries").joinpath(
"query1_duplicate_fragment.graphql"
)
file_.write_text(test_query_str + test_duplicate_fragment_str, encoding="utf-8")
return file_


@pytest.fixture
def single_file_query_with_unused_fragment(
tmp_path_factory, test_query_str, test_unused_fragment_str
):
file_ = tmp_path_factory.mktemp("queries").joinpath(
"query1_unused_fragment.graphql"
)
file_.write_text(test_query_str + test_unused_fragment_str, encoding="utf-8")
return file_


@pytest.fixture
def invalid_syntax_query_file(tmp_path_factory):
file_ = tmp_path_factory.mktemp("queries").joinpath("query.graphql")
Expand Down Expand Up @@ -449,6 +527,62 @@ def test_get_graphql_queries_with_invalid_query_for_schema_raises_invalid_operat
)


def test_get_graphql_queries_with_fragment_returns_schema_definitions(
single_file_query_with_fragment, schema_str
):
queries = get_graphql_queries(
single_file_query_with_fragment.as_posix(), build_schema(schema_str)
)

assert len(queries) == 3


def test_get_graphql_queries_with_duplicate_fragment_raises_invalid_operation(
single_file_query_with_duplicate_fragment, schema_str
):
with pytest.raises(InvalidOperationForSchema):
get_graphql_queries(
single_file_query_with_duplicate_fragment.as_posix(),
build_schema(schema_str),
)


def test_unused_fragment_without_skips_raises_invalid_operation(
single_file_query_with_unused_fragment,
schema_str,
):
with pytest.raises(InvalidOperationForSchema):
get_graphql_queries(
single_file_query_with_unused_fragment.as_posix(),
build_schema(schema_str),
[],
)


def test_duplicate_fragment_passes_when_skip_rule_enabled(
single_file_query_with_duplicate_fragment,
schema_str,
):
get_graphql_queries(
single_file_query_with_duplicate_fragment.as_posix(),
build_schema(schema_str),
[
get_validation_rule("NoUnusedFragments"),
get_validation_rule("UniqueFragmentNames"),
],
)


def test_get_validation_rule_accepts_all_specified_rule_names():
rule = get_validation_rule("NoUnusedVariables")
assert rule.__name__ == "NoUnusedVariablesRule"


def test_get_validation_rule_with_unknown_rule_raises_value_error():
with pytest.raises(ValueError):
get_validation_rule("UnknownRule")


def test_introspect_remote_schema_passes_introspection_settings_to_introspection_query(
mocker,
):
Expand Down
Loading