diff --git a/src/preproc.rs b/src/preproc.rs index c8c5827..94e0f28 100644 --- a/src/preproc.rs +++ b/src/preproc.rs @@ -127,76 +127,100 @@ fn trim(input: Cow) -> Cow { // Aggressive preprocessors -fn lcs_substr(f_line: &str, s_line: &str) -> String { - // grab character iterators from both strings - let f_line_chars = f_line.chars(); - let s_line_chars = s_line.chars(); - - // zip them together and find the common substring from the start - f_line_chars - .zip(s_line_chars) - .take_while(|&(f, s)| f == s) - .map(|(f, _s)| f) - .collect::() - .trim() - .into() //TODO: big optimization needed, this is a wasteful conversion +// Cut prefix of string near given byte index. +// If given index doesn't lie at char boundary, +// returns the biggest prefix with length not exceeding idx. +// If index is bigger than length or string, returns the whole string. +fn trim_byte_adjusted(s: &str, idx: usize) -> &str { + if idx >= s.len() { + return s + } + + if let Some(sub) = s.get(..idx) { + sub + } else { + // Inspect bytes before index + let trailing_continuation = s.as_bytes()[..idx] + .iter() + .rev() + // Multibyte characters are encoded in UTF-8 in the following manner: + // first byte | rest of bytes + // 1..10xxxxx 10xxxxxx + // ^^^^ number of ones is equal to number of bytes in codepoint + // Number of 10xxxxxx bytes in codepoint is at most 3 in valid UTF-8-encoded string, + // so this loop actually runs a little iterations + .take_while(|&byte| byte & 0b1100_0000 == 0b1000_0000) + .count(); + // Subtract 1 to take the first byte in codepoint into account + &s[..idx - trailing_continuation - 1] + } +} + +fn lcs_substr<'a>(f_line: &'a str, s_line: &'a str) -> &'a str { + // find the length of common prefix in byte representations of strings + let prefix_len = f_line.as_bytes() + .iter() + .zip(s_line.as_bytes()) + .take_while(|(&f, &s)| f == s) + .count(); + + trim_byte_adjusted(f_line, prefix_len).trim() } fn remove_common_tokens(input: Cow) -> Cow { let lines: Vec<&str> = input.split('\n').collect(); - let mut l_iter = lines.iter().peekable(); + let mut l_iter = lines.iter(); // TODO: consider whether this can all be done in one pass https://github.com/amzn/askalono/issues/36 - let mut prefix_counts: HashMap = HashMap::new(); + let mut prefix_counts = HashMap::<_, u32>::new(); // pass 1: iterate through the text to record common prefixes - while let Some(line) = l_iter.next() { - if let Some(next) = l_iter.peek() { - let common = lcs_substr(line, next); + if let Some(first) = l_iter.next() { + let mut pair = ("", first); + let line_pairs = std::iter::from_fn(|| { + pair = (pair.1, l_iter.next()?); + Some(pair) + }); + for (a, b) in line_pairs { + let common = lcs_substr(a, b); // why start at 1, then immediately add 1? // lcs_substr compares two lines! // this doesn't need to be exact, just consistent. if common.len() > 3 { - *prefix_counts.entry(common.to_owned()).or_insert(1) += 1; + *prefix_counts.entry(common).or_insert(1) += 1; } } } // look at the most common observed prefix - let max_prefix = prefix_counts.iter().max_by_key(|&(_k, v)| v); - if max_prefix.is_none() { - return input; - } - let (most_common, _) = max_prefix.unwrap(); + let most_common = match prefix_counts.iter().max_by_key(|&(_k, v)| v) { + Some((prefix, _count)) => prefix, + None => return input, + }; // reconcile the count with other longer prefixes that may be stored - let mut final_common_count = 0; - for (k, v) in prefix_counts.iter() { - if k.starts_with(most_common) { - final_common_count += v; - } - } + let common_count = prefix_counts.iter() + .filter_map(|(s, count)| Some(count).filter(|_| s.starts_with(most_common))) + .sum::(); // the common string must be at least 80% of the text - let prefix_threshold: u32 = (0.8f32 * lines.len() as f32) as u32; - if final_common_count < prefix_threshold { + let prefix_threshold = (0.8f32 * lines.len() as f32) as _; + if common_count < prefix_threshold { return input; } // pass 2: remove that substring - let prefix_len = most_common.len(); lines .iter() .map(|line| { if line.starts_with(most_common) { - &line[prefix_len..] + &line[most_common.len()..] } else { - &line - } + line + }.trim() }) - .map(|line| line.trim()) .collect::>() .join("\n") .into() @@ -272,6 +296,40 @@ fn collapse_whitespace(input: Cow) -> Cow { mod tests { use super::*; + #[test] + fn trim_byte_adjusted_respects_multibyte_characters() { + let input = "RustКраб橙蟹🦀"; + let expected = [ + "", + "R", + "Ru", + "Rus", + "Rust", + "Rust", + "RustК", + "RustК", + "RustКр", + "RustКр", + "RustКра", + "RustКра", + "RustКраб", + "RustКраб", + "RustКраб", + "RustКраб橙", + "RustКраб橙", + "RustКраб橙", + "RustКраб橙蟹", + "RustКраб橙蟹", + "RustКраб橙蟹", + "RustКраб橙蟹", + "RustКраб橙蟹🦀", + ]; + + for (i, &outcome) in expected.iter().enumerate() { + assert_eq!(outcome, trim_byte_adjusted(input, i)) + } + } + #[test] fn greatest_substring_removal() { // the funky string syntax \n\ is to add a newline but skip the