diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 2821b11a..bf535dc2 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -163,6 +163,95 @@ arrow::Result> ResultToRecordBatch( return result_record_batch; } +size_t CountSQLPlaceholders(const std::string& query) { + bool in_single_quote = false; + bool in_double_quote = false; + size_t count = 0; + + for (size_t i = 0; i < query.size(); ++i) { + char c = query[i]; + + if (c == '\'' && !in_double_quote) { + in_single_quote = !in_single_quote; + } else if (c == '"' && !in_single_quote) { + in_double_quote = !in_double_quote; + } else if (c == '?' && !in_single_quote && !in_double_quote) { + ++count; + } else if (c == '\\' && i + 1 < query.size()) { + ++i; // skip escaped characters + } + } + + return count; +} + +std::shared_ptr HardcodedSchema() { + std::vector> fields; + fields.reserve(2); + + std::string field_name1 = "showing_id"; + std::string field_name2 = "total_quantity"; + std::shared_ptr data_type1 = arrow::int64(); + std::shared_ptr data_type2 = arrow::int64(); + fields.push_back(arrow::field(std::move(field_name1), std::move(data_type1))); + fields.push_back(arrow::field(std::move(field_name2), std::move(data_type2))); + + return arrow::schema(std::move(fields)); +} + +std::vector GenerateSQLWithValues( + const std::shared_ptr& batch, + const std::string& sql_template) { + if (batch->num_columns() != 2) { + throw std::runtime_error("RecordBatch must have exactly 2 columns."); + } + + for (int i = 0; i < 2; ++i) { + if (batch->column(i)->type_id() != arrow::Type::INT64) { + throw std::runtime_error("Both columns must be of type int64."); + } + } + + // Count placeholders + const auto placeholder_count = CountSQLPlaceholders(sql_template); + if (placeholder_count != 2) { + throw std::runtime_error("SQL string must contain exactly 2 placeholders."); + } + + auto col0 = std::static_pointer_cast(batch->column(0)); + auto col1 = std::static_pointer_cast(batch->column(1)); + + std::vector results; + for (int64_t row = 0; row < batch->num_rows(); ++row) { + if (col0->IsNull(row) || col1->IsNull(row)) { + throw std::runtime_error( + "Null values are not supported in placeholder substitution."); + } + + int64_t val0 = col0->Value(row); + int64_t val1 = col1->Value(row); + + // Replace the placeholders one by one + std::string result; + size_t pos = 0; + int replace_count = 0; + for (char c : sql_template) { + if (c == '?' && replace_count == 0) { + result += std::to_string(val0); + ++replace_count; + } else if (c == '?' && replace_count == 1) { + result += std::to_string(val1); + ++replace_count; + } else { + result += c; + } + } + + results.push_back(result); + } + return results; +} + BradFlightSqlServer::BradFlightSqlServer() : autoincrement_id_(0ULL) {} BradFlightSqlServer::~BradFlightSqlServer() = default; @@ -250,10 +339,17 @@ BradFlightSqlServer::CreatePreparedStatement( const PreparedStatementContext statement_context{request.query, request.transaction_id}; prepared_statements_.insert(id, statement_context); - // std::cerr << "Registered prepared statement " << id << " " << request.query - // << std::endl; - return arrow::flight::sql::ActionCreatePreparedStatementResult{nullptr, - nullptr, id}; + const size_t num_params = CountSQLPlaceholders(request.query); + if (num_params == 0) { + return arrow::flight::sql::ActionCreatePreparedStatementResult{nullptr, + nullptr, id}; + } else if (num_params == 2) { + return arrow::flight::sql::ActionCreatePreparedStatementResult{ + nullptr, HardcodedSchema(), id}; + } else { + return arrow::Status::Invalid( + "Unsupported number of parameters in prepared statement: ", num_params); + } } arrow::Status BradFlightSqlServer::ClosePreparedStatement( @@ -289,17 +385,52 @@ BradFlightSqlServer::GetFlightInfoPreparedStatement( return GetFlightInfoImpl(query, transaction_id, descriptor); } -// Currently unimplemented. - -arrow::Result> -BradFlightSqlServer::DoGetPreparedStatement( +arrow::Result BradFlightSqlServer::DoPutPreparedStatementUpdate( const arrow::flight::ServerCallContext& context, - const arrow::flight::sql::PreparedStatementQuery& command) { - std::cerr << "DoGetPreparedStatement called " - << command.prepared_statement_handle << std::endl; - return arrow::Result>(); + const arrow::flight::sql::PreparedStatementUpdate& command, + arrow::flight::FlightMessageReader* reader) { + // std::cerr << "DoPutPreparedStatementUpdate called " + // << command.prepared_statement_handle << std::endl; + const PreparedStatementContext* statement_ctx = nullptr; + prepared_statements_.find_fn( + command.prepared_statement_handle, + [&statement_ctx](const auto& ps_ctx) { statement_ctx = &ps_ctx; }); + if (statement_ctx == nullptr) { + return arrow::Status::Invalid("Invalid prepared statement handle."); + } + + auto record_batches_result = reader->ToRecordBatches(); + if (!record_batches_result.ok()) { + return record_batches_result.status(); + } + + int64_t num_rows = 0; + auto record_batches = record_batches_result.ValueOrDie(); + try { + for (auto& batch : record_batches) { + const auto queries = GenerateSQLWithValues(batch, statement_ctx->query); + { + // NOTE: This is a blocking call to handle_query_. We use a mutex as a + // simple way to ensure only one Flight SQL-sourced query runs at a + // time. This is because concurrent Flight SQL queries need to run in + // different sessions, which requires additional setup. + std::unique_lock lock(mutex_); + py::gil_scoped_acquire guard; + for (const auto& query : queries) { + handle_query_(query); + } + } + num_rows += batch->num_rows(); + } + } catch (const std::runtime_error& e) { + return arrow::Status::Invalid(e.what()); + } + + return arrow::Result(num_rows); } +// Currently unimplemented. + arrow::Status BradFlightSqlServer::DoPutPreparedStatementQuery( const arrow::flight::ServerCallContext& context, const arrow::flight::sql::PreparedStatementQuery& command, @@ -310,13 +441,13 @@ arrow::Status BradFlightSqlServer::DoPutPreparedStatementQuery( return arrow::Status(); } -arrow::Result BradFlightSqlServer::DoPutPreparedStatementUpdate( +arrow::Result> +BradFlightSqlServer::DoGetPreparedStatement( const arrow::flight::ServerCallContext& context, - const arrow::flight::sql::PreparedStatementUpdate& command, - arrow::flight::FlightMessageReader* reader) { - std::cerr << "DoPutPreparedStatementUpdate called " + const arrow::flight::sql::PreparedStatementQuery& command) { + std::cerr << "DoGetPreparedStatement called " << command.prepared_statement_handle << std::endl; - return arrow::Result(); + return arrow::Result>(); } arrow::Result> diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index 6b9c1742..291e93cb 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -12,6 +12,7 @@ #include #include #include +#include #include "brad_statement.h" #include "libcuckoo/cuckoohash_map.hh" @@ -106,6 +107,7 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { prepared_statements_; std::atomic autoincrement_id_; + std::mutex mutex_; }; } // namespace brad diff --git a/src/brad/front_end/vdbe/vdbe_front_end.py b/src/brad/front_end/vdbe/vdbe_front_end.py index a2e03aa5..9dbedce5 100644 --- a/src/brad/front_end/vdbe/vdbe_front_end.py +++ b/src/brad/front_end/vdbe/vdbe_front_end.py @@ -319,6 +319,11 @@ async def _run_query_impl( self._query_latency_sketches[vdbe_id] = self._get_empty_sketch() self._query_latency_sketches[vdbe_id].add(run_time_s_float) + # fetchall() may raise an error if the query does not produce output + # (e.g., INSERT). + if query_rep.is_data_modification_query(): + return ([], Schema.empty() if retrieve_schema else None) + # Extract and return the results, if any. try: result_row_limit = self._config.result_row_limit()