diff --git a/src/mercury_engine_data_structures/cli.py b/src/mercury_engine_data_structures/cli.py index 229aefb..299e463 100644 --- a/src/mercury_engine_data_structures/cli.py +++ b/src/mercury_engine_data_structures/cli.py @@ -11,6 +11,7 @@ from mercury_engine_data_structures.construct_extensions.json import convert_to_raw_python from mercury_engine_data_structures.file_tree_editor import FileTreeEditor, OutputFormat from mercury_engine_data_structures.game_check import Game +from mercury_engine_data_structures.romfs import ExtractedRomFs def game_argument_type(s: str) -> Game: @@ -134,7 +135,7 @@ def do_decode_from_pkg(args): root: Path = args.root asset_name: str = args.asset_name - pkg_editor = FileTreeEditor(root, args.game) + pkg_editor = FileTreeEditor(ExtractedRomFs(root), args.game) asset = pkg_editor.get_parsed_asset(asset_name) print(asset.raw) @@ -173,7 +174,7 @@ def find_pkg_for(args): asset_id: int = args.asset_id asset_name: str = args.asset_name - pkg_editor = FileTreeEditor(root, args.game) + pkg_editor = FileTreeEditor(ExtractedRomFs(root), args.game) if asset_id is not None: items = list(pkg_editor.find_pkgs(asset_id)) else: @@ -251,7 +252,7 @@ def extract_files(args: argparse.Namespace) -> None: root: Path = args.root output_root: Path = args.output - pkg_editor = FileTreeEditor(root, args.game) + pkg_editor = FileTreeEditor(ExtractedRomFs(root), args.game) output_root.mkdir(parents=True, exist_ok=True) for file_name in pkg_editor.all_asset_names(): @@ -267,7 +268,7 @@ def replace_files(args: argparse.Namespace) -> None: new_files: Path = args.new_files output: Path = args.output - pkg_editor = FileTreeEditor(root, args.game) + pkg_editor = FileTreeEditor(ExtractedRomFs(root), args.game) for file_name in new_files.rglob("*"): if file_name.is_file(): diff --git a/src/mercury_engine_data_structures/file_tree_editor.py b/src/mercury_engine_data_structures/file_tree_editor.py index 466c7b0..0f646c8 100644 --- a/src/mercury_engine_data_structures/file_tree_editor.py +++ b/src/mercury_engine_data_structures/file_tree_editor.py @@ -14,6 +14,7 @@ from mercury_engine_data_structures.formats.base_resource import AssetId, BaseResource, NameOrAssetId, resolve_asset_id from mercury_engine_data_structures.formats.pkg import Pkg from mercury_engine_data_structures.game_check import Game +from mercury_engine_data_structures.romfs import RomFs _T = typing.TypeVar("_T", bound=BaseResource) logger = logging.getLogger(__name__) @@ -31,12 +32,6 @@ class OutputFormat(enum.Enum): ROMFS = enum.auto() -def _read_file_with_entry(path: Path, entry): - with path.open("rb") as f: - f.seek(entry.start_offset) - return f.read(entry.end_offset - entry.start_offset) - - def _write_to_path(output: Path, data: bytes): output.parent.mkdir(parents=True, exist_ok=True) output.write_bytes(data) @@ -67,17 +62,14 @@ class FileTreeEditor: _in_memory_pkgs: dict[str, Pkg] _toc: Toc - def __init__(self, root: Path, target_game: Game): - self.root = root + def __init__(self, romfs: RomFs, target_game: Game): + self.romfs = romfs self.target_game = target_game self._modified_resources = {} self._in_memory_pkgs = {} self._update_headers() - def path_for_pkg(self, pkg_name: str) -> Path: - return self.root.joinpath(pkg_name) - def _add_pkg_name_for_asset_id(self, asset_id: AssetId, pkg_name: str | None): self._files_for_asset_id[asset_id] = self._files_for_asset_id.get(asset_id, set()) self._files_for_asset_id[asset_id].add(pkg_name) @@ -89,18 +81,13 @@ def _update_headers(self): self._files_for_asset_id = {} self._name_for_asset_id = copy.copy(_all_asset_id_for_game(self.target_game)) - self._toc = Toc.parse(self.root.joinpath(Toc.system_files_name()).read_bytes(), target_game=self.target_game) - custom_names = self.root.joinpath("custom_names.json") - if custom_names.is_file(): - with custom_names.open() as f: - self._name_for_asset_id.update({asset_id: name for name, asset_id in json.load(f).items()}) + self._toc = Toc.parse(self.romfs.get_file(Toc.system_files_name()), target_game=self.target_game) - for f in self.root.rglob("*.*"): - name = f.relative_to(self.root).as_posix() + for name in self.romfs.all_files(): asset_id = resolve_asset_id(name, self.target_game) self._name_for_asset_id[asset_id] = name - if f.suffix == ".pkg": + if name.endswith(".pkg"): self.all_pkgs.append(name) elif self._toc.get_size_for(asset_id) is None: @@ -111,7 +98,7 @@ def _update_headers(self): self._add_pkg_name_for_asset_id(asset_id, None) for name in self.all_pkgs: - with self.path_for_pkg(name).open("rb") as f: + with self.romfs.get_pkg_stream(name) as f: self.headers[name] = Pkg.header_class(self.target_game).parse_stream(f, target_game=self.target_game) self._ensured_asset_ids[name] = set() @@ -181,11 +168,11 @@ def get_raw_asset(self, asset_id: NameOrAssetId, *, in_pkg: str | None = None) - entry = header.entries_by_id.get(asset_id) if entry is not None: logger.info("Reading asset %s from pkg %s", str(original_name), name) - return _read_file_with_entry(self.path_for_pkg(name), entry) + return self.romfs.read_file_with_entry(name, entry) if in_pkg is None and asset_id in self._name_for_asset_id: name = self._name_for_asset_id[asset_id] - return self.root.joinpath(name).read_bytes() + return self.romfs.get_file(name) raise ValueError(f"Unknown asset_id: {original_name}") @@ -294,7 +281,7 @@ def get_pkg(self, pkg_name: str) -> Pkg: if pkg_name not in self._in_memory_pkgs: logger.info("Reading %s", pkg_name) - with self.path_for_pkg(pkg_name).open("rb") as f: + with self.romfs.get_pkg_stream(pkg_name) as f: self._in_memory_pkgs[pkg_name] = Pkg.parse_stream(f, target_game=self.target_game) return self._in_memory_pkgs[pkg_name] @@ -404,18 +391,6 @@ def save_modifications(self, output_path: Path, output_format: OutputFormat, *, with out_pkg_path.open("wb") as f: pkg.build_stream(f) - custom_names = output_path.joinpath("custom_names.json") - with custom_names.open("w") as f: - json.dump( - { - name: asset_id - for asset_id, name in self._name_for_asset_id.items() - if asset_id not in _all_asset_id_for_game(self.target_game) - }, - f, - indent=4, - ) - self._modified_resources = {} if finalize_editor: # _update_headers has significant runtime costs, so avoid it. diff --git a/src/mercury_engine_data_structures/romfs.py b/src/mercury_engine_data_structures/romfs.py new file mode 100644 index 0000000..d38a68b --- /dev/null +++ b/src/mercury_engine_data_structures/romfs.py @@ -0,0 +1,64 @@ +import io +from abc import ABC, abstractmethod +from collections.abc import Iterator +from contextlib import contextmanager +from pathlib import Path + + +class RomFs(ABC): + @contextmanager + @abstractmethod + def get_pkg_stream(self, file_path: str) -> Iterator[io.BufferedIOBase]: + """Returns a package file stream which should be used in a "with" context. + + :param file_path: File path to the pkg file + """ + pass + + @abstractmethod + def read_file_with_entry(self, file_path: str, entry) -> bytes: + """Reads and returns a file within a pkg file. + + :param file_path: File path to the pkg file + :param entry: An entry object containing the end_offset and start_offset within the pkg + """ + pass + + @abstractmethod + def get_file(self, file_path: str) -> bytes: + """Reads and returns a file. + + :param file_path: Path to the file + """ + pass + + @abstractmethod + def all_files(self) -> Iterator[str]: + """Returns an Iterator for all files within the RomFS""" + pass + + +class ExtractedRomFs(RomFs): + def __init__(self, root: Path): + self.root = root + + @contextmanager + def get_pkg_stream(self, file_path: str) -> Iterator[io.BufferedReader]: + file_stream = self.root.joinpath(file_path).open("rb") + try: + yield file_stream + finally: + file_stream.close() + + def read_file_with_entry(self, file_path: str, entry) -> bytes: + with self.root.joinpath(file_path).open("rb") as f: + f.seek(entry.start_offset) + return f.read(entry.end_offset - entry.start_offset) + + def get_file(self, file_path: str) -> bytes: + return self.root.joinpath(file_path).read_bytes() + + def all_files(self) -> Iterator[str]: + for f in self.root.rglob("*.*"): + name = f.relative_to(self.root).as_posix() + yield name diff --git a/tests/conftest.py b/tests/conftest.py index 8df0f71..e201491 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from mercury_engine_data_structures.file_tree_editor import FileTreeEditor from mercury_engine_data_structures.game_check import Game +from mercury_engine_data_structures.romfs import ExtractedRomFs _FAIL_INSTEAD_OF_SKIP = False @@ -35,17 +36,17 @@ def dread_path_210(): @pytest.fixture(scope="session") def samus_returns_tree(samus_returns_path): - return FileTreeEditor(samus_returns_path, Game.SAMUS_RETURNS) + return FileTreeEditor(ExtractedRomFs(samus_returns_path), Game.SAMUS_RETURNS) @pytest.fixture(scope="session") def dread_tree_100(dread_path_100): - return FileTreeEditor(dread_path_100, Game.DREAD) + return FileTreeEditor(ExtractedRomFs(dread_path_100), Game.DREAD) @pytest.fixture(scope="session") def dread_tree_210(dread_path_210): - return FileTreeEditor(dread_path_210, Game.DREAD) + return FileTreeEditor(ExtractedRomFs(dread_path_210), Game.DREAD) def pytest_addoption(parser): diff --git a/tools/plot_simple_map.py b/tools/plot_simple_map.py index 4d6f9c6..e3932e7 100644 --- a/tools/plot_simple_map.py +++ b/tools/plot_simple_map.py @@ -14,6 +14,7 @@ from mercury_engine_data_structures.file_tree_editor import FileTreeEditor from mercury_engine_data_structures.formats import Bmscc, Brfld, Brsa from mercury_engine_data_structures.game_check import Game +from mercury_engine_data_structures.romfs import ExtractedRomFs world_names = { "maps/levels/c10_samus/s010_cave/s010_cave.brfld": "Artaria", @@ -3276,7 +3277,7 @@ def decode_world( # noqa: C901 all_names = dread_data.all_asset_id_to_name() game = Game.DREAD - pkg_editor = FileTreeEditor(root, Game.DREAD) + pkg_editor = FileTreeEditor(ExtractedRomFs(root), Game.DREAD) for asset_id, name in all_names.items(): if target_level not in name: diff --git a/tools/sr_export_rdv_database.py b/tools/sr_export_rdv_database.py index 26abcc3..57c63d6 100644 --- a/tools/sr_export_rdv_database.py +++ b/tools/sr_export_rdv_database.py @@ -17,6 +17,7 @@ from mercury_engine_data_structures.file_tree_editor import FileTreeEditor from mercury_engine_data_structures.formats import Bmscc, Bmsld from mercury_engine_data_structures.game_check import Game +from mercury_engine_data_structures.romfs import ExtractedRomFs world_names = { "maps/levels/c10_samus/s000_surface/s000_surface.bmsld": "Surface - East", @@ -549,7 +550,7 @@ def decode_world( all_names = samus_returns_data.all_asset_id_to_name() game = Game.SAMUS_RETURNS - pkg_editor = FileTreeEditor(root, target_game=game) + pkg_editor = FileTreeEditor(ExtractedRomFs(root), target_game=game) for asset_id, name in all_names.items(): if target_level not in name: