diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 70281cb8f1..d0a549461e 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -3298,6 +3298,18 @@ def get_sequence_mask_broadcast(self, axis=None): assert seq_mask.get_shape().ndims == self.batch_ndim return seq_mask + def copy_masked(self, mask_value): + """ + :param float|int mask_value: + :rtype: Data + """ + assert self.placeholder is not None + from .basic import mask_dyn_seq_len_nd + dyn_axes = [axis for axis, dim in enumerate(self.dim_tags) if not dim.is_batch_dim() and dim.dimension is None] + res = self.copy() + res.placeholder = mask_dyn_seq_len_nd(self, pad_value=mask_value, axes=dyn_axes) + return res + def get_batch_dim(self): """ :rtype: tf.Tensor|int