From 49bc358bf948ba9c94703dad987b818474d56e71 Mon Sep 17 00:00:00 2001 From: JP-Ellis Date: Thu, 26 Oct 2023 14:40:00 +1100 Subject: [PATCH] feat(ffi): add OwnedString class A number of FFI functions return strings that are owned by the library and must be freed manually. Unfortunately, returning a `str` from a function would result in a memory leak, as the Python runtime would not know to free the string. This commit adds an `OwnedString` class that wraps a `str` and a function that frees the string. The `__del__` method of the class calls the free function, ensuring that the string is freed when the object is garbage collected. Signed-off-by: JP-Ellis --- pact/v3/ffi.py | 112 +++++++++++++++++++++++++++++++++---------- tests/v3/test_ffi.py | 16 +++++++ 2 files changed, 104 insertions(+), 24 deletions(-) diff --git a/pact/v3/ffi.py b/pact/v3/ffi.py index c2cf575998..fd190ee1b4 100644 --- a/pact/v3/ffi.py +++ b/pact/v3/ffi.py @@ -92,6 +92,7 @@ if TYPE_CHECKING: import cffi from pathlib import Path + from typing_extensions import Self # The follow types are classes defined in the Rust code. Ultimately, a Python # alternative should be implemented, but for now, the follow lines only serve @@ -613,6 +614,75 @@ def raise_exception(self) -> None: raise RuntimeError(self.text) +class OwnedString(str): + """ + A string that owns its own memory. + + This is used to ensure that the memory is freed when the string is + destroyed. + + As this is subclassed from `str`, it can be used in place of a normal string + in most cases. + """ + + def __new__(cls, ptr: cffi.FFI.CData) -> Self: + """ + Create a new Owned String. + + As this is a subclass of the immutable type `str`, we need to override + the `__new__` method to ensure that the string is initialised correctly. + """ + s = ffi.string(ptr) + return super().__new__(cls, s if isinstance(s, str) else s.decode("utf-8")) + + def __init__(self, ptr: cffi.FFI.CData) -> None: + """ + Initialise a new Owned String. + + Args: + ptr: + CFFI data structure. + """ + self._ptr = ptr + s = ffi.string(ptr) + self._string = s if isinstance(s, str) else s.decode("utf-8") + + def __str__(self) -> str: + """ + String representation of the Owned String. + """ + return self._string + + def __repr__(self) -> str: + """ + Debugging string representation of the Owned String. + """ + return f"" + + def __del__(self) -> None: + """ + Destructor for the Owned String. + """ + string_delete(self) + + def __eq__(self, other: object) -> bool: + """ + Equality comparison. + + Args: + other: + The object to compare to. + + Returns: + Whether the two objects are equal. + """ + if isinstance(other, OwnedString): + return self._ptr == other._ptr + if isinstance(other, str): + return self._string == other + return super().__eq__(other) + + def version() -> str: """ Return the version of the pact_ffi library. @@ -3000,7 +3070,7 @@ def message_delete(message: Message) -> None: raise NotImplementedError -def message_get_contents(message: Message) -> str: +def message_get_contents(message: Message) -> OwnedString | None: """ Get the contents of a `Message` in string form. @@ -3112,7 +3182,7 @@ def message_set_contents_bin( raise NotImplementedError -def message_get_description(message: Message) -> str: +def message_get_description(message: Message) -> OwnedString: r""" Get a copy of the description. @@ -4196,20 +4266,14 @@ def sync_message_get_provider_state_iter( raise NotImplementedError -def string_delete(string: str) -> None: +def string_delete(string: OwnedString) -> None: """ Delete a string previously returned by this FFI. [Rust `pactffi_string_delete`](https://docs.rs/pact_ffi/0.4.9/pact_ffi/?search=pactffi_string_delete) - - It is explicitly allowed to pass a null pointer to this function; in that - case the function will do nothing. - - # Safety Passing an invalid pointer, or one that was not returned by a FFI - function can result in undefined behaviour. """ - raise NotImplementedError + lib.pactffi_string_delete(string._ptr) def create_mock_server(pact_str: str, addr_str: str, *, tls: bool) -> int: @@ -4253,7 +4317,7 @@ def create_mock_server(pact_str: str, addr_str: str, *, tls: bool) -> int: raise NotImplementedError -def get_tls_ca_certificate() -> str: +def get_tls_ca_certificate() -> OwnedString: """ Fetch the CA Certificate used to generate the self-signed certificate. @@ -4267,7 +4331,7 @@ def get_tls_ca_certificate() -> str: An empty string indicates an error reading the pem file. """ - raise NotImplementedError + return OwnedString(lib.pactffi_get_tls_ca_certificate()) def create_mock_server_for_pact(pact: PactHandle, addr_str: str, *, tls: bool) -> int: @@ -5624,7 +5688,7 @@ def message_with_metadata(message_handle: MessageHandle, key: str, value: str) - raise NotImplementedError -def message_reify(message_handle: MessageHandle) -> str: +def message_reify(message_handle: MessageHandle) -> OwnedString: """ Reifies the given message. @@ -6320,7 +6384,7 @@ def verifier_cli_args() -> str: raise NotImplementedError -def verifier_logs(handle: VerifierHandle) -> str: +def verifier_logs(handle: VerifierHandle) -> OwnedString: """ Extracts the logs for the verification run. @@ -6337,7 +6401,7 @@ def verifier_logs(handle: VerifierHandle) -> str: raise NotImplementedError -def verifier_logs_for_provider(provider_name: str) -> str: +def verifier_logs_for_provider(provider_name: str) -> OwnedString: """ Extracts the logs for the verification run for the provider name. @@ -6354,7 +6418,7 @@ def verifier_logs_for_provider(provider_name: str) -> str: raise NotImplementedError -def verifier_output(handle: VerifierHandle, strip_ansi: int) -> str: +def verifier_output(handle: VerifierHandle, strip_ansi: int) -> OwnedString: """ Extracts the standard output for the verification run. @@ -6373,7 +6437,7 @@ def verifier_output(handle: VerifierHandle, strip_ansi: int) -> str: raise NotImplementedError -def verifier_json(handle: VerifierHandle) -> str: +def verifier_json(handle: VerifierHandle) -> OwnedString: """ Extracts the verification result as a JSON document. @@ -6498,7 +6562,7 @@ def matches_string_value( expected_value: str, actual_value: str, cascaded: int, -) -> str: +) -> OwnedString: """ Determines if the string value matches the given matching rule. @@ -6529,7 +6593,7 @@ def matches_u64_value( expected_value: int, actual_value: int, cascaded: int, -) -> str: +) -> OwnedString: """ Determines if the unsigned integer value matches the given matching rule. @@ -6559,7 +6623,7 @@ def matches_i64_value( expected_value: int, actual_value: int, cascaded: int, -) -> str: +) -> OwnedString: """ Determines if the signed integer value matches the given matching rule. @@ -6589,7 +6653,7 @@ def matches_f64_value( expected_value: float, actual_value: float, cascaded: int, -) -> str: +) -> OwnedString: """ Determines if the floating point value matches the given matching rule. @@ -6619,7 +6683,7 @@ def matches_bool_value( expected_value: int, actual_value: int, cascaded: int, -) -> str: +) -> OwnedString: """ Determines if the boolean value matches the given matching rule. @@ -6651,7 +6715,7 @@ def matches_binary_value( # noqa: PLR0913 actual_value: str, actual_value_len: int, cascaded: int, -) -> str: +) -> OwnedString: """ Determines if the binary value matches the given matching rule. @@ -6686,7 +6750,7 @@ def matches_json_value( expected_value: str, actual_value: str, cascaded: int, -) -> str: +) -> OwnedString: """ Determines if the JSON value matches the given matching rule. diff --git a/tests/v3/test_ffi.py b/tests/v3/test_ffi.py index f899144db2..912411c53d 100644 --- a/tests/v3/test_ffi.py +++ b/tests/v3/test_ffi.py @@ -51,3 +51,19 @@ def test_get_error_message() -> None: ret: int = ffi.lib.pactffi_validate_datetime(invalid_utf8, invalid_utf8) assert ret == 2 assert ffi.get_error_message() == "error parsing value as UTF-8" + + +def test_owned_string() -> None: + string = ffi.get_tls_ca_certificate() + assert isinstance(string, str) + assert len(string) > 0 + assert str(string) == string + assert repr(string).startswith("") + assert string.startswith("-----BEGIN CERTIFICATE-----") + assert string.endswith( + ( + "-----END CERTIFICATE-----\n", + "-----END CERTIFICATE-----\r\n", + ) + )