Skip to content

Commit

Permalink
SSH Exploiter: Improve code quality
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyamalviya committed Jul 6, 2023
1 parent 24d5ed5 commit 3dae412
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 35 deletions.
20 changes: 9 additions & 11 deletions monkey/agent_plugins/exploiters/ssh/src/plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from functools import partial
from pathlib import PurePath
from pprint import pformat
from typing import Any, Dict, Sequence

Expand Down Expand Up @@ -91,18 +90,17 @@ def run(
exploitation_success=False, propagation_success=False, error_message=msg
)

def build_command(path: PurePath):
return build_ssh_command(
agent_id=self._agent_id,
target_host=host,
servers=servers,
current_depth=current_depth,
remote_agent_binary_destination_path=path,
otp_provider=self._otp_provider,
)
command_builder = partial(
build_ssh_command,
agent_id=self._agent_id,
target_host=host,
servers=servers,
current_depth=current_depth,
otp_provider=self._otp_provider,
)

ssh_exploit_client_factory = SSHRemoteAccessClientFactory(
host=host, options=ssh_options, command_builder=build_command
host=host, options=ssh_options, command_builder=command_builder
)

brute_force_exploiter = BruteForceExploiter(
Expand Down
49 changes: 26 additions & 23 deletions monkey/agent_plugins/exploiters/ssh/src/ssh_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import paramiko

from common.common_consts.timeouts import LONG_REQUEST_TIMEOUT, MEDIUM_REQUEST_TIMEOUT
from common.credentials import Credentials, Password, Username, get_plaintext
from common.credentials import Credentials, Password, SSHKeypair, Username, get_plaintext
from common.types import NetworkPort
from infection_monkey.i_puppet import TargetHost

Expand Down Expand Up @@ -37,17 +37,25 @@ def connect(
:param timeout: Timeout for the connection, in seconds
:raises Exception: If the connection could not be established
"""

if isinstance(credentials.secret, SSHKeypair):
connect_function = self._connect_with_private_key
elif isinstance(credentials.secret, Password):
connect_function = self._connect_with_login_credentials
else:
message = "Unrecognised credential secret type"
logger.debug(message)
raise Exception(message)

client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.WarningPolicy())

try:
self._connect_with_private_key(client, host, credentials, port, timeout)
connect_function(client, host, credentials, port, timeout)
logger.debug(f"Successfully authenticated using SSH on host: {host.ip}")
except Exception:
try:
self._connect_with_login_credentials(client, host, credentials, port, timeout)
except Exception as err:
client.close()
raise err
except Exception as err:
client.close()
raise err

self._client = client
self._authenticated = True
Expand All @@ -61,16 +69,14 @@ def _connect_with_private_key(
timeout: float,
):
username = (
credentials.identity.username
if isinstance(credentials.identity, Username)
else credentials.identity
credentials.identity.username if isinstance(credentials.identity, Username) else None
)

try:
private_key_buffer = io.StringIO(get_plaintext(credentials.secret.private_key))
private_key = paramiko.RSAKey.from_private_key(private_key_buffer)
except (IOError, paramiko.SSHException, paramiko.PasswordRequiredException) as err:
logger.error("Failed reading ssh key")
logger.error("Failed reading SSH key")
raise err

try:
Expand All @@ -86,18 +92,18 @@ def _connect_with_private_key(
allow_agent=False,
)
logger.debug(
f"Successfully logged in {host.ip} using {username}@{host.ip} user's private key"
f"Successfully logged into {host.ip} using {username}@{host.ip} user's private key"
)
except paramiko.AuthenticationException as err:
error_message = (
f"Failed logging into victim {host.ip} with {username}@{host.ip}"
f"Failed logging into victim {host.ip} with {username}@{host.ip} user's"
f"private key: {err}"
)
logger.info(error_message)
raise err
except Exception as err:
error_message = (
f"Unexpected error while attempting to login to {username}@{host.ip} with ssh key: "
f"Unexpected error while attempting to login to {username}@{host.ip} with SSH key: "
f"{err}"
)
logger.error(error_message)
Expand All @@ -112,14 +118,9 @@ def _connect_with_login_credentials(
timeout: float,
):
username = (
credentials.identity.username
if isinstance(credentials.identity, Username)
else credentials.identity
credentials.identity.username if isinstance(credentials.identity, Username) else None
)

if not isinstance(credentials.secret, Password):
raise Exception(f"No suitable credentials found for SSH login on host: {host.ip}")

try:
client.connect(
str(host.ip),
Expand Down Expand Up @@ -156,7 +157,7 @@ def copy_file(
:raises Exception: If the file copy failed
"""
try:
with self._client.open_sftp() as sftp:
with self._client.open_sftp() as sftp: # type: ignore [union-attr]
sftp.putfo(
io.BytesIO(file),
str(destination_path),
Expand All @@ -180,7 +181,9 @@ def execute_command(self, command: str) -> bytes:
:raises Exception: If the command execution failed
:return: The command output
"""
_, stdout, _ = self._client.exec_command(command, timeout=SSH_EXEC_TIMEOUT)
_, stdout, _ = self._client.exec_command( # type: ignore [union-attr]
command=command, timeout=SSH_EXEC_TIMEOUT
)
return stdout

def connected(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def execute_agent(self, agent_binary_path: PurePath, tags: Set[str]):
self._raise_if_not_authenticated(RemoteCommandExecutionError)
try:
tags.update(EXECUTION_TAGS)
self._ssh_client.execute_command(self._build_command(agent_binary_path))
self._ssh_client.execute_command(
self._build_command(remote_agent_binary_destination_path=agent_binary_path)
)
except Exception as err:
raise RemoteCommandExecutionError(err)

Expand Down

0 comments on commit 3dae412

Please sign in to comment.