Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/agents/AgentContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ export class AgentContext {
/** Current token type being processed */
currentTokenType: ContentTypes.TEXT | ContentTypes.THINK | 'think_and_text' =
ContentTypes.TEXT;
/**
* State machine for detecting fragmented thinking tags in streaming content.
* States:
* - 'normal': Processing regular text, looking for opening tags
* - 'buffering_open': Accumulating potential opening tag (<think> or <thinking>)
* - 'thinking': Inside thinking block, processing thinking content
* - 'buffering_close': Accumulating potential closing tag (</think> or </thinking>)
*/
thinkingState: 'normal' | 'buffering_open' | 'thinking' | 'buffering_close' =
'normal';
/** Buffer for accumulating potential thinking tags */
tagBuffer: string = '';
/** Current step ID for flushing buffer at stream end */
currentStepId?: string;
/** Whether tools should end the workflow */
toolEnd: boolean = false;
/** Cached system runnable (created lazily) */
Expand Down
8 changes: 6 additions & 2 deletions src/events.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type { MultiAgentGraph, StandardGraph } from '@/graphs';
import type { Logger } from 'winston';
import type * as t from '@/types';
import { handleToolCalls } from '@/tools/handlers';
import { flushThinkingBuffer } from '@/stream';
import { Constants, Providers } from '@/common';

export class HandlerRegistry {
Expand Down Expand Up @@ -43,6 +44,11 @@ export class ModelEndHandler implements t.EventHandler {
return;
}

const agentContext = graph.getAgentContext(metadata);

// Flush any remaining buffered content from thinking tag state machine
await flushThinkingBuffer(agentContext, graph as StandardGraph);

const usage = data?.output?.usage_metadata;
if (usage != null && this.collectedUsage != null) {
this.collectedUsage.push(usage);
Expand All @@ -60,8 +66,6 @@ export class ModelEndHandler implements t.EventHandler {
{ depth: null }
);

const agentContext = graph.getAgentContext(metadata);

if (
agentContext.provider !== Providers.GOOGLE &&
agentContext.provider !== Providers.BEDROCK
Expand Down
277 changes: 277 additions & 0 deletions src/specs/fragmented-thinking.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
// src/specs/fragmented-thinking.test.ts
// Tests for fragmented <thinking> tag handling in streamed content
// This tests the state machine that detects thinking tags split across chunks

import { HumanMessage, MessageContentText } from '@langchain/core/messages';
import type { RunnableConfig } from '@langchain/core/runnables';
import type * as t from '@/types';
import { ChatModelStreamHandler, createContentAggregator } from '@/stream';
import { GraphEvents, Providers } from '@/common';
import { getLLMConfig } from '@/utils/llmConfig';
import { Run } from '@/run';

describe('Fragmented Thinking Tags Tests', () => {
jest.setTimeout(30000);
let run: Run<t.IState>;
let contentParts: t.MessageContentComplex[];
let aggregateContent: t.ContentAggregator;
let runSteps: Set<string>;

const config: Partial<RunnableConfig> & {
version: 'v1' | 'v2';
run_id?: string;
streamMode: string;
} = {
configurable: {
thread_id: 'fragmented-thinking-test',
},
streamMode: 'values',
version: 'v2' as const,
callbacks: [
{
async handleCustomEvent(event, data): Promise<void> {
if (event !== GraphEvents.ON_MESSAGE_DELTA) {
return;
}
const messageDeltaData = data as t.MessageDeltaEvent;

// Wait until we see the run step
const maxAttempts = 50;
let attempts = 0;
while (!runSteps.has(messageDeltaData.id) && attempts < maxAttempts) {
await new Promise((resolve) => setTimeout(resolve, 100));
attempts++;
}

aggregateContent({ event, data: messageDeltaData });
},
},
],
};

beforeEach(async () => {
const { contentParts: parts, aggregateContent: ac } =
createContentAggregator();
aggregateContent = ac;
runSteps = new Set();
contentParts = parts as t.MessageContentComplex[];
});

afterEach(() => {
runSteps.clear();
});

const onReasoningDeltaSpy = jest.fn();

afterAll(() => {
onReasoningDeltaSpy.mockReset();
});

const setupCustomHandlers = (): Record<
string | GraphEvents,
t.EventHandler
> => ({
[GraphEvents.CHAT_MODEL_STREAM]: new ChatModelStreamHandler(),
[GraphEvents.ON_RUN_STEP_COMPLETED]: {
handle: (
event: GraphEvents.ON_RUN_STEP_COMPLETED,
data: t.StreamEventData
): void => {
aggregateContent({
event,
data: data as unknown as { result: t.ToolEndEvent },
});
},
},
[GraphEvents.ON_RUN_STEP]: {
handle: (
event: GraphEvents.ON_RUN_STEP,
data: t.StreamEventData
): void => {
const runStepData = data as t.RunStep;
runSteps.add(runStepData.id);
aggregateContent({ event, data: runStepData });
},
},
[GraphEvents.ON_RUN_STEP_DELTA]: {
handle: (
event: GraphEvents.ON_RUN_STEP_DELTA,
data: t.StreamEventData
): void => {
aggregateContent({ event, data: data as t.RunStepDeltaEvent });
},
},
[GraphEvents.ON_REASONING_DELTA]: {
handle: (
event: GraphEvents.ON_REASONING_DELTA,
data: t.StreamEventData
): void => {
onReasoningDeltaSpy(event, data);
aggregateContent({ event, data: data as t.ReasoningDeltaEvent });
},
},
});

// Helper to create a fresh run for each test
const createTestRun = async (
customHandlers: Record<string | GraphEvents, t.EventHandler>
): Promise<Run<t.IState>> => {
const llmConfig = getLLMConfig(Providers.BEDROCK);
return Run.create<t.IState>({
runId: `fragmented-thinking-test-run-${Date.now()}`,
graphConfig: {
type: 'standard',
llmConfig,
instructions: 'You are a helpful assistant.',
},
returnContent: true,
customHandlers,
});
};

// Test with <thinking> tags
test('should handle <thinking> tags in streamed content', async () => {
const customHandlers = setupCustomHandlers();
run = await createTestRun(customHandlers);

const responseWithThinkingTag =
'<thinking> Let me think about this. </thinking> The answer is 42.';
run.Graph?.overrideTestModel([responseWithThinkingTag], 2);

const inputs = {
messages: [new HumanMessage('What is the meaning of life?')],
};

await run.processStream(inputs, config);

expect(contentParts).toBeDefined();
expect(contentParts.length).toBe(2);

const thinkingPart = contentParts.find(
(p) => (p as t.ReasoningContentText).think !== undefined
) as t.ReasoningContentText;
const textPart = contentParts.find(
(p) => (p as MessageContentText).text !== undefined
) as MessageContentText;

expect(thinkingPart).toBeDefined();
expect(thinkingPart.think).toContain('Let me think about this.');
expect(thinkingPart.think).not.toContain('<thinking>');
expect(thinkingPart.think).not.toContain('</thinking>');

expect(textPart).toBeDefined();
expect(textPart.text).toContain('The answer is 42.');
expect(textPart.text).not.toContain('<thinking>');

expect(onReasoningDeltaSpy).toHaveBeenCalled();
});

// Test with <think> tags (shorter variant)
test('should handle <think> tags in streamed content', async () => {
onReasoningDeltaSpy.mockClear();
const customHandlers = setupCustomHandlers();
run = await createTestRun(customHandlers);

const responseWithThinkTag =
'<think> Processing the question... </think> Here is my response.';
run.Graph?.overrideTestModel([responseWithThinkTag], 2);

const inputs = {
messages: [new HumanMessage('Tell me something.')],
};

await run.processStream(inputs, config);

expect(contentParts).toBeDefined();
expect(contentParts.length).toBe(2);

const thinkingPart = contentParts.find(
(p) => (p as t.ReasoningContentText).think !== undefined
) as t.ReasoningContentText;
const textPart = contentParts.find(
(p) => (p as MessageContentText).text !== undefined
) as MessageContentText;

expect(thinkingPart).toBeDefined();
expect(thinkingPart.think).toContain('Processing the question...');
expect(thinkingPart.think).not.toContain('<think>');
expect(thinkingPart.think).not.toContain('</think>');

expect(textPart).toBeDefined();
expect(textPart.text).toContain('Here is my response.');
expect(textPart.text).not.toContain('<think>');

expect(onReasoningDeltaSpy).toHaveBeenCalled();
});

// Test with plain text (no thinking tags)
test('should handle plain text without thinking tags', async () => {
onReasoningDeltaSpy.mockClear();
const customHandlers = setupCustomHandlers();
run = await createTestRun(customHandlers);

const responseWithoutTags =
'This is a simple response without any thinking.';
run.Graph?.overrideTestModel([responseWithoutTags], 2);

const inputs = {
messages: [new HumanMessage('Say something simple.')],
};

await run.processStream(inputs, config);

expect(contentParts).toBeDefined();
expect(contentParts.length).toBe(1);

const textPart = contentParts[0] as MessageContentText;
expect(textPart.text).toBe(
'This is a simple response without any thinking.'
);

// No reasoning delta should be called for plain text
expect(onReasoningDeltaSpy).not.toHaveBeenCalled();
});

// Test with multiple thinking blocks in sequence
test('should handle multiple thinking blocks in sequence', async () => {
onReasoningDeltaSpy.mockClear();
const customHandlers = setupCustomHandlers();
run = await createTestRun(customHandlers);

const responseWithMultipleThinkingTags =
'<thinking> First thought. </thinking> Response one. <thinking> Second thought. </thinking> Response two.';
run.Graph?.overrideTestModel([responseWithMultipleThinkingTags], 2);

const inputs = {
messages: [new HumanMessage('Give me a complex response.')],
};

await run.processStream(inputs, config);

expect(contentParts).toBeDefined();
// Should have thinking and text parts (exact count depends on aggregation)
expect(contentParts.length).toBeGreaterThanOrEqual(2);

const thinkingPart = contentParts.find(
(p) => (p as t.ReasoningContentText).think !== undefined
) as t.ReasoningContentText;
const textPart = contentParts.find(
(p) => (p as MessageContentText).text !== undefined
) as MessageContentText;

// Verify thinking content contains both thoughts (accumulated)
expect(thinkingPart).toBeDefined();
expect(thinkingPart.think).toContain('First thought.');
expect(thinkingPart.think).toContain('Second thought.');
expect(thinkingPart.think).not.toContain('<thinking>');
expect(thinkingPart.think).not.toContain('</thinking>');

// Verify text content contains both responses
expect(textPart).toBeDefined();
expect(textPart.text).toContain('Response one.');
expect(textPart.text).toContain('Response two.');
expect(textPart.text).not.toContain('<thinking>');

expect(onReasoningDeltaSpy).toHaveBeenCalled();
});
});
3 changes: 2 additions & 1 deletion src/specs/reasoning.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ describe(`${capitalizeFirstLetter(provider)} Streaming Tests`, () => {
await run.processStream(inputs, config);
expect(contentParts).toBeDefined();
expect(contentParts.length).toBe(2);
const reasoningContent = reasoningText.match(/<think>(.*)<\/think>/s)?.[0];
// Tags are stripped from thinking content by the state machine
const reasoningContent = reasoningText.match(/<think>(.*)<\/think>/s)?.[1];
const content = reasoningText.split(/<\/think>/)[1];
expect((contentParts[0] as t.ReasoningContentText).think).toBe(
reasoningContent
Expand Down
Loading