From e10dce795fc53e6b3919135eb7a259b7dc3b8cda Mon Sep 17 00:00:00 2001 From: leeliu103 Date: Fri, 5 Dec 2025 17:38:47 +0000 Subject: [PATCH] Add color dimension selector for Linear Layout visualization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Users can now select which input dimension determines cell colors in the Linear Layout visualization. The dropdown automatically updates when dimensions are added, removed, or renamed, and defaults to the 'thread' dimension when available. Changes: - Add color dimension dropdown to Linear Layout tab - Implement dynamic color rendering based on selected input dimension - Add CSS styling matching existing design system - Add test coverage for color dimension functionality - Feature is isolated to Linear Layout with no impact on other layouts 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/styles.css | 37 ++++++++ src/tabs/LinearLayoutTab.ts | 116 ++++++++++++++++++++++- src/visualization/CanvasRenderer.test.ts | 38 ++++++++ src/visualization/CanvasRenderer.ts | 58 +++++++++++- 4 files changed, 244 insertions(+), 5 deletions(-) diff --git a/src/styles.css b/src/styles.css index 44603a8..942d807 100644 --- a/src/styles.css +++ b/src/styles.css @@ -192,6 +192,43 @@ button:active { background-color: #95a5a6; } +.dimension-color-select { + display: flex; + flex-direction: column; + gap: 0.5rem; + margin-top: 1rem; + padding: 0.75rem; + background-color: #f8f9fa; + border: 1px solid #e9ecef; + border-radius: 4px; + font-size: 0.9rem; + color: #555; + font-weight: 500; +} + +.dimension-color-select select { + margin-top: 0.25rem; + padding: 0.5rem; + border: 1px solid #ddd; + border-radius: 4px; + font-size: 0.9rem; + background-color: white; + cursor: pointer; + transition: border-color 0.2s, box-shadow 0.2s; +} + +.dimension-color-select select:focus { + outline: none; + border-color: #3498db; + box-shadow: 0 0 0 3px rgba(52, 152, 219, 0.1); +} + +.dimension-color-select select:disabled { + background-color: #f8f9fa; + cursor: not-allowed; + opacity: 0.6; +} + .dimension-list { display: flex; flex-direction: column; diff --git a/src/tabs/LinearLayoutTab.ts b/src/tabs/LinearLayoutTab.ts index b7ca939..ee21ea3 100644 --- a/src/tabs/LinearLayoutTab.ts +++ b/src/tabs/LinearLayoutTab.ts @@ -39,7 +39,9 @@ export class LinearLayoutTab extends CanvasTab { private readonly dimensionAddButtons: Record private readonly matrixEditor: LinearLayoutMatrixEditor private readonly layoutStatus: HTMLElement + private readonly colorDimensionSelect: HTMLSelectElement private dimensionState: Record + private selectedColorDimensionId: string | null constructor(tabId: string) { const tabContent = document.getElementById(tabId) @@ -102,8 +104,9 @@ export class LinearLayoutTab extends CanvasTab { const inputAdd = this.form.querySelector('[data-add-dimension="input"]') const outputAdd = this.form.querySelector('[data-add-dimension="output"]') const matrixButton = this.form.querySelector('#linear-edit-matrix') + const colorSelect = this.form.querySelector('#linear-color-dimension') - if (!inputList || !outputList || !inputAdd || !outputAdd || !matrixButton) { + if (!inputList || !outputList || !inputAdd || !outputAdd || !matrixButton || !colorSelect) { throw new Error('LinearLayoutTab dimension controls failed to initialize') } @@ -120,6 +123,11 @@ export class LinearLayoutTab extends CanvasTab { this.layoutStatus.className = 'layout-status' this.layoutStatus.setAttribute('role', 'status') matrixButton.insertAdjacentElement('afterend', this.layoutStatus) + this.colorDimensionSelect = colorSelect + this.selectedColorDimensionId = null + this.colorDimensionSelect.addEventListener('change', () => { + this.handleColorDimensionSelectionChange() + }) this.matrixEditor.onVisibilityChange((isOpen) => { this.toggleSidebarInteractivity(isOpen) }) @@ -190,6 +198,10 @@ export class LinearLayoutTab extends CanvasTab {

Input Dimensions

+
@@ -256,6 +268,9 @@ export class LinearLayoutTab extends CanvasTab { }) this.updateAddButtonState(type) + if (type === 'input') { + this.updateColorDimensionOptions() + } this.syncEditorAndLayout() } @@ -269,6 +284,7 @@ export class LinearLayoutTab extends CanvasTab { } private handleDimensionFieldChange(): void { + this.updateColorDimensionOptions() this.syncEditorAndLayout() } @@ -312,10 +328,12 @@ export class LinearLayoutTab extends CanvasTab { private updateRendererFromLayout(): void { const tensorShape = this.deriveTensorShape() this.params.tensorShape = tensorShape + const selectedDimension = this.getSelectedColorDimensionName() const renderer = this.getRenderer() if (renderer) { renderer.updateLayout(this.layout, this.params) + renderer.setColorByInputDimension(selectedDimension) return } @@ -324,7 +342,12 @@ export class LinearLayoutTab extends CanvasTab { this.layout, this.params, undefined, - { colorGrouping: 'thread', traversalMode: 'by-output', showCellText: false } + { + colorGrouping: 'thread', + traversalMode: 'by-output', + showCellText: false, + colorInputDimension: selectedDimension, + } ) this.setRenderer(newRenderer) newRenderer.render() @@ -397,6 +420,95 @@ export class LinearLayoutTab extends CanvasTab { this.sidebar.insertBefore(infoBlock, controlsContainer) } + private updateColorDimensionOptions(): void { + const select = this.colorDimensionSelect + const dimensions = this.dimensionState.input + + select.innerHTML = '' + + if (dimensions.length === 0) { + const placeholder = document.createElement('option') + placeholder.value = '' + placeholder.textContent = 'No input dimensions' + select.appendChild(placeholder) + select.disabled = true + this.selectedColorDimensionId = null + this.applySelectedColorDimension() + return + } + + select.disabled = false + dimensions.forEach((dimension) => { + const option = document.createElement('option') + option.value = dimension.id + option.textContent = this.getDimensionOptionLabel(dimension) + select.appendChild(option) + }) + + const selectionStillValid = this.selectedColorDimensionId && + dimensions.some((dimension) => dimension.id === this.selectedColorDimensionId) + ? this.selectedColorDimensionId + : null + + const defaultSelection = this.getDefaultColorDimensionId(dimensions) + const nextSelection = selectionStillValid ?? defaultSelection ?? dimensions[0]?.id ?? null + + if (nextSelection) { + this.selectedColorDimensionId = nextSelection + select.value = nextSelection + } else { + select.selectedIndex = 0 + this.selectedColorDimensionId = null + } + + this.applySelectedColorDimension() + } + + private getDimensionOptionLabel(dimension: LinearDimension): string { + const trimmed = dimension.name.trim() + if (trimmed.length > 0) { + return trimmed + } + if (dimension.name.length > 0) { + return dimension.name + } + return 'unnamed' + } + + private getDefaultColorDimensionId(dimensions: LinearDimension[]): string | undefined { + const preferred = dimensions.find((dimension) => dimension.name.trim().toLowerCase() === 'thread') + if (preferred) { + return preferred.id + } + return dimensions[0]?.id + } + + private handleColorDimensionSelectionChange(): void { + const value = this.colorDimensionSelect.value + this.selectedColorDimensionId = value || null + this.applySelectedColorDimension() + } + + private getSelectedColorDimensionName(): string | undefined { + if (!this.selectedColorDimensionId) { + return undefined + } + const dimension = this.dimensionState.input.find((item) => item.id === this.selectedColorDimensionId) + if (!dimension) { + return undefined + } + const trimmed = dimension.name.trim() + return trimmed.length > 0 ? trimmed : undefined + } + + private applySelectedColorDimension(): void { + const renderer = this.getRenderer() + if (!renderer) { + return + } + renderer.setColorByInputDimension(this.getSelectedColorDimensionName()) + } + private getDefaultName(type: DimensionType): string { const prefix = type === 'input' ? 'input' : 'output' return `${prefix}${this.dimensionState[type].length + 1}` diff --git a/src/visualization/CanvasRenderer.test.ts b/src/visualization/CanvasRenderer.test.ts index 5c25edf..b704070 100644 --- a/src/visualization/CanvasRenderer.test.ts +++ b/src/visualization/CanvasRenderer.test.ts @@ -1,5 +1,6 @@ import { describe, it, expect, beforeEach, vi } from 'vitest' import { CanvasRenderer, type PositionResolver } from './CanvasRenderer' +import { ColorScheme } from './ColorScheme' import { createBlockLayout } from '../layouts/BlockLayout' import type { BlockLayoutParams } from '../validation/InputValidator' import { LinearLayout } from '../core/LinearLayout' @@ -114,6 +115,43 @@ describe('CanvasRenderer', () => { expect(color0).not.toBe(color1) expect(scheme.getColorForThread(0)).toBe(color0) }) + + it('should color cells by a selected input dimension when provided', () => { + const layout = LinearLayout.identity1D(4, 'thread', 'dim0').multiply( + LinearLayout.identity1D(4, 'lane', 'dim1') + ) + const params: BlockLayoutParams = { + sizePerThread: [1, 1] as [number, number], + threadsPerWarp: [4, 1] as [number, number], + warpsPerCTA: [1, 1] as [number, number], + order: [0, 1] as [number, number], + tensorShape: [4, 4] as [number, number], + } + + const warpSpy = vi.spyOn(ColorScheme.prototype, 'getColorForWarp') + const threadSpy = vi.spyOn(ColorScheme.prototype, 'getColorForThread') + + const renderer = new CanvasRenderer( + canvas, + layout, + params, + undefined, + { + traversalMode: 'by-output', + colorGrouping: 'thread', + showCellText: false, + colorInputDimension: 'lane', + } + ) + + renderer.render() + + expect(warpSpy).toHaveBeenCalled() + expect(threadSpy).not.toHaveBeenCalled() + + warpSpy.mockRestore() + threadSpy.mockRestore() + }) }) describe('viewport integration', () => { diff --git a/src/visualization/CanvasRenderer.ts b/src/visualization/CanvasRenderer.ts index dde98d6..57ee35c 100644 --- a/src/visualization/CanvasRenderer.ts +++ b/src/visualization/CanvasRenderer.ts @@ -11,6 +11,7 @@ interface CanvasRendererOptions { colorGrouping?: ColorGroupingMode traversalMode?: TraversalMode showCellText?: boolean + colorInputDimension?: string } export interface CellInfo { @@ -74,6 +75,8 @@ export class CanvasRenderer { private traversalMode: TraversalMode private colorGrouping: ColorGroupingMode private showCellText: boolean + private customColorDimension?: string + private maxThreadIdObserved = 0 constructor( private canvas: HTMLCanvasElement, @@ -92,6 +95,8 @@ export class CanvasRenderer { this.colorGrouping = options?.colorGrouping ?? 'warp' this.traversalMode = options?.traversalMode ?? 'by-thread' this.showCellText = options?.showCellText ?? true + const trimmedColorDim = options?.colorInputDimension?.trim() + this.customColorDimension = trimmedColorDim && trimmedColorDim.length > 0 ? trimmedColorDim : undefined this.resetColorSchemeFromParams() // Initialize viewport controller @@ -138,6 +143,7 @@ export class CanvasRenderer { this.params.warpsPerCTA[1] const warpSize = this.params.threadsPerWarp[0] * this.params.threadsPerWarp[1] + this.maxThreadIdObserved = Math.max(totalThreads - 1, 0) for (let threadId = 0; threadId < totalThreads; threadId++) { const positions = this.positionResolver(this.layout, threadId) @@ -252,7 +258,12 @@ export class CanvasRenderer { } traverse(0) - this.rebuildColorSchemeFromThreads(maxThreadId) + this.maxThreadIdObserved = Math.max(maxThreadId, 0) + if (this.customColorDimension) { + this.resetColorSchemeForCustomDimension() + } else { + this.rebuildColorSchemeFromThreads(maxThreadId) + } return cache } @@ -277,6 +288,16 @@ export class CanvasRenderer { } } + private resetColorSchemeForCustomDimension(): void { + const dimName = this.customColorDimension + if (!dimName) { + return + } + const hasDimension = this.layout.hasInDim(dimName) + const colorCount = hasDimension ? Math.max(this.layout.getInDimSize(dimName), 1) : 1 + this.colorScheme = new ColorScheme(colorCount, 1) + } + /** * Render the entire visualization */ @@ -347,8 +368,8 @@ export class CanvasRenderer { const y = row * this.cellSize * viewport.scale + viewport.offsetY const size = this.cellSize * viewport.scale - // Fill cell with warp color - this.ctx.fillStyle = this.colorScheme.getColorForThread(cellInfo.threadId) + // Fill cell with color determined by current grouping + this.ctx.fillStyle = this.getCellFillColor(cellInfo) this.ctx.fillRect(x, y, size, size) // Draw text if cell is large enough @@ -359,6 +380,20 @@ export class CanvasRenderer { } } + private getCellFillColor(cellInfo: CellInfo): string { + if (this.customColorDimension) { + const value = cellInfo.inputCoords?.[this.customColorDimension] + if (typeof value === 'number' && Number.isFinite(value)) { + const normalized = Math.max(0, Math.trunc(value)) + return this.colorScheme.getColorForWarp(normalized) + } + } + if (this.colorGrouping === 'thread') { + return this.colorScheme.getColorForThread(cellInfo.threadId) + } + return this.colorScheme.getColorForWarp(cellInfo.warpId) + } + /** * Draw text inside a cell */ @@ -496,6 +531,23 @@ export class CanvasRenderer { return this.colorScheme.getColorForThread(threadId) } + setColorByInputDimension(dimensionName?: string): void { + const normalized = dimensionName?.trim() + const nextDimension = normalized && normalized.length > 0 ? normalized : undefined + if (this.customColorDimension === nextDimension) { + return + } + this.customColorDimension = nextDimension + if (this.customColorDimension) { + this.resetColorSchemeForCustomDimension() + } else if (this.traversalMode === 'by-output') { + this.rebuildColorSchemeFromThreads(this.maxThreadIdObserved) + } else { + this.resetColorSchemeFromParams() + } + this.render() + } + /** * Zoom in */