Skip to content

Commit 78b18eb

Browse files
committed
feat: add Zod validation for query descriptors at API boundary
Replace unsafe `as` casts in query-do.ts parseQuery() with Zod schema validation. All POST endpoints (/query, /query/count, /query/exists, /query/first, /query/explain, /query/stream) now validate request bodies and return 400 with clear error messages on malformed input. Validated fields: table (required string), filters (op enum, column non-empty), aggregates (fn enum), sortDirection (asc|desc), limit (positive int), offset (non-negative int), cacheTTL (positive int), vectorSearch (topK positive int).
1 parent 832eafc commit 78b18eb

File tree

4 files changed

+134
-31
lines changed

4 files changed

+134
-31
lines changed

package.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,8 @@
4040
"src/wasm/querymode.wasm"
4141
],
4242
"license": "MIT",
43-
"packageManager": "pnpm@10.21.0"
43+
"packageManager": "pnpm@10.21.0",
44+
"dependencies": {
45+
"zod": "^4.3.6"
46+
}
4447
}

pnpm-lock.yaml

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/query-do.ts

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { mergeQueryResults } from "./merge.js";
1111
import { coalesceRanges, fetchBounded, withRetry, withTimeout } from "./coalesce.js";
1212
import { VipCache } from "./vip-cache.js";
1313
import { parseLanceV2Columns } from "./lance-v2.js";
14+
import { parseAndValidateQuery } from "./query-schema.js";
1415
import wasmModule from "./wasm-module.js";
1516

1617
const FRAGMENT_POOL_MAX = 20; // Max Fragment DO slots per datacenter (idle slots cost nothing)
@@ -246,11 +247,13 @@ export class QueryDO implements DurableObject {
246247

247248
private async handleQuery(request: Request): Promise<Response> {
248249
const requestId = request.headers.get("x-querymode-request-id") ?? crypto.randomUUID();
249-
const body = await request.json() as Record<string, unknown>;
250-
if (!body.table || typeof body.table !== "string") {
251-
return this.json({ error: "Missing or invalid 'table' field" }, 400);
250+
const body = await request.json();
251+
let query: QueryDescriptor;
252+
try {
253+
query = this.parseQuery(body);
254+
} catch (err) {
255+
return this.json({ error: (err as Error).message }, 400);
252256
}
253-
const query = this.parseQuery(body);
254257
try {
255258
const result = await this.executeQuery(query);
256259
result.requestId = requestId;
@@ -291,25 +294,14 @@ export class QueryDO implements DurableObject {
291294
return `qr:${query.table}:${(h >>> 0).toString(36)}`;
292295
}
293296

294-
private parseQuery(request_body: Record<string, unknown>): QueryDescriptor {
295-
return {
296-
table: request_body.table as string,
297-
filters: (request_body.filters ?? []) as QueryDescriptor["filters"],
298-
projections: (request_body.projections ?? request_body.select ?? []) as string[],
299-
sortColumn: request_body.sortColumn as string | undefined,
300-
sortDirection: request_body.sortDirection as "asc" | "desc" | undefined,
301-
limit: request_body.limit as number | undefined,
302-
vectorSearch: request_body.vectorSearch as QueryDescriptor["vectorSearch"],
303-
aggregates: request_body.aggregates as QueryDescriptor["aggregates"],
304-
groupBy: request_body.groupBy as string[] | undefined,
305-
cacheTTL: request_body.cacheTTL as number | undefined,
306-
};
297+
private parseQuery(body: unknown): QueryDescriptor {
298+
return parseAndValidateQuery(body) as QueryDescriptor;
307299
}
308300

309301
private async handleCount(request: Request): Promise<Response> {
310-
const body = await request.json() as Record<string, unknown>;
311-
if (!body.table || typeof body.table !== "string") return this.json({ error: "Missing 'table'" }, 400);
312-
const query = this.parseQuery(body);
302+
const body = await request.json();
303+
let query: QueryDescriptor;
304+
try { query = this.parseQuery(body); } catch (err) { return this.json({ error: (err as Error).message }, 400); }
313305
try {
314306
// Fast path: no filters — sum page rowCounts from cached metadata
315307
if (query.filters.length === 0) {
@@ -331,9 +323,9 @@ export class QueryDO implements DurableObject {
331323
}
332324

333325
private async handleExists(request: Request): Promise<Response> {
334-
const body = await request.json() as Record<string, unknown>;
335-
if (!body.table || typeof body.table !== "string") return this.json({ error: "Missing 'table'" }, 400);
336-
const query = this.parseQuery(body);
326+
const body = await request.json();
327+
let query: QueryDescriptor;
328+
try { query = this.parseQuery(body); } catch (err) { return this.json({ error: (err as Error).message }, 400); }
337329
query.limit = 1;
338330
try {
339331
const result = await this.executeQuery(query);
@@ -344,9 +336,9 @@ export class QueryDO implements DurableObject {
344336
}
345337

346338
private async handleFirst(request: Request): Promise<Response> {
347-
const body = await request.json() as Record<string, unknown>;
348-
if (!body.table || typeof body.table !== "string") return this.json({ error: "Missing 'table'" }, 400);
349-
const query = this.parseQuery(body);
339+
const body = await request.json();
340+
let query: QueryDescriptor;
341+
try { query = this.parseQuery(body); } catch (err) { return this.json({ error: (err as Error).message }, 400); }
350342
query.limit = 1;
351343
try {
352344
const result = await this.executeQuery(query);
@@ -357,9 +349,9 @@ export class QueryDO implements DurableObject {
357349
}
358350

359351
private async handleExplain(request: Request): Promise<Response> {
360-
const body = await request.json() as Record<string, unknown>;
361-
if (!body.table || typeof body.table !== "string") return this.json({ error: "Missing 'table'" }, 400);
362-
const query = this.parseQuery(body);
352+
const body = await request.json();
353+
let query: QueryDescriptor;
354+
try { query = this.parseQuery(body); } catch (err) { return this.json({ error: (err as Error).message }, 400); }
363355
try {
364356
let meta: TableMeta | undefined = this.footerCache.get(query.table);
365357
const metaCached = !!meta;
@@ -1358,7 +1350,9 @@ export class QueryDO implements DurableObject {
13581350

13591351
/** Stream query results as NDJSON. */
13601352
private async handleQueryStream(request: Request): Promise<Response> {
1361-
const query = (await request.json()) as QueryDescriptor;
1353+
const body = await request.json();
1354+
let query: QueryDescriptor;
1355+
try { query = this.parseQuery(body); } catch (err) { return this.json({ error: (err as Error).message }, 400); }
13621356
const result = await this.executeQuery(query);
13631357

13641358
const { readable, writable } = new TransformStream<Uint8Array>();

src/query-schema.ts

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/**
2+
* Zod schemas for query descriptor validation.
3+
*
4+
* Validates POST bodies at the API boundary (query-do.ts) before they
5+
* reach the execution engine. Catches malformed filters, invalid ops,
6+
* bad aggregate fns, etc. with clear error messages.
7+
*/
8+
import { z } from "zod/v4";
9+
10+
const filterOpSchema = z.object({
11+
column: z.string().min(1, "Filter column name cannot be empty"),
12+
op: z.enum(["eq", "neq", "gt", "gte", "lt", "lte", "in"]),
13+
value: z.union([
14+
z.number(),
15+
z.string(),
16+
z.array(z.union([z.number(), z.string()])),
17+
]),
18+
});
19+
20+
const aggregateOpSchema = z.object({
21+
fn: z.enum(["sum", "avg", "min", "max", "count"]),
22+
column: z.string().min(1, "Aggregate column name cannot be empty"),
23+
alias: z.string().optional(),
24+
});
25+
26+
const vectorSearchSchema = z.object({
27+
column: z.string().min(1),
28+
queryVector: z.union([
29+
z.array(z.number()),
30+
z.instanceof(Float32Array),
31+
]),
32+
topK: z.number().int().positive(),
33+
});
34+
35+
export const queryDescriptorSchema = z.object({
36+
table: z.string().min(1, "Table name is required"),
37+
filters: z.array(filterOpSchema).default([]),
38+
projections: z.array(z.string()).default([]),
39+
select: z.array(z.string()).optional(), // alias for projections
40+
sortColumn: z.string().optional(),
41+
sortDirection: z.enum(["asc", "desc"]).optional(),
42+
limit: z.number().int().positive().optional(),
43+
offset: z.number().int().nonnegative().optional(),
44+
vectorSearch: vectorSearchSchema.optional(),
45+
aggregates: z.array(aggregateOpSchema).optional(),
46+
groupBy: z.array(z.string()).optional(),
47+
cacheTTL: z.number().int().positive().optional(),
48+
});
49+
50+
export type ValidatedQuery = z.infer<typeof queryDescriptorSchema>;
51+
52+
/**
53+
* Parse and validate a raw request body into a QueryDescriptor.
54+
* Throws a formatted error string on validation failure.
55+
*/
56+
export function parseAndValidateQuery(body: unknown): {
57+
table: string;
58+
filters: { column: string; op: "eq" | "neq" | "gt" | "gte" | "lt" | "lte" | "in"; value: number | string | (number | string)[] }[];
59+
projections: string[];
60+
sortColumn?: string;
61+
sortDirection?: "asc" | "desc";
62+
limit?: number;
63+
offset?: number;
64+
vectorSearch?: { column: string; queryVector: number[] | Float32Array; topK: number };
65+
aggregates?: { fn: "sum" | "avg" | "min" | "max" | "count"; column: string; alias?: string }[];
66+
groupBy?: string[];
67+
cacheTTL?: number;
68+
} {
69+
const result = queryDescriptorSchema.safeParse(body);
70+
if (!result.success) {
71+
const issues = result.error.issues.map(i =>
72+
`${i.path.join(".")}: ${i.message}`
73+
).join("; ");
74+
throw new Error(`Invalid query: ${issues}`);
75+
}
76+
77+
const data = result.data;
78+
// Merge `select` alias into `projections`
79+
const projections = data.projections.length > 0
80+
? data.projections
81+
: (data.select ?? []);
82+
83+
return {
84+
table: data.table,
85+
filters: data.filters,
86+
projections,
87+
sortColumn: data.sortColumn,
88+
sortDirection: data.sortDirection,
89+
limit: data.limit,
90+
offset: data.offset,
91+
vectorSearch: data.vectorSearch,
92+
aggregates: data.aggregates,
93+
groupBy: data.groupBy,
94+
cacheTTL: data.cacheTTL,
95+
};
96+
}

0 commit comments

Comments
 (0)