Skip to content

Commit 876f911

Browse files
committed
fix: aggregate expression args (COUNT(a+b)) no longer collide with COUNT(*)
compileAggregate and rewriteAggregatesAsColumns both fell back to "*" for non-column aggregate arguments, causing COUNT(expr) to produce the same output column name as COUNT(*). Added aggArgKey() in ast.ts that serializes expressions to deterministic keys (e.g. "add(a,b)"), used by both sites so HAVING column lookups always match.
1 parent 9ace9a3 commit 876f911

File tree

4 files changed

+51
-5
lines changed

4 files changed

+51
-5
lines changed

src/sql/ast.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,26 @@ export type SqlStatement =
118118
| { kind: "select"; stmt: SelectStmt }
119119
| { kind: "show_versions"; stmt: ShowVersionsStmt }
120120
| { kind: "diff"; stmt: DiffStmt };
121+
122+
/**
123+
* Deterministic string key for an aggregate argument expression.
124+
* Returns "col" for column refs, "*" for star, or a serialized form
125+
* for complex expressions (e.g. "add(a,b)" for a + b).
126+
* Used by both compileAggregate and rewriteAggregatesAsColumns so
127+
* their output column names always match.
128+
*/
129+
export function aggArgKey(expr: SqlExpr): string {
130+
switch (expr.kind) {
131+
case "star": return "*";
132+
case "column": return expr.name;
133+
case "value": {
134+
const v = expr.value;
135+
return v.type === "null" ? "null" : String(v.value);
136+
}
137+
case "binary": return `${expr.op}(${aggArgKey(expr.left)},${aggArgKey(expr.right)})`;
138+
case "unary": return `${expr.op}(${aggArgKey(expr.operand)})`;
139+
case "call": return `${expr.name}(${expr.args.map(aggArgKey).join(",")})`;
140+
case "cast": return `cast(${aggArgKey(expr.expr)},${expr.targetType})`;
141+
default: return "_expr";
142+
}
143+
}

src/sql/compiler.test.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,19 @@ describe("SQL Compiler", () => {
9494
expect(desc.aggregates).toEqual([{ fn: "count_distinct", column: "user_id" }]);
9595
});
9696

97+
it("compiles SUM(expression) with deterministic column key", () => {
98+
const desc = sql("SELECT SUM(price * quantity) FROM orders");
99+
expect(desc.aggregates).toEqual([{ fn: "sum", column: "multiply(price,quantity)" }]);
100+
});
101+
102+
it("COUNT(expr) does not collide with COUNT(*)", () => {
103+
const desc = sql("SELECT COUNT(*), COUNT(a + b) FROM t");
104+
expect(desc.aggregates).toHaveLength(2);
105+
expect(desc.aggregates![0].column).toBe("*");
106+
expect(desc.aggregates![1].column).toBe("add(a,b)");
107+
expect(desc.aggregates![0].column).not.toBe(desc.aggregates![1].column);
108+
});
109+
97110
it("compiles GROUP BY", () => {
98111
const desc = sql("SELECT region, SUM(sales) FROM data GROUP BY region");
99112
expect(desc.groupBy).toEqual(["region"]);
@@ -234,6 +247,16 @@ describe("SQL Compiler - compileFull", () => {
234247
}
235248
});
236249

250+
it("HAVING with expression-arg aggregate uses matching column key", () => {
251+
const result = sqlFull("SELECT dept, SUM(price * qty) FROM t GROUP BY dept HAVING SUM(price * qty) > 100");
252+
expect(result.havingExpr).toBeDefined();
253+
if (result.havingExpr!.kind === "binary" && result.havingExpr!.left.kind === "column") {
254+
expect(result.havingExpr!.left.name).toBe("sum_multiply(price,qty)");
255+
}
256+
// Must match the aggregate's column key
257+
expect(result.descriptor.aggregates![0].column).toBe("multiply(price,qty)");
258+
});
259+
237260
it("returns allOrderBy for multi-column ORDER BY", () => {
238261
const result = sqlFull("SELECT * FROM t ORDER BY name ASC, age DESC");
239262
expect(result.allOrderBy).toBeDefined();

src/sql/compiler.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import type { QueryDescriptor } from "../client.js";
44
import type { AggregateOp, FilterOp, WindowSpec } from "../types.js";
5-
import type { SelectStmt, SqlExpr, TableRef, SqlOrderBy, CteDef } from "./ast.js";
5+
import { aggArgKey, type SelectStmt, type SqlExpr, type TableRef, type SqlOrderBy, type CteDef } from "./ast.js";
66
import { rewriteAggregatesAsColumns } from "./evaluator.js";
77

88
const AGGREGATE_FNS = new Set(["COUNT", "SUM", "AVG", "MIN", "MAX", "COUNT_DISTINCT", "STDDEV", "VARIANCE", "MEDIAN", "PERCENTILE"]);
@@ -227,6 +227,7 @@ function extractColumnName(expr: SqlExpr): string | undefined {
227227
return undefined;
228228
}
229229

230+
230231
function isAggregateCall(name: string): boolean {
231232
return AGGREGATE_FNS.has(name.toUpperCase());
232233
}
@@ -240,8 +241,7 @@ function compileAggregate(expr: SqlExpr & { kind: "call" }, alias?: string): Agg
240241
}
241242

242243
if (expr.args.length > 0 && expr.args[0].kind !== "star") {
243-
const colName = extractColumnName(expr.args[0]);
244-
if (colName) column = colName;
244+
column = aggArgKey(expr.args[0]);
245245
}
246246

247247
const result: AggregateOp = { fn, column };

src/sql/evaluator.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/** Runtime SQL expression evaluator — evaluates AST nodes against Row data */
22

33
import type { Row } from "../types.js";
4-
import type { SqlExpr } from "./ast.js";
4+
import { aggArgKey, type SqlExpr } from "./ast.js";
55
import { compileLikeRegex } from "../decode.js";
66

77
export function evaluateExpr(expr: SqlExpr, row: Row): unknown {
@@ -208,7 +208,7 @@ export function rewriteAggregatesAsColumns(expr: SqlExpr): SqlExpr {
208208
case "call": {
209209
const fnUpper = expr.name.toUpperCase();
210210
if (isAggregate(fnUpper)) {
211-
const col = expr.args[0]?.kind === "star" ? "*" : (expr.args[0]?.kind === "column" ? expr.args[0].name : "*");
211+
const col = expr.args[0] ? aggArgKey(expr.args[0]) : "*";
212212
const fn = (expr.distinct && fnUpper === "COUNT") ? "count_distinct" : fnUpper.toLowerCase();
213213
return { kind: "column", name: `${fn}_${col}` };
214214
}

0 commit comments

Comments
 (0)