From fa90463596fea1c311e350b07014718b42a1edfd Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 17 Mar 2022 09:30:07 +0100 Subject: [PATCH 01/10] Always enable flat net construction Fix #992 --- returnn/tf/network.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/returnn/tf/network.py b/returnn/tf/network.py index 68f578ab55..5a53b7f01f 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -786,17 +786,6 @@ def get_layer(src_name): self.used_data_keys.update(extra_net.used_data_keys) return created_layers - def _flat_construction_enabled(self): - """ - :return: whether to use flat construction algorithm in :func:`construct_layer`. - Use this if you get stack overflow errors, such as: - ``Fatal Python error: Cannot recover from stack overflow`` - or - ``RuntimeError: maximum recursion depth exceeded``. - :rtype: bool - """ - return self.get_config().bool("flat_net_construction", False) - def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_existing=True): """ This triggers the construction of the layer `name` if it is not constructed yet. @@ -918,14 +907,13 @@ def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_ layer_name=full_name, network=self) return sub_layer - if self._flat_construction_enabled(): - delayed_exc = _DelayedConstructionException( - network=self, layer_name=name, - other_kwargs=dict(net_dict=net_dict, get_layer=get_layer, add_layer=add_layer, check_existing=check_existing)) - if not self._construction_stack.in_flat_construct_count: - return self._construction_stack.flat_construct(delayed_exc) - if self._construction_stack.layers: - raise delayed_exc + delayed_exc = _DelayedConstructionException( + network=self, layer_name=name, + other_kwargs=dict(net_dict=net_dict, get_layer=get_layer, add_layer=add_layer, check_existing=check_existing)) + if not self._construction_stack.in_flat_construct_count: + return self._construction_stack.flat_construct(delayed_exc) + if self._construction_stack.layers: + raise delayed_exc layer_desc = layer_desc.copy() layer_desc.pop("class") From 19aa9b3d110972b4fda74bfd27636e06b2fe9f19 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 17 Mar 2022 09:31:08 +0100 Subject: [PATCH 02/10] test_flat_net_construction, no flat_net_construction option --- tests/test_TFNetworkLayer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index dc724adaa2..b02f992448 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -5640,7 +5640,6 @@ def test_flat_net_construction(): "data": (n_in, 2), "classes": (n_out, 1), }, - "flat_net_construction": True, "debug_print_layer_output_template": True, }) print("Creating network...") From 12102ba3d2349136078182965a26cf2a697f8f5c Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 17 Mar 2022 10:31:48 +0100 Subject: [PATCH 03/10] _DelayedConstructionException, nicer repr --- returnn/tf/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/tf/network.py b/returnn/tf/network.py index 5a53b7f01f..5e6cdfe843 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -3612,7 +3612,7 @@ def __init__(self, network, layer_name, other_kwargs): self.other_kwargs = other_kwargs def __repr__(self): - return "%s(layer_name=%r)" % (self.__class__.__name__, self.layer_name) + return "<%s %r/%r>" % (self.__class__.__name__, self.network.name, self.layer_name) def delayed_construction(self): """ From e3b567fd994b6d99f404149d2d9d79cc52b6aa56 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 17 Mar 2022 10:32:25 +0100 Subject: [PATCH 04/10] flat construction, small fix --- returnn/tf/network.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/returnn/tf/network.py b/returnn/tf/network.py index 5e6cdfe843..e795c08be6 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -818,6 +818,9 @@ def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_ return self.get_layer(name) except (LayerNotFound, DataNotFound): pass # ok, we will try to construct it then + delayed_exc = _DelayedConstructionException( + network=self, layer_name=name, # make sure that we have all the original args + other_kwargs=dict(net_dict=net_dict, get_layer=get_layer, add_layer=add_layer, check_existing=check_existing)) if not get_layer: get_layer = GetLayer(network=self, add_layer_func=add_layer) full_name = name @@ -907,12 +910,13 @@ def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_ layer_name=full_name, network=self) return sub_layer - delayed_exc = _DelayedConstructionException( - network=self, layer_name=name, - other_kwargs=dict(net_dict=net_dict, get_layer=get_layer, add_layer=add_layer, check_existing=check_existing)) if not self._construction_stack.in_flat_construct_count: return self._construction_stack.flat_construct(delayed_exc) if self._construction_stack.layers: + # Note: We don't want to raise this earlier here in this function + # because certain exceptions such as LayerNotFound should directly be raised + # because some other code tests for this + # (e.g. checking the loss checking for layer "classes" and then layer "data:classes"). raise delayed_exc layer_desc = layer_desc.copy() From e372d3ba3af82f10058b1f4e9cc72dae40c6ea2a Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 17 Mar 2022 11:35:05 +0100 Subject: [PATCH 05/10] flat construction, refactor, cleanup --- returnn/tf/network.py | 78 +++++++++++++++++++++++++++++++------------ 1 file changed, 56 insertions(+), 22 deletions(-) diff --git a/returnn/tf/network.py b/returnn/tf/network.py index e795c08be6..c4b2de144f 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -323,9 +323,16 @@ class _NetworkConstructionStack: Used to keep the recursive construction state of :function:`TFNetwork.construct_layer`. """ + # This assumes that we do single-threaded net construction. + # For multi-threading (if this would ever be realistic for net construction), + # we would need this to be a thread local. + # We still have a stack here for flat_construction() because we need to be nested + # for things like CondLayer or RecLayer. + _flat_construction_stack = [] # type: typing.List[_NetworkConstructionStack] + def __init__(self): self.layers = [] # type: typing.List[str] - self.in_flat_construct_count = 0 + self.flat_construct_stack = [] # type: typing.List[typing.Tuple[TFNetwork, str, typing.Dict[str, typing.Any]]] def append(self, layer_name): """ @@ -340,23 +347,59 @@ def remove(self, layer_name): """ self.layers.remove(layer_name) - def flat_construct(self, initial): + def is_active_flat_construction(self): + """ + :return: whether this is called inside self.flat_construct(...) + :rtype: bool + """ + cls = self.__class__ + if not cls._flat_construction_stack: + return False + return cls._flat_construction_stack[-1] is self + + def should_continue_construction(self): """ - :param _DelayedConstructionException initial: + We assume that we are inside self.flat_construct(), and currently doing a construct_layer(). + + :return: whether construct_layer() should continue. otherwise, it would throw _DelayedConstructionException. + :rtype: bool + """ + assert self.is_active_flat_construction() + return not self.layers + + def flat_construct(self, initial_exc): """ - self.in_flat_construct_count += 1 - queue = [initial] # type: typing.List[_DelayedConstructionException] + :param _DelayedConstructionException initial_exc: + :rtype: LayerBase + """ + cls = self.__class__ + assert not self.flat_construct_stack + stack = self.flat_construct_stack + initial = (initial_exc.network, initial_exc.layer_name, initial_exc.other_kwargs) + stack.append(initial) + cls._flat_construction_stack.append(self) try: - while queue: + while stack: try: - res = queue[-1].delayed_construction() - if queue[-1] is initial: + top = stack[-1] + network, layer_name, other_kwargs = top + res = network.construct_layer(name=layer_name, **other_kwargs) + stack.pop(-1) + if top is initial: + assert not stack return res - queue.pop(-1) except _DelayedConstructionException as delayed_exc: - queue.append(delayed_exc) + stack.append((delayed_exc.network, delayed_exc.layer_name, delayed_exc.other_kwargs)) + except Exception as exc: + attr = "_RETURNN_layer_construction_stack" + if not hasattr(exc, attr): + setattr(exc, attr, []) + getattr(exc, attr).extend(stack) + raise finally: - self.in_flat_construct_count -= 1 + top_stack = cls._flat_construction_stack.pop(-1) + assert top_stack is self + stack.clear() assert False, "we should not get here" @@ -910,9 +953,9 @@ def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_ layer_name=full_name, network=self) return sub_layer - if not self._construction_stack.in_flat_construct_count: + if not self._construction_stack.is_active_flat_construction(): return self._construction_stack.flat_construct(delayed_exc) - if self._construction_stack.layers: + if not self._construction_stack.should_continue_construction(): # Note: We don't want to raise this earlier here in this function # because certain exceptions such as LayerNotFound should directly be raised # because some other code tests for this @@ -3618,15 +3661,6 @@ def __init__(self, network, layer_name, other_kwargs): def __repr__(self): return "<%s %r/%r>" % (self.__class__.__name__, self.network.name, self.layer_name) - def delayed_construction(self): - """ - Call :func:`TFNetwork.construct_layer` again now. - - :rtype: LayerBase - """ - print("Delayed flat layer construction:", self.layer_name, file=log.v5) - return self.network.construct_layer(name=self.layer_name, **self.other_kwargs) - class LayerNotFound(NetworkLayerException): """ From 5c4b1bcb5ab9fdd09d61eefb554b69bcaf8e0fda Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 17 Mar 2022 11:40:07 +0100 Subject: [PATCH 06/10] flat construction, small fix --- returnn/tf/network.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/returnn/tf/network.py b/returnn/tf/network.py index c4b2de144f..c74dd01cfc 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -330,7 +330,11 @@ class _NetworkConstructionStack: # for things like CondLayer or RecLayer. _flat_construction_stack = [] # type: typing.List[_NetworkConstructionStack] - def __init__(self): + def __init__(self, network): + """ + :param TFNetwork network: + """ + self.network = network self.layers = [] # type: typing.List[str] self.flat_construct_stack = [] # type: typing.List[typing.Tuple[TFNetwork, str, typing.Dict[str, typing.Any]]] @@ -389,6 +393,8 @@ def flat_construct(self, initial_exc): assert not stack return res except _DelayedConstructionException as delayed_exc: + if delayed_exc.network is not self.network: + raise # some parent flat_construct() should handle this stack.append((delayed_exc.network, delayed_exc.layer_name, delayed_exc.other_kwargs)) except Exception as exc: attr = "_RETURNN_layer_construction_stack" @@ -501,7 +507,7 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None, self.extra_nets = {} # type: typing.Dict[str,TFNetwork] self.subnets = {} # type: typing.Dict[str,Subnetwork] self._selected_train_layers = None - self._construction_stack = _NetworkConstructionStack() + self._construction_stack = _NetworkConstructionStack(self) self.layers_desc = {} # type: typing.Dict[str,typing.Dict[str]] self.layers = {} # type: typing.Dict[str,LayerBase] self.losses_dict = {} # type: typing.Dict[str,LossHolder] From fbde04dfb9a5c9809a2499fc87c3db1bca2b0483 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 17 Mar 2022 11:42:47 +0100 Subject: [PATCH 07/10] flat construction, small fix --- returnn/tf/network.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/returnn/tf/network.py b/returnn/tf/network.py index c74dd01cfc..50bfb28591 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -357,9 +357,7 @@ def is_active_flat_construction(self): :rtype: bool """ cls = self.__class__ - if not cls._flat_construction_stack: - return False - return cls._flat_construction_stack[-1] is self + return self in cls._flat_construction_stack def should_continue_construction(self): """ From b6f6c3225a41733c1102fc206b724b8e559835f2 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 17 Mar 2022 11:54:20 +0100 Subject: [PATCH 08/10] flat construction, cleanup --- returnn/tf/network.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/returnn/tf/network.py b/returnn/tf/network.py index 50bfb28591..19366b85ec 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -336,7 +336,11 @@ def __init__(self, network): """ self.network = network self.layers = [] # type: typing.List[str] - self.flat_construct_stack = [] # type: typing.List[typing.Tuple[TFNetwork, str, typing.Dict[str, typing.Any]]] + self.flat_construct_stack = [] # type: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any]]] + + def __repr__(self): + return "<%s %r (cur stack size: %i)>" % ( + self.__class__.__name__, self.network.name, len(self.flat_construct_stack)) def append(self, layer_name): """ @@ -375,17 +379,18 @@ def flat_construct(self, initial_exc): :rtype: LayerBase """ cls = self.__class__ + assert initial_exc.network is self.network assert not self.flat_construct_stack stack = self.flat_construct_stack - initial = (initial_exc.network, initial_exc.layer_name, initial_exc.other_kwargs) + initial = (initial_exc.layer_name, initial_exc.other_kwargs) stack.append(initial) cls._flat_construction_stack.append(self) try: while stack: try: top = stack[-1] - network, layer_name, other_kwargs = top - res = network.construct_layer(name=layer_name, **other_kwargs) + layer_name, other_kwargs = top + res = self.network.construct_layer(name=layer_name, **other_kwargs) stack.pop(-1) if top is initial: assert not stack @@ -393,12 +398,12 @@ def flat_construct(self, initial_exc): except _DelayedConstructionException as delayed_exc: if delayed_exc.network is not self.network: raise # some parent flat_construct() should handle this - stack.append((delayed_exc.network, delayed_exc.layer_name, delayed_exc.other_kwargs)) + stack.append((delayed_exc.layer_name, delayed_exc.other_kwargs)) except Exception as exc: attr = "_RETURNN_layer_construction_stack" if not hasattr(exc, attr): setattr(exc, attr, []) - getattr(exc, attr).extend(stack) + getattr(exc, attr).extend([(self.network, layer_name) for (layer_name, _) in stack]) raise finally: top_stack = cls._flat_construction_stack.pop(-1) From 0318dcb5ac4b51e22996a63ae1ce89f47ccc5853 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 17 Mar 2022 12:52:06 +0100 Subject: [PATCH 09/10] flat construction, refactor, cleanup, fix --- returnn/tf/layers/rec.py | 12 ++++++-- returnn/tf/network.py | 59 ++++++++++++++++++++++------------------ 2 files changed, 42 insertions(+), 29 deletions(-) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 353d5695f7..06ce3366cd 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -1222,12 +1222,15 @@ def _construct_template(self, parent_get_layer): from collections import OrderedDict from returnn.util.basic import StringIO, BehaviorVersion from returnn.tf.network import NetworkConstructionDependencyLoopException, DataNotFound + from returnn.tf.network import _DelayedConstructionException # The stack trace is not so interesting for these exceptions. skip_stack_trace_exception_types = ( NetworkConstructionDependencyLoopException,) # These Exceptions always indicate incorrect construction, so fail directly instead of collecting them - fail_directly_exception_types = (DataNotFound, LayerNotFound, BehaviorVersion.RequirementNotSatisfied) + fail_directly_exception_types = ( + DataNotFound, LayerNotFound, BehaviorVersion.RequirementNotSatisfied, + _DelayedConstructionException) # noinspection PyShadowingNames def _parent_get_layer(layer_name): @@ -1644,8 +1647,11 @@ def __call__(lself, name, is_prev_time_frame=False): # And keep the remaining ones for potential later reports. self._template_construction_exceptions = [s.text for s in ConstructCtx.collected_exceptions.values()] - except Exception: - print("%r: exception constructing template network (for deps and data shapes)" % self) + except _DelayedConstructionException: + raise + + except Exception as exc: + print("%r: %s while constructing template network (for deps and data shapes)" % (self, type(exc).__name__)) from pprint import pprint print("Most recent construction stack:") if ConstructCtx.most_recent: diff --git a/returnn/tf/network.py b/returnn/tf/network.py index 19366b85ec..be167f1b39 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -355,25 +355,34 @@ def remove(self, layer_name): """ self.layers.remove(layer_name) - def is_active_flat_construction(self): + def on_construct_layer_call(self, exc): """ - :return: whether this is called inside self.flat_construct(...) - :rtype: bool + This covers the whole flat construction logic. + If this returns None, it means that the normal construction should follow. + If a layer is returned, this can directly be returned. + Otherwise, this will not return but throw the exception which is handled outside. + + :param _DelayedConstructionException exc: + :rtype: LayerBase|None """ cls = self.__class__ - return self in cls._flat_construction_stack + if self not in cls._flat_construction_stack: + return self._flat_construct(exc) - def should_continue_construction(self): - """ - We assume that we are inside self.flat_construct(), and currently doing a construct_layer(). + assert exc.network is self.network + if self.flat_construct_stack: + if exc.layer_name == self.flat_construct_stack[-1][0]: + return None # continue with construction - :return: whether construct_layer() should continue. otherwise, it would throw _DelayedConstructionException. - :rtype: bool - """ - assert self.is_active_flat_construction() - return not self.layers + existing_in_stack = [entry for entry in self.flat_construct_stack if exc.layer_name == entry[0]] + if existing_in_stack: + raise NetworkConstructionDependencyLoopException( + layer_name=exc.layer_name, constructing_layers=[entry[0] for entry in self.flat_construct_stack], + net_dict=existing_in_stack[0][1]["net_dict"], network=self.network) - def flat_construct(self, initial_exc): + raise exc + + def _flat_construct(self, initial_exc): """ :param _DelayedConstructionException initial_exc: :rtype: LayerBase @@ -962,15 +971,6 @@ def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_ layer_name=full_name, network=self) return sub_layer - if not self._construction_stack.is_active_flat_construction(): - return self._construction_stack.flat_construct(delayed_exc) - if not self._construction_stack.should_continue_construction(): - # Note: We don't want to raise this earlier here in this function - # because certain exceptions such as LayerNotFound should directly be raised - # because some other code tests for this - # (e.g. checking the loss checking for layer "classes" and then layer "data:classes"). - raise delayed_exc - layer_desc = layer_desc.copy() layer_desc.pop("class") # Note about name: @@ -980,10 +980,14 @@ def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_ layer_desc["_network"] = net layer_desc["_name"] = base_name name_with_prefix = ("%s:%s" % (extra_prefix, name)) if extra_prefix else name - if name_with_prefix in self._construction_stack.layers: - raise NetworkConstructionDependencyLoopException( - layer_name=name_with_prefix, constructing_layers=self._construction_stack.layers, - net_dict=net_dict, network=self) + + # Note: We don't want to raise this earlier here in this function + # because certain exceptions such as LayerNotFound should directly be raised + # because some other code tests for this + # (e.g. checking the loss checking for layer "classes" and then layer "data:classes"). + _constructed_layer = self._construction_stack.on_construct_layer_call(delayed_exc) + if _constructed_layer: + return _constructed_layer self._construction_stack.append(name_with_prefix) try: # This call would also resolve dependencies, and e.g. recursively then create them (via get_layer calls). @@ -3146,6 +3150,9 @@ def add_templated_layer(name, layer_class, **layer_desc): if layer.get("is_output_layer"): get_templated_layer(layer_name) + except _DelayedConstructionException: + raise + except Exception as exc: # Merge the exception message + further debug information all together into a single exception, # which we will raise. From b7e3e9b038f9255cc049fa4db004432c0d104d9c Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 17 Mar 2022 14:05:33 +0100 Subject: [PATCH 10/10] flat construction, allow nested --- returnn/tf/network.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/returnn/tf/network.py b/returnn/tf/network.py index be167f1b39..a41362732d 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -366,7 +366,7 @@ def on_construct_layer_call(self, exc): :rtype: LayerBase|None """ cls = self.__class__ - if self not in cls._flat_construction_stack: + if not cls._flat_construction_stack or cls._flat_construction_stack[-1] is not self: return self._flat_construct(exc) assert exc.network is self.network @@ -389,24 +389,25 @@ def _flat_construct(self, initial_exc): """ cls = self.__class__ assert initial_exc.network is self.network - assert not self.flat_construct_stack stack = self.flat_construct_stack initial = (initial_exc.layer_name, initial_exc.other_kwargs) stack.append(initial) + stack_init_idx = len(stack) - 1 cls._flat_construction_stack.append(self) try: while stack: try: - top = stack[-1] + stack_top_idx = len(stack) - 1 + top = stack[stack_top_idx] layer_name, other_kwargs = top res = self.network.construct_layer(name=layer_name, **other_kwargs) + assert stack_top_idx == len(stack) - 1 stack.pop(-1) if top is initial: - assert not stack return res except _DelayedConstructionException as delayed_exc: - if delayed_exc.network is not self.network: - raise # some parent flat_construct() should handle this + # See on_construct_layer_call(). + assert delayed_exc.network is self.network # we should be in another flat_construct() otherwise stack.append((delayed_exc.layer_name, delayed_exc.other_kwargs)) except Exception as exc: attr = "_RETURNN_layer_construction_stack" @@ -417,7 +418,7 @@ def _flat_construct(self, initial_exc): finally: top_stack = cls._flat_construction_stack.pop(-1) assert top_stack is self - stack.clear() + del stack[stack_init_idx:] assert False, "we should not get here"