Skip to content

Commit

Permalink
Add type annotations (#381)
Browse files Browse the repository at this point in the history
Add type annotations to the pywinrm library. This can help reduce type
related bugs inside the library and also help callers to figure out the
correct values that can be used in the public API.
  • Loading branch information
jborean93 authored Jun 6, 2024
1 parent 6f2ef68 commit fbb05e8
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 130 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ jobs:
run: |
python -m black . --check
python -m isort . --check-only
python -m mypy .
pytest -v --cov=winrm --cov-report=term-missing winrm/tests/
- name: upload coverage data
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ __pycache__
/winrm/tests/config.json
.pytest_cache
venv
.mypy_cache
37 changes: 36 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ include-package-data = true
packages = ["winrm"]

[tool.setuptools.package-data]
"winrm" = ["py.typed"]
"winrm.tests" = ["*.ps1"]

[tool.setuptools.dynamic]
Expand All @@ -77,4 +78,38 @@ exclude = '''
'''

[tool.isort]
profile = "black"
profile = "black"

[tool.mypy]
exclude = "build/|winrm/tests/|winrm/vendor/"
mypy_path = "$MYPY_CONFIG_FILE_DIR"
python_version = "3.8"
show_error_codes = true
show_column_numbers = true
disallow_any_unimported = true
disallow_untyped_calls = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
no_implicit_reexport = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
warn_unreachable = true

[[tool.mypy.overrides]]
module = "winrm.vendor.*"
follow_imports = "skip"

[[tool.mypy.overrides]]
module = "requests.packages.urllib3.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "requests_credssp"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "requests_ntlm"
ignore_missing_imports = true
3 changes: 3 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# this assumes the base requirements have been satisfied via setup.py
black == 24.4.2
isort == 5.13.2
mypy == 1.10.0
pytest
pytest-cov
mock
types-requests
types-xmltodict
26 changes: 16 additions & 10 deletions winrm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import collections.abc
import re
import typing as t
import warnings
import xml.etree.ElementTree as ET
from base64 import b64encode
Expand All @@ -22,22 +24,22 @@
class Response(object):
"""Response from a remote command execution"""

def __init__(self, args):
def __init__(self, args: tuple[bytes, bytes, int]) -> None:
self.std_out, self.std_err, self.status_code = args

def __repr__(self):
def __repr__(self) -> str:
# TODO put tree dots at the end if out/err was truncated
return '<Response code {0}, out "{1}", err "{2}">'.format(self.status_code, self.std_out[:20], self.std_err[:20])
return '<Response code {0}, out "{1!r}", err "{2!r}">'.format(self.status_code, self.std_out[:20], self.std_err[:20])


class Session(object):
# TODO implement context manager methods
def __init__(self, target, auth, **kwargs):
def __init__(self, target: str, auth: tuple[str, str], **kwargs: t.Any) -> None:
username, password = auth
self.url = self._build_url(target, kwargs.get("transport", "plaintext"))
self.protocol = Protocol(self.url, username=username, password=password, **kwargs)

def run_cmd(self, command, args=()):
def run_cmd(self, command: str, args: collections.abc.Iterable[str | bytes] = ()) -> Response:
# TODO optimize perf. Do not call open/close shell every time
shell_id = self.protocol.open_shell()
command_id = self.protocol.run_command(shell_id, command, args)
Expand All @@ -46,7 +48,7 @@ def run_cmd(self, command, args=()):
self.protocol.close_shell(shell_id)
return rs

def run_ps(self, script):
def run_ps(self, script: str) -> Response:
"""base64 encodes a Powershell script and executes the powershell
encoded script command
"""
Expand All @@ -59,7 +61,7 @@ def run_ps(self, script):
rs.std_err = self._clean_error_msg(rs.std_err)
return rs

def _clean_error_msg(self, msg):
def _clean_error_msg(self, msg: bytes) -> bytes:
"""converts a Powershell CLIXML message to a more human readable string"""
# TODO prepare unit test, beautify code
# if the msg does not start with this, return it as is
Expand All @@ -77,7 +79,8 @@ def _clean_error_msg(self, msg):
for s in nodes:
# append error msg string to result, also
# the hex chars represent CRLF so we replace with newline
new_msg += s.text.replace("_x000D__x000A_", "\n")
if s.text:
new_msg += s.text.replace("_x000D__x000A_", "\n")
except Exception as e:
# if any of the above fails, the msg was not true xml
# print a warning and return the original string
Expand All @@ -93,7 +96,7 @@ def _clean_error_msg(self, msg):
# just return the original message
return msg

def _strip_namespace(self, xml):
def _strip_namespace(self, xml: bytes) -> bytes:
"""strips any namespaces from an xml string"""
p = re.compile(b'xmlns=*[""][^""]*[""]')
allmatches = p.finditer(xml)
Expand All @@ -102,8 +105,11 @@ def _strip_namespace(self, xml):
return xml

@staticmethod
def _build_url(target, transport):
def _build_url(target: str, transport: str) -> str:
match = re.match(r"(?i)^((?P<scheme>http[s]?)://)?(?P<host>[0-9a-z-_.]+)(:(?P<port>\d+))?(?P<path>(/)?(wsman)?)?", target) # NOQA
if not match:
raise ValueError("Invalid target URL: {0}".format(target))

scheme = match.group("scheme")
if not scheme:
# TODO do we have anything other than HTTP/HTTPS
Expand Down
43 changes: 23 additions & 20 deletions winrm/encryption.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import re
import struct
from urllib.parse import urlsplit
Expand All @@ -12,7 +14,7 @@ class Encryption(object):
SIXTEN_KB = 16384
MIME_BOUNDARY = b"--Encrypted Boundary"

def __init__(self, session, protocol):
def __init__(self, session: requests.Session, protocol: str) -> None:
"""
[MS-WSMV] v30.0 2016-07-14
Expand Down Expand Up @@ -51,7 +53,7 @@ def __init__(self, session, protocol):
else:
raise WinRMError("Encryption for protocol '%s' not supported in pywinrm" % protocol)

def prepare_encrypted_request(self, session, endpoint, message):
def prepare_encrypted_request(self, session: requests.Session, endpoint: str | bytes, message: bytes) -> requests.PreparedRequest:
"""
Creates a prepared request to send to the server with an encrypted message
and correct headers
Expand All @@ -77,28 +79,29 @@ def prepare_encrypted_request(self, session, endpoint, message):

request = requests.Request("POST", endpoint, data=encrypted_message)
prepared_request = session.prepare_request(request)
prepared_request.headers["Content-Length"] = str(len(prepared_request.body))
prepared_request.headers["Content-Length"] = str(len(prepared_request.body)) if prepared_request.body else "0"
prepared_request.headers["Content-Type"] = '{0};protocol="{1}";boundary="Encrypted Boundary"'.format(content_type, self.protocol_string.decode())

return prepared_request

def parse_encrypted_response(self, response):
def parse_encrypted_response(self, response: requests.Response) -> bytes:
"""
Takes in the encrypted response from the server and decrypts it
:param response: The response that needs to be decrypted
:return: The unencrypted message from the server
"""
content_type = response.headers["Content-Type"]

if 'protocol="{0}"'.format(self.protocol_string.decode()) in content_type:
host = urlsplit(response.request.url).hostname
msg = self._decrypt_response(response, host)
else:
msg = response.text
msg = response.content

return msg

def _encrypt_message(self, message, host):
def _encrypt_message(self, message: bytes, host: str | bytes | None) -> bytes:
message_length = str(len(message)).encode()
encrypted_stream = self._build_message(message, host)

Expand All @@ -111,7 +114,7 @@ def _encrypt_message(self, message, host):

return message_payload

def _decrypt_response(self, response, host):
def _decrypt_response(self, response: requests.Response, host: str | bytes | None) -> bytes:
parts = response.content.split(self.MIME_BOUNDARY + b"\r\n")
parts = list(filter(None, parts)) # filter out empty parts of the split
message = b""
Expand Down Expand Up @@ -139,55 +142,55 @@ def _decrypt_response(self, response, host):

return message

def _decrypt_ntlm_message(self, encrypted_data, host):
def _decrypt_ntlm_message(self, encrypted_data: bytes, host: str | bytes | None) -> bytes:
signature_length = struct.unpack("<i", encrypted_data[:4])[0]
signature = encrypted_data[4 : signature_length + 4]
encrypted_message = encrypted_data[signature_length + 4 :]

message = self.session.auth.session_security.unwrap(encrypted_message, signature)
message = self.session.auth.session_security.unwrap(encrypted_message, signature) # type: ignore[union-attr]

return message

def _decrypt_credssp_message(self, encrypted_data, host):
def _decrypt_credssp_message(self, encrypted_data: bytes, host: str | bytes | None) -> bytes:
# trailer_length = struct.unpack("<i", encrypted_data[:4])[0]
encrypted_message = encrypted_data[4:]

credssp_context = self.session.auth.contexts[host]
credssp_context = self.session.auth.contexts[host] # type: ignore[union-attr]
message = credssp_context.unwrap(encrypted_message)

return message

def _decrypt_kerberos_message(self, encrypted_data, host):
def _decrypt_kerberos_message(self, encrypted_data: bytes, host: str | bytes | None) -> bytes:
signature_length = struct.unpack("<i", encrypted_data[:4])[0]
signature = encrypted_data[4 : signature_length + 4]
encrypted_message = encrypted_data[signature_length + 4 :]

message = self.session.auth.unwrap_winrm(host, encrypted_message, signature)
message = self.session.auth.unwrap_winrm(host, encrypted_message, signature) # type: ignore[union-attr]

return message

def _build_ntlm_message(self, message, host):
sealed_message, signature = self.session.auth.session_security.wrap(message)
def _build_ntlm_message(self, message: bytes, host: str | bytes | None) -> bytes:
sealed_message, signature = self.session.auth.session_security.wrap(message) # type: ignore[union-attr]
signature_length = struct.pack("<i", len(signature))

return signature_length + signature + sealed_message

def _build_credssp_message(self, message, host):
credssp_context = self.session.auth.contexts[host]
def _build_credssp_message(self, message: bytes, host: str | bytes | None) -> bytes:
credssp_context = self.session.auth.contexts[host] # type: ignore[union-attr]
sealed_message = credssp_context.wrap(message)

cipher_negotiated = credssp_context.tls_connection.get_cipher_name()
trailer_length = self._get_credssp_trailer_length(len(message), cipher_negotiated)

return struct.pack("<i", trailer_length) + sealed_message

def _build_kerberos_message(self, message, host):
sealed_message, signature = self.session.auth.wrap_winrm(host, message)
def _build_kerberos_message(self, message: bytes, host: str | bytes | None) -> bytes:
sealed_message, signature = self.session.auth.wrap_winrm(host, message) # type: ignore[union-attr]
signature_length = struct.pack("<i", len(signature))

return signature_length + signature + sealed_message

def _get_credssp_trailer_length(self, message_length, cipher_suite):
def _get_credssp_trailer_length(self, message_length: int, cipher_suite: str) -> int:
# I really don't like the way this works but can't find a better way, MS
# allows you to get this info through the struct SecPkgContext_StreamSizes
# but there is no GSSAPI/OpenSSL equivalent so we need to calculate it
Expand Down
10 changes: 5 additions & 5 deletions winrm/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@ class WinRMTransportError(Exception):
"""WinRM errors specific to transport-level problems (unexpected HTTP error codes, etc)"""

@property
def protocol(self):
def protocol(self) -> str:
return self.args[0]

@property
def code(self):
def code(self) -> int:
return self.args[1]

@property
def message(self):
def message(self) -> str:
return "Bad HTTP response returned from server. Code {0}".format(self.code)

@property
def response_text(self):
def response_text(self) -> str:
return self.args[2]

def __str__(self):
def __str__(self) -> str:
return self.message


Expand Down
Loading

0 comments on commit fbb05e8

Please sign in to comment.