diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index 8b4fe477b868..7e2a759c5590 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -16,9 +16,11 @@ // under the License. use arrow_flight::sql::{ActionCreatePreparedStatementResult, SqlInfo}; -use arrow_flight::FlightData; +use arrow_flight::{FlightData, HandshakeRequest, HandshakeResponse}; +use futures::Stream; +use std::pin::Pin; use tonic::transport::Server; -use tonic::{Response, Status, Streaming}; +use tonic::{Request, Response, Status, Streaming}; use arrow_flight::{ flight_service_server::FlightService, @@ -41,6 +43,52 @@ pub struct FlightSqlServiceImpl {} #[tonic::async_trait] impl FlightSqlService for FlightSqlServiceImpl { type FlightService = FlightSqlServiceImpl; + + async fn do_handshake( + &self, + request: Request>, + ) -> Result< + Response> + Send>>>, + Status, + > { + let basic = "Basic "; + let authorization = request + .metadata() + .get("authorization") + .ok_or(Status::invalid_argument("authorization field not present"))? + .to_str() + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + if !authorization.starts_with(basic) { + Err(Status::invalid_argument(format!( + "Auth type not implemented: {}", + authorization + )))?; + } + let base64 = &authorization[basic.len()..]; + let bytes = base64::decode(base64) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let str = String::from_utf8(bytes) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let parts: Vec<_> = str.split(":").collect(); + if parts.len() != 2 { + Err(Status::invalid_argument(format!( + "Invalid authorization header" + )))?; + } + let user = parts[0]; + let pass = parts[1]; + if user != "admin" || pass != "password" { + Err(Status::unauthenticated("Invalid credentials!"))? + } + let result = HandshakeResponse { + protocol_version: 0, + payload: "random_uuid_token".as_bytes().to_vec(), + }; + let result = Ok(result); + let output = futures::stream::iter(vec![result]); + return Ok(Response::new(Box::pin(output))); + } + // get_flight_info async fn get_flight_info_statement( &self, diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 87e282b103b7..2d9d88638588 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -47,6 +47,20 @@ pub trait FlightSqlService: /// When impl FlightSqlService, you can always set FlightService to Self type FlightService: FlightService; + /// Accept authentication and return a token + /// + async fn do_handshake( + &self, + _request: Request>, + ) -> Result< + Response> + Send>>>, + Status, + > { + Err(Status::unimplemented( + "Handshake has no default implementation", + )) + } + /// Get a FlightInfo for executing a SQL query. async fn get_flight_info_statement( &self, @@ -256,9 +270,10 @@ where async fn handshake( &self, - _request: Request>, + request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + let res = self.do_handshake(request).await?; + Ok(res) } async fn list_flights(