diff --git a/autopush-common/src/db/bigtable/bigtable_client/mod.rs b/autopush-common/src/db/bigtable/bigtable_client/mod.rs index 8ca8adc8b..9d1256b7c 100644 --- a/autopush-common/src/db/bigtable/bigtable_client/mod.rs +++ b/autopush-common/src/db/bigtable/bigtable_client/mod.rs @@ -101,6 +101,13 @@ fn message_gc_policy_filter() -> Result, error::BigTableErr Ok(vec![router_gc_policy_filter(), timestamp_filter]) } +/// Return a Column family regex RowFilter +fn family_filter(regex: String) -> data::RowFilter { + let mut filter = data::RowFilter::default(); + filter.set_family_name_regex_filter(regex); + filter +} + /// Escape bytes for RE values /// /// Based off google-re2/perl's quotemeta function @@ -126,16 +133,17 @@ fn escape_bytes(bytes: &[u8]) -> Vec { /// Return a chain of RowFilters limiting to a match of the specified /// `version`'s column value fn version_filter(version: &Uuid) -> Vec { - let mut family_filter = data::RowFilter::default(); - family_filter.set_family_name_regex_filter(format!("^{ROUTER_FAMILY}$")); - let mut cq_filter = data::RowFilter::default(); cq_filter.set_column_qualifier_regex_filter("^version$".as_bytes().to_vec()); let mut value_filter = data::RowFilter::default(); value_filter.set_value_regex_filter(escape_bytes(version.as_bytes())); - vec![family_filter, cq_filter, value_filter] + vec![ + family_filter(format!("^{ROUTER_FAMILY}$")), + cq_filter, + value_filter, + ] } /// Return a newly generated `version` column `Cell` @@ -324,12 +332,13 @@ impl BigTableClientImpl { Ok(()) } - /// Read a given row from the row key. - async fn read_row(&self, row_key: &str) -> Result, error::BigTableError> { - debug!("🉑 Row key: {row_key}"); - let req = self.read_row_request(row_key); + /// Read one row for the [ReadRowsRequest] (assuming only a single row was requested). + async fn read_row( + &self, + req: bigtable::ReadRowsRequest, + ) -> Result, error::BigTableError> { let mut rows = self.read_rows(req).await?; - Ok(rows.remove(row_key)) + Ok(rows.pop_first().map(|(_, v)| v)) } /// Take a big table ReadRowsRequest (containing the keys and filters) and return a set of row data indexed by row key. @@ -717,7 +726,9 @@ impl DbClient for BigTableClientImpl { async fn get_user(&self, uaid: &Uuid) -> DbResult> { let row_key = uaid.as_simple().to_string(); - let Some(mut row) = self.read_row(&row_key).await? else { + let mut req = self.read_row_request(&row_key); + req.set_filter(family_filter(format!("^{ROUTER_FAMILY}$"))); + let Some(mut row) = self.read_row(req).await? else { return Ok(None); }; @@ -819,21 +830,16 @@ impl DbClient for BigTableClientImpl { let row_key = uaid.simple().to_string(); let mut req = self.read_row_request(&row_key); - let mut family_filter = data::RowFilter::default(); - family_filter.set_family_name_regex_filter(format!("^{ROUTER_FAMILY}$")); - let mut cq_filter = data::RowFilter::default(); cq_filter.set_column_qualifier_regex_filter("^chid:.*$".as_bytes().to_vec()); - req.set_filter(filter_chain(vec![ router_gc_policy_filter(), - family_filter, + family_filter(format!("^{ROUTER_FAMILY}$")), cq_filter, ])); - let mut rows = self.read_rows(req).await?; let mut result = HashSet::new(); - if let Some(record) = rows.remove(&row_key) { + if let Some(record) = self.read_row(req).await? { for mut cells in record.cells.into_values() { let Some(cell) = cells.pop() else { continue; @@ -853,7 +859,7 @@ impl DbClient for BigTableClientImpl { /// Delete the channel. Does not delete its associated pending messages. async fn remove_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult { let row_key = uaid.simple().to_string(); - let mut req = self.mutate_row_request(&row_key); + let mut req = self.check_and_mutate_row_request(&row_key); // Delete the column representing the channel_id let column = format!("chid:{}", channel_id.as_hyphenated()); @@ -865,11 +871,14 @@ impl DbClient for BigTableClientImpl { row.cells .insert(ROUTER_FAMILY.to_owned(), vec![new_version_cell(expiry)]); mutations.extend(self.get_mutations(row.cells)?); - req.set_mutations(mutations); - self.mutate_row(req).await?; - // XXX: this could be check_and_mutate to determine if the channel existed - Ok(true) + // check if the channel existed/was actually removed + let mut cq_filter = data::RowFilter::default(); + cq_filter.set_column_qualifier_regex_filter(format!("^{column}$").into_bytes()); + req.set_predicate_filter(filter_chain(vec![router_gc_policy_filter(), cq_filter])); + req.set_true_mutations(mutations); + + Ok(self.check_and_mutate(req).await?) } /// Remove the node_id @@ -1048,7 +1057,10 @@ impl DbClient for BigTableClientImpl { rows.set_row_ranges(row_ranges); req.set_rows(rows); - req.set_filter(filter_chain(message_gc_policy_filter()?)); + let mut filters = message_gc_policy_filter()?; + filters.push(family_filter(format!("^{MESSAGE_TOPIC_FAMILY}$"))); + + req.set_filter(filter_chain(filters)); if limit > 0 { trace!("🉑 Setting limit to {limit}"); req.set_rows_limit(limit as i64); @@ -1114,7 +1126,10 @@ impl DbClient for BigTableClientImpl { // therefore run two filters, one to fetch the candidate IDs // and another to fetch the content of the messages. */ - req.set_filter(filter_chain(message_gc_policy_filter()?)); + let mut filters = message_gc_policy_filter()?; + filters.push(family_filter(format!("^{MESSAGE_FAMILY}$"))); + + req.set_filter(filter_chain(filters)); if limit > 0 { req.set_rows_limit(limit as i64); } @@ -1300,7 +1315,8 @@ mod tests { assert_eq!(channels, new_channels); // can we remove a channel? - client.remove_channel(&uaid, &chid_to_remove).await?; + assert!(client.remove_channel(&uaid, &chid_to_remove).await?); + assert!(!client.remove_channel(&uaid, &chid_to_remove).await?); new_channels.remove(&chid_to_remove); let channels = client.get_channels(&uaid).await?; assert_eq!(channels, new_channels); @@ -1387,6 +1403,7 @@ mod tests { assert!(client.remove_channel(&uaid, &chid).await.is_ok()); // Now, can we do all that with topic messages + client.add_channel(&uaid, &topic_chid).await?; let test_data = "An_encrypted_pile_of_crap_with_a_topic".to_owned(); let timestamp = now(); let sort_key = now(); @@ -1465,7 +1482,8 @@ mod tests { }], ); client.write_row(row).await.unwrap(); - let Some(row) = client.read_row(&row_key).await.unwrap() else { + let req = client.read_row_request(&row_key); + let Some(row) = client.read_row(req).await.unwrap() else { panic!("Expected row"); }; assert_eq!(row.cells.len(), 1);