Skip to content

Commit

Permalink
test: Security Policy Plugin: Module Installation
Browse files Browse the repository at this point in the history
  • Loading branch information
ankurdubey521 committed Oct 16, 2023
1 parent 5097985 commit 4f272e9
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,5 @@ interface ISecurityPolicyPlugin {
/// set in the security policy of the smart contract wallet.
/// @param _scw The address of the smart contract wallet
/// @param _plugin The address of the plugin to be installed
function validateSecurityPolicy(
address _scw,
address _plugin
) external view;
function validateSecurityPolicy(address _scw, address _plugin) external;
}
10 changes: 8 additions & 2 deletions contracts/smart-account/modules/SecurityPolicyManagerPlugin.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ contract SecurityPolicyManagerPlugin is ISecurityPolicyManagerPlugin {
address _setupContract,
bytes calldata _setupData
) external override returns (address) {
// The Setup Contract must satisfy all security policies
_validateSecurityPolicies(msg.sender, _setupContract);

// Instruct the SA to install the module and return the address
ISmartAccount sa = ISmartAccount(msg.sender);
(bool success, bytes memory returndata) = sa
Expand All @@ -37,8 +40,11 @@ contract SecurityPolicyManagerPlugin is ISecurityPolicyManagerPlugin {

address module = abi.decode(returndata, (address));

// Validate the security policies
_validateSecurityPolicies(msg.sender, module);
// If the setup contract differs from the installed module,
// Validate the module as well
if (module != _setupContract) {
_validateSecurityPolicies(msg.sender, module);
}

return module;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,42 @@ import {SecurityPolicyManagerPlugin, SENTINEL_MODULE_ADDRESS} from "modules/Secu
import {ISecurityPolicyPlugin} from "interfaces/modules/ISecurityPolicyPlugin.sol";
import {ISecurityPolicyManagerPlugin, ISecurityPolicyManagerPluginEventsErrors} from "interfaces/modules/ISecurityPolicyManagerPlugin.sol";
import {UserOperation} from "aa-core/EntryPoint.sol";
import {MultichainECDSAValidator} from "modules/MultichainECDSAValidator.sol";
import "forge-std/console2.sol";

contract TestSecurityPolicyPlugin is ISecurityPolicyPlugin {
bool public shouldRevert;
bool public wasCalled;

function validateSecurityPolicy(address, address) external view override {
require(!shouldRevert, "TestSecurityPolicyPlugin: shouldRevert");
mapping(address => bool) public blacklist;

constructor() {
blacklist[address(0x2)] = true;
}

error TestSecurityPolicyPluginError(address);

function validateSecurityPolicy(
address,
address _plugin
) external override {
wasCalled = true;
if (shouldRevert || blacklist[_plugin]) {
revert TestSecurityPolicyPluginError(address(this));
}
}

function setShouldRevert(bool _shouldRevert) external {
shouldRevert = _shouldRevert;
}
}

contract TestSetupContractBlacklistReturn {
function initForSmartAccount(address) external view returns (address) {
return address(0x2);
}
}

contract SecurityPolicyManagerPluginModuleInstallationTest is
SATestBase,
ISecurityPolicyManagerPluginEventsErrors
Expand All @@ -33,6 +55,8 @@ contract SecurityPolicyManagerPluginModuleInstallationTest is
TestSecurityPolicyPlugin p3;
TestSecurityPolicyPlugin p4;

MultichainECDSAValidator validator;

function setUp() public virtual override {
super.setUp();

Expand Down Expand Up @@ -95,5 +119,140 @@ contract SecurityPolicyManagerPluginModuleInstallationTest is
alice
);
entryPoint.handleOps(arraifyOps(op), owner.addr);

// Create MultichainValidator
validator = new MultichainECDSAValidator();
}

function testModuleInstallation() external {
bytes memory setupData = abi.encodeCall(
validator.initForSmartAccount,
(alice.addr)
);

UserOperation memory op = makeEcdsaModuleUserOp(
getSmartAccountExecuteCalldata(
address(spmp),
0,
abi.encodeCall(
ISecurityPolicyManagerPlugin.checkSetupAndEnableModule,
(address(validator), setupData)
)
),
sa,
0,
alice
);

vm.expectEmit(true, true, true, true);
emit ModuleValidated(address(sa), address(validator));

entryPoint.handleOps(arraifyOps(op), owner.addr);

assertTrue(p1.wasCalled());
assertTrue(p2.wasCalled());
assertTrue(p3.wasCalled());
assertTrue(p4.wasCalled());
assertTrue(sa.isModuleEnabled(address(validator)));
}

function testShouldRevertModuleInstallationIfSecurityPolicyIsNotSatisifedOnSetupContract()
external
{
TestSetupContractBlacklistReturn blacklistReturn = new TestSetupContractBlacklistReturn();

bytes memory setupData = abi.encodeCall(
validator.initForSmartAccount,
(alice.addr)
);

UserOperation memory op = makeEcdsaModuleUserOp(
getSmartAccountExecuteCalldata(
address(spmp),
0,
abi.encodeCall(
ISecurityPolicyManagerPlugin.checkSetupAndEnableModule,
(address(blacklistReturn), setupData)
)
),
sa,
0,
alice
);

vm.recordLogs();
entryPoint.handleOps(arraifyOps(op), owner.addr);
Vm.Log[] memory logs = vm.getRecordedLogs();
UserOperationEventData memory eventData = getUserOperationEventData(
logs
);
assertFalse(eventData.success);
UserOperationRevertReasonEventData
memory revertReasonEventData = getUserOperationRevertReasonEventData(
logs
);
assertEq(
keccak256(revertReasonEventData.revertReason),
keccak256(
abi.encodeWithSelector(
TestSecurityPolicyPlugin
.TestSecurityPolicyPluginError
.selector,
p4
)
)
);

assertFalse(sa.isModuleEnabled(address(validator)));
}

function testShouldRevertModuleInstallationIfSecurityPolicyIsNotSatisifedOnInstalledPlugin()
external
{
bytes memory setupData = abi.encodeCall(
validator.initForSmartAccount,
(alice.addr)
);

UserOperation memory op = makeEcdsaModuleUserOp(
getSmartAccountExecuteCalldata(
address(spmp),
0,
abi.encodeCall(
ISecurityPolicyManagerPlugin.checkSetupAndEnableModule,
(address(validator), setupData)
)
),
sa,
0,
alice
);

p4.setShouldRevert(true);

vm.recordLogs();
entryPoint.handleOps(arraifyOps(op), owner.addr);
Vm.Log[] memory logs = vm.getRecordedLogs();
UserOperationEventData memory eventData = getUserOperationEventData(
logs
);
assertFalse(eventData.success);
UserOperationRevertReasonEventData
memory revertReasonEventData = getUserOperationRevertReasonEventData(
logs
);
assertEq(
keccak256(revertReasonEventData.revertReason),
keccak256(
abi.encodeWithSelector(
TestSecurityPolicyPlugin
.TestSecurityPolicyPluginError
.selector,
p4
)
)
);

assertFalse(sa.isModuleEnabled(address(validator)));
}
}

0 comments on commit 4f272e9

Please sign in to comment.