diff --git a/config/system_config_demo_s3.yml b/config/system_config_demo_s3.yml index 545df619..fffc73b8 100644 --- a/config/system_config_demo_s3.yml +++ b/config/system_config_demo_s3.yml @@ -6,7 +6,7 @@ # listen on successive ports (e.g., 6584, 6585, etc.). front_end_interface: "0.0.0.0" front_end_port: 6783 -num_front_ends: 3 +num_front_ends: 2 # If installed and enabled, BRAD will serve its UI from a webserver that listens # for connections on this network interface and port. @@ -127,6 +127,7 @@ std_datasets: bootstrap_vdbe_path: config/vdbe_demo/imdb_etl_vdbes.json disable_query_logging: true vdbe_start_port: 10076 +flight_sql_mode: "vdbe" aurora_max_query_factor: 4.0 aurora_max_query_factor_replace: 10000.0 diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index a7fff57f..2821b11a 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -92,9 +92,9 @@ arrow::Result> ResultToRecordBatch( columns.push_back(values); } else if (field_type->Equals( - arrow::decimal(/*precision=*/10, /*scale=*/2))) { + arrow::decimal128(/*precision=*/10, /*scale=*/2))) { arrow::Decimal128Builder decimalbuilder( - arrow::decimal(/*precision=*/10, /*scale=*/2)); + arrow::decimal128(/*precision=*/10, /*scale=*/2)); for (int row_ix = 0; row_ix < num_rows; ++row_ix) { const std::optional val = py::cast>( @@ -149,6 +149,11 @@ arrow::Result> ResultToRecordBatch( std::shared_ptr values; ARROW_ASSIGN_OR_RAISE(values, nullbuilder.Finish()); columns.push_back(values); + } else { + std::cerr << "ERROR: Unsupported field type: " << field_type->ToString() + << std::endl; + return arrow::Status::NotImplemented("Unsupported field type: ", + field_type->ToString()); } } diff --git a/src/brad/config/file.py b/src/brad/config/file.py index 1bc9f551..43593245 100644 --- a/src/brad/config/file.py +++ b/src/brad/config/file.py @@ -315,6 +315,13 @@ def vdbe_start_port(self) -> int: return 9876 # Default return int(self._raw["vdbe_start_port"]) + def flight_sql_mode(self) -> Optional[str]: + try: + return self._raw["flight_sql_mode"] + except KeyError: + # FlightSQL mode is not set. + return None + def _extract_log_path(self, config_key: str) -> Optional[pathlib.Path]: if config_key not in self._raw: return None diff --git a/src/brad/connection/factory.py b/src/brad/connection/factory.py index 90a621c4..fe1424d1 100644 --- a/src/brad/connection/factory.py +++ b/src/brad/connection/factory.py @@ -29,7 +29,9 @@ async def connect_to( return cls.connect_to_stub(config) # HACK: Schema aliasing for convenience. - if schema_name is not None and schema_name == "imdb_editable_100g": + if schema_name is not None and ( + schema_name == "imdb_editable_100g" or schema_name == "imdb_etl_100g" + ): schema_name = "imdb_extended_100g" connection_details = config.get_connection_details(engine) @@ -158,7 +160,9 @@ async def connect_to_sidecar( return cls.connect_to_stub(config) # HACK: Schema aliasing for convenience. - if schema_name is not None and schema_name == "imdb_editable_100g": + if schema_name is not None and ( + schema_name == "imdb_editable_100g" or schema_name == "imdb_etl_100g" + ): schema_name = "imdb_extended_100g" connection_details = config.get_sidecar_db_details() diff --git a/src/brad/exec/cli.py b/src/brad/exec/cli.py index 9e5b8fa5..b7cc9df8 100644 --- a/src/brad/exec/cli.py +++ b/src/brad/exec/cli.py @@ -2,6 +2,7 @@ import pathlib import readline import time +import pyodbc from typing import List, Tuple from tabulate import tabulate @@ -80,6 +81,11 @@ def run_query(client: BradGrpcClient | BradFlightSqlClientOdbc, query: str) -> N print("Query resulted in an error:") print(ex.message()) print() + except pyodbc.Error as ex: + print() + print("Query resulted in an error:") + print(repr(ex)) + print() class BradShell(cmd.Cmd): @@ -145,7 +151,10 @@ def main(args) -> None: host, port = parse_endpoint(args.endpoint) print("BRAD Interactive Shell v{}".format(brad.__version__)) print() - print("Connecting to BRAD VDBE at {}:{}...".format(host, port)) + if args.use_odbc: + print("Connecting to BRAD VDBE at {}:{} (using ODBC)...".format(host, port)) + else: + print("Connecting to BRAD VDBE at {}:{}...".format(host, port)) def run_shell(client: BradGrpcClient | BradFlightSqlClientOdbc) -> None: print("Connected!") diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index 9ed2aea2..7cae88c3 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -87,7 +87,10 @@ def __init__( input_queue: mp.Queue, output_queue: mp.Queue, ): - if BradFrontEnd.native_server_is_supported(): + if ( + BradFrontEnd.native_server_is_supported() + and config.flight_sql_mode() == "front_end" + ): from brad.front_end.flight_sql_server import BradFlightSqlServer self._flight_sql_server: Optional[BradFlightSqlServer] = ( @@ -98,8 +101,12 @@ def __init__( ) ) self._flight_sql_server_session_id: Optional[SessionId] = None + logger.info( + "FlightSQL server is enabled for the front end. Will listen on port 31337." + ) else: self._flight_sql_server = None + logger.info("FlightSQL server is disabled for the front end.") self._main_thread_loop: Optional[asyncio.AbstractEventLoop] = None diff --git a/src/brad/front_end/vdbe/vdbe_endpoint_manager.py b/src/brad/front_end/vdbe/vdbe_endpoint_manager.py index 16c4814e..e085a0c7 100644 --- a/src/brad/front_end/vdbe/vdbe_endpoint_manager.py +++ b/src/brad/front_end/vdbe/vdbe_endpoint_manager.py @@ -1,10 +1,22 @@ +import asyncio import grpc import json import logging -from typing import Callable, Optional, Tuple, Dict, AsyncIterable, Any, Set, Awaitable +import threading +from typing import ( + Callable, + Optional, + Tuple, + Dict, + AsyncIterable, + Any, + Set, + Awaitable, +) import brad.proto_gen.brad_pb2_grpc as brad_grpc -from brad.connection.schema import Schema +from brad.config.file import ConfigFile +from brad.connection.schema import Schema, DataType from brad.front_end.brad_interface import BradInterface from brad.front_end.grpc import BradGrpc from brad.front_end.session import SessionManager, SessionId @@ -15,9 +27,10 @@ logger = logging.getLogger(__name__) -# (query_string, vdbe_id, session_id, debug_info) -> (rows, schema) +# (query_string, vdbe_id, session_id, debug_info, retrieve_schema) -> (rows, schema) QueryHandler = Callable[ - [str, int, SessionId, Dict[str, Any]], Awaitable[Tuple[RowList, Optional[Schema]]] + [str, int, SessionId, Dict[str, Any], bool], + Awaitable[Tuple[RowList, Optional[Schema]]], ] @@ -33,11 +46,30 @@ def __init__( vdbe_mgr: VdbeFrontEndManager, session_mgr: SessionManager, handler: QueryHandler, + config: ConfigFile, ) -> None: self._vdbe_mgr = vdbe_mgr self._session_mgr = session_mgr self._handler = handler + self._config = config self._endpoints: Dict[int, Tuple[int, grpc.aio.Server, VdbeGrpcInterface]] = {} + self._flight_sql_endpoints: Dict[int, Tuple[int, VdbeFlightSqlServer]] = {} + + try: + # pylint: disable-next=import-error,no-name-in-module,unused-import + import brad.native.pybind_brad_server as brad_server + + self._use_flight_sql = self._config.flight_sql_mode() == "vdbe" + except ImportError: + self._use_flight_sql = False + + if self._use_flight_sql: + logger.info("Will start Flight SQL endpoints for VDBEs.") + else: + logger.info( + "Flight SQL endpoints for VDBEs are not available. " + "Using gRPC endpoints only." + ) async def initialize(self) -> None: for engine in self._vdbe_mgr.engines(): @@ -66,10 +98,29 @@ async def add_vdbe_endpoint(self, port: int, vdbe_id: int) -> None: grpc_server.add_insecure_port(f"0.0.0.0:{port}") await grpc_server.start() logger.info( - "Added VDBE endpoint for ID %d. Listening on port %d.", vdbe_id, port + "Added gRPC VDBE endpoint for ID %d. Listening on port %d.", vdbe_id, port ) self._endpoints[vdbe_id] = (port, grpc_server, query_service) + if self._use_flight_sql: + session_id, _ = await self._session_mgr.create_new_session() + # The flight SQL port is offset by 10,000 from the gRPC port. + flight_sql_port = port + 10_000 + flight_sql_server = VdbeFlightSqlServer( + vdbe_id=vdbe_id, + port=flight_sql_port, + main_loop=asyncio.get_event_loop(), + handler=self._handler, + session_id=session_id, + ) + flight_sql_server.start() + self._flight_sql_endpoints[vdbe_id] = (flight_sql_port, flight_sql_server) + logger.info( + "Added Flight SQL VDBE endpoint for ID %d. Listening on port %d.", + vdbe_id, + flight_sql_port, + ) + async def remove_vdbe_endpoint(self, vdbe_id: int) -> None: try: port, grpc_server, query_service = self._endpoints[vdbe_id] @@ -77,11 +128,29 @@ async def remove_vdbe_endpoint(self, vdbe_id: int) -> None: # See `brad.front_end.BradFrontEnd.serve_forever`. grpc_server.__del__() del self._endpoints[vdbe_id] - logger.info("Removed VDBE endpoint for ID %d (was port %d).", vdbe_id, port) + logger.info( + "Removed gRPC VDBE endpoint for ID %d (was port %d).", vdbe_id, port + ) + + except KeyError: + logger.error( + "Tried to remove gRPC VDBE endpoint for ID %d, but it was not found.", + vdbe_id, + ) + try: + port, flight_sql_server = self._flight_sql_endpoints[vdbe_id] + flight_sql_server.stop() + del self._flight_sql_endpoints[vdbe_id] + await self._session_mgr.end_session(flight_sql_server.session_id) + logger.info( + "Removed Flight SQL VDBE endpoint for ID %d (was port %d).", + vdbe_id, + port, + ) except KeyError: logger.error( - "Tried to remove VDBE endpoint for ID %d, but it was not found.", + "Tried to remove Flight SQL VDBE endpoint for ID %d, but it was not found.", vdbe_id, ) @@ -144,7 +213,9 @@ async def run_query_json( This method may throw an error to indicate a problem with the query. """ - results, _ = await self._handler(query, self._vdbe_id, session_id, debug_info) + results, _ = await self._handler( + query, self._vdbe_id, session_id, debug_info, False + ) return json.dumps(results, cls=DecimalEncoder, default=str) async def end_session(self, session_id: SessionId) -> None: @@ -156,3 +227,82 @@ async def end_all_sessions(self) -> None: self._our_sessions.clear() for session_id in our_sessions: await self._session_mgr.end_session(session_id) + + +class VdbeFlightSqlServer: + def __init__( + self, + *, + vdbe_id: int, + port: int, + main_loop: asyncio.AbstractEventLoop, + handler: QueryHandler, + session_id: SessionId, + ) -> None: + # pylint: disable-next=import-error,no-name-in-module + import brad.native.pybind_brad_server as brad_server + + # pylint: disable-next=c-extension-no-member + self._flight_sql_server = brad_server.BradFlightSqlServer() + self._flight_sql_server.init("0.0.0.0", port, self._handle_query) + self._thread = threading.Thread( + name=f"FlightSqlServer-{vdbe_id}", target=self._serve + ) + self._vdbe_id = vdbe_id + self._port = port + # Important: The endpoint manager is responsible for creating and + # terminating the session. + self.session_id = session_id + + self._main_loop = main_loop + self._handler = handler + + def start(self) -> None: + self._thread.start() + + def stop(self) -> None: + logger.info( + "BRAD FlightSQL server stopping (port %d, VDBE %d)...", + self._port, + self._vdbe_id, + ) + self._flight_sql_server.shutdown() + self._thread.join() + logger.info( + "BRAD FlightSQL server stopped (port %d, VDBE %d).", + self._port, + self._vdbe_id, + ) + + def _serve(self) -> None: + self._flight_sql_server.serve() + + def _handle_query(self, query: str) -> Tuple[RowList, Schema]: + # This method is called from a separate thread. So it's very important + # to schedule the handler on the main event loop thread. + debug_info: Dict[str, Any] = {} + future = asyncio.run_coroutine_threadsafe( # type: ignore + self._handler(query, self._vdbe_id, self.session_id, debug_info, True), # type: ignore + self._main_loop, + ) + row_result, schema = future.result() + assert schema is not None + + # We need to do extra processing for decimal fields since our C++ + # backend expects them as strings. + decimal_fields = [] + for idx, field in enumerate(schema.fields): + if field.data_type == DataType.Decimal: + decimal_fields.append(idx) + + if len(decimal_fields) > 0: + new_rows = [] + for row in row_result: + new_row = tuple( + str(value) if idx in decimal_fields else value + for idx, value in enumerate(row) + ) + new_rows.append(new_row) + row_result = new_rows + + return row_result, schema diff --git a/src/brad/front_end/vdbe/vdbe_front_end.py b/src/brad/front_end/vdbe/vdbe_front_end.py index 4aac2e98..a2e03aa5 100644 --- a/src/brad/front_end/vdbe/vdbe_front_end.py +++ b/src/brad/front_end/vdbe/vdbe_front_end.py @@ -130,6 +130,7 @@ def __init__( vdbe_mgr=self._vdbe_mgr, session_mgr=self._sessions, handler=self._run_query_impl, + config=self._config, ) self._shutdown_event = asyncio.Event()