From 8837048bcc38f7434887140870748b92db0f5386 Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Mon, 1 Jul 2024 16:48:15 +0200 Subject: [PATCH] Slightly improved generator API (#2973) --- backend/src/api/api.py | 24 ++++++++----------- .../ncnn/batch_processing/load_models.py | 5 +++- .../onnx/batch_processing/load_models.py | 5 +++- .../pytorch/iteration/load_models.py | 5 +++- .../image/batch_processing/load_images.py | 5 +++- backend/src/process.py | 4 +++- 6 files changed, 29 insertions(+), 19 deletions(-) diff --git a/backend/src/api/api.py b/backend/src/api/api.py index e46df90be..a6cce94d2 100644 --- a/backend/src/api/api.py +++ b/backend/src/api/api.py @@ -2,7 +2,6 @@ import importlib import os -import typing from dataclasses import asdict, dataclass, field from typing import ( Any, @@ -513,22 +512,21 @@ def add_package( @dataclass class Generator(Generic[I]): - supplier: Callable[[], typing.Iterator[I | Exception]] + supplier: Callable[[], Iterable[I | Exception]] expected_length: int fail_fast: bool = True + def with_fail_fast(self, fail_fast: bool) -> Generator[I]: + return Generator(self.supplier, self.expected_length, fail_fast=fail_fast) + @staticmethod def from_iter( - supplier: Callable[[], typing.Iterator[I | Exception]], - expected_length: int, - fail_fast: bool = True, + supplier: Callable[[], Iterable[I | Exception]], expected_length: int ) -> Generator[I]: - return Generator(supplier, expected_length, fail_fast=fail_fast) + return Generator(supplier, expected_length) @staticmethod - def from_list( - l: list[L], map_fn: Callable[[L, int], I], fail_fast: bool = True - ) -> Generator[I]: + def from_list(l: list[L], map_fn: Callable[[L, int], I]) -> Generator[I]: """ Creates a new generator from a list that is mapped using the given function. The iterable will be equivalent to `map(map_fn, l)`. @@ -541,12 +539,10 @@ def supplier(): except Exception as e: yield e - return Generator(supplier, len(l), fail_fast=fail_fast) + return Generator(supplier, len(l)) @staticmethod - def from_range( - count: int, map_fn: Callable[[int], I], fail_fast: bool = True - ) -> Generator[I]: + def from_range(count: int, map_fn: Callable[[int], I]) -> Generator[I]: """ Creates a new generator the given number of items where each item is lazily evaluated. The iterable will be equivalent to `map(map_fn, range(count))`. @@ -560,7 +556,7 @@ def supplier(): except Exception as e: yield e - return Generator(supplier, count, fail_fast=fail_fast) + return Generator(supplier, count) N = TypeVar("N") diff --git a/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py b/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py index 417a4bbbe..e157c2b25 100644 --- a/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py +++ b/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py @@ -82,4 +82,7 @@ def load_model(filepath_pairs: tuple[Path, Path], index: int): model_files = list(zip(param_files, bin_files)) - return Generator.from_list(model_files, load_model, fail_fast), directory + return ( + Generator.from_list(model_files, load_model).with_fail_fast(fail_fast), + directory, + ) diff --git a/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py b/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py index 20750d0ec..29c2678ef 100644 --- a/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py +++ b/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py @@ -63,4 +63,7 @@ def load_model(path: Path, index: int): supported_filetypes = [".onnx"] model_files = list_all_files_sorted(directory, supported_filetypes) - return Generator.from_list(model_files, load_model, fail_fast), directory + return ( + Generator.from_list(model_files, load_model).with_fail_fast(fail_fast), + directory, + ) diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py b/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py index 4e616fe67..0b4b96f53 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py @@ -62,4 +62,7 @@ def load_model(path: Path, index: int): supported_filetypes = [".pt", ".pth", ".ckpt", ".safetensors"] model_files: list[Path] = list_all_files_sorted(directory, supported_filetypes) - return Generator.from_list(model_files, load_model, fail_fast), directory + return ( + Generator.from_list(model_files, load_model).with_fail_fast(fail_fast), + directory, + ) diff --git a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py index e66579658..4ab28ff85 100644 --- a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py +++ b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py @@ -132,4 +132,7 @@ def load_image(path: Path, index: int): if use_limit: just_image_files = just_image_files[:limit] - return Generator.from_list(just_image_files, load_image, fail_fast), directory + return ( + Generator.from_list(just_image_files, load_image).with_fail_fast(fail_fast), + directory, + ) diff --git a/backend/src/process.py b/backend/src/process.py index d12f042e0..4b0b427a9 100644 --- a/backend/src/process.py +++ b/backend/src/process.py @@ -706,7 +706,9 @@ async def __iterate_generator_nodes(self, generator_nodes: list[GeneratorNode]): # run the generator nodes before anything else for node in generator_nodes: generator_output = await self.process_generator_node(node) - generator_suppliers[node.id] = generator_output.generator.supplier() + generator_suppliers[node.id] = ( + generator_output.generator.supplier().__iter__() + ) collector_nodes, __output_nodes, __all_iterated_nodes = ( self.__get_iterated_nodes(node)