-
Notifications
You must be signed in to change notification settings - Fork 19
Improve history sanitization #290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
b159280
Remove only corrupted message during sanitization
brichet 3f119f9
Prevent adding corrupted messages sequence in history instead of fixi…
brichet 331310d
Merge branch 'main' into history-sanitization
brichet dc1dfd0
Merge branch 'main' into history-sanitization
brichet 778ca31
Merge branch 'main' into history-sanitization
brichet File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import { createMCPClient, type MCPClient } from '@ai-sdk/mcp'; | ||
| import type { IMessageContent } from '@jupyter/chat'; | ||
| import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; | ||
| import { PromiseDelegate } from '@lumino/coreutils'; | ||
| import { ISignal, Signal } from '@lumino/signaling'; | ||
| import { | ||
| ToolLoopAgent, | ||
|
|
@@ -11,7 +12,8 @@ import { | |
| type ToolApprovalRequestOutput, | ||
| type TypedToolError, | ||
| type TypedToolOutputDenied, | ||
| type TypedToolResult | ||
| type TypedToolResult, | ||
| type AssistantModelMessage | ||
| } from 'ai'; | ||
| import { ISecretsManager } from 'jupyter-secrets-manager'; | ||
|
|
||
|
|
@@ -324,6 +326,7 @@ export class AgentManager implements IAgentManager { | |
| this._skills = []; | ||
| this._agentConfig = null; | ||
| this._renderMimeRegistry = options.renderMimeRegistry; | ||
| this._streaming.resolve(); | ||
|
|
||
| this.activeProvider = | ||
| options.activeProvider ?? this._settingsModel.config.defaultProvider; | ||
|
|
@@ -452,19 +455,11 @@ export class AgentManager implements IAgentManager { | |
| /** | ||
| * Clears conversation history and resets agent state. | ||
| */ | ||
| clearHistory(): void { | ||
| async clearHistory(): Promise<void> { | ||
| // Stop any ongoing streaming | ||
| this.stopStreaming(); | ||
| this.stopStreaming('Chat cleared'); | ||
|
|
||
| // Reject any pending approvals | ||
| for (const [approvalId, pending] of this._pendingApprovals) { | ||
| pending.resolve(false, 'Chat cleared'); | ||
| this._agentEvent.emit({ | ||
| type: 'tool_approval_resolved', | ||
| data: { approvalId, approved: false } | ||
| }); | ||
| } | ||
| this._pendingApprovals.clear(); | ||
| await this._streaming.promise; | ||
|
|
||
| // Clear history and token usage | ||
| this._history = []; | ||
|
|
@@ -502,9 +497,20 @@ export class AgentManager implements IAgentManager { | |
|
|
||
| /** | ||
| * Stops the current streaming response by aborting the request. | ||
| * Resolve any pending approval. | ||
| */ | ||
| stopStreaming(): void { | ||
| stopStreaming(reason?: string): void { | ||
| this._controller?.abort(); | ||
|
|
||
| // Reject any pending approvals | ||
| for (const [approvalId, pending] of this._pendingApprovals) { | ||
| pending.resolve(false, reason ?? 'Stream ended by user'); | ||
| this._agentEvent.emit({ | ||
| type: 'tool_approval_resolved', | ||
| data: { approvalId, approved: false } | ||
| }); | ||
| } | ||
| this._pendingApprovals.clear(); | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -547,8 +553,9 @@ export class AgentManager implements IAgentManager { | |
| * @param message The user message to respond to (may include processed attachment content) | ||
| */ | ||
| async generateResponse(message: string): Promise<void> { | ||
| this._streaming = new PromiseDelegate(); | ||
| this._controller = new AbortController(); | ||
|
|
||
| const responseHistory: ModelMessage[] = []; | ||
| try { | ||
| // Ensure we have an agent | ||
| if (!this._agent) { | ||
|
|
@@ -560,15 +567,15 @@ export class AgentManager implements IAgentManager { | |
| } | ||
|
|
||
| // Add user message to history | ||
| this._history.push({ | ||
| responseHistory.push({ | ||
| role: 'user', | ||
| content: message | ||
| }); | ||
|
|
||
| let continueLoop = true; | ||
| while (continueLoop) { | ||
| const result = await this._agent.stream({ | ||
| messages: this._history, | ||
| messages: [...this._history, ...responseHistory], | ||
| abortSignal: this._controller.signal | ||
| }); | ||
|
|
||
|
|
@@ -580,15 +587,13 @@ export class AgentManager implements IAgentManager { | |
|
|
||
| // Add response messages to history | ||
| if (responseMessages.messages?.length) { | ||
| this._history.push( | ||
| ...Private.sanitizeModelMessages(responseMessages.messages) | ||
| ); | ||
| responseHistory.push(...responseMessages.messages); | ||
| } | ||
|
|
||
| // Add approval response if processed | ||
| if (streamResult.approvalResponse) { | ||
| // Check if the last message is a tool message we can append to | ||
| const lastMsg = this._history[this._history.length - 1]; | ||
| const lastMsg = responseHistory[responseHistory.length - 1]; | ||
| if ( | ||
| lastMsg && | ||
| lastMsg.role === 'tool' && | ||
|
|
@@ -599,24 +604,25 @@ export class AgentManager implements IAgentManager { | |
| toolContent.push(...streamResult.approvalResponse.content); | ||
| } else { | ||
| // Add as separate message | ||
| this._history.push(streamResult.approvalResponse); | ||
| responseHistory.push(streamResult.approvalResponse); | ||
| } | ||
| } | ||
|
|
||
| continueLoop = streamResult.approvalProcessed; | ||
| } | ||
|
|
||
| // Add the messages to the history only if the response ended without error. | ||
| this._history.push(...Private.sanitizeModelMessages(responseHistory)); | ||
| } catch (error) { | ||
| if ((error as Error).name !== 'AbortError') { | ||
| this._agentEvent.emit({ | ||
| type: 'error', | ||
| data: { error: error as Error } | ||
| }); | ||
| } | ||
| // After an error (including AbortError), sanitize the history | ||
| // to remove any trailing assistant messages without tool results | ||
| this._sanitizeHistory(); | ||
| } finally { | ||
| this._controller = null; | ||
| this._streaming.resolve(); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -1149,96 +1155,6 @@ WEB RETRIEVAL POLICY: | |
| return `Supported MIME types in this session: ${safeMimeTypes.join(', ')}`; | ||
| } | ||
|
|
||
| /** | ||
| * Sanitizes history to ensure it's in a valid state in case of abort or error. | ||
| */ | ||
| private _sanitizeHistory(): void { | ||
| if (this._history.length === 0) { | ||
| return; | ||
| } | ||
|
|
||
| const newHistory: ModelMessage[] = []; | ||
| for (let i = 0; i < this._history.length; i++) { | ||
| const msg = this._history[i]; | ||
|
|
||
| if (msg.role === 'assistant') { | ||
| const toolCallIds = this._getToolCallIds(msg); | ||
| if (toolCallIds.length > 0) { | ||
| // Find if there's a following tool message with results for these calls | ||
| const nextMsg = this._history[i + 1]; | ||
| if ( | ||
| nextMsg && | ||
| nextMsg.role === 'tool' && | ||
| this._matchesAllToolCalls(nextMsg, toolCallIds) | ||
| ) { | ||
| newHistory.push(msg); | ||
| } else { | ||
| // Message has unmatched tool calls drop it and everything after it | ||
| break; | ||
| } | ||
| } else { | ||
| newHistory.push(msg); | ||
| } | ||
| } else if (msg.role === 'tool') { | ||
| // Tool messages are valid if they were preceded by a valid assistant message | ||
| newHistory.push(msg); | ||
| } else { | ||
| newHistory.push(msg); | ||
| } | ||
| } | ||
|
|
||
| this._history = newHistory; | ||
| } | ||
|
|
||
| /** | ||
| * Extracts tool call IDs from a message | ||
| */ | ||
| private _getToolCallIds(message: ModelMessage): string[] { | ||
| const ids: string[] = []; | ||
|
|
||
| // Check content array for tool-call parts | ||
| if (Array.isArray(message.content)) { | ||
| for (const part of message.content) { | ||
| if ( | ||
| typeof part === 'object' && | ||
| part !== null && | ||
| 'type' in part && | ||
| part.type === 'tool-call' | ||
| ) { | ||
| ids.push(part.toolCallId); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return ids; | ||
| } | ||
|
|
||
| /** | ||
| * Checks if a tool message contains results for all specified tool call IDs | ||
| */ | ||
| private _matchesAllToolCalls( | ||
| message: ModelMessage, | ||
| callIds: string[] | ||
| ): boolean { | ||
| if (message.role !== 'tool' || !Array.isArray(message.content)) { | ||
| return false; | ||
| } | ||
|
|
||
| const resultIds = new Set<string>(); | ||
| for (const part of message.content) { | ||
| if ( | ||
| typeof part === 'object' && | ||
| part !== null && | ||
| 'type' in part && | ||
| part.type === 'tool-result' | ||
| ) { | ||
| resultIds.add(part.toolCallId); | ||
| } | ||
| } | ||
|
|
||
| return callIds.every(id => resultIds.has(id)); | ||
| } | ||
|
|
||
| // Private attributes | ||
| private _settingsModel: IAISettingsModel; | ||
| private _toolRegistry?: IToolRegistry; | ||
|
|
@@ -1263,25 +1179,123 @@ WEB RETRIEVAL POLICY: | |
| string, | ||
| { resolve: (approved: boolean, reason?: string) => void } | ||
| > = new Map(); | ||
| private _streaming: PromiseDelegate<void> = new PromiseDelegate(); | ||
| } | ||
|
|
||
| namespace Private { | ||
| /** | ||
| * Keep only serializable messages by doing a JSON round-trip. | ||
| * Messages that cannot be serialized are dropped. | ||
| * Sanitize the messages before adding them to the history. | ||
| * | ||
| * 1- Make sure the message sequence is not altered: | ||
| * - tool-call messages should have a corresponding tool-result (and vice-versa) | ||
| * - tool-approval-request should have a tool-approval-response (and vice-versa) | ||
| * | ||
| * 2- Keep only serializable messages by doing a JSON round-trip. | ||
| * Messages that cannot be serialized are dropped. | ||
| */ | ||
| export const sanitizeModelMessages = ( | ||
| messages: ModelMessage[] | ||
| ): ModelMessage[] => { | ||
| const sanitized: ModelMessage[] = []; | ||
| for (const message of messages) { | ||
| try { | ||
| sanitized.push(JSON.parse(JSON.stringify(message))); | ||
| } catch { | ||
| // Drop messages that cannot be serialized | ||
| if (message.role === 'assistant') { | ||
| let newMessage: AssistantModelMessage | undefined; | ||
| if (!Array.isArray(message.content)) { | ||
| newMessage = message; | ||
| } else { | ||
| // Remove assistant message content without a required response. | ||
| const newContent: typeof message.content = []; | ||
| for (const assistantContent of message.content) { | ||
| let isContentValid = true; | ||
| if (assistantContent.type === 'tool-call') { | ||
| const toolCallId = assistantContent.toolCallId; | ||
| isContentValid = !!messages.find( | ||
| msg => | ||
| msg.role === 'tool' && | ||
| Array.isArray(msg.content) && | ||
| msg.content.find( | ||
| content => | ||
| content.type === 'tool-result' && | ||
| content.toolCallId === toolCallId | ||
| ) | ||
| ); | ||
| } else if (assistantContent.type === 'tool-approval-request') { | ||
| const approvalId = assistantContent.approvalId; | ||
| isContentValid = !!messages.find( | ||
| msg => | ||
| msg.role === 'tool' && | ||
| Array.isArray(msg.content) && | ||
| msg.content.find( | ||
| content => | ||
| content.type === 'tool-approval-response' && | ||
| content.approvalId === approvalId | ||
| ) | ||
| ); | ||
| } | ||
| if (isContentValid) { | ||
| newContent.push(assistantContent); | ||
| } | ||
| } | ||
| if (newContent.length) { | ||
| newMessage = { ...message, content: newContent }; | ||
| } | ||
| } | ||
| if (newMessage) { | ||
| try { | ||
| sanitized.push(JSON.parse(JSON.stringify(newMessage))); | ||
| } catch { | ||
| // Drop messages that cannot be serialized | ||
| } | ||
| } | ||
| } else if (message.role === 'tool') { | ||
| // Remove tool message content without request. | ||
| const newContent: typeof message.content = []; | ||
| for (const toolContent of message.content) { | ||
| let isContentValid = true; | ||
| if (toolContent.type === 'tool-result') { | ||
| const toolCallId = toolContent.toolCallId; | ||
| isContentValid = !!sanitized.find( | ||
| msg => | ||
| msg.role === 'assistant' && | ||
| Array.isArray(msg.content) && | ||
| msg.content.find( | ||
| content => | ||
| content.type === 'tool-call' && | ||
| content.toolCallId === toolCallId | ||
| ) | ||
| ); | ||
| } else if (toolContent.type === 'tool-approval-response') { | ||
| const approvalId = toolContent.approvalId; | ||
| isContentValid = !!sanitized.find( | ||
| msg => | ||
| msg.role === 'assistant' && | ||
| Array.isArray(msg.content) && | ||
| msg.content.find( | ||
| content => | ||
| content.type === 'tool-approval-request' && | ||
| content.approvalId === approvalId | ||
| ) | ||
| ); | ||
| } | ||
| if (isContentValid) { | ||
| newContent.push(toolContent); | ||
| } | ||
| } | ||
| if (newContent.length) { | ||
| try { | ||
| sanitized.push( | ||
| JSON.parse(JSON.stringify({ ...message, content: newContent })) | ||
| ); | ||
| } catch { | ||
| // Drop messages that cannot be serialized | ||
| } | ||
| } | ||
| } else { | ||
| // Message is a system or user message. | ||
| sanitized.push(message); | ||
| } | ||
| } | ||
| return sanitized; | ||
| return sanitized.length === messages.length ? sanitized : []; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this mean all messages are dropped if only one of them is removed during sanitization?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually expected:
|
||
| }; | ||
|
|
||
| /** | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claude found that this change may require updating the following to properly
awaitthe function call here:ai/src/chat-model.ts
Line 188 in 62da3ae
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is called asynchronously, on click.
AFAIK, awaiting it in the chat model would not prevent the user to send a message... Or we should "lock" the input waiting for it.
Do you think that it worth it ?