From aca964222b140001d2c07ee5bffe9ae13803f34a Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Tue, 24 Jun 2025 19:03:19 -0400 Subject: [PATCH 1/4] Return placeholder schema --- cpp/server/brad_server_simple.cc | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 2821b11a..d5e3331e 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -163,6 +163,20 @@ arrow::Result> ResultToRecordBatch( return result_record_batch; } +std::shared_ptr SimpleSchema() { + std::vector> fields; + fields.reserve(2); + + std::string field_name1 = "showing_id"; + std::string field_name2 = "total_quantity"; + std::shared_ptr data_type1 = arrow::int64(); + std::shared_ptr data_type2 = arrow::int64(); + fields.push_back(arrow::field(std::move(field_name1), std::move(data_type1))); + fields.push_back(arrow::field(std::move(field_name2), std::move(data_type2))); + + return arrow::schema(std::move(fields)); +} + BradFlightSqlServer::BradFlightSqlServer() : autoincrement_id_(0ULL) {} BradFlightSqlServer::~BradFlightSqlServer() = default; @@ -253,7 +267,7 @@ BradFlightSqlServer::CreatePreparedStatement( // std::cerr << "Registered prepared statement " << id << " " << request.query // << std::endl; return arrow::flight::sql::ActionCreatePreparedStatementResult{nullptr, - nullptr, id}; + SimpleSchema(), id}; } arrow::Status BradFlightSqlServer::ClosePreparedStatement( From e38b27e498ca49dccd20f350e75efadf66c0fd63 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Wed, 25 Jun 2025 07:53:30 -0400 Subject: [PATCH 2/4] Implementation updates to support writes --- cpp/server/brad_server_simple.cc | 144 +++++++++++++++++++++++++++---- 1 file changed, 127 insertions(+), 17 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index d5e3331e..9e4d1e2f 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -163,7 +163,29 @@ arrow::Result> ResultToRecordBatch( return result_record_batch; } -std::shared_ptr SimpleSchema() { +size_t CountSQLPlaceholders(const std::string& query) { + bool in_single_quote = false; + bool in_double_quote = false; + size_t count = 0; + + for (size_t i = 0; i < query.size(); ++i) { + char c = query[i]; + + if (c == '\'' && !in_double_quote) { + in_single_quote = !in_single_quote; + } else if (c == '"' && !in_single_quote) { + in_double_quote = !in_double_quote; + } else if (c == '?' && !in_single_quote && !in_double_quote) { + ++count; + } else if (c == '\\' && i + 1 < query.size()) { + ++i; // skip escaped characters + } + } + + return count; +} + +std::shared_ptr HardcodedSchema() { std::vector> fields; fields.reserve(2); @@ -177,6 +199,59 @@ std::shared_ptr SimpleSchema() { return arrow::schema(std::move(fields)); } +std::vector GenerateSQLWithValues( + const std::shared_ptr& batch, + const std::string& sql_template) { + if (batch->num_columns() != 2) { + throw std::runtime_error("RecordBatch must have exactly 2 columns."); + } + + for (int i = 0; i < 2; ++i) { + if (batch->column(i)->type_id() != arrow::Type::INT64) { + throw std::runtime_error("Both columns must be of type int64."); + } + } + + // Count placeholders + const auto placeholder_count = CountSQLPlaceholders(sql_template); + if (placeholder_count != 2) { + throw std::runtime_error("SQL string must contain exactly 2 placeholders."); + } + + auto col0 = std::static_pointer_cast(batch->column(0)); + auto col1 = std::static_pointer_cast(batch->column(1)); + + std::vector results; + for (int64_t row = 0; row < batch->num_rows(); ++row) { + if (col0->IsNull(row) || col1->IsNull(row)) { + throw std::runtime_error( + "Null values are not supported in placeholder substitution."); + } + + int64_t val0 = col0->Value(row); + int64_t val1 = col1->Value(row); + + // Replace the placeholders one by one + std::string result; + size_t pos = 0; + int replace_count = 0; + for (char c : sql_template) { + if (c == '?' && replace_count == 0) { + result += std::to_string(val0); + ++replace_count; + } else if (c == '?' && replace_count == 1) { + result += std::to_string(val1); + ++replace_count; + } else { + result += c; + } + } + + results.push_back(result); + } + return results; +} + BradFlightSqlServer::BradFlightSqlServer() : autoincrement_id_(0ULL) {} BradFlightSqlServer::~BradFlightSqlServer() = default; @@ -264,10 +339,17 @@ BradFlightSqlServer::CreatePreparedStatement( const PreparedStatementContext statement_context{request.query, request.transaction_id}; prepared_statements_.insert(id, statement_context); - // std::cerr << "Registered prepared statement " << id << " " << request.query - // << std::endl; - return arrow::flight::sql::ActionCreatePreparedStatementResult{nullptr, - SimpleSchema(), id}; + const size_t num_params = CountSQLPlaceholders(request.query); + if (num_params == 0) { + return arrow::flight::sql::ActionCreatePreparedStatementResult{nullptr, + nullptr, id}; + } else if (num_params == 2) { + return arrow::flight::sql::ActionCreatePreparedStatementResult{ + nullptr, HardcodedSchema(), id}; + } else { + return arrow::Status::Invalid( + "Unsupported number of parameters in prepared statement: ", num_params); + } } arrow::Status BradFlightSqlServer::ClosePreparedStatement( @@ -303,17 +385,45 @@ BradFlightSqlServer::GetFlightInfoPreparedStatement( return GetFlightInfoImpl(query, transaction_id, descriptor); } -// Currently unimplemented. - -arrow::Result> -BradFlightSqlServer::DoGetPreparedStatement( +arrow::Result BradFlightSqlServer::DoPutPreparedStatementUpdate( const arrow::flight::ServerCallContext& context, - const arrow::flight::sql::PreparedStatementQuery& command) { - std::cerr << "DoGetPreparedStatement called " + const arrow::flight::sql::PreparedStatementUpdate& command, + arrow::flight::FlightMessageReader* reader) { + std::cerr << "DoPutPreparedStatementUpdate called " << command.prepared_statement_handle << std::endl; - return arrow::Result>(); + const PreparedStatementContext* statement_ctx = nullptr; + prepared_statements_.find_fn( + command.prepared_statement_handle, + [&statement_ctx](const auto& ps_ctx) { statement_ctx = &ps_ctx; }); + if (statement_ctx == nullptr) { + return arrow::Status::Invalid("Invalid prepared statement handle."); + } + + auto record_batches_result = reader->ToRecordBatches(); + if (!record_batches_result.ok()) { + return record_batches_result.status(); + } + + int64_t num_rows = 0; + auto record_batches = record_batches_result.ValueOrDie(); + try { + for (auto& batch : record_batches) { + const auto queries = GenerateSQLWithValues(batch, statement_ctx->query); + std::cerr << "Would run queries: " << std::endl; + for (const auto& query : queries) { + std::cerr << " " << query << std::endl; + } + num_rows += batch->num_rows(); + } + } catch (const std::runtime_error& e) { + return arrow::Status::Invalid(e.what()); + } + + return arrow::Result(num_rows); } +// Currently unimplemented. + arrow::Status BradFlightSqlServer::DoPutPreparedStatementQuery( const arrow::flight::ServerCallContext& context, const arrow::flight::sql::PreparedStatementQuery& command, @@ -324,13 +434,13 @@ arrow::Status BradFlightSqlServer::DoPutPreparedStatementQuery( return arrow::Status(); } -arrow::Result BradFlightSqlServer::DoPutPreparedStatementUpdate( +arrow::Result> +BradFlightSqlServer::DoGetPreparedStatement( const arrow::flight::ServerCallContext& context, - const arrow::flight::sql::PreparedStatementUpdate& command, - arrow::flight::FlightMessageReader* reader) { - std::cerr << "DoPutPreparedStatementUpdate called " + const arrow::flight::sql::PreparedStatementQuery& command) { + std::cerr << "DoGetPreparedStatement called " << command.prepared_statement_handle << std::endl; - return arrow::Result(); + return arrow::Result>(); } arrow::Result> From d44e3930a743e451ccfb110d91c13fe6a006d175 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Wed, 25 Jun 2025 08:38:58 -0400 Subject: [PATCH 3/4] Pass the query to the front end --- cpp/server/brad_server_simple.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 9e4d1e2f..da5774bc 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -389,8 +389,8 @@ arrow::Result BradFlightSqlServer::DoPutPreparedStatementUpdate( const arrow::flight::ServerCallContext& context, const arrow::flight::sql::PreparedStatementUpdate& command, arrow::flight::FlightMessageReader* reader) { - std::cerr << "DoPutPreparedStatementUpdate called " - << command.prepared_statement_handle << std::endl; + // std::cerr << "DoPutPreparedStatementUpdate called " + // << command.prepared_statement_handle << std::endl; const PreparedStatementContext* statement_ctx = nullptr; prepared_statements_.find_fn( command.prepared_statement_handle, @@ -409,9 +409,11 @@ arrow::Result BradFlightSqlServer::DoPutPreparedStatementUpdate( try { for (auto& batch : record_batches) { const auto queries = GenerateSQLWithValues(batch, statement_ctx->query); - std::cerr << "Would run queries: " << std::endl; - for (const auto& query : queries) { - std::cerr << " " << query << std::endl; + { + py::gil_scoped_acquire guard; + for (const auto& query : queries) { + handle_query_(query); + } } num_rows += batch->num_rows(); } From fdd610e5d789118e3b4b06a3fd4c88aeb75beba4 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Wed, 25 Jun 2025 09:36:32 -0400 Subject: [PATCH 4/4] Additional fixes --- cpp/server/brad_server_simple.cc | 5 +++++ cpp/server/brad_server_simple.h | 2 ++ src/brad/front_end/vdbe/vdbe_front_end.py | 5 +++++ 3 files changed, 12 insertions(+) diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index da5774bc..bf535dc2 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -410,6 +410,11 @@ arrow::Result BradFlightSqlServer::DoPutPreparedStatementUpdate( for (auto& batch : record_batches) { const auto queries = GenerateSQLWithValues(batch, statement_ctx->query); { + // NOTE: This is a blocking call to handle_query_. We use a mutex as a + // simple way to ensure only one Flight SQL-sourced query runs at a + // time. This is because concurrent Flight SQL queries need to run in + // different sessions, which requires additional setup. + std::unique_lock lock(mutex_); py::gil_scoped_acquire guard; for (const auto& query : queries) { handle_query_(query); diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index 6b9c1742..291e93cb 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -12,6 +12,7 @@ #include #include #include +#include #include "brad_statement.h" #include "libcuckoo/cuckoohash_map.hh" @@ -106,6 +107,7 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { prepared_statements_; std::atomic autoincrement_id_; + std::mutex mutex_; }; } // namespace brad diff --git a/src/brad/front_end/vdbe/vdbe_front_end.py b/src/brad/front_end/vdbe/vdbe_front_end.py index a2e03aa5..9dbedce5 100644 --- a/src/brad/front_end/vdbe/vdbe_front_end.py +++ b/src/brad/front_end/vdbe/vdbe_front_end.py @@ -319,6 +319,11 @@ async def _run_query_impl( self._query_latency_sketches[vdbe_id] = self._get_empty_sketch() self._query_latency_sketches[vdbe_id].add(run_time_s_float) + # fetchall() may raise an error if the query does not produce output + # (e.g., INSERT). + if query_rep.is_data_modification_query(): + return ([], Schema.empty() if retrieve_schema else None) + # Extract and return the results, if any. try: result_row_limit = self._config.result_row_limit()