Skip to content

Commit

Permalink
Replace NonZeroU32 type lookup ids with u32 (#90)
Browse files Browse the repository at this point in the history
* Replace NonZeroU32 type ids with u32

* Add UntrackedSymbol constructor

* Fmt

* Remove UntrackedSymbol constructor

* Update src/interner.rs

Co-authored-by: David <[email protected]>

* Update src/interner.rs

Co-authored-by: David <[email protected]>

Co-authored-by: David <[email protected]>
  • Loading branch information
ascjones and dvdplm authored May 10, 2021
1 parent 3b542f4 commit ac731b0
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 121 deletions.
63 changes: 22 additions & 41 deletions src/interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use crate::prelude::{
Entry,
},
marker::PhantomData,
num::NonZeroU32,
vec::Vec,
};

Expand All @@ -42,39 +41,21 @@ use serde::{
///
/// This can be used by self-referential types but
/// can no longer be used to resolve instances.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[derive(
Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, scale::Encode, scale::Decode,
)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
pub struct UntrackedSymbol<T> {
/// The index to the symbol in the interner table.
id: NonZeroU32,
id: u32,
#[cfg_attr(feature = "serde", serde(skip))]
marker: PhantomData<fn() -> T>,
}

impl<T> scale::Encode for UntrackedSymbol<T> {
fn encode_to<W: scale::Output + ?Sized>(&self, dest: &mut W) {
self.id.get().encode_to(dest)
}
}

impl<T> scale::Decode for UntrackedSymbol<T> {
fn decode<I: scale::Input>(value: &mut I) -> Result<Self, scale::Error> {
let id = <u32 as scale::Decode>::decode(value)?;
if id < 1 {
return Err("UntrackedSymbol::id should be a non-zero unsigned integer".into())
}
let id = NonZeroU32::new(id).expect("ID is non zero");
Ok(UntrackedSymbol {
id,
marker: Default::default(),
})
}
}

impl<T> UntrackedSymbol<T> {
/// Returns the index to the symbol in the interner table.
pub fn id(&self) -> NonZeroU32 {
pub fn id(&self) -> u32 {
self.id
}
}
Expand All @@ -86,7 +67,7 @@ impl<T> UntrackedSymbol<T> {
#[cfg_attr(feature = "serde", derive(Serialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
pub struct Symbol<'a, T> {
id: NonZeroU32,
id: u32,
#[cfg_attr(feature = "serde", serde(skip))]
marker: PhantomData<fn() -> &'a T>,
}
Expand Down Expand Up @@ -181,18 +162,18 @@ where
(
inserted,
Symbol {
id: NonZeroU32::new((sym_id + 1) as u32).unwrap(),
id: sym_id as u32,
marker: PhantomData,
},
)
}

/// Returns the symbol of the given element or `None` if it hasn't been
/// interned already.
pub fn get(&self, s: &T) -> Option<Symbol<T>> {
self.map.get(s).map(|&id| {
pub fn get(&self, sym: &T) -> Option<Symbol<T>> {
self.map.get(sym).map(|&id| {
Symbol {
id: NonZeroU32::new(id as u32).unwrap(),
id: id as u32,
marker: PhantomData,
}
})
Expand All @@ -201,7 +182,7 @@ where
/// Resolves the original element given its associated symbol or
/// returns `None` if it has not been interned yet.
pub fn resolve(&self, sym: Symbol<T>) -> Option<&T> {
let idx = (sym.id.get() - 1) as usize;
let idx = sym.id as usize;
if idx >= self.vec.len() {
return None
}
Expand All @@ -220,7 +201,7 @@ mod tests {
new_symbol: &'static str,
expected_id: u32,
) {
let actual_id = interner.intern_or_get(new_symbol).1.id.get();
let actual_id = interner.intern_or_get(new_symbol).1.id;
assert_eq!(actual_id, expected_id,);
}

Expand All @@ -229,7 +210,7 @@ mod tests {
E: Into<Option<&'static str>>,
{
let actual_str = interner.resolve(Symbol {
id: NonZeroU32::new(symbol_id).unwrap(),
id: symbol_id,
marker: PhantomData,
});
assert_eq!(actual_str.cloned(), expected_str.into(),);
Expand All @@ -238,14 +219,14 @@ mod tests {
#[test]
fn simple() {
let mut interner = StringInterner::new();
assert_id(&mut interner, "Hello", 1);
assert_id(&mut interner, ", World!", 2);
assert_id(&mut interner, "1 2 3", 3);
assert_id(&mut interner, "Hello", 1);

assert_resolve(&mut interner, 1, "Hello");
assert_resolve(&mut interner, 2, ", World!");
assert_resolve(&mut interner, 3, "1 2 3");
assert_resolve(&mut interner, 4, None);
assert_id(&mut interner, "Hello", 0);
assert_id(&mut interner, ", World!", 1);
assert_id(&mut interner, "1 2 3", 2);
assert_id(&mut interner, "Hello", 0);

assert_resolve(&mut interner, 0, "Hello");
assert_resolve(&mut interner, 1, ", World!");
assert_resolve(&mut interner, 2, "1 2 3");
assert_resolve(&mut interner, 3, None);
}
}
15 changes: 7 additions & 8 deletions src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ use crate::prelude::{
any::TypeId,
collections::BTreeMap,
fmt::Debug,
num::NonZeroU32,
vec::Vec,
};

Expand Down Expand Up @@ -178,14 +177,14 @@ impl From<Registry> for PortableRegistry {

impl PortableRegistry {
/// Returns the type definition for the given identifier, `None` if no type found for that ID.
pub fn resolve(&self, id: NonZeroU32) -> Option<&Type<PortableForm>> {
self.types.get((id.get() - 1) as usize)
pub fn resolve(&self, id: u32) -> Option<&Type<PortableForm>> {
self.types.get(id as usize)
}

/// Returns an iterator for all types paired with their associated NonZeroU32 identifier.
pub fn enumerate(&self) -> impl Iterator<Item = (NonZeroU32, &Type<PortableForm>)> {
/// Returns an iterator for all types paired with their associated u32 identifier.
pub fn enumerate(&self) -> impl Iterator<Item = (u32, &Type<PortableForm>)> {
self.types.iter().enumerate().map(|(i, ty)| {
let id = NonZeroU32::new(i as u32 + 1).expect("i + 1 > 0; qed");
let id = i as u32;
(id, ty)
})
}
Expand Down Expand Up @@ -213,9 +212,9 @@ mod tests {

assert_eq!(4, readonly.enumerate().count());

let mut expected = 1;
let mut expected = 0;
for (i, _) in readonly.enumerate() {
assert_eq!(NonZeroU32::new(expected).unwrap(), i);
assert_eq!(expected, i);
expected += 1;
}
}
Expand Down
8 changes: 2 additions & 6 deletions test_suite/tests/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ fn scale_encode_then_decode_to_readonly() {
let original_serialized = serde_json::to_value(registry).unwrap();

let readonly_decoded = PortableRegistry::decode(&mut &encoded[..]).unwrap();
assert!(readonly_decoded
.resolve(NonZeroU32::new(1).unwrap())
.is_some());
assert!(readonly_decoded.resolve(0).is_some());
let decoded_serialized = serde_json::to_value(readonly_decoded).unwrap();

assert_eq!(decoded_serialized, original_serialized);
Expand All @@ -81,9 +79,7 @@ fn json_serialize_then_deserialize_to_readonly() {
// assert_eq!(original_serialized, serde_json::Value::Null);
let readonly_deserialized: PortableRegistry =
serde_json::from_value(original_serialized.clone()).unwrap();
assert!(readonly_deserialized
.resolve(NonZeroU32::new(1).unwrap())
.is_some());
assert!(readonly_deserialized.resolve(0).is_some());
let readonly_serialized = serde_json::to_value(readonly_deserialized).unwrap();

assert_eq!(readonly_serialized, original_serialized);
Expand Down
Loading

0 comments on commit ac731b0

Please sign in to comment.