Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/batch length #824

Merged
merged 7 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cassandra.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
run: cargo build --verbose --tests
- name: Run tests on cassandra
run: |
CDC='disabled' SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose -- --skip test_views_in_schema_info
CDC='disabled' SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose -- --skip test_views_in_schema_info --skip test_large_batch_statements
- name: Stop the cluster
if: ${{ always() }}
run: docker compose -f test/cluster/cassandra/docker-compose.yml stop
Expand Down
4 changes: 4 additions & 0 deletions scylla-cql/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,10 @@ pub enum BadQuery {
#[error("Passed invalid keyspace name to use: {0}")]
BadKeyspaceName(#[from] BadKeyspaceName),

/// Too many queries in the batch statement
#[error("Number of Queries in Batch Statement supplied is {0} which has exceeded the max value of 65,535")]
TooManyQueriesInBatchStatement(usize),

/// Other reasons of bad query
#[error("{0}")]
Other(String),
Expand Down
2 changes: 1 addition & 1 deletion scylla-cql/src/frame/frame_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub enum ParseError {
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error("type not yet implemented, id: {0}")]
TypeNotImplemented(i16),
TypeNotImplemented(u16),
#[error(transparent)]
SerializeValuesError(#[from] SerializeValuesError),
#[error(transparent)]
Expand Down
2 changes: 1 addition & 1 deletion scylla-cql/src/frame/request/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec<SerializedV
fn deserialize(buf: &mut &[u8]) -> Result<Self, ParseError> {
let batch_type = buf.get_u8().try_into()?;

let statements_count: usize = types::read_short(buf)?.try_into()?;
let statements_count: usize = types::read_short(buf)?.into();
let statements_with_values = (0..statements_count)
.map(|_| {
let batch_statement = BatchStatement::deserialize(buf)?;
Expand Down
4 changes: 2 additions & 2 deletions scylla-cql/src/frame/response/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ fn deser_type(buf: &mut &[u8]) -> StdResult<ColumnType, ParseError> {
0x0030 => {
let keyspace_name: String = types::read_string(buf)?.to_string();
let type_name: String = types::read_string(buf)?.to_string();
let fields_size: usize = types::read_short(buf)?.try_into()?;
let fields_size: usize = types::read_short(buf)?.into();

let mut field_types: Vec<(String, ColumnType)> = Vec::with_capacity(fields_size);

Expand All @@ -455,7 +455,7 @@ fn deser_type(buf: &mut &[u8]) -> StdResult<ColumnType, ParseError> {
}
}
0x0031 => {
let len: usize = types::read_short(buf)?.try_into()?;
let len: usize = types::read_short(buf)?.into();
let mut types = Vec::with_capacity(len);
for _ in 0..len {
types.push(deser_type(buf)?);
Expand Down
20 changes: 10 additions & 10 deletions scylla-cql/src/frame/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use uuid::Uuid;
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "SCREAMING_SNAKE_CASE"))]
#[repr(i16)]
#[repr(u16)]
pub enum Consistency {
Any = 0x0000,
One = 0x0001,
Expand Down Expand Up @@ -169,30 +169,30 @@ fn type_long() {
}
}

pub fn read_short(buf: &mut &[u8]) -> Result<i16, ParseError> {
let v = buf.read_i16::<BigEndian>()?;
pub fn read_short(buf: &mut &[u8]) -> Result<u16, ParseError> {
let v = buf.read_u16::<BigEndian>()?;
Ok(v)
}

pub fn write_short(v: i16, buf: &mut impl BufMut) {
buf.put_i16(v);
pub fn write_short(v: u16, buf: &mut impl BufMut) {
buf.put_u16(v);
}

pub(crate) fn read_short_length(buf: &mut &[u8]) -> Result<usize, ParseError> {
let v = read_short(buf)?;
let v: usize = v.try_into()?;
let v: usize = v.into();
Ok(v)
}

fn write_short_length(v: usize, buf: &mut impl BufMut) -> Result<(), ParseError> {
let v: i16 = v.try_into()?;
let v: u16 = v.try_into()?;
write_short(v, buf);
Ok(())
}

#[test]
fn type_short() {
let vals = [i16::MIN, -1, 0, 1, i16::MAX];
let vals: [u16; 3] = [0, 1, u16::MAX];
for val in vals.iter() {
let mut buf = Vec::new();
write_short(*val, &mut buf);
Expand Down Expand Up @@ -464,11 +464,11 @@ pub fn read_consistency(buf: &mut &[u8]) -> Result<Consistency, ParseError> {
}

pub fn write_consistency(c: Consistency, buf: &mut impl BufMut) {
write_short(c as i16, buf);
write_short(c as u16, buf);
}

pub fn write_serial_consistency(c: SerialConsistency, buf: &mut impl BufMut) {
write_short(c as i16, buf);
write_short(c as u16, buf);
}

#[test]
Expand Down
12 changes: 6 additions & 6 deletions scylla-cql/src/frame/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct Time(pub Duration);
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct SerializedValues {
serialized_values: Vec<u8>,
values_num: i16,
values_num: u16,
contains_names: bool,
}

Expand All @@ -77,7 +77,7 @@ pub struct CqlDuration {

#[derive(Debug, Error, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum SerializeValuesError {
#[error("Too many values to add, max 32 767 values can be sent in a request")]
#[error("Too many values to add, max 65,535 values can be sent in a request")]
TooManyValues,
#[error("Mixing named and not named values is not allowed")]
MixingNamedAndNotNamedValues,
Expand Down Expand Up @@ -134,7 +134,7 @@ impl SerializedValues {
if self.contains_names {
return Err(SerializeValuesError::MixingNamedAndNotNamedValues);
}
if self.values_num == i16::MAX {
if self.values_num == u16::MAX {
return Err(SerializeValuesError::TooManyValues);
}
Comment on lines +137 to 139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The message for TooManyValues has to be adjusted as well, currently it mentions i16::MAX requests.

#[error("Too many values to add, max 32 767 values can be sent in a request")]
TooManyValues,


Expand All @@ -158,7 +158,7 @@ impl SerializedValues {
return Err(SerializeValuesError::MixingNamedAndNotNamedValues);
}
self.contains_names = true;
if self.values_num == i16::MAX {
if self.values_num == u16::MAX {
return Err(SerializeValuesError::TooManyValues);
}

Expand All @@ -184,15 +184,15 @@ impl SerializedValues {
}

pub fn write_to_request(&self, buf: &mut impl BufMut) {
buf.put_i16(self.values_num);
buf.put_u16(self.values_num);
buf.put(&self.serialized_values[..]);
}

pub fn is_empty(&self) -> bool {
self.values_num == 0
}

pub fn len(&self) -> i16 {
pub fn len(&self) -> u16 {
self.values_num
}

Expand Down
2 changes: 1 addition & 1 deletion scylla/src/statement/prepared_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ impl PreparedStatement {
#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
pub enum PartitionKeyExtractionError {
#[error("No value with given pk_index! pk_index: {0}, values.len(): {1}")]
NoPkIndexValue(u16, i16),
NoPkIndexValue(u16, u16),
}

#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
Expand Down
68 changes: 68 additions & 0 deletions scylla/src/transport/large_batch_statements_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use assert_matches::assert_matches;

use scylla_cql::errors::{BadQuery, QueryError};

use crate::batch::BatchType;
use crate::query::Query;
use crate::{
batch::Batch,
test_utils::{create_new_session_builder, unique_keyspace_name},
QueryResult, Session,
};

#[tokio::test]
async fn test_large_batch_statements() {
let mut session = create_new_session_builder().build().await.unwrap();

let ks = unique_keyspace_name();
session = create_test_session(session, &ks).await;

let max_queries = u16::MAX as usize;
let batch_insert_result = write_batch(&session, max_queries, &ks).await;

batch_insert_result.unwrap();

let too_many_queries = u16::MAX as usize + 1;
let batch_insert_result = write_batch(&session, too_many_queries, &ks).await;
assert_matches!(
batch_insert_result.unwrap_err(),
QueryError::BadQuery(BadQuery::TooManyQueriesInBatchStatement(_too_many_queries)) if _too_many_queries == too_many_queries
)
}

async fn create_test_session(session: Session, ks: &String) -> Session {
session
.query(
format!("CREATE KEYSPACE {} WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1 }}",ks),
&[],
)
.await.unwrap();
session
.query(
format!(
"CREATE TABLE {}.pairs (dummy int, k blob, v blob, primary key (dummy, k))",
ks
),
&[],
)
.await
.unwrap();
session
}

async fn write_batch(session: &Session, n: usize, ks: &String) -> Result<QueryResult, QueryError> {
let mut batch_query = Batch::new(BatchType::Unlogged);
let mut batch_values = Vec::new();
let query = format!("INSERT INTO {}.pairs (dummy, k, v) VALUES (0, ?, ?)", ks);
let query = Query::new(query);
let prepared_statement = session.prepare(query).await.unwrap();
for i in 0..n {
let mut key = vec![0];
key.extend(i.to_be_bytes().as_slice());
let value = key.clone();
let values = vec![key, value];
batch_values.push(values);
batch_query.append_statement(prepared_statement.clone());
}
session.batch(&batch_query, batch_values).await
}
2 changes: 2 additions & 0 deletions scylla/src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ mod silent_prepare_batch_test;
mod cql_types_test;
#[cfg(test)]
mod cql_value_test;
#[cfg(test)]
mod large_batch_statements_test;

pub use cluster::ClusterData;
pub use node::{KnownNode, Node, NodeAddr, NodeRef};
8 changes: 8 additions & 0 deletions scylla/src/transport/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub use crate::transport::connection_pool::PoolSize;
use crate::authentication::AuthenticatorProvider;
#[cfg(feature = "ssl")]
use openssl::ssl::SslContext;
use scylla_cql::errors::BadQuery;

/// Translates IP addresses received from ScyllaDB nodes into locally reachable addresses.
///
Expand Down Expand Up @@ -1143,6 +1144,13 @@ impl Session {
// Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
// If users batch statements by shard, they will be rewarded with full shard awareness

// check to ensure that we don't send a batch statement with more than u16::MAX queries
let batch_statements_length = batch.statements.len();
if batch_statements_length > u16::MAX as usize {
return Err(QueryError::BadQuery(
BadQuery::TooManyQueriesInBatchStatement(batch_statements_length),
));
}
// Extract first serialized_value
let first_serialized_value = values.batch_values_iter().next_serialized().transpose()?;
let first_serialized_value = first_serialized_value.as_deref();
Expand Down