diff --git a/backend/src/api/api.py b/backend/src/api/api.py index 8ee41ad3f..a6cce94d2 100644 --- a/backend/src/api/api.py +++ b/backend/src/api/api.py @@ -516,18 +516,17 @@ class Generator(Generic[I]): expected_length: int fail_fast: bool = True + def with_fail_fast(self, fail_fast: bool) -> Generator[I]: + return Generator(self.supplier, self.expected_length, fail_fast=fail_fast) + @staticmethod def from_iter( - supplier: Callable[[], Iterable[I | Exception]], - expected_length: int, - fail_fast: bool = True, + supplier: Callable[[], Iterable[I | Exception]], expected_length: int ) -> Generator[I]: - return Generator(supplier, expected_length, fail_fast=fail_fast) + return Generator(supplier, expected_length) @staticmethod - def from_list( - l: list[L], map_fn: Callable[[L, int], I], fail_fast: bool = True - ) -> Generator[I]: + def from_list(l: list[L], map_fn: Callable[[L, int], I]) -> Generator[I]: """ Creates a new generator from a list that is mapped using the given function. The iterable will be equivalent to `map(map_fn, l)`. @@ -540,12 +539,10 @@ def supplier(): except Exception as e: yield e - return Generator(supplier, len(l), fail_fast=fail_fast) + return Generator(supplier, len(l)) @staticmethod - def from_range( - count: int, map_fn: Callable[[int], I], fail_fast: bool = True - ) -> Generator[I]: + def from_range(count: int, map_fn: Callable[[int], I]) -> Generator[I]: """ Creates a new generator the given number of items where each item is lazily evaluated. The iterable will be equivalent to `map(map_fn, range(count))`. @@ -559,7 +556,7 @@ def supplier(): except Exception as e: yield e - return Generator(supplier, count, fail_fast=fail_fast) + return Generator(supplier, count) N = TypeVar("N") diff --git a/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py b/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py index 417a4bbbe..e157c2b25 100644 --- a/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py +++ b/backend/src/packages/chaiNNer_ncnn/ncnn/batch_processing/load_models.py @@ -82,4 +82,7 @@ def load_model(filepath_pairs: tuple[Path, Path], index: int): model_files = list(zip(param_files, bin_files)) - return Generator.from_list(model_files, load_model, fail_fast), directory + return ( + Generator.from_list(model_files, load_model).with_fail_fast(fail_fast), + directory, + ) diff --git a/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py b/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py index 20750d0ec..29c2678ef 100644 --- a/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py +++ b/backend/src/packages/chaiNNer_onnx/onnx/batch_processing/load_models.py @@ -63,4 +63,7 @@ def load_model(path: Path, index: int): supported_filetypes = [".onnx"] model_files = list_all_files_sorted(directory, supported_filetypes) - return Generator.from_list(model_files, load_model, fail_fast), directory + return ( + Generator.from_list(model_files, load_model).with_fail_fast(fail_fast), + directory, + ) diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py b/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py index 4e616fe67..0b4b96f53 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py @@ -62,4 +62,7 @@ def load_model(path: Path, index: int): supported_filetypes = [".pt", ".pth", ".ckpt", ".safetensors"] model_files: list[Path] = list_all_files_sorted(directory, supported_filetypes) - return Generator.from_list(model_files, load_model, fail_fast), directory + return ( + Generator.from_list(model_files, load_model).with_fail_fast(fail_fast), + directory, + ) diff --git a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py index e6e5bbdee..3c4c00016 100644 --- a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py +++ b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py @@ -130,4 +130,7 @@ def load_image(path: Path, index: int): if use_limit: just_image_files = just_image_files[:limit] - return Generator.from_list(just_image_files, load_image, fail_fast), directory + return ( + Generator.from_list(just_image_files, load_image).with_fail_fast(fail_fast), + directory, + )