Skip to content

Commit

Permalink
Merge pull request #1137 from nrxus/fix-non-ident-column-names
Browse files Browse the repository at this point in the history
Allow serialization of "columns" that are not valid rust identifiers
  • Loading branch information
Lorak-mmk authored Dec 4, 2024
2 parents 62f96b3 + 8b579e6 commit 4b6ad84
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 10 deletions.
17 changes: 17 additions & 0 deletions scylla-cql/src/types/serialize/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,23 @@ pub(crate) mod tests {
assert_eq!(reference, row);
}

#[test]
fn test_row_serialization_with_not_rust_idents() {
#[derive(SerializeRow, Debug)]
#[scylla(crate = crate)]
struct RowWithTTL {
#[scylla(rename = "[ttl]")]
ttl: i32,
}

let spec = [col("[ttl]", ColumnType::Int)];

let reference = do_serialize((42i32,), &spec);
let row = do_serialize(RowWithTTL { ttl: 42 }, &spec);

assert_eq!(reference, row);
}

#[derive(SerializeRow, Debug)]
#[scylla(crate = crate)]
struct TestRowWithSkippedFields {
Expand Down
28 changes: 28 additions & 0 deletions scylla-cql/src/types/serialize/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2824,4 +2824,32 @@ pub(crate) mod tests {

assert_eq!(reference, row);
}

#[test]
fn test_udt_with_non_rust_ident() {
#[derive(SerializeValue, Debug)]
#[scylla(crate = crate)]
struct UdtWithNonRustIdent {
#[scylla(rename = "a$a")]
a: i32,
}

let typ = ColumnType::UserDefinedType {
type_name: "typ".into(),
keyspace: "ks".into(),
field_types: vec![("a$a".into(), ColumnType::Int)],
};
let value = UdtWithNonRustIdent { a: 42 };

let mut reference = Vec::new();
// Total length of the struct
reference.extend_from_slice(&8i32.to_be_bytes());
// Field 'a'
reference.extend_from_slice(&(std::mem::size_of_val(&value.a) as i32).to_be_bytes());
reference.extend_from_slice(&value.a.to_be_bytes());

let udt = do_serialize(value, &typ);

assert_eq!(reference, udt);
}
}
2 changes: 1 addition & 1 deletion scylla-macros/src/serialize/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ impl Generator for ColumnSortingGenerator<'_> {
statements.push(self.ctx.generate_mk_ser_err());

// Generate a "visited" flag for each field
let visited_flag_names = rust_field_names
let visited_flag_names = rust_field_idents
.iter()
.map(|s| syn::Ident::new(&format!("visited_flag_{}", s), Span::call_site()))
.collect::<Vec<_>>();
Expand Down
13 changes: 6 additions & 7 deletions scylla-macros/src/serialize/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::collections::HashMap;

use darling::FromAttributes;
use proc_macro::TokenStream;
use proc_macro2::Span;
use syn::parse_quote;

use crate::Flavor;
Expand Down Expand Up @@ -327,14 +326,14 @@ impl Generator for FieldSortingGenerator<'_> {
.generate_udt_type_match(parse_quote!(#crate_path::UdtTypeCheckErrorKind::NotUdt)),
);

fn make_visited_flag_ident(field_name: &str) -> syn::Ident {
syn::Ident::new(&format!("visited_flag_{}", field_name), Span::call_site())
fn make_visited_flag_ident(field_name: &syn::Ident) -> syn::Ident {
syn::Ident::new(&format!("visited_flag_{}", field_name), field_name.span())
}

// Generate a "visited" flag for each field
let visited_flag_names = rust_field_names
let visited_flag_names = rust_field_idents
.iter()
.map(|s| make_visited_flag_ident(s))
.map(make_visited_flag_ident)
.collect::<Vec<_>>();
statements.extend::<Vec<_>>(parse_quote! {
#(let mut #visited_flag_names = false;)*
Expand All @@ -347,11 +346,11 @@ impl Generator for FieldSortingGenerator<'_> {
.fields
.iter()
.filter(|f| !f.attrs.ignore_missing)
.map(|f| f.field_name());
.map(|f| &f.ident);
// An iterator over visited flags of Rust fields that can't be ignored
// (i.e., if UDT misses a corresponding field, an error should be raised).
let nonignorable_visited_flag_names =
nonignorable_rust_field_names.map(|s| make_visited_flag_ident(&s));
nonignorable_rust_field_names.map(make_visited_flag_ident);

// Generate a variable that counts down visited fields.
let field_count = self.ctx.fields.len();
Expand Down
4 changes: 2 additions & 2 deletions scylla/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ pub use scylla_cql::macros::SerializeRow;
/// If the value of the field received from DB is null, the field will be
/// initialized with `Default::default()`.
///
/// `#[scylla(rename = "field_name")`
/// `#[scylla(rename = "field_name")]`
///
/// By default, the generated implementation will try to match the Rust field
/// to a UDT field with the same name. This attribute instead allows to match
Expand Down Expand Up @@ -475,7 +475,7 @@ pub use scylla_macros::DeserializeValue;
/// The field will be completely ignored during deserialization and will
/// be initialized with `Default::default()`.
///
/// `#[scylla(rename = "field_name")`
/// `#[scylla(rename = "field_name")]`
///
/// By default, the generated implementation will try to match the Rust field
/// to a column with the same name. This attribute allows to match to a column
Expand Down

0 comments on commit 4b6ad84

Please sign in to comment.