diff --git a/index.html b/index.html
index 078e31d..b87c54b 100644
--- a/index.html
+++ b/index.html
@@ -137,7 +137,7 @@
Segments (Bank Row):
diff --git a/src/layouts/SharedLayout.test.ts b/src/layouts/SharedLayout.test.ts
index 8418370..4bb3bdd 100644
--- a/src/layouts/SharedLayout.test.ts
+++ b/src/layouts/SharedLayout.test.ts
@@ -4,7 +4,9 @@ import {
computeBankInfo,
computeRowSwizzle,
computeSwizzledColumn,
+ createSharedBankLayout,
createSharedLayout,
+ SHARED_BANK_COUNT,
type SharedLayoutParams,
} from './SharedLayout'
@@ -193,11 +195,48 @@ describe('computeSwizzledColumn', () => {
})
describe('bank calculations', () => {
+ const assertBankLayoutMatchesAssign = (
+ tensorShape: [number, number],
+ elementBits: number,
+ overrides?: { bankCount?: number; bankSizeBits?: number }
+ ): void => {
+ const bankLayout = createSharedBankLayout(tensorShape, elementBits, overrides)
+ const totalElements = tensorShape[0] * tensorShape[1]
+ for (let offset = 0; offset < totalElements; offset++) {
+ const coords = bankLayout.layout.apply({ offset })
+ const assignment = assignBank(offset, elementBits, overrides)
+ expect(coords.bank).toBe(assignment.bank)
+ expect(coords.segment).toBe(assignment.segment)
+ }
+ }
+
+ const assertBankLayoutMatchesAssignWithHalves = (
+ tensorShape: [number, number],
+ overrides?: { bankCount?: number; bankSizeBits?: number }
+ ): void => {
+ const elementBits = 64
+ const bankLayout = createSharedBankLayout(tensorShape, elementBits, overrides)
+ const totalElements = tensorShape[0] * tensorShape[1]
+ const bankCount = overrides?.bankCount ?? SHARED_BANK_COUNT
+ for (let offset = 0; offset < totalElements; offset++) {
+ const base = assignBank(offset, elementBits, overrides)
+ for (let half = 0; half < 2; half++) {
+ const coords = bankLayout.layout.apply({ offset, half })
+ const baseAddress = base.segment * bankCount + base.bank
+ const address = baseAddress + half
+ const expectedBank = ((address % bankCount) + bankCount) % bankCount
+ const expectedSegment = Math.floor(address / bankCount)
+ expect(coords.bank).toBe(expectedBank)
+ expect(coords.segment).toBe(expectedSegment)
+ }
+ }
+ }
+
it('reports per-bank statistics derived from tensor shape', () => {
const info = computeBankInfo(baseParams.tensorShape, 16)
expect(info.bankCount).toBe(32)
expect(info.bankSizeBits).toBe(32)
- expect(info.elementsPerBankRow).toBe(2)
+ expect(info.elementsPerBankCell).toBe(2)
expect(info.segmentsPerBankRow).toBe(128)
})
@@ -220,22 +259,22 @@ describe('bank calculations', () => {
}
for (const bits of bitWidths) {
const info = computeBankInfo([64, 64], bits)
- expect(info.elementsPerBankRow).toBe(expectedElements[bits])
+ expect(info.elementsPerBankCell).toBe(expectedElements[bits])
expect(info.segmentsPerBankRow).toBeGreaterThan(0)
}
})
it('derives per-bank capacity directly from the configured bank width', () => {
const largerBanks = computeBankInfo([8, 8], 16, { bankCount: 64, bankSizeBits: 64 })
- expect(largerBanks.elementsPerBankRow).toBe(4)
+ expect(largerBanks.elementsPerBankCell).toBe(4)
const smallerBanks = computeBankInfo([8, 8], 16, { bankCount: 16, bankSizeBits: 16 })
- expect(smallerBanks.elementsPerBankRow).toBe(1)
+ expect(smallerBanks.elementsPerBankCell).toBe(1)
// Ensure bank count alone does not change the per-bank capacity
const moreBanksSameWidth = computeBankInfo([8, 8], 16, { bankCount: 128, bankSizeBits: 32 })
const fewerBanksSameWidth = computeBankInfo([8, 8], 16, { bankCount: 16, bankSizeBits: 32 })
- expect(moreBanksSameWidth.elementsPerBankRow).toBe(fewerBanksSameWidth.elementsPerBankRow)
+ expect(moreBanksSameWidth.elementsPerBankCell).toBe(fewerBanksSameWidth.elementsPerBankCell)
})
it('derives segments for tiny and large tensors', () => {
@@ -259,6 +298,182 @@ describe('bank calculations', () => {
const far = assignBank(4096, 32, overrides)
expect(far.segment).toBeGreaterThan(0)
})
+
+ it('maps offsets to bank layout coordinates for packed elements', () => {
+ const bankLayout = createSharedBankLayout([8, 8], 16)
+ for (let offset = 0; offset < 32; offset++) {
+ const coords = bankLayout.layout.apply({ offset })
+ const assignment = assignBank(offset, 16)
+ expect(coords.bank).toBe(assignment.bank)
+ expect(coords.segment).toBe(assignment.segment)
+ }
+ })
+
+ it('matches assignBank for 8-bit elements', () => {
+ assertBankLayoutMatchesAssign([8, 8], 8)
+ })
+
+ it('matches assignBank for 32-bit elements', () => {
+ assertBankLayoutMatchesAssign([8, 8], 32)
+ })
+
+ it('groups four 8-bit elements per bank in the bank layout', () => {
+ const bankLayout = createSharedBankLayout([8, 8], 8)
+ const offsets = [0, 1, 2, 3]
+ const positions = offsets.map((offset) => bankLayout.layout.apply({ offset }))
+ const banks = positions.map((coords) => coords.bank)
+ const segments = positions.map((coords) => coords.segment)
+ expect(new Set(banks).size).toBe(1)
+ expect(banks[0]).toBe(0)
+ expect(new Set(segments).size).toBe(1)
+ })
+
+ it('maps 32-bit elements to single banks without packing slots', () => {
+ const bankLayout = createSharedBankLayout([8, 8], 32)
+ expect(bankLayout.slotsPerBankCell).toBe(1)
+ expect(bankLayout.bankSpan).toBe(1)
+ expect(bankLayout.layout.getInDimNames()).not.toContain('half')
+
+ const first = bankLayout.layout.apply({ offset: 0 })
+ const second = bankLayout.layout.apply({ offset: 1 })
+ expect(first.segment).toBe(0)
+ expect(first.bank).toBe(0)
+ expect(second.segment).toBe(0)
+ expect(second.bank).toBe(1)
+ })
+
+ it('maps 64-bit elements to bank starts without packing rows', () => {
+ const bankLayout = createSharedBankLayout([8, 8], 64)
+ expect(bankLayout.slotsPerBankCell).toBe(1)
+ expect(bankLayout.layout.getInDimNames()).toContain('half')
+
+ const first = bankLayout.layout.apply({ offset: 0, half: 0 })
+ const second = bankLayout.layout.apply({ offset: 0, half: 1 })
+ expect(first.segment).toBe(0)
+ expect(second.segment).toBe(0)
+ expect(first.bank).toBe(0)
+ expect(second.bank).toBe(1)
+ })
+
+ it('matches assignBank for 64-bit elements across halves', () => {
+ assertBankLayoutMatchesAssignWithHalves([8, 8])
+ })
+
+ it('matches assignBank with bank overrides for 8-bit elements', () => {
+ const overrides = { bankCount: 16, bankSizeBits: 32 }
+ assertBankLayoutMatchesAssign([8, 8], 8, overrides)
+ })
+
+ it('matches assignBank with bank overrides for 32-bit elements', () => {
+ const overrides = { bankCount: 16, bankSizeBits: 32 }
+ assertBankLayoutMatchesAssign([8, 8], 32, overrides)
+ })
+
+ it('matches assignBank with bank overrides for 64-bit elements across halves', () => {
+ const overrides = { bankCount: 16, bankSizeBits: 32 }
+ assertBankLayoutMatchesAssignWithHalves([8, 8], overrides)
+ })
+})
+
+describe('logical/bank cross-validation', () => {
+ const tensorShapes: Array<[number, number]> = [
+ [8, 8],
+ [128, 64],
+ [256, 128],
+ ]
+ const elementBitWidths = [8, 16, 32, 64]
+ const swizzleCases: Array
> =
+ [
+ { vec: 1, perPhase: 1, maxPhase: 1, swizzleMode: 'swizzled' },
+ { vec: 2, perPhase: 2, maxPhase: 4, swizzleMode: 'swizzled' },
+ { vec: 4, perPhase: 4, maxPhase: 8, swizzleMode: 'swizzled' },
+ { vec: 4, perPhase: 2, maxPhase: 4, swizzleMode: 'amdRotating' },
+ ]
+ const order: [number, number] = [0, 1]
+
+ const buildOffsets = (totalElements: number): number[] => {
+ const offsets = new Set()
+ offsets.add(0)
+ offsets.add(totalElements - 1)
+ for (let power = 1; power < totalElements; power <<= 1) {
+ offsets.add(power)
+ offsets.add(power - 1)
+ offsets.add(power + 1)
+ }
+ if (totalElements > 4) {
+ offsets.add(Math.floor(totalElements / 3))
+ offsets.add(Math.floor((2 * totalElements) / 3))
+ }
+ return Array.from(offsets)
+ .filter((value) => value >= 0 && value < totalElements)
+ .sort((a, b) => a - b)
+ }
+
+ it('keeps logical and bank layouts consistent across offsets, shapes, and swizzle settings', () => {
+ for (const tensorShape of tensorShapes) {
+ const totalElements = tensorShape[0] * tensorShape[1]
+ const offsets = buildOffsets(totalElements)
+
+ for (const swizzle of swizzleCases) {
+ const params: SharedLayoutParams = {
+ tensorShape,
+ order,
+ vec: swizzle.vec,
+ perPhase: swizzle.perPhase,
+ maxPhase: swizzle.maxPhase,
+ swizzleMode: swizzle.swizzleMode,
+ }
+ const logicalLayout = createSharedLayout(params)
+ const inverseLayout = logicalLayout.layout.invert()
+ const rowCount = tensorShape[logicalLayout.rowDimIndex]
+ const colCount = tensorShape[logicalLayout.colDimIndex]
+
+ for (const elementBits of elementBitWidths) {
+ const bankLayout = createSharedBankLayout(tensorShape, elementBits)
+
+ for (const offset of offsets) {
+ const logicalCoords = logicalLayout.layout.apply({ offset })
+ const row = logicalCoords[logicalLayout.rowDimName]
+ const col = logicalCoords[logicalLayout.colDimName]
+
+ expect(typeof row).toBe('number')
+ expect(typeof col).toBe('number')
+
+ const expectedRow = Math.floor(offset / colCount)
+ const unswizzledCol = offset % colCount
+ const expectedCol = computeSwizzledColumn(expectedRow, unswizzledCol, params)
+
+ expect(row).toBe(expectedRow)
+ expect(row).toBeGreaterThanOrEqual(0)
+ expect(row).toBeLessThan(rowCount)
+ expect(col).toBe(expectedCol)
+ expect(col).toBeGreaterThanOrEqual(0)
+ expect(col).toBeLessThan(colCount)
+
+ const roundTripOffset = inverseLayout.apply(logicalCoords).offset
+ expect(roundTripOffset).toBe(offset)
+
+ const base = assignBank(roundTripOffset, elementBits)
+ if (bankLayout.bankSpan > 1) {
+ for (let half = 0; half < bankLayout.bankSpan; half++) {
+ const coords = bankLayout.layout.apply({ offset: roundTripOffset, half })
+ const baseAddress = base.segment * bankLayout.bankCount + base.bank
+ const address = baseAddress + half
+ const expectedBank = ((address % bankLayout.bankCount) + bankLayout.bankCount) % bankLayout.bankCount
+ const expectedSegment = Math.floor(address / bankLayout.bankCount)
+ expect(coords.bank).toBe(expectedBank)
+ expect(coords.segment).toBe(expectedSegment)
+ }
+ } else {
+ const coords = bankLayout.layout.apply({ offset: roundTripOffset })
+ expect(coords.bank).toBe(base.bank)
+ expect(coords.segment).toBe(base.segment)
+ }
+ }
+ }
+ }
+ }
+ })
})
describe('validation', () => {
diff --git a/src/layouts/SharedLayout.ts b/src/layouts/SharedLayout.ts
index 62af498..3710274 100644
--- a/src/layouts/SharedLayout.ts
+++ b/src/layouts/SharedLayout.ts
@@ -25,10 +25,20 @@ export interface SharedLayoutBuildResult {
export interface SharedBankInfo {
bankCount: number
bankSizeBits: number
- elementsPerBankRow: number
+ elementsPerBankCell: number
segmentsPerBankRow: number
}
+export interface SharedBankLayoutResult {
+ layout: LinearLayout
+ bankCount: number
+ bankSizeBits: number
+ segmentCount: number
+ slotsPerBankCell: number
+ rowCount: number
+ bankSpan: number
+}
+
export interface BankAssignmentOptions {
bankCount?: number
bankSizeBits?: number
@@ -48,10 +58,15 @@ function assertPowerOfTwo(value: number, context: string): void {
throw new Error(`${context} must be a positive integer`)
}
if ((value & (value - 1)) !== 0) {
- throw new Error(`${context} must be a power of two`)
+ throw new Error(`${context} must be a power of two (received ${value})`)
}
}
+function log2PowerOfTwo(value: number, context: string): number {
+ assertPowerOfTwo(value, context)
+ return Math.trunc(Math.log2(value))
+}
+
function assertDimIndex(value: number, context: string): asserts value is DimIndex {
if (value !== 0 && value !== 1) {
throw new Error(`${context} must be 0 or 1`)
@@ -200,6 +215,103 @@ export function computeSwizzledColumn(
return normalizedCol ^ rowSwizzle
}
+export function createSharedBankLayout(
+ tensorShape: [number, number],
+ elementBits: number,
+ overrides?: BankAssignmentOptions
+): SharedBankLayoutResult {
+ assertPositive(elementBits, 'Element bitwidth')
+ assertPowerOfTwo(elementBits, 'Element bitwidth')
+
+ const bankCount = overrides?.bankCount ?? SHARED_BANK_COUNT
+ const bankSizeBits = overrides?.bankSizeBits ?? SHARED_BANK_WIDTH_BITS
+ assertPowerOfTwo(bankCount, 'Bank count')
+ assertPowerOfTwo(bankSizeBits, 'Bank size (bits)')
+
+ const totalElements = Math.max(1, tensorShape[0] * tensorShape[1])
+ assertPowerOfTwo(totalElements, `Total element count (${totalElements})`)
+
+ const bytesPerElement = Math.max(1, Math.trunc(elementBits / 8))
+ const bankWidthBytes = Math.max(1, Math.trunc(bankSizeBits / 8))
+ const bankRowBytes = Math.max(1, bankCount * bankWidthBytes)
+ const totalBytes = totalElements * bytesPerElement
+ const segmentCount = Math.max(1, Math.ceil(totalBytes / bankRowBytes))
+ assertPowerOfTwo(segmentCount, `Bank segment count (${segmentCount})`)
+
+ const elementsPerBankCell = bankSizeBits / elementBits
+ const usesPacking = elementsPerBankCell >= 1
+ const slotsPerBankCell = usesPacking ? Math.trunc(elementsPerBankCell) : 1
+ assertPowerOfTwo(slotsPerBankCell, 'Slots per bank cell')
+
+ const rowCount = segmentCount
+ assertPowerOfTwo(rowCount, 'Bank row count')
+
+ const slotBits = log2PowerOfTwo(slotsPerBankCell, 'Slots per bank cell')
+ const bankBits = log2PowerOfTwo(bankCount, 'Bank count')
+ const segmentBits = log2PowerOfTwo(segmentCount, 'Bank segment count')
+ const offsetBits = log2PowerOfTwo(totalElements, 'Total element count')
+
+ let bankSpan = 1
+ if (!usesPacking) {
+ bankSpan = elementBits / bankSizeBits
+ assertPowerOfTwo(bankSpan, 'Bank span')
+ const shift = log2PowerOfTwo(bankSpan, 'Bank span')
+ const expectedOffsetBits = segmentBits + bankBits - shift
+ if (offsetBits !== expectedOffsetBits) {
+ throw new Error('Offset bit count does not match bank grid dimensions')
+ }
+ } else {
+ const maxOffsetBits = slotBits + bankBits + segmentBits
+ if (offsetBits > maxOffsetBits) {
+ throw new Error('Offset bit count exceeds packed bank grid dimensions')
+ }
+ }
+
+ const offsetBases: number[][] = []
+ for (let bit = 0; bit < offsetBits; bit++) {
+ const offset = 1 << bit
+ const bankAddress = usesPacking
+ ? Math.floor(offset / slotsPerBankCell)
+ : offset * bankSpan
+ const bank = ((bankAddress % bankCount) + bankCount) % bankCount
+ const segment = Math.floor(bankAddress / bankCount)
+ offsetBases.push([segment, bank])
+ }
+
+ const basesArray: Array<[string, number[][]]> = [['offset', offsetBases]]
+
+ if (bankSpan > 1) {
+ const halfBits = log2PowerOfTwo(bankSpan, 'Bank span')
+ const halfBases: number[][] = []
+ for (let bit = 0; bit < halfBits; bit++) {
+ const half = 1 << bit
+ const bankAddress = half
+ const bank = ((bankAddress % bankCount) + bankCount) % bankCount
+ const segment = Math.floor(bankAddress / bankCount)
+ halfBases.push([segment, bank])
+ }
+ basesArray.push(['half', halfBases])
+ }
+
+ const layout = new LinearLayout(
+ basesArray,
+ [
+ ['segment', rowCount],
+ ['bank', bankCount],
+ ]
+ )
+
+ return {
+ layout,
+ bankCount,
+ bankSizeBits,
+ segmentCount,
+ slotsPerBankCell,
+ rowCount,
+ bankSpan,
+ }
+}
+
export function computeBankInfo(
tensorShape: [number, number],
elementBits: number,
@@ -213,7 +325,7 @@ export function computeBankInfo(
const bytesPerElement = Math.max(1, Math.trunc(elementBits / 8))
const totalBankRowBits = bankCount * bankSizeBits
- const elementsPerBankRow = bankSizeBits / elementBits
+ const elementsPerBankCell = bankSizeBits / elementBits
const totalElements = Math.max(1, tensorShape[0] * tensorShape[1])
const totalBytes = totalElements * bytesPerElement
const bankRowBytes = Math.max(1, Math.trunc(totalBankRowBits / 8))
@@ -222,7 +334,7 @@ export function computeBankInfo(
return {
bankCount,
bankSizeBits,
- elementsPerBankRow,
+ elementsPerBankCell,
segmentsPerBankRow,
}
}
diff --git a/src/tabs/SharedLayoutTab.test.ts b/src/tabs/SharedLayoutTab.test.ts
index ae33e90..9fb0526 100644
--- a/src/tabs/SharedLayoutTab.test.ts
+++ b/src/tabs/SharedLayoutTab.test.ts
@@ -6,6 +6,7 @@ type RendererStub = {
setCustomColorProvider: ReturnType
screenToGrid: ReturnType
getCellInfo: ReturnType
+ getCellEntries: ReturnType
}
const { rendererInstances, createRendererStub } = vi.hoisted(() => {
@@ -24,6 +25,7 @@ const { rendererInstances, createRendererStub } = vi.hoisted(() => {
inputCoords: { offset: 0 },
outputCoords: { dim0: 0, dim1: 0 },
}),
+ getCellEntries: vi.fn().mockReturnValue(null),
})
return { rendererInstances: instances, createRendererStub: buildStub }
})
@@ -111,7 +113,7 @@ const setupDom = () => {
@@ -312,5 +314,62 @@ describe('SharedLayoutTab', () => {
expectTooltipContains('