Skip to content

Commit

Permalink
Merge branch 'generator-broadcasts' into gen-out-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment committed Jul 1, 2024
2 parents f23a262 + 307109c commit a6e4bce
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 16 deletions.
21 changes: 9 additions & 12 deletions backend/src/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,18 +516,17 @@ class Generator(Generic[I]):
expected_length: int
fail_fast: bool = True

def with_fail_fast(self, fail_fast: bool) -> Generator[I]:
return Generator(self.supplier, self.expected_length, fail_fast=fail_fast)

@staticmethod
def from_iter(
supplier: Callable[[], Iterable[I | Exception]],
expected_length: int,
fail_fast: bool = True,
supplier: Callable[[], Iterable[I | Exception]], expected_length: int
) -> Generator[I]:
return Generator(supplier, expected_length, fail_fast=fail_fast)
return Generator(supplier, expected_length)

@staticmethod
def from_list(
l: list[L], map_fn: Callable[[L, int], I], fail_fast: bool = True
) -> Generator[I]:
def from_list(l: list[L], map_fn: Callable[[L, int], I]) -> Generator[I]:
"""
Creates a new generator from a list that is mapped using the given
function. The iterable will be equivalent to `map(map_fn, l)`.
Expand All @@ -540,12 +539,10 @@ def supplier():
except Exception as e:
yield e

return Generator(supplier, len(l), fail_fast=fail_fast)
return Generator(supplier, len(l))

@staticmethod
def from_range(
count: int, map_fn: Callable[[int], I], fail_fast: bool = True
) -> Generator[I]:
def from_range(count: int, map_fn: Callable[[int], I]) -> Generator[I]:
"""
Creates a new generator the given number of items where each item is
lazily evaluated. The iterable will be equivalent to `map(map_fn, range(count))`.
Expand All @@ -559,7 +556,7 @@ def supplier():
except Exception as e:
yield e

return Generator(supplier, count, fail_fast=fail_fast)
return Generator(supplier, count)


N = TypeVar("N")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,7 @@ def load_model(filepath_pairs: tuple[Path, Path], index: int):

model_files = list(zip(param_files, bin_files))

return Generator.from_list(model_files, load_model, fail_fast), directory
return (
Generator.from_list(model_files, load_model).with_fail_fast(fail_fast),
directory,
)
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,7 @@ def load_model(path: Path, index: int):
supported_filetypes = [".onnx"]
model_files = list_all_files_sorted(directory, supported_filetypes)

return Generator.from_list(model_files, load_model, fail_fast), directory
return (
Generator.from_list(model_files, load_model).with_fail_fast(fail_fast),
directory,
)
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,7 @@ def load_model(path: Path, index: int):
supported_filetypes = [".pt", ".pth", ".ckpt", ".safetensors"]
model_files: list[Path] = list_all_files_sorted(directory, supported_filetypes)

return Generator.from_list(model_files, load_model, fail_fast), directory
return (
Generator.from_list(model_files, load_model).with_fail_fast(fail_fast),
directory,
)
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,7 @@ def load_image(path: Path, index: int):
if use_limit:
just_image_files = just_image_files[:limit]

return Generator.from_list(just_image_files, load_image, fail_fast), directory
return (
Generator.from_list(just_image_files, load_image).with_fail_fast(fail_fast),
directory,
)

0 comments on commit a6e4bce

Please sign in to comment.