Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bytes memory version of Math.modExp #4893

Merged
merged 19 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .changeset/shiny-poets-whisper.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
'openzeppelin-solidity': minor
---

`Math`: Add `modExp` function that exposes the `EIP-198` precompile.
`Math`: Add `modExp` function that exposes the `EIP-198` precompile. Includes `uint256` and `bytes memory` versions.
60 changes: 54 additions & 6 deletions contracts/utils/math/Math.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

pragma solidity ^0.8.20;

import {Address} from "../Address.sol";
import {Panic} from "../Panic.sol";
import {SafeCast} from "./SafeCast.sol";

Expand Down Expand Up @@ -289,11 +288,7 @@ library Math {
function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) {
(bool success, uint256 result) = tryModExp(b, e, m);
if (!success) {
if (m == 0) {
Panic.panic(Panic.DIVISION_BY_ZERO);
} else {
revert Address.FailedInnerCall();
}
Panic.panic(Panic.DIVISION_BY_ZERO);
}
return result;
}
Expand Down Expand Up @@ -335,6 +330,59 @@ library Math {
}
}

/**
* @dev Variant of {modExp} that supports inputs of arbitrary length.
*/
function modExp(bytes memory b, bytes memory e, bytes memory m) internal view returns (bytes memory) {
(bool success, bytes memory result) = tryModExp(b, e, m);
if (!success) {
Panic.panic(Panic.DIVISION_BY_ZERO);
}
return result;
}

/**
* @dev Variant of {tryModExp} that supports inputs of arbitrary length.
*/
function tryModExp(
bytes memory b,
bytes memory e,
bytes memory m
) internal view returns (bool success, bytes memory result) {
if (_zeroBytes(m)) return (false, result);
ernestognw marked this conversation as resolved.
Show resolved Hide resolved

uint256 mLen = m.length;

// Encode call args in result and move the free memory pointer
result = abi.encodePacked(b.length, e.length, mLen, b, e, m);

/// @solidity memory-safe-assembly
assembly {
// Write result on top of args to avoid allocating extra memory.
// | Offset | Content | Content (Hex) |
// |-----------|--------------|--------------------------------------------------------------------|
// | 0x00:0x1f | args length | 0x<.......................................20+20+20+bLen+eLen+mLen> |
// | 0x20+mLen | result | 0x<........................................................result> |
// | 0x..:0x.. | dirty bytes | 0x<............................................20+20+20+bLen+eLen> |
ernestognw marked this conversation as resolved.
Show resolved Hide resolved
success := staticcall(gas(), 0x05, add(result, 0x20), mload(result), add(result, 0x20), mLen)
// Overwrite the length.
// result.length > returndatasize() is guaranteed because returndatasize() == m.length
mstore(result, mLen)
}
ernestognw marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* @dev Returns whether the provided byte array is zero.
*/
function _zeroBytes(bytes memory byteArray) private pure returns (bool) {
for (uint256 i; i < byteArray.length; ++i) {
ernestognw marked this conversation as resolved.
Show resolved Hide resolved
if (byteArray[i] != 0) {
return false;
}
}
return true;
}

/**
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
* towards zero.
Expand Down
26 changes: 26 additions & 0 deletions test/utils/math/Math.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,32 @@ contract MathTest is Test {
}
}

function testModExpMemory(uint256 b, uint256 e, uint256 m) public {
if (m == 0) {
vm.expectRevert(stdError.divisionError);
}
bytes memory result = Math.modExp(abi.encodePacked(b), abi.encodePacked(e), abi.encodePacked(m));
uint256 res = abi.decode(result, (uint256));
assertLt(res, m);
assertEq(result.length, 32);
assertEq(res, _nativeModExp(b, e, m));
}

function testTryModExpMemory(uint256 b, uint256 e, uint256 m) public {
(bool success, bytes memory result) = Math.tryModExp(
abi.encodePacked(b),
abi.encodePacked(e),
abi.encodePacked(m)
);
if (success) {
uint256 res = abi.decode(result, (uint256));
assertLt(res, m);
assertEq(res, _nativeModExp(b, e, m));
} else {
assertEq(result.length, 0);
}
}

function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
if (m == 1) return 0;
uint256 r = 1;
Expand Down
147 changes: 119 additions & 28 deletions test/utils/math/Math.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
const { Rounding } = require('../../helpers/enums');
const { min, max } = require('../../helpers/math');
const { generators } = require('../../helpers/random');
const { range } = require('../../../scripts/helpers');
const { toBeHex, dataLength } = require('ethers');
const { product } = require('../../helpers/iterate');

const RoundingDown = [Rounding.Floor, Rounding.Trunc];
const RoundingUp = [Rounding.Ceil, Rounding.Expand];
Expand Down Expand Up @@ -141,24 +144,6 @@ describe('Math', function () {
});
});

describe('tryModExp', function () {
it('is correctly returning true and calculating modulus', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 50n;

expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([true, base ** exponent % modulus]);
});

it('is correctly returning false when modulus is 0', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 0n;

expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([false, 0n]);
});
});

describe('max', function () {
it('is correctly detected in both position', async function () {
await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));
Expand Down Expand Up @@ -354,20 +339,126 @@ describe('Math', function () {
});

describe('modExp', function () {
it('is correctly calculating modulus', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 50n;
describe('with uint256 inputs', function () {
before(function () {
this.fn = '$modExp(uint256,uint256,uint256)';
});

it('is correctly calculating modulus', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 50n;

expect(await this.mock[this.fn](base, exponent, modulus)).to.equal(base ** exponent % modulus);
});

it('is correctly reverting when modulus is zero', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 0n;

await expect(this.mock[this.fn](base, exponent, modulus)).to.be.revertedWithPanic(PANIC_CODES.DIVISION_BY_ZERO);
});
});

describe('with bytes memory inputs', function () {
before(function () {
this.fn = '$modExp(bytes,bytes,bytes)';
});

it('is correctly calculating modulus', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 50n;

expect(await this.mock[this.fn](toBeHex(base), toBeHex(exponent), toBeHex(modulus))).to.equal(
toBeHex(base ** exponent % modulus),
);
});

it('is correctly reverting when modulus is zero', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 0n;

await expect(this.mock[this.fn](toBeHex(base), toBeHex(exponent), toBeHex(modulus))).to.be.revertedWithPanic(
PANIC_CODES.DIVISION_BY_ZERO,
);
});

expect(await this.mock.$modExp(base, exponent, modulus)).to.equal(base ** exponent % modulus);
for (const [baseExp, exponentExp, modulusExp] of product(range(0, 24, 4), range(0, 24, 4), range(0, 256, 64))) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ranges were selected to fit within the max BigInt.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These values are actually pretty small. The max value they take is

b: 1048576
e: 1048576
m: 6277101735386680763835789423207666416102355444464034512896

In comparaison, type(uint256).max is

115792089237316195423570985008687907853269984665640564039457584007913129639935

Copy link
Collaborator

@Amxx Amxx Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which makes me wonder:

  • if all these fit in a uint256, why only test them by the bytes version, and not with the uint256 version
  • This looks like fuzzing, but with very specific values (only powers of 2).

Since the values used here are not bigger than then one supported by the foundry fuzzing, are way slower to run, and do not cover a bigger space ... do we need that at all ? I'd remove this part of the tests.

const b = 2n ** BigInt(baseExp);
const e = 2n ** BigInt(exponentExp);
const m = 2n ** BigInt(modulusExp);

it(`calculates b ** e % m (b=${b}) (e=${e}) (m=${m})`, async function () {
const result = await this.mock[this.fn](toBeHex(b), toBeHex(e), toBeHex(m));
expect(result).to.equal(toBeHex(b ** e % m, dataLength(toBeHex(m))));
});
}
});
});

describe('tryModExp', function () {
describe('with uint256 inputs', function () {
before(function () {
this.fn = '$tryModExp(uint256,uint256,uint256)';
});

it('is correctly returning true and calculating modulus', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 50n;

expect(await this.mock[this.fn](base, exponent, modulus)).to.deep.equal([true, base ** exponent % modulus]);
});

it('is correctly returning false when modulus is 0', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 0n;

expect(await this.mock[this.fn](base, exponent, modulus)).to.deep.equal([false, 0n]);
});
});

it('is correctly reverting when modulus is zero', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 0n;
describe('with bytes memory inputs', function () {
before(function () {
this.fn = '$tryModExp(bytes,bytes,bytes)';
});

it('is correctly returning true and calculating modulus', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 50n;

expect(await this.mock[this.fn](toBeHex(base), toBeHex(exponent), toBeHex(modulus))).to.deep.equal([
true,
toBeHex(base ** exponent % modulus),
]);
});

it('is correctly returning false when modulus is 0', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 0n;

expect(await this.mock[this.fn](toBeHex(base), toBeHex(exponent), toBeHex(modulus))).to.deep.equal([
false,
'0x',
]);
});

for (const [baseExp, exponentExp, modulusExp] of product(range(0, 24, 4), range(0, 24, 4), range(0, 256, 64))) {
const b = 2n ** BigInt(baseExp) + 1n;
const e = 2n ** BigInt(exponentExp) + 1n;
const m = 2n ** BigInt(modulusExp) + 1n;

await expect(this.mock.$modExp(base, exponent, modulus)).to.be.revertedWithPanic(PANIC_CODES.DIVISION_BY_ZERO);
it(`calculates b ** e % m (b=${b}) (e=${e}) (m=${m})`, async function () {
const result = await this.mock[this.fn](toBeHex(b), toBeHex(e), toBeHex(m));
expect(result).to.deep.equal([true, toBeHex(b ** e % m, dataLength(toBeHex(m)))]);
});
}
});
});

Expand Down
Loading