From 1d0299c73fd0c6c5f3b4feed42c58d900c515dcd Mon Sep 17 00:00:00 2001 From: Dan Holtmann-Rice Date: Mon, 12 Feb 2024 17:53:14 -0800 Subject: [PATCH] Enable pytype support for (nested) calls to auto_config-decorated functions. PiperOrigin-RevId: 606427622 --- fiddle/_src/experimental/auto_config.py | 158 ++++++++++++++----- fiddle/_src/experimental/auto_config_test.py | 40 ++++- 2 files changed, 150 insertions(+), 48 deletions(-) diff --git a/fiddle/_src/experimental/auto_config.py b/fiddle/_src/experimental/auto_config.py index 47e3936d..7c8fbe33 100644 --- a/fiddle/_src/experimental/auto_config.py +++ b/fiddle/_src/experimental/auto_config.py @@ -30,7 +30,8 @@ import linecache import textwrap import types -from typing import Any, Callable, Optional, Type, cast +import typing +from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar, cast, overload from fiddle._src import arg_factory from fiddle._src import building @@ -49,24 +50,31 @@ _ATTR_SAVE_TEMP_VAR_ID = '_attr_save_temp' _CLOSURE_WRAPPER_ID = '__auto_config_closure_wrapper__' _EMPTY_ARGUMENTS = ast.arguments( - posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]) + posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[] +) _BUILTINS = frozenset([ - builtin for builtin in builtins.__dict__.values() + builtin + for builtin in builtins.__dict__.values() if inspect.isroutine(builtin) or inspect.isclass(builtin) ]) +_GenericCallable = TypeVar('_GenericCallable', bound=Callable[..., Any]) +T = TypeVar('T') + + @dataclasses.dataclass(frozen=True) -class AutoConfig: +class AutoConfig(Generic[T]): """A function wrapper for auto_config'd functions. In order to support auto_config'ing @classmethod's, we need to customize the descriptor protocol for the auto_config'd function. This simple wrapper type - is designed to look like a simple `functool.wraps` wrapper, but implements - custom behavior for bound methods. + is designed to look like a `functool.wraps` wrapper, but implements custom + behavior for bound methods. """ - func: Callable[..., Any] - buildable_func: Callable[..., config.Buildable] + + func: T + buildable_func: Callable[..., Any] always_inline: bool @property @@ -74,10 +82,15 @@ def nowrap(self): return True # Tells Flax not to decorate this object, for classmethods. def __post_init__(self): - # Must copy-over to correctly implement "functools.wraps"-like - # functionality. - for name in ('__module__', '__name__', '__qualname__', '__doc__', - '__annotations__'): + # These attributes must be copied over to in order to correctly implement + # "functools.wraps"-like functionality. + for name in ( + '__module__', + '__name__', + '__qualname__', + '__doc__', + '__annotations__', + ): try: value = getattr(self.func, name) except AttributeError: @@ -85,10 +98,21 @@ def __post_init__(self): else: object.__setattr__(self, name, value) - def __call__(self, *args, **kwargs) -> Any: - return self.func(*args, **kwargs) + if typing.TYPE_CHECKING: + __module__: str + __name__: str + __qualname__: str + __doc__: str + __annotations__: Dict[str, Any] + # The following informs type checkers that the call method has the same + # signature/annotations as `func`. + __call__: T + else: + # Actual implementation of __call__, which forwards parameters to `func`. + def __call__(self, *args, **kwargs) -> Any: + return self.func(*args, **kwargs) - def as_buildable(self, *args, **kwargs) -> config.Buildable: + def as_buildable(self, *args, **kwargs) -> Any: return self.buildable_func(*args, **kwargs) def __get__(self, obj, objtype=None): @@ -96,7 +120,8 @@ def __get__(self, obj, objtype=None): return AutoConfig( func=self.func.__get__(obj, objtype), buildable_func=self.buildable_func.__get__(obj, objtype), - always_inline=self.always_inline) + always_inline=self.always_inline, + ) # pytype: enable=attribute-error @property @@ -116,11 +141,13 @@ class UnsupportedLanguageConstructError(SyntaxError): class _AutoConfigNodeTransformer(ast.NodeTransformer): """A NodeTransformer that adds the auto-config call handler into an AST.""" - def __init__(self, - source: str, - filename: str, - line_number: int, - allow_control_flow=False): + def __init__( + self, + source: str, + filename: str, + line_number: int, + allow_control_flow=False, + ): """Initializes the auto config node transformer instance. Args: @@ -191,7 +218,8 @@ def _validate_decorator_ordering(self, node: ast.FunctionDef): raise AssertionError( f'@{decorator} placed above @auto_config on function {node.name} ' f'at {self._filename}:{self._line_number}. Reorder decorators so ' - f'that @auto_config is placed above @{decorator}.') + f'that @auto_config is placed above @{decorator}.' + ) # pylint: disable=invalid-name def visit_Call(self, node: ast.Call): @@ -432,7 +460,8 @@ def fn(...): # Or some expression involving a lambda. *closure_var_definitions, *module.body, ], - decorator_list=[]) + decorator_list=[], + ) ], type_ignores=[], ) @@ -443,7 +472,8 @@ def fn(...): # Or some expression involving a lambda. def _find_function_code(code: types.CodeType, fn_name: str): """Finds the code object within `code` corresponding to `fn_name`.""" code = [ - const for const in code.co_consts + const + for const in code.co_consts if inspect.iscode(const) and const.co_name == fn_name ] assert len(code) == 1, f"Couldn't find function code for {fn_name!r}." @@ -553,7 +583,7 @@ def _make_partial(partial_cls, buildable_or_callable, *args, **kwargs): return partial_cls(buildable_or_callable, *args, **kwargs) -def exempt(fn_or_cls: Callable[..., Any]) -> Callable[..., Any]: +def exempt(fn_or_cls: _GenericCallable) -> _GenericCallable: """Wrap a callable so that it's exempted from auto_config. This can be used either as a decorator to exempt a function, or used inside @@ -591,8 +621,36 @@ class ConfigTypes: arg_factory_cls: Type[partial.ArgFactory] = partial.ArgFactory +@overload +def auto_config( + fn: _GenericCallable, + *, + experimental_allow_dataclass_attribute_access: bool = False, + experimental_allow_control_flow: bool = False, + experimental_always_inline: Optional[bool] = None, + experimental_exemption_policy: Optional[auto_config_policy.Policy] = None, + experimental_config_types: ConfigTypes = ConfigTypes(), + experimental_result_must_contain_buildable: bool = True, +) -> AutoConfig[_GenericCallable]: + ... + + +@overload +def auto_config( + fn: None = None, + *, + experimental_allow_dataclass_attribute_access: bool = False, + experimental_allow_control_flow: bool = False, + experimental_always_inline: Optional[bool] = None, + experimental_exemption_policy: Optional[auto_config_policy.Policy] = None, + experimental_config_types: ConfigTypes = ConfigTypes(), + experimental_result_must_contain_buildable: bool = True, +) -> Callable[[_GenericCallable], AutoConfig[_GenericCallable]]: + ... + + def auto_config( - fn=None, + fn: Optional[_GenericCallable] = None, *, experimental_allow_dataclass_attribute_access=False, experimental_allow_control_flow: bool = False, @@ -778,9 +836,11 @@ def auto_config_attr_save_handler(obj, attr, value, allow_dataclass=True): def make_auto_config(fn): if not isinstance(fn, (types.FunctionType, classmethod, staticmethod)): - raise ValueError('`auto_config` is only compatible with functions, ' - f'`@classmethod`s, and `@staticmethod`s. Got {fn!r} ' - f'with type {type(fn)!r}.') + raise ValueError( + '`auto_config` is only compatible with functions, ' + f'`@classmethod`s, and `@staticmethod`s. Got {fn!r} ' + f'with type {type(fn)!r}.' + ) if isinstance(fn, (classmethod, staticmethod)): method_type = type(fn) @@ -799,7 +859,8 @@ def make_auto_config(fn): source=source, filename=filename, line_number=line_number, - allow_control_flow=experimental_allow_control_flow) + allow_control_flow=experimental_allow_control_flow, + ) # Parse the AST, and modify it by intercepting all `Call`s with the # `auto_config_call_handler`. Finally, ensure line numbers and code @@ -882,7 +943,8 @@ def as_buildable(*args, **kwargs): fn = method_type(fn) as_buildable = method_type(as_buildable) return AutoConfig( - fn, as_buildable, always_inline=experimental_always_inline) + fn, as_buildable, always_inline=experimental_always_inline + ) # Decorator with empty parenthesis. if fn is None: @@ -951,7 +1013,6 @@ def main(): experimental_always_inline = True def make_unconfig(fn) -> AutoConfig: - @functools.wraps(fn) def python_implementation(*args, **kwargs): previous = building._state.in_build # pytype: disable=module-attr # pylint: disable=protected-access @@ -965,7 +1026,8 @@ def python_implementation(*args, **kwargs): return AutoConfig( func=python_implementation, buildable_func=fn, - always_inline=experimental_always_inline) + always_inline=experimental_always_inline, + ) # We use this pattern to support using the decorator with and without # parenthesis. @@ -1023,20 +1085,27 @@ def make_experiment(): doesn't correspond to an ``auto_config``'d function. """ if not isinstance(buildable, config.Config): - raise ValueError('Cannot `inline` non-Config buildables; ' - f'{type(buildable)} is not compatible.') + raise ValueError( + 'Cannot `inline` non-Config buildables; ' + f'{type(buildable)} is not compatible.' + ) if not is_auto_config(buildable.__fn_or_cls__): - raise ValueError('Cannot `inline` a non-auto_config function; ' - f'`{buildable.__fn_or_cls__}` is not compatible.') + raise ValueError( + 'Cannot `inline` a non-auto_config function; ' + f'`{buildable.__fn_or_cls__}` is not compatible.' + ) # Evaluate the `as_buildable` interpretation. auto_config_fn = cast(AutoConfig, buildable.__fn_or_cls__) tmp_config = auto_config_fn.as_buildable(**buildable.__arguments__) if not isinstance(tmp_config, config.Buildable): - raise ValueError('You cannot currently inline functions that do not return ' - '`fdl.Buildable`s.') + raise ValueError( + 'You cannot currently inline functions that do not return ' + '`fdl.Buildable`s.' + ) mutate_buildable.move_buildable_internals( - source=tmp_config, destination=buildable) + source=tmp_config, destination=buildable + ) def _getsource(fn: Any) -> str: @@ -1056,11 +1125,12 @@ def _is_lambda(fn: Any) -> bool: return False if not (hasattr(fn, '__name__') and hasattr(fn, '__code__')): return False - return ((fn.__name__ == '') or (fn.__code__.co_name == '')) + return (fn.__name__ == '') or (fn.__code__.co_name == '') class _LambdaFinder(cst.CSTVisitor): """CST Visitor that searches for the source code for a given lambda func.""" + METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,) def __init__(self, lambda_fn): @@ -1095,7 +1165,8 @@ def _getsource_for_lambda(fn: Callable[..., Any]) -> str: elif not lambda_finder.candidates: raise ValueError( 'Fiddle auto_config was unable to find the source code for ' - f'{fn}: could not find lambda on line {lambda_finder.lineno}.') + f'{fn}: could not find lambda on line {lambda_finder.lineno}.' + ) else: # TODO(b/258671226): If desired, we could narrow down which lambda is # used based on the signature (or even fancier things like the checking @@ -1103,7 +1174,8 @@ def _getsource_for_lambda(fn: Callable[..., Any]) -> str: raise ValueError( 'Fiddle auto_config was unable to find the source code for ' f'{fn}: multiple lambdas found on line {lambda_finder.lineno}; ' - 'try moving each lambda to its own line.') + 'try moving each lambda to its own line.' + ) def with_buildable_func( diff --git a/fiddle/_src/experimental/auto_config_test.py b/fiddle/_src/experimental/auto_config_test.py index ce138c7d..338630df 100644 --- a/fiddle/_src/experimental/auto_config_test.py +++ b/fiddle/_src/experimental/auto_config_test.py @@ -21,7 +21,7 @@ import functools import inspect import sys -from typing import Any +from typing import Any, TypeVar from absl.testing import absltest from absl.testing import parameterized @@ -31,6 +31,7 @@ from fiddle._src.experimental import auto_config_policy from fiddle._src.experimental import autobuilders as ab from fiddle._src.testing import test_util +import pytype_extensions def basic_fn(arg, kwarg='test'): @@ -87,7 +88,10 @@ def globals_test_fn(): return pass_through(5) -def pass_through(arg): +T = TypeVar('T') + + +def pass_through(arg: T) -> T: return arg @@ -346,12 +350,20 @@ def test_class_config(arg1, *, arg2='default'): test_class_config('positional'), ) + def test_type_inference(self): + @auto_config.auto_config + def test_fn_config(arg: int) -> int: + return pass_through(arg) + + output = test_fn_config(42) + pytype_extensions.assert_type(output, int) + def test_calling_auto_config(self): expected_config = fdl.Config( basic_fn, 1, kwarg=fdl.Config(FrozenSampleClass, 1, 2) ) - @auto_config.auto_config(experimental_always_inline=True) + @auto_config.auto_config def test_class_config(): return FrozenSampleClass(1, arg2=2) @@ -364,6 +376,19 @@ def test_fn_config(): {'arg': 1, 'kwarg': FrozenSampleClass(1, arg2=2)}, test_fn_config() ) + def test_calling_auto_config_type_inference(self): + @auto_config.auto_config + def test_class_config() -> FrozenSampleClass: + return FrozenSampleClass(1, arg2=2) + + @auto_config.auto_config + def test_fn_config(): + output = pass_through(test_class_config()) + pytype_extensions.assert_type(output, FrozenSampleClass) + return output + + del test_fn_config # Unused + def test_nested_calls(self): expected_config = fdl.Config( FrozenSampleClass, 1, arg2=fdl.Config(basic_fn, 2, 'kwarg') @@ -460,7 +485,7 @@ def autobuilder_using_fn(): def test_auto_configuring_non_function(self): with self.assertRaisesRegex(ValueError, 'only compatible with functions'): - auto_config.auto_config(3) + auto_config.auto_config(3) # pytype: disable=wrong-arg-types def test_return_structure(self): expected_config = { @@ -779,6 +804,11 @@ def make_sample(max_value): self.assertEqual(cfg.arg2.arg1, fdl.Config(pass_through, 5)) self.assertEqual(cfg.arg2.arg2, 5) + def test_exemption_type_inference(self): + exempted_func = auto_config.exempt(pass_through) + value = exempted_func(42) + pytype_extensions.assert_type(value, int) + def test_lambda_supported_in_decorator(self): @auto_config.auto_config(experimental_exemption_policy=lambda x: False) def make_sample(): @@ -1044,7 +1074,7 @@ def test_can_pass_self_as_keyword(self): FrozenSampleClass.autoconfig_method(self=x), {'arg': x, 'kwarg': 'test'} ) self.assertDagEqual( - FrozenSampleClass.autoconfig_method.as_buildable(self=x), + FrozenSampleClass.autoconfig_method.as_buildable(self=x), # pytype: disable=duplicate-keyword-argument fdl.Config(basic_fn, x), )