Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ <h3>Operand</h3>

<div id="wmma-validation-errors" class="validation-errors"></div>
<div id="wmma-validation-warnings" class="validation-warnings"></div>
<button type="button" id="wmma-show-linear-layout">Show in Linear Layout</button>
</form>

<div class="controls" data-controls></div>
Expand Down Expand Up @@ -220,6 +221,7 @@ <h3>Operand</h3>

<div id="mfma-validation-errors" class="validation-errors"></div>
<div id="mfma-validation-warnings" class="validation-warnings"></div>
<button type="button" id="mfma-show-linear-layout">Show in Linear Layout</button>
</form>

<div class="controls" data-controls></div>
Expand Down
12 changes: 12 additions & 0 deletions src/main.tabs.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@ const setupDom = () => {
`
form.appendChild(operandSelect)

const showLinearButton = document.createElement('button')
showLinearButton.type = 'button'
showLinearButton.id = 'wmma-show-linear-layout'
showLinearButton.textContent = 'Show in Linear Layout'
form.appendChild(showLinearButton)

sidebar.appendChild(form)

const controls = document.createElement('div')
Expand Down Expand Up @@ -346,6 +352,12 @@ const setupDom = () => {
`
form.appendChild(operandSelect)

const showLinearButton = document.createElement('button')
showLinearButton.type = 'button'
showLinearButton.id = 'mfma-show-linear-layout'
showLinearButton.textContent = 'Show in Linear Layout'
form.appendChild(showLinearButton)

sidebar.appendChild(form)

const controls = document.createElement('div')
Expand Down
76 changes: 56 additions & 20 deletions src/tabs/BlockLayoutTab.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { BlockLayoutTab } from './BlockLayoutTab'
import { createBlockLayout } from '../layouts/BlockLayout'
import { LinearLayout } from '../core/LinearLayout'
import { layoutProjectionBus, LINEAR_LAYOUT_TAB_ID } from '../integration/LayoutProjectionBus'
import type { SnapshotFilterResult } from '../core/filterSnapshotDimensions'

type ParameterFormStub = {
getParams: ReturnType<typeof vi.fn>
Expand Down Expand Up @@ -82,15 +81,19 @@ vi.mock('../visualization/CanvasRenderer', () => ({
}),
}))

const buildBlockLayoutDom = (): void => {
const buildBlockLayoutDom = (options: { includeButton?: boolean } = {}): void => {
const { includeButton = true } = options
const buttonMarkup = includeButton
? '<button type="button" id="show-linear-layout">Show in Linear Layout</button>'
: ''
document.body.innerHTML = `
<div id="block-layout">
<div class="visualization">
<canvas id="canvas"></canvas>
</div>
<aside class="sidebar">
<form id="paramForm">
<button type="button" id="show-linear-layout">Show in Linear Layout</button>
${buttonMarkup}
</form>
<div id="validation-errors"></div>
<div id="validation-warnings"></div>
Expand Down Expand Up @@ -134,34 +137,33 @@ describe('BlockLayoutTab', () => {
document.body.innerHTML = ''
})

it('prepares snapshots that normalize lane naming and filter size-1 dimensions', () => {
const tab = instantiateTab()
const params: BlockLayoutParams = {
it('publishes snapshots that normalize lane naming and filter size-1 dimensions', () => {
instantiateTab()
currentParams = {
sizePerThread: [1, 2],
threadsPerWarp: [1, 32],
warpsPerCTA: [1, 1],
order: [0, 1],
tensorShape: [1, 64],
}
const layout = createBlockLayout(params)
const { snapshot, removedInputDimensions } = (tab as unknown as {
prepareLinearLayoutSnapshot(
l: LinearLayout,
p: BlockLayoutParams
): SnapshotFilterResult
}).prepareLinearLayoutSnapshot(layout, params)

const inputNames = snapshot.inputDimensions.map((dim) => dim.name)

const publishSpy = vi.spyOn(layoutProjectionBus, 'publish')
document.getElementById('show-linear-layout')?.dispatchEvent(new MouseEvent('click'))

const snapshot = publishSpy.mock.calls[0]?.[0]?.snapshot
expect(snapshot).toBeDefined()
const inputNames = snapshot?.inputDimensions.map((dim) => dim.name) ?? []
expect(inputNames).toContain('lane')
expect(inputNames).not.toContain('thread')

const warpDim = snapshot.inputDimensions.find((dim) => dim.name === 'warp')
const warpDim = snapshot?.inputDimensions.find((dim) => dim.name === 'warp')
expect(warpDim).toBeUndefined()
expect(removedInputDimensions).toContain('warp')

expect(snapshot.metadata?.colorInputDimension).toBe('warp')
expect(snapshot.metadata?.description).toContain('Projection')
expect(snapshot.metadata?.tensorShape).toEqual([...params.tensorShape])
expect(snapshot?.metadata?.colorInputDimension).toBe('warp')
expect(snapshot?.metadata?.description).toContain('Projection')
expect(snapshot?.metadata?.tensorShape).toEqual([...currentParams.tensorShape])

publishSpy.mockRestore()
})

const projectionCases: Array<{ name: string; order: [number, number] }> = [
Expand Down Expand Up @@ -234,4 +236,38 @@ describe('BlockLayoutTab', () => {
publishSpy.mockRestore()
})
})

it('alerts when all dimensions collapse during projection', () => {
currentParams = {
sizePerThread: [1, 1],
threadsPerWarp: [1, 1],
warpsPerCTA: [1, 1],
order: [0, 1],
tensorShape: [1, 1],
}

const alertSpy = vi.spyOn(window, 'alert').mockImplementation(() => {})
const publishSpy = vi.spyOn(layoutProjectionBus, 'publish')

instantiateTab()
document.getElementById('show-linear-layout')?.dispatchEvent(new MouseEvent('click'))

expect(alertSpy).toHaveBeenCalledWith(
'Unable to derive a linear layout because all input and output dimensions are size 1.'
)
expect(publishSpy).not.toHaveBeenCalled()

alertSpy.mockRestore()
publishSpy.mockRestore()
})

it('warns when the Show in Linear Layout button is missing', () => {
buildBlockLayoutDom({ includeButton: false })
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
expect(() => new BlockLayoutTab('block-layout')).not.toThrow()
expect(warnSpy).toHaveBeenCalledWith(
'Show in Linear Layout button not found in BlockLayoutTab.'
)
warnSpy.mockRestore()
})
})
74 changes: 15 additions & 59 deletions src/tabs/BlockLayoutTab.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import type { LayoutMatrixSnapshot, LinearLayout } from '../core/LinearLayout'
import type { LinearLayout } from '../core/LinearLayout'
import { createBlockLayout } from '../layouts/BlockLayout'
import type { BlockLayoutParams } from '../validation/InputValidator'
import { CanvasRenderer } from '../visualization/CanvasRenderer'
import { ParameterForm } from '../ui/ParameterForm'
import { renderSharedControls } from '../ui/renderSharedControls'
import { layoutProjectionBus, LINEAR_LAYOUT_TAB_ID } from '../integration/LayoutProjectionBus'
import { CanvasTab, type CanvasTabElements } from './CanvasTab'
import {
filterSnapshotDimensions,
type SnapshotFilterResult,
} from '../core/filterSnapshotDimensions'
import type { SnapshotFilterResult } from '../core/filterSnapshotDimensions'

/**
* Controller for the Block Layout tab. Handles form input, visualization
Expand Down Expand Up @@ -71,47 +67,18 @@ export class BlockLayoutTab extends CanvasTab {
}

private setupShowInLinearLayoutButton(): void {
const button = this.root.querySelector<HTMLButtonElement>('#show-linear-layout')
if (!button) {
console.warn('Show in Linear Layout button not found in BlockLayoutTab.')
return
}

button.addEventListener('click', () => {
if (!this.form.validate()) {
return
}
const params = this.form.getParams()
try {
this.setupLinearLayoutProjection({
buttonSelector: '#show-linear-layout',
missingButtonWarning: 'Show in Linear Layout button not found in BlockLayoutTab.',
sourceTabId: this.tabId,
shouldProject: () => this.form.validate(),
buildSnapshot: () => {
const params = this.form.getParams()
const layout = createBlockLayout(params)
const { snapshot, removedInputDimensions, removedOutputDimensions } =
this.prepareLinearLayoutSnapshot(layout, params)
if (snapshot.inputDimensions.length === 0 || snapshot.outputDimensions.length === 0) {
const inputCollapsed = removedInputDimensions.length > 0
const outputCollapsed = removedOutputDimensions.length > 0
let message = 'Unable to derive a linear layout from the current configuration.'
if (inputCollapsed || outputCollapsed) {
if (inputCollapsed && outputCollapsed) {
message =
'Unable to derive a linear layout because all input and output dimensions are size 1.'
} else if (inputCollapsed) {
message = 'Unable to derive a linear layout because all input dimensions are size 1.'
} else {
message = 'Unable to derive a linear layout because all output dimensions are size 1.'
}
}
alert(message)
return
}
layoutProjectionBus.publish({
sourceTabId: this.tabId,
targetTabId: LINEAR_LAYOUT_TAB_ID,
snapshot,
})
} catch (error) {
console.error('Failed to project Block Layout into Linear Layout:', error)
alert('Failed to convert Block Layout into Linear Layout. See console for details.')
}
return this.buildBlockProjectionSnapshot(layout, params)
},
errorLogContext: 'Block Layout',
errorAlertMessage: 'Failed to convert Block Layout into Linear Layout. See console for details.',
})
}

Expand Down Expand Up @@ -174,33 +141,22 @@ export class BlockLayoutTab extends CanvasTab {
this.hideTooltip()
}

private prepareLinearLayoutSnapshot(
private buildBlockProjectionSnapshot(
layout: LinearLayout,
params: BlockLayoutParams
): SnapshotFilterResult {
const snapshot = layout.toMatrixSnapshot({
return this.prepareLinearLayoutSnapshot(layout, {
sourceLayoutType: 'block-layout',
sourceTabId: this.tabId,
tensorShape: [...params.tensorShape],
generatedAt: new Date().toISOString(),
description: 'Projection from Block Layout controls',
colorInputDimension: this.getColorDimensionPreference(),
})
this.normalizeLaneDimensionName(snapshot)
return filterSnapshotDimensions(snapshot)
}

private getColorDimensionPreference(): string {
return 'warp'
}

private normalizeLaneDimensionName(snapshot: LayoutMatrixSnapshot): void {
snapshot.inputDimensions = snapshot.inputDimensions.map((dimension) => {
const trimmed = dimension.name.trim()
if (trimmed.toLowerCase() !== 'lane') {
return dimension
}
return { ...dimension, name: 'lane' }
})
}
}
Loading