From bcebc309238b2b73425dcc46632a4bb5e2134ac2 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 20 Sep 2021 12:32:19 +0200 Subject: [PATCH 1/2] LengthLayer, support dyn_size_ext --- returnn/tf/layers/basic.py | 42 +++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index de621dd7df..4379094b2e 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1637,42 +1637,50 @@ class LengthLayer(LayerBase): layer_class = "length" # noinspection PyUnusedLocal - def __init__(self, add_time_axis=False, dtype="int32", sparse=False, **kwargs): + def __init__(self, axis="T", add_time_axis=False, dtype="int32", sparse=False, **kwargs): """ + :param str|DimensionTag axis: :param bool add_time_axis: :param str dtype: :param bool sparse: """ super(LengthLayer, self).__init__(**kwargs) assert len(self.sources) == 1, "%s: expects one source" % self - out = tf.cast(self.sources[0].output.get_sequence_lengths(), dtype) + source = self.sources[0].output + axis = source.get_axis_from_description(axis, allow_int=False) + dim = source.dim_tags[axis] if add_time_axis: - out = tf.expand_dims(out, axis=self.output.time_dim_axis) - self.output.placeholder = out + self.output.placeholder = tf.expand_dims(dim.dyn_size, axis=self.output.time_dim_axis) + else: + self.output.placeholder = dim.dyn_size_ext.placeholder @classmethod - def get_out_data_from_opts(cls, name, sources, add_time_axis=False, dtype="int32", sparse=False, **kwargs): + def get_out_data_from_opts(cls, name, sources, axis="T", add_time_axis=False, dtype="int32", sparse=False, **kwargs): """ :param str name: :param list[LayerBase] sources: + :param str|DimensionTag axis: :param bool add_time_axis: :param str dtype: :param bool sparse: :rtype: Data """ + assert len(sources) == 1 + source = sources[0].output + axis = source.get_axis_from_description(axis, allow_int=False) + dim = source.dim_tags[axis] if add_time_axis: - shape = (1,) - time_dim_axis = 1 - else: - shape = () - time_dim_axis = None - return Data( - name="%s_length" % name, - shape=shape, - batch_dim_axis=0, - time_dim_axis=time_dim_axis, - dtype=dtype, - sparse=sparse, dim=None if sparse else NotSpecified) + assert dim.dyn_size_ext and dim.dyn_size_ext.have_batch_axis() and dim.dyn_size_ext.batch_ndim == 1 # [B] + return Data( + name="%s_length" % name, + shape=[1], batch_dim_axis=0, time_dim_axis=1, + dtype=dtype, sparse=sparse, dim=None if sparse else NotSpecified) + if not dim.dyn_size_ext: # yet undefined + return Data( + name="%s_length" % name, + shape=(), batch_dim_axis=0, time_dim_axis=None, + dtype=dtype, sparse=sparse, dim=None if sparse else NotSpecified) + return dim.dyn_size_ext class SoftmaxOverSpatialLayer(_ConcatInputLayer): From 1194a001ed85187d33f9cb41eb0844009e94936a Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 20 Sep 2021 16:25:33 +0200 Subject: [PATCH 2/2] LengthLayer, add dim_tag attrib --- returnn/tf/layers/basic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 4379094b2e..5a77fac928 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1649,6 +1649,7 @@ def __init__(self, axis="T", add_time_axis=False, dtype="int32", sparse=False, * source = self.sources[0].output axis = source.get_axis_from_description(axis, allow_int=False) dim = source.dim_tags[axis] + self.dim_tag = dim if add_time_axis: self.output.placeholder = tf.expand_dims(dim.dyn_size, axis=self.output.time_dim_axis) else: