Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "turms"
version = "0.3.1a0"
version = "0.4.0a1"
description = "graphql-codegen powered by pydantic"
authors = ["jhnnsrs <jhnnsrs@gmail.com>"]
license = "CC BY-NC 3.0"
Expand Down
1 change: 1 addition & 0 deletions turms/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def log(message, level):
)

except Exception as e:
get_console().print_exception()
project_tree.style = "red"
project_tree.label = f"{key} 💥"
project_tree.add(Tree(str(e), style="red"))
Expand Down
7 changes: 5 additions & 2 deletions turms/cli/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ def __call__(self, change, path: str) -> bool:
return False

x = os.path.basename(path)
fileending = x.split(".")[1]
try:
fileending = x.split(".")[1]

return fileending == "graphql"
return fileending == "graphql"
except Exception:
return False


def stream_changes(folder: str): # pragma: no cover
Expand Down
44 changes: 42 additions & 2 deletions turms/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pydantic import AnyHttpUrl, BaseModel, BaseSettings, Field, validator
from typing import Any, Dict, List, Optional, Union, Protocol, Literal, runtime_checkable
from turms.helpers import import_string
from graphql import DirectiveLocation
from enum import Enum


Expand All @@ -11,7 +12,6 @@ class ConfigProxy(BaseModel):
class Config:
extra = "allow"


class ImportableFunctionMixin(Protocol):
@classmethod
def __get_validators__(cls):
Expand Down Expand Up @@ -71,6 +71,10 @@ class LogFunction(Protocol):
def __call__(self, message, level: LogLevel = LogLevel.INFO):
pass

@runtime_checkable
class DirectiveFunction(ImportableFunctionMixin, Protocol):
pass


class FreezeConfig(BaseSettings):
"""Configuration for freezing the generated pydantic
Expand Down Expand Up @@ -157,6 +161,28 @@ class OptionsConfig(BaseSettings):
)
"""The types to freeze"""

DescriptionLocation = Literal["field", "docstring", "both"]
ArgType = Literal["string", "int", "float"]

class TurmsDirective(BaseModel):
"""A directive resolver"""

type: DirectiveFunction
"""The type of the resolver"""
locations: List[DirectiveLocation] = Field(
[], description="The location of the directive"
)
args: Optional[Dict[str, ArgType]] = Field(
default_factory=dict, description="The arguments of the directive"
)
trim: bool = Field(
True, description="This directive will be exluded and not send to the server"
)






class GeneratorConfig(BaseSettings):
"""Configuration for the generator
Expand All @@ -183,6 +209,13 @@ class GeneratorConfig(BaseSettings):
allow_introspection: bool = True
"""Allow introspection queries"""

description: DescriptionLocation = "docstring"
turms_directives: Dict[str, TurmsDirective] = Field(
default_factory=dict,
description="Mapping directives to a resolver e.g @upper: path.to.upper_resolver, @lower: path.to.lower_resolver",
)


object_bases: List[str] = ["pydantic.BaseModel"]
"""The base classes for the generated objects. This is useful if you want to change the base class from BaseModel to something else"""

Expand All @@ -192,7 +225,14 @@ class GeneratorConfig(BaseSettings):
"""Always resolve interfaces to concrete types"""
exclude_typenames: bool = False
"""Exclude __typename from generated models when calling dict or json"""

default_factories: Dict[str, PythonType] = Field(
default_factory=dict,
description="Mapping arguments to a default factories e.g ID: uuid.uuid4, Date: datetime.date.today, if you want to generated a default value for a list of a Type you can use standard graphql syntax like [Date]: path.to.list.factory",
)
argument_validators: Dict[str, PythonType] = Field(
default_factory=dict,
description="Will generate a validator for arguments that will always be called. This is useful for validating arguments that are not part of the schema. e.g. ID: path.to.validator",
)
scalar_definitions: Dict[str, PythonType] = Field(
default_factory=dict,
description="Additional config for mapping scalars to python types (e.g. ID: str). Can use dotted paths to import types from other modules.",
Expand Down
67 changes: 67 additions & 0 deletions turms/directive_resolvers/variable/validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import ast
from turms.registry import ClassRegistry
from graphql import VariableDefinitionNode
from ast import AST
from typing import List


def validator(
v: VariableDefinitionNode, body: List[AST], registry: ClassRegistry, path: str
):
registry.register_import("pydantic.validator")
registry.register_import(path)

body.append(
ast.FunctionDef(
name=registry.generate_parameter_name(v.variable.name.value) + "_validator",
args=ast.arguments(
args=[
ast.arg(
arg="cls",
),
ast.arg(
arg="value",
),
],
defaults=[],
kw_defaults=[],
kwarg=None,
kwonlyargs=[],
posonlyargs=[],
vararg=None,
),
body=[
ast.Return(
value=ast.Call(
func=ast.Name(id=path.split(".")[-1], ctx=ast.Load()),
args=[
ast.Name(id="cls", ctx=ast.Load()),
ast.Name(id="value", ctx=ast.Load()),
],
keywords=[],
),
),
],
decorator_list=[
ast.Call(
func=ast.Name(id="validator", ctx=ast.Load()),
args=[
ast.Constant(
registry.generate_parameter_name(v.variable.name.value),
ctx=ast.Load(),
),
],
keywords=[
ast.keyword(
arg="pre", value=ast.Constant(True, ctx=ast.Load())
),
ast.keyword(
arg="always", value=ast.Constant(True, ctx=ast.Load())
),
],
),
],
returns=None,
)
)
return body
87 changes: 63 additions & 24 deletions turms/plugins/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
FieldNode,
OperationDefinitionNode,
OperationType,
VariableDefinitionNode,
NamedTypeNode,
ListTypeNode,
)
from graphql.utilities.build_client_schema import GraphQLSchema
from graphql.utilities.get_operation_root_type import get_operation_root_type
Expand All @@ -27,7 +30,7 @@
parse_value_node,
)
import logging

from turms.utils import print_operation

logger = logging.getLogger(__name__)
fragment_searcher = re.compile(r"\.\.\.(?P<fragment>[a-zA-Z]*)")
Expand All @@ -52,7 +55,6 @@ def get_query_bases(
plugin_config: OperationsPluginConfig,
registry: ClassRegistry,
):

if plugin_config.query_bases:
for base in plugin_config.query_bases:
registry.register_import(base)
Expand All @@ -76,7 +78,6 @@ def get_mutation_bases(
plugin_config: OperationsPluginConfig,
registry: ClassRegistry,
):

if plugin_config.mutation_bases:
for base in plugin_config.mutation_bases:
registry.register_import(base)
Expand All @@ -100,7 +101,6 @@ def get_arguments_bases(
plugin_config: OperationsPluginConfig,
registry: ClassRegistry,
):

if plugin_config.arguments_bases:
for base in plugin_config.arguments_bases:
registry.register_import(base)
Expand All @@ -124,7 +124,6 @@ def get_subscription_bases(
plugin_config: OperationsPluginConfig,
registry: ClassRegistry,
):

if plugin_config.subscription_bases:
for base in plugin_config.subscription_bases:
registry.register_import(base)
Expand All @@ -143,14 +142,23 @@ def get_subscription_bases(
]


def represent_variable_definition(x: VariableDefinitionNode) -> str:
if isinstance(x, NamedTypeNode):
return x.name.value
if isinstance(x, NonNullTypeNode):
return f"{represent_variable_definition(x.type)}!"
if isinstance(x, ListTypeNode):
return f"[{represent_variable_definition(x.type)}]"
raise Exception(f"Unknown type {type(x)}")


def generate_operation(
o: OperationDefinitionNode,
client_schema: GraphQLSchema,
config: GeneratorConfig,
plugin_config: OperationsPluginConfig,
registry: ClassRegistry,
):

tree = []
assert o.name.value, "Operation names are required"

Expand Down Expand Up @@ -197,46 +205,78 @@ def generate_operation(
),
]

query_document = language.print_ast(o)
merged_document = replace_iteratively(query_document, registry)
merged_document = print_operation(o, config, registry)

if plugin_config.create_arguments:

arguments_body = []
validators_body = []

for v in o.variable_definitions:
if isinstance(v.type, NonNullTypeNode) and not v.default_value:
arguments_body += [
variable_def_body = []

type = v.type

keywords = []

if v.default_value:
keywords.append(
ast.keyword(
arg="default",
value=ast.Constant(v.default_value.value, ctx=ast.Load()),
)
)

if keywords:
registry.register_import("pydantic.Field")
variable_def_body += [
ast.AnnAssign(
target=ast.Name(
id=registry.generate_parameter_name(v.variable.name.value),
ctx=ast.Store(),
),
annotation=recurse_type_annotation(v.type, registry),
annotation=recurse_type_annotation(
v.type,
registry,
),
value=ast.Call(
func=ast.Name(id="Field", ctx=ast.Load()),
args=[],
keywords=keywords,
),
simple=1,
)
]

if not isinstance(v.type, NonNullTypeNode) or v.default_value:
arguments_body += [
else:
variable_def_body += [
ast.AnnAssign(
target=ast.Name(
id=registry.generate_parameter_name(v.variable.name.value),
ctx=ast.Store(),
),
annotation=recurse_type_annotation(
v.type,
registry,
),
value=ast.Constant(
value=parse_value_node(v.default_value)
if v.default_value
else None
),
annotation=recurse_type_annotation(v.type, registry),
simple=1,
)
]

for directive in v.directives:
directive_name = directive.name.value
directive_def = registry.get_directive(directive_name)
if directive_def:
variable_def_body = directive_def.type(
v,
variable_def_body,
registry,
**{
arg.name.value: arg.value.value
for arg in directive.arguments
},
)

arguments_body += variable_def_body

arguments_body += validators_body

class_body_fields += [
ast.ClassDef(
"Arguments",
Expand Down Expand Up @@ -299,7 +339,6 @@ def generate_ast(
config: GeneratorConfig,
registry: ClassRegistry,
) -> List[ast.AST]:

plugin_tree = []

documents = parse_documents(
Expand Down
Loading