Skip to content

Commit

Permalink
use clap 3 style args parsing for datafusion cli (#1749)
Browse files Browse the repository at this point in the history
* use clap 3 style args parsing for datafusion cli

* upgrade cli version
  • Loading branch information
jimexist authored Feb 5, 2022
1 parent 15cfcbc commit 40df55f
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 179 deletions.
3 changes: 2 additions & 1 deletion datafusion-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

[package]
name = "datafusion-cli"
version = "5.1.0"
description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model. It supports executing SQL queries against CSV and Parquet files as well as querying directly against in-memory data."
version = "6.0.0"
authors = ["Apache Arrow <[email protected]>"]
edition = "2021"
keywords = [ "arrow", "datafusion", "ballista", "query", "sql" ]
Expand Down
11 changes: 8 additions & 3 deletions datafusion-cli/src/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
use crate::context::Context;
use crate::functions::{display_all_functions, Function};
use crate::print_format::PrintFormat;
use crate::print_options::{self, PrintOptions};
use crate::print_options::PrintOptions;
use clap::ArgEnum;
use datafusion::arrow::array::{ArrayRef, StringArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
Expand Down Expand Up @@ -206,10 +207,14 @@ impl OutputFormat {
Self::ChangeFormat(format) => {
if let Ok(format) = format.parse::<PrintFormat>() {
print_options.format = format;
println!("Output format is {}.", print_options.format);
println!("Output format is {:?}.", print_options.format);
Ok(())
} else {
Err(DataFusionError::Execution(format!("{} is not a valid format type [possible values: csv, tsv, table, json, ndjson]", format)))
Err(DataFusionError::Execution(format!(
"{:?} is not a valid format type [possible values: {:?}]",
format,
PrintFormat::value_variants()
)))
}
}
}
Expand Down
10 changes: 2 additions & 8 deletions datafusion-cli/src/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,14 @@ use crate::{
command::{Command, OutputFormat},
context::Context,
helper::CliHelper,
print_format::{all_print_formats, PrintFormat},
print_options::PrintOptions,
};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::arrow::util::pretty;
use datafusion::error::{DataFusionError, Result};
use rustyline::config::Config;
use datafusion::error::Result;
use rustyline::error::ReadlineError;
use rustyline::Editor;
use std::fs::File;
use std::io::prelude::*;
use std::io::BufReader;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Instant;

/// run and execute SQL statements and commands from a file, against a context with the given print options
Expand Down Expand Up @@ -109,7 +103,7 @@ pub async fn exec_from_repl(ctx: &mut Context, print_options: &mut PrintOptions)
);
}
} else {
println!("Output format is {}.", print_options.format);
println!("Output format is {:?}.", print_options.format);
}
}
_ => {
Expand Down
2 changes: 1 addition & 1 deletion datafusion-cli/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use arrow::array::StringArray;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use datafusion::error::{DataFusionError, Result};
use datafusion::error::Result;
use std::fmt;
use std::str::FromStr;
use std::sync::Arc;
Expand Down
1 change: 0 additions & 1 deletion datafusion-cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
// under the License.

#![doc = include_str!("../README.md")]
#![allow(unused_imports)]
pub const DATAFUSION_CLI_VERSION: &str = env!("CARGO_PKG_VERSION");

pub mod command;
Expand Down
162 changes: 63 additions & 99 deletions datafusion-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,132 +15,96 @@
// specific language governing permissions and limitations
// under the License.

use clap::{crate_version, App, Arg};
use clap::Parser;
use datafusion::error::Result;
use datafusion::execution::context::ExecutionConfig;
use datafusion_cli::{
context::Context,
exec,
print_format::{all_print_formats, PrintFormat},
print_options::PrintOptions,
context::Context, exec, print_format::PrintFormat, print_options::PrintOptions,
DATAFUSION_CLI_VERSION,
};
use std::env;
use std::fs::File;
use std::io::BufReader;
use std::path::Path;

#[derive(Debug, Parser, PartialEq)]
#[clap(author, version, about, long_about= None)]
struct Args {
#[clap(
short = 'p',
long,
help = "Path to your data, default to current directory",
validator(is_valid_data_dir)
)]
data_path: Option<String>,

#[clap(
short = 'c',
long,
help = "The batch size of each query, or use DataFusion default",
validator(is_valid_batch_size)
)]
batch_size: Option<usize>,

#[clap(
short,
long,
multiple_values = true,
help = "Execute commands from file(s), then exit",
validator(is_valid_file)
)]
file: Vec<String>,

#[clap(long, arg_enum, default_value_t = PrintFormat::Table)]
format: PrintFormat,

#[clap(long, help = "Ballista scheduler host")]
host: Option<String>,

#[clap(long, help = "Ballista scheduler port")]
port: Option<u16>,

#[clap(
short,
long,
help = "Reduce printing other than the results and work quietly"
)]
quiet: bool,
}

#[tokio::main]
pub async fn main() -> Result<()> {
let matches = App::new("DataFusion")
.version(crate_version!())
.about(
"DataFusion is an in-memory query engine that uses Apache Arrow \
as the memory model. It supports executing SQL queries against CSV and \
Parquet files as well as querying directly against in-memory data.",
)
.arg(
Arg::new("data-path")
.help("Path to your data, default to current directory")
.short('p')
.long("data-path")
.validator(is_valid_data_dir)
.takes_value(true),
)
.arg(
Arg::new("batch-size")
.help("The batch size of each query, or use DataFusion default")
.short('c')
.long("batch-size")
.validator(is_valid_batch_size)
.takes_value(true),
)
.arg(
Arg::new("file")
.help("Execute commands from file(s), then exit")
.short('f')
.long("file")
.multiple_occurrences(true)
.validator(is_valid_file)
.takes_value(true),
)
.arg(
Arg::new("format")
.help("Output format")
.long("format")
.default_value("table")
.possible_values(
&all_print_formats()
.iter()
.map(|format| format.to_string())
.collect::<Vec<_>>()
.iter()
.map(|i| i.as_str())
.collect::<Vec<_>>(),
)
.takes_value(true),
)
.arg(
Arg::new("host")
.help("Ballista scheduler host")
.long("host")
.takes_value(true),
)
.arg(
Arg::new("port")
.help("Ballista scheduler port")
.long("port")
.takes_value(true),
)
.arg(
Arg::new("quiet")
.help("Reduce printing other than the results and work quietly")
.short('q')
.long("quiet")
.takes_value(false),
)
.get_matches();

let quiet = matches.is_present("quiet");

if !quiet {
println!("DataFusion CLI v{}\n", DATAFUSION_CLI_VERSION);
}
let args = Args::parse();

let host = matches.value_of("host");
let port = matches
.value_of("port")
.and_then(|port| port.parse::<u16>().ok());
if !args.quiet {
println!("DataFusion CLI v{}", DATAFUSION_CLI_VERSION);
}

if let Some(path) = matches.value_of("data-path") {
if let Some(ref path) = args.data_path {
let p = Path::new(path);
env::set_current_dir(&p).unwrap();
};

let mut execution_config = ExecutionConfig::new().with_information_schema(true);

if let Some(batch_size) = matches
.value_of("batch-size")
.and_then(|size| size.parse::<usize>().ok())
{
if let Some(batch_size) = args.batch_size {
execution_config = execution_config.with_batch_size(batch_size);
};

let mut ctx: Context = match (host, port) {
(Some(h), Some(p)) => Context::new_remote(h, p)?,
let mut ctx: Context = match (args.host, args.port) {
(Some(ref h), Some(p)) => Context::new_remote(h, p)?,
_ => Context::new_local(&execution_config),
};

let format = matches
.value_of("format")
.expect("No format is specified")
.parse::<PrintFormat>()
.expect("Invalid format");

let mut print_options = PrintOptions { format, quiet };
let mut print_options = PrintOptions {
format: args.format,
quiet: args.quiet,
};

if let Some(file_paths) = matches.values_of("file") {
let files = file_paths
let files = args.file;
if !files.is_empty() {
let files = files
.into_iter()
.map(|file_path| File::open(file_path).unwrap())
.collect::<Vec<_>>();
for file in files {
Expand Down
70 changes: 4 additions & 66 deletions datafusion-cli/src/print_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ use arrow::json::{ArrayWriter, LineDelimitedWriter};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::arrow::util::pretty;
use datafusion::error::{DataFusionError, Result};
use std::fmt;
use std::str::FromStr;

/// Allow records to be printed in different formats
#[derive(Debug, PartialEq, Eq, Clone)]
#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone)]
pub enum PrintFormat {
Csv,
Tsv,
Expand All @@ -34,40 +33,11 @@ pub enum PrintFormat {
NdJson,
}

/// returns all print formats
pub fn all_print_formats() -> Vec<PrintFormat> {
vec![
PrintFormat::Csv,
PrintFormat::Tsv,
PrintFormat::Table,
PrintFormat::Json,
PrintFormat::NdJson,
]
}

impl FromStr for PrintFormat {
type Err = ();
fn from_str(s: &str) -> std::result::Result<Self, ()> {
match s.to_lowercase().as_str() {
"csv" => Ok(Self::Csv),
"tsv" => Ok(Self::Tsv),
"table" => Ok(Self::Table),
"json" => Ok(Self::Json),
"ndjson" => Ok(Self::NdJson),
_ => Err(()),
}
}
}
type Err = String;

impl fmt::Display for PrintFormat {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::Csv => write!(f, "csv"),
Self::Tsv => write!(f, "tsv"),
Self::Table => write!(f, "table"),
Self::Json => write!(f, "json"),
Self::NdJson => write!(f, "ndjson"),
}
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
clap::ArgEnum::from_str(s, true)
}
}

Expand Down Expand Up @@ -123,38 +93,6 @@ mod tests {
use datafusion::from_slice::FromSlice;
use std::sync::Arc;

#[test]
fn test_from_str() {
let format = "csv".parse::<PrintFormat>().unwrap();
assert_eq!(PrintFormat::Csv, format);

let format = "tsv".parse::<PrintFormat>().unwrap();
assert_eq!(PrintFormat::Tsv, format);

let format = "json".parse::<PrintFormat>().unwrap();
assert_eq!(PrintFormat::Json, format);

let format = "ndjson".parse::<PrintFormat>().unwrap();
assert_eq!(PrintFormat::NdJson, format);

let format = "table".parse::<PrintFormat>().unwrap();
assert_eq!(PrintFormat::Table, format);
}

#[test]
fn test_to_str() {
assert_eq!("csv", PrintFormat::Csv.to_string());
assert_eq!("table", PrintFormat::Table.to_string());
assert_eq!("tsv", PrintFormat::Tsv.to_string());
assert_eq!("json", PrintFormat::Json.to_string());
assert_eq!("ndjson", PrintFormat::NdJson.to_string());
}

#[test]
fn test_from_str_failure() {
assert!("pretty".parse::<PrintFormat>().is_err());
}

#[test]
fn test_print_batches_with_sep() {
let batches = vec![];
Expand Down

0 comments on commit 40df55f

Please sign in to comment.