diff --git a/package.json b/package.json index b5566a7..c82db80 100644 --- a/package.json +++ b/package.json @@ -22,6 +22,7 @@ "typescript": "^5.0.4" }, "dependencies": { + "arktype": "1.0.14-alpha", "zod": "^3.21.4" } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 13a3f5c..eb79ea1 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -3,6 +3,7 @@ lockfileVersion: 5.4 specifiers: '@types/jest': ^29.5.1 '@types/node': ^20.2.4 + arktype: 1.0.14-alpha jest: ^29.5.0 openai: ^3.2.1 ts-jest: ^29.1.0 @@ -11,6 +12,7 @@ specifiers: zod: ^3.21.4 dependencies: + arktype: 1.0.14-alpha zod: 3.21.4 devDependencies: @@ -769,6 +771,11 @@ packages: sprintf-js: 1.0.3 dev: true + /arktype/1.0.14-alpha: + resolution: {integrity: sha512-theD5K4QrYCWMtQ52Masj169IgtMJ8Argld/MBS4lotEwR3b+GzfBvsqVJ1OIKhJDTdES02FLQTjcfRe00//mA==} + requiresBuild: true + dev: false + /asynckit/0.4.0: resolution: {integrity: sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==} dev: true diff --git a/src/PromptBuilder.ts b/src/PromptBuilder.ts index a6c83c5..0f93e62 100644 --- a/src/PromptBuilder.ts +++ b/src/PromptBuilder.ts @@ -1,4 +1,5 @@ import { z } from "zod"; +import { Type } from "arktype"; import { F } from "ts-toolbelt"; import { Prompt } from "./Prompt"; import { ExtractArgs, ReplaceArgs, TypeToZodShape } from "./types"; @@ -18,23 +19,13 @@ export class PromptBuilder< addZodInputValidation( shape: TypeToZodShape ) { - const zodValidator = z.object(shape as any); - return new (class extends PromptBuilder { - validate(args: Record): args is TShape { - return zodValidator.safeParse(args).success; - } - - get type() { - return this.template as ReplaceArgs; - } + return new ZodPromptBuilder(this.template, shape); + } - build( - args: F.Narrow - ) { - zodValidator.parse(args); - return new Prompt(this.template, args).toString(); - } - })(this.template); + addArkTypeInputValidation( + shape: Type + ) { + return new ArkTypePromptBuilder(this.template, shape); } validate(args: Record): args is TExpectedInput { @@ -52,3 +43,65 @@ export class PromptBuilder< return new Prompt(this.template, args).toString(); } } + +class ZodPromptBuilder< + TPromptTemplate extends string, + TExpectedInput extends ExtractArgs +> extends PromptBuilder { + constructor( + public template: TPromptTemplate, + public shape: TypeToZodShape + ) { + super(template); + } + validate(args: Record): args is TExpectedInput { + const zodValidator = z.object(this.shape as any); + return zodValidator.safeParse(args).success; + } + + get type() { + return this.template as ReplaceArgs; + } + + build( + args: F.Narrow + ) { + const zodValidator = z.object(this.shape as any); + zodValidator.parse(args); + return new Prompt(this.template, args).toString(); + } +} + +class ArkTypePromptBuilder< + TPromptTemplate extends string, + TExpectedInput extends ExtractArgs +> extends PromptBuilder { + constructor( + public template: TPromptTemplate, + public shape: Type + ) { + super(template); + } + validate(args: Record): args is TExpectedInput { + try { + this.shape(args); + return true; + } catch (e) { + return false; + } + } + + get type() { + return this.template as ReplaceArgs; + } + + build( + args: F.Narrow + ) { + const { problems } = this.shape(args); + if (problems?.summary) { + throw new Error(problems.summary); + } + return new Prompt(this.template, args).toString(); + } +} diff --git a/src/__tests__/PromptBuilder.test.ts b/src/__tests__/PromptBuilder.test.ts index d837194..bdcf474 100644 --- a/src/__tests__/PromptBuilder.test.ts +++ b/src/__tests__/PromptBuilder.test.ts @@ -1,5 +1,6 @@ import { strict as assert } from "node:assert"; import { z, ZodError } from "zod"; +import { type } from 'arktype'; import { PromptBuilder } from "../PromptBuilder"; import { Equal, Expect } from "./types.test"; @@ -282,7 +283,9 @@ describe("PromptBuilder with input validation using Zod", () => { ); type BasicType = typeof promptBuilder.type; // ^? - type BasicTest = Expect>; + type BasicTest = Expect< + Equal + >; const tsValidatedPromptBuilder = promptBuilder.addInputValidation<{ jokeType: "funny" | "silly"; @@ -305,34 +308,37 @@ describe("PromptBuilder with input validation using Zod", () => { }); test("Can write a function that accepts the type of a PromptBuilder then accepts any output from that builder", () => { - const promptBuilder = new PromptBuilder("Tell me a {{jokeType}} joke.").addInputValidation<{ - jokeType: "funny" | "silly" + const promptBuilder = new PromptBuilder( + "Tell me a {{jokeType}} joke." + ).addInputValidation<{ + jokeType: "funny" | "silly"; }>(); function exampleFunction(input: typeof promptBuilder.type) {} exampleFunction(promptBuilder.build({ jokeType: "funny" })); - exampleFunction("Tell me a funny joke.") + exampleFunction("Tell me a funny joke."); exampleFunction(promptBuilder.build({ jokeType: "silly" })); - exampleFunction("Tell me a silly joke.") + exampleFunction("Tell me a silly joke."); // @ts-expect-error exampleFunction(promptBuilder.build({ jokeType: "bad" })); // @ts-expect-error - exampleFunction("Tell me a bad joke.") - }) - + exampleFunction("Tell me a bad joke."); + }); test("Can write a function that accepts the type of a PromptBuilder then accepts any output from that builder", () => { - const promptBuilder = new PromptBuilder("Tell me a {{jokeType}} joke.").addZodInputValidation({ + const promptBuilder = new PromptBuilder( + "Tell me a {{jokeType}} joke." + ).addZodInputValidation({ jokeType: z.union([z.literal("funny"), z.literal("silly")]), }); function exampleFunction(input: typeof promptBuilder.type) {} exampleFunction(promptBuilder.build({ jokeType: "funny" })); - exampleFunction("Tell me a funny joke.") + exampleFunction("Tell me a funny joke."); exampleFunction(promptBuilder.build({ jokeType: "silly" })); - exampleFunction("Tell me a silly joke.") + exampleFunction("Tell me a silly joke."); // @ts-expect-error - exampleFunction("Tell me a bad joke.") + exampleFunction("Tell me a bad joke."); assert.throws( () => { // @ts-expect-error @@ -350,5 +356,32 @@ describe("PromptBuilder with input validation using Zod", () => { return true; } ); - }) + }); + + test("Can add arktypes validation", () => { + const promptBuilder = new PromptBuilder( + "Tell me a {{jokeType}} joke." + ).addArkTypeInputValidation( + type({ + jokeType: "'funny' | 'silly'", + }) + ); + function exampleFunction(input: typeof promptBuilder.type) {} + + exampleFunction(promptBuilder.build({ jokeType: "funny" })); + exampleFunction("Tell me a funny joke."); + exampleFunction(promptBuilder.build({ jokeType: "silly" })); + exampleFunction("Tell me a silly joke."); + // @ts-expect-error + exampleFunction("Tell me a bad joke."); + + assert.throws( + () => { + + // @ts-expect-error + exampleFunction(promptBuilder.build({ jokeType: "bad" })); + } + ); + + }); });