Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pgsrv): Implement start of postgres extended query protocol #117

Merged
merged 12 commits into from
Sep 27, 2022
90 changes: 88 additions & 2 deletions crates/pgsrv/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,80 @@ impl PgCodec {
})
}

fn decode_parse(buf: &mut Cursor<'_>) -> Result<FrontendMessage> {
let name = buf.read_cstring()?.to_string();
let sql = buf.read_cstring()?.to_string();
let num_params = buf.get_i16() as usize;
let mut param_types = Vec::with_capacity(num_params);
for _ in 0..num_params {
param_types.push(buf.get_i32());
}
Comment on lines +155 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat snippet to use when setting up a vec with known size. Not sure how exactly this compares though with respect to performance.

And not a all necessary, just something I thought I would share.

Suggested change
let mut param_types = Vec::with_capacity(num_params);
for _ in 0..num_params {
param_types.push(buf.get_i32());
}
let param_types = (0..num_params).map(|| buf.get_i32()).collect();

Ok(FrontendMessage::Parse {
name,
sql,
param_types,
})
}

fn decode_bind(buf: &mut Cursor<'_>) -> Result<FrontendMessage> {
let portal = buf.read_cstring()?.to_string();
let statement = buf.read_cstring()?.to_string();

let num_params = buf.get_i16() as usize;
let mut param_formats = Vec::with_capacity(num_params);
for _ in 0..num_params {
param_formats.push(buf.get_i16());
}

let num_values = buf.get_i16() as usize; // must match num_params
justinrubek marked this conversation as resolved.
Show resolved Hide resolved
let mut param_values = Vec::with_capacity(num_values);
for _ in 0..num_values {
let len = buf.get_i32();
if len == -1 {
param_values.push(None);
} else {
let mut val = vec![0; len as usize];
buf.copy_to_slice(&mut val);
param_values.push(Some(val));
}
}

let num_params = buf.get_i16() as usize;
let mut result_formats = Vec::with_capacity(num_params);
for _ in 0..num_params {
result_formats.push(buf.get_i16());
}

Ok(FrontendMessage::Bind {
portal,
statement,
param_formats,
param_values,
result_formats,
})
}

fn decode_describe(buf: &mut Cursor<'_>) -> Result<FrontendMessage> {
let object_type = buf.get_u8().try_into()?;
let name = buf.read_cstring()?.to_string();

Ok(FrontendMessage::Describe { object_type, name })
}

fn decode_execute(buf: &mut Cursor<'_>) -> Result<FrontendMessage> {
let portal = buf.read_cstring()?.to_string();
let max_rows = buf.get_i32();
Ok(FrontendMessage::Execute { portal, max_rows })
}

fn decode_sync(_buf: &mut Cursor<'_>) -> Result<FrontendMessage> {
Ok(FrontendMessage::Sync)
}

fn decode_terminate(_buf: &mut Cursor<'_>) -> Result<FrontendMessage> {
Ok(FrontendMessage::Terminate)
}

fn encode_scalar_as_text(scalar: ScalarValue, buf: &mut BytesMut) -> Result<()> {
if scalar.is_null() {
buf.put_i32(-1);
Expand Down Expand Up @@ -187,6 +261,9 @@ impl Encoder<BackendMessage> for PgCodec {
BackendMessage::DataRow(_, _) => b'D',
BackendMessage::ErrorResponse(_) => b'E',
BackendMessage::NoticeResponse(_) => b'N',
BackendMessage::ParseComplete => b'1',
BackendMessage::BindComplete => b'2',
BackendMessage::NoData => b'n',
};
dst.put_u8(byte);

Expand All @@ -198,6 +275,9 @@ impl Encoder<BackendMessage> for PgCodec {
BackendMessage::AuthenticationOk => dst.put_i32(0),
BackendMessage::AuthenticationCleartextPassword => dst.put_i32(3),
BackendMessage::EmptyQueryResponse => (),
BackendMessage::ParseComplete => (),
BackendMessage::BindComplete => (),
BackendMessage::NoData => (),
BackendMessage::ParameterStatus { key, val } => {
dst.put_cstring(&key);
dst.put_cstring(&val);
Expand Down Expand Up @@ -284,8 +364,8 @@ impl Decoder for PgCodec {
let msg_len = i32::from_be_bytes(src[1..5].try_into().unwrap()) as usize;

// Not enough bytes to read the full message yet.
if src.len() < msg_len {
src.reserve(msg_len - src.len());
if src.len() < msg_len + 1 {
src.reserve(msg_len + 1 - src.len());
return Ok(None);
}

Expand All @@ -296,6 +376,12 @@ impl Decoder for PgCodec {
let msg = match msg_type {
b'Q' => Self::decode_query(&mut buf)?,
b'p' => Self::decode_password(&mut buf)?,
b'P' => Self::decode_parse(&mut buf)?,
b'B' => Self::decode_bind(&mut buf)?,
b'D' => Self::decode_describe(&mut buf)?,
b'E' => Self::decode_execute(&mut buf)?,
b'S' => Self::decode_sync(&mut buf)?,
b'X' => return Ok(None),
justinrubek marked this conversation as resolved.
Show resolved Hide resolved
other => return Err(PgSrvError::InvalidMsgType(other)),
};

Expand Down
3 changes: 3 additions & 0 deletions crates/pgsrv/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ pub enum PgSrvError {
#[error("missing null byte")]
MissingNullByte,

#[error("unexpected describe object type: {0}")]
UnexpectedDescribeObjectType(u8),

/// We've received an unexpected message identifier from the frontend.
/// Includes the char representation to allow for easy cross referencing
/// with the Postgres message format documentation.
Expand Down
Loading