From 7d31efa4251b2999eab34019030dee8b06d0da7b Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 21 Sep 2021 15:06:49 +0200 Subject: [PATCH] Data.copy_masked --- returnn/tf/util/data.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 70281cb8f1..d0a549461e 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -3298,6 +3298,18 @@ def get_sequence_mask_broadcast(self, axis=None): assert seq_mask.get_shape().ndims == self.batch_ndim return seq_mask + def copy_masked(self, mask_value): + """ + :param float|int mask_value: + :rtype: Data + """ + assert self.placeholder is not None + from .basic import mask_dyn_seq_len_nd + dyn_axes = [axis for axis, dim in enumerate(self.dim_tags) if not dim.is_batch_dim() and dim.dimension is None] + res = self.copy() + res.placeholder = mask_dyn_seq_len_nd(self, pad_value=mask_value, axes=dyn_axes) + return res + def get_batch_dim(self): """ :rtype: tf.Tensor|int