Skip to content
42 changes: 17 additions & 25 deletions app/mcp/route.ts
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
import { registerAllTools } from "@/lib/mcp/tools";
import { createMcpHandler } from "mcp-handler";
import { createMcpHandler, withMcpAuth } from "mcp-handler";
import { verifyApiKey } from "@/lib/mcp/verifyApiKey";

let handler: ReturnType<typeof createMcpHandler> | null = null;
const baseHandler = createMcpHandler(
server => {
registerAllTools(server);
},
{
serverInfo: {
name: "recoup-mcp",
version: "0.0.1",
},
},
);

/**
* Gets the MCP handler for the API.
*
* @returns The MCP handler.
*/
async function getHandler(): Promise<ReturnType<typeof createMcpHandler>> {
if (!handler) {
handler = createMcpHandler(
server => {
registerAllTools(server);
},
{
serverInfo: {
name: "recoup-mcp",
version: "0.0.1",
},
},
);
}
return handler;
}
// Wrap with auth - API key is required for all MCP requests
const handler = withMcpAuth(baseHandler, verifyApiKey, {
required: true,
});

/**
* GET handler for the MCP API.
Expand All @@ -32,7 +26,6 @@ async function getHandler(): Promise<ReturnType<typeof createMcpHandler>> {
* @returns The response from the MCP handler.
*/
export async function GET(req: Request) {
const handler = await getHandler();
return handler(req);
}

Expand All @@ -43,6 +36,5 @@ export async function GET(req: Request) {
* @returns The response from the MCP handler.
*/
export async function POST(req: Request) {
const handler = await getHandler();
return handler(req);
}
52 changes: 52 additions & 0 deletions lib/mcp/resolveAccountId.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import { canAccessAccount } from "@/lib/organizations/canAccessAccount";
import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey";

export interface ResolveAccountIdParams {
authInfo: McpAuthInfo | undefined;
accountIdOverride: string | undefined;
}

export interface ResolveAccountIdResult {
accountId: string | null;
error: string | null;
}

/**
* Resolves the accountId from MCP auth info or an override parameter.
* Validates access when an org API key attempts to use an account_id override.
*
* @param params - The auth info and optional account_id override.
* @returns The resolved accountId or an error message.
*/
export async function resolveAccountId({
authInfo,
accountIdOverride,
}: ResolveAccountIdParams): Promise<ResolveAccountIdResult> {
const authAccountId = authInfo?.extra?.accountId;
const authOrgId = authInfo?.extra?.orgId;

if (authAccountId) {
// If account_id override is provided, validate access (for org API keys)
if (accountIdOverride && accountIdOverride !== authAccountId) {
const hasAccess = await canAccessAccount({
orgId: authOrgId,
targetAccountId: accountIdOverride,
});
if (!hasAccess) {
return { accountId: null, error: "Access denied to specified account_id" };
}
return { accountId: accountIdOverride, error: null };
}
return { accountId: authAccountId, error: null };
}

if (accountIdOverride) {
return { accountId: accountIdOverride, error: null };
}

return {
accountId: null,
error:
"Authentication required. Provide an API key via Authorization: Bearer header, or provide account_id from the system prompt context.",
};
}
189 changes: 166 additions & 23 deletions lib/mcp/tools/artists/__tests__/registerCreateNewArtistTool.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js";
import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js";

const mockCreateArtistInDb = vi.fn();
const mockCopyRoom = vi.fn();
const mockCanAccessAccount = vi.fn();

vi.mock("@/lib/artists/createArtistInDb", () => ({
createArtistInDb: (...args: unknown[]) => mockCreateArtistInDb(...args),
Expand All @@ -12,11 +15,39 @@ vi.mock("@/lib/rooms/copyRoom", () => ({
copyRoom: (...args: unknown[]) => mockCopyRoom(...args),
}));

vi.mock("@/lib/organizations/canAccessAccount", () => ({
canAccessAccount: (...args: unknown[]) => mockCanAccessAccount(...args),
}));

import { registerCreateNewArtistTool } from "../registerCreateNewArtistTool";

type ServerRequestHandlerExtra = RequestHandlerExtra<ServerRequest, ServerNotification>;

/**
* Creates a mock extra object with optional authInfo.
*/
function createMockExtra(authInfo?: {
accountId?: string;
orgId?: string | null;
}): ServerRequestHandlerExtra {
return {
authInfo: authInfo
? {
token: "test-token",
scopes: ["mcp:tools"],
clientId: authInfo.accountId,
extra: {
accountId: authInfo.accountId,
orgId: authInfo.orgId ?? null,
},
}
: undefined,
} as unknown as ServerRequestHandlerExtra;
}

describe("registerCreateNewArtistTool", () => {
let mockServer: McpServer;
let registeredHandler: (args: unknown) => Promise<unknown>;
let registeredHandler: (args: unknown, extra: ServerRequestHandlerExtra) => Promise<unknown>;

beforeEach(() => {
vi.clearAllMocks();
Expand All @@ -40,7 +71,7 @@ describe("registerCreateNewArtistTool", () => {
);
});

it("creates an artist and returns success", async () => {
it("creates an artist and returns success with account_id", async () => {
const mockArtist = {
id: "artist-123",
account_id: "artist-123",
Expand All @@ -50,10 +81,13 @@ describe("registerCreateNewArtistTool", () => {
};
mockCreateArtistInDb.mockResolvedValue(mockArtist);

const result = await registeredHandler({
name: "Test Artist",
account_id: "owner-456",
});
const result = await registeredHandler(
{
name: "Test Artist",
account_id: "owner-456",
},
createMockExtra(),
);

expect(mockCreateArtistInDb).toHaveBeenCalledWith("Test Artist", "owner-456", undefined);
expect(result).toEqual({
Expand All @@ -66,6 +100,34 @@ describe("registerCreateNewArtistTool", () => {
});
});

it("creates an artist using auth info accountId", async () => {
const mockArtist = {
id: "artist-123",
account_id: "artist-123",
name: "Test Artist",
account_info: [{ image: null }],
account_socials: [],
};
mockCreateArtistInDb.mockResolvedValue(mockArtist);

const result = await registeredHandler(
{
name: "Test Artist",
},
createMockExtra({ accountId: "auth-account-123" }),
);

expect(mockCreateArtistInDb).toHaveBeenCalledWith("Test Artist", "auth-account-123", undefined);
expect(result).toEqual({
content: [
{
type: "text",
text: expect.stringContaining("Successfully created artist"),
},
],
});
});

it("copies room when active_conversation_id is provided", async () => {
const mockArtist = {
id: "artist-123",
Expand All @@ -77,11 +139,14 @@ describe("registerCreateNewArtistTool", () => {
mockCreateArtistInDb.mockResolvedValue(mockArtist);
mockCopyRoom.mockResolvedValue("new-room-789");

const result = await registeredHandler({
name: "Test Artist",
account_id: "owner-456",
active_conversation_id: "source-room-111",
});
const result = await registeredHandler(
{
name: "Test Artist",
account_id: "owner-456",
active_conversation_id: "source-room-111",
},
createMockExtra(),
);

expect(mockCopyRoom).toHaveBeenCalledWith("source-room-111", "artist-123");
expect(result).toEqual({
Expand All @@ -104,22 +169,28 @@ describe("registerCreateNewArtistTool", () => {
};
mockCreateArtistInDb.mockResolvedValue(mockArtist);

await registeredHandler({
name: "Test Artist",
account_id: "owner-456",
organization_id: "org-999",
});
await registeredHandler(
{
name: "Test Artist",
account_id: "owner-456",
organization_id: "org-999",
},
createMockExtra(),
);

expect(mockCreateArtistInDb).toHaveBeenCalledWith("Test Artist", "owner-456", "org-999");
});

it("returns error when artist creation fails", async () => {
mockCreateArtistInDb.mockResolvedValue(null);

const result = await registeredHandler({
name: "Test Artist",
account_id: "owner-456",
});
const result = await registeredHandler(
{
name: "Test Artist",
account_id: "owner-456",
},
createMockExtra(),
);

expect(result).toEqual({
content: [
Expand All @@ -134,16 +205,88 @@ describe("registerCreateNewArtistTool", () => {
it("returns error with message when exception is thrown", async () => {
mockCreateArtistInDb.mockRejectedValue(new Error("Database connection failed"));

const result = await registeredHandler({
const result = await registeredHandler(
{
name: "Test Artist",
account_id: "owner-456",
},
createMockExtra(),
);

expect(result).toEqual({
content: [
{
type: "text",
text: expect.stringContaining("Database connection failed"),
},
],
});
});

it("allows account_id override for org auth with access", async () => {
mockCanAccessAccount.mockResolvedValue(true);
const mockArtist = {
id: "artist-123",
account_id: "artist-123",
name: "Test Artist",
account_id: "owner-456",
account_info: [{ image: null }],
account_socials: [],
};
mockCreateArtistInDb.mockResolvedValue(mockArtist);

await registeredHandler(
{
name: "Test Artist",
account_id: "target-account-456",
},
createMockExtra({ accountId: "org-account-id", orgId: "org-account-id" }),
);

expect(mockCanAccessAccount).toHaveBeenCalledWith({
orgId: "org-account-id",
targetAccountId: "target-account-456",
});
expect(mockCreateArtistInDb).toHaveBeenCalledWith(
"Test Artist",
"target-account-456",
undefined,
);
});

it("returns error when org auth lacks access to account_id", async () => {
mockCanAccessAccount.mockResolvedValue(false);

const result = await registeredHandler(
{
name: "Test Artist",
account_id: "target-account-456",
},
createMockExtra({ accountId: "org-account-id", orgId: "org-account-id" }),
);

expect(result).toEqual({
content: [
{
type: "text",
text: expect.stringContaining("Database connection failed"),
text: expect.stringContaining("Access denied to specified account_id"),
},
],
});
});

it("returns error when neither auth nor account_id is provided", async () => {
const result = await registeredHandler(
{
name: "Test Artist",
},
createMockExtra(),
);

expect(result).toEqual({
content: [
{
type: "text",
text: expect.stringContaining("Authentication required"),
},
],
});
Expand Down
Loading