From a541aee9db04757e77314a14c928b40d8c5f452e Mon Sep 17 00:00:00 2001 From: Lingxuan Zuo Date: Mon, 30 Mar 2026 13:25:13 +0800 Subject: [PATCH 1/4] perf(vector): remove float precision loss in codecs and tighten benchmark path --- BUILD.bazel | 20 ++ README-zh.md | 37 +++ README.md | 37 +++ docs/local_vector_search_v01.md | 67 ++++ python_api/BUILD.bazel | 18 ++ python_api/README.md | 30 ++ python_api/tests/test_vector_search.py | 67 ++++ python_api/velaria_cli.py | 189 ++++++++++++ scripts/BUILD.bazel | 1 + scripts/build_py_cli_executable.sh | 39 +++ src/dataflow/api/dataframe.cc | 73 +++++ src/dataflow/api/dataframe.h | 9 + src/dataflow/api/session.cc | 13 + src/dataflow/api/session.h | 6 + src/dataflow/core/value.h | 52 +++- .../examples/vector_search_benchmark.cc | 75 +++++ src/dataflow/examples/velaria_cli.cc | 287 ++++++++++++++++++ src/dataflow/planner/plan.cc | 52 +++- src/dataflow/python/python_module.cc | 174 +++++++++-- src/dataflow/runtime/vector_index.cc | 161 ++++++++++ src/dataflow/runtime/vector_index.h | 36 +++ src/dataflow/serial/serializer.cc | 114 +++++-- src/dataflow/stream/binary_row_batch.cc | 32 +- src/dataflow/tests/vector_runtime_test.cc | 90 ++++++ 24 files changed, 1626 insertions(+), 53 deletions(-) create mode 100644 docs/local_vector_search_v01.md create mode 100644 python_api/tests/test_vector_search.py create mode 100644 python_api/velaria_cli.py create mode 100755 scripts/build_py_cli_executable.sh create mode 100644 src/dataflow/examples/vector_search_benchmark.cc create mode 100644 src/dataflow/examples/velaria_cli.cc create mode 100644 src/dataflow/runtime/vector_index.cc create mode 100644 src/dataflow/runtime/vector_index.h create mode 100644 src/dataflow/tests/vector_runtime_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index 8113d80..fff8027 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -15,6 +15,7 @@ cc_library( "src/dataflow/runtime/job_master.cc", "src/dataflow/runtime/byte_transport.cc", "src/dataflow/runtime/rpc_runner.cc", + "src/dataflow/runtime/vector_index.cc", "src/dataflow/rpc/rpc_codec.cc", "src/dataflow/transport/ipc_transport.cc", "src/dataflow/ai/plugin_runtime.cc", @@ -43,6 +44,7 @@ cc_library( "src/dataflow/runtime/job_master.h", "src/dataflow/runtime/observability.h", "src/dataflow/runtime/rpc_runner.h", + "src/dataflow/runtime/vector_index.h", "src/dataflow/ai/plugin_runtime.h", "src/dataflow/rpc/rpc_codec.h", "src/dataflow/transport/ipc_transport.h", @@ -128,6 +130,12 @@ cc_binary( deps = [":dataflow_core"], ) +cc_binary( + name = "velaria_cli", + srcs = ["src/dataflow/examples/velaria_cli.cc"], + deps = [":dataflow_core"], +) + cc_library( name = "dataflow_actor_rpc_codec", srcs = [ @@ -235,6 +243,12 @@ cc_binary( deps = [":dataflow_stream_actor_runtime"], ) +cc_binary( + name = "vector_search_benchmark", + srcs = ["src/dataflow/examples/vector_search_benchmark.cc"], + deps = [":dataflow_core"], +) + cc_test( name = "sql_regression_test", srcs = ["src/dataflow/tests/sql_regression_test.cc"], @@ -270,3 +284,9 @@ cc_test( srcs = ["src/dataflow/tests/stream_strategy_explain_test.cc"], deps = [":dataflow_core"], ) + +cc_test( + name = "vector_runtime_test", + srcs = ["src/dataflow/tests/vector_runtime_test.cc"], + deps = [":dataflow_core"], +) diff --git a/README-zh.md b/README-zh.md index e3d0e83..47cf699 100644 --- a/README-zh.md +++ b/README-zh.md @@ -229,6 +229,42 @@ uv run --project python_api python python_api/demo_batch_sql_arrow.py uv run --project python_api python python_api/demo_stream_sql.py ``` +同时在 Session 侧新增了向量查询入口:`Session.vectorQuery(table, vector_column, query_vector, top_k, metric)`(metric 支持 cosine/dot/l2),以及 explain 接口 `Session.explainVectorQuery(...)`。 + +支持打包单文件 CLI 可执行产物(内含 Python 运行时依赖 + native `_velaria.so`): + +```bash +./scripts/build_py_cli_executable.sh +./dist/velaria-cli csv-sql \ + --csv /path/to/input.csv \ + --query "SELECT * FROM input_table LIMIT 5" +``` + +额外支持直接编译 native CLI 二进制(运行时不依赖 Python 环境): + +```bash +bazel build //:velaria_cli +./bazel-bin/velaria_cli \ + --csv /path/to/input.csv \ + --query "SELECT * FROM input_table LIMIT 5" +``` + +native CLI 向量查询(fixed length vector,支持 cosine/cosin、dot 与 l2): + +```bash +./bazel-bin/velaria_cli \ + --csv /path/to/vectors.csv \ + --vector-column embedding \ + --query-vector "0.1,0.2,0.3" \ + --metric cosine \ + --top-k 5 +``` + +runtime 传输层现已在 proto-like 与 binary row batch codec 中保留 `FixedVector` 类型,跨进程传输时不会丢失向量维度语义。 +FixedVector 在内部 codec 里改为 raw float bit payload 编码,避免文本往返造成的精度损耗。 +当前向量检索范围为本地 exact scan(`mode=exact-scan`)+ 固定维度 float 向量;v0.1 不包含 ANN 与分布式执行路径。 +Arrow ingestion 已增加 `FixedSizeList` 的 native 快路径,可减少向量列的 Python 对象转换开销。 + ## 同机多进程实验路径 同机路径刻意保持最小: @@ -270,6 +306,7 @@ Dashboard: - `//:stream_benchmark` - `//:stream_actor_benchmark` - `//:tpch_q1_style_benchmark` +- `//:vector_search_benchmark` 同机 observability regression: diff --git a/README.md b/README.md index 3e19589..bb41c4a 100644 --- a/README.md +++ b/README.md @@ -190,6 +190,8 @@ Main API: - `Session.stream_sql(...)` - `Session.explain_stream_sql(...)` - `Session.start_stream_sql(...)` +- `Session.vectorQuery(table, vector_column, query_vector, top_k, metric)` (`metric`: cosine/dot/l2) +- `Session.explainVectorQuery(table, vector_column, query_vector, top_k, metric)` Arrow ingestion accepts: @@ -208,6 +210,40 @@ uv run --project python_api python python_api/demo_batch_sql_arrow.py uv run --project python_api python python_api/demo_stream_sql.py ``` +Build a single-file CLI executable (bundles Python runtime deps + native `_velaria.so`): + +```bash +./scripts/build_py_cli_executable.sh +./dist/velaria-cli csv-sql \ + --csv /path/to/input.csv \ + --query "SELECT * FROM input_table LIMIT 5" +``` + +Build a native CLI binary (no Python runtime dependency required at runtime): + +```bash +bazel build //:velaria_cli +./bazel-bin/velaria_cli \ + --csv /path/to/input.csv \ + --query "SELECT * FROM input_table LIMIT 5" +``` + +Vector query (fixed-length vector, cosine/dot/l2) via native CLI: + +```bash +./bazel-bin/velaria_cli \ + --csv /path/to/vectors.csv \ + --vector-column embedding \ + --query-vector "0.1,0.2,0.3" \ + --metric cosine \ + --top-k 5 +``` + +Runtime-level vector transport now preserves `FixedVector` through proto-like and binary row batch codecs, so cross-process payloads keep vector type and dimensions. +FixedVector serialization now uses raw float bit payload encoding in internal codecs to avoid text round-trip precision loss. +Current vector search scope is local-only exact scan (`mode=exact-scan`) with fixed-dimension float vectors; no ANN/distributed path in v0.1. +Arrow ingestion now includes a direct `FixedSizeList` fast path in the native bridge, reducing Python object conversion overhead on vector columns. + ## Same-Host Multi-Process Experiment The same-host path is intentionally minimal: @@ -249,6 +285,7 @@ Useful local targets: - `//:stream_benchmark` - `//:stream_actor_benchmark` - `//:tpch_q1_style_benchmark` +- `//:vector_search_benchmark` Same-host observability regression: diff --git a/docs/local_vector_search_v01.md b/docs/local_vector_search_v01.md new file mode 100644 index 0000000..b841fca --- /dev/null +++ b/docs/local_vector_search_v01.md @@ -0,0 +1,67 @@ +# Local Vector Search v0.1 (Velaria) + +## Scope + +This document defines a minimal local-first vector search path for Velaria. + +### Goals + +- Fixed-dimension `float32` vector column support. +- Exact scan backend only. +- Metrics: `cosine`, `dot`, `l2`. +- `top-k` query support. +- C++ API via `DataFrame` / `DataflowSession`. +- Python front-end API for invoking vector search. +- Explain text that mirrors actual runtime behavior. +- Keep ingestion/query path zero-copy-oriented where possible. + +### Non-goals (v0.1) + +- No ANN index (HNSW/IVF/PQ). +- No distributed vector execution. +- No standalone vector database subsystem. +- No new SQL grammar for vector search in this phase. + +## Minimal abstractions + +- `Value::DataType::FixedVector` stores fixed-dimension float vectors. +- `VectorIndex` runtime interface with an `ExactScanVectorIndex` implementation. +- `ExactScanVectorIndex` uses flat contiguous buffers and heap top-k selection for scan acceleration. +- Internal vector transport codecs use raw float bit payloads to avoid text precision loss. +- `VectorSearchMetric`: cosine/dot/l2. +- `VectorSearchResult`: `{row_id, score}`. + +## Public API draft + +### C++ + +- `DataFrame::vectorQuery(vector_column, query_vector, top_k, metric)` +- `DataFrame::explainVectorQuery(vector_column, query_vector, top_k, metric)` +- `DataflowSession::vectorQuery(table, vector_column, query_vector, top_k, metric)` +- `DataflowSession::explainVectorQuery(table, vector_column, query_vector, top_k, metric)` + +### Python + +- `Session.vector_search(table, vector_column, query_vector, top_k=10, metric="cosine")` +- `Session.explain_vector_search(table, vector_column, query_vector, top_k=10, metric="cosine")` + +## Explain fields + +Current explain output contains: + +- `mode=exact-scan` +- `metric=` +- `dimension=` +- `top_k=` +- `candidate_rows=` +- `filter_pushdown=false` +- `acceleration=flat-buffer+heap-topk` + +## Test matrix + +- Vector value roundtrip in proto-like serializer. +- Vector value roundtrip in binary row batch codec. +- Runtime query correctness for cosine/l2/dot top-k. +- Dimension mismatch rejection. +- Python API shape and argument validation. +- Arrow `FixedSizeList` ingestion fast path coverage. diff --git a/python_api/BUILD.bazel b/python_api/BUILD.bazel index 3e7b369..9992fb8 100644 --- a/python_api/BUILD.bazel +++ b/python_api/BUILD.bazel @@ -45,6 +45,13 @@ py_binary( deps = [":velaria_py_pkg"], ) +py_binary( + name = "velaria_cli", + srcs = ["velaria_cli.py"], + main = "velaria_cli.py", + deps = [":velaria_py_pkg"], +) + py_package( name = "velaria_pkg", packages = ["velaria"], @@ -157,3 +164,14 @@ py_test( ":velaria_py_pkg", ], ) + +py_test( + name = "vector_search_test", + srcs = ["tests/test_vector_search.py"], + main = "tests/test_vector_search.py", + imports = ["."], + deps = [ + ":velaria_py_pkg", + requirement("pyarrow"), + ], +) diff --git a/python_api/README.md b/python_api/README.md index 1c336c4..30715f9 100644 --- a/python_api/README.md +++ b/python_api/README.md @@ -21,6 +21,36 @@ uv run --project python_api python python_api/demo_batch_sql_arrow.py uv run --project python_api python python_api/demo_stream_sql.py ``` +Single-file CLI packaging (Python deps + native `_velaria.so`): + +```bash +./scripts/build_py_cli_executable.sh +./dist/velaria-cli csv-sql --csv /path/to/input.csv --query "SELECT * FROM input_table LIMIT 5" +./dist/velaria-cli vector-search --csv /path/to/vectors.csv --vector-column embedding --query-vector "0.1,0.2,0.3" --metric cosine --top-k 5 +``` + +Python Session API for local vector search: + +```python +from velaria import Session + +session = Session() +# assume a temp view named "vec_src" already exists +out = session.vector_search("vec_src", "embedding", [0.1, 0.2, 0.3], top_k=5, metric="dot") +print(out.to_rows()) +print(session.explain_vector_search("vec_src", "embedding", [0.1, 0.2, 0.3], top_k=5, metric="dot")) +``` + +Current vector search scope is local exact scan only (`cosine`/`dot`/`l2`) on fixed-dimension float vectors. + +Native binary CLI alternative (runtime does not require Python environment): + +```bash +bazel build //:velaria_cli +./bazel-bin/velaria_cli --csv /path/to/input.csv --query "SELECT * FROM input_table LIMIT 5" +./bazel-bin/velaria_cli --csv /path/to/vectors.csv --vector-column embedding --query-vector "0.1,0.2,0.3" --metric l2 --top-k 5 +``` + ## CI packaging PR CI builds and uploads two native wheel variants: diff --git a/python_api/tests/test_vector_search.py b/python_api/tests/test_vector_search.py new file mode 100644 index 0000000..0d67740 --- /dev/null +++ b/python_api/tests/test_vector_search.py @@ -0,0 +1,67 @@ +import unittest + +import pyarrow as pa + +from velaria import Session + + +class VectorSearchTest(unittest.TestCase): + def test_vector_search_metrics_and_explain(self): + session = Session() + vectors = pa.array( + [[1.0, 0.0, 0.0], [0.9, 0.1, 0.0], [0.0, 1.0, 0.0]], + type=pa.list_(pa.float32(), 3), + ) + table = pa.table({"id": [1, 2, 3], "embedding": vectors}) + df = session.create_dataframe_from_arrow(table) + session.create_temp_view("vec_src_py", df) + + cosine = session.vector_search( + table="vec_src_py", + vector_column="embedding", + query_vector=[1.0, 0.0, 0.0], + top_k=2, + metric="cosine", + ).to_rows() + self.assertEqual(cosine["rows"][0][0], 0) + + dot = session.vector_search( + table="vec_src_py", + vector_column="embedding", + query_vector=[1.0, 0.0, 0.0], + top_k=1, + metric="dot", + ).to_rows() + self.assertEqual(dot["rows"][0][0], 0) + + explain = session.explain_vector_search( + table="vec_src_py", + vector_column="embedding", + query_vector=[1.0, 0.0, 0.0], + top_k=2, + metric="cosine", + ) + self.assertIn("mode=exact-scan", explain) + self.assertIn("metric=cosine", explain) + self.assertIn("top_k=2", explain) + self.assertIn("acceleration=flat-buffer+heap-topk", explain) + + def test_vector_dimension_mismatch(self): + session = Session() + vectors = pa.array([[1.0, 0.0], [0.0, 1.0]], type=pa.list_(pa.float32(), 2)) + table = pa.table({"embedding": vectors}) + df = session.create_dataframe_from_arrow(table) + session.create_temp_view("vec_mismatch", df) + + with self.assertRaises(RuntimeError): + session.vector_search( + table="vec_mismatch", + vector_column="embedding", + query_vector=[1.0, 0.0, 0.0], + top_k=1, + metric="l2", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python_api/velaria_cli.py b/python_api/velaria_cli.py new file mode 100644 index 0000000..f57fcd5 --- /dev/null +++ b/python_api/velaria_cli.py @@ -0,0 +1,189 @@ +import argparse +import json +import math +import pathlib +from typing import Iterable + +from velaria import Session + + +def _run_csv_sql(csv_path: pathlib.Path, table: str, query: str) -> int: + session = Session() + df = session.read_csv(str(csv_path)) + session.create_temp_view(table, df) + result = session.sql(query).to_arrow() + print( + json.dumps( + { + "table": table, + "query": query, + "schema": result.schema.names, + "rows": result.to_pylist(), + }, + indent=2, + ensure_ascii=False, + ) + ) + return 0 + + +def _parse_vector_text(text: str) -> list[float]: + value = text.strip() + if value.startswith("[") and value.endswith("]"): + value = value[1:-1].strip() + if not value: + return [] + return [float(part.strip()) for part in value.split(",") if part.strip()] + + +def _extract_row_vector(raw_value) -> list[float]: + if isinstance(raw_value, (list, tuple)): + return [float(v) for v in raw_value] + return _parse_vector_text(str(raw_value)) + + +def _cosine_distance(lhs: Iterable[float], rhs: Iterable[float]) -> float: + lhs_values = list(lhs) + rhs_values = list(rhs) + dot = sum(a * b for a, b in zip(lhs_values, rhs_values)) + lhs_norm = math.sqrt(sum(a * a for a in lhs_values)) + rhs_norm = math.sqrt(sum(b * b for b in rhs_values)) + if lhs_norm == 0.0 or rhs_norm == 0.0: + return 1.0 + similarity = dot / (lhs_norm * rhs_norm) + similarity = max(-1.0, min(1.0, similarity)) + return 1.0 - similarity + + +def _l2_distance(lhs: Iterable[float], rhs: Iterable[float]) -> float: + lhs_values = list(lhs) + rhs_values = list(rhs) + return math.sqrt(sum((a - b) * (a - b) for a, b in zip(lhs_values, rhs_values))) + + +def _dot_score(lhs: Iterable[float], rhs: Iterable[float]) -> float: + lhs_values = list(lhs) + rhs_values = list(rhs) + return sum(a * b for a, b in zip(lhs_values, rhs_values)) + + +def _run_vector_search( + csv_path: pathlib.Path, + vector_column: str, + query_vector: str, + metric: str, + top_k: int, +) -> int: + session = Session() + table = session.read_csv(str(csv_path)).to_arrow() + rows = table.to_pylist() + needle = _parse_vector_text(query_vector) + if not needle: + raise ValueError("--query-vector must not be empty") + + scored = [] + expected_dim = len(needle) + for row_index, row in enumerate(rows): + if vector_column not in row: + raise KeyError(f"vector column not found: {vector_column}") + vector = _extract_row_vector(row[vector_column]) + if len(vector) != expected_dim: + raise ValueError( + f"fixed length vector mismatch at row {row_index}: expect {expected_dim}, got {len(vector)}" + ) + if metric in ("cosine", "cosin"): + distance = _cosine_distance(vector, needle) + elif metric == "dot": + distance = _dot_score(vector, needle) + else: + distance = _l2_distance(vector, needle) + scored.append({"row_index": row_index, "distance": distance, "row": row}) + + scored.sort(key=lambda item: item["distance"], reverse=(metric == "dot")) + payload = { + "metric": "cosine" if metric in ("cosine", "cosin") else metric, + "top_k": top_k, + "rows": scored[:top_k], + } + print(json.dumps(payload, indent=2, ensure_ascii=False)) + return 0 + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="velaria-cli", + description="Velaria CLI for SQL query execution.", + ) + subparsers = parser.add_subparsers(dest="command", required=True) + + csv_sql = subparsers.add_parser( + "csv-sql", + help="Read CSV and run a SQL query through DataflowSession.", + ) + csv_sql.add_argument( + "--csv", + required=True, + help="CSV file path.", + ) + csv_sql.add_argument( + "--table", + default="input_table", + help="Temporary table name exposed to session.sql(...).", + ) + csv_sql.add_argument( + "--query", + required=True, + help="SQL query text.", + ) + + vector_search = subparsers.add_parser( + "vector-search", + help="Read CSV and run fixed-length vector nearest search.", + ) + vector_search.add_argument("--csv", required=True, help="CSV file path.") + vector_search.add_argument( + "--vector-column", + required=True, + help="Vector column name. Row value format supports '1,2,3' or '[1,2,3]'.", + ) + vector_search.add_argument( + "--query-vector", + required=True, + help="Query vector, e.g. '0.1,0.2,0.3'.", + ) + vector_search.add_argument( + "--metric", + default="cosine", + choices=["cosine", "cosin", "dot", "l2"], + help="Distance metric.", + ) + vector_search.add_argument( + "--top-k", + type=int, + default=5, + help="Return top-k nearest rows.", + ) + return parser + + +def main() -> int: + parser = _build_parser() + args = parser.parse_args() + + if args.command == "csv-sql": + return _run_csv_sql(pathlib.Path(args.csv), args.table, args.query) + if args.command == "vector-search": + return _run_vector_search( + csv_path=pathlib.Path(args.csv), + vector_column=args.vector_column, + query_vector=args.query_vector, + metric=args.metric, + top_k=args.top_k, + ) + + parser.error(f"unsupported command: {args.command}") + return 2 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/BUILD.bazel b/scripts/BUILD.bazel index b83c574..0afaec1 100644 --- a/scripts/BUILD.bazel +++ b/scripts/BUILD.bazel @@ -1,4 +1,5 @@ exports_files([ + "build_py_cli_executable.sh", "build_dashboard_frontend.sh", "build_native_wheel.py", "run_actor_rpc_e2e.sh", diff --git a/scripts/build_py_cli_executable.sh b/scripts/build_py_cli_executable.sh new file mode 100755 index 0000000..890948d --- /dev/null +++ b/scripts/build_py_cli_executable.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +PY_DIR="${ROOT_DIR}/python_api" +OUT_DIR="${1:-${ROOT_DIR}/dist}" +EXE_NAME="${2:-velaria-cli}" + +mkdir -p "${OUT_DIR}" + +NATIVE_SO="${ROOT_DIR}/bazel-bin/_velaria.so" +if [[ ! -f "${NATIVE_SO}" ]]; then + echo "[build-cli] building native extension //:velaria_pyext" + bazel build //:velaria_pyext + NATIVE_SO="$(bazel info bazel-bin)/_velaria.so" +fi + +if [[ ! -f "${NATIVE_SO}" ]]; then + echo "[build-cli] native extension not found: ${NATIVE_SO}" >&2 + exit 1 +fi + +cp "${NATIVE_SO}" "${PY_DIR}/velaria/_velaria.so" + +cleanup() { + rm -f "${PY_DIR}/velaria/_velaria.so" +} +trap cleanup EXIT + +echo "[build-cli] packaging one-file executable with uv + pyinstaller" +uv run --project "${PY_DIR}" --with pyinstaller pyinstaller \ + --onefile \ + --name "${EXE_NAME}" \ + --distpath "${OUT_DIR}" \ + --workpath "${OUT_DIR}/.build" \ + --specpath "${OUT_DIR}/.spec" \ + "${PY_DIR}/velaria_cli.py" + +echo "[build-cli] done: ${OUT_DIR}/${EXE_NAME}" diff --git a/src/dataflow/api/dataframe.cc b/src/dataflow/api/dataframe.cc index 562580d..d9aa757 100644 --- a/src/dataflow/api/dataframe.cc +++ b/src/dataflow/api/dataframe.cc @@ -7,6 +7,7 @@ #include #include "src/dataflow/ai/plugin_runtime.h" +#include "src/dataflow/runtime/vector_index.h" namespace dataflow { @@ -270,6 +271,78 @@ DataFrame DataFrame::aggregate(const std::vector& keys, return DataFrame(node, executor_); } +DataFrame DataFrame::vectorQuery(const std::string& vectorColumn, + const std::vector& queryVector, + size_t top_k, + VectorDistanceMetric metric) const { + if (queryVector.empty()) { + throw std::invalid_argument("query vector cannot be empty"); + } + const auto source = materialize(); + const size_t vector_index = source.schema.indexOf(vectorColumn); + std::vector> vectors; + vectors.reserve(source.rows.size()); + for (size_t i = 0; i < source.rows.size(); ++i) { + if (vector_index >= source.rows[i].size()) continue; + const auto& cell = source.rows[i][vector_index]; + std::vector vec; + if (cell.type() == DataType::FixedVector) { + vec = cell.asFixedVector(); + } else { + vec = Value::parseFixedVector(cell.toString()); + } + if (vec.size() != queryVector.size()) { + throw std::invalid_argument("fixed vector length mismatch in vectorQuery"); + } + vectors.push_back(std::move(vec)); + } + VectorSearchOptions options; + options.top_k = top_k; + options.metric = metric == VectorDistanceMetric::L2 + ? VectorSearchMetric::L2 + : (metric == VectorDistanceMetric::Dot ? VectorSearchMetric::Dot + : VectorSearchMetric::Cosine); + const auto index = makeExactScanVectorIndex(std::move(vectors)); + const auto scored = index->search(queryVector, options); + const size_t take = scored.size(); + Table out; + out.schema = Schema({"row_id", "score"}); + out.rows.reserve(take); + for (size_t i = 0; i < take; ++i) { + Row row; + row.emplace_back(static_cast(scored[i].row_id)); + row.emplace_back(scored[i].score); + out.rows.push_back(std::move(row)); + } + return DataFrame(std::move(out)); +} + +std::string DataFrame::explainVectorQuery(const std::string& vectorColumn, + const std::vector& queryVector, size_t top_k, + VectorDistanceMetric metric) const { + const auto source = materialize(); + const size_t vector_index = source.schema.indexOf(vectorColumn); + std::vector> vectors; + vectors.reserve(source.rows.size()); + for (const auto& row : source.rows) { + if (vector_index >= row.size()) continue; + vectors.push_back(row[vector_index].type() == DataType::FixedVector + ? row[vector_index].asFixedVector() + : Value::parseFixedVector(row[vector_index].toString())); + } + VectorSearchOptions options; + options.top_k = top_k; + options.metric = metric == VectorDistanceMetric::L2 + ? VectorSearchMetric::L2 + : (metric == VectorDistanceMetric::Dot ? VectorSearchMetric::Dot + : VectorSearchMetric::Cosine); + const auto index = makeExactScanVectorIndex(std::move(vectors)); + if (!queryVector.empty() && index->dimension() != 0 && queryVector.size() != index->dimension()) { + throw std::invalid_argument("query vector dimension mismatch in explainVectorQuery"); + } + return index->explain(options); +} + GroupedDataFrame DataFrame::groupBy(const std::vector& keys) const { const auto source = materialize(); std::vector idxs; diff --git a/src/dataflow/api/dataframe.h b/src/dataflow/api/dataframe.h index f47f388..937212b 100644 --- a/src/dataflow/api/dataframe.h +++ b/src/dataflow/api/dataframe.h @@ -13,6 +13,7 @@ namespace dataflow { class DataFrame; +enum class VectorDistanceMetric { Cosine, Dot, L2 }; class GroupedDataFrame { public: @@ -45,6 +46,14 @@ class DataFrame { DataFrame cache() const; DataFrame aggregate(const std::vector& keys, const std::vector& aggs) const; + DataFrame vectorQuery(const std::string& vectorColumn, + const std::vector& queryVector, + size_t top_k, + VectorDistanceMetric metric = VectorDistanceMetric::Cosine) const; + std::string explainVectorQuery(const std::string& vectorColumn, + const std::vector& queryVector, + size_t top_k, + VectorDistanceMetric metric = VectorDistanceMetric::Cosine) const; GroupedDataFrame groupBy(const std::vector& keys) const; DataFrame join(const DataFrame& right, const std::string& leftOn, const std::string& rightOn, diff --git a/src/dataflow/api/session.cc b/src/dataflow/api/session.cc index a9778f2..66f980e 100644 --- a/src/dataflow/api/session.cc +++ b/src/dataflow/api/session.cc @@ -670,6 +670,19 @@ DataFrame DataflowSession::sql(const std::string& sql) { return result; } +DataFrame DataflowSession::vectorQuery(const std::string& table, const std::string& vector_column, + const std::vector& query_vector, size_t top_k, + VectorDistanceMetric metric) { + return catalog_.getView(table).vectorQuery(vector_column, query_vector, top_k, metric); +} + +std::string DataflowSession::explainVectorQuery(const std::string& table, + const std::string& vector_column, + const std::vector& query_vector, + size_t top_k, VectorDistanceMetric metric) { + return catalog_.getView(table).explainVectorQuery(vector_column, query_vector, top_k, metric); +} + StreamingDataFrame DataflowSession::streamSql(const std::string& sql) { const auto statement = sql::SqlParser::parse(sql); if (statement.kind != sql::SqlStatementKind::Select) { diff --git a/src/dataflow/api/session.h b/src/dataflow/api/session.h index 2232824..b6ce22a 100644 --- a/src/dataflow/api/session.h +++ b/src/dataflow/api/session.h @@ -26,6 +26,12 @@ class DataflowSession { void createTempView(const std::string& name, const StreamingDataFrame& df); void registerStreamSink(const std::string& name, std::shared_ptr sink); DataFrame sql(const std::string& sql); + DataFrame vectorQuery(const std::string& table, const std::string& vector_column, + const std::vector& query_vector, size_t top_k, + VectorDistanceMetric metric = VectorDistanceMetric::Cosine); + std::string explainVectorQuery(const std::string& table, const std::string& vector_column, + const std::vector& query_vector, size_t top_k, + VectorDistanceMetric metric = VectorDistanceMetric::Cosine); StreamingDataFrame streamSql(const std::string& sql); std::string explainStreamSql(const std::string& sql, const StreamingQueryOptions& options = {}); diff --git a/src/dataflow/core/value.h b/src/dataflow/core/value.h index 9c3ef7b..e3bf3b2 100644 --- a/src/dataflow/core/value.h +++ b/src/dataflow/core/value.h @@ -5,10 +5,11 @@ #include #include #include +#include namespace dataflow { -enum class DataType { Nil = 0, Int64 = 1, Double = 2, String = 3 }; +enum class DataType { Nil = 0, Int64 = 1, Double = 2, String = 3, FixedVector = 4 }; class Value { public: @@ -17,6 +18,8 @@ class Value { Value(double v) : type_(DataType::Double), i64_(0), d_(v), s_("") {} Value(const char* s) : type_(DataType::String), i64_(0), d_(0.0), s_(s) {} Value(std::string s) : type_(DataType::String), i64_(0), d_(0.0), s_(std::move(s)) {} + Value(std::vector v) + : type_(DataType::FixedVector), i64_(0), d_(0.0), s_(""), vec_(std::move(v)) {} DataType type() const { return type_; } @@ -52,6 +55,32 @@ class Value { return s_; } + const std::vector& asFixedVector() const { + if (type_ != DataType::FixedVector) { + throw std::runtime_error("value is not fixed vector"); + } + return vec_; + } + + static std::vector parseFixedVector(const std::string& raw) { + std::string text = raw; + if (!text.empty() && text.front() == '[' && text.back() == ']') { + text = text.substr(1, text.size() - 2); + } + std::vector out; + std::stringstream ss(text); + std::string token; + while (std::getline(ss, token, ',')) { + if (token.empty()) continue; + std::stringstream trim(token); + std::string cleaned; + trim >> cleaned; + if (cleaned.empty()) continue; + out.push_back(std::stof(cleaned)); + } + return out; + } + std::string toString() const { switch (type_) { case DataType::Nil: @@ -65,6 +94,16 @@ class Value { } case DataType::String: return s_; + case DataType::FixedVector: { + std::ostringstream oss; + oss << "["; + for (std::size_t i = 0; i < vec_.size(); ++i) { + if (i > 0) oss << ","; + oss << std::fixed << std::setprecision(6) << vec_[i]; + } + oss << "]"; + return oss.str(); + } } return ""; } @@ -79,6 +118,7 @@ class Value { int64_t i64_; double d_; std::string s_; + std::vector vec_; int compare(const Value& rhs) const { if (type_ != rhs.type_) { @@ -98,6 +138,16 @@ class Value { return (d_ < rhs.d_) ? -1 : (d_ > rhs.d_ ? 1 : 0); case DataType::String: return (s_ < rhs.s_) ? -1 : (s_ > rhs.s_ ? 1 : 0); + case DataType::FixedVector: { + if (vec_.size() != rhs.vec_.size()) { + return vec_.size() < rhs.vec_.size() ? -1 : 1; + } + for (std::size_t i = 0; i < vec_.size(); ++i) { + if (vec_[i] < rhs.vec_[i]) return -1; + if (vec_[i] > rhs.vec_[i]) return 1; + } + return 0; + } } return 0; } diff --git a/src/dataflow/examples/vector_search_benchmark.cc b/src/dataflow/examples/vector_search_benchmark.cc new file mode 100644 index 0000000..03bb782 --- /dev/null +++ b/src/dataflow/examples/vector_search_benchmark.cc @@ -0,0 +1,75 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "src/dataflow/api/session.h" + +namespace { + +dataflow::Table makeSyntheticTable(std::size_t rows, std::size_t dim, uint32_t seed) { + std::mt19937 rng(seed); + std::uniform_real_distribution dist(-1.0f, 1.0f); + + dataflow::Table table; + table.schema = dataflow::Schema({"id", "embedding"}); + table.rows.reserve(rows); + for (std::size_t i = 0; i < rows; ++i) { + std::vector vec(dim); + for (std::size_t d = 0; d < dim; ++d) vec[d] = dist(rng); + dataflow::Row row; + row.emplace_back(static_cast(i)); + row.emplace_back(dataflow::Value(std::move(vec))); + table.rows.push_back(std::move(row)); + } + return table; +} + +std::vector makeQuery(std::size_t dim, uint32_t seed) { + std::mt19937 rng(seed); + std::uniform_real_distribution dist(-1.0f, 1.0f); + std::vector q(dim); + for (std::size_t i = 0; i < dim; ++i) q[i] = dist(rng); + return q; +} + +void runCase(std::size_t rows, std::size_t dim, dataflow::VectorDistanceMetric metric, + const std::string& metric_name) { + auto table = makeSyntheticTable(rows, dim, static_cast(rows + dim)); + auto query = makeQuery(dim, static_cast(dim)); + + auto& session = dataflow::DataflowSession::builder(); + const std::string view_name = "vec_bench_" + std::to_string(rows) + "_" + std::to_string(dim); + session.createTempView(view_name, session.createDataFrame(table)); + + const auto begin = std::chrono::steady_clock::now(); + auto out = session.vectorQuery(view_name, "embedding", query, 10, metric).toTable(); + const auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - begin); + + std::cout << "{" + << "\"rows\":" << rows << "," + << "\"dimension\":" << dim << "," + << "\"top_k\":10," + << "\"metric\":\"" << metric_name << "\"," + << "\"elapsed_ms\":" << elapsed.count() << "," + << "\"result_rows\":" << out.rows.size() << "" + << "}" << std::endl; +} + +} // namespace + +int main() { + std::cout << "[vector-benchmark] exact scan regression baseline" << std::endl; + for (std::size_t rows : {10000ULL, 100000ULL}) { + for (std::size_t dim : {128ULL, 768ULL}) { + runCase(rows, dim, dataflow::VectorDistanceMetric::Cosine, "cosine"); + runCase(rows, dim, dataflow::VectorDistanceMetric::Dot, "dot"); + runCase(rows, dim, dataflow::VectorDistanceMetric::L2, "l2"); + } + } + return 0; +} diff --git a/src/dataflow/examples/velaria_cli.cc b/src/dataflow/examples/velaria_cli.cc new file mode 100644 index 0000000..0cd64ee --- /dev/null +++ b/src/dataflow/examples/velaria_cli.cc @@ -0,0 +1,287 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "src/dataflow/api/session.h" +#include "src/dataflow/core/value.h" + +namespace { + +std::string escapeJson(const std::string& input) { + std::string out; + out.reserve(input.size()); + for (char c : input) { + switch (c) { + case '\\': + out += "\\\\"; + break; + case '"': + out += "\\\""; + break; + case '\n': + out += "\\n"; + break; + case '\r': + out += "\\r"; + break; + case '\t': + out += "\\t"; + break; + default: + out += c; + break; + } + } + return out; +} + +std::string valueToJson(const dataflow::Value& value) { + switch (value.type()) { + case dataflow::DataType::Nil: + return "null"; + case dataflow::DataType::Int64: + return std::to_string(value.asInt64()); + case dataflow::DataType::Double: + return value.toString(); + case dataflow::DataType::String: + return "\"" + escapeJson(value.asString()) + "\""; + } + return "null"; +} + +void printUsage(const char* program) { + std::cerr << "Usage: " << program + << " --csv [--query ] [--table ] [--delimiter ]\n" + << " " << program + << " --csv --vector-column --query-vector \n" + << " [--metric cosine|cosin|dot|l2] [--top-k ]\n"; +} + +std::vector parseVectorText(const std::string& raw) { + std::string input = raw; + if (!input.empty() && input.front() == '[' && input.back() == ']') { + input = input.substr(1, input.size() - 2); + } + std::vector out; + std::stringstream ss(input); + std::string item; + while (std::getline(ss, item, ',')) { + if (item.empty()) continue; + out.push_back(std::stod(item)); + } + return out; +} + +std::vector parseVectorValue(const dataflow::Value& value) { + if (value.type() == dataflow::DataType::String) { + return parseVectorText(value.asString()); + } + return parseVectorText(value.toString()); +} + +double l2Distance(const std::vector& lhs, const std::vector& rhs) { + double sum = 0.0; + for (std::size_t i = 0; i < lhs.size(); ++i) { + const double diff = lhs[i] - rhs[i]; + sum += diff * diff; + } + return std::sqrt(sum); +} + +double cosineDistance(const std::vector& lhs, const std::vector& rhs) { + double dot = 0.0; + double lhs_norm = 0.0; + double rhs_norm = 0.0; + for (std::size_t i = 0; i < lhs.size(); ++i) { + dot += lhs[i] * rhs[i]; + lhs_norm += lhs[i] * lhs[i]; + rhs_norm += rhs[i] * rhs[i]; + } + if (lhs_norm == 0.0 || rhs_norm == 0.0) { + return 1.0; + } + double similarity = dot / (std::sqrt(lhs_norm) * std::sqrt(rhs_norm)); + if (similarity > 1.0) similarity = 1.0; + if (similarity < -1.0) similarity = -1.0; + return 1.0 - similarity; +} + +double dotScore(const std::vector& lhs, const std::vector& rhs) { + double dot = 0.0; + for (std::size_t i = 0; i < lhs.size(); ++i) { + dot += lhs[i] * rhs[i]; + } + return dot; +} + +} // namespace + +int main(int argc, char** argv) { + std::string csv_path; + std::string query; + std::string table = "input_table"; + std::string vector_column; + std::string query_vector; + std::string metric = "cosine"; + std::size_t top_k = 5; + char delimiter = ','; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--csv" && i + 1 < argc) { + csv_path = argv[++i]; + } else if (arg == "--query" && i + 1 < argc) { + query = argv[++i]; + } else if (arg == "--table" && i + 1 < argc) { + table = argv[++i]; + } else if (arg == "--delimiter" && i + 1 < argc) { + std::string value = argv[++i]; + if (value.size() != 1) { + std::cerr << "--delimiter expects a single character\n"; + return 2; + } + delimiter = value[0]; + } else if (arg == "--vector-column" && i + 1 < argc) { + vector_column = argv[++i]; + } else if (arg == "--query-vector" && i + 1 < argc) { + query_vector = argv[++i]; + } else if (arg == "--metric" && i + 1 < argc) { + metric = argv[++i]; + } else if (arg == "--top-k" && i + 1 < argc) { + top_k = static_cast(std::stoull(argv[++i])); + } else if (arg == "-h" || arg == "--help") { + printUsage(argv[0]); + return 0; + } else { + std::cerr << "Unknown argument: " << arg << "\n"; + printUsage(argv[0]); + return 2; + } + } + + if (csv_path.empty()) { + printUsage(argv[0]); + return 2; + } + + const bool vector_mode = !vector_column.empty() || !query_vector.empty(); + if (!vector_mode && query.empty()) { + printUsage(argv[0]); + return 2; + } + + try { + auto& session = dataflow::DataflowSession::builder(); + auto df = session.read_csv(csv_path, delimiter); + + if (vector_mode) { + auto result = df.toTable(); + if (!result.schema.has(vector_column)) { + throw std::runtime_error("vector column not found: " + vector_column); + } + const auto vector_index = result.schema.indexOf(vector_column); + const auto needle = parseVectorText(query_vector); + if (needle.empty()) { + throw std::runtime_error("query vector cannot be empty"); + } + + struct Candidate { + std::size_t row_index; + double distance; + }; + std::vector candidates; + candidates.reserve(result.rows.size()); + + for (std::size_t i = 0; i < result.rows.size(); ++i) { + auto vec = parseVectorValue(result.rows[i][vector_index]); + if (vec.size() != needle.size()) { + throw std::runtime_error("fixed length vector mismatch at row " + std::to_string(i)); + } + double distance = 0.0; + if (metric == "cosine" || metric == "cosin") { + distance = cosineDistance(vec, needle); + } else if (metric == "dot") { + distance = dotScore(vec, needle); + } else if (metric == "l2") { + distance = l2Distance(vec, needle); + } else { + throw std::runtime_error("unsupported metric: " + metric); + } + candidates.push_back(Candidate{i, distance}); + } + + if (metric == "dot") { + std::sort(candidates.begin(), candidates.end(), [](const Candidate& lhs, const Candidate& rhs) { + return lhs.distance > rhs.distance; + }); + } else { + std::sort(candidates.begin(), candidates.end(), [](const Candidate& lhs, const Candidate& rhs) { + return lhs.distance < rhs.distance; + }); + } + + const std::size_t emit = std::min(top_k, candidates.size()); + std::cout << "{\n"; + std::cout << " \"metric\": \"" << (metric == "cosin" ? "cosine" : metric) << "\",\n"; + std::cout << " \"top_k\": " << top_k << ",\n"; + std::cout << " \"rows\": [\n"; + for (std::size_t i = 0; i < emit; ++i) { + const auto& c = candidates[i]; + std::cout << " {\"row_index\": " << c.row_index << ", \"distance\": " << c.distance + << ", \"row\": ["; + const auto& row = result.rows[c.row_index]; + for (std::size_t j = 0; j < row.size(); ++j) { + if (j > 0) std::cout << ", "; + std::cout << valueToJson(row[j]); + } + std::cout << "]}"; + if (i + 1 < emit) std::cout << ","; + std::cout << "\n"; + } + std::cout << " ]\n"; + std::cout << "}\n"; + return 0; + } + + session.createTempView(table, df); + auto result = session.sql(query).toTable(); + + std::cout << "{\n"; + std::cout << " \"table\": \"" << escapeJson(table) << "\",\n"; + std::cout << " \"query\": \"" << escapeJson(query) << "\",\n"; + std::cout << " \"schema\": ["; + for (std::size_t i = 0; i < result.schema.fields.size(); ++i) { + if (i > 0) { + std::cout << ", "; + } + std::cout << "\"" << escapeJson(result.schema.fields[i]) << "\""; + } + std::cout << "],\n"; + + std::cout << " \"rows\": [\n"; + for (std::size_t r = 0; r < result.rows.size(); ++r) { + std::cout << " ["; + for (std::size_t c = 0; c < result.rows[r].size(); ++c) { + if (c > 0) { + std::cout << ", "; + } + std::cout << valueToJson(result.rows[r][c]); + } + std::cout << "]"; + if (r + 1 < result.rows.size()) { + std::cout << ","; + } + std::cout << "\n"; + } + std::cout << " ]\n"; + std::cout << "}\n"; + return 0; + } catch (const std::exception& ex) { + std::cerr << "velaria_cli failed: " << ex.what() << "\n"; + return 1; + } +} diff --git a/src/dataflow/planner/plan.cc b/src/dataflow/planner/plan.cc index 662a8c0..7f92bda 100644 --- a/src/dataflow/planner/plan.cc +++ b/src/dataflow/planner/plan.cc @@ -1,6 +1,9 @@ #include "src/dataflow/planner/plan.h" #include +#include +#include +#include #include #include @@ -71,7 +74,36 @@ bool (*predicateForOp(const std::string& op))(const Value&, const Value&) { std::string serializeValue(const Value& value) { std::string out; appendInt(&out, static_cast(value.type())); - appendToken(&out, value.toString()); + switch (value.type()) { + case DataType::Nil: + appendToken(&out, ""); + break; + case DataType::Int64: + appendToken(&out, std::to_string(value.asInt64())); + break; + case DataType::Double: { + std::ostringstream oss; + oss.precision(std::numeric_limits::max_digits10); + oss << value.asDouble(); + appendToken(&out, oss.str()); + break; + } + case DataType::String: + appendToken(&out, value.asString()); + break; + case DataType::FixedVector: { + std::ostringstream oss; + const auto& vec = value.asFixedVector(); + oss << vec.size(); + for (float v : vec) { + uint32_t bits = 0; + std::memcpy(&bits, &v, sizeof(bits)); + oss << ";" << bits; + } + appendToken(&out, oss.str()); + break; + } + } return out; } @@ -88,6 +120,24 @@ Value deserializeValue(const std::string& payload) { return Value(std::stod(raw)); case DataType::String: return Value(raw); + case DataType::FixedVector: { + std::vector vec; + std::stringstream ss(raw); + std::string token; + if (!std::getline(ss, token, ';')) return Value(vec); + const std::size_t n = static_cast(std::stoull(token)); + vec.reserve(n); + for (std::size_t i = 0; i < n; ++i) { + if (!std::getline(ss, token, ';')) { + throw std::runtime_error("plan decode: invalid fixed vector payload"); + } + const uint32_t bits = static_cast(std::stoul(token)); + float v = 0.0f; + std::memcpy(&v, &bits, sizeof(v)); + vec.push_back(v); + } + return Value(std::move(vec)); + } } throw std::runtime_error("plan decode: unsupported value type"); } diff --git a/src/dataflow/python/python_module.cc b/src/dataflow/python/python_module.cc index 65e3d01..eec42f3 100644 --- a/src/dataflow/python/python_module.cc +++ b/src/dataflow/python/python_module.cc @@ -174,6 +174,29 @@ std::vector parseStringList(PyObject* obj, const char* arg_name) { return values; } +std::vector parseFloatVector(PyObject* obj, const char* arg_name) { + if (!PySequence_Check(obj)) { + throw std::runtime_error(std::string(arg_name) + " must be a sequence of floats"); + } + PyObject* seq = PySequence_Fast(obj, arg_name); + if (seq == nullptr) { + throw std::runtime_error(std::string("failed to read ") + arg_name); + } + std::vector values; + const Py_ssize_t count = PySequence_Fast_GET_SIZE(seq); + values.reserve(static_cast(count)); + PyObject** items = PySequence_Fast_ITEMS(seq); + for (Py_ssize_t i = 0; i < count; ++i) { + if (!PyFloat_Check(items[i]) && !PyLong_Check(items[i])) { + Py_DECREF(seq); + throw std::runtime_error(std::string(arg_name) + " must contain only numbers"); + } + values.push_back(static_cast(PyFloat_AsDouble(items[i]))); + } + Py_DECREF(seq); + return values; +} + df::Value valueFromPy(PyObject* obj) { if (obj == Py_None) { return df::Value(); @@ -190,6 +213,9 @@ df::Value valueFromPy(PyObject* obj) { if (PyUnicode_Check(obj)) { return df::Value(std::string(PyUnicode_AsUTF8(obj))); } + if (PyList_Check(obj) || PyTuple_Check(obj)) { + return df::Value(parseFloatVector(obj, "value")); + } throw std::runtime_error("value must be None, int, float, bool, or string"); } @@ -204,6 +230,12 @@ enum class FastArrowColumnKind { Float64, Utf8, LargeUtf8, + FixedSizeListFloat32, +}; + +struct FastArrowColumnSpec { + FastArrowColumnKind kind = FastArrowColumnKind::Unsupported; + int32_t fixed_list_size = 0; }; bool bitmapHasValue(const uint8_t* bitmap, int64_t index) { @@ -213,31 +245,42 @@ bool bitmapHasValue(const uint8_t* bitmap, int64_t index) { return ((bitmap[index >> 3] >> (index & 7)) & 0x01u) != 0; } -FastArrowColumnKind fastArrowColumnKind(const ArrowSchema* schema) { +FastArrowColumnSpec fastArrowColumnSpec(const ArrowSchema* schema) { + FastArrowColumnSpec spec; if (schema == nullptr || schema->format == nullptr) { - return FastArrowColumnKind::Unsupported; + return spec; } const std::string format(schema->format); - if (format == "b") return FastArrowColumnKind::Bool; - if (format == "i") return FastArrowColumnKind::Int32; - if (format == "I") return FastArrowColumnKind::UInt32; - if (format == "l") return FastArrowColumnKind::Int64; - if (format == "L") return FastArrowColumnKind::UInt64; - if (format == "f") return FastArrowColumnKind::Float32; - if (format == "g") return FastArrowColumnKind::Float64; - if (format == "u") return FastArrowColumnKind::Utf8; - if (format == "U") return FastArrowColumnKind::LargeUtf8; - return FastArrowColumnKind::Unsupported; -} - -df::Value valueFromFastArrowColumn(FastArrowColumnKind kind, const ArrowArray* array, int64_t row_index) { + if (format == "b") spec.kind = FastArrowColumnKind::Bool; + if (format == "i") spec.kind = FastArrowColumnKind::Int32; + if (format == "I") spec.kind = FastArrowColumnKind::UInt32; + if (format == "l") spec.kind = FastArrowColumnKind::Int64; + if (format == "L") spec.kind = FastArrowColumnKind::UInt64; + if (format == "f") spec.kind = FastArrowColumnKind::Float32; + if (format == "g") spec.kind = FastArrowColumnKind::Float64; + if (format == "u") spec.kind = FastArrowColumnKind::Utf8; + if (format == "U") spec.kind = FastArrowColumnKind::LargeUtf8; + if (spec.kind != FastArrowColumnKind::Unsupported) return spec; + + // Arrow C data interface fixed-size list format: +w: + if (format.rfind("+w:", 0) == 0 && schema->n_children == 1 && schema->children != nullptr && + schema->children[0] != nullptr && std::string(schema->children[0]->format) == "f") { + spec.kind = FastArrowColumnKind::FixedSizeListFloat32; + spec.fixed_list_size = static_cast(std::stoi(format.substr(3))); + return spec; + } + return spec; +} + +df::Value valueFromFastArrowColumn(const FastArrowColumnSpec& spec, const ArrowArray* array, + int64_t row_index) { const int64_t offset_row = row_index + array->offset; const auto* validity = reinterpret_cast(array->n_buffers > 0 ? array->buffers[0] : nullptr); if (array->null_count != 0 && !bitmapHasValue(validity, offset_row)) { return df::Value(); } - switch (kind) { + switch (spec.kind) { case FastArrowColumnKind::Bool: { const auto* bits = reinterpret_cast(array->buffers[1]); return df::Value(static_cast(((bits[offset_row >> 3] >> (offset_row & 7)) & 0x01u) != 0)); @@ -268,6 +311,23 @@ df::Value valueFromFastArrowColumn(FastArrowColumnKind kind, const ArrowArray* a const auto end = offsets[offset_row + 1]; return df::Value(std::string(data + begin, data + end)); } + case FastArrowColumnKind::FixedSizeListFloat32: { + if (array->n_children != 1 || array->children == nullptr || array->children[0] == nullptr || + spec.fixed_list_size <= 0) { + throw std::runtime_error("invalid Arrow fixed-size-list column"); + } + const ArrowArray* child = array->children[0]; + const int64_t base_index = (offset_row * static_cast(spec.fixed_list_size)) + child->offset; + const auto* values = reinterpret_cast(child->buffers[1]); + if (values == nullptr) { + throw std::runtime_error("missing child value buffer for fixed-size-list"); + } + std::vector out(static_cast(spec.fixed_list_size)); + for (int32_t i = 0; i < spec.fixed_list_size; ++i) { + out[static_cast(i)] = values[base_index + i]; + } + return df::Value(std::move(out)); + } case FastArrowColumnKind::Unsupported: break; } @@ -301,19 +361,19 @@ bool tryTableFromArrowCapsules(PyObject* obj, df::Table* table) { std::vector names; names.reserve(static_cast(schema->n_children)); - std::vector kinds; - kinds.reserve(static_cast(schema->n_children)); + std::vector specs; + specs.reserve(static_cast(schema->n_children)); for (int64_t i = 0; i < schema->n_children; ++i) { const auto* child_schema = schema->children[i]; - const auto kind = fastArrowColumnKind(child_schema); - if (kind == FastArrowColumnKind::Unsupported) { + const auto spec = fastArrowColumnSpec(child_schema); + if (spec.kind == FastArrowColumnKind::Unsupported) { Py_DECREF(pair); return false; } names.emplace_back(child_schema != nullptr && child_schema->name != nullptr ? child_schema->name : ("col_" + std::to_string(static_cast(i)))); - kinds.push_back(kind); + specs.push_back(spec); } std::vector rows(static_cast(array->length), @@ -326,7 +386,7 @@ bool tryTableFromArrowCapsules(PyObject* obj, df::Table* table) { } for (int64_t row = 0; row < array->length; ++row) { rows[static_cast(row)][static_cast(column)] = - valueFromFastArrowColumn(kinds[static_cast(column)], child_array, row); + valueFromFastArrowColumn(specs[static_cast(column)], child_array, row); } } @@ -540,6 +600,14 @@ PyObject* pyFromValue(const df::Value& value) { return PyFloat_FromDouble(value.asDouble()); case df::DataType::String: return PyUnicode_FromString(value.asString().c_str()); + case df::DataType::FixedVector: { + const auto& vec = value.asFixedVector(); + PyObject* out = PyList_New(static_cast(vec.size())); + for (size_t i = 0; i < vec.size(); ++i) { + PyList_SET_ITEM(out, static_cast(i), PyFloat_FromDouble(vec[i])); + } + return out; + } } Py_RETURN_NONE; } @@ -577,6 +645,8 @@ std::string arrowFormatForValue(const df::Value& value) { return "l"; case df::DataType::Double: return "g"; + case df::DataType::FixedVector: + return "u"; } return "u"; } @@ -1083,6 +1153,62 @@ PyObject* sessionExplainStreamSql(PyVelariaSession* self, PyObject* args, PyObje }); } +df::VectorDistanceMetric parseVectorMetric(const std::string& metric) { + if (metric == "cosine" || metric == "cosin") return df::VectorDistanceMetric::Cosine; + if (metric == "dot") return df::VectorDistanceMetric::Dot; + if (metric == "l2") return df::VectorDistanceMetric::L2; + throw std::runtime_error("metric must be one of: cosine/cosin, dot, l2"); +} + +PyObject* sessionVectorSearch(PyVelariaSession* self, PyObject* args, PyObject* kwargs) { + return withExceptionTranslation([&]() -> PyObject* { + const char* table = nullptr; + const char* vector_column = nullptr; + PyObject* query_vector = nullptr; + unsigned long long top_k = 10; + const char* metric = "cosine"; + static const char* kwlist[] = {"table", "vector_column", "query_vector", "top_k", "metric", + nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "ssO|Ks", const_cast(kwlist), &table, + &vector_column, &query_vector, &top_k, &metric)) { + return nullptr; + } + const auto query = parseFloatVector(query_vector, "query_vector"); + df::DataFrame result; + { + AllowThreads allow; + result = self->session_ptr->vectorQuery(table, vector_column, query, static_cast(top_k), + parseVectorMetric(metric)); + } + return wrapDataFrame(result); + }); +} + +PyObject* sessionExplainVectorSearch(PyVelariaSession* self, PyObject* args, PyObject* kwargs) { + return withExceptionTranslation([&]() -> PyObject* { + const char* table = nullptr; + const char* vector_column = nullptr; + PyObject* query_vector = nullptr; + unsigned long long top_k = 10; + const char* metric = "cosine"; + static const char* kwlist[] = {"table", "vector_column", "query_vector", "top_k", "metric", + nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "ssO|Ks", const_cast(kwlist), &table, + &vector_column, &query_vector, &top_k, &metric)) { + return nullptr; + } + const auto query = parseFloatVector(query_vector, "query_vector"); + std::string explain; + { + AllowThreads allow; + explain = self->session_ptr->explainVectorQuery(table, vector_column, query, + static_cast(top_k), + parseVectorMetric(metric)); + } + return PyUnicode_FromString(explain.c_str()); + }); +} + PyObject* dataFrameToRows(PyVelariaDataFrame* self, PyObject*) { return withExceptionTranslation([&]() -> PyObject* { std::unique_ptr table; @@ -1379,6 +1505,10 @@ PyMethodDef sessionMethods[] = { {"explain_stream_sql", reinterpret_cast(sessionExplainStreamSql), METH_VARARGS | METH_KEYWORDS, "Explain a SELECT or INSERT INTO ... SELECT ... streaming SQL query."}, + {"vector_search", reinterpret_cast(sessionVectorSearch), + METH_VARARGS | METH_KEYWORDS, "Run exact local vector search on a temp view."}, + {"explain_vector_search", reinterpret_cast(sessionExplainVectorSearch), + METH_VARARGS | METH_KEYWORDS, "Explain exact local vector search strategy."}, {nullptr, nullptr, 0, nullptr}, }; diff --git a/src/dataflow/runtime/vector_index.cc b/src/dataflow/runtime/vector_index.cc new file mode 100644 index 0000000..f5c8be1 --- /dev/null +++ b/src/dataflow/runtime/vector_index.cc @@ -0,0 +1,161 @@ +#include "src/dataflow/runtime/vector_index.h" + +#include +#include +#include +#include +#include + +namespace dataflow { +namespace { + +const char* metricName(VectorSearchMetric metric) { + switch (metric) { + case VectorSearchMetric::Cosine: + return "cosine"; + case VectorSearchMetric::Dot: + return "dot"; + case VectorSearchMetric::L2: + return "l2"; + } + return "cosine"; +} + +class ExactScanVectorIndex : public VectorIndex { + public: + explicit ExactScanVectorIndex(std::vector> vectors) + : row_count_(vectors.size()), dimension_(vectors.empty() ? 0 : vectors.front().size()) { + if (vectors.empty()) return; + for (const auto& row : vectors) { + if (row.size() != dimension_) { + throw std::invalid_argument("exact scan vector index requires fixed-length vectors"); + } + } + data_.resize(row_count_ * dimension_); + norms_.resize(row_count_, 0.0); + for (std::size_t row = 0; row < row_count_; ++row) { + double norm = 0.0; + for (std::size_t col = 0; col < dimension_; ++col) { + const float v = vectors[row][col]; + data_[row * dimension_ + col] = v; + norm += static_cast(v) * static_cast(v); + } + norms_[row] = std::sqrt(norm); + } + } + + std::size_t dimension() const override { return dimension_; } + std::size_t size() const override { return row_count_; } + + std::vector search(const std::vector& query, + const VectorSearchOptions& options) const override { + if (row_count_ == 0) return {}; + if (query.size() != dimension()) { + throw std::invalid_argument("query vector dimension mismatch"); + } + + const std::size_t k = std::max(1, options.top_k); + auto cmp_desc = [](const VectorSearchResult& lhs, const VectorSearchResult& rhs) { + return lhs.score < rhs.score; + }; + auto cmp_asc = [](const VectorSearchResult& lhs, const VectorSearchResult& rhs) { + return lhs.score > rhs.score; + }; + std::priority_queue, decltype(cmp_desc)> + max_heap(cmp_desc); + std::priority_queue, decltype(cmp_asc)> + min_heap(cmp_asc); + + double query_norm = 0.0; + for (float v : query) query_norm += static_cast(v) * static_cast(v); + query_norm = std::sqrt(query_norm); + + for (std::size_t row = 0; row < row_count_; ++row) { + const float* base = data_.data() + row * dimension_; + double score = 0.0; + if (options.metric == VectorSearchMetric::L2) { + double sum = 0.0; + for (std::size_t col = 0; col < dimension_; ++col) { + const double diff = static_cast(base[col]) - static_cast(query[col]); + sum += diff * diff; + } + score = std::sqrt(sum); + } else { + double dot = 0.0; + for (std::size_t col = 0; col < dimension_; ++col) { + dot += static_cast(base[col]) * static_cast(query[col]); + } + if (options.metric == VectorSearchMetric::Dot) { + score = dot; + } else { + const double denom = norms_[row] * query_norm; + score = denom == 0.0 ? 1.0 : (1.0 - std::max(-1.0, std::min(1.0, dot / denom))); + } + } + VectorSearchResult current{row, score}; + if (options.metric == VectorSearchMetric::Dot) { + if (min_heap.size() < k) { + min_heap.push(current); + } else if (current.score > min_heap.top().score) { + min_heap.pop(); + min_heap.push(current); + } + } else { + if (max_heap.size() < k) { + max_heap.push(current); + } else if (current.score < max_heap.top().score) { + max_heap.pop(); + max_heap.push(current); + } + } + } + + std::vector scored; + if (options.metric == VectorSearchMetric::Dot) { + while (!min_heap.empty()) { + scored.push_back(min_heap.top()); + min_heap.pop(); + } + std::sort(scored.begin(), scored.end(), + [](const VectorSearchResult& lhs, const VectorSearchResult& rhs) { + return lhs.score > rhs.score; + }); + } else { + while (!max_heap.empty()) { + scored.push_back(max_heap.top()); + max_heap.pop(); + } + std::sort(scored.begin(), scored.end(), + [](const VectorSearchResult& lhs, const VectorSearchResult& rhs) { + return lhs.score < rhs.score; + }); + } + return scored; + } + + std::string explain(const VectorSearchOptions& options) const override { + std::ostringstream out; + out << "mode=exact-scan\n"; + out << "metric=" << metricName(options.metric) << "\n"; + out << "dimension=" << dimension() << "\n"; + out << "top_k=" << options.top_k << "\n"; + out << "candidate_rows=" << size() << "\n"; + out << "acceleration=flat-buffer+heap-topk\n"; + out << "filter_pushdown=false\n"; + return out.str(); + } + + private: + std::size_t row_count_ = 0; + std::size_t dimension_ = 0; + std::vector data_; + std::vector norms_; +}; + +} // namespace + +std::unique_ptr makeExactScanVectorIndex(std::vector> vectors) { + return std::unique_ptr(new ExactScanVectorIndex(std::move(vectors))); +} + +} // namespace dataflow diff --git a/src/dataflow/runtime/vector_index.h b/src/dataflow/runtime/vector_index.h new file mode 100644 index 0000000..d743a33 --- /dev/null +++ b/src/dataflow/runtime/vector_index.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include +#include + +namespace dataflow { + +enum class VectorSearchMetric { Cosine, Dot, L2 }; + +struct VectorSearchOptions { + VectorSearchMetric metric = VectorSearchMetric::Cosine; + std::size_t top_k = 10; +}; + +struct VectorSearchResult { + std::size_t row_id = 0; + double score = 0.0; +}; + +class VectorIndex { + public: + virtual ~VectorIndex() = default; + + virtual std::size_t dimension() const = 0; + virtual std::size_t size() const = 0; + virtual std::vector search( + const std::vector& query, + const VectorSearchOptions& options) const = 0; + virtual std::string explain(const VectorSearchOptions& options) const = 0; +}; + +std::unique_ptr makeExactScanVectorIndex(std::vector> vectors); + +} // namespace dataflow diff --git a/src/dataflow/serial/serializer.cc b/src/dataflow/serial/serializer.cc index 2e03993..359f199 100644 --- a/src/dataflow/serial/serializer.cc +++ b/src/dataflow/serial/serializer.cc @@ -2,10 +2,75 @@ #include #include +#include +#include #include namespace dataflow { +namespace { + +std::string encodeValuePayload(const Value& value) { + std::ostringstream out; + switch (value.type()) { + case DataType::Nil: + return ""; + case DataType::Int64: + return std::to_string(value.asInt64()); + case DataType::Double: + out.precision(std::numeric_limits::max_digits10); + out << value.asDouble(); + return out.str(); + case DataType::String: + return value.asString(); + case DataType::FixedVector: { + const auto& vec = value.asFixedVector(); + out << vec.size(); + for (float v : vec) { + uint32_t bits = 0; + std::memcpy(&bits, &v, sizeof(bits)); + out << ";" << bits; + } + return out.str(); + } + } + return ""; +} + +Value decodeValuePayload(DataType type, const std::string& payload) { + switch (type) { + case DataType::Nil: + return Value(); + case DataType::Int64: + return Value(static_cast(std::stoll(payload))); + case DataType::Double: + return Value(std::stod(payload)); + case DataType::String: + return Value(payload); + case DataType::FixedVector: { + std::vector vec; + std::stringstream ss(payload); + std::string token; + if (!std::getline(ss, token, ';')) return Value(vec); + const std::size_t n = static_cast(std::stoull(token)); + vec.reserve(n); + for (std::size_t i = 0; i < n; ++i) { + if (!std::getline(ss, token, ';')) { + throw std::runtime_error("invalid fixed vector payload"); + } + const uint32_t bits = static_cast(std::stoul(token)); + float v = 0.0f; + std::memcpy(&v, &bits, sizeof(v)); + vec.push_back(v); + } + return Value(std::move(vec)); + } + } + return Value(); +} + +} // namespace + std::string ProtoLikeSerializer::name() const { return "proto-like"; } std::string ProtoLikeSerializer::serialize(const Table& table) const { @@ -16,11 +81,11 @@ std::string ProtoLikeSerializer::serialize(const Table& table) const { } out << table.rows.size() << '\n'; for (const auto& row : table.rows) { - out << row.size() << ':'; + out << row.size(); for (size_t i = 0; i < row.size(); ++i) { const auto& v = row[i]; - out << static_cast(v.type()) << '|' << v.toString(); - if (i + 1 < row.size()) out << ','; + const std::string payload = encodeValuePayload(v); + out << "|" << static_cast(v.type()) << ":" << payload.size() << ":" << payload; } out << '\n'; } @@ -47,34 +112,29 @@ Table ProtoLikeSerializer::deserialize(const std::string& payload) const { for (size_t r = 0; r < rowCount; ++r) { if (!std::getline(in, token)) return Table(); - const auto colon = token.find(':'); - if (colon == std::string::npos) return Table(); - const size_t itemCount = static_cast(std::stoull(token.substr(0, colon))); - std::string payloadRest = token.substr(colon + 1); + const auto first_sep = token.find('|'); + const size_t itemCount = static_cast(std::stoull(first_sep == std::string::npos + ? token + : token.substr(0, first_sep))); Row row; row.reserve(itemCount); - size_t offset = 0; + size_t offset = first_sep == std::string::npos ? token.size() : first_sep + 1; for (size_t i = 0; i < itemCount; ++i) { - size_t sep = payloadRest.find(',', offset); - std::string item = (sep == std::string::npos) ? payloadRest.substr(offset) - : payloadRest.substr(offset, sep - offset); - offset = (sep == std::string::npos) ? payloadRest.size() : (sep + 1); - - const size_t bar = item.find('|'); - if (bar == std::string::npos) return Table(); - const auto typeCode = static_cast(std::stoi(item.substr(0, bar))); - const std::string value = item.substr(bar + 1); - - if (typeCode == DataType::Nil) { - row.emplace_back(); - } else if (typeCode == DataType::Int64) { - row.emplace_back(static_cast(std::stoll(value))); - } else if (typeCode == DataType::Double) { - row.emplace_back(std::stod(value)); - } else { - row.emplace_back(Value(value)); - } + const size_t type_sep = token.find(':', offset); + if (type_sep == std::string::npos) return Table(); + const auto typeCode = static_cast(std::stoi(token.substr(offset, type_sep - offset))); + const size_t len_sep = token.find(':', type_sep + 1); + if (len_sep == std::string::npos) return Table(); + const size_t payload_len = + static_cast(std::stoull(token.substr(type_sep + 1, len_sep - type_sep - 1))); + const size_t payload_begin = len_sep + 1; + const size_t payload_end = payload_begin + payload_len; + if (payload_end > token.size()) return Table(); + const std::string payload = token.substr(payload_begin, payload_len); + row.emplace_back(decodeValuePayload(typeCode, payload)); + offset = payload_end; + if (offset < token.size() && token[offset] == '|') ++offset; } table.rows.push_back(std::move(row)); } diff --git a/src/dataflow/stream/binary_row_batch.cc b/src/dataflow/stream/binary_row_batch.cc index acb8fef..e3c6774 100644 --- a/src/dataflow/stream/binary_row_batch.cc +++ b/src/dataflow/stream/binary_row_batch.cc @@ -298,12 +298,17 @@ size_t estimateValueSize(const Value& value, const PreparedBinaryRowColumn& colu case DataType::Double: return sizeof(uint64_t); case DataType::String: - case DataType::Nil: - default: if (column.encoding == kEncodingDictionary) { return varintSize(column.dictionary_index.at(value.toString())); } return stringEncodedSize(value.toString()); + case DataType::FixedVector: { + const auto& vec = value.asFixedVector(); + return varintSize(vec.size()) + vec.size() * sizeof(uint32_t); + } + case DataType::Nil: + default: + return stringEncodedSize(value.toString()); } } @@ -384,6 +389,14 @@ size_t serializeRangeInternal(const Table& table, size_t row_begin, size_t row_e uint64_t raw = 0; std::memcpy(&raw, &d, sizeof(raw)); writeU64(writer, raw); + } else if (columns[i].type == DataType::FixedVector) { + const auto& vec = value.asFixedVector(); + writeVarint(writer, vec.size()); + for (float item : vec) { + uint32_t bits = 0; + std::memcpy(&bits, &item, sizeof(bits)); + writeU32(writer, bits); + } } else { writeString(writer, value.toString()); } @@ -417,6 +430,21 @@ bool readValue(const BufferCursor& src, size_t* offset, DataType type, Value* ou *out = Value(std::move(text)); return true; } + case DataType::FixedVector: { + uint64_t dim = 0; + if (!readVarint(src, offset, &dim)) return false; + std::vector vec; + vec.reserve(static_cast(dim)); + for (uint64_t i = 0; i < dim; ++i) { + uint32_t bits = 0; + if (!readU32(src, offset, &bits)) return false; + float v = 0.0f; + std::memcpy(&v, &bits, sizeof(v)); + vec.push_back(v); + } + *out = Value(std::move(vec)); + return true; + } } } diff --git a/src/dataflow/tests/vector_runtime_test.cc b/src/dataflow/tests/vector_runtime_test.cc new file mode 100644 index 0000000..14c1584 --- /dev/null +++ b/src/dataflow/tests/vector_runtime_test.cc @@ -0,0 +1,90 @@ +#include +#include +#include +#include +#include +#include + +#include "src/dataflow/api/session.h" +#include "src/dataflow/serial/serializer.h" +#include "src/dataflow/stream/binary_row_batch.h" + +namespace { + +void expect(bool cond, const std::string& msg) { + if (!cond) throw std::runtime_error(msg); +} + +bool nearlyEqual(double lhs, double rhs, double eps = 1e-6) { + return std::fabs(lhs - rhs) <= eps; +} + +} // namespace + +int main() { + try { + dataflow::Table table; + table.schema = dataflow::Schema({"id", "embedding"}); + table.rows = { + {dataflow::Value(int64_t(1)), dataflow::Value(std::vector{1.0f, 0.0f, 0.0f})}, + {dataflow::Value(int64_t(2)), dataflow::Value(std::vector{0.9f, 0.1f, 0.0f})}, + {dataflow::Value(int64_t(3)), dataflow::Value(std::vector{0.0f, 1.0f, 0.0f})}, + }; + + dataflow::ProtoLikeSerializer text_codec; + auto text_payload = text_codec.serialize(table); + auto text_roundtrip = text_codec.deserialize(text_payload); + expect(text_roundtrip.rows.size() == 3, "proto-like row count mismatch"); + expect(text_roundtrip.rows[0][1].type() == dataflow::DataType::FixedVector, + "proto-like should keep vector type"); + expect(text_roundtrip.rows[0][1].asFixedVector().size() == 3, + "proto-like vector length mismatch"); + + dataflow::BinaryRowBatchCodec batch_codec; + std::vector binary_payload; + batch_codec.serialize(table, &binary_payload); + auto binary_roundtrip = batch_codec.deserialize(binary_payload); + expect(binary_roundtrip.rows.size() == 3, "binary row batch row count mismatch"); + expect(binary_roundtrip.rows[1][1].type() == dataflow::DataType::FixedVector, + "binary row batch should keep vector type"); + expect(binary_roundtrip.rows[1][1].asFixedVector()[1] == 0.1f, + "binary row batch vector content mismatch"); + + auto& session = dataflow::DataflowSession::builder(); + auto df = session.createDataFrame(table); + session.createTempView("vec_src", df); + + auto cosine = session.vectorQuery("vec_src", "embedding", {1.0f, 0.0f, 0.0f}, 2, + dataflow::VectorDistanceMetric::Cosine) + .toTable(); + expect(cosine.rows.size() == 2, "cosine vector query top-k mismatch"); + expect(cosine.schema.fields[0] == "row_id", "cosine result schema mismatch"); + expect(cosine.rows[0][0].asInt64() == 0, "cosine nearest should be row 0"); + + auto l2 = session.vectorQuery("vec_src", "embedding", {0.0f, 1.0f, 0.0f}, 1, + dataflow::VectorDistanceMetric::L2) + .toTable(); + expect(l2.rows.size() == 1, "l2 vector query top-k mismatch"); + expect(l2.rows[0][0].asInt64() == 2, "l2 nearest should be row 2"); + + auto dot = session.vectorQuery("vec_src", "embedding", {1.0f, 0.0f, 0.0f}, 1, + dataflow::VectorDistanceMetric::Dot) + .toTable(); + expect(dot.rows[0][0].asInt64() == 0, "dot nearest should be row 0"); + + const auto explain = session.explainVectorQuery("vec_src", "embedding", {1.0f, 0.0f, 0.0f}, 2, + dataflow::VectorDistanceMetric::Cosine); + expect(explain.find("mode=exact-scan") != std::string::npos, "explain mode missing"); + expect(explain.find("metric=cosine") != std::string::npos, "explain metric missing"); + expect(explain.find("dimension=3") != std::string::npos, "explain dimension missing"); + expect(explain.find("top_k=2") != std::string::npos, "explain top_k missing"); + expect(explain.find("acceleration=flat-buffer+heap-topk") != std::string::npos, + "explain acceleration hint missing"); + + std::cout << "[test] vector runtime query and transport ok" << std::endl; + return 0; + } catch (const std::exception& ex) { + std::cerr << "[test] vector runtime query and transport failed: " << ex.what() << std::endl; + return 1; + } +} From c00135d7a5e0abbc630256460a8731dfee863d54 Mon Sep 17 00:00:00 2001 From: zuolingxuan Date: Mon, 30 Mar 2026 13:33:55 +0800 Subject: [PATCH 2/4] fix(python): use session field in vector bindings --- src/dataflow/python/python_module.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dataflow/python/python_module.cc b/src/dataflow/python/python_module.cc index eec42f3..3a47cda 100644 --- a/src/dataflow/python/python_module.cc +++ b/src/dataflow/python/python_module.cc @@ -1177,8 +1177,8 @@ PyObject* sessionVectorSearch(PyVelariaSession* self, PyObject* args, PyObject* df::DataFrame result; { AllowThreads allow; - result = self->session_ptr->vectorQuery(table, vector_column, query, static_cast(top_k), - parseVectorMetric(metric)); + result = self->session->vectorQuery(table, vector_column, query, static_cast(top_k), + parseVectorMetric(metric)); } return wrapDataFrame(result); }); @@ -1201,9 +1201,9 @@ PyObject* sessionExplainVectorSearch(PyVelariaSession* self, PyObject* args, PyO std::string explain; { AllowThreads allow; - explain = self->session_ptr->explainVectorQuery(table, vector_column, query, - static_cast(top_k), - parseVectorMetric(metric)); + explain = self->session->explainVectorQuery(table, vector_column, query, + static_cast(top_k), + parseVectorMetric(metric)); } return PyUnicode_FromString(explain.c_str()); }); From c17732eda74f2236ff72f97d02f0b1b237a52f69 Mon Sep 17 00:00:00 2001 From: zuolingxuan Date: Mon, 30 Mar 2026 14:19:42 +0800 Subject: [PATCH 3/4] perf(runtime): split actor result data from control --- BUILD.bazel | 10 +- README-zh.md | 14 + README.md | 14 + src/dataflow/api/dataframe.cc | 97 +++--- src/dataflow/api/dataframe.h | 9 + .../examples/vector_search_benchmark.cc | 111 ++++++- src/dataflow/rpc/rpc_codec.cc | 15 +- src/dataflow/runner/actor_runtime.cc | 309 +++++++++++++++--- src/dataflow/runtime/vector_index.cc | 111 ++++--- src/dataflow/tests/vector_runtime_test.cc | 97 +++++- src/dataflow/transport/ipc_transport.cc | 8 +- 11 files changed, 643 insertions(+), 152 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index fff8027..6d5b51c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -246,7 +246,10 @@ cc_binary( cc_binary( name = "vector_search_benchmark", srcs = ["src/dataflow/examples/vector_search_benchmark.cc"], - deps = [":dataflow_core"], + deps = [ + ":dataflow_actor_rpc_codec", + ":dataflow_core", + ], ) cc_test( @@ -288,5 +291,8 @@ cc_test( cc_test( name = "vector_runtime_test", srcs = ["src/dataflow/tests/vector_runtime_test.cc"], - deps = [":dataflow_core"], + deps = [ + ":dataflow_actor_rpc_codec", + ":dataflow_core", + ], ) diff --git a/README-zh.md b/README-zh.md index 47cf699..08152d5 100644 --- a/README-zh.md +++ b/README-zh.md @@ -264,6 +264,7 @@ runtime 传输层现已在 proto-like 与 binary row batch codec 中保留 `Fixe FixedVector 在内部 codec 里改为 raw float bit payload 编码,避免文本往返造成的精度损耗。 当前向量检索范围为本地 exact scan(`mode=exact-scan`)+ 固定维度 float 向量;v0.1 不包含 ANN 与分布式执行路径。 Arrow ingestion 已增加 `FixedSizeList` 的 native 快路径,可减少向量列的 Python 对象转换开销。 +同机 actor runtime 的结果回传现在采用“双帧”模型:控制消息继续走 `actor-rpc-v1`,结果表单独走 `table-bin-v1` 的 `DataBatch` 帧,并通过 `correlation_id` 关联;热路径不再把整张结果表塞进 actor JSON body。 ## 同机多进程实验路径 @@ -285,6 +286,8 @@ smoke: bazel run //:actor_rpc_smoke ``` +该 smoke 现会同时校验 actor 控制消息和关联的二进制 `DataBatch` 结果帧。 + 三进程本地运行: ```bash @@ -308,6 +311,17 @@ Dashboard: - `//:tpch_q1_style_benchmark` - `//:vector_search_benchmark` +向量 benchmark: + +```bash +bazel run //:vector_search_benchmark +``` + +会输出两类 JSON 行: + +- `vector-query`:cold query、warm query、warm explain 延迟 +- `vector-transport`:proto-like 与 `BinaryRowBatch` 的编解码耗时、payload 大小,以及 actor 控制帧开销 + 同机 observability regression: ```bash diff --git a/README.md b/README.md index bb41c4a..8c93c6f 100644 --- a/README.md +++ b/README.md @@ -243,6 +243,7 @@ Runtime-level vector transport now preserves `FixedVector` through proto-like an FixedVector serialization now uses raw float bit payload encoding in internal codecs to avoid text round-trip precision loss. Current vector search scope is local-only exact scan (`mode=exact-scan`) with fixed-dimension float vectors; no ANN/distributed path in v0.1. Arrow ingestion now includes a direct `FixedSizeList` fast path in the native bridge, reducing Python object conversion overhead on vector columns. +For same-host actor runtime results, the control message stays on `actor-rpc-v1`, while the result table is forwarded as a separate `table-bin-v1` `DataBatch` frame linked by `correlation_id`. The hot result path no longer puts row payloads inside the actor JSON body. ## Same-Host Multi-Process Experiment @@ -264,6 +265,8 @@ Smoke: bazel run //:actor_rpc_smoke ``` +The smoke target now verifies both the actor control message and the correlated binary `DataBatch` result frame. + Three-process local run: ```bash @@ -287,6 +290,17 @@ Useful local targets: - `//:tpch_q1_style_benchmark` - `//:vector_search_benchmark` +Vector benchmark: + +```bash +bazel run //:vector_search_benchmark +``` + +It emits JSON lines for: + +- `vector-query`: cold query, warm query, and warm explain latency +- `vector-transport`: proto-like vs `BinaryRowBatch` serialize/deserialize cost and payload size, plus actor control-frame overhead + Same-host observability regression: ```bash diff --git a/src/dataflow/api/dataframe.cc b/src/dataflow/api/dataframe.cc index d9aa757..d2d253a 100644 --- a/src/dataflow/api/dataframe.cc +++ b/src/dataflow/api/dataframe.cc @@ -22,6 +22,13 @@ bool gtePred(const Value& lhs, const Value& rhs) { return lhs > rhs || lhs == rh std::shared_ptr defaultExecutor() { return std::make_shared(); } +VectorSearchMetric toRuntimeMetric(VectorDistanceMetric metric) { + return metric == VectorDistanceMetric::L2 + ? VectorSearchMetric::L2 + : (metric == VectorDistanceMetric::Dot ? VectorSearchMetric::Dot + : VectorSearchMetric::Cosine); +} + std::string planKindName(PlanKind kind) { switch (kind) { case PlanKind::Source: @@ -179,6 +186,44 @@ const Table& DataFrame::materialize() const { return cached_table_; } +const CachedVectorColumn& DataFrame::vectorColumnCache(const std::string& vectorColumn) const { + const auto& source = materialize(); + const size_t vector_index = source.schema.indexOf(vectorColumn); + const auto it = vector_cache_.find(vector_index); + if (it != vector_cache_.end()) { + return it->second; + } + + CachedVectorColumn cache; + cache.row_ids.reserve(source.rows.size()); + std::vector> vectors; + vectors.reserve(source.rows.size()); + + std::size_t expected_dim = 0; + bool has_dimension = false; + for (size_t row_id = 0; row_id < source.rows.size(); ++row_id) { + if (vector_index >= source.rows[row_id].size()) continue; + const auto& cell = source.rows[row_id][vector_index]; + if (cell.isNull()) continue; + + std::vector vec = cell.type() == DataType::FixedVector + ? cell.asFixedVector() + : Value::parseFixedVector(cell.toString()); + if (!has_dimension) { + expected_dim = vec.size(); + has_dimension = true; + } else if (vec.size() != expected_dim) { + throw std::invalid_argument("fixed vector length mismatch in vector column cache"); + } + cache.row_ids.push_back(row_id); + vectors.push_back(std::move(vec)); + } + + cache.index = std::shared_ptr(makeExactScanVectorIndex(std::move(vectors)).release()); + const auto inserted = vector_cache_.emplace(vector_index, std::move(cache)); + return inserted.first->second; +} + std::string DataFrame::explain() const { std::ostringstream out; explainPlan(plan_, out, 0); @@ -278,39 +323,21 @@ DataFrame DataFrame::vectorQuery(const std::string& vectorColumn, if (queryVector.empty()) { throw std::invalid_argument("query vector cannot be empty"); } - const auto source = materialize(); - const size_t vector_index = source.schema.indexOf(vectorColumn); - std::vector> vectors; - vectors.reserve(source.rows.size()); - for (size_t i = 0; i < source.rows.size(); ++i) { - if (vector_index >= source.rows[i].size()) continue; - const auto& cell = source.rows[i][vector_index]; - std::vector vec; - if (cell.type() == DataType::FixedVector) { - vec = cell.asFixedVector(); - } else { - vec = Value::parseFixedVector(cell.toString()); - } - if (vec.size() != queryVector.size()) { - throw std::invalid_argument("fixed vector length mismatch in vectorQuery"); - } - vectors.push_back(std::move(vec)); + const auto& cache = vectorColumnCache(vectorColumn); + if (cache.index->dimension() != 0 && cache.index->dimension() != queryVector.size()) { + throw std::invalid_argument("fixed vector length mismatch in vectorQuery"); } VectorSearchOptions options; options.top_k = top_k; - options.metric = metric == VectorDistanceMetric::L2 - ? VectorSearchMetric::L2 - : (metric == VectorDistanceMetric::Dot ? VectorSearchMetric::Dot - : VectorSearchMetric::Cosine); - const auto index = makeExactScanVectorIndex(std::move(vectors)); - const auto scored = index->search(queryVector, options); + options.metric = toRuntimeMetric(metric); + const auto scored = cache.index->search(queryVector, options); const size_t take = scored.size(); Table out; out.schema = Schema({"row_id", "score"}); out.rows.reserve(take); for (size_t i = 0; i < take; ++i) { Row row; - row.emplace_back(static_cast(scored[i].row_id)); + row.emplace_back(static_cast(cache.row_ids.at(scored[i].row_id))); row.emplace_back(scored[i].score); out.rows.push_back(std::move(row)); } @@ -320,27 +347,15 @@ DataFrame DataFrame::vectorQuery(const std::string& vectorColumn, std::string DataFrame::explainVectorQuery(const std::string& vectorColumn, const std::vector& queryVector, size_t top_k, VectorDistanceMetric metric) const { - const auto source = materialize(); - const size_t vector_index = source.schema.indexOf(vectorColumn); - std::vector> vectors; - vectors.reserve(source.rows.size()); - for (const auto& row : source.rows) { - if (vector_index >= row.size()) continue; - vectors.push_back(row[vector_index].type() == DataType::FixedVector - ? row[vector_index].asFixedVector() - : Value::parseFixedVector(row[vector_index].toString())); - } + const auto& cache = vectorColumnCache(vectorColumn); VectorSearchOptions options; options.top_k = top_k; - options.metric = metric == VectorDistanceMetric::L2 - ? VectorSearchMetric::L2 - : (metric == VectorDistanceMetric::Dot ? VectorSearchMetric::Dot - : VectorSearchMetric::Cosine); - const auto index = makeExactScanVectorIndex(std::move(vectors)); - if (!queryVector.empty() && index->dimension() != 0 && queryVector.size() != index->dimension()) { + options.metric = toRuntimeMetric(metric); + if (!queryVector.empty() && cache.index->dimension() != 0 && + queryVector.size() != cache.index->dimension()) { throw std::invalid_argument("query vector dimension mismatch in explainVectorQuery"); } - return index->explain(options); + return cache.index->explain(options); } GroupedDataFrame DataFrame::groupBy(const std::vector& keys) const { diff --git a/src/dataflow/api/dataframe.h b/src/dataflow/api/dataframe.h index 937212b..f31615d 100644 --- a/src/dataflow/api/dataframe.h +++ b/src/dataflow/api/dataframe.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include "src/dataflow/core/table.h" @@ -13,8 +14,14 @@ namespace dataflow { class DataFrame; +class VectorIndex; enum class VectorDistanceMetric { Cosine, Dot, L2 }; +struct CachedVectorColumn { + std::shared_ptr index; + std::vector row_ids; +}; + class GroupedDataFrame { public: GroupedDataFrame(PlanNodePtr plan, std::vector keys, std::shared_ptr exec) @@ -72,11 +79,13 @@ class DataFrame { private: const Table& materialize() const; + const CachedVectorColumn& vectorColumnCache(const std::string& vectorColumn) const; PlanNodePtr plan_; std::shared_ptr executor_; mutable bool cached_ = false; mutable Table cached_table_; + mutable std::unordered_map vector_cache_; }; } // namespace dataflow diff --git a/src/dataflow/examples/vector_search_benchmark.cc b/src/dataflow/examples/vector_search_benchmark.cc index 03bb782..47a508f 100644 --- a/src/dataflow/examples/vector_search_benchmark.cc +++ b/src/dataflow/examples/vector_search_benchmark.cc @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -7,6 +8,9 @@ #include #include "src/dataflow/api/session.h" +#include "src/dataflow/rpc/actor_rpc_codec.h" +#include "src/dataflow/serial/serializer.h" +#include "src/dataflow/stream/binary_row_batch.h" namespace { @@ -36,6 +40,11 @@ std::vector makeQuery(std::size_t dim, uint32_t seed) { return q; } +long long microsBetween(std::chrono::steady_clock::time_point begin, + std::chrono::steady_clock::time_point end) { + return std::chrono::duration_cast(end - begin).count(); +} + void runCase(std::size_t rows, std::size_t dim, dataflow::VectorDistanceMetric metric, const std::string& metric_name) { auto table = makeSyntheticTable(rows, dim, static_cast(rows + dim)); @@ -45,21 +54,114 @@ void runCase(std::size_t rows, std::size_t dim, dataflow::VectorDistanceMetric m const std::string view_name = "vec_bench_" + std::to_string(rows) + "_" + std::to_string(dim); session.createTempView(view_name, session.createDataFrame(table)); - const auto begin = std::chrono::steady_clock::now(); + const auto cold_begin = std::chrono::steady_clock::now(); auto out = session.vectorQuery(view_name, "embedding", query, 10, metric).toTable(); - const auto elapsed = std::chrono::duration_cast( - std::chrono::steady_clock::now() - begin); + const auto cold_end = std::chrono::steady_clock::now(); + + constexpr std::size_t kWarmIterations = 5; + long long warm_query_us = 0; + for (std::size_t i = 0; i < kWarmIterations; ++i) { + const auto begin = std::chrono::steady_clock::now(); + auto warm = session.vectorQuery(view_name, "embedding", query, 10, metric).toTable(); + const auto end = std::chrono::steady_clock::now(); + warm_query_us += microsBetween(begin, end); + if (warm.rows.size() != out.rows.size()) { + std::cerr << "[vector-benchmark] warm query cardinality mismatch" << std::endl; + std::exit(1); + } + } + + long long explain_us = 0; + for (std::size_t i = 0; i < kWarmIterations; ++i) { + const auto begin = std::chrono::steady_clock::now(); + const auto explain = session.explainVectorQuery(view_name, "embedding", query, 10, metric); + const auto end = std::chrono::steady_clock::now(); + explain_us += microsBetween(begin, end); + if (explain.find("mode=exact-scan") == std::string::npos) { + std::cerr << "[vector-benchmark] explain output mismatch" << std::endl; + std::exit(1); + } + } std::cout << "{" + << "\"bench\":\"vector-query\"," << "\"rows\":" << rows << "," << "\"dimension\":" << dim << "," << "\"top_k\":10," << "\"metric\":\"" << metric_name << "\"," - << "\"elapsed_ms\":" << elapsed.count() << "," + << "\"cold_query_us\":" << microsBetween(cold_begin, cold_end) << "," + << "\"warm_query_avg_us\":" << (warm_query_us / static_cast(kWarmIterations)) << "," + << "\"warm_explain_avg_us\":" << (explain_us / static_cast(kWarmIterations)) << "," << "\"result_rows\":" << out.rows.size() << "" << "}" << std::endl; } +void runTransportCase(std::size_t rows, std::size_t dim) { + const auto table = makeSyntheticTable(rows, dim, static_cast(rows * 17 + dim)); + + dataflow::ProtoLikeSerializer proto_codec; + const auto proto_serialize_begin = std::chrono::steady_clock::now(); + const auto proto_payload = proto_codec.serialize(table); + const auto proto_serialize_end = std::chrono::steady_clock::now(); + const auto proto_deserialize_begin = std::chrono::steady_clock::now(); + const auto proto_roundtrip = proto_codec.deserialize(proto_payload); + const auto proto_deserialize_end = std::chrono::steady_clock::now(); + + dataflow::BinaryRowBatchCodec batch_codec; + std::vector binary_payload; + const auto binary_serialize_begin = std::chrono::steady_clock::now(); + batch_codec.serialize(table, &binary_payload); + const auto binary_serialize_end = std::chrono::steady_clock::now(); + const auto binary_deserialize_begin = std::chrono::steady_clock::now(); + const auto binary_roundtrip = batch_codec.deserialize(binary_payload); + const auto binary_deserialize_end = std::chrono::steady_clock::now(); + + dataflow::ActorRpcMessage actor_payload; + actor_payload.action = dataflow::ActorRpcAction::Result; + actor_payload.job_id = "bench-job"; + actor_payload.chain_id = "bench-chain"; + actor_payload.task_id = "bench-task"; + actor_payload.node_id = "bench-node"; + actor_payload.ok = true; + actor_payload.state = "FINISHED"; + actor_payload.summary = "vector transport benchmark"; + actor_payload.result_location = "inline://bench-job/bench-task"; + actor_payload.payload = actor_payload.summary; + + const auto actor_encode_begin = std::chrono::steady_clock::now(); + const auto actor_wire = encodeActorRpcMessage(actor_payload); + const auto actor_encode_end = std::chrono::steady_clock::now(); + const auto actor_decode_begin = std::chrono::steady_clock::now(); + dataflow::ActorRpcMessage actor_roundtrip; + const bool actor_decode_ok = decodeActorRpcMessage(actor_wire, &actor_roundtrip); + const auto actor_decode_end = std::chrono::steady_clock::now(); + + if (proto_roundtrip.rowCount() != table.rowCount() || binary_roundtrip.rowCount() != table.rowCount()) { + std::cerr << "[vector-benchmark] transport roundtrip row count mismatch" << std::endl; + std::exit(1); + } + if (!actor_decode_ok || actor_roundtrip.summary != actor_payload.summary || + actor_roundtrip.result_location != actor_payload.result_location) { + std::cerr << "[vector-benchmark] actor rpc control roundtrip mismatch" << std::endl; + std::exit(1); + } + + std::cout << "{" + << "\"bench\":\"vector-transport\"," + << "\"rows\":" << rows << "," + << "\"dimension\":" << dim << "," + << "\"proto_serialize_us\":" << microsBetween(proto_serialize_begin, proto_serialize_end) << "," + << "\"proto_deserialize_us\":" << microsBetween(proto_deserialize_begin, proto_deserialize_end) << "," + << "\"proto_payload_bytes\":" << proto_payload.size() << "," + << "\"binary_serialize_us\":" << microsBetween(binary_serialize_begin, binary_serialize_end) << "," + << "\"binary_deserialize_us\":" << microsBetween(binary_deserialize_begin, binary_deserialize_end) << "," + << "\"binary_payload_bytes\":" << binary_payload.size() << "," + << "\"actor_rpc_encode_us\":" << microsBetween(actor_encode_begin, actor_encode_end) << "," + << "\"actor_rpc_decode_us\":" << microsBetween(actor_decode_begin, actor_decode_end) << "," + << "\"actor_rpc_control_bytes\":" << actor_wire.size() + << "}" << std::endl; +} + } // namespace int main() { @@ -69,6 +171,7 @@ int main() { runCase(rows, dim, dataflow::VectorDistanceMetric::Cosine, "cosine"); runCase(rows, dim, dataflow::VectorDistanceMetric::Dot, "dot"); runCase(rows, dim, dataflow::VectorDistanceMetric::L2, "l2"); + runTransportCase(rows, dim); } } return 0; diff --git a/src/dataflow/rpc/rpc_codec.cc b/src/dataflow/rpc/rpc_codec.cc index 7e51a52..0f7c26a 100644 --- a/src/dataflow/rpc/rpc_codec.cc +++ b/src/dataflow/rpc/rpc_codec.cc @@ -5,7 +5,7 @@ #include #include -#include "src/dataflow/serial/serializer.h" +#include "src/dataflow/stream/binary_row_batch.h" namespace dataflow { @@ -180,9 +180,10 @@ class TableBatchRpcSerializer : public IRpcSerializer { if (envelope.type != RpcMessageType::DataBatch) return {}; const auto* batch = static_cast(message); if (batch == nullptr) return {}; - const ProtoLikeSerializer serializer; - const std::string payload = serializer.serialize(batch->table); - return std::vector(payload.begin(), payload.end()); + BinaryRowBatchCodec serializer; + std::vector payload; + serializer.serialize(batch->table, &payload); + return payload; } bool deserialize(const RpcEnvelope& envelope, @@ -191,10 +192,8 @@ class TableBatchRpcSerializer : public IRpcSerializer { if (envelope.type != RpcMessageType::DataBatch) return false; auto* batch = static_cast(out_message); if (batch == nullptr) return false; - const ProtoLikeSerializer serializer; - const std::string payload_str(reinterpret_cast(payload.data()), - payload.size()); - batch->table = serializer.deserialize(payload_str); + BinaryRowBatchCodec serializer; + batch->table = serializer.deserialize(payload); return true; } }; diff --git a/src/dataflow/runner/actor_runtime.cc b/src/dataflow/runner/actor_runtime.cc index a661430..7c51ebc 100644 --- a/src/dataflow/runner/actor_runtime.cc +++ b/src/dataflow/runner/actor_runtime.cc @@ -28,7 +28,7 @@ #include "src/dataflow/rpc/rpc_codec.h" #include "src/dataflow/runtime/job_master.h" #include "src/dataflow/runtime/observability.h" -#include "src/dataflow/serial/serializer.h" +#include "src/dataflow/stream/binary_row_batch.h" #include "src/dataflow/sql/sql_parser.h" #include "src/dataflow/transport/ipc_transport.h" @@ -516,6 +516,7 @@ std::string buildJobsJson(const std::unordered_map& } RpcFrame makeFrameFromMessage(uint64_t msg_id, + uint64_t correlation_id, const std::string& source, const std::string& target, const ActorRpcMessage& message) { @@ -523,7 +524,7 @@ RpcFrame makeFrameFromMessage(uint64_t msg_id, frame.header.protocol_version = 1; frame.header.type = RpcMessageType::Control; frame.header.message_id = msg_id; - frame.header.correlation_id = 0; + frame.header.correlation_id = correlation_id; frame.header.codec_id = "actor-rpc-v1"; frame.header.source = source; frame.header.target = target; @@ -531,6 +532,41 @@ RpcFrame makeFrameFromMessage(uint64_t msg_id, return frame; } +RpcFrame makeFrameFromMessage(uint64_t msg_id, + const std::string& source, + const std::string& target, + const ActorRpcMessage& message) { + return makeFrameFromMessage(msg_id, 0, source, target, message); +} + +RpcFrame makeDataBatchFrame(uint64_t msg_id, + uint64_t correlation_id, + const std::string& source, + const std::string& target, + const Table& table) { + RpcFrame frame; + frame.header.protocol_version = 1; + frame.header.type = RpcMessageType::DataBatch; + frame.header.message_id = msg_id; + frame.header.correlation_id = correlation_id; + frame.header.codec_id = "table-bin-v1"; + frame.header.source = source; + frame.header.target = target; + BinaryRowBatchCodec codec; + codec.serialize(table, &frame.payload); + return frame; +} + +bool decodeDataBatchFrame(const RpcFrame& frame, Table* out) { + if (out == nullptr || frame.header.type != RpcMessageType::DataBatch || + frame.header.codec_id != "table-bin-v1") { + return false; + } + BinaryRowBatchCodec codec; + *out = codec.deserialize(frame.payload); + return true; +} + void removeWorker(std::vector* workers, int fd) { workers->erase(std::remove(workers->begin(), workers->end(), fd), workers->end()); } @@ -612,6 +648,7 @@ int runActorScheduler(const ActorRuntimeConfig& config) { std::unordered_set idle_workers; std::unordered_map job_to_client; std::unordered_map task_to_worker; + std::unordered_map pending_worker_result_msgs; std::unordered_map job_snapshots; std::vector dashboard_conns; uint64_t next_message_id = 1; @@ -929,6 +966,7 @@ int runActorScheduler(const ActorRuntimeConfig& config) { conn_role.erase(fd); conn_node.erase(fd); worker_node_by_fd.erase(fd); + pending_worker_result_msgs.erase(fd); if (role == "worker" && isLocalWorkerNodeId(worker_node, config.node_id)) { local_worker_nodes.erase(worker_node); } @@ -1114,6 +1152,89 @@ int runActorScheduler(const ActorRuntimeConfig& config) { cleanup(fd); continue; } + if (frame.header.type == RpcMessageType::DataBatch) { + auto pending_it = pending_worker_result_msgs.find(fd); + if (pending_it == pending_worker_result_msgs.end()) { + emitRpcEvent("actor_scheduler", "unexpected_data_batch", config.node_id, + "RPC_UNEXPECTED_DATA_BATCH", "", + {observability::field("fd", fd), + observability::field("codec_id", frame.header.codec_id)}); + ++i; + continue; + } + Table result_table; + if (!decodeDataBatchFrame(frame, &result_table)) { + emitRpcEvent("actor_scheduler", "data_batch_decode_error", config.node_id, + "RPC_DATA_BATCH_DECODE_ERROR", pending_it->second.job_id, + {observability::field("fd", fd), + observability::field("task_id", pending_it->second.task_id)}); + pending_worker_result_msgs.erase(pending_it); + ++i; + continue; + } + + const ActorRpcMessage& result_msg = pending_it->second; + RemoteTaskCompletion completion; + completion.job_id = result_msg.job_id; + completion.chain_id = result_msg.chain_id; + completion.task_id = result_msg.task_id; + completion.attempt = static_cast(result_msg.attempt); + completion.ok = true; + completion.payload.clear(); + completion.error_message = result_msg.reason; + completion.output_rows = static_cast(result_msg.output_rows); + completion.worker_id = result_msg.node_id; + completion.result_table = result_table; + completion.has_result_table = true; + if (completion.output_rows == 0) { + completion.output_rows = completion.result_table.rowCount(); + } + JobMaster::instance().completeRemoteTask(completion); + + const auto worker_it = task_to_worker.find(result_msg.task_id); + if (worker_it != task_to_worker.end()) { + idle_workers.insert(worker_it->second); + auto snap_it = job_snapshots.find(result_msg.job_id); + if (snap_it != job_snapshots.end()) { + snap_it->second.worker_node = result_msg.node_id; + snap_it->second.result_payload = + result_msg.summary.empty() ? result_msg.reason : result_msg.summary; + snap_it->second.state = "FINISHED"; + snap_it->second.status_code = "JOB_FINISHED"; + snap_it->second.chain.state = "FINISHED"; + snap_it->second.chain.status_code = "CHAIN_FINISHED"; + snap_it->second.task.state = "FINISHED"; + snap_it->second.task.status_code = "TASK_FINISHED"; + snap_it->second.task.worker_id = result_msg.node_id; + snap_it->second.task.fail_reason.clear(); + snap_it->second.task.worker_finished_at = std::chrono::steady_clock::now(); + snap_it->second.task.result_returned_at = snap_it->second.task.worker_finished_at; + emitSnapshotEvent(snap_it->second); + } + task_to_worker.erase(worker_it); + } + + const auto client_it = job_to_client.find(result_msg.job_id); + if (client_it != job_to_client.end()) { + const uint64_t control_id = next_message_id++; + ActorRpcMessage forward = result_msg; + forward.payload.clear(); + if (sendFrameOverSocket(client_it->second, codec, + makeFrameFromMessage(control_id, 0, config.node_id, + std::to_string(client_it->second), + forward))) { + sendFrameOverSocket(client_it->second, codec, + makeDataBatchFrame(next_message_id++, control_id, config.node_id, + std::to_string(client_it->second), + completion.result_table)); + } + job_to_client.erase(client_it); + } + pending_worker_result_msgs.erase(pending_it); + dispatchPending(); + ++i; + continue; + } ActorRpcMessage msg; if (!decodeActorRpcMessage(frame.payload, &msg)) { emitRpcEvent("actor_scheduler", "decode_error", config.node_id, "RPC_DECODE_ERROR", "", @@ -1199,6 +1320,12 @@ int runActorScheduler(const ActorRuntimeConfig& config) { {observability::field("client_fd", fd)}); dispatchPending(); } else if (msg.action == ActorRpcAction::Result) { + if (msg.ok && !msg.result_location.empty()) { + pending_worker_result_msgs[fd] = msg; + ++i; + continue; + } + RemoteTaskCompletion completion; completion.job_id = msg.job_id; completion.chain_id = msg.chain_id; @@ -1209,37 +1336,29 @@ int runActorScheduler(const ActorRuntimeConfig& config) { completion.error_message = msg.reason; completion.output_rows = static_cast(msg.output_rows); completion.worker_id = msg.node_id; - if (msg.ok) { - const ProtoLikeSerializer serializer; - completion.result_table = serializer.deserialize(msg.payload); - completion.has_result_table = true; - if (completion.output_rows == 0) { - completion.output_rows = completion.result_table.rowCount(); - } - } JobMaster::instance().completeRemoteTask(completion); const auto worker_it = task_to_worker.find(msg.task_id); if (worker_it != task_to_worker.end()) { idle_workers.insert(worker_it->second); auto snap_it = job_snapshots.find(msg.job_id); - if (snap_it != job_snapshots.end()) { - snap_it->second.worker_node = msg.node_id; - snap_it->second.result_payload = msg.summary.empty() ? msg.reason : msg.summary; - snap_it->second.state = msg.ok ? "FINISHED" : "FAILED"; + if (snap_it != job_snapshots.end()) { + snap_it->second.worker_node = msg.node_id; + snap_it->second.result_payload = msg.summary.empty() ? msg.reason : msg.summary; + snap_it->second.state = msg.ok ? "FINISHED" : "FAILED"; snap_it->second.status_code = msg.ok ? "JOB_FINISHED" : "JOB_FAILED"; snap_it->second.chain.state = msg.ok ? "FINISHED" : "FAILED"; snap_it->second.chain.status_code = msg.ok ? "CHAIN_FINISHED" : "CHAIN_FAILED"; - snap_it->second.task.state = msg.ok ? "FINISHED" : "FAILED"; - snap_it->second.task.status_code = msg.ok ? "TASK_FINISHED" : "TASK_FAILED"; - snap_it->second.task.worker_id = msg.node_id; - snap_it->second.task.fail_reason = msg.ok ? "" : msg.reason; - snap_it->second.task.worker_finished_at = std::chrono::steady_clock::now(); - snap_it->second.task.result_returned_at = snap_it->second.task.worker_finished_at; - emitSnapshotEvent(snap_it->second); + snap_it->second.task.state = msg.ok ? "FINISHED" : "FAILED"; + snap_it->second.task.status_code = msg.ok ? "TASK_FINISHED" : "TASK_FAILED"; + snap_it->second.task.worker_id = msg.node_id; + snap_it->second.task.fail_reason = msg.ok ? "" : msg.reason; + snap_it->second.task.worker_finished_at = std::chrono::steady_clock::now(); + snap_it->second.task.result_returned_at = snap_it->second.task.worker_finished_at; + emitSnapshotEvent(snap_it->second); + } + task_to_worker.erase(worker_it); } - task_to_worker.erase(worker_it); - } const auto client_it = job_to_client.find(msg.job_id); if (client_it == job_to_client.end()) { @@ -1250,7 +1369,7 @@ int runActorScheduler(const ActorRuntimeConfig& config) { ActorRpcMessage forward = msg; forward.payload = msg.ok ? msg.summary : (msg.summary.empty() ? msg.reason : msg.summary); sendFrameOverSocket(client_it->second, codec, - makeFrameFromMessage(next_message_id++, config.node_id, + makeFrameFromMessage(next_message_id++, 0, config.node_id, std::to_string(client_it->second), forward)); job_to_client.erase(client_it); dispatchPending(); @@ -1406,6 +1525,11 @@ int runActorWorker(const ActorRuntimeConfig& config) { }); ActorRpcMessage result; + Table output_table; + bool has_output_table = false; + std::size_t result_batch_bytes = 0; + RpcFrame result_batch_frame; + bool has_result_batch_frame = false; result.action = ActorRpcAction::Result; result.job_id = msg.job_id; result.chain_id = msg.chain_id; @@ -1415,24 +1539,18 @@ int runActorWorker(const ActorRuntimeConfig& config) { try { const auto plan = deserializePlan(msg.payload); const LocalExecutor executor; - const Table output = executor.execute(plan); + output_table = executor.execute(plan); + has_output_table = true; emitRpcEvent("actor_worker", "task_executed", config.node_id, "TASK_EXECUTED", msg.job_id, {observability::field("task_id", msg.task_id), observability::field("chain_id", msg.chain_id), - observability::field("output_rows", output.rowCount())}); - const ProtoLikeSerializer serializer; - const std::string serialized_output = serializer.serialize(output); + observability::field("output_rows", output_table.rowCount())}); result.ok = true; result.state = "FINISHED"; - result.output_rows = output.rowCount(); + result.output_rows = output_table.rowCount(); result.result_location = "inline://" + msg.job_id + "/" + msg.task_id; - result.summary = summarizeTable(output); - result.payload = serialized_output; - emitRpcEvent("actor_worker", "result_serialized", config.node_id, "TASK_RESULT_SERIALIZED", msg.job_id, - {observability::field("task_id", msg.task_id), - observability::field("chain_id", msg.chain_id), - observability::field("payload_bytes", serialized_output.size()), - observability::field("summary", result.summary)}); + result.summary = summarizeTable(output_table); + result.payload.clear(); } catch (const std::exception& e) { result.ok = false; result.state = "FAILED"; @@ -1442,9 +1560,26 @@ int runActorWorker(const ActorRuntimeConfig& config) { stop_heartbeat.store(true); if (heartbeat_thread.joinable()) heartbeat_thread.join(); - sendFrameOverSocket(fd, codec, - makeFrameFromMessage(next_message_id.fetch_add(1), config.node_id, - "scheduler", result)); + const uint64_t control_id = next_message_id.fetch_add(1); + if (result.ok && has_output_table) { + result_batch_frame = + makeDataBatchFrame(next_message_id.fetch_add(1), control_id, config.node_id, + "scheduler", output_table); + result_batch_bytes = result_batch_frame.payload.size(); + has_result_batch_frame = true; + } + RpcFrame control_frame = + makeFrameFromMessage(control_id, 0, config.node_id, "scheduler", result); + emitRpcEvent("actor_worker", "result_serialized", config.node_id, "TASK_RESULT_SERIALIZED", msg.job_id, + {observability::field("task_id", msg.task_id), + observability::field("chain_id", msg.chain_id), + observability::field("control_payload_bytes", control_frame.payload.size()), + observability::field("result_batch_bytes", result_batch_bytes), + observability::field("summary", result.summary)}); + sendFrameOverSocket(fd, codec, std::move(control_frame)); + if (has_result_batch_frame) { + sendFrameOverSocket(fd, codec, std::move(result_batch_frame)); + } emitRpcEvent("actor_worker", "job_completed", config.node_id, result.ok ? "TASK_FINISHED" : "TASK_FAILED", msg.job_id, {observability::field("task_id", msg.task_id), @@ -1519,6 +1654,9 @@ int runActorClient(const ActorRuntimeConfig& config, const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(10); bool awaiting_result = true; + bool awaiting_result_table = false; + uint64_t expected_result_correlation = 0; + ActorRpcMessage pending_result; while (awaiting_result) { const auto now = std::chrono::steady_clock::now(); if (now >= deadline) { @@ -1537,6 +1675,34 @@ int runActorClient(const ActorRuntimeConfig& config, emitRpcEvent("actor_client", "connection_closed", config.node_id, "RPC_CONNECTION_CLOSED", submit.job_id); return 1; } + if (frame.header.type == RpcMessageType::DataBatch) { + if (!awaiting_result_table || frame.header.correlation_id != expected_result_correlation) { + continue; + } + Table result_table; + if (!decodeDataBatchFrame(frame, &result_table)) { + emitRpcEvent("actor_client", "result_decode_failed", config.node_id, + "JOB_RESULT_DECODE_FAILED", pending_result.job_id); + std::cerr << "[client] job result decode failed\n"; + return 1; + } + const std::string summary = summarizeTable(result_table); + emitRpcEvent("actor_client", "job_result", config.node_id, "JOB_RESULT_RECEIVED", + pending_result.job_id, + {observability::field("payload", summary), + observability::field("task_id", pending_result.task_id), + observability::field("chain_id", pending_result.chain_id), + observability::field("result_location", pending_result.result_location), + observability::field("summary", pending_result.summary)}); + std::cout << "[client] job result: " << summary; + if (!pending_result.result_location.empty()) { + std::cout << " @ " << pending_result.result_location; + } + std::cout << "\n"; + awaiting_result_table = false; + awaiting_result = false; + continue; + } ActorRpcMessage msg; if (!decodeActorRpcMessage(frame.payload, &msg)) { continue; @@ -1551,19 +1717,25 @@ int runActorClient(const ActorRuntimeConfig& config, {observability::field("reason", msg.reason)}); std::cout << "[client] job accepted: " << msg.job_id << "\n"; } else if (msg.action == ActorRpcAction::Result) { - emitRpcEvent("actor_client", "job_result", config.node_id, - msg.ok ? "JOB_RESULT_RECEIVED" : "JOB_RESULT_FAILED", msg.job_id, - {observability::field("payload", msg.payload), - observability::field("task_id", msg.task_id), - observability::field("chain_id", msg.chain_id), - observability::field("result_location", msg.result_location), - observability::field("summary", msg.summary)}); if (!msg.ok) { std::cerr << "[client] job failed: " << (msg.payload.empty() ? msg.reason : msg.payload) << "\n"; return 1; } - std::cout << "[client] job result: " << msg.payload; + if (!msg.result_location.empty()) { + pending_result = msg; + awaiting_result_table = true; + expected_result_correlation = frame.header.message_id; + continue; + } + emitRpcEvent("actor_client", "job_result", config.node_id, "JOB_RESULT_RECEIVED", msg.job_id, + {observability::field("payload", msg.summary.empty() ? msg.payload : msg.summary), + observability::field("task_id", msg.task_id), + observability::field("chain_id", msg.chain_id), + observability::field("result_location", msg.result_location), + observability::field("summary", msg.summary)}); + std::cout << "[client] job result: " + << (msg.summary.empty() ? msg.payload : msg.summary); if (!msg.result_location.empty()) { std::cout << " @ " << msg.result_location; } @@ -1606,7 +1778,48 @@ int runActorSmoke() { std::cerr << "[smoke] codec mismatch\n"; return 1; } - std::cout << "[smoke] actor rpc codec roundtrip ok\n"; + RpcFrame control_frame = makeFrameFromMessage(17, "smoke-worker", "smoke-client", origin); + Table result_table; + result_table.schema = Schema({"token", "score"}); + result_table.rows = { + {Value("smoke"), Value(42.0)}, + {Value("payload"), Value(7.0)}, + }; + RpcFrame batch_frame = + makeDataBatchFrame(18, control_frame.header.message_id, "smoke-worker", "smoke-client", + result_table); + + LengthPrefixedFrameCodec frame_codec; + std::size_t consumed = 0; + RpcFrame decoded_control; + const auto control_wire = frame_codec.encode(control_frame); + if (!frame_codec.decode(control_wire, &decoded_control, &consumed)) { + std::cerr << "[smoke] control frame roundtrip failed\n"; + return 1; + } + ActorRpcMessage framed_copy; + if (!decodeActorRpcMessage(decoded_control.payload, &framed_copy) || + framed_copy.summary != origin.summary) { + std::cerr << "[smoke] control payload decode mismatch\n"; + return 1; + } + + RpcFrame decoded_batch; + const auto batch_wire = frame_codec.encode(batch_frame); + if (!frame_codec.decode(batch_wire, &decoded_batch, &consumed)) { + std::cerr << "[smoke] data batch frame roundtrip failed\n"; + return 1; + } + if (decoded_batch.header.correlation_id != decoded_control.header.message_id) { + std::cerr << "[smoke] data batch correlation mismatch\n"; + return 1; + } + Table batch_copy; + if (!decodeDataBatchFrame(decoded_batch, &batch_copy) || batch_copy.rowCount() != result_table.rowCount()) { + std::cerr << "[smoke] data batch decode mismatch\n"; + return 1; + } + std::cout << "[smoke] actor rpc control/data-batch roundtrip ok\n"; return 0; } diff --git a/src/dataflow/runtime/vector_index.cc b/src/dataflow/runtime/vector_index.cc index f5c8be1..d23a03d 100644 --- a/src/dataflow/runtime/vector_index.cc +++ b/src/dataflow/runtime/vector_index.cc @@ -55,63 +55,33 @@ class ExactScanVectorIndex : public VectorIndex { } const std::size_t k = std::max(1, options.top_k); - auto cmp_desc = [](const VectorSearchResult& lhs, const VectorSearchResult& rhs) { + auto cmp_max = [](const VectorSearchResult& lhs, const VectorSearchResult& rhs) { return lhs.score < rhs.score; }; - auto cmp_asc = [](const VectorSearchResult& lhs, const VectorSearchResult& rhs) { + auto cmp_min = [](const VectorSearchResult& lhs, const VectorSearchResult& rhs) { return lhs.score > rhs.score; }; - std::priority_queue, decltype(cmp_desc)> - max_heap(cmp_desc); - std::priority_queue, decltype(cmp_asc)> - min_heap(cmp_asc); - - double query_norm = 0.0; - for (float v : query) query_norm += static_cast(v) * static_cast(v); - query_norm = std::sqrt(query_norm); - - for (std::size_t row = 0; row < row_count_; ++row) { - const float* base = data_.data() + row * dimension_; - double score = 0.0; - if (options.metric == VectorSearchMetric::L2) { - double sum = 0.0; - for (std::size_t col = 0; col < dimension_; ++col) { - const double diff = static_cast(base[col]) - static_cast(query[col]); - sum += diff * diff; - } - score = std::sqrt(sum); - } else { + std::vector scored; + scored.reserve(std::min(k, row_count_)); + if (options.metric == VectorSearchMetric::Dot) { + std::vector heap_storage; + heap_storage.reserve(k); + std::priority_queue, decltype(cmp_min)> + min_heap(cmp_min, std::move(heap_storage)); + for (std::size_t row = 0; row < row_count_; ++row) { + const float* base = data_.data() + row * dimension_; double dot = 0.0; for (std::size_t col = 0; col < dimension_; ++col) { dot += static_cast(base[col]) * static_cast(query[col]); } - if (options.metric == VectorSearchMetric::Dot) { - score = dot; - } else { - const double denom = norms_[row] * query_norm; - score = denom == 0.0 ? 1.0 : (1.0 - std::max(-1.0, std::min(1.0, dot / denom))); - } - } - VectorSearchResult current{row, score}; - if (options.metric == VectorSearchMetric::Dot) { + VectorSearchResult current{row, dot}; if (min_heap.size() < k) { min_heap.push(current); } else if (current.score > min_heap.top().score) { min_heap.pop(); min_heap.push(current); } - } else { - if (max_heap.size() < k) { - max_heap.push(current); - } else if (current.score < max_heap.top().score) { - max_heap.pop(); - max_heap.push(current); - } } - } - - std::vector scored; - if (options.metric == VectorSearchMetric::Dot) { while (!min_heap.empty()) { scored.push_back(min_heap.top()); min_heap.pop(); @@ -120,7 +90,29 @@ class ExactScanVectorIndex : public VectorIndex { [](const VectorSearchResult& lhs, const VectorSearchResult& rhs) { return lhs.score > rhs.score; }); - } else { + return scored; + } + + std::vector heap_storage; + heap_storage.reserve(k); + std::priority_queue, decltype(cmp_max)> + max_heap(cmp_max, std::move(heap_storage)); + if (options.metric == VectorSearchMetric::L2) { + for (std::size_t row = 0; row < row_count_; ++row) { + const float* base = data_.data() + row * dimension_; + double squared_l2 = 0.0; + for (std::size_t col = 0; col < dimension_; ++col) { + const double diff = static_cast(base[col]) - static_cast(query[col]); + squared_l2 += diff * diff; + } + VectorSearchResult current{row, squared_l2}; + if (max_heap.size() < k) { + max_heap.push(current); + } else if (current.score < max_heap.top().score) { + max_heap.pop(); + max_heap.push(current); + } + } while (!max_heap.empty()) { scored.push_back(max_heap.top()); max_heap.pop(); @@ -129,7 +121,40 @@ class ExactScanVectorIndex : public VectorIndex { [](const VectorSearchResult& lhs, const VectorSearchResult& rhs) { return lhs.score < rhs.score; }); + for (auto& item : scored) { + item.score = std::sqrt(item.score); + } + return scored; + } + + double query_norm = 0.0; + for (float v : query) query_norm += static_cast(v) * static_cast(v); + query_norm = std::sqrt(query_norm); + for (std::size_t row = 0; row < row_count_; ++row) { + const float* base = data_.data() + row * dimension_; + double dot = 0.0; + for (std::size_t col = 0; col < dimension_; ++col) { + dot += static_cast(base[col]) * static_cast(query[col]); + } + const double denom = norms_[row] * query_norm; + const double cosine_distance = + denom == 0.0 ? 1.0 : (1.0 - std::max(-1.0, std::min(1.0, dot / denom))); + VectorSearchResult current{row, cosine_distance}; + if (max_heap.size() < k) { + max_heap.push(current); + } else if (current.score < max_heap.top().score) { + max_heap.pop(); + max_heap.push(current); + } + } + while (!max_heap.empty()) { + scored.push_back(max_heap.top()); + max_heap.pop(); } + std::sort(scored.begin(), scored.end(), + [](const VectorSearchResult& lhs, const VectorSearchResult& rhs) { + return lhs.score < rhs.score; + }); return scored; } diff --git a/src/dataflow/tests/vector_runtime_test.cc b/src/dataflow/tests/vector_runtime_test.cc index 14c1584..666d39b 100644 --- a/src/dataflow/tests/vector_runtime_test.cc +++ b/src/dataflow/tests/vector_runtime_test.cc @@ -1,4 +1,3 @@ -#include #include #include #include @@ -6,6 +5,8 @@ #include #include "src/dataflow/api/session.h" +#include "src/dataflow/rpc/actor_rpc_codec.h" +#include "src/dataflow/rpc/rpc_codec.h" #include "src/dataflow/serial/serializer.h" #include "src/dataflow/stream/binary_row_batch.h" @@ -15,10 +16,6 @@ void expect(bool cond, const std::string& msg) { if (!cond) throw std::runtime_error(msg); } -bool nearlyEqual(double lhs, double rhs, double eps = 1e-6) { - return std::fabs(lhs - rhs) <= eps; -} - } // namespace int main() { @@ -50,6 +47,82 @@ int main() { expect(binary_roundtrip.rows[1][1].asFixedVector()[1] == 0.1f, "binary row batch vector content mismatch"); + dataflow::RpcEnvelope rpc_envelope; + rpc_envelope.type = dataflow::RpcMessageType::DataBatch; + rpc_envelope.codec_id = "table-bin-v1"; + auto table_serializer = dataflow::makeTableRpcSerializer(); + dataflow::RpcDataBatchMessage batch_message{table}; + const auto rpc_payload = table_serializer->serialize(rpc_envelope, &batch_message); + expect(!rpc_payload.empty(), "rpc table payload should not be empty"); + dataflow::RpcDataBatchMessage decoded_batch; + expect(table_serializer->deserialize(rpc_envelope, rpc_payload, &decoded_batch), + "rpc table batch deserialize failed"); + expect(decoded_batch.table.rows.size() == 3, "rpc table batch row count mismatch"); + expect(decoded_batch.table.rows[2][1].asFixedVector()[1] == 1.0f, + "rpc table batch vector content mismatch"); + + dataflow::ActorRpcMessage actor_origin; + actor_origin.action = dataflow::ActorRpcAction::Result; + actor_origin.job_id = "vector-job"; + actor_origin.chain_id = "vector-chain"; + actor_origin.task_id = "vector-task"; + actor_origin.node_id = "vector-worker"; + actor_origin.ok = true; + actor_origin.state = "FINISHED"; + actor_origin.summary = "vector payload"; + actor_origin.result_location = "inline://vector-job/vector-task"; + actor_origin.payload = actor_origin.summary; + const auto actor_wire = dataflow::encodeActorRpcMessage(actor_origin); + dataflow::ActorRpcMessage actor_copy; + expect(dataflow::decodeActorRpcMessage(actor_wire, &actor_copy), + "actor rpc roundtrip failed for control payload"); + expect(actor_copy.summary == actor_origin.summary, + "actor rpc should preserve control summary"); + + dataflow::LengthPrefixedFrameCodec frame_codec; + dataflow::RpcFrame control_frame; + control_frame.header.protocol_version = 1; + control_frame.header.type = dataflow::RpcMessageType::Control; + control_frame.header.message_id = 17; + control_frame.header.correlation_id = 0; + control_frame.header.codec_id = "actor-rpc-v1"; + control_frame.header.source = "vector-worker"; + control_frame.header.target = "vector-client"; + control_frame.payload = actor_wire; + + std::size_t control_consumed = 0; + dataflow::RpcFrame decoded_control_frame; + expect(frame_codec.decode(frame_codec.encode(control_frame), &decoded_control_frame, &control_consumed), + "frame codec should decode control frame"); + expect(decoded_control_frame.header.message_id == control_frame.header.message_id, + "control frame message id mismatch"); + expect(decoded_control_frame.header.codec_id == "actor-rpc-v1", + "control frame codec id mismatch"); + + dataflow::RpcFrame data_frame; + data_frame.header.protocol_version = 1; + data_frame.header.type = dataflow::RpcMessageType::DataBatch; + data_frame.header.message_id = 18; + data_frame.header.correlation_id = control_frame.header.message_id; + data_frame.header.codec_id = "table-bin-v1"; + data_frame.header.source = "vector-worker"; + data_frame.header.target = "vector-client"; + data_frame.payload = rpc_payload; + + std::size_t data_consumed = 0; + dataflow::RpcFrame decoded_data_frame; + expect(frame_codec.decode(frame_codec.encode(data_frame), &decoded_data_frame, &data_consumed), + "frame codec should decode data frame"); + expect(decoded_data_frame.header.correlation_id == control_frame.header.message_id, + "data frame correlation id mismatch"); + dataflow::RpcDataBatchMessage framed_batch; + expect(table_serializer->deserialize(decoded_data_frame.header, decoded_data_frame.payload, &framed_batch), + "framed data batch deserialize failed"); + expect(framed_batch.table.rows.size() == table.rows.size(), + "framed data batch row count mismatch"); + expect(framed_batch.table.rows[0][1].asFixedVector()[0] == 1.0f, + "framed data batch vector content mismatch"); + auto& session = dataflow::DataflowSession::builder(); auto df = session.createDataFrame(table); session.createTempView("vec_src", df); @@ -81,6 +154,20 @@ int main() { expect(explain.find("acceleration=flat-buffer+heap-topk") != std::string::npos, "explain acceleration hint missing"); + dataflow::Table sparse_table; + sparse_table.schema = dataflow::Schema({"id", "embedding"}); + sparse_table.rows = { + {dataflow::Value(int64_t(10))}, + {dataflow::Value(int64_t(20)), dataflow::Value(std::vector{0.0f, 1.0f, 0.0f})}, + {dataflow::Value(int64_t(30)), dataflow::Value(std::vector{1.0f, 0.0f, 0.0f})}, + }; + session.createTempView("vec_sparse", session.createDataFrame(sparse_table)); + auto sparse = session.vectorQuery("vec_sparse", "embedding", {1.0f, 0.0f, 0.0f}, 1, + dataflow::VectorDistanceMetric::Cosine) + .toTable(); + expect(sparse.rows.size() == 1, "sparse vector query top-k mismatch"); + expect(sparse.rows[0][0].asInt64() == 2, "sparse vector query should preserve source row id"); + std::cout << "[test] vector runtime query and transport ok" << std::endl; return 0; } catch (const std::exception& ex) { diff --git a/src/dataflow/transport/ipc_transport.cc b/src/dataflow/transport/ipc_transport.cc index 68ce6f1..a580aef 100644 --- a/src/dataflow/transport/ipc_transport.cc +++ b/src/dataflow/transport/ipc_transport.cc @@ -15,6 +15,12 @@ namespace dataflow { +namespace { + +constexpr uint32_t kMaxFramePayloadBytes = 512u * 1024u * 1024u; + +} // namespace + bool parseEndpoint(const std::string& endpoint, std::string* host, uint16_t* port) { if (host == nullptr || port == nullptr) return false; const std::string::size_type sep = endpoint.find(':'); @@ -141,7 +147,7 @@ bool recvFrameOverSocket(int fd, (static_cast(length_prefix[1]) << 8) | (static_cast(length_prefix[2]) << 16) | (static_cast(length_prefix[3]) << 24); - if (payload_bytes == 0 || payload_bytes > 32u * 1024u * 1024u) return false; + if (payload_bytes == 0 || payload_bytes > kMaxFramePayloadBytes) return false; std::vector bytes(4 + payload_bytes); memcpy(&bytes[0], length_prefix, 4); From 1f68785c7d94205d260ca3120d23804d7ab4a34b Mon Sep 17 00:00:00 2001 From: zuolingxuan Date: Mon, 30 Mar 2026 14:29:37 +0800 Subject: [PATCH 4/4] fix(ci): keep actor rpc smoke success marker stable --- src/dataflow/runner/actor_runtime.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dataflow/runner/actor_runtime.cc b/src/dataflow/runner/actor_runtime.cc index 7c51ebc..852b0bf 100644 --- a/src/dataflow/runner/actor_runtime.cc +++ b/src/dataflow/runner/actor_runtime.cc @@ -1819,7 +1819,7 @@ int runActorSmoke() { std::cerr << "[smoke] data batch decode mismatch\n"; return 1; } - std::cout << "[smoke] actor rpc control/data-batch roundtrip ok\n"; + std::cout << "[smoke] actor rpc codec roundtrip ok (control/data-batch)\n"; return 0; }