diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index c84e48f4..ceac9704 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -92,6 +92,7 @@ def __init__( default_optional_fields_to_none: bool = False, include_typename: bool = True, ignore_extra_fields: bool = True, + models_only: bool = False, ) -> None: self.package_path = Path(target_path) / package_name @@ -150,13 +151,15 @@ def __init__( self._unpacked_fragments: set[str] = set() self._used_enums: list[str] = [] + self.models_only = models_only self.enable_custom_operations = enable_custom_operations if self.enable_custom_operations: self.files_to_include.append(self.base_schema_root_file_path) def generate(self) -> list[str]: """Generate package with graphql client.""" - self._include_exceptions() + if not self.models_only: + self._include_exceptions() self._validate_unique_file_names() if not self.package_path.exists(): self.package_path.mkdir() @@ -164,7 +167,7 @@ def generate(self) -> list[str]: self._generate_result_types() self._generate_fragments() self._copy_files() - if self.enable_custom_operations: + if not self.models_only and self.enable_custom_operations: self._generate_custom_fields_typing() self._generate_custom_fields() self.client_generator.add_execute_custom_operation_method(self.async_client) @@ -179,7 +182,8 @@ def generate(self) -> list[str]: "mutation", OperationType.MUTATION.value.upper(), self.async_client ) - self._generate_client() + if not self.models_only: + self._generate_client() self._generate_enums() self._generate_init() @@ -223,14 +227,15 @@ def add_operation(self, definition: OperationDefinitionNode): query_types_generator.get_generated_public_names(), module_name, 1 ) - self.client_generator.add_method( - definition=definition, - name=method_name, - return_type=return_type_name, - return_type_module=module_name, - operation_str=operation_str, - async_=self.async_client, - ) + if not self.models_only: + self.client_generator.add_method( + definition=definition, + name=method_name, + return_type=return_type_name, + return_type_module=module_name, + operation_str=operation_str, + async_=self.async_client, + ) def _include_exceptions(self): if self.base_client_file_path in ( @@ -247,18 +252,19 @@ def _include_exceptions(self): ) def _validate_unique_file_names(self): - file_names = ( - [ + file_names = [ + self.base_model_file_path.name, + f"{self.enums_module_name}.py", + f"{self.input_types_module_name}.py", + f"{self.fragments_module_name}.py", + ] + if not self.models_only: + file_names += [ f"{self.client_file_name}.py", self.base_client_file_path.name, - self.base_model_file_path.name, - f"{self.enums_module_name}.py", - f"{self.input_types_module_name}.py", - f"{self.fragments_module_name}.py", ] - + list(self._result_types_files.keys()) - + [f.name for f in self.files_to_include] - ) + file_names += list(self._result_types_files.keys()) + file_names += [f.name for f in self.files_to_include] if len(file_names) != len(set(file_names)): seen = set() @@ -310,7 +316,7 @@ def _generate_enums(self): ) def _generate_input_types(self): - if self.include_all_inputs: + if self.include_all_inputs or self.models_only: module = self.input_types_generator.generate() else: used_inputs = self.client_generator.arguments_generator.get_used_inputs() @@ -359,10 +365,9 @@ def _generate_fragments(self): ) def _copy_files(self): - files_to_copy = self.files_to_include + [ - self.base_client_file_path, - self.base_model_file_path, - ] + files_to_copy = self.files_to_include + [self.base_model_file_path] + if not self.models_only: + files_to_copy.append(self.base_client_file_path) for source_path in files_to_copy: code = self._add_comments_to_code(source_path.read_text(encoding="utf-8")) if not self.ignore_extra_fields and source_path.name == "base_model.py": @@ -373,11 +378,12 @@ def _copy_files(self): target_path.write_text(code) self._generated_files.append(target_path.name) - self.init_generator.add_import( - names=[self.base_client_name], - from_=self.base_client_file_path.stem, - level=1, - ) + if not self.models_only: + self.init_generator.add_import( + names=[self.base_client_name], + from_=self.base_client_file_path.stem, + level=1, + ) self.init_generator.add_import( names=[BASE_MODEL_CLASS_NAME, UPLOAD_CLASS_NAME], from_=self.base_model_file_path.stem, @@ -429,20 +435,30 @@ def get_package_generator( plugin_manager: PluginManager, ) -> PackageGenerator: init_generator = InitFileGenerator(plugin_manager=plugin_manager) - client_generator = ClientGenerator( - base_client_import=generate_import_from( + if settings.models_only: + base_client_import = generate_import_from( + names=["object"], from_="builtins", level=0 + ) + client_name = "Client" + base_client = "object" + else: + base_client_import = generate_import_from( names=[settings.base_client_name], from_=Path(settings.base_client_file_path).stem, level=1, - ), + ) + client_name = settings.client_name + base_client = settings.base_client_name + client_generator = ClientGenerator( + base_client_import=base_client_import, arguments_generator=ArgumentsGenerator( schema=schema, convert_to_snake_case=settings.convert_to_snake_case, custom_scalars=settings.scalars, plugin_manager=plugin_manager, ), - name=settings.client_name, - base_client=settings.base_client_name, + name=client_name, + base_client=base_client, enums_module_name=settings.enums_module_name, input_types_module_name=settings.input_types_module_name, unset_import=UNSET_IMPORT, @@ -553,4 +569,5 @@ def get_package_generator( default_optional_fields_to_none=settings.default_optional_fields_to_none, include_typename=settings.include_typename, ignore_extra_fields=settings.ignore_extra_fields, + models_only=settings.models_only, ) diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 74d29e17..1862d343 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -78,9 +78,14 @@ class ClientSettings(BaseSettings): default_optional_fields_to_none: bool = False include_typename: bool = True ignore_extra_fields: bool = True + models_only: bool = False def __post_init__(self): - if not self.queries_path and not self.enable_custom_operations: + if ( + not self.queries_path + and not self.enable_custom_operations + and not self.models_only + ): raise TypeError("__init__ missing 1 required argument: 'queries_path'") super().__post_init__() @@ -93,24 +98,27 @@ def __post_init__(self): f"Valid options are: {valid_options}" ) from exc - self._set_default_base_client_data() + if not self.models_only: + self._set_default_base_client_data() for name, data in self.scalars.items(): data.graphql_name = name - assert_path_exists(self.queries_path) + if self.queries_path: + assert_path_exists(self.queries_path) assert_string_is_valid_python_identifier(self.target_package_name) assert_path_is_valid_directory(self.target_package_path) - assert_string_is_valid_python_identifier(self.client_name) - assert_string_is_valid_python_identifier(self.client_file_name) - assert_string_is_valid_python_identifier(self.base_client_name) - assert_path_exists(self.base_client_file_path) - assert_path_is_valid_file(self.base_client_file_path) - assert_class_is_defined_in_file( - Path(self.base_client_file_path), self.base_client_name - ) + if not self.models_only: + assert_string_is_valid_python_identifier(self.client_name) + assert_string_is_valid_python_identifier(self.client_file_name) + assert_string_is_valid_python_identifier(self.base_client_name) + assert_path_exists(self.base_client_file_path) + assert_path_is_valid_file(self.base_client_file_path) + assert_class_is_defined_in_file( + Path(self.base_client_file_path), self.base_client_name + ) assert_string_is_valid_python_identifier(self.enums_module_name) assert_string_is_valid_python_identifier(self.input_types_module_name) @@ -177,6 +185,23 @@ def used_settings_message(self) -> str: if self.include_typename else "Not including __typename fields in generated queries." ) + if self.models_only: + return dedent( + f"""\ + Selected strategy: {Strategy.CLIENT} + Generating models only. + Using schema from '{self.schema_path or self.remote_schema_url}'. + Using '{self.target_package_name}' as package name. + Generating package into '{self.target_package_path}'. + Generating enums into '{self.enums_module_name}.py'. + Generating inputs into '{self.input_types_module_name}.py'. + Generating fragments into '{self.fragments_module_name}.py'. + Comments type: {self.include_comments.value} + {snake_case_msg} + {files_to_include_msg} + {plugins_msg} + """ + ) return dedent( f"""\ Selected strategy: {Strategy.CLIENT} diff --git a/tests/client_generators/package_generator/test_generated_files.py b/tests/client_generators/package_generator/test_generated_files.py index d0e12932..01349769 100644 --- a/tests/client_generators/package_generator/test_generated_files.py +++ b/tests/client_generators/package_generator/test_generated_files.py @@ -764,3 +764,52 @@ def test_generate_creates_client_with_custom_scalars_imports( f"{generator.client_file_name}.py" ).open() as client_file: assert "from .abc import ScalarABC" in client_file.read() + + +def test_generate_models_only(tmp_path, schema, async_base_client_import): + package_name = "test_graphql_client" + generator = PackageGenerator( + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), + models_only=True, + ) + query_str = """ + query CustomQuery($id: ID!) { + query1(id: $id) { + field1 + } + } + """ + generator.add_operation(parse(query_str).definitions[0]) + generated_files = generator.generate() + + package_path = tmp_path / package_name + # Model files should exist + assert (package_path / "__init__.py").exists() + assert (package_path / "base_model.py").exists() + assert (package_path / f"{generator.enums_module_name}.py").exists() + assert (package_path / f"{generator.input_types_module_name}.py").exists() + # Result types from operations should still be generated + assert (package_path / "custom_query.py").exists() + assert "custom_query.py" in generated_files + # Client runtime files should NOT exist + assert not (package_path / "client.py").exists() + assert not (package_path / generator.base_client_file_path.name).exists() + assert not (package_path / EXCEPTIONS_FILE_PATH.name).exists() + assert "client.py" not in generated_files + assert EXCEPTIONS_FILE_PATH.name not in generated_files + # __init__.py should not import client classes + init_content = (package_path / "__init__.py").read_text() + assert "from .base_model import BaseModel, Upload" in init_content + assert "Client" not in init_content + assert "AsyncBaseClient" not in init_content + assert "GraphQLClientError" not in init_content diff --git a/tests/test_settings.py b/tests/test_settings.py index 047bd1e1..b822920e 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -479,3 +479,22 @@ def test_client_settings_include_typename_can_be_set_to_true(tmp_path): ) assert settings.include_typename is True + + +def test_client_settings_models_only(tmp_path): + schema_path = tmp_path / "schema.graphql" + schema_path.touch() + + settings = ClientSettings( + schema_path=schema_path.as_posix(), + models_only=True, + ) + + assert settings.models_only is True + assert settings.queries_path == "" + assert settings.base_client_name == "" + assert settings.base_client_file_path == "" + + result = settings.used_settings_message + assert "Generating models only." in result + assert settings.schema_path in result