Skip to content

Commit

Permalink
[SlashIndicator, RoninValidatorSet]: DelegateGuard: restrict delegate…
Browse files Browse the repository at this point in the history
…call to abuse pcu contracts (#234)
  • Loading branch information
TuDo1403 authored Jun 12, 2023
1 parent ede2cee commit d553583
Show file tree
Hide file tree
Showing 12 changed files with 357 additions and 5 deletions.
5 changes: 5 additions & 0 deletions contracts/libraries/Guards.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;

import { DelegateGuard } from "./guards/DelegateGuard.sol";
import { ReentrancyGuard } from "./guards/ReentrancyGuard.sol";
89 changes: 89 additions & 0 deletions contracts/libraries/guards/DelegateGuard.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.17;

import { GuardCore } from "./GuardCore.sol";
import { LibBitSlot, BitSlot } from "../udvts/LibBitSlot.sol";

/**
* @title DelegateGuard
* @dev A contract that provides delegatecall restriction functionality.
*/
abstract contract DelegateGuard is GuardCore {
using LibBitSlot for BitSlot;

/**
* @dev Error message for already initialized contract.
*/
error ErrAlreadyInitialized();

/**
* @dev Error message for restricted call type.
*/
error ErrCallTypeRestricted();

modifier onlyDelegate() virtual {
_checkDelegate(true);
_;
}

modifier nonDelegate() virtual {
_checkDelegate(false);
_;
}

/**
* @dev Initializes the origin address and turns on the initialized bit.
*
* Note:
* - Must be called during initialization if using an Upgradeable Proxy.
* - Must be called in the constructor if using immutable contracts.
*
*/
function _initOriginAddress() internal virtual {
bytes32 _slot = _guardSlot();
uint8 _initializedBitPos = _getBitPos(GuardType.INITIALIZE_BIT);
// Check if the contract is already initialized
if (BitSlot.wrap(_slot).get(_initializedBitPos)) revert ErrAlreadyInitialized();

uint8 _originBitPos = _getBitPos(GuardType.ORIGIN_ADDRESS);
assembly {
// Load the full slot
let _data := sload(_slot)
// Shift << address(this) to `_originBitPos`
let _originMask := shl(_originBitPos, address())
// Shift << initialized bit to `initializedBitPos`
let _initializedMask := shl(_initializedBitPos, 1)
// _data = _data | _originMask | _initializedMask
sstore(_slot, or(_data, or(_originMask, _initializedMask)))
}
}

/**
* @dev Internal function to restrict delegatecall based on the current context and the `_mustDelegate` flag.
* @notice When `_mustDelegate` is true, it enforces that `address(this) != originAddress`, otherwise reverts.
* When `_mustDelegate` is false, it enforces that `address(this) == originAddress`, otherwise reverts.
*/
function _checkDelegate(bool _mustDelegate) private view {
bytes32 _slot = _guardSlot();
uint8 _originBitPos = _getBitPos(GuardType.ORIGIN_ADDRESS);
bytes4 _callTypeRestricted = ErrCallTypeRestricted.selector;

assembly {
let _data := sload(_slot)
// Shift >> address(this) to `_originBitPos`
let _origin := shr(_originBitPos, _data)
// Dirty bytes removal
_origin := and(_origin, 0xffffffffffffffffffffffffffffffffffffffff)

// Check the current context and restrict based on the `_mustDelegate` flag
// If the current context differs from the origin address and `_mustDelegate` flag is false, restrict only normal calls and revert
// If the current context differs from the origin address and `_mustDelegate` flag is true, restrict only delegatecall and pass
// If the current context equals the origin address and `_mustDelegate` flag is false, restrict only normal calls and pass
// If the current context equals the origin address and `_mustDelegate` flag is true, restrict only delegatecall and revert
if iszero(xor(eq(_origin, address()), _mustDelegate)) {
mstore(0x00, _callTypeRestricted)
revert(0x1c, 0x04)
}
}
}
}
38 changes: 38 additions & 0 deletions contracts/libraries/guards/GuardCore.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;

abstract contract GuardCore {
/**
* | GuardType | Bit position | Bit length | Explanation |
* |----------------|--------------|------------|-----------------------------------------------------------------|
* | PAUSER_BIT | 0 | 1 | Flag indicating whether the contract is paused or not. |
* | REENTRANCY_BIT | 1 | 1 | Flag indicating whether the reentrancy guard is on or off. |
* | INITIALIZE_BIT | 2 | 1 | Flag indicating whether the origin guard is initialized or not. |
* | ORIGIN_ADDRESS | 3 | 160 | Bits to store the origin address for delegate call guard. |
*/
enum GuardType {
PAUSER_BIT,
REENTRANCY_BIT,
INITIALIZE_BIT,
ORIGIN_ADDRESS
}

/**
* @dev Returns the bit position of a guard type.
*
* Requirement:
* - Each guard type must have a different bit position.
* - The position of the guard types must not collide with each other.
*
*/
function _getBitPos(GuardType _type) internal pure virtual returns (uint8 _pos) {
assembly {
_pos := _type
}
}

/**
* @dev Returns the guard slot to store all of the guard types.
*/
function _guardSlot() internal pure virtual returns (bytes32);
}
50 changes: 50 additions & 0 deletions contracts/libraries/guards/ReentrancyGuard.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.17;

import { GuardCore } from "./GuardCore.sol";
import { LibBitSlot, BitSlot } from "../udvts/LibBitSlot.sol";

/**
* @title ReentrancyGuard
* @dev A contract that provides protection against reentrancy attacks.
*/
abstract contract ReentrancyGuard is GuardCore {
using LibBitSlot for BitSlot;

/**
* @dev Error message for reentrant function call.
*/
error ErrNonReentrancy();

/**
* @dev Modifier to prevent reentrancy attacks.
*/
modifier nonReentrant() virtual {
_beforeEnter();
_;
_afterEnter();
}

/**
* @dev Internal function to check and set the reentrancy bit before entering the protected function.
* @dev Throws an error if the reentrancy bit is already set.
*/
function _beforeEnter() internal virtual {
BitSlot _slot = BitSlot.wrap(_guardSlot());
uint8 _reentrancyBitPos = _getBitPos(GuardType.REENTRANCY_BIT);

// Check if the reentrancy bit is already set, and revert if it is
if (_slot.get(_reentrancyBitPos)) revert ErrNonReentrancy();

// Set the reentrancy bit to true
_slot.set({ pos: _reentrancyBitPos, bitOn: true });
}

/**
* @dev Internal function to reset the reentrancy bit after exiting the protected function.
*/
function _afterEnter() internal virtual {
// Reset the reentrancy bit to false
BitSlot.wrap(_guardSlot()).set({ pos: _getBitPos(GuardType.REENTRANCY_BIT), bitOn: false });
}
}
37 changes: 37 additions & 0 deletions contracts/libraries/udvts/LibBitSlot.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.17;

type BitSlot is bytes32;

library LibBitSlot {
/**
* @dev Returns whether the bit at position `pos`-th of a slot `slot` is set or not.
*/
function get(BitSlot slot, uint8 pos) internal view returns (bool bitOn) {
assembly {
bitOn := and(shr(pos, sload(slot)), 1)
}
}

/**
* @dev Sets the bit at position `pos`-th of a slot `slot` to be on or off.
*/
function set(
BitSlot slot,
uint8 pos,
bool bitOn
) internal {
assembly {
let value := sload(slot)
let shift := and(pos, 0xff)
// Isolate the bit at `shift`.
let bit := and(shr(shift, value), 1)
// Xor it with `_bitOn`. Results in 1 if both are different, else 0.
bit := xor(bit, bitOn)
// Shifts the bit back. Then, xor with value.
// Only the bit at `shift` will be flipped if they differ.
// Every other bit will stay the same, as they are xor'ed with zeroes.
sstore(slot, xor(value, shl(shift, bit)))
}
}
}
63 changes: 63 additions & 0 deletions contracts/mocks/MockProxyDelegate.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT

pragma solidity ^0.8.9;

library ErrorHandler {
function handleRevert(bool status, bytes memory returnOrRevertData) internal pure {
assembly {
if iszero(status) {
let revertLength := mload(returnOrRevertData)
if iszero(iszero(revertLength)) {
// Start of revert data bytes. The 0x20 offset is always the same.
revert(add(returnOrRevertData, 0x20), revertLength)
}

/// @dev equivalent to revert ExecutionFailed()
mstore(0x00, 0xacfdb444)
revert(0x1c, 0x04)
}
}
}
}

contract MockProxyDelegate {
error ExecutionFailed();

using ErrorHandler for bool;

/// @dev value is equal to keccak256("MockProxyDelegate.slot") - 1
bytes32 public constant SLOT = 0xd7e37bb02f38a001dc6dc288698347e84408fb1c25d8015413a6203a79da346f;

constructor(
address target_,
address admin_,
address implement_
) {
assembly {
sstore(SLOT, target_)
/// @dev value is equal to keccak256("eip1967.proxy.admin") - 1
sstore(0xb53127684a568b3173ae13b9f8a6016e243e63b6e8ee1178d6a717850b5d6103, admin_)
/// @dev value is equal to keccak256("eip1967.proxy.implementation") - 1
sstore(0x360894a13ba1a3210667c828492db98dca3e2076cc3735a920a3ca505d382bbc, implement_)
}
}

function slashDoubleSign(
address,
bytes calldata,
bytes calldata
) external {
address target;
assembly {
target := sload(SLOT)
}
(bool success, bytes memory returnOrRevertData) = target.delegatecall(
abi.encodeWithSelector(
/// @dev value is equal to bytes4(keccak256(functionDelegate(bytes)))
0x4bb5274a,
msg.data
)
);
success.handleRevert(returnOrRevertData);
}
}
5 changes: 3 additions & 2 deletions contracts/ronin/slash-indicator/SlashDoubleSign.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ pragma solidity ^0.8.9;
import "../../interfaces/slash-indicator/ISlashDoubleSign.sol";
import "../../precompile-usages/PCUValidateDoubleSign.sol";
import "../../extensions/collections/HasValidatorContract.sol";
import { DelegateGuard } from "../../libraries/Guards.sol";

abstract contract SlashDoubleSign is ISlashDoubleSign, HasValidatorContract, PCUValidateDoubleSign {
abstract contract SlashDoubleSign is ISlashDoubleSign, HasValidatorContract, PCUValidateDoubleSign, DelegateGuard {
/// @dev The amount of RON to slash double sign.
uint256 internal _slashDoubleSignAmount;
/// @dev The block number that the punished validator will be jailed until, due to double signing.
Expand All @@ -31,7 +32,7 @@ abstract contract SlashDoubleSign is ISlashDoubleSign, HasValidatorContract, PCU
address _consensusAddr,
bytes calldata _header1,
bytes calldata _header2
) external override onlyAdmin {
) external override nonDelegate onlyAdmin {
bytes32 _header1Checksum = keccak256(_header1);
bytes32 _header2Checksum = keccak256(_header2);

Expand Down
14 changes: 14 additions & 0 deletions contracts/ronin/slash-indicator/SlashIndicator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ contract SlashIndicator is
_creditScoreConfigs[2],
_creditScoreConfigs[3]
);
_initOriginAddress();
}

function initializeV2() external reinitializer(2) {
_initOriginAddress();
}

/**
Expand Down Expand Up @@ -128,4 +133,13 @@ contract SlashIndicator is
_validatorContract.isBlockProducer(_addr) &&
!_maintenanceContract.checkMaintained(_addr, block.number);
}

/**
* @dev Returns the guard slot for the contract.
* @return The guard slot value.
*/
function _guardSlot() internal pure override returns (bytes32) {
/// @dev value is equal to keccak256("@ronin.dpos.slash-indicator.SlashIndicator.guard.slot") - 1
return 0x155bb4ed7f6246483709b1cbe37e46dd176a81efb7ed6f314e99eb0dd07a7fa7;
}
}
6 changes: 4 additions & 2 deletions contracts/ronin/validator/CoinbaseExecution.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import "../../precompile-usages/PCUPickValidatorSet.sol";
import "./storage-fragments/CommonStorage.sol";
import "./CandidateManager.sol";
import "./EmergencyExit.sol";
import { DelegateGuard } from "../../libraries/Guards.sol";

abstract contract CoinbaseExecution is
ICoinbaseExecution,
Expand All @@ -25,7 +26,8 @@ abstract contract CoinbaseExecution is
HasBridgeTrackingContract,
HasMaintenanceContract,
HasSlashIndicatorContract,
EmergencyExit
EmergencyExit,
DelegateGuard
{
using EnumFlags for EnumFlags.ValidatorFlag;

Expand Down Expand Up @@ -92,7 +94,7 @@ abstract contract CoinbaseExecution is
/**
* @inheritdoc ICoinbaseExecution
*/
function wrapUpEpoch() external payable virtual override onlyCoinbase whenEpochEnding oncePerEpoch {
function wrapUpEpoch() external payable virtual override onlyCoinbase whenEpochEnding oncePerEpoch nonDelegate {
uint256 _newPeriod = _computePeriod(block.timestamp);
bool _periodEnding = _isPeriodEnding(_newPeriod);

Expand Down
10 changes: 10 additions & 0 deletions contracts/ronin/validator/RoninValidatorSet.sol
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ contract RoninValidatorSet is Initializable, CoinbaseExecution, SlashingExecutio
_setEmergencyExitLockedAmount(__emergencyExitConfigs[0]);
_setEmergencyExpiryDuration(__emergencyExitConfigs[1]);
_numberOfBlocksInEpoch = __numberOfBlocksInEpoch;
_initOriginAddress();
}

function initializeV2() external reinitializer(2) {
_initOriginAddress();
}

/**
Expand All @@ -73,4 +78,9 @@ contract RoninValidatorSet is Initializable, CoinbaseExecution, SlashingExecutio
{
return super._bridgeOperatorOf(_consensusAddr);
}

function _guardSlot() internal pure override returns (bytes32) {
/// @dev value is equal to keccak256("@ronin.dpos.validator.RoninValidator.guard.slot") - 1
return 0x597ac532456d03c8d4b2b6e1822ad69ce16d14b57d469414341a60a75681a805;
}
}
Loading

0 comments on commit d553583

Please sign in to comment.