Skip to content

Commit

Permalink
Fix result set terminator handling
Browse files Browse the repository at this point in the history
  • Loading branch information
blackbeam committed Dec 10, 2022
1 parent bac4b8d commit 7ee5307
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ crossbeam = "0.8.1"
io-enum = "1.0.0"
flate2 = { version = "1.0", default-features = false }
lru = "0.8.1"
mysql_common = { version = "0.29.1", default-features = false }
mysql_common = { version = "0.29.2", default-features = false }
socket2 = "0.4"
once_cell = "1.7.2"
pem = "1.0.1"
Expand Down
54 changes: 24 additions & 30 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use mysql_common::{
binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, Column, ComStmtClose,
ComStmtExecuteRequestBuilder, ComStmtSendLongData, CommonOkPacket, ErrPacket,
HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OkPacketKind,
OldAuthSwitchRequest, ResultSetTerminator, SessionStateInfo,
OldAuthSwitchRequest, OldEofPacket, ResultSetTerminator, SessionStateInfo,
},
proto::{codec::Compression, sync_framed::MySyncFramed, MySerialize},
row::{Row, RowDeserializer},
Expand Down Expand Up @@ -204,6 +204,11 @@ impl ConnInner {
pub struct Conn(Box<ConnInner>);

impl Conn {
/// Must not be called before handle_handshake.
const fn has_capability(&self, flag: CapabilityFlags) -> bool {
self.0.capability_flags.contains(flag)
}

/// Returns version number reported by the server.
pub fn server_version(&self) -> (u16, u16, u16) {
self.0
Expand Down Expand Up @@ -562,10 +567,7 @@ impl Conn {

if self.is_insecure() {
if let Some(ssl_opts) = self.0.opts.get_ssl_opts().cloned() {
if !handshake
.capabilities()
.contains(CapabilityFlags::CLIENT_SSL)
{
if !self.has_capability(CapabilityFlags::CLIENT_SSL) {
return Err(DriverError(TlsNotSupported));
} else {
self.do_ssl_request()?;
Expand Down Expand Up @@ -596,11 +598,7 @@ impl Conn {
self.write_handshake_response(&auth_plugin, auth_data.as_deref())?;
self.continue_auth(&auth_plugin, &*nonce, false)?;

if self
.0
.capability_flags
.contains(CapabilityFlags::CLIENT_COMPRESS)
{
if self.has_capability(CapabilityFlags::CLIENT_COMPRESS) {
self.switch_to_compressed();
}

Expand Down Expand Up @@ -1080,32 +1078,28 @@ impl Conn {
self.query_first(format!("SELECT @@{}", name))
}

fn next_bin(&mut self, columns: Arc<[Column]>) -> Result<Option<Row>> {
fn next_row_packet(&mut self) -> Result<Option<Buffer>> {
if !self.0.has_results {
return Ok(None);
}
let pld = self.read_packet()?;
if pld[0] == 0xfe && pld.len() < 0xfe {
self.0.has_results = false;
self.handle_ok::<ResultSetTerminator>(&pld)?;
return Ok(None);
}
let row = ParseBuf(&*pld).parse::<RowDeserializer<ServerSide, Binary>>(columns)?;
Ok(Some(row.into()))
}

fn next_text(&mut self, columns: Arc<[Column]>) -> Result<Option<Row>> {
if !self.0.has_results {
return Ok(None);
}
let pld = self.read_packet()?;
if pld[0] == 0xfe && pld.len() < 0xfe {
self.0.has_results = false;
self.handle_ok::<ResultSetTerminator>(&pld)?;
return Ok(None);

if self.has_capability(CapabilityFlags::CLIENT_DEPRECATE_EOF) {
if pld[0] == 0xfe && pld.len() < MAX_PAYLOAD_LEN {
self.0.has_results = false;
self.handle_ok::<ResultSetTerminator>(&pld)?;
return Ok(None);
}
} else {
if pld[0] == 0xfe && pld.len() < 8 {
self.0.has_results = false;
self.handle_ok::<OldEofPacket>(&pld)?;
return Ok(None);
}
}
let row = ParseBuf(&*pld).parse::<RowDeserializer<(), Text>>(columns)?;
Ok(Some(row.into()))

Ok(Some(pld))
}

fn has_stmt(&self, query: &[u8]) -> bool {
Expand Down
18 changes: 15 additions & 3 deletions src/conn/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

pub use mysql_common::proto::{Binary, Text};

use mysql_common::packets::OkPacket;
use mysql_common::{io::ParseBuf, packets::OkPacket, row::RowDeserializer, value::ServerSide};

use std::{borrow::Cow, marker::PhantomData, sync::Arc};

Expand All @@ -27,13 +27,25 @@ pub trait Protocol: 'static + Send + Sync {

impl Protocol for Text {
fn next(conn: &mut Conn, columns: Arc<[Column]>) -> Result<Option<Row>> {
conn.next_text(columns)
match conn.next_row_packet()? {
Some(pld) => {
let row = ParseBuf(&*pld).parse::<RowDeserializer<(), Text>>(columns)?;
Ok(Some(row.into()))
}
None => Ok(None),
}
}
}

impl Protocol for Binary {
fn next(conn: &mut Conn, columns: Arc<[Column]>) -> Result<Option<Row>> {
conn.next_bin(columns)
match conn.next_row_packet()? {
Some(pld) => {
let row = ParseBuf(&*pld).parse::<RowDeserializer<ServerSide, Binary>>(columns)?;
Ok(Some(row.into()))
}
None => Ok(None),
}
}
}

Expand Down

0 comments on commit 7ee5307

Please sign in to comment.