From 51e521165184c9937f55b68b6f678182e456b743 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Thu, 28 Jul 2022 10:12:12 -0600 Subject: [PATCH 1/5] Allow FlightSqlService to do authentication (with example) --- arrow-flight/examples/flight_sql_server.rs | 46 +++++++++++++++++++++- arrow-flight/src/sql/server.rs | 12 +++++- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index 8b4fe477b868..450136b57fb1 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +use std::pin::Pin; +use futures::Stream; use arrow_flight::sql::{ActionCreatePreparedStatementResult, SqlInfo}; -use arrow_flight::FlightData; +use arrow_flight::{FlightData, HandshakeRequest, HandshakeResponse}; use tonic::transport::Server; -use tonic::{Response, Status, Streaming}; +use tonic::{Request, Response, Status, Streaming}; use arrow_flight::{ flight_service_server::FlightService, @@ -41,6 +43,46 @@ pub struct FlightSqlServiceImpl {} #[tonic::async_trait] impl FlightSqlService for FlightSqlServiceImpl { type FlightService = FlightSqlServiceImpl; + + async fn do_handshake( + &self, + request: Request> + ) -> Result< + Response> + Send>>>, + Status + > { + let basic = "Basic "; + let authorization = request.metadata().get("authorization") + .ok_or(Status::invalid_argument("authorization field not present"))? + .to_str() + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + if !authorization.starts_with(basic) { + Err(Status::invalid_argument(format!("Auth type not implemented: {}", authorization)))?; + } + let base64 = &authorization[basic.len()..]; + let bytes = base64::decode(base64) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let str = String::from_utf8(bytes) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let parts: Vec<_> = str.split(":").collect(); + if parts.len() != 2 { + Err(Status::invalid_argument(format!("Invalid authorization header")))?; + } + let user = parts[0]; + let pass = parts[1]; + if user != "admin" || pass != "password" { + Err(Status::unauthenticated("Invalid credentials!"))? + } + let result = HandshakeResponse { + protocol_version: 0, + payload: "random_uuid_token".as_bytes().to_vec() + }; + let result = Ok(result); + let output = futures::stream::iter(vec![result]); + return Ok(Response::new(Box::pin(output))); + + } + // get_flight_info async fn get_flight_info_statement( &self, diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 87e282b103b7..63c90e51b61f 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -47,6 +47,13 @@ pub trait FlightSqlService: /// When impl FlightSqlService, you can always set FlightService to Self type FlightService: FlightService; + /// Accept authentication and return a token + /// https://arrow.apache.org/docs/format/Flight.html#authentication + async fn do_handshake( + &self, + request: Request>, + ) -> Result> + Send>>>, Status>; + /// Get a FlightInfo for executing a SQL query. async fn get_flight_info_statement( &self, @@ -256,9 +263,10 @@ where async fn handshake( &self, - _request: Request>, + request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + let res = self.do_handshake(request).await?; + Ok(res) } async fn list_flights( From 3f44901aad51d95a7911063b3b509cc8976952d5 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Thu, 28 Jul 2022 10:19:47 -0600 Subject: [PATCH 2/5] Formatting --- arrow-flight/examples/flight_sql_server.rs | 26 +++++++++++++--------- arrow-flight/src/sql/server.rs | 5 ++++- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index 450136b57fb1..7e2a759c5590 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use std::pin::Pin; -use futures::Stream; use arrow_flight::sql::{ActionCreatePreparedStatementResult, SqlInfo}; use arrow_flight::{FlightData, HandshakeRequest, HandshakeResponse}; +use futures::Stream; +use std::pin::Pin; use tonic::transport::Server; use tonic::{Request, Response, Status, Streaming}; @@ -46,18 +46,23 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_handshake( &self, - request: Request> + request: Request>, ) -> Result< - Response> + Send>>>, - Status + Response> + Send>>>, + Status, > { let basic = "Basic "; - let authorization = request.metadata().get("authorization") + let authorization = request + .metadata() + .get("authorization") .ok_or(Status::invalid_argument("authorization field not present"))? .to_str() .map_err(|_| Status::invalid_argument("authorization not parsable"))?; if !authorization.starts_with(basic) { - Err(Status::invalid_argument(format!("Auth type not implemented: {}", authorization)))?; + Err(Status::invalid_argument(format!( + "Auth type not implemented: {}", + authorization + )))?; } let base64 = &authorization[basic.len()..]; let bytes = base64::decode(base64) @@ -66,7 +71,9 @@ impl FlightSqlService for FlightSqlServiceImpl { .map_err(|_| Status::invalid_argument("authorization not parsable"))?; let parts: Vec<_> = str.split(":").collect(); if parts.len() != 2 { - Err(Status::invalid_argument(format!("Invalid authorization header")))?; + Err(Status::invalid_argument(format!( + "Invalid authorization header" + )))?; } let user = parts[0]; let pass = parts[1]; @@ -75,12 +82,11 @@ impl FlightSqlService for FlightSqlServiceImpl { } let result = HandshakeResponse { protocol_version: 0, - payload: "random_uuid_token".as_bytes().to_vec() + payload: "random_uuid_token".as_bytes().to_vec(), }; let result = Ok(result); let output = futures::stream::iter(vec![result]); return Ok(Response::new(Box::pin(output))); - } // get_flight_info diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 63c90e51b61f..494db8705a00 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -52,7 +52,10 @@ pub trait FlightSqlService: async fn do_handshake( &self, request: Request>, - ) -> Result> + Send>>>, Status>; + ) -> Result< + Response> + Send>>>, + Status, + >; /// Get a FlightInfo for executing a SQL query. async fn get_flight_info_statement( From 0f07d9300349c7cbad3ee35a789ace0c3a790af0 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Thu, 28 Jul 2022 10:54:39 -0600 Subject: [PATCH 3/5] Fix doc --- arrow-flight/src/sql/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 494db8705a00..9e1523ee0e2d 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -48,7 +48,7 @@ pub trait FlightSqlService: type FlightService: FlightService; /// Accept authentication and return a token - /// https://arrow.apache.org/docs/format/Flight.html#authentication + /// async fn do_handshake( &self, request: Request>, From ba9520b1dbf83f2bbcfea71591188a0383fdfbcd Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Thu, 28 Jul 2022 12:17:21 -0600 Subject: [PATCH 4/5] Default impl --- arrow-flight/src/sql/server.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 9e1523ee0e2d..89df1a92fc63 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -51,11 +51,13 @@ pub trait FlightSqlService: /// async fn do_handshake( &self, - request: Request>, + _request: Request>, ) -> Result< Response> + Send>>>, Status, - >; + > { + Err(Status::unimplemented("Handshake has no default implementation")) + } /// Get a FlightInfo for executing a SQL query. async fn get_flight_info_statement( From 32a6aed73673b79c57c81e70935083b22ac4e70a Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Thu, 28 Jul 2022 12:19:57 -0600 Subject: [PATCH 5/5] fmt --- arrow-flight/src/sql/server.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 89df1a92fc63..2d9d88638588 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -56,7 +56,9 @@ pub trait FlightSqlService: Response> + Send>>>, Status, > { - Err(Status::unimplemented("Handshake has no default implementation")) + Err(Status::unimplemented( + "Handshake has no default implementation", + )) } /// Get a FlightInfo for executing a SQL query.