diff --git a/fiddle/_src/building.py b/fiddle/_src/building.py index 61fc29ce..f62ac015 100644 --- a/fiddle/_src/building.py +++ b/fiddle/_src/building.py @@ -25,6 +25,7 @@ from fiddle._src import daglish from fiddle._src import partial from fiddle._src import reraised_exception +from fiddle._src.mutate_buildable import update_callable T = TypeVar('T') @@ -44,7 +45,8 @@ def _in_build(): """A context manager to ensure fdl.build is not called recursively.""" if _state.in_build: raise ValueError( - 'It is forbidden to call `fdl.build` inside another `fdl.build` call.') + 'It is forbidden to call `fdl.build` inside another `fdl.build` call.' + ) _state.in_build = True try: yield @@ -192,4 +194,11 @@ def _build(value: Any, state: daglish.State) -> Any: type(buildable), ) + # Poison the buildable to prevent future builds. + def _poison(*args, **kwargs): # pylint: disable=unused-argument + raise ValueError('Only call fld.build() once.') + + if isinstance(buildable, config_lib.Buildable): + update_callable(buildable, _poison) + return result diff --git a/fiddle/_src/building_test.py b/fiddle/_src/building_test.py index 15e8cfb1..0ae51a12 100644 --- a/fiddle/_src/building_test.py +++ b/fiddle/_src/building_test.py @@ -14,12 +14,12 @@ # limitations under the License. """Tests for history.""" + import dataclasses import unittest import warnings from absl import logging from absl.testing import absltest - from fiddle._src import building from fiddle._src import config @@ -82,6 +82,12 @@ def test_traversable_w_buildable(self): building.build(value) self.assertEmpty(log_output) + def test_build_poisoning(self): + foo = config.Config(Foo, 1, 2) + with self.assertRaises(ValueError): + building.build(foo) + building.build(foo) + if __name__ == '__main__': unittest.main()