From 0b4eccf93a0bc85964e8205c5d76b26489aae2b9 Mon Sep 17 00:00:00 2001 From: leeliu103 Date: Tue, 27 Jan 2026 15:53:51 +0000 Subject: [PATCH] Add Shared Layout tab with GPU memory swizzle visualization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements comprehensive Shared Layout feature for visualizing GPU shared memory swizzle patterns with bank conflict analysis. Core Implementation: - SharedLayout: Swizzle math supporting standard phase-based and AMD rotating modes - SharedLayoutValidator: Comprehensive validation for swizzle parameters and tensor shapes - SharedLayoutTab: Interactive tab with real-time bank info and tooltips - 56 new tests achieving full coverage (unit, integration, bank distribution, E2E) Architecture Improvements: - Refactored ParameterForm to generic component for reusability across layout types - Added custom color provider support to CanvasRenderer for bank-based coloring - Fractional bank capacity reporting for accurate multi-bank element modeling - Single-render optimization matching other tab patterns UI Features: - AMD-focused defaults (128×64, vec=4, perPhase=2, maxPhase=4, row-major order) - Real-time bank info display (32 banks, element capacity, segments) - Interactive tooltips showing logical index, offset, segment, and bank assignment - Auto-updating form with comprehensive validation Test Coverage: - 21 unit tests for swizzle formulas and bank calculations - 11 integration tests for round-trip correctness - 5 bank distribution tests - 12 validator tests - 7 tab controller tests - 5 E2E tests for UI interactions Co-Authored-By: Claude --- index.html | 138 +++++++- src/layouts/SharedLayout.bank.test.ts | 129 ++++++++ src/layouts/SharedLayout.integration.test.ts | 179 +++++++++++ src/layouts/SharedLayout.test.ts | 283 +++++++++++++++++ src/layouts/SharedLayout.ts | 248 +++++++++++++++ src/main.tabs.test.ts | 10 + src/main.ts | 3 + src/styles.css | 26 ++ src/tabs/BlockLayoutTab.ts | 58 +++- src/tabs/SharedLayoutTab.test.ts | 316 +++++++++++++++++++ src/tabs/SharedLayoutTab.ts | 299 ++++++++++++++++++ src/ui/ParameterForm.test.ts | 56 +++- src/ui/ParameterForm.ts | 198 ++++-------- src/validation/SharedLayoutValidator.test.ts | 124 ++++++++ src/validation/SharedLayoutValidator.ts | 139 ++++++++ src/visualization/CanvasRenderer.test.ts | 24 ++ src/visualization/CanvasRenderer.ts | 59 ++++ tests/visualization.spec.ts | 178 +++++++++-- 18 files changed, 2297 insertions(+), 170 deletions(-) create mode 100644 src/layouts/SharedLayout.bank.test.ts create mode 100644 src/layouts/SharedLayout.integration.test.ts create mode 100644 src/layouts/SharedLayout.test.ts create mode 100644 src/layouts/SharedLayout.ts create mode 100644 src/tabs/SharedLayoutTab.test.ts create mode 100644 src/tabs/SharedLayoutTab.ts create mode 100644 src/validation/SharedLayoutValidator.test.ts create mode 100644 src/validation/SharedLayoutValidator.ts diff --git a/index.html b/index.html index b84b60f..078e31d 100644 --- a/index.html +++ b/index.html @@ -113,7 +113,143 @@

Tensor Shape

-
+
+ + +
+ +
+
diff --git a/src/layouts/SharedLayout.bank.test.ts b/src/layouts/SharedLayout.bank.test.ts new file mode 100644 index 0000000..25b0ddd --- /dev/null +++ b/src/layouts/SharedLayout.bank.test.ts @@ -0,0 +1,129 @@ +import { describe, expect, it } from 'vitest' +import { assignBank, createSharedLayout, type SharedLayoutParams } from './SharedLayout' + +interface LayoutAccessor { + numCols: number + numRows: number + toOffset(row: number, col: number): number +} + +function buildAccessor(params: SharedLayoutParams): LayoutAccessor { + const { layout, rowDimName, colDimName } = createSharedLayout(params) + const inverse = layout.invert() + const colIndex = params.order[0] + const rowIndex = params.order[1] + const numCols = params.tensorShape[colIndex] + const numRows = params.tensorShape[rowIndex] + return { + numCols, + numRows, + toOffset(row: number, col: number): number { + const inputs: Record = { + [rowDimName]: row, + [colDimName]: col, + } + const { offset = 0 } = inverse.apply(inputs) + return offset + }, + } +} + +function rowMajorOffset(numCols: number, row: number, col: number): number { + return row * numCols + col +} + +describe('SharedLayout bank distribution', () => { + it('reduces column-stripe bank conflicts compared to row-major', () => { + const params: SharedLayoutParams = { + tensorShape: [64, 64], + order: [0, 1], + vec: 8, + perPhase: 4, + maxPhase: 8, + swizzleMode: 'swizzled', + } + const accessor = buildAccessor(params) + const column = 0 + const rows = Array.from({ length: 8 }, (_, idx) => idx) + const rowMajorBanks = rows.map((row) => + assignBank(rowMajorOffset(accessor.numCols, row, column), 16).bank + ) + const swizzledBanks = rows.map((row) => assignBank(accessor.toOffset(row, column), 16).bank) + expect(new Set(rowMajorBanks).size).toBe(1) + expect(new Set(swizzledBanks).size).toBeGreaterThan(new Set(rowMajorBanks).size) + }) + + it('assigns distinct banks to consecutive rows for 32-bit accesses', () => { + const params: SharedLayoutParams = { + tensorShape: [64, 64], + order: [0, 1], + vec: 4, + perPhase: 2, + maxPhase: 8, + swizzleMode: 'swizzled', + } + const accessor = buildAccessor(params) + const column = 5 + const rows = [0, 1, 2, 3] + const rowMajorBanks = rows.map((row) => + assignBank(rowMajorOffset(accessor.numCols, row, column), 32).bank + ) + const banks = rows.map((row) => assignBank(accessor.toOffset(row, column), 32).bank) + expect(new Set(banks).size).toBeGreaterThan(new Set(rowMajorBanks).size) + }) + + it('spans multiple bank segments for large tensors', () => { + const params: SharedLayoutParams = { + tensorShape: [128, 128], + order: [0, 1], + vec: 16, + perPhase: 8, + maxPhase: 16, + swizzleMode: 'swizzled', + } + const accessor = buildAccessor(params) + const start = assignBank(accessor.toOffset(0, 0), 32) + const end = assignBank(accessor.toOffset(accessor.numRows - 1, accessor.numCols - 1), 32) + expect(end.segment).toBeGreaterThan(start.segment) + }) + + it('responds to custom bank topologies', () => { + const params: SharedLayoutParams = { + tensorShape: [64, 64], + order: [0, 1], + vec: 8, + perPhase: 4, + maxPhase: 8, + swizzleMode: 'swizzled', + } + const accessor = buildAccessor(params) + const columns = Array.from({ length: 64 }, (_, idx) => idx % accessor.numCols) + const offsets = columns.map((col) => accessor.toOffset(12, col)) + const sixteenBanks = offsets.map((value) => + assignBank(value, 32, { bankCount: 16, bankSizeBits: 32 }).bank + ) + const sixtyFourBanks = offsets.map((value) => + assignBank(value, 32, { bankCount: 64, bankSizeBits: 32 }).bank + ) + expect(Math.max(...sixteenBanks)).toBeLessThan(16) + expect(Math.max(...sixtyFourBanks)).toBeLessThan(64) + expect(new Set(sixtyFourBanks).size).toBeGreaterThan(new Set(sixteenBanks).size) + }) + + it('varies bank coverage with element size', () => { + const params: SharedLayoutParams = { + tensorShape: [32, 128], + order: [0, 1], + vec: 8, + perPhase: 4, + maxPhase: 16, + swizzleMode: 'swizzled', + } + const accessor = buildAccessor(params) + const row = 10 + const columns = Array.from({ length: 32 }, (_, idx) => idx) + const banks8 = columns.map((col) => assignBank(accessor.toOffset(row, col), 8).bank) + const banks64 = columns.map((col) => assignBank(accessor.toOffset(row, col), 64).bank) + expect(new Set(banks64).size).toBeGreaterThan(new Set(banks8).size) + }) +}) diff --git a/src/layouts/SharedLayout.integration.test.ts b/src/layouts/SharedLayout.integration.test.ts new file mode 100644 index 0000000..ab182d5 --- /dev/null +++ b/src/layouts/SharedLayout.integration.test.ts @@ -0,0 +1,179 @@ +import { describe, expect, it } from 'vitest' +import { assignBank, createSharedLayout, type SharedLayoutParams } from './SharedLayout' + +function enumerateOffsets(params: SharedLayoutParams): number[] { + const { layout, rowDimName, colDimName } = createSharedLayout(params) + const inverse = layout.invert() + const numCols = params.tensorShape[params.order[0]] + const numRows = params.tensorShape[params.order[1]] + const totalElements = numCols * numRows + const offsets = new Set() + + for (let row = 0; row < numRows; row++) { + for (let col = 0; col < numCols; col++) { + const inverseInputs: Record = { + [rowDimName]: row, + [colDimName]: col, + } + const { offset = 0 } = inverse.apply(inverseInputs) + expect(offset).toBeGreaterThanOrEqual(0) + expect(offset).toBeLessThan(totalElements) + expect(offsets.has(offset)).toBe(false) + offsets.add(offset) + + const coords = layout.apply({ offset }) + expect(coords[rowDimName]).toBe(row) + expect(coords[colDimName]).toBe(col) + } + } + + expect(offsets.size).toBe(totalElements) + const sorted = Array.from(offsets).sort((a, b) => a - b) + expect(sorted[0]).toBe(0) + expect(sorted[sorted.length - 1]).toBe(totalElements - 1) + return sorted +} + +describe('SharedLayout integration', () => { + describe('round-trip enumerations', () => { + it('round-trips every element for a 2x2 tensor', () => { + const params: SharedLayoutParams = { + tensorShape: [2, 2], + order: [0, 1], + vec: 1, + perPhase: 1, + maxPhase: 1, + swizzleMode: 'swizzled', + } + enumerateOffsets(params) + }) + + it('round-trips every element for a 4x4 tensor', () => { + const params: SharedLayoutParams = { + tensorShape: [4, 4], + order: [0, 1], + vec: 2, + perPhase: 2, + maxPhase: 2, + swizzleMode: 'swizzled', + } + enumerateOffsets(params) + }) + + it('round-trips every element for an 8x8 tensor', () => { + const params: SharedLayoutParams = { + tensorShape: [8, 8], + order: [0, 1], + vec: 4, + perPhase: 2, + maxPhase: 4, + swizzleMode: 'swizzled', + } + enumerateOffsets(params) + }) + + it('round-trips every element for a 16x16 tensor', () => { + const params: SharedLayoutParams = { + tensorShape: [16, 16], + order: [0, 1], + vec: 8, + perPhase: 4, + maxPhase: 8, + swizzleMode: 'swizzled', + } + enumerateOffsets(params) + }) + }) + + describe('Triton parameter tuples', () => { + const tuples: Array<[number, number, number]> = [ + [1, 1, 1], + [4, 2, 4], + [8, 1, 8], + [16, 1, 16], + ] + it.each(tuples)( + 'round-trips for vec=%i, perPhase=%i, maxPhase=%i', + (vec, perPhase, maxPhase) => { + const params: SharedLayoutParams = { + tensorShape: [16, 16], + order: [0, 1], + vec, + perPhase, + maxPhase, + swizzleMode: 'swizzled', + } + enumerateOffsets(params) + } + ) + }) + + it('round-trips AMD rotating 64x64 layouts without collisions', () => { + const params: SharedLayoutParams = { + tensorShape: [64, 64], + order: [0, 1], + vec: 8, + perPhase: 4, + maxPhase: 8, + swizzleMode: 'amdRotating', + } + const offsets = enumerateOffsets(params) + const blockPeriod = params.perPhase * params.maxPhase + const columnZeroOffsets: number[] = [] + const { layout, rowDimName, colDimName } = createSharedLayout(params) + const inverse = layout.invert() + for (let row = 0; row < params.tensorShape[params.order[1]]; row++) { + const inputs: Record = { + [rowDimName]: row, + [colDimName]: 0, + } + const { offset = 0 } = inverse.apply(inputs) + columnZeroOffsets.push(offset) + } + expect(columnZeroOffsets.slice(0, blockPeriod)).not.toEqual( + columnZeroOffsets.slice(blockPeriod, blockPeriod * 2) + ) + expect(offsets.length).toBe(params.tensorShape[0] * params.tensorShape[1]) + }) + + it('supports non-square tensors and tracks element-size bank coverage', () => { + const params: SharedLayoutParams = { + tensorShape: [32, 128], + order: [0, 1], + vec: 8, + perPhase: 4, + maxPhase: 16, + swizzleMode: 'swizzled', + } + const offsets = enumerateOffsets(params) + const elementBits = [8, 16, 32, 64] + const coverage = elementBits.map((bits) => { + const banks = new Set() + for (const offset of offsets) { + banks.add(assignBank(offset, bits).bank) + } + return { bits, banks: banks.size } + }) + expect(coverage.find((entry) => entry.bits === 64)?.banks).toBeLessThan( + coverage.find((entry) => entry.bits === 8)?.banks ?? 0 + ) + }) + + it('keeps AMD and swizzled offsets aligned to prevent collisions', () => { + const swizzled: SharedLayoutParams = { + tensorShape: [32, 32], + order: [0, 1], + vec: 8, + perPhase: 4, + maxPhase: 8, + swizzleMode: 'swizzled', + } + const amd: SharedLayoutParams = { + ...swizzled, + swizzleMode: 'amdRotating', + } + const swizzledOffsets = enumerateOffsets(swizzled) + const amdOffsets = enumerateOffsets(amd) + expect(amdOffsets).toEqual(swizzledOffsets) + }) +}) diff --git a/src/layouts/SharedLayout.test.ts b/src/layouts/SharedLayout.test.ts new file mode 100644 index 0000000..8418370 --- /dev/null +++ b/src/layouts/SharedLayout.test.ts @@ -0,0 +1,283 @@ +import { describe, expect, it } from 'vitest' +import { + assignBank, + computeBankInfo, + computeRowSwizzle, + computeSwizzledColumn, + createSharedLayout, + type SharedLayoutParams, +} from './SharedLayout' + +const baseParams: SharedLayoutParams = { + tensorShape: [64, 128], + order: [0, 1], + vec: 2, + perPhase: 2, + maxPhase: 4, + swizzleMode: 'swizzled', +} + +describe('createSharedLayout', () => { + it('reorders outputs into canonical dim order regardless of layout order', () => { + const direct = createSharedLayout(baseParams) + expect(direct.layout.getOutDims()).toEqual([ + ['dim0', 64], + ['dim1', 128], + ]) + expect(direct.layout.getInDimNames()).toEqual(['offset']) + + const reversed = createSharedLayout({ + ...baseParams, + order: [1, 0], + }) + expect(reversed.layout.getOutDims()).toEqual([ + ['dim0', 64], + ['dim1', 128], + ]) + expect(reversed.colDimIndex).toBe(1) + expect(reversed.rowDimIndex).toBe(0) + }) + + it('builds layouts for tiny tensors when every parameter is 1', () => { + const tinyParams: SharedLayoutParams = { + tensorShape: [2, 2], + order: [0, 1], + vec: 1, + perPhase: 1, + maxPhase: 1, + swizzleMode: 'swizzled', + } + const result = createSharedLayout(tinyParams) + expect(result.layout.getOutDims()).toEqual([ + ['dim0', 2], + ['dim1', 2], + ]) + }) + + it('supports large tensors where vec equals columns and perPhase equals rows', () => { + const largeParams: SharedLayoutParams = { + tensorShape: [128, 128], + order: [0, 1], + vec: 128, + perPhase: 128, + maxPhase: 128, + swizzleMode: 'swizzled', + } + const result = createSharedLayout(largeParams) + expect(result.layout.getOutDims()).toEqual([ + ['dim0', 128], + ['dim1', 128], + ]) + }) +}) + +describe('row swizzle helpers', () => { + it('computes swizzle for standard shared layouts', () => { + const params = { ...baseParams, vec: 4, perPhase: 2, maxPhase: 8 } satisfies SharedLayoutParams + expect(computeRowSwizzle(0, params)).toBe(0) + expect(computeRowSwizzle(1, params)).toBe(0) + expect(computeRowSwizzle(2, params)).toBe(4) + expect(computeRowSwizzle(16, params)).toBe(0) + }) + + it('computes swizzle for AMD rotating layouts', () => { + const params: SharedLayoutParams = { + ...baseParams, + swizzleMode: 'amdRotating', + vec: 2, + perPhase: 2, + maxPhase: 4, + } + expect(computeRowSwizzle(0, params)).toBe(0) + expect(computeRowSwizzle(2, params)).toBe(2) + expect(computeRowSwizzle(4, params)).toBe(4) + }) + + it('XORs contributions when multiple row bits are set', () => { + const params: SharedLayoutParams = { + ...baseParams, + vec: 8, + perPhase: 2, + maxPhase: 8, + } + expect(computeRowSwizzle(6, params)).toBe( + computeRowSwizzle(2, params) ^ computeRowSwizzle(4, params) + ) + expect(computeRowSwizzle(10, params)).toBe( + computeRowSwizzle(2, params) ^ computeRowSwizzle(8, params) + ) + }) + + it('returns zero swizzle for degenerate tensor sizes', () => { + const singleRow: SharedLayoutParams = { + ...baseParams, + tensorShape: [64, 1], + } + expect(computeRowSwizzle(5, singleRow)).toBe(0) + + const zeroCols: SharedLayoutParams = { + ...baseParams, + tensorShape: [0, 64], + } + expect(computeRowSwizzle(8, zeroCols)).toBe(0) + }) + + it('applies AMD rotating XOR when crossing block boundaries', () => { + const amdParams: SharedLayoutParams = { + tensorShape: [64, 64], + order: [0, 1], + vec: 4, + perPhase: 4, + maxPhase: 4, + swizzleMode: 'amdRotating', + } + const standard = { ...amdParams, swizzleMode: 'swizzled' as const } + expect(computeRowSwizzle(32, standard)).toBe(0) + expect(computeRowSwizzle(32, amdParams)).toBe(8) + expect(computeRowSwizzle(48, amdParams)).toBe(12) + }) + + it('behaves like an identity layout when all tunable parameters equal 1', () => { + const params: SharedLayoutParams = { + tensorShape: [1, 1], + order: [0, 1], + vec: 1, + perPhase: 1, + maxPhase: 1, + swizzleMode: 'swizzled', + } + expect(computeRowSwizzle(0, params)).toBe(0) + expect(computeSwizzledColumn(0, 3, params)).toBe(0) + }) +}) + +describe('computeSwizzledColumn', () => { + it('matches XOR-based swizzle for logical columns', () => { + const params = { ...baseParams, vec: 4, perPhase: 2, maxPhase: 4 } satisfies SharedLayoutParams + const row = 8 + const rowSwizzle = computeRowSwizzle(row, params) + const originalCol = 21 + const swizzled = computeSwizzledColumn(row, originalCol, params) + expect(swizzled).toBe((originalCol ^ rowSwizzle) % params.tensorShape[params.order[0]]) + }) + + it('normalizes logical columns that exceed the tensor width', () => { + const params = { ...baseParams, vec: 2, perPhase: 2, maxPhase: 4 } satisfies SharedLayoutParams + const numCols = params.tensorShape[params.order[0]] + const row = 5 + const rowSwizzle = computeRowSwizzle(row, params) + const wrappedCol = numCols * 3 + 7 + expect(computeSwizzledColumn(row, wrappedCol, params)).toBe( + (wrappedCol % numCols) ^ rowSwizzle + ) + }) + + it('propagates negative logical columns through the XOR normalization', () => { + const params = { ...baseParams, vec: 2, perPhase: 2, maxPhase: 4 } satisfies SharedLayoutParams + const numCols = params.tensorShape[params.order[0]] + const row = 3 + const rowSwizzle = computeRowSwizzle(row, params) + const negativeCol = -9 + expect(computeSwizzledColumn(row, negativeCol, params)).toBe( + (negativeCol % numCols) ^ rowSwizzle + ) + }) + + it('returns the original column when there are zero logical columns', () => { + const params: SharedLayoutParams = { + ...baseParams, + tensorShape: [0, 64], + } + expect(computeSwizzledColumn(3, 99, params)).toBe(99) + }) +}) + +describe('bank calculations', () => { + it('reports per-bank statistics derived from tensor shape', () => { + const info = computeBankInfo(baseParams.tensorShape, 16) + expect(info.bankCount).toBe(32) + expect(info.bankSizeBits).toBe(32) + expect(info.elementsPerBankRow).toBe(2) + expect(info.segmentsPerBankRow).toBe(128) + }) + + it('assigns banks and segments based on offsets', () => { + const first = assignBank(0, 16) + const next = assignBank(8, 16) + expect(first.bank).toBe(0) + expect(first.segment).toBe(0) + expect(next.bank).toBeGreaterThanOrEqual(0) + expect(next.segment).toBeGreaterThanOrEqual(0) + }) + + it('scales bank information with different element bitwidths', () => { + const bitWidths = [8, 16, 32, 64] + const expectedElements: Record = { + 8: 4, + 16: 2, + 32: 1, + 64: 0.5, + } + for (const bits of bitWidths) { + const info = computeBankInfo([64, 64], bits) + expect(info.elementsPerBankRow).toBe(expectedElements[bits]) + expect(info.segmentsPerBankRow).toBeGreaterThan(0) + } + }) + + it('derives per-bank capacity directly from the configured bank width', () => { + const largerBanks = computeBankInfo([8, 8], 16, { bankCount: 64, bankSizeBits: 64 }) + expect(largerBanks.elementsPerBankRow).toBe(4) + + const smallerBanks = computeBankInfo([8, 8], 16, { bankCount: 16, bankSizeBits: 16 }) + expect(smallerBanks.elementsPerBankRow).toBe(1) + + // Ensure bank count alone does not change the per-bank capacity + const moreBanksSameWidth = computeBankInfo([8, 8], 16, { bankCount: 128, bankSizeBits: 32 }) + const fewerBanksSameWidth = computeBankInfo([8, 8], 16, { bankCount: 16, bankSizeBits: 32 }) + expect(moreBanksSameWidth.elementsPerBankRow).toBe(fewerBanksSameWidth.elementsPerBankRow) + }) + + it('derives segments for tiny and large tensors', () => { + const tiny = computeBankInfo([2, 2], 32) + expect(tiny.segmentsPerBankRow).toBe(1) + + const large = computeBankInfo([128, 128], 16) + expect(large.segmentsPerBankRow).toBeGreaterThan(tiny.segmentsPerBankRow) + }) + + it('assigns banks using overrides and clamps edge cases', () => { + const overrides = { bankCount: 16, bankSizeBits: 64 } + const regular = assignBank(5, 32, overrides) + expect(regular.bank).toBeLessThan(16) + expect(regular.segment).toBe(0) + + const negative = assignBank(-10, 32, overrides) + expect(negative.bank).toBe(0) + expect(negative.segment).toBe(0) + + const far = assignBank(4096, 32, overrides) + expect(far.segment).toBeGreaterThan(0) + }) +}) + +describe('validation', () => { + it('rejects tensor shapes that are not powers of two', () => { + expect(() => + createSharedLayout({ ...baseParams, tensorShape: [12, 128] }) + ).toThrow(/Column count must be a power of two/) + expect(() => + createSharedLayout({ ...baseParams, tensorShape: [64, 96] }) + ).toThrow(/Row count must be a power of two/) + }) + + it('rejects non power-of-two parameters', () => { + expect(() => createSharedLayout({ ...baseParams, vec: 3 })).toThrow(/vec must be a power of two/) + expect(() => createSharedLayout({ ...baseParams, perPhase: 6 })).toThrow( + /perPhase must be a power of two/ + ) + expect(() => createSharedLayout({ ...baseParams, maxPhase: 10 })).toThrow( + /maxPhase must be a power of two/ + ) + }) +}) diff --git a/src/layouts/SharedLayout.ts b/src/layouts/SharedLayout.ts new file mode 100644 index 0000000..62af498 --- /dev/null +++ b/src/layouts/SharedLayout.ts @@ -0,0 +1,248 @@ +import { LinearLayout } from '../core/LinearLayout' + +export const SHARED_BANK_COUNT = 32 +export const SHARED_BANK_WIDTH_BITS = 32 + +export type SharedSwizzleMode = 'swizzled' | 'amdRotating' + +export interface SharedLayoutParams { + tensorShape: [number, number] + order: [number, number] + vec: number + perPhase: number + maxPhase: number + swizzleMode: SharedSwizzleMode +} + +export interface SharedLayoutBuildResult { + layout: LinearLayout + rowDimName: string + colDimName: string + rowDimIndex: number + colDimIndex: number +} + +export interface SharedBankInfo { + bankCount: number + bankSizeBits: number + elementsPerBankRow: number + segmentsPerBankRow: number +} + +export interface BankAssignmentOptions { + bankCount?: number + bankSizeBits?: number +} + +type DimIndex = 0 | 1 + +function assertPositive(value: number, context: string): void { + if (!Number.isFinite(value) || value <= 0) { + throw new Error(`${context} must be a positive number`) + } +} + +function assertPowerOfTwo(value: number, context: string): void { + assertPositive(value, context) + if (!Number.isInteger(value)) { + throw new Error(`${context} must be a positive integer`) + } + if ((value & (value - 1)) !== 0) { + throw new Error(`${context} must be a power of two`) + } +} + +function assertDimIndex(value: number, context: string): asserts value is DimIndex { + if (value !== 0 && value !== 1) { + throw new Error(`${context} must be 0 or 1`) + } +} + +function assertDefined(value: T | undefined, context: string): T { + if (value === undefined) { + throw new Error(`${context} is missing`) + } + return value +} + +function normalizeOrder(order: [number, number]): [DimIndex, DimIndex] { + const [first, second] = order + assertDimIndex(first, 'Order[0]') + assertDimIndex(second, 'Order[1]') + if ((first === 0 && second === 1) || (first === 1 && second === 0)) { + return [first, second] + } + throw new Error('Order must be [0,1] or [1,0]') +} + +export function createSharedLayout(params: SharedLayoutParams): SharedLayoutBuildResult { + const dims: [string, string] = ['dim0', 'dim1'] + const shape = params.tensorShape + if (shape.length !== 2) { + throw new Error('Shared layout currently only supports 2D tensors') + } + + const [colDimIndex, rowDimIndex] = normalizeOrder(params.order) + const numCols = assertDefined(shape[colDimIndex], 'Column count') + const numRows = assertDefined(shape[rowDimIndex], 'Row count') + + assertPowerOfTwo(numCols, 'Column count') + assertPowerOfTwo(numRows, 'Row count') + assertPowerOfTwo(params.vec, 'vec') + assertPowerOfTwo(params.perPhase, 'perPhase') + assertPowerOfTwo(params.maxPhase, 'maxPhase') + + const bases = buildSharedBases(numRows, numCols, params) + const rowDimName = assertDefined(dims[rowDimIndex], 'Row dimension name') + const colDimName = assertDefined(dims[colDimIndex], 'Column dimension name') + + const layout = new LinearLayout( + [['offset', bases]], + [rowDimName, colDimName] + ).reorderOutputs(dims) + + return { + layout, + rowDimName, + colDimName, + rowDimIndex, + colDimIndex, + } +} + +function buildSharedBases( + numRows: number, + numCols: number, + params: SharedLayoutParams +): number[][] { + const bases: number[][] = [] + + for (let col = 1; col < numCols; col <<= 1) { + bases.push([0, col]) + } + + for (let row = 1; row < numRows; row <<= 1) { + const swizzle = params.swizzleMode === 'amdRotating' + ? computeAmdRotatingContribution(row, params.vec, params.perPhase, params.maxPhase, numCols) + : computeSwizzledContribution(row, params.vec, params.perPhase, params.maxPhase, numCols) + bases.push([row, swizzle]) + } + + return bases +} + +function computeSwizzledContribution( + rowBitValue: number, + vec: number, + perPhase: number, + maxPhase: number, + numCols: number +): number { + if (perPhase <= 0 || maxPhase <= 0 || numCols <= 0) { + return 0 + } + const phase = Math.floor(rowBitValue / perPhase) % maxPhase + const contribution = vec * phase + return contribution % numCols +} + +function computeAmdRotatingContribution( + rowBitValue: number, + vec: number, + perPhase: number, + maxPhase: number, + numCols: number +): number { + if (perPhase <= 0 || maxPhase <= 0 || numCols <= 0) { + return 0 + } + const phase = Math.floor(rowBitValue / perPhase) % maxPhase + const blockNo = Math.floor(rowBitValue / perPhase / maxPhase) % maxPhase + const combinedPhase = phase ^ blockNo + const contribution = vec * combinedPhase + return contribution % numCols +} + +export function computeRowSwizzle(row: number, params: SharedLayoutParams): number { + const [colDimIndex, rowDimIndex] = normalizeOrder(params.order) + const numRows = assertDefined(params.tensorShape[rowDimIndex], 'Row count') + const numCols = assertDefined(params.tensorShape[colDimIndex], 'Column count') + if (numRows <= 1 || numCols <= 0) { + return 0 + } + + let swizzle = 0 + for (let bit = 1; bit < numRows; bit <<= 1) { + if ((row & bit) === 0) { + continue + } + const contribution = params.swizzleMode === 'amdRotating' + ? computeAmdRotatingContribution(bit, params.vec, params.perPhase, params.maxPhase, numCols) + : computeSwizzledContribution(bit, params.vec, params.perPhase, params.maxPhase, numCols) + swizzle ^= contribution + } + + return swizzle % numCols +} + +export function computeSwizzledColumn( + row: number, + column: number, + params: SharedLayoutParams +): number { + const [colDimIndex] = normalizeOrder(params.order) + const numCols = assertDefined(params.tensorShape[colDimIndex], 'Column count') + if (numCols <= 0) { + return column + } + const rowSwizzle = computeRowSwizzle(row, params) + const normalizedCol = column % numCols + return normalizedCol ^ rowSwizzle +} + +export function computeBankInfo( + tensorShape: [number, number], + elementBits: number, + overrides?: BankAssignmentOptions +): SharedBankInfo { + assertPositive(elementBits, 'Element bitwidth') + const bankCount = overrides?.bankCount ?? SHARED_BANK_COUNT + const bankSizeBits = overrides?.bankSizeBits ?? SHARED_BANK_WIDTH_BITS + assertPositive(bankCount, 'Bank count') + assertPositive(bankSizeBits, 'Bank size (bits)') + + const bytesPerElement = Math.max(1, Math.trunc(elementBits / 8)) + const totalBankRowBits = bankCount * bankSizeBits + const elementsPerBankRow = bankSizeBits / elementBits + const totalElements = Math.max(1, tensorShape[0] * tensorShape[1]) + const totalBytes = totalElements * bytesPerElement + const bankRowBytes = Math.max(1, Math.trunc(totalBankRowBits / 8)) + const segmentsPerBankRow = Math.max(1, Math.ceil(totalBytes / bankRowBytes)) + + return { + bankCount, + bankSizeBits, + elementsPerBankRow, + segmentsPerBankRow, + } +} + +export function assignBank( + offset: number, + elementBits: number, + overrides?: BankAssignmentOptions +): { bank: number; segment: number } { + assertPositive(elementBits, 'Element bitwidth') + const bankCount = overrides?.bankCount ?? SHARED_BANK_COUNT + const bankSizeBits = overrides?.bankSizeBits ?? SHARED_BANK_WIDTH_BITS + assertPositive(bankCount, 'Bank count') + assertPositive(bankSizeBits, 'Bank size (bits)') + + const bytesPerElement = Math.max(1, Math.trunc(elementBits / 8)) + const bankWidthBytes = Math.max(1, Math.trunc(bankSizeBits / 8)) + const byteOffset = Math.max(0, offset) * bytesPerElement + const bankAddress = Math.floor(byteOffset / bankWidthBytes) + const bank = ((bankAddress % bankCount) + bankCount) % bankCount + const segment = Math.floor(bankAddress / bankCount) + return { bank, segment } +} diff --git a/src/main.tabs.test.ts b/src/main.tabs.test.ts index b5736c0..6f7f60e 100644 --- a/src/main.tabs.test.ts +++ b/src/main.tabs.test.ts @@ -73,6 +73,16 @@ vi.mock('./tabs/WMMALayoutTab', () => ({ WMMALayoutTab: WMMALayoutTabMock, })) +const SharedLayoutTabMock = vi.fn().mockImplementation(() => ({ + activate: vi.fn(), + deactivate: vi.fn(), + resize: vi.fn(), +})) + +vi.mock('./tabs/SharedLayoutTab', () => ({ + SharedLayoutTab: SharedLayoutTabMock, +})) + const MFMALayoutTabMock = vi.fn().mockImplementation(() => ({ activate: vi.fn(), deactivate: vi.fn(), diff --git a/src/main.ts b/src/main.ts index fc12206..2cec6b2 100644 --- a/src/main.ts +++ b/src/main.ts @@ -3,6 +3,7 @@ import { WMMALayoutTab } from './tabs/WMMALayoutTab' import { MFMALayoutTab } from './tabs/MFMALayoutTab' import { LinearLayoutTab } from './tabs/LinearLayoutTab' import { layoutProjectionBus } from './integration/LayoutProjectionBus' +import { SharedLayoutTab } from './tabs/SharedLayoutTab' type TabController = { activate(): void @@ -54,9 +55,11 @@ const setActiveTab = (tabId: string): void => { layoutProjectionBus.setTabActivationHandler(setActiveTab) const linearLayoutTab = new LinearLayoutTab('linear-layout') +const sharedLayoutTab = new SharedLayoutTab('shared-layout') const blockLayoutTab = new BlockLayoutTab('block-layout') controllers.set('block-layout', blockLayoutTab) +controllers.set('shared-layout', sharedLayoutTab) controllers.set('wmma-layout', new WMMALayoutTab('wmma-layout')) controllers.set('mfma-layout', new MFMALayoutTab('mfma-layout')) controllers.set('linear-layout', linearLayoutTab) diff --git a/src/styles.css b/src/styles.css index a80e1f0..9dd153d 100644 --- a/src/styles.css +++ b/src/styles.css @@ -430,6 +430,32 @@ button:disabled:hover { padding-bottom: 0.75rem; } +.bank-info { + padding: 1rem; + background-color: #f8f9fa; + border-radius: 4px; + border: 1px solid #dee2e6; +} + +.bank-info > div { + padding: 0.5rem 0; + font-size: 0.85rem; + line-height: 1.6; +} + +.bank-info > div:not(:last-child) { + border-bottom: 1px solid #e9ecef; + padding-bottom: 0.75rem; +} + +.bank-info strong { + font-weight: 600; +} + +.bank-info span { + margin-left: 0.5rem; +} + .validation-errors, .validation-warnings { margin-bottom: 1rem; diff --git a/src/tabs/BlockLayoutTab.ts b/src/tabs/BlockLayoutTab.ts index 3d1d23f..6c74d27 100644 --- a/src/tabs/BlockLayoutTab.ts +++ b/src/tabs/BlockLayoutTab.ts @@ -1,6 +1,6 @@ import type { LinearLayout } from '../core/LinearLayout' import { createBlockLayout } from '../layouts/BlockLayout' -import type { BlockLayoutParams } from '../validation/InputValidator' +import { InputValidator, type BlockLayoutParams } from '../validation/InputValidator' import { CanvasRenderer } from '../visualization/CanvasRenderer' import { ParameterForm } from '../ui/ParameterForm' import { renderSharedControls } from '../ui/renderSharedControls' @@ -12,8 +12,9 @@ import type { SnapshotFilterResult } from '../core/filterSnapshotDimensions' * updates, and hover tooltips while reusing the shared CanvasTab interactions. */ export class BlockLayoutTab extends CanvasTab { - private readonly form: ParameterForm + private readonly form: ParameterForm private readonly tabId: string + private static readonly validator = new InputValidator() constructor(tabId: string) { const tabContent = document.getElementById(tabId) @@ -47,7 +48,14 @@ export class BlockLayoutTab extends CanvasTab { super(elements) this.tabId = tabId - this.form = new ParameterForm('paramForm') + this.form = new ParameterForm({ + formId: 'paramForm', + errorsContainerId: 'validation-errors', + warningsContainerId: 'validation-warnings', + readParams: readBlockLayoutParams, + validateParams: (params) => + BlockLayoutTab.validator.validateBlockLayoutParams(params), + }) const linearLayoutButton = this.setupShowInLinearLayoutButton() if (linearLayoutButton) { this.monitorShowInLinearLayoutButton(linearLayoutButton) @@ -169,5 +177,49 @@ export class BlockLayoutTab extends CanvasTab { private getColorDimensionPreference(): string { return 'warp' } +} + +const readBlockLayoutParams = (form: HTMLFormElement): BlockLayoutParams => ({ + sizePerThread: [ + getNumericValue(form, 'sizePerThread0'), + getNumericValue(form, 'sizePerThread1'), + ] as [number, number], + threadsPerWarp: [ + getNumericValue(form, 'threadsPerWarp0'), + getNumericValue(form, 'threadsPerWarp1'), + ] as [number, number], + warpsPerCTA: [ + getNumericValue(form, 'warpsPerCTA0'), + getNumericValue(form, 'warpsPerCTA1'), + ] as [number, number], + order: getOrderValue(form), + tensorShape: [ + getNumericValue(form, 'tensorShape0'), + getNumericValue(form, 'tensorShape1'), + ] as [number, number], +}) + +const getNumericValue = (form: HTMLFormElement, elementId: string): number => { + const input = form.querySelector(`#${elementId}`) + if (!input) { + throw new Error(`BlockLayoutTab input not found: #${elementId}`) + } + const rawValue = input.value.trim() + if (rawValue === '') { + return Number.NaN + } + const numericValue = Number(rawValue) + return Number.isFinite(numericValue) ? numericValue : Number.NaN +} +const getOrderValue = (form: HTMLFormElement): [number, number] => { + const select = form.querySelector('#order') + if (!select) { + throw new Error('BlockLayoutTab order select not found') + } + const [dim0, dim1] = select.value.split(',').map((value) => Number(value.trim())) + if (!Number.isInteger(dim0) || !Number.isInteger(dim1)) { + throw new Error(`Invalid order value: ${select.value}`) + } + return [dim0, dim1] as [number, number] } diff --git a/src/tabs/SharedLayoutTab.test.ts b/src/tabs/SharedLayoutTab.test.ts new file mode 100644 index 0000000..ae33e90 --- /dev/null +++ b/src/tabs/SharedLayoutTab.test.ts @@ -0,0 +1,316 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +type RendererStub = { + render: ReturnType + reset: ReturnType + setCustomColorProvider: ReturnType + screenToGrid: ReturnType + getCellInfo: ReturnType +} + +const { rendererInstances, createRendererStub } = vi.hoisted(() => { + const instances: RendererStub[] = [] + const buildStub = (): RendererStub => ({ + render: vi.fn(), + reset: vi.fn(), + setCustomColorProvider: vi.fn(), + screenToGrid: vi.fn().mockReturnValue({ row: 0, col: 0 }), + getCellInfo: vi.fn().mockReturnValue({ + threadId: 0, + registerId: 0, + warpId: 0, + position: [0, 0], + sourcePosition: [0, 0], + inputCoords: { offset: 0 }, + outputCoords: { dim0: 0, dim1: 0 }, + }), + }) + return { rendererInstances: instances, createRendererStub: buildStub } +}) + +const { tooltipInstances, createTooltipStub } = vi.hoisted(() => { + const instances: Array<{ show: ReturnType; hide: ReturnType }> = [] + const buildStub = () => { + const stub = { + show: vi.fn(), + hide: vi.fn(), + } + instances.push(stub) + return stub + } + return { tooltipInstances: instances, createTooltipStub: buildStub } +}) + +vi.mock('../visualization/CanvasRenderer', () => ({ + CanvasRenderer: vi.fn().mockImplementation(() => { + const stub = createRendererStub() + rendererInstances.push(stub) + return stub + }), +})) + +vi.mock('../ui/Tooltip', () => ({ + Tooltip: vi.fn().mockImplementation(() => createTooltipStub()), +})) + +import { SharedLayoutTab } from './SharedLayoutTab' + +const buildCanvas = (): HTMLCanvasElement => { + const canvas = document.createElement('canvas') + canvas.id = 'shared-canvas' + canvas.width = 600 + canvas.height = 400 + Object.defineProperty(canvas, 'getContext', { + value: vi.fn(() => ({ + fillRect: vi.fn(), + clearRect: vi.fn(), + beginPath: vi.fn(), + moveTo: vi.fn(), + lineTo: vi.fn(), + stroke: vi.fn(), + fillText: vi.fn(), + save: vi.fn(), + restore: vi.fn(), + translate: vi.fn(), + scale: vi.fn(), + })), + }) + Object.defineProperty(canvas, 'getBoundingClientRect', { + value: () => ({ + width: 600, + height: 400, + left: 0, + top: 0, + right: 600, + bottom: 400, + x: 0, + y: 0, + toJSON: () => ({}), + }), + }) + return canvas +} + +const setupDom = () => { + document.body.innerHTML = ` +
+
+ +
+
+
+ ` + + const visualization = document.querySelector('.visualization') + Object.defineProperty(visualization, 'getBoundingClientRect', { + value: () => ({ + width: 800, + height: 600, + left: 0, + top: 0, + right: 800, + bottom: 600, + x: 0, + y: 0, + toJSON: () => ({}), + }), + }) + visualization?.appendChild(buildCanvas()) +} + +describe('SharedLayoutTab', () => { + beforeEach(() => { + vi.resetModules() + rendererInstances.length = 0 + tooltipInstances.length = 0 + document.body.innerHTML = '' + setupDom() + vi.stubGlobal('alert', vi.fn()) + }) + + it('initializes the renderer with a bank color provider', () => { + new SharedLayoutTab('shared-layout') + expect(rendererInstances.length).toBe(1) + const instance = rendererInstances[0] + expect(instance?.setCustomColorProvider).toHaveBeenCalled() + }) + + it('re-renders when tensor shape changes', () => { + const tab = new SharedLayoutTab('shared-layout') + expect(tab).toBeDefined() + const rowsInput = document.getElementById('shared-rows') as HTMLInputElement + rowsInput.value = '256' + rowsInput.dispatchEvent(new Event('input')) + + expect(rendererInstances.length).toBe(2) + const segmentsText = document.getElementById('shared-bank-segments')?.textContent + expect(segmentsText).toBe('256') + }) + + it('shows validation errors when inputs are invalid', () => { + const tab = new SharedLayoutTab('shared-layout') + expect(tab).toBeDefined() + const rowsInput = document.getElementById('shared-rows') as HTMLInputElement + rowsInput.value = '63' + rowsInput.dispatchEvent(new Event('input')) + + const errors = document.getElementById('shared-validation-errors') + expect(errors?.textContent).toContain('power of two') + }) + + it('clears bank info when validation fails', () => { + const tab = new SharedLayoutTab('shared-layout') + expect(tab).toBeDefined() + const rowsInput = document.getElementById('shared-rows') as HTMLInputElement + rowsInput.value = '63' + rowsInput.dispatchEvent(new Event('input')) + + const segmentsText = document.getElementById('shared-bank-segments')?.textContent + expect(segmentsText).toBe('--') + }) + + describe('hover tooltips', () => { + const triggerHover = ( + tab: SharedLayoutTab, + cellData: { row: number; col: number; offset: number } + ): void => { + const renderer = rendererInstances[rendererInstances.length - 1] + renderer?.getCellInfo.mockReturnValue({ + threadId: 0, + registerId: 0, + warpId: 0, + position: [cellData.row, cellData.col], + sourcePosition: [cellData.row, cellData.col], + inputCoords: { offset: cellData.offset }, + outputCoords: { dim0: cellData.row, dim1: cellData.col }, + }) + const hoverable = tab as unknown as { handleHover: (event: MouseEvent) => void } + hoverable.handleHover(new MouseEvent('mousemove', { clientX: 200, clientY: 150 })) + } + + const expectTooltipContains = (text: string): void => { + const tooltip = tooltipInstances[tooltipInstances.length - 1] + expect(tooltip).toBeDefined() + const calls = tooltip?.show.mock.calls ?? [] + expect(calls.length).toBeGreaterThan(0) + const content = calls[calls.length - 1]?.[0] as string + expect(content).toContain(text) + } + + it('shows correct tooltip values for the first cell', () => { + const tab = new SharedLayoutTab('shared-layout') + triggerHover(tab, { row: 0, col: 0, offset: 0 }) + + expectTooltipContains('Logical Index: (0, 0)') + expectTooltipContains('Offset: 0') + expectTooltipContains('Segment (Bank Row): 0') + expectTooltipContains('Bank: 0') + }) + + it('shows tooltip values for a middle cell with a non-zero bank', () => { + const tab = new SharedLayoutTab('shared-layout') + triggerHover(tab, { row: 12, col: 28, offset: 3 }) + + expectTooltipContains('Logical Index: (12, 28)') + expectTooltipContains('Offset: 3') + expectTooltipContains('Segment (Bank Row): 0') + expectTooltipContains('Bank: 1') + }) + + it('shows tooltip values for a high offset cell spanning multiple segments', () => { + const tab = new SharedLayoutTab('shared-layout') + triggerHover(tab, { row: 127, col: 63, offset: 515 }) + + expectTooltipContains('Logical Index: (127, 63)') + expectTooltipContains('Offset: 515') + expectTooltipContains('Segment (Bank Row): 8') + expectTooltipContains('Bank: 1') + }) + }) +}) diff --git a/src/tabs/SharedLayoutTab.ts b/src/tabs/SharedLayoutTab.ts new file mode 100644 index 0000000..8d5bd59 --- /dev/null +++ b/src/tabs/SharedLayoutTab.ts @@ -0,0 +1,299 @@ +import { CanvasTab, type CanvasTabElements } from './CanvasTab' +import { renderSharedControls } from '../ui/renderSharedControls' +import { CanvasRenderer, type CellInfo } from '../visualization/CanvasRenderer' +import { ParameterForm } from '../ui/ParameterForm' +import type { BlockLayoutParams } from '../validation/InputValidator' +import { + assignBank, + computeBankInfo, + createSharedLayout, + type SharedLayoutBuildResult, + type SharedLayoutParams, +} from '../layouts/SharedLayout' +import { + SharedLayoutValidator, + type SharedLayoutUiParams, + type SharedLayoutViewMode, +} from '../validation/SharedLayoutValidator' + +interface BankInfoElements { + bankCount: HTMLElement + bankSize: HTMLElement + elemsPerBank: HTMLElement + segmentsPerBank: HTMLElement +} + +export class SharedLayoutTab extends CanvasTab { + private readonly form: ParameterForm + private readonly validator = new SharedLayoutValidator() + private readonly bankInfoElements: BankInfoElements + + private currentParams: SharedLayoutUiParams | null = null + private currentLayout: SharedLayoutBuildResult | null = null + private currentElementBits = 16 + + constructor(tabId: string) { + const tabContent = document.getElementById(tabId) + if (!tabContent) { + throw new Error(`SharedLayoutTab container not found: ${tabId}`) + } + + const visualizationContainer = tabContent.querySelector('.visualization') + if (!(visualizationContainer instanceof HTMLElement)) { + throw new Error('SharedLayoutTab visualization container not found') + } + + const canvas = visualizationContainer.querySelector('canvas') + if (!(canvas instanceof HTMLCanvasElement)) { + throw new Error('SharedLayoutTab canvas element not found') + } + + const controlsContainer = tabContent.querySelector('[data-controls]') + if (!(controlsContainer instanceof HTMLElement)) { + throw new Error('SharedLayoutTab controls container not found') + } + const resetButton = renderSharedControls(controlsContainer, { resetButtonId: 'shared-reset' }) + + if (!(tabContent.querySelector('#sharedForm') instanceof HTMLFormElement)) { + throw new Error('SharedLayoutTab form element not found') + } + if ( + !(tabContent.querySelector('#shared-validation-errors') instanceof HTMLElement) || + !(tabContent.querySelector('#shared-validation-warnings') instanceof HTMLElement) + ) { + throw new Error('SharedLayoutTab validation containers not found') + } + + const bankCount = tabContent.querySelector('#shared-bank-count') + const bankSize = tabContent.querySelector('#shared-bank-size') + const elemsPerBank = tabContent.querySelector('#shared-elems-per-bank') + const segmentsPerBank = tabContent.querySelector('#shared-bank-segments') + if ( + !(bankCount instanceof HTMLElement) || + !(bankSize instanceof HTMLElement) || + !(elemsPerBank instanceof HTMLElement) || + !(segmentsPerBank instanceof HTMLElement) + ) { + throw new Error('SharedLayoutTab bank info elements missing') + } + + const elements: CanvasTabElements = { + root: tabContent, + canvas, + visualizationContainer, + resetButton, + } + + super(elements) + + this.bankInfoElements = { + bankCount, + bankSize, + elemsPerBank, + segmentsPerBank, + } + + this.form = new ParameterForm({ + formId: 'sharedForm', + errorsContainerId: 'shared-validation-errors', + warningsContainerId: 'shared-validation-warnings', + readParams: readSharedLayoutParams, + validateParams: (params) => this.validator.validate(params), + }) + + this.form.onParamsChange((params) => { + this.refreshVisualization(params) + }) + this.form.onValidationChange((isValid) => { + if (!isValid) { + this.currentParams = null + this.currentLayout = null + this.currentElementBits = 16 + this.clearBankInfo() + this.hideTooltip() + } + }) + + const initialParams = this.form.getParams() + if (this.form.validate()) { + this.refreshVisualization(initialParams) + } else { + this.currentParams = null + this.currentLayout = null + this.currentElementBits = 16 + this.clearBankInfo() + } + } + + protected handleHover(event: MouseEvent): void { + const renderer = this.getRenderer() + if (!renderer || !this.currentParams || !this.currentLayout) { + this.hideTooltip() + return + } + + const rect = this.canvas.getBoundingClientRect() + const x = event.clientX - rect.left + const y = event.clientY - rect.top + const gridPos = renderer.screenToGrid(x, y) + const cellInfo = renderer.getCellInfo(gridPos.row, gridPos.col) + + if (!cellInfo) { + this.hideTooltip() + return + } + + const row = this.readOutputCoordinate( + cellInfo, + this.currentLayout.rowDimName, + this.currentLayout.rowDimIndex + ) + const col = this.readOutputCoordinate( + cellInfo, + this.currentLayout.colDimName, + this.currentLayout.colDimIndex + ) + const offset = this.getCellOffset(cellInfo) + const { bank, segment } = assignBank(offset, this.currentElementBits) + const tooltipLines = [ + `
Logical Index: (${row}, ${col})
`, + `
Offset: ${offset}
`, + `
Segment (Bank Row): ${segment}
`, + `
Bank: ${bank}
`, + ] + + this.tooltip.show(tooltipLines.join(''), event.clientX, event.clientY) + } + + protected resetHover(): void { + this.hideTooltip() + } + + private refreshVisualization(params: SharedLayoutUiParams): void { + this.currentParams = params + this.currentElementBits = params.elementBits + this.updateBankInfo(params) + + if (params.viewMode !== 'logical') { + // Bank view not implemented yet; fallback to logical rendering. + this.currentLayout = null + return + } + + try { + this.updateVisualization(params) + } catch (error) { + console.error('Failed to render shared layout', error) + const errorMessage = error instanceof Error ? error.message : String(error) + alert(`Failed to render shared layout: ${errorMessage}`) + } + } + + private updateVisualization(params: SharedLayoutUiParams): void { + const layoutResult = createSharedLayout(params) + const rendererParams: BlockLayoutParams = { + sizePerThread: [1, 1], + threadsPerWarp: [1, 1], + warpsPerCTA: [1, 1], + order: [0, 1], + tensorShape: params.tensorShape, + } + + const renderer = new CanvasRenderer( + this.canvas, + layoutResult.layout, + rendererParams, + undefined, + { + traversalMode: 'by-output', + colorGrouping: 'warp', + showCellText: false, + } + ) + this.setRenderer(renderer) + renderer.setCustomColorProvider((cell) => this.resolveBankForCell(cell, params.elementBits)) + this.currentLayout = layoutResult + } + + private resolveBankForCell(cell: CellInfo, elementBits: number): number { + const { bank } = assignBank(this.getCellOffset(cell), elementBits) + return bank + } + + private getCellOffset(cellInfo: CellInfo): number { + const offset = cellInfo.inputCoords?.offset + if (typeof offset === 'number' && Number.isFinite(offset)) { + return offset + } + return 0 + } + + private updateBankInfo(params: SharedLayoutUiParams): void { + const info = computeBankInfo(params.tensorShape, params.elementBits) + this.bankInfoElements.bankCount.textContent = info.bankCount.toString() + this.bankInfoElements.bankSize.textContent = `${info.bankSizeBits} bits` + this.bankInfoElements.elemsPerBank.textContent = info.elementsPerBankRow.toString() + this.bankInfoElements.segmentsPerBank.textContent = info.segmentsPerBankRow.toString() + } + + private clearBankInfo(): void { + Object.values(this.bankInfoElements).forEach((element) => { + element.textContent = '--' + }) + } + + private readOutputCoordinate( + cellInfo: CellInfo, + dimName: string, + fallbackIndex: number + ): number { + const fromCoords = cellInfo.outputCoords?.[dimName] + if (typeof fromCoords === 'number' && Number.isFinite(fromCoords)) { + return fromCoords + } + return cellInfo.position[fallbackIndex] ?? 0 + } +} + +const readSharedLayoutParams = (form: HTMLFormElement): SharedLayoutUiParams => ({ + tensorShape: [ + getNumericValue(form, '#shared-rows'), + getNumericValue(form, '#shared-cols'), + ] as [number, number], + order: parseOrder(form, '#shared-order'), + vec: getNumericValue(form, '#shared-vec'), + perPhase: getNumericValue(form, '#shared-per-phase'), + maxPhase: getNumericValue(form, '#shared-max-phase'), + swizzleMode: getSelectValue(form, '#shared-swizzle-mode') as SharedLayoutParams['swizzleMode'], + elementBits: getNumericValue(form, '#shared-element-bits'), + viewMode: getSelectValue(form, '#shared-view-mode') as SharedLayoutViewMode, +}) + +const getNumericValue = (form: HTMLFormElement, selector: string): number => { + const element = form.querySelector(selector) + if (!element) { + throw new Error(`SharedLayoutTab field not found: ${selector}`) + } + + const rawValue = element.value.trim() + if (rawValue === '') { + return Number.NaN + } + + const numericValue = Number(rawValue) + return Number.isFinite(numericValue) ? numericValue : Number.NaN +} + +const getSelectValue = (form: HTMLFormElement, selector: string): string => { + const select = form.querySelector(selector) + if (!select) { + throw new Error(`SharedLayoutTab select not found: ${selector}`) + } + return select.value +} + +const parseOrder = (form: HTMLFormElement, selector: string): [number, number] => { + const value = getSelectValue(form, selector) + const [first, second] = value.split(',').map((part) => Number(part.trim())) + return [first ?? Number.NaN, second ?? Number.NaN] +} diff --git a/src/ui/ParameterForm.test.ts b/src/ui/ParameterForm.test.ts index 3ed5abb..886646e 100644 --- a/src/ui/ParameterForm.test.ts +++ b/src/ui/ParameterForm.test.ts @@ -1,5 +1,6 @@ import { describe, it, beforeEach, expect, vi } from 'vitest' import { ParameterForm } from './ParameterForm' +import { InputValidator, type BlockLayoutParams } from '../validation/InputValidator' const buildFormDom = (): void => { document.body.innerHTML = ` @@ -28,13 +29,64 @@ const buildFormDom = (): void => { ` } +const getNumberValue = (form: HTMLFormElement, id: string): number => { + const input = form.querySelector(`#${id}`) + if (!input) { + throw new Error(`Input not found: ${id}`) + } + const rawValue = input.value.trim() + if (rawValue === '') { + return Number.NaN + } + const numericValue = Number(rawValue) + return Number.isFinite(numericValue) ? numericValue : Number.NaN +} + +const getOrderValue = (form: HTMLFormElement): [number, number] => { + const select = form.querySelector('#order') + if (!select) { + throw new Error('Order select not found') + } + const [dim0, dim1] = select.value.split(',').map(Number) as [number, number] + return [dim0, dim1] +} + +const buildForm = (): ParameterForm => { + const validator = new InputValidator() + return new ParameterForm({ + formId: 'paramForm', + errorsContainerId: 'validation-errors', + warningsContainerId: 'validation-warnings', + readParams: (form) => ({ + sizePerThread: [ + getNumberValue(form, 'sizePerThread0'), + getNumberValue(form, 'sizePerThread1'), + ] as [number, number], + threadsPerWarp: [ + getNumberValue(form, 'threadsPerWarp0'), + getNumberValue(form, 'threadsPerWarp1'), + ] as [number, number], + warpsPerCTA: [ + getNumberValue(form, 'warpsPerCTA0'), + getNumberValue(form, 'warpsPerCTA1'), + ] as [number, number], + order: getOrderValue(form), + tensorShape: [ + getNumberValue(form, 'tensorShape0'), + getNumberValue(form, 'tensorShape1'), + ] as [number, number], + }), + validateParams: (params) => validator.validateBlockLayoutParams(params), + }) +} + describe('ParameterForm', () => { beforeEach(() => { buildFormDom() }) it('shows inline errors for decimal input instead of throwing', () => { - const form = new ParameterForm('paramForm') + const form = buildForm() const sizeInput = document.getElementById('sizePerThread0') as HTMLInputElement sizeInput.value = '2.5' @@ -47,7 +99,7 @@ describe('ParameterForm', () => { }) it('suppresses callbacks and flips validation state when parse errors occur', () => { - const form = new ParameterForm('paramForm') + const form = buildForm() const changeSpy = vi.fn() const validationSpy = vi.fn() form.onParamsChange(changeSpy) diff --git a/src/ui/ParameterForm.ts b/src/ui/ParameterForm.ts index 2bcacac..1ed9bed 100644 --- a/src/ui/ParameterForm.ts +++ b/src/ui/ParameterForm.ts @@ -1,118 +1,120 @@ -import { InputValidator, type BlockLayoutParams } from '../validation/InputValidator' - -export class ParameterForm { - private form: HTMLFormElement - private errorsDiv: HTMLElement - private warningsDiv: HTMLElement - private validator: InputValidator - private validationListeners = new Set<(isValid: boolean) => void>() +import type { ValidationResult } from '../validation/InputValidator' + +export interface ParameterFormConfig { + formId: string + errorsContainerId: string + warningsContainerId: string + readParams: (form: HTMLFormElement) => TParams + validateParams: (params: TParams) => ValidationResult +} + +export class ParameterForm { + private readonly form: HTMLFormElement + private readonly errorsDiv: HTMLElement + private readonly warningsDiv: HTMLElement + private readonly readParamsFn: (form: HTMLFormElement) => TParams + private readonly validateParamsFn: (params: TParams) => ValidationResult + private readonly validationListeners = new Set<(isValid: boolean) => void>() private lastValidationResult: boolean | null = null - constructor(formId: string) { - const form = document.getElementById(formId) - if (!form || !(form instanceof HTMLFormElement)) { - throw new Error(`Form not found: ${formId}`) + constructor(config: ParameterFormConfig) { + const form = document.getElementById(config.formId) + if (!(form instanceof HTMLFormElement)) { + throw new Error(`Form not found: ${config.formId}`) } this.form = form - const errorsDiv = document.getElementById('validation-errors') - const warningsDiv = document.getElementById('validation-warnings') - - if (!errorsDiv || !warningsDiv) { + const errorsDiv = document.getElementById(config.errorsContainerId) + const warningsDiv = document.getElementById(config.warningsContainerId) + if (!(errorsDiv instanceof HTMLElement) || !(warningsDiv instanceof HTMLElement)) { throw new Error('Validation message divs not found') } this.errorsDiv = errorsDiv this.warningsDiv = warningsDiv - this.validator = new InputValidator() + this.readParamsFn = config.readParams + this.validateParamsFn = config.validateParams } - /** - * Get current parameter values from the form - */ - getParams(): BlockLayoutParams { - return { - sizePerThread: [ - this.getNumberValue('sizePerThread0'), - this.getNumberValue('sizePerThread1'), - ] as [number, number], - threadsPerWarp: [ - this.getNumberValue('threadsPerWarp0'), - this.getNumberValue('threadsPerWarp1'), - ] as [number, number], - warpsPerCTA: [ - this.getNumberValue('warpsPerCTA0'), - this.getNumberValue('warpsPerCTA1'), - ] as [number, number], - order: this.getOrderValue(), - tensorShape: [ - this.getNumberValue('tensorShape0'), - this.getNumberValue('tensorShape1'), - ] as [number, number], - } + getParams(): TParams { + return this.readParamsFn(this.form) } - /** - * Validate current parameters and display errors/warnings - * Returns true if valid (no errors), false otherwise - */ validate(): boolean { const params = this.getParams() - const result = this.validator.validateBlockLayoutParams(params) + const result = this.runValidation(params) + return result.valid + } + + onParamsChange(callback: (params: TParams) => void): void { + const handleChange = (): void => { + const params = this.getParams() + const result = this.runValidation(params) + if (result.valid) { + callback(params) + } + } + + const fields = this.form.querySelectorAll('input, select') + fields.forEach((field) => { + const eventName = field instanceof HTMLSelectElement ? 'change' : 'input' + field.addEventListener(eventName, handleChange) + }) + + this.form.addEventListener('submit', (event) => event.preventDefault()) + } + + onValidationChange(callback: (isValid: boolean) => void): void { + this.validationListeners.add(callback) + + if (this.lastValidationResult === null) { + this.validate() + return + } + + callback(this.lastValidationResult) + } + + private runValidation(params: TParams): ValidationResult { + const result = this.validateParamsFn(params) - // Display errors if (result.errors.size > 0) { - const errorList = Array.from(result.errors.values()) - this.showErrors(errorList) + this.showErrors(Array.from(result.errors.values())) } else { this.hideErrors() } - // Display warnings if (result.warnings.size > 0) { - const warningList = Array.from(result.warnings.values()) - this.showWarnings(warningList) + this.showWarnings(Array.from(result.warnings.values())) } else { this.hideWarnings() } this.updateValidationState(result.valid) - return result.valid + return result } - /** - * Show error messages - */ private showErrors(errors: string[]): void { this.errorsDiv.innerHTML = ` Errors: -
    ${errors.map(err => `
  • ${err}
  • `).join('')}
+
    ${errors.map((err) => `
  • ${err}
  • `).join('')}
` this.errorsDiv.classList.add('visible') } - /** - * Hide error messages - */ private hideErrors(): void { this.errorsDiv.classList.remove('visible') this.errorsDiv.innerHTML = '' } - /** - * Show warning messages - */ private showWarnings(warnings: string[]): void { this.warningsDiv.innerHTML = ` Warnings: -
    ${warnings.map(warn => `
  • ${warn}
  • `).join('')}
+
    ${warnings.map((warn) => `
  • ${warn}
  • `).join('')}
` this.warningsDiv.classList.add('visible') } - /** - * Hide warning messages - */ private hideWarnings(): void { this.warningsDiv.classList.remove('visible') this.warningsDiv.innerHTML = '' @@ -125,68 +127,4 @@ export class ParameterForm { this.lastValidationResult = nextState this.validationListeners.forEach((listener) => listener(nextState)) } - - /** - * Invoke the callback whenever any form field changes and the inputs validate. - */ - onParamsChange(callback: (params: BlockLayoutParams) => void): void { - const handleChange = (): void => { - if (this.validate()) { - callback(this.getParams()) - } - } - - const fields = this.form.querySelectorAll('input, select') - fields.forEach((field) => { - const eventName = field instanceof HTMLSelectElement ? 'change' : 'input' - field.addEventListener(eventName, handleChange) - }) - - this.form.addEventListener('submit', (event) => event.preventDefault()) - } - - /** - * Notify listeners whenever the validation state changes. - */ - onValidationChange(callback: (isValid: boolean) => void): void { - this.validationListeners.add(callback) - - if (this.lastValidationResult === null) { - this.validate() - return - } - - callback(this.lastValidationResult) - } - - /** - * Get number value from input - */ - private getNumberValue(id: string): number { - const input = document.getElementById(id) as HTMLInputElement - if (!input) { - throw new Error(`Input not found: ${id}`) - } - const rawValue = input.value.trim() - if (rawValue === '') { - return Number.NaN - } - const numericValue = Number(rawValue) - return Number.isFinite(numericValue) ? numericValue : Number.NaN - } - - /** - * Get order value from dropdown - */ - private getOrderValue(): [number, number] { - const select = document.getElementById('order') as HTMLSelectElement - if (!select) { - throw new Error('Order select not found') - } - const [dim0, dim1] = select.value.split(',').map(Number) as [number, number] - if (!Number.isInteger(dim0) || !Number.isInteger(dim1)) { - throw new Error(`Invalid order value: ${select.value}`) - } - return [dim0, dim1] - } } diff --git a/src/validation/SharedLayoutValidator.test.ts b/src/validation/SharedLayoutValidator.test.ts new file mode 100644 index 0000000..6366143 --- /dev/null +++ b/src/validation/SharedLayoutValidator.test.ts @@ -0,0 +1,124 @@ +import { describe, expect, it } from 'vitest' +import type { SharedLayoutUiParams } from './SharedLayoutValidator' +import { SharedLayoutValidator } from './SharedLayoutValidator' + +const baseParams: SharedLayoutUiParams = { + tensorShape: [128, 64], + order: [1, 0], + vec: 4, + perPhase: 2, + maxPhase: 4, + swizzleMode: 'swizzled', + elementBits: 16, + viewMode: 'logical', +} + +describe('SharedLayoutValidator', () => { + const validator = new SharedLayoutValidator() + + it('accepts the default configuration', () => { + const result = validator.validate(baseParams) + expect(result.valid).toBe(true) + expect(result.errors.size).toBe(0) + }) + + it('rejects non power-of-two tensor shapes', () => { + const result = validator.validate({ + ...baseParams, + tensorShape: [63, 128], + }) + expect(result.valid).toBe(false) + expect(result.errors.get('tensorShape')).toContain('power of two') + }) + + it('rejects unsupported element bit widths', () => { + const result = validator.validate({ + ...baseParams, + elementBits: 12, + }) + expect(result.valid).toBe(false) + expect(result.errors.get('elementBits')).toContain('8, 16, 32, or 64') + }) + + it('enforces valid order combinations', () => { + const result = validator.validate({ + ...baseParams, + order: [0, 0], + }) + expect(result.valid).toBe(false) + expect(result.errors.get('order')).toContain('[0,1] or [1,0]') + }) + + it('accepts reversed order', () => { + const result = validator.validate({ + ...baseParams, + order: [0, 1], + }) + expect(result.valid).toBe(true) + }) + + it('requires a supported swizzle mode', () => { + const result = validator.validate({ + ...baseParams, + swizzleMode: 'unknown' as any, + }) + expect(result.valid).toBe(false) + expect(result.errors.get('swizzleMode')).toContain('"swizzled" or "amdRotating"') + }) + + it('caps vec at the hardware limit', () => { + const result = validator.validate({ + ...baseParams, + vec: 32, + }) + expect(result.valid).toBe(false) + expect(result.errors.get('vec')).toContain('must not exceed 16') + }) + + it('caps perPhase at the hardware limit', () => { + const result = validator.validate({ + ...baseParams, + perPhase: 32, + }) + expect(result.valid).toBe(false) + expect(result.errors.get('perPhase')).toContain('must not exceed 16') + }) + + it('caps maxPhase at the hardware limit', () => { + const result = validator.validate({ + ...baseParams, + maxPhase: 32, + }) + expect(result.valid).toBe(false) + expect(result.errors.get('maxPhase')).toContain('must not exceed 16') + }) + + it('flags vec values that exceed the selected column dimension', () => { + const result = validator.validate({ + ...baseParams, + tensorShape: [128, 8], + vec: 16, + }) + expect(result.valid).toBe(false) + expect(result.errors.get('vec')).toContain('must not exceed the number of columns') + }) + + it('flags perPhase values larger than the row dimension', () => { + const result = validator.validate({ + ...baseParams, + tensorShape: [8, 64], + perPhase: 16, + }) + expect(result.valid).toBe(false) + expect(result.errors.get('perPhase')).toContain('must not exceed the number of rows') + }) + + it('emits a warning when bank view is requested', () => { + const result = validator.validate({ + ...baseParams, + viewMode: 'bank', + }) + expect(result.valid).toBe(true) + expect(result.warnings.get('viewMode')).toContain('Only logical view is implemented') + }) +}) diff --git a/src/validation/SharedLayoutValidator.ts b/src/validation/SharedLayoutValidator.ts new file mode 100644 index 0000000..67c2c5e --- /dev/null +++ b/src/validation/SharedLayoutValidator.ts @@ -0,0 +1,139 @@ +import type { SharedLayoutParams } from '../layouts/SharedLayout' +import type { ValidationResult } from './InputValidator' + +export type SharedLayoutViewMode = 'logical' | 'bank' + +export interface SharedLayoutUiParams extends SharedLayoutParams { + elementBits: number + viewMode: SharedLayoutViewMode +} + +const SWIZZLE_PARAM_LIMIT = 16 + +export class SharedLayoutValidator { + validate(params: SharedLayoutUiParams): ValidationResult { + const errors = new Map() + const warnings = new Map() + + const shapeValid = this.validateShape(params.tensorShape, errors) + const orderValid = this.validateOrder(params.order, errors) + this.validateSwizzleMode(params.swizzleMode, errors) + this.validateSwizzleParams(params, errors, shapeValid, orderValid) + this.validateElementBits(params.elementBits, errors) + this.validateViewMode(params.viewMode, warnings) + + return { + valid: errors.size === 0, + errors, + warnings, + } + } + + private validateShape(shape: [number, number], errors: Map): boolean { + let valid = true + const labels = ['Rows', 'Cols'] + shape.forEach((value, index) => { + if (!Number.isInteger(value) || value <= 0) { + errors.set('tensorShape', `${labels[index]} must be a positive integer`) + valid = false + return + } + if (!this.isPowerOfTwo(value)) { + errors.set('tensorShape', `${labels[index]} must be a power of two`) + valid = false + } + }) + return valid + } + + private validateOrder(order: [number, number], errors: Map): boolean { + const valid = + (order[0] === 0 && order[1] === 1) || + (order[0] === 1 && order[1] === 0) + if (!valid) { + errors.set('order', 'Order must be [0,1] or [1,0]') + } + return valid + } + + private validateSwizzleMode(swizzleMode: string, errors: Map): void { + if (swizzleMode !== 'swizzled' && swizzleMode !== 'amdRotating') { + errors.set('swizzleMode', 'swizzleMode must be "swizzled" or "amdRotating"') + } + } + + private validateSwizzleParams( + params: SharedLayoutParams, + errors: Map, + shapeValid: boolean, + orderValid: boolean + ): void { + const { vec, perPhase, maxPhase, tensorShape, order } = params + const validations: Array<[number, string]> = [ + [vec, 'vec'], + [perPhase, 'perPhase'], + [maxPhase, 'maxPhase'], + ] + + const invalidFields = new Set() + + validations.forEach(([value, label]) => { + if (!Number.isInteger(value) || value <= 0) { + errors.set(label, `${label} must be a positive integer`) + invalidFields.add(label) + } else if (!this.isPowerOfTwo(value)) { + errors.set(label, `${label} must be a power of two`) + invalidFields.add(label) + } + }) + + const enforceHardwareLimit = (value: number, label: string): void => { + if (!invalidFields.has(label) && value > SWIZZLE_PARAM_LIMIT) { + errors.set(label, `${label} must not exceed ${SWIZZLE_PARAM_LIMIT} (AMD/NVIDIA hardware limit)`) + invalidFields.add(label) + } + } + + enforceHardwareLimit(vec, 'vec') + enforceHardwareLimit(perPhase, 'perPhase') + enforceHardwareLimit(maxPhase, 'maxPhase') + + if (!invalidFields.has('maxPhase') && maxPhase <= 0) { + errors.set('maxPhase', 'maxPhase must be positive') + invalidFields.add('maxPhase') + } + + if (!shapeValid || !orderValid) { + return + } + + const colDimIndex = order[0] + const rowDimIndex = order[1] + const colSize = tensorShape[colDimIndex] ?? 0 + const rowSize = tensorShape[rowDimIndex] ?? 0 + + if (!invalidFields.has('vec') && vec > colSize) { + errors.set('vec', 'vec must not exceed the number of columns') + } + if (!invalidFields.has('perPhase') && perPhase > rowSize) { + errors.set('perPhase', 'perPhase must not exceed the number of rows') + } + } + + private validateElementBits(elementBits: number, errors: Map): void { + const allowed = new Set([8, 16, 32, 64]) + if (!allowed.has(elementBits)) { + errors.set('elementBits', 'Element bits must be 8, 16, 32, or 64') + } + } + + private validateViewMode(viewMode: SharedLayoutViewMode, warnings: Map): void { + if (viewMode !== 'logical') { + warnings.set('viewMode', 'Only logical view is implemented') + } + } + + private isPowerOfTwo(value: number): boolean { + return value > 0 && (value & (value - 1)) === 0 + } +} diff --git a/src/visualization/CanvasRenderer.test.ts b/src/visualization/CanvasRenderer.test.ts index 438e53e..a93f74f 100644 --- a/src/visualization/CanvasRenderer.test.ts +++ b/src/visualization/CanvasRenderer.test.ts @@ -178,6 +178,30 @@ describe('CanvasRenderer', () => { }) }) + describe('custom color provider', () => { + it('applies custom color indices when provided', () => { + const provider = vi.fn((cell: CellInfo) => cell.position[0]) + const warpSpy = vi.spyOn(ColorScheme.prototype, 'getColorForWarp') + + try { + renderer.setCustomColorProvider(provider) + renderer.render() + expect(provider).toHaveBeenCalled() + expect(warpSpy).toHaveBeenCalled() + } finally { + warpSpy.mockRestore() + } + }) + + it('falls back to default coloring when provider is cleared', () => { + const provider = vi.fn(() => 0) + renderer.setCustomColorProvider(provider) + renderer.setCustomColorProvider(undefined) + const color = renderer.getWarpColor(0) + expect(typeof color).toBe('string') + }) + }) + describe('viewport integration', () => { it('should support zoom in', () => { expect(() => renderer.zoomIn()).not.toThrow() diff --git a/src/visualization/CanvasRenderer.ts b/src/visualization/CanvasRenderer.ts index beab541..ba73bea 100644 --- a/src/visualization/CanvasRenderer.ts +++ b/src/visualization/CanvasRenderer.ts @@ -12,6 +12,7 @@ interface CanvasRendererOptions { traversalMode?: TraversalMode showCellText?: boolean colorInputDimension?: string + customColorProvider?: (cell: CellInfo) => number | null | undefined } export interface CellInfo { @@ -81,6 +82,8 @@ export class CanvasRenderer { private colorGrouping: ColorGroupingMode private showCellText: boolean private customColorDimension?: string + private customColorProvider?: (cell: CellInfo) => number | null | undefined + private customColorGroupCount = 0 private maxThreadIdObserved = 0 constructor( @@ -102,6 +105,7 @@ export class CanvasRenderer { this.showCellText = options?.showCellText ?? true const trimmedColorDim = options?.colorInputDimension?.trim() this.customColorDimension = trimmedColorDim && trimmedColorDim.length > 0 ? trimmedColorDim : undefined + this.customColorProvider = options?.customColorProvider this.resetColorSchemeFromParams() // Initialize viewport controller @@ -127,6 +131,7 @@ export class CanvasRenderer { // Build cell data cache this.cellDataCache = this.buildCellDataCache() + this.applyCustomColorProvider() } /** @@ -404,6 +409,39 @@ export class CanvasRenderer { this.colorScheme = new ColorScheme(colorCount, 1) } + private applyCustomColorProvider(): void { + if (!this.customColorProvider) { + if (this.customColorGroupCount > 0) { + this.customColorGroupCount = 0 + if (this.traversalMode === 'by-output') { + this.rebuildColorSchemeFromThreads(this.maxThreadIdObserved) + } else { + this.resetColorSchemeFromParams() + } + } + return + } + + let maxGroup = 0 + let observed = false + for (const entries of this.cellDataCache.values()) { + for (const cell of entries) { + const value = this.customColorProvider(cell) + if (typeof value === 'number' && Number.isFinite(value)) { + const normalized = Math.max(0, Math.trunc(value)) + maxGroup = Math.max(maxGroup, normalized) + observed = true + } + } + } + + const desiredGroups = Math.max(1, observed ? maxGroup + 1 : 1) + if (desiredGroups !== this.customColorGroupCount) { + this.customColorGroupCount = desiredGroups + this.colorScheme = new ColorScheme(desiredGroups, 1) + } + } + /** * Render the entire visualization */ @@ -487,6 +525,13 @@ export class CanvasRenderer { } private getCellFillColor(cellInfo: CellInfo): string { + if (this.customColorProvider) { + const value = this.customColorProvider(cellInfo) + if (typeof value === 'number' && Number.isFinite(value)) { + const normalized = Math.max(0, Math.trunc(value)) + return this.colorScheme.getColorForWarp(normalized) + } + } if (this.customColorDimension) { const value = cellInfo.inputCoords?.[this.customColorDimension] if (typeof value === 'number' && Number.isFinite(value)) { @@ -638,6 +683,8 @@ export class CanvasRenderer { } setColorByInputDimension(dimensionName?: string): void { + this.customColorProvider = undefined + this.customColorGroupCount = 0 const normalized = dimensionName?.trim() const nextDimension = normalized && normalized.length > 0 ? normalized : undefined if (this.customColorDimension === nextDimension) { @@ -654,6 +701,17 @@ export class CanvasRenderer { this.render() } + setCustomColorProvider(provider?: (cell: CellInfo) => number | null | undefined): void { + this.customColorProvider = provider + if (provider) { + this.customColorDimension = undefined + } else { + this.customColorGroupCount = 0 + } + this.applyCustomColorProvider() + this.render() + } + /** * Zoom in */ @@ -709,6 +767,7 @@ export class CanvasRenderer { } this.cellDataCache = this.buildCellDataCache() + this.applyCustomColorProvider() this.render() } } diff --git a/tests/visualization.spec.ts b/tests/visualization.spec.ts index b28a9c8..ab7f552 100644 --- a/tests/visualization.spec.ts +++ b/tests/visualization.spec.ts @@ -1,14 +1,42 @@ import { test, expect } from '@playwright/test' +import type { Page } from '@playwright/test' + +const waitForSharedCanvas = async (page: Page): Promise => { + await page.waitForFunction(() => { + const canvas = document.querySelector('#shared-canvas') + return Boolean(canvas && canvas.width > 0 && canvas.height > 0) + }) +} + +const openSharedLayoutTab = async (page: Page): Promise => { + await page.goto('/') + await page.locator('#tab-shared-layout').click() + await expect(page.locator('#shared-layout')).toHaveClass(/active/) + await waitForSharedCanvas(page) +} + +const captureConsoleErrors = (page: Page): (() => string[]) => { + const errors: string[] = [] + page.on('console', (message) => { + if (message.type() === 'error') { + errors.push(message.text()) + } + }) + page.on('pageerror', (error) => { + errors.push(error.message ?? String(error)) + }) + return () => errors +} test.describe('Triton Layout Visualizer', () => { test('should load the application', async ({ page }) => { await page.goto('/') // Check that the title is correct - await expect(page).toHaveTitle(/Triton Layout Visualizer/) + await expect(page).toHaveTitle(/GPU Tensor Layout Visualizer/) // Check that main heading is visible - await expect(page.locator('h1')).toContainText('Triton Block Layout Visualizer') + await expect(page.locator('h1')).toContainText('GPU Tensor Layout Visualizer') }) test('should render the canvas', async ({ page }) => { @@ -61,12 +89,9 @@ test.describe('Triton Layout Visualizer', () => { // Enter invalid value (not power of 2) await page.locator('#sizePerThread0').fill('3') - // Submit form - await page.locator('button[type="submit"]').click() - // Check that error message is displayed + await page.waitForSelector('#validation-errors.visible') const errors = page.locator('#validation-errors') - await expect(errors).toBeVisible() await expect(errors).toContainText('power') }) @@ -81,35 +106,12 @@ test.describe('Triton Layout Visualizer', () => { await page.locator('#tensorShape0').fill('32') await page.locator('#tensorShape1').fill('32') - // Submit form - await page.locator('button[type="submit"]').click() - - // Wait a bit for rendering - await page.waitForTimeout(100) - - // Take another screenshot - const newScreenshot = await canvas.screenshot() - // Screenshots should be different - expect(Buffer.compare(initialScreenshot, newScreenshot)).not.toBe(0) - }) - - test('should have zoom controls', async ({ page }) => { - await page.goto('/') - - // Check that zoom buttons exist - await expect(page.locator('#zoomIn')).toBeVisible() - await expect(page.locator('#zoomOut')).toBeVisible() - await expect(page.locator('#reset')).toBeVisible() - - // Click zoom in button - await page.locator('#zoomIn').click() - - // Click zoom out button - await page.locator('#zoomOut').click() - - // Click reset button - await page.locator('#reset').click() + await page.waitForTimeout(100) + await expect.poll(async () => { + const newScreenshot = await canvas.screenshot() + return Buffer.compare(initialScreenshot, newScreenshot) + }).not.toBe(0) }) test('should show tooltip when hovering WMMA cells', async ({ page }) => { @@ -147,3 +149,111 @@ test.describe('Triton Layout Visualizer', () => { }) }) }) + +test.describe('Shared Layout', () => { + test('Tab switching activates Shared Layout without console errors', async ({ page }) => { + const getConsoleErrors = captureConsoleErrors(page) + await openSharedLayoutTab(page) + + const sharedContent = page.locator('#shared-layout') + await expect(sharedContent).toHaveClass(/active/) + await expect(sharedContent).toBeVisible() + await expect(page.locator('#block-layout')).toBeHidden() + await expect(page.locator('#shared-canvas')).toBeVisible() + + expect(getConsoleErrors()).toEqual([]) + }) + + test('Form controls are visible with default Shared Layout values', async ({ page }) => { + await openSharedLayoutTab(page) + + const controlSelectors = [ + '#shared-swizzle-mode', + '#shared-view-mode', + '#shared-vec', + '#shared-per-phase', + '#shared-max-phase', + '#shared-order', + '#shared-rows', + '#shared-cols', + '#shared-element-bits', + ] + + for (const selector of controlSelectors) { + await expect(page.locator(selector)).toBeVisible() + } + + await expect(page.locator('#shared-rows')).toHaveValue('128') + await expect(page.locator('#shared-cols')).toHaveValue('64') + await expect(page.locator('#shared-vec')).toHaveValue('4') + await expect(page.locator('#shared-per-phase')).toHaveValue('2') + await expect(page.locator('#shared-max-phase')).toHaveValue('4') + await expect(page.locator('#shared-order')).toHaveValue('1,0') + await expect(page.locator('#shared-element-bits')).toHaveValue('16') + await expect(page.locator('#shared-swizzle-mode')).toHaveValue('swizzled') + await expect(page.locator('#shared-view-mode')).toHaveValue('logical') + }) + + test('Validation errors appear and bank info clears for invalid tensors', async ({ page }) => { + await openSharedLayoutTab(page) + + await page.locator('#shared-rows').fill('3') + + const errors = page.locator('#shared-validation-errors') + await expect(errors).toHaveClass(/visible/) + await expect(errors).toContainText(/power/i) + + const bankSelectors = [ + '#shared-bank-count', + '#shared-bank-size', + '#shared-elems-per-bank', + '#shared-bank-segments', + ] + + for (const selector of bankSelectors) { + await expect(page.locator(selector)).toHaveText('--') + } + }) + + test('Shared canvas renders and rerenders after tensor updates', async ({ page }) => { + await openSharedLayoutTab(page) + + const canvas = page.locator('#shared-canvas') + await expect(canvas).toBeVisible() + + const boundingBox = await canvas.boundingBox() + expect(boundingBox).not.toBeNull() + expect(boundingBox!.width).toBeGreaterThan(0) + expect(boundingBox!.height).toBeGreaterThan(0) + + const initialScreenshot = await canvas.screenshot() + + await page.locator('#shared-rows').fill('256') + await expect(page.locator('#shared-bank-segments')).toHaveText('256') + + const updatedScreenshot = await canvas.screenshot() + expect(Buffer.compare(initialScreenshot, updatedScreenshot)).not.toBe(0) + }) + + test('Tooltip shows Shared Layout metadata without legacy fields', async ({ page }) => { + await openSharedLayoutTab(page) + + const canvas = page.locator('#shared-canvas') + await expect(canvas).toBeVisible() + + const box = await canvas.boundingBox() + expect(box).not.toBeNull() + await page.mouse.move(box!.x + box!.width / 2, box!.y + box!.height / 2) + + const tooltip = page.locator('.layout-tooltip') + await expect(tooltip).toBeVisible() + await expect(tooltip).toContainText('Logical Index:') + await expect(tooltip).toContainText('Offset:') + await expect(tooltip).toContainText('Bank:') + + const tooltipText = ((await tooltip.textContent()) ?? '').trim() + expect(tooltipText).toMatch(/Segment (?:\(Bank Row\):|:)/) + await expect(tooltip).not.toContainText('Original Col') + await expect(tooltip).not.toContainText('Swizzled Col') + }) +})