Skip to content

Commit

Permalink
feat(ffi): add OwnedString class
Browse files Browse the repository at this point in the history
A number of FFI functions return strings that are owned by the library
and must be freed manually. Unfortunately, returning a `str` from a
function would result in a memory leak, as the Python runtime would not
know to free the string.

This commit adds an `OwnedString` class that wraps a `str` and a
function that frees the string. The `__del__` method of the class calls
the free function, ensuring that the string is freed when the object is
garbage collected.

Signed-off-by: JP-Ellis <[email protected]>
  • Loading branch information
JP-Ellis committed Oct 26, 2023
1 parent db6374a commit 49bc358
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 24 deletions.
112 changes: 88 additions & 24 deletions pact/v3/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
if TYPE_CHECKING:
import cffi
from pathlib import Path
from typing_extensions import Self

# The follow types are classes defined in the Rust code. Ultimately, a Python
# alternative should be implemented, but for now, the follow lines only serve
Expand Down Expand Up @@ -613,6 +614,75 @@ def raise_exception(self) -> None:
raise RuntimeError(self.text)


class OwnedString(str):
"""
A string that owns its own memory.
This is used to ensure that the memory is freed when the string is
destroyed.
As this is subclassed from `str`, it can be used in place of a normal string
in most cases.
"""

def __new__(cls, ptr: cffi.FFI.CData) -> Self:
"""
Create a new Owned String.
As this is a subclass of the immutable type `str`, we need to override
the `__new__` method to ensure that the string is initialised correctly.
"""
s = ffi.string(ptr)
return super().__new__(cls, s if isinstance(s, str) else s.decode("utf-8"))

def __init__(self, ptr: cffi.FFI.CData) -> None:
"""
Initialise a new Owned String.
Args:
ptr:
CFFI data structure.
"""
self._ptr = ptr
s = ffi.string(ptr)
self._string = s if isinstance(s, str) else s.decode("utf-8")

def __str__(self) -> str:
"""
String representation of the Owned String.
"""
return self._string

def __repr__(self) -> str:
"""
Debugging string representation of the Owned String.
"""
return f"<OwnedString: {self._string!r}, ptr={self._ptr!r}>"

def __del__(self) -> None:
"""
Destructor for the Owned String.
"""
string_delete(self)

def __eq__(self, other: object) -> bool:
"""
Equality comparison.
Args:
other:
The object to compare to.
Returns:
Whether the two objects are equal.
"""
if isinstance(other, OwnedString):
return self._ptr == other._ptr
if isinstance(other, str):
return self._string == other
return super().__eq__(other)


def version() -> str:
"""
Return the version of the pact_ffi library.
Expand Down Expand Up @@ -3000,7 +3070,7 @@ def message_delete(message: Message) -> None:
raise NotImplementedError


def message_get_contents(message: Message) -> str:
def message_get_contents(message: Message) -> OwnedString | None:
"""
Get the contents of a `Message` in string form.
Expand Down Expand Up @@ -3112,7 +3182,7 @@ def message_set_contents_bin(
raise NotImplementedError


def message_get_description(message: Message) -> str:
def message_get_description(message: Message) -> OwnedString:
r"""
Get a copy of the description.
Expand Down Expand Up @@ -4196,20 +4266,14 @@ def sync_message_get_provider_state_iter(
raise NotImplementedError


def string_delete(string: str) -> None:
def string_delete(string: OwnedString) -> None:
"""
Delete a string previously returned by this FFI.
[Rust
`pactffi_string_delete`](https://docs.rs/pact_ffi/0.4.9/pact_ffi/?search=pactffi_string_delete)
It is explicitly allowed to pass a null pointer to this function; in that
case the function will do nothing.
# Safety Passing an invalid pointer, or one that was not returned by a FFI
function can result in undefined behaviour.
"""
raise NotImplementedError
lib.pactffi_string_delete(string._ptr)


def create_mock_server(pact_str: str, addr_str: str, *, tls: bool) -> int:
Expand Down Expand Up @@ -4253,7 +4317,7 @@ def create_mock_server(pact_str: str, addr_str: str, *, tls: bool) -> int:
raise NotImplementedError


def get_tls_ca_certificate() -> str:
def get_tls_ca_certificate() -> OwnedString:
"""
Fetch the CA Certificate used to generate the self-signed certificate.
Expand All @@ -4267,7 +4331,7 @@ def get_tls_ca_certificate() -> str:
An empty string indicates an error reading the pem file.
"""
raise NotImplementedError
return OwnedString(lib.pactffi_get_tls_ca_certificate())


def create_mock_server_for_pact(pact: PactHandle, addr_str: str, *, tls: bool) -> int:
Expand Down Expand Up @@ -5624,7 +5688,7 @@ def message_with_metadata(message_handle: MessageHandle, key: str, value: str) -
raise NotImplementedError


def message_reify(message_handle: MessageHandle) -> str:
def message_reify(message_handle: MessageHandle) -> OwnedString:
"""
Reifies the given message.
Expand Down Expand Up @@ -6320,7 +6384,7 @@ def verifier_cli_args() -> str:
raise NotImplementedError


def verifier_logs(handle: VerifierHandle) -> str:
def verifier_logs(handle: VerifierHandle) -> OwnedString:
"""
Extracts the logs for the verification run.
Expand All @@ -6337,7 +6401,7 @@ def verifier_logs(handle: VerifierHandle) -> str:
raise NotImplementedError


def verifier_logs_for_provider(provider_name: str) -> str:
def verifier_logs_for_provider(provider_name: str) -> OwnedString:
"""
Extracts the logs for the verification run for the provider name.
Expand All @@ -6354,7 +6418,7 @@ def verifier_logs_for_provider(provider_name: str) -> str:
raise NotImplementedError


def verifier_output(handle: VerifierHandle, strip_ansi: int) -> str:
def verifier_output(handle: VerifierHandle, strip_ansi: int) -> OwnedString:
"""
Extracts the standard output for the verification run.
Expand All @@ -6373,7 +6437,7 @@ def verifier_output(handle: VerifierHandle, strip_ansi: int) -> str:
raise NotImplementedError


def verifier_json(handle: VerifierHandle) -> str:
def verifier_json(handle: VerifierHandle) -> OwnedString:
"""
Extracts the verification result as a JSON document.
Expand Down Expand Up @@ -6498,7 +6562,7 @@ def matches_string_value(
expected_value: str,
actual_value: str,
cascaded: int,
) -> str:
) -> OwnedString:
"""
Determines if the string value matches the given matching rule.
Expand Down Expand Up @@ -6529,7 +6593,7 @@ def matches_u64_value(
expected_value: int,
actual_value: int,
cascaded: int,
) -> str:
) -> OwnedString:
"""
Determines if the unsigned integer value matches the given matching rule.
Expand Down Expand Up @@ -6559,7 +6623,7 @@ def matches_i64_value(
expected_value: int,
actual_value: int,
cascaded: int,
) -> str:
) -> OwnedString:
"""
Determines if the signed integer value matches the given matching rule.
Expand Down Expand Up @@ -6589,7 +6653,7 @@ def matches_f64_value(
expected_value: float,
actual_value: float,
cascaded: int,
) -> str:
) -> OwnedString:
"""
Determines if the floating point value matches the given matching rule.
Expand Down Expand Up @@ -6619,7 +6683,7 @@ def matches_bool_value(
expected_value: int,
actual_value: int,
cascaded: int,
) -> str:
) -> OwnedString:
"""
Determines if the boolean value matches the given matching rule.
Expand Down Expand Up @@ -6651,7 +6715,7 @@ def matches_binary_value( # noqa: PLR0913
actual_value: str,
actual_value_len: int,
cascaded: int,
) -> str:
) -> OwnedString:
"""
Determines if the binary value matches the given matching rule.
Expand Down Expand Up @@ -6686,7 +6750,7 @@ def matches_json_value(
expected_value: str,
actual_value: str,
cascaded: int,
) -> str:
) -> OwnedString:
"""
Determines if the JSON value matches the given matching rule.
Expand Down
16 changes: 16 additions & 0 deletions tests/v3/test_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,19 @@ def test_get_error_message() -> None:
ret: int = ffi.lib.pactffi_validate_datetime(invalid_utf8, invalid_utf8)
assert ret == 2
assert ffi.get_error_message() == "error parsing value as UTF-8"


def test_owned_string() -> None:
string = ffi.get_tls_ca_certificate()
assert isinstance(string, str)
assert len(string) > 0
assert str(string) == string
assert repr(string).startswith("<OwnedString: ")
assert repr(string).endswith(">")
assert string.startswith("-----BEGIN CERTIFICATE-----")
assert string.endswith(
(
"-----END CERTIFICATE-----\n",
"-----END CERTIFICATE-----\r\n",
)
)

0 comments on commit 49bc358

Please sign in to comment.