diff --git a/lib/fetch.test.ts b/lib/fetch.test.ts index 7cdc6312..9e6e1965 100644 --- a/lib/fetch.test.ts +++ b/lib/fetch.test.ts @@ -1,5 +1,6 @@ -import { describe, test, expect } from "bun:test"; +import { describe, test, expect, spyOn } from "bun:test"; import { Client } from "./fetch.ts"; +import * as ssrf from "./ssrf.ts"; describe("Client with SSRF protection", () => { test("blocks request to localhost", async () => { @@ -153,3 +154,121 @@ describe("Client allowedHosts", () => { await expect(result).rejects.not.toThrow(/host is not allowed/); }); }); + +describe("DNS rebinding protection", () => { + test("resolves DNS once and connects to validated IP", async () => { + // Track how many times resolveHostname is called + const resolveHostnameSpy = spyOn(ssrf, "resolveHostname"); + + const client = new Client({ + timeoutMs: 5000, + bodyLimit: 1024 * 1024, + ssrfProtection: true, + }); + + // This will fail because localhost resolves to a private IP + await expect(client.fetch("http://localhost/")).rejects.toThrow(/private IP/); + + // resolveHostname should only be called once + expect(resolveHostnameSpy).toHaveBeenCalledTimes(1); + expect(resolveHostnameSpy).toHaveBeenCalledWith("localhost"); + + resolveHostnameSpy.mockRestore(); + }); + + test("blocks hostname that resolves to private IP after DNS lookup", async () => { + // Mock resolveHostname to return a private IP for a "safe looking" hostname + const resolveHostnameMock = spyOn(ssrf, "resolveHostname").mockResolvedValue([ + "127.0.0.1", + ]); + + const client = new Client({ + timeoutMs: 5000, + bodyLimit: 1024 * 1024, + ssrfProtection: true, + }); + + // Even though the hostname looks safe, it resolves to a private IP + await expect(client.fetch("http://safe-looking-host.com/")).rejects.toThrow( + /private IP/, + ); + + resolveHostnameMock.mockRestore(); + }); + + test("uses first public IP when multiple IPs resolved", async () => { + // Mock to return mixed IPs - private first, then public + const resolveHostnameMock = spyOn(ssrf, "resolveHostname").mockResolvedValue([ + "127.0.0.1", // private - should be skipped + "10.0.0.1", // private - should be skipped + "8.8.8.8", // public - should be used + ]); + + const client = new Client({ + timeoutMs: 1000, // Short timeout - we just want to verify no SSRF error + bodyLimit: 1024 * 1024, + ssrfProtection: true, + }); + + // Should succeed past SSRF check and fail with connection error + // (because 8.8.8.8 doesn't serve HTTP on port 80) + // The important thing is it doesn't throw "private IP" error + const result = client.fetch("http://example.com/"); + await expect(result).rejects.not.toThrow(/private IP/); + + resolveHostnameMock.mockRestore(); + }); + + test("blocks when all resolved IPs are private", async () => { + // Mock to return only private IPs + const resolveHostnameMock = spyOn(ssrf, "resolveHostname").mockResolvedValue([ + "127.0.0.1", + "10.0.0.1", + "192.168.1.1", + ]); + + const client = new Client({ + timeoutMs: 5000, + bodyLimit: 1024 * 1024, + ssrfProtection: true, + }); + + await expect(client.fetch("http://example.com/")).rejects.toThrow(/private IP/); + + resolveHostnameMock.mockRestore(); + }); + + test("validates redirect destinations for DNS rebinding", async () => { + // Start a simple server that redirects + const server = Bun.serve({ + port: 0, + fetch(req) { + const url = new URL(req.url); + if (url.pathname === "/redirect") { + // Redirect to localhost (should be blocked) + return new Response(null, { + status: 302, + headers: { Location: "http://127.0.0.1/target" }, + }); + } + return new Response("OK"); + }, + }); + + try { + const client = new Client({ + timeoutMs: 5000, + bodyLimit: 1024 * 1024, + ssrfProtection: true, + }); + + // Request to the server's redirect endpoint + // The redirect target (127.0.0.1) should be blocked + await expect( + client.fetch(`http://127.0.0.1:${server.port}/redirect`), + ).rejects.toThrow(/private IP/); + } finally { + server.stop(); + } + }); +}); diff --git a/lib/fetch.ts b/lib/fetch.ts index ee555725..65ddb56c 100644 --- a/lib/fetch.ts +++ b/lib/fetch.ts @@ -1,4 +1,4 @@ -import { validateUrlForSSRF } from "./ssrf.ts"; +import { resolveHostname, isPrivateIP, parseIPv4, parseIPv6 } from "./ssrf.ts"; import { HttpError } from "./types.ts"; type ReadResult = { done: false; value: T } | { done: true; value?: T }; @@ -64,9 +64,6 @@ export class Client { let redirectCount = 0; while (true) { - // Validate host before each request (initial + redirects) - await this.validateHost(currentUrl); - const res = await this.makeRequest(currentUrl); // Check for redirect responses @@ -107,12 +104,22 @@ export class Client { } // Performs a single HTTP request with timeout and manual redirect handling. + // DNS resolution and SSRF validation are done atomically to prevent DNS rebinding. private async makeRequest(url: string): Promise { + const parsed = this.validateUrl(url); + + // Resolve hostname to IP and validate for SSRF in one atomic step. + // The request is then made directly to the validated IP. + const { targetUrl, host } = await this.resolveAndValidate(parsed); + let res: Response; try { - res = await fetch(url, { + res = await fetch(targetUrl, { signal: AbortSignal.timeout(this.timeoutMs), redirect: "manual", + headers: { + Host: host, + }, }); } catch { throw new HttpError(400, "fetch: unable to make request"); @@ -121,6 +128,73 @@ export class Client { return res; } + // Resolves hostname to IP and validates it for SSRF. + // Returns a URL with the IP address and the original host for the Host header. + // This ensures DNS resolution and validation happen atomically - no TOCTTOU. + private async resolveAndValidate( + parsed: URL, + ): Promise<{ targetUrl: string; host: string }> { + const hostname = parsed.hostname; + + // If SSRF protection is disabled, just return the original URL + if (!this.ssrfProtection) { + return { targetUrl: parsed.toString(), host: parsed.host }; + } + + // Check if hostname is already an IP address + const isIPv4 = parseIPv4(hostname) !== null; + const cleanedHostname = hostname.replace(/^\[|\]$/g, ""); + const isIPv6 = parseIPv6(cleanedHostname) !== null; + + if (isIPv4 || isIPv6) { + // Direct IP address - validate it + const ip = isIPv6 ? cleanedHostname : hostname; + const result = isPrivateIP(ip); + if (result.isPrivate) { + throw new HttpError( + 403, + `fetch: host resolves to private IP (${result.reason})`, + ); + } + return { targetUrl: parsed.toString(), host: parsed.host }; + } + + // Resolve hostname to IP addresses + const ips = await resolveHostname(hostname); + + // Find the first non-private IP + for (const ip of ips) { + const result = isPrivateIP(ip); + if (!result.isPrivate) { + // Build URL with IP address instead of hostname + const targetUrl = this.buildUrlWithIP(parsed, ip); + return { targetUrl, host: parsed.host }; + } + } + + // All IPs are private + throw new HttpError(403, "fetch: host resolves to private IP"); + } + + // Builds a URL replacing the hostname with an IP address. + private buildUrlWithIP(original: URL, ip: string): string { + // Handle IPv6 addresses - they need brackets in URLs + const hostPart = ip.includes(":") ? `[${ip}]` : ip; + + // Reconstruct URL with IP + let url = `${original.protocol}//${hostPart}`; + + // Add port if non-default + if (original.port) { + url += `:${original.port}`; + } + + // Add path, search, and hash + url += original.pathname + original.search + original.hash; + + return url; + } + // Reads the response body into a pre-allocated buffer when Content-Length is known. private async readWithKnownLength( reader: StreamReader, @@ -193,7 +267,8 @@ export class Client { ); } - private async validateHost(url: string): Promise { + // Validates URL scheme and allowed hosts (does not do SSRF IP validation). + private validateUrl(url: string): URL { let parsed: URL; try { parsed = new URL(url); @@ -211,10 +286,7 @@ export class Client { throw new HttpError(403, "fetch: host is not allowed"); } - // SSRF protection: validate IP addresses - if (this.ssrfProtection) { - await validateUrlForSSRF(parsed); - } + return parsed; } }