Skip to content

Commit

Permalink
Update FlightSqlService trait to proxy handshake (#2211)
Browse files Browse the repository at this point in the history
* Allow FlightSqlService to do authentication (with example)

* Formatting

* Fix doc

* Default impl

* fmt
  • Loading branch information
avantgardnerio authored Jul 28, 2022
1 parent 48cc6c3 commit acd8042
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
52 changes: 50 additions & 2 deletions arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
// under the License.

use arrow_flight::sql::{ActionCreatePreparedStatementResult, SqlInfo};
use arrow_flight::FlightData;
use arrow_flight::{FlightData, HandshakeRequest, HandshakeResponse};
use futures::Stream;
use std::pin::Pin;
use tonic::transport::Server;
use tonic::{Response, Status, Streaming};
use tonic::{Request, Response, Status, Streaming};

use arrow_flight::{
flight_service_server::FlightService,
Expand All @@ -41,6 +43,52 @@ pub struct FlightSqlServiceImpl {}
#[tonic::async_trait]
impl FlightSqlService for FlightSqlServiceImpl {
type FlightService = FlightSqlServiceImpl;

async fn do_handshake(
&self,
request: Request<Streaming<HandshakeRequest>>,
) -> Result<
Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
Status,
> {
let basic = "Basic ";
let authorization = request
.metadata()
.get("authorization")
.ok_or(Status::invalid_argument("authorization field not present"))?
.to_str()
.map_err(|_| Status::invalid_argument("authorization not parsable"))?;
if !authorization.starts_with(basic) {
Err(Status::invalid_argument(format!(
"Auth type not implemented: {}",
authorization
)))?;
}
let base64 = &authorization[basic.len()..];
let bytes = base64::decode(base64)
.map_err(|_| Status::invalid_argument("authorization not parsable"))?;
let str = String::from_utf8(bytes)
.map_err(|_| Status::invalid_argument("authorization not parsable"))?;
let parts: Vec<_> = str.split(":").collect();
if parts.len() != 2 {
Err(Status::invalid_argument(format!(
"Invalid authorization header"
)))?;
}
let user = parts[0];
let pass = parts[1];
if user != "admin" || pass != "password" {
Err(Status::unauthenticated("Invalid credentials!"))?
}
let result = HandshakeResponse {
protocol_version: 0,
payload: "random_uuid_token".as_bytes().to_vec(),
};
let result = Ok(result);
let output = futures::stream::iter(vec![result]);
return Ok(Response::new(Box::pin(output)));
}

// get_flight_info
async fn get_flight_info_statement(
&self,
Expand Down
19 changes: 17 additions & 2 deletions arrow-flight/src/sql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ pub trait FlightSqlService:
/// When impl FlightSqlService, you can always set FlightService to Self
type FlightService: FlightService;

/// Accept authentication and return a token
/// <https://arrow.apache.org/docs/format/Flight.html#authentication>
async fn do_handshake(
&self,
_request: Request<Streaming<HandshakeRequest>>,
) -> Result<
Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
Status,
> {
Err(Status::unimplemented(
"Handshake has no default implementation",
))
}

/// Get a FlightInfo for executing a SQL query.
async fn get_flight_info_statement(
&self,
Expand Down Expand Up @@ -256,9 +270,10 @@ where

async fn handshake(
&self,
_request: Request<Streaming<HandshakeRequest>>,
request: Request<Streaming<HandshakeRequest>>,
) -> Result<Response<Self::HandshakeStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
let res = self.do_handshake(request).await?;
Ok(res)
}

async fn list_flights(
Expand Down

0 comments on commit acd8042

Please sign in to comment.