Skip to content

Possible bug in DisentangledSelfAttention? #15

@rcalef

Description

@rcalef

Thanks for the great work and open-sourced codebase for your model!

I've been working on a project that involves fine-tuning your model, and ran into one issue. I was fine-tuning using DDP and getting errors about some parameters being unused, and I found that all the modules that require gradients but aren't receiving any are the ss_q_proj layers in each encoder layer's attention block (i.e. prosst.encoder.layer[0-11].attention.self.ss_q_proj.weight and prosst.encoder.layer[0-11].attention.self.ss_q_proj.bias).

Looking into the code here, I noticed that in this block:

if "ss2aa" in self.pos_att_type:
  assert ss_hidden_states is not None
  ss_query_layer = self.ss_q_proj(ss_hidden_states)
  ss_query_layer = self.transpose_for_scores(ss_query_layer)
  ss_query_layer /= torch.sqrt(
    torch.tensor(ss_query_layer.size(-1), dtype=torch.float)
    * scale_factor
  )
  ss2aa_att = torch.matmul(
    key_layer, query_layer.transpose(-1, -2).to(dtype=key_layer.dtype)
  )
  score += ss2aa_att

it looks like ss_query_layer is created but not used. Should the line key_layer, query_layer.transpose(-1, -2).to(dtype=key_layer.dtype) instead be key_layer, ss_query_layer.transpose(-1, -2).to(dtype=key_layer.dtype)? It seems like that would be more in line with what it looks like the code was intended to do.

Apologies if I'm misunderstanding the codebase or method!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions