Skip to content

Commit 4d5869b

Browse files
authored
Add bidirectional attention support to attention pattern (#75)
1 parent 8605b16 commit 4d5869b

3 files changed

Lines changed: 44 additions & 3 deletions

File tree

python/circuitsvis/attention.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def attention_heads(
1414
min_value: Optional[float] = None,
1515
negative_color: Optional[str] = None,
1616
positive_color: Optional[str] = None,
17+
mask_upper_tri: Optional[bool] = None,
1718
) -> RenderedHTML:
1819
"""Attention Heads
1920
@@ -37,6 +38,9 @@ def attention_heads(
3738
positive_color: Color for positive values. This can be any valid CSS
3839
color string. Be mindful of color blindness if not using the default
3940
here.
41+
mask_upper_tri: Whether or not to mask the upper triangular portion of
42+
the attention patterns. Should be true for causal attention, false for
43+
bidirectional attention.
4044
4145
Returns:
4246
Html: Attention pattern visualization
@@ -49,6 +53,7 @@ def attention_heads(
4953
"negativeColor": negative_color,
5054
"positiveColor": positive_color,
5155
"tokens": tokens,
56+
"maskUpperTri": mask_upper_tri,
5257
}
5358

5459
return render(
@@ -90,6 +95,7 @@ def attention_pattern(
9095
negative_color: Optional[str] = None,
9196
show_axis_labels: Optional[bool] = None,
9297
positive_color: Optional[str] = None,
98+
mask_upper_tri: Optional[bool] = None,
9399
) -> RenderedHTML:
94100
"""Attention Pattern
95101
@@ -112,6 +118,9 @@ def attention_pattern(
112118
positive_color: Color for positive values. This can be any valid CSS
113119
color string. Be mindful of color blindness if not using the default
114120
here.
121+
mask_upper_tri: Whether or not to mask the upper triangular portion of
122+
the attention patterns. Should be true for causal attention, false for
123+
bidirectional attention.
115124
116125
Returns:
117126
Html: Attention pattern visualization
@@ -124,6 +133,7 @@ def attention_pattern(
124133
"negativeColor": negative_color,
125134
"positiveColor": positive_color,
126135
"showAxisLabels": show_axis_labels,
136+
"maskUpperTri": mask_upper_tri,
127137
}
128138

129139
return render(

react/src/attention/AttentionHeads.tsx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ export function AttentionHeadsSelector({
3434
onMouseEnter,
3535
onMouseLeave,
3636
positiveColor,
37+
maskUpperTri,
3738
tokens
3839
}: AttentionHeadsProps & {
3940
attentionHeadNames: string[];
@@ -89,6 +90,7 @@ export function AttentionHeadsSelector({
8990
minValue={minValue}
9091
negativeColor={negativeColor}
9192
positiveColor={positiveColor}
93+
maskUpperTri={maskUpperTri}
9294
/>
9395
</div>
9496
</div>
@@ -112,6 +114,7 @@ export function AttentionHeads({
112114
minValue,
113115
negativeColor,
114116
positiveColor,
117+
maskUpperTri = true,
115118
tokens
116119
}: AttentionHeadsProps) {
117120
// Attention head focussed state
@@ -137,6 +140,7 @@ export function AttentionHeads({
137140
onMouseEnter={onMouseEnter}
138141
onMouseLeave={onMouseLeave}
139142
positiveColor={positiveColor}
143+
maskUpperTri={maskUpperTri}
140144
tokens={tokens}
141145
/>
142146

@@ -165,6 +169,7 @@ export function AttentionHeads({
165169
negativeColor={negativeColor}
166170
positiveColor={positiveColor}
167171
zoomed={true}
172+
maskUpperTri={maskUpperTri}
168173
tokens={tokens}
169174
/>
170175
</div>
@@ -241,6 +246,17 @@ export interface AttentionHeadsProps {
241246
*/
242247
positiveColor?: string;
243248

249+
/**
250+
* Mask upper triangular
251+
*
252+
* Whether or not to mask the upper triangular portion of the attention patterns.
253+
*
254+
* Should be true for causal attention, false for bidirectional attention.
255+
*
256+
* @default true
257+
*/
258+
maskUpperTri?: boolean;
259+
244260
/**
245261
* Show axis labels
246262
*/

react/src/attention/AttentionPattern.tsx

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ export function AttentionPattern({
6262
upperTriColor = DefaultUpperTriColor,
6363
showAxisLabels = true,
6464
zoomed = false,
65+
maskUpperTri = true,
6566
tokens
6667
}: AttentionPatternProps) {
6768
// Tokens must be unique (for the categories), so we add an index prefix
@@ -96,7 +97,7 @@ export function AttentionPattern({
9697
// Set the background color for each block, based on the attention value
9798
backgroundColor(context: ScriptableContext<"matrix">) {
9899
const block = context.dataset.data[context.dataIndex] as any as Block;
99-
if (block.srcIdx > block.destIdx) {
100+
if (maskUpperTri && block.srcIdx > block.destIdx) {
100101
// Color the upper triangular part separately
101102
return colord(upperTriColor).toRgbString();
102103
}
@@ -130,7 +131,10 @@ export function AttentionPattern({
130131
title: () => "", // Hide the title
131132
label({ raw }: TooltipItem<"matrix">) {
132133
const block = raw as Block;
133-
if (block.destIdx < block.srcIdx) return "N/A"; // Just show N/A for the upper triangular part
134+
if (maskUpperTri && block.destIdx < block.srcIdx) {
135+
// Just show N/A for the upper triangular part
136+
return "N/A";
137+
}
134138
return [
135139
`(${block.destIdx}, ${block.srcIdx})`,
136140
`Src: ${block.srcToken}`,
@@ -259,11 +263,22 @@ export interface AttentionPatternProps {
259263
*/
260264
positiveColor?: string;
261265

266+
/**
267+
* Mask upper triangular
268+
*
269+
* Whether or not to mask the upper triangular portion of the attention patterns.
270+
*
271+
* Should be true for causal attention, false for bidirectional attention.
272+
*
273+
* @default true
274+
*/
275+
maskUpperTri?: boolean;
276+
262277
/**
263278
* Upper triangular color
264279
*
265280
* Color to use for the upper triangular part of the attention pattern to make visualization slightly nicer.
266-
* The upper triangular part is irrelevant because of the causal mask.
281+
* Only applied if maskUpperTri is set to true.
267282
*
268283
* @default rgb(200, 200, 200)
269284
*

0 commit comments

Comments
 (0)