Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion config/system_config_demo_s3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# listen on successive ports (e.g., 6584, 6585, etc.).
front_end_interface: "0.0.0.0"
front_end_port: 6783
num_front_ends: 3
num_front_ends: 2

# If installed and enabled, BRAD will serve its UI from a webserver that listens
# for connections on this network interface and port.
Expand Down Expand Up @@ -127,6 +127,7 @@ std_datasets:
bootstrap_vdbe_path: config/vdbe_demo/imdb_etl_vdbes.json
disable_query_logging: true
vdbe_start_port: 10076
flight_sql_mode: "vdbe"

aurora_max_query_factor: 4.0
aurora_max_query_factor_replace: 10000.0
Expand Down
9 changes: 7 additions & 2 deletions cpp/server/brad_server_simple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>> ResultToRecordBatch(
columns.push_back(values);

} else if (field_type->Equals(
arrow::decimal(/*precision=*/10, /*scale=*/2))) {
arrow::decimal128(/*precision=*/10, /*scale=*/2))) {
arrow::Decimal128Builder decimalbuilder(
arrow::decimal(/*precision=*/10, /*scale=*/2));
arrow::decimal128(/*precision=*/10, /*scale=*/2));
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const std::optional<std::string> val =
py::cast<std::optional<std::string>>(
Expand Down Expand Up @@ -149,6 +149,11 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>> ResultToRecordBatch(
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, nullbuilder.Finish());
columns.push_back(values);
} else {
std::cerr << "ERROR: Unsupported field type: " << field_type->ToString()
<< std::endl;
return arrow::Status::NotImplemented("Unsupported field type: ",
field_type->ToString());
}
}

Expand Down
7 changes: 7 additions & 0 deletions src/brad/config/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,13 @@ def vdbe_start_port(self) -> int:
return 9876 # Default
return int(self._raw["vdbe_start_port"])

def flight_sql_mode(self) -> Optional[str]:
try:
return self._raw["flight_sql_mode"]
except KeyError:
# FlightSQL mode is not set.
return None

def _extract_log_path(self, config_key: str) -> Optional[pathlib.Path]:
if config_key not in self._raw:
return None
Expand Down
8 changes: 6 additions & 2 deletions src/brad/connection/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ async def connect_to(
return cls.connect_to_stub(config)

# HACK: Schema aliasing for convenience.
if schema_name is not None and schema_name == "imdb_editable_100g":
if schema_name is not None and (
schema_name == "imdb_editable_100g" or schema_name == "imdb_etl_100g"
):
schema_name = "imdb_extended_100g"

connection_details = config.get_connection_details(engine)
Expand Down Expand Up @@ -158,7 +160,9 @@ async def connect_to_sidecar(
return cls.connect_to_stub(config)

# HACK: Schema aliasing for convenience.
if schema_name is not None and schema_name == "imdb_editable_100g":
if schema_name is not None and (
schema_name == "imdb_editable_100g" or schema_name == "imdb_etl_100g"
):
schema_name = "imdb_extended_100g"

connection_details = config.get_sidecar_db_details()
Expand Down
11 changes: 10 additions & 1 deletion src/brad/exec/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pathlib
import readline
import time
import pyodbc
from typing import List, Tuple
from tabulate import tabulate

Expand Down Expand Up @@ -80,6 +81,11 @@ def run_query(client: BradGrpcClient | BradFlightSqlClientOdbc, query: str) -> N
print("Query resulted in an error:")
print(ex.message())
print()
except pyodbc.Error as ex:
print()
print("Query resulted in an error:")
print(repr(ex))
print()


class BradShell(cmd.Cmd):
Expand Down Expand Up @@ -145,7 +151,10 @@ def main(args) -> None:
host, port = parse_endpoint(args.endpoint)
print("BRAD Interactive Shell v{}".format(brad.__version__))
print()
print("Connecting to BRAD VDBE at {}:{}...".format(host, port))
if args.use_odbc:
print("Connecting to BRAD VDBE at {}:{} (using ODBC)...".format(host, port))
else:
print("Connecting to BRAD VDBE at {}:{}...".format(host, port))

def run_shell(client: BradGrpcClient | BradFlightSqlClientOdbc) -> None:
print("Connected!")
Expand Down
9 changes: 8 additions & 1 deletion src/brad/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def __init__(
input_queue: mp.Queue,
output_queue: mp.Queue,
):
if BradFrontEnd.native_server_is_supported():
if (
BradFrontEnd.native_server_is_supported()
and config.flight_sql_mode() == "front_end"
):
from brad.front_end.flight_sql_server import BradFlightSqlServer

self._flight_sql_server: Optional[BradFlightSqlServer] = (
Expand All @@ -98,8 +101,12 @@ def __init__(
)
)
self._flight_sql_server_session_id: Optional[SessionId] = None
logger.info(
"FlightSQL server is enabled for the front end. Will listen on port 31337."
)
else:
self._flight_sql_server = None
logger.info("FlightSQL server is disabled for the front end.")

self._main_thread_loop: Optional[asyncio.AbstractEventLoop] = None

Expand Down
166 changes: 158 additions & 8 deletions src/brad/front_end/vdbe/vdbe_endpoint_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
import asyncio
import grpc
import json
import logging
from typing import Callable, Optional, Tuple, Dict, AsyncIterable, Any, Set, Awaitable
import threading
from typing import (
Callable,
Optional,
Tuple,
Dict,
AsyncIterable,
Any,
Set,
Awaitable,
)

import brad.proto_gen.brad_pb2_grpc as brad_grpc
from brad.connection.schema import Schema
from brad.config.file import ConfigFile
from brad.connection.schema import Schema, DataType
from brad.front_end.brad_interface import BradInterface
from brad.front_end.grpc import BradGrpc
from brad.front_end.session import SessionManager, SessionId
Expand All @@ -15,9 +27,10 @@
logger = logging.getLogger(__name__)


# (query_string, vdbe_id, session_id, debug_info) -> (rows, schema)
# (query_string, vdbe_id, session_id, debug_info, retrieve_schema) -> (rows, schema)
QueryHandler = Callable[
[str, int, SessionId, Dict[str, Any]], Awaitable[Tuple[RowList, Optional[Schema]]]
[str, int, SessionId, Dict[str, Any], bool],
Awaitable[Tuple[RowList, Optional[Schema]]],
]


Expand All @@ -33,11 +46,30 @@ def __init__(
vdbe_mgr: VdbeFrontEndManager,
session_mgr: SessionManager,
handler: QueryHandler,
config: ConfigFile,
) -> None:
self._vdbe_mgr = vdbe_mgr
self._session_mgr = session_mgr
self._handler = handler
self._config = config
self._endpoints: Dict[int, Tuple[int, grpc.aio.Server, VdbeGrpcInterface]] = {}
self._flight_sql_endpoints: Dict[int, Tuple[int, VdbeFlightSqlServer]] = {}

try:
# pylint: disable-next=import-error,no-name-in-module,unused-import
import brad.native.pybind_brad_server as brad_server

self._use_flight_sql = self._config.flight_sql_mode() == "vdbe"
except ImportError:
self._use_flight_sql = False

if self._use_flight_sql:
logger.info("Will start Flight SQL endpoints for VDBEs.")
else:
logger.info(
"Flight SQL endpoints for VDBEs are not available. "
"Using gRPC endpoints only."
)

async def initialize(self) -> None:
for engine in self._vdbe_mgr.engines():
Expand Down Expand Up @@ -66,22 +98,59 @@ async def add_vdbe_endpoint(self, port: int, vdbe_id: int) -> None:
grpc_server.add_insecure_port(f"0.0.0.0:{port}")
await grpc_server.start()
logger.info(
"Added VDBE endpoint for ID %d. Listening on port %d.", vdbe_id, port
"Added gRPC VDBE endpoint for ID %d. Listening on port %d.", vdbe_id, port
)
self._endpoints[vdbe_id] = (port, grpc_server, query_service)

if self._use_flight_sql:
session_id, _ = await self._session_mgr.create_new_session()
# The flight SQL port is offset by 10,000 from the gRPC port.
flight_sql_port = port + 10_000
flight_sql_server = VdbeFlightSqlServer(
vdbe_id=vdbe_id,
port=flight_sql_port,
main_loop=asyncio.get_event_loop(),
handler=self._handler,
session_id=session_id,
)
flight_sql_server.start()
self._flight_sql_endpoints[vdbe_id] = (flight_sql_port, flight_sql_server)
logger.info(
"Added Flight SQL VDBE endpoint for ID %d. Listening on port %d.",
vdbe_id,
flight_sql_port,
)

async def remove_vdbe_endpoint(self, vdbe_id: int) -> None:
try:
port, grpc_server, query_service = self._endpoints[vdbe_id]
await query_service.end_all_sessions()
# See `brad.front_end.BradFrontEnd.serve_forever`.
grpc_server.__del__()
del self._endpoints[vdbe_id]
logger.info("Removed VDBE endpoint for ID %d (was port %d).", vdbe_id, port)
logger.info(
"Removed gRPC VDBE endpoint for ID %d (was port %d).", vdbe_id, port
)

except KeyError:
logger.error(
"Tried to remove gRPC VDBE endpoint for ID %d, but it was not found.",
vdbe_id,
)

try:
port, flight_sql_server = self._flight_sql_endpoints[vdbe_id]
flight_sql_server.stop()
del self._flight_sql_endpoints[vdbe_id]
await self._session_mgr.end_session(flight_sql_server.session_id)
logger.info(
"Removed Flight SQL VDBE endpoint for ID %d (was port %d).",
vdbe_id,
port,
)
except KeyError:
logger.error(
"Tried to remove VDBE endpoint for ID %d, but it was not found.",
"Tried to remove Flight SQL VDBE endpoint for ID %d, but it was not found.",
vdbe_id,
)

Expand Down Expand Up @@ -144,7 +213,9 @@ async def run_query_json(

This method may throw an error to indicate a problem with the query.
"""
results, _ = await self._handler(query, self._vdbe_id, session_id, debug_info)
results, _ = await self._handler(
query, self._vdbe_id, session_id, debug_info, False
)
return json.dumps(results, cls=DecimalEncoder, default=str)

async def end_session(self, session_id: SessionId) -> None:
Expand All @@ -156,3 +227,82 @@ async def end_all_sessions(self) -> None:
self._our_sessions.clear()
for session_id in our_sessions:
await self._session_mgr.end_session(session_id)


class VdbeFlightSqlServer:
def __init__(
self,
*,
vdbe_id: int,
port: int,
main_loop: asyncio.AbstractEventLoop,
handler: QueryHandler,
session_id: SessionId,
) -> None:
# pylint: disable-next=import-error,no-name-in-module
import brad.native.pybind_brad_server as brad_server

# pylint: disable-next=c-extension-no-member
self._flight_sql_server = brad_server.BradFlightSqlServer()
self._flight_sql_server.init("0.0.0.0", port, self._handle_query)
self._thread = threading.Thread(
name=f"FlightSqlServer-{vdbe_id}", target=self._serve
)
self._vdbe_id = vdbe_id
self._port = port
# Important: The endpoint manager is responsible for creating and
# terminating the session.
self.session_id = session_id

self._main_loop = main_loop
self._handler = handler

def start(self) -> None:
self._thread.start()

def stop(self) -> None:
logger.info(
"BRAD FlightSQL server stopping (port %d, VDBE %d)...",
self._port,
self._vdbe_id,
)
self._flight_sql_server.shutdown()
self._thread.join()
logger.info(
"BRAD FlightSQL server stopped (port %d, VDBE %d).",
self._port,
self._vdbe_id,
)

def _serve(self) -> None:
self._flight_sql_server.serve()

def _handle_query(self, query: str) -> Tuple[RowList, Schema]:
# This method is called from a separate thread. So it's very important
# to schedule the handler on the main event loop thread.
debug_info: Dict[str, Any] = {}
future = asyncio.run_coroutine_threadsafe( # type: ignore
self._handler(query, self._vdbe_id, self.session_id, debug_info, True), # type: ignore
self._main_loop,
)
row_result, schema = future.result()
assert schema is not None

# We need to do extra processing for decimal fields since our C++
# backend expects them as strings.
decimal_fields = []
for idx, field in enumerate(schema.fields):
if field.data_type == DataType.Decimal:
decimal_fields.append(idx)

if len(decimal_fields) > 0:
new_rows = []
for row in row_result:
new_row = tuple(
str(value) if idx in decimal_fields else value
for idx, value in enumerate(row)
)
new_rows.append(new_row)
row_result = new_rows

return row_result, schema
1 change: 1 addition & 0 deletions src/brad/front_end/vdbe/vdbe_front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
vdbe_mgr=self._vdbe_mgr,
session_mgr=self._sessions,
handler=self._run_query_impl,
config=self._config,
)
self._shutdown_event = asyncio.Event()

Expand Down