Skip to content

Commit

Permalink
feat: add command line args for duckdb-server-rust (#647)
Browse files Browse the repository at this point in the history
* feat: add command line args for duckdb-server-rust

* use Option

* format code
  • Loading branch information
kwonoh authored Jan 7, 2025
1 parent 662e971 commit 08bec1c
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 14 deletions.
95 changes: 89 additions & 6 deletions packages/duckdb-server-rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions packages/duckdb-server-rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ async-trait = "0.1.83"
axum = { version = "0.7", features = ["http1", "http2", "ws", "json", "tokio", "tracing", "macros"] }
axum-server = { version = "0.7", features = ["tls-rustls"] }
axum-server-dual-protocol = "0.7"
clap = { version = "4.5.23", features = ["derive"] }
duckdb = { version = "1", features = ["bundled", "csv", "json", "parquet", "url", "r2d2"] }
futures = "0.3"
listenfd = "1.0"
Expand Down
13 changes: 10 additions & 3 deletions packages/duckdb-server-rust/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,17 @@ async fn handle_post(
query::handle(&state, params).await
}

pub fn app() -> Result<Router> {
pub fn app(
dp_path: Option<&str>,
connection_pool_size: Option<u32>,
cache_size: Option<usize>,
) -> Result<Router> {
// Database and state setup
let db = ConnectionPool::new(":memory:", 10)?;
let cache = lru::LruCache::new(1000.try_into()?);
let db = ConnectionPool::new(
dp_path.unwrap_or(":memory:"),
connection_pool_size.unwrap_or(10),
)?;
let cache = lru::LruCache::new(cache_size.unwrap_or(1000).try_into()?);

let state = Arc::new(AppState {
db: Box::new(db),
Expand Down
31 changes: 29 additions & 2 deletions packages/duckdb-server-rust/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use anyhow::Result;
use axum_server::tls_rustls::RustlsConfig;
use clap::Parser;
use listenfd::ListenFd;
use std::net::TcpListener;
use std::{net::Ipv4Addr, net::SocketAddr, path::PathBuf};
Expand All @@ -14,8 +15,30 @@ mod interfaces;
mod query;
mod websocket;

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
/// Path of database file (e.g., "database.db". ":memory:" for in-memory database)
#[arg(default_value = ":memory:")]
database: String,

/// HTTP Port
#[arg(short, long, default_value_t = 3000)]
port: u16,

/// Max connection pool size
#[arg(long, default_value_t = 10)]
connection_pool_size: u32,

/// Max number of cache entries
#[arg(long, default_value_t = 1000)]
cache_size: usize,
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();

// Tracing setup
tracing_subscriber::registry()
.with(
Expand All @@ -27,7 +50,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.init();

// App setup
let app = app::app()?;
let app = app::app(
Some(&args.database),
Some(args.connection_pool_size),
Some(args.cache_size),
)?;

// TLS configuration
let mut config = RustlsConfig::from_pem_file(
Expand All @@ -42,7 +69,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}

// Listenfd setup
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 3000);
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), args.port);
let mut listenfd = ListenFd::from_env();
let listener = match listenfd.take_tcp_listener(0)? {
// if we are given a tcp listener on listen fd 0, we use that one
Expand Down
6 changes: 3 additions & 3 deletions packages/duckdb-server-rust/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async fn get_arrow() -> Result<()> {

#[tokio::test]
async fn select_1_get() -> Result<()> {
let app = app::app()?;
let app = app::app(None, None, None)?;

let response = app
.oneshot(
Expand All @@ -112,7 +112,7 @@ async fn select_1_get() -> Result<()> {

#[tokio::test]
async fn select_1_post() -> Result<()> {
let app = app::app()?;
let app = app::app(None, None, None)?;

let response = app
.oneshot(
Expand All @@ -136,7 +136,7 @@ async fn select_1_post() -> Result<()> {

#[tokio::test]
async fn query_arrow() -> Result<()> {
let app = app::app()?;
let app = app::app(None, None, None)?;

let response = app
.oneshot(
Expand Down

0 comments on commit 08bec1c

Please sign in to comment.