Skip to content

Commit

Permalink
Merge pull request #1147 from jqnatividad/joinp-validation
Browse files Browse the repository at this point in the history
`joinp`: add --validate option
  • Loading branch information
jqnatividad authored Jul 16, 2023
2 parents 3c8819e + dcdc42e commit a5147c1
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 9 deletions.
55 changes: 46 additions & 9 deletions src/cmd/joinp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ joinp options:
--filter-right <arg> Filter the right CSV data set by the given Polars SQL
expression BEFORE the join. Only rows where it
evaluates to true are kept.
--validate <arg> Validate the join keys BEFORE performing the join.
Valid values are:
none - No validation is performed.
onetomany - join keys are unique in the left data set.
manytoone - join keys are unique in the right data set.
onetoone - join keys are unique in both left & right data sets.
[default: none]
--nulls When set, joins will work on empty fields.
Otherwise, empty fields are completely ignored.
--try-parsedates When set, the join will attempt to parse the columns
Expand Down Expand Up @@ -147,7 +154,7 @@ use std::{
use polars::{
chunked_array::object::{AsOfOptions, AsofStrategy},
datatypes::AnyValue,
frame::hash_join::JoinType,
frame::hash_join::{JoinType, JoinValidation},
prelude::{CsvWriter, LazyCsvReader, LazyFileListReader, LazyFrame, SerWriter, SortOptions},
};
use serde::Deserialize;
Expand All @@ -168,6 +175,7 @@ struct Args {
flag_cross: bool,
flag_filter_left: Option<String>,
flag_filter_right: Option<String>,
flag_validate: Option<String>,
flag_nulls: bool,
flag_try_parsedates: bool,
flag_streaming: bool,
Expand Down Expand Up @@ -203,6 +211,16 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
}
let join = args.new_join(args.flag_try_parsedates)?;

// safety: flag_validate is always is_some() as it has a default value
args.flag_validate = Some(args.flag_validate.unwrap().to_lowercase());
let validation = match args.flag_validate.as_deref() {
Some("none") | None => JoinValidation::ManyToMany,
Some("onetomany") => JoinValidation::OneToMany,
Some("manytoone") => JoinValidation::ManyToOne,
Some("onetoone") => JoinValidation::OneToOne,
Some(s) => return fail_clierror!("Invalid join validation: {s}"),
};

let join_shape = match (
args.flag_left,
args.flag_left_anti,
Expand All @@ -211,12 +229,24 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
args.flag_cross,
args.flag_asof,
) {
(false, false, false, false, false, false) => join.polars_join(JoinType::Inner, false),
(true, false, false, false, false, false) => join.polars_join(JoinType::Left, false),
(false, true, false, false, false, false) => join.polars_join(JoinType::Anti, false),
(false, false, true, false, false, false) => join.polars_join(JoinType::Semi, false),
(false, false, false, true, false, false) => join.polars_join(JoinType::Outer, false),
(false, false, false, false, true, false) => join.polars_join(JoinType::Cross, false),
(false, false, false, false, false, false) => {
join.polars_join(JoinType::Inner, validation, false)
}
(true, false, false, false, false, false) => {
join.polars_join(JoinType::Left, validation, false)
}
(false, true, false, false, false, false) => {
join.polars_join(JoinType::Anti, validation, false)
}
(false, false, true, false, false, false) => {
join.polars_join(JoinType::Semi, validation, false)
}
(false, false, false, true, false, false) => {
join.polars_join(JoinType::Outer, validation, false)
}
(false, false, false, false, true, false) => {
join.polars_join(JoinType::Cross, validation, false)
}
(false, false, false, false, false, true) => {
// safety: flag_strategy is always is_some() as it has a default value
args.flag_strategy = Some(args.flag_strategy.unwrap().to_lowercase());
Expand Down Expand Up @@ -261,7 +291,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
.collect(),
);
}
join.polars_join(JoinType::AsOf(asof_options), true)
join.polars_join(JoinType::AsOf(asof_options), validation, true)
}
_ => fail!("Please pick exactly one join operation."),
}?;
Expand Down Expand Up @@ -289,7 +319,12 @@ struct JoinStruct {
}

impl JoinStruct {
fn polars_join(mut self, jointype: JoinType, asof_join: bool) -> CliResult<(usize, usize)> {
fn polars_join(
mut self,
jointype: JoinType,
validation: JoinValidation,
asof_join: bool,
) -> CliResult<(usize, usize)> {
let selcols1: Vec<_> = self.sel1.split(',').map(polars::lazy::dsl::col).collect();
let selcols2: Vec<_> = self.sel2.split(',').map(polars::lazy::dsl::col).collect();

Expand Down Expand Up @@ -320,6 +355,7 @@ impl JoinStruct {
.join_builder()
.with(self.lf2.with_optimizations(optimize_all))
.how(JoinType::Cross)
.validate(validation)
.force_parallel(true)
.finish()
.collect()?
Expand All @@ -337,6 +373,7 @@ impl JoinStruct {
.left_on(selcols1)
.right_on(selcols2)
.how(jointype)
.validate(validation)
.force_parallel(true)
.finish()
.collect()?
Expand Down
34 changes: 34 additions & 0 deletions tests/test_joinp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ macro_rules! joinp_test {
mod $name {
use std::process;

#[allow(unused_imports)]
use super::{make_rows, setup};
use crate::workdir::Workdir;

Expand Down Expand Up @@ -110,6 +111,39 @@ joinp_test!(
}
);

joinp_test!(
joinp_outer_left_validate_none,
|wrk: Workdir, mut cmd: process::Command| {
cmd.arg("--left").args(["--validate", "none"]);
let got: Vec<Vec<String>> = wrk.read_stdout(&mut cmd);
let expected = make_rows(
false,
vec![
svec!["Boston", "MA", "Logan Airport"],
svec!["Boston", "MA", "Boston Garden"],
svec!["New York", "NY", ""],
svec!["San Francisco", "CA", ""],
svec!["Buffalo", "NY", "Ralph Wilson Stadium"],
],
);
assert_eq!(got, expected);
}
);

joinp_test!(
joinp_outer_left_validate_onetomany,
|wrk: Workdir, mut cmd: process::Command| {
cmd.arg("--left").args(["--validate", "onetomany"]);
let got: String = wrk.output_stderr(&mut cmd);
assert_eq!(
got,
"Polars error: ComputeError(ErrString(\"the join keys did not fulfil 1:m \
validation\"))\n"
);
wrk.assert_err(&mut cmd);
}
);

joinp_test!(joinp_full, |wrk: Workdir, mut cmd: process::Command| {
cmd.arg("--full");
let got: Vec<Vec<String>> = wrk.read_stdout(&mut cmd);
Expand Down

0 comments on commit a5147c1

Please sign in to comment.