diff --git a/convex/httpApiV1.ts b/convex/httpApiV1.ts index 993ea28..bc2b193 100644 --- a/convex/httpApiV1.ts +++ b/convex/httpApiV1.ts @@ -1,18 +1,12 @@ import { CliPublishRequestSchema, parseArk } from 'clawdhub-schema' import { api, internal } from './_generated/api' import type { Doc, Id } from './_generated/dataModel' -import type { ActionCtx } from './_generated/server' -import { httpAction } from './_generated/server' +import { type ActionCtx, httpAction } from './_generated/server' import { requireApiTokenUser } from './lib/apiTokenAuth' -import { hashToken } from './lib/tokens' +import { applyRateLimit, parseBearerToken } from './lib/httpRateLimit' import { publishVersionForUser } from './skills' import { publishSoulVersionForUser } from './souls' -const RATE_LIMIT_WINDOW_MS = 60_000 -const RATE_LIMITS = { - read: { ip: 120, key: 600 }, - write: { ip: 30, key: 120 }, -} as const const MAX_RAW_FILE_BYTES = 200 * 1024 type SearchSkillEntry = { @@ -631,87 +625,6 @@ async function resolveTags( return resolved } -async function applyRateLimit( - ctx: ActionCtx, - request: Request, - kind: 'read' | 'write', -): Promise<{ ok: true; headers: HeadersInit } | { ok: false; response: Response }> { - const ip = getClientIp(request) ?? 'unknown' - const ipResult = await checkRateLimit(ctx, `ip:${ip}`, RATE_LIMITS[kind].ip) - const token = parseBearerToken(request) - const keyResult = token - ? await checkRateLimit(ctx, `key:${await hashToken(token)}`, RATE_LIMITS[kind].key) - : null - - const chosen = pickMostRestrictive(ipResult, keyResult) - const headers = rateHeaders(chosen) - - if (!ipResult.allowed || (keyResult && !keyResult.allowed)) { - return { - ok: false, - response: text('Rate limit exceeded', 429, headers), - } - } - - return { ok: true, headers } -} - -type RateLimitResult = { - allowed: boolean - remaining: number - limit: number - resetAt: number -} - -async function checkRateLimit( - ctx: ActionCtx, - key: string, - limit: number, -): Promise { - return (await ctx.runMutation(internal.rateLimits.checkRateLimitInternal, { - key, - limit, - windowMs: RATE_LIMIT_WINDOW_MS, - })) as RateLimitResult -} - -function pickMostRestrictive(primary: RateLimitResult, secondary: RateLimitResult | null) { - if (!secondary) return primary - if (!primary.allowed) return primary - if (!secondary.allowed) return secondary - return secondary.remaining < primary.remaining ? secondary : primary -} - -function rateHeaders(result: RateLimitResult): HeadersInit { - const resetSeconds = Math.ceil(result.resetAt / 1000) - return { - 'X-RateLimit-Limit': String(result.limit), - 'X-RateLimit-Remaining': String(result.remaining), - 'X-RateLimit-Reset': String(resetSeconds), - ...(result.allowed ? {} : { 'Retry-After': String(resetSeconds) }), - } -} - -function getClientIp(request: Request) { - const header = - request.headers.get('cf-connecting-ip') ?? - request.headers.get('x-real-ip') ?? - request.headers.get('x-forwarded-for') ?? - request.headers.get('fly-client-ip') - if (!header) return null - if (header.includes(',')) return header.split(',')[0]?.trim() || null - return header.trim() -} - -function parseBearerToken(request: Request) { - const header = request.headers.get('authorization') ?? request.headers.get('Authorization') - if (!header) return null - const trimmed = header.trim() - if (!trimmed.toLowerCase().startsWith('bearer ')) return null - const token = trimmed.slice(7).trim() - return token || null -} - function json(value: unknown, status = 200, headers?: HeadersInit) { return new Response(JSON.stringify(value), { status, diff --git a/convex/lib/httpRateLimit.test.ts b/convex/lib/httpRateLimit.test.ts new file mode 100644 index 0000000..c5d1ee9 --- /dev/null +++ b/convex/lib/httpRateLimit.test.ts @@ -0,0 +1,35 @@ +/* @vitest-environment node */ +import { describe, expect, it } from 'vitest' +import { getClientIp } from './httpRateLimit' + +describe('getClientIp', () => { + it('returns null when cf-connecting-ip missing', () => { + const request = new Request('https://example.com', { + headers: { + 'x-forwarded-for': '203.0.113.9', + }, + }) + process.env.TRUST_FORWARDED_IPS = '' + expect(getClientIp(request)).toBeNull() + }) + + it('returns first ip from cf-connecting-ip', () => { + const request = new Request('https://example.com', { + headers: { + 'cf-connecting-ip': '203.0.113.1, 198.51.100.2', + }, + }) + expect(getClientIp(request)).toBe('203.0.113.1') + }) + + it('uses forwarded headers when opt-in enabled', () => { + const request = new Request('https://example.com', { + headers: { + 'x-forwarded-for': '203.0.113.9, 198.51.100.2', + }, + }) + process.env.TRUST_FORWARDED_IPS = 'true' + expect(getClientIp(request)).toBe('203.0.113.9') + process.env.TRUST_FORWARDED_IPS = '' + }) +}) diff --git a/convex/lib/httpRateLimit.ts b/convex/lib/httpRateLimit.ts new file mode 100644 index 0000000..5615c23 --- /dev/null +++ b/convex/lib/httpRateLimit.ts @@ -0,0 +1,112 @@ +import { internal } from '../_generated/api' +import type { ActionCtx } from '../_generated/server' +import { hashToken } from './tokens' + +const RATE_LIMIT_WINDOW_MS = 60_000 +export const RATE_LIMITS = { + read: { ip: 120, key: 600 }, + write: { ip: 30, key: 120 }, +} as const + +type RateLimitResult = { + allowed: boolean + remaining: number + limit: number + resetAt: number +} + +export async function applyRateLimit( + ctx: ActionCtx, + request: Request, + kind: keyof typeof RATE_LIMITS, +): Promise<{ ok: true; headers: HeadersInit } | { ok: false; response: Response }> { + const ip = getClientIp(request) ?? 'unknown' + const ipResult = await checkRateLimit(ctx, `ip:${ip}`, RATE_LIMITS[kind].ip) + const token = parseBearerToken(request) + const keyResult = token + ? await checkRateLimit(ctx, `key:${await hashToken(token)}`, RATE_LIMITS[kind].key) + : null + + const chosen = pickMostRestrictive(ipResult, keyResult) + const headers = rateHeaders(chosen) + + if (!ipResult.allowed || (keyResult && !keyResult.allowed)) { + return { + ok: false, + response: new Response('Rate limit exceeded', { + status: 429, + headers: mergeHeaders( + { + 'Content-Type': 'text/plain; charset=utf-8', + 'Cache-Control': 'no-store', + }, + headers, + ), + }), + } + } + + return { ok: true, headers } +} + +export function getClientIp(request: Request) { + const header = request.headers.get('cf-connecting-ip') + if (!header) { + if (!shouldTrustForwardedIps()) return null + const forwarded = + request.headers.get('x-real-ip') ?? + request.headers.get('x-forwarded-for') ?? + request.headers.get('fly-client-ip') + if (!forwarded) return null + if (forwarded.includes(',')) return forwarded.split(',')[0]?.trim() || null + return forwarded.trim() + } + if (header.includes(',')) return header.split(',')[0]?.trim() || null + return header.trim() +} + +async function checkRateLimit( + ctx: ActionCtx, + key: string, + limit: number, +): Promise { + return (await ctx.runMutation(internal.rateLimits.checkRateLimitInternal, { + key, + limit, + windowMs: RATE_LIMIT_WINDOW_MS, + })) as RateLimitResult +} + +function pickMostRestrictive(primary: RateLimitResult, secondary: RateLimitResult | null) { + if (!secondary) return primary + if (!primary.allowed) return primary + if (!secondary.allowed) return secondary + return secondary.remaining < primary.remaining ? secondary : primary +} + +function rateHeaders(result: RateLimitResult): HeadersInit { + const resetSeconds = Math.ceil(result.resetAt / 1000) + return { + 'X-RateLimit-Limit': String(result.limit), + 'X-RateLimit-Remaining': String(result.remaining), + 'X-RateLimit-Reset': String(resetSeconds), + ...(result.allowed ? {} : { 'Retry-After': String(resetSeconds) }), + } +} + +export function parseBearerToken(request: Request) { + const header = request.headers.get('authorization') ?? request.headers.get('Authorization') + if (!header) return null + const trimmed = header.trim() + if (!trimmed.toLowerCase().startsWith('bearer ')) return null + const token = trimmed.slice(7).trim() + return token || null +} + +function mergeHeaders(base: HeadersInit, extra?: HeadersInit) { + return { ...(base as Record), ...(extra as Record) } +} + +function shouldTrustForwardedIps() { + return String(process.env.TRUST_FORWARDED_IPS ?? '').toLowerCase() === 'true' +}