Skip to content

Commit

Permalink
Agent: Pass Agent ID to credentials collectors
Browse files Browse the repository at this point in the history
Issue #3119
PR #3157
  • Loading branch information
cakekoa committed Mar 28, 2023
1 parent 59fed74 commit 3d40568
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from common.credentials import Credentials, LMHash, NTHash, Password, Username
from common.event_queue import IAgentEventQueue
from common.tags import T1003_ATTACK_TECHNIQUE_TAG, T1005_ATTACK_TECHNIQUE_TAG
from common.types import AgentID
from infection_monkey.i_puppet import ICredentialCollector
from infection_monkey.model import USERNAME_PREFIX
from infection_monkey.utils.ids import get_agent_id

from . import pypykatz_handler
from .windows_credentials import WindowsCredentials
Expand All @@ -27,8 +27,9 @@


class MimikatzCredentialCollector(ICredentialCollector):
def __init__(self, agent_event_queue: IAgentEventQueue):
def __init__(self, agent_event_queue: IAgentEventQueue, agent_id: AgentID):
self._agent_event_queue = agent_event_queue
self._agent_id = agent_id

def collect_credentials(self, options=None) -> Sequence[Credentials]:
logger.info("Attempting to collect windows credentials with pypykatz.")
Expand Down Expand Up @@ -76,7 +77,7 @@ def _to_credentials(windows_credentials: Sequence[WindowsCredentials]) -> Sequen

def _publish_credentials_stolen_event(self, collected_credentials: Sequence[Credentials]):
credentials_stolen_event = CredentialsStolenEvent(
source=get_agent_id(),
source=self._agent_id,
tags=MIMIKATZ_EVENT_TAGS,
stolen_credentials=collected_credentials,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from common.credentials import Credentials
from common.event_queue import IAgentEventQueue
from common.types import AgentID
from infection_monkey.credential_collectors.ssh_collector import ssh_handler
from infection_monkey.i_puppet import ICredentialCollector

Expand All @@ -14,12 +15,13 @@ class SSHCredentialCollector(ICredentialCollector):
SSH keys credential collector
"""

def __init__(self, agent_event_queue: IAgentEventQueue):
def __init__(self, agent_event_queue: IAgentEventQueue, agent_id: AgentID):
self._agent_event_queue = agent_event_queue
self._agent_id = agent_id

def collect_credentials(self, _options=None) -> Sequence[Credentials]:
logger.info("Started scanning for SSH credentials")
ssh_info = ssh_handler.get_ssh_info(self._agent_event_queue)
ssh_info = ssh_handler.get_ssh_info(self._agent_event_queue, self._agent_id)
logger.info("Finished scanning for SSH credentials")

return ssh_handler.to_credentials(ssh_info)
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
T1005_ATTACK_TECHNIQUE_TAG,
T1145_ATTACK_TECHNIQUE_TAG,
)
from common.types import AgentID
from common.utils.environment import is_windows_os
from infection_monkey.utils.ids import get_agent_id

logger = logging.getLogger(__name__)

Expand All @@ -29,7 +29,7 @@
)


def get_ssh_info(agent_event_queue: IAgentEventQueue) -> Iterable[Dict]:
def get_ssh_info(agent_event_queue: IAgentEventQueue, agent_id: AgentID) -> Iterable[Dict]:
# TODO: Remove this check when this is turned into a plugin.
if is_windows_os():
logger.debug(
Expand All @@ -38,7 +38,7 @@ def get_ssh_info(agent_event_queue: IAgentEventQueue) -> Iterable[Dict]:
return []

home_dirs = _get_home_dirs()
ssh_info = _get_ssh_files(home_dirs, agent_event_queue)
ssh_info = _get_ssh_files(home_dirs, agent_event_queue, agent_id)

return ssh_info

Expand Down Expand Up @@ -79,6 +79,7 @@ def _get_ssh_struct(name: str, home_dir: str) -> Dict:
def _get_ssh_files(
user_info: Iterable[Dict],
agent_event_queue: IAgentEventQueue,
agent_id: AgentID,
) -> Iterable[Dict]:
for info in user_info:
path = info["home_dir"]
Expand Down Expand Up @@ -109,7 +110,7 @@ def _get_ssh_files(
logger.info("Found private key in %s" % private)
collected_credentials = to_credentials([info])
_publish_credentials_stolen_event(
collected_credentials, agent_event_queue
collected_credentials, agent_event_queue, agent_id
)
else:
continue
Expand Down Expand Up @@ -154,10 +155,12 @@ def to_credentials(ssh_info: Iterable[Dict]) -> Sequence[Credentials]:


def _publish_credentials_stolen_event(
collected_credentials: Sequence[Credentials], agent_event_queue: IAgentEventQueue
collected_credentials: Sequence[Credentials],
agent_event_queue: IAgentEventQueue,
agent_id: AgentID,
):
credentials_stolen_event = CredentialsStolenEvent(
source=get_agent_id(),
source=agent_id,
tags=SSH_COLLECTOR_EVENT_TAGS,
stolen_credentials=collected_credentials,
)
Expand Down
4 changes: 2 additions & 2 deletions monkey/infection_monkey/monkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,12 @@ def _build_puppet(self, operating_system: OperatingSystem) -> IPuppet:
puppet.load_plugin(
AgentPluginType.CREDENTIAL_COLLECTOR,
"MimikatzCollector",
MimikatzCredentialCollector(self._agent_event_queue),
MimikatzCredentialCollector(self._agent_event_queue, self._agent_id),
)
puppet.load_plugin(
AgentPluginType.CREDENTIAL_COLLECTOR,
"SSHCollector",
SSHCredentialCollector(self._agent_event_queue),
SSHCredentialCollector(self._agent_event_queue, self._agent_id),
)

puppet.load_plugin(AgentPluginType.FINGERPRINTER, "http", HTTPFingerprinter())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from common.agent_events import CredentialsStolenEvent
from common.credentials import Credentials, LMHash, NTHash, Password, Username
from common.event_queue import IAgentEventQueue
from common.types import AgentID
from infection_monkey.credential_collectors import MimikatzCredentialCollector
from infection_monkey.credential_collectors.mimikatz_collector.mimikatz_credential_collector import ( # noqa: E501
MIMIKATZ_EVENT_TAGS,
Expand All @@ -14,8 +15,10 @@
WindowsCredentials,
)

AGENT_ID = AgentID("be11ad56-995d-45fd-be03-e7806a47b56b")

def patch_pypykatz(win_creds: [WindowsCredentials], monkeypatch):

def patch_pypykatz(win_creds: Sequence[WindowsCredentials], monkeypatch):
monkeypatch.setattr(
"infection_monkey.credential_collectors"
".mimikatz_collector.pypykatz_handler.get_windows_creds",
Expand All @@ -25,7 +28,7 @@ def patch_pypykatz(win_creds: [WindowsCredentials], monkeypatch):

def collect_credentials() -> Sequence[Credentials]:
mock_event_queue = MagicMock(spec=IAgentEventQueue)
return MimikatzCredentialCollector(mock_event_queue).collect_credentials()
return MimikatzCredentialCollector(mock_event_queue, AGENT_ID).collect_credentials()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -129,7 +132,7 @@ def test_mimikatz_credentials_stolen_event_published(monkeypatch):
mock_event_queue = MagicMock(spec=IAgentEventQueue)
patch_pypykatz([], monkeypatch)

mimikatz_credential_collector = MimikatzCredentialCollector(mock_event_queue)
mimikatz_credential_collector = MimikatzCredentialCollector(mock_event_queue, AGENT_ID)
mimikatz_credential_collector.collect_credentials()

mock_event_queue.publish.assert_called_once()
Expand All @@ -143,7 +146,7 @@ def test_mimikatz_credentials_stolen_event_tags(monkeypatch):
mock_event_queue = MagicMock(spec=IAgentEventQueue)
patch_pypykatz([], monkeypatch)

mimikatz_credential_collector = MimikatzCredentialCollector(mock_event_queue)
mimikatz_credential_collector = MimikatzCredentialCollector(mock_event_queue, AGENT_ID)
mimikatz_credential_collector.collect_credentials()

mock_event_queue_call_args = mock_event_queue.publish.call_args[0][0]
Expand All @@ -160,7 +163,7 @@ def test_mimikatz_credentials_stolen_event_stolen_credentials(monkeypatch):
]
patch_pypykatz(win_creds, monkeypatch)

mimikatz_credential_collector = MimikatzCredentialCollector(mock_event_queue)
mimikatz_credential_collector = MimikatzCredentialCollector(mock_event_queue, AGENT_ID)
collected_credentials = mimikatz_credential_collector.collect_credentials()

mock_event_queue_call_args = mock_event_queue.publish.call_args[0][0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

from common.credentials import Credentials, SSHKeypair, Username
from common.event_queue import IAgentEventQueue
from common.types import AgentID
from infection_monkey.credential_collectors import SSHCredentialCollector

AGENT_ID = AgentID("ed077054-a316-479a-a99d-75bb378c0a6e")


def patch_ssh_handler(ssh_creds, monkeypatch):
monkeypatch.setattr(
"infection_monkey.credential_collectors.ssh_collector.ssh_handler.get_ssh_info",
lambda _: ssh_creds,
lambda _, __: ssh_creds,
)


Expand All @@ -19,7 +22,9 @@ def patch_ssh_handler(ssh_creds, monkeypatch):
)
def test_ssh_credentials_empty_results(monkeypatch, ssh_creds):
patch_ssh_handler(ssh_creds, monkeypatch)
collected = SSHCredentialCollector(MagicMock(spec=IAgentEventQueue)).collect_credentials()
collected = SSHCredentialCollector(
MagicMock(spec=IAgentEventQueue), AGENT_ID
).collect_credentials()
assert not collected


Expand Down Expand Up @@ -64,5 +69,7 @@ def test_ssh_info_result_parsing(monkeypatch):
Credentials(identity=username3, secret=None),
Credentials(identity=None, secret=ssh_keypair3),
]
collected = SSHCredentialCollector(MagicMock(spec=IAgentEventQueue)).collect_credentials()
collected = SSHCredentialCollector(
MagicMock(spec=IAgentEventQueue), AGENT_ID
).collect_credentials()
assert expected == collected

0 comments on commit 3d40568

Please sign in to comment.