diff --git a/src/bindings.d.ts b/src/bindings.d.ts index 60e7905..bfb6c4b 100644 --- a/src/bindings.d.ts +++ b/src/bindings.d.ts @@ -3,7 +3,7 @@ import type { Context } from '@web3-storage/gateway-lib' import type { CARLink } from 'cardex/api' import type { R2Bucket, KVNamespace } from '@cloudflare/workers-types' import type { MemoryBudget } from './lib/mem-budget' -import { CID } from '@web3-storage/gateway-lib/handlers' +import { CID } from 'multiformats' export {} @@ -13,8 +13,9 @@ export interface Environment { CONTENT_CLAIMS_SERVICE_URL?: string RATE_LIMITS_SERVICE_URL?: string ACCOUNTING_SERVICE_URL: string - MY_RATE_LIMITER: RateLimit + MY_RATE_LIMITER: KVNamespace AUTH_TOKEN_METADATA: KVNamespace + FF_RATE_LIMITER_ENABLED: string } export type GetCIDRequestData = Pick @@ -25,13 +26,24 @@ export interface RateLimitsService { check: (cid: CID, options: GetCIDRequestOptions) => Promise } +export interface RateLimitConfig { + requests: number + window: number + concurrent: number +} + export interface TokenMetadata { - locationClaim?: unknown // TODO: figure out the right type to use for this - we probably need it for the private data case to verify auth + id: string invalid?: boolean + rateLimits?: RateLimitConfig + origins?: string[] + expiresAt?: number } export interface RateLimits { - create: ({ env }: { env: Environment }) => RateLimitsService + create: (options: { env: Environment }) => { + check: (cid: CID, request: Request) => Promise + } } export interface AccountingService { @@ -40,6 +52,18 @@ export interface AccountingService { } export interface Accounting { - create: ({ serviceURL }: { serviceURL?: string }) => AccountingService + create: (options: { serviceURL: string }) => { + record: (cid: CID, options: any) => Promise + getTokenMetadata: (token: string) => Promise + } +} + +export enum RATE_LIMIT_EXCEEDED { + YES = 'yes', + NO = 'no' +} + +export interface ExecutionContext extends EventContext { + waitUntil(promise: Promise): voidUU } diff --git a/src/middleware.js b/src/middleware.js index 78ccf4b..c3ee5da 100644 --- a/src/middleware.js +++ b/src/middleware.js @@ -16,61 +16,172 @@ import { handleCarBlock } from './handlers/car-block.js' */ /** - * - * @param {string} s - * @returns {import('./bindings.js').TokenMetadata} + * Cache configuration + * @type {{ + * DEFAULT_TTL: number, + * STALE_TTL: number, + * REVALIDATE_AFTER: number + * }} */ -function deserializeTokenMetadata(s) { - // TODO should this be dag-json? - return JSON.parse(s) +const CACHE_CONFIG = { + DEFAULT_TTL: 3600, // 1 hour default TTL + STALE_TTL: 300, // 5 minutes before considered stale + REVALIDATE_AFTER: 3300 // Revalidate after 55 minutes } /** - * - * @param {import('./bindings.js').TokenMetadata} m - * @returns string + * Token metadata with cache control + * @typedef {Object} CachedTokenMetadata + * @property {import('./bindings.js').TokenMetadata} data + * @property {number} timestamp + * @property {number} expiresAt */ -function serializeTokenMetadata(m) { - // TODO should this be dag-json? - return JSON.stringify(m) + +/** + * Serialize token metadata with cache control + * @param {import('./bindings.js').TokenMetadata} metadata + * @returns {string} + */ +function serializeTokenMetadata(metadata) { + const cached = { + data: metadata, + timestamp: Date.now(), + expiresAt: Date.now() + CACHE_CONFIG.DEFAULT_TTL * 1000 + } + return JSON.stringify(cached) } /** - * + * Deserialize cached token metadata + * @param {string} cached + * @returns {CachedTokenMetadata} + */ +function deserializeTokenMetadata(cached) { + return JSON.parse(cached) +} + +/** + * Fetch fresh token metadata from the accounting service * @param {Environment} env - * @param {import('@web3-storage/gateway-lib/handlers').CID} cid + * @param {string} authToken + * @returns {Promise} */ -async function checkRateLimitForCID(env, cid) { - const rateLimitResponse = await env.MY_RATE_LIMITER.limit({ key: cid.toString() }) - if (rateLimitResponse.success) { - return RATE_LIMIT_EXCEEDED.NO - } else { - console.log(`limiting CID ${cid}`) - return RATE_LIMIT_EXCEEDED.YES - } +async function fetchTokenMetadata(env, authToken) { + const accounting = Accounting.create({ serviceURL: env.ACCOUNTING_SERVICE_URL }) + return await accounting.getTokenMetadata(authToken) } /** - * + * Get token metadata with SWR caching pattern * @param {Environment} env - * @param {string} authToken - * @returns TokenMetadata + * @param {string} authToken + * @param {ExecutionContext} ctx + * @returns {Promise} */ -async function getTokenMetadata(env, authToken) { +async function getTokenMetadata(env, authToken, ctx) { const cachedValue = await env.AUTH_TOKEN_METADATA.get(authToken) - // TODO: we should implement an SWR pattern here - record an expiry in the metadata and if the expiry has passed, re-validate the cache after - // returning the value + if (cachedValue) { - return deserializeTokenMetadata(cachedValue) - } else { - const accounting = Accounting.create({ serviceURL: env.ACCOUNTING_SERVICE_URL }) - const tokenMetadata = await accounting.getTokenMetadata(authToken) - if (tokenMetadata) { - await env.AUTH_TOKEN_METADATA.put(authToken, serializeTokenMetadata(tokenMetadata)) - return tokenMetadata - } else { - return null + const cached = deserializeTokenMetadata(cachedValue) + const now = Date.now() + + // Return cached data immediately if not expired + if (now < cached.expiresAt) { + // If approaching expiration, trigger background refresh + if (now > cached.timestamp + CACHE_CONFIG.REVALIDATE_AFTER * 1000) { + ctx.waitUntil(refreshTokenMetadata(env, authToken)) + } + return cached.data + } + + // If expired but within stale window, return stale data and trigger refresh + if (now < cached.expiresAt + CACHE_CONFIG.STALE_TTL * 1000) { + ctx.waitUntil(refreshTokenMetadata(env, authToken)) + return cached.data + } + } + + // No cache or expired beyond stale window - fetch fresh data + return await refreshTokenMetadata(env, authToken) +} + +/** + * Refresh token metadata in cache + * @param {Environment} env + * @param {string} authToken + * @returns {Promise} + */ +async function refreshTokenMetadata(env, authToken) { + try { + const freshData = await fetchTokenMetadata(env, authToken) + if (freshData) { + await env.AUTH_TOKEN_METADATA.put( + authToken, + serializeTokenMetadata(freshData) + ) + return freshData } + return null + } catch (error) { + console.error('Error refreshing token metadata:', error) + return null + } +} + +/** + * Default rate limits for anonymous users + * @type {import('./bindings.js').RateLimitConfig} + */ +const DEFAULT_RATE_LIMITS = { + requests: 100, // requests per window + window: 60, // window size in seconds + concurrent: 5 // max concurrent requests +} + +/** + * Check rate limits for a given CID and token + * @param {Environment} env + * @param {import('multiformats').CID} cid + * @param {string | null} token + * @param {import('./bindings.js').TokenMetadata | null} tokenMetadata + * @returns {Promise} + */ +async function checkRateLimitForRequest(env, cid, token, tokenMetadata) { + // Get appropriate limits based on token metadata or defaults + const limits = tokenMetadata?.rateLimits || DEFAULT_RATE_LIMITS + + // Create a unique key that includes token (if present) and CID + const key = token ? `${token}:${cid.toString()}` : cid.toString() + + // Check concurrent requests first + const concurrentKey = `concurrent:${key}` + const concurrent = parseInt(await env.MY_RATE_LIMITER.get(concurrentKey) || '0', 10) + + if (concurrent >= limits.concurrent) { + console.warn(`Concurrent limit exceeded for ${key}`) + return RATE_LIMIT_EXCEEDED.YES + } + + // Increment concurrent requests + await env.MY_RATE_LIMITER.put(concurrentKey, (concurrent + 1).toString(), { expirationTtl: 60 }) + + try { + // Check rate limits + const rateLimitResponse = await env.MY_RATE_LIMITER.limit({ + key, + requests: limits.requests, + window: limits.window + }) + + if (!rateLimitResponse.success) { + console.warn(`Rate limit exceeded for ${key}`) + return RATE_LIMIT_EXCEEDED.YES + } + + return RATE_LIMIT_EXCEEDED.NO + } finally { + // Decrement concurrent requests count + await env.MY_RATE_LIMITER.put(concurrentKey, concurrent.toString()) } } @@ -79,30 +190,33 @@ async function getTokenMetadata(env, authToken) { */ const RateLimits = { create: ({ env }) => ({ - check: async (cid, options) => { - const authToken = await getAuthorizationTokenFromRequest(options) + check: async (cid, request) => { + const authToken = await getAuthorizationTokenFromRequest(request) + let tokenMetadata = null + if (authToken) { - console.log(`found token ${authToken}, looking for content commitment`) - const tokenMetadata = await getTokenMetadata(env, authToken) - - if (tokenMetadata) { - if (tokenMetadata.invalid) { - // this means we know about the token and we know it's invalid, so we should just use the CID rate limit - return checkRateLimitForCID(env, cid) - } else { - // TODO at some point we should enforce user configurable rate limits and origin matching - // but for now we just serve all valid token requests - return RATE_LIMIT_EXCEEDED.NO + console.log(`Found token ${authToken}, checking metadata`) + // Create an execution context for background tasks + const executionCtx = { + waitUntil: (promise) => { + // In browser environment, we need to handle this differently + if (typeof WorkerGlobalScope !== 'undefined') { + return self.waitUntil(promise) + } + // For other environments, we'll just await the promise + return promise } - } else { - // we didn't get any metadata - for now just use the top level rate limit - // this means token based requests will be subject to normal rate limits until the data propagates - return checkRateLimitForCID(env, cid) } - } else { - // no token, use normal rate limit - return checkRateLimitForCID(env, cid) + + tokenMetadata = await getTokenMetadata(env, authToken, executionCtx) + + if (tokenMetadata?.invalid) { + console.warn(`Invalid token ${authToken} attempting access`) + return RATE_LIMIT_EXCEEDED.YES + } } + + return checkRateLimitForRequest(env, cid, authToken, tokenMetadata) } }) } @@ -125,18 +239,45 @@ const Accounting = { } /** - * + * Validates the token format and structure + * @param {string} token - The token to validate + * @returns {boolean} + */ +function isValidTokenFormat(token) { + if (!token || typeof token !== 'string') return false + + // Token should be at least 32 characters long for security + if (token.length < 32) return false + + // Token should be base64url encoded + const base64urlRegex = /^[A-Za-z0-9_-]+$/ + if (!base64urlRegex.test(token)) return false + + return true +} + +/** + * Gets and validates the authorization token from the request * @param {Pick} request - * @returns string + * @returns {Promise} */ async function getAuthorizationTokenFromRequest(request) { - // TODO this is probably wrong - const authToken = request.headers.get('Authorization') - return authToken + const authHeader = request.headers.get('Authorization') + if (!authHeader) return null + + // Validate Bearer token format + if (!authHeader.startsWith('Bearer ')) return null + + const token = authHeader.slice(7).trim() + if (!isValidTokenFormat(token)) { + console.warn('Invalid token format detected') + return null + } + + return token } /** - * * @type {import('@web3-storage/gateway-lib').Middleware} */ export function withRateLimits(handler) {