Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fire_query method to client #22

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,49 @@
use crate::{flight::SqlFlightClient, prices::PricesClient, tls::new_tls_flight_channel};
use arrow_flight::decode::FlightRecordBatchStream;
use std::error::Error;
use tokio::join;
use tonic::transport::Channel;

pub struct SpiceClientConfig {
pub https_addr: String,
pub flight_channel: Channel,
pub firecache_channel: Channel,
}

impl SpiceClientConfig {
pub fn new(https_addr: String, flight_channel: Channel) -> Self {
pub fn new(https_addr: String, flight_channel: Channel, firecache_channel: Channel) -> Self {
SpiceClientConfig {
https_addr: https_addr,
flight_channel: flight_channel,
firecache_channel: firecache_channel,
}
}

pub async fn load_from_default() -> Result<SpiceClientConfig, Box<dyn Error>> {
let https_addr = "https://data.spiceai.io".to_string();
let flight_addr = "https://flight.spiceai.io".to_string();
match new_tls_flight_channel(flight_addr.clone()).await {
Err(e) => Err(e.into()),
Ok(flight_chan) => Ok(SpiceClientConfig::new(https_addr, flight_chan)),
let firecache_addr = "https://firecache.spiceai.io".to_string();

match join!(
gloomweaver marked this conversation as resolved.
Show resolved Hide resolved
new_tls_flight_channel(flight_addr.clone()),
new_tls_flight_channel(firecache_addr.clone())
) {
(Err(e), _) => return Err(e.into()),
(_, Err(e)) => return Err(e.into()),
(Ok(flight_chan), Ok(firecache_chan)) => {
return Ok(SpiceClientConfig::new(
https_addr,
flight_chan,
firecache_chan,
));
}
}
}
}

pub struct SpiceClient {
flight: SqlFlightClient,
firecache: SqlFlightClient,
pub prices: PricesClient,
}

Expand All @@ -38,11 +54,19 @@ impl SpiceClient {
.expect("Error Loading Client Config");
Self {
flight: SqlFlightClient::new(config.flight_channel, api_key.to_string()),
firecache: SqlFlightClient::new(config.firecache_channel, api_key.to_string()),
prices: PricesClient::new(Some(config.https_addr), api_key.to_string()),
}
}

pub async fn query(&mut self, query: &str) -> Result<FlightRecordBatchStream, Box<dyn Error>> {
self.flight.query(query).await
self.flight.query(query, false).await
}

pub async fn fire_query(
&mut self,
query: &str,
) -> Result<FlightRecordBatchStream, Box<dyn Error>> {
self.firecache.query(query, true).await
}
}
17 changes: 16 additions & 1 deletion src/flight.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::sql::client::FlightSqlServiceClient;
use arrow_flight::Ticket;

use std::error::Error;
use tonic::transport::Channel;
Expand Down Expand Up @@ -30,6 +31,7 @@ impl SqlFlightClient {
pub async fn query(
&mut self,
query: &str,
firecache: bool,
) -> std::result::Result<FlightRecordBatchStream, Box<dyn Error>> {
match self.authenticate().await {
Err(e) => return Err(e.into()),
Expand All @@ -40,7 +42,20 @@ impl SqlFlightClient {
Ok(resp) => {
for ep in resp.endpoint {
if let Some(tkt) = ep.ticket {
return self.client.do_get(tkt).await.map_err(|e| e.into());
// There seems to be an issue with ticket parsing in arrow-flight crate
// This is a workaround to fix the issue
let fixed_ticket = if firecache {
Ticket::new(
tkt.ticket
.into_iter()
.skip_while(|&x| x != b'}')
.skip(1)
.collect::<Vec<u8>>(),
)
} else {
tkt
};
return self.client.do_get(fixed_ticket).await.map_err(|e| e.into());
}
}
Err("no tickets for flight endpoint".into())
Expand Down
28 changes: 27 additions & 1 deletion tests/client_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ mod tests {
async fn test_query() {
let mut spice_client = new_client().await;
match spice_client.query(
"SELECT number, \"timestamp\", base_fee_per_gas, base_fee_per_gas / 1e9 AS base_fee_per_gas_gwei FROM eth.recent_blocks limit 10",
r#"SELECT number, "timestamp", base_fee_per_gas, base_fee_per_gas / 1e9 AS base_fee_per_gas_gwei FROM eth.recent_blocks limit 10"#,
).await {
Ok(mut flight_data_stream) => {
// Read back RecordBatches
Expand All @@ -43,6 +43,32 @@ mod tests {
};
}

#[tokio::test]
async fn test_fire_query() {
let mut spice_client = new_client().await;
match spice_client
.fire_query(r#"SELECT number, "timestamp", base_fee_per_gas, base_fee_per_gas / 1e9 AS base_fee_per_gas_gwei FROM eth.recent_blocks limit 10"#)
.await
{
Ok(mut flight_data_stream) => {
while let Some(batch) = flight_data_stream.next().await {
match batch {
Ok(batch) => {
assert_eq!(batch.num_columns(), 4);
assert_eq!(batch.num_rows(), 10);
},
Err(e) => {
assert!(false, "Error: {}", e)
}
};
}
}
Err(e) => {
assert!(false, "Error: {}", e);
}
};
}

#[tokio::test]
async fn test_query_streaming() {
let mut spice_client = new_client().await;
Expand Down