From 6b46b248da286251ba4cc13546389199a474e25f Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Wed, 18 Feb 2026 10:26:35 -0500 Subject: [PATCH] fix: make callback the first parameter for adding hook callback --- docs/TESTING.md | 8 +-- src/__fixtures__/mock-hook-provider.ts | 4 +- src/agent/__tests__/agent.hook.test.ts | 24 +++---- .../sliding-window-conversation-manager.ts | 8 +-- src/hooks/__tests__/registry.test.ts | 72 +++++++++---------- src/hooks/registry.ts | 8 +-- src/hooks/types.ts | 4 +- 7 files changed, 64 insertions(+), 64 deletions(-) diff --git a/docs/TESTING.md b/docs/TESTING.md index f56e7445..88a4d8cc 100644 --- a/docs/TESTING.md +++ b/docs/TESTING.md @@ -321,12 +321,12 @@ When testing hook behavior, you **MUST** use `agent.hooks.addCallback()` for reg // ✅ CORRECT - Use agent.hooks.addCallback() for single callbacks const agent = new Agent({ model, tools: [tool] }) -agent.hooks.addCallback(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { +agent.hooks.addCallback((event: BeforeToolCallEvent) => { event.toolUse = { ...event.toolUse, input: { value: 42 }, } -}) +}, BeforeToolCallEvent) // ✅ CORRECT - Use MockHookProvider to record and verify hook invocations const hookProvider = new MockHookProvider() @@ -337,11 +337,11 @@ expect(hookProvider.invocations).toContainEqual(new BeforeInvocationEvent({ agen // ❌ WRONG - Do NOT create inline HookProvider objects const switchToolHook = { registerCallbacks: (registry: HookRegistry) => { - registry.addCallback(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + registry.addCallback((event: BeforeToolCallEvent) => { if (event.toolUse.name === 'tool1') { event.tool = tool2 } - }) + }, BeforeToolCallEvent) }, } ``` diff --git a/src/__fixtures__/mock-hook-provider.ts b/src/__fixtures__/mock-hook-provider.ts index 0ce8437b..580eb7b5 100644 --- a/src/__fixtures__/mock-hook-provider.ts +++ b/src/__fixtures__/mock-hook-provider.ts @@ -40,9 +40,9 @@ export class MockHookProvider implements HookProvider { const eventTypes = this.includeModelEvents ? [...lifecycleEvents, ...modelEvents] : lifecycleEvents for (const eventType of eventTypes) { - registry.addCallback(eventType, (e) => { + registry.addCallback((e) => { this.invocations.push(e) - }) + }, eventType) } } diff --git a/src/agent/__tests__/agent.hook.test.ts b/src/agent/__tests__/agent.hook.test.ts index ed9f7cca..ff57de71 100644 --- a/src/agent/__tests__/agent.hook.test.ts +++ b/src/agent/__tests__/agent.hook.test.ts @@ -306,12 +306,12 @@ describe('Agent Hooks Integration', () => { .addTurn({ type: 'textBlock', text: 'Success after retry' }) const agent = new Agent({ model }) - agent.hooks.addCallback(AfterModelCallEvent, (event: AfterModelCallEvent) => { + agent.hooks.addCallback((event: AfterModelCallEvent) => { callCount++ if (callCount === 1 && event.error) { event.retry = true } - }) + }, AfterModelCallEvent) const result = await agent.invoke('Test') @@ -333,12 +333,12 @@ describe('Agent Hooks Integration', () => { .addTurn({ type: 'textBlock', text: 'Second response after retry' }) const agent = new Agent({ model }) - agent.hooks.addCallback(AfterModelCallEvent, (event: AfterModelCallEvent) => { + agent.hooks.addCallback((event: AfterModelCallEvent) => { callCount++ if (callCount === 1 && !event.error) { event.retry = true } - }) + }, AfterModelCallEvent) const result = await agent.invoke('Test') @@ -364,12 +364,12 @@ describe('Agent Hooks Integration', () => { .addTurn({ type: 'textBlock', text: 'Done' }) const agent = new Agent({ model, tools: [tool] }) - agent.hooks.addCallback(AfterToolCallEvent, (event: AfterToolCallEvent) => { + agent.hooks.addCallback((event: AfterToolCallEvent) => { hookCallCount++ if (hookCallCount === 1 && event.error) { event.retry = true } - }) + }, AfterToolCallEvent) const result = await agent.invoke('Test') @@ -414,15 +414,15 @@ describe('Agent Hooks Integration', () => { .addTurn({ type: 'textBlock', text: 'Done' }) const agent = new Agent({ model, tools: [tool] }) - agent.hooks.addCallback(BeforeToolCallEvent, () => { + agent.hooks.addCallback(() => { beforeCount++ - }) - agent.hooks.addCallback(AfterToolCallEvent, (event: AfterToolCallEvent) => { + }, BeforeToolCallEvent) + agent.hooks.addCallback((event: AfterToolCallEvent) => { afterCount++ if (afterCount === 1) { event.retry = true } - }) + }, AfterToolCallEvent) await agent.invoke('Test') @@ -448,12 +448,12 @@ describe('Agent Hooks Integration', () => { .addTurn({ type: 'textBlock', text: 'Done' }) const agent = new Agent({ model, tools: [tool] }) - agent.hooks.addCallback(AfterToolCallEvent, (event: AfterToolCallEvent) => { + agent.hooks.addCallback((event: AfterToolCallEvent) => { hookCallCount++ if (hookCallCount === 1) { event.retry = true } - }) + }, AfterToolCallEvent) const result = await agent.invoke('Test') diff --git a/src/conversation-manager/sliding-window-conversation-manager.ts b/src/conversation-manager/sliding-window-conversation-manager.ts index 3f2b234f..d6a26b3b 100644 --- a/src/conversation-manager/sliding-window-conversation-manager.ts +++ b/src/conversation-manager/sliding-window-conversation-manager.ts @@ -65,17 +65,17 @@ export class SlidingWindowConversationManager implements HookProvider { */ public registerCallbacks(registry: HookRegistry): void { // Apply sliding window management after each invocation - registry.addCallback(AfterInvocationEvent, (event) => { + registry.addCallback((event) => { this.applyManagement(event.agent.messages) - }) + }, AfterInvocationEvent) // Handle context overflow errors - registry.addCallback(AfterModelCallEvent, (event) => { + registry.addCallback((event) => { if (event.error instanceof ContextWindowOverflowError) { this.reduceContext(event.agent.messages, event.error) event.retry = true } - }) + }, AfterModelCallEvent) } /** diff --git a/src/hooks/__tests__/registry.test.ts b/src/hooks/__tests__/registry.test.ts index b3022024..05b91884 100644 --- a/src/hooks/__tests__/registry.test.ts +++ b/src/hooks/__tests__/registry.test.ts @@ -16,7 +16,7 @@ describe('HookRegistryImplementation', () => { describe('addCallback', () => { it('registers callback for event type', async () => { const callback = vi.fn() - registry.addCallback(BeforeInvocationEvent, callback) + registry.addCallback(callback, BeforeInvocationEvent) await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent })) @@ -27,8 +27,8 @@ describe('HookRegistryImplementation', () => { const callback1 = vi.fn() const callback2 = vi.fn() - registry.addCallback(BeforeInvocationEvent, callback1) - registry.addCallback(BeforeInvocationEvent, callback2) + registry.addCallback(callback1, BeforeInvocationEvent) + registry.addCallback(callback2, BeforeInvocationEvent) await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent })) @@ -40,8 +40,8 @@ describe('HookRegistryImplementation', () => { const beforeCallback = vi.fn() const afterCallback = vi.fn() - registry.addCallback(BeforeInvocationEvent, beforeCallback) - registry.addCallback(AfterInvocationEvent, afterCallback) + registry.addCallback(beforeCallback, BeforeInvocationEvent) + registry.addCallback(afterCallback, AfterInvocationEvent) await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent })) @@ -61,8 +61,8 @@ describe('HookRegistryImplementation', () => { const provider: HookProvider = { registerCallbacks: (reg) => { - reg.addCallback(BeforeInvocationEvent, beforeCallback) - reg.addCallback(AfterInvocationEvent, afterCallback) + reg.addCallback(beforeCallback, BeforeInvocationEvent) + reg.addCallback(afterCallback, AfterInvocationEvent) }, } @@ -87,7 +87,7 @@ describe('HookRegistryImplementation', () => { // Verify _currentProvider is cleared by registering another provider successfully const workingProvider: HookProvider = { registerCallbacks: (reg) => { - reg.addCallback(BeforeInvocationEvent, vi.fn()) + reg.addCallback(vi.fn(), BeforeInvocationEvent) }, } @@ -105,8 +105,8 @@ describe('HookRegistryImplementation', () => { callOrder.push(2) }) - registry.addCallback(BeforeInvocationEvent, callback1) - registry.addCallback(BeforeInvocationEvent, callback2) + registry.addCallback(callback1, BeforeInvocationEvent) + registry.addCallback(callback2, BeforeInvocationEvent) await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent })) @@ -122,8 +122,8 @@ describe('HookRegistryImplementation', () => { callOrder.push(2) }) - registry.addCallback(AfterInvocationEvent, callback1) - registry.addCallback(AfterInvocationEvent, callback2) + registry.addCallback(callback1, AfterInvocationEvent) + registry.addCallback(callback2, AfterInvocationEvent) await registry.invokeCallbacks(new AfterInvocationEvent({ agent: mockAgent })) @@ -137,7 +137,7 @@ describe('HookRegistryImplementation', () => { completed = true }) - registry.addCallback(BeforeInvocationEvent, callback) + registry.addCallback(callback, BeforeInvocationEvent) await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent })) @@ -149,7 +149,7 @@ describe('HookRegistryImplementation', () => { throw new Error('Hook failed') }) - registry.addCallback(BeforeInvocationEvent, callback) + registry.addCallback(callback, BeforeInvocationEvent) await expect(registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent }))).rejects.toThrow( 'Hook failed' @@ -162,8 +162,8 @@ describe('HookRegistryImplementation', () => { }) const callback2 = vi.fn() - registry.addCallback(BeforeInvocationEvent, callback1) - registry.addCallback(BeforeInvocationEvent, callback2) + registry.addCallback(callback1, BeforeInvocationEvent) + registry.addCallback(callback2, BeforeInvocationEvent) await expect(registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent }))).rejects.toThrow( 'First callback failed' @@ -182,8 +182,8 @@ describe('HookRegistryImplementation', () => { callOrder.push('async') }) - registry.addCallback(BeforeInvocationEvent, syncCallback) - registry.addCallback(BeforeInvocationEvent, asyncCallback) + registry.addCallback(syncCallback, BeforeInvocationEvent) + registry.addCallback(asyncCallback, BeforeInvocationEvent) await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent })) @@ -201,7 +201,7 @@ describe('HookRegistryImplementation', () => { it('returns cleanup function that removes the callback', async () => { const callback = vi.fn() - const cleanup = registry.addCallback(BeforeInvocationEvent, callback) + const cleanup = registry.addCallback(callback, BeforeInvocationEvent) cleanup() await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent })) @@ -212,7 +212,7 @@ describe('HookRegistryImplementation', () => { it('cleanup function is idempotent', async () => { const callback = vi.fn() - const cleanup = registry.addCallback(BeforeInvocationEvent, callback) + const cleanup = registry.addCallback(callback, BeforeInvocationEvent) cleanup() cleanup() cleanup() @@ -226,8 +226,8 @@ describe('HookRegistryImplementation', () => { const callback1 = vi.fn() const callback2 = vi.fn() - const cleanup1 = registry.addCallback(BeforeInvocationEvent, callback1) - registry.addCallback(BeforeInvocationEvent, callback2) + const cleanup1 = registry.addCallback(callback1, BeforeInvocationEvent) + registry.addCallback(callback2, BeforeInvocationEvent) cleanup1() await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent })) @@ -241,7 +241,7 @@ describe('HookRegistryImplementation', () => { const provider: HookProvider = { registerCallbacks: (reg) => { - reg.addCallback(BeforeInvocationEvent, callback) + reg.addCallback(callback, BeforeInvocationEvent) }, } @@ -261,8 +261,8 @@ describe('HookRegistryImplementation', () => { const provider: HookProvider = { registerCallbacks: (reg) => { - reg.addCallback(BeforeInvocationEvent, beforeCallback) - reg.addCallback(AfterInvocationEvent, afterCallback) + reg.addCallback(beforeCallback, BeforeInvocationEvent) + reg.addCallback(afterCallback, AfterInvocationEvent) }, } @@ -281,7 +281,7 @@ describe('HookRegistryImplementation', () => { const provider: HookProvider = { registerCallbacks: (reg) => { - reg.addCallback(BeforeInvocationEvent, callback) + reg.addCallback(callback, BeforeInvocationEvent) }, } @@ -299,7 +299,7 @@ describe('HookRegistryImplementation', () => { const provider1: HookProvider = { registerCallbacks: (reg) => { - reg.addCallback(BeforeInvocationEvent, callback) + reg.addCallback(callback, BeforeInvocationEvent) }, } @@ -321,13 +321,13 @@ describe('HookRegistryImplementation', () => { const provider1: HookProvider = { registerCallbacks: (reg) => { - reg.addCallback(BeforeInvocationEvent, callback1) + reg.addCallback(callback1, BeforeInvocationEvent) }, } const provider2: HookProvider = { registerCallbacks: (reg) => { - reg.addCallback(BeforeInvocationEvent, callback2) + reg.addCallback(callback2, BeforeInvocationEvent) }, } @@ -347,11 +347,11 @@ describe('HookRegistryImplementation', () => { const provider: HookProvider = { registerCallbacks: (reg) => { - reg.addCallback(BeforeInvocationEvent, providerCallback) + reg.addCallback(providerCallback, BeforeInvocationEvent) }, } - registry.addCallback(BeforeInvocationEvent, directCallback) + registry.addCallback(directCallback, BeforeInvocationEvent) registry.addHook(provider) registry.removeHook(provider) @@ -366,7 +366,7 @@ describe('HookRegistryImplementation', () => { const provider: HookProvider = { registerCallbacks: (reg) => { - reg.addCallback(BeforeInvocationEvent, callback) + reg.addCallback(callback, BeforeInvocationEvent) }, } @@ -391,16 +391,16 @@ describe('HookRegistryImplementation', () => { const provider: HookProvider = { registerCallbacks: (reg) => { - reg.addCallback(BeforeInvocationEvent, callback1) - reg.addCallback(BeforeInvocationEvent, callback2) + reg.addCallback(callback1, BeforeInvocationEvent) + reg.addCallback(callback2, BeforeInvocationEvent) }, } registry.addHook(provider) registry.removeHook(provider) - const cleanup = registry.addCallback(BeforeInvocationEvent, callback1) - registry.addCallback(BeforeInvocationEvent, callback2) + const cleanup = registry.addCallback(callback1, BeforeInvocationEvent) + registry.addCallback(callback2, BeforeInvocationEvent) cleanup() await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent })) diff --git a/src/hooks/registry.ts b/src/hooks/registry.ts index feb83a43..cd9de881 100644 --- a/src/hooks/registry.ts +++ b/src/hooks/registry.ts @@ -17,11 +17,11 @@ export interface HookRegistry { /** * Register a callback function for a specific event type. * - * @param eventType - The event class constructor to register the callback for * @param callback - The callback function to invoke when the event occurs + * @param eventType - The event class constructor to register the callback for * @returns Cleanup function that removes the callback when invoked */ - addCallback(eventType: HookEventConstructor, callback: HookCallback): HookCleanup + addCallback(callback: HookCallback, eventType: HookEventConstructor): HookCleanup /** * Register all callbacks from a hook provider. @@ -54,11 +54,11 @@ export class HookRegistryImplementation implements HookRegistry { /** * Register a callback function for a specific event type. * - * @param eventType - The event class constructor to register the callback for * @param callback - The callback function to invoke when the event occurs + * @param eventType - The event class constructor to register the callback for * @returns Cleanup function that removes the callback when invoked */ - addCallback(eventType: HookEventConstructor, callback: HookCallback): HookCleanup { + addCallback(callback: HookCallback, eventType: HookEventConstructor): HookCleanup { const entry: CallbackEntry = { callback: callback as HookCallback, source: this._currentProvider } const callbacks = this._callbacks.get(eventType) ?? [] callbacks.push(entry) diff --git a/src/hooks/types.ts b/src/hooks/types.ts index fb7d7453..da7d15b6 100644 --- a/src/hooks/types.ts +++ b/src/hooks/types.ts @@ -35,8 +35,8 @@ export type HookCleanup = () => void * ```typescript * class MyHooks implements HookProvider { * registerCallbacks(registry: HookRegistry): void { - * registry.addCallback(BeforeInvocationEvent, this.onStart) - * registry.addCallback(AfterInvocationEvent, this.onEnd) + * registry.addCallback(this.onStart, BeforeInvocationEvent) + * registry.addCallback(this.onEnd, AfterInvocationEvent) * } * * private onStart = (event: BeforeInvocationEvent): void => {