From fbb05e87dc8c7ad79bc5dc1d334c39f28bc80b73 Mon Sep 17 00:00:00 2001 From: Jordan Borean Date: Thu, 6 Jun 2024 10:37:33 +1000 Subject: [PATCH] Add type annotations (#381) Add type annotations to the pywinrm library. This can help reduce type related bugs inside the library and also help callers to figure out the correct values that can be used in the public API. --- .github/workflows/ci.yml | 1 + .gitignore | 1 + pyproject.toml | 37 ++++++++++++- requirements-test.txt | 3 ++ winrm/__init__.py | 26 ++++++---- winrm/encryption.py | 43 ++++++++------- winrm/exceptions.py | 10 ++-- winrm/protocol.py | 109 ++++++++++++++++++++++----------------- winrm/py.typed | 0 winrm/transport.py | 107 +++++++++++++++++++++----------------- 10 files changed, 207 insertions(+), 130 deletions(-) create mode 100644 winrm/py.typed diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9b2acbd8..94c09caa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -79,6 +79,7 @@ jobs: run: | python -m black . --check python -m isort . --check-only + python -m mypy . pytest -v --cov=winrm --cov-report=term-missing winrm/tests/ - name: upload coverage data diff --git a/.gitignore b/.gitignore index 7d6a3ec9..c7019aac 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ __pycache__ /winrm/tests/config.json .pytest_cache venv +.mypy_cache \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3c0b4e83..df422e72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ include-package-data = true packages = ["winrm"] [tool.setuptools.package-data] +"winrm" = ["py.typed"] "winrm.tests" = ["*.ps1"] [tool.setuptools.dynamic] @@ -77,4 +78,38 @@ exclude = ''' ''' [tool.isort] -profile = "black" \ No newline at end of file +profile = "black" + +[tool.mypy] +exclude = "build/|winrm/tests/|winrm/vendor/" +mypy_path = "$MYPY_CONFIG_FILE_DIR" +python_version = "3.8" +show_error_codes = true +show_column_numbers = true +disallow_any_unimported = true +disallow_untyped_calls = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_reexport = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = true + +[[tool.mypy.overrides]] +module = "winrm.vendor.*" +follow_imports = "skip" + +[[tool.mypy.overrides]] +module = "requests.packages.urllib3.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "requests_credssp" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "requests_ntlm" +ignore_missing_imports = true \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index efe27f03..8f54e8b1 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,6 +1,9 @@ # this assumes the base requirements have been satisfied via setup.py black == 24.4.2 isort == 5.13.2 +mypy == 1.10.0 pytest pytest-cov mock +types-requests +types-xmltodict \ No newline at end of file diff --git a/winrm/__init__.py b/winrm/__init__.py index 6b76e4fa..ec31a328 100644 --- a/winrm/__init__.py +++ b/winrm/__init__.py @@ -1,6 +1,8 @@ from __future__ import annotations +import collections.abc import re +import typing as t import warnings import xml.etree.ElementTree as ET from base64 import b64encode @@ -22,22 +24,22 @@ class Response(object): """Response from a remote command execution""" - def __init__(self, args): + def __init__(self, args: tuple[bytes, bytes, int]) -> None: self.std_out, self.std_err, self.status_code = args - def __repr__(self): + def __repr__(self) -> str: # TODO put tree dots at the end if out/err was truncated - return ''.format(self.status_code, self.std_out[:20], self.std_err[:20]) + return ''.format(self.status_code, self.std_out[:20], self.std_err[:20]) class Session(object): # TODO implement context manager methods - def __init__(self, target, auth, **kwargs): + def __init__(self, target: str, auth: tuple[str, str], **kwargs: t.Any) -> None: username, password = auth self.url = self._build_url(target, kwargs.get("transport", "plaintext")) self.protocol = Protocol(self.url, username=username, password=password, **kwargs) - def run_cmd(self, command, args=()): + def run_cmd(self, command: str, args: collections.abc.Iterable[str | bytes] = ()) -> Response: # TODO optimize perf. Do not call open/close shell every time shell_id = self.protocol.open_shell() command_id = self.protocol.run_command(shell_id, command, args) @@ -46,7 +48,7 @@ def run_cmd(self, command, args=()): self.protocol.close_shell(shell_id) return rs - def run_ps(self, script): + def run_ps(self, script: str) -> Response: """base64 encodes a Powershell script and executes the powershell encoded script command """ @@ -59,7 +61,7 @@ def run_ps(self, script): rs.std_err = self._clean_error_msg(rs.std_err) return rs - def _clean_error_msg(self, msg): + def _clean_error_msg(self, msg: bytes) -> bytes: """converts a Powershell CLIXML message to a more human readable string""" # TODO prepare unit test, beautify code # if the msg does not start with this, return it as is @@ -77,7 +79,8 @@ def _clean_error_msg(self, msg): for s in nodes: # append error msg string to result, also # the hex chars represent CRLF so we replace with newline - new_msg += s.text.replace("_x000D__x000A_", "\n") + if s.text: + new_msg += s.text.replace("_x000D__x000A_", "\n") except Exception as e: # if any of the above fails, the msg was not true xml # print a warning and return the original string @@ -93,7 +96,7 @@ def _clean_error_msg(self, msg): # just return the original message return msg - def _strip_namespace(self, xml): + def _strip_namespace(self, xml: bytes) -> bytes: """strips any namespaces from an xml string""" p = re.compile(b'xmlns=*[""][^""]*[""]') allmatches = p.finditer(xml) @@ -102,8 +105,11 @@ def _strip_namespace(self, xml): return xml @staticmethod - def _build_url(target, transport): + def _build_url(target: str, transport: str) -> str: match = re.match(r"(?i)^((?Phttp[s]?)://)?(?P[0-9a-z-_.]+)(:(?P\d+))?(?P(/)?(wsman)?)?", target) # NOQA + if not match: + raise ValueError("Invalid target URL: {0}".format(target)) + scheme = match.group("scheme") if not scheme: # TODO do we have anything other than HTTP/HTTPS diff --git a/winrm/encryption.py b/winrm/encryption.py index 0cf02f6b..50ceb0e3 100644 --- a/winrm/encryption.py +++ b/winrm/encryption.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re import struct from urllib.parse import urlsplit @@ -12,7 +14,7 @@ class Encryption(object): SIXTEN_KB = 16384 MIME_BOUNDARY = b"--Encrypted Boundary" - def __init__(self, session, protocol): + def __init__(self, session: requests.Session, protocol: str) -> None: """ [MS-WSMV] v30.0 2016-07-14 @@ -51,7 +53,7 @@ def __init__(self, session, protocol): else: raise WinRMError("Encryption for protocol '%s' not supported in pywinrm" % protocol) - def prepare_encrypted_request(self, session, endpoint, message): + def prepare_encrypted_request(self, session: requests.Session, endpoint: str | bytes, message: bytes) -> requests.PreparedRequest: """ Creates a prepared request to send to the server with an encrypted message and correct headers @@ -77,12 +79,12 @@ def prepare_encrypted_request(self, session, endpoint, message): request = requests.Request("POST", endpoint, data=encrypted_message) prepared_request = session.prepare_request(request) - prepared_request.headers["Content-Length"] = str(len(prepared_request.body)) + prepared_request.headers["Content-Length"] = str(len(prepared_request.body)) if prepared_request.body else "0" prepared_request.headers["Content-Type"] = '{0};protocol="{1}";boundary="Encrypted Boundary"'.format(content_type, self.protocol_string.decode()) return prepared_request - def parse_encrypted_response(self, response): + def parse_encrypted_response(self, response: requests.Response) -> bytes: """ Takes in the encrypted response from the server and decrypts it @@ -90,15 +92,16 @@ def parse_encrypted_response(self, response): :return: The unencrypted message from the server """ content_type = response.headers["Content-Type"] + if 'protocol="{0}"'.format(self.protocol_string.decode()) in content_type: host = urlsplit(response.request.url).hostname msg = self._decrypt_response(response, host) else: - msg = response.text + msg = response.content return msg - def _encrypt_message(self, message, host): + def _encrypt_message(self, message: bytes, host: str | bytes | None) -> bytes: message_length = str(len(message)).encode() encrypted_stream = self._build_message(message, host) @@ -111,7 +114,7 @@ def _encrypt_message(self, message, host): return message_payload - def _decrypt_response(self, response, host): + def _decrypt_response(self, response: requests.Response, host: str | bytes | None) -> bytes: parts = response.content.split(self.MIME_BOUNDARY + b"\r\n") parts = list(filter(None, parts)) # filter out empty parts of the split message = b"" @@ -139,41 +142,41 @@ def _decrypt_response(self, response, host): return message - def _decrypt_ntlm_message(self, encrypted_data, host): + def _decrypt_ntlm_message(self, encrypted_data: bytes, host: str | bytes | None) -> bytes: signature_length = struct.unpack(" bytes: # trailer_length = struct.unpack(" bytes: signature_length = struct.unpack(" bytes: + sealed_message, signature = self.session.auth.session_security.wrap(message) # type: ignore[union-attr] signature_length = struct.pack(" bytes: + credssp_context = self.session.auth.contexts[host] # type: ignore[union-attr] sealed_message = credssp_context.wrap(message) cipher_negotiated = credssp_context.tls_connection.get_cipher_name() @@ -181,13 +184,13 @@ def _build_credssp_message(self, message, host): return struct.pack(" bytes: + sealed_message, signature = self.session.auth.wrap_winrm(host, message) # type: ignore[union-attr] signature_length = struct.pack(" int: # I really don't like the way this works but can't find a better way, MS # allows you to get this info through the struct SecPkgContext_StreamSizes # but there is no GSSAPI/OpenSSL equivalent so we need to calculate it diff --git a/winrm/exceptions.py b/winrm/exceptions.py index 8e3e54f9..6282b20e 100644 --- a/winrm/exceptions.py +++ b/winrm/exceptions.py @@ -11,22 +11,22 @@ class WinRMTransportError(Exception): """WinRM errors specific to transport-level problems (unexpected HTTP error codes, etc)""" @property - def protocol(self): + def protocol(self) -> str: return self.args[0] @property - def code(self): + def code(self) -> int: return self.args[1] @property - def message(self): + def message(self) -> str: return "Bad HTTP response returned from server. Code {0}".format(self.code) @property - def response_text(self): + def response_text(self) -> str: return self.args[2] - def __str__(self): + def __str__(self) -> str: return self.message diff --git a/winrm/protocol.py b/winrm/protocol.py index 6d9b1f92..d096e29d 100644 --- a/winrm/protocol.py +++ b/winrm/protocol.py @@ -3,6 +3,8 @@ from __future__ import annotations import base64 +import collections.abc +import typing as t import uuid import xml.etree.ElementTree as ET @@ -31,25 +33,25 @@ class Protocol(object): def __init__( self, - endpoint, - transport="plaintext", - username=None, - password=None, - realm=None, - service="HTTP", - keytab=None, - ca_trust_path="legacy_requests", - cert_pem=None, - cert_key_pem=None, - server_cert_validation="validate", - kerberos_delegation=False, - read_timeout_sec=DEFAULT_READ_TIMEOUT_SEC, - operation_timeout_sec=DEFAULT_OPERATION_TIMEOUT_SEC, - kerberos_hostname_override=None, - message_encryption="auto", - credssp_disable_tlsv1_2=False, - send_cbt=True, - proxy="legacy_requests", + endpoint: str, + transport: t.Literal["auto", "basic", "certificate", "ntlm", "kerberos", "credssp", "plaintext", "ssl"] = "plaintext", + username: str | None = None, + password: str | None = None, + realm: None = None, + service: str = "HTTP", + keytab: None = None, + ca_trust_path: t.Literal["legacy_requests"] | str = "legacy_requests", + cert_pem: str | None = None, + cert_key_pem: str | None = None, + server_cert_validation: t.Literal["validate", "ignore"] | None = "validate", + kerberos_delegation: bool = False, + read_timeout_sec: str | int = DEFAULT_READ_TIMEOUT_SEC, + operation_timeout_sec: str | int = DEFAULT_OPERATION_TIMEOUT_SEC, + kerberos_hostname_override: str | None = None, + message_encryption: t.Literal["auto", "always", "never"] = "auto", + credssp_disable_tlsv1_2: bool = False, + send_cbt: bool = True, + proxy: t.Literal["legacy_requests"] | str | None = "legacy_requests", ): """ @param string endpoint: the WinRM webservice endpoint @@ -58,7 +60,7 @@ def __init__( @param string password: password @param string realm: unused @param string service: the service name, default is HTTP - @param string keytab: the path to a keytab file if you are using one + @param string keytab: unused @param string ca_trust_path: Certification Authority trust path. If server_cert_validation is set to 'validate': 'legacy_requests'(default) to use environment variables, None to explicitly disallow any additional CA trust path @@ -125,15 +127,15 @@ def __init__( def open_shell( self, - i_stream="stdin", - o_stream="stdout stderr", - working_directory=None, - env_vars=None, - noprofile=False, - codepage=437, - lifetime=None, - idle_timeout=None, - ): + i_stream: str = "stdin", + o_stream: str = "stdout stderr", + working_directory: str | None = None, + env_vars: dict[str, str] | None = None, + noprofile: bool = False, + codepage: int = 437, + lifetime: None = None, + idle_timeout: str | int | None = None, + ) -> str: """ Create a Shell on the destination host @param string i_stream: Which input stream to open. Leave this alone @@ -187,13 +189,19 @@ def open_shell( # res = xmltodict.parse(res) # return res['s:Envelope']['s:Body']['x:ResourceCreated']['a:ReferenceParameters']['w:SelectorSet']['w:Selector']['#text'] root = ET.fromstring(res) - return next(node for node in root.findall(".//*") if node.get("Name") == "ShellId").text + return t.cast(str, next(node for node in root.findall(".//*") if node.get("Name") == "ShellId").text) # Helper method for building SOAP Header - def _get_soap_header(self, action=None, resource_uri=None, shell_id=None, message_id=None): + def _get_soap_header( + self, + action: str | None = None, + resource_uri: str | None = None, + shell_id: str | None = None, + message_id: uuid.UUID | None = None, + ) -> dict[str, t.Any]: if not message_id: message_id = uuid.uuid4() - header = { + header: dict[str, t.Any] = { "@xmlns:xsd": "http://www.w3.org/2001/XMLSchema", "@xmlns:xsi": "http://www.w3.org/2001/XMLSchema-instance", "@xmlns:env": xmlns["soapenv"], @@ -224,7 +232,7 @@ def _get_soap_header(self, action=None, resource_uri=None, shell_id=None, messag header["env:Header"]["w:SelectorSet"] = {"w:Selector": {"@Name": "ShellId", "#text": shell_id}} return header - def send_message(self, message): + def send_message(self, message: str) -> bytes: # TODO add message_id vs relates_to checking # TODO port error handling code try: @@ -257,15 +265,17 @@ def send_message(self, message): if fault_subcode is not None: fault_data["fault_subcode"] = fault_subcode.text - error_message = fault.find("soapenv:Reason/soapenv:Text", xmlns) - if error_message is not None: - error_message = error_message.text + error_message_node = fault.find("soapenv:Reason/soapenv:Text", xmlns) + if error_message_node is not None: + error_message = error_message_node.text else: error_message = "(no error message in fault)" raise WinRMError("{0} (extended fault data: {1})".format(error_message, fault_data)) - def close_shell(self, shell_id, close_session=True): + raise + + def close_shell(self, shell_id: str, close_session: bool = True) -> None: """ Close the shell @param string shell_id: The shell id on the remote machine. @@ -292,7 +302,7 @@ def close_shell(self, shell_id, close_session=True): res = self.send_message(xmltodict.unparse(req)) root = ET.fromstring(res) - relates_to = next(node for node in root.findall(".//*") if node.tag.endswith("RelatesTo")).text + relates_to = t.cast(str, next(node for node in root.findall(".//*") if node.tag.endswith("RelatesTo")).text) finally: # Close the transport if we are done with the shell. # This will ensure no lingering TCP connections are thrown back into a requests' connection pool. @@ -302,7 +312,14 @@ def close_shell(self, shell_id, close_session=True): # TODO change assert into user-friendly exception assert uuid.UUID(relates_to.replace("uuid:", "")) == message_id - def run_command(self, shell_id, command, arguments=(), console_mode_stdin=True, skip_cmd_shell=False): + def run_command( + self, + shell_id: str, + command: str, + arguments: collections.abc.Iterable[str | bytes] = (), + console_mode_stdin: bool = True, + skip_cmd_shell: bool = False, + ) -> str: """ Run a command on a machine with an open shell @param string shell_id: The shell id on the remote machine. @@ -339,9 +356,9 @@ def run_command(self, shell_id, command, arguments=(), console_mode_stdin=True, res = self.send_message(xmltodict.unparse(req)) root = ET.fromstring(res) command_id = next(node for node in root.findall(".//*") if node.tag.endswith("CommandId")).text - return command_id + return t.cast(str, command_id) - def cleanup_command(self, shell_id, command_id): + def cleanup_command(self, shell_id: str, command_id: str) -> None: """ Clean-up after a command. @see #run_command @param string shell_id: The shell id on the remote machine. @@ -369,11 +386,11 @@ def cleanup_command(self, shell_id, command_id): res = self.send_message(xmltodict.unparse(req)) root = ET.fromstring(res) - relates_to = next(node for node in root.findall(".//*") if node.tag.endswith("RelatesTo")).text + relates_to = t.cast(str, next(node for node in root.findall(".//*") if node.tag.endswith("RelatesTo")).text) # TODO change assert into user-friendly exception assert uuid.UUID(relates_to.replace("uuid:", "")) == message_id - def send_command_input(self, shell_id, command_id, stdin_input, end=False): + def send_command_input(self, shell_id: str, command_id: str, stdin_input: str | bytes, end: bool = False) -> None: """ Send input to the given shell and command. @param string shell_id: The shell id on the remote machine. @@ -408,7 +425,7 @@ def send_command_input(self, shell_id, command_id, stdin_input, end=False): stdin_envelope["#text"] = base64.b64encode(stdin_input) self.send_message(xmltodict.unparse(req)) - def get_command_output(self, shell_id, command_id): + def get_command_output(self, shell_id: str, command_id: str) -> tuple[bytes, bytes, int]: """ Get the Output of the given shell and command @param string shell_id: The shell id on the remote machine. @@ -433,7 +450,7 @@ def get_command_output(self, shell_id, command_id): pass return b"".join(stdout_buffer), b"".join(stderr_buffer), return_code - def _raw_get_command_output(self, shell_id, command_id): + def _raw_get_command_output(self, shell_id: str, command_id: str) -> tuple[bytes, bytes, int, bool]: req = { "env:Envelope": self._get_soap_header( resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA @@ -470,6 +487,6 @@ def _raw_get_command_output(self, shell_id, command_id): # command_done = len([node for node in root.findall(".//*") if node.get("State", "").endswith("CommandState/Done")]) == 1 if command_done: - return_code = int(next(node for node in root.findall(".//*") if node.tag.endswith("ExitCode")).text) + return_code = int(next(node for node in root.findall(".//*") if node.tag.endswith("ExitCode")).text or -1) return stdout, stderr, return_code, command_done diff --git a/winrm/py.typed b/winrm/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/winrm/transport.py b/winrm/transport.py index d432d22f..fa2e781f 100644 --- a/winrm/transport.py +++ b/winrm/transport.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import typing as t import warnings import requests @@ -40,7 +41,7 @@ __all__ = ["Transport"] -def strtobool(value): +def strtobool(value: str) -> bool: value = value.lower() if value in ("true", "t", "yes", "y", "on", "1"): return True @@ -59,27 +60,27 @@ class UnsupportedAuthArgument(Warning): class Transport(object): def __init__( self, - endpoint, - username=None, - password=None, - realm=None, - service=None, - keytab=None, - ca_trust_path="legacy_requests", - cert_pem=None, - cert_key_pem=None, - read_timeout_sec=None, - server_cert_validation="validate", - kerberos_delegation=False, - kerberos_hostname_override=None, - auth_method="auto", - message_encryption="auto", - credssp_disable_tlsv1_2=False, - credssp_auth_mechanism="auto", - credssp_minimum_version=2, - send_cbt=True, - proxy="legacy_requests", - ): + endpoint: str, + username: str | None = None, + password: str | None = None, + realm: None = None, + service: str | None = None, + keytab: None = None, + ca_trust_path: t.Literal["legacy_requests"] | str = "legacy_requests", + cert_pem: str | None = None, + cert_key_pem: str | None = None, + read_timeout_sec: int | None = None, + server_cert_validation: t.Literal["validate", "ignore"] | None = "validate", + kerberos_delegation: bool | str = False, + kerberos_hostname_override: str | None = None, + auth_method: t.Literal["auto", "basic", "certificate", "ntlm", "kerberos", "credssp", "plaintext", "ssl"] = "auto", + message_encryption: t.Literal["auto", "always", "never"] = "auto", + credssp_disable_tlsv1_2: bool = False, + credssp_auth_mechanism: t.Literal["auto", "ntlm", "kerberos"] = "auto", + credssp_minimum_version: int = 2, + send_cbt: bool = True, + proxy: t.Literal["legacy_requests"] | str | None = "legacy_requests", + ) -> None: self.endpoint = endpoint self.username = username self.password = password @@ -161,14 +162,17 @@ def __init__( if self.password is None: raise InvalidCredentialsError("auth method %s requires a password" % self.auth_method) - self.session = None + self.session: requests.Session | None = None # Used for encrypting messages - self.encryption = None # The Pywinrm Encryption class used to encrypt/decrypt messages + self.encryption: Encryption | None = None # The Pywinrm Encryption class used to encrypt/decrypt messages if self.message_encryption not in ["auto", "always", "never"]: raise WinRMError("invalid message_encryption arg: %s. Should be 'auto', 'always', or 'never'" % self.message_encryption) - def build_session(self): + def build_session(self) -> requests.Session: + if self.session: + return self.session + session = requests.Session() proxies = dict() @@ -234,7 +238,7 @@ def build_session(self): if not HAVE_KERBEROS: raise WinRMError("requested auth method is kerberos, but pykerberos is not installed") - session.auth = HTTPKerberosAuth( + kerb_auth = session.auth = HTTPKerberosAuth( mutual_authentication=REQUIRED, delegate=self.kerberos_delegation, force_preemptive=True, @@ -244,14 +248,17 @@ def build_session(self): service=self.service, send_cbt=self.send_cbt, ) - encryption_available = hasattr(session.auth, "winrm_encryption_available") and session.auth.winrm_encryption_available + encryption_available = hasattr(session.auth, "winrm_encryption_available") and kerb_auth.winrm_encryption_available elif self.auth_method in ["certificate", "ssl"]: if self.auth_method == "ssl" and not self.cert_pem and not self.cert_key_pem: # 'ssl' was overloaded for HTTPS with optional certificate auth, # fall back to basic auth if no cert specified - session.auth = requests.auth.HTTPBasicAuth(username=self.username, password=self.password) + session.auth = requests.auth.HTTPBasicAuth( + username=self.username or "", + password=self.password or "", + ) else: - session.cert = (self.cert_pem, self.cert_key_pem) + session.cert = (self.cert_pem or "", self.cert_key_pem or "") session.headers["Authorization"] = "http://schemas.dmtf.org/wbem/wsman/1/wsman/secprofile/https/mutual" elif self.auth_method == "ntlm": if not HAVE_NTLM: @@ -266,7 +273,10 @@ def build_session(self): encryption_available = hasattr(session.auth, "session_security") # TODO: ssl is not exactly right here- should really be client_cert elif self.auth_method in ["basic", "plaintext"]: - session.auth = requests.auth.HTTPBasicAuth(username=self.username, password=self.password) + session.auth = requests.auth.HTTPBasicAuth( + username=self.username or "", + password=self.password or "", + ) elif self.auth_method == "credssp": if not HAVE_CREDSSP: raise WinRMError("requests auth method is credssp, but requests-credssp is not installed") @@ -290,26 +300,27 @@ def build_session(self): raise WinRMError("message encryption is set to 'always' but the selected auth method %s does not support it" % self.auth_method) elif encryption_available: if self.message_encryption == "always": - self.setup_encryption() + self.setup_encryption(session) elif self.message_encryption == "auto" and not self.endpoint.lower().startswith("https"): - self.setup_encryption() + self.setup_encryption(session) - def setup_encryption(self): + return session + + def setup_encryption(self, session: requests.Session) -> None: # Security context doesn't exist, sending blank message to initialise context request = requests.Request("POST", self.endpoint, data=None) - prepared_request = self.session.prepare_request(request) - self._send_message_request(prepared_request, "") - self.encryption = Encryption(self.session, self.auth_method) + prepared_request = session.prepare_request(request) + self._send_message_request(session, prepared_request) + self.encryption = Encryption(session, self.auth_method) - def close_session(self): + def close_session(self) -> None: if not self.session: return self.session.close() self.session = None - def send_message(self, message): - if not self.session: - self.build_session() + def send_message(self, message: str | bytes) -> bytes: + session = self.build_session() # urllib3 fails on SSL retries with unicode buffers- must send it a byte string # see https://github.com/shazow/urllib3/issues/717 @@ -317,17 +328,17 @@ def send_message(self, message): message = message.encode("utf-8") if self.encryption: - prepared_request = self.encryption.prepare_encrypted_request(self.session, self.endpoint, message) + prepared_request = self.encryption.prepare_encrypted_request(session, self.endpoint, message) else: request = requests.Request("POST", self.endpoint, data=message) - prepared_request = self.session.prepare_request(request) + prepared_request = session.prepare_request(request) - response = self._send_message_request(prepared_request, message) + response = self._send_message_request(session, prepared_request) return self._get_message_response_text(response) - def _send_message_request(self, prepared_request, message): + def _send_message_request(self, session: requests.Session, prepared_request: requests.PreparedRequest) -> requests.Response: try: - response = self.session.send(prepared_request, timeout=self.read_timeout_sec) + response = session.send(prepared_request, timeout=self.read_timeout_sec) response.raise_for_status() return response except requests.HTTPError as ex: @@ -336,11 +347,11 @@ def _send_message_request(self, prepared_request, message): if ex.response.content: response_text = self._get_message_response_text(ex.response) else: - response_text = "" + response_text = b"" - raise WinRMTransportError("http", ex.response.status_code, response_text) + raise WinRMTransportError("http", ex.response.status_code, response_text.decode()) - def _get_message_response_text(self, response): + def _get_message_response_text(self, response: requests.Response) -> bytes: if self.encryption: response_text = self.encryption.parse_encrypted_response(response) else: