Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 52 additions & 35 deletions ariadne_codegen/client_generators/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
default_optional_fields_to_none: bool = False,
include_typename: bool = True,
ignore_extra_fields: bool = True,
models_only: bool = False,
) -> None:
self.package_path = Path(target_path) / package_name

Expand Down Expand Up @@ -150,21 +151,23 @@ def __init__(
self._unpacked_fragments: set[str] = set()
self._used_enums: list[str] = []

self.models_only = models_only
self.enable_custom_operations = enable_custom_operations
if self.enable_custom_operations:
self.files_to_include.append(self.base_schema_root_file_path)

def generate(self) -> list[str]:
"""Generate package with graphql client."""
self._include_exceptions()
if not self.models_only:
self._include_exceptions()
self._validate_unique_file_names()
if not self.package_path.exists():
self.package_path.mkdir()
self._generate_input_types()
self._generate_result_types()
self._generate_fragments()
self._copy_files()
if self.enable_custom_operations:
if not self.models_only and self.enable_custom_operations:
self._generate_custom_fields_typing()
self._generate_custom_fields()
self.client_generator.add_execute_custom_operation_method(self.async_client)
Expand All @@ -179,7 +182,8 @@ def generate(self) -> list[str]:
"mutation", OperationType.MUTATION.value.upper(), self.async_client
)

self._generate_client()
if not self.models_only:
self._generate_client()
self._generate_enums()
self._generate_init()

Expand Down Expand Up @@ -223,14 +227,15 @@ def add_operation(self, definition: OperationDefinitionNode):
query_types_generator.get_generated_public_names(), module_name, 1
)

self.client_generator.add_method(
definition=definition,
name=method_name,
return_type=return_type_name,
return_type_module=module_name,
operation_str=operation_str,
async_=self.async_client,
)
if not self.models_only:
self.client_generator.add_method(
definition=definition,
name=method_name,
return_type=return_type_name,
return_type_module=module_name,
operation_str=operation_str,
async_=self.async_client,
)

def _include_exceptions(self):
if self.base_client_file_path in (
Expand All @@ -247,18 +252,19 @@ def _include_exceptions(self):
)

def _validate_unique_file_names(self):
file_names = (
[
file_names = [
self.base_model_file_path.name,
f"{self.enums_module_name}.py",
f"{self.input_types_module_name}.py",
f"{self.fragments_module_name}.py",
]
if not self.models_only:
file_names += [
f"{self.client_file_name}.py",
self.base_client_file_path.name,
self.base_model_file_path.name,
f"{self.enums_module_name}.py",
f"{self.input_types_module_name}.py",
f"{self.fragments_module_name}.py",
]
+ list(self._result_types_files.keys())
+ [f.name for f in self.files_to_include]
)
file_names += list(self._result_types_files.keys())
file_names += [f.name for f in self.files_to_include]

if len(file_names) != len(set(file_names)):
seen = set()
Expand Down Expand Up @@ -310,7 +316,7 @@ def _generate_enums(self):
)

def _generate_input_types(self):
if self.include_all_inputs:
if self.include_all_inputs or self.models_only:
module = self.input_types_generator.generate()
else:
used_inputs = self.client_generator.arguments_generator.get_used_inputs()
Expand Down Expand Up @@ -359,10 +365,9 @@ def _generate_fragments(self):
)

def _copy_files(self):
files_to_copy = self.files_to_include + [
self.base_client_file_path,
self.base_model_file_path,
]
files_to_copy = self.files_to_include + [self.base_model_file_path]
if not self.models_only:
files_to_copy.append(self.base_client_file_path)
for source_path in files_to_copy:
code = self._add_comments_to_code(source_path.read_text(encoding="utf-8"))
if not self.ignore_extra_fields and source_path.name == "base_model.py":
Expand All @@ -373,11 +378,12 @@ def _copy_files(self):
target_path.write_text(code)
self._generated_files.append(target_path.name)

self.init_generator.add_import(
names=[self.base_client_name],
from_=self.base_client_file_path.stem,
level=1,
)
if not self.models_only:
self.init_generator.add_import(
names=[self.base_client_name],
from_=self.base_client_file_path.stem,
level=1,
)
self.init_generator.add_import(
names=[BASE_MODEL_CLASS_NAME, UPLOAD_CLASS_NAME],
from_=self.base_model_file_path.stem,
Expand Down Expand Up @@ -429,20 +435,30 @@ def get_package_generator(
plugin_manager: PluginManager,
) -> PackageGenerator:
init_generator = InitFileGenerator(plugin_manager=plugin_manager)
client_generator = ClientGenerator(
base_client_import=generate_import_from(
if settings.models_only:
base_client_import = generate_import_from(
names=["object"], from_="builtins", level=0
)
client_name = "Client"
base_client = "object"
else:
base_client_import = generate_import_from(
names=[settings.base_client_name],
from_=Path(settings.base_client_file_path).stem,
level=1,
),
)
client_name = settings.client_name
base_client = settings.base_client_name
client_generator = ClientGenerator(
base_client_import=base_client_import,
arguments_generator=ArgumentsGenerator(
schema=schema,
convert_to_snake_case=settings.convert_to_snake_case,
custom_scalars=settings.scalars,
plugin_manager=plugin_manager,
),
name=settings.client_name,
base_client=settings.base_client_name,
name=client_name,
base_client=base_client,
enums_module_name=settings.enums_module_name,
input_types_module_name=settings.input_types_module_name,
unset_import=UNSET_IMPORT,
Expand Down Expand Up @@ -553,4 +569,5 @@ def get_package_generator(
default_optional_fields_to_none=settings.default_optional_fields_to_none,
include_typename=settings.include_typename,
ignore_extra_fields=settings.ignore_extra_fields,
models_only=settings.models_only,
)
47 changes: 36 additions & 11 deletions ariadne_codegen/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,14 @@ class ClientSettings(BaseSettings):
default_optional_fields_to_none: bool = False
include_typename: bool = True
ignore_extra_fields: bool = True
models_only: bool = False

def __post_init__(self):
if not self.queries_path and not self.enable_custom_operations:
if (
not self.queries_path
and not self.enable_custom_operations
and not self.models_only
):
raise TypeError("__init__ missing 1 required argument: 'queries_path'")
super().__post_init__()

Expand All @@ -93,24 +98,27 @@ def __post_init__(self):
f"Valid options are: {valid_options}"
) from exc

self._set_default_base_client_data()
if not self.models_only:
self._set_default_base_client_data()

for name, data in self.scalars.items():
data.graphql_name = name

assert_path_exists(self.queries_path)
if self.queries_path:
assert_path_exists(self.queries_path)

assert_string_is_valid_python_identifier(self.target_package_name)
assert_path_is_valid_directory(self.target_package_path)

assert_string_is_valid_python_identifier(self.client_name)
assert_string_is_valid_python_identifier(self.client_file_name)
assert_string_is_valid_python_identifier(self.base_client_name)
assert_path_exists(self.base_client_file_path)
assert_path_is_valid_file(self.base_client_file_path)
assert_class_is_defined_in_file(
Path(self.base_client_file_path), self.base_client_name
)
if not self.models_only:
assert_string_is_valid_python_identifier(self.client_name)
assert_string_is_valid_python_identifier(self.client_file_name)
assert_string_is_valid_python_identifier(self.base_client_name)
assert_path_exists(self.base_client_file_path)
assert_path_is_valid_file(self.base_client_file_path)
assert_class_is_defined_in_file(
Path(self.base_client_file_path), self.base_client_name
)

assert_string_is_valid_python_identifier(self.enums_module_name)
assert_string_is_valid_python_identifier(self.input_types_module_name)
Expand Down Expand Up @@ -177,6 +185,23 @@ def used_settings_message(self) -> str:
if self.include_typename
else "Not including __typename fields in generated queries."
)
if self.models_only:
return dedent(
f"""\
Selected strategy: {Strategy.CLIENT}
Generating models only.
Using schema from '{self.schema_path or self.remote_schema_url}'.
Using '{self.target_package_name}' as package name.
Generating package into '{self.target_package_path}'.
Generating enums into '{self.enums_module_name}.py'.
Generating inputs into '{self.input_types_module_name}.py'.
Generating fragments into '{self.fragments_module_name}.py'.
Comments type: {self.include_comments.value}
{snake_case_msg}
{files_to_include_msg}
{plugins_msg}
"""
)
return dedent(
f"""\
Selected strategy: {Strategy.CLIENT}
Expand Down
49 changes: 49 additions & 0 deletions tests/client_generators/package_generator/test_generated_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,3 +764,52 @@ def test_generate_creates_client_with_custom_scalars_imports(
f"{generator.client_file_name}.py"
).open() as client_file:
assert "from .abc import ScalarABC" in client_file.read()


def test_generate_models_only(tmp_path, schema, async_base_client_import):
package_name = "test_graphql_client"
generator = PackageGenerator(
package_name=package_name,
target_path=tmp_path.as_posix(),
schema=schema,
init_generator=InitFileGenerator(),
client_generator=ClientGenerator(
base_client_import=async_base_client_import,
arguments_generator=ArgumentsGenerator(schema=schema),
),
enums_generator=EnumsGenerator(schema=schema),
input_types_generator=InputTypesGenerator(schema=schema),
fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}),
models_only=True,
)
query_str = """
query CustomQuery($id: ID!) {
query1(id: $id) {
field1
}
}
"""
generator.add_operation(parse(query_str).definitions[0])
generated_files = generator.generate()

package_path = tmp_path / package_name
# Model files should exist
assert (package_path / "__init__.py").exists()
assert (package_path / "base_model.py").exists()
assert (package_path / f"{generator.enums_module_name}.py").exists()
assert (package_path / f"{generator.input_types_module_name}.py").exists()
# Result types from operations should still be generated
assert (package_path / "custom_query.py").exists()
assert "custom_query.py" in generated_files
# Client runtime files should NOT exist
assert not (package_path / "client.py").exists()
assert not (package_path / generator.base_client_file_path.name).exists()
assert not (package_path / EXCEPTIONS_FILE_PATH.name).exists()
assert "client.py" not in generated_files
assert EXCEPTIONS_FILE_PATH.name not in generated_files
# __init__.py should not import client classes
init_content = (package_path / "__init__.py").read_text()
assert "from .base_model import BaseModel, Upload" in init_content
assert "Client" not in init_content
assert "AsyncBaseClient" not in init_content
assert "GraphQLClientError" not in init_content
19 changes: 19 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,3 +479,22 @@ def test_client_settings_include_typename_can_be_set_to_true(tmp_path):
)

assert settings.include_typename is True


def test_client_settings_models_only(tmp_path):
schema_path = tmp_path / "schema.graphql"
schema_path.touch()

settings = ClientSettings(
schema_path=schema_path.as_posix(),
models_only=True,
)

assert settings.models_only is True
assert settings.queries_path == ""
assert settings.base_client_name == ""
assert settings.base_client_file_path == ""

result = settings.used_settings_message
assert "Generating models only." in result
assert settings.schema_path in result