diff --git a/ariadne_codegen/client_generators/custom_operation.py b/ariadne_codegen/client_generators/custom_operation.py index 805ac32b..da9d8aae 100644 --- a/ariadne_codegen/client_generators/custom_operation.py +++ b/ariadne_codegen/client_generators/custom_operation.py @@ -103,6 +103,8 @@ def generate(self) -> ast.Module: + cast(list[ast.stmt], self._type_imports) + [self._class_def], ) + if self.plugin_manager: + return self.plugin_manager.generate_custom_module(module) return module def _add_import(self, import_: Optional[ast.ImportFrom] = None): @@ -155,7 +157,7 @@ def _generate_method( return_arguments_keys, return_arguments_values ) - return generate_method_definition( + method = generate_method_definition( name=process_name( operation_name, convert_to_snake_case=self.convert_to_snake_case, @@ -183,6 +185,10 @@ def _generate_method( decorator_list=[generate_name("classmethod")], ) + if self.plugin_manager: + return self.plugin_manager.generate_custom_method(method) + return method + def _get_return_type_and_from(self, final_type): """ Determines the return type name and its import path based on the final type. diff --git a/ariadne_codegen/plugins/base.py b/ariadne_codegen/plugins/base.py index ad29b540..7258118e 100644 --- a/ariadne_codegen/plugins/base.py +++ b/ariadne_codegen/plugins/base.py @@ -145,3 +145,11 @@ def get_file_comment( self, comment: str, code: str, source: Optional[str] = None ) -> str: return comment + + def generate_custom_module( + self, module + ) -> ast.Module: + return module + + def generate_custom_method(self, method_def: ast.FunctionDef) -> ast.FunctionDef: + return method_def diff --git a/ariadne_codegen/plugins/manager.py b/ariadne_codegen/plugins/manager.py index 243e60bb..bdd26075 100644 --- a/ariadne_codegen/plugins/manager.py +++ b/ariadne_codegen/plugins/manager.py @@ -213,3 +213,15 @@ def get_file_comment( return self._apply_plugins_on_object( "get_file_comment", comment, code=code, source=source ) + + def generate_custom_module(self, module: ast.Module) -> ast.Module: + return self._apply_plugins_on_object( + "generate_custom_module", + module + ) + + def generate_custom_method(self, method_def: ast.FunctionDef) -> ast.FunctionDef: + return self._apply_plugins_on_object( + "generate_custom_method", + method_def + ) diff --git a/tests/plugins/test_manager.py b/tests/plugins/test_manager.py index 8ee4d431..ead72f32 100644 --- a/tests/plugins/test_manager.py +++ b/tests/plugins/test_manager.py @@ -411,3 +411,38 @@ def test_get_file_comment_calls_plugins_get_file_comment( plugin1, plugin2 = plugin_manager_with_mocked_plugins.plugins assert plugin1.get_file_comment.called assert plugin2.get_file_comment.called + +def test_generate_custom_module_calls_plugins_generate_custom_module( + plugin_manager_with_mocked_plugins, +): + plugin_manager_with_mocked_plugins.generate_custom_module( + ast.Module(body=[], type_ignores=[]), + ) + + plugin1, plugin2 = plugin_manager_with_mocked_plugins.plugins + assert plugin1.generate_custom_module.called + assert plugin2.generate_custom_module.called + +def test_generate_custom_method_calls_plugins_generate_custom_method( + plugin_manager_with_mocked_plugins, +): + plugin_manager_with_mocked_plugins.generate_custom_method( + ast.FunctionDef( + name="", + args=ast.arguments( + posonlyargs=[], + args=[], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[], + ), + body=[], + decorator_list=[], + ) + ) + + plugin1, plugin2 = plugin_manager_with_mocked_plugins.plugins + assert plugin1.generate_custom_method.called + assert plugin2.generate_custom_method.called