diff --git a/src/styles.css b/src/styles.css index 44603a8..942d807 100644 --- a/src/styles.css +++ b/src/styles.css @@ -192,6 +192,43 @@ button:active { background-color: #95a5a6; } +.dimension-color-select { + display: flex; + flex-direction: column; + gap: 0.5rem; + margin-top: 1rem; + padding: 0.75rem; + background-color: #f8f9fa; + border: 1px solid #e9ecef; + border-radius: 4px; + font-size: 0.9rem; + color: #555; + font-weight: 500; +} + +.dimension-color-select select { + margin-top: 0.25rem; + padding: 0.5rem; + border: 1px solid #ddd; + border-radius: 4px; + font-size: 0.9rem; + background-color: white; + cursor: pointer; + transition: border-color 0.2s, box-shadow 0.2s; +} + +.dimension-color-select select:focus { + outline: none; + border-color: #3498db; + box-shadow: 0 0 0 3px rgba(52, 152, 219, 0.1); +} + +.dimension-color-select select:disabled { + background-color: #f8f9fa; + cursor: not-allowed; + opacity: 0.6; +} + .dimension-list { display: flex; flex-direction: column; diff --git a/src/tabs/LinearLayoutTab.ts b/src/tabs/LinearLayoutTab.ts index b7ca939..ee21ea3 100644 --- a/src/tabs/LinearLayoutTab.ts +++ b/src/tabs/LinearLayoutTab.ts @@ -39,7 +39,9 @@ export class LinearLayoutTab extends CanvasTab { private readonly dimensionAddButtons: Record private readonly matrixEditor: LinearLayoutMatrixEditor private readonly layoutStatus: HTMLElement + private readonly colorDimensionSelect: HTMLSelectElement private dimensionState: Record + private selectedColorDimensionId: string | null constructor(tabId: string) { const tabContent = document.getElementById(tabId) @@ -102,8 +104,9 @@ export class LinearLayoutTab extends CanvasTab { const inputAdd = this.form.querySelector('[data-add-dimension="input"]') const outputAdd = this.form.querySelector('[data-add-dimension="output"]') const matrixButton = this.form.querySelector('#linear-edit-matrix') + const colorSelect = this.form.querySelector('#linear-color-dimension') - if (!inputList || !outputList || !inputAdd || !outputAdd || !matrixButton) { + if (!inputList || !outputList || !inputAdd || !outputAdd || !matrixButton || !colorSelect) { throw new Error('LinearLayoutTab dimension controls failed to initialize') } @@ -120,6 +123,11 @@ export class LinearLayoutTab extends CanvasTab { this.layoutStatus.className = 'layout-status' this.layoutStatus.setAttribute('role', 'status') matrixButton.insertAdjacentElement('afterend', this.layoutStatus) + this.colorDimensionSelect = colorSelect + this.selectedColorDimensionId = null + this.colorDimensionSelect.addEventListener('change', () => { + this.handleColorDimensionSelectionChange() + }) this.matrixEditor.onVisibilityChange((isOpen) => { this.toggleSidebarInteractivity(isOpen) }) @@ -190,6 +198,10 @@ export class LinearLayoutTab extends CanvasTab {

Input Dimensions

+
@@ -256,6 +268,9 @@ export class LinearLayoutTab extends CanvasTab { }) this.updateAddButtonState(type) + if (type === 'input') { + this.updateColorDimensionOptions() + } this.syncEditorAndLayout() } @@ -269,6 +284,7 @@ export class LinearLayoutTab extends CanvasTab { } private handleDimensionFieldChange(): void { + this.updateColorDimensionOptions() this.syncEditorAndLayout() } @@ -312,10 +328,12 @@ export class LinearLayoutTab extends CanvasTab { private updateRendererFromLayout(): void { const tensorShape = this.deriveTensorShape() this.params.tensorShape = tensorShape + const selectedDimension = this.getSelectedColorDimensionName() const renderer = this.getRenderer() if (renderer) { renderer.updateLayout(this.layout, this.params) + renderer.setColorByInputDimension(selectedDimension) return } @@ -324,7 +342,12 @@ export class LinearLayoutTab extends CanvasTab { this.layout, this.params, undefined, - { colorGrouping: 'thread', traversalMode: 'by-output', showCellText: false } + { + colorGrouping: 'thread', + traversalMode: 'by-output', + showCellText: false, + colorInputDimension: selectedDimension, + } ) this.setRenderer(newRenderer) newRenderer.render() @@ -397,6 +420,95 @@ export class LinearLayoutTab extends CanvasTab { this.sidebar.insertBefore(infoBlock, controlsContainer) } + private updateColorDimensionOptions(): void { + const select = this.colorDimensionSelect + const dimensions = this.dimensionState.input + + select.innerHTML = '' + + if (dimensions.length === 0) { + const placeholder = document.createElement('option') + placeholder.value = '' + placeholder.textContent = 'No input dimensions' + select.appendChild(placeholder) + select.disabled = true + this.selectedColorDimensionId = null + this.applySelectedColorDimension() + return + } + + select.disabled = false + dimensions.forEach((dimension) => { + const option = document.createElement('option') + option.value = dimension.id + option.textContent = this.getDimensionOptionLabel(dimension) + select.appendChild(option) + }) + + const selectionStillValid = this.selectedColorDimensionId && + dimensions.some((dimension) => dimension.id === this.selectedColorDimensionId) + ? this.selectedColorDimensionId + : null + + const defaultSelection = this.getDefaultColorDimensionId(dimensions) + const nextSelection = selectionStillValid ?? defaultSelection ?? dimensions[0]?.id ?? null + + if (nextSelection) { + this.selectedColorDimensionId = nextSelection + select.value = nextSelection + } else { + select.selectedIndex = 0 + this.selectedColorDimensionId = null + } + + this.applySelectedColorDimension() + } + + private getDimensionOptionLabel(dimension: LinearDimension): string { + const trimmed = dimension.name.trim() + if (trimmed.length > 0) { + return trimmed + } + if (dimension.name.length > 0) { + return dimension.name + } + return 'unnamed' + } + + private getDefaultColorDimensionId(dimensions: LinearDimension[]): string | undefined { + const preferred = dimensions.find((dimension) => dimension.name.trim().toLowerCase() === 'thread') + if (preferred) { + return preferred.id + } + return dimensions[0]?.id + } + + private handleColorDimensionSelectionChange(): void { + const value = this.colorDimensionSelect.value + this.selectedColorDimensionId = value || null + this.applySelectedColorDimension() + } + + private getSelectedColorDimensionName(): string | undefined { + if (!this.selectedColorDimensionId) { + return undefined + } + const dimension = this.dimensionState.input.find((item) => item.id === this.selectedColorDimensionId) + if (!dimension) { + return undefined + } + const trimmed = dimension.name.trim() + return trimmed.length > 0 ? trimmed : undefined + } + + private applySelectedColorDimension(): void { + const renderer = this.getRenderer() + if (!renderer) { + return + } + renderer.setColorByInputDimension(this.getSelectedColorDimensionName()) + } + private getDefaultName(type: DimensionType): string { const prefix = type === 'input' ? 'input' : 'output' return `${prefix}${this.dimensionState[type].length + 1}` diff --git a/src/visualization/CanvasRenderer.test.ts b/src/visualization/CanvasRenderer.test.ts index 5c25edf..b704070 100644 --- a/src/visualization/CanvasRenderer.test.ts +++ b/src/visualization/CanvasRenderer.test.ts @@ -1,5 +1,6 @@ import { describe, it, expect, beforeEach, vi } from 'vitest' import { CanvasRenderer, type PositionResolver } from './CanvasRenderer' +import { ColorScheme } from './ColorScheme' import { createBlockLayout } from '../layouts/BlockLayout' import type { BlockLayoutParams } from '../validation/InputValidator' import { LinearLayout } from '../core/LinearLayout' @@ -114,6 +115,43 @@ describe('CanvasRenderer', () => { expect(color0).not.toBe(color1) expect(scheme.getColorForThread(0)).toBe(color0) }) + + it('should color cells by a selected input dimension when provided', () => { + const layout = LinearLayout.identity1D(4, 'thread', 'dim0').multiply( + LinearLayout.identity1D(4, 'lane', 'dim1') + ) + const params: BlockLayoutParams = { + sizePerThread: [1, 1] as [number, number], + threadsPerWarp: [4, 1] as [number, number], + warpsPerCTA: [1, 1] as [number, number], + order: [0, 1] as [number, number], + tensorShape: [4, 4] as [number, number], + } + + const warpSpy = vi.spyOn(ColorScheme.prototype, 'getColorForWarp') + const threadSpy = vi.spyOn(ColorScheme.prototype, 'getColorForThread') + + const renderer = new CanvasRenderer( + canvas, + layout, + params, + undefined, + { + traversalMode: 'by-output', + colorGrouping: 'thread', + showCellText: false, + colorInputDimension: 'lane', + } + ) + + renderer.render() + + expect(warpSpy).toHaveBeenCalled() + expect(threadSpy).not.toHaveBeenCalled() + + warpSpy.mockRestore() + threadSpy.mockRestore() + }) }) describe('viewport integration', () => { diff --git a/src/visualization/CanvasRenderer.ts b/src/visualization/CanvasRenderer.ts index dde98d6..57ee35c 100644 --- a/src/visualization/CanvasRenderer.ts +++ b/src/visualization/CanvasRenderer.ts @@ -11,6 +11,7 @@ interface CanvasRendererOptions { colorGrouping?: ColorGroupingMode traversalMode?: TraversalMode showCellText?: boolean + colorInputDimension?: string } export interface CellInfo { @@ -74,6 +75,8 @@ export class CanvasRenderer { private traversalMode: TraversalMode private colorGrouping: ColorGroupingMode private showCellText: boolean + private customColorDimension?: string + private maxThreadIdObserved = 0 constructor( private canvas: HTMLCanvasElement, @@ -92,6 +95,8 @@ export class CanvasRenderer { this.colorGrouping = options?.colorGrouping ?? 'warp' this.traversalMode = options?.traversalMode ?? 'by-thread' this.showCellText = options?.showCellText ?? true + const trimmedColorDim = options?.colorInputDimension?.trim() + this.customColorDimension = trimmedColorDim && trimmedColorDim.length > 0 ? trimmedColorDim : undefined this.resetColorSchemeFromParams() // Initialize viewport controller @@ -138,6 +143,7 @@ export class CanvasRenderer { this.params.warpsPerCTA[1] const warpSize = this.params.threadsPerWarp[0] * this.params.threadsPerWarp[1] + this.maxThreadIdObserved = Math.max(totalThreads - 1, 0) for (let threadId = 0; threadId < totalThreads; threadId++) { const positions = this.positionResolver(this.layout, threadId) @@ -252,7 +258,12 @@ export class CanvasRenderer { } traverse(0) - this.rebuildColorSchemeFromThreads(maxThreadId) + this.maxThreadIdObserved = Math.max(maxThreadId, 0) + if (this.customColorDimension) { + this.resetColorSchemeForCustomDimension() + } else { + this.rebuildColorSchemeFromThreads(maxThreadId) + } return cache } @@ -277,6 +288,16 @@ export class CanvasRenderer { } } + private resetColorSchemeForCustomDimension(): void { + const dimName = this.customColorDimension + if (!dimName) { + return + } + const hasDimension = this.layout.hasInDim(dimName) + const colorCount = hasDimension ? Math.max(this.layout.getInDimSize(dimName), 1) : 1 + this.colorScheme = new ColorScheme(colorCount, 1) + } + /** * Render the entire visualization */ @@ -347,8 +368,8 @@ export class CanvasRenderer { const y = row * this.cellSize * viewport.scale + viewport.offsetY const size = this.cellSize * viewport.scale - // Fill cell with warp color - this.ctx.fillStyle = this.colorScheme.getColorForThread(cellInfo.threadId) + // Fill cell with color determined by current grouping + this.ctx.fillStyle = this.getCellFillColor(cellInfo) this.ctx.fillRect(x, y, size, size) // Draw text if cell is large enough @@ -359,6 +380,20 @@ export class CanvasRenderer { } } + private getCellFillColor(cellInfo: CellInfo): string { + if (this.customColorDimension) { + const value = cellInfo.inputCoords?.[this.customColorDimension] + if (typeof value === 'number' && Number.isFinite(value)) { + const normalized = Math.max(0, Math.trunc(value)) + return this.colorScheme.getColorForWarp(normalized) + } + } + if (this.colorGrouping === 'thread') { + return this.colorScheme.getColorForThread(cellInfo.threadId) + } + return this.colorScheme.getColorForWarp(cellInfo.warpId) + } + /** * Draw text inside a cell */ @@ -496,6 +531,23 @@ export class CanvasRenderer { return this.colorScheme.getColorForThread(threadId) } + setColorByInputDimension(dimensionName?: string): void { + const normalized = dimensionName?.trim() + const nextDimension = normalized && normalized.length > 0 ? normalized : undefined + if (this.customColorDimension === nextDimension) { + return + } + this.customColorDimension = nextDimension + if (this.customColorDimension) { + this.resetColorSchemeForCustomDimension() + } else if (this.traversalMode === 'by-output') { + this.rebuildColorSchemeFromThreads(this.maxThreadIdObserved) + } else { + this.resetColorSchemeFromParams() + } + this.render() + } + /** * Zoom in */