From e5c72c38e5bd7417a6b747590cba702d96544d07 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 20 Sep 2021 17:51:08 +0200 Subject: [PATCH 1/2] Rec subset output, fix out dim tag in some cases, declare_same_as --- returnn/tf/layers/rec.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 9fd6e49564..e243b4ab41 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -2631,8 +2631,15 @@ def cond(i, net_vars, acc_tas, seq_len_info=None): if output_layer: assert isinstance(output_layer, LayerBase) output_data = output_layer.output.copy_as_time_major() - assert 0 in output_data.size_placeholder - rec_layer.output.size_placeholder = output_data.size_placeholder.copy() + self.time_dim_tag.declare_same_as(output_data.get_time_dim_tag()) + assert len(rec_layer.output.dim_tags) == len(output_data.dim_tags) + for tag1, tag2 in zip(rec_layer.output.dim_tags, output_data.dim_tags): + assert tag1.is_equal(tag2, allow_same_feature_dim=True) + # Make sure they are the same. + # It can happen that they are not when the dim tag is created inside, + # and then created once for the template layer, and again for the real layer. + # Make sure they are really the same such that we get all information like dyn sizes. + tag1.declare_same_as(tag2) output = output_data.placeholder else: assert seq_len is not None @@ -2641,12 +2648,6 @@ def cond(i, net_vars, acc_tas, seq_len_info=None): output = tensor_array_stack( self.final_acc_tas_dict["output_output"], stop=max_seq_len, name="output_stack") # e.g. (time, batch, dim) - existing_time_dim_tag = DimensionTag.get_tag_from_size_tensor(rec_layer.output.size_placeholder[0]) - if existing_time_dim_tag: - self.time_dim_tag.declare_same_as(existing_time_dim_tag) - else: - self.time_dim_tag.set_tag_on_size_tensor(rec_layer.output.size_placeholder[0], batch=rec_layer.output.batch) - for key in ( self.net.used_data_keys | (self.input_layers_net.used_data_keys if self.input_layers_net else set()) | From cfdae650304d55b9dbbe07852831c38364f3cf2a Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 21 Sep 2021 01:39:44 +0200 Subject: [PATCH 2/2] Rec subnet relax out fixed time dim tag check --- returnn/tf/layers/rec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index e243b4ab41..2fb0826a50 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -1988,7 +1988,7 @@ def get_output(self): fixed_seq_len = input_seq_len if fixed_seq_len is not None: time_dim_tag = DimensionTag.get_tag_from_size_tensor(fixed_seq_len) - assert time_dim_tag is self.time_dim_tag + assert time_dim_tag == self.time_dim_tag with tf.name_scope("check_seq_len_batch_size"): fixed_seq_len = check_input_dim( fixed_seq_len, axis=0, dim=batch_dim * (input_beam.beam_size if input_beam else 1))