Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2216 use tcprelay in agent #2251

Merged
merged 14 commits into from
Sep 7, 2022
7 changes: 7 additions & 0 deletions monkey/common/network/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@


def address_to_ip_port(address: str) -> Tuple[str, Optional[str]]:
"""
Split a string containing an IP address (and optionally a port) into IP and Port components.
Currently only works for IPv4 addresses.

:param address: The address string.
:return: Tuple of IP and port strings. The port may be None if no port was in the address.
"""
if ":" in address:
ip, port = address.split(":")
return ip, port or None
Expand Down
2 changes: 1 addition & 1 deletion monkey/infection_monkey/master/automated_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _run_simulation(self):
current_depth = self._current_depth if self._current_depth is not None else 0
logger.info(f"Current depth is {current_depth}")

if maximum_depth_reached(config.propagation.maximum_depth, current_depth):
if not maximum_depth_reached(config.propagation.maximum_depth, current_depth):
self._propagator.propagate(config.propagation, current_depth, self._servers, self._stop)
else:
logger.info("Skipping propagation: maximum depth reached")
Expand Down
37 changes: 19 additions & 18 deletions monkey/infection_monkey/monkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import subprocess
import sys
from ipaddress import IPv4Interface
from ipaddress import IPv4Address, IPv4Interface
from pathlib import Path, WindowsPath
from typing import List

Expand Down Expand Up @@ -41,7 +41,8 @@
from infection_monkey.master.control_channel import ControlChannel
from infection_monkey.model import VictimHostFactory
from infection_monkey.network.firewall import app as firewall
from infection_monkey.network.info import get_network_interfaces
from infection_monkey.network.info import get_free_tcp_port, get_network_interfaces
from infection_monkey.network.relay import TCPRelay
from infection_monkey.network_scanning.elasticsearch_fingerprinter import ElasticSearchFingerprinter
from infection_monkey.network_scanning.http_fingerprinter import HTTPFingerprinter
from infection_monkey.network_scanning.mssql_fingerprinter import MSSQLFingerprinter
Expand Down Expand Up @@ -100,11 +101,10 @@ def __init__(self, args):
# TODO Refactor the telemetry messengers to accept control client
# and remove control_client_object
ControlClient.control_client_object = self._control_client
self._monkey_inbound_tunnel = None
self._telemetry_messenger = LegacyTelemetryMessengerAdapter()
self._current_depth = self._opts.depth
self._master = None
self._inbound_tunnel_opened = False
self._relay: TCPRelay

@staticmethod
def _get_arguments(args):
Expand Down Expand Up @@ -180,14 +180,17 @@ def _setup(self):
control_channel.register_agent(self._opts.parent)

config = control_channel.get_config()
self._monkey_inbound_tunnel = self._control_client.create_control_tunnel(
config.keep_tunnel_open_time

relay_port = get_free_tcp_port()
self._relay = TCPRelay(
relay_port,
IPv4Address(self._cmd_island_ip),
self._cmd_island_port,
client_disconnect_timeout=config.keep_tunnel_open_time,
)
if self._monkey_inbound_tunnel and maximum_depth_reached(
config.propagation.maximum_depth, self._current_depth
):
self._inbound_tunnel_opened = True
self._monkey_inbound_tunnel.start()

if not maximum_depth_reached(config.propagation.maximum_depth, self._current_depth):
self._relay.start()

StateTelem(is_done=False, version=get_version()).send()
TunnelTelem(self._control_client.proxies).send()
Expand Down Expand Up @@ -215,7 +218,7 @@ def _build_master(self):
victim_host_factory = self._build_victim_host_factory(local_network_interfaces)

telemetry_messenger = ExploitInterceptingTelemetryMessenger(
self._telemetry_messenger, self._monkey_inbound_tunnel
self._telemetry_messenger, self._relay
)

self._master = AutomatedMaster(
Expand Down Expand Up @@ -374,9 +377,7 @@ def _build_victim_host_factory(
on_island = self._running_on_island(local_network_interfaces)
logger.debug(f"This agent is running on the island: {on_island}")

return VictimHostFactory(
self._monkey_inbound_tunnel, self._cmd_island_ip, self._cmd_island_port, on_island
)
return VictimHostFactory(None, self._cmd_island_ip, self._cmd_island_port, on_island)
mssalvatore marked this conversation as resolved.
Show resolved Hide resolved

def _running_on_island(self, local_network_interfaces: List[IPv4Interface]) -> bool:
server_ip, _ = address_to_ip_port(self._control_client.server_address)
Expand All @@ -394,9 +395,9 @@ def cleanup(self):

reset_signal_handlers()

if self._inbound_tunnel_opened:
self._monkey_inbound_tunnel.stop()
self._monkey_inbound_tunnel.join()
if self._relay and self._relay.is_alive():
self._relay.stop()
self._relay.join(timeout=60)

if firewall.is_enabled():
firewall.remove_firewall_rule()
Expand Down
18 changes: 12 additions & 6 deletions monkey/infection_monkey/network/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
import struct
from dataclasses import dataclass
from ipaddress import IPv4Interface
from random import randint # noqa: DUO102
from random import shuffle # noqa: DUO102
from typing import List

import netifaces
import psutil

from infection_monkey.utils.environment import is_windows_os

from .ports import COMMON_PORTS

# Timeout for monkey connections
LOOPBACK_NAME = b"lo"
SIOCGIFADDR = 0x8915 # get PA address
Expand Down Expand Up @@ -119,14 +121,18 @@ def get_routes(): # based on scapy implementation for route parsing


def get_free_tcp_port(min_range=1024, max_range=65535):
min_range = max(1, min_range)
max_range = min(65535, max_range)

in_use = [conn.laddr[1] for conn in psutil.net_connections()]
in_use = {conn.laddr[1] for conn in psutil.net_connections()}

for i in range(min_range, max_range):
port = randint(min_range, max_range)
for port in COMMON_PORTS:
if port not in in_use:
return port

min_range = max(1, min_range)
max_range = min(65535, max_range)
ports = list(range(min_range, max_range))
shuffle(ports)
for port in ports:
if port not in in_use:
return port

Expand Down
15 changes: 15 additions & 0 deletions monkey/infection_monkey/network/ports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import List

COMMON_PORTS: List[int] = [
1025, # NFS, IIS
1433, # Microsoft SQL Server
1434, # Microsoft SQL Monitor
1720, # h323q931
1723, # Microsoft PPTP VPN
3306, # mysql
3389, # Windows Terminal Server (RDP)
5900, # vnc
6001, # X11:1
8080, # http-proxy
8888, # sun-answerbook
]
1 change: 1 addition & 0 deletions monkey/infection_monkey/network/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .sockets_pipe import SocketsPipe
from .tcp_connection_handler import TCPConnectionHandler
from .tcp_pipe_spawner import TCPPipeSpawner
from .tcp_relay import TCPRelay
36 changes: 29 additions & 7 deletions monkey/infection_monkey/network/relay/tcp_relay.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from ipaddress import IPv4Address
from threading import Lock, Thread
from time import sleep

from infection_monkey.network.relay import RelayUserHandler, TCPConnectionHandler, TCPPipeSpawner
from infection_monkey.network.relay import (
RelayConnectionHandler,
RelayUserHandler,
TCPConnectionHandler,
TCPPipeSpawner,
)
from infection_monkey.utils.threading import InterruptableThreadMixin


Expand All @@ -12,13 +18,21 @@ class TCPRelay(Thread, InterruptableThreadMixin):

def __init__(
self,
relay_user_handler: RelayUserHandler,
connection_handler: TCPConnectionHandler,
pipe_spawner: TCPPipeSpawner,
relay_port: int,
dest_addr: IPv4Address,
dest_port: int,
client_disconnect_timeout: float,
):
self._user_handler = relay_user_handler
self._connection_handler = connection_handler
self._pipe_spawner = pipe_spawner
self._user_handler = RelayUserHandler(client_disconnect_timeout=client_disconnect_timeout)
self._pipe_spawner = TCPPipeSpawner(dest_addr, dest_port)
relay_filter = RelayConnectionHandler(self._pipe_spawner, self._user_handler)
self._connection_handler = TCPConnectionHandler(
bind_host="",
bind_port=relay_port,
client_connected=[
relay_filter.handle_new_connection,
],
)
super().__init__(name="MonkeyTcpRelayThread", daemon=True)
self._lock = Lock()

Expand All @@ -32,6 +46,14 @@ def run(self):
self._connection_handler.join()
self._wait_for_pipes_to_close()

def add_potential_user(self, user_address: IPv4Address):
"""
Notify TCPRelay of a user that may try to connect.

:param user_address: The address of the potential new user.
"""
self._user_handler.add_potential_user(user_address)

def _wait_for_users_to_disconnect(self):
"""
Blocks until the users disconnect or the timeout has elapsed.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
from functools import singledispatch
from ipaddress import IPv4Address

from infection_monkey.network.relay.tcp_relay import TCPRelay
from infection_monkey.network.relay import TCPRelay
from infection_monkey.telemetry.exploit_telem import ExploitTelem
from infection_monkey.telemetry.i_telem import ITelem
from infection_monkey.telemetry.messengers.i_telemetry_messenger import ITelemetryMessenger
from infection_monkey.tunnel import MonkeyTunnel


class ExploitInterceptingTelemetryMessenger(ITelemetryMessenger):
def __init__(
self, telemetry_messenger: ITelemetryMessenger, tunnel: MonkeyTunnel, relay: TCPRelay
):
def __init__(self, telemetry_messenger: ITelemetryMessenger, relay: TCPRelay):
self._telemetry_messenger = telemetry_messenger
self._tunnel = tunnel
self._relay = relay

def send_telemetry(self, telemetry: ITelem):
_send_telemetry(telemetry, self._telemetry_messenger, self._tunnel, self._relay)
_send_telemetry(telemetry, self._telemetry_messenger, self._relay)


# Note: We can use @singledispatchmethod instead of @singledispatch if we migrate to Python 3.8 or
Expand All @@ -26,7 +22,6 @@ def send_telemetry(self, telemetry: ITelem):
def _send_telemetry(
telemetry: ITelem,
telemetry_messenger: ITelemetryMessenger,
tunnel: MonkeyTunnel,
relay: TCPRelay,
):
telemetry_messenger.send_telemetry(telemetry)
Expand All @@ -36,12 +31,11 @@ def _send_telemetry(
def _(
telemetry: ExploitTelem,
telemetry_messenger: ITelemetryMessenger,
tunnel: MonkeyTunnel,
relay: TCPRelay,
):
if telemetry.propagation_result is True:
tunnel.set_wait_for_exploited_machines()
if relay:
relay.add_potential_user(IPv4Address(telemetry.host["ip_addr"]))
address = IPv4Address(str(telemetry.host["ip_addr"]))
relay.add_potential_user(address)

telemetry_messenger.send_telemetry(telemetry)
10 changes: 9 additions & 1 deletion monkey/infection_monkey/utils/propagation.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@
def maximum_depth_reached(maximum_depth: int, current_depth: int) -> bool:
return maximum_depth > current_depth
"""
Return whether or not the current depth has eclipsed the maximum depth.
Values are nonnegative. Depth should increase from zero.

:param maximum_depth: The maximum depth.
:param current_depth: The current depth.
:return: True if the current depth has reached the maximum depth, otherwise False.
"""
return current_depth >= maximum_depth
40 changes: 40 additions & 0 deletions monkey/tests/unit_tests/infection_monkey/network/test_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from dataclasses import dataclass
from typing import Tuple

import pytest

from infection_monkey.network.info import get_free_tcp_port
from infection_monkey.network.ports import COMMON_PORTS


@dataclass
class Connection:
laddr: Tuple[str, int]


@pytest.mark.parametrize("port", COMMON_PORTS)
def test_get_free_tcp_port__checks_common_ports(port: int, monkeypatch):
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS if p is not port]

monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)
assert get_free_tcp_port() is port


def test_get_free_tcp_port__checks_other_ports_if_common_ports_unavailable(monkeypatch):
unavailable_ports = [Connection(("", p)) for p in COMMON_PORTS]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)

assert get_free_tcp_port() is not None


def test_get_free_tcp_port__none_if_no_available_ports(monkeypatch):
unavailable_ports = [Connection(("", p)) for p in range(65535)]
monkeypatch.setattr(
"infection_monkey.network.info.psutil.net_connections", lambda: unavailable_ports
)

assert get_free_tcp_port() is None
Original file line number Diff line number Diff line change
Expand Up @@ -20,49 +20,43 @@ def get_data(self):

def test_generic_telemetry(TestTelem):
mock_telemetry_messenger = MagicMock()
mock_tunnel = MagicMock()
mock_relay = MagicMock()

telemetry_messenger = ExploitInterceptingTelemetryMessenger(
mock_telemetry_messenger, mock_tunnel, mock_relay
mock_telemetry_messenger, mock_relay
)

telemetry_messenger.send_telemetry(TestTelem())

assert mock_telemetry_messenger.send_telemetry.called
assert not mock_tunnel.set_wait_for_exploited_machines.called
assert not mock_relay.add_potential_user.called


def test_propagation_successful_exploit_telemetry():
mock_telemetry_messenger = MagicMock()
mock_tunnel = MagicMock()
mock_relay = MagicMock()
mock_exploit_telem = MockExploitTelem(True)

telemetry_messenger = ExploitInterceptingTelemetryMessenger(
mock_telemetry_messenger, mock_tunnel, mock_relay
mock_telemetry_messenger, mock_relay
)

telemetry_messenger.send_telemetry(mock_exploit_telem)

assert mock_telemetry_messenger.send_telemetry.called
assert mock_tunnel.set_wait_for_exploited_machines.called
assert mock_relay.add_potential_user.called


def test_propagation_failed_exploit_telemetry():
mock_telemetry_messenger = MagicMock()
mock_tunnel = MagicMock()
mock_relay = MagicMock()
mock_exploit_telem = MockExploitTelem(False)

telemetry_messenger = ExploitInterceptingTelemetryMessenger(
mock_telemetry_messenger, mock_tunnel, mock_relay
mock_telemetry_messenger, mock_relay
)

telemetry_messenger.send_telemetry(mock_exploit_telem)

assert mock_telemetry_messenger.send_telemetry.called
assert not mock_tunnel.set_wait_for_exploited_machines.called
assert not mock_relay.add_potential_user.called
Loading