diff --git a/contracts/smart-account/interfaces/modules/ISecurityPolicyManagerPlugin.sol b/contracts/smart-account/interfaces/modules/ISecurityPolicyManagerPlugin.sol index 25236c3e..aa01d570 100644 --- a/contracts/smart-account/interfaces/modules/ISecurityPolicyManagerPlugin.sol +++ b/contracts/smart-account/interfaces/modules/ISecurityPolicyManagerPlugin.sol @@ -8,11 +8,13 @@ address constant SENTINEL_MODULE_ADDRESS = address(0x1); interface ISecurityPolicyManagerPluginEventsErrors { event SecurityPolicyEnabled(address indexed scw, address indexed policy); event SecurityPolicyDisabled(address indexed scw, address indexed policy); + event ModuleValidated(address indexed scw, address indexed module); error SecurityPolicyAlreadyEnabled(address policy); error SecurityPolicyAlreadyDisabled(address policy); error InvalidSecurityPolicyAddress(address policy); error InvalidPointerAddress(address pointer); + error ModuleInstallationFailed(); error EmptyPolicyList(); } diff --git a/contracts/smart-account/modules/SecurityPolicyManagerPlugin.sol b/contracts/smart-account/modules/SecurityPolicyManagerPlugin.sol index edb12f40..1d3e8bf3 100644 --- a/contracts/smart-account/modules/SecurityPolicyManagerPlugin.sol +++ b/contracts/smart-account/modules/SecurityPolicyManagerPlugin.sol @@ -2,6 +2,8 @@ pragma solidity 0.8.17; import {ISecurityPolicyManagerPlugin, ISecurityPolicyPlugin, SENTINEL_MODULE_ADDRESS} from "contracts/smart-account/interfaces/modules/ISecurityPolicyManagerPlugin.sol"; +import {ISmartAccount} from "contracts/smart-account/interfaces/ISmartAccount.sol"; +import {Enum} from "contracts/smart-account/common/Enum.sol"; /// @title Security Policy Manager Plugin /// @author @ankurdubey521 @@ -14,10 +16,31 @@ contract SecurityPolicyManagerPlugin is ISecurityPolicyManagerPlugin { /// @inheritdoc ISecurityPolicyManagerPlugin function checkSetupAndEnableModule( - address, - bytes calldata + address _setupContract, + bytes calldata _setupData ) external override returns (address) { - revert("Not implemented"); + // Instruct the SA to install the module and return the address + ISmartAccount sa = ISmartAccount(msg.sender); + (bool success, bytes memory returndata) = sa + .execTransactionFromModuleReturnData( + msg.sender, + 0, + abi.encodeCall( + sa.setupAndEnableModule, + (_setupContract, _setupData) + ), + Enum.Operation.Call + ); + if (!success) { + revert ModuleInstallationFailed(); + } + + address module = abi.decode(returndata, (address)); + + // Validate the security policies + _validateSecurityPolicies(msg.sender, module); + + return module; } ////////////////////////// SECURITY POLICY MANAGEMENT FUNCTIONS ////////////////////////// @@ -60,7 +83,6 @@ contract SecurityPolicyManagerPlugin is ISecurityPolicyManagerPlugin { msg.sender ]; - // TODO: Verify if this reduces gas uint256 length = _policies.length; if (length == 0) { @@ -192,7 +214,12 @@ contract SecurityPolicyManagerPlugin is ISecurityPolicyManagerPlugin { address _scw, address _start, uint256 _pageSize - ) external view returns (ISecurityPolicyPlugin[] memory enabledPolicies) { + ) + external + view + override + returns (ISecurityPolicyPlugin[] memory enabledPolicies) + { enabledPolicies = new ISecurityPolicyPlugin[](_pageSize); uint256 actualEnabledPoliciesLength; @@ -226,4 +253,21 @@ contract SecurityPolicyManagerPlugin is ISecurityPolicyManagerPlugin { mstore(enabledPolicies, actualEnabledPoliciesLength) } } + + ////////////////////////// PLUGIN INSTALLATION FUNCTIONS HELPERS ////////////////////////// + + function _validateSecurityPolicies(address _sa, address _module) internal { + mapping(address => address) + storage enabledSecurityPolicies = enabledSecurityPoliciesLinkedList[ + _sa + ]; + + address current = enabledSecurityPolicies[SENTINEL_MODULE_ADDRESS]; + while (current != address(0) && current != SENTINEL_MODULE_ADDRESS) { + ISecurityPolicyPlugin(current).validateSecurityPolicy(_sa, _module); + current = enabledSecurityPolicies[current]; + } + + emit ModuleValidated(_sa, _module); + } } diff --git a/test/foundry/module/SecurityPolicy/SecurityPolicyManagerPlugin.ModuleInstallation.t.sol b/test/foundry/module/SecurityPolicy/SecurityPolicyManagerPlugin.ModuleInstallation.t.sol new file mode 100644 index 00000000..c8b7fd2a --- /dev/null +++ b/test/foundry/module/SecurityPolicy/SecurityPolicyManagerPlugin.ModuleInstallation.t.sol @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.17; + +import {Vm} from "forge-std/Test.sol"; +import {SATestBase} from "../../base/SATestBase.sol"; +import {SmartAccount} from "sa/SmartAccount.sol"; +import {SecurityPolicyManagerPlugin, SENTINEL_MODULE_ADDRESS} from "modules/SecurityPolicyManagerPlugin.sol"; +import {ISecurityPolicyPlugin} from "interfaces/modules/ISecurityPolicyPlugin.sol"; +import {ISecurityPolicyManagerPlugin, ISecurityPolicyManagerPluginEventsErrors} from "interfaces/modules/ISecurityPolicyManagerPlugin.sol"; +import {UserOperation} from "aa-core/EntryPoint.sol"; +import "forge-std/console2.sol"; + +contract TestSecurityPolicyPlugin is ISecurityPolicyPlugin { + bool public shouldRevert; + + function validateSecurityPolicy(address, address) external view override { + require(!shouldRevert, "TestSecurityPolicyPlugin: shouldRevert"); + } + + function setShouldRevert(bool _shouldRevert) external { + shouldRevert = _shouldRevert; + } +} + +contract SecurityPolicyManagerPluginModuleInstallationTest is + SATestBase, + ISecurityPolicyManagerPluginEventsErrors +{ + SmartAccount sa; + SecurityPolicyManagerPlugin spmp; + TestSecurityPolicyPlugin p1; + TestSecurityPolicyPlugin p2; + TestSecurityPolicyPlugin p3; + TestSecurityPolicyPlugin p4; + + function setUp() public virtual override { + super.setUp(); + + // Deploy Smart Account with default module + uint256 smartAccountDeploymentIndex = 0; + bytes memory moduleSetupData = getEcdsaOwnershipRegistryModuleSetupData( + alice.addr + ); + sa = getSmartAccountWithModule( + address(ecdsaOwnershipRegistryModule), + moduleSetupData, + smartAccountDeploymentIndex, + "aliceSA" + ); + + // Deploy SecurityPolicyManagerPlugin + spmp = new SecurityPolicyManagerPlugin(); + vm.label(address(spmp), "SecurityPolicyManagerPlugin"); + p1 = new TestSecurityPolicyPlugin(); + vm.label(address(p1), "p1"); + p2 = new TestSecurityPolicyPlugin(); + vm.label(address(p2), "p2"); + p3 = new TestSecurityPolicyPlugin(); + vm.label(address(p3), "p3"); + p4 = new TestSecurityPolicyPlugin(); + vm.label(address(p4), "p4"); + + // Enable SecurityPolicy Manager Plugin + UserOperation memory op = makeEcdsaModuleUserOp( + getSmartAccountExecuteCalldata( + address(sa), + 0, + abi.encodeCall(sa.enableModule, address(spmp)) + ), + sa, + 0, + alice + ); + entryPoint.handleOps(arraifyOps(op), owner.addr); + + // Enable p1, p2, p3, p4 + ISecurityPolicyPlugin[] memory policies = new ISecurityPolicyPlugin[]( + 4 + ); + policies[0] = p1; + policies[1] = p2; + policies[2] = p3; + policies[3] = p4; + op = makeEcdsaModuleUserOp( + getSmartAccountExecuteCalldata( + address(spmp), + 0, + abi.encodeCall( + ISecurityPolicyManagerPlugin.enableSecurityPolicies, + (policies) + ) + ), + sa, + 0, + alice + ); + entryPoint.handleOps(arraifyOps(op), owner.addr); + } +} diff --git a/test/foundry/module/SecurityPolicy/SecurityPolicyManagerPlugin.PluginManagementTest.t.sol b/test/foundry/module/SecurityPolicy/SecurityPolicyManagerPlugin.PluginManagementTest.t.sol index 53e33a58..04c86cad 100644 --- a/test/foundry/module/SecurityPolicy/SecurityPolicyManagerPlugin.PluginManagementTest.t.sol +++ b/test/foundry/module/SecurityPolicy/SecurityPolicyManagerPlugin.PluginManagementTest.t.sol @@ -1041,4 +1041,129 @@ contract SecurityPolicyManagerPluginPluginManagementTest is assertEq(address(enabledSecurityPolicies[2]), address(p2)); assertEq(address(enabledSecurityPolicies[3]), address(p1)); } + + function testShouldNotAllowEnablingAlreadyEnabledPolicySingleEnable() + external + { + bytes memory data = getSmartAccountExecuteCalldata( + address(spmp), + 0, + abi.encodeCall( + ISecurityPolicyManagerPlugin.enableSecurityPolicy, + (p1) + ) + ); + + UserOperation memory op = makeEcdsaModuleUserOp(data, sa, 0, alice); + entryPoint.handleOps(arraifyOps(op), owner.addr); + + op = makeEcdsaModuleUserOp(data, sa, 0, alice); + + vm.recordLogs(); + entryPoint.handleOps(arraifyOps(op), owner.addr); + Vm.Log[] memory logs = vm.getRecordedLogs(); + UserOperationEventData memory eventData = getUserOperationEventData( + logs + ); + assertFalse(eventData.success); + UserOperationRevertReasonEventData + memory revertReasonEventData = getUserOperationRevertReasonEventData( + logs + ); + assertEq( + keccak256(revertReasonEventData.revertReason), + keccak256( + abi.encodeWithSelector( + SecurityPolicyAlreadyEnabled.selector, + p1 + ) + ) + ); + } + + function testShouldNotAllowEnablingAlreadyEnabledPolicySMultiEnable() + external + { + bytes memory data = getSmartAccountExecuteCalldata( + address(spmp), + 0, + abi.encodeCall( + ISecurityPolicyManagerPlugin.enableSecurityPolicy, + (p1) + ) + ); + UserOperation memory op = makeEcdsaModuleUserOp(data, sa, 0, alice); + entryPoint.handleOps(arraifyOps(op), owner.addr); + + ISecurityPolicyPlugin[] memory policies = new ISecurityPolicyPlugin[]( + 1 + ); + policies[0] = p1; + data = getSmartAccountExecuteCalldata( + address(spmp), + 0, + abi.encodeCall( + ISecurityPolicyManagerPlugin.enableSecurityPolicies, + (policies) + ) + ); + op = makeEcdsaModuleUserOp(data, sa, 0, alice); + + vm.recordLogs(); + entryPoint.handleOps(arraifyOps(op), owner.addr); + Vm.Log[] memory logs = vm.getRecordedLogs(); + UserOperationEventData memory eventData = getUserOperationEventData( + logs + ); + assertFalse(eventData.success); + UserOperationRevertReasonEventData + memory revertReasonEventData = getUserOperationRevertReasonEventData( + logs + ); + assertEq( + keccak256(revertReasonEventData.revertReason), + keccak256( + abi.encodeWithSelector( + SecurityPolicyAlreadyEnabled.selector, + p1 + ) + ) + ); + } + + function testShouldAllowDisablingAlreadyEnabledPolicySingleDisable() + external + { + bytes memory data = getSmartAccountExecuteCalldata( + address(spmp), + 0, + abi.encodeCall( + ISecurityPolicyManagerPlugin.disableSecurityPolicy, + (p1, p2) + ) + ); + + UserOperation memory op = makeEcdsaModuleUserOp(data, sa, 0, alice); + + vm.recordLogs(); + entryPoint.handleOps(arraifyOps(op), owner.addr); + Vm.Log[] memory logs = vm.getRecordedLogs(); + UserOperationEventData memory eventData = getUserOperationEventData( + logs + ); + assertFalse(eventData.success); + UserOperationRevertReasonEventData + memory revertReasonEventData = getUserOperationRevertReasonEventData( + logs + ); + assertEq( + keccak256(revertReasonEventData.revertReason), + keccak256( + abi.encodeWithSelector( + SecurityPolicyAlreadyDisabled.selector, + p1 + ) + ) + ); + } }