diff --git a/crates/piston-core/src/gpu/buffer_allocator/allocator.rs b/crates/piston-core/src/gpu/buffer_allocator/allocator.rs index 43160da9..916ab1dd 100644 --- a/crates/piston-core/src/gpu/buffer_allocator/allocator.rs +++ b/crates/piston-core/src/gpu/buffer_allocator/allocator.rs @@ -7,7 +7,7 @@ use crate::{ PooledGPUBuffer, TensorUsageRecords, UNIFORM_ALIGN, WgpuDevice, }, }; -use crate::{HashMap, LazyOp}; +use crate::{HashMap, HashSet, LazyOp}; use parking_lot::RwLock; use std::num::NonZero; use std::{borrow::Cow, collections::BTreeMap}; @@ -281,6 +281,7 @@ impl BufferAllocator { output_tensors: &BTreeMap, assignments: &mut HashMap, gpu_compile_keys: &HashMap, + shareable_ids: &HashSet, use_shared_buffers: bool, device: &WgpuDevice, ) -> Result<(), DeviceError> { @@ -289,9 +290,16 @@ impl BufferAllocator { let mut shared_objects: Vec = Vec::with_capacity(records.0.len()); for record in records.0.iter() { + let is_output = output_tensors.get(&record.last_consumer_id).is_some(); + let is_shareable_id = record + .id + .map(|id| shareable_ids.contains(&id)) + .unwrap_or(false); + let should_be_shared = use_shared_buffers - && !(record.requires_grad.unwrap_or(false) - || output_tensors.get(&record.last_consumer_id).is_some()); + && is_shareable_id + && !is_output + && !record.requires_grad.unwrap_or(false); let mut best_obj = None; @@ -370,6 +378,7 @@ impl BufferAllocator { execution_order: &[&OpTensor], output_tensors: &BTreeMap, gpu_compile_keys: &HashMap, + shareable_ids: &HashSet, use_shared_buffers: bool, device: &WgpuDevice, ) -> Result, DeviceError> { @@ -395,6 +404,7 @@ impl BufferAllocator { output_tensors, &mut assignments, gpu_compile_keys, + shareable_ids, use_shared_buffers, device, )?; diff --git a/crates/piston-core/src/gpu/buffer_allocator/lazy_graph_executor.rs b/crates/piston-core/src/gpu/buffer_allocator/lazy_graph_executor.rs index e8454d47..bb788ec9 100644 --- a/crates/piston-core/src/gpu/buffer_allocator/lazy_graph_executor.rs +++ b/crates/piston-core/src/gpu/buffer_allocator/lazy_graph_executor.rs @@ -1,7 +1,9 @@ use crate::{ Compiled, CpuUniform, DebugSelection, Executable, ExecutionError, ExecutionResult, GPUBuffer, HashMap, HashSet, Hasher as HasherType, Inner, LazyOp, StepLog, StepLogConfig, Storage, - TensorError, WgpuDevice, reset_scope_context, + TensorError, WgpuDevice, + gpu::{TensorOpDesc, trace_sink}, + reset_scope_context, }; #[cfg(feature = "debug")] use crate::{DebugTensor, Device, DeviceStorage}; @@ -392,7 +394,34 @@ impl LazyGraphExecutor { } } - log::debug!("Post-order hash: {:?}", hasher.finish()); + // Determine which tensors are safe candidates for shared-object buffers. + // + // A tensor is considered shareable if: + // - It does not have extra external strong references beyond what we expect from the + // execution graph and (optionally) owned_tensors. + // - It is not a leaf / parameter (requires_grad). + // - It does not retain its gradient. + // + // These tensors are good candidates for hosting shared-object buffers whose contents may + // be reused for other logical tensors once their lifetimes (based on execution_order) + // have ended. + let mut shareable_ids = + HashSet::with_capacity_and_hasher(post_order.len(), Default::default()); + for t in &post_order { + let id = t.id(); + let expected_strong = owned_tensors + .as_ref() + .and_then(|ot| ot.contains(&id).then_some(2)) + .unwrap_or(1); + let strong = t.strong_count(); + + let externally_pinned = + strong > expected_strong || t.requires_grad() || t.retains_grad(); + + if !externally_pinned { + shareable_ids.insert(id); + } + } let output_tensors = tensors .iter() @@ -483,6 +512,7 @@ impl LazyGraphExecutor { &post_order, &output_tensors, &compile_keys, + &shareable_ids, self.shared_object_allocation_enabled, gpu_device, )?) diff --git a/crates/piston-core/src/gpu/device.rs b/crates/piston-core/src/gpu/device.rs index 0fa87804..931a7a71 100644 --- a/crates/piston-core/src/gpu/device.rs +++ b/crates/piston-core/src/gpu/device.rs @@ -304,6 +304,7 @@ impl WgpuDevice { execution_order: &[&OpTensor], output_tensors: &BTreeMap, gpu_compile_keys: &HashMap, + shareable_ids: &crate::HashSet, use_shared_buffers: bool, device: &WgpuDevice, ) -> Result, DeviceError> { @@ -311,6 +312,7 @@ impl WgpuDevice { execution_order, output_tensors, gpu_compile_keys, + shareable_ids, use_shared_buffers, device, ) diff --git a/examples/finetuning/.editorconfig b/examples/finetuning/.editorconfig new file mode 100644 index 00000000..5696c4e8 --- /dev/null +++ b/examples/finetuning/.editorconfig @@ -0,0 +1,23 @@ +root = true + +[*] +end_of_line = lf +insert_final_newline = true +indent_style = tab +indent_size = tab +tab_width = 2 +charset = utf-8 +trim_trailing_whitespace = true + +[*.py] +indent_style = space +indent_size = 4 +tab_width = 4 + +[*.ipynb] +indent_style = space +indent_size = 4 +tab_width = 4 + +[package.json] +indent_style = space diff --git a/examples/finetuning/.gitignore b/examples/finetuning/.gitignore new file mode 100644 index 00000000..85b308c3 --- /dev/null +++ b/examples/finetuning/.gitignore @@ -0,0 +1,28 @@ +node_modules + +# Output +.output +.vercel +.netlify +.wrangler +/.svelte-kit +/build + +# OS +.DS_Store +Thumbs.db + +# Env +*.local +.env* +!.env.production +!.env.development +!.env.example +!.env.test + +# Vite +vite.config.js.timestamp-* +vite.config.ts.timestamp-* + +static/tokenizer +static/tokenized diff --git a/examples/finetuning/.npmrc b/examples/finetuning/.npmrc new file mode 100644 index 00000000..fbb338d9 --- /dev/null +++ b/examples/finetuning/.npmrc @@ -0,0 +1,2 @@ +engine-strict=true +use-node-version=22.17.1 diff --git a/examples/finetuning/.prettierignore b/examples/finetuning/.prettierignore new file mode 100644 index 00000000..ab78a95d --- /dev/null +++ b/examples/finetuning/.prettierignore @@ -0,0 +1,4 @@ +# Package Managers +package-lock.json +pnpm-lock.yaml +yarn.lock diff --git a/examples/finetuning/.prettierrc b/examples/finetuning/.prettierrc new file mode 100644 index 00000000..bd92dfb0 --- /dev/null +++ b/examples/finetuning/.prettierrc @@ -0,0 +1,18 @@ +{ + "useTabs": true, + "singleQuote": true, + "trailingComma": "none", + "printWidth": 100, + "tabWidth": 2, + "plugins": [ + "prettier-plugin-svelte" + ], + "overrides": [ + { + "files": "*.svelte", + "options": { + "parser": "svelte" + } + } + ] +} diff --git a/examples/finetuning/README.md b/examples/finetuning/README.md new file mode 100644 index 00000000..99aaad96 --- /dev/null +++ b/examples/finetuning/README.md @@ -0,0 +1 @@ +This is the source code for the finetuning demo at [finetune.sequence.toys](https://finetune.sequence.toys). It is a static Svelte app. diff --git a/examples/finetuning/eslint.config.js b/examples/finetuning/eslint.config.js new file mode 100644 index 00000000..e00e208a --- /dev/null +++ b/examples/finetuning/eslint.config.js @@ -0,0 +1,81 @@ +import { includeIgnoreFile } from '@eslint/compat'; +import js from '@eslint/js'; +import prettier from 'eslint-config-prettier'; +import perfectionist from 'eslint-plugin-perfectionist'; +import svelte from 'eslint-plugin-svelte'; +import { defineConfig } from 'eslint/config'; +import globals from 'globals'; +import { fileURLToPath } from 'node:url'; +import ts from 'typescript-eslint'; +const gitignorePath = fileURLToPath(new URL('./.gitignore', import.meta.url)); +const tsconfigRootDir = fileURLToPath(new URL('.', import.meta.url)); + +export default defineConfig( + includeIgnoreFile(gitignorePath), + js.configs.recommended, + ...ts.configs.recommended, + ...svelte.configs['flat/recommended'], + ...svelte.configs['flat/prettier'], + { + languageOptions: { + globals: { + ...globals.browser, + ...globals.node + } + } + }, + { + files: ['**/*.svelte', '**/*.svelte.ts', '**/*.ts'], + + languageOptions: { + parserOptions: { + parser: ts.parser, + // Ensure the TypeScript parser knows which tsconfig to use in this package + tsconfigRootDir, + projectService: true + } + } + }, + { + plugins: { + prettier, + perfectionist + }, + rules: { + 'no-unused-vars': 'off', + '@typescript-eslint/ban-types': 'off', + '@typescript-eslint/no-explicit-any': 'warn', + '@typescript-eslint/no-unused-vars': [ + 'warn', + { + argsIgnorePattern: '^_', + varsIgnorePattern: '^_', + caughtErrorsIgnorePattern: '^_' + } + ], + '@typescript-eslint/no-var-requires': 'off', + 'perfectionist/sort-imports': [ + 'error', + { + type: 'natural', + order: 'asc', + groups: [ + 'type', + ['builtin', 'external'], + 'internal-type', + 'internal', + ['parent-type', 'sibling-type', 'index-type'], + ['parent', 'sibling', 'index'], + 'side-effect', + 'style', + 'object', + 'unknown' + ] + } + ], + 'perfectionist/sort-exports': ['error', { type: 'natural' }], + 'perfectionist/sort-named-exports': ['error', { type: 'natural' }], + 'perfectionist/sort-named-imports': ['error', { type: 'natural' }] + } + } +); diff --git a/examples/finetuning/package.json b/examples/finetuning/package.json new file mode 100644 index 00000000..47e53ff9 --- /dev/null +++ b/examples/finetuning/package.json @@ -0,0 +1,63 @@ +{ + "name": "piston-finetune-toy", + "private": true, + "version": "0.0.1", + "type": "module", + "scripts": { + "dev": "vite dev", + "build": "vite build", + "preview": "vite preview", + "prepare": "svelte-kit sync || echo ''", + "check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json", + "check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch", + "lint": "eslint . && prettier --check .", + "format": "prettier --write ." + }, + "dependencies": { + "@codemirror/autocomplete": "^6.19.1", + "@codemirror/language": "^6.11.3", + "@codemirror/lint": "^6.9.1", + "@codemirror/state": "^6.5.2", + "@codemirror/view": "^6.38.6", + "@huggingface/jinja": "^0.5.1", + "@lezer/highlight": "^1.2.3", + "@lucide/svelte": "^0.554.0", + "@piston-ml/piston-web": "workspace:^", + "@types/katex": "^0.16.7", + "codemirror": "^6.0.2", + "echarts": "^6.0.0", + "example-common": "workspace:*", + "katex": "^0.16.23", + "maxrects-packer": "^2.7.3", + "random-js": "^2.1.0", + "svelte-portal": "^2.2.1", + "unique-names-generator": "^4.7.1" + }, + "devDependencies": { + "@eslint/compat": "^1.4.0", + "@eslint/js": "^9.37.0", + "@sveltejs/adapter-static": "^3.0.10", + "@sveltejs/kit": "^2.46.4", + "@sveltejs/vite-plugin-svelte": "^6.2.1", + "@tailwindcss/vite": "^4.1.14", + "@types/glob": "^9.0.0", + "@webgpu/types": "^0.1.65", + "eslint": "^9.37.0", + "eslint-config-prettier": "^10.1.8", + "eslint-plugin-perfectionist": "^4.15.1", + "eslint-plugin-svelte": "^3.12.4", + "glob": "^11.0.3", + "globals": "^16.4.0", + "prettier": "^3.6.2", + "prettier-plugin-svelte": "^3.4.0", + "rollup": "^4.52.4", + "sirv": "^2.0.4", + "svelte": "workspace:^", + "svelte-check": "^4.3.3", + "tailwindcss": "^4.1.14", + "typescript": "^5.9.3", + "typescript-eslint": "^8.46.0", + "vite": "^7.1.9", + "vite-plugin-wasm": "^3.5.0" + } +} diff --git a/examples/finetuning/src/app.css b/examples/finetuning/src/app.css new file mode 100644 index 00000000..aec164aa --- /dev/null +++ b/examples/finetuning/src/app.css @@ -0,0 +1,25 @@ +@import 'tailwindcss' source('./'); +@import 'example-common/theme.css'; + +@source "../../../packages/example-common/src/**/*.{svelte,js,ts}"; + +html { + /* Base font size for fine pointers (desktop/mouse) */ + font-size: 16px; + overscroll-behavior: none; + @apply h-full; + @apply w-full; +} + +@media (pointer: coarse) { + html { + font-size: 18px; + } +} + +body { + @apply h-full; + @apply w-full; + @apply text-base; + overscroll-behavior: none; +} diff --git a/examples/finetuning/src/app.d.ts b/examples/finetuning/src/app.d.ts new file mode 100644 index 00000000..da2a8798 --- /dev/null +++ b/examples/finetuning/src/app.d.ts @@ -0,0 +1,14 @@ +// See https://svelte.dev/docs/kit/types#app.d.ts +// for information about these interfaces +declare global { + const __COMMIT_HASH__: string; + namespace App { + // interface Error {} + // interface Locals {} + // interface PageData {} + // interface PageState {} + // interface Platform {} + } +} + +export {}; diff --git a/examples/finetuning/src/app.html b/examples/finetuning/src/app.html new file mode 100644 index 00000000..f273cc58 --- /dev/null +++ b/examples/finetuning/src/app.html @@ -0,0 +1,11 @@ + + + + + + %sveltekit.head% + + +
%sveltekit.body%
+ + diff --git a/examples/finetuning/src/lib/attachments/echarts.svelte.ts b/examples/finetuning/src/lib/attachments/echarts.svelte.ts new file mode 100644 index 00000000..2ba10bd5 --- /dev/null +++ b/examples/finetuning/src/lib/attachments/echarts.svelte.ts @@ -0,0 +1,236 @@ +import type { Attachment } from 'svelte/attachments'; + +import * as echarts from 'echarts/core'; + +type EChartsAttachmentParams = { + opts: () => echarts.EChartsCoreOption | null; + // Optional setup hook to programmatically configure the instance (e.g., event bridging) + // May return a cleanup function that will be called on detach. + setup?: (chart: echarts.ECharts, getPeers: undefined) => void | (() => void); +}; + +export function setupAxisSync(chart: echarts.ECharts, getPeers: () => echarts.ECharts[]) { + let isRelaying = false; + chart.on('updateAxisPointer', (evt) => { + if (isRelaying) return; + const e = evt as { axesInfo?: Array<{ axisDim: string; value: number }> }; + const axesInfo = e?.axesInfo; + if (!axesInfo || axesInfo.length === 0) return; + const xInfo = axesInfo.find((a) => a.axisDim === 'x'); + if (!xInfo) return; + + const catIndex = Math.max(0, Math.floor(xInfo.value ?? 0)); + const opt = chart.getOption?.() as + | { xAxis?: Array<{ data?: (number | string)[] }> } + | undefined; + const cats = opt?.xAxis?.[0]?.data ?? []; + const step = Number(cats[catIndex] ?? catIndex); + + isRelaying = true; + for (const peer of getPeers()) { + const hit = findPeerDataIndexByStep(peer, step); + peer.dispatchAction({ + type: 'updateAxisPointer', + seriesIndex: hit?.seriesIndex ?? 0, + dataIndex: hit?.dataIndex ?? catIndex + }); + } + Promise.resolve().then(() => { + isRelaying = false; + }); + }); +} + +// Binary-search utilities for sorted numeric step arrays +export function exactIndex(steps: number[], target: number): number { + let lo = 0, + hi = steps.length - 1; + while (lo <= hi) { + const mid = (lo + hi) >> 1; + const v = steps[mid]; + if (v === target) return mid; + if (v < target) lo = mid + 1; + else hi = mid - 1; + } + return -1; +} + +export function nearestIndex(steps: number[], target: number): number { + if (steps.length === 0) return -1; + const first = steps[0], + last = steps[steps.length - 1]; + if (target < first || target > last) return -1; + let lo = 0, + hi = steps.length; + while (lo < hi) { + const mid = (lo + hi) >> 1; + if (steps[mid] < target) lo = mid + 1; + else hi = mid; + } + if (lo === 0) return 0; + if (lo === steps.length) return steps.length - 1; + return target - steps[lo - 1] <= steps[lo] - target ? lo - 1 : lo; +} + +// Extract numeric x/step from an ECharts updateAxisPointer event +export function extractStepFromAxisPointerEvent(evt: unknown): number | null { + const e = evt as { axesInfo?: Array<{ axisDim: string; value: number }> } | undefined; + const xInfo = e?.axesInfo?.find((a) => a.axisDim === 'x'); + const step = Number(xInfo?.value); + return Number.isFinite(step) ? step : null; +} + +function extractXValue(datum: unknown): number | null { + if (Array.isArray(datum)) { + const x = Number(datum[0]); + return Number.isFinite(x) ? x : null; + } + if (datum && typeof datum === 'object') { + const obj = datum as { value?: unknown; x?: unknown; step?: unknown }; + if (Array.isArray(obj.value)) { + const x = Number(obj.value[0]); + return Number.isFinite(x) ? x : null; + } + const xCandidate = obj.x ?? obj.step; + if (typeof xCandidate === 'number') return xCandidate; + if (typeof xCandidate === 'string') { + const parsed = Number(xCandidate); + return Number.isFinite(parsed) ? parsed : null; + } + } + return null; +} + +function linearSearchByX(data: unknown[], targetStep: number): number { + for (let i = 0; i < data.length; i++) { + const x = extractXValue(data[i]); + if (x === targetStep) return i; + } + return -1; +} + +function binarySearchByX(data: unknown[], targetStep: number): number { + if (data.length === 0) return -1; + const first = extractXValue(data[0]); + const last = extractXValue(data[data.length - 1]); + if (first === null || last === null) return linearSearchByX(data, targetStep); + + const ascending = last >= first; + let lo = 0; + let hi = data.length - 1; + while (lo <= hi) { + const mid = lo + ((hi - lo) >> 1); + const x = extractXValue(data[mid]); + if (x === null) return linearSearchByX(data, targetStep); + if (x === targetStep) return mid; + if (ascending ? x < targetStep : x > targetStep) lo = mid + 1; + else hi = mid - 1; + } + return -1; +} + +function getSeriesArrayFromOption(opt: unknown): unknown[] { + const o = opt as { series?: unknown[] } | undefined; + return Array.isArray(o?.series) ? (o!.series as unknown[]) : []; +} + +function getDataArrayFromSeries(seriesItem: unknown): unknown[] { + const s = seriesItem as { data?: unknown[] } | undefined; + return Array.isArray(s?.data) ? (s!.data as unknown[]) : []; +} + +export function findPeerDataIndexByStep( + peer: echarts.ECharts, + step: number +): { seriesIndex: number; dataIndex: number } | null { + const opt = peer.getOption() as unknown; + const seriesArr = getSeriesArrayFromOption(opt); + if (seriesArr.length === 0) return null; + const seriesIndex = seriesArr.length - 1; // highest series index + const data = getDataArrayFromSeries(seriesArr[seriesIndex]); + if (data.length === 0) return null; + const dataIndex = binarySearchByX(data, step); + if (dataIndex < 0) return null; + return { seriesIndex, dataIndex }; +} + +export default function createEChartsAttachment(params: EChartsAttachmentParams): Attachment { + return (node: Element) => { + if (!(node instanceof HTMLDivElement)) { + throw new Error('ECharts attachment requires a div element'); + } + + const chart = echarts.init(node); + const getPeers = undefined; + + // Allow caller to set up custom behavior (e.g., axis-pointer mapping) + let setupCleanup: (() => void) | undefined; + const maybeCleanup = params.setup?.(chart, getPeers); + if (typeof maybeCleanup === 'function') setupCleanup = maybeCleanup; + + const resizeObserver = new ResizeObserver(() => { + chart.resize(); + }); + resizeObserver.observe(node); + + $effect(() => { + const options = params.opts?.(); + if (options) { + chart.setOption(options, { + notMerge: false, + replaceMerge: ['series'], + lazyUpdate: false + }); + } + }); + + return () => { + resizeObserver.unobserve(node); + setupCleanup?.(); + chart.dispose(); + }; + }; +} + +type MoveDetail = { + sourceId: string; + runId: string; + step: number; +}; + +type ClearDetail = { + sourceId: string; +}; + +const MOVE_EVENT = 'run-pointer:move'; +const CLEAR_EVENT = 'run-pointer:clear'; + +const bus: EventTarget = new EventTarget(); + +export function publishMove(detail: MoveDetail): void { + bus.dispatchEvent(new CustomEvent(MOVE_EVENT, { detail })); +} + +export function publishClear(detail: ClearDetail): void { + bus.dispatchEvent(new CustomEvent(CLEAR_EVENT, { detail })); +} + +export function subscribe( + onMove?: (detail: MoveDetail) => void, + onClear?: (detail: ClearDetail) => void +): () => void { + const moveListener = (e: Event) => { + const ce = e as CustomEvent; + if (onMove) onMove(ce.detail); + }; + const clearListener = (e: Event) => { + const ce = e as CustomEvent; + if (onClear) onClear(ce.detail); + }; + if (onMove) bus.addEventListener(MOVE_EVENT, moveListener as EventListener); + if (onClear) bus.addEventListener(CLEAR_EVENT, clearListener as EventListener); + return () => { + if (onMove) bus.removeEventListener(MOVE_EVENT, moveListener as EventListener); + if (onClear) bus.removeEventListener(CLEAR_EVENT, clearListener as EventListener); + }; +} diff --git a/examples/finetuning/src/lib/components/MetricToggleChips.svelte b/examples/finetuning/src/lib/components/MetricToggleChips.svelte new file mode 100644 index 00000000..b79ebaa3 --- /dev/null +++ b/examples/finetuning/src/lib/components/MetricToggleChips.svelte @@ -0,0 +1,41 @@ + + +
+ {#each items as name (name)} + + {/each} +
diff --git a/examples/finetuning/src/lib/components/controls/Controls.svelte b/examples/finetuning/src/lib/components/controls/Controls.svelte new file mode 100644 index 00000000..5d99ce15 --- /dev/null +++ b/examples/finetuning/src/lib/components/controls/Controls.svelte @@ -0,0 +1,776 @@ + + +
+ {#if runsMap.size >= 1} + toggleControlSection('runs')} + contentClass="w-full" + > + + + {/if} + + toggleControlSection('gpu')} + contentClass={collapsibleSectionClass} + > + + {gpuName || 'Unknown'} + + + (gpuPowerPreference.current = 'high-performance')} + /> + + resetConfigToDefaults('training.vramLimitMb.present')} + > + { + if (value >= 1024) { + const gb = value / 1024; + const gbStr = gb % 1 === 0 ? `${gb}GB` : `${gb.toFixed(1)}GB`; + return gbStr; + } + return `${value}MB`; + }} + hasDefaultValue={equalsConfigDefault('training.vramLimitMb.value')} + onReset={() => resetConfigToDefaults('training.vramLimitMb.value')} + /> + + + + toggleControlSection('task')} + contentClass={collapsibleSectionClass} + > + + + + toggleControlSection('model')} + contentClass={collapsibleSectionClass} + > + + +
+ + {getParameterCount()} + + + {getHiddenSize()} + + + {currentDataset.vocabSize} + +
+
+ + toggleControlSection('training')} + > + resetConfigToDefaults('training.logSteps')} + /> + + resetConfigToDefaults('training.batchSize')} + /> + + resetConfigToDefaults('training.gradNorm.track')} + > + resetConfigToDefaults('training.gradNorm.errorIfNonfinite')} + /> + + resetConfigToDefaults('training.clipGradNorm.present')} + > + resetConfigToDefaults('training.clipGradNorm.value')} + /> + + + + resetConfigToDefaults('training.validation.present')} + > + resetConfigToDefaults('training.validation.valSteps')} + /> + resetConfigToDefaults('training.validation.batchSize')} + /> + resetConfigToDefaults('training.validation.completions.present')} + > + resetConfigToDefaults('training.validation.completions.amount')} + /> + {#if config.training.validation.completions.amount === 'subset'} + resetConfigToDefaults('training.validation.completions.subsetSize')} + /> + {/if} + + resetConfigToDefaults('training.validation.temperature')} + /> + + + + + resetConfigToDefaults('training.limitTraining.present')} + > + resetConfigToDefaults('training.limitTraining.steps')} + /> + + + + resetConfigToDefaults('training.labelSmoothing.present')} + > + resetConfigToDefaults('training.labelSmoothing.value')} + /> + + + resetConfigToDefaults('training.dropout.present')} + > + resetConfigToDefaults('training.dropout.embedding')} + /> + resetConfigToDefaults('training.dropout.transformer.attention')} + /> + resetConfigToDefaults('training.dropout.transformer.residual')} + /> + + + + resetConfigToDefaults('training.randomSeed.present')} + > + resetConfigToDefaults('training.randomSeed.value')} + /> + + + resetConfigToDefaults('training.checkpointEverySteps.present')} + > + resetConfigToDefaults('training.checkpointEverySteps.value')} + /> + + + + toggleControlSection('optimizer')} + contentClass={collapsibleSectionClass} + > + resetConfigToDefaults('optimizer.type')} + /> + + resetConfigToDefaults('optimizer.lr')} + /> + + resetConfigToDefaults('optimizer.warmupSteps.present')} + > + resetConfigToDefaults('optimizer.warmupSteps.value')} + /> + + + resetConfigToDefaults('optimizer.lrScheduler.present')} + > + + + + resetConfigToDefaults('optimizer.weightDecay.present')} + > + resetConfigToDefaults('optimizer.weightDecay.value')} + /> + + resetConfigToDefaults('optimizer.weightDecay.useWeightDecayGroups')} + /> + + +
+ {#if config.optimizer.type === 'Muon'} + +
+ resetConfigToDefaults('optimizer.muon.nsSteps')} + /> + resetConfigToDefaults('optimizer.muon.momentum')} + /> + resetConfigToDefaults('optimizer.muon.nesterov')} + /> +
+
+ {/if} + {#if config.optimizer.type === 'AdamW' || config.optimizer.type === 'Adam' || config.optimizer.type === 'Muon'} + {@const settingsName = config.optimizer.type === 'Adam' ? 'Adam' : 'AdamW'} + +
+ resetConfigToDefaults('optimizer.adam.beta1')} + /> + resetConfigToDefaults('optimizer.adam.beta2')} + /> + resetConfigToDefaults('optimizer.adam.eps')} + /> + resetConfigToDefaults('optimizer.adam.amsgrad')} + /> +
+
+ {/if} + + {#if config.optimizer.type === 'SGD'} +
+ resetConfigToDefaults('optimizer.sgd.momentum')} + /> + resetConfigToDefaults('optimizer.sgd.dampening')} + /> + resetConfigToDefaults('optimizer.sgd.nesterov')} + /> +
+ {/if} +
+
+ + toggleControlSection('advanced')} + contentClass={collapsibleSectionClass} + > + resetConfigToDefaults('training.useWeakTensorReferences')} + /> + resetConfigToDefaults('training.sharedObjectAllocation')} + /> + resetConfigToDefaults('training.inplaceSupport')} + /> + +
+ + diff --git a/examples/finetuning/src/lib/components/controls/DatasetControls.svelte b/examples/finetuning/src/lib/components/controls/DatasetControls.svelte new file mode 100644 index 00000000..a1003070 --- /dev/null +++ b/examples/finetuning/src/lib/components/controls/DatasetControls.svelte @@ -0,0 +1,38 @@ + + + resetConfigToDefaults('data.dataset')} +/> + +

{datasetConfigMetadata?.description}

+ + + +{#if getShowLowDiversityDatasetError()} +
+ +

+ Not enough example diversity in the training dataset for a held-out validation set of size {config + .training.validation.batchSize}. Consider changing dataset parameters or reducing the + validation batch size. +

+
+
+{/if} diff --git a/examples/finetuning/src/lib/components/controls/DatasetSample.svelte b/examples/finetuning/src/lib/components/controls/DatasetSample.svelte new file mode 100644 index 00000000..bd9aad91 --- /dev/null +++ b/examples/finetuning/src/lib/components/controls/DatasetSample.svelte @@ -0,0 +1,158 @@ + + +{#snippet token(value: number, dashed: boolean = false)} + + {#if tokenizer instanceof Promise} + {#await tokenizer then tokenizer} + {decodeSingle(value, tokenizer)} + {/await} + {:else} + {decodeSingle(value, tokenizer ?? null)} + {/if} + +{/snippet} + +{#snippet tokenSequence(values: number[], ignored?: boolean[])} +
+ {#each values as value, i (i)} + {@render token(value, ignored ? ignored[i] : false)} + {/each} +
+{/snippet} + +
0 ? `${lastStableHeight}px` : 'auto'} + style:overflow-y={anyPending ? 'hidden' : 'visible'} +> + {#if sampleData} + {#await sampleData then { collated, hasPrompt }} + {#if collated.length > 0} +
+ + + + {#if hasPrompt} + + {/if} + + + + + {#each collated as { prompt, target, fullSequence } (Array.prototype.concat.call(fullSequence, prompt ?? [], target ?? []))} + {@const pLen = prompt?.length || 0} + {@const targetFlags = maskedFlagsForRange(fullSequence, pLen, target?.length ?? 0)} + + {#if hasPrompt} + {@const promptFlags = maskedFlagsForRange(fullSequence, 0, pLen)} + + {/if} + + + {/each} + + {#if hasPrompt} + + {/if} + + + +
PromptTarget
+ {@render tokenSequence(prompt!, promptFlags)} + + {@render tokenSequence(target ?? fullSequence, targetFlags)} +
......
+
+ {/if} + {/await} + {/if} +
diff --git a/examples/finetuning/src/lib/components/controls/LRSchedulePicker.svelte b/examples/finetuning/src/lib/components/controls/LRSchedulePicker.svelte new file mode 100644 index 00000000..531e616f --- /dev/null +++ b/examples/finetuning/src/lib/components/controls/LRSchedulePicker.svelte @@ -0,0 +1,517 @@ + + +
+ + + + + resetConfigToDefaults('optimizer.lrScheduler.type')} + /> + + +
+ + + + + + + {#if points.length > 0} + {@const lrs = points.map(([, y]) => y)} + {@const maxLr = Math.max(...lrs)} + {@const minLr = Math.min(...lrs)} + {#if maxLr === minLr} + + {maxLrLabel} + {:else} + + {maxLrLabel} + {minLrLabel} + {/if} + {/if} + + + steps + {#if points.length > 0} + {Math.max(...points.map(([x]) => x))} + {/if} + + + {#if svgPoints} + + {/if} + + + +
+ +
+
+ + +
+ + (stepsToShow = 1000)} + /> + + + {#if config.optimizer.lrScheduler.type === 'linear'} + resetConfigToDefaults('optimizer.lrScheduler.linearSchedule.startFactor')} + /> + resetConfigToDefaults('optimizer.lrScheduler.linearSchedule.endFactor')} + /> + resetConfigToDefaults('optimizer.lrScheduler.linearSchedule.totalIters')} + /> + {/if} + + + {#if config.optimizer.lrScheduler.type === 'constant'} + resetConfigToDefaults('optimizer.lrScheduler.constantSchedule.factor')} + /> + resetConfigToDefaults('optimizer.lrScheduler.constantSchedule.totalIters')} + /> + {/if} + + + {#if config.optimizer.lrScheduler.type === 'cosine'} + resetConfigToDefaults('optimizer.lrScheduler.cosineAnnealingSchedule.tMax')} + /> + + resetConfigToDefaults('optimizer.lrScheduler.cosineAnnealingSchedule.etaMin')} + /> + {/if} + + + {#if config.optimizer.lrScheduler.type === 'step'} + resetConfigToDefaults('optimizer.lrScheduler.stepSchedule.stepSize')} + /> + resetConfigToDefaults('optimizer.lrScheduler.stepSchedule.gamma')} + /> + {/if} + + + {#if config.optimizer.lrScheduler.type === 'exponential'} + resetConfigToDefaults('optimizer.lrScheduler.exponentialSchedule.gamma')} + /> + {/if} +
+
diff --git a/examples/finetuning/src/lib/components/controls/RunsTable.svelte b/examples/finetuning/src/lib/components/controls/RunsTable.svelte new file mode 100644 index 00000000..73d67b06 --- /dev/null +++ b/examples/finetuning/src/lib/components/controls/RunsTable.svelte @@ -0,0 +1,52 @@ + + +
+ {#if runs.length > 1} + + {/if} + +
+
+
+ + + + + + + + + + + + {#each runs as run (run.runId)} + + + + + {/each} + +
{frozenColumn.label}Changes
{run.runId}{run.diffSummary ?? 'initial experiment'}
+
+
+
+
diff --git a/examples/finetuning/src/lib/components/controls/SelectDataset.svelte b/examples/finetuning/src/lib/components/controls/SelectDataset.svelte new file mode 100644 index 00000000..112eec09 --- /dev/null +++ b/examples/finetuning/src/lib/components/controls/SelectDataset.svelte @@ -0,0 +1,91 @@ + + +{#snippet nameAndBadges(_opt: DatasetOption)} + {@const opt = _opt as DatasetOption} + {@const meta = getMeta(opt.value as DatasetName)} + {@const name = opt.text ?? meta.name} + {@const citations = + 'citations' in meta + ? ((meta as Record).citations as CitationsType) + : undefined} +
+ {name} + {#if citations} + + {/if} +
+{/snippet} + +
+ + {#snippet option(_opt, _selected)} + {@const opt = _opt as DatasetOption} + {@render nameAndBadges(opt)} + {/snippet} + {#snippet trigger(_selected)} + {#if _selected} + {@const opt = _selected as DatasetOption} + {opt.text ?? opt.value} + {:else} + Select... + {/if} + {/snippet} + + {#if 'citations' in selectedMeta} +
+ +
+ {/if} +
diff --git a/examples/finetuning/src/lib/components/controls/select/SelectWithCitations.svelte b/examples/finetuning/src/lib/components/controls/select/SelectWithCitations.svelte new file mode 100644 index 00000000..515e5fa3 --- /dev/null +++ b/examples/finetuning/src/lib/components/controls/select/SelectWithCitations.svelte @@ -0,0 +1,66 @@ + + +{#snippet citationView(_opt: CitationOption)} + {@const label = _opt.title} + {@const citations = _opt.citations} +
+
{label}
+ {#if citations} + + {/if} +
+{/snippet} + +
+ + {#snippet option(_opt, _selected)} + {@const opt = _opt as CitationOption} + {@render citationView(opt)} + {/snippet} + {#snippet trigger(_selected)} + {#if _selected} + {@const opt = _selected as CitationOption} + {opt.title} + {:else} + Select... + {/if} + {/snippet} + + {#if selectedOption?.citations} +
+ {/if} +
diff --git a/examples/finetuning/src/lib/components/metrics/MetricsSection.svelte b/examples/finetuning/src/lib/components/metrics/MetricsSection.svelte new file mode 100644 index 00000000..ca908275 --- /dev/null +++ b/examples/finetuning/src/lib/components/metrics/MetricsSection.svelte @@ -0,0 +1,65 @@ + + +
+
+
{ + if (e.key === 'Enter' || e.key === ' ') { + e.preventDefault(); + handleClick(); + } + }} + > + +

{title}

+
+ {@render chips?.()} +
+ + {#if isOpen} +
+ {@render children?.()} +
+ {/if} +
diff --git a/examples/finetuning/src/lib/components/metrics/RunChart.svelte b/examples/finetuning/src/lib/components/metrics/RunChart.svelte new file mode 100644 index 00000000..8ef57a5e --- /dev/null +++ b/examples/finetuning/src/lib/components/metrics/RunChart.svelte @@ -0,0 +1,323 @@ + + +
+
chartOptions, + setup: (chart) => { + let isRelaying = false; + + const onUpdateAxisPointer = (evt: unknown) => { + if (isRelaying) return; + const s = extractStepFromAxisPointerEvent(evt); + if (s == null) return; + // Choose topmost series that contains this exact step + for (let idx = seriesIndexToRunId.length - 1; idx >= 0; idx--) { + const runId = seriesIndexToRunId[idx]; + if (!runId) continue; + const seriesInfo = runIdToSeries.get(runId); + if (!seriesInfo) continue; + if (exactIndex(seriesInfo.steps, s) !== -1) { + publishMove({ sourceId: chartId, runId, step: s }); + break; + } + } + }; + + const onGlobalOut = () => { + publishClear({ sourceId: chartId }); + }; + + chart.on('updateAxisPointer', onUpdateAxisPointer); + chart.on('globalout', onGlobalOut); + + const unsubscribe = subscribe( + ({ sourceId, runId, step }) => { + if (sourceId === chartId) return; + const info = runIdToSeries.get(runId); + if (!info) return; + const { seriesIndex, steps, first, last } = info; + if (step < first || step > last) return; + const dataIndex = nearestIndex(steps, step); + if (dataIndex < 0) return; + isRelaying = true; + try { + // ECharts accepts a second options argument with { silent: true } + chart.dispatchAction( + { type: 'updateAxisPointer', seriesIndex, dataIndex }, + { silent: true } + ); + } finally { + isRelaying = false; + } + }, + ({ sourceId }) => { + if (sourceId === chartId) return; + chart.dispatchAction({ type: 'hideTip' }, { silent: true }); + } + ); + + return () => { + unsubscribe(); + chart.off('updateAxisPointer', onUpdateAxisPointer); + chart.off('globalout', onGlobalOut); + }; + } + })} + >
+
diff --git a/examples/finetuning/src/lib/components/metrics/validationCompletions/CompletionsToken.svelte b/examples/finetuning/src/lib/components/metrics/validationCompletions/CompletionsToken.svelte new file mode 100644 index 00000000..8613d92a --- /dev/null +++ b/examples/finetuning/src/lib/components/metrics/validationCompletions/CompletionsToken.svelte @@ -0,0 +1,124 @@ + + + { + if (e.key === 'Enter' || e.key === ' ') { + e.preventDefault(); + handleClick(e); + } + }} +> + {#if targetText != null} + + + {#if actualText} + + {visualizeToken(actualText)} + + {/if} + + {visualizeToken(targetText)} + + + {:else} + + {/if} + diff --git a/examples/finetuning/src/lib/components/metrics/validationCompletions/ValidationCompletionsViewer.svelte b/examples/finetuning/src/lib/components/metrics/validationCompletions/ValidationCompletionsViewer.svelte new file mode 100644 index 00000000..4a534700 --- /dev/null +++ b/examples/finetuning/src/lib/components/metrics/validationCompletions/ValidationCompletionsViewer.svelte @@ -0,0 +1,824 @@ + + +
+
+ + validation/completions + {#if completionsData} + + of {completionsData.targetStep.completions.length} + {/if} + + {#if completionsData} +

+ Step {completionsData.stepNumber} • Temp: {completionsData.targetStep.samplingParams + .temperature} +

+ {/if} +
+ + {#if !completionsData} +
+

No validation data available. Validation will run during training.

+
+ {:else} + {@const visibleCompletionsNumberWidth = visibleCompletions.length.toString().length} + {@const focus = hoveredFocus} +
+
+ {#await tokenizer then tokenizer} + {#each visibleCompletions as completion, index (index)} + {@const genIds = completion.tokenIds} + {@const hasTargets = Boolean(completionsData.targets)} + {@const tgtIds = hasTargets ? completionsData.targets?.[index] || [] : []} + {@const tokenComparisons = hasTargets ? compareTokenIds(genIds, tgtIds, tokenizer) : []} + {@const matchesRow = completionsData.targetStep.matches?.[index] || []} + {@const prefixLen = completionsData.decoderPromptLengths?.[index] ?? 1} + +
+ (hoveredFocus = { + exampleIndex: index, + tokenIndex: hoveredFocus?.tokenIndex ?? 0 + })} + onmouseleave={() => (hoveredFocus = null)} + > +
+ {(index + 1).toString().padStart(visibleCompletionsNumberWidth, '\u00A0')} +
+ + +
+ {#if hasTargets} + {#each tokenComparisons as item, tIndex (tIndex)} + {@const isPromptItem = item.kind === 'prompt'} + {@const completedBefore = tokenComparisons + .slice(0, tIndex) + .filter((it) => it.kind !== 'prompt').length} + {@const sequenceTokenIndex = isPromptItem + ? prefixLen + : prefixLen + Math.max(0, completedBefore)} + {@const isHighlighted = + !isPromptItem && + focus?.exampleIndex === index && + focus?.tokenIndex === sequenceTokenIndex} + {#if item.kind === 'prompt'} + {}} + onLeave={() => {}} + onSelect={() => {}} + /> + {:else if item.isCorrect} + (hoveredFocus = { exampleIndex: ei, tokenIndex: ti })} + onLeave={() => (hoveredFocus = null)} + /> + {:else} + (hoveredFocus = { exampleIndex: ei, tokenIndex: ti })} + onLeave={() => (hoveredFocus = null)} + /> + {/if} + {/each} + {:else if genIds.length === 0} + [empty] + {:else} + {#each genIds as id, tIndex (tIndex)} + {@const text = decodeSingle(id, tokenizer)} + {@const isPrompt = tIndex < prefixLen} + {@const match = matchesRow[tIndex - prefixLen]} + {@const variant = isPrompt + ? 'prompt' + : !hasMatchData || matchesRow.length === 0 + ? 'generated' + : match === true + ? 'correct' + : match === false + ? 'incorrect' + : 'neutral'} + {@const isHighlighted = + !isPrompt && focus?.exampleIndex === index && focus?.tokenIndex === tIndex} + (hoveredFocus = { exampleIndex: ei, tokenIndex: ti })} + onLeave={() => (hoveredFocus = null)} + /> + {/each} + {/if} +
+
+ {/each} + {/await} +
+ + + {#if chartOptions} + +
+
+ {#if selectedProbsStep} +
+ (selectedProbsStep = null)} + /> +
+ {/if} +
{ + if (!chartOptions) return null; + const baseBar = chartOptions.bar; + const tbc = chartOptions.tokensByCol ?? []; + const last = Math.max(0, tbc.length - 1); + const col = + activeStepCol == null + ? last + : Math.max(0, Math.min(last, Math.floor(activeStepCol))); + const defaultCategories = chartOptions.tokenCategoriesByCol[col]; + const defaultData = tbc[col]; + const categories = selectedProbsStep + ? selectedProbsStep.categories + : defaultCategories; + const data = selectedProbsStep ? selectedProbsStep.data : defaultData; + return { + ...baseBar, + title: selectedProbsStep + ? { ...baseBar.title, text: `probs (step ${selectedProbsStep.step})` } + : baseBar?.title, + xAxis: [ + { + type: 'category', + data: categories, + axisPointer: { show: true, type: 'shadow' }, + triggerEvent: true, + axisTick: { show: false }, + axisLabel: { show: false }, + axisLine: { show: false } + } + ], + series: [{ id: 'token-bars', type: 'bar', data }] + }; + } + })} + >
+
+
+
+ {/if} +
+ {/if} +
diff --git a/examples/finetuning/src/lib/dataUtils.ts b/examples/finetuning/src/lib/dataUtils.ts new file mode 100644 index 00000000..23567569 --- /dev/null +++ b/examples/finetuning/src/lib/dataUtils.ts @@ -0,0 +1,31 @@ +export function openDb( + dbName: string, + dbVersion: number, + onupgradeneeded: (db: IDBDatabase) => void +): Promise { + return new Promise((resolve, reject) => { + const req = indexedDB.open(dbName, dbVersion); + req.onupgradeneeded = () => onupgradeneeded(req.result); + req.onsuccess = () => resolve(req.result); + req.onerror = () => reject(req.error); + req.onblocked = () => reject(new Error('IndexedDB upgrade blocked')); + }); +} + +export function promisify(req: IDBRequest): Promise { + return new Promise((resolve, reject) => { + req.onsuccess = () => resolve(req.result); + req.onerror = () => reject(req.error); + }); +} + +export function txRequest( + db: IDBDatabase, + storeName: string, + mode: IDBTransactionMode, + op: (store: IDBObjectStore) => IDBRequest +): Promise { + const tx = db.transaction(storeName, mode); + const store = tx.objectStore(storeName); + return promisify(op(store)); +} diff --git a/examples/finetuning/src/lib/train/generate.ts b/examples/finetuning/src/lib/train/generate.ts new file mode 100644 index 00000000..089c50fe --- /dev/null +++ b/examples/finetuning/src/lib/train/generate.ts @@ -0,0 +1,263 @@ +/** + * @fileoverview Shared text generation utilities for different model types + */ + +import type { Device, Tensor } from '@piston-ml/piston-web'; + +import { int32, tensor, WeakTensorFunctionMode } from '@piston-ml/piston-web'; + +import type { GeneratableModel } from './types'; + +import { createEmptyDecoderKVCache, type DecoderKVCache } from './model/cache'; +import { GPT } from './model/gpt'; + +export interface GenerationConfig { + maxTokens?: number; + stopTokens?: number | number[]; + device?: string | Device; + startToken?: number; + maxTargetLength?: number; + temperature?: number; + useKvCache?: boolean; +} + +export interface GenerationResult { + sequences: number[][]; + // Raw probabilities (post-softmax) for generated tokens + probs?: Tensor; + // Running average throughput since start of generation (tokens/second) + tokensPerSecond?: number; +} + +function normalizeStopTokens(stopTokens: number | number[] = []): Set { + return new Set(Array.isArray(stopTokens) ? stopTokens : [stopTokens]); +} + +/** + * Create tensor from token sequences with proper device and dtype + */ +function createInputTensor(sequences: number[][], device: Device): Tensor { + return tensor(sequences, { device, dtype: int32 }); +} + +/** + * Convert tensor output to array of token IDs + */ +async function tensorToTokens(tokenTensor: Tensor): Promise { + return (await (await tokenTensor.to('cpu')).toVec()) as Int32Array; +} + +/** + * Check if any sequence should continue generating (hasn't hit stop tokens) + */ +function shouldContinueGeneration( + results: number[][], + newTokens: Int32Array, + stopTokenSet: Set +): boolean { + let shouldContinue = false; + for (let i = 0; i < results.length; i++) { + // If this sequence already ended with a stop token, do not append anything further + const sequence = results[i]; + const lastToken = sequence.length > 0 ? sequence[sequence.length - 1] : undefined; + const alreadyStopped = lastToken !== undefined && stopTokenSet.has(lastToken); + + if (alreadyStopped) { + continue; + } + + const token = newTokens[i]; + // Always append the newly generated token + sequence.push(token); + // Continue only if we did not just append a stop token + if (!stopTokenSet.has(token)) { + shouldContinue = true; + } + } + return shouldContinue; +} + +/** + * Create a simple running tokens/sec tracker using the Performance API + */ +function startTokensPerSecondTracker(): (tokensProducedThisStep: number) => number { + const now = + typeof performance !== 'undefined' && typeof performance.now === 'function' + ? () => performance.now() + : () => Date.now(); + const startMs = now(); + let totalTokens = 0; + return (tokensProducedThisStep: number) => { + if (tokensProducedThisStep > 0) totalTokens += tokensProducedThisStep; + const elapsedMs = Math.max(1, now() - startMs); + return totalTokens / (elapsedMs / 1000); + }; +} + +/** + * Generate tokens for GPT (decoder-only) model + */ +export async function* generateGPTStream( + model: GPT, + input: number[] | number[][], + config: GenerationConfig = {} +): AsyncGenerator { + const stopTokenSet = normalizeStopTokens(config.stopTokens); + const isBatch = Array.isArray(input[0]); + const sequences = isBatch ? (input as number[][]) : [input as number[]]; + + // Initialize results with copies of input sequences + const results = sequences.map((seq) => [...seq]); + const device = model.lm_head.weight.device; + + const kvCache: DecoderKVCache | null = + (config.useKvCache ?? true) ? createEmptyDecoderKVCache(model.config.nLayer) : null; + + let seeded = false; + let step = 0; + const getTokensPerSecond = startTokensPerSecondTracker(); + while (true) { + const weakMode = new WeakTensorFunctionMode(); + try { + weakMode.markWeak(kvCache); + // Seed cache once with full sequences, then switch to single-token steps + const prevLengths = results.map((seq) => seq.length); + let inputTensor: Tensor; + if (kvCache && seeded) { + const latestTokens = results.map((seq) => [seq[seq.length - 1] ?? 0]); + inputTensor = createInputTensor(latestTokens, device); + } else { + const { padded } = rightPadSequences(results); + inputTensor = createInputTensor(padded, device); + } + + // Forward pass to get logits + const [logits, _] = model.forward(inputTensor, { kvCache }); + weakMode.pin(kvCache); + + // Get logits for the last token in each sequence + const [batchSize, seqLen, vocabSize] = logits.size(); + let lastTokenLogits = logits + .slice([ + [0, batchSize], + [seqLen - 1, seqLen], + [0, vocabSize] + ]) + .view([batchSize, vocabSize]); + + if (config.temperature && config.temperature > 0) { + lastTokenLogits = lastTokenLogits.div(config.temperature); + } + + lastTokenLogits = lastTokenLogits.softmax(-1); + + // Choose next tokens: sample when temperature > 0, else greedy argmax + const nextTokenTensor = + config.temperature && config.temperature > 0 + ? lastTokenLogits.multinomial(1, { replacement: false }) + : lastTokenLogits.argmax({ dim: -1 }); + const nextTokensArray = await tensorToTokens(nextTokenTensor); + + // Update sequences and check for stop conditions + const shouldContinue = shouldContinueGeneration(results, nextTokensArray, stopTokenSet); + // Compute tokens appended this step across active sequences + let appendedThisStep = 0; + for (let i = 0; i < results.length; i++) { + appendedThisStep += results[i].length - prevLengths[i]; + } + const tokensPerSecond = getTokensPerSecond(appendedThisStep); + + weakMode.pin([results, lastTokenLogits]); + // Yield current state with sequences and logits + yield { + sequences: isBatch ? results.map((seq) => [...seq]) : [results[0].slice()], + probs: lastTokenLogits, // Provide the softmax'd logits for the last token + tokensPerSecond + }; + + // Mark cache as seeded after first forward with cache + if (kvCache && !seeded) seeded = true; + + // If all sequences hit stop tokens, break + if (!shouldContinue) { + break; + } + + step++; + + if (config.maxTokens !== undefined && step >= config.maxTokens) { + break; + } + } finally { + weakMode[Symbol.dispose](); + } + } +} + +/** + * Standard generate function that collects all tokens (backward compatible) + */ +export async function generate( + model: GeneratableModel, + input: number[] | number[][], + config: GenerationConfig = {} +): Promise { + const results: number[][] = []; + let tokenCount = 0; + const maxTokens = config.maxTokens ?? 50; + + for await (const generationResult of generateGPTStream(model, input, config)) { + results.length = 0; + results.push(...generationResult.sequences); + tokenCount++; + + if (tokenCount >= maxTokens) { + break; + } + } + + return results; +} + +/** + * Right-pad sequences to a uniform length for batched forward passes. + * Returns padded sequences and a padding mask (1 for real tokens, 0 for padding). + * Note: The original sequences array is NOT modified. + */ +export function rightPadSequences( + sequences: number[][], + padToken?: number +): { padded: number[][]; paddingMask: number[][] } { + const maxLen = sequences.reduce((m, s) => Math.max(m, s.length), 0); + const padded: number[][] = new Array(sequences.length); + const paddingMask: number[][] = new Array(sequences.length); + + for (let i = 0; i < sequences.length; i++) { + const seq = sequences[i]; + const realLen = seq.length; + const rowMask: number[] = new Array(maxLen).fill(0); + for (let j = 0; j < realLen; j++) rowMask[j] = 1; + + if (realLen === maxLen) { + padded[i] = [...seq]; + paddingMask[i] = rowMask; + continue; + } + + let padVal: number; + if (padToken !== undefined) { + padVal = padToken; + } else if (realLen > 0) { + padVal = seq[realLen - 1]; + } else { + padVal = 0; // degenerate case + } + + const numPad = maxLen - realLen; + padded[i] = + realLen > 0 ? [...seq, ...new Array(numPad).fill(padVal)] : new Array(maxLen).fill(padVal); + paddingMask[i] = rowMask; + } + + return { padded, paddingMask }; +} diff --git a/examples/finetuning/src/lib/train/model/cache.ts b/examples/finetuning/src/lib/train/model/cache.ts new file mode 100644 index 00000000..7dd33524 --- /dev/null +++ b/examples/finetuning/src/lib/train/model/cache.ts @@ -0,0 +1,20 @@ +import type { Tensor } from '@piston-ml/piston-web'; + +export type SelfAttentionCache = { + k: Tensor; + v: Tensor; + length: number; +}; + +export type DecoderLayerCache = { + self?: SelfAttentionCache; + cross?: SelfAttentionCache; +}; + +export type DecoderKVCache = { + layers: DecoderLayerCache[]; +}; + +export function createEmptyDecoderKVCache(nLayers: number): DecoderKVCache { + return { layers: Array.from({ length: nLayers }, () => ({})) }; +} diff --git a/examples/finetuning/src/lib/train/model/config.ts b/examples/finetuning/src/lib/train/model/config.ts new file mode 100644 index 00000000..1fd79161 --- /dev/null +++ b/examples/finetuning/src/lib/train/model/config.ts @@ -0,0 +1,49 @@ +import type { GPT2ModelType } from '$lib/workspace/config'; + +import { GPT2_BLOCK_SIZE, GPT2_VOCAB_SIZE } from './gpt'; + +export interface GPT2Config { + modelType: GPT2ModelType; + nLayer: number; + nHead: number; + nEmbd: number; + vocabSize: number; + blockSize: number; + embdPdrop: number; + residPdrop: number; + attnPdrop: number; +} + +function getModelParametersFromType(modelType: GPT2ModelType): { + nLayer: number; + nHead: number; + nEmbd: number; +} { + switch (modelType) { + case 'distilgpt2': + return { nLayer: 6, nHead: 12, nEmbd: 768 }; + case 'gpt2': + return { nLayer: 12, nHead: 12, nEmbd: 768 }; + case 'gpt2-medium': + return { nLayer: 24, nHead: 16, nEmbd: 1024 }; + case 'gpt2-large': + return { nLayer: 36, nHead: 20, nEmbd: 1280 }; + case 'gpt2-xl': + return { nLayer: 48, nHead: 25, nEmbd: 1600 }; + } +} + +export function buildGPT2Config(modelType: GPT2ModelType): GPT2Config { + const { nLayer, nHead, nEmbd } = getModelParametersFromType(modelType); + return { + vocabSize: GPT2_VOCAB_SIZE, + blockSize: GPT2_BLOCK_SIZE, + modelType, + nLayer, + nHead, + nEmbd, + embdPdrop: 0.1, + residPdrop: 0.1, + attnPdrop: 0.1 + }; +} diff --git a/examples/finetuning/src/lib/train/model/gpt.ts b/examples/finetuning/src/lib/train/model/gpt.ts new file mode 100644 index 00000000..971f5188 --- /dev/null +++ b/examples/finetuning/src/lib/train/model/gpt.ts @@ -0,0 +1,392 @@ +/** + * @fileoverview Implementation of generic encoder-decoder, encoder-only, and decoder-only + * transformer models + */ + +import type { Config } from '$lib/workspace/config'; + +import { cat, CrossEntropyLoss, nn, Parameter, type Tensor } from '@piston-ml/piston-web'; + +import type { DecoderLayerCache, SelfAttentionCache } from './cache'; +import type { DecoderKVCache } from './cache'; + +import { createEmptyDecoderKVCache } from './cache'; +import { type GPT2Config } from './config'; +import { createCausalMask, createPositionIds, maskedFill } from './utils'; + +export const GPT2_VOCAB_SIZE = 1024; +export const GPT2_BLOCK_SIZE = 1024; + +export class CausalSelfAttention extends nn.Module { + private readonly nHeads: number; + private readonly nKvHeads: number; + private readonly embeddingSize: number; + private readonly headDim: number; + private readonly config: GPT2Config; + + private readonly c_attn: nn.Linear; + private readonly c_proj: nn.Linear; + + private attn_dropout?: nn.Dropout; + private resid_dropout?: nn.Dropout; + + constructor(config: GPT2Config) { + super(); + + this.nHeads = config.nHead; + this.nKvHeads = config.nHead; + this.embeddingSize = config.nEmbd; + this.headDim = config.nEmbd / this.nHeads; + this.config = config; + + const kvDim = this.headDim * this.nKvHeads; + const qkvOutDim = this.embeddingSize + 2 * kvDim; + + this.c_attn = new nn.Linear(this.embeddingSize, qkvOutDim); + this.c_proj = new nn.Linear(this.embeddingSize, this.embeddingSize); + + this.attn_dropout = new nn.Dropout(config.attnPdrop); + this.resid_dropout = new nn.Dropout(config.residPdrop); + } + + private projectQkv(input: Tensor): [Tensor, Tensor, Tensor] { + const [B, T, _] = input.size(); + const kvDim = this.headDim * this.nKvHeads; + const keyPos = this.embeddingSize; + const valuePos = this.embeddingSize + kvDim; + const qkv = this.c_attn.forward(input); + let q = qkv.slice([ + [0, B], + [0, T], + [0, this.embeddingSize] + ]); + let k = qkv.slice([ + [0, B], + [0, T], + [keyPos, keyPos + kvDim] + ]); + let v = qkv.slice([ + [0, B], + [0, T], + [valuePos, valuePos + kvDim] + ]); + const qShape = [B, T, this.nHeads, this.headDim]; + const kvShape = [B, T, this.nKvHeads, this.headDim]; + q = q.view(qShape)?.transpose(1, 2); + k = k.view(kvShape)?.transpose(1, 2); + v = v.view(kvShape)?.transpose(1, 2); + [k, v] = this.applyGroupedQueryBroadcast(k, v, B, T); + return [q, k, v]; + } + + private runAttention( + q: Tensor, + k: Tensor, + v: Tensor, + options: { + attentionMask?: Tensor | null; + cache?: SelfAttentionCache | null; + ropeOffsets?: { qOffset: number; kOffset: number } | null; + } + ): [Tensor, SelfAttentionCache | null] { + const B = q.size(0); + const T_q = q.size(2); + + // Concatenate with cache if present + const kCat = options.cache ? cat([options.cache.k, k], { dim: 2 }) : k; + const vCat = options.cache ? cat([options.cache.v, v], { dim: 2 }) : v; + const T_kv = kCat.size(2); + + // Compute attention scores + let att = q.matmul(kCat, { transRhs: true }).div(Math.sqrt(this.headDim)); + + att = this.applyCausalMask(att, T_q, T_kv); + + if (options.attentionMask) { + att = this.applyAttentionMask(att, options.attentionMask); + } + + // Apply softmax over keys plus optional sink column + att = att.softmax(3); + att = this.attn_dropout ? this.attn_dropout.forward(att) : att; + + let y = att.matmul(vCat).transpose(1, 2).view([B, T_q, this.embeddingSize]); + // Apply output projection + y = this.c_proj.forward(y); + + if (this.resid_dropout) { + y = this.resid_dropout.forward(y); + } + + const newCache: SelfAttentionCache | null = { + k: kCat, + v: vCat, + length: T_kv + }; + return [y, newCache]; + } + + private applyAttentionMask(att: Tensor, attentionMask: Tensor) { + // Convert attention mask to appropriate shape [B, 1, 1, T_kv] and broadcast + const mask = attentionMask.unsqueeze(1).unsqueeze(2).broadcastTo(att.size()); + return maskedFill(att, mask, -1e9); + } + + private applyCausalMask(att: Tensor, queryLen: number, keyLen: number) { + // Apply causal mask if needed + return queryLen <= 1 + ? att + : (() => { + const mask = createCausalMask(queryLen, keyLen).broadcastTo(att.size()); + return maskedFill(att, mask, -1e9); + })(); + } + + /** + * Apply grouped-query attention broadcasting to key and value tensors + * @param k Key tensor [B, nKvHeads, seqLen, headDim] + * @param v Value tensor [B, nKvHeads, seqLen, headDim] + * @param B Batch size + * @param seqLen Sequence length + * @returns Broadcasted [k, v] tensors [B, nHeads, seqLen, headDim] + */ + private applyGroupedQueryBroadcast( + k: Tensor, + v: Tensor, + B: number, + seqLen: number + ): [Tensor, Tensor] { + if (this.nHeads !== this.nKvHeads) { + const repeatFactor = this.nHeads / this.nKvHeads; + const th = seqLen * this.headDim; + + k = k + .view([B, this.nKvHeads, th]) + .unsqueeze(2) + .broadcastTo([B, this.nKvHeads, repeatFactor, th]) + .view([B, this.nHeads, seqLen, this.headDim]); + v = v + .view([B, this.nKvHeads, th]) + .unsqueeze(2) + .broadcastTo([B, this.nKvHeads, repeatFactor, th]) + .view([B, this.nHeads, seqLen, this.headDim]); + } + return [k, v]; + } + + forward( + input: Tensor, + options: { attentionMask?: Tensor | null; cache?: SelfAttentionCache | null } = {} + ): { output: Tensor; pastKeyValues?: SelfAttentionCache } { + const qkv = this.projectQkv(input); + const [q, k, v] = qkv; + const pastLen = options.cache?.length ?? 0; + const [y, newCache] = this.runAttention(q, k, v, { + attentionMask: options.attentionMask ?? null, + cache: options.cache ?? null, + ropeOffsets: { qOffset: pastLen, kOffset: pastLen } + }); + return { output: y, pastKeyValues: newCache ?? undefined }; + } +} + +type GPTDict = { + drop: nn.Dropout; + wte: nn.Embedding; + wpe: nn.Embedding; + h: nn.ModuleList; + ln_f: nn.Module<[Tensor], Tensor>; +}; + +export class GPT extends nn.Module { + public config: GPT2Config; + public transformer: nn.ModuleDict; + readonly lm_head: nn.Linear; + private readonly criterion: CrossEntropyLoss; + + constructor(modelConfig: GPT2Config, config: Config) { + super(); + + this.config = modelConfig; + const transformerDict: GPTDict = { + drop: new nn.Dropout(this.config.embdPdrop), + wte: new nn.Embedding(this.config.vocabSize, this.config.nEmbd), + wpe: new nn.Embedding(this.config.blockSize, this.config.nEmbd), + h: new nn.ModuleList( + Array.from({ length: this.config.nLayer }).map(() => new Block(this.config)) + ), + ln_f: new nn.LayerNorm(this.config.nEmbd) + }; + + this.transformer = new nn.ModuleDict(transformerDict); + + // Output projection with optional weight tying to token embeddings + this.lm_head = new nn.Linear(this.config.nEmbd, this.config.vocabSize, false); + this.lm_head.weight = new Parameter(this.transformer.dict.wte.weight); + + this.criterion = new CrossEntropyLoss({ + labelSmoothing: config.training.labelSmoothing.present + ? config.training.labelSmoothing.value + : 0.0, + ignoreIndex: -100 + }); + } + + /** + * @param input - Input tensor of token IDs [batch_size, seq_len] + * @param targets - Target tensor of token IDs [batch_size, + * seq_len] + * @returns [logits, loss] + */ + forward( + input: Tensor, + options: { targets?: Tensor | null; kvCache?: DecoderKVCache | null } = {} + ): [Tensor, Tensor | null] { + const targets = options.targets ?? null; + const kvCache = options.kvCache ?? null; + const [batchSize, seqLen] = input.size(); + + if (!seqLen) { + throw new Error( + 'Input tensor has no sequence length (did you forget to pass input as batches?)' + ); + } + + // Get token embeddings + let wordEmbeddings = this.transformer.dict.wte.forward(input); + + // Use cache length (if any) as position offset for absolute encodings during incremental decoding + const posOffset = kvCache?.layers?.[0]?.self?.length ?? 0; + + // Add positional embeddings + const positions = createPositionIds(seqLen, batchSize, wordEmbeddings.device, posOffset); + const positionEmbeddingsOutput = this.transformer.dict.wpe.forward(positions); + wordEmbeddings = wordEmbeddings.add(positionEmbeddingsOutput); + // Apply embedding dropout if configured + wordEmbeddings = this.transformer.dict.drop.forward(wordEmbeddings); + + // Pass through each transformer layer + let hiddenStates = wordEmbeddings; + + const useCache = kvCache !== null; + const cacheObj = useCache ? kvCache! : createEmptyDecoderKVCache(this.config.nLayer); + for (let i = 0; i < this.config.nLayer; i++) { + const layerModule = this.transformer.dict.h[i] as Block; + const result = layerModule.forward(hiddenStates, { cache: cacheObj.layers[i] }); + if (useCache) { + cacheObj.layers[i] = result.pastKeyValues!; + hiddenStates = result.output; + } else { + hiddenStates = result.output; + } + } + + // Apply final layer normalization + if (this.transformer.dict.ln_f) { + hiddenStates = this.transformer.dict.ln_f.forward(hiddenStates); + } + + // Project to vocabulary + const logits = this.lm_head.forward(hiddenStates); + + const loss = targets + ? this.criterion.forward(logits.view([-1, logits.size(-1)]), targets.view(-1)) + : null; + + return [logits, loss]; + } +} + +export type DecoderLayerForwardOptions = { + encoderHiddenStates?: Tensor | null; + srcPaddingMask?: Tensor | null; + tgtPaddingMask?: Tensor | null; + cache?: DecoderLayerCache | null; +}; + +export type DecoderLayerForwardResult = { + output: Tensor; + pastKeyValues?: DecoderLayerCache; +}; + +export class Block extends nn.Module { + private readonly lnSelfAttn!: nn.Module<[Tensor], Tensor>; + private readonly lnMlp: nn.Module<[Tensor], Tensor>; + private readonly selfAttn: CausalSelfAttention; + private readonly mlp: MLP; + private readonly dropout?: nn.Dropout; + + constructor(config: GPT2Config) { + super(); + + this.lnSelfAttn = new nn.LayerNorm(config.nEmbd); + this.selfAttn = new CausalSelfAttention(config); + + this.lnMlp = new nn.LayerNorm(config.nEmbd); + this.mlp = new MLP(config.nEmbd); + + if (config.residPdrop > 0) { + this.dropout = new nn.Dropout(config.residPdrop); + } + } + + forward(input: Tensor, options: DecoderLayerForwardOptions = {}): DecoderLayerForwardResult { + const tgtPaddingMask = options.tgtPaddingMask ?? null; + const cache = options.cache ?? null; + let x = input; + let selfCache = cache?.self ?? null; + + const residual = input; + x = this.lnSelfAttn.forward(input); + const selfResult = this.selfAttn.forward(x, { + attentionMask: tgtPaddingMask ?? null, + cache: selfCache ?? null + }); + selfCache = selfResult.pastKeyValues ?? null; + x = residual.add(selfResult.output); + + const residual3 = x; + x = this.lnMlp.forward(x); + x = this.mlp.forward(x); + if (this.dropout) { + x = this.dropout.forward(x); + } + x = residual3.add(x); + + const result: DecoderLayerForwardResult = { output: x }; + if (cache) { + result.pastKeyValues = { self: selfCache ?? undefined, cross: cache.cross ?? undefined }; + } + return result; + } +} + +export class MLP extends nn.Module { + private readonly upProj: nn.Linear; + private readonly downProj: nn.Linear; + private readonly activation: (x: Tensor) => Tensor; + + /** + * @param embeddingSize - Embedding size + */ + constructor(embeddingSize: number) { + super(); + + const intermediateSize = 4 * embeddingSize; + this.upProj = new nn.Linear(embeddingSize, intermediateSize); + this.downProj = new nn.Linear(intermediateSize, embeddingSize); + + this.activation = (x: Tensor): Tensor => x.gelu(); + } + + /** + * Forward pass through the MLP + * @param input - Input tensor + * @returns Output tensor + */ + forward(input: Tensor): Tensor { + let h = this.upProj.forward(input); + h = this.activation(h); + return this.downProj.forward(h); + } +} diff --git a/examples/finetuning/src/lib/train/model/utils.ts b/examples/finetuning/src/lib/train/model/utils.ts new file mode 100644 index 00000000..2179a192 --- /dev/null +++ b/examples/finetuning/src/lib/train/model/utils.ts @@ -0,0 +1,52 @@ +import { arange, Device, gpu, int32, type Tensor } from '@piston-ml/piston-web'; + +/** + * Create a causal (lower triangular) mask. + * @param queryLen - Length of the current query. + * @param keyLen - Length of the key (which may include cached tokens). + * @returns Causal mask tensor of shape [1, numHeads, queryLen, keyLen]. + */ +export function createCausalMask(queryLen: number, keyLen: number): Tensor { + // General causal mask supporting past KV cache where keyLen may exceed queryLen. + // We want to mask future positions: for each query i, keys j > pastLen + i are masked. + // pastLen is inferred as keyLen - queryLen when using KV cache (else 0). + const pastLen = Math.max(0, keyLen - queryLen); + const i = arange({ end: queryLen, device: gpu, dtype: int32 }) + .unsqueeze(1) + .broadcastTo([queryLen, keyLen]); + const j = arange({ end: keyLen, device: gpu, dtype: int32 }) + .unsqueeze(0) + .broadcastTo([queryLen, keyLen]); + // Mask is true where positions are allowed: j <= pastLen + i + return j.le(i.add(pastLen)); +} + +/** + * Create position IDs tensor [0, 1, 2, ..., seqLen-1] and broadcast to batch size + * @param seqLen - Sequence length + * @param batchSize - Batch size + * @param device - Device to place tensor on + * @returns Position IDs tensor + */ +export function createPositionIds( + seqLen: number, + batchSize: number, + device: Device, + offset: number = 0 +): Tensor { + // Create position IDs tensor [offset, offset+1, ..., offset+seqLen-1] and broadcast to batch + const positionIds = arange({ end: seqLen, device, dtype: int32 }).add(offset).cast(int32); + // Reshape to [1, seqLen] and broadcast to [batchSize, seqLen] + return positionIds.unsqueeze(0).broadcastTo([batchSize, seqLen]); +} + +/** + * Apply mask to attention scores + * @param onFalse - Attention scores + * @param mask - Mask tensor + * @param onTrueValue - Value to fill masked positions with + * @returns Masked scores + */ +export function maskedFill(onTrue: Tensor, mask: Tensor, onFalseValue: number): Tensor { + return onTrue.where(mask, onFalseValue); +} diff --git a/examples/finetuning/src/lib/train/moduleWorker.ts b/examples/finetuning/src/lib/train/moduleWorker.ts new file mode 100644 index 00000000..3d35a2ad --- /dev/null +++ b/examples/finetuning/src/lib/train/moduleWorker.ts @@ -0,0 +1,295 @@ +import type { Config } from '$lib/workspace/config'; +import type { Tensor } from '@piston-ml/piston-web'; + +import * as piston from '@piston-ml/piston-web'; + +import type { WorkerCommand, WorkerEvent } from './protocol'; + +import { TrainingSession } from './session'; +import { type CheckpointExtra, splitLoadedState } from './utils/checkpoint'; +import { inspectModel } from './utils/model'; + +let session: TrainingSession | undefined; + +// Console Interception +const originalConsole = { + log: console.log.bind(console), + error: console.error.bind(console), + warn: console.warn.bind(console), + info: console.info.bind(console), + debug: console.debug.bind(console) +}; + +function formatArgs(args: unknown[]) { + return args + .map((arg) => { + if (typeof arg === 'object' && arg !== null) { + try { + return JSON.stringify(arg); + } catch (_: unknown) { + return '[Unserializable Object]'; + } + } + return String(arg); + }) + .join(' '); +} + +interface LogInfo { + level: string; + message: string; + source: string; + lineno?: number; + colno?: number; +} + +function sendLog(level: string, message: string, source: string = '[Worker]') { + // Check if this is a WASM log and parse it + const wasmLogRegex = /^\[WASM ([^:]+):(\d+)(?::(\d+))?\] (.*)$/; + const match = message.match(wasmLogRegex); + + if (level === 'error' && message.startsWith('panicked at')) { + const lines = message.split('\n'); + if (lines[1].startsWith('VRAM limit exceeded')) { + self.postMessage({ type: 'log', level: 'error', message: lines[1] }); + self.postMessage({ + type: 'error', + runId: session?.runId, + message: 'VRAM limit exceeded', + name: 'VRAMLimitExceededError' + }); + return; + } else { + self.postMessage({ type: 'error', runId: session?.runId, message }); + } + } + + if (match) { + // Handle WASM logs with parsed source info + const [, filepath, lineno, colno, actualMessage] = match; + const logInfo: LogInfo = { + level, + message: actualMessage, + source: `[WASM] ${filepath}`, + lineno: parseInt(lineno, 10), + ...(colno && { colno: parseInt(colno, 10) }) + }; + self.postMessage({ type: 'log', ...logInfo }); + } else { + // Handle regular logs + const logInfo: LogInfo = { + level, + message, + source + }; + self.postMessage({ type: 'log', ...logInfo }); + } +} + +// Wrap console methods before importing Piston to catch its logs +Object.keys(originalConsole).forEach((level) => { + (console as unknown as Record void>)[level] = ( + ...args: unknown[] + ) => { + const message = formatArgs(args); + + // Use sendLog which will handle WASM log parsing internally + sendLog(level, message, currentExecutionSource); + + // Also call original console for debugging + originalConsole[level as keyof typeof originalConsole](...args); + }; +}); + +// Global error handler - catches unhandled errors +self.addEventListener('error', (event) => { + const errorMessage = `Uncaught Error: ${event.message} at ${event.filename}:${event.lineno}:${event.colno}`; + sendLog('error', errorMessage); + if (event.error?.stack) { + sendLog('error', `${event.error.stack}`); + } +}); + +// Unhandled promise rejection handler - catches unhandled promise rejections +self.addEventListener('unhandledrejection', (event) => { + const errorMessage = `Unhandled Promise Rejection: ${event.reason}`; + sendLog('error', errorMessage); + if (event.reason?.stack) { + sendLog('error', `${event.reason.stack}`); + } + // Prevent the default browser behavior (logging to console) + event.preventDefault(); +}); + +// Intercept and override the default error reporting +const originalOnError = self.onerror; +self.onerror = (message, source, lineno, colno, error) => { + const errorMessage = `Global Error: ${message} at ${source}:${lineno}:${colno}`; + sendLog('error', errorMessage); + if (error?.stack) { + sendLog('error', `${error.stack}`); + } + // Call original handler if it exists + if (originalOnError) { + return originalOnError(message, source, lineno, colno, error); + } + // Prevent default browser error handling + return true; +}; + +// +// End Console Interception +// + +// Track current execution context for logging (will be reassigned during execution) +let currentExecutionSource = '[Worker]'; + +function postEvent(e: WorkerEvent) { + self.postMessage(e); +} + +function startTraining() { + if (!session) return; + const runId = session.runId; + session.start().catch((error: unknown) => { + console.error('Training error:', error); + self.postMessage({ + type: 'error', + runId, + message: error instanceof Error ? error.message : String(error), + name: error instanceof Error ? error.name : undefined, + stack: error instanceof Error ? error.stack : undefined + }); + }); +} + +// Message handler for worker +self.addEventListener('message', async (event) => { + const raw = event.data as WorkerCommand | { type: string; data?: unknown }; + const type: string = (raw as { type: string }).type; + const data: unknown = (raw as { data?: unknown }).data; + + switch (type) { + case 'save': { + session?.save(); + break; + } + case 'pause': { + session?.pause(); + break; + } + case 'resume': { + session?.resume(); + startTraining(); + break; + } + case 'step': { + if (!session) break; + await session.pause(); + await session.step({ manual: true }); + break; + } + case 'start': + try { + const { + runId: runIdFromData, + config, + resumeFrom, + gpuPowerPreference + } = data as { + runId: string; + config: Config; + resumeFrom?: Uint8Array; + gpuPowerPreference?: 'high-performance' | 'low-power'; + }; + currentExecutionSource = `[Training:${runIdFromData}]`; + + // Apply GPU power preference before any GPU initialization + if (gpuPowerPreference) { + await piston.applyGpuPowerPreference(gpuPowerPreference); + } + + console.info(`Starting training for run ${runIdFromData}`); + session = new TrainingSession(runIdFromData, config, postEvent, resumeFrom); + startTraining(); + } catch (error: unknown) { + console.error('Training error:', error); + self.postMessage({ + type: 'error', + runId: (data as { runId?: string })?.runId, + message: error instanceof Error ? error.message : String(error), + name: error instanceof Error ? error.name : undefined, + stack: error instanceof Error ? error.stack : undefined + }); + } + break; + case 'checkpoint.peekConfig': { + const { requestId, buffer } = data as { + requestId: string; + buffer: Uint8Array; + }; + const loaded = piston.load(buffer, piston.gpu); + const split = splitLoadedState( + loaded as { state: Record; extra?: CheckpointExtra } + ); + self.postMessage({ type: 'checkpoint.config', requestId, config: split.config }); + break; + } + case 'inspectModel': + try { + const { config, requestId, gpuPowerPreference } = data as { + config: Config; + requestId: string; + gpuPowerPreference?: 'high-performance' | 'low-power'; + }; + currentExecutionSource = '[ModelInspection]'; + + // Apply GPU power preference before any GPU usage + if (gpuPowerPreference) { + await piston.applyGpuPowerPreference(gpuPowerPreference); + (globalThis as unknown as { piston: typeof piston }).piston = piston; + } + + console.debug('Inspecting model...'); + const result = inspectModel(config); + + self.postMessage({ + type: 'modelInspection', + requestId, + parameterCount: result.parameterCount, + hiddenSize: result.hiddenSize, + mlpIntermediateSize: result.mlpIntermediateSize, + modelIndex: result.modelIndex, + vocabSize: result.vocabSize, + blockSize: result.blockSize + }); + } catch (error: unknown) { + console.error('Model inspection error:', error); + self.postMessage({ + type: 'modelInspectionError', + requestId: (data as { requestId?: string })?.requestId ?? '', + message: error instanceof Error ? error.message : String(error) + }); + } + break; + + default: + console.warn(`Unknown message type: ${type}`); + break; + } +}); + +// Initialize Piston, then signal that the worker is ready +piston + .init() + .then(() => { + console.info('Piston initialized'); + self.postMessage({ type: 'ready' }); + }) + .catch((error: unknown) => { + console.error('Error initializing Piston:', error); + self.postMessage({ + type: 'error', + message: error instanceof Error ? error.message : String(error) + }); + }); diff --git a/examples/finetuning/src/lib/train/protocol.ts b/examples/finetuning/src/lib/train/protocol.ts new file mode 100644 index 00000000..e2a9c225 --- /dev/null +++ b/examples/finetuning/src/lib/train/protocol.ts @@ -0,0 +1,81 @@ +import type { Config } from '$lib/workspace/config'; +import type { StepData } from '$lib/workspace/runs.svelte'; +import type { IndexState } from '@piston-ml/piston-web'; + +type WithRunId = { runId: string }; + +export type WorkerCommand = + | { + type: 'start'; + data: { runId: string; config: Config; resumeFrom?: Uint8Array }; + } + | { + type: 'checkpoint.peekConfig'; + data: { requestId: string; buffer: Uint8Array }; + } + | { type: 'pause' } + | { type: 'resume' } + | { type: 'step' } + | { type: 'save' } + | { type: 'stop' } + | { type: 'inspectModel'; data: { requestId: string; config: Config } }; + +type ReadyWorkerEvent = { + type: 'ready'; +}; + +type LogWorkerEvent = { + type: 'log'; + level: 'debug' | 'info' | 'warn' | 'error'; + message: string; + source?: string; + lineno?: number; + colno?: number; +}; + +type ErrorWorkerEvent = { + type: 'error'; + name?: string; + message: string; + stack?: string; +}; + +export type RunWorkerEventWithoutRunId = + | { + type: 'metrics'; + data: { [metricName: string]: Omit }; + metadata?: { step?: number }; + } + | { + type: 'capture'; + step: number; + boxes: unknown[]; + statsById: Record; + width: number; + height: number; + queries: unknown[]; + } + | { type: 'checkpoint'; buffer: Uint8Array } + | { type: 'restart'; buffer: Uint8Array } + | { type: 'paused' } + | { type: 'resumed' } + | { type: 'complete' } + | LogWorkerEvent + | ErrorWorkerEvent; + +export type RunWorkerEvent = RunWorkerEventWithoutRunId & WithRunId; + +export type WorkerEvent = + | ReadyWorkerEvent + | LogWorkerEvent + | ErrorWorkerEvent + | RunWorkerEvent + | { type: 'checkpoint.config'; requestId: string; config: Config } + | { + type: 'modelInspection'; + requestId: string; + parameterCount: number; + vocabSize: number; + modelIndex: IndexState; + } + | { type: 'modelInspectionError'; requestId: string; message: string }; diff --git a/examples/finetuning/src/lib/train/session.ts b/examples/finetuning/src/lib/train/session.ts new file mode 100644 index 00000000..0f07dd26 --- /dev/null +++ b/examples/finetuning/src/lib/train/session.ts @@ -0,0 +1,675 @@ +import type { Config } from '$lib/workspace/config'; +import type { StepData } from '$lib/workspace/runs.svelte'; + +import { + CosineAnnealingLR, + ExponentialLR, + LinearLR, + type LRScheduler, + SequentialLR, + StepLR, + type Tensor +} from '@piston-ml/piston-web'; +import * as piston from '@piston-ml/piston-web'; + +import type { NaturalLanguageDataset } from './data/natural'; +import type { BuiltData } from './data/pipeline'; +import type { RunWorkerEvent, RunWorkerEventWithoutRunId, WorkerEvent } from './protocol'; +import type { GeneratableModel, NaturalCollateFnType } from './types'; + +import { buildDataset, tensorWrap } from './data'; +import { filterDatasetByHeldoutSamples } from './data/filter'; +import { buildDataPipeline } from './data/pipeline'; +import { GPT, GPT2_BLOCK_SIZE, GPT2_VOCAB_SIZE } from './model/gpt'; +import { + type AnySchedulerState, + buildCheckpoint, + type CheckpointDataState, + type CheckpointExtra, + splitLoadedState +} from './utils/checkpoint'; +// import { initTransformerParameters } from './utils/init'; +import { calculateParameterSum, createCollateFn, createModel } from './utils/model'; +import { MarkStepModeIfEnabled, WeakModeIfEnabled } from './utils/modes'; +import { configureOptimizers } from './utils/optim'; +import { + buildValidationExamplesSubset, + buildValidationLog, + computeLikelihoodMetrics, + computeNaturalValidationMetrics, + type NaturalValidationExamples, + prepareNaturalValidationExamples, + type ValidationStep +} from './validation'; + +// @ts-expect-error polyfill +Symbol.dispose ||= Symbol.for('Symbol.dispose'); + +export class TrainingSession { + readonly runId: string; + private config: Config; + private readonly post: (e: RunWorkerEventWithoutRunId) => void; + private readonly resumeFrom?: Uint8Array; + + private paused = false; + private resolvePause: (() => void) | null = null; + + private model!: GeneratableModel; + private optimizer!: piston.Optimizer; + private scheduler: LRScheduler | undefined; + private trainDataset!: NaturalLanguageDataset; + private blockSize!: number; + + private isSetup: boolean = false; + + private startTimeMs: number | null = null; + private lastLogTime: number | null = null; + private lastLogStep: number | null = null; + private stepCount: number = 0; + + private dataPipeline!: BuiltData; + + private validationExamples: NaturalValidationExamples | null = null; + private validationCollateFn: NaturalCollateFnType | null = null; + private validationDataset: NaturalLanguageDataset | null = null; + // This is a little bit gross, but it's a straightforward way to make sure we have valid targets + // when we resume from a checkpoint. If I ever do another pass over this code, this will be first + // to go. + private includeTargetsOnNextValidation: boolean = false; + + constructor( + runId: string, + config: Config, + post: (e: WorkerEvent) => void, + resumeFrom?: Uint8Array + ) { + this.runId = runId; + this.config = config; + this.post = (e: RunWorkerEventWithoutRunId) => + // We only post the subset of events that have runId in the payload + (post as (e: RunWorkerEvent) => void)({ ...e, runId: this.runId }); + this.resumeFrom = resumeFrom; + if (resumeFrom) { + this.includeTargetsOnNextValidation = true; + } + } + + async pause() { + if (this.paused) return; + this.paused = true; + await new Promise((resolve) => { + this.resolvePause = resolve; + }); + this.resolvePause = null; + this.post({ type: 'paused' }); + } + + resume() { + this.paused = false; + this.post({ type: 'resumed' }); + } + + async save() { + try { + if (this.paused) { + try { + if (!this.model) { + // Defer save until model is ready + this.post({ type: 'log', level: 'info', message: 'Save requested before model ready' }); + return; + } + await piston.gpu.markStep(); + const buffer = await this.saveLatestCheckpoint(); + this.post({ type: 'checkpoint', buffer }); + } catch (e) { + this.post({ + type: 'error', + message: String(e) + }); + } + } else { + throw new Error('Saving during training is not supported'); + } + } catch (e) { + this.post({ type: 'error', message: String(e) }); + } + } + + async saveLatestCheckpoint(): Promise> { + if (!this.model) throw new Error('No model available to save'); + await piston.gpu.markStep(); + // Derive dataset state if available + const dataState = { + blockSize: this.blockSize, + ...this.trainDataset.exportState() + }; + const { tensors, extra } = buildCheckpoint( + this.model, + this.optimizer!, + this.stepCount, + this.config ? JSON.parse(JSON.stringify(this.config)) : null, + this.scheduler, + dataState, + this.startTimeMs ?? undefined + ); + return piston.save(tensors, extra); + } + + private logMetrics( + data: { [metricName: string]: Omit }, + metadata?: { step?: number } + ) { + this.post({ type: 'metrics', data, metadata }); + } + + private async setup() { + if (this.isSetup) { + return; + } + + // Log initial memory + const initialMemoryMB = Number(piston.gpu.usageBytes()) / (1024 * 1024); + console.debug(`Initial memory: ${initialMemoryMB} MB`); + + // If resuming from a checkpoint, parse and use checkpoint config + let resumePayload: { + modelState: Record; + optimizerPacked?: { state: Record; paramGroups: piston.ParamGroupConfig[] }; + schedulerState?: unknown; + numSteps: number; + config: Config; + dataState?: CheckpointDataState; + startTimeMs?: number; + } | null = null; + + if (this.resumeFrom) { + const loaded = piston.load(this.resumeFrom, piston.gpu); + const split = splitLoadedState( + loaded as { state: Record; extra?: CheckpointExtra } + ); + resumePayload = { + modelState: split.modelState, + optimizerPacked: split.optimizerState as unknown as { + state: Record; + paramGroups: piston.ParamGroupConfig[]; + }, + schedulerState: split.schedulerState, + numSteps: split.numSteps, + config: split.config, + dataState: split.dataState, + startTimeMs: split.startTimeMs + }; + if (resumePayload.config) { + this.config = resumePayload.config as Config; + } + // If blockSize present in extras, prefer it + if (split.dataState && split.dataState.blockSize !== undefined) { + this.blockSize = split.dataState.blockSize; + } + } + + if (this.config.training.vramLimitMb.present) { + piston.gpu.setVRAMLimit(BigInt(this.config.training.vramLimitMb.value * 1024 * 1024)); + } + + // Ensure shared-object allocation is enabled so buffer handles are stable across steps + piston.gpu.setSharedObjectAllocationEnabled(this.config.training.sharedObjectAllocation); + piston.gpu.setCachingEnabled(this.config.training.cachingEnabled); + piston.gpu.setInplaceSupport(this.config.training.inplaceSupport); + + const trainDataset: NaturalLanguageDataset = buildDataset(this.config, 'train'); + // Restore dataset state if present + if (resumePayload && resumePayload.dataState) { + const dsState = resumePayload.dataState; + await trainDataset.importState(dsState); + } + this.trainDataset = trainDataset; + + const validationDisabled = + ('disableValidation' in this.trainDataset && this.trainDataset.disableValidation) || false; + + this.validationExamples = null; + this.validationCollateFn = null; + this.validationDataset = null; + + if (this.config.training.validation.present && !validationDisabled) { + this.validationDataset = buildDataset(this.config, 'val'); + this.validationCollateFn = createCollateFn(tensorWrap); + this.validationExamples = await prepareNaturalValidationExamples( + this.config, + this.validationDataset! + ); + // Filter training dataset against holdout examples without duplication + const validationSequences: number[][] = this.validationExamples.naturalSequences; + this.trainDataset = filterDatasetByHeldoutSamples( + this.trainDataset, + this.config.data.dataset, + validationSequences + ); + console.debug( + `Prepared ${validationSequences.length} validation examples for batch generation` + ); + } + + if (validationDisabled) { + console.debug('Validation disabled by dataset; skipping validation and holdout filtering.'); + } + + const vocabSize = GPT2_VOCAB_SIZE; + const blockSize = GPT2_BLOCK_SIZE; + this.blockSize = blockSize; + + console.debug( + `Created dataset ${this.trainDataset.name} with vocab size ${vocabSize} and block size ${blockSize}` + ); + + // Create model + this.model = createModel(this.config); + + // If starting from scratch, initialize model parameters + if (!resumePayload) { + // initTransformerParameters(this.model, this.config); + + // We need to flatten down initialization to the constant tensors they're on top of + await piston.gpu.markStep(); + + const parameterSum = new BigUint64Array( + new Float64Array([await (await calculateParameterSum(this.model).to('cpu')).item()]).buffer + ); + console.debug(`Initialization parameter sum: ${parameterSum}`); + } + + // Build and store the training data pipeline (iterator bound to current dataset/collate) + this.dataPipeline = await buildDataPipeline(this.config, this.trainDataset); + + // If resuming, load model state BEFORE creating the optimizer so param identities match + let startStep = 0; + if (resumePayload) { + this.model.loadStateDict(resumePayload.modelState, { strict: false }); + startStep = (resumePayload.numSteps ?? 0) + 1; + this.stepCount = startStep; + // If checkpoint carried a startTimeMs, use it for wall-clock continuity + if (typeof resumePayload.startTimeMs === 'number') { + this.startTimeMs = resumePayload.startTimeMs; + } + } + + // Create optimizer based on model type, using the (possibly restored) model parameters + const optimizer = configureOptimizers( + this.model, + ['transformer.h'], + 'lm_head', + this.config.optimizer, + piston.gpu + ); + this.optimizer = optimizer; + + // If resuming, load optimizer state NOW that groups refer to current model parameters + if (resumePayload && resumePayload.optimizerPacked) { + optimizer.loadStateDict(resumePayload.optimizerPacked as piston.StateDict); + } + + // Create learning rate scheduler if configured + if (this.config.optimizer.lrScheduler.present) { + const lrConfig = this.config.optimizer.lrScheduler; + switch (lrConfig.type) { + case 'step': + this.scheduler = new StepLR( + this.optimizer, + lrConfig.stepSchedule.stepSize, + lrConfig.stepSchedule.gamma + ); + break; + case 'cosine': + this.scheduler = new CosineAnnealingLR( + this.optimizer, + lrConfig.cosineAnnealingSchedule.tMax, + lrConfig.cosineAnnealingSchedule.etaMin + ); + break; + case 'exponential': + this.scheduler = new ExponentialLR(this.optimizer, lrConfig.exponentialSchedule.gamma); + break; + case 'linear': + this.scheduler = new LinearLR( + this.optimizer, + lrConfig.linearSchedule.startFactor, + lrConfig.linearSchedule.endFactor, + lrConfig.linearSchedule.totalIters + ); + break; + default: + throw new Error(`Unknown scheduler type: ${lrConfig.type}`); + } + + if (this.scheduler && this.config.optimizer.warmupSteps.present) { + const n = this.config.optimizer.warmupSteps.value; + if (n > 0) { + const warmup = new LinearLR(optimizer, 1e-8, 1.0, n); + this.scheduler = new SequentialLR(optimizer, [warmup, this.scheduler], [n]); + } + } + } else if (this.config.optimizer.warmupSteps.present) { + const n = this.config.optimizer.warmupSteps.value; + if (n > 0) { + this.scheduler = new LinearLR(optimizer, 1e-8, 1.0, n); + } + } + + // If resuming, load scheduler state after it is created + if (resumePayload && this.scheduler && resumePayload.schedulerState) { + this.scheduler.loadStateDict(resumePayload.schedulerState as AnySchedulerState); + } + + this.model.train(); + + this.isSetup = true; + } + + async step({ manual = false }: { manual?: boolean } = {}): Promise< + IteratorResult + > { + if (this.startTimeMs == null) { + this.startTimeMs = Date.now(); + } + if (this.lastLogStep == null) { + this.lastLogStep = this.stepCount; + } + try { + const iterNext = await this.dataPipeline.train.iterator.next(); + if (iterNext.done) { + return { done: true, value: 'completed' }; + } + const batch = iterNext.value; + performance.mark('stepStart'); + // Reset peak GPU memory tracking at the start of the step + piston.gpu.markUsageBytesStep(); + + let isLastStep = false; + if ( + this.config.training.limitTraining.present && + this.stepCount + 1 >= this.config.training.limitTraining.steps + ) { + console.log( + `Stopping training at step ${this.stepCount} because it reached the limit of ${this.config.training.limitTraining.steps} steps` + ); + isLastStep = true; + } + + const loggingStep = + manual || isLastStep || this.stepCount % this.config.training.logSteps === 0; + + const weakModeUntilAfterBackward = new WeakModeIfEnabled( + this.config.training.useWeakTensorReferences, + { + label: 'train/forward_through_backward' + } + ); + + let loss: Tensor; + try { + // For GPT: batch contains [inputs, targets] + const { tensors } = batch; + const [inputs, gptTargets] = tensors; + const [, computedLoss] = (this.model as GPT).forward(await inputs.to('gpu'), { + targets: await gptTargets.to('gpu') + }); + + if (!computedLoss) { + throw new Error('No loss tensor returned from decoder-only model'); + } + + loss = computedLoss; + + weakModeUntilAfterBackward.pin(loss); + + loss.backward(); + } finally { + weakModeUntilAfterBackward[Symbol.dispose](); + } + + const weakModeForOptimizerStep = new WeakModeIfEnabled( + this.config.training.useWeakTensorReferences, + { + label: 'train/optimizer_step' + } + ); + + let gradNorm: Tensor | undefined; + try { + const weakMarkStepMode = new MarkStepModeIfEnabled( + this.config.training.useWeakTensorReferences + ); + weakModeForOptimizerStep.pin(loss); + + if (this.config.training.gradNorm.track) { + if (this.config.training.clipGradNorm.present) { + gradNorm = weakModeForOptimizerStep.pin( + piston.clipGradNorm_(this.model.parameters(), this.config.training.clipGradNorm.value) + ); + } else if (loggingStep) { + // If we're not clipping gradients, we can just get the total gradient norm + gradNorm = weakModeForOptimizerStep.pin( + piston.getTotalGradNorm(this.model.parameters()) + ); + } + } + + try { + // await this.optimizer.step(); + await piston.gpu.markStep(); + } finally { + weakMarkStepMode[Symbol.dispose](); + } + } finally { + // TODO: decide if it's okay that we're disposing the mode twice here + weakModeForOptimizerStep[Symbol.dispose](); + } + + const finalWeakModeForStep = new WeakModeIfEnabled( + this.config.training.useWeakTensorReferences, + { + label: 'train/final' + } + ); + + try { + // We've kept loss strong; we'll want to make sure we get rid of it + // Batch tensors are created outside of weak mode, so we manually mark them as weak + finalWeakModeForStep.markWeak([loss, gradNorm, batch.tensors]); + + this.optimizer.zeroGrad(true); + + // Step learning rate scheduler if present + if (this.scheduler) { + this.scheduler.step(); + } + + if ( + this.config.training.validation.present && + (this.stepCount % this.config.training.validation.valSteps === 0 || isLastStep) && + this.validationExamples && + this.validationDataset && + this.validationCollateFn + ) { + try { + let valLoss = Number.NaN; + let perplexity = Number.NaN; + let validationLog: Record> = {}; + + if (this.validationExamples) { + if (this.config.training.validation.completions.present) { + let validationExamplesSubset: NaturalValidationExamples | null = null; + if (this.config.training.validation.completions.amount === 'subset') { + validationExamplesSubset = buildValidationExamplesSubset( + this.validationExamples, + this.config.training.validation.completions.subsetSize + ); + } else { + validationExamplesSubset = this.validationExamples; + } + const validationStepData = await computeNaturalValidationMetrics( + this.model, + this.validationDataset, + validationExamplesSubset as NaturalValidationExamples, + this.config.training.validation + ); + validationLog = buildValidationLog(validationStepData); + if (this.includeTargetsOnNextValidation) { + this.includeTargetsOnNextValidation = false; + } + } + + const result = await computeLikelihoodMetrics( + this.model, + this.validationExamples!, + this.validationCollateFn! + ); + + valLoss = result.valLoss; + perplexity = result.perplexity; + + const logData: Record> = { + ...validationLog, + 'validation/loss': valLoss, + 'validation/perplexity': perplexity + }; + this.logMetrics(logData, { step: this.stepCount }); + } + } catch (error) { + console.error('Error during batch validation:', error); + } + } + + if (loggingStep) { + const currentTime = Date.now(); + const totalElapsedSeconds = (currentTime - this.startTimeMs!) / 1000; + + // Calculate delta time and steps since last log + const deltaTime = (currentTime - this.lastLogTime!) / 1000; + const deltaSteps = this.stepCount - this.lastLogStep!; + + // Calculate steps per second and words per second based on delta + const stepsPerSecond = deltaSteps > 0 ? deltaSteps / deltaTime : 0; + + // Calculate words per second (tokens per second) + // Get sequence length from the first tensor in the batch + let sequenceLength = 0; + + // Encoder-decoder will have three tensors in its batch, but we can just use the first one + const [inputs] = batch.tensors; + sequenceLength = inputs.shape[1]; // [batch_size, seq_len] + + const tokensPerStep = this.config.training.batchSize * sequenceLength; + const tokensPerSecond = deltaSteps > 0 ? (deltaSteps * tokensPerStep) / deltaTime : 0; + + const activeMap = piston.__pistonActiveTensors(); + const activeTensors = Array.from(activeMap.values()).reduce((s, v) => s + v.length, 0); + + let lossItem: number | null = null; + + const lossCpu = await loss.to('cpu'); + lossItem = await lossCpu.item(); + + if (lossItem === null) { + throw new Error('Loss item is null?'); + } + + const peakUsageMb = Number(piston.gpu.peakUsageBytes()) / (1024 * 1024); + + const logData: Record = { + 'train/loss': lossItem, + 'allocation/active_tensor_count': activeTensors, + 'allocation/gpu_memory_mb': peakUsageMb, + 'speed/steps_per_second': stepsPerSecond, + 'speed/step': this.stepCount, + 'speed/tokens_per_second': tokensPerSecond, + 'speed/wall_clock_seconds': totalElapsedSeconds + }; + + if (gradNorm) { + const gradNormCpu = await gradNorm.to('cpu'); + const gradNormItem = await gradNormCpu.item(); + if (this.config.training.gradNorm.errorIfNonfinite && !isFinite(gradNormItem)) { + throw new Error(`Gradient norm was nonfinite, so it cannot be clipped.`); + } + logData['train/grad_norm'] = gradNormItem; + } + + // Log current learning rate if scheduler is present + const currentLr = this.optimizer.paramGroups[0].lr; + if (currentLr) { + logData['optimizer/learning_rate'] = currentLr; + } + + this.logMetrics(logData, { step: this.stepCount }); + + // Update last log time and step + this.lastLogTime = currentTime; + this.lastLogStep = this.stepCount; + } + + // Trigger periodic checkpoint save (non-restart) if configured + if (this.config.training.checkpointEverySteps.present) { + const checkpointEvery = this.config.training.checkpointEverySteps.value; + if (checkpointEvery > 0 && (this.stepCount + 1) % checkpointEvery === 0) { + try { + const bytes = await this.saveLatestCheckpoint(); + this.post({ type: 'checkpoint', buffer: bytes }); + } catch (e) { + // Non-fatal; continue training + this.post({ type: 'log', level: 'warn', message: String(e) }); + } + } + } + + // Trigger periodic restart if configured + const restartEvery = this.config.training.restartEverySteps ?? 0; + const willRestart = restartEvery > 0 && (this.stepCount + 1) % restartEvery === 0; + if (willRestart) { + console.debug(`Routine restart at step ${this.stepCount}`); + await piston.gpu.markStep(); + const bytes = await this.saveLatestCheckpoint(); + this.post({ type: 'restart', buffer: bytes }); + return { done: true, value: 'restarted' }; + } + + if (isLastStep) { + return { done: true, value: 'completed' }; + } + } finally { + finalWeakModeForStep[Symbol.dispose](); + } + + this.stepCount++; + + performance.mark('stepEnd'); + } catch (error) { + console.error(`Error during training: ${error}`); + throw error; + } + return { done: false, value: undefined }; + } + + async start(): Promise { + await this.setup(); + while (true) { + if (this.paused) { + if (this.resolvePause) { + this.resolvePause(); + } + return; + } + const { done, value } = await this.step(); + if (done) { + if (value === 'completed') { + this.post({ type: 'complete' }); + break; + } + if (value === 'restarted') { + return; + } + } + } + } +} diff --git a/examples/finetuning/src/lib/train/tokenizer.ts b/examples/finetuning/src/lib/train/tokenizer.ts new file mode 100644 index 00000000..44c8abb7 --- /dev/null +++ b/examples/finetuning/src/lib/train/tokenizer.ts @@ -0,0 +1,2457 @@ +/** + * @fileoverview Simplified Tokenizer implementation adapted from huggingface/transformers.js + */ + +import { PUBLIC_DATA_URL } from '$env/static/public'; +import { Template } from '@huggingface/jinja'; +import { int32, Tensor, tensor } from '@piston-ml/piston-web'; + +/* eslint-disable @typescript-eslint/no-unsafe-declaration-merging */ +abstract class Callable { + /** + * Creates a new instance of the Callable class. + */ + constructor() { + /** + * Creates a closure that delegates to a private method 'call' with the given arguments. + * @param args Zero or more arguments to pass to the 'call' method. + * @returns The result of calling the 'call' method. + */ + const closure = ((...args: Args) => { + return (closure as unknown as { call: (...args: Args) => Return }).call(...args); + }) as unknown as (...args: Args) => Return; + return Object.setPrototypeOf(closure, new.target.prototype) as unknown as this & + ((...args: Args) => Return); + } + + /** + * This method should be implemented in subclasses to provide the + * functionality of the callable object. + * + * @param args Zero or more arguments to pass to the 'call' method. + * @throws {Error} If the subclass does not implement the `call` method. + */ + protected abstract call(..._args: Args): Return; +} +interface Callable { + (...args: Args): Return; +} + +// Discriminated config helpers +type WithType = { type: TType }; + +// Normalizer configs +type NormalizerSequenceConfig = WithType<'Sequence'> & { normalizers: NormalizerConfig[] }; +type NFCConfig = WithType<'NFC'>; +type NFDConfig = WithType<'NFD'>; +type NFKCConfig = WithType<'NFKC'>; +type NFKDConfig = WithType<'NFKD'>; +type StripConfig = WithType<'Strip'> & { stripLeft?: boolean; stripRight?: boolean }; +type LowercaseConfig = WithType<'Lowercase'>; +type PrependConfig = WithType<'Prepend'> & { prepend: string }; +type NormalizerConfig = + | NormalizerSequenceConfig + | NFCConfig + | NFDConfig + | NFKCConfig + | NFKDConfig + | StripConfig + | LowercaseConfig + | PrependConfig; + +// PreTokenizer configs and options +type PreTokenizeOptions = { sectionIndex?: number }; +type PreTokenizerSequenceConfig = WithType<'Sequence'> & { pretokenizers: PreTokenizerConfig[] }; +type WhitespacePreTokenizerConfig = WithType<'Whitespace'>; +type WhitespaceSplitConfig = WithType<'WhitespaceSplit'>; +type MetaspacePreTokenizerConfig = WithType<'Metaspace'> & { + addPrefixSpace: boolean; + replacement: string; + strRep?: string; + prependScheme?: 'first' | 'never' | 'always'; +}; +type ByteLevelPreTokenizerConfig = WithType<'ByteLevel'> & { + addPrefixSpace: boolean; + trimOffsets: boolean; + useRegex?: boolean; +}; +type PreTokenizerConfig = + | PreTokenizerSequenceConfig + | WhitespacePreTokenizerConfig + | WhitespaceSplitConfig + | MetaspacePreTokenizerConfig + | ByteLevelPreTokenizerConfig; + +// PostProcessor configs and options +type PostProcessorOptions = { addSpecialTokens?: boolean }; +type PostProcessorResult = { tokens: string[]; tokenTypeIds?: number[] }; +type PostProcessorSequenceConfig = WithType<'Sequence'> & { processors: PostProcessorConfig[] }; +type ByteLevelPostProcessorConfig = WithType<'ByteLevel'>; +type PostProcessorConfig = PostProcessorSequenceConfig | ByteLevelPostProcessorConfig; + +// Decoder configs +type ByteLevelDecoderConfig = WithType<'ByteLevel'> & { trimOffsets?: boolean }; +type ByteFallbackConfig = WithType<'ByteFallback'>; +type FuseDecoderConfig = WithType<'Fuse'>; +type StripDecoderConfig = WithType<'Strip'> & { content: string; start: number; stop: number }; +type DecoderSequenceConfig = WithType<'Sequence'> & { decoders: DecoderConfig[] }; +type BPEDecoderConfig = WithType<'BPEDecoder'> & { suffix: string }; +type DecoderConfig = + | ByteLevelDecoderConfig + | ByteFallbackConfig + | FuseDecoderConfig + | StripDecoderConfig + | DecoderSequenceConfig + | BPEDecoderConfig; + +// Model configs +export interface TokenizerModelConfig { + fuseUnk: boolean; + byteFallback: boolean; + ignoreMerges: boolean; +} + +type BPEConfig = WithType<'BPE'> & + TokenizerModelConfig & { + vocab: Record; + merges: string[] | [string, string][]; + unkToken: string; + endOfWordSuffix?: string; + continuingSubwordSuffix?: string | null; + }; + +type TokenizerModelFactoryConfig = BPEConfig; // Extend when additional models are added + +// Tokenizer JSON and runtime config +interface TokenizerJSON { + normalizer: NormalizerConfig | null; + preTokenizer: PreTokenizerConfig | null; + model: TokenizerModelFactoryConfig; + postProcessor: PostProcessorConfig | null; + decoder: DecoderConfig | null; + addedTokens: AddedTokenConfig[]; +} + +interface TokenizerConfig { + [key: string]: unknown; + additionalSpecialTokens?: string[]; + modelMaxLength: number; + removeSpace: boolean; + cleanUpTokenizationSpaces?: boolean; + paddingSide?: 'left' | 'right'; + addBosToken?: boolean; + addEosToken?: boolean; + chatTemplate?: null | Array<{ name: string; template: string }> | Record; +} + +const TOKENIZER_URL = PUBLIC_DATA_URL + 'tokenizer'; + +/** + * Loads a tokenizer from the specified path. + * @param tokenizerName The path to the tokenizer directory. + * @returns A promise that resolves with tokenizer JSON and config. + */ +async function loadTokenizer(tokenizerName: string): Promise<[TokenizerJSON, TokenizerConfig]> { + return Promise.all([ + fetchJSON(TOKENIZER_URL, `${tokenizerName}/tokenizer.json`).then( + camelCaseKeysDeep + ), + fetchJSON(TOKENIZER_URL, `${tokenizerName}/tokenizer_config.json`).then( + camelCaseKeysDeep + ) + ]); +} + +function isPlainObject(value: unknown): value is Record { + return ( + value !== null && typeof value === 'object' && Object.getPrototypeOf(value) === Object.prototype + ); +} + +function toCamelKey(key: string): string { + return key.includes('_') + ? key.replace(/_+([a-zA-Z0-9])/g, (_m, c: string) => c.toUpperCase()) + : key; +} + +function camelCaseKeysDeep(input: T): T { + if (Array.isArray(input)) { + return input.map((item) => camelCaseKeysDeep(item)) as unknown as T; + } + if (isPlainObject(input)) { + const obj = input as Record; + const out: Record = Object.create(null); + for (const [key, value] of Object.entries(obj)) { + const transformed = camelCaseKeysDeep(value); + // Preserve original snake_case for compatibility + out[key] = transformed; + const camelKey = toCamelKey(key); + if (camelKey !== key && !(camelKey in out)) { + out[camelKey] = transformed; + } + } + return out as unknown as T; + } + return input; +} + +// Minimal fetch wrapper used here; replace with project-util if available +async function fetchJSON(basePath: string, fileName: string): Promise { + const url = `${basePath.replace(/\/$/, '')}/${fileName}`; + const res = await fetch(url); + if (!res.ok) throw new Error(`Failed to load ${fileName} from ${url}`); + return res.json() as Promise; +} + +/** + * Helper function to convert an Object to a Map + * @param obj The object to convert. + * @returns The map. + */ +function objectToMap(obj: Record): Map { + return new Map(Object.entries(obj)); +} + +/** + * Helper function to fuse consecutive unknown tokens. + * @param arr The list of input tokens + * @param tokensToIds The mapping from tokens to token ids. + * @param unkTokenId The value to fuse on. + */ +function fuseUnk(arr: string[], tokensToIds: Map, unkTokenId: number): string[] { + const fused = []; + let i = 0; + while (i < arr.length) { + fused.push(arr[i]); + if ((tokensToIds.get(arr[i]) ?? unkTokenId) !== unkTokenId) { + ++i; + continue; + } + + while (++i < arr.length && (tokensToIds.get(arr[i]) ?? unkTokenId) === unkTokenId) { + if (tokensToIds.get(fused[fused.length - 1]) !== unkTokenId) { + fused[fused.length - 1] += arr[i]; + } + } + } + + return fused; +} + +/** + * Split a string on whitespace. + * @param text The text to split. + * @returns The split string. + */ +function whitespaceSplit(text: string): string[] { + return text.match(/\S+/g) || []; +} + +/** + * Represent a token added by the user on top of the existing Model vocabulary. + * AddedToken can be configured to specify the behavior they should have in various situations like: + * - Whether they should only match single words + * - Whether to include any whitespace on its left or right + */ +interface AddedTokenConfig { + content: string; + id: number; + singleWord?: boolean; + lstrip?: boolean; + rstrip?: boolean; + normalized?: boolean; + special?: boolean; +} +class AddedToken { + content: string; + id: number; + singleWord: boolean; + lstrip: boolean; + rstrip: boolean; + special: boolean; + normalized: boolean | null; + /** + * Creates a new instance of AddedToken. + * @param config Added token configuration object. + * @param config.content The content of the added token. + * @param config.id The id of the added token. + * @param config.singleWord Whether this token must be a single word or can break words. + * @param config.lstrip Whether this token should strip whitespaces on its left. + * @param config.rstrip Whether this token should strip whitespaces on its right. + * @param config.normalized Whether this token should be normalized. + * @param config.special Whether this token is special. + */ + constructor(config: AddedTokenConfig) { + this.content = config.content; + this.id = config.id; + this.singleWord = config.singleWord ?? false; + this.lstrip = config.lstrip ?? false; + this.rstrip = config.rstrip ?? false; + this.special = config.special ?? false; + this.normalized = config.normalized ?? null; + } +} + +export interface TokenizerModelConfig { + fuseUnk: boolean; + byteFallback: boolean; + ignoreMerges: boolean; +} + +/** + * Abstract base class for tokenizer models. + */ +export class TokenizerModel extends Callable<[string[]], string[]> { + config: TokenizerModelConfig; + vocab: string[]; + tokensToIds: Map; + unkTokenId?: number; + unkToken?: string; + endOfWordSuffix?: string; + fuseUnk: boolean; + /** + * Creates a new instance of TokenizerModel. + * @param config The configuration object for the TokenizerModel. + */ + constructor(config: TokenizerModelConfig) { + super(); + this.config = config; + + this.vocab = []; + + this.tokensToIds = new Map(); + + this.unkTokenId = undefined; + this.unkToken = undefined; + this.endOfWordSuffix = undefined; + + this.fuseUnk = this.config.fuseUnk ?? false; + } + + /** + * Instantiates a new TokenizerModel instance based on the configuration object provided. + * @param config The configuration object for the TokenizerModel. + * @param _args Optional arguments to pass to the specific TokenizerModel constructor. + * @returns A new instance of a TokenizerModel. + * @throws Will throw an error if the TokenizerModel type in the config is not recognized. + */ + static fromConfig(config: TokenizerModelFactoryConfig, ..._args: unknown[]): TokenizerModel { + switch (config.type) { + case 'BPE': + default: + return new BPE(config); + } + } + + /** + * Internal function to call the TokenizerModel instance. + * @param tokens The tokens to encode. + * @returns The encoded tokens. + */ + protected call(...[tokens]: [string[]]): string[] { + tokens = this.encode(tokens); + if (this.fuseUnk) { + // Fuse unknown tokens + tokens = fuseUnk(tokens, this.tokensToIds, this.unkTokenId as number); + } + return tokens; + } + + /** + * Encodes a list of tokens into a list of token IDs. + * @param tokens The tokens to encode. + * @returns The encoded tokens. + * @throws Will throw an error if not implemented in a subclass. + */ + encode(_tokens: string[]): string[] { + throw Error('encode should be implemented in subclass.'); + } + + /** + * Converts a list of tokens into a list of token IDs. + * @param tokens The tokens to convert. + * @returns The converted token IDs. + */ + convertTokensToIds(tokens: string[]): number[] { + return tokens.map((t) => this.tokensToIds.get(t) ?? (this.unkTokenId as number)); + } + + /** + * Converts a list of token IDs into a list of tokens. + * @param ids The token IDs to convert. + * @returns The converted tokens. + */ + convertIdsToTokens(ids: number[] | bigint[]): string[] { + return ids.map((i) => this.vocab[Number(i)] ?? (this.unkToken as string)); + } +} + +/** + * Returns list of utf-8 byte and a mapping to unicode strings. + * Specifically avoids mapping to whitespace/control characters the BPE code barfs on. + * @returns Object with utf-8 byte keys and unicode string values. + */ +const BYTES_TO_UNICODE = (() => { + // Returns list of utf-8 byte and a mapping to unicode strings. + // We specifically avoids mapping to whitespace/control characters the bpe code barfs on. + + const bs = [ + ...Array.from( + { length: '~'.charCodeAt(0) - '!'.charCodeAt(0) + 1 }, + (_, i) => i + '!'.charCodeAt(0) + ), + ...Array.from( + { length: '¬'.charCodeAt(0) - '¡'.charCodeAt(0) + 1 }, + (_, i) => i + '¡'.charCodeAt(0) + ), + ...Array.from( + { length: 'ÿ'.charCodeAt(0) - '®'.charCodeAt(0) + 1 }, + (_, i) => i + '®'.charCodeAt(0) + ) + ]; + const cs = bs.slice(); + let n = 0; + for (let b = 0; b < 256; ++b) { + if (!bs.includes(b)) { + bs.push(b); + cs.push(256 + n); + n += 1; + } + } + const ccs = cs.map((n) => String.fromCharCode(n)); + return Object.fromEntries(bs.map((b, i) => [b, ccs[i]])); +})(); + +const UNICODE_TO_BYTES = Object.fromEntries( + Object.entries(BYTES_TO_UNICODE).map(([key, value]) => [value, key]) +); + +interface BPENode { + token: string; + bias: number; + score?: number; + prev?: BPENode; + next?: BPENode; +} + +/** + * BPE class for encoding text into Byte-Pair-Encoding (BPE) tokens. + */ +class BPE extends TokenizerModel { + merges!: [string, string][]; + bpeRanks!: Map; + continuingSubwordSuffix!: string | null; + byteFallback!: boolean; + textEncoder!: TextEncoder; + ignoreMerges!: boolean; + maxLengthToCache!: number; + cacheCapacity!: number; + cache!: LRUCache; + + constructor(config: BPEConfig) { + super(config); + this.tokensToIds = objectToMap(config.vocab); + this.unkTokenId = this.tokensToIds.get(config.unkToken) as number; + this.unkToken = config.unkToken as string; + this.vocab = new Array(this.tokensToIds.size); + for (const [key, value] of this.tokensToIds) { + this.vocab[value] = key; + } + const useNewMergeFormat = Array.isArray(config.merges[0]); + this.merges = useNewMergeFormat + ? (config.merges as [string, string][]) + : (config.merges as string[]).map((x) => x.split(' ', 2) as [string, string]); + this.bpeRanks = new Map(this.merges.map((x, i) => [JSON.stringify(x), i])) as Map< + string, + number + >; + this.endOfWordSuffix = config.endOfWordSuffix as string | undefined; + this.continuingSubwordSuffix = (config.continuingSubwordSuffix ?? null) as string | null; + this.byteFallback = (this.config.byteFallback ?? false) as boolean; + if (this.byteFallback) { + this.textEncoder = new TextEncoder(); + } + this.ignoreMerges = (this.config.ignoreMerges ?? false) as boolean; + this.maxLengthToCache = 256; + this.cacheCapacity = 10000; + this.cache = new LRUCache(this.cacheCapacity); + } + clearCache() { + this.cache.clear(); + } + bpe(token: string): string[] { + if (token.length === 0) { + return []; + } + const cached = this.cache.get(token); + if (cached !== undefined) { + return cached; + } + const word = Array.from(token); + if (this.endOfWordSuffix) { + word[word.length - 1] += this.endOfWordSuffix; + } + let result: string[] = []; + if (word.length > 1) { + const queue = new PriorityQueue((a, b) => (a.score as number) < (b.score as number)); + let startingNode: BPENode = { + token: word[0], + bias: 0, + prev: undefined, + next: undefined + }; + let previousNode = startingNode; + for (let i = 1; i < word.length; ++i) { + const currentNode: BPENode = { + bias: i / word.length, + token: word[i], + prev: previousNode, + next: undefined + }; + previousNode.next = currentNode; + this.addNode(queue, previousNode); + previousNode = currentNode; + } + while (!queue.isEmpty()) { + const node = queue.pop() as BPENode & { + deleted?: boolean; + prev?: BPENode & { deleted?: boolean }; + next?: BPENode & { deleted?: boolean }; + }; + if (node.deleted || !node.next || node.next.deleted) continue; + node.deleted = true; + node.next.deleted = true; + if (node.prev) { + const newPreviousNode = { ...(node.prev as BPENode) } as BPENode; + node.prev.deleted = true; + node.prev = newPreviousNode; + if (newPreviousNode.prev) { + (newPreviousNode.prev as BPENode).next = newPreviousNode; + } else { + startingNode = newPreviousNode; + } + } + const merged: BPENode = { + token: node.token + (node.next as BPENode).token, + bias: node.bias, + prev: node.prev, + next: (node.next as BPENode).next + }; + if (merged.prev) { + (merged.prev as BPENode).next = merged; + this.addNode(queue, merged.prev as BPENode); + } else { + startingNode = merged; + } + if (merged.next) { + (merged.next as BPENode).prev = merged; + this.addNode(queue, merged); + } + } + for ( + let currentNode: BPENode | undefined = startingNode; + currentNode !== undefined; + currentNode = currentNode.next + ) { + result.push(currentNode.token); + } + } else { + result = word; + } + if (this.continuingSubwordSuffix) { + for (let i = 0; i < result.length - 1; ++i) { + result[i] += this.continuingSubwordSuffix; + } + } + if (token.length < this.maxLengthToCache) { + this.cache.put(token, result); + } + return result; + } + private addNode(queue: PriorityQueue, node: BPENode) { + const rank = this.bpeRanks.get(JSON.stringify([node.token, (node.next as BPENode).token])); + if (rank !== undefined) { + node.score = rank + node.bias; + queue.push(node); + } + } + encode(tokens: string[]): string[] { + const outputTokens: string[] = []; + for (const token of tokens) { + if (this.ignoreMerges && this.tokensToIds.has(token)) { + outputTokens.push(token); + continue; + } + const bpeTokenList = this.bpe(token); + for (const t of bpeTokenList) { + if (this.tokensToIds.has(t)) { + outputTokens.push(t); + } else if (this.byteFallback) { + const byteTokens = Array.from(this.textEncoder.encode(t)).map( + (x) => `<0x${x.toString(16).toUpperCase().padStart(2, '0')}>` + ); + if (byteTokens.every((x) => this.tokensToIds.has(x))) { + outputTokens.push(...byteTokens); + } else { + outputTokens.push(this.unkToken as string); + } + } else { + outputTokens.push(this.unkToken as string); + } + } + } + return outputTokens; + } +} + +/** + * A base class for text normalization. + */ +abstract class Normalizer extends Callable<[string], string> { + config: TConfig; + /** + * @param config The configuration object for the normalizer. + */ + constructor(config: TConfig) { + super(); + this.config = config; + } + static fromConfig(config: TConfig): Normalizer { + switch (config.type) { + case 'Sequence': + return new NormalizerSequence(config); + case 'NFC': + return new NFC(config); + case 'NFD': + return new NFD(config); + case 'NFKC': + return new NFKC(config); + case 'NFKD': + return new NFKD(config); + case 'Strip': + return new StripNormalizer(config); + case 'Lowercase': + return new Lowercase(config); + case 'Prepend': + return new Prepend(config); + } + } + + normalize(_text: string): string { + throw Error('normalize should be implemented in subclass.'); + } + + protected call(...[text]: [string]): string { + return this.normalize(text); + } +} + +/** + * A normalizer that applies Unicode normalization to the input text. + */ +abstract class UnicodeNormalizer extends Normalizer { + form: 'NFC' | 'NFD' | 'NFKC' | 'NFKD' | undefined = undefined; + + /** + * Normalize the input text by applying Unicode normalization. + * @param text The input text to be normalized. + * @returns The normalized text. + */ + normalize(text: string) { + text = text.normalize(this.form as 'NFC'); + return text; + } +} + +/** + * A normalizer that applies Unicode normalization form C (NFC) to the input text. + * Canonical Decomposition, followed by Canonical Composition. + */ +class NFC extends UnicodeNormalizer { + form = 'NFC' as const; +} + +/** + * A normalizer that applies Unicode normalization form D (NFD) to the input text. + * Canonical Decomposition. + */ +class NFD extends UnicodeNormalizer { + form = 'NFD' as const; +} + +/** + * A normalizer that applies Unicode normalization form KC (NFKC) to the input text. + * Compatibility Decomposition, followed by Canonical Composition. + */ +class NFKC extends UnicodeNormalizer { + form = 'NFKC' as const; +} + +/** + * A normalizer that applies Unicode normalization form KD (NFKD) to the input text. + * Compatibility Decomposition. + */ +class NFKD extends UnicodeNormalizer { + form = 'NFKD' as const; +} + +/** + * A normalizer that strips leading and/or trailing whitespace from the input text. + */ +class StripNormalizer extends Normalizer { + /** + * Strip leading and/or trailing whitespace from the input text. + * @param text The input text. + * @returns The normalized text. + */ + normalize(text: string) { + const cfg = this.config; + if (cfg.stripLeft && cfg.stripRight) { + // Fast path to avoid an extra trim call + text = text.trim(); + } else { + if (cfg.stripLeft) { + text = text.trimStart(); + } + if (cfg.stripRight) { + text = text.trimEnd(); + } + } + return text; + } +} + +/** + * A Normalizer that lowercases the input string. + */ +class Lowercase extends Normalizer { + /** + * Lowercases the input string. + * @param text The text to normalize. + * @returns The normalized text. + */ + normalize(text: string) { + text = text.toLowerCase(); + return text; + } +} + +/** + * A Normalizer that prepends a string to the input string. + */ +class Prepend extends Normalizer { + /** + * Prepends the input string. + * @param text The text to normalize. + * @returns The normalized text. + */ + normalize(text: string) { + const cfg = this.config; + text = cfg.prepend + text; + return text; + } +} + +/** + * A Normalizer that applies a sequence of Normalizers. + */ + +class NormalizerSequence extends Normalizer { + normalizers: Normalizer[]; + + constructor(config: NormalizerSequenceConfig) { + super(config); + this.normalizers = config.normalizers.map((x) => Normalizer.fromConfig(x)); + } + /** + * Apply a sequence of Normalizers to the input text. + * @param text The text to normalize. + * @returns The normalized text. + */ + normalize(text: string) { + return this.normalizers.reduce((t, normalizer) => { + return normalizer.normalize(t); + }, text); + } +} + +/** + * A callable class representing a pre-tokenizer used in tokenization. Subclasses + * should implement the `preTokenizeText` method to define the specific pre-tokenization logic. + */ +abstract class PreTokenizer extends Callable< + [string | string[], PreTokenizeOptions | undefined], + string[] +> { + /** + * Factory method that returns an instance of a subclass of `PreTokenizer` based on the provided configuration. + * + * @static + * @param config A configuration object for the pre-tokenizer. + * @returns An instance of a subclass of `PreTokenizer`. + * @throws If the provided configuration object does not correspond to any known pre-tokenizer. + */ + static fromConfig(config: PreTokenizerConfig): PreTokenizer { + switch (config.type) { + case 'Sequence': + return new PreTokenizerSequence(config); + case 'Whitespace': + return new WhitespacePreTokenizer(); + case 'WhitespaceSplit': + return new WhitespaceSplit(); + case 'Metaspace': + return new MetaspacePreTokenizer(config); + case 'ByteLevel': + return new ByteLevelPreTokenizer(config); + default: + throw new Error('Unknown PreTokenizer type'); + } + } + + /** + * Method that should be implemented by subclasses to define the specific pre-tokenization logic. + * + * @param text The text to pre-tokenize. + * @param options Additional options for the pre-tokenization logic. + * @returns The pre-tokenized text. + * @throws {Error} If the method is not implemented in the subclass. + */ + abstract preTokenizeText(text: string, options?: PreTokenizeOptions): string[]; + + /** + * Tokenizes the given text into pre-tokens. + * @param text The text or array of texts to pre-tokenize. + * @param options Additional options for the pre-tokenization logic. + * @returns An array of pre-tokens. + */ + preTokenize(text: string | string[], options?: PreTokenizeOptions): string[] { + return ( + Array.isArray(text) + ? (text as string[]).map((x) => this.preTokenizeText(x, options)) + : this.preTokenizeText(text as string, options) + ).flat(); + } + + /** + * Alias for {@link PreTokenizer#preTokenize}. + * @param text The text or array of texts to pre-tokenize. + * @param options Additional options for the pre-tokenization logic. + * @returns An array of pre-tokens. + */ + protected call( + ...[text, options]: [string | string[], PreTokenizeOptions | undefined] + ): string[] { + return this.preTokenize(text, options); + } +} + +/** + * A pre-tokenizer that splits text into Byte-Pair-Encoding (BPE) subwords. + * @extends PreTokenizer + */ + +class ByteLevelPreTokenizer extends PreTokenizer { + config: ByteLevelPreTokenizerConfig; + addPrefixSpace!: boolean; + trimOffsets!: boolean; + useRegex!: boolean; + pattern!: RegExp; + byteEncoder!: Record; + textEncoder!: TextEncoder; + constructor(config: ByteLevelPreTokenizerConfig) { + super(); + this.config = config; + this.addPrefixSpace = this.config.addPrefixSpace; + this.trimOffsets = this.config.trimOffsets; + this.useRegex = this.config.useRegex ?? true; + this.pattern = /'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/gu; + this.byteEncoder = BYTES_TO_UNICODE as Record; + this.textEncoder = new TextEncoder(); + } + preTokenizeText(text: string, _options?: PreTokenizeOptions): string[] { + if (this.addPrefixSpace && !text.startsWith(' ')) { + text = ' ' + text; + } + const tokens = this.useRegex ? text.match(this.pattern) || [] : [text]; + return tokens.map((token) => + Array.from(this.textEncoder.encode(token), (byte) => this.byteEncoder[byte]).join('') + ); + } +} + +type PostProcessorArgs = [string[], (string[] | null | undefined)?, PostProcessorOptions?]; +abstract class PostProcessor extends Callable { + config: PostProcessorConfig; + constructor(config: PostProcessorConfig) { + super(); + this.config = config; + } + static fromConfig(config: PostProcessorConfig): PostProcessor { + switch (config.type) { + case 'ByteLevel': + return new ByteLevelPostProcessor(config); + case 'Sequence': + return new PostProcessorSequence(config); + default: + throw new Error('Unknown PostProcessor type'); + } + } + + abstract postProcess( + tokens: string[], + ...args: [string[] | null | undefined, PostProcessorOptions?] + ): PostProcessorResult; + + protected call( + ...[tokens, ...args]: [string[], (string[] | null | undefined)?, PostProcessorOptions?] + ): PostProcessorResult { + return this.postProcess(tokens, ...args); + } +} + +/** + * A PostProcessor that returns the given tokens as is. + */ +class ByteLevelPostProcessor extends PostProcessor { + postProcess(tokens: string[], tokensPair: string[] | null = null): { tokens: string[] } { + if (tokensPair) { + tokens = mergeArrays(tokens, tokensPair); + } + return { tokens }; + } +} + +/** + * A post-processor that applies multiple post-processors in sequence. + */ +class PostProcessorSequence extends PostProcessor { + processors: PostProcessor[]; + constructor(config: PostProcessorSequenceConfig) { + super(config); + this.processors = config.processors.map((x) => PostProcessor.fromConfig(x)); + } + postProcess( + tokens: string[], + tokensPair: string[] | null = null, + options: PostProcessorOptions = {} + ): { tokens: string[]; tokenTypeIds?: number[] } { + let tokenTypeIds: number[] | undefined; + for (const processor of this.processors) { + if (processor instanceof ByteLevelPostProcessor) { + const output = processor.postProcess(tokens); + tokens = output.tokens; + if (tokensPair) { + const pairOutput = processor.postProcess(tokensPair); + tokensPair = pairOutput.tokens; + } + } else { + const output = processor.postProcess(tokens, tokensPair ?? null, options); + tokens = output.tokens; + if (output.tokenTypeIds) { + tokenTypeIds = output.tokenTypeIds; + } + } + } + return { tokens, tokenTypeIds: tokenTypeIds }; + } +} + +/** + * The base class for token decoders. + */ +abstract class Decoder extends Callable< + [string[]], + string +> { + config: TConfig; + addedTokens: AddedToken[]; + endOfWordSuffix?: string; + constructor(config: TConfig) { + super(); + this.config = config; + this.addedTokens = []; + this.endOfWordSuffix = undefined; + } + static fromConfig(config: DecoderConfig): Decoder { + switch (config.type) { + case 'ByteLevel': + return new ByteLevelDecoder(config); + case 'ByteFallback': + return new ByteFallback(config); + case 'Fuse': + return new FuseDecoder(config); + case 'Strip': + return new StripDecoder(config); + case 'Sequence': + return new DecoderSequence(config); + case 'BPEDecoder': + return new BPEDecoder(config); + default: + throw new Error('Unknown Decoder type'); + } + } + protected call(...[tokens]: [string[]]): string { + return this.decode(tokens); + } + decode(tokens: string[]): string { + return this.decodeChain(tokens).join(''); + } + abstract decodeChain(tokens: string[]): string[]; +} + +class ByteFallback extends Decoder { + textDecoder!: TextDecoder; + constructor(config: ByteFallbackConfig) { + super(config); + this.textDecoder = new TextDecoder(); + } + decodeChain(tokens: string[]): string[] { + const newTokens: string[] = []; + let previousByteTokens: number[] = []; + for (const token of tokens) { + let bytes: number | null = null; + if (token.length === 6 && token.startsWith('<0x') && token.endsWith('>')) { + const byte = parseInt(token.slice(3, 5), 16); + if (!isNaN(byte)) { + bytes = byte; + } + } + if (bytes !== null) { + previousByteTokens.push(bytes); + } else { + if (previousByteTokens.length > 0) { + const string = this.textDecoder.decode(Uint8Array.from(previousByteTokens)); + newTokens.push(string); + previousByteTokens = []; + } + newTokens.push(token); + } + } + if (previousByteTokens.length > 0) { + const string = this.textDecoder.decode(Uint8Array.from(previousByteTokens)); + newTokens.push(string); + previousByteTokens = []; + } + return newTokens; + } +} + +/** + * Fuse simply fuses all tokens into one big string. + * It's usually the last decoding step anyway, but this decoder + * exists incase some decoders need to happen after that step + */ +class FuseDecoder extends Decoder { + /** @type {Decoder['decodeChain']} */ + decodeChain(tokens: string[]): string[] { + return [tokens.join('')]; + } +} + +class StripDecoder extends Decoder { + content!: string; + start!: number; + stop!: number; + constructor(config: StripDecoderConfig) { + super(config); + const cfg = this.config; + this.content = cfg.content; + this.start = cfg.start; + this.stop = cfg.stop; + } + /** @type {Decoder['decodeChain']} */ + decodeChain(tokens: string[]): string[] { + return tokens.map((token) => { + let startCut = 0; + for (let i = 0; i < this.start; ++i) { + if (token[i] === this.content) { + startCut = i + 1; + continue; + } else { + break; + } + } + let stopCut = token.length; + for (let i = 0; i < this.stop; ++i) { + const index = token.length - i - 1; + if (token[index] === this.content) { + stopCut = index; + continue; + } else { + break; + } + } + return token.slice(startCut, stopCut); + }); + } +} + +/** + * Byte-level decoder for tokenization output. Inherits from the `Decoder` class. + * @extends Decoder + */ +class ByteLevelDecoder extends Decoder { + byteDecoder!: Record; + textDecoder!: TextDecoder; + constructor(config: ByteLevelDecoderConfig) { + super(config); + this.byteDecoder = UNICODE_TO_BYTES as unknown as Record; + this.textDecoder = new TextDecoder('utf-8', { fatal: false, ignoreBOM: true }); + this.endOfWordSuffix = undefined; + } + convertTokensToString(tokens: string[]): string { + const text = tokens.join(''); + const byteArray = new Uint8Array([...text].map((c) => this.byteDecoder[c])); + const decodedText = this.textDecoder.decode(byteArray); + return decodedText; + } + /** @type {Decoder['decodeChain']} */ + decodeChain(tokens: string[]): string[] { + const subTexts: string[] = []; + let currentSubText: string[] = []; + for (const token of tokens) { + if (this.addedTokens.find((x) => x.content === token) !== undefined) { + if (currentSubText.length > 0) { + subTexts.push(this.convertTokensToString(currentSubText)); + currentSubText = []; + } + subTexts.push(token); + } else { + currentSubText.push(token); + } + } + if (currentSubText.length > 0) { + subTexts.push(this.convertTokensToString(currentSubText)); + } + return subTexts; + } +} + +/** + * Apply a sequence of decoders. + * @extends Decoder + */ +class DecoderSequence extends Decoder { + decoders!: Decoder[]; + constructor(config: DecoderSequenceConfig) { + super(config); + this.decoders = config.decoders.map((x) => Decoder.fromConfig(x)); + } + /** @type {Decoder['decodeChain']} */ + decodeChain(tokens: string[]): string[] { + return this.decoders.reduce((toks: string[], decoder: Decoder) => { + return decoder.decodeChain(toks); + }, tokens); + } +} + +class BPEDecoder extends Decoder { + suffix!: string; + constructor(config: BPEDecoderConfig) { + super(config); + const cfg = this.config; + this.suffix = cfg.suffix; + } + /** @type {Decoder['decodeChain']} */ + decodeChain(tokens: string[]): string[] { + return tokens.map((token, i) => { + return token.replaceAll(this.suffix, i === tokens.length - 1 ? '' : ' '); + }); + } +} + +/** + * This PreTokenizer replaces spaces with the given replacement character, adds a prefix space if requested, + * and returns a list of tokens. + * @extends PreTokenizer + */ +class MetaspacePreTokenizer extends PreTokenizer { + addPrefixSpace: boolean; + replacement: string; + strRep: string; + prependScheme: 'first' | 'never' | 'always'; + /** + * @param {Object} config The configuration object for the MetaspacePreTokenizer. + * @param {boolean} config.addPrefixSpace Whether to add a prefix space to the first token. + * @param {string} config.replacement The character to replace spaces with. + * @param {string} [config.strRep=config.replacement] An optional string representation of the replacement character. + * @param {'first'|'never'|'always'} [config.prependScheme='always'] The metaspace prepending scheme. + */ + constructor(config: MetaspacePreTokenizerConfig) { + super(); + + this.addPrefixSpace = config.addPrefixSpace; + this.replacement = config.replacement; + this.strRep = config.strRep || this.replacement; + this.prependScheme = config.prependScheme ?? 'always'; + } + + /** + * This method takes a string, replaces spaces with the replacement character, + * adds a prefix space if requested, and returns a new list of tokens. + * @param text The text to pre-tokenize. + * @param options The options for the pre-tokenization. + * @param options.sectionIndex The index of the section to pre-tokenize. + * @returns A new list of pre-tokenized tokens. + */ + preTokenizeText( + text: string, + { sectionIndex: sectionIndex = undefined }: PreTokenizeOptions = {} + ) { + let normalized = text.replaceAll(' ', this.strRep); + + if ( + // We add a prefix space if: + // (1) The addPrefixSpace option is enabled and the normalized token does not already start + // with the replacement character. + this.addPrefixSpace && + !normalized.startsWith(this.replacement) && + // and (2) either: + // (a) prependScheme is 'always' + // (b) prependScheme is 'first' and this is the first section + (this.prependScheme === 'always' || (this.prependScheme === 'first' && sectionIndex === 0)) + ) { + normalized = this.strRep + normalized; + } + return [normalized]; + } +} + +/** + * A pre-tokenizer that applies a sequence of pre-tokenizers to the input text. + * @extends PreTokenizer + */ +class PreTokenizerSequence extends PreTokenizer { + tokenizers: PreTokenizer[]; + /** + * Creates an instance of PreTokenizerSequence. + * @param {Object} config The configuration object for the pre-tokenizer sequence. + * @param {Object[]} config.pretokenizers An array of pre-tokenizer configurations. + */ + constructor(config: PreTokenizerSequenceConfig) { + super(); + this.tokenizers = config.pretokenizers.map((x) => PreTokenizer.fromConfig(x)); + } + + /** + * Applies each pre-tokenizer in the sequence to the input text in turn. + * @param text The text to pre-tokenize. + * @param options Additional options for the pre-tokenization logic. + * @returns The pre-tokenized text. + */ + preTokenizeText(text: string, options: PreTokenizeOptions) { + // Use reduce to apply each tokenizer to the text + return this.tokenizers.reduce( + (preTokenizedText, tokenizer) => { + return tokenizer.preTokenize(preTokenizedText, options); + }, + [text] + ); + } +} + +/** + * Splits on word boundaries (using the following regular expression: `\w+|[^\w\s]+`). + */ +class WhitespacePreTokenizer extends PreTokenizer { + /** + * Creates an instance of WhitespacePreTokenizer. + * @param config The configuration object for the pre-tokenizer. + */ + constructor() { + super(); + } + /** + * Pre-tokenizes the input text by splitting it on word boundaries. + * @param text The text to be pre-tokenized. + * @param options Additional options for the pre-tokenization logic. + * @returns An array of tokens produced by splitting the input text on whitespace. + */ + preTokenizeText(text: string, _options: unknown) { + return text.match(/\w+|[^\w\s]+/g) || []; + } +} + +/** + * Splits a string of text by whitespace characters into individual tokens. + * @extends PreTokenizer + */ +class WhitespaceSplit extends PreTokenizer { + /** + * Creates an instance of WhitespaceSplit. + * @param config The configuration object for the pre-tokenizer. + */ + constructor() { + super(); + } + /** + * Pre-tokenizes the input text by splitting it on whitespace characters. + * @param text The text to be pre-tokenized. + * @param options Additional options for the pre-tokenization logic. + * @returns An array of tokens produced by splitting the input text on whitespace. + */ + preTokenizeText(text: string, _options: unknown) { + return whitespaceSplit(text); + } +} + +const SPECIAL_TOKEN_ATTRIBUTES = [ + 'bos_token', + 'eos_token', + 'unk_token', + 'sep_token', + 'pad_token', + 'cls_token', + 'mask_token' + // additional_special_tokens (TODO) +]; + +/** + * + * Helper function for padding values of an object, which are each arrays. + * NOTE: No additional checks are made here for validity of arguments. + * @param item The input object. + * @param length The length to pad to. + * @param valueFn Determine the value to fill the array, based on its key. + * @param side Which side to pad the array. + */ +function padHelper( + item: Record, + length: number, + valueFn: (key: string) => T, + side: 'right' | 'left' +) { + for (const key of Object.keys(item)) { + const diff = length - item[key].length; + const value = valueFn(key); + + const padData = new Array(diff).fill(value); + item[key] = + side === 'right' ? mergeArrays(item[key], padData) : mergeArrays(padData, item[key]); + } +} + +/** + * Helper function for truncating values of an object, which are each arrays. + * NOTE: No additional checks are made here for validity of arguments. + * @param item The input object. + * @param length The length to truncate to. + */ +function truncateHelper(item: Record, length: number) { + // Setting .length to a lower value truncates the array in-place: + // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/length + for (const key of Object.keys(item)) { + item[key].length = length; + } +} + +interface DecodeArgs { + skipSpecialTokens?: boolean; + cleanUpTokenizationSpaces?: boolean; +} + +type BatchEncodingItem = number[] | number[][] | Tensor; + +interface BatchEncoding { + inputIds: BatchEncodingItem; + attentionMask: BatchEncodingItem; + tokenTypeIds?: BatchEncodingItem; +} + +interface Message { + role: string; + content: string; +} + +export class PreTrainedTokenizer extends Callable< + [ + string | string[], + { + textPair?: string | null; + addSpecialTokens?: boolean; + padding?: boolean | 'max_length'; + truncation?: boolean | null; + maxLength?: number | null; + returnTensor?: boolean; + returnTokenTypeIds?: boolean | null; + }? + ], + BatchEncoding +> { + config: TokenizerConfig; + normalizer!: ((text: string) => string) | Normalizer | null; + preTokenizer!: ((text: string, options?: PreTokenizeOptions) => string[]) | PreTokenizer | null; + model!: TokenizerModel; + postProcessor!: + | (( + tokens: string[], + tokensPair?: string[] | null, + options?: PostProcessorOptions + ) => PostProcessorResult) + | PostProcessor + | null; + decoder!: ((tokens: string[]) => string) | Decoder | null; + specialTokens: string[]; + allSpecialIds: number[]; + addedTokens: AddedToken[]; + additionalSpecialTokens: string[]; + addedTokensSplitter: DictionarySplitter; + addedTokensMap: Map; + maskToken?: string | null; + maskTokenId?: number; + padToken?: string | null; + padTokenId?: number; + sepToken?: string | null; + sepTokenId?: number; + unkToken?: string | null; + unkTokenId?: number; + bosToken?: string | null; + bosTokenId?: number; + eosToken?: string | null; + eosTokenId?: number; + modelMaxLength!: number; + removeSpace!: boolean; + cleanUpTokenizationSpaces!: boolean; + paddingSide: 'left' | 'right' = 'right'; + addBoxToken?: boolean; + addEosToken?: boolean; + chatTemplate: null | Record | Array<{ name: string; template: string }>; + returnTokenTypeIds = false; + private compiledTemplateCache: Map; + constructor(tokenizerJSON: TokenizerJSON, tokenizerConfig: TokenizerConfig) { + super(); + this.config = tokenizerConfig; + this.normalizer = tokenizerJSON.normalizer + ? Normalizer.fromConfig(tokenizerJSON.normalizer) + : null; + this.preTokenizer = tokenizerJSON.preTokenizer + ? PreTokenizer.fromConfig(tokenizerJSON.preTokenizer) + : null; + this.model = TokenizerModel.fromConfig(tokenizerJSON.model, tokenizerConfig); + this.postProcessor = tokenizerJSON.postProcessor + ? PostProcessor.fromConfig(tokenizerJSON.postProcessor) + : null; + this.decoder = tokenizerJSON.decoder ? Decoder.fromConfig(tokenizerJSON.decoder) : null; + this.specialTokens = []; + this.allSpecialIds = []; + this.addedTokens = []; + for (const addedToken of tokenizerJSON.addedTokens) { + const token = new AddedToken(addedToken); + this.addedTokens.push(token); + this.model.tokensToIds.set(token.content, token.id); + this.model.vocab[token.id] = token.content; + if (token.special) { + this.specialTokens.push(token.content); + this.allSpecialIds.push(token.id); + } + } + this.additionalSpecialTokens = tokenizerConfig.additionalSpecialTokens ?? []; + this.specialTokens.push(...this.additionalSpecialTokens); + this.specialTokens = [...new Set(this.specialTokens)]; + if (this.decoder) { + (this.decoder as Decoder).addedTokens = this.addedTokens; + (this.decoder as Decoder).endOfWordSuffix = this.model.endOfWordSuffix; + } + this.addedTokensSplitter = new DictionarySplitter(this.addedTokens.map((x) => x.content)); + this.addedTokensMap = new Map(this.addedTokens.map((x) => [x.content, x])); + this.maskToken = this.getToken('mask_token'); + this.maskTokenId = this.model.tokensToIds.get(this.maskToken as string); + this.padToken = this.getToken('pad_token', 'eos_token'); + this.padTokenId = this.model.tokensToIds.get(this.padToken as string); + this.sepToken = this.getToken('sep_token'); + this.sepTokenId = this.model.tokensToIds.get(this.sepToken as string); + this.unkToken = this.getToken('unk_token'); + this.unkTokenId = this.model.tokensToIds.get(this.unkToken as string); + this.bosToken = this.getToken('bos_token'); + this.bosTokenId = this.model.tokensToIds.get(this.bosToken as string); + this.eosToken = this.getToken('eos_token'); + this.eosTokenId = this.model.tokensToIds.get(this.eosToken as string); + this.modelMaxLength = tokenizerConfig.modelMaxLength as number; + this.removeSpace = tokenizerConfig.removeSpace as boolean; + this.cleanUpTokenizationSpaces = (tokenizerConfig.cleanUpTokenizationSpaces ?? true) as boolean; + if (tokenizerConfig.paddingSide) { + this.paddingSide = tokenizerConfig.paddingSide as 'left' | 'right'; + } + this.addBoxToken = tokenizerConfig.addBosToken as boolean; + this.addEosToken = tokenizerConfig.addEosToken as boolean; + this.chatTemplate = tokenizerConfig.chatTemplate ?? null; + if (Array.isArray(this.chatTemplate)) { + const chatTemplate: Record = Object.create(null); + for (const { name, template } of this.chatTemplate) { + if (typeof name !== 'string' || typeof template !== 'string') { + throw new Error( + 'Chat template must be a list of objects with "name" and "template" properties' + ); + } + chatTemplate[name] = template; + } + this.chatTemplate = chatTemplate; + } + this.compiledTemplateCache = new Map(); + } + getToken(...keys: string[]): string | null { + for (const key of keys) { + const item = this.config[key]; + if (!item) continue; + if (typeof item === 'object') { + const maybe = item as { type?: string; content?: string }; + if (maybe.type === 'AddedToken' && typeof maybe.content === 'string') { + return maybe.content; + } + throw Error(`Unknown token: ${String(item)}`); + } else { + return item as string; + } + } + return null; + } + + static async fromPretrained(tokenizerName: string): Promise { + const info = await loadTokenizer(tokenizerName); + return new this(...info); + } + + /** + * Encode/tokenize the given text(s). + * @param text The text to tokenize. + * @param options An optional object containing the following properties: + * @param options.textPair A second sequence to be encoded with the first. + * @param options.padding Whether to pad the input sequences. + * @param options.addSpecialTokens Whether or not to add the special tokens associated with the corresponding model. + * @param options.truncation Whether to truncate the input sequences. + * @param options.maxLength Maximum length of the returned list and optionally padding length. + * @param options.returnTensor Whether to return the results as Tensors or arrays. + * @param options.returnTokenTypeIds Whether to return the token type ids. + * @returns Object to be passed to the model. + */ + protected call( + text: string | string[], + { + textPair = null, + addSpecialTokens = true, + padding = false, + truncation = null, + maxLength = null, + returnTensor = true, + returnTokenTypeIds = null + }: { + textPair?: string | null; + addSpecialTokens?: boolean; + padding?: boolean | 'max_length'; + truncation?: boolean | null; + maxLength?: number | null; + returnTensor?: boolean; + returnTokenTypeIds?: boolean | null; + } = {} + ): BatchEncoding { + const isBatched = Array.isArray(text); + + let encodedTokens; + + if (isBatched) { + if (text.length === 0) { + throw Error('text array must be non-empty'); + } + encodedTokens = text.map((x) => + this.encodePlus(x, { + addSpecialTokens: addSpecialTokens, + returnTokenTypeIds: returnTokenTypeIds + }) + ); + } else { + if (text === null || text === undefined) { + throw Error('text may not be null or undefined'); + } + + if (Array.isArray(textPair)) { + throw Error( + 'When specifying `textPair`, since `text` is a string, `textPair` must also be a string (i.e., not an array).' + ); + } + + // For single input, we just wrap in an array, and then unwrap later. + encodedTokens = [ + this.encodePlus(text, { + addSpecialTokens: addSpecialTokens, + returnTokenTypeIds: returnTokenTypeIds + }) + ]; + } + // At this point, `encodedTokens` is batched, of shape [batchSize, tokens]. + // However, array may be jagged. So, we may need pad to maxLength. + if (maxLength === null) { + maxLength = this.modelMaxLength; + } else if (truncation === null) { + if (padding === true) { + console.warn( + '`maxLength` is ignored when `padding: true` and there is no truncation strategy. ' + + "To pad to max length, use `padding: 'maxLength'`." + ); + maxLength = this.modelMaxLength; + } else if (padding === false) { + console.warn( + 'Truncation was not explicitly activated but `maxLength` is provided a specific value, please use `truncation: true` to explicitly truncate examples to max length.' + ); + truncation = true; + } + } + + // padding: 'maxLength' doesn't require any additional calculation + // but padding: true has to calculate maxLength from the sequences + if (padding === true) { + maxLength = Math.min( + max(encodedTokens.map((x) => x.inputIds.length))[0], + maxLength ?? Infinity + ); + } + + // Ensure it is less than model max length + maxLength = Math.min(maxLength, this.modelMaxLength ?? Infinity); + + if (padding || truncation) { + // Perform padding and/or truncation + for (let i = 0; i < encodedTokens.length; ++i) { + if (encodedTokens[i].inputIds.length === maxLength) { + continue; + } else if (encodedTokens[i].inputIds.length > maxLength) { + // possibly truncate + if (truncation) { + truncateHelper(encodedTokens[i], maxLength); + } + } else { + // t.length < maxLength + // possibly pad + if (padding) { + padHelper( + encodedTokens[i], + maxLength, + (key) => (key === 'inputIds' ? this.padTokenId : 0), + this.paddingSide + ); + } + } + } + } + + const result: Record = {}; + + if (returnTensor) { + if (!(padding && truncation)) { + // Not, guaranteed that all items have same length, so + // we perform additional check + + if ( + encodedTokens.some((x) => { + for (const key of Object.keys(x)) { + if ( + (x as Record)[key].length !== + (encodedTokens[0] as Record)[key]?.length + ) { + return true; + } + } + return false; + }) + ) { + throw Error( + 'Unable to create tensor, you should probably activate truncation and/or padding ' + + "with 'padding=true' and 'truncation=true' to have batched tensors with the same length." + ); + } + } + + // Now we actually convert to tensor + // NOTE: In the same way as the python library, we return a batched tensor, regardless of + // whether we have a single input or multiple inputs. + const dims = [encodedTokens.length, encodedTokens[0].inputIds.length]; + + for (const key of Object.keys(encodedTokens[0])) { + result[key] = tensor( + Int32Array.from( + encodedTokens + .flatMap( + (x) => + (x as Record)[key] as (bigint | boolean | number | string)[] + ) + .map(Number) + ), + { shape: dims, dtype: int32 } + ); + } + } else { + for (const key of Object.keys(encodedTokens[0])) { + result[key] = encodedTokens.map((x) => (x as Record)[key]); + } + + // If not returning a tensor, we match the input type + if (!isBatched) { + // Input was not batched, so we unwrap + for (const key of Object.keys(result)) { + result[key] = (result[key] as unknown[])[0]; + } + } + } + + return result as unknown as BatchEncoding; + } + + /** + * Encodes a single text using the preprocessor pipeline of the tokenizer. + * + * @param {string|null} text The text to encode. + * @returns {string[]|null} The encoded tokens. + */ + private encodeText(text: string | null): string[] | null { + if (text === null) return null; + + // Actual function which does encoding, for a single text + // First, we take care of special tokens. Needed to avoid issues arising from + // normalization and/or pretokenization (which may not preserve special tokens) + const sections = this.addedTokensSplitter.split(text); + + // Process left/right stripping of added tokens + for (let i = 0; i < sections.length; ++i) { + const addedToken = this.addedTokensMap.get(sections[i]); + if (addedToken) { + if (addedToken.lstrip && i > 0) { + sections[i - 1] = sections[i - 1].trimEnd(); + } + if (addedToken.rstrip && i < sections.length - 1) { + sections[i + 1] = sections[i + 1].trimStart(); + } + } + } + + const tokens = sections.flatMap((x, sectionIndex) => { + if (x.length === 0) return []; + if (this.addedTokensMap.has(x)) return [x]; // Return added tokens unchanged + + if (this.removeSpace === true) { + x = x.trim().split(/\s+/).join(' '); + } + + if (this.normalizer !== null) { + x = this.normalizer(x); + } + + // If, after normalization, this section is empty (e.g., trimming whitespace), + // we return an empty array + if (x.length === 0) { + return []; + } + + const sectionTokens = + this.preTokenizer !== null + ? this.preTokenizer(x, { + sectionIndex: sectionIndex + }) + : [x]; + + const tokens = this.model(sectionTokens); + + return tokens; + }); + + return tokens; + } + + /** + * Encodes a single text or a pair of texts using the model's tokenizer. + * + * @param text The text to encode. + * @param options An optional object containing the following properties: + * @param options.textPair The optional second text to encode. + * @param options.addSpecialTokens Whether or not to add the special tokens associated with the corresponding model. + * @param options.returnTokenTypeIds Whether to return tokenTypeIds. + * @returns An object containing the encoded text. + */ + private encodePlus( + text: string, + { + textPair = null, + addSpecialTokens = true, + returnTokenTypeIds = null + }: { + textPair?: string | null; + addSpecialTokens?: boolean; + returnTokenTypeIds?: boolean | null; + } = {} + ) { + const { tokens, tokenTypeIds } = this.tokenizeHelper(text, { + pair: textPair, + addSpecialTokens + }); + + const inputIds = this.model.convertTokensToIds(tokens); + + const result = { + inputIds: inputIds, + attentionMask: new Array(inputIds.length).fill(1) + }; + if ((returnTokenTypeIds ?? this.returnTokenTypeIds) && tokenTypeIds) { + (result as { tokenTypeIds?: number[] }).tokenTypeIds = tokenTypeIds; + } + return result; + } + + /** + * Internal helper function to tokenize a text, and optionally a pair of texts. + * @param text The text to tokenize. + * @param options An optional object containing the following properties: + * @param options.pair The optional second text to tokenize. + * @param options.addSpecialTokens Whether or not to add the special tokens associated with the corresponding model. + * @returns An object containing the tokens and optionally the token type IDs. + */ + private tokenizeHelper( + text: string, + { + pair = null, + addSpecialTokens = false + }: { pair?: string | null; addSpecialTokens?: boolean } = {} + ) { + const tokens = this.encodeText(text); + const tokens2 = this.encodeText(pair); + + return this.postProcessor + ? this.postProcessor(tokens ?? [], tokens2 ?? null, { addSpecialTokens }) + : { tokens: mergeArrays(tokens ?? [], tokens2 ?? []) }; + } + + /** + * Converts a string into a sequence of tokens. + * @param text The sequence to be encoded. + * @param options An optional object containing the following properties: + * @param options.pair A second sequence to be encoded with the first. + * @param options.addSpecialTokens Whether or not to add the special tokens associated with the corresponding model. + * @returns The list of tokens. + */ + tokenize(text: string, { pair = null, addSpecialTokens = false } = {}) { + return this.tokenizeHelper(text, { pair, addSpecialTokens }).tokens; + } + + /** + * Encodes a single text or a pair of texts using the model's tokenizer. + * + * @param text The text to encode. + * @param options An optional object containing the following properties: + * @param options.addSpecialTokens Whether or not to add the special tokens associated with the corresponding model. + * @param options.returnTokenTypeIds Whether to return tokenTypeIds. + * @returns An array of token IDs representing the encoded text(s). + */ + encode(text: string, { addSpecialTokens = true, returnTokenTypeIds = null } = {}) { + return this.encodePlus(text, { + addSpecialTokens, + returnTokenTypeIds + }).inputIds; + } + + /** + * Decode a batch of tokenized sequences. + * @param batch List of tokenized input sequences. + * @param decodeArgs (Optional) Object with decoding arguments. + * @returns List of decoded sequences. + */ + batchDecode(batch: number[][], decodeArgs: DecodeArgs = {}) { + return batch.map((x) => this.decode(x, decodeArgs)); + } + + /** + * Decodes a sequence of token IDs back to a string. + * + * @param tokenIds List of token IDs to decode. + * @param decodeArgs (Optional) Object with decoding arguments. + * + * @returns The decoded string. + * @throws If `tokenIds` is not a non-empty array of integers. + */ + decode(tokenIds: number[], decodeArgs: DecodeArgs = {}) { + if ( + !Array.isArray(tokenIds) || + tokenIds.length === 0 || + !(Number.isInteger(tokenIds[0]) || typeof tokenIds[0] === 'bigint') + ) { + throw Error('tokenIds must be a non-empty array of integers.'); + } + + return this.decodeSingle(tokenIds, decodeArgs); + } + + /** + * Decode a single list of token ids to a string. + * @param tokenIds List of token ids to decode + * @param decodeArgs Optional arguments for decoding + * @param [decodeArgs.skipSpecialTokens=false] Whether to skip special tokens during decoding + * @param [decodeArgs.cleanUpTokenizationSpaces=null] Whether to clean up tokenization spaces + * during decoding. If null, the value is set to `this.decoder.cleanup` if it exists, falling + * back to `this.cleanUpTokenizationSpaces` if it exists, falling back to `true`. + * @returns The decoded string + */ + decodeSingle(tokenIds: number[], { skipSpecialTokens = false }: DecodeArgs = {}) { + let tokens = this.model.convertIdsToTokens(tokenIds); + if (skipSpecialTokens) { + tokens = tokens.filter((x) => !this.specialTokens.includes(x)); + } + + // If `this.decoder` is null, we just join tokens with a space: + // https://github.com/huggingface/tokenizers/blob/8edec536a737cb04494b454805be16c020abb14f/tokenizers/src/tokenizer/mod.rs#L835 + let decoded = this.decoder ? this.decoder(tokens) : tokens.join(' '); + + // Slight hack, but prevents having to pass `skipSpecialTokens` to each call to `decode`, which + // would lead to code duplication. + if (this.decoder && 'endOfWordSuffix' in this.decoder && this.decoder.endOfWordSuffix) { + decoded = decoded.replaceAll(this.decoder.endOfWordSuffix, ' '); + if (skipSpecialTokens) { + decoded = decoded.trim(); + } + } + + return decoded; + } + + /** + * Retrieve the chat template string used for tokenizing chat messages. This template is used + * internally by the `applyChatTemplate` method and can also be used externally to retrieve the + * model's chat template for better generation tracking. + * + * @param options An optional object containing the following properties: + * @param options.chatTemplate A Jinja template or the name of a template to use for this + * conversion. It is usually not necessary to pass anything to this argument, as the model's + * template will be used by default. + * @param options.tools A list of tools (callable functions) that will be accessible to the model. + * If the template does not support function calling, this argument will have no effect. Each + * tool should be passed as a JSON Schema, giving the name, description and argument types for + * the tool. See our + * [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) + * for more information. + * @returns The chat template string. + */ + getChatTemplate({ + chatTemplate = null, + tools = null + }: { chatTemplate?: string | null; tools?: string[] | null } = {}): string { + // First, handle the cases when the model has a dict of multiple templates + if (this.chatTemplate && typeof this.chatTemplate === 'object') { + const templateDict = this.chatTemplate; + + if (chatTemplate !== null && Object.hasOwn(templateDict, chatTemplate)) { + // The user can pass the name of a template to the chat template argument instead of an + // entire template + chatTemplate = (templateDict as Record)[chatTemplate]; + } else if (chatTemplate === null) { + if (tools !== null && 'toolUse' in templateDict) { + chatTemplate = templateDict['toolUse']; + } else if ('default' in templateDict) { + chatTemplate = templateDict['default']; + } else { + throw Error( + `This model has multiple chat templates with no default specified! Please either pass` + + ` a chat template or the name of the template you wish to use to the 'chatTemplate'` + + ` argument. Available template names are ${Object.keys(templateDict).sort()}.` + ); + } + } + } else if (chatTemplate === null) { + // These are the cases when the model has a single template + // priority: `chatTemplate` argument > `tokenizer.chatTemplate` + if (this.chatTemplate) { + chatTemplate = this.chatTemplate; + } else { + throw Error( + 'Cannot use applyChatTemplate() because tokenizer.chatTemplate is not set and no template ' + + 'argument was passed! For information about writing templates and setting the ' + + 'tokenizer.chatTemplate attribute, please see the documentation at ' + + 'https://huggingface.co/docs/transformers/main/en/chat_templating' + ); + } + } + return chatTemplate; + } + + /** + * Converts a list of message objects with `"role"` and `"content"` keys to a list of token + * ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to + * determine the format and control tokens to use when converting. + * + * See [here](https://huggingface.co/docs/transformers/chat_templating) for more information. + * + * @param conversation A list of message objects with `"role"` and `"content"` keys, + * representing the chat history so far. + * @param options An optional object containing the following properties: + * @param options.chatTemplate A Jinja template to use for this conversion. If + * this is not passed, the model's chat template will be used instead. + * @param options.tools A list of tools (callable functions) that will be accessible to the model. + * If the template does not support function calling, this argument will have no effect. Each + * tool should be passed as a JSON Schema, giving the name, description and argument types for + * the tool. See our + * [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) + * for more information. + * @param options.documents A list of dicts representing documents that will be accessible to the model if it is performing RAG + * (retrieval-augmented generation). If the template does not support RAG, this argument will have no + * effect. We recommend that each document should be a dict containing "title" and "text" keys. Please + * see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG) + * for examples of passing documents with chat templates. + * @param options.addGenerationPrompt Whether to end the prompt with the token(s) that indicate + * the start of an assistant message. This is useful when you want to generate a response from the + * model. Note that this argument will be passed to the chat template, and so it must be supported + * in the template for this argument to have any effect. + * @param options.tokenize Whether to tokenize the output. If false, the output will be a string. + * @param options.padding Whether to pad sequences to the maximum length. Has no effect if tokenize is false. + * @param options.truncation Whether to truncate sequences to the maximum length. Has no effect if tokenize is false. + * @param options.maxLength Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is false. + * If not specified, the tokenizer's `max_length` attribute will be used as a default. + * @param options.returnTensor Whether to return the output as a Tensor or an Array. Has no effect if tokenize is false. + * @param options.returnDict Whether to return a dictionary with named outputs. Has no effect if tokenize is false. + * @param options.tokenizerKwargs Additional options to pass to the tokenizer. + * @returns The tokenized output. + */ + applyChatTemplate( + conversation: Message[], + { + tools = null, + documents = null, + chatTemplate = null, + addGenerationPrompt = false, + tokenize = true, + padding = false, + truncation = false, + maxLength = null, + returnTensor = true, + returnDict = false, + tokenizerKwargs = {}, + ...kwargs + }: { + tools?: string[] | null; + documents?: string[] | null; + chatTemplate?: string | null; + addGenerationPrompt?: boolean; + tokenize?: boolean; + padding?: boolean; + truncation?: boolean; + maxLength?: number | null; + returnTensor?: boolean; + returnDict?: boolean; + tokenizerKwargs?: Record; + } = {} + ) { + chatTemplate = this.getChatTemplate({ chatTemplate, tools }); + + if (typeof chatTemplate !== 'string') { + throw Error(`chat_template must be a string, but got ${typeof chatTemplate}`); + } + + // Compilation function uses a cache to avoid recompiling the same template + let compiledTemplate = this.compiledTemplateCache.get(chatTemplate); + if (compiledTemplate === undefined) { + compiledTemplate = new Template(chatTemplate); + this.compiledTemplateCache.set(chatTemplate, compiledTemplate); + } + + const specialTokensMap = Object.create(null); + for (const key of SPECIAL_TOKEN_ATTRIBUTES) { + const value = this.getToken(key); + if (value) { + specialTokensMap[key] = value; + } + } + + const rendered = compiledTemplate.render({ + messages: conversation, + addGenerationPrompt, + tools, + documents, + ...specialTokensMap, + ...kwargs + }); + + if (tokenize) { + const out = this.call(rendered, { + addSpecialTokens: false, + padding, + truncation, + maxLength, + returnTensor, + ...tokenizerKwargs + }); + return returnDict ? out : out.inputIds; + } + + return rendered; + } +} + +export function max(arr: T) { + if (arr.length === 0) throw Error('Array must not be empty'); + let max = arr[0]; + let indexOfMax = 0; + for (let i = 1; i < arr.length; ++i) { + if (arr[i] > max) { + max = arr[i]; + indexOfMax = i; + } + } + return [max, indexOfMax] as T extends bigint[] ? [bigint, number] : [number, number]; +} + +function mergeArrays(...arrs: T[]): T { + return Array.prototype.concat.apply([], arrs) as T; +} + +type TrieNode = { + /** + * If this node marks the end of a word, this property will + * contain the complete word. Otherwise, it's undefined. + */ + end?: string; + + /** + * An index signature to represent child nodes. Each key is a + * character, and each value is the next TrieNode in the sequence. + * The value is a union to satisfy TypeScript's index signature rules. + */ + [key: string]: TrieNode | string | undefined; +}; + +/** + * A data structure which uses a trie to split a string into tokens based on a dictionary. + * It can also use a regular expression to preprocess the input text before splitting. + * + * NOTE: To ensure multi-byte characters are handled correctly, we operate at byte-level instead of character-level. + */ +class DictionarySplitter { + trie: TrieNode; + /** + * @param dictionary The dictionary of words to use for splitting. + */ + constructor(dictionary: string[]) { + this.trie = this.buildTrie(dictionary); + } + + /** + * Builds a trie from the given dictionary. + * @param dictionary The dictionary of words to build the trie from. + * @returns The root node of the trie. + */ + private buildTrie(dictionary: string[]) { + const trie: TrieNode = Object.create(null); + for (const word of dictionary) { + let node = trie; + for (let i = 0; i < word.length; ++i) { + node = (node[word[i]] ??= Object.create(null)) as TrieNode; + } + node.end = word; + } + return trie; + } + + /** + * Splits the input text into tokens based on the dictionary. + * @param {string} text The input text to split. + * @returns {string[]} An array of tokens. + */ + split(text: string): string[] { + const result = []; + const n = text.length; + let start = 0; + let i = 0; + + while (i < n) { + let node = this.trie; + let match = null; + let j = i; + + while (j < n && (node = node[text[j]] as TrieNode)) { + if (node.end) { + // Always keep the last (i.e., longest) match. + match = node.end; + } + ++j; + } + + if (match) { + if (i > start) { + result.push(text.slice(start, i)); + } + result.push(match); + i += match.length; + start = i; + } else { + ++i; + } + } + if (start < n) { + result.push(text.slice(start)); + } + return result; + } +} + +/** + * Efficient Heap-based Implementation of a Priority Queue. + * It uses an array-based binary heap, where the root is at index `0`, and the + * children of node `i` are located at indices `2i + 1` and `2i + 2`, respectively. + * + * Adapted from the following sources: + * - https://stackoverflow.com/a/42919752/13989043 (original) + * - https://github.com/belladoreai/llama-tokenizer-js (minor improvements) + */ +class PriorityQueue { + private heap: T[]; + private comparator: (a: T, b: T) => boolean; + private maxSize: number; + /** + * Create a new PriorityQueue. + * @param comparator Comparator function to determine priority. Defaults to a MaxHeap. + */ + constructor(comparator = (a: T, b: T) => a > b, maxSize = Infinity) { + this.heap = []; + this.comparator = comparator; + this.maxSize = maxSize; + } + + /** + * The size of the queue + */ + get size() { + return this.heap.length; + } + + /** + * Check if the queue is empty. + * @returns `true` if the queue is empty, `false` otherwise. + */ + isEmpty() { + return this.size === 0; + } + + /** + * Return the element with the highest priority in the queue. + * @returns The highest priority element in the queue. + */ + peek() { + return this.heap[0]; + } + + /** + * Add one or more elements to the queue. + * @param values The values to push into the queue. + * @returns The new size of the queue. + */ + push(...values: T[]) { + return this.extend(values); + } + + /** + * Add multiple elements to the queue. + * @param values The values to push into the queue. + * @returns The new size of the queue. + */ + extend(values: T[]) { + for (const value of values) { + if (this.size < this.maxSize) { + this.heap.push(value); + this.siftUp(); + } else { + // Get index of value with the lowest priority + const smallest = this.smallest(); + + // If the new value has higher priority than the smallest value in the heap + // then replace the smallest value with the new value and update the heap + if (this.comparator(value, this.heap[smallest])) { + this.heap[smallest] = value; + this.siftUpFrom(smallest); + } + } + } + return this.size; + } + + /** + * Remove and return the element with the highest priority in the queue. + * @returns The element with the highest priority in the queue. + */ + pop() { + const poppedValue = this.peek(); + const bottom = this.size - 1; + if (bottom > 0) { + this.swap(0, bottom); + } + this.heap.pop(); + this.siftDown(); + return poppedValue; + } + + /** + * Replace the element with the highest priority in the queue with a new value. + * @param value The new value. + * @returns The replaced value. + */ + replace(value: T) { + const replacedValue = this.peek(); + this.heap[0] = value; + this.siftDown(); + return replacedValue; + } + + /** + * Compute the index for the parent of the node at index `i`. + * @param i The index of the node to get the parent of. + * @returns The index of the parent node. + */ + private parent(i: number) { + return ((i + 1) >>> 1) - 1; + } + + /** + * Compute the index for the left child of the node at index `i`. + * @param i The index of the node to get the left child of. + * @returns The index of the left child. + * + */ + private left(i: number) { + return (i << 1) + 1; + } + + /** + * Compute the index for the right child of the node at index `i`. + * @param i The index of the node to get the right child of. + * @returns The index of the right child. + */ + private right(i: number) { + return (i + 1) << 1; + } + + /** + * Check if the element at index `i` is greater than the element at index `j`. + * @param i The index of the first element to compare. + * @param j The index of the second element to compare. + * @returns `true` if the element at index `i` is greater than the element at index `j`, `false` otherwise. + * + */ + private greater(i: number, j: number) { + return this.comparator(this.heap[i], this.heap[j]); + } + + /** + * Swap the elements at indices `i` and `j`. + * @param i The index of the first element to swap. + * @param j The index of the second element to swap. + * + */ + private swap(i: number, j: number) { + const temp = this.heap[i]; + this.heap[i] = this.heap[j]; + this.heap[j] = temp; + } + + /** + * Maintain the heap property by updating positions in the heap, + * starting at the last element and moving up the heap. + */ + private siftUp() { + this.siftUpFrom(this.size - 1); + } + + /** + * Helper function to sift up from a given node. + * @param node The index of the node to start sifting up from. + */ + private siftUpFrom(node: number) { + while (node > 0 && this.greater(node, this.parent(node))) { + this.swap(node, this.parent(node)); + node = this.parent(node); + } + } + + /** + * Maintain the heap property by updating positions in the heap, + * starting at the first element and moving down the heap. + */ + private siftDown() { + let node = 0; + while ( + (this.left(node) < this.size && this.greater(this.left(node), node)) || + (this.right(node) < this.size && this.greater(this.right(node), node)) + ) { + const maxChild = + this.right(node) < this.size && this.greater(this.right(node), this.left(node)) + ? this.right(node) + : this.left(node); + this.swap(node, maxChild); + node = maxChild; + } + } + + /** + * Get the index of the smallest element in the heap. Since we use an array-based heap, + * the index can be computed without needing to traverse the heap. + */ + private smallest(): number { + return 2 ** Math.floor(Math.log2(this.size)) - 1; + } +} + +/** + * A simple Least Recently Used (LRU) cache implementation in JavaScript. + * This cache stores key-value pairs and evicts the least recently used item + * when the capacity is exceeded. + */ +class LRUCache { + capacity: number; + cache: Map; + /** + * Creates an LRUCache instance. + * @param capacity The maximum number of items the cache can hold. + */ + constructor(capacity: number) { + this.capacity = capacity; + this.cache = new Map(); + } + + /** + * Retrieves the value associated with the given key and marks the key as recently used. + * @param key The key to retrieve. + * @returns The value associated with the key, or undefined if the key does not exist. + */ + get(key: Key) { + if (!this.cache.has(key)) return undefined; + const value = this.cache.get(key); + this.cache.delete(key); + this.cache.set(key, value as Value); + return value; + } + + /** + * Inserts or updates the key-value pair in the cache. + * If the key already exists, it is updated and marked as recently used. + * If the cache exceeds its capacity, the least recently used item is evicted. + * @param key The key to add or update. + * @param value The value to associate with the key. + */ + put(key: Key, value: Value) { + if (this.cache.has(key)) { + this.cache.delete(key); + } + this.cache.set(key, value); + if (this.cache.size > this.capacity) { + this.cache.delete(this.cache.keys().next().value as Key); + } + } + + /** + * Clears the cache. + */ + clear() { + this.cache.clear(); + } +} + +export function decodeSingle(value: number, tokenizer: PreTrainedTokenizer | null): string { + if (tokenizer instanceof PreTrainedTokenizer) { + return tokenizer + .decodeSingle([value]) + .replaceAll('<|end_of_text|>', '▶️📄') + .replaceAll('<|im_start|>', '▶️💬') + .replaceAll('<|im_end|>', '⏹️💬'); + } + return `<${value}>`; +} diff --git a/examples/finetuning/src/lib/train/types.ts b/examples/finetuning/src/lib/train/types.ts new file mode 100644 index 00000000..600749c5 --- /dev/null +++ b/examples/finetuning/src/lib/train/types.ts @@ -0,0 +1,27 @@ +import type { DataLoader } from '@piston-ml/piston-web'; + +import type { NaturalLanguageAutoregressiveBatch } from './data/natural'; +import type { GPT } from './model/gpt'; + +type CollateFn = (batch: B) => T; + +type NaturalCollateInput = number[][]; + +export type NaturalBatchType = NaturalLanguageAutoregressiveBatch; + +export type AutoregressiveBatchType = NaturalLanguageAutoregressiveBatch; + +export type NaturalAutoregressiveCollateFnType = CollateFn< + NaturalCollateInput, + NaturalLanguageAutoregressiveBatch +>; + +export type AutoregressiveCollateFnType = NaturalAutoregressiveCollateFnType; + +export type NaturalCollateFnType = CollateFn>; + +export type NaturalDataloaderType = DataLoader>; + +export type AutoregressiveModelType = GPT; + +export type GeneratableModel = AutoregressiveModelType; diff --git a/examples/finetuning/src/lib/train/utils/checkpoint.ts b/examples/finetuning/src/lib/train/utils/checkpoint.ts new file mode 100644 index 00000000..f55c5341 --- /dev/null +++ b/examples/finetuning/src/lib/train/utils/checkpoint.ts @@ -0,0 +1,224 @@ +import type { Config } from '$lib/workspace/config'; + +import * as piston from '@piston-ml/piston-web'; +import { + type ConstantConfig, + type CosineAnnealingConfig, + type ExponentialConfig, + type LinearConfig, + LRScheduler, + Optimizer, + type OptimizerParamState, + type ParamGroupConfig, + type SchedulerStateDict, + type StateDict, + type StepConfig, + Tensor +} from '@piston-ml/piston-web'; + +/** + * Recursively walks an object and extracts any Tensor values into `out` as Buffers. + * Replaces extracted tensors in the returned structure with a small marker object + * containing the tensor storage key that was used in `out`. + */ +export function splitTensorsFromObject( + value: unknown, + baseKey: string, + out: Record +): unknown { + if (value instanceof Tensor) { + out[baseKey] = new piston.Buffer(value, true); + return { __tensor__: baseKey }; + } + if (Array.isArray(value)) { + return value.map((v, i) => splitTensorsFromObject(v, `${baseKey}.${i}`, out)); + } + if (value && typeof value === 'object') { + const result: Record = {}; + for (const [k, v] of Object.entries(value)) { + result[k] = splitTensorsFromObject(v, `${baseKey}.${k}`, out); + } + return result; + } + return value; +} + +export type AnySchedulerState = SchedulerStateDict< + StepConfig | CosineAnnealingConfig | ExponentialConfig | LinearConfig | ConstantConfig | unknown +>; + +export type CheckpointOptimizerExtra = { + name: string; + // JSON with __tensor__ markers + state: unknown; + paramGroups: ParamGroupConfig[]; +} | null; + +export interface CheckpointDataState { + blockSize: number; + shardIndex: number; + cursor: number; +} + +export interface CheckpointExtra { + config: Config; + optimizer: CheckpointOptimizerExtra; + numSteps: number; + lrScheduler?: { state: AnySchedulerState }; + dataState?: CheckpointDataState; + // Optional wall-clock training start time in ms to persist across restarts + startTimeMs?: number; +} + +/** + * Builds a checkpoint payload by combining model parameters with optimizer state. + * - Model parameters/buffers go into `tensors` directly + * - Any Tensor values found inside optimizer.stateDict().state are lifted into `tensors` + * under keys prefixed with `optimizer/state/...` + * - Extra contains { config, optimizer, numSteps } + */ + +export function buildCheckpoint( + model: piston.Module, + optimizer: Optimizer, + numSteps: number, + configForExtra: Config, + scheduler?: LRScheduler, + dataState?: CheckpointDataState, + startTimeMs?: number +): { tensors: Record; extra: CheckpointExtra } { + const tensors: Record = model.stateDict(); + + let optimizerExtra: CheckpointOptimizerExtra = null; + try { + const name = optimizer.constructor.name ?? 'Optimizer'; + const packed = optimizer.stateDict(); + const tensorSlots: Record = {}; + const jsonState = splitTensorsFromObject(packed.state, 'optimizer.state', tensorSlots); + Object.assign(tensors, tensorSlots); + optimizerExtra = { + name, + state: jsonState, + paramGroups: packed.paramGroups + }; + } catch (e) { + console.warn('Failed to pack optimizer stateDict for checkpoint extra:', e); + } + + const extra: CheckpointExtra = { + config: configForExtra, + optimizer: optimizerExtra, + numSteps, + lrScheduler: scheduler ? { state: scheduler.stateDict() } : undefined, + dataState, + startTimeMs + }; + + return { tensors, extra }; +} + +/** + * Replace any marker objects of the form { __tensor__: key } inside a JSON structure + * with actual Tensors from the provided mapping. + */ +export function rehydrateTensorsInObject(value: unknown, lifted: Record): T { + if (value && typeof value === 'object' && !Array.isArray(value)) { + const marker = value as { __tensor__?: string }; + if (typeof marker.__tensor__ === 'string') { + const key = marker.__tensor__; + if (!(key in lifted)) { + throw new Error(`Missing lifted tensor for key '${key}' during optimizer rehydration`); + } + return lifted[key] as unknown as T; + } + const out: Record = {}; + for (const [k, v] of Object.entries(value)) { + out[k] = rehydrateTensorsInObject(v, lifted); + } + return out as unknown as T; + } + if (Array.isArray(value)) { + return value.map((v) => rehydrateTensorsInObject(v, lifted)) as unknown as T; + } + return value as T; +} + +export interface SplitLoadedStateResult { + modelState: Record; + schedulerState?: AnySchedulerState; + optimizerState: StateDict; + numSteps: number; + config: Config; + dataState?: CheckpointDataState; + startTimeMs?: number; +} + +/** + * Given loaded state from piston.load, split out model state from lifted optimizer tensors + * and rehydrate optimizer and scheduler states from extras. + */ +export function splitLoadedState(loaded: { + state: Record; + extra?: CheckpointExtra; +}): SplitLoadedStateResult { + const prefix = 'optimizer.state'; + const liftedOptimizerTensors: Record = {}; + const modelState: Record = {}; + + for (const [key, t] of Object.entries(loaded.state)) { + if (key.startsWith(prefix)) { + liftedOptimizerTensors[key] = t; + } else { + modelState[key] = t; + } + } + + let optimizerState: StateDict | undefined; + let schedulerState: AnySchedulerState | undefined = undefined; + let numSteps = 0; + let config: Config | null = null; + let dataState: CheckpointDataState | undefined = undefined; + let startTimeMs: number | undefined = undefined; + + const { extra } = loaded; + + if (extra) { + config = extra.config; + numSteps = extra.numSteps; + if (extra.optimizer) { + const rehydratedState = rehydrateTensorsInObject>( + extra.optimizer.state, + liftedOptimizerTensors + ); + optimizerState = { + state: rehydratedState, + paramGroups: extra.optimizer.paramGroups ?? [] + }; + } + if (extra.lrScheduler && extra.lrScheduler.state) { + schedulerState = extra.lrScheduler.state; + } + if (extra.dataState) { + dataState = extra.dataState; + } + if (typeof extra.startTimeMs === 'number') { + startTimeMs = extra.startTimeMs; + } + } + + if (!config) { + throw new Error('No config found in checkpoint'); + } + + if (numSteps == null) { + throw new Error('No numSteps found in checkpoint'); + } + + if (!optimizerState) { + throw new Error('No optimizer state found in checkpoint'); + } + + // Some runs don't use a scheduler, so we don't validate that it's present + + return { modelState, optimizerState, schedulerState, numSteps, config, dataState, startTimeMs }; +} diff --git a/examples/finetuning/src/lib/train/utils/init.ts b/examples/finetuning/src/lib/train/utils/init.ts new file mode 100644 index 00000000..974695f3 --- /dev/null +++ b/examples/finetuning/src/lib/train/utils/init.ts @@ -0,0 +1,54 @@ +// import type { Config, ProjectionInitializationConfig } from '$lib/workspace/config'; + +// import { initNormal_, initOnes_, initZeros_, nn } from '@piston-ml/piston-web'; + +// TODO: Setup initialization for GPT2 lora + +// export function initTransformerParameters(self: nn.Module, config: Config): void { +// const initializationConfig = config.model.transformer.initialization; + +// if (!initializationConfig.present) { +// return; +// } + +// const initTransformerWeights = (module: nn.Module): void => { +// if (module instanceof nn.Linear) { +// initNormal_(module.weight, { mean: 0.0, std: initializationConfig.std }); +// if (module.bias != null) { +// initZeros_(module.bias); +// } +// } else if (module instanceof nn.Embedding) { +// initNormal_(module.weight, { mean: 0.0, std: initializationConfig.std }); +// } else if (module instanceof nn.LayerNorm) { +// if (module.bias) { +// initZeros_(module.bias); +// } +// initOnes_(module.weight); +// } +// }; + +// const initProjection = (p: nn.Parameter, projectionConfig: ProjectionInitializationConfig) => { +// if (!projectionConfig.present) { +// return; +// } +// const nLayers = config.model.layers; +// if (projectionConfig.strategy === 'layer-scaled') { +// initNormal_(p, { mean: 0.0, std: 0.02 / Math.sqrt(2 * nLayers) }); +// } else if (projectionConfig.strategy === 'zero') { +// initZeros_(p); +// } +// }; + +// self.apply(initTransformerWeights); +// for (const [pn, p] of self.namedParameters()) { +// if (pn.endsWith('cProj.weight')) { +// initProjection(p, initializationConfig.projections.attention); +// } +// if (pn.endsWith('downProj.weight')) { +// initProjection(p, initializationConfig.projections.mlp); +// } +// if (pn.endsWith('lmHead.weight')) { +// initProjection(p, initializationConfig.projections.lmHead); +// } +// } +// } diff --git a/examples/finetuning/src/lib/train/utils/model.ts b/examples/finetuning/src/lib/train/utils/model.ts new file mode 100644 index 00000000..289b7a81 --- /dev/null +++ b/examples/finetuning/src/lib/train/utils/model.ts @@ -0,0 +1,134 @@ +import { + CaptureIndexMode, + DataLoader, + type IndexState, + Module, + Tensor, + weak +} from '@piston-ml/piston-web'; +import * as piston from '@piston-ml/piston-web'; + +import type { Config } from '../../workspace/config'; +import type { NaturalBatchType, NaturalCollateFnType, NaturalDataloaderType } from '../types'; + +import { type CollateWrapFunction } from '../data'; +import { naturalLanguageAutoregressiveCollate, NaturalLanguageDataset } from '../data/natural'; +import { buildGPT2Config } from '../model/config'; +import { GPT, GPT2_BLOCK_SIZE, GPT2_VOCAB_SIZE } from '../model/gpt'; +import { parseSeed } from './random'; + +// Overloads for strong typing based on dataset kind +export function createCollateFn( + wrapFunction?: CollateWrapFunction | null +): NaturalCollateFnType { + const collateOptions = wrapFunction !== undefined ? { wrapFunction } : {}; + return (batch: number[][]) => + naturalLanguageAutoregressiveCollate(batch as number[][], { + ...collateOptions + }); +} + +export function createDataloader( + config: Config, + dataset: NaturalLanguageDataset, + wrapFunction?: CollateWrapFunction | null +): [NaturalDataloaderType, NaturalCollateFnType] { + const collateFn = createCollateFn(wrapFunction); + return [ + new DataLoader>(dataset, { + collateFn, + batchSize: config.training.batchSize + }), + collateFn + ]; +} + +/** + * Create a model instance based on the configuration + */ +export function createModel(config: Config): GPT { + return new GPT(buildGPT2Config(config.model.type), config); +} + +export function calculateParameterSum(model: Module): Tensor { + const sums = model.parameters().map((param) => param.sum()); + return piston.stack(sums).sum(); +} + +export function countParameters(model: GPT): number { + let totalParams = 0; + + // Walk through all named parameters + for (const [_, param] of model.namedParameters()) { + if (param && param.shape) { + const paramCount = (param.shape as number[]).reduce( + (acc: number, dim: number) => acc * dim, + 1 + ); + totalParams += paramCount; + } + } + + return totalParams; +} + +/** + * Inspect model for a given configuration: count the number of parameters and capture an "index" + * of the model. + */ +export function inspectModel(config: Config): { + parameterCount: number; + hiddenSize: number; + mlpIntermediateSize: number; + modelIndex: IndexState; + vocabSize: number; + blockSize: number; +} { + return weak( + () => { + const blockSize = GPT2_BLOCK_SIZE; + const vocabSize = GPT2_VOCAB_SIZE; + const model = createModel(config); + const parameterCount = countParameters(model); + const hiddenSize = model.config.nEmbd; + const mlpIntermediateSize = hiddenSize * 4; + + let indexMode: CaptureIndexMode | null = null; + try { + indexMode = new CaptureIndexMode(model); + + // Run the model forward with an input from the dataloader + model.forward(piston.zeros([1, blockSize], { dtype: piston.int32 })); + + console.debug(`Model has ${parameterCount} parameters with vocab size ${vocabSize}`); + + return { + parameterCount, + hiddenSize, + mlpIntermediateSize, + vocabSize, + blockSize, + modelIndex: indexMode!.index + }; + } finally { + indexMode![Symbol.dispose](); + } + }, + { + label: 'inspectModel' + } + ); +} + +export function seedPiston(config: Config) { + // Set up random number generator + const seed = parseSeed( + config.training.randomSeed.present ? config.training.randomSeed.value : undefined + ); + + if (seed !== undefined) { + piston.seed(seed); + } + + return seed; +} diff --git a/examples/finetuning/src/lib/train/utils/modes.ts b/examples/finetuning/src/lib/train/utils/modes.ts new file mode 100644 index 00000000..4be5a4e7 --- /dev/null +++ b/examples/finetuning/src/lib/train/utils/modes.ts @@ -0,0 +1,96 @@ +import { + PistonFunctionMode, + PistonMarkStepMode, + WeakMarkStepMode, + WeakTensorFunctionMode, + type WeakTensorFunctionModeOptions +} from '@piston-ml/piston-web'; +import * as piston from '@piston-ml/piston-web'; + +export class DebugMode extends PistonFunctionMode { + constructor(public debugEnabled: boolean) { + super(); + } + + _pistonFunction( + func: (...args: unknown[]) => FT | Promise, + _types: unknown[], + args: unknown[], + kwargs: Record + ): T | Promise { + if (this.debugEnabled) { + console.log( + func.name, + args.reduce((acc: number[], a) => { + if (a instanceof piston.wasm.Tensor_wasm) { + return [...acc, a.id]; + } + return acc; + }, []) + ); + } + + const after = (result: T) => { + if (result instanceof piston.wasm.Tensor_wasm) { + console.log(func.name, 'result', result.id); + } + return result; + }; + + const result = func(...args, kwargs) as T | Promise; + if (result instanceof Promise) { + return result.then(after) as Promise; + } + + return after(result) as T; + } +} + +export class WeakModeIfEnabled { + private mode: WeakTensorFunctionMode | null = null; + + constructor( + public enabled: boolean, + public options: WeakTensorFunctionModeOptions + ) { + if (enabled) { + this.mode = new WeakTensorFunctionMode(options); + } + } + + markWeak(input: T) { + if (this.mode) { + this.mode.markWeak(input); + } + return input; + } + + pin(input: T) { + if (this.mode) { + this.mode.pin(input); + } + return input; + } + + [Symbol.dispose]() { + if (this.mode) { + this.mode[Symbol.dispose](); + } + } +} + +export class MarkStepModeIfEnabled { + private mode: PistonMarkStepMode | null = null; + + constructor(public enabled: boolean) { + if (enabled) { + this.mode = new WeakMarkStepMode(); + } + } + + [Symbol.dispose]() { + if (this.mode) { + this.mode[Symbol.dispose](); + } + } +} diff --git a/examples/finetuning/src/lib/train/utils/optim.ts b/examples/finetuning/src/lib/train/utils/optim.ts new file mode 100644 index 00000000..855f40d8 --- /dev/null +++ b/examples/finetuning/src/lib/train/utils/optim.ts @@ -0,0 +1,381 @@ +import type { OptimizerConfig } from '$lib/workspace/config'; + +import { + AdamW, + type Device, + type Module, + MuonWithAdamW, + type MuonWithAdamWParamGroup, + nn, + type Optimizer, + type Parameter as ParameterType, + type ParamGroup, + SGD +} from '@piston-ml/piston-web'; + +// Deterministic sorting helpers +function compareByName(a: [string, T], b: [string, T]): number { + return a[0] < b[0] ? -1 : a[0] > b[0] ? 1 : 0; +} + +function sortEntriesByName(entries: Array<[string, T]>): Array<[string, T]> { + return entries.sort(compareByName); +} + +function paramsFromNamesSorted( + names: Iterable, + paramDict: Map +): ParameterType[] { + return Array.from(names) + .sort() + .map((name) => paramDict.get(name)!) + .filter((p) => p != null); +} + +/** + * Validates that all model parameters are included in the parameter groups. + * Throws an error if any parameters are missing from the groups. + * @param model - The model to validate + * @param paramGroups - The parameter groups to check + * @throws Error if any model parameters are not included in the parameter groups + */ +function validateParameterGroups( + model: Module, + paramGroups: ParamGroup[], + paramDict: Map +): void { + // Get all parameters from the model + const allModelParams = new Set(); + for (const [_, param] of model.namedParameters()) { + allModelParams.add(param); + } + + // Get all parameters from the parameter groups + const groupParams = new Set(); + for (const group of paramGroups) { + for (const param of group.params) { + groupParams.add(param); + } + } + + // Find parameters that are in the model but not in any group + const missingParams: ParameterType[] = []; + for (const param of allModelParams) { + if (!groupParams.has(param)) { + missingParams.push(param); + } + } + + if (missingParams.length > 0) { + // Find the names of the missing parameters using paramDict + const missingParamNames: string[] = []; + for (const [name, param] of paramDict) { + if (missingParams.includes(param)) { + missingParamNames.push(name); + } + } + + throw new Error( + `Found ${missingParams.length} model parameters that are not included in any parameter group (${groupParams.size} included). ` + + `All model parameters must be assigned to a parameter group for training. ` + + `Missing parameters: ${missingParamNames.join(', ')}` + ); + } +} + +function getWeightDecayParams( + model: Module, + useWeightDecayGroups: boolean, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + whitelistWeightModules: (new (...args: any[]) => Module)[], + // eslint-disable-next-line @typescript-eslint/no-explicit-any + blacklistWeightModules: (new (...args: any[]) => Module)[] +): { paramDict: Map; decay: Set; noDecay: Set } { + const decay = new Set(); + const noDecay = new Set(); + + const paramDict = new Map(); + + for (const [mn, m] of model.namedModules()) { + for (const [pn, p] of m.namedParameters()) { + const fpn = mn ? `${mn}.${pn}` : pn; + + paramDict.set(fpn, p); + + if (useWeightDecayGroups) { + if (pn.endsWith('bias')) { + // All biases will not be decayed + noDecay.add(fpn); + } else if (pn.endsWith('weight')) { + if (whitelistWeightModules.some((cls) => m instanceof cls)) { + // Weights of whitelist modules will be weight decayed + decay.add(fpn); + } else if (blacklistWeightModules.some((cls) => m instanceof cls)) { + // Weights of blacklist modules will NOT be weight decayed + noDecay.add(fpn); + } + } else { + // Parameters that are not weights or biases (shouldn't exist in std models) + // Add to decay by default, adjust if necessary for specific models. + decay.add(fpn); + } + } else { + decay.add(fpn); + } + } + } + + if (useWeightDecayGroups) { + // Validate that we considered every parameter + const allParamNames = new Set( + Array.from(model.namedParameters()).map(([name]) => name) + ); + const interParams = new Set([...decay].filter((x) => noDecay.has(x))); + const unionParams = new Set([...decay, ...noDecay]); + + if (interParams.size !== 0) { + throw new Error( + `Parameters ${JSON.stringify(Array.from(interParams))} made it into both decay/noDecay sets` + ); + } + const missingParams = new Set([...allParamNames].filter((x) => !unionParams.has(x))); + if (missingParams.size !== 0) { + throw new Error( + `Parameters ${JSON.stringify( + Array.from(missingParams) + )} were not separated into either decay/noDecay set` + ); + } + } + + return { paramDict, decay, noDecay }; +} + +/** + * Based on what minGPT does: + * Configures the optimizer based on the training configuration. + * Separates parameters into weight decay and no weight decay groups. + * @param trainConfig - The optimizer + * configuration + * @param device - The computation device + * @returns The configured optimizer + */ +export function configureOptimizers( + model: Module, + moduleLayersPrefixes: string[], + lmHeadPrefix: string, + trainConfig: OptimizerConfig, + device: Device +): Optimizer { + const whitelistWeightModules = [nn.Linear]; + const blacklistWeightModules = [nn.LayerNorm, nn.RMSNorm, nn.Embedding]; + + const effectiveWeightDecay = trainConfig.weightDecay.present + ? trainConfig.weightDecay.value + : 0.0; + + const { paramDict, decay, noDecay } = getWeightDecayParams( + model, + trainConfig.weightDecay.useWeightDecayGroups, + whitelistWeightModules, + blacklistWeightModules + ); + // Deterministic param lists by name + const decayParamsValues = paramsFromNamesSorted(decay, paramDict); + const noDecayParamsValues = paramsFromNamesSorted(noDecay, paramDict); + + if (trainConfig.type === 'AdamW' || trainConfig.type === 'Adam' || trainConfig.type === 'SGD') { + const optimGroups: ParamGroup[] = [ + { + params: decayParamsValues, + weightDecay: effectiveWeightDecay + }, + ...(noDecayParamsValues.length > 0 + ? [ + { + params: noDecayParamsValues, + weightDecay: 0.0 // no decay + } + ] + : []) + ]; + + validateParameterGroups(model, optimGroups, paramDict); + + // Create the AdamW optimizer + if (trainConfig.type === 'AdamW' || trainConfig.type === 'Adam') { + return new AdamW(optimGroups, device, { + lr: trainConfig.lr, + betas: [trainConfig.adam.beta1, trainConfig.adam.beta2], + eps: trainConfig.adam.eps, + weightDecay: effectiveWeightDecay, + amsgrad: trainConfig.adam.amsgrad + }); + } else if (trainConfig.type === 'SGD') { + return new SGD(optimGroups, device, { + lr: trainConfig.lr, + momentum: trainConfig.sgd.momentum, + dampening: trainConfig.sgd.dampening, + weightDecay: effectiveWeightDecay, + nesterov: trainConfig.sgd.nesterov + }); + } + } else if (trainConfig.type === 'Muon') { + // Get parameter groups by type + const paramEntries = sortEntriesByName(Array.from(paramDict.entries())); + const moduleLayersParams = paramEntries.filter(([n]) => + moduleLayersPrefixes.some((prefix) => n.startsWith(prefix)) + ); + // Sort each category deterministically by name + const hiddenMatrixParams = sortEntriesByName( + moduleLayersParams.filter(([n, p]) => p.ndim >= 2 && !n.toLowerCase().includes('embed')) + ); + const scalarParams = sortEntriesByName(moduleLayersParams.filter(([_, p]) => p.ndim < 2)); + const embedParams = sortEntriesByName( + paramEntries.filter(([n, _]) => n.toLowerCase().includes('embed')) + ); + const headParams = sortEntriesByName(paramEntries.filter(([n]) => n.startsWith(lmHeadPrefix))); + // Any other params we just throw to AdamW + const filteredParams = new Set([ + ...hiddenMatrixParams.map(([n]) => n), + ...scalarParams.map(([n]) => n), + ...embedParams.map(([n]) => n), + ...headParams.map(([n]) => n) + ]); + const remainingParams = paramEntries.filter(([n]) => !filteredParams.has(n)); + + if (remainingParams.length > 0) { + console.warn( + `Found ${remainingParams.length} parameters that don't fit Muon categorization and will be handled by AdamW:`, + remainingParams.map(([name]) => name) + ); + } + + // Apply weight decay grouping to each parameter type + const paramGroups: MuonWithAdamWParamGroup[] = []; + + // Hidden matrix parameters for Muon optimizer + if (trainConfig.weightDecay.useWeightDecayGroups) { + const hiddenDecay = hiddenMatrixParams.filter(([name]) => decay.has(name)).map(([_, p]) => p); + const hiddenNoDecay = hiddenMatrixParams + .filter(([name]) => noDecay.has(name)) + .map(([_, p]) => p); + + if (hiddenDecay.length > 0) { + paramGroups.push({ + optimizer: 'muon', + lr: trainConfig.lr, + weightDecay: effectiveWeightDecay, + momentum: trainConfig.muon.momentum, + nsSteps: trainConfig.muon.nsSteps, + nesterov: trainConfig.muon.nesterov, + params: hiddenDecay + }); + } + + if (hiddenNoDecay.length > 0) { + paramGroups.push({ + optimizer: 'muon', + lr: trainConfig.lr, + weightDecay: 0.0, // no decay + momentum: trainConfig.muon.momentum, + nsSteps: trainConfig.muon.nsSteps, + nesterov: trainConfig.muon.nesterov, + params: hiddenNoDecay + }); + } + } else { + if (hiddenMatrixParams.length > 0) { + paramGroups.push({ + optimizer: 'muon', + lr: trainConfig.lr, + weightDecay: effectiveWeightDecay, + momentum: trainConfig.muon.momentum, + nsSteps: trainConfig.muon.nsSteps, + nesterov: trainConfig.muon.nesterov, + params: hiddenMatrixParams.map(([_, p]) => p) + }); + } + } + + // Scalar, embedding, and head parameters for AdamW optimizer + const adamwParams = sortEntriesByName([ + ...scalarParams, + ...embedParams, + ...headParams, + ...remainingParams + ]); + + // Check if there is any overlap between the two optimizers getting overlap of adamWparams + const adamwParamSet = new Set(adamwParams.map(([n]) => n)); + const muonParamSet = new Set(hiddenMatrixParams.map(([n]) => n)); + const overlap = adamwParamSet.intersection(muonParamSet); + if (overlap.size > 0) { + throw new Error( + `Overlap between AdamW and Muon parameters: ${Array.from(overlap).join(', ')}` + ); + } + + if (trainConfig.weightDecay.useWeightDecayGroups) { + const adamwDecay = adamwParams.filter(([name]) => decay.has(name)).map(([_, p]) => p); + const adamwNoDecay = adamwParams.filter(([name]) => noDecay.has(name)).map(([_, p]) => p); + + if (adamwDecay.length > 0) { + paramGroups.push({ + optimizer: 'adamw', + lr: trainConfig.lr, + betas: [trainConfig.adam.beta1, trainConfig.adam.beta2], + eps: trainConfig.adam.eps, + weightDecay: effectiveWeightDecay, + amsgrad: trainConfig.adam.amsgrad, + params: adamwDecay + }); + } + + if (adamwNoDecay.length > 0) { + paramGroups.push({ + optimizer: 'adamw', + lr: trainConfig.lr, + betas: [trainConfig.adam.beta1, trainConfig.adam.beta2], + eps: trainConfig.adam.eps, + weightDecay: 0.0, // no decay + amsgrad: trainConfig.adam.amsgrad, + params: adamwNoDecay + }); + } + } else { + if (adamwParams.length > 0) { + paramGroups.push({ + optimizer: 'adamw', + lr: trainConfig.lr, + betas: [trainConfig.adam.beta1, trainConfig.adam.beta2], + eps: trainConfig.adam.eps, + weightDecay: effectiveWeightDecay, + amsgrad: trainConfig.adam.amsgrad, + params: adamwParams.map(([_, p]) => p) + }); + } + } + + validateParameterGroups(model, paramGroups, paramDict); + + return new MuonWithAdamW(paramGroups, device, { + muon: { + lr: trainConfig.lr, + weightDecay: effectiveWeightDecay, + momentum: trainConfig.muon.momentum, + nsSteps: trainConfig.muon.nsSteps, + nesterov: trainConfig.muon.nesterov + }, + adamw: { + lr: trainConfig.lr, + betas: [trainConfig.adam.beta1, trainConfig.adam.beta2], + eps: trainConfig.adam.eps, + weightDecay: effectiveWeightDecay, + amsgrad: trainConfig.adam.amsgrad + } + }); + } + + throw new Error(`Unknown optimizer type: ${trainConfig.type}`); +} diff --git a/examples/finetuning/src/lib/train/utils/random.ts b/examples/finetuning/src/lib/train/utils/random.ts new file mode 100644 index 00000000..03756479 --- /dev/null +++ b/examples/finetuning/src/lib/train/utils/random.ts @@ -0,0 +1,39 @@ +import { MersenneTwister19937, Random } from 'random-js'; + +export function parseSeed(seed?: string): number | undefined { + if (seed === undefined || seed === '') { + return undefined; + } + + const parsed = parseInt(seed); + if (!isNaN(parsed)) { + return parsed; + } + + // Simple hash function for string + let hash = 0; + for (let i = 0; i < seed.length; i++) { + const char = seed.charCodeAt(i); + hash = (hash << 5) - hash + char; + hash = hash & hash; // Convert to 32bit integer + } + return Math.abs(hash); +} + +/** + * Creates a seeded random number generator from a string or undefined seed. + * If seed is undefined, auto-seeds the generator. + * If seed is a string that can be parsed as a number, uses the parsed number. + * Otherwise, uses a simple hash function to convert the string to a number. + */ +export function seededRandom(seed?: number): Random { + if (seed === undefined) { + return new Random(MersenneTwister19937.autoSeed()); + } + + return new Random(MersenneTwister19937.seed(seed)); +} + +export function forkRandom(random: Random): Random { + return new Random(MersenneTwister19937.seed(random.int32())); +} diff --git a/examples/finetuning/src/lib/train/validation.ts b/examples/finetuning/src/lib/train/validation.ts new file mode 100644 index 00000000..59335bcf --- /dev/null +++ b/examples/finetuning/src/lib/train/validation.ts @@ -0,0 +1,149 @@ +import type { Config, ValidationConfig } from '$lib/workspace/config'; +import type { BaseStepData, TokenRollout } from '$lib/workspace/runs.svelte'; + +import { type Tensor, weak } from '@piston-ml/piston-web'; + +import type { GeneratableModel, NaturalCollateFnType } from './types'; + +import { NaturalLanguageDataset } from './data/natural'; +import { generateDecoderCompletions } from './validationHelpers'; + +export type ValidationStep = BaseStepData & { + type: 'validation'; + completions: TokenRollout[]; + samplingParams: { + temperature: number; + }; + // Running average throughput across the generation in tokens/second (decoder/generative only) + tokensPerSecond?: number; + targets?: number[][]; // Only present in first step, what tokens should have been + encoderInputs?: number[][]; // Encoder/source inputs used (for encoder-decoder or encoder-only) + decoderPromptLengths?: number[]; // For decoder-only display: prompt token counts per example + matches?: boolean[][]; // Per-example, per-token correctness flags +}; + +export type NaturalValidationExamples = { + naturalSequences: number[][]; +}; + +export function buildValidationExamplesSubset( + examples: NaturalValidationExamples, + subsetSize: number +): NaturalValidationExamples { + return { + naturalSequences: examples.naturalSequences.slice(0, subsetSize) + }; +} + +export async function prepareNaturalValidationExamples( + config: Config, + dataset: NaturalLanguageDataset +): Promise { + const naturalSequences: number[][] = []; + + let count = 0; + for await (const sampleSequence of dataset) { + naturalSequences.push(sampleSequence); + count++; + if (count >= config.training.validation.batchSize) break; + } + + return { naturalSequences }; +} + +export async function computeNaturalValidationMetrics( + model: GeneratableModel, + dataset: NaturalLanguageDataset, + valExamples: NaturalValidationExamples, + valConfig: ValidationConfig +): Promise> { + let promptLen = 0; + + const contextSize = dataset.contextSize; + + // promptLen = Math.max(Math.floor(contextSize / 4), 1); + promptLen = 8; + const eosId = dataset.eosId as number; + const starts = valExamples.naturalSequences.map((seq) => seq.slice(0, promptLen)); + const maxTokens = Math.max(0, contextSize - promptLen); + const result = await generateDecoderCompletions(model, starts, { + maxTokens, + stopTokens: eosId !== null ? [eosId] : [], + temperature: valConfig.temperature, + useKvCache: valConfig.useKvCache + }); + const { completions, tokensPerSecond } = result; + + const validationStepData: Omit = { + type: 'validation', + completions, + samplingParams: { temperature: valConfig.temperature }, + decoderPromptLengths: new Array(valExamples.naturalSequences.length).fill(promptLen), + tokensPerSecond + }; + + return validationStepData; +} + +export function buildValidationLog( + validationStepData: Omit +): Record> { + // Aggregate numeric-like metrics from per-example metrics; average arrays per-example; skip 'matches' + const aggregatedNumeric: Record = {}; + + // Compute number of unique completions by hashing tokenIds + const uniqueCompletionsCount = (() => { + const seen = new Set(); + for (const c of validationStepData.completions ?? []) { + const ids = c?.tokenIds ?? []; + seen.add(ids.join(',')); + } + return seen.size; + })(); + + const validationLog: Record = { + 'validation/completions': validationStepData, + 'validation/unique_completions': uniqueCompletionsCount + }; + if (typeof validationStepData.tokensPerSecond === 'number') { + validationLog['validation/tokens_per_second'] = validationStepData.tokensPerSecond; + } + for (const [k, v] of Object.entries(aggregatedNumeric)) { + validationLog[`validation/${k}`] = v; + } + return validationLog; +} + +export async function computeLikelihoodMetrics( + model: GeneratableModel, + sequences: NaturalValidationExamples, + collateFn: NaturalCollateFnType +): Promise<{ valLoss: number; perplexity: number }> { + return await weak(async () => { + model.eval(); + + let valLoss: number | null = null; + try { + const collated = collateFn(sequences.naturalSequences); + + let loss: Tensor | null = null; + const [inputs, targets] = collated.tensors; + [, loss] = model.forward(await inputs.to('gpu'), { + targets: await targets.to('gpu') + }); + + if (!loss) { + throw new Error(`No loss tensor returned from decoder-only model during validation`); + } + valLoss = await (await loss.to('cpu')).item(); + if (valLoss === null) { + throw new Error(`Validation loss item is null for decoder-only model`); + } + } finally { + model.train(); + } + + const perplexity = Math.exp(valLoss); + return { valLoss, perplexity }; + }); +} diff --git a/examples/finetuning/src/lib/train/validationHelpers.ts b/examples/finetuning/src/lib/train/validationHelpers.ts new file mode 100644 index 00000000..3c62ee45 --- /dev/null +++ b/examples/finetuning/src/lib/train/validationHelpers.ts @@ -0,0 +1,60 @@ +import type { TokenRollout } from '$lib/workspace/runs.svelte'; + +import { pin, weak } from '@piston-ml/piston-web'; + +import { generateGPTStream } from './generate'; +import { GPT } from './model/gpt'; + +export type DecoderGenerationOptions = { + maxTokens?: number; + stopTokens?: number[]; + temperature: number; + useKvCache: boolean; +}; + +export async function generateDecoderCompletions( + model: GPT, + startSequences: number[][], + options: DecoderGenerationOptions +): Promise<{ completions: TokenRollout[]; tokensPerSecond?: number }> { + const { maxTokens, stopTokens, temperature, useKvCache } = options; + + model.eval(); + + const completions: TokenRollout[] = []; + let lastTPS: number | undefined; + + for (let bi = 0; bi < startSequences.length; bi++) { + await weak(async () => { + const startTokens = startSequences[bi] ?? []; + let seq: number[] = []; + const perStepProbs: number[][] = []; + let stepIndex = 0; + for await (const generationResult of generateGPTStream(model, startTokens, { + maxTokens, + stopTokens: stopTokens ?? [], + temperature, + useKvCache + })) { + seq = generationResult.sequences[0] ? [...generationResult.sequences[0]] : []; + lastTPS = generationResult.tokensPerSecond ?? lastTPS; + if (generationResult.probs) { + const probsArray = await (await generationResult.probs.to('cpu')).toVec(); + const [_b, v] = generationResult.probs.shape; + const row: number[] = []; + for (let vi = 0; vi < v; vi++) { + row.push(probsArray[0 * v + vi]); + } + perStepProbs.push(row); + } + stepIndex++; + if (typeof maxTokens === 'number' && stepIndex >= maxTokens) break; + } + completions[bi] = { tokenIds: seq, probs: pin(perStepProbs) }; + }); + } + + model.train(); + + return { completions, tokensPerSecond: lastTPS }; +} diff --git a/examples/finetuning/src/lib/workspace/checkpointStore.ts b/examples/finetuning/src/lib/workspace/checkpointStore.ts new file mode 100644 index 00000000..558bec9e --- /dev/null +++ b/examples/finetuning/src/lib/workspace/checkpointStore.ts @@ -0,0 +1,52 @@ +import { openDb, txRequest } from '$lib/dataUtils'; + +const DB_NAME = 'piston-checkpoint-store'; +const DB_VERSION = 1; +const STORE_CHECKPOINTS = 'checkpoints'; + +export class CheckpointStore { + private dbPromise: Promise | null = null; + + private get db(): Promise { + if (!this.dbPromise) + this.dbPromise = openDb(DB_NAME, DB_VERSION, (db) => { + if (!db.objectStoreNames.contains(STORE_CHECKPOINTS)) { + db.createObjectStore(STORE_CHECKPOINTS); + } + }); + return this.dbPromise; + } + + async get(runId: string): Promise { + const db = await this.db; + return txRequest(db, STORE_CHECKPOINTS, 'readonly', (s) => + s.get(runId) + ); + } + + async set(runId: string, bytes: Uint8Array | ArrayBuffer): Promise { + const db = await this.db; + const buf = bytes instanceof Uint8Array ? (bytes.buffer as ArrayBuffer) : bytes; + await txRequest(db, STORE_CHECKPOINTS, 'readwrite', (s) => s.put(buf, runId)); + } + + async has(runId: string): Promise { + const db = await this.db; + const res = await txRequest(db, STORE_CHECKPOINTS, 'readonly', (s) => + s.get(runId) + ); + return res != null; + } + + async delete(runId: string): Promise { + const db = await this.db; + await txRequest(db, STORE_CHECKPOINTS, 'readwrite', (s) => s.delete(runId)); + } + + async clear(): Promise { + const db = await this.db; + await txRequest(db, STORE_CHECKPOINTS, 'readwrite', (s) => s.clear()); + } +} + +export const checkpointStore = new CheckpointStore(); diff --git a/examples/finetuning/src/lib/workspace/config.svelte.ts b/examples/finetuning/src/lib/workspace/config.svelte.ts new file mode 100644 index 00000000..738d7fd4 --- /dev/null +++ b/examples/finetuning/src/lib/workspace/config.svelte.ts @@ -0,0 +1,466 @@ +import type { Config } from '$lib/workspace/config'; + +import { browser } from '$app/environment'; +import { buildDataset } from '$lib/train/data'; +import { getCollatedSampleData } from '$lib/train/data/collate'; +import { GPT2_BLOCK_SIZE, GPT2_VOCAB_SIZE } from '$lib/train/model/gpt'; +import { createDataloader } from '$lib/train/utils/model'; +import { SvelteURL, SvelteURLSearchParams } from 'svelte/reactivity'; + +import { getCurrentRun, getLatestRun } from './runs.svelte'; + +const CONFIG_DEFAULTS: Config = { + training: { + logSteps: 5, + batchSize: 1, + validation: { + present: true, + valSteps: 10, + batchSize: 8, + temperature: 0.0, + useKvCache: false, + completions: { + present: true, + decodingBatchSize: 1, + amount: 'subset', + subsetSize: 4 + } + }, + limitTraining: { + present: false, + steps: 50_000 + }, + labelSmoothing: { + present: false, + value: 1e-4 + }, + dropout: { + present: false, + embedding: 0.1, + transformer: { + attention: 0.1, + residual: 0.1 + } + }, + randomSeed: { + present: true, + value: 'sequence toy' + }, + gradNorm: { + track: true, + errorIfNonfinite: true + }, + clipGradNorm: { + present: false, + value: 1.0 + }, + useWeakTensorReferences: true, + sharedObjectAllocation: false, + cachingEnabled: false, + inplaceSupport: true, + vramLimitMb: { + present: true, + value: 4096 + }, + checkpointEverySteps: { + present: true, + value: 200 + }, + restartEverySteps: 1000 + }, + data: { + dataset: 'tinystories' + }, + model: { + type: 'distilgpt2' + }, + optimizer: { + type: 'Muon', + lr: 1e-3, + weightDecay: { + present: true, + value: 1e-2, + useWeightDecayGroups: true + }, + warmupSteps: { + present: true, + value: 100 + }, + lrScheduler: { + present: true, + type: 'cosine', + stepSchedule: { + stepSize: 100, + gamma: 0.8 + }, + constantSchedule: { + factor: 1 / 3, + totalIters: 100 + }, + cosineAnnealingSchedule: { + tMax: 500, + etaMin: 1e-4 + }, + exponentialSchedule: { + gamma: 0.999 + }, + linearSchedule: { + startFactor: 1.0, + endFactor: 1 / 3, + totalIters: 1000 + } + }, + adam: { + beta1: 0.9, + beta2: 0.999, + eps: 1e-8, + amsgrad: false + }, + sgd: { + momentum: 0.9, + dampening: 0, + nesterov: false + }, + muon: { + momentum: 0.95, + nsSteps: 5, + nesterov: true + } + }, + version: 1 +}; + +function computeEffectiveDefaults(): Config { + return JSON.parse(JSON.stringify(CONFIG_DEFAULTS)) as Config; +} + +/** + * Parses a value based on the type of the default value in the config. This is not wildly general, + * but it seems to work for the current config. + * @param valueStr - The value to parse. + * @param defaultValue - The default value. + * @returns The parsed value. + */ +function parseValueBasedOnDefault(valueStr: string, defaultValue: unknown): unknown { + if (typeof defaultValue === 'boolean') { + return valueStr.toLowerCase() === 'true'; + } + if (typeof defaultValue === 'number') { + const num = parseFloat(valueStr); + return isNaN(num) ? defaultValue : num; + } + return valueStr; // Default to string if type is not boolean or number +} + +/** + * Builds a config from URL search params. + * @param params - The URL search params. + * @param defaults - The defaults to use if no URL search params are present. + * @returns The config. + */ +function buildConfigFromUrlParams(params: URLSearchParams, defaults: Config): Partial { + const configFromUrl: Record = {}; + + for (const [path, valueStr] of params) { + const keys = path.split('.'); + let currentLevel = configFromUrl; + let currentDefaultsLevel: unknown = defaults; + + try { + for (let i = 0; i < keys.length; i++) { + const key = keys[i]; + if ( + currentDefaultsLevel === undefined || + typeof currentDefaultsLevel !== 'object' || + currentDefaultsLevel === null + ) { + throw new Error(`Invalid config path from URL: ${path}`); + } + currentDefaultsLevel = (currentDefaultsLevel as Record)[key]; + + if (i < keys.length - 1) { + if ( + !(currentLevel as Record)[key] || + typeof (currentLevel as Record)[key] !== 'object' + ) { + (currentLevel as Record)[key] = {}; + } + currentLevel = (currentLevel as Record)[key] as Record; + } else { + (currentLevel as Record)[key] = parseValueBasedOnDefault( + valueStr, + currentDefaultsLevel + ); + } + } + } catch (e) { + console.warn((e as Error).message); + continue; // Skip this parameter if path is invalid or type mismatch + } + } + return configFromUrl as Partial; +} + +function mergeDeep(target: Record, source: Record) { + for (const key in source) { + if (Object.prototype.hasOwnProperty.call(source, key)) { + const sourceVal = source[key]; + let targetKeyAsObject = target[key] as Record; + + if (sourceVal && typeof sourceVal === 'object' && !Array.isArray(sourceVal)) { + if ( + !targetKeyAsObject || + typeof targetKeyAsObject !== 'object' || + Array.isArray(targetKeyAsObject) + ) { + targetKeyAsObject = {}; + target[key] = targetKeyAsObject; + } + mergeDeep(targetKeyAsObject, sourceVal as Record); + } else if (sourceVal !== undefined) { + target[key] = sourceVal; + } + } + } +} + +/** + * Gets the initial config from the URL search params, or the defaults if no URL search params are + * present. + * @returns The initial config. + */ +function getInitialConfig(): Config { + // Start with effective defaults, possibly from URL 'preset' + let base: Config = JSON.parse(JSON.stringify(CONFIG_DEFAULTS)); + if (typeof window !== 'undefined' && window.location && window.URLSearchParams) { + try { + const params = new URLSearchParams(window.location.search); + base = computeEffectiveDefaults(); + const configOverrides = buildConfigFromUrlParams(params, base); + const initial = JSON.parse(JSON.stringify(base)); + mergeDeep(initial, configOverrides); + return initial; + } catch (e) { + console.error('Error processing config from URL, using defaults:', e); + return JSON.parse(JSON.stringify(CONFIG_DEFAULTS)); + } + } + return base; +} + +export const config = $state(getInitialConfig()); +const configDefaults = $derived(computeEffectiveDefaults()); + +/** + * Resets one or more config values to their defaults using dot-separated paths. + */ +export function resetConfigToDefaults(paths: string | string[]) { + const pathList = Array.isArray(paths) ? paths : [paths]; + + for (const path of pathList) { + const defaultValue = getValueAtPath(configDefaults as unknown as Record, path); + if (defaultValue === undefined) { + console.warn(`resetConfigToDefaults: Unknown config path "${path}"`); + continue; + } + // Deep clone to avoid mutating the CONFIG_DEFAULTS reference + const cloned = deepClone(defaultValue); + const ok = setValueAtPath(config as unknown as Record, path, cloned); + if (!ok) { + console.warn(`resetConfigToDefaults: Failed to set value for path "${path}"`); + } + } +} + +function deepClone(value: T): T { + return JSON.parse(JSON.stringify(value)) as T; +} + +export function getConfigDefaultValue(path: string): unknown { + const val = getValueAtPath(configDefaults as unknown as Record, path); + return deepClone(val); +} + +export function equalsConfigDefault(path: string): boolean { + const current = getValueAtPath(config as unknown as Record, path); + const def = getValueAtPath(configDefaults as unknown as Record, path); + return valuesDeepEqual(current, def); +} + +function valuesDeepEqual(a: unknown, b: unknown): boolean { + try { + return JSON.stringify(a) === JSON.stringify(b); + } catch { + return a === b; + } +} + +function getValueAtPath(obj: Record, path: string): unknown { + const keys = path.split('.'); + let current: unknown = obj; + for (const key of keys) { + if ( + current === null || + current === undefined || + typeof current !== 'object' || + !(key in (current as Record)) + ) { + return undefined; + } + current = (current as Record)[key]; + } + return current; +} + +function setValueAtPath(target: Record, path: string, value: unknown): boolean { + const keys = path.split('.'); + let current: Record = target; + for (let i = 0; i < keys.length - 1; i++) { + const key = keys[i]; + const next = current[key]; + if (next === undefined || next === null || typeof next !== 'object' || Array.isArray(next)) { + // Only create an object if we are not overwriting a non-object path + current[key] = {}; + } + current = current[key] as Record; + } + const lastKey = keys[keys.length - 1]; + current[lastKey] = value as unknown; + return true; +} + +/** + * Flattens only the non-default values from an object by comparing against a defaults object. + * Returns a map of dot-separated paths to stringified values. + */ +function flattenNonDefault( + obj: Record, + defaults: Record, + prefix: string = '' +): Record { + const params: Record = {}; + for (const key in obj) { + if (!Object.prototype.hasOwnProperty.call(obj, key)) continue; + const newPrefix = prefix ? `${prefix}.${key}` : key; + const value = obj[key]; + const defaultValue = (defaults ?? {})[key]; + + if (value !== null && typeof value === 'object' && !Array.isArray(value)) { + const defaultChild = + defaultValue !== null && typeof defaultValue === 'object' && !Array.isArray(defaultValue) + ? (defaultValue as Record) + : ({} as Record); + const nested = flattenNonDefault(value as Record, defaultChild, newPrefix); + Object.assign(params, nested); + } else if (value !== undefined) { + if (defaultValue === undefined || !valuesDeepEqual(value, defaultValue)) { + params[newPrefix] = String(value); + } + } + } + return params; +} + +export function initSharedConfigUrlSync() { + if (typeof window !== 'undefined' && window.history && window.URL) { + $effect(() => { + const configSnapshot = $state.snapshot(config); + const flatParams = flattenNonDefault( + configSnapshot, + configDefaults as unknown as Record + ); + // If any parameters are present, also include the current config version + if (Object.keys(flatParams).length > 0) { + flatParams['version'] = String(configSnapshot.version); + } + const searchParamsString = new SvelteURLSearchParams(flatParams).toString(); + + const currentUrl = new SvelteURL(window.location.href); + currentUrl.search = searchParamsString; // This replaces the entire search string + + // Only call replaceState if the URL actually changed to avoid flooding history + if (window.location.href !== currentUrl.href) { + window.history.replaceState({}, '', currentUrl.toString()); + } + }); + } +} + +export function replaceConfig(next: Config) { + config['training'] = deepClone(next.training); + config['data'] = deepClone(next.data); + config['model'] = deepClone(next.model); + config['optimizer'] = deepClone(next.optimizer); + config['version'] = next.version; + validateConfig(); +} + +function datasetFromConfig(config: Config) { + // Only build the full dataset/tokenizer pipeline in the browser. During SSR/prerender + // we return a lightweight placeholder object so that components can render without + // triggering network fetches (which would use Node's global fetch with a relative URL). + if (!browser) { + return { + dataset: null, + tokenizer: null, + sampleData: null, + collated: null + }; + } + + const dataset = buildDataset(config, 'train'); + const [, collateFn] = createDataloader(config, dataset, null); + + const collatedData = getCollatedSampleData(dataset, collateFn, 4); + + return { + dataset, + vocabSize: GPT2_VOCAB_SIZE, + blockSize: GPT2_BLOCK_SIZE, + tokenizer: dataset.tokenizer, + sampleData: collatedData.then((data) => { + const firstSample = data.collated[0]; + return { + hasPrompt: 'prompt' in firstSample && (firstSample.prompt?.length ?? 0) > 0, + samples: data.samples, + collated: data.collated + }; + }) + }; +} + +const currentDataset = $derived(datasetFromConfig(config)); + +export function getCurrentDataset() { + return currentDataset; +} + +const currentRunDataset = $derived.by(() => { + const currentRun = getCurrentRun(); + return currentRun?.config ? datasetFromConfig(currentRun.config) : null; +}); + +export function getCurrentRunDataset() { + return currentRunDataset; +} + +const latestRunDataset = $derived.by(() => { + const latestRun = getLatestRun(); + return latestRun?.config ? datasetFromConfig(latestRun.config) : null; +}); + +export function getLatestRunDataset() { + return latestRunDataset; +} + +export function validateConfig() { + // There are a few things that can still slip through the cracks, so we deal with those here. + + if ( + config.training.validation.completions.present && + config.training.validation.completions.amount === 'subset' && + config.training.validation.completions.subsetSize > config.training.validation.batchSize + ) { + config.training.validation.completions.subsetSize = config.training.validation.batchSize; + } +} diff --git a/examples/finetuning/src/lib/workspace/config.ts b/examples/finetuning/src/lib/workspace/config.ts new file mode 100644 index 00000000..792e42c3 --- /dev/null +++ b/examples/finetuning/src/lib/workspace/config.ts @@ -0,0 +1,283 @@ +import type { DATASET_CONFIG_DEFAULTS } from '$lib/train/data'; +import type { + ConstantConfig, + CosineAnnealingConfig, + ExponentialConfig, + LinearConfig, + StepConfig +} from '@piston-ml/piston-web'; + +export interface TransformerDropoutConfig { + present: boolean; + embedding: number; + transformer: { + attention: number; + residual: number; + }; +} + +export interface ValidationCompletionsConfig { + present: boolean; + decodingBatchSize: number; + amount: 'all' | 'subset'; + subsetSize: number; +} + +export interface ValidationConfig { + present: boolean; + valSteps: number; + batchSize: number; + temperature: number; + completions: ValidationCompletionsConfig; + useKvCache: boolean; +} + +export interface TrainingConfig { + logSteps: number; + limitTraining: { + present: boolean; + steps: number; + }; + checkpointEverySteps: { + present: boolean; + value: number; + }; + batchSize: number; + dropout: TransformerDropoutConfig; + validation: ValidationConfig; + labelSmoothing: { + present: boolean; + value: number; + }; + randomSeed: { + present: boolean; + value: string; + }; + vramLimitMb: { + present: boolean; + value: number; + }; + gradNorm: { + track: boolean; + errorIfNonfinite: boolean; + }; + clipGradNorm: { + present: boolean; + value: number; + }; + useWeakTensorReferences: boolean; + sharedObjectAllocation: boolean; + cachingEnabled: boolean; + inplaceSupport: boolean; + restartEverySteps: number; +} + +export interface DataConfig { + dataset: keyof typeof DATASET_CONFIG_DEFAULTS; +} + +export type ProjectionInitializationStrategy = 'layer-scaled' | 'zero'; +export interface ProjectionInitializationConfig { + present: boolean; + strategy: ProjectionInitializationStrategy; +} + +export interface TransformerInitializationConfig { + present: boolean; + std: number; + projections: { + attention: ProjectionInitializationConfig; + mlp: ProjectionInitializationConfig; + lmHead: ProjectionInitializationConfig; + }; +} + +export interface LSTMInitializationConfig { + forgetGateBias: { + present: boolean; + value: number; + }; +} + +export type GPT2ModelType = 'distilgpt2' | 'gpt2' | 'gpt2-medium' | 'gpt2-large' | 'gpt2-xl'; + +export interface ModelConfig { + type: GPT2ModelType; +} + +export interface OptimizerConfig { + type: 'AdamW' | 'Adam' | 'SGD' | 'Muon'; + lr: number; + weightDecay: { + present: boolean; + value: number; + useWeightDecayGroups: boolean; + }; + warmupSteps: { present: boolean; value: number }; + lrScheduler: { + present: boolean; + type: string; + stepSchedule: StepConfig; + constantSchedule: ConstantConfig; + cosineAnnealingSchedule: CosineAnnealingConfig; + exponentialSchedule: ExponentialConfig; + linearSchedule: LinearConfig; + }; + adam: { + beta1: number; + beta2: number; + eps: number; + amsgrad: boolean; + }; + sgd: { + dampening: number; + momentum: number; + nesterov: boolean; + }; + muon: { + nsSteps: number; + momentum: number; + nesterov: boolean; + }; +} + +export interface Config { + version: number; + training: TrainingConfig; + data: DataConfig; + model: ModelConfig; + optimizer: OptimizerConfig; +} + +export type ConfigItemDescription = + | { + shortName: string; + } + | string + | [string, number] + | null; + +type ReplaceValues = T extends object ? { [K in keyof T]: ReplaceValues } : V; + +export type ConfigValues = ReplaceValues; + +export const CONFIG_DESCRIPTIONS: ConfigValues = { + training: { + logSteps: 'log steps', + batchSize: 'batch', + clipGradNorm: { + present: 'clip grad norm', + value: 'clip grad norm' + }, + validation: { + present: 'val', + valSteps: 'val steps', + batchSize: 'val size', + temperature: 'val temp', + useKvCache: 'val kv cache', + completions: { + present: 'completions', + decodingBatchSize: 'completions batch', + amount: 'completions amount strategy', + subsetSize: 'completions subset' + } + }, + limitTraining: { + present: 'limit train', + steps: 'max steps' + }, + labelSmoothing: { + present: 'smoothing', + value: 'smoothing' + }, + dropout: { + present: 'dropout', + embedding: 'dropout emb', + transformer: { + attention: 'dropout attn', + residual: 'dropout resid' + } + }, + randomSeed: { + present: 'seed', + value: 'seed' + }, + gradNorm: { + track: 'track grad norm', + errorIfNonfinite: 'error nonfinite' + }, + useWeakTensorReferences: 'weak tensor refs', + sharedObjectAllocation: 'shared objs', + cachingEnabled: 'caching', + inplaceSupport: 'inplace', + vramLimitMb: { + present: 'vram lim', + value: 'vram lim' + }, + checkpointEverySteps: { + present: 'checkpointing', + value: 'checkpoint steps' + }, + restartEverySteps: 'restart steps' + }, + data: { + dataset: 'dataset' + }, + model: { + type: 'model' + }, + optimizer: { + type: 'optim', + lr: 'lr', + weightDecay: { + present: 'decay', + value: 'decay', + useWeightDecayGroups: 'decay groups' + }, + warmupSteps: { + present: 'warmup', + value: 'warmup steps' + }, + lrScheduler: { + present: 'lr sched', + type: 'lr sched', + stepSchedule: { + stepSize: 'sched step', + gamma: 'sched gamma' + }, + constantSchedule: { + factor: 'const lr factor', + totalIters: 'const lr total' + }, + cosineAnnealingSchedule: { + tMax: 'cos lr tmax', + etaMin: 'cos lr eta min' + }, + exponentialSchedule: { + gamma: 'exp lr gamma' + }, + linearSchedule: { + startFactor: 'lin lr start', + endFactor: 'lin lr end', + totalIters: 'lin lr total' + } + }, + adam: { + beta1: 'adam beta1', + beta2: 'adam beta2', + eps: 'adam eps', + amsgrad: 'adam ams' + }, + sgd: { + momentum: 'sgd moment', + dampening: 'sgd damp', + nesterov: 'sgd nester' + }, + muon: { + momentum: 'muon moment', + nsSteps: 'muon nssteps', + nesterov: 'muon nester' + } + }, + version: null +}; diff --git a/examples/finetuning/src/lib/workspace/lastSessionStore.ts b/examples/finetuning/src/lib/workspace/lastSessionStore.ts new file mode 100644 index 00000000..b15ef4f4 --- /dev/null +++ b/examples/finetuning/src/lib/workspace/lastSessionStore.ts @@ -0,0 +1,115 @@ +import { openDb, txRequest } from '$lib/dataUtils'; +import { SvelteMap } from 'svelte/reactivity'; + +import type { Config } from './config'; +import type { RunData, StepData } from './runs.svelte'; + +export type SavedRun = Omit & { metrics: Record }; + +const DB_NAME = 'piston-last-session-store'; +const DB_VERSION = 1; +const STORE_SESSION = 'session'; +const STORE_CHECKPOINT = 'checkpoint'; +const STORE_META = 'meta'; +const META_LAST_RUN_ID_KEY = 'lastRunId'; + +export function serializeRun(run: RunData): SavedRun { + const metrics = Object.fromEntries([...run.metrics.entries()].map(([k, v]) => [k, v.data])); + return { ...run, metrics }; +} + +export function deserializeRun(saved: SavedRun): RunData { + return { + ...saved, + metrics: new SvelteMap( + Object.entries(saved.metrics).map(([k, v]) => [k, { metricName: k, data: v }]) + ) + }; +} + +class LastSessionStore { + private dbPromise: Promise | null = null; + + private get db(): Promise { + if (!this.dbPromise) + this.dbPromise = openDb(DB_NAME, DB_VERSION, (db) => { + if (!db.objectStoreNames.contains(STORE_SESSION)) db.createObjectStore(STORE_SESSION); + if (!db.objectStoreNames.contains(STORE_CHECKPOINT)) db.createObjectStore(STORE_CHECKPOINT); + if (!db.objectStoreNames.contains(STORE_META)) db.createObjectStore(STORE_META); + }); + return this.dbPromise; + } + + private async getLastRunId(db: IDBDatabase): Promise { + return txRequest(db, STORE_META, 'readonly', (s) => + s.get(META_LAST_RUN_ID_KEY) + ); + } + + async get(): Promise<{ run: RunData; checkpoint: Uint8Array } | null> { + const db = await this.db; + const runId = await this.getLastRunId(db); + if (!runId) return null; + const [run, buffer] = await Promise.all([ + txRequest(db, STORE_SESSION, 'readonly', (s) => s.get(runId)), + txRequest(db, STORE_CHECKPOINT, 'readonly', (s) => s.get(runId)) + ]); + if (!run || !buffer) return null; + return { run: deserializeRun(run), checkpoint: new Uint8Array(buffer) }; + } + + async set(run: RunData, checkpoint: Uint8Array | ArrayBuffer): Promise { + const db = await this.db; + const buf = checkpoint instanceof Uint8Array ? (checkpoint.buffer as ArrayBuffer) : checkpoint; + // Clear all existing sessions; we'll want to remove this if we ever support multiple + // persistence. + await this.delete(); + await Promise.all([ + txRequest(db, STORE_SESSION, 'readwrite', (s) => s.put(serializeRun(run), run.runId)), + txRequest(db, STORE_CHECKPOINT, 'readwrite', (s) => s.put(buf, run.runId)), + txRequest(db, STORE_META, 'readwrite', (s) => s.put(run.runId, META_LAST_RUN_ID_KEY)) + ]); + } + + /** + * Update only the config of the last run in storage without touching other data. + * This avoids fully deserializing metrics and only rewrites the config field. + */ + async updateConfig(mutator: (config: Config) => Config | void): Promise { + const db = await this.db; + const runId = await this.getLastRunId(db); + if (!runId) return; + + const saved = await txRequest(db, STORE_SESSION, 'readonly', (s) => + s.get(runId) + ); + if (!saved) return; + + const maybeNew = mutator(saved.config as Config); + const newConfig = (maybeNew ? maybeNew : saved.config) as Config; + const updated: SavedRun = { ...saved, config: newConfig }; + await txRequest(db, STORE_SESSION, 'readwrite', (s) => s.put(updated, runId)); + } + + async exists(): Promise { + const db = await this.db; + const runId = await this.getLastRunId(db); + if (!runId) return false; + const [run, buffer] = await Promise.all([ + txRequest(db, STORE_SESSION, 'readonly', (s) => s.get(runId)), + txRequest(db, STORE_CHECKPOINT, 'readonly', (s) => s.get(runId)) + ]); + return run != null && buffer != null; + } + + async delete(): Promise { + const db = await this.db; + await Promise.all([ + txRequest(db, STORE_SESSION, 'readwrite', (s) => s.clear()), + txRequest(db, STORE_CHECKPOINT, 'readwrite', (s) => s.clear()), + txRequest(db, STORE_META, 'readwrite', (s) => s.delete(META_LAST_RUN_ID_KEY)) + ]); + } +} + +export const lastSessionStore = new LastSessionStore(); diff --git a/examples/finetuning/src/lib/workspace/localStorage.svelte.ts b/examples/finetuning/src/lib/workspace/localStorage.svelte.ts new file mode 100644 index 00000000..5422c14f --- /dev/null +++ b/examples/finetuning/src/lib/workspace/localStorage.svelte.ts @@ -0,0 +1,98 @@ +import { tick } from 'svelte'; + +export class LocalStorage { + #key: string; + #version = $state(0); + #listeners = 0; + #value: T | undefined; + + #handler = (e: StorageEvent) => { + if (e.storageArea !== localStorage) return; + if (e.key !== this.#key) return; + + this.#version += 1; + }; + + constructor(key: string, initial?: T) { + this.#key = key; + this.#value = initial; + + if (typeof localStorage !== 'undefined') { + if (localStorage.getItem(key) == null) { + localStorage.setItem(key, JSON.stringify(initial)); + } + } + } + + get current() { + // eslint-disable-next-line @typescript-eslint/no-unused-expressions + this.#version; + + const localStorageItem = + typeof localStorage !== 'undefined' ? localStorage.getItem(this.#key) : null; + const root = localStorageItem !== null ? JSON.parse(localStorageItem) : this.#value; + + const proxies = new WeakMap(); + + const proxy = (value: unknown) => { + if (typeof value !== 'object' || value === null) { + return value; + } + + let p = proxies.get(value); + + if (!p) { + p = new Proxy(value, { + get: (target, property) => { + // eslint-disable-next-line @typescript-eslint/no-unused-expressions + this.#version; + return proxy(Reflect.get(target, property)); + }, + set: (target, property, value) => { + this.#version += 1; + Reflect.set(target, property, value); + + if (typeof localStorage !== 'undefined') { + localStorage.setItem(this.#key, JSON.stringify(root)); + } + + return true; + } + }); + + proxies.set(value, p); + } + + return p; + }; + + if ($effect.tracking()) { + $effect(() => { + if (this.#listeners === 0) { + window.addEventListener('storage', this.#handler); + } + + this.#listeners += 1; + + return () => { + tick().then(() => { + this.#listeners -= 1; + if (this.#listeners === 0) { + window.removeEventListener('storage', this.#handler); + } + }); + }; + }); + } + + return proxy(root); + } + + set current(value) { + if (typeof localStorage !== 'undefined') { + localStorage.setItem(this.#key, JSON.stringify(value)); + } + + this.#version += 1; + } +} diff --git a/examples/finetuning/src/lib/workspace/runs.svelte.ts b/examples/finetuning/src/lib/workspace/runs.svelte.ts new file mode 100644 index 00000000..1b8e9ac5 --- /dev/null +++ b/examples/finetuning/src/lib/workspace/runs.svelte.ts @@ -0,0 +1,443 @@ +import type { ValidationStep } from '$lib/train/validation'; + +import { generateMemorableName } from '$lib/workspace/utils'; +import { SvelteMap, SvelteSet } from 'svelte/reactivity'; + +import { type Config, CONFIG_DESCRIPTIONS } from './config'; + +export type BaseStepData = { step: number }; + +export type Point = BaseStepData & { y: number }; + +export type TokenRollout = { + tokenIds: number[]; + probs: number[][]; +}; + +export type VisualizationStep = BaseStepData & { + type: 'visualization'; +}; + +export type StepData = Point | ValidationStep; + +export type MetricData = { + metricName: string; + data: StepData[]; +}; + +export type RunData = { + runId: string; + color: string; + config: Config; + metrics: SvelteMap; + step: number; + lastUpdated: number; + createdAt: number; + diffSummary: string; +}; + +export type RunMeta = { + runId: string | null; + config: Config | null; +}; + +// A palette of distinct colors for runs +const RUN_COLORS = ['#ab6aea', '#f5493b', '#dfa300', '#3bbc4a', '#00b7c0', '#4475f6']; + +export const PREFIX_BOOSTS: Record = { data: 10 }; + +type DiffItem = { + label: string; + path: string[]; + display: string; + boost: number; +}; + +function pathToString(path: ReadonlyArray): string { + return path.join('.'); +} + +function getAtPath(obj: unknown, path: ReadonlyArray): unknown { + let cur = obj; + for (const key of path) { + if (cur == null || typeof cur !== 'object') return undefined; + cur = (cur as Record)[key]; + } + return cur; +} + +function getDescriptor(path: ReadonlyArray): { label: string | null; itemBoost?: number } { + let cur = CONFIG_DESCRIPTIONS as unknown; + + for (const key of path) { + if (cur == null || typeof cur !== 'object') return { label: null }; + cur = (cur as Record)[key]; + } + if (cur == null) return { label: null }; + if (Array.isArray(cur)) { + const [shortName, boost] = cur; + return { label: shortName, itemBoost: typeof boost === 'number' ? boost : undefined }; + } + if (typeof cur === 'string') return { label: cur }; + if (typeof cur === 'object' && 'shortName' in cur && typeof cur.shortName === 'string') { + return { label: cur.shortName }; + } + return { label: null }; +} + +function maxPrefixBoost(path: ReadonlyArray, boosts: Record): number { + let maxB = 0; + for (let i = 1; i <= path.length; i++) { + const pref = path.slice(0, i).join('.'); + const b = boosts[pref]; + if (typeof b === 'number' && b > maxB) maxB = b; + } + return maxB; +} + +function orderOfMagnitude(n: number): number { + if (n === 0) return 0; + return Math.floor(Math.log10(Math.abs(n))); +} + +function formatScientific(n: number, significantDigits = 1): string { + if (n === 0) return '0'; + const s = n.toExponential(Math.max(0, significantDigits - 1)); + const [mant, expRaw] = s.split('e'); + const mantTrim = mant + .replace(/\.0+$/, '') + .replace(/(\.[0-9]*?)0+$/, '$1') + .replace(/\.$/, ''); + const exp = String(parseInt(expRaw, 10)); + return `${mantTrim}e${exp}`; +} + +function formatNumberDiff(a: number, b: number): string { + const bothSmallInts = + Number.isInteger(a) && Number.isInteger(b) && Math.abs(a) < 100 && Math.abs(b) < 100; + if (bothSmallInts) return `${a}→${b}`; + + // If signs differ or either is non-integer, prefer concise scientific form + const expA = orderOfMagnitude(a); + const expB = orderOfMagnitude(b); + const bothInts = Number.isInteger(a) && Number.isInteger(b); + const base = Math.pow(10, expA); + if ( + bothInts && + Math.sign(a) >= 0 && + Math.sign(b) >= 0 && + expA === expB && + // Only use suffix form for small changes + Math.abs(b - a) < base + ) { + const lead = Math.floor(a / base); + const suffixA = a - lead * base; + const suffixB = b - lead * base; + return `${lead}e${expA}+(${suffixA}→${suffixB})`; + } + const sa = formatScientific(a, 1); + const sb = formatScientific(b, 1); + return `${sa}⥵${sb}`; +} + +function formatValue(v: unknown): string { + if (typeof v === 'number') return formatScientific(v, 1); + if (typeof v === 'string') return v; + if (typeof v === 'boolean') return v ? 'true' : 'false'; + return String(v); +} + +function comparePrimitive(a: unknown, b: unknown): boolean { + // Strict equality suffices for primitives we expect here + return a === b; +} + +function collectLeafPaths(obj: unknown, prefix: string[] = []): string[][] { + const out: string[][] = []; + if (obj == null || typeof obj !== 'object') return out; + for (const key of Object.keys(obj)) { + const nextPath = [...prefix, key]; + const val = (obj as Record)[key]; + if (val != null && typeof val === 'object') { + // Only traverse objects that are not arrays + if (!Array.isArray(val)) out.push(...collectLeafPaths(val, nextPath)); + } else { + out.push(nextPath); + } + } + return out; +} + +export function describeConfigDiff( + prev: Config | null, + curr: Config, + opts?: { topK?: number; prefixBoosts?: Record } +): string { + const topK = opts?.topK ?? 3; + const boosts = opts?.prefixBoosts ?? PREFIX_BOOSTS; + if (!prev) return 'initial experiment'; + + const prevPaths = collectLeafPaths(prev); + const currPaths = collectLeafPaths(curr); + const allKey = new SvelteSet(); + for (const p of prevPaths) allKey.add(pathToString(p)); + for (const p of currPaths) allKey.add(pathToString(p)); + + const diffs: DiffItem[] = []; + for (const key of allKey) { + const path = key.split('.'); + // if (shouldSkipPath(path)) continue; + const { label, itemBoost } = getDescriptor(path); + if (!label) continue; + const a = getAtPath(prev, path); + const b = getAtPath(curr, path); + if (comparePrimitive(a, b)) continue; + + const prefixB = maxPrefixBoost(path, boosts); + let effBoost = prefixB; + if (typeof itemBoost === 'number') { + if (itemBoost < prefixB) { + console.warn( + `item boost lower than parent prefix; path=${key} itemBoost=${itemBoost} parentMax=${prefixB}` + ); + } + effBoost = Math.max(effBoost, itemBoost); + } + + let display: string; + if (typeof a === 'boolean' && typeof b === 'boolean') { + display = b ? `+${label}` : `-${label}`; + } else if (typeof a === 'number' && typeof b === 'number') { + display = `${label}:${formatNumberDiff(a, b)}`; + } else { + display = `${label}:${formatValue(a)}→${formatValue(b)}`; + } + + diffs.push({ label, path, display, boost: effBoost }); + } + + diffs.sort((x, y) => { + if (y.boost !== x.boost) return y.boost - x.boost; + return x.label.localeCompare(y.label); + }); + + if (diffs.length === 0) return 'no changes'; + const top = diffs.slice(0, topK).map((d) => d.display); + const rest = diffs.length - top.length; + return rest > 0 ? `${top.join(', ')}, etc ${rest} more` : top.join(', '); +} + +export const runsMap = new SvelteMap(); +export const runCounter = $state({ current: 0 }); +export const currentRun = $state<{ current: RunMeta | null }>({ + current: null +}); + +export function getRuns(): ReadonlyArray { + return [...runsMap.values()].sort((a, b) => a.runId.localeCompare(b.runId)); +} + +export function getAllMetricNames(): ReadonlyArray { + const names = new SvelteSet(); + for (const run of runsMap.values()) { + for (const metricName of run.metrics.keys()) { + names.add(metricName); + } + } + return [...names].sort(); +} + +/** + * Gets all metric names from the last n runs. + * @param n - The number of recent runs to consider + * @returns Array of unique metric names sorted alphabetically + */ +export function getMetricNamesFromLastNRuns(n: number): ReadonlyArray { + const names = new SvelteSet(); + const recentRuns = getLastNRuns(n); + for (const run of recentRuns) { + for (const metricName of run.metrics.keys()) { + names.add(metricName); + } + } + return [...names].sort(); +} + +export function getCurrentRun(): RunMeta | null { + return currentRun.current; +} + +export function getLatestRun(): RunMeta | null { + if (currentRun.current !== null) { + return currentRun.current; + } + + const allRuns = [...runsMap.values()]; + allRuns.sort((a, b) => b.lastUpdated - a.lastUpdated); + return allRuns[0] ?? null; +} + +/** + * Gets the last n runs sorted by most recently updated, ensuring the current run is always + * included. + */ +export function getLastNRuns(n: number): ReadonlyArray { + const allRuns = [...runsMap.values()]; + + // Sort by lastUpdated descending (most recent first) + allRuns.sort((a, b) => b.lastUpdated - a.lastUpdated); + + // If we have fewer runs than requested, return all of them + if (allRuns.length <= n) { + return allRuns; + } + + return allRuns.slice(0, n); +} + +export function clearPastRuns(): void { + const keepRunId = currentRun.current; + if (keepRunId === null) { + // No current run tracked; clear everything + runsMap.clear(); + return; + } + for (const runId of [...runsMap.keys()]) { + if (runId !== keepRunId.runId) { + runsMap.delete(runId); + } + } +} + +export function newRun(config: Config, id?: string): RunData { + if (id && runsMap.has(id)) { + throw new Error(`Run with id ${id} already exists`); + } + + const runId = id ?? generateMemorableName(runCounter.current); + const color = RUN_COLORS[runCounter.current % RUN_COLORS.length]; + const now = Date.now(); + // Find baseline as immediately-previous-by-creation + const existingRuns = [...runsMap.values()]; + existingRuns.sort((a, b) => (a.createdAt ?? a.lastUpdated) - (b.createdAt ?? b.lastUpdated)); + const prevRun = existingRuns.length > 0 ? existingRuns[existingRuns.length - 1] : undefined; + const diffSummary = describeConfigDiff(prevRun?.config ?? null, config, { + topK: 3, + prefixBoosts: PREFIX_BOOSTS + }); + + const run = { + runId: runId, + config, + color: color, + metrics: new SvelteMap(), + step: 0, + lastUpdated: now, + createdAt: now, + diffSummary + }; + runsMap.set(runId, run); + runCounter.current += 1; + currentRun.current = { runId: runId, config: config }; + return run; +} + +export function endRun() { + currentRun.current = null; +} + +export function restoreRun(run: RunData): RunData { + runsMap.set(run.runId, run); + runCounter.current += 1; + currentRun.current = { runId: run.runId, config: run.config }; + return run; +} + +/** + * Logs metric data for a specific step in a run. + */ +export function log( + runId: string, + data: { [metricName: string]: Omit }, + { step }: { step?: number } = {} +): void { + const run = runsMap.get(runId); + const determinedStep = step !== undefined ? step : (runsMap.get(runId)?.step ?? 0); + + // Create run if it doesn't exist + if (!run) { + throw new Error(`Run with id ${runId} does not exist`); + } + + const currentStep = determinedStep; + + // Update metrics for the specified step + for (const [metricName, value] of Object.entries(data)) { + let metric = run.metrics.get(metricName); + if (!metric) { + metric = { + metricName, + data: [] + }; + run.metrics.set(metricName, metric); + } + + let stepData: StepData; + if (typeof value === 'number') { + stepData = { step: currentStep, y: value }; + } else { + stepData = { step: currentStep, ...value } as ValidationStep; + } + + const updatedMetric = { + ...metric, + data: [...metric.data, stepData].sort((a: StepData, b: StepData) => a.step - b.step) + }; + + run.metrics.set(metricName, updatedMetric); + } + + // Update step counter + if (step === undefined) { + run.step += 1; + } else if (step > run.step) { + run.step = step; + } + + run.lastUpdated = Date.now(); + runsMap.set(runId, run); +} + +export type LogFn = typeof log; + +export function resetWorkspace(): void { + runsMap.clear(); + runCounter.current = 0; + currentRun.current = null; + console.debug('[workspaceState] Reset.'); +} + +// Group metrics by prefix (everything before the first '/') +export function getMetricGroups(metricNames: ReadonlyArray): Record { + const groups: Record = {}; + + const allMetricNames = metricNames; + + allMetricNames.forEach((metricName) => { + // For now we're just special-casing this, but we might want to bring it into the metrics view + // anyway + if (metricName === 'visualization/matches') { + return; + } + + const parts = metricName.split('/'); + const groupName = parts.length > 1 ? parts[0] : 'default'; + + if (!groups[groupName]) { + groups[groupName] = []; + } + groups[groupName].push(metricName); + }); + + return groups; +} diff --git a/examples/finetuning/src/lib/workspace/ui.svelte.ts b/examples/finetuning/src/lib/workspace/ui.svelte.ts new file mode 100644 index 00000000..f8c456a7 --- /dev/null +++ b/examples/finetuning/src/lib/workspace/ui.svelte.ts @@ -0,0 +1,317 @@ +import { config } from './config.svelte'; +import { LocalStorage } from './localStorage.svelte'; +import { newRun, restoreRun, type RunData } from './runs.svelte'; +import { + trainingState, + waitForNextCheckpoint, + workerPauseTraining, + workerReady, + workerRequestSave, + workerResumeTraining, + workerStartTraining, + workerStep, + workerStopTraining +} from './workers.svelte'; + +export const isMobile = $state({ current: false }); +export const activeTab: { current: 'about' | 'metrics' } = $state({ + current: 'about' +}); +export const isVisualizerEditorMinimized = $state({ current: true }); +export const hasWebGPU = $state({ current: false }); +export const browserInfo: { + current: { + type: + | 'chrome' + | 'edge' + | 'brave' + | 'arc' + | 'opera' + | 'vivaldi' + | 'safari' + | 'firefox' + | 'unknown'; + platform: 'ios' | 'macos' | 'windows' | 'android' | 'linux' | 'other'; + }; +} = $state({ + current: { type: 'unknown', platform: 'other' } +}); + +// Local-only GPU preferences/state +export const gpuPowerPreference = new LocalStorage<'high-performance' | 'low-power'>( + 'gpuPowerPreference', + 'high-performance' +); +const gpuName = $state<{ current: string | null }>({ current: null }); +export function setGpuName(name: string | null) { + gpuName.current = name; +} +export function getGpuName() { + return gpuName.current; +} + +export const setupUI = () => { + // Browser/platform detection (best-effort; UA-CH not universally available yet) + const ua = navigator.userAgent.toLowerCase(); + const vendor = navigator.vendor?.toLowerCase?.() ?? ''; + + // Platform + let platform: 'ios' | 'macos' | 'windows' | 'android' | 'linux' | 'other' = 'other'; + if (/iphone|ipad|ipod/.test(ua)) platform = 'ios'; + else if (/macintosh|mac os x/.test(ua)) platform = 'macos'; + else if (/windows nt/.test(ua)) platform = 'windows'; + else if (/android/.test(ua)) platform = 'android'; + else if (/linux/.test(ua)) platform = 'linux'; + + // Chromium-family checks + // Distinguish some popular Chromium variants before generic Chrome + let type: + | 'chrome' + | 'edge' + | 'brave' + | 'arc' + | 'opera' + | 'vivaldi' + | 'safari' + | 'firefox' + | 'unknown' = 'unknown'; + if (/edg\//.test(ua)) type = 'edge'; + else if (/vivaldi/.test(ua)) type = 'vivaldi'; + else if (/opr\//.test(ua)) type = 'opera'; + else if (/brave/.test(ua)) type = 'brave'; + else if (/arc\//.test(ua)) type = 'arc'; + else if (/firefox/.test(ua)) type = 'firefox'; + else if (/safari/.test(ua) && /apple/.test(vendor) && !/chrome|crios|android/.test(ua)) + type = 'safari'; + else if (/chrome|crios/.test(ua)) type = 'chrome'; + + browserInfo.current = { type, platform }; + + const mediaQuery = window.matchMedia('(min-width: 40rem)'); + isMobile.current = !mediaQuery.matches; + + // Set configOpen based on media query if not already set by user + if (configOpen.current === null) { + configOpen.current = mediaQuery.matches; + } + + // Listen for changes in screen size + const handleMediaChange = (e: MediaQueryListEvent) => { + isMobile.current = !e.matches; + + // If switching to mobile and config is open, close it and reset tab + if (isMobile.current && configOpen.current) { + configOpen.current = false; + activeTab.current = 'about'; + } + }; + + mediaQuery.addEventListener('change', handleMediaChange); + + return () => { + mediaQuery.removeEventListener('change', handleMediaChange); + }; +}; + +// Function to handle tab selection with mobile behavior +export function selectTab(tabName: 'about' | 'metrics') { + activeTab.current = tabName; + if (isMobile.current && configOpen.current) { + configOpen.current = false; + } +} + +let flashVramLimit = $state(false); + +export function triggerVramLimitFlash() { + controlSectionsOpen.current.training = true; + + // Scroll to GPU memory limit after a brief delay to allow section to open + setTimeout(() => { + const trainingVramLimitElement = document.getElementById('training-vram-limit'); + flashVramLimit = true; + if (trainingVramLimitElement) { + trainingVramLimitElement.scrollIntoView({ + behavior: 'instant', + block: 'center' + }); + trainingVramLimitElement.classList.add('error-flash'); + setTimeout(() => { + trainingVramLimitElement.classList.remove('error-flash'); + flashVramLimit = false; + }, 1000); + } + }, 100); +} + +export function getFlashVramLimit() { + return flashVramLimit; +} + +let showLowDiversityDatasetError = $state(false); + +export function triggerLowDiversityDatasetError() { + controlSectionsOpen.current.task = true; + showLowDiversityDatasetError = true; + + // Scroll to GPU memory limit after a brief delay to allow section to open + setTimeout(() => { + const lowDiversityDatasetErrorElement = document.getElementById('low-diversity-dataset-error'); + if (lowDiversityDatasetErrorElement) { + lowDiversityDatasetErrorElement.scrollIntoView({ + behavior: 'instant', + block: 'center' + }); + } + }, 100); +} + +export function getShowLowDiversityDatasetError() { + return showLowDiversityDatasetError; +} + +export function resetLowDiversityDatasetError() { + showLowDiversityDatasetError = false; +} + +const iconStrokeWidth = $derived(isMobile ? 2 : 2.5); + +export function getIconStrokeWidth() { + return iconStrokeWidth; +} + +// Initialize sectionsOpen from localStorage or use defaults +export const controlSectionsOpen = new LocalStorage('controlSectionsOpen', { + gpu: true, + runs: true, + training: true, + task: true, + model: true, + optimizer: true, + advanced: false +}); + +export function toggleControlSection(sectionName: keyof typeof controlSectionsOpen.current) { + controlSectionsOpen.current[sectionName] = !controlSectionsOpen.current[sectionName]; +} + +export const metricsSectionsOpen = new LocalStorage('metricsSectionsOpen', {}); + +export function toggleMetricsSection(sectionName: string) { + metricsSectionsOpen.current[sectionName] = !(metricsSectionsOpen.current[sectionName] ?? true); +} + +export const maxCompletions = new LocalStorage('maxCompletions', 4); + +export function setMaxCompletions(value: number) { + maxCompletions.current = value; +} + +// Visibility state for per-metric charts (user overrides only) +export const metricVisibility = new LocalStorage('metricVisibility', {}); + +// Initialize configOpen from localStorage with no default (null means use media query) +export const configOpen = new LocalStorage('configOpen', null); + +export const tourState = new LocalStorage<{ + startedExperiment: boolean; + restartedExperiment: boolean; + seenCQLTutorial: boolean; +}>('tourState', { + startedExperiment: false, + restartedExperiment: false, + seenCQLTutorial: false +}); + +export function switchToMetrics() { + selectTab('metrics'); +} + +export function toggleConfig() { + configOpen.current = !configOpen.current; +} + +export async function saveModel() { + // Set up waiter BEFORE causing a save so auto-save on pause satisfies it + const waiter = waitForNextCheckpoint(); + + if (trainingState.current === 'training') { + workerPauseTraining(); + // paused handler will request a save + } else if (trainingState.current === 'paused') { + workerRequestSave(); + } else { + return; + } + + const { runId, buffer } = await waiter; + const blob = new Blob([buffer.buffer as ArrayBuffer], { + type: 'application/octet-stream' + }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `${runId}.safetensors`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); +} + +// Function to start training +export function startTraining( + options: { run?: RunData; resumeFrom: Uint8Array } | undefined = undefined +) { + const { run, resumeFrom } = options ?? {}; + + if (trainingState.current !== 'stopped' || !workerReady.current) return; + + trainingState.current = 'training'; + const effectiveRun = run ? restoreRun(run) : newRun(JSON.parse(JSON.stringify(config))); + + if (isMobile.current && configOpen.current) { + configOpen.current = false; + } + + // We don't want to wrench them away from the visualize tab, but if they're + // running an experiment, we want it to look like something is happening. + if (!tourState.current.startedExperiment || activeTab.current === 'about') { + switchToMetrics(); + } + + if (!tourState.current.startedExperiment) { + tourState.current.startedExperiment = true; + } + + if (getShowLowDiversityDatasetError()) { + resetLowDiversityDatasetError(); + } + + workerStartTraining(effectiveRun.runId, resumeFrom ? resumeFrom : undefined); +} + +// Function to stop training +export async function stopTraining() { + await workerStopTraining(); +} + +export function togglePause() { + if (trainingState.current === 'stopped') return; + if (trainingState.current === 'training') { + workerPauseTraining(); + } else { + workerResumeTraining(); + } +} + +export function stepForward() { + if (trainingState.current === 'stopped') return; + workerStep(); +} + +export async function restartTraining() { + await stopTraining(); + startTraining(); + if (!tourState.current.restartedExperiment) { + tourState.current.restartedExperiment = true; + } +} diff --git a/examples/finetuning/src/lib/workspace/utils.ts b/examples/finetuning/src/lib/workspace/utils.ts new file mode 100644 index 00000000..9ad244c5 --- /dev/null +++ b/examples/finetuning/src/lib/workspace/utils.ts @@ -0,0 +1,39 @@ +import { adjectives, animals, uniqueNamesGenerator } from 'unique-names-generator'; + +export function generateMemorableName(number: number) { + // by default uniqueNamesGenerator picks one word per dictionary + const twoWord = uniqueNamesGenerator({ + dictionaries: [adjectives, animals], + separator: '-', + style: 'lowerCase', + length: 2 + }); + return `${twoWord}-${number}`; +} + +/** + * Returns a new array sorted so that any items whose key is found in `priorityKeys` + * appear first in the specified order, followed by the remaining items sorted by + * their key alphabetically. + * + * Example: + * sortWithPriority(['b', 'a', 'c'], (x) => x, ['c']) -> ['c', 'a', 'b'] + */ +export function sortWithPriority( + items: ReadonlyArray, + getKey: (item: T) => string, + priorityKeys: ReadonlyArray +): T[] { + const priorityIndex = new Map(); + for (let i = 0; i < priorityKeys.length; i++) { + priorityIndex.set(priorityKeys[i], i); + } + return [...items].sort((a, b) => { + const ka = getKey(a); + const kb = getKey(b); + const ia = priorityIndex.has(ka) ? (priorityIndex.get(ka) as number) : Number.POSITIVE_INFINITY; + const ib = priorityIndex.has(kb) ? (priorityIndex.get(kb) as number) : Number.POSITIVE_INFINITY; + if (ia !== ib) return ia - ib; + return ka.localeCompare(kb); + }); +} diff --git a/examples/finetuning/src/lib/workspace/workers.svelte.ts b/examples/finetuning/src/lib/workspace/workers.svelte.ts new file mode 100644 index 00000000..26034929 --- /dev/null +++ b/examples/finetuning/src/lib/workspace/workers.svelte.ts @@ -0,0 +1,507 @@ +import type { Config } from '$lib/workspace/config'; +import type { IndexState } from '@piston-ml/piston-web'; + +import { SvelteMap } from 'svelte/reactivity'; + +import { config } from './config.svelte'; +import { lastSessionStore } from './lastSessionStore'; +import { currentRun, log, runsMap } from './runs.svelte'; +import { + gpuPowerPreference, + triggerLowDiversityDatasetError, + triggerVramLimitFlash +} from './ui.svelte'; + +// Train state +let trainWorker: Worker | null = $state(null); +export const workerReady = $state({ current: false }); +export const workerVersion = $state({ current: 0 }); +export const trainingState = $state<{ current: 'training' | 'paused' | 'stopped' }>({ + current: 'stopped' +}); + +// UA memory measurement state (main thread only) +let uaMemoryInterval: ReturnType | null = null; +let lastUAMemoryBytes: number | null = null; + +let screenWakeLock: WakeLockSentinel | null = null; + +type CheckpointPayload = { runId: string; buffer: Uint8Array }; +const pendingCheckpointWaiters: Array<(p: CheckpointPayload) => void> = []; +const pendingPeekResolvers = new SvelteMap void>(); + +async function acquireScreenWakeLock() { + // Only attempt in browser/secure contexts that support it + if (typeof navigator === 'undefined' || !('wakeLock' in navigator)) return; + try { + // Request a screen wake lock + screenWakeLock = await navigator.wakeLock.request('screen'); + // Ensure our local reference is cleared if the system revokes the lock + screenWakeLock?.addEventListener?.('release', () => { + screenWakeLock = null; + }); + } catch (err) { + console.warn('Screen Wake Lock request failed:', err); + } +} + +async function releaseScreenWakeLock() { + if (!screenWakeLock) return; + try { + await screenWakeLock.release(); + } catch (err) { + console.warn('Screen Wake Lock release failed:', err); + } finally { + screenWakeLock = null; + } +} + +export async function initializeWorker() { + return new Promise((resolve, reject) => { + try { + // Create the dedicated module worker + // eslint-disable-next-line svelte/prefer-svelte-reactivity + trainWorker = new Worker(new URL('$lib/train/moduleWorker.ts', import.meta.url), { + type: 'module', + name: 'moduleWorker' + }); + + console.log('[Main] Module worker created successfully.'); + + // Set up UA memory measurement (immediate + interval) on main thread (only once) + if (!uaMemoryInterval) { + const measure = ( + performance as Performance & { + measureUserAgentSpecificMemory?: () => Promise<{ bytes: number }>; + } + ).measureUserAgentSpecificMemory; + if (typeof measure === 'function') { + const measureAndStore = async () => { + try { + const { bytes } = await ( + performance as Performance & { + measureUserAgentSpecificMemory?: () => Promise<{ bytes: number }>; + } + ).measureUserAgentSpecificMemory!(); + if (typeof bytes === 'number' && Number.isFinite(bytes)) { + lastUAMemoryBytes = bytes; + } + } catch (err) { + console.warn('Error measuring UA memory:', err); + // Ignore measurement errors + } + }; + // Immediate measurement so first log can include it + void measureAndStore(); + uaMemoryInterval = setInterval(() => { + void measureAndStore(); + }, 10_000); + } else { + console.debug( + 'performance.measureUserAgentSpecificMemory is not available; skipping UA memory interval' + ); + } + } + + trainWorker.onmessage = (event) => { + const { type, ...data } = event.data; + + switch (type) { + case 'ready': + console.log('[Main] Worker is ready'); + resolve(); + workerReady.current = true; + workerVersion.current += 1; + break; + case 'checkpoint.config': { + const { requestId, config: cfg } = data as { + requestId: string; + config: Config; + }; + const resolver = pendingPeekResolvers.get(requestId); + if (resolver) { + pendingPeekResolvers.delete(requestId); + resolver(cfg); + } + break; + } + case 'metrics': { + // Handle training metric logs + if (!data.runId || !data.data) { + console.error('[Main] Invalid metrics data:', data); + return; + } + const step = data.metadata?.step as number | undefined; + const combinedMetrics: Record> = {}; + for (const [metricName, value] of Object.entries(data.data)) { + combinedMetrics[metricName] = value as number | Record; + } + if (lastUAMemoryBytes !== null) { + combinedMetrics['allocation/cpu_memory_mb'] = lastUAMemoryBytes / (1024 * 1024); + lastUAMemoryBytes = null; + } + log(data.runId, combinedMetrics, { step }); + break; + } + case 'complete': + console.log(`[Main] Training completed for run ${data.runId}`); + trainingState.current = 'stopped'; + currentRun.current = null; + void releaseScreenWakeLock(); + break; + + case 'restart': { + console.log(`[Main] Worker requested restart for run ${data.runId}`); + const buffer = data.buffer as Uint8Array; + const runId = data.runId as string; + // Persist last session snapshot with checkpoint + const run = runsMap.get(runId); + if (run) { + void lastSessionStore.set(run, buffer); + } + // Terminate and recreate worker + trainWorker?.terminate(); + workerReady.current = false; + // Ensure training state reflects continuity across restart + trainingState.current = 'training'; + initializeWorker().then(() => { + // Send start with resumeFrom to resume same run id + trainWorker!.postMessage({ + type: 'start', + data: { + runId, + config: $state.snapshot(config), + resumeFrom: buffer, + gpuPowerPreference: gpuPowerPreference.current + } + }); + }); + break; + } + case 'checkpoint': { + const uint8array = data.buffer as Uint8Array | undefined; + const runId = data.runId as string | undefined; + + // Persist last session snapshot with checkpoint (always) + if (uint8array && runId) { + const run = runsMap.get(runId); + if (run) { + void lastSessionStore.set(run, uint8array); + } + + // Fulfill all waiters if present + for (const waiter of pendingCheckpointWaiters) { + void waiter({ runId, buffer: uint8array }); + } + pendingCheckpointWaiters.length = 0; + } + + break; + } + case 'error': + if (data.name === 'VRAMLimitExceededError') { + console.error(`[Main] VRAM limit exceeded for run ${data.runId}:`, data.message); + triggerVramLimitFlash(); + } else if (data.name === 'LowDiversityDatasetError') { + console.error( + `[Main] Low diversity dataset error for run ${data.runId}:`, + data.message + ); + triggerLowDiversityDatasetError(); + } else { + console.error(`[Main] Training error for run ${data.runId}:`, data.message); + } + workerStopTraining(); + break; + case 'paused': + console.log('[Main] Training paused'); + trainingState.current = 'paused'; + // Request a save to persist checkpoint; session will be stored alongside when checkpoint arrives + if (currentRun.current?.runId) { + workerRequestSave(); + } + break; + case 'resumed': + console.log('[Main] Training resumed'); + trainingState.current = 'training'; + break; + } + }; + + trainWorker.onerror = (event) => { + console.error('[Main] Worker onerror:', event); + reject(new Error(event.error)); + }; + } catch (error) { + console.error('[Main] Failed to create worker:', error); + reject(error); + } + }); +} + +export function workerStartTraining(runId: string, resumeFrom?: Uint8Array) { + if (!trainWorker) { + throw new Error('Worker not initialized'); + } + + trainWorker.postMessage({ + type: 'start', + data: { + runId: runId, + config: $state.snapshot(config), + resumeFrom, + gpuPowerPreference: gpuPowerPreference.current + } + }); + + trainingState.current = 'training'; + void acquireScreenWakeLock(); +} + +export function workerRequestSave() { + if (!trainWorker) { + throw new Error('Worker not initialized'); + } + trainWorker.postMessage({ type: 'save' }); +} + +export function waitForNextCheckpoint(): Promise { + return new Promise((resolve) => { + pendingCheckpointWaiters.push(resolve); + }); +} + +export function peekCheckpointConfig(buffer: Uint8Array): Promise { + if (!trainWorker) { + return Promise.reject(new Error('Worker not initialized')); + } + const requestId = crypto.randomUUID(); + return new Promise((resolve) => { + pendingPeekResolvers.set(requestId, resolve); + trainWorker!.postMessage({ type: 'checkpoint.peekConfig', data: { requestId, buffer } }); + }); +} + +export async function workerStopTraining() { + if (!trainWorker || trainingState.current === 'stopped') return; + void releaseScreenWakeLock(); + + // For now, we'll just terminate and recreate the worker + // In a more sophisticated implementation, we'd send a stop message + trainWorker.terminate(); + workerReady.current = false; + trainingState.current = 'stopped'; + currentRun.current = null; + + // Recreate worker + await initializeWorker(); +} + +export function workerPauseTraining() { + if (!trainWorker || trainingState.current !== 'training') return; + trainWorker.postMessage({ type: 'pause' }); +} + +export function workerResumeTraining() { + if (!trainWorker || trainingState.current !== 'paused') return; + trainWorker.postMessage({ type: 'resume' }); +} + +export function workerStep() { + if (!trainWorker || trainingState.current === 'stopped') return; + trainWorker.postMessage({ type: 'step' }); +} + +// +// Model inspection state +// + +let parameterCount = $state(null); +let hiddenSize = $state(null); +let mlpIntermediateSize = $state(null); +let modelIndex = $state(null); +let modelInspectionRequestId = $state(null); +let isInspectingModel = $state(false); +let modelInspectionWorker: Worker | null = $state(null); + +export function getParameterCount() { + return parameterCount; +} + +export function getHiddenSize() { + return hiddenSize; +} + +export function getMlpIntermediateSize() { + return mlpIntermediateSize; +} + +export function getModelIndex() { + return modelIndex; +} + +export function getIsInspectingModel() { + return isInspectingModel; +} + +export function setModelInspectionWorker(workerInstance: Worker | null) { + modelInspectionWorker = workerInstance; + // Trigger initial model inspection when worker is set + if (modelInspectionWorker && !isInspectingModel) { + setTimeout(() => requestModelInspection(), 0); + } +} + +// Export a function to manually trigger model inspection +export function triggerModelInspection() { + if (modelInspectionWorker && !isInspectingModel) { + setTimeout(() => requestModelInspection(), 0); + } +} + +function requestModelInspection() { + if (!modelInspectionWorker || isInspectingModel) return; + + isInspectingModel = true; + modelInspectionRequestId = crypto.randomUUID(); + + try { + modelInspectionWorker.postMessage({ + type: 'inspectModel', + data: { + config: $state.snapshot(config), + requestId: modelInspectionRequestId, + gpuPowerPreference: gpuPowerPreference.current + } + }); + } catch (error) { + console.error('Failed to request model inspection:', error); + isInspectingModel = false; + modelInspectionRequestId = null; + } +} + +export function handleModelInspectionResponse(data: { + requestId: string; + parameterCount: number; + hiddenSize: number; + mlpIntermediateSize: number; + vocabSize: number; + modelIndex: IndexState; +}) { + if (data.requestId === modelInspectionRequestId) { + parameterCount = data.parameterCount; + hiddenSize = data.hiddenSize; + mlpIntermediateSize = data.mlpIntermediateSize; + modelIndex = data.modelIndex; + isInspectingModel = false; + modelInspectionRequestId = null; + } +} + +export function handleModelInspectionError(data: { requestId: string; message: string }) { + if (data.requestId === modelInspectionRequestId) { + console.error('Model inspection error:', data.message); + isInspectingModel = false; + modelInspectionRequestId = null; + } +} + +export async function initializeModelInspectionWorker() { + return new Promise((resolve, reject) => { + try { + // Create the dedicated model inspection worker + // eslint-disable-next-line svelte/prefer-svelte-reactivity + modelInspectionWorker = new Worker(new URL('$lib/train/moduleWorker.ts', import.meta.url), { + type: 'module', + name: 'modelInspectionWorker' + }); + + console.log('[Main] Model inspection worker created successfully.'); + + modelInspectionWorker.onmessage = (event) => { + const { type, ...data } = event.data; + + switch (type) { + case 'ready': + console.log('[Main] Model inspection worker is ready'); + resolve(); + // Set the worker reference for model inspection + setModelInspectionWorker(modelInspectionWorker); + break; + + case 'modelInspection': + handleModelInspectionResponse(data); + break; + + case 'modelInspectionError': + handleModelInspectionError(data); + break; + + case 'error': + console.error('[Main] Model inspection worker error:', data.message); + break; + } + }; + + modelInspectionWorker.onerror = (event) => { + console.error('[Main] Model inspection worker error:', event.message, event); + reject(new Error(event.error)); + }; + } catch (error) { + console.error('[Main] Failed to create model inspection worker:', error); + reject(error); + } + }); +} + +export async function initializeWorkers() { + return Promise.all([initializeWorker(), initializeModelInspectionWorker()]); +} + +export function cleanupWorkers() { + if (trainWorker) { + trainWorker.terminate(); + trainWorker = null; + } + if (modelInspectionWorker) { + modelInspectionWorker.terminate(); + modelInspectionWorker = null; + } + // Clear UA memory interval + if (uaMemoryInterval) { + clearInterval(uaMemoryInterval); + uaMemoryInterval = null; + } + lastUAMemoryBytes = null; + + void releaseScreenWakeLock(); +} + +// +// Visualizer APIs +// +// eslint-disable-next-line svelte/prefer-svelte-reactivity +const canvasesWithAttemptedInitialization = new Set(); +export function initializeVisualizerCanvas( + canvas: HTMLCanvasElement, + labelPaddingCssPx: number = 0 +) { + if (!trainWorker) throw new Error('Worker not initialized'); + if (canvasesWithAttemptedInitialization.has(canvas)) return; + const offscreen = canvas.transferControlToOffscreen(); + trainWorker.postMessage( + { type: 'visualizer.canvas', data: { canvas: offscreen, labelPaddingCssPx } }, + [offscreen] + ); + canvasesWithAttemptedInitialization.add(canvas); +} + +export function resizeVisualizer(width: number) { + if (!trainWorker) return; + trainWorker.postMessage({ type: 'visualizer.resize', data: { width } }); +} + +export function getWorkerVersion() { + return workerVersion; +} diff --git a/examples/finetuning/src/routes/+layout.svelte b/examples/finetuning/src/routes/+layout.svelte new file mode 100644 index 00000000..47d9355e --- /dev/null +++ b/examples/finetuning/src/routes/+layout.svelte @@ -0,0 +1,95 @@ + + + + sequence toy + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +{@render children()} diff --git a/examples/finetuning/src/routes/+layout.ts b/examples/finetuning/src/routes/+layout.ts new file mode 100644 index 00000000..189f71e2 --- /dev/null +++ b/examples/finetuning/src/routes/+layout.ts @@ -0,0 +1 @@ +export const prerender = true; diff --git a/examples/finetuning/src/routes/+page.svelte b/examples/finetuning/src/routes/+page.svelte new file mode 100644 index 00000000..a725643e --- /dev/null +++ b/examples/finetuning/src/routes/+page.svelte @@ -0,0 +1,517 @@ + + +{#snippet tabButton( + title: string, + icon: typeof ChartLine, + hideBorder: boolean, + isActive: boolean, + disabled: boolean, + highlighted: boolean, + hasActivity: boolean, + onClick: () => void +)} + {@const Icon = icon} + +{/snippet} + +
+
+ Sequence Toy + +
+
+ + + + {#if configOpen.current !== false} +
+
+ {#if trainingState.current !== 'stopped'} +
+ {currentRun.current?.runId} + +
+ {/if} +
+
+ {#if trainingState.current === 'stopped'} + {#if canResume} + + + + Resume from last session + + + {/if} + startTraining()} + highlighted={workerReady.current && + trainingState.current === 'stopped' && + !tourState.current.startedExperiment} + class="w-full h-7.5" + > + + + Start Training + + + {:else} +
+ + + {#if trainingState.current === 'paused'} + + {:else} + + {/if} + + + + + + + + + + + + + + + + + + {#if shouldSuggestRestart} + New Changes + {/if} + + +
+ {/if} +
+
+
+ +
+
+
+ {/if} + + + {#if shouldShowTabContent} +
+ {#if activeTab.current === 'metrics'} + + {:else if activeTab.current === 'about'} + + {/if} +
+ {/if} +
+
+ +{#if isDragOver} +
+
+ Drop checkpoint to start from it +
+
+{/if} + +{#if showReplaceDialog} +
{ + if (e.key === 'Enter' || e.key === ' ') cancelReplace(); + }} + > + +
+{/if} + + diff --git a/examples/finetuning/src/routes/tabs/About.svelte b/examples/finetuning/src/routes/tabs/About.svelte new file mode 100644 index 00000000..85174799 --- /dev/null +++ b/examples/finetuning/src/routes/tabs/About.svelte @@ -0,0 +1,118 @@ + + +
+
+
+ + {#if !hasWebGPU.current} + {@const isBrowserUnknown = browserInfo.current.type === 'unknown'} + +
+

+ Sequence Toy requires WebGPU support, and your browser doesn't support it yet. + {#if isBrowserUnknown} + You have a few options: + {/if} +

+
+
    + {#if isBrowserUnknown || (browserInfo.current.type !== 'firefox' && browserInfo.current.type !== 'safari')} +
  • + Use the latest version of Chrome, or a + Chromium-based browser (Edge, Arc, Brave, etc.) +
      +
    • + If that doesn't work, try Chrome Canary and ensure WebGPU is enabled. +
    • +
    +
  • + {/if} + {#if isBrowserUnknown || browserInfo.current.type === 'firefox'} +
  • + Use Firefox Nightly + +
  • + {/if} + {#if isBrowserUnknown || browserInfo.current.type === 'safari'} +
  • + Use the latest Safari Technology Preview or enable WebGPU via Feature + Flags. +
      + {#if isBrowserUnknown || browserInfo.current.platform === 'ios'} +
    • + iOS: System Settings > Apps > Safari > Advanced > Feature Flags > Enable + "WebGPU" +
    • + {/if} + {#if isBrowserUnknown || browserInfo.current.platform === 'macos'} +
    • + MacOS: Develop menu Feature Flags > Enable "WebGPU" +
    • + {/if} +
    +
  • + {/if} +
+
+ {/if} +
+

+ Finetune Toy +

+
+ +

Train a language model in your browser with WebGPU

+
+
+ +
+
diff --git a/examples/finetuning/src/routes/tabs/Metrics.svelte b/examples/finetuning/src/routes/tabs/Metrics.svelte new file mode 100644 index 00000000..f8903fab --- /dev/null +++ b/examples/finetuning/src/routes/tabs/Metrics.svelte @@ -0,0 +1,138 @@ + + +
+ {#if !tourState.current.restartedExperiment} + + Tinker with the experiment setup, then click New Changes + to try out your changes. You can probably break it! + Report issues on Github . + + {/if} + {#if Object.keys(metricGroups).length === 0} +
+

No metrics available. Start training to see charts.

+
+ {:else} +
+ {#each Object.entries(metricGroups).sort(([a], [b]) => { + const order = ['validation', 'train', 'optimizer']; + const aPriority = order.indexOf(a); + const bPriority = order.indexOf(b); + return (aPriority === -1 ? 999 : aPriority) - (bPriority === -1 ? 999 : bPriority) || a.localeCompare(b); + }) as [groupName, metrics] (groupName)} + {@const filteredMetrics = getFilteredMetrics(groupName, metrics)} + {@const hasMetrics = filteredMetrics.length > 0} + {@const sectionOpen = (metricsSectionsOpen.current[groupName] ?? true) && hasMetrics} + toggleMetricsSection(groupName) : undefined} + > + {#snippet chips()} + + {/snippet} + {#each filteredMetrics as metricName (metricName)} +
+ {#if metricName === 'validation/completions'} + + {:else} + + {/if} +
+ {/each} +
+ {/each} +
+ {/if} +
diff --git a/examples/finetuning/static/Berkeley Mono Variable.woff2 b/examples/finetuning/static/Berkeley Mono Variable.woff2 new file mode 100644 index 00000000..831f3b18 Binary files /dev/null and b/examples/finetuning/static/Berkeley Mono Variable.woff2 differ diff --git a/examples/finetuning/static/_headers b/examples/finetuning/static/_headers new file mode 100644 index 00000000..4b8e46b9 --- /dev/null +++ b/examples/finetuning/static/_headers @@ -0,0 +1,4 @@ +# This allows us to use performance.measureUserAgentSpecificMemory on browsers that support it (Chrome) +/* + Cross-Origin-Embedder-Policy: require-corp + Cross-Origin-Opener-Policy: same-origin \ No newline at end of file diff --git a/examples/finetuning/svelte.config.js b/examples/finetuning/svelte.config.js new file mode 100644 index 00000000..64694ad1 --- /dev/null +++ b/examples/finetuning/svelte.config.js @@ -0,0 +1,15 @@ +import adapter from '@sveltejs/adapter-static'; +import { vitePreprocess } from '@sveltejs/vite-plugin-svelte'; + +/** @type {import('@sveltejs/kit').Config} */ +const config = { + // Consult https://svelte.dev/docs/kit/integrations + // for more information about preprocessors + preprocess: vitePreprocess(), + + kit: { + adapter: adapter() + } +}; + +export default config; diff --git a/examples/finetuning/tsconfig.json b/examples/finetuning/tsconfig.json new file mode 100644 index 00000000..ad478e1b --- /dev/null +++ b/examples/finetuning/tsconfig.json @@ -0,0 +1,22 @@ +{ + "extends": "./.svelte-kit/tsconfig.json", + "compilerOptions": { + "allowJs": true, + "checkJs": true, + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "resolveJsonModule": true, + "skipLibCheck": true, + "sourceMap": true, + "strict": true, + "moduleResolution": "bundler", + "types": [ + "@webgpu/types" + ], + } + // Path aliases are handled by https://svelte.dev/docs/kit/configuration#alias + // except $lib which is handled by https://svelte.dev/docs/kit/configuration#files + // + // If you want to overwrite includes/excludes, make sure to copy over the relevant includes/excludes + // from the referenced tsconfig.json - TypeScript does not merge them in +} diff --git a/examples/finetuning/vite.config.ts b/examples/finetuning/vite.config.ts new file mode 100644 index 00000000..870a05f3 --- /dev/null +++ b/examples/finetuning/vite.config.ts @@ -0,0 +1,94 @@ +import { sveltekit } from '@sveltejs/kit/vite'; +import tailwindcss from '@tailwindcss/vite'; +import { execSync } from 'node:child_process'; +import fsSync from 'node:fs'; +import path from 'path'; +import sirv from 'sirv'; +import { fileURLToPath } from 'url'; +import { defineConfig, loadEnv, type ViteDevServer } from 'vite'; +import wasm from 'vite-plugin-wasm'; + +const projectRoot = path.resolve(path.dirname(fileURLToPath(import.meta.url)), '../..'); + +// Dev-only mount for tokenizer and tokenized directories via env paths +const devStaticMount = (opts: { tokenizerDir?: string; tokenizedDir?: string }) => ({ + name: 'dev-static-mount', + apply: 'serve', + configureServer(server: ViteDevServer) { + const tokenizerDir = opts.tokenizerDir; + const tokenizedDir = opts.tokenizedDir; + if (tokenizerDir) + server.middlewares.use('/tokenizer', sirv(tokenizerDir, { dev: true, etag: true })); + if (tokenizedDir) + server.middlewares.use('/tokenized', sirv(tokenizedDir, { dev: true, etag: true })); + } +}); + +const commitHash = execSync('git rev-parse --short HEAD').toString().trim(); + +export default defineConfig(({ mode }) => { + const envDir = path.dirname(fileURLToPath(import.meta.url)); + const env = loadEnv(mode, envDir, ''); + + return { + define: { + __COMMIT_HASH__: JSON.stringify(commitHash) + }, + plugins: [ + tailwindcss(), + ...(mode === 'development' + ? [ + devStaticMount({ + tokenizerDir: env.VITE_TOKENIZER_DIR, + tokenizedDir: env.VITE_TOKENIZED_DIR + }) + ] + : []), + sveltekit(), + wasm() + ], + worker: { + format: 'es', + plugins: () => [wasm(), sveltekit()] + }, + resolve: { + dedupe: [ + 'svelte', + 'svelte/legacy', + '@codemirror/state', + '@codemirror/view', + '@codemirror/language', + '@codemirror/lang-javascript', + '@codemirror/lint', + 'codemirror', + '@lezer/highlight' + ] + }, + esbuild: { + supported: { 'top-level-await': true }, + keepNames: true + }, + server: { + allowedHosts: ['photon-5.local', 'localhost', '127.0.0.1'], + https: { + key: fsSync.readFileSync('./localhost+5-key.pem'), + cert: fsSync.readFileSync('./localhost+5.pem') + }, + fs: { + // Allow serving files from the project root and one level up + allow: [ + // Allow serving from the Svelte project directory + path.resolve(path.dirname(fileURLToPath(import.meta.url))), + // Allow serving from the entire ratchet project directory + projectRoot, + // Allow serving from the WASM file's directory + path.resolve(projectRoot, 'target', 'pkg', 'piston-web') + ] + }, + headers: { + 'Cross-Origin-Embedder-Policy': 'require-corp', + 'Cross-Origin-Opener-Policy': 'same-origin' + } + } + }; +}); diff --git a/examples/finetuning/vitest-setup-client.ts b/examples/finetuning/vitest-setup-client.ts new file mode 100644 index 00000000..ea18d6a1 --- /dev/null +++ b/examples/finetuning/vitest-setup-client.ts @@ -0,0 +1,19 @@ +import { vi } from 'vitest'; + +import '@testing-library/jest-dom/vitest'; + +// required for svelte5 + jsdom as jsdom does not support matchMedia +Object.defineProperty(window, 'matchMedia', { + writable: true, + enumerable: true, + value: vi.fn().mockImplementation((query) => ({ + matches: false, + media: query, + onchange: null, + addEventListener: vi.fn(), + removeEventListener: vi.fn(), + dispatchEvent: vi.fn() + })) +}); + +// add more mocks here if you need them diff --git a/examples/piston-train-toy/src/lib/train/data/natural/dataset.ts b/examples/piston-train-toy/src/lib/train/data/natural/dataset.ts index 3eb450ac..76a0b04c 100644 --- a/examples/piston-train-toy/src/lib/train/data/natural/dataset.ts +++ b/examples/piston-train-toy/src/lib/train/data/natural/dataset.ts @@ -1,5 +1,5 @@ -import type { CitationEntries } from '$lib/components/controls/Citations.svelte'; import type { Config } from '$lib/workspace/config'; +import type { CitationEntries } from 'example-common'; import { PUBLIC_DATA_URL } from '$env/static/public'; import { PreTrainedTokenizer } from '$lib/train/tokenizer'; diff --git a/examples/piston-train-toy/src/lib/workspace/config.svelte.ts b/examples/piston-train-toy/src/lib/workspace/config.svelte.ts index 42352e11..3f57ddf5 100644 --- a/examples/piston-train-toy/src/lib/workspace/config.svelte.ts +++ b/examples/piston-train-toy/src/lib/workspace/config.svelte.ts @@ -67,7 +67,7 @@ const CONFIG_DEFAULTS: Config = { value: 1.0 }, useWeakTensorReferences: true, - sharedObjectAllocation: false, + sharedObjectAllocation: true, cachingEnabled: false, inplaceSupport: true, enableVisualization: true, diff --git a/examples/piston-train-toy/src/routes/tabs/About.svelte b/examples/piston-train-toy/src/routes/tabs/About.svelte index 7f146b45..68a97718 100644 --- a/examples/piston-train-toy/src/routes/tabs/About.svelte +++ b/examples/piston-train-toy/src/routes/tabs/About.svelte @@ -1,8 +1,6 @@