From 9d4f657cdd82fd9494a5d95f7bfd7faff14a2ac7 Mon Sep 17 00:00:00 2001 From: leeliu103 Date: Thu, 29 Jan 2026 19:09:41 +0000 Subject: [PATCH] Add Bank View for GPU shared memory visualization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements a new Bank View mode in the Shared Layout tab that visualizes how tensor elements are distributed across GPU memory banks. This helps identify bank conflicts and understand physical memory layout patterns. Key features: - Bank layout mapping using standard GPU memory model (32 banks × 32-bit width) - Sub-cell rendering for packed elements (8-bit: 4/bank, 16-bit: 2/bank) - 64-bit element support spanning two banks with half-labeling - Rich tooltips showing segment, bank, and logical index for each element - Performance optimized: fast inverse traversal for 32/64-bit elements - Comprehensive cross-validation tests ensuring consistency between logical and bank views Changes: - Added createSharedBankLayout() for bank coordinate mapping - Extended CanvasRenderer with sub-cell support and rectangular cells - Added 32 new tests including cross-view consistency validation - UI updates: bank view selector and bank info panel Co-Authored-By: Claude (claude-sonnet-4.5) --- index.html | 6 +- src/layouts/SharedLayout.test.ts | 225 ++++++++++++++++- src/layouts/SharedLayout.ts | 120 ++++++++- src/tabs/SharedLayoutTab.test.ts | 61 ++++- src/tabs/SharedLayoutTab.ts | 249 ++++++++++++++++++- src/validation/SharedLayoutValidator.test.ts | 4 +- src/validation/SharedLayoutValidator.ts | 8 +- src/visualization/CanvasRenderer.test.ts | 36 ++- src/visualization/CanvasRenderer.ts | 190 +++++++++++--- src/visualization/ViewportController.ts | 20 +- tests/visualization.spec.ts | 18 ++ 11 files changed, 860 insertions(+), 77 deletions(-) diff --git a/index.html b/index.html index 078e31d..b87c54b 100644 --- a/index.html +++ b/index.html @@ -137,7 +137,7 @@

View Mode

View: @@ -232,8 +232,8 @@

Bank Info

32 bits
- Elems/Bank: - 64 + Elems/Bank Cell: + 2
Segments (Bank Row): diff --git a/src/layouts/SharedLayout.test.ts b/src/layouts/SharedLayout.test.ts index 8418370..4bb3bdd 100644 --- a/src/layouts/SharedLayout.test.ts +++ b/src/layouts/SharedLayout.test.ts @@ -4,7 +4,9 @@ import { computeBankInfo, computeRowSwizzle, computeSwizzledColumn, + createSharedBankLayout, createSharedLayout, + SHARED_BANK_COUNT, type SharedLayoutParams, } from './SharedLayout' @@ -193,11 +195,48 @@ describe('computeSwizzledColumn', () => { }) describe('bank calculations', () => { + const assertBankLayoutMatchesAssign = ( + tensorShape: [number, number], + elementBits: number, + overrides?: { bankCount?: number; bankSizeBits?: number } + ): void => { + const bankLayout = createSharedBankLayout(tensorShape, elementBits, overrides) + const totalElements = tensorShape[0] * tensorShape[1] + for (let offset = 0; offset < totalElements; offset++) { + const coords = bankLayout.layout.apply({ offset }) + const assignment = assignBank(offset, elementBits, overrides) + expect(coords.bank).toBe(assignment.bank) + expect(coords.segment).toBe(assignment.segment) + } + } + + const assertBankLayoutMatchesAssignWithHalves = ( + tensorShape: [number, number], + overrides?: { bankCount?: number; bankSizeBits?: number } + ): void => { + const elementBits = 64 + const bankLayout = createSharedBankLayout(tensorShape, elementBits, overrides) + const totalElements = tensorShape[0] * tensorShape[1] + const bankCount = overrides?.bankCount ?? SHARED_BANK_COUNT + for (let offset = 0; offset < totalElements; offset++) { + const base = assignBank(offset, elementBits, overrides) + for (let half = 0; half < 2; half++) { + const coords = bankLayout.layout.apply({ offset, half }) + const baseAddress = base.segment * bankCount + base.bank + const address = baseAddress + half + const expectedBank = ((address % bankCount) + bankCount) % bankCount + const expectedSegment = Math.floor(address / bankCount) + expect(coords.bank).toBe(expectedBank) + expect(coords.segment).toBe(expectedSegment) + } + } + } + it('reports per-bank statistics derived from tensor shape', () => { const info = computeBankInfo(baseParams.tensorShape, 16) expect(info.bankCount).toBe(32) expect(info.bankSizeBits).toBe(32) - expect(info.elementsPerBankRow).toBe(2) + expect(info.elementsPerBankCell).toBe(2) expect(info.segmentsPerBankRow).toBe(128) }) @@ -220,22 +259,22 @@ describe('bank calculations', () => { } for (const bits of bitWidths) { const info = computeBankInfo([64, 64], bits) - expect(info.elementsPerBankRow).toBe(expectedElements[bits]) + expect(info.elementsPerBankCell).toBe(expectedElements[bits]) expect(info.segmentsPerBankRow).toBeGreaterThan(0) } }) it('derives per-bank capacity directly from the configured bank width', () => { const largerBanks = computeBankInfo([8, 8], 16, { bankCount: 64, bankSizeBits: 64 }) - expect(largerBanks.elementsPerBankRow).toBe(4) + expect(largerBanks.elementsPerBankCell).toBe(4) const smallerBanks = computeBankInfo([8, 8], 16, { bankCount: 16, bankSizeBits: 16 }) - expect(smallerBanks.elementsPerBankRow).toBe(1) + expect(smallerBanks.elementsPerBankCell).toBe(1) // Ensure bank count alone does not change the per-bank capacity const moreBanksSameWidth = computeBankInfo([8, 8], 16, { bankCount: 128, bankSizeBits: 32 }) const fewerBanksSameWidth = computeBankInfo([8, 8], 16, { bankCount: 16, bankSizeBits: 32 }) - expect(moreBanksSameWidth.elementsPerBankRow).toBe(fewerBanksSameWidth.elementsPerBankRow) + expect(moreBanksSameWidth.elementsPerBankCell).toBe(fewerBanksSameWidth.elementsPerBankCell) }) it('derives segments for tiny and large tensors', () => { @@ -259,6 +298,182 @@ describe('bank calculations', () => { const far = assignBank(4096, 32, overrides) expect(far.segment).toBeGreaterThan(0) }) + + it('maps offsets to bank layout coordinates for packed elements', () => { + const bankLayout = createSharedBankLayout([8, 8], 16) + for (let offset = 0; offset < 32; offset++) { + const coords = bankLayout.layout.apply({ offset }) + const assignment = assignBank(offset, 16) + expect(coords.bank).toBe(assignment.bank) + expect(coords.segment).toBe(assignment.segment) + } + }) + + it('matches assignBank for 8-bit elements', () => { + assertBankLayoutMatchesAssign([8, 8], 8) + }) + + it('matches assignBank for 32-bit elements', () => { + assertBankLayoutMatchesAssign([8, 8], 32) + }) + + it('groups four 8-bit elements per bank in the bank layout', () => { + const bankLayout = createSharedBankLayout([8, 8], 8) + const offsets = [0, 1, 2, 3] + const positions = offsets.map((offset) => bankLayout.layout.apply({ offset })) + const banks = positions.map((coords) => coords.bank) + const segments = positions.map((coords) => coords.segment) + expect(new Set(banks).size).toBe(1) + expect(banks[0]).toBe(0) + expect(new Set(segments).size).toBe(1) + }) + + it('maps 32-bit elements to single banks without packing slots', () => { + const bankLayout = createSharedBankLayout([8, 8], 32) + expect(bankLayout.slotsPerBankCell).toBe(1) + expect(bankLayout.bankSpan).toBe(1) + expect(bankLayout.layout.getInDimNames()).not.toContain('half') + + const first = bankLayout.layout.apply({ offset: 0 }) + const second = bankLayout.layout.apply({ offset: 1 }) + expect(first.segment).toBe(0) + expect(first.bank).toBe(0) + expect(second.segment).toBe(0) + expect(second.bank).toBe(1) + }) + + it('maps 64-bit elements to bank starts without packing rows', () => { + const bankLayout = createSharedBankLayout([8, 8], 64) + expect(bankLayout.slotsPerBankCell).toBe(1) + expect(bankLayout.layout.getInDimNames()).toContain('half') + + const first = bankLayout.layout.apply({ offset: 0, half: 0 }) + const second = bankLayout.layout.apply({ offset: 0, half: 1 }) + expect(first.segment).toBe(0) + expect(second.segment).toBe(0) + expect(first.bank).toBe(0) + expect(second.bank).toBe(1) + }) + + it('matches assignBank for 64-bit elements across halves', () => { + assertBankLayoutMatchesAssignWithHalves([8, 8]) + }) + + it('matches assignBank with bank overrides for 8-bit elements', () => { + const overrides = { bankCount: 16, bankSizeBits: 32 } + assertBankLayoutMatchesAssign([8, 8], 8, overrides) + }) + + it('matches assignBank with bank overrides for 32-bit elements', () => { + const overrides = { bankCount: 16, bankSizeBits: 32 } + assertBankLayoutMatchesAssign([8, 8], 32, overrides) + }) + + it('matches assignBank with bank overrides for 64-bit elements across halves', () => { + const overrides = { bankCount: 16, bankSizeBits: 32 } + assertBankLayoutMatchesAssignWithHalves([8, 8], overrides) + }) +}) + +describe('logical/bank cross-validation', () => { + const tensorShapes: Array<[number, number]> = [ + [8, 8], + [128, 64], + [256, 128], + ] + const elementBitWidths = [8, 16, 32, 64] + const swizzleCases: Array> = + [ + { vec: 1, perPhase: 1, maxPhase: 1, swizzleMode: 'swizzled' }, + { vec: 2, perPhase: 2, maxPhase: 4, swizzleMode: 'swizzled' }, + { vec: 4, perPhase: 4, maxPhase: 8, swizzleMode: 'swizzled' }, + { vec: 4, perPhase: 2, maxPhase: 4, swizzleMode: 'amdRotating' }, + ] + const order: [number, number] = [0, 1] + + const buildOffsets = (totalElements: number): number[] => { + const offsets = new Set() + offsets.add(0) + offsets.add(totalElements - 1) + for (let power = 1; power < totalElements; power <<= 1) { + offsets.add(power) + offsets.add(power - 1) + offsets.add(power + 1) + } + if (totalElements > 4) { + offsets.add(Math.floor(totalElements / 3)) + offsets.add(Math.floor((2 * totalElements) / 3)) + } + return Array.from(offsets) + .filter((value) => value >= 0 && value < totalElements) + .sort((a, b) => a - b) + } + + it('keeps logical and bank layouts consistent across offsets, shapes, and swizzle settings', () => { + for (const tensorShape of tensorShapes) { + const totalElements = tensorShape[0] * tensorShape[1] + const offsets = buildOffsets(totalElements) + + for (const swizzle of swizzleCases) { + const params: SharedLayoutParams = { + tensorShape, + order, + vec: swizzle.vec, + perPhase: swizzle.perPhase, + maxPhase: swizzle.maxPhase, + swizzleMode: swizzle.swizzleMode, + } + const logicalLayout = createSharedLayout(params) + const inverseLayout = logicalLayout.layout.invert() + const rowCount = tensorShape[logicalLayout.rowDimIndex] + const colCount = tensorShape[logicalLayout.colDimIndex] + + for (const elementBits of elementBitWidths) { + const bankLayout = createSharedBankLayout(tensorShape, elementBits) + + for (const offset of offsets) { + const logicalCoords = logicalLayout.layout.apply({ offset }) + const row = logicalCoords[logicalLayout.rowDimName] + const col = logicalCoords[logicalLayout.colDimName] + + expect(typeof row).toBe('number') + expect(typeof col).toBe('number') + + const expectedRow = Math.floor(offset / colCount) + const unswizzledCol = offset % colCount + const expectedCol = computeSwizzledColumn(expectedRow, unswizzledCol, params) + + expect(row).toBe(expectedRow) + expect(row).toBeGreaterThanOrEqual(0) + expect(row).toBeLessThan(rowCount) + expect(col).toBe(expectedCol) + expect(col).toBeGreaterThanOrEqual(0) + expect(col).toBeLessThan(colCount) + + const roundTripOffset = inverseLayout.apply(logicalCoords).offset + expect(roundTripOffset).toBe(offset) + + const base = assignBank(roundTripOffset, elementBits) + if (bankLayout.bankSpan > 1) { + for (let half = 0; half < bankLayout.bankSpan; half++) { + const coords = bankLayout.layout.apply({ offset: roundTripOffset, half }) + const baseAddress = base.segment * bankLayout.bankCount + base.bank + const address = baseAddress + half + const expectedBank = ((address % bankLayout.bankCount) + bankLayout.bankCount) % bankLayout.bankCount + const expectedSegment = Math.floor(address / bankLayout.bankCount) + expect(coords.bank).toBe(expectedBank) + expect(coords.segment).toBe(expectedSegment) + } + } else { + const coords = bankLayout.layout.apply({ offset: roundTripOffset }) + expect(coords.bank).toBe(base.bank) + expect(coords.segment).toBe(base.segment) + } + } + } + } + } + }) }) describe('validation', () => { diff --git a/src/layouts/SharedLayout.ts b/src/layouts/SharedLayout.ts index 62af498..3710274 100644 --- a/src/layouts/SharedLayout.ts +++ b/src/layouts/SharedLayout.ts @@ -25,10 +25,20 @@ export interface SharedLayoutBuildResult { export interface SharedBankInfo { bankCount: number bankSizeBits: number - elementsPerBankRow: number + elementsPerBankCell: number segmentsPerBankRow: number } +export interface SharedBankLayoutResult { + layout: LinearLayout + bankCount: number + bankSizeBits: number + segmentCount: number + slotsPerBankCell: number + rowCount: number + bankSpan: number +} + export interface BankAssignmentOptions { bankCount?: number bankSizeBits?: number @@ -48,10 +58,15 @@ function assertPowerOfTwo(value: number, context: string): void { throw new Error(`${context} must be a positive integer`) } if ((value & (value - 1)) !== 0) { - throw new Error(`${context} must be a power of two`) + throw new Error(`${context} must be a power of two (received ${value})`) } } +function log2PowerOfTwo(value: number, context: string): number { + assertPowerOfTwo(value, context) + return Math.trunc(Math.log2(value)) +} + function assertDimIndex(value: number, context: string): asserts value is DimIndex { if (value !== 0 && value !== 1) { throw new Error(`${context} must be 0 or 1`) @@ -200,6 +215,103 @@ export function computeSwizzledColumn( return normalizedCol ^ rowSwizzle } +export function createSharedBankLayout( + tensorShape: [number, number], + elementBits: number, + overrides?: BankAssignmentOptions +): SharedBankLayoutResult { + assertPositive(elementBits, 'Element bitwidth') + assertPowerOfTwo(elementBits, 'Element bitwidth') + + const bankCount = overrides?.bankCount ?? SHARED_BANK_COUNT + const bankSizeBits = overrides?.bankSizeBits ?? SHARED_BANK_WIDTH_BITS + assertPowerOfTwo(bankCount, 'Bank count') + assertPowerOfTwo(bankSizeBits, 'Bank size (bits)') + + const totalElements = Math.max(1, tensorShape[0] * tensorShape[1]) + assertPowerOfTwo(totalElements, `Total element count (${totalElements})`) + + const bytesPerElement = Math.max(1, Math.trunc(elementBits / 8)) + const bankWidthBytes = Math.max(1, Math.trunc(bankSizeBits / 8)) + const bankRowBytes = Math.max(1, bankCount * bankWidthBytes) + const totalBytes = totalElements * bytesPerElement + const segmentCount = Math.max(1, Math.ceil(totalBytes / bankRowBytes)) + assertPowerOfTwo(segmentCount, `Bank segment count (${segmentCount})`) + + const elementsPerBankCell = bankSizeBits / elementBits + const usesPacking = elementsPerBankCell >= 1 + const slotsPerBankCell = usesPacking ? Math.trunc(elementsPerBankCell) : 1 + assertPowerOfTwo(slotsPerBankCell, 'Slots per bank cell') + + const rowCount = segmentCount + assertPowerOfTwo(rowCount, 'Bank row count') + + const slotBits = log2PowerOfTwo(slotsPerBankCell, 'Slots per bank cell') + const bankBits = log2PowerOfTwo(bankCount, 'Bank count') + const segmentBits = log2PowerOfTwo(segmentCount, 'Bank segment count') + const offsetBits = log2PowerOfTwo(totalElements, 'Total element count') + + let bankSpan = 1 + if (!usesPacking) { + bankSpan = elementBits / bankSizeBits + assertPowerOfTwo(bankSpan, 'Bank span') + const shift = log2PowerOfTwo(bankSpan, 'Bank span') + const expectedOffsetBits = segmentBits + bankBits - shift + if (offsetBits !== expectedOffsetBits) { + throw new Error('Offset bit count does not match bank grid dimensions') + } + } else { + const maxOffsetBits = slotBits + bankBits + segmentBits + if (offsetBits > maxOffsetBits) { + throw new Error('Offset bit count exceeds packed bank grid dimensions') + } + } + + const offsetBases: number[][] = [] + for (let bit = 0; bit < offsetBits; bit++) { + const offset = 1 << bit + const bankAddress = usesPacking + ? Math.floor(offset / slotsPerBankCell) + : offset * bankSpan + const bank = ((bankAddress % bankCount) + bankCount) % bankCount + const segment = Math.floor(bankAddress / bankCount) + offsetBases.push([segment, bank]) + } + + const basesArray: Array<[string, number[][]]> = [['offset', offsetBases]] + + if (bankSpan > 1) { + const halfBits = log2PowerOfTwo(bankSpan, 'Bank span') + const halfBases: number[][] = [] + for (let bit = 0; bit < halfBits; bit++) { + const half = 1 << bit + const bankAddress = half + const bank = ((bankAddress % bankCount) + bankCount) % bankCount + const segment = Math.floor(bankAddress / bankCount) + halfBases.push([segment, bank]) + } + basesArray.push(['half', halfBases]) + } + + const layout = new LinearLayout( + basesArray, + [ + ['segment', rowCount], + ['bank', bankCount], + ] + ) + + return { + layout, + bankCount, + bankSizeBits, + segmentCount, + slotsPerBankCell, + rowCount, + bankSpan, + } +} + export function computeBankInfo( tensorShape: [number, number], elementBits: number, @@ -213,7 +325,7 @@ export function computeBankInfo( const bytesPerElement = Math.max(1, Math.trunc(elementBits / 8)) const totalBankRowBits = bankCount * bankSizeBits - const elementsPerBankRow = bankSizeBits / elementBits + const elementsPerBankCell = bankSizeBits / elementBits const totalElements = Math.max(1, tensorShape[0] * tensorShape[1]) const totalBytes = totalElements * bytesPerElement const bankRowBytes = Math.max(1, Math.trunc(totalBankRowBits / 8)) @@ -222,7 +334,7 @@ export function computeBankInfo( return { bankCount, bankSizeBits, - elementsPerBankRow, + elementsPerBankCell, segmentsPerBankRow, } } diff --git a/src/tabs/SharedLayoutTab.test.ts b/src/tabs/SharedLayoutTab.test.ts index ae33e90..9fb0526 100644 --- a/src/tabs/SharedLayoutTab.test.ts +++ b/src/tabs/SharedLayoutTab.test.ts @@ -6,6 +6,7 @@ type RendererStub = { setCustomColorProvider: ReturnType screenToGrid: ReturnType getCellInfo: ReturnType + getCellEntries: ReturnType } const { rendererInstances, createRendererStub } = vi.hoisted(() => { @@ -24,6 +25,7 @@ const { rendererInstances, createRendererStub } = vi.hoisted(() => { inputCoords: { offset: 0 }, outputCoords: { dim0: 0, dim1: 0 }, }), + getCellEntries: vi.fn().mockReturnValue(null), }) return { rendererInstances: instances, createRendererStub: buildStub } }) @@ -111,7 +113,7 @@ const setupDom = () => {
@@ -312,5 +314,62 @@ describe('SharedLayoutTab', () => { expectTooltipContains('Segment (Bank Row): 8') expectTooltipContains('Bank: 1') }) + + it('shows bank view tooltip entries for multiple offsets', () => { + const tab = new SharedLayoutTab('shared-layout') + const viewSelect = document.getElementById('shared-view-mode') as HTMLSelectElement + viewSelect.value = 'bank' + viewSelect.dispatchEvent(new Event('change')) + + const renderer = rendererInstances[rendererInstances.length - 1] + renderer?.getCellEntries.mockReturnValue([ + { + threadId: 0, + registerId: 0, + warpId: 0, + position: [0, 0], + sourcePosition: [0, 0], + inputCoords: { offset: 0 }, + outputCoords: { segment: 0, bank: 0 }, + }, + { + threadId: 0, + registerId: 1, + warpId: 0, + position: [0, 0], + sourcePosition: [0, 0], + inputCoords: { offset: 1 }, + outputCoords: { segment: 0, bank: 0 }, + }, + ]) + + const hoverable = tab as unknown as { handleHover: (event: MouseEvent) => void } + hoverable.handleHover(new MouseEvent('mousemove', { clientX: 200, clientY: 150 })) + + expectTooltipContains('Segment (Bank Row): 0') + expectTooltipContains('Bank: 0') + expectTooltipContains('Elements in bank:') + expectTooltipContains('Element 1:') + expectTooltipContains('Offset: 0') + expectTooltipContains('Element 2:') + expectTooltipContains('Offset: 1') + }) + + it('keeps both halves when labeling 64-bit elements in bank cells', () => { + const tab = new SharedLayoutTab('shared-layout') + const buildLabels = tab as unknown as { + buildBankCellLabels: (entries: Array<{ inputCoords?: { offset?: number; half?: number } }>, bankSpan: number) => string[] | null + } + + const labels = buildLabels.buildBankCellLabels( + [ + { inputCoords: { offset: 0, half: 0 } }, + { inputCoords: { offset: 0, half: 1 } }, + ], + 2 + ) + + expect(labels).toEqual(['0:H1', '0:H2']) + }) }) }) diff --git a/src/tabs/SharedLayoutTab.ts b/src/tabs/SharedLayoutTab.ts index 8d5bd59..db8e6a9 100644 --- a/src/tabs/SharedLayoutTab.ts +++ b/src/tabs/SharedLayoutTab.ts @@ -1,12 +1,15 @@ import { CanvasTab, type CanvasTabElements } from './CanvasTab' import { renderSharedControls } from '../ui/renderSharedControls' -import { CanvasRenderer, type CellInfo } from '../visualization/CanvasRenderer' +import { CanvasRenderer, type CellInfo, type SubCellLayout } from '../visualization/CanvasRenderer' import { ParameterForm } from '../ui/ParameterForm' import type { BlockLayoutParams } from '../validation/InputValidator' import { assignBank, computeBankInfo, + createSharedBankLayout, createSharedLayout, + SHARED_BANK_COUNT, + type SharedBankLayoutResult, type SharedLayoutBuildResult, type SharedLayoutParams, } from '../layouts/SharedLayout' @@ -30,6 +33,7 @@ export class SharedLayoutTab extends CanvasTab { private currentParams: SharedLayoutUiParams | null = null private currentLayout: SharedLayoutBuildResult | null = null + private currentBankLayout: SharedBankLayoutResult | null = null private currentElementBits = 16 constructor(tabId: string) { @@ -108,6 +112,7 @@ export class SharedLayoutTab extends CanvasTab { if (!isValid) { this.currentParams = null this.currentLayout = null + this.currentBankLayout = null this.currentElementBits = 16 this.clearBankInfo() this.hideTooltip() @@ -120,6 +125,7 @@ export class SharedLayoutTab extends CanvasTab { } else { this.currentParams = null this.currentLayout = null + this.currentBankLayout = null this.currentElementBits = 16 this.clearBankInfo() } @@ -136,8 +142,31 @@ export class SharedLayoutTab extends CanvasTab { const x = event.clientX - rect.left const y = event.clientY - rect.top const gridPos = renderer.screenToGrid(x, y) - const cellInfo = renderer.getCellInfo(gridPos.row, gridPos.col) + if (this.currentParams.viewMode === 'bank') { + const entries = renderer.getCellEntries(gridPos.row, gridPos.col) + if (!entries || entries.length === 0) { + this.hideTooltip() + return + } + const primary = entries[0] + if (!primary) { + this.hideTooltip() + return + } + const { bank, segment } = this.resolveBankSegment(primary) + const tooltipLines = [ + `
Segment (Bank Row): ${segment}
`, + `
Bank: ${bank}
`, + ] + + const entryLines = this.buildBankTooltipEntries(entries) + tooltipLines.push(...entryLines) + + this.tooltip.show(tooltipLines.join(''), event.clientX, event.clientY) + return + } + const cellInfo = renderer.getCellInfo(gridPos.row, gridPos.col) if (!cellInfo) { this.hideTooltip() return @@ -169,19 +198,38 @@ export class SharedLayoutTab extends CanvasTab { this.hideTooltip() } + private resolveBankSegment(cellInfo: CellInfo): { bank: number; segment: number } { + const offset = this.getCellOffset(cellInfo) + const base = assignBank(offset, this.currentElementBits) + const half = cellInfo.inputCoords?.half + if (typeof half !== 'number' || !Number.isFinite(half)) { + return base + } + + const bankCount = this.currentBankLayout?.bankCount + ?? (this.currentParams + ? computeBankInfo(this.currentParams.tensorShape, this.currentParams.elementBits).bankCount + : SHARED_BANK_COUNT) + const baseAddress = base.segment * bankCount + base.bank + const address = baseAddress + Math.trunc(half) + const bank = ((address % bankCount) + bankCount) % bankCount + const segment = Math.floor(address / bankCount) + return { bank, segment } + } + private refreshVisualization(params: SharedLayoutUiParams): void { this.currentParams = params this.currentElementBits = params.elementBits this.updateBankInfo(params) - if (params.viewMode !== 'logical') { - // Bank view not implemented yet; fallback to logical rendering. - this.currentLayout = null - return - } - try { - this.updateVisualization(params) + const layoutResult = createSharedLayout(params) + this.currentLayout = layoutResult + if (params.viewMode === 'bank') { + this.updateBankVisualization(params, layoutResult) + } else { + this.updateLogicalVisualization(params, layoutResult) + } } catch (error) { console.error('Failed to render shared layout', error) const errorMessage = error instanceof Error ? error.message : String(error) @@ -189,8 +237,11 @@ export class SharedLayoutTab extends CanvasTab { } } - private updateVisualization(params: SharedLayoutUiParams): void { - const layoutResult = createSharedLayout(params) + private updateLogicalVisualization( + params: SharedLayoutUiParams, + layoutResult: SharedLayoutBuildResult + ): void { + this.currentBankLayout = null const rendererParams: BlockLayoutParams = { sizePerThread: [1, 1], threadsPerWarp: [1, 1], @@ -212,6 +263,63 @@ export class SharedLayoutTab extends CanvasTab { ) this.setRenderer(renderer) renderer.setCustomColorProvider((cell) => this.resolveBankForCell(cell, params.elementBits)) + } + + private updateBankVisualization( + params: SharedLayoutUiParams, + layoutResult: SharedLayoutBuildResult + ): void { + const bankLayout = createSharedBankLayout(params.tensorShape, params.elementBits) + this.currentBankLayout = bankLayout + const rendererParams: BlockLayoutParams = { + sizePerThread: [1, 1], + threadsPerWarp: [1, 1], + warpsPerCTA: [1, 1], + order: [0, 1], + tensorShape: [bankLayout.rowCount, bankLayout.bankCount], + } + + const slotsPerBankCell = bankLayout.slotsPerBankCell + const bankSpan = bankLayout.bankSpan + const subCellProvider = + slotsPerBankCell > 1 + ? (entries: CellInfo[]): SubCellLayout | null => + this.buildBankSubCells(entries, slotsPerBankCell) + : undefined + + const baseCellSize = 50 + const cellWidth = slotsPerBankCell > 1 ? baseCellSize * slotsPerBankCell : baseCellSize + const cellHeight = baseCellSize + + const renderer = new CanvasRenderer( + this.canvas, + bankLayout.layout, + rendererParams, + undefined, + { + traversalMode: 'by-output', + colorGrouping: 'warp', + showCellText: true, + customCellTextProvider: + slotsPerBankCell > 1 + ? undefined + : (_cell, entries) => this.buildBankCellLabels(entries, bankSpan), + subCellProvider, + cellWidth, + cellHeight, + forceInputTraversal: slotsPerBankCell > 1, + } + ) + + this.setRenderer(renderer) + renderer.setCustomColorProvider((cell) => { + const bank = cell.outputCoords?.bank + if (typeof bank === 'number' && Number.isFinite(bank)) { + return bank + } + return cell.position[1] + }) + this.currentLayout = layoutResult } @@ -228,11 +336,128 @@ export class SharedLayoutTab extends CanvasTab { return 0 } + private buildBankCellLabels(entries: CellInfo[], bankSpan: number): string[] | null { + const offsets = entries + .map((entry) => { + const offset = entry.inputCoords?.offset + if (typeof offset !== 'number' || !Number.isFinite(offset)) { + return null + } + const half = entry.inputCoords?.half + const halfLabel = + bankSpan > 1 && typeof half === 'number' + ? `:H${half + 1}` + : '' + const halfKey = bankSpan > 1 && typeof half === 'number' ? half : null + const key = halfKey === null ? `${offset}` : `${offset}:${halfKey}` + return { offset, label: `${offset}${halfLabel}`, key } + }) + .filter((value): value is { offset: number; label: string; key: string } => Boolean(value)) + .sort((a, b) => a.offset - b.offset || a.key.localeCompare(b.key)) + + if (offsets.length === 0) { + return null + } + + const uniqueLabels: string[] = [] + const seen = new Set() + offsets.forEach(({ key, label }) => { + if (seen.has(key)) { + return + } + seen.add(key) + uniqueLabels.push(label) + }) + + return uniqueLabels + } + + private buildBankSubCells(entries: CellInfo[], slotsPerBankCell: number): SubCellLayout | null { + if (slotsPerBankCell <= 1) { + return null + } + const columns = slotsPerBankCell + const rows = 1 + const items = entries + .map((entry) => { + const offset = entry.inputCoords?.offset + if (typeof offset !== 'number' || !Number.isFinite(offset)) { + return null + } + const slot = Math.trunc(offset) % slotsPerBankCell + return { index: slot, label: offset.toString() } + }) + .filter((item): item is { index: number; label: string } => Boolean(item)) + + return { rows, cols: columns, items } + } + + private buildBankTooltipEntries(entries: CellInfo[]): string[] { + const currentLayout = this.currentLayout + if (!currentLayout) { + return [] + } + + const logicalEntries = entries + .map((entry) => { + const offset = entry.inputCoords?.offset + if (typeof offset !== 'number' || !Number.isFinite(offset)) { + return null + } + const logicalCoords = currentLayout.layout.apply({ offset }) + const row = this.readOutputCoordinate( + { ...entry, outputCoords: logicalCoords }, + currentLayout.rowDimName, + currentLayout.rowDimIndex + ) + const col = this.readOutputCoordinate( + { ...entry, outputCoords: logicalCoords }, + currentLayout.colDimName, + currentLayout.colDimIndex + ) + const half = entry.inputCoords?.half + const halfLabel = + typeof half === 'number' && Number.isFinite(half) + ? half === 0 + ? 'First half of element:' + : 'Second half of element:' + : null + return { offset, row, col, halfLabel } + }) + .filter((entry): entry is { offset: number; row: number; col: number; halfLabel: string | null } => Boolean(entry)) + .sort((a, b) => a.offset - b.offset || a.row - b.row || a.col - b.col || (a.halfLabel ?? '').localeCompare(b.halfLabel ?? '')) + + if (logicalEntries.length === 0) { + return [] + } + + const lines = ['
Elements in bank:
'] + const useHalfLabels = logicalEntries.some((entry) => entry.halfLabel !== null) + let elementCounter = 1 + logicalEntries.forEach((entry) => { + const header = useHalfLabels + ? entry.halfLabel ?? `Element ${elementCounter}:` + : `Element ${elementCounter}:` + lines.push(`
${header}
`) + lines.push( + `
Offset: ${entry.offset}
` + ) + lines.push( + `
Logical Index: (${entry.row}, ${entry.col})
` + ) + if (!useHalfLabels || entry.halfLabel === null) { + elementCounter += 1 + } + }) + + return lines + } + private updateBankInfo(params: SharedLayoutUiParams): void { const info = computeBankInfo(params.tensorShape, params.elementBits) this.bankInfoElements.bankCount.textContent = info.bankCount.toString() this.bankInfoElements.bankSize.textContent = `${info.bankSizeBits} bits` - this.bankInfoElements.elemsPerBank.textContent = info.elementsPerBankRow.toString() + this.bankInfoElements.elemsPerBank.textContent = info.elementsPerBankCell.toString() this.bankInfoElements.segmentsPerBank.textContent = info.segmentsPerBankRow.toString() } diff --git a/src/validation/SharedLayoutValidator.test.ts b/src/validation/SharedLayoutValidator.test.ts index 6366143..7e91a2d 100644 --- a/src/validation/SharedLayoutValidator.test.ts +++ b/src/validation/SharedLayoutValidator.test.ts @@ -113,12 +113,12 @@ describe('SharedLayoutValidator', () => { expect(result.errors.get('perPhase')).toContain('must not exceed the number of rows') }) - it('emits a warning when bank view is requested', () => { + it('accepts bank view mode', () => { const result = validator.validate({ ...baseParams, viewMode: 'bank', }) expect(result.valid).toBe(true) - expect(result.warnings.get('viewMode')).toContain('Only logical view is implemented') + expect(result.warnings.size).toBe(0) }) }) diff --git a/src/validation/SharedLayoutValidator.ts b/src/validation/SharedLayoutValidator.ts index 67c2c5e..b893bde 100644 --- a/src/validation/SharedLayoutValidator.ts +++ b/src/validation/SharedLayoutValidator.ts @@ -20,7 +20,7 @@ export class SharedLayoutValidator { this.validateSwizzleMode(params.swizzleMode, errors) this.validateSwizzleParams(params, errors, shapeValid, orderValid) this.validateElementBits(params.elementBits, errors) - this.validateViewMode(params.viewMode, warnings) + this.validateViewMode(params.viewMode, errors) return { valid: errors.size === 0, @@ -127,9 +127,9 @@ export class SharedLayoutValidator { } } - private validateViewMode(viewMode: SharedLayoutViewMode, warnings: Map): void { - if (viewMode !== 'logical') { - warnings.set('viewMode', 'Only logical view is implemented') + private validateViewMode(viewMode: SharedLayoutViewMode, errors: Map): void { + if (viewMode !== 'logical' && viewMode !== 'bank') { + errors.set('viewMode', 'View mode must be "logical" or "bank"') } } diff --git a/src/visualization/CanvasRenderer.test.ts b/src/visualization/CanvasRenderer.test.ts index a93f74f..c8bf8a9 100644 --- a/src/visualization/CanvasRenderer.test.ts +++ b/src/visualization/CanvasRenderer.test.ts @@ -5,8 +5,10 @@ import { createBlockLayout } from '../layouts/BlockLayout' import type { BlockLayoutParams } from '../validation/InputValidator' import { LinearLayout } from '../core/LinearLayout' -const getRendererCellSize = (renderer: CanvasRenderer): number => - (renderer as unknown as { cellSize: number }).cellSize +const getRendererCellWidth = (renderer: CanvasRenderer): number => + (renderer as unknown as { cellWidth: number }).cellWidth +const getRendererCellHeight = (renderer: CanvasRenderer): number => + (renderer as unknown as { cellHeight: number }).cellHeight const getRendererViewport = ( renderer: CanvasRenderer @@ -20,10 +22,11 @@ const getCellCenterCoordinates = (renderer: CanvasRenderer, row: number, col: nu y: number } => { const viewport = getRendererViewport(renderer) - const scaledCellSize = getRendererCellSize(renderer) * viewport.scale + const scaledCellWidth = getRendererCellWidth(renderer) * viewport.scale + const scaledCellHeight = getRendererCellHeight(renderer) * viewport.scale return { - x: viewport.offsetX + col * scaledCellSize + scaledCellSize / 2, - y: viewport.offsetY + row * scaledCellSize + scaledCellSize / 2, + x: viewport.offsetX + col * scaledCellWidth + scaledCellWidth / 2, + y: viewport.offsetY + row * scaledCellHeight + scaledCellHeight / 2, } } @@ -618,6 +621,22 @@ describe('CanvasRenderer', () => { } }) + it('returns all entries for a cell with overlapping mappings', () => { + const layout = createRankDeficientLayout() + const params = createRankDeficientParams() + const renderer = new CanvasRenderer( + canvas, + layout, + params, + undefined, + { traversalMode: 'by-output', colorGrouping: 'thread', showCellText: false } + ) + + const entries = renderer.getCellEntries(0, 0) + expect(entries).not.toBeNull() + expect(entries && entries.length).toBeGreaterThan(1) + }) + it('provides accurate input coordinates in the fallback enumeration path', () => { const renderer = new CanvasRenderer( canvas, @@ -727,12 +746,13 @@ describe('CanvasRenderer', () => { expect(strokeRectSpy).toHaveBeenCalledTimes(expectedCells.length) const viewport = getRendererViewport(renderer) - const scaledCellSize = getRendererCellSize(renderer) * viewport.scale + const scaledCellWidth = getRendererCellWidth(renderer) * viewport.scale + const scaledCellHeight = getRendererCellHeight(renderer) * viewport.scale const highlightedCells = strokeRectSpy.mock.calls .map(call => { const [x, y] = call - const col = Math.round((x - viewport.offsetX) / scaledCellSize) - const row = Math.round((y - viewport.offsetY) / scaledCellSize) + const col = Math.round((x - viewport.offsetX) / scaledCellWidth) + const row = Math.round((y - viewport.offsetY) / scaledCellHeight) return `${row},${col}` }) .sort() diff --git a/src/visualization/CanvasRenderer.ts b/src/visualization/CanvasRenderer.ts index ba73bea..dbeee2c 100644 --- a/src/visualization/CanvasRenderer.ts +++ b/src/visualization/CanvasRenderer.ts @@ -13,6 +13,11 @@ interface CanvasRendererOptions { showCellText?: boolean colorInputDimension?: string customColorProvider?: (cell: CellInfo) => number | null | undefined + customCellTextProvider?: (cell: CellInfo, entries: CellInfo[]) => string | string[] | null | undefined + subCellProvider?: (entries: CellInfo[]) => SubCellLayout | null | undefined + cellWidth?: number + cellHeight?: number + forceInputTraversal?: boolean } export interface CellInfo { @@ -30,6 +35,12 @@ export interface CellInfo { outputCoords?: Record } +export interface SubCellLayout { + rows: number + cols: number + items: Array<{ index: number; label: string }> +} + export interface ResolvedPosition { pos: [number, number] registerId: number @@ -66,7 +77,8 @@ export class CanvasRenderer { private ctx: CanvasRenderingContext2D private viewportController: ViewportController private colorScheme: ColorScheme = new ColorScheme(1, 1) - private cellSize = 50 // Base size of each cell in pixels + private cellWidth = 50 // Base width of each cell in pixels + private cellHeight = 50 // Base height of each cell in pixels private isDragging = false private lastMouseX = 0 private lastMouseY = 0 @@ -85,6 +97,9 @@ export class CanvasRenderer { private customColorProvider?: (cell: CellInfo) => number | null | undefined private customColorGroupCount = 0 private maxThreadIdObserved = 0 + private customCellTextProvider?: (cell: CellInfo, entries: CellInfo[]) => string | string[] | null | undefined + private subCellProvider?: (entries: CellInfo[]) => SubCellLayout | null | undefined + private forceInputTraversal: boolean constructor( private canvas: HTMLCanvasElement, @@ -106,13 +121,20 @@ export class CanvasRenderer { const trimmedColorDim = options?.colorInputDimension?.trim() this.customColorDimension = trimmedColorDim && trimmedColorDim.length > 0 ? trimmedColorDim : undefined this.customColorProvider = options?.customColorProvider + this.customCellTextProvider = options?.customCellTextProvider + this.subCellProvider = options?.subCellProvider + this.cellWidth = options?.cellWidth ?? 50 + this.cellHeight = options?.cellHeight ?? 50 + this.forceInputTraversal = options?.forceInputTraversal ?? false this.resetColorSchemeFromParams() // Initialize viewport controller this.viewportController = new ViewportController( canvas, params.tensorShape[1], - params.tensorShape[0] + params.tensorShape[0], + this.cellWidth, + this.cellHeight ) // Set position resolver (default to block layout resolver) @@ -205,11 +227,15 @@ export class CanvasRenderer { } let result: OutputTraversalBuildResult - try { - const inverse = this.layout.invert() - result = this.buildOutputCacheUsingInverse(outputDims, inverse) - } catch { + if (this.forceInputTraversal) { result = this.buildOutputCacheFromInputs(outputDims) + } else { + try { + const inverse = this.layout.invert() + result = this.buildOutputCacheUsingInverse(outputDims, inverse) + } catch { + result = this.buildOutputCacheFromInputs(outputDims) + } } this.maxThreadIdObserved = Math.max(result.maxThreadId, 0) @@ -477,19 +503,19 @@ export class CanvasRenderer { // Draw vertical lines for (let col = 0; col <= cols; col++) { - const x = col * this.cellSize * viewport.scale + viewport.offsetX + const x = col * this.cellWidth * viewport.scale + viewport.offsetX this.ctx.beginPath() this.ctx.moveTo(x, viewport.offsetY) - this.ctx.lineTo(x, rows * this.cellSize * viewport.scale + viewport.offsetY) + this.ctx.lineTo(x, rows * this.cellHeight * viewport.scale + viewport.offsetY) this.ctx.stroke() } // Draw horizontal lines for (let row = 0; row <= rows; row++) { - const y = row * this.cellSize * viewport.scale + viewport.offsetY + const y = row * this.cellHeight * viewport.scale + viewport.offsetY this.ctx.beginPath() this.ctx.moveTo(viewport.offsetX, y) - this.ctx.lineTo(cols * this.cellSize * viewport.scale + viewport.offsetX, y) + this.ctx.lineTo(cols * this.cellWidth * viewport.scale + viewport.offsetX, y) this.ctx.stroke() } } @@ -508,17 +534,27 @@ export class CanvasRenderer { const cellInfo = cellEntries[0]! - const x = col * this.cellSize * viewport.scale + viewport.offsetX - const y = row * this.cellSize * viewport.scale + viewport.offsetY - const size = this.cellSize * viewport.scale + const x = col * this.cellWidth * viewport.scale + viewport.offsetX + const y = row * this.cellHeight * viewport.scale + viewport.offsetY + const width = this.cellWidth * viewport.scale + const height = this.cellHeight * viewport.scale + const minSize = Math.min(width, height) // Fill cell with color determined by current grouping this.ctx.fillStyle = this.getCellFillColor(cellInfo) - this.ctx.fillRect(x, y, size, size) + this.ctx.fillRect(x, y, width, height) // Draw text if cell is large enough - if (this.showCellText && size > 20) { - this.drawCellText(cellInfo, x, y, size) + if (this.showCellText && minSize > 20) { + const subCells = this.subCellProvider?.(cellEntries) + if (subCells && subCells.rows * subCells.cols > 1) { + this.drawSubCells(subCells, x, y, width, height) + } else { + const lines = this.getCellTextLines(cellInfo, cellEntries) + if (lines && lines.length > 0) { + this.drawCellText(lines, x, y, width, height) + } + } } } } @@ -548,24 +584,106 @@ export class CanvasRenderer { /** * Draw text inside a cell */ + private getCellTextLines(cellInfo: CellInfo, entries: CellInfo[]): string[] | null { + if (this.customCellTextProvider) { + const provided = this.customCellTextProvider(cellInfo, entries) + if (Array.isArray(provided)) { + const lines = provided.map((line) => line.trim()).filter((line) => line.length > 0) + return lines.length > 0 ? lines : null + } + if (typeof provided === 'string') { + const trimmed = provided.trim() + return trimmed.length > 0 ? [trimmed] : null + } + return null + } + + const threadLabel = cellInfo.threadId.toString().padStart(2, '0') + const registerLabel = cellInfo.registerId.toString().padStart(2, '0') + return [`T${threadLabel}::${registerLabel}`] + } + private drawCellText( - cellInfo: CellInfo, + lines: string[], x: number, y: number, - size: number + width: number, + height: number ): void { - const threadLabel = cellInfo.threadId.toString().padStart(2, '0') - const registerLabel = cellInfo.registerId.toString().padStart(2, '0') - const text = `T${threadLabel}::${registerLabel}` + const lineCount = Math.max(lines.length, 1) + const minDim = Math.min(width, height) + const maxFont = Math.min(minDim / 4, 16) + const minFont = 8 + const usableHeight = height * 0.8 + const lineHeight = Math.max(minFont, Math.min(maxFont, usableHeight / lineCount)) + const fontSize = Math.max(minFont, lineHeight) + + this.ctx.font = `${fontSize}px monospace` + this.ctx.fillStyle = '#000000' + this.ctx.textAlign = 'center' + this.ctx.textBaseline = 'middle' + + const totalHeight = lineHeight * lineCount + const startY = y + height / 2 - totalHeight / 2 + lineHeight / 2 + + lines.forEach((line, index) => { + const lineY = startY + index * lineHeight + this.ctx.fillText(line, x + width / 2, lineY) + }) + } + + private drawSubCells(layout: SubCellLayout, x: number, y: number, width: number, height: number): void { + const rows = Math.max(1, layout.rows) + const cols = Math.max(1, layout.cols) + const subWidth = width / cols + const subHeight = height / rows + + this.ctx.strokeStyle = 'rgba(0, 0, 0, 0.2)' + this.ctx.lineWidth = 1 - // Calculate font size based on cell size - const fontSize = Math.max(8, Math.min(size / 4, 16)) + for (let col = 1; col < cols; col++) { + const lineX = x + col * subWidth + this.ctx.beginPath() + this.ctx.moveTo(lineX, y) + this.ctx.lineTo(lineX, y + height) + this.ctx.stroke() + } + + for (let row = 1; row < rows; row++) { + const lineY = y + row * subHeight + this.ctx.beginPath() + this.ctx.moveTo(x, lineY) + this.ctx.lineTo(x + width, lineY) + this.ctx.stroke() + } + + if (layout.items.length === 0) { + return + } + + const fontSize = Math.max(6, Math.min(subWidth, subHeight) * 0.45) this.ctx.font = `${fontSize}px monospace` this.ctx.fillStyle = '#000000' this.ctx.textAlign = 'center' this.ctx.textBaseline = 'middle' - this.ctx.fillText(text, x + size / 2, y + size / 2) + layout.items.forEach(({ index, label }) => { + if (!Number.isFinite(index)) { + return + } + const normalizedIndex = Math.trunc(index) + if (normalizedIndex < 0) { + return + } + const subRow = Math.floor(normalizedIndex / cols) + const subCol = normalizedIndex % cols + if (subRow >= rows || subCol >= cols) { + return + } + const centerX = x + subCol * subWidth + subWidth / 2 + const centerY = y + subRow * subHeight + subHeight / 2 + this.ctx.fillText(label, centerX, centerY) + }) } /** @@ -576,13 +694,14 @@ export class CanvasRenderer { viewport: { offsetX: number; offsetY: number; scale: number } ): void { const [row, col] = cell - const x = col * this.cellSize * viewport.scale + viewport.offsetX - const y = row * this.cellSize * viewport.scale + viewport.offsetY - const size = this.cellSize * viewport.scale + const x = col * this.cellWidth * viewport.scale + viewport.offsetX + const y = row * this.cellHeight * viewport.scale + viewport.offsetY + const width = this.cellWidth * viewport.scale + const height = this.cellHeight * viewport.scale this.ctx.strokeStyle = '#ff0000' this.ctx.lineWidth = 3 - this.ctx.strokeRect(x, y, size, size) + this.ctx.strokeRect(x, y, width, height) } /** @@ -599,6 +718,17 @@ export class CanvasRenderer { return entries && entries.length > 0 ? entries[0]! : null } + getCellEntries(row: number, col: number): CellInfo[] | null { + if (row < 0 || row >= this.params.tensorShape[0] || + col < 0 || col >= this.params.tensorShape[1]) { + return null + } + + const key = `${row},${col}` + const entries = this.cellDataCache.get(key) + return entries && entries.length > 0 ? entries : null + } + /** * Handle mouse down event (start dragging) */ @@ -758,7 +888,9 @@ export class CanvasRenderer { this.viewportController = new ViewportController( this.canvas, this.params.tensorShape[1], - this.params.tensorShape[0] + this.params.tensorShape[0], + this.cellWidth, + this.cellHeight ) } diff --git a/src/visualization/ViewportController.ts b/src/visualization/ViewportController.ts index fa4f028..ea54d9f 100644 --- a/src/visualization/ViewportController.ts +++ b/src/visualization/ViewportController.ts @@ -12,7 +12,9 @@ export class ViewportController { constructor( private canvas: HTMLCanvasElement, private gridWidth: number, - private gridHeight: number + private gridHeight: number, + private cellWidth = 50, + private cellHeight = 50 ) { // Start centered and scaled to fit this.viewport = { @@ -53,13 +55,13 @@ export class ViewportController { reset(): void { // Calculate scale to fit grid in canvas with some padding const padding = 40 - const scaleX = (this.canvas.width - padding * 2) / (this.gridWidth * 50) - const scaleY = (this.canvas.height - padding * 2) / (this.gridHeight * 50) + const scaleX = (this.canvas.width - padding * 2) / (this.gridWidth * this.cellWidth) + const scaleY = (this.canvas.height - padding * 2) / (this.gridHeight * this.cellHeight) const scale = Math.min(scaleX, scaleY, 1) this.viewport.scale = scale - this.viewport.offsetX = (this.canvas.width - this.gridWidth * 50 * scale) / 2 - this.viewport.offsetY = (this.canvas.height - this.gridHeight * 50 * scale) / 2 + this.viewport.offsetX = (this.canvas.width - this.gridWidth * this.cellWidth * scale) / 2 + this.viewport.offsetY = (this.canvas.height - this.gridHeight * this.cellHeight * scale) / 2 } getViewport(): Readonly { @@ -67,8 +69,8 @@ export class ViewportController { } screenToGrid(screenX: number, screenY: number): { col: number; row: number } { - const gridX = (screenX - this.viewport.offsetX) / this.viewport.scale / 50 - const gridY = (screenY - this.viewport.offsetY) / this.viewport.scale / 50 + const gridX = (screenX - this.viewport.offsetX) / this.viewport.scale / this.cellWidth + const gridY = (screenY - this.viewport.offsetY) / this.viewport.scale / this.cellHeight return { col: Math.floor(gridX), row: Math.floor(gridY), @@ -77,8 +79,8 @@ export class ViewportController { gridToScreen(col: number, row: number): { x: number; y: number } { return { - x: col * 50 * this.viewport.scale + this.viewport.offsetX, - y: row * 50 * this.viewport.scale + this.viewport.offsetY, + x: col * this.cellWidth * this.viewport.scale + this.viewport.offsetX, + y: row * this.cellHeight * this.viewport.scale + this.viewport.offsetY, } } } diff --git a/tests/visualization.spec.ts b/tests/visualization.spec.ts index ab7f552..6f5be3a 100644 --- a/tests/visualization.spec.ts +++ b/tests/visualization.spec.ts @@ -235,6 +235,24 @@ test.describe('Shared Layout', () => { expect(Buffer.compare(initialScreenshot, updatedScreenshot)).not.toBe(0) }) + test('Bank view shows offsets in the tooltip', async ({ page }) => { + await openSharedLayoutTab(page) + + await page.locator('#shared-view-mode').selectOption('bank') + + const canvas = page.locator('#shared-canvas') + await expect(canvas).toBeVisible() + + const box = await canvas.boundingBox() + expect(box).not.toBeNull() + await page.mouse.move(box!.x + box!.width / 2, box!.y + box!.height / 2) + + const tooltip = page.locator('.layout-tooltip') + await expect(tooltip).toBeVisible() + await expect(tooltip).toContainText('Bank:') + await expect(tooltip).toContainText('Elements in bank:') + }) + test('Tooltip shows Shared Layout metadata without legacy fields', async ({ page }) => { await openSharedLayoutTab(page)