diff --git a/src/tabs/LinearLayoutTab.ts b/src/tabs/LinearLayoutTab.ts index ee21ea3..de7689d 100644 --- a/src/tabs/LinearLayoutTab.ts +++ b/src/tabs/LinearLayoutTab.ts @@ -313,11 +313,13 @@ export class LinearLayoutTab extends CanvasTab { try { const layout = LinearLayout.fromBitMatrix(matrix, inputs, outputs) - if (!layout.isInvertible()) { - throw new Error('Layout matrix is not invertible. Ensure input/output bit counts match and the matrix has full rank.') - } + const isInvertible = layout.isInvertible() this.layout = layout - this.setLayoutStatus('') + if (isInvertible) { + this.setLayoutStatus('') + } else { + this.setLayoutStatus('Layout matrix is not invertible. Showing partial mapping for valid outputs.') + } this.updateRendererFromLayout() } catch (error) { const message = error instanceof Error ? error.message : String(error) diff --git a/src/visualization/CanvasRenderer.test.ts b/src/visualization/CanvasRenderer.test.ts index b704070..438e53e 100644 --- a/src/visualization/CanvasRenderer.test.ts +++ b/src/visualization/CanvasRenderer.test.ts @@ -1,10 +1,32 @@ import { describe, it, expect, beforeEach, vi } from 'vitest' -import { CanvasRenderer, type PositionResolver } from './CanvasRenderer' +import { CanvasRenderer, type PositionResolver, type CellInfo } from './CanvasRenderer' import { ColorScheme } from './ColorScheme' import { createBlockLayout } from '../layouts/BlockLayout' import type { BlockLayoutParams } from '../validation/InputValidator' import { LinearLayout } from '../core/LinearLayout' +const getRendererCellSize = (renderer: CanvasRenderer): number => + (renderer as unknown as { cellSize: number }).cellSize + +const getRendererViewport = ( + renderer: CanvasRenderer +): { offsetX: number; offsetY: number; scale: number } => + (renderer as unknown as { + viewportController: { getViewport(): { offsetX: number; offsetY: number; scale: number } } + }).viewportController.getViewport() + +const getCellCenterCoordinates = (renderer: CanvasRenderer, row: number, col: number): { + x: number + y: number +} => { + const viewport = getRendererViewport(renderer) + const scaledCellSize = getRendererCellSize(renderer) * viewport.scale + return { + x: viewport.offsetX + col * scaledCellSize + scaledCellSize / 2, + y: viewport.offsetY + row * scaledCellSize + scaledCellSize / 2, + } +} + describe('CanvasRenderer', () => { let canvas: HTMLCanvasElement let renderer: CanvasRenderer @@ -131,26 +153,28 @@ describe('CanvasRenderer', () => { const warpSpy = vi.spyOn(ColorScheme.prototype, 'getColorForWarp') const threadSpy = vi.spyOn(ColorScheme.prototype, 'getColorForThread') - const renderer = new CanvasRenderer( - canvas, - layout, - params, - undefined, - { - traversalMode: 'by-output', - colorGrouping: 'thread', - showCellText: false, - colorInputDimension: 'lane', - } - ) - - renderer.render() + try { + const renderer = new CanvasRenderer( + canvas, + layout, + params, + undefined, + { + traversalMode: 'by-output', + colorGrouping: 'thread', + showCellText: false, + colorInputDimension: 'lane', + } + ) - expect(warpSpy).toHaveBeenCalled() - expect(threadSpy).not.toHaveBeenCalled() + renderer.render() - warpSpy.mockRestore() - threadSpy.mockRestore() + expect(warpSpy).toHaveBeenCalled() + expect(threadSpy).not.toHaveBeenCalled() + } finally { + warpSpy.mockRestore() + threadSpy.mockRestore() + } }) }) @@ -258,12 +282,10 @@ describe('CanvasRenderer', () => { replicaRenderer.render() strokeRectSpy.mockClear() - const cellSize = 50 - const offsetX = (canvas.width - params.tensorShape[1] * cellSize) / 2 - const offsetY = (canvas.height - params.tensorShape[0] * cellSize) / 2 + const firstCellCenter = getCellCenterCoordinates(replicaRenderer, 0, 0) const event = new MouseEvent('mousemove', { - clientX: offsetX + cellSize / 2, - clientY: offsetY + cellSize / 2, + clientX: firstCellCenter.x, + clientY: firstCellCenter.y, }) replicaRenderer.handleMouseMove(event) @@ -380,6 +402,31 @@ describe('CanvasRenderer', () => { }) describe('output traversal mode', () => { + const createRankDeficientLayout = () => + LinearLayout.fromBitMatrix( + [ + [1, 0, 0], + [1, 0, 0], + [0, 0, 1], + ], + [ + { name: 'thread', size: 4 }, + { name: 'register', size: 2 }, + ], + [ + { name: 'row', size: 4 }, + { name: 'col', size: 2 }, + ] + ) + + const createRankDeficientParams = (): BlockLayoutParams => ({ + sizePerThread: [1, 1] as [number, number], + threadsPerWarp: [4, 1] as [number, number], + warpsPerCTA: [1, 1] as [number, number], + order: [0, 1] as [number, number], + tensorShape: [4, 2] as [number, number], + }) + it('resolves cells directly from output coordinates', () => { const tensorParams: BlockLayoutParams = { sizePerThread: [1, 1] as [number, number], @@ -417,5 +464,267 @@ describe('CanvasRenderer', () => { const cacheSize = (outputRenderer as unknown as { cellDataCache: Map }).cellDataCache.size expect(cacheSize).toBe(tensorParams.tensorShape[0] * tensorParams.tensorShape[1]) }) + + it('falls back to partial mappings when the layout is not invertible', () => { + const nonInvertibleLayout = LinearLayout.fromBitMatrix( + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 1, 1], + ], + [ + { name: 'thread', size: 4 }, + { name: 'register', size: 2 }, + ], + [ + { name: 'row', size: 8 }, + { name: 'col', size: 2 }, + ] + ) + expect(nonInvertibleLayout.isInvertible()).toBe(false) + + const params: BlockLayoutParams = { + sizePerThread: [1, 1] as [number, number], + threadsPerWarp: [4, 1] as [number, number], + warpsPerCTA: [1, 1] as [number, number], + order: [0, 1] as [number, number], + tensorShape: [8, 2] as [number, number], + } + + const renderer = new CanvasRenderer( + canvas, + nonInvertibleLayout, + params, + undefined, + { traversalMode: 'by-output', colorGrouping: 'thread', showCellText: false } + ) + + const populatedCell = renderer.getCellInfo(0, 0) + expect(populatedCell).not.toBeNull() + expect(populatedCell?.outputCoords?.row).toBe(0) + expect(populatedCell?.outputCoords?.col).toBe(0) + + const unmappedCell = renderer.getCellInfo(7, 0) + expect(unmappedCell).toBeNull() + + const cacheSize = (renderer as unknown as { cellDataCache: Map }).cellDataCache.size + expect(cacheSize).toBe(8) + }) + + it('handles rank-deficient matrices where multiple inputs map to the same output cell', () => { + const layout = createRankDeficientLayout() + const params = createRankDeficientParams() + expect(layout.isInvertible()).toBe(false) + + const renderer = new CanvasRenderer( + canvas, + layout, + params, + undefined, + { traversalMode: 'by-output', colorGrouping: 'thread', showCellText: false } + ) + + const cache = (renderer as unknown as { cellDataCache: Map }).cellDataCache + const registerCount = layout.getInDimSize('register') + const totalThreads = + params.threadsPerWarp[0] * + params.threadsPerWarp[1] * + params.warpsPerCTA[0] * + params.warpsPerCTA[1] + + const expectedEntries = new Map< + string, + Array<{ + threadId: number + registerId: number + inputCoords: { thread: number; register: number } + outputCoords: { row: number; col: number } + }> + >() + + for (let threadId = 0; threadId < totalThreads; threadId++) { + for (let registerId = 0; registerId < registerCount; registerId++) { + const outputCoords = layout.apply({ thread: threadId, register: registerId }) + const row = outputCoords.row + const col = outputCoords.col + expect(typeof row).toBe('number') + expect(typeof col).toBe('number') + if (typeof row !== 'number' || typeof col !== 'number') { + continue + } + const key = `${row},${col}` + const entry = { + threadId, + registerId, + inputCoords: { thread: threadId, register: registerId }, + outputCoords: { row, col }, + } + const bucket = expectedEntries.get(key) + if (bucket) { + bucket.push(entry) + } else { + expectedEntries.set(key, [entry]) + } + } + } + + const expectedKeys = [...expectedEntries.keys()].sort() + const actualKeys = [...cache.keys()].sort() + + expect(cache.size).toBe(expectedEntries.size) + expect(actualKeys).toEqual(expectedKeys) + + for (const [key, expectedCells] of expectedEntries) { + const collision = cache.get(key) + expect(collision).toBeDefined() + if (!collision) { + continue + } + expect(collision).toHaveLength(expectedCells.length) + const matchers = expectedCells.map(entry => + expect.objectContaining({ + threadId: entry.threadId, + registerId: entry.registerId, + inputCoords: expect.objectContaining(entry.inputCoords), + outputCoords: expect.objectContaining(entry.outputCoords), + }) + ) + expect(collision).toEqual(expect.arrayContaining(matchers)) + } + }) + + it('provides accurate input coordinates in the fallback enumeration path', () => { + const renderer = new CanvasRenderer( + canvas, + createRankDeficientLayout(), + createRankDeficientParams(), + undefined, + { traversalMode: 'by-output', colorGrouping: 'thread', showCellText: false } + ) + + const cell00 = renderer.getCellInfo(0, 0) + expect(cell00).not.toBeNull() + expect([0, 2]).toContain(cell00?.threadId) + expect([0]).toContain(cell00?.registerId) + expect(cell00?.inputCoords).toMatchObject({ thread: 0, register: 0 }) + expect(cell00?.outputCoords).toMatchObject({ row: 0, col: 0 }) + + const cell31 = renderer.getCellInfo(3, 1) + expect(cell31).not.toBeNull() + expect([1, 3]).toContain(cell31?.threadId) + expect([1]).toContain(cell31?.registerId) + expect(cell31?.inputCoords).toMatchObject({ thread: 1, register: 1 }) + expect(cell31?.outputCoords).toMatchObject({ row: 3, col: 1 }) + }) + + it('supports color-by-input dimension when using the fallback path', () => { + const layout = createRankDeficientLayout() + const params = createRankDeficientParams() + const warpSpy = vi.spyOn(ColorScheme.prototype, 'getColorForWarp') + const threadSpy = vi.spyOn(ColorScheme.prototype, 'getColorForThread') + + const renderer = new CanvasRenderer( + canvas, + layout, + params, + undefined, + { + traversalMode: 'by-output', + colorGrouping: 'thread', + showCellText: false, + colorInputDimension: 'register', + } + ) + + try { + renderer.render() + const colorCalls = warpSpy.mock.calls.map(call => call[0] as number) + expect(colorCalls).toContain(0) + expect(colorCalls).toContain(1) + expect(threadSpy).not.toHaveBeenCalled() + } finally { + warpSpy.mockRestore() + threadSpy.mockRestore() + } + }) + + it('highlights overlapping cells when hovering in by-output traversal mode', () => { + const layout = createRankDeficientLayout() + const params = createRankDeficientParams() + const renderer = new CanvasRenderer( + canvas, + layout, + params, + undefined, + { traversalMode: 'by-output', colorGrouping: 'thread', showCellText: false } + ) + + const ctx = canvas.getContext('2d')! + const strokeRectSpy = vi.spyOn(ctx, 'strokeRect') + const rectSpy = vi.spyOn(canvas, 'getBoundingClientRect').mockReturnValue({ + width: canvas.width, + height: canvas.height, + top: 0, + left: 0, + right: canvas.width, + bottom: canvas.height, + x: 0, + y: 0, + toJSON: () => ({}), + }) + + try { + renderer.render() + + const sourceIndex = (renderer as unknown as { + sourcePositionIndex: Map> + }).sourcePositionIndex + + const firstEntry = sourceIndex.get('0,0') + const secondEntry = sourceIndex.get('3,1') + + expect(firstEntry).toBeDefined() + expect(firstEntry).toHaveLength(2) + expect(secondEntry).toBeDefined() + expect(secondEntry).toHaveLength(2) + + strokeRectSpy.mockClear() + const hoverAndAssert = (targetCell: [number, number], expectedCells: Array<[number, number]>) => { + const cellCenter = getCellCenterCoordinates(renderer, targetCell[0], targetCell[1]) + const event = new MouseEvent('mousemove', { + clientX: cellCenter.x, + clientY: cellCenter.y, + }) + + strokeRectSpy.mockClear() + renderer.handleMouseMove(event) + + expect(strokeRectSpy).toHaveBeenCalledTimes(expectedCells.length) + + const viewport = getRendererViewport(renderer) + const scaledCellSize = getRendererCellSize(renderer) * viewport.scale + const highlightedCells = strokeRectSpy.mock.calls + .map(call => { + const [x, y] = call + const col = Math.round((x - viewport.offsetX) / scaledCellSize) + const row = Math.round((y - viewport.offsetY) / scaledCellSize) + return `${row},${col}` + }) + .sort() + + const expected = expectedCells.map(([row, col]) => `${row},${col}`).sort() + expect(highlightedCells).toEqual(expected) + } + + if (firstEntry && secondEntry) { + hoverAndAssert(firstEntry[0]!, firstEntry) + hoverAndAssert(secondEntry[0]!, secondEntry) + } + } finally { + strokeRectSpy.mockRestore() + rectSpy.mockRestore() + } + }) }) }) diff --git a/src/visualization/CanvasRenderer.ts b/src/visualization/CanvasRenderer.ts index 57ee35c..beab541 100644 --- a/src/visualization/CanvasRenderer.ts +++ b/src/visualization/CanvasRenderer.ts @@ -35,6 +35,11 @@ export interface ResolvedPosition { sourcePos?: [number, number] } +interface OutputTraversalBuildResult { + cache: Map + maxThreadId: number +} + /** * Position resolver function type - returns positions for a given thread. * @@ -187,16 +192,39 @@ export class CanvasRenderer { } private buildCellDataCacheByOutput(): Map { - const cache = new Map() - this.sourcePositionIndex = new Map>() const outputDims = this.layout.getOutDims() if (outputDims.length === 0) { - return cache + this.sourcePositionIndex = new Map>() + this.maxThreadIdObserved = 0 + return new Map() + } + + let result: OutputTraversalBuildResult + try { + const inverse = this.layout.invert() + result = this.buildOutputCacheUsingInverse(outputDims, inverse) + } catch { + result = this.buildOutputCacheFromInputs(outputDims) + } + + this.maxThreadIdObserved = Math.max(result.maxThreadId, 0) + if (this.customColorDimension) { + this.resetColorSchemeForCustomDimension() + } else { + this.rebuildColorSchemeFromThreads(result.maxThreadId) } + return result.cache + } - const inverse = this.layout.invert() + private buildOutputCacheUsingInverse( + outputDims: Array<[string, number]>, + inverse: LinearLayout + ): OutputTraversalBuildResult { + const cache = new Map() + this.sourcePositionIndex = new Map>() const coords: Record = {} let maxThreadId = 0 + const warpSize = Math.max(this.params.threadsPerWarp[0] * this.params.threadsPerWarp[1], 1) const traverse = (index: number) => { if (index === outputDims.length) { @@ -204,21 +232,12 @@ export class CanvasRenderer { const inputCoords = inverse.apply(outputCoords) const threadId = inputCoords.thread ?? 0 const registerId = inputCoords.register ?? 0 - const warpSize = Math.max(this.params.threadsPerWarp[0] * this.params.threadsPerWarp[1], 1) const warpId = this.colorGrouping === 'thread' ? threadId : Math.floor(threadId / warpSize) maxThreadId = Math.max(maxThreadId, threadId) - const primaryDim = outputDims[0]?.[0] - const secondaryDim = outputDims[1]?.[0] - const row = secondaryDim ? (outputCoords[primaryDim ?? ''] ?? 0) : 0 - const col = secondaryDim - ? outputCoords[secondaryDim] ?? 0 - : primaryDim - ? outputCoords[primaryDim] ?? 0 - : 0 - + const [row, col] = this.resolveDisplayPosition(outputDims, outputCoords) const key = `${row},${col}` const sourcePosition: [number, number] = [row, col] const cellInfo: CellInfo = { @@ -258,13 +277,100 @@ export class CanvasRenderer { } traverse(0) - this.maxThreadIdObserved = Math.max(maxThreadId, 0) - if (this.customColorDimension) { - this.resetColorSchemeForCustomDimension() - } else { - this.rebuildColorSchemeFromThreads(maxThreadId) + return { cache, maxThreadId } + } + + private buildOutputCacheFromInputs( + outputDims: Array<[string, number]> + ): OutputTraversalBuildResult { + const cache = new Map() + this.sourcePositionIndex = new Map>() + const inputDims = this.layout.getInDimNames().map((name) => [name, this.layout.getInDimSize(name)] as [string, number]) + const coords: Record = {} + const warpSize = Math.max(this.params.threadsPerWarp[0] * this.params.threadsPerWarp[1], 1) + let maxThreadId = 0 + + const traverse = (index: number) => { + if (index === inputDims.length) { + const inputCoords = { ...coords } + const outputCoords = this.layout.apply(inputCoords) + if (!this.isOutputWithinBounds(outputCoords, outputDims)) { + return + } + + const threadId = inputCoords.thread ?? 0 + const registerId = inputCoords.register ?? 0 + const warpId = this.colorGrouping === 'thread' + ? threadId + : Math.floor(threadId / warpSize) + maxThreadId = Math.max(maxThreadId, threadId) + + const [row, col] = this.resolveDisplayPosition(outputDims, outputCoords) + const key = `${row},${col}` + const sourcePosition: [number, number] = [row, col] + const cellInfo: CellInfo = { + threadId, + registerId, + warpId, + position: [row, col], + sourcePosition, + inputCoords, + outputCoords, + } + + const existing = cache.get(key) + if (existing) { + existing.push(cellInfo) + } else { + cache.set(key, [cellInfo]) + } + + const sourceKey = `${sourcePosition[0]},${sourcePosition[1]}` + const sourceEntries = this.sourcePositionIndex.get(sourceKey) + if (sourceEntries) { + sourceEntries.push([row, col]) + } else { + this.sourcePositionIndex.set(sourceKey, [[row, col]]) + } + return + } + + const dim = inputDims[index] + if (!dim) return + const [name, size] = dim + for (let value = 0; value < size; value++) { + coords[name] = value + traverse(index + 1) + } } - return cache + + traverse(0) + return { cache, maxThreadId } + } + + private resolveDisplayPosition( + outputDims: Array<[string, number]>, + outputCoords: Record + ): [number, number] { + const primaryDim = outputDims[0]?.[0] + const secondaryDim = outputDims[1]?.[0] + const row = secondaryDim ? (outputCoords[primaryDim ?? ''] ?? 0) : 0 + const col = secondaryDim + ? outputCoords[secondaryDim] ?? 0 + : primaryDim + ? outputCoords[primaryDim] ?? 0 + : 0 + return [row, col] + } + + private isOutputWithinBounds( + outputCoords: Record, + outputDims: Array<[string, number]> + ): boolean { + return outputDims.every(([name, size]) => { + const value = outputCoords[name] + return typeof value === 'number' && value >= 0 && value < size + }) } private rebuildColorSchemeFromThreads(maxThreadId: number): void {