From 87326f7313e851a603ef430baa33823e4813d977 Mon Sep 17 00:00:00 2001 From: Anton Bukov Date: Thu, 17 Sep 2020 22:19:11 +0300 Subject: [PATCH] Add functionStaticCall and functionDelegateCall methods to Address library (#2333) Co-authored-by: Francisco Giordano --- CHANGELOG.md | 4 + contracts/mocks/AddressImpl.sol | 12 ++- contracts/mocks/CallReceiverMock.sol | 10 +++ contracts/utils/Address.sol | 58 +++++++++++++-- test/utils/Address.test.js | 106 ++++++++++++++++++++++++++- 5 files changed, 183 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a2cbb489ed1..0fb3e224a2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 3.3.0 (unreleased) + + * `Address`: added `functionStaticCall` and `functionDelegateCall`, similar to the existing `functionCall`. ([#2333](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2333)) + ## 3.2.0 (2020-09-10) ### New features diff --git a/contracts/mocks/AddressImpl.sol b/contracts/mocks/AddressImpl.sol index 19dcb15b440..8a80a84bc52 100644 --- a/contracts/mocks/AddressImpl.sol +++ b/contracts/mocks/AddressImpl.sol @@ -5,6 +5,8 @@ pragma solidity ^0.6.0; import "../utils/Address.sol"; contract AddressImpl { + string public sharedAnswer; + event CallReturnValue(string data); function isContract(address account) external view returns (bool) { @@ -17,13 +19,21 @@ contract AddressImpl { function functionCall(address target, bytes calldata data) external { bytes memory returnData = Address.functionCall(target, data); - emit CallReturnValue(abi.decode(returnData, (string))); } function functionCallWithValue(address target, bytes calldata data, uint256 value) external payable { bytes memory returnData = Address.functionCallWithValue(target, data, value); + emit CallReturnValue(abi.decode(returnData, (string))); + } + + function functionStaticCall(address target, bytes calldata data) external { + bytes memory returnData = Address.functionStaticCall(target, data); + emit CallReturnValue(abi.decode(returnData, (string))); + } + function functionDelegateCall(address target, bytes calldata data) external { + bytes memory returnData = Address.functionDelegateCall(target, data); emit CallReturnValue(abi.decode(returnData, (string))); } diff --git a/contracts/mocks/CallReceiverMock.sol b/contracts/mocks/CallReceiverMock.sol index 2e3297617b9..b76175214ba 100644 --- a/contracts/mocks/CallReceiverMock.sol +++ b/contracts/mocks/CallReceiverMock.sol @@ -3,6 +3,7 @@ pragma solidity ^0.6.0; contract CallReceiverMock { + string public sharedAnswer; event MockFunctionCalled(); @@ -20,6 +21,10 @@ contract CallReceiverMock { return "0x1234"; } + function mockStaticFunction() public pure returns (string memory) { + return "0x1234"; + } + function mockFunctionRevertsNoReason() public payable { revert(); } @@ -37,4 +42,9 @@ contract CallReceiverMock { _array.push(i); } } + + function mockFunctionWritesStorage() public returns (string memory) { + sharedAnswer = "42"; + return "0x1234"; + } } diff --git a/contracts/utils/Address.sol b/contracts/utils/Address.sol index e93111dd0be..a3a98d5d494 100644 --- a/contracts/utils/Address.sol +++ b/contracts/utils/Address.sol @@ -87,7 +87,7 @@ library Address { * _Available since v3.1._ */ function functionCall(address target, bytes memory data, string memory errorMessage) internal returns (bytes memory) { - return _functionCallWithValue(target, data, 0, errorMessage); + return functionCallWithValue(target, data, 0, errorMessage); } /** @@ -113,14 +113,62 @@ library Address { */ function functionCallWithValue(address target, bytes memory data, uint256 value, string memory errorMessage) internal returns (bytes memory) { require(address(this).balance >= value, "Address: insufficient balance for call"); - return _functionCallWithValue(target, data, value, errorMessage); + require(isContract(target), "Address: call to non-contract"); + + // solhint-disable-next-line avoid-low-level-calls + (bool success, bytes memory returndata) = target.call{ value: value }(data); + return _verifyCallResult(success, returndata, errorMessage); } - function _functionCallWithValue(address target, bytes memory data, uint256 weiValue, string memory errorMessage) private returns (bytes memory) { - require(isContract(target), "Address: call to non-contract"); + /** + * @dev Same as {xref-Address-functionCall-address-bytes-}[`functionCall`], + * but performing a static call. + * + * _Available since v3.3._ + */ + function functionStaticCall(address target, bytes memory data) internal view returns (bytes memory) { + return functionStaticCall(target, data, "Address: low-level static call failed"); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-string-}[`functionCall`], + * but performing a static call. + * + * _Available since v3.3._ + */ + function functionStaticCall(address target, bytes memory data, string memory errorMessage) internal view returns (bytes memory) { + require(isContract(target), "Address: static call to non-contract"); + + // solhint-disable-next-line avoid-low-level-calls + (bool success, bytes memory returndata) = target.staticcall(data); + return _verifyCallResult(success, returndata, errorMessage); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-}[`functionCall`], + * but performing a delegate call. + * + * _Available since v3.3._ + */ + function functionDelegateCall(address target, bytes memory data) internal returns (bytes memory) { + return functionDelegateCall(target, data, "Address: low-level delegate call failed"); + } + + /** + * @dev Same as {xref-Address-functionCall-address-bytes-string-}[`functionCall`], + * but performing a delegate call. + * + * _Available since v3.3._ + */ + function functionDelegateCall(address target, bytes memory data, string memory errorMessage) internal returns (bytes memory) { + require(isContract(target), "Address: delegate call to non-contract"); // solhint-disable-next-line avoid-low-level-calls - (bool success, bytes memory returndata) = target.call{ value: weiValue }(data); + (bool success, bytes memory returndata) = target.delegatecall(data); + return _verifyCallResult(success, returndata, errorMessage); + } + + function _verifyCallResult(bool success, bytes memory returndata, string memory errorMessage) private pure returns(bytes memory) { if (success) { return returndata; } else { diff --git a/test/utils/Address.test.js b/test/utils/Address.test.js index 792e5e4215a..e6333e987cb 100644 --- a/test/utils/Address.test.js +++ b/test/utils/Address.test.js @@ -143,6 +143,7 @@ describe('Address', function () { // which cause a mockFunctionOutOfGas function to crash Ganache and the // subsequent tests before running out of gas. it('reverts when the called function runs out of gas', async function () { + this.timeout(10000); if (coverage) { return this.skip(); } const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ name: 'mockFunctionOutOfGas', @@ -154,7 +155,7 @@ describe('Address', function () { this.mock.functionCall(this.contractRecipient.address, abiEncodedCall), 'Address: low-level call failed', ); - }).timeout(5000); + }); it('reverts when the called function throws', async function () { const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ @@ -285,4 +286,107 @@ describe('Address', function () { }); }); }); + + describe('functionStaticCall', function () { + beforeEach(async function () { + this.contractRecipient = await CallReceiverMock.new(); + }); + + it('calls the requested function', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockStaticFunction', + type: 'function', + inputs: [], + }, []); + + const receipt = await this.mock.functionStaticCall(this.contractRecipient.address, abiEncodedCall); + + expectEvent(receipt, 'CallReturnValue', { data: '0x1234' }); + }); + + it('reverts on a non-static function', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunction', + type: 'function', + inputs: [], + }, []); + + await expectRevert( + this.mock.functionStaticCall(this.contractRecipient.address, abiEncodedCall), + 'Address: low-level static call failed', + ); + }); + + it('bubbles up revert reason', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunctionRevertsReason', + type: 'function', + inputs: [], + }, []); + + await expectRevert( + this.mock.functionStaticCall(this.contractRecipient.address, abiEncodedCall), + 'CallReceiverMock: reverting', + ); + }); + + it('reverts when address is not a contract', async function () { + const [ recipient ] = accounts; + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunction', + type: 'function', + inputs: [], + }, []); + await expectRevert( + this.mock.functionStaticCall(recipient, abiEncodedCall), + 'Address: static call to non-contract', + ); + }); + }); + + describe('functionDelegateCall', function () { + beforeEach(async function () { + this.contractRecipient = await CallReceiverMock.new(); + }); + + it('delegate calls the requested function', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunctionWritesStorage', + type: 'function', + inputs: [], + }, []); + + const receipt = await this.mock.functionDelegateCall(this.contractRecipient.address, abiEncodedCall); + + expectEvent(receipt, 'CallReturnValue', { data: '0x1234' }); + + expect(await this.mock.sharedAnswer()).to.equal('42'); + }); + + it('bubbles up revert reason', async function () { + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunctionRevertsReason', + type: 'function', + inputs: [], + }, []); + + await expectRevert( + this.mock.functionDelegateCall(this.contractRecipient.address, abiEncodedCall), + 'CallReceiverMock: reverting', + ); + }); + + it('reverts when address is not a contract', async function () { + const [ recipient ] = accounts; + const abiEncodedCall = web3.eth.abi.encodeFunctionCall({ + name: 'mockFunction', + type: 'function', + inputs: [], + }, []); + await expectRevert( + this.mock.functionDelegateCall(recipient, abiEncodedCall), + 'Address: delegate call to non-contract', + ); + }); + }); });