Skip to content

Commit 4e16df2

Browse files
committed
feat: add indexed aggregates and filtered WASM aggregate path
- Add sumFloat64Indexed, minFloat64Indexed, maxFloat64Indexed (and int64 variants) to Zig engine - Extend WasmAggregateOperator to handle filtered aggregates: SIMD filter → indexed aggregate, bypassing Row[] materialization - Fix int64 filter bug: restrict WASM filter path to f64/i32 (no filterInt64Buffer exists) - Upgrade unionIndices to O(n+m) sorted merge
1 parent 0ce62a5 commit 4e16df2

File tree

3 files changed

+193
-42
lines changed

3 files changed

+193
-42
lines changed

src/operators.ts

Lines changed: 122 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,9 +1184,19 @@ export class InMemorySortOperator implements Operator {
11841184
export function canUseWasmAggregate(query: QueryDescriptor, columns: ColumnMeta[]): boolean {
11851185
if (!query.aggregates || query.aggregates.length === 0) return false;
11861186
if (query.groupBy && query.groupBy.length > 0) return false;
1187-
if (query.filters.length > 0) return false;
11881187

11891188
const colMap = new Map(columns.map(c => [c.name, c]));
1189+
1190+
// Filters are allowed if they are on numeric columns with scalar ops (no "in")
1191+
for (const f of query.filters) {
1192+
if (f.op === "in") return false;
1193+
const fc = colMap.get(f.column);
1194+
if (!fc) return false;
1195+
if (fc.dtype !== "float64" && fc.dtype !== "int32") return false;
1196+
if (fc.pages.some(p => p.encoding)) return false;
1197+
if (typeof f.value !== "number") return false;
1198+
}
1199+
11901200
for (const agg of query.aggregates) {
11911201
if (agg.column === "*") continue; // count(*) is always fine
11921202
const col = colMap.get(agg.column);
@@ -1226,6 +1236,8 @@ export class WasmAggregateOperator implements Operator {
12261236
this.consumed = true;
12271237

12281238
const aggregates = this.query.aggregates ?? [];
1239+
const filters = this.query.filters;
1240+
const hasFilters = filters.length > 0;
12291241
// Accumulator per aggregate
12301242
const acc: { sum: number; count: number; min: number; max: number }[] =
12311243
aggregates.map(() => ({ sum: 0, count: 0, min: Infinity, max: -Infinity }));
@@ -1234,58 +1246,126 @@ export class WasmAggregateOperator implements Operator {
12341246
for (const frag of this.fragments) {
12351247
const colMap = new Map(frag.columns.map(c => [c.name, c]));
12361248

1237-
for (let ai = 0; ai < aggregates.length; ai++) {
1238-
const agg = aggregates[ai];
1239-
if (agg.column === "*") {
1240-
// count(*): just sum row counts from page metadata
1241-
const firstCol = frag.columns[0];
1242-
if (firstCol) {
1243-
for (const page of firstCol.pages) {
1244-
acc[ai].count += page.rowCount;
1245-
}
1246-
}
1247-
continue;
1248-
}
1249+
// Collect all columns needed (aggregate + filter)
1250+
const neededCols = new Set<string>();
1251+
for (const agg of aggregates) if (agg.column !== "*") neededCols.add(agg.column);
1252+
for (const f of filters) neededCols.add(f.column);
12491253

1250-
const col = colMap.get(agg.column);
1251-
if (!col) continue;
1254+
// Process page by page
1255+
const firstCol = frag.columns[0];
1256+
if (!firstCol) continue;
1257+
const pageCount = firstCol.pages.length;
12521258

1253-
for (const page of col.pages) {
1254-
if (canSkipPage(page, this.query.filters, col.name)) {
1255-
this.pagesSkipped++;
1256-
continue;
1257-
}
1259+
for (let pi = 0; pi < pageCount; pi++) {
1260+
// Page-level skip via min/max stats
1261+
if (canSkipPage(firstCol.pages[pi], filters, firstCol.name)) {
1262+
this.pagesSkipped++;
1263+
continue;
1264+
}
12581265

1259-
const buf = await frag.readPage(col, page);
1266+
// Read needed columns for this page
1267+
const pageBuffers = new Map<string, ArrayBuffer>();
1268+
for (const colName of neededCols) {
1269+
const col = colMap.get(colName);
1270+
if (!col || !col.pages[pi]) continue;
1271+
const buf = await frag.readPage(col, col.pages[pi]);
12601272
this.bytesRead += buf.byteLength;
1273+
pageBuffers.set(colName, buf);
1274+
}
12611275

1262-
if (col.dtype === "float64") {
1263-
const count = buf.byteLength >> 3;
1264-
acc[ai].count += count;
1265-
if (agg.fn === "sum" || agg.fn === "avg") {
1266-
acc[ai].sum += this.wasm.sumFloat64(buf);
1276+
// Compute matching indices if filters exist
1277+
let matchCount = -1; // -1 = no filter, use full buffer
1278+
let indicesPtr = 0;
1279+
if (hasFilters) {
1280+
this.wasm.exports.resetHeap();
1281+
let currentIndices: Uint32Array | null = null;
1282+
1283+
for (const f of filters) {
1284+
const col = colMap.get(f.column);
1285+
const buf = pageBuffers.get(f.column);
1286+
if (!col || !buf) { currentIndices = new Uint32Array(0); break; }
1287+
const wasmOp = filterOpToWasm(f.op);
1288+
if (wasmOp < 0) { currentIndices = new Uint32Array(0); break; }
1289+
if (col.dtype !== "float64" && col.dtype !== "int32") {
1290+
// Only f64 and i32 have WASM SIMD filter support; skip unsupported types
1291+
currentIndices = new Uint32Array(0); break;
12671292
}
1268-
if (agg.fn === "min") {
1269-
const v = this.wasm.minFloat64(buf);
1270-
if (v < acc[ai].min) acc[ai].min = v;
1293+
const elemSize = col.dtype === "float64" ? 8 : 4;
1294+
const rowCount = buf.byteLength / elemSize;
1295+
1296+
// Copy data to WASM
1297+
const dataPtr = this.wasm.exports.alloc(buf.byteLength);
1298+
if (!dataPtr) { currentIndices = new Uint32Array(0); break; }
1299+
new Uint8Array(this.wasm.exports.memory.buffer, dataPtr, buf.byteLength).set(new Uint8Array(buf));
1300+
const outPtr = this.wasm.exports.alloc(rowCount * 4);
1301+
if (!outPtr) { currentIndices = new Uint32Array(0); break; }
1302+
1303+
let count: number;
1304+
if (col.dtype === "float64") {
1305+
count = this.wasm.exports.filterFloat64Buffer(dataPtr, rowCount, wasmOp, f.value as number, outPtr, rowCount);
1306+
} else {
1307+
count = this.wasm.exports.filterInt32Buffer(dataPtr, rowCount, wasmOp, f.value as number, outPtr, rowCount);
12711308
}
1272-
if (agg.fn === "max") {
1273-
const v = this.wasm.maxFloat64(buf);
1274-
if (v > acc[ai].max) acc[ai].max = v;
1309+
const filterResult = new Uint32Array(this.wasm.exports.memory.buffer.slice(outPtr, outPtr + count * 4));
1310+
1311+
if (currentIndices) {
1312+
currentIndices = wasmIntersect(currentIndices, filterResult, this.wasm);
1313+
} else {
1314+
currentIndices = filterResult;
12751315
}
1276-
} else if (col.dtype === "int64") {
1316+
}
1317+
1318+
if (!currentIndices || currentIndices.length === 0) continue; // all filtered out
1319+
1320+
// Copy indices to WASM for indexed aggregates
1321+
this.wasm.exports.resetHeap();
1322+
indicesPtr = this.wasm.exports.alloc(currentIndices.byteLength);
1323+
if (indicesPtr) {
1324+
new Uint32Array(this.wasm.exports.memory.buffer, indicesPtr, currentIndices.length).set(currentIndices);
1325+
}
1326+
matchCount = currentIndices.length;
1327+
}
1328+
1329+
// Aggregate per column
1330+
for (let ai = 0; ai < aggregates.length; ai++) {
1331+
const agg = aggregates[ai];
1332+
if (agg.column === "*") {
1333+
acc[ai].count += hasFilters ? matchCount : firstCol.pages[pi].rowCount;
1334+
continue;
1335+
}
1336+
1337+
const col = colMap.get(agg.column);
1338+
const buf = pageBuffers.get(agg.column);
1339+
if (!col || !buf) continue;
1340+
1341+
if (!hasFilters) {
1342+
// Unfiltered: use full-buffer SIMD aggregates
12771343
const count = buf.byteLength >> 3;
12781344
acc[ai].count += count;
1279-
if (agg.fn === "sum" || agg.fn === "avg") {
1280-
acc[ai].sum += Number(this.wasm.sumInt64(buf));
1281-
}
1282-
if (agg.fn === "min") {
1283-
const v = Number(this.wasm.minInt64(buf));
1284-
if (v < acc[ai].min) acc[ai].min = v;
1345+
if (col.dtype === "float64") {
1346+
if (agg.fn === "sum" || agg.fn === "avg") acc[ai].sum += this.wasm.sumFloat64(buf);
1347+
if (agg.fn === "min") { const v = this.wasm.minFloat64(buf); if (v < acc[ai].min) acc[ai].min = v; }
1348+
if (agg.fn === "max") { const v = this.wasm.maxFloat64(buf); if (v > acc[ai].max) acc[ai].max = v; }
1349+
} else if (col.dtype === "int64") {
1350+
if (agg.fn === "sum" || agg.fn === "avg") acc[ai].sum += Number(this.wasm.sumInt64(buf));
1351+
if (agg.fn === "min") { const v = Number(this.wasm.minInt64(buf)); if (v < acc[ai].min) acc[ai].min = v; }
1352+
if (agg.fn === "max") { const v = Number(this.wasm.maxInt64(buf)); if (v > acc[ai].max) acc[ai].max = v; }
12851353
}
1286-
if (agg.fn === "max") {
1287-
const v = Number(this.wasm.maxInt64(buf));
1288-
if (v > acc[ai].max) acc[ai].max = v;
1354+
} else {
1355+
// Filtered: use indexed aggregates on matching rows only
1356+
acc[ai].count += matchCount;
1357+
const dataPtr = this.wasm.exports.alloc(buf.byteLength);
1358+
if (!dataPtr) continue;
1359+
new Uint8Array(this.wasm.exports.memory.buffer, dataPtr, buf.byteLength).set(new Uint8Array(buf));
1360+
1361+
if (col.dtype === "float64") {
1362+
if (agg.fn === "sum" || agg.fn === "avg") acc[ai].sum += this.wasm.exports.sumFloat64Indexed(dataPtr, indicesPtr, matchCount);
1363+
if (agg.fn === "min") { const v = this.wasm.exports.minFloat64Indexed(dataPtr, indicesPtr, matchCount); if (v < acc[ai].min) acc[ai].min = v; }
1364+
if (agg.fn === "max") { const v = this.wasm.exports.maxFloat64Indexed(dataPtr, indicesPtr, matchCount); if (v > acc[ai].max) acc[ai].max = v; }
1365+
} else if (col.dtype === "int64") {
1366+
if (agg.fn === "sum" || agg.fn === "avg") acc[ai].sum += Number(this.wasm.exports.sumInt64Indexed(dataPtr, indicesPtr, matchCount));
1367+
if (agg.fn === "min") { const v = Number(this.wasm.exports.minInt64Indexed(dataPtr, indicesPtr, matchCount)); if (v < acc[ai].min) acc[ai].min = v; }
1368+
if (agg.fn === "max") { const v = Number(this.wasm.exports.maxInt64Indexed(dataPtr, indicesPtr, matchCount)); if (v > acc[ai].max) acc[ai].max = v; }
12891369
}
12901370
}
12911371
}

src/wasm-engine.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ export interface WasmExports {
7878
minInt64Buffer(ptr: number, len: number): bigint;
7979
maxInt64Buffer(ptr: number, len: number): bigint;
8080

81+
// Indexed aggregates — operate on filtered subsets (aggregates.zig)
82+
sumFloat64Indexed(dataPtr: number, indicesPtr: number, count: number): number;
83+
minFloat64Indexed(dataPtr: number, indicesPtr: number, count: number): number;
84+
maxFloat64Indexed(dataPtr: number, indicesPtr: number, count: number): number;
85+
sumInt64Indexed(dataPtr: number, indicesPtr: number, count: number): bigint;
86+
minInt64Indexed(dataPtr: number, indicesPtr: number, count: number): bigint;
87+
maxInt64Indexed(dataPtr: number, indicesPtr: number, count: number): bigint;
88+
8189
// Compression (compression.zig)
8290
zstd_decompress(compressedPtr: number, compressedLen: number, decompressedPtr: number, decompressedCapacity: number): number;
8391
zstd_get_decompressed_size(compressedPtr: number, compressedLen: number): number;

wasm/src/wasm/aggregates.zig

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,3 +822,66 @@ export fn unionIndices(
822822

823823
return out_count;
824824
}
825+
826+
// ============================================================================
827+
// Indexed Aggregates — operate on a subset of rows identified by an index array.
828+
// Used for filtered aggregates: filter → get indices → aggregate only matches.
829+
// ============================================================================
830+
831+
/// Sum float64 values at given indices
832+
export fn sumFloat64Indexed(data: [*]const f64, indices: [*]const u32, count: usize) f64 {
833+
var sum: f64 = 0;
834+
for (0..count) |i| sum += data[indices[i]];
835+
return sum;
836+
}
837+
838+
/// Min float64 values at given indices
839+
export fn minFloat64Indexed(data: [*]const f64, indices: [*]const u32, count: usize) f64 {
840+
if (count == 0) return 0;
841+
var min_val = data[indices[0]];
842+
for (1..count) |i| {
843+
const v = data[indices[i]];
844+
if (v < min_val) min_val = v;
845+
}
846+
return min_val;
847+
}
848+
849+
/// Max float64 values at given indices
850+
export fn maxFloat64Indexed(data: [*]const f64, indices: [*]const u32, count: usize) f64 {
851+
if (count == 0) return 0;
852+
var max_val = data[indices[0]];
853+
for (1..count) |i| {
854+
const v = data[indices[i]];
855+
if (v > max_val) max_val = v;
856+
}
857+
return max_val;
858+
}
859+
860+
/// Sum int64 values at given indices
861+
export fn sumInt64Indexed(data: [*]const i64, indices: [*]const u32, count: usize) i64 {
862+
var sum: i64 = 0;
863+
for (0..count) |i| sum += data[indices[i]];
864+
return sum;
865+
}
866+
867+
/// Min int64 values at given indices
868+
export fn minInt64Indexed(data: [*]const i64, indices: [*]const u32, count: usize) i64 {
869+
if (count == 0) return 0;
870+
var min_val = data[indices[0]];
871+
for (1..count) |i| {
872+
const v = data[indices[i]];
873+
if (v < min_val) min_val = v;
874+
}
875+
return min_val;
876+
}
877+
878+
/// Max int64 values at given indices
879+
export fn maxInt64Indexed(data: [*]const i64, indices: [*]const u32, count: usize) i64 {
880+
if (count == 0) return 0;
881+
var max_val = data[indices[0]];
882+
for (1..count) |i| {
883+
const v = data[indices[i]];
884+
if (v > max_val) max_val = v;
885+
}
886+
return max_val;
887+
}

0 commit comments

Comments
 (0)