Skip to content

Commit

Permalink
Improve ChecksummedBytes::extend and clarify data integrity guarantee
Browse files Browse the repository at this point in the history
ChecksummedBytes maintains a data buffer and a checksum and guarantees that only validated data can be accessed. Transformations such as `split_off`, `extend`, or `slice` (introduced in this change), may trigger a validation (and return an IntegrityError on failure), or propagate existing checksum(s) if possible, allowing for later validation.

This change clarifies the data integrity guarantee in the docs and optimizes the extend method to avoid re-validation when the checksums for both slices can be combined. It also avoid a redundant buffer allocation.

Signed-off-by: Alessandro Passaro <[email protected]>
  • Loading branch information
passaro committed Oct 25, 2023
1 parent 1ea6143 commit 711d37d
Showing 1 changed file with 189 additions and 76 deletions.
265 changes: 189 additions & 76 deletions mountpoint-s3/src/checksums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ use std::ops::RangeBounds;

use bytes::{Bytes, BytesMut};
use mountpoint_s3_crt::checksums::crc32c::{self, Crc32c};

use thiserror::Error;

/// A `ChecksummedBytes` is a bytes buffer that carries its checksum.
/// The implementation guarantees that integrity will be validated before the data can be accessed.
/// Data transformations will either fail returning an [IntegrityError], or propagate the checksum
/// so that it can be validated on access.
#[derive(Clone, Debug)]
pub struct ChecksummedBytes {
orig_bytes: Bytes,
Expand All @@ -29,21 +33,21 @@ impl ChecksummedBytes {
Self::new(bytes, checksum)
}

/// Convert the `ChecksummedBytes` into `Bytes`, data integrity will be validated before converting.
/// Convert the [ChecksummedBytes] into [Bytes], data integrity will be validated before converting.
///
/// Return `IntegrityError` on data corruption.
/// Return [IntegrityError] on data corruption.
pub fn into_bytes(self) -> Result<Bytes, IntegrityError> {
self.validate()?;

Ok(self.curr_slice)
}

/// Returns the number of bytes contained in this `ChecksummedBytes`.
/// Returns the number of bytes contained in this [ChecksummedBytes].
pub fn len(&self) -> usize {
self.curr_slice.len()
}

/// Returns true if the `ChecksummedBytes` has a length of 0.
/// Returns true if the [ChecksummedBytes] has a length of 0.
pub fn is_empty(&self) -> bool {
self.curr_slice.is_empty()
}
Expand All @@ -63,47 +67,77 @@ impl ChecksummedBytes {
}
}

/// Append the given checksummed bytes to current `ChecksummedBytes`, ensure that data integrity will
/// be validated.
/// Returns a slice of self for the provided range.
///
/// This operation just increases the reference count and sets a few indices,
/// so there will be no validation and the checksum will not be recomputed.
pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
Self {
orig_bytes: self.orig_bytes.clone(),
curr_slice: self.curr_slice.slice(range),
checksum: self.checksum,
}
}

/// Returns a copy of this slice, with the guarantee that the checksum is computed exactly
/// on the slice, rather than on a larger containing buffer.
///
/// Return [IntegrityError] if data corruption is detected.
pub fn shrink_to_fit(&self) -> Result<Self, IntegrityError> {
if self.curr_slice.len() == self.orig_bytes.len() {
return Ok(self.clone());
}

let result = Self::from_bytes(self.curr_slice.clone());
self.validate()?;
Ok(result)
}

/// Append the given checksummed bytes to current [ChecksummedBytes]. Will combine the
/// existing checksums if possible, or compute a new one and validate data integrity.
///
/// Return `IntegrityError` on data corruption.
/// Return [IntegrityError] if data corruption is detected.
pub fn extend(&mut self, extend: ChecksummedBytes) -> Result<(), IntegrityError> {
let curr_len = self.curr_slice.len();
let total_len = curr_len + extend.len();

let mut bytes_mut = BytesMut::with_capacity(total_len);
bytes_mut.extend_from_slice(&self.curr_slice);
bytes_mut.extend_from_slice(&extend.curr_slice);
let new_bytes = bytes_mut.freeze();
let new_checksum = crc32c::checksum(&new_bytes);
let new_checksummed_bytes = ChecksummedBytes::new(new_bytes, new_checksum);

// Validate data integrity with checksum bracketing.
{
// 1. repeat the operation, which means copying into a new buffer in this case.
let mut bytes_mut_dup = BytesMut::with_capacity(total_len);
bytes_mut_dup.extend_from_slice(&self.curr_slice);
bytes_mut_dup.extend_from_slice(&extend.curr_slice);
let new_bytes_dup = bytes_mut_dup.freeze();
let new_checksum_dup = crc32c::checksum(&new_bytes_dup);

// 2. compare the checksum between the two transformations.
if new_checksum != new_checksum_dup {
return Err(IntegrityError::ChecksumMismatch(new_checksum, new_checksum_dup));
}

// 3. validate original buffers to make sure that the data we have copied are still valid.
self.validate()?;
if extend.is_empty() {
// No op, but check that `extend` was not corrupted
extend.validate()?;
return Ok(());
}

if self.is_empty() {
// Replace with `extend`, but check that `self` was not corrupted
self.validate()?;
*self = extend;
return Ok(());
}

*self = new_checksummed_bytes;
// When appending two slices, we can combine their checksums and obtain the new checksum
// without having to recompute it from the data.
// However, since a `ChecksummedBytes` potentially holds the checksum of some larger buffer,
// rather than the exact one for the slice, we need to first invoke `shrink_to_fit` on each
// slice and use the resulting exact checksums.
let prefix = self.shrink_to_fit()?;
assert_eq!(prefix.orig_bytes.len(), prefix.curr_slice.len());
let suffix = extend.shrink_to_fit()?;
assert_eq!(suffix.orig_bytes.len(), suffix.curr_slice.len());

// Combine the checksums.
let new_checksum = combine_checksums(prefix.checksum, suffix.checksum, suffix.len());

// Combine the slices.
let new_bytes = {
let mut bytes_mut = BytesMut::with_capacity(prefix.len() + suffix.len());
bytes_mut.extend_from_slice(&prefix.curr_slice);
bytes_mut.extend_from_slice(&suffix.curr_slice);
bytes_mut.freeze()
};
*self = ChecksummedBytes::new(new_bytes, new_checksum);
Ok(())
}

/// Validate data integrity in this `ChecksummedBytes`.
/// Validate data integrity in this [ChecksummedBytes].
///
/// Return `IntegrityError` on data corruption.
/// Return [IntegrityError] on data corruption.
pub fn validate(&self) -> Result<(), IntegrityError> {
let checksum = crc32c::checksum(&self.orig_bytes);
if self.checksum != checksum {
Expand Down Expand Up @@ -161,14 +195,10 @@ impl PartialEq for ChecksummedBytes {
return false;
}

if self.orig_bytes == other.orig_bytes && self.checksum == other.checksum {
return true;
}

let result = self.orig_bytes == other.orig_bytes && self.checksum == other.checksum;
self.validate().expect("should be valid");
other.validate().expect("should be valid");

true
result
}
}

Expand All @@ -192,8 +222,7 @@ mod tests {
fn test_into_bytes_integrity_error() {
let bytes = Bytes::from_static(b"some bytes");
let checksum = crc32c::checksum(&bytes);
let mut checksummed_bytes = ChecksummedBytes::new(bytes, checksum);
checksummed_bytes.orig_bytes = Bytes::from_static(b"new bytes");
let checksummed_bytes = ChecksummedBytes::new(Bytes::from_static(b"new bytes"), checksum);

let actual = checksummed_bytes.into_bytes();
assert!(matches!(actual, Err(IntegrityError::ChecksumMismatch(_, _))));
Expand All @@ -220,31 +249,80 @@ mod tests {
}

#[test]
fn test_extend() {
fn test_slice() {
let range = 3..7;
let bytes = Bytes::from_static(b"some bytes");
let expected = Bytes::from_static(b"some bytes extended");
let expected = bytes.clone();
let expected_slice = bytes.slice(range.clone());
let checksum = crc32c::checksum(&bytes);
let mut checksummed_bytes = ChecksummedBytes::new(bytes, checksum);
let original = ChecksummedBytes::new(bytes, checksum);
let slice = original.slice(range);

assert_eq!(expected, original.orig_bytes);
assert_eq!(expected, original.curr_slice);
assert_eq!(expected, slice.orig_bytes);
assert_eq!(expected_slice, slice.curr_slice);
assert_eq!(checksum, original.checksum);
assert_eq!(checksum, slice.checksum);
}

let extend = Bytes::from_static(b" extended");
let extend_checksum = crc32c::checksum(&extend);
let extend = ChecksummedBytes::new(extend, extend_checksum);
checksummed_bytes.extend(extend).unwrap();
#[test]
fn test_shrink_to_fit() {
let original = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
let unchanged = original.shrink_to_fit().unwrap();
assert_eq!(original.curr_slice, unchanged.curr_slice);
assert_eq!(original.orig_bytes, unchanged.orig_bytes);
assert_eq!(original.checksum, unchanged.checksum);

let slice = original.clone().split_off(5);
let shrunken = slice.shrink_to_fit().unwrap();
assert_eq!(slice.curr_slice, shrunken.curr_slice);
assert_ne!(slice.orig_bytes, shrunken.orig_bytes);
assert_ne!(slice.checksum, shrunken.checksum);
}

#[test]
fn test_shrink_to_fit_corrupted() {
let checksum = crc32c::checksum(b"some bytes");
let original = ChecksummedBytes::new(Bytes::from_static(b"other bytes"), checksum);
assert!(matches!(
original.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));

let unchanged = original.shrink_to_fit().unwrap();
assert_eq!(original.curr_slice, unchanged.curr_slice);
assert_eq!(original.orig_bytes, unchanged.orig_bytes);
assert_eq!(original.checksum, unchanged.checksum);
assert!(matches!(
unchanged.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));

let slice = original.clone().split_off(5);
assert!(matches!(slice.validate(), Err(IntegrityError::ChecksumMismatch(_, _))));

let result = slice.shrink_to_fit();
assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
}

#[test]
fn test_extend() {
let expected = Bytes::from_static(b"some bytes extended");
let mut checksummed_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
let extend_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b" extended"));
checksummed_bytes.extend(extend_bytes).unwrap();
let actual = checksummed_bytes.curr_slice;
assert_eq!(expected, actual);
}

#[test]
fn test_extend_after_split() {
let split_off_at = 4;
let bytes = Bytes::from_static(b"some bytes");
let expected = Bytes::from_static(b"some ext");
let checksum = crc32c::checksum(&bytes);
let mut checksummed_bytes = ChecksummedBytes::new(bytes, checksum);

let extend = Bytes::from_static(b" extended");
let extend_checksum = crc32c::checksum(&extend);
let mut extend = ChecksummedBytes::new(extend, extend_checksum);
let expected = Bytes::from_static(b"some ext");
let mut checksummed_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
let mut extend = ChecksummedBytes::from_bytes(Bytes::from_static(b" extended"));
checksummed_bytes.split_off(split_off_at);
extend.split_off(split_off_at);
checksummed_bytes.extend(extend).unwrap();
Expand All @@ -254,34 +332,69 @@ mod tests {

#[test]
fn test_extend_self_corrupted() {
let bytes = Bytes::from_static(b"some bytes");
let checksum = crc32c::checksum(&bytes);
let mut checksummed_bytes = ChecksummedBytes::new(bytes, checksum);
let corrupted_bytes = Bytes::from_static(b"corrupted data");
let checksum = crc32c::checksum(b"some bytes");
let mut checksummed_bytes = ChecksummedBytes::new(corrupted_bytes, checksum);
assert!(matches!(
checksummed_bytes.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));

let extend = ChecksummedBytes::from_bytes(Bytes::from_static(b" extended"));
assert!(matches!(extend.validate(), Ok(())));

let currupted_bytes = Bytes::from_static(b"corrupted data");
checksummed_bytes.orig_bytes = currupted_bytes.clone();
checksummed_bytes.curr_slice = currupted_bytes;
checksummed_bytes.extend(extend).unwrap();
assert!(matches!(
checksummed_bytes.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));
}

#[test]
fn test_extend_after_split_self_corrupted() {
let corrupted_bytes = Bytes::from_static(b"corrupted data");
let checksum = crc32c::checksum(b"some bytes");
let mut checksummed_bytes = ChecksummedBytes::new(corrupted_bytes, checksum);
assert!(matches!(
checksummed_bytes.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));
checksummed_bytes.split_off(4);

let extend = ChecksummedBytes::from_bytes(Bytes::from_static(b" extended"));
assert!(matches!(extend.validate(), Ok(())));

let extend = Bytes::from_static(b" extended");
let extend_checksum = crc32c::checksum(&extend);
let extend = ChecksummedBytes::new(extend, extend_checksum);
let result = checksummed_bytes.extend(extend);
assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
}

#[test]
fn test_extend_other_corrupted() {
let bytes = Bytes::from_static(b"some bytes");
let checksum = crc32c::checksum(&bytes);
let mut checksummed_bytes = ChecksummedBytes::new(bytes, checksum);
let mut checksummed_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
assert!(matches!(checksummed_bytes.validate(), Ok(())));

let extend = Bytes::from_static(b" extended");
let extend_checksum = crc32c::checksum(&extend);
let mut extend = ChecksummedBytes::new(extend, extend_checksum);
let corrupted_bytes = Bytes::from_static(b"corrupted data");
let extend_checksum = crc32c::checksum(b" extended");
let extend = ChecksummedBytes::new(corrupted_bytes, extend_checksum);
assert!(matches!(extend.validate(), Err(IntegrityError::ChecksumMismatch(_, _))));

let currupted_bytes = Bytes::from_static(b"corrupted data");
extend.orig_bytes = currupted_bytes.clone();
extend.curr_slice = currupted_bytes;
checksummed_bytes.extend(extend).unwrap();
assert!(matches!(
checksummed_bytes.validate(),
Err(IntegrityError::ChecksumMismatch(_, _))
));
}

#[test]
fn test_extend_after_split_other_corrupted() {
let mut checksummed_bytes = ChecksummedBytes::from_bytes(Bytes::from_static(b"some bytes"));
assert!(matches!(checksummed_bytes.validate(), Ok(())));

let corrupted_bytes = Bytes::from_static(b"corrupted data");
let extend_checksum = crc32c::checksum(b" extended");
let mut extend = ChecksummedBytes::new(corrupted_bytes, extend_checksum);
extend.split_off(4);
assert!(matches!(extend.validate(), Err(IntegrityError::ChecksumMismatch(_, _))));

let result = checksummed_bytes.extend(extend);
assert!(matches!(result, Err(IntegrityError::ChecksumMismatch(_, _))));
Expand Down

0 comments on commit 711d37d

Please sign in to comment.