From e667264d24a966b5bb7ccbadefa7cb4e82f03430 Mon Sep 17 00:00:00 2001 From: leeliu103 Date: Mon, 8 Dec 2025 17:32:06 +0000 Subject: [PATCH] Add cross-tab layout projection from Block Layout to Linear Layout MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements a projection system that allows users to export Block Layout configurations into the Linear Layout tab for further analysis and visualization. Key additions: - LayoutProjectionBus: Event-driven communication system for cross-tab layout transfer - "Show in Linear Layout" button in Block Layout tab with automatic dimension filtering - filterSnapshotDimensions: Filters out size-1 dimensions during projection - toMatrixSnapshot/reorderOutputs methods in LinearLayout for snapshot export - importLayoutSnapshot in LinearLayoutTab with metadata-based defaults - Comprehensive test coverage for projection flow, dimension filtering, and round-trip conversions The projection automatically normalizes dimension names (e.g., "lane"), filters size-1 dimensions, and preserves metadata like source type, tensor shape, and color dimension preferences. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- index.html | 1 + src/core/LinearLayout.ts | 120 ++++++++++ src/core/filterSnapshotDimensions.test.ts | 96 ++++++++ src/core/filterSnapshotDimensions.ts | 111 +++++++++ src/integration/LayoutProjectionBus.test.ts | 110 +++++++++ src/integration/LayoutProjectionBus.ts | 59 +++++ src/layouts/BlockLayout.test.ts | 120 ++++++++++ src/layouts/BlockLayout.ts | 3 +- src/main.tabs.test.ts | 133 ++++++++++- src/main.ts | 15 +- src/tabs/BlockLayoutTab.test.ts | 237 +++++++++++++++++++ src/tabs/BlockLayoutTab.ts | 84 +++++++ src/tabs/LinearLayoutTab.test.ts | 233 ++++++++++++++++++ src/tabs/LinearLayoutTab.ts | 121 +++++++++- src/ui/LinearLayoutMatrixEditor.ts | 66 +++++- src/validation/LinearLayoutValidator.test.ts | 28 ++- 16 files changed, 1516 insertions(+), 21 deletions(-) create mode 100644 src/core/filterSnapshotDimensions.test.ts create mode 100644 src/core/filterSnapshotDimensions.ts create mode 100644 src/integration/LayoutProjectionBus.test.ts create mode 100644 src/integration/LayoutProjectionBus.ts create mode 100644 src/tabs/BlockLayoutTab.test.ts create mode 100644 src/tabs/LinearLayoutTab.test.ts diff --git a/index.html b/index.html index 2e83c08..78ac43f 100644 --- a/index.html +++ b/index.html @@ -100,6 +100,7 @@

Tensor Shape

+
diff --git a/src/core/LinearLayout.ts b/src/core/LinearLayout.ts index 35a793f..603ea72 100644 --- a/src/core/LinearLayout.ts +++ b/src/core/LinearLayout.ts @@ -117,6 +117,35 @@ interface BitDescriptor { bit: number } +export interface LayoutSnapshotMetadata { + sourceLayoutType: string + sourceTabId?: string + tensorShape?: number[] + description?: string + generatedAt?: string + colorInputDimension?: string +} + +export interface LayoutMatrixSnapshot { + inputDimensions: Array<{ name: string; size: number }> + outputDimensions: Array<{ name: string; size: number }> + matrix: number[][] + metadata?: LayoutSnapshotMetadata +} + +function cloneMetadata(metadata?: LayoutSnapshotMetadata): LayoutSnapshotMetadata | undefined { + if (!metadata) { + return undefined + } + const cloned: LayoutSnapshotMetadata = { + ...metadata, + } + if (metadata.tensorShape) { + cloned.tensorShape = [...metadata.tensorShape] + } + return cloned +} + function assertPowerOfTwo(value: number, context: string): void { if (!Number.isInteger(value) || value <= 0 || (value & (value - 1)) !== 0) { throw new Error(`${context} must be a positive power of two, got ${value}`) @@ -475,6 +504,26 @@ export class LinearLayout { return new LinearLayout(basesArray, outDimPairs) } + toMatrixSnapshot(metadata?: LayoutSnapshotMetadata): LayoutMatrixSnapshot { + const inputSpecs = this.getInputDimensionSpecs() + const outputSpecs = this.getOutputDimensionSpecs() + if (inputSpecs.length === 0 || outputSpecs.length === 0) { + return { + inputDimensions: [], + outputDimensions: [], + matrix: [], + metadata: cloneMetadata(metadata), + } + } + const matrix = basesToMatrix(this.bases, inputSpecs, outputSpecs) + return { + inputDimensions: inputSpecs.map((spec) => ({ name: spec.name, size: spec.size })), + outputDimensions: outputSpecs.map((spec) => ({ name: spec.name, size: spec.size })), + matrix, + metadata: cloneMetadata(metadata), + } + } + /** * Apply the linear layout: compute L(inputs) * Uses XOR to combine bases according to input bit patterns @@ -682,6 +731,77 @@ export class LinearLayout { return result } + /** + * Return a layout whose output dimensions follow the preferred order. + * Any dimensions not mentioned in preferredOrder retain their existing + * relative ordering and are appended to the end. + */ + reorderOutputs(preferredOrder: string[]): LinearLayout { + const currentOrder = this.getOutDimNames() + if (currentOrder.length <= 1) { + return this + } + + const normalizedPreference = preferredOrder + .map((name) => name.trim()) + .filter((name) => name.length > 0) + + const finalOrder: string[] = [] + const seen = new Set() + const appendIfPresent = (name: string): void => { + if (seen.has(name)) { + return + } + if (!currentOrder.includes(name)) { + return + } + seen.add(name) + finalOrder.push(name) + } + + normalizedPreference.forEach(appendIfPresent) + currentOrder.forEach(appendIfPresent) + + if (finalOrder.length !== currentOrder.length) { + // Should not happen because we append all remaining dimensions. + return this + } + + let hasChanges = false + for (let i = 0; i < finalOrder.length; i++) { + if (finalOrder[i] !== currentOrder[i]) { + hasChanges = true + break + } + } + if (!hasChanges) { + return this + } + + const currentIndex = new Map() + currentOrder.forEach((name, index) => currentIndex.set(name, index)) + + const reorderedBases: Array<[string, number[][]]> = [] + for (const [inDim, dimBases] of this.bases.entries()) { + const remappedBases = dimBases.map((basisVector) => { + return finalOrder.map((outName) => { + const originalIndex = currentIndex.get(outName) + if (originalIndex === undefined) { + return 0 + } + return basisVector[originalIndex] ?? 0 + }) + }) + reorderedBases.push([inDim, remappedBases]) + } + + const reorderedOutDims = finalOrder.map( + (name) => [name, this.getOutDimSize(name)] as [string, number] + ) + + return new LinearLayout(reorderedBases, reorderedOutDims) + } + private getInputDimensionSpecs(): DimensionSpec[] { return this.getInDimNames().map((name) => ({ name, diff --git a/src/core/filterSnapshotDimensions.test.ts b/src/core/filterSnapshotDimensions.test.ts new file mode 100644 index 0000000..ba5e961 --- /dev/null +++ b/src/core/filterSnapshotDimensions.test.ts @@ -0,0 +1,96 @@ +import { describe, it, expect } from 'vitest' +import type { LayoutMatrixSnapshot } from './LinearLayout' +import { filterSnapshotDimensions } from './filterSnapshotDimensions' + +const buildSnapshot = (): LayoutMatrixSnapshot => ({ + inputDimensions: [ + { name: 'reg', size: 4 }, + { name: 'lane', size: 2 }, + ], + outputDimensions: [ + { name: 'dim0', size: 4 }, + { name: 'dim1', size: 2 }, + ], + matrix: [ + [1, 0, 1], + [0, 1, 0], + [1, 1, 0], + [0, 0, 1], + ], +}) + +describe('filterSnapshotDimensions', () => { + it('clones the snapshot when no dimensions are removed', () => { + const snapshot = buildSnapshot() + const { snapshot: filtered, removedInputDimensions, removedOutputDimensions } = + filterSnapshotDimensions(snapshot) + + expect(removedInputDimensions).toHaveLength(0) + expect(removedOutputDimensions).toHaveLength(0) + expect(filtered).not.toBe(snapshot) + expect(filtered.inputDimensions).toEqual(snapshot.inputDimensions) + expect(filtered.outputDimensions).toEqual(snapshot.outputDimensions) + expect(filtered.matrix).toEqual(snapshot.matrix) + }) + + it('removes size-1 dimensions and leaves remaining matrix intact', () => { + const snapshot: LayoutMatrixSnapshot = { + inputDimensions: [ + { name: 'reg', size: 2 }, + { name: 'warp', size: 1 }, + ], + outputDimensions: [ + { name: 'dim0', size: 4 }, + { name: 'dim1', size: 1 }, + ], + matrix: [ + [1], + [0], + ], + } + + const { snapshot: filtered, removedInputDimensions, removedOutputDimensions } = + filterSnapshotDimensions(snapshot) + + expect(filtered.inputDimensions).toEqual([{ name: 'reg', size: 2 }]) + expect(filtered.outputDimensions).toEqual([{ name: 'dim0', size: 4 }]) + expect(filtered.matrix).toEqual([ + [1], + [0], + ]) + expect(removedInputDimensions).toEqual(['warp']) + expect(removedOutputDimensions).toEqual(['dim1']) + }) + + it('drops matrix rows and columns when removing larger dimensions via custom min size', () => { + const snapshot: LayoutMatrixSnapshot = { + inputDimensions: [ + { name: 'reg', size: 4 }, + { name: 'lane', size: 2 }, + ], + outputDimensions: [ + { name: 'dim0', size: 8 }, + { name: 'dim1', size: 2 }, + ], + matrix: [ + [1, 0, 0], + [0, 1, 0], + [1, 1, 0], + [0, 0, 1], + ], + } + + const { snapshot: filtered, removedInputDimensions, removedOutputDimensions } = + filterSnapshotDimensions(snapshot, 4) + + expect(filtered.inputDimensions).toEqual([{ name: 'reg', size: 4 }]) + expect(filtered.outputDimensions).toEqual([{ name: 'dim0', size: 8 }]) + expect(filtered.matrix).toEqual([ + [1, 0], + [0, 1], + [1, 1], + ]) + expect(removedInputDimensions).toEqual(['lane']) + expect(removedOutputDimensions).toEqual(['dim1']) + }) +}) diff --git a/src/core/filterSnapshotDimensions.ts b/src/core/filterSnapshotDimensions.ts new file mode 100644 index 0000000..4605292 --- /dev/null +++ b/src/core/filterSnapshotDimensions.ts @@ -0,0 +1,111 @@ +import type { LayoutMatrixSnapshot } from './LinearLayout' + +export interface SnapshotFilterResult { + snapshot: LayoutMatrixSnapshot + removedInputDimensions: string[] + removedOutputDimensions: string[] +} + +/** + * Remove dimensions smaller than the provided minimum size from a layout snapshot. + * Returns the filtered snapshot along with the dimension names that were removed. + */ +export function filterSnapshotDimensions( + snapshot: LayoutMatrixSnapshot, + minSize = 2 +): SnapshotFilterResult { + if (minSize <= 1) { + return { + snapshot: cloneSnapshot(snapshot), + removedInputDimensions: [], + removedOutputDimensions: [], + } + } + + const keepDimension = (dimension: { size: number }): boolean => dimension.size >= minSize + const removedInputDimensions = snapshot.inputDimensions + .filter((dimension) => !keepDimension(dimension)) + .map((dimension) => dimension.name) + const removedOutputDimensions = snapshot.outputDimensions + .filter((dimension) => !keepDimension(dimension)) + .map((dimension) => dimension.name) + + if (removedInputDimensions.length === 0 && removedOutputDimensions.length === 0) { + return { + snapshot: cloneSnapshot(snapshot), + removedInputDimensions, + removedOutputDimensions, + } + } + + const columnIndexesToKeep = collectBitIndexes(snapshot.inputDimensions, keepDimension) + const rowIndexesToKeep = collectBitIndexes(snapshot.outputDimensions, keepDimension) + const filteredMatrix = buildFilteredMatrix(snapshot.matrix, rowIndexesToKeep, columnIndexesToKeep) + + return { + snapshot: { + metadata: snapshot.metadata, + inputDimensions: snapshot.inputDimensions.filter(keepDimension).map(cloneDimension), + outputDimensions: snapshot.outputDimensions.filter(keepDimension).map(cloneDimension), + matrix: filteredMatrix, + }, + removedInputDimensions, + removedOutputDimensions, + } +} + +function cloneSnapshot(snapshot: LayoutMatrixSnapshot): LayoutMatrixSnapshot { + return { + metadata: snapshot.metadata, + inputDimensions: snapshot.inputDimensions.map(cloneDimension), + outputDimensions: snapshot.outputDimensions.map(cloneDimension), + matrix: snapshot.matrix.map((row) => [...row]), + } +} + +function cloneDimension(dimension: T): T { + return { ...dimension } +} + +function collectBitIndexes( + dimensions: Array<{ size: number }>, + keepDimension: (dimension: { size: number }) => boolean +): number[] { + const indexes: number[] = [] + let cursor = 0 + for (const dimension of dimensions) { + const bitWidth = getBitWidth(dimension.size) + if (keepDimension(dimension)) { + for (let i = 0; i < bitWidth; i++) { + indexes.push(cursor + i) + } + } + cursor += bitWidth + } + return indexes +} + +function getBitWidth(size: number): number { + if (!Number.isFinite(size) || size <= 0) { + return 0 + } + const width = Math.log2(size) + return Number.isInteger(width) ? width : 0 +} + +function buildFilteredMatrix( + matrix: number[][], + rowIndexes: number[], + columnIndexes: number[] +): number[][] { + if (rowIndexes.length === matrix.length && matrix.length > 0) { + if (columnIndexes.length === (matrix[0]?.length ?? 0)) { + return matrix.map((row) => [...row]) + } + } + + return rowIndexes.map((rowIndex) => { + const sourceRow = matrix[rowIndex] ?? [] + return columnIndexes.map((columnIndex) => sourceRow[columnIndex] ?? 0) + }) +} diff --git a/src/integration/LayoutProjectionBus.test.ts b/src/integration/LayoutProjectionBus.test.ts new file mode 100644 index 0000000..2eecc2f --- /dev/null +++ b/src/integration/LayoutProjectionBus.test.ts @@ -0,0 +1,110 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { LayoutProjectionBus } from './LayoutProjectionBus' + +const sampleSnapshot = { + inputDimensions: [], + outputDimensions: [], + matrix: [], +} + +describe('LayoutProjectionBus', () => { + let bus: LayoutProjectionBus + + beforeEach(() => { + bus = new LayoutProjectionBus() + }) + + it('warns when publishing to a tab with no subscribers', () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + bus.publish({ + sourceTabId: 'block-layout', + targetTabId: 'linear-layout', + snapshot: sampleSnapshot, + }) + expect(warnSpy).toHaveBeenCalledWith( + '[LayoutProjectionBus] No listeners registered for target "linear-layout"' + ) + warnSpy.mockRestore() + }) + + it('delivers requests to subscribers and supports unsubscribe', () => { + const listener = vi.fn() + const unsubscribe = bus.subscribe('linear-layout', listener) + + const request = { + sourceTabId: 'block-layout', + targetTabId: 'linear-layout', + snapshot: sampleSnapshot, + } + + bus.publish(request) + expect(listener).toHaveBeenCalledWith(request) + + unsubscribe() + listener.mockClear() + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + bus.publish(request) + expect(listener).not.toHaveBeenCalled() + warnSpy.mockRestore() + }) + + it('handles errors thrown by listeners without crashing', () => { + const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + const failingListener = vi.fn(() => { + throw new Error('listener failed') + }) + const succeedingListener = vi.fn() + + bus.subscribe('linear-layout', failingListener) + bus.subscribe('linear-layout', succeedingListener) + + expect(() => + bus.publish({ + sourceTabId: 'block-layout', + targetTabId: 'linear-layout', + snapshot: sampleSnapshot, + }) + ).not.toThrow() + + expect(failingListener).toHaveBeenCalled() + expect(succeedingListener).toHaveBeenCalled() + expect(errorSpy).toHaveBeenCalledWith( + '[LayoutProjectionBus] Listener failed to process projection request', + expect.any(Error) + ) + errorSpy.mockRestore() + }) + + it('invokes the tab activation handler after delivering a request', () => { + const activationSpy = vi.fn() + const listener = vi.fn() + bus.setTabActivationHandler(activationSpy) + bus.subscribe('linear-layout', listener) + + const request = { + sourceTabId: 'block-layout', + targetTabId: 'linear-layout', + snapshot: sampleSnapshot, + } + + bus.publish(request) + expect(listener).toHaveBeenCalledWith(request) + expect(activationSpy).toHaveBeenCalledWith('linear-layout') + }) + + it('supports multiple listeners per tab', () => { + const first = vi.fn() + const second = vi.fn() + bus.subscribe('linear-layout', first) + bus.subscribe('linear-layout', second) + + bus.publish({ + sourceTabId: 'block-layout', + targetTabId: 'linear-layout', + snapshot: sampleSnapshot, + }) + + expect(first).toHaveBeenCalled() + expect(second).toHaveBeenCalled() + }) +}) diff --git a/src/integration/LayoutProjectionBus.ts b/src/integration/LayoutProjectionBus.ts new file mode 100644 index 0000000..48dbb62 --- /dev/null +++ b/src/integration/LayoutProjectionBus.ts @@ -0,0 +1,59 @@ +import type { LayoutMatrixSnapshot } from '../core/LinearLayout' + +export interface LayoutProjectionRequest { + readonly sourceTabId: string + readonly targetTabId: string + readonly snapshot: LayoutMatrixSnapshot +} + +type ProjectionListener = (request: LayoutProjectionRequest) => void + +/** + * Lightweight event bus for routing layout projection requests between tabs. + * Tabs publish projection requests targeting another tab without needing + * any compile-time knowledge about the destination implementation. + */ +export class LayoutProjectionBus { + private readonly listeners = new Map>() + private tabActivationHandler: ((tabId: string) => void) | null = null + + public setTabActivationHandler(handler: (tabId: string) => void): void { + this.tabActivationHandler = handler + } + + public subscribe(targetTabId: string, listener: ProjectionListener): () => void { + const listenersForTarget = this.listeners.get(targetTabId) ?? new Set() + listenersForTarget.add(listener) + this.listeners.set(targetTabId, listenersForTarget) + return () => { + listenersForTarget.delete(listener) + if (listenersForTarget.size === 0) { + this.listeners.delete(targetTabId) + } + } + } + + public publish(request: LayoutProjectionRequest): void { + const listeners = this.listeners.get(request.targetTabId) + if (!listeners || listeners.size === 0) { + console.warn( + `[LayoutProjectionBus] No listeners registered for target "${request.targetTabId}"` + ) + return + } + + listeners.forEach((listener) => { + try { + listener(request) + } catch (error) { + console.error('[LayoutProjectionBus] Listener failed to process projection request', error) + } + }) + + this.tabActivationHandler?.(request.targetTabId) + } +} + +export const layoutProjectionBus = new LayoutProjectionBus() + +export const LINEAR_LAYOUT_TAB_ID = 'linear-layout' diff --git a/src/layouts/BlockLayout.test.ts b/src/layouts/BlockLayout.test.ts index 271e306..c3ec8c1 100644 --- a/src/layouts/BlockLayout.test.ts +++ b/src/layouts/BlockLayout.test.ts @@ -5,6 +5,7 @@ import { createBlockLayout, getPositionsForThread, } from './BlockLayout' +import { LinearLayout } from '../core/LinearLayout' describe('BlockLayout', () => { describe('identityStandardND', () => { @@ -160,6 +161,125 @@ describe('BlockLayout', () => { // Number of registers should be greater than base sizePerThread product expect(numRegisters).toBeGreaterThan(params.sizePerThread[0] * params.sizePerThread[1]) }) + + it('keeps output dimension ordering consistent regardless of layout order', () => { + const columnMajor = createBlockLayout({ + sizePerThread: [2, 2], + threadsPerWarp: [8, 4], + warpsPerCTA: [1, 2], + order: [0, 1], + tensorShape: [16, 16], + }) + const rowMajor = createBlockLayout({ + sizePerThread: [2, 2], + threadsPerWarp: [8, 4], + warpsPerCTA: [1, 2], + order: [1, 0], + tensorShape: [16, 16], + }) + expect(columnMajor.getOutDims().map(([name]) => name)).toEqual(['dim0', 'dim1']) + expect(rowMajor.getOutDims().map(([name]) => name)).toEqual(['dim0', 'dim1']) + }) + }) + + describe('layout projection snapshots', () => { + const expectSnapshotRoundTrip = (params: { + sizePerThread: [number, number] + threadsPerWarp: [number, number] + warpsPerCTA: [number, number] + order: [number, number] + tensorShape: [number, number] + }) => { + const layout = createBlockLayout(params) + const snapshot = layout.toMatrixSnapshot({ + sourceLayoutType: 'block-layout', + }) + const imported = LinearLayout.fromBitMatrix( + snapshot.matrix, + snapshot.inputDimensions, + snapshot.outputDimensions + ) + + const registerSize = layout.getInDimSize('register') + const laneSize = layout.getInDimSize('lane') + const warpSize = layout.getInDimSize('warp') + const laneDimensionName = + snapshot.inputDimensions.find((dim) => dim.name === 'lane')?.name ?? 'lane' + + for (let register = 0; register < registerSize; register++) { + for (let lane = 0; lane < laneSize; lane++) { + for (let warp = 0; warp < warpSize; warp++) { + const original = layout.apply({ register, lane, warp }) + const importedResult = imported.apply({ + register, + warp, + [laneDimensionName]: lane, + }) + expect(importedResult).toEqual(original) + } + } + } + } + + it('round-trips column-major layouts through matrix snapshots', () => { + expectSnapshotRoundTrip({ + sizePerThread: [2, 2], + threadsPerWarp: [2, 2], + warpsPerCTA: [1, 1], + order: [0, 1], + tensorShape: [4, 4], + }) + }) + + it('round-trips complex row-major layouts through matrix snapshots', () => { + expectSnapshotRoundTrip({ + sizePerThread: [2, 2], + threadsPerWarp: [8, 4], + warpsPerCTA: [1, 2], + order: [1, 0], + tensorShape: [16, 16], + }) + }) + + it('preserves size-1 dimensions when exporting and importing', () => { + const params = { + sizePerThread: [1, 2] as [number, number], + threadsPerWarp: [1, 32] as [number, number], + warpsPerCTA: [1, 1] as [number, number], + order: [0, 1] as [number, number], + tensorShape: [1, 64] as [number, number], + } + + const layout = createBlockLayout(params) + const snapshot = layout.toMatrixSnapshot({ + sourceLayoutType: 'block-layout', + }) + + const warpDim = snapshot.inputDimensions.find((dim) => dim.name === 'warp') + expect(warpDim?.size).toBe(1) + const dim0 = snapshot.outputDimensions.find((dim) => dim.name === 'dim0') + expect(dim0?.size).toBe(1) + + const imported = LinearLayout.fromBitMatrix( + snapshot.matrix, + snapshot.inputDimensions, + snapshot.outputDimensions + ) + + const registerSize = layout.getInDimSize('register') + const laneSize = layout.getInDimSize('lane') + const warpSize = layout.getInDimSize('warp') + + for (let register = 0; register < registerSize; register++) { + for (let lane = 0; lane < laneSize; lane++) { + for (let warp = 0; warp < warpSize; warp++) { + const original = layout.apply({ register, lane, warp }) + const importedResult = imported.apply({ register, lane, warp }) + expect(importedResult).toEqual(original) + } + } + } + }) }) describe('getPositionsForThread', () => { diff --git a/src/layouts/BlockLayout.ts b/src/layouts/BlockLayout.ts index 6f064c7..6fd1633 100644 --- a/src/layouts/BlockLayout.ts +++ b/src/layouts/BlockLayout.ts @@ -89,7 +89,8 @@ export function createBlockLayout(params: BlockLayoutParams): LinearLayout { dim1: tensorShape[1], }) - return extended + const canonicalDimOrder = tensorShape.map((_, index) => `dim${index}`) + return extended.reorderOutputs(canonicalDimOrder) } /** diff --git a/src/main.tabs.test.ts b/src/main.tabs.test.ts index ffd8dc5..98381ca 100644 --- a/src/main.tabs.test.ts +++ b/src/main.tabs.test.ts @@ -84,12 +84,59 @@ vi.mock('./tabs/MFMALayoutTab', () => ({ MFMALayoutTab: MFMALayoutTabMock, })) -const LinearLayoutTabMock = vi.fn().mockImplementation(() => ({ - activate: vi.fn(), - deactivate: vi.fn(), - resize: vi.fn(), +type ProjectionRequest = { + sourceTabId: string + targetTabId: string + snapshot: unknown +} + +const LINEAR_LAYOUT_TARGET = 'linear-layout' +const projectionListeners = new Map void>() + +const subscribeMock = vi.fn( + (targetTabId: string, listener: (request: ProjectionRequest) => void) => { + projectionListeners.set(targetTabId, listener) + return () => { + const current = projectionListeners.get(targetTabId) + if (current === listener) { + projectionListeners.delete(targetTabId) + } + } + } +) + +let tabActivationHandler: ((tabId: string) => void) | null = null +const setTabActivationHandlerMock = vi.fn((handler: (tabId: string) => void) => { + tabActivationHandler = handler +}) + +const publishMock = vi.fn((request: ProjectionRequest) => { + projectionListeners.get(request.targetTabId)?.(request) + tabActivationHandler?.(request.targetTabId) +}) + +vi.mock('./integration/LayoutProjectionBus', () => ({ + layoutProjectionBus: { + subscribe: subscribeMock, + publish: publishMock, + setTabActivationHandler: setTabActivationHandlerMock, + }, + LINEAR_LAYOUT_TAB_ID: LINEAR_LAYOUT_TARGET, })) +const LinearLayoutTabMock = vi.fn().mockImplementation(() => { + const instance = { + activate: vi.fn(), + deactivate: vi.fn(), + resize: vi.fn(), + importLayoutSnapshot: vi.fn(), + } + subscribeMock(LINEAR_LAYOUT_TARGET, (request: ProjectionRequest) => { + instance.importLayoutSnapshot(request.snapshot) + }) + return instance +}) + vi.mock('./tabs/LinearLayoutTab', () => ({ LinearLayoutTab: LinearLayoutTabMock, })) @@ -187,6 +234,11 @@ const setupDom = () => { const form = document.createElement('form') form.id = 'paramForm' + const showLinearButton = document.createElement('button') + showLinearButton.type = 'button' + showLinearButton.id = 'show-linear-layout' + showLinearButton.textContent = 'Show in Linear Layout' + form.appendChild(showLinearButton) sidebar.appendChild(form) const validationErrors = document.createElement('div') @@ -368,6 +420,12 @@ describe('main tab switching', () => { onSubmitMock.mockReset() onInputMock.mockReset() + projectionListeners.clear() + tabActivationHandler = null + subscribeMock.mockClear() + publishMock.mockClear() + setTabActivationHandlerMock.mockClear() + setupDom() LinearLayoutTabMock.mockClear() @@ -465,6 +523,73 @@ describe('main tab switching', () => { expect(ckContent?.classList.contains('active')).toBe(true) }) + it('projects the current block layout into the linear layout tab when the shortcut button is used', async () => { + const snapshot = { + inputDimensions: [ + { name: 'register', size: 4 }, + { name: ' lane ', size: 32 }, + { name: 'warp', size: 1 }, + ], + outputDimensions: [ + { name: 'dim0', size: 8 }, + { name: 'dim1', size: 8 }, + ], + matrix: [[1]], + } + const exportMock = vi.fn().mockImplementation((metadata) => ({ + ...snapshot, + metadata, + })) + createBlockLayoutMock.mockImplementation(() => ({ + toMatrixSnapshot: exportMock, + })) + + await bootstrapMain() + + const showButton = document.getElementById('show-linear-layout') + showButton?.dispatchEvent(new MouseEvent('click', { bubbles: true })) + + const expectedMetadata = expect.objectContaining({ + sourceLayoutType: 'block-layout', + sourceTabId: 'block-layout', + tensorShape: [...mockParams.tensorShape], + colorInputDimension: 'warp', + }) + + expect(exportMock).toHaveBeenCalledWith(expectedMetadata) + expect(publishMock).toHaveBeenCalledWith( + expect.objectContaining({ + sourceTabId: 'block-layout', + targetTabId: LINEAR_LAYOUT_TARGET, + }) + ) + + const linearTabInstance = LinearLayoutTabMock.mock.results[0]?.value as { + importLayoutSnapshot: ReturnType + } | undefined + expect(linearTabInstance?.importLayoutSnapshot).toHaveBeenCalledWith( + expect.objectContaining({ + metadata: expectedMetadata, + }) + ) + + const publishedSnapshot = publishMock.mock.calls[0]?.[0]?.snapshot + expect(publishedSnapshot?.inputDimensions.some((dim) => dim.name === 'lane')).toBe(true) + // Warp dimension (size 1) should be filtered out + const warpDim = publishedSnapshot?.inputDimensions.find((dim) => dim.name === 'warp') + expect(warpDim).toBeUndefined() + + const importedArg = linearTabInstance?.importLayoutSnapshot.mock.calls[0]?.[0] + expect(importedArg?.inputDimensions.some((dim: { name: string }) => dim.name === 'lane')).toBe( + true + ) + + const linearButton = document.querySelector('[data-tab="linear-layout"]') + const linearContent = document.getElementById('linear-layout') + expect(linearButton?.classList.contains('active')).toBe(true) + expect(linearContent?.classList.contains('active')).toBe(true) + }) + it('can switch between multiple tabs sequentially', async () => { await bootstrapMain() diff --git a/src/main.ts b/src/main.ts index 939f4bf..fc12206 100644 --- a/src/main.ts +++ b/src/main.ts @@ -2,6 +2,7 @@ import { BlockLayoutTab } from './tabs/BlockLayoutTab' import { WMMALayoutTab } from './tabs/WMMALayoutTab' import { MFMALayoutTab } from './tabs/MFMALayoutTab' import { LinearLayoutTab } from './tabs/LinearLayoutTab' +import { layoutProjectionBus } from './integration/LayoutProjectionBus' type TabController = { activate(): void @@ -20,10 +21,6 @@ if (tabContents.length === 0) { } const controllers: Map = new Map() -controllers.set('block-layout', new BlockLayoutTab('block-layout')) -controllers.set('wmma-layout', new WMMALayoutTab('wmma-layout')) -controllers.set('mfma-layout', new MFMALayoutTab('mfma-layout')) -controllers.set('linear-layout', new LinearLayoutTab('linear-layout')) let currentTabId: string | null = null @@ -54,6 +51,16 @@ const setActiveTab = (tabId: string): void => { currentTabId = tabId } +layoutProjectionBus.setTabActivationHandler(setActiveTab) + +const linearLayoutTab = new LinearLayoutTab('linear-layout') +const blockLayoutTab = new BlockLayoutTab('block-layout') + +controllers.set('block-layout', blockLayoutTab) +controllers.set('wmma-layout', new WMMALayoutTab('wmma-layout')) +controllers.set('mfma-layout', new MFMALayoutTab('mfma-layout')) +controllers.set('linear-layout', linearLayoutTab) + tabButtons.forEach((button) => { button.addEventListener('click', () => { const targetTab = button.getAttribute('data-tab') diff --git a/src/tabs/BlockLayoutTab.test.ts b/src/tabs/BlockLayoutTab.test.ts new file mode 100644 index 0000000..a2f0fb2 --- /dev/null +++ b/src/tabs/BlockLayoutTab.test.ts @@ -0,0 +1,237 @@ +import { describe, it, beforeEach, afterEach, expect, vi } from 'vitest' +import type { BlockLayoutParams } from '../validation/InputValidator' +import { BlockLayoutTab } from './BlockLayoutTab' +import { createBlockLayout } from '../layouts/BlockLayout' +import { LinearLayout } from '../core/LinearLayout' +import { layoutProjectionBus, LINEAR_LAYOUT_TAB_ID } from '../integration/LayoutProjectionBus' +import type { SnapshotFilterResult } from '../core/filterSnapshotDimensions' + +type ParameterFormStub = { + getParams: ReturnType + validate: ReturnType + onSubmit: ReturnType + onInput: ReturnType +} + +const defaultParams: BlockLayoutParams = { + sizePerThread: [2, 2], + threadsPerWarp: [8, 4], + warpsPerCTA: [1, 2], + order: [0, 1], + tensorShape: [16, 16], +} + +const cloneParams = (params: BlockLayoutParams): BlockLayoutParams => ({ + sizePerThread: [...params.sizePerThread] as [number, number], + threadsPerWarp: [...params.threadsPerWarp] as [number, number], + warpsPerCTA: [...params.warpsPerCTA] as [number, number], + order: [...params.order] as [number, number], + tensorShape: [...params.tensorShape] as [number, number], +}) + +let currentParams: BlockLayoutParams = cloneParams(defaultParams) +let validateResult = true +const parameterFormInstances: ParameterFormStub[] = [] + +vi.mock('../ui/ParameterForm', () => ({ + ParameterForm: vi.fn().mockImplementation(() => { + const instance: ParameterFormStub = { + getParams: vi.fn(() => currentParams), + validate: vi.fn(() => validateResult), + onSubmit: vi.fn(), + onInput: vi.fn(), + } + parameterFormInstances.push(instance) + return instance + }), +})) + +type CanvasRendererStub = { + render: ReturnType + reset: ReturnType + updateLayout: ReturnType + setColorByInputDimension: ReturnType + handleMouseDown: ReturnType + handleMouseUp: ReturnType + handleMouseMove: ReturnType + handleWheel: ReturnType + screenToGrid: ReturnType + getCellInfo: ReturnType + getWarpColor: ReturnType +} + +const canvasRendererInstances: CanvasRendererStub[] = [] + +vi.mock('../visualization/CanvasRenderer', () => ({ + CanvasRenderer: vi.fn().mockImplementation(() => { + const instance: CanvasRendererStub = { + render: vi.fn(), + reset: vi.fn(), + updateLayout: vi.fn(), + setColorByInputDimension: vi.fn(), + handleMouseDown: vi.fn(), + handleMouseUp: vi.fn(), + handleMouseMove: vi.fn(), + handleWheel: vi.fn(), + screenToGrid: vi.fn().mockReturnValue({ row: 0, col: 0 }), + getCellInfo: vi.fn().mockReturnValue(null), + getWarpColor: vi.fn().mockReturnValue('#000000'), + } + canvasRendererInstances.push(instance) + return instance + }), +})) + +const buildBlockLayoutDom = (): void => { + document.body.innerHTML = ` +
+
+ +
+ +
+ ` + + const visualization = document.querySelector('.visualization') as HTMLElement + Object.defineProperty(visualization, 'getBoundingClientRect', { + value: () => ({ + width: 800, + height: 600, + left: 0, + top: 0, + right: 800, + bottom: 600, + x: 0, + y: 0, + toJSON: () => ({}), + }), + }) +} + +const instantiateTab = (): BlockLayoutTab => { + buildBlockLayoutDom() + return new BlockLayoutTab('block-layout') +} + +describe('BlockLayoutTab', () => { + beforeEach(() => { + currentParams = cloneParams(defaultParams) + validateResult = true + parameterFormInstances.length = 0 + canvasRendererInstances.length = 0 + document.body.innerHTML = '' + vi.clearAllMocks() + }) + + afterEach(() => { + document.body.innerHTML = '' + }) + + it('prepares snapshots that normalize lane naming and filter size-1 dimensions', () => { + const tab = instantiateTab() + const params: BlockLayoutParams = { + sizePerThread: [1, 2], + threadsPerWarp: [1, 32], + warpsPerCTA: [1, 1], + order: [0, 1], + tensorShape: [1, 64], + } + const layout = createBlockLayout(params) + const { snapshot, removedInputDimensions } = (tab as unknown as { + prepareLinearLayoutSnapshot( + l: LinearLayout, + p: BlockLayoutParams + ): SnapshotFilterResult + }).prepareLinearLayoutSnapshot(layout, params) + + const inputNames = snapshot.inputDimensions.map((dim) => dim.name) + expect(inputNames).toContain('lane') + expect(inputNames).not.toContain('thread') + + const warpDim = snapshot.inputDimensions.find((dim) => dim.name === 'warp') + expect(warpDim).toBeUndefined() + expect(removedInputDimensions).toContain('warp') + + expect(snapshot.metadata?.colorInputDimension).toBe('warp') + expect(snapshot.metadata?.description).toContain('Projection') + expect(snapshot.metadata?.tensorShape).toEqual([...params.tensorShape]) + }) + + const projectionCases: Array<{ name: string; order: [number, number] }> = [ + { name: 'row-major', order: [0, 1] }, + { name: 'column-major', order: [1, 0] }, + ] + + projectionCases.forEach(({ name, order }) => { + it(`publishes snapshots that round-trip data for ${name} layouts`, () => { + const tab = instantiateTab() + expect(tab).toBeDefined() + + currentParams = { + sizePerThread: [2, 2], + threadsPerWarp: [4, 2], + warpsPerCTA: [1, 2], + order, + tensorShape: [8, 16], + } + + const publishSpy = vi.spyOn(layoutProjectionBus, 'publish') + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + + const button = document.getElementById('show-linear-layout') as HTMLButtonElement + button?.click() + + expect(publishSpy).toHaveBeenCalledWith( + expect.objectContaining({ + targetTabId: LINEAR_LAYOUT_TAB_ID, + snapshot: expect.any(Object), + }) + ) + + const snapshot = publishSpy.mock.calls[0]?.[0]?.snapshot as ReturnType< + LinearLayout['toMatrixSnapshot'] + > + expect(snapshot).toBeDefined() + if (!snapshot) { + throw new Error('Missing snapshot in projection payload') + } + const imported = LinearLayout.fromBitMatrix( + snapshot.matrix, + snapshot.inputDimensions, + snapshot.outputDimensions + ) + const expected = createBlockLayout(currentParams) + + const registerSize = expected.getInDimSize('register') + const laneSize = expected.getInDimSize('lane') + const warpSize = expected.getInDimSize('warp') + const hasLaneDimension = snapshot.inputDimensions.some((dim) => dim.name === 'lane') + expect(hasLaneDimension).toBe(true) + expect(snapshot.outputDimensions.map((dim) => dim.name)).toEqual(['dim0', 'dim1']) + + for (let register = 0; register < registerSize; register++) { + for (let lane = 0; lane < laneSize; lane++) { + for (let warp = 0; warp < warpSize; warp++) { + const expectedResult = expected.apply({ register, lane, warp }) + const importedResult = imported.apply({ + register, + warp, + lane, + }) + expect(importedResult).toEqual(expectedResult) + } + } + } + + warnSpy.mockRestore() + publishSpy.mockRestore() + }) + }) +}) diff --git a/src/tabs/BlockLayoutTab.ts b/src/tabs/BlockLayoutTab.ts index 9ade2c3..d8e7f95 100644 --- a/src/tabs/BlockLayoutTab.ts +++ b/src/tabs/BlockLayoutTab.ts @@ -1,9 +1,15 @@ +import type { LayoutMatrixSnapshot, LinearLayout } from '../core/LinearLayout' import { createBlockLayout } from '../layouts/BlockLayout' import type { BlockLayoutParams } from '../validation/InputValidator' import { CanvasRenderer } from '../visualization/CanvasRenderer' import { ParameterForm } from '../ui/ParameterForm' import { renderSharedControls } from '../ui/renderSharedControls' +import { layoutProjectionBus, LINEAR_LAYOUT_TAB_ID } from '../integration/LayoutProjectionBus' import { CanvasTab, type CanvasTabElements } from './CanvasTab' +import { + filterSnapshotDimensions, + type SnapshotFilterResult, +} from '../core/filterSnapshotDimensions' /** * Controller for the Block Layout tab. Handles form input, visualization @@ -11,6 +17,7 @@ import { CanvasTab, type CanvasTabElements } from './CanvasTab' */ export class BlockLayoutTab extends CanvasTab { private readonly form: ParameterForm + private readonly tabId: string constructor(tabId: string) { const tabContent = document.getElementById(tabId) @@ -43,8 +50,10 @@ export class BlockLayoutTab extends CanvasTab { super(elements) + this.tabId = tabId this.form = new ParameterForm('paramForm') this.setupFormHandlers() + this.setupShowInLinearLayoutButton() } /** @@ -61,6 +70,51 @@ export class BlockLayoutTab extends CanvasTab { } } + private setupShowInLinearLayoutButton(): void { + const button = this.root.querySelector('#show-linear-layout') + if (!button) { + console.warn('Show in Linear Layout button not found in BlockLayoutTab.') + return + } + + button.addEventListener('click', () => { + if (!this.form.validate()) { + return + } + const params = this.form.getParams() + try { + const layout = createBlockLayout(params) + const { snapshot, removedInputDimensions, removedOutputDimensions } = + this.prepareLinearLayoutSnapshot(layout, params) + if (snapshot.inputDimensions.length === 0 || snapshot.outputDimensions.length === 0) { + const inputCollapsed = removedInputDimensions.length > 0 + const outputCollapsed = removedOutputDimensions.length > 0 + let message = 'Unable to derive a linear layout from the current configuration.' + if (inputCollapsed || outputCollapsed) { + if (inputCollapsed && outputCollapsed) { + message = + 'Unable to derive a linear layout because all input and output dimensions are size 1.' + } else if (inputCollapsed) { + message = 'Unable to derive a linear layout because all input dimensions are size 1.' + } else { + message = 'Unable to derive a linear layout because all output dimensions are size 1.' + } + } + alert(message) + return + } + layoutProjectionBus.publish({ + sourceTabId: this.tabId, + targetTabId: LINEAR_LAYOUT_TAB_ID, + snapshot, + }) + } catch (error) { + console.error('Failed to project Block Layout into Linear Layout:', error) + alert('Failed to convert Block Layout into Linear Layout. See console for details.') + } + }) + } + /** * Update the rendered visualization with the latest form parameters. */ @@ -119,4 +173,34 @@ export class BlockLayoutTab extends CanvasTab { protected resetHover(): void { this.hideTooltip() } + + private prepareLinearLayoutSnapshot( + layout: LinearLayout, + params: BlockLayoutParams + ): SnapshotFilterResult { + const snapshot = layout.toMatrixSnapshot({ + sourceLayoutType: 'block-layout', + sourceTabId: this.tabId, + tensorShape: [...params.tensorShape], + generatedAt: new Date().toISOString(), + description: 'Projection from Block Layout controls', + colorInputDimension: this.getColorDimensionPreference(), + }) + this.normalizeLaneDimensionName(snapshot) + return filterSnapshotDimensions(snapshot) + } + + private getColorDimensionPreference(): string { + return 'warp' + } + + private normalizeLaneDimensionName(snapshot: LayoutMatrixSnapshot): void { + snapshot.inputDimensions = snapshot.inputDimensions.map((dimension) => { + const trimmed = dimension.name.trim() + if (trimmed.toLowerCase() !== 'lane') { + return dimension + } + return { ...dimension, name: 'lane' } + }) + } } diff --git a/src/tabs/LinearLayoutTab.test.ts b/src/tabs/LinearLayoutTab.test.ts new file mode 100644 index 0000000..a1f1991 --- /dev/null +++ b/src/tabs/LinearLayoutTab.test.ts @@ -0,0 +1,233 @@ +import { describe, it, beforeEach, afterEach, expect, vi } from 'vitest' +import type { MatrixEditorDimensions } from '../ui/LinearLayoutMatrixEditor' +import { LinearLayoutTab } from './LinearLayoutTab' +import { LinearLayout } from '../core/LinearLayout' + +class MatrixEditorStub { + public latestDimensions: MatrixEditorDimensions = { input: [], output: [] } + private matrix: number[][] = [] + private matrixListeners = new Set<(matrix: number[][]) => void>() + private visibilityListeners = new Set<(isOpen: boolean) => void>() + + onVisibilityChange(listener: (isOpen: boolean) => void): () => void { + this.visibilityListeners.add(listener) + return () => this.visibilityListeners.delete(listener) + } + + onMatrixChange(listener: (matrix: number[][]) => void): () => void { + this.matrixListeners.add(listener) + return () => this.matrixListeners.delete(listener) + } + + updateDimensions(dimensions: MatrixEditorDimensions, options: { emitChange?: boolean } = {}): boolean { + this.latestDimensions = this.cloneDimensions(dimensions) + this.matrix = this.createZeroMatrix(dimensions) + if (options.emitChange === false) { + return false + } + this.notifyMatrixChange() + return true + } + + replaceMatrix( + values: number[][], + dimensions?: MatrixEditorDimensions, + options: { emitChange?: boolean } = {} + ): void { + if (dimensions) { + this.latestDimensions = this.cloneDimensions(dimensions) + } + this.matrix = values.map((row) => [...row]) + if (options.emitChange === false) { + return + } + this.notifyMatrixChange() + } + + getMatrix(): number[][] { + return this.matrix.map((row) => [...row]) + } + + open(dimensions: MatrixEditorDimensions): void { + this.latestDimensions = this.cloneDimensions(dimensions) + this.visibilityListeners.forEach((listener) => listener(true)) + } + + private notifyMatrixChange(): void { + const snapshot = this.getMatrix() + this.matrixListeners.forEach((listener) => listener(snapshot)) + } + + private createZeroMatrix(dimensions: MatrixEditorDimensions): number[][] { + const rows = dimensions.output.reduce((sum, dim) => { + const bits = Math.max(0, Math.round(Math.log2(Math.max(dim.size, 1)))) + return sum + bits + }, 0) + const cols = dimensions.input.reduce((sum, dim) => { + const bits = Math.max(0, Math.round(Math.log2(Math.max(dim.size, 1)))) + return sum + bits + }, 0) + return Array.from({ length: rows }, () => Array.from({ length: cols }, () => 0)) + } + + private cloneDimensions(dimensions: MatrixEditorDimensions): MatrixEditorDimensions { + return { + input: dimensions.input.map((dim) => ({ ...dim })), + output: dimensions.output.map((dim) => ({ ...dim })), + } + } +} + +const matrixEditorInstances: MatrixEditorStub[] = [] + +vi.mock('../ui/LinearLayoutMatrixEditor', () => ({ + LinearLayoutMatrixEditor: vi.fn().mockImplementation(() => { + const instance = new MatrixEditorStub() + matrixEditorInstances.push(instance) + return instance + }), +})) + +type CanvasRendererStub = { + render: ReturnType + reset: ReturnType + updateLayout: ReturnType + setColorByInputDimension: ReturnType + handleMouseDown: ReturnType + handleMouseUp: ReturnType + handleMouseMove: ReturnType + handleWheel: ReturnType + screenToGrid: ReturnType + getCellInfo: ReturnType + getWarpColor: ReturnType +} + +const canvasRendererInstances: CanvasRendererStub[] = [] + +vi.mock('../visualization/CanvasRenderer', () => ({ + CanvasRenderer: vi.fn().mockImplementation(() => { + const instance: CanvasRendererStub = { + render: vi.fn(), + reset: vi.fn(), + updateLayout: vi.fn(), + setColorByInputDimension: vi.fn(), + handleMouseDown: vi.fn(), + handleMouseUp: vi.fn(), + handleMouseMove: vi.fn(), + handleWheel: vi.fn(), + screenToGrid: vi.fn().mockReturnValue({ row: 0, col: 0 }), + getCellInfo: vi.fn().mockReturnValue(null), + getWarpColor: vi.fn().mockReturnValue('#000000'), + } + canvasRendererInstances.push(instance) + return instance + }), +})) + +const subscribeMock = vi.hoisted(() => vi.fn().mockReturnValue(() => {})) + +vi.mock('../integration/LayoutProjectionBus', () => ({ + layoutProjectionBus: { + subscribe: subscribeMock, + }, + LINEAR_LAYOUT_TAB_ID: 'linear-layout', +})) + +const buildLinearLayoutDom = (): void => { + document.body.innerHTML = ` +
+
+ +
+ +
+ ` + + const visualization = document.querySelector('.visualization') as HTMLElement + Object.defineProperty(visualization, 'getBoundingClientRect', { + value: () => ({ + width: 800, + height: 600, + left: 0, + top: 0, + right: 800, + bottom: 600, + x: 0, + y: 0, + toJSON: () => ({}), + }), + }) +} + +const instantiateTab = (): LinearLayoutTab => { + buildLinearLayoutDom() + matrixEditorInstances.length = 0 + canvasRendererInstances.length = 0 + return new LinearLayoutTab('linear-layout') +} + +describe('LinearLayoutTab', () => { + beforeEach(() => { + document.body.innerHTML = '' + vi.clearAllMocks() + }) + + afterEach(() => { + document.body.innerHTML = '' + }) + + it('applies color metadata, truncates extra outputs, and ignores size-1 inputs', () => { + const tab = instantiateTab() + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const renderer = canvasRendererInstances[0] + expect(renderer).toBeDefined() + + const baseLayout = LinearLayout.identity1D(8, 'register', 'dim0').multiply( + LinearLayout.identity1D(8, 'lane', 'dim1') + ) + const snapshot = baseLayout.toMatrixSnapshot({ + colorInputDimension: 'register', + sourceLayoutType: 'block-layout', + }) + snapshot.inputDimensions = [ + ...snapshot.inputDimensions, + { name: 'warp', size: 1 }, + ] + snapshot.outputDimensions = [ + ...snapshot.outputDimensions, + { name: 'dim2', size: 4 }, + ] + + tab.importLayoutSnapshot(snapshot) + + expect(warnSpy).toHaveBeenCalledWith('Truncated imported layout to the first two output dimensions.') + warnSpy.mockRestore() + + const dimensionState = (tab as unknown as { dimensionState: { input: Array<{ id: string; name: string }> } }).dimensionState + const colorSelect = document.getElementById('linear-color-dimension') as HTMLSelectElement + const selectedId = colorSelect.value + const selectedDimension = dimensionState.input.find((dim) => dim.id === selectedId) + expect(selectedDimension?.name).toBe('register') + expect(renderer?.setColorByInputDimension).toHaveBeenCalledWith('register') + const inputNames = dimensionState.input.map((dim) => dim.name) + expect(inputNames).not.toContain('warp') + + const editor = matrixEditorInstances[0] + expect(editor.latestDimensions.output).toHaveLength(2) + expect(editor.latestDimensions.output.map((dim) => dim.name)).toEqual(['dim0', 'dim1']) + }) + + it('reports an error message when importing an empty snapshot', () => { + const tab = instantiateTab() + const statusElement = (tab as unknown as { layoutStatus: HTMLElement }).layoutStatus + tab.importLayoutSnapshot({ + inputDimensions: [], + outputDimensions: [], + matrix: [], + }) + expect(statusElement?.textContent).toContain('Cannot import layout') + }) +}) diff --git a/src/tabs/LinearLayoutTab.ts b/src/tabs/LinearLayoutTab.ts index 88514f3..1815577 100644 --- a/src/tabs/LinearLayoutTab.ts +++ b/src/tabs/LinearLayoutTab.ts @@ -1,4 +1,9 @@ -import { LinearLayout } from '../core/LinearLayout' +import { + LinearLayout, + type LayoutMatrixSnapshot, + type LayoutSnapshotMetadata, +} from '../core/LinearLayout' +import { filterSnapshotDimensions } from '../core/filterSnapshotDimensions' import type { BlockLayoutParams } from '../validation/InputValidator' import { LinearLayoutValidator, @@ -8,6 +13,7 @@ import { import { CanvasRenderer, type CellInfo } from '../visualization/CanvasRenderer' import { LinearLayoutMatrixEditor, type MatrixEditorDimensions } from '../ui/LinearLayoutMatrixEditor' import { renderSharedControls } from '../ui/renderSharedControls' +import { layoutProjectionBus, LINEAR_LAYOUT_TAB_ID } from '../integration/LayoutProjectionBus' import { CanvasTab, type CanvasTabElements } from './CanvasTab' type DimensionType = 'input' | 'output' @@ -23,6 +29,10 @@ interface OutputBitGroup { bitWidth: number } +interface RenderDimensionOptions { + suppressMatrixSync?: boolean +} + const BASIS_HEADING = 'Basis Calculation (all weighted bases are added via ⊕)' /** @@ -44,6 +54,7 @@ export class LinearLayoutTab extends CanvasTab { private selectedColorDimensionId: string | null private hasDimensionValidationErrors: boolean private showingValidationStatus: boolean + private suppressMatrixChangeOnce: boolean constructor(tabId: string) { const tabContent = document.getElementById(tabId) @@ -92,7 +103,7 @@ export class LinearLayoutTab extends CanvasTab { this.dimensionState = { input: [ { id: this.createDimensionId(), name: 'reg', size: 8 }, - { id: this.createDimensionId(), name: 'thread', size: 32 }, + { id: this.createDimensionId(), name: 'lane', size: 32 }, ], output: [ { id: this.createDimensionId(), name: 'dim0', size: 16 }, @@ -130,6 +141,7 @@ export class LinearLayoutTab extends CanvasTab { this.selectedColorDimensionId = null this.hasDimensionValidationErrors = false this.showingValidationStatus = false + this.suppressMatrixChangeOnce = false this.colorDimensionSelect.addEventListener('change', () => { this.handleColorDimensionSelectionChange() }) @@ -165,6 +177,10 @@ export class LinearLayoutTab extends CanvasTab { this.renderOperationsInfo(controlsContainer) this.updateRendererFromLayout() + + layoutProjectionBus.subscribe(LINEAR_LAYOUT_TAB_ID, ({ snapshot }) => { + this.importLayoutSnapshot(snapshot) + }) } protected handleHover(event: MouseEvent): void { @@ -201,6 +217,60 @@ export class LinearLayoutTab extends CanvasTab { this.hideTooltip() } + public importLayoutSnapshot(snapshot: LayoutMatrixSnapshot): void { + const { snapshot: filteredSnapshot, removedInputDimensions, removedOutputDimensions } = + filterSnapshotDimensions(snapshot) + + if ( + filteredSnapshot.inputDimensions.length === 0 || + filteredSnapshot.outputDimensions.length === 0 + ) { + const hadFilteredDimensions = + removedInputDimensions.length > 0 || removedOutputDimensions.length > 0 + this.setLayoutStatus( + hadFilteredDimensions + ? 'Cannot import layout because every dimension is size 1.' + : 'Cannot import layout without both input and output dimensions.' + ) + return + } + + const inputs = filteredSnapshot.inputDimensions.map((dimension) => ({ + id: this.createDimensionId(), + name: dimension.name.trim() || dimension.name, + size: dimension.size, + })) + + const outputs = filteredSnapshot.outputDimensions.slice(0, 2).map((dimension) => ({ + id: this.createDimensionId(), + name: dimension.name.trim() || dimension.name, + size: dimension.size, + })) + + if (filteredSnapshot.outputDimensions.length > 2) { + console.warn('Truncated imported layout to the first two output dimensions.') + } + + this.dimensionState = { + input: inputs, + output: outputs, + } + this.selectedColorDimensionId = this.getMetadataPreferredColorDimensionId( + inputs, + filteredSnapshot.metadata + ) + this.hasDimensionValidationErrors = false + this.updateValidationStatus(false) + + const suppressOptions: RenderDimensionOptions = { suppressMatrixSync: true } + this.renderDimensionRows('input', suppressOptions) + this.renderDimensionRows('output', suppressOptions) + + const dimensions = this.getMatrixDimensions() + this.matrixEditor.replaceMatrix(filteredSnapshot.matrix, dimensions) + this.updateProjectionStatus(filteredSnapshot.metadata) + } + private buildFormMarkup(): string { return `
@@ -224,7 +294,7 @@ export class LinearLayoutTab extends CanvasTab { ` } - private renderDimensionRows(type: DimensionType): void { + private renderDimensionRows(type: DimensionType, options: RenderDimensionOptions = {}): void { const list = this.dimensionLists[type] list.innerHTML = '' @@ -292,6 +362,9 @@ export class LinearLayoutTab extends CanvasTab { this.updateColorDimensionOptions() } this.applyDimensionValidationState() + if (options.suppressMatrixSync) { + this.suppressMatrixChangeOnce = true + } this.syncEditorAndLayout() } @@ -316,8 +389,13 @@ export class LinearLayoutTab extends CanvasTab { if (this.hasDimensionValidationErrors) { return } - const emittedMatrixChange = this.matrixEditor.updateDimensions(this.getMatrixDimensions()) - if (!emittedMatrixChange) { + const emitChange = !this.suppressMatrixChangeOnce + this.suppressMatrixChangeOnce = false + const emittedMatrixChange = this.matrixEditor.updateDimensions( + this.getMatrixDimensions(), + { emitChange } + ) + if (emitChange && !emittedMatrixChange) { this.rebuildLayoutFromMatrix() } } @@ -402,6 +480,24 @@ export class LinearLayoutTab extends CanvasTab { this.layoutStatus.classList.toggle('visible', Boolean(message)) } + private updateProjectionStatus(metadata?: LayoutSnapshotMetadata): void { + if (!metadata) { + this.setLayoutStatus('') + return + } + const summaryParts: string[] = [] + if (metadata.sourceLayoutType) { + summaryParts.push(`Imported from ${metadata.sourceLayoutType}`) + } + if (metadata.tensorShape?.length) { + summaryParts.push(`tensor ${metadata.tensorShape.join('×')}`) + } + if (metadata.description) { + summaryParts.push(metadata.description) + } + this.setLayoutStatus(summaryParts.join(' · ')) + } + private addDimension(type: DimensionType): void { if (type === 'output' && this.dimensionState.output.length >= 2) { return @@ -509,13 +605,26 @@ export class LinearLayoutTab extends CanvasTab { } private getDefaultColorDimensionId(dimensions: LinearDimension[]): string | undefined { - const preferred = dimensions.find((dimension) => dimension.name.trim().toLowerCase() === 'thread') + const preferred = dimensions.find((dimension) => dimension.name.trim().toLowerCase() === 'lane') if (preferred) { return preferred.id } return dimensions[0]?.id } + private getMetadataPreferredColorDimensionId( + inputs: LinearDimension[], + metadata?: LayoutSnapshotMetadata + ): string | null { + const preferred = metadata?.colorInputDimension?.trim() + if (!preferred) { + return null + } + const normalized = preferred.toLowerCase() + const match = inputs.find((dimension) => dimension.name.trim().toLowerCase() === normalized) + return match ? match.id : null + } + private handleColorDimensionSelectionChange(): void { const value = this.colorDimensionSelect.value this.selectedColorDimensionId = value || null diff --git a/src/ui/LinearLayoutMatrixEditor.ts b/src/ui/LinearLayoutMatrixEditor.ts index ad3283b..88f8f3c 100644 --- a/src/ui/LinearLayoutMatrixEditor.ts +++ b/src/ui/LinearLayoutMatrixEditor.ts @@ -20,6 +20,14 @@ interface DialogSize { height: number } +interface UpdateDimensionsOptions { + emitChange?: boolean +} + +interface ReplaceMatrixOptions { + emitChange?: boolean +} + type ForwardableMouseEventType = | 'mousemove' | 'mouseleave' @@ -273,12 +281,16 @@ export class LinearLayoutMatrixEditor { /** * Update the internal dimension snapshot without showing the modal. */ - public updateDimensions(dimensions: MatrixEditorDimensions): boolean { + public updateDimensions( + dimensions: MatrixEditorDimensions, + options: UpdateDimensionsOptions = {} + ): boolean { + const emitChange = options.emitChange !== false this.currentDimensions = { input: dimensions.input.map((dim) => ({ ...dim })), output: dimensions.output.map((dim) => ({ ...dim })), } - const { shapeChanged, matrixUpdated } = this.rebuildMatrixIfNeeded() + const { shapeChanged, matrixUpdated } = this.rebuildMatrixIfNeeded(emitChange) if (shapeChanged) { this.autoFitToMatrix() } @@ -287,7 +299,7 @@ export class LinearLayoutMatrixEditor { } else { this.needsRender = true } - return matrixUpdated + return emitChange && matrixUpdated } /** @@ -342,6 +354,33 @@ export class LinearLayoutMatrixEditor { return this.matrixValues.map((row) => [...row]) } + public replaceMatrix( + values: number[][], + dimensions?: MatrixEditorDimensions, + options: ReplaceMatrixOptions = {} + ): void { + if (dimensions) { + this.currentDimensions = { + input: dimensions.input.map((dim) => ({ ...dim })), + output: dimensions.output.map((dim) => ({ ...dim })), + } + this.rowBits = this.buildBitDescriptors(this.currentDimensions.output) + this.columnBits = this.buildBitDescriptors(this.currentDimensions.input) + this.signature = this.computeSignature(this.currentDimensions) + } + const rowCount = this.rowBits.length + const colCount = this.columnBits.length + this.matrixValues = this.normalizeMatrix(values, rowCount, colCount) + if (this.isOpen) { + this.renderMatrix() + } else { + this.needsRender = true + } + if (options.emitChange !== false) { + this.notifyMatrixChange() + } + } + public onMatrixChange(listener: (matrix: number[][]) => void): () => void { this.matrixListeners.add(listener) return () => { @@ -360,7 +399,9 @@ export class LinearLayoutMatrixEditor { this.scheduleCellSizeUpdate() } - private rebuildMatrixIfNeeded(): { shapeChanged: boolean; matrixUpdated: boolean } { + private rebuildMatrixIfNeeded( + emitChange: boolean + ): { shapeChanged: boolean; matrixUpdated: boolean } { const rowBits = this.buildBitDescriptors(this.currentDimensions.output) const columnBits = this.buildBitDescriptors(this.currentDimensions.input) const signature = this.computeSignature(this.currentDimensions) @@ -375,7 +416,9 @@ export class LinearLayoutMatrixEditor { ) this.signature = signature this.seedDefaultMatrix(rowBits.length, columnBits.length) - this.notifyMatrixChange() + if (emitChange) { + this.notifyMatrixChange() + } matrixUpdated = true } @@ -656,6 +699,19 @@ export class LinearLayoutMatrixEditor { return `${serialize(dimensions.input)}->${serialize(dimensions.output)}` } + private normalizeMatrix(values: number[][], rows: number, cols: number): number[][] { + if (rows === 0 || cols === 0) { + return Array.from({ length: rows }, () => []) + } + return Array.from({ length: rows }, (_, rowIdx) => { + const sourceRow = values[rowIdx] ?? [] + return Array.from({ length: cols }, (_, colIdx) => { + const cell = sourceRow[colIdx] ?? 0 + return cell & 1 + }) + }) + } + private seedDefaultMatrix(rows: number, cols: number): void { const limit = Math.min(rows, cols) for (let idx = 0; idx < limit; idx++) { diff --git a/src/validation/LinearLayoutValidator.test.ts b/src/validation/LinearLayoutValidator.test.ts index 5b89a70..b573850 100644 --- a/src/validation/LinearLayoutValidator.test.ts +++ b/src/validation/LinearLayoutValidator.test.ts @@ -63,6 +63,32 @@ describe('LinearLayoutValidator', () => { const result = validator.validateDimension({ id: 'dim-6', name: 'warp', + size: 0, + }) + + expect(result.size).toBe('Size must be at least 2') + }) + + it('rejects zero or negative sizes explicitly', () => { + const zeroResult = validator.validateDimension({ + id: 'dim-7', + name: 'warp', + size: 0, + }) + expect(zeroResult.size).toBe('Size must be at least 2') + + const negativeResult = validator.validateDimension({ + id: 'dim-8', + name: 'warp', + size: -4, + }) + expect(negativeResult.size).toBe('Size must be at least 2') + }) + + it('rejects size-1 dimensions explicitly', () => { + const result = validator.validateDimension({ + id: 'dim-10', + name: 'warp', size: 1, }) @@ -71,7 +97,7 @@ describe('LinearLayoutValidator', () => { it('rejects sizes that are not powers of two', () => { const result = validator.validateDimension({ - id: 'dim-7', + id: 'dim-9', name: 'warp', size: 6, })