Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ use arrow_ipc::{MessageHeader, root_as_message};
use arrow_schema::{ArrowError, Schema, SchemaRef};
use futures::{Stream, TryStreamExt, stream};
use prost::Message;
use tonic::transport::Channel;
use tonic::codegen::{Body, StdError};
use tonic::{IntoRequest, IntoStreamingRequest, Streaming};

/// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow data
/// by FlightSQL protocol.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct FlightSqlServiceClient<T> {
token: Option<String>,
headers: HashMap<String, String>,
Expand All @@ -71,14 +71,20 @@ pub struct FlightSqlServiceClient<T> {
/// A FlightSql protocol client that can run queries against FlightSql servers
/// This client is in the "experimental" stage. It is not guaranteed to follow the spec in all instances.
/// Github issues are welcomed.
impl FlightSqlServiceClient<Channel> {
impl<T> FlightSqlServiceClient<T>
where
T: tonic::client::GrpcService<tonic::body::Body>,
T::Error: Into<StdError>,
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
{
/// Creates a new FlightSql client that connects to a server over an arbitrary tonic `Channel`
pub fn new(channel: Channel) -> Self {
pub fn new(channel: T) -> Self {
Self::new_from_inner(FlightServiceClient::new(channel))
}

/// Creates a new higher level client with the provided lower level client
pub fn new_from_inner(inner: FlightServiceClient<Channel>) -> Self {
pub fn new_from_inner(inner: FlightServiceClient<T>) -> Self {
Self {
token: None,
flight_client: inner,
Expand All @@ -87,17 +93,17 @@ impl FlightSqlServiceClient<Channel> {
}

/// Return a reference to the underlying [`FlightServiceClient`]
pub fn inner(&self) -> &FlightServiceClient<Channel> {
pub fn inner(&self) -> &FlightServiceClient<T> {
&self.flight_client
}

/// Return a mutable reference to the underlying [`FlightServiceClient`]
pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
pub fn inner_mut(&mut self) -> &mut FlightServiceClient<T> {
&mut self.flight_client
}

/// Consume this client and return the underlying [`FlightServiceClient`]
pub fn into_inner(self) -> FlightServiceClient<Channel> {
pub fn into_inner(self) -> FlightServiceClient<T> {
self.flight_client
}

Expand Down Expand Up @@ -416,7 +422,10 @@ impl FlightSqlServiceClient<Channel> {
&mut self,
query: String,
transaction_id: Option<Bytes>,
) -> Result<PreparedStatement<Channel>, ArrowError> {
) -> Result<PreparedStatement<T>, ArrowError>
where
T: Clone,
{
let cmd = ActionCreatePreparedStatementRequest {
query,
transaction_id,
Expand Down Expand Up @@ -509,10 +518,10 @@ impl FlightSqlServiceClient<Channel> {
Ok(())
}

fn set_request_headers<T>(
fn set_request_headers<M>(
&self,
mut req: tonic::Request<T>,
) -> Result<tonic::Request<T>, ArrowError> {
mut req: tonic::Request<M>,
) -> Result<tonic::Request<M>, ArrowError> {
for (k, v) in &self.headers {
let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| {
ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}"))
Expand All @@ -532,6 +541,16 @@ impl FlightSqlServiceClient<Channel> {
}
}

impl<T: Clone> Clone for FlightSqlServiceClient<T> {
fn clone(&self) -> Self {
Self {
headers: self.headers.clone(),
token: self.token.clone(),
flight_client: self.flight_client.clone(),
}
}
}

/// A PreparedStatement
#[derive(Debug, Clone)]
pub struct PreparedStatement<T> {
Expand All @@ -542,9 +561,15 @@ pub struct PreparedStatement<T> {
parameter_schema: Schema,
}

impl PreparedStatement<Channel> {
impl<T> PreparedStatement<T>
where
T: tonic::client::GrpcService<tonic::body::Body>,
T::Error: Into<StdError>,
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
{
pub(crate) fn new(
flight_client: FlightSqlServiceClient<Channel>,
flight_client: FlightSqlServiceClient<T>,
handle: impl Into<Bytes>,
dataset_schema: Schema,
parameter_schema: Schema,
Expand Down
Loading