Skip to content

Commit 0072c7d

Browse files
committed
fix: track compression durations across sessions
1 parent 37ed1dd commit 0072c7d

File tree

7 files changed

+311
-80
lines changed

7 files changed

+311
-80
lines changed

lib/compress/state.ts

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,31 +26,6 @@ export function allocateRunId(state: SessionState): number {
2626
return next
2727
}
2828

29-
export function attachCompressionDuration(
30-
state: SessionState,
31-
callId: string,
32-
messageId: string,
33-
durationMs: number,
34-
): number {
35-
if (typeof durationMs !== "number" || !Number.isFinite(durationMs)) {
36-
return 0
37-
}
38-
39-
let updates = 0
40-
for (const block of state.prune.messages.blocksById.values()) {
41-
const matchesCall = block.compressCallId === callId
42-
const matchesMessage = !block.compressCallId && block.compressMessageId === messageId
43-
if (!matchesCall && !matchesMessage) {
44-
continue
45-
}
46-
47-
block.durationMs = durationMs
48-
updates++
49-
}
50-
51-
return updates
52-
}
53-
5429
export function wrapCompressedSummary(blockId: number, summary: string): string {
5530
const header = COMPRESSED_BLOCK_HEADER
5631
const footer = formatMessageIdTag(formatBlockRef(blockId))

lib/hooks.ts

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import {
1616
} from "./messages"
1717
import { renderSystemPrompt, type PromptStore } from "./prompts"
1818
import { buildProtectedToolsExtension } from "./prompts/extensions/system"
19-
import { attachCompressionDuration } from "./compress/state"
2019
import {
2120
applyPendingManualTrigger,
2221
handleContextCommand,
@@ -30,7 +29,14 @@ import {
3029
} from "./commands"
3130
import { type HostPermissionSnapshot } from "./host-permissions"
3231
import { compressPermission, syncCompressPermissionState } from "./compress-permission"
33-
import { checkSession, ensureSessionInitialized, saveSessionState, syncToolCache } from "./state"
32+
import {
33+
checkSession,
34+
ensureSessionInitialized,
35+
applyPendingCompressionDurations,
36+
queueCompressionDuration,
37+
saveSessionState,
38+
syncToolCache,
39+
} from "./state"
3440
import { cacheSystemPromptTokens } from "./ui/utils"
3541

3642
const INTERNAL_AGENT_SIGNATURES = [
@@ -269,6 +275,12 @@ export function createTextCompleteHandler() {
269275

270276
export function createEventHandler(state: SessionState, logger: Logger) {
271277
return async (input: { event: any }) => {
278+
const eventSessionId =
279+
typeof input.event?.properties?.sessionID === "string"
280+
? input.event.properties.sessionID
281+
: typeof input.event?.properties?.part?.sessionID === "string"
282+
? input.event.properties.part.sessionID
283+
: undefined
272284
const eventTime =
273285
typeof input.event?.time === "number" && Number.isFinite(input.event.time)
274286
? input.event.time
@@ -287,20 +299,26 @@ export function createEventHandler(state: SessionState, logger: Logger) {
287299
}
288300

289301
if (part.state.status === "pending") {
290-
if (typeof part.callID !== "string" || typeof part.messageID !== "string") {
302+
if (
303+
typeof part.callID !== "string" ||
304+
typeof part.messageID !== "string" ||
305+
typeof eventSessionId !== "string"
306+
) {
291307
return
292308
}
293309

294-
if (state.compressionStarts.has(part.callID)) {
310+
if (state.compressionTiming.startsByCallId.has(part.callID)) {
295311
return
296312
}
297313

298314
const startedAt = eventTime ?? Date.now()
299-
state.compressionStarts.set(part.callID, {
315+
state.compressionTiming.startsByCallId.set(part.callID, {
316+
sessionId: eventSessionId,
300317
messageId: part.messageID,
301318
startedAt,
302319
})
303320
logger.debug("Recorded compression start", {
321+
sessionID: eventSessionId,
304322
callID: part.callID,
305323
messageID: part.messageID,
306324
startedAt,
@@ -309,12 +327,16 @@ export function createEventHandler(state: SessionState, logger: Logger) {
309327
}
310328

311329
if (part.state.status === "completed") {
312-
if (typeof part.callID !== "string" || typeof part.messageID !== "string") {
330+
if (
331+
typeof part.callID !== "string" ||
332+
typeof part.messageID !== "string" ||
333+
typeof eventSessionId !== "string"
334+
) {
313335
return
314336
}
315337

316-
const start = state.compressionStarts.get(part.callID)
317-
state.compressionStarts.delete(part.callID)
338+
const start = state.compressionTiming.startsByCallId.get(part.callID)
339+
state.compressionTiming.startsByCallId.delete(part.callID)
318340

319341
const runningAt =
320342
typeof part.state.time?.start === "number" && Number.isFinite(part.state.time.start)
@@ -341,28 +363,25 @@ export function createEventHandler(state: SessionState, logger: Logger) {
341363
return
342364
}
343365

344-
const updates = attachCompressionDuration(
345-
state,
346-
part.callID,
347-
part.messageID,
348-
durationMs,
349-
)
366+
queueCompressionDuration(state, eventSessionId, part.callID, part.messageID, durationMs)
367+
368+
const updates =
369+
state.sessionId === eventSessionId
370+
? applyPendingCompressionDurations(state, eventSessionId)
371+
: 0
350372
if (updates === 0) {
351373
return
352374
}
353375

376+
await saveSessionState(state, logger)
377+
354378
logger.info("Attached compression time to blocks", {
379+
sessionID: eventSessionId,
355380
callID: part.callID,
356381
messageID: part.messageID,
357382
blocks: updates,
358383
durationMs,
359384
})
360-
361-
saveSessionState(state, logger).catch((error) => {
362-
logger.warn("Failed to persist compression time update", {
363-
error: error instanceof Error ? error.message : String(error),
364-
})
365-
})
366385
return
367386
}
368387

@@ -371,7 +390,7 @@ export function createEventHandler(state: SessionState, logger: Logger) {
371390
}
372391

373392
if (typeof part.callID === "string") {
374-
state.compressionStarts.delete(part.callID)
393+
state.compressionTiming.startsByCallId.delete(part.callID)
375394
}
376395
}
377396
}

lib/state/persistence.ts

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { homedir } from "os"
1010
import { join } from "path"
1111
import type { CompressionBlock, PrunedMessageEntry, SessionState, SessionStats } from "./types"
1212
import type { Logger } from "../logger"
13+
import { serializePruneMessagesState } from "./utils"
1314

1415
/** Prune state as stored on disk */
1516
export interface PersistedPruneMessagesState {
@@ -58,6 +59,23 @@ function getSessionFilePath(sessionId: string): string {
5859
return join(STORAGE_DIR, `${sessionId}.json`)
5960
}
6061

62+
async function writePersistedSessionState(
63+
sessionId: string,
64+
state: PersistedSessionState,
65+
logger: Logger,
66+
): Promise<void> {
67+
await ensureStorageDir()
68+
69+
const filePath = getSessionFilePath(sessionId)
70+
const content = JSON.stringify(state, null, 2)
71+
await fs.writeFile(filePath, content, "utf-8")
72+
73+
logger.info("Saved session state to disk", {
74+
sessionId,
75+
totalTokensSaved: state.stats.totalPruneTokens,
76+
})
77+
}
78+
6179
export async function saveSessionState(
6280
sessionState: SessionState,
6381
logger: Logger,
@@ -68,26 +86,11 @@ export async function saveSessionState(
6886
return
6987
}
7088

71-
await ensureStorageDir()
72-
7389
const state: PersistedSessionState = {
7490
sessionName: sessionName,
7591
prune: {
7692
tools: Object.fromEntries(sessionState.prune.tools),
77-
messages: {
78-
byMessageId: Object.fromEntries(sessionState.prune.messages.byMessageId),
79-
blocksById: Object.fromEntries(
80-
Array.from(sessionState.prune.messages.blocksById.entries()).map(
81-
([blockId, block]) => [String(blockId), block],
82-
),
83-
),
84-
activeBlockIds: Array.from(sessionState.prune.messages.activeBlockIds),
85-
activeByAnchorMessageId: Object.fromEntries(
86-
sessionState.prune.messages.activeByAnchorMessageId,
87-
),
88-
nextBlockId: sessionState.prune.messages.nextBlockId,
89-
nextRunId: sessionState.prune.messages.nextRunId,
90-
},
93+
messages: serializePruneMessagesState(sessionState.prune.messages),
9194
},
9295
nudges: {
9396
contextLimitAnchors: Array.from(sessionState.nudges.contextLimitAnchors),
@@ -98,14 +101,7 @@ export async function saveSessionState(
98101
lastUpdated: new Date().toISOString(),
99102
}
100103

101-
const filePath = getSessionFilePath(sessionState.sessionId)
102-
const content = JSON.stringify(state, null, 2)
103-
await fs.writeFile(filePath, content, "utf-8")
104-
105-
logger.info("Saved session state to disk", {
106-
sessionId: sessionState.sessionId,
107-
totalTokensSaved: state.stats.totalPruneTokens,
108-
})
104+
await writePersistedSessionState(sessionState.sessionId, state, logger)
109105
} catch (error: any) {
110106
logger.error("Failed to save session state", {
111107
sessionId: sessionState.sessionId,

lib/state/state.ts

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import type { CompressionStart, SessionState, ToolParameterEntry, WithParts } from "./types"
1+
import type { CompressionTimingState, SessionState, ToolParameterEntry, WithParts } from "./types"
22
import type { Logger } from "../logger"
33
import { loadSessionState, saveSessionState } from "./persistence"
44
import {
5+
attachCompressionDuration,
56
isSubAgentSession,
67
findLastCompactionTimestamp,
78
countTurns,
@@ -13,6 +14,13 @@ import {
1314
} from "./utils"
1415
import { getLastUserMessage } from "../messages/query"
1516

17+
function createCompressionTimingState(): CompressionTimingState {
18+
return {
19+
startsByCallId: new Map(),
20+
pendingBySessionId: new Map(),
21+
}
22+
}
23+
1624
export const checkSession = async (
1725
client: any,
1826
state: SessionState,
@@ -43,6 +51,17 @@ export const checkSession = async (
4351
}
4452
}
4553

54+
if (state.sessionId === lastSessionId) {
55+
const applied = applyPendingCompressionDurations(state, lastSessionId)
56+
if (applied > 0) {
57+
saveSessionState(state, logger).catch((error) => {
58+
logger.warn("Failed to persist queued compression time updates", {
59+
error: error instanceof Error ? error.message : String(error),
60+
})
61+
})
62+
}
63+
}
64+
4665
const lastCompactionTimestamp = findLastCompactionTimestamp(messages)
4766
if (lastCompactionTimestamp > state.lastCompaction) {
4867
state.lastCompaction = lastCompactionTimestamp
@@ -81,7 +100,7 @@ export function createSessionState(): SessionState {
81100
pruneTokenCounter: 0,
82101
totalPruneTokens: 0,
83102
},
84-
compressionStarts: new Map<string, CompressionStart>(),
103+
compressionTiming: createCompressionTimingState(),
85104
toolParameters: new Map<string, ToolParameterEntry>(),
86105
subAgentResultCache: new Map<string, string>(),
87106
toolIdList: [],
@@ -178,4 +197,53 @@ export async function ensureSessionInitialized(
178197
pruneTokenCounter: persisted.stats?.pruneTokenCounter || 0,
179198
totalPruneTokens: persisted.stats?.totalPruneTokens || 0,
180199
}
200+
201+
const applied = applyPendingCompressionDurations(state, sessionId)
202+
if (applied > 0) {
203+
await saveSessionState(state, logger)
204+
}
205+
}
206+
207+
export function queueCompressionDuration(
208+
state: SessionState,
209+
sessionId: string,
210+
callId: string,
211+
messageId: string,
212+
durationMs: number,
213+
): void {
214+
const queued = state.compressionTiming.pendingBySessionId.get(sessionId) || []
215+
const filtered = queued.filter((entry) => entry.callId !== callId)
216+
filtered.push({ callId, messageId, durationMs })
217+
state.compressionTiming.pendingBySessionId.set(sessionId, filtered)
218+
}
219+
220+
export function applyPendingCompressionDurations(state: SessionState, sessionId: string): number {
221+
const queued = state.compressionTiming.pendingBySessionId.get(sessionId)
222+
if (!queued || queued.length === 0) {
223+
return 0
224+
}
225+
226+
let updates = 0
227+
const remaining = []
228+
for (const entry of queued) {
229+
const applied = attachCompressionDuration(
230+
state.prune.messages,
231+
entry.callId,
232+
entry.messageId,
233+
entry.durationMs,
234+
)
235+
if (applied > 0) {
236+
updates += applied
237+
continue
238+
}
239+
remaining.push(entry)
240+
}
241+
242+
if (remaining.length > 0) {
243+
state.compressionTiming.pendingBySessionId.set(sessionId, remaining)
244+
} else {
245+
state.compressionTiming.pendingBySessionId.delete(sessionId)
246+
}
247+
248+
return updates
181249
}

lib/state/types.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,22 @@ export interface SessionStats {
2222
}
2323

2424
export interface CompressionStart {
25+
sessionId: string
2526
messageId: string
2627
startedAt: number
2728
}
2829

30+
export interface PendingCompressionDuration {
31+
callId: string
32+
messageId: string
33+
durationMs: number
34+
}
35+
36+
export interface CompressionTimingState {
37+
startsByCallId: Map<string, CompressionStart>
38+
pendingBySessionId: Map<string, PendingCompressionDuration[]>
39+
}
40+
2941
export interface PrunedMessageEntry {
3042
tokenCount: number
3143
allBlockIds: number[]
@@ -103,7 +115,7 @@ export interface SessionState {
103115
prune: Prune
104116
nudges: Nudges
105117
stats: SessionStats
106-
compressionStarts: Map<string, CompressionStart>
118+
compressionTiming: CompressionTimingState
107119
toolParameters: Map<string, ToolParameterEntry>
108120
subAgentResultCache: Map<string, string>
109121
toolIdList: string[]

0 commit comments

Comments
 (0)