diff --git a/.changeset/bright-glasses-happen.md b/.changeset/bright-glasses-happen.md new file mode 100644 index 000000000000..41930e4f56e5 --- /dev/null +++ b/.changeset/bright-glasses-happen.md @@ -0,0 +1,6 @@ +--- +"@ai-sdk/provider-utils": major +"ai": major +--- + +feat(ai): change type of experimental_context from unknown to generic diff --git a/examples/ai-e2e-next/app/api/chat/human-in-the-loop/utils.ts b/examples/ai-e2e-next/app/api/chat/human-in-the-loop/utils.ts index 561ecf49a683..752f4d99216c 100644 --- a/examples/ai-e2e-next/app/api/chat/human-in-the-loop/utils.ts +++ b/examples/ai-e2e-next/app/api/chat/human-in-the-loop/utils.ts @@ -51,7 +51,7 @@ export async function processToolCalls< executeFunctions: { [K in keyof Tools & keyof ExecutableTools]?: ( args: ExecutableTools[K] extends Tool ? P : never, - context: ToolExecutionOptions, + context: ToolExecutionOptions<{}>, ) => Promise; }, ): Promise { @@ -86,6 +86,7 @@ export async function processToolCalls< result = await toolInstance(part.input, { messages: await convertToModelMessages(messages), toolCallId: part.toolCallId, + experimental_context: {}, }); } else { result = 'Error: No execute function found on tool'; diff --git a/examples/ai-functions/src/generate-text/openai/tool-call-with-context.ts b/examples/ai-functions/src/generate-text/openai/tool-call-with-context.ts index 743ca0c8b363..2ac1cb6dbb06 100644 --- a/examples/ai-functions/src/generate-text/openai/tool-call-with-context.ts +++ b/examples/ai-functions/src/generate-text/openai/tool-call-with-context.ts @@ -5,17 +5,21 @@ import { run } from '../../lib/run'; run(async () => { const result = await generateText({ - model: openai('gpt-4o'), + model: openai('gpt-5-mini'), tools: { weather: tool({ description: 'Get the weather in a location', inputSchema: z.object({ location: z.string().describe('The location to get the weather for'), }), - execute: async ({ location }, { experimental_context: context }) => { - const typedContext = context as { weatherApiKey: string }; // or use type validation library - - console.log(typedContext); + contextSchema: z.object({ + weatherApiKey: z.string().describe('The API key for the weather API'), + }), + execute: async ( + { location }, + { experimental_context: { weatherApiKey } }, + ) => { + console.log('weather tool api key:', weatherApiKey); return { location, @@ -23,8 +27,39 @@ run(async () => { }; }, }), + calculator: tool({ + description: 'Calculate mathematical expressions', + inputSchema: z.object({ + expression: z + .string() + .describe('The mathematical expression to calculate'), + }), + contextSchema: z.object({ + calculatorApiKey: z + .string() + .describe('The API key for the calculator API'), + }), + execute: async ( + { expression }, + { experimental_context: { calculatorApiKey } }, + ) => { + console.log('calculator tool api key:', calculatorApiKey); + return { + expression, + result: eval(expression), + }; + }, + }), + }, + experimental_context: { + weatherApiKey: 'weather-123', + calculatorApiKey: 'calculator-456', + somethingElse: 'other-context', + }, + prepareStep: async ({ experimental_context: context }) => { + console.log('prepareStep context:', context); + return {}; }, - experimental_context: { weatherApiKey: '123' }, prompt: 'What is the weather in San Francisco?', }); diff --git a/examples/ai-functions/src/lib/print-full-stream.ts b/examples/ai-functions/src/lib/print-full-stream.ts index dc63f0b30c55..a8994ba4e493 100644 --- a/examples/ai-functions/src/lib/print-full-stream.ts +++ b/examples/ai-functions/src/lib/print-full-stream.ts @@ -3,7 +3,7 @@ import { StreamTextResult } from 'ai'; export async function printFullStream({ result, }: { - result: StreamTextResult; + result: StreamTextResult; }) { for await (const chunk of result.fullStream) { switch (chunk.type) { diff --git a/examples/ai-functions/src/lib/save-raw-chunks.ts b/examples/ai-functions/src/lib/save-raw-chunks.ts index 2bb0a0f8662f..57661394945f 100644 --- a/examples/ai-functions/src/lib/save-raw-chunks.ts +++ b/examples/ai-functions/src/lib/save-raw-chunks.ts @@ -5,7 +5,7 @@ export async function saveRawChunks({ result, filename, }: { - result: StreamTextResult; + result: StreamTextResult; filename: string; }) { const rawChunks: unknown[] = []; diff --git a/examples/ai-functions/src/stream-text/openai/tool-call-with-context.ts b/examples/ai-functions/src/stream-text/openai/tool-call-with-context.ts new file mode 100644 index 000000000000..41fe8dca9508 --- /dev/null +++ b/examples/ai-functions/src/stream-text/openai/tool-call-with-context.ts @@ -0,0 +1,68 @@ +import { openai } from '@ai-sdk/openai'; +import { streamText, tool } from 'ai'; +import { z } from 'zod'; +import { run } from '../../lib/run'; +import { printFullStream } from '../../lib/print-full-stream'; + +run(async () => { + const result = streamText({ + model: openai('gpt-5-mini'), + tools: { + weather: tool({ + description: 'Get the weather in a location', + inputSchema: z.object({ + location: z.string().describe('The location to get the weather for'), + }), + contextSchema: z.object({ + weatherApiKey: z.string().describe('The API key for the weather API'), + }), + execute: async ( + { location }, + { experimental_context: { weatherApiKey } }, + ) => { + console.log('weather tool api key:', weatherApiKey); + + return { + location, + temperature: 72 + Math.floor(Math.random() * 21) - 10, + }; + }, + }), + calculator: tool({ + description: 'Calculate mathematical expressions', + inputSchema: z.object({ + expression: z + .string() + .describe('The mathematical expression to calculate'), + }), + contextSchema: z.object({ + calculatorApiKey: z + .string() + .describe('The API key for the calculator API'), + }), + execute: async ( + { expression }, + { experimental_context: { calculatorApiKey } }, + ) => { + console.log('calculator tool api key:', calculatorApiKey); + return { + expression, + result: eval(expression), + }; + }, + }), + }, + experimental_context: { + weatherApiKey: 'weather-123', + calculatorApiKey: 'calculator-456', + somethingElse: 'other-context', + }, + prepareStep: async ({ experimental_context: context }) => { + console.log('prepareStep context:', context); + return {}; + }, + prompt: 'What is the weather in San Francisco?', + }); + + await printFullStream({ result }); +}); diff --git a/examples/ai-functions/src/test/typed-context-2.ts b/examples/ai-functions/src/test/typed-context-2.ts new file mode 100644 index 000000000000..e34507f0d7ea --- /dev/null +++ b/examples/ai-functions/src/test/typed-context-2.ts @@ -0,0 +1,92 @@ +import { FlexibleSchema } from 'ai'; +import { run } from '../lib/run'; +import { z } from 'zod'; + +type Context = Record; + +interface Tool { + inputSchema: FlexibleSchema; + contextSchema: FlexibleSchema; + execute: (input: NoInfer, context: NoInfer) => unknown; +} + +function tool(options: { + inputSchema: FlexibleSchema; + contextSchema: FlexibleSchema; + execute: (input: NoInfer, context: NoInfer) => unknown; +}) { + return options; +} + +type InferToolInput = + TOOL extends Tool ? INPUT : never; +type InferToolContext = + TOOL extends Tool ? CONTEXT : never; + +export type ToolSet = Record>; + +type UnionToIntersection = ( + U extends unknown ? (arg: U) => void : never +) extends (arg: infer I) => void + ? I + : never; + +// should be a union of all the context types of the tools +type InferToolSetContext = UnionToIntersection< + { + [K in keyof TOOLS]: InferToolContext; + }[keyof TOOLS] +>; + +type ExpandedContext = InferToolSetContext & + Context; + +function executeTool< + TOOLS extends ToolSet, + CONTEXT extends ExpandedContext, +>({ + tools, + toolName, + input, + context, +}: { + tools: TOOLS; + toolName: keyof TOOLS; + input: InferToolInput; + context: CONTEXT; + prepareStep: (context: CONTEXT) => void; +}) { + const tool = tools[toolName]; + return tool.execute(input, context); +} + +run(async () => { + const tool1 = tool({ + inputSchema: z.object({ input1: z.string() }), + contextSchema: z.object({ context1: z.number() }), + execute: async ({ input1 }, { context1 }) => { + console.log(input1, context1); + }, + }); + + const tool2 = tool({ + inputSchema: z.object({ input2: z.number() }), + contextSchema: z.object({ context2: z.string() }), + execute: async ({ input2 }, { context2 }) => { + console.log(input2, context2); + }, + }); + + executeTool({ + tools: { + tool1, + tool2, + }, + toolName: 'tool1', + input: { input1: 'Hello' }, + context: { context1: 1, context2: 'world', somethingElse: 'context' }, + prepareStep: context => { + console.log(context); + }, + }); +}); diff --git a/examples/ai-functions/src/test/typed-context.ts b/examples/ai-functions/src/test/typed-context.ts new file mode 100644 index 000000000000..4446294f8c8d --- /dev/null +++ b/examples/ai-functions/src/test/typed-context.ts @@ -0,0 +1,42 @@ +import { run } from '../lib/run'; + +interface Tool { + execute: (context: CONTEXT) => Promise; +} + +export type ToolSet = Record>; + +function executeTool>({ + tools, + toolName, + context, +}: { + tools: TOOLS; + toolName: keyof TOOLS; + context: CONTEXT; +}) { + return tools[toolName].execute(context); +} + +run(async () => { + const tool1: Tool<{ name: string }> = { + execute: async context => { + console.log(context); + }, + }; + + const tool2: Tool<{ age: number }> = { + execute: async context => { + console.log(context); + }, + }; + + executeTool({ + tools: { + tool1, + tool2, + }, + toolName: 'tool1', + context: { name: 'John', age: 30 }, + }); +}); diff --git a/examples/mcp/src/image-content/client.ts b/examples/mcp/src/image-content/client.ts index c7955957c56e..e36da502074e 100644 --- a/examples/mcp/src/image-content/client.ts +++ b/examples/mcp/src/image-content/client.ts @@ -20,7 +20,10 @@ async function main() { const tool = tools['get-image']; console.log('Calling get-image tool...\n'); - const result = await tool.execute!({}, { messages: [], toolCallId: '1' }); + const result = await tool.execute!( + {}, + { messages: [], toolCallId: '1', experimental_context: {} }, + ); console.log('Raw execute() result (MCP format):'); console.log(JSON.stringify(result, null, 2)); diff --git a/examples/mcp/src/output-schema/client.ts b/examples/mcp/src/output-schema/client.ts index f7ec49b6a7da..9832ae124a6a 100644 --- a/examples/mcp/src/output-schema/client.ts +++ b/examples/mcp/src/output-schema/client.ts @@ -59,7 +59,7 @@ async function main() { const weatherTool = tools['get-weather']; const weatherResult = await weatherTool.execute( { location: 'New York' }, - { messages: [], toolCallId: 'weather-1' }, + { messages: [], toolCallId: 'weather-1', experimental_context: {} }, ); const weather = weatherResult as { @@ -78,7 +78,7 @@ async function main() { const usersTool = tools['list-users']; const usersResult = await usersTool.execute( {}, - { messages: [], toolCallId: 'users-1' }, + { messages: [], toolCallId: 'users-1', experimental_context: {} }, ); const users = usersResult as { @@ -96,7 +96,7 @@ async function main() { const echoTool = tools['echo']; const echoResult = await echoTool.execute( { message: 'Hello, MCP!' }, - { messages: [], toolCallId: 'echo-1' }, + { messages: [], toolCallId: 'echo-1', experimental_context: {} }, ); console.log('Raw result:', JSON.stringify(echoResult, null, 2)); diff --git a/packages/ai/src/agent/agent.ts b/packages/ai/src/agent/agent.ts index 6ae0916ddd02..714496e2a259 100644 --- a/packages/ai/src/agent/agent.ts +++ b/packages/ai/src/agent/agent.ts @@ -3,7 +3,8 @@ import { GenerateTextResult } from '../generate-text/generate-text-result'; import { Output } from '../generate-text/output'; import { StreamTextTransform } from '../generate-text/stream-text'; import { StreamTextResult } from '../generate-text/stream-text-result'; -import { ToolSet } from '../generate-text/tool-set'; +import type { GenerationContext } from '../generate-text/generation-context'; +import type { ToolSet } from '../generate-text/tool-set'; import { TimeoutConfiguration } from '../prompt/call-settings'; import type { ToolLoopAgentOnFinishCallback, @@ -17,9 +18,11 @@ import type { /** * Parameters for calling an agent. */ -export type AgentCallParameters = ([ +export type AgentCallParameters< CALL_OPTIONS, -] extends [never] + TOOLS extends ToolSet = {}, + CONTEXT extends GenerationContext = GenerationContext, +> = ([CALL_OPTIONS] extends [never] ? { options?: never } : { options: CALL_OPTIONS }) & ( @@ -67,12 +70,12 @@ export type AgentCallParameters = ([ /** * Callback that is called when the agent operation begins, before any LLM calls. */ - experimental_onStart?: ToolLoopAgentOnStartCallback; + experimental_onStart?: ToolLoopAgentOnStartCallback; /** * Callback that is called when a step (LLM call) begins, before the provider is called. */ - experimental_onStepStart?: ToolLoopAgentOnStepStartCallback; + experimental_onStepStart?: ToolLoopAgentOnStepStartCallback; /** * Callback that is called before each tool execution begins. @@ -122,6 +125,7 @@ export type AgentStreamParameters< export interface Agent< CALL_OPTIONS = never, TOOLS extends ToolSet = {}, + CONTEXT extends GenerationContext = GenerationContext, OUTPUT extends Output = never, > { /** @@ -145,12 +149,12 @@ export interface Agent< */ generate( options: AgentCallParameters, - ): PromiseLike>; + ): PromiseLike>; /** * Streams an output from the agent (streaming). */ stream( options: AgentStreamParameters, - ): PromiseLike>; + ): PromiseLike>; } diff --git a/packages/ai/src/agent/create-agent-ui-stream-response.ts b/packages/ai/src/agent/create-agent-ui-stream-response.ts index b40a798e550a..957bb104111f 100644 --- a/packages/ai/src/agent/create-agent-ui-stream-response.ts +++ b/packages/ai/src/agent/create-agent-ui-stream-response.ts @@ -1,6 +1,7 @@ import { StreamTextTransform, UIMessageStreamOptions } from '../generate-text'; import { Output } from '../generate-text/output'; -import { ToolSet } from '../generate-text/tool-set'; +import type { GenerationContext } from '../generate-text/generation-context'; +import type { ToolSet } from '../generate-text/tool-set'; import { TimeoutConfiguration } from '../prompt/call-settings'; import { createUIMessageStreamResponse } from '../ui-message-stream'; import { UIMessageStreamResponseInit } from '../ui-message-stream/ui-message-stream-response-init'; @@ -29,7 +30,8 @@ import type { ToolLoopAgentOnStepFinishCallback } from './tool-loop-agent-settin export async function createAgentUIStreamResponse< CALL_OPTIONS = never, TOOLS extends ToolSet = {}, - OUTPUT extends Output = never, + CONTEXT extends GenerationContext = GenerationContext, + OUTPUT extends Output = never, MESSAGE_METADATA = unknown, >({ headers, @@ -38,7 +40,7 @@ export async function createAgentUIStreamResponse< consumeSseStream, ...options }: { - agent: Agent; + agent: Agent; uiMessages: unknown[]; abortSignal?: AbortSignal; timeout?: TimeoutConfiguration; diff --git a/packages/ai/src/agent/create-agent-ui-stream.ts b/packages/ai/src/agent/create-agent-ui-stream.ts index b0bead5412cf..7ceb950a99ec 100644 --- a/packages/ai/src/agent/create-agent-ui-stream.ts +++ b/packages/ai/src/agent/create-agent-ui-stream.ts @@ -1,6 +1,7 @@ import { StreamTextTransform, UIMessageStreamOptions } from '../generate-text'; import { Output } from '../generate-text/output'; -import { ToolSet } from '../generate-text/tool-set'; +import type { GenerationContext } from '../generate-text/generation-context'; +import type { ToolSet } from '../generate-text/tool-set'; import { TimeoutConfiguration } from '../prompt/call-settings'; import { InferUIMessageChunk } from '../ui-message-stream'; import { convertToModelMessages } from '../ui/convert-to-model-messages'; @@ -26,6 +27,7 @@ import type { ToolLoopAgentOnStepFinishCallback } from './tool-loop-agent-settin export async function createAgentUIStream< CALL_OPTIONS = never, TOOLS extends ToolSet = {}, + CONTEXT extends GenerationContext = GenerationContext, OUTPUT extends Output = never, MESSAGE_METADATA = unknown, >({ @@ -38,7 +40,7 @@ export async function createAgentUIStream< onStepFinish, ...uiMessageStreamOptions }: { - agent: Agent; + agent: Agent; uiMessages: unknown[]; abortSignal?: AbortSignal; timeout?: TimeoutConfiguration; diff --git a/packages/ai/src/agent/pipe-agent-ui-stream-to-response.ts b/packages/ai/src/agent/pipe-agent-ui-stream-to-response.ts index 6fcc21a0e48e..942c0871fbdc 100644 --- a/packages/ai/src/agent/pipe-agent-ui-stream-to-response.ts +++ b/packages/ai/src/agent/pipe-agent-ui-stream-to-response.ts @@ -1,7 +1,8 @@ import { ServerResponse } from 'node:http'; import { StreamTextTransform, UIMessageStreamOptions } from '../generate-text'; import { Output } from '../generate-text/output'; -import { ToolSet } from '../generate-text/tool-set'; +import type { GenerationContext } from '../generate-text/generation-context'; +import type { ToolSet } from '../generate-text/tool-set'; import { TimeoutConfiguration } from '../prompt/call-settings'; import { pipeUIMessageStreamToResponse } from '../ui-message-stream'; import { UIMessageStreamResponseInit } from '../ui-message-stream/ui-message-stream-response-init'; @@ -29,6 +30,7 @@ import type { ToolLoopAgentOnStepFinishCallback } from './tool-loop-agent-settin export async function pipeAgentUIStreamToResponse< CALL_OPTIONS = never, TOOLS extends ToolSet = {}, + CONTEXT extends GenerationContext = GenerationContext, OUTPUT extends Output = never, MESSAGE_METADATA = unknown, >({ @@ -40,7 +42,7 @@ export async function pipeAgentUIStreamToResponse< ...options }: { response: ServerResponse; - agent: Agent; + agent: Agent; uiMessages: unknown[]; abortSignal?: AbortSignal; timeout?: TimeoutConfiguration; diff --git a/packages/ai/src/agent/tool-loop-agent-settings.ts b/packages/ai/src/agent/tool-loop-agent-settings.ts index 1b456dae0af9..5f6cf1cab0db 100644 --- a/packages/ai/src/agent/tool-loop-agent-settings.ts +++ b/packages/ai/src/agent/tool-loop-agent-settings.ts @@ -16,7 +16,8 @@ import { Output } from '../generate-text/output'; import { PrepareStepFunction } from '../generate-text/prepare-step'; import { StopCondition } from '../generate-text/stop-condition'; import { ToolCallRepairFunction } from '../generate-text/tool-call-repair-function'; -import { ToolSet } from '../generate-text/tool-set'; +import type { GenerationContext } from '../generate-text/generation-context'; +import type { ToolSet } from '../generate-text/tool-set'; import { CallSettings, TimeoutConfiguration } from '../prompt/call-settings'; import { Prompt } from '../prompt/prompt'; import { TelemetrySettings } from '../telemetry/telemetry-settings'; @@ -26,13 +27,17 @@ import { AgentCallParameters } from './agent'; export type ToolLoopAgentOnStartCallback< TOOLS extends ToolSet = ToolSet, + CONTEXT extends GenerationContext = GenerationContext, OUTPUT extends Output = Output, -> = (event: OnStartEvent) => PromiseLike | void; +> = (event: OnStartEvent) => PromiseLike | void; export type ToolLoopAgentOnStepStartCallback< TOOLS extends ToolSet = ToolSet, + CONTEXT extends GenerationContext = GenerationContext, OUTPUT extends Output = Output, -> = (event: OnStepStartEvent) => PromiseLike | void; +> = ( + event: OnStepStartEvent, +) => PromiseLike | void; export type ToolLoopAgentOnToolCallStartCallback< TOOLS extends ToolSet = ToolSet, @@ -42,13 +47,15 @@ export type ToolLoopAgentOnToolCallFinishCallback< TOOLS extends ToolSet = ToolSet, > = (event: OnToolCallFinishEvent) => PromiseLike | void; -export type ToolLoopAgentOnStepFinishCallback = ( - stepResult: OnStepFinishEvent, -) => Promise | void; +export type ToolLoopAgentOnStepFinishCallback< + TOOLS extends ToolSet = ToolSet, + CONTEXT extends GenerationContext = GenerationContext, +> = (stepResult: OnStepFinishEvent) => Promise | void; -export type ToolLoopAgentOnFinishCallback = ( - event: OnFinishEvent, -) => PromiseLike | void; +export type ToolLoopAgentOnFinishCallback< + TOOLS extends ToolSet = ToolSet, + CONTEXT extends GenerationContext = GenerationContext, +> = (event: OnFinishEvent) => PromiseLike | void; /** * Configuration options for an agent. @@ -56,6 +63,7 @@ export type ToolLoopAgentOnFinishCallback = ( export type ToolLoopAgentSettings< CALL_OPTIONS = never, TOOLS extends ToolSet = {}, + CONTEXT extends GenerationContext = GenerationContext, OUTPUT extends Output = never, > = Omit & { /** @@ -100,8 +108,8 @@ export type ToolLoopAgentSettings< * @default isStepCount(20) */ stopWhen?: - | StopCondition> - | Array>>; + | StopCondition, CONTEXT> + | Array, CONTEXT>>; /** * Optional telemetry configuration (experimental). @@ -122,7 +130,7 @@ export type ToolLoopAgentSettings< /** * Optional function that you can use to provide different settings for a step. */ - prepareStep?: PrepareStepFunction>; + prepareStep?: PrepareStepFunction, CONTEXT>; /** * A function that attempts to repair a tool call that failed to parse. @@ -132,14 +140,19 @@ export type ToolLoopAgentSettings< /** * Callback that is called when the agent operation begins, before any LLM calls. */ - experimental_onStart?: ToolLoopAgentOnStartCallback, OUTPUT>; + experimental_onStart?: ToolLoopAgentOnStartCallback< + NoInfer, + CONTEXT, + NoInfer + >; /** * Callback that is called when a step (LLM call) begins, before the provider is called. */ experimental_onStepStart?: ToolLoopAgentOnStepStartCallback< NoInfer, - OUTPUT + NoInfer, + NoInfer >; /** @@ -177,10 +190,8 @@ export type ToolLoopAgentSettings< * Context that is passed into tool calls. * * Experimental (can break in patch releases). - * - * @default undefined */ - experimental_context?: unknown; + experimental_context?: CONTEXT; /** * Custom download function to use for URLs. @@ -205,7 +216,7 @@ export type ToolLoopAgentSettings< 'onStepFinish' > & Pick< - ToolLoopAgentSettings, + ToolLoopAgentSettings>, | 'model' | 'tools' | 'maxOutputTokens' @@ -227,7 +238,7 @@ export type ToolLoopAgentSettings< >, ) => MaybePromiseLike< Pick< - ToolLoopAgentSettings, + ToolLoopAgentSettings>, | 'model' | 'tools' | 'maxOutputTokens' diff --git a/packages/ai/src/agent/tool-loop-agent.test-d.ts b/packages/ai/src/agent/tool-loop-agent.test-d.ts index 86be0b7c0109..c2169b1e02aa 100644 --- a/packages/ai/src/agent/tool-loop-agent.test-d.ts +++ b/packages/ai/src/agent/tool-loop-agent.test-d.ts @@ -11,11 +11,13 @@ import type { ToolLoopAgentOnFinishCallback } from './tool-loop-agent-settings'; describe('ToolLoopAgent', () => { describe('onFinish callback type compatibility', () => { it('should allow StreamTextOnFinishCallback where ToolLoopAgentOnFinishCallback is expected', () => { - const streamTextCallback: StreamTextOnFinishCallback<{}> = - async event => { - const context: unknown = event.experimental_context; - context; - }; + const streamTextCallback: StreamTextOnFinishCallback< + {}, + {} + > = async event => { + const context: unknown = event.experimental_context; + context; + }; expectTypeOf(streamTextCallback).toMatchTypeOf< ToolLoopAgentOnFinishCallback<{}> @@ -29,7 +31,7 @@ describe('ToolLoopAgent', () => { }; expectTypeOf(agentCallback).toMatchTypeOf< - StreamTextOnFinishCallback<{}> + StreamTextOnFinishCallback<{}, {}> >(); }); }); diff --git a/packages/ai/src/agent/tool-loop-agent.ts b/packages/ai/src/agent/tool-loop-agent.ts index 766c1c2b57e7..1c6b33569f1e 100644 --- a/packages/ai/src/agent/tool-loop-agent.ts +++ b/packages/ai/src/agent/tool-loop-agent.ts @@ -4,10 +4,15 @@ import { Output } from '../generate-text/output'; import { isStepCount } from '../generate-text/stop-condition'; import { streamText } from '../generate-text/stream-text'; import { StreamTextResult } from '../generate-text/stream-text-result'; -import { ToolSet } from '../generate-text/tool-set'; +import type { GenerationContext } from '../generate-text/generation-context'; +import type { ToolSet } from '../generate-text/tool-set'; import { Prompt } from '../prompt'; import { Agent, AgentCallParameters, AgentStreamParameters } from './agent'; -import { ToolLoopAgentSettings } from './tool-loop-agent-settings'; +import { + ToolLoopAgentOnStartCallback, + ToolLoopAgentOnStepStartCallback, + ToolLoopAgentSettings, +} from './tool-loop-agent-settings'; /** * A tool loop agent is an agent that runs tools in a loop. In each step, @@ -23,13 +28,21 @@ import { ToolLoopAgentSettings } from './tool-loop-agent-settings'; export class ToolLoopAgent< CALL_OPTIONS = never, TOOLS extends ToolSet = {}, + CONTEXT extends GenerationContext = GenerationContext, OUTPUT extends Output = never, -> implements Agent { +> implements Agent { readonly version = 'agent-v1'; - private readonly settings: ToolLoopAgentSettings; + private readonly settings: ToolLoopAgentSettings< + CALL_OPTIONS, + TOOLS, + CONTEXT, + OUTPUT + >; - constructor(settings: ToolLoopAgentSettings) { + constructor( + settings: ToolLoopAgentSettings, + ) { this.settings = settings; } @@ -53,7 +66,7 @@ export class ToolLoopAgent< options?: CALL_OPTIONS; }): Promise< Omit< - ToolLoopAgentSettings, + ToolLoopAgentSettings, | 'prepareCall' | 'instructions' | 'experimental_onStart' @@ -85,7 +98,12 @@ export class ToolLoopAgent< (await this.settings.prepareCall?.( baseCallArgs as Parameters< NonNullable< - ToolLoopAgentSettings['prepareCall'] + ToolLoopAgentSettings< + CALL_OPTIONS, + TOOLS, + CONTEXT, + OUTPUT + >['prepareCall'] > >[0], )) ?? baseCallArgs; @@ -127,7 +145,7 @@ export class ToolLoopAgent< onFinish, ...options }: AgentCallParameters): Promise< - GenerateTextResult + GenerateTextResult > { return generateText({ ...(await this.prepareCall(options)), @@ -135,11 +153,15 @@ export class ToolLoopAgent< timeout, experimental_onStart: this.mergeCallbacks( this.settings.experimental_onStart, - experimental_onStart, + experimental_onStart as + | ToolLoopAgentOnStartCallback + | undefined, ), experimental_onStepStart: this.mergeCallbacks( this.settings.experimental_onStepStart, - experimental_onStepStart, + experimental_onStepStart as + | ToolLoopAgentOnStepStartCallback + | undefined, ), experimental_onToolCallStart: this.mergeCallbacks( this.settings.experimental_onToolCallStart, @@ -172,7 +194,7 @@ export class ToolLoopAgent< onFinish, ...options }: AgentStreamParameters): Promise< - StreamTextResult + StreamTextResult > { return streamText({ ...(await this.prepareCall(options)), @@ -181,11 +203,15 @@ export class ToolLoopAgent< experimental_transform, experimental_onStart: this.mergeCallbacks( this.settings.experimental_onStart, - experimental_onStart, + experimental_onStart as + | ToolLoopAgentOnStartCallback + | undefined, ), experimental_onStepStart: this.mergeCallbacks( this.settings.experimental_onStepStart, - experimental_onStepStart, + experimental_onStepStart as + | ToolLoopAgentOnStepStartCallback + | undefined, ), experimental_onToolCallStart: this.mergeCallbacks( this.settings.experimental_onToolCallStart, diff --git a/packages/ai/src/generate-text/__snapshots__/generate-text.test.ts.snap b/packages/ai/src/generate-text/__snapshots__/generate-text.test.ts.snap index a67498cf5668..30103ffbd5d0 100644 --- a/packages/ai/src/generate-text/__snapshots__/generate-text.test.ts.snap +++ b/packages/ai/src/generate-text/__snapshots__/generate-text.test.ts.snap @@ -5,7 +5,7 @@ exports[`generateText > options.experimental_onStart > should send correct infor "abortSignal": undefined, "activeTools": undefined, "callId": "test-telemetry-call-id", - "experimental_context": undefined, + "experimental_context": {}, "frequencyPenalty": undefined, "functionId": "test-function", "headers": { @@ -46,7 +46,7 @@ exports[`generateText > options.experimental_onToolCallStart > should be called { "abortSignal": undefined, "callId": "test-telemetry-call-id", - "experimental_context": undefined, + "experimental_context": {}, "functionId": undefined, "messages": [ { @@ -83,7 +83,7 @@ exports[`generateText > options.stopWhen > 2 steps: initial, tool-result > callb ], "dynamicToolCalls": [], "dynamicToolResults": [], - "experimental_context": undefined, + "experimental_context": {}, "files": [], "finishReason": "stop", "functionId": undefined, @@ -177,7 +177,7 @@ exports[`generateText > options.stopWhen > 2 steps: initial, tool-result > callb "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -254,7 +254,7 @@ exports[`generateText > options.stopWhen > 2 steps: initial, tool-result > callb "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -403,7 +403,7 @@ exports[`generateText > options.stopWhen > 2 steps: initial, tool-result > callb "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -480,7 +480,7 @@ exports[`generateText > options.stopWhen > 2 steps: initial, tool-result > callb "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -635,7 +635,7 @@ exports[`generateText > options.stopWhen > 2 steps: initial, tool-result > resul "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -712,7 +712,7 @@ exports[`generateText > options.stopWhen > 2 steps: initial, tool-result > resul "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -1673,7 +1673,7 @@ exports[`generateText > result.steps > should add the reasoning from the model r "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -1774,7 +1774,7 @@ exports[`generateText > result.steps > should contain files 1`] = ` "type": "file", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -1873,7 +1873,7 @@ exports[`generateText > result.steps > should contain sources 1`] = ` "url": "https://example.com/2", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, diff --git a/packages/ai/src/generate-text/__snapshots__/stream-text.test.ts.snap b/packages/ai/src/generate-text/__snapshots__/stream-text.test.ts.snap index c32c3bf645a1..b5603beb4166 100644 --- a/packages/ai/src/generate-text/__snapshots__/stream-text.test.ts.snap +++ b/packages/ai/src/generate-text/__snapshots__/stream-text.test.ts.snap @@ -7,7 +7,7 @@ exports[`streamText > options.experimental_onStart > should send correct informa "abortSignal": undefined, "activeTools": undefined, "callId": "test-telemetry-call-id", - "experimental_context": undefined, + "experimental_context": {}, "frequencyPenalty": undefined, "functionId": "test-function", "headers": undefined, @@ -46,7 +46,7 @@ exports[`streamText > options.experimental_onToolCallStart > should be called wi { "abortSignal": undefined, "callId": "test-telemetry-call-id", - "experimental_context": undefined, + "experimental_context": {}, "functionId": undefined, "messages": [ { diff --git a/packages/ai/src/generate-text/content-part.ts b/packages/ai/src/generate-text/content-part.ts index 5e862e844abb..58ec81f1d440 100644 --- a/packages/ai/src/generate-text/content-part.ts +++ b/packages/ai/src/generate-text/content-part.ts @@ -1,14 +1,14 @@ import { ProviderMetadata } from '../types'; import { Source } from '../types/language-model'; import { GeneratedFile } from './generated-file'; +import { ReasoningFileOutput, ReasoningOutput } from './reasoning-output'; import { ToolApprovalRequestOutput } from './tool-approval-request-output'; -import { ReasoningOutput, ReasoningFileOutput } from './reasoning-output'; import { TypedToolCall } from './tool-call'; import { TypedToolError } from './tool-error'; import { TypedToolResult } from './tool-result'; import { ToolSet } from './tool-set'; -export type ContentPart = +export type ContentPart = | { type: 'text'; text: string; providerMetadata?: ProviderMetadata } | { type: 'custom'; diff --git a/packages/ai/src/generate-text/core-events.ts b/packages/ai/src/generate-text/core-events.ts index 4f2c93aae7ab..cd30610e6beb 100644 --- a/packages/ai/src/generate-text/core-events.ts +++ b/packages/ai/src/generate-text/core-events.ts @@ -10,9 +10,20 @@ import type { LanguageModelUsage } from '../types/usage'; import type { Output } from './output'; import type { StepResult } from './step-result'; import type { StopCondition } from './stop-condition'; +import { TextStreamPart } from './stream-text-result'; import type { TypedToolCall } from './tool-call'; +import type { GenerationContext } from './generation-context'; import type { ToolSet } from './tool-set'; -import { TextStreamPart } from './stream-text-result'; + +/** + * Common model information used across callback events. + */ +export interface CallbackModelInfo { + /** The provider identifier (e.g., 'openai', 'anthropic'). */ + readonly provider: string; + /** The specific model identifier (e.g., 'gpt-4o'). */ + readonly modelId: string; +} /** * Event passed to the `onStart` callback. @@ -21,6 +32,7 @@ import { TextStreamPart } from './stream-text-result'; */ export interface OnStartEvent< TOOLS extends ToolSet = ToolSet, + CONTEXT extends GenerationContext = GenerationContext, OUTPUT extends Output = Output, INCLUDE = { requestBody?: boolean; responseBody?: boolean }, > { @@ -94,8 +106,8 @@ export interface OnStartEvent< * When the condition is an array, any of the conditions can be met to stop. */ readonly stopWhen: - | StopCondition - | Array> + | StopCondition, CONTEXT> + | Array, CONTEXT>> | undefined; /** The output specification for structured outputs, if configured. */ @@ -139,6 +151,7 @@ export interface OnStartEvent< */ export interface OnStepStartEvent< TOOLS extends ToolSet = ToolSet, + CONTEXT extends GenerationContext = GenerationContext, OUTPUT extends Output = Output, INCLUDE = { requestBody?: boolean; responseBody?: boolean }, > { @@ -180,7 +193,7 @@ export interface OnStepStartEvent< readonly activeTools: Array | undefined; /** Array of results from previous steps (empty for first step). */ - readonly steps: ReadonlyArray>; + readonly steps: ReadonlyArray>; /** Additional provider-specific options for this step. */ readonly providerOptions: ProviderOptions | undefined; @@ -199,8 +212,8 @@ export interface OnStepStartEvent< * When the condition is an array, any of the conditions can be met to stop. */ readonly stopWhen: - | StopCondition - | Array> + | StopCondition + | Array> | undefined; /** The output specification for structured outputs, if configured. */ @@ -356,8 +369,10 @@ export interface OnChunkEvent { * Called when a step (LLM call) completes. * Includes the StepResult for that step along with the call identifier. */ -export type OnStepFinishEvent = - StepResult; +export type OnStepFinishEvent< + TOOLS extends ToolSet = ToolSet, + CONTEXT extends GenerationContext = GenerationContext, +> = StepResult; /** * Event passed to the `onFinish` callback. @@ -365,26 +380,28 @@ export type OnStepFinishEvent = * Called when the entire generation completes (all steps finished). * Includes the final step's result along with aggregated data from all steps. */ -export type OnFinishEvent = - StepResult & { - /** Array containing results from all steps in the generation. */ - readonly steps: StepResult[]; - - /** Aggregated token usage across all steps. */ - readonly totalUsage: LanguageModelUsage; - - /** - * The final state of the user-defined context object. - * - * Experimental (can break in patch releases). - * - * @default undefined - */ - experimental_context: unknown; - - /** Identifier from telemetry settings for grouping related operations. */ - readonly functionId: string | undefined; - - /** Additional metadata from telemetry settings. */ - readonly metadata: Record | undefined; - }; +export type OnFinishEvent< + TOOLS extends ToolSet = ToolSet, + CONTEXT extends GenerationContext = GenerationContext, +> = StepResult & { + /** Array containing results from all steps in the generation. */ + readonly steps: StepResult[]; + + /** Aggregated token usage across all steps. */ + readonly totalUsage: LanguageModelUsage; + + /** + * The final state of the user-defined context object. + * + * Experimental (can break in patch releases). + * + * @default undefined + */ + experimental_context: CONTEXT; + + /** Identifier from telemetry settings for grouping related operations. */ + readonly functionId: string | undefined; + + /** Additional metadata from telemetry settings. */ + readonly metadata: Record | undefined; +}; diff --git a/packages/ai/src/generate-text/create-execute-tools-transformation.test.ts b/packages/ai/src/generate-text/create-execute-tools-transformation.test.ts index b68a63d76ae6..69d56e8a194d 100644 --- a/packages/ai/src/generate-text/create-execute-tools-transformation.test.ts +++ b/packages/ai/src/generate-text/create-execute-tools-transformation.test.ts @@ -58,7 +58,7 @@ describe('createExecuteToolsTransformation', () => { messages: [], timeout: undefined, abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, }), ); @@ -136,7 +136,7 @@ describe('createExecuteToolsTransformation', () => { messages: [], abortSignal: undefined, timeout: undefined, - experimental_context: undefined, + experimental_context: {}, }); expect( @@ -231,7 +231,7 @@ describe('createExecuteToolsTransformation', () => { messages: [], abortSignal: undefined, timeout: undefined, - experimental_context: undefined, + experimental_context: {}, }), ); @@ -274,7 +274,7 @@ describe('createExecuteToolsTransformation', () => { messages: [], timeout: undefined, abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallStart: async () => { callOrder.push('onToolCallStart'); }, @@ -324,7 +324,7 @@ describe('createExecuteToolsTransformation', () => { messages: [], timeout: undefined, abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, stepNumber: 2, provider: 'test-provider', modelId: 'test-model', @@ -397,7 +397,7 @@ describe('createExecuteToolsTransformation', () => { messages: [], timeout: undefined, abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallFinish: async event => { finishEvents.push(event); }, @@ -454,7 +454,7 @@ describe('createExecuteToolsTransformation', () => { messages: [], timeout: undefined, abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallFinish: async event => { finishEvents.push(event); }, @@ -504,7 +504,7 @@ describe('createExecuteToolsTransformation', () => { messages: [], timeout: undefined, abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallStart: async event => { startEvents.push(event); }, @@ -557,7 +557,7 @@ describe('createExecuteToolsTransformation', () => { messages: [], timeout: undefined, abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallStart: async event => { startEvents.push(event.toolCall.toolCallId); }, @@ -613,7 +613,7 @@ describe('createExecuteToolsTransformation', () => { messages: [], timeout: undefined, abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallStart: async event => { startEvents.push(event); }, @@ -667,7 +667,7 @@ describe('createExecuteToolsTransformation', () => { messages: [], abortSignal: undefined, timeout: undefined, - experimental_context: undefined, + experimental_context: {}, }), ); @@ -753,7 +753,7 @@ describe('createExecuteToolsTransformation', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, }), ); diff --git a/packages/ai/src/generate-text/create-execute-tools-transformation.ts b/packages/ai/src/generate-text/create-execute-tools-transformation.ts index f2271ecfad78..a1e3e967b43e 100644 --- a/packages/ai/src/generate-text/create-execute-tools-transformation.ts +++ b/packages/ai/src/generate-text/create-execute-tools-transformation.ts @@ -9,10 +9,14 @@ import { StreamTextOnToolCallStartCallback, } from './stream-text'; import { TypedToolCall } from './tool-call'; -import { ToolSet } from './tool-set'; +import type { GenerationContext } from './generation-context'; +import type { ToolSet } from './tool-set'; import { ModelCallStreamPart } from './stream-model-call'; -export function createExecuteToolsTransformation({ +export function createExecuteToolsTransformation< + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, +>({ tools, telemetry, callId, @@ -34,7 +38,7 @@ export function createExecuteToolsTransformation({ messages: ModelMessage[]; abortSignal: AbortSignal | undefined; timeout?: TimeoutConfiguration; - experimental_context: unknown; + experimental_context: CONTEXT; generateId: IdGenerator; stepNumber?: number; provider?: string; diff --git a/packages/ai/src/generate-text/execute-tool-call.test.ts b/packages/ai/src/generate-text/execute-tool-call.test.ts index f19ce9febd65..87c29f8dcd74 100644 --- a/packages/ai/src/generate-text/execute-tool-call.test.ts +++ b/packages/ai/src/generate-text/execute-tool-call.test.ts @@ -48,7 +48,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, }); expect(result).toBeUndefined(); @@ -69,7 +69,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, }); expect(result).toEqual({ @@ -97,7 +97,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, }); expect(result).toMatchObject({ @@ -125,7 +125,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, }); expect(result).toEqual({ @@ -155,7 +155,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, }); expect(result).toMatchObject({ @@ -229,7 +229,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallStart: async () => { throw new Error('callback error'); }, @@ -357,7 +357,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallFinish: async () => { throw new Error('callback error'); }, @@ -386,7 +386,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallFinish: async () => { throw new Error('callback error'); }, @@ -419,7 +419,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallFinish: async event => { finishEvents.push(event); }, @@ -449,7 +449,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallFinish: async event => { finishEvents.push(event); }, @@ -479,7 +479,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onPreliminaryToolResult: result => { preliminaryResults.push(result); }, @@ -516,7 +516,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onPreliminaryToolResult: result => { preliminaryResults.push(result); }, @@ -611,7 +611,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallFinish: async event => { finishEvents.push(event); }, @@ -645,7 +645,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, executeToolInTelemetryContext, }); @@ -669,7 +669,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, }); expect(result).toMatchObject({ @@ -694,7 +694,7 @@ describe('executeToolCall', () => { messages: [], abortSignal: undefined, timeout: { toolMs: 5000 }, - experimental_context: undefined, + experimental_context: {}, }); expect(result).toMatchObject({ @@ -722,7 +722,7 @@ describe('executeToolCall', () => { messages: [], abortSignal: undefined, timeout: { toolMs: 5000 }, - experimental_context: undefined, + experimental_context: {}, }); expect(receivedSignal).toBeDefined(); @@ -747,7 +747,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, }); expect(receivedSignal).toBeUndefined(); @@ -773,7 +773,7 @@ describe('executeToolCall', () => { messages: [], abortSignal: controller.signal, timeout: { toolMs: 5000 }, - experimental_context: undefined, + experimental_context: {}, }); expect(receivedSignal).toBeDefined(); @@ -800,7 +800,7 @@ describe('executeToolCall', () => { messages: [], abortSignal: undefined, timeout: { toolMs: 10000, tools: { testToolMs: 2000 } }, - experimental_context: undefined, + experimental_context: {}, }); expect(receivedSignal).toBeDefined(); @@ -826,7 +826,7 @@ describe('executeToolCall', () => { messages: [], abortSignal: undefined, timeout: { toolMs: 5000, tools: { otherToolMs: 2000 } }, - experimental_context: undefined, + experimental_context: {}, }); expect(receivedSignal).toBeDefined(); @@ -852,7 +852,7 @@ describe('executeToolCall', () => { messages: [], abortSignal: undefined, timeout: { tools: { otherToolMs: 2000 } }, - experimental_context: undefined, + experimental_context: {}, }); expect(receivedSignal).toBeUndefined(); @@ -875,7 +875,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, }); expect(result).toMatchObject({ @@ -901,7 +901,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, }); expect(result).toMatchObject({ @@ -920,7 +920,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, }); expect(result).toBeUndefined(); @@ -941,7 +941,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, }); expect(result).toBeUndefined(); @@ -964,7 +964,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallStart: [ async () => { calls.push('first'); @@ -993,7 +993,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallFinish: [ async () => { calls.push('first'); @@ -1022,7 +1022,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallStart: [ undefined, async () => { @@ -1057,7 +1057,7 @@ describe('executeToolCall', () => { callId: 'test-telemetry-call-id', messages: [], abortSignal: undefined, - experimental_context: undefined, + experimental_context: {}, onToolCallStart: [ async () => { throw new Error('listener error'); diff --git a/packages/ai/src/generate-text/execute-tool-call.ts b/packages/ai/src/generate-text/execute-tool-call.ts index 0002f07401f4..409b75c43478 100644 --- a/packages/ai/src/generate-text/execute-tool-call.ts +++ b/packages/ai/src/generate-text/execute-tool-call.ts @@ -1,20 +1,21 @@ import { executeTool, ModelMessage } from '@ai-sdk/provider-utils'; -import { notify } from '../util/notify'; import { getToolTimeoutMs, TimeoutConfiguration, } from '../prompt/call-settings'; import { TelemetrySettings } from '../telemetry/telemetry-settings'; +import { notify } from '../util/notify'; import { now } from '../util/now'; import { GenerateTextOnToolCallFinishCallback, GenerateTextOnToolCallStartCallback, } from './generate-text'; import { TypedToolCall } from './tool-call'; +import { TypedToolError } from './tool-error'; import { ToolOutput } from './tool-output'; -import { ToolSet } from './tool-set'; import { TypedToolResult } from './tool-result'; -import { TypedToolError } from './tool-error'; +import type { GenerationContext } from './generation-context'; +import type { ToolSet } from './tool-set'; /** * Executes a single tool call and manages its lifecycle callbacks. @@ -27,7 +28,10 @@ import { TypedToolError } from './tool-error'; * * @returns The tool output (result or error), or undefined if the tool has no execute function. */ -export async function executeToolCall({ +export async function executeToolCall< + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, +>({ toolCall, tools, telemetry, @@ -50,8 +54,8 @@ export async function executeToolCall({ callId: string; messages: ModelMessage[]; abortSignal: AbortSignal | undefined; + experimental_context: CONTEXT; timeout?: TimeoutConfiguration; - experimental_context: unknown; stepNumber?: number; provider?: string; modelId?: string; diff --git a/packages/ai/src/generate-text/generate-text-result.ts b/packages/ai/src/generate-text/generate-text-result.ts index c3ceb5869435..a80bb66df7ab 100644 --- a/packages/ai/src/generate-text/generate-text-result.ts +++ b/packages/ai/src/generate-text/generate-text-result.ts @@ -16,7 +16,8 @@ import { StaticToolResult, TypedToolResult, } from './tool-result'; -import { ToolSet } from './tool-set'; +import type { GenerationContext } from './generation-context'; +import type { ToolSet } from './tool-set'; /** * The result of a `generateText` call. @@ -24,6 +25,7 @@ import { ToolSet } from './tool-set'; */ export interface GenerateTextResult< TOOLS extends ToolSet, + CONTEXT extends GenerationContext, OUTPUT extends Output, > { /** @@ -151,7 +153,7 @@ export interface GenerateTextResult< * You can use this to get information about intermediate steps, * such as the tool calls or the response headers. */ - readonly steps: Array>; + readonly steps: Array>; /** * The generated structured output. It uses the `output` specification. diff --git a/packages/ai/src/generate-text/generate-text.test-d.ts b/packages/ai/src/generate-text/generate-text.test-d.ts index fd7ccd47fabd..3bb3dc08a1b0 100644 --- a/packages/ai/src/generate-text/generate-text.test-d.ts +++ b/packages/ai/src/generate-text/generate-text.test-d.ts @@ -1,4 +1,5 @@ import { JSONValue } from '@ai-sdk/provider'; +import { tool } from '@ai-sdk/provider-utils'; import { describe, expectTypeOf, it } from 'vitest'; import { z } from 'zod'; import { generateText, Output } from '../generate-text'; @@ -65,4 +66,47 @@ describe('generateText types', () => { expectTypeOf().toEqualTypeOf(); }); }); + + describe('experimental_context', () => { + it('should infer typed experimental_context with one tool context and prepareStep', async () => { + generateText({ + model: new MockLanguageModelV4(), + prompt: 'Hello, world!', + tools: { + weather: tool({ + inputSchema: z.object({ + city: z.string(), + }), + contextSchema: z.object({ + userId: z.string(), + }), + execute: async (_input, { experimental_context }) => { + expectTypeOf(experimental_context).toMatchObjectType<{ + userId: string; + }>(); + + return 'sunny'; + }, + }), + }, + experimental_context: { + userId: 'test-user', + role: 'admin', + }, + prepareStep: ({ experimental_context }) => { + expectTypeOf(experimental_context).toMatchObjectType<{ + userId: string; + role: string; + }>(); + + return { + experimental_context: { + userId: experimental_context.userId, + role: experimental_context.role, + }, + }; + }, + }); + }); + }); }); diff --git a/packages/ai/src/generate-text/generate-text.test.ts b/packages/ai/src/generate-text/generate-text.test.ts index f2e10805eb71..042631852e0d 100644 --- a/packages/ai/src/generate-text/generate-text.test.ts +++ b/packages/ai/src/generate-text/generate-text.test.ts @@ -812,7 +812,9 @@ describe('generateText', () => { describe('options.experimental_onStart', () => { it('should send correct information with text prompt', async () => { - let startEvent!: Parameters[0]; + let startEvent!: Parameters< + GenerateTextOnStartCallback + >[0]; await generateText({ model: new MockLanguageModelV4({ @@ -839,7 +841,9 @@ describe('generateText', () => { }); it('should pass experimental_context', async () => { - let startEvent!: Parameters[0]; + let startEvent!: Parameters< + GenerateTextOnStartCallback + >[0]; await generateText({ model: new MockLanguageModelV4({ @@ -862,7 +866,9 @@ describe('generateText', () => { }); it('should send correct information with system and messages', async () => { - let startEvent!: Parameters[0]; + let startEvent!: Parameters< + GenerateTextOnStartCallback + >[0]; await generateText({ model: new MockLanguageModelV4({ @@ -1365,7 +1371,7 @@ describe('generateText', () => { describe('options.onStepFinish stepNumber', () => { it('should pass stepNumber 0 for a single step', async () => { let stepFinishEvent!: Parameters< - GenerateTextOnStepFinishCallback + GenerateTextOnStepFinishCallback >[0]; await generateText({ @@ -2178,7 +2184,7 @@ describe('generateText', () => { describe('options.onFinish', () => { it('should send correct information', async () => { - let result!: Parameters>[0]; + let result!: Parameters>[0]; await generateText({ model: new MockLanguageModelV4({ @@ -2254,7 +2260,7 @@ describe('generateText', () => { ], "dynamicToolCalls": [], "dynamicToolResults": [], - "experimental_context": undefined, + "experimental_context": {}, "files": [], "finishReason": "stop", "functionId": undefined, @@ -2370,7 +2376,7 @@ describe('generateText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -2513,9 +2519,9 @@ describe('generateText', () => { }); describe('options.stopWhen', () => { - let result: GenerateTextResult; - let onFinishResult: Parameters>[0]; - let onStepFinishResults: StepResult[]; + let result: GenerateTextResult; + let onFinishResult: Parameters>[0]; + let onStepFinishResults: StepResult[]; beforeEach(() => { result = undefined as any; @@ -2711,13 +2717,13 @@ describe('generateText', () => { }); describe('2 steps: initial, tool-result with prepareStep', () => { - let result: GenerateTextResult; - let onStepFinishResults: StepResult[]; + let result: GenerateTextResult; + let onStepFinishResults: StepResult[]; let doGenerateCalls: Array; let prepareStepCalls: Array<{ modelId: string; stepNumber: number; - steps: Array>; + steps: Array>; messages: Array; experimental_context: unknown; }>; @@ -3481,10 +3487,10 @@ describe('generateText', () => { }); describe('2 stop conditions', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; let stopConditionCalls: Array<{ number: number; - steps: StepResult[]; + steps: StepResult[]; }>; beforeEach(async () => { @@ -3599,7 +3605,7 @@ describe('generateText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -3698,7 +3704,7 @@ describe('generateText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -3945,6 +3951,7 @@ describe('generateText', () => { abortSignal: abortController.signal, toolCallId: 'call-1', messages: expect.any(Array), + experimental_context: {}, }, ); }); @@ -4048,6 +4055,7 @@ describe('generateText', () => { abortSignal: expect.any(AbortSignal), toolCallId: 'call-1', messages: expect.any(Array), + experimental_context: {}, }, ); }); @@ -4346,7 +4354,7 @@ describe('generateText', () => { { "options": { "abortSignal": undefined, - "experimental_context": undefined, + "experimental_context": {}, "input": { "value": "value", }, @@ -4476,7 +4484,7 @@ describe('generateText', () => { describe('provider-executed tools', () => { describe('two provider-executed tool calls and results', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; beforeEach(async () => { result = await generateText({ @@ -5069,7 +5077,7 @@ describe('generateText', () => { }); describe('tool execution errors', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; beforeEach(async () => { result = await generateText({ @@ -5260,14 +5268,14 @@ describe('generateText', () => { describe('programmatic tool calling', () => { describe('5 steps: code_execution triggers client tool across multiple turns (dice game fixture)', () => { - let result: GenerateTextResult; - let onFinishResult: Parameters>[0]; - let onStepFinishResults: StepResult[]; + let result: GenerateTextResult; + let onFinishResult: Parameters>[0]; + let onStepFinishResults: StepResult[]; let doGenerateCalls: Array; let prepareStepCalls: Array<{ modelId: string; stepNumber: number; - steps: Array>; + steps: Array>; messages: Array; }>; let rollDieExecutions: Array<{ player: string }>; @@ -6598,7 +6606,7 @@ describe('generateText', () => { describe('invalid tool calls', () => { describe('single invalid tool call', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; beforeEach(async () => { result = await generateText({ @@ -6741,7 +6749,7 @@ describe('generateText', () => { describe('tools with preliminary results', () => { describe('single tool with preliminary results', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; beforeEach(async () => { result = await generateText({ @@ -6862,7 +6870,7 @@ describe('generateText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -7062,7 +7070,7 @@ describe('generateText', () => { describe('tool execution approval', () => { describe('when a single tool needs approval', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; beforeEach(async () => { result = await generateText({ @@ -7167,7 +7175,7 @@ describe('generateText', () => { }); describe('when a single tool has a needsApproval function', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; let needsApprovalCalls: Array<{ input: any; options: any }> = []; beforeEach(async () => { @@ -7335,7 +7343,7 @@ describe('generateText', () => { "value": "value-needs-approval", }, "options": { - "experimental_context": undefined, + "experimental_context": {}, "messages": [ { "content": "test-input", @@ -7350,7 +7358,7 @@ describe('generateText', () => { "value": "value-no-approval", }, "options": { - "experimental_context": undefined, + "experimental_context": {}, "messages": [ { "content": "test-input", @@ -7366,9 +7374,9 @@ describe('generateText', () => { }); describe('when a call from a single tool that needs approval is approved', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; let prompts: LanguageModelV4Prompt[]; - let executeFunction: ToolExecuteFunction; + let executeFunction: ToolExecuteFunction; beforeEach(async () => { prompts = []; @@ -7532,7 +7540,7 @@ describe('generateText', () => { }); describe('when a call from a single tool that needs approval is approved and the tool throws', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; let prompts: LanguageModelV4Prompt[]; beforeEach(async () => { @@ -7646,9 +7654,9 @@ describe('generateText', () => { }); describe('when a call from a single tool that needs approval is denied', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; let prompts: LanguageModelV4Prompt[]; - let executeFunction: ToolExecuteFunction; + let executeFunction: ToolExecuteFunction; beforeEach(async () => { prompts = []; @@ -7805,9 +7813,9 @@ describe('generateText', () => { }); describe('when two calls from a single tool that needs approval are approved', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; let prompts: LanguageModelV4Prompt[]; - let executeFunction: ToolExecuteFunction; + let executeFunction: ToolExecuteFunction; beforeEach(async () => { prompts = []; @@ -8010,7 +8018,7 @@ describe('generateText', () => { describe('provider-executed tool (MCP) approval', () => { describe('when a provider-executed tool emits tool-approval-request', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; beforeEach(async () => { result = await generateText({ @@ -8121,7 +8129,7 @@ describe('generateText', () => { }); describe('when a provider-executed tool approval is approved', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; let prompts: LanguageModelV4Prompt[]; beforeEach(async () => { @@ -8307,7 +8315,7 @@ describe('generateText', () => { }); describe('when a provider-executed tool approval is denied', () => { - let result: GenerateTextResult; + let result: GenerateTextResult; let prompts: LanguageModelV4Prompt[]; beforeEach(async () => { diff --git a/packages/ai/src/generate-text/generate-text.ts b/packages/ai/src/generate-text/generate-text.ts index f369c82d38ee..317b2b7ff977 100644 --- a/packages/ai/src/generate-text/generate-text.ts +++ b/packages/ai/src/generate-text/generate-text.ts @@ -64,6 +64,7 @@ import { executeToolCall } from './execute-tool-call'; import { filterActiveTools } from './filter-active-tool'; import { GenerateTextResult } from './generate-text-result'; import { DefaultGeneratedFile } from './generated-file'; +import type { GenerationContext } from './generation-context'; import { isApprovalNeeded } from './is-approval-needed'; import { Output, text } from './output'; import { InferCompleteOutput } from './output-utils'; @@ -73,8 +74,8 @@ import { convertToReasoningOutputs } from './reasoning-output'; import { ResponseMessage } from './response-message'; import { DefaultStepResult, StepResult } from './step-result'; import { - isStopConditionMet, isStepCount, + isStopConditionMet, StopCondition, } from './stop-condition'; import { toResponseMessages } from './to-response-messages'; @@ -84,7 +85,7 @@ import { ToolCallRepairFunction } from './tool-call-repair-function'; import { TypedToolError } from './tool-error'; import { ToolOutput } from './tool-output'; import { TypedToolResult } from './tool-result'; -import { ToolSet } from './tool-set'; +import type { ToolSet } from './tool-set'; const originalGenerateId = createIdGenerator({ prefix: 'aitxt', @@ -114,10 +115,11 @@ type GenerateTextIncludeSettings = { * @param event - The event object containing generation configuration. */ export type GenerateTextOnStartCallback< - TOOLS extends ToolSet = ToolSet, + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, OUTPUT extends Output = Output, > = ( - event: OnStartEvent, + event: OnStartEvent, ) => PromiseLike | void; /** @@ -130,10 +132,11 @@ export type GenerateTextOnStartCallback< * @param event - The event object containing step configuration. */ export type GenerateTextOnStepStartCallback< - TOOLS extends ToolSet = ToolSet, + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, OUTPUT extends Output = Output, > = ( - event: OnStepStartEvent, + event: OnStepStartEvent, ) => PromiseLike | void; /** @@ -144,9 +147,9 @@ export type GenerateTextOnStepStartCallback< * * @param event - The event object containing tool call information. */ -export type GenerateTextOnToolCallStartCallback< - TOOLS extends ToolSet = ToolSet, -> = (event: OnToolCallStartEvent) => PromiseLike | void; +export type GenerateTextOnToolCallStartCallback = ( + event: OnToolCallStartEvent, +) => PromiseLike | void; /** * Callback that is set using the `experimental_onToolCallFinish` option. @@ -160,9 +163,9 @@ export type GenerateTextOnToolCallStartCallback< * * @param event - The event object containing tool call result information. */ -export type GenerateTextOnToolCallFinishCallback< - TOOLS extends ToolSet = ToolSet, -> = (event: OnToolCallFinishEvent) => PromiseLike | void; +export type GenerateTextOnToolCallFinishCallback = ( + event: OnToolCallFinishEvent, +) => PromiseLike | void; /** * Callback that is set using the `onStepFinish` option. @@ -172,9 +175,10 @@ export type GenerateTextOnToolCallFinishCallback< * * @param stepResult - The result of the step. */ -export type GenerateTextOnStepFinishCallback = ( - event: OnStepFinishEvent, -) => Promise | void; +export type GenerateTextOnStepFinishCallback< + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, +> = (event: OnStepFinishEvent) => Promise | void; /** * Callback that is set using the `onFinish` option. @@ -185,9 +189,10 @@ export type GenerateTextOnStepFinishCallback = ( * * @param event - The final result along with aggregated step data. */ -export type GenerateTextOnFinishCallback = ( - event: OnFinishEvent, -) => PromiseLike | void; +export type GenerateTextOnFinishCallback< + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, +> = (event: OnFinishEvent) => PromiseLike | void; /** * Generate a text and call tools for a given prompt using a language model. @@ -245,6 +250,7 @@ export type GenerateTextOnFinishCallback = ( */ export async function generateText< TOOLS extends ToolSet, + CONTEXT extends GenerationContext, OUTPUT extends Output = Output, >({ model: modelArg, @@ -268,7 +274,7 @@ export async function generateText< prepareStep = experimental_prepareStep, experimental_repairToolCall: repairToolCall, experimental_download: download, - experimental_context, + experimental_context = {} as CONTEXT, experimental_include: include, _internal: { generateId = originalGenerateId, @@ -313,8 +319,8 @@ export async function generateText< * @default isStepCount(1) */ stopWhen?: - | StopCondition> - | Array>>; + | StopCondition, CONTEXT> + | Array, CONTEXT>>; /** * Optional telemetry configuration (experimental). @@ -361,12 +367,12 @@ export async function generateText< /** * @deprecated Use `prepareStep` instead. */ - experimental_prepareStep?: PrepareStepFunction>; + experimental_prepareStep?: PrepareStepFunction, CONTEXT>; /** * Optional function that you can use to provide different settings for a step. */ - prepareStep?: PrepareStepFunction>; + prepareStep?: PrepareStepFunction, CONTEXT>; /** * A function that attempts to repair a tool call that failed to parse. @@ -377,7 +383,11 @@ export async function generateText< * Callback that is called when the generateText operation begins, * before any LLM calls are made. */ - experimental_onStart?: GenerateTextOnStartCallback, OUTPUT>; + experimental_onStart?: GenerateTextOnStartCallback< + NoInfer, + NoInfer, + NoInfer + >; /** * Callback that is called when a step (LLM call) begins, @@ -385,7 +395,8 @@ export async function generateText< */ experimental_onStepStart?: GenerateTextOnStepStartCallback< NoInfer, - OUTPUT + NoInfer, + NoInfer >; /** @@ -405,12 +416,15 @@ export async function generateText< /** * Callback that is called when each step (LLM call) is finished, including intermediate steps. */ - onStepFinish?: GenerateTextOnStepFinishCallback>; + onStepFinish?: GenerateTextOnStepFinishCallback< + NoInfer, + NoInfer + >; /** * Callback that is called when all steps are finished and the response is complete. */ - onFinish?: GenerateTextOnFinishCallback>; + onFinish?: GenerateTextOnFinishCallback, NoInfer>; /** * Context that is passed into tool execution. @@ -419,7 +433,7 @@ export async function generateText< * * @default undefined */ - experimental_context?: unknown; + experimental_context?: CONTEXT; /** * Settings for controlling what data is included in step results. @@ -450,9 +464,9 @@ export async function generateText< generateId?: IdGenerator; generateCallId?: IdGenerator; }; - }): Promise> { + }): Promise> { const model = resolveLanguageModel(modelArg); - const createGlobalTelemetry = getGlobalTelemetryIntegration(); + const createGlobalTelemetry = getGlobalTelemetryIntegration(); const stopConditions = asArray(stopWhen); const totalTimeoutMs = getTotalTimeoutMs(timeout); @@ -528,7 +542,7 @@ export async function generateText< onStart, globalTelemetry.onStart as | undefined - | GenerateTextOnStartCallback, + | GenerateTextOnStartCallback, ], }); @@ -660,7 +674,7 @@ export async function generateText< > & { response: { id: string; timestamp: Date; modelId: string } }; let clientToolCalls: Array> = []; let clientToolOutputs: Array> = []; - const steps: GenerateTextResult['steps'] = []; + const steps: GenerateTextResult['steps'] = []; // Track provider-executed tool calls that support deferred results // (e.g., code_execution in programmatic tool calling scenarios). @@ -755,7 +769,7 @@ export async function generateText< onStepStart, globalTelemetry.onStepStart as | undefined - | GenerateTextOnStepStartCallback, + | GenerateTextOnStepStartCallback, ], }); @@ -978,23 +992,26 @@ export async function generateText< const stepNumber = steps.length; - const currentStepResult: StepResult = new DefaultStepResult({ - callId, - stepNumber, - provider: stepModel.provider, - modelId: stepModel.modelId, - functionId: telemetry?.functionId, - metadata: telemetry?.metadata as Record | undefined, - experimental_context, - content: stepContent, - finishReason: currentModelResponse.finishReason.unified, - rawFinishReason: currentModelResponse.finishReason.raw, - usage: asLanguageModelUsage(currentModelResponse.usage), - warnings: currentModelResponse.warnings, - providerMetadata: currentModelResponse.providerMetadata, - request: stepRequest, - response: stepResponse, - }); + const currentStepResult: StepResult = + new DefaultStepResult({ + callId, + stepNumber, + provider: stepModel.provider, + modelId: stepModel.modelId, + functionId: telemetry?.functionId, + metadata: telemetry?.metadata as + | Record + | undefined, + experimental_context, + content: stepContent, + finishReason: currentModelResponse.finishReason.unified, + rawFinishReason: currentModelResponse.finishReason.raw, + usage: asLanguageModelUsage(currentModelResponse.usage), + warnings: currentModelResponse.warnings, + providerMetadata: currentModelResponse.providerMetadata, + request: stepRequest, + response: stepResponse, + }); logWarnings({ warnings: currentModelResponse.warnings ?? [], @@ -1010,7 +1027,7 @@ export async function generateText< onStepFinish, globalTelemetry.onStepFinish as | undefined - | GenerateTextOnStepFinishCallback, + | GenerateTextOnStepFinishCallback, ], }); } finally { @@ -1080,7 +1097,7 @@ export async function generateText< onFinish, globalTelemetry.onFinish as | undefined - | GenerateTextOnFinishCallback, + | GenerateTextOnFinishCallback, ], }); @@ -1109,7 +1126,10 @@ export async function generateText< } } -async function executeTools({ +async function executeTools< + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, +>({ toolCalls, tools, telemetry, @@ -1132,7 +1152,7 @@ async function executeTools({ messages: ModelMessage[]; abortSignal: AbortSignal | undefined; timeout?: TimeoutConfiguration; - experimental_context: unknown; + experimental_context: CONTEXT; stepNumber: number; provider: string; modelId: string; @@ -1168,14 +1188,15 @@ async function executeTools({ class DefaultGenerateTextResult< TOOLS extends ToolSet, + CONTEXT extends GenerationContext, OUTPUT extends Output, -> implements GenerateTextResult { - readonly steps: GenerateTextResult['steps']; +> implements GenerateTextResult { + readonly steps: GenerateTextResult['steps']; readonly totalUsage: LanguageModelUsage; private readonly _output: InferCompleteOutput | undefined; constructor(options: { - steps: GenerateTextResult['steps']; + steps: GenerateTextResult['steps']; output: InferCompleteOutput | undefined; totalUsage: LanguageModelUsage; }) { diff --git a/packages/ai/src/generate-text/generation-context.test-d.ts b/packages/ai/src/generate-text/generation-context.test-d.ts new file mode 100644 index 000000000000..942581e206a6 --- /dev/null +++ b/packages/ai/src/generate-text/generation-context.test-d.ts @@ -0,0 +1,53 @@ +import { Context, tool } from '@ai-sdk/provider-utils'; +import { describe, expectTypeOf, it } from 'vitest'; +import { z } from 'zod/v4'; +import type { GenerationContext } from './generation-context'; + +describe('GenerationContext', () => { + it('combines inferred tool context with the generic context type', () => { + const tools = { + weather: tool({ + inputSchema: z.object({ + city: z.string(), + }), + contextSchema: z.object({ + userId: z.string(), + }), + }), + forecast: tool({ + inputSchema: z.object({ + days: z.number(), + }), + contextSchema: z.object({ + role: z.string(), + }), + }), + }; + + expectTypeOf>().toEqualTypeOf< + { + userId: string; + } & { + role: string; + } & Context + >(); + + expectTypeOf< + GenerationContext + >().toMatchObjectType(); + }); + + it('falls back to the generic context type when tools have no contextSchema', () => { + const tools = { + weather: tool({ + inputSchema: z.object({ + city: z.string(), + }), + }), + }; + + expectTypeOf< + GenerationContext + >().toMatchObjectType(); + }); +}); diff --git a/packages/ai/src/generate-text/generation-context.ts b/packages/ai/src/generate-text/generation-context.ts new file mode 100644 index 000000000000..7b9a0dce877a --- /dev/null +++ b/packages/ai/src/generate-text/generation-context.ts @@ -0,0 +1,15 @@ +import type { Context } from '@ai-sdk/provider-utils'; +import type { InferToolSetContext } from './infer-tool-set-context'; +import type { ToolSet } from './tool-set'; + +/** + * The context type for a generation call. + * + * It expands the tool set context with the generic context type for + * e.g. prepareStep or telemetry, + * while keeping the inferred tool set context for autocompletion. + */ +export type GenerationContext = InferToolSetContext< + NoInfer +> & + Context; diff --git a/packages/ai/src/generate-text/infer-tool-set-context.test-d.ts b/packages/ai/src/generate-text/infer-tool-set-context.test-d.ts new file mode 100644 index 000000000000..58d92f6dba55 --- /dev/null +++ b/packages/ai/src/generate-text/infer-tool-set-context.test-d.ts @@ -0,0 +1,65 @@ +import { Context, tool } from '@ai-sdk/provider-utils'; +import { describe, expectTypeOf, it } from 'vitest'; +import { z } from 'zod/v4'; +import type { InferToolSetContext } from './infer-tool-set-context'; + +describe('InferToolSetContext', () => { + it('infers the intersection of context types across a tool set', () => { + const tools = { + weather: tool({ + inputSchema: z.object({ + city: z.string(), + }), + contextSchema: z.object({ + userId: z.string(), + }), + }), + forecast: tool({ + inputSchema: z.object({ + days: z.number(), + }), + contextSchema: z.object({ + role: z.string(), + }), + }), + }; + + expectTypeOf>().toMatchObjectType<{ + userId: string; + role: string; + }>(); + }); + + it('infers a single tool context type from a tool set', () => { + const tools = { + weather: tool({ + inputSchema: z.object({ + city: z.string(), + }), + contextSchema: z.object({ + userId: z.string(), + role: z.string(), + }), + }), + }; + + expectTypeOf>().toMatchObjectType<{ + userId: string; + role: string; + }>(); + }); + + it('falls back to the generic context type for tools without contextSchema', () => { + const tools = { + weather: tool({ + inputSchema: z.object({ + city: z.string(), + }), + }), + }; + + expectTypeOf< + InferToolSetContext + >().toMatchObjectType(); + }); +}); diff --git a/packages/ai/src/generate-text/infer-tool-set-context.ts b/packages/ai/src/generate-text/infer-tool-set-context.ts new file mode 100644 index 000000000000..8c549e2d8afc --- /dev/null +++ b/packages/ai/src/generate-text/infer-tool-set-context.ts @@ -0,0 +1,12 @@ +import { InferToolContext } from '@ai-sdk/provider-utils'; +import { UnionToIntersection } from '../util/union-to-intersection'; +import type { ToolSet } from './tool-set'; + +/** + * Infer the context type of a tool set. + */ +export type InferToolSetContext = UnionToIntersection< + { + [K in keyof TOOLS]: InferToolContext>; + }[keyof TOOLS] +>; diff --git a/packages/ai/src/generate-text/prepare-step.ts b/packages/ai/src/generate-text/prepare-step.ts index f3f367094aee..890f934de36c 100644 --- a/packages/ai/src/generate-text/prepare-step.ts +++ b/packages/ai/src/generate-text/prepare-step.ts @@ -1,11 +1,13 @@ import { + Context, ModelMessage, ProviderOptions, SystemModelMessage, - Tool, } from '@ai-sdk/provider-utils'; import { LanguageModel, ToolChoice } from '../types/language-model'; import { StepResult } from './step-result'; +import type { GenerationContext } from './generation-context'; +import type { ToolSet } from './tool-set'; /** * Function that you can use to provide different settings for a step. @@ -21,12 +23,13 @@ import { StepResult } from './step-result'; * If you return undefined (or for undefined settings), the settings from the outer level will be used. */ export type PrepareStepFunction< - TOOLS extends Record = Record, + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, > = (options: { /** * The steps that have been executed so far. */ - steps: Array>>; + steps: Array>; /** * The number of the step that is being executed. @@ -46,15 +49,18 @@ export type PrepareStepFunction< /** * The context passed via the experimental_context setting (experimental). */ - experimental_context: unknown; -}) => PromiseLike> | PrepareStepResult; + experimental_context: CONTEXT; +}) => + | PromiseLike> + | PrepareStepResult; /** * The result type returned by a {@link PrepareStepFunction}, * allowing per-step overrides of model, tools, or messages. */ export type PrepareStepResult< - TOOLS extends Record = Record, + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, > = | { /** @@ -90,7 +96,7 @@ export type PrepareStepResult< * Changing the context will affect the context in this step * and all subsequent steps. */ - experimental_context?: unknown; + experimental_context?: CONTEXT; /** * Additional provider-specific options for this step. diff --git a/packages/ai/src/generate-text/step-result.ts b/packages/ai/src/generate-text/step-result.ts index 04c8adee2ced..dddfe648f61f 100644 --- a/packages/ai/src/generate-text/step-result.ts +++ b/packages/ai/src/generate-text/step-result.ts @@ -1,10 +1,4 @@ -import { ReasoningPart, ReasoningFilePart } from '@ai-sdk/provider-utils'; -import { asReasoningText } from './reasoning'; -import { - ReasoningOutput, - ReasoningFileOutput, - convertFromReasoningOutputs, -} from './reasoning-output'; +import { ReasoningFilePart, ReasoningPart } from '@ai-sdk/provider-utils'; import { CallWarning, FinishReason, @@ -16,6 +10,12 @@ import { Source } from '../types/language-model'; import { LanguageModelUsage } from '../types/usage'; import { ContentPart } from './content-part'; import { GeneratedFile } from './generated-file'; +import { asReasoningText } from './reasoning'; +import { + ReasoningFileOutput, + ReasoningOutput, + convertFromReasoningOutputs, +} from './reasoning-output'; import { ResponseMessage } from './response-message'; import { DynamicToolCall, StaticToolCall, TypedToolCall } from './tool-call'; import { @@ -23,12 +23,16 @@ import { StaticToolResult, TypedToolResult, } from './tool-result'; -import { ToolSet } from './tool-set'; +import type { GenerationContext } from './generation-context'; +import type { ToolSet } from './tool-set'; /** * The result of a single step in the generation process. */ -export type StepResult = { +export type StepResult< + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, +> = { /** * Unique identifier for the generation call this step belongs to. */ @@ -64,7 +68,7 @@ export type StepResult = { * * Experimental (can break in patch releases). */ - readonly experimental_context: unknown; + readonly experimental_context: CONTEXT; /** * The content that was generated in the last step. @@ -178,21 +182,25 @@ export type StepResult = { export class DefaultStepResult< TOOLS extends ToolSet, -> implements StepResult { - readonly callId: StepResult['callId']; - readonly stepNumber: StepResult['stepNumber']; - readonly model: StepResult['model']; - readonly functionId: StepResult['functionId']; - readonly metadata: StepResult['metadata']; - readonly experimental_context: StepResult['experimental_context']; - readonly content: StepResult['content']; - readonly finishReason: StepResult['finishReason']; - readonly rawFinishReason: StepResult['rawFinishReason']; - readonly usage: StepResult['usage']; - readonly warnings: StepResult['warnings']; - readonly request: StepResult['request']; - readonly response: StepResult['response']; - readonly providerMetadata: StepResult['providerMetadata']; + CONTEXT extends GenerationContext, +> implements StepResult { + readonly callId: StepResult['callId']; + readonly stepNumber: StepResult['stepNumber']; + readonly model: StepResult['model']; + readonly functionId: StepResult['functionId']; + readonly metadata: StepResult['metadata']; + readonly experimental_context: StepResult< + TOOLS, + CONTEXT + >['experimental_context']; + readonly content: StepResult['content']; + readonly finishReason: StepResult['finishReason']; + readonly rawFinishReason: StepResult['rawFinishReason']; + readonly usage: StepResult['usage']; + readonly warnings: StepResult['warnings']; + readonly request: StepResult['request']; + readonly response: StepResult['response']; + readonly providerMetadata: StepResult['providerMetadata']; constructor({ callId, @@ -211,21 +219,21 @@ export class DefaultStepResult< response, providerMetadata, }: { - callId: StepResult['callId']; - stepNumber: StepResult['stepNumber']; - provider: string; - modelId: string; - functionId: StepResult['functionId']; - metadata: StepResult['metadata']; - experimental_context: StepResult['experimental_context']; - content: StepResult['content']; - finishReason: StepResult['finishReason']; - rawFinishReason: StepResult['rawFinishReason']; - usage: StepResult['usage']; - warnings: StepResult['warnings']; - request: StepResult['request']; - response: StepResult['response']; - providerMetadata: StepResult['providerMetadata']; + callId: StepResult['callId']; + stepNumber: StepResult['stepNumber']; + provider: StepResult['model']['provider']; + modelId: StepResult['model']['modelId']; + functionId: StepResult['functionId']; + metadata: StepResult['metadata']; + experimental_context: StepResult['experimental_context']; + content: StepResult['content']; + finishReason: StepResult['finishReason']; + rawFinishReason: StepResult['rawFinishReason']; + usage: StepResult['usage']; + warnings: StepResult['warnings']; + request: StepResult['request']; + response: StepResult['response']; + providerMetadata: StepResult['providerMetadata']; }) { this.callId = callId; this.stepNumber = stepNumber; diff --git a/packages/ai/src/generate-text/stop-condition.ts b/packages/ai/src/generate-text/stop-condition.ts index 2690e1bb6290..6527b7f73218 100644 --- a/packages/ai/src/generate-text/stop-condition.ts +++ b/packages/ai/src/generate-text/stop-condition.ts @@ -1,31 +1,38 @@ import { StepResult } from './step-result'; -import { ToolSet } from './tool-set'; +import type { GenerationContext } from './generation-context'; +import type { ToolSet } from './tool-set'; -export type StopCondition = (options: { - steps: Array>; +export type StopCondition< + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, +> = (options: { + steps: Array>; }) => PromiseLike | boolean; -export function isStepCount(stepCount: number): StopCondition { +export function isStepCount(stepCount: number): StopCondition { return ({ steps }) => steps.length === stepCount; } -export function isLoopFinished(): StopCondition { +export function isLoopFinished(): StopCondition { return () => false; } -export function hasToolCall(toolName: string): StopCondition { +export function hasToolCall(toolName: string): StopCondition { return ({ steps }) => steps[steps.length - 1]?.toolCalls?.some( toolCall => toolCall.toolName === toolName, ) ?? false; } -export async function isStopConditionMet({ +export async function isStopConditionMet< + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, +>({ stopConditions, steps, }: { - stopConditions: Array>; - steps: Array>; + stopConditions: Array>; + steps: Array>; }): Promise { return ( await Promise.all(stopConditions.map(condition => condition({ steps }))) diff --git a/packages/ai/src/generate-text/stream-text-result.ts b/packages/ai/src/generate-text/stream-text-result.ts index a2fc408defa3..cbef38b52777 100644 --- a/packages/ai/src/generate-text/stream-text-result.ts +++ b/packages/ai/src/generate-text/stream-text-result.ts @@ -35,7 +35,8 @@ import { StaticToolResult, TypedToolResult, } from './tool-result'; -import { ToolSet } from './tool-set'; +import type { GenerationContext } from './generation-context'; +import type { ToolSet } from './tool-set'; export type UIMessageStreamOptions = { /** @@ -108,6 +109,7 @@ export type ConsumeStreamOptions = { */ export interface StreamTextResult< TOOLS extends ToolSet, + CONTEXT extends GenerationContext, OUTPUT extends Output, > { /** @@ -237,7 +239,7 @@ export interface StreamTextResult< * * Automatically consumes the stream. */ - readonly steps: PromiseLike>>; + readonly steps: PromiseLike>>; /** * Additional request information from the last step. diff --git a/packages/ai/src/generate-text/stream-text.test-d.ts b/packages/ai/src/generate-text/stream-text.test-d.ts index 3a9afb5228af..824d3a8b0909 100644 --- a/packages/ai/src/generate-text/stream-text.test-d.ts +++ b/packages/ai/src/generate-text/stream-text.test-d.ts @@ -1,4 +1,5 @@ import { JSONValue } from '@ai-sdk/provider'; +import { tool } from '@ai-sdk/provider-utils'; import { describe, expectTypeOf, it } from 'vitest'; import { z } from 'zod'; import { Output, streamText } from '../generate-text'; @@ -197,4 +198,47 @@ describe('streamText types', () => { >(); }); }); + + describe('experimental_context', () => { + it('should infer typed experimental_context with one tool context and prepareStep', async () => { + streamText({ + model: new MockLanguageModelV4(), + prompt: 'Hello, world!', + tools: { + weather: tool({ + inputSchema: z.object({ + city: z.string(), + }), + contextSchema: z.object({ + userId: z.string(), + }), + execute: async (_input, { experimental_context }) => { + expectTypeOf(experimental_context).toMatchObjectType<{ + userId: string; + }>(); + + return 'sunny'; + }, + }), + }, + experimental_context: { + userId: 'test-user', + role: 'admin', + }, + prepareStep: ({ experimental_context }) => { + expectTypeOf(experimental_context).toMatchObjectType<{ + userId: string; + role: string; + }>(); + + return { + experimental_context: { + userId: experimental_context.userId, + role: experimental_context.role, + }, + }; + }, + }); + }); + }); }); diff --git a/packages/ai/src/generate-text/stream-text.test.ts b/packages/ai/src/generate-text/stream-text.test.ts index ce7a7ed67e87..344e8cd28303 100644 --- a/packages/ai/src/generate-text/stream-text.test.ts +++ b/packages/ai/src/generate-text/stream-text.test.ts @@ -4968,7 +4968,7 @@ describe('streamText', () => { "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -5111,7 +5111,7 @@ describe('streamText', () => { "url": "https://example.com/2", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -5199,7 +5199,7 @@ describe('streamText', () => { "type": "file", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -5309,7 +5309,7 @@ describe('streamText', () => { "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -7041,7 +7041,7 @@ describe('streamText', () => { describe('options.onFinish', () => { it('should send correct information', async () => { - let result!: Parameters>[0]; + let result!: Parameters>[0]; const resultObject = streamText({ model: createTestModel({ @@ -7121,7 +7121,7 @@ describe('streamText', () => { ], "dynamicToolCalls": [], "dynamicToolResults": [], - "experimental_context": undefined, + "experimental_context": {}, "files": [], "finishReason": "stop", "functionId": undefined, @@ -7241,7 +7241,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -7436,7 +7436,7 @@ describe('streamText', () => { ], "dynamicToolCalls": [], "dynamicToolResults": [], - "experimental_context": undefined, + "experimental_context": {}, "files": [], "finishReason": "stop", "functionId": undefined, @@ -7531,7 +7531,7 @@ describe('streamText', () => { "url": "https://example.com/2", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -7698,7 +7698,7 @@ describe('streamText', () => { ], "dynamicToolCalls": [], "dynamicToolResults": [], - "experimental_context": undefined, + "experimental_context": {}, "files": [ DefaultGeneratedFileWithType { "base64Data": "Hello World", @@ -7787,7 +7787,7 @@ describe('streamText', () => { "type": "file", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -8023,9 +8023,9 @@ describe('streamText', () => { }); describe('options.stopWhen', () => { - let result: StreamTextResult; - let onFinishResult: Parameters>[0]; - let onStepFinishResults: StepResult[]; + let result: StreamTextResult; + let onFinishResult: Parameters>[0]; + let onStepFinishResults: StepResult[]; let tracer: MockTracer; let stepInputs: Array; @@ -8428,7 +8428,7 @@ describe('streamText', () => { ], "dynamicToolCalls": [], "dynamicToolResults": [], - "experimental_context": undefined, + "experimental_context": {}, "files": [], "finishReason": "stop", "functionId": undefined, @@ -8531,7 +8531,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -8615,7 +8615,7 @@ describe('streamText', () => { "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -8775,7 +8775,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -8859,7 +8859,7 @@ describe('streamText', () => { "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -9037,7 +9037,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -9121,7 +9121,7 @@ describe('streamText', () => { "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -9344,7 +9344,7 @@ describe('streamText', () => { let prepareStepCalls: Array<{ modelId: string; stepNumber: number; - steps: Array>; + steps: Array>; messages: Array; experimental_context: unknown; }>; @@ -10332,7 +10332,7 @@ describe('streamText', () => { ], "dynamicToolCalls": [], "dynamicToolResults": [], - "experimental_context": undefined, + "experimental_context": {}, "files": [], "finishReason": "stop", "functionId": undefined, @@ -10435,7 +10435,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -10519,7 +10519,7 @@ describe('streamText', () => { "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -10679,7 +10679,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -10763,7 +10763,7 @@ describe('streamText', () => { "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -10937,7 +10937,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -11021,7 +11021,7 @@ describe('streamText', () => { "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -11392,7 +11392,7 @@ describe('streamText', () => { describe('2 stop conditions', () => { let stopConditionCalls: Array<{ number: number; - steps: StepResult[]; + steps: StepResult[]; }>; beforeEach(async () => { @@ -11518,7 +11518,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -11628,7 +11628,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -11820,7 +11820,7 @@ describe('streamText', () => { describe('provider-executed tools', () => { describe('single provider-executed tool call and result', () => { - let result: StreamTextResult; + let result: StreamTextResult; beforeEach(async () => { result = streamText({ @@ -12455,7 +12455,7 @@ describe('streamText', () => { describe('dynamic tools', () => { describe('single dynamic tool call and result', () => { - let result: StreamTextResult; + let result: StreamTextResult; beforeEach(async () => { result = streamText({ @@ -12824,6 +12824,7 @@ describe('streamText', () => { { abortSignal: abortController.signal, toolCallId: 'call-1', + experimental_context: {}, messages: expect.any(Array), }, ); @@ -13451,7 +13452,7 @@ describe('streamText', () => { { "options": { "abortSignal": undefined, - "experimental_context": undefined, + "experimental_context": {}, "messages": [ { "content": "test-input", @@ -13465,7 +13466,7 @@ describe('streamText', () => { { "options": { "abortSignal": undefined, - "experimental_context": undefined, + "experimental_context": {}, "inputTextDelta": "{"", "messages": [ { @@ -13480,7 +13481,7 @@ describe('streamText', () => { { "options": { "abortSignal": undefined, - "experimental_context": undefined, + "experimental_context": {}, "inputTextDelta": "value", "messages": [ { @@ -13495,7 +13496,7 @@ describe('streamText', () => { { "options": { "abortSignal": undefined, - "experimental_context": undefined, + "experimental_context": {}, "inputTextDelta": "":"", "messages": [ { @@ -13510,7 +13511,7 @@ describe('streamText', () => { { "options": { "abortSignal": undefined, - "experimental_context": undefined, + "experimental_context": {}, "inputTextDelta": "Spark", "messages": [ { @@ -13525,7 +13526,7 @@ describe('streamText', () => { { "options": { "abortSignal": undefined, - "experimental_context": undefined, + "experimental_context": {}, "inputTextDelta": "le", "messages": [ { @@ -13540,7 +13541,7 @@ describe('streamText', () => { { "options": { "abortSignal": undefined, - "experimental_context": undefined, + "experimental_context": {}, "inputTextDelta": " Day", "messages": [ { @@ -13555,7 +13556,7 @@ describe('streamText', () => { { "options": { "abortSignal": undefined, - "experimental_context": undefined, + "experimental_context": {}, "inputTextDelta": ""}", "messages": [ { @@ -13570,7 +13571,7 @@ describe('streamText', () => { { "options": { "abortSignal": undefined, - "experimental_context": undefined, + "experimental_context": {}, "input": { "value": "Sparkle Day", }, @@ -13713,7 +13714,7 @@ describe('streamText', () => { }); describe('tool execution errors', () => { - let result: StreamTextResult; + let result: StreamTextResult; beforeEach(async () => { result = streamText({ @@ -13865,7 +13866,7 @@ describe('streamText', () => { "type": "tool-error", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -14370,7 +14371,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -14604,7 +14605,7 @@ describe('streamText', () => { ], "dynamicToolCalls": [], "dynamicToolResults": [], - "experimental_context": undefined, + "experimental_context": {}, "files": [], "finishReason": "stop", "functionId": undefined, @@ -14724,7 +14725,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -14954,7 +14955,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -15461,7 +15462,7 @@ describe('streamText', () => { "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -15906,7 +15907,7 @@ describe('streamText', () => { ], "dynamicToolCalls": [], "dynamicToolResults": [], - "experimental_context": undefined, + "experimental_context": {}, "files": [], "finishReason": "stop", "functionId": undefined, @@ -15952,7 +15953,7 @@ describe('streamText', () => { "type": "text", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -16045,7 +16046,7 @@ describe('streamText', () => { }); describe('array output', () => { - let result: StreamTextResult | undefined; + let result: StreamTextResult | undefined; let onFinishResult: | Parameters[0]>['onFinish']>[0] @@ -16796,7 +16797,7 @@ describe('streamText', () => { describe('mixed multi content streaming with interleaving parts', () => { describe('mixed text and reasoning blocks', () => { - let result: StreamTextResult; + let result: StreamTextResult; beforeEach(async () => { result = streamText({ @@ -17048,7 +17049,7 @@ describe('streamText', () => { "type": "reasoning", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -17120,9 +17121,9 @@ describe('streamText', () => { describe('abort signal', () => { describe('basic abort', () => { - let result: StreamTextResult; + let result: StreamTextResult; let onErrorCalls: Array<{ error: unknown }> = []; - let onAbortCalls: Array<{ steps: StepResult[] }> = []; + let onAbortCalls: Array<{ steps: StepResult[] }> = []; beforeEach(() => { onErrorCalls = []; @@ -17301,9 +17302,9 @@ describe('streamText', () => { }); describe('abort in 2nd step', () => { - let result: StreamTextResult; + let result: StreamTextResult; let onErrorCalls: Array<{ error: unknown }> = []; - let onAbortCalls: Array<{ steps: StepResult[] }> = []; + let onAbortCalls: Array<{ steps: StepResult[] }> = []; beforeEach(() => { onErrorCalls = []; @@ -17439,7 +17440,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "tool-calls", "functionId": undefined, "metadata": undefined, @@ -17619,9 +17620,9 @@ describe('streamText', () => { }); describe('abort during tool execution', () => { - let result: StreamTextResult; + let result: StreamTextResult; let onErrorCalls: Array<{ error: unknown }> = []; - let onAbortCalls: Array<{ steps: StepResult[] }> = []; + let onAbortCalls: Array<{ steps: StepResult[] }> = []; beforeEach(() => { onErrorCalls = []; @@ -17906,7 +17907,7 @@ describe('streamText', () => { describe('invalid tool calls', () => { describe('single invalid tool call', () => { - let result: StreamTextResult; + let result: StreamTextResult; beforeEach(async () => { result = streamText({ @@ -18203,7 +18204,7 @@ describe('streamText', () => { describe('tools with preliminary results', () => { describe('single tool with preliminary results', () => { - let result: StreamTextResult; + let result: StreamTextResult; beforeEach(async () => { result = streamText({ @@ -18494,7 +18495,7 @@ describe('streamText', () => { "type": "tool-result", }, ], - "experimental_context": undefined, + "experimental_context": {}, "finishReason": "stop", "functionId": undefined, "metadata": undefined, @@ -18574,7 +18575,7 @@ describe('streamText', () => { describe('provider-executed dynamic tools', () => { describe('single provider-executed dynamic tool with input streaming', () => { - let result: StreamTextResult; + let result: StreamTextResult; beforeEach(async () => { result = streamText({ @@ -18882,14 +18883,14 @@ describe('streamText', () => { describe('programmatic tool calling', () => { describe('5 steps: code_execution triggers client tool across multiple turns (dice game fixture)', () => { - let result: StreamTextResult; - let onFinishResult: Parameters>[0]; - let onStepFinishResults: StepResult[]; + let result: StreamTextResult; + let onFinishResult: Parameters>[0]; + let onStepFinishResults: StepResult[]; let doStreamCalls: Array; let prepareStepCalls: Array<{ modelId: string; stepNumber: number; - steps: Array>; + steps: Array>; messages: Array; }>; let rollDieExecutions: Array<{ player: string }>; @@ -20788,7 +20789,7 @@ describe('streamText', () => { describe('tool execution approval', () => { describe('when a single tool needs approval', () => { - let result: StreamTextResult; + let result: StreamTextResult; beforeEach(async () => { result = streamText({ @@ -21012,7 +21013,7 @@ describe('streamText', () => { }); describe('when a single tool has a needsApproval function', () => { - let result: StreamTextResult; + let result: StreamTextResult; let needsApprovalCalls: Array<{ input: any; options: any }> = []; beforeEach(async () => { @@ -21340,9 +21341,9 @@ describe('streamText', () => { }); describe('when a call from a single tool that needs approval is approved', () => { - let result: StreamTextResult; + let result: StreamTextResult; let prompts: LanguageModelV4Prompt[]; - let executeFunction: ToolExecuteFunction; + let executeFunction: ToolExecuteFunction; beforeEach(async () => { prompts = []; @@ -21644,7 +21645,7 @@ describe('streamText', () => { }); describe('when a call from a single tool that needs approval is approved and the tool throws', () => { - let result: StreamTextResult; + let result: StreamTextResult; let prompts: LanguageModelV4Prompt[]; beforeEach(async () => { @@ -21767,7 +21768,7 @@ describe('streamText', () => { }); describe('when a call from a single tool with preliminary results that needs approval is approved', () => { - let result: StreamTextResult; + let result: StreamTextResult; let prompts: LanguageModelV4Prompt[]; beforeEach(async () => { @@ -22097,9 +22098,9 @@ describe('streamText', () => { }); describe('when a call from a single tool that needs approval is denied', () => { - let result: StreamTextResult; + let result: StreamTextResult; let prompts: LanguageModelV4Prompt[]; - let executeFunction: ToolExecuteFunction; + let executeFunction: ToolExecuteFunction; beforeEach(async () => { prompts = []; @@ -22389,7 +22390,7 @@ describe('streamText', () => { describe('provider-executed tool (MCP) approval', () => { describe('when a provider-executed tool emits tool-approval-request', () => { - let result: StreamTextResult; + let result: StreamTextResult; beforeEach(async () => { result = streamText({ @@ -22624,7 +22625,7 @@ describe('streamText', () => { }); describe('when a provider-executed tool approval is approved', () => { - let result: StreamTextResult; + let result: StreamTextResult; let prompts: LanguageModelV4Prompt[]; beforeEach(async () => { @@ -22813,7 +22814,7 @@ describe('streamText', () => { }); describe('when a provider-executed tool approval is denied', () => { - let result: StreamTextResult; + let result: StreamTextResult; let prompts: LanguageModelV4Prompt[]; beforeEach(async () => { diff --git a/packages/ai/src/generate-text/stream-text.ts b/packages/ai/src/generate-text/stream-text.ts index e190b365500c..455bbefd6b9f 100644 --- a/packages/ai/src/generate-text/stream-text.ts +++ b/packages/ai/src/generate-text/stream-text.ts @@ -84,6 +84,7 @@ import type { } from './core-events'; import { createExecuteToolsTransformation } from './create-execute-tools-transformation'; import { executeToolCall } from './execute-tool-call'; +import { filterActiveTools } from './filter-active-tool'; import { invokeToolCallbacksFromStream } from './invoke-tool-callbacks-from-stream'; import { Output, text } from './output'; import { @@ -96,8 +97,8 @@ import { convertToReasoningOutputs } from './reasoning-output'; import { ResponseMessage } from './response-message'; import { DefaultStepResult, StepResult } from './step-result'; import { - isStopConditionMet, isStepCount, + isStopConditionMet, StopCondition, } from './stop-condition'; import { ModelCallStreamPart, streamModelCall } from './stream-model-call'; @@ -112,8 +113,8 @@ import { TypedToolCall } from './tool-call'; import { ToolCallRepairFunction } from './tool-call-repair-function'; import { ToolOutput } from './tool-output'; import { StaticToolOutputDenied } from './tool-output-denied'; -import { ToolSet } from './tool-set'; -import { filterActiveTools } from './filter-active-tool'; +import type { GenerationContext } from './generation-context'; +import type { ToolSet } from './tool-set'; const originalGenerateId = createIdGenerator({ prefix: 'aitxt', @@ -150,9 +151,10 @@ export type StreamTextOnErrorCallback = (event: { * * @param stepResult - The result of the step. */ -export type StreamTextOnStepFinishCallback = ( - event: OnStepFinishEvent, -) => PromiseLike | void; +export type StreamTextOnStepFinishCallback< + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, +> = (event: OnStepFinishEvent) => PromiseLike | void; /** * Callback that is set using the `onChunk` option. @@ -182,20 +184,24 @@ export type StreamTextOnChunkCallback = (event: { * * @param event - The event that is passed to the callback. */ -export type StreamTextOnFinishCallback = ( - event: OnFinishEvent, -) => PromiseLike | void; +export type StreamTextOnFinishCallback< + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, +> = (event: OnFinishEvent) => PromiseLike | void; /** * Callback that is set using the `onAbort` option. * * @param event - The event that is passed to the callback. */ -export type StreamTextOnAbortCallback = (event: { +export type StreamTextOnAbortCallback< + TOOLS extends ToolSet, + CONTEXT extends GenerationContext, +> = (event: { /** * Details for all previously finished steps. */ - readonly steps: StepResult[]; + readonly steps: StepResult[]; }) => PromiseLike | void; /** @@ -214,9 +220,10 @@ type StreamTextIncludeSettings = { requestBody?: boolean }; */ export type StreamTextOnStartCallback< TOOLS extends ToolSet = ToolSet, + CONTEXT extends GenerationContext = GenerationContext, OUTPUT extends Output = Output, > = ( - event: OnStartEvent, + event: OnStartEvent, ) => PromiseLike | void; /** @@ -230,9 +237,10 @@ export type StreamTextOnStartCallback< */ export type StreamTextOnStepStartCallback< TOOLS extends ToolSet = ToolSet, + CONTEXT extends GenerationContext = GenerationContext, OUTPUT extends Output = Output, > = ( - event: OnStepStartEvent, + event: OnStepStartEvent, ) => PromiseLike | void; export type StreamTextOnToolCallStartCallback = @@ -290,6 +298,7 @@ export type StreamTextOnToolCallFinishCallback< */ export function streamText< TOOLS extends ToolSet, + CONTEXT extends GenerationContext = GenerationContext, OUTPUT extends Output = Output, >({ model, @@ -325,7 +334,7 @@ export function streamText< experimental_onStepStart: onStepStart, experimental_onToolCallStart: onToolCallStart, experimental_onToolCallFinish: onToolCallFinish, - experimental_context, + experimental_context = {} as CONTEXT, experimental_include: include, _internal: { now = originalNow, @@ -365,8 +374,8 @@ export function streamText< * @default isStepCount(1) */ stopWhen?: - | StopCondition> - | Array>>; + | StopCondition, CONTEXT> + | Array, CONTEXT>>; /** * Optional telemetry configuration (experimental). @@ -414,7 +423,7 @@ export function streamText< * @returns An object that contains the settings for the step. * If you return undefined (or for undefined settings), the settings from the outer level will be used. */ - prepareStep?: PrepareStepFunction>; + prepareStep?: PrepareStepFunction, CONTEXT>; /** * A function that attempts to repair a tool call that failed to parse. @@ -464,20 +473,27 @@ export function streamText< * * The usage is the combined usage of all steps. */ - onFinish?: StreamTextOnFinishCallback; + onFinish?: StreamTextOnFinishCallback, NoInfer>; - onAbort?: StreamTextOnAbortCallback; + onAbort?: StreamTextOnAbortCallback, NoInfer>; /** * Callback that is called when each step (LLM call) is finished, including intermediate steps. */ - onStepFinish?: StreamTextOnStepFinishCallback; + onStepFinish?: StreamTextOnStepFinishCallback< + NoInfer, + NoInfer + >; /** * Callback that is called when the streamText operation begins, * before any LLM calls are made. */ - experimental_onStart?: StreamTextOnStartCallback, OUTPUT>; + experimental_onStart?: StreamTextOnStartCallback< + NoInfer, + NoInfer, + NoInfer + >; /** * Callback that is called when a step (LLM call) begins, @@ -485,7 +501,8 @@ export function streamText< */ experimental_onStepStart?: StreamTextOnStepStartCallback< NoInfer, - OUTPUT + NoInfer, + NoInfer >; /** @@ -509,7 +526,7 @@ export function streamText< * * @default undefined */ - experimental_context?: unknown; + experimental_context?: CONTEXT; /** * Settings for controlling what data is included in step results. @@ -535,7 +552,7 @@ export function streamText< generateId?: IdGenerator; generateCallId?: IdGenerator; }; - }): StreamTextResult { + }): StreamTextResult { const totalTimeoutMs = getTotalTimeoutMs(timeout); const stepTimeoutMs = getStepTimeoutMs(timeout); const chunkTimeoutMs = getChunkTimeoutMs(timeout); @@ -543,7 +560,7 @@ export function streamText< stepTimeoutMs != null ? new AbortController() : undefined; const chunkAbortController = chunkTimeoutMs != null ? new AbortController() : undefined; - return new DefaultStreamTextResult({ + return new DefaultStreamTextResult({ model: resolveLanguageModel(model), telemetry, headers, @@ -697,19 +714,20 @@ function createOutputTransformStream< class DefaultStreamTextResult< TOOLS extends ToolSet, + CONTEXT extends GenerationContext, OUTPUT extends Output, -> implements StreamTextResult { +> implements StreamTextResult { private readonly _totalUsage = new DelayedPromise< - Awaited['usage']> + Awaited['usage']> >(); private readonly _finishReason = new DelayedPromise< - Awaited['finishReason']> + Awaited['finishReason']> >(); private readonly _rawFinishReason = new DelayedPromise< - Awaited['rawFinishReason']> + Awaited['rawFinishReason']> >(); private readonly _steps = new DelayedPromise< - Awaited['steps']> + Awaited['steps']> >(); private readonly addStream: ( @@ -789,32 +807,52 @@ class DefaultStreamTextResult< transforms: Array>; activeTools: Array | undefined; repairToolCall: ToolCallRepairFunction | undefined; - stopConditions: Array>>; + stopConditions: Array, NoInfer>>; output: OUTPUT | undefined; providerOptions: ProviderOptions | undefined; - prepareStep: PrepareStepFunction> | undefined; + prepareStep: + | PrepareStepFunction, NoInfer> + | undefined; includeRawChunks: boolean; now: () => number; generateId: () => string; generateCallId: () => string; timeout: TimeoutConfiguration | undefined; stopWhen: - | StopCondition> - | Array>> + | StopCondition, NoInfer> + | Array, NoInfer>> | undefined; originalAbortSignal: AbortSignal | undefined; - experimental_context: unknown; + experimental_context: CONTEXT; download: DownloadFunction | undefined; include: { requestBody?: boolean } | undefined; // callbacks: onChunk: undefined | StreamTextOnChunkCallback; onError: StreamTextOnErrorCallback; - onFinish: undefined | StreamTextOnFinishCallback; - onAbort: undefined | StreamTextOnAbortCallback; - onStepFinish: undefined | StreamTextOnStepFinishCallback; - onStart: undefined | StreamTextOnStartCallback; - onStepStart: undefined | StreamTextOnStepStartCallback; + onFinish: + | undefined + | StreamTextOnFinishCallback, NoInfer>; + onAbort: + | undefined + | StreamTextOnAbortCallback, NoInfer>; + onStepFinish: + | undefined + | StreamTextOnStepFinishCallback, NoInfer>; + onStart: + | undefined + | StreamTextOnStartCallback< + NoInfer, + NoInfer, + NoInfer + >; + onStepStart: + | undefined + | StreamTextOnStepStartCallback< + NoInfer, + NoInfer, + NoInfer + >; onToolCallStart: undefined | StreamTextOnToolCallStartCallback; onToolCallFinish: undefined | StreamTextOnToolCallFinishCallback; }) { @@ -842,7 +880,7 @@ class DefaultStreamTextResult< let recordedTotalUsage: LanguageModelUsage | undefined = undefined; let recordedRequest: LanguageModelRequestMetadata = {}; let recordedWarnings: Array = []; - const recordedSteps: StepResult[] = []; + const recordedSteps: StepResult[] = []; // Track provider-executed tool calls that support deferred results // (e.g., code_execution in programmatic tool calling scenarios). @@ -1043,25 +1081,26 @@ class DefaultStreamTextResult< }); // Add step information (after response messages are updated): - const currentStepResult: StepResult = new DefaultStepResult({ - callId, - stepNumber: recordedSteps.length, - provider: model.provider, - modelId: model.modelId, - ...callbackTelemetryProps, - experimental_context, - content: recordedContent, - finishReason: part.finishReason, - rawFinishReason: part.rawFinishReason, - usage: part.usage, - warnings: recordedWarnings, - request: recordedRequest, - response: { - ...part.response, - messages: [...recordedResponseMessages, ...stepMessages], - }, - providerMetadata: part.providerMetadata, - }); + const currentStepResult: StepResult = + new DefaultStepResult({ + callId, + stepNumber: recordedSteps.length, + provider: model.provider, + modelId: model.modelId, + ...callbackTelemetryProps, + experimental_context, + content: recordedContent, + finishReason: part.finishReason, + rawFinishReason: part.rawFinishReason, + usage: part.usage, + warnings: recordedWarnings, + request: recordedRequest, + response: { + ...part.response, + messages: [...recordedResponseMessages, ...stepMessages], + }, + providerMetadata: part.providerMetadata, + }); await notify({ event: currentStepResult, @@ -1157,7 +1196,7 @@ class DefaultStreamTextResult< onFinish, globalTelemetry.onFinish as | undefined - | StreamTextOnFinishCallback, + | StreamTextOnFinishCallback, NoInfer>, ], }); } catch (error) { @@ -1319,7 +1358,7 @@ class DefaultStreamTextResult< onStart, globalTelemetry.onStart as | undefined - | StreamTextOnStartCallback, + | StreamTextOnStartCallback, ], }); @@ -1614,7 +1653,7 @@ class DefaultStreamTextResult< onStepStart, globalTelemetry.onStepStart as | undefined - | StreamTextOnStepStartCallback, + | StreamTextOnStepStartCallback, ], }); }, diff --git a/packages/ai/src/generate-text/tool-call.ts b/packages/ai/src/generate-text/tool-call.ts index b6f087a675f9..1d72085b9091 100644 --- a/packages/ai/src/generate-text/tool-call.ts +++ b/packages/ai/src/generate-text/tool-call.ts @@ -1,4 +1,4 @@ -import { Tool } from '@ai-sdk/provider-utils'; +import { InferToolInput } from '@ai-sdk/provider-utils'; import { ProviderMetadata } from '../types'; import { ValueOf } from '../util/value-of'; import { ToolSet } from './tool-set'; @@ -13,7 +13,7 @@ type BaseToolCall = { export type StaticToolCall = ValueOf<{ [NAME in keyof TOOLS]: BaseToolCall & { toolName: NAME & string; - input: TOOLS[NAME] extends Tool ? PARAMETERS : never; + input: InferToolInput; dynamic?: false | undefined; invalid?: false | undefined; error?: never; diff --git a/packages/ai/src/generate-text/tool-set.ts b/packages/ai/src/generate-text/tool-set.ts index 09e64774d48e..57f40d946f9c 100644 --- a/packages/ai/src/generate-text/tool-set.ts +++ b/packages/ai/src/generate-text/tool-set.ts @@ -1,10 +1,15 @@ -import { Tool } from '@ai-sdk/provider-utils'; +import type { Tool } from '@ai-sdk/provider-utils'; export type ToolSet = Record< string, - (Tool | Tool | Tool | Tool) & + ( + | Tool + | Tool + | Tool + | Tool + ) & Pick< - Tool, + Tool, | 'execute' | 'onInputAvailable' | 'onInputStart' diff --git a/packages/ai/src/telemetry/open-telemetry-integration.ts b/packages/ai/src/telemetry/open-telemetry-integration.ts index 3b4fd00ea77e..f9a156937511 100644 --- a/packages/ai/src/telemetry/open-telemetry-integration.ts +++ b/packages/ai/src/telemetry/open-telemetry-integration.ts @@ -31,6 +31,7 @@ import type { OnToolCallStartEvent, } from '../generate-text/core-events'; import type { Output } from '../generate-text/output'; +import type { GenerationContext } from '../generate-text/generation-context'; import type { ToolSet } from '../generate-text/tool-set'; import type { ObjectOnStartEvent, @@ -115,8 +116,9 @@ function selectAttributes( interface OtelStepStartEvent< TOOLS extends ToolSet = ToolSet, + CONTEXT extends GenerationContext = GenerationContext, OUTPUT extends Output = Output, -> extends OnStepStartEvent { +> extends OnStepStartEvent { readonly promptMessages?: LanguageModelV4Prompt; readonly stepTools?: ReadonlyArray>; readonly stepToolChoice?: unknown; diff --git a/packages/ai/src/telemetry/telemetry-integration.ts b/packages/ai/src/telemetry/telemetry-integration.ts index 1d7703ec0a0d..ff720a90a44f 100644 --- a/packages/ai/src/telemetry/telemetry-integration.ts +++ b/packages/ai/src/telemetry/telemetry-integration.ts @@ -1,3 +1,15 @@ +import type { + EmbedFinishEvent, + EmbedOnFinishEvent, + EmbedOnStartEvent, + EmbedStartEvent, +} from '../embed/embed-events'; +import type { + ObjectOnFinishEvent, + ObjectOnStartEvent, + ObjectOnStepFinishEvent, + ObjectOnStepStartEvent, +} from '../generate-object/structured-output-events'; import type { OnChunkEvent, OnFinishEvent, @@ -10,22 +22,10 @@ import type { import type { Output } from '../generate-text/output'; import type { ToolSet } from '../generate-text/tool-set'; import type { - EmbedOnStartEvent, - EmbedOnFinishEvent, - EmbedStartEvent, - EmbedFinishEvent, -} from '../embed/embed-events'; -import type { - ObjectOnStartEvent, - ObjectOnFinishEvent, - ObjectOnStepStartEvent, - ObjectOnStepFinishEvent, -} from '../generate-object/structured-output-events'; -import type { - RerankOnStartEvent, + RerankFinishEvent, RerankOnFinishEvent, + RerankOnStartEvent, RerankStartEvent, - RerankFinishEvent, } from '../rerank/rerank-events'; import { Listener } from '../util/notify'; diff --git a/packages/ai/src/ui/direct-chat-transport.ts b/packages/ai/src/ui/direct-chat-transport.ts index a795eadbda7d..52a999944f8a 100644 --- a/packages/ai/src/ui/direct-chat-transport.ts +++ b/packages/ai/src/ui/direct-chat-transport.ts @@ -1,8 +1,9 @@ +import { Agent } from '../agent/agent'; import { Output } from '../generate-text/output'; import { UIMessageStreamOptions } from '../generate-text/stream-text-result'; -import { ToolSet } from '../generate-text/tool-set'; +import type { GenerationContext } from '../generate-text/generation-context'; +import type { ToolSet } from '../generate-text/tool-set'; import { UIMessageChunk } from '../ui-message-stream/ui-message-chunks'; -import { Agent } from '../agent/agent'; import { ChatTransport } from './chat-transport'; import { convertToModelMessages } from './convert-to-model-messages'; import { InferUITools, UIMessage } from './ui-messages'; @@ -14,13 +15,14 @@ import { validateUIMessages } from './validate-ui-messages'; export type DirectChatTransportOptions< CALL_OPTIONS, TOOLS extends ToolSet, - OUTPUT extends Output, + CONTEXT extends GenerationContext, + OUTPUT extends Output, UI_MESSAGE extends UIMessage>, > = { /** * The agent to use for generating responses. */ - agent: Agent; + agent: Agent; /** * Options to pass to the agent when calling it. @@ -49,14 +51,15 @@ export type DirectChatTransportOptions< export class DirectChatTransport< CALL_OPTIONS = never, TOOLS extends ToolSet = {}, - OUTPUT extends Output = never, + CONTEXT extends GenerationContext = GenerationContext, + OUTPUT extends Output = never, UI_MESSAGE extends UIMessage> = UIMessage< unknown, never, InferUITools >, > implements ChatTransport { - private readonly agent: Agent; + private readonly agent: Agent; private readonly agentOptions: CALL_OPTIONS | undefined; private readonly uiMessageStreamOptions: Omit< UIMessageStreamOptions, @@ -67,7 +70,13 @@ export class DirectChatTransport< agent, options, ...uiMessageStreamOptions - }: DirectChatTransportOptions) { + }: DirectChatTransportOptions< + CALL_OPTIONS, + TOOLS, + CONTEXT, + OUTPUT, + UI_MESSAGE + >) { this.agent = agent; this.agentOptions = options; this.uiMessageStreamOptions = uiMessageStreamOptions; @@ -97,7 +106,7 @@ export class DirectChatTransport< ...(this.agentOptions !== undefined ? { options: this.agentOptions } : {}), - } as Parameters['stream']>[0]); + } as Parameters['stream']>[0]); // Return the UI message stream return result.toUIMessageStream(this.uiMessageStreamOptions); diff --git a/packages/ai/src/util/union-to-intersection.test-d.ts b/packages/ai/src/util/union-to-intersection.test-d.ts new file mode 100644 index 000000000000..3a94464b5c58 --- /dev/null +++ b/packages/ai/src/util/union-to-intersection.test-d.ts @@ -0,0 +1,32 @@ +import { describe, expectTypeOf, it } from 'vitest'; +import type { UnionToIntersection } from './union-to-intersection'; + +describe('UnionToIntersection', () => { + it('returns never when given no input', () => { + type Result = UnionToIntersection; + + expectTypeOf().toEqualTypeOf(); + }); + + it('returns the same type for a single input', () => { + type Result = UnionToIntersection<{ city: string }>; + + expectTypeOf().toEqualTypeOf<{ + city: string; + }>(); + }); + + it('converts a union of object types into an intersection', () => { + type Result = UnionToIntersection< + { city: string } | { countryCode: string } + >; + + expectTypeOf().toEqualTypeOf< + { + city: string; + } & { + countryCode: string; + } + >(); + }); +}); diff --git a/packages/ai/src/util/union-to-intersection.ts b/packages/ai/src/util/union-to-intersection.ts new file mode 100644 index 000000000000..63186a2bde79 --- /dev/null +++ b/packages/ai/src/util/union-to-intersection.ts @@ -0,0 +1,17 @@ +/** + * Converts a union type `U` into an intersection type. + * + * For example: + * type A = { a: number }; + * type B = { b: string }; + * type Union = A | B; + * type Intersection = UnionToIntersection; + * // Intersection is: { a: number } & { b: string } + * + * This is useful when you have a union of object types and need a type with all possible properties. + */ +export type UnionToIntersection = ( + U extends unknown ? (arg: U) => void : never +) extends (arg: infer I) => void + ? I + : never; diff --git a/packages/mcp/src/tool/mcp-client.test.ts b/packages/mcp/src/tool/mcp-client.test.ts index 629d785b00cc..7747b0f72377 100644 --- a/packages/mcp/src/tool/mcp-client.test.ts +++ b/packages/mcp/src/tool/mcp-client.test.ts @@ -76,6 +76,7 @@ describe('MCPClient', () => { { messages: [], toolCallId: '1', + experimental_context: {}, }, ), ).toMatchInlineSnapshot(` @@ -135,7 +136,7 @@ describe('MCPClient', () => { // Verify the execute function works const result = await tool.execute( { foo: 'bar' }, - { messages: [], toolCallId: '1' }, + { messages: [], toolCallId: '1', experimental_context: {} }, ); expect(result).toMatchObject({ content: [{ type: 'text', text: 'Mock tool call result' }], @@ -192,8 +193,12 @@ describe('MCPClient', () => { const tools = await client.tools(); const tool = tools['get-image']; - expect(await tool.execute!({}, { messages: [], toolCallId: '1' })) - .toMatchInlineSnapshot(` + expect( + await tool.execute!( + {}, + { messages: [], toolCallId: '1', experimental_context: {} }, + ), + ).toMatchInlineSnapshot(` { "content": [ { @@ -268,8 +273,12 @@ describe('MCPClient', () => { const tools = await client.tools(); const tool = tools['get-text']; - expect(await tool.execute!({}, { messages: [], toolCallId: '1' })) - .toMatchInlineSnapshot(` + expect( + await tool.execute!( + {}, + { messages: [], toolCallId: '1', experimental_context: {} }, + ), + ).toMatchInlineSnapshot(` { "content": [ { @@ -334,8 +343,12 @@ describe('MCPClient', () => { const tools = await client.tools(); const tool = tools['get-mixed']; - expect(await tool.execute!({}, { messages: [], toolCallId: '1' })) - .toMatchInlineSnapshot(` + expect( + await tool.execute!( + {}, + { messages: [], toolCallId: '1', experimental_context: {} }, + ), + ).toMatchInlineSnapshot(` { "content": [ { @@ -411,8 +424,12 @@ describe('MCPClient', () => { const tools = await client.tools(); const tool = tools['get-unknown']; - expect(await tool.execute!({}, { messages: [], toolCallId: '1' })) - .toMatchInlineSnapshot(` + expect( + await tool.execute!( + {}, + { messages: [], toolCallId: '1', experimental_context: {} }, + ), + ).toMatchInlineSnapshot(` { "content": [ { @@ -477,8 +494,12 @@ describe('MCPClient', () => { const tools = await client.tools(); const tool = tools['get-raw']; - expect(await tool.execute!({}, { messages: [], toolCallId: '1' })) - .toMatchInlineSnapshot(` + expect( + await tool.execute!( + {}, + { messages: [], toolCallId: '1', experimental_context: {} }, + ), + ).toMatchInlineSnapshot(` { "isError": false, "toolResult": undefined, @@ -716,6 +737,7 @@ describe('MCPClient', () => { { messages: [], toolCallId: '1', + experimental_context: {}, }, ); @@ -758,7 +780,10 @@ describe('MCPClient', () => { }); const toolCall = tools['mock-tool'].execute; await expect( - toolCall({ bar: 'bar' }, { messages: [], toolCallId: '1' }), + toolCall( + { bar: 'bar' }, + { messages: [], toolCallId: '1', experimental_context: {} }, + ), ).rejects.toThrow(MCPClientError); }); @@ -782,7 +807,10 @@ describe('MCPClient', () => { const toolCall = tools['mock-tool'].execute; try { - await toolCall({ bar: 'bar' }, { messages: [], toolCallId: '1' }); + await toolCall( + { bar: 'bar' }, + { messages: [], toolCallId: '1', experimental_context: {} }, + ); throw new Error('Expected error to be thrown'); } catch (error) { expect(MCPClientError.isInstance(error)).toBe(true); @@ -940,6 +968,7 @@ describe('MCPClient', () => { messages: [], toolCallId: '1', abortSignal: abortController.signal, + experimental_context: {}, }, ), ).rejects.toSatisfy( @@ -1054,6 +1083,7 @@ describe('MCPClient', () => { { messages: [], toolCallId: '1', + experimental_context: {}, }, ); @@ -1091,7 +1121,10 @@ describe('MCPClient', () => { }, }); - const result = await tool.execute({}, { messages: [], toolCallId: '1' }); + const result = await tool.execute( + {}, + { messages: [], toolCallId: '1', experimental_context: {} }, + ); expect(result).toMatchInlineSnapshot(` { "content": [ @@ -1209,7 +1242,7 @@ describe('MCPClient', () => { const result = await tool.execute( { location: 'New York' }, - { messages: [], toolCallId: '1' }, + { messages: [], toolCallId: '1', experimental_context: {} }, ); expectTypeOf>>().toEqualTypeOf<{ @@ -1265,7 +1298,7 @@ describe('MCPClient', () => { const result = await tools['json-tool'].execute( {}, - { messages: [], toolCallId: '1' }, + { messages: [], toolCallId: '1', experimental_context: {} }, ); expect(result).toEqual({ @@ -1319,7 +1352,7 @@ describe('MCPClient', () => { const result = await tool.execute( { input: 'test' }, - { messages: [], toolCallId: '1' }, + { messages: [], toolCallId: '1', experimental_context: {} }, ); expectTypeOf< @@ -1380,7 +1413,10 @@ describe('MCPClient', () => { }); await expect( - tools['bad-output-tool'].execute({}, { messages: [], toolCallId: '1' }), + tools['bad-output-tool'].execute( + {}, + { messages: [], toolCallId: '1', experimental_context: {} }, + ), ).rejects.toThrow(MCPClientError); }); @@ -1426,7 +1462,7 @@ describe('MCPClient', () => { await expect( tools['invalid-json-tool'].execute( {}, - { messages: [], toolCallId: '1' }, + { messages: [], toolCallId: '1', experimental_context: {} }, ), ).rejects.toThrow(MCPClientError); }); @@ -1473,7 +1509,7 @@ describe('MCPClient', () => { await expect( tools['mismatched-json-tool'].execute( {}, - { messages: [], toolCallId: '1' }, + { messages: [], toolCallId: '1', experimental_context: {} }, ), ).rejects.toThrow(MCPClientError); }); @@ -1487,7 +1523,7 @@ describe('MCPClient', () => { const result = await tools['mock-tool'].execute( { foo: 'bar' }, - { messages: [], toolCallId: '1' }, + { messages: [], toolCallId: '1', experimental_context: {} }, ); // With automatic discovery, result is CallToolResult @@ -1568,7 +1604,7 @@ describe('MCPClient', () => { const result = await tools['complex-tool'].execute( {}, - { messages: [], toolCallId: '1' }, + { messages: [], toolCallId: '1', experimental_context: {} }, ); expect(result).toEqual({ diff --git a/packages/mcp/src/tool/mcp-client.ts b/packages/mcp/src/tool/mcp-client.ts index debcaf497b0c..978fb44ed9a7 100644 --- a/packages/mcp/src/tool/mcp-client.ts +++ b/packages/mcp/src/tool/mcp-client.ts @@ -420,7 +420,7 @@ class DefaultMCPClient implements MCPClient { }: { name: string; args: Record; - options?: ToolExecutionOptions; + options?: ToolExecutionOptions<{}>; }): Promise { try { return this.request({ @@ -579,7 +579,7 @@ class DefaultMCPClient implements MCPClient { const execute = async ( args: any, - options: ToolExecutionOptions, + options: ToolExecutionOptions<{}>, ): Promise => { options?.abortSignal?.throwIfAborted(); const result = await self.callTool({ name, args, options }); diff --git a/packages/openai/src/tool/local-shell.test-d.ts b/packages/openai/src/tool/local-shell.test-d.ts index 45b198d660a4..8dc6ef798de7 100644 --- a/packages/openai/src/tool/local-shell.test-d.ts +++ b/packages/openai/src/tool/local-shell.test-d.ts @@ -13,7 +13,8 @@ describe('local-shell tool type', () => { expectTypeOf(localShellTool).toEqualTypeOf< Tool< InferSchema, - InferSchema + InferSchema, + {} > >(); }); diff --git a/packages/openai/src/tool/web-search.test-d.ts b/packages/openai/src/tool/web-search.test-d.ts index fc78329bf87f..8c0b8e86dc62 100644 --- a/packages/openai/src/tool/web-search.test-d.ts +++ b/packages/openai/src/tool/web-search.test-d.ts @@ -7,7 +7,7 @@ describe('web-search tool type', () => { const webSearchTool = webSearch(); expectTypeOf(webSearchTool).toEqualTypeOf< - Tool<{}, InferSchema> + Tool<{}, InferSchema, {}> >(); }); }); diff --git a/packages/provider-utils/src/provider-tool-factory.ts b/packages/provider-utils/src/provider-tool-factory.ts index 2c29130e0a51..77e63ec5bfd1 100644 --- a/packages/provider-utils/src/provider-tool-factory.ts +++ b/packages/provider-utils/src/provider-tool-factory.ts @@ -1,24 +1,33 @@ import { tool, Tool, ToolExecuteFunction } from './types/tool'; import { FlexibleSchema } from './schema'; +import { Context } from './types/context'; -export type ProviderToolFactory = ( +export type ProviderToolFactory< + INPUT, + ARGS extends object, + CONTEXT extends Context = {}, +> = ( options: ARGS & { - execute?: ToolExecuteFunction; - needsApproval?: Tool['needsApproval']; - toModelOutput?: Tool['toModelOutput']; - onInputStart?: Tool['onInputStart']; - onInputDelta?: Tool['onInputDelta']; - onInputAvailable?: Tool['onInputAvailable']; + execute?: ToolExecuteFunction; + needsApproval?: Tool['needsApproval']; + toModelOutput?: Tool['toModelOutput']; + onInputStart?: Tool['onInputStart']; + onInputDelta?: Tool['onInputDelta']; + onInputAvailable?: Tool['onInputAvailable']; }, -) => Tool; +) => Tool; -export function createProviderToolFactory({ +export function createProviderToolFactory< + INPUT, + ARGS extends object, + CONTEXT extends Context = {}, +>({ id, inputSchema, }: { id: `${string}.${string}`; inputSchema: FlexibleSchema; -}): ProviderToolFactory { +}): ProviderToolFactory { return ({ execute, outputSchema, @@ -29,14 +38,14 @@ export function createProviderToolFactory({ onInputAvailable, ...args }: ARGS & { - execute?: ToolExecuteFunction; + execute?: ToolExecuteFunction; outputSchema?: FlexibleSchema; - needsApproval?: Tool['needsApproval']; - toModelOutput?: Tool['toModelOutput']; - onInputStart?: Tool['onInputStart']; - onInputDelta?: Tool['onInputDelta']; - onInputAvailable?: Tool['onInputAvailable']; - }): Tool => + needsApproval?: Tool['needsApproval']; + toModelOutput?: Tool['toModelOutput']; + onInputStart?: Tool['onInputStart']; + onInputDelta?: Tool['onInputDelta']; + onInputAvailable?: Tool['onInputAvailable']; + }): Tool => tool({ type: 'provider', id, @@ -56,21 +65,23 @@ export type ProviderToolFactoryWithOutputSchema< INPUT, OUTPUT, ARGS extends object, + CONTEXT extends Context = {}, > = ( options: ARGS & { - execute?: ToolExecuteFunction; - needsApproval?: Tool['needsApproval']; - toModelOutput?: Tool['toModelOutput']; - onInputStart?: Tool['onInputStart']; - onInputDelta?: Tool['onInputDelta']; - onInputAvailable?: Tool['onInputAvailable']; + execute?: ToolExecuteFunction; + needsApproval?: Tool['needsApproval']; + toModelOutput?: Tool['toModelOutput']; + onInputStart?: Tool['onInputStart']; + onInputDelta?: Tool['onInputDelta']; + onInputAvailable?: Tool['onInputAvailable']; }, -) => Tool; +) => Tool; export function createProviderToolFactoryWithOutputSchema< INPUT, OUTPUT, ARGS extends object, + CONTEXT extends Context = {}, >({ id, inputSchema, @@ -91,7 +102,7 @@ export function createProviderToolFactoryWithOutputSchema< * @default false */ supportsDeferredResults?: boolean; -}): ProviderToolFactoryWithOutputSchema { +}): ProviderToolFactoryWithOutputSchema { return ({ execute, needsApproval, @@ -101,13 +112,13 @@ export function createProviderToolFactoryWithOutputSchema< onInputAvailable, ...args }: ARGS & { - execute?: ToolExecuteFunction; - needsApproval?: Tool['needsApproval']; - toModelOutput?: Tool['toModelOutput']; - onInputStart?: Tool['onInputStart']; - onInputDelta?: Tool['onInputDelta']; - onInputAvailable?: Tool['onInputAvailable']; - }): Tool => + execute?: ToolExecuteFunction; + needsApproval?: Tool['needsApproval']; + toModelOutput?: Tool['toModelOutput']; + onInputStart?: Tool['onInputStart']; + onInputDelta?: Tool['onInputDelta']; + onInputAvailable?: Tool['onInputAvailable']; + }): Tool => tool({ type: 'provider', id, diff --git a/packages/provider-utils/src/types/context.ts b/packages/provider-utils/src/types/context.ts new file mode 100644 index 000000000000..d595a8e7d3a3 --- /dev/null +++ b/packages/provider-utils/src/types/context.ts @@ -0,0 +1,4 @@ +/** + * A context object that is passed into tool execution. + */ +export type Context = Record; diff --git a/packages/provider-utils/src/types/execute-tool.ts b/packages/provider-utils/src/types/execute-tool.ts index 7780d582f4ce..41f83de3721c 100644 --- a/packages/provider-utils/src/types/execute-tool.ts +++ b/packages/provider-utils/src/types/execute-tool.ts @@ -1,14 +1,15 @@ import { isAsyncIterable } from '../is-async-iterable'; -import { ToolExecutionOptions, ToolExecuteFunction } from './tool'; +import { Context } from './context'; +import { ToolExecuteFunction, ToolExecutionOptions } from './tool'; -export async function* executeTool({ +export async function* executeTool({ execute, input, options, }: { - execute: ToolExecuteFunction; + execute: ToolExecuteFunction; input: INPUT; - options: ToolExecutionOptions; + options: ToolExecutionOptions>; }): AsyncGenerator< { type: 'preliminary'; output: OUTPUT } | { type: 'final'; output: OUTPUT } > { diff --git a/packages/provider-utils/src/types/index.ts b/packages/provider-utils/src/types/index.ts index 8f74b0d80e3b..e896a6e14521 100644 --- a/packages/provider-utils/src/types/index.ts +++ b/packages/provider-utils/src/types/index.ts @@ -13,19 +13,21 @@ export type { ToolResultOutput, ToolResultPart, } from './content-part'; +export type { Context } from './context'; export type { DataContent } from './data-content'; export { executeTool } from './execute-tool'; +export type { InferToolContext } from './infer-tool-context'; +export type { InferToolInput } from './infer-tool-input'; +export type { InferToolOutput } from './infer-tool-output'; export type { ModelMessage } from './model-message'; export type { ProviderOptions } from './provider-options'; export type { SystemModelMessage } from './system-model-message'; export { dynamicTool, tool, - type InferToolInput, - type InferToolOutput, type Tool, - type ToolExecutionOptions, type ToolExecuteFunction, + type ToolExecutionOptions, type ToolNeedsApprovalFunction, } from './tool'; export type { ToolApprovalRequest } from './tool-approval-request'; @@ -34,9 +36,12 @@ export type { ToolCall } from './tool-call'; export type { ToolContent, ToolModelMessage } from './tool-model-message'; export type { ToolResult } from './tool-result'; export type { UserContent, UserModelMessage } from './user-model-message'; + +import type { Context } from './context'; import type { ToolExecutionOptions } from './tool'; /** * @deprecated Use ToolExecutionOptions instead. */ -export type ToolCallOptions = ToolExecutionOptions; +export type ToolCallOptions = + ToolExecutionOptions; diff --git a/packages/provider-utils/src/types/infer-tool-context.test-d.ts b/packages/provider-utils/src/types/infer-tool-context.test-d.ts new file mode 100644 index 000000000000..e32503d7328e --- /dev/null +++ b/packages/provider-utils/src/types/infer-tool-context.test-d.ts @@ -0,0 +1,38 @@ +import { describe, expectTypeOf, it } from 'vitest'; +import { z } from 'zod/v4'; +import type { InferToolContext } from './infer-tool-context'; +import { tool } from './tool'; +import { Context } from './context'; + +describe('InferToolContext', () => { + it('infers the context type from a tool with contextSchema', () => { + const weatherTool = tool({ + inputSchema: z.object({ + city: z.string(), + }), + contextSchema: z.object({ + userId: z.string(), + role: z.string(), + }), + execute: async () => ({ temperature: 72 }), + }); + + expectTypeOf>().toEqualTypeOf<{ + userId: string; + role: string; + }>(); + }); + + it('infers the generic context type from a tool without contextSchema', () => { + const weatherTool = tool({ + inputSchema: z.object({ + city: z.string(), + }), + execute: async () => ({ temperature: 72 }), + }); + + expectTypeOf< + InferToolContext + >().toEqualTypeOf(); + }); +}); diff --git a/packages/provider-utils/src/types/infer-tool-context.ts b/packages/provider-utils/src/types/infer-tool-context.ts new file mode 100644 index 000000000000..294074e3f096 --- /dev/null +++ b/packages/provider-utils/src/types/infer-tool-context.ts @@ -0,0 +1,7 @@ +import type { Tool } from './tool'; + +/** + * Infer the context type of a tool. + */ +export type InferToolContext> = + TOOL extends Tool ? CONTEXT : never; diff --git a/packages/provider-utils/src/types/infer-tool-input.test-d.ts b/packages/provider-utils/src/types/infer-tool-input.test-d.ts new file mode 100644 index 000000000000..0045b7ac8779 --- /dev/null +++ b/packages/provider-utils/src/types/infer-tool-input.test-d.ts @@ -0,0 +1,21 @@ +import { describe, expectTypeOf, it } from 'vitest'; +import { z } from 'zod/v4'; +import type { InferToolInput } from './infer-tool-input'; +import { tool } from './tool'; + +describe('InferToolInput', () => { + it('infers the input type from a tool with inputSchema', () => { + const weatherTool = tool({ + inputSchema: z.object({ + city: z.string(), + countryCode: z.string().length(2), + }), + execute: async () => ({ temperature: 72 }), + }); + + expectTypeOf>().toEqualTypeOf<{ + city: string; + countryCode: string; + }>(); + }); +}); diff --git a/packages/provider-utils/src/types/infer-tool-input.ts b/packages/provider-utils/src/types/infer-tool-input.ts new file mode 100644 index 000000000000..8ae9c8678f8f --- /dev/null +++ b/packages/provider-utils/src/types/infer-tool-input.ts @@ -0,0 +1,7 @@ +import type { Tool } from './tool'; + +/** + * Infer the input type of a tool. + */ +export type InferToolInput> = + TOOL extends Tool ? INPUT : never; diff --git a/packages/provider-utils/src/types/infer-tool-output.test-d.ts b/packages/provider-utils/src/types/infer-tool-output.test-d.ts new file mode 100644 index 000000000000..10c2a08c1ecd --- /dev/null +++ b/packages/provider-utils/src/types/infer-tool-output.test-d.ts @@ -0,0 +1,40 @@ +import { describe, expectTypeOf, it } from 'vitest'; +import { z } from 'zod/v4'; +import type { InferToolOutput } from './infer-tool-output'; +import { tool } from './tool'; + +describe('InferToolOutput', () => { + it('infers the output type from a tool with execute function', () => { + const weatherTool = tool({ + inputSchema: z.object({ + city: z.string(), + }), + execute: async () => ({ + temperature: 72, + conditions: 'sunny' as const, + }), + }); + + expectTypeOf>().toEqualTypeOf<{ + temperature: number; + conditions: 'sunny'; + }>(); + }); + + it('infers the output type from a tool with outputSchema', () => { + const weatherTool = tool({ + inputSchema: z.object({ + city: z.string(), + }), + outputSchema: z.object({ + temperature: z.number(), + conditions: z.literal('sunny'), + }), + }); + + expectTypeOf>().toEqualTypeOf<{ + temperature: number; + conditions: 'sunny'; + }>(); + }); +}); diff --git a/packages/provider-utils/src/types/infer-tool-output.ts b/packages/provider-utils/src/types/infer-tool-output.ts new file mode 100644 index 000000000000..6401e1e9407d --- /dev/null +++ b/packages/provider-utils/src/types/infer-tool-output.ts @@ -0,0 +1,7 @@ +import type { Tool } from './tool'; + +/** + * Infer the output type of a tool. + */ +export type InferToolOutput> = + TOOL extends Tool ? OUTPUT : never; diff --git a/packages/provider-utils/src/types/tool.test-d.ts b/packages/provider-utils/src/types/tool.test-d.ts index 33b313457381..cc80184f7c7d 100644 --- a/packages/provider-utils/src/types/tool.test-d.ts +++ b/packages/provider-utils/src/types/tool.test-d.ts @@ -1,10 +1,10 @@ -import { LanguageModelV4ToolResultPart } from '@ai-sdk/provider'; import { describe, expectTypeOf, it } from 'vitest'; import { z } from 'zod/v4'; import { FlexibleSchema } from '../schema'; +import { ToolResultOutput } from './content-part'; +import { Context } from './context'; import { ModelMessage } from './model-message'; import { Tool, tool, ToolExecuteFunction } from './tool'; -import { ToolResultOutput } from './content-part'; describe('tool type', () => { describe('input type', () => { @@ -13,7 +13,9 @@ describe('tool type', () => { inputSchema: z.object({ number: z.number() }), }); - expectTypeOf(aTool).toEqualTypeOf>(); + expectTypeOf(aTool).toEqualTypeOf< + Tool<{ number: number }, never, Context> + >(); expectTypeOf(aTool.execute).toEqualTypeOf(); expectTypeOf(aTool.execute).not.toEqualTypeOf(); expectTypeOf(aTool.inputSchema).toEqualTypeOf< @@ -26,7 +28,7 @@ describe('tool type', () => { inputSchema: null as unknown as FlexibleSchema, }); - expectTypeOf(aTool).toEqualTypeOf>(); + expectTypeOf(aTool).toEqualTypeOf>(); expectTypeOf(aTool.execute).toEqualTypeOf(); expectTypeOf(aTool.execute).not.toEqualTypeOf(); expectTypeOf(aTool.inputSchema).toEqualTypeOf>(); @@ -78,9 +80,11 @@ describe('tool type', () => { }, }); - expectTypeOf(aTool).toEqualTypeOf>(); + expectTypeOf(aTool).toEqualTypeOf< + Tool<{ number: number }, 'test', Context> + >(); expectTypeOf(aTool.execute).toMatchTypeOf< - ToolExecuteFunction<{ number: number }, 'test'> | undefined + ToolExecuteFunction<{ number: number }, 'test', Context> | undefined >(); expectTypeOf(aTool.execute).not.toEqualTypeOf(); expectTypeOf(aTool.inputSchema).toEqualTypeOf< @@ -96,9 +100,11 @@ describe('tool type', () => { }, }); - expectTypeOf(aTool).toEqualTypeOf>(); + expectTypeOf(aTool).toEqualTypeOf< + Tool<{ number: number }, 'test', Context> + >(); expectTypeOf(aTool.execute).toEqualTypeOf< - ToolExecuteFunction<{ number: number }, 'test'> | undefined + ToolExecuteFunction<{ number: number }, 'test', Context> | undefined >(); expectTypeOf(aTool.inputSchema).toEqualTypeOf< FlexibleSchema<{ number: number }> @@ -176,7 +182,7 @@ describe('tool type', () => { expectTypeOf(options).toEqualTypeOf<{ toolCallId: string; messages: ModelMessage[]; - experimental_context?: unknown; + experimental_context: Context; }>(); return true; }, @@ -189,7 +195,7 @@ describe('tool type', () => { options: { toolCallId: string; messages: ModelMessage[]; - experimental_context: unknown; + experimental_context: Context; }, ) => boolean | PromiseLike) | undefined @@ -205,7 +211,7 @@ describe('tool type', () => { expectTypeOf(options).toEqualTypeOf<{ toolCallId: string; messages: ModelMessage[]; - experimental_context?: unknown; + experimental_context: Context; }>(); return true; }, @@ -218,7 +224,7 @@ describe('tool type', () => { options: { toolCallId: string; messages: ModelMessage[]; - experimental_context: unknown; + experimental_context: Context; }, ) => boolean | PromiseLike) | undefined diff --git a/packages/provider-utils/src/types/tool.ts b/packages/provider-utils/src/types/tool.ts index ea795973d9a5..ebbe371ffe43 100644 --- a/packages/provider-utils/src/types/tool.ts +++ b/packages/provider-utils/src/types/tool.ts @@ -3,11 +3,12 @@ import { FlexibleSchema } from '../schema'; import { ToolResultOutput } from './content-part'; import { ModelMessage } from './model-message'; import { ProviderOptions } from './provider-options'; +import { Context } from './context'; /** * Additional options that are sent into each tool call. */ -export interface ToolExecutionOptions { +export interface ToolExecutionOptions { /** * The ID of the tool call. You can use it e.g. when sending tool-call related information with stream data. */ @@ -36,13 +37,13 @@ export interface ToolExecutionOptions { * * Experimental (can break in patch releases). */ - experimental_context?: unknown; + experimental_context: CONTEXT; } /** * Function that is called to determine if the tool needs approval before it can be executed. */ -export type ToolNeedsApprovalFunction = ( +export type ToolNeedsApprovalFunction = ( input: INPUT, options: { /** @@ -61,13 +62,13 @@ export type ToolNeedsApprovalFunction = ( * * Experimental (can break in patch releases). */ - experimental_context?: unknown; + experimental_context: CONTEXT; }, ) => boolean | PromiseLike; -export type ToolExecuteFunction = ( +export type ToolExecuteFunction = ( input: INPUT, - options: ToolExecutionOptions, + options: ToolExecutionOptions, ) => AsyncIterable | PromiseLike | OUTPUT; // 0 extends 1 & N checks for any @@ -78,7 +79,11 @@ type NeverOptional = 0 extends 1 & N ? Partial> : T; -type ToolOutputProperties = NeverOptional< +type ToolOutputProperties< + INPUT, + OUTPUT, + CONTEXT extends Context, +> = NeverOptional< OUTPUT, | { /** @@ -88,7 +93,7 @@ type ToolOutputProperties = NeverOptional< * @args is the input of the tool call. * @options.abortSignal is a signal that can be used to abort the tool call. */ - execute: ToolExecuteFunction; + execute: ToolExecuteFunction; outputSchema?: FlexibleSchema; } @@ -108,6 +113,7 @@ type ToolOutputProperties = NeverOptional< export type Tool< INPUT extends JSONValue | unknown | never = any, OUTPUT extends JSONValue | unknown | never = any, + CONTEXT extends Context = Context, > = { /** * An optional description of what the tool does. @@ -143,12 +149,17 @@ export type Tool< */ inputExamples?: Array<{ input: NoInfer }>; + contextSchema?: FlexibleSchema; + /** * Whether the tool needs approval before it can be executed. */ needsApproval?: | boolean - | ToolNeedsApprovalFunction<[INPUT] extends [never] ? unknown : INPUT>; + | ToolNeedsApprovalFunction< + [INPUT] extends [never] ? unknown : INPUT, + NoInfer + >; /** * Strict mode setting for the tool. @@ -163,14 +174,18 @@ export type Tool< * Optional function that is called when the argument streaming starts. * Only called when the tool is used in a streaming context. */ - onInputStart?: (options: ToolExecutionOptions) => void | PromiseLike; + onInputStart?: ( + options: ToolExecutionOptions>, + ) => void | PromiseLike; /** * Optional function that is called when an argument streaming delta is available. * Only called when the tool is used in a streaming context. */ onInputDelta?: ( - options: { inputTextDelta: string } & ToolExecutionOptions, + options: { inputTextDelta: string } & ToolExecutionOptions< + NoInfer + >, ) => void | PromiseLike; /** @@ -180,9 +195,9 @@ export type Tool< onInputAvailable?: ( options: { input: [INPUT] extends [never] ? unknown : INPUT; - } & ToolExecutionOptions, + } & ToolExecutionOptions>, ) => void | PromiseLike; -} & ToolOutputProperties & { +} & ToolOutputProperties> & { /** * Optional conversion function that maps the tool result to an output that can be used by the language model. * @@ -257,28 +272,22 @@ export type Tool< } ); -/** - * Infer the input type of a tool. - */ -export type InferToolInput = - TOOL extends Tool ? INPUT : never; - -/** - * Infer the output type of a tool. - */ -export type InferToolOutput = - TOOL extends Tool ? OUTPUT : never; - /** * Helper function for inferring the execute args of a tool. */ // Note: overload order is important for auto-completion -export function tool( - tool: Tool, -): Tool; -export function tool(tool: Tool): Tool; -export function tool(tool: Tool): Tool; -export function tool(tool: Tool): Tool; +export function tool( + tool: Tool, +): Tool; +export function tool( + tool: Tool, +): Tool; +export function tool( + tool: Tool, +): Tool; +export function tool( + tool: Tool, +): Tool; export function tool(tool: any): any { return tool; } @@ -291,7 +300,7 @@ export function dynamicTool(tool: { title?: string; providerOptions?: ProviderOptions; inputSchema: FlexibleSchema; - execute: ToolExecuteFunction; + execute: ToolExecuteFunction; /** * Optional conversion function that maps the tool result to an output that can be used by the language model. @@ -318,8 +327,8 @@ export function dynamicTool(tool: { /** * Whether the tool needs approval before it can be executed. */ - needsApproval?: boolean | ToolNeedsApprovalFunction; -}): Tool & { + needsApproval?: boolean | ToolNeedsApprovalFunction; +}): Tool & { type: 'dynamic'; } { return { ...tool, type: 'dynamic' };