diff --git a/AGENTS.md b/AGENTS.md index 08097a77..c6d32f94 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -241,20 +241,20 @@ logger.warn(`field=<${value}> | statement one | statement two`) **Examples**: ```typescript -// ✅ Good: Context fields with message +// Good: Context fields with message logger.warn(`stop_reason=<${stopReason}>, fallback=<${fallback}> | unknown stop reason, converting to camelCase`) logger.warn(`event_type=<${eventType}> | unsupported bedrock event type`) -// ✅ Good: Simple message without context fields +// Good: Simple message without context fields logger.warn('cache points are not supported in openai system prompts, ignoring cache points') -// ✅ Good: Multiple statements separated by pipes +// Good: Multiple statements separated by pipes logger.warn(`request_id=<${id}> | processing request | starting validation`) -// ❌ Bad: Not using angle brackets for values +// Bad: Not using angle brackets for values logger.warn(`stop_reason=${stopReason} | unknown stop reason`) -// ❌ Bad: Using punctuation +// Bad: Using punctuation logger.warn(`event_type=<${eventType}> | Unsupported event type.`) ``` @@ -289,7 +289,7 @@ src/ **Example**: ```typescript -// ✅ Good: Main function first, helpers follow +// Good: Main function first, helpers follow export async function* mainFunction() { const result = await helperFunction1() return helperFunction2(result) @@ -303,7 +303,7 @@ function helperFunction2(input: string) { // Implementation } -// ❌ Bad: Helpers before main function +// Bad: Helpers before main function async function helperFunction1() { // Implementation } @@ -325,10 +325,10 @@ test/integ/ **Optional chaining for null safety**: Prefer optional chaining over verbose `typeof` checks when accessing potentially undefined properties: ```typescript -// ✅ Good: Optional chaining +// Good: Optional chaining return globalThis?.process?.env?.API_KEY -// ❌ Bad: Verbose typeof checks +// Bad: Verbose typeof checks if (typeof process !== 'undefined' && typeof process.env !== 'undefined') { return process.env.API_KEY } @@ -369,7 +369,7 @@ export function getData(): any { **Private fields**: Use underscore prefix for private class fields to improve readability and distinguish them from public members. ```typescript -// ✅ Good: Private fields with underscore prefix +// Good: Private fields with underscore prefix export class Example { private readonly _config: Config private _state: State @@ -384,7 +384,7 @@ export class Example { } } -// ❌ Bad: No underscore for private fields +// Bad: No underscore for private fields export class Example { private readonly config: Config // Missing underscore @@ -497,7 +497,7 @@ import type { Options, Config } from '../types' **When defining interfaces or types, organize them so the top-level interface comes first, followed by its dependencies, and then all nested dependencies.** ```typescript -// ✅ Correct - Top-level first, then dependencies +// Correct - Top-level first, then dependencies export interface Message { role: Role content: ContentBlock[] @@ -537,7 +537,7 @@ export class ToolResultBlock { } } -// ❌ Wrong - Dependencies before top-level +// Wrong - Dependencies before top-level export type Role = 'user' | 'assistant' export interface TextBlockData { @@ -557,7 +557,7 @@ export interface Message { // Top-level should come first **When creating discriminated unions with a `type` field, the type value MUST match the interface name with the first letter lowercase.** ```typescript -// ✅ Correct - type matches class name (first letter lowercase) +// Correct - type matches class name (first letter lowercase) export class TextBlock { readonly type = 'textBlock' as const // Matches 'TextBlock' class name readonly text: string @@ -572,7 +572,7 @@ export class CachePointBlock { export type ContentBlock = TextBlock | ToolUseBlock | CachePointBlock -// ❌ Wrong - type doesn't match class name +// Wrong - type doesn't match class name export class CachePointBlock { readonly type = 'cachePoint' as const // Should be 'cachePointBlock' readonly cacheType: 'default' @@ -581,6 +581,47 @@ export class CachePointBlock { **Rationale**: This consistent naming makes discriminated unions predictable and improves code readability. Developers can easily understand the relationship between the type value and the class. +### API Union Types (Bedrock Pattern) + +When the upstream API (e.g., Bedrock) defines a type as a **UNION** ("only one member can be specified"), model it as a TypeScript `type` union with each variant's field **required** — not an `interface` with optional fields. This allows non-breaking expansion when new variants are added. + +The Bedrock API marks all fields in union types as "Not Required" as a mechanism for future extensibility. In TypeScript, encode the mutual exclusivity using `|` with each variant having its field required. The "not required" from the API docs means "this field won't be present if a different variant is active." + +```typescript +// Correct: type union — each variant has its field required +// Adding a new variant later (e.g., | { image: ImageData }) is non-breaking +export type CitationSourceContent = { text: string } + +// Correct: multi-variant union with object-key discrimination +export type DocumentSourceData = + | { bytes: Uint8Array } + | { text: string } + | { content: DocumentContentBlockData[] } + | { s3Location: S3LocationData } + +// Correct: multi-variant union for citation locations +export type CitationLocation = + | { documentChar: DocumentCharLocation } + | { documentPage: DocumentPageLocation } + | { web: WebLocation } + +// Wrong: interface with optional fields — cannot expand without breaking +export interface CitationSourceContent { + text?: string +} + +// Wrong: interface with required field — changing to union later is breaking +export interface CitationSourceContent { + text: string +} +``` + +**Key points**: +- Use `type` alias (not `interface`) so it can be expanded to a union later +- Each variant's field is **required** within that variant +- Use object-key discrimination (`'text' in source`) to narrow variants at runtime +- See `DocumentSourceData` in `src/types/media.ts` and `CitationLocation` in `src/types/citations.ts` for reference implementations + ### Error Handling ```typescript @@ -614,13 +655,13 @@ export class ValidationError extends Error { When asserting on objects, prefer `toStrictEqual` for full object comparison rather than checking individual fields: ```typescript -// ✅ Good: Full object assertion with toStrictEqual +// Good: Full object assertion with toStrictEqual expect(provider.getConfig()).toStrictEqual({ modelId: 'gemini-2.5-flash', params: { temperature: 0.5 }, }) -// ❌ Bad: Checking individual fields +// Bad: Checking individual fields expect(provider.getConfig().modelId).toBe('gemini-2.5-flash') expect(provider.getConfig().params.temperature).toBe(0.5) ``` @@ -639,7 +680,7 @@ When adding or modifying dependencies, you **MUST** follow the guidelines in [do ## Things to Do -✅ **Do**: +**Do**: - Use relative imports for internal modules - Co-locate unit tests with source under `__tests__` directories - Follow nested describe pattern for test organization @@ -652,7 +693,7 @@ When adding or modifying dependencies, you **MUST** follow the guidelines in [do ## Things NOT to Do -❌ **Don't**: +**Don't**: - Use `any` type (enforced by ESLint) - Put unit tests in separate `tests/` directory (use `src/**/__tests__/**`) - Skip documentation for exported functions diff --git a/src/__fixtures__/mock-message-model.ts b/src/__fixtures__/mock-message-model.ts index fa128aaa..4077cf05 100644 --- a/src/__fixtures__/mock-message-model.ts +++ b/src/__fixtures__/mock-message-model.ts @@ -263,6 +263,19 @@ export class MockMessageModel extends Model { // This is typically used in system prompts or message content for guardrail evaluation break + case 'citationsBlock': + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'citationsContentDelta', + citations: block.citations, + content: block.content, + }, + } + yield { type: 'modelContentBlockStopEvent' } + break + case 'imageBlock': case 'videoBlock': case 'documentBlock': diff --git a/src/__fixtures__/slim-types.ts b/src/__fixtures__/slim-types.ts index 9a475893..e4f6684e 100644 --- a/src/__fixtures__/slim-types.ts +++ b/src/__fixtures__/slim-types.ts @@ -14,6 +14,7 @@ import type { JsonBlock, } from '../types/messages.js' import type { ImageBlock, VideoBlock, DocumentBlock } from '../types/media.js' +import type { CitationsBlock } from '../types/citations.js' /** * Strips the toJSON method from a type, allowing plain objects to be used in tests. @@ -42,6 +43,7 @@ export type PlainContentBlock = | NoJSON | NoJSON | NoJSON + | NoJSON /** * Plain system content block without toJSON method. diff --git a/src/index.ts b/src/index.ts index 64d0e41d..31d13bd5 100644 --- a/src/index.ts +++ b/src/index.ts @@ -66,6 +66,18 @@ export { contentBlockFromData, } from './types/messages.js' +// Citation types +export type { + CitationsBlockData, + Citation, + CitationLocation, + CitationSourceContent, + CitationGeneratedContent, +} from './types/citations.js' + +// Citation class +export { CitationsBlock } from './types/citations.js' + // Media classes export { S3Location, ImageBlock, VideoBlock, DocumentBlock } from './types/media.js' @@ -122,6 +134,7 @@ export type { TextDelta, ToolUseInputDelta, ReasoningContentDelta, + CitationsContentDelta, ContentBlockDelta, ModelContentBlockDeltaEventData, ModelContentBlockDeltaEvent, diff --git a/src/models/__tests__/bedrock.test.ts b/src/models/__tests__/bedrock.test.ts index 4fff826d..32e591a0 100644 --- a/src/models/__tests__/bedrock.test.ts +++ b/src/models/__tests__/bedrock.test.ts @@ -6,6 +6,7 @@ import { ContextWindowOverflowError, ModelThrottledError } from '../../errors.js import { Message, ReasoningBlock, ToolUseBlock, ToolResultBlock, JsonBlock } from '../../types/messages.js' import type { SystemContentBlock } from '../../types/messages.js' import { TextBlock, GuardContentBlock, CachePointBlock } from '../../types/messages.js' +import { CitationsBlock } from '../../types/citations.js' import type { StreamOptions } from '../model.js' import { collectIterator } from '../../__fixtures__/model-test-helpers.js' @@ -761,6 +762,80 @@ describe('BedrockModel', () => { }) }) + it('yields and validates citationsContent events correctly', async () => { + // Bedrock wire format uses object-key discrimination + const bedrockCitationsData = { + citations: [ + { + location: { documentChar: { documentIndex: 0, start: 10, end: 50 } }, + sourceContent: [{ text: 'source text' }], + source: 'doc-0', + title: 'Test Doc', + }, + ], + content: [{ text: 'generated text' }], + } + + const mockSend = vi.fn(async () => { + if (stream) { + return { + stream: (async function* (): AsyncGenerator { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { + contentBlockDelta: { + delta: { citationsContent: bedrockCitationsData }, + }, + } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn' } } + yield { + metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, metrics: { latencyMs: 100 } }, + } + })(), + } + } else { + return { + output: { + message: { + role: 'assistant', + content: [{ citationsContent: bedrockCitationsData }], + }, + }, + stopReason: 'end_turn', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + metrics: { latencyMs: 100 }, + } + } + }) + mockBedrockClientImplementation({ send: mockSend }) + + const provider = new BedrockModel({ stream }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Cite this.')] })] + const events = await collectIterator(provider.stream(messages)) + + // SDK events should use type-field discrimination + expect(events).toContainEqual({ role: 'assistant', type: 'modelMessageStartEvent' }) + expect(events).toContainEqual({ type: 'modelContentBlockStartEvent' }) + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'citationsContentDelta', + citations: [ + { + location: { type: 'documentChar', documentIndex: 0, start: 10, end: 50 }, + sourceContent: [{ text: 'source text' }], + source: 'doc-0', + title: 'Test Doc', + }, + ], + content: [{ text: 'generated text' }], + }, + }) + expect(events).toContainEqual({ type: 'modelContentBlockStopEvent' }) + expect(events).toContainEqual({ stopReason: 'endTurn', type: 'modelMessageStopEvent' }) + }) + describe('error handling', async () => { it.each([ { @@ -1475,6 +1550,121 @@ describe('BedrockModel', () => { }) }) + describe('citations content block formatting', () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + + it('maps SDK CitationLocation types to Bedrock object-key format through formatting pipeline', async () => { + const provider = new BedrockModel() + // SDK format uses type-field discrimination + const sdkCitations = [ + { + location: { type: 'documentChar' as const, documentIndex: 0, start: 150, end: 300 }, + source: 'doc-0', + sourceContent: [{ text: 'char source' }], + title: 'Text Document', + }, + { + location: { type: 'documentPage' as const, documentIndex: 0, start: 2, end: 3 }, + source: 'doc-0', + sourceContent: [{ text: 'page source' }], + title: 'PDF Document', + }, + { + location: { type: 'documentChunk' as const, documentIndex: 1, start: 5, end: 8 }, + source: 'doc-1', + sourceContent: [{ text: 'chunk source' }], + title: 'Chunked Document', + }, + { + location: { type: 'searchResult' as const, searchResultIndex: 0, start: 25, end: 150 }, + source: 'search-0', + sourceContent: [{ text: 'search source' }], + title: 'Search Result', + }, + { + location: { type: 'web' as const, url: 'https://example.com/doc', domain: 'example.com' }, + source: 'web-0', + sourceContent: [{ text: 'web source' }], + title: 'Web Page', + }, + ] + + const messages = [ + new Message({ + role: 'assistant', + content: [ + new CitationsBlock({ + citations: sdkCitations, + content: [{ text: 'generated text with all citation types' }], + }), + ], + }), + new Message({ + role: 'user', + content: [new TextBlock('Follow up')], + }), + ] + + collectIterator(provider.stream(messages)) + + // Bedrock wire format uses object-key discrimination + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'assistant', + content: [ + { + citationsContent: { + citations: [ + { + location: { documentChar: { documentIndex: 0, start: 150, end: 300 } }, + source: 'doc-0', + sourceContent: [{ text: 'char source' }], + title: 'Text Document', + }, + { + location: { documentPage: { documentIndex: 0, start: 2, end: 3 } }, + source: 'doc-0', + sourceContent: [{ text: 'page source' }], + title: 'PDF Document', + }, + { + location: { documentChunk: { documentIndex: 1, start: 5, end: 8 } }, + source: 'doc-1', + sourceContent: [{ text: 'chunk source' }], + title: 'Chunked Document', + }, + { + location: { + searchResultLocation: { searchResultIndex: 0, start: 25, end: 150 }, + }, + source: 'search-0', + sourceContent: [{ text: 'search source' }], + title: 'Search Result', + }, + { + location: { web: { url: 'https://example.com/doc' } }, + source: 'web-0', + sourceContent: [{ text: 'web source' }], + title: 'Web Page', + }, + ], + content: [{ text: 'generated text with all citation types' }], + }, + }, + ], + }, + { + role: 'user', + content: [{ text: 'Follow up' }], + }, + ], + }) + ) + }) + }) + describe('includeToolResultStatus configuration', async () => { const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) diff --git a/src/models/bedrock.ts b/src/models/bedrock.ts index 06705dee..95ed417a 100644 --- a/src/models/bedrock.ts +++ b/src/models/bedrock.ts @@ -35,11 +35,15 @@ import { DocumentFormat, ImageFormat, type BedrockRuntimeClientResolvedConfig, + type CitationLocation as BedrockCitationLocation, + type Citation as BedrockCitation, + type CitationsContentBlock as BedrockCitationsContentBlock, } from '@aws-sdk/client-bedrock-runtime' import { type BaseModelConfig, Model, type StreamOptions } from '../models/model.js' import type { ContentBlock, Message, StopReason, ToolUseBlock } from '../types/messages.js' import type { ImageSource, VideoSource, DocumentSource } from '../types/media.js' -import type { ModelStreamEvent, ReasoningContentDelta, Usage } from '../models/streaming.js' +import type { CitationsContentDelta, ModelStreamEvent, ReasoningContentDelta, Usage } from '../models/streaming.js' +import type { Citation, CitationLocation, CitationsBlockData } from '../types/citations.js' import type { JSONValue } from '../types/json.js' import { ContextWindowOverflowError, ModelThrottledError, normalizeError } from '../errors.js' import { ensureDefined } from '../types/validation.js' @@ -632,6 +636,14 @@ export class BedrockModel extends Model { }, } + case 'citationsBlock': + return { + citationsContent: { + citations: block.citations.map((c) => this._mapCitationToBedrock(c)), + content: block.content, + }, + } + case 'guardContentBlock': { if (block.text) { return { @@ -802,6 +814,19 @@ export class BedrockModel extends Model { events.push({ type: 'modelContentBlockStopEvent' }) }, + citationsContent: (block: BedrockCitationsContentBlock): void => { + if (!block) return + events.push({ type: 'modelContentBlockStartEvent' }) + + const mapped = this._mapBedrockCitationsData(block) + const delta: CitationsContentDelta = { + type: 'citationsContentDelta', + citations: mapped.citations, + content: mapped.content, + } + events.push({ type: 'modelContentBlockDeltaEvent', delta }) + events.push({ type: 'modelContentBlockStopEvent' }) + }, } const content = ensureDefined(message.content, 'message.content') @@ -915,6 +940,16 @@ export class BedrockModel extends Model { events.push({ type: 'modelContentBlockDeltaEvent', delta: reasoningDelta }) } }, + citationsContent: (block: BedrockCitationsContentBlock): void => { + if (!block) return + const mapped = this._mapBedrockCitationsData(block) + const delta: CitationsContentDelta = { + type: 'citationsContentDelta', + citations: mapped.citations, + content: mapped.content, + } + events.push({ type: 'modelContentBlockDeltaEvent', delta }) + }, } for (const key in delta) { @@ -1049,6 +1084,111 @@ export class BedrockModel extends Model { return mappedStopReason } + + /** + * Maps a Bedrock object-key citation location to the SDK's type-field format. + * + * Bedrock uses object-key discrimination (`{ documentChar: { ... } }`) while the SDK uses + * type-field discrimination (`{ type: 'documentChar', ... }`). Also normalizes Bedrock's + * `searchResultLocation` key to the shorter `searchResult`. + * + * @param bedrockLocation - Bedrock citation location with object-key discrimination + * @returns SDK CitationLocation with type field discrimination + */ + private _mapBedrockCitationLocation(bedrockLocation: BedrockCitationLocation): CitationLocation | undefined { + if (bedrockLocation.documentChar) { + const loc = bedrockLocation.documentChar + return { type: 'documentChar', documentIndex: loc.documentIndex!, start: loc.start!, end: loc.end! } + } + if (bedrockLocation.documentPage) { + const loc = bedrockLocation.documentPage + return { type: 'documentPage', documentIndex: loc.documentIndex!, start: loc.start!, end: loc.end! } + } + if (bedrockLocation.documentChunk) { + const loc = bedrockLocation.documentChunk + return { type: 'documentChunk', documentIndex: loc.documentIndex!, start: loc.start!, end: loc.end! } + } + if (bedrockLocation.searchResultLocation) { + const loc = bedrockLocation.searchResultLocation + return { type: 'searchResult', searchResultIndex: loc.searchResultIndex!, start: loc.start!, end: loc.end! } + } + if (bedrockLocation.web) { + const loc = bedrockLocation.web + return { type: 'web', url: loc.url!, ...(loc.domain && { domain: loc.domain }) } + } + logger.warn(`citation_location=<${JSON.stringify(bedrockLocation)}> | unknown citation location type`) + return undefined + } + + /** + * Maps a Bedrock CitationsContentBlock to SDK CitationsBlockData. + * + * @param bedrockData - Bedrock CitationsContentBlock + * @returns SDK CitationsBlockData with type-field CitationLocations + */ + private _mapBedrockCitationsData(bedrockData: BedrockCitationsContentBlock): CitationsBlockData { + return { + citations: (bedrockData.citations ?? []) + .map((citation) => { + const location = citation.location ? this._mapBedrockCitationLocation(citation.location) : undefined + if (!location) return undefined + return { + source: citation.source ?? '', + title: citation.title ?? '', + sourceContent: (citation.sourceContent ?? []).map((sc) => ({ text: sc.text! })), + location, + } + }) + .filter((c) => c !== undefined), + content: (bedrockData.content ?? []).map((gc) => ({ text: gc.text! })), + } + } + + /** + * Maps an SDK Citation to Bedrock's Citation format. + * + * @param citation - SDK Citation with type-field location + * @returns Bedrock Citation with object-key location + */ + private _mapCitationToBedrock(citation: Citation): BedrockCitation { + return { + location: this._mapCitationLocationToBedrock(citation.location), + sourceContent: citation.sourceContent.map((sc) => ({ text: sc.text })), + source: citation.source, + title: citation.title, + } + } + + /** + * Maps an SDK CitationLocation to Bedrock's object-key format. + * + * @param location - SDK CitationLocation with type field + * @returns Bedrock CitationLocation with object-key discrimination + */ + private _mapCitationLocationToBedrock(location: CitationLocation): BedrockCitationLocation { + switch (location.type) { + case 'documentChar': { + const { type: _, ...fields } = location + return { documentChar: fields } + } + case 'documentPage': { + const { type: _, ...fields } = location + return { documentPage: fields } + } + case 'documentChunk': { + const { type: _, ...fields } = location + return { documentChunk: fields } + } + case 'searchResult': { + const { type: _, ...fields } = location + return { searchResultLocation: fields } + } + case 'web': + return { web: { url: location.url } } + default: + return location as unknown as BedrockCitationLocation + } + } } /** diff --git a/src/models/model.ts b/src/models/model.ts index 96555dc6..19e545c9 100644 --- a/src/models/model.ts +++ b/src/models/model.ts @@ -8,6 +8,8 @@ import { TextBlock, ToolUseBlock, } from '../types/messages.js' +import { CitationsBlock } from '../types/citations.js' +import type { Citation, CitationGeneratedContent } from '../types/citations.js' import type { ToolChoice, ToolSpec } from '../tools/types.js' import { ModelContentBlockDeltaEvent, @@ -203,6 +205,9 @@ export abstract class Model { signature?: string redactedContent?: Uint8Array } = {} + let accumulatedCitationsList: Citation[] = [] + let accumulatedCitationsContent: CitationGeneratedContent[] = [] + let hasCitations = false let errorToThrow: Error | undefined = undefined let stoppedMessage: Message | null = null let finalStopReason: StopReason | null = null @@ -228,23 +233,28 @@ export abstract class Model { accumulatedToolInput = '' accumulatedText = '' accumulatedReasoning = {} + accumulatedCitationsList = [] + accumulatedCitationsContent = [] + hasCitations = false break - case 'modelContentBlockDeltaEvent': - switch (event.delta.type) { - case 'textDelta': - accumulatedText += event.delta.text - break - case 'toolUseInputDelta': - accumulatedToolInput += event.delta.input - break - case 'reasoningContentDelta': - if (event.delta.text) accumulatedReasoning.text = (accumulatedReasoning.text ?? '') + event.delta.text - if (event.delta.signature) accumulatedReasoning.signature = event.delta.signature - if (event.delta.redactedContent) accumulatedReasoning.redactedContent = event.delta.redactedContent - break + case 'modelContentBlockDeltaEvent': { + const delta = event.delta + if (delta.type === 'textDelta') { + accumulatedText += delta.text + } else if (delta.type === 'toolUseInputDelta') { + accumulatedToolInput += delta.input + } else if (delta.type === 'reasoningContentDelta') { + if (delta.text) accumulatedReasoning.text = (accumulatedReasoning.text ?? '') + delta.text + if (delta.signature) accumulatedReasoning.signature = delta.signature + if (delta.redactedContent) accumulatedReasoning.redactedContent = delta.redactedContent + } else if (delta.type === 'citationsContentDelta') { + accumulatedCitationsList.push(...delta.citations) + accumulatedCitationsContent.push(...delta.content) + hasCitations = true } break + } case 'modelContentBlockStopEvent': { // Finalize and emit complete ContentBlock @@ -265,6 +275,12 @@ export abstract class Model { ...accumulatedReasoning, }) accumulatedReasoning = {} // Reset after creating reasoning block + } else if (hasCitations) { + block = new CitationsBlock({ + citations: accumulatedCitationsList, + content: accumulatedCitationsContent, + }) + hasCitations = false } else { block = new TextBlock(accumulatedText) } diff --git a/src/models/streaming.ts b/src/models/streaming.ts index 6b67e6e6..c908f88e 100644 --- a/src/models/streaming.ts +++ b/src/models/streaming.ts @@ -1,5 +1,6 @@ import type { Role, StopReason } from '../types/messages.js' import type { JSONValue } from '../types/json.js' +import type { Citation, CitationGeneratedContent } from '../types/citations.js' /** * ModelStreamEvent types for Model interactions. @@ -323,7 +324,7 @@ export interface ToolUseStart { * * This is a discriminated union for type-safe delta handling. */ -export type ContentBlockDelta = TextDelta | ToolUseInputDelta | ReasoningContentDelta +export type ContentBlockDelta = TextDelta | ToolUseInputDelta | ReasoningContentDelta | CitationsContentDelta /** * Text delta within a content block. @@ -383,6 +384,27 @@ export interface ReasoningContentDelta { redactedContent?: Uint8Array } +/** + * Citations content delta within a content block. + * Represents a citations content block from the model. + */ +export interface CitationsContentDelta { + /** + * Discriminator for citations content delta. + */ + type: 'citationsContentDelta' + + /** + * Array of citations linking generated content to source locations. + */ + citations: Citation[] + + /** + * The generated content associated with these citations. + */ + content: CitationGeneratedContent[] +} + /** * Token usage statistics for a model invocation. * Tracks input, output, and total tokens, plus cache-related metrics. diff --git a/src/types/__tests__/citations.test.ts b/src/types/__tests__/citations.test.ts new file mode 100644 index 00000000..f3ffd91f --- /dev/null +++ b/src/types/__tests__/citations.test.ts @@ -0,0 +1,115 @@ +import { describe, expect, it } from 'vitest' +import { CitationsBlock, type CitationsBlockData } from '../citations.js' + +describe('CitationsBlock', () => { + const singleCitationData: CitationsBlockData = { + citations: [ + { + location: { type: 'documentChar', documentIndex: 0, start: 10, end: 50 }, + source: 'doc-0', + sourceContent: [{ text: 'source text from document' }], + title: 'Test Document', + }, + ], + content: [{ text: 'generated text with citation' }], + } + + const allVariantsData: CitationsBlockData = { + citations: [ + { + location: { type: 'documentChar', documentIndex: 0, start: 150, end: 300 }, + source: 'doc-0', + sourceContent: [{ text: 'char source' }], + title: 'Text Document', + }, + { + location: { type: 'documentPage', documentIndex: 0, start: 2, end: 3 }, + source: 'doc-0', + sourceContent: [{ text: 'page source' }], + title: 'PDF Document', + }, + { + location: { type: 'documentChunk', documentIndex: 1, start: 5, end: 8 }, + source: 'doc-1', + sourceContent: [{ text: 'chunk source' }], + title: 'Chunked Document', + }, + { + location: { type: 'searchResult', searchResultIndex: 0, start: 25, end: 150 }, + source: 'search-0', + sourceContent: [{ text: 'search source' }], + title: 'Search Result', + }, + { + location: { type: 'web', url: 'https://example.com/doc', domain: 'example.com' }, + source: 'web-0', + sourceContent: [{ text: 'web source' }, { text: 'additional source' }], + title: 'Web Page', + }, + ], + content: [{ text: 'first generated' }, { text: 'second generated' }], + } + + it('creates block with correct type discriminator', () => { + const block = new CitationsBlock(singleCitationData) + expect(block.type).toBe('citationsBlock') + }) + + it('stores citations and content', () => { + const block = new CitationsBlock(singleCitationData) + expect(block.citations).toStrictEqual(singleCitationData.citations) + expect(block.content).toStrictEqual(singleCitationData.content) + }) + + it('round-trips all CitationLocation variants, multiple citations, and multiple content blocks', () => { + const original = new CitationsBlock(allVariantsData) + const json = original.toJSON() + const restored = CitationsBlock.fromJSON(json) + + expect(restored).toEqual(original) + expect(restored.citations).toHaveLength(5) + + expect(restored.citations[0]!.location.type).toBe('documentChar') + expect(restored.citations[1]!.location.type).toBe('documentPage') + expect(restored.citations[2]!.location.type).toBe('documentChunk') + expect(restored.citations[3]!.location.type).toBe('searchResult') + expect(restored.citations[4]!.location.type).toBe('web') + + // Verify web-specific optional domain field survives round-trip + const webLoc = restored.citations[4]!.location + if (webLoc.type === 'web') { + expect(webLoc.domain).toBe('example.com') + } + }) + + it('handles empty arrays', () => { + const data: CitationsBlockData = { + citations: [], + content: [], + } + const block = new CitationsBlock(data) + expect(block.citations).toStrictEqual([]) + expect(block.content).toStrictEqual([]) + + const restored = CitationsBlock.fromJSON(block.toJSON()) + expect(restored).toEqual(block) + }) + + it('toJSON returns wrapped format', () => { + const block = new CitationsBlock(singleCitationData) + const json = block.toJSON() + expect(json).toStrictEqual({ + citationsContent: { + citations: singleCitationData.citations, + content: singleCitationData.content, + }, + }) + }) + + it('works with JSON.stringify', () => { + const original = new CitationsBlock(allVariantsData) + const jsonString = JSON.stringify(original) + const restored = CitationsBlock.fromJSON(JSON.parse(jsonString)) + expect(restored).toEqual(original) + }) +}) diff --git a/src/types/__tests__/messages.test.ts b/src/types/__tests__/messages.test.ts index 15210cbc..b3e8b9e4 100644 --- a/src/types/__tests__/messages.test.ts +++ b/src/types/__tests__/messages.test.ts @@ -14,6 +14,7 @@ import { systemPromptToData, } from '../messages.js' import { ImageBlock, VideoBlock, DocumentBlock, encodeBase64 } from '../media.js' +import { CitationsBlock } from '../citations.js' describe('Message', () => { test('creates message with role and content', () => { @@ -281,6 +282,31 @@ describe('Message.fromMessageData', () => { expect(message.content[0]!.type).toBe('documentBlock') }) + it('converts citations content block data to CitationsBlock', () => { + const messageData: MessageData = { + role: 'assistant', + content: [ + { + citationsContent: { + citations: [ + { + location: { type: 'documentChar', documentIndex: 0, start: 10, end: 50 }, + source: 'doc-0', + sourceContent: [{ text: 'source text' }], + title: 'Test Doc', + }, + ], + content: [{ text: 'generated text' }], + }, + }, + ], + } + const message = Message.fromMessageData(messageData) + expect(message.content).toHaveLength(1) + expect(message.content[0]).toBeInstanceOf(CitationsBlock) + expect(message.content[0]!.type).toBe('citationsBlock') + }) + it('converts multiple content blocks', () => { const messageData: MessageData = { role: 'user', @@ -532,6 +558,7 @@ describe('toJSON/fromJSON round-trips', () => { ['Message with text content', () => new Message({ role: 'user', content: [new TextBlock('Hello')] })], ['Message with multiple content blocks', () => new Message({ role: 'assistant', content: [new TextBlock('Here is the result'), new ToolUseBlock({ name: 'test-tool', toolUseId: '123', input: { key: 'value' } })] })], ['Message with image content', () => new Message({ role: 'user', content: [new TextBlock('Check this image'), new ImageBlock({ format: 'png', source: { bytes: new Uint8Array([1, 2, 3]) } })] })], + ['CitationsBlock', () => new CitationsBlock({ citations: [{ location: { type: 'documentChar', documentIndex: 0, start: 0, end: 10 }, source: 'doc-0', sourceContent: [{ text: 'source' }], title: 'Test' }], content: [{ text: 'generated' }] })], ] as const it.each(roundTripCases)('%s', (_name, createBlock) => { diff --git a/src/types/citations.ts b/src/types/citations.ts new file mode 100644 index 00000000..a7553068 --- /dev/null +++ b/src/types/citations.ts @@ -0,0 +1,218 @@ +import type { JSONSerializable, Serialized } from './json.js' + +/** + * Citation types for document citation content blocks. + * + * Citations are returned by models when document citations are enabled. + * They are output-only blocks that appear in conversation history. + */ + +/** + * Discriminated union of citation location types. + * Each variant uses a `type` field to identify the location kind. + */ +export type CitationLocation = + | { + /** + * Location referencing character positions within a document. + */ + type: 'documentChar' + + /** + * Index of the source document. + */ + documentIndex: number + + /** + * Start character position. + */ + start: number + + /** + * End character position. + */ + end: number + } + | { + /** + * Location referencing page positions within a document. + */ + type: 'documentPage' + + /** + * Index of the source document. + */ + documentIndex: number + + /** + * Start page number. + */ + start: number + + /** + * End page number. + */ + end: number + } + | { + /** + * Location referencing chunk positions within a document. + */ + type: 'documentChunk' + + /** + * Index of the source document. + */ + documentIndex: number + + /** + * Start chunk index. + */ + start: number + + /** + * End chunk index. + */ + end: number + } + | { + /** + * Location referencing a search result. + */ + type: 'searchResult' + + /** + * Index of the search result. + */ + searchResultIndex: number + + /** + * Start position within the search result. + */ + start: number + + /** + * End position within the search result. + */ + end: number + } + | { + /** + * Location referencing a web URL. + */ + type: 'web' + + /** + * The URL of the web source. + */ + url: string + + /** + * The domain of the web source. + */ + domain?: string + } + +/** + * Source content referenced by a citation. + * Modeled as a union type for future extensibility. + */ +export type CitationSourceContent = { text: string } + +/** + * Generated content associated with a citation. + * Modeled as a union type for future extensibility. + */ +export type CitationGeneratedContent = { text: string } + +/** + * A single citation linking generated content to a source location. + */ +export interface Citation { + /** + * The location of the cited source. + */ + location: CitationLocation + + /** + * The source identifier string. + */ + source: string + + /** + * The source content referenced by this citation. + */ + sourceContent: CitationSourceContent[] + + /** + * Title of the cited source. + */ + title: string +} + +/** + * Data for a citations content block. + */ +export interface CitationsBlockData { + /** + * Array of citations linking generated content to source locations. + */ + citations: Citation[] + + /** + * The generated content associated with these citations. + */ + content: CitationGeneratedContent[] +} + +/** + * Citations content block within a message. + * Returned by models when document citations are enabled. + * This is an output-only block — users do not construct these directly. + */ +export class CitationsBlock + implements CitationsBlockData, JSONSerializable<{ citationsContent: Serialized }> +{ + /** + * Discriminator for citations content. + */ + readonly type = 'citationsBlock' as const + + /** + * Array of citations linking generated content to source locations. + */ + readonly citations: Citation[] + + /** + * The generated content associated with these citations. + */ + readonly content: CitationGeneratedContent[] + + constructor(data: CitationsBlockData) { + this.citations = data.citations + this.content = data.content + } + + /** + * Serializes the CitationsBlock to a JSON-compatible ContentBlockData object. + * Called automatically by JSON.stringify(). + */ + toJSON(): { citationsContent: Serialized } { + return { + citationsContent: { + citations: this.citations, + content: this.content, + }, + } + } + + /** + * Creates a CitationsBlock instance from its wrapped data format. + * + * @param data - Wrapped CitationsBlockData to deserialize + * @returns CitationsBlock instance + */ + static fromJSON(data: { citationsContent: Serialized }): CitationsBlock { + return new CitationsBlock(data.citationsContent) + } +} diff --git a/src/types/messages.ts b/src/types/messages.ts index cf7e3621..b438c6cf 100644 --- a/src/types/messages.ts +++ b/src/types/messages.ts @@ -2,6 +2,8 @@ import type { JSONValue, Serialized, MaybeSerializedInput, JSONSerializable } fr import { omitUndefined } from './json.js' import type { ImageBlockData, VideoBlockData, DocumentBlockData } from './media.js' import { ImageBlock, VideoBlock, DocumentBlock, encodeBase64, decodeBase64 } from './media.js' +import type { CitationsBlockData } from './citations.js' +import { CitationsBlock } from './citations.js' /** * Message types and content blocks for conversational AI interactions. @@ -115,6 +117,7 @@ export type ContentBlockData = | { image: ImageBlockData } | { video: VideoBlockData } | { document: DocumentBlockData } + | { citationsContent: CitationsBlockData } export type ContentBlock = | TextBlock @@ -126,6 +129,7 @@ export type ContentBlock = | ImageBlock | VideoBlock | DocumentBlock + | CitationsBlock /** * Data for a text block. @@ -875,6 +879,8 @@ export function contentBlockFromData(data: ContentBlockData): ContentBlock { return VideoBlock.fromJSON(data) } else if ('document' in data) { return DocumentBlock.fromJSON(data) + } else if ('citationsContent' in data) { + return CitationsBlock.fromJSON(data) } else { throw new Error('Unknown ContentBlockData type') } diff --git a/test/integ/__fixtures__/model-providers.ts b/test/integ/__fixtures__/model-providers.ts index 20ed42d9..56e5292a 100644 --- a/test/integ/__fixtures__/model-providers.ts +++ b/test/integ/__fixtures__/model-providers.ts @@ -22,6 +22,7 @@ export interface ProviderFeatures { images: boolean documents: boolean video: boolean + citations: boolean } export const bedrock = { @@ -34,6 +35,7 @@ export const bedrock = { images: true, documents: true, video: true, + citations: true, } satisfies ProviderFeatures, models: { default: {}, @@ -68,6 +70,7 @@ export const openai = { images: true, documents: true, video: false, + citations: false, } satisfies ProviderFeatures, models: { default: {}, @@ -100,6 +103,7 @@ export const anthropic = { images: true, documents: true, video: false, + citations: false, } satisfies ProviderFeatures, models: { default: {}, @@ -139,6 +143,7 @@ export const gemini = { images: true, documents: true, video: true, + citations: false, } satisfies ProviderFeatures, models: { default: {}, diff --git a/test/integ/agent.test.ts b/test/integ/agent.test.ts index b38e5f5c..a21dedbe 100644 --- a/test/integ/agent.test.ts +++ b/test/integ/agent.test.ts @@ -1,6 +1,7 @@ import { describe, expect, it } from 'vitest' import { Agent, + CitationsBlock, DocumentBlock, ImageBlock, Message, @@ -262,6 +263,126 @@ describe.each(allProviders)('Agent with $name', ({ name, skip, createModel, mode expect(textContent?.text).toMatch(/yellow/i) }) + describe.skipIf(!supports.citations)('Citations', () => { + const documentText = [ + 'France is a country in Western Europe. Its capital is Paris, which is known as the City of Light.', + 'Paris has a population of approximately 2.1 million people in the city proper.', + 'The Eiffel Tower, built in 1889, is the most visited paid monument in the world.', + 'France is the most visited country in the world, with over 89 million tourists annually.', + 'The French Revolution of 1789 was a pivotal event in world history.', + ].join(' ') + + const textDocBlock = new DocumentBlock({ + name: 'test-document', + format: 'txt', + source: { content: [{ text: documentText }] }, + citations: { enabled: true }, + }) + + const textDocPrompt = new TextBlock( + 'Using the document, what is the capital of France and what is it known for? Cite specific details.' + ) + + it('returns documentChunk citations from text document', async () => { + const agent = new Agent({ + model: createModel({ stream: false }), + printer: false, + }) + + const result = await agent.invoke([textDocBlock, textDocPrompt]) + + expect(result.stopReason).toBe('endTurn') + + const citationsBlock = result.lastMessage.content.find( + (block): block is CitationsBlock => block.type === 'citationsBlock' + ) + expect(citationsBlock).toBeDefined() + expect(citationsBlock!.citations).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + location: expect.objectContaining({ type: 'documentChunk' }), + source: expect.any(String), + title: expect.any(String), + sourceContent: expect.arrayContaining([expect.objectContaining({ text: expect.any(String) })]), + }), + ]) + ) + expect(citationsBlock!.content).toEqual( + expect.arrayContaining([expect.objectContaining({ text: expect.any(String) })]) + ) + }) + + it('returns documentPage citations from PDF document and preserves them in multi-turn', async () => { + const pdfBytes = await loadFixture(letterPdfUrl) + + const agent = new Agent({ + model: createModel({ stream: false }), + printer: false, + }) + + const result = await agent.invoke([ + new DocumentBlock({ + name: 'letter', + format: 'pdf', + source: { bytes: pdfBytes }, + citations: { enabled: true }, + }), + new TextBlock('Summarize this document briefly.'), + ]) + + expect(result.stopReason).toBe('endTurn') + + const citationsBlock = result.lastMessage.content.find( + (block): block is CitationsBlock => block.type === 'citationsBlock' + ) + expect(citationsBlock).toBeDefined() + expect(citationsBlock!.citations).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + location: expect.objectContaining({ type: 'documentPage' }), + source: expect.any(String), + title: expect.any(String), + sourceContent: expect.arrayContaining([expect.objectContaining({ text: expect.any(String) })]), + }), + ]) + ) + expect(citationsBlock!.content).toEqual( + expect.arrayContaining([expect.objectContaining({ text: expect.any(String) })]) + ) + + // Second turn: verify citations survive in conversation history + const followUp = await agent.invoke('What else can you tell me about this document?') + expect(followUp.stopReason).toBe('endTurn') + expect(followUp.lastMessage.role).toBe('assistant') + expect(followUp.lastMessage.content.length).toBeGreaterThan(0) + }) + + it('emits citationsContentDelta events via agent.stream()', async () => { + const agent = new Agent({ + model: createModel({ stream: false }), + printer: false, + }) + + const { items, result } = await collectGenerator(agent.stream([textDocBlock, textDocPrompt])) + + expect(result.stopReason).toBe('endTurn') + + const citationDeltas = items.filter( + (item) => + item.type === 'modelStreamUpdateEvent' && + item.event.type === 'modelContentBlockDeltaEvent' && + item.event.delta.type === 'citationsContentDelta' + ) + expect(citationDeltas.length).toBeGreaterThan(0) + + const citationsBlock = result.lastMessage.content.find( + (block): block is CitationsBlock => block.type === 'citationsBlock' + ) + expect(citationsBlock).toBeDefined() + expect(citationsBlock!.citations.length).toBeGreaterThan(0) + }) + }) + describe.skipIf(!supports.images)('multimodal input', () => { it('accepts ContentBlock[] input', async () => { const agent = new Agent({