Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 148 additions & 17 deletions cpp/server/brad_server_simple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,95 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>> ResultToRecordBatch(
return result_record_batch;
}

size_t CountSQLPlaceholders(const std::string& query) {
bool in_single_quote = false;
bool in_double_quote = false;
size_t count = 0;

for (size_t i = 0; i < query.size(); ++i) {
char c = query[i];

if (c == '\'' && !in_double_quote) {
in_single_quote = !in_single_quote;
} else if (c == '"' && !in_single_quote) {
in_double_quote = !in_double_quote;
} else if (c == '?' && !in_single_quote && !in_double_quote) {
++count;
} else if (c == '\\' && i + 1 < query.size()) {
++i; // skip escaped characters
}
}

return count;
}

std::shared_ptr<arrow::Schema> HardcodedSchema() {
std::vector<std::shared_ptr<arrow::Field>> fields;
fields.reserve(2);

std::string field_name1 = "showing_id";
std::string field_name2 = "total_quantity";
std::shared_ptr<arrow::DataType> data_type1 = arrow::int64();
std::shared_ptr<arrow::DataType> data_type2 = arrow::int64();
fields.push_back(arrow::field(std::move(field_name1), std::move(data_type1)));
fields.push_back(arrow::field(std::move(field_name2), std::move(data_type2)));

return arrow::schema(std::move(fields));
}

std::vector<std::string> GenerateSQLWithValues(
const std::shared_ptr<arrow::RecordBatch>& batch,
const std::string& sql_template) {
if (batch->num_columns() != 2) {
throw std::runtime_error("RecordBatch must have exactly 2 columns.");
}

for (int i = 0; i < 2; ++i) {
if (batch->column(i)->type_id() != arrow::Type::INT64) {
throw std::runtime_error("Both columns must be of type int64.");
}
}

// Count placeholders
const auto placeholder_count = CountSQLPlaceholders(sql_template);
if (placeholder_count != 2) {
throw std::runtime_error("SQL string must contain exactly 2 placeholders.");
}

auto col0 = std::static_pointer_cast<arrow::Int64Array>(batch->column(0));
auto col1 = std::static_pointer_cast<arrow::Int64Array>(batch->column(1));

std::vector<std::string> results;
for (int64_t row = 0; row < batch->num_rows(); ++row) {
if (col0->IsNull(row) || col1->IsNull(row)) {
throw std::runtime_error(
"Null values are not supported in placeholder substitution.");
}

int64_t val0 = col0->Value(row);
int64_t val1 = col1->Value(row);

// Replace the placeholders one by one
std::string result;
size_t pos = 0;
int replace_count = 0;
for (char c : sql_template) {
if (c == '?' && replace_count == 0) {
result += std::to_string(val0);
++replace_count;
} else if (c == '?' && replace_count == 1) {
result += std::to_string(val1);
++replace_count;
} else {
result += c;
}
}

results.push_back(result);
}
return results;
}

BradFlightSqlServer::BradFlightSqlServer() : autoincrement_id_(0ULL) {}

BradFlightSqlServer::~BradFlightSqlServer() = default;
Expand Down Expand Up @@ -250,10 +339,17 @@ BradFlightSqlServer::CreatePreparedStatement(
const PreparedStatementContext statement_context{request.query,
request.transaction_id};
prepared_statements_.insert(id, statement_context);
// std::cerr << "Registered prepared statement " << id << " " << request.query
// << std::endl;
return arrow::flight::sql::ActionCreatePreparedStatementResult{nullptr,
nullptr, id};
const size_t num_params = CountSQLPlaceholders(request.query);
if (num_params == 0) {
return arrow::flight::sql::ActionCreatePreparedStatementResult{nullptr,
nullptr, id};
} else if (num_params == 2) {
return arrow::flight::sql::ActionCreatePreparedStatementResult{
nullptr, HardcodedSchema(), id};
} else {
return arrow::Status::Invalid(
"Unsupported number of parameters in prepared statement: ", num_params);
}
}

arrow::Status BradFlightSqlServer::ClosePreparedStatement(
Expand Down Expand Up @@ -289,17 +385,52 @@ BradFlightSqlServer::GetFlightInfoPreparedStatement(
return GetFlightInfoImpl(query, transaction_id, descriptor);
}

// Currently unimplemented.

arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>>
BradFlightSqlServer::DoGetPreparedStatement(
arrow::Result<int64_t> BradFlightSqlServer::DoPutPreparedStatementUpdate(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementQuery& command) {
std::cerr << "DoGetPreparedStatement called "
<< command.prepared_statement_handle << std::endl;
return arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>>();
const arrow::flight::sql::PreparedStatementUpdate& command,
arrow::flight::FlightMessageReader* reader) {
// std::cerr << "DoPutPreparedStatementUpdate called "
// << command.prepared_statement_handle << std::endl;
const PreparedStatementContext* statement_ctx = nullptr;
prepared_statements_.find_fn(
command.prepared_statement_handle,
[&statement_ctx](const auto& ps_ctx) { statement_ctx = &ps_ctx; });
if (statement_ctx == nullptr) {
return arrow::Status::Invalid("Invalid prepared statement handle.");
}

auto record_batches_result = reader->ToRecordBatches();
if (!record_batches_result.ok()) {
return record_batches_result.status();
}

int64_t num_rows = 0;
auto record_batches = record_batches_result.ValueOrDie();
try {
for (auto& batch : record_batches) {
const auto queries = GenerateSQLWithValues(batch, statement_ctx->query);
{
// NOTE: This is a blocking call to handle_query_. We use a mutex as a
// simple way to ensure only one Flight SQL-sourced query runs at a
// time. This is because concurrent Flight SQL queries need to run in
// different sessions, which requires additional setup.
std::unique_lock<std::mutex> lock(mutex_);
py::gil_scoped_acquire guard;
for (const auto& query : queries) {
handle_query_(query);
}
}
num_rows += batch->num_rows();
}
} catch (const std::runtime_error& e) {
return arrow::Status::Invalid(e.what());
}

return arrow::Result<int64_t>(num_rows);
}

// Currently unimplemented.

arrow::Status BradFlightSqlServer::DoPutPreparedStatementQuery(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementQuery& command,
Expand All @@ -310,13 +441,13 @@ arrow::Status BradFlightSqlServer::DoPutPreparedStatementQuery(
return arrow::Status();
}

arrow::Result<int64_t> BradFlightSqlServer::DoPutPreparedStatementUpdate(
arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>>
BradFlightSqlServer::DoGetPreparedStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementUpdate& command,
arrow::flight::FlightMessageReader* reader) {
std::cerr << "DoPutPreparedStatementUpdate called "
const arrow::flight::sql::PreparedStatementQuery& command) {
std::cerr << "DoGetPreparedStatement called "
<< command.prepared_statement_handle << std::endl;
return arrow::Result<int64_t>();
return arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>>();
}

arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>>
Expand Down
2 changes: 2 additions & 0 deletions cpp/server/brad_server_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <string>
#include <utility>
#include <vector>
#include <mutex>

#include "brad_statement.h"
#include "libcuckoo/cuckoohash_map.hh"
Expand Down Expand Up @@ -106,6 +107,7 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase {
prepared_statements_;

std::atomic<uint64_t> autoincrement_id_;
std::mutex mutex_;
};

} // namespace brad
5 changes: 5 additions & 0 deletions src/brad/front_end/vdbe/vdbe_front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ async def _run_query_impl(
self._query_latency_sketches[vdbe_id] = self._get_empty_sketch()
self._query_latency_sketches[vdbe_id].add(run_time_s_float)

# fetchall() may raise an error if the query does not produce output
# (e.g., INSERT).
if query_rep.is_data_modification_query():
return ([], Schema.empty() if retrieve_schema else None)

# Extract and return the results, if any.
try:
result_row_limit = self._config.result_row_limit()
Expand Down