diff --git a/lib_lanmt_modules.py b/lib_lanmt_modules.py index 18235f5..f5c689a 100644 --- a/lib_lanmt_modules.py +++ b/lib_lanmt_modules.py @@ -68,7 +68,7 @@ def forward(self, x, x_mask, y, y_mask): h1 = residual_connect(h1, x) # Cross-attention h2 = self.layer_norm2(h1) - h2, _ = self.attention(h2, y, y, mask=y_mask) + h2, _ = self.cross_attention(h2, y, y, mask=y_mask) h2 = self.dropout(h2) h2 = residual_connect(h2, h1) # Feed-forward layer