Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure pack writer hierarchy, add output_dir property #6

Merged
merged 3 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ Source = "https://github.com/kmontag/alpax"
[tool.hatch.version]
path = "src/alpax/__about__.py"

# [tool.hatch.envs.hatch-test]
# extra-dependencies = [
# "pytest-asyncio~=0.23.7",
# ]
[tool.hatch.envs.hatch-test]
extra-dependencies = [
"pytest-asyncio~=0.23.7",
]

[tool.hatch.envs.coverage]
detached = true
Expand All @@ -49,6 +49,7 @@ combine = "coverage combine {args}"
report-xml = "coverage xml"

[tool.hatch.envs.types]
template = "hatch-test"
extra-dependencies = [
"mypy>=1.0.0",
]
Expand Down Expand Up @@ -99,6 +100,9 @@ extend-select = ["ARG", "B", "E", "I", "W"]
# line length in some cases.
extend-ignore = ["E501"]

[tool.ruff.lint.isort]
known-first-party = ["alpax"]

[tool.semantic_release]
assets = []
# Work around not being able to install packages preemptively in the
Expand Down
102 changes: 52 additions & 50 deletions src/alpax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import textwrap
from collections.abc import Collection, Sequence
from contextlib import AsyncExitStack
from typing import (
TYPE_CHECKING,
Generic,
Expand Down Expand Up @@ -45,7 +46,7 @@ def override(func):
# element is the tag and subtag values.
Tag: TypeAlias = tuple[str, Sequence[str]]

_Context = TypeVar("_Context")
_PackWriterAsyncType = TypeVar("_PackWriterAsyncType", bound="PackWriterAsync")


class PackProperties(TypedDict):
Expand All @@ -60,7 +61,7 @@ class PackProperties(TypedDict):
is_hidden_in_browse_groups: NotRequired[bool]


class PackWriterAsync(Generic[_Context]):
class PackWriterAsync:
def __init__(self, **k: Unpack[PackProperties]):
self._name: str = k["name"]
self._unique_id: str = k["unique_id"]
Expand All @@ -75,10 +76,7 @@ def __init__(self, **k: Unpack[PackProperties]):

self._is_hidden_in_browse_groups = k.get("is_hidden_in_browse_groups", False)

self.__context: _Context | None = None
# The `open()` method might return None, so we need a separate
# tracker to check the open status.
self.__has_context: bool = False
self.__exit_stack: AsyncExitStack | None = None

# Propagate unexpected keys up to `object`, so that errors
# will be thrown if appropriate.
Expand Down Expand Up @@ -108,38 +106,47 @@ async def commit(self) -> None:
raise NotImplementedError

# Open any resources necessary to start adding content, e.g. a
# temp directory to stage files.
async def open(self) -> _Context:
# temp directory to stage files. If any resources need to be
# cleaned up after all content has been added/committed, add them
# to the exit stack.
async def _create_context(self, exit_stack: AsyncExitStack) -> None:
raise NotImplementedError

# Close any resources opened by `_open`.
async def close(self, context: _Context) -> None:
raise NotImplementedError
async def open(self) -> None:
if self.__exit_stack is not None:
msg = f"{self} is already open"
raise RuntimeError(msg)

self.__exit_stack = AsyncExitStack()
await self._create_context(self.__exit_stack)

async def close(self) -> None:
exit_stack = self.__exit_stack
if exit_stack is None:
msg = f"{self} is not open"
raise RuntimeError(msg)
self.__exit_stack = None
await exit_stack.aclose()

# Allow usage like:
#
# async with await PackWriter(**args) as p:
# p.set_file(...)
# p.set_preview(...)
# async with await PackWriterAsync(**args) as p:
# await p.set_file(...)
# await p.set_preview(...)
#
# Which is equivalent to:
#
# p = PackWriter(**args)
# context = p.open()
# p = PackWriterAsync(**args)
# await p.open()
# try:
# p.set_file(...)
# p.set_preview(...)
# p.commit()
# await p.set_file(...)
# await p.set_preview(...)
# await p.commit()
# finally:
# p.close(context)
# await p.close(context)
#
async def __aenter__(self) -> Self:
if self.__has_context:
msg = f"{self} is already open"
raise ValueError(msg)

self.__context = await self.open()
self.__has_context = True
await self.open()
return self

async def __aexit__(
Expand All @@ -148,26 +155,17 @@ async def __aexit__(
exc_inst: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if not self.__has_context:
msg = f"{self} is not open"
raise ValueError(msg)

try:
if exc_type is None:
await self.commit()
finally:
await self.close(
# The context type is allowed to be `None`, so we
# can't just assert that this is present.
self.__context, # type: ignore
)
self.__has_context = False
await self.close()


# For synchronous writes, just wrap an async writer.
class PackWriter(Generic[_Context]):
def __init__(self, pack_writer_async: PackWriterAsync[_Context]) -> None:
self._pack_writer_async = pack_writer_async
class PackWriter(Generic[_PackWriterAsyncType]):
def __init__(self, pack_writer_async: _PackWriterAsyncType) -> None:
self._pack_writer_async: _PackWriterAsyncType = pack_writer_async

def set_file(self, path: str, file: str) -> None:
asyncio.run(self._pack_writer_async.set_file(path, file))
Expand All @@ -187,11 +185,11 @@ def set_preview_content(self, path: str, ogg_content: bytes) -> None:
def commit(self) -> None:
asyncio.run(self._pack_writer_async.commit())

def open(self) -> _Context:
return asyncio.run(self._pack_writer_async.open())
def open(self) -> None:
asyncio.run(self._pack_writer_async.open())

def close(self, context: _Context) -> None:
asyncio.run(self._pack_writer_async.close(context))
def close(self) -> None:
asyncio.run(self._pack_writer_async.close())

def __enter__(self) -> Self:
asyncio.run(self._pack_writer_async.__aenter__())
Expand All @@ -206,7 +204,7 @@ def __exit__(
asyncio.run(self._pack_writer_async.__aexit__(exc_type, exc_val, exc_tb))


class DirectoryPackWriterAsync(PackWriterAsync[None]):
class DirectoryPackWriterAsync(PackWriterAsync):
def __init__(self, output_dir: str | os.PathLike, **k: Unpack[PackProperties]):
super().__init__(**k)

Expand All @@ -221,6 +219,10 @@ def __init__(self, output_dir: str | os.PathLike, **k: Unpack[PackProperties]):
# Keys are paths within the pack.
self._tags: dict[str, Collection[Tag]] = {}

@property
def output_dir(self) -> str | os.PathLike:
return self._output_dir

@override
async def set_file(self, path: str, file: str) -> None:
await self._copy_to_path(path, file)
Expand All @@ -242,11 +244,7 @@ async def set_preview_content(self, path: str, ogg_content: bytes) -> None:
await self._write_to_path(self._preview_path(path), ogg_content)

@override
async def open(self) -> None:
return None

@override
async def close(self, context: None) -> None:
async def _create_context(self, exit_stack: AsyncExitStack) -> None:
pass

@override
Expand Down Expand Up @@ -357,7 +355,11 @@ def do_write_file(absolute_path: str, content: bytes) -> None:
await asyncio.to_thread(do_write_file, absolute_path, content)


class DirectoryPackWriter(PackWriter):
class DirectoryPackWriter(PackWriter[DirectoryPackWriterAsync]):
def __init__(self, output_dir: str | os.PathLike, **k: Unpack[PackProperties]):
pack_writer_async = DirectoryPackWriterAsync(output_dir, **k)
super().__init__(pack_writer_async)

@property
def output_dir(self) -> str | os.PathLike:
return self._pack_writer_async.output_dir
29 changes: 21 additions & 8 deletions tests/test_alpax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import os
import tempfile

import pytest

import alpax


Expand Down Expand Up @@ -51,15 +52,27 @@ def test_directory() -> None:
assert "<rdf:li>Tag Name|Tag Value|Subtag Value</rdf:li>" in xmp_content


def test_simple_directory_async() -> None:
@pytest.mark.asyncio
async def test_simple_directory_async() -> None:
with tempfile.TemporaryDirectory() as output_dir:
async with alpax.DirectoryPackWriterAsync(output_dir, name="Test", unique_id="test.id") as pack_writer:
# Simple test, just make sure the write can happen
# without errors.
await pack_writer.set_file_content("path.adg", b"content")

with open(os.path.join(output_dir, "path.adg")) as f: # noqa: ASYNC101
assert f.read() == "content"


async def run() -> None:
@pytest.mark.asyncio
async def test_exceptions_propagated() -> None:
with tempfile.TemporaryDirectory() as output_dir:
exception_msg = "test exception"

async def raise_exc() -> None:
async with alpax.DirectoryPackWriterAsync(output_dir, name="Test", unique_id="test.id") as pack_writer:
# Simple test, just make sure the write can happen
# without errors.
await pack_writer.set_file_content("path.adg", b"content")
raise RuntimeError(exception_msg)

asyncio.run(run())
with open(os.path.join(output_dir, "path.adg")) as f:
assert f.read() == "content"
with pytest.raises(RuntimeError, match=exception_msg):
await raise_exc()
Loading