From a1275a62895eacd79590315e5847d048fa3edc0c Mon Sep 17 00:00:00 2001 From: leeliu103 Date: Fri, 5 Dec 2025 19:33:45 +0000 Subject: [PATCH] Add partial mapping visualization for non-invertible LinearLayout matrices MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When LinearLayout matrices are not invertible, enumerate all possible inputs and display valid output mappings on the canvas instead of showing an error. This allows users to visualize partial mappings for rank-deficient or rectangular matrices. - Implement input enumeration fallback when matrix inversion fails - Show status message: "Layout matrix is not invertible. Showing partial mapping for valid outputs." - Add comprehensive test coverage for collision scenarios, hover highlighting, and color-by-dimension - Handle multiple inputs mapping to same output (rank-deficient case) - Validate coordinates and ensure proper highlighting in by-output traversal mode 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/tabs/LinearLayoutTab.ts | 10 +- src/visualization/CanvasRenderer.test.ts | 357 +++++++++++++++++++++-- src/visualization/CanvasRenderer.ts | 146 +++++++-- 3 files changed, 465 insertions(+), 48 deletions(-) diff --git a/src/tabs/LinearLayoutTab.ts b/src/tabs/LinearLayoutTab.ts index ee21ea3..de7689d 100644 --- a/src/tabs/LinearLayoutTab.ts +++ b/src/tabs/LinearLayoutTab.ts @@ -313,11 +313,13 @@ export class LinearLayoutTab extends CanvasTab { try { const layout = LinearLayout.fromBitMatrix(matrix, inputs, outputs) - if (!layout.isInvertible()) { - throw new Error('Layout matrix is not invertible. Ensure input/output bit counts match and the matrix has full rank.') - } + const isInvertible = layout.isInvertible() this.layout = layout - this.setLayoutStatus('') + if (isInvertible) { + this.setLayoutStatus('') + } else { + this.setLayoutStatus('Layout matrix is not invertible. Showing partial mapping for valid outputs.') + } this.updateRendererFromLayout() } catch (error) { const message = error instanceof Error ? error.message : String(error) diff --git a/src/visualization/CanvasRenderer.test.ts b/src/visualization/CanvasRenderer.test.ts index b704070..438e53e 100644 --- a/src/visualization/CanvasRenderer.test.ts +++ b/src/visualization/CanvasRenderer.test.ts @@ -1,10 +1,32 @@ import { describe, it, expect, beforeEach, vi } from 'vitest' -import { CanvasRenderer, type PositionResolver } from './CanvasRenderer' +import { CanvasRenderer, type PositionResolver, type CellInfo } from './CanvasRenderer' import { ColorScheme } from './ColorScheme' import { createBlockLayout } from '../layouts/BlockLayout' import type { BlockLayoutParams } from '../validation/InputValidator' import { LinearLayout } from '../core/LinearLayout' +const getRendererCellSize = (renderer: CanvasRenderer): number => + (renderer as unknown as { cellSize: number }).cellSize + +const getRendererViewport = ( + renderer: CanvasRenderer +): { offsetX: number; offsetY: number; scale: number } => + (renderer as unknown as { + viewportController: { getViewport(): { offsetX: number; offsetY: number; scale: number } } + }).viewportController.getViewport() + +const getCellCenterCoordinates = (renderer: CanvasRenderer, row: number, col: number): { + x: number + y: number +} => { + const viewport = getRendererViewport(renderer) + const scaledCellSize = getRendererCellSize(renderer) * viewport.scale + return { + x: viewport.offsetX + col * scaledCellSize + scaledCellSize / 2, + y: viewport.offsetY + row * scaledCellSize + scaledCellSize / 2, + } +} + describe('CanvasRenderer', () => { let canvas: HTMLCanvasElement let renderer: CanvasRenderer @@ -131,26 +153,28 @@ describe('CanvasRenderer', () => { const warpSpy = vi.spyOn(ColorScheme.prototype, 'getColorForWarp') const threadSpy = vi.spyOn(ColorScheme.prototype, 'getColorForThread') - const renderer = new CanvasRenderer( - canvas, - layout, - params, - undefined, - { - traversalMode: 'by-output', - colorGrouping: 'thread', - showCellText: false, - colorInputDimension: 'lane', - } - ) - - renderer.render() + try { + const renderer = new CanvasRenderer( + canvas, + layout, + params, + undefined, + { + traversalMode: 'by-output', + colorGrouping: 'thread', + showCellText: false, + colorInputDimension: 'lane', + } + ) - expect(warpSpy).toHaveBeenCalled() - expect(threadSpy).not.toHaveBeenCalled() + renderer.render() - warpSpy.mockRestore() - threadSpy.mockRestore() + expect(warpSpy).toHaveBeenCalled() + expect(threadSpy).not.toHaveBeenCalled() + } finally { + warpSpy.mockRestore() + threadSpy.mockRestore() + } }) }) @@ -258,12 +282,10 @@ describe('CanvasRenderer', () => { replicaRenderer.render() strokeRectSpy.mockClear() - const cellSize = 50 - const offsetX = (canvas.width - params.tensorShape[1] * cellSize) / 2 - const offsetY = (canvas.height - params.tensorShape[0] * cellSize) / 2 + const firstCellCenter = getCellCenterCoordinates(replicaRenderer, 0, 0) const event = new MouseEvent('mousemove', { - clientX: offsetX + cellSize / 2, - clientY: offsetY + cellSize / 2, + clientX: firstCellCenter.x, + clientY: firstCellCenter.y, }) replicaRenderer.handleMouseMove(event) @@ -380,6 +402,31 @@ describe('CanvasRenderer', () => { }) describe('output traversal mode', () => { + const createRankDeficientLayout = () => + LinearLayout.fromBitMatrix( + [ + [1, 0, 0], + [1, 0, 0], + [0, 0, 1], + ], + [ + { name: 'thread', size: 4 }, + { name: 'register', size: 2 }, + ], + [ + { name: 'row', size: 4 }, + { name: 'col', size: 2 }, + ] + ) + + const createRankDeficientParams = (): BlockLayoutParams => ({ + sizePerThread: [1, 1] as [number, number], + threadsPerWarp: [4, 1] as [number, number], + warpsPerCTA: [1, 1] as [number, number], + order: [0, 1] as [number, number], + tensorShape: [4, 2] as [number, number], + }) + it('resolves cells directly from output coordinates', () => { const tensorParams: BlockLayoutParams = { sizePerThread: [1, 1] as [number, number], @@ -417,5 +464,267 @@ describe('CanvasRenderer', () => { const cacheSize = (outputRenderer as unknown as { cellDataCache: Map }).cellDataCache.size expect(cacheSize).toBe(tensorParams.tensorShape[0] * tensorParams.tensorShape[1]) }) + + it('falls back to partial mappings when the layout is not invertible', () => { + const nonInvertibleLayout = LinearLayout.fromBitMatrix( + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 1, 1], + ], + [ + { name: 'thread', size: 4 }, + { name: 'register', size: 2 }, + ], + [ + { name: 'row', size: 8 }, + { name: 'col', size: 2 }, + ] + ) + expect(nonInvertibleLayout.isInvertible()).toBe(false) + + const params: BlockLayoutParams = { + sizePerThread: [1, 1] as [number, number], + threadsPerWarp: [4, 1] as [number, number], + warpsPerCTA: [1, 1] as [number, number], + order: [0, 1] as [number, number], + tensorShape: [8, 2] as [number, number], + } + + const renderer = new CanvasRenderer( + canvas, + nonInvertibleLayout, + params, + undefined, + { traversalMode: 'by-output', colorGrouping: 'thread', showCellText: false } + ) + + const populatedCell = renderer.getCellInfo(0, 0) + expect(populatedCell).not.toBeNull() + expect(populatedCell?.outputCoords?.row).toBe(0) + expect(populatedCell?.outputCoords?.col).toBe(0) + + const unmappedCell = renderer.getCellInfo(7, 0) + expect(unmappedCell).toBeNull() + + const cacheSize = (renderer as unknown as { cellDataCache: Map }).cellDataCache.size + expect(cacheSize).toBe(8) + }) + + it('handles rank-deficient matrices where multiple inputs map to the same output cell', () => { + const layout = createRankDeficientLayout() + const params = createRankDeficientParams() + expect(layout.isInvertible()).toBe(false) + + const renderer = new CanvasRenderer( + canvas, + layout, + params, + undefined, + { traversalMode: 'by-output', colorGrouping: 'thread', showCellText: false } + ) + + const cache = (renderer as unknown as { cellDataCache: Map }).cellDataCache + const registerCount = layout.getInDimSize('register') + const totalThreads = + params.threadsPerWarp[0] * + params.threadsPerWarp[1] * + params.warpsPerCTA[0] * + params.warpsPerCTA[1] + + const expectedEntries = new Map< + string, + Array<{ + threadId: number + registerId: number + inputCoords: { thread: number; register: number } + outputCoords: { row: number; col: number } + }> + >() + + for (let threadId = 0; threadId < totalThreads; threadId++) { + for (let registerId = 0; registerId < registerCount; registerId++) { + const outputCoords = layout.apply({ thread: threadId, register: registerId }) + const row = outputCoords.row + const col = outputCoords.col + expect(typeof row).toBe('number') + expect(typeof col).toBe('number') + if (typeof row !== 'number' || typeof col !== 'number') { + continue + } + const key = `${row},${col}` + const entry = { + threadId, + registerId, + inputCoords: { thread: threadId, register: registerId }, + outputCoords: { row, col }, + } + const bucket = expectedEntries.get(key) + if (bucket) { + bucket.push(entry) + } else { + expectedEntries.set(key, [entry]) + } + } + } + + const expectedKeys = [...expectedEntries.keys()].sort() + const actualKeys = [...cache.keys()].sort() + + expect(cache.size).toBe(expectedEntries.size) + expect(actualKeys).toEqual(expectedKeys) + + for (const [key, expectedCells] of expectedEntries) { + const collision = cache.get(key) + expect(collision).toBeDefined() + if (!collision) { + continue + } + expect(collision).toHaveLength(expectedCells.length) + const matchers = expectedCells.map(entry => + expect.objectContaining({ + threadId: entry.threadId, + registerId: entry.registerId, + inputCoords: expect.objectContaining(entry.inputCoords), + outputCoords: expect.objectContaining(entry.outputCoords), + }) + ) + expect(collision).toEqual(expect.arrayContaining(matchers)) + } + }) + + it('provides accurate input coordinates in the fallback enumeration path', () => { + const renderer = new CanvasRenderer( + canvas, + createRankDeficientLayout(), + createRankDeficientParams(), + undefined, + { traversalMode: 'by-output', colorGrouping: 'thread', showCellText: false } + ) + + const cell00 = renderer.getCellInfo(0, 0) + expect(cell00).not.toBeNull() + expect([0, 2]).toContain(cell00?.threadId) + expect([0]).toContain(cell00?.registerId) + expect(cell00?.inputCoords).toMatchObject({ thread: 0, register: 0 }) + expect(cell00?.outputCoords).toMatchObject({ row: 0, col: 0 }) + + const cell31 = renderer.getCellInfo(3, 1) + expect(cell31).not.toBeNull() + expect([1, 3]).toContain(cell31?.threadId) + expect([1]).toContain(cell31?.registerId) + expect(cell31?.inputCoords).toMatchObject({ thread: 1, register: 1 }) + expect(cell31?.outputCoords).toMatchObject({ row: 3, col: 1 }) + }) + + it('supports color-by-input dimension when using the fallback path', () => { + const layout = createRankDeficientLayout() + const params = createRankDeficientParams() + const warpSpy = vi.spyOn(ColorScheme.prototype, 'getColorForWarp') + const threadSpy = vi.spyOn(ColorScheme.prototype, 'getColorForThread') + + const renderer = new CanvasRenderer( + canvas, + layout, + params, + undefined, + { + traversalMode: 'by-output', + colorGrouping: 'thread', + showCellText: false, + colorInputDimension: 'register', + } + ) + + try { + renderer.render() + const colorCalls = warpSpy.mock.calls.map(call => call[0] as number) + expect(colorCalls).toContain(0) + expect(colorCalls).toContain(1) + expect(threadSpy).not.toHaveBeenCalled() + } finally { + warpSpy.mockRestore() + threadSpy.mockRestore() + } + }) + + it('highlights overlapping cells when hovering in by-output traversal mode', () => { + const layout = createRankDeficientLayout() + const params = createRankDeficientParams() + const renderer = new CanvasRenderer( + canvas, + layout, + params, + undefined, + { traversalMode: 'by-output', colorGrouping: 'thread', showCellText: false } + ) + + const ctx = canvas.getContext('2d')! + const strokeRectSpy = vi.spyOn(ctx, 'strokeRect') + const rectSpy = vi.spyOn(canvas, 'getBoundingClientRect').mockReturnValue({ + width: canvas.width, + height: canvas.height, + top: 0, + left: 0, + right: canvas.width, + bottom: canvas.height, + x: 0, + y: 0, + toJSON: () => ({}), + }) + + try { + renderer.render() + + const sourceIndex = (renderer as unknown as { + sourcePositionIndex: Map> + }).sourcePositionIndex + + const firstEntry = sourceIndex.get('0,0') + const secondEntry = sourceIndex.get('3,1') + + expect(firstEntry).toBeDefined() + expect(firstEntry).toHaveLength(2) + expect(secondEntry).toBeDefined() + expect(secondEntry).toHaveLength(2) + + strokeRectSpy.mockClear() + const hoverAndAssert = (targetCell: [number, number], expectedCells: Array<[number, number]>) => { + const cellCenter = getCellCenterCoordinates(renderer, targetCell[0], targetCell[1]) + const event = new MouseEvent('mousemove', { + clientX: cellCenter.x, + clientY: cellCenter.y, + }) + + strokeRectSpy.mockClear() + renderer.handleMouseMove(event) + + expect(strokeRectSpy).toHaveBeenCalledTimes(expectedCells.length) + + const viewport = getRendererViewport(renderer) + const scaledCellSize = getRendererCellSize(renderer) * viewport.scale + const highlightedCells = strokeRectSpy.mock.calls + .map(call => { + const [x, y] = call + const col = Math.round((x - viewport.offsetX) / scaledCellSize) + const row = Math.round((y - viewport.offsetY) / scaledCellSize) + return `${row},${col}` + }) + .sort() + + const expected = expectedCells.map(([row, col]) => `${row},${col}`).sort() + expect(highlightedCells).toEqual(expected) + } + + if (firstEntry && secondEntry) { + hoverAndAssert(firstEntry[0]!, firstEntry) + hoverAndAssert(secondEntry[0]!, secondEntry) + } + } finally { + strokeRectSpy.mockRestore() + rectSpy.mockRestore() + } + }) }) }) diff --git a/src/visualization/CanvasRenderer.ts b/src/visualization/CanvasRenderer.ts index 57ee35c..beab541 100644 --- a/src/visualization/CanvasRenderer.ts +++ b/src/visualization/CanvasRenderer.ts @@ -35,6 +35,11 @@ export interface ResolvedPosition { sourcePos?: [number, number] } +interface OutputTraversalBuildResult { + cache: Map + maxThreadId: number +} + /** * Position resolver function type - returns positions for a given thread. * @@ -187,16 +192,39 @@ export class CanvasRenderer { } private buildCellDataCacheByOutput(): Map { - const cache = new Map() - this.sourcePositionIndex = new Map>() const outputDims = this.layout.getOutDims() if (outputDims.length === 0) { - return cache + this.sourcePositionIndex = new Map>() + this.maxThreadIdObserved = 0 + return new Map() + } + + let result: OutputTraversalBuildResult + try { + const inverse = this.layout.invert() + result = this.buildOutputCacheUsingInverse(outputDims, inverse) + } catch { + result = this.buildOutputCacheFromInputs(outputDims) + } + + this.maxThreadIdObserved = Math.max(result.maxThreadId, 0) + if (this.customColorDimension) { + this.resetColorSchemeForCustomDimension() + } else { + this.rebuildColorSchemeFromThreads(result.maxThreadId) } + return result.cache + } - const inverse = this.layout.invert() + private buildOutputCacheUsingInverse( + outputDims: Array<[string, number]>, + inverse: LinearLayout + ): OutputTraversalBuildResult { + const cache = new Map() + this.sourcePositionIndex = new Map>() const coords: Record = {} let maxThreadId = 0 + const warpSize = Math.max(this.params.threadsPerWarp[0] * this.params.threadsPerWarp[1], 1) const traverse = (index: number) => { if (index === outputDims.length) { @@ -204,21 +232,12 @@ export class CanvasRenderer { const inputCoords = inverse.apply(outputCoords) const threadId = inputCoords.thread ?? 0 const registerId = inputCoords.register ?? 0 - const warpSize = Math.max(this.params.threadsPerWarp[0] * this.params.threadsPerWarp[1], 1) const warpId = this.colorGrouping === 'thread' ? threadId : Math.floor(threadId / warpSize) maxThreadId = Math.max(maxThreadId, threadId) - const primaryDim = outputDims[0]?.[0] - const secondaryDim = outputDims[1]?.[0] - const row = secondaryDim ? (outputCoords[primaryDim ?? ''] ?? 0) : 0 - const col = secondaryDim - ? outputCoords[secondaryDim] ?? 0 - : primaryDim - ? outputCoords[primaryDim] ?? 0 - : 0 - + const [row, col] = this.resolveDisplayPosition(outputDims, outputCoords) const key = `${row},${col}` const sourcePosition: [number, number] = [row, col] const cellInfo: CellInfo = { @@ -258,13 +277,100 @@ export class CanvasRenderer { } traverse(0) - this.maxThreadIdObserved = Math.max(maxThreadId, 0) - if (this.customColorDimension) { - this.resetColorSchemeForCustomDimension() - } else { - this.rebuildColorSchemeFromThreads(maxThreadId) + return { cache, maxThreadId } + } + + private buildOutputCacheFromInputs( + outputDims: Array<[string, number]> + ): OutputTraversalBuildResult { + const cache = new Map() + this.sourcePositionIndex = new Map>() + const inputDims = this.layout.getInDimNames().map((name) => [name, this.layout.getInDimSize(name)] as [string, number]) + const coords: Record = {} + const warpSize = Math.max(this.params.threadsPerWarp[0] * this.params.threadsPerWarp[1], 1) + let maxThreadId = 0 + + const traverse = (index: number) => { + if (index === inputDims.length) { + const inputCoords = { ...coords } + const outputCoords = this.layout.apply(inputCoords) + if (!this.isOutputWithinBounds(outputCoords, outputDims)) { + return + } + + const threadId = inputCoords.thread ?? 0 + const registerId = inputCoords.register ?? 0 + const warpId = this.colorGrouping === 'thread' + ? threadId + : Math.floor(threadId / warpSize) + maxThreadId = Math.max(maxThreadId, threadId) + + const [row, col] = this.resolveDisplayPosition(outputDims, outputCoords) + const key = `${row},${col}` + const sourcePosition: [number, number] = [row, col] + const cellInfo: CellInfo = { + threadId, + registerId, + warpId, + position: [row, col], + sourcePosition, + inputCoords, + outputCoords, + } + + const existing = cache.get(key) + if (existing) { + existing.push(cellInfo) + } else { + cache.set(key, [cellInfo]) + } + + const sourceKey = `${sourcePosition[0]},${sourcePosition[1]}` + const sourceEntries = this.sourcePositionIndex.get(sourceKey) + if (sourceEntries) { + sourceEntries.push([row, col]) + } else { + this.sourcePositionIndex.set(sourceKey, [[row, col]]) + } + return + } + + const dim = inputDims[index] + if (!dim) return + const [name, size] = dim + for (let value = 0; value < size; value++) { + coords[name] = value + traverse(index + 1) + } } - return cache + + traverse(0) + return { cache, maxThreadId } + } + + private resolveDisplayPosition( + outputDims: Array<[string, number]>, + outputCoords: Record + ): [number, number] { + const primaryDim = outputDims[0]?.[0] + const secondaryDim = outputDims[1]?.[0] + const row = secondaryDim ? (outputCoords[primaryDim ?? ''] ?? 0) : 0 + const col = secondaryDim + ? outputCoords[secondaryDim] ?? 0 + : primaryDim + ? outputCoords[primaryDim] ?? 0 + : 0 + return [row, col] + } + + private isOutputWithinBounds( + outputCoords: Record, + outputDims: Array<[string, number]> + ): boolean { + return outputDims.every(([name, size]) => { + const value = outputCoords[name] + return typeof value === 'number' && value >= 0 && value < size + }) } private rebuildColorSchemeFromThreads(maxThreadId: number): void {