diff --git a/src/cmd/join.rs b/src/cmd/join.rs index b7b4f9931..efac848a6 100644 --- a/src/cmd/join.rs +++ b/src/cmd/join.rs @@ -67,6 +67,12 @@ join options: Otherwise, empty fields are completely ignored. (In fact, any row that has an empty field in the key specified is ignored.) + --keys-output Write successfully joined keys to . + This means that the keys are written to the output + file when a match is found, with the exception of + anti joins, where keys are written when NO match + is found. + Cross joins do not write keys. Common options: -h, --help Display this message @@ -112,6 +118,7 @@ struct Args { flag_ignore_case: bool, flag_nulls: bool, flag_delimiter: Option, + flag_keys_output: Option, } pub fn run(argv: &[&str]) -> CliResult<()> { @@ -193,6 +200,7 @@ struct IoState { no_headers: bool, casei: bool, nulls: bool, + keys_wtr: KeysWriter, } impl IoState { @@ -219,9 +227,12 @@ impl IoState { let mut row = csv::ByteRecord::new(); let mut key: Vec; let mut output = csv::ByteRecord::new(); + while self.rdr1.read_byte_record(&mut row)? { key = get_row_key(&self.sel1, &row, self.casei); if let Some(rows) = validx.values.get(&key) { + self.keys_wtr.write_key(&key)?; + for &rowi in rows { validx.idx.seek(rowi as u64)?; @@ -234,7 +245,9 @@ impl IoState { } } } - Ok(self.wtr.flush()?) + self.wtr.flush()?; + self.keys_wtr.flush()?; + Ok(()) } fn outer_join(mut self, right: bool) -> CliResult<()> { @@ -253,16 +266,17 @@ impl IoState { while self.rdr1.read_byte_record(&mut row)? { key = get_row_key(&self.sel1, &row, self.casei); if let Some(rows) = validx.values.get(&key) { + self.keys_wtr.write_key(&key)?; + for &rowi in rows { validx.idx.seek(rowi as u64)?; - let mut row1 = row.iter(); validx.idx.read_byte_record(&mut scratch)?; output.clear(); if right { output.extend(&scratch); - output.extend(&mut row1); + output.extend(&row); } else { - output.extend(&mut row1); + output.extend(&row); output.extend(&scratch); } self.wtr.write_record(&output)?; @@ -279,24 +293,31 @@ impl IoState { self.wtr.write_record(&output)?; } } - Ok(self.wtr.flush()?) + self.wtr.flush()?; + self.keys_wtr.flush()?; + Ok(()) } fn left_join(mut self, anti: bool) -> CliResult<()> { let validx = ValueIndex::new(self.rdr2, &self.sel2, self.casei, self.nulls)?; let mut row = csv::ByteRecord::new(); let mut key: Vec; + while self.rdr1.read_byte_record(&mut row)? { key = get_row_key(&self.sel1, &row, self.casei); if validx.values.get(&key).is_none() { if anti { + self.keys_wtr.write_key(&key)?; self.wtr.write_record(&row)?; } } else if !anti { + self.keys_wtr.write_key(&key)?; self.wtr.write_record(&row)?; } } - Ok(self.wtr.flush()?) + self.wtr.flush()?; + self.keys_wtr.flush()?; + Ok(()) } fn full_outer_join(mut self) -> CliResult<()> { @@ -309,9 +330,12 @@ impl IoState { let mut rdr2_written: Vec<_> = repeat(false).take(validx.num_rows).collect(); let mut row1 = csv::ByteRecord::new(); let mut key: Vec; + while self.rdr1.read_byte_record(&mut row1)? { key = get_row_key(&self.sel1, &row1, self.casei); if let Some(rows) = validx.values.get(&key) { + self.keys_wtr.write_key(&key)?; + for &rowi in rows { rdr2_written[rowi] = true; @@ -342,7 +366,9 @@ impl IoState { self.wtr.write_record(&output)?; } } - Ok(self.wtr.flush()?) + self.wtr.flush()?; + self.keys_wtr.flush()?; + Ok(()) } fn cross_join(mut self) -> CliResult<()> { @@ -392,9 +418,22 @@ impl Args { .no_headers(self.flag_no_headers) .select(self.arg_columns2.clone()); - let mut rdr1 = rconf1.reader_file_stdin()?; - let mut rdr2 = rconf2.reader_file_stdin()?; + let mut rdr1 = match rconf1.reader_file_stdin() { + Ok(rdr1) => rdr1, + Err(e) => return fail_clierror!("Failed to read input1: {e}"), + }; + let mut rdr2 = match rconf2.reader_file_stdin() { + Ok(rdr2) => rdr2, + Err(e) => return fail_clierror!("Failed to read input2: {e}"), + }; let (sel1, sel2) = self.get_selections(&rconf1, &mut rdr1, &rconf2, &mut rdr2)?; + + let keys_wtr = if self.flag_cross { + KeysWriter::new(None)? + } else { + KeysWriter::new(self.flag_keys_output.as_ref())? + }; + Ok(IoState { wtr: Config::new(self.flag_output.as_ref()).writer()?, rdr1, @@ -404,6 +443,7 @@ impl Args { no_headers: rconf1.no_headers, casei: self.flag_ignore_case, nulls: self.flag_nulls, + keys_wtr, }) } @@ -445,8 +485,8 @@ impl ValueIndex { casei: bool, nulls: bool, ) -> CliResult> { - let mut val_idx = AHashMap::with_capacity(10000); - let mut row_idx = io::Cursor::new(Vec::with_capacity(8 * 10000)); + let mut val_idx = AHashMap::with_capacity(20_000); + let mut row_idx = io::Cursor::new(Vec::with_capacity(8 * 20_000)); let (mut rowi, mut count) = (0_usize, 0_usize); // This logic is kind of tricky. Basically, we want to include @@ -524,3 +564,36 @@ impl fmt::Debug for ValueIndex { fn get_row_key(sel: &Selection, row: &csv::ByteRecord, casei: bool) -> Vec { sel.select(row).map(|v| util::transform(v, casei)).collect() } + +struct KeysWriter { + writer: csv::Writer>, + enabled: bool, +} + +impl KeysWriter { + fn new(keys_path: Option<&String>) -> CliResult { + let (writer, enabled) = if let Some(path) = keys_path { + (Config::new(Some(path)).writer()?, true) + } else { + let sink: Box = Box::new(std::io::sink()); + (csv::WriterBuilder::new().from_writer(sink), false) + }; + + Ok(Self { writer, enabled }) + } + + #[inline] + fn write_key(&mut self, key: &[ByteString]) -> CliResult<()> { + if self.enabled { + self.writer.write_record(key)?; + } + Ok(()) + } + + fn flush(&mut self) -> CliResult<()> { + if self.enabled { + self.writer.flush()?; + } + Ok(()) + } +}