-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpenalties.ts
More file actions
312 lines (279 loc) · 8.58 KB
/
penalties.ts
File metadata and controls
312 lines (279 loc) · 8.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
/**
* Repetition Penalties Implementation
*
* Ported from llama.cpp/src/llama-sampling.cpp (lines 1747-1778)
* Exact formulas match native implementation for token-by-token parity
*
* References:
* - Penalty apply: llama_sampler_penalties_apply (line 1747)
* - Token tracking: llama_sampler_penalties_accept (line 1716)
*/
/**
* Token history tracker with frequency map
*
* Matches llama.cpp's ring_buffer + unordered_map pattern
* Maintains sliding window of last N tokens with O(1) operations
*/
export class TokenHistoryTracker {
private ringBuffer: number[] = [];
private tokenCount: Map<number, number> = new Map();
private maxSize: number;
constructor(penaltyLastN: number) {
this.maxSize = Math.max(penaltyLastN, 0);
}
/**
* Add token to history (matches llama_sampler_penalties_accept)
*/
accept(token: number): void {
if (this.maxSize === 0) {
return;
}
// Increment count for new token
this.tokenCount.set(token, (this.tokenCount.get(token) ?? 0) + 1);
// If buffer is full, remove oldest token
if (this.ringBuffer.length >= this.maxSize) {
const oldest = this.ringBuffer.shift()!;
// Decrement count for removed token
const count = this.tokenCount.get(oldest)!;
if (count === 1) {
this.tokenCount.delete(oldest);
} else {
this.tokenCount.set(oldest, count - 1);
}
}
// Add new token to back
this.ringBuffer.push(token);
}
/**
* Get occurrence count for a token in current window
*/
getCount(token: number): number {
return this.tokenCount.get(token) ?? 0;
}
/**
* Check if token exists in current window
*/
hasToken(token: number): boolean {
return this.tokenCount.has(token);
}
/**
* Reset history (matches llama_sampler_penalties_reset)
*/
reset(): void {
this.ringBuffer = [];
this.tokenCount.clear();
}
/**
* Get current window size
*/
size(): number {
return this.ringBuffer.length;
}
/**
* Get array of unique token IDs in current window
* Used for sparse penalty updates (O(H) instead of O(V))
*
* Performance: H ≈ 10-50 tokens vs V = 65,536
*/
getUniqueTokens(): number[] {
return Array.from(this.tokenCount.keys());
}
/**
* Check if penalties would modify logits
* Used to avoid unnecessary work
*/
static hasPenalties(params: {
repeat?: number;
frequency?: number;
presence?: number;
}): boolean {
return (
(params.repeat !== undefined && params.repeat !== 1.0) ||
(params.frequency !== undefined && params.frequency !== 0.0) ||
(params.presence !== undefined && params.presence !== 0.0)
);
}
/**
* Compute penalty adjustment for a single token
* Returns adjustment to subtract from logit (for freq/presence) or multiply factor (for repeat)
*
* This is the "virtual penalty" - O(1) lookup, no buffer mutation
*
* @param tokenId Token to check
* @param baseLogit Original logit value
* @param params Penalty parameters
* @returns Adjusted logit after penalties
*/
computeAdjustment(
tokenId: number,
baseLogit: number,
params: {
repeat?: number;
frequency?: number;
presence?: number;
}
): number {
const count = this.getCount(tokenId);
if (count === 0) {
return baseLogit; // No penalty if token not in history
}
let adjusted = baseLogit;
// 1. Repetition penalty (multiplicative)
if (params.repeat !== undefined && params.repeat !== 1.0) {
if (adjusted <= 0) {
adjusted *= params.repeat;
} else {
adjusted /= params.repeat;
}
}
// 2. Frequency penalty (subtractive, scaled by count)
if (params.frequency !== undefined && params.frequency !== 0.0) {
adjusted -= count * params.frequency;
}
// 3. Presence penalty (subtractive, binary)
if (params.presence !== undefined && params.presence !== 0.0) {
adjusted -= params.presence;
}
return adjusted;
}
}
/**
* Apply repetition penalty (matches llama-sampling.cpp lines 1768-1772)
*
* CRITICAL: Different formula for negative vs positive logits
* This matches the llama.cpp fix for the academic paper's bug
*
* OPTIMIZATION: Sparse iteration over H tokens (not V)
* Performance: O(H) where H ≈ 10-50, not O(V) where V = 65,536
*
* @param logits Logits array (modified in-place)
* @param tokenHistory Token history tracker
* @param penalty Repetition penalty multiplier (typical: 1.0-1.5)
*/
export function applyRepetitionPenalty(
logits: Float32Array,
tokenHistory: TokenHistoryTracker,
penalty: number
): void {
if (penalty === 1.0) {
return; // No-op
}
// Sparse iteration: only process tokens in history (H ≈ 10-50)
const uniqueTokens = tokenHistory.getUniqueTokens();
for (const tokenId of uniqueTokens) {
// From llama.cpp comment (line 1766):
// "The academic publication that described this technique actually
// just only divided, but that would cause tokens with negative logits
// to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the
// penalty instead of dividing."
if (logits[tokenId] <= 0) {
// Multiply for negative logits
logits[tokenId] *= penalty;
} else {
// Divide for positive logits
logits[tokenId] /= penalty;
}
}
}
/**
* Apply frequency penalty (matches llama-sampling.cpp line 1774, first term)
*
* Penalty scales linearly with token occurrence count
*
* Formula: logit -= count * penalty_freq
*
* OPTIMIZATION: Sparse iteration over H tokens (not V)
* Performance: O(H) where H ≈ 10-50, not O(V) where V = 65,536
*
* @param logits Logits array (modified in-place)
* @param tokenHistory Token history tracker
* @param penalty Frequency penalty (typical: 0.0-2.0)
*/
export function applyFrequencyPenalty(
logits: Float32Array,
tokenHistory: TokenHistoryTracker,
penalty: number
): void {
if (penalty === 0.0) {
return; // No-op
}
// Sparse iteration: only process tokens in history (H ≈ 10-50)
const uniqueTokens = tokenHistory.getUniqueTokens();
for (const tokenId of uniqueTokens) {
const count = tokenHistory.getCount(tokenId);
// Exact formula from llama.cpp line 1774 (first term)
logits[tokenId] -= count * penalty;
}
}
/**
* Apply presence penalty (matches llama-sampling.cpp line 1774, second term)
*
* Flat penalty if token appears at least once (binary, not scaled by count)
*
* Formula: logit -= (count > 0 ? 1 : 0) * penalty_present
*
* OPTIMIZATION: Sparse iteration over H tokens (not V)
* Performance: O(H) where H ≈ 10-50, not O(V) where V = 65,536
*
* @param logits Logits array (modified in-place)
* @param tokenHistory Token history tracker
* @param penalty Presence penalty (typical: 0.0-1.0)
*/
export function applyPresencePenalty(
logits: Float32Array,
tokenHistory: TokenHistoryTracker,
penalty: number
): void {
if (penalty === 0.0) {
return; // No-op
}
// Sparse iteration: only process tokens in history (H ≈ 10-50)
const uniqueTokens = tokenHistory.getUniqueTokens();
for (const tokenId of uniqueTokens) {
// Exact formula from llama.cpp line 1774 (second term)
// Note: All tokens in uniqueTokens have count > 0
logits[tokenId] -= penalty;
}
}
/**
* Apply all penalties in correct order (matches llama-sampling.cpp line 1774)
*
* Order matches llama.cpp:
* 1. Repetition penalty (multiply/divide based on logit sign)
* 2. Frequency + Presence penalties (subtractive)
*
* Note: llama.cpp applies frequency and presence in a single pass (line 1774),
* but the effect is identical to applying them separately
*
* @param logits Logits array (modified in-place)
* @param tokenHistory Token history tracker
* @param params Penalty parameters
*/
export function applyPenalties(
logits: Float32Array,
tokenHistory: TokenHistoryTracker,
params: {
repeat?: number;
frequency?: number;
presence?: number;
}
): void {
// Early return optimization (matches llama.cpp lines 1750-1752)
if (
tokenHistory.size() === 0 ||
(params.repeat === 1.0 && params.frequency === 0.0 && params.presence === 0.0)
) {
return;
}
// Apply in correct order (matches llama.cpp)
if (params.repeat !== undefined && params.repeat !== 1.0) {
applyRepetitionPenalty(logits, tokenHistory, params.repeat);
}
if (params.frequency !== undefined && params.frequency !== 0.0) {
applyFrequencyPenalty(logits, tokenHistory, params.frequency);
}
if (params.presence !== undefined && params.presence !== 0.0) {
applyPresencePenalty(logits, tokenHistory, params.presence);
}
}