From 6a7fa2bba143392817e3595011c006c33077c9cf Mon Sep 17 00:00:00 2001 From: Paul Noel Date: Fri, 20 Sep 2024 13:46:01 -0500 Subject: [PATCH 1/5] evm: add TransceiverRegistry with delegate --- evm/src/Router.sol | 194 +++++++- evm/src/TransceiverRegistry.sol | 708 ++++++++++++++++++++++++++++ evm/src/interfaces/IRouter.sol | 27 +- evm/src/interfaces/ITransceiver.sol | 32 ++ evm/test/Router.t.sol | 200 +++++++- evm/test/TransceiverRegistry.t.sol | 315 +++++++++++++ 6 files changed, 1452 insertions(+), 24 deletions(-) create mode 100644 evm/src/TransceiverRegistry.sol create mode 100644 evm/src/interfaces/ITransceiver.sol create mode 100644 evm/test/TransceiverRegistry.t.sol diff --git a/evm/src/Router.sol b/evm/src/Router.sol index 520663ef..1d64cb90 100644 --- a/evm/src/Router.sol +++ b/evm/src/Router.sol @@ -3,31 +3,207 @@ pragma solidity ^0.8.13; import "./interfaces/IRouter.sol"; import "./MessageSequence.sol"; +import "./TransceiverRegistry.sol"; +import "./interfaces/ITransceiver.sol"; -contract Router is IRouter, MessageSequence { +contract Router is IRouter, MessageSequence, TransceiverRegistry { string public constant ROUTER_VERSION = "0.0.1"; + // =============== Events ================================================================ + + /// @notice Emitted when an integrator registers a delegate. + /// @param integrator The address of the integrator. + /// @param delegate The address of the delegate. + event DelegateRegistered(address integrator, address delegate); + + /// @notice Emitted when a message has been attested to. + /// @param integrator The address of the integrator. + /// @param transceiver The address of the transceiver. + /// @param digest The digest of the message. + event MessageAttestedTo(address integrator, address transceiver, bytes32 digest); + + /// @notice Emitted when a message has been sent. + /// @param sender The address of the sender. + /// @param recipient The address of the recipient. + /// @param recipientChain The chainId of the recipient. + /// @param digest The digest of the message. + event MessageSent(address sender, address recipient, uint16 recipientChain, bytes32 digest); + + // =============== Errors ================================================================ + + /// @notice Error when the transceiver is disabled. + error TransceiverNotEnabled(); + + /// @notice Error when the admin is the zero address. + error InvalidAdminZeroAddress(); + + /// @notice Error when the integrator did not register an admin. + error IntegratorNotRegistered(); + + /// @notice Error when the caller is not the registered admin. + error CallerNotAdmin(); + + // =============== Storage =============================================================== + + /// @dev Holds the integrator address to IntegratorConfig mapping. + /// mapping(address => IntegratorConfig) + bytes32 private constant INTEGRATOR_CONFIGS_SLOT = bytes32(uint256(keccak256("registry.integratorConfigs")) - 1); + + /// @dev Integrator address => IntegratorConfig mapping. + function _getIntegratorConfigsStorage() internal pure returns (mapping(address => IntegratorConfig) storage $) { + uint256 slot = uint256(INTEGRATOR_CONFIGS_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + // =============== External ============================================================== + /// @notice This is the first thing an integrator should do to register the admin address. + /// The admin address is used to manage the transceivers. + /// @dev The msg.sender needs to be the integrator contract. + /// @param admin The address of the admin. Pass in msg.sender, if you want the integrator to be the admin. + function registerAdmin(address admin) external { + // Get the storage for this integrator contract + mapping(address => IntegratorConfig) storage integratorConfigs = _getIntegratorConfigsStorage(); + // Do some checks? Should address(0) mean something special? + if (admin == address(0)) { + revert InvalidAdminZeroAddress(); + } + // Update the storage. + integratorConfigs[msg.sender] = IntegratorConfig({isInitialized: true, admin: admin}); + } + + /// @notice The admin contract calls this function. + /// @param integrator The address of the integrator contract. + /// @param transceiver The address of the Transceiver contract. + /// @param chain The chain ID of the Transceiver contract. + function setSendTransceiver(address integrator, address transceiver, uint16 chain) external { + _checkIntegratorAdmin(integrator, msg.sender); + _setSendTransceiver(integrator, transceiver, chain); + } + + /// @notice The admin contract uses this function. + /// @param integrator The address of the integrator contract. + /// @param transceiver The address of the Transceiver contract. + /// @param chain The chain ID of the Transceiver contract. + function setRecvTransceiver(address integrator, address transceiver, uint16 chain) external { + _checkIntegratorAdmin(integrator, msg.sender); + _setRecvTransceiver(integrator, transceiver, chain); + } + /// @inheritdoc IRouter function sendMessage( uint16 recipientChain, UniversalAddress recipientAddress, address refundAddress, - bytes memory message + bytes32 payloadHash ) external payable returns (uint64) { - return _sendMessage(recipientChain, recipientAddress, refundAddress, message, msg.sender); + return _sendMessage(recipientChain, recipientAddress, refundAddress, payloadHash, msg.sender); + } + + /// @dev Receive a message from another chain called by integrator. + /// @param sourceChain The Wormhole chain ID of the recipient. + /// @param senderAddress The universal address of the peer on the recipient chain. + /// @param refundAddress The source chain refund address passed to the Transceiver. + /// @param messageHash The hash of the message. + /// @return uint128 The bitmap + function receiveMessage( + uint16 sourceChain, + UniversalAddress senderAddress, + address refundAddress, + bytes32 messageHash + ) external payable returns (uint128) { + // Find the transceiver for the source chain. + // address transceiver = this.getRecvTransceiverByChain(msg.sender, sourceChain); + // Receive the message. + } + + /// @inheritdoc IRouter + function attestMessage( + uint16 sourceChain, // Wormhole Chain ID + UniversalAddress sourceAddress, // UniversalAddress of the message sender (integrator) + uint64 sequence, // Next sequence number for that integrator (consuming the sequence number) + uint16 destinationChainId, // Wormhole Chain ID + UniversalAddress destinationAddress, // UniversalAddress of the messsage recipient (integrator on destination chain) + bytes32 payloadHash // keccak256 of arbitrary payload from the integrator + ) external { + _attestMessage(sourceChain, sourceAddress, sequence, destinationChainId, destinationAddress, payloadHash); } // =============== Internal ============================================================== + /// @notice This function checks that the integrator is registered. + /// @dev This function will revert under the following conditions: + /// - The integrator is not registered + /// @param integrator The integrator address + function _checkIntegrator(address integrator) internal view { + IntegratorConfig storage config = _getIntegratorConfigsStorage()[integrator]; + if (!config.isInitialized) { + revert IntegratorNotRegistered(); + } + } + + /// @notice This function checks that the integrator is registered and the admin is valid. + /// @dev This function will revert under the following conditions: + /// - The integrator is not registered + /// - The admin is not configured for this integrator + /// @param integrator The integrator address + /// @param admin The admin address for this integrator + function _checkIntegratorAdmin(address integrator, address admin) internal view { + IntegratorConfig storage config = _getIntegratorConfigsStorage()[integrator]; + if (!config.isInitialized) { + revert IntegratorNotRegistered(); + } + + if (config.admin != admin) { + revert CallerNotAdmin(); + } + } + function _sendMessage( - uint16, // recipientChain, - UniversalAddress, // recipientAddress, - address, // refundAddress, - bytes memory, // _message, - address sender + uint16 chainId, + UniversalAddress recipientAddress, + address refundAddress, + bytes32 messageHash, + address // sender ) internal returns (uint64 sequence) { - sequence = _useMessageSequence(sender); + _checkIntegrator(msg.sender); + // get the next sequence number for msg.sender + sequence = _useMessageSequence(msg.sender); + // get the enabled send transceivers for [msg.sender][recipientChain] + address[] memory sendTransceivers = this.getSendTransceiversByChain(msg.sender, chainId); + if (sendTransceivers.length == 0) { + revert TransceiverNotEnabled(); + } + for (uint256 i = 0; i < sendTransceivers.length; i++) { + // quote the delivery price + uint256 deliveryPrice = ITransceiver(sendTransceivers[i]).quoteDeliveryPrice(chainId); + // call sendMessage + ITransceiver(sendTransceivers[i]).sendMessage{value: deliveryPrice}( + chainId, messageHash, recipientAddress, UniversalAddressLibrary.fromAddress(refundAddress).toBytes32() + ); + } + // for each enabled transceiver + // quote the delivery price + // see https://github.com/wormhole-foundation/example-native-token-transfers/blob/68a7ca4132c74e838ac23e54752e8c0bc02bb4a2/evm/src/NttManager/ManagerBase.sol#L113 + // call sendMessage + } + + function _attestMessage( + uint16 sourceChain, // Wormhole Chain ID + UniversalAddress sourceAddress, // UniversalAddress of the message sender (integrator) + uint64 sequence, // Next sequence number for that integrator (consuming the sequence number) + uint16 destinationChainId, // Wormhole Chain ID + UniversalAddress destinationAddress, // UniversalAddress of the messsage recipient (integrator on destination chain) + bytes32 payloadHash // keccak256 of arbitrary payload from the integrator + ) internal { + _checkIntegrator(msg.sender); + // sanity check that destinationChainId is this chain + // get enabled recv transceivers for [destinationAddress][sourceChain] + // address transceiver = this.getRecvTransceiverByChain(sourceChain); + // check that msg.sender is one of those transceivers + // compute the message digest + // set the bit in perIntegratorAttestations[destinationAddress][digest] corresponding to msg.sender } } diff --git a/evm/src/TransceiverRegistry.sol b/evm/src/TransceiverRegistry.sol new file mode 100644 index 00000000..f47bd121 --- /dev/null +++ b/evm/src/TransceiverRegistry.sol @@ -0,0 +1,708 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.13; + +/// @title TransceiverRegistry +/// @notice This contract is responsible for handling the registration of Transceivers. +/// @dev This contract checks that a few critical invariants hold when transceivers are added or removed, +/// including: +/// 1. If a transceiver is not registered, it should be enabled. +/// 2. The value set in the bitmap of transceivers +/// should directly correspond to the whether the transceiver is enabled +abstract contract TransceiverRegistry { + constructor() {} + + /// @dev Information about registered transceivers. + struct TransceiverInfo { + // whether this transceiver is registered + bool registered; + uint8 index; // the index into the integrator's transceivers array + } + + // TODO: Does this need to be a struct? + /// @dev Bitmap encoding the enabled transceivers. + /// invariant: forall (i: uint8), enabledTransceiverBitmap & i == 1 <=> transceiverInfos[i].enabled + struct _EnabledTransceiverBitmap { + uint128 bitmap; // MAX_TRANSCEIVERS = 128 + } + + /// @dev Total number of registered transceivers. This number can only increase. + /// invariant: numRegisteredTransceivers <= MAX_TRANSCEIVERS + /// invariant: forall (i: uint8), + /// i < numRegisteredTransceivers <=> exists (a: address), transceiverInfos[a].index == i + struct _NumTransceivers { + uint8 registered; + } + + struct IntegratorConfig { + bool isInitialized; + address admin; + } + + uint8 constant MAX_TRANSCEIVERS = 128; + + // =============== Events =============================================== + + /// @notice Emitted when a send side transceiver is added. + /// @param integrator The address of the integrator. + /// @param transceiver The address of the transceiver. + /// @param chainId The chain to which the threshold applies. + /// @param transceiversNum The current number of transceivers. + event SendTransceiverAdded(address integrator, address transceiver, uint16 chainId, uint64 transceiversNum); + + /// @notice Emitted when a receive side transceiver is added. + /// @param integrator The address of the integrator. + /// @param transceiver The address of the transceiver. + /// @param chainId The chain to which the threshold applies. + /// @param transceiversNum The current number of transceivers. + event RecvTransceiverAdded(address integrator, address transceiver, uint16 chainId, uint64 transceiversNum); + + /// @notice Emitted when a send side transceiver is enabled for a chain. + /// @param integrator The address of the integrator. + /// @param transceiver The address of the transceiver. + /// @param chainId The chain to which the threshold applies. + event SendTransceiverEnabledForChain(address integrator, address transceiver, uint16 chainId); + + /// @notice Emitted when a receive side transceiver is enabled for a chain. + /// @param integrator The address of the integrator. + /// @param transceiver The address of the transceiver. + /// @param chainId The chain to which the threshold applies. + event RecvTransceiverEnabledForChain(address integrator, address transceiver, uint16 chainId); + + /// @notice Emitted when a send side transceiver is removed from the nttManager. + /// @param integrator The address of the integrator. + /// @param transceiver The address of the transceiver. + /// @param chainId The chain to which the threshold applies. + event SendTransceiverDisabled(address integrator, address transceiver, uint16 chainId); + + /// @notice Emitted when a receive side transceiver is removed from the nttManager. + /// @param integrator The address of the integrator. + /// @param transceiver The address of the transceiver. + /// @param chainId The chain to which the threshold applies. + event RecvTransceiverDisabled(address integrator, address transceiver, uint16 chainId); + + // =============== Errors =============================================== + + /// @notice Error when the caller is not the transceiver. + /// @param caller The address of the caller. + error CallerNotTransceiver(address caller); + + /// @notice Error when the transceiver is the zero address. + error InvalidTransceiverZeroAddress(); + + /// @notice Error when the transceiver is disabled. + error DisabledTransceiver(address transceiver); + + /// @notice Error when the number of registered transceivers + /// exceeeds (MAX_TRANSCEIVERS = 64). + error TooManyTransceivers(); + + /// @notice Error when attempting to remove a transceiver + /// that is not registered. + /// @param transceiver The address of the transceiver. + error NonRegisteredTransceiver(address transceiver); + + /// @notice Error when attempting to use an incorrect chain + /// @param chain The id of the incorrect chain + error InvalidChain(uint16 chain); + + /// @notice Error when attempting to enable a transceiver that is already enabled. + /// @param transceiver The address of the transceiver. + error TransceiverAlreadyEnabled(address transceiver); + + // TODO: Not sure if I need this, yet. Will add if Router.sol needs it. + // modifier onlyTransceiver() { + // if (!_getTransceiverInfosStorage()[msg.sender].enabled) { + // revert CallerNotTransceiver(msg.sender); + // } + // _; + // } + + // =============== Storage =============================================== + + /// @dev Holds the integrator address to transceiver address to TransceiverInfo mapping. + /// mapping(address => mapping(address => TransceiverInfo)) + bytes32 private constant TRANSCEIVER_INFOS_SLOT = bytes32(uint256(keccak256("registry.transceiverInfos")) - 1); + + /// @dev Holds send side Integrator address => Transceiver addresses mapping. + /// mapping(address => address[]) across all chains + bytes32 private constant REGISTERED_TRANSCEIVERS_SLOT = + bytes32(uint256(keccak256("registry.registeredTransceivers")) - 1); + + /// @dev Holds send side Integrator address => NumTransceivers mapping. + /// mapping(address => _NumTransceivers) + bytes32 private constant NUM_REGISTERED_TRANSCEIVERS_SLOT = + bytes32(uint256(keccak256("registry.numRegisteredTransceivers")) - 1); + + // =============== Send side ============================================= + + /// @dev Holds send side integrator address => Chain ID => Enabled transceiver bitmap mapping. + /// mapping(address => mapping(uint16 => uint128)) + bytes32 private constant ENABLED_SEND_TRANSCEIVER_BITMAP_SLOT = + bytes32(uint256(keccak256("registry.sendTransceiverBitmap")) - 1); + + /// @dev Holds send side Integrator address => Transceiver addresses mapping. + /// mapping(address => address[]) across all chains + bytes32 private constant REGISTERED_SEND_TRANSCEIVERS_SLOT = + bytes32(uint256(keccak256("registry.registeredSendTransceivers")) - 1); + + // =============== Recv side ============================================= + + /// @dev Holds receive side integrator address => Chain ID => Enabled transceiver bitmap mapping. + /// mapping(address => mapping(uint16 => uint128)) + bytes32 private constant ENABLED_RECV_TRANSCEIVER_BITMAP_SLOT = + bytes32(uint256(keccak256("registry.recvTransceiverBitmap")) - 1); + + /// @dev Holds receive side Integrator address => Transceiver addresses mapping. + /// mapping(address => address[]) across all chains + bytes32 private constant REGISTERED_RECV_TRANSCEIVERS_SLOT = + bytes32(uint256(keccak256("registry.registeredRecvTransceivers")) - 1); + + // =============== Mappings =============================================== + + /// @dev Integrator address => transceiver address => TransceiverInfo mapping. + function _getTransceiverInfosStorage() + internal + pure + returns (mapping(address => mapping(address => TransceiverInfo)) storage $) + { + uint256 slot = uint256(TRANSCEIVER_INFOS_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + /// @dev Integrator address => Chain ID => Enabled transceiver bitmap mapping. + function _getPerChainSendTransceiverBitmapStorage() + private + pure + returns (mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage $) + { + uint256 slot = uint256(ENABLED_SEND_TRANSCEIVER_BITMAP_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + /// @dev Integrator address => Chain ID => Enabled transceiver bitmap mapping. + function _getPerChainRecvTransceiverBitmapStorage() + private + pure + returns (mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage $) + { + uint256 slot = uint256(ENABLED_RECV_TRANSCEIVER_BITMAP_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + /// @dev Integrator address => Transceiver address[] mapping. + /// Contains all registered transceivers for this integrator. + function _getRegisteredTransceiversStorage() internal pure returns (mapping(address => address[]) storage $) { + uint256 slot = uint256(REGISTERED_TRANSCEIVERS_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + /// @dev Integrator address => NumTransceivers mapping. + /// Contains number of registered transceivers for this integrator. + /// The transceivers may or may not be enabled. + function _getNumTransceiversStorage() internal pure returns (mapping(address => _NumTransceivers) storage $) { + uint256 slot = uint256(NUM_REGISTERED_TRANSCEIVERS_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + // =============== Storage Getters/Setters ======================================== + + /// @dev Returns if the send side transceiver is enabled for the given integrator and chain. + /// @param integrator The integrator address + /// @param transceiver The transceiver address + /// @param chainId The chain ID + /// @return true if the transceiver is enabled, false otherwise. + function _isSendTransceiverEnabledForChain(address integrator, address transceiver, uint16 chainId) + internal + view + returns (bool) + { + uint128 bitmap = _getEnabledSendTransceiversBitmapForChain(integrator, chainId); + return _isTransceiverEnabledForChain(integrator, transceiver, bitmap); + } + + /// @dev Returns if the receive side transceiver is enabled for the given integrator and chain. + /// @param integrator The integrator address + /// @param transceiver The transceiver address + /// @param chainId The chain ID + /// @return true if the transceiver is enabled, false otherwise. + function _isRecvTransceiverEnabledForChain(address integrator, address transceiver, uint16 chainId) + internal + view + returns (bool) + { + uint128 bitmap = _getEnabledRecvTransceiversBitmapForChain(integrator, chainId); + return _isTransceiverEnabledForChain(integrator, transceiver, bitmap); + } + + /// @dev This is a common function between send/receive transceivers. + /// @param integrator The integrator address + /// @param transceiver The transceiver address + /// @return true if the transceiver is enabled, false otherwise. + function _isTransceiverEnabledForChain(address integrator, address transceiver, uint128 bitmap) + internal + view + returns (bool) + { + if (transceiver == address(0)) { + revert InvalidTransceiverZeroAddress(); + } + uint8 index = _getTransceiverInfosStorage()[integrator][transceiver].index; + return (bitmap & uint128(1 << index)) != 0; + } + + /// @dev This function will revert if the transceiver is an invalid address or not registered. + /// @param integrator The integrator address + /// @param transceiver The transceiver address + function _checkTransceiver(address integrator, address transceiver) internal view { + if (transceiver == address(0)) { + revert InvalidTransceiverZeroAddress(); + } + + if (!_getTransceiverInfosStorage()[integrator][transceiver].registered) { + revert NonRegisteredTransceiver(transceiver); + } + } + + /// @dev It is assumed that the integrator address is already validated (and not 0) + /// This just enables the send side transceiver. It does not register it. + /// @param integrator The integrator address + /// @param transceiver The transceiver address + /// @param chainId The chain ID + function _enableSendTransceiverForChain(address integrator, address transceiver, uint16 chainId) internal { + _checkTransceiver(integrator, transceiver); + + uint8 index = _getTransceiverInfosStorage()[integrator][transceiver].index; + mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _bitmaps = + _getPerChainSendTransceiverBitmapStorage(); + _bitmaps[integrator][chainId].bitmap |= uint128(1 << index); + } + + /// @dev It is assumed that the integrator address is already validated (and not 0) + /// This just enables the receive side transceiver. It does not register it. + /// @param integrator The integrator address + /// @param transceiver The transceiver address + /// @param chainId The chain ID + function _enableRecvTransceiverForChain(address integrator, address transceiver, uint16 chainId) internal { + _checkTransceiver(integrator, transceiver); + + uint8 index = _getTransceiverInfosStorage()[integrator][transceiver].index; + mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _bitmaps = + _getPerChainRecvTransceiverBitmapStorage(); + _bitmaps[integrator][chainId].bitmap |= uint128(1 << index); + } + + /// @dev This function enables a send side transceiver. If it is not registered, it will register it. + /// @param integrator The integrator address + /// @param transceiver The transceiver address + /// @param chainId The chain ID + /// @return index The index of this newly enabled send side transceiver + function _setSendTransceiver(address integrator, address transceiver, uint16 chainId) + internal + returns (uint8 index) + { + // These are everything for an integrator. + mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); + mapping(address => _NumTransceivers) storage _numTransceivers = _getNumTransceiversStorage(); + // This is send side for a specific chain. + mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _enabledTransceiverBitmap = + _getPerChainSendTransceiverBitmapStorage(); + + if (transceiver == address(0)) { + revert InvalidTransceiverZeroAddress(); + } + + if (chainId == 0) { + revert InvalidChain(chainId); + } + + if (!transceiverInfos[integrator][transceiver].registered) { + if (_numTransceivers[integrator].registered >= MAX_TRANSCEIVERS) { + revert TooManyTransceivers(); + } + + // Create the TransceiverInfo + transceiverInfos[integrator][transceiver] = + TransceiverInfo({registered: true, index: _numTransceivers[integrator].registered}); + // Add this transceiver to the integrator => address[] mapping + _getRegisteredTransceiversStorage()[integrator].push(transceiver); + // Increment count of transceivers + _numTransceivers[integrator].registered++; + // Emit an event + emit SendTransceiverAdded(integrator, transceiver, chainId, _numTransceivers[integrator].registered); + } + + // _numTransceivers[integrator].enabled++; + + // Add this transceiver to the per chain list of transceivers by updating the bitmap + uint128 updatedEnabledTransceiverBitmap = _enabledTransceiverBitmap[integrator][chainId].bitmap + | uint128(1 << transceiverInfos[integrator][transceiver].index); + // ensure that this actually changed the bitmap + if (updatedEnabledTransceiverBitmap == _enabledTransceiverBitmap[integrator][chainId].bitmap) { + revert TransceiverAlreadyEnabled(transceiver); + } + _enabledTransceiverBitmap[integrator][chainId].bitmap = updatedEnabledTransceiverBitmap; + + _checkSendTransceiversInvariants(integrator); + emit SendTransceiverEnabledForChain(integrator, transceiver, chainId); + + return transceiverInfos[integrator][transceiver].index; + } + + /// @dev This function enables a transceiver. If it is not registered, it will register it. + /// @param integrator The integrator address + /// @param transceiver The transceiver address + /// @param chainId The chain ID + /// @return index The index of this newly enabled receive side transceiver + function _setRecvTransceiver(address integrator, address transceiver, uint16 chainId) + internal + returns (uint8 index) + { + // These are everything for an integrator. + mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); + mapping(address => _NumTransceivers) storage _numTransceivers = _getNumTransceiversStorage(); + // This is send side for a specific chain. + mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _enabledTransceiverBitmap = + _getPerChainRecvTransceiverBitmapStorage(); + + if (transceiver == address(0)) { + revert InvalidTransceiverZeroAddress(); + } + + if (chainId == 0) { + revert InvalidChain(chainId); + } + + if (!transceiverInfos[integrator][transceiver].registered) { + if (_numTransceivers[integrator].registered >= MAX_TRANSCEIVERS) { + revert TooManyTransceivers(); + } + + // Create the TransceiverInfo + transceiverInfos[integrator][transceiver] = + TransceiverInfo({registered: true, index: _numTransceivers[integrator].registered}); + // Add this transceiver to the integrator => address[] mapping + _getRegisteredTransceiversStorage()[integrator].push(transceiver); + // Increment count of transceivers + _numTransceivers[integrator].registered++; + // Emit an event + emit RecvTransceiverAdded(integrator, transceiver, chainId, _numTransceivers[integrator].registered); + } + + // _numTransceivers[integrator].enabled++; + + // Add this transceiver to the per chain list of transceivers by updating the bitmap + uint128 updatedEnabledTransceiverBitmap = _enabledTransceiverBitmap[integrator][chainId].bitmap + | uint128(1 << transceiverInfos[integrator][transceiver].index); + // ensure that this actually changed the bitmap + if (updatedEnabledTransceiverBitmap == _enabledTransceiverBitmap[integrator][chainId].bitmap) { + revert TransceiverAlreadyEnabled(transceiver); + } + _enabledTransceiverBitmap[integrator][chainId].bitmap = updatedEnabledTransceiverBitmap; + + _checkRecvTransceiversInvariants(integrator); + emit RecvTransceiverEnabledForChain(integrator, transceiver, chainId); + + return transceiverInfos[integrator][transceiver].index; + } + + /// @dev This function disables a send side transceiver by chain. + /// @notice This function will revert under the following conditions: + /// - The transceiver is the zero address + /// - The transceiver is not registered + /// @param integrator The integrator address + /// @param transceiver The transceiver address + /// @param chainId The chain ID + function _disableSendTransceiver(address integrator, address transceiver, uint16 chainId) internal { + mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); + mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _enabledTransceiverBitmap = + _getPerChainSendTransceiverBitmapStorage(); + + if (transceiver == address(0)) { + revert InvalidTransceiverZeroAddress(); + } + + if (chainId == 0) { + revert InvalidChain(chainId); + } + + TransceiverInfo storage info = transceiverInfos[integrator][transceiver]; + + if (!info.registered) { + revert NonRegisteredTransceiver(transceiver); + } + + uint128 updatedEnabledTransceiverBitmap = _enabledTransceiverBitmap[integrator][chainId].bitmap + & uint128(~(1 << transceiverInfos[integrator][transceiver].index)); + // ensure that this actually changed the bitmap + if (updatedEnabledTransceiverBitmap >= _enabledTransceiverBitmap[integrator][chainId].bitmap) { + revert DisabledTransceiver(transceiver); + } + _enabledTransceiverBitmap[integrator][chainId].bitmap = updatedEnabledTransceiverBitmap; + + _checkSendTransceiversInvariants(integrator); + // we call the invariant check on the transceiver here as well, since + // the above check only iterates through the enabled transceivers. + _checkSendTransceiverInvariants(integrator, transceiver); + emit SendTransceiverDisabled(integrator, transceiver, chainId); + } + + /// @dev This function disables a receive side transceiver by chain. + /// @notice This function will revert under the following conditions: + /// - The transceiver is the zero address + /// - The transceiver is not registered + /// @param integrator The integrator address + /// @param transceiver The transceiver address + /// @param chainId The chain ID + function _disableRecvTransceiver(address integrator, address transceiver, uint16 chainId) internal { + mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); + mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _enabledTransceiverBitmap = + _getPerChainRecvTransceiverBitmapStorage(); + + if (transceiver == address(0)) { + revert InvalidTransceiverZeroAddress(); + } + + if (chainId == 0) { + revert InvalidChain(chainId); + } + + TransceiverInfo storage info = transceiverInfos[integrator][transceiver]; + + if (!info.registered) { + revert NonRegisteredTransceiver(transceiver); + } + + uint128 updatedEnabledTransceiverBitmap = _enabledTransceiverBitmap[integrator][chainId].bitmap + & uint128(~(1 << transceiverInfos[integrator][transceiver].index)); + // ensure that this actually changed the bitmap + if (updatedEnabledTransceiverBitmap >= _enabledTransceiverBitmap[integrator][chainId].bitmap) { + revert DisabledTransceiver(transceiver); + } + _enabledTransceiverBitmap[integrator][chainId].bitmap = updatedEnabledTransceiverBitmap; + + _checkRecvTransceiversInvariants(integrator); + // we call the invariant check on the transceiver here as well, since + // the above check only iterates through the enabled transceivers. + _checkRecvTransceiverInvariants(integrator, transceiver); + emit RecvTransceiverDisabled(integrator, transceiver, chainId); + } + + /// @param integrator The integrator address + /// @param forChainId The chain ID + /// @return bitmap The bitmap of the send side transceivers enabled for this integrator and chain + function _getEnabledSendTransceiversBitmapForChain(address integrator, uint16 forChainId) + internal + view + virtual + returns (uint128 bitmap) + { + if (forChainId == 0) { + revert InvalidChain(forChainId); + } + bitmap = _getPerChainSendTransceiverBitmapStorage()[integrator][forChainId].bitmap; + } + + /// @param integrator The integrator address + /// @param forChainId The chain ID + /// @return bitmap The bitmap of the send side transceivers enabled for this integrator and chain + function _getEnabledRecvTransceiversBitmapForChain(address integrator, uint16 forChainId) + internal + view + virtual + returns (uint128 bitmap) + { + if (forChainId == 0) { + revert InvalidChain(forChainId); + } + bitmap = _getPerChainRecvTransceiverBitmapStorage()[integrator][forChainId].bitmap; + } + + // =============== EXTERNAL FUNCTIONS ======================================== + + /// @notice Returns the registered send side transceiver addresses for the given integrator. + /// @param integrator The integrator address + /// @return result The registered transceivers for the given integrator. + function getTransceivers(address integrator) external view returns (address[] memory result) { + result = _getRegisteredTransceiversStorage()[integrator]; + } + + /// @notice Returns the enabled send side transceiver addresses for the given integrator. + /// @param integrator The integrator address + /// @param chainId The chainId for the desired transceivers + /// @return result The enabled send side transceivers for the given integrator and chain. + function getSendTransceiversByChain(address integrator, uint16 chainId) + external + view + returns (address[] memory result) + { + address[] memory allTransceivers = _getRegisteredTransceiversStorage()[integrator]; + address[] memory tempResult = new address[](allTransceivers.length); + for (uint256 i = 0; i < allTransceivers.length; i++) { + if (_isSendTransceiverEnabledForChain(integrator, allTransceivers[i], chainId)) { + tempResult[i] = allTransceivers[i]; + } + } + result = new address[](tempResult.length); + for (uint256 i = 0; i < tempResult.length; i++) { + result[i] = tempResult[i]; + } + } + + /// @notice Returns the enabled send side transceiver addresses for the given integrator. + /// @param integrator The integrator address + /// @param chainId The chainId for the desired transceivers + /// @return result The enabled send side transceivers for the given integrator. + function getRecvTransceiversByChain(address integrator, uint16 chainId) + external + view + returns (address[] memory result) + { + address[] memory allTransceivers = _getRegisteredTransceiversStorage()[integrator]; + address[] memory tempResult = new address[](allTransceivers.length); + for (uint256 i = 0; i < allTransceivers.length; i++) { + if (_isRecvTransceiverEnabledForChain(integrator, allTransceivers[i], chainId)) { + tempResult[i] = allTransceivers[i]; + } + } + result = new address[](tempResult.length); + for (uint256 i = 0; i < tempResult.length; i++) { + result[i] = tempResult[i]; + } + } + + // ============== Invariants ============================================= + + /// @dev Check that the transceiver is in a valid state. + /// Checking these invariants is somewhat costly, but we only need to do it + /// when modifying the transceivers, which happens infrequently. + function _checkSendTransceiversInvariants(address integrator) internal view { + // _NumTransceivers storage _numTransceivers = _getNumSendTransceiversStorage()[integrator]; + // address[] storage _enabledTransceivers = _getRegisteredSendTransceiversStorage()[integrator]; + + // uint256 numTransceiversEnabled = _numTransceivers.enabled; + // assert(numTransceiversEnabled == _enabledTransceivers.length); + + // for (uint256 i = 0; i < numTransceiversEnabled; i++) { + // _checkSendTransceiverInvariants(integrator, _enabledTransceivers[i]); + // } + + // // invariant: each transceiver is only enabled once + // for (uint256 i = 0; i < numTransceiversEnabled; i++) { + // for (uint256 j = i + 1; j < numTransceiversEnabled; j++) { + // assert(_enabledTransceivers[i] != _enabledTransceivers[j]); + // } + // } + + // // invariant: numRegisteredTransceivers <= MAX_TRANSCEIVERS + // assert(_numTransceivers.registered <= MAX_TRANSCEIVERS); + } + + /// @dev Check that the transceiver is in a valid state. + function _checkSendTransceiverInvariants(address integrator, address transceiver) private view { + // mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); + // mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _enabledTransceiverBitmap = + // _getPerChainSendTransceiverBitmapStorage(); + // mapping(address => _NumTransceivers) storage _numTransceivers = _getNumSendTransceiversStorage(); + // mapping(address => address[]) storage _enabledTransceivers = _getRegisteredSendTransceiversStorage(); + + // TransceiverInfo memory transceiverInfo = transceiverInfos[integrator][transceiver]; + + // // if an transceiver is not registered, it should not be enabled + // assert(transceiverInfo.registered || (!transceiverInfo.enabled && transceiverInfo.index == 0)); + + // bool transceiverInEnabledBitmap = ( + // _enabledTransceiverBitmap[integrator][transceiverInfo.chainId].bitmap & uint128(1 << transceiverInfo.index) + // ) != 0; + // bool transceiverEnabled = transceiverInfo.enabled; + + // bool transceiverInEnabledTransceivers = false; + + // for (uint256 i = 0; i < _numTransceivers[integrator].enabled; i++) { + // if (_enabledTransceivers[integrator][i] == transceiver) { + // transceiverInEnabledTransceivers = true; + // break; + // } + // } + + // // invariant: transceiverInfos[integrator][transceiver].enabled + // // <=> enabledTransceiverBitmap & (1 << transceiverInfos[integrator][transceiver].index) != 0 + // assert(transceiverInEnabledBitmap == transceiverEnabled); + + // // invariant: transceiverInfos[integrator][transceiver].enabled <=> transceiver in _enabledTransceivers + // assert(transceiverInEnabledTransceivers == transceiverEnabled); + + // assert(transceiverInfo.index < _numTransceivers[integrator].registered); + } + + /// @dev Check that the transceiver is in a valid state. + /// Checking these invariants is somewhat costly, but we only need to do it + /// when modifying the transceivers, which happens infrequently. + function _checkRecvTransceiversInvariants(address integrator) internal view { + // _NumTransceivers storage _numTransceivers = _getNumRecvTransceiversStorage()[integrator]; + // address[] storage _enabledTransceivers = _getRegisteredRecvTransceiversStorage()[integrator]; + + // uint256 numTransceiversEnabled = _numTransceivers.enabled; + // assert(numTransceiversEnabled == _enabledTransceivers.length); + + // for (uint256 i = 0; i < numTransceiversEnabled; i++) { + // _checkRecvTransceiverInvariants(integrator, _enabledTransceivers[i]); + // } + + // // invariant: each transceiver is only enabled once + // for (uint256 i = 0; i < numTransceiversEnabled; i++) { + // for (uint256 j = i + 1; j < numTransceiversEnabled; j++) { + // assert(_enabledTransceivers[i] != _enabledTransceivers[j]); + // } + // } + + // // invariant: numRegisteredTransceivers <= MAX_TRANSCEIVERS + // assert(_numTransceivers.registered <= MAX_TRANSCEIVERS); + } + + /// @dev Check that the transceiver is in a valid state. + function _checkRecvTransceiverInvariants(address integrator, address transceiver) private view { + // mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); + // mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _enabledTransceiverBitmap = + // _getPerChainRecvTransceiverBitmapStorage(); + // mapping(address => _NumTransceivers) storage _numTransceivers = _getNumRecvTransceiversStorage(); + // mapping(address => address[]) storage _enabledTransceivers = _getRegisteredRecvTransceiversStorage(); + + // TransceiverInfo memory transceiverInfo = transceiverInfos[integrator][transceiver]; + + // // if an transceiver is not registered, it should not be enabled + // assert(transceiverInfo.registered || (!transceiverInfo.enabled && transceiverInfo.index == 0)); + + // bool transceiverInEnabledBitmap = ( + // _enabledTransceiverBitmap[integrator][transceiverInfo.chainId].bitmap & uint128(1 << transceiverInfo.index) + // ) != 0; + // bool transceiverEnabled = transceiverInfo.enabled; + + // bool transceiverInEnabledTransceivers = false; + + // for (uint256 i = 0; i < _numTransceivers[integrator].enabled; i++) { + // if (_enabledTransceivers[integrator][i] == transceiver) { + // transceiverInEnabledTransceivers = true; + // break; + // } + // } + + // // invariant: transceiverInfos[integrator][transceiver].enabled + // // <=> enabledTransceiverBitmap & (1 << transceiverInfos[integrator][transceiver].index) != 0 + // assert(transceiverInEnabledBitmap == transceiverEnabled); + + // // invariant: transceiverInfos[integrator][transceiver].enabled <=> transceiver in _enabledTransceivers + // assert(transceiverInEnabledTransceivers == transceiverEnabled); + + // assert(transceiverInfo.index < _numTransceivers[integrator].registered); + } +} diff --git a/evm/src/interfaces/IRouter.sol b/evm/src/interfaces/IRouter.sol index 76759a77..567db858 100644 --- a/evm/src/interfaces/IRouter.sol +++ b/evm/src/interfaces/IRouter.sol @@ -9,12 +9,35 @@ interface IRouter is IMessageSequence { /// @param recipientChain The Wormhole chain ID of the recipient. /// @param recipientAddress The universal address of the peer on the recipient chain. /// @param refundAddress The source chain refund address passed to the Transceiver. - /// @param message A message to be sent to the recipient chain. + /// @param payloadHash keccak256 of a message to be sent to the recipient chain. /// @return uint64 The sequence number of the message. function sendMessage( uint16 recipientChain, UniversalAddress recipientAddress, address refundAddress, - bytes memory message + bytes32 payloadHash ) external payable returns (uint64); + + // /// @dev Receive a message from another chain called by integrator. + // /// @param sourceChain The Wormhole chain ID of the recipient. + // /// @param senderAddress The universal address of the peer on the recipient chain. + // /// @param refundAddress The source chain refund address passed to the Transceiver. + // /// @param message A message to be sent to the recipient chain. + // /// @return uint128 The bitmap + function receiveMessage( + uint16 sourceChain, + UniversalAddress senderAddress, + address refundAddress, + bytes32 messageHash + ) external payable returns (uint128); + + /// @notice Called by a Transceiver contract to deliver a verified attestation. + function attestMessage( + uint16 sourceChain, // Wormhole Chain ID + UniversalAddress sourceAddress, // UniversalAddress of the message sender (integrator) + uint64 sequence, // Next sequence number for that integrator (consuming the sequence number) + uint16 destinationChainId, // Wormhole Chain ID + UniversalAddress destinationAddress, // UniversalAddress of the messsage recipient (integrator on destination chain) + bytes32 payloadHash // keccak256 of arbitrary payload from the integrator + ) external; } diff --git a/evm/src/interfaces/ITransceiver.sol b/evm/src/interfaces/ITransceiver.sol new file mode 100644 index 00000000..a605790a --- /dev/null +++ b/evm/src/interfaces/ITransceiver.sol @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.13; + +import "../libraries/UniversalAddress.sol"; + +interface ITransceiver { + /// @notice The caller is not the NttManager. + /// @dev Selector: 0xc5aa6153. + /// @param caller The address of the caller. + error CallerNotRouter(address caller); + + /// @notice Returns the string type of the transceiver. E.g. "wormhole", "axelar", etc. + function getTransceiverType() external view returns (string memory); + + /// @notice Fetch the delivery price for a given recipient chain transfer. + /// @param recipientChain The Wormhole chain ID of the target chain. + /// @return deliveryPrice The cost of delivering a message to the recipient chain, + /// in this chain's native token. + function quoteDeliveryPrice(uint16 recipientChain) external view returns (uint256); + + /// @dev Send a message to another chain. + /// @param recipientChain The Wormhole chain ID of the recipient. + /// @param messageHash The hash of the message to be sent to the recipient chain. + /// @param recipientAddress The Wormhole formatted address of the recipient chain. + /// @param refundAddress The address of the refund recipient + function sendMessage( + uint16 recipientChain, + bytes32 messageHash, + UniversalAddress recipientAddress, + bytes32 refundAddress + ) external payable; +} diff --git a/evm/test/Router.t.sol b/evm/test/Router.t.sol index f0d25b2a..020c1de2 100644 --- a/evm/test/Router.t.sol +++ b/evm/test/Router.t.sol @@ -4,34 +4,208 @@ pragma solidity ^0.8.13; import {Test, console} from "forge-std/Test.sol"; import "../src/libraries/UniversalAddress.sol"; import {Router} from "../src/Router.sol"; +import {TransceiverRegistry} from "../src/TransceiverRegistry.sol"; +import {ITransceiver} from "../src/interfaces/ITransceiver.sol"; + +contract RouterImpl is Router { +// function getDelegate(address delegator) public view returns (address) { +// return _getDelegateStorage()[delegator]; +// } + +// function registerDelegate(address delegate) public { +// _registerDelegate(delegate); +// } +} + +// This contract does send/receive operations +contract Integrator { + RouterImpl public router; + address myAdmin; + + constructor(address _router) { + router = RouterImpl(_router); + } + + function setMeAsAdmin(address admin) public { + myAdmin = admin; + } + + function registerWithRouter() public { + router.registerAdmin(myAdmin); + } + + function sendMessage( + uint16 recipientChain, + UniversalAddress recipientAddress, + address refundAddress, + bytes32 payloadHash + ) public payable returns (uint64) { + return router.sendMessage(recipientChain, recipientAddress, refundAddress, payloadHash); + } +} + +// This contract can only do transceiver operations +contract Admin { + address public integrator; + RouterImpl public router; + + constructor(address _integrator, address _router) { + integrator = _integrator; + router = RouterImpl(_router); + } + + function requestAdmin() public { + Integrator(integrator).setMeAsAdmin(address(this)); + } + + function setSendTransceiver(address transceiver, uint16 chain) public { + router.setSendTransceiver(integrator, transceiver, chain); + } + + function setRecvTransceiver(address transceiver, uint16 chain) public { + router.setRecvTransceiver(integrator, transceiver, chain); + } +} + +contract TransceiverImpl is ITransceiver { + function getTransceiverType() public pure override returns (string memory) { + return "test"; + } + + function quoteDeliveryPrice(uint16 /*recipientChain*/ ) public pure override returns (uint256) { + return 0; + } + + function sendMessage( + uint16 recipientChain, + bytes32 messageHash, + UniversalAddress recipientAddress, + bytes32 refundAddress + ) public payable override { + // Do nothing + } +} contract RouterTest is Test { - Router public router; + RouterImpl public router; + TransceiverImpl public transceiverImpl; address userA = address(0x123); address userB = address(0x456); address refundAddr = address(0x789); - bytes message = "hello, world"; + bytes32 messageHash = keccak256("hello, world"); function setUp() public { - router = new Router(); + router = new RouterImpl(); + transceiverImpl = new TransceiverImpl(); + } + + function test_setSendTransceiver() public { + Integrator integrator = new Integrator(address(router)); + Admin admin = new Admin(address(integrator), address(router)); + Admin imposter = new Admin(address(integrator), address(router)); + address transceiver1 = address(0x111); + uint16 chain = 2; + + admin.requestAdmin(); + integrator.registerWithRouter(); + admin.setSendTransceiver(transceiver1, chain); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, transceiver1)); + admin.setSendTransceiver(transceiver1, chain); + + vm.expectRevert(abi.encodeWithSelector(Router.CallerNotAdmin.selector)); + imposter.setSendTransceiver(transceiver1, chain); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, transceiver1)); + admin.setSendTransceiver(transceiver1, chain); + } + + function test_setRecvTransceiver() public { + Integrator integrator = new Integrator(address(router)); + Admin admin = new Admin(address(integrator), address(router)); + Admin imposter = new Admin(address(integrator), address(router)); + address transceiver1 = address(0x111); + uint16 chain = 2; + + admin.requestAdmin(); + integrator.registerWithRouter(); + admin.setRecvTransceiver(transceiver1, chain); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, transceiver1)); + admin.setRecvTransceiver(transceiver1, chain); + + vm.expectRevert(abi.encodeWithSelector(Router.CallerNotAdmin.selector)); + imposter.setRecvTransceiver(transceiver1, chain); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, transceiver1)); + admin.setRecvTransceiver(transceiver1, chain); } function test_sendMessageIncrementsSequence() public { - assertEq(router.nextMessageSequence(userA), 0); + Integrator integrator = new Integrator(address(router)); + Admin admin = new Admin(address(integrator), address(router)); + address transceiver1 = address(0x111); + uint16 chain = 2; + admin.requestAdmin(); + integrator.registerWithRouter(); + admin.setSendTransceiver(transceiver1, chain); + assertEq(router.nextMessageSequence(address(integrator)), 0); // Send inital message from userA, going from unset to 1 - vm.startPrank(userA); - router.sendMessage(1, UniversalAddressLibrary.fromAddress(userB), refundAddr, message); - assertEq(router.nextMessageSequence(userA), 1); - // Send additional message from userA, incrementing the existing sequence - router.sendMessage(1, UniversalAddressLibrary.fromAddress(userB), refundAddr, message); - assertEq(router.nextMessageSequence(userA), 2); + // vm.startPrank(userA); + // vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); + // router.sendMessage(1, UniversalAddressLibrary.fromAddress(userB), refundAddr, messageHash); + // assertEq(router.nextMessageSequence(userA), 0); + // address me = address(this); + // transceiverRegistry.registerAdmin(me); + // // Send additional message from userA, incrementing the existing sequence + // router.sendMessage(1, UniversalAddressLibrary.fromAddress(userB), refundAddr, messageHash); + // assertEq(router.nextMessageSequence(userA), 2); } function testFuzz_sendMessage(address user) public { - uint64 beforeSequence = router.nextMessageSequence(user); + // uint16 chainId = 2; vm.startPrank(user); - router.sendMessage(1, UniversalAddressLibrary.fromAddress(user), refundAddr, message); - assertEq(router.nextMessageSequence(user), beforeSequence + 1); + // Register a transceiver + // Integrator integrator = new Integrator(address(router)); + // uint64 beforeSequence = router.nextMessageSequence(address(this)); + // address transceiver = address(transceiverImpl); + // integrator.setMeAsDelegate(); + // uint8 index = router.setSendTransceiver(transceiver, chainId); + // assert(index == 0); + // address[] memory enabledTransceivers = router.getSendTransceiversByChain(address(this), chainId); + // assert(enabledTransceivers.length == 1); + // router.sendMessage(chainId, UniversalAddressLibrary.fromAddress(user), refundAddr, messageHash); + // assertEq(router.nextMessageSequence(address(this)), beforeSequence + 1); + } + + function testFuzz_receiveMessage(address user) public { + // uint16 chainId = 2; + vm.startPrank(user); + // Integrator integrator = new Integrator(address(router)); + // address transceiver = address(0x111); + // integrator.setRecvTransceiver(transceiver, chainId); + // address[] memory enabledTransceivers = router.getRecvTransceiversByChain(address(integrator), chainId); + // assert(enabledTransceivers.length == 1); + // router.receiveMessage(1, UniversalAddressLibrary.fromAddress(user), refundAddr, messageHash); + } + + function testFuzz_attestMessage(address user) public { + // uint16 srcChain = 2; + // uint16 dstChain = 3; + // uint64 sequence = 1; + // bytes32 payloadHash = keccak256("hello, world"); + // sourceAddress = UniversalAddressLibrary.fromAddress(user); + // destinationAddress = UniversalAddressLibrary.fromAddress(user); + // vm.startPrank(user); + // Integrator integrator = new Integrator(address(router)); + // address transceiver = address(0x111); + // integrator.setRecvTransceiver(transceiver, dstChain); + // address[] memory enabledTransceivers = router.getRecvTransceiversByChain(address(integrator), dstChain); + // assert(enabledTransceivers.length == 1); + // router.attestMessage( + // srcChain, + // UniversalAddressLibrary.fromAddress(user), + // sequence, + // dstChain, + // UniversalAddressLibrary.fromAddress(user), + // payloadHash + // ); } } diff --git a/evm/test/TransceiverRegistry.t.sol b/evm/test/TransceiverRegistry.t.sol new file mode 100644 index 00000000..0a9ab406 --- /dev/null +++ b/evm/test/TransceiverRegistry.t.sol @@ -0,0 +1,315 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.13; + +import {Test, console} from "forge-std/Test.sol"; +import "../src/TransceiverRegistry.sol"; + +contract ConcreteTransceiverRegistry is TransceiverRegistry { + function rmvSendTransceiver(address integrator, address transceiver, uint16 chain) public { + _disableSendTransceiver(integrator, transceiver, chain); + } + + function rmvRecvTransceiver(address integrator, address transceiver, uint16 chain) public { + _disableRecvTransceiver(integrator, transceiver, chain); + } + + function setSendTransceiver(address integrator, address transceiver, uint16 chain) public returns (uint8 index) { + return _setSendTransceiver(integrator, transceiver, chain); + } + + function setRecvTransceiver(address integrator, address transceiver, uint16 chain) public returns (uint8 index) { + return _setRecvTransceiver(integrator, transceiver, chain); + } + + function disableSendTransceiver(address integrator, address transceiver, uint16 chain) public { + _disableSendTransceiver(integrator, transceiver, chain); + } + + function disableRecvTransceiver(address integrator, address transceiver, uint16 chain) public { + _disableRecvTransceiver(integrator, transceiver, chain); + } + + function getRegisteredTransceiversStorage(address integrator) public view returns (address[] memory $) { + return _getRegisteredTransceiversStorage()[integrator]; + } + + function getNumTransceiversStorage(address integrator) public view returns (_NumTransceivers memory $) { + return _getNumTransceiversStorage()[integrator]; + } + + function getEnabledSendTransceiversBitmapForChain(address integrator, uint16 chain) + public + view + returns (uint128 bitmap) + { + return _getEnabledSendTransceiversBitmapForChain(integrator, chain); + } + + function getEnabledRecvTransceiversBitmapForChain(address integrator, uint16 chain) + public + view + returns (uint128 bitmap) + { + return _getEnabledRecvTransceiversBitmapForChain(integrator, chain); + } + + function enableSendTransceiverForChain(address integrator, address transceiver, uint16 chainId) public { + _enableSendTransceiverForChain(integrator, transceiver, chainId); + } + + function enableRecvTransceiverForChain(address integrator, address transceiver, uint16 chainId) public { + _enableRecvTransceiverForChain(integrator, transceiver, chainId); + } + + function isSendTransceiverEnabledForChain(address integrator, address transceiver, uint16 chainId) + public + view + returns (bool) + { + return _isSendTransceiverEnabledForChain(integrator, transceiver, chainId); + } + + function isRecvTransceiverEnabledForChain(address integrator, address transceiver, uint16 chainId) + public + view + returns (bool) + { + return _isRecvTransceiverEnabledForChain(integrator, transceiver, chainId); + } + + function getMaxTransceivers() public pure returns (uint8) { + return MAX_TRANSCEIVERS; + } +} + +contract TransceiverRegistryTest is Test { + ConcreteTransceiverRegistry public transceiverRegistry; + address integrator1 = address(0x1); + address integrator2 = address(0x2); + address zeroTransceiver = address(0); + address sendTransceiver = address(0x123); + address recvTransceiver = address(0x234); + uint16 zeroChain = 0; + uint16 chain = 2; + uint16 wrongChain = 3; + + function setUp() public { + transceiverRegistry = new ConcreteTransceiverRegistry(); + } + + function test1() public view { + assertEq(transceiverRegistry.getTransceivers(integrator1).length, 0); + assertEq(transceiverRegistry.getTransceivers(integrator2).length, 0); + } + + function test2() public { + address me = address(this); + // Send side + assertEq(transceiverRegistry.getTransceivers(me).length, 0); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); + transceiverRegistry.setSendTransceiver(me, sendTransceiver, zeroChain); + transceiverRegistry.setSendTransceiver(me, sendTransceiver, chain); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); + transceiverRegistry.disableSendTransceiver(me, sendTransceiver, zeroChain); + + // Recv side + // Transceiver was registered on the send side + assertEq(transceiverRegistry.getTransceivers(me).length, 1); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); + transceiverRegistry.setRecvTransceiver(me, sendTransceiver, zeroChain); + transceiverRegistry.setRecvTransceiver(me, recvTransceiver, chain); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); + transceiverRegistry.disableRecvTransceiver(me, recvTransceiver, zeroChain); + } + + function test3() public { + address me = address(this); + // Send side + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, sendTransceiver)); + transceiverRegistry.rmvSendTransceiver(me, sendTransceiver, chain); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); + transceiverRegistry.setSendTransceiver(me, zeroTransceiver, chain); + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 0); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 0); + transceiverRegistry.setSendTransceiver(me, sendTransceiver, chain); + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 1); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 1); + // assertEq(transceiverRegistry.getSendTransceiverInfos(integrator1).length, 1); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); + transceiverRegistry.disableSendTransceiver(me, sendTransceiver, zeroChain); + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 1); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 1); + transceiverRegistry.disableSendTransceiver(me, sendTransceiver, chain); + // disabled, but stays registered + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 1); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 1); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.DisabledTransceiver.selector, sendTransceiver)); + transceiverRegistry.disableSendTransceiver(me, sendTransceiver, chain); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); + transceiverRegistry.disableSendTransceiver(me, zeroTransceiver, chain); + // assertEq(transceiverRegistry.getSendTransceiverInfos(integrator1).length, 0); + + // Recv side + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, recvTransceiver)); + transceiverRegistry.rmvRecvTransceiver(me, recvTransceiver, chain); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); + transceiverRegistry.setRecvTransceiver(me, zeroTransceiver, chain); + // Carry over from send side test + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 1); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 1); + transceiverRegistry.setRecvTransceiver(me, recvTransceiver, chain); + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 2); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 2); + // assertEq(transceiverRegistry.getRecvTransceiverInfos(integrator1).length, 1); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); + transceiverRegistry.disableRecvTransceiver(me, recvTransceiver, zeroChain); + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 2); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 2); + transceiverRegistry.disableRecvTransceiver(me, recvTransceiver, chain); + // disabled, but stays registered + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 2); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 2); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.DisabledTransceiver.selector, recvTransceiver)); + transceiverRegistry.disableRecvTransceiver(me, recvTransceiver, chain); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); + transceiverRegistry.disableRecvTransceiver(me, zeroTransceiver, chain); + // assertEq(transceiverRegistry.getRecvTransceiverInfos(integrator1).length, 0); + } + + function test4() public { + // Send side + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(integrator1).length, 0); + assertEq(transceiverRegistry.getEnabledSendTransceiversBitmapForChain(integrator1, chain), 0); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); + assertEq(transceiverRegistry.getEnabledSendTransceiversBitmapForChain(integrator1, zeroChain), 0); + + // Recv side + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(integrator1).length, 0); + assertEq(transceiverRegistry.getEnabledRecvTransceiversBitmapForChain(integrator1, chain), 0); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); + assertEq(transceiverRegistry.getEnabledRecvTransceiversBitmapForChain(integrator1, zeroChain), 0); + } + + // This is a redudant test, as the previous tests already cover this + function test5() public view { + // Send side + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(integrator1).length, 0); + + // Recv side + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(integrator1).length, 0); + } + + // This is a redudant test, as the previous tests already cover this + function test6() public view { + // Send side + TransceiverRegistry._NumTransceivers memory numSendTransceivers = + transceiverRegistry.getNumTransceiversStorage(integrator1); + assertEq(numSendTransceivers.registered, 0); + + // Recv side + TransceiverRegistry._NumTransceivers memory numRecvTransceivers = + transceiverRegistry.getNumTransceiversStorage(integrator1); + assertEq(numRecvTransceivers.registered, 0); + } + + function test7() public { + address me = address(this); + // Send side + address sTransceiver = address(0x456); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, sTransceiver)); + transceiverRegistry.enableSendTransceiverForChain(me, sTransceiver, chain); + + // Recv side + address rTransceiver = address(0x567); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, rTransceiver)); + transceiverRegistry.enableRecvTransceiverForChain(me, rTransceiver, chain); + } + + function test8() public { + uint16 chainId = 3; + address me = address(this); + + // Send side + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); + transceiverRegistry.enableSendTransceiverForChain(me, zeroTransceiver, chainId); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); + transceiverRegistry.isSendTransceiverEnabledForChain(me, zeroTransceiver, chainId); + + // Recv side + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); + transceiverRegistry.enableRecvTransceiverForChain(me, zeroTransceiver, chainId); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); + transceiverRegistry.isRecvTransceiverEnabledForChain(me, zeroTransceiver, chainId); + } + + function test9() public { + uint16 chainId = 4; + address me = address(this); + + // Send side + address sTransceiver = address(0x345); + assertEq(transceiverRegistry.isSendTransceiverEnabledForChain(me, sTransceiver, chainId), false); + transceiverRegistry.setSendTransceiver(me, sTransceiver, chain); + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 1); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 1); + transceiverRegistry.enableSendTransceiverForChain(me, sTransceiver, chainId); + bool enabled = transceiverRegistry.isSendTransceiverEnabledForChain(me, sTransceiver, chainId); + assertEq(enabled, true); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, sTransceiver)); + transceiverRegistry.setSendTransceiver(me, sTransceiver, chain); + + // Recv side + address rTransceiver = address(0x453); + assertEq(transceiverRegistry.isRecvTransceiverEnabledForChain(me, rTransceiver, chainId), false); + transceiverRegistry.setRecvTransceiver(me, rTransceiver, chain); + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 2); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 2); + transceiverRegistry.enableRecvTransceiverForChain(me, rTransceiver, chainId); + enabled = transceiverRegistry.isRecvTransceiverEnabledForChain(me, rTransceiver, chainId); + assertEq(enabled, true); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, rTransceiver)); + transceiverRegistry.setRecvTransceiver(me, rTransceiver, chain); + } + + function test10() public { + address me = address(this); + uint8 maxTransceivers = transceiverRegistry.getMaxTransceivers(); + + // Send side + for (uint8 i = 0; i < maxTransceivers; i++) { + transceiverRegistry.setSendTransceiver(me, address(uint160(i + 1)), chain); + } + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TooManyTransceivers.selector)); + transceiverRegistry.setSendTransceiver(me, address(0x111), chain); + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); + transceiverRegistry.disableSendTransceiver(me, address(0x1), chain); + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TooManyTransceivers.selector)); + transceiverRegistry.setSendTransceiver(me, address(0x111), chain); + } + + function test11() public { + address me = address(this); + uint8 maxTransceivers = transceiverRegistry.getMaxTransceivers(); + + // Recv side + for (uint8 i = 0; i < maxTransceivers; i++) { + transceiverRegistry.setRecvTransceiver(me, address(uint160(i + 1)), chain); + } + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TooManyTransceivers.selector)); + transceiverRegistry.setRecvTransceiver(me, address(0x111), chain); + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); + transceiverRegistry.disableRecvTransceiver(me, address(0x1), chain); + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); + assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TooManyTransceivers.selector)); + transceiverRegistry.setRecvTransceiver(me, address(0x111), chain); + } +} From 94470b5adf9ef4f944ac60d60680d8e3356521e7 Mon Sep 17 00:00:00 2001 From: Bruce Riley Date: Fri, 11 Oct 2024 15:10:47 -0500 Subject: [PATCH 2/5] evm: add router --- evm/src/Router.sol | 606 +++++++++++++++----- evm/src/TransceiverRegistry.sol | 657 ++++++++-------------- evm/src/interfaces/IRouter.sol | 43 -- evm/src/interfaces/IRouterAdmin.sol | 58 ++ evm/src/interfaces/IRouterIntegrator.sol | 74 +++ evm/src/interfaces/IRouterTransceiver.sol | 22 + evm/src/interfaces/ITransceiver.sol | 24 +- evm/test/Router.t.sol | 650 ++++++++++++++++----- evm/test/TransceiverRegistry.t.sol | 275 +++++---- 9 files changed, 1564 insertions(+), 845 deletions(-) delete mode 100644 evm/src/interfaces/IRouter.sol create mode 100644 evm/src/interfaces/IRouterAdmin.sol create mode 100644 evm/src/interfaces/IRouterIntegrator.sol create mode 100644 evm/src/interfaces/IRouterTransceiver.sol diff --git a/evm/src/Router.sol b/evm/src/Router.sol index 1d64cb90..ec10291a 100644 --- a/evm/src/Router.sol +++ b/evm/src/Router.sol @@ -1,53 +1,165 @@ // SPDX-License-Identifier: Apache-2.0 pragma solidity ^0.8.13; -import "./interfaces/IRouter.sol"; +import "./interfaces/IRouterAdmin.sol"; +import "./interfaces/IRouterIntegrator.sol"; +import "./interfaces/IRouterTransceiver.sol"; import "./MessageSequence.sol"; import "./TransceiverRegistry.sol"; import "./interfaces/ITransceiver.sol"; -contract Router is IRouter, MessageSequence, TransceiverRegistry { +contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageSequence, TransceiverRegistry { string public constant ROUTER_VERSION = "0.0.1"; + struct IntegratorConfig { + bool isInitialized; + address admin; + address transfer; + } + + // =============== Immutables ============================================================ + + /// @dev Wormhole chain ID that the Router is deployed on. + /// This chain ID is formatted Wormhole Chain IDs -- https://docs.wormhole.com/wormhole/reference/constants + uint16 public immutable ourChainId; + + // =============== Setup ================================================================= + + constructor(uint16 _ourChainId) { + ourChainId = _ourChainId; + } + // =============== Events ================================================================ - /// @notice Emitted when an integrator registers a delegate. - /// @param integrator The address of the integrator. - /// @param delegate The address of the delegate. - event DelegateRegistered(address integrator, address delegate); + /// @notice Emitted when an integrator registers with the router. + /// @dev Topic0 + /// 0x582a4322c684b4cdebf273e2be5090d5f21476be5566a98d6a224a450447c5b4 + /// @param integrator The address of the integrator contract. + /// @param admin The address of the admin contract. + event IntegratorRegistered(address integrator, address admin); - /// @notice Emitted when a message has been attested to. - /// @param integrator The address of the integrator. - /// @param transceiver The address of the transceiver. - /// @param digest The digest of the message. - event MessageAttestedTo(address integrator, address transceiver, bytes32 digest); + /// @notice Emitted when the admin is changed for an integrator. + /// @dev Topic0 + /// 0x9f6130d220a6021d90d78c7ed17b7cfb79f530974405b174fef75f671205513c + /// @param integrator The address of the integrator contract. + /// @param oldAdmin The address of the old admin contract. + /// @param newAdmin The address of the new admin contract. + event AdminUpdated(address integrator, address oldAdmin, address newAdmin); + + /// @notice Emitted when an admin change request is received for an integrator. + /// @dev Topic0 + /// 0xcdeb0d05a920666dfd2822eb51628fff963ba0b1672f984a8b60017ed83939e4 + /// @param integrator The address of the integrator contract. + /// @param oldAdmin The address of the old admin contract. + /// @param newAdmin The address of the new admin contract. + event AdminUpdateRequested(address integrator, address oldAdmin, address newAdmin); /// @notice Emitted when a message has been sent. + /// @param messageHash The keccak256 of the message. It is, also, indexed. /// @param sender The address of the sender. /// @param recipient The address of the recipient. /// @param recipientChain The chainId of the recipient. + /// @param sequence The sequence of the message. /// @param digest The digest of the message. - event MessageSent(address sender, address recipient, uint16 recipientChain, bytes32 digest); + /// @dev Topic0 0x1c170583317700fb71bc583fa6fdd8ff893f6c3a15a79104f1681d6d9eb708ee + event MessageSent( + bytes32 indexed messageHash, + UniversalAddress sender, + UniversalAddress recipient, + uint16 recipientChain, + uint64 sequence, + bytes32 digest + ); + + /// @notice Emitted when a message has been attested to. + /// @param messageHash The keccak256 of the message. It is, also, indexed. + /// @param srcChain The Wormhole chain ID of the sender. + /// @param srcAddr The universal address of the peer on the sending chain. + /// @param sequence The sequence number of the message (per integrator). + /// @param dstChain The Wormhole chain ID of the destination. + /// @param dstAddr The destination address of the message. + /// @param payloadHash The keccak256 of payload from the integrator. + /// @param attestedBitmap Bitmap of transceivers that have attested the message. + /// @param attestingTransceiver The address of the transceiver that attested the message. + /// @dev Topic0 0xb2328f51e669b73cf1831e232716eec9959360a52818a63bb1d82d900de667d8 + event MessageAttestedTo( + bytes32 indexed messageHash, + uint16 srcChain, + UniversalAddress srcAddr, + uint64 sequence, + uint16 dstChain, + UniversalAddress dstAddr, + bytes32 payloadHash, + uint128 attestedBitmap, + UniversalAddress attestingTransceiver + ); + + /// @notice Emitted when a message has been received. + /// @param messageHash The keccak256 of the message. It is, also, indexed. + /// @param srcChain The Wormhole chain ID of the sender. + /// @param srcAddr The universal address of the peer on the sending chain. + /// @param sequence The sequence number of the message (per integrator). + /// @param dstChain The Wormhole chain ID of the destination. + /// @param dstAddr The destination address of the message. + /// @param payloadHash The keccak256 of payload from the integrator. + /// @param enabledBitmap Bitmap of transceivers enabled for the source chain. + /// @param attestedBitmap Bitmap of transceivers that have attested the message. + /// @dev Topic0 0xae4f20b00e13c9f1eec6c3c72ba3146c9538ca60f28c3eb57538b14965905e7d + event MessageReceived( + bytes32 indexed messageHash, + uint16 srcChain, + UniversalAddress srcAddr, + uint64 sequence, + uint16 dstChain, + UniversalAddress dstAddr, + bytes32 payloadHash, + uint128 enabledBitmap, + uint128 attestedBitmap + ); // =============== Errors ================================================================ - /// @notice Error when the transceiver is disabled. + /// @notice Error when the destination chain ID doesn't match this chain. + /// @dev Selector: 0xb86ac1ef. + error InvalidDestinationChain(); + + /// @notice Error when the transceiver is being used as if it is enabled but it is disabled. + /// @dev Selector: 0x424afc23. error TransceiverNotEnabled(); /// @notice Error when the admin is the zero address. + /// @dev Selector: 0x554ff5d7. error InvalidAdminZeroAddress(); + /// @notice Error when there was an attempt to change the admin while a transfer is in progress. + /// @dev Selector: 0xc78a581c. + error AdminTransferInProgress(); + + /// @notice Error when the integrator tries to re-register. + /// @dev Selector: 0x626bb491. + error IntegratorAlreadyRegistered(); + /// @notice Error when the integrator did not register an admin. + /// @dev Selector: 0x815255c7. error IntegratorNotRegistered(); /// @notice Error when the caller is not the registered admin. - error CallerNotAdmin(); + /// @dev Selector: 0xc183bcef. + error CallerNotAuthorized(); + + /// @notice Error when message attestation not found in store. + /// @dev Selector: 0x1547aa01. + error UnknownMessageAttestation(); + + /// @notice Error when message is already marked as executed. + /// @dev Selector: 0x0dc10197. + error AlreadyExecuted(); // =============== Storage =============================================================== /// @dev Holds the integrator address to IntegratorConfig mapping. /// mapping(address => IntegratorConfig) - bytes32 private constant INTEGRATOR_CONFIGS_SLOT = bytes32(uint256(keccak256("registry.integratorConfigs")) - 1); + bytes32 private constant INTEGRATOR_CONFIGS_SLOT = bytes32(uint256(keccak256("router.integratorConfigs")) - 1); /// @dev Integrator address => IntegratorConfig mapping. function _getIntegratorConfigsStorage() internal pure returns (mapping(address => IntegratorConfig) storage $) { @@ -57,153 +169,393 @@ contract Router is IRouter, MessageSequence, TransceiverRegistry { } } + struct AttestationInfo { + bool executed; // replay protection + uint128 attestedTransceivers; // bitmap corresponding to perIntegratorTransceivers + } + + /// @dev Holds the integrator address to message digest to attestation info mapping. + /// mapping(address => IntegratorConfig) + bytes32 private constant ATTESTATION_INFO_SLOT = bytes32(uint256(keccak256("router.attestationInfo")) - 1); + + /// @dev Integrator address => message digest -> attestation info mapping. + function _getAttestationInfoStorage() + internal + pure + returns (mapping(address => mapping(bytes32 => AttestationInfo)) storage $) + { + uint256 slot = uint256(ATTESTATION_INFO_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + // =============== Getters =============================================================== + + /// @notice Returns the admin for a given integrator. + /// @param integrator The address of the integrator contract. + /// @return address The address of the administrator contract + function getAdmin(address integrator) public view returns (address) { + mapping(address => IntegratorConfig) storage integratorConfigs = _getIntegratorConfigsStorage(); + return integratorConfigs[integrator].admin; + } + // =============== External ============================================================== - /// @notice This is the first thing an integrator should do to register the admin address. - /// The admin address is used to manage the transceivers. - /// @dev The msg.sender needs to be the integrator contract. - /// @param admin The address of the admin. Pass in msg.sender, if you want the integrator to be the admin. - function registerAdmin(address admin) external { + // =============== Admin functions ======================================================= + + /// @inheritdoc IRouterIntegrator + function register(address initialAdmin) external { + if (initialAdmin == address(0)) { + revert InvalidAdminZeroAddress(); + } + + address integrator = msg.sender; + // Get the storage for this integrator contract mapping(address => IntegratorConfig) storage integratorConfigs = _getIntegratorConfigsStorage(); - // Do some checks? Should address(0) mean something special? - if (admin == address(0)) { + + if (integratorConfigs[integrator].isInitialized) { + revert IntegratorAlreadyRegistered(); + } + + // Update the storage. + integratorConfigs[integrator] = + IntegratorConfig({isInitialized: true, admin: initialAdmin, transfer: address(0)}); + emit IntegratorRegistered(integrator, initialAdmin); + } + + /// @inheritdoc IRouterAdmin + function updateAdmin(address integrator, address newAdmin) external onlyAdmin(integrator) { + if (newAdmin == address(0)) { + // Use discardAdmin() instead. revert InvalidAdminZeroAddress(); } + // Get the storage for this integrator contract + mapping(address => IntegratorConfig) storage integratorConfigs = _getIntegratorConfigsStorage(); + + if (integratorConfigs[integrator].transfer != address(0)) { + revert AdminTransferInProgress(); + } + // Update the storage. - integratorConfigs[msg.sender] = IntegratorConfig({isInitialized: true, admin: admin}); + integratorConfigs[integrator].admin = newAdmin; + emit AdminUpdated(integrator, msg.sender, newAdmin); } - /// @notice The admin contract calls this function. - /// @param integrator The address of the integrator contract. - /// @param transceiver The address of the Transceiver contract. - /// @param chain The chain ID of the Transceiver contract. - function setSendTransceiver(address integrator, address transceiver, uint16 chain) external { - _checkIntegratorAdmin(integrator, msg.sender); - _setSendTransceiver(integrator, transceiver, chain); + /// @inheritdoc IRouterAdmin + function transferAdmin(address integrator, address newAdmin) external onlyAdmin(integrator) { + if (newAdmin == address(0)) { + // Use discardAdmin() instead. + revert InvalidAdminZeroAddress(); + } + // Get the storage for this integrator contract + mapping(address => IntegratorConfig) storage integratorConfigs = _getIntegratorConfigsStorage(); + + if (integratorConfigs[integrator].transfer != address(0)) { + revert AdminTransferInProgress(); + } + + // Update the storage with this request. + integratorConfigs[integrator].transfer = newAdmin; + emit AdminUpdateRequested(integrator, msg.sender, newAdmin); } - /// @notice The admin contract uses this function. - /// @param integrator The address of the integrator contract. - /// @param transceiver The address of the Transceiver contract. - /// @param chain The chain ID of the Transceiver contract. - function setRecvTransceiver(address integrator, address transceiver, uint16 chain) external { - _checkIntegratorAdmin(integrator, msg.sender); - _setRecvTransceiver(integrator, transceiver, chain); + /// @inheritdoc IRouterAdmin + function claimAdmin(address integrator) external { + // Get the storage for this integrator contract + mapping(address => IntegratorConfig) storage integratorConfigs = _getIntegratorConfigsStorage(); + + address oldAdmin = integratorConfigs[integrator].admin; + address newAdmin = integratorConfigs[integrator].transfer; + if (msg.sender == oldAdmin) { + // This is the cancel case. + integratorConfigs[integrator].transfer = address(0); + } else if (msg.sender == newAdmin) { + // Update the storage with this request. + integratorConfigs[integrator].admin = newAdmin; + integratorConfigs[integrator].transfer = address(0); + } else { + revert CallerNotAuthorized(); + } + emit AdminUpdated(integrator, oldAdmin, newAdmin); } - /// @inheritdoc IRouter - function sendMessage( - uint16 recipientChain, - UniversalAddress recipientAddress, - address refundAddress, - bytes32 payloadHash - ) external payable returns (uint64) { - return _sendMessage(recipientChain, recipientAddress, refundAddress, payloadHash, msg.sender); + /// @inheritdoc IRouterAdmin + function discardAdmin(address integrator) external onlyAdmin(integrator) { + // Get the storage for this integrator contract + mapping(address => IntegratorConfig) storage integratorConfigs = _getIntegratorConfigsStorage(); + + if (integratorConfigs[integrator].transfer != address(0)) { + revert AdminTransferInProgress(); + } + + // Update the storage. + integratorConfigs[integrator].admin = address(0); + emit AdminUpdated(integrator, msg.sender, address(0)); } - /// @dev Receive a message from another chain called by integrator. - /// @param sourceChain The Wormhole chain ID of the recipient. - /// @param senderAddress The universal address of the peer on the recipient chain. - /// @param refundAddress The source chain refund address passed to the Transceiver. - /// @param messageHash The hash of the message. - /// @return uint128 The bitmap - function receiveMessage( - uint16 sourceChain, - UniversalAddress senderAddress, - address refundAddress, - bytes32 messageHash - ) external payable returns (uint128) { - // Find the transceiver for the source chain. - // address transceiver = this.getRecvTransceiverByChain(msg.sender, sourceChain); - // Receive the message. + // =============== Transceiver functions ======================================================= + + /// @inheritdoc IRouterAdmin + function addTransceiver(address integrator, uint16 chainId, address transceiver) + external + onlyAdmin(integrator) + returns (uint8 index) + { + // Call the TransceiverRegistry version. + return _addTransceiver(integrator, chainId, transceiver); + } + + /// @inheritdoc IRouterAdmin + function enableSendTransceiver(address integrator, uint16 chain, address transceiver) + external + onlyAdmin(integrator) + { + // Call the TransceiverRegistry version. + _enableSendTransceiver(integrator, chain, transceiver); + } + + /// @inheritdoc IRouterAdmin + function enableRecvTransceiver(address integrator, uint16 chain, address transceiver) + external + onlyAdmin(integrator) + { + // Call the TransceiverRegistry version. + _enableRecvTransceiver(integrator, chain, transceiver); + } + + /// @inheritdoc IRouterAdmin + function disableSendTransceiver(address integrator, uint16 chain, address transceiver) + external + onlyAdmin(integrator) + { + // Call the TransceiverRegistry version. + _disableSendTransceiver(integrator, chain, transceiver); + } + + /// @inheritdoc IRouterAdmin + function disableRecvTransceiver(address integrator, uint16 chain, address transceiver) + external + onlyAdmin(integrator) + { + // Call the TransceiverRegistry version. + _disableRecvTransceiver(integrator, chain, transceiver); + } + + // =============== Message functions ======================================================= + + /// @inheritdoc IRouterIntegrator + function sendMessage(uint16 dstChain, UniversalAddress dstAddr, address refundAddress, bytes32 payloadHash) + external + payable + onlyIntegrator + returns (uint64 sequence) + { + // get the enabled send transceivers for [msg.sender][dstChain] + address[] memory sendTransceivers = getSendTransceiversByChain(msg.sender, dstChain); + uint256 len = sendTransceivers.length; + if (len == 0) { + revert TransceiverNotEnabled(); + } + UniversalAddress sender = UniversalAddressLibrary.fromAddress(msg.sender); + // get the next sequence number for msg.sender + sequence = _useMessageSequence(msg.sender); + UniversalAddress refundUA = UniversalAddressLibrary.fromAddress(refundAddress); + for (uint256 i = 0; i < len;) { + // quote the delivery price + uint256 deliveryPrice = ITransceiver(sendTransceivers[i]).quoteDeliveryPrice(dstChain); + // call sendMessage + ITransceiver(sendTransceivers[i]).sendMessage{value: deliveryPrice}( + sender, dstChain, dstAddr, sequence, payloadHash, UniversalAddressLibrary.toBytes32(refundUA) + ); + unchecked { + ++i; + } + } + + emit MessageSent( + _computeMessageDigest(ourChainId, sender, sequence, dstChain, dstAddr, payloadHash), + sender, + dstAddr, + dstChain, + sequence, + payloadHash + ); } - /// @inheritdoc IRouter + /// @inheritdoc IRouterTransceiver function attestMessage( - uint16 sourceChain, // Wormhole Chain ID - UniversalAddress sourceAddress, // UniversalAddress of the message sender (integrator) - uint64 sequence, // Next sequence number for that integrator (consuming the sequence number) - uint16 destinationChainId, // Wormhole Chain ID - UniversalAddress destinationAddress, // UniversalAddress of the messsage recipient (integrator on destination chain) - bytes32 payloadHash // keccak256 of arbitrary payload from the integrator + uint16 srcChain, + UniversalAddress srcAddr, + uint64 sequence, + uint16 dstChain, + UniversalAddress dstAddr, + bytes32 payloadHash ) external { - _attestMessage(sourceChain, sourceAddress, sequence, destinationChainId, destinationAddress, payloadHash); + // This is called by the transceiver so we don't check onlyIntegrator. + address integrator = dstAddr.toAddress(); + + // sanity check that destinationChain is this chain + if (dstChain != ourChainId) { + revert InvalidDestinationChain(); + } + + TransceiverInfo storage tsInfo = _getTransceiverInfosStorage()[integrator][msg.sender]; + if (!tsInfo.registered) { + revert TransceiverNotEnabled(); + } + + // Make sure it's enabled on the receive. + if (!_isRecvTransceiverEnabledForChain(integrator, srcChain, msg.sender)) { + revert TransceiverNotEnabled(); + } + + // compute the message digest + bytes32 messageDigest = _computeMessageDigest(srcChain, srcAddr, sequence, dstChain, dstAddr, payloadHash); + + AttestationInfo storage attestationInfo = _getAttestationInfoStorage()[integrator][messageDigest]; + + // It's okay to mark it as attested if it has already been executed. + + // set the bit in perIntegratorAttestations[dstAddr][digest] corresponding to msg.sender + attestationInfo.attestedTransceivers |= uint128(1 << tsInfo.index); + emit MessageAttestedTo( + _computeMessageDigest(srcChain, srcAddr, sequence, dstChain, dstAddr, payloadHash), + srcChain, + srcAddr, + sequence, + dstChain, + dstAddr, + payloadHash, + attestationInfo.attestedTransceivers, + UniversalAddressLibrary.fromAddress(msg.sender) + ); + } + + /// @inheritdoc IRouterIntegrator + function recvMessage( + uint16 srcChain, + UniversalAddress srcAddr, + uint64 sequence, + uint16 dstChain, + UniversalAddress dstAddr, + bytes32 payloadHash + ) external payable onlyIntegrator returns (uint128 enabledBitmap, uint128 attestedBitmap) { + // sanity check that dstChain is this chain + if (dstChain != ourChainId) { + revert InvalidDestinationChain(); + } + + // sanity check that msg.sender is integrator: The caller has onlyIntegrator. + + enabledBitmap = _getEnabledRecvTransceiversBitmapForChain(msg.sender, srcChain); + if (enabledBitmap == 0) { + revert TransceiverNotEnabled(); + } + + // compute the message digest + bytes32 messageDigest = _computeMessageDigest(srcChain, srcAddr, sequence, dstChain, dstAddr, payloadHash); + + AttestationInfo storage attestationInfo = _getAttestationInfoStorage()[msg.sender][messageDigest]; + + // revert if not in perIntegratorAttestations map + if ((attestationInfo.attestedTransceivers == 0) && (!attestationInfo.executed)) { + revert UnknownMessageAttestation(); + } + + // revert if already executed + if (attestationInfo.executed) { + revert AlreadyExecuted(); + } + + // set the executed flag in perIntegratorAttestations[dstAddr][digest] + attestationInfo.executed = true; + attestedBitmap = attestationInfo.attestedTransceivers; + + emit MessageReceived( + messageDigest, srcChain, srcAddr, sequence, dstChain, dstAddr, payloadHash, enabledBitmap, attestedBitmap + ); + } + + /// @inheritdoc IRouterIntegrator + function getMessageStatus( + uint16 srcChain, + UniversalAddress srcAddr, + uint64 sequence, + uint16 dstChain, + UniversalAddress dstAddr, + bytes32 payloadHash + ) external view onlyIntegrator returns (uint128 enabledBitmap, uint128 attestedBitmap, bool executed) { + // sanity check that dstChain is this chain + if (dstChain != ourChainId) { + revert InvalidDestinationChain(); + } + enabledBitmap = _getEnabledRecvTransceiversBitmapForChain(msg.sender, srcChain); + // compute the message digest + bytes32 messageDigest = _computeMessageDigest(srcChain, srcAddr, sequence, dstChain, dstAddr, payloadHash); + + AttestationInfo storage attestationInfo = _getAttestationInfoStorage()[msg.sender][messageDigest]; + + attestedBitmap = attestationInfo.attestedTransceivers; + + executed = attestationInfo.executed; + } + + /// @inheritdoc IRouterIntegrator + function execMessage( + uint16 srcChain, + UniversalAddress srcAddr, + uint64 sequence, + uint16 dstChain, + UniversalAddress dstAddr, + bytes32 payloadHash + ) external onlyIntegrator { + if (dstChain != ourChainId) { + revert InvalidDestinationChain(); + } + // compute the message digest + bytes32 messageDigest = _computeMessageDigest(srcChain, srcAddr, sequence, dstChain, dstAddr, payloadHash); + + AttestationInfo storage attestationInfo = _getAttestationInfoStorage()[msg.sender][messageDigest]; + + bool executed = attestationInfo.executed; + if (executed) { + revert AlreadyExecuted(); + } + attestationInfo.executed = true; } // =============== Internal ============================================================== - /// @notice This function checks that the integrator is registered. - /// @dev This function will revert under the following conditions: - /// - The integrator is not registered - /// @param integrator The integrator address - function _checkIntegrator(address integrator) internal view { - IntegratorConfig storage config = _getIntegratorConfigsStorage()[integrator]; + modifier onlyIntegrator() { + IntegratorConfig storage config = _getIntegratorConfigsStorage()[msg.sender]; if (!config.isInitialized) { revert IntegratorNotRegistered(); } + _; } - /// @notice This function checks that the integrator is registered and the admin is valid. - /// @dev This function will revert under the following conditions: - /// - The integrator is not registered - /// - The admin is not configured for this integrator - /// @param integrator The integrator address - /// @param admin The admin address for this integrator - function _checkIntegratorAdmin(address integrator, address admin) internal view { + modifier onlyAdmin(address integrator) { IntegratorConfig storage config = _getIntegratorConfigsStorage()[integrator]; if (!config.isInitialized) { revert IntegratorNotRegistered(); } - if (config.admin != admin) { - revert CallerNotAdmin(); + if (config.admin != msg.sender) { + revert CallerNotAuthorized(); } + _; } - function _sendMessage( - uint16 chainId, - UniversalAddress recipientAddress, - address refundAddress, - bytes32 messageHash, - address // sender - ) internal returns (uint64 sequence) { - _checkIntegrator(msg.sender); - // get the next sequence number for msg.sender - sequence = _useMessageSequence(msg.sender); - // get the enabled send transceivers for [msg.sender][recipientChain] - address[] memory sendTransceivers = this.getSendTransceiversByChain(msg.sender, chainId); - if (sendTransceivers.length == 0) { - revert TransceiverNotEnabled(); - } - for (uint256 i = 0; i < sendTransceivers.length; i++) { - // quote the delivery price - uint256 deliveryPrice = ITransceiver(sendTransceivers[i]).quoteDeliveryPrice(chainId); - // call sendMessage - ITransceiver(sendTransceivers[i]).sendMessage{value: deliveryPrice}( - chainId, messageHash, recipientAddress, UniversalAddressLibrary.fromAddress(refundAddress).toBytes32() - ); - } - // for each enabled transceiver - // quote the delivery price - // see https://github.com/wormhole-foundation/example-native-token-transfers/blob/68a7ca4132c74e838ac23e54752e8c0bc02bb4a2/evm/src/NttManager/ManagerBase.sol#L113 - // call sendMessage - } - - function _attestMessage( - uint16 sourceChain, // Wormhole Chain ID - UniversalAddress sourceAddress, // UniversalAddress of the message sender (integrator) - uint64 sequence, // Next sequence number for that integrator (consuming the sequence number) - uint16 destinationChainId, // Wormhole Chain ID - UniversalAddress destinationAddress, // UniversalAddress of the messsage recipient (integrator on destination chain) - bytes32 payloadHash // keccak256 of arbitrary payload from the integrator - ) internal { - _checkIntegrator(msg.sender); - // sanity check that destinationChainId is this chain - // get enabled recv transceivers for [destinationAddress][sourceChain] - // address transceiver = this.getRecvTransceiverByChain(sourceChain); - // check that msg.sender is one of those transceivers - // compute the message digest - // set the bit in perIntegratorAttestations[destinationAddress][digest] corresponding to msg.sender + function _computeMessageDigest( + uint16 sourceChain, + UniversalAddress srcAddr, + uint64 sequence, + uint16 destinationChain, + UniversalAddress dstAddr, + bytes32 payloadHash + ) internal pure returns (bytes32) { + return keccak256(abi.encodePacked(sourceChain, srcAddr, sequence, destinationChain, dstAddr, payloadHash)); } } diff --git a/evm/src/TransceiverRegistry.sol b/evm/src/TransceiverRegistry.sol index f47bd121..0f6a6a91 100644 --- a/evm/src/TransceiverRegistry.sol +++ b/evm/src/TransceiverRegistry.sol @@ -5,12 +5,10 @@ pragma solidity ^0.8.13; /// @notice This contract is responsible for handling the registration of Transceivers. /// @dev This contract checks that a few critical invariants hold when transceivers are added or removed, /// including: -/// 1. If a transceiver is not registered, it should be enabled. +/// 1. If a transceiver is not registered, it should be not enabled. /// 2. The value set in the bitmap of transceivers /// should directly correspond to the whether the transceiver is enabled abstract contract TransceiverRegistry { - constructor() {} - /// @dev Information about registered transceivers. struct TransceiverInfo { // whether this transceiver is registered @@ -18,7 +16,6 @@ abstract contract TransceiverRegistry { uint8 index; // the index into the integrator's transceivers array } - // TODO: Does this need to be a struct? /// @dev Bitmap encoding the enabled transceivers. /// invariant: forall (i: uint8), enabledTransceiverBitmap & i == 1 <=> transceiverInfos[i].enabled struct _EnabledTransceiverBitmap { @@ -33,90 +30,92 @@ abstract contract TransceiverRegistry { uint8 registered; } - struct IntegratorConfig { - bool isInitialized; - address admin; - } - uint8 constant MAX_TRANSCEIVERS = 128; // =============== Events =============================================== - /// @notice Emitted when a send side transceiver is added. - /// @param integrator The address of the integrator. - /// @param transceiver The address of the transceiver. - /// @param chainId The chain to which the threshold applies. - /// @param transceiversNum The current number of transceivers. - event SendTransceiverAdded(address integrator, address transceiver, uint16 chainId, uint64 transceiversNum); - - /// @notice Emitted when a receive side transceiver is added. + /// @notice Emitted when a transceiver is added. + /// @dev Topic0 + /// 0x11b137fe0ddc829607ddb73c998c8792af425d23c3d44235e97ba9d0ded66e58 /// @param integrator The address of the integrator. + /// @param chain The Wormhole chain ID on which this transceiver is added. /// @param transceiver The address of the transceiver. - /// @param chainId The chain to which the threshold applies. /// @param transceiversNum The current number of transceivers. - event RecvTransceiverAdded(address integrator, address transceiver, uint16 chainId, uint64 transceiversNum); + event TransceiverAdded(address integrator, uint16 chain, address transceiver, uint8 transceiversNum); /// @notice Emitted when a send side transceiver is enabled for a chain. + /// @dev Topic0 + /// 0x1e8617217e121e5aee2e06d784ac4dab35309adecb2a18f98eaf8c430e19a5c3 /// @param integrator The address of the integrator. + /// @param chain The Wormhole chain ID on which this transceiver is enabled. /// @param transceiver The address of the transceiver. - /// @param chainId The chain to which the threshold applies. - event SendTransceiverEnabledForChain(address integrator, address transceiver, uint16 chainId); + event SendTransceiverEnabledForChain(address integrator, uint16 chain, address transceiver); /// @notice Emitted when a receive side transceiver is enabled for a chain. + /// @dev Topic0 + /// 0x3e9ae7f2b6957091d9e99a42a88cf2e8da98b142a61811ac0e9e41f2f9778fbc /// @param integrator The address of the integrator. + /// @param chain The Wormhole chain ID on which this transceiver is enabled. /// @param transceiver The address of the transceiver. - /// @param chainId The chain to which the threshold applies. - event RecvTransceiverEnabledForChain(address integrator, address transceiver, uint16 chainId); + event RecvTransceiverEnabledForChain(address integrator, uint16 chain, address transceiver); - /// @notice Emitted when a send side transceiver is removed from the nttManager. + /// @notice Emitted when a send side transceiver is removed from the router. + /// @dev Topic0 + /// 0xb8844d856d7b255f06b1c28ae0324984a00923b2e98616302766622c20e37fac /// @param integrator The address of the integrator. + /// @param chain The Wormhole chain ID on which this transceiver is disabled. /// @param transceiver The address of the transceiver. - /// @param chainId The chain to which the threshold applies. - event SendTransceiverDisabled(address integrator, address transceiver, uint16 chainId); + event SendTransceiverDisabledForChain(address integrator, uint16 chain, address transceiver); - /// @notice Emitted when a receive side transceiver is removed from the nttManager. + /// @notice Emitted when a receive side transceiver is removed from the router. + /// @dev Topic0 + /// 0x205d0d0e655937210435fc177252accf5845b3c05787d7374023e44970730d33 /// @param integrator The address of the integrator. + /// @param chain The Wormhole chain ID on which this transceiver is disabled. /// @param transceiver The address of the transceiver. - /// @param chainId The chain to which the threshold applies. - event RecvTransceiverDisabled(address integrator, address transceiver, uint16 chainId); + event RecvTransceiverDisabledForChain(address integrator, uint16 chain, address transceiver); // =============== Errors =============================================== /// @notice Error when the caller is not the transceiver. + /// @dev Selector: 0xa0ae911d. /// @param caller The address of the caller. error CallerNotTransceiver(address caller); /// @notice Error when the transceiver is the zero address. + /// @dev Selector: 0x2f44bd77. error InvalidTransceiverZeroAddress(); /// @notice Error when the transceiver is disabled. - error DisabledTransceiver(address transceiver); + /// @dev Selector: 0xa64030ff. + error TransceiverAlreadyDisabled(address transceiver); /// @notice Error when the number of registered transceivers - /// exceeeds (MAX_TRANSCEIVERS = 64). + /// exceeeds (MAX_TRANSCEIVERS = 128). + /// @dev Selector: 0x891684c3. error TooManyTransceivers(); - /// @notice Error when attempting to remove a transceiver + /// @notice Error when attempting to use an unregistered transceiver /// that is not registered. + /// @dev Selector: 0x891684c3. /// @param transceiver The address of the transceiver. error NonRegisteredTransceiver(address transceiver); /// @notice Error when attempting to use an incorrect chain + /// @dev Selector: 0x587c94c3. /// @param chain The id of the incorrect chain error InvalidChain(uint16 chain); + /// @notice Error when attempting to register a transceiver that is already register. + /// @dev Selector: 0xeaac8f97. + /// @param transceiver The address of the transceiver. + error TransceiverAlreadyRegistered(address transceiver); + /// @notice Error when attempting to enable a transceiver that is already enabled. + /// @dev Selector: 0x8d68f84d. /// @param transceiver The address of the transceiver. error TransceiverAlreadyEnabled(address transceiver); - // TODO: Not sure if I need this, yet. Will add if Router.sol needs it. - // modifier onlyTransceiver() { - // if (!_getTransceiverInfosStorage()[msg.sender].enabled) { - // revert CallerNotTransceiver(msg.sender); - // } - // _; - // } - // =============== Storage =============================================== /// @dev Holds the integrator address to transceiver address to TransceiverInfo mapping. @@ -135,10 +134,10 @@ abstract contract TransceiverRegistry { // =============== Send side ============================================= - /// @dev Holds send side integrator address => Chain ID => Enabled transceiver bitmap mapping. - /// mapping(address => mapping(uint16 => uint128)) - bytes32 private constant ENABLED_SEND_TRANSCEIVER_BITMAP_SLOT = - bytes32(uint256(keccak256("registry.sendTransceiverBitmap")) - 1); + /// @dev Holds integrator address => Chain ID => Enabled send side transceiver address[] mapping. + /// mapping(address => mapping(uint16 => address[])) + bytes32 private constant ENABLED_SEND_TRANSCEIVER_ARRAY_SLOT = + bytes32(uint256(keccak256("registry.sendTransceiverArray")) - 1); /// @dev Holds send side Integrator address => Transceiver addresses mapping. /// mapping(address => address[]) across all chains @@ -147,7 +146,7 @@ abstract contract TransceiverRegistry { // =============== Recv side ============================================= - /// @dev Holds receive side integrator address => Chain ID => Enabled transceiver bitmap mapping. + /// @dev Holds integrator address => Chain ID => Enabled transceiver receive side bitmap mapping. /// mapping(address => mapping(uint16 => uint128)) bytes32 private constant ENABLED_RECV_TRANSCEIVER_BITMAP_SLOT = bytes32(uint256(keccak256("registry.recvTransceiverBitmap")) - 1); @@ -172,12 +171,12 @@ abstract contract TransceiverRegistry { } /// @dev Integrator address => Chain ID => Enabled transceiver bitmap mapping. - function _getPerChainSendTransceiverBitmapStorage() + function _getPerChainSendTransceiverArrayStorage() private pure - returns (mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage $) + returns (mapping(address => mapping(uint16 => address[])) storage $) { - uint256 slot = uint256(ENABLED_SEND_TRANSCEIVER_BITMAP_SLOT); + uint256 slot = uint256(ENABLED_SEND_TRANSCEIVER_ARRAY_SLOT); assembly ("memory-safe") { $.slot := slot } @@ -214,246 +213,147 @@ abstract contract TransceiverRegistry { } } - // =============== Storage Getters/Setters ======================================== - - /// @dev Returns if the send side transceiver is enabled for the given integrator and chain. - /// @param integrator The integrator address - /// @param transceiver The transceiver address - /// @param chainId The chain ID - /// @return true if the transceiver is enabled, false otherwise. - function _isSendTransceiverEnabledForChain(address integrator, address transceiver, uint16 chainId) - internal - view - returns (bool) - { - uint128 bitmap = _getEnabledSendTransceiversBitmapForChain(integrator, chainId); - return _isTransceiverEnabledForChain(integrator, transceiver, bitmap); - } - - /// @dev Returns if the receive side transceiver is enabled for the given integrator and chain. - /// @param integrator The integrator address - /// @param transceiver The transceiver address - /// @param chainId The chain ID - /// @return true if the transceiver is enabled, false otherwise. - function _isRecvTransceiverEnabledForChain(address integrator, address transceiver, uint16 chainId) - internal - view - returns (bool) - { - uint128 bitmap = _getEnabledRecvTransceiversBitmapForChain(integrator, chainId); - return _isTransceiverEnabledForChain(integrator, transceiver, bitmap); - } + // =============== Modifiers ====================================================== - /// @dev This is a common function between send/receive transceivers. - /// @param integrator The integrator address + /// @notice This modifier will revert if the transceiver is a zero address or the chain is invalid /// @param transceiver The transceiver address - /// @return true if the transceiver is enabled, false otherwise. - function _isTransceiverEnabledForChain(address integrator, address transceiver, uint128 bitmap) - internal - view - returns (bool) - { + /// @param chain The Wormhole chain ID + modifier onlyValidTransceiverAndChain(address transceiver, uint16 chain) { if (transceiver == address(0)) { revert InvalidTransceiverZeroAddress(); } - uint8 index = _getTransceiverInfosStorage()[integrator][transceiver].index; - return (bitmap & uint128(1 << index)) != 0; + if (chain == 0) { + revert InvalidChain(chain); + } + _; } - /// @dev This function will revert if the transceiver is an invalid address or not registered. + /// @notice This modifier will revert if the transceiver is an invalid address or not registered. + /// Or the chain is invalid /// @param integrator The integrator address + /// @param chain The Wormhole chain ID /// @param transceiver The transceiver address - function _checkTransceiver(address integrator, address transceiver) internal view { + modifier onlyRegisteredTransceiver(address integrator, uint16 chain, address transceiver) { if (transceiver == address(0)) { revert InvalidTransceiverZeroAddress(); } + if (chain == 0) { + revert InvalidChain(chain); + } + if (!_getTransceiverInfosStorage()[integrator][transceiver].registered) { revert NonRegisteredTransceiver(transceiver); } + _; } - /// @dev It is assumed that the integrator address is already validated (and not 0) - /// This just enables the send side transceiver. It does not register it. - /// @param integrator The integrator address - /// @param transceiver The transceiver address - /// @param chainId The chain ID - function _enableSendTransceiverForChain(address integrator, address transceiver, uint16 chainId) internal { - _checkTransceiver(integrator, transceiver); - - uint8 index = _getTransceiverInfosStorage()[integrator][transceiver].index; - mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _bitmaps = - _getPerChainSendTransceiverBitmapStorage(); - _bitmaps[integrator][chainId].bitmap |= uint128(1 << index); - } - - /// @dev It is assumed that the integrator address is already validated (and not 0) - /// This just enables the receive side transceiver. It does not register it. - /// @param integrator The integrator address - /// @param transceiver The transceiver address - /// @param chainId The chain ID - function _enableRecvTransceiverForChain(address integrator, address transceiver, uint16 chainId) internal { - _checkTransceiver(integrator, transceiver); - - uint8 index = _getTransceiverInfosStorage()[integrator][transceiver].index; - mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _bitmaps = - _getPerChainRecvTransceiverBitmapStorage(); - _bitmaps[integrator][chainId].bitmap |= uint128(1 << index); - } + // =============== Storage Getters/Setters ======================================== - /// @dev This function enables a send side transceiver. If it is not registered, it will register it. + /// @dev This function adds a transceiver. /// @param integrator The integrator address + /// @param chain The The Wormhole chain ID /// @param transceiver The transceiver address - /// @param chainId The chain ID - /// @return index The index of this newly enabled send side transceiver - function _setSendTransceiver(address integrator, address transceiver, uint16 chainId) + /// @return index The index of this newly enabled transceiver + function _addTransceiver(address integrator, uint16 chain, address transceiver) internal + onlyValidTransceiverAndChain(transceiver, chain) returns (uint8 index) { - // These are everything for an integrator. mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); mapping(address => _NumTransceivers) storage _numTransceivers = _getNumTransceiversStorage(); - // This is send side for a specific chain. - mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _enabledTransceiverBitmap = - _getPerChainSendTransceiverBitmapStorage(); - - if (transceiver == address(0)) { - revert InvalidTransceiverZeroAddress(); - } - if (chainId == 0) { - revert InvalidChain(chainId); + if (transceiverInfos[integrator][transceiver].registered) { + revert TransceiverAlreadyRegistered(transceiver); } - - if (!transceiverInfos[integrator][transceiver].registered) { - if (_numTransceivers[integrator].registered >= MAX_TRANSCEIVERS) { - revert TooManyTransceivers(); - } - - // Create the TransceiverInfo - transceiverInfos[integrator][transceiver] = - TransceiverInfo({registered: true, index: _numTransceivers[integrator].registered}); - // Add this transceiver to the integrator => address[] mapping - _getRegisteredTransceiversStorage()[integrator].push(transceiver); - // Increment count of transceivers - _numTransceivers[integrator].registered++; - // Emit an event - emit SendTransceiverAdded(integrator, transceiver, chainId, _numTransceivers[integrator].registered); + if (_numTransceivers[integrator].registered >= MAX_TRANSCEIVERS) { + revert TooManyTransceivers(); } - - // _numTransceivers[integrator].enabled++; - - // Add this transceiver to the per chain list of transceivers by updating the bitmap - uint128 updatedEnabledTransceiverBitmap = _enabledTransceiverBitmap[integrator][chainId].bitmap - | uint128(1 << transceiverInfos[integrator][transceiver].index); - // ensure that this actually changed the bitmap - if (updatedEnabledTransceiverBitmap == _enabledTransceiverBitmap[integrator][chainId].bitmap) { - revert TransceiverAlreadyEnabled(transceiver); - } - _enabledTransceiverBitmap[integrator][chainId].bitmap = updatedEnabledTransceiverBitmap; - - _checkSendTransceiversInvariants(integrator); - emit SendTransceiverEnabledForChain(integrator, transceiver, chainId); + // Create the TransceiverInfo + transceiverInfos[integrator][transceiver] = + TransceiverInfo({registered: true, index: _numTransceivers[integrator].registered}); + // Add this transceiver to the integrator => address[] mapping + _getRegisteredTransceiversStorage()[integrator].push(transceiver); + // Increment count of transceivers + _numTransceivers[integrator].registered++; + // Emit an event + emit TransceiverAdded(integrator, chain, transceiver, _numTransceivers[integrator].registered); return transceiverInfos[integrator][transceiver].index; } - /// @dev This function enables a transceiver. If it is not registered, it will register it. + /// @dev It is assumed that the integrator address is already validated (and not 0) + /// This just enables the send side transceiver. It does not register it. /// @param integrator The integrator address + /// @param chain The Wormhole chain ID /// @param transceiver The transceiver address - /// @param chainId The chain ID - /// @return index The index of this newly enabled receive side transceiver - function _setRecvTransceiver(address integrator, address transceiver, uint16 chainId) + function _enableSendTransceiver(address integrator, uint16 chain, address transceiver) internal - returns (uint8 index) + onlyRegisteredTransceiver(integrator, chain, transceiver) { - // These are everything for an integrator. - mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); - mapping(address => _NumTransceivers) storage _numTransceivers = _getNumTransceiversStorage(); - // This is send side for a specific chain. - mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _enabledTransceiverBitmap = - _getPerChainRecvTransceiverBitmapStorage(); - - if (transceiver == address(0)) { - revert InvalidTransceiverZeroAddress(); - } - - if (chainId == 0) { - revert InvalidChain(chainId); - } - - if (!transceiverInfos[integrator][transceiver].registered) { - if (_numTransceivers[integrator].registered >= MAX_TRANSCEIVERS) { - revert TooManyTransceivers(); - } - - // Create the TransceiverInfo - transceiverInfos[integrator][transceiver] = - TransceiverInfo({registered: true, index: _numTransceivers[integrator].registered}); - // Add this transceiver to the integrator => address[] mapping - _getRegisteredTransceiversStorage()[integrator].push(transceiver); - // Increment count of transceivers - _numTransceivers[integrator].registered++; - // Emit an event - emit RecvTransceiverAdded(integrator, transceiver, chainId, _numTransceivers[integrator].registered); - } - - // _numTransceivers[integrator].enabled++; - - // Add this transceiver to the per chain list of transceivers by updating the bitmap - uint128 updatedEnabledTransceiverBitmap = _enabledTransceiverBitmap[integrator][chainId].bitmap - | uint128(1 << transceiverInfos[integrator][transceiver].index); - // ensure that this actually changed the bitmap - if (updatedEnabledTransceiverBitmap == _enabledTransceiverBitmap[integrator][chainId].bitmap) { + if (_isSendTransceiverEnabledForChain(integrator, chain, transceiver)) { revert TransceiverAlreadyEnabled(transceiver); } - _enabledTransceiverBitmap[integrator][chainId].bitmap = updatedEnabledTransceiverBitmap; - - _checkRecvTransceiversInvariants(integrator); - emit RecvTransceiverEnabledForChain(integrator, transceiver, chainId); - - return transceiverInfos[integrator][transceiver].index; + mapping(address => mapping(uint16 => address[])) storage sendTransceiverArray = + _getPerChainSendTransceiverArrayStorage(); + sendTransceiverArray[integrator][chain].push(transceiver); + emit SendTransceiverEnabledForChain(integrator, chain, transceiver); } - /// @dev This function disables a send side transceiver by chain. - /// @notice This function will revert under the following conditions: - /// - The transceiver is the zero address - /// - The transceiver is not registered + /// @dev It is assumed that the integrator address is already validated (and not 0) + /// This just enables the receive side transceiver. It does not register it. /// @param integrator The integrator address + /// @param chain The Wormhole chain ID /// @param transceiver The transceiver address - /// @param chainId The chain ID - function _disableSendTransceiver(address integrator, address transceiver, uint16 chainId) internal { - mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); - mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _enabledTransceiverBitmap = - _getPerChainSendTransceiverBitmapStorage(); - - if (transceiver == address(0)) { - revert InvalidTransceiverZeroAddress(); - } - - if (chainId == 0) { - revert InvalidChain(chainId); + function _enableRecvTransceiver(address integrator, uint16 chain, address transceiver) + internal + onlyRegisteredTransceiver(integrator, chain, transceiver) + { + if (_isRecvTransceiverEnabledForChain(integrator, chain, transceiver)) { + revert TransceiverAlreadyEnabled(transceiver); } + uint8 index = _getTransceiverInfosStorage()[integrator][transceiver].index; + mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _bitmaps = + _getPerChainRecvTransceiverBitmapStorage(); + _bitmaps[integrator][chain].bitmap |= uint128(1 << index); + emit RecvTransceiverEnabledForChain(integrator, chain, transceiver); + } - TransceiverInfo storage info = transceiverInfos[integrator][transceiver]; - - if (!info.registered) { - revert NonRegisteredTransceiver(transceiver); + /// @notice This function disables a send side transceiver by chain. + /// @param integrator The integrator address + /// @param chain The chain ID + /// @param transceiver The transceiver address + function _disableSendTransceiver(address integrator, uint16 chain, address transceiver) + internal + onlyRegisteredTransceiver(integrator, chain, transceiver) + { + // mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); + mapping(address => mapping(uint16 => address[])) storage enabledSendTransceivers = + _getPerChainSendTransceiverArrayStorage(); + address[] storage transceivers = enabledSendTransceivers[integrator][chain]; + + // Get the index of the disabled transceiver in the enabled transceivers array + // and replace it with the last element in the array + uint256 len = transceivers.length; + bool found = false; + for (uint256 i = 0; i < len;) { + if (transceivers[i] == transceiver) { + // Swap the last element with the element to be removed + transceivers[i] = transceivers[len - 1]; + // Remove the last element + transceivers.pop(); + found = true; + break; + } + unchecked { + ++i; + } } - - uint128 updatedEnabledTransceiverBitmap = _enabledTransceiverBitmap[integrator][chainId].bitmap - & uint128(~(1 << transceiverInfos[integrator][transceiver].index)); - // ensure that this actually changed the bitmap - if (updatedEnabledTransceiverBitmap >= _enabledTransceiverBitmap[integrator][chainId].bitmap) { - revert DisabledTransceiver(transceiver); + if (!found) { + revert TransceiverAlreadyDisabled(transceiver); } - _enabledTransceiverBitmap[integrator][chainId].bitmap = updatedEnabledTransceiverBitmap; - _checkSendTransceiversInvariants(integrator); - // we call the invariant check on the transceiver here as well, since - // the above check only iterates through the enabled transceivers. - _checkSendTransceiverInvariants(integrator, transceiver); - emit SendTransceiverDisabled(integrator, transceiver, chainId); + emit SendTransceiverDisabledForChain(integrator, chain, transceiver); } /// @dev This function disables a receive side transceiver by chain. @@ -461,70 +361,95 @@ abstract contract TransceiverRegistry { /// - The transceiver is the zero address /// - The transceiver is not registered /// @param integrator The integrator address + /// @param chain The Wormhole chain ID /// @param transceiver The transceiver address - /// @param chainId The chain ID - function _disableRecvTransceiver(address integrator, address transceiver, uint16 chainId) internal { + function _disableRecvTransceiver(address integrator, uint16 chain, address transceiver) + internal + onlyRegisteredTransceiver(integrator, chain, transceiver) + { mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _enabledTransceiverBitmap = _getPerChainRecvTransceiverBitmapStorage(); - if (transceiver == address(0)) { - revert InvalidTransceiverZeroAddress(); - } - - if (chainId == 0) { - revert InvalidChain(chainId); + uint128 updatedEnabledTransceiverBitmap = _enabledTransceiverBitmap[integrator][chain].bitmap + & uint128(~(1 << transceiverInfos[integrator][transceiver].index)); + // ensure that this actually changed the bitmap + if (updatedEnabledTransceiverBitmap >= _enabledTransceiverBitmap[integrator][chain].bitmap) { + revert TransceiverAlreadyDisabled(transceiver); } + _enabledTransceiverBitmap[integrator][chain].bitmap = updatedEnabledTransceiverBitmap; - TransceiverInfo storage info = transceiverInfos[integrator][transceiver]; - - if (!info.registered) { - revert NonRegisteredTransceiver(transceiver); - } + emit RecvTransceiverDisabledForChain(integrator, chain, transceiver); + } - uint128 updatedEnabledTransceiverBitmap = _enabledTransceiverBitmap[integrator][chainId].bitmap - & uint128(~(1 << transceiverInfos[integrator][transceiver].index)); - // ensure that this actually changed the bitmap - if (updatedEnabledTransceiverBitmap >= _enabledTransceiverBitmap[integrator][chainId].bitmap) { - revert DisabledTransceiver(transceiver); + /// @dev Returns if the send side transceiver is enabled for the given integrator and chain. + /// @param integrator The integrator address + /// @param chain The Wormhole chain ID + /// @param transceiver The transceiver address + /// @return true if the transceiver is enabled, false otherwise. + function _isSendTransceiverEnabledForChain(address integrator, uint16 chain, address transceiver) + internal + view + onlyRegisteredTransceiver(integrator, chain, transceiver) + returns (bool) + { + address[] storage transceivers = _getPerChainSendTransceiverArrayStorage()[integrator][chain]; + uint256 length = transceivers.length; + for (uint256 i = 0; i < length;) { + if (transceivers[i] == transceiver) { + return true; + } + unchecked { + ++i; + } } - _enabledTransceiverBitmap[integrator][chainId].bitmap = updatedEnabledTransceiverBitmap; + return false; + } - _checkRecvTransceiversInvariants(integrator); - // we call the invariant check on the transceiver here as well, since - // the above check only iterates through the enabled transceivers. - _checkRecvTransceiverInvariants(integrator, transceiver); - emit RecvTransceiverDisabled(integrator, transceiver, chainId); + /// @dev Returns if the receive side transceiver is enabled for the given integrator and chain. + /// @param integrator The integrator address + /// @param chain The Wormhole chain ID + /// @param transceiver The transceiver address + /// @return true if the transceiver is enabled, false otherwise. + function _isRecvTransceiverEnabledForChain(address integrator, uint16 chain, address transceiver) + internal + view + onlyRegisteredTransceiver(integrator, chain, transceiver) + returns (bool) + { + uint128 bitmap = _getEnabledRecvTransceiversBitmapForChain(integrator, chain); + uint8 index = _getTransceiverInfosStorage()[integrator][transceiver].index; + return (bitmap & uint128(1 << index)) != 0; } /// @param integrator The integrator address - /// @param forChainId The chain ID - /// @return bitmap The bitmap of the send side transceivers enabled for this integrator and chain - function _getEnabledSendTransceiversBitmapForChain(address integrator, uint16 forChainId) + /// @param chain The Wormhole chain ID + /// @return array The array of the send side transceivers enabled for this integrator and chain + function _getEnabledSendTransceiversArrayForChain(address integrator, uint16 chain) internal view virtual - returns (uint128 bitmap) + returns (address[] storage array) { - if (forChainId == 0) { - revert InvalidChain(forChainId); + if (chain == 0) { + revert InvalidChain(chain); } - bitmap = _getPerChainSendTransceiverBitmapStorage()[integrator][forChainId].bitmap; + array = _getPerChainSendTransceiverArrayStorage()[integrator][chain]; } /// @param integrator The integrator address - /// @param forChainId The chain ID + /// @param chain The Wormhole chain ID /// @return bitmap The bitmap of the send side transceivers enabled for this integrator and chain - function _getEnabledRecvTransceiversBitmapForChain(address integrator, uint16 forChainId) + function _getEnabledRecvTransceiversBitmapForChain(address integrator, uint16 chain) internal view virtual returns (uint128 bitmap) { - if (forChainId == 0) { - revert InvalidChain(forChainId); + if (chain == 0) { + revert InvalidChain(chain); } - bitmap = _getPerChainRecvTransceiverBitmapStorage()[integrator][forChainId].bitmap; + bitmap = _getPerChainRecvTransceiverBitmapStorage()[integrator][chain].bitmap; } // =============== EXTERNAL FUNCTIONS ======================================== @@ -538,171 +463,59 @@ abstract contract TransceiverRegistry { /// @notice Returns the enabled send side transceiver addresses for the given integrator. /// @param integrator The integrator address - /// @param chainId The chainId for the desired transceivers + /// @param chain The Wormhole chain ID for the desired transceivers /// @return result The enabled send side transceivers for the given integrator and chain. - function getSendTransceiversByChain(address integrator, uint16 chainId) - external + function getSendTransceiversByChain(address integrator, uint16 chain) + public view returns (address[] memory result) { + if (chain == 0) { + revert InvalidChain(chain); + } address[] memory allTransceivers = _getRegisteredTransceiversStorage()[integrator]; address[] memory tempResult = new address[](allTransceivers.length); - for (uint256 i = 0; i < allTransceivers.length; i++) { - if (_isSendTransceiverEnabledForChain(integrator, allTransceivers[i], chainId)) { - tempResult[i] = allTransceivers[i]; + uint8 len = 0; + uint256 allTransceiversLength = allTransceivers.length; + for (uint256 i = 0; i < allTransceiversLength;) { + if (_isSendTransceiverEnabledForChain(integrator, chain, allTransceivers[i])) { + tempResult[len] = allTransceivers[i]; + ++len; + } + unchecked { + ++i; } } - result = new address[](tempResult.length); - for (uint256 i = 0; i < tempResult.length; i++) { + result = new address[](len); + for (uint8 i = 0; i < len; i++) { result[i] = tempResult[i]; } } - /// @notice Returns the enabled send side transceiver addresses for the given integrator. + /// @notice Returns the enabled receive side transceiver addresses for the given integrator. /// @param integrator The integrator address - /// @param chainId The chainId for the desired transceivers - /// @return result The enabled send side transceivers for the given integrator. - function getRecvTransceiversByChain(address integrator, uint16 chainId) - external + /// @param chain The Wormhole chain ID for the desired transceivers + /// @return result The enabled receive side transceivers for the given integrator. + function getRecvTransceiversByChain(address integrator, uint16 chain) + public view returns (address[] memory result) { address[] memory allTransceivers = _getRegisteredTransceiversStorage()[integrator]; - address[] memory tempResult = new address[](allTransceivers.length); + // Count number of bits set in the bitmap so we can calculate the size of the result array. + uint128 bitmap = _getEnabledRecvTransceiversBitmapForChain(integrator, chain); + uint8 count = 0; + while (bitmap != 0) { + count += uint8(bitmap & 1); + bitmap >>= 1; + } + result = new address[](count); + uint256 len = 0; for (uint256 i = 0; i < allTransceivers.length; i++) { - if (_isRecvTransceiverEnabledForChain(integrator, allTransceivers[i], chainId)) { - tempResult[i] = allTransceivers[i]; + if (_isRecvTransceiverEnabledForChain(integrator, chain, allTransceivers[i])) { + result[len] = allTransceivers[i]; + ++len; } } - result = new address[](tempResult.length); - for (uint256 i = 0; i < tempResult.length; i++) { - result[i] = tempResult[i]; - } - } - - // ============== Invariants ============================================= - - /// @dev Check that the transceiver is in a valid state. - /// Checking these invariants is somewhat costly, but we only need to do it - /// when modifying the transceivers, which happens infrequently. - function _checkSendTransceiversInvariants(address integrator) internal view { - // _NumTransceivers storage _numTransceivers = _getNumSendTransceiversStorage()[integrator]; - // address[] storage _enabledTransceivers = _getRegisteredSendTransceiversStorage()[integrator]; - - // uint256 numTransceiversEnabled = _numTransceivers.enabled; - // assert(numTransceiversEnabled == _enabledTransceivers.length); - - // for (uint256 i = 0; i < numTransceiversEnabled; i++) { - // _checkSendTransceiverInvariants(integrator, _enabledTransceivers[i]); - // } - - // // invariant: each transceiver is only enabled once - // for (uint256 i = 0; i < numTransceiversEnabled; i++) { - // for (uint256 j = i + 1; j < numTransceiversEnabled; j++) { - // assert(_enabledTransceivers[i] != _enabledTransceivers[j]); - // } - // } - - // // invariant: numRegisteredTransceivers <= MAX_TRANSCEIVERS - // assert(_numTransceivers.registered <= MAX_TRANSCEIVERS); - } - - /// @dev Check that the transceiver is in a valid state. - function _checkSendTransceiverInvariants(address integrator, address transceiver) private view { - // mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); - // mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _enabledTransceiverBitmap = - // _getPerChainSendTransceiverBitmapStorage(); - // mapping(address => _NumTransceivers) storage _numTransceivers = _getNumSendTransceiversStorage(); - // mapping(address => address[]) storage _enabledTransceivers = _getRegisteredSendTransceiversStorage(); - - // TransceiverInfo memory transceiverInfo = transceiverInfos[integrator][transceiver]; - - // // if an transceiver is not registered, it should not be enabled - // assert(transceiverInfo.registered || (!transceiverInfo.enabled && transceiverInfo.index == 0)); - - // bool transceiverInEnabledBitmap = ( - // _enabledTransceiverBitmap[integrator][transceiverInfo.chainId].bitmap & uint128(1 << transceiverInfo.index) - // ) != 0; - // bool transceiverEnabled = transceiverInfo.enabled; - - // bool transceiverInEnabledTransceivers = false; - - // for (uint256 i = 0; i < _numTransceivers[integrator].enabled; i++) { - // if (_enabledTransceivers[integrator][i] == transceiver) { - // transceiverInEnabledTransceivers = true; - // break; - // } - // } - - // // invariant: transceiverInfos[integrator][transceiver].enabled - // // <=> enabledTransceiverBitmap & (1 << transceiverInfos[integrator][transceiver].index) != 0 - // assert(transceiverInEnabledBitmap == transceiverEnabled); - - // // invariant: transceiverInfos[integrator][transceiver].enabled <=> transceiver in _enabledTransceivers - // assert(transceiverInEnabledTransceivers == transceiverEnabled); - - // assert(transceiverInfo.index < _numTransceivers[integrator].registered); - } - - /// @dev Check that the transceiver is in a valid state. - /// Checking these invariants is somewhat costly, but we only need to do it - /// when modifying the transceivers, which happens infrequently. - function _checkRecvTransceiversInvariants(address integrator) internal view { - // _NumTransceivers storage _numTransceivers = _getNumRecvTransceiversStorage()[integrator]; - // address[] storage _enabledTransceivers = _getRegisteredRecvTransceiversStorage()[integrator]; - - // uint256 numTransceiversEnabled = _numTransceivers.enabled; - // assert(numTransceiversEnabled == _enabledTransceivers.length); - - // for (uint256 i = 0; i < numTransceiversEnabled; i++) { - // _checkRecvTransceiverInvariants(integrator, _enabledTransceivers[i]); - // } - - // // invariant: each transceiver is only enabled once - // for (uint256 i = 0; i < numTransceiversEnabled; i++) { - // for (uint256 j = i + 1; j < numTransceiversEnabled; j++) { - // assert(_enabledTransceivers[i] != _enabledTransceivers[j]); - // } - // } - - // // invariant: numRegisteredTransceivers <= MAX_TRANSCEIVERS - // assert(_numTransceivers.registered <= MAX_TRANSCEIVERS); - } - - /// @dev Check that the transceiver is in a valid state. - function _checkRecvTransceiverInvariants(address integrator, address transceiver) private view { - // mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); - // mapping(address => mapping(uint16 => _EnabledTransceiverBitmap)) storage _enabledTransceiverBitmap = - // _getPerChainRecvTransceiverBitmapStorage(); - // mapping(address => _NumTransceivers) storage _numTransceivers = _getNumRecvTransceiversStorage(); - // mapping(address => address[]) storage _enabledTransceivers = _getRegisteredRecvTransceiversStorage(); - - // TransceiverInfo memory transceiverInfo = transceiverInfos[integrator][transceiver]; - - // // if an transceiver is not registered, it should not be enabled - // assert(transceiverInfo.registered || (!transceiverInfo.enabled && transceiverInfo.index == 0)); - - // bool transceiverInEnabledBitmap = ( - // _enabledTransceiverBitmap[integrator][transceiverInfo.chainId].bitmap & uint128(1 << transceiverInfo.index) - // ) != 0; - // bool transceiverEnabled = transceiverInfo.enabled; - - // bool transceiverInEnabledTransceivers = false; - - // for (uint256 i = 0; i < _numTransceivers[integrator].enabled; i++) { - // if (_enabledTransceivers[integrator][i] == transceiver) { - // transceiverInEnabledTransceivers = true; - // break; - // } - // } - - // // invariant: transceiverInfos[integrator][transceiver].enabled - // // <=> enabledTransceiverBitmap & (1 << transceiverInfos[integrator][transceiver].index) != 0 - // assert(transceiverInEnabledBitmap == transceiverEnabled); - - // // invariant: transceiverInfos[integrator][transceiver].enabled <=> transceiver in _enabledTransceivers - // assert(transceiverInEnabledTransceivers == transceiverEnabled); - - // assert(transceiverInfo.index < _numTransceivers[integrator].registered); } } diff --git a/evm/src/interfaces/IRouter.sol b/evm/src/interfaces/IRouter.sol deleted file mode 100644 index 567db858..00000000 --- a/evm/src/interfaces/IRouter.sol +++ /dev/null @@ -1,43 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -pragma solidity ^0.8.13; - -import "./IMessageSequence.sol"; -import "../libraries/UniversalAddress.sol"; - -interface IRouter is IMessageSequence { - /// @dev Send a message to another chain. - /// @param recipientChain The Wormhole chain ID of the recipient. - /// @param recipientAddress The universal address of the peer on the recipient chain. - /// @param refundAddress The source chain refund address passed to the Transceiver. - /// @param payloadHash keccak256 of a message to be sent to the recipient chain. - /// @return uint64 The sequence number of the message. - function sendMessage( - uint16 recipientChain, - UniversalAddress recipientAddress, - address refundAddress, - bytes32 payloadHash - ) external payable returns (uint64); - - // /// @dev Receive a message from another chain called by integrator. - // /// @param sourceChain The Wormhole chain ID of the recipient. - // /// @param senderAddress The universal address of the peer on the recipient chain. - // /// @param refundAddress The source chain refund address passed to the Transceiver. - // /// @param message A message to be sent to the recipient chain. - // /// @return uint128 The bitmap - function receiveMessage( - uint16 sourceChain, - UniversalAddress senderAddress, - address refundAddress, - bytes32 messageHash - ) external payable returns (uint128); - - /// @notice Called by a Transceiver contract to deliver a verified attestation. - function attestMessage( - uint16 sourceChain, // Wormhole Chain ID - UniversalAddress sourceAddress, // UniversalAddress of the message sender (integrator) - uint64 sequence, // Next sequence number for that integrator (consuming the sequence number) - uint16 destinationChainId, // Wormhole Chain ID - UniversalAddress destinationAddress, // UniversalAddress of the messsage recipient (integrator on destination chain) - bytes32 payloadHash // keccak256 of arbitrary payload from the integrator - ) external; -} diff --git a/evm/src/interfaces/IRouterAdmin.sol b/evm/src/interfaces/IRouterAdmin.sol new file mode 100644 index 00000000..3d2388ce --- /dev/null +++ b/evm/src/interfaces/IRouterAdmin.sol @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.13; + +interface IRouterAdmin { + /// @notice Transfers admin privileges from the current admin to another contract. + /// @dev The msg.sender must be the current admin contract. + /// @param integrator The address of the integrator contract. + /// @param newAdmin The address of the new admin. + function updateAdmin(address integrator, address newAdmin) external; + + /// @notice Starts the two step process of transferring admin privileges from the current admin to another contract. + /// @dev The msg.sender must be the current admin contract. + /// @param integrator The address of the integrator contract. + /// @param newAdmin The address of the new admin. + function transferAdmin(address integrator, address newAdmin) external; + + /// @notice Starts the two step process of transferring admin privileges from the current admin to another contract. + /// @dev The msg.sender must be the current admin contract. + /// @param integrator The address of the integrator contract. + function claimAdmin(address integrator) external; + + /// @notice Clears the current admin. THIS IS NOT REVERSIBLE. + /// This ensures that the Integrator configuration becomes immutable. + /// @dev The msg.sender must be the current admin contract. + /// @param integrator The address of the integrator contract. + function discardAdmin(address integrator) external; + + /// @notice Adds the given transceiver to the given chain for the integrator's list of transceivers. + /// This does NOT enable the transceiver for sending or receiving. + /// @param integrator The address of the integrator contract. + /// @param transceiver The address of the Transceiver contract. + /// @param chainId The chain ID of the Transceiver contract. + function addTransceiver(address integrator, uint16 chainId, address transceiver) external returns (uint8 index); + + /// @notice This enables the sending of messages from the given transceiver on the given chain. + /// @param integrator The address of the integrator contract. + /// @param transceiver The address of the Transceiver contract. + /// @param chain The chain ID of the Transceiver contract. + function enableSendTransceiver(address integrator, uint16 chain, address transceiver) external; + + /// @notice This enables the receiving of messages by the given transceiver on the given chain. + /// @param integrator The address of the integrator contract. + /// @param transceiver The address of the Transceiver contract. + /// @param chain The chain ID of the Transceiver contract. + function enableRecvTransceiver(address integrator, uint16 chain, address transceiver) external; + + /// @notice This disables the sending of messages from the given transceiver on the given chain. + /// @param integrator The address of the integrator contract. + /// @param transceiver The address of the Transceiver contract. + /// @param chain The chain ID of the Transceiver contract. + function disableSendTransceiver(address integrator, uint16 chain, address transceiver) external; + + /// @notice This disables the receiving of messages by the given transceiver on the given chain. + /// @param integrator The address of the integrator contract. + /// @param transceiver The address of the Transceiver contract. + /// @param chain The chain ID of the Transceiver contract. + function disableRecvTransceiver(address integrator, uint16 chain, address transceiver) external; +} diff --git a/evm/src/interfaces/IRouterIntegrator.sol b/evm/src/interfaces/IRouterIntegrator.sol new file mode 100644 index 00000000..4e9a8709 --- /dev/null +++ b/evm/src/interfaces/IRouterIntegrator.sol @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.13; + +import "./IMessageSequence.sol"; +import "../libraries/UniversalAddress.sol"; + +interface IRouterIntegrator is IMessageSequence { + /// @notice This is the first thing an integrator should do. It registers the integrator with the router + /// and sets the administrator contract for that integrator. The admin address is used to manage the transceivers. + /// @dev The msg.sender needs to be the integrator contract. + /// @param initialAdmin The address of the admin. Pass in msg.sender, if you want the integrator to be the admin. + function register(address initialAdmin) external; + + /// @notice Send a message to another chain. + /// @param dstChain The Wormhole chain ID of the recipient. + /// @param dstAddr The universal address of the peer on the recipient chain. + /// @param refundAddress The source chain refund address passed to the Transceiver. + /// @param payloadHash keccak256 of a message to be sent to the recipient chain. + /// @return uint64 The sequence number of the message. + function sendMessage(uint16 dstChain, UniversalAddress dstAddr, address refundAddress, bytes32 payloadHash) + external + payable + returns (uint64); + + /// @notice Receive a message and mark it executed. + /// @param srcChain The Wormhole chain ID of the sender. + /// @param srcAddr The universal address of the peer on the sending chain. + /// @param sequence The sequence number of the message (per integrator). + /// @param dstChain The Wormhole chain ID of the destination. + /// @param dstAddr The destination address of the message. + /// @param payloadHash The keccak256 of payload from the integrator. + /// @return (uint128, uint128) The enabled bitmap, and the attested bitmap, respectively. + function recvMessage( + uint16 srcChain, + UniversalAddress srcAddr, + uint64 sequence, + uint16 dstChain, + UniversalAddress dstAddr, + bytes32 payloadHash + ) external payable returns (uint128, uint128); + + /// @notice Execute a message without requiring any attestations. + /// @param srcChain The Wormhole chain ID of the sender. + /// @param srcAddr The universal address of the peer on the sending chain. + /// @param sequence The sequence number of the message (per integrator). + /// @param dstChain The Wormhole chain ID of the destination. + /// @param dstAddr The destination address of the message. + /// @param payloadHash The keccak256 of payload from the integrator. + function execMessage( + uint16 srcChain, + UniversalAddress srcAddr, + uint64 sequence, + uint16 dstChain, + UniversalAddress dstAddr, + bytes32 payloadHash + ) external; + + /// @notice Retrieve the status of a message. + /// @param srcChain The Wormhole chain ID of the sender. + /// @param srcAddr The universal address of the message. + /// @param sequence The sequence number of the message. + /// @param dstChain The Wormhole chain ID of the destination. + /// @param dstAddr The destination address of the message. + /// @param payloadHash The keccak256 of payload from the integrator. + /// @return (uint128, uint128, bool) The enabled bitmap, the attested bitmap, if the message was executed. + function getMessageStatus( + uint16 srcChain, + UniversalAddress srcAddr, + uint64 sequence, + uint16 dstChain, + UniversalAddress dstAddr, + bytes32 payloadHash + ) external returns (uint128, uint128, bool); +} diff --git a/evm/src/interfaces/IRouterTransceiver.sol b/evm/src/interfaces/IRouterTransceiver.sol new file mode 100644 index 00000000..548661e2 --- /dev/null +++ b/evm/src/interfaces/IRouterTransceiver.sol @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.13; + +import "../libraries/UniversalAddress.sol"; + +interface IRouterTransceiver { + /// @notice Called by a Transceiver contract to attest to a message. + /// @param srcChain The Wormhole chain ID of the sender. + /// @param srcAddr The universal address of the peer on the sending chain. + /// @param sequence The sequence number of the message (per integrator). + /// @param dstChain The Wormhole chain ID of the destination. + /// @param dstAddr The destination address of the message. + /// @param payloadHash The keccak256 of payload from the integrator. + function attestMessage( + uint16 srcChain, + UniversalAddress srcAddr, + uint64 sequence, + uint16 dstChain, + UniversalAddress dstAddr, + bytes32 payloadHash + ) external; +} diff --git a/evm/src/interfaces/ITransceiver.sol b/evm/src/interfaces/ITransceiver.sol index a605790a..f0a3590e 100644 --- a/evm/src/interfaces/ITransceiver.sol +++ b/evm/src/interfaces/ITransceiver.sol @@ -4,8 +4,8 @@ pragma solidity ^0.8.13; import "../libraries/UniversalAddress.sol"; interface ITransceiver { - /// @notice The caller is not the NttManager. - /// @dev Selector: 0xc5aa6153. + /// @notice The caller is not the Router. + /// @dev Selector: 0xfb217bcd. /// @param caller The address of the caller. error CallerNotRouter(address caller); @@ -19,14 +19,18 @@ interface ITransceiver { function quoteDeliveryPrice(uint16 recipientChain) external view returns (uint256); /// @dev Send a message to another chain. - /// @param recipientChain The Wormhole chain ID of the recipient. - /// @param messageHash The hash of the message to be sent to the recipient chain. - /// @param recipientAddress The Wormhole formatted address of the recipient chain. - /// @param refundAddress The address of the refund recipient + /// @param srcAddr The universal address of the sender. + /// @param dstChain The Wormhole chain ID of the recipient. + /// @param dstAddr The universal address of the recipient. + /// @param sequence The per-integrator sequence number associated with the message. + /// @param payloadHash The hash of the message to be sent to the recipient chain. + /// @param refundAddr The address of the refund recipient function sendMessage( - uint16 recipientChain, - bytes32 messageHash, - UniversalAddress recipientAddress, - bytes32 refundAddress + UniversalAddress srcAddr, + uint16 dstChain, + UniversalAddress dstAddr, + uint64 sequence, + bytes32 payloadHash, + bytes32 refundAddr ) external payable; } diff --git a/evm/test/Router.t.sol b/evm/test/Router.t.sol index 020c1de2..f764427c 100644 --- a/evm/test/Router.t.sol +++ b/evm/test/Router.t.sol @@ -2,19 +2,27 @@ pragma solidity ^0.8.13; import {Test, console} from "forge-std/Test.sol"; +import "forge-std/console.sol"; import "../src/libraries/UniversalAddress.sol"; import {Router} from "../src/Router.sol"; import {TransceiverRegistry} from "../src/TransceiverRegistry.sol"; import {ITransceiver} from "../src/interfaces/ITransceiver.sol"; contract RouterImpl is Router { -// function getDelegate(address delegator) public view returns (address) { -// return _getDelegateStorage()[delegator]; -// } + uint16 public constant OurChainId = 42; -// function registerDelegate(address delegate) public { -// _registerDelegate(delegate); -// } + constructor() Router(OurChainId) {} + + function myComputeMessageHash( + uint16 sourceChain, + UniversalAddress srcAddr, + uint64 sequence, + uint16 destinationChain, + UniversalAddress dstAddr, + bytes32 payloadHash + ) public pure returns (bytes32) { + return _computeMessageDigest(sourceChain, srcAddr, sequence, destinationChain, dstAddr, payloadHash); + } } // This contract does send/receive operations @@ -25,23 +33,6 @@ contract Integrator { constructor(address _router) { router = RouterImpl(_router); } - - function setMeAsAdmin(address admin) public { - myAdmin = admin; - } - - function registerWithRouter() public { - router.registerAdmin(myAdmin); - } - - function sendMessage( - uint16 recipientChain, - UniversalAddress recipientAddress, - address refundAddress, - bytes32 payloadHash - ) public payable returns (uint64) { - return router.sendMessage(recipientChain, recipientAddress, refundAddress, payloadHash); - } } // This contract can only do transceiver operations @@ -53,21 +44,13 @@ contract Admin { integrator = _integrator; router = RouterImpl(_router); } - - function requestAdmin() public { - Integrator(integrator).setMeAsAdmin(address(this)); - } - - function setSendTransceiver(address transceiver, uint16 chain) public { - router.setSendTransceiver(integrator, transceiver, chain); - } - - function setRecvTransceiver(address transceiver, uint16 chain) public { - router.setRecvTransceiver(integrator, transceiver, chain); - } } contract TransceiverImpl is ITransceiver { + //======================= Interface ================================================= + // add this to be excluded from coverage report + function test() public {} + function getTransceiverType() public pure override returns (string memory) { return "test"; } @@ -77,16 +60,28 @@ contract TransceiverImpl is ITransceiver { } function sendMessage( - uint16 recipientChain, - bytes32 messageHash, - UniversalAddress recipientAddress, - bytes32 refundAddress + UniversalAddress, // sourceAddress, + uint16, // recipientChain, + UniversalAddress, // recipientAddress, + uint64, // sequence, + bytes32, // payloadHash, + bytes32 // refundAddress ) public payable override { - // Do nothing + messagesSent += 1; + } + + //======================= Implementation ================================================= + + uint256 public messagesSent; + + function getMessagesSent() public view returns (uint256) { + return messagesSent; } } contract RouterTest is Test { + uint16 constant OurChainId = 42; + RouterImpl public router; TransceiverImpl public transceiverImpl; @@ -100,112 +95,495 @@ contract RouterTest is Test { transceiverImpl = new TransceiverImpl(); } - function test_setSendTransceiver() public { - Integrator integrator = new Integrator(address(router)); - Admin admin = new Admin(address(integrator), address(router)); - Admin imposter = new Admin(address(integrator), address(router)); - address transceiver1 = address(0x111); - uint16 chain = 2; + function test_register() public { + address integrator = address(new Integrator(address(router))); + vm.startPrank(integrator); + + // Can't update the admin until we've set it. + vm.expectRevert(abi.encodeWithSelector(Router.IntegratorNotRegistered.selector)); + router.updateAdmin(integrator, address(0)); + + // Can't register admin of zero. + vm.expectRevert(abi.encodeWithSelector(Router.InvalidAdminZeroAddress.selector)); + router.register(address(0)); + + // But a real address should work. + Admin admin = new Admin(integrator, address(router)); + router.register(address(admin)); + require(router.getAdmin(integrator) == address(admin), "admin address is wrong"); + + // Can't register twice. + vm.expectRevert(abi.encodeWithSelector(Router.IntegratorAlreadyRegistered.selector)); + router.register(address(admin)); + + // Test updateAdmin(). + Admin newAdmin = new Admin(integrator, address(router)); + + // Only the admin can update. The integrator can't. + vm.expectRevert(abi.encodeWithSelector(Router.CallerNotAuthorized.selector)); + router.updateAdmin(integrator, address(newAdmin)); + + vm.startPrank(address(admin)); + + // We can set the admin to a new address. + router.updateAdmin(integrator, address(newAdmin)); + require(router.getAdmin(integrator) == address(newAdmin), "failed to update admin address"); + + // And the old admin should no longer be able to update. + vm.expectRevert(abi.encodeWithSelector(Router.CallerNotAuthorized.selector)); + router.updateAdmin(integrator, address(admin)); + + // But the new admin can. + vm.startPrank(address(newAdmin)); + Admin newerAdmin = new Admin(integrator, address(router)); + router.updateAdmin(integrator, address(newerAdmin)); + require(router.getAdmin(integrator) == address(newerAdmin), "failed to update admin address"); + + vm.startPrank(address(newerAdmin)); + // Two step update to first admin. + vm.expectRevert(abi.encodeWithSelector(Router.InvalidAdminZeroAddress.selector)); + router.transferAdmin(integrator, address(0)); + router.transferAdmin(integrator, address(admin)); + require(router.getAdmin(integrator) == address(newerAdmin), "updated admin address too early"); + vm.expectRevert(abi.encodeWithSelector(Router.AdminTransferInProgress.selector)); + router.transferAdmin(integrator, address(admin)); + vm.expectRevert(abi.encodeWithSelector(Router.AdminTransferInProgress.selector)); + router.updateAdmin(integrator, address(newerAdmin)); + vm.expectRevert(abi.encodeWithSelector(Router.AdminTransferInProgress.selector)); + router.discardAdmin(integrator); + // Test that the initiator can cancel the update. + router.claimAdmin(integrator); + require(router.getAdmin(integrator) == address(newerAdmin), "failed to update to first admin address"); + + // Two step update to new admin. + router.transferAdmin(integrator, address(newAdmin)); + require(router.getAdmin(integrator) == address(newerAdmin), "updated admin address too early"); + vm.startPrank(address(admin)); + vm.expectRevert(abi.encodeWithSelector(Router.CallerNotAuthorized.selector)); + router.claimAdmin(integrator); + // Test that the new admin can claim the update. + vm.startPrank(address(newAdmin)); + router.claimAdmin(integrator); + require(router.getAdmin(integrator) == address(newAdmin), "failed to update to new admin address"); + + // One step update to zero. + vm.expectRevert(abi.encodeWithSelector(Router.InvalidAdminZeroAddress.selector)); + router.updateAdmin(integrator, address(0)); + + router.discardAdmin(integrator); + } + + function test_addSendTransceiver() public { + address integrator = address(new Integrator(address(router))); + address admin = address(new Admin(integrator, address(router))); + address imposter = address(new Admin(integrator, address(router))); + TransceiverImpl transceiver1 = new TransceiverImpl(); + TransceiverImpl transceiver2 = new TransceiverImpl(); + TransceiverImpl transceiver3 = new TransceiverImpl(); + address taddr1 = address(transceiver1); + address taddr2 = address(transceiver2); + address taddr3 = address(transceiver3); + vm.startPrank(integrator); + + // Can't enable a transceiver until we've set the admin. + vm.expectRevert(abi.encodeWithSelector(Router.IntegratorNotRegistered.selector)); + router.addTransceiver(integrator, 1, taddr1); + + // Register the integrator and set the admin. + router.register(admin); + + // The admin can add a transceiver. + vm.startPrank(admin); + router.addTransceiver(integrator, 1, taddr1); + + // Others cannot add a transceiver. + vm.startPrank(imposter); + vm.expectRevert(abi.encodeWithSelector(Router.CallerNotAuthorized.selector)); + router.addTransceiver(integrator, 1, taddr1); + + // Can't register the transceiver twice. + vm.startPrank(admin); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyRegistered.selector, taddr1)); + router.addTransceiver(integrator, 1, taddr1); + // Can't enable the transceiver twice. + router.enableSendTransceiver(integrator, 1, taddr1); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, taddr1)); + router.enableSendTransceiver(integrator, 1, taddr1); + + router.addTransceiver(integrator, 1, taddr2); + address[] memory transceivers = router.getSendTransceiversByChain(integrator, 1); + require(transceivers.length == 1, "Wrong number of transceivers enabled on chain one, should be 1"); + // Enable another transceiver on chain one and one on chain two. + router.enableSendTransceiver(integrator, 1, taddr2); + router.addTransceiver(integrator, 2, taddr3); + router.enableSendTransceiver(integrator, 2, taddr3); + + // And verify they got set properly. + transceivers = router.getSendTransceiversByChain(integrator, 1); + require(transceivers.length == 2, "Wrong number of transceivers enabled on chain one"); + require(transceivers[0] == taddr1, "Wrong transceiver one on chain one"); + require(transceivers[1] == taddr2, "Wrong transceiver two on chain one"); + transceivers = router.getSendTransceiversByChain(integrator, 2); + require(transceivers.length == 1, "Wrong number of transceivers enabled on chain two"); + require(transceivers[0] == taddr3, "Wrong transceiver one on chain two"); + router.disableSendTransceiver(integrator, 2, taddr3); + require(transceivers.length == 1, "Wrong number of transceivers enabled on chain two"); + } + + function test_addRecvTransceiver() public { + address integrator = address(new Integrator(address(router))); + address admin = address(new Admin(integrator, address(router))); + address imposter = address(new Admin(integrator, address(router))); + TransceiverImpl transceiver1 = new TransceiverImpl(); + TransceiverImpl transceiver2 = new TransceiverImpl(); + TransceiverImpl transceiver3 = new TransceiverImpl(); + address taddr1 = address(transceiver1); + address taddr2 = address(transceiver2); + address taddr3 = address(transceiver3); + vm.startPrank(integrator); + + // Can't enable a transceiver until we've set the admin. + vm.expectRevert(abi.encodeWithSelector(Router.IntegratorNotRegistered.selector)); + router.addTransceiver(integrator, 1, taddr1); + + // Register the integrator and set the admin. + router.register(admin); + + // The admin can add a transceiver. + vm.startPrank(admin); + router.addTransceiver(integrator, 1, taddr1); + + // Others cannot add a transceiver. + vm.startPrank(imposter); + vm.expectRevert(abi.encodeWithSelector(Router.CallerNotAuthorized.selector)); + router.addTransceiver(integrator, 1, taddr1); - admin.requestAdmin(); - integrator.registerWithRouter(); - admin.setSendTransceiver(transceiver1, chain); - vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, transceiver1)); - admin.setSendTransceiver(transceiver1, chain); + // Can't register the transceiver twice. + vm.startPrank(admin); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyRegistered.selector, taddr1)); + router.addTransceiver(integrator, 1, taddr1); + // Can't enable the transceiver twice. + router.enableRecvTransceiver(integrator, 1, taddr1); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, taddr1)); + router.enableRecvTransceiver(integrator, 1, taddr1); - vm.expectRevert(abi.encodeWithSelector(Router.CallerNotAdmin.selector)); - imposter.setSendTransceiver(transceiver1, chain); - vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, transceiver1)); - admin.setSendTransceiver(transceiver1, chain); + router.addTransceiver(integrator, 1, taddr2); + address[] memory transceivers = router.getRecvTransceiversByChain(integrator, 1); + require(transceivers.length == 1, "Wrong number of transceivers enabled on chain one, should be 1"); + // Enable another transceiver on chain one and one on chain two. + router.enableRecvTransceiver(integrator, 1, taddr2); + router.addTransceiver(integrator, 2, taddr3); + router.enableRecvTransceiver(integrator, 2, taddr3); + + // And verify they got set properly. + transceivers = router.getRecvTransceiversByChain(integrator, 1); + require(transceivers.length == 2, "Wrong number of transceivers enabled on chain one"); + require(transceivers[0] == taddr1, "Wrong transceiver one on chain one"); + require(transceivers[1] == taddr2, "Wrong transceiver two on chain one"); + transceivers = router.getRecvTransceiversByChain(integrator, 2); + require(transceivers.length == 1, "Wrong number of transceivers enabled on chain two"); + require(transceivers[0] == taddr3, "Wrong transceiver one on chain two"); } - function test_setRecvTransceiver() public { - Integrator integrator = new Integrator(address(router)); - Admin admin = new Admin(address(integrator), address(router)); - Admin imposter = new Admin(address(integrator), address(router)); - address transceiver1 = address(0x111); + function test_sendMessage() public { + address integrator = address(new Integrator(address(router))); + address admin = address(new Admin(integrator, address(router))); uint16 chain = 2; + uint16 zeroChain = 0; + TransceiverImpl transceiver1 = new TransceiverImpl(); + TransceiverImpl transceiver2 = new TransceiverImpl(); + TransceiverImpl transceiver3 = new TransceiverImpl(); + vm.startPrank(integrator); + router.register(admin); + + // Sending with no transceivers should revert. + vm.startPrank(integrator); + vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); + uint64 sequence = router.sendMessage(2, UniversalAddressLibrary.fromAddress(userA), refundAddr, messageHash); + + // Now enable some transceivers. + vm.startPrank(admin); + router.addTransceiver(integrator, 2, address(transceiver1)); + router.enableSendTransceiver(integrator, 2, address(transceiver1)); + router.addTransceiver(integrator, 2, address(transceiver2)); + router.enableSendTransceiver(integrator, 2, address(transceiver2)); + router.addTransceiver(integrator, 3, address(transceiver3)); + router.enableSendTransceiver(integrator, 3, address(transceiver3)); + + // Only an integrator can call send. + vm.startPrank(userA); + vm.expectRevert(abi.encodeWithSelector(Router.IntegratorNotRegistered.selector)); + sequence = router.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), refundAddr, messageHash); + + // Send a message on chain two. It should go out on the first two transceivers, but not the third one. + vm.startPrank(integrator); + sequence = router.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), refundAddr, messageHash); + require(sequence == 0, "Sequence number is wrong"); + require(transceiver1.getMessagesSent() == 1, "Failed to send a message on transceiver 1"); + require(transceiver2.getMessagesSent() == 1, "Failed to send a message on transceiver 2"); + require(transceiver3.getMessagesSent() == 0, "Should not have sent a message on transceiver 3"); + + sequence = router.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), refundAddr, messageHash); + require(sequence == 1, "Second sequence number is wrong"); + require(transceiver1.getMessagesSent() == 2, "Failed to send second message on transceiver 1"); + require(transceiver2.getMessagesSent() == 2, "Failed to send second message on transceiver 2"); + require(transceiver3.getMessagesSent() == 0, "Should not have sent second message on transceiver 3"); - admin.requestAdmin(); - integrator.registerWithRouter(); - admin.setRecvTransceiver(transceiver1, chain); - vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, transceiver1)); - admin.setRecvTransceiver(transceiver1, chain); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); + sequence = router.sendMessage(zeroChain, UniversalAddressLibrary.fromAddress(userA), refundAddr, messageHash); + require(sequence == 0, "Failed sequence number is wrong"); // 0 because of the revert - vm.expectRevert(abi.encodeWithSelector(Router.CallerNotAdmin.selector)); - imposter.setRecvTransceiver(transceiver1, chain); - vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, transceiver1)); - admin.setRecvTransceiver(transceiver1, chain); + sequence = router.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), refundAddr, messageHash); + require(sequence == 2, "Third sequence number is wrong"); } - function test_sendMessageIncrementsSequence() public { - Integrator integrator = new Integrator(address(router)); - Admin admin = new Admin(address(integrator), address(router)); - address transceiver1 = address(0x111); + function test_attestMessage() public { + UniversalAddress sourceIntegrator = UniversalAddressLibrary.fromAddress(address(userA)); + address integrator = address(new Integrator(address(router))); + UniversalAddress destIntegrator = UniversalAddressLibrary.fromAddress(address(integrator)); + address admin = address(new Admin(integrator, address(router))); + TransceiverImpl transceiver1 = new TransceiverImpl(); + TransceiverImpl transceiver2 = new TransceiverImpl(); + TransceiverImpl transceiver3 = new TransceiverImpl(); uint16 chain = 2; - admin.requestAdmin(); - integrator.registerWithRouter(); - admin.setSendTransceiver(transceiver1, chain); - assertEq(router.nextMessageSequence(address(integrator)), 0); - // Send inital message from userA, going from unset to 1 - // vm.startPrank(userA); + uint16 anotherChain = 1; + vm.startPrank(integrator); + router.register(admin); + + // Attesting with no transceivers should revert. + vm.startPrank(integrator); + vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); + router.attestMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // Now enable some transceivers. + vm.startPrank(admin); + router.addTransceiver(integrator, chain, address(transceiver1)); + router.enableRecvTransceiver(integrator, chain, address(transceiver1)); + router.addTransceiver(integrator, chain, address(transceiver2)); + router.enableRecvTransceiver(integrator, chain, address(transceiver2)); + router.addTransceiver(integrator, chain + 1, address(transceiver3)); + router.enableRecvTransceiver(integrator, chain + 1, address(transceiver3)); + + // Only a transceiver can call attest. + vm.startPrank(userB); + vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); + router.attestMessage(chain, sourceIntegrator, anotherChain, OurChainId, destIntegrator, messageHash); + + // Attestinging a message destined for the wrong chain should revert. + vm.startPrank(address(transceiver2)); + vm.expectRevert(abi.encodeWithSelector(Router.InvalidDestinationChain.selector)); + router.attestMessage(chain, sourceIntegrator, anotherChain, OurChainId + 1, destIntegrator, messageHash); + + // This attest should work. + vm.startPrank(address(transceiver2)); + router.attestMessage(chain, sourceIntegrator, anotherChain, OurChainId, destIntegrator, messageHash); + + // Receive what we just attested to mark it executed. + vm.startPrank(integrator); + router.recvMessage(chain, sourceIntegrator, anotherChain, OurChainId, destIntegrator, messageHash); + + // Attesting after receive should still work. + vm.startPrank(address(transceiver2)); + router.attestMessage(chain, sourceIntegrator, anotherChain, OurChainId, destIntegrator, messageHash); + + // Attesting on a disabled transceiver should revert. + vm.startPrank(admin); + router.disableRecvTransceiver(integrator, 2, address(transceiver1)); + vm.startPrank(address(transceiver1)); + vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); + router.attestMessage(chain, sourceIntegrator, anotherChain, OurChainId, destIntegrator, messageHash); + } + + function test_recvMessage() public { + UniversalAddress sourceIntegrator = UniversalAddressLibrary.fromAddress(address(userA)); + address integrator = address(new Integrator(address(router))); + UniversalAddress destIntegrator = UniversalAddressLibrary.fromAddress(address(integrator)); + address admin = address(new Admin(integrator, address(router))); + TransceiverImpl transceiver1 = new TransceiverImpl(); + TransceiverImpl transceiver2 = new TransceiverImpl(); + TransceiverImpl transceiver3 = new TransceiverImpl(); + vm.startPrank(integrator); + router.register(admin); + + // Receiving with no transceivers should revert. + vm.startPrank(integrator); + vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); + router.recvMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // Now enable some transceivers so we can attest. Receive doesn't use the transceivers. + vm.startPrank(admin); + router.addTransceiver(integrator, 2, address(transceiver1)); + router.enableRecvTransceiver(integrator, 2, address(transceiver1)); + router.addTransceiver(integrator, 2, address(transceiver2)); + router.enableRecvTransceiver(integrator, 2, address(transceiver2)); + router.addTransceiver(integrator, 3, address(transceiver3)); + router.enableRecvTransceiver(integrator, 3, address(transceiver3)); + + // Only an integrator can call receive. + vm.startPrank(userB); + vm.expectRevert(abi.encodeWithSelector(Router.IntegratorNotRegistered.selector)); + router.recvMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // Receiving a message destine for the wrong chain should revert. + vm.startPrank(integrator); + vm.expectRevert(abi.encodeWithSelector(Router.InvalidDestinationChain.selector)); + router.recvMessage(2, sourceIntegrator, 1, OurChainId + 1, destIntegrator, messageHash); + + // Receiving before there are any attestations should revert. + vm.startPrank(integrator); + vm.expectRevert(abi.encodeWithSelector(Router.UnknownMessageAttestation.selector)); + router.recvMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // Attest so we can receive. + vm.startPrank(address(transceiver2)); + router.attestMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // This receive should work. + vm.startPrank(integrator); + (uint128 enabledBitmap, uint128 attestedBitmap) = + router.recvMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // Make sure it return the right bitmaps. + require(enabledBitmap == 0x03, "Enabled bitmap is wrong"); + require(attestedBitmap == 0x02, "Attested bitmap is wrong"); + + // But doing it again should revert. + vm.expectRevert(abi.encodeWithSelector(Router.AlreadyExecuted.selector)); + router.recvMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + } + + function test_getMessageStatus() public { + UniversalAddress sourceIntegrator = UniversalAddressLibrary.fromAddress(address(userA)); + address integrator = address(new Integrator(address(router))); + UniversalAddress destIntegrator = UniversalAddressLibrary.fromAddress(address(integrator)); + address admin = address(new Admin(integrator, address(router))); + TransceiverImpl transceiver1 = new TransceiverImpl(); + TransceiverImpl transceiver2 = new TransceiverImpl(); + TransceiverImpl transceiver3 = new TransceiverImpl(); + vm.startPrank(integrator); + router.register(admin); + + // No transceivers should revert. + vm.startPrank(integrator); // vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); - // router.sendMessage(1, UniversalAddressLibrary.fromAddress(userB), refundAddr, messageHash); - // assertEq(router.nextMessageSequence(userA), 0); - // address me = address(this); - // transceiverRegistry.registerAdmin(me); - // // Send additional message from userA, incrementing the existing sequence - // router.sendMessage(1, UniversalAddressLibrary.fromAddress(userB), refundAddr, messageHash); - // assertEq(router.nextMessageSequence(userA), 2); - } - - function testFuzz_sendMessage(address user) public { - // uint16 chainId = 2; - vm.startPrank(user); - // Register a transceiver - // Integrator integrator = new Integrator(address(router)); - // uint64 beforeSequence = router.nextMessageSequence(address(this)); - // address transceiver = address(transceiverImpl); - // integrator.setMeAsDelegate(); - // uint8 index = router.setSendTransceiver(transceiver, chainId); - // assert(index == 0); - // address[] memory enabledTransceivers = router.getSendTransceiversByChain(address(this), chainId); - // assert(enabledTransceivers.length == 1); - // router.sendMessage(chainId, UniversalAddressLibrary.fromAddress(user), refundAddr, messageHash); - // assertEq(router.nextMessageSequence(address(this)), beforeSequence + 1); - } - - function testFuzz_receiveMessage(address user) public { - // uint16 chainId = 2; - vm.startPrank(user); - // Integrator integrator = new Integrator(address(router)); - // address transceiver = address(0x111); - // integrator.setRecvTransceiver(transceiver, chainId); - // address[] memory enabledTransceivers = router.getRecvTransceiversByChain(address(integrator), chainId); - // assert(enabledTransceivers.length == 1); - // router.receiveMessage(1, UniversalAddressLibrary.fromAddress(user), refundAddr, messageHash); - } - - function testFuzz_attestMessage(address user) public { - // uint16 srcChain = 2; - // uint16 dstChain = 3; - // uint64 sequence = 1; - // bytes32 payloadHash = keccak256("hello, world"); - // sourceAddress = UniversalAddressLibrary.fromAddress(user); - // destinationAddress = UniversalAddressLibrary.fromAddress(user); - // vm.startPrank(user); - // Integrator integrator = new Integrator(address(router)); - // address transceiver = address(0x111); - // integrator.setRecvTransceiver(transceiver, dstChain); - // address[] memory enabledTransceivers = router.getRecvTransceiversByChain(address(integrator), dstChain); - // assert(enabledTransceivers.length == 1); - // router.attestMessage( - // srcChain, - // UniversalAddressLibrary.fromAddress(user), - // sequence, - // dstChain, - // UniversalAddressLibrary.fromAddress(user), - // payloadHash - // ); + router.getMessageStatus(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // Now enable some transceivers so we can attest. Receive doesn't use the transceivers. + vm.startPrank(admin); + router.addTransceiver(integrator, 2, address(transceiver1)); + router.enableRecvTransceiver(integrator, 2, address(transceiver1)); + router.addTransceiver(integrator, 2, address(transceiver2)); + router.enableRecvTransceiver(integrator, 2, address(transceiver2)); + router.addTransceiver(integrator, 3, address(transceiver3)); + router.enableRecvTransceiver(integrator, 3, address(transceiver3)); + + // Only an integrator can call receive. + vm.startPrank(userB); + vm.expectRevert(abi.encodeWithSelector(Router.IntegratorNotRegistered.selector)); + router.getMessageStatus(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // Receiving a message destine for the wrong chain should revert. + vm.startPrank(integrator); + vm.expectRevert(abi.encodeWithSelector(Router.InvalidDestinationChain.selector)); + router.getMessageStatus(2, sourceIntegrator, 1, OurChainId + 1, destIntegrator, messageHash); + + // Receiving before there are any attestations should revert. + vm.startPrank(integrator); + // vm.expectRevert(abi.encodeWithSelector(Router.UnknownMessageAttestation.selector)); + router.getMessageStatus(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // Attest so we can receive. + vm.startPrank(address(transceiver2)); + router.attestMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // This receive should work. + vm.startPrank(integrator); + (uint128 enabledBitmap, uint128 attestedBitmap, bool executed) = + router.getMessageStatus(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); + + // Make sure it return the right bitmaps. + require(enabledBitmap == 0x03, "Enabled bitmap is wrong"); + require(attestedBitmap == 0x02, "Attested bitmap is wrong"); + require(executed == false, "executed flag is wrong"); + } + + function test_execMessage() public { + address integrator = address(new Integrator(address(router))); + address admin = address(new Admin(integrator, address(router))); + TransceiverImpl transceiver1 = new TransceiverImpl(); + TransceiverImpl transceiver2 = new TransceiverImpl(); + uint16 chain1 = 1; + uint64 sequence = 1; + uint128 enabledBitmap; + uint128 attestedBitmap; + bool executed; + + // Register the integrator and set the admin. + vm.startPrank(integrator); + router.register(admin); + + (enabledBitmap, attestedBitmap, executed) = router.getMessageStatus( + chain1, + UniversalAddressLibrary.fromAddress(address(userA)), + sequence, + OurChainId, + UniversalAddressLibrary.fromAddress(address(integrator)), + messageHash + ); + require(executed == false, "executed flag should be false before execMessage"); + vm.expectRevert(abi.encodeWithSelector(Router.InvalidDestinationChain.selector)); + router.execMessage( + chain1, + UniversalAddressLibrary.fromAddress(address(transceiver1)), + sequence, + chain1, + UniversalAddressLibrary.fromAddress(address(transceiver2)), + messageHash + ); + router.execMessage( + chain1, + UniversalAddressLibrary.fromAddress(address(transceiver1)), + sequence, + OurChainId, + UniversalAddressLibrary.fromAddress(address(transceiver2)), + messageHash + ); + (enabledBitmap, attestedBitmap, executed) = router.getMessageStatus( + chain1, + UniversalAddressLibrary.fromAddress(address(userA)), + sequence, + OurChainId, + UniversalAddressLibrary.fromAddress(address(integrator)), + messageHash + ); + // require(executed == true, "executed flag should be true after execMessage"); + // Second execMessage should revert. + vm.expectRevert(abi.encodeWithSelector(Router.AlreadyExecuted.selector)); + router.execMessage( + chain1, + UniversalAddressLibrary.fromAddress(address(transceiver1)), + sequence, + OurChainId, + UniversalAddressLibrary.fromAddress(address(transceiver2)), + messageHash + ); + } + + function test_computeMessageHash() public view { + UniversalAddress sourceIntegrator = UniversalAddressLibrary.fromAddress(address(userA)); + UniversalAddress destIntegrator = UniversalAddressLibrary.fromAddress(address(userB)); + uint16 srcChain = 2; + uint16 dstChain = 42; + uint64 sequence = 3; + bytes32 payloadHash = keccak256("hello, world"); + bytes32 myMessageHash = + router.myComputeMessageHash(srcChain, sourceIntegrator, sequence, dstChain, destIntegrator, payloadHash); + bytes32 expectedHash = + keccak256(abi.encodePacked(srcChain, sourceIntegrator, sequence, dstChain, destIntegrator, payloadHash)); + require(myMessageHash == expectedHash, "Message hash is wrong"); + require( + myMessageHash == 0xf589999616054a74b876390c4eb6e067da272da5cd313a9657d33ec3cab06760, + "Message hash literal is wrong" + ); } } diff --git a/evm/test/TransceiverRegistry.t.sol b/evm/test/TransceiverRegistry.t.sol index 0a9ab406..5d56c97e 100644 --- a/evm/test/TransceiverRegistry.t.sol +++ b/evm/test/TransceiverRegistry.t.sol @@ -5,28 +5,16 @@ import {Test, console} from "forge-std/Test.sol"; import "../src/TransceiverRegistry.sol"; contract ConcreteTransceiverRegistry is TransceiverRegistry { - function rmvSendTransceiver(address integrator, address transceiver, uint16 chain) public { - _disableSendTransceiver(integrator, transceiver, chain); + function addTransceiver(address integrator, uint16 chain, address transceiver) public returns (uint8 index) { + return _addTransceiver(integrator, chain, transceiver); } - function rmvRecvTransceiver(address integrator, address transceiver, uint16 chain) public { - _disableRecvTransceiver(integrator, transceiver, chain); + function disableSendTransceiver(address integrator, uint16 chain, address transceiver) public { + _disableSendTransceiver(integrator, chain, transceiver); } - function setSendTransceiver(address integrator, address transceiver, uint16 chain) public returns (uint8 index) { - return _setSendTransceiver(integrator, transceiver, chain); - } - - function setRecvTransceiver(address integrator, address transceiver, uint16 chain) public returns (uint8 index) { - return _setRecvTransceiver(integrator, transceiver, chain); - } - - function disableSendTransceiver(address integrator, address transceiver, uint16 chain) public { - _disableSendTransceiver(integrator, transceiver, chain); - } - - function disableRecvTransceiver(address integrator, address transceiver, uint16 chain) public { - _disableRecvTransceiver(integrator, transceiver, chain); + function disableRecvTransceiver(address integrator, uint16 chain, address transceiver) public { + _disableRecvTransceiver(integrator, chain, transceiver); } function getRegisteredTransceiversStorage(address integrator) public view returns (address[] memory $) { @@ -40,9 +28,9 @@ contract ConcreteTransceiverRegistry is TransceiverRegistry { function getEnabledSendTransceiversBitmapForChain(address integrator, uint16 chain) public view - returns (uint128 bitmap) + returns (address[] memory transceivers) { - return _getEnabledSendTransceiversBitmapForChain(integrator, chain); + return _getEnabledSendTransceiversArrayForChain(integrator, chain); } function getEnabledRecvTransceiversBitmapForChain(address integrator, uint16 chain) @@ -53,28 +41,28 @@ contract ConcreteTransceiverRegistry is TransceiverRegistry { return _getEnabledRecvTransceiversBitmapForChain(integrator, chain); } - function enableSendTransceiverForChain(address integrator, address transceiver, uint16 chainId) public { - _enableSendTransceiverForChain(integrator, transceiver, chainId); + function enableSendTransceiver(address integrator, uint16 chainId, address transceiver) public { + _enableSendTransceiver(integrator, chainId, transceiver); } - function enableRecvTransceiverForChain(address integrator, address transceiver, uint16 chainId) public { - _enableRecvTransceiverForChain(integrator, transceiver, chainId); + function enableRecvTransceiver(address integrator, uint16 chainId, address transceiver) public { + _enableRecvTransceiver(integrator, chainId, transceiver); } - function isSendTransceiverEnabledForChain(address integrator, address transceiver, uint16 chainId) + function isSendTransceiverEnabledForChain(address integrator, uint16 chainId, address transceiver) public view returns (bool) { - return _isSendTransceiverEnabledForChain(integrator, transceiver, chainId); + return _isSendTransceiverEnabledForChain(integrator, chainId, transceiver); } - function isRecvTransceiverEnabledForChain(address integrator, address transceiver, uint16 chainId) + function isRecvTransceiverEnabledForChain(address integrator, uint16 chainId, address transceiver) public view returns (bool) { - return _isRecvTransceiverEnabledForChain(integrator, transceiver, chainId); + return _isRecvTransceiverEnabledForChain(integrator, chainId, transceiver); } function getMaxTransceivers() public pure returns (uint8) { @@ -107,81 +95,90 @@ contract TransceiverRegistryTest is Test { // Send side assertEq(transceiverRegistry.getTransceivers(me).length, 0); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); - transceiverRegistry.setSendTransceiver(me, sendTransceiver, zeroChain); - transceiverRegistry.setSendTransceiver(me, sendTransceiver, chain); - vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); - transceiverRegistry.disableSendTransceiver(me, sendTransceiver, zeroChain); + transceiverRegistry.addTransceiver(me, zeroChain, sendTransceiver); + transceiverRegistry.addTransceiver(me, chain, sendTransceiver); // Recv side - // Transceiver was registered on the send side + // A transceiver was registered on the send side assertEq(transceiverRegistry.getTransceivers(me).length, 1); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); - transceiverRegistry.setRecvTransceiver(me, sendTransceiver, zeroChain); - transceiverRegistry.setRecvTransceiver(me, recvTransceiver, chain); - vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); - transceiverRegistry.disableRecvTransceiver(me, recvTransceiver, zeroChain); + transceiverRegistry.addTransceiver(me, zeroChain, recvTransceiver); + transceiverRegistry.addTransceiver(me, chain, recvTransceiver); } function test3() public { + // Need to add transceiver, then enable it, then disable it address me = address(this); // Send side vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, sendTransceiver)); - transceiverRegistry.rmvSendTransceiver(me, sendTransceiver, chain); + transceiverRegistry.disableSendTransceiver(me, chain, sendTransceiver); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); - transceiverRegistry.setSendTransceiver(me, zeroTransceiver, chain); - assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 0); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 0); - transceiverRegistry.setSendTransceiver(me, sendTransceiver, chain); - assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 1); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 1); + transceiverRegistry.addTransceiver(me, chain, zeroTransceiver); + require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 0, "S1"); + require(transceiverRegistry.getNumTransceiversStorage(me).registered == 0, "S2"); + transceiverRegistry.addTransceiver(me, chain, sendTransceiver); + require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 1, "S3"); + require(transceiverRegistry.getNumTransceiversStorage(me).registered == 1, "S4"); // assertEq(transceiverRegistry.getSendTransceiverInfos(integrator1).length, 1); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); - transceiverRegistry.disableSendTransceiver(me, sendTransceiver, zeroChain); - assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 1); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 1); - transceiverRegistry.disableSendTransceiver(me, sendTransceiver, chain); - // disabled, but stays registered - assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 1); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 1); - vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.DisabledTransceiver.selector, sendTransceiver)); - transceiverRegistry.disableSendTransceiver(me, sendTransceiver, chain); + transceiverRegistry.disableSendTransceiver(me, zeroChain, sendTransceiver); + require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 1, "S5"); + require(transceiverRegistry.getNumTransceiversStorage(me).registered == 1, "S6"); + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyDisabled.selector, sendTransceiver) + ); + transceiverRegistry.disableSendTransceiver(me, chain, sendTransceiver); + transceiverRegistry.enableSendTransceiver(me, chain, sendTransceiver); + require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 1, "S7"); + require(transceiverRegistry.getNumTransceiversStorage(me).registered == 1, "S8"); + transceiverRegistry.disableSendTransceiver(me, chain, sendTransceiver); + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyDisabled.selector, sendTransceiver) + ); + transceiverRegistry.disableSendTransceiver(me, chain, sendTransceiver); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); - transceiverRegistry.disableSendTransceiver(me, zeroTransceiver, chain); + transceiverRegistry.disableSendTransceiver(me, chain, zeroTransceiver); // assertEq(transceiverRegistry.getSendTransceiverInfos(integrator1).length, 0); // Recv side vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, recvTransceiver)); - transceiverRegistry.rmvRecvTransceiver(me, recvTransceiver, chain); + transceiverRegistry.disableRecvTransceiver(me, chain, recvTransceiver); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); - transceiverRegistry.setRecvTransceiver(me, zeroTransceiver, chain); + transceiverRegistry.addTransceiver(me, chain, zeroTransceiver); // Carry over from send side test - assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 1); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 1); - transceiverRegistry.setRecvTransceiver(me, recvTransceiver, chain); - assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 2); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 2); + require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 1, "R1"); + require(transceiverRegistry.getNumTransceiversStorage(me).registered == 1, "R2"); + transceiverRegistry.addTransceiver(me, chain, recvTransceiver); + require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 2, "R3"); + require(transceiverRegistry.getNumTransceiversStorage(me).registered == 2, "R4"); // assertEq(transceiverRegistry.getRecvTransceiverInfos(integrator1).length, 1); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); - transceiverRegistry.disableRecvTransceiver(me, recvTransceiver, zeroChain); - assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 2); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 2); - transceiverRegistry.disableRecvTransceiver(me, recvTransceiver, chain); - // disabled, but stays registered - assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 2); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 2); - vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.DisabledTransceiver.selector, recvTransceiver)); - transceiverRegistry.disableRecvTransceiver(me, recvTransceiver, chain); + transceiverRegistry.disableRecvTransceiver(me, zeroChain, recvTransceiver); + require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 2, "R5"); + require(transceiverRegistry.getNumTransceiversStorage(me).registered == 2, "R6"); + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyDisabled.selector, recvTransceiver) + ); + transceiverRegistry.disableRecvTransceiver(me, chain, recvTransceiver); + transceiverRegistry.enableRecvTransceiver(me, chain, recvTransceiver); + require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 2, "R7"); + require(transceiverRegistry.getNumTransceiversStorage(me).registered == 2, "R8"); + transceiverRegistry.disableRecvTransceiver(me, chain, recvTransceiver); + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyDisabled.selector, recvTransceiver) + ); + transceiverRegistry.disableRecvTransceiver(me, chain, recvTransceiver); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); - transceiverRegistry.disableRecvTransceiver(me, zeroTransceiver, chain); + transceiverRegistry.disableRecvTransceiver(me, chain, zeroTransceiver); // assertEq(transceiverRegistry.getRecvTransceiverInfos(integrator1).length, 0); } function test4() public { // Send side assertEq(transceiverRegistry.getRegisteredTransceiversStorage(integrator1).length, 0); - assertEq(transceiverRegistry.getEnabledSendTransceiversBitmapForChain(integrator1, chain), 0); + assertEq(transceiverRegistry.getEnabledSendTransceiversBitmapForChain(integrator1, chain).length, 0); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); - assertEq(transceiverRegistry.getEnabledSendTransceiversBitmapForChain(integrator1, zeroChain), 0); + assertEq(transceiverRegistry.getEnabledSendTransceiversBitmapForChain(integrator1, zeroChain).length, 0); // Recv side assertEq(transceiverRegistry.getRegisteredTransceiversStorage(integrator1).length, 0); @@ -217,29 +214,32 @@ contract TransceiverRegistryTest is Test { // Send side address sTransceiver = address(0x456); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, sTransceiver)); - transceiverRegistry.enableSendTransceiverForChain(me, sTransceiver, chain); + transceiverRegistry.enableSendTransceiver(me, chain, sTransceiver); // Recv side address rTransceiver = address(0x567); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, rTransceiver)); - transceiverRegistry.enableRecvTransceiverForChain(me, rTransceiver, chain); + transceiverRegistry.enableRecvTransceiver(me, chain, rTransceiver); } function test8() public { + uint16 zeroChainId = 0; uint16 chainId = 3; address me = address(this); // Send side vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); - transceiverRegistry.enableSendTransceiverForChain(me, zeroTransceiver, chainId); + transceiverRegistry.enableSendTransceiver(me, chainId, zeroTransceiver); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); - transceiverRegistry.isSendTransceiverEnabledForChain(me, zeroTransceiver, chainId); + transceiverRegistry.isSendTransceiverEnabledForChain(me, chainId, zeroTransceiver); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChainId)); + transceiverRegistry.isSendTransceiverEnabledForChain(me, zeroChainId, me); - // Recv side + // // Recv side vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); - transceiverRegistry.enableRecvTransceiverForChain(me, zeroTransceiver, chainId); + transceiverRegistry.enableRecvTransceiver(me, chainId, zeroTransceiver); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); - transceiverRegistry.isRecvTransceiverEnabledForChain(me, zeroTransceiver, chainId); + transceiverRegistry.isRecvTransceiverEnabledForChain(me, chainId, zeroTransceiver); } function test9() public { @@ -248,27 +248,27 @@ contract TransceiverRegistryTest is Test { // Send side address sTransceiver = address(0x345); - assertEq(transceiverRegistry.isSendTransceiverEnabledForChain(me, sTransceiver, chainId), false); - transceiverRegistry.setSendTransceiver(me, sTransceiver, chain); - assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 1); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 1); - transceiverRegistry.enableSendTransceiverForChain(me, sTransceiver, chainId); - bool enabled = transceiverRegistry.isSendTransceiverEnabledForChain(me, sTransceiver, chainId); - assertEq(enabled, true); - vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, sTransceiver)); - transceiverRegistry.setSendTransceiver(me, sTransceiver, chain); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, sTransceiver)); + require(transceiverRegistry.isSendTransceiverEnabledForChain(me, chainId, sTransceiver) == false, "S1"); + transceiverRegistry.addTransceiver(me, chain, sTransceiver); + require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 1, "S2"); + require(transceiverRegistry.getNumTransceiversStorage(me).registered == 1, "S3"); + transceiverRegistry.enableSendTransceiver(me, chainId, sTransceiver); + bool enabled = transceiverRegistry.isSendTransceiverEnabledForChain(me, chainId, sTransceiver); + require(enabled == true, "S4"); + transceiverRegistry.enableSendTransceiver(me, chain, sTransceiver); // Recv side address rTransceiver = address(0x453); - assertEq(transceiverRegistry.isRecvTransceiverEnabledForChain(me, rTransceiver, chainId), false); - transceiverRegistry.setRecvTransceiver(me, rTransceiver, chain); - assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, 2); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, 2); - transceiverRegistry.enableRecvTransceiverForChain(me, rTransceiver, chainId); - enabled = transceiverRegistry.isRecvTransceiverEnabledForChain(me, rTransceiver, chainId); - assertEq(enabled, true); - vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, rTransceiver)); - transceiverRegistry.setRecvTransceiver(me, rTransceiver, chain); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, rTransceiver)); + require(transceiverRegistry.isRecvTransceiverEnabledForChain(me, chainId, rTransceiver) == false, "R1"); + transceiverRegistry.addTransceiver(me, chain, rTransceiver); + require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 2, "R2"); + require(transceiverRegistry.getNumTransceiversStorage(me).registered == 2, "R3"); + transceiverRegistry.enableRecvTransceiver(me, chainId, rTransceiver); + enabled = transceiverRegistry.isRecvTransceiverEnabledForChain(me, chainId, rTransceiver); + require(enabled == true, "R4"); + transceiverRegistry.enableRecvTransceiver(me, chain, rTransceiver); } function test10() public { @@ -277,19 +277,25 @@ contract TransceiverRegistryTest is Test { // Send side for (uint8 i = 0; i < maxTransceivers; i++) { - transceiverRegistry.setSendTransceiver(me, address(uint160(i + 1)), chain); + transceiverRegistry.addTransceiver(me, chain, address(uint160(i + 1))); } assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TooManyTransceivers.selector)); - transceiverRegistry.setSendTransceiver(me, address(0x111), chain); + transceiverRegistry.addTransceiver(me, chain, address(0x111)); assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); - transceiverRegistry.disableSendTransceiver(me, address(0x1), chain); assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); - vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TooManyTransceivers.selector)); - transceiverRegistry.setSendTransceiver(me, address(0x111), chain); + for (uint8 i = 0; i < maxTransceivers; i++) { + transceiverRegistry.enableSendTransceiver(me, chain, address(uint160(i + 1))); + } + transceiverRegistry.disableSendTransceiver(me, chain, address(uint160(30))); + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyDisabled.selector, address(uint160(30))) + ); + transceiverRegistry.disableSendTransceiver(me, chain, address(uint160(30))); + transceiverRegistry.getSendTransceiversByChain(me, chain); } function test11() public { @@ -298,18 +304,73 @@ contract TransceiverRegistryTest is Test { // Recv side for (uint8 i = 0; i < maxTransceivers; i++) { - transceiverRegistry.setRecvTransceiver(me, address(uint160(i + 1)), chain); + transceiverRegistry.addTransceiver(me, chain, address(uint160(i + 1))); } assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TooManyTransceivers.selector)); - transceiverRegistry.setRecvTransceiver(me, address(0x111), chain); + transceiverRegistry.addTransceiver(me, chain, address(0x111)); assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); - transceiverRegistry.disableRecvTransceiver(me, address(0x1), chain); + transceiverRegistry.enableRecvTransceiver(me, chain, address(0x1)); + transceiverRegistry.enableRecvTransceiver(me, chain, address(0x2)); + transceiverRegistry.disableRecvTransceiver(me, chain, address(0x2)); + transceiverRegistry.disableRecvTransceiver(me, chain, address(0x1)); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyDisabled.selector, address(0x1))); + transceiverRegistry.disableSendTransceiver(me, chain, address(0x1)); assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TooManyTransceivers.selector)); - transceiverRegistry.setRecvTransceiver(me, address(0x111), chain); + transceiverRegistry.addTransceiver(me, chain, address(0x111)); + } + + function test_getSendTransceiversByChain() public { + address me = address(this); + uint16 chain1 = 1; + uint16 chain2 = 2; + address transceiver1 = address(0x1); // enabled, chain 1 + address transceiver2 = address(0x2); // enabled, chain 1 + address transceiver3 = address(0x3); // enabled, chain 2 + address transceiver4 = address(0x4); // disabled, chain 2 + + transceiverRegistry.addTransceiver(me, chain1, transceiver1); + transceiverRegistry.enableSendTransceiver(me, chain1, transceiver1); + transceiverRegistry.addTransceiver(me, chain1, transceiver2); + transceiverRegistry.enableSendTransceiver(me, chain1, transceiver2); + transceiverRegistry.addTransceiver(me, chain2, transceiver3); + transceiverRegistry.enableSendTransceiver(me, chain2, transceiver3); + transceiverRegistry.addTransceiver(me, chain2, transceiver4); + address[] memory chain1Addrs = transceiverRegistry.getSendTransceiversByChain(me, chain1); + require(chain1Addrs.length == 2, "Wrong number of transceivers enabled on chain one"); + address[] memory chain2Addrs = transceiverRegistry.getSendTransceiversByChain(me, chain2); + require(chain2Addrs.length == 1, "Wrong number of transceivers enabled on chain two"); + transceiverRegistry.enableSendTransceiver(me, chain2, transceiver4); + transceiverRegistry.disableSendTransceiver(me, chain2, transceiver3); + require(chain2Addrs.length == 1, "Wrong number of transceivers enabled on chain two"); + } + + function test_getRecvTransceiversByChain() public { + address me = address(this); + uint16 chain1 = 1; + uint16 chain2 = 2; + address transceiver1 = address(0x1); // enabled, chain 1 + address transceiver2 = address(0x2); // enabled, chain 1 + address transceiver3 = address(0x3); // enabled, chain 2 + address transceiver4 = address(0x4); // disabled, chain 2 + + transceiverRegistry.addTransceiver(me, chain1, transceiver1); + transceiverRegistry.enableRecvTransceiver(me, chain1, transceiver1); + transceiverRegistry.addTransceiver(me, chain1, transceiver2); + transceiverRegistry.enableRecvTransceiver(me, chain1, transceiver2); + transceiverRegistry.addTransceiver(me, chain2, transceiver3); + transceiverRegistry.enableRecvTransceiver(me, chain2, transceiver3); + transceiverRegistry.addTransceiver(me, chain2, transceiver4); + address[] memory chain1Addrs = transceiverRegistry.getRecvTransceiversByChain(me, chain1); + require(chain1Addrs.length == 2, "Wrong number of transceivers enabled on chain one"); + address[] memory chain2Addrs = transceiverRegistry.getRecvTransceiversByChain(me, chain2); + require(chain2Addrs.length == 1, "Wrong number of transceivers enabled on chain two"); + transceiverRegistry.enableRecvTransceiver(me, chain2, transceiver4); + transceiverRegistry.disableRecvTransceiver(me, chain2, transceiver3); + require(chain2Addrs.length == 1, "Wrong number of transceivers enabled on chain two"); } } From 6b3a938f0ad7e20bee6be4e11e041154f8ac8bfd Mon Sep 17 00:00:00 2001 From: Paul Noel Date: Thu, 24 Oct 2024 08:24:48 -0400 Subject: [PATCH 3/5] evm: changes from PR comments --- evm/README.md | 10 ++ evm/src/Router.sol | 100 +++++++++--------- evm/src/TransceiverRegistry.sol | 128 +++++++---------------- evm/src/interfaces/IRouterAdmin.sol | 5 +- evm/src/interfaces/IRouterIntegrator.sol | 6 +- evm/src/interfaces/ITransceiver.sol | 8 +- evm/test/Router.t.sol | 84 ++++++++------- evm/test/TransceiverRegistry.t.sol | 107 ++++++++----------- 8 files changed, 196 insertions(+), 252 deletions(-) diff --git a/evm/README.md b/evm/README.md index 45911f2a..1c713f20 100644 --- a/evm/README.md +++ b/evm/README.md @@ -122,6 +122,16 @@ struct AttestationInfo { // Integrator (message recipient) => message digest -> attestation info mapping(address => mapping(bytes32 => AttestationInfo)) perIntegratorAttestations; + +struct IntegratorConfig { + bool isInitialized; + address admin; + address pending_admin; +} + +// Integrator address => configuration information +// Used by Router to maintain admin information +mapping(address => IntegratorConfig) integratorConfigs ``` ## Development diff --git a/evm/src/Router.sol b/evm/src/Router.sol index ec10291a..986eb79b 100644 --- a/evm/src/Router.sol +++ b/evm/src/Router.sol @@ -14,7 +14,7 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS struct IntegratorConfig { bool isInitialized; address admin; - address transfer; + address pending_admin; } // =============== Immutables ============================================================ @@ -55,20 +55,20 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS event AdminUpdateRequested(address integrator, address oldAdmin, address newAdmin); /// @notice Emitted when a message has been sent. - /// @param messageHash The keccak256 of the message. It is, also, indexed. + /// @param messageDigest The keccak256 of the provided fields. It is, also, indexed. /// @param sender The address of the sender. + /// @param sequence The sequence of the message. /// @param recipient The address of the recipient. /// @param recipientChain The chainId of the recipient. - /// @param sequence The sequence of the message. - /// @param digest The digest of the message. + /// @param payloadDigest The digest of the payload (from the integrator). /// @dev Topic0 0x1c170583317700fb71bc583fa6fdd8ff893f6c3a15a79104f1681d6d9eb708ee event MessageSent( - bytes32 indexed messageHash, + bytes32 indexed messageDigest, UniversalAddress sender, + uint64 sequence, UniversalAddress recipient, uint16 recipientChain, - uint64 sequence, - bytes32 digest + bytes32 payloadDigest ); /// @notice Emitted when a message has been attested to. @@ -135,6 +135,10 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS /// @dev Selector: 0xc78a581c. error AdminTransferInProgress(); + /// @notice Error when there was an attempt to claim the admin while no transfer was in progress. + /// @dev Selector: 0xe8aba8ca. + error NoAdminTransferInProgress(); + /// @notice Error when the integrator tries to re-register. /// @dev Selector: 0x626bb491. error IntegratorAlreadyRegistered(); @@ -151,6 +155,10 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS /// @dev Selector: 0x1547aa01. error UnknownMessageAttestation(); + /// @notice Error when message is attempted to be attested multiple times. + /// @dev Selector: 0x833e2681. + error DuplicateMessageAttestation(); + /// @notice Error when message is already marked as executed. /// @dev Selector: 0x0dc10197. error AlreadyExecuted(); @@ -175,7 +183,7 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS } /// @dev Holds the integrator address to message digest to attestation info mapping. - /// mapping(address => IntegratorConfig) + /// mapping(address => mapping(bytes32 => AttestationInfo) bytes32 private constant ATTESTATION_INFO_SLOT = bytes32(uint256(keccak256("router.attestationInfo")) - 1); /// @dev Integrator address => message digest -> attestation info mapping. @@ -221,7 +229,7 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS // Update the storage. integratorConfigs[integrator] = - IntegratorConfig({isInitialized: true, admin: initialAdmin, transfer: address(0)}); + IntegratorConfig({isInitialized: true, admin: initialAdmin, pending_admin: address(0)}); emit IntegratorRegistered(integrator, initialAdmin); } @@ -234,7 +242,7 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS // Get the storage for this integrator contract mapping(address => IntegratorConfig) storage integratorConfigs = _getIntegratorConfigsStorage(); - if (integratorConfigs[integrator].transfer != address(0)) { + if (integratorConfigs[integrator].pending_admin != address(0)) { revert AdminTransferInProgress(); } @@ -252,12 +260,12 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS // Get the storage for this integrator contract mapping(address => IntegratorConfig) storage integratorConfigs = _getIntegratorConfigsStorage(); - if (integratorConfigs[integrator].transfer != address(0)) { + if (integratorConfigs[integrator].pending_admin != address(0)) { revert AdminTransferInProgress(); } // Update the storage with this request. - integratorConfigs[integrator].transfer = newAdmin; + integratorConfigs[integrator].pending_admin = newAdmin; emit AdminUpdateRequested(integrator, msg.sender, newAdmin); } @@ -266,18 +274,19 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS // Get the storage for this integrator contract mapping(address => IntegratorConfig) storage integratorConfigs = _getIntegratorConfigsStorage(); + if (integratorConfigs[integrator].pending_admin == address(0)) { + revert NoAdminTransferInProgress(); + } + address oldAdmin = integratorConfigs[integrator].admin; - address newAdmin = integratorConfigs[integrator].transfer; - if (msg.sender == oldAdmin) { - // This is the cancel case. - integratorConfigs[integrator].transfer = address(0); - } else if (msg.sender == newAdmin) { - // Update the storage with this request. - integratorConfigs[integrator].admin = newAdmin; - integratorConfigs[integrator].transfer = address(0); - } else { + address pendingAdmin = integratorConfigs[integrator].pending_admin; + address newAdmin = msg.sender; + if (newAdmin != oldAdmin && newAdmin != pendingAdmin) { revert CallerNotAuthorized(); } + // Update the storage with this request. + integratorConfigs[integrator].admin = newAdmin; + integratorConfigs[integrator].pending_admin = address(0); emit AdminUpdated(integrator, oldAdmin, newAdmin); } @@ -286,7 +295,7 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS // Get the storage for this integrator contract mapping(address => IntegratorConfig) storage integratorConfigs = _getIntegratorConfigsStorage(); - if (integratorConfigs[integrator].transfer != address(0)) { + if (integratorConfigs[integrator].pending_admin != address(0)) { revert AdminTransferInProgress(); } @@ -298,13 +307,13 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS // =============== Transceiver functions ======================================================= /// @inheritdoc IRouterAdmin - function addTransceiver(address integrator, uint16 chainId, address transceiver) + function addTransceiver(address integrator, address transceiver) external onlyAdmin(integrator) returns (uint8 index) { // Call the TransceiverRegistry version. - return _addTransceiver(integrator, chainId, transceiver); + return _addTransceiver(integrator, transceiver); } /// @inheritdoc IRouterAdmin @@ -346,10 +355,9 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS // =============== Message functions ======================================================= /// @inheritdoc IRouterIntegrator - function sendMessage(uint16 dstChain, UniversalAddress dstAddr, address refundAddress, bytes32 payloadHash) + function sendMessage(uint16 dstChain, UniversalAddress dstAddr, bytes32 payloadHash, address refundAddress) external payable - onlyIntegrator returns (uint64 sequence) { // get the enabled send transceivers for [msg.sender][dstChain] @@ -361,13 +369,12 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS UniversalAddress sender = UniversalAddressLibrary.fromAddress(msg.sender); // get the next sequence number for msg.sender sequence = _useMessageSequence(msg.sender); - UniversalAddress refundUA = UniversalAddressLibrary.fromAddress(refundAddress); for (uint256 i = 0; i < len;) { // quote the delivery price uint256 deliveryPrice = ITransceiver(sendTransceivers[i]).quoteDeliveryPrice(dstChain); // call sendMessage ITransceiver(sendTransceivers[i]).sendMessage{value: deliveryPrice}( - sender, dstChain, dstAddr, sequence, payloadHash, UniversalAddressLibrary.toBytes32(refundUA) + sender, sequence, dstChain, dstAddr, payloadHash, refundAddress ); unchecked { ++i; @@ -377,9 +384,9 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS emit MessageSent( _computeMessageDigest(ourChainId, sender, sequence, dstChain, dstAddr, payloadHash), sender, + sequence, dstAddr, dstChain, - sequence, payloadHash ); } @@ -393,7 +400,6 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS UniversalAddress dstAddr, bytes32 payloadHash ) external { - // This is called by the transceiver so we don't check onlyIntegrator. address integrator = dstAddr.toAddress(); // sanity check that destinationChain is this chain @@ -407,7 +413,7 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS } // Make sure it's enabled on the receive. - if (!_isRecvTransceiverEnabledForChain(integrator, srcChain, msg.sender)) { + if (!_isRecvTransceiverEnabledForChainWithCheck(integrator, srcChain, msg.sender)) { revert TransceiverNotEnabled(); } @@ -415,11 +421,16 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS bytes32 messageDigest = _computeMessageDigest(srcChain, srcAddr, sequence, dstChain, dstAddr, payloadHash); AttestationInfo storage attestationInfo = _getAttestationInfoStorage()[integrator][messageDigest]; + uint128 updatedTransceivers = attestationInfo.attestedTransceivers | uint128(1 << tsInfo.index); + // Check that this message has not already been attested. + if (updatedTransceivers == attestationInfo.attestedTransceivers) { + revert DuplicateMessageAttestation(); + } // It's okay to mark it as attested if it has already been executed. // set the bit in perIntegratorAttestations[dstAddr][digest] corresponding to msg.sender - attestationInfo.attestedTransceivers |= uint128(1 << tsInfo.index); + attestationInfo.attestedTransceivers = updatedTransceivers; emit MessageAttestedTo( _computeMessageDigest(srcChain, srcAddr, sequence, dstChain, dstAddr, payloadHash), srcChain, @@ -441,14 +452,12 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS uint16 dstChain, UniversalAddress dstAddr, bytes32 payloadHash - ) external payable onlyIntegrator returns (uint128 enabledBitmap, uint128 attestedBitmap) { + ) external payable returns (uint128 enabledBitmap, uint128 attestedBitmap) { // sanity check that dstChain is this chain if (dstChain != ourChainId) { revert InvalidDestinationChain(); } - // sanity check that msg.sender is integrator: The caller has onlyIntegrator. - enabledBitmap = _getEnabledRecvTransceiversBitmapForChain(msg.sender, srcChain); if (enabledBitmap == 0) { revert TransceiverNotEnabled(); @@ -486,7 +495,7 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS uint16 dstChain, UniversalAddress dstAddr, bytes32 payloadHash - ) external view onlyIntegrator returns (uint128 enabledBitmap, uint128 attestedBitmap, bool executed) { + ) external view returns (uint128 enabledBitmap, uint128 attestedBitmap, bool executed) { // sanity check that dstChain is this chain if (dstChain != ourChainId) { revert InvalidDestinationChain(); @@ -510,7 +519,7 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS uint16 dstChain, UniversalAddress dstAddr, bytes32 payloadHash - ) external onlyIntegrator { + ) external { if (dstChain != ourChainId) { revert InvalidDestinationChain(); } @@ -519,8 +528,7 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS AttestationInfo storage attestationInfo = _getAttestationInfoStorage()[msg.sender][messageDigest]; - bool executed = attestationInfo.executed; - if (executed) { + if (attestationInfo.executed) { revert AlreadyExecuted(); } attestationInfo.executed = true; @@ -528,14 +536,6 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS // =============== Internal ============================================================== - modifier onlyIntegrator() { - IntegratorConfig storage config = _getIntegratorConfigsStorage()[msg.sender]; - if (!config.isInitialized) { - revert IntegratorNotRegistered(); - } - _; - } - modifier onlyAdmin(address integrator) { IntegratorConfig storage config = _getIntegratorConfigsStorage()[integrator]; if (!config.isInitialized) { @@ -549,13 +549,13 @@ contract Router is IRouterAdmin, IRouterIntegrator, IRouterTransceiver, MessageS } function _computeMessageDigest( - uint16 sourceChain, + uint16 srcChain, UniversalAddress srcAddr, uint64 sequence, - uint16 destinationChain, + uint16 dstChain, UniversalAddress dstAddr, bytes32 payloadHash ) internal pure returns (bytes32) { - return keccak256(abi.encodePacked(sourceChain, srcAddr, sequence, destinationChain, dstAddr, payloadHash)); + return keccak256(abi.encodePacked(srcChain, srcAddr, sequence, dstChain, dstAddr, payloadHash)); } } diff --git a/evm/src/TransceiverRegistry.sol b/evm/src/TransceiverRegistry.sol index 0f6a6a91..987f6cbc 100644 --- a/evm/src/TransceiverRegistry.sol +++ b/evm/src/TransceiverRegistry.sol @@ -3,11 +3,6 @@ pragma solidity ^0.8.13; /// @title TransceiverRegistry /// @notice This contract is responsible for handling the registration of Transceivers. -/// @dev This contract checks that a few critical invariants hold when transceivers are added or removed, -/// including: -/// 1. If a transceiver is not registered, it should be not enabled. -/// 2. The value set in the bitmap of transceivers -/// should directly correspond to the whether the transceiver is enabled abstract contract TransceiverRegistry { /// @dev Information about registered transceivers. struct TransceiverInfo { @@ -22,26 +17,17 @@ abstract contract TransceiverRegistry { uint128 bitmap; // MAX_TRANSCEIVERS = 128 } - /// @dev Total number of registered transceivers. This number can only increase. - /// invariant: numRegisteredTransceivers <= MAX_TRANSCEIVERS - /// invariant: forall (i: uint8), - /// i < numRegisteredTransceivers <=> exists (a: address), transceiverInfos[a].index == i - struct _NumTransceivers { - uint8 registered; - } - uint8 constant MAX_TRANSCEIVERS = 128; // =============== Events =============================================== /// @notice Emitted when a transceiver is added. /// @dev Topic0 - /// 0x11b137fe0ddc829607ddb73c998c8792af425d23c3d44235e97ba9d0ded66e58 + /// 0x21bd18575f35e922dfe885784f1c36fe1c055f9a74fec0e9d113930f47e14bf2 /// @param integrator The address of the integrator. - /// @param chain The Wormhole chain ID on which this transceiver is added. /// @param transceiver The address of the transceiver. /// @param transceiversNum The current number of transceivers. - event TransceiverAdded(address integrator, uint16 chain, address transceiver, uint8 transceiversNum); + event TransceiverAdded(address integrator, address transceiver, uint8 transceiversNum); /// @notice Emitted when a send side transceiver is enabled for a chain. /// @dev Topic0 @@ -127,11 +113,6 @@ abstract contract TransceiverRegistry { bytes32 private constant REGISTERED_TRANSCEIVERS_SLOT = bytes32(uint256(keccak256("registry.registeredTransceivers")) - 1); - /// @dev Holds send side Integrator address => NumTransceivers mapping. - /// mapping(address => _NumTransceivers) - bytes32 private constant NUM_REGISTERED_TRANSCEIVERS_SLOT = - bytes32(uint256(keccak256("registry.numRegisteredTransceivers")) - 1); - // =============== Send side ============================================= /// @dev Holds integrator address => Chain ID => Enabled send side transceiver address[] mapping. @@ -139,11 +120,6 @@ abstract contract TransceiverRegistry { bytes32 private constant ENABLED_SEND_TRANSCEIVER_ARRAY_SLOT = bytes32(uint256(keccak256("registry.sendTransceiverArray")) - 1); - /// @dev Holds send side Integrator address => Transceiver addresses mapping. - /// mapping(address => address[]) across all chains - bytes32 private constant REGISTERED_SEND_TRANSCEIVERS_SLOT = - bytes32(uint256(keccak256("registry.registeredSendTransceivers")) - 1); - // =============== Recv side ============================================= /// @dev Holds integrator address => Chain ID => Enabled transceiver receive side bitmap mapping. @@ -151,11 +127,6 @@ abstract contract TransceiverRegistry { bytes32 private constant ENABLED_RECV_TRANSCEIVER_BITMAP_SLOT = bytes32(uint256(keccak256("registry.recvTransceiverBitmap")) - 1); - /// @dev Holds receive side Integrator address => Transceiver addresses mapping. - /// mapping(address => address[]) across all chains - bytes32 private constant REGISTERED_RECV_TRANSCEIVERS_SLOT = - bytes32(uint256(keccak256("registry.registeredRecvTransceivers")) - 1); - // =============== Mappings =============================================== /// @dev Integrator address => transceiver address => TransceiverInfo mapping. @@ -203,31 +174,8 @@ abstract contract TransceiverRegistry { } } - /// @dev Integrator address => NumTransceivers mapping. - /// Contains number of registered transceivers for this integrator. - /// The transceivers may or may not be enabled. - function _getNumTransceiversStorage() internal pure returns (mapping(address => _NumTransceivers) storage $) { - uint256 slot = uint256(NUM_REGISTERED_TRANSCEIVERS_SLOT); - assembly ("memory-safe") { - $.slot := slot - } - } - // =============== Modifiers ====================================================== - /// @notice This modifier will revert if the transceiver is a zero address or the chain is invalid - /// @param transceiver The transceiver address - /// @param chain The Wormhole chain ID - modifier onlyValidTransceiverAndChain(address transceiver, uint16 chain) { - if (transceiver == address(0)) { - revert InvalidTransceiverZeroAddress(); - } - if (chain == 0) { - revert InvalidChain(chain); - } - _; - } - /// @notice This modifier will revert if the transceiver is an invalid address or not registered. /// Or the chain is invalid /// @param integrator The integrator address @@ -252,32 +200,29 @@ abstract contract TransceiverRegistry { /// @dev This function adds a transceiver. /// @param integrator The integrator address - /// @param chain The The Wormhole chain ID /// @param transceiver The transceiver address /// @return index The index of this newly enabled transceiver - function _addTransceiver(address integrator, uint16 chain, address transceiver) - internal - onlyValidTransceiverAndChain(transceiver, chain) - returns (uint8 index) - { + function _addTransceiver(address integrator, address transceiver) internal returns (uint8 index) { + if (transceiver == address(0)) { + revert InvalidTransceiverZeroAddress(); + } mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); - mapping(address => _NumTransceivers) storage _numTransceivers = _getNumTransceiversStorage(); + mapping(address => address[]) storage registeredTransceivers = _getRegisteredTransceiversStorage(); if (transceiverInfos[integrator][transceiver].registered) { revert TransceiverAlreadyRegistered(transceiver); } - if (_numTransceivers[integrator].registered >= MAX_TRANSCEIVERS) { + uint8 registeredTransceiversLength = uint8(registeredTransceivers[integrator].length); + if (registeredTransceiversLength >= MAX_TRANSCEIVERS) { revert TooManyTransceivers(); } // Create the TransceiverInfo transceiverInfos[integrator][transceiver] = - TransceiverInfo({registered: true, index: _numTransceivers[integrator].registered}); + TransceiverInfo({registered: true, index: registeredTransceiversLength}); // Add this transceiver to the integrator => address[] mapping _getRegisteredTransceiversStorage()[integrator].push(transceiver); - // Increment count of transceivers - _numTransceivers[integrator].registered++; // Emit an event - emit TransceiverAdded(integrator, chain, transceiver, _numTransceivers[integrator].registered); + emit TransceiverAdded(integrator, transceiver, registeredTransceiversLength); return transceiverInfos[integrator][transceiver].index; } @@ -327,7 +272,6 @@ abstract contract TransceiverRegistry { internal onlyRegisteredTransceiver(integrator, chain, transceiver) { - // mapping(address => mapping(address => TransceiverInfo)) storage transceiverInfos = _getTransceiverInfosStorage(); mapping(address => mapping(uint16 => address[])) storage enabledSendTransceivers = _getPerChainSendTransceiverArrayStorage(); address[] storage transceivers = enabledSendTransceivers[integrator][chain]; @@ -382,15 +326,24 @@ abstract contract TransceiverRegistry { emit RecvTransceiverDisabledForChain(integrator, chain, transceiver); } - /// @dev Returns if the send side transceiver is enabled for the given integrator and chain. + function _isSendTransceiverEnabledForChainWithCheck(address integrator, uint16 chain, address transceiver) + internal + view + onlyRegisteredTransceiver(integrator, chain, transceiver) + returns (bool) + { + return _isSendTransceiverEnabledForChain(integrator, chain, transceiver); + } + + /// @notice Returns if the send side transceiver is enabled for the given integrator and chain. + /// @dev This function is private and should only be called by a function that checks the validity of chain and transceiver. /// @param integrator The integrator address /// @param chain The Wormhole chain ID /// @param transceiver The transceiver address /// @return true if the transceiver is enabled, false otherwise. function _isSendTransceiverEnabledForChain(address integrator, uint16 chain, address transceiver) - internal + private view - onlyRegisteredTransceiver(integrator, chain, transceiver) returns (bool) { address[] storage transceivers = _getPerChainSendTransceiverArrayStorage()[integrator][chain]; @@ -406,20 +359,29 @@ abstract contract TransceiverRegistry { return false; } - /// @dev Returns if the receive side transceiver is enabled for the given integrator and chain. + function _isRecvTransceiverEnabledForChainWithCheck(address integrator, uint16 chain, address transceiver) + internal + view + onlyRegisteredTransceiver(integrator, chain, transceiver) + returns (bool) + { + return _isRecvTransceiverEnabledForChain(integrator, chain, transceiver); + } + + /// @notice Returns if the receive side transceiver is enabled for the given integrator and chain. + /// @dev This function is private and should only be called by a function that checks the validity of chain and transceiver. /// @param integrator The integrator address /// @param chain The Wormhole chain ID /// @param transceiver The transceiver address /// @return true if the transceiver is enabled, false otherwise. function _isRecvTransceiverEnabledForChain(address integrator, uint16 chain, address transceiver) - internal + private view - onlyRegisteredTransceiver(integrator, chain, transceiver) returns (bool) { uint128 bitmap = _getEnabledRecvTransceiversBitmapForChain(integrator, chain); uint8 index = _getTransceiverInfosStorage()[integrator][transceiver].index; - return (bitmap & uint128(1 << index)) != 0; + return (bitmap & uint128(1 << index)) > 0; } /// @param integrator The integrator address @@ -473,23 +435,7 @@ abstract contract TransceiverRegistry { if (chain == 0) { revert InvalidChain(chain); } - address[] memory allTransceivers = _getRegisteredTransceiversStorage()[integrator]; - address[] memory tempResult = new address[](allTransceivers.length); - uint8 len = 0; - uint256 allTransceiversLength = allTransceivers.length; - for (uint256 i = 0; i < allTransceiversLength;) { - if (_isSendTransceiverEnabledForChain(integrator, chain, allTransceivers[i])) { - tempResult[len] = allTransceivers[i]; - ++len; - } - unchecked { - ++i; - } - } - result = new address[](len); - for (uint8 i = 0; i < len; i++) { - result[i] = tempResult[i]; - } + result = _getEnabledSendTransceiversArrayForChain(integrator, chain); } /// @notice Returns the enabled receive side transceiver addresses for the given integrator. diff --git a/evm/src/interfaces/IRouterAdmin.sol b/evm/src/interfaces/IRouterAdmin.sol index 3d2388ce..607b63fc 100644 --- a/evm/src/interfaces/IRouterAdmin.sol +++ b/evm/src/interfaces/IRouterAdmin.sol @@ -14,7 +14,7 @@ interface IRouterAdmin { /// @param newAdmin The address of the new admin. function transferAdmin(address integrator, address newAdmin) external; - /// @notice Starts the two step process of transferring admin privileges from the current admin to another contract. + /// @notice Completes the two step process of transferring admin privileges from the current admin to another contract. /// @dev The msg.sender must be the current admin contract. /// @param integrator The address of the integrator contract. function claimAdmin(address integrator) external; @@ -29,8 +29,7 @@ interface IRouterAdmin { /// This does NOT enable the transceiver for sending or receiving. /// @param integrator The address of the integrator contract. /// @param transceiver The address of the Transceiver contract. - /// @param chainId The chain ID of the Transceiver contract. - function addTransceiver(address integrator, uint16 chainId, address transceiver) external returns (uint8 index); + function addTransceiver(address integrator, address transceiver) external returns (uint8 index); /// @notice This enables the sending of messages from the given transceiver on the given chain. /// @param integrator The address of the integrator contract. diff --git a/evm/src/interfaces/IRouterIntegrator.sol b/evm/src/interfaces/IRouterIntegrator.sol index 4e9a8709..7c766add 100644 --- a/evm/src/interfaces/IRouterIntegrator.sol +++ b/evm/src/interfaces/IRouterIntegrator.sol @@ -8,16 +8,16 @@ interface IRouterIntegrator is IMessageSequence { /// @notice This is the first thing an integrator should do. It registers the integrator with the router /// and sets the administrator contract for that integrator. The admin address is used to manage the transceivers. /// @dev The msg.sender needs to be the integrator contract. - /// @param initialAdmin The address of the admin. Pass in msg.sender, if you want the integrator to be the admin. + /// @param initialAdmin The address of the admin. function register(address initialAdmin) external; /// @notice Send a message to another chain. /// @param dstChain The Wormhole chain ID of the recipient. /// @param dstAddr The universal address of the peer on the recipient chain. - /// @param refundAddress The source chain refund address passed to the Transceiver. /// @param payloadHash keccak256 of a message to be sent to the recipient chain. /// @return uint64 The sequence number of the message. - function sendMessage(uint16 dstChain, UniversalAddress dstAddr, address refundAddress, bytes32 payloadHash) + /// @param refundAddress The source chain refund address passed to the Transceiver. + function sendMessage(uint16 dstChain, UniversalAddress dstAddr, bytes32 payloadHash, address refundAddress) external payable returns (uint64); diff --git a/evm/src/interfaces/ITransceiver.sol b/evm/src/interfaces/ITransceiver.sol index f0a3590e..bd1b6763 100644 --- a/evm/src/interfaces/ITransceiver.sol +++ b/evm/src/interfaces/ITransceiver.sol @@ -20,17 +20,17 @@ interface ITransceiver { /// @dev Send a message to another chain. /// @param srcAddr The universal address of the sender. + /// @param sequence The per-integrator sequence number associated with the message. /// @param dstChain The Wormhole chain ID of the recipient. /// @param dstAddr The universal address of the recipient. - /// @param sequence The per-integrator sequence number associated with the message. /// @param payloadHash The hash of the message to be sent to the recipient chain. - /// @param refundAddr The address of the refund recipient + /// @param refundAddr The address of the refund recipient. function sendMessage( UniversalAddress srcAddr, + uint64 sequence, uint16 dstChain, UniversalAddress dstAddr, - uint64 sequence, bytes32 payloadHash, - bytes32 refundAddr + address refundAddr ) external payable; } diff --git a/evm/test/Router.t.sol b/evm/test/Router.t.sol index f764427c..f54831c6 100644 --- a/evm/test/Router.t.sol +++ b/evm/test/Router.t.sol @@ -61,11 +61,11 @@ contract TransceiverImpl is ITransceiver { function sendMessage( UniversalAddress, // sourceAddress, + uint64, // sequence, uint16, // recipientChain, UniversalAddress, // recipientAddress, - uint64, // sequence, bytes32, // payloadHash, - bytes32 // refundAddress + address // refundAddress ) public payable override { messagesSent += 1; } @@ -139,6 +139,10 @@ contract RouterTest is Test { router.updateAdmin(integrator, address(newerAdmin)); require(router.getAdmin(integrator) == address(newerAdmin), "failed to update admin address"); + // Cannot claim if there is no transfer in progress. + vm.expectRevert(abi.encodeWithSelector(Router.NoAdminTransferInProgress.selector)); + router.claimAdmin(integrator); + vm.startPrank(address(newerAdmin)); // Two step update to first admin. vm.expectRevert(abi.encodeWithSelector(Router.InvalidAdminZeroAddress.selector)); @@ -187,35 +191,35 @@ contract RouterTest is Test { // Can't enable a transceiver until we've set the admin. vm.expectRevert(abi.encodeWithSelector(Router.IntegratorNotRegistered.selector)); - router.addTransceiver(integrator, 1, taddr1); + router.addTransceiver(integrator, taddr1); // Register the integrator and set the admin. router.register(admin); // The admin can add a transceiver. vm.startPrank(admin); - router.addTransceiver(integrator, 1, taddr1); + router.addTransceiver(integrator, taddr1); // Others cannot add a transceiver. vm.startPrank(imposter); vm.expectRevert(abi.encodeWithSelector(Router.CallerNotAuthorized.selector)); - router.addTransceiver(integrator, 1, taddr1); + router.addTransceiver(integrator, taddr1); // Can't register the transceiver twice. vm.startPrank(admin); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyRegistered.selector, taddr1)); - router.addTransceiver(integrator, 1, taddr1); + router.addTransceiver(integrator, taddr1); // Can't enable the transceiver twice. router.enableSendTransceiver(integrator, 1, taddr1); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, taddr1)); router.enableSendTransceiver(integrator, 1, taddr1); - router.addTransceiver(integrator, 1, taddr2); + router.addTransceiver(integrator, taddr2); address[] memory transceivers = router.getSendTransceiversByChain(integrator, 1); require(transceivers.length == 1, "Wrong number of transceivers enabled on chain one, should be 1"); // Enable another transceiver on chain one and one on chain two. router.enableSendTransceiver(integrator, 1, taddr2); - router.addTransceiver(integrator, 2, taddr3); + router.addTransceiver(integrator, taddr3); router.enableSendTransceiver(integrator, 2, taddr3); // And verify they got set properly. @@ -244,35 +248,35 @@ contract RouterTest is Test { // Can't enable a transceiver until we've set the admin. vm.expectRevert(abi.encodeWithSelector(Router.IntegratorNotRegistered.selector)); - router.addTransceiver(integrator, 1, taddr1); + router.addTransceiver(integrator, taddr1); // Register the integrator and set the admin. router.register(admin); // The admin can add a transceiver. vm.startPrank(admin); - router.addTransceiver(integrator, 1, taddr1); + router.addTransceiver(integrator, taddr1); // Others cannot add a transceiver. vm.startPrank(imposter); vm.expectRevert(abi.encodeWithSelector(Router.CallerNotAuthorized.selector)); - router.addTransceiver(integrator, 1, taddr1); + router.addTransceiver(integrator, taddr1); // Can't register the transceiver twice. vm.startPrank(admin); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyRegistered.selector, taddr1)); - router.addTransceiver(integrator, 1, taddr1); + router.addTransceiver(integrator, taddr1); // Can't enable the transceiver twice. router.enableRecvTransceiver(integrator, 1, taddr1); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyEnabled.selector, taddr1)); router.enableRecvTransceiver(integrator, 1, taddr1); - router.addTransceiver(integrator, 1, taddr2); + router.addTransceiver(integrator, taddr2); address[] memory transceivers = router.getRecvTransceiversByChain(integrator, 1); require(transceivers.length == 1, "Wrong number of transceivers enabled on chain one, should be 1"); // Enable another transceiver on chain one and one on chain two. router.enableRecvTransceiver(integrator, 1, taddr2); - router.addTransceiver(integrator, 2, taddr3); + router.addTransceiver(integrator, taddr3); router.enableRecvTransceiver(integrator, 2, taddr3); // And verify they got set properly. @@ -299,41 +303,41 @@ contract RouterTest is Test { // Sending with no transceivers should revert. vm.startPrank(integrator); vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); - uint64 sequence = router.sendMessage(2, UniversalAddressLibrary.fromAddress(userA), refundAddr, messageHash); + uint64 sequence = router.sendMessage(2, UniversalAddressLibrary.fromAddress(userA), messageHash, refundAddr); // Now enable some transceivers. vm.startPrank(admin); - router.addTransceiver(integrator, 2, address(transceiver1)); + router.addTransceiver(integrator, address(transceiver1)); router.enableSendTransceiver(integrator, 2, address(transceiver1)); - router.addTransceiver(integrator, 2, address(transceiver2)); + router.addTransceiver(integrator, address(transceiver2)); router.enableSendTransceiver(integrator, 2, address(transceiver2)); - router.addTransceiver(integrator, 3, address(transceiver3)); + router.addTransceiver(integrator, address(transceiver3)); router.enableSendTransceiver(integrator, 3, address(transceiver3)); // Only an integrator can call send. vm.startPrank(userA); - vm.expectRevert(abi.encodeWithSelector(Router.IntegratorNotRegistered.selector)); - sequence = router.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), refundAddr, messageHash); + vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); + sequence = router.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), messageHash, refundAddr); // Send a message on chain two. It should go out on the first two transceivers, but not the third one. vm.startPrank(integrator); - sequence = router.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), refundAddr, messageHash); + sequence = router.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), messageHash, refundAddr); require(sequence == 0, "Sequence number is wrong"); require(transceiver1.getMessagesSent() == 1, "Failed to send a message on transceiver 1"); require(transceiver2.getMessagesSent() == 1, "Failed to send a message on transceiver 2"); require(transceiver3.getMessagesSent() == 0, "Should not have sent a message on transceiver 3"); - sequence = router.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), refundAddr, messageHash); + sequence = router.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), messageHash, refundAddr); require(sequence == 1, "Second sequence number is wrong"); require(transceiver1.getMessagesSent() == 2, "Failed to send second message on transceiver 1"); require(transceiver2.getMessagesSent() == 2, "Failed to send second message on transceiver 2"); require(transceiver3.getMessagesSent() == 0, "Should not have sent second message on transceiver 3"); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); - sequence = router.sendMessage(zeroChain, UniversalAddressLibrary.fromAddress(userA), refundAddr, messageHash); + sequence = router.sendMessage(zeroChain, UniversalAddressLibrary.fromAddress(userA), messageHash, refundAddr); require(sequence == 0, "Failed sequence number is wrong"); // 0 because of the revert - sequence = router.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), refundAddr, messageHash); + sequence = router.sendMessage(chain, UniversalAddressLibrary.fromAddress(userA), messageHash, refundAddr); require(sequence == 2, "Third sequence number is wrong"); } @@ -357,11 +361,11 @@ contract RouterTest is Test { // Now enable some transceivers. vm.startPrank(admin); - router.addTransceiver(integrator, chain, address(transceiver1)); + router.addTransceiver(integrator, address(transceiver1)); router.enableRecvTransceiver(integrator, chain, address(transceiver1)); - router.addTransceiver(integrator, chain, address(transceiver2)); + router.addTransceiver(integrator, address(transceiver2)); router.enableRecvTransceiver(integrator, chain, address(transceiver2)); - router.addTransceiver(integrator, chain + 1, address(transceiver3)); + router.addTransceiver(integrator, address(transceiver3)); router.enableRecvTransceiver(integrator, chain + 1, address(transceiver3)); // Only a transceiver can call attest. @@ -378,12 +382,16 @@ contract RouterTest is Test { vm.startPrank(address(transceiver2)); router.attestMessage(chain, sourceIntegrator, anotherChain, OurChainId, destIntegrator, messageHash); + // Multiple Attests from same transceiver should revert. + vm.expectRevert(abi.encodeWithSelector(Router.DuplicateMessageAttestation.selector)); + router.attestMessage(chain, sourceIntegrator, anotherChain, OurChainId, destIntegrator, messageHash); + // Receive what we just attested to mark it executed. vm.startPrank(integrator); router.recvMessage(chain, sourceIntegrator, anotherChain, OurChainId, destIntegrator, messageHash); - // Attesting after receive should still work. - vm.startPrank(address(transceiver2)); + // Attesting after receive should still work on a different transceiver. + vm.startPrank(address(transceiver1)); router.attestMessage(chain, sourceIntegrator, anotherChain, OurChainId, destIntegrator, messageHash); // Attesting on a disabled transceiver should revert. @@ -412,16 +420,16 @@ contract RouterTest is Test { // Now enable some transceivers so we can attest. Receive doesn't use the transceivers. vm.startPrank(admin); - router.addTransceiver(integrator, 2, address(transceiver1)); + router.addTransceiver(integrator, address(transceiver1)); router.enableRecvTransceiver(integrator, 2, address(transceiver1)); - router.addTransceiver(integrator, 2, address(transceiver2)); + router.addTransceiver(integrator, address(transceiver2)); router.enableRecvTransceiver(integrator, 2, address(transceiver2)); - router.addTransceiver(integrator, 3, address(transceiver3)); + router.addTransceiver(integrator, address(transceiver3)); router.enableRecvTransceiver(integrator, 3, address(transceiver3)); // Only an integrator can call receive. vm.startPrank(userB); - vm.expectRevert(abi.encodeWithSelector(Router.IntegratorNotRegistered.selector)); + vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); router.recvMessage(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); // Receiving a message destine for the wrong chain should revert. @@ -465,24 +473,22 @@ contract RouterTest is Test { // No transceivers should revert. vm.startPrank(integrator); - // vm.expectRevert(abi.encodeWithSelector(Router.TransceiverNotEnabled.selector)); router.getMessageStatus(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); // Now enable some transceivers so we can attest. Receive doesn't use the transceivers. vm.startPrank(admin); - router.addTransceiver(integrator, 2, address(transceiver1)); + router.addTransceiver(integrator, address(transceiver1)); router.enableRecvTransceiver(integrator, 2, address(transceiver1)); - router.addTransceiver(integrator, 2, address(transceiver2)); + router.addTransceiver(integrator, address(transceiver2)); router.enableRecvTransceiver(integrator, 2, address(transceiver2)); - router.addTransceiver(integrator, 3, address(transceiver3)); + router.addTransceiver(integrator, address(transceiver3)); router.enableRecvTransceiver(integrator, 3, address(transceiver3)); // Only an integrator can call receive. vm.startPrank(userB); - vm.expectRevert(abi.encodeWithSelector(Router.IntegratorNotRegistered.selector)); router.getMessageStatus(2, sourceIntegrator, 1, OurChainId, destIntegrator, messageHash); - // Receiving a message destine for the wrong chain should revert. + // Receiving a message destined for the wrong chain should revert. vm.startPrank(integrator); vm.expectRevert(abi.encodeWithSelector(Router.InvalidDestinationChain.selector)); router.getMessageStatus(2, sourceIntegrator, 1, OurChainId + 1, destIntegrator, messageHash); diff --git a/evm/test/TransceiverRegistry.t.sol b/evm/test/TransceiverRegistry.t.sol index 5d56c97e..d0644f70 100644 --- a/evm/test/TransceiverRegistry.t.sol +++ b/evm/test/TransceiverRegistry.t.sol @@ -5,8 +5,8 @@ import {Test, console} from "forge-std/Test.sol"; import "../src/TransceiverRegistry.sol"; contract ConcreteTransceiverRegistry is TransceiverRegistry { - function addTransceiver(address integrator, uint16 chain, address transceiver) public returns (uint8 index) { - return _addTransceiver(integrator, chain, transceiver); + function addTransceiver(address integrator, address transceiver) public returns (uint8 index) { + return _addTransceiver(integrator, transceiver); } function disableSendTransceiver(address integrator, uint16 chain, address transceiver) public { @@ -21,10 +21,6 @@ contract ConcreteTransceiverRegistry is TransceiverRegistry { return _getRegisteredTransceiversStorage()[integrator]; } - function getNumTransceiversStorage(address integrator) public view returns (_NumTransceivers memory $) { - return _getNumTransceiversStorage()[integrator]; - } - function getEnabledSendTransceiversBitmapForChain(address integrator, uint16 chain) public view @@ -54,7 +50,7 @@ contract ConcreteTransceiverRegistry is TransceiverRegistry { view returns (bool) { - return _isSendTransceiverEnabledForChain(integrator, chainId, transceiver); + return _isSendTransceiverEnabledForChainWithCheck(integrator, chainId, transceiver); } function isRecvTransceiverEnabledForChain(address integrator, uint16 chainId, address transceiver) @@ -62,7 +58,7 @@ contract ConcreteTransceiverRegistry is TransceiverRegistry { view returns (bool) { - return _isRecvTransceiverEnabledForChain(integrator, chainId, transceiver); + return _isRecvTransceiverEnabledForChainWithCheck(integrator, chainId, transceiver); } function getMaxTransceivers() public pure returns (uint8) { @@ -94,16 +90,12 @@ contract TransceiverRegistryTest is Test { address me = address(this); // Send side assertEq(transceiverRegistry.getTransceivers(me).length, 0); - vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); - transceiverRegistry.addTransceiver(me, zeroChain, sendTransceiver); - transceiverRegistry.addTransceiver(me, chain, sendTransceiver); + transceiverRegistry.addTransceiver(me, sendTransceiver); // Recv side // A transceiver was registered on the send side assertEq(transceiverRegistry.getTransceivers(me).length, 1); - vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); - transceiverRegistry.addTransceiver(me, zeroChain, recvTransceiver); - transceiverRegistry.addTransceiver(me, chain, recvTransceiver); + transceiverRegistry.addTransceiver(me, recvTransceiver); } function test3() public { @@ -113,24 +105,22 @@ contract TransceiverRegistryTest is Test { vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, sendTransceiver)); transceiverRegistry.disableSendTransceiver(me, chain, sendTransceiver); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); - transceiverRegistry.addTransceiver(me, chain, zeroTransceiver); + transceiverRegistry.addTransceiver(me, zeroTransceiver); require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 0, "S1"); - require(transceiverRegistry.getNumTransceiversStorage(me).registered == 0, "S2"); - transceiverRegistry.addTransceiver(me, chain, sendTransceiver); + transceiverRegistry.addTransceiver(me, sendTransceiver); require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 1, "S3"); - require(transceiverRegistry.getNumTransceiversStorage(me).registered == 1, "S4"); // assertEq(transceiverRegistry.getSendTransceiverInfos(integrator1).length, 1); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); transceiverRegistry.disableSendTransceiver(me, zeroChain, sendTransceiver); require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 1, "S5"); - require(transceiverRegistry.getNumTransceiversStorage(me).registered == 1, "S6"); vm.expectRevert( abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyDisabled.selector, sendTransceiver) ); transceiverRegistry.disableSendTransceiver(me, chain, sendTransceiver); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); + transceiverRegistry.enableSendTransceiver(me, zeroChain, sendTransceiver); transceiverRegistry.enableSendTransceiver(me, chain, sendTransceiver); require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 1, "S7"); - require(transceiverRegistry.getNumTransceiversStorage(me).registered == 1, "S8"); transceiverRegistry.disableSendTransceiver(me, chain, sendTransceiver); vm.expectRevert( abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyDisabled.selector, sendTransceiver) @@ -144,25 +134,21 @@ contract TransceiverRegistryTest is Test { vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, recvTransceiver)); transceiverRegistry.disableRecvTransceiver(me, chain, recvTransceiver); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); - transceiverRegistry.addTransceiver(me, chain, zeroTransceiver); + transceiverRegistry.addTransceiver(me, zeroTransceiver); // Carry over from send side test require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 1, "R1"); - require(transceiverRegistry.getNumTransceiversStorage(me).registered == 1, "R2"); - transceiverRegistry.addTransceiver(me, chain, recvTransceiver); + transceiverRegistry.addTransceiver(me, recvTransceiver); require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 2, "R3"); - require(transceiverRegistry.getNumTransceiversStorage(me).registered == 2, "R4"); // assertEq(transceiverRegistry.getRecvTransceiverInfos(integrator1).length, 1); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChain)); transceiverRegistry.disableRecvTransceiver(me, zeroChain, recvTransceiver); require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 2, "R5"); - require(transceiverRegistry.getNumTransceiversStorage(me).registered == 2, "R6"); vm.expectRevert( abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyDisabled.selector, recvTransceiver) ); transceiverRegistry.disableRecvTransceiver(me, chain, recvTransceiver); transceiverRegistry.enableRecvTransceiver(me, chain, recvTransceiver); require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 2, "R7"); - require(transceiverRegistry.getNumTransceiversStorage(me).registered == 2, "R8"); transceiverRegistry.disableRecvTransceiver(me, chain, recvTransceiver); vm.expectRevert( abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyDisabled.selector, recvTransceiver) @@ -196,19 +182,6 @@ contract TransceiverRegistryTest is Test { assertEq(transceiverRegistry.getRegisteredTransceiversStorage(integrator1).length, 0); } - // This is a redudant test, as the previous tests already cover this - function test6() public view { - // Send side - TransceiverRegistry._NumTransceivers memory numSendTransceivers = - transceiverRegistry.getNumTransceiversStorage(integrator1); - assertEq(numSendTransceivers.registered, 0); - - // Recv side - TransceiverRegistry._NumTransceivers memory numRecvTransceivers = - transceiverRegistry.getNumTransceiversStorage(integrator1); - assertEq(numRecvTransceivers.registered, 0); - } - function test7() public { address me = address(this); // Send side @@ -250,9 +223,8 @@ contract TransceiverRegistryTest is Test { address sTransceiver = address(0x345); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, sTransceiver)); require(transceiverRegistry.isSendTransceiverEnabledForChain(me, chainId, sTransceiver) == false, "S1"); - transceiverRegistry.addTransceiver(me, chain, sTransceiver); + transceiverRegistry.addTransceiver(me, sTransceiver); require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 1, "S2"); - require(transceiverRegistry.getNumTransceiversStorage(me).registered == 1, "S3"); transceiverRegistry.enableSendTransceiver(me, chainId, sTransceiver); bool enabled = transceiverRegistry.isSendTransceiverEnabledForChain(me, chainId, sTransceiver); require(enabled == true, "S4"); @@ -262,9 +234,8 @@ contract TransceiverRegistryTest is Test { address rTransceiver = address(0x453); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.NonRegisteredTransceiver.selector, rTransceiver)); require(transceiverRegistry.isRecvTransceiverEnabledForChain(me, chainId, rTransceiver) == false, "R1"); - transceiverRegistry.addTransceiver(me, chain, rTransceiver); + transceiverRegistry.addTransceiver(me, rTransceiver); require(transceiverRegistry.getRegisteredTransceiversStorage(me).length == 2, "R2"); - require(transceiverRegistry.getNumTransceiversStorage(me).registered == 2, "R3"); transceiverRegistry.enableRecvTransceiver(me, chainId, rTransceiver); enabled = transceiverRegistry.isRecvTransceiverEnabledForChain(me, chainId, rTransceiver); require(enabled == true, "R4"); @@ -277,16 +248,13 @@ contract TransceiverRegistryTest is Test { // Send side for (uint8 i = 0; i < maxTransceivers; i++) { - transceiverRegistry.addTransceiver(me, chain, address(uint160(i + 1))); + transceiverRegistry.addTransceiver(me, address(uint160(i + 1))); } assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TooManyTransceivers.selector)); - transceiverRegistry.addTransceiver(me, chain, address(0x111)); + transceiverRegistry.addTransceiver(me, address(0x111)); assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); for (uint8 i = 0; i < maxTransceivers; i++) { transceiverRegistry.enableSendTransceiver(me, chain, address(uint160(i + 1))); } @@ -304,14 +272,12 @@ contract TransceiverRegistryTest is Test { // Recv side for (uint8 i = 0; i < maxTransceivers; i++) { - transceiverRegistry.addTransceiver(me, chain, address(uint160(i + 1))); + transceiverRegistry.addTransceiver(me, address(uint160(i + 1))); } assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TooManyTransceivers.selector)); - transceiverRegistry.addTransceiver(me, chain, address(0x111)); + transceiverRegistry.addTransceiver(me, address(0x111)); assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); transceiverRegistry.enableRecvTransceiver(me, chain, address(0x1)); transceiverRegistry.enableRecvTransceiver(me, chain, address(0x2)); transceiverRegistry.disableRecvTransceiver(me, chain, address(0x2)); @@ -319,9 +285,8 @@ contract TransceiverRegistryTest is Test { vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TransceiverAlreadyDisabled.selector, address(0x1))); transceiverRegistry.disableSendTransceiver(me, chain, address(0x1)); assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); - assertEq(transceiverRegistry.getNumTransceiversStorage(me).registered, maxTransceivers); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.TooManyTransceivers.selector)); - transceiverRegistry.addTransceiver(me, chain, address(0x111)); + transceiverRegistry.addTransceiver(me, address(0x111)); } function test_getSendTransceiversByChain() public { @@ -333,13 +298,13 @@ contract TransceiverRegistryTest is Test { address transceiver3 = address(0x3); // enabled, chain 2 address transceiver4 = address(0x4); // disabled, chain 2 - transceiverRegistry.addTransceiver(me, chain1, transceiver1); + transceiverRegistry.addTransceiver(me, transceiver1); transceiverRegistry.enableSendTransceiver(me, chain1, transceiver1); - transceiverRegistry.addTransceiver(me, chain1, transceiver2); + transceiverRegistry.addTransceiver(me, transceiver2); transceiverRegistry.enableSendTransceiver(me, chain1, transceiver2); - transceiverRegistry.addTransceiver(me, chain2, transceiver3); + transceiverRegistry.addTransceiver(me, transceiver3); transceiverRegistry.enableSendTransceiver(me, chain2, transceiver3); - transceiverRegistry.addTransceiver(me, chain2, transceiver4); + transceiverRegistry.addTransceiver(me, transceiver4); address[] memory chain1Addrs = transceiverRegistry.getSendTransceiversByChain(me, chain1); require(chain1Addrs.length == 2, "Wrong number of transceivers enabled on chain one"); address[] memory chain2Addrs = transceiverRegistry.getSendTransceiversByChain(me, chain2); @@ -358,13 +323,13 @@ contract TransceiverRegistryTest is Test { address transceiver3 = address(0x3); // enabled, chain 2 address transceiver4 = address(0x4); // disabled, chain 2 - transceiverRegistry.addTransceiver(me, chain1, transceiver1); + transceiverRegistry.addTransceiver(me, transceiver1); transceiverRegistry.enableRecvTransceiver(me, chain1, transceiver1); - transceiverRegistry.addTransceiver(me, chain1, transceiver2); + transceiverRegistry.addTransceiver(me, transceiver2); transceiverRegistry.enableRecvTransceiver(me, chain1, transceiver2); - transceiverRegistry.addTransceiver(me, chain2, transceiver3); + transceiverRegistry.addTransceiver(me, transceiver3); transceiverRegistry.enableRecvTransceiver(me, chain2, transceiver3); - transceiverRegistry.addTransceiver(me, chain2, transceiver4); + transceiverRegistry.addTransceiver(me, transceiver4); address[] memory chain1Addrs = transceiverRegistry.getRecvTransceiversByChain(me, chain1); require(chain1Addrs.length == 2, "Wrong number of transceivers enabled on chain one"); address[] memory chain2Addrs = transceiverRegistry.getRecvTransceiversByChain(me, chain2); @@ -373,4 +338,22 @@ contract TransceiverRegistryTest is Test { transceiverRegistry.disableRecvTransceiver(me, chain2, transceiver3); require(chain2Addrs.length == 1, "Wrong number of transceivers enabled on chain two"); } + + function test_recvPerformance() public { + address me = address(this); + uint8 maxTransceivers = transceiverRegistry.getMaxTransceivers(); + + // Recv side + for (uint8 i = 0; i < maxTransceivers; i++) { + transceiverRegistry.addTransceiver(me, address(uint160(i + 1))); + } + assertEq(transceiverRegistry.getRegisteredTransceiversStorage(me).length, maxTransceivers); + for (uint8 i = 0; i < maxTransceivers; i++) { + transceiverRegistry.enableRecvTransceiver(me, chain, address(uint160(i + 1))); + } + address[] memory chainAddrs = transceiverRegistry.getRecvTransceiversByChain(me, chain); + require(chainAddrs.length == maxTransceivers, "Wrong number of transceivers enabled on chain one"); + address[] memory chain2Addrs = transceiverRegistry.getRecvTransceiversByChain(me, wrongChain); + require(chain2Addrs.length == 0, "Wrong number of transceivers enabled on chain two"); + } } From 843b1098d46d1ab30755d4658d918b52c568569c Mon Sep 17 00:00:00 2001 From: Paul Noel Date: Thu, 24 Oct 2024 08:28:07 -0400 Subject: [PATCH 4/5] evm: ran forge fmt --- evm/test/Router.t.sol | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evm/test/Router.t.sol b/evm/test/Router.t.sol index f54831c6..0565c475 100644 --- a/evm/test/Router.t.sol +++ b/evm/test/Router.t.sol @@ -139,7 +139,7 @@ contract RouterTest is Test { router.updateAdmin(integrator, address(newerAdmin)); require(router.getAdmin(integrator) == address(newerAdmin), "failed to update admin address"); - // Cannot claim if there is no transfer in progress. + // Cannot claim if there is no transfer in progress. vm.expectRevert(abi.encodeWithSelector(Router.NoAdminTransferInProgress.selector)); router.claimAdmin(integrator); From bf7ed8701a274ddfa5f73d7dccd4314b84b2a162 Mon Sep 17 00:00:00 2001 From: Paul Noel Date: Thu, 24 Oct 2024 09:12:13 -0400 Subject: [PATCH 5/5] evm: fix coverage --- evm/test/TransceiverRegistry.t.sol | 2 ++ 1 file changed, 2 insertions(+) diff --git a/evm/test/TransceiverRegistry.t.sol b/evm/test/TransceiverRegistry.t.sol index d0644f70..67b2606f 100644 --- a/evm/test/TransceiverRegistry.t.sol +++ b/evm/test/TransceiverRegistry.t.sol @@ -213,6 +213,8 @@ contract TransceiverRegistryTest is Test { transceiverRegistry.enableRecvTransceiver(me, chainId, zeroTransceiver); vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector)); transceiverRegistry.isRecvTransceiverEnabledForChain(me, chainId, zeroTransceiver); + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, zeroChainId)); + transceiverRegistry.isRecvTransceiverEnabledForChain(me, zeroChainId, me); } function test9() public {