diff --git a/sqlx-postgres/src/arguments.rs b/sqlx-postgres/src/arguments.rs index 0d57f068b0..dd6f239f8a 100644 --- a/sqlx-postgres/src/arguments.rs +++ b/sqlx-postgres/src/arguments.rs @@ -64,13 +64,19 @@ impl PgArguments { where T: Encode<'q, Postgres> + Type, { - // remember the type information for this value - self.types - .push(value.produces().unwrap_or_else(T::type_info)); + let type_info = value.produces().unwrap_or_else(T::type_info); + + let buffer_snapshot = self.buffer.snapshot(); // encode the value into our buffer - self.buffer.encode(value)?; + if let Err(error) = self.buffer.encode(value) { + // reset the value buffer to its previous value if encoding failed so we don't leave a half-encoded value behind + self.buffer.reset_to_snapshot(buffer_snapshot); + return Err(error); + }; + // remember the type information for this value + self.types.push(type_info); // increment the number of arguments we are tracking self.buffer.count += 1; @@ -176,6 +182,44 @@ impl PgArgumentBuffer { self.extend_from_slice(&0_u32.to_be_bytes()); self.type_holes.push((offset, type_name.clone())); } + + fn snapshot(&self) -> PgArgumentBufferSnapshot { + let Self { + buffer, + count, + patches, + type_holes, + } = self; + + PgArgumentBufferSnapshot { + buffer_length: buffer.len(), + count: *count, + patches_length: patches.len(), + type_holes_length: type_holes.len(), + } + } + + fn reset_to_snapshot( + &mut self, + PgArgumentBufferSnapshot { + buffer_length, + count, + patches_length, + type_holes_length, + }: PgArgumentBufferSnapshot, + ) { + self.buffer.truncate(buffer_length); + self.count = count; + self.patches.truncate(patches_length); + self.type_holes.truncate(type_holes_length); + } +} + +struct PgArgumentBufferSnapshot { + buffer_length: usize, + count: usize, + patches_length: usize, + type_holes_length: usize, } impl Deref for PgArgumentBuffer {