diff --git a/index.html b/index.html
index b84b60f..078e31d 100644
--- a/index.html
+++ b/index.html
@@ -113,7 +113,143 @@
diff --git a/src/layouts/SharedLayout.bank.test.ts b/src/layouts/SharedLayout.bank.test.ts
new file mode 100644
index 0000000..25b0ddd
--- /dev/null
+++ b/src/layouts/SharedLayout.bank.test.ts
@@ -0,0 +1,129 @@
+import { describe, expect, it } from 'vitest'
+import { assignBank, createSharedLayout, type SharedLayoutParams } from './SharedLayout'
+
+interface LayoutAccessor {
+ numCols: number
+ numRows: number
+ toOffset(row: number, col: number): number
+}
+
+function buildAccessor(params: SharedLayoutParams): LayoutAccessor {
+ const { layout, rowDimName, colDimName } = createSharedLayout(params)
+ const inverse = layout.invert()
+ const colIndex = params.order[0]
+ const rowIndex = params.order[1]
+ const numCols = params.tensorShape[colIndex]
+ const numRows = params.tensorShape[rowIndex]
+ return {
+ numCols,
+ numRows,
+ toOffset(row: number, col: number): number {
+ const inputs: Record
= {
+ [rowDimName]: row,
+ [colDimName]: col,
+ }
+ const { offset = 0 } = inverse.apply(inputs)
+ return offset
+ },
+ }
+}
+
+function rowMajorOffset(numCols: number, row: number, col: number): number {
+ return row * numCols + col
+}
+
+describe('SharedLayout bank distribution', () => {
+ it('reduces column-stripe bank conflicts compared to row-major', () => {
+ const params: SharedLayoutParams = {
+ tensorShape: [64, 64],
+ order: [0, 1],
+ vec: 8,
+ perPhase: 4,
+ maxPhase: 8,
+ swizzleMode: 'swizzled',
+ }
+ const accessor = buildAccessor(params)
+ const column = 0
+ const rows = Array.from({ length: 8 }, (_, idx) => idx)
+ const rowMajorBanks = rows.map((row) =>
+ assignBank(rowMajorOffset(accessor.numCols, row, column), 16).bank
+ )
+ const swizzledBanks = rows.map((row) => assignBank(accessor.toOffset(row, column), 16).bank)
+ expect(new Set(rowMajorBanks).size).toBe(1)
+ expect(new Set(swizzledBanks).size).toBeGreaterThan(new Set(rowMajorBanks).size)
+ })
+
+ it('assigns distinct banks to consecutive rows for 32-bit accesses', () => {
+ const params: SharedLayoutParams = {
+ tensorShape: [64, 64],
+ order: [0, 1],
+ vec: 4,
+ perPhase: 2,
+ maxPhase: 8,
+ swizzleMode: 'swizzled',
+ }
+ const accessor = buildAccessor(params)
+ const column = 5
+ const rows = [0, 1, 2, 3]
+ const rowMajorBanks = rows.map((row) =>
+ assignBank(rowMajorOffset(accessor.numCols, row, column), 32).bank
+ )
+ const banks = rows.map((row) => assignBank(accessor.toOffset(row, column), 32).bank)
+ expect(new Set(banks).size).toBeGreaterThan(new Set(rowMajorBanks).size)
+ })
+
+ it('spans multiple bank segments for large tensors', () => {
+ const params: SharedLayoutParams = {
+ tensorShape: [128, 128],
+ order: [0, 1],
+ vec: 16,
+ perPhase: 8,
+ maxPhase: 16,
+ swizzleMode: 'swizzled',
+ }
+ const accessor = buildAccessor(params)
+ const start = assignBank(accessor.toOffset(0, 0), 32)
+ const end = assignBank(accessor.toOffset(accessor.numRows - 1, accessor.numCols - 1), 32)
+ expect(end.segment).toBeGreaterThan(start.segment)
+ })
+
+ it('responds to custom bank topologies', () => {
+ const params: SharedLayoutParams = {
+ tensorShape: [64, 64],
+ order: [0, 1],
+ vec: 8,
+ perPhase: 4,
+ maxPhase: 8,
+ swizzleMode: 'swizzled',
+ }
+ const accessor = buildAccessor(params)
+ const columns = Array.from({ length: 64 }, (_, idx) => idx % accessor.numCols)
+ const offsets = columns.map((col) => accessor.toOffset(12, col))
+ const sixteenBanks = offsets.map((value) =>
+ assignBank(value, 32, { bankCount: 16, bankSizeBits: 32 }).bank
+ )
+ const sixtyFourBanks = offsets.map((value) =>
+ assignBank(value, 32, { bankCount: 64, bankSizeBits: 32 }).bank
+ )
+ expect(Math.max(...sixteenBanks)).toBeLessThan(16)
+ expect(Math.max(...sixtyFourBanks)).toBeLessThan(64)
+ expect(new Set(sixtyFourBanks).size).toBeGreaterThan(new Set(sixteenBanks).size)
+ })
+
+ it('varies bank coverage with element size', () => {
+ const params: SharedLayoutParams = {
+ tensorShape: [32, 128],
+ order: [0, 1],
+ vec: 8,
+ perPhase: 4,
+ maxPhase: 16,
+ swizzleMode: 'swizzled',
+ }
+ const accessor = buildAccessor(params)
+ const row = 10
+ const columns = Array.from({ length: 32 }, (_, idx) => idx)
+ const banks8 = columns.map((col) => assignBank(accessor.toOffset(row, col), 8).bank)
+ const banks64 = columns.map((col) => assignBank(accessor.toOffset(row, col), 64).bank)
+ expect(new Set(banks64).size).toBeGreaterThan(new Set(banks8).size)
+ })
+})
diff --git a/src/layouts/SharedLayout.integration.test.ts b/src/layouts/SharedLayout.integration.test.ts
new file mode 100644
index 0000000..ab182d5
--- /dev/null
+++ b/src/layouts/SharedLayout.integration.test.ts
@@ -0,0 +1,179 @@
+import { describe, expect, it } from 'vitest'
+import { assignBank, createSharedLayout, type SharedLayoutParams } from './SharedLayout'
+
+function enumerateOffsets(params: SharedLayoutParams): number[] {
+ const { layout, rowDimName, colDimName } = createSharedLayout(params)
+ const inverse = layout.invert()
+ const numCols = params.tensorShape[params.order[0]]
+ const numRows = params.tensorShape[params.order[1]]
+ const totalElements = numCols * numRows
+ const offsets = new Set()
+
+ for (let row = 0; row < numRows; row++) {
+ for (let col = 0; col < numCols; col++) {
+ const inverseInputs: Record = {
+ [rowDimName]: row,
+ [colDimName]: col,
+ }
+ const { offset = 0 } = inverse.apply(inverseInputs)
+ expect(offset).toBeGreaterThanOrEqual(0)
+ expect(offset).toBeLessThan(totalElements)
+ expect(offsets.has(offset)).toBe(false)
+ offsets.add(offset)
+
+ const coords = layout.apply({ offset })
+ expect(coords[rowDimName]).toBe(row)
+ expect(coords[colDimName]).toBe(col)
+ }
+ }
+
+ expect(offsets.size).toBe(totalElements)
+ const sorted = Array.from(offsets).sort((a, b) => a - b)
+ expect(sorted[0]).toBe(0)
+ expect(sorted[sorted.length - 1]).toBe(totalElements - 1)
+ return sorted
+}
+
+describe('SharedLayout integration', () => {
+ describe('round-trip enumerations', () => {
+ it('round-trips every element for a 2x2 tensor', () => {
+ const params: SharedLayoutParams = {
+ tensorShape: [2, 2],
+ order: [0, 1],
+ vec: 1,
+ perPhase: 1,
+ maxPhase: 1,
+ swizzleMode: 'swizzled',
+ }
+ enumerateOffsets(params)
+ })
+
+ it('round-trips every element for a 4x4 tensor', () => {
+ const params: SharedLayoutParams = {
+ tensorShape: [4, 4],
+ order: [0, 1],
+ vec: 2,
+ perPhase: 2,
+ maxPhase: 2,
+ swizzleMode: 'swizzled',
+ }
+ enumerateOffsets(params)
+ })
+
+ it('round-trips every element for an 8x8 tensor', () => {
+ const params: SharedLayoutParams = {
+ tensorShape: [8, 8],
+ order: [0, 1],
+ vec: 4,
+ perPhase: 2,
+ maxPhase: 4,
+ swizzleMode: 'swizzled',
+ }
+ enumerateOffsets(params)
+ })
+
+ it('round-trips every element for a 16x16 tensor', () => {
+ const params: SharedLayoutParams = {
+ tensorShape: [16, 16],
+ order: [0, 1],
+ vec: 8,
+ perPhase: 4,
+ maxPhase: 8,
+ swizzleMode: 'swizzled',
+ }
+ enumerateOffsets(params)
+ })
+ })
+
+ describe('Triton parameter tuples', () => {
+ const tuples: Array<[number, number, number]> = [
+ [1, 1, 1],
+ [4, 2, 4],
+ [8, 1, 8],
+ [16, 1, 16],
+ ]
+ it.each(tuples)(
+ 'round-trips for vec=%i, perPhase=%i, maxPhase=%i',
+ (vec, perPhase, maxPhase) => {
+ const params: SharedLayoutParams = {
+ tensorShape: [16, 16],
+ order: [0, 1],
+ vec,
+ perPhase,
+ maxPhase,
+ swizzleMode: 'swizzled',
+ }
+ enumerateOffsets(params)
+ }
+ )
+ })
+
+ it('round-trips AMD rotating 64x64 layouts without collisions', () => {
+ const params: SharedLayoutParams = {
+ tensorShape: [64, 64],
+ order: [0, 1],
+ vec: 8,
+ perPhase: 4,
+ maxPhase: 8,
+ swizzleMode: 'amdRotating',
+ }
+ const offsets = enumerateOffsets(params)
+ const blockPeriod = params.perPhase * params.maxPhase
+ const columnZeroOffsets: number[] = []
+ const { layout, rowDimName, colDimName } = createSharedLayout(params)
+ const inverse = layout.invert()
+ for (let row = 0; row < params.tensorShape[params.order[1]]; row++) {
+ const inputs: Record = {
+ [rowDimName]: row,
+ [colDimName]: 0,
+ }
+ const { offset = 0 } = inverse.apply(inputs)
+ columnZeroOffsets.push(offset)
+ }
+ expect(columnZeroOffsets.slice(0, blockPeriod)).not.toEqual(
+ columnZeroOffsets.slice(blockPeriod, blockPeriod * 2)
+ )
+ expect(offsets.length).toBe(params.tensorShape[0] * params.tensorShape[1])
+ })
+
+ it('supports non-square tensors and tracks element-size bank coverage', () => {
+ const params: SharedLayoutParams = {
+ tensorShape: [32, 128],
+ order: [0, 1],
+ vec: 8,
+ perPhase: 4,
+ maxPhase: 16,
+ swizzleMode: 'swizzled',
+ }
+ const offsets = enumerateOffsets(params)
+ const elementBits = [8, 16, 32, 64]
+ const coverage = elementBits.map((bits) => {
+ const banks = new Set()
+ for (const offset of offsets) {
+ banks.add(assignBank(offset, bits).bank)
+ }
+ return { bits, banks: banks.size }
+ })
+ expect(coverage.find((entry) => entry.bits === 64)?.banks).toBeLessThan(
+ coverage.find((entry) => entry.bits === 8)?.banks ?? 0
+ )
+ })
+
+ it('keeps AMD and swizzled offsets aligned to prevent collisions', () => {
+ const swizzled: SharedLayoutParams = {
+ tensorShape: [32, 32],
+ order: [0, 1],
+ vec: 8,
+ perPhase: 4,
+ maxPhase: 8,
+ swizzleMode: 'swizzled',
+ }
+ const amd: SharedLayoutParams = {
+ ...swizzled,
+ swizzleMode: 'amdRotating',
+ }
+ const swizzledOffsets = enumerateOffsets(swizzled)
+ const amdOffsets = enumerateOffsets(amd)
+ expect(amdOffsets).toEqual(swizzledOffsets)
+ })
+})
diff --git a/src/layouts/SharedLayout.test.ts b/src/layouts/SharedLayout.test.ts
new file mode 100644
index 0000000..8418370
--- /dev/null
+++ b/src/layouts/SharedLayout.test.ts
@@ -0,0 +1,283 @@
+import { describe, expect, it } from 'vitest'
+import {
+ assignBank,
+ computeBankInfo,
+ computeRowSwizzle,
+ computeSwizzledColumn,
+ createSharedLayout,
+ type SharedLayoutParams,
+} from './SharedLayout'
+
+const baseParams: SharedLayoutParams = {
+ tensorShape: [64, 128],
+ order: [0, 1],
+ vec: 2,
+ perPhase: 2,
+ maxPhase: 4,
+ swizzleMode: 'swizzled',
+}
+
+describe('createSharedLayout', () => {
+ it('reorders outputs into canonical dim order regardless of layout order', () => {
+ const direct = createSharedLayout(baseParams)
+ expect(direct.layout.getOutDims()).toEqual([
+ ['dim0', 64],
+ ['dim1', 128],
+ ])
+ expect(direct.layout.getInDimNames()).toEqual(['offset'])
+
+ const reversed = createSharedLayout({
+ ...baseParams,
+ order: [1, 0],
+ })
+ expect(reversed.layout.getOutDims()).toEqual([
+ ['dim0', 64],
+ ['dim1', 128],
+ ])
+ expect(reversed.colDimIndex).toBe(1)
+ expect(reversed.rowDimIndex).toBe(0)
+ })
+
+ it('builds layouts for tiny tensors when every parameter is 1', () => {
+ const tinyParams: SharedLayoutParams = {
+ tensorShape: [2, 2],
+ order: [0, 1],
+ vec: 1,
+ perPhase: 1,
+ maxPhase: 1,
+ swizzleMode: 'swizzled',
+ }
+ const result = createSharedLayout(tinyParams)
+ expect(result.layout.getOutDims()).toEqual([
+ ['dim0', 2],
+ ['dim1', 2],
+ ])
+ })
+
+ it('supports large tensors where vec equals columns and perPhase equals rows', () => {
+ const largeParams: SharedLayoutParams = {
+ tensorShape: [128, 128],
+ order: [0, 1],
+ vec: 128,
+ perPhase: 128,
+ maxPhase: 128,
+ swizzleMode: 'swizzled',
+ }
+ const result = createSharedLayout(largeParams)
+ expect(result.layout.getOutDims()).toEqual([
+ ['dim0', 128],
+ ['dim1', 128],
+ ])
+ })
+})
+
+describe('row swizzle helpers', () => {
+ it('computes swizzle for standard shared layouts', () => {
+ const params = { ...baseParams, vec: 4, perPhase: 2, maxPhase: 8 } satisfies SharedLayoutParams
+ expect(computeRowSwizzle(0, params)).toBe(0)
+ expect(computeRowSwizzle(1, params)).toBe(0)
+ expect(computeRowSwizzle(2, params)).toBe(4)
+ expect(computeRowSwizzle(16, params)).toBe(0)
+ })
+
+ it('computes swizzle for AMD rotating layouts', () => {
+ const params: SharedLayoutParams = {
+ ...baseParams,
+ swizzleMode: 'amdRotating',
+ vec: 2,
+ perPhase: 2,
+ maxPhase: 4,
+ }
+ expect(computeRowSwizzle(0, params)).toBe(0)
+ expect(computeRowSwizzle(2, params)).toBe(2)
+ expect(computeRowSwizzle(4, params)).toBe(4)
+ })
+
+ it('XORs contributions when multiple row bits are set', () => {
+ const params: SharedLayoutParams = {
+ ...baseParams,
+ vec: 8,
+ perPhase: 2,
+ maxPhase: 8,
+ }
+ expect(computeRowSwizzle(6, params)).toBe(
+ computeRowSwizzle(2, params) ^ computeRowSwizzle(4, params)
+ )
+ expect(computeRowSwizzle(10, params)).toBe(
+ computeRowSwizzle(2, params) ^ computeRowSwizzle(8, params)
+ )
+ })
+
+ it('returns zero swizzle for degenerate tensor sizes', () => {
+ const singleRow: SharedLayoutParams = {
+ ...baseParams,
+ tensorShape: [64, 1],
+ }
+ expect(computeRowSwizzle(5, singleRow)).toBe(0)
+
+ const zeroCols: SharedLayoutParams = {
+ ...baseParams,
+ tensorShape: [0, 64],
+ }
+ expect(computeRowSwizzle(8, zeroCols)).toBe(0)
+ })
+
+ it('applies AMD rotating XOR when crossing block boundaries', () => {
+ const amdParams: SharedLayoutParams = {
+ tensorShape: [64, 64],
+ order: [0, 1],
+ vec: 4,
+ perPhase: 4,
+ maxPhase: 4,
+ swizzleMode: 'amdRotating',
+ }
+ const standard = { ...amdParams, swizzleMode: 'swizzled' as const }
+ expect(computeRowSwizzle(32, standard)).toBe(0)
+ expect(computeRowSwizzle(32, amdParams)).toBe(8)
+ expect(computeRowSwizzle(48, amdParams)).toBe(12)
+ })
+
+ it('behaves like an identity layout when all tunable parameters equal 1', () => {
+ const params: SharedLayoutParams = {
+ tensorShape: [1, 1],
+ order: [0, 1],
+ vec: 1,
+ perPhase: 1,
+ maxPhase: 1,
+ swizzleMode: 'swizzled',
+ }
+ expect(computeRowSwizzle(0, params)).toBe(0)
+ expect(computeSwizzledColumn(0, 3, params)).toBe(0)
+ })
+})
+
+describe('computeSwizzledColumn', () => {
+ it('matches XOR-based swizzle for logical columns', () => {
+ const params = { ...baseParams, vec: 4, perPhase: 2, maxPhase: 4 } satisfies SharedLayoutParams
+ const row = 8
+ const rowSwizzle = computeRowSwizzle(row, params)
+ const originalCol = 21
+ const swizzled = computeSwizzledColumn(row, originalCol, params)
+ expect(swizzled).toBe((originalCol ^ rowSwizzle) % params.tensorShape[params.order[0]])
+ })
+
+ it('normalizes logical columns that exceed the tensor width', () => {
+ const params = { ...baseParams, vec: 2, perPhase: 2, maxPhase: 4 } satisfies SharedLayoutParams
+ const numCols = params.tensorShape[params.order[0]]
+ const row = 5
+ const rowSwizzle = computeRowSwizzle(row, params)
+ const wrappedCol = numCols * 3 + 7
+ expect(computeSwizzledColumn(row, wrappedCol, params)).toBe(
+ (wrappedCol % numCols) ^ rowSwizzle
+ )
+ })
+
+ it('propagates negative logical columns through the XOR normalization', () => {
+ const params = { ...baseParams, vec: 2, perPhase: 2, maxPhase: 4 } satisfies SharedLayoutParams
+ const numCols = params.tensorShape[params.order[0]]
+ const row = 3
+ const rowSwizzle = computeRowSwizzle(row, params)
+ const negativeCol = -9
+ expect(computeSwizzledColumn(row, negativeCol, params)).toBe(
+ (negativeCol % numCols) ^ rowSwizzle
+ )
+ })
+
+ it('returns the original column when there are zero logical columns', () => {
+ const params: SharedLayoutParams = {
+ ...baseParams,
+ tensorShape: [0, 64],
+ }
+ expect(computeSwizzledColumn(3, 99, params)).toBe(99)
+ })
+})
+
+describe('bank calculations', () => {
+ it('reports per-bank statistics derived from tensor shape', () => {
+ const info = computeBankInfo(baseParams.tensorShape, 16)
+ expect(info.bankCount).toBe(32)
+ expect(info.bankSizeBits).toBe(32)
+ expect(info.elementsPerBankRow).toBe(2)
+ expect(info.segmentsPerBankRow).toBe(128)
+ })
+
+ it('assigns banks and segments based on offsets', () => {
+ const first = assignBank(0, 16)
+ const next = assignBank(8, 16)
+ expect(first.bank).toBe(0)
+ expect(first.segment).toBe(0)
+ expect(next.bank).toBeGreaterThanOrEqual(0)
+ expect(next.segment).toBeGreaterThanOrEqual(0)
+ })
+
+ it('scales bank information with different element bitwidths', () => {
+ const bitWidths = [8, 16, 32, 64]
+ const expectedElements: Record = {
+ 8: 4,
+ 16: 2,
+ 32: 1,
+ 64: 0.5,
+ }
+ for (const bits of bitWidths) {
+ const info = computeBankInfo([64, 64], bits)
+ expect(info.elementsPerBankRow).toBe(expectedElements[bits])
+ expect(info.segmentsPerBankRow).toBeGreaterThan(0)
+ }
+ })
+
+ it('derives per-bank capacity directly from the configured bank width', () => {
+ const largerBanks = computeBankInfo([8, 8], 16, { bankCount: 64, bankSizeBits: 64 })
+ expect(largerBanks.elementsPerBankRow).toBe(4)
+
+ const smallerBanks = computeBankInfo([8, 8], 16, { bankCount: 16, bankSizeBits: 16 })
+ expect(smallerBanks.elementsPerBankRow).toBe(1)
+
+ // Ensure bank count alone does not change the per-bank capacity
+ const moreBanksSameWidth = computeBankInfo([8, 8], 16, { bankCount: 128, bankSizeBits: 32 })
+ const fewerBanksSameWidth = computeBankInfo([8, 8], 16, { bankCount: 16, bankSizeBits: 32 })
+ expect(moreBanksSameWidth.elementsPerBankRow).toBe(fewerBanksSameWidth.elementsPerBankRow)
+ })
+
+ it('derives segments for tiny and large tensors', () => {
+ const tiny = computeBankInfo([2, 2], 32)
+ expect(tiny.segmentsPerBankRow).toBe(1)
+
+ const large = computeBankInfo([128, 128], 16)
+ expect(large.segmentsPerBankRow).toBeGreaterThan(tiny.segmentsPerBankRow)
+ })
+
+ it('assigns banks using overrides and clamps edge cases', () => {
+ const overrides = { bankCount: 16, bankSizeBits: 64 }
+ const regular = assignBank(5, 32, overrides)
+ expect(regular.bank).toBeLessThan(16)
+ expect(regular.segment).toBe(0)
+
+ const negative = assignBank(-10, 32, overrides)
+ expect(negative.bank).toBe(0)
+ expect(negative.segment).toBe(0)
+
+ const far = assignBank(4096, 32, overrides)
+ expect(far.segment).toBeGreaterThan(0)
+ })
+})
+
+describe('validation', () => {
+ it('rejects tensor shapes that are not powers of two', () => {
+ expect(() =>
+ createSharedLayout({ ...baseParams, tensorShape: [12, 128] })
+ ).toThrow(/Column count must be a power of two/)
+ expect(() =>
+ createSharedLayout({ ...baseParams, tensorShape: [64, 96] })
+ ).toThrow(/Row count must be a power of two/)
+ })
+
+ it('rejects non power-of-two parameters', () => {
+ expect(() => createSharedLayout({ ...baseParams, vec: 3 })).toThrow(/vec must be a power of two/)
+ expect(() => createSharedLayout({ ...baseParams, perPhase: 6 })).toThrow(
+ /perPhase must be a power of two/
+ )
+ expect(() => createSharedLayout({ ...baseParams, maxPhase: 10 })).toThrow(
+ /maxPhase must be a power of two/
+ )
+ })
+})
diff --git a/src/layouts/SharedLayout.ts b/src/layouts/SharedLayout.ts
new file mode 100644
index 0000000..62af498
--- /dev/null
+++ b/src/layouts/SharedLayout.ts
@@ -0,0 +1,248 @@
+import { LinearLayout } from '../core/LinearLayout'
+
+export const SHARED_BANK_COUNT = 32
+export const SHARED_BANK_WIDTH_BITS = 32
+
+export type SharedSwizzleMode = 'swizzled' | 'amdRotating'
+
+export interface SharedLayoutParams {
+ tensorShape: [number, number]
+ order: [number, number]
+ vec: number
+ perPhase: number
+ maxPhase: number
+ swizzleMode: SharedSwizzleMode
+}
+
+export interface SharedLayoutBuildResult {
+ layout: LinearLayout
+ rowDimName: string
+ colDimName: string
+ rowDimIndex: number
+ colDimIndex: number
+}
+
+export interface SharedBankInfo {
+ bankCount: number
+ bankSizeBits: number
+ elementsPerBankRow: number
+ segmentsPerBankRow: number
+}
+
+export interface BankAssignmentOptions {
+ bankCount?: number
+ bankSizeBits?: number
+}
+
+type DimIndex = 0 | 1
+
+function assertPositive(value: number, context: string): void {
+ if (!Number.isFinite(value) || value <= 0) {
+ throw new Error(`${context} must be a positive number`)
+ }
+}
+
+function assertPowerOfTwo(value: number, context: string): void {
+ assertPositive(value, context)
+ if (!Number.isInteger(value)) {
+ throw new Error(`${context} must be a positive integer`)
+ }
+ if ((value & (value - 1)) !== 0) {
+ throw new Error(`${context} must be a power of two`)
+ }
+}
+
+function assertDimIndex(value: number, context: string): asserts value is DimIndex {
+ if (value !== 0 && value !== 1) {
+ throw new Error(`${context} must be 0 or 1`)
+ }
+}
+
+function assertDefined(value: T | undefined, context: string): T {
+ if (value === undefined) {
+ throw new Error(`${context} is missing`)
+ }
+ return value
+}
+
+function normalizeOrder(order: [number, number]): [DimIndex, DimIndex] {
+ const [first, second] = order
+ assertDimIndex(first, 'Order[0]')
+ assertDimIndex(second, 'Order[1]')
+ if ((first === 0 && second === 1) || (first === 1 && second === 0)) {
+ return [first, second]
+ }
+ throw new Error('Order must be [0,1] or [1,0]')
+}
+
+export function createSharedLayout(params: SharedLayoutParams): SharedLayoutBuildResult {
+ const dims: [string, string] = ['dim0', 'dim1']
+ const shape = params.tensorShape
+ if (shape.length !== 2) {
+ throw new Error('Shared layout currently only supports 2D tensors')
+ }
+
+ const [colDimIndex, rowDimIndex] = normalizeOrder(params.order)
+ const numCols = assertDefined(shape[colDimIndex], 'Column count')
+ const numRows = assertDefined(shape[rowDimIndex], 'Row count')
+
+ assertPowerOfTwo(numCols, 'Column count')
+ assertPowerOfTwo(numRows, 'Row count')
+ assertPowerOfTwo(params.vec, 'vec')
+ assertPowerOfTwo(params.perPhase, 'perPhase')
+ assertPowerOfTwo(params.maxPhase, 'maxPhase')
+
+ const bases = buildSharedBases(numRows, numCols, params)
+ const rowDimName = assertDefined(dims[rowDimIndex], 'Row dimension name')
+ const colDimName = assertDefined(dims[colDimIndex], 'Column dimension name')
+
+ const layout = new LinearLayout(
+ [['offset', bases]],
+ [rowDimName, colDimName]
+ ).reorderOutputs(dims)
+
+ return {
+ layout,
+ rowDimName,
+ colDimName,
+ rowDimIndex,
+ colDimIndex,
+ }
+}
+
+function buildSharedBases(
+ numRows: number,
+ numCols: number,
+ params: SharedLayoutParams
+): number[][] {
+ const bases: number[][] = []
+
+ for (let col = 1; col < numCols; col <<= 1) {
+ bases.push([0, col])
+ }
+
+ for (let row = 1; row < numRows; row <<= 1) {
+ const swizzle = params.swizzleMode === 'amdRotating'
+ ? computeAmdRotatingContribution(row, params.vec, params.perPhase, params.maxPhase, numCols)
+ : computeSwizzledContribution(row, params.vec, params.perPhase, params.maxPhase, numCols)
+ bases.push([row, swizzle])
+ }
+
+ return bases
+}
+
+function computeSwizzledContribution(
+ rowBitValue: number,
+ vec: number,
+ perPhase: number,
+ maxPhase: number,
+ numCols: number
+): number {
+ if (perPhase <= 0 || maxPhase <= 0 || numCols <= 0) {
+ return 0
+ }
+ const phase = Math.floor(rowBitValue / perPhase) % maxPhase
+ const contribution = vec * phase
+ return contribution % numCols
+}
+
+function computeAmdRotatingContribution(
+ rowBitValue: number,
+ vec: number,
+ perPhase: number,
+ maxPhase: number,
+ numCols: number
+): number {
+ if (perPhase <= 0 || maxPhase <= 0 || numCols <= 0) {
+ return 0
+ }
+ const phase = Math.floor(rowBitValue / perPhase) % maxPhase
+ const blockNo = Math.floor(rowBitValue / perPhase / maxPhase) % maxPhase
+ const combinedPhase = phase ^ blockNo
+ const contribution = vec * combinedPhase
+ return contribution % numCols
+}
+
+export function computeRowSwizzle(row: number, params: SharedLayoutParams): number {
+ const [colDimIndex, rowDimIndex] = normalizeOrder(params.order)
+ const numRows = assertDefined(params.tensorShape[rowDimIndex], 'Row count')
+ const numCols = assertDefined(params.tensorShape[colDimIndex], 'Column count')
+ if (numRows <= 1 || numCols <= 0) {
+ return 0
+ }
+
+ let swizzle = 0
+ for (let bit = 1; bit < numRows; bit <<= 1) {
+ if ((row & bit) === 0) {
+ continue
+ }
+ const contribution = params.swizzleMode === 'amdRotating'
+ ? computeAmdRotatingContribution(bit, params.vec, params.perPhase, params.maxPhase, numCols)
+ : computeSwizzledContribution(bit, params.vec, params.perPhase, params.maxPhase, numCols)
+ swizzle ^= contribution
+ }
+
+ return swizzle % numCols
+}
+
+export function computeSwizzledColumn(
+ row: number,
+ column: number,
+ params: SharedLayoutParams
+): number {
+ const [colDimIndex] = normalizeOrder(params.order)
+ const numCols = assertDefined(params.tensorShape[colDimIndex], 'Column count')
+ if (numCols <= 0) {
+ return column
+ }
+ const rowSwizzle = computeRowSwizzle(row, params)
+ const normalizedCol = column % numCols
+ return normalizedCol ^ rowSwizzle
+}
+
+export function computeBankInfo(
+ tensorShape: [number, number],
+ elementBits: number,
+ overrides?: BankAssignmentOptions
+): SharedBankInfo {
+ assertPositive(elementBits, 'Element bitwidth')
+ const bankCount = overrides?.bankCount ?? SHARED_BANK_COUNT
+ const bankSizeBits = overrides?.bankSizeBits ?? SHARED_BANK_WIDTH_BITS
+ assertPositive(bankCount, 'Bank count')
+ assertPositive(bankSizeBits, 'Bank size (bits)')
+
+ const bytesPerElement = Math.max(1, Math.trunc(elementBits / 8))
+ const totalBankRowBits = bankCount * bankSizeBits
+ const elementsPerBankRow = bankSizeBits / elementBits
+ const totalElements = Math.max(1, tensorShape[0] * tensorShape[1])
+ const totalBytes = totalElements * bytesPerElement
+ const bankRowBytes = Math.max(1, Math.trunc(totalBankRowBits / 8))
+ const segmentsPerBankRow = Math.max(1, Math.ceil(totalBytes / bankRowBytes))
+
+ return {
+ bankCount,
+ bankSizeBits,
+ elementsPerBankRow,
+ segmentsPerBankRow,
+ }
+}
+
+export function assignBank(
+ offset: number,
+ elementBits: number,
+ overrides?: BankAssignmentOptions
+): { bank: number; segment: number } {
+ assertPositive(elementBits, 'Element bitwidth')
+ const bankCount = overrides?.bankCount ?? SHARED_BANK_COUNT
+ const bankSizeBits = overrides?.bankSizeBits ?? SHARED_BANK_WIDTH_BITS
+ assertPositive(bankCount, 'Bank count')
+ assertPositive(bankSizeBits, 'Bank size (bits)')
+
+ const bytesPerElement = Math.max(1, Math.trunc(elementBits / 8))
+ const bankWidthBytes = Math.max(1, Math.trunc(bankSizeBits / 8))
+ const byteOffset = Math.max(0, offset) * bytesPerElement
+ const bankAddress = Math.floor(byteOffset / bankWidthBytes)
+ const bank = ((bankAddress % bankCount) + bankCount) % bankCount
+ const segment = Math.floor(bankAddress / bankCount)
+ return { bank, segment }
+}
diff --git a/src/main.tabs.test.ts b/src/main.tabs.test.ts
index b5736c0..6f7f60e 100644
--- a/src/main.tabs.test.ts
+++ b/src/main.tabs.test.ts
@@ -73,6 +73,16 @@ vi.mock('./tabs/WMMALayoutTab', () => ({
WMMALayoutTab: WMMALayoutTabMock,
}))
+const SharedLayoutTabMock = vi.fn().mockImplementation(() => ({
+ activate: vi.fn(),
+ deactivate: vi.fn(),
+ resize: vi.fn(),
+}))
+
+vi.mock('./tabs/SharedLayoutTab', () => ({
+ SharedLayoutTab: SharedLayoutTabMock,
+}))
+
const MFMALayoutTabMock = vi.fn().mockImplementation(() => ({
activate: vi.fn(),
deactivate: vi.fn(),
diff --git a/src/main.ts b/src/main.ts
index fc12206..2cec6b2 100644
--- a/src/main.ts
+++ b/src/main.ts
@@ -3,6 +3,7 @@ import { WMMALayoutTab } from './tabs/WMMALayoutTab'
import { MFMALayoutTab } from './tabs/MFMALayoutTab'
import { LinearLayoutTab } from './tabs/LinearLayoutTab'
import { layoutProjectionBus } from './integration/LayoutProjectionBus'
+import { SharedLayoutTab } from './tabs/SharedLayoutTab'
type TabController = {
activate(): void
@@ -54,9 +55,11 @@ const setActiveTab = (tabId: string): void => {
layoutProjectionBus.setTabActivationHandler(setActiveTab)
const linearLayoutTab = new LinearLayoutTab('linear-layout')
+const sharedLayoutTab = new SharedLayoutTab('shared-layout')
const blockLayoutTab = new BlockLayoutTab('block-layout')
controllers.set('block-layout', blockLayoutTab)
+controllers.set('shared-layout', sharedLayoutTab)
controllers.set('wmma-layout', new WMMALayoutTab('wmma-layout'))
controllers.set('mfma-layout', new MFMALayoutTab('mfma-layout'))
controllers.set('linear-layout', linearLayoutTab)
diff --git a/src/styles.css b/src/styles.css
index a80e1f0..9dd153d 100644
--- a/src/styles.css
+++ b/src/styles.css
@@ -430,6 +430,32 @@ button:disabled:hover {
padding-bottom: 0.75rem;
}
+.bank-info {
+ padding: 1rem;
+ background-color: #f8f9fa;
+ border-radius: 4px;
+ border: 1px solid #dee2e6;
+}
+
+.bank-info > div {
+ padding: 0.5rem 0;
+ font-size: 0.85rem;
+ line-height: 1.6;
+}
+
+.bank-info > div:not(:last-child) {
+ border-bottom: 1px solid #e9ecef;
+ padding-bottom: 0.75rem;
+}
+
+.bank-info strong {
+ font-weight: 600;
+}
+
+.bank-info span {
+ margin-left: 0.5rem;
+}
+
.validation-errors,
.validation-warnings {
margin-bottom: 1rem;
diff --git a/src/tabs/BlockLayoutTab.ts b/src/tabs/BlockLayoutTab.ts
index 3d1d23f..6c74d27 100644
--- a/src/tabs/BlockLayoutTab.ts
+++ b/src/tabs/BlockLayoutTab.ts
@@ -1,6 +1,6 @@
import type { LinearLayout } from '../core/LinearLayout'
import { createBlockLayout } from '../layouts/BlockLayout'
-import type { BlockLayoutParams } from '../validation/InputValidator'
+import { InputValidator, type BlockLayoutParams } from '../validation/InputValidator'
import { CanvasRenderer } from '../visualization/CanvasRenderer'
import { ParameterForm } from '../ui/ParameterForm'
import { renderSharedControls } from '../ui/renderSharedControls'
@@ -12,8 +12,9 @@ import type { SnapshotFilterResult } from '../core/filterSnapshotDimensions'
* updates, and hover tooltips while reusing the shared CanvasTab interactions.
*/
export class BlockLayoutTab extends CanvasTab {
- private readonly form: ParameterForm
+ private readonly form: ParameterForm
private readonly tabId: string
+ private static readonly validator = new InputValidator()
constructor(tabId: string) {
const tabContent = document.getElementById(tabId)
@@ -47,7 +48,14 @@ export class BlockLayoutTab extends CanvasTab {
super(elements)
this.tabId = tabId
- this.form = new ParameterForm('paramForm')
+ this.form = new ParameterForm({
+ formId: 'paramForm',
+ errorsContainerId: 'validation-errors',
+ warningsContainerId: 'validation-warnings',
+ readParams: readBlockLayoutParams,
+ validateParams: (params) =>
+ BlockLayoutTab.validator.validateBlockLayoutParams(params),
+ })
const linearLayoutButton = this.setupShowInLinearLayoutButton()
if (linearLayoutButton) {
this.monitorShowInLinearLayoutButton(linearLayoutButton)
@@ -169,5 +177,49 @@ export class BlockLayoutTab extends CanvasTab {
private getColorDimensionPreference(): string {
return 'warp'
}
+}
+
+const readBlockLayoutParams = (form: HTMLFormElement): BlockLayoutParams => ({
+ sizePerThread: [
+ getNumericValue(form, 'sizePerThread0'),
+ getNumericValue(form, 'sizePerThread1'),
+ ] as [number, number],
+ threadsPerWarp: [
+ getNumericValue(form, 'threadsPerWarp0'),
+ getNumericValue(form, 'threadsPerWarp1'),
+ ] as [number, number],
+ warpsPerCTA: [
+ getNumericValue(form, 'warpsPerCTA0'),
+ getNumericValue(form, 'warpsPerCTA1'),
+ ] as [number, number],
+ order: getOrderValue(form),
+ tensorShape: [
+ getNumericValue(form, 'tensorShape0'),
+ getNumericValue(form, 'tensorShape1'),
+ ] as [number, number],
+})
+
+const getNumericValue = (form: HTMLFormElement, elementId: string): number => {
+ const input = form.querySelector(`#${elementId}`)
+ if (!input) {
+ throw new Error(`BlockLayoutTab input not found: #${elementId}`)
+ }
+ const rawValue = input.value.trim()
+ if (rawValue === '') {
+ return Number.NaN
+ }
+ const numericValue = Number(rawValue)
+ return Number.isFinite(numericValue) ? numericValue : Number.NaN
+}
+const getOrderValue = (form: HTMLFormElement): [number, number] => {
+ const select = form.querySelector('#order')
+ if (!select) {
+ throw new Error('BlockLayoutTab order select not found')
+ }
+ const [dim0, dim1] = select.value.split(',').map((value) => Number(value.trim()))
+ if (!Number.isInteger(dim0) || !Number.isInteger(dim1)) {
+ throw new Error(`Invalid order value: ${select.value}`)
+ }
+ return [dim0, dim1] as [number, number]
}
diff --git a/src/tabs/SharedLayoutTab.test.ts b/src/tabs/SharedLayoutTab.test.ts
new file mode 100644
index 0000000..ae33e90
--- /dev/null
+++ b/src/tabs/SharedLayoutTab.test.ts
@@ -0,0 +1,316 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+type RendererStub = {
+ render: ReturnType
+ reset: ReturnType
+ setCustomColorProvider: ReturnType
+ screenToGrid: ReturnType
+ getCellInfo: ReturnType
+}
+
+const { rendererInstances, createRendererStub } = vi.hoisted(() => {
+ const instances: RendererStub[] = []
+ const buildStub = (): RendererStub => ({
+ render: vi.fn(),
+ reset: vi.fn(),
+ setCustomColorProvider: vi.fn(),
+ screenToGrid: vi.fn().mockReturnValue({ row: 0, col: 0 }),
+ getCellInfo: vi.fn().mockReturnValue({
+ threadId: 0,
+ registerId: 0,
+ warpId: 0,
+ position: [0, 0],
+ sourcePosition: [0, 0],
+ inputCoords: { offset: 0 },
+ outputCoords: { dim0: 0, dim1: 0 },
+ }),
+ })
+ return { rendererInstances: instances, createRendererStub: buildStub }
+})
+
+const { tooltipInstances, createTooltipStub } = vi.hoisted(() => {
+ const instances: Array<{ show: ReturnType; hide: ReturnType }> = []
+ const buildStub = () => {
+ const stub = {
+ show: vi.fn(),
+ hide: vi.fn(),
+ }
+ instances.push(stub)
+ return stub
+ }
+ return { tooltipInstances: instances, createTooltipStub: buildStub }
+})
+
+vi.mock('../visualization/CanvasRenderer', () => ({
+ CanvasRenderer: vi.fn().mockImplementation(() => {
+ const stub = createRendererStub()
+ rendererInstances.push(stub)
+ return stub
+ }),
+}))
+
+vi.mock('../ui/Tooltip', () => ({
+ Tooltip: vi.fn().mockImplementation(() => createTooltipStub()),
+}))
+
+import { SharedLayoutTab } from './SharedLayoutTab'
+
+const buildCanvas = (): HTMLCanvasElement => {
+ const canvas = document.createElement('canvas')
+ canvas.id = 'shared-canvas'
+ canvas.width = 600
+ canvas.height = 400
+ Object.defineProperty(canvas, 'getContext', {
+ value: vi.fn(() => ({
+ fillRect: vi.fn(),
+ clearRect: vi.fn(),
+ beginPath: vi.fn(),
+ moveTo: vi.fn(),
+ lineTo: vi.fn(),
+ stroke: vi.fn(),
+ fillText: vi.fn(),
+ save: vi.fn(),
+ restore: vi.fn(),
+ translate: vi.fn(),
+ scale: vi.fn(),
+ })),
+ })
+ Object.defineProperty(canvas, 'getBoundingClientRect', {
+ value: () => ({
+ width: 600,
+ height: 400,
+ left: 0,
+ top: 0,
+ right: 600,
+ bottom: 400,
+ x: 0,
+ y: 0,
+ toJSON: () => ({}),
+ }),
+ })
+ return canvas
+}
+
+const setupDom = () => {
+ document.body.innerHTML = `
+
+ `
+
+ const visualization = document.querySelector('.visualization')
+ Object.defineProperty(visualization, 'getBoundingClientRect', {
+ value: () => ({
+ width: 800,
+ height: 600,
+ left: 0,
+ top: 0,
+ right: 800,
+ bottom: 600,
+ x: 0,
+ y: 0,
+ toJSON: () => ({}),
+ }),
+ })
+ visualization?.appendChild(buildCanvas())
+}
+
+describe('SharedLayoutTab', () => {
+ beforeEach(() => {
+ vi.resetModules()
+ rendererInstances.length = 0
+ tooltipInstances.length = 0
+ document.body.innerHTML = ''
+ setupDom()
+ vi.stubGlobal('alert', vi.fn())
+ })
+
+ it('initializes the renderer with a bank color provider', () => {
+ new SharedLayoutTab('shared-layout')
+ expect(rendererInstances.length).toBe(1)
+ const instance = rendererInstances[0]
+ expect(instance?.setCustomColorProvider).toHaveBeenCalled()
+ })
+
+ it('re-renders when tensor shape changes', () => {
+ const tab = new SharedLayoutTab('shared-layout')
+ expect(tab).toBeDefined()
+ const rowsInput = document.getElementById('shared-rows') as HTMLInputElement
+ rowsInput.value = '256'
+ rowsInput.dispatchEvent(new Event('input'))
+
+ expect(rendererInstances.length).toBe(2)
+ const segmentsText = document.getElementById('shared-bank-segments')?.textContent
+ expect(segmentsText).toBe('256')
+ })
+
+ it('shows validation errors when inputs are invalid', () => {
+ const tab = new SharedLayoutTab('shared-layout')
+ expect(tab).toBeDefined()
+ const rowsInput = document.getElementById('shared-rows') as HTMLInputElement
+ rowsInput.value = '63'
+ rowsInput.dispatchEvent(new Event('input'))
+
+ const errors = document.getElementById('shared-validation-errors')
+ expect(errors?.textContent).toContain('power of two')
+ })
+
+ it('clears bank info when validation fails', () => {
+ const tab = new SharedLayoutTab('shared-layout')
+ expect(tab).toBeDefined()
+ const rowsInput = document.getElementById('shared-rows') as HTMLInputElement
+ rowsInput.value = '63'
+ rowsInput.dispatchEvent(new Event('input'))
+
+ const segmentsText = document.getElementById('shared-bank-segments')?.textContent
+ expect(segmentsText).toBe('--')
+ })
+
+ describe('hover tooltips', () => {
+ const triggerHover = (
+ tab: SharedLayoutTab,
+ cellData: { row: number; col: number; offset: number }
+ ): void => {
+ const renderer = rendererInstances[rendererInstances.length - 1]
+ renderer?.getCellInfo.mockReturnValue({
+ threadId: 0,
+ registerId: 0,
+ warpId: 0,
+ position: [cellData.row, cellData.col],
+ sourcePosition: [cellData.row, cellData.col],
+ inputCoords: { offset: cellData.offset },
+ outputCoords: { dim0: cellData.row, dim1: cellData.col },
+ })
+ const hoverable = tab as unknown as { handleHover: (event: MouseEvent) => void }
+ hoverable.handleHover(new MouseEvent('mousemove', { clientX: 200, clientY: 150 }))
+ }
+
+ const expectTooltipContains = (text: string): void => {
+ const tooltip = tooltipInstances[tooltipInstances.length - 1]
+ expect(tooltip).toBeDefined()
+ const calls = tooltip?.show.mock.calls ?? []
+ expect(calls.length).toBeGreaterThan(0)
+ const content = calls[calls.length - 1]?.[0] as string
+ expect(content).toContain(text)
+ }
+
+ it('shows correct tooltip values for the first cell', () => {
+ const tab = new SharedLayoutTab('shared-layout')
+ triggerHover(tab, { row: 0, col: 0, offset: 0 })
+
+ expectTooltipContains('Logical Index: (0, 0)')
+ expectTooltipContains('Offset: 0')
+ expectTooltipContains('Segment (Bank Row): 0')
+ expectTooltipContains('Bank: 0')
+ })
+
+ it('shows tooltip values for a middle cell with a non-zero bank', () => {
+ const tab = new SharedLayoutTab('shared-layout')
+ triggerHover(tab, { row: 12, col: 28, offset: 3 })
+
+ expectTooltipContains('Logical Index: (12, 28)')
+ expectTooltipContains('Offset: 3')
+ expectTooltipContains('Segment (Bank Row): 0')
+ expectTooltipContains('Bank: 1')
+ })
+
+ it('shows tooltip values for a high offset cell spanning multiple segments', () => {
+ const tab = new SharedLayoutTab('shared-layout')
+ triggerHover(tab, { row: 127, col: 63, offset: 515 })
+
+ expectTooltipContains('Logical Index: (127, 63)')
+ expectTooltipContains('Offset: 515')
+ expectTooltipContains('Segment (Bank Row): 8')
+ expectTooltipContains('Bank: 1')
+ })
+ })
+})
diff --git a/src/tabs/SharedLayoutTab.ts b/src/tabs/SharedLayoutTab.ts
new file mode 100644
index 0000000..8d5bd59
--- /dev/null
+++ b/src/tabs/SharedLayoutTab.ts
@@ -0,0 +1,299 @@
+import { CanvasTab, type CanvasTabElements } from './CanvasTab'
+import { renderSharedControls } from '../ui/renderSharedControls'
+import { CanvasRenderer, type CellInfo } from '../visualization/CanvasRenderer'
+import { ParameterForm } from '../ui/ParameterForm'
+import type { BlockLayoutParams } from '../validation/InputValidator'
+import {
+ assignBank,
+ computeBankInfo,
+ createSharedLayout,
+ type SharedLayoutBuildResult,
+ type SharedLayoutParams,
+} from '../layouts/SharedLayout'
+import {
+ SharedLayoutValidator,
+ type SharedLayoutUiParams,
+ type SharedLayoutViewMode,
+} from '../validation/SharedLayoutValidator'
+
+interface BankInfoElements {
+ bankCount: HTMLElement
+ bankSize: HTMLElement
+ elemsPerBank: HTMLElement
+ segmentsPerBank: HTMLElement
+}
+
+export class SharedLayoutTab extends CanvasTab {
+ private readonly form: ParameterForm
+ private readonly validator = new SharedLayoutValidator()
+ private readonly bankInfoElements: BankInfoElements
+
+ private currentParams: SharedLayoutUiParams | null = null
+ private currentLayout: SharedLayoutBuildResult | null = null
+ private currentElementBits = 16
+
+ constructor(tabId: string) {
+ const tabContent = document.getElementById(tabId)
+ if (!tabContent) {
+ throw new Error(`SharedLayoutTab container not found: ${tabId}`)
+ }
+
+ const visualizationContainer = tabContent.querySelector('.visualization')
+ if (!(visualizationContainer instanceof HTMLElement)) {
+ throw new Error('SharedLayoutTab visualization container not found')
+ }
+
+ const canvas = visualizationContainer.querySelector('canvas')
+ if (!(canvas instanceof HTMLCanvasElement)) {
+ throw new Error('SharedLayoutTab canvas element not found')
+ }
+
+ const controlsContainer = tabContent.querySelector('[data-controls]')
+ if (!(controlsContainer instanceof HTMLElement)) {
+ throw new Error('SharedLayoutTab controls container not found')
+ }
+ const resetButton = renderSharedControls(controlsContainer, { resetButtonId: 'shared-reset' })
+
+ if (!(tabContent.querySelector('#sharedForm') instanceof HTMLFormElement)) {
+ throw new Error('SharedLayoutTab form element not found')
+ }
+ if (
+ !(tabContent.querySelector('#shared-validation-errors') instanceof HTMLElement) ||
+ !(tabContent.querySelector('#shared-validation-warnings') instanceof HTMLElement)
+ ) {
+ throw new Error('SharedLayoutTab validation containers not found')
+ }
+
+ const bankCount = tabContent.querySelector('#shared-bank-count')
+ const bankSize = tabContent.querySelector('#shared-bank-size')
+ const elemsPerBank = tabContent.querySelector('#shared-elems-per-bank')
+ const segmentsPerBank = tabContent.querySelector('#shared-bank-segments')
+ if (
+ !(bankCount instanceof HTMLElement) ||
+ !(bankSize instanceof HTMLElement) ||
+ !(elemsPerBank instanceof HTMLElement) ||
+ !(segmentsPerBank instanceof HTMLElement)
+ ) {
+ throw new Error('SharedLayoutTab bank info elements missing')
+ }
+
+ const elements: CanvasTabElements = {
+ root: tabContent,
+ canvas,
+ visualizationContainer,
+ resetButton,
+ }
+
+ super(elements)
+
+ this.bankInfoElements = {
+ bankCount,
+ bankSize,
+ elemsPerBank,
+ segmentsPerBank,
+ }
+
+ this.form = new ParameterForm({
+ formId: 'sharedForm',
+ errorsContainerId: 'shared-validation-errors',
+ warningsContainerId: 'shared-validation-warnings',
+ readParams: readSharedLayoutParams,
+ validateParams: (params) => this.validator.validate(params),
+ })
+
+ this.form.onParamsChange((params) => {
+ this.refreshVisualization(params)
+ })
+ this.form.onValidationChange((isValid) => {
+ if (!isValid) {
+ this.currentParams = null
+ this.currentLayout = null
+ this.currentElementBits = 16
+ this.clearBankInfo()
+ this.hideTooltip()
+ }
+ })
+
+ const initialParams = this.form.getParams()
+ if (this.form.validate()) {
+ this.refreshVisualization(initialParams)
+ } else {
+ this.currentParams = null
+ this.currentLayout = null
+ this.currentElementBits = 16
+ this.clearBankInfo()
+ }
+ }
+
+ protected handleHover(event: MouseEvent): void {
+ const renderer = this.getRenderer()
+ if (!renderer || !this.currentParams || !this.currentLayout) {
+ this.hideTooltip()
+ return
+ }
+
+ const rect = this.canvas.getBoundingClientRect()
+ const x = event.clientX - rect.left
+ const y = event.clientY - rect.top
+ const gridPos = renderer.screenToGrid(x, y)
+ const cellInfo = renderer.getCellInfo(gridPos.row, gridPos.col)
+
+ if (!cellInfo) {
+ this.hideTooltip()
+ return
+ }
+
+ const row = this.readOutputCoordinate(
+ cellInfo,
+ this.currentLayout.rowDimName,
+ this.currentLayout.rowDimIndex
+ )
+ const col = this.readOutputCoordinate(
+ cellInfo,
+ this.currentLayout.colDimName,
+ this.currentLayout.colDimIndex
+ )
+ const offset = this.getCellOffset(cellInfo)
+ const { bank, segment } = assignBank(offset, this.currentElementBits)
+ const tooltipLines = [
+ `Logical Index: (${row}, ${col})
`,
+ `Offset: ${offset}
`,
+ `Segment (Bank Row): ${segment}
`,
+ `Bank: ${bank}
`,
+ ]
+
+ this.tooltip.show(tooltipLines.join(''), event.clientX, event.clientY)
+ }
+
+ protected resetHover(): void {
+ this.hideTooltip()
+ }
+
+ private refreshVisualization(params: SharedLayoutUiParams): void {
+ this.currentParams = params
+ this.currentElementBits = params.elementBits
+ this.updateBankInfo(params)
+
+ if (params.viewMode !== 'logical') {
+ // Bank view not implemented yet; fallback to logical rendering.
+ this.currentLayout = null
+ return
+ }
+
+ try {
+ this.updateVisualization(params)
+ } catch (error) {
+ console.error('Failed to render shared layout', error)
+ const errorMessage = error instanceof Error ? error.message : String(error)
+ alert(`Failed to render shared layout: ${errorMessage}`)
+ }
+ }
+
+ private updateVisualization(params: SharedLayoutUiParams): void {
+ const layoutResult = createSharedLayout(params)
+ const rendererParams: BlockLayoutParams = {
+ sizePerThread: [1, 1],
+ threadsPerWarp: [1, 1],
+ warpsPerCTA: [1, 1],
+ order: [0, 1],
+ tensorShape: params.tensorShape,
+ }
+
+ const renderer = new CanvasRenderer(
+ this.canvas,
+ layoutResult.layout,
+ rendererParams,
+ undefined,
+ {
+ traversalMode: 'by-output',
+ colorGrouping: 'warp',
+ showCellText: false,
+ }
+ )
+ this.setRenderer(renderer)
+ renderer.setCustomColorProvider((cell) => this.resolveBankForCell(cell, params.elementBits))
+ this.currentLayout = layoutResult
+ }
+
+ private resolveBankForCell(cell: CellInfo, elementBits: number): number {
+ const { bank } = assignBank(this.getCellOffset(cell), elementBits)
+ return bank
+ }
+
+ private getCellOffset(cellInfo: CellInfo): number {
+ const offset = cellInfo.inputCoords?.offset
+ if (typeof offset === 'number' && Number.isFinite(offset)) {
+ return offset
+ }
+ return 0
+ }
+
+ private updateBankInfo(params: SharedLayoutUiParams): void {
+ const info = computeBankInfo(params.tensorShape, params.elementBits)
+ this.bankInfoElements.bankCount.textContent = info.bankCount.toString()
+ this.bankInfoElements.bankSize.textContent = `${info.bankSizeBits} bits`
+ this.bankInfoElements.elemsPerBank.textContent = info.elementsPerBankRow.toString()
+ this.bankInfoElements.segmentsPerBank.textContent = info.segmentsPerBankRow.toString()
+ }
+
+ private clearBankInfo(): void {
+ Object.values(this.bankInfoElements).forEach((element) => {
+ element.textContent = '--'
+ })
+ }
+
+ private readOutputCoordinate(
+ cellInfo: CellInfo,
+ dimName: string,
+ fallbackIndex: number
+ ): number {
+ const fromCoords = cellInfo.outputCoords?.[dimName]
+ if (typeof fromCoords === 'number' && Number.isFinite(fromCoords)) {
+ return fromCoords
+ }
+ return cellInfo.position[fallbackIndex] ?? 0
+ }
+}
+
+const readSharedLayoutParams = (form: HTMLFormElement): SharedLayoutUiParams => ({
+ tensorShape: [
+ getNumericValue(form, '#shared-rows'),
+ getNumericValue(form, '#shared-cols'),
+ ] as [number, number],
+ order: parseOrder(form, '#shared-order'),
+ vec: getNumericValue(form, '#shared-vec'),
+ perPhase: getNumericValue(form, '#shared-per-phase'),
+ maxPhase: getNumericValue(form, '#shared-max-phase'),
+ swizzleMode: getSelectValue(form, '#shared-swizzle-mode') as SharedLayoutParams['swizzleMode'],
+ elementBits: getNumericValue(form, '#shared-element-bits'),
+ viewMode: getSelectValue(form, '#shared-view-mode') as SharedLayoutViewMode,
+})
+
+const getNumericValue = (form: HTMLFormElement, selector: string): number => {
+ const element = form.querySelector(selector)
+ if (!element) {
+ throw new Error(`SharedLayoutTab field not found: ${selector}`)
+ }
+
+ const rawValue = element.value.trim()
+ if (rawValue === '') {
+ return Number.NaN
+ }
+
+ const numericValue = Number(rawValue)
+ return Number.isFinite(numericValue) ? numericValue : Number.NaN
+}
+
+const getSelectValue = (form: HTMLFormElement, selector: string): string => {
+ const select = form.querySelector(selector)
+ if (!select) {
+ throw new Error(`SharedLayoutTab select not found: ${selector}`)
+ }
+ return select.value
+}
+
+const parseOrder = (form: HTMLFormElement, selector: string): [number, number] => {
+ const value = getSelectValue(form, selector)
+ const [first, second] = value.split(',').map((part) => Number(part.trim()))
+ return [first ?? Number.NaN, second ?? Number.NaN]
+}
diff --git a/src/ui/ParameterForm.test.ts b/src/ui/ParameterForm.test.ts
index 3ed5abb..886646e 100644
--- a/src/ui/ParameterForm.test.ts
+++ b/src/ui/ParameterForm.test.ts
@@ -1,5 +1,6 @@
import { describe, it, beforeEach, expect, vi } from 'vitest'
import { ParameterForm } from './ParameterForm'
+import { InputValidator, type BlockLayoutParams } from '../validation/InputValidator'
const buildFormDom = (): void => {
document.body.innerHTML = `
@@ -28,13 +29,64 @@ const buildFormDom = (): void => {
`
}
+const getNumberValue = (form: HTMLFormElement, id: string): number => {
+ const input = form.querySelector(`#${id}`)
+ if (!input) {
+ throw new Error(`Input not found: ${id}`)
+ }
+ const rawValue = input.value.trim()
+ if (rawValue === '') {
+ return Number.NaN
+ }
+ const numericValue = Number(rawValue)
+ return Number.isFinite(numericValue) ? numericValue : Number.NaN
+}
+
+const getOrderValue = (form: HTMLFormElement): [number, number] => {
+ const select = form.querySelector('#order')
+ if (!select) {
+ throw new Error('Order select not found')
+ }
+ const [dim0, dim1] = select.value.split(',').map(Number) as [number, number]
+ return [dim0, dim1]
+}
+
+const buildForm = (): ParameterForm => {
+ const validator = new InputValidator()
+ return new ParameterForm({
+ formId: 'paramForm',
+ errorsContainerId: 'validation-errors',
+ warningsContainerId: 'validation-warnings',
+ readParams: (form) => ({
+ sizePerThread: [
+ getNumberValue(form, 'sizePerThread0'),
+ getNumberValue(form, 'sizePerThread1'),
+ ] as [number, number],
+ threadsPerWarp: [
+ getNumberValue(form, 'threadsPerWarp0'),
+ getNumberValue(form, 'threadsPerWarp1'),
+ ] as [number, number],
+ warpsPerCTA: [
+ getNumberValue(form, 'warpsPerCTA0'),
+ getNumberValue(form, 'warpsPerCTA1'),
+ ] as [number, number],
+ order: getOrderValue(form),
+ tensorShape: [
+ getNumberValue(form, 'tensorShape0'),
+ getNumberValue(form, 'tensorShape1'),
+ ] as [number, number],
+ }),
+ validateParams: (params) => validator.validateBlockLayoutParams(params),
+ })
+}
+
describe('ParameterForm', () => {
beforeEach(() => {
buildFormDom()
})
it('shows inline errors for decimal input instead of throwing', () => {
- const form = new ParameterForm('paramForm')
+ const form = buildForm()
const sizeInput = document.getElementById('sizePerThread0') as HTMLInputElement
sizeInput.value = '2.5'
@@ -47,7 +99,7 @@ describe('ParameterForm', () => {
})
it('suppresses callbacks and flips validation state when parse errors occur', () => {
- const form = new ParameterForm('paramForm')
+ const form = buildForm()
const changeSpy = vi.fn()
const validationSpy = vi.fn()
form.onParamsChange(changeSpy)
diff --git a/src/ui/ParameterForm.ts b/src/ui/ParameterForm.ts
index 2bcacac..1ed9bed 100644
--- a/src/ui/ParameterForm.ts
+++ b/src/ui/ParameterForm.ts
@@ -1,118 +1,120 @@
-import { InputValidator, type BlockLayoutParams } from '../validation/InputValidator'
-
-export class ParameterForm {
- private form: HTMLFormElement
- private errorsDiv: HTMLElement
- private warningsDiv: HTMLElement
- private validator: InputValidator
- private validationListeners = new Set<(isValid: boolean) => void>()
+import type { ValidationResult } from '../validation/InputValidator'
+
+export interface ParameterFormConfig {
+ formId: string
+ errorsContainerId: string
+ warningsContainerId: string
+ readParams: (form: HTMLFormElement) => TParams
+ validateParams: (params: TParams) => ValidationResult
+}
+
+export class ParameterForm {
+ private readonly form: HTMLFormElement
+ private readonly errorsDiv: HTMLElement
+ private readonly warningsDiv: HTMLElement
+ private readonly readParamsFn: (form: HTMLFormElement) => TParams
+ private readonly validateParamsFn: (params: TParams) => ValidationResult
+ private readonly validationListeners = new Set<(isValid: boolean) => void>()
private lastValidationResult: boolean | null = null
- constructor(formId: string) {
- const form = document.getElementById(formId)
- if (!form || !(form instanceof HTMLFormElement)) {
- throw new Error(`Form not found: ${formId}`)
+ constructor(config: ParameterFormConfig) {
+ const form = document.getElementById(config.formId)
+ if (!(form instanceof HTMLFormElement)) {
+ throw new Error(`Form not found: ${config.formId}`)
}
this.form = form
- const errorsDiv = document.getElementById('validation-errors')
- const warningsDiv = document.getElementById('validation-warnings')
-
- if (!errorsDiv || !warningsDiv) {
+ const errorsDiv = document.getElementById(config.errorsContainerId)
+ const warningsDiv = document.getElementById(config.warningsContainerId)
+ if (!(errorsDiv instanceof HTMLElement) || !(warningsDiv instanceof HTMLElement)) {
throw new Error('Validation message divs not found')
}
this.errorsDiv = errorsDiv
this.warningsDiv = warningsDiv
- this.validator = new InputValidator()
+ this.readParamsFn = config.readParams
+ this.validateParamsFn = config.validateParams
}
- /**
- * Get current parameter values from the form
- */
- getParams(): BlockLayoutParams {
- return {
- sizePerThread: [
- this.getNumberValue('sizePerThread0'),
- this.getNumberValue('sizePerThread1'),
- ] as [number, number],
- threadsPerWarp: [
- this.getNumberValue('threadsPerWarp0'),
- this.getNumberValue('threadsPerWarp1'),
- ] as [number, number],
- warpsPerCTA: [
- this.getNumberValue('warpsPerCTA0'),
- this.getNumberValue('warpsPerCTA1'),
- ] as [number, number],
- order: this.getOrderValue(),
- tensorShape: [
- this.getNumberValue('tensorShape0'),
- this.getNumberValue('tensorShape1'),
- ] as [number, number],
- }
+ getParams(): TParams {
+ return this.readParamsFn(this.form)
}
- /**
- * Validate current parameters and display errors/warnings
- * Returns true if valid (no errors), false otherwise
- */
validate(): boolean {
const params = this.getParams()
- const result = this.validator.validateBlockLayoutParams(params)
+ const result = this.runValidation(params)
+ return result.valid
+ }
+
+ onParamsChange(callback: (params: TParams) => void): void {
+ const handleChange = (): void => {
+ const params = this.getParams()
+ const result = this.runValidation(params)
+ if (result.valid) {
+ callback(params)
+ }
+ }
+
+ const fields = this.form.querySelectorAll('input, select')
+ fields.forEach((field) => {
+ const eventName = field instanceof HTMLSelectElement ? 'change' : 'input'
+ field.addEventListener(eventName, handleChange)
+ })
+
+ this.form.addEventListener('submit', (event) => event.preventDefault())
+ }
+
+ onValidationChange(callback: (isValid: boolean) => void): void {
+ this.validationListeners.add(callback)
+
+ if (this.lastValidationResult === null) {
+ this.validate()
+ return
+ }
+
+ callback(this.lastValidationResult)
+ }
+
+ private runValidation(params: TParams): ValidationResult {
+ const result = this.validateParamsFn(params)
- // Display errors
if (result.errors.size > 0) {
- const errorList = Array.from(result.errors.values())
- this.showErrors(errorList)
+ this.showErrors(Array.from(result.errors.values()))
} else {
this.hideErrors()
}
- // Display warnings
if (result.warnings.size > 0) {
- const warningList = Array.from(result.warnings.values())
- this.showWarnings(warningList)
+ this.showWarnings(Array.from(result.warnings.values()))
} else {
this.hideWarnings()
}
this.updateValidationState(result.valid)
- return result.valid
+ return result
}
- /**
- * Show error messages
- */
private showErrors(errors: string[]): void {
this.errorsDiv.innerHTML = `
Errors:
- ${errors.map(err => `- ${err}
`).join('')}
+ ${errors.map((err) => `- ${err}
`).join('')}
`
this.errorsDiv.classList.add('visible')
}
- /**
- * Hide error messages
- */
private hideErrors(): void {
this.errorsDiv.classList.remove('visible')
this.errorsDiv.innerHTML = ''
}
- /**
- * Show warning messages
- */
private showWarnings(warnings: string[]): void {
this.warningsDiv.innerHTML = `
Warnings:
- ${warnings.map(warn => `- ${warn}
`).join('')}
+ ${warnings.map((warn) => `- ${warn}
`).join('')}
`
this.warningsDiv.classList.add('visible')
}
- /**
- * Hide warning messages
- */
private hideWarnings(): void {
this.warningsDiv.classList.remove('visible')
this.warningsDiv.innerHTML = ''
@@ -125,68 +127,4 @@ export class ParameterForm {
this.lastValidationResult = nextState
this.validationListeners.forEach((listener) => listener(nextState))
}
-
- /**
- * Invoke the callback whenever any form field changes and the inputs validate.
- */
- onParamsChange(callback: (params: BlockLayoutParams) => void): void {
- const handleChange = (): void => {
- if (this.validate()) {
- callback(this.getParams())
- }
- }
-
- const fields = this.form.querySelectorAll('input, select')
- fields.forEach((field) => {
- const eventName = field instanceof HTMLSelectElement ? 'change' : 'input'
- field.addEventListener(eventName, handleChange)
- })
-
- this.form.addEventListener('submit', (event) => event.preventDefault())
- }
-
- /**
- * Notify listeners whenever the validation state changes.
- */
- onValidationChange(callback: (isValid: boolean) => void): void {
- this.validationListeners.add(callback)
-
- if (this.lastValidationResult === null) {
- this.validate()
- return
- }
-
- callback(this.lastValidationResult)
- }
-
- /**
- * Get number value from input
- */
- private getNumberValue(id: string): number {
- const input = document.getElementById(id) as HTMLInputElement
- if (!input) {
- throw new Error(`Input not found: ${id}`)
- }
- const rawValue = input.value.trim()
- if (rawValue === '') {
- return Number.NaN
- }
- const numericValue = Number(rawValue)
- return Number.isFinite(numericValue) ? numericValue : Number.NaN
- }
-
- /**
- * Get order value from dropdown
- */
- private getOrderValue(): [number, number] {
- const select = document.getElementById('order') as HTMLSelectElement
- if (!select) {
- throw new Error('Order select not found')
- }
- const [dim0, dim1] = select.value.split(',').map(Number) as [number, number]
- if (!Number.isInteger(dim0) || !Number.isInteger(dim1)) {
- throw new Error(`Invalid order value: ${select.value}`)
- }
- return [dim0, dim1]
- }
}
diff --git a/src/validation/SharedLayoutValidator.test.ts b/src/validation/SharedLayoutValidator.test.ts
new file mode 100644
index 0000000..6366143
--- /dev/null
+++ b/src/validation/SharedLayoutValidator.test.ts
@@ -0,0 +1,124 @@
+import { describe, expect, it } from 'vitest'
+import type { SharedLayoutUiParams } from './SharedLayoutValidator'
+import { SharedLayoutValidator } from './SharedLayoutValidator'
+
+const baseParams: SharedLayoutUiParams = {
+ tensorShape: [128, 64],
+ order: [1, 0],
+ vec: 4,
+ perPhase: 2,
+ maxPhase: 4,
+ swizzleMode: 'swizzled',
+ elementBits: 16,
+ viewMode: 'logical',
+}
+
+describe('SharedLayoutValidator', () => {
+ const validator = new SharedLayoutValidator()
+
+ it('accepts the default configuration', () => {
+ const result = validator.validate(baseParams)
+ expect(result.valid).toBe(true)
+ expect(result.errors.size).toBe(0)
+ })
+
+ it('rejects non power-of-two tensor shapes', () => {
+ const result = validator.validate({
+ ...baseParams,
+ tensorShape: [63, 128],
+ })
+ expect(result.valid).toBe(false)
+ expect(result.errors.get('tensorShape')).toContain('power of two')
+ })
+
+ it('rejects unsupported element bit widths', () => {
+ const result = validator.validate({
+ ...baseParams,
+ elementBits: 12,
+ })
+ expect(result.valid).toBe(false)
+ expect(result.errors.get('elementBits')).toContain('8, 16, 32, or 64')
+ })
+
+ it('enforces valid order combinations', () => {
+ const result = validator.validate({
+ ...baseParams,
+ order: [0, 0],
+ })
+ expect(result.valid).toBe(false)
+ expect(result.errors.get('order')).toContain('[0,1] or [1,0]')
+ })
+
+ it('accepts reversed order', () => {
+ const result = validator.validate({
+ ...baseParams,
+ order: [0, 1],
+ })
+ expect(result.valid).toBe(true)
+ })
+
+ it('requires a supported swizzle mode', () => {
+ const result = validator.validate({
+ ...baseParams,
+ swizzleMode: 'unknown' as any,
+ })
+ expect(result.valid).toBe(false)
+ expect(result.errors.get('swizzleMode')).toContain('"swizzled" or "amdRotating"')
+ })
+
+ it('caps vec at the hardware limit', () => {
+ const result = validator.validate({
+ ...baseParams,
+ vec: 32,
+ })
+ expect(result.valid).toBe(false)
+ expect(result.errors.get('vec')).toContain('must not exceed 16')
+ })
+
+ it('caps perPhase at the hardware limit', () => {
+ const result = validator.validate({
+ ...baseParams,
+ perPhase: 32,
+ })
+ expect(result.valid).toBe(false)
+ expect(result.errors.get('perPhase')).toContain('must not exceed 16')
+ })
+
+ it('caps maxPhase at the hardware limit', () => {
+ const result = validator.validate({
+ ...baseParams,
+ maxPhase: 32,
+ })
+ expect(result.valid).toBe(false)
+ expect(result.errors.get('maxPhase')).toContain('must not exceed 16')
+ })
+
+ it('flags vec values that exceed the selected column dimension', () => {
+ const result = validator.validate({
+ ...baseParams,
+ tensorShape: [128, 8],
+ vec: 16,
+ })
+ expect(result.valid).toBe(false)
+ expect(result.errors.get('vec')).toContain('must not exceed the number of columns')
+ })
+
+ it('flags perPhase values larger than the row dimension', () => {
+ const result = validator.validate({
+ ...baseParams,
+ tensorShape: [8, 64],
+ perPhase: 16,
+ })
+ expect(result.valid).toBe(false)
+ expect(result.errors.get('perPhase')).toContain('must not exceed the number of rows')
+ })
+
+ it('emits a warning when bank view is requested', () => {
+ const result = validator.validate({
+ ...baseParams,
+ viewMode: 'bank',
+ })
+ expect(result.valid).toBe(true)
+ expect(result.warnings.get('viewMode')).toContain('Only logical view is implemented')
+ })
+})
diff --git a/src/validation/SharedLayoutValidator.ts b/src/validation/SharedLayoutValidator.ts
new file mode 100644
index 0000000..67c2c5e
--- /dev/null
+++ b/src/validation/SharedLayoutValidator.ts
@@ -0,0 +1,139 @@
+import type { SharedLayoutParams } from '../layouts/SharedLayout'
+import type { ValidationResult } from './InputValidator'
+
+export type SharedLayoutViewMode = 'logical' | 'bank'
+
+export interface SharedLayoutUiParams extends SharedLayoutParams {
+ elementBits: number
+ viewMode: SharedLayoutViewMode
+}
+
+const SWIZZLE_PARAM_LIMIT = 16
+
+export class SharedLayoutValidator {
+ validate(params: SharedLayoutUiParams): ValidationResult {
+ const errors = new Map()
+ const warnings = new Map()
+
+ const shapeValid = this.validateShape(params.tensorShape, errors)
+ const orderValid = this.validateOrder(params.order, errors)
+ this.validateSwizzleMode(params.swizzleMode, errors)
+ this.validateSwizzleParams(params, errors, shapeValid, orderValid)
+ this.validateElementBits(params.elementBits, errors)
+ this.validateViewMode(params.viewMode, warnings)
+
+ return {
+ valid: errors.size === 0,
+ errors,
+ warnings,
+ }
+ }
+
+ private validateShape(shape: [number, number], errors: Map): boolean {
+ let valid = true
+ const labels = ['Rows', 'Cols']
+ shape.forEach((value, index) => {
+ if (!Number.isInteger(value) || value <= 0) {
+ errors.set('tensorShape', `${labels[index]} must be a positive integer`)
+ valid = false
+ return
+ }
+ if (!this.isPowerOfTwo(value)) {
+ errors.set('tensorShape', `${labels[index]} must be a power of two`)
+ valid = false
+ }
+ })
+ return valid
+ }
+
+ private validateOrder(order: [number, number], errors: Map): boolean {
+ const valid =
+ (order[0] === 0 && order[1] === 1) ||
+ (order[0] === 1 && order[1] === 0)
+ if (!valid) {
+ errors.set('order', 'Order must be [0,1] or [1,0]')
+ }
+ return valid
+ }
+
+ private validateSwizzleMode(swizzleMode: string, errors: Map): void {
+ if (swizzleMode !== 'swizzled' && swizzleMode !== 'amdRotating') {
+ errors.set('swizzleMode', 'swizzleMode must be "swizzled" or "amdRotating"')
+ }
+ }
+
+ private validateSwizzleParams(
+ params: SharedLayoutParams,
+ errors: Map,
+ shapeValid: boolean,
+ orderValid: boolean
+ ): void {
+ const { vec, perPhase, maxPhase, tensorShape, order } = params
+ const validations: Array<[number, string]> = [
+ [vec, 'vec'],
+ [perPhase, 'perPhase'],
+ [maxPhase, 'maxPhase'],
+ ]
+
+ const invalidFields = new Set()
+
+ validations.forEach(([value, label]) => {
+ if (!Number.isInteger(value) || value <= 0) {
+ errors.set(label, `${label} must be a positive integer`)
+ invalidFields.add(label)
+ } else if (!this.isPowerOfTwo(value)) {
+ errors.set(label, `${label} must be a power of two`)
+ invalidFields.add(label)
+ }
+ })
+
+ const enforceHardwareLimit = (value: number, label: string): void => {
+ if (!invalidFields.has(label) && value > SWIZZLE_PARAM_LIMIT) {
+ errors.set(label, `${label} must not exceed ${SWIZZLE_PARAM_LIMIT} (AMD/NVIDIA hardware limit)`)
+ invalidFields.add(label)
+ }
+ }
+
+ enforceHardwareLimit(vec, 'vec')
+ enforceHardwareLimit(perPhase, 'perPhase')
+ enforceHardwareLimit(maxPhase, 'maxPhase')
+
+ if (!invalidFields.has('maxPhase') && maxPhase <= 0) {
+ errors.set('maxPhase', 'maxPhase must be positive')
+ invalidFields.add('maxPhase')
+ }
+
+ if (!shapeValid || !orderValid) {
+ return
+ }
+
+ const colDimIndex = order[0]
+ const rowDimIndex = order[1]
+ const colSize = tensorShape[colDimIndex] ?? 0
+ const rowSize = tensorShape[rowDimIndex] ?? 0
+
+ if (!invalidFields.has('vec') && vec > colSize) {
+ errors.set('vec', 'vec must not exceed the number of columns')
+ }
+ if (!invalidFields.has('perPhase') && perPhase > rowSize) {
+ errors.set('perPhase', 'perPhase must not exceed the number of rows')
+ }
+ }
+
+ private validateElementBits(elementBits: number, errors: Map): void {
+ const allowed = new Set([8, 16, 32, 64])
+ if (!allowed.has(elementBits)) {
+ errors.set('elementBits', 'Element bits must be 8, 16, 32, or 64')
+ }
+ }
+
+ private validateViewMode(viewMode: SharedLayoutViewMode, warnings: Map): void {
+ if (viewMode !== 'logical') {
+ warnings.set('viewMode', 'Only logical view is implemented')
+ }
+ }
+
+ private isPowerOfTwo(value: number): boolean {
+ return value > 0 && (value & (value - 1)) === 0
+ }
+}
diff --git a/src/visualization/CanvasRenderer.test.ts b/src/visualization/CanvasRenderer.test.ts
index 438e53e..a93f74f 100644
--- a/src/visualization/CanvasRenderer.test.ts
+++ b/src/visualization/CanvasRenderer.test.ts
@@ -178,6 +178,30 @@ describe('CanvasRenderer', () => {
})
})
+ describe('custom color provider', () => {
+ it('applies custom color indices when provided', () => {
+ const provider = vi.fn((cell: CellInfo) => cell.position[0])
+ const warpSpy = vi.spyOn(ColorScheme.prototype, 'getColorForWarp')
+
+ try {
+ renderer.setCustomColorProvider(provider)
+ renderer.render()
+ expect(provider).toHaveBeenCalled()
+ expect(warpSpy).toHaveBeenCalled()
+ } finally {
+ warpSpy.mockRestore()
+ }
+ })
+
+ it('falls back to default coloring when provider is cleared', () => {
+ const provider = vi.fn(() => 0)
+ renderer.setCustomColorProvider(provider)
+ renderer.setCustomColorProvider(undefined)
+ const color = renderer.getWarpColor(0)
+ expect(typeof color).toBe('string')
+ })
+ })
+
describe('viewport integration', () => {
it('should support zoom in', () => {
expect(() => renderer.zoomIn()).not.toThrow()
diff --git a/src/visualization/CanvasRenderer.ts b/src/visualization/CanvasRenderer.ts
index beab541..ba73bea 100644
--- a/src/visualization/CanvasRenderer.ts
+++ b/src/visualization/CanvasRenderer.ts
@@ -12,6 +12,7 @@ interface CanvasRendererOptions {
traversalMode?: TraversalMode
showCellText?: boolean
colorInputDimension?: string
+ customColorProvider?: (cell: CellInfo) => number | null | undefined
}
export interface CellInfo {
@@ -81,6 +82,8 @@ export class CanvasRenderer {
private colorGrouping: ColorGroupingMode
private showCellText: boolean
private customColorDimension?: string
+ private customColorProvider?: (cell: CellInfo) => number | null | undefined
+ private customColorGroupCount = 0
private maxThreadIdObserved = 0
constructor(
@@ -102,6 +105,7 @@ export class CanvasRenderer {
this.showCellText = options?.showCellText ?? true
const trimmedColorDim = options?.colorInputDimension?.trim()
this.customColorDimension = trimmedColorDim && trimmedColorDim.length > 0 ? trimmedColorDim : undefined
+ this.customColorProvider = options?.customColorProvider
this.resetColorSchemeFromParams()
// Initialize viewport controller
@@ -127,6 +131,7 @@ export class CanvasRenderer {
// Build cell data cache
this.cellDataCache = this.buildCellDataCache()
+ this.applyCustomColorProvider()
}
/**
@@ -404,6 +409,39 @@ export class CanvasRenderer {
this.colorScheme = new ColorScheme(colorCount, 1)
}
+ private applyCustomColorProvider(): void {
+ if (!this.customColorProvider) {
+ if (this.customColorGroupCount > 0) {
+ this.customColorGroupCount = 0
+ if (this.traversalMode === 'by-output') {
+ this.rebuildColorSchemeFromThreads(this.maxThreadIdObserved)
+ } else {
+ this.resetColorSchemeFromParams()
+ }
+ }
+ return
+ }
+
+ let maxGroup = 0
+ let observed = false
+ for (const entries of this.cellDataCache.values()) {
+ for (const cell of entries) {
+ const value = this.customColorProvider(cell)
+ if (typeof value === 'number' && Number.isFinite(value)) {
+ const normalized = Math.max(0, Math.trunc(value))
+ maxGroup = Math.max(maxGroup, normalized)
+ observed = true
+ }
+ }
+ }
+
+ const desiredGroups = Math.max(1, observed ? maxGroup + 1 : 1)
+ if (desiredGroups !== this.customColorGroupCount) {
+ this.customColorGroupCount = desiredGroups
+ this.colorScheme = new ColorScheme(desiredGroups, 1)
+ }
+ }
+
/**
* Render the entire visualization
*/
@@ -487,6 +525,13 @@ export class CanvasRenderer {
}
private getCellFillColor(cellInfo: CellInfo): string {
+ if (this.customColorProvider) {
+ const value = this.customColorProvider(cellInfo)
+ if (typeof value === 'number' && Number.isFinite(value)) {
+ const normalized = Math.max(0, Math.trunc(value))
+ return this.colorScheme.getColorForWarp(normalized)
+ }
+ }
if (this.customColorDimension) {
const value = cellInfo.inputCoords?.[this.customColorDimension]
if (typeof value === 'number' && Number.isFinite(value)) {
@@ -638,6 +683,8 @@ export class CanvasRenderer {
}
setColorByInputDimension(dimensionName?: string): void {
+ this.customColorProvider = undefined
+ this.customColorGroupCount = 0
const normalized = dimensionName?.trim()
const nextDimension = normalized && normalized.length > 0 ? normalized : undefined
if (this.customColorDimension === nextDimension) {
@@ -654,6 +701,17 @@ export class CanvasRenderer {
this.render()
}
+ setCustomColorProvider(provider?: (cell: CellInfo) => number | null | undefined): void {
+ this.customColorProvider = provider
+ if (provider) {
+ this.customColorDimension = undefined
+ } else {
+ this.customColorGroupCount = 0
+ }
+ this.applyCustomColorProvider()
+ this.render()
+ }
+
/**
* Zoom in
*/
@@ -709,6 +767,7 @@ export class CanvasRenderer {
}
this.cellDataCache = this.buildCellDataCache()
+ this.applyCustomColorProvider()
this.render()
}
}
diff --git a/tests/visualization.spec.ts b/tests/visualization.spec.ts
index b28a9c8..ab7f552 100644
--- a/tests/visualization.spec.ts
+++ b/tests/visualization.spec.ts
@@ -1,14 +1,42 @@
import { test, expect } from '@playwright/test'
+import type { Page } from '@playwright/test'
+
+const waitForSharedCanvas = async (page: Page): Promise => {
+ await page.waitForFunction(() => {
+ const canvas = document.querySelector('#shared-canvas')
+ return Boolean(canvas && canvas.width > 0 && canvas.height > 0)
+ })
+}
+
+const openSharedLayoutTab = async (page: Page): Promise => {
+ await page.goto('/')
+ await page.locator('#tab-shared-layout').click()
+ await expect(page.locator('#shared-layout')).toHaveClass(/active/)
+ await waitForSharedCanvas(page)
+}
+
+const captureConsoleErrors = (page: Page): (() => string[]) => {
+ const errors: string[] = []
+ page.on('console', (message) => {
+ if (message.type() === 'error') {
+ errors.push(message.text())
+ }
+ })
+ page.on('pageerror', (error) => {
+ errors.push(error.message ?? String(error))
+ })
+ return () => errors
+}
test.describe('Triton Layout Visualizer', () => {
test('should load the application', async ({ page }) => {
await page.goto('/')
// Check that the title is correct
- await expect(page).toHaveTitle(/Triton Layout Visualizer/)
+ await expect(page).toHaveTitle(/GPU Tensor Layout Visualizer/)
// Check that main heading is visible
- await expect(page.locator('h1')).toContainText('Triton Block Layout Visualizer')
+ await expect(page.locator('h1')).toContainText('GPU Tensor Layout Visualizer')
})
test('should render the canvas', async ({ page }) => {
@@ -61,12 +89,9 @@ test.describe('Triton Layout Visualizer', () => {
// Enter invalid value (not power of 2)
await page.locator('#sizePerThread0').fill('3')
- // Submit form
- await page.locator('button[type="submit"]').click()
-
// Check that error message is displayed
+ await page.waitForSelector('#validation-errors.visible')
const errors = page.locator('#validation-errors')
- await expect(errors).toBeVisible()
await expect(errors).toContainText('power')
})
@@ -81,35 +106,12 @@ test.describe('Triton Layout Visualizer', () => {
await page.locator('#tensorShape0').fill('32')
await page.locator('#tensorShape1').fill('32')
- // Submit form
- await page.locator('button[type="submit"]').click()
-
- // Wait a bit for rendering
- await page.waitForTimeout(100)
-
- // Take another screenshot
- const newScreenshot = await canvas.screenshot()
-
// Screenshots should be different
- expect(Buffer.compare(initialScreenshot, newScreenshot)).not.toBe(0)
- })
-
- test('should have zoom controls', async ({ page }) => {
- await page.goto('/')
-
- // Check that zoom buttons exist
- await expect(page.locator('#zoomIn')).toBeVisible()
- await expect(page.locator('#zoomOut')).toBeVisible()
- await expect(page.locator('#reset')).toBeVisible()
-
- // Click zoom in button
- await page.locator('#zoomIn').click()
-
- // Click zoom out button
- await page.locator('#zoomOut').click()
-
- // Click reset button
- await page.locator('#reset').click()
+ await page.waitForTimeout(100)
+ await expect.poll(async () => {
+ const newScreenshot = await canvas.screenshot()
+ return Buffer.compare(initialScreenshot, newScreenshot)
+ }).not.toBe(0)
})
test('should show tooltip when hovering WMMA cells', async ({ page }) => {
@@ -147,3 +149,111 @@ test.describe('Triton Layout Visualizer', () => {
})
})
})
+
+test.describe('Shared Layout', () => {
+ test('Tab switching activates Shared Layout without console errors', async ({ page }) => {
+ const getConsoleErrors = captureConsoleErrors(page)
+ await openSharedLayoutTab(page)
+
+ const sharedContent = page.locator('#shared-layout')
+ await expect(sharedContent).toHaveClass(/active/)
+ await expect(sharedContent).toBeVisible()
+ await expect(page.locator('#block-layout')).toBeHidden()
+ await expect(page.locator('#shared-canvas')).toBeVisible()
+
+ expect(getConsoleErrors()).toEqual([])
+ })
+
+ test('Form controls are visible with default Shared Layout values', async ({ page }) => {
+ await openSharedLayoutTab(page)
+
+ const controlSelectors = [
+ '#shared-swizzle-mode',
+ '#shared-view-mode',
+ '#shared-vec',
+ '#shared-per-phase',
+ '#shared-max-phase',
+ '#shared-order',
+ '#shared-rows',
+ '#shared-cols',
+ '#shared-element-bits',
+ ]
+
+ for (const selector of controlSelectors) {
+ await expect(page.locator(selector)).toBeVisible()
+ }
+
+ await expect(page.locator('#shared-rows')).toHaveValue('128')
+ await expect(page.locator('#shared-cols')).toHaveValue('64')
+ await expect(page.locator('#shared-vec')).toHaveValue('4')
+ await expect(page.locator('#shared-per-phase')).toHaveValue('2')
+ await expect(page.locator('#shared-max-phase')).toHaveValue('4')
+ await expect(page.locator('#shared-order')).toHaveValue('1,0')
+ await expect(page.locator('#shared-element-bits')).toHaveValue('16')
+ await expect(page.locator('#shared-swizzle-mode')).toHaveValue('swizzled')
+ await expect(page.locator('#shared-view-mode')).toHaveValue('logical')
+ })
+
+ test('Validation errors appear and bank info clears for invalid tensors', async ({ page }) => {
+ await openSharedLayoutTab(page)
+
+ await page.locator('#shared-rows').fill('3')
+
+ const errors = page.locator('#shared-validation-errors')
+ await expect(errors).toHaveClass(/visible/)
+ await expect(errors).toContainText(/power/i)
+
+ const bankSelectors = [
+ '#shared-bank-count',
+ '#shared-bank-size',
+ '#shared-elems-per-bank',
+ '#shared-bank-segments',
+ ]
+
+ for (const selector of bankSelectors) {
+ await expect(page.locator(selector)).toHaveText('--')
+ }
+ })
+
+ test('Shared canvas renders and rerenders after tensor updates', async ({ page }) => {
+ await openSharedLayoutTab(page)
+
+ const canvas = page.locator('#shared-canvas')
+ await expect(canvas).toBeVisible()
+
+ const boundingBox = await canvas.boundingBox()
+ expect(boundingBox).not.toBeNull()
+ expect(boundingBox!.width).toBeGreaterThan(0)
+ expect(boundingBox!.height).toBeGreaterThan(0)
+
+ const initialScreenshot = await canvas.screenshot()
+
+ await page.locator('#shared-rows').fill('256')
+ await expect(page.locator('#shared-bank-segments')).toHaveText('256')
+
+ const updatedScreenshot = await canvas.screenshot()
+ expect(Buffer.compare(initialScreenshot, updatedScreenshot)).not.toBe(0)
+ })
+
+ test('Tooltip shows Shared Layout metadata without legacy fields', async ({ page }) => {
+ await openSharedLayoutTab(page)
+
+ const canvas = page.locator('#shared-canvas')
+ await expect(canvas).toBeVisible()
+
+ const box = await canvas.boundingBox()
+ expect(box).not.toBeNull()
+ await page.mouse.move(box!.x + box!.width / 2, box!.y + box!.height / 2)
+
+ const tooltip = page.locator('.layout-tooltip')
+ await expect(tooltip).toBeVisible()
+ await expect(tooltip).toContainText('Logical Index:')
+ await expect(tooltip).toContainText('Offset:')
+ await expect(tooltip).toContainText('Bank:')
+
+ const tooltipText = ((await tooltip.textContent()) ?? '').trim()
+ expect(tooltipText).toMatch(/Segment (?:\(Bank Row\):|:)/)
+ await expect(tooltip).not.toContainText('Original Col')
+ await expect(tooltip).not.toContainText('Swizzled Col')
+ })
+})