Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion ariadne_graphql_proxy/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,11 @@ def interfaces_thunk():
name=graphql_type.name,
fields=fields_thunk,
interfaces=interfaces_thunk,
is_type_of=graphql_type.is_type_of,
description=graphql_type.description,
extensions=graphql_type.extensions,
ast_node=graphql_type.ast_node,
extension_ast_nodes=graphql_type.extension_ast_nodes,
)


Expand Down Expand Up @@ -759,6 +764,11 @@ def interfaces_thunk():
name=interface_type.name,
fields=fields_thunk,
interfaces=interfaces_thunk,
resolve_type=interface_type.resolve_type,
description=interface_type.description,
extensions=interface_type.extensions,
ast_node=interface_type.ast_node,
extension_ast_nodes=interface_type.extension_ast_nodes,
)


Expand All @@ -775,7 +785,15 @@ def thunk():
if subtype.name not in types_to_exclude
)

return GraphQLUnionType(name=union_type.name, types=thunk)
return GraphQLUnionType(
name=union_type.name,
types=thunk,
resolve_type=union_type.resolve_type,
description=union_type.description,
extensions=union_type.extensions,
ast_node=union_type.ast_node,
extension_ast_nodes=union_type.extension_ast_nodes,
)


def copy_directives(
Expand Down
53 changes: 53 additions & 0 deletions ariadne_graphql_proxy/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,28 @@ def interfaces_thunk():
}
return [merged_types[name] for name in interfaces_names]

if (
object1.is_type_of
and object2.is_type_of
and object1.is_type_of != object2.is_type_of
):
raise TypeError(
f"{object1.name} is_type_of functions don't match: "
f"{repr(object1.is_type_of)} != {repr(object2.is_type_of)}"
)

extensions = object1.extensions.copy()
extensions.update(**object2.extensions.copy())

return GraphQLObjectType(
name=object1.name,
fields=fields_thunk,
interfaces=interfaces_thunk,
is_type_of=object1.is_type_of or object2.is_type_of,
description=object1.description or object2.description,
extensions=extensions,
ast_node=object1.ast_node or object2.ast_node,
extension_ast_nodes=object1.extension_ast_nodes or object2.extension_ast_nodes,
)


Expand Down Expand Up @@ -583,11 +600,30 @@ def interfaces_thunk():
}
return [merged_types[name] for name in interfaces_names]

if (
interface1.resolve_type
and interface2.resolve_type
and interface1.resolve_type != interface2.resolve_type
):
raise TypeError(
f"{interface1.name} resolve_type functions don't match: "
f"{repr(interface1.resolve_type)} != {repr(interface2.resolve_type)}"
)

extensions = interface1.extensions.copy()
extensions.update(**interface2.extensions.copy())

return GraphQLInterfaceType(
name=interface1.name,
fields=fields_thunk,
interfaces=interfaces_thunk,
resolve_type=interface1.resolve_type or interface2.resolve_type,
description=interface1.description or interface2.description,
extensions=extensions,
ast_node=interface1.ast_node or interface2.ast_node,
extension_ast_nodes=(
interface1.extension_ast_nodes or interface2.extension_ast_nodes
),
)


Expand All @@ -608,8 +644,25 @@ def thunk():
names = {t.name for t in chain(union1.types, union2.types)}
return tuple(merged_types[subtype_name] for subtype_name in names)

if (
union1.resolve_type
and union2.resolve_type
and union1.resolve_type != union2.resolve_type
):
raise TypeError(
f"{union1.name} resolve_type functions don't match: "
f"{repr(union1.resolve_type)} != {repr(union2.resolve_type)}"
)

extensions = union1.extensions.copy()
extensions.update(**union2.extensions.copy())

return GraphQLUnionType(
name=union1.name,
types=thunk,
resolve_type=union1.resolve_type or union2.resolve_type,
description=union1.description or union2.description,
extensions=extensions,
ast_node=union1.ast_node or union2.ast_node,
extension_ast_nodes=union1.extension_ast_nodes or union2.extension_ast_nodes,
)
63 changes: 47 additions & 16 deletions ariadne_graphql_proxy/query_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@

from .selections import merge_selections

# Present in every GraphQL response when requested, but not listed on types in
# introspection `fields` — without an explicit allow-list, query splitting drops
# them and abstract types cannot be resolved at execution time.
_COMPOSITE_META_FIELDS = frozenset({"__typename"})

# Valid only on the root Query type; same introspection gap as __typename.
_ROOT_QUERY_INTROSPECTION_FIELDS = frozenset({"__schema", "__type"})


class QueryFilterContext:
schema_id: int
Expand Down Expand Up @@ -49,6 +57,23 @@ def __init__(
self.foreign_keys = foreign_keys
self.dependencies = dependencies

def _field_passes_type_filter(
self,
field_name: str,
type_fields: Dict[str, Set[int]],
schema_id: int,
*,
root_graphql_type: str | None = None,
) -> bool:
if field_name in _COMPOSITE_META_FIELDS:
return True
if (
root_graphql_type == "Query"
and field_name in _ROOT_QUERY_INTROSPECTION_FIELDS
):
return True
return field_name in type_fields and schema_id in type_fields[field_name]

def split_query(
self, document: DocumentNode
) -> List[Tuple[int, DocumentNode, Set[str]]]:
Expand Down Expand Up @@ -101,13 +126,16 @@ def filter_operation_node(
return None

type_fields = self.fields_map[type_name]
self.update_context_variables(operation_node, context)
new_selections: List[SelectionNode] = []

for selection in operation_node.selection_set.selections:
if isinstance(selection, FieldNode):
if (
selection.name.value not in type_fields
or context.schema_id not in type_fields[selection.name.value]
if not self._field_passes_type_filter(
selection.name.value,
type_fields,
context.schema_id,
root_graphql_type=type_name,
):
continue

Expand Down Expand Up @@ -182,6 +210,9 @@ def filter_field_node( # noqa: C901
selection_set=SelectionSetNode(selections=foreign_key),
)

if schema_obj == "Query" and field_name in _ROOT_QUERY_INTROSPECTION_FIELDS:
return field_node

type_name = self.fields_types[schema_obj][field_name]
type_is_union = type_name in self.unions

Expand All @@ -203,9 +234,8 @@ def filter_field_node( # noqa: C901
new_selections, fields_dependencies[field_name].selections
)

if (
field_name not in type_fields
or context.schema_id not in type_fields[field_name]
if not self._field_passes_type_filter(
field_name, type_fields, context.schema_id
):
continue

Expand Down Expand Up @@ -252,6 +282,7 @@ def filter_inline_fragment_node(
schema_obj: str,
context: QueryFilterContext,
) -> InlineFragmentNode | None:
self.update_context_variables(fragment_node, context)
type_name = fragment_node.type_condition.name.value
type_fields = self.fields_map[type_name]

Expand All @@ -268,9 +299,8 @@ def filter_inline_fragment_node(
new_selections, fields_dependencies[field_name].selections
)

if (
field_name not in type_fields
or context.schema_id not in type_fields[field_name]
if not self._field_passes_type_filter(
field_name, type_fields, context.schema_id
):
continue

Expand Down Expand Up @@ -306,6 +336,7 @@ def filter_fragment_spread_node(
schema_obj: str,
context: QueryFilterContext,
) -> List[SelectionNode]:
self.update_context_variables(fragment_node, context)
fragment_name = fragment_node.name.value
fragment = context.fragments.get(fragment_name)

Expand All @@ -328,9 +359,8 @@ def filter_fragment_spread_node(
new_selections, fields_dependencies[field_name].selections
)

if (
field_name not in type_fields
or context.schema_id not in type_fields[field_name]
if not self._field_passes_type_filter(
field_name, type_fields, context.schema_id
):
continue

Expand Down Expand Up @@ -389,11 +419,12 @@ def get_type_fields_dependencies(

return None

def update_context_variables(
self, field_node: FieldNode, context: QueryFilterContext
):
for argument in field_node.arguments:
def update_context_variables(self, node, context: QueryFilterContext):
for argument in getattr(node, "arguments", ()) or ():
self.extract_variables(argument.value, context) # type: ignore
for directive in getattr(node, "directives", ()) or ():
for argument in directive.arguments:
self.extract_variables(argument.value, context) # type: ignore

def extract_variables(
self,
Expand Down
26 changes: 26 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,32 @@ def search_schema_json(search_schema):
return {"data": schema_data}


@pytest.fixture
def interface_entry_schema():
return make_executable_schema(
"""
type Query {
entry: Result!
}

interface Result {
id: ID!
}

type Thing implements Result {
id: ID!
title: String!
}
"""
)


@pytest.fixture
def interface_entry_schema_json(interface_entry_schema):
schema_data = graphql_sync(interface_entry_schema, get_introspection_query()).data
return {"data": schema_data}


@pytest.fixture
def search_root_value():
return {
Expand Down
15 changes: 15 additions & 0 deletions tests/test_copy_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,18 @@ def test_copy_interface_returns_new_interface_without_excluded_arg():
assert isinstance(copied_type, GraphQLInterfaceType)
assert copied_type is not graphql_type
assert "arg1" not in copied_type.fields["fieldA"].args


def test_copy_interface_preserves_resolve_type():
def resolve_type(*_):
return "TypeName"

graphql_type = GraphQLInterfaceType(
name="TypeName",
fields={"fieldA": GraphQLField(type_=GraphQLString)},
resolve_type=resolve_type,
)

copied_type = copy_interface({}, graphql_type)

assert copied_type.resolve_type is resolve_type
15 changes: 15 additions & 0 deletions tests/test_copy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,18 @@ def test_copy_object_returns_new_object_without_excluded_argument():
assert isinstance(copied_type, GraphQLObjectType)
assert copied_type is not graphql_type
assert "arg1" not in copied_type.fields["fieldA"].args


def test_copy_object_preserves_is_type_of():
def is_type_of(*_):
return True

graphql_type = GraphQLObjectType(
name="TypeName",
fields={"fieldA": GraphQLField(type_=GraphQLString)},
is_type_of=is_type_of,
)

copied_type = copy_object({}, graphql_type)

assert copied_type.is_type_of is is_type_of
19 changes: 19 additions & 0 deletions tests/test_copy_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,22 @@ def test_copy_union_returns_copy_of_union_with_copies_of_subtypes():
assert isinstance(copied_union_type, GraphQLUnionType)
assert copied_union_type is not union_type
assert copied_union_type.types == (duplicated_subtype1, duplicated_subtype2)


def test_copy_union_preserves_resolve_type():
def resolve_type(*_):
return "TypeA"

subtype = GraphQLObjectType(
name="TypeA", fields={"valA": GraphQLField(type_=GraphQLString)}
)
duplicated_subtype = GraphQLObjectType(
name="TypeA", fields={"valA": GraphQLField(type_=GraphQLString)}
)
union_type = GraphQLUnionType(
name="UnionType", types=[subtype], resolve_type=resolve_type
)

copied_union_type = copy_union({"TypeA": duplicated_subtype}, union_type)

assert copied_union_type.resolve_type is resolve_type
23 changes: 23 additions & 0 deletions tests/test_merge_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,26 @@ def test_merge_interfaces_raises_type_error_for_not_matching_descriptions():
interface1=interface1,
interface2=interface2,
)


def test_merge_interfaces_preserves_resolve_type():
def resolve_type(*_):
return "TestType"

interface1 = GraphQLInterfaceType(
name="TestInterface",
fields={"fieldA": GraphQLField(type_=GraphQLString)},
resolve_type=resolve_type,
)
interface2 = GraphQLInterfaceType(
name="TestInterface",
fields={"fieldA": GraphQLField(type_=GraphQLString)},
)

merged_interface = merge_interfaces(
merged_types={},
interface1=interface1,
interface2=interface2,
)

assert merged_interface.resolve_type is resolve_type
Loading
Loading