diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index 7e2a759c5590..aa0d407113d7 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -16,7 +16,7 @@ // under the License. use arrow_flight::sql::{ActionCreatePreparedStatementResult, SqlInfo}; -use arrow_flight::{FlightData, HandshakeRequest, HandshakeResponse}; +use arrow_flight::{Action, FlightData, HandshakeRequest, HandshakeResponse, Ticket}; use futures::Stream; use std::pin::Pin; use tonic::transport::Server; @@ -93,179 +93,253 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn get_flight_info_statement( &self, _query: CommandStatementQuery, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_statement not implemented", + )) } + async fn get_flight_info_prepared_statement( &self, _query: CommandPreparedStatementQuery, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_prepared_statement not implemented", + )) } + async fn get_flight_info_catalogs( &self, _query: CommandGetCatalogs, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_catalogs not implemented", + )) } + async fn get_flight_info_schemas( &self, _query: CommandGetDbSchemas, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_schemas not implemented", + )) } + async fn get_flight_info_tables( &self, _query: CommandGetTables, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_tables not implemented", + )) } + async fn get_flight_info_table_types( &self, _query: CommandGetTableTypes, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_table_types not implemented", + )) } + async fn get_flight_info_sql_info( &self, _query: CommandGetSqlInfo, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_sql_info not implemented", + )) } + async fn get_flight_info_primary_keys( &self, _query: CommandGetPrimaryKeys, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_primary_keys not implemented", + )) } + async fn get_flight_info_exported_keys( &self, _query: CommandGetExportedKeys, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_exported_keys not implemented", + )) } + async fn get_flight_info_imported_keys( &self, _query: CommandGetImportedKeys, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_imported_keys not implemented", + )) } + async fn get_flight_info_cross_reference( &self, _query: CommandGetCrossReference, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_imported_keys not implemented", + )) } + // do_get async fn do_get_statement( &self, _ticket: TicketStatementQuery, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_statement not implemented")) } async fn do_get_prepared_statement( &self, _query: CommandPreparedStatementQuery, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_get_prepared_statement not implemented", + )) } + async fn do_get_catalogs( &self, _query: CommandGetCatalogs, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_catalogs not implemented")) } + async fn do_get_schemas( &self, _query: CommandGetDbSchemas, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_schemas not implemented")) } + async fn do_get_tables( &self, _query: CommandGetTables, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_tables not implemented")) } + async fn do_get_table_types( &self, _query: CommandGetTableTypes, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_table_types not implemented")) } + async fn do_get_sql_info( &self, _query: CommandGetSqlInfo, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_sql_info not implemented")) } + async fn do_get_primary_keys( &self, _query: CommandGetPrimaryKeys, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_primary_keys not implemented")) } + async fn do_get_exported_keys( &self, _query: CommandGetExportedKeys, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_get_exported_keys not implemented", + )) } + async fn do_get_imported_keys( &self, _query: CommandGetImportedKeys, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_get_imported_keys not implemented", + )) } + async fn do_get_cross_reference( &self, _query: CommandGetCrossReference, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_get_cross_reference not implemented", + )) } + // do_put async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, + _request: Request>, ) -> Result { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_put_statement_update not implemented", + )) } + async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, - _request: Streaming, + _request: Request>, ) -> Result::DoPutStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_put_prepared_statement_query not implemented", + )) } + async fn do_put_prepared_statement_update( &self, _query: CommandPreparedStatementUpdate, - _request: Streaming, + _request: Request>, ) -> Result { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_put_prepared_statement_update not implemented", + )) } + // do_action async fn do_action_create_prepared_statement( &self, _query: ActionCreatePreparedStatementRequest, + _request: Request, ) -> Result { Err(Status::unimplemented("Not yet implemented")) } async fn do_action_close_prepared_statement( &self, _query: ActionClosePreparedStatementRequest, + _request: Request, ) { unimplemented!("Not yet implemented") } diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 2d9d88638588..74676429faad 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -65,77 +65,77 @@ pub trait FlightSqlService: async fn get_flight_info_statement( &self, query: CommandStatementQuery, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo for executing an already created prepared statement. async fn get_flight_info_prepared_statement( &self, query: CommandPreparedStatementQuery, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo for listing catalogs. async fn get_flight_info_catalogs( &self, query: CommandGetCatalogs, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo for listing schemas. async fn get_flight_info_schemas( &self, query: CommandGetDbSchemas, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo for listing tables. async fn get_flight_info_tables( &self, query: CommandGetTables, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo to extract information about the table types. async fn get_flight_info_table_types( &self, query: CommandGetTableTypes, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo for retrieving other information (See SqlInfo). async fn get_flight_info_sql_info( &self, query: CommandGetSqlInfo, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo to extract information about primary and foreign keys. async fn get_flight_info_primary_keys( &self, query: CommandGetPrimaryKeys, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo to extract information about exported keys. async fn get_flight_info_exported_keys( &self, query: CommandGetExportedKeys, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo to extract information about imported keys. async fn get_flight_info_imported_keys( &self, query: CommandGetImportedKeys, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo to extract information about cross reference. async fn get_flight_info_cross_reference( &self, query: CommandGetCrossReference, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; // do_get @@ -144,66 +144,77 @@ pub trait FlightSqlService: async fn do_get_statement( &self, ticket: TicketStatementQuery, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the prepared statement query results. async fn do_get_prepared_statement( &self, query: CommandPreparedStatementQuery, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the list of catalogs. async fn do_get_catalogs( &self, query: CommandGetCatalogs, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the list of schemas. async fn do_get_schemas( &self, query: CommandGetDbSchemas, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the list of tables. async fn do_get_tables( &self, query: CommandGetTables, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the table types. async fn do_get_table_types( &self, query: CommandGetTableTypes, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the list of SqlInfo results. async fn do_get_sql_info( &self, query: CommandGetSqlInfo, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the primary and foreign keys. async fn do_get_primary_keys( &self, query: CommandGetPrimaryKeys, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the exported keys. async fn do_get_exported_keys( &self, query: CommandGetExportedKeys, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the imported keys. async fn do_get_imported_keys( &self, query: CommandGetImportedKeys, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the cross reference. async fn do_get_cross_reference( &self, query: CommandGetCrossReference, + request: Request, ) -> Result::DoGetStream>, Status>; // do_put @@ -212,20 +223,21 @@ pub trait FlightSqlService: async fn do_put_statement_update( &self, ticket: CommandStatementUpdate, + request: Request>, ) -> Result; /// Bind parameters to given prepared statement. async fn do_put_prepared_statement_query( &self, query: CommandPreparedStatementQuery, - request: Streaming, + request: Request>, ) -> Result::DoPutStream>, Status>; /// Execute an update SQL prepared statement. async fn do_put_prepared_statement_update( &self, query: CommandPreparedStatementUpdate, - request: Streaming, + request: Request>, ) -> Result; // do_action @@ -234,12 +246,14 @@ pub trait FlightSqlService: async fn do_action_create_prepared_statement( &self, query: ActionCreatePreparedStatementRequest, + request: Request, ) -> Result; /// Close a prepared statement. async fn do_action_close_prepared_statement( &self, query: ActionClosePreparedStatementRequest, + request: Request, ); /// Register a new SqlInfo result, making it available when calling GetSqlInfo. @@ -287,119 +301,87 @@ where &self, request: Request, ) -> Result, Status> { - let request = request.into_inner(); let any: prost_types::Any = - prost::Message::decode(&*request.cmd).map_err(decode_error_to_status)?; + Message::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?; if any.is::() { - return self - .get_flight_info_statement( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_statement(token, request).await; } if any.is::() { + let handle = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); return self - .get_flight_info_prepared_statement( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) + .get_flight_info_prepared_statement(handle, request) .await; } if any.is::() { - return self - .get_flight_info_catalogs( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_catalogs(token, request).await; } if any.is::() { - return self - .get_flight_info_schemas( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_schemas(token, request).await; } if any.is::() { - return self - .get_flight_info_tables( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_tables(token, request).await; } if any.is::() { - return self - .get_flight_info_table_types( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_table_types(token, request).await; } if any.is::() { - return self - .get_flight_info_sql_info( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_sql_info(token, request).await; } if any.is::() { - return self - .get_flight_info_primary_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_primary_keys(token, request).await; } if any.is::() { - return self - .get_flight_info_exported_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_exported_keys(token, request).await; } if any.is::() { - return self - .get_flight_info_imported_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_imported_keys(token, request).await; } if any.is::() { - return self - .get_flight_info_cross_reference( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_cross_reference(token, request).await; } Err(Status::unimplemented(format!( @@ -419,133 +401,107 @@ where &self, request: Request, ) -> Result, Status> { - let request = request.into_inner(); - let any: prost_types::Any = - prost::Message::decode(&*request.ticket).map_err(decode_error_to_status)?; + let any: prost_types::Any = prost::Message::decode(&*request.get_ref().ticket) + .map_err(decode_error_to_status)?; if any.is::() { - return self - .do_get_statement( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_statement(token, request).await; } if any.is::() { - return self - .do_get_prepared_statement( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_prepared_statement(token, request).await; } if any.is::() { - return self - .do_get_catalogs( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_catalogs(token, request).await; } if any.is::() { - return self - .do_get_schemas( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_schemas(token, request).await; } if any.is::() { - return self - .do_get_tables( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_tables(token, request).await; } if any.is::() { - return self - .do_get_table_types( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_table_types(token, request).await; } if any.is::() { - return self - .do_get_sql_info( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_sql_info(token, request).await; } if any.is::() { - return self - .do_get_primary_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_primary_keys(token, request).await; } if any.is::() { - return self - .do_get_exported_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_exported_keys(token, request).await; } if any.is::() { - return self - .do_get_imported_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_imported_keys(token, request).await; } if any.is::() { - return self - .do_get_cross_reference( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_cross_reference(token, request).await; } Err(Status::unimplemented(format!( "do_get: The defined request is invalid: {:?}", - String::from_utf8(request.ticket).unwrap() + String::from_utf8(request.get_ref().ticket.clone()).unwrap() ))) } async fn do_put( &self, - request: Request>, + mut request: Request>, ) -> Result, Status> { - let mut request = request.into_inner(); - let cmd = request.message().await?.unwrap(); + let cmd = request.get_mut().message().await?.unwrap(); let any: prost_types::Any = prost::Message::decode(&*cmd.flight_descriptor.unwrap().cmd) .map_err(decode_error_to_status)?; if any.is::() { - let record_count = self - .do_put_statement_update( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await?; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + let record_count = self.do_put_statement_update(token, request).await?; let result = DoPutUpdateResult { record_count }; let output = futures::stream::iter(vec![Ok(super::super::gen::PutResult { app_metadata: result.as_any().encode_to_vec(), @@ -553,23 +509,19 @@ where return Ok(Response::new(Box::pin(output))); } if any.is::() { - return self - .do_put_prepared_statement_query( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_put_prepared_statement_query(token, request).await; } if any.is::() { + let handle = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); let record_count = self - .do_put_prepared_statement_update( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) + .do_put_prepared_statement_update(handle, request) .await?; let result = DoPutUpdateResult { record_count }; let output = futures::stream::iter(vec![Ok(super::super::gen::PutResult { @@ -614,11 +566,9 @@ where &self, request: Request, ) -> Result, Status> { - let request = request.into_inner(); - - if request.r#type == CREATE_PREPARED_STATEMENT { - let any: prost_types::Any = - prost::Message::decode(&*request.body).map_err(decode_error_to_status)?; + if request.get_ref().r#type == CREATE_PREPARED_STATEMENT { + let any: prost_types::Any = Message::decode(&*request.get_ref().body) + .map_err(decode_error_to_status)?; let cmd: ActionCreatePreparedStatementRequest = any .unpack() @@ -628,15 +578,17 @@ where "Unable to unpack ActionCreatePreparedStatementRequest.", ) })?; - let stmt = self.do_action_create_prepared_statement(cmd).await?; + let stmt = self + .do_action_create_prepared_statement(cmd, request) + .await?; let output = futures::stream::iter(vec![Ok(super::super::gen::Result { body: stmt.as_any().encode_to_vec(), })]); return Ok(Response::new(Box::pin(output))); } - if request.r#type == CLOSE_PREPARED_STATEMENT { - let any: prost_types::Any = - prost::Message::decode(&*request.body).map_err(decode_error_to_status)?; + if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT { + let any: prost_types::Any = Message::decode(&*request.get_ref().body) + .map_err(decode_error_to_status)?; let cmd: ActionClosePreparedStatementRequest = any .unpack() @@ -646,13 +598,13 @@ where "Unable to unpack ActionClosePreparedStatementRequest.", ) })?; - self.do_action_close_prepared_statement(cmd).await; + self.do_action_close_prepared_statement(cmd, request).await; return Ok(Response::new(Box::pin(futures::stream::empty()))); } Err(Status::invalid_argument(format!( "do_action: The defined request is invalid: {:?}", - request.r#type + request.get_ref().r#type ))) }