Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 1 addition & 120 deletions lib/fetch.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { describe, test, expect, spyOn } from "bun:test";
import { describe, test, expect } from "bun:test";
import { Client } from "./fetch.ts";
import * as ssrf from "./ssrf.ts";

describe("Client with SSRF protection", () => {
test("blocks request to localhost", async () => {
Expand Down Expand Up @@ -154,121 +153,3 @@ describe("Client allowedHosts", () => {
await expect(result).rejects.not.toThrow(/host is not allowed/);
});
});

describe("DNS rebinding protection", () => {
test("resolves DNS once and connects to validated IP", async () => {
// Track how many times resolveHostname is called
const resolveHostnameSpy = spyOn(ssrf, "resolveHostname");

const client = new Client({
timeoutMs: 5000,
bodyLimit: 1024 * 1024,
ssrfProtection: true,
});

// This will fail because localhost resolves to a private IP
await expect(client.fetch("http://localhost/")).rejects.toThrow(/private IP/);

// resolveHostname should only be called once
expect(resolveHostnameSpy).toHaveBeenCalledTimes(1);
expect(resolveHostnameSpy).toHaveBeenCalledWith("localhost");

resolveHostnameSpy.mockRestore();
});

test("blocks hostname that resolves to private IP after DNS lookup", async () => {
// Mock resolveHostname to return a private IP for a "safe looking" hostname
const resolveHostnameMock = spyOn(ssrf, "resolveHostname").mockResolvedValue([
"127.0.0.1",
]);

const client = new Client({
timeoutMs: 5000,
bodyLimit: 1024 * 1024,
ssrfProtection: true,
});

// Even though the hostname looks safe, it resolves to a private IP
await expect(client.fetch("http://safe-looking-host.com/")).rejects.toThrow(
/private IP/,
);

resolveHostnameMock.mockRestore();
});

test("uses first public IP when multiple IPs resolved", async () => {
// Mock to return mixed IPs - private first, then public
const resolveHostnameMock = spyOn(ssrf, "resolveHostname").mockResolvedValue([
"127.0.0.1", // private - should be skipped
"10.0.0.1", // private - should be skipped
"8.8.8.8", // public - should be used
]);

const client = new Client({
timeoutMs: 1000, // Short timeout - we just want to verify no SSRF error
bodyLimit: 1024 * 1024,
ssrfProtection: true,
});

// Should succeed past SSRF check and fail with connection error
// (because 8.8.8.8 doesn't serve HTTP on port 80)
// The important thing is it doesn't throw "private IP" error
const result = client.fetch("http://example.com/");
await expect(result).rejects.not.toThrow(/private IP/);

resolveHostnameMock.mockRestore();
});

test("blocks when all resolved IPs are private", async () => {
// Mock to return only private IPs
const resolveHostnameMock = spyOn(ssrf, "resolveHostname").mockResolvedValue([
"127.0.0.1",
"10.0.0.1",
"192.168.1.1",
]);

const client = new Client({
timeoutMs: 5000,
bodyLimit: 1024 * 1024,
ssrfProtection: true,
});

await expect(client.fetch("http://example.com/")).rejects.toThrow(/private IP/);

resolveHostnameMock.mockRestore();
});

test("validates redirect destinations for DNS rebinding", async () => {
// Start a simple server that redirects
const server = Bun.serve({
port: 0,
fetch(req) {
const url = new URL(req.url);
if (url.pathname === "/redirect") {
// Redirect to localhost (should be blocked)
return new Response(null, {
status: 302,
headers: { Location: "http://127.0.0.1/target" },
});
}
return new Response("OK");
},
});

try {
const client = new Client({
timeoutMs: 5000,
bodyLimit: 1024 * 1024,
ssrfProtection: true,
});

// Request to the server's redirect endpoint
// The redirect target (127.0.0.1) should be blocked
await expect(
client.fetch(`http://127.0.0.1:${server.port}/redirect`),
).rejects.toThrow(/private IP/);
} finally {
server.stop();
}
});
});
92 changes: 10 additions & 82 deletions lib/fetch.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { resolveHostname, isPrivateIP, parseIPv4, parseIPv6 } from "./ssrf.ts";
import { validateUrlForSSRF } from "./ssrf.ts";
import { HttpError } from "./types.ts";

type ReadResult<T> = { done: false; value: T } | { done: true; value?: T };
Expand Down Expand Up @@ -64,6 +64,9 @@ export class Client {
let redirectCount = 0;

while (true) {
// Validate host before each request (initial + redirects)
await this.validateHost(currentUrl);

const res = await this.makeRequest(currentUrl);

// Check for redirect responses
Expand Down Expand Up @@ -104,22 +107,12 @@ export class Client {
}

// Performs a single HTTP request with timeout and manual redirect handling.
// DNS resolution and SSRF validation are done atomically to prevent DNS rebinding.
private async makeRequest(url: string): Promise<Response> {
const parsed = this.validateUrl(url);

// Resolve hostname to IP and validate for SSRF in one atomic step.
// The request is then made directly to the validated IP.
const { targetUrl, host } = await this.resolveAndValidate(parsed);

let res: Response;
try {
res = await fetch(targetUrl, {
res = await fetch(url, {
signal: AbortSignal.timeout(this.timeoutMs),
redirect: "manual",
headers: {
Host: host,
},
});
} catch {
throw new HttpError(400, "fetch: unable to make request");
Expand All @@ -128,73 +121,6 @@ export class Client {
return res;
}

// Resolves hostname to IP and validates it for SSRF.
// Returns a URL with the IP address and the original host for the Host header.
// This ensures DNS resolution and validation happen atomically - no TOCTTOU.
private async resolveAndValidate(
parsed: URL,
): Promise<{ targetUrl: string; host: string }> {
const hostname = parsed.hostname;

// If SSRF protection is disabled, just return the original URL
if (!this.ssrfProtection) {
return { targetUrl: parsed.toString(), host: parsed.host };
}

// Check if hostname is already an IP address
const isIPv4 = parseIPv4(hostname) !== null;
const cleanedHostname = hostname.replace(/^\[|\]$/g, "");
const isIPv6 = parseIPv6(cleanedHostname) !== null;

if (isIPv4 || isIPv6) {
// Direct IP address - validate it
const ip = isIPv6 ? cleanedHostname : hostname;
const result = isPrivateIP(ip);
if (result.isPrivate) {
throw new HttpError(
403,
`fetch: host resolves to private IP (${result.reason})`,
);
}
return { targetUrl: parsed.toString(), host: parsed.host };
}

// Resolve hostname to IP addresses
const ips = await resolveHostname(hostname);

// Find the first non-private IP
for (const ip of ips) {
const result = isPrivateIP(ip);
if (!result.isPrivate) {
// Build URL with IP address instead of hostname
const targetUrl = this.buildUrlWithIP(parsed, ip);
return { targetUrl, host: parsed.host };
}
}

// All IPs are private
throw new HttpError(403, "fetch: host resolves to private IP");
}

// Builds a URL replacing the hostname with an IP address.
private buildUrlWithIP(original: URL, ip: string): string {
// Handle IPv6 addresses - they need brackets in URLs
const hostPart = ip.includes(":") ? `[${ip}]` : ip;

// Reconstruct URL with IP
let url = `${original.protocol}//${hostPart}`;

// Add port if non-default
if (original.port) {
url += `:${original.port}`;
}

// Add path, search, and hash
url += original.pathname + original.search + original.hash;

return url;
}

// Reads the response body into a pre-allocated buffer when Content-Length is known.
private async readWithKnownLength(
reader: StreamReader<Uint8Array>,
Expand Down Expand Up @@ -267,8 +193,7 @@ export class Client {
);
}

// Validates URL scheme and allowed hosts (does not do SSRF IP validation).
private validateUrl(url: string): URL {
private async validateHost(url: string): Promise<void> {
let parsed: URL;
try {
parsed = new URL(url);
Expand All @@ -286,7 +211,10 @@ export class Client {
throw new HttpError(403, "fetch: host is not allowed");
}

return parsed;
// SSRF protection: validate IP addresses
if (this.ssrfProtection) {
await validateUrlForSSRF(parsed);
}
}
}

Expand Down
Loading