@@ -14,6 +14,7 @@ def attention_heads(
1414 min_value : Optional [float ] = None ,
1515 negative_color : Optional [str ] = None ,
1616 positive_color : Optional [str ] = None ,
17+ mask_upper_tri : Optional [bool ] = None ,
1718) -> RenderedHTML :
1819 """Attention Heads
1920
@@ -37,6 +38,9 @@ def attention_heads(
3738 positive_color: Color for positive values. This can be any valid CSS
3839 color string. Be mindful of color blindness if not using the default
3940 here.
41+ mask_upper_tri: Whether or not to mask the upper triangular portion of
42+ the attention patterns. Should be true for causal attention, false for
43+ bidirectional attention.
4044
4145 Returns:
4246 Html: Attention pattern visualization
@@ -49,6 +53,7 @@ def attention_heads(
4953 "negativeColor" : negative_color ,
5054 "positiveColor" : positive_color ,
5155 "tokens" : tokens ,
56+ "maskUpperTri" : mask_upper_tri ,
5257 }
5358
5459 return render (
@@ -90,6 +95,7 @@ def attention_pattern(
9095 negative_color : Optional [str ] = None ,
9196 show_axis_labels : Optional [bool ] = None ,
9297 positive_color : Optional [str ] = None ,
98+ mask_upper_tri : Optional [bool ] = None ,
9399) -> RenderedHTML :
94100 """Attention Pattern
95101
@@ -112,6 +118,9 @@ def attention_pattern(
112118 positive_color: Color for positive values. This can be any valid CSS
113119 color string. Be mindful of color blindness if not using the default
114120 here.
121+ mask_upper_tri: Whether or not to mask the upper triangular portion of
122+ the attention patterns. Should be true for causal attention, false for
123+ bidirectional attention.
115124
116125 Returns:
117126 Html: Attention pattern visualization
@@ -124,6 +133,7 @@ def attention_pattern(
124133 "negativeColor" : negative_color ,
125134 "positiveColor" : positive_color ,
126135 "showAxisLabels" : show_axis_labels ,
136+ "maskUpperTri" : mask_upper_tri ,
127137 }
128138
129139 return render (
0 commit comments