From 12ace6450063ad516ff06f3fa2e291f46669f73e Mon Sep 17 00:00:00 2001
From: Vin Howe <24789592+vinhowe@users.noreply.github.com>
Date: Fri, 28 Nov 2025 11:16:47 -0700
Subject: [PATCH 1/5] First pass at finetune demo
---
examples/finetuning/.editorconfig | 23 +
examples/finetuning/.gitignore | 28 +
examples/finetuning/.npmrc | 2 +
examples/finetuning/.prettierignore | 4 +
examples/finetuning/.prettierrc | 18 +
examples/finetuning/README.md | 1 +
examples/finetuning/eslint.config.js | 81 +
examples/finetuning/package.json | 63 +
examples/finetuning/src/app.css | 25 +
examples/finetuning/src/app.d.ts | 14 +
examples/finetuning/src/app.html | 11 +
.../src/lib/attachments/echarts.svelte.ts | 236 ++
.../lib/components/MetricToggleChips.svelte | 41 +
.../lib/components/controls/Controls.svelte | 776 ++++++
.../controls/DatasetControls.svelte | 38 +
.../components/controls/DatasetSample.svelte | 158 ++
.../controls/LRSchedulePicker.svelte | 517 ++++
.../lib/components/controls/RunsTable.svelte | 52 +
.../components/controls/SelectDataset.svelte | 91 +
.../select/SelectWithCitations.svelte | 66 +
.../components/metrics/MetricsSection.svelte | 65 +
.../lib/components/metrics/RunChart.svelte | 323 +++
.../CompletionsToken.svelte | 124 +
.../ValidationCompletionsViewer.svelte | 824 ++++++
examples/finetuning/src/lib/dataUtils.ts | 31 +
examples/finetuning/src/lib/train/generate.ts | 263 ++
.../finetuning/src/lib/train/model/cache.ts | 20 +
.../finetuning/src/lib/train/model/config.ts | 49 +
.../finetuning/src/lib/train/model/gpt.ts | 392 +++
.../finetuning/src/lib/train/model/utils.ts | 52 +
.../finetuning/src/lib/train/moduleWorker.ts | 295 ++
examples/finetuning/src/lib/train/protocol.ts | 81 +
examples/finetuning/src/lib/train/session.ts | 675 +++++
.../finetuning/src/lib/train/tokenizer.ts | 2457 +++++++++++++++++
examples/finetuning/src/lib/train/types.ts | 27 +
.../src/lib/train/utils/checkpoint.ts | 224 ++
.../finetuning/src/lib/train/utils/init.ts | 54 +
.../finetuning/src/lib/train/utils/model.ts | 134 +
.../finetuning/src/lib/train/utils/modes.ts | 96 +
.../finetuning/src/lib/train/utils/optim.ts | 381 +++
.../finetuning/src/lib/train/utils/random.ts | 39 +
.../finetuning/src/lib/train/validation.ts | 149 +
.../src/lib/train/validationHelpers.ts | 60 +
.../src/lib/workspace/checkpointStore.ts | 52 +
.../src/lib/workspace/config.svelte.ts | 466 ++++
.../finetuning/src/lib/workspace/config.ts | 283 ++
.../src/lib/workspace/lastSessionStore.ts | 115 +
.../src/lib/workspace/localStorage.svelte.ts | 98 +
.../src/lib/workspace/runs.svelte.ts | 443 +++
.../finetuning/src/lib/workspace/ui.svelte.ts | 317 +++
.../finetuning/src/lib/workspace/utils.ts | 39 +
.../src/lib/workspace/workers.svelte.ts | 507 ++++
examples/finetuning/src/routes/+layout.svelte | 95 +
examples/finetuning/src/routes/+layout.ts | 1 +
examples/finetuning/src/routes/+page.svelte | 517 ++++
.../finetuning/src/routes/tabs/About.svelte | 118 +
.../finetuning/src/routes/tabs/Metrics.svelte | 138 +
.../static/Berkeley Mono Variable.woff2 | Bin 0 -> 46428 bytes
examples/finetuning/static/_headers | 4 +
examples/finetuning/svelte.config.js | 15 +
examples/finetuning/tsconfig.json | 22 +
examples/finetuning/vite.config.ts | 94 +
examples/finetuning/vitest-setup-client.ts | 19 +
63 files changed, 12403 insertions(+)
create mode 100644 examples/finetuning/.editorconfig
create mode 100644 examples/finetuning/.gitignore
create mode 100644 examples/finetuning/.npmrc
create mode 100644 examples/finetuning/.prettierignore
create mode 100644 examples/finetuning/.prettierrc
create mode 100644 examples/finetuning/README.md
create mode 100644 examples/finetuning/eslint.config.js
create mode 100644 examples/finetuning/package.json
create mode 100644 examples/finetuning/src/app.css
create mode 100644 examples/finetuning/src/app.d.ts
create mode 100644 examples/finetuning/src/app.html
create mode 100644 examples/finetuning/src/lib/attachments/echarts.svelte.ts
create mode 100644 examples/finetuning/src/lib/components/MetricToggleChips.svelte
create mode 100644 examples/finetuning/src/lib/components/controls/Controls.svelte
create mode 100644 examples/finetuning/src/lib/components/controls/DatasetControls.svelte
create mode 100644 examples/finetuning/src/lib/components/controls/DatasetSample.svelte
create mode 100644 examples/finetuning/src/lib/components/controls/LRSchedulePicker.svelte
create mode 100644 examples/finetuning/src/lib/components/controls/RunsTable.svelte
create mode 100644 examples/finetuning/src/lib/components/controls/SelectDataset.svelte
create mode 100644 examples/finetuning/src/lib/components/controls/select/SelectWithCitations.svelte
create mode 100644 examples/finetuning/src/lib/components/metrics/MetricsSection.svelte
create mode 100644 examples/finetuning/src/lib/components/metrics/RunChart.svelte
create mode 100644 examples/finetuning/src/lib/components/metrics/validationCompletions/CompletionsToken.svelte
create mode 100644 examples/finetuning/src/lib/components/metrics/validationCompletions/ValidationCompletionsViewer.svelte
create mode 100644 examples/finetuning/src/lib/dataUtils.ts
create mode 100644 examples/finetuning/src/lib/train/generate.ts
create mode 100644 examples/finetuning/src/lib/train/model/cache.ts
create mode 100644 examples/finetuning/src/lib/train/model/config.ts
create mode 100644 examples/finetuning/src/lib/train/model/gpt.ts
create mode 100644 examples/finetuning/src/lib/train/model/utils.ts
create mode 100644 examples/finetuning/src/lib/train/moduleWorker.ts
create mode 100644 examples/finetuning/src/lib/train/protocol.ts
create mode 100644 examples/finetuning/src/lib/train/session.ts
create mode 100644 examples/finetuning/src/lib/train/tokenizer.ts
create mode 100644 examples/finetuning/src/lib/train/types.ts
create mode 100644 examples/finetuning/src/lib/train/utils/checkpoint.ts
create mode 100644 examples/finetuning/src/lib/train/utils/init.ts
create mode 100644 examples/finetuning/src/lib/train/utils/model.ts
create mode 100644 examples/finetuning/src/lib/train/utils/modes.ts
create mode 100644 examples/finetuning/src/lib/train/utils/optim.ts
create mode 100644 examples/finetuning/src/lib/train/utils/random.ts
create mode 100644 examples/finetuning/src/lib/train/validation.ts
create mode 100644 examples/finetuning/src/lib/train/validationHelpers.ts
create mode 100644 examples/finetuning/src/lib/workspace/checkpointStore.ts
create mode 100644 examples/finetuning/src/lib/workspace/config.svelte.ts
create mode 100644 examples/finetuning/src/lib/workspace/config.ts
create mode 100644 examples/finetuning/src/lib/workspace/lastSessionStore.ts
create mode 100644 examples/finetuning/src/lib/workspace/localStorage.svelte.ts
create mode 100644 examples/finetuning/src/lib/workspace/runs.svelte.ts
create mode 100644 examples/finetuning/src/lib/workspace/ui.svelte.ts
create mode 100644 examples/finetuning/src/lib/workspace/utils.ts
create mode 100644 examples/finetuning/src/lib/workspace/workers.svelte.ts
create mode 100644 examples/finetuning/src/routes/+layout.svelte
create mode 100644 examples/finetuning/src/routes/+layout.ts
create mode 100644 examples/finetuning/src/routes/+page.svelte
create mode 100644 examples/finetuning/src/routes/tabs/About.svelte
create mode 100644 examples/finetuning/src/routes/tabs/Metrics.svelte
create mode 100644 examples/finetuning/static/Berkeley Mono Variable.woff2
create mode 100644 examples/finetuning/static/_headers
create mode 100644 examples/finetuning/svelte.config.js
create mode 100644 examples/finetuning/tsconfig.json
create mode 100644 examples/finetuning/vite.config.ts
create mode 100644 examples/finetuning/vitest-setup-client.ts
diff --git a/examples/finetuning/.editorconfig b/examples/finetuning/.editorconfig
new file mode 100644
index 00000000..5696c4e8
--- /dev/null
+++ b/examples/finetuning/.editorconfig
@@ -0,0 +1,23 @@
+root = true
+
+[*]
+end_of_line = lf
+insert_final_newline = true
+indent_style = tab
+indent_size = tab
+tab_width = 2
+charset = utf-8
+trim_trailing_whitespace = true
+
+[*.py]
+indent_style = space
+indent_size = 4
+tab_width = 4
+
+[*.ipynb]
+indent_style = space
+indent_size = 4
+tab_width = 4
+
+[package.json]
+indent_style = space
diff --git a/examples/finetuning/.gitignore b/examples/finetuning/.gitignore
new file mode 100644
index 00000000..85b308c3
--- /dev/null
+++ b/examples/finetuning/.gitignore
@@ -0,0 +1,28 @@
+node_modules
+
+# Output
+.output
+.vercel
+.netlify
+.wrangler
+/.svelte-kit
+/build
+
+# OS
+.DS_Store
+Thumbs.db
+
+# Env
+*.local
+.env*
+!.env.production
+!.env.development
+!.env.example
+!.env.test
+
+# Vite
+vite.config.js.timestamp-*
+vite.config.ts.timestamp-*
+
+static/tokenizer
+static/tokenized
diff --git a/examples/finetuning/.npmrc b/examples/finetuning/.npmrc
new file mode 100644
index 00000000..fbb338d9
--- /dev/null
+++ b/examples/finetuning/.npmrc
@@ -0,0 +1,2 @@
+engine-strict=true
+use-node-version=22.17.1
diff --git a/examples/finetuning/.prettierignore b/examples/finetuning/.prettierignore
new file mode 100644
index 00000000..ab78a95d
--- /dev/null
+++ b/examples/finetuning/.prettierignore
@@ -0,0 +1,4 @@
+# Package Managers
+package-lock.json
+pnpm-lock.yaml
+yarn.lock
diff --git a/examples/finetuning/.prettierrc b/examples/finetuning/.prettierrc
new file mode 100644
index 00000000..bd92dfb0
--- /dev/null
+++ b/examples/finetuning/.prettierrc
@@ -0,0 +1,18 @@
+{
+ "useTabs": true,
+ "singleQuote": true,
+ "trailingComma": "none",
+ "printWidth": 100,
+ "tabWidth": 2,
+ "plugins": [
+ "prettier-plugin-svelte"
+ ],
+ "overrides": [
+ {
+ "files": "*.svelte",
+ "options": {
+ "parser": "svelte"
+ }
+ }
+ ]
+}
diff --git a/examples/finetuning/README.md b/examples/finetuning/README.md
new file mode 100644
index 00000000..99aaad96
--- /dev/null
+++ b/examples/finetuning/README.md
@@ -0,0 +1 @@
+This is the source code for the finetuning demo at [finetune.sequence.toys](https://finetune.sequence.toys). It is a static Svelte app.
diff --git a/examples/finetuning/eslint.config.js b/examples/finetuning/eslint.config.js
new file mode 100644
index 00000000..e00e208a
--- /dev/null
+++ b/examples/finetuning/eslint.config.js
@@ -0,0 +1,81 @@
+import { includeIgnoreFile } from '@eslint/compat';
+import js from '@eslint/js';
+import prettier from 'eslint-config-prettier';
+import perfectionist from 'eslint-plugin-perfectionist';
+import svelte from 'eslint-plugin-svelte';
+import { defineConfig } from 'eslint/config';
+import globals from 'globals';
+import { fileURLToPath } from 'node:url';
+import ts from 'typescript-eslint';
+const gitignorePath = fileURLToPath(new URL('./.gitignore', import.meta.url));
+const tsconfigRootDir = fileURLToPath(new URL('.', import.meta.url));
+
+export default defineConfig(
+ includeIgnoreFile(gitignorePath),
+ js.configs.recommended,
+ ...ts.configs.recommended,
+ ...svelte.configs['flat/recommended'],
+ ...svelte.configs['flat/prettier'],
+ {
+ languageOptions: {
+ globals: {
+ ...globals.browser,
+ ...globals.node
+ }
+ }
+ },
+ {
+ files: ['**/*.svelte', '**/*.svelte.ts', '**/*.ts'],
+
+ languageOptions: {
+ parserOptions: {
+ parser: ts.parser,
+ // Ensure the TypeScript parser knows which tsconfig to use in this package
+ tsconfigRootDir,
+ projectService: true
+ }
+ }
+ },
+ {
+ plugins: {
+ prettier,
+ perfectionist
+ },
+ rules: {
+ 'no-unused-vars': 'off',
+ '@typescript-eslint/ban-types': 'off',
+ '@typescript-eslint/no-explicit-any': 'warn',
+ '@typescript-eslint/no-unused-vars': [
+ 'warn',
+ {
+ argsIgnorePattern: '^_',
+ varsIgnorePattern: '^_',
+ caughtErrorsIgnorePattern: '^_'
+ }
+ ],
+ '@typescript-eslint/no-var-requires': 'off',
+ 'perfectionist/sort-imports': [
+ 'error',
+ {
+ type: 'natural',
+ order: 'asc',
+ groups: [
+ 'type',
+ ['builtin', 'external'],
+ 'internal-type',
+ 'internal',
+ ['parent-type', 'sibling-type', 'index-type'],
+ ['parent', 'sibling', 'index'],
+ 'side-effect',
+ 'style',
+ 'object',
+ 'unknown'
+ ]
+ }
+ ],
+ 'perfectionist/sort-exports': ['error', { type: 'natural' }],
+ 'perfectionist/sort-named-exports': ['error', { type: 'natural' }],
+ 'perfectionist/sort-named-imports': ['error', { type: 'natural' }]
+ }
+ }
+);
diff --git a/examples/finetuning/package.json b/examples/finetuning/package.json
new file mode 100644
index 00000000..47e53ff9
--- /dev/null
+++ b/examples/finetuning/package.json
@@ -0,0 +1,63 @@
+{
+ "name": "piston-finetune-toy",
+ "private": true,
+ "version": "0.0.1",
+ "type": "module",
+ "scripts": {
+ "dev": "vite dev",
+ "build": "vite build",
+ "preview": "vite preview",
+ "prepare": "svelte-kit sync || echo ''",
+ "check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
+ "check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch",
+ "lint": "eslint . && prettier --check .",
+ "format": "prettier --write ."
+ },
+ "dependencies": {
+ "@codemirror/autocomplete": "^6.19.1",
+ "@codemirror/language": "^6.11.3",
+ "@codemirror/lint": "^6.9.1",
+ "@codemirror/state": "^6.5.2",
+ "@codemirror/view": "^6.38.6",
+ "@huggingface/jinja": "^0.5.1",
+ "@lezer/highlight": "^1.2.3",
+ "@lucide/svelte": "^0.554.0",
+ "@piston-ml/piston-web": "workspace:^",
+ "@types/katex": "^0.16.7",
+ "codemirror": "^6.0.2",
+ "echarts": "^6.0.0",
+ "example-common": "workspace:*",
+ "katex": "^0.16.23",
+ "maxrects-packer": "^2.7.3",
+ "random-js": "^2.1.0",
+ "svelte-portal": "^2.2.1",
+ "unique-names-generator": "^4.7.1"
+ },
+ "devDependencies": {
+ "@eslint/compat": "^1.4.0",
+ "@eslint/js": "^9.37.0",
+ "@sveltejs/adapter-static": "^3.0.10",
+ "@sveltejs/kit": "^2.46.4",
+ "@sveltejs/vite-plugin-svelte": "^6.2.1",
+ "@tailwindcss/vite": "^4.1.14",
+ "@types/glob": "^9.0.0",
+ "@webgpu/types": "^0.1.65",
+ "eslint": "^9.37.0",
+ "eslint-config-prettier": "^10.1.8",
+ "eslint-plugin-perfectionist": "^4.15.1",
+ "eslint-plugin-svelte": "^3.12.4",
+ "glob": "^11.0.3",
+ "globals": "^16.4.0",
+ "prettier": "^3.6.2",
+ "prettier-plugin-svelte": "^3.4.0",
+ "rollup": "^4.52.4",
+ "sirv": "^2.0.4",
+ "svelte": "workspace:^",
+ "svelte-check": "^4.3.3",
+ "tailwindcss": "^4.1.14",
+ "typescript": "^5.9.3",
+ "typescript-eslint": "^8.46.0",
+ "vite": "^7.1.9",
+ "vite-plugin-wasm": "^3.5.0"
+ }
+}
diff --git a/examples/finetuning/src/app.css b/examples/finetuning/src/app.css
new file mode 100644
index 00000000..aec164aa
--- /dev/null
+++ b/examples/finetuning/src/app.css
@@ -0,0 +1,25 @@
+@import 'tailwindcss' source('./');
+@import 'example-common/theme.css';
+
+@source "../../../packages/example-common/src/**/*.{svelte,js,ts}";
+
+html {
+ /* Base font size for fine pointers (desktop/mouse) */
+ font-size: 16px;
+ overscroll-behavior: none;
+ @apply h-full;
+ @apply w-full;
+}
+
+@media (pointer: coarse) {
+ html {
+ font-size: 18px;
+ }
+}
+
+body {
+ @apply h-full;
+ @apply w-full;
+ @apply text-base;
+ overscroll-behavior: none;
+}
diff --git a/examples/finetuning/src/app.d.ts b/examples/finetuning/src/app.d.ts
new file mode 100644
index 00000000..da2a8798
--- /dev/null
+++ b/examples/finetuning/src/app.d.ts
@@ -0,0 +1,14 @@
+// See https://svelte.dev/docs/kit/types#app.d.ts
+// for information about these interfaces
+declare global {
+ const __COMMIT_HASH__: string;
+ namespace App {
+ // interface Error {}
+ // interface Locals {}
+ // interface PageData {}
+ // interface PageState {}
+ // interface Platform {}
+ }
+}
+
+export {};
diff --git a/examples/finetuning/src/app.html b/examples/finetuning/src/app.html
new file mode 100644
index 00000000..f273cc58
--- /dev/null
+++ b/examples/finetuning/src/app.html
@@ -0,0 +1,11 @@
+
+
+
+
+
+ %sveltekit.head%
+
+
+ %sveltekit.body%
+
+
diff --git a/examples/finetuning/src/lib/attachments/echarts.svelte.ts b/examples/finetuning/src/lib/attachments/echarts.svelte.ts
new file mode 100644
index 00000000..2ba10bd5
--- /dev/null
+++ b/examples/finetuning/src/lib/attachments/echarts.svelte.ts
@@ -0,0 +1,236 @@
+import type { Attachment } from 'svelte/attachments';
+
+import * as echarts from 'echarts/core';
+
+type EChartsAttachmentParams = {
+ opts: () => echarts.EChartsCoreOption | null;
+ // Optional setup hook to programmatically configure the instance (e.g., event bridging)
+ // May return a cleanup function that will be called on detach.
+ setup?: (chart: echarts.ECharts, getPeers: undefined) => void | (() => void);
+};
+
+export function setupAxisSync(chart: echarts.ECharts, getPeers: () => echarts.ECharts[]) {
+ let isRelaying = false;
+ chart.on('updateAxisPointer', (evt) => {
+ if (isRelaying) return;
+ const e = evt as { axesInfo?: Array<{ axisDim: string; value: number }> };
+ const axesInfo = e?.axesInfo;
+ if (!axesInfo || axesInfo.length === 0) return;
+ const xInfo = axesInfo.find((a) => a.axisDim === 'x');
+ if (!xInfo) return;
+
+ const catIndex = Math.max(0, Math.floor(xInfo.value ?? 0));
+ const opt = chart.getOption?.() as
+ | { xAxis?: Array<{ data?: (number | string)[] }> }
+ | undefined;
+ const cats = opt?.xAxis?.[0]?.data ?? [];
+ const step = Number(cats[catIndex] ?? catIndex);
+
+ isRelaying = true;
+ for (const peer of getPeers()) {
+ const hit = findPeerDataIndexByStep(peer, step);
+ peer.dispatchAction({
+ type: 'updateAxisPointer',
+ seriesIndex: hit?.seriesIndex ?? 0,
+ dataIndex: hit?.dataIndex ?? catIndex
+ });
+ }
+ Promise.resolve().then(() => {
+ isRelaying = false;
+ });
+ });
+}
+
+// Binary-search utilities for sorted numeric step arrays
+export function exactIndex(steps: number[], target: number): number {
+ let lo = 0,
+ hi = steps.length - 1;
+ while (lo <= hi) {
+ const mid = (lo + hi) >> 1;
+ const v = steps[mid];
+ if (v === target) return mid;
+ if (v < target) lo = mid + 1;
+ else hi = mid - 1;
+ }
+ return -1;
+}
+
+export function nearestIndex(steps: number[], target: number): number {
+ if (steps.length === 0) return -1;
+ const first = steps[0],
+ last = steps[steps.length - 1];
+ if (target < first || target > last) return -1;
+ let lo = 0,
+ hi = steps.length;
+ while (lo < hi) {
+ const mid = (lo + hi) >> 1;
+ if (steps[mid] < target) lo = mid + 1;
+ else hi = mid;
+ }
+ if (lo === 0) return 0;
+ if (lo === steps.length) return steps.length - 1;
+ return target - steps[lo - 1] <= steps[lo] - target ? lo - 1 : lo;
+}
+
+// Extract numeric x/step from an ECharts updateAxisPointer event
+export function extractStepFromAxisPointerEvent(evt: unknown): number | null {
+ const e = evt as { axesInfo?: Array<{ axisDim: string; value: number }> } | undefined;
+ const xInfo = e?.axesInfo?.find((a) => a.axisDim === 'x');
+ const step = Number(xInfo?.value);
+ return Number.isFinite(step) ? step : null;
+}
+
+function extractXValue(datum: unknown): number | null {
+ if (Array.isArray(datum)) {
+ const x = Number(datum[0]);
+ return Number.isFinite(x) ? x : null;
+ }
+ if (datum && typeof datum === 'object') {
+ const obj = datum as { value?: unknown; x?: unknown; step?: unknown };
+ if (Array.isArray(obj.value)) {
+ const x = Number(obj.value[0]);
+ return Number.isFinite(x) ? x : null;
+ }
+ const xCandidate = obj.x ?? obj.step;
+ if (typeof xCandidate === 'number') return xCandidate;
+ if (typeof xCandidate === 'string') {
+ const parsed = Number(xCandidate);
+ return Number.isFinite(parsed) ? parsed : null;
+ }
+ }
+ return null;
+}
+
+function linearSearchByX(data: unknown[], targetStep: number): number {
+ for (let i = 0; i < data.length; i++) {
+ const x = extractXValue(data[i]);
+ if (x === targetStep) return i;
+ }
+ return -1;
+}
+
+function binarySearchByX(data: unknown[], targetStep: number): number {
+ if (data.length === 0) return -1;
+ const first = extractXValue(data[0]);
+ const last = extractXValue(data[data.length - 1]);
+ if (first === null || last === null) return linearSearchByX(data, targetStep);
+
+ const ascending = last >= first;
+ let lo = 0;
+ let hi = data.length - 1;
+ while (lo <= hi) {
+ const mid = lo + ((hi - lo) >> 1);
+ const x = extractXValue(data[mid]);
+ if (x === null) return linearSearchByX(data, targetStep);
+ if (x === targetStep) return mid;
+ if (ascending ? x < targetStep : x > targetStep) lo = mid + 1;
+ else hi = mid - 1;
+ }
+ return -1;
+}
+
+function getSeriesArrayFromOption(opt: unknown): unknown[] {
+ const o = opt as { series?: unknown[] } | undefined;
+ return Array.isArray(o?.series) ? (o!.series as unknown[]) : [];
+}
+
+function getDataArrayFromSeries(seriesItem: unknown): unknown[] {
+ const s = seriesItem as { data?: unknown[] } | undefined;
+ return Array.isArray(s?.data) ? (s!.data as unknown[]) : [];
+}
+
+export function findPeerDataIndexByStep(
+ peer: echarts.ECharts,
+ step: number
+): { seriesIndex: number; dataIndex: number } | null {
+ const opt = peer.getOption() as unknown;
+ const seriesArr = getSeriesArrayFromOption(opt);
+ if (seriesArr.length === 0) return null;
+ const seriesIndex = seriesArr.length - 1; // highest series index
+ const data = getDataArrayFromSeries(seriesArr[seriesIndex]);
+ if (data.length === 0) return null;
+ const dataIndex = binarySearchByX(data, step);
+ if (dataIndex < 0) return null;
+ return { seriesIndex, dataIndex };
+}
+
+export default function createEChartsAttachment(params: EChartsAttachmentParams): Attachment {
+ return (node: Element) => {
+ if (!(node instanceof HTMLDivElement)) {
+ throw new Error('ECharts attachment requires a div element');
+ }
+
+ const chart = echarts.init(node);
+ const getPeers = undefined;
+
+ // Allow caller to set up custom behavior (e.g., axis-pointer mapping)
+ let setupCleanup: (() => void) | undefined;
+ const maybeCleanup = params.setup?.(chart, getPeers);
+ if (typeof maybeCleanup === 'function') setupCleanup = maybeCleanup;
+
+ const resizeObserver = new ResizeObserver(() => {
+ chart.resize();
+ });
+ resizeObserver.observe(node);
+
+ $effect(() => {
+ const options = params.opts?.();
+ if (options) {
+ chart.setOption(options, {
+ notMerge: false,
+ replaceMerge: ['series'],
+ lazyUpdate: false
+ });
+ }
+ });
+
+ return () => {
+ resizeObserver.unobserve(node);
+ setupCleanup?.();
+ chart.dispose();
+ };
+ };
+}
+
+type MoveDetail = {
+ sourceId: string;
+ runId: string;
+ step: number;
+};
+
+type ClearDetail = {
+ sourceId: string;
+};
+
+const MOVE_EVENT = 'run-pointer:move';
+const CLEAR_EVENT = 'run-pointer:clear';
+
+const bus: EventTarget = new EventTarget();
+
+export function publishMove(detail: MoveDetail): void {
+ bus.dispatchEvent(new CustomEvent(MOVE_EVENT, { detail }));
+}
+
+export function publishClear(detail: ClearDetail): void {
+ bus.dispatchEvent(new CustomEvent(CLEAR_EVENT, { detail }));
+}
+
+export function subscribe(
+ onMove?: (detail: MoveDetail) => void,
+ onClear?: (detail: ClearDetail) => void
+): () => void {
+ const moveListener = (e: Event) => {
+ const ce = e as CustomEvent;
+ if (onMove) onMove(ce.detail);
+ };
+ const clearListener = (e: Event) => {
+ const ce = e as CustomEvent;
+ if (onClear) onClear(ce.detail);
+ };
+ if (onMove) bus.addEventListener(MOVE_EVENT, moveListener as EventListener);
+ if (onClear) bus.addEventListener(CLEAR_EVENT, clearListener as EventListener);
+ return () => {
+ if (onMove) bus.removeEventListener(MOVE_EVENT, moveListener as EventListener);
+ if (onClear) bus.removeEventListener(CLEAR_EVENT, clearListener as EventListener);
+ };
+}
diff --git a/examples/finetuning/src/lib/components/MetricToggleChips.svelte b/examples/finetuning/src/lib/components/MetricToggleChips.svelte
new file mode 100644
index 00000000..b79ebaa3
--- /dev/null
+++ b/examples/finetuning/src/lib/components/MetricToggleChips.svelte
@@ -0,0 +1,41 @@
+
+
+
+ {#each items as name (name)}
+
+ {/each}
+
diff --git a/examples/finetuning/src/lib/components/controls/Controls.svelte b/examples/finetuning/src/lib/components/controls/Controls.svelte
new file mode 100644
index 00000000..5d99ce15
--- /dev/null
+++ b/examples/finetuning/src/lib/components/controls/Controls.svelte
@@ -0,0 +1,776 @@
+
+
+
+ {#if runsMap.size >= 1}
+
toggleControlSection('runs')}
+ contentClass="w-full"
+ >
+
+
+ {/if}
+
+
toggleControlSection('gpu')}
+ contentClass={collapsibleSectionClass}
+ >
+
+ {gpuName || 'Unknown'}
+
+
+ (gpuPowerPreference.current = 'high-performance')}
+ />
+
+ resetConfigToDefaults('training.vramLimitMb.present')}
+ >
+ {
+ if (value >= 1024) {
+ const gb = value / 1024;
+ const gbStr = gb % 1 === 0 ? `${gb}GB` : `${gb.toFixed(1)}GB`;
+ return gbStr;
+ }
+ return `${value}MB`;
+ }}
+ hasDefaultValue={equalsConfigDefault('training.vramLimitMb.value')}
+ onReset={() => resetConfigToDefaults('training.vramLimitMb.value')}
+ />
+
+
+
+
toggleControlSection('task')}
+ contentClass={collapsibleSectionClass}
+ >
+
+
+
+
toggleControlSection('model')}
+ contentClass={collapsibleSectionClass}
+ >
+
+
+
+
+ {getParameterCount()}
+
+
+ {getHiddenSize()}
+
+
+ {currentDataset.vocabSize}
+
+
+
+
+
toggleControlSection('training')}
+ >
+ resetConfigToDefaults('training.logSteps')}
+ />
+
+ resetConfigToDefaults('training.batchSize')}
+ />
+
+ resetConfigToDefaults('training.gradNorm.track')}
+ >
+ resetConfigToDefaults('training.gradNorm.errorIfNonfinite')}
+ />
+
+ resetConfigToDefaults('training.clipGradNorm.present')}
+ >
+ resetConfigToDefaults('training.clipGradNorm.value')}
+ />
+
+
+
+ resetConfigToDefaults('training.validation.present')}
+ >
+ resetConfigToDefaults('training.validation.valSteps')}
+ />
+ resetConfigToDefaults('training.validation.batchSize')}
+ />
+ resetConfigToDefaults('training.validation.completions.present')}
+ >
+ resetConfigToDefaults('training.validation.completions.amount')}
+ />
+ {#if config.training.validation.completions.amount === 'subset'}
+ resetConfigToDefaults('training.validation.completions.subsetSize')}
+ />
+ {/if}
+
+ resetConfigToDefaults('training.validation.temperature')}
+ />
+
+
+
+
+ resetConfigToDefaults('training.limitTraining.present')}
+ >
+ resetConfigToDefaults('training.limitTraining.steps')}
+ />
+
+
+
+ resetConfigToDefaults('training.labelSmoothing.present')}
+ >
+ resetConfigToDefaults('training.labelSmoothing.value')}
+ />
+
+
+ resetConfigToDefaults('training.dropout.present')}
+ >
+ resetConfigToDefaults('training.dropout.embedding')}
+ />
+ resetConfigToDefaults('training.dropout.transformer.attention')}
+ />
+ resetConfigToDefaults('training.dropout.transformer.residual')}
+ />
+
+
+
+ resetConfigToDefaults('training.randomSeed.present')}
+ >
+ resetConfigToDefaults('training.randomSeed.value')}
+ />
+
+
+ resetConfigToDefaults('training.checkpointEverySteps.present')}
+ >
+ resetConfigToDefaults('training.checkpointEverySteps.value')}
+ />
+
+
+
+
toggleControlSection('optimizer')}
+ contentClass={collapsibleSectionClass}
+ >
+ resetConfigToDefaults('optimizer.type')}
+ />
+
+ resetConfigToDefaults('optimizer.lr')}
+ />
+
+ resetConfigToDefaults('optimizer.warmupSteps.present')}
+ >
+ resetConfigToDefaults('optimizer.warmupSteps.value')}
+ />
+
+
+ resetConfigToDefaults('optimizer.lrScheduler.present')}
+ >
+
+
+
+ resetConfigToDefaults('optimizer.weightDecay.present')}
+ >
+ resetConfigToDefaults('optimizer.weightDecay.value')}
+ />
+
+ resetConfigToDefaults('optimizer.weightDecay.useWeightDecayGroups')}
+ />
+
+
+
+ {#if config.optimizer.type === 'Muon'}
+
+
+ resetConfigToDefaults('optimizer.muon.nsSteps')}
+ />
+ resetConfigToDefaults('optimizer.muon.momentum')}
+ />
+ resetConfigToDefaults('optimizer.muon.nesterov')}
+ />
+
+
+ {/if}
+ {#if config.optimizer.type === 'AdamW' || config.optimizer.type === 'Adam' || config.optimizer.type === 'Muon'}
+ {@const settingsName = config.optimizer.type === 'Adam' ? 'Adam' : 'AdamW'}
+
+
+ resetConfigToDefaults('optimizer.adam.beta1')}
+ />
+ resetConfigToDefaults('optimizer.adam.beta2')}
+ />
+ resetConfigToDefaults('optimizer.adam.eps')}
+ />
+ resetConfigToDefaults('optimizer.adam.amsgrad')}
+ />
+
+
+ {/if}
+
+ {#if config.optimizer.type === 'SGD'}
+
+ resetConfigToDefaults('optimizer.sgd.momentum')}
+ />
+ resetConfigToDefaults('optimizer.sgd.dampening')}
+ />
+ resetConfigToDefaults('optimizer.sgd.nesterov')}
+ />
+
+ {/if}
+
+
+
+
toggleControlSection('advanced')}
+ contentClass={collapsibleSectionClass}
+ >
+ resetConfigToDefaults('training.useWeakTensorReferences')}
+ />
+ resetConfigToDefaults('training.sharedObjectAllocation')}
+ />
+ resetConfigToDefaults('training.inplaceSupport')}
+ />
+
+
+
+
diff --git a/examples/finetuning/src/lib/components/controls/DatasetControls.svelte b/examples/finetuning/src/lib/components/controls/DatasetControls.svelte
new file mode 100644
index 00000000..a1003070
--- /dev/null
+++ b/examples/finetuning/src/lib/components/controls/DatasetControls.svelte
@@ -0,0 +1,38 @@
+
+
+ resetConfigToDefaults('data.dataset')}
+/>
+
+{datasetConfigMetadata?.description}
+
+
+
+{#if getShowLowDiversityDatasetError()}
+
+
+
+ Not enough example diversity in the training dataset for a held-out validation set of size {config
+ .training.validation.batchSize}. Consider changing dataset parameters or reducing the
+ validation batch size.
+
+
+
+{/if}
diff --git a/examples/finetuning/src/lib/components/controls/DatasetSample.svelte b/examples/finetuning/src/lib/components/controls/DatasetSample.svelte
new file mode 100644
index 00000000..bd9aad91
--- /dev/null
+++ b/examples/finetuning/src/lib/components/controls/DatasetSample.svelte
@@ -0,0 +1,158 @@
+
+
+{#snippet token(value: number, dashed: boolean = false)}
+
+ {#if tokenizer instanceof Promise}
+ {#await tokenizer then tokenizer}
+ {decodeSingle(value, tokenizer)}
+ {/await}
+ {:else}
+ {decodeSingle(value, tokenizer ?? null)}
+ {/if}
+
+{/snippet}
+
+{#snippet tokenSequence(values: number[], ignored?: boolean[])}
+
+ {#each values as value, i (i)}
+ {@render token(value, ignored ? ignored[i] : false)}
+ {/each}
+
+{/snippet}
+
+ 0 ? `${lastStableHeight}px` : 'auto'}
+ style:overflow-y={anyPending ? 'hidden' : 'visible'}
+>
+ {#if sampleData}
+ {#await sampleData then { collated, hasPrompt }}
+ {#if collated.length > 0}
+
+
+
+
+ {#if hasPrompt}
+ | Prompt |
+ {/if}
+ Target |
+
+
+
+ {#each collated as { prompt, target, fullSequence } (Array.prototype.concat.call(fullSequence, prompt ?? [], target ?? []))}
+ {@const pLen = prompt?.length || 0}
+ {@const targetFlags = maskedFlagsForRange(fullSequence, pLen, target?.length ?? 0)}
+
+ {#if hasPrompt}
+ {@const promptFlags = maskedFlagsForRange(fullSequence, 0, pLen)}
+ |
+ {@render tokenSequence(prompt!, promptFlags)}
+ |
+ {/if}
+
+ {@render tokenSequence(target ?? fullSequence, targetFlags)}
+ |
+
+ {/each}
+
+ {#if hasPrompt}
+ | ... |
+ {/if}
+ ... |
+
+
+
+
+ {/if}
+ {/await}
+ {/if}
+
diff --git a/examples/finetuning/src/lib/components/controls/LRSchedulePicker.svelte b/examples/finetuning/src/lib/components/controls/LRSchedulePicker.svelte
new file mode 100644
index 00000000..531e616f
--- /dev/null
+++ b/examples/finetuning/src/lib/components/controls/LRSchedulePicker.svelte
@@ -0,0 +1,517 @@
+
+
+
+
+
+
+
+
resetConfigToDefaults('optimizer.lrScheduler.type')}
+ />
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ (stepsToShow = 1000)}
+ />
+
+
+ {#if config.optimizer.lrScheduler.type === 'linear'}
+ resetConfigToDefaults('optimizer.lrScheduler.linearSchedule.startFactor')}
+ />
+ resetConfigToDefaults('optimizer.lrScheduler.linearSchedule.endFactor')}
+ />
+ resetConfigToDefaults('optimizer.lrScheduler.linearSchedule.totalIters')}
+ />
+ {/if}
+
+
+ {#if config.optimizer.lrScheduler.type === 'constant'}
+ resetConfigToDefaults('optimizer.lrScheduler.constantSchedule.factor')}
+ />
+ resetConfigToDefaults('optimizer.lrScheduler.constantSchedule.totalIters')}
+ />
+ {/if}
+
+
+ {#if config.optimizer.lrScheduler.type === 'cosine'}
+ resetConfigToDefaults('optimizer.lrScheduler.cosineAnnealingSchedule.tMax')}
+ />
+
+ resetConfigToDefaults('optimizer.lrScheduler.cosineAnnealingSchedule.etaMin')}
+ />
+ {/if}
+
+
+ {#if config.optimizer.lrScheduler.type === 'step'}
+ resetConfigToDefaults('optimizer.lrScheduler.stepSchedule.stepSize')}
+ />
+ resetConfigToDefaults('optimizer.lrScheduler.stepSchedule.gamma')}
+ />
+ {/if}
+
+
+ {#if config.optimizer.lrScheduler.type === 'exponential'}
+ resetConfigToDefaults('optimizer.lrScheduler.exponentialSchedule.gamma')}
+ />
+ {/if}
+
+
diff --git a/examples/finetuning/src/lib/components/controls/RunsTable.svelte b/examples/finetuning/src/lib/components/controls/RunsTable.svelte
new file mode 100644
index 00000000..73d67b06
--- /dev/null
+++ b/examples/finetuning/src/lib/components/controls/RunsTable.svelte
@@ -0,0 +1,52 @@
+
+
+
+ {#if runs.length > 1}
+
+ {/if}
+
+
+
+
+
+
+
+
+
+
+ | {frozenColumn.label} |
+ Changes |
+
+
+
+ {#each runs as run (run.runId)}
+
+ | {run.runId} |
+ {run.diffSummary ?? 'initial experiment'} |
+
+ {/each}
+
+
+
+
+
+
diff --git a/examples/finetuning/src/lib/components/controls/SelectDataset.svelte b/examples/finetuning/src/lib/components/controls/SelectDataset.svelte
new file mode 100644
index 00000000..112eec09
--- /dev/null
+++ b/examples/finetuning/src/lib/components/controls/SelectDataset.svelte
@@ -0,0 +1,91 @@
+
+
+{#snippet nameAndBadges(_opt: DatasetOption)}
+ {@const opt = _opt as DatasetOption}
+ {@const meta = getMeta(opt.value as DatasetName)}
+ {@const name = opt.text ?? meta.name}
+ {@const citations =
+ 'citations' in meta
+ ? ((meta as Record).citations as CitationsType)
+ : undefined}
+
+ {name}
+ {#if citations}
+
+ {/if}
+
+{/snippet}
+
+
+
+ {#snippet option(_opt, _selected)}
+ {@const opt = _opt as DatasetOption}
+ {@render nameAndBadges(opt)}
+ {/snippet}
+ {#snippet trigger(_selected)}
+ {#if _selected}
+ {@const opt = _selected as DatasetOption}
+ {opt.text ?? opt.value}
+ {:else}
+ Select...
+ {/if}
+ {/snippet}
+
+ {#if 'citations' in selectedMeta}
+
+
+
+ {/if}
+
diff --git a/examples/finetuning/src/lib/components/controls/select/SelectWithCitations.svelte b/examples/finetuning/src/lib/components/controls/select/SelectWithCitations.svelte
new file mode 100644
index 00000000..515e5fa3
--- /dev/null
+++ b/examples/finetuning/src/lib/components/controls/select/SelectWithCitations.svelte
@@ -0,0 +1,66 @@
+
+
+{#snippet citationView(_opt: CitationOption)}
+ {@const label = _opt.title}
+ {@const citations = _opt.citations}
+
+
{label}
+ {#if citations}
+
+ {/if}
+
+{/snippet}
+
+
+
+ {#snippet option(_opt, _selected)}
+ {@const opt = _opt as CitationOption}
+ {@render citationView(opt)}
+ {/snippet}
+ {#snippet trigger(_selected)}
+ {#if _selected}
+ {@const opt = _selected as CitationOption}
+ {opt.title}
+ {:else}
+ Select...
+ {/if}
+ {/snippet}
+
+ {#if selectedOption?.citations}
+
+ {/if}
+
diff --git a/examples/finetuning/src/lib/components/metrics/MetricsSection.svelte b/examples/finetuning/src/lib/components/metrics/MetricsSection.svelte
new file mode 100644
index 00000000..ca908275
--- /dev/null
+++ b/examples/finetuning/src/lib/components/metrics/MetricsSection.svelte
@@ -0,0 +1,65 @@
+
+
+
+
+
{
+ if (e.key === 'Enter' || e.key === ' ') {
+ e.preventDefault();
+ handleClick();
+ }
+ }}
+ >
+
+
{title}
+
+ {@render chips?.()}
+
+
+ {#if isOpen}
+
+ {@render children?.()}
+
+ {/if}
+
diff --git a/examples/finetuning/src/lib/components/metrics/RunChart.svelte b/examples/finetuning/src/lib/components/metrics/RunChart.svelte
new file mode 100644
index 00000000..8ef57a5e
--- /dev/null
+++ b/examples/finetuning/src/lib/components/metrics/RunChart.svelte
@@ -0,0 +1,323 @@
+
+
+
+
chartOptions,
+ setup: (chart) => {
+ let isRelaying = false;
+
+ const onUpdateAxisPointer = (evt: unknown) => {
+ if (isRelaying) return;
+ const s = extractStepFromAxisPointerEvent(evt);
+ if (s == null) return;
+ // Choose topmost series that contains this exact step
+ for (let idx = seriesIndexToRunId.length - 1; idx >= 0; idx--) {
+ const runId = seriesIndexToRunId[idx];
+ if (!runId) continue;
+ const seriesInfo = runIdToSeries.get(runId);
+ if (!seriesInfo) continue;
+ if (exactIndex(seriesInfo.steps, s) !== -1) {
+ publishMove({ sourceId: chartId, runId, step: s });
+ break;
+ }
+ }
+ };
+
+ const onGlobalOut = () => {
+ publishClear({ sourceId: chartId });
+ };
+
+ chart.on('updateAxisPointer', onUpdateAxisPointer);
+ chart.on('globalout', onGlobalOut);
+
+ const unsubscribe = subscribe(
+ ({ sourceId, runId, step }) => {
+ if (sourceId === chartId) return;
+ const info = runIdToSeries.get(runId);
+ if (!info) return;
+ const { seriesIndex, steps, first, last } = info;
+ if (step < first || step > last) return;
+ const dataIndex = nearestIndex(steps, step);
+ if (dataIndex < 0) return;
+ isRelaying = true;
+ try {
+ // ECharts accepts a second options argument with { silent: true }
+ chart.dispatchAction(
+ { type: 'updateAxisPointer', seriesIndex, dataIndex },
+ { silent: true }
+ );
+ } finally {
+ isRelaying = false;
+ }
+ },
+ ({ sourceId }) => {
+ if (sourceId === chartId) return;
+ chart.dispatchAction({ type: 'hideTip' }, { silent: true });
+ }
+ );
+
+ return () => {
+ unsubscribe();
+ chart.off('updateAxisPointer', onUpdateAxisPointer);
+ chart.off('globalout', onGlobalOut);
+ };
+ }
+ })}
+ >
+
diff --git a/examples/finetuning/src/lib/components/metrics/validationCompletions/CompletionsToken.svelte b/examples/finetuning/src/lib/components/metrics/validationCompletions/CompletionsToken.svelte
new file mode 100644
index 00000000..8613d92a
--- /dev/null
+++ b/examples/finetuning/src/lib/components/metrics/validationCompletions/CompletionsToken.svelte
@@ -0,0 +1,124 @@
+
+
+ {
+ if (e.key === 'Enter' || e.key === ' ') {
+ e.preventDefault();
+ handleClick(e);
+ }
+ }}
+>
+ {#if targetText != null}
+
+ {visualizeToken(targetText)}
+
+
+ {#if actualText}
+
+ {visualizeToken(actualText)}
+
+ {/if}
+
+ {visualizeToken(targetText)}
+
+
+ {:else}
+
+ {visualizeToken(actualText)}
+
+ {/if}
+
diff --git a/examples/finetuning/src/lib/components/metrics/validationCompletions/ValidationCompletionsViewer.svelte b/examples/finetuning/src/lib/components/metrics/validationCompletions/ValidationCompletionsViewer.svelte
new file mode 100644
index 00000000..4a534700
--- /dev/null
+++ b/examples/finetuning/src/lib/components/metrics/validationCompletions/ValidationCompletionsViewer.svelte
@@ -0,0 +1,824 @@
+
+
+
+
+
+ validation/completions
+ {#if completionsData}
+
+ of {completionsData.targetStep.completions.length}
+ {/if}
+
+ {#if completionsData}
+
+ Step {completionsData.stepNumber} • Temp: {completionsData.targetStep.samplingParams
+ .temperature}
+
+ {/if}
+
+
+ {#if !completionsData}
+
+
No validation data available. Validation will run during training.
+
+ {:else}
+ {@const visibleCompletionsNumberWidth = visibleCompletions.length.toString().length}
+ {@const focus = hoveredFocus}
+
+
+ {#await tokenizer then tokenizer}
+ {#each visibleCompletions as completion, index (index)}
+ {@const genIds = completion.tokenIds}
+ {@const hasTargets = Boolean(completionsData.targets)}
+ {@const tgtIds = hasTargets ? completionsData.targets?.[index] || [] : []}
+ {@const tokenComparisons = hasTargets ? compareTokenIds(genIds, tgtIds, tokenizer) : []}
+ {@const matchesRow = completionsData.targetStep.matches?.[index] || []}
+ {@const prefixLen = completionsData.decoderPromptLengths?.[index] ?? 1}
+
+
+ (hoveredFocus = {
+ exampleIndex: index,
+ tokenIndex: hoveredFocus?.tokenIndex ?? 0
+ })}
+ onmouseleave={() => (hoveredFocus = null)}
+ >
+
+ {(index + 1).toString().padStart(visibleCompletionsNumberWidth, '\u00A0')}
+
+
+
+
+ {#if hasTargets}
+ {#each tokenComparisons as item, tIndex (tIndex)}
+ {@const isPromptItem = item.kind === 'prompt'}
+ {@const completedBefore = tokenComparisons
+ .slice(0, tIndex)
+ .filter((it) => it.kind !== 'prompt').length}
+ {@const sequenceTokenIndex = isPromptItem
+ ? prefixLen
+ : prefixLen + Math.max(0, completedBefore)}
+ {@const isHighlighted =
+ !isPromptItem &&
+ focus?.exampleIndex === index &&
+ focus?.tokenIndex === sequenceTokenIndex}
+ {#if item.kind === 'prompt'}
+ {}}
+ onLeave={() => {}}
+ onSelect={() => {}}
+ />
+ {:else if item.isCorrect}
+ (hoveredFocus = { exampleIndex: ei, tokenIndex: ti })}
+ onLeave={() => (hoveredFocus = null)}
+ />
+ {:else}
+ (hoveredFocus = { exampleIndex: ei, tokenIndex: ti })}
+ onLeave={() => (hoveredFocus = null)}
+ />
+ {/if}
+ {/each}
+ {:else if genIds.length === 0}
+ [empty]
+ {:else}
+ {#each genIds as id, tIndex (tIndex)}
+ {@const text = decodeSingle(id, tokenizer)}
+ {@const isPrompt = tIndex < prefixLen}
+ {@const match = matchesRow[tIndex - prefixLen]}
+ {@const variant = isPrompt
+ ? 'prompt'
+ : !hasMatchData || matchesRow.length === 0
+ ? 'generated'
+ : match === true
+ ? 'correct'
+ : match === false
+ ? 'incorrect'
+ : 'neutral'}
+ {@const isHighlighted =
+ !isPrompt && focus?.exampleIndex === index && focus?.tokenIndex === tIndex}
+ (hoveredFocus = { exampleIndex: ei, tokenIndex: ti })}
+ onLeave={() => (hoveredFocus = null)}
+ />
+ {/each}
+ {/if}
+
+
+ {/each}
+ {/await}
+
+
+
+ {#if chartOptions}
+
+
+
+ {#if selectedProbsStep}
+
+ (selectedProbsStep = null)}
+ />
+
+ {/if}
+
{
+ if (!chartOptions) return null;
+ const baseBar = chartOptions.bar;
+ const tbc = chartOptions.tokensByCol ?? [];
+ const last = Math.max(0, tbc.length - 1);
+ const col =
+ activeStepCol == null
+ ? last
+ : Math.max(0, Math.min(last, Math.floor(activeStepCol)));
+ const defaultCategories = chartOptions.tokenCategoriesByCol[col];
+ const defaultData = tbc[col];
+ const categories = selectedProbsStep
+ ? selectedProbsStep.categories
+ : defaultCategories;
+ const data = selectedProbsStep ? selectedProbsStep.data : defaultData;
+ return {
+ ...baseBar,
+ title: selectedProbsStep
+ ? { ...baseBar.title, text: `probs (step ${selectedProbsStep.step})` }
+ : baseBar?.title,
+ xAxis: [
+ {
+ type: 'category',
+ data: categories,
+ axisPointer: { show: true, type: 'shadow' },
+ triggerEvent: true,
+ axisTick: { show: false },
+ axisLabel: { show: false },
+ axisLine: { show: false }
+ }
+ ],
+ series: [{ id: 'token-bars', type: 'bar', data }]
+ };
+ }
+ })}
+ >
+
+
+
+ {/if}
+
+ {/if}
+
diff --git a/examples/finetuning/src/lib/dataUtils.ts b/examples/finetuning/src/lib/dataUtils.ts
new file mode 100644
index 00000000..23567569
--- /dev/null
+++ b/examples/finetuning/src/lib/dataUtils.ts
@@ -0,0 +1,31 @@
+export function openDb(
+ dbName: string,
+ dbVersion: number,
+ onupgradeneeded: (db: IDBDatabase) => void
+): Promise {
+ return new Promise((resolve, reject) => {
+ const req = indexedDB.open(dbName, dbVersion);
+ req.onupgradeneeded = () => onupgradeneeded(req.result);
+ req.onsuccess = () => resolve(req.result);
+ req.onerror = () => reject(req.error);
+ req.onblocked = () => reject(new Error('IndexedDB upgrade blocked'));
+ });
+}
+
+export function promisify(req: IDBRequest): Promise {
+ return new Promise((resolve, reject) => {
+ req.onsuccess = () => resolve(req.result);
+ req.onerror = () => reject(req.error);
+ });
+}
+
+export function txRequest(
+ db: IDBDatabase,
+ storeName: string,
+ mode: IDBTransactionMode,
+ op: (store: IDBObjectStore) => IDBRequest
+): Promise {
+ const tx = db.transaction(storeName, mode);
+ const store = tx.objectStore(storeName);
+ return promisify(op(store));
+}
diff --git a/examples/finetuning/src/lib/train/generate.ts b/examples/finetuning/src/lib/train/generate.ts
new file mode 100644
index 00000000..089c50fe
--- /dev/null
+++ b/examples/finetuning/src/lib/train/generate.ts
@@ -0,0 +1,263 @@
+/**
+ * @fileoverview Shared text generation utilities for different model types
+ */
+
+import type { Device, Tensor } from '@piston-ml/piston-web';
+
+import { int32, tensor, WeakTensorFunctionMode } from '@piston-ml/piston-web';
+
+import type { GeneratableModel } from './types';
+
+import { createEmptyDecoderKVCache, type DecoderKVCache } from './model/cache';
+import { GPT } from './model/gpt';
+
+export interface GenerationConfig {
+ maxTokens?: number;
+ stopTokens?: number | number[];
+ device?: string | Device;
+ startToken?: number;
+ maxTargetLength?: number;
+ temperature?: number;
+ useKvCache?: boolean;
+}
+
+export interface GenerationResult {
+ sequences: number[][];
+ // Raw probabilities (post-softmax) for generated tokens
+ probs?: Tensor;
+ // Running average throughput since start of generation (tokens/second)
+ tokensPerSecond?: number;
+}
+
+function normalizeStopTokens(stopTokens: number | number[] = []): Set {
+ return new Set(Array.isArray(stopTokens) ? stopTokens : [stopTokens]);
+}
+
+/**
+ * Create tensor from token sequences with proper device and dtype
+ */
+function createInputTensor(sequences: number[][], device: Device): Tensor {
+ return tensor(sequences, { device, dtype: int32 });
+}
+
+/**
+ * Convert tensor output to array of token IDs
+ */
+async function tensorToTokens(tokenTensor: Tensor): Promise {
+ return (await (await tokenTensor.to('cpu')).toVec()) as Int32Array;
+}
+
+/**
+ * Check if any sequence should continue generating (hasn't hit stop tokens)
+ */
+function shouldContinueGeneration(
+ results: number[][],
+ newTokens: Int32Array,
+ stopTokenSet: Set
+): boolean {
+ let shouldContinue = false;
+ for (let i = 0; i < results.length; i++) {
+ // If this sequence already ended with a stop token, do not append anything further
+ const sequence = results[i];
+ const lastToken = sequence.length > 0 ? sequence[sequence.length - 1] : undefined;
+ const alreadyStopped = lastToken !== undefined && stopTokenSet.has(lastToken);
+
+ if (alreadyStopped) {
+ continue;
+ }
+
+ const token = newTokens[i];
+ // Always append the newly generated token
+ sequence.push(token);
+ // Continue only if we did not just append a stop token
+ if (!stopTokenSet.has(token)) {
+ shouldContinue = true;
+ }
+ }
+ return shouldContinue;
+}
+
+/**
+ * Create a simple running tokens/sec tracker using the Performance API
+ */
+function startTokensPerSecondTracker(): (tokensProducedThisStep: number) => number {
+ const now =
+ typeof performance !== 'undefined' && typeof performance.now === 'function'
+ ? () => performance.now()
+ : () => Date.now();
+ const startMs = now();
+ let totalTokens = 0;
+ return (tokensProducedThisStep: number) => {
+ if (tokensProducedThisStep > 0) totalTokens += tokensProducedThisStep;
+ const elapsedMs = Math.max(1, now() - startMs);
+ return totalTokens / (elapsedMs / 1000);
+ };
+}
+
+/**
+ * Generate tokens for GPT (decoder-only) model
+ */
+export async function* generateGPTStream(
+ model: GPT,
+ input: number[] | number[][],
+ config: GenerationConfig = {}
+): AsyncGenerator {
+ const stopTokenSet = normalizeStopTokens(config.stopTokens);
+ const isBatch = Array.isArray(input[0]);
+ const sequences = isBatch ? (input as number[][]) : [input as number[]];
+
+ // Initialize results with copies of input sequences
+ const results = sequences.map((seq) => [...seq]);
+ const device = model.lm_head.weight.device;
+
+ const kvCache: DecoderKVCache | null =
+ (config.useKvCache ?? true) ? createEmptyDecoderKVCache(model.config.nLayer) : null;
+
+ let seeded = false;
+ let step = 0;
+ const getTokensPerSecond = startTokensPerSecondTracker();
+ while (true) {
+ const weakMode = new WeakTensorFunctionMode();
+ try {
+ weakMode.markWeak(kvCache);
+ // Seed cache once with full sequences, then switch to single-token steps
+ const prevLengths = results.map((seq) => seq.length);
+ let inputTensor: Tensor;
+ if (kvCache && seeded) {
+ const latestTokens = results.map((seq) => [seq[seq.length - 1] ?? 0]);
+ inputTensor = createInputTensor(latestTokens, device);
+ } else {
+ const { padded } = rightPadSequences(results);
+ inputTensor = createInputTensor(padded, device);
+ }
+
+ // Forward pass to get logits
+ const [logits, _] = model.forward(inputTensor, { kvCache });
+ weakMode.pin(kvCache);
+
+ // Get logits for the last token in each sequence
+ const [batchSize, seqLen, vocabSize] = logits.size();
+ let lastTokenLogits = logits
+ .slice([
+ [0, batchSize],
+ [seqLen - 1, seqLen],
+ [0, vocabSize]
+ ])
+ .view([batchSize, vocabSize]);
+
+ if (config.temperature && config.temperature > 0) {
+ lastTokenLogits = lastTokenLogits.div(config.temperature);
+ }
+
+ lastTokenLogits = lastTokenLogits.softmax(-1);
+
+ // Choose next tokens: sample when temperature > 0, else greedy argmax
+ const nextTokenTensor =
+ config.temperature && config.temperature > 0
+ ? lastTokenLogits.multinomial(1, { replacement: false })
+ : lastTokenLogits.argmax({ dim: -1 });
+ const nextTokensArray = await tensorToTokens(nextTokenTensor);
+
+ // Update sequences and check for stop conditions
+ const shouldContinue = shouldContinueGeneration(results, nextTokensArray, stopTokenSet);
+ // Compute tokens appended this step across active sequences
+ let appendedThisStep = 0;
+ for (let i = 0; i < results.length; i++) {
+ appendedThisStep += results[i].length - prevLengths[i];
+ }
+ const tokensPerSecond = getTokensPerSecond(appendedThisStep);
+
+ weakMode.pin([results, lastTokenLogits]);
+ // Yield current state with sequences and logits
+ yield {
+ sequences: isBatch ? results.map((seq) => [...seq]) : [results[0].slice()],
+ probs: lastTokenLogits, // Provide the softmax'd logits for the last token
+ tokensPerSecond
+ };
+
+ // Mark cache as seeded after first forward with cache
+ if (kvCache && !seeded) seeded = true;
+
+ // If all sequences hit stop tokens, break
+ if (!shouldContinue) {
+ break;
+ }
+
+ step++;
+
+ if (config.maxTokens !== undefined && step >= config.maxTokens) {
+ break;
+ }
+ } finally {
+ weakMode[Symbol.dispose]();
+ }
+ }
+}
+
+/**
+ * Standard generate function that collects all tokens (backward compatible)
+ */
+export async function generate(
+ model: GeneratableModel,
+ input: number[] | number[][],
+ config: GenerationConfig = {}
+): Promise {
+ const results: number[][] = [];
+ let tokenCount = 0;
+ const maxTokens = config.maxTokens ?? 50;
+
+ for await (const generationResult of generateGPTStream(model, input, config)) {
+ results.length = 0;
+ results.push(...generationResult.sequences);
+ tokenCount++;
+
+ if (tokenCount >= maxTokens) {
+ break;
+ }
+ }
+
+ return results;
+}
+
+/**
+ * Right-pad sequences to a uniform length for batched forward passes.
+ * Returns padded sequences and a padding mask (1 for real tokens, 0 for padding).
+ * Note: The original sequences array is NOT modified.
+ */
+export function rightPadSequences(
+ sequences: number[][],
+ padToken?: number
+): { padded: number[][]; paddingMask: number[][] } {
+ const maxLen = sequences.reduce((m, s) => Math.max(m, s.length), 0);
+ const padded: number[][] = new Array(sequences.length);
+ const paddingMask: number[][] = new Array(sequences.length);
+
+ for (let i = 0; i < sequences.length; i++) {
+ const seq = sequences[i];
+ const realLen = seq.length;
+ const rowMask: number[] = new Array(maxLen).fill(0);
+ for (let j = 0; j < realLen; j++) rowMask[j] = 1;
+
+ if (realLen === maxLen) {
+ padded[i] = [...seq];
+ paddingMask[i] = rowMask;
+ continue;
+ }
+
+ let padVal: number;
+ if (padToken !== undefined) {
+ padVal = padToken;
+ } else if (realLen > 0) {
+ padVal = seq[realLen - 1];
+ } else {
+ padVal = 0; // degenerate case
+ }
+
+ const numPad = maxLen - realLen;
+ padded[i] =
+ realLen > 0 ? [...seq, ...new Array(numPad).fill(padVal)] : new Array(maxLen).fill(padVal);
+ paddingMask[i] = rowMask;
+ }
+
+ return { padded, paddingMask };
+}
diff --git a/examples/finetuning/src/lib/train/model/cache.ts b/examples/finetuning/src/lib/train/model/cache.ts
new file mode 100644
index 00000000..7dd33524
--- /dev/null
+++ b/examples/finetuning/src/lib/train/model/cache.ts
@@ -0,0 +1,20 @@
+import type { Tensor } from '@piston-ml/piston-web';
+
+export type SelfAttentionCache = {
+ k: Tensor;
+ v: Tensor;
+ length: number;
+};
+
+export type DecoderLayerCache = {
+ self?: SelfAttentionCache;
+ cross?: SelfAttentionCache;
+};
+
+export type DecoderKVCache = {
+ layers: DecoderLayerCache[];
+};
+
+export function createEmptyDecoderKVCache(nLayers: number): DecoderKVCache {
+ return { layers: Array.from({ length: nLayers }, () => ({})) };
+}
diff --git a/examples/finetuning/src/lib/train/model/config.ts b/examples/finetuning/src/lib/train/model/config.ts
new file mode 100644
index 00000000..1fd79161
--- /dev/null
+++ b/examples/finetuning/src/lib/train/model/config.ts
@@ -0,0 +1,49 @@
+import type { GPT2ModelType } from '$lib/workspace/config';
+
+import { GPT2_BLOCK_SIZE, GPT2_VOCAB_SIZE } from './gpt';
+
+export interface GPT2Config {
+ modelType: GPT2ModelType;
+ nLayer: number;
+ nHead: number;
+ nEmbd: number;
+ vocabSize: number;
+ blockSize: number;
+ embdPdrop: number;
+ residPdrop: number;
+ attnPdrop: number;
+}
+
+function getModelParametersFromType(modelType: GPT2ModelType): {
+ nLayer: number;
+ nHead: number;
+ nEmbd: number;
+} {
+ switch (modelType) {
+ case 'distilgpt2':
+ return { nLayer: 6, nHead: 12, nEmbd: 768 };
+ case 'gpt2':
+ return { nLayer: 12, nHead: 12, nEmbd: 768 };
+ case 'gpt2-medium':
+ return { nLayer: 24, nHead: 16, nEmbd: 1024 };
+ case 'gpt2-large':
+ return { nLayer: 36, nHead: 20, nEmbd: 1280 };
+ case 'gpt2-xl':
+ return { nLayer: 48, nHead: 25, nEmbd: 1600 };
+ }
+}
+
+export function buildGPT2Config(modelType: GPT2ModelType): GPT2Config {
+ const { nLayer, nHead, nEmbd } = getModelParametersFromType(modelType);
+ return {
+ vocabSize: GPT2_VOCAB_SIZE,
+ blockSize: GPT2_BLOCK_SIZE,
+ modelType,
+ nLayer,
+ nHead,
+ nEmbd,
+ embdPdrop: 0.1,
+ residPdrop: 0.1,
+ attnPdrop: 0.1
+ };
+}
diff --git a/examples/finetuning/src/lib/train/model/gpt.ts b/examples/finetuning/src/lib/train/model/gpt.ts
new file mode 100644
index 00000000..971f5188
--- /dev/null
+++ b/examples/finetuning/src/lib/train/model/gpt.ts
@@ -0,0 +1,392 @@
+/**
+ * @fileoverview Implementation of generic encoder-decoder, encoder-only, and decoder-only
+ * transformer models
+ */
+
+import type { Config } from '$lib/workspace/config';
+
+import { cat, CrossEntropyLoss, nn, Parameter, type Tensor } from '@piston-ml/piston-web';
+
+import type { DecoderLayerCache, SelfAttentionCache } from './cache';
+import type { DecoderKVCache } from './cache';
+
+import { createEmptyDecoderKVCache } from './cache';
+import { type GPT2Config } from './config';
+import { createCausalMask, createPositionIds, maskedFill } from './utils';
+
+export const GPT2_VOCAB_SIZE = 1024;
+export const GPT2_BLOCK_SIZE = 1024;
+
+export class CausalSelfAttention extends nn.Module {
+ private readonly nHeads: number;
+ private readonly nKvHeads: number;
+ private readonly embeddingSize: number;
+ private readonly headDim: number;
+ private readonly config: GPT2Config;
+
+ private readonly c_attn: nn.Linear;
+ private readonly c_proj: nn.Linear;
+
+ private attn_dropout?: nn.Dropout;
+ private resid_dropout?: nn.Dropout;
+
+ constructor(config: GPT2Config) {
+ super();
+
+ this.nHeads = config.nHead;
+ this.nKvHeads = config.nHead;
+ this.embeddingSize = config.nEmbd;
+ this.headDim = config.nEmbd / this.nHeads;
+ this.config = config;
+
+ const kvDim = this.headDim * this.nKvHeads;
+ const qkvOutDim = this.embeddingSize + 2 * kvDim;
+
+ this.c_attn = new nn.Linear(this.embeddingSize, qkvOutDim);
+ this.c_proj = new nn.Linear(this.embeddingSize, this.embeddingSize);
+
+ this.attn_dropout = new nn.Dropout(config.attnPdrop);
+ this.resid_dropout = new nn.Dropout(config.residPdrop);
+ }
+
+ private projectQkv(input: Tensor): [Tensor, Tensor, Tensor] {
+ const [B, T, _] = input.size();
+ const kvDim = this.headDim * this.nKvHeads;
+ const keyPos = this.embeddingSize;
+ const valuePos = this.embeddingSize + kvDim;
+ const qkv = this.c_attn.forward(input);
+ let q = qkv.slice([
+ [0, B],
+ [0, T],
+ [0, this.embeddingSize]
+ ]);
+ let k = qkv.slice([
+ [0, B],
+ [0, T],
+ [keyPos, keyPos + kvDim]
+ ]);
+ let v = qkv.slice([
+ [0, B],
+ [0, T],
+ [valuePos, valuePos + kvDim]
+ ]);
+ const qShape = [B, T, this.nHeads, this.headDim];
+ const kvShape = [B, T, this.nKvHeads, this.headDim];
+ q = q.view(qShape)?.transpose(1, 2);
+ k = k.view(kvShape)?.transpose(1, 2);
+ v = v.view(kvShape)?.transpose(1, 2);
+ [k, v] = this.applyGroupedQueryBroadcast(k, v, B, T);
+ return [q, k, v];
+ }
+
+ private runAttention(
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ options: {
+ attentionMask?: Tensor | null;
+ cache?: SelfAttentionCache | null;
+ ropeOffsets?: { qOffset: number; kOffset: number } | null;
+ }
+ ): [Tensor, SelfAttentionCache | null] {
+ const B = q.size(0);
+ const T_q = q.size(2);
+
+ // Concatenate with cache if present
+ const kCat = options.cache ? cat([options.cache.k, k], { dim: 2 }) : k;
+ const vCat = options.cache ? cat([options.cache.v, v], { dim: 2 }) : v;
+ const T_kv = kCat.size(2);
+
+ // Compute attention scores
+ let att = q.matmul(kCat, { transRhs: true }).div(Math.sqrt(this.headDim));
+
+ att = this.applyCausalMask(att, T_q, T_kv);
+
+ if (options.attentionMask) {
+ att = this.applyAttentionMask(att, options.attentionMask);
+ }
+
+ // Apply softmax over keys plus optional sink column
+ att = att.softmax(3);
+ att = this.attn_dropout ? this.attn_dropout.forward(att) : att;
+
+ let y = att.matmul(vCat).transpose(1, 2).view([B, T_q, this.embeddingSize]);
+ // Apply output projection
+ y = this.c_proj.forward(y);
+
+ if (this.resid_dropout) {
+ y = this.resid_dropout.forward(y);
+ }
+
+ const newCache: SelfAttentionCache | null = {
+ k: kCat,
+ v: vCat,
+ length: T_kv
+ };
+ return [y, newCache];
+ }
+
+ private applyAttentionMask(att: Tensor, attentionMask: Tensor) {
+ // Convert attention mask to appropriate shape [B, 1, 1, T_kv] and broadcast
+ const mask = attentionMask.unsqueeze(1).unsqueeze(2).broadcastTo(att.size());
+ return maskedFill(att, mask, -1e9);
+ }
+
+ private applyCausalMask(att: Tensor, queryLen: number, keyLen: number) {
+ // Apply causal mask if needed
+ return queryLen <= 1
+ ? att
+ : (() => {
+ const mask = createCausalMask(queryLen, keyLen).broadcastTo(att.size());
+ return maskedFill(att, mask, -1e9);
+ })();
+ }
+
+ /**
+ * Apply grouped-query attention broadcasting to key and value tensors
+ * @param k Key tensor [B, nKvHeads, seqLen, headDim]
+ * @param v Value tensor [B, nKvHeads, seqLen, headDim]
+ * @param B Batch size
+ * @param seqLen Sequence length
+ * @returns Broadcasted [k, v] tensors [B, nHeads, seqLen, headDim]
+ */
+ private applyGroupedQueryBroadcast(
+ k: Tensor,
+ v: Tensor,
+ B: number,
+ seqLen: number
+ ): [Tensor, Tensor] {
+ if (this.nHeads !== this.nKvHeads) {
+ const repeatFactor = this.nHeads / this.nKvHeads;
+ const th = seqLen * this.headDim;
+
+ k = k
+ .view([B, this.nKvHeads, th])
+ .unsqueeze(2)
+ .broadcastTo([B, this.nKvHeads, repeatFactor, th])
+ .view([B, this.nHeads, seqLen, this.headDim]);
+ v = v
+ .view([B, this.nKvHeads, th])
+ .unsqueeze(2)
+ .broadcastTo([B, this.nKvHeads, repeatFactor, th])
+ .view([B, this.nHeads, seqLen, this.headDim]);
+ }
+ return [k, v];
+ }
+
+ forward(
+ input: Tensor,
+ options: { attentionMask?: Tensor | null; cache?: SelfAttentionCache | null } = {}
+ ): { output: Tensor; pastKeyValues?: SelfAttentionCache } {
+ const qkv = this.projectQkv(input);
+ const [q, k, v] = qkv;
+ const pastLen = options.cache?.length ?? 0;
+ const [y, newCache] = this.runAttention(q, k, v, {
+ attentionMask: options.attentionMask ?? null,
+ cache: options.cache ?? null,
+ ropeOffsets: { qOffset: pastLen, kOffset: pastLen }
+ });
+ return { output: y, pastKeyValues: newCache ?? undefined };
+ }
+}
+
+type GPTDict = {
+ drop: nn.Dropout;
+ wte: nn.Embedding;
+ wpe: nn.Embedding;
+ h: nn.ModuleList;
+ ln_f: nn.Module<[Tensor], Tensor>;
+};
+
+export class GPT extends nn.Module {
+ public config: GPT2Config;
+ public transformer: nn.ModuleDict;
+ readonly lm_head: nn.Linear;
+ private readonly criterion: CrossEntropyLoss;
+
+ constructor(modelConfig: GPT2Config, config: Config) {
+ super();
+
+ this.config = modelConfig;
+ const transformerDict: GPTDict = {
+ drop: new nn.Dropout(this.config.embdPdrop),
+ wte: new nn.Embedding(this.config.vocabSize, this.config.nEmbd),
+ wpe: new nn.Embedding(this.config.blockSize, this.config.nEmbd),
+ h: new nn.ModuleList(
+ Array.from({ length: this.config.nLayer }).map(() => new Block(this.config))
+ ),
+ ln_f: new nn.LayerNorm(this.config.nEmbd)
+ };
+
+ this.transformer = new nn.ModuleDict(transformerDict);
+
+ // Output projection with optional weight tying to token embeddings
+ this.lm_head = new nn.Linear(this.config.nEmbd, this.config.vocabSize, false);
+ this.lm_head.weight = new Parameter(this.transformer.dict.wte.weight);
+
+ this.criterion = new CrossEntropyLoss({
+ labelSmoothing: config.training.labelSmoothing.present
+ ? config.training.labelSmoothing.value
+ : 0.0,
+ ignoreIndex: -100
+ });
+ }
+
+ /**
+ * @param input - Input tensor of token IDs [batch_size, seq_len]
+ * @param targets - Target tensor of token IDs [batch_size,
+ * seq_len]
+ * @returns [logits, loss]
+ */
+ forward(
+ input: Tensor,
+ options: { targets?: Tensor | null; kvCache?: DecoderKVCache | null } = {}
+ ): [Tensor, Tensor | null] {
+ const targets = options.targets ?? null;
+ const kvCache = options.kvCache ?? null;
+ const [batchSize, seqLen] = input.size();
+
+ if (!seqLen) {
+ throw new Error(
+ 'Input tensor has no sequence length (did you forget to pass input as batches?)'
+ );
+ }
+
+ // Get token embeddings
+ let wordEmbeddings = this.transformer.dict.wte.forward(input);
+
+ // Use cache length (if any) as position offset for absolute encodings during incremental decoding
+ const posOffset = kvCache?.layers?.[0]?.self?.length ?? 0;
+
+ // Add positional embeddings
+ const positions = createPositionIds(seqLen, batchSize, wordEmbeddings.device, posOffset);
+ const positionEmbeddingsOutput = this.transformer.dict.wpe.forward(positions);
+ wordEmbeddings = wordEmbeddings.add(positionEmbeddingsOutput);
+ // Apply embedding dropout if configured
+ wordEmbeddings = this.transformer.dict.drop.forward(wordEmbeddings);
+
+ // Pass through each transformer layer
+ let hiddenStates = wordEmbeddings;
+
+ const useCache = kvCache !== null;
+ const cacheObj = useCache ? kvCache! : createEmptyDecoderKVCache(this.config.nLayer);
+ for (let i = 0; i < this.config.nLayer; i++) {
+ const layerModule = this.transformer.dict.h[i] as Block;
+ const result = layerModule.forward(hiddenStates, { cache: cacheObj.layers[i] });
+ if (useCache) {
+ cacheObj.layers[i] = result.pastKeyValues!;
+ hiddenStates = result.output;
+ } else {
+ hiddenStates = result.output;
+ }
+ }
+
+ // Apply final layer normalization
+ if (this.transformer.dict.ln_f) {
+ hiddenStates = this.transformer.dict.ln_f.forward(hiddenStates);
+ }
+
+ // Project to vocabulary
+ const logits = this.lm_head.forward(hiddenStates);
+
+ const loss = targets
+ ? this.criterion.forward(logits.view([-1, logits.size(-1)]), targets.view(-1))
+ : null;
+
+ return [logits, loss];
+ }
+}
+
+export type DecoderLayerForwardOptions = {
+ encoderHiddenStates?: Tensor | null;
+ srcPaddingMask?: Tensor | null;
+ tgtPaddingMask?: Tensor | null;
+ cache?: DecoderLayerCache | null;
+};
+
+export type DecoderLayerForwardResult = {
+ output: Tensor;
+ pastKeyValues?: DecoderLayerCache;
+};
+
+export class Block extends nn.Module {
+ private readonly lnSelfAttn!: nn.Module<[Tensor], Tensor>;
+ private readonly lnMlp: nn.Module<[Tensor], Tensor>;
+ private readonly selfAttn: CausalSelfAttention;
+ private readonly mlp: MLP;
+ private readonly dropout?: nn.Dropout;
+
+ constructor(config: GPT2Config) {
+ super();
+
+ this.lnSelfAttn = new nn.LayerNorm(config.nEmbd);
+ this.selfAttn = new CausalSelfAttention(config);
+
+ this.lnMlp = new nn.LayerNorm(config.nEmbd);
+ this.mlp = new MLP(config.nEmbd);
+
+ if (config.residPdrop > 0) {
+ this.dropout = new nn.Dropout(config.residPdrop);
+ }
+ }
+
+ forward(input: Tensor, options: DecoderLayerForwardOptions = {}): DecoderLayerForwardResult {
+ const tgtPaddingMask = options.tgtPaddingMask ?? null;
+ const cache = options.cache ?? null;
+ let x = input;
+ let selfCache = cache?.self ?? null;
+
+ const residual = input;
+ x = this.lnSelfAttn.forward(input);
+ const selfResult = this.selfAttn.forward(x, {
+ attentionMask: tgtPaddingMask ?? null,
+ cache: selfCache ?? null
+ });
+ selfCache = selfResult.pastKeyValues ?? null;
+ x = residual.add(selfResult.output);
+
+ const residual3 = x;
+ x = this.lnMlp.forward(x);
+ x = this.mlp.forward(x);
+ if (this.dropout) {
+ x = this.dropout.forward(x);
+ }
+ x = residual3.add(x);
+
+ const result: DecoderLayerForwardResult = { output: x };
+ if (cache) {
+ result.pastKeyValues = { self: selfCache ?? undefined, cross: cache.cross ?? undefined };
+ }
+ return result;
+ }
+}
+
+export class MLP extends nn.Module {
+ private readonly upProj: nn.Linear;
+ private readonly downProj: nn.Linear;
+ private readonly activation: (x: Tensor) => Tensor;
+
+ /**
+ * @param embeddingSize - Embedding size
+ */
+ constructor(embeddingSize: number) {
+ super();
+
+ const intermediateSize = 4 * embeddingSize;
+ this.upProj = new nn.Linear(embeddingSize, intermediateSize);
+ this.downProj = new nn.Linear(intermediateSize, embeddingSize);
+
+ this.activation = (x: Tensor): Tensor => x.gelu();
+ }
+
+ /**
+ * Forward pass through the MLP
+ * @param input - Input tensor
+ * @returns Output tensor
+ */
+ forward(input: Tensor): Tensor {
+ let h = this.upProj.forward(input);
+ h = this.activation(h);
+ return this.downProj.forward(h);
+ }
+}
diff --git a/examples/finetuning/src/lib/train/model/utils.ts b/examples/finetuning/src/lib/train/model/utils.ts
new file mode 100644
index 00000000..2179a192
--- /dev/null
+++ b/examples/finetuning/src/lib/train/model/utils.ts
@@ -0,0 +1,52 @@
+import { arange, Device, gpu, int32, type Tensor } from '@piston-ml/piston-web';
+
+/**
+ * Create a causal (lower triangular) mask.
+ * @param queryLen - Length of the current query.
+ * @param keyLen - Length of the key (which may include cached tokens).
+ * @returns Causal mask tensor of shape [1, numHeads, queryLen, keyLen].
+ */
+export function createCausalMask(queryLen: number, keyLen: number): Tensor {
+ // General causal mask supporting past KV cache where keyLen may exceed queryLen.
+ // We want to mask future positions: for each query i, keys j > pastLen + i are masked.
+ // pastLen is inferred as keyLen - queryLen when using KV cache (else 0).
+ const pastLen = Math.max(0, keyLen - queryLen);
+ const i = arange({ end: queryLen, device: gpu, dtype: int32 })
+ .unsqueeze(1)
+ .broadcastTo([queryLen, keyLen]);
+ const j = arange({ end: keyLen, device: gpu, dtype: int32 })
+ .unsqueeze(0)
+ .broadcastTo([queryLen, keyLen]);
+ // Mask is true where positions are allowed: j <= pastLen + i
+ return j.le(i.add(pastLen));
+}
+
+/**
+ * Create position IDs tensor [0, 1, 2, ..., seqLen-1] and broadcast to batch size
+ * @param seqLen - Sequence length
+ * @param batchSize - Batch size
+ * @param device - Device to place tensor on
+ * @returns Position IDs tensor
+ */
+export function createPositionIds(
+ seqLen: number,
+ batchSize: number,
+ device: Device,
+ offset: number = 0
+): Tensor {
+ // Create position IDs tensor [offset, offset+1, ..., offset+seqLen-1] and broadcast to batch
+ const positionIds = arange({ end: seqLen, device, dtype: int32 }).add(offset).cast(int32);
+ // Reshape to [1, seqLen] and broadcast to [batchSize, seqLen]
+ return positionIds.unsqueeze(0).broadcastTo([batchSize, seqLen]);
+}
+
+/**
+ * Apply mask to attention scores
+ * @param onFalse - Attention scores
+ * @param mask - Mask tensor
+ * @param onTrueValue - Value to fill masked positions with
+ * @returns Masked scores
+ */
+export function maskedFill(onTrue: Tensor, mask: Tensor, onFalseValue: number): Tensor {
+ return onTrue.where(mask, onFalseValue);
+}
diff --git a/examples/finetuning/src/lib/train/moduleWorker.ts b/examples/finetuning/src/lib/train/moduleWorker.ts
new file mode 100644
index 00000000..3d35a2ad
--- /dev/null
+++ b/examples/finetuning/src/lib/train/moduleWorker.ts
@@ -0,0 +1,295 @@
+import type { Config } from '$lib/workspace/config';
+import type { Tensor } from '@piston-ml/piston-web';
+
+import * as piston from '@piston-ml/piston-web';
+
+import type { WorkerCommand, WorkerEvent } from './protocol';
+
+import { TrainingSession } from './session';
+import { type CheckpointExtra, splitLoadedState } from './utils/checkpoint';
+import { inspectModel } from './utils/model';
+
+let session: TrainingSession | undefined;
+
+// Console Interception
+const originalConsole = {
+ log: console.log.bind(console),
+ error: console.error.bind(console),
+ warn: console.warn.bind(console),
+ info: console.info.bind(console),
+ debug: console.debug.bind(console)
+};
+
+function formatArgs(args: unknown[]) {
+ return args
+ .map((arg) => {
+ if (typeof arg === 'object' && arg !== null) {
+ try {
+ return JSON.stringify(arg);
+ } catch (_: unknown) {
+ return '[Unserializable Object]';
+ }
+ }
+ return String(arg);
+ })
+ .join(' ');
+}
+
+interface LogInfo {
+ level: string;
+ message: string;
+ source: string;
+ lineno?: number;
+ colno?: number;
+}
+
+function sendLog(level: string, message: string, source: string = '[Worker]') {
+ // Check if this is a WASM log and parse it
+ const wasmLogRegex = /^\[WASM ([^:]+):(\d+)(?::(\d+))?\] (.*)$/;
+ const match = message.match(wasmLogRegex);
+
+ if (level === 'error' && message.startsWith('panicked at')) {
+ const lines = message.split('\n');
+ if (lines[1].startsWith('VRAM limit exceeded')) {
+ self.postMessage({ type: 'log', level: 'error', message: lines[1] });
+ self.postMessage({
+ type: 'error',
+ runId: session?.runId,
+ message: 'VRAM limit exceeded',
+ name: 'VRAMLimitExceededError'
+ });
+ return;
+ } else {
+ self.postMessage({ type: 'error', runId: session?.runId, message });
+ }
+ }
+
+ if (match) {
+ // Handle WASM logs with parsed source info
+ const [, filepath, lineno, colno, actualMessage] = match;
+ const logInfo: LogInfo = {
+ level,
+ message: actualMessage,
+ source: `[WASM] ${filepath}`,
+ lineno: parseInt(lineno, 10),
+ ...(colno && { colno: parseInt(colno, 10) })
+ };
+ self.postMessage({ type: 'log', ...logInfo });
+ } else {
+ // Handle regular logs
+ const logInfo: LogInfo = {
+ level,
+ message,
+ source
+ };
+ self.postMessage({ type: 'log', ...logInfo });
+ }
+}
+
+// Wrap console methods before importing Piston to catch its logs
+Object.keys(originalConsole).forEach((level) => {
+ (console as unknown as Record void>)[level] = (
+ ...args: unknown[]
+ ) => {
+ const message = formatArgs(args);
+
+ // Use sendLog which will handle WASM log parsing internally
+ sendLog(level, message, currentExecutionSource);
+
+ // Also call original console for debugging
+ originalConsole[level as keyof typeof originalConsole](...args);
+ };
+});
+
+// Global error handler - catches unhandled errors
+self.addEventListener('error', (event) => {
+ const errorMessage = `Uncaught Error: ${event.message} at ${event.filename}:${event.lineno}:${event.colno}`;
+ sendLog('error', errorMessage);
+ if (event.error?.stack) {
+ sendLog('error', `${event.error.stack}`);
+ }
+});
+
+// Unhandled promise rejection handler - catches unhandled promise rejections
+self.addEventListener('unhandledrejection', (event) => {
+ const errorMessage = `Unhandled Promise Rejection: ${event.reason}`;
+ sendLog('error', errorMessage);
+ if (event.reason?.stack) {
+ sendLog('error', `${event.reason.stack}`);
+ }
+ // Prevent the default browser behavior (logging to console)
+ event.preventDefault();
+});
+
+// Intercept and override the default error reporting
+const originalOnError = self.onerror;
+self.onerror = (message, source, lineno, colno, error) => {
+ const errorMessage = `Global Error: ${message} at ${source}:${lineno}:${colno}`;
+ sendLog('error', errorMessage);
+ if (error?.stack) {
+ sendLog('error', `${error.stack}`);
+ }
+ // Call original handler if it exists
+ if (originalOnError) {
+ return originalOnError(message, source, lineno, colno, error);
+ }
+ // Prevent default browser error handling
+ return true;
+};
+
+//
+// End Console Interception
+//
+
+// Track current execution context for logging (will be reassigned during execution)
+let currentExecutionSource = '[Worker]';
+
+function postEvent(e: WorkerEvent) {
+ self.postMessage(e);
+}
+
+function startTraining() {
+ if (!session) return;
+ const runId = session.runId;
+ session.start().catch((error: unknown) => {
+ console.error('Training error:', error);
+ self.postMessage({
+ type: 'error',
+ runId,
+ message: error instanceof Error ? error.message : String(error),
+ name: error instanceof Error ? error.name : undefined,
+ stack: error instanceof Error ? error.stack : undefined
+ });
+ });
+}
+
+// Message handler for worker
+self.addEventListener('message', async (event) => {
+ const raw = event.data as WorkerCommand | { type: string; data?: unknown };
+ const type: string = (raw as { type: string }).type;
+ const data: unknown = (raw as { data?: unknown }).data;
+
+ switch (type) {
+ case 'save': {
+ session?.save();
+ break;
+ }
+ case 'pause': {
+ session?.pause();
+ break;
+ }
+ case 'resume': {
+ session?.resume();
+ startTraining();
+ break;
+ }
+ case 'step': {
+ if (!session) break;
+ await session.pause();
+ await session.step({ manual: true });
+ break;
+ }
+ case 'start':
+ try {
+ const {
+ runId: runIdFromData,
+ config,
+ resumeFrom,
+ gpuPowerPreference
+ } = data as {
+ runId: string;
+ config: Config;
+ resumeFrom?: Uint8Array;
+ gpuPowerPreference?: 'high-performance' | 'low-power';
+ };
+ currentExecutionSource = `[Training:${runIdFromData}]`;
+
+ // Apply GPU power preference before any GPU initialization
+ if (gpuPowerPreference) {
+ await piston.applyGpuPowerPreference(gpuPowerPreference);
+ }
+
+ console.info(`Starting training for run ${runIdFromData}`);
+ session = new TrainingSession(runIdFromData, config, postEvent, resumeFrom);
+ startTraining();
+ } catch (error: unknown) {
+ console.error('Training error:', error);
+ self.postMessage({
+ type: 'error',
+ runId: (data as { runId?: string })?.runId,
+ message: error instanceof Error ? error.message : String(error),
+ name: error instanceof Error ? error.name : undefined,
+ stack: error instanceof Error ? error.stack : undefined
+ });
+ }
+ break;
+ case 'checkpoint.peekConfig': {
+ const { requestId, buffer } = data as {
+ requestId: string;
+ buffer: Uint8Array;
+ };
+ const loaded = piston.load(buffer, piston.gpu);
+ const split = splitLoadedState(
+ loaded as { state: Record; extra?: CheckpointExtra }
+ );
+ self.postMessage({ type: 'checkpoint.config', requestId, config: split.config });
+ break;
+ }
+ case 'inspectModel':
+ try {
+ const { config, requestId, gpuPowerPreference } = data as {
+ config: Config;
+ requestId: string;
+ gpuPowerPreference?: 'high-performance' | 'low-power';
+ };
+ currentExecutionSource = '[ModelInspection]';
+
+ // Apply GPU power preference before any GPU usage
+ if (gpuPowerPreference) {
+ await piston.applyGpuPowerPreference(gpuPowerPreference);
+ (globalThis as unknown as { piston: typeof piston }).piston = piston;
+ }
+
+ console.debug('Inspecting model...');
+ const result = inspectModel(config);
+
+ self.postMessage({
+ type: 'modelInspection',
+ requestId,
+ parameterCount: result.parameterCount,
+ hiddenSize: result.hiddenSize,
+ mlpIntermediateSize: result.mlpIntermediateSize,
+ modelIndex: result.modelIndex,
+ vocabSize: result.vocabSize,
+ blockSize: result.blockSize
+ });
+ } catch (error: unknown) {
+ console.error('Model inspection error:', error);
+ self.postMessage({
+ type: 'modelInspectionError',
+ requestId: (data as { requestId?: string })?.requestId ?? '',
+ message: error instanceof Error ? error.message : String(error)
+ });
+ }
+ break;
+
+ default:
+ console.warn(`Unknown message type: ${type}`);
+ break;
+ }
+});
+
+// Initialize Piston, then signal that the worker is ready
+piston
+ .init()
+ .then(() => {
+ console.info('Piston initialized');
+ self.postMessage({ type: 'ready' });
+ })
+ .catch((error: unknown) => {
+ console.error('Error initializing Piston:', error);
+ self.postMessage({
+ type: 'error',
+ message: error instanceof Error ? error.message : String(error)
+ });
+ });
diff --git a/examples/finetuning/src/lib/train/protocol.ts b/examples/finetuning/src/lib/train/protocol.ts
new file mode 100644
index 00000000..e2a9c225
--- /dev/null
+++ b/examples/finetuning/src/lib/train/protocol.ts
@@ -0,0 +1,81 @@
+import type { Config } from '$lib/workspace/config';
+import type { StepData } from '$lib/workspace/runs.svelte';
+import type { IndexState } from '@piston-ml/piston-web';
+
+type WithRunId = { runId: string };
+
+export type WorkerCommand =
+ | {
+ type: 'start';
+ data: { runId: string; config: Config; resumeFrom?: Uint8Array };
+ }
+ | {
+ type: 'checkpoint.peekConfig';
+ data: { requestId: string; buffer: Uint8Array };
+ }
+ | { type: 'pause' }
+ | { type: 'resume' }
+ | { type: 'step' }
+ | { type: 'save' }
+ | { type: 'stop' }
+ | { type: 'inspectModel'; data: { requestId: string; config: Config } };
+
+type ReadyWorkerEvent = {
+ type: 'ready';
+};
+
+type LogWorkerEvent = {
+ type: 'log';
+ level: 'debug' | 'info' | 'warn' | 'error';
+ message: string;
+ source?: string;
+ lineno?: number;
+ colno?: number;
+};
+
+type ErrorWorkerEvent = {
+ type: 'error';
+ name?: string;
+ message: string;
+ stack?: string;
+};
+
+export type RunWorkerEventWithoutRunId =
+ | {
+ type: 'metrics';
+ data: { [metricName: string]: Omit };
+ metadata?: { step?: number };
+ }
+ | {
+ type: 'capture';
+ step: number;
+ boxes: unknown[];
+ statsById: Record;
+ width: number;
+ height: number;
+ queries: unknown[];
+ }
+ | { type: 'checkpoint'; buffer: Uint8Array }
+ | { type: 'restart'; buffer: Uint8Array }
+ | { type: 'paused' }
+ | { type: 'resumed' }
+ | { type: 'complete' }
+ | LogWorkerEvent
+ | ErrorWorkerEvent;
+
+export type RunWorkerEvent = RunWorkerEventWithoutRunId & WithRunId;
+
+export type WorkerEvent =
+ | ReadyWorkerEvent
+ | LogWorkerEvent
+ | ErrorWorkerEvent
+ | RunWorkerEvent
+ | { type: 'checkpoint.config'; requestId: string; config: Config }
+ | {
+ type: 'modelInspection';
+ requestId: string;
+ parameterCount: number;
+ vocabSize: number;
+ modelIndex: IndexState;
+ }
+ | { type: 'modelInspectionError'; requestId: string; message: string };
diff --git a/examples/finetuning/src/lib/train/session.ts b/examples/finetuning/src/lib/train/session.ts
new file mode 100644
index 00000000..0f07dd26
--- /dev/null
+++ b/examples/finetuning/src/lib/train/session.ts
@@ -0,0 +1,675 @@
+import type { Config } from '$lib/workspace/config';
+import type { StepData } from '$lib/workspace/runs.svelte';
+
+import {
+ CosineAnnealingLR,
+ ExponentialLR,
+ LinearLR,
+ type LRScheduler,
+ SequentialLR,
+ StepLR,
+ type Tensor
+} from '@piston-ml/piston-web';
+import * as piston from '@piston-ml/piston-web';
+
+import type { NaturalLanguageDataset } from './data/natural';
+import type { BuiltData } from './data/pipeline';
+import type { RunWorkerEvent, RunWorkerEventWithoutRunId, WorkerEvent } from './protocol';
+import type { GeneratableModel, NaturalCollateFnType } from './types';
+
+import { buildDataset, tensorWrap } from './data';
+import { filterDatasetByHeldoutSamples } from './data/filter';
+import { buildDataPipeline } from './data/pipeline';
+import { GPT, GPT2_BLOCK_SIZE, GPT2_VOCAB_SIZE } from './model/gpt';
+import {
+ type AnySchedulerState,
+ buildCheckpoint,
+ type CheckpointDataState,
+ type CheckpointExtra,
+ splitLoadedState
+} from './utils/checkpoint';
+// import { initTransformerParameters } from './utils/init';
+import { calculateParameterSum, createCollateFn, createModel } from './utils/model';
+import { MarkStepModeIfEnabled, WeakModeIfEnabled } from './utils/modes';
+import { configureOptimizers } from './utils/optim';
+import {
+ buildValidationExamplesSubset,
+ buildValidationLog,
+ computeLikelihoodMetrics,
+ computeNaturalValidationMetrics,
+ type NaturalValidationExamples,
+ prepareNaturalValidationExamples,
+ type ValidationStep
+} from './validation';
+
+// @ts-expect-error polyfill
+Symbol.dispose ||= Symbol.for('Symbol.dispose');
+
+export class TrainingSession {
+ readonly runId: string;
+ private config: Config;
+ private readonly post: (e: RunWorkerEventWithoutRunId) => void;
+ private readonly resumeFrom?: Uint8Array;
+
+ private paused = false;
+ private resolvePause: (() => void) | null = null;
+
+ private model!: GeneratableModel;
+ private optimizer!: piston.Optimizer;
+ private scheduler: LRScheduler | undefined;
+ private trainDataset!: NaturalLanguageDataset;
+ private blockSize!: number;
+
+ private isSetup: boolean = false;
+
+ private startTimeMs: number | null = null;
+ private lastLogTime: number | null = null;
+ private lastLogStep: number | null = null;
+ private stepCount: number = 0;
+
+ private dataPipeline!: BuiltData;
+
+ private validationExamples: NaturalValidationExamples | null = null;
+ private validationCollateFn: NaturalCollateFnType | null = null;
+ private validationDataset: NaturalLanguageDataset | null = null;
+ // This is a little bit gross, but it's a straightforward way to make sure we have valid targets
+ // when we resume from a checkpoint. If I ever do another pass over this code, this will be first
+ // to go.
+ private includeTargetsOnNextValidation: boolean = false;
+
+ constructor(
+ runId: string,
+ config: Config,
+ post: (e: WorkerEvent) => void,
+ resumeFrom?: Uint8Array
+ ) {
+ this.runId = runId;
+ this.config = config;
+ this.post = (e: RunWorkerEventWithoutRunId) =>
+ // We only post the subset of events that have runId in the payload
+ (post as (e: RunWorkerEvent) => void)({ ...e, runId: this.runId });
+ this.resumeFrom = resumeFrom;
+ if (resumeFrom) {
+ this.includeTargetsOnNextValidation = true;
+ }
+ }
+
+ async pause() {
+ if (this.paused) return;
+ this.paused = true;
+ await new Promise((resolve) => {
+ this.resolvePause = resolve;
+ });
+ this.resolvePause = null;
+ this.post({ type: 'paused' });
+ }
+
+ resume() {
+ this.paused = false;
+ this.post({ type: 'resumed' });
+ }
+
+ async save() {
+ try {
+ if (this.paused) {
+ try {
+ if (!this.model) {
+ // Defer save until model is ready
+ this.post({ type: 'log', level: 'info', message: 'Save requested before model ready' });
+ return;
+ }
+ await piston.gpu.markStep();
+ const buffer = await this.saveLatestCheckpoint();
+ this.post({ type: 'checkpoint', buffer });
+ } catch (e) {
+ this.post({
+ type: 'error',
+ message: String(e)
+ });
+ }
+ } else {
+ throw new Error('Saving during training is not supported');
+ }
+ } catch (e) {
+ this.post({ type: 'error', message: String(e) });
+ }
+ }
+
+ async saveLatestCheckpoint(): Promise> {
+ if (!this.model) throw new Error('No model available to save');
+ await piston.gpu.markStep();
+ // Derive dataset state if available
+ const dataState = {
+ blockSize: this.blockSize,
+ ...this.trainDataset.exportState()
+ };
+ const { tensors, extra } = buildCheckpoint(
+ this.model,
+ this.optimizer!,
+ this.stepCount,
+ this.config ? JSON.parse(JSON.stringify(this.config)) : null,
+ this.scheduler,
+ dataState,
+ this.startTimeMs ?? undefined
+ );
+ return piston.save(tensors, extra);
+ }
+
+ private logMetrics(
+ data: { [metricName: string]: Omit },
+ metadata?: { step?: number }
+ ) {
+ this.post({ type: 'metrics', data, metadata });
+ }
+
+ private async setup() {
+ if (this.isSetup) {
+ return;
+ }
+
+ // Log initial memory
+ const initialMemoryMB = Number(piston.gpu.usageBytes()) / (1024 * 1024);
+ console.debug(`Initial memory: ${initialMemoryMB} MB`);
+
+ // If resuming from a checkpoint, parse and use checkpoint config
+ let resumePayload: {
+ modelState: Record;
+ optimizerPacked?: { state: Record; paramGroups: piston.ParamGroupConfig[] };
+ schedulerState?: unknown;
+ numSteps: number;
+ config: Config;
+ dataState?: CheckpointDataState;
+ startTimeMs?: number;
+ } | null = null;
+
+ if (this.resumeFrom) {
+ const loaded = piston.load(this.resumeFrom, piston.gpu);
+ const split = splitLoadedState(
+ loaded as { state: Record; extra?: CheckpointExtra }
+ );
+ resumePayload = {
+ modelState: split.modelState,
+ optimizerPacked: split.optimizerState as unknown as {
+ state: Record;
+ paramGroups: piston.ParamGroupConfig[];
+ },
+ schedulerState: split.schedulerState,
+ numSteps: split.numSteps,
+ config: split.config,
+ dataState: split.dataState,
+ startTimeMs: split.startTimeMs
+ };
+ if (resumePayload.config) {
+ this.config = resumePayload.config as Config;
+ }
+ // If blockSize present in extras, prefer it
+ if (split.dataState && split.dataState.blockSize !== undefined) {
+ this.blockSize = split.dataState.blockSize;
+ }
+ }
+
+ if (this.config.training.vramLimitMb.present) {
+ piston.gpu.setVRAMLimit(BigInt(this.config.training.vramLimitMb.value * 1024 * 1024));
+ }
+
+ // Ensure shared-object allocation is enabled so buffer handles are stable across steps
+ piston.gpu.setSharedObjectAllocationEnabled(this.config.training.sharedObjectAllocation);
+ piston.gpu.setCachingEnabled(this.config.training.cachingEnabled);
+ piston.gpu.setInplaceSupport(this.config.training.inplaceSupport);
+
+ const trainDataset: NaturalLanguageDataset = buildDataset(this.config, 'train');
+ // Restore dataset state if present
+ if (resumePayload && resumePayload.dataState) {
+ const dsState = resumePayload.dataState;
+ await trainDataset.importState(dsState);
+ }
+ this.trainDataset = trainDataset;
+
+ const validationDisabled =
+ ('disableValidation' in this.trainDataset && this.trainDataset.disableValidation) || false;
+
+ this.validationExamples = null;
+ this.validationCollateFn = null;
+ this.validationDataset = null;
+
+ if (this.config.training.validation.present && !validationDisabled) {
+ this.validationDataset = buildDataset(this.config, 'val');
+ this.validationCollateFn = createCollateFn(tensorWrap);
+ this.validationExamples = await prepareNaturalValidationExamples(
+ this.config,
+ this.validationDataset!
+ );
+ // Filter training dataset against holdout examples without duplication
+ const validationSequences: number[][] = this.validationExamples.naturalSequences;
+ this.trainDataset = filterDatasetByHeldoutSamples(
+ this.trainDataset,
+ this.config.data.dataset,
+ validationSequences
+ );
+ console.debug(
+ `Prepared ${validationSequences.length} validation examples for batch generation`
+ );
+ }
+
+ if (validationDisabled) {
+ console.debug('Validation disabled by dataset; skipping validation and holdout filtering.');
+ }
+
+ const vocabSize = GPT2_VOCAB_SIZE;
+ const blockSize = GPT2_BLOCK_SIZE;
+ this.blockSize = blockSize;
+
+ console.debug(
+ `Created dataset ${this.trainDataset.name} with vocab size ${vocabSize} and block size ${blockSize}`
+ );
+
+ // Create model
+ this.model = createModel(this.config);
+
+ // If starting from scratch, initialize model parameters
+ if (!resumePayload) {
+ // initTransformerParameters(this.model, this.config);
+
+ // We need to flatten down initialization to the constant tensors they're on top of
+ await piston.gpu.markStep();
+
+ const parameterSum = new BigUint64Array(
+ new Float64Array([await (await calculateParameterSum(this.model).to('cpu')).item()]).buffer
+ );
+ console.debug(`Initialization parameter sum: ${parameterSum}`);
+ }
+
+ // Build and store the training data pipeline (iterator bound to current dataset/collate)
+ this.dataPipeline = await buildDataPipeline(this.config, this.trainDataset);
+
+ // If resuming, load model state BEFORE creating the optimizer so param identities match
+ let startStep = 0;
+ if (resumePayload) {
+ this.model.loadStateDict(resumePayload.modelState, { strict: false });
+ startStep = (resumePayload.numSteps ?? 0) + 1;
+ this.stepCount = startStep;
+ // If checkpoint carried a startTimeMs, use it for wall-clock continuity
+ if (typeof resumePayload.startTimeMs === 'number') {
+ this.startTimeMs = resumePayload.startTimeMs;
+ }
+ }
+
+ // Create optimizer based on model type, using the (possibly restored) model parameters
+ const optimizer = configureOptimizers(
+ this.model,
+ ['transformer.h'],
+ 'lm_head',
+ this.config.optimizer,
+ piston.gpu
+ );
+ this.optimizer = optimizer;
+
+ // If resuming, load optimizer state NOW that groups refer to current model parameters
+ if (resumePayload && resumePayload.optimizerPacked) {
+ optimizer.loadStateDict(resumePayload.optimizerPacked as piston.StateDict);
+ }
+
+ // Create learning rate scheduler if configured
+ if (this.config.optimizer.lrScheduler.present) {
+ const lrConfig = this.config.optimizer.lrScheduler;
+ switch (lrConfig.type) {
+ case 'step':
+ this.scheduler = new StepLR(
+ this.optimizer,
+ lrConfig.stepSchedule.stepSize,
+ lrConfig.stepSchedule.gamma
+ );
+ break;
+ case 'cosine':
+ this.scheduler = new CosineAnnealingLR(
+ this.optimizer,
+ lrConfig.cosineAnnealingSchedule.tMax,
+ lrConfig.cosineAnnealingSchedule.etaMin
+ );
+ break;
+ case 'exponential':
+ this.scheduler = new ExponentialLR(this.optimizer, lrConfig.exponentialSchedule.gamma);
+ break;
+ case 'linear':
+ this.scheduler = new LinearLR(
+ this.optimizer,
+ lrConfig.linearSchedule.startFactor,
+ lrConfig.linearSchedule.endFactor,
+ lrConfig.linearSchedule.totalIters
+ );
+ break;
+ default:
+ throw new Error(`Unknown scheduler type: ${lrConfig.type}`);
+ }
+
+ if (this.scheduler && this.config.optimizer.warmupSteps.present) {
+ const n = this.config.optimizer.warmupSteps.value;
+ if (n > 0) {
+ const warmup = new LinearLR(optimizer, 1e-8, 1.0, n);
+ this.scheduler = new SequentialLR(optimizer, [warmup, this.scheduler], [n]);
+ }
+ }
+ } else if (this.config.optimizer.warmupSteps.present) {
+ const n = this.config.optimizer.warmupSteps.value;
+ if (n > 0) {
+ this.scheduler = new LinearLR(optimizer, 1e-8, 1.0, n);
+ }
+ }
+
+ // If resuming, load scheduler state after it is created
+ if (resumePayload && this.scheduler && resumePayload.schedulerState) {
+ this.scheduler.loadStateDict(resumePayload.schedulerState as AnySchedulerState);
+ }
+
+ this.model.train();
+
+ this.isSetup = true;
+ }
+
+ async step({ manual = false }: { manual?: boolean } = {}): Promise<
+ IteratorResult
+ > {
+ if (this.startTimeMs == null) {
+ this.startTimeMs = Date.now();
+ }
+ if (this.lastLogStep == null) {
+ this.lastLogStep = this.stepCount;
+ }
+ try {
+ const iterNext = await this.dataPipeline.train.iterator.next();
+ if (iterNext.done) {
+ return { done: true, value: 'completed' };
+ }
+ const batch = iterNext.value;
+ performance.mark('stepStart');
+ // Reset peak GPU memory tracking at the start of the step
+ piston.gpu.markUsageBytesStep();
+
+ let isLastStep = false;
+ if (
+ this.config.training.limitTraining.present &&
+ this.stepCount + 1 >= this.config.training.limitTraining.steps
+ ) {
+ console.log(
+ `Stopping training at step ${this.stepCount} because it reached the limit of ${this.config.training.limitTraining.steps} steps`
+ );
+ isLastStep = true;
+ }
+
+ const loggingStep =
+ manual || isLastStep || this.stepCount % this.config.training.logSteps === 0;
+
+ const weakModeUntilAfterBackward = new WeakModeIfEnabled(
+ this.config.training.useWeakTensorReferences,
+ {
+ label: 'train/forward_through_backward'
+ }
+ );
+
+ let loss: Tensor;
+ try {
+ // For GPT: batch contains [inputs, targets]
+ const { tensors } = batch;
+ const [inputs, gptTargets] = tensors;
+ const [, computedLoss] = (this.model as GPT).forward(await inputs.to('gpu'), {
+ targets: await gptTargets.to('gpu')
+ });
+
+ if (!computedLoss) {
+ throw new Error('No loss tensor returned from decoder-only model');
+ }
+
+ loss = computedLoss;
+
+ weakModeUntilAfterBackward.pin(loss);
+
+ loss.backward();
+ } finally {
+ weakModeUntilAfterBackward[Symbol.dispose]();
+ }
+
+ const weakModeForOptimizerStep = new WeakModeIfEnabled(
+ this.config.training.useWeakTensorReferences,
+ {
+ label: 'train/optimizer_step'
+ }
+ );
+
+ let gradNorm: Tensor | undefined;
+ try {
+ const weakMarkStepMode = new MarkStepModeIfEnabled(
+ this.config.training.useWeakTensorReferences
+ );
+ weakModeForOptimizerStep.pin(loss);
+
+ if (this.config.training.gradNorm.track) {
+ if (this.config.training.clipGradNorm.present) {
+ gradNorm = weakModeForOptimizerStep.pin(
+ piston.clipGradNorm_(this.model.parameters(), this.config.training.clipGradNorm.value)
+ );
+ } else if (loggingStep) {
+ // If we're not clipping gradients, we can just get the total gradient norm
+ gradNorm = weakModeForOptimizerStep.pin(
+ piston.getTotalGradNorm(this.model.parameters())
+ );
+ }
+ }
+
+ try {
+ // await this.optimizer.step();
+ await piston.gpu.markStep();
+ } finally {
+ weakMarkStepMode[Symbol.dispose]();
+ }
+ } finally {
+ // TODO: decide if it's okay that we're disposing the mode twice here
+ weakModeForOptimizerStep[Symbol.dispose]();
+ }
+
+ const finalWeakModeForStep = new WeakModeIfEnabled(
+ this.config.training.useWeakTensorReferences,
+ {
+ label: 'train/final'
+ }
+ );
+
+ try {
+ // We've kept loss strong; we'll want to make sure we get rid of it
+ // Batch tensors are created outside of weak mode, so we manually mark them as weak
+ finalWeakModeForStep.markWeak([loss, gradNorm, batch.tensors]);
+
+ this.optimizer.zeroGrad(true);
+
+ // Step learning rate scheduler if present
+ if (this.scheduler) {
+ this.scheduler.step();
+ }
+
+ if (
+ this.config.training.validation.present &&
+ (this.stepCount % this.config.training.validation.valSteps === 0 || isLastStep) &&
+ this.validationExamples &&
+ this.validationDataset &&
+ this.validationCollateFn
+ ) {
+ try {
+ let valLoss = Number.NaN;
+ let perplexity = Number.NaN;
+ let validationLog: Record> = {};
+
+ if (this.validationExamples) {
+ if (this.config.training.validation.completions.present) {
+ let validationExamplesSubset: NaturalValidationExamples | null = null;
+ if (this.config.training.validation.completions.amount === 'subset') {
+ validationExamplesSubset = buildValidationExamplesSubset(
+ this.validationExamples,
+ this.config.training.validation.completions.subsetSize
+ );
+ } else {
+ validationExamplesSubset = this.validationExamples;
+ }
+ const validationStepData = await computeNaturalValidationMetrics(
+ this.model,
+ this.validationDataset,
+ validationExamplesSubset as NaturalValidationExamples,
+ this.config.training.validation
+ );
+ validationLog = buildValidationLog(validationStepData);
+ if (this.includeTargetsOnNextValidation) {
+ this.includeTargetsOnNextValidation = false;
+ }
+ }
+
+ const result = await computeLikelihoodMetrics(
+ this.model,
+ this.validationExamples!,
+ this.validationCollateFn!
+ );
+
+ valLoss = result.valLoss;
+ perplexity = result.perplexity;
+
+ const logData: Record> = {
+ ...validationLog,
+ 'validation/loss': valLoss,
+ 'validation/perplexity': perplexity
+ };
+ this.logMetrics(logData, { step: this.stepCount });
+ }
+ } catch (error) {
+ console.error('Error during batch validation:', error);
+ }
+ }
+
+ if (loggingStep) {
+ const currentTime = Date.now();
+ const totalElapsedSeconds = (currentTime - this.startTimeMs!) / 1000;
+
+ // Calculate delta time and steps since last log
+ const deltaTime = (currentTime - this.lastLogTime!) / 1000;
+ const deltaSteps = this.stepCount - this.lastLogStep!;
+
+ // Calculate steps per second and words per second based on delta
+ const stepsPerSecond = deltaSteps > 0 ? deltaSteps / deltaTime : 0;
+
+ // Calculate words per second (tokens per second)
+ // Get sequence length from the first tensor in the batch
+ let sequenceLength = 0;
+
+ // Encoder-decoder will have three tensors in its batch, but we can just use the first one
+ const [inputs] = batch.tensors;
+ sequenceLength = inputs.shape[1]; // [batch_size, seq_len]
+
+ const tokensPerStep = this.config.training.batchSize * sequenceLength;
+ const tokensPerSecond = deltaSteps > 0 ? (deltaSteps * tokensPerStep) / deltaTime : 0;
+
+ const activeMap = piston.__pistonActiveTensors();
+ const activeTensors = Array.from(activeMap.values()).reduce((s, v) => s + v.length, 0);
+
+ let lossItem: number | null = null;
+
+ const lossCpu = await loss.to('cpu');
+ lossItem = await lossCpu.item();
+
+ if (lossItem === null) {
+ throw new Error('Loss item is null?');
+ }
+
+ const peakUsageMb = Number(piston.gpu.peakUsageBytes()) / (1024 * 1024);
+
+ const logData: Record = {
+ 'train/loss': lossItem,
+ 'allocation/active_tensor_count': activeTensors,
+ 'allocation/gpu_memory_mb': peakUsageMb,
+ 'speed/steps_per_second': stepsPerSecond,
+ 'speed/step': this.stepCount,
+ 'speed/tokens_per_second': tokensPerSecond,
+ 'speed/wall_clock_seconds': totalElapsedSeconds
+ };
+
+ if (gradNorm) {
+ const gradNormCpu = await gradNorm.to('cpu');
+ const gradNormItem = await gradNormCpu.item();
+ if (this.config.training.gradNorm.errorIfNonfinite && !isFinite(gradNormItem)) {
+ throw new Error(`Gradient norm was nonfinite, so it cannot be clipped.`);
+ }
+ logData['train/grad_norm'] = gradNormItem;
+ }
+
+ // Log current learning rate if scheduler is present
+ const currentLr = this.optimizer.paramGroups[0].lr;
+ if (currentLr) {
+ logData['optimizer/learning_rate'] = currentLr;
+ }
+
+ this.logMetrics(logData, { step: this.stepCount });
+
+ // Update last log time and step
+ this.lastLogTime = currentTime;
+ this.lastLogStep = this.stepCount;
+ }
+
+ // Trigger periodic checkpoint save (non-restart) if configured
+ if (this.config.training.checkpointEverySteps.present) {
+ const checkpointEvery = this.config.training.checkpointEverySteps.value;
+ if (checkpointEvery > 0 && (this.stepCount + 1) % checkpointEvery === 0) {
+ try {
+ const bytes = await this.saveLatestCheckpoint();
+ this.post({ type: 'checkpoint', buffer: bytes });
+ } catch (e) {
+ // Non-fatal; continue training
+ this.post({ type: 'log', level: 'warn', message: String(e) });
+ }
+ }
+ }
+
+ // Trigger periodic restart if configured
+ const restartEvery = this.config.training.restartEverySteps ?? 0;
+ const willRestart = restartEvery > 0 && (this.stepCount + 1) % restartEvery === 0;
+ if (willRestart) {
+ console.debug(`Routine restart at step ${this.stepCount}`);
+ await piston.gpu.markStep();
+ const bytes = await this.saveLatestCheckpoint();
+ this.post({ type: 'restart', buffer: bytes });
+ return { done: true, value: 'restarted' };
+ }
+
+ if (isLastStep) {
+ return { done: true, value: 'completed' };
+ }
+ } finally {
+ finalWeakModeForStep[Symbol.dispose]();
+ }
+
+ this.stepCount++;
+
+ performance.mark('stepEnd');
+ } catch (error) {
+ console.error(`Error during training: ${error}`);
+ throw error;
+ }
+ return { done: false, value: undefined };
+ }
+
+ async start(): Promise {
+ await this.setup();
+ while (true) {
+ if (this.paused) {
+ if (this.resolvePause) {
+ this.resolvePause();
+ }
+ return;
+ }
+ const { done, value } = await this.step();
+ if (done) {
+ if (value === 'completed') {
+ this.post({ type: 'complete' });
+ break;
+ }
+ if (value === 'restarted') {
+ return;
+ }
+ }
+ }
+ }
+}
diff --git a/examples/finetuning/src/lib/train/tokenizer.ts b/examples/finetuning/src/lib/train/tokenizer.ts
new file mode 100644
index 00000000..44c8abb7
--- /dev/null
+++ b/examples/finetuning/src/lib/train/tokenizer.ts
@@ -0,0 +1,2457 @@
+/**
+ * @fileoverview Simplified Tokenizer implementation adapted from huggingface/transformers.js
+ */
+
+import { PUBLIC_DATA_URL } from '$env/static/public';
+import { Template } from '@huggingface/jinja';
+import { int32, Tensor, tensor } from '@piston-ml/piston-web';
+
+/* eslint-disable @typescript-eslint/no-unsafe-declaration-merging */
+abstract class Callable {
+ /**
+ * Creates a new instance of the Callable class.
+ */
+ constructor() {
+ /**
+ * Creates a closure that delegates to a private method 'call' with the given arguments.
+ * @param args Zero or more arguments to pass to the 'call' method.
+ * @returns The result of calling the 'call' method.
+ */
+ const closure = ((...args: Args) => {
+ return (closure as unknown as { call: (...args: Args) => Return }).call(...args);
+ }) as unknown as (...args: Args) => Return;
+ return Object.setPrototypeOf(closure, new.target.prototype) as unknown as this &
+ ((...args: Args) => Return);
+ }
+
+ /**
+ * This method should be implemented in subclasses to provide the
+ * functionality of the callable object.
+ *
+ * @param args Zero or more arguments to pass to the 'call' method.
+ * @throws {Error} If the subclass does not implement the `call` method.
+ */
+ protected abstract call(..._args: Args): Return;
+}
+interface Callable {
+ (...args: Args): Return;
+}
+
+// Discriminated config helpers
+type WithType = { type: TType };
+
+// Normalizer configs
+type NormalizerSequenceConfig = WithType<'Sequence'> & { normalizers: NormalizerConfig[] };
+type NFCConfig = WithType<'NFC'>;
+type NFDConfig = WithType<'NFD'>;
+type NFKCConfig = WithType<'NFKC'>;
+type NFKDConfig = WithType<'NFKD'>;
+type StripConfig = WithType<'Strip'> & { stripLeft?: boolean; stripRight?: boolean };
+type LowercaseConfig = WithType<'Lowercase'>;
+type PrependConfig = WithType<'Prepend'> & { prepend: string };
+type NormalizerConfig =
+ | NormalizerSequenceConfig
+ | NFCConfig
+ | NFDConfig
+ | NFKCConfig
+ | NFKDConfig
+ | StripConfig
+ | LowercaseConfig
+ | PrependConfig;
+
+// PreTokenizer configs and options
+type PreTokenizeOptions = { sectionIndex?: number };
+type PreTokenizerSequenceConfig = WithType<'Sequence'> & { pretokenizers: PreTokenizerConfig[] };
+type WhitespacePreTokenizerConfig = WithType<'Whitespace'>;
+type WhitespaceSplitConfig = WithType<'WhitespaceSplit'>;
+type MetaspacePreTokenizerConfig = WithType<'Metaspace'> & {
+ addPrefixSpace: boolean;
+ replacement: string;
+ strRep?: string;
+ prependScheme?: 'first' | 'never' | 'always';
+};
+type ByteLevelPreTokenizerConfig = WithType<'ByteLevel'> & {
+ addPrefixSpace: boolean;
+ trimOffsets: boolean;
+ useRegex?: boolean;
+};
+type PreTokenizerConfig =
+ | PreTokenizerSequenceConfig
+ | WhitespacePreTokenizerConfig
+ | WhitespaceSplitConfig
+ | MetaspacePreTokenizerConfig
+ | ByteLevelPreTokenizerConfig;
+
+// PostProcessor configs and options
+type PostProcessorOptions = { addSpecialTokens?: boolean };
+type PostProcessorResult = { tokens: string[]; tokenTypeIds?: number[] };
+type PostProcessorSequenceConfig = WithType<'Sequence'> & { processors: PostProcessorConfig[] };
+type ByteLevelPostProcessorConfig = WithType<'ByteLevel'>;
+type PostProcessorConfig = PostProcessorSequenceConfig | ByteLevelPostProcessorConfig;
+
+// Decoder configs
+type ByteLevelDecoderConfig = WithType<'ByteLevel'> & { trimOffsets?: boolean };
+type ByteFallbackConfig = WithType<'ByteFallback'>;
+type FuseDecoderConfig = WithType<'Fuse'>;
+type StripDecoderConfig = WithType<'Strip'> & { content: string; start: number; stop: number };
+type DecoderSequenceConfig = WithType<'Sequence'> & { decoders: DecoderConfig[] };
+type BPEDecoderConfig = WithType<'BPEDecoder'> & { suffix: string };
+type DecoderConfig =
+ | ByteLevelDecoderConfig
+ | ByteFallbackConfig
+ | FuseDecoderConfig
+ | StripDecoderConfig
+ | DecoderSequenceConfig
+ | BPEDecoderConfig;
+
+// Model configs
+export interface TokenizerModelConfig {
+ fuseUnk: boolean;
+ byteFallback: boolean;
+ ignoreMerges: boolean;
+}
+
+type BPEConfig = WithType<'BPE'> &
+ TokenizerModelConfig & {
+ vocab: Record;
+ merges: string[] | [string, string][];
+ unkToken: string;
+ endOfWordSuffix?: string;
+ continuingSubwordSuffix?: string | null;
+ };
+
+type TokenizerModelFactoryConfig = BPEConfig; // Extend when additional models are added
+
+// Tokenizer JSON and runtime config
+interface TokenizerJSON {
+ normalizer: NormalizerConfig | null;
+ preTokenizer: PreTokenizerConfig | null;
+ model: TokenizerModelFactoryConfig;
+ postProcessor: PostProcessorConfig | null;
+ decoder: DecoderConfig | null;
+ addedTokens: AddedTokenConfig[];
+}
+
+interface TokenizerConfig {
+ [key: string]: unknown;
+ additionalSpecialTokens?: string[];
+ modelMaxLength: number;
+ removeSpace: boolean;
+ cleanUpTokenizationSpaces?: boolean;
+ paddingSide?: 'left' | 'right';
+ addBosToken?: boolean;
+ addEosToken?: boolean;
+ chatTemplate?: null | Array<{ name: string; template: string }> | Record;
+}
+
+const TOKENIZER_URL = PUBLIC_DATA_URL + 'tokenizer';
+
+/**
+ * Loads a tokenizer from the specified path.
+ * @param tokenizerName The path to the tokenizer directory.
+ * @returns A promise that resolves with tokenizer JSON and config.
+ */
+async function loadTokenizer(tokenizerName: string): Promise<[TokenizerJSON, TokenizerConfig]> {
+ return Promise.all([
+ fetchJSON(TOKENIZER_URL, `${tokenizerName}/tokenizer.json`).then(
+ camelCaseKeysDeep
+ ),
+ fetchJSON(TOKENIZER_URL, `${tokenizerName}/tokenizer_config.json`).then(
+ camelCaseKeysDeep
+ )
+ ]);
+}
+
+function isPlainObject(value: unknown): value is Record {
+ return (
+ value !== null && typeof value === 'object' && Object.getPrototypeOf(value) === Object.prototype
+ );
+}
+
+function toCamelKey(key: string): string {
+ return key.includes('_')
+ ? key.replace(/_+([a-zA-Z0-9])/g, (_m, c: string) => c.toUpperCase())
+ : key;
+}
+
+function camelCaseKeysDeep(input: T): T {
+ if (Array.isArray(input)) {
+ return input.map((item) => camelCaseKeysDeep(item)) as unknown as T;
+ }
+ if (isPlainObject(input)) {
+ const obj = input as Record;
+ const out: Record = Object.create(null);
+ for (const [key, value] of Object.entries(obj)) {
+ const transformed = camelCaseKeysDeep(value);
+ // Preserve original snake_case for compatibility
+ out[key] = transformed;
+ const camelKey = toCamelKey(key);
+ if (camelKey !== key && !(camelKey in out)) {
+ out[camelKey] = transformed;
+ }
+ }
+ return out as unknown as T;
+ }
+ return input;
+}
+
+// Minimal fetch wrapper used here; replace with project-util if available
+async function fetchJSON(basePath: string, fileName: string): Promise {
+ const url = `${basePath.replace(/\/$/, '')}/${fileName}`;
+ const res = await fetch(url);
+ if (!res.ok) throw new Error(`Failed to load ${fileName} from ${url}`);
+ return res.json() as Promise;
+}
+
+/**
+ * Helper function to convert an Object to a Map
+ * @param obj The object to convert.
+ * @returns The map.
+ */
+function objectToMap(obj: Record): Map {
+ return new Map(Object.entries(obj));
+}
+
+/**
+ * Helper function to fuse consecutive unknown tokens.
+ * @param arr The list of input tokens
+ * @param tokensToIds The mapping from tokens to token ids.
+ * @param unkTokenId The value to fuse on.
+ */
+function fuseUnk(arr: string[], tokensToIds: Map, unkTokenId: number): string[] {
+ const fused = [];
+ let i = 0;
+ while (i < arr.length) {
+ fused.push(arr[i]);
+ if ((tokensToIds.get(arr[i]) ?? unkTokenId) !== unkTokenId) {
+ ++i;
+ continue;
+ }
+
+ while (++i < arr.length && (tokensToIds.get(arr[i]) ?? unkTokenId) === unkTokenId) {
+ if (tokensToIds.get(fused[fused.length - 1]) !== unkTokenId) {
+ fused[fused.length - 1] += arr[i];
+ }
+ }
+ }
+
+ return fused;
+}
+
+/**
+ * Split a string on whitespace.
+ * @param text The text to split.
+ * @returns The split string.
+ */
+function whitespaceSplit(text: string): string[] {
+ return text.match(/\S+/g) || [];
+}
+
+/**
+ * Represent a token added by the user on top of the existing Model vocabulary.
+ * AddedToken can be configured to specify the behavior they should have in various situations like:
+ * - Whether they should only match single words
+ * - Whether to include any whitespace on its left or right
+ */
+interface AddedTokenConfig {
+ content: string;
+ id: number;
+ singleWord?: boolean;
+ lstrip?: boolean;
+ rstrip?: boolean;
+ normalized?: boolean;
+ special?: boolean;
+}
+class AddedToken {
+ content: string;
+ id: number;
+ singleWord: boolean;
+ lstrip: boolean;
+ rstrip: boolean;
+ special: boolean;
+ normalized: boolean | null;
+ /**
+ * Creates a new instance of AddedToken.
+ * @param config Added token configuration object.
+ * @param config.content The content of the added token.
+ * @param config.id The id of the added token.
+ * @param config.singleWord Whether this token must be a single word or can break words.
+ * @param config.lstrip Whether this token should strip whitespaces on its left.
+ * @param config.rstrip Whether this token should strip whitespaces on its right.
+ * @param config.normalized Whether this token should be normalized.
+ * @param config.special Whether this token is special.
+ */
+ constructor(config: AddedTokenConfig) {
+ this.content = config.content;
+ this.id = config.id;
+ this.singleWord = config.singleWord ?? false;
+ this.lstrip = config.lstrip ?? false;
+ this.rstrip = config.rstrip ?? false;
+ this.special = config.special ?? false;
+ this.normalized = config.normalized ?? null;
+ }
+}
+
+export interface TokenizerModelConfig {
+ fuseUnk: boolean;
+ byteFallback: boolean;
+ ignoreMerges: boolean;
+}
+
+/**
+ * Abstract base class for tokenizer models.
+ */
+export class TokenizerModel extends Callable<[string[]], string[]> {
+ config: TokenizerModelConfig;
+ vocab: string[];
+ tokensToIds: Map;
+ unkTokenId?: number;
+ unkToken?: string;
+ endOfWordSuffix?: string;
+ fuseUnk: boolean;
+ /**
+ * Creates a new instance of TokenizerModel.
+ * @param config The configuration object for the TokenizerModel.
+ */
+ constructor(config: TokenizerModelConfig) {
+ super();
+ this.config = config;
+
+ this.vocab = [];
+
+ this.tokensToIds = new Map();
+
+ this.unkTokenId = undefined;
+ this.unkToken = undefined;
+ this.endOfWordSuffix = undefined;
+
+ this.fuseUnk = this.config.fuseUnk ?? false;
+ }
+
+ /**
+ * Instantiates a new TokenizerModel instance based on the configuration object provided.
+ * @param config The configuration object for the TokenizerModel.
+ * @param _args Optional arguments to pass to the specific TokenizerModel constructor.
+ * @returns A new instance of a TokenizerModel.
+ * @throws Will throw an error if the TokenizerModel type in the config is not recognized.
+ */
+ static fromConfig(config: TokenizerModelFactoryConfig, ..._args: unknown[]): TokenizerModel {
+ switch (config.type) {
+ case 'BPE':
+ default:
+ return new BPE(config);
+ }
+ }
+
+ /**
+ * Internal function to call the TokenizerModel instance.
+ * @param tokens The tokens to encode.
+ * @returns The encoded tokens.
+ */
+ protected call(...[tokens]: [string[]]): string[] {
+ tokens = this.encode(tokens);
+ if (this.fuseUnk) {
+ // Fuse unknown tokens
+ tokens = fuseUnk(tokens, this.tokensToIds, this.unkTokenId as number);
+ }
+ return tokens;
+ }
+
+ /**
+ * Encodes a list of tokens into a list of token IDs.
+ * @param tokens The tokens to encode.
+ * @returns The encoded tokens.
+ * @throws Will throw an error if not implemented in a subclass.
+ */
+ encode(_tokens: string[]): string[] {
+ throw Error('encode should be implemented in subclass.');
+ }
+
+ /**
+ * Converts a list of tokens into a list of token IDs.
+ * @param tokens The tokens to convert.
+ * @returns The converted token IDs.
+ */
+ convertTokensToIds(tokens: string[]): number[] {
+ return tokens.map((t) => this.tokensToIds.get(t) ?? (this.unkTokenId as number));
+ }
+
+ /**
+ * Converts a list of token IDs into a list of tokens.
+ * @param ids The token IDs to convert.
+ * @returns The converted tokens.
+ */
+ convertIdsToTokens(ids: number[] | bigint[]): string[] {
+ return ids.map((i) => this.vocab[Number(i)] ?? (this.unkToken as string));
+ }
+}
+
+/**
+ * Returns list of utf-8 byte and a mapping to unicode strings.
+ * Specifically avoids mapping to whitespace/control characters the BPE code barfs on.
+ * @returns Object with utf-8 byte keys and unicode string values.
+ */
+const BYTES_TO_UNICODE = (() => {
+ // Returns list of utf-8 byte and a mapping to unicode strings.
+ // We specifically avoids mapping to whitespace/control characters the bpe code barfs on.
+
+ const bs = [
+ ...Array.from(
+ { length: '~'.charCodeAt(0) - '!'.charCodeAt(0) + 1 },
+ (_, i) => i + '!'.charCodeAt(0)
+ ),
+ ...Array.from(
+ { length: '¬'.charCodeAt(0) - '¡'.charCodeAt(0) + 1 },
+ (_, i) => i + '¡'.charCodeAt(0)
+ ),
+ ...Array.from(
+ { length: 'ÿ'.charCodeAt(0) - '®'.charCodeAt(0) + 1 },
+ (_, i) => i + '®'.charCodeAt(0)
+ )
+ ];
+ const cs = bs.slice();
+ let n = 0;
+ for (let b = 0; b < 256; ++b) {
+ if (!bs.includes(b)) {
+ bs.push(b);
+ cs.push(256 + n);
+ n += 1;
+ }
+ }
+ const ccs = cs.map((n) => String.fromCharCode(n));
+ return Object.fromEntries(bs.map((b, i) => [b, ccs[i]]));
+})();
+
+const UNICODE_TO_BYTES = Object.fromEntries(
+ Object.entries(BYTES_TO_UNICODE).map(([key, value]) => [value, key])
+);
+
+interface BPENode {
+ token: string;
+ bias: number;
+ score?: number;
+ prev?: BPENode;
+ next?: BPENode;
+}
+
+/**
+ * BPE class for encoding text into Byte-Pair-Encoding (BPE) tokens.
+ */
+class BPE extends TokenizerModel {
+ merges!: [string, string][];
+ bpeRanks!: Map;
+ continuingSubwordSuffix!: string | null;
+ byteFallback!: boolean;
+ textEncoder!: TextEncoder;
+ ignoreMerges!: boolean;
+ maxLengthToCache!: number;
+ cacheCapacity!: number;
+ cache!: LRUCache;
+
+ constructor(config: BPEConfig) {
+ super(config);
+ this.tokensToIds = objectToMap(config.vocab);
+ this.unkTokenId = this.tokensToIds.get(config.unkToken) as number;
+ this.unkToken = config.unkToken as string;
+ this.vocab = new Array(this.tokensToIds.size);
+ for (const [key, value] of this.tokensToIds) {
+ this.vocab[value] = key;
+ }
+ const useNewMergeFormat = Array.isArray(config.merges[0]);
+ this.merges = useNewMergeFormat
+ ? (config.merges as [string, string][])
+ : (config.merges as string[]).map((x) => x.split(' ', 2) as [string, string]);
+ this.bpeRanks = new Map(this.merges.map((x, i) => [JSON.stringify(x), i])) as Map<
+ string,
+ number
+ >;
+ this.endOfWordSuffix = config.endOfWordSuffix as string | undefined;
+ this.continuingSubwordSuffix = (config.continuingSubwordSuffix ?? null) as string | null;
+ this.byteFallback = (this.config.byteFallback ?? false) as boolean;
+ if (this.byteFallback) {
+ this.textEncoder = new TextEncoder();
+ }
+ this.ignoreMerges = (this.config.ignoreMerges ?? false) as boolean;
+ this.maxLengthToCache = 256;
+ this.cacheCapacity = 10000;
+ this.cache = new LRUCache(this.cacheCapacity);
+ }
+ clearCache() {
+ this.cache.clear();
+ }
+ bpe(token: string): string[] {
+ if (token.length === 0) {
+ return [];
+ }
+ const cached = this.cache.get(token);
+ if (cached !== undefined) {
+ return cached;
+ }
+ const word = Array.from(token);
+ if (this.endOfWordSuffix) {
+ word[word.length - 1] += this.endOfWordSuffix;
+ }
+ let result: string[] = [];
+ if (word.length > 1) {
+ const queue = new PriorityQueue((a, b) => (a.score as number) < (b.score as number));
+ let startingNode: BPENode = {
+ token: word[0],
+ bias: 0,
+ prev: undefined,
+ next: undefined
+ };
+ let previousNode = startingNode;
+ for (let i = 1; i < word.length; ++i) {
+ const currentNode: BPENode = {
+ bias: i / word.length,
+ token: word[i],
+ prev: previousNode,
+ next: undefined
+ };
+ previousNode.next = currentNode;
+ this.addNode(queue, previousNode);
+ previousNode = currentNode;
+ }
+ while (!queue.isEmpty()) {
+ const node = queue.pop() as BPENode & {
+ deleted?: boolean;
+ prev?: BPENode & { deleted?: boolean };
+ next?: BPENode & { deleted?: boolean };
+ };
+ if (node.deleted || !node.next || node.next.deleted) continue;
+ node.deleted = true;
+ node.next.deleted = true;
+ if (node.prev) {
+ const newPreviousNode = { ...(node.prev as BPENode) } as BPENode;
+ node.prev.deleted = true;
+ node.prev = newPreviousNode;
+ if (newPreviousNode.prev) {
+ (newPreviousNode.prev as BPENode).next = newPreviousNode;
+ } else {
+ startingNode = newPreviousNode;
+ }
+ }
+ const merged: BPENode = {
+ token: node.token + (node.next as BPENode).token,
+ bias: node.bias,
+ prev: node.prev,
+ next: (node.next as BPENode).next
+ };
+ if (merged.prev) {
+ (merged.prev as BPENode).next = merged;
+ this.addNode(queue, merged.prev as BPENode);
+ } else {
+ startingNode = merged;
+ }
+ if (merged.next) {
+ (merged.next as BPENode).prev = merged;
+ this.addNode(queue, merged);
+ }
+ }
+ for (
+ let currentNode: BPENode | undefined = startingNode;
+ currentNode !== undefined;
+ currentNode = currentNode.next
+ ) {
+ result.push(currentNode.token);
+ }
+ } else {
+ result = word;
+ }
+ if (this.continuingSubwordSuffix) {
+ for (let i = 0; i < result.length - 1; ++i) {
+ result[i] += this.continuingSubwordSuffix;
+ }
+ }
+ if (token.length < this.maxLengthToCache) {
+ this.cache.put(token, result);
+ }
+ return result;
+ }
+ private addNode(queue: PriorityQueue, node: BPENode) {
+ const rank = this.bpeRanks.get(JSON.stringify([node.token, (node.next as BPENode).token]));
+ if (rank !== undefined) {
+ node.score = rank + node.bias;
+ queue.push(node);
+ }
+ }
+ encode(tokens: string[]): string[] {
+ const outputTokens: string[] = [];
+ for (const token of tokens) {
+ if (this.ignoreMerges && this.tokensToIds.has(token)) {
+ outputTokens.push(token);
+ continue;
+ }
+ const bpeTokenList = this.bpe(token);
+ for (const t of bpeTokenList) {
+ if (this.tokensToIds.has(t)) {
+ outputTokens.push(t);
+ } else if (this.byteFallback) {
+ const byteTokens = Array.from(this.textEncoder.encode(t)).map(
+ (x) => `<0x${x.toString(16).toUpperCase().padStart(2, '0')}>`
+ );
+ if (byteTokens.every((x) => this.tokensToIds.has(x))) {
+ outputTokens.push(...byteTokens);
+ } else {
+ outputTokens.push(this.unkToken as string);
+ }
+ } else {
+ outputTokens.push(this.unkToken as string);
+ }
+ }
+ }
+ return outputTokens;
+ }
+}
+
+/**
+ * A base class for text normalization.
+ */
+abstract class Normalizer extends Callable<[string], string> {
+ config: TConfig;
+ /**
+ * @param config The configuration object for the normalizer.
+ */
+ constructor(config: TConfig) {
+ super();
+ this.config = config;
+ }
+ static fromConfig(config: TConfig): Normalizer {
+ switch (config.type) {
+ case 'Sequence':
+ return new NormalizerSequence(config);
+ case 'NFC':
+ return new NFC(config);
+ case 'NFD':
+ return new NFD(config);
+ case 'NFKC':
+ return new NFKC(config);
+ case 'NFKD':
+ return new NFKD(config);
+ case 'Strip':
+ return new StripNormalizer(config);
+ case 'Lowercase':
+ return new Lowercase(config);
+ case 'Prepend':
+ return new Prepend(config);
+ }
+ }
+
+ normalize(_text: string): string {
+ throw Error('normalize should be implemented in subclass.');
+ }
+
+ protected call(...[text]: [string]): string {
+ return this.normalize(text);
+ }
+}
+
+/**
+ * A normalizer that applies Unicode normalization to the input text.
+ */
+abstract class UnicodeNormalizer extends Normalizer {
+ form: 'NFC' | 'NFD' | 'NFKC' | 'NFKD' | undefined = undefined;
+
+ /**
+ * Normalize the input text by applying Unicode normalization.
+ * @param text The input text to be normalized.
+ * @returns The normalized text.
+ */
+ normalize(text: string) {
+ text = text.normalize(this.form as 'NFC');
+ return text;
+ }
+}
+
+/**
+ * A normalizer that applies Unicode normalization form C (NFC) to the input text.
+ * Canonical Decomposition, followed by Canonical Composition.
+ */
+class NFC extends UnicodeNormalizer {
+ form = 'NFC' as const;
+}
+
+/**
+ * A normalizer that applies Unicode normalization form D (NFD) to the input text.
+ * Canonical Decomposition.
+ */
+class NFD extends UnicodeNormalizer {
+ form = 'NFD' as const;
+}
+
+/**
+ * A normalizer that applies Unicode normalization form KC (NFKC) to the input text.
+ * Compatibility Decomposition, followed by Canonical Composition.
+ */
+class NFKC extends UnicodeNormalizer {
+ form = 'NFKC' as const;
+}
+
+/**
+ * A normalizer that applies Unicode normalization form KD (NFKD) to the input text.
+ * Compatibility Decomposition.
+ */
+class NFKD extends UnicodeNormalizer {
+ form = 'NFKD' as const;
+}
+
+/**
+ * A normalizer that strips leading and/or trailing whitespace from the input text.
+ */
+class StripNormalizer extends Normalizer {
+ /**
+ * Strip leading and/or trailing whitespace from the input text.
+ * @param text The input text.
+ * @returns The normalized text.
+ */
+ normalize(text: string) {
+ const cfg = this.config;
+ if (cfg.stripLeft && cfg.stripRight) {
+ // Fast path to avoid an extra trim call
+ text = text.trim();
+ } else {
+ if (cfg.stripLeft) {
+ text = text.trimStart();
+ }
+ if (cfg.stripRight) {
+ text = text.trimEnd();
+ }
+ }
+ return text;
+ }
+}
+
+/**
+ * A Normalizer that lowercases the input string.
+ */
+class Lowercase extends Normalizer {
+ /**
+ * Lowercases the input string.
+ * @param text The text to normalize.
+ * @returns The normalized text.
+ */
+ normalize(text: string) {
+ text = text.toLowerCase();
+ return text;
+ }
+}
+
+/**
+ * A Normalizer that prepends a string to the input string.
+ */
+class Prepend extends Normalizer {
+ /**
+ * Prepends the input string.
+ * @param text The text to normalize.
+ * @returns The normalized text.
+ */
+ normalize(text: string) {
+ const cfg = this.config;
+ text = cfg.prepend + text;
+ return text;
+ }
+}
+
+/**
+ * A Normalizer that applies a sequence of Normalizers.
+ */
+
+class NormalizerSequence extends Normalizer {
+ normalizers: Normalizer[];
+
+ constructor(config: NormalizerSequenceConfig) {
+ super(config);
+ this.normalizers = config.normalizers.map((x) => Normalizer.fromConfig(x));
+ }
+ /**
+ * Apply a sequence of Normalizers to the input text.
+ * @param text The text to normalize.
+ * @returns The normalized text.
+ */
+ normalize(text: string) {
+ return this.normalizers.reduce((t, normalizer) => {
+ return normalizer.normalize(t);
+ }, text);
+ }
+}
+
+/**
+ * A callable class representing a pre-tokenizer used in tokenization. Subclasses
+ * should implement the `preTokenizeText` method to define the specific pre-tokenization logic.
+ */
+abstract class PreTokenizer extends Callable<
+ [string | string[], PreTokenizeOptions | undefined],
+ string[]
+> {
+ /**
+ * Factory method that returns an instance of a subclass of `PreTokenizer` based on the provided configuration.
+ *
+ * @static
+ * @param config A configuration object for the pre-tokenizer.
+ * @returns An instance of a subclass of `PreTokenizer`.
+ * @throws If the provided configuration object does not correspond to any known pre-tokenizer.
+ */
+ static fromConfig(config: PreTokenizerConfig): PreTokenizer {
+ switch (config.type) {
+ case 'Sequence':
+ return new PreTokenizerSequence(config);
+ case 'Whitespace':
+ return new WhitespacePreTokenizer();
+ case 'WhitespaceSplit':
+ return new WhitespaceSplit();
+ case 'Metaspace':
+ return new MetaspacePreTokenizer(config);
+ case 'ByteLevel':
+ return new ByteLevelPreTokenizer(config);
+ default:
+ throw new Error('Unknown PreTokenizer type');
+ }
+ }
+
+ /**
+ * Method that should be implemented by subclasses to define the specific pre-tokenization logic.
+ *
+ * @param text The text to pre-tokenize.
+ * @param options Additional options for the pre-tokenization logic.
+ * @returns The pre-tokenized text.
+ * @throws {Error} If the method is not implemented in the subclass.
+ */
+ abstract preTokenizeText(text: string, options?: PreTokenizeOptions): string[];
+
+ /**
+ * Tokenizes the given text into pre-tokens.
+ * @param text The text or array of texts to pre-tokenize.
+ * @param options Additional options for the pre-tokenization logic.
+ * @returns An array of pre-tokens.
+ */
+ preTokenize(text: string | string[], options?: PreTokenizeOptions): string[] {
+ return (
+ Array.isArray(text)
+ ? (text as string[]).map((x) => this.preTokenizeText(x, options))
+ : this.preTokenizeText(text as string, options)
+ ).flat();
+ }
+
+ /**
+ * Alias for {@link PreTokenizer#preTokenize}.
+ * @param text The text or array of texts to pre-tokenize.
+ * @param options Additional options for the pre-tokenization logic.
+ * @returns An array of pre-tokens.
+ */
+ protected call(
+ ...[text, options]: [string | string[], PreTokenizeOptions | undefined]
+ ): string[] {
+ return this.preTokenize(text, options);
+ }
+}
+
+/**
+ * A pre-tokenizer that splits text into Byte-Pair-Encoding (BPE) subwords.
+ * @extends PreTokenizer
+ */
+
+class ByteLevelPreTokenizer extends PreTokenizer {
+ config: ByteLevelPreTokenizerConfig;
+ addPrefixSpace!: boolean;
+ trimOffsets!: boolean;
+ useRegex!: boolean;
+ pattern!: RegExp;
+ byteEncoder!: Record;
+ textEncoder!: TextEncoder;
+ constructor(config: ByteLevelPreTokenizerConfig) {
+ super();
+ this.config = config;
+ this.addPrefixSpace = this.config.addPrefixSpace;
+ this.trimOffsets = this.config.trimOffsets;
+ this.useRegex = this.config.useRegex ?? true;
+ this.pattern = /'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/gu;
+ this.byteEncoder = BYTES_TO_UNICODE as Record;
+ this.textEncoder = new TextEncoder();
+ }
+ preTokenizeText(text: string, _options?: PreTokenizeOptions): string[] {
+ if (this.addPrefixSpace && !text.startsWith(' ')) {
+ text = ' ' + text;
+ }
+ const tokens = this.useRegex ? text.match(this.pattern) || [] : [text];
+ return tokens.map((token) =>
+ Array.from(this.textEncoder.encode(token), (byte) => this.byteEncoder[byte]).join('')
+ );
+ }
+}
+
+type PostProcessorArgs = [string[], (string[] | null | undefined)?, PostProcessorOptions?];
+abstract class PostProcessor extends Callable {
+ config: PostProcessorConfig;
+ constructor(config: PostProcessorConfig) {
+ super();
+ this.config = config;
+ }
+ static fromConfig(config: PostProcessorConfig): PostProcessor {
+ switch (config.type) {
+ case 'ByteLevel':
+ return new ByteLevelPostProcessor(config);
+ case 'Sequence':
+ return new PostProcessorSequence(config);
+ default:
+ throw new Error('Unknown PostProcessor type');
+ }
+ }
+
+ abstract postProcess(
+ tokens: string[],
+ ...args: [string[] | null | undefined, PostProcessorOptions?]
+ ): PostProcessorResult;
+
+ protected call(
+ ...[tokens, ...args]: [string[], (string[] | null | undefined)?, PostProcessorOptions?]
+ ): PostProcessorResult {
+ return this.postProcess(tokens, ...args);
+ }
+}
+
+/**
+ * A PostProcessor that returns the given tokens as is.
+ */
+class ByteLevelPostProcessor extends PostProcessor {
+ postProcess(tokens: string[], tokensPair: string[] | null = null): { tokens: string[] } {
+ if (tokensPair) {
+ tokens = mergeArrays(tokens, tokensPair);
+ }
+ return { tokens };
+ }
+}
+
+/**
+ * A post-processor that applies multiple post-processors in sequence.
+ */
+class PostProcessorSequence extends PostProcessor {
+ processors: PostProcessor[];
+ constructor(config: PostProcessorSequenceConfig) {
+ super(config);
+ this.processors = config.processors.map((x) => PostProcessor.fromConfig(x));
+ }
+ postProcess(
+ tokens: string[],
+ tokensPair: string[] | null = null,
+ options: PostProcessorOptions = {}
+ ): { tokens: string[]; tokenTypeIds?: number[] } {
+ let tokenTypeIds: number[] | undefined;
+ for (const processor of this.processors) {
+ if (processor instanceof ByteLevelPostProcessor) {
+ const output = processor.postProcess(tokens);
+ tokens = output.tokens;
+ if (tokensPair) {
+ const pairOutput = processor.postProcess(tokensPair);
+ tokensPair = pairOutput.tokens;
+ }
+ } else {
+ const output = processor.postProcess(tokens, tokensPair ?? null, options);
+ tokens = output.tokens;
+ if (output.tokenTypeIds) {
+ tokenTypeIds = output.tokenTypeIds;
+ }
+ }
+ }
+ return { tokens, tokenTypeIds: tokenTypeIds };
+ }
+}
+
+/**
+ * The base class for token decoders.
+ */
+abstract class Decoder extends Callable<
+ [string[]],
+ string
+> {
+ config: TConfig;
+ addedTokens: AddedToken[];
+ endOfWordSuffix?: string;
+ constructor(config: TConfig) {
+ super();
+ this.config = config;
+ this.addedTokens = [];
+ this.endOfWordSuffix = undefined;
+ }
+ static fromConfig(config: DecoderConfig): Decoder {
+ switch (config.type) {
+ case 'ByteLevel':
+ return new ByteLevelDecoder(config);
+ case 'ByteFallback':
+ return new ByteFallback(config);
+ case 'Fuse':
+ return new FuseDecoder(config);
+ case 'Strip':
+ return new StripDecoder(config);
+ case 'Sequence':
+ return new DecoderSequence(config);
+ case 'BPEDecoder':
+ return new BPEDecoder(config);
+ default:
+ throw new Error('Unknown Decoder type');
+ }
+ }
+ protected call(...[tokens]: [string[]]): string {
+ return this.decode(tokens);
+ }
+ decode(tokens: string[]): string {
+ return this.decodeChain(tokens).join('');
+ }
+ abstract decodeChain(tokens: string[]): string[];
+}
+
+class ByteFallback extends Decoder {
+ textDecoder!: TextDecoder;
+ constructor(config: ByteFallbackConfig) {
+ super(config);
+ this.textDecoder = new TextDecoder();
+ }
+ decodeChain(tokens: string[]): string[] {
+ const newTokens: string[] = [];
+ let previousByteTokens: number[] = [];
+ for (const token of tokens) {
+ let bytes: number | null = null;
+ if (token.length === 6 && token.startsWith('<0x') && token.endsWith('>')) {
+ const byte = parseInt(token.slice(3, 5), 16);
+ if (!isNaN(byte)) {
+ bytes = byte;
+ }
+ }
+ if (bytes !== null) {
+ previousByteTokens.push(bytes);
+ } else {
+ if (previousByteTokens.length > 0) {
+ const string = this.textDecoder.decode(Uint8Array.from(previousByteTokens));
+ newTokens.push(string);
+ previousByteTokens = [];
+ }
+ newTokens.push(token);
+ }
+ }
+ if (previousByteTokens.length > 0) {
+ const string = this.textDecoder.decode(Uint8Array.from(previousByteTokens));
+ newTokens.push(string);
+ previousByteTokens = [];
+ }
+ return newTokens;
+ }
+}
+
+/**
+ * Fuse simply fuses all tokens into one big string.
+ * It's usually the last decoding step anyway, but this decoder
+ * exists incase some decoders need to happen after that step
+ */
+class FuseDecoder extends Decoder {
+ /** @type {Decoder['decodeChain']} */
+ decodeChain(tokens: string[]): string[] {
+ return [tokens.join('')];
+ }
+}
+
+class StripDecoder extends Decoder {
+ content!: string;
+ start!: number;
+ stop!: number;
+ constructor(config: StripDecoderConfig) {
+ super(config);
+ const cfg = this.config;
+ this.content = cfg.content;
+ this.start = cfg.start;
+ this.stop = cfg.stop;
+ }
+ /** @type {Decoder['decodeChain']} */
+ decodeChain(tokens: string[]): string[] {
+ return tokens.map((token) => {
+ let startCut = 0;
+ for (let i = 0; i < this.start; ++i) {
+ if (token[i] === this.content) {
+ startCut = i + 1;
+ continue;
+ } else {
+ break;
+ }
+ }
+ let stopCut = token.length;
+ for (let i = 0; i < this.stop; ++i) {
+ const index = token.length - i - 1;
+ if (token[index] === this.content) {
+ stopCut = index;
+ continue;
+ } else {
+ break;
+ }
+ }
+ return token.slice(startCut, stopCut);
+ });
+ }
+}
+
+/**
+ * Byte-level decoder for tokenization output. Inherits from the `Decoder` class.
+ * @extends Decoder
+ */
+class ByteLevelDecoder extends Decoder {
+ byteDecoder!: Record;
+ textDecoder!: TextDecoder;
+ constructor(config: ByteLevelDecoderConfig) {
+ super(config);
+ this.byteDecoder = UNICODE_TO_BYTES as unknown as Record;
+ this.textDecoder = new TextDecoder('utf-8', { fatal: false, ignoreBOM: true });
+ this.endOfWordSuffix = undefined;
+ }
+ convertTokensToString(tokens: string[]): string {
+ const text = tokens.join('');
+ const byteArray = new Uint8Array([...text].map((c) => this.byteDecoder[c]));
+ const decodedText = this.textDecoder.decode(byteArray);
+ return decodedText;
+ }
+ /** @type {Decoder['decodeChain']} */
+ decodeChain(tokens: string[]): string[] {
+ const subTexts: string[] = [];
+ let currentSubText: string[] = [];
+ for (const token of tokens) {
+ if (this.addedTokens.find((x) => x.content === token) !== undefined) {
+ if (currentSubText.length > 0) {
+ subTexts.push(this.convertTokensToString(currentSubText));
+ currentSubText = [];
+ }
+ subTexts.push(token);
+ } else {
+ currentSubText.push(token);
+ }
+ }
+ if (currentSubText.length > 0) {
+ subTexts.push(this.convertTokensToString(currentSubText));
+ }
+ return subTexts;
+ }
+}
+
+/**
+ * Apply a sequence of decoders.
+ * @extends Decoder
+ */
+class DecoderSequence extends Decoder {
+ decoders!: Decoder[];
+ constructor(config: DecoderSequenceConfig) {
+ super(config);
+ this.decoders = config.decoders.map((x) => Decoder.fromConfig(x));
+ }
+ /** @type {Decoder['decodeChain']} */
+ decodeChain(tokens: string[]): string[] {
+ return this.decoders.reduce((toks: string[], decoder: Decoder) => {
+ return decoder.decodeChain(toks);
+ }, tokens);
+ }
+}
+
+class BPEDecoder extends Decoder {
+ suffix!: string;
+ constructor(config: BPEDecoderConfig) {
+ super(config);
+ const cfg = this.config;
+ this.suffix = cfg.suffix;
+ }
+ /** @type {Decoder['decodeChain']} */
+ decodeChain(tokens: string[]): string[] {
+ return tokens.map((token, i) => {
+ return token.replaceAll(this.suffix, i === tokens.length - 1 ? '' : ' ');
+ });
+ }
+}
+
+/**
+ * This PreTokenizer replaces spaces with the given replacement character, adds a prefix space if requested,
+ * and returns a list of tokens.
+ * @extends PreTokenizer
+ */
+class MetaspacePreTokenizer extends PreTokenizer {
+ addPrefixSpace: boolean;
+ replacement: string;
+ strRep: string;
+ prependScheme: 'first' | 'never' | 'always';
+ /**
+ * @param {Object} config The configuration object for the MetaspacePreTokenizer.
+ * @param {boolean} config.addPrefixSpace Whether to add a prefix space to the first token.
+ * @param {string} config.replacement The character to replace spaces with.
+ * @param {string} [config.strRep=config.replacement] An optional string representation of the replacement character.
+ * @param {'first'|'never'|'always'} [config.prependScheme='always'] The metaspace prepending scheme.
+ */
+ constructor(config: MetaspacePreTokenizerConfig) {
+ super();
+
+ this.addPrefixSpace = config.addPrefixSpace;
+ this.replacement = config.replacement;
+ this.strRep = config.strRep || this.replacement;
+ this.prependScheme = config.prependScheme ?? 'always';
+ }
+
+ /**
+ * This method takes a string, replaces spaces with the replacement character,
+ * adds a prefix space if requested, and returns a new list of tokens.
+ * @param text The text to pre-tokenize.
+ * @param options The options for the pre-tokenization.
+ * @param options.sectionIndex The index of the section to pre-tokenize.
+ * @returns A new list of pre-tokenized tokens.
+ */
+ preTokenizeText(
+ text: string,
+ { sectionIndex: sectionIndex = undefined }: PreTokenizeOptions = {}
+ ) {
+ let normalized = text.replaceAll(' ', this.strRep);
+
+ if (
+ // We add a prefix space if:
+ // (1) The addPrefixSpace option is enabled and the normalized token does not already start
+ // with the replacement character.
+ this.addPrefixSpace &&
+ !normalized.startsWith(this.replacement) &&
+ // and (2) either:
+ // (a) prependScheme is 'always'
+ // (b) prependScheme is 'first' and this is the first section
+ (this.prependScheme === 'always' || (this.prependScheme === 'first' && sectionIndex === 0))
+ ) {
+ normalized = this.strRep + normalized;
+ }
+ return [normalized];
+ }
+}
+
+/**
+ * A pre-tokenizer that applies a sequence of pre-tokenizers to the input text.
+ * @extends PreTokenizer
+ */
+class PreTokenizerSequence extends PreTokenizer {
+ tokenizers: PreTokenizer[];
+ /**
+ * Creates an instance of PreTokenizerSequence.
+ * @param {Object} config The configuration object for the pre-tokenizer sequence.
+ * @param {Object[]} config.pretokenizers An array of pre-tokenizer configurations.
+ */
+ constructor(config: PreTokenizerSequenceConfig) {
+ super();
+ this.tokenizers = config.pretokenizers.map((x) => PreTokenizer.fromConfig(x));
+ }
+
+ /**
+ * Applies each pre-tokenizer in the sequence to the input text in turn.
+ * @param text The text to pre-tokenize.
+ * @param options Additional options for the pre-tokenization logic.
+ * @returns The pre-tokenized text.
+ */
+ preTokenizeText(text: string, options: PreTokenizeOptions) {
+ // Use reduce to apply each tokenizer to the text
+ return this.tokenizers.reduce(
+ (preTokenizedText, tokenizer) => {
+ return tokenizer.preTokenize(preTokenizedText, options);
+ },
+ [text]
+ );
+ }
+}
+
+/**
+ * Splits on word boundaries (using the following regular expression: `\w+|[^\w\s]+`).
+ */
+class WhitespacePreTokenizer extends PreTokenizer {
+ /**
+ * Creates an instance of WhitespacePreTokenizer.
+ * @param config The configuration object for the pre-tokenizer.
+ */
+ constructor() {
+ super();
+ }
+ /**
+ * Pre-tokenizes the input text by splitting it on word boundaries.
+ * @param text The text to be pre-tokenized.
+ * @param options Additional options for the pre-tokenization logic.
+ * @returns An array of tokens produced by splitting the input text on whitespace.
+ */
+ preTokenizeText(text: string, _options: unknown) {
+ return text.match(/\w+|[^\w\s]+/g) || [];
+ }
+}
+
+/**
+ * Splits a string of text by whitespace characters into individual tokens.
+ * @extends PreTokenizer
+ */
+class WhitespaceSplit extends PreTokenizer {
+ /**
+ * Creates an instance of WhitespaceSplit.
+ * @param config The configuration object for the pre-tokenizer.
+ */
+ constructor() {
+ super();
+ }
+ /**
+ * Pre-tokenizes the input text by splitting it on whitespace characters.
+ * @param text The text to be pre-tokenized.
+ * @param options Additional options for the pre-tokenization logic.
+ * @returns An array of tokens produced by splitting the input text on whitespace.
+ */
+ preTokenizeText(text: string, _options: unknown) {
+ return whitespaceSplit(text);
+ }
+}
+
+const SPECIAL_TOKEN_ATTRIBUTES = [
+ 'bos_token',
+ 'eos_token',
+ 'unk_token',
+ 'sep_token',
+ 'pad_token',
+ 'cls_token',
+ 'mask_token'
+ // additional_special_tokens (TODO)
+];
+
+/**
+ *
+ * Helper function for padding values of an object, which are each arrays.
+ * NOTE: No additional checks are made here for validity of arguments.
+ * @param item The input object.
+ * @param length The length to pad to.
+ * @param valueFn Determine the value to fill the array, based on its key.
+ * @param side Which side to pad the array.
+ */
+function padHelper(
+ item: Record,
+ length: number,
+ valueFn: (key: string) => T,
+ side: 'right' | 'left'
+) {
+ for (const key of Object.keys(item)) {
+ const diff = length - item[key].length;
+ const value = valueFn(key);
+
+ const padData = new Array(diff).fill(value);
+ item[key] =
+ side === 'right' ? mergeArrays(item[key], padData) : mergeArrays(padData, item[key]);
+ }
+}
+
+/**
+ * Helper function for truncating values of an object, which are each arrays.
+ * NOTE: No additional checks are made here for validity of arguments.
+ * @param item The input object.
+ * @param length The length to truncate to.
+ */
+function truncateHelper(item: Record, length: number) {
+ // Setting .length to a lower value truncates the array in-place:
+ // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/length
+ for (const key of Object.keys(item)) {
+ item[key].length = length;
+ }
+}
+
+interface DecodeArgs {
+ skipSpecialTokens?: boolean;
+ cleanUpTokenizationSpaces?: boolean;
+}
+
+type BatchEncodingItem = number[] | number[][] | Tensor;
+
+interface BatchEncoding {
+ inputIds: BatchEncodingItem;
+ attentionMask: BatchEncodingItem;
+ tokenTypeIds?: BatchEncodingItem;
+}
+
+interface Message {
+ role: string;
+ content: string;
+}
+
+export class PreTrainedTokenizer extends Callable<
+ [
+ string | string[],
+ {
+ textPair?: string | null;
+ addSpecialTokens?: boolean;
+ padding?: boolean | 'max_length';
+ truncation?: boolean | null;
+ maxLength?: number | null;
+ returnTensor?: boolean;
+ returnTokenTypeIds?: boolean | null;
+ }?
+ ],
+ BatchEncoding
+> {
+ config: TokenizerConfig;
+ normalizer!: ((text: string) => string) | Normalizer | null;
+ preTokenizer!: ((text: string, options?: PreTokenizeOptions) => string[]) | PreTokenizer | null;
+ model!: TokenizerModel;
+ postProcessor!:
+ | ((
+ tokens: string[],
+ tokensPair?: string[] | null,
+ options?: PostProcessorOptions
+ ) => PostProcessorResult)
+ | PostProcessor
+ | null;
+ decoder!: ((tokens: string[]) => string) | Decoder | null;
+ specialTokens: string[];
+ allSpecialIds: number[];
+ addedTokens: AddedToken[];
+ additionalSpecialTokens: string[];
+ addedTokensSplitter: DictionarySplitter;
+ addedTokensMap: Map;
+ maskToken?: string | null;
+ maskTokenId?: number;
+ padToken?: string | null;
+ padTokenId?: number;
+ sepToken?: string | null;
+ sepTokenId?: number;
+ unkToken?: string | null;
+ unkTokenId?: number;
+ bosToken?: string | null;
+ bosTokenId?: number;
+ eosToken?: string | null;
+ eosTokenId?: number;
+ modelMaxLength!: number;
+ removeSpace!: boolean;
+ cleanUpTokenizationSpaces!: boolean;
+ paddingSide: 'left' | 'right' = 'right';
+ addBoxToken?: boolean;
+ addEosToken?: boolean;
+ chatTemplate: null | Record | Array<{ name: string; template: string }>;
+ returnTokenTypeIds = false;
+ private compiledTemplateCache: Map;
+ constructor(tokenizerJSON: TokenizerJSON, tokenizerConfig: TokenizerConfig) {
+ super();
+ this.config = tokenizerConfig;
+ this.normalizer = tokenizerJSON.normalizer
+ ? Normalizer.fromConfig(tokenizerJSON.normalizer)
+ : null;
+ this.preTokenizer = tokenizerJSON.preTokenizer
+ ? PreTokenizer.fromConfig(tokenizerJSON.preTokenizer)
+ : null;
+ this.model = TokenizerModel.fromConfig(tokenizerJSON.model, tokenizerConfig);
+ this.postProcessor = tokenizerJSON.postProcessor
+ ? PostProcessor.fromConfig(tokenizerJSON.postProcessor)
+ : null;
+ this.decoder = tokenizerJSON.decoder ? Decoder.fromConfig(tokenizerJSON.decoder) : null;
+ this.specialTokens = [];
+ this.allSpecialIds = [];
+ this.addedTokens = [];
+ for (const addedToken of tokenizerJSON.addedTokens) {
+ const token = new AddedToken(addedToken);
+ this.addedTokens.push(token);
+ this.model.tokensToIds.set(token.content, token.id);
+ this.model.vocab[token.id] = token.content;
+ if (token.special) {
+ this.specialTokens.push(token.content);
+ this.allSpecialIds.push(token.id);
+ }
+ }
+ this.additionalSpecialTokens = tokenizerConfig.additionalSpecialTokens ?? [];
+ this.specialTokens.push(...this.additionalSpecialTokens);
+ this.specialTokens = [...new Set(this.specialTokens)];
+ if (this.decoder) {
+ (this.decoder as Decoder).addedTokens = this.addedTokens;
+ (this.decoder as Decoder).endOfWordSuffix = this.model.endOfWordSuffix;
+ }
+ this.addedTokensSplitter = new DictionarySplitter(this.addedTokens.map((x) => x.content));
+ this.addedTokensMap = new Map(this.addedTokens.map((x) => [x.content, x]));
+ this.maskToken = this.getToken('mask_token');
+ this.maskTokenId = this.model.tokensToIds.get(this.maskToken as string);
+ this.padToken = this.getToken('pad_token', 'eos_token');
+ this.padTokenId = this.model.tokensToIds.get(this.padToken as string);
+ this.sepToken = this.getToken('sep_token');
+ this.sepTokenId = this.model.tokensToIds.get(this.sepToken as string);
+ this.unkToken = this.getToken('unk_token');
+ this.unkTokenId = this.model.tokensToIds.get(this.unkToken as string);
+ this.bosToken = this.getToken('bos_token');
+ this.bosTokenId = this.model.tokensToIds.get(this.bosToken as string);
+ this.eosToken = this.getToken('eos_token');
+ this.eosTokenId = this.model.tokensToIds.get(this.eosToken as string);
+ this.modelMaxLength = tokenizerConfig.modelMaxLength as number;
+ this.removeSpace = tokenizerConfig.removeSpace as boolean;
+ this.cleanUpTokenizationSpaces = (tokenizerConfig.cleanUpTokenizationSpaces ?? true) as boolean;
+ if (tokenizerConfig.paddingSide) {
+ this.paddingSide = tokenizerConfig.paddingSide as 'left' | 'right';
+ }
+ this.addBoxToken = tokenizerConfig.addBosToken as boolean;
+ this.addEosToken = tokenizerConfig.addEosToken as boolean;
+ this.chatTemplate = tokenizerConfig.chatTemplate ?? null;
+ if (Array.isArray(this.chatTemplate)) {
+ const chatTemplate: Record = Object.create(null);
+ for (const { name, template } of this.chatTemplate) {
+ if (typeof name !== 'string' || typeof template !== 'string') {
+ throw new Error(
+ 'Chat template must be a list of objects with "name" and "template" properties'
+ );
+ }
+ chatTemplate[name] = template;
+ }
+ this.chatTemplate = chatTemplate;
+ }
+ this.compiledTemplateCache = new Map();
+ }
+ getToken(...keys: string[]): string | null {
+ for (const key of keys) {
+ const item = this.config[key];
+ if (!item) continue;
+ if (typeof item === 'object') {
+ const maybe = item as { type?: string; content?: string };
+ if (maybe.type === 'AddedToken' && typeof maybe.content === 'string') {
+ return maybe.content;
+ }
+ throw Error(`Unknown token: ${String(item)}`);
+ } else {
+ return item as string;
+ }
+ }
+ return null;
+ }
+
+ static async fromPretrained(tokenizerName: string): Promise {
+ const info = await loadTokenizer(tokenizerName);
+ return new this(...info);
+ }
+
+ /**
+ * Encode/tokenize the given text(s).
+ * @param text The text to tokenize.
+ * @param options An optional object containing the following properties:
+ * @param options.textPair A second sequence to be encoded with the first.
+ * @param options.padding Whether to pad the input sequences.
+ * @param options.addSpecialTokens Whether or not to add the special tokens associated with the corresponding model.
+ * @param options.truncation Whether to truncate the input sequences.
+ * @param options.maxLength Maximum length of the returned list and optionally padding length.
+ * @param options.returnTensor Whether to return the results as Tensors or arrays.
+ * @param options.returnTokenTypeIds Whether to return the token type ids.
+ * @returns Object to be passed to the model.
+ */
+ protected call(
+ text: string | string[],
+ {
+ textPair = null,
+ addSpecialTokens = true,
+ padding = false,
+ truncation = null,
+ maxLength = null,
+ returnTensor = true,
+ returnTokenTypeIds = null
+ }: {
+ textPair?: string | null;
+ addSpecialTokens?: boolean;
+ padding?: boolean | 'max_length';
+ truncation?: boolean | null;
+ maxLength?: number | null;
+ returnTensor?: boolean;
+ returnTokenTypeIds?: boolean | null;
+ } = {}
+ ): BatchEncoding {
+ const isBatched = Array.isArray(text);
+
+ let encodedTokens;
+
+ if (isBatched) {
+ if (text.length === 0) {
+ throw Error('text array must be non-empty');
+ }
+ encodedTokens = text.map((x) =>
+ this.encodePlus(x, {
+ addSpecialTokens: addSpecialTokens,
+ returnTokenTypeIds: returnTokenTypeIds
+ })
+ );
+ } else {
+ if (text === null || text === undefined) {
+ throw Error('text may not be null or undefined');
+ }
+
+ if (Array.isArray(textPair)) {
+ throw Error(
+ 'When specifying `textPair`, since `text` is a string, `textPair` must also be a string (i.e., not an array).'
+ );
+ }
+
+ // For single input, we just wrap in an array, and then unwrap later.
+ encodedTokens = [
+ this.encodePlus(text, {
+ addSpecialTokens: addSpecialTokens,
+ returnTokenTypeIds: returnTokenTypeIds
+ })
+ ];
+ }
+ // At this point, `encodedTokens` is batched, of shape [batchSize, tokens].
+ // However, array may be jagged. So, we may need pad to maxLength.
+ if (maxLength === null) {
+ maxLength = this.modelMaxLength;
+ } else if (truncation === null) {
+ if (padding === true) {
+ console.warn(
+ '`maxLength` is ignored when `padding: true` and there is no truncation strategy. ' +
+ "To pad to max length, use `padding: 'maxLength'`."
+ );
+ maxLength = this.modelMaxLength;
+ } else if (padding === false) {
+ console.warn(
+ 'Truncation was not explicitly activated but `maxLength` is provided a specific value, please use `truncation: true` to explicitly truncate examples to max length.'
+ );
+ truncation = true;
+ }
+ }
+
+ // padding: 'maxLength' doesn't require any additional calculation
+ // but padding: true has to calculate maxLength from the sequences
+ if (padding === true) {
+ maxLength = Math.min(
+ max(encodedTokens.map((x) => x.inputIds.length))[0],
+ maxLength ?? Infinity
+ );
+ }
+
+ // Ensure it is less than model max length
+ maxLength = Math.min(maxLength, this.modelMaxLength ?? Infinity);
+
+ if (padding || truncation) {
+ // Perform padding and/or truncation
+ for (let i = 0; i < encodedTokens.length; ++i) {
+ if (encodedTokens[i].inputIds.length === maxLength) {
+ continue;
+ } else if (encodedTokens[i].inputIds.length > maxLength) {
+ // possibly truncate
+ if (truncation) {
+ truncateHelper(encodedTokens[i], maxLength);
+ }
+ } else {
+ // t.length < maxLength
+ // possibly pad
+ if (padding) {
+ padHelper(
+ encodedTokens[i],
+ maxLength,
+ (key) => (key === 'inputIds' ? this.padTokenId : 0),
+ this.paddingSide
+ );
+ }
+ }
+ }
+ }
+
+ const result: Record = {};
+
+ if (returnTensor) {
+ if (!(padding && truncation)) {
+ // Not, guaranteed that all items have same length, so
+ // we perform additional check
+
+ if (
+ encodedTokens.some((x) => {
+ for (const key of Object.keys(x)) {
+ if (
+ (x as Record)[key].length !==
+ (encodedTokens[0] as Record)[key]?.length
+ ) {
+ return true;
+ }
+ }
+ return false;
+ })
+ ) {
+ throw Error(
+ 'Unable to create tensor, you should probably activate truncation and/or padding ' +
+ "with 'padding=true' and 'truncation=true' to have batched tensors with the same length."
+ );
+ }
+ }
+
+ // Now we actually convert to tensor
+ // NOTE: In the same way as the python library, we return a batched tensor, regardless of
+ // whether we have a single input or multiple inputs.
+ const dims = [encodedTokens.length, encodedTokens[0].inputIds.length];
+
+ for (const key of Object.keys(encodedTokens[0])) {
+ result[key] = tensor(
+ Int32Array.from(
+ encodedTokens
+ .flatMap(
+ (x) =>
+ (x as Record)[key] as (bigint | boolean | number | string)[]
+ )
+ .map(Number)
+ ),
+ { shape: dims, dtype: int32 }
+ );
+ }
+ } else {
+ for (const key of Object.keys(encodedTokens[0])) {
+ result[key] = encodedTokens.map((x) => (x as Record)[key]);
+ }
+
+ // If not returning a tensor, we match the input type
+ if (!isBatched) {
+ // Input was not batched, so we unwrap
+ for (const key of Object.keys(result)) {
+ result[key] = (result[key] as unknown[])[0];
+ }
+ }
+ }
+
+ return result as unknown as BatchEncoding;
+ }
+
+ /**
+ * Encodes a single text using the preprocessor pipeline of the tokenizer.
+ *
+ * @param {string|null} text The text to encode.
+ * @returns {string[]|null} The encoded tokens.
+ */
+ private encodeText(text: string | null): string[] | null {
+ if (text === null) return null;
+
+ // Actual function which does encoding, for a single text
+ // First, we take care of special tokens. Needed to avoid issues arising from
+ // normalization and/or pretokenization (which may not preserve special tokens)
+ const sections = this.addedTokensSplitter.split(text);
+
+ // Process left/right stripping of added tokens
+ for (let i = 0; i < sections.length; ++i) {
+ const addedToken = this.addedTokensMap.get(sections[i]);
+ if (addedToken) {
+ if (addedToken.lstrip && i > 0) {
+ sections[i - 1] = sections[i - 1].trimEnd();
+ }
+ if (addedToken.rstrip && i < sections.length - 1) {
+ sections[i + 1] = sections[i + 1].trimStart();
+ }
+ }
+ }
+
+ const tokens = sections.flatMap((x, sectionIndex) => {
+ if (x.length === 0) return [];
+ if (this.addedTokensMap.has(x)) return [x]; // Return added tokens unchanged
+
+ if (this.removeSpace === true) {
+ x = x.trim().split(/\s+/).join(' ');
+ }
+
+ if (this.normalizer !== null) {
+ x = this.normalizer(x);
+ }
+
+ // If, after normalization, this section is empty (e.g., trimming whitespace),
+ // we return an empty array
+ if (x.length === 0) {
+ return [];
+ }
+
+ const sectionTokens =
+ this.preTokenizer !== null
+ ? this.preTokenizer(x, {
+ sectionIndex: sectionIndex
+ })
+ : [x];
+
+ const tokens = this.model(sectionTokens);
+
+ return tokens;
+ });
+
+ return tokens;
+ }
+
+ /**
+ * Encodes a single text or a pair of texts using the model's tokenizer.
+ *
+ * @param text The text to encode.
+ * @param options An optional object containing the following properties:
+ * @param options.textPair The optional second text to encode.
+ * @param options.addSpecialTokens Whether or not to add the special tokens associated with the corresponding model.
+ * @param options.returnTokenTypeIds Whether to return tokenTypeIds.
+ * @returns An object containing the encoded text.
+ */
+ private encodePlus(
+ text: string,
+ {
+ textPair = null,
+ addSpecialTokens = true,
+ returnTokenTypeIds = null
+ }: {
+ textPair?: string | null;
+ addSpecialTokens?: boolean;
+ returnTokenTypeIds?: boolean | null;
+ } = {}
+ ) {
+ const { tokens, tokenTypeIds } = this.tokenizeHelper(text, {
+ pair: textPair,
+ addSpecialTokens
+ });
+
+ const inputIds = this.model.convertTokensToIds(tokens);
+
+ const result = {
+ inputIds: inputIds,
+ attentionMask: new Array(inputIds.length).fill(1)
+ };
+ if ((returnTokenTypeIds ?? this.returnTokenTypeIds) && tokenTypeIds) {
+ (result as { tokenTypeIds?: number[] }).tokenTypeIds = tokenTypeIds;
+ }
+ return result;
+ }
+
+ /**
+ * Internal helper function to tokenize a text, and optionally a pair of texts.
+ * @param text The text to tokenize.
+ * @param options An optional object containing the following properties:
+ * @param options.pair The optional second text to tokenize.
+ * @param options.addSpecialTokens Whether or not to add the special tokens associated with the corresponding model.
+ * @returns An object containing the tokens and optionally the token type IDs.
+ */
+ private tokenizeHelper(
+ text: string,
+ {
+ pair = null,
+ addSpecialTokens = false
+ }: { pair?: string | null; addSpecialTokens?: boolean } = {}
+ ) {
+ const tokens = this.encodeText(text);
+ const tokens2 = this.encodeText(pair);
+
+ return this.postProcessor
+ ? this.postProcessor(tokens ?? [], tokens2 ?? null, { addSpecialTokens })
+ : { tokens: mergeArrays(tokens ?? [], tokens2 ?? []) };
+ }
+
+ /**
+ * Converts a string into a sequence of tokens.
+ * @param text The sequence to be encoded.
+ * @param options An optional object containing the following properties:
+ * @param options.pair A second sequence to be encoded with the first.
+ * @param options.addSpecialTokens Whether or not to add the special tokens associated with the corresponding model.
+ * @returns The list of tokens.
+ */
+ tokenize(text: string, { pair = null, addSpecialTokens = false } = {}) {
+ return this.tokenizeHelper(text, { pair, addSpecialTokens }).tokens;
+ }
+
+ /**
+ * Encodes a single text or a pair of texts using the model's tokenizer.
+ *
+ * @param text The text to encode.
+ * @param options An optional object containing the following properties:
+ * @param options.addSpecialTokens Whether or not to add the special tokens associated with the corresponding model.
+ * @param options.returnTokenTypeIds Whether to return tokenTypeIds.
+ * @returns An array of token IDs representing the encoded text(s).
+ */
+ encode(text: string, { addSpecialTokens = true, returnTokenTypeIds = null } = {}) {
+ return this.encodePlus(text, {
+ addSpecialTokens,
+ returnTokenTypeIds
+ }).inputIds;
+ }
+
+ /**
+ * Decode a batch of tokenized sequences.
+ * @param batch List of tokenized input sequences.
+ * @param decodeArgs (Optional) Object with decoding arguments.
+ * @returns List of decoded sequences.
+ */
+ batchDecode(batch: number[][], decodeArgs: DecodeArgs = {}) {
+ return batch.map((x) => this.decode(x, decodeArgs));
+ }
+
+ /**
+ * Decodes a sequence of token IDs back to a string.
+ *
+ * @param tokenIds List of token IDs to decode.
+ * @param decodeArgs (Optional) Object with decoding arguments.
+ *
+ * @returns The decoded string.
+ * @throws If `tokenIds` is not a non-empty array of integers.
+ */
+ decode(tokenIds: number[], decodeArgs: DecodeArgs = {}) {
+ if (
+ !Array.isArray(tokenIds) ||
+ tokenIds.length === 0 ||
+ !(Number.isInteger(tokenIds[0]) || typeof tokenIds[0] === 'bigint')
+ ) {
+ throw Error('tokenIds must be a non-empty array of integers.');
+ }
+
+ return this.decodeSingle(tokenIds, decodeArgs);
+ }
+
+ /**
+ * Decode a single list of token ids to a string.
+ * @param tokenIds List of token ids to decode
+ * @param decodeArgs Optional arguments for decoding
+ * @param [decodeArgs.skipSpecialTokens=false] Whether to skip special tokens during decoding
+ * @param [decodeArgs.cleanUpTokenizationSpaces=null] Whether to clean up tokenization spaces
+ * during decoding. If null, the value is set to `this.decoder.cleanup` if it exists, falling
+ * back to `this.cleanUpTokenizationSpaces` if it exists, falling back to `true`.
+ * @returns The decoded string
+ */
+ decodeSingle(tokenIds: number[], { skipSpecialTokens = false }: DecodeArgs = {}) {
+ let tokens = this.model.convertIdsToTokens(tokenIds);
+ if (skipSpecialTokens) {
+ tokens = tokens.filter((x) => !this.specialTokens.includes(x));
+ }
+
+ // If `this.decoder` is null, we just join tokens with a space:
+ // https://github.com/huggingface/tokenizers/blob/8edec536a737cb04494b454805be16c020abb14f/tokenizers/src/tokenizer/mod.rs#L835
+ let decoded = this.decoder ? this.decoder(tokens) : tokens.join(' ');
+
+ // Slight hack, but prevents having to pass `skipSpecialTokens` to each call to `decode`, which
+ // would lead to code duplication.
+ if (this.decoder && 'endOfWordSuffix' in this.decoder && this.decoder.endOfWordSuffix) {
+ decoded = decoded.replaceAll(this.decoder.endOfWordSuffix, ' ');
+ if (skipSpecialTokens) {
+ decoded = decoded.trim();
+ }
+ }
+
+ return decoded;
+ }
+
+ /**
+ * Retrieve the chat template string used for tokenizing chat messages. This template is used
+ * internally by the `applyChatTemplate` method and can also be used externally to retrieve the
+ * model's chat template for better generation tracking.
+ *
+ * @param options An optional object containing the following properties:
+ * @param options.chatTemplate A Jinja template or the name of a template to use for this
+ * conversion. It is usually not necessary to pass anything to this argument, as the model's
+ * template will be used by default.
+ * @param options.tools A list of tools (callable functions) that will be accessible to the model.
+ * If the template does not support function calling, this argument will have no effect. Each
+ * tool should be passed as a JSON Schema, giving the name, description and argument types for
+ * the tool. See our
+ * [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
+ * for more information.
+ * @returns The chat template string.
+ */
+ getChatTemplate({
+ chatTemplate = null,
+ tools = null
+ }: { chatTemplate?: string | null; tools?: string[] | null } = {}): string {
+ // First, handle the cases when the model has a dict of multiple templates
+ if (this.chatTemplate && typeof this.chatTemplate === 'object') {
+ const templateDict = this.chatTemplate;
+
+ if (chatTemplate !== null && Object.hasOwn(templateDict, chatTemplate)) {
+ // The user can pass the name of a template to the chat template argument instead of an
+ // entire template
+ chatTemplate = (templateDict as Record)[chatTemplate];
+ } else if (chatTemplate === null) {
+ if (tools !== null && 'toolUse' in templateDict) {
+ chatTemplate = templateDict['toolUse'];
+ } else if ('default' in templateDict) {
+ chatTemplate = templateDict['default'];
+ } else {
+ throw Error(
+ `This model has multiple chat templates with no default specified! Please either pass` +
+ ` a chat template or the name of the template you wish to use to the 'chatTemplate'` +
+ ` argument. Available template names are ${Object.keys(templateDict).sort()}.`
+ );
+ }
+ }
+ } else if (chatTemplate === null) {
+ // These are the cases when the model has a single template
+ // priority: `chatTemplate` argument > `tokenizer.chatTemplate`
+ if (this.chatTemplate) {
+ chatTemplate = this.chatTemplate;
+ } else {
+ throw Error(
+ 'Cannot use applyChatTemplate() because tokenizer.chatTemplate is not set and no template ' +
+ 'argument was passed! For information about writing templates and setting the ' +
+ 'tokenizer.chatTemplate attribute, please see the documentation at ' +
+ 'https://huggingface.co/docs/transformers/main/en/chat_templating'
+ );
+ }
+ }
+ return chatTemplate;
+ }
+
+ /**
+ * Converts a list of message objects with `"role"` and `"content"` keys to a list of token
+ * ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
+ * determine the format and control tokens to use when converting.
+ *
+ * See [here](https://huggingface.co/docs/transformers/chat_templating) for more information.
+ *
+ * @param conversation A list of message objects with `"role"` and `"content"` keys,
+ * representing the chat history so far.
+ * @param options An optional object containing the following properties:
+ * @param options.chatTemplate A Jinja template to use for this conversion. If
+ * this is not passed, the model's chat template will be used instead.
+ * @param options.tools A list of tools (callable functions) that will be accessible to the model.
+ * If the template does not support function calling, this argument will have no effect. Each
+ * tool should be passed as a JSON Schema, giving the name, description and argument types for
+ * the tool. See our
+ * [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
+ * for more information.
+ * @param options.documents A list of dicts representing documents that will be accessible to the model if it is performing RAG
+ * (retrieval-augmented generation). If the template does not support RAG, this argument will have no
+ * effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
+ * see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
+ * for examples of passing documents with chat templates.
+ * @param options.addGenerationPrompt Whether to end the prompt with the token(s) that indicate
+ * the start of an assistant message. This is useful when you want to generate a response from the
+ * model. Note that this argument will be passed to the chat template, and so it must be supported
+ * in the template for this argument to have any effect.
+ * @param options.tokenize Whether to tokenize the output. If false, the output will be a string.
+ * @param options.padding Whether to pad sequences to the maximum length. Has no effect if tokenize is false.
+ * @param options.truncation Whether to truncate sequences to the maximum length. Has no effect if tokenize is false.
+ * @param options.maxLength Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is false.
+ * If not specified, the tokenizer's `max_length` attribute will be used as a default.
+ * @param options.returnTensor Whether to return the output as a Tensor or an Array. Has no effect if tokenize is false.
+ * @param options.returnDict Whether to return a dictionary with named outputs. Has no effect if tokenize is false.
+ * @param options.tokenizerKwargs Additional options to pass to the tokenizer.
+ * @returns The tokenized output.
+ */
+ applyChatTemplate(
+ conversation: Message[],
+ {
+ tools = null,
+ documents = null,
+ chatTemplate = null,
+ addGenerationPrompt = false,
+ tokenize = true,
+ padding = false,
+ truncation = false,
+ maxLength = null,
+ returnTensor = true,
+ returnDict = false,
+ tokenizerKwargs = {},
+ ...kwargs
+ }: {
+ tools?: string[] | null;
+ documents?: string[] | null;
+ chatTemplate?: string | null;
+ addGenerationPrompt?: boolean;
+ tokenize?: boolean;
+ padding?: boolean;
+ truncation?: boolean;
+ maxLength?: number | null;
+ returnTensor?: boolean;
+ returnDict?: boolean;
+ tokenizerKwargs?: Record;
+ } = {}
+ ) {
+ chatTemplate = this.getChatTemplate({ chatTemplate, tools });
+
+ if (typeof chatTemplate !== 'string') {
+ throw Error(`chat_template must be a string, but got ${typeof chatTemplate}`);
+ }
+
+ // Compilation function uses a cache to avoid recompiling the same template
+ let compiledTemplate = this.compiledTemplateCache.get(chatTemplate);
+ if (compiledTemplate === undefined) {
+ compiledTemplate = new Template(chatTemplate);
+ this.compiledTemplateCache.set(chatTemplate, compiledTemplate);
+ }
+
+ const specialTokensMap = Object.create(null);
+ for (const key of SPECIAL_TOKEN_ATTRIBUTES) {
+ const value = this.getToken(key);
+ if (value) {
+ specialTokensMap[key] = value;
+ }
+ }
+
+ const rendered = compiledTemplate.render({
+ messages: conversation,
+ addGenerationPrompt,
+ tools,
+ documents,
+ ...specialTokensMap,
+ ...kwargs
+ });
+
+ if (tokenize) {
+ const out = this.call(rendered, {
+ addSpecialTokens: false,
+ padding,
+ truncation,
+ maxLength,
+ returnTensor,
+ ...tokenizerKwargs
+ });
+ return returnDict ? out : out.inputIds;
+ }
+
+ return rendered;
+ }
+}
+
+export function max(arr: T) {
+ if (arr.length === 0) throw Error('Array must not be empty');
+ let max = arr[0];
+ let indexOfMax = 0;
+ for (let i = 1; i < arr.length; ++i) {
+ if (arr[i] > max) {
+ max = arr[i];
+ indexOfMax = i;
+ }
+ }
+ return [max, indexOfMax] as T extends bigint[] ? [bigint, number] : [number, number];
+}
+
+function mergeArrays(...arrs: T[]): T {
+ return Array.prototype.concat.apply([], arrs) as T;
+}
+
+type TrieNode = {
+ /**
+ * If this node marks the end of a word, this property will
+ * contain the complete word. Otherwise, it's undefined.
+ */
+ end?: string;
+
+ /**
+ * An index signature to represent child nodes. Each key is a
+ * character, and each value is the next TrieNode in the sequence.
+ * The value is a union to satisfy TypeScript's index signature rules.
+ */
+ [key: string]: TrieNode | string | undefined;
+};
+
+/**
+ * A data structure which uses a trie to split a string into tokens based on a dictionary.
+ * It can also use a regular expression to preprocess the input text before splitting.
+ *
+ * NOTE: To ensure multi-byte characters are handled correctly, we operate at byte-level instead of character-level.
+ */
+class DictionarySplitter {
+ trie: TrieNode;
+ /**
+ * @param dictionary The dictionary of words to use for splitting.
+ */
+ constructor(dictionary: string[]) {
+ this.trie = this.buildTrie(dictionary);
+ }
+
+ /**
+ * Builds a trie from the given dictionary.
+ * @param dictionary The dictionary of words to build the trie from.
+ * @returns The root node of the trie.
+ */
+ private buildTrie(dictionary: string[]) {
+ const trie: TrieNode = Object.create(null);
+ for (const word of dictionary) {
+ let node = trie;
+ for (let i = 0; i < word.length; ++i) {
+ node = (node[word[i]] ??= Object.create(null)) as TrieNode;
+ }
+ node.end = word;
+ }
+ return trie;
+ }
+
+ /**
+ * Splits the input text into tokens based on the dictionary.
+ * @param {string} text The input text to split.
+ * @returns {string[]} An array of tokens.
+ */
+ split(text: string): string[] {
+ const result = [];
+ const n = text.length;
+ let start = 0;
+ let i = 0;
+
+ while (i < n) {
+ let node = this.trie;
+ let match = null;
+ let j = i;
+
+ while (j < n && (node = node[text[j]] as TrieNode)) {
+ if (node.end) {
+ // Always keep the last (i.e., longest) match.
+ match = node.end;
+ }
+ ++j;
+ }
+
+ if (match) {
+ if (i > start) {
+ result.push(text.slice(start, i));
+ }
+ result.push(match);
+ i += match.length;
+ start = i;
+ } else {
+ ++i;
+ }
+ }
+ if (start < n) {
+ result.push(text.slice(start));
+ }
+ return result;
+ }
+}
+
+/**
+ * Efficient Heap-based Implementation of a Priority Queue.
+ * It uses an array-based binary heap, where the root is at index `0`, and the
+ * children of node `i` are located at indices `2i + 1` and `2i + 2`, respectively.
+ *
+ * Adapted from the following sources:
+ * - https://stackoverflow.com/a/42919752/13989043 (original)
+ * - https://github.com/belladoreai/llama-tokenizer-js (minor improvements)
+ */
+class PriorityQueue {
+ private heap: T[];
+ private comparator: (a: T, b: T) => boolean;
+ private maxSize: number;
+ /**
+ * Create a new PriorityQueue.
+ * @param comparator Comparator function to determine priority. Defaults to a MaxHeap.
+ */
+ constructor(comparator = (a: T, b: T) => a > b, maxSize = Infinity) {
+ this.heap = [];
+ this.comparator = comparator;
+ this.maxSize = maxSize;
+ }
+
+ /**
+ * The size of the queue
+ */
+ get size() {
+ return this.heap.length;
+ }
+
+ /**
+ * Check if the queue is empty.
+ * @returns `true` if the queue is empty, `false` otherwise.
+ */
+ isEmpty() {
+ return this.size === 0;
+ }
+
+ /**
+ * Return the element with the highest priority in the queue.
+ * @returns The highest priority element in the queue.
+ */
+ peek() {
+ return this.heap[0];
+ }
+
+ /**
+ * Add one or more elements to the queue.
+ * @param values The values to push into the queue.
+ * @returns The new size of the queue.
+ */
+ push(...values: T[]) {
+ return this.extend(values);
+ }
+
+ /**
+ * Add multiple elements to the queue.
+ * @param values The values to push into the queue.
+ * @returns The new size of the queue.
+ */
+ extend(values: T[]) {
+ for (const value of values) {
+ if (this.size < this.maxSize) {
+ this.heap.push(value);
+ this.siftUp();
+ } else {
+ // Get index of value with the lowest priority
+ const smallest = this.smallest();
+
+ // If the new value has higher priority than the smallest value in the heap
+ // then replace the smallest value with the new value and update the heap
+ if (this.comparator(value, this.heap[smallest])) {
+ this.heap[smallest] = value;
+ this.siftUpFrom(smallest);
+ }
+ }
+ }
+ return this.size;
+ }
+
+ /**
+ * Remove and return the element with the highest priority in the queue.
+ * @returns The element with the highest priority in the queue.
+ */
+ pop() {
+ const poppedValue = this.peek();
+ const bottom = this.size - 1;
+ if (bottom > 0) {
+ this.swap(0, bottom);
+ }
+ this.heap.pop();
+ this.siftDown();
+ return poppedValue;
+ }
+
+ /**
+ * Replace the element with the highest priority in the queue with a new value.
+ * @param value The new value.
+ * @returns The replaced value.
+ */
+ replace(value: T) {
+ const replacedValue = this.peek();
+ this.heap[0] = value;
+ this.siftDown();
+ return replacedValue;
+ }
+
+ /**
+ * Compute the index for the parent of the node at index `i`.
+ * @param i The index of the node to get the parent of.
+ * @returns The index of the parent node.
+ */
+ private parent(i: number) {
+ return ((i + 1) >>> 1) - 1;
+ }
+
+ /**
+ * Compute the index for the left child of the node at index `i`.
+ * @param i The index of the node to get the left child of.
+ * @returns The index of the left child.
+ *
+ */
+ private left(i: number) {
+ return (i << 1) + 1;
+ }
+
+ /**
+ * Compute the index for the right child of the node at index `i`.
+ * @param i The index of the node to get the right child of.
+ * @returns The index of the right child.
+ */
+ private right(i: number) {
+ return (i + 1) << 1;
+ }
+
+ /**
+ * Check if the element at index `i` is greater than the element at index `j`.
+ * @param i The index of the first element to compare.
+ * @param j The index of the second element to compare.
+ * @returns `true` if the element at index `i` is greater than the element at index `j`, `false` otherwise.
+ *
+ */
+ private greater(i: number, j: number) {
+ return this.comparator(this.heap[i], this.heap[j]);
+ }
+
+ /**
+ * Swap the elements at indices `i` and `j`.
+ * @param i The index of the first element to swap.
+ * @param j The index of the second element to swap.
+ *
+ */
+ private swap(i: number, j: number) {
+ const temp = this.heap[i];
+ this.heap[i] = this.heap[j];
+ this.heap[j] = temp;
+ }
+
+ /**
+ * Maintain the heap property by updating positions in the heap,
+ * starting at the last element and moving up the heap.
+ */
+ private siftUp() {
+ this.siftUpFrom(this.size - 1);
+ }
+
+ /**
+ * Helper function to sift up from a given node.
+ * @param node The index of the node to start sifting up from.
+ */
+ private siftUpFrom(node: number) {
+ while (node > 0 && this.greater(node, this.parent(node))) {
+ this.swap(node, this.parent(node));
+ node = this.parent(node);
+ }
+ }
+
+ /**
+ * Maintain the heap property by updating positions in the heap,
+ * starting at the first element and moving down the heap.
+ */
+ private siftDown() {
+ let node = 0;
+ while (
+ (this.left(node) < this.size && this.greater(this.left(node), node)) ||
+ (this.right(node) < this.size && this.greater(this.right(node), node))
+ ) {
+ const maxChild =
+ this.right(node) < this.size && this.greater(this.right(node), this.left(node))
+ ? this.right(node)
+ : this.left(node);
+ this.swap(node, maxChild);
+ node = maxChild;
+ }
+ }
+
+ /**
+ * Get the index of the smallest element in the heap. Since we use an array-based heap,
+ * the index can be computed without needing to traverse the heap.
+ */
+ private smallest(): number {
+ return 2 ** Math.floor(Math.log2(this.size)) - 1;
+ }
+}
+
+/**
+ * A simple Least Recently Used (LRU) cache implementation in JavaScript.
+ * This cache stores key-value pairs and evicts the least recently used item
+ * when the capacity is exceeded.
+ */
+class LRUCache {
+ capacity: number;
+ cache: Map;
+ /**
+ * Creates an LRUCache instance.
+ * @param capacity The maximum number of items the cache can hold.
+ */
+ constructor(capacity: number) {
+ this.capacity = capacity;
+ this.cache = new Map();
+ }
+
+ /**
+ * Retrieves the value associated with the given key and marks the key as recently used.
+ * @param key The key to retrieve.
+ * @returns The value associated with the key, or undefined if the key does not exist.
+ */
+ get(key: Key) {
+ if (!this.cache.has(key)) return undefined;
+ const value = this.cache.get(key);
+ this.cache.delete(key);
+ this.cache.set(key, value as Value);
+ return value;
+ }
+
+ /**
+ * Inserts or updates the key-value pair in the cache.
+ * If the key already exists, it is updated and marked as recently used.
+ * If the cache exceeds its capacity, the least recently used item is evicted.
+ * @param key The key to add or update.
+ * @param value The value to associate with the key.
+ */
+ put(key: Key, value: Value) {
+ if (this.cache.has(key)) {
+ this.cache.delete(key);
+ }
+ this.cache.set(key, value);
+ if (this.cache.size > this.capacity) {
+ this.cache.delete(this.cache.keys().next().value as Key);
+ }
+ }
+
+ /**
+ * Clears the cache.
+ */
+ clear() {
+ this.cache.clear();
+ }
+}
+
+export function decodeSingle(value: number, tokenizer: PreTrainedTokenizer | null): string {
+ if (tokenizer instanceof PreTrainedTokenizer) {
+ return tokenizer
+ .decodeSingle([value])
+ .replaceAll('<|end_of_text|>', '▶️📄')
+ .replaceAll('<|im_start|>', '▶️💬')
+ .replaceAll('<|im_end|>', '⏹️💬');
+ }
+ return `<${value}>`;
+}
diff --git a/examples/finetuning/src/lib/train/types.ts b/examples/finetuning/src/lib/train/types.ts
new file mode 100644
index 00000000..600749c5
--- /dev/null
+++ b/examples/finetuning/src/lib/train/types.ts
@@ -0,0 +1,27 @@
+import type { DataLoader } from '@piston-ml/piston-web';
+
+import type { NaturalLanguageAutoregressiveBatch } from './data/natural';
+import type { GPT } from './model/gpt';
+
+type CollateFn = (batch: B) => T;
+
+type NaturalCollateInput = number[][];
+
+export type NaturalBatchType = NaturalLanguageAutoregressiveBatch;
+
+export type AutoregressiveBatchType = NaturalLanguageAutoregressiveBatch;
+
+export type NaturalAutoregressiveCollateFnType = CollateFn<
+ NaturalCollateInput,
+ NaturalLanguageAutoregressiveBatch
+>;
+
+export type AutoregressiveCollateFnType = NaturalAutoregressiveCollateFnType;
+
+export type NaturalCollateFnType = CollateFn>;
+
+export type NaturalDataloaderType = DataLoader>;
+
+export type AutoregressiveModelType = GPT;
+
+export type GeneratableModel = AutoregressiveModelType;
diff --git a/examples/finetuning/src/lib/train/utils/checkpoint.ts b/examples/finetuning/src/lib/train/utils/checkpoint.ts
new file mode 100644
index 00000000..f55c5341
--- /dev/null
+++ b/examples/finetuning/src/lib/train/utils/checkpoint.ts
@@ -0,0 +1,224 @@
+import type { Config } from '$lib/workspace/config';
+
+import * as piston from '@piston-ml/piston-web';
+import {
+ type ConstantConfig,
+ type CosineAnnealingConfig,
+ type ExponentialConfig,
+ type LinearConfig,
+ LRScheduler,
+ Optimizer,
+ type OptimizerParamState,
+ type ParamGroupConfig,
+ type SchedulerStateDict,
+ type StateDict,
+ type StepConfig,
+ Tensor
+} from '@piston-ml/piston-web';
+
+/**
+ * Recursively walks an object and extracts any Tensor values into `out` as Buffers.
+ * Replaces extracted tensors in the returned structure with a small marker object
+ * containing the tensor storage key that was used in `out`.
+ */
+export function splitTensorsFromObject(
+ value: unknown,
+ baseKey: string,
+ out: Record
+): unknown {
+ if (value instanceof Tensor) {
+ out[baseKey] = new piston.Buffer(value, true);
+ return { __tensor__: baseKey };
+ }
+ if (Array.isArray(value)) {
+ return value.map((v, i) => splitTensorsFromObject(v, `${baseKey}.${i}`, out));
+ }
+ if (value && typeof value === 'object') {
+ const result: Record = {};
+ for (const [k, v] of Object.entries(value)) {
+ result[k] = splitTensorsFromObject(v, `${baseKey}.${k}`, out);
+ }
+ return result;
+ }
+ return value;
+}
+
+export type AnySchedulerState = SchedulerStateDict<
+ StepConfig | CosineAnnealingConfig | ExponentialConfig | LinearConfig | ConstantConfig | unknown
+>;
+
+export type CheckpointOptimizerExtra = {
+ name: string;
+ // JSON with __tensor__ markers
+ state: unknown;
+ paramGroups: ParamGroupConfig[];
+} | null;
+
+export interface CheckpointDataState {
+ blockSize: number;
+ shardIndex: number;
+ cursor: number;
+}
+
+export interface CheckpointExtra {
+ config: Config;
+ optimizer: CheckpointOptimizerExtra;
+ numSteps: number;
+ lrScheduler?: { state: AnySchedulerState };
+ dataState?: CheckpointDataState;
+ // Optional wall-clock training start time in ms to persist across restarts
+ startTimeMs?: number;
+}
+
+/**
+ * Builds a checkpoint payload by combining model parameters with optimizer state.
+ * - Model parameters/buffers go into `tensors` directly
+ * - Any Tensor values found inside optimizer.stateDict().state are lifted into `tensors`
+ * under keys prefixed with `optimizer/state/...`
+ * - Extra contains { config, optimizer, numSteps }
+ */
+
+export function buildCheckpoint(
+ model: piston.Module,
+ optimizer: Optimizer,
+ numSteps: number,
+ configForExtra: Config,
+ scheduler?: LRScheduler,
+ dataState?: CheckpointDataState,
+ startTimeMs?: number
+): { tensors: Record; extra: CheckpointExtra } {
+ const tensors: Record = model.stateDict();
+
+ let optimizerExtra: CheckpointOptimizerExtra = null;
+ try {
+ const name = optimizer.constructor.name ?? 'Optimizer';
+ const packed = optimizer.stateDict();
+ const tensorSlots: Record = {};
+ const jsonState = splitTensorsFromObject(packed.state, 'optimizer.state', tensorSlots);
+ Object.assign(tensors, tensorSlots);
+ optimizerExtra = {
+ name,
+ state: jsonState,
+ paramGroups: packed.paramGroups
+ };
+ } catch (e) {
+ console.warn('Failed to pack optimizer stateDict for checkpoint extra:', e);
+ }
+
+ const extra: CheckpointExtra = {
+ config: configForExtra,
+ optimizer: optimizerExtra,
+ numSteps,
+ lrScheduler: scheduler ? { state: scheduler.stateDict() } : undefined,
+ dataState,
+ startTimeMs
+ };
+
+ return { tensors, extra };
+}
+
+/**
+ * Replace any marker objects of the form { __tensor__: key } inside a JSON structure
+ * with actual Tensors from the provided mapping.
+ */
+export function rehydrateTensorsInObject(value: unknown, lifted: Record): T {
+ if (value && typeof value === 'object' && !Array.isArray(value)) {
+ const marker = value as { __tensor__?: string };
+ if (typeof marker.__tensor__ === 'string') {
+ const key = marker.__tensor__;
+ if (!(key in lifted)) {
+ throw new Error(`Missing lifted tensor for key '${key}' during optimizer rehydration`);
+ }
+ return lifted[key] as unknown as T;
+ }
+ const out: Record = {};
+ for (const [k, v] of Object.entries(value)) {
+ out[k] = rehydrateTensorsInObject(v, lifted);
+ }
+ return out as unknown as T;
+ }
+ if (Array.isArray(value)) {
+ return value.map((v) => rehydrateTensorsInObject(v, lifted)) as unknown as T;
+ }
+ return value as T;
+}
+
+export interface SplitLoadedStateResult {
+ modelState: Record;
+ schedulerState?: AnySchedulerState;
+ optimizerState: StateDict;
+ numSteps: number;
+ config: Config;
+ dataState?: CheckpointDataState;
+ startTimeMs?: number;
+}
+
+/**
+ * Given loaded state from piston.load, split out model state from lifted optimizer tensors
+ * and rehydrate optimizer and scheduler states from extras.
+ */
+export function splitLoadedState(loaded: {
+ state: Record;
+ extra?: CheckpointExtra;
+}): SplitLoadedStateResult {
+ const prefix = 'optimizer.state';
+ const liftedOptimizerTensors: Record = {};
+ const modelState: Record = {};
+
+ for (const [key, t] of Object.entries(loaded.state)) {
+ if (key.startsWith(prefix)) {
+ liftedOptimizerTensors[key] = t;
+ } else {
+ modelState[key] = t;
+ }
+ }
+
+ let optimizerState: StateDict | undefined;
+ let schedulerState: AnySchedulerState | undefined = undefined;
+ let numSteps = 0;
+ let config: Config | null = null;
+ let dataState: CheckpointDataState | undefined = undefined;
+ let startTimeMs: number | undefined = undefined;
+
+ const { extra } = loaded;
+
+ if (extra) {
+ config = extra.config;
+ numSteps = extra.numSteps;
+ if (extra.optimizer) {
+ const rehydratedState = rehydrateTensorsInObject>(
+ extra.optimizer.state,
+ liftedOptimizerTensors
+ );
+ optimizerState = {
+ state: rehydratedState,
+ paramGroups: extra.optimizer.paramGroups ?? []
+ };
+ }
+ if (extra.lrScheduler && extra.lrScheduler.state) {
+ schedulerState = extra.lrScheduler.state;
+ }
+ if (extra.dataState) {
+ dataState = extra.dataState;
+ }
+ if (typeof extra.startTimeMs === 'number') {
+ startTimeMs = extra.startTimeMs;
+ }
+ }
+
+ if (!config) {
+ throw new Error('No config found in checkpoint');
+ }
+
+ if (numSteps == null) {
+ throw new Error('No numSteps found in checkpoint');
+ }
+
+ if (!optimizerState) {
+ throw new Error('No optimizer state found in checkpoint');
+ }
+
+ // Some runs don't use a scheduler, so we don't validate that it's present
+
+ return { modelState, optimizerState, schedulerState, numSteps, config, dataState, startTimeMs };
+}
diff --git a/examples/finetuning/src/lib/train/utils/init.ts b/examples/finetuning/src/lib/train/utils/init.ts
new file mode 100644
index 00000000..974695f3
--- /dev/null
+++ b/examples/finetuning/src/lib/train/utils/init.ts
@@ -0,0 +1,54 @@
+// import type { Config, ProjectionInitializationConfig } from '$lib/workspace/config';
+
+// import { initNormal_, initOnes_, initZeros_, nn } from '@piston-ml/piston-web';
+
+// TODO: Setup initialization for GPT2 lora
+
+// export function initTransformerParameters(self: nn.Module, config: Config): void {
+// const initializationConfig = config.model.transformer.initialization;
+
+// if (!initializationConfig.present) {
+// return;
+// }
+
+// const initTransformerWeights = (module: nn.Module): void => {
+// if (module instanceof nn.Linear) {
+// initNormal_(module.weight, { mean: 0.0, std: initializationConfig.std });
+// if (module.bias != null) {
+// initZeros_(module.bias);
+// }
+// } else if (module instanceof nn.Embedding) {
+// initNormal_(module.weight, { mean: 0.0, std: initializationConfig.std });
+// } else if (module instanceof nn.LayerNorm) {
+// if (module.bias) {
+// initZeros_(module.bias);
+// }
+// initOnes_(module.weight);
+// }
+// };
+
+// const initProjection = (p: nn.Parameter, projectionConfig: ProjectionInitializationConfig) => {
+// if (!projectionConfig.present) {
+// return;
+// }
+// const nLayers = config.model.layers;
+// if (projectionConfig.strategy === 'layer-scaled') {
+// initNormal_(p, { mean: 0.0, std: 0.02 / Math.sqrt(2 * nLayers) });
+// } else if (projectionConfig.strategy === 'zero') {
+// initZeros_(p);
+// }
+// };
+
+// self.apply(initTransformerWeights);
+// for (const [pn, p] of self.namedParameters()) {
+// if (pn.endsWith('cProj.weight')) {
+// initProjection(p, initializationConfig.projections.attention);
+// }
+// if (pn.endsWith('downProj.weight')) {
+// initProjection(p, initializationConfig.projections.mlp);
+// }
+// if (pn.endsWith('lmHead.weight')) {
+// initProjection(p, initializationConfig.projections.lmHead);
+// }
+// }
+// }
diff --git a/examples/finetuning/src/lib/train/utils/model.ts b/examples/finetuning/src/lib/train/utils/model.ts
new file mode 100644
index 00000000..289b7a81
--- /dev/null
+++ b/examples/finetuning/src/lib/train/utils/model.ts
@@ -0,0 +1,134 @@
+import {
+ CaptureIndexMode,
+ DataLoader,
+ type IndexState,
+ Module,
+ Tensor,
+ weak
+} from '@piston-ml/piston-web';
+import * as piston from '@piston-ml/piston-web';
+
+import type { Config } from '../../workspace/config';
+import type { NaturalBatchType, NaturalCollateFnType, NaturalDataloaderType } from '../types';
+
+import { type CollateWrapFunction } from '../data';
+import { naturalLanguageAutoregressiveCollate, NaturalLanguageDataset } from '../data/natural';
+import { buildGPT2Config } from '../model/config';
+import { GPT, GPT2_BLOCK_SIZE, GPT2_VOCAB_SIZE } from '../model/gpt';
+import { parseSeed } from './random';
+
+// Overloads for strong typing based on dataset kind
+export function createCollateFn(
+ wrapFunction?: CollateWrapFunction | null
+): NaturalCollateFnType {
+ const collateOptions = wrapFunction !== undefined ? { wrapFunction } : {};
+ return (batch: number[][]) =>
+ naturalLanguageAutoregressiveCollate(batch as number[][], {
+ ...collateOptions
+ });
+}
+
+export function createDataloader(
+ config: Config,
+ dataset: NaturalLanguageDataset,
+ wrapFunction?: CollateWrapFunction | null
+): [NaturalDataloaderType, NaturalCollateFnType] {
+ const collateFn = createCollateFn(wrapFunction);
+ return [
+ new DataLoader>(dataset, {
+ collateFn,
+ batchSize: config.training.batchSize
+ }),
+ collateFn
+ ];
+}
+
+/**
+ * Create a model instance based on the configuration
+ */
+export function createModel(config: Config): GPT {
+ return new GPT(buildGPT2Config(config.model.type), config);
+}
+
+export function calculateParameterSum(model: Module): Tensor {
+ const sums = model.parameters().map((param) => param.sum());
+ return piston.stack(sums).sum();
+}
+
+export function countParameters(model: GPT): number {
+ let totalParams = 0;
+
+ // Walk through all named parameters
+ for (const [_, param] of model.namedParameters()) {
+ if (param && param.shape) {
+ const paramCount = (param.shape as number[]).reduce(
+ (acc: number, dim: number) => acc * dim,
+ 1
+ );
+ totalParams += paramCount;
+ }
+ }
+
+ return totalParams;
+}
+
+/**
+ * Inspect model for a given configuration: count the number of parameters and capture an "index"
+ * of the model.
+ */
+export function inspectModel(config: Config): {
+ parameterCount: number;
+ hiddenSize: number;
+ mlpIntermediateSize: number;
+ modelIndex: IndexState;
+ vocabSize: number;
+ blockSize: number;
+} {
+ return weak(
+ () => {
+ const blockSize = GPT2_BLOCK_SIZE;
+ const vocabSize = GPT2_VOCAB_SIZE;
+ const model = createModel(config);
+ const parameterCount = countParameters(model);
+ const hiddenSize = model.config.nEmbd;
+ const mlpIntermediateSize = hiddenSize * 4;
+
+ let indexMode: CaptureIndexMode | null = null;
+ try {
+ indexMode = new CaptureIndexMode(model);
+
+ // Run the model forward with an input from the dataloader
+ model.forward(piston.zeros([1, blockSize], { dtype: piston.int32 }));
+
+ console.debug(`Model has ${parameterCount} parameters with vocab size ${vocabSize}`);
+
+ return {
+ parameterCount,
+ hiddenSize,
+ mlpIntermediateSize,
+ vocabSize,
+ blockSize,
+ modelIndex: indexMode!.index
+ };
+ } finally {
+ indexMode![Symbol.dispose]();
+ }
+ },
+ {
+ label: 'inspectModel'
+ }
+ );
+}
+
+export function seedPiston(config: Config) {
+ // Set up random number generator
+ const seed = parseSeed(
+ config.training.randomSeed.present ? config.training.randomSeed.value : undefined
+ );
+
+ if (seed !== undefined) {
+ piston.seed(seed);
+ }
+
+ return seed;
+}
diff --git a/examples/finetuning/src/lib/train/utils/modes.ts b/examples/finetuning/src/lib/train/utils/modes.ts
new file mode 100644
index 00000000..4be5a4e7
--- /dev/null
+++ b/examples/finetuning/src/lib/train/utils/modes.ts
@@ -0,0 +1,96 @@
+import {
+ PistonFunctionMode,
+ PistonMarkStepMode,
+ WeakMarkStepMode,
+ WeakTensorFunctionMode,
+ type WeakTensorFunctionModeOptions
+} from '@piston-ml/piston-web';
+import * as piston from '@piston-ml/piston-web';
+
+export class DebugMode extends PistonFunctionMode {
+ constructor(public debugEnabled: boolean) {
+ super();
+ }
+
+ _pistonFunction(
+ func: (...args: unknown[]) => FT | Promise,
+ _types: unknown[],
+ args: unknown[],
+ kwargs: Record
+ ): T | Promise {
+ if (this.debugEnabled) {
+ console.log(
+ func.name,
+ args.reduce((acc: number[], a) => {
+ if (a instanceof piston.wasm.Tensor_wasm) {
+ return [...acc, a.id];
+ }
+ return acc;
+ }, [])
+ );
+ }
+
+ const after = (result: T) => {
+ if (result instanceof piston.wasm.Tensor_wasm) {
+ console.log(func.name, 'result', result.id);
+ }
+ return result;
+ };
+
+ const result = func(...args, kwargs) as T | Promise;
+ if (result instanceof Promise) {
+ return result.then(after) as Promise;
+ }
+
+ return after(result) as T;
+ }
+}
+
+export class WeakModeIfEnabled {
+ private mode: WeakTensorFunctionMode | null = null;
+
+ constructor(
+ public enabled: boolean,
+ public options: WeakTensorFunctionModeOptions
+ ) {
+ if (enabled) {
+ this.mode = new WeakTensorFunctionMode(options);
+ }
+ }
+
+ markWeak(input: T) {
+ if (this.mode) {
+ this.mode.markWeak(input);
+ }
+ return input;
+ }
+
+ pin(input: T) {
+ if (this.mode) {
+ this.mode.pin(input);
+ }
+ return input;
+ }
+
+ [Symbol.dispose]() {
+ if (this.mode) {
+ this.mode[Symbol.dispose]();
+ }
+ }
+}
+
+export class MarkStepModeIfEnabled {
+ private mode: PistonMarkStepMode | null = null;
+
+ constructor(public enabled: boolean) {
+ if (enabled) {
+ this.mode = new WeakMarkStepMode();
+ }
+ }
+
+ [Symbol.dispose]() {
+ if (this.mode) {
+ this.mode[Symbol.dispose]();
+ }
+ }
+}
diff --git a/examples/finetuning/src/lib/train/utils/optim.ts b/examples/finetuning/src/lib/train/utils/optim.ts
new file mode 100644
index 00000000..855f40d8
--- /dev/null
+++ b/examples/finetuning/src/lib/train/utils/optim.ts
@@ -0,0 +1,381 @@
+import type { OptimizerConfig } from '$lib/workspace/config';
+
+import {
+ AdamW,
+ type Device,
+ type Module,
+ MuonWithAdamW,
+ type MuonWithAdamWParamGroup,
+ nn,
+ type Optimizer,
+ type Parameter as ParameterType,
+ type ParamGroup,
+ SGD
+} from '@piston-ml/piston-web';
+
+// Deterministic sorting helpers
+function compareByName(a: [string, T], b: [string, T]): number {
+ return a[0] < b[0] ? -1 : a[0] > b[0] ? 1 : 0;
+}
+
+function sortEntriesByName(entries: Array<[string, T]>): Array<[string, T]> {
+ return entries.sort(compareByName);
+}
+
+function paramsFromNamesSorted(
+ names: Iterable,
+ paramDict: Map
+): ParameterType[] {
+ return Array.from(names)
+ .sort()
+ .map((name) => paramDict.get(name)!)
+ .filter((p) => p != null);
+}
+
+/**
+ * Validates that all model parameters are included in the parameter groups.
+ * Throws an error if any parameters are missing from the groups.
+ * @param model - The model to validate
+ * @param paramGroups - The parameter groups to check
+ * @throws Error if any model parameters are not included in the parameter groups
+ */
+function validateParameterGroups(
+ model: Module,
+ paramGroups: ParamGroup[],
+ paramDict: Map
+): void {
+ // Get all parameters from the model
+ const allModelParams = new Set();
+ for (const [_, param] of model.namedParameters()) {
+ allModelParams.add(param);
+ }
+
+ // Get all parameters from the parameter groups
+ const groupParams = new Set();
+ for (const group of paramGroups) {
+ for (const param of group.params) {
+ groupParams.add(param);
+ }
+ }
+
+ // Find parameters that are in the model but not in any group
+ const missingParams: ParameterType[] = [];
+ for (const param of allModelParams) {
+ if (!groupParams.has(param)) {
+ missingParams.push(param);
+ }
+ }
+
+ if (missingParams.length > 0) {
+ // Find the names of the missing parameters using paramDict
+ const missingParamNames: string[] = [];
+ for (const [name, param] of paramDict) {
+ if (missingParams.includes(param)) {
+ missingParamNames.push(name);
+ }
+ }
+
+ throw new Error(
+ `Found ${missingParams.length} model parameters that are not included in any parameter group (${groupParams.size} included). ` +
+ `All model parameters must be assigned to a parameter group for training. ` +
+ `Missing parameters: ${missingParamNames.join(', ')}`
+ );
+ }
+}
+
+function getWeightDecayParams(
+ model: Module,
+ useWeightDecayGroups: boolean,
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ whitelistWeightModules: (new (...args: any[]) => Module)[],
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ blacklistWeightModules: (new (...args: any[]) => Module)[]
+): { paramDict: Map; decay: Set; noDecay: Set } {
+ const decay = new Set();
+ const noDecay = new Set();
+
+ const paramDict = new Map();
+
+ for (const [mn, m] of model.namedModules()) {
+ for (const [pn, p] of m.namedParameters()) {
+ const fpn = mn ? `${mn}.${pn}` : pn;
+
+ paramDict.set(fpn, p);
+
+ if (useWeightDecayGroups) {
+ if (pn.endsWith('bias')) {
+ // All biases will not be decayed
+ noDecay.add(fpn);
+ } else if (pn.endsWith('weight')) {
+ if (whitelistWeightModules.some((cls) => m instanceof cls)) {
+ // Weights of whitelist modules will be weight decayed
+ decay.add(fpn);
+ } else if (blacklistWeightModules.some((cls) => m instanceof cls)) {
+ // Weights of blacklist modules will NOT be weight decayed
+ noDecay.add(fpn);
+ }
+ } else {
+ // Parameters that are not weights or biases (shouldn't exist in std models)
+ // Add to decay by default, adjust if necessary for specific models.
+ decay.add(fpn);
+ }
+ } else {
+ decay.add(fpn);
+ }
+ }
+ }
+
+ if (useWeightDecayGroups) {
+ // Validate that we considered every parameter
+ const allParamNames = new Set(
+ Array.from(model.namedParameters()).map(([name]) => name)
+ );
+ const interParams = new Set([...decay].filter((x) => noDecay.has(x)));
+ const unionParams = new Set([...decay, ...noDecay]);
+
+ if (interParams.size !== 0) {
+ throw new Error(
+ `Parameters ${JSON.stringify(Array.from(interParams))} made it into both decay/noDecay sets`
+ );
+ }
+ const missingParams = new Set([...allParamNames].filter((x) => !unionParams.has(x)));
+ if (missingParams.size !== 0) {
+ throw new Error(
+ `Parameters ${JSON.stringify(
+ Array.from(missingParams)
+ )} were not separated into either decay/noDecay set`
+ );
+ }
+ }
+
+ return { paramDict, decay, noDecay };
+}
+
+/**
+ * Based on what minGPT does:
+ * Configures the optimizer based on the training configuration.
+ * Separates parameters into weight decay and no weight decay groups.
+ * @param trainConfig - The optimizer
+ * configuration
+ * @param device - The computation device
+ * @returns The configured optimizer
+ */
+export function configureOptimizers(
+ model: Module,
+ moduleLayersPrefixes: string[],
+ lmHeadPrefix: string,
+ trainConfig: OptimizerConfig,
+ device: Device
+): Optimizer {
+ const whitelistWeightModules = [nn.Linear];
+ const blacklistWeightModules = [nn.LayerNorm, nn.RMSNorm, nn.Embedding];
+
+ const effectiveWeightDecay = trainConfig.weightDecay.present
+ ? trainConfig.weightDecay.value
+ : 0.0;
+
+ const { paramDict, decay, noDecay } = getWeightDecayParams(
+ model,
+ trainConfig.weightDecay.useWeightDecayGroups,
+ whitelistWeightModules,
+ blacklistWeightModules
+ );
+ // Deterministic param lists by name
+ const decayParamsValues = paramsFromNamesSorted(decay, paramDict);
+ const noDecayParamsValues = paramsFromNamesSorted(noDecay, paramDict);
+
+ if (trainConfig.type === 'AdamW' || trainConfig.type === 'Adam' || trainConfig.type === 'SGD') {
+ const optimGroups: ParamGroup[] = [
+ {
+ params: decayParamsValues,
+ weightDecay: effectiveWeightDecay
+ },
+ ...(noDecayParamsValues.length > 0
+ ? [
+ {
+ params: noDecayParamsValues,
+ weightDecay: 0.0 // no decay
+ }
+ ]
+ : [])
+ ];
+
+ validateParameterGroups(model, optimGroups, paramDict);
+
+ // Create the AdamW optimizer
+ if (trainConfig.type === 'AdamW' || trainConfig.type === 'Adam') {
+ return new AdamW(optimGroups, device, {
+ lr: trainConfig.lr,
+ betas: [trainConfig.adam.beta1, trainConfig.adam.beta2],
+ eps: trainConfig.adam.eps,
+ weightDecay: effectiveWeightDecay,
+ amsgrad: trainConfig.adam.amsgrad
+ });
+ } else if (trainConfig.type === 'SGD') {
+ return new SGD(optimGroups, device, {
+ lr: trainConfig.lr,
+ momentum: trainConfig.sgd.momentum,
+ dampening: trainConfig.sgd.dampening,
+ weightDecay: effectiveWeightDecay,
+ nesterov: trainConfig.sgd.nesterov
+ });
+ }
+ } else if (trainConfig.type === 'Muon') {
+ // Get parameter groups by type
+ const paramEntries = sortEntriesByName(Array.from(paramDict.entries()));
+ const moduleLayersParams = paramEntries.filter(([n]) =>
+ moduleLayersPrefixes.some((prefix) => n.startsWith(prefix))
+ );
+ // Sort each category deterministically by name
+ const hiddenMatrixParams = sortEntriesByName(
+ moduleLayersParams.filter(([n, p]) => p.ndim >= 2 && !n.toLowerCase().includes('embed'))
+ );
+ const scalarParams = sortEntriesByName(moduleLayersParams.filter(([_, p]) => p.ndim < 2));
+ const embedParams = sortEntriesByName(
+ paramEntries.filter(([n, _]) => n.toLowerCase().includes('embed'))
+ );
+ const headParams = sortEntriesByName(paramEntries.filter(([n]) => n.startsWith(lmHeadPrefix)));
+ // Any other params we just throw to AdamW
+ const filteredParams = new Set([
+ ...hiddenMatrixParams.map(([n]) => n),
+ ...scalarParams.map(([n]) => n),
+ ...embedParams.map(([n]) => n),
+ ...headParams.map(([n]) => n)
+ ]);
+ const remainingParams = paramEntries.filter(([n]) => !filteredParams.has(n));
+
+ if (remainingParams.length > 0) {
+ console.warn(
+ `Found ${remainingParams.length} parameters that don't fit Muon categorization and will be handled by AdamW:`,
+ remainingParams.map(([name]) => name)
+ );
+ }
+
+ // Apply weight decay grouping to each parameter type
+ const paramGroups: MuonWithAdamWParamGroup[] = [];
+
+ // Hidden matrix parameters for Muon optimizer
+ if (trainConfig.weightDecay.useWeightDecayGroups) {
+ const hiddenDecay = hiddenMatrixParams.filter(([name]) => decay.has(name)).map(([_, p]) => p);
+ const hiddenNoDecay = hiddenMatrixParams
+ .filter(([name]) => noDecay.has(name))
+ .map(([_, p]) => p);
+
+ if (hiddenDecay.length > 0) {
+ paramGroups.push({
+ optimizer: 'muon',
+ lr: trainConfig.lr,
+ weightDecay: effectiveWeightDecay,
+ momentum: trainConfig.muon.momentum,
+ nsSteps: trainConfig.muon.nsSteps,
+ nesterov: trainConfig.muon.nesterov,
+ params: hiddenDecay
+ });
+ }
+
+ if (hiddenNoDecay.length > 0) {
+ paramGroups.push({
+ optimizer: 'muon',
+ lr: trainConfig.lr,
+ weightDecay: 0.0, // no decay
+ momentum: trainConfig.muon.momentum,
+ nsSteps: trainConfig.muon.nsSteps,
+ nesterov: trainConfig.muon.nesterov,
+ params: hiddenNoDecay
+ });
+ }
+ } else {
+ if (hiddenMatrixParams.length > 0) {
+ paramGroups.push({
+ optimizer: 'muon',
+ lr: trainConfig.lr,
+ weightDecay: effectiveWeightDecay,
+ momentum: trainConfig.muon.momentum,
+ nsSteps: trainConfig.muon.nsSteps,
+ nesterov: trainConfig.muon.nesterov,
+ params: hiddenMatrixParams.map(([_, p]) => p)
+ });
+ }
+ }
+
+ // Scalar, embedding, and head parameters for AdamW optimizer
+ const adamwParams = sortEntriesByName([
+ ...scalarParams,
+ ...embedParams,
+ ...headParams,
+ ...remainingParams
+ ]);
+
+ // Check if there is any overlap between the two optimizers getting overlap of adamWparams
+ const adamwParamSet = new Set(adamwParams.map(([n]) => n));
+ const muonParamSet = new Set(hiddenMatrixParams.map(([n]) => n));
+ const overlap = adamwParamSet.intersection(muonParamSet);
+ if (overlap.size > 0) {
+ throw new Error(
+ `Overlap between AdamW and Muon parameters: ${Array.from(overlap).join(', ')}`
+ );
+ }
+
+ if (trainConfig.weightDecay.useWeightDecayGroups) {
+ const adamwDecay = adamwParams.filter(([name]) => decay.has(name)).map(([_, p]) => p);
+ const adamwNoDecay = adamwParams.filter(([name]) => noDecay.has(name)).map(([_, p]) => p);
+
+ if (adamwDecay.length > 0) {
+ paramGroups.push({
+ optimizer: 'adamw',
+ lr: trainConfig.lr,
+ betas: [trainConfig.adam.beta1, trainConfig.adam.beta2],
+ eps: trainConfig.adam.eps,
+ weightDecay: effectiveWeightDecay,
+ amsgrad: trainConfig.adam.amsgrad,
+ params: adamwDecay
+ });
+ }
+
+ if (adamwNoDecay.length > 0) {
+ paramGroups.push({
+ optimizer: 'adamw',
+ lr: trainConfig.lr,
+ betas: [trainConfig.adam.beta1, trainConfig.adam.beta2],
+ eps: trainConfig.adam.eps,
+ weightDecay: 0.0, // no decay
+ amsgrad: trainConfig.adam.amsgrad,
+ params: adamwNoDecay
+ });
+ }
+ } else {
+ if (adamwParams.length > 0) {
+ paramGroups.push({
+ optimizer: 'adamw',
+ lr: trainConfig.lr,
+ betas: [trainConfig.adam.beta1, trainConfig.adam.beta2],
+ eps: trainConfig.adam.eps,
+ weightDecay: effectiveWeightDecay,
+ amsgrad: trainConfig.adam.amsgrad,
+ params: adamwParams.map(([_, p]) => p)
+ });
+ }
+ }
+
+ validateParameterGroups(model, paramGroups, paramDict);
+
+ return new MuonWithAdamW(paramGroups, device, {
+ muon: {
+ lr: trainConfig.lr,
+ weightDecay: effectiveWeightDecay,
+ momentum: trainConfig.muon.momentum,
+ nsSteps: trainConfig.muon.nsSteps,
+ nesterov: trainConfig.muon.nesterov
+ },
+ adamw: {
+ lr: trainConfig.lr,
+ betas: [trainConfig.adam.beta1, trainConfig.adam.beta2],
+ eps: trainConfig.adam.eps,
+ weightDecay: effectiveWeightDecay,
+ amsgrad: trainConfig.adam.amsgrad
+ }
+ });
+ }
+
+ throw new Error(`Unknown optimizer type: ${trainConfig.type}`);
+}
diff --git a/examples/finetuning/src/lib/train/utils/random.ts b/examples/finetuning/src/lib/train/utils/random.ts
new file mode 100644
index 00000000..03756479
--- /dev/null
+++ b/examples/finetuning/src/lib/train/utils/random.ts
@@ -0,0 +1,39 @@
+import { MersenneTwister19937, Random } from 'random-js';
+
+export function parseSeed(seed?: string): number | undefined {
+ if (seed === undefined || seed === '') {
+ return undefined;
+ }
+
+ const parsed = parseInt(seed);
+ if (!isNaN(parsed)) {
+ return parsed;
+ }
+
+ // Simple hash function for string
+ let hash = 0;
+ for (let i = 0; i < seed.length; i++) {
+ const char = seed.charCodeAt(i);
+ hash = (hash << 5) - hash + char;
+ hash = hash & hash; // Convert to 32bit integer
+ }
+ return Math.abs(hash);
+}
+
+/**
+ * Creates a seeded random number generator from a string or undefined seed.
+ * If seed is undefined, auto-seeds the generator.
+ * If seed is a string that can be parsed as a number, uses the parsed number.
+ * Otherwise, uses a simple hash function to convert the string to a number.
+ */
+export function seededRandom(seed?: number): Random {
+ if (seed === undefined) {
+ return new Random(MersenneTwister19937.autoSeed());
+ }
+
+ return new Random(MersenneTwister19937.seed(seed));
+}
+
+export function forkRandom(random: Random): Random {
+ return new Random(MersenneTwister19937.seed(random.int32()));
+}
diff --git a/examples/finetuning/src/lib/train/validation.ts b/examples/finetuning/src/lib/train/validation.ts
new file mode 100644
index 00000000..59335bcf
--- /dev/null
+++ b/examples/finetuning/src/lib/train/validation.ts
@@ -0,0 +1,149 @@
+import type { Config, ValidationConfig } from '$lib/workspace/config';
+import type { BaseStepData, TokenRollout } from '$lib/workspace/runs.svelte';
+
+import { type Tensor, weak } from '@piston-ml/piston-web';
+
+import type { GeneratableModel, NaturalCollateFnType } from './types';
+
+import { NaturalLanguageDataset } from './data/natural';
+import { generateDecoderCompletions } from './validationHelpers';
+
+export type ValidationStep = BaseStepData & {
+ type: 'validation';
+ completions: TokenRollout[];
+ samplingParams: {
+ temperature: number;
+ };
+ // Running average throughput across the generation in tokens/second (decoder/generative only)
+ tokensPerSecond?: number;
+ targets?: number[][]; // Only present in first step, what tokens should have been
+ encoderInputs?: number[][]; // Encoder/source inputs used (for encoder-decoder or encoder-only)
+ decoderPromptLengths?: number[]; // For decoder-only display: prompt token counts per example
+ matches?: boolean[][]; // Per-example, per-token correctness flags
+};
+
+export type NaturalValidationExamples = {
+ naturalSequences: number[][];
+};
+
+export function buildValidationExamplesSubset(
+ examples: NaturalValidationExamples,
+ subsetSize: number
+): NaturalValidationExamples {
+ return {
+ naturalSequences: examples.naturalSequences.slice(0, subsetSize)
+ };
+}
+
+export async function prepareNaturalValidationExamples(
+ config: Config,
+ dataset: NaturalLanguageDataset
+): Promise {
+ const naturalSequences: number[][] = [];
+
+ let count = 0;
+ for await (const sampleSequence of dataset) {
+ naturalSequences.push(sampleSequence);
+ count++;
+ if (count >= config.training.validation.batchSize) break;
+ }
+
+ return { naturalSequences };
+}
+
+export async function computeNaturalValidationMetrics(
+ model: GeneratableModel,
+ dataset: NaturalLanguageDataset,
+ valExamples: NaturalValidationExamples,
+ valConfig: ValidationConfig
+): Promise> {
+ let promptLen = 0;
+
+ const contextSize = dataset.contextSize;
+
+ // promptLen = Math.max(Math.floor(contextSize / 4), 1);
+ promptLen = 8;
+ const eosId = dataset.eosId as number;
+ const starts = valExamples.naturalSequences.map((seq) => seq.slice(0, promptLen));
+ const maxTokens = Math.max(0, contextSize - promptLen);
+ const result = await generateDecoderCompletions(model, starts, {
+ maxTokens,
+ stopTokens: eosId !== null ? [eosId] : [],
+ temperature: valConfig.temperature,
+ useKvCache: valConfig.useKvCache
+ });
+ const { completions, tokensPerSecond } = result;
+
+ const validationStepData: Omit = {
+ type: 'validation',
+ completions,
+ samplingParams: { temperature: valConfig.temperature },
+ decoderPromptLengths: new Array(valExamples.naturalSequences.length).fill(promptLen),
+ tokensPerSecond
+ };
+
+ return validationStepData;
+}
+
+export function buildValidationLog(
+ validationStepData: Omit
+): Record> {
+ // Aggregate numeric-like metrics from per-example metrics; average arrays per-example; skip 'matches'
+ const aggregatedNumeric: Record = {};
+
+ // Compute number of unique completions by hashing tokenIds
+ const uniqueCompletionsCount = (() => {
+ const seen = new Set();
+ for (const c of validationStepData.completions ?? []) {
+ const ids = c?.tokenIds ?? [];
+ seen.add(ids.join(','));
+ }
+ return seen.size;
+ })();
+
+ const validationLog: Record = {
+ 'validation/completions': validationStepData,
+ 'validation/unique_completions': uniqueCompletionsCount
+ };
+ if (typeof validationStepData.tokensPerSecond === 'number') {
+ validationLog['validation/tokens_per_second'] = validationStepData.tokensPerSecond;
+ }
+ for (const [k, v] of Object.entries(aggregatedNumeric)) {
+ validationLog[`validation/${k}`] = v;
+ }
+ return validationLog;
+}
+
+export async function computeLikelihoodMetrics(
+ model: GeneratableModel,
+ sequences: NaturalValidationExamples,
+ collateFn: NaturalCollateFnType
+): Promise<{ valLoss: number; perplexity: number }> {
+ return await weak(async () => {
+ model.eval();
+
+ let valLoss: number | null = null;
+ try {
+ const collated = collateFn(sequences.naturalSequences);
+
+ let loss: Tensor | null = null;
+ const [inputs, targets] = collated.tensors;
+ [, loss] = model.forward(await inputs.to('gpu'), {
+ targets: await targets.to('gpu')
+ });
+
+ if (!loss) {
+ throw new Error(`No loss tensor returned from decoder-only model during validation`);
+ }
+ valLoss = await (await loss.to('cpu')).item();
+ if (valLoss === null) {
+ throw new Error(`Validation loss item is null for decoder-only model`);
+ }
+ } finally {
+ model.train();
+ }
+
+ const perplexity = Math.exp(valLoss);
+ return { valLoss, perplexity };
+ });
+}
diff --git a/examples/finetuning/src/lib/train/validationHelpers.ts b/examples/finetuning/src/lib/train/validationHelpers.ts
new file mode 100644
index 00000000..3c62ee45
--- /dev/null
+++ b/examples/finetuning/src/lib/train/validationHelpers.ts
@@ -0,0 +1,60 @@
+import type { TokenRollout } from '$lib/workspace/runs.svelte';
+
+import { pin, weak } from '@piston-ml/piston-web';
+
+import { generateGPTStream } from './generate';
+import { GPT } from './model/gpt';
+
+export type DecoderGenerationOptions = {
+ maxTokens?: number;
+ stopTokens?: number[];
+ temperature: number;
+ useKvCache: boolean;
+};
+
+export async function generateDecoderCompletions(
+ model: GPT,
+ startSequences: number[][],
+ options: DecoderGenerationOptions
+): Promise<{ completions: TokenRollout[]; tokensPerSecond?: number }> {
+ const { maxTokens, stopTokens, temperature, useKvCache } = options;
+
+ model.eval();
+
+ const completions: TokenRollout[] = [];
+ let lastTPS: number | undefined;
+
+ for (let bi = 0; bi < startSequences.length; bi++) {
+ await weak(async () => {
+ const startTokens = startSequences[bi] ?? [];
+ let seq: number[] = [];
+ const perStepProbs: number[][] = [];
+ let stepIndex = 0;
+ for await (const generationResult of generateGPTStream(model, startTokens, {
+ maxTokens,
+ stopTokens: stopTokens ?? [],
+ temperature,
+ useKvCache
+ })) {
+ seq = generationResult.sequences[0] ? [...generationResult.sequences[0]] : [];
+ lastTPS = generationResult.tokensPerSecond ?? lastTPS;
+ if (generationResult.probs) {
+ const probsArray = await (await generationResult.probs.to('cpu')).toVec();
+ const [_b, v] = generationResult.probs.shape;
+ const row: number[] = [];
+ for (let vi = 0; vi < v; vi++) {
+ row.push(probsArray[0 * v + vi]);
+ }
+ perStepProbs.push(row);
+ }
+ stepIndex++;
+ if (typeof maxTokens === 'number' && stepIndex >= maxTokens) break;
+ }
+ completions[bi] = { tokenIds: seq, probs: pin(perStepProbs) };
+ });
+ }
+
+ model.train();
+
+ return { completions, tokensPerSecond: lastTPS };
+}
diff --git a/examples/finetuning/src/lib/workspace/checkpointStore.ts b/examples/finetuning/src/lib/workspace/checkpointStore.ts
new file mode 100644
index 00000000..558bec9e
--- /dev/null
+++ b/examples/finetuning/src/lib/workspace/checkpointStore.ts
@@ -0,0 +1,52 @@
+import { openDb, txRequest } from '$lib/dataUtils';
+
+const DB_NAME = 'piston-checkpoint-store';
+const DB_VERSION = 1;
+const STORE_CHECKPOINTS = 'checkpoints';
+
+export class CheckpointStore {
+ private dbPromise: Promise | null = null;
+
+ private get db(): Promise {
+ if (!this.dbPromise)
+ this.dbPromise = openDb(DB_NAME, DB_VERSION, (db) => {
+ if (!db.objectStoreNames.contains(STORE_CHECKPOINTS)) {
+ db.createObjectStore(STORE_CHECKPOINTS);
+ }
+ });
+ return this.dbPromise;
+ }
+
+ async get(runId: string): Promise {
+ const db = await this.db;
+ return txRequest(db, STORE_CHECKPOINTS, 'readonly', (s) =>
+ s.get(runId)
+ );
+ }
+
+ async set(runId: string, bytes: Uint8Array | ArrayBuffer): Promise {
+ const db = await this.db;
+ const buf = bytes instanceof Uint8Array ? (bytes.buffer as ArrayBuffer) : bytes;
+ await txRequest(db, STORE_CHECKPOINTS, 'readwrite', (s) => s.put(buf, runId));
+ }
+
+ async has(runId: string): Promise {
+ const db = await this.db;
+ const res = await txRequest(db, STORE_CHECKPOINTS, 'readonly', (s) =>
+ s.get(runId)
+ );
+ return res != null;
+ }
+
+ async delete(runId: string): Promise {
+ const db = await this.db;
+ await txRequest(db, STORE_CHECKPOINTS, 'readwrite', (s) => s.delete(runId));
+ }
+
+ async clear(): Promise {
+ const db = await this.db;
+ await txRequest(db, STORE_CHECKPOINTS, 'readwrite', (s) => s.clear());
+ }
+}
+
+export const checkpointStore = new CheckpointStore();
diff --git a/examples/finetuning/src/lib/workspace/config.svelte.ts b/examples/finetuning/src/lib/workspace/config.svelte.ts
new file mode 100644
index 00000000..738d7fd4
--- /dev/null
+++ b/examples/finetuning/src/lib/workspace/config.svelte.ts
@@ -0,0 +1,466 @@
+import type { Config } from '$lib/workspace/config';
+
+import { browser } from '$app/environment';
+import { buildDataset } from '$lib/train/data';
+import { getCollatedSampleData } from '$lib/train/data/collate';
+import { GPT2_BLOCK_SIZE, GPT2_VOCAB_SIZE } from '$lib/train/model/gpt';
+import { createDataloader } from '$lib/train/utils/model';
+import { SvelteURL, SvelteURLSearchParams } from 'svelte/reactivity';
+
+import { getCurrentRun, getLatestRun } from './runs.svelte';
+
+const CONFIG_DEFAULTS: Config = {
+ training: {
+ logSteps: 5,
+ batchSize: 1,
+ validation: {
+ present: true,
+ valSteps: 10,
+ batchSize: 8,
+ temperature: 0.0,
+ useKvCache: false,
+ completions: {
+ present: true,
+ decodingBatchSize: 1,
+ amount: 'subset',
+ subsetSize: 4
+ }
+ },
+ limitTraining: {
+ present: false,
+ steps: 50_000
+ },
+ labelSmoothing: {
+ present: false,
+ value: 1e-4
+ },
+ dropout: {
+ present: false,
+ embedding: 0.1,
+ transformer: {
+ attention: 0.1,
+ residual: 0.1
+ }
+ },
+ randomSeed: {
+ present: true,
+ value: 'sequence toy'
+ },
+ gradNorm: {
+ track: true,
+ errorIfNonfinite: true
+ },
+ clipGradNorm: {
+ present: false,
+ value: 1.0
+ },
+ useWeakTensorReferences: true,
+ sharedObjectAllocation: false,
+ cachingEnabled: false,
+ inplaceSupport: true,
+ vramLimitMb: {
+ present: true,
+ value: 4096
+ },
+ checkpointEverySteps: {
+ present: true,
+ value: 200
+ },
+ restartEverySteps: 1000
+ },
+ data: {
+ dataset: 'tinystories'
+ },
+ model: {
+ type: 'distilgpt2'
+ },
+ optimizer: {
+ type: 'Muon',
+ lr: 1e-3,
+ weightDecay: {
+ present: true,
+ value: 1e-2,
+ useWeightDecayGroups: true
+ },
+ warmupSteps: {
+ present: true,
+ value: 100
+ },
+ lrScheduler: {
+ present: true,
+ type: 'cosine',
+ stepSchedule: {
+ stepSize: 100,
+ gamma: 0.8
+ },
+ constantSchedule: {
+ factor: 1 / 3,
+ totalIters: 100
+ },
+ cosineAnnealingSchedule: {
+ tMax: 500,
+ etaMin: 1e-4
+ },
+ exponentialSchedule: {
+ gamma: 0.999
+ },
+ linearSchedule: {
+ startFactor: 1.0,
+ endFactor: 1 / 3,
+ totalIters: 1000
+ }
+ },
+ adam: {
+ beta1: 0.9,
+ beta2: 0.999,
+ eps: 1e-8,
+ amsgrad: false
+ },
+ sgd: {
+ momentum: 0.9,
+ dampening: 0,
+ nesterov: false
+ },
+ muon: {
+ momentum: 0.95,
+ nsSteps: 5,
+ nesterov: true
+ }
+ },
+ version: 1
+};
+
+function computeEffectiveDefaults(): Config {
+ return JSON.parse(JSON.stringify(CONFIG_DEFAULTS)) as Config;
+}
+
+/**
+ * Parses a value based on the type of the default value in the config. This is not wildly general,
+ * but it seems to work for the current config.
+ * @param valueStr - The value to parse.
+ * @param defaultValue - The default value.
+ * @returns The parsed value.
+ */
+function parseValueBasedOnDefault(valueStr: string, defaultValue: unknown): unknown {
+ if (typeof defaultValue === 'boolean') {
+ return valueStr.toLowerCase() === 'true';
+ }
+ if (typeof defaultValue === 'number') {
+ const num = parseFloat(valueStr);
+ return isNaN(num) ? defaultValue : num;
+ }
+ return valueStr; // Default to string if type is not boolean or number
+}
+
+/**
+ * Builds a config from URL search params.
+ * @param params - The URL search params.
+ * @param defaults - The defaults to use if no URL search params are present.
+ * @returns The config.
+ */
+function buildConfigFromUrlParams(params: URLSearchParams, defaults: Config): Partial {
+ const configFromUrl: Record = {};
+
+ for (const [path, valueStr] of params) {
+ const keys = path.split('.');
+ let currentLevel = configFromUrl;
+ let currentDefaultsLevel: unknown = defaults;
+
+ try {
+ for (let i = 0; i < keys.length; i++) {
+ const key = keys[i];
+ if (
+ currentDefaultsLevel === undefined ||
+ typeof currentDefaultsLevel !== 'object' ||
+ currentDefaultsLevel === null
+ ) {
+ throw new Error(`Invalid config path from URL: ${path}`);
+ }
+ currentDefaultsLevel = (currentDefaultsLevel as Record)[key];
+
+ if (i < keys.length - 1) {
+ if (
+ !(currentLevel as Record)[key] ||
+ typeof (currentLevel as Record)[key] !== 'object'
+ ) {
+ (currentLevel as Record)[key] = {};
+ }
+ currentLevel = (currentLevel as Record)[key] as Record;
+ } else {
+ (currentLevel as Record)[key] = parseValueBasedOnDefault(
+ valueStr,
+ currentDefaultsLevel
+ );
+ }
+ }
+ } catch (e) {
+ console.warn((e as Error).message);
+ continue; // Skip this parameter if path is invalid or type mismatch
+ }
+ }
+ return configFromUrl as Partial;
+}
+
+function mergeDeep(target: Record, source: Record) {
+ for (const key in source) {
+ if (Object.prototype.hasOwnProperty.call(source, key)) {
+ const sourceVal = source[key];
+ let targetKeyAsObject = target[key] as Record;
+
+ if (sourceVal && typeof sourceVal === 'object' && !Array.isArray(sourceVal)) {
+ if (
+ !targetKeyAsObject ||
+ typeof targetKeyAsObject !== 'object' ||
+ Array.isArray(targetKeyAsObject)
+ ) {
+ targetKeyAsObject = {};
+ target[key] = targetKeyAsObject;
+ }
+ mergeDeep(targetKeyAsObject, sourceVal as Record);
+ } else if (sourceVal !== undefined) {
+ target[key] = sourceVal;
+ }
+ }
+ }
+}
+
+/**
+ * Gets the initial config from the URL search params, or the defaults if no URL search params are
+ * present.
+ * @returns The initial config.
+ */
+function getInitialConfig(): Config {
+ // Start with effective defaults, possibly from URL 'preset'
+ let base: Config = JSON.parse(JSON.stringify(CONFIG_DEFAULTS));
+ if (typeof window !== 'undefined' && window.location && window.URLSearchParams) {
+ try {
+ const params = new URLSearchParams(window.location.search);
+ base = computeEffectiveDefaults();
+ const configOverrides = buildConfigFromUrlParams(params, base);
+ const initial = JSON.parse(JSON.stringify(base));
+ mergeDeep(initial, configOverrides);
+ return initial;
+ } catch (e) {
+ console.error('Error processing config from URL, using defaults:', e);
+ return JSON.parse(JSON.stringify(CONFIG_DEFAULTS));
+ }
+ }
+ return base;
+}
+
+export const config = $state(getInitialConfig());
+const configDefaults = $derived(computeEffectiveDefaults());
+
+/**
+ * Resets one or more config values to their defaults using dot-separated paths.
+ */
+export function resetConfigToDefaults(paths: string | string[]) {
+ const pathList = Array.isArray(paths) ? paths : [paths];
+
+ for (const path of pathList) {
+ const defaultValue = getValueAtPath(configDefaults as unknown as Record, path);
+ if (defaultValue === undefined) {
+ console.warn(`resetConfigToDefaults: Unknown config path "${path}"`);
+ continue;
+ }
+ // Deep clone to avoid mutating the CONFIG_DEFAULTS reference
+ const cloned = deepClone(defaultValue);
+ const ok = setValueAtPath(config as unknown as Record, path, cloned);
+ if (!ok) {
+ console.warn(`resetConfigToDefaults: Failed to set value for path "${path}"`);
+ }
+ }
+}
+
+function deepClone(value: T): T {
+ return JSON.parse(JSON.stringify(value)) as T;
+}
+
+export function getConfigDefaultValue(path: string): unknown {
+ const val = getValueAtPath(configDefaults as unknown as Record, path);
+ return deepClone(val);
+}
+
+export function equalsConfigDefault(path: string): boolean {
+ const current = getValueAtPath(config as unknown as Record, path);
+ const def = getValueAtPath(configDefaults as unknown as Record, path);
+ return valuesDeepEqual(current, def);
+}
+
+function valuesDeepEqual(a: unknown, b: unknown): boolean {
+ try {
+ return JSON.stringify(a) === JSON.stringify(b);
+ } catch {
+ return a === b;
+ }
+}
+
+function getValueAtPath(obj: Record, path: string): unknown {
+ const keys = path.split('.');
+ let current: unknown = obj;
+ for (const key of keys) {
+ if (
+ current === null ||
+ current === undefined ||
+ typeof current !== 'object' ||
+ !(key in (current as Record))
+ ) {
+ return undefined;
+ }
+ current = (current as Record)[key];
+ }
+ return current;
+}
+
+function setValueAtPath(target: Record, path: string, value: unknown): boolean {
+ const keys = path.split('.');
+ let current: Record = target;
+ for (let i = 0; i < keys.length - 1; i++) {
+ const key = keys[i];
+ const next = current[key];
+ if (next === undefined || next === null || typeof next !== 'object' || Array.isArray(next)) {
+ // Only create an object if we are not overwriting a non-object path
+ current[key] = {};
+ }
+ current = current[key] as Record;
+ }
+ const lastKey = keys[keys.length - 1];
+ current[lastKey] = value as unknown;
+ return true;
+}
+
+/**
+ * Flattens only the non-default values from an object by comparing against a defaults object.
+ * Returns a map of dot-separated paths to stringified values.
+ */
+function flattenNonDefault(
+ obj: Record,
+ defaults: Record,
+ prefix: string = ''
+): Record {
+ const params: Record = {};
+ for (const key in obj) {
+ if (!Object.prototype.hasOwnProperty.call(obj, key)) continue;
+ const newPrefix = prefix ? `${prefix}.${key}` : key;
+ const value = obj[key];
+ const defaultValue = (defaults ?? {})[key];
+
+ if (value !== null && typeof value === 'object' && !Array.isArray(value)) {
+ const defaultChild =
+ defaultValue !== null && typeof defaultValue === 'object' && !Array.isArray(defaultValue)
+ ? (defaultValue as Record)
+ : ({} as Record);
+ const nested = flattenNonDefault(value as Record, defaultChild, newPrefix);
+ Object.assign(params, nested);
+ } else if (value !== undefined) {
+ if (defaultValue === undefined || !valuesDeepEqual(value, defaultValue)) {
+ params[newPrefix] = String(value);
+ }
+ }
+ }
+ return params;
+}
+
+export function initSharedConfigUrlSync() {
+ if (typeof window !== 'undefined' && window.history && window.URL) {
+ $effect(() => {
+ const configSnapshot = $state.snapshot(config);
+ const flatParams = flattenNonDefault(
+ configSnapshot,
+ configDefaults as unknown as Record
+ );
+ // If any parameters are present, also include the current config version
+ if (Object.keys(flatParams).length > 0) {
+ flatParams['version'] = String(configSnapshot.version);
+ }
+ const searchParamsString = new SvelteURLSearchParams(flatParams).toString();
+
+ const currentUrl = new SvelteURL(window.location.href);
+ currentUrl.search = searchParamsString; // This replaces the entire search string
+
+ // Only call replaceState if the URL actually changed to avoid flooding history
+ if (window.location.href !== currentUrl.href) {
+ window.history.replaceState({}, '', currentUrl.toString());
+ }
+ });
+ }
+}
+
+export function replaceConfig(next: Config) {
+ config['training'] = deepClone(next.training);
+ config['data'] = deepClone(next.data);
+ config['model'] = deepClone(next.model);
+ config['optimizer'] = deepClone(next.optimizer);
+ config['version'] = next.version;
+ validateConfig();
+}
+
+function datasetFromConfig(config: Config) {
+ // Only build the full dataset/tokenizer pipeline in the browser. During SSR/prerender
+ // we return a lightweight placeholder object so that components can render without
+ // triggering network fetches (which would use Node's global fetch with a relative URL).
+ if (!browser) {
+ return {
+ dataset: null,
+ tokenizer: null,
+ sampleData: null,
+ collated: null
+ };
+ }
+
+ const dataset = buildDataset(config, 'train');
+ const [, collateFn] = createDataloader(config, dataset, null);
+
+ const collatedData = getCollatedSampleData(dataset, collateFn, 4);
+
+ return {
+ dataset,
+ vocabSize: GPT2_VOCAB_SIZE,
+ blockSize: GPT2_BLOCK_SIZE,
+ tokenizer: dataset.tokenizer,
+ sampleData: collatedData.then((data) => {
+ const firstSample = data.collated[0];
+ return {
+ hasPrompt: 'prompt' in firstSample && (firstSample.prompt?.length ?? 0) > 0,
+ samples: data.samples,
+ collated: data.collated
+ };
+ })
+ };
+}
+
+const currentDataset = $derived(datasetFromConfig(config));
+
+export function getCurrentDataset() {
+ return currentDataset;
+}
+
+const currentRunDataset = $derived.by(() => {
+ const currentRun = getCurrentRun();
+ return currentRun?.config ? datasetFromConfig(currentRun.config) : null;
+});
+
+export function getCurrentRunDataset() {
+ return currentRunDataset;
+}
+
+const latestRunDataset = $derived.by(() => {
+ const latestRun = getLatestRun();
+ return latestRun?.config ? datasetFromConfig(latestRun.config) : null;
+});
+
+export function getLatestRunDataset() {
+ return latestRunDataset;
+}
+
+export function validateConfig() {
+ // There are a few things that can still slip through the cracks, so we deal with those here.
+
+ if (
+ config.training.validation.completions.present &&
+ config.training.validation.completions.amount === 'subset' &&
+ config.training.validation.completions.subsetSize > config.training.validation.batchSize
+ ) {
+ config.training.validation.completions.subsetSize = config.training.validation.batchSize;
+ }
+}
diff --git a/examples/finetuning/src/lib/workspace/config.ts b/examples/finetuning/src/lib/workspace/config.ts
new file mode 100644
index 00000000..792e42c3
--- /dev/null
+++ b/examples/finetuning/src/lib/workspace/config.ts
@@ -0,0 +1,283 @@
+import type { DATASET_CONFIG_DEFAULTS } from '$lib/train/data';
+import type {
+ ConstantConfig,
+ CosineAnnealingConfig,
+ ExponentialConfig,
+ LinearConfig,
+ StepConfig
+} from '@piston-ml/piston-web';
+
+export interface TransformerDropoutConfig {
+ present: boolean;
+ embedding: number;
+ transformer: {
+ attention: number;
+ residual: number;
+ };
+}
+
+export interface ValidationCompletionsConfig {
+ present: boolean;
+ decodingBatchSize: number;
+ amount: 'all' | 'subset';
+ subsetSize: number;
+}
+
+export interface ValidationConfig {
+ present: boolean;
+ valSteps: number;
+ batchSize: number;
+ temperature: number;
+ completions: ValidationCompletionsConfig;
+ useKvCache: boolean;
+}
+
+export interface TrainingConfig {
+ logSteps: number;
+ limitTraining: {
+ present: boolean;
+ steps: number;
+ };
+ checkpointEverySteps: {
+ present: boolean;
+ value: number;
+ };
+ batchSize: number;
+ dropout: TransformerDropoutConfig;
+ validation: ValidationConfig;
+ labelSmoothing: {
+ present: boolean;
+ value: number;
+ };
+ randomSeed: {
+ present: boolean;
+ value: string;
+ };
+ vramLimitMb: {
+ present: boolean;
+ value: number;
+ };
+ gradNorm: {
+ track: boolean;
+ errorIfNonfinite: boolean;
+ };
+ clipGradNorm: {
+ present: boolean;
+ value: number;
+ };
+ useWeakTensorReferences: boolean;
+ sharedObjectAllocation: boolean;
+ cachingEnabled: boolean;
+ inplaceSupport: boolean;
+ restartEverySteps: number;
+}
+
+export interface DataConfig {
+ dataset: keyof typeof DATASET_CONFIG_DEFAULTS;
+}
+
+export type ProjectionInitializationStrategy = 'layer-scaled' | 'zero';
+export interface ProjectionInitializationConfig {
+ present: boolean;
+ strategy: ProjectionInitializationStrategy;
+}
+
+export interface TransformerInitializationConfig {
+ present: boolean;
+ std: number;
+ projections: {
+ attention: ProjectionInitializationConfig;
+ mlp: ProjectionInitializationConfig;
+ lmHead: ProjectionInitializationConfig;
+ };
+}
+
+export interface LSTMInitializationConfig {
+ forgetGateBias: {
+ present: boolean;
+ value: number;
+ };
+}
+
+export type GPT2ModelType = 'distilgpt2' | 'gpt2' | 'gpt2-medium' | 'gpt2-large' | 'gpt2-xl';
+
+export interface ModelConfig {
+ type: GPT2ModelType;
+}
+
+export interface OptimizerConfig {
+ type: 'AdamW' | 'Adam' | 'SGD' | 'Muon';
+ lr: number;
+ weightDecay: {
+ present: boolean;
+ value: number;
+ useWeightDecayGroups: boolean;
+ };
+ warmupSteps: { present: boolean; value: number };
+ lrScheduler: {
+ present: boolean;
+ type: string;
+ stepSchedule: StepConfig;
+ constantSchedule: ConstantConfig;
+ cosineAnnealingSchedule: CosineAnnealingConfig;
+ exponentialSchedule: ExponentialConfig;
+ linearSchedule: LinearConfig;
+ };
+ adam: {
+ beta1: number;
+ beta2: number;
+ eps: number;
+ amsgrad: boolean;
+ };
+ sgd: {
+ dampening: number;
+ momentum: number;
+ nesterov: boolean;
+ };
+ muon: {
+ nsSteps: number;
+ momentum: number;
+ nesterov: boolean;
+ };
+}
+
+export interface Config {
+ version: number;
+ training: TrainingConfig;
+ data: DataConfig;
+ model: ModelConfig;
+ optimizer: OptimizerConfig;
+}
+
+export type ConfigItemDescription =
+ | {
+ shortName: string;
+ }
+ | string
+ | [string, number]
+ | null;
+
+type ReplaceValues = T extends object ? { [K in keyof T]: ReplaceValues } : V;
+
+export type ConfigValues = ReplaceValues;
+
+export const CONFIG_DESCRIPTIONS: ConfigValues = {
+ training: {
+ logSteps: 'log steps',
+ batchSize: 'batch',
+ clipGradNorm: {
+ present: 'clip grad norm',
+ value: 'clip grad norm'
+ },
+ validation: {
+ present: 'val',
+ valSteps: 'val steps',
+ batchSize: 'val size',
+ temperature: 'val temp',
+ useKvCache: 'val kv cache',
+ completions: {
+ present: 'completions',
+ decodingBatchSize: 'completions batch',
+ amount: 'completions amount strategy',
+ subsetSize: 'completions subset'
+ }
+ },
+ limitTraining: {
+ present: 'limit train',
+ steps: 'max steps'
+ },
+ labelSmoothing: {
+ present: 'smoothing',
+ value: 'smoothing'
+ },
+ dropout: {
+ present: 'dropout',
+ embedding: 'dropout emb',
+ transformer: {
+ attention: 'dropout attn',
+ residual: 'dropout resid'
+ }
+ },
+ randomSeed: {
+ present: 'seed',
+ value: 'seed'
+ },
+ gradNorm: {
+ track: 'track grad norm',
+ errorIfNonfinite: 'error nonfinite'
+ },
+ useWeakTensorReferences: 'weak tensor refs',
+ sharedObjectAllocation: 'shared objs',
+ cachingEnabled: 'caching',
+ inplaceSupport: 'inplace',
+ vramLimitMb: {
+ present: 'vram lim',
+ value: 'vram lim'
+ },
+ checkpointEverySteps: {
+ present: 'checkpointing',
+ value: 'checkpoint steps'
+ },
+ restartEverySteps: 'restart steps'
+ },
+ data: {
+ dataset: 'dataset'
+ },
+ model: {
+ type: 'model'
+ },
+ optimizer: {
+ type: 'optim',
+ lr: 'lr',
+ weightDecay: {
+ present: 'decay',
+ value: 'decay',
+ useWeightDecayGroups: 'decay groups'
+ },
+ warmupSteps: {
+ present: 'warmup',
+ value: 'warmup steps'
+ },
+ lrScheduler: {
+ present: 'lr sched',
+ type: 'lr sched',
+ stepSchedule: {
+ stepSize: 'sched step',
+ gamma: 'sched gamma'
+ },
+ constantSchedule: {
+ factor: 'const lr factor',
+ totalIters: 'const lr total'
+ },
+ cosineAnnealingSchedule: {
+ tMax: 'cos lr tmax',
+ etaMin: 'cos lr eta min'
+ },
+ exponentialSchedule: {
+ gamma: 'exp lr gamma'
+ },
+ linearSchedule: {
+ startFactor: 'lin lr start',
+ endFactor: 'lin lr end',
+ totalIters: 'lin lr total'
+ }
+ },
+ adam: {
+ beta1: 'adam beta1',
+ beta2: 'adam beta2',
+ eps: 'adam eps',
+ amsgrad: 'adam ams'
+ },
+ sgd: {
+ momentum: 'sgd moment',
+ dampening: 'sgd damp',
+ nesterov: 'sgd nester'
+ },
+ muon: {
+ momentum: 'muon moment',
+ nsSteps: 'muon nssteps',
+ nesterov: 'muon nester'
+ }
+ },
+ version: null
+};
diff --git a/examples/finetuning/src/lib/workspace/lastSessionStore.ts b/examples/finetuning/src/lib/workspace/lastSessionStore.ts
new file mode 100644
index 00000000..b15ef4f4
--- /dev/null
+++ b/examples/finetuning/src/lib/workspace/lastSessionStore.ts
@@ -0,0 +1,115 @@
+import { openDb, txRequest } from '$lib/dataUtils';
+import { SvelteMap } from 'svelte/reactivity';
+
+import type { Config } from './config';
+import type { RunData, StepData } from './runs.svelte';
+
+export type SavedRun = Omit & { metrics: Record };
+
+const DB_NAME = 'piston-last-session-store';
+const DB_VERSION = 1;
+const STORE_SESSION = 'session';
+const STORE_CHECKPOINT = 'checkpoint';
+const STORE_META = 'meta';
+const META_LAST_RUN_ID_KEY = 'lastRunId';
+
+export function serializeRun(run: RunData): SavedRun {
+ const metrics = Object.fromEntries([...run.metrics.entries()].map(([k, v]) => [k, v.data]));
+ return { ...run, metrics };
+}
+
+export function deserializeRun(saved: SavedRun): RunData {
+ return {
+ ...saved,
+ metrics: new SvelteMap(
+ Object.entries(saved.metrics).map(([k, v]) => [k, { metricName: k, data: v }])
+ )
+ };
+}
+
+class LastSessionStore {
+ private dbPromise: Promise | null = null;
+
+ private get db(): Promise {
+ if (!this.dbPromise)
+ this.dbPromise = openDb(DB_NAME, DB_VERSION, (db) => {
+ if (!db.objectStoreNames.contains(STORE_SESSION)) db.createObjectStore(STORE_SESSION);
+ if (!db.objectStoreNames.contains(STORE_CHECKPOINT)) db.createObjectStore(STORE_CHECKPOINT);
+ if (!db.objectStoreNames.contains(STORE_META)) db.createObjectStore(STORE_META);
+ });
+ return this.dbPromise;
+ }
+
+ private async getLastRunId(db: IDBDatabase): Promise {
+ return txRequest(db, STORE_META, 'readonly', (s) =>
+ s.get(META_LAST_RUN_ID_KEY)
+ );
+ }
+
+ async get(): Promise<{ run: RunData; checkpoint: Uint8Array } | null> {
+ const db = await this.db;
+ const runId = await this.getLastRunId(db);
+ if (!runId) return null;
+ const [run, buffer] = await Promise.all([
+ txRequest(db, STORE_SESSION, 'readonly', (s) => s.get(runId)),
+ txRequest(db, STORE_CHECKPOINT, 'readonly', (s) => s.get(runId))
+ ]);
+ if (!run || !buffer) return null;
+ return { run: deserializeRun(run), checkpoint: new Uint8Array(buffer) };
+ }
+
+ async set(run: RunData, checkpoint: Uint8Array | ArrayBuffer): Promise {
+ const db = await this.db;
+ const buf = checkpoint instanceof Uint8Array ? (checkpoint.buffer as ArrayBuffer) : checkpoint;
+ // Clear all existing sessions; we'll want to remove this if we ever support multiple
+ // persistence.
+ await this.delete();
+ await Promise.all([
+ txRequest(db, STORE_SESSION, 'readwrite', (s) => s.put(serializeRun(run), run.runId)),
+ txRequest(db, STORE_CHECKPOINT, 'readwrite', (s) => s.put(buf, run.runId)),
+ txRequest(db, STORE_META, 'readwrite', (s) => s.put(run.runId, META_LAST_RUN_ID_KEY))
+ ]);
+ }
+
+ /**
+ * Update only the config of the last run in storage without touching other data.
+ * This avoids fully deserializing metrics and only rewrites the config field.
+ */
+ async updateConfig(mutator: (config: Config) => Config | void): Promise {
+ const db = await this.db;
+ const runId = await this.getLastRunId(db);
+ if (!runId) return;
+
+ const saved = await txRequest(db, STORE_SESSION, 'readonly', (s) =>
+ s.get(runId)
+ );
+ if (!saved) return;
+
+ const maybeNew = mutator(saved.config as Config);
+ const newConfig = (maybeNew ? maybeNew : saved.config) as Config;
+ const updated: SavedRun = { ...saved, config: newConfig };
+ await txRequest(db, STORE_SESSION, 'readwrite', (s) => s.put(updated, runId));
+ }
+
+ async exists(): Promise {
+ const db = await this.db;
+ const runId = await this.getLastRunId(db);
+ if (!runId) return false;
+ const [run, buffer] = await Promise.all([
+ txRequest(db, STORE_SESSION, 'readonly', (s) => s.get(runId)),
+ txRequest(db, STORE_CHECKPOINT, 'readonly', (s) => s.get(runId))
+ ]);
+ return run != null && buffer != null;
+ }
+
+ async delete(): Promise {
+ const db = await this.db;
+ await Promise.all([
+ txRequest(db, STORE_SESSION, 'readwrite', (s) => s.clear()),
+ txRequest(db, STORE_CHECKPOINT, 'readwrite', (s) => s.clear()),
+ txRequest(db, STORE_META, 'readwrite', (s) => s.delete(META_LAST_RUN_ID_KEY))
+ ]);
+ }
+}
+
+export const lastSessionStore = new LastSessionStore();
diff --git a/examples/finetuning/src/lib/workspace/localStorage.svelte.ts b/examples/finetuning/src/lib/workspace/localStorage.svelte.ts
new file mode 100644
index 00000000..5422c14f
--- /dev/null
+++ b/examples/finetuning/src/lib/workspace/localStorage.svelte.ts
@@ -0,0 +1,98 @@
+import { tick } from 'svelte';
+
+export class LocalStorage {
+ #key: string;
+ #version = $state(0);
+ #listeners = 0;
+ #value: T | undefined;
+
+ #handler = (e: StorageEvent) => {
+ if (e.storageArea !== localStorage) return;
+ if (e.key !== this.#key) return;
+
+ this.#version += 1;
+ };
+
+ constructor(key: string, initial?: T) {
+ this.#key = key;
+ this.#value = initial;
+
+ if (typeof localStorage !== 'undefined') {
+ if (localStorage.getItem(key) == null) {
+ localStorage.setItem(key, JSON.stringify(initial));
+ }
+ }
+ }
+
+ get current() {
+ // eslint-disable-next-line @typescript-eslint/no-unused-expressions
+ this.#version;
+
+ const localStorageItem =
+ typeof localStorage !== 'undefined' ? localStorage.getItem(this.#key) : null;
+ const root = localStorageItem !== null ? JSON.parse(localStorageItem) : this.#value;
+
+ const proxies = new WeakMap();
+
+ const proxy = (value: unknown) => {
+ if (typeof value !== 'object' || value === null) {
+ return value;
+ }
+
+ let p = proxies.get(value);
+
+ if (!p) {
+ p = new Proxy(value, {
+ get: (target, property) => {
+ // eslint-disable-next-line @typescript-eslint/no-unused-expressions
+ this.#version;
+ return proxy(Reflect.get(target, property));
+ },
+ set: (target, property, value) => {
+ this.#version += 1;
+ Reflect.set(target, property, value);
+
+ if (typeof localStorage !== 'undefined') {
+ localStorage.setItem(this.#key, JSON.stringify(root));
+ }
+
+ return true;
+ }
+ });
+
+ proxies.set(value, p);
+ }
+
+ return p;
+ };
+
+ if ($effect.tracking()) {
+ $effect(() => {
+ if (this.#listeners === 0) {
+ window.addEventListener('storage', this.#handler);
+ }
+
+ this.#listeners += 1;
+
+ return () => {
+ tick().then(() => {
+ this.#listeners -= 1;
+ if (this.#listeners === 0) {
+ window.removeEventListener('storage', this.#handler);
+ }
+ });
+ };
+ });
+ }
+
+ return proxy(root);
+ }
+
+ set current(value) {
+ if (typeof localStorage !== 'undefined') {
+ localStorage.setItem(this.#key, JSON.stringify(value));
+ }
+
+ this.#version += 1;
+ }
+}
diff --git a/examples/finetuning/src/lib/workspace/runs.svelte.ts b/examples/finetuning/src/lib/workspace/runs.svelte.ts
new file mode 100644
index 00000000..1b8e9ac5
--- /dev/null
+++ b/examples/finetuning/src/lib/workspace/runs.svelte.ts
@@ -0,0 +1,443 @@
+import type { ValidationStep } from '$lib/train/validation';
+
+import { generateMemorableName } from '$lib/workspace/utils';
+import { SvelteMap, SvelteSet } from 'svelte/reactivity';
+
+import { type Config, CONFIG_DESCRIPTIONS } from './config';
+
+export type BaseStepData = { step: number };
+
+export type Point = BaseStepData & { y: number };
+
+export type TokenRollout = {
+ tokenIds: number[];
+ probs: number[][];
+};
+
+export type VisualizationStep = BaseStepData & {
+ type: 'visualization';
+};
+
+export type StepData = Point | ValidationStep;
+
+export type MetricData = {
+ metricName: string;
+ data: StepData[];
+};
+
+export type RunData = {
+ runId: string;
+ color: string;
+ config: Config;
+ metrics: SvelteMap;
+ step: number;
+ lastUpdated: number;
+ createdAt: number;
+ diffSummary: string;
+};
+
+export type RunMeta = {
+ runId: string | null;
+ config: Config | null;
+};
+
+// A palette of distinct colors for runs
+const RUN_COLORS = ['#ab6aea', '#f5493b', '#dfa300', '#3bbc4a', '#00b7c0', '#4475f6'];
+
+export const PREFIX_BOOSTS: Record = { data: 10 };
+
+type DiffItem = {
+ label: string;
+ path: string[];
+ display: string;
+ boost: number;
+};
+
+function pathToString(path: ReadonlyArray): string {
+ return path.join('.');
+}
+
+function getAtPath(obj: unknown, path: ReadonlyArray): unknown {
+ let cur = obj;
+ for (const key of path) {
+ if (cur == null || typeof cur !== 'object') return undefined;
+ cur = (cur as Record)[key];
+ }
+ return cur;
+}
+
+function getDescriptor(path: ReadonlyArray): { label: string | null; itemBoost?: number } {
+ let cur = CONFIG_DESCRIPTIONS as unknown;
+
+ for (const key of path) {
+ if (cur == null || typeof cur !== 'object') return { label: null };
+ cur = (cur as Record)[key];
+ }
+ if (cur == null) return { label: null };
+ if (Array.isArray(cur)) {
+ const [shortName, boost] = cur;
+ return { label: shortName, itemBoost: typeof boost === 'number' ? boost : undefined };
+ }
+ if (typeof cur === 'string') return { label: cur };
+ if (typeof cur === 'object' && 'shortName' in cur && typeof cur.shortName === 'string') {
+ return { label: cur.shortName };
+ }
+ return { label: null };
+}
+
+function maxPrefixBoost(path: ReadonlyArray, boosts: Record): number {
+ let maxB = 0;
+ for (let i = 1; i <= path.length; i++) {
+ const pref = path.slice(0, i).join('.');
+ const b = boosts[pref];
+ if (typeof b === 'number' && b > maxB) maxB = b;
+ }
+ return maxB;
+}
+
+function orderOfMagnitude(n: number): number {
+ if (n === 0) return 0;
+ return Math.floor(Math.log10(Math.abs(n)));
+}
+
+function formatScientific(n: number, significantDigits = 1): string {
+ if (n === 0) return '0';
+ const s = n.toExponential(Math.max(0, significantDigits - 1));
+ const [mant, expRaw] = s.split('e');
+ const mantTrim = mant
+ .replace(/\.0+$/, '')
+ .replace(/(\.[0-9]*?)0+$/, '$1')
+ .replace(/\.$/, '');
+ const exp = String(parseInt(expRaw, 10));
+ return `${mantTrim}e${exp}`;
+}
+
+function formatNumberDiff(a: number, b: number): string {
+ const bothSmallInts =
+ Number.isInteger(a) && Number.isInteger(b) && Math.abs(a) < 100 && Math.abs(b) < 100;
+ if (bothSmallInts) return `${a}→${b}`;
+
+ // If signs differ or either is non-integer, prefer concise scientific form
+ const expA = orderOfMagnitude(a);
+ const expB = orderOfMagnitude(b);
+ const bothInts = Number.isInteger(a) && Number.isInteger(b);
+ const base = Math.pow(10, expA);
+ if (
+ bothInts &&
+ Math.sign(a) >= 0 &&
+ Math.sign(b) >= 0 &&
+ expA === expB &&
+ // Only use suffix form for small changes
+ Math.abs(b - a) < base
+ ) {
+ const lead = Math.floor(a / base);
+ const suffixA = a - lead * base;
+ const suffixB = b - lead * base;
+ return `${lead}e${expA}+(${suffixA}→${suffixB})`;
+ }
+ const sa = formatScientific(a, 1);
+ const sb = formatScientific(b, 1);
+ return `${sa}⥵${sb}`;
+}
+
+function formatValue(v: unknown): string {
+ if (typeof v === 'number') return formatScientific(v, 1);
+ if (typeof v === 'string') return v;
+ if (typeof v === 'boolean') return v ? 'true' : 'false';
+ return String(v);
+}
+
+function comparePrimitive(a: unknown, b: unknown): boolean {
+ // Strict equality suffices for primitives we expect here
+ return a === b;
+}
+
+function collectLeafPaths(obj: unknown, prefix: string[] = []): string[][] {
+ const out: string[][] = [];
+ if (obj == null || typeof obj !== 'object') return out;
+ for (const key of Object.keys(obj)) {
+ const nextPath = [...prefix, key];
+ const val = (obj as Record)[key];
+ if (val != null && typeof val === 'object') {
+ // Only traverse objects that are not arrays
+ if (!Array.isArray(val)) out.push(...collectLeafPaths(val, nextPath));
+ } else {
+ out.push(nextPath);
+ }
+ }
+ return out;
+}
+
+export function describeConfigDiff(
+ prev: Config | null,
+ curr: Config,
+ opts?: { topK?: number; prefixBoosts?: Record }
+): string {
+ const topK = opts?.topK ?? 3;
+ const boosts = opts?.prefixBoosts ?? PREFIX_BOOSTS;
+ if (!prev) return 'initial experiment';
+
+ const prevPaths = collectLeafPaths(prev);
+ const currPaths = collectLeafPaths(curr);
+ const allKey = new SvelteSet