From 8c5f23810f796d61cac03edd9f83bb1f2d69f554 Mon Sep 17 00:00:00 2001 From: Bruce Riley Date: Mon, 14 Oct 2024 14:24:53 -0500 Subject: [PATCH] EVM: Implement attest and receive --- evm/src/Router.sol | 244 +++++++++++++++++++++++++-------- evm/src/interfaces/IRouter.sol | 52 +++---- evm/test/Router.t.sol | 106 +++++++------- 3 files changed, 270 insertions(+), 132 deletions(-) diff --git a/evm/src/Router.sol b/evm/src/Router.sol index f92d117c..0cf715e4 100644 --- a/evm/src/Router.sol +++ b/evm/src/Router.sol @@ -42,12 +42,6 @@ contract Router is IRouter, MessageSequence, TransceiverRegistry { /// @param newAdmin The address of the new admin contract. event AdminUpdated(address integrator, address oldAdmin, address newAdmin); - /// @notice Emitted when a message has been attested to. - /// @param integrator The address of the integrator. - /// @param transceiver The address of the transceiver. - /// @param digest The digest of the message. - event MessageAttestedTo(address integrator, address transceiver, bytes32 digest); - /// @notice Emitted when a message has been sent. /// @param sender The address of the sender. /// @param recipient The address of the recipient. @@ -55,6 +49,44 @@ contract Router is IRouter, MessageSequence, TransceiverRegistry { /// @param digest The digest of the message. event MessageSent(address sender, address recipient, uint16 recipientChain, bytes32 digest); + /// @notice Emitted when a message has been attested to. + /// @param sourceChainId The Wormhole chain ID of the sender. + /// @param sourceAddress The universal address of the peer on the sending chain. + /// @param sequence The sequence number of the message (per integrator). + /// @param destinationChainId The Wormhole chain ID of the destination. + /// @param destinationAddress The destination address of the message. + /// @param payloadHash The keccak256 of payload from the integrator. + /// @param attestedBitmap Bitmap of transceivers that have attested the message. + event MessageAttestedTo( + uint16 sourceChainId, + UniversalAddress sourceAddress, + uint64 sequence, + uint16 destinationChainId, + UniversalAddress destinationAddress, + bytes32 payloadHash, + uint128 attestedBitmap + ); + + /// @notice Emitted when a message has been received. + /// @param sourceChainId The Wormhole chain ID of the sender. + /// @param sourceAddress The universal address of the peer on the sending chain. + /// @param sequence The sequence number of the message (per integrator). + /// @param destinationChainId The Wormhole chain ID of the destination. + /// @param destinationAddress The destination address of the message. + /// @param payloadHash The keccak256 of payload from the integrator. + /// @param enabledBitmap Bitmap of transceivers enabled for the source chain. + /// @param attestedBitmap Bitmap of transceivers that have attested the message. + event MessageReceived( + uint16 sourceChainId, + UniversalAddress sourceAddress, + uint64 sequence, + uint16 destinationChainId, + UniversalAddress destinationAddress, + bytes32 payloadHash, + uint128 enabledBitmap, + uint128 attestedBitmap + ); + // =============== Errors ================================================================ /// @notice Error when the destination chain ID doesn't match this chain. @@ -75,6 +107,12 @@ contract Router is IRouter, MessageSequence, TransceiverRegistry { /// @notice Error when the caller is not the registered admin. error CallerNotAdmin(); + /// @notice Message attestation not found in store. + error UnknownMessageAttestation(); + + /// @notice Message is already marked as executed. + error AlreadyExecuted(); + // =============== Storage =============================================================== /// @dev Holds the integrator address to IntegratorConfig mapping. @@ -89,6 +127,27 @@ contract Router is IRouter, MessageSequence, TransceiverRegistry { } } + struct AttestationInfo { + bool executed; // replay protection + uint128 attestedTransceivers; // bitmap corresponding to perIntegratorTransceivers + } + + /// @dev Holds the integrator address to message digest to attestation info mapping. + /// mapping(address => IntegratorConfig) + bytes32 private constant ATTESTATION_INFO_SLOT = bytes32(uint256(keccak256("registry.attestationInfo")) - 1); + + /// @dev Integrator address => message digest -> attestation info mapping. + function _getAttestationInfoStorage() + internal + pure + returns (mapping(address => mapping(bytes32 => AttestationInfo)) storage $) + { + uint256 slot = uint256(ATTESTATION_INFO_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + // =============== Getters =============================================================== // TODO: Can this be public? @@ -167,26 +226,29 @@ contract Router is IRouter, MessageSequence, TransceiverRegistry { } /// @inheritdoc IRouter - function receiveMessage( - uint16 sourceChain, - UniversalAddress senderAddress, - uint16 destinationChain, - address refundAddress, - bytes32 messageHash - ) external payable isRegistered returns (uint128) { - _receiveMessage(sourceChain, senderAddress, destinationChain, refundAddress, messageHash); + function attestMessage( + uint16 sourceChainId, + UniversalAddress sourceAddress, + uint64 sequence, + uint16 destinationChainId, + UniversalAddress destinationAddress, + bytes32 payloadHash + ) external { + // This is called by the transceiver so we don't check isRegistered. + _attestMessage(sourceChainId, sourceAddress, sequence, destinationChainId, destinationAddress, payloadHash); } /// @inheritdoc IRouter - function attestMessage( - uint16 sourceChain, // Wormhole Chain ID - UniversalAddress sourceAddress, // UniversalAddress of the message sender (integrator) - uint64 sequence, // Next sequence number for that integrator (consuming the sequence number) - uint16 destinationChain, // Wormhole Chain ID - UniversalAddress destinationAddress, // UniversalAddress of the messsage recipient (integrator on destination chain) - bytes32 payloadHash // keccak256 of arbitrary payload from the integrator - ) external isRegistered { - _attestMessage(sourceChain, sourceAddress, sequence, destinationChain, destinationAddress, payloadHash); + function receiveMessage( + uint16 sourceChainId, + UniversalAddress sourceAddress, + uint64 sequence, + uint16 destinationChainId, + UniversalAddress destinationAddress, + bytes32 payloadHash + ) external payable isRegistered returns (uint128 enabledBitmap, uint128 attestedBitmap) { + return + _receiveMessage(sourceChainId, sourceAddress, sequence, destinationChainId, destinationAddress, payloadHash); } // =============== Internal ============================================================== @@ -211,6 +273,19 @@ contract Router is IRouter, MessageSequence, TransceiverRegistry { _; } + function computeMessageDigest( + uint16 sourceChain, + UniversalAddress sourceAddress, + uint64 sequence, + uint16 destinationChain, + UniversalAddress destinationAddress, + bytes32 payloadHash + ) internal returns (bytes32) { + return keccak256( + abi.encodePacked(sourceChain, sourceAddress, sequence, destinationChain, destinationAddress, payloadHash) + ); + } + function _sendMessage( uint16 destinationChainId, UniversalAddress recipientAddress, @@ -249,55 +324,116 @@ contract Router is IRouter, MessageSequence, TransceiverRegistry { // call sendMessage } - function _receiveMessage( - uint16 sourceChain, - UniversalAddress senderAddress, + function _attestMessage( + uint16 sourceChainId, + UniversalAddress sourceAddress, + uint64 sequence, uint16 destinationChainId, - address refundAddress, - bytes32 messageHash - ) internal returns (uint128) { + UniversalAddress destinationAddress, + bytes32 payloadHash + ) internal { + address integrator = destinationAddress.toAddress(); + + // sanity check that destinationChain is this chain if (destinationChainId != chainId) { revert InvalidDestinationChainId(); } - // get the enabled receive transceivers for [msg.sender][recipientChain] - address[] memory recvTransceivers = getRecvTransceiversByChain(msg.sender, sourceChain); + // get enabled recv transceivers for [destinationAddress][sourceChain] + address[] memory recvTransceivers = getRecvTransceiversByChain(destinationAddress.toAddress(), sourceChainId); uint256 len = recvTransceivers.length; - if (len == 0) { + // Don't need to revert on zero len because the check below will fail if it's empty. + + // get the index of this receive transceiver + uint8 index = type(uint8).max; + for (uint256 idx = 0; idx < len;) { + if (recvTransceivers[idx] == msg.sender) { + index = _getTransceiverInfosStorage()[integrator][msg.sender].index; + break; + } + unchecked { + ++idx; + } + } + + if (index == type(uint8).max) { revert TransceiverNotEnabled(); } - // Find the transceiver for the source chain. - // address transceiver = this.getRecvTransceiverByChain(msg.sender, sourceChain); - // Receive the message. + // compute the message digest + bytes32 messageDigest = computeMessageDigest( + sourceChainId, sourceAddress, sequence, destinationChainId, destinationAddress, payloadHash + ); + + AttestationInfo storage attestationInfo = _getAttestationInfoStorage()[integrator][messageDigest]; + + // do not revert if already attested or executed + if (attestationInfo.executed) { + return; + } + + // set the bit in perIntegratorAttestations[destinationAddress][digest] corresponding to msg.sender + attestationInfo.attestedTransceivers |= uint128(1 << index); + emit MessageAttestedTo( + sourceChainId, + sourceAddress, + sequence, + destinationChainId, + destinationAddress, + payloadHash, + attestationInfo.attestedTransceivers + ); } - function _attestMessage( - uint16 sourceChain, // Wormhole Chain ID - UniversalAddress sourceAddress, // UniversalAddress of the message sender (integrator) - uint64 sequence, // Next sequence number for that integrator (consuming the sequence number) - uint16 destinationChainId, // Wormhole Chain ID - UniversalAddress destinationAddress, // UniversalAddress of the messsage recipient (integrator on destination chain) - bytes32 payloadHash // keccak256 of arbitrary payload from the integrator - ) internal { + function _receiveMessage( + uint16 sourceChainId, + UniversalAddress sourceAddress, + uint64 sequence, + uint16 destinationChainId, + UniversalAddress destinationAddress, + bytes32 payloadHash + ) internal returns (uint128 enabledBitmap, uint128 attestedBitmap) { + // sanity check that destinationChainId is this chain if (destinationChainId != chainId) { revert InvalidDestinationChainId(); } - // get the enabled receive transceivers for [msg.sender][recipientChain] - address[] memory recvTransceivers = getRecvTransceiversByChain(msg.sender, sourceChain); - uint256 len = recvTransceivers.length; - if (len == 0) { + // sanity check that msg.sender is integrator: The caller has isRegistered. + + enabledBitmap = _getEnabledRecvTransceiversBitmapForChain(msg.sender, sourceChainId); + if (enabledBitmap == 0) { revert TransceiverNotEnabled(); } - address transceiver; - emit MessageAttestedTo(msg.sender, transceiver, payloadHash); - - // get enabled recv transceivers for [destinationAddress][sourceChain] - // address transceiver = this.getRecvTransceiverByChain(sourceChain); - // check that msg.sender is one of those transceivers // compute the message digest - // set the bit in perIntegratorAttestations[destinationAddress][digest] corresponding to msg.sender + bytes32 messageDigest = computeMessageDigest( + sourceChainId, sourceAddress, sequence, destinationChainId, destinationAddress, payloadHash + ); + + AttestationInfo storage attestationInfo = _getAttestationInfoStorage()[msg.sender][messageDigest]; + + // revert if not in perIntegratorAttestations map + if ((attestationInfo.attestedTransceivers == 0) && (!attestationInfo.executed)) { + revert UnknownMessageAttestation(); + } + + // revert if already executed + if (attestationInfo.executed) { + revert AlreadyExecuted(); + } + + // set the executed flag in perIntegratorAttestations[destinationAddress][digest] + attestationInfo.executed = true; + attestedBitmap = attestationInfo.attestedTransceivers; + emit MessageReceived( + sourceChainId, + sourceAddress, + sequence, + destinationChainId, + destinationAddress, + payloadHash, + enabledBitmap, + attestedBitmap + ); } } diff --git a/evm/src/interfaces/IRouter.sol b/evm/src/interfaces/IRouter.sol index d42c19b5..8f2fde4a 100644 --- a/evm/src/interfaces/IRouter.sol +++ b/evm/src/interfaces/IRouter.sol @@ -18,34 +18,36 @@ interface IRouter is IMessageSequence { bytes32 payloadHash ) external payable returns (uint64); - /// @notice Receive a message from another chain called by integrator. - /// @param sourceChain The Wormhole chain ID of the recipient. - /// @param senderAddress The universal address of the peer on the recipient chain. - /// @param destinationChain The Wormhole chain ID of the destination. - /// @param refundAddress The source chain refund address passed to the Transceiver. - /// @param messageHash hash A message to be sent to the recipient chain. - /// @return uint128 The bitmap - function receiveMessage( - uint16 sourceChain, - UniversalAddress senderAddress, - uint16 destinationChain, - address refundAddress, - bytes32 messageHash - ) external payable returns (uint128); - /// @notice Called by a Transceiver contract to attest to a message. - /// @param sourceChain The Wormhole chain ID of the recipient. - /// @param sourceAddress The universal address of the peer on the recipient chain. + /// @param sourceChainId The Wormhole chain ID of the sender. + /// @param sourceAddress The universal address of the peer on the sending chain. /// @param sequence The sequence number of the message (per integrator). - /// @param destinationChain The Wormhole chain ID of the destination. + /// @param destinationChainId The Wormhole chain ID of the destination. /// @param destinationAddress The destination address of the message. - /// @param payloadHash A message to be sent to the recipient chain. + /// @param payloadHash The keccak256 of payload from the integrator. function attestMessage( - uint16 sourceChain, // Wormhole Chain ID - UniversalAddress sourceAddress, // UniversalAddress of the message sender (integrator) - uint64 sequence, // Next sequence number for that integrator (consuming the sequence number) - uint16 destinationChain, // Wormhole Chain ID - UniversalAddress destinationAddress, // UniversalAddress of the messsage recipient (integrator on destination chain) - bytes32 payloadHash // keccak256 of arbitrary payload from the integrator + uint16 sourceChainId, + UniversalAddress sourceAddress, + uint64 sequence, + uint16 destinationChainId, + UniversalAddress destinationAddress, + bytes32 payloadHash ) external; + + /// @notice Called by a integrator contract to receive a message and mark it executed. + /// @param sourceChainId The Wormhole chain ID of the sender. + /// @param sourceAddress The universal address of the peer on the sending chain. + /// @param sequence The sequence number of the message (per integrator). + /// @param destinationChainId The Wormhole chain ID of the destination. + /// @param destinationAddress The destination address of the message. + /// @param payloadHash The keccak256 of payload from the integrator. + /// @return uint128 The bitmap + function receiveMessage( + uint16 sourceChainId, + UniversalAddress sourceAddress, + uint64 sequence, + uint16 destinationChainId, + UniversalAddress destinationAddress, + bytes32 payloadHash + ) external payable returns (uint128, uint128); } diff --git a/evm/test/Router.t.sol b/evm/test/Router.t.sol index 41177c2d..231f27c6 100644 --- a/evm/test/Router.t.sol +++ b/evm/test/Router.t.sol @@ -296,8 +296,10 @@ contract RouterTest is Test { require(transceiver3.getMessagesSent() == 0, "Should not have sent second message on transceiver 3"); } - function test_receiveMessage() public { + function test_attestMessage() public { + UniversalAddress sourceIntegrator = UniversalAddressLibrary.fromAddress(address(userA)); address integrator = address(new Integrator(address(router))); + UniversalAddress destIntegrator = UniversalAddressLibrary.fromAddress(address(integrator)); address admin = address(new Admin(integrator, address(router))); TransceiverImpl transceiver1 = new TransceiverImpl(); TransceiverImpl transceiver2 = new TransceiverImpl(); @@ -305,10 +307,10 @@ contract RouterTest is Test { vm.startPrank(integrator); router.register(admin); - // Receiving with no transceivers should revert. + // Attesting with no transceivers should revert. vm.startPrank(integrator); vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); - router.receiveMessage(2, UniversalAddressLibrary.fromAddress(userA), OurChainId, refundAddr, messageHash); + router.attestMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); // Now enable some transceivers. vm.startPrank(admin); @@ -316,24 +318,33 @@ contract RouterTest is Test { router.setRecvTransceiver(integrator, address(transceiver2), 2); router.setRecvTransceiver(integrator, address(transceiver3), 3); - // Only an integrator can call receive. - vm.startPrank(userA); - vm.expectRevert(abi.encodeWithSelector(Router.IntegratorNotRegistered.selector)); - router.receiveMessage(2, UniversalAddressLibrary.fromAddress(userA), OurChainId + 1, refundAddr, messageHash); + // Only a transceiver can call attest. + vm.startPrank(userB); + vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); + router.attestMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); - // Receiving a message destine for the wrong chain should revert. - vm.startPrank(integrator); + // Attestinging a message destine for the wrong chain should revert. + vm.startPrank(address(transceiver2)); vm.expectRevert(abi.encodeWithSelector(Router.InvalidDestinationChainId.selector)); - router.receiveMessage(2, UniversalAddressLibrary.fromAddress(userA), OurChainId + 1, refundAddr, messageHash); + router.attestMessage(2, sourceIntegrator, 1, OurChainId + 1, destIntegrator, messageHash); - // Receive a message on chain two. It should go out on the first two transceivers, but not the third one. + // This attest should work. + vm.startPrank(address(transceiver2)); + router.attestMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // Receive what we just attested to mark it executed. vm.startPrank(integrator); - router.receiveMessage(2, UniversalAddressLibrary.fromAddress(userA), OurChainId, refundAddr, messageHash); - // TODO: What should we validate here?? + router.receiveMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // Attesting after receive should still work. + vm.startPrank(address(transceiver2)); + router.attestMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); } - function test_attestMessage() public { + function test_receiveMessage() public { + UniversalAddress sourceIntegrator = UniversalAddressLibrary.fromAddress(address(userA)); address integrator = address(new Integrator(address(router))); + UniversalAddress destIntegrator = UniversalAddressLibrary.fromAddress(address(integrator)); address admin = address(new Admin(integrator, address(router))); TransceiverImpl transceiver1 = new TransceiverImpl(); TransceiverImpl transceiver2 = new TransceiverImpl(); @@ -344,56 +355,45 @@ contract RouterTest is Test { // Receiving with no transceivers should revert. vm.startPrank(integrator); vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); - router.attestMessage( - 2, - UniversalAddressLibrary.fromAddress(userA), - 1, - OurChainId, - UniversalAddressLibrary.fromAddress(userB), - messageHash - ); + router.receiveMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); - // Now enable some transceivers. + // Now enable some transceivers so we can attest. Receive doesn't use the transceivers. vm.startPrank(admin); router.setRecvTransceiver(integrator, address(transceiver1), 2); router.setRecvTransceiver(integrator, address(transceiver2), 2); router.setRecvTransceiver(integrator, address(transceiver3), 3); - // Only an integrator can call attest. - vm.startPrank(userA); + // Only an integrator can call receive. + vm.startPrank(userB); vm.expectRevert(abi.encodeWithSelector(Router.IntegratorNotRegistered.selector)); - router.attestMessage( - 2, - UniversalAddressLibrary.fromAddress(userA), - 1, - OurChainId + 1, - UniversalAddressLibrary.fromAddress(userB), - messageHash - ); + router.receiveMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); - // Attestinging a message destine for the wrong chain should revert. + // Receiving a message destine for the wrong chain should revert. vm.startPrank(integrator); vm.expectRevert(abi.encodeWithSelector(Router.InvalidDestinationChainId.selector)); - router.attestMessage( - 2, - UniversalAddressLibrary.fromAddress(userA), - 1, - OurChainId + 1, - UniversalAddressLibrary.fromAddress(userB), - messageHash - ); - - // Receive a message on chain two. It should go out on the first two transceivers, but not the third one. + router.receiveMessage(2, sourceIntegrator, 1, OurChainId + 1, destIntegrator, messageHash); + + // Receiving before there are any attestations should revert. + vm.startPrank(integrator); + vm.expectRevert(abi.encodeWithSelector(Router.UnknownMessageAttestation.selector)); + router.receiveMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // Attest so we can receive. + vm.startPrank(address(transceiver2)); + router.attestMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // This receive should work. vm.startPrank(integrator); - router.attestMessage( - 2, - UniversalAddressLibrary.fromAddress(userA), - 1, - OurChainId, - UniversalAddressLibrary.fromAddress(userB), - messageHash - ); - // TODO: What should we validate here?? + (uint128 enabledBitmap, uint128 attestedBitmap) = + router.receiveMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // Make sure it return the right bitmaps. + require(enabledBitmap == 0x03, "Enabled bitmap is wrong"); + require(attestedBitmap == 0x02, "Attested bitmap is wrong"); + + // But doing it again should revert. + vm.expectRevert(abi.encodeWithSelector(Router.AlreadyExecuted.selector)); + router.receiveMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); } function test_sendMessageIncrementsSequence() public {