-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.ts
More file actions
304 lines (276 loc) · 9.41 KB
/
utils.ts
File metadata and controls
304 lines (276 loc) · 9.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
/**
* Shared sampling utilities
*/
/**
* Default K-cap for candidate truncation in sampling algorithms
*
* Used by: topP, minP, topNSigma
* Rationale: nucleus/min-p rarely need >256 candidates; avoids O(V) operations
*/
export const DEFAULT_KCAP = 256;
/**
* Larger K-cap for Typical-P sampling (entropy stability)
*
* Typical-P requires larger candidate pool for stable entropy estimates
* See: Meister et al. 2022 "Locally Typical Sampling" (TACL)
* Rationale: Entropy calculation over 512 candidates provides better estimate
* than 256, improving typical-p's ability to filter atypical tokens
*/
export const TYPICAL_P_KCAP = 512;
/**
* Calculate required workspace capacity for given sampling params
*
* Used for Test-Time Alignment (TTA) where params change between tokens:
* - Adaptive top-K: User may scrub K slider during inference
* - Typical-P toggling: Switching typicalP on/off mid-run
* - Exploratory bursts: Temporary widen K for distribution probing
*
* Strategy:
* - If typicalP < 1: Need ≥512 for stable entropy calculation
* - Else: max(topK, DEFAULT_KCAP) for nucleus/min-p sampling
* - Clamp to vocab size (no point allocating beyond V)
*
* @param topK Top-K value (0 = disabled)
* @param typicalP Typical-P value (1.0 = disabled)
* @param vocabSize Vocabulary size (upper bound)
* @returns Required kcap for workspace.ensureCapacity()
*/
export function requiredKcap(
topK: number | undefined,
typicalP: number | undefined,
vocabSize: number
): number {
// Determine base K requirement
const k = topK && topK > 0 ? Math.min(topK, vocabSize) : DEFAULT_KCAP;
// Typical-P needs larger pool for stable entropy
const needsTypicalP = typicalP !== undefined && typicalP < 1.0;
const base = needsTypicalP
? Math.max(k, TYPICAL_P_KCAP)
: Math.max(k, DEFAULT_KCAP);
// Clamp to vocab size
return Math.min(base, vocabSize);
}
/**
* Apply temperature scaling (in-place)
*
* Formula: logit /= temperature
*
* Effects:
* - temp > 1.0: Flatter distribution (more random)
* - temp = 1.0: No change
* - temp < 1.0: Sharper distribution (more deterministic)
* - temp → 0: Approaches greedy
*
* @param logits Logits array (modified in-place)
* @param temp Temperature value
*/
export function applyTemperature(logits: Float32Array, temp: number): void {
if (temp === 1.0) return; // No-op
for (let i = 0; i < logits.length; i++) {
logits[i] /= temp;
}
}
/**
* Preallocated workspace for sampling operations
* Reuse per token to avoid GC churn
*
* INVARIANTS:
* - All buffers sized to kcap (dynamically grows, never shrinks)
* - Zero allocations per token after initial growth
* - workingLogits allocated lazily only if needed (parity mode or penalties)
* - Growth uses power-of-two sizing to amortize reallocations
*
* Performance: Preallocating buffers reduces allocations from ~100/token to ~0
*
* Test-Time Alignment (TTA):
* - Supports adaptive top-K (40→320→128 during inference)
* - Supports typical-P toggling (256→512 when enabled)
* - Handles exploratory bursts without reinitialization
* - Never downsizes (monotonic growth for bursty workloads)
*
* Buffers:
* - idxK, valK, probK: For heap-based Top-K selection (size K)
* - tmpIdx, tmpLogits, tmpProbs: For applyTopP re-sorting (size K)
* - workingLogits: For penalty application (size V, optional)
* - version: Increments on growth (guards stale CandidateSet references)
*/
export class SamplerWorkspace {
// Private fields - use getters for external access
#kcap: number;
#idxK: Uint32Array;
#valK: Float32Array;
#probK: Float32Array;
#tmpIdx: Uint32Array;
#tmpLogits: Float32Array;
#tmpProbs: Float32Array;
#workingLogits?: Float32Array;
#version = 0;
constructor(maxK: number = 256, vocabSize?: number) {
// Round up to next power of two for consistent growth
this.#kcap = maxK <= 0 ? 256 : 1 << Math.ceil(Math.log2(maxK));
this.#idxK = new Uint32Array(this.#kcap);
this.#valK = new Float32Array(this.#kcap);
this.#probK = new Float32Array(this.#kcap);
this.#tmpIdx = new Uint32Array(this.#kcap);
this.#tmpLogits = new Float32Array(this.#kcap);
this.#tmpProbs = new Float32Array(this.#kcap);
// Preallocate V-length buffer if vocab size known
// Reuse across tokens to avoid per-token 262KB allocation
if (vocabSize !== undefined && vocabSize > 0) {
this.#workingLogits = new Float32Array(vocabSize);
}
}
// Read-only accessors (external immutability, internal flexibility)
get kcap(): number { return this.#kcap; }
get idxK(): Uint32Array { return this.#idxK; }
get valK(): Float32Array { return this.#valK; }
get probK(): Float32Array { return this.#probK; }
get tmpIdx(): Uint32Array { return this.#tmpIdx; }
get tmpLogits(): Float32Array { return this.#tmpLogits; }
get tmpProbs(): Float32Array { return this.#tmpProbs; }
get workingLogits(): Float32Array | undefined { return this.#workingLogits; }
get version(): number { return this.#version; }
/**
* Ensure workspace buffers can accommodate needed capacity
* Grows buffers if needed using power-of-two sizing (never shrinks)
*
* Growth strategy:
* - Round up to next power of two (40→64, 320→512)
* - Amortizes reallocations when params change during TTA
* - Typical progression: 40→64→128→256→512 (~5 allocations max per session)
* - Monotonic growth (never downsize, avoid churn in bursty workloads)
*
* @param needed Required capacity
*/
ensureCapacity(needed: number): void {
if (needed <= this.#kcap) return;
// Grow to next power of two to amortize churn
const newCap = 1 << Math.ceil(Math.log2(needed));
this.#kcap = newCap;
this.#idxK = new Uint32Array(newCap);
this.#valK = new Float32Array(newCap);
this.#probK = new Float32Array(newCap);
this.#tmpIdx = new Uint32Array(newCap);
this.#tmpLogits = new Float32Array(newCap);
this.#tmpProbs = new Float32Array(newCap);
this.#version++;
}
/**
* Ensure workingLogits buffer is allocated for given vocab size
* Used when sampling full vocabulary (no top-K truncation)
*
* Lazily allocates or resizes buffer as needed (never shrinks)
* Typical size: 65536 tokens × 4 bytes = 262KB
*
* @param vocabSize Required vocabulary size
*/
ensureWorkingLogits(vocabSize: number): void {
if (!this.#workingLogits || this.#workingLogits.length < vocabSize) {
this.#workingLogits = new Float32Array(vocabSize);
}
}
}
/**
* Efficient Top-K selection using min-heap
* O(V log K) where V=65,536, K≈40
*
* Zero-copy: Reads from input array, writes indices to preallocated buffer
* Virtual penalties: Optional accessor applies penalties during comparisons (no buffer mutation)
*
* Performance: ~1-2ms for V=65,536, K=40 (vs ~30-50ms for full sort)
*
* @param logits Input logits (zero-copy view, NOT copied, NOT modified)
* @param k Number of top elements to select
* @param idxK Output buffer for indices (Uint32Array, length >= k)
* @param valK Output buffer for values (Float32Array, length >= k)
* @param adjustFn Optional penalty accessor: (tokenId, baseLogit) => adjustedLogit
* @returns Actual K selected (min(k, logits.length))
*/
export function selectTopK(
logits: Float32Array,
k: number,
idxK: Uint32Array,
valK: Float32Array,
adjustFn?: (tokenId: number, baseLogit: number) => number
): number {
const V = logits.length;
if (k <= 0) return 0;
const K = Math.min(k, V);
// Helper to get adjusted value (virtual penalty application)
const getValue = adjustFn
? (i: number) => adjustFn(i, logits[i])
: (i: number) => logits[i];
// Seed heap with first K entries
for (let i = 0; i < K; i++) {
idxK[i] = i;
valK[i] = getValue(i);
}
// Build min-heap (by valK)
for (let i = (K >>> 1) - 1; i >= 0; i--) {
siftDown(idxK, valK, i, K);
}
// Scan remainder - maintain min-heap of K largest
for (let i = K; i < V; i++) {
const x = getValue(i);
if (x > valK[0]) { // Replace minimum if found larger
idxK[0] = i;
valK[0] = x;
siftDown(idxK, valK, 0, K);
}
}
// Sort top-K descending by value (for stable sampling)
heapSortDesc(idxK, valK, K);
return K;
}
/**
* Sift down operation for min-heap maintenance
* Internal helper for selectTopK
*/
function siftDown(
idx: Uint32Array,
val: Float32Array,
i: number,
n: number
): void {
while (true) {
const l = i * 2 + 1;
const r = l + 1;
let s = i;
if (l < n && val[l] < val[s]) s = l;
if (r < n && val[r] < val[s]) s = r;
if (s === i) break;
swap(idx, val, i, s);
i = s;
}
}
/**
* Convert min-heap to descending sorted array in-place
* Internal helper for selectTopK
*/
function heapSortDesc(
idx: Uint32Array,
val: Float32Array,
n: number
): void {
// Extract elements from min-heap (largest to smallest = descending)
// Min-heap: root is minimum, so extracting places minimum at end
// After all extractions: largest at [0], smallest at [n-1]
for (let end = n - 1; end > 0; end--) {
swap(idx, val, 0, end); // Move root (current min) to position `end`
siftDown(idx, val, 0, end); // Restore heap for [0, end)
}
// Result is already in descending order - no reverse needed!
}
/**
* Swap elements in parallel arrays
* Internal helper for heap operations
*/
function swap(
idx: Uint32Array,
val: Float32Array,
a: number,
b: number
): void {
const ti = idx[a]; idx[a] = idx[b]; idx[b] = ti;
const tv = val[a]; val[a] = val[b]; val[b] = tv;
}