diff --git a/fiddle/_src/absl_flags/flags.py b/fiddle/_src/absl_flags/flags.py index ac0f6fe6..da02b545 100644 --- a/fiddle/_src/absl_flags/flags.py +++ b/fiddle/_src/absl_flags/flags.py @@ -15,9 +15,10 @@ """API to use command line flags with Fiddle Buildables.""" +import dataclasses import re import types -from typing import Any, Optional, TypeVar +from typing import Any, List, Optional, Text, TypeVar, Union from absl import flags from etils import epath @@ -81,60 +82,32 @@ def serialize(self, value: config.Buildable) -> str: return f"config_str:{serialized}" -class FiddleFlag(flags.MultiFlag): - """ABSL flag class for a Fiddle config flag. +@dataclasses.dataclass +class _LazyFlagValue: + """Represents a lazily evaluated Fiddle flag value. - This class is used to parse command line flags to construct a Fiddle `Config` - object with certain transformations applied as specified in the command line - flags. + This is separate from FiddleFlag because it is used by both defaults and + provided flags. - Most users should rely on the `DEFINE_fiddle_config()` API below. Using this - class directly provides flexibility to users to parse Fiddle flags themselves - programmatically. Also see the documentation for `DEFINE_fiddle_config()` - below. + Lazy flag values are useful because they allow other parts of the system to + be set up, so things like logging can be configured before a configuration is + loaded. + """ - Example usage where this flag is parsed from existing flag: - ``` - from fiddle import absl_flags as fdl_flags + flag_name: str + remaining_directives: List[str] = dataclasses.field(default_factory=list) + first_command: Optional[str] = None + initial_config_expression: Optional[str] = None - _MY_CONFIG = fdl_flags.DEFINE_multi_string( - "my_config", - "Name of the fiddle config" - ) + default_module: Optional[types.ModuleType] = None + allow_imports: bool = True + pyref_policy: Optional[serialization.PyrefPolicy] = None - fiddle_flag = fdl_flags.FiddleFlag( - name="config", - default_module=my_module, - default=None, - parser=flags.ArgumentParser(), - serializer=None, - help_string="My fiddle flag", - ) - fiddle_flag.parse(_MY_CONFIG.value) - config = fiddle_flag.value - ``` - """ - - def __init__( - self, - *args, - default_module: Optional[types.ModuleType] = None, - allow_imports: bool = True, - pyref_policy: Optional[serialization.PyrefPolicy] = None, - **kwargs, - ): - self.allow_imports = allow_imports - self.default_module = default_module - self._pyref_policy = pyref_policy - self.first_command = None - self._initial_config_expression = None - # A `directive` is a str of the form e.g. 'config:...'. - # Due to the lazy evaluation of `value`, this list is needed to keep - # track of the remaining `directives`. - self._remaining_directives = [] - super().__init__(*args, **kwargs) + # Only set internally, please use get_value() / set_value(). + _value: Optional[Any] = None def _initial_config(self, expression: str): + """Generates the initial config from a config: directive.""" call_expr = utils.CallExpression.parse(expression) base_name = call_expr.func_name base_fn = utils.resolve_function_reference( @@ -150,6 +123,7 @@ def _initial_config(self, expression: str): return base_fn(*call_expr.args, **call_expr.kwargs) def _apply_fiddler(self, cfg: config.Buildable, expression: str): + """Modifies the config from the given CLI flag.""" call_expr = utils.CallExpression.parse(expression) base_name = call_expr.func_name fiddler = utils.resolve_function_reference( @@ -175,57 +149,40 @@ def _apply_fiddler(self, cfg: config.Buildable, expression: str): # `fdl.Buildable` object. return new_cfg if new_cfg is not None else cfg - def parse(self, arguments): - new_parsed = self._parse(arguments) - self._remaining_directives.extend(new_parsed) - self.present += len(new_parsed) - - def unparse(self) -> None: - self.value = self.default - self.using_default_value = True - # Reset it so that all `directives` not being processed yet will be - # discarded. - self._remaining_directives = [] - self.present = 0 - def _parse_config(self, command: str, expression: str) -> None: - if self._initial_config_expression: + """Sets the initial config from the given CLI flag/directive.""" + if self.initial_config_expression: raise ValueError( "Only one base configuration is permitted. Received" f"{command}:{expression} after " - f"{self.first_command}:{self._initial_config_expression} was" + f"{self.first_command}:{self.initial_config_expression} was" " already provided." ) else: - self._initial_config_expression = expression + self.initial_config_expression = expression if command == "config": - self.value = self._initial_config(expression) + self._value = self._initial_config(expression) elif command == "config_file": with epath.Path(expression).open() as f: - self.value = serialization.load_json( - f.read(), pyref_policy=self._pyref_policy + self._value = serialization.load_json( + f.read(), pyref_policy=self.pyref_policy ) elif command == "config_str": serializer = utils.ZlibJSONSerializer() - self.value = serializer.deserialize( - expression, pyref_policy=self._pyref_policy + self._value = serializer.deserialize( + expression, pyref_policy=self.pyref_policy ) - def _serialize(self, value) -> str: - # Skip MultiFlag serialization as we don't truly have a multi-flag. - # This will invoke Flag._serialize - return super(flags.MultiFlag, self)._serialize(value) - - @property - def value(self): - while self._remaining_directives: + def get_value(self): + """Gets the current value (parsing any directives).""" + while self.remaining_directives: # Pop already processed `directive` so that _value won't be updated twice # by the same argument. - item = self._remaining_directives.pop(0) + item = self.remaining_directives.pop(0) match = _COMMAND_RE.fullmatch(item) if not match: raise ValueError( - f"All flag values to {self.name} must begin with 'config:', " + f"All flag values to {self.flag_name} must begin with 'config:', " "'config_file:', 'config_str:', 'set:', or 'fiddler:'." ) command, expression = match.groups() @@ -235,7 +192,9 @@ def value(self): raise ValueError( "First flag command must specify the input config via either " "config or config_file or config_str commands. " - f"Received command: {command} instead." + f"Received command: {command} instead. If you have a default " + "value set, you must re-provide that on the CLI before setting " + "values or running fiddlers." ) self.first_command = command @@ -254,15 +213,130 @@ def value(self): raise AssertionError("Internal error; should not be reached.") return self._value + def set_value(self, value: Any): + self._value = value + + +class FiddleFlag(flags.MultiFlag): + """ABSL flag class for a Fiddle config flag. + + This class is used to parse command line flags to construct a Fiddle `Config` + object with certain transformations applied as specified in the command line + flags. + + Most users should rely on the `DEFINE_fiddle_config()` API below. Using this + class directly provides flexibility to users to parse Fiddle flags themselves + programmatically. Also see the documentation for `DEFINE_fiddle_config()` + below. + + Example usage where this flag is parsed from existing flag: + ``` + from fiddle import absl_flags as fdl_flags + + _MY_CONFIG = fdl_flags.DEFINE_multi_string( + "my_config", + "Name of the fiddle config" + ) + + fiddle_flag = fdl_flags.FiddleFlag( + name="config", + default_module=my_module, + default=None, + parser=flags.ArgumentParser(), + serializer=None, + help_string="My fiddle flag", + ) + fiddle_flag.parse(_MY_CONFIG.value) + config = fiddle_flag.value + ``` + """ + + def __init__( + self, + *args, + name: Text, + default_module: Optional[types.ModuleType] = None, + allow_imports: bool = True, + pyref_policy: Optional[serialization.PyrefPolicy] = None, + **kwargs, + ): + self.allow_imports = allow_imports + self.default_module = default_module + self._pyref_policy = pyref_policy + self._lazy_default = _LazyFlagValue( + flag_name=name, + default_module=default_module, + allow_imports=allow_imports, + pyref_policy=pyref_policy, + ) + self._lazy_value = _LazyFlagValue( + flag_name=name, + default_module=default_module, + allow_imports=allow_imports, + pyref_policy=pyref_policy, + ) + kwargs["name"] = name + super().__init__(*args, **kwargs) + + def parse(self, arguments): + new_parsed = self._parse(arguments) + self._lazy_value.remaining_directives.extend(new_parsed) + self.present += len(new_parsed) + + def _parse_from_default( + self, value: Union[Text, List[Any]] + ) -> Optional[List[Any]]: + lazy_default_value = _LazyFlagValue( + flag_name=self.name, + default_module=self.default_module, + allow_imports=self.allow_imports, + pyref_policy=self._pyref_policy, + ) + value = self._parse(value) + assert isinstance(value, list) + lazy_default_value.remaining_directives.extend(value) + return lazy_default_value # pytype: disable=bad-return-type + + def unparse(self) -> None: + self.value = self.default + self.using_default_value = True + # Reset it so that all `directives` not being processed yet will be + # discarded. + self._lazy_value.remaining_directives = [] + self.present = 0 + + def _serialize(self, value) -> str: + # Skip MultiFlag serialization as we don't truly have a multi-flag. + # This will invoke Flag._serialize + return super(flags.MultiFlag, self)._serialize(value) + + @property + def value(self): + return self._lazy_value.get_value() + @value.setter def value(self, value): - self._value = value + self._lazy_value.set_value(value) + + @property + def default(self): + return self._lazy_default.get_value() + + @default.setter + def default(self, value): + if isinstance(value, _LazyFlagValue): + # Note: This is only for _set_default(). We might choose to override that + # instead of just _parse_from_default(), in which case this branch can be + # removed. + self._lazy_default = value + else: + self._lazy_default.set_value(value) def DEFINE_fiddle_config( # pylint: disable=invalid-name name: str, *, - default: Any = None, + default_flag_str: Optional[str] = None, help_string: str, default_module: Optional[types.ModuleType] = None, pyref_policy: Optional[serialization.PyrefPolicy] = None, @@ -317,12 +391,12 @@ def main(argv) -> None: python3 -m path.to.my.binary --my_config=config_file:path/to/file Args: - name: name of the command line flag. - default: default value of the flag. - help_string: help string describing what the flag does. - default_module: the python module where this flag is defined. - pyref_policy: a policy for importing references to Python objects. - flag_values: the ``FlagValues`` instance with which the flag will be + name: Name of the command line flag. + default_flag_str: Default value of the flag. + help_string: Help string describing what the flag does. + default_module: The python module where this flag is defined. + pyref_policy: A policy for importing references to Python objects. + flag_values: The ``FlagValues`` instance with which the flag will be registered. This should almost never need to be overridden. required: bool, is this a required flag. This must be used as a keyword argument. @@ -334,7 +408,7 @@ def main(argv) -> None: FiddleFlag( name=name, default_module=default_module, - default=default, + default=default_flag_str, pyref_policy=pyref_policy, parser=flags.ArgumentParser(), serializer=FiddleFlagSerializer(pyref_policy=pyref_policy),