Skip to content

Commit 8605b16

Browse files
authored
Fix AttentionHead resizing bug (#63)
1 parent a4f6a72 commit 8605b16

File tree

2 files changed

+39
-20
lines changed

2 files changed

+39
-20
lines changed

react/src/attention/AttentionHeads.tsx

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,17 +143,7 @@ export function AttentionHeads({
143143
<Row>
144144
<Col xs={12}>
145145
<h3 style={{ marginBottom: 10 }}>{headNames[focused]} Zoomed</h3>
146-
147-
<div
148-
style={{
149-
position: "relative",
150-
// Set the maximum width such that a head with just a few tokens
151-
// doesn't have crazy large boxes per token. Note this is the
152-
// width of the full chart (including axis labels) so it also
153-
// needs a sensible lowest maximum.
154-
maxWidth: `${Math.max(Math.round(tokens.length * 2.4), 20)}em`
155-
}}
156-
>
146+
<div>
157147
<h2
158148
style={{
159149
position: "absolute",
@@ -174,6 +164,7 @@ export function AttentionHeads({
174164
minValue={minValue}
175165
negativeColor={negativeColor}
176166
positiveColor={positiveColor}
167+
zoomed={true}
177168
tokens={tokens}
178169
/>
179170
</div>

react/src/attention/AttentionPattern.tsx

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ export function AttentionPattern({
6161
positiveColor,
6262
upperTriColor = DefaultUpperTriColor,
6363
showAxisLabels = true,
64+
zoomed = false,
6465
tokens
6566
}: AttentionPatternProps) {
6667
// Tokens must be unique (for the categories), so we add an index prefix
@@ -164,15 +165,37 @@ export function AttentionPattern({
164165

165166
return (
166167
<Col>
167-
<Row style={{ aspectRatio: showAxisLabels ? undefined : "1/1" }}>
168-
<Chart
169-
type="matrix"
170-
options={options}
171-
data={data}
172-
width={1000}
173-
height={1000}
174-
updateMode="none"
175-
/>
168+
<Row>
169+
<div
170+
style={{
171+
// Chart.js charts resizing is weird.
172+
// Responsive chart elements (which all are by default) require the
173+
// parent element to have position: 'relative' and no sibling elements.
174+
// There were previously issues that only occured at particular display
175+
// sizes and zoom levels. See:
176+
// https://github.com/alan-cooney/CircuitsVis/pull/63
177+
// https://www.chartjs.org/docs/latest/configuration/responsive.html#important-note
178+
// https://stackoverflow.com/a/48770978/7086623
179+
position: "relative",
180+
// Set the maximum width of zoomed heads such that a head with just a
181+
// few tokens doesn't have crazy large boxes per token and the chart
182+
// doesn't overflow the screen. Other heads fill their width.
183+
maxWidth: zoomed
184+
? `min(100%, ${Math.round(tokens.length * 8)}em)`
185+
: "initial",
186+
width: zoomed ? "initial" : "100%",
187+
aspectRatio: "1/1"
188+
}}
189+
>
190+
<Chart
191+
type="matrix"
192+
options={options}
193+
data={data}
194+
width={1000}
195+
height={1000}
196+
updateMode="none"
197+
/>
198+
</div>
176199
</Row>
177200
</Col>
178201
);
@@ -255,6 +278,11 @@ export interface AttentionPatternProps {
255278
*/
256279
showAxisLabels?: boolean;
257280

281+
/**
282+
* Is this a zoomed in view?
283+
*/
284+
zoomed?: boolean;
285+
258286
/**
259287
* List of tokens
260288
*

0 commit comments

Comments
 (0)