diff --git a/fiddle/_src/absl_flags/utils.py b/fiddle/_src/absl_flags/utils.py index 1e75e4a6..a6d5a74d 100644 --- a/fiddle/_src/absl_flags/utils.py +++ b/fiddle/_src/absl_flags/utils.py @@ -46,6 +46,10 @@ def error_prefix(self, name: str) -> str: return f'Could not load fiddler {name!r}' +def _is_dotted_name(name: str) -> bool: + return len(name.split('.')) >= 2 + + def _import_dotted_name( name: str, mode: ImportDottedNameDebugContext, @@ -69,16 +73,17 @@ def _import_dotted_name( AttributeError: If the imported module does not contain a value with the indicated name. """ - name_pieces = name.split('.') - if module is not None: - name_pieces = [module.__name__] + name_pieces - if len(name_pieces) < 2: + if not _is_dotted_name(name): raise ValueError( f'{mode.error_prefix(name)}: Expected a dotted name including the ' 'module name.' ) + name_pieces = name.split('.') + if module is not None: + name_pieces = [module.__name__] + name_pieces + # We don't know where the module ends and the name begins; so we need to # try different split points. Longer module names take precedence. for i in range(len(name_pieces) - 1, 0, -1): @@ -245,7 +250,7 @@ def resolve_function_reference( """ if hasattr(module, function_name): return getattr(module, function_name) - elif allow_imports: + elif allow_imports and _is_dotted_name(function_name): # Try a relative import first. if module is not None: try: @@ -271,8 +276,8 @@ def resolve_function_reference( else: available_names = module_reflection.find_base_config_like_things(module) raise ValueError( - f'{failure_msg_prefix} {function_name!r}; ' - f'available names: {", ".join(available_names)}.' + f'{failure_msg_prefix} {function_name!r}: Could not resolve reference ' + f'to named function, available names: {", ".join(available_names)}.' ) diff --git a/fiddle/_src/absl_flags/utils_test.py b/fiddle/_src/absl_flags/utils_test.py index a0866f33..a0b48b0d 100644 --- a/fiddle/_src/absl_flags/utils_test.py +++ b/fiddle/_src/absl_flags/utils_test.py @@ -62,7 +62,7 @@ def test_module_relative_resolution_falls_back_to_absolute(self): def test_raises_without_resolvable_name(self): with self.assertRaisesRegex( - ValueError, "Could not init a buildable from 'config_bar'" + ValueError, "'config_bar': Could not resolve reference" ): utils.resolve_function_reference( function_name='config_bar',