diff --git a/index.html b/index.html
index 72247cc..b84b60f 100644
--- a/index.html
+++ b/index.html
@@ -99,7 +99,6 @@
Tensor Shape
-
diff --git a/src/main.tabs.test.ts b/src/main.tabs.test.ts
index 58e0351..b5736c0 100644
--- a/src/main.tabs.test.ts
+++ b/src/main.tabs.test.ts
@@ -11,12 +11,14 @@ const mockParams: BlockLayoutParams = {
const getParamsMock = vi.fn()
const validateMock = vi.fn()
-const onSubmitMock = vi.fn()
+const onParamsChangeMock = vi.fn()
+const onValidationChangeMock = vi.fn()
vi.mock('./ui/ParameterForm', () => ({
ParameterForm: vi.fn().mockImplementation(() => ({
getParams: getParamsMock,
validate: validateMock,
- onSubmit: onSubmitMock,
+ onParamsChange: onParamsChangeMock,
+ onValidationChange: onValidationChangeMock,
})),
}))
@@ -426,7 +428,11 @@ describe('main tab switching', () => {
validateMock.mockReset()
validateMock.mockReturnValue(true)
- onSubmitMock.mockReset()
+ onParamsChangeMock.mockReset()
+ onValidationChangeMock.mockReset()
+ onValidationChangeMock.mockImplementation((callback: (isValid: boolean) => void) => {
+ callback(true)
+ })
projectionListeners.clear()
tabActivationHandler = null
subscribeMock.mockClear()
diff --git a/src/styles.css b/src/styles.css
index 247873e..a80e1f0 100644
--- a/src/styles.css
+++ b/src/styles.css
@@ -175,6 +175,16 @@ button:active {
background-color: #21618c;
}
+button:disabled {
+ background-color: #95a5a6;
+ cursor: not-allowed;
+ opacity: 0.65;
+}
+
+button:disabled:hover {
+ background-color: #95a5a6;
+}
+
.dimension-add {
width: 100%;
padding: 0.5rem;
diff --git a/src/tabs/BlockLayoutTab.test.ts b/src/tabs/BlockLayoutTab.test.ts
index 4413312..bbe43dd 100644
--- a/src/tabs/BlockLayoutTab.test.ts
+++ b/src/tabs/BlockLayoutTab.test.ts
@@ -8,7 +8,10 @@ import { layoutProjectionBus, LINEAR_LAYOUT_TAB_ID } from '../integration/Layout
type ParameterFormStub = {
getParams: ReturnType
validate: ReturnType
- onSubmit: ReturnType
+ onParamsChange: ReturnType
+ onValidationChange: ReturnType
+ triggerParamsChange?: (params: BlockLayoutParams) => void
+ triggerValidationChange?: (isValid: boolean) => void
}
const defaultParams: BlockLayoutParams = {
@@ -33,11 +36,35 @@ const parameterFormInstances: ParameterFormStub[] = []
vi.mock('../ui/ParameterForm', () => ({
ParameterForm: vi.fn().mockImplementation(() => {
+ let lastValidationNotification: boolean | null = null
const instance: ParameterFormStub = {
getParams: vi.fn(() => currentParams),
- validate: vi.fn(() => validateResult),
- onSubmit: vi.fn(),
+ validate: vi.fn(),
+ onParamsChange: vi.fn(),
+ onValidationChange: vi.fn(),
+ triggerParamsChange: undefined,
+ triggerValidationChange: undefined,
}
+ instance.validate.mockImplementation(() => {
+ if (
+ typeof instance.triggerValidationChange === 'function' &&
+ lastValidationNotification !== validateResult
+ ) {
+ instance.triggerValidationChange(validateResult)
+ }
+ lastValidationNotification = validateResult
+ return validateResult
+ })
+ instance.onParamsChange.mockImplementation((callback: (params: BlockLayoutParams) => void) => {
+ instance.triggerParamsChange = callback
+ })
+ instance.onValidationChange.mockImplementation((callback: (isValid: boolean) => void) => {
+ instance.triggerValidationChange = (state: boolean) => {
+ lastValidationNotification = state
+ callback(state)
+ }
+ instance.triggerValidationChange(validateResult)
+ })
parameterFormInstances.push(instance)
return instance
}),
@@ -135,6 +162,46 @@ describe('BlockLayoutTab', () => {
document.body.innerHTML = ''
})
+ it('re-renders automatically when parameters change', () => {
+ instantiateTab()
+ expect(canvasRendererInstances.length).toBe(1)
+
+ const formInstance = parameterFormInstances[0]
+ expect(formInstance?.triggerParamsChange).toBeDefined()
+
+ currentParams = {
+ sizePerThread: [4, 1],
+ threadsPerWarp: [2, 16],
+ warpsPerCTA: [1, 1],
+ order: [1, 0],
+ tensorShape: [32, 8],
+ }
+
+ formInstance?.triggerParamsChange?.(currentParams)
+
+ expect(canvasRendererInstances.length).toBe(2)
+ })
+
+ it('toggles the Show in Linear Layout button based on validation state changes', () => {
+ instantiateTab()
+ const formInstance = parameterFormInstances[0]
+ const button = document.getElementById('show-linear-layout') as HTMLButtonElement
+ expect(button.disabled).toBe(false)
+
+ formInstance?.triggerValidationChange?.(false)
+ expect(button.disabled).toBe(true)
+
+ formInstance?.triggerValidationChange?.(true)
+ expect(button.disabled).toBe(false)
+ })
+
+ it('initializes with the button disabled when the form is invalid', () => {
+ validateResult = false
+ instantiateTab()
+ const button = document.getElementById('show-linear-layout') as HTMLButtonElement
+ expect(button.disabled).toBe(true)
+ })
+
it('publishes snapshots that normalize lane naming and filter size-1 dimensions', () => {
instantiateTab()
currentParams = {
diff --git a/src/tabs/BlockLayoutTab.ts b/src/tabs/BlockLayoutTab.ts
index 8fe5e76..3d1d23f 100644
--- a/src/tabs/BlockLayoutTab.ts
+++ b/src/tabs/BlockLayoutTab.ts
@@ -48,15 +48,18 @@ export class BlockLayoutTab extends CanvasTab {
this.tabId = tabId
this.form = new ParameterForm('paramForm')
+ const linearLayoutButton = this.setupShowInLinearLayoutButton()
+ if (linearLayoutButton) {
+ this.monitorShowInLinearLayoutButton(linearLayoutButton)
+ }
this.setupFormHandlers()
- this.setupShowInLinearLayoutButton()
}
/**
* Initialize form listeners and kick off the first render.
*/
private setupFormHandlers(): void {
- this.form.onSubmit((params) => {
+ this.form.onParamsChange((params) => {
this.updateVisualization(params)
})
@@ -66,7 +69,7 @@ export class BlockLayoutTab extends CanvasTab {
}
}
- private setupShowInLinearLayoutButton(): void {
+ private setupShowInLinearLayoutButton(): HTMLButtonElement | null {
this.setupLinearLayoutProjection({
buttonSelector: '#show-linear-layout',
missingButtonWarning: 'Show in Linear Layout button not found in BlockLayoutTab.',
@@ -80,6 +83,14 @@ export class BlockLayoutTab extends CanvasTab {
errorLogContext: 'Block Layout',
errorAlertMessage: 'Failed to convert Block Layout into Linear Layout. See console for details.',
})
+
+ return this.root.querySelector('#show-linear-layout')
+ }
+
+ private monitorShowInLinearLayoutButton(button: HTMLButtonElement): void {
+ this.form.onValidationChange((isValid) => {
+ button.disabled = !isValid
+ })
}
/**
diff --git a/src/ui/ParameterForm.test.ts b/src/ui/ParameterForm.test.ts
new file mode 100644
index 0000000..3ed5abb
--- /dev/null
+++ b/src/ui/ParameterForm.test.ts
@@ -0,0 +1,76 @@
+import { describe, it, beforeEach, expect, vi } from 'vitest'
+import { ParameterForm } from './ParameterForm'
+
+const buildFormDom = (): void => {
+ document.body.innerHTML = `
+
+ `
+}
+
+describe('ParameterForm', () => {
+ beforeEach(() => {
+ buildFormDom()
+ })
+
+ it('shows inline errors for decimal input instead of throwing', () => {
+ const form = new ParameterForm('paramForm')
+ const sizeInput = document.getElementById('sizePerThread0') as HTMLInputElement
+ sizeInput.value = '2.5'
+
+ const isValid = form.validate()
+ expect(isValid).toBe(false)
+
+ const errorsDiv = document.getElementById('validation-errors') as HTMLElement
+ expect(errorsDiv.classList.contains('visible')).toBe(true)
+ expect(errorsDiv.textContent).toContain('Size per thread')
+ })
+
+ it('suppresses callbacks and flips validation state when parse errors occur', () => {
+ const form = new ParameterForm('paramForm')
+ const changeSpy = vi.fn()
+ const validationSpy = vi.fn()
+ form.onParamsChange(changeSpy)
+ form.onValidationChange(validationSpy)
+ validationSpy.mockClear()
+
+ const tpwInput = document.getElementById('threadsPerWarp0') as HTMLInputElement
+ tpwInput.value = 'abc'
+
+ const invalidEvent = new Event('input', { bubbles: true })
+ tpwInput.dispatchEvent(invalidEvent)
+
+ expect(changeSpy).not.toHaveBeenCalled()
+ expect(validationSpy).toHaveBeenCalledWith(false)
+
+ const params = form.getParams()
+ expect(Number.isNaN(params.threadsPerWarp[0])).toBe(true)
+
+ validationSpy.mockClear()
+ tpwInput.value = '8'
+ tpwInput.dispatchEvent(new Event('input', { bubbles: true }))
+
+ expect(changeSpy).toHaveBeenCalledTimes(1)
+ expect(validationSpy).toHaveBeenCalledWith(true)
+ })
+})
diff --git a/src/ui/ParameterForm.ts b/src/ui/ParameterForm.ts
index 4fca92e..2bcacac 100644
--- a/src/ui/ParameterForm.ts
+++ b/src/ui/ParameterForm.ts
@@ -5,6 +5,8 @@ export class ParameterForm {
private errorsDiv: HTMLElement
private warningsDiv: HTMLElement
private validator: InputValidator
+ private validationListeners = new Set<(isValid: boolean) => void>()
+ private lastValidationResult: boolean | null = null
constructor(formId: string) {
const form = document.getElementById(formId)
@@ -74,6 +76,7 @@ export class ParameterForm {
this.hideWarnings()
}
+ this.updateValidationState(result.valid)
return result.valid
}
@@ -115,17 +118,45 @@ export class ParameterForm {
this.warningsDiv.innerHTML = ''
}
+ private updateValidationState(nextState: boolean): void {
+ if (this.lastValidationResult === nextState) {
+ return
+ }
+ this.lastValidationResult = nextState
+ this.validationListeners.forEach((listener) => listener(nextState))
+ }
+
/**
- * Add event listener for form submission
+ * Invoke the callback whenever any form field changes and the inputs validate.
*/
- onSubmit(callback: (params: BlockLayoutParams) => void): void {
- this.form.addEventListener('submit', (event) => {
- event.preventDefault()
-
+ onParamsChange(callback: (params: BlockLayoutParams) => void): void {
+ const handleChange = (): void => {
if (this.validate()) {
callback(this.getParams())
}
+ }
+
+ const fields = this.form.querySelectorAll('input, select')
+ fields.forEach((field) => {
+ const eventName = field instanceof HTMLSelectElement ? 'change' : 'input'
+ field.addEventListener(eventName, handleChange)
})
+
+ this.form.addEventListener('submit', (event) => event.preventDefault())
+ }
+
+ /**
+ * Notify listeners whenever the validation state changes.
+ */
+ onValidationChange(callback: (isValid: boolean) => void): void {
+ this.validationListeners.add(callback)
+
+ if (this.lastValidationResult === null) {
+ this.validate()
+ return
+ }
+
+ callback(this.lastValidationResult)
}
/**
@@ -136,11 +167,12 @@ export class ParameterForm {
if (!input) {
throw new Error(`Input not found: ${id}`)
}
- const value = Number(input.value)
- if (!Number.isFinite(value) || !Number.isInteger(value)) {
- throw new Error(`Invalid number value for ${id}: ${input.value}`)
+ const rawValue = input.value.trim()
+ if (rawValue === '') {
+ return Number.NaN
}
- return value
+ const numericValue = Number(rawValue)
+ return Number.isFinite(numericValue) ? numericValue : Number.NaN
}
/**
diff --git a/src/validation/InputValidator.test.ts b/src/validation/InputValidator.test.ts
index 82a8406..aeb5252 100644
--- a/src/validation/InputValidator.test.ts
+++ b/src/validation/InputValidator.test.ts
@@ -326,6 +326,49 @@ describe('InputValidator', () => {
})
})
+ describe('decimal and NaN handling', () => {
+ it('should flag decimal inputs without throwing', () => {
+ const params = {
+ sizePerThread: [2.5, 2],
+ threadsPerWarp: [8, 4],
+ warpsPerCTA: [1, 2],
+ order: [0, 1],
+ tensorShape: [16, 16],
+ } as const
+
+ expect(() => validator.validateBlockLayoutParams(params)).not.toThrow()
+ const result = validator.validateBlockLayoutParams(params)
+ expect(result.valid).toBe(false)
+ expect(result.errors.get('sizePerThread')).toContain('positive integers')
+ })
+
+ it('should treat NaN inputs as validation errors and skip derived checks', () => {
+ const nanResult = validator.validateBlockLayoutParams({
+ sizePerThread: [Number.NaN, 2],
+ threadsPerWarp: [8, 4],
+ warpsPerCTA: [1, 2],
+ order: [0, 1],
+ tensorShape: [8, 16],
+ })
+
+ expect(nanResult.valid).toBe(false)
+ expect(nanResult.errors.get('sizePerThread')).toContain('positive integers')
+ expect(nanResult.errors.has('tensorShape')).toBe(false)
+
+ const tpwResult = validator.validateBlockLayoutParams({
+ sizePerThread: [2, 2],
+ threadsPerWarp: [Number.NaN, 4],
+ warpsPerCTA: [1, 2],
+ order: [0, 1],
+ tensorShape: [16, 16],
+ })
+
+ expect(tpwResult.valid).toBe(false)
+ expect(tpwResult.errors.get('threadsPerWarp')).toContain('positive integers')
+ expect(tpwResult.errors.get('threadsPerWarp')).not.toContain('32 or 64')
+ })
+ })
+
describe('helper methods', () => {
it('isPowerOfTwo should work correctly', () => {
expect(validator['isPowerOfTwo'](1)).toBe(true)
diff --git a/src/validation/InputValidator.ts b/src/validation/InputValidator.ts
index 2b7af01..c03b368 100644
--- a/src/validation/InputValidator.ts
+++ b/src/validation/InputValidator.ts
@@ -30,7 +30,7 @@ export class InputValidator {
}
// Validate threadsPerWarp product (must be 32 or 64)
- const tpwProductError = this.validateThreadsPerWarpProduct(params.threadsPerWarp)
+ const tpwProductError = tpwError ? null : this.validateThreadsPerWarpProduct(params.threadsPerWarp)
if (tpwProductError) {
errors.set('threadsPerWarp', tpwProductError)
}
@@ -53,8 +53,9 @@ export class InputValidator {
errors.set('tensorShape', tensorShapeError)
}
- // Validate tensor shape vs coverage (only if basic tensor shape validation passed)
- if (!tensorShapeError) {
+ // Validate tensor shape vs coverage (only if all dependent values validated)
+ const canValidateCoverage = !tensorShapeError && !sptError && !tpwError && !wpcError
+ if (canValidateCoverage) {
const { error: tsError, warning: tsWarning } = this.validateTensorShapeVsCoverage(
params.sizePerThread,
params.threadsPerWarp,