From 75f0abf698004af2aa2e296473171f6807674302 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 20 Sep 2021 18:33:15 +0200 Subject: [PATCH 1/4] DimensionTag.declare_same_as, fix dyn_size_ext for ctx in some cases This partly addresses the concerns and mostly fixes #672. --- returnn/tf/util/data.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 3d5b62109a..0228ece6a1 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -235,6 +235,21 @@ def get_for_batch_ctx(self, batch, ctx): same_base._same_for_batch_ctx[(dim_tag.batch, dim_tag.control_flow_ctx)] = dim_tag return dim_tag + def get_dyn_size_ext_for_batch_ctx(self, batch, ctx): + """ + :param BatchInfo|None batch: + :param ControlFlowContext|None ctx: + :rtype: Data|None + """ + if not batch and self.batch: + # Assume global batch. + batch = self.batch.get_global_base() + if not batch: + # This is usually not valid. However, this case can happen early at initialization. + assert batch == self.batch and ctx == self.control_flow_ctx + return self.dyn_size_ext + return self.get_for_batch_ctx(batch, ctx).dyn_size_ext + @property def dyn_size(self): """ @@ -520,16 +535,18 @@ def declare_same_as(self, other): self_same_as.same_as = other_same_base self_same_as._same_as_tb = traceback.extract_stack() if self_same_as.dyn_size_ext is None: - self_same_as.dyn_size_ext = other_same_base.dyn_size_ext + self_same_as.dyn_size_ext = other_same_base.get_dyn_size_ext_for_batch_ctx( + self_same_as.batch, self_same_as.control_flow_ctx) elif other_same_base.dyn_size_ext is None: - other_same_base.dyn_size_ext = self_same_as.dyn_size_ext + other_same_base.dyn_size_ext = self_same_as.get_dyn_size_ext_for_batch_ctx( + other_same_base.batch, other_same_base.control_flow_ctx) if self.dyn_size_ext is None and self_same_as.dyn_size_ext: - self.dyn_size_ext = self_same_as.dyn_size_ext.copy_extend_with_beam(self.batch.beam if self.batch else None) + self.dyn_size_ext = self_same_as.get_dyn_size_ext_for_batch_ctx(self.batch, self.control_flow_ctx) self.same_as = other_same_base self._same_as_tb = traceback.extract_stack() if self.dyn_size is not None and other_same_base.dyn_size is not None: if self.dyn_size is not other_same_base.dyn_size: - if self.batch == other_same_base.batch: + if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: # Note: Instead of making this a warning, we could also enforce this at some point. # The user should be able to fix `extern_data` in the config such that this is correct in the first place. # Also, in addition to this warning, we might want to add some runtime check on the eq of the dyn sizes. @@ -540,17 +557,18 @@ def declare_same_as(self, other): if self.same_as.dyn_size is not None and self.src_data: assert isinstance(self.src_axis, int) # Maybe it changed in the meanwhile, so check. - if self.src_data.get_dim_tag(self.src_axis).description == self.description: - self.src_data.size_placeholder[ - self.src_data.get_batch_axis_excluding_batch(self.src_axis)] = self.same_as.dyn_size + tag = self.src_data.get_dim_tag(self.src_axis) + if tag.description == self.description and not tag.dyn_size_ext: + tag.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx(tag.batch, tag.control_flow_ctx) # If others dyn_size is None but we have a dyn_size, maybe update others dyn_size. if self.dyn_size is not None and self.same_as.dyn_size is not self.dyn_size: # Could be unset if it comes from the config, or from prev graph creation. # This is important such that self.can_compare() is sane. if self.same_as.dyn_size is None or self.same_as.dyn_size.graph is not self.dyn_size.graph: - self.same_as.dyn_size_ext = self.dyn_size_ext + self.same_as.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx( + self.same_as.batch, self.same_as.control_flow_ctx) if not self.dyn_size_ext and other.dyn_size_ext: - self.dyn_size_ext = other.dyn_size_ext.copy() + self.dyn_size_ext = other.get_dyn_size_ext_for_batch_ctx(self.batch, self.control_flow_ctx) @classmethod def get_existing_tag_from_collection(cls, other, tags, is_equal_opts=None): From 87a116955b4a48895c0a4e2b1d5cdcaa9061bb19 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 20 Sep 2021 21:09:08 +0200 Subject: [PATCH 2/4] DimensionTag.get_for_batch_ctx validate current graph --- returnn/tf/util/data.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 0228ece6a1..b7a53f36a6 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -149,6 +149,20 @@ def _can_use_in_ctx(self, ctx): return False return True + def _validate_in_current_graph(self): + """ + :rtype: bool + """ + if self.dyn_size_ext and self.dyn_size_ext.placeholder is not None: + g = tf_compat.v1.get_default_graph() + if self.dyn_size_ext.placeholder.graph is not g: # maybe from an earlier run which reuses the dim tag + # Reset and cleanup. + self.dyn_size_ext = None + same_base = self.get_same_base() + same_base._same_for_batch_ctx.pop((self.batch, self.control_flow_ctx), None) + return False + return True + def get_for_batch_ctx(self, batch, ctx): """ :param BatchInfo batch: @@ -156,6 +170,7 @@ def get_for_batch_ctx(self, batch, ctx): :rtype: DimensionTag """ if self.batch == batch and self.control_flow_ctx == ctx: + self._validate_in_current_graph() return self if self.is_batch_dim(): # We ignore the ctx for the batch dim currently. @@ -169,6 +184,7 @@ def get_for_batch_ctx(self, batch, ctx): if batch.is_broadcast(): return self # just leave as-is. should not matter. same_base = self.get_same_base() + same_base._validate_in_current_graph() # Might be uninitialized in some cases. Assume batch is global. if not same_base.batch: batch_base = batch.get_global_base() @@ -183,13 +199,13 @@ def get_for_batch_ctx(self, batch, ctx): assert same_base.batch == same_base.dyn_size_ext.batch assert same_base.control_flow_ctx == same_base.dyn_size_ext.control_flow_ctx tag = same_base._same_for_batch_ctx.get((batch, ctx), None) - if tag: + if tag and tag._validate_in_current_graph(): return tag if same_base.batch == batch and same_base._can_use_in_ctx(ctx): return same_base for ctx_ in ControlFlowContext.abs_ctx_stack_with_root(ctx): tag = same_base._same_for_batch_ctx.get((batch, ctx_), None) - if tag and tag._can_use_in_ctx(ctx): + if tag and tag._can_use_in_ctx(ctx) and tag._validate_in_current_graph(): return tag # Ok, nothing matching found. dyn_size_ext = None @@ -202,7 +218,7 @@ def get_for_batch_ctx(self, batch, ctx): else: for ctx_ in ControlFlowContext.abs_ctx_stack_with_root(ctx): tag = same_base._same_for_batch_ctx.get((batch_base, ctx_), None) - if tag and tag._can_use_in_ctx(ctx): + if tag and tag._can_use_in_ctx(ctx) and tag._validate_in_current_graph(): base_can_use_in_ctx = tag break if base_can_use_in_ctx and base_can_use_in_ctx.dyn_size_ext: From 1af2fa917ae81da90e2dc65b35dd0a8d2ca69442 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 20 Sep 2021 23:46:08 +0200 Subject: [PATCH 3/4] DimensionTag.declare_same_as, better warning --- returnn/tf/util/data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index b7a53f36a6..90ae0002ac 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -567,7 +567,8 @@ def declare_same_as(self, other): # The user should be able to fix `extern_data` in the config such that this is correct in the first place. # Also, in addition to this warning, we might want to add some runtime check on the eq of the dyn sizes. print( - "Warning: assuming dim tags are same with different size placeholders: %r vs %r" % (self, other_same_base)) + "Warning: assuming dim tags are same with different size placeholders: %r vs %r" % ( + self.dyn_size, other_same_base.dyn_size)) # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, # maybe we can overtake the size_placeholder now. if self.same_as.dyn_size is not None and self.src_data: From 5ae214ab1803eec15d432972aa972aaebab9afdb Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 21 Sep 2021 01:39:13 +0200 Subject: [PATCH 4/4] DimensionTag.declare_same_as better logic --- returnn/tf/util/data.py | 106 +++++++++++++++++++++++++++++++--------- 1 file changed, 84 insertions(+), 22 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 90ae0002ac..10e7d494c5 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -82,14 +82,14 @@ def __init__(self, kind=Types.Unspecified, description=None, if dyn_size_ext: assert batch == dyn_size_ext.batch self.dyn_size_ext = dyn_size_ext # type: typing.Optional[Data] - if dyn_size is not None: - assert not dyn_size_ext - self.dyn_size = dyn_size self._dyn_size_same = set() # type: typing.Set[tf.Tensor] self._undefined = undefined # We can have different tag variants per batch info (e.g. with beam), or per control flow ctx. # They each have same_as = self. The same_base should have the base (global) batch info. self._same_for_batch_ctx = {} # type: typing.Dict[typing.Tuple[BatchInfo,typing.Optional[ControlFlowContext]],DimensionTag] # nopep8 + if dyn_size is not None: + assert not dyn_size_ext + self.dyn_size = dyn_size def __repr__(self): return "DimensionTag{%s}" % self.short_repr() @@ -153,23 +153,49 @@ def _validate_in_current_graph(self): """ :rtype: bool """ - if self.dyn_size_ext and self.dyn_size_ext.placeholder is not None: + tensor = None + if self.batch: + batch_base = self.batch.get_global_base() + if batch_base.is_global_batch(): + tensor = batch_base.get_global_batch_dim().size + if not isinstance(tensor, tf.Tensor): + if self.dyn_size_ext and self.dyn_size_ext.placeholder is not None: + tensor = self.dyn_size_ext.placeholder + if isinstance(tensor, tf.Tensor): g = tf_compat.v1.get_default_graph() - if self.dyn_size_ext.placeholder.graph is not g: # maybe from an earlier run which reuses the dim tag + if tensor.graph is not g: # maybe from an earlier run which reuses the dim tag # Reset and cleanup. self.dyn_size_ext = None same_base = self.get_same_base() same_base._same_for_batch_ctx.pop((self.batch, self.control_flow_ctx), None) + self.batch = None # it is invalid in the new graph + self.control_flow_ctx = None # also invalid return False return True - def get_for_batch_ctx(self, batch, ctx): + def _maybe_update(self): + if self.is_batch_dim(): + return + if isinstance(self.dimension, int): + return + if self.dyn_size_ext: + return + if not self.batch: + return + # Check if we can find more in + same = self.get_for_batch_ctx(self.batch, self.control_flow_ctx, allow_none=True) + if self is same or not same or not same.dyn_size_ext: + return + self.dyn_size_ext = same.dyn_size_ext + + def get_for_batch_ctx(self, batch, ctx, allow_none=False): """ :param BatchInfo batch: :param ControlFlowContext|None ctx: - :rtype: DimensionTag + :param bool allow_none: + :rtype: DimensionTag|None """ - if self.batch == batch and self.control_flow_ctx == ctx: + if self.batch == batch and self.control_flow_ctx == ctx and self.dyn_size_ext: self._validate_in_current_graph() return self if self.is_batch_dim(): @@ -198,27 +224,24 @@ def get_for_batch_ctx(self, batch, ctx): if same_base.dyn_size_ext: assert same_base.batch == same_base.dyn_size_ext.batch assert same_base.control_flow_ctx == same_base.dyn_size_ext.control_flow_ctx - tag = same_base._same_for_batch_ctx.get((batch, ctx), None) - if tag and tag._validate_in_current_graph(): - return tag - if same_base.batch == batch and same_base._can_use_in_ctx(ctx): - return same_base for ctx_ in ControlFlowContext.abs_ctx_stack_with_root(ctx): tag = same_base._same_for_batch_ctx.get((batch, ctx_), None) if tag and tag._can_use_in_ctx(ctx) and tag._validate_in_current_graph(): return tag + if same_base.batch == batch and same_base._can_use_in_ctx(ctx) and same_base.dyn_size_ext: + return same_base # Ok, nothing matching found. dyn_size_ext = None # Maybe we have sth with the base batch without beam which we can extend. if batch.copy_remove_beam() == batch.get_global_base() and batch.beam: batch_base = batch.get_global_base() base_can_use_in_ctx = None - if same_base.batch == batch_base and same_base._can_use_in_ctx(ctx): + if same_base.batch == batch_base and same_base._can_use_in_ctx(ctx) and same_base.dyn_size_ext: base_can_use_in_ctx = same_base else: for ctx_ in ControlFlowContext.abs_ctx_stack_with_root(ctx): tag = same_base._same_for_batch_ctx.get((batch_base, ctx_), None) - if tag and tag._can_use_in_ctx(ctx) and tag._validate_in_current_graph(): + if tag and tag._can_use_in_ctx(ctx) and tag._validate_in_current_graph() and tag.dyn_size_ext: base_can_use_in_ctx = tag break if base_can_use_in_ctx and base_can_use_in_ctx.dyn_size_ext: @@ -240,6 +263,8 @@ def get_for_batch_ctx(self, batch, ctx): name=get_valid_scope_name_from_str("%s_identity_for_beam_%s" % (dyn_size_ext.name, batch.beam.name))) dyn_size_ext.placeholder._RETURNN_dyn_size_beam = batch.beam dyn_size_ext.placeholder._RETURNN_beam_expanded_base_data = beam_expanded_base_data + if not dyn_size_ext and allow_none: + return None dim_tag = DimensionTag( kind=self.kind, description=self.description, dimension=self.dimension, batch=batch, control_flow_ctx=dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx, @@ -251,6 +276,16 @@ def get_for_batch_ctx(self, batch, ctx): same_base._same_for_batch_ctx[(dim_tag.batch, dim_tag.control_flow_ctx)] = dim_tag return dim_tag + def set_dyn_size_ext_for_batch_ctx(self, batch, ctx, dyn_size_ext): + """ + :param BatchInfo batch: + :param ControlFlowContext|None ctx: + :param Data dyn_size_ext: + """ + same = self.get_for_batch_ctx(batch, ctx) + same.dyn_size_ext = dyn_size_ext + self._maybe_update() + def get_dyn_size_ext_for_batch_ctx(self, batch, ctx): """ :param BatchInfo|None batch: @@ -264,7 +299,10 @@ def get_dyn_size_ext_for_batch_ctx(self, batch, ctx): # This is usually not valid. However, this case can happen early at initialization. assert batch == self.batch and ctx == self.control_flow_ctx return self.dyn_size_ext - return self.get_for_batch_ctx(batch, ctx).dyn_size_ext + same = self.get_for_batch_ctx(batch, ctx, allow_none=True) + if not same: + return None + return same.dyn_size_ext @property def dyn_size(self): @@ -538,6 +576,8 @@ def declare_same_as(self, other): """ :param DimensionTag other: """ + self._maybe_update() + self._validate_in_current_graph() if self is other: return other_same_base = other.get_same_base() @@ -548,18 +588,21 @@ def declare_same_as(self, other): assert not self_same_as.same_as if self_same_as is other_same_base: return + other_same_base._merge_same_for_batch_ctx_dict(self_same_as) self_same_as.same_as = other_same_base self_same_as._same_as_tb = traceback.extract_stack() - if self_same_as.dyn_size_ext is None: + if self_same_as.dyn_size_ext is None or not self_same_as._validate_in_current_graph(): self_same_as.dyn_size_ext = other_same_base.get_dyn_size_ext_for_batch_ctx( self_same_as.batch, self_same_as.control_flow_ctx) - elif other_same_base.dyn_size_ext is None: + elif other_same_base.dyn_size_ext is None or not other_same_base._validate_in_current_graph(): other_same_base.dyn_size_ext = self_same_as.get_dyn_size_ext_for_batch_ctx( other_same_base.batch, other_same_base.control_flow_ctx) - if self.dyn_size_ext is None and self_same_as.dyn_size_ext: + if (self.dyn_size_ext is None or not self._validate_in_current_graph()) and self_same_as.dyn_size_ext: self.dyn_size_ext = self_same_as.get_dyn_size_ext_for_batch_ctx(self.batch, self.control_flow_ctx) + other_same_base._merge_same_for_batch_ctx_dict(self) self.same_as = other_same_base self._same_as_tb = traceback.extract_stack() + self._maybe_update() if self.dyn_size is not None and other_same_base.dyn_size is not None: if self.dyn_size is not other_same_base.dyn_size: if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: @@ -575,18 +618,37 @@ def declare_same_as(self, other): assert isinstance(self.src_axis, int) # Maybe it changed in the meanwhile, so check. tag = self.src_data.get_dim_tag(self.src_axis) - if tag.description == self.description and not tag.dyn_size_ext: + if tag.description == self.description and (not tag.dyn_size_ext or not tag._validate_in_current_graph()): tag.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx(tag.batch, tag.control_flow_ctx) # If others dyn_size is None but we have a dyn_size, maybe update others dyn_size. if self.dyn_size is not None and self.same_as.dyn_size is not self.dyn_size: # Could be unset if it comes from the config, or from prev graph creation. # This is important such that self.can_compare() is sane. - if self.same_as.dyn_size is None or self.same_as.dyn_size.graph is not self.dyn_size.graph: + if self.same_as.dyn_size is None or not self.same_as._validate_in_current_graph(): self.same_as.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx( self.same_as.batch, self.same_as.control_flow_ctx) - if not self.dyn_size_ext and other.dyn_size_ext: + if (not self.dyn_size_ext or not self._validate_in_current_graph()) and other.dyn_size_ext: self.dyn_size_ext = other.get_dyn_size_ext_for_batch_ctx(self.batch, self.control_flow_ctx) + def _merge_same_for_batch_ctx_dict(self, other): + """ + :param DimensionTag other: + """ + self._validate_in_current_graph() + for _, dim in list(self._same_for_batch_ctx.items()): + assert isinstance(dim, DimensionTag) + dim._validate_in_current_graph() + for key, dim in other._same_for_batch_ctx.items(): + if not dim._validate_in_current_graph(): + continue + self_dim = self._same_for_batch_ctx.get(key, None) + if self_dim and (self_dim.dyn_size_ext or not dim.dyn_size_ext): + continue # keep ours + if not dim.dyn_size_ext: + continue # undefined, do not overtake + self._same_for_batch_ctx[key] = dim + other._same_for_batch_ctx.clear() # we only want to have it once + @classmethod def get_existing_tag_from_collection(cls, other, tags, is_equal_opts=None): """