Skip to content

Commit 97c3d0b

Browse files
committed
refactor: simplify compression timing state
1 parent d9a43c2 commit 97c3d0b

6 files changed

Lines changed: 25 additions & 86 deletions

File tree

lib/compress/pipeline.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import { sendCompressNotification } from "../ui/notification"
99
import type { ToolContext } from "./types"
1010
import { buildSearchContext, fetchSessionMessages } from "./search"
1111
import type { SearchContext } from "./types"
12+
import { applyPendingCompressionDurations } from "./timing"
1213

1314
interface RunContext {
1415
ask(input: {
@@ -83,6 +84,7 @@ export async function finalizeSession(
8384
batchTopic: string | undefined,
8485
): Promise<void> {
8586
ctx.state.manualMode = ctx.state.manualMode ? "active" : false
87+
applyPendingCompressionDurations(ctx.state)
8688
await saveSessionState(ctx.state, ctx.logger)
8789

8890
const params = getCurrentParams(ctx.state, rawMessages, ctx.logger)

lib/compress/state.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ export function allocateRunId(state: SessionState): number {
2929
export function attachCompressionDuration(
3030
messagesState: PruneMessagesState,
3131
callId: string,
32-
messageId: string,
3332
durationMs: number,
3433
): number {
3534
if (typeof durationMs !== "number" || !Number.isFinite(durationMs)) {
@@ -38,9 +37,7 @@ export function attachCompressionDuration(
3837

3938
let updates = 0
4039
for (const block of messagesState.blocksById.values()) {
41-
const matchesCall = block.compressCallId === callId
42-
const matchesMessage = !block.compressCallId && block.compressMessageId === messageId
43-
if (!matchesCall && !matchesMessage) {
40+
if (block.compressCallId !== callId) {
4441
continue
4542
}
4643

lib/compress/timing.ts

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,37 @@
11
import type { SessionState } from "../state/types"
22
import { attachCompressionDuration } from "./state"
33

4-
export interface CompressionStart {
5-
sessionId: string
6-
messageId: string
7-
startedAt: number
8-
}
9-
104
export interface PendingCompressionDuration {
115
callId: string
12-
messageId: string
136
durationMs: number
147
}
158

169
export interface CompressionTimingState {
17-
startsByCallId: Map<string, CompressionStart>
18-
pendingBySessionId: Map<string, PendingCompressionDuration[]>
10+
startsByCallId: Map<string, number>
11+
pendingByCallId: Map<string, PendingCompressionDuration>
1912
}
2013

2114
export function createCompressionTimingState(): CompressionTimingState {
2215
return {
2316
startsByCallId: new Map(),
24-
pendingBySessionId: new Map(),
17+
pendingByCallId: new Map(),
2518
}
2619
}
2720

2821
export function recordCompressionStart(
2922
state: SessionState,
3023
callId: string,
31-
sessionId: string,
32-
messageId: string,
3324
startedAt: number,
3425
): boolean {
3526
if (state.compressionTiming.startsByCallId.has(callId)) {
3627
return false
3728
}
3829

39-
state.compressionTiming.startsByCallId.set(callId, {
40-
sessionId,
41-
messageId,
42-
startedAt,
43-
})
30+
state.compressionTiming.startsByCallId.set(callId, startedAt)
4431
return true
4532
}
4633

47-
export function consumeCompressionStart(
48-
state: SessionState,
49-
callId: string,
50-
): CompressionStart | undefined {
34+
export function consumeCompressionStart(state: SessionState, callId: string): number | undefined {
5135
const start = state.compressionTiming.startsByCallId.get(callId)
5236
state.compressionTiming.startsByCallId.delete(callId)
5337
return start
@@ -58,7 +42,7 @@ export function clearCompressionStart(state: SessionState, callId: string): void
5842
}
5943

6044
export function resolveCompressionDuration(
61-
start: CompressionStart | undefined,
45+
startedAt: number | undefined,
6246
eventTime: number | undefined,
6347
partTime: { start?: unknown; end?: unknown } | undefined,
6448
): number | undefined {
@@ -67,8 +51,8 @@ export function resolveCompressionDuration(
6751
? partTime.start
6852
: eventTime
6953
const pendingToRunningMs =
70-
start && typeof runningAt === "number"
71-
? Math.max(0, runningAt - start.startedAt)
54+
typeof startedAt === "number" && typeof runningAt === "number"
55+
? Math.max(0, runningAt - startedAt)
7256
: undefined
7357

7458
const toolStart = partTime?.start
@@ -86,43 +70,28 @@ export function resolveCompressionDuration(
8670

8771
export function queueCompressionDuration(
8872
state: SessionState,
89-
sessionId: string,
9073
callId: string,
91-
messageId: string,
9274
durationMs: number,
9375
): void {
94-
const queued = state.compressionTiming.pendingBySessionId.get(sessionId) || []
95-
const filtered = queued.filter((entry) => entry.callId !== callId)
96-
filtered.push({ callId, messageId, durationMs })
97-
state.compressionTiming.pendingBySessionId.set(sessionId, filtered)
76+
state.compressionTiming.pendingByCallId.set(callId, { callId, durationMs })
9877
}
9978

100-
export function applyPendingCompressionDurations(state: SessionState, sessionId: string): number {
101-
const queued = state.compressionTiming.pendingBySessionId.get(sessionId)
102-
if (!queued || queued.length === 0) {
79+
export function applyPendingCompressionDurations(state: SessionState): number {
80+
if (state.compressionTiming.pendingByCallId.size === 0) {
10381
return 0
10482
}
10583

10684
let updates = 0
107-
const remaining = []
108-
for (const entry of queued) {
85+
for (const [callId, entry] of state.compressionTiming.pendingByCallId) {
10986
const applied = attachCompressionDuration(
11087
state.prune.messages,
11188
entry.callId,
112-
entry.messageId,
11389
entry.durationMs,
11490
)
11591
if (applied > 0) {
11692
updates += applied
117-
continue
93+
state.compressionTiming.pendingByCallId.delete(callId)
11894
}
119-
remaining.push(entry)
120-
}
121-
122-
if (remaining.length > 0) {
123-
state.compressionTiming.pendingBySessionId.set(sessionId, remaining)
124-
} else {
125-
state.compressionTiming.pendingBySessionId.delete(sessionId)
12695
}
12796

12897
return updates

lib/hooks.ts

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,6 @@ export function createTextCompleteHandler() {
276276

277277
export function createEventHandler(state: SessionState, logger: Logger) {
278278
return async (input: { event: any }) => {
279-
const eventSessionId =
280-
typeof input.event?.properties?.sessionID === "string"
281-
? input.event.properties.sessionID
282-
: typeof input.event?.properties?.part?.sessionID === "string"
283-
? input.event.properties.part.sessionID
284-
: undefined
285279
const eventTime =
286280
typeof input.event?.time === "number" && Number.isFinite(input.event.time)
287281
? input.event.time
@@ -300,41 +294,23 @@ export function createEventHandler(state: SessionState, logger: Logger) {
300294
}
301295

302296
if (part.state.status === "pending") {
303-
if (
304-
typeof part.callID !== "string" ||
305-
typeof part.messageID !== "string" ||
306-
typeof eventSessionId !== "string"
307-
) {
297+
if (typeof part.callID !== "string") {
308298
return
309299
}
310300

311301
const startedAt = eventTime ?? Date.now()
312-
if (
313-
!recordCompressionStart(
314-
state,
315-
part.callID,
316-
eventSessionId,
317-
part.messageID,
318-
startedAt,
319-
)
320-
) {
302+
if (!recordCompressionStart(state, part.callID, startedAt)) {
321303
return
322304
}
323305
logger.debug("Recorded compression start", {
324-
sessionID: eventSessionId,
325306
callID: part.callID,
326-
messageID: part.messageID,
327307
startedAt,
328308
})
329309
return
330310
}
331311

332312
if (part.state.status === "completed") {
333-
if (
334-
typeof part.callID !== "string" ||
335-
typeof part.messageID !== "string" ||
336-
typeof eventSessionId !== "string"
337-
) {
313+
if (typeof part.callID !== "string") {
338314
return
339315
}
340316

@@ -344,22 +320,17 @@ export function createEventHandler(state: SessionState, logger: Logger) {
344320
return
345321
}
346322

347-
queueCompressionDuration(state, eventSessionId, part.callID, part.messageID, durationMs)
323+
queueCompressionDuration(state, part.callID, durationMs)
348324

349-
const updates =
350-
state.sessionId === eventSessionId
351-
? applyPendingCompressionDurations(state, eventSessionId)
352-
: 0
325+
const updates = applyPendingCompressionDurations(state)
353326
if (updates === 0) {
354327
return
355328
}
356329

357330
await saveSessionState(state, logger)
358331

359332
logger.info("Attached compression time to blocks", {
360-
sessionID: eventSessionId,
361333
callID: part.callID,
362-
messageID: part.messageID,
363334
blocks: updates,
364335
durationMs,
365336
})

lib/state/state.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ export async function ensureSessionInitialized(
180180
totalPruneTokens: persisted.stats?.totalPruneTokens || 0,
181181
}
182182

183-
const applied = applyPendingCompressionDurations(state, sessionId)
183+
const applied = applyPendingCompressionDurations(state)
184184
if (applied > 0) {
185185
await saveSessionState(state, logger)
186186
}

tests/hooks-permission.test.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ test("event hook falls back to completed runtime when running duration missing",
372372
endId: "m0001",
373373
anchorMessageId: "msg-a",
374374
compressMessageId: "message-1",
375-
compressCallId: undefined,
375+
compressCallId: "call-3",
376376
includedBlockIds: [],
377377
consumedBlockIds: [],
378378
parentBlockIds: [],
@@ -492,7 +492,7 @@ test("event hook queues duration updates until the matching session is loaded",
492492
},
493493
})
494494

495-
assert.equal(liveState.compressionTiming.pendingBySessionId.get(targetSessionId)?.length, 1)
495+
assert.equal(liveState.compressionTiming.pendingByCallId.has("call-remote"), true)
496496
assert.equal(liveState.compressionTiming.startsByCallId.has("call-remote"), false)
497497

498498
await ensureSessionInitialized(
@@ -520,5 +520,5 @@ test("event hook queues duration updates until the matching session is loaded",
520520
)
521521

522522
assert.equal(liveState.prune.messages.blocksById.get(1)?.durationMs, 250)
523-
assert.equal(liveState.compressionTiming.pendingBySessionId.has(targetSessionId), false)
523+
assert.equal(liveState.compressionTiming.pendingByCallId.has("call-remote"), false)
524524
})

0 commit comments

Comments
 (0)