Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 187 additions & 38 deletions components/task-heatmap.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'use client'

import { useEffect, useMemo, useState } from 'react'
import { useEffect, useMemo, useRef, useState } from 'react'
import type { LeaderboardEntry } from '@/lib/types'
import { PROVIDER_COLORS, CATEGORY_ICONS } from '@/lib/types'
import { fetchSubmissionClient } from '@/lib/api'
Expand All @@ -14,13 +14,63 @@ interface TaskHeatmapProps {
onCategoriesChange: (categories: string[]) => void
}

interface TaskInfo {
score: number
maxScore: number
taskName: string
category: string
}

interface ModelTaskData {
model: string
provider: string
percentage: number
tasks: Map<string, { score: number; maxScore: number; taskName: string; category: string }>
/** Use plain object instead of Map so it can be serialized to sessionStorage */
tasks: Record<string, TaskInfo>
}

/**
* In-memory cache of submission details keyed by submission_id.
* Serialized to sessionStorage so it survives page refreshes.
*/
interface SerializedCache {
[submissionId: string]: ModelTaskData
}

const SESSION_STORAGE_KEY = 'pinchbench_heatmap_cache'

/** Load cache from sessionStorage (returns empty object if none) */
function loadCacheFromSession(): SerializedCache {
try {
const raw = sessionStorage.getItem(SESSION_STORAGE_KEY)
if (raw) {
return JSON.parse(raw) as SerializedCache
}
} catch {
// sessionStorage unavailable or corrupted — start fresh
}
return {}
}

/** Persist cache to sessionStorage */
function saveCacheToSession(cache: SerializedCache): void {
try {
sessionStorage.setItem(SESSION_STORAGE_KEY, JSON.stringify(cache))
} catch {
// sessionStorage write failed — ignore (quota exceeded, private mode, etc.)
}
}

/**
* Module-level cache (in-memory Map for fast access during a session).
* Backed by sessionStorage for persistence across page refreshes.
* Structure: { [submission_id]: ModelTaskData }
*/
const submissionCache: SerializedCache = loadCacheFromSession()

/** Maximum number of concurrent API requests when fetching submission details */
const CONCURRENCY_LIMIT = 10

function getScoreColor(ratio: number): string {
// Red (0%) -> Yellow (50%) -> Green (100%)
if (ratio >= 0.85) return 'hsl(142, 71%, 35%)'
Expand All @@ -40,80 +90,166 @@ function getScoreTextColor(ratio: number): string {

export function TaskHeatmap({ entries, selectedCategories, onCategoriesChange }: TaskHeatmapProps) {
const [modelData, setModelData] = useState<ModelTaskData[]>([])
const [loading, setLoading] = useState(true)
// loadingState tracks the load phase:
// - 'idle': no load in progress
// - 'initial': first load, no cached data (full page spinner)
// - 'incremental': have some cached data, fetching remaining entries (chips stay interactive)
// - 'done': all loaded successfully
// - 'error': failed
const [loadingState, setLoadingState] = useState<'idle' | 'initial' | 'incremental' | 'done' | 'error'>('idle')
const [error, setError] = useState<string | null>(null)
const [sortBy, setSortBy] = useState<'score' | 'name'>('score')
const [hoveredCell, setHoveredCell] = useState<{ model: string; taskId: string } | null>(null)
// Track how many entries have been loaded for progress display
const [loadedCount, setLoadedCount] = useState(0)

// Ref to track whether the current effect has been cancelled.
// Using ref instead of closure variable to avoid stale references after await.
const cancelledRef = useRef(false)

// Track the previous entries' submission IDs to skip re-fetching when
// the entries array reference changes but the underlying models are the same
// (e.g., when URL changes trigger a parent re-render with the same data).
const prevSubmissionIdsRef = useRef<string[] | null>(null)

// Fetch task-level data for each model's best submission
useEffect(() => {
let cancelled = false
cancelledRef.current = false
const currentCache = submissionCache

async function loadData() {
setLoading(true)
// Extract current submission IDs from entries
const currentIds = entries.map(e => e.submission_id)
const prevIds = prevSubmissionIdsRef.current

// If the underlying submission IDs are the same, skip re-fetching entirely.
// This prevents redundant API calls when parent re-renders with a new
// entries array reference but the same model list (e.g., URL param changes).
if (prevIds && prevIds.length === currentIds.length && prevIds.every((id, i) => id === currentIds[i])) {
return
}

// Update ref BEFORE any early returns or cache lookups so it's always in sync
prevSubmissionIdsRef.current = currentIds

// Separate entries into cached and uncached
const uncached: LeaderboardEntry[] = []
const initialData: ModelTaskData[] = []

for (const entry of entries) {
const cached = currentCache[entry.submission_id]
if (cached) {
initialData.push(cached)
} else {
uncached.push(entry)
}
}

// If all entries are cached, apply data immediately
if (uncached.length === 0) {
if (cancelledRef.current) return
setModelData(initialData)
setLoadingState('done')
setLoadedCount(initialData.length)
return
}

// Check cancelled before making any state updates, in case the old effect
// resolved its await after a new effect has already started.
if (cancelledRef.current) return

// We have some cached data — show it immediately while fetching the rest.
// Set 'incremental' so chips stay interactive; if there is no cached
// data yet, use 'initial' to show the full spinner.
const hasCachedData = initialData.length > 0
setModelData(hasCachedData ? initialData : [])
setLoadingState(hasCachedData ? 'incremental' : 'initial')
setError(null)

let totalLoaded = initialData.length
setLoadedCount(totalLoaded)

try {
// Fetch submissions in batches of 5 to avoid overwhelming the API
const results: ModelTaskData[] = []
const batchSize = 5
// Fetch uncached entries in controlled concurrency batches
const results: ModelTaskData[] = [...initialData]

for (let i = 0; i < entries.length; i += batchSize) {
if (cancelled) return
for (let i = 0; i < uncached.length; i += CONCURRENCY_LIMIT) {
if (cancelledRef.current) return

const batch = entries.slice(i, i + batchSize)
const batch = uncached.slice(i, i + CONCURRENCY_LIMIT)
const batchResults = await Promise.all(
batch.map(async (entry) => {
batch.map(async (entry): Promise<ModelTaskData | null> => {
try {
const response = await fetchSubmissionClient(entry.submission_id)
if (cancelledRef.current) return null
const submission = transformSubmission(response.submission)
const taskMap = new Map<string, { score: number; maxScore: number; taskName: string; category: string }>()

const taskRecord: Record<string, TaskInfo> = {}
for (const task of submission.task_results) {
taskMap.set(task.task_id, {
taskRecord[task.task_id] = {
score: task.score,
maxScore: task.max_score,
taskName: task.task_name,
category: task.category,
})
}
}

return {
const result: ModelTaskData = {
model: entry.model,
provider: entry.provider,
percentage: entry.percentage,
tasks: taskMap,
} as ModelTaskData
tasks: taskRecord,
}

// Store in module-level cache
currentCache[entry.submission_id] = result
return result
} catch {
return null
}
})
)

results.push(...batchResults.filter((r): r is ModelTaskData => r !== null))
if (cancelledRef.current) return

const validBatchResults = batchResults.filter((r): r is ModelTaskData => r !== null)
totalLoaded += validBatchResults.length

// Only persist to sessionStorage if this batch yielded new entries
if (validBatchResults.length > 0) {
saveCacheToSession(currentCache)
}

// Functional update to avoid stale closure
setModelData(prev => {
const prevModels = new Set(prev.map(d => d.model))
const newItems = validBatchResults.filter(r => !prevModels.has(r.model))
if (newItems.length === 0) return prev
return [...prev, ...newItems]
})
setLoadedCount(totalLoaded)
}

if (!cancelled) {
setModelData(results)
setLoading(false)
if (!cancelledRef.current) {
setLoadingState('done')
}
} catch {
if (!cancelled) {
if (!cancelledRef.current) {
setError('Failed to load task data')
setLoading(false)
setLoadingState('error')
}
}
}

loadData()
return () => { cancelled = true }
return () => { cancelledRef.current = true }
}, [entries])

// Collect all unique tasks and sort by category
const allTasks = useMemo(() => {
const taskMap = new Map<string, { taskName: string; category: string }>()
for (const model of modelData) {
for (const [taskId, task] of model.tasks) {
for (const [taskId, task] of Object.entries(model.tasks)) {
if (!taskMap.has(taskId)) {
taskMap.set(taskId, { taskName: task.taskName, category: task.category })
}
Expand Down Expand Up @@ -156,7 +292,7 @@ export function TaskHeatmap({ entries, selectedCategories, onCategoriesChange }:
let sumScore = 0
let sumMax = 0
for (const task of filteredTasks) {
const td = m.tasks.get(task.taskId)
const td = m.tasks[task.taskId]
if (td) {
sumScore += td.score
sumMax += td.maxScore
Expand Down Expand Up @@ -211,38 +347,42 @@ export function TaskHeatmap({ entries, selectedCategories, onCategoriesChange }:
onCategoriesChange(next)
}

if (loading) {
const isInitialLoad = loadingState === 'initial'
const isIncrementalLoad = loadingState === 'incremental'
const hasError = loadingState === 'error'
const hasAnyData = modelData.length > 0

// Early return for initial load — full page spinner
if (isInitialLoad) {
return (
<div className="flex flex-col items-center justify-center h-64 rounded-lg border border-border bg-muted/30">
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-primary mb-3" />
<p className="text-sm text-muted-foreground">
Loading task-level data for {entries.length} models...
</p>
{modelData.length > 0 && (
<p className="text-xs text-muted-foreground/60 mt-1">
{modelData.length} of {entries.length} loaded
</p>
)}
</div>
)
}

if (error) {
// Early return for error state
if (hasError) {
return (
<div className="flex items-center justify-center h-64 rounded-lg border border-border bg-muted/30">
<div className="flex flex-col items-center justify-center h-64 rounded-lg border border-border bg-muted/30">
<p className="text-sm text-destructive">{error}</p>
</div>
)
}

if (modelData.length === 0 || allTasks.length === 0) {
// Early return for empty / no-data state
if (!hasAnyData || allTasks.length === 0) {
return (
<div className="flex items-center justify-center h-64 rounded-lg border border-border bg-muted/30">
<div className="flex flex-col items-center justify-center h-64 rounded-lg border border-border bg-muted/30">
<p className="text-sm text-muted-foreground">No task data available.</p>
</div>
)
}

// Early return for no matching categories
if (categoryFilterActive && filteredTasks.length === 0) {
return (
<div className="rounded-lg border border-border bg-muted/30 px-4 py-8 text-center">
Expand All @@ -260,6 +400,7 @@ export function TaskHeatmap({ entries, selectedCategories, onCategoriesChange }:
)
}

// Main render — chips and table are always accessible (even during incremental loading)
return (
<div>
<h2 className="text-lg font-semibold text-foreground mb-1">
Expand Down Expand Up @@ -320,6 +461,14 @@ export function TaskHeatmap({ entries, selectedCategories, onCategoriesChange }:
})}
</div>

{/* Incremental loading progress — shown while fetching remaining entries */}
{isIncrementalLoad && (
<div className="mb-2 flex items-center gap-2 text-xs text-muted-foreground">
<div className="animate-spin rounded-full h-3 w-3 border-b border-primary" />
<span>Caching remaining models: {loadedCount} of {entries.length}</span>
</div>
)}

{/* Controls */}
<div className="flex flex-col gap-2 sm:flex-row sm:items-center sm:gap-4 mb-4">
<div className="flex flex-col gap-1.5">
Expand Down Expand Up @@ -462,7 +611,7 @@ export function TaskHeatmap({ entries, selectedCategories, onCategoriesChange }:
</div>
</td>
{filteredTasks.map((task) => {
const taskData = model.tasks.get(task.taskId)
const taskData = model.tasks[task.taskId]
const ratio = taskData ? taskData.score / taskData.maxScore : 0
const hasData = !!taskData
const isHovered = hoveredCell?.model === model.model && hoveredCell?.taskId === task.taskId
Expand Down