diff --git a/mountpoint-s3-crt/CHANGELOG.md b/mountpoint-s3-crt/CHANGELOG.md index e98ca69a7..d49b44c11 100644 --- a/mountpoint-s3-crt/CHANGELOG.md +++ b/mountpoint-s3-crt/CHANGELOG.md @@ -4,6 +4,9 @@ * Checksum hashers no longer implement `std::hash::Hasher`. ([#1082](https://github.com/awslabs/mountpoint-s3/pull/1082)) * Add bindings to remaining checksum types CRC64, SHA1, and SHA256. ([#1082](https://github.com/awslabs/mountpoint-s3/pull/1082)) * Add wrapping type `ByteBuf` for `aws_byte_buf`. ([#1082](https://github.com/awslabs/mountpoint-s3/pull/1082)) +* `HeadersError::HeaderNotFound` and `HeadersError::Invalid` variants now include the name of the header. + Despite the new field being private, this may impact any code that was pattern matching on these variants. + ([#1205](https://github.com/awslabs/mountpoint-s3/pull/1205)) ## v0.10.0 (October 17, 2024) diff --git a/mountpoint-s3-crt/src/http/request_response.rs b/mountpoint-s3-crt/src/http/request_response.rs index 06cce753b..5c9e7d966 100644 --- a/mountpoint-s3-crt/src/http/request_response.rs +++ b/mountpoint-s3-crt/src/http/request_response.rs @@ -62,29 +62,38 @@ unsafe impl Send for Headers {} /// allow threads to simultaneously modify it. unsafe impl Sync for Headers {} -/// Errors returned by operations on [Headers] +/// Errors returned by operations on [Headers]. +/// +/// TODO: Where the variant contains an [OsString] for the header name, +/// we could explore using a static [OsStr] to avoid unnecessary memory copies +/// since we know the values at compilation time. #[derive(Debug, Error, PartialEq, Eq)] pub enum HeadersError { /// The header was not found - #[error("Header not found")] - HeaderNotFound, + #[error("Header {0:?} not found")] + HeaderNotFound(OsString), /// Internal CRT error #[error("CRT error: {0}")] CrtError(#[source] Error), /// Header value could not be converted to String - #[error("Header string was not valid: {0:?}")] - Invalid(OsString), + #[error("Header {name:?} had invalid string value: {value:?}")] + Invalid { + /// Name of the header + name: OsString, + /// Value of the header, which was not valid to convert to [String] + value: OsString, + }, } -// Convert CRT error into HeadersError, mapping the HEADER_NOT_FOUND to HeadersError::HeaderNotFound. -impl From for HeadersError { - fn from(err: Error) -> Self { - if err == (aws_http_errors::AWS_ERROR_HTTP_HEADER_NOT_FOUND as i32).into() { - Self::HeaderNotFound +impl HeadersError { + /// Try to convert the CRT [Error] into [HeadersError::HeaderNotFound], or return [HeadersError::CrtError]. + fn try_convert(err: Error, header_name: &OsStr) -> HeadersError { + if err.raw_error() == (aws_http_errors::AWS_ERROR_HTTP_HEADER_NOT_FOUND as i32) { + HeadersError::HeaderNotFound(header_name.to_owned()) } else { - Self::CrtError(err) + HeadersError::CrtError(err) } } } @@ -105,7 +114,11 @@ impl Headers { /// Create a new [Headers] object in the given allocator. pub fn new(allocator: &Allocator) -> Result { // SAFETY: allocator is a valid aws_allocator, and we check the return is non-null. - let inner = unsafe { aws_http_headers_new(allocator.inner.as_ptr()).ok_or_last_error()? }; + let inner = unsafe { + aws_http_headers_new(allocator.inner.as_ptr()) + .ok_or_last_error() + .map_err(HeadersError::CrtError)? + }; Ok(Self { inner }) } @@ -118,12 +131,14 @@ impl Headers { } /// Get the header at the specified index. - pub fn get_index(&self, index: usize) -> Result, HeadersError> { + fn get_index(&self, index: usize) -> Result, HeadersError> { // SAFETY: `self.inner` is a valid aws_http_headers, and `aws_http_headers_get_index` // promises to initialize the output `struct aws_http_header *out_header` on success. let header = unsafe { let mut header: MaybeUninit = MaybeUninit::uninit(); - aws_http_headers_get_index(self.inner.as_ptr(), index, header.as_mut_ptr()).ok_or_last_error()?; + aws_http_headers_get_index(self.inner.as_ptr(), index, header.as_mut_ptr()) + .ok_or_last_error() + .map_err(HeadersError::CrtError)?; header.assume_init() }; @@ -153,7 +168,9 @@ impl Headers { // SAFETY: `aws_http_headers_add_header` makes a copy of the underlying strings. // Also, this function takes a mut reference to `self`, since this function modifies the headers. unsafe { - aws_http_headers_add_header(self.inner.as_ptr(), &header.inner).ok_or_last_error()?; + aws_http_headers_add_header(self.inner.as_ptr(), &header.inner) + .ok_or_last_error() + .map_err(HeadersError::CrtError)?; } Ok(()) @@ -171,7 +188,9 @@ impl Headers { // SAFETY: `aws_http_headers_erase` doesn't hold on to a copy of the name we pass in, so it's // okay to call with with an `aws_byte_cursor` that may not outlive this `Headers`. unsafe { - aws_http_headers_erase(self.inner.as_ptr(), name.as_ref().as_aws_byte_cursor()).ok_or_last_error()?; + aws_http_headers_erase(self.inner.as_ptr(), name.as_ref().as_aws_byte_cursor()) + .ok_or_last_error() + .map_err(|err| HeadersError::try_convert(err, name.as_ref()))?; } Ok(()) @@ -183,12 +202,15 @@ impl Headers { // initialize the output `struct aws_byte_cursor *out_value` on success. let value = unsafe { let mut value: MaybeUninit = MaybeUninit::uninit(); + aws_http_headers_get( self.inner.as_ptr(), name.as_ref().as_aws_byte_cursor(), value.as_mut_ptr(), ) - .ok_or_last_error()?; + .ok_or_last_error() + .map_err(|err| HeadersError::try_convert(err, name.as_ref()))?; + value.assume_init() }; @@ -203,12 +225,17 @@ impl Headers { /// Get a single header by name as a [String]. pub fn get_as_string>(&self, name: H) -> Result { + let name = name.as_ref(); let header = self.get(name)?; let value = header.value(); if let Some(s) = value.to_str() { Ok(s.to_string()) } else { - Err(HeadersError::Invalid(value.clone())) + let err = HeadersError::Invalid { + name: name.to_owned(), + value: value.clone(), + }; + Err(err) } } @@ -263,7 +290,7 @@ impl Iterator for HeadersIterator<'_> { let header = self .headers .get_index(self.offset) - .expect("HeadersIterator: failed to get next header"); + .expect("headers at any offset smaller than original count should always exist given mut access"); self.offset += 1; Some((header.name, header.value)) @@ -417,14 +444,29 @@ mod test { #[test] fn test_header_not_present() { let headers = Headers::new(&Allocator::default()).expect("failed to create headers"); - assert!(!headers.has_header("a")); + + assert!(!headers.has_header("a"), "header should not be present"); + let error = headers.get("a").expect_err("should fail because header is not present"); - assert_eq!(error, HeadersError::HeaderNotFound, "should fail with HeaderNotFound"); + assert_eq!( + error.to_string(), + "Header \"a\" not found", + "header error display should match expected output", + ); + if let HeadersError::HeaderNotFound(name) = error { + assert_eq!(name, "a", "header name should match original argument"); + } else { + panic!("should fail with HeaderNotFound"); + } let error = headers .get_as_string("a") .expect_err("should fail because header is not present"); - assert_eq!(error, HeadersError::HeaderNotFound, "should fail with HeaderNotFound"); + if let HeadersError::HeaderNotFound(name) = error { + assert_eq!(name, "a", "header name should match original argument"); + } else { + panic!("should fail with HeaderNotFound"); + } let header = headers .get_as_optional_string("a")