diff --git a/returnn/tf/layers/base.py b/returnn/tf/layers/base.py index 9674923436..dd3b579d52 100644 --- a/returnn/tf/layers/base.py +++ b/returnn/tf/layers/base.py @@ -11,7 +11,7 @@ from returnn.util.basic import NotSpecified, CollectionReadCheckCovered, BehaviorVersion import returnn.tf.compat as tf_compat import returnn.tf.util.basic as tf_util -from returnn.tf.util.data import Data, SearchBeam +from returnn.tf.util.data import Data from returnn.tf.util.basic import OutputWithActivation, CustomUpdate, reuse_name_scope from returnn.log import log @@ -255,6 +255,7 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe :return: Data template (placeholder not set) :rtype: Data """ + from ..util.data import DimensionTag if callable(out_type): return out_type( network=network, name=name, n_out=n_out, target=target, size_target=size_target, sources=sources, loss=loss, @@ -273,9 +274,8 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe network=network, mark_data_key_as_used=False).dim if n_out is not NotSpecified: assert out_type["dim"] == n_out - sources_data = None - if sources and sources[0]: - sources_data = sources[0].output.copy_template() + sources_data_list = [src.output for src in sources if src] + sources_data = Data.get_common_data(sources_data_list, ignore_feature_dim=True) if sources_data_list else None if sources_data and not sources_data.sparse and not out_type.get("sparse", False): out_type.setdefault("dtype", sources_data.dtype) # You are supposed to set self.output.{batch_dim_axis,time_dim_axis} explicitly, @@ -296,38 +296,30 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe if "shape" not in out_type and "dim_tags" not in out_type: if sources_data: if out_type.get("sparse", False): - out_type.setdefault("shape", sources_data.shape_sparse) + out_type["dim_tags"] = sources_data.dim_tags_sparse else: # not sparse feature_dim_axis = out_type.get("feature_dim_axis", NotSpecified) - if feature_dim_axis is NotSpecified: - if sources_data.feature_dim_axis is not None: - feature_dim_axis = sources_data.feature_dim_axis - else: - feature_dim_axis = -1 - if sources_data.shape: - default_shape = list(sources_data.shape_dense) - if sources_data.batch_dim_axis is not None: - default_shape.insert(sources_data.batch_dim_axis, None) - default_shape[feature_dim_axis] = out_type.get("dim", None) - if out_type.get("batch_dim_axis") is not None: - default_shape.pop(out_type.get("batch_dim_axis")) - else: # source is scalar - if out_type.get("dim") or out_type.get("feature_dim_axis") is not None: - default_shape = (out_type.get("dim"),) + dim = out_type.get("dim", None) + dim_tags = list(sources_data.dim_tags_sparse) + feature_dim_tag = DimensionTag( + kind=DimensionTag.Types.Feature, description="%s:feature-dense" % name, dimension=dim) + if feature_dim_axis in (NotSpecified, None): + if sources_data.feature_dim_axis is None: + feature_dim_axis = len(dim_tags) else: - default_shape = () - out_type.setdefault("shape", tuple(default_shape)) + feature_dim_axis = sources_data.feature_dim_axis + dim_tags.insert(feature_dim_axis, feature_dim_tag) + out_type["dim_tags"] = dim_tags elif network.is_inside_rec_layer(): if out_type.get("sparse", False): out_type.setdefault("shape", ()) else: out_type.setdefault("shape", (out_type.get("dim", None),)) # Note: No special handling for feature_dim_axis here for now... - beam = None - for src in sources: - if src: # might be None if template construction - beam = SearchBeam.get_combined_beam(beam, src.output.beam) - out_type.setdefault("beam", beam) + if sources_data and sources_data.batch: + out_type.setdefault("batch", sources_data.batch) + if sources_data and sources_data.beam: + out_type.setdefault("beam", sources_data.beam) output = Data(**out_type) cls._post_init_output( output=output, network=network, target=target, size_target=size_target, _target_layers=_target_layers, diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index ffa8e7ca93..8aee5ed70e 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -2343,23 +2343,23 @@ def shape_dense(self): return self.shape @property - def shape_sparse(self): + def batch_shape_dense(self): """ - :return: shape without feature dim axis :rtype: tuple[int|None] """ if self.sparse: - return self.shape - return self.shape[:self.feature_dim_axis] + self.shape[self.feature_dim_axis + 1:] + return self.batch_shape + (self.dim,) + return self.batch_shape @property - def batch_shape_dense(self): + def dim_tags_sparse(self): """ - :rtype: tuple[int|None] + :return: dim tags without feature dim axis + :rtype: tuple[DimensionTag] """ - if self.sparse: - return self.batch_shape + (self.dim,) - return self.batch_shape + if self.sparse or not self.have_feature_axis(): + return self.dim_tags + return self.dim_tags[:self.feature_dim_axis] + self.dim_tags[self.feature_dim_axis + 1:] @property def ndim(self):