From 735ffc1c9fc047a76879c97dd4c1153b3651214a Mon Sep 17 00:00:00 2001 From: leeliu103 Date: Tue, 25 Nov 2025 22:43:18 +0000 Subject: [PATCH] Add LinearLayout inversion and live matrix visualization with basis calculation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit implements output-to-input coordinate mapping for LinearLayout visualization by leveraging Triton's basis-based approach: Core Features: - Add fromBitMatrix() and invert() methods using binary Gaussian elimination - Implement output-traversal mode that applies inverted layout to recover input coordinates - Add live matrix updates: canvas refreshes immediately as user edits matrix - Simplify tooltip to show input/output coordinates with correct bit widths (log2 of dimension size) - Add vertical matrix basis calculation display matching Triton's format - Fix basis column and row ordering to match matrix editor UI/UX Improvements: - Add spacing between basis columns from different inputs in tooltip for visual clarity - Remove text labels from LinearLayout canvas cells (color-coded only) to avoid dimension name mismatch - Keep text labels for WMMA/MFMA tabs which use canonical dimension names Performance & Correctness: - Fix double layout computation: eliminate redundant rebuild when dimension edits trigger matrix-change event - Align invert()/isInvertible() with Triton's "square and surjective" requirement - Remove unused columnIndex field from BasisColumnDescriptor Default parameters: output 16×16, input reg(8) + thread(32) All 115 tests passing. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/core/LinearLayout.test.ts | 48 +++ src/core/LinearLayout.ts | 405 ++++++++++++++++----- src/styles.css | 91 ++++- src/tabs/LinearLayoutTab.ts | 438 +++++++++++++++++------ src/ui/LinearLayoutMatrixEditor.ts | 40 ++- src/visualization/CanvasRenderer.test.ts | 59 +++ src/visualization/CanvasRenderer.ts | 157 +++++++- 7 files changed, 1024 insertions(+), 214 deletions(-) diff --git a/src/core/LinearLayout.test.ts b/src/core/LinearLayout.test.ts index 210ff2d..77a3341 100644 --- a/src/core/LinearLayout.test.ts +++ b/src/core/LinearLayout.test.ts @@ -33,6 +33,24 @@ describe('LinearLayout', () => { }) }) + describe('fromBitMatrix', () => { + it('should reconstruct a 1D identity layout', () => { + const matrix = [ + [1, 0], + [0, 1], + ] + const layout = LinearLayout.fromBitMatrix( + matrix, + [{ name: 'in', size: 4 }], + [{ name: 'out', size: 4 }] + ) + + for (let i = 0; i < 4; i++) { + expect(layout.apply({ in: i }).out).toBe(i) + } + }) + }) + describe('XOR linearity', () => { it('should satisfy L(a⊕b) = L(a)⊕L(b)', () => { const layout = LinearLayout.identity1D(8, 'in', 'out') @@ -215,6 +233,36 @@ describe('LinearLayout', () => { } } }) + + it('should report non-invertible layouts', () => { + const nonSquare = new LinearLayout( + [ + ['in', [[1, 0]]], + ], + [ + ['out0', 4], + ['out1', 4], + ] + ) + + expect(nonSquare.isInvertible()).toBe(false) + expect(() => nonSquare.invert()).toThrowError() + }) + + it('should reject layouts where output bits are zero but inputs are not', () => { + const broadcast = LinearLayout.zeros1D(8, 'reg', 'zero') + expect(broadcast.isInvertible()).toBe(false) + expect(() => broadcast.invert()).toThrowError(/square and surjective/i) + }) + + it('should invert layouts whose input/output spaces collapse to a single value', () => { + const trivial = LinearLayout.identity1D(1, 'in', 'out') + expect(trivial.isInvertible()).toBe(true) + const inverted = trivial.invert() + expect(inverted.apply({ out: 0 }).in).toBe(0) + expect(inverted.getOutDimNames()).toEqual(['in']) + expect(inverted.getInDimNames()).toEqual(['out']) + }) }) describe('ensureNotSmallerThan', () => { diff --git a/src/core/LinearLayout.ts b/src/core/LinearLayout.ts index dfb106e..8151946 100644 --- a/src/core/LinearLayout.ts +++ b/src/core/LinearLayout.ts @@ -107,6 +107,215 @@ function supremum(x: string[], y: string[]): string[] { return result } +interface DimensionSpec { + name: string + size: number +} + +interface BitDescriptor { + name: string + bit: number +} + +function assertPowerOfTwo(value: number, context: string): void { + if (!Number.isInteger(value) || value <= 0 || (value & (value - 1)) !== 0) { + throw new Error(`${context} must be a positive power of two, got ${value}`) + } +} + +function bitCountForSize(size: number): number { + assertPowerOfTwo(size, 'Dimension size') + return Math.trunc(Math.log2(size)) +} + +function totalBitCount(specs: DimensionSpec[]): number { + return specs.reduce((sum, spec) => sum + bitCountForSize(spec.size), 0) +} + +function normalizeDimensions( + dimensions: Array<{ name: string; size: number }>, + context: string +): DimensionSpec[] { + const seen = new Set() + return dimensions.map((dim, index) => { + const trimmedName = dim.name.trim() + if (!trimmedName) { + throw new Error(`${context} dimension at index ${index} must have a name`) + } + assertPowerOfTwo(dim.size, `${context} dimension "${trimmedName}" size`) + if (seen.has(trimmedName)) { + throw new Error(`${context} dimension "${trimmedName}" is duplicated`) + } + seen.add(trimmedName) + return { name: trimmedName, size: dim.size } + }) +} + +function buildBitDescriptorsFromSpecs(specs: DimensionSpec[]): BitDescriptor[] { + const descriptors: BitDescriptor[] = [] + for (const spec of specs) { + const bitCount = bitCountForSize(spec.size) + for (let bit = 0; bit < bitCount; bit++) { + descriptors.push({ name: spec.name, bit }) + } + } + return descriptors +} + +function matrixToBases( + matrix: number[][], + inputSpecs: DimensionSpec[], + outputSpecs: DimensionSpec[] +): Array<[string, number[][]]> { + const rowDescriptors = buildBitDescriptorsFromSpecs(outputSpecs) + const columnDescriptors = buildBitDescriptorsFromSpecs(inputSpecs) + + if (matrix.length !== rowDescriptors.length) { + throw new Error( + `Matrix row count (${matrix.length}) does not match output bit count (${rowDescriptors.length})` + ) + } + const expectedCols = columnDescriptors.length + matrix.forEach((row, rowIdx) => { + if (row.length !== expectedCols) { + throw new Error( + `Matrix column count mismatch on row ${rowIdx} (expected ${expectedCols}, got ${row.length})` + ) + } + }) + + const rowIndexMap = new Map() + rowDescriptors.forEach((descriptor, index) => { + const existing = rowIndexMap.get(descriptor.name) ?? [] + existing[descriptor.bit] = index + rowIndexMap.set(descriptor.name, existing) + }) + + const bases: Array<[string, number[][]]> = [] + let columnCursor = 0 + for (const spec of inputSpecs) { + const bitCount = bitCountForSize(spec.size) + const dimBases: number[][] = [] + for (let bit = 0; bit < bitCount; bit++) { + const columnIndex = columnCursor + bit + const basisVector: number[] = [] + for (const output of outputSpecs) { + const rowIndexes = rowIndexMap.get(output.name) ?? [] + let value = 0 + for (let bitPos = 0; bitPos < rowIndexes.length; bitPos++) { + const rowIndex = rowIndexes[bitPos] + if (rowIndex === undefined) continue + const cell = matrix[rowIndex]?.[columnIndex] ?? 0 + if (cell & 1) { + value |= 1 << bitPos + } + } + basisVector.push(value) + } + dimBases.push(basisVector) + } + bases.push([spec.name, dimBases]) + columnCursor += bitCount + } + + return bases +} + +function basesToMatrix( + bases: Map, + inputSpecs: DimensionSpec[], + outputSpecs: DimensionSpec[] +): number[][] { + const rowDescriptors = buildBitDescriptorsFromSpecs(outputSpecs) + const columnDescriptors = buildBitDescriptorsFromSpecs(inputSpecs) + const matrix = rowDescriptors.map(() => new Array(columnDescriptors.length).fill(0)) + + const rowIndexMap = new Map() + rowDescriptors.forEach((descriptor, index) => { + const existing = rowIndexMap.get(descriptor.name) ?? [] + existing[descriptor.bit] = index + rowIndexMap.set(descriptor.name, existing) + }) + + let columnCursor = 0 + for (const spec of inputSpecs) { + const dimBases = bases.get(spec.name) ?? [] + const bitCount = bitCountForSize(spec.size) + if (dimBases.length !== bitCount) { + throw new Error( + `Input dimension ${spec.name} expected ${bitCount} bases but got ${dimBases.length}` + ) + } + + for (let bit = 0; bit < bitCount; bit++) { + const columnIndex = columnCursor + bit + const basisVector = dimBases[bit] ?? [] + for (let outIdx = 0; outIdx < outputSpecs.length; outIdx++) { + const output = outputSpecs[outIdx] + if (!output) continue + const rowIndexes = rowIndexMap.get(output.name) ?? [] + const basisValue = basisVector[outIdx] ?? 0 + for (let bitPos = 0; bitPos < rowIndexes.length; bitPos++) { + const rowIndex = rowIndexes[bitPos] + if (rowIndex === undefined) continue + const matrixRow = matrix[rowIndex] + if (matrixRow) { + matrixRow[columnIndex] = (basisValue >> bitPos) & 1 + } + } + } + } + columnCursor += bitCount + } + + return matrix +} + +function invertBinaryMatrix(matrix: number[][]): number[][] { + const n = matrix.length + if (n === 0) { + return [] + } + const width = matrix[0]?.length ?? 0 + if (width !== n) { + throw new Error('Layout matrix must be square to invert') + } + + const augmented: number[][] = matrix.map((row, i) => { + if (row.length !== width) { + throw new Error('Layout matrix rows must have consistent width') + } + const left = row.map((value) => (value & 1 ? 1 : 0)) + const right = new Array(n).fill(0) + right[i] = 1 + return [...left, ...right] + }) + + for (let col = 0; col < n; col++) { + let pivot = col + while (pivot < n && augmented[pivot]?.[col] !== 1) { + pivot++ + } + if (pivot === n) { + throw new Error('Layout matrix is not invertible') + } + if (pivot !== col) { + const temp = augmented[col]! + augmented[col] = augmented[pivot]! + augmented[pivot] = temp + } + for (let row = 0; row < n; row++) { + if (row !== col && augmented[row]?.[col] === 1) { + for (let k = col; k < 2 * n; k++) { + augmented[row]![k]! ^= augmented[col]![k]! + } + } + } + } + + return augmented.map((row) => row.slice(n)) +} + /** * LinearLayout - A function mapping tuples of integers to tuples of integers * using linear algebra over GF(2) (the two-element field with XOR as addition) @@ -114,6 +323,8 @@ function supremum(x: string[], y: string[]): string[] { * Based on Triton's LinearLayout implementation in triton/include/triton/Tools/LinearLayout.h */ +type OutDimInit = string[] | Array<[string, number]> + export class LinearLayout { // bases[inDim][i] = L(0, ..., inDim=2^i, ..., 0) // Each basis is a vector of size numOutDims @@ -122,18 +333,45 @@ export class LinearLayout { // Output dimension names and their sizes private outDims: Map - constructor( - basesArray: Array<[string, number[][]]>, - outDimNames: string[] - ) { + constructor(basesArray: Array<[string, number[][]]>, outDimInit: OutDimInit) { this.bases = new Map(basesArray) this.outDims = new Map() - // Infer out-dim sizes from bases - // Find max value for each output dimension by XORing all bases - const maxValues: number[] = new Array(outDimNames.length).fill(0) + const hasExplicitSizes = Array.isArray(outDimInit[0]) + if (hasExplicitSizes) { + this.initializeExplicitOutDims(outDimInit as Array<[string, number]>) + } else { + this.initializeOutDimsFromNames(outDimInit as string[]) + } + } + + private initializeExplicitOutDims(outDims: Array<[string, number]>): void { + for (const [rawName, size] of outDims) { + const name = rawName.trim() + if (!name) { + throw new Error('Output dimension name is required') + } + if (this.outDims.has(name)) { + throw new Error(`Duplicate output dimension "${name}"`) + } + assertPowerOfTwo(size, `Output dimension "${name}" size`) + this.outDims.set(name, size) + } + } + + private initializeOutDimsFromNames(outDimNames: string[]): void { + const trimmedNames = outDimNames.map((name) => name.trim()).filter((name) => name.length > 0) + const seen = new Set() + trimmedNames.forEach((name) => { + if (seen.has(name)) { + throw new Error(`Duplicate output dimension "${name}"`) + } + seen.add(name) + this.outDims.set(name, 1) + }) - for (const [_, inDimBases] of this.bases) { + const maxValues: number[] = new Array(trimmedNames.length).fill(0) + for (const [, inDimBases] of this.bases) { for (const basis of inDimBases) { for (let i = 0; i < basis.length; i++) { const currentMax = maxValues[i] @@ -145,13 +383,14 @@ export class LinearLayout { } } - for (let i = 0; i < outDimNames.length; i++) { - const outDim = outDimNames[i] - const maxValue = maxValues[i] - if (!outDim || maxValue === undefined) continue - // Round up to next power of 2 after the max value - this.outDims.set(outDim, this.nextPowerOf2(maxValue + 1)) - } + trimmedNames.forEach((name, index) => { + const maxValue = maxValues[index] + if (maxValue === undefined) { + return + } + const inferredSize = this.nextPowerOf2(Math.max(1, maxValue + 1)) + this.outDims.set(name, inferredSize) + }) } static empty(): LinearLayout { @@ -198,6 +437,21 @@ export class LinearLayout { return new LinearLayout([[inDim, bases]], [outDim]) } + static fromBitMatrix( + matrix: number[][], + inputDimensions: Array<{ name: string; size: number }>, + outputDimensions: Array<{ name: string; size: number }> + ): LinearLayout { + const normalizedInputs = normalizeDimensions(inputDimensions, 'Input') + const normalizedOutputs = normalizeDimensions(outputDimensions, 'Output') + if (normalizedOutputs.length === 0 || normalizedInputs.length === 0) { + throw new Error('Both input and output dimensions are required to build a layout') + } + const basesArray = matrixToBases(matrix, normalizedInputs, normalizedOutputs) + const outDimPairs = normalizedOutputs.map((spec) => [spec.name, spec.size] as [string, number]) + return new LinearLayout(basesArray, outDimPairs) + } + /** * Apply the linear layout: compute L(inputs) * Uses XOR to combine bases according to input bit patterns @@ -322,78 +576,52 @@ export class LinearLayout { * Returns a layout where apply(invert().apply(x)) = x */ invert(): LinearLayout { - // For simplicity, we'll implement this for square layouts - // In a full implementation, this would use Gaussian elimination - - const inDimNames = Array.from(this.bases.keys()) - const outDimNames = Array.from(this.outDims.keys()) - - // Build the matrix - const totalInBits = inDimNames.reduce( - (sum, dim) => sum + (this.bases.get(dim)?.length ?? 0), - 0 - ) - const totalOutBits = outDimNames.reduce( - (sum, dim) => sum + Math.log2(this.outDims.get(dim) ?? 1), - 0 - ) - - if (totalInBits !== totalOutBits) { - throw new Error('Cannot invert non-square layout') + const inputSpecs = this.getInputDimensionSpecs() + const outputSpecs = this.getOutputDimensionSpecs() + if (inputSpecs.length === 0 || outputSpecs.length === 0) { + throw new Error('Cannot invert layout: input and output dimensions are required') } - // Create inverse mapping by brute force for now - // (In production, would use Gaussian elimination) - const newBases: Array<[string, number[][]]> = [] - - for (const outDim of outDimNames) { - const outSize = this.outDims.get(outDim) ?? 1 - const numBases = Math.log2(outSize) - const bases: number[][] = [] - - for (let bitIdx = 0; bitIdx < numBases; bitIdx++) { - const testOut: Record = {} - for (const od of outDimNames) { - testOut[od] = od === outDim ? (1 << bitIdx) : 0 - } - - // Find input that produces this output - const basis: number[] = [] - for (const inDim of inDimNames) { - const inSize = this.getInDimSize(inDim) - let found = 0 - - for (let testIn = 0; testIn < inSize; testIn++) { - const inputs: Record = {} - for (const id of inDimNames) { - inputs[id] = id === inDim ? testIn : 0 - } - - const result = this.apply(inputs) - let matches = true - for (const od of outDimNames) { - if (result[od] !== testOut[od]) { - matches = false - break - } - } + const inputBitCount = totalBitCount(inputSpecs) + const outputBitCount = totalBitCount(outputSpecs) + if (inputBitCount !== outputBitCount) { + throw new Error('Cannot invert layout: layout must be square and surjective (input/output bit counts differ)') + } - if (matches) { - found = testIn - break - } - } + let inverseMatrix: number[][] + if (outputBitCount === 0) { + inverseMatrix = [] + } else { + const matrix = basesToMatrix(this.bases, inputSpecs, outputSpecs) + inverseMatrix = invertBinaryMatrix(matrix) + } - basis.push(found) - } + const basesArray = matrixToBases(inverseMatrix, outputSpecs, inputSpecs) + const outDimPairs = inputSpecs.map((spec) => [spec.name, spec.size] as [string, number]) + return new LinearLayout(basesArray, outDimPairs) + } - bases.push(basis) + isInvertible(): boolean { + try { + const inputSpecs = this.getInputDimensionSpecs() + const outputSpecs = this.getOutputDimensionSpecs() + if (inputSpecs.length === 0 || outputSpecs.length === 0) { + return false } - - newBases.push([outDim, bases]) + const inputBitCount = totalBitCount(inputSpecs) + const outputBitCount = totalBitCount(outputSpecs) + if (inputBitCount !== outputBitCount) { + return false + } + if (inputBitCount === 0) { + return true + } + const matrix = basesToMatrix(this.bases, inputSpecs, outputSpecs) + invertBinaryMatrix(matrix) + return true + } catch { + return false } - - return new LinearLayout(newBases, inDimNames) } /** @@ -431,6 +659,17 @@ export class LinearLayout { return result } + private getInputDimensionSpecs(): DimensionSpec[] { + return this.getInDimNames().map((name) => ({ + name, + size: this.getInDimSize(name), + })) + } + + private getOutputDimensionSpecs(): DimensionSpec[] { + return this.getOutDims().map(([name, size]) => ({ name, size })) + } + getInDimNames(): string[] { return Array.from(this.bases.keys()) } @@ -459,11 +698,7 @@ export class LinearLayout { getOutDimSizeLog2(dim: string): number { const size = this.getOutDimSize(dim) if (size <= 0) return 0 - const log2 = Math.log2(size) - if (!Number.isInteger(log2)) { - throw new Error(`Output dimension ${dim} has non power-of-two size ${size}`) - } - return Math.trunc(log2) + return bitCountForSize(size) } /** diff --git a/src/styles.css b/src/styles.css index 7f37d9c..44603a8 100644 --- a/src/styles.css +++ b/src/styles.css @@ -417,7 +417,9 @@ button:active { border-radius: 4px; font-size: 0.8rem; line-height: 1.3; - max-width: 240px; + width: max-content; + max-width: 960px; + max-width: min(960px, calc(100vw - 32px)); box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15); z-index: 1000; } @@ -440,6 +442,93 @@ button:active { flex-shrink: 0; } +.layout-tooltip .basis-tooltip { + margin-top: 0.25rem; +} + +.layout-tooltip .basis-equation { + display: flex; + flex-direction: column; + gap: 0.4rem; + margin-top: 0.25rem; +} + +.layout-tooltip .basis-columns { + display: flex; + flex-wrap: nowrap; + align-items: flex-start; + gap: 0.35rem; + overflow-x: auto; +} + +.layout-tooltip .basis-column { + display: inline-flex; + flex-direction: column; + align-items: center; + font-family: 'JetBrains Mono', 'Fira Code', 'Roboto Mono', Consolas, monospace; + font-size: 0.75rem; + line-height: 1.15; + padding: 0.1rem 0.2rem; + border-radius: 2px; + background-color: rgba(255, 255, 255, 0.06); + min-width: 32px; +} + +.layout-tooltip .basis-column.basis-column-input-gap { + margin-left: 0.5rem; +} + +.layout-tooltip .basis-bit { + display: block; +} + +.layout-tooltip .basis-dimension-gap { + display: block; + height: 0.45rem; + width: 100%; +} + +.layout-tooltip .basis-input-bit, +.layout-tooltip .basis-multiply-symbol { + font-weight: 600; + font-size: 0.75rem; + line-height: 1.1; + margin-bottom: 0.05rem; +} + +.layout-tooltip .basis-multiply-symbol { + display: flex; + align-items: center; + justify-content: center; + font-size: 0.8rem; + margin-bottom: 0.1rem; +} + +.layout-tooltip .basis-placeholder { + visibility: hidden; +} + +.layout-tooltip .basis-result-column { + background-color: rgba(255, 255, 255, 0.12); + align-self: flex-end; +} + +.layout-tooltip .basis-column-operator { + display: inline-flex; + align-items: center; + justify-content: center; + font-weight: 600; + font-size: 0.85rem; + min-height: 100%; + padding: 0 0.1rem; + align-self: stretch; +} + +.layout-tooltip .basis-column-operator.basis-equals { + margin-left: 0.15rem; + margin-right: 0.15rem; +} + .validation-errors li, .validation-warnings li { margin: 0.25rem 0; diff --git a/src/tabs/LinearLayoutTab.ts b/src/tabs/LinearLayoutTab.ts index 5ffbfb0..b7ca939 100644 --- a/src/tabs/LinearLayoutTab.ts +++ b/src/tabs/LinearLayoutTab.ts @@ -1,6 +1,6 @@ import { LinearLayout } from '../core/LinearLayout' import type { BlockLayoutParams } from '../validation/InputValidator' -import { CanvasRenderer, type PositionResolver } from '../visualization/CanvasRenderer' +import { CanvasRenderer, type CellInfo } from '../visualization/CanvasRenderer' import { LinearLayoutMatrixEditor, type MatrixEditorDimensions } from '../ui/LinearLayoutMatrixEditor' import { renderSharedControls } from '../ui/renderSharedControls' import { CanvasTab, type CanvasTabElements } from './CanvasTab' @@ -13,19 +13,32 @@ interface LinearDimension { size: number } +interface BasisColumnDescriptor { + dimName: string + bitIndex: number + values: number[] +} + +interface OutputBitGroup { + name: string + bitWidth: number +} + +const BASIS_HEADING = 'Basis Calculation (all weighted bases are added via ⊕)' + /** * Restored Linear Layout tab that re-introduces the dimension controls while * continuing to render the simplified visualization until matrix wiring is required. */ export class LinearLayoutTab extends CanvasTab { private readonly params: BlockLayoutParams - private readonly layout: LinearLayout + private layout: LinearLayout private readonly form: HTMLFormElement private readonly sidebar: HTMLElement private readonly dimensionLists: Record private readonly dimensionAddButtons: Record - private readonly matrixButton: HTMLButtonElement private readonly matrixEditor: LinearLayoutMatrixEditor + private readonly layoutStatus: HTMLElement private dimensionState: Record constructor(tabId: string) { @@ -73,12 +86,12 @@ export class LinearLayoutTab extends CanvasTab { this.form = form this.dimensionState = { input: [ - { id: this.createDimensionId(), name: 'register', size: 4 }, + { id: this.createDimensionId(), name: 'reg', size: 8 }, { id: this.createDimensionId(), name: 'thread', size: 32 }, ], output: [ - { id: this.createDimensionId(), name: 'outdim1', size: 16 }, - { id: this.createDimensionId(), name: 'outdim2', size: 256 }, + { id: this.createDimensionId(), name: 'dim0', size: 16 }, + { id: this.createDimensionId(), name: 'dim1', size: 16 }, ], } @@ -102,59 +115,39 @@ export class LinearLayoutTab extends CanvasTab { input: inputAdd, output: outputAdd, } - this.matrixButton = matrixButton this.matrixEditor = new LinearLayoutMatrixEditor() + this.layoutStatus = document.createElement('div') + this.layoutStatus.className = 'layout-status' + this.layoutStatus.setAttribute('role', 'status') + matrixButton.insertAdjacentElement('afterend', this.layoutStatus) this.matrixEditor.onVisibilityChange((isOpen) => { this.toggleSidebarInteractivity(isOpen) }) - - inputAdd.addEventListener('click', () => this.addDimension('input')) - outputAdd.addEventListener('click', () => this.addDimension('output')) - matrixButton.addEventListener('click', () => this.handleMatrixEditorClick()) - - this.renderDimensionRows('input', { showErrors: false }) - this.renderDimensionRows('output', { showErrors: false }) - this.renderOperationsInfo(controlsContainer) + this.matrixEditor.onMatrixChange((matrix) => { + this.rebuildLayoutFromMatrix(matrix) + }) this.params = { sizePerThread: [1, 1], - threadsPerWarp: [8, 4], - warpsPerCTA: [2, 1], + threadsPerWarp: [32, 1], + warpsPerCTA: [1, 1], order: [0, 1], - tensorShape: [8, 8], + tensorShape: [1, 1], } - const totalThreads = - this.params.threadsPerWarp[0] * - this.params.threadsPerWarp[1] * - this.params.warpsPerCTA[0] * - this.params.warpsPerCTA[1] + this.layout = LinearLayout.empty() - this.layout = LinearLayout.identity1D(totalThreads, 'thread', 'logical') + this.resizeCanvas() - this.initializeRenderer() - } + inputAdd.addEventListener('click', () => this.addDimension('input')) + outputAdd.addEventListener('click', () => this.addDimension('output')) + matrixButton.addEventListener('click', () => this.handleMatrixEditorClick()) - private initializeRenderer(): void { - const positionResolver: PositionResolver = (_layout, threadId) => { - const columns = this.params.tensorShape[1] - const row = Math.floor(threadId / columns) - const column = threadId % columns - return [ - { - pos: [row, column], - registerId: threadId, - sourcePos: [row, column], - }, - ] - } + this.renderDimensionRows('input') + this.renderDimensionRows('output') + this.renderOperationsInfo(controlsContainer) - this.resizeCanvas() - const renderer = new CanvasRenderer(this.canvas, this.layout, this.params, positionResolver, { - colorGrouping: 'thread', - }) - this.setRenderer(renderer) - renderer.render() + this.updateRendererFromLayout() } protected handleHover(event: MouseEvent): void { @@ -175,15 +168,13 @@ export class LinearLayoutTab extends CanvasTab { return } - const warpColor = renderer.getWarpColor(cellInfo.warpId) + const inputSection = this.formatCoordinateSection(cellInfo.inputCoords, 'input') + const outputSection = this.formatCoordinateSection(cellInfo.outputCoords, 'output') + const basisSection = this.buildBasisSection(cellInfo) const tooltipContent = ` -
Position: (${cellInfo.position[0]}, ${cellInfo.position[1]})
-
- Warp: ${cellInfo.warpId} - -
-
Thread: ${cellInfo.threadId}
-
Register: ${cellInfo.registerId}
+
Input${inputSection}
+
Output${outputSection}
+ ${basisSection} ` this.tooltip.show(tooltipContent, event.clientX, event.clientY) @@ -212,10 +203,7 @@ export class LinearLayoutTab extends CanvasTab { ` } - private renderDimensionRows( - type: DimensionType, - options: { showErrors?: boolean } = { showErrors: true } - ): void { + private renderDimensionRows(type: DimensionType): void { const list = this.dimensionLists[type] list.innerHTML = '' @@ -232,7 +220,7 @@ export class LinearLayoutTab extends CanvasTab { nameInput.value = dimension.name nameInput.addEventListener('input', (event) => { dimension.name = (event.target as HTMLInputElement).value - this.refreshValidation(true) + this.handleDimensionFieldChange() }) nameLabel.appendChild(nameInput) @@ -245,7 +233,7 @@ export class LinearLayoutTab extends CanvasTab { sizeInput.addEventListener('input', (event) => { const value = Number((event.target as HTMLInputElement).value) dimension.size = Number.isFinite(value) ? value : 0 - this.refreshValidation(true) + this.handleDimensionFieldChange() }) sizeLabel.appendChild(sizeInput) @@ -260,20 +248,15 @@ export class LinearLayoutTab extends CanvasTab { this.removeDimension(type, dimension.id) }) - const error = document.createElement('div') - error.className = 'dimension-error' - error.dataset.errorFor = dimension.id - row.appendChild(nameLabel) row.appendChild(sizeLabel) row.appendChild(removeButton) - row.appendChild(error) list.appendChild(row) }) this.updateAddButtonState(type) - this.refreshValidation(options.showErrors ?? true) + this.syncEditorAndLayout() } private toggleSidebarInteractivity(isLocked: boolean): void { @@ -285,60 +268,82 @@ export class LinearLayoutTab extends CanvasTab { } } - private refreshValidation(showErrors: boolean): void { - const inputsValid = this.dimensionState.input.every((dimension) => - this.validateDimension(dimension, showErrors) - ) - const outputsValid = this.dimensionState.output.every((dimension) => - this.validateDimension(dimension, showErrors) - ) + private handleDimensionFieldChange(): void { + this.syncEditorAndLayout() + } - const canEditMatrix = inputsValid && outputsValid && this.dimensionState.output.length > 0 - this.matrixButton.disabled = !canEditMatrix + private syncEditorAndLayout(): void { + const emittedMatrixChange = this.matrixEditor.updateDimensions(this.getMatrixDimensions()) + if (!emittedMatrixChange) { + this.rebuildLayoutFromMatrix() + } } - private validateDimension(dimension: LinearDimension, showErrors: boolean): boolean { - let errorMessage = '' - const trimmedName = dimension.name.trim() - if (!trimmedName) { - errorMessage = 'Name is required.' - } else if (!Number.isInteger(dimension.size)) { - errorMessage = 'Size must be an integer.' - } else if (dimension.size < 2) { - errorMessage = 'Size must be at least 2.' - } else if (!this.isPowerOfTwo(dimension.size)) { - errorMessage = 'Size must be a power of two.' + private rebuildLayoutFromMatrix(matrixSnapshot?: number[][]): void { + if (this.dimensionState.input.length === 0 || this.dimensionState.output.length === 0) { + this.setLayoutStatus('Add at least one input and one output dimension.') + return } - if (showErrors) { - this.setRowError(dimension.id, errorMessage) - } else if (errorMessage === '') { - this.clearRowError(dimension.id) + const matrix = matrixSnapshot ?? this.matrixEditor.getMatrix() + if (matrix.length === 0 || (matrix[0]?.length ?? 0) === 0) { + this.setLayoutStatus('Open the matrix editor to configure the layout matrix.') + return } - return errorMessage === '' + const dimensions = this.getMatrixDimensions() + const inputs = dimensions.input.map(({ name, size }) => ({ name: name.trim(), size })) + const outputs = dimensions.output.map(({ name, size }) => ({ name: name.trim(), size })) + + try { + const layout = LinearLayout.fromBitMatrix(matrix, inputs, outputs) + if (!layout.isInvertible()) { + throw new Error('Layout matrix is not invertible. Ensure input/output bit counts match and the matrix has full rank.') + } + this.layout = layout + this.setLayoutStatus('') + this.updateRendererFromLayout() + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + this.setLayoutStatus(message) + } } - private setRowError(dimensionId: string, message: string): void { - const errorElement = this.sidebar.querySelector( - `.dimension-error[data-error-for="${dimensionId}"]` - ) - if (!errorElement) { + private updateRendererFromLayout(): void { + const tensorShape = this.deriveTensorShape() + this.params.tensorShape = tensorShape + + const renderer = this.getRenderer() + if (renderer) { + renderer.updateLayout(this.layout, this.params) return } - errorElement.textContent = message - errorElement.classList.toggle('visible', Boolean(message)) - } - private clearRowError(dimensionId: string): void { - const errorElement = this.sidebar.querySelector( - `.dimension-error[data-error-for="${dimensionId}"]` + const newRenderer = new CanvasRenderer( + this.canvas, + this.layout, + this.params, + undefined, + { colorGrouping: 'thread', traversalMode: 'by-output', showCellText: false } ) - if (!errorElement) { - return + this.setRenderer(newRenderer) + newRenderer.render() + } + + private deriveTensorShape(): [number, number] { + const [first, second] = this.layout.getOutDims() + if (first && second) { + return [first[1], second[1]] + } + if (first) { + return [1, first[1]] } - errorElement.textContent = '' - errorElement.classList.remove('visible') + return this.params.tensorShape + } + + private setLayoutStatus(message: string): void { + this.layoutStatus.textContent = message + this.layoutStatus.classList.toggle('visible', Boolean(message)) } private addDimension(type: DimensionType): void { @@ -352,7 +357,7 @@ export class LinearLayoutTab extends CanvasTab { size: 2, } this.dimensionState[type] = [...this.dimensionState[type], newDimension] - this.renderDimensionRows(type, { showErrors: false }) + this.renderDimensionRows(type) } private removeDimension(type: DimensionType, id: string): void { @@ -360,7 +365,7 @@ export class LinearLayoutTab extends CanvasTab { return } this.dimensionState[type] = this.dimensionState[type].filter((dimension) => dimension.id !== id) - this.renderDimensionRows(type, { showErrors: true }) + this.renderDimensionRows(type) } private updateAddButtonState(type: DimensionType): void { @@ -401,18 +406,225 @@ export class LinearLayoutTab extends CanvasTab { return `dim-${Math.random().toString(36).slice(2, 10)}` } - private isPowerOfTwo(value: number): boolean { - return value > 0 && (value & (value - 1)) === 0 - } - private handleMatrixEditorClick(): void { this.matrixEditor.open(this.getMatrixDimensions()) } private getMatrixDimensions(): MatrixEditorDimensions { return { - input: this.dimensionState.input.map((dimension) => ({ ...dimension })), - output: this.dimensionState.output.map((dimension) => ({ ...dimension })), + input: this.dimensionState.input.map((dimension) => ({ + ...dimension, + name: dimension.name.trim(), + })), + output: this.dimensionState.output.map((dimension) => ({ + ...dimension, + name: dimension.name.trim(), + })), + } + } + + private formatCoordinateSection( + coords: Record | undefined, + type: DimensionType + ): string { + const dimensions = this.dimensionState[type] + if (!coords || dimensions.length === 0) { + return '
n/a
' + } + + return dimensions + .map((dimension) => { + const key = dimension.name.trim() || dimension.name + const value = coords[key] ?? 0 + const bitWidth = Math.max(1, this.getBitWidthFromSize(dimension.size)) + return `
${key || 'unnamed'}: ${this.formatValueWithBinary(value, bitWidth)}
` + }) + .join('') + } + + private buildBasisSection(cellInfo: CellInfo): string { + if (!cellInfo.inputCoords || !cellInfo.outputCoords) { + return this.renderEmptyBasisSection() + } + + const outDims = this.layout.getOutDims() + if (outDims.length === 0) { + return this.renderEmptyBasisSection() + } + + const bases = this.layout.getBases() + const columns = this.buildOrderedBasisColumns(bases) + + if (columns.length === 0) { + return this.renderEmptyBasisSection() + } + + const outputGroups: OutputBitGroup[] = outDims.map(([name, size]) => ({ + name, + bitWidth: Math.max(1, this.getBitWidthFromSize(size)), + })) + + const totalRows = outputGroups.reduce((total, group) => total + group.bitWidth, 0) + if (totalRows === 0) { + return this.renderEmptyBasisSection() + } + + const basisColumns = columns.map((column, index) => { + const prevDimName = index > 0 ? columns[index - 1]?.dimName : undefined + const hasInputGap = index > 0 && column.dimName !== prevDimName + return this.renderBasisColumn( + column, + outputGroups, + cellInfo.inputCoords as Record, + hasInputGap + ) + }) + + if (basisColumns.length === 0) { + return this.renderEmptyBasisSection() + } + + const resultColumn = this.renderBasisResultColumn( + cellInfo.outputCoords as Record, + outputGroups + ) + const columnMarkup = basisColumns.join('') + + return ` +
+ ${BASIS_HEADING} +
+
+ ${columnMarkup} + = + ${resultColumn} +
+
+
+ ` + } + + private renderEmptyBasisSection(): string { + return `
${BASIS_HEADING}
n/a
` + } + + private buildOrderedBasisColumns( + bases: ReadonlyMap + ): BasisColumnDescriptor[] { + const orderedColumns: BasisColumnDescriptor[] = [] + const processed = new Set() + + const appendColumns = (rawName: string | undefined): void => { + const dimName = rawName?.trim() || rawName || '' + if (!dimName || processed.has(dimName)) { + return + } + const dimBases = bases.get(dimName) + if (!dimBases || dimBases.length === 0) { + processed.add(dimName) + return + } + dimBases.forEach((values, bitIndex) => { + orderedColumns.push({ + dimName, + bitIndex, + values: values ?? [], + }) + }) + processed.add(dimName) + } + + this.dimensionState.input + .map((dimension) => { + const trimmed = dimension.name.trim() + return trimmed || dimension.name + }) + .filter((name): name is string => Boolean(name)) + .forEach((name) => appendColumns(name)) + + this.layout.getInDimNames().forEach((name) => { + appendColumns(name) + }) + + return orderedColumns + } + + private renderBasisColumn( + column: BasisColumnDescriptor, + groups: OutputBitGroup[], + inputCoords: Record, + hasInputGap = false + ): string { + const inputValue = inputCoords[column.dimName] ?? 0 + const bitValue = (inputValue >> column.bitIndex) & 1 + const label = `${column.dimName || 'input'} bit ${column.bitIndex}` + const bitRows = this.renderBitRows(column.values, groups) + const columnClasses = ['basis-column'] + if (hasInputGap) { + columnClasses.push('basis-column-input-gap') + } + return ` + + ${bitValue} + × + ${bitRows} + + ` + } + + private renderBasisResultColumn( + outputCoords: Record, + groups: OutputBitGroup[] + ): string { + const values = groups.map(({ name }) => outputCoords[name] ?? 0) + const bitRows = this.renderBitRows(values, groups) + return ` + + + + ${bitRows} + + ` + } + + private renderBitRows(values: number[], groups: OutputBitGroup[]): string { + const rows: string[] = [] + groups.forEach(({ bitWidth }, idx) => { + const rawValue = values[idx] ?? 0 + const normalized = Number.isFinite(rawValue) ? Math.max(0, Math.trunc(rawValue)) : 0 + for (let bit = 0; bit < Math.max(1, bitWidth); bit++) { + const bitValue = (normalized >> bit) & 1 + rows.push(`${bitValue}`) + } + if (idx < groups.length - 1) { + rows.push('') + } + }) + return rows.join('') + } + + private formatValueWithBinary(value: number, bitWidth: number): string { + const normalized = Number.isFinite(value) ? Math.max(0, Math.trunc(value)) : 0 + const binary = this.formatBinaryValue(normalized, bitWidth) + return `${normalized}(0b${binary})` + } + + private formatBinaryValue(value: number, bitWidth: number): string { + const normalized = Number.isFinite(value) ? Math.max(0, Math.trunc(value)) : 0 + const digits = normalized.toString(2) + const width = Math.max(bitWidth, digits.length, 1) + return digits.padStart(width, '0') + } + + private getBitWidthFromSize(size: number): number { + if (!Number.isFinite(size) || size < 2) { + return 1 + } + const log2 = Math.log2(size) + if (!Number.isFinite(log2) || log2 <= 0) { + return 1 } + const exactWidth = Number.isInteger(log2) ? log2 : Math.ceil(log2) + return Math.max(1, exactWidth) } } diff --git a/src/ui/LinearLayoutMatrixEditor.ts b/src/ui/LinearLayoutMatrixEditor.ts index 802e6e4..ad3283b 100644 --- a/src/ui/LinearLayoutMatrixEditor.ts +++ b/src/ui/LinearLayoutMatrixEditor.ts @@ -77,6 +77,7 @@ export class LinearLayoutMatrixEditor { private readonly keydownHandler: (event: KeyboardEvent) => void private readonly viewportResizeHandler: () => void private readonly visibilityListeners = new Set<(isOpen: boolean) => void>() + private readonly matrixListeners = new Set<(matrix: number[][]) => void>() private readonly autoFitClass = 'matrix-auto-fit' private forwardedHoverTarget: HTMLCanvasElement | null = null @@ -272,12 +273,12 @@ export class LinearLayoutMatrixEditor { /** * Update the internal dimension snapshot without showing the modal. */ - public updateDimensions(dimensions: MatrixEditorDimensions): void { + public updateDimensions(dimensions: MatrixEditorDimensions): boolean { this.currentDimensions = { input: dimensions.input.map((dim) => ({ ...dim })), output: dimensions.output.map((dim) => ({ ...dim })), } - const shapeChanged = this.rebuildMatrixIfNeeded() + const { shapeChanged, matrixUpdated } = this.rebuildMatrixIfNeeded() if (shapeChanged) { this.autoFitToMatrix() } @@ -286,6 +287,7 @@ export class LinearLayoutMatrixEditor { } else { this.needsRender = true } + return matrixUpdated } /** @@ -340,6 +342,13 @@ export class LinearLayoutMatrixEditor { return this.matrixValues.map((row) => [...row]) } + public onMatrixChange(listener: (matrix: number[][]) => void): () => void { + this.matrixListeners.add(listener) + return () => { + this.matrixListeners.delete(listener) + } + } + private ensureMatrixRendered(): void { if (!this.needsRender) { this.scheduleCellSizeUpdate() @@ -351,7 +360,7 @@ export class LinearLayoutMatrixEditor { this.scheduleCellSizeUpdate() } - private rebuildMatrixIfNeeded(): boolean { + private rebuildMatrixIfNeeded(): { shapeChanged: boolean; matrixUpdated: boolean } { const rowBits = this.buildBitDescriptors(this.currentDimensions.output) const columnBits = this.buildBitDescriptors(this.currentDimensions.input) const signature = this.computeSignature(this.currentDimensions) @@ -359,16 +368,20 @@ export class LinearLayoutMatrixEditor { rowBits.length !== this.matrixValues.length || (this.matrixValues[0]?.length ?? 0) !== columnBits.length + let matrixUpdated = false if (signature !== this.signature || shapeChanged) { this.matrixValues = Array.from({ length: rowBits.length }, () => Array.from({ length: columnBits.length }, () => 0) ) this.signature = signature + this.seedDefaultMatrix(rowBits.length, columnBits.length) + this.notifyMatrixChange() + matrixUpdated = true } this.rowBits = rowBits this.columnBits = columnBits - return shapeChanged + return { shapeChanged, matrixUpdated } } private renderMatrix(): void { @@ -450,6 +463,7 @@ export class LinearLayoutMatrixEditor { button.textContent = next.toString() button.setAttribute('aria-pressed', next === 1 ? 'true' : 'false') button.classList.toggle('active', next === 1) + this.notifyMatrixChange() } private handleOverlayMouseMove(event: MouseEvent): void { @@ -642,6 +656,24 @@ export class LinearLayoutMatrixEditor { return `${serialize(dimensions.input)}->${serialize(dimensions.output)}` } + private seedDefaultMatrix(rows: number, cols: number): void { + const limit = Math.min(rows, cols) + for (let idx = 0; idx < limit; idx++) { + const row = this.matrixValues[idx] + if (row) { + row[idx] = 1 + } + } + } + + private notifyMatrixChange(): void { + if (this.matrixListeners.size === 0) { + return + } + const snapshot = this.getMatrix() + this.matrixListeners.forEach((listener) => listener(snapshot)) + } + private initializeDialogFrame(): void { this.autoFitToMatrix({ recenter: true }) } diff --git a/src/visualization/CanvasRenderer.test.ts b/src/visualization/CanvasRenderer.test.ts index 8e839af..5c25edf 100644 --- a/src/visualization/CanvasRenderer.test.ts +++ b/src/visualization/CanvasRenderer.test.ts @@ -2,6 +2,7 @@ import { describe, it, expect, beforeEach, vi } from 'vitest' import { CanvasRenderer, type PositionResolver } from './CanvasRenderer' import { createBlockLayout } from '../layouts/BlockLayout' import type { BlockLayoutParams } from '../validation/InputValidator' +import { LinearLayout } from '../core/LinearLayout' describe('CanvasRenderer', () => { let canvas: HTMLCanvasElement @@ -66,6 +67,24 @@ describe('CanvasRenderer', () => { }) }) + describe('cell text labels', () => { + it('should omit text when disabled', () => { + const ctx = canvas.getContext('2d')! + const fillTextSpy = vi.spyOn(ctx, 'fillText') + const textlessRenderer = new CanvasRenderer( + canvas, + layout, + params, + undefined, + { showCellText: false } + ) + + textlessRenderer.render() + expect(fillTextSpy).not.toHaveBeenCalled() + fillTextSpy.mockRestore() + }) + }) + describe('color grouping', () => { it('should assign unique colors per thread when requested', () => { const threadParams: BlockLayoutParams = { @@ -321,4 +340,44 @@ describe('CanvasRenderer', () => { expect(() => extendedRenderer.render()).not.toThrow() }) }) + + describe('output traversal mode', () => { + it('resolves cells directly from output coordinates', () => { + const tensorParams: BlockLayoutParams = { + sizePerThread: [1, 1] as [number, number], + threadsPerWarp: [4, 1] as [number, number], + warpsPerCTA: [1, 1] as [number, number], + order: [0, 1] as [number, number], + tensorShape: [4, 4] as [number, number], + } + + const rowLayout = LinearLayout.identity1D(4, 'thread', 'row') + const colLayout = LinearLayout.identity1D(4, 'register', 'col') + const layout = rowLayout.multiply(colLayout) + + const outputRenderer = new CanvasRenderer( + canvas, + layout, + tensorParams, + undefined, + { traversalMode: 'by-output', colorGrouping: 'thread' } + ) + + const cell = outputRenderer.getCellInfo(0, 0) + expect(cell).not.toBeNull() + expect(cell?.outputCoords?.row).toBe(0) + expect(cell?.outputCoords?.col).toBe(0) + expect(cell?.inputCoords?.thread).toBe(0) + expect(cell?.inputCoords?.register).toBe(0) + expect(typeof outputRenderer.getThreadColor(cell!.threadId)).toBe('string') + + const farCell = outputRenderer.getCellInfo(3, 3) + expect(farCell).not.toBeNull() + expect(farCell?.outputCoords?.row).toBe(3) + expect(farCell?.outputCoords?.col).toBe(3) + + const cacheSize = (outputRenderer as unknown as { cellDataCache: Map }).cellDataCache.size + expect(cacheSize).toBe(tensorParams.tensorShape[0] * tensorParams.tensorShape[1]) + }) + }) }) diff --git a/src/visualization/CanvasRenderer.ts b/src/visualization/CanvasRenderer.ts index db25db0..dde98d6 100644 --- a/src/visualization/CanvasRenderer.ts +++ b/src/visualization/CanvasRenderer.ts @@ -5,6 +5,13 @@ import { ColorScheme } from './ColorScheme' import { ViewportController } from './ViewportController' type ColorGroupingMode = 'warp' | 'thread' +type TraversalMode = 'by-thread' | 'by-output' + +interface CanvasRendererOptions { + colorGrouping?: ColorGroupingMode + traversalMode?: TraversalMode + showCellText?: boolean +} export interface CellInfo { threadId: number @@ -17,6 +24,8 @@ export interface CellInfo { warpId: number position: [number, number] sourcePosition: [number, number] + inputCoords?: Record + outputCoords?: Record } export interface ResolvedPosition { @@ -49,7 +58,7 @@ export type PositionResolver = ( export class CanvasRenderer { private ctx: CanvasRenderingContext2D private viewportController: ViewportController - private colorScheme: ColorScheme + private colorScheme: ColorScheme = new ColorScheme(1, 1) private cellSize = 50 // Base size of each cell in pixels private isDragging = false private lastMouseX = 0 @@ -62,13 +71,16 @@ export class CanvasRenderer { // Position resolver function private positionResolver: PositionResolver + private traversalMode: TraversalMode + private colorGrouping: ColorGroupingMode + private showCellText: boolean constructor( private canvas: HTMLCanvasElement, private layout: LinearLayout, private params: BlockLayoutParams, positionResolver?: PositionResolver, - options?: { colorGrouping?: ColorGroupingMode } + options?: CanvasRendererOptions ) { const ctx = canvas.getContext('2d') if (!ctx) { @@ -77,14 +89,10 @@ export class CanvasRenderer { this.ctx = ctx // Calculate total number of warps and warp size - const totalWarps = params.warpsPerCTA[0] * params.warpsPerCTA[1] - const warpSize = params.threadsPerWarp[0] * params.threadsPerWarp[1] - const totalThreads = totalWarps * warpSize - const colorGrouping: ColorGroupingMode = options?.colorGrouping ?? 'warp' - this.colorScheme = - colorGrouping === 'thread' - ? new ColorScheme(totalThreads, 1) - : new ColorScheme(totalWarps, warpSize) + this.colorGrouping = options?.colorGrouping ?? 'warp' + this.traversalMode = options?.traversalMode ?? 'by-thread' + this.showCellText = options?.showCellText ?? true + this.resetColorSchemeFromParams() // Initialize viewport controller this.viewportController = new ViewportController( @@ -115,6 +123,13 @@ export class CanvasRenderer { * Build a cache mapping each cell position to its thread/register info */ private buildCellDataCache(): Map { + if (this.traversalMode === 'by-output') { + return this.buildCellDataCacheByOutput() + } + return this.buildCellDataCacheByThread() + } + + private buildCellDataCacheByThread(): Map { const cache = new Map() this.sourcePositionIndex = new Map>() const totalThreads = this.params.threadsPerWarp[0] * @@ -165,6 +180,103 @@ export class CanvasRenderer { return cache } + private buildCellDataCacheByOutput(): Map { + const cache = new Map() + this.sourcePositionIndex = new Map>() + const outputDims = this.layout.getOutDims() + if (outputDims.length === 0) { + return cache + } + + const inverse = this.layout.invert() + const coords: Record = {} + let maxThreadId = 0 + + const traverse = (index: number) => { + if (index === outputDims.length) { + const outputCoords = { ...coords } + const inputCoords = inverse.apply(outputCoords) + const threadId = inputCoords.thread ?? 0 + const registerId = inputCoords.register ?? 0 + const warpSize = Math.max(this.params.threadsPerWarp[0] * this.params.threadsPerWarp[1], 1) + const warpId = this.colorGrouping === 'thread' + ? threadId + : Math.floor(threadId / warpSize) + maxThreadId = Math.max(maxThreadId, threadId) + + const primaryDim = outputDims[0]?.[0] + const secondaryDim = outputDims[1]?.[0] + const row = secondaryDim ? (outputCoords[primaryDim ?? ''] ?? 0) : 0 + const col = secondaryDim + ? outputCoords[secondaryDim] ?? 0 + : primaryDim + ? outputCoords[primaryDim] ?? 0 + : 0 + + const key = `${row},${col}` + const sourcePosition: [number, number] = [row, col] + const cellInfo: CellInfo = { + threadId, + registerId, + warpId, + position: [row, col], + sourcePosition, + inputCoords, + outputCoords, + } + + const existing = cache.get(key) + if (existing) { + existing.push(cellInfo) + } else { + cache.set(key, [cellInfo]) + } + + const sourceKey = `${sourcePosition[0]},${sourcePosition[1]}` + const sourceEntries = this.sourcePositionIndex.get(sourceKey) + if (sourceEntries) { + sourceEntries.push([row, col]) + } else { + this.sourcePositionIndex.set(sourceKey, [[row, col]]) + } + return + } + + const dim = outputDims[index] + if (!dim) return + const [name, size] = dim + for (let value = 0; value < size; value++) { + coords[name] = value + traverse(index + 1) + } + } + + traverse(0) + this.rebuildColorSchemeFromThreads(maxThreadId) + return cache + } + + private rebuildColorSchemeFromThreads(maxThreadId: number): void { + const warpSize = Math.max(this.params.threadsPerWarp[0] * this.params.threadsPerWarp[1], 1) + if (this.colorGrouping === 'thread') { + const totalThreads = Math.max(maxThreadId + 1, 1) + this.colorScheme = new ColorScheme(totalThreads, 1) + } else { + const totalWarps = Math.max(Math.ceil((maxThreadId + 1) / warpSize), 1) + this.colorScheme = new ColorScheme(totalWarps, warpSize) + } + } + + private resetColorSchemeFromParams(): void { + const totalWarps = Math.max(this.params.warpsPerCTA[0] * this.params.warpsPerCTA[1], 1) + const warpSize = Math.max(this.params.threadsPerWarp[0] * this.params.threadsPerWarp[1], 1) + if (this.colorGrouping === 'thread') { + this.colorScheme = new ColorScheme(Math.max(totalWarps * warpSize, 1), 1) + } else { + this.colorScheme = new ColorScheme(totalWarps, warpSize) + } + } + /** * Render the entire visualization */ @@ -240,7 +352,7 @@ export class CanvasRenderer { this.ctx.fillRect(x, y, size, size) // Draw text if cell is large enough - if (size > 20) { + if (this.showCellText && size > 20) { this.drawCellText(cellInfo, x, y, size) } } @@ -380,6 +492,10 @@ export class CanvasRenderer { return this.colorScheme.getColorForWarp(warpId) } + getThreadColor(threadId: number): string { + return this.colorScheme.getColorForThread(threadId) + } + /** * Zoom in */ @@ -418,4 +534,23 @@ export class CanvasRenderer { screenToGrid(x: number, y: number): { row: number; col: number } { return this.viewportController.screenToGrid(x, y) } + + updateLayout(layout: LinearLayout, params?: BlockLayoutParams): void { + this.layout = layout + if (params) { + this.params = params + this.viewportController = new ViewportController( + this.canvas, + this.params.tensorShape[1], + this.params.tensorShape[0] + ) + } + + if (this.traversalMode !== 'by-output') { + this.resetColorSchemeFromParams() + } + + this.cellDataCache = this.buildCellDataCache() + this.render() + } }