Skip to content

Commit

Permalink
Add fire_query method to client
Browse files Browse the repository at this point in the history
  • Loading branch information
gloomweaver committed Jan 3, 2024
1 parent dce0d2a commit d81cdac
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 7 deletions.
34 changes: 29 additions & 5 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,49 @@
use crate::{flight::SqlFlightClient, prices::PricesClient, tls::new_tls_flight_channel};
use arrow_flight::decode::FlightRecordBatchStream;
use std::error::Error;
use tokio::join;
use tonic::transport::Channel;

pub struct SpiceClientConfig {
pub https_addr: String,
pub flight_channel: Channel,
pub firecache_channel: Channel,
}

impl SpiceClientConfig {
pub fn new(https_addr: String, flight_channel: Channel) -> Self {
pub fn new(https_addr: String, flight_channel: Channel, firecache_channel: Channel) -> Self {
SpiceClientConfig {
https_addr: https_addr,
flight_channel: flight_channel,
firecache_channel: firecache_channel,
}
}

pub async fn load_from_default() -> Result<SpiceClientConfig, Box<dyn Error>> {
let https_addr = "https://data.spiceai.io".to_string();
let flight_addr = "https://flight.spiceai.io".to_string();
match new_tls_flight_channel(flight_addr.clone()).await {
Err(e) => Err(e.into()),
Ok(flight_chan) => Ok(SpiceClientConfig::new(https_addr, flight_chan)),
let firecache_addr = "https://firecache.spiceai.io".to_string();

match join!(
new_tls_flight_channel(flight_addr.clone()),
new_tls_flight_channel(firecache_addr.clone())
) {
(Err(e), _) => return Err(e.into()),
(_, Err(e)) => return Err(e.into()),
(Ok(flight_chan), Ok(firecache_chan)) => {
return Ok(SpiceClientConfig::new(
https_addr,
flight_chan,
firecache_chan,
));
}
}
}
}

pub struct SpiceClient {
flight: SqlFlightClient,
firecache: SqlFlightClient,
pub prices: PricesClient,
}

Expand All @@ -38,11 +54,19 @@ impl SpiceClient {
.expect("Error Loading Client Config");
Self {
flight: SqlFlightClient::new(config.flight_channel, api_key.to_string()),
firecache: SqlFlightClient::new(config.firecache_channel, api_key.to_string()),
prices: PricesClient::new(Some(config.https_addr), api_key.to_string()),
}
}

pub async fn query(&mut self, query: &str) -> Result<FlightRecordBatchStream, Box<dyn Error>> {
self.flight.query(query).await
self.flight.query(query, false).await
}

pub async fn fire_query(
&mut self,
query: &str,
) -> Result<FlightRecordBatchStream, Box<dyn Error>> {
self.firecache.query(query, true).await
}
}
17 changes: 16 additions & 1 deletion src/flight.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::sql::client::FlightSqlServiceClient;
use arrow_flight::Ticket;

use std::error::Error;
use tonic::transport::Channel;
Expand Down Expand Up @@ -30,6 +31,7 @@ impl SqlFlightClient {
pub async fn query(
&mut self,
query: &str,
firecache: bool,
) -> std::result::Result<FlightRecordBatchStream, Box<dyn Error>> {
match self.authenticate().await {
Err(e) => return Err(e.into()),
Expand All @@ -40,7 +42,20 @@ impl SqlFlightClient {
Ok(resp) => {
for ep in resp.endpoint {
if let Some(tkt) = ep.ticket {
return self.client.do_get(tkt).await.map_err(|e| e.into());
// There seems to be an issue with ticket parsing in arrow-flight crate
// This is a workaround to fix the issue
let fixed_ticket = if firecache {
Ticket::new(
tkt.ticket
.into_iter()
.skip_while(|&x| x != b'}')
.skip(1)
.collect::<Vec<u8>>(),
)
} else {
tkt
};
return self.client.do_get(fixed_ticket).await.map_err(|e| e.into());
}
}
Err("no tickets for flight endpoint".into())
Expand Down
28 changes: 27 additions & 1 deletion tests/client_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ mod tests {
async fn test_query() {
let mut spice_client = new_client().await;
match spice_client.query(
"SELECT number, \"timestamp\", base_fee_per_gas, base_fee_per_gas / 1e9 AS base_fee_per_gas_gwei FROM eth.recent_blocks limit 10",
r#"SELECT number, "timestamp", base_fee_per_gas, base_fee_per_gas / 1e9 AS base_fee_per_gas_gwei FROM eth.recent_blocks limit 10"#,
).await {
Ok(mut flight_data_stream) => {
// Read back RecordBatches
Expand All @@ -43,6 +43,32 @@ mod tests {
};
}

#[tokio::test]
async fn test_fire_query() {
let mut spice_client = new_client().await;
match spice_client
.fire_query(r#"SELECT number, "timestamp", base_fee_per_gas, base_fee_per_gas / 1e9 AS base_fee_per_gas_gwei FROM eth.recent_blocks limit 10"#)
.await
{
Ok(mut flight_data_stream) => {
while let Some(batch) = flight_data_stream.next().await {
match batch {
Ok(batch) => {
assert_eq!(batch.num_columns(), 4);
assert_eq!(batch.num_rows(), 10);
},
Err(e) => {
assert!(false, "Error: {}", e)
}
};
}
}
Err(e) => {
assert!(false, "Error: {}", e);
}
};
}

#[tokio::test]
async fn test_query_streaming() {
let mut spice_client = new_client().await;
Expand Down

0 comments on commit d81cdac

Please sign in to comment.