Skip to content

Commit

Permalink
feat: join add --keys-output option
Browse files Browse the repository at this point in the history
  • Loading branch information
jqnatividad committed Jan 4, 2025
1 parent ca1b5f4 commit 6354689
Showing 1 changed file with 84 additions and 11 deletions.
95 changes: 84 additions & 11 deletions src/cmd/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ join options:
Otherwise, empty fields are completely ignored.
(In fact, any row that has an empty field in the
key specified is ignored.)
--keys-output <file> Write successfully joined keys to <file>.
This means that the keys are written to the output
file when a match is found, with the exception of
anti joins, where keys are written when NO match
is found.
Cross joins do not write keys.
Common options:
-h, --help Display this message
Expand Down Expand Up @@ -112,6 +118,7 @@ struct Args {
flag_ignore_case: bool,
flag_nulls: bool,
flag_delimiter: Option<Delimiter>,
flag_keys_output: Option<String>,
}

pub fn run(argv: &[&str]) -> CliResult<()> {
Expand Down Expand Up @@ -193,6 +200,7 @@ struct IoState<R, W: io::Write> {
no_headers: bool,
casei: bool,
nulls: bool,
keys_wtr: KeysWriter,
}

impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
Expand All @@ -219,9 +227,12 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
let mut row = csv::ByteRecord::new();
let mut key: Vec<ByteString>;
let mut output = csv::ByteRecord::new();

while self.rdr1.read_byte_record(&mut row)? {
key = get_row_key(&self.sel1, &row, self.casei);
if let Some(rows) = validx.values.get(&key) {
self.keys_wtr.write_key(&key)?;

for &rowi in rows {
validx.idx.seek(rowi as u64)?;

Expand All @@ -234,7 +245,9 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
}
}
}
Ok(self.wtr.flush()?)
self.wtr.flush()?;
self.keys_wtr.flush()?;
Ok(())
}

fn outer_join(mut self, right: bool) -> CliResult<()> {
Expand All @@ -253,16 +266,17 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
while self.rdr1.read_byte_record(&mut row)? {
key = get_row_key(&self.sel1, &row, self.casei);
if let Some(rows) = validx.values.get(&key) {
self.keys_wtr.write_key(&key)?;

for &rowi in rows {
validx.idx.seek(rowi as u64)?;
let mut row1 = row.iter();
validx.idx.read_byte_record(&mut scratch)?;
output.clear();
if right {
output.extend(&scratch);
output.extend(&mut row1);
output.extend(&row);
} else {
output.extend(&mut row1);
output.extend(&row);
output.extend(&scratch);
}
self.wtr.write_record(&output)?;
Expand All @@ -279,24 +293,31 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
self.wtr.write_record(&output)?;
}
}
Ok(self.wtr.flush()?)
self.wtr.flush()?;
self.keys_wtr.flush()?;
Ok(())
}

fn left_join(mut self, anti: bool) -> CliResult<()> {
let validx = ValueIndex::new(self.rdr2, &self.sel2, self.casei, self.nulls)?;
let mut row = csv::ByteRecord::new();
let mut key: Vec<ByteString>;

while self.rdr1.read_byte_record(&mut row)? {
key = get_row_key(&self.sel1, &row, self.casei);
if validx.values.get(&key).is_none() {
if anti {
self.keys_wtr.write_key(&key)?;
self.wtr.write_record(&row)?;
}
} else if !anti {
self.keys_wtr.write_key(&key)?;
self.wtr.write_record(&row)?;
}
}
Ok(self.wtr.flush()?)
self.wtr.flush()?;
self.keys_wtr.flush()?;
Ok(())
}

fn full_outer_join(mut self) -> CliResult<()> {
Expand All @@ -309,9 +330,12 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
let mut rdr2_written: Vec<_> = repeat(false).take(validx.num_rows).collect();
let mut row1 = csv::ByteRecord::new();
let mut key: Vec<ByteString>;

while self.rdr1.read_byte_record(&mut row1)? {
key = get_row_key(&self.sel1, &row1, self.casei);
if let Some(rows) = validx.values.get(&key) {
self.keys_wtr.write_key(&key)?;

for &rowi in rows {
rdr2_written[rowi] = true;

Expand Down Expand Up @@ -342,7 +366,9 @@ impl<R: io::Read + io::Seek, W: io::Write> IoState<R, W> {
self.wtr.write_record(&output)?;
}
}
Ok(self.wtr.flush()?)
self.wtr.flush()?;
self.keys_wtr.flush()?;
Ok(())
}

fn cross_join(mut self) -> CliResult<()> {
Expand Down Expand Up @@ -392,9 +418,22 @@ impl Args {
.no_headers(self.flag_no_headers)
.select(self.arg_columns2.clone());

let mut rdr1 = rconf1.reader_file_stdin()?;
let mut rdr2 = rconf2.reader_file_stdin()?;
let mut rdr1 = match rconf1.reader_file_stdin() {
Ok(rdr1) => rdr1,
Err(e) => return fail_clierror!("Failed to read input1: {e}"),
};
let mut rdr2 = match rconf2.reader_file_stdin() {
Ok(rdr2) => rdr2,
Err(e) => return fail_clierror!("Failed to read input2: {e}"),
};
let (sel1, sel2) = self.get_selections(&rconf1, &mut rdr1, &rconf2, &mut rdr2)?;

let keys_wtr = if self.flag_cross {
KeysWriter::new(None)?
} else {
KeysWriter::new(self.flag_keys_output.as_ref())?
};

Ok(IoState {
wtr: Config::new(self.flag_output.as_ref()).writer()?,
rdr1,
Expand All @@ -404,6 +443,7 @@ impl Args {
no_headers: rconf1.no_headers,
casei: self.flag_ignore_case,
nulls: self.flag_nulls,
keys_wtr,
})
}

Expand Down Expand Up @@ -445,8 +485,8 @@ impl<R: io::Read + io::Seek> ValueIndex<R> {
casei: bool,
nulls: bool,
) -> CliResult<ValueIndex<R>> {
let mut val_idx = AHashMap::with_capacity(10000);
let mut row_idx = io::Cursor::new(Vec::with_capacity(8 * 10000));
let mut val_idx = AHashMap::with_capacity(20_000);
let mut row_idx = io::Cursor::new(Vec::with_capacity(8 * 20_000));
let (mut rowi, mut count) = (0_usize, 0_usize);

// This logic is kind of tricky. Basically, we want to include
Expand Down Expand Up @@ -524,3 +564,36 @@ impl<R> fmt::Debug for ValueIndex<R> {
fn get_row_key(sel: &Selection, row: &csv::ByteRecord, casei: bool) -> Vec<ByteString> {
sel.select(row).map(|v| util::transform(v, casei)).collect()
}

struct KeysWriter {
writer: csv::Writer<Box<dyn io::Write>>,
enabled: bool,
}

impl KeysWriter {
fn new(keys_path: Option<&String>) -> CliResult<Self> {
let (writer, enabled) = if let Some(path) = keys_path {
(Config::new(Some(path)).writer()?, true)
} else {
let sink: Box<dyn io::Write> = Box::new(std::io::sink());
(csv::WriterBuilder::new().from_writer(sink), false)
};

Ok(Self { writer, enabled })
}

#[inline]
fn write_key(&mut self, key: &[ByteString]) -> CliResult<()> {
if self.enabled {
self.writer.write_record(key)?;
}
Ok(())
}

fn flush(&mut self) -> CliResult<()> {
if self.enabled {
self.writer.flush()?;
}
Ok(())
}
}

0 comments on commit 6354689

Please sign in to comment.