diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9b2acbd..94c09ca 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -79,6 +79,7 @@ jobs: run: | python -m black . --check python -m isort . --check-only + python -m mypy . pytest -v --cov=winrm --cov-report=term-missing winrm/tests/ - name: upload coverage data diff --git a/.gitignore b/.gitignore index 7d6a3ec..c7019aa 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ __pycache__ /winrm/tests/config.json .pytest_cache venv +.mypy_cache \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3c0b4e8..df422e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ include-package-data = true packages = ["winrm"] [tool.setuptools.package-data] +"winrm" = ["py.typed"] "winrm.tests" = ["*.ps1"] [tool.setuptools.dynamic] @@ -77,4 +78,38 @@ exclude = ''' ''' [tool.isort] -profile = "black" \ No newline at end of file +profile = "black" + +[tool.mypy] +exclude = "build/|winrm/tests/|winrm/vendor/" +mypy_path = "$MYPY_CONFIG_FILE_DIR" +python_version = "3.8" +show_error_codes = true +show_column_numbers = true +disallow_any_unimported = true +disallow_untyped_calls = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_reexport = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = true + +[[tool.mypy.overrides]] +module = "winrm.vendor.*" +follow_imports = "skip" + +[[tool.mypy.overrides]] +module = "requests.packages.urllib3.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "requests_credssp" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "requests_ntlm" +ignore_missing_imports = true \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index efe27f0..8f54e8b 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,6 +1,9 @@ # this assumes the base requirements have been satisfied via setup.py black == 24.4.2 isort == 5.13.2 +mypy == 1.10.0 pytest pytest-cov mock +types-requests +types-xmltodict \ No newline at end of file diff --git a/winrm/__init__.py b/winrm/__init__.py index 6b76e4f..ec31a32 100644 --- a/winrm/__init__.py +++ b/winrm/__init__.py @@ -1,6 +1,8 @@ from __future__ import annotations +import collections.abc import re +import typing as t import warnings import xml.etree.ElementTree as ET from base64 import b64encode @@ -22,22 +24,22 @@ class Response(object): """Response from a remote command execution""" - def __init__(self, args): + def __init__(self, args: tuple[bytes, bytes, int]) -> None: self.std_out, self.std_err, self.status_code = args - def __repr__(self): + def __repr__(self) -> str: # TODO put tree dots at the end if out/err was truncated - return ''.format(self.status_code, self.std_out[:20], self.std_err[:20]) + return ''.format(self.status_code, self.std_out[:20], self.std_err[:20]) class Session(object): # TODO implement context manager methods - def __init__(self, target, auth, **kwargs): + def __init__(self, target: str, auth: tuple[str, str], **kwargs: t.Any) -> None: username, password = auth self.url = self._build_url(target, kwargs.get("transport", "plaintext")) self.protocol = Protocol(self.url, username=username, password=password, **kwargs) - def run_cmd(self, command, args=()): + def run_cmd(self, command: str, args: collections.abc.Iterable[str | bytes] = ()) -> Response: # TODO optimize perf. Do not call open/close shell every time shell_id = self.protocol.open_shell() command_id = self.protocol.run_command(shell_id, command, args) @@ -46,7 +48,7 @@ def run_cmd(self, command, args=()): self.protocol.close_shell(shell_id) return rs - def run_ps(self, script): + def run_ps(self, script: str) -> Response: """base64 encodes a Powershell script and executes the powershell encoded script command """ @@ -59,7 +61,7 @@ def run_ps(self, script): rs.std_err = self._clean_error_msg(rs.std_err) return rs - def _clean_error_msg(self, msg): + def _clean_error_msg(self, msg: bytes) -> bytes: """converts a Powershell CLIXML message to a more human readable string""" # TODO prepare unit test, beautify code # if the msg does not start with this, return it as is @@ -77,7 +79,8 @@ def _clean_error_msg(self, msg): for s in nodes: # append error msg string to result, also # the hex chars represent CRLF so we replace with newline - new_msg += s.text.replace("_x000D__x000A_", "\n") + if s.text: + new_msg += s.text.replace("_x000D__x000A_", "\n") except Exception as e: # if any of the above fails, the msg was not true xml # print a warning and return the original string @@ -93,7 +96,7 @@ def _clean_error_msg(self, msg): # just return the original message return msg - def _strip_namespace(self, xml): + def _strip_namespace(self, xml: bytes) -> bytes: """strips any namespaces from an xml string""" p = re.compile(b'xmlns=*[""][^""]*[""]') allmatches = p.finditer(xml) @@ -102,8 +105,11 @@ def _strip_namespace(self, xml): return xml @staticmethod - def _build_url(target, transport): + def _build_url(target: str, transport: str) -> str: match = re.match(r"(?i)^((?Phttp[s]?)://)?(?P[0-9a-z-_.]+)(:(?P\d+))?(?P(/)?(wsman)?)?", target) # NOQA + if not match: + raise ValueError("Invalid target URL: {0}".format(target)) + scheme = match.group("scheme") if not scheme: # TODO do we have anything other than HTTP/HTTPS diff --git a/winrm/encryption.py b/winrm/encryption.py index 0cf02f6..50ceb0e 100644 --- a/winrm/encryption.py +++ b/winrm/encryption.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re import struct from urllib.parse import urlsplit @@ -12,7 +14,7 @@ class Encryption(object): SIXTEN_KB = 16384 MIME_BOUNDARY = b"--Encrypted Boundary" - def __init__(self, session, protocol): + def __init__(self, session: requests.Session, protocol: str) -> None: """ [MS-WSMV] v30.0 2016-07-14 @@ -51,7 +53,7 @@ def __init__(self, session, protocol): else: raise WinRMError("Encryption for protocol '%s' not supported in pywinrm" % protocol) - def prepare_encrypted_request(self, session, endpoint, message): + def prepare_encrypted_request(self, session: requests.Session, endpoint: str | bytes, message: bytes) -> requests.PreparedRequest: """ Creates a prepared request to send to the server with an encrypted message and correct headers @@ -77,12 +79,12 @@ def prepare_encrypted_request(self, session, endpoint, message): request = requests.Request("POST", endpoint, data=encrypted_message) prepared_request = session.prepare_request(request) - prepared_request.headers["Content-Length"] = str(len(prepared_request.body)) + prepared_request.headers["Content-Length"] = str(len(prepared_request.body)) if prepared_request.body else "0" prepared_request.headers["Content-Type"] = '{0};protocol="{1}";boundary="Encrypted Boundary"'.format(content_type, self.protocol_string.decode()) return prepared_request - def parse_encrypted_response(self, response): + def parse_encrypted_response(self, response: requests.Response) -> bytes: """ Takes in the encrypted response from the server and decrypts it @@ -90,15 +92,16 @@ def parse_encrypted_response(self, response): :return: The unencrypted message from the server """ content_type = response.headers["Content-Type"] + if 'protocol="{0}"'.format(self.protocol_string.decode()) in content_type: host = urlsplit(response.request.url).hostname msg = self._decrypt_response(response, host) else: - msg = response.text + msg = response.content return msg - def _encrypt_message(self, message, host): + def _encrypt_message(self, message: bytes, host: str | bytes | None) -> bytes: message_length = str(len(message)).encode() encrypted_stream = self._build_message(message, host) @@ -111,7 +114,7 @@ def _encrypt_message(self, message, host): return message_payload - def _decrypt_response(self, response, host): + def _decrypt_response(self, response: requests.Response, host: str | bytes | None) -> bytes: parts = response.content.split(self.MIME_BOUNDARY + b"\r\n") parts = list(filter(None, parts)) # filter out empty parts of the split message = b"" @@ -139,41 +142,41 @@ def _decrypt_response(self, response, host): return message - def _decrypt_ntlm_message(self, encrypted_data, host): + def _decrypt_ntlm_message(self, encrypted_data: bytes, host: str | bytes | None) -> bytes: signature_length = struct.unpack(" bytes: # trailer_length = struct.unpack(" bytes: signature_length = struct.unpack(" bytes: + sealed_message, signature = self.session.auth.session_security.wrap(message) # type: ignore[union-attr] signature_length = struct.pack(" bytes: + credssp_context = self.session.auth.contexts[host] # type: ignore[union-attr] sealed_message = credssp_context.wrap(message) cipher_negotiated = credssp_context.tls_connection.get_cipher_name() @@ -181,13 +184,13 @@ def _build_credssp_message(self, message, host): return struct.pack(" bytes: + sealed_message, signature = self.session.auth.wrap_winrm(host, message) # type: ignore[union-attr] signature_length = struct.pack(" int: # I really don't like the way this works but can't find a better way, MS # allows you to get this info through the struct SecPkgContext_StreamSizes # but there is no GSSAPI/OpenSSL equivalent so we need to calculate it diff --git a/winrm/exceptions.py b/winrm/exceptions.py index 8e3e54f..6282b20 100644 --- a/winrm/exceptions.py +++ b/winrm/exceptions.py @@ -11,22 +11,22 @@ class WinRMTransportError(Exception): """WinRM errors specific to transport-level problems (unexpected HTTP error codes, etc)""" @property - def protocol(self): + def protocol(self) -> str: return self.args[0] @property - def code(self): + def code(self) -> int: return self.args[1] @property - def message(self): + def message(self) -> str: return "Bad HTTP response returned from server. Code {0}".format(self.code) @property - def response_text(self): + def response_text(self) -> str: return self.args[2] - def __str__(self): + def __str__(self) -> str: return self.message diff --git a/winrm/protocol.py b/winrm/protocol.py index 6d9b1f9..d096e29 100644 --- a/winrm/protocol.py +++ b/winrm/protocol.py @@ -3,6 +3,8 @@ from __future__ import annotations import base64 +import collections.abc +import typing as t import uuid import xml.etree.ElementTree as ET @@ -31,25 +33,25 @@ class Protocol(object): def __init__( self, - endpoint, - transport="plaintext", - username=None, - password=None, - realm=None, - service="HTTP", - keytab=None, - ca_trust_path="legacy_requests", - cert_pem=None, - cert_key_pem=None, - server_cert_validation="validate", - kerberos_delegation=False, - read_timeout_sec=DEFAULT_READ_TIMEOUT_SEC, - operation_timeout_sec=DEFAULT_OPERATION_TIMEOUT_SEC, - kerberos_hostname_override=None, - message_encryption="auto", - credssp_disable_tlsv1_2=False, - send_cbt=True, - proxy="legacy_requests", + endpoint: str, + transport: t.Literal["auto", "basic", "certificate", "ntlm", "kerberos", "credssp", "plaintext", "ssl"] = "plaintext", + username: str | None = None, + password: str | None = None, + realm: None = None, + service: str = "HTTP", + keytab: None = None, + ca_trust_path: t.Literal["legacy_requests"] | str = "legacy_requests", + cert_pem: str | None = None, + cert_key_pem: str | None = None, + server_cert_validation: t.Literal["validate", "ignore"] | None = "validate", + kerberos_delegation: bool = False, + read_timeout_sec: str | int = DEFAULT_READ_TIMEOUT_SEC, + operation_timeout_sec: str | int = DEFAULT_OPERATION_TIMEOUT_SEC, + kerberos_hostname_override: str | None = None, + message_encryption: t.Literal["auto", "always", "never"] = "auto", + credssp_disable_tlsv1_2: bool = False, + send_cbt: bool = True, + proxy: t.Literal["legacy_requests"] | str | None = "legacy_requests", ): """ @param string endpoint: the WinRM webservice endpoint @@ -58,7 +60,7 @@ def __init__( @param string password: password @param string realm: unused @param string service: the service name, default is HTTP - @param string keytab: the path to a keytab file if you are using one + @param string keytab: unused @param string ca_trust_path: Certification Authority trust path. If server_cert_validation is set to 'validate': 'legacy_requests'(default) to use environment variables, None to explicitly disallow any additional CA trust path @@ -125,15 +127,15 @@ def __init__( def open_shell( self, - i_stream="stdin", - o_stream="stdout stderr", - working_directory=None, - env_vars=None, - noprofile=False, - codepage=437, - lifetime=None, - idle_timeout=None, - ): + i_stream: str = "stdin", + o_stream: str = "stdout stderr", + working_directory: str | None = None, + env_vars: dict[str, str] | None = None, + noprofile: bool = False, + codepage: int = 437, + lifetime: None = None, + idle_timeout: str | int | None = None, + ) -> str: """ Create a Shell on the destination host @param string i_stream: Which input stream to open. Leave this alone @@ -187,13 +189,19 @@ def open_shell( # res = xmltodict.parse(res) # return res['s:Envelope']['s:Body']['x:ResourceCreated']['a:ReferenceParameters']['w:SelectorSet']['w:Selector']['#text'] root = ET.fromstring(res) - return next(node for node in root.findall(".//*") if node.get("Name") == "ShellId").text + return t.cast(str, next(node for node in root.findall(".//*") if node.get("Name") == "ShellId").text) # Helper method for building SOAP Header - def _get_soap_header(self, action=None, resource_uri=None, shell_id=None, message_id=None): + def _get_soap_header( + self, + action: str | None = None, + resource_uri: str | None = None, + shell_id: str | None = None, + message_id: uuid.UUID | None = None, + ) -> dict[str, t.Any]: if not message_id: message_id = uuid.uuid4() - header = { + header: dict[str, t.Any] = { "@xmlns:xsd": "http://www.w3.org/2001/XMLSchema", "@xmlns:xsi": "http://www.w3.org/2001/XMLSchema-instance", "@xmlns:env": xmlns["soapenv"], @@ -224,7 +232,7 @@ def _get_soap_header(self, action=None, resource_uri=None, shell_id=None, messag header["env:Header"]["w:SelectorSet"] = {"w:Selector": {"@Name": "ShellId", "#text": shell_id}} return header - def send_message(self, message): + def send_message(self, message: str) -> bytes: # TODO add message_id vs relates_to checking # TODO port error handling code try: @@ -257,15 +265,17 @@ def send_message(self, message): if fault_subcode is not None: fault_data["fault_subcode"] = fault_subcode.text - error_message = fault.find("soapenv:Reason/soapenv:Text", xmlns) - if error_message is not None: - error_message = error_message.text + error_message_node = fault.find("soapenv:Reason/soapenv:Text", xmlns) + if error_message_node is not None: + error_message = error_message_node.text else: error_message = "(no error message in fault)" raise WinRMError("{0} (extended fault data: {1})".format(error_message, fault_data)) - def close_shell(self, shell_id, close_session=True): + raise + + def close_shell(self, shell_id: str, close_session: bool = True) -> None: """ Close the shell @param string shell_id: The shell id on the remote machine. @@ -292,7 +302,7 @@ def close_shell(self, shell_id, close_session=True): res = self.send_message(xmltodict.unparse(req)) root = ET.fromstring(res) - relates_to = next(node for node in root.findall(".//*") if node.tag.endswith("RelatesTo")).text + relates_to = t.cast(str, next(node for node in root.findall(".//*") if node.tag.endswith("RelatesTo")).text) finally: # Close the transport if we are done with the shell. # This will ensure no lingering TCP connections are thrown back into a requests' connection pool. @@ -302,7 +312,14 @@ def close_shell(self, shell_id, close_session=True): # TODO change assert into user-friendly exception assert uuid.UUID(relates_to.replace("uuid:", "")) == message_id - def run_command(self, shell_id, command, arguments=(), console_mode_stdin=True, skip_cmd_shell=False): + def run_command( + self, + shell_id: str, + command: str, + arguments: collections.abc.Iterable[str | bytes] = (), + console_mode_stdin: bool = True, + skip_cmd_shell: bool = False, + ) -> str: """ Run a command on a machine with an open shell @param string shell_id: The shell id on the remote machine. @@ -339,9 +356,9 @@ def run_command(self, shell_id, command, arguments=(), console_mode_stdin=True, res = self.send_message(xmltodict.unparse(req)) root = ET.fromstring(res) command_id = next(node for node in root.findall(".//*") if node.tag.endswith("CommandId")).text - return command_id + return t.cast(str, command_id) - def cleanup_command(self, shell_id, command_id): + def cleanup_command(self, shell_id: str, command_id: str) -> None: """ Clean-up after a command. @see #run_command @param string shell_id: The shell id on the remote machine. @@ -369,11 +386,11 @@ def cleanup_command(self, shell_id, command_id): res = self.send_message(xmltodict.unparse(req)) root = ET.fromstring(res) - relates_to = next(node for node in root.findall(".//*") if node.tag.endswith("RelatesTo")).text + relates_to = t.cast(str, next(node for node in root.findall(".//*") if node.tag.endswith("RelatesTo")).text) # TODO change assert into user-friendly exception assert uuid.UUID(relates_to.replace("uuid:", "")) == message_id - def send_command_input(self, shell_id, command_id, stdin_input, end=False): + def send_command_input(self, shell_id: str, command_id: str, stdin_input: str | bytes, end: bool = False) -> None: """ Send input to the given shell and command. @param string shell_id: The shell id on the remote machine. @@ -408,7 +425,7 @@ def send_command_input(self, shell_id, command_id, stdin_input, end=False): stdin_envelope["#text"] = base64.b64encode(stdin_input) self.send_message(xmltodict.unparse(req)) - def get_command_output(self, shell_id, command_id): + def get_command_output(self, shell_id: str, command_id: str) -> tuple[bytes, bytes, int]: """ Get the Output of the given shell and command @param string shell_id: The shell id on the remote machine. @@ -433,7 +450,7 @@ def get_command_output(self, shell_id, command_id): pass return b"".join(stdout_buffer), b"".join(stderr_buffer), return_code - def _raw_get_command_output(self, shell_id, command_id): + def _raw_get_command_output(self, shell_id: str, command_id: str) -> tuple[bytes, bytes, int, bool]: req = { "env:Envelope": self._get_soap_header( resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA @@ -470,6 +487,6 @@ def _raw_get_command_output(self, shell_id, command_id): # command_done = len([node for node in root.findall(".//*") if node.get("State", "").endswith("CommandState/Done")]) == 1 if command_done: - return_code = int(next(node for node in root.findall(".//*") if node.tag.endswith("ExitCode")).text) + return_code = int(next(node for node in root.findall(".//*") if node.tag.endswith("ExitCode")).text or -1) return stdout, stderr, return_code, command_done diff --git a/winrm/py.typed b/winrm/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/winrm/transport.py b/winrm/transport.py index d432d22..fa2e781 100644 --- a/winrm/transport.py +++ b/winrm/transport.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import typing as t import warnings import requests @@ -40,7 +41,7 @@ __all__ = ["Transport"] -def strtobool(value): +def strtobool(value: str) -> bool: value = value.lower() if value in ("true", "t", "yes", "y", "on", "1"): return True @@ -59,27 +60,27 @@ class UnsupportedAuthArgument(Warning): class Transport(object): def __init__( self, - endpoint, - username=None, - password=None, - realm=None, - service=None, - keytab=None, - ca_trust_path="legacy_requests", - cert_pem=None, - cert_key_pem=None, - read_timeout_sec=None, - server_cert_validation="validate", - kerberos_delegation=False, - kerberos_hostname_override=None, - auth_method="auto", - message_encryption="auto", - credssp_disable_tlsv1_2=False, - credssp_auth_mechanism="auto", - credssp_minimum_version=2, - send_cbt=True, - proxy="legacy_requests", - ): + endpoint: str, + username: str | None = None, + password: str | None = None, + realm: None = None, + service: str | None = None, + keytab: None = None, + ca_trust_path: t.Literal["legacy_requests"] | str = "legacy_requests", + cert_pem: str | None = None, + cert_key_pem: str | None = None, + read_timeout_sec: int | None = None, + server_cert_validation: t.Literal["validate", "ignore"] | None = "validate", + kerberos_delegation: bool | str = False, + kerberos_hostname_override: str | None = None, + auth_method: t.Literal["auto", "basic", "certificate", "ntlm", "kerberos", "credssp", "plaintext", "ssl"] = "auto", + message_encryption: t.Literal["auto", "always", "never"] = "auto", + credssp_disable_tlsv1_2: bool = False, + credssp_auth_mechanism: t.Literal["auto", "ntlm", "kerberos"] = "auto", + credssp_minimum_version: int = 2, + send_cbt: bool = True, + proxy: t.Literal["legacy_requests"] | str | None = "legacy_requests", + ) -> None: self.endpoint = endpoint self.username = username self.password = password @@ -161,14 +162,17 @@ def __init__( if self.password is None: raise InvalidCredentialsError("auth method %s requires a password" % self.auth_method) - self.session = None + self.session: requests.Session | None = None # Used for encrypting messages - self.encryption = None # The Pywinrm Encryption class used to encrypt/decrypt messages + self.encryption: Encryption | None = None # The Pywinrm Encryption class used to encrypt/decrypt messages if self.message_encryption not in ["auto", "always", "never"]: raise WinRMError("invalid message_encryption arg: %s. Should be 'auto', 'always', or 'never'" % self.message_encryption) - def build_session(self): + def build_session(self) -> requests.Session: + if self.session: + return self.session + session = requests.Session() proxies = dict() @@ -234,7 +238,7 @@ def build_session(self): if not HAVE_KERBEROS: raise WinRMError("requested auth method is kerberos, but pykerberos is not installed") - session.auth = HTTPKerberosAuth( + kerb_auth = session.auth = HTTPKerberosAuth( mutual_authentication=REQUIRED, delegate=self.kerberos_delegation, force_preemptive=True, @@ -244,14 +248,17 @@ def build_session(self): service=self.service, send_cbt=self.send_cbt, ) - encryption_available = hasattr(session.auth, "winrm_encryption_available") and session.auth.winrm_encryption_available + encryption_available = hasattr(session.auth, "winrm_encryption_available") and kerb_auth.winrm_encryption_available elif self.auth_method in ["certificate", "ssl"]: if self.auth_method == "ssl" and not self.cert_pem and not self.cert_key_pem: # 'ssl' was overloaded for HTTPS with optional certificate auth, # fall back to basic auth if no cert specified - session.auth = requests.auth.HTTPBasicAuth(username=self.username, password=self.password) + session.auth = requests.auth.HTTPBasicAuth( + username=self.username or "", + password=self.password or "", + ) else: - session.cert = (self.cert_pem, self.cert_key_pem) + session.cert = (self.cert_pem or "", self.cert_key_pem or "") session.headers["Authorization"] = "http://schemas.dmtf.org/wbem/wsman/1/wsman/secprofile/https/mutual" elif self.auth_method == "ntlm": if not HAVE_NTLM: @@ -266,7 +273,10 @@ def build_session(self): encryption_available = hasattr(session.auth, "session_security") # TODO: ssl is not exactly right here- should really be client_cert elif self.auth_method in ["basic", "plaintext"]: - session.auth = requests.auth.HTTPBasicAuth(username=self.username, password=self.password) + session.auth = requests.auth.HTTPBasicAuth( + username=self.username or "", + password=self.password or "", + ) elif self.auth_method == "credssp": if not HAVE_CREDSSP: raise WinRMError("requests auth method is credssp, but requests-credssp is not installed") @@ -290,26 +300,27 @@ def build_session(self): raise WinRMError("message encryption is set to 'always' but the selected auth method %s does not support it" % self.auth_method) elif encryption_available: if self.message_encryption == "always": - self.setup_encryption() + self.setup_encryption(session) elif self.message_encryption == "auto" and not self.endpoint.lower().startswith("https"): - self.setup_encryption() + self.setup_encryption(session) - def setup_encryption(self): + return session + + def setup_encryption(self, session: requests.Session) -> None: # Security context doesn't exist, sending blank message to initialise context request = requests.Request("POST", self.endpoint, data=None) - prepared_request = self.session.prepare_request(request) - self._send_message_request(prepared_request, "") - self.encryption = Encryption(self.session, self.auth_method) + prepared_request = session.prepare_request(request) + self._send_message_request(session, prepared_request) + self.encryption = Encryption(session, self.auth_method) - def close_session(self): + def close_session(self) -> None: if not self.session: return self.session.close() self.session = None - def send_message(self, message): - if not self.session: - self.build_session() + def send_message(self, message: str | bytes) -> bytes: + session = self.build_session() # urllib3 fails on SSL retries with unicode buffers- must send it a byte string # see https://github.com/shazow/urllib3/issues/717 @@ -317,17 +328,17 @@ def send_message(self, message): message = message.encode("utf-8") if self.encryption: - prepared_request = self.encryption.prepare_encrypted_request(self.session, self.endpoint, message) + prepared_request = self.encryption.prepare_encrypted_request(session, self.endpoint, message) else: request = requests.Request("POST", self.endpoint, data=message) - prepared_request = self.session.prepare_request(request) + prepared_request = session.prepare_request(request) - response = self._send_message_request(prepared_request, message) + response = self._send_message_request(session, prepared_request) return self._get_message_response_text(response) - def _send_message_request(self, prepared_request, message): + def _send_message_request(self, session: requests.Session, prepared_request: requests.PreparedRequest) -> requests.Response: try: - response = self.session.send(prepared_request, timeout=self.read_timeout_sec) + response = session.send(prepared_request, timeout=self.read_timeout_sec) response.raise_for_status() return response except requests.HTTPError as ex: @@ -336,11 +347,11 @@ def _send_message_request(self, prepared_request, message): if ex.response.content: response_text = self._get_message_response_text(ex.response) else: - response_text = "" + response_text = b"" - raise WinRMTransportError("http", ex.response.status_code, response_text) + raise WinRMTransportError("http", ex.response.status_code, response_text.decode()) - def _get_message_response_text(self, response): + def _get_message_response_text(self, response: requests.Response) -> bytes: if self.encryption: response_text = self.encryption.parse_encrypted_response(response) else: