diff --git a/src/cmd/joinp.rs b/src/cmd/joinp.rs index a00264feb..075d5a4c7 100644 --- a/src/cmd/joinp.rs +++ b/src/cmd/joinp.rs @@ -55,6 +55,13 @@ joinp options: --filter-right Filter the right CSV data set by the given Polars SQL expression BEFORE the join. Only rows where it evaluates to true are kept. + --validate Validate the join keys BEFORE performing the join. + Valid values are: + none - No validation is performed. + onetomany - join keys are unique in the left data set. + manytoone - join keys are unique in the right data set. + onetoone - join keys are unique in both left & right data sets. + [default: none] --nulls When set, joins will work on empty fields. Otherwise, empty fields are completely ignored. --try-parsedates When set, the join will attempt to parse the columns @@ -147,7 +154,7 @@ use std::{ use polars::{ chunked_array::object::{AsOfOptions, AsofStrategy}, datatypes::AnyValue, - frame::hash_join::JoinType, + frame::hash_join::{JoinType, JoinValidation}, prelude::{CsvWriter, LazyCsvReader, LazyFileListReader, LazyFrame, SerWriter, SortOptions}, }; use serde::Deserialize; @@ -168,6 +175,7 @@ struct Args { flag_cross: bool, flag_filter_left: Option, flag_filter_right: Option, + flag_validate: Option, flag_nulls: bool, flag_try_parsedates: bool, flag_streaming: bool, @@ -203,6 +211,16 @@ pub fn run(argv: &[&str]) -> CliResult<()> { } let join = args.new_join(args.flag_try_parsedates)?; + // safety: flag_validate is always is_some() as it has a default value + args.flag_validate = Some(args.flag_validate.unwrap().to_lowercase()); + let validation = match args.flag_validate.as_deref() { + Some("none") | None => JoinValidation::ManyToMany, + Some("onetomany") => JoinValidation::OneToMany, + Some("manytoone") => JoinValidation::ManyToOne, + Some("onetoone") => JoinValidation::OneToOne, + Some(s) => return fail_clierror!("Invalid join validation: {s}"), + }; + let join_shape = match ( args.flag_left, args.flag_left_anti, @@ -211,12 +229,24 @@ pub fn run(argv: &[&str]) -> CliResult<()> { args.flag_cross, args.flag_asof, ) { - (false, false, false, false, false, false) => join.polars_join(JoinType::Inner, false), - (true, false, false, false, false, false) => join.polars_join(JoinType::Left, false), - (false, true, false, false, false, false) => join.polars_join(JoinType::Anti, false), - (false, false, true, false, false, false) => join.polars_join(JoinType::Semi, false), - (false, false, false, true, false, false) => join.polars_join(JoinType::Outer, false), - (false, false, false, false, true, false) => join.polars_join(JoinType::Cross, false), + (false, false, false, false, false, false) => { + join.polars_join(JoinType::Inner, validation, false) + } + (true, false, false, false, false, false) => { + join.polars_join(JoinType::Left, validation, false) + } + (false, true, false, false, false, false) => { + join.polars_join(JoinType::Anti, validation, false) + } + (false, false, true, false, false, false) => { + join.polars_join(JoinType::Semi, validation, false) + } + (false, false, false, true, false, false) => { + join.polars_join(JoinType::Outer, validation, false) + } + (false, false, false, false, true, false) => { + join.polars_join(JoinType::Cross, validation, false) + } (false, false, false, false, false, true) => { // safety: flag_strategy is always is_some() as it has a default value args.flag_strategy = Some(args.flag_strategy.unwrap().to_lowercase()); @@ -261,7 +291,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> { .collect(), ); } - join.polars_join(JoinType::AsOf(asof_options), true) + join.polars_join(JoinType::AsOf(asof_options), validation, true) } _ => fail!("Please pick exactly one join operation."), }?; @@ -289,7 +319,12 @@ struct JoinStruct { } impl JoinStruct { - fn polars_join(mut self, jointype: JoinType, asof_join: bool) -> CliResult<(usize, usize)> { + fn polars_join( + mut self, + jointype: JoinType, + validation: JoinValidation, + asof_join: bool, + ) -> CliResult<(usize, usize)> { let selcols1: Vec<_> = self.sel1.split(',').map(polars::lazy::dsl::col).collect(); let selcols2: Vec<_> = self.sel2.split(',').map(polars::lazy::dsl::col).collect(); @@ -320,6 +355,7 @@ impl JoinStruct { .join_builder() .with(self.lf2.with_optimizations(optimize_all)) .how(JoinType::Cross) + .validate(validation) .force_parallel(true) .finish() .collect()? @@ -337,6 +373,7 @@ impl JoinStruct { .left_on(selcols1) .right_on(selcols2) .how(jointype) + .validate(validation) .force_parallel(true) .finish() .collect()? diff --git a/tests/test_joinp.rs b/tests/test_joinp.rs index 668882cb8..018614778 100644 --- a/tests/test_joinp.rs +++ b/tests/test_joinp.rs @@ -5,6 +5,7 @@ macro_rules! joinp_test { mod $name { use std::process; + #[allow(unused_imports)] use super::{make_rows, setup}; use crate::workdir::Workdir; @@ -110,6 +111,39 @@ joinp_test!( } ); +joinp_test!( + joinp_outer_left_validate_none, + |wrk: Workdir, mut cmd: process::Command| { + cmd.arg("--left").args(["--validate", "none"]); + let got: Vec> = wrk.read_stdout(&mut cmd); + let expected = make_rows( + false, + vec![ + svec!["Boston", "MA", "Logan Airport"], + svec!["Boston", "MA", "Boston Garden"], + svec!["New York", "NY", ""], + svec!["San Francisco", "CA", ""], + svec!["Buffalo", "NY", "Ralph Wilson Stadium"], + ], + ); + assert_eq!(got, expected); + } +); + +joinp_test!( + joinp_outer_left_validate_onetomany, + |wrk: Workdir, mut cmd: process::Command| { + cmd.arg("--left").args(["--validate", "onetomany"]); + let got: String = wrk.output_stderr(&mut cmd); + assert_eq!( + got, + "Polars error: ComputeError(ErrString(\"the join keys did not fulfil 1:m \ + validation\"))\n" + ); + wrk.assert_err(&mut cmd); + } +); + joinp_test!(joinp_full, |wrk: Workdir, mut cmd: process::Command| { cmd.arg("--full"); let got: Vec> = wrk.read_stdout(&mut cmd);