Skip to content

Commit

Permalink
Fix flight sql do put handling, add bind parameter support to FlightS…
Browse files Browse the repository at this point in the history
…QL cli client (apache#4797)

* change Streaming<FlightData> to Peekable<Streaming<FlightData>>

* add explanatory comment

* working test

* trigger pre-commit hooks?

* Update arrow-flight/src/sql/client.rs

Co-authored-by: Raphael Taylor-Davies <[email protected]>

* remove unnecessary multi-thread annotation

* rework api

---------

Co-authored-by: Raphael Taylor-Davies <[email protected]>
  • Loading branch information
suremarc and tustvold authored Sep 18, 2023
1 parent 33b881d commit 47e8a8d
Show file tree
Hide file tree
Showing 5 changed files with 392 additions and 75 deletions.
9 changes: 5 additions & 4 deletions arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow_flight::sql::server::PeekableFlightDataStream;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use futures::{stream, Stream, TryStreamExt};
Expand Down Expand Up @@ -602,15 +603,15 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_statement_update(
&self,
_ticket: CommandStatementUpdate,
_request: Request<Streaming<FlightData>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Ok(FAKE_UPDATE_RESULT)
}

async fn do_put_substrait_plan(
&self,
_ticket: CommandStatementSubstraitPlan,
_request: Request<Streaming<FlightData>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_substrait_plan not implemented",
Expand All @@ -620,7 +621,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_prepared_statement_query(
&self,
_query: CommandPreparedStatementQuery,
_request: Request<Streaming<FlightData>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_query not implemented",
Expand All @@ -630,7 +631,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_put_prepared_statement_update(
&self,
_query: CommandPreparedStatementUpdate,
_request: Request<Streaming<FlightData>>,
_request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_update not implemented",
Expand Down
104 changes: 92 additions & 12 deletions arrow-flight/src/bin/flight_sql_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use std::{sync::Arc, time::Duration};
use std::{error::Error, sync::Arc, time::Duration};

use arrow_array::RecordBatch;
use arrow_cast::pretty::pretty_format_batches;
use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray};
use arrow_cast::{cast_with_options, pretty::pretty_format_batches, CastOptions};
use arrow_flight::{
sql::client::FlightSqlServiceClient, utils::flight_data_to_batches, FlightData,
FlightInfo,
};
use arrow_schema::{ArrowError, Schema};
use clap::Parser;
use clap::{Parser, Subcommand};
use futures::TryStreamExt;
use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
use tracing_log::log::info;
Expand Down Expand Up @@ -98,8 +99,20 @@ struct Args {
#[clap(flatten)]
client_args: ClientArgs,

/// SQL query.
query: String,
#[clap(subcommand)]
cmd: Command,
}

#[derive(Debug, Subcommand)]
enum Command {
StatementQuery {
query: String,
},
PreparedStatementQuery {
query: String,
#[clap(short, value_parser = parse_key_val)]
params: Vec<(String, String)>,
},
}

#[tokio::main]
Expand All @@ -108,12 +121,50 @@ async fn main() {
setup_logging();
let mut client = setup_client(args.client_args).await.expect("setup client");

let info = client
.execute(args.query, None)
let flight_info = match args.cmd {
Command::StatementQuery { query } => client
.execute(query, None)
.await
.expect("execute statement"),
Command::PreparedStatementQuery { query, params } => {
let mut prepared_stmt = client
.prepare(query, None)
.await
.expect("prepare statement");

if !params.is_empty() {
prepared_stmt
.set_parameters(
construct_record_batch_from_params(
&params,
prepared_stmt
.parameter_schema()
.expect("get parameter schema"),
)
.expect("construct parameters"),
)
.expect("bind parameters")
}

prepared_stmt
.execute()
.await
.expect("execute prepared statement")
}
};

let batches = execute_flight(&mut client, flight_info)
.await
.expect("prepare statement");
info!("got flight info");
.expect("read flight data");

let res = pretty_format_batches(batches.as_slice()).expect("format results");
println!("{res}");
}

async fn execute_flight(
client: &mut FlightSqlServiceClient<Channel>,
info: FlightInfo,
) -> Result<Vec<RecordBatch>, ArrowError> {
let schema = Arc::new(Schema::try_from(info.clone()).expect("valid schema"));
let mut batches = Vec::with_capacity(info.endpoint.len() + 1);
batches.push(RecordBatch::new_empty(schema));
Expand All @@ -134,8 +185,27 @@ async fn main() {
}
info!("received data");

let res = pretty_format_batches(batches.as_slice()).expect("format results");
println!("{res}");
Ok(batches)
}

fn construct_record_batch_from_params(
params: &[(String, String)],
parameter_schema: &Schema,
) -> Result<RecordBatch, ArrowError> {
let mut items = Vec::<(&String, ArrayRef)>::new();

for (name, value) in params {
let field = parameter_schema.field_with_name(name)?;
let value_as_array = StringArray::new_scalar(value);
let casted = cast_with_options(
value_as_array.get().0,
field.data_type(),
&CastOptions::default(),
)?;
items.push((name, casted))
}

RecordBatch::try_from_iter(items)
}

fn setup_logging() {
Expand Down Expand Up @@ -203,3 +273,13 @@ async fn setup_client(

Ok(client)
}

/// Parse a single key-value pair
fn parse_key_val(
s: &str,
) -> Result<(String, String), Box<dyn Error + Send + Sync + 'static>> {
let pos = s
.find('=')
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
}
50 changes: 47 additions & 3 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use std::collections::HashMap;
use std::str::FromStr;
use tonic::metadata::AsciiMetadataKey;

use crate::encode::FlightDataEncoderBuilder;
use crate::error::FlightError;
use crate::flight_service_client::FlightServiceClient;
use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT};
use crate::sql::{
Expand All @@ -32,8 +34,8 @@ use crate::sql::{
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo,
CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo,
CommandPreparedStatementQuery, CommandStatementQuery, CommandStatementUpdate,
DoPutUpdateResult, ProstMessageExt, SqlInfo,
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo,
};
use crate::{
Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest,
Expand Down Expand Up @@ -439,9 +441,12 @@ impl PreparedStatement<Channel> {

/// Executes the prepared statement query on the server.
pub async fn execute(&mut self) -> Result<FlightInfo, ArrowError> {
self.write_bind_params().await?;

let cmd = CommandPreparedStatementQuery {
prepared_statement_handle: self.handle.clone(),
};

let result = self
.flight_sql_client
.get_flight_info_for_command(cmd)
Expand All @@ -451,7 +456,9 @@ impl PreparedStatement<Channel> {

/// Executes the prepared statement update query on the server.
pub async fn execute_update(&mut self) -> Result<i64, ArrowError> {
let cmd = CommandPreparedStatementQuery {
self.write_bind_params().await?;

let cmd = CommandPreparedStatementUpdate {
prepared_statement_handle: self.handle.clone(),
};
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
Expand Down Expand Up @@ -492,6 +499,36 @@ impl PreparedStatement<Channel> {
Ok(())
}

/// Submit parameters to the server, if any have been set on this prepared statement instance
async fn write_bind_params(&mut self) -> Result<(), ArrowError> {
if let Some(ref params_batch) = self.parameter_binding {
let cmd = CommandPreparedStatementQuery {
prepared_statement_handle: self.handle.clone(),
};

let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let flight_stream_builder = FlightDataEncoderBuilder::new()
.with_flight_descriptor(Some(descriptor))
.with_schema(params_batch.schema());
let flight_data = flight_stream_builder
.build(futures::stream::iter(
self.parameter_binding.clone().map(Ok),
))
.try_collect::<Vec<_>>()
.await
.map_err(flight_error_to_arrow_error)?;

self.flight_sql_client
.do_put(stream::iter(flight_data))
.await?
.try_collect::<Vec<_>>()
.await
.map_err(status_to_arrow_error)?;
}

Ok(())
}

/// Close the prepared statement, so that this PreparedStatement can not used
/// anymore and server can free up any resources.
pub async fn close(mut self) -> Result<(), ArrowError> {
Expand All @@ -515,6 +552,13 @@ fn status_to_arrow_error(status: tonic::Status) -> ArrowError {
ArrowError::IpcError(format!("{status:?}"))
}

fn flight_error_to_arrow_error(err: FlightError) -> ArrowError {
match err {
FlightError::Arrow(e) => e,
e => ArrowError::ExternalError(Box::new(e)),
}
}

// A polymorphic structure to natively represent different types of data contained in `FlightData`
pub enum ArrowFlightData {
RecordBatch(RecordBatch),
Expand Down
Loading

0 comments on commit 47e8a8d

Please sign in to comment.