diff --git a/fiddle/_src/codegen/auto_config/code_ir.py b/fiddle/_src/codegen/auto_config/code_ir.py index 858982b4..489f11b4 100644 --- a/fiddle/_src/codegen/auto_config/code_ir.py +++ b/fiddle/_src/codegen/auto_config/code_ir.py @@ -109,6 +109,11 @@ class ModuleReference(BaseNameReference): """Reference to an imported module.""" +@dataclasses.dataclass +class BuiltinReference(BaseNameReference): + """Reference to an imported module.""" + + @dataclasses.dataclass class FixtureReference(BaseNameReference): """Reference to another fixture.""" @@ -129,6 +134,14 @@ def __hash__(self): return id(self) +@dataclasses.dataclass +class ParameterizedTypeExpression(CodegenNode): + """Reference to a parameterized type like list[int].""" + + base_expression: Any # Expression like BuiltinReference(Name("list")) + param_expressions: List[Any] # List of (positional) argument expressions + + @dataclasses.dataclass class ArgFactoryExpr(CodegenNode): """Represents a factory that should be interpreted as an argument factory. diff --git a/fiddle/_src/codegen/auto_config/import_manager_wrapper.py b/fiddle/_src/codegen/auto_config/import_manager_wrapper.py new file mode 100644 index 00000000..d8b4c377 --- /dev/null +++ b/fiddle/_src/codegen/auto_config/import_manager_wrapper.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# Copyright 2022 The Fiddle-Config Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Small helper functions around the ImportManager. + +This is a bit of cruft and should eventually be cleaned up. + +Context: The import manager predates modern (auto_config) codegen, and is used +by legacy codegen and diff codegen. The latter is still pretty important and +needs to be supported. +""" + +import logging +import typing +from typing import Any + +from fiddle._src.codegen import import_manager as import_manager_lib +from fiddle._src.codegen.auto_config import code_ir + + +def _name_to_attribute_expression(name: str) -> code_ir.CodegenNode: + """Converts a fully-qualified name to a code_ir node. + + Args: + name: Output from the import manager. + + Returns: + Codegen node. + """ + if "." not in name: + logging.warning( + "Expected to find a module in %s, but found none. This might be because" + " your module is from __main__, so we'll still emit code, but you might" + " need to fix imports for this symbol.", + name, + ) + return code_ir.BaseNameReference(code_ir.Name(name)) + base, *parts = name.split(".") + value = code_ir.ModuleReference(code_ir.Name(base)) + for part in parts: + value = code_ir.AttributeExpression(value, part) + return typing.cast(code_ir.AttributeExpression, value) + + +def add( + value: Any, import_manager: import_manager_lib.ImportManager +) -> code_ir.CodegenNode: + return _name_to_attribute_expression(import_manager.add(value)) diff --git a/fiddle/_src/codegen/auto_config/make_symbolic_references.py b/fiddle/_src/codegen/auto_config/make_symbolic_references.py index 915c2853..d3170d45 100644 --- a/fiddle/_src/codegen/auto_config/make_symbolic_references.py +++ b/fiddle/_src/codegen/auto_config/make_symbolic_references.py @@ -18,13 +18,13 @@ import enum import functools import inspect -import typing from typing import Any, Callable from fiddle import arg_factory from fiddle import daglish from fiddle._src import config as config_lib from fiddle._src.codegen.auto_config import code_ir +from fiddle._src.codegen.auto_config import import_manager_wrapper def is_plain_symbol_or_enum_value(value: Any) -> bool: @@ -67,16 +67,6 @@ def noop_history_comments(unused_buildable): return code_ir.HistoryComments() -def _name_to_attribute_expression(name: str) -> code_ir.AttributeExpression: - if "." not in name: - raise ValueError(f"Could not parse symbol import {name}") - base, *parts = name.split(".") - value = code_ir.ModuleReference(code_ir.Name(base)) - for part in parts: - value = code_ir.AttributeExpression(value, part) - return typing.cast(code_ir.AttributeExpression, value) - - def replace_callables_and_configs_with_symbols( task: code_ir.CodegenTask, *, @@ -98,7 +88,7 @@ def replace_callables_and_configs_with_symbols( def _handle_partial( value: config_lib.Partial, state: daglish.State, - ir_for_symbol: code_ir.AttributeExpression, + ir_for_symbol: code_ir.CodegenNode, ): """Split-out helper method to handle Partial() nodes.""" arguments = config_lib.ordered_arguments(value) @@ -131,9 +121,7 @@ def _handle_partial( def _arg_factory_partial(): return code_ir.SymbolOrFixtureCall( - _name_to_attribute_expression( - task.import_manager.add(arg_factory.partial) - ), + import_manager_wrapper.add(arg_factory.partial, task.import_manager), positional_arg_expressions=[ir_for_symbol], arg_expressions=arg_factory_args, history_comments=format_history(value), @@ -146,9 +134,7 @@ def _arg_factory_partial(): # the auto_config fixture's as_buildable() method. If we got rid of the # functools.partial, then we couldn't configure any attributes. return code_ir.SymbolOrFixtureCall( - _name_to_attribute_expression( - task.import_manager.add(functools.partial) - ), + import_manager_wrapper.add(functools.partial, task.import_manager), positional_arg_expressions=[ir_for_symbol], arg_expressions=regular_args, history_comments=format_history(value), @@ -161,9 +147,7 @@ def _arg_factory_partial(): # which order, but we need to emit both decorators. Go with functools # on the outer level. return code_ir.SymbolOrFixtureCall( - _name_to_attribute_expression( - task.import_manager.add(functools.partial) - ), + import_manager_wrapper.add(functools.partial, task.import_manager), positional_arg_expressions=[_arg_factory_partial()], arg_expressions=regular_args, history_comments=format_history(value), @@ -171,8 +155,8 @@ def _arg_factory_partial(): def traverse(value, state: daglish.State): if isinstance(value, config_lib.Buildable): - ir_for_symbol = _name_to_attribute_expression( - task.import_manager.add(config_lib.get_callable(value)) + ir_for_symbol = import_manager_wrapper.add( + config_lib.get_callable(value), task.import_manager ) if isinstance(value, config_lib.Config): all_tags = value.__argument_tags__ @@ -214,7 +198,7 @@ def traverse(value, state: daglish.State): else: raise TypeError(f"Unsupported Buildable {type(value)}") elif is_plain_symbol_or_enum_value(value): - return _name_to_attribute_expression(task.import_manager.add(value)) + return import_manager_wrapper.add(value, task.import_manager) else: return state.map_children(value) diff --git a/fiddle/_src/codegen/newcg_symbolic_references.py b/fiddle/_src/codegen/newcg_symbolic_references.py index b75a0e25..031f40b8 100644 --- a/fiddle/_src/codegen/newcg_symbolic_references.py +++ b/fiddle/_src/codegen/newcg_symbolic_references.py @@ -18,12 +18,12 @@ N.B. Please see codegen/auto_config for the auto_config version!! """ -import typing from typing import Callable from fiddle import daglish from fiddle._src import config as config_lib from fiddle._src.codegen.auto_config import code_ir +from fiddle._src.codegen.auto_config import import_manager_wrapper from fiddle._src.codegen.auto_config import make_symbolic_references as ac_make_symbolic_references is_plain_symbol_or_enum_value = ( @@ -56,16 +56,6 @@ def import_symbols(task: code_ir.CodegenTask) -> None: task.import_manager.add(value) -def _name_to_attribute_expression(name: str) -> code_ir.AttributeExpression: - if "." not in name: - raise ValueError(f"Could not parse symbol import {name}") - base, *parts = name.split(".") - value = code_ir.ModuleReference(code_ir.Name(base)) - for part in parts: - value = code_ir.AttributeExpression(value, part) - return typing.cast(code_ir.AttributeExpression, value) - - def replace_callables_and_configs_with_symbols( task: code_ir.CodegenTask, *, @@ -84,11 +74,11 @@ def replace_callables_and_configs_with_symbols( def traverse(value, state: daglish.State): if isinstance(value, config_lib.Buildable): - ir_for_buildable_type = _name_to_attribute_expression( - task.import_manager.add(type(value)) + ir_for_buildable_type = import_manager_wrapper.add( + type(value), task.import_manager ) - ir_for_symbol = _name_to_attribute_expression( - task.import_manager.add(config_lib.get_callable(value)) + ir_for_symbol = import_manager_wrapper.add( + config_lib.get_callable(value), task.import_manager ) all_tags = value.__argument_tags__ value = state.map_children(value) @@ -113,7 +103,7 @@ def traverse(value, state: daglish.State): history_comments=format_history(value), ) elif is_plain_symbol_or_enum_value(value): - return _name_to_attribute_expression(task.import_manager.add(value)) + return import_manager_wrapper.add(value, task.import_manager) else: return state.map_children(value)