Skip to content

Adding Longformer Encoder Decoder support for T5 #10432

@huu4ontocord

Description

@huu4ontocord

🚀 Adding Longformer Encoder Decoder support for T5

LED is great for doing long form encoder decoder of documents, but it is based only on BART. T5 has certain advantages, such as being designed for multi tasks (QA, summarization, etc.) and having relative positioning.

T5 uses relative positioning which maps well to doing sliding chunks and should not require additional training to learn new relative position buckets. Adding LED support will permit any already trained T5 models to be used efficiently on long document.

I've started incorporating LED features into the encoder portion of T5 but have some quesitons about the position_bias and implementation details of t5 and LED. With some help on understanding how sliding window multiplcation works in LED and how relative position is organized, I think I can finish the impelmentation.

In particular, T5 passes a position_bias that along with the mask as added in each layer. This bias is added to each score before performing a softmax.

I've surmised that I can add the position_bias to the mask in the long former self attention, and then that should mostly be the same as the orginal t5 self attention.

T5's position_bias is in the shape of (batch_size, n_heads, seq_length, key_length) . But the mask used for LED is in the form of (batch_size, seq_length), which is then mapped to n_heads and then through sliding multiplication to stack the mask. I permute the postion_bias, and then run through sliding multiplication to stack the bias so that the posiion bias can db added to the mask.

I tried a test of attention_window size of 512 with exactly 512 worth of tokens, which should make it equivalent to t5 self attention. But something seems to be off.

The encoder produces a tensor that suprisingly can be decoded by the decoder, which is encouraging, but it's not producing an answer for QA for example.

I noticed that t5 doesn't use sqrt (key value proj dim) normalization, and has an extra mapping through tensor o. I tried with and without sqrt but no good either way.

Am I getting something mixed up with the position_bias?

@ibeltagy @patrickvonplaten @sgugger any help would be much appreciated. Happy to contribute this as a PR when completed.

Current code: https://github.com/ontocord/t5_led/blob/main/t5_ext.py

relevant portion:

    def forward_long(
        self,
        hidden_states,
        mask=None,
        position_bias=None,
        layer_head_mask=None,
        is_index_masked=None,
        is_index_global_attn=None,
        is_global_attn=None,
        output_attentions=False,
        compute_relative_attention_bias=False,
        query_states = None,
        query_mask = None,
        layer_id=0,
    ):
        """
        :class:`LEDEncoderSelfAttention` expects `len(hidden_states)` to be multiple of `attention_window`. Padding to
        `attention_window` happens in :meth:`LEDEncoderModel.forward` to avoid redoing the padding on each layer.
        The `mask` is changed in :meth:`LEDEncoderModel.forward` from 0, 1, 2 to:
            * -10000: no attention
            * 0: local attention
            * +10000: global attention
        """

        batch_size, seq_length = hidden_states.shape[:2]

        if position_bias is None:
            if not self.has_relative_attention_bias or not compute_relative_attention_bias:
                position_bias = torch.zeros(
                    (1, self.n_heads, seq_length, seq_lenth),  device=hidden_states.device, dtype=hidden_states.dtype
                )
            else:
                position_bias = self.compute_bias(seq_length, seq_length,  False)  # (batch_size, n_heads, seq_length, key_length) 
            position_bias = position_bias.permute(0, 2, 1, 3) 
            print ("ccompute bias 2", position_bias.size())

        hidden_states = hidden_states.transpose(0, 1)
        if query_states is None:
            query_states = hidden_states
        # project hidden states
        if query_mask is not None:
            query_vectors = self.q(query_states) * query_mask.unsqueeze(-1).expand(-1, -1, query.shape[-1]) 
        else:
            query_vectors = self.q(query_states)

        key_vectors = self.k(hidden_states)
        value_vectors = self.v(hidden_states)

        seq_len, batch_size, embed_dim = hidden_states.size()
        assert (
            embed_dim == self.embed_dim
        ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"

        # normalize query - T5 does not do the sqrt???
        query_vectors /= math.sqrt(self.key_value_proj_dim)

        query_vectors = query_vectors.view(seq_len, batch_size, self.n_heads, self.key_value_proj_dim).transpose(0, 1)
        key_vectors = key_vectors.view(seq_len, batch_size, self.n_heads, self.key_value_proj_dim).transpose(0, 1)

        attn_scores = self._sliding_chunks_query_key_matmul(
            query_vectors, key_vectors, self.one_sided_attn_window_size
        )

        # values to pad for attention probs
        remove_from_windowed_mask = (mask != 0)[:, :, None, None]

        # cast to fp32/fp16 then replace 1's with -inf
        float_mask = remove_from_windowed_mask.type_as(query_vectors).masked_fill(
            remove_from_windowed_mask, -10000.0
        )

        # POSITION_BIAS here: stack 2*one_sided_attn_window_size+1 worth of bias in the last dimension
        position_bias2 = self._sliding_chunks_query_key_matmul(
            position_bias.new_ones(size=position_bias.size()), position_bias, self.one_sided_attn_window_size
        )
        
        # diagonal mask with zeros everywhere and -inf inplace of padding
        diagonal_mask = self._sliding_chunks_query_key_matmul(
            float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
        )

        # pad local attention probs and add the position bias
        attn_scores += diagonal_mask + position_bias2

        assert list(attn_scores.size()) == [
            batch_size,
            seq_len,
            self.n_heads,
            self.one_sided_attn_window_size * 2 + 1,
        ], f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.n_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"

        # compute local attention probs from global attention keys and contact over window dim
        if is_global_attn:
            # compute global attn indices required through out forward fn
            (
                max_num_global_attn_indices,
                is_index_global_attn_nonzero,
                is_local_index_global_attn_nonzero,
                is_local_index_no_global_attn_nonzero,
            ) = self._get_global_attn_indices(is_index_global_attn)
            # calculate global attn probs from global key

            global_key_attn_scores = self._concat_with_global_key_attn_probs(
                query_vectors=query_vectors,
                key_vectors=key_vectors,
                max_num_global_attn_indices=max_num_global_attn_indices,
                is_index_global_attn_nonzero=is_index_global_attn_nonzero,
                is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
                is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
            )
            # concat to local_attn_probs
            # (batch_size, seq_len, n_heads, extra attention count + 2*window+1)
            attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1)

            # free memory
            del global_key_attn_scores


        attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32)  # use fp32 for numerical stability

        if layer_head_mask is not None:
            assert layer_head_mask.size() == (
                self.n_heads,
            ), f"Head mask for a single layer should be of size {(self.n_heads,)}, but is {layer_head_mask.size()}"
            attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs

        # softmax sometimes inserts NaN if all positions are masked, replace them with 0

        attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :attn_probs.size()[1], None, None], 0.0)
        attn_probs = attn_probs.type_as(attn_scores)

        # free memory
        del attn_scores

        # apply dropout
        attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training)
  
      value_vectors = value_vectors.view(seq_len, batch_size, self.n_heads, self.key_value_proj_dim).transpose(0, 1)

        # compute local attention output with global attention value and add
        if is_global_attn:
            # compute sum of global and local attn
            attn_output = self._compute_attn_output_with_global_indices(
                value_vectors=value_vectors,
                attn_probs=attn_probs,
                max_num_global_attn_indices=max_num_global_attn_indices,
                is_index_global_attn_nonzero=is_index_global_attn_nonzero,
                is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
            )
        else:
            # compute local attn only
            attn_output = self._sliding_chunks_matmul_attn_probs_value(
                attn_probs, value_vectors, self.one_sided_attn_window_size
            )

        assert attn_output.size() == (batch_size, seq_len, self.n_heads, self.key_value_proj_dim), "Unexpected size"
        attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()

        # compute value for global attention and overwrite to attention output
        # TODO: remove the redundant computation
        if is_global_attn:

            global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
                hidden_states=hidden_states,
                max_num_global_attn_indices=max_num_global_attn_indices,
                layer_head_mask=layer_head_mask,
                is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
                is_index_global_attn_nonzero=is_index_global_attn_nonzero,
                is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
                is_index_masked=is_index_masked,
            )

            # get only non zero global attn output
            nonzero_global_attn_output = global_attn_output[
                is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]
            ]

            # overwrite values with global attention
            attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view(
                len(is_local_index_global_attn_nonzero[0]), -1
            )
            # The attention weights for tokens with global attention are
            # just filler values, they were never used to compute the output.
            # Fill with 0 now, the correct values are in 'global_attn_probs'.
            attn_probs[is_index_global_attn_nonzero] = 0

        attn_output = attn_output.transpose(0, 1)
        # t5 runs the attn_output through o, and expects attn_output to be (batch_size, seq_length, dim)
        attn_output = self.o(attn_output)

        present_key_value_state = None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)

        if output_attentions:
            outputs = outputs + (attn_weights,)

        return outputs + (global_attn_probs,) if (is_global_attn and output_attentions) else outputs

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions