From 70ba71688b74dcf718fb4dd9496a59b5617243bb Mon Sep 17 00:00:00 2001 From: "Alina (Xi) Li" Date: Fri, 28 Nov 2025 14:23:12 -0800 Subject: [PATCH 1/2] Extract implementation of gh-46574 Address more feedback Avoid using "using" in Headers Add `server->Wait` call Co-Authored-By: justing-bq --- cpp/src/arrow/flight/sql/odbc/odbc_api.cc | 8 +- .../flight/sql/odbc/odbc_impl/CMakeLists.txt | 4 +- .../sql/odbc/odbc_impl/attribute_utils.h | 34 ++--- .../sql/odbc/odbc_impl/encoding_utils.h | 27 ++-- .../odbc/odbc_impl/flight_sql_auth_method.cc | 4 +- .../odbc/odbc_impl/flight_sql_connection.cc | 21 ++- .../odbc_impl/flight_sql_get_tables_reader.cc | 4 +- .../odbc_impl/flight_sql_get_tables_reader.h | 6 +- .../flight_sql_get_type_info_reader.cc | 1 + .../flight_sql_get_type_info_reader.h | 26 ++-- .../odbc/odbc_impl/flight_sql_result_set.cc | 6 +- .../odbc/odbc_impl/flight_sql_result_set.h | 1 + .../odbc/odbc_impl/flight_sql_statement.cc | 78 ++++++----- .../sql/odbc/odbc_impl/flight_sql_statement.h | 3 +- .../flight_sql_statement_get_tables.cc | 47 ++++--- .../flight_sql_statement_get_tables.h | 27 ++-- .../flight_sql_stream_chunk_buffer.cc | 51 +++++-- .../flight_sql_stream_chunk_buffer.h | 5 +- .../flight_sql_stream_chunk_buffer_test.cc | 131 ++++++++++++++++++ .../sql/odbc/odbc_impl/get_info_cache.cc | 15 +- .../sql/odbc/odbc_impl/get_info_cache.h | 4 +- .../flight/sql/odbc/odbc_impl/spi/statement.h | 4 +- .../flight/sql/odbc/odbc_impl/system_dsn.h | 8 +- .../arrow/flight/sql/odbc/odbc_impl/util.cc | 1 + .../arrow/flight/sql/odbc/odbc_impl/util.h | 33 +++-- .../flight/sql/odbc/tests/odbc_test_suite.cc | 8 +- 26 files changed, 370 insertions(+), 187 deletions(-) create mode 100644 cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc index 19c6e62eba4..336d7d4f217 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc @@ -738,7 +738,7 @@ SQLRETURN SQLSetConnectAttr(SQLHDBC conn, SQLINTEGER attr, SQLPOINTER value_ptr, // entries in the properties. void LoadPropertiesFromDSN(const std::string& dsn, Connection::ConnPropertyMap& properties) { - arrow::flight::sql::odbc::config::Configuration config; + config::Configuration config; config.LoadDsn(dsn); Connection::ConnPropertyMap dsn_properties = config.GetProperties(); for (auto& [key, value] : dsn_properties) { @@ -796,7 +796,7 @@ SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND window_handle, // Load the DSN window according to driver_completion if (driver_completion == SQL_DRIVER_PROMPT) { // Load DSN window before first attempt to connect - arrow::flight::sql::odbc::config::Configuration config; + config::Configuration config; if (!DisplayConnectionWindow(window_handle, config, properties)) { return static_cast(SQL_NO_DATA); } @@ -809,7 +809,7 @@ SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND window_handle, // If first connection fails due to missing attributes, load // the DSN window and try to connect again if (!missing_properties.empty()) { - arrow::flight::sql::odbc::config::Configuration config; + config::Configuration config; missing_properties.clear(); if (!DisplayConnectionWindow(window_handle, config, properties)) { @@ -855,7 +855,7 @@ SQLRETURN SQLConnect(SQLHDBC conn, SQLWCHAR* dsn_name, SQLSMALLINT dsn_name_len, ODBCConnection* connection = reinterpret_cast(conn); std::string dsn = SqlWcharToString(dsn_name, dsn_name_len); - Configuration config; + config::Configuration config; config.LoadDsn(dsn); if (user_name) { diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt index 8f09fccd71d..71a315660bf 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt @@ -164,9 +164,11 @@ add_arrow_test(odbc_spi_impl_test accessors/time_array_accessor_test.cc accessors/timestamp_array_accessor_test.cc flight_sql_connection_test.cc + flight_sql_stream_chunk_buffer_test.cc parse_table_types_test.cc json_converter_test.cc record_batch_transformer_test.cc util_test.cc EXTRA_LINK_LIBS - arrow_odbc_spi_impl) + arrow_odbc_spi_impl + arrow_flight_testing_shared) diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h index 8c5eae59f7e..315d854b60d 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h @@ -30,10 +30,6 @@ // GH-48083 TODO: replace `namespace ODBC` with `namespace arrow::flight::sql::odbc` namespace ODBC { -using arrow::flight::sql::odbc::Diagnostics; -using arrow::flight::sql::odbc::DriverException; -using arrow::flight::sql::odbc::WcsToUtf8; - template inline void GetAttribute(T attribute_value, SQLPOINTER output, O output_size, O* output_len_ptr) { @@ -70,7 +66,7 @@ inline SQLRETURN GetAttributeUTF8(std::string_view attribute_value, SQLPOINTER o template inline SQLRETURN GetAttributeUTF8(std::string_view attribute_value, SQLPOINTER output, O output_size, O* output_len_ptr, - Diagnostics& diagnostics) { + arrow::flight::sql::odbc::Diagnostics& diagnostics) { SQLRETURN result = GetAttributeUTF8(attribute_value, output, output_size, output_len_ptr); if (SQL_SUCCESS_WITH_INFO == result) { @@ -85,10 +81,11 @@ inline SQLRETURN GetAttributeSQLWCHAR(std::string_view attribute_value, O output_size, O* output_len_ptr) { size_t length = ConvertToSqlWChar( attribute_value, reinterpret_cast(output), - is_length_in_bytes ? output_size : output_size * GetSqlWCharSize()); + is_length_in_bytes ? output_size + : output_size * arrow::flight::sql::odbc::GetSqlWCharSize()); if (!is_length_in_bytes) { - length = length / GetSqlWCharSize(); + length = length / arrow::flight::sql::odbc::GetSqlWCharSize(); } if (output_len_ptr) { @@ -97,17 +94,19 @@ inline SQLRETURN GetAttributeSQLWCHAR(std::string_view attribute_value, if (output && output_size < - static_cast(length + (is_length_in_bytes ? GetSqlWCharSize() : 1))) { + static_cast(length + (is_length_in_bytes + ? arrow::flight::sql::odbc::GetSqlWCharSize() + : 1))) { return SQL_SUCCESS_WITH_INFO; } return SQL_SUCCESS; } template -inline SQLRETURN GetAttributeSQLWCHAR(std::string_view attribute_value, - bool is_length_in_bytes, SQLPOINTER output, - O output_size, O* output_len_ptr, - Diagnostics& diagnostics) { +inline SQLRETURN GetAttributeSQLWCHAR( + const std::string& attribute_value, bool is_length_in_bytes, SQLPOINTER output, + O output_size, O* output_len_ptr, + arrow::flight::sql::odbc::Diagnostics& diagnostics) { SQLRETURN result = GetAttributeSQLWCHAR(attribute_value, is_length_in_bytes, output, output_size, output_len_ptr); if (SQL_SUCCESS_WITH_INFO == result) { @@ -120,7 +119,7 @@ template inline SQLRETURN GetStringAttribute(bool is_unicode, std::string_view attribute_value, bool is_length_in_bytes, SQLPOINTER output, O output_size, O* output_len_ptr, - Diagnostics& diagnostics) { + arrow::flight::sql::odbc::Diagnostics& diagnostics) { SQLRETURN result = SQL_SUCCESS; if (is_unicode) { result = GetAttributeSQLWCHAR(attribute_value, is_length_in_bytes, output, @@ -158,9 +157,11 @@ inline void SetAttributeSQLWCHAR(SQLPOINTER new_value, SQLINTEGER input_length_i std::string& attribute_to_write) { thread_local std::vector utf8_str; if (input_length_in_bytes == SQL_NTS) { - WcsToUtf8(new_value, &utf8_str); + arrow::flight::sql::odbc::WcsToUtf8(new_value, &utf8_str); } else { - WcsToUtf8(new_value, input_length_in_bytes / GetSqlWCharSize(), &utf8_str); + arrow::flight::sql::odbc::WcsToUtf8( + new_value, input_length_in_bytes / arrow::flight::sql::odbc::GetSqlWCharSize(), + &utf8_str); } attribute_to_write.assign((char*)utf8_str.data()); } @@ -168,7 +169,8 @@ inline void SetAttributeSQLWCHAR(SQLPOINTER new_value, SQLINTEGER input_length_i template void CheckIfAttributeIsSetToOnlyValidValue(SQLPOINTER value, T allowed_value) { if (static_cast(reinterpret_cast(value)) != allowed_value) { - throw DriverException("Optional feature not implemented", "HYC00"); + throw arrow::flight::sql::odbc::DriverException("Optional feature not implemented", + "HYC00"); } } } // namespace ODBC diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h index 66e5c3bf0d8..7f8a4a7ef85 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h @@ -32,17 +32,12 @@ namespace ODBC { -using arrow::flight::sql::odbc::DriverException; -using arrow::flight::sql::odbc::GetSqlWCharSize; -using arrow::flight::sql::odbc::Utf8ToWcs; -using arrow::flight::sql::odbc::WcsToUtf8; - // Return the number of bytes required for the conversion. template inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer, SQLLEN buffer_size_in_bytes) { thread_local std::vector wstr; - Utf8ToWcs(str.data(), str.size(), &wstr); + arrow::flight::sql::odbc::Utf8ToWcs(str.data(), str.size(), &wstr); SQLLEN value_length_in_bytes = wstr.size(); if (buffer) { @@ -51,11 +46,14 @@ inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer, // Write a NUL terminator if (buffer_size_in_bytes >= - value_length_in_bytes + static_cast(GetSqlWCharSize())) { - reinterpret_cast(buffer)[value_length_in_bytes / GetSqlWCharSize()] = + value_length_in_bytes + + static_cast(arrow::flight::sql::odbc::GetSqlWCharSize())) { + reinterpret_cast( + buffer)[value_length_in_bytes / arrow::flight::sql::odbc::GetSqlWCharSize()] = '\0'; } else { - SQLLEN num_chars_written = buffer_size_in_bytes / GetSqlWCharSize(); + SQLLEN num_chars_written = + buffer_size_in_bytes / arrow::flight::sql::odbc::GetSqlWCharSize(); // If we failed to even write one char, the buffer is too small to hold a // NUL-terminator. if (num_chars_written > 0) { @@ -68,15 +66,16 @@ inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer, inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer, SQLLEN buffer_size_in_bytes) { - switch (GetSqlWCharSize()) { + switch (arrow::flight::sql::odbc::GetSqlWCharSize()) { case sizeof(char16_t): return ConvertToSqlWChar(str, buffer, buffer_size_in_bytes); case sizeof(char32_t): return ConvertToSqlWChar(str, buffer, buffer_size_in_bytes); default: assert(false); - throw DriverException("Encoding is unsupported, SQLWCHAR size: " + - std::to_string(GetSqlWCharSize())); + throw arrow::flight::sql::odbc::DriverException( + "Encoding is unsupported, SQLWCHAR size: " + + std::to_string(arrow::flight::sql::odbc::GetSqlWCharSize())); } } @@ -92,9 +91,9 @@ inline std::string SqlWcharToString(SQLWCHAR* wchar_msg, SQLINTEGER msg_len = SQ thread_local std::vector utf8_str; if (msg_len == SQL_NTS) { - WcsToUtf8((void*)wchar_msg, &utf8_str); + arrow::flight::sql::odbc::WcsToUtf8((void*)wchar_msg, &utf8_str); } else { - WcsToUtf8((void*)wchar_msg, msg_len, &utf8_str); + arrow::flight::sql::odbc::WcsToUtf8((void*)wchar_msg, msg_len, &utf8_str); } return std::string(utf8_str.begin(), utf8_str.end()); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc index bdf7f71589c..b0090a8cf74 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc @@ -44,8 +44,8 @@ class NoOpClientAuthHandler : public ClientAuthHandler { NoOpClientAuthHandler() {} Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override { - // Write a blank string. The server should ignore this and just accept any Handshake - // request. + // The server should ignore this and just accept any Handshake + // request. Some servers do not allow authentication with no handshakes. return outgoing->Write(std::string()); } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc index e18a58d069f..0a00afd7f5e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc @@ -99,14 +99,9 @@ inline std::string GetCerts() { return ""; } #endif -// Case insensitive comparator that takes string_view -struct CaseInsensitiveComparatorStrView { - bool operator()(std::string_view s1, std::string_view s2) const { - return boost::lexicographical_compare(s1, s2, boost::is_iless()); - } -}; - -const std::set BUILT_IN_PROPERTIES = { +const std::set BUILT_IN_PROPERTIES = { + FlightSqlConnection::DRIVER, + FlightSqlConnection::DSN, FlightSqlConnection::HOST, FlightSqlConnection::PORT, FlightSqlConnection::USER, @@ -160,14 +155,14 @@ void FlightSqlConnection::Connect(const ConnPropertyMap& properties, auto flight_ssl_configs = LoadFlightSslConfigs(properties); Location location = BuildLocation(properties, missing_attr, flight_ssl_configs); - FlightClientOptions client_options = + client_options_ = BuildFlightClientOptions(properties, missing_attr, flight_ssl_configs); const std::shared_ptr& cookie_factory = GetCookieFactory(); - client_options.middleware.push_back(cookie_factory); + client_options_.middleware.push_back(cookie_factory); std::unique_ptr flight_client; - ThrowIfNotOK(FlightClient::Connect(location, client_options).Value(&flight_client)); + ThrowIfNotOK(FlightClient::Connect(location, client_options_).Value(&flight_client)); PopulateMetadataSettings(properties); PopulateCallOptions(properties); @@ -370,7 +365,7 @@ void FlightSqlConnection::Close() { std::shared_ptr FlightSqlConnection::CreateStatement() { return std::shared_ptr(new FlightSqlStatement( - diagnostics_, *sql_client_, call_options_, metadata_settings_)); + diagnostics_, *sql_client_, client_options_, call_options_, metadata_settings_)); } bool FlightSqlConnection::SetAttribute(Connection::AttributeId attribute, @@ -416,7 +411,7 @@ FlightSqlConnection::FlightSqlConnection(OdbcVersion odbc_version, const std::string& driver_version) : diagnostics_("Apache Arrow", "Flight SQL", odbc_version), odbc_version_(odbc_version), - info_(call_options_, sql_client_, driver_version), + info_(client_options_, call_options_, sql_client_, driver_version), closed_(true) { attribute_[CONNECTION_DEAD] = static_cast(SQL_TRUE); attribute_[LOGIN_TIMEOUT] = static_cast(0); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.cc index 5fe6069648f..ebff8c40f2c 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.cc @@ -36,7 +36,7 @@ GetTablesReader::GetTablesReader(std::shared_ptr record_batch) bool GetTablesReader::Next() { return ++current_row_ < record_batch_->num_rows(); } -optional GetTablesReader::GetCatalogName() { +std::optional GetTablesReader::GetCatalogName() { const auto& array = checked_pointer_cast(record_batch_->column(0)); if (array->IsNull(current_row_)) return nullopt; @@ -44,7 +44,7 @@ optional GetTablesReader::GetCatalogName() { return array->GetString(current_row_); } -optional GetTablesReader::GetDbSchemaName() { +std::optional GetTablesReader::GetDbSchemaName() { const auto& array = checked_pointer_cast(record_batch_->column(1)); if (array->IsNull(current_row_)) return nullopt; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.h index 6cc464d072b..ad9739d87bb 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.h @@ -20,8 +20,6 @@ namespace arrow::flight::sql::odbc { -using std::optional; - class GetTablesReader { private: std::shared_ptr record_batch_; @@ -32,9 +30,9 @@ class GetTablesReader { bool Next(); - optional GetCatalogName(); + std::optional GetCatalogName(); - optional GetDbSchemaName(); + std::optional GetDbSchemaName(); std::string GetTableName(); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.cc index 7f290096e5a..13115c88dbd 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.cc @@ -28,6 +28,7 @@ namespace arrow::flight::sql::odbc { using arrow::internal::checked_pointer_cast; using std::nullopt; +using std::optional; GetTypeInfoReader::GetTypeInfoReader(std::shared_ptr record_batch) : record_batch_(std::move(record_batch)), current_row_(-1) {} diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.h index a7c1d51182f..ce38a925ae1 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.h @@ -20,8 +20,6 @@ namespace arrow::flight::sql::odbc { -using std::optional; - class GetTypeInfoReader { private: std::shared_ptr record_batch_; @@ -36,13 +34,13 @@ class GetTypeInfoReader { int32_t GetDataType(); - optional GetColumnSize(); + std::optional GetColumnSize(); - optional GetLiteralPrefix(); + std::optional GetLiteralPrefix(); - optional GetLiteralSuffix(); + std::optional GetLiteralSuffix(); - optional> GetCreateParams(); + std::optional> GetCreateParams(); int32_t GetNullable(); @@ -50,25 +48,25 @@ class GetTypeInfoReader { int32_t GetSearchable(); - optional GetUnsignedAttribute(); + std::optional GetUnsignedAttribute(); bool GetFixedPrecScale(); - optional GetAutoIncrement(); + std::optional GetAutoIncrement(); - optional GetLocalTypeName(); + std::optional GetLocalTypeName(); - optional GetMinimumScale(); + std::optional GetMinimumScale(); - optional GetMaximumScale(); + std::optional GetMaximumScale(); int32_t GetSqlDataType(); - optional GetDatetimeSubcode(); + std::optional GetDatetimeSubcode(); - optional GetNumPrecRadix(); + std::optional GetNumPrecRadix(); - optional GetIntervalPrecision(); + std::optional GetIntervalPrecision(); }; } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc index 19149b3c48d..80967b9f200 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.cc @@ -29,12 +29,12 @@ namespace arrow::flight::sql::odbc { FlightSqlResultSet::FlightSqlResultSet( - FlightSqlClient& flight_sql_client, const FlightCallOptions& call_options, - const std::shared_ptr& flight_info, + FlightSqlClient& flight_sql_client, const FlightClientOptions& client_options, + const FlightCallOptions& call_options, const std::shared_ptr& flight_info, const std::shared_ptr& transformer, Diagnostics& diagnostics, const MetadataSettings& metadata_settings) : metadata_settings_(metadata_settings), - chunk_buffer_(flight_sql_client, call_options, flight_info, + chunk_buffer_(flight_sql_client, client_options, call_options, flight_info, metadata_settings_.chunk_buffer_capacity), transformer_(transformer), metadata_(transformer diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h index 6083b332824..ac2ae80e010 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set.h @@ -51,6 +51,7 @@ class FlightSqlResultSet : public ResultSet { ~FlightSqlResultSet() override; FlightSqlResultSet(FlightSqlClient& flight_sql_client, + const FlightClientOptions& client_options, const FlightCallOptions& call_options, const std::shared_ptr& flight_info, const std::shared_ptr& transformer, diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc index 30eb1fdf44a..785a04c7b0e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc @@ -41,9 +41,10 @@ using util::ThrowIfNotOK; namespace { -void ClosePreparedStatementIfAny(std::shared_ptr& prepared_statement) { +void ClosePreparedStatementIfAny(std::shared_ptr& prepared_statement, + const FlightCallOptions& options) { if (prepared_statement != nullptr) { - ThrowIfNotOK(prepared_statement->Close()); + ThrowIfNotOK(prepared_statement->Close(options)); prepared_statement.reset(); } } @@ -52,11 +53,13 @@ void ClosePreparedStatementIfAny(std::shared_ptr& prepared_st FlightSqlStatement::FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client, + FlightClientOptions client_options, FlightCallOptions call_options, const MetadataSettings& metadata_settings) : diagnostics_("Apache Arrow", diagnostics.GetDataSourceComponent(), diagnostics.GetOdbcVersion()), sql_client_(sql_client), + client_options_(std::move(client_options)), call_options_(std::move(call_options)), metadata_settings_(metadata_settings) { attribute_[METADATA_ID] = static_cast(SQL_FALSE); @@ -97,7 +100,7 @@ boost::optional FlightSqlStatement::GetAttribute( boost::optional> FlightSqlStatement::Prepare( const std::string& query) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.Prepare(call_options_, query); @@ -111,27 +114,30 @@ boost::optional> FlightSqlStatement::Prepare( } bool FlightSqlStatement::ExecutePrepared() { + // GH-47990 TODO: use DCHECK instead of assert assert(prepared_statement_.get() != nullptr); - Result> result = prepared_statement_->Execute(); + Result> result = + prepared_statement_->Execute(call_options_); + ThrowIfNotOK(result.status()); current_result_set_ = std::make_shared( - sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, - metadata_settings_); + sql_client_, client_options_, call_options_, result.ValueOrDie(), nullptr, + diagnostics_, metadata_settings_); return true; } bool FlightSqlStatement::Execute(const std::string& query) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.Execute(call_options_, query); ThrowIfNotOK(result.status()); current_result_set_ = std::make_shared( - sql_client_, call_options_, result.ValueOrDie(), nullptr, diagnostics_, - metadata_settings_); + sql_client_, client_options_, call_options_, result.ValueOrDie(), nullptr, + diagnostics_, metadata_settings_); return true; } @@ -146,33 +152,35 @@ std::shared_ptr FlightSqlStatement::GetTables( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* table_type, const ColumnNames& column_names) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); std::vector table_types; if ((catalog_name && *catalog_name == "%") && (schema_name && schema_name->empty()) && (table_name && table_name->empty())) { - current_result_set_ = GetTablesForSQLAllCatalogs( - column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + current_result_set_ = + GetTablesForSQLAllCatalogs(column_names, client_options_, call_options_, + sql_client_, diagnostics_, metadata_settings_); } else if ((catalog_name && catalog_name->empty()) && (schema_name && *schema_name == "%") && (table_name && table_name->empty())) { - current_result_set_ = - GetTablesForSQLAllDbSchemas(column_names, call_options_, sql_client_, schema_name, - diagnostics_, metadata_settings_); + current_result_set_ = GetTablesForSQLAllDbSchemas( + column_names, client_options_, call_options_, sql_client_, schema_name, + diagnostics_, metadata_settings_); } else if ((catalog_name && catalog_name->empty()) && (schema_name && schema_name->empty()) && (table_name && table_name->empty()) && (table_type && *table_type == "%")) { - current_result_set_ = GetTablesForSQLAllTableTypes( - column_names, call_options_, sql_client_, diagnostics_, metadata_settings_); + current_result_set_ = + GetTablesForSQLAllTableTypes(column_names, client_options_, call_options_, + sql_client_, diagnostics_, metadata_settings_); } else { if (table_type) { ParseTableTypes(*table_type, table_types); } current_result_set_ = GetTablesForGenericUse( - column_names, call_options_, sql_client_, catalog_name, schema_name, table_name, - table_types, diagnostics_, metadata_settings_); + column_names, client_options_, call_options_, sql_client_, catalog_name, + schema_name, table_name, table_types, diagnostics_, metadata_settings_); } return current_result_set_; @@ -199,7 +207,7 @@ std::shared_ptr FlightSqlStatement::GetTables_V3( std::shared_ptr FlightSqlStatement::GetColumns_V2( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* column_name) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetTables( call_options_, catalog_name, schema_name, table_name, true, nullptr); @@ -210,9 +218,9 @@ std::shared_ptr FlightSqlStatement::GetColumns_V2( auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_2, column_name); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } @@ -220,7 +228,7 @@ std::shared_ptr FlightSqlStatement::GetColumns_V2( std::shared_ptr FlightSqlStatement::GetColumns_V3( const std::string* catalog_name, const std::string* schema_name, const std::string* table_name, const std::string* column_name) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetTables( call_options_, catalog_name, schema_name, table_name, true, nullptr); @@ -231,15 +239,15 @@ std::shared_ptr FlightSqlStatement::GetColumns_V3( auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_3, column_name); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } std::shared_ptr FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetXdbcTypeInfo(call_options_); ThrowIfNotOK(result.status()); @@ -249,15 +257,15 @@ std::shared_ptr FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_2, data_type); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } std::shared_ptr FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) { - ClosePreparedStatementIfAny(prepared_statement_); + ClosePreparedStatementIfAny(prepared_statement_, call_options_); Result> result = sql_client_.GetXdbcTypeInfo(call_options_); ThrowIfNotOK(result.status()); @@ -267,9 +275,9 @@ std::shared_ptr FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) auto transformer = std::make_shared( metadata_settings_, OdbcVersion::V_3, data_type); - current_result_set_ = - std::make_shared(sql_client_, call_options_, flight_info, - transformer, diagnostics_, metadata_settings_); + current_result_set_ = std::make_shared( + sql_client_, client_options_, call_options_, flight_info, transformer, diagnostics_, + metadata_settings_); return current_result_set_; } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h index 36dc245c1d7..3593b2f774d 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h @@ -32,6 +32,7 @@ class FlightSqlStatement : public Statement { private: Diagnostics diagnostics_; std::map attribute_; + FlightClientOptions client_options_; FlightCallOptions call_options_; FlightSqlClient& sql_client_; std::shared_ptr current_result_set_; @@ -46,7 +47,7 @@ class FlightSqlStatement : public Statement { public: FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client, - FlightCallOptions call_options, + FlightClientOptions client_options, FlightCallOptions call_options, const MetadataSettings& metadata_settings); bool SetAttribute(StatementAttributeId attribute, const Attribute& value) override; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc index 1af2ab42bff..87c7ac0f532 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.cc @@ -66,9 +66,9 @@ void ParseTableTypes(const std::string& table_type, } std::shared_ptr GetTablesForSQLAllCatalogs( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings) { + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetCatalogs(call_options); std::shared_ptr schema; @@ -86,13 +86,15 @@ std::shared_ptr GetTablesForSQLAllCatalogs( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForSQLAllDbSchemas( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* schema_name, Diagnostics& diagnostics, + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* schema_name, Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetDbSchemas(call_options, nullptr, schema_name); @@ -112,14 +114,15 @@ std::shared_ptr GetTablesForSQLAllDbSchemas( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForSQLAllTableTypes( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings) { + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetTableTypes(call_options); std::shared_ptr schema; @@ -137,16 +140,17 @@ std::shared_ptr GetTablesForSQLAllTableTypes( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } std::shared_ptr GetTablesForGenericUse( - const ColumnNames& names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* catalog_name, - const std::string* schema_name, const std::string* table_name, - const std::vector& table_types, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings) { + const ColumnNames& names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* catalog_name, const std::string* schema_name, + const std::string* table_name, const std::vector& table_types, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings) { Result> result = sql_client.GetTables( call_options, catalog_name, schema_name, table_name, false, &table_types); @@ -165,8 +169,9 @@ std::shared_ptr GetTablesForGenericUse( .AddFieldOfNulls(names.remarks_column, arrow::utf8()) .Build(); - return std::make_shared( - sql_client, call_options, flight_info, transformer, diagnostics, metadata_settings); + return std::make_shared(sql_client, client_options, call_options, + flight_info, transformer, diagnostics, + metadata_settings); } } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h index 31abab91cb5..0c3ad10f97b 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement_get_tables.h @@ -40,25 +40,26 @@ void ParseTableTypes(const std::string& table_type, std::vector& table_types); std::shared_ptr GetTablesForSQLAllCatalogs( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings); + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings); std::shared_ptr GetTablesForSQLAllDbSchemas( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* schema_name, Diagnostics& diagnostics, + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* schema_name, Diagnostics& diagnostics, const MetadataSettings& metadata_settings); std::shared_ptr GetTablesForSQLAllTableTypes( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings); + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings); std::shared_ptr GetTablesForGenericUse( - const ColumnNames& column_names, FlightCallOptions& call_options, - FlightSqlClient& sql_client, const std::string* catalog_name, - const std::string* schema_name, const std::string* table_name, - const std::vector& table_types, Diagnostics& diagnostics, - const MetadataSettings& metadata_settings); + const ColumnNames& column_names, FlightClientOptions& client_options, + FlightCallOptions& call_options, FlightSqlClient& sql_client, + const std::string* catalog_name, const std::string* schema_name, + const std::string* table_name, const std::vector& table_types, + Diagnostics& diagnostics, const MetadataSettings& metadata_settings); } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc index 25bf04ea507..a01a0c2407d 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.cc @@ -23,34 +23,69 @@ namespace arrow::flight::sql::odbc { using arrow::Result; FlightStreamChunkBuffer::FlightStreamChunkBuffer( - FlightSqlClient& flight_sql_client, const FlightCallOptions& call_options, - const std::shared_ptr& flight_info, size_t queue_capacity) + FlightSqlClient& flight_sql_client, const FlightClientOptions& client_options, + const FlightCallOptions& call_options, const std::shared_ptr& flight_info, + size_t queue_capacity) : queue_(queue_capacity) { - // FIXME: Endpoint iteration should consider endpoints may be at different hosts for (const auto& endpoint : flight_info->endpoints()) { const Ticket& ticket = endpoint.ticket; - auto result = flight_sql_client.DoGet(call_options, ticket); + arrow::Result> result; + std::shared_ptr temp_flight_sql_client; + auto endpoint_locations = endpoint.locations; + if (endpoint_locations.empty()) { + // list of Locations needs to be empty to proceed + result = flight_sql_client.DoGet(call_options, ticket); + } else { + // If it is non-empty, the driver should create a FlightSqlClient to connect to one + // of the specified Locations directly. + + // GH-47117: Currently a new FlightClient will be made for each partition that + // returns a non-empty Location, which is then disposed of. It may be better to + // cache clients because a server may report the same Locations. It would also be + // good to identify when the reported Location is the same as the original + // connection's Location and skip creating a FlightClient in that scenario. + + std::unique_ptr temp_flight_client; + util::ThrowIfNotOK(FlightClient::Connect(endpoint_locations[0], client_options) + .Value(&temp_flight_client)); + temp_flight_sql_client = + std::make_shared(std::move(temp_flight_client)); + + result = temp_flight_sql_client->DoGet(call_options, ticket); + } + util::ThrowIfNotOK(result.status()); std::shared_ptr stream_reader_ptr(std::move(result.ValueOrDie())); - BlockingQueue>::Supplier supplier = [=] { + BlockingQueue, + std::shared_ptr>>::Supplier supplier = [=] { auto result = stream_reader_ptr->Next(); bool is_not_ok = !result.ok(); bool is_not_empty = result.ok() && (result.ValueOrDie().data != nullptr); - return boost::make_optional(is_not_ok || is_not_empty, std::move(result)); + // If result is valid, save the temp Flight SQL Client for future stream reader + // call. temp_flight_sql_client is intentionally null if the list of endpoint + // Locations is empty. + // After all data is fetched from reader, the temp client is closed. + + // gh-48084 Replace boost::optional with std::optional + return boost::make_optional( + is_not_ok || is_not_empty, + std::make_pair(std::move(result), temp_flight_sql_client)); }; queue_.AddProducer(std::move(supplier)); } } bool FlightStreamChunkBuffer::GetNext(FlightStreamChunk* chunk) { - Result result; - if (!queue_.Pop(&result)) { + std::pair, std::shared_ptr> + closeable_endpoint_stream_pair; + if (!queue_.Pop(&closeable_endpoint_stream_pair)) { return false; } + Result result = closeable_endpoint_stream_pair.first; if (!result.status().ok()) { Close(); throw DriverException(result.status().message()); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h index f59336c984d..696e67e5aa7 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h @@ -24,10 +24,13 @@ namespace arrow::flight::sql::odbc { class FlightStreamChunkBuffer { - BlockingQueue> queue_; + BlockingQueue< + std::pair, std::shared_ptr>> + queue_; public: FlightStreamChunkBuffer(FlightSqlClient& flight_sql_client, + const FlightClientOptions& client_options, const FlightCallOptions& call_options, const std::shared_ptr& flight_info, size_t queue_capacity = 5); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc new file mode 100644 index 00000000000..cbe5cd8f7e5 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer_test.cc @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/array.h" + +#include "arrow/testing/gtest_util.h" + +#include "arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h" +#include "arrow/flight/sql/odbc/odbc_impl/json_converter.h" +#include "arrow/flight/test_flight_server.h" +#include "arrow/flight/test_util.h" + +#include +#include + +namespace arrow::flight::sql::odbc { + +using arrow::Array; +using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClientOptions; +using arrow::flight::FlightDescriptor; +using arrow::flight::FlightEndpoint; +using arrow::flight::Location; +using arrow::flight::Ticket; +using arrow::flight::sql::FlightSqlClient; + +class FlightStreamChunkBufferTest : public ::testing::Test { + // Sets up two mock servers for each test case. + // This is for testing endpoint iteration only. + + protected: + void SetUp() override { + // Set up server 1 + server1 = std::make_shared(); + ASSERT_OK_AND_ASSIGN(auto location1, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options1(location1); + ASSERT_OK(server1->Init(options1)); + ASSERT_OK_AND_ASSIGN(server_location1, + Location::ForGrpcTcp("localhost", server1->port())); + + // Set up server 2 + server2 = std::make_shared(); + ASSERT_OK_AND_ASSIGN(auto location2, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options2(location2); + ASSERT_OK(server2->Init(options2)); + ASSERT_OK_AND_ASSIGN(server_location2, + Location::ForGrpcTcp("localhost", server2->port())); + + // Make SQL Client that is connected to server 1 + ASSERT_OK_AND_ASSIGN(auto client, arrow::flight::FlightClient::Connect(location1)); + sql_client.reset(new FlightSqlClient(std::move(client))); + } + + void TearDown() override { + ASSERT_OK(server1->Shutdown()); + ASSERT_OK(server1->Wait()); + ASSERT_OK(server2->Shutdown()); + ASSERT_OK(server1->Wait()); + } + + public: + arrow::flight::Location server_location1; + std::shared_ptr server1; + arrow::flight::Location server_location2; + std::shared_ptr server2; + std::shared_ptr sql_client; +}; + +FlightInfo MultipleEndpointsFlightInfo(Location location1, Location location2) { + // Sever will generate random data for `ticket-ints-1` + FlightEndpoint endpoint1({Ticket{"ticket-ints-1"}, {location1}, std::nullopt, {}}); + FlightEndpoint endpoint2({Ticket{"ticket-ints-1"}, {location2}, std::nullopt, {}}); + + FlightDescriptor descr1{FlightDescriptor::PATH, "", {"examples", "ints"}}; + + auto schema1 = arrow::flight::ExampleIntSchema(); + + return arrow::flight::MakeFlightInfo(*schema1, descr1, {endpoint1, endpoint2}, 1000, + 100000, false, ""); +} + +TEST_F(FlightStreamChunkBufferTest, TestMultipleEndpointsInt) { + FlightClientOptions client_options = FlightClientOptions::Defaults(); + FlightCallOptions options; + FlightInfo info = MultipleEndpointsFlightInfo(server_location1, server_location2); + std::shared_ptr info_ptr = std::make_shared(info); + + FlightStreamChunkBuffer chunk_buffer(*sql_client, client_options, options, info_ptr); + + FlightStreamChunk current_chunk; + + // Server returns 5 batch of results from each endpoints. + // Each batch contains 8 columns + int num_chunks = 0; + while (chunk_buffer.GetNext(¤t_chunk)) { + num_chunks++; + + int num_cols = current_chunk.data->num_columns(); + ASSERT_EQ(8, num_cols); + + for (int i = 0; i < num_cols; i++) { + auto array = current_chunk.data->column(i); + // Each array has random length + ASSERT_GT(array->length(), 0); + + std::vector int_types = { + Type::type::INT8, Type::type::UINT8, Type::type::INT16, Type::type::UINT16, + Type::type::INT32, Type::type::UINT32, Type::type::INT64, Type::type::UINT64}; + ASSERT_THAT(int_types, testing::Contains(array->type_id())); + } + } + + // Verify 5 batches of data is returned by each of the two endpoints. + // In total 10 batches should be returned. + ASSERT_EQ(10, num_chunks); +} +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc index bf2f6b6eca2..7f6ba8042de 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc @@ -199,10 +199,14 @@ inline void SetDefaultIfMissing(std::unordered_map& } // namespace -GetInfoCache::GetInfoCache(FlightCallOptions& call_options, +GetInfoCache::GetInfoCache(FlightClientOptions& client_options, + FlightCallOptions& call_options, std::unique_ptr& client, const std::string& driver_version) - : call_options_(call_options), sql_client_(client), has_server_info_(false) { + : client_options_(client_options), + call_options_(call_options), + sql_client_(client), + has_server_info_(false) { info_[SQL_DRIVER_NAME] = "Arrow Flight ODBC Driver"; info_[SQL_DRIVER_VER] = util::ConvertToDBMSVer(driver_version); @@ -283,7 +287,8 @@ bool GetInfoCache::LoadInfoFromServer() { arrow::Result> result = sql_client_->GetSqlInfo(call_options_, {}); util::ThrowIfNotOK(result.status()); - FlightStreamChunkBuffer chunk_iter(*sql_client_, call_options_, result.ValueOrDie()); + FlightStreamChunkBuffer chunk_iter(*sql_client_, client_options_, call_options_, + result.ValueOrDie()); FlightStreamChunk chunk; bool supports_correlation_name = false; @@ -311,8 +316,8 @@ bool GetInfoCache::LoadInfoFromServer() { std::string server_name( reinterpret_cast(scalar->child_value().get())->view()); - // TODO: Consider creating different properties in GetSqlInfo. - // TODO: Investigate if SQL_SERVER_NAME should just be the host + // GH-47855 TODO: Consider creating different properties in GetSqlInfo. + // GH-47856 TODO: Investigate if SQL_SERVER_NAME should just be the host // address as well. In JDBC, FLIGHT_SQL_SERVER_NAME is only used for // the DatabaseProductName. info_[SQL_SERVER_NAME] = server_name; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h index d0e0efd159f..a1452e4b466 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.h @@ -30,13 +30,15 @@ namespace arrow::flight::sql::odbc { class GetInfoCache { private: std::unordered_map info_; + FlightClientOptions& client_options_; FlightCallOptions& call_options_; std::unique_ptr& sql_client_; std::mutex mutex_; std::atomic has_server_info_; public: - GetInfoCache(FlightCallOptions& call_options, std::unique_ptr& client, + GetInfoCache(FlightClientOptions& client_options, FlightCallOptions& call_options, + std::unique_ptr& client, const std::string& driver_version); void SetProperty(uint16_t property, Connection::Info value); Connection::Info GetInfo(uint16_t info_type); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/statement.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/statement.h index 970e447dfdc..390950e7413 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/statement.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/statement.h @@ -24,8 +24,6 @@ namespace arrow::flight::sql::odbc { -using boost::optional; - class ResultSet; class ResultSetMetadata; @@ -69,7 +67,7 @@ class Statement { /// /// \param attribute Attribute identifier to be retrieved. /// \return Value associated with the attribute. - virtual optional GetAttribute( + virtual boost::optional GetAttribute( Statement::StatementAttributeId attribute) = 0; /// \brief Prepares the statement. diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h index 5d23c3dfcaf..f1fee84fbd4 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h @@ -23,8 +23,6 @@ namespace arrow::flight::sql::odbc { -using config::Configuration; - #if defined _WIN32 /** * Display connection window for user to configure connection parameters. @@ -33,7 +31,7 @@ using config::Configuration; * @param config Output configuration. * @return True on success and false on fail. */ -bool DisplayConnectionWindow(void* window_parent, Configuration& config); +bool DisplayConnectionWindow(void* window_parent, config::Configuration& config); /** * For SQLDriverConnect. @@ -45,7 +43,7 @@ bool DisplayConnectionWindow(void* window_parent, Configuration& config); * @param properties Output properties. * @return True on success and false on fail. */ -bool DisplayConnectionWindow(void* window_parent, Configuration& config, +bool DisplayConnectionWindow(void* window_parent, config::Configuration& config, Connection::ConnPropertyMap& properties); #endif @@ -56,7 +54,7 @@ bool DisplayConnectionWindow(void* window_parent, Configuration& config, * @param driver Driver. * @return True on success and false on fail. */ -bool RegisterDsn(const Configuration& config, LPCWSTR driver); +bool RegisterDsn(const config::Configuration& config, LPCWSTR driver); /** * Unregister specified DSN. diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc index 59ee7dda565..b951fa999a3 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc @@ -64,6 +64,7 @@ CDataType GetDefaultCCharType(bool use_wide_char) { using std::make_optional; using std::nullopt; +using std::optional; /// \brief Returns the mapping from Arrow type to SqlDataType /// \param field the field to return the SqlDataType for diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h index c17e77e7de8..d8097328501 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h @@ -49,8 +49,6 @@ namespace util { typedef std::function(const std::shared_ptr&)> ArrayConvertTask; -using std::optional; - inline void ThrowIfNotOK(const Status& status) { if (!status.ok()) { throw DriverException(status.message()); @@ -63,7 +61,7 @@ inline bool CheckIfSetToOnlyValidValue(const AttributeTypeT& value, T allowed_va } template -Status AppendToBuilder(BUILDER& builder, optional opt_value) { +Status AppendToBuilder(BUILDER& builder, std::optional opt_value) { if (opt_value) { return builder.Append(*opt_value); } else { @@ -87,29 +85,30 @@ CDataType ConvertCDataTypeFromV2ToV3(int16_t data_type_v2); std::string GetTypeNameFromSqlDataType(int16_t data_type); -optional GetRadixFromSqlDataType(SqlDataType data_type); +std::optional GetRadixFromSqlDataType(SqlDataType data_type); int16_t GetNonConciseDataType(SqlDataType data_type); -optional GetSqlDateTimeSubCode(SqlDataType data_type); +std::optional GetSqlDateTimeSubCode(SqlDataType data_type); -optional GetCharOctetLength(SqlDataType data_type, - const arrow::Result& column_size, - const int32_t decimal_precison = 0); +std::optional GetCharOctetLength(SqlDataType data_type, + const arrow::Result& column_size, + const int32_t decimal_precison = 0); -optional GetBufferLength(SqlDataType data_type, - const optional& column_size); +std::optional GetBufferLength(SqlDataType data_type, + const std::optional& column_size); -optional GetLength(SqlDataType data_type, const optional& column_size); +std::optional GetLength(SqlDataType data_type, + const std::optional& column_size); -optional GetTypeScale(SqlDataType data_type, - const optional& type_scale); +std::optional GetTypeScale(SqlDataType data_type, + const std::optional& type_scale); -optional GetColumnSize(SqlDataType data_type, - const optional& column_size); +std::optional GetColumnSize(SqlDataType data_type, + const std::optional& column_size); -optional GetDisplaySize(SqlDataType data_type, - const optional& column_size); +std::optional GetDisplaySize(SqlDataType data_type, + const std::optional& column_size); std::string ConvertSqlPatternToRegexString(const std::string& pattern); diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc index dfef9dcd1d0..16766acb04c 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc @@ -490,22 +490,22 @@ std::wstring ConvertToWString(const std::vector& str_val, SQLSMALLINT } else { EXPECT_GT(str_len, 0); EXPECT_LE(str_len, static_cast(kOdbcBufferSize)); - attr_str = std::wstring(str_val.begin(), - str_val.begin() + str_len / ODBC::GetSqlWCharSize()); + attr_str = + std::wstring(str_val.begin(), str_val.begin() + str_len / GetSqlWCharSize()); } return attr_str; } void CheckStringColumnW(SQLHSTMT stmt, int col_id, const std::wstring& expected) { SQLWCHAR buf[1024]; - SQLLEN buf_len = sizeof(buf) * ODBC::GetSqlWCharSize(); + SQLLEN buf_len = sizeof(buf) * GetSqlWCharSize(); ASSERT_EQ(SQL_SUCCESS, SQLGetData(stmt, col_id, SQL_C_WCHAR, buf, buf_len, &buf_len)); EXPECT_GT(buf_len, 0); // returned buf_len is in bytes so convert to length in characters - size_t char_count = static_cast(buf_len) / ODBC::GetSqlWCharSize(); + size_t char_count = static_cast(buf_len) / GetSqlWCharSize(); std::wstring returned(buf, buf + char_count); EXPECT_EQ(expected, returned); From f3872a577970c0e69bc737d466a0b0e1b2aa6ed7 Mon Sep 17 00:00:00 2001 From: "Alina (Xi) Li" Date: Fri, 28 Nov 2025 15:12:44 -0800 Subject: [PATCH 2/2] trigger CI