Skip to content

Commit

Permalink
imp: Implement custom JSON and Borsh deserialization for ChainId (#…
Browse files Browse the repository at this point in the history
…1013)

* Add ChainId json deserialize test

* Add BorshDeserialization test

* Add some debugging to ChainId deserialize impl

* Change assertion to unwrap

* First stab at custom Visitor and Deserialize impls

* Implement custom Deserialize for ChainId

* Remove unnecessary ChainId::from_str call

* Add some additional assertions to test valid ChainIds

* Stub out custom BorshDeserialize impl

* Get custom borshDeserialize impl compiling

* Add rstest test case testing invalid borsh deserialization

* Add changelog entry

* Cargo fmt

* Incorporate some PR feedback

* Remove expanded.rs file

* Remove expanded.rs file

* Clean up borshDeserialize impl

* Add test the verifies valid borsh deserialization

* Update BorshDeserialize test

* test: add test_valid_borsh_ser_de_roundtrip

* Clean up

* nit

* Move `use core::fmt` statement into serde::Deserialize impl

* fix: core::fmt::Result

---------

Co-authored-by: Farhad Shabani <[email protected]>
  • Loading branch information
seanchen1991 and Farhad-Shabani authored Dec 21, 2023
1 parent 8f21923 commit 7481311
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- `[ibc-core-host-types]` Implement custom JSON and Borsh deserialization for `ChainId` ([#996](https://github.com/cosmos/ibc-rs/pull/1013))
2 changes: 1 addition & 1 deletion ibc-core/ics24-host/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,4 @@ parity-scale-codec = [
"ibc-core-host-types/parity-scale-codec",
"ibc-core-handler-types/parity-scale-codec",
"ibc-primitives/parity-scale-codec",
]
]
3 changes: 2 additions & 1 deletion ibc-core/ics24-host/types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ parity-scale-codec = { workspace = true, optional = true }
scale-info = { workspace = true, optional = true }

[dev-dependencies]
rstest = { workspace = true }
rstest = { workspace = true }
serde_json = { workspace = true }

[features]
default = ["std"]
Expand Down
201 changes: 195 additions & 6 deletions ibc-core/ics24-host/types/src/identifiers/chain_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use core::fmt::{Debug, Display, Error as FmtError, Formatter};
use core::str::FromStr;

use ibc_primitives::prelude::*;
#[cfg(feature = "serde")]
use serde::de::{Deserialize, Deserializer, Error, MapAccess, Visitor};

use crate::error::IdentifierError;
use crate::validate::{
Expand All @@ -24,11 +26,8 @@ use crate::validate::{
scale_info::TypeInfo
)
)]
#[cfg_attr(
feature = "borsh",
derive(borsh::BorshSerialize, borsh::BorshDeserialize)
)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "borsh", derive(borsh::BorshSerialize))]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ChainId {
Expand Down Expand Up @@ -116,6 +115,146 @@ impl ChainId {
}
}

#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for ChainId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
const FIELDS: &[&str] = &["id", "revision_number"];

enum Field {
Id,
RevisionNumber,
}

impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct FieldVisitor;

impl<'de> Visitor<'de> for FieldVisitor {
type Value = Field;

fn expecting(&self, formatter: &mut Formatter<'_>) -> core::fmt::Result {
write!(formatter, "expected one of: {:?}", &FIELDS)
}

fn visit_str<E>(self, value: &str) -> Result<Field, E>
where
E: Error,
{
match value {
"id" => Ok(Field::Id),
"revisionNumber" | "revision_number" => Ok(Field::RevisionNumber),
_ => Err(Error::unknown_field(value, FIELDS)),
}
}
}

deserializer.deserialize_identifier(FieldVisitor)
}
}

struct ChainIdVisitor;

impl<'de> Visitor<'de> for ChainIdVisitor {
type Value = ChainId;

fn expecting(&self, formatter: &mut Formatter<'_>) -> core::fmt::Result {
formatter.write_str("struct ChainId")
}

fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
where
V: MapAccess<'de>,
{
let mut id = None;
let mut revision_number = None;

while let Some(key) = map.next_key()? {
match key {
Field::Id => {
if id.is_some() {
return Err(Error::duplicate_field("id"));
}

let next_value = map.next_value::<&str>()?;

let chain_id = ChainId::from_str(next_value)
.map_err(|_| Error::custom("failed to parse ChainId `id` field"))?;

id = Some(chain_id.id);
revision_number = Some(chain_id.revision_number);
}
Field::RevisionNumber => {
let next_value = map.next_value::<&str>()?;
let rev = u64::from_str(next_value).unwrap_or(0);

if let Some(rn) = revision_number {
if rev != 0 && rn != rev {
return Err(Error::custom(format_args!(
"chain ID revision numbers do not match; got `{}` and `{}`",
rn, rev,
)));
}
} else {
revision_number = Some(rev);
}
}
}
}

let id = id.ok_or_else(|| Error::missing_field("id"))?;

Ok(ChainId {
id,
revision_number: revision_number.unwrap_or(0),
})
}
}

deserializer.deserialize_struct("ChainId", FIELDS, ChainIdVisitor)
}
}

#[cfg(feature = "borsh")]
mod borsh_impls {
use borsh::maybestd::io::{self, Error, ErrorKind, Read};
use borsh::BorshDeserialize;

use super::*;

impl BorshDeserialize for ChainId {
fn deserialize_reader<R: Read>(reader: &mut R) -> io::Result<Self> {
let (id, revision_number) = <(String, u64)>::deserialize_reader(reader)?;

match parse_chain_id_string(&id) {
Ok((_, rn)) => {
if revision_number != 0 && rn != revision_number {
return Err(Error::new(
ErrorKind::Other,
"chain ID revision numbers do not match",
));
}
}
_ => {
if revision_number != 0 {
return Err(Error::new(ErrorKind::Other, "failed to parse chain ID"));
}
}
}

Ok(ChainId {
id,
revision_number,
})
}
}
}

/// Construct a `ChainId` from a string literal only if it forms a valid
/// identifier.
impl FromStr for ChainId {
Expand Down Expand Up @@ -250,7 +389,7 @@ mod tests {
#[case(" -")]
#[case(" -1")]
#[case("/chainA-1")]
fn test_invalid_chain_id(#[case] chain_id_str: &str) {
fn test_invalid_chain_id_from_str(#[case] chain_id_str: &str) {
assert!(ChainId::new(chain_id_str).is_err());
}

Expand All @@ -275,4 +414,54 @@ mod tests {
assert_eq!(chain_id.revision_number(), 0);
assert_eq!(chain_id.as_str(), "chainA");
}

#[cfg(feature = "serde")]
#[rstest]
#[case(r#"{"id":"foo-42","revision_number":"42"}"#)]
#[case(r#"{"id":"foo-42","revision_number":"0"}"#)]
#[case(r#"{"id":"foo-bar-42","revision_number":"0"}"#)]
fn test_valid_chain_id_json_deserialization(#[case] chain_id_json: &str) {
let chain_id = serde_json::from_str::<ChainId>(chain_id_json);
assert!(chain_id.is_ok());

let chain_id = chain_id.unwrap();

let (_id, rev_num) = chain_id.split_chain_id().unwrap();

assert_eq!(rev_num, chain_id.revision_number());
}

#[cfg(feature = "serde")]
#[rstest]
#[case(r#"{"id":"foo-42","revision_number":"69"}"#)]
#[case(r#"{"id":"foo-0","revision_number":"69"}"#)]
#[case(r#"{"id":"/foo-42","revision_number":"0"}"#)]
fn test_invalid_chain_id_json_deserialization(#[case] chain_id_json: &str) {
assert!(serde_json::from_str::<ChainId>(chain_id_json).is_err())
}

#[cfg(feature = "borsh")]
#[rstest]
#[case(b"\x06\0\0\0foo-42\x45\0\0\0\0\0\0\0")]
fn test_invalid_chain_id_borsh_deserialization(#[case] chain_id_bytes: &[u8]) {
use borsh::BorshDeserialize;

assert!(ChainId::try_from_slice(chain_id_bytes).is_err())
}

#[cfg(feature = "borsh")]
fn borsh_ser_de_roundtrip(chain_id: ChainId) {
use borsh::{BorshDeserialize, BorshSerialize};

let chain_id_bytes = chain_id.try_to_vec().unwrap();
let res = ChainId::try_from_slice(&chain_id_bytes).unwrap();
assert_eq!(chain_id, res);
}

#[cfg(feature = "borsh")]
#[test]
fn test_valid_borsh_ser_de_roundtrip() {
borsh_ser_de_roundtrip(ChainId::new("foo-42").unwrap());
borsh_ser_de_roundtrip(ChainId::new("foo").unwrap());
}
}

0 comments on commit 7481311

Please sign in to comment.