From 8384730603cec3e7b984f9ed51b6fa39fbdaa79c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Sun, 10 Nov 2024 16:02:42 +0100 Subject: [PATCH] codewide: introduce DeserializeOwned{Row,Value} As noted in a review comment, parsing `for<'r> DeserializeValue<'r, 'r>` to understand it's requiring an owned type is nontrivial and could be replaced with a subtrait with an informative name. Therefore, this commit introduces DeserializeOwnedRow and DeserializeOwnedValue (to be used by the `scylla` crate itself only). --- scylla/src/lib.rs | 7 +++++++ scylla/src/transport/cql_collections_test.rs | 4 ++-- scylla/src/transport/cql_types_test.rs | 6 +++--- scylla/src/transport/iterator.rs | 3 ++- scylla/src/transport/session_test.rs | 4 ++-- scylla/src/transport/topology.rs | 4 ++-- 6 files changed, 18 insertions(+), 10 deletions(-) diff --git a/scylla/src/lib.rs b/scylla/src/lib.rs index bac3fd3f9..715fe8d4d 100644 --- a/scylla/src/lib.rs +++ b/scylla/src/lib.rs @@ -230,6 +230,13 @@ pub mod deserialize { UdtIterator, UdtTypeCheckErrorKind, }; } + + // Shorthands for better readability. + #[cfg_attr(not(test), allow(unused))] + pub(crate) trait DeserializeOwnedValue: for<'r> DeserializeValue<'r, 'r> {} + impl DeserializeOwnedValue for T where T: for<'r> DeserializeValue<'r, 'r> {} + pub(crate) trait DeserializeOwnedRow: for<'r> DeserializeRow<'r, 'r> {} + impl DeserializeOwnedRow for T where T: for<'r> DeserializeRow<'r, 'r> {} } pub mod authentication; diff --git a/scylla/src/transport/cql_collections_test.rs b/scylla/src/transport/cql_collections_test.rs index f37d28a8f..475bd47ee 100644 --- a/scylla/src/transport/cql_collections_test.rs +++ b/scylla/src/transport/cql_collections_test.rs @@ -1,5 +1,5 @@ +use crate::deserialize::DeserializeOwnedValue; use crate::transport::session::Session; -use scylla_cql::types::deserialize::value::DeserializeValue; use crate::frame::response::result::CqlValue; use crate::test_utils::{create_new_session_builder, setup_tracing}; @@ -36,7 +36,7 @@ async fn insert_and_select( expected: &SelectT, ) where InsertT: SerializeValue, - SelectT: for<'r> DeserializeValue<'r, 'r> + PartialEq + std::fmt::Debug, + SelectT: DeserializeOwnedValue + PartialEq + std::fmt::Debug, { session .query_unpaged( diff --git a/scylla/src/transport/cql_types_test.rs b/scylla/src/transport/cql_types_test.rs index 0a1833fd7..2863df76c 100644 --- a/scylla/src/transport/cql_types_test.rs +++ b/scylla/src/transport/cql_types_test.rs @@ -1,4 +1,5 @@ use crate as scylla; +use crate::deserialize::DeserializeOwnedValue; use crate::frame::response::result::CqlValue; use crate::frame::value::{Counter, CqlDate, CqlTime, CqlTimestamp}; use crate::test_utils::{create_new_session_builder, scylla_supports_tablets, setup_tracing}; @@ -6,7 +7,6 @@ use crate::transport::session::Session; use crate::utils::test_utils::unique_keyspace_name; use itertools::Itertools; use scylla_cql::frame::value::{CqlTimeuuid, CqlVarint}; -use scylla_cql::types::deserialize::value::DeserializeValue; use scylla_cql::types::serialize::value::SerializeValue; use scylla_macros::{DeserializeValue, SerializeValue}; use std::cmp::PartialEq; @@ -74,7 +74,7 @@ async fn init_test(table_name: &str, type_name: &str) -> Session { // Expected values and bound values are computed using T::from_str async fn run_tests(tests: &[&str], type_name: &str) where - T: SerializeValue + for<'r> DeserializeValue<'r, 'r> + FromStr + Debug + Clone + PartialEq, + T: SerializeValue + DeserializeOwnedValue + FromStr + Debug + Clone + PartialEq, { let session: Session = init_test(type_name, type_name).await; session.await_schema_agreement().await.unwrap(); @@ -1799,7 +1799,7 @@ async fn test_udt_with_missing_field() { expected: TR, ) where TQ: SerializeValue, - TR: for<'r> DeserializeValue<'r, 'r> + PartialEq + Debug, + TR: DeserializeOwnedValue + PartialEq + Debug, { session .query_unpaged( diff --git a/scylla/src/transport/iterator.rs b/scylla/src/transport/iterator.rs index 100fafe2e..92918281d 100644 --- a/scylla/src/transport/iterator.rs +++ b/scylla/src/transport/iterator.rs @@ -24,6 +24,7 @@ use super::query_result::ColumnSpecs; use super::session::RequestSpan; use crate::cql_to_rust::{FromRow, FromRowError}; +use crate::deserialize::DeserializeOwnedRow; use crate::frame::response::{ result, result::{ColumnSpec, Row}, @@ -1076,7 +1077,7 @@ impl TypedRowStream { /// It only works with owned types! For example, &str is not supported. impl Stream for TypedRowStream where - RowT: for<'r> DeserializeRow<'r, 'r>, + RowT: DeserializeOwnedRow, { type Item = Result; diff --git a/scylla/src/transport/session_test.rs b/scylla/src/transport/session_test.rs index 7e767941a..4de38f52c 100644 --- a/scylla/src/transport/session_test.rs +++ b/scylla/src/transport/session_test.rs @@ -1,5 +1,5 @@ use crate::batch::{Batch, BatchStatement}; -use crate::deserialize::DeserializeValue; +use crate::deserialize::DeserializeOwnedValue; use crate::prepared_statement::PreparedStatement; use crate::query::Query; use crate::retry_policy::{QueryInfo, RetryDecision, RetryPolicy, RetrySession}; @@ -3100,7 +3100,7 @@ async fn test_deserialize_empty_collections() { session.use_keyspace(&ks, true).await.unwrap(); async fn deserialize_empty_collection< - Collection: Default + for<'frame> DeserializeValue<'frame, 'frame> + SerializeValue, + Collection: Default + DeserializeOwnedValue + SerializeValue, >( session: &Session, collection_name: &str, diff --git a/scylla/src/transport/topology.rs b/scylla/src/transport/topology.rs index 7f3f6e41f..ab29cd46b 100644 --- a/scylla/src/transport/topology.rs +++ b/scylla/src/transport/topology.rs @@ -1,3 +1,4 @@ +use crate::deserialize::DeserializeOwnedRow; use crate::frame::response::event::Event; use crate::routing::Token; use crate::statement::query::Query; @@ -15,7 +16,6 @@ use futures::Stream; use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; use scylla_cql::frame::frame_errors::RowsParseError; -use scylla_cql::types::deserialize::row::DeserializeRow; use scylla_cql::types::deserialize::TypeCheckError; use scylla_macros::DeserializeRow; use std::borrow::BorrowMut; @@ -930,7 +930,7 @@ fn query_filter_keyspace_name<'a, R>( convert_typecheck_error: impl FnOnce(TypeCheckError) -> MetadataError + 'a, ) -> impl Stream> + 'a where - R: for<'r> DeserializeRow<'r, 'r> + 'static, + R: DeserializeOwnedRow + 'static, { let conn = conn.clone();