Skip to content

Commit

Permalink
refactor: migrate gateway to use NewMessage
Browse files Browse the repository at this point in the history
  • Loading branch information
cgorenflo committed Sep 21, 2023
1 parent 45fcc06 commit b4675ca
Show file tree
Hide file tree
Showing 9 changed files with 284 additions and 603 deletions.
1 change: 0 additions & 1 deletion contracts/aggregate-verifier/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use connection_router::state::{CrossChainId, NewMessage, ID_SEPARATOR};
use cosmwasm_std::from_binary;
use cosmwasm_std::Addr;
use cw_multi_test::{App, ContractWrapper, Executor};
use voting_verifier::msg as voting_msg;

use crate::mock::{make_mock_voting_verifier, mark_messages_as_verified};
pub mod mock;
Expand Down
19 changes: 19 additions & 0 deletions contracts/connection-router/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ impl TryFrom<String> for Address {

#[cw_serde]
#[serde(try_from = "String")]
#[derive(Eq, Hash)]
pub struct MessageId(String);

impl FromStr for MessageId {
Expand Down Expand Up @@ -333,11 +334,28 @@ impl KeyDeserialize for MessageId {
}

#[cw_serde]
#[derive(Eq, Hash)]
pub struct CrossChainId {
pub chain: ChainName,
pub id: MessageId,
}

/// todo: remove this when state::NewMessage is used
impl FromStr for CrossChainId {
type Err = ContractError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts = s.split_once(ID_SEPARATOR);
let (chain, id) = parts
.map(|(chain, id)| (chain.parse::<ChainName>(), id.parse::<MessageId>()))
.ok_or(ContractError::InvalidMessageId)?;
Ok(CrossChainId {
chain: chain?,
id: id?,
})
}
}

impl Display for CrossChainId {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}{}{}", &self.chain, ID_SEPARATOR, &self.id)
Expand Down Expand Up @@ -368,6 +386,7 @@ impl KeyDeserialize for CrossChainId {

#[cw_serde]
#[serde(try_from = "String")]
#[derive(Eq, Hash)]
pub struct ChainName(String);

impl FromStr for ChainName {
Expand Down
76 changes: 25 additions & 51 deletions contracts/gateway/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
state::{Config, CONFIG, OUTGOING_MESSAGES},
};

use connection_router::state::Message;
use connection_router::state::NewMessage;

use self::execute::{route_incoming_messages, route_outgoing_messages, verify_messages};

Expand All @@ -35,19 +35,8 @@ pub fn execute(
msg: ExecuteMsg,
) -> Result<Response, axelar_wasm_std::ContractError> {
match msg {
ExecuteMsg::VerifyMessages(messages) => {
let msgs = messages
.into_iter()
.map(Message::try_from)
.collect::<Result<Vec<Message>, _>>()?;
verify_messages(deps, msgs)
}
ExecuteMsg::RouteMessages(messages) => {
let msgs = messages
.into_iter()
.map(Message::try_from)
.collect::<Result<Vec<Message>, _>>()?;

ExecuteMsg::VerifyMessages(msgs) => verify_messages(deps, msgs),
ExecuteMsg::RouteMessages(msgs) => {
let router = CONFIG.load(deps.storage)?.router;
if info.sender == router {
route_outgoing_messages(deps, msgs)
Expand All @@ -61,40 +50,37 @@ pub fn execute(

pub mod execute {

use cosmwasm_std::{to_binary, QueryRequest, StdError, WasmMsg, WasmQuery};
use connection_router::state::CrossChainId;
use cosmwasm_std::{to_binary, QueryRequest, WasmMsg, WasmQuery};

use crate::{events::GatewayEvent, state::OUTGOING_MESSAGES};

use super::*;

fn contains_duplicates(msgs: &mut Vec<Message>) -> bool {
fn contains_duplicates(msgs: &mut Vec<NewMessage>) -> bool {
let orig_len = msgs.len();
msgs.sort_unstable_by_key(|a| a.id.to_string());
msgs.dedup_by(|a, b| a.id == b.id);
msgs.sort_unstable_by_key(|a| a.cc_id.to_string());
msgs.dedup_by(|a, b| a.cc_id == b.cc_id);
orig_len != msgs.len()
}

fn partition_by_verified(
deps: DepsMut,
msgs: Vec<Message>,
) -> Result<(Vec<Message>, Vec<Message>), ContractError> {
msgs: Vec<NewMessage>,
) -> Result<(Vec<NewMessage>, Vec<NewMessage>), ContractError> {
let verifier = CONFIG.load(deps.storage)?.verifier;

let query_msg = aggregate_verifier::msg::QueryMsg::IsVerified {
messages: msgs
.iter()
.map(|m| m.clone().try_into())
.collect::<Result<_, _>>()
.map_err(|_| StdError::generic_err("invalid messages"))?, //todo: error mapping needs to get removed when gateway is mirated to NewMessage
messages: msgs.clone(),
};
let query_response: Vec<(String, bool)> =
let query_response: Vec<(CrossChainId, bool)> =
deps.querier.query(&QueryRequest::Wasm(WasmQuery::Smart {
contract_addr: verifier.to_string(),
msg: to_binary(&query_msg)?,
}))?;

Ok(msgs.into_iter().partition(|m| -> bool {
match query_response.iter().find(|r| m.id.to_string() == r.0) {
match query_response.iter().find(|r| m.cc_id == r.0) {
Some((_, v)) => *v,
None => false,
}
Expand All @@ -103,7 +89,7 @@ pub mod execute {

pub fn verify_messages(
deps: DepsMut,
mut msgs: Vec<Message>,
mut msgs: Vec<NewMessage>,
) -> Result<Response, ContractError> {
let config = CONFIG.load(deps.storage)?;
let verifier = config.verifier;
Expand All @@ -117,38 +103,29 @@ pub mod execute {
Ok(Response::new().add_message(WasmMsg::Execute {
contract_addr: verifier.to_string(),
msg: to_binary(&aggregate_verifier::msg::ExecuteMsg::VerifyMessages {
messages: unverified
.into_iter()
.map(|m| m.clone().try_into())
.collect::<Result<_, _>>()
.map_err(|_| StdError::generic_err("invalid messages"))?, //todo: error mapping needs to get removed when gateway is mirated to NewMessage
messages: unverified,
})?,
funds: vec![],
}))
}

pub fn route_incoming_messages(
deps: DepsMut,
mut msgs: Vec<Message>,
mut msgs: Vec<NewMessage>,
) -> Result<Response, ContractError> {
let router = CONFIG.load(deps.storage)?.router;

if contains_duplicates(&mut msgs) {
return Err(ContractError::DuplicateMessageID);
}

let (verified, unverified) = partition_by_verified(deps, msgs)?;
let (verified, unverified) = partition_by_verified(deps, msgs.clone())?;

Ok(Response::new()
.add_message(WasmMsg::Execute {
contract_addr: router.to_string(),
msg: to_binary(&connection_router::msg::ExecuteMsg::RouteMessages(
verified
.clone()
.into_iter()
.map(connection_router::state::NewMessage::try_from)
.collect::<Result<Vec<_>, _>>()
.map_err(|_| connection_router::error::ContractError::InvalidMessageId)?, //todo: remove when integrating error stack into the gateway
verified.clone(),
))?,
funds: vec![],
})
Expand All @@ -166,10 +143,10 @@ pub mod execute {

pub fn route_outgoing_messages(
deps: DepsMut,
msgs: Vec<Message>,
msgs: Vec<NewMessage>,
) -> Result<Response, ContractError> {
for m in &msgs {
OUTGOING_MESSAGES.save(deps.storage, m.id.to_string(), m)?;
OUTGOING_MESSAGES.save(deps.storage, m.cc_id.clone(), m)?;
}

Ok(Response::new().add_events(
Expand All @@ -182,19 +159,16 @@ pub mod execute {
#[cfg_attr(not(feature = "library"), entry_point)]
pub fn query(deps: Deps, _env: Env, msg: QueryMsg) -> StdResult<Binary> {
match msg {
QueryMsg::GetMessages { message_ids } => {
QueryMsg::GetMessages {
message_ids: cross_chain_ids,
} => {
let mut msgs = vec![];

for id in message_ids {
for id in cross_chain_ids {
msgs.push(OUTGOING_MESSAGES.load(deps.storage, id)?);
}

to_binary(
&msgs
.into_iter()
.map(|m| m.into())
.collect::<Vec<connection_router::msg::Message>>(),
)
to_binary(&msgs)
}
}
}
22 changes: 12 additions & 10 deletions contracts/gateway/src/events.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
use connection_router::events::make_message_event;
use connection_router::state::Message;
use connection_router::events::make_message_event_new;
use connection_router::state::NewMessage;
use cosmwasm_std::Event;

pub enum GatewayEvent {
MessageVerified { msg: Message },
MessageVerificationFailed { msg: Message },
MessageRouted { msg: Message },
MessageRoutingFailed { msg: Message },
MessageVerified { msg: NewMessage },
MessageVerificationFailed { msg: NewMessage },
MessageRouted { msg: NewMessage },
MessageRoutingFailed { msg: NewMessage },
}

impl From<GatewayEvent> for Event {
fn from(other: GatewayEvent) -> Self {
match other {
GatewayEvent::MessageVerified { msg } => make_message_event("message_verified", msg),
GatewayEvent::MessageRouted { msg } => make_message_event("message_routed", msg),
GatewayEvent::MessageVerified { msg } => {
make_message_event_new("message_verified", msg)

Check warning on line 16 in contracts/gateway/src/events.rs

View check run for this annotation

Codecov / codecov/patch

contracts/gateway/src/events.rs#L15-L16

Added lines #L15 - L16 were not covered by tests
}
GatewayEvent::MessageRouted { msg } => make_message_event_new("message_routed", msg),
GatewayEvent::MessageVerificationFailed { msg } => {
make_message_event("message_verification_failed", msg)
make_message_event_new("message_verification_failed", msg)

Check warning on line 20 in contracts/gateway/src/events.rs

View check run for this annotation

Codecov / codecov/patch

contracts/gateway/src/events.rs#L20

Added line #L20 was not covered by tests
}
GatewayEvent::MessageRoutingFailed { msg } => {
make_message_event("message_routing_failed", msg)
make_message_event_new("message_routing_failed", msg)
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions contracts/gateway/src/msg.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use connection_router::msg::Message;
use connection_router::state::{CrossChainId, NewMessage};
use cosmwasm_schema::{cw_serde, QueryResponses};

#[cw_serde]
Expand All @@ -10,15 +10,15 @@ pub struct InstantiateMsg {
#[cw_serde]
pub enum ExecuteMsg {
// Permissionless
VerifyMessages(Vec<Message>),
VerifyMessages(Vec<NewMessage>),

// Permissionless
RouteMessages(Vec<Message>),
RouteMessages(Vec<NewMessage>),
}

#[cw_serde]
#[derive(QueryResponses)]
pub enum QueryMsg {
#[returns(Vec<Message>)]
GetMessages { message_ids: Vec<String> },
#[returns(Vec<NewMessage>)]
GetMessages { message_ids: Vec<CrossChainId> },
}
4 changes: 2 additions & 2 deletions contracts/gateway/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use connection_router::state::Message;
use connection_router::state::{CrossChainId, NewMessage};
use cosmwasm_schema::cw_serde;
use cosmwasm_std::Addr;
use cw_storage_plus::{Item, Map};
Expand All @@ -11,4 +11,4 @@ pub struct Config {

pub const CONFIG: Item<Config> = Item::new("config");

pub const OUTGOING_MESSAGES: Map<String, Message> = Map::new("outgoing_messages");
pub const OUTGOING_MESSAGES: Map<CrossChainId, NewMessage> = Map::new("outgoing_messages");
30 changes: 12 additions & 18 deletions contracts/gateway/tests/mock.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use connection_router::msg::Message;
use connection_router::state::NewMessage;
use connection_router::state::{CrossChainId, NewMessage};
use cosmwasm_schema::cw_serde;
use cosmwasm_std::{to_binary, Addr, Binary, Deps, DepsMut, Env, MessageInfo, Response, StdResult};
use cw_multi_test::{App, ContractWrapper, Executor};
Expand All @@ -24,19 +23,17 @@ pub fn mock_verifier_execute(
MockVerifierExecuteMsg::VerifyMessages { messages } => {
let mut res = vec![];
for m in messages {
let m = connection_router::state::Message::try_from(m).unwrap();
match MOCK_VERIFIER_MESSAGES
.may_load(deps.storage, serde_json::to_string(&m).unwrap())?
{
Some(b) => res.push((m.id, b)),
None => res.push((m.id, false)),
Some(b) => res.push((m.cc_id, b)),
None => res.push((m.cc_id, false)),
}
}
Ok(Response::new().set_data(to_binary(&res)?))
}
MockVerifierExecuteMsg::MessagesVerified { messages } => {
for m in messages {
let m = connection_router::state::Message::try_from(m).unwrap();
MOCK_VERIFIER_MESSAGES.save(
deps.storage,
serde_json::to_string(&m).unwrap(),
Expand All @@ -58,12 +55,11 @@ pub fn mock_verifier_query(deps: Deps, _env: Env, msg: MockVerifierQueryMsg) ->
match msg {
MockVerifierQueryMsg::IsVerified { messages } => {
for m in messages {
let m = connection_router::state::Message::try_from(m).unwrap();
match MOCK_VERIFIER_MESSAGES
.may_load(deps.storage, serde_json::to_string(&m).unwrap())?
{
Some(v) => res.push((m.id.to_string(), v)),
None => res.push((m.id.to_string(), false)),
Some(v) => res.push((m.cc_id, v)),
None => res.push((m.cc_id, false)),
}
}
}
Expand All @@ -75,7 +71,7 @@ pub fn is_verified(
app: &mut App,
verifier_address: Addr,
msgs: Vec<NewMessage>,
) -> Vec<(String, bool)> {
) -> Vec<(CrossChainId, bool)> {
app.wrap()
.query_wasm_smart(
verifier_address,
Expand All @@ -94,8 +90,7 @@ pub fn mark_messages_as_verified(app: &mut App, verifier_address: Addr, msgs: Ve
.unwrap();
}

const MOCK_ROUTER_MESSAGES: Map<String, connection_router::state::Message> =
Map::new("router_messages");
const MOCK_ROUTER_MESSAGES: Map<CrossChainId, NewMessage> = Map::new("router_messages");

pub fn mock_router_execute(
deps: DepsMut,
Expand All @@ -106,8 +101,7 @@ pub fn mock_router_execute(
match msg {
connection_router::msg::ExecuteMsg::RouteMessages(msgs) => {
for msg in msgs {
let msg = connection_router::state::Message::try_from(msg)?;
MOCK_ROUTER_MESSAGES.save(deps.storage, msg.id.to_string(), &msg)?;
MOCK_ROUTER_MESSAGES.save(deps.storage, msg.cc_id.clone(), &msg)?;
}
}
_ => (),
Expand All @@ -117,7 +111,7 @@ pub fn mock_router_execute(

#[cw_serde]
pub enum MockRouterQueryMsg {
GetMessages { ids: Vec<String> },
GetMessages { ids: Vec<CrossChainId> },
}
pub fn mock_router_query(deps: Deps, _env: Env, msg: MockRouterQueryMsg) -> StdResult<Binary> {
let mut msgs = vec![];
Expand All @@ -138,13 +132,13 @@ pub fn mock_router_query(deps: Deps, _env: Env, msg: MockRouterQueryMsg) -> StdR
pub fn get_router_messages(
app: &mut App,
router_address: Addr,
msgs: Vec<connection_router::msg::Message>,
) -> Vec<connection_router::state::Message> {
msgs: Vec<NewMessage>,
) -> Vec<NewMessage> {
app.wrap()
.query_wasm_smart(
router_address,
&MockRouterQueryMsg::GetMessages {
ids: msgs.iter().map(|m| m.id.to_string()).collect(),
ids: msgs.into_iter().map(|m| m.cc_id).collect(),
},
)
.unwrap()
Expand Down
Loading

0 comments on commit b4675ca

Please sign in to comment.