Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 19 additions & 27 deletions returnn/tf/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from returnn.util.basic import NotSpecified, CollectionReadCheckCovered, BehaviorVersion
import returnn.tf.compat as tf_compat
import returnn.tf.util.basic as tf_util
from returnn.tf.util.data import Data, SearchBeam
from returnn.tf.util.data import Data
from returnn.tf.util.basic import OutputWithActivation, CustomUpdate, reuse_name_scope
from returnn.log import log

Expand Down Expand Up @@ -255,6 +255,7 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe
:return: Data template (placeholder not set)
:rtype: Data
"""
from ..util.data import DimensionTag
if callable(out_type):
return out_type(
network=network, name=name, n_out=n_out, target=target, size_target=size_target, sources=sources, loss=loss,
Expand All @@ -273,9 +274,8 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe
network=network, mark_data_key_as_used=False).dim
if n_out is not NotSpecified:
assert out_type["dim"] == n_out
sources_data = None
if sources and sources[0]:
sources_data = sources[0].output.copy_template()
sources_data_list = [src.output for src in sources if src]
sources_data = Data.get_common_data(sources_data_list, ignore_feature_dim=True) if sources_data_list else None
if sources_data and not sources_data.sparse and not out_type.get("sparse", False):
out_type.setdefault("dtype", sources_data.dtype)
# You are supposed to set self.output.{batch_dim_axis,time_dim_axis} explicitly,
Expand All @@ -296,38 +296,30 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe
if "shape" not in out_type and "dim_tags" not in out_type:
if sources_data:
if out_type.get("sparse", False):
out_type.setdefault("shape", sources_data.shape_sparse)
out_type["dim_tags"] = sources_data.dim_tags_sparse
else: # not sparse
feature_dim_axis = out_type.get("feature_dim_axis", NotSpecified)
if feature_dim_axis is NotSpecified:
if sources_data.feature_dim_axis is not None:
feature_dim_axis = sources_data.feature_dim_axis
else:
feature_dim_axis = -1
if sources_data.shape:
default_shape = list(sources_data.shape_dense)
if sources_data.batch_dim_axis is not None:
default_shape.insert(sources_data.batch_dim_axis, None)
default_shape[feature_dim_axis] = out_type.get("dim", None)
if out_type.get("batch_dim_axis") is not None:
default_shape.pop(out_type.get("batch_dim_axis"))
else: # source is scalar
if out_type.get("dim") or out_type.get("feature_dim_axis") is not None:
default_shape = (out_type.get("dim"),)
dim = out_type.get("dim", None)
dim_tags = list(sources_data.dim_tags_sparse)
feature_dim_tag = DimensionTag(
kind=DimensionTag.Types.Feature, description="%s:feature-dense" % name, dimension=dim)
if feature_dim_axis in (NotSpecified, None):
if sources_data.feature_dim_axis is None:
feature_dim_axis = len(dim_tags)
else:
default_shape = ()
out_type.setdefault("shape", tuple(default_shape))
feature_dim_axis = sources_data.feature_dim_axis
dim_tags.insert(feature_dim_axis, feature_dim_tag)
out_type["dim_tags"] = dim_tags
elif network.is_inside_rec_layer():
if out_type.get("sparse", False):
out_type.setdefault("shape", ())
else:
out_type.setdefault("shape", (out_type.get("dim", None),))
# Note: No special handling for feature_dim_axis here for now...
beam = None
for src in sources:
if src: # might be None if template construction
beam = SearchBeam.get_combined_beam(beam, src.output.beam)
out_type.setdefault("beam", beam)
if sources_data and sources_data.batch:
out_type.setdefault("batch", sources_data.batch)
if sources_data and sources_data.beam:
out_type.setdefault("beam", sources_data.beam)
output = Data(**out_type)
cls._post_init_output(
output=output, network=network, target=target, size_target=size_target, _target_layers=_target_layers,
Expand Down
18 changes: 9 additions & 9 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2343,23 +2343,23 @@ def shape_dense(self):
return self.shape

@property
def shape_sparse(self):
def batch_shape_dense(self):
"""
:return: shape without feature dim axis
:rtype: tuple[int|None]
"""
if self.sparse:
return self.shape
return self.shape[:self.feature_dim_axis] + self.shape[self.feature_dim_axis + 1:]
return self.batch_shape + (self.dim,)
return self.batch_shape

@property
def batch_shape_dense(self):
def dim_tags_sparse(self):
"""
:rtype: tuple[int|None]
:return: dim tags without feature dim axis
:rtype: tuple[DimensionTag]
"""
if self.sparse:
return self.batch_shape + (self.dim,)
return self.batch_shape
if self.sparse or not self.have_feature_axis():
return self.dim_tags
return self.dim_tags[:self.feature_dim_axis] + self.dim_tags[self.feature_dim_axis + 1:]

@property
def ndim(self):
Expand Down