Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower resolution option #26

Merged
merged 4 commits into from
Sep 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 41 additions & 32 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import time
from typing import Any, Optional
from typing import Any, Dict, Optional

import torch

Expand Down Expand Up @@ -96,6 +96,11 @@ class SharedInputs:
description="Run faster predictions with model optimized for speed (currently fp8 quantized); disable to run in original bf16",
default=True,
)
megapixels: Input = Input(
description="Approximate number of megapixels for generated image",
choices=["1", "0.25"],
default="1",
)


SHARED_INPUTS = SharedInputs()
Expand Down Expand Up @@ -191,6 +196,13 @@ def compile_fp8(self):
self.fp8_pipe.generate(
prompt="godzilla!", width=width, height=height, num_steps=4, guidance=3
)
self.fp8_pipe.generate(
prompt="godzilla!",
width=width // 2,
height=height // 2,
num_steps=4,
guidance=3,
)

print("compiled in ", time.time() - st)

Expand Down Expand Up @@ -228,26 +240,35 @@ def get_image(self, image: str):
def predict():
raise Exception("You need to instantiate a predictor for a specific flux model")

def preprocess(
self, aspect_ratio: str, seed: Optional[int], megapixels: str
) -> Dict:
width, height = ASPECT_RATIOS.get(aspect_ratio)
if megapixels == "0.25":
width, height = width // 2, height // 2

if not seed:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")

return {"width": width, "height": height, "seed": seed}

@torch.inference_mode()
def base_predict(
self,
prompt: str,
aspect_ratio: str,
num_outputs: int,
num_inference_steps: int,
guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default
image: Path = None, # img2img for flux-dev
prompt_strength: float = 0.8,
seed: Optional[int] = None,
seed: int = None,
width: int = 1024,
height: int = 1024,
) -> List[Path]:
"""Run a single prediction on the model"""
torch_device = torch.device("cuda")
init_image = None
width, height = self.aspect_ratio_to_width_height(aspect_ratio)

if not seed:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")

# img2img only works for flux-dev
if image:
Expand Down Expand Up @@ -345,24 +366,17 @@ def base_predict(
def fp8_predict(
self,
prompt: str,
aspect_ratio: str,
num_outputs: int,
num_inference_steps: int,
guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default
image: Path = None, # img2img for flux-dev
prompt_strength: float = 0.8,
seed: Optional[int] = None,
seed: int = None,
width: int = 1024,
height: int = 1024,
) -> List[Image]:
"""Run a single prediction on the model"""
print("running quantized prediction")
if seed is None:
seed = np.random.randint(1, 100000)

width, height = self.aspect_ratio_to_width_height(aspect_ratio)
if image:
image = Image.open(image).convert("RGB")
print("generating")

return self.fp8_pipe.generate(
prompt=prompt,
width=width,
Expand Down Expand Up @@ -455,22 +469,17 @@ def predict(
output_quality: int = SHARED_INPUTS.output_quality,
disable_safety_checker: bool = SHARED_INPUTS.disable_safety_checker,
go_fast: bool = SHARED_INPUTS.go_fast,
megapixels: str = SHARED_INPUTS.megapixels,
) -> List[Path]:
hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels)

if go_fast:
imgs, np_imgs = self.fp8_predict(
prompt,
aspect_ratio,
num_outputs,
num_inference_steps=self.num_steps,
seed=seed,
prompt, num_outputs, num_inference_steps=self.num_steps, **hws_kwargs
)
else:
imgs, np_imgs = self.base_predict(
prompt,
aspect_ratio,
num_outputs,
num_inference_steps=self.num_steps,
seed=seed,
prompt, num_outputs, num_inference_steps=self.num_steps, **hws_kwargs
)

return self.postprocess(
Expand Down Expand Up @@ -515,32 +524,32 @@ def predict(
output_quality: int = SHARED_INPUTS.output_quality,
disable_safety_checker: bool = SHARED_INPUTS.disable_safety_checker,
go_fast: bool = SHARED_INPUTS.go_fast,
megapixels: str = SHARED_INPUTS.megapixels,
) -> List[Path]:
if image and go_fast:
print("img2img not supported with fp8 quantization; running with bf16")
go_fast = False
hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels)

if go_fast:
imgs, np_imgs = self.fp8_predict(
prompt,
aspect_ratio,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
**hws_kwargs,
)
else:
imgs, np_imgs = self.base_predict(
prompt,
aspect_ratio,
num_outputs,
num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
**hws_kwargs,
)

return self.postprocess(
Expand Down
Loading