From cb84c4900cc98b8c78cbc9cacf3e3f0c08fe93ed Mon Sep 17 00:00:00 2001 From: Irrelevant Date: Thu, 26 Feb 2026 09:11:09 -0600 Subject: [PATCH] feat: add reusable library dataset generation API --- package-lock.json | 4 +- src/index.ts | 483 +++++++++++++++++++++---------------------- test/datagen.test.ts | 61 +++++- 3 files changed, 294 insertions(+), 254 deletions(-) diff --git a/package-lock.json b/package-lock.json index 959240e..ea76db6 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@teichai/datagen", - "version": "0.1.11", + "version": "0.1.12", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@teichai/datagen", - "version": "0.1.11", + "version": "0.1.12", "license": "Apache-2.0", "bin": { "datagen": "dist/cli.js" diff --git a/src/index.ts b/src/index.ts index 0a2b63f..88b4de1 100644 --- a/src/index.ts +++ b/src/index.ts @@ -399,32 +399,20 @@ export async function ensureReadableFile(filePath: string) { await fs.access(filePath); } -export async function main(argv = process.argv.slice(2)) { - await maybeNotifyNewVersion({ - cliName: CLI_NAME, - packageName: CLI_PACKAGE_NAME, - currentVersion: CLI_VERSION - }); - - if (hasFlag(argv, FLAG_ALIASES.help)) { - printHelp(); - return; - } - - if (hasFlag(argv, FLAG_ALIASES.version)) { - printVersion(); - return; - } +export type GenerateDatasetOptions = Args & { + apiKey: string; + stderr?: NodeJS.WriteStream; +}; - let parsed: Args; - try { - parsed = parseArgs(argv); - } catch (err: any) { - console.error(err?.message ?? String(err)); - process.exit(1); - return; - } +export type GenerateDatasetResult = { + outPath: string; + completed: number; + okCount: number; + errCount: number; + spentUsd: number; +}; +export async function generateDataset(options: GenerateDatasetOptions): Promise { const { model, promptsPath, @@ -438,27 +426,16 @@ export async function main(argv = process.argv.slice(2)) { openrouterProviderSort, openrouterIsFree, reasoningEffort, - timeout - } = parsed; - - const apiKey = process.env.API_KEY; - if (!apiKey) { - console.error('Missing env var "API_KEY".'); - process.exit(1); - } + timeout, + apiKey, + stderr = process.stderr + } = options; const absPromptsPath = resolve(promptsPath); const absOutPath = resolve(outPath); + await ensureReadableFile(absPromptsPath); - try { - await ensureReadableFile(absPromptsPath); - } catch (err: any) { - console.error(err?.message ?? String(err)); - process.exit(1); - return; - } - - const useProgress = progress && Boolean(process.stderr.isTTY); + const useProgress = progress && Boolean(stderr.isTTY); let totalPrompts = 0; if (useProgress) { try { @@ -467,45 +444,35 @@ export async function main(argv = process.argv.slice(2)) { totalPrompts = 0; } } - const bar = - useProgress && totalPrompts > 0 ? new ProgressBar(totalPrompts, process.stderr) : null; + const bar = useProgress && totalPrompts > 0 ? new ProgressBar(totalPrompts, stderr) : null; const writeLine = (msg: string) => { if (bar) bar.writeLine(msg); - else process.stderr.write(msg + "\n"); + else stderr.write(msg + "\n"); }; const isOpenRouter = isOpenRouterApiBase(apiBase); const useFreeKeys = isOpenRouter && openrouterIsFree; const maxConcurrent = Math.max(1, concurrent); if (useFreeKeys) { - const msg = - "INFO: openrouter.isFree is enabled. API_KEY must be an OpenRouter management key to create per-request keys."; - writeLine(msg); + writeLine( + "INFO: openrouter.isFree is enabled. API_KEY must be an OpenRouter management key to create per-request keys." + ); } let requestKeys = [apiKey]; let spawnedKeyHashes: string[] = []; if (useFreeKeys) { const baseName = `datagen-${Date.now()}`; - try { - const createdKeys = await Promise.all( - Array.from({ length: maxConcurrent }, (_, idx) => - createOpenRouterApiKey(apiBase, apiKey, `${baseName}-${idx + 1}`) - ) - ); - requestKeys = createdKeys.map((item) => item.key); - spawnedKeyHashes = createdKeys - .map((item) => item.hash) - .filter((hash): hash is string => typeof hash === "string"); - } catch (err: any) { - const details = err?.message ?? String(err); - console.error( - `Key creation failed. Please make sure the API_KEY you provided is a management key. ${details}` - ); - console.error("Create one at https://openrouter.ai/settings/management-keys"); - process.exit(1); - return; - } + const createdKeys = await Promise.all( + Array.from({ length: maxConcurrent }, (_, idx) => + createOpenRouterApiKey(apiBase, apiKey, `${baseName}-${idx + 1}`) + ) + ); + requestKeys = createdKeys.map((item) => item.key); + spawnedKeyHashes = createdKeys + .map((item) => item.hash) + .filter((hash): hash is string => typeof hash === "string"); } + const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); const deleteKeyWithRetry = async (hash: string) => { const delays = [250, 500, 1000, 2000]; @@ -521,12 +488,9 @@ export async function main(argv = process.argv.slice(2)) { } } } - const msg = `WARN: Failed to delete key ${hash}: ${ - (lastError as any)?.message ?? String(lastError) - }`; - writeLine(msg); + writeLine(`WARN: Failed to delete key ${hash}: ${(lastError as any)?.message ?? String(lastError)}`); }; - let cleanupStarted = false; + const cleanupKeys = async () => { if (!useFreeKeys) return; const hashes = new Set(spawnedKeyHashes); @@ -535,208 +499,225 @@ export async function main(argv = process.argv.slice(2)) { await deleteKeyWithRetry(hash); } }; - const handleTermination = () => { - if (cleanupStarted) return; - cleanupStarted = true; - writeLine("Cleaning up processes and exiting."); - void cleanupKeys().finally(() => { - process.exit(1); - }); - }; - process.once("SIGINT", handleTermination); - process.once("SIGTERM", handleTermination); - process.once("SIGQUIT", handleTermination); - const pricingKey = requestKeys[0]; - const providerPref = - isOpenRouter && (openrouterProviderOrder || openrouterProviderSort) - ? { - ...(openrouterProviderOrder ? { order: openrouterProviderOrder } : {}), - ...(openrouterProviderSort ? { sort: openrouterProviderSort } : {}) + + try { + const pricingKey = requestKeys[0]; + const providerPref = + isOpenRouter && (openrouterProviderOrder || openrouterProviderSort) + ? { + ...(openrouterProviderOrder ? { order: openrouterProviderOrder } : {}), + ...(openrouterProviderSort ? { sort: openrouterProviderSort } : {}) + } + : undefined; + let pricing: OpenRouterModelPricing | null = null; + if (isOpenRouter) { + try { + pricing = await getOpenRouterModelPricing(apiBase, pricingKey, model); + if (!pricing) { + writeLine(`WARN: Could not find pricing for model "${model}" on OpenRouter.`); } - : undefined; - let pricing: OpenRouterModelPricing | null = null; - if (isOpenRouter) { - try { - pricing = await getOpenRouterModelPricing(apiBase, pricingKey, model); - if (!pricing) { - const msg = `WARN: Could not find pricing for model "${model}" on OpenRouter.`; - if (bar) bar.writeLine(msg); - else process.stderr.write(msg + "\n"); + } catch (err: any) { + writeLine(`WARN: Failed to fetch OpenRouter models/pricing: ${err?.message ?? String(err)}`); } - } catch (err: any) { - const msg = `WARN: Failed to fetch OpenRouter models/pricing: ${err?.message ?? String(err)}`; - if (bar) bar.writeLine(msg); - else process.stderr.write(msg + "\n"); } - } - if (pricing) { - const lines = [ - `Model: ${model}`, - `API: ${apiBase}`, - pricing.modelId !== model - ? `OpenRouter pricing model: ${pricing.modelId}` - : null, - providerPref - ? `OpenRouter provider prefs: ${JSON.stringify(providerPref)}` - : null, - reasoningEffort ? `Reasoning effort: ${reasoningEffort}` : null, - `Pricing (USD per 1M tokens): prompt=${formatUsdPerMillionTokens(pricing.known.prompt, pricing.raw.prompt, pricing.promptPerTokenUSD)} completion=${formatUsdPerMillionTokens(pricing.known.completion, pricing.raw.completion, pricing.completionPerTokenUSD)}`, - `Pricing (USD per token): prompt=${formatUsdOrUnknown(pricing.known.prompt, pricing.raw.prompt, pricing.promptPerTokenUSD)}/token completion=${formatUsdOrUnknown(pricing.known.completion, pricing.raw.completion, pricing.completionPerTokenUSD)}/token`, - `Pricing (USD per request): request=${formatUsdOrUnknown(pricing.known.request, pricing.raw.request, pricing.requestUSD)}/request` - ].filter((l): l is string => Boolean(l)); - for (const l of lines) { - if (bar) bar.writeLine(l); - else process.stderr.write(l + "\n"); - } - - if (!pricing.known.prompt || !pricing.known.completion) { - const msg = "WARN: OpenRouter did not provide token pricing for this model; spent total will be omitted."; - if (bar) bar.writeLine(msg); - else process.stderr.write(msg + "\n"); + if (pricing) { + const lines = [ + `Model: ${model}`, + `API: ${apiBase}`, + pricing.modelId !== model ? `OpenRouter pricing model: ${pricing.modelId}` : null, + providerPref ? `OpenRouter provider prefs: ${JSON.stringify(providerPref)}` : null, + reasoningEffort ? `Reasoning effort: ${reasoningEffort}` : null, + `Pricing (USD per 1M tokens): prompt=${formatUsdPerMillionTokens(pricing.known.prompt, pricing.raw.prompt, pricing.promptPerTokenUSD)} completion=${formatUsdPerMillionTokens(pricing.known.completion, pricing.raw.completion, pricing.completionPerTokenUSD)}`, + `Pricing (USD per token): prompt=${formatUsdOrUnknown(pricing.known.prompt, pricing.raw.prompt, pricing.promptPerTokenUSD)}/token completion=${formatUsdOrUnknown(pricing.known.completion, pricing.raw.completion, pricing.completionPerTokenUSD)}/token`, + `Pricing (USD per request): request=${formatUsdOrUnknown(pricing.known.request, pricing.raw.request, pricing.requestUSD)}/request` + ].filter((l): l is string => Boolean(l)); + for (const l of lines) { + writeLine(l); + } + if (!pricing.known.prompt || !pricing.known.completion) { + writeLine( + "WARN: OpenRouter did not provide token pricing for this model; spent total will be omitted." + ); + } } - } - - const rl = createInterface({ - input: createReadStream(absPromptsPath), - crlfDelay: Infinity - }); - const out = createWriteStream(absOutPath, { flags: "w" }); - - let lineNum = 0; - let completed = 0; - let okCount = 0; - let errCount = 0; - let spentUsd = 0; - const canTrackSpend = Boolean( - pricing && pricing.known.prompt && pricing.known.completion && pricing.known.request - ); - if (bar) - bar.render(0, { - ok: 0, - err: 0, - spentUsd: canTrackSpend ? spentUsd : undefined + const rl = createInterface({ + input: createReadStream(absPromptsPath), + crlfDelay: Infinity }); - - const inFlight = new Set>(); - - let writeQueue = Promise.resolve(); - const writeJsonlLine = (line: string) => { - writeQueue = writeQueue.then( - () => - new Promise((resolve, reject) => { - out.write(line, (err) => { - if (err) reject(err); - else resolve(); - }); - }) + const out = createWriteStream(absOutPath, { flags: "w" }); + + let lineNum = 0; + let completed = 0; + let okCount = 0; + let errCount = 0; + let spentUsd = 0; + const canTrackSpend = Boolean( + pricing && pricing.known.prompt && pricing.known.completion && pricing.known.request ); - return writeQueue; - }; - - const renderProgress = () => { - if (!bar) return; - bar.render(completed, { - ok: okCount, - err: errCount, - spentUsd: canTrackSpend ? spentUsd : undefined - }); - }; + if (bar) bar.render(0, { ok: 0, err: 0, spentUsd: canTrackSpend ? spentUsd : undefined }); + + const inFlight = new Set>(); + let writeQueue = Promise.resolve(); + const writeJsonlLine = (line: string) => { + writeQueue = writeQueue.then( + () => + new Promise((resolve, reject) => { + out.write(line, (err) => { + if (err) reject(err); + else resolve(); + }); + }) + ); + return writeQueue; + }; + + const renderProgress = () => { + if (!bar) return; + bar.render(completed, { + ok: okCount, + err: errCount, + spentUsd: canTrackSpend ? spentUsd : undefined + }); + }; + + const useKeyPool = useFreeKeys && requestKeys.length > 0; + const availableKeys = useKeyPool ? [...requestKeys] : []; + const keyWaiters: Array<(key: string) => void> = []; + const acquireKey = async () => { + if (!useKeyPool) return requestKeys[0]; + if (availableKeys.length > 0) return availableKeys.shift() as string; + return await new Promise((resolve) => { + keyWaiters.push(resolve); + }); + }; + const releaseKey = (key: string) => { + if (!useKeyPool) return; + const waiter = keyWaiters.shift(); + if (waiter) waiter(key); + else availableKeys.push(key); + }; + + const schedule = (line: number, prompt: string) => { + const p = (async () => { + let requestKey = requestKeys[0]; + if (useKeyPool) { + requestKey = await acquireKey(); + } + try { + const { content, reasoning, usage } = await callOpenRouter( + apiBase, + requestKey, + model, + systemPrompt, + prompt, + providerPref, + reasoningEffort, + timeout + ); + + if (canTrackSpend && pricing) { + spentUsd += calculateOpenRouterSpendUSD(pricing, usage); + } + + const assistantContent = formatAssistantContent(content, reasoning); + const messages = buildOutputMessages(systemPrompt, prompt, assistantContent, storeSystem); + await writeJsonlLine(JSON.stringify({ messages }) + "\n"); + okCount++; + } catch (err: any) { + writeLine(`ERR line ${line}: ${err?.message ?? String(err)}`); + errCount++; + } finally { + releaseKey(requestKey); + completed++; + renderProgress(); + } + })(); - const useKeyPool = useFreeKeys && requestKeys.length > 0; - const availableKeys = useKeyPool ? [...requestKeys] : []; - const keyWaiters: Array<(key: string) => void> = []; - const acquireKey = async () => { - if (!useKeyPool) return requestKeys[0]; - if (availableKeys.length > 0) return availableKeys.shift() as string; - return await new Promise((resolve) => { - keyWaiters.push(resolve); - }); - }; - const releaseKey = (key: string) => { - if (!useKeyPool) return; - const waiter = keyWaiters.shift(); - if (waiter) waiter(key); - else availableKeys.push(key); - }; + inFlight.add(p); + p.finally(() => inFlight.delete(p)); + }; - const schedule = (index: number, line: number, prompt: string) => { - const p = (async () => { - let requestKey = requestKeys[0]; - if (useKeyPool) { - requestKey = await acquireKey(); + const waitForSlot = async () => { + while (inFlight.size >= maxConcurrent) { + await Promise.race(inFlight); } - try { - const { content, reasoning, usage } = await callOpenRouter( - apiBase, - requestKey, - model, - systemPrompt, - prompt, - providerPref, - reasoningEffort, - timeout - ); + }; + + for await (const line of rl) { + lineNum++; + const prompt = line.trim(); + if (!prompt) continue; + await waitForSlot(); + schedule(lineNum, prompt); + } - if (canTrackSpend && pricing) { - spentUsd += calculateOpenRouterSpendUSD(pricing, usage); - } + while (inFlight.size > 0) { + await Promise.race(inFlight); + } - const assistantContent = formatAssistantContent(content, reasoning); - const messages = buildOutputMessages( - systemPrompt, - prompt, - assistantContent, - storeSystem - ); + await writeQueue; + await new Promise((resolve, reject) => { + out.end((err?: Error | null) => { + if (err) reject(err); + else resolve(); + }); + }); - await writeJsonlLine(JSON.stringify({ messages }) + "\n"); - okCount++; - } catch (err: any) { - const msg = `ERR line ${line}: ${err?.message ?? String(err)}`; - if (bar) bar.writeLine(msg); - else process.stderr.write(msg + "\n"); - errCount++; - } finally { - releaseKey(requestKey); - completed++; - renderProgress(); - } - })(); + if (bar) { + bar.finish(completed, { + ok: okCount, + err: errCount, + spentUsd: canTrackSpend ? spentUsd : undefined + }); + } - inFlight.add(p); - p.finally(() => inFlight.delete(p)); - }; + return { outPath: absOutPath, completed, okCount, errCount, spentUsd }; + } finally { + await cleanupKeys(); + } +} - const waitForSlot = async () => { - while (inFlight.size >= maxConcurrent) { - await Promise.race(inFlight); - } - }; +export async function main(argv = process.argv.slice(2)) { + await maybeNotifyNewVersion({ + cliName: CLI_NAME, + packageName: CLI_PACKAGE_NAME, + currentVersion: CLI_VERSION + }); - let promptIndex = 0; - for await (const line of rl) { - lineNum++; - const prompt = line.trim(); - if (!prompt) continue; + if (hasFlag(argv, FLAG_ALIASES.help)) { + printHelp(); + return; + } - await waitForSlot(); - schedule(promptIndex, lineNum, prompt); - promptIndex++; + if (hasFlag(argv, FLAG_ALIASES.version)) { + printVersion(); + return; } - while (inFlight.size > 0) { - await Promise.race(inFlight); + let parsed: Args; + try { + parsed = parseArgs(argv); + } catch (err: any) { + console.error(err?.message ?? String(err)); + process.exit(1); + return; } - await cleanupKeys(); - out.end(); - if (bar) { - bar.finish(completed, { - ok: okCount, - err: errCount, - spentUsd: canTrackSpend ? spentUsd : undefined + const apiKey = process.env.API_KEY; + if (!apiKey) { + console.error('Missing env var "API_KEY".'); + process.exit(1); + } + + try { + await generateDataset({ + ...parsed, + apiKey }); + } catch (err: any) { + console.error(err?.message ?? String(err)); + process.exit(1); } } diff --git a/test/datagen.test.ts b/test/datagen.test.ts index 4ca1fbb..625d010 100644 --- a/test/datagen.test.ts +++ b/test/datagen.test.ts @@ -9,8 +9,10 @@ import { buildOutputMessages, formatAssistantContent, callOpenRouter, - ensureReadableFile + ensureReadableFile, + generateDataset } from "../src/index.js"; +import { readFile } from "node:fs/promises"; test("parseArgs requires model and prompts", () => { assert.throws(() => parseArgs([]), /Usage:/); @@ -321,3 +323,60 @@ test("ensureReadableFile throws if missing or not a file", async () => { await writeFile(file, "hello\n"); await ensureReadableFile(file); }); + +test("generateDataset writes JSONL output for prompts", async () => { + const dir = await mkdtemp(join(tmpdir(), "datagen-")); + const promptsPath = join(dir, "prompts.txt"); + const outPath = join(dir, "dataset.jsonl"); + await writeFile(promptsPath, "first\n\nsecond\n"); + + globalThis.fetch = async () => { + return { + ok: true, + status: 200, + async json() { + return { + choices: [ + { + message: { + content: "assistant" + } + } + ] + }; + } + } as any; + }; + + const result = await generateDataset({ + model: "model-x", + promptsPath, + outPath, + apiBase: "https://example.com/api/v1", + systemPrompt: "", + storeSystem: true, + progress: false, + concurrent: 1, + openrouterProviderOrder: null, + openrouterProviderSort: null, + openrouterIsFree: false, + reasoningEffort: null, + timeout: null, + apiKey: "KEY" + }); + + assert.equal(result.completed, 2); + assert.equal(result.okCount, 2); + assert.equal(result.errCount, 0); + + const content = await readFile(outPath, "utf8"); + const lines = content.trim().split("\n"); + assert.equal(lines.length, 2); + const parsed = lines.map((line) => JSON.parse(line)); + assert.deepEqual(parsed[0], { + messages: [ + { role: "user", content: "first" }, + { role: "assistant", content: "assistant" } + ] + }); +});