diff --git a/monkey/common/configuration/__init__.py b/monkey/common/configuration/__init__.py index 2246e27f22a..c7fefc11beb 100644 --- a/monkey/common/configuration/__init__.py +++ b/monkey/common/configuration/__init__.py @@ -1,8 +1,16 @@ -from .agent_configuration import ( - AgentConfiguration, - AgentConfigurationSchema, +from .agent_configuration import AgentConfiguration, InvalidConfigurationError +from .agent_sub_configurations import ( + CustomPBAConfiguration, + PluginConfiguration, + ScanTargetConfiguration, + ICMPScanConfiguration, + TCPScanConfiguration, + NetworkScanConfiguration, + ExploitationOptionsConfiguration, + ExploiterConfiguration, + ExploitationConfiguration, + PropagationConfiguration, ) from .default_agent_configuration import ( - DEFAULT_AGENT_CONFIGURATION_JSON, - build_default_agent_configuration, + DEFAULT_AGENT_CONFIGURATION, ) diff --git a/monkey/common/configuration/agent_configuration.py b/monkey/common/configuration/agent_configuration.py index 15b338fe028..097b8638244 100644 --- a/monkey/common/configuration/agent_configuration.py +++ b/monkey/common/configuration/agent_configuration.py @@ -1,7 +1,10 @@ +from __future__ import annotations + from dataclasses import dataclass -from typing import List +from typing import Any, List, Mapping -from marshmallow import Schema, fields, post_load +from marshmallow import Schema, fields +from marshmallow.exceptions import MarshmallowError from .agent_sub_configuration_schemas import ( CustomPBAConfigurationSchema, @@ -15,6 +18,15 @@ ) +class InvalidConfigurationError(Exception): + pass + + +INVALID_CONFIGURATION_ERROR_MESSAGE = ( + "Cannot construct an AgentConfiguration object with the supplied, invalid data:" +) + + @dataclass(frozen=True) class AgentConfiguration: keep_tunnel_open_time: float @@ -24,6 +36,57 @@ class AgentConfiguration: payloads: List[PluginConfiguration] propagation: PropagationConfiguration + def __post_init__(self): + # This will raise an exception if the object is invalid. Calling this in __post__init() + # makes it impossible to construct an invalid object + try: + AgentConfigurationSchema().dump(self) + except Exception as err: + raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}") + + @staticmethod + def from_mapping(config_mapping: Mapping[str, Any]) -> AgentConfiguration: + """ + Construct an AgentConfiguration from a Mapping + + :param config_mapping: A Mapping that represents an AgentConfiguration + :return: An AgentConfiguration + :raises: InvalidConfigurationError if the provided Mapping does not represent a valid + AgentConfiguration + """ + + try: + config_dict = AgentConfigurationSchema().load(config_mapping) + return AgentConfiguration(**config_dict) + except MarshmallowError as err: + raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}") + + @staticmethod + def from_json(config_json: str) -> AgentConfiguration: + """ + Construct an AgentConfiguration from a JSON string + + :param config_json: A JSON string that represents an AgentConfiguration + :return: An AgentConfiguration + :raises: InvalidConfigurationError if the provided JSON does not represent a valid + AgentConfiguration + """ + try: + config_dict = AgentConfigurationSchema().loads(config_json) + return AgentConfiguration(**config_dict) + except MarshmallowError as err: + raise InvalidConfigurationError(f"{INVALID_CONFIGURATION_ERROR_MESSAGE}: {err}") + + @staticmethod + def to_json(config: AgentConfiguration) -> str: + """ + Serialize an AgentConfiguration to JSON + + :param config: An AgentConfiguration + :return: A JSON string representing the AgentConfiguration + """ + return AgentConfigurationSchema().dumps(config) + class AgentConfigurationSchema(Schema): keep_tunnel_open_time = fields.Float() @@ -32,7 +95,3 @@ class AgentConfigurationSchema(Schema): credential_collectors = fields.List(fields.Nested(PluginConfigurationSchema)) payloads = fields.List(fields.Nested(PluginConfigurationSchema)) propagation = fields.Nested(PropagationConfigurationSchema) - - @post_load - def _make_agent_configuration(self, data, **kwargs): - return AgentConfiguration(**data) diff --git a/monkey/common/configuration/default_agent_configuration.py b/monkey/common/configuration/default_agent_configuration.py index c831695664c..4eb8496a2fc 100644 --- a/monkey/common/configuration/default_agent_configuration.py +++ b/monkey/common/configuration/default_agent_configuration.py @@ -1,208 +1,115 @@ -from . import AgentConfiguration, AgentConfigurationSchema +from . import AgentConfiguration +from .agent_sub_configurations import ( + CustomPBAConfiguration, + ExploitationConfiguration, + ExploitationOptionsConfiguration, + ExploiterConfiguration, + ICMPScanConfiguration, + NetworkScanConfiguration, + PluginConfiguration, + PropagationConfiguration, + ScanTargetConfiguration, + TCPScanConfiguration, +) -DEFAULT_AGENT_CONFIGURATION_JSON = """{ - "keep_tunnel_open_time": 30, - "post_breach_actions": [ - { - "name": "CommunicateAsBackdoorUser", - "options": {} - }, - { - "name": "ModifyShellStartupFiles", - "options": {} - }, - { - "name": "HiddenFiles", - "options": {} - }, - { - "name": "TrapCommand", - "options": {} - }, - { - "name": "ChangeSetuidSetgid", - "options": {} - }, - { - "name": "ScheduleJobs", - "options": {} - }, - { - "name": "Timestomping", - "options": {} - }, - { - "name": "AccountDiscovery", - "options": {} - }, - { - "name": "ProcessListCollection", - "options": {} - } - ], - "credential_collectors": [ - { - "name": "MimikatzCollector", - "options": {} - }, - { - "name": "SSHCollector", - "options": {} - } - ], - "payloads": [ - { - "name": "ransomware", - "options": { - "encryption": { - "enabled": true, - "directories": { - "linux_target_dir": "", - "windows_target_dir": "" - } - }, - "other_behaviors": { - "readme": true - } - } - } - ], - "custom_pbas": { - "linux_command": "", - "linux_filename": "", - "windows_command": "", - "windows_filename": "" - }, - "propagation": { - "maximum_depth": 2, - "network_scan": { - "tcp": { - "timeout": 3000, - "ports": [ - 22, - 80, - 135, - 443, - 445, - 2222, - 3306, - 3389, - 5985, - 5986, - 7001, - 8008, - 8080, - 8088, - 8983, - 9200, - 9600 - ] - }, - "icmp": { - "timeout": 1000 - }, - "fingerprinters": [ - { - "name": "elastic", - "options": {} - }, - { - "name": "http", - "options": { - "http_ports": [ - 80, - 443, - 7001, - 8008, - 8080, - 8983, - 9200, - 9600 - ] - } - }, - { - "name": "mssql", - "options": {} - }, - { - "name": "smb", - "options": {} - }, - { - "name": "ssh", - "options": {} - } - ], - "targets": { - "blocked_ips": [], - "inaccessible_subnets": [], - "local_network_scan": true, - "subnets": [] - } - }, - "exploitation": { - "options": { - "http_ports": [ - 80, - 443, - 7001, - 8008, - 8080, - 8983, - 9200, - 9600 - ] - }, - "brute_force": [ - { - "name": "MSSQLExploiter", - "options": {} +PBAS = [ + "CommunicateAsBackdoorUser", + "ModifyShellStartupFiles", + "HiddenFiles", + "TrapCommand", + "ChangeSetuidSetgid", + "ScheduleJobs", + "Timestomping", + "AccountDiscovery", + "ProcessListCollection", +] - }, - { - "name": "PowerShellExploiter", - "options": {} +CREDENTIAL_COLLECTORS = ["MimikatzCollector", "SSHCollector"] - }, - { - "name": "SSHExploiter", - "options": {} +PBA_CONFIGURATION = [PluginConfiguration(pba, {}) for pba in PBAS] +CREDENTIAL_COLLECTOR_CONFIGURATION = [ + PluginConfiguration(collector, {}) for collector in CREDENTIAL_COLLECTORS +] - }, - { - "name": "SmbExploiter", - "options": { - "smb_download_timeout": 30 - } +RANSOMWARE_OPTIONS = { + "encryption": { + "enabled": True, + "directories": {"linux_target_dir": "", "windows_target_dir": ""}, + }, + "other_behaviors": {"readme": True}, +} - }, - { - "name": "WmiExploiter", - "options": { - "smb_download_timeout": 30 - } +PAYLOAD_CONFIGURATION = [PluginConfiguration("ransomware", RANSOMWARE_OPTIONS)] - } - ], - "vulnerability": [ - { - "name": "HadoopExploiter", - "options": {} +CUSTOM_PBA_CONFIGURATION = CustomPBAConfiguration( + linux_command="", linux_filename="", windows_command="", windows_filename="" +) - }, - { - "name": "Log4ShellExploiter", - "options": {} +TCP_PORTS = [ + 22, + 80, + 135, + 443, + 445, + 2222, + 3306, + 3389, + 5985, + 5986, + 7001, + 8008, + 8080, + 8088, + 8983, + 9200, + 9600, +] - } - ] - } - } - } -""" +TCP_SCAN_CONFIGURATION = TCPScanConfiguration(timeout=3.0, ports=TCP_PORTS) +ICMP_CONFIGURATION = ICMPScanConfiguration(timeout=1.0) +HTTP_PORTS = [80, 443, 7001, 8008, 8080, 8983, 9200, 9600] +FINGERPRINTERS = [ + PluginConfiguration("elastic", {}), + PluginConfiguration("http", {"http_ports": HTTP_PORTS}), + PluginConfiguration("mssql", {}), + PluginConfiguration("smb", {}), + PluginConfiguration("ssh", {}), +] +SCAN_TARGET_CONFIGURATION = ScanTargetConfiguration([], [], True, []) +NETWORK_SCAN_CONFIGURATION = NetworkScanConfiguration( + TCP_SCAN_CONFIGURATION, ICMP_CONFIGURATION, FINGERPRINTERS, SCAN_TARGET_CONFIGURATION +) -def build_default_agent_configuration() -> AgentConfiguration: - schema = AgentConfigurationSchema() - return schema.loads(DEFAULT_AGENT_CONFIGURATION_JSON) +EXPLOITATION_OPTIONS_CONFIGURATION = ExploitationOptionsConfiguration(HTTP_PORTS) +BRUTE_FORCE_EXPLOITERS = [ + ExploiterConfiguration("MSSQLExploiter", {}), + ExploiterConfiguration("PowerShellExploiter", {}), + ExploiterConfiguration("SSHExploiter", {}), + ExploiterConfiguration("SmbExploiter", {"smb_download_timeout": 30}), + ExploiterConfiguration("WmiExploiter", {"smb_download_timeout": 30}), +] + +VULNERABILITY_EXPLOITERS = [ + ExploiterConfiguration("Log4ShellExploiter", {}), + ExploiterConfiguration("HadoopExploiter", {}), +] + +EXPLOITATION_CONFIGURATION = ExploitationConfiguration( + EXPLOITATION_OPTIONS_CONFIGURATION, BRUTE_FORCE_EXPLOITERS, VULNERABILITY_EXPLOITERS +) + +PROPAGATION_CONFIGURATION = PropagationConfiguration( + maximum_depth=2, + network_scan=NETWORK_SCAN_CONFIGURATION, + exploitation=EXPLOITATION_CONFIGURATION, +) + +DEFAULT_AGENT_CONFIGURATION = AgentConfiguration( + keep_tunnel_open_time=30, + custom_pbas=CUSTOM_PBA_CONFIGURATION, + post_breach_actions=PBA_CONFIGURATION, + credential_collectors=CREDENTIAL_COLLECTOR_CONFIGURATION, + payloads=PAYLOAD_CONFIGURATION, + propagation=PROPAGATION_CONFIGURATION, +) diff --git a/monkey/common/operating_systems.py b/monkey/common/operating_systems.py index 67f67da8133..2ac2f64b359 100644 --- a/monkey/common/operating_systems.py +++ b/monkey/common/operating_systems.py @@ -2,5 +2,12 @@ class OperatingSystems(Enum): + """ + An Enum representing all supported operating systems + + This Enum represents all operating systems that Infection Monkey supports. The value of each + member is the member's name in all lower-case characters. + """ + LINUX = "linux" WINDOWS = "windows" diff --git a/monkey/common/utils/exceptions.py b/monkey/common/utils/exceptions.py index 5935145e730..31cebca3299 100644 --- a/monkey/common/utils/exceptions.py +++ b/monkey/common/utils/exceptions.py @@ -38,5 +38,7 @@ class DomainControllerNameFetchError(FailedExploitationError): """Raise on failed attempt to extract domain controller's name""" +# TODO: This has been replaced by common.configuration.InvalidConfigurationError. Use that error +# instead and remove this one. class InvalidConfigurationError(Exception): """Raise when configuration is invalid""" diff --git a/monkey/infection_monkey/exploit/caching_agent_repository.py b/monkey/infection_monkey/exploit/caching_agent_repository.py index 0f86bbd9d65..7d358025844 100644 --- a/monkey/infection_monkey/exploit/caching_agent_repository.py +++ b/monkey/infection_monkey/exploit/caching_agent_repository.py @@ -5,6 +5,7 @@ import requests +from common import OperatingSystems from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT from . import IAgentRepository @@ -22,18 +23,22 @@ def __init__(self, island_url: str, proxies: Mapping[str, str]): self._proxies = proxies self._lock = threading.Lock() - def get_agent_binary(self, os: str, architecture: str = None) -> io.BytesIO: + def get_agent_binary( + self, operating_system: OperatingSystems, architecture: str = None + ) -> io.BytesIO: # If multiple calls to get_agent_binary() are made simultaneously before the result of # _download_binary_from_island() is cached, then multiple requests will be sent to the # island. Add a mutex in front of the call to _download_agent_binary_from_island() so # that only one request per OS will be sent to the island. with self._lock: - return io.BytesIO(self._download_binary_from_island(os)) + return io.BytesIO(self._download_binary_from_island(operating_system)) @lru_cache(maxsize=None) - def _download_binary_from_island(self, os: str) -> bytes: + def _download_binary_from_island(self, operating_system: OperatingSystems) -> bytes: + os_name = operating_system.value + response = requests.get( # noqa: DUO123 - f"{self._island_url}/api/agent-binaries/{os}", + f"{self._island_url}/api/agent-binaries/{os_name}", verify=False, proxies=self._proxies, timeout=MEDIUM_REQUEST_TIMEOUT, diff --git a/monkey/infection_monkey/exploit/i_agent_repository.py b/monkey/infection_monkey/exploit/i_agent_repository.py index d825772a0d3..308cf541832 100644 --- a/monkey/infection_monkey/exploit/i_agent_repository.py +++ b/monkey/infection_monkey/exploit/i_agent_repository.py @@ -1,6 +1,8 @@ import abc import io +from common import OperatingSystems + # TODO: The Island also has an IAgentRepository with a totally different interface. At the moment, # the Island and Agent have different needs, but at some point we should unify these. @@ -13,12 +15,13 @@ class IAgentRepository(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def get_agent_binary(self, os: str, architecture: str = None) -> io.BytesIO: + def get_agent_binary( + self, operating_system: OperatingSystems, architecture: str = None + ) -> io.BytesIO: """ Retrieve the appropriate agent binary from the repository. - :param str os: The name of the operating system on which the agent binary will run - :param str architecture: Reserved + :param operating_system: The name of the operating system on which the agent binary will run + :param architecture: Reserved :return: A file-like object for the requested agent binary - :rtype: io.BytesIO """ pass diff --git a/monkey/infection_monkey/exploit/log4shell.py b/monkey/infection_monkey/exploit/log4shell.py index ffbcdd0d62a..cab4ed54854 100644 --- a/monkey/infection_monkey/exploit/log4shell.py +++ b/monkey/infection_monkey/exploit/log4shell.py @@ -129,7 +129,7 @@ def _build_command(self, path: PurePath, http_path) -> str: } def _build_java_class(self, exploit_command: str) -> bytes: - if OperatingSystems.LINUX in self.host.os["type"]: + if OperatingSystems.LINUX == self.host.os["type"]: return build_exploit_bytecode(exploit_command, LINUX_EXPLOIT_TEMPLATE_PATH) else: return build_exploit_bytecode(exploit_command, WINDOWS_EXPLOIT_TEMPLATE_PATH) diff --git a/monkey/infection_monkey/exploit/powershell.py b/monkey/infection_monkey/exploit/powershell.py index 6ef72963e8b..c9991b0b405 100644 --- a/monkey/infection_monkey/exploit/powershell.py +++ b/monkey/infection_monkey/exploit/powershell.py @@ -2,6 +2,7 @@ from pathlib import Path, PurePath from typing import List, Optional +from common import OperatingSystems from infection_monkey.exploit.HostExploiter import HostExploiter from infection_monkey.exploit.powershell_utils.auth_options import AuthOptions, get_auth_options from infection_monkey.exploit.powershell_utils.credentials import ( @@ -162,7 +163,7 @@ def _copy_monkey_binary_to_victim(self, monkey_path_on_victim: PurePath): temp_monkey_binary_filepath.unlink() def _create_local_agent_file(self, binary_path): - agent_binary_bytes = self.agent_repository.get_agent_binary("windows") + agent_binary_bytes = self.agent_repository.get_agent_binary(OperatingSystems.WINDOWS) with open(binary_path, "wb") as f: f.write(agent_binary_bytes.getvalue()) diff --git a/monkey/infection_monkey/exploit/tools/http_tools.py b/monkey/infection_monkey/exploit/tools/http_tools.py index 92696a5b797..b27f6cf6fb1 100644 --- a/monkey/infection_monkey/exploit/tools/http_tools.py +++ b/monkey/infection_monkey/exploit/tools/http_tools.py @@ -57,6 +57,6 @@ def create_locked_transfer( httpd.start() lock.acquire() return ( - "http://%s:%s/%s" % (local_ip, local_port, urllib.parse.quote(host.os["type"])), + f"http://{local_ip}:{local_port}/{urllib.parse.quote(host.os['type'].value)}", httpd, ) diff --git a/monkey/infection_monkey/i_control_channel.py b/monkey/infection_monkey/i_control_channel.py index 33539417c42..b903080188e 100644 --- a/monkey/infection_monkey/i_control_channel.py +++ b/monkey/infection_monkey/i_control_channel.py @@ -1,5 +1,7 @@ import abc +from common.configuration import AgentConfiguration + class IControlChannel(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -11,10 +13,10 @@ def should_agent_stop(self) -> bool: """ @abc.abstractmethod - def get_config(self) -> dict: + def get_config(self) -> AgentConfiguration: """ - :return: A dictionary containing Agent Configuration - :rtype: dict + :return: An AgentConfiguration object + :rtype: AgentConfiguration """ pass diff --git a/monkey/infection_monkey/master/automated_master.py b/monkey/infection_monkey/master/automated_master.py index f79fd5f1220..ad73b50f437 100644 --- a/monkey/infection_monkey/master/automated_master.py +++ b/monkey/infection_monkey/master/automated_master.py @@ -1,8 +1,9 @@ import logging import threading import time -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import Any, Callable, Iterable, List, Optional +from common.configuration import CustomPBAConfiguration, PluginConfiguration from common.utils import Timer from infection_monkey.credential_store import ICredentialsStore from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError @@ -13,7 +14,7 @@ from infection_monkey.telemetry.credentials_telem import CredentialsTelem from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger from infection_monkey.telemetry.post_breach_telem import PostBreachTelem -from infection_monkey.utils.propagation import should_propagate +from infection_monkey.utils.propagation import maximum_depth_reached from infection_monkey.utils.threading import create_daemon_thread, interruptible_iter from . import Exploiter, IPScanner, Propagator @@ -111,7 +112,7 @@ def _wait_for_master_stop_condition(self): time.sleep(CHECK_FOR_TERMINATE_INTERVAL_SEC) @staticmethod - def _try_communicate_with_island(fn: Callable[[], Any], max_tries: int): + def _try_communicate_with_island(fn: Callable[[], Any], max_tries: int) -> Any: tries = 0 while tries < max_tries: try: @@ -141,7 +142,7 @@ def _run_simulation(self): try: config = AutomatedMaster._try_communicate_with_island( self._control_channel.get_config, CHECK_FOR_CONFIG_COUNT - )["config"] + ) except IslandCommunicationError as e: logger.error(f"An error occurred while fetching configuration: {e}") return @@ -150,7 +151,7 @@ def _run_simulation(self): target=self._run_plugins, name="CredentialCollectorThread", args=( - config["credential_collectors"], + config.credential_collectors, "credential collector", self._collect_credentials, ), @@ -158,7 +159,7 @@ def _run_simulation(self): pba_thread = create_daemon_thread( target=self._run_pbas, name="PBAThread", - args=(config["post_breach_actions"].items(), self._run_pba, config["custom_pbas"]), + args=(config.post_breach_actions, self._run_pba, config.custom_pbas), ) credential_collector_thread.start() @@ -173,52 +174,56 @@ def _run_simulation(self): current_depth = self._current_depth if self._current_depth is not None else 0 logger.info(f"Current depth is {current_depth}") - if should_propagate(self._control_channel.get_config(), self._current_depth): - self._propagator.propagate(config["propagation"], current_depth, self._stop) + if maximum_depth_reached(config.propagation.maximum_depth, self._current_depth): + self._propagator.propagate(config.propagation, current_depth, self._stop) else: logger.info("Skipping propagation: maximum depth reached") payload_thread = create_daemon_thread( target=self._run_plugins, name="PayloadThread", - args=(config["payloads"].items(), "payload", self._run_payload), + args=(config.payloads, "payload", self._run_payload), ) payload_thread.start() payload_thread.join() pba_thread.join() - def _collect_credentials(self, collector: str): - credentials = self._puppet.run_credential_collector(collector, {}) + def _collect_credentials(self, collector: PluginConfiguration): + credentials = self._puppet.run_credential_collector(collector.name, collector.options) if credentials: self._telemetry_messenger.send_telemetry(CredentialsTelem(credentials)) else: logger.debug(f"No credentials were collected by {collector}") - def _run_pba(self, pba: Tuple[str, Dict]): - name = pba[0] - options = pba[1] - - for pba_data in self._puppet.run_pba(name, options): + def _run_pba(self, pba: PluginConfiguration): + for pba_data in self._puppet.run_pba(pba.name, pba.options): self._telemetry_messenger.send_telemetry(PostBreachTelem(pba_data)) - def _run_payload(self, payload: Tuple[str, Dict]): - name = payload[0] - options = payload[1] - - self._puppet.run_payload(name, options, self._stop) + def _run_payload(self, payload: PluginConfiguration): + self._puppet.run_payload(payload.name, payload.options, self._stop) def _run_pbas( - self, plugins: Iterable[Any], callback: Callable[[Any], None], custom_pba_options: Mapping + self, + plugins: Iterable[PluginConfiguration], + callback: Callable[[Any], None], + custom_pba_options: CustomPBAConfiguration, ): self._run_plugins(plugins, "post-breach action", callback) if custom_pba_is_enabled(custom_pba_options): - self._run_plugins([("CustomPBA", custom_pba_options)], "post-breach action", callback) + self._run_plugins( + [PluginConfiguration(name="CustomPBA", options=custom_pba_options.__dict__)], + "post-breach action", + callback, + ) def _run_plugins( - self, plugins: Iterable[Any], plugin_type: str, callback: Callable[[Any], None] + self, + plugins: Iterable[PluginConfiguration], + plugin_type: str, + callback: Callable[[Any], None], ): logger.info(f"Running {plugin_type}s") logger.debug(f"Found {len(plugins)} {plugin_type}(s) to run") diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py index c93af43e046..d68f42bda72 100644 --- a/monkey/infection_monkey/master/control_channel.py +++ b/monkey/infection_monkey/master/control_channel.py @@ -1,10 +1,12 @@ import json import logging +from pprint import pformat from typing import Mapping import requests from common.common_consts.timeouts import SHORT_REQUEST_TIMEOUT +from common.configuration import AgentConfiguration from infection_monkey.custom_types import PropagationCredentials from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError @@ -47,7 +49,7 @@ def should_agent_stop(self) -> bool: ) as e: raise IslandCommunicationError(e) - def get_config(self) -> dict: + def get_config(self) -> AgentConfiguration: try: response = requests.get( # noqa: DUO123 f"https://{self._control_channel_server}/api/agent", @@ -57,7 +59,10 @@ def get_config(self) -> dict: ) response.raise_for_status() - return json.loads(response.content.decode()) + config_dict = json.loads(response.text)["config"] + logger.debug(f"Received configuration:\n{pformat(json.loads(response.text))}") + + return AgentConfiguration.from_mapping(config_dict) except ( json.JSONDecodeError, requests.exceptions.ConnectionError, diff --git a/monkey/infection_monkey/master/exploiter.py b/monkey/infection_monkey/master/exploiter.py index 35a212482fb..53665da38ff 100644 --- a/monkey/infection_monkey/master/exploiter.py +++ b/monkey/infection_monkey/master/exploiter.py @@ -5,9 +5,13 @@ from itertools import chain from queue import Queue from threading import Event -from typing import Callable, Dict, List, Mapping +from typing import Callable, Dict, Sequence from common import OperatingSystems +from common.configuration.agent_sub_configurations import ( + ExploitationConfiguration, + ExploiterConfiguration, +) from infection_monkey.custom_types import PropagationCredentials from infection_monkey.i_puppet import ExploiterResultData, IPuppet from infection_monkey.model import VictimHost @@ -46,7 +50,7 @@ def __init__( def exploit_hosts( self, - exploiter_config: Dict, + exploiter_config: ExploitationConfiguration, hosts_to_exploit: Queue, current_depth: int, results_callback: Callback, @@ -56,7 +60,7 @@ def exploit_hosts( exploiters_to_run = self._process_exploiter_config(exploiter_config) logger.debug( "Agent is configured to run the following exploiters in order: " - f"{', '.join([e['name'] for e in exploiters_to_run])}" + f"{', '.join([e.name for e in exploiters_to_run])}" ) exploit_args = ( @@ -75,24 +79,26 @@ def exploit_hosts( ) @staticmethod - def _process_exploiter_config(exploiter_config: Mapping) -> List[Mapping]: + def _process_exploiter_config( + exploiter_config: ExploitationConfiguration, + ) -> Sequence[ExploiterConfiguration]: # Run vulnerability exploiters before brute force exploiters to minimize the effect of # account lockout due to invalid credentials - ordered_exploiters = chain( - exploiter_config["vulnerability"], exploiter_config["brute_force"] - ) + ordered_exploiters = chain(exploiter_config.vulnerability, exploiter_config.brute_force) exploiters_to_run = list(deepcopy(ordered_exploiters)) + extended_exploiters = [] for exploiter in exploiters_to_run: # This order allows exploiter-specific options to # override general options for all exploiters. - exploiter["options"] = {**exploiter_config["options"], **exploiter["options"]} + options = {**exploiter_config.options.__dict__, **exploiter.options} + extended_exploiters.append(ExploiterConfiguration(exploiter.name, options)) - return exploiters_to_run + return extended_exploiters def _exploit_hosts_on_queue( self, - exploiters_to_run: List[Dict], + exploiters_to_run: Sequence[ExploiterConfiguration], hosts_to_exploit: Queue, current_depth: int, results_callback: Callback, @@ -119,7 +125,7 @@ def _exploit_hosts_on_queue( def _run_all_exploiters( self, - exploiters_to_run: List[Dict], + exploiters_to_run: Sequence[ExploiterConfiguration], victim_host: VictimHost, current_depth: int, results_callback: Callback, @@ -127,11 +133,10 @@ def _run_all_exploiters( ): for exploiter in interruptible_iter(exploiters_to_run, stop): - exploiter_name = exploiter["name"] + exploiter_name = exploiter.name victim_os = victim_host.os.get("type") # We want to try all exploiters if the victim's OS is unknown - print(victim_os) if victim_os is not None and victim_os not in SUPPORTED_OS[exploiter_name]: logger.debug( f"Skipping {exploiter_name} because it does not support " @@ -140,7 +145,7 @@ def _run_all_exploiters( continue exploiter_results = self._run_exploiter( - exploiter_name, exploiter["options"], victim_host, current_depth, stop + exploiter_name, exploiter.options, victim_host, current_depth, stop ) results_callback(exploiter_name, victim_host, exploiter_results) diff --git a/monkey/infection_monkey/master/ip_scanner.py b/monkey/infection_monkey/master/ip_scanner.py index 8c0ea5caaa5..071abe95a8a 100644 --- a/monkey/infection_monkey/master/ip_scanner.py +++ b/monkey/infection_monkey/master/ip_scanner.py @@ -3,8 +3,13 @@ import threading from queue import Queue from threading import Event -from typing import Any, Callable, Dict, List +from typing import Callable, Dict, Sequence +from common.configuration.agent_sub_configurations import ( + NetworkScanConfiguration, + PluginConfiguration, + ScanTargetConfiguration, +) from infection_monkey.i_puppet import ( FingerprintData, IPuppet, @@ -29,8 +34,8 @@ def __init__(self, puppet: IPuppet, num_workers: int): def scan( self, - addresses_to_scan: List[NetworkAddress], - options: Dict, + addresses_to_scan: Sequence[NetworkAddress], + options: ScanTargetConfiguration, results_callback: Callback, stop: Event, ): @@ -49,12 +54,16 @@ def scan( ) def _scan_addresses( - self, addresses: Queue, options: Dict, results_callback: Callback, stop: Event + self, + addresses: Queue, + options: NetworkScanConfiguration, + results_callback: Callback, + stop: Event, ): - logger.debug(f"Starting scan thread -- Thread ID: {threading.get_ident()}") - icmp_timeout = options["icmp"]["timeout_ms"] / 1000 - tcp_timeout = options["tcp"]["timeout_ms"] / 1000 - tcp_ports = options["tcp"]["ports"] + logger.debug(f"Starting scan .read -- Thread ID: {threading.get_ident()}") + icmp_timeout = options.icmp.timeout + tcp_timeout = options.tcp.timeout + tcp_ports = options.tcp.ports try: while not stop.is_set(): @@ -66,7 +75,7 @@ def _scan_addresses( fingerprint_data = {} if IPScanner.port_scan_found_open_port(port_scan_data): - fingerprinters = options["fingerprinters"] + fingerprinters = options.fingerprinters fingerprint_data = self._run_fingerprinters( address.ip, fingerprinters, ping_scan_data, port_scan_data, stop ) @@ -90,7 +99,7 @@ def port_scan_found_open_port(port_scan_data: Dict[int, PortScanData]): def _run_fingerprinters( self, ip: str, - fingerprinters: List[Dict[str, Any]], + fingerprinters: Sequence[PluginConfiguration], ping_scan_data: PingScanData, port_scan_data: Dict[int, PortScanData], stop: Event, @@ -98,8 +107,8 @@ def _run_fingerprinters( fingerprint_data = {} for f in interruptible_iter(fingerprinters, stop): - fingerprint_data[f["name"]] = self._puppet.fingerprint( - f["name"], ip, ping_scan_data, port_scan_data, f["options"] + fingerprint_data[f.name] = self._puppet.fingerprint( + f.name, ip, ping_scan_data, port_scan_data, f.options ) return fingerprint_data diff --git a/monkey/infection_monkey/master/option_parsing.py b/monkey/infection_monkey/master/option_parsing.py index c35bf6303c0..c9262c5c989 100644 --- a/monkey/infection_monkey/master/option_parsing.py +++ b/monkey/infection_monkey/master/option_parsing.py @@ -1,13 +1,12 @@ -from typing import Dict - +from common.configuration import CustomPBAConfiguration from infection_monkey.utils.environment import is_windows_os -def custom_pba_is_enabled(pba_options: Dict) -> bool: +def custom_pba_is_enabled(pba_options: CustomPBAConfiguration) -> bool: if not is_windows_os(): - if pba_options["linux_command"]: + if pba_options.linux_command: return True else: - if pba_options["windows_command"]: + if pba_options.windows_command: return True return False diff --git a/monkey/infection_monkey/master/propagator.py b/monkey/infection_monkey/master/propagator.py index be4d6caf292..64edae2ec86 100644 --- a/monkey/infection_monkey/master/propagator.py +++ b/monkey/infection_monkey/master/propagator.py @@ -1,8 +1,13 @@ import logging from queue import Queue from threading import Event -from typing import Dict, List +from typing import List +from common.configuration import ( + NetworkScanConfiguration, + PropagationConfiguration, + ScanTargetConfiguration, +) from infection_monkey.i_puppet import ( ExploiterResultData, FingerprintData, @@ -39,14 +44,18 @@ def __init__( self._local_network_interfaces = local_network_interfaces self._hosts_to_exploit = None - def propagate(self, propagation_config: Dict, current_depth: int, stop: Event): + def propagate( + self, propagation_config: PropagationConfiguration, current_depth: int, stop: Event + ): logger.info("Attempting to propagate") network_scan_completed = Event() self._hosts_to_exploit = Queue() scan_thread = create_daemon_thread( - target=self._scan_network, name="PropagatorScanThread", args=(propagation_config, stop) + target=self._scan_network, + name="PropagatorScanThread", + args=(propagation_config.network_scan, stop), ) exploit_thread = create_daemon_thread( target=self._exploit_hosts, @@ -64,22 +73,21 @@ def propagate(self, propagation_config: Dict, current_depth: int, stop: Event): logger.info("Finished attempting to propagate") - def _scan_network(self, propagation_config: Dict, stop: Event): + def _scan_network(self, scan_config: NetworkScanConfiguration, stop: Event): logger.info("Starting network scan") - target_config = propagation_config["targets"] - scan_config = propagation_config["network_scan"] - - addresses_to_scan = self._compile_scan_target_list(target_config) + addresses_to_scan = self._compile_scan_target_list(scan_config.targets) self._ip_scanner.scan(addresses_to_scan, scan_config, self._process_scan_results, stop) logger.info("Finished network scan") - def _compile_scan_target_list(self, target_config: Dict) -> List[NetworkAddress]: - ranges_to_scan = target_config["subnet_scan_list"] - inaccessible_subnets = target_config["inaccessible_subnets"] - blocklisted_ips = target_config["blocked_ips"] - enable_local_network_scan = target_config["local_network_scan"] + def _compile_scan_target_list( + self, target_config: ScanTargetConfiguration + ) -> List[NetworkAddress]: + ranges_to_scan = target_config.subnets + inaccessible_subnets = target_config.inaccessible_subnets + blocklisted_ips = target_config.blocked_ips + enable_local_network_scan = target_config.local_network_scan return compile_scan_target_list( self._local_network_interfaces, @@ -134,14 +142,14 @@ def _process_fingerprinter_results(victim_host: VictimHost, fingerprint_data: Fi def _exploit_hosts( self, - propagation_config: Dict, + propagation_config: PropagationConfiguration, current_depth: int, network_scan_completed: Event, stop: Event, ): logger.info("Exploiting victims") - exploiter_config = propagation_config["exploiters"] + exploiter_config = propagation_config.exploitation self._exploiter.exploit_hosts( exploiter_config, self._hosts_to_exploit, diff --git a/monkey/infection_monkey/model/host.py b/monkey/infection_monkey/model/host.py index 6a1295e58f2..167bef246e3 100644 --- a/monkey/infection_monkey/model/host.py +++ b/monkey/infection_monkey/model/host.py @@ -17,7 +17,7 @@ def as_dict(self): return self.__dict__ def is_windows(self) -> bool: - return OperatingSystems.WINDOWS in self.os["type"] + return OperatingSystems.WINDOWS == self.os["type"] def __hash__(self): return hash(self.ip_addr) diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 466de666499..749803c7b98 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -78,7 +78,7 @@ remove_monkey_dir, ) from infection_monkey.utils.monkey_log_path import get_agent_log_path -from infection_monkey.utils.propagation import should_propagate +from infection_monkey.utils.propagation import maximum_depth_reached from infection_monkey.utils.signal_handler import register_signal_handlers, reset_signal_handlers logger = logging.getLogger(__name__) @@ -173,9 +173,11 @@ def _setup(self): config = control_channel.get_config() self._monkey_inbound_tunnel = self._control_client.create_control_tunnel( - config["config"]["keep_tunnel_open_time"] + config.keep_tunnel_open_time ) - if self._monkey_inbound_tunnel and should_propagate(config, self._current_depth): + if self._monkey_inbound_tunnel and maximum_depth_reached( + config.propagation.maximum_depth, self._current_depth + ): self._inbound_tunnel_opened = True self._monkey_inbound_tunnel.start() diff --git a/monkey/infection_monkey/post_breach/actions/clear_command_history.py b/monkey/infection_monkey/post_breach/actions/clear_command_history.py index 2641051cca0..7a5a350f5a6 100644 --- a/monkey/infection_monkey/post_breach/actions/clear_command_history.py +++ b/monkey/infection_monkey/post_breach/actions/clear_command_history.py @@ -16,7 +16,7 @@ def __init__(self, telemetry_messenger: ITelemetryMessenger): super().__init__(telemetry_messenger, name=POST_BREACH_CLEAR_CMD_HISTORY) def run(self, options: Dict) -> Iterable[PostBreachData]: - results = [pba.run() for pba in self.clear_command_history_pba_list()] + results = [pba.run(options) for pba in self.clear_command_history_pba_list()] if results: # `self.command` is empty here self.pba_data.append(PostBreachData(self.name, self.command, results)) @@ -53,7 +53,7 @@ def __init__(self, linux_cmds): linux_cmd=linux_cmds, ) - def run(self) -> Tuple[str, bool]: + def run(self, options: Dict) -> Tuple[str, bool]: if self.command: try: output = subprocess.check_output( # noqa: DUO116 diff --git a/monkey/infection_monkey/transport/http.py b/monkey/infection_monkey/transport/http.py index 63aaa0b36ae..7bcbcd87d0a 100644 --- a/monkey/infection_monkey/transport/http.py +++ b/monkey/infection_monkey/transport/http.py @@ -62,7 +62,7 @@ def do_HEAD(self): f.close() def send_head(self): - if self.path != "/" + urllib.parse.quote(self.victim_os): + if self.path != "/" + urllib.parse.quote(self.victim_os.value): self.send_error(500, "") return None, 0, 0 try: diff --git a/monkey/infection_monkey/utils/propagation.py b/monkey/infection_monkey/utils/propagation.py index 004bafdd2cc..2da2e7beeef 100644 --- a/monkey/infection_monkey/utils/propagation.py +++ b/monkey/infection_monkey/utils/propagation.py @@ -1,2 +1,2 @@ -def should_propagate(config: dict, current_depth: int) -> bool: - return config["config"]["depth"] > current_depth +def maximum_depth_reached(maximum_depth: int, current_depth: int) -> bool: + return maximum_depth > current_depth diff --git a/monkey/monkey_island/cc/repository/file_agent_configuration_repository.py b/monkey/monkey_island/cc/repository/file_agent_configuration_repository.py index b63ee817c0b..312e3921e55 100644 --- a/monkey/monkey_island/cc/repository/file_agent_configuration_repository.py +++ b/monkey/monkey_island/cc/repository/file_agent_configuration_repository.py @@ -1,6 +1,6 @@ import io -from common.configuration import AgentConfiguration, AgentConfigurationSchema +from common.configuration import AgentConfiguration from monkey_island.cc import repository from monkey_island.cc.repository import ( IAgentConfigurationRepository, @@ -17,21 +17,20 @@ def __init__( ): self._default_agent_configuration = default_agent_configuration self._file_repository = file_repository - self._schema = AgentConfigurationSchema() def get_configuration(self) -> AgentConfiguration: try: with self._file_repository.open_file(AGENT_CONFIGURATION_FILE_NAME) as f: configuration_json = f.read().decode() - return self._schema.loads(configuration_json) + return AgentConfiguration.from_json(configuration_json) except repository.FileNotFoundError: return self._default_agent_configuration except Exception as err: raise RetrievalError(f"Error retrieving the agent configuration: {err}") def store_configuration(self, agent_configuration: AgentConfiguration): - configuration_json = self._schema.dumps(agent_configuration) + configuration_json = AgentConfiguration.to_json(agent_configuration) self._file_repository.save_file( AGENT_CONFIGURATION_FILE_NAME, io.BytesIO(configuration_json.encode()) diff --git a/monkey/monkey_island/cc/resources/agent_configuration.py b/monkey/monkey_island/cc/resources/agent_configuration.py index f0ad73cb8f1..1bf3115640d 100644 --- a/monkey/monkey_island/cc/resources/agent_configuration.py +++ b/monkey/monkey_island/cc/resources/agent_configuration.py @@ -1,9 +1,9 @@ import json -import marshmallow from flask import make_response, request -from common.configuration.agent_configuration import AgentConfigurationSchema +from common.configuration.agent_configuration import AgentConfiguration as AgentConfigurationObject +from common.configuration.agent_configuration import InvalidConfigurationError from monkey_island.cc.repository import IAgentConfigurationRepository from monkey_island.cc.resources.AbstractResource import AbstractResource from monkey_island.cc.resources.request_authentication import jwt_required @@ -14,22 +14,21 @@ class AgentConfiguration(AbstractResource): def __init__(self, agent_configuration_repository: IAgentConfigurationRepository): self._agent_configuration_repository = agent_configuration_repository - self._schema = AgentConfigurationSchema() @jwt_required def get(self): configuration = self._agent_configuration_repository.get_configuration() - configuration_json = self._schema.dumps(configuration) + configuration_json = AgentConfigurationObject.to_json(configuration) return make_response(configuration_json, 200) @jwt_required def post(self): try: - configuration_object = self._schema.loads(request.data) + configuration_object = AgentConfigurationObject.from_json(request.data) self._agent_configuration_repository.store_configuration(configuration_object) return make_response({}, 200) - except (marshmallow.exceptions.ValidationError, json.JSONDecodeError) as err: + except (InvalidConfigurationError, json.JSONDecodeError) as err: return make_response( {"message": f"Invalid configuration supplied: {err}"}, 400, diff --git a/monkey/monkey_island/cc/services/config.py b/monkey/monkey_island/cc/services/config.py index 69db1c2e182..6d14505c34f 100644 --- a/monkey/monkey_island/cc/services/config.py +++ b/monkey/monkey_island/cc/services/config.py @@ -179,6 +179,7 @@ def ssh_add_keys(public_key, private_key): should_encrypt=True, ) + @staticmethod def _filter_none_values(data): if isinstance(data, dict): return { @@ -460,7 +461,7 @@ def _format_tcp_scan_from_flat_config(config: Dict) -> Dict[str, Any]: formatted_tcp_scan_config = {} - formatted_tcp_scan_config["timeout"] = config[flat_tcp_timeout_field] + formatted_tcp_scan_config["timeout"] = config[flat_tcp_timeout_field] / 1000 ports = ConfigService._union_tcp_and_http_ports( config[flat_tcp_ports_field], config[flat_http_ports_field] @@ -484,7 +485,7 @@ def _format_icmp_scan_from_flat_config(config: Dict) -> Dict[str, Any]: flat_ping_timeout_field = "ping_scan_timeout" formatted_icmp_scan_config = {} - formatted_icmp_scan_config["timeout"] = config[flat_ping_timeout_field] + formatted_icmp_scan_config["timeout"] = config[flat_ping_timeout_field] / 1000 config.pop(flat_ping_timeout_field, None) diff --git a/monkey/monkey_island/cc/services/initialize.py b/monkey/monkey_island/cc/services/initialize.py index 52343bbf58f..922c3654bc1 100644 --- a/monkey/monkey_island/cc/services/initialize.py +++ b/monkey/monkey_island/cc/services/initialize.py @@ -3,7 +3,7 @@ from common import DIContainer from common.aws import AWSInstance -from common.configuration import AgentConfiguration, build_default_agent_configuration +from common.configuration import DEFAULT_AGENT_CONFIGURATION, AgentConfiguration from common.utils.file_utils import get_binary_io_sha256_hash from monkey_island.cc.repository import ( AgentBinaryRepository, @@ -32,7 +32,7 @@ def initialize_services(data_dir: Path) -> DIContainer: container.register_convention(Path, "data_dir", data_dir) container.register_convention( - AgentConfiguration, "default_agent_configuration", build_default_agent_configuration() + AgentConfiguration, "default_agent_configuration", DEFAULT_AGENT_CONFIGURATION ) container.register_instance(AWSInstance, AWSInstance()) diff --git a/monkey/tests/monkey_island/in_memory_agent_configuration_repository.py b/monkey/tests/monkey_island/in_memory_agent_configuration_repository.py index e737d645c86..e9bcbae6279 100644 --- a/monkey/tests/monkey_island/in_memory_agent_configuration_repository.py +++ b/monkey/tests/monkey_island/in_memory_agent_configuration_repository.py @@ -1,12 +1,12 @@ from tests.common.example_agent_configuration import AGENT_CONFIGURATION -from common.configuration.agent_configuration import AgentConfigurationSchema +from common.configuration.agent_configuration import AgentConfiguration from monkey_island.cc.repository import IAgentConfigurationRepository class InMemoryAgentConfigurationRepository(IAgentConfigurationRepository): def __init__(self): - self._configuration = AgentConfigurationSchema().load(AGENT_CONFIGURATION) + self._configuration = AgentConfiguration.from_mapping(AGENT_CONFIGURATION) def get_configuration(self): return self._configuration diff --git a/monkey/tests/unit_tests/common/test_agent_configuration.py b/monkey/tests/unit_tests/common/configuration/test_agent_configuration.py similarity index 75% rename from monkey/tests/unit_tests/common/test_agent_configuration.py rename to monkey/tests/unit_tests/common/configuration/test_agent_configuration.py index 7ea80cfc549..e06a4cf3e3e 100644 --- a/monkey/tests/unit_tests/common/test_agent_configuration.py +++ b/monkey/tests/unit_tests/common/configuration/test_agent_configuration.py @@ -1,3 +1,7 @@ +import json +from copy import deepcopy + +import pytest from tests.common.example_agent_configuration import ( AGENT_CONFIGURATION, BLOCKED_IPS, @@ -23,11 +27,8 @@ WINDOWS_FILENAME, ) -from common.configuration import ( - DEFAULT_AGENT_CONFIGURATION_JSON, - AgentConfiguration, - AgentConfigurationSchema, -) +from common.configuration import AgentConfiguration, InvalidConfigurationError +from common.configuration.agent_configuration import AgentConfigurationSchema from common.configuration.agent_sub_configuration_schemas import ( CustomPBAConfigurationSchema, ExploitationConfigurationSchema, @@ -157,10 +158,8 @@ def test_propagation_configuration(): def test_agent_configuration(): - schema = AgentConfigurationSchema() - - config = schema.load(AGENT_CONFIGURATION) - config_dict = schema.dump(config) + config = AgentConfiguration.from_mapping(AGENT_CONFIGURATION) + config_json = AgentConfiguration.to_json(config) assert isinstance(config, AgentConfiguration) assert config.keep_tunnel_open_time == 30 @@ -169,12 +168,53 @@ def test_agent_configuration(): assert isinstance(config.credential_collectors[0], PluginConfiguration) assert isinstance(config.payloads[0], PluginConfiguration) assert isinstance(config.propagation, PropagationConfiguration) - assert config_dict == AGENT_CONFIGURATION + assert json.loads(config_json) == AGENT_CONFIGURATION + +def test_incorrect_type(): + valid_config = AgentConfiguration.from_mapping(AGENT_CONFIGURATION) + with pytest.raises(InvalidConfigurationError): + valid_config_dict = valid_config.__dict__ + valid_config_dict["keep_tunnel_open_time"] = "not_a_float" + AgentConfiguration(**valid_config_dict) -def test_default_agent_configuration(): + +def test_from_dict(): schema = AgentConfigurationSchema() + dict_ = deepcopy(AGENT_CONFIGURATION) + + config = AgentConfiguration.from_mapping(dict_) + + assert schema.dump(config) == dict_ + + +def test_from_dict__invalid_data(): + dict_ = deepcopy(AGENT_CONFIGURATION) + dict_["payloads"] = "payloads" - config = schema.loads(DEFAULT_AGENT_CONFIGURATION_JSON) + with pytest.raises(InvalidConfigurationError): + AgentConfiguration.from_mapping(dict_) + + +def test_from_json(): + schema = AgentConfigurationSchema() + dict_ = deepcopy(AGENT_CONFIGURATION) + + config = AgentConfiguration.from_json(json.dumps(dict_)) assert isinstance(config, AgentConfiguration) + assert schema.dump(config) == dict_ + + +def test_from_json__invalid_data(): + invalid_dict = deepcopy(AGENT_CONFIGURATION) + invalid_dict["payloads"] = "payloads" + + with pytest.raises(InvalidConfigurationError): + AgentConfiguration.from_json(json.dumps(invalid_dict)) + + +def test_to_json(): + config = deepcopy(AGENT_CONFIGURATION) + + assert json.loads(AgentConfiguration.to_json(config)) == AGENT_CONFIGURATION diff --git a/monkey/tests/unit_tests/conftest.py b/monkey/tests/unit_tests/conftest.py index 3634e52b9bf..51528ba00af 100644 --- a/monkey/tests/unit_tests/conftest.py +++ b/monkey/tests/unit_tests/conftest.py @@ -9,7 +9,7 @@ MONKEY_BASE_PATH = str(Path(__file__).parent.parent.parent) sys.path.insert(0, MONKEY_BASE_PATH) -from common.configuration import AgentConfiguration, build_default_agent_configuration # noqa: E402 +from common.configuration import DEFAULT_AGENT_CONFIGURATION, AgentConfiguration # noqa: E402 @pytest.fixture(scope="session") @@ -60,4 +60,4 @@ def inner(filename: str) -> Dict: @pytest.fixture def default_agent_configuration() -> AgentConfiguration: - return build_default_agent_configuration() + return DEFAULT_AGENT_CONFIGURATION diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py b/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py index f5d9c2da428..cc7e497b6c3 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py @@ -8,6 +8,10 @@ from tests.unit_tests.infection_monkey.master.mock_puppet import MockPuppet from common import OperatingSystems +from common.configuration.agent_sub_configurations import ( + ExploitationConfiguration, + ExploiterConfiguration, +) from infection_monkey.master import Exploiter from infection_monkey.model import VictimHost @@ -35,18 +39,18 @@ def callback(): @pytest.fixture -def exploiter_config(): - return { - "options": {"dropper_path_linux": "/tmp/monkey"}, - "brute_force": [ - {"name": "MSSQLExploiter", "options": {"timeout": 10}}, - {"name": "SSHExploiter", "options": {}}, - {"name": "WmiExploiter", "options": {"timeout": 10}}, - ], - "vulnerability": [ - {"name": "ZerologonExploiter", "options": {}}, - ], - } +def exploiter_config(default_agent_configuration): + brute_force = [ + ExploiterConfiguration(name="MSSQLExploiter", options={"timeout": 10}), + ExploiterConfiguration(name="SSHExploiter", options={}), + ExploiterConfiguration(name="WmiExploiter", options={"timeout": 10}), + ] + vulnerability = [ExploiterConfiguration(name="ZerologonExploiter", options={})] + return ExploitationConfiguration( + options=default_agent_configuration.propagation.exploitation.options, + brute_force=brute_force, + vulnerability=vulnerability, + ) @pytest.fixture diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_ip_scanner.py b/monkey/tests/unit_tests/infection_monkey/master/test_ip_scanner.py index 9fafdaee2e4..bf026510fc1 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_ip_scanner.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_ip_scanner.py @@ -5,6 +5,12 @@ import pytest from tests.unit_tests.infection_monkey.master.mock_puppet import MockPuppet +from common.configuration.agent_sub_configurations import ( + ICMPScanConfiguration, + NetworkScanConfiguration, + PluginConfiguration, + TCPScanConfiguration, +) from infection_monkey.i_puppet import FingerprintData, PortScanData, PortStatus from infection_monkey.master import IPScanner from infection_monkey.network import NetworkAddress @@ -14,28 +20,31 @@ @pytest.fixture -def scan_config(): - return { - "tcp": { - "timeout_ms": 3000, - "ports": [ - 22, - 445, - 3389, - 443, - 8008, - 3306, - ], - }, - "icmp": { - "timeout_ms": 1000, - }, - "fingerprinters": [ - {"name": "HTTPFinger", "options": {}}, - {"name": "SMBFinger", "options": {}}, - {"name": "SSHFinger", "options": {}}, +def scan_config(default_agent_configuration): + tcp_config = TCPScanConfiguration( + timeout=3, + ports=[ + 22, + 445, + 3389, + 443, + 8008, + 3306, ], - } + ) + icmp_config = ICMPScanConfiguration(timeout=1) + fingerprinter_config = [ + PluginConfiguration(name="HTTPFinger", options={}), + PluginConfiguration(name="SMBFinger", options={}), + PluginConfiguration(name="SSHFinger", options={}), + ] + scan_config = NetworkScanConfiguration( + tcp_config, + icmp_config, + fingerprinter_config, + default_agent_configuration.propagation.network_scan.targets, + ) + return scan_config @pytest.fixture diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py b/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py index 3746e65ebe6..2ebdcd84ea6 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_propagator.py @@ -3,6 +3,11 @@ import pytest +from common.configuration.agent_sub_configurations import ( + NetworkScanConfiguration, + PropagationConfiguration, + ScanTargetConfiguration, +) from infection_monkey.i_puppet import ( ExploiterResultData, FingerprintData, @@ -135,24 +140,37 @@ def exploit_hosts( pass -def test_scan_result_processing(telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory): +def get_propagation_config( + default_agent_configuration, scan_target_config: ScanTargetConfiguration +): + network_scan = NetworkScanConfiguration( + default_agent_configuration.propagation.network_scan.tcp, + default_agent_configuration.propagation.network_scan.icmp, + default_agent_configuration.propagation.network_scan.fingerprinters, + scan_target_config, + ) + propagation_config = PropagationConfiguration( + default_agent_configuration.propagation.maximum_depth, + network_scan, + default_agent_configuration.propagation.exploitation, + ) + return propagation_config + + +def test_scan_result_processing( + telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory, default_agent_configuration +): p = Propagator( telemetry_messenger_spy, mock_ip_scanner, StubExploiter(), mock_victim_host_factory, [] ) - p.propagate( - { - "targets": { - "subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"], - "local_network_scan": False, - "inaccessible_subnets": [], - "blocked_ips": [], - }, - "network_scan": {}, # This is empty since MockIPscanner ignores it - "exploiters": {}, # This is empty since StubExploiter ignores it - }, - 1, - Event(), + targets = ScanTargetConfiguration( + blocked_ips=[], + inaccessible_subnets=[], + local_network_scan=False, + subnets=["10.0.0.1", "10.0.0.2", "10.0.0.3"], ) + propagation_config = get_propagation_config(default_agent_configuration, targets) + p.propagate(propagation_config, 1, Event()) assert len(telemetry_messenger_spy.telemetries) == 3 @@ -237,25 +255,20 @@ def exploit_hosts( def test_exploiter_result_processing( - telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory + telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory, default_agent_configuration ): p = Propagator( telemetry_messenger_spy, mock_ip_scanner, MockExploiter(), mock_victim_host_factory, [] ) - p.propagate( - { - "targets": { - "subnet_scan_list": ["10.0.0.1", "10.0.0.2", "10.0.0.3"], - "local_network_scan": False, - "inaccessible_subnets": [], - "blocked_ips": [], - }, - "network_scan": {}, # This is empty since MockIPscanner ignores it - "exploiters": {}, # This is empty since MockExploiter ignores it - }, - 1, - Event(), + + targets = ScanTargetConfiguration( + blocked_ips=[], + inaccessible_subnets=[], + local_network_scan=False, + subnets=["10.0.0.1", "10.0.0.2", "10.0.0.3"], ) + propagation_config = get_propagation_config(default_agent_configuration, targets) + p.propagate(propagation_config, 1, Event()) exploit_telems = [t for t in telemetry_messenger_spy.telemetries if isinstance(t, ExploitTelem)] assert len(exploit_telems) == 4 @@ -278,7 +291,9 @@ def test_exploiter_result_processing( assert data["propagation_result"] -def test_scan_target_generation(telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory): +def test_scan_target_generation( + telemetry_messenger_spy, mock_ip_scanner, mock_victim_host_factory, default_agent_configuration +): local_network_interfaces = [NetworkInterface("10.0.0.9", "/29")] p = Propagator( telemetry_messenger_spy, @@ -287,20 +302,15 @@ def test_scan_target_generation(telemetry_messenger_spy, mock_ip_scanner, mock_v mock_victim_host_factory, local_network_interfaces, ) - p.propagate( - { - "targets": { - "subnet_scan_list": ["10.0.0.0/29", "172.10.20.30"], - "local_network_scan": True, - "blocked_ips": ["10.0.0.3"], - "inaccessible_subnets": ["10.0.0.128/30", "10.0.0.8/29"], - }, - "network_scan": {}, # This is empty since MockIPscanner ignores it - "exploiters": {}, # This is empty since MockExploiter ignores it - }, - 1, - Event(), + targets = ScanTargetConfiguration( + blocked_ips=["10.0.0.3"], + inaccessible_subnets=["10.0.0.128/30", "10.0.0.8/29"], + local_network_scan=True, + subnets=["10.0.0.0/29", "172.10.20.30"], ) + propagation_config = get_propagation_config(default_agent_configuration, targets) + p.propagate(propagation_config, 1, Event()) + expected_ip_scan_list = [ "10.0.0.0", "10.0.0.1", diff --git a/monkey/tests/unit_tests/infection_monkey/utils/test_propagation.py b/monkey/tests/unit_tests/infection_monkey/utils/test_propagation.py index 37f7194a663..19b2c18b574 100644 --- a/monkey/tests/unit_tests/infection_monkey/utils/test_propagation.py +++ b/monkey/tests/unit_tests/infection_monkey/utils/test_propagation.py @@ -1,32 +1,22 @@ -from infection_monkey.utils.propagation import should_propagate +from infection_monkey.utils.propagation import maximum_depth_reached -def get_config(max_depth): - return {"config": {"depth": max_depth}} - - -def test_should_propagate_current_less_than_max(): - max_depth = 2 +def test_maximum_depth_reached__current_less_than_max(): + maximum_depth = 2 current_depth = 1 - config = get_config(max_depth) - - assert should_propagate(config, current_depth) is True + assert maximum_depth_reached(maximum_depth, current_depth) is True -def test_should_propagate_current_greater_than_max(): - max_depth = 2 +def test_maximum_depth_reached__current_greater_than_max(): + maximum_depth = 2 current_depth = 3 - config = get_config(max_depth) - - assert should_propagate(config, current_depth) is False - + assert maximum_depth_reached(maximum_depth, current_depth) is False -def test_should_propagate_current_equal_to_max(): - max_depth = 2 - current_depth = max_depth - config = get_config(max_depth) +def test_maximum_depth_reached__current_equal_to_max(): + maximum_depth = 2 + current_depth = maximum_depth - assert should_propagate(config, current_depth) is False + assert maximum_depth_reached(maximum_depth, current_depth) is False diff --git a/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_agent_configuration_repository.py b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_agent_configuration_repository.py index 4ab111606c2..fb7863dc321 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_agent_configuration_repository.py +++ b/monkey/tests/unit_tests/monkey_island/cc/repository/test_file_agent_configuration_repository.py @@ -2,7 +2,7 @@ from tests.common.example_agent_configuration import AGENT_CONFIGURATION from tests.monkey_island import OpenErrorFileRepository, SingleFileRepository -from common.configuration import AgentConfigurationSchema +from common.configuration import AgentConfiguration from monkey_island.cc.repository import FileAgentConfigurationRepository, RetrievalError @@ -12,8 +12,7 @@ def repository(default_agent_configuration): def test_store_agent_config(repository): - schema = AgentConfigurationSchema() - agent_configuration = schema.load(AGENT_CONFIGURATION) + agent_configuration = AgentConfiguration.from_mapping(AGENT_CONFIGURATION) repository.store_configuration(agent_configuration) retrieved_agent_configuration = repository.get_configuration() diff --git a/monkey/tests/unit_tests/monkey_island/cc/services/test_config.py b/monkey/tests/unit_tests/monkey_island/cc/services/test_config.py index 85f3f4823b4..f170b08652c 100644 --- a/monkey/tests/unit_tests/monkey_island/cc/services/test_config.py +++ b/monkey/tests/unit_tests/monkey_island/cc/services/test_config.py @@ -99,7 +99,7 @@ def test_format_config_for_agent__propagation(): def test_format_config_for_agent__network_scan(): expected_network_scan_config = { "tcp": { - "timeout": 3000, + "timeout": 3.0, "ports": [ 22, 80, @@ -117,7 +117,7 @@ def test_format_config_for_agent__network_scan(): ], }, "icmp": { - "timeout": 1000, + "timeout": 1.0, }, "targets": { "blocked_ips": ["192.168.1.1", "192.168.1.100"],