diff --git a/src/agents/createFaceDetectionAgent.ts b/src/agents/createFaceDetectionAgent.ts new file mode 100644 index 0000000..9bad54f --- /dev/null +++ b/src/agents/createFaceDetectionAgent.ts @@ -0,0 +1,35 @@ +import { ToolLoopAgent, Output, stepCountIs } from "ai"; +import { z } from "zod"; + +const faceDetectionSchema = z.object({ + hasFace: z.boolean(), +}); + +const instructions = `You classify images as face guides or not. + +A face guide is a headshot or portrait photo on a plain or white background, used for face-swapping in AI image generation. It shows a person's face clearly as the primary subject. + +These are NOT face guides: +- Playlist covers or album art (even if they show a person) +- Promotional graphics with text overlays +- Concert photos or action shots +- Logos or branded images +- Any image where the face is not the sole focus on a clean background + +Return hasFace: true ONLY for face guide images (headshots on plain backgrounds). +Return hasFace: false for everything else.`; + +/** + * Creates a ToolLoopAgent configured for face guide detection in images. + * Uses Output.object with a Zod schema for structured boolean response. + * + * @returns A configured ToolLoopAgent using Google Gemini via AI Gateway. + */ +export function createFaceDetectionAgent() { + return new ToolLoopAgent({ + model: "google/gemini-3.1-flash-lite-preview", + instructions, + output: Output.object({ schema: faceDetectionSchema }), + stopWhen: stepCountIs(1), + }); +} diff --git a/src/content/__tests__/detectFace.test.ts b/src/content/__tests__/detectFace.test.ts index 6a8bf5b..0b4c84d 100644 --- a/src/content/__tests__/detectFace.test.ts +++ b/src/content/__tests__/detectFace.test.ts @@ -4,9 +4,11 @@ vi.mock("../../sandboxes/logStep", () => ({ logStep: vi.fn(), })); -const mockFalSubscribe = vi.fn(); -vi.mock("../falSubscribe", () => ({ - falSubscribe: (...args: unknown[]) => mockFalSubscribe(...args), +const mockGenerate = vi.fn(); +vi.mock("../../agents/createFaceDetectionAgent", () => ({ + createFaceDetectionAgent: () => ({ + generate: mockGenerate, + }), })); import { detectFace } from "../detectFace"; @@ -16,96 +18,56 @@ describe("detectFace", () => { vi.clearAllMocks(); }); - it("returns true when a person label is detected", async () => { - mockFalSubscribe.mockResolvedValue({ - data: { - results: { - bboxes: [[10, 20, 100, 200]], - labels: ["person"], - }, - }, - }); + it("returns true when the agent detects a face guide", async () => { + mockGenerate.mockResolvedValue({ output: { hasFace: true } }); const result = await detectFace("https://example.com/headshot.png"); expect(result).toBe(true); - expect(mockFalSubscribe).toHaveBeenCalledWith( - "fal-ai/florence-2-large/object-detection", - { image_url: "https://example.com/headshot.png" }, - ); }); - it("returns true when a face label is detected among other objects", async () => { - mockFalSubscribe.mockResolvedValue({ - data: { - results: { - bboxes: [[0, 0, 50, 50], [10, 20, 100, 200]], - labels: ["chair", "human face"], - }, - }, - }); - - const result = await detectFace("https://example.com/photo.png"); - - expect(result).toBe(true); - }); - - it("returns false when no person or face labels are detected", async () => { - mockFalSubscribe.mockResolvedValue({ - data: { - results: { - bboxes: [[0, 0, 300, 300]], - labels: ["album cover"], - }, - }, - }); + it("returns false when the agent detects no face guide", async () => { + mockGenerate.mockResolvedValue({ output: { hasFace: false } }); const result = await detectFace("https://example.com/album-cover.png"); expect(result).toBe(false); }); - it("returns false when results are empty", async () => { - mockFalSubscribe.mockResolvedValue({ - data: { - results: { - bboxes: [], - labels: [], - }, - }, - }); + it("sends a few-shot example with the face guide reference image", async () => { + mockGenerate.mockResolvedValue({ output: { hasFace: true } }); - const result = await detectFace("https://example.com/blank.png"); + await detectFace("https://example.com/photo.png"); - expect(result).toBe(false); - }); + const callArgs = mockGenerate.mock.calls[0][0]; + const messages = callArgs.messages; - it("returns false when detection fails", async () => { - mockFalSubscribe.mockRejectedValue(new Error("Detection failed")); + // First message: example face guide image URL + question + expect(messages[0].role).toBe("user"); + const exampleImagePart = messages[0].content.find((p: { type: string }) => p.type === "image"); + expect(exampleImagePart).toBeDefined(); + expect(exampleImagePart.image).toContain("face-guide-example.png"); - const result = await detectFace("https://example.com/broken.png"); + // Second message: assistant answer for the example + expect(messages[1].role).toBe("assistant"); - expect(result).toBe(false); + // Third message: actual image to classify + expect(messages[2].role).toBe("user"); + const targetImagePart = messages[2].content.find((p: { type: string }) => p.type === "image"); + expect(targetImagePart.image).toBe("https://example.com/photo.png"); }); - it("does not false-positive on labels containing face words as substrings", async () => { - mockFalSubscribe.mockResolvedValue({ - data: { - results: { - bboxes: [[0, 0, 200, 200]], - labels: ["ottoman", "mannequin", "womanizer"], - }, - }, - }); + it("returns false when the agent throws", async () => { + mockGenerate.mockRejectedValue(new Error("Model error")); - const result = await detectFace("https://example.com/furniture.png"); + const result = await detectFace("https://example.com/broken.png"); expect(result).toBe(false); }); it("logs the error when detection fails", async () => { const { logStep } = await import("../../sandboxes/logStep"); - mockFalSubscribe.mockRejectedValue(new Error("Rate limit exceeded")); + mockGenerate.mockRejectedValue(new Error("Rate limit exceeded")); await detectFace("https://example.com/broken.png"); @@ -115,4 +77,12 @@ describe("detectFace", () => { expect.objectContaining({ error: "Rate limit exceeded" }), ); }); + + it("returns false when output is null", async () => { + mockGenerate.mockResolvedValue({ output: null }); + + const result = await detectFace("https://example.com/broken.png"); + + expect(result).toBe(false); + }); }); diff --git a/src/content/detectFace.ts b/src/content/detectFace.ts index 4715686..a37ebb4 100644 --- a/src/content/detectFace.ts +++ b/src/content/detectFace.ts @@ -1,34 +1,47 @@ import { logStep } from "../sandboxes/logStep"; -import { falSubscribe } from "./falSubscribe"; +import { createFaceDetectionAgent } from "../agents/createFaceDetectionAgent"; -const DETECTION_MODEL = "fal-ai/florence-2-large/object-detection"; - -/** Labels that indicate a human face or person is present in the image. */ -const FACE_LABELS = ["person", "face", "human face", "man", "woman", "boy", "girl"]; +const FACE_GUIDE_EXAMPLE_URL = + "https://dxfamqbi5zyezrs5.public.blob.vercel-storage.com/content-attachments/image/1775671967694-face-guide-example.png"; /** - * Detects whether an image contains a human face using Florence-2 object detection. + * Detects whether an image is a face guide (headshot/portrait on a plain background) + * rather than a playlist cover, album art, or other image that may incidentally contain a face. + * + * Uses a few-shot approach: shows the model an example face guide first, then asks + * it to classify the target image. * * @param imageUrl - URL of the image to analyze - * @returns true if at least one face/person is detected, false otherwise + * @returns true if the image is a face guide, false otherwise */ export async function detectFace(imageUrl: string): Promise { try { - const result = await falSubscribe(DETECTION_MODEL, { - image_url: imageUrl, + const agent = createFaceDetectionAgent(); + const { output } = await agent.generate({ + messages: [ + { + role: "user", + content: [ + { type: "image", image: FACE_GUIDE_EXAMPLE_URL }, + { type: "text", text: "This is an example of a face guide — a headshot or portrait on a plain/white background used for face-swapping. Is this a face guide?" }, + ], + }, + { + role: "assistant", + content: [{ type: "text", text: '{"hasFace":true}' }], + }, + { + role: "user", + content: [ + { type: "image", image: imageUrl }, + { type: "text", text: "Is this image a face guide like the example above? A face guide is a headshot or portrait on a plain background. Playlist covers, album art, promotional graphics, and other images that happen to show a face are NOT face guides." }, + ], + }, + ], }); - const data = result.data as Record; - const results = data.results as { labels?: string[] } | undefined; - const labels = results?.labels ?? []; - - const hasFace = labels.some((label) => { - const lower = label.toLowerCase(); - return FACE_LABELS.some( - (faceLabel) => lower === faceLabel || lower.split(" ").includes(faceLabel), - ); - }); - logStep("Face detection result", false, { imageUrl: imageUrl.slice(0, 80), hasFace, labels }); + const hasFace = output?.hasFace ?? false; + logStep("Face detection result", false, { imageUrl: imageUrl.slice(0, 80), hasFace }); return hasFace; } catch (err) { logStep("Face detection failed, assuming no face", false, {