Skip to content

Commit

Permalink
Merge pull request #28 from replicate/flux-schnell-steps
Browse files Browse the repository at this point in the history
Expose num_inference_steps in flux-schnell
  • Loading branch information
andreasjansson authored Oct 1, 2024
2 parents 71e89aa + cc4e30b commit 6a063a2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
17 changes: 8 additions & 9 deletions fp8/flux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def compile(self):
"final_layer",
"pe_embedder",
]

if self.config.compile_whole_model:
# we're not doing the compilation here in this case
return
Expand Down Expand Up @@ -384,7 +384,7 @@ def into_bytes(self, x: torch.Tensor, jpeg_quality: int = 99) -> io.BytesIO:
im = self.img_encoder.encode_torch(im, quality=jpeg_quality)
images.clear()
return im

@torch.inference_mode()
def as_img_tensor(self, x: torch.Tensor) -> io.BytesIO:
"""Converts the image tensor to bytes."""
Expand Down Expand Up @@ -593,8 +593,6 @@ def generate(
io.BytesIO: Generated image(s) in bytes format.
int: Seed used for generation (only if return_seed is True).
"""
num_steps = 4 if self.name == "flux-schnell" else num_steps

init_image = self.load_init_image_if_needed(init_image)

# allow for packing and conversion to latent space
Expand Down Expand Up @@ -648,7 +646,7 @@ def generate(
)
output_imgs.append(denoised_img)
compiling = False

img = torch.cat(output_imgs)

# offload the model to cpu if needed
Expand All @@ -660,7 +658,7 @@ def generate(
img = self.vae_decode(img, height, width)

return self.as_img_tensor(img)

def denoise_single_item(self,
img,
img_ids,
Expand All @@ -671,7 +669,7 @@ def denoise_single_item(self,
guidance,
compiling
):

img = img.unsqueeze(0)
img_ids = img_ids.unsqueeze(0)
txt = txt.unsqueeze(0)
Expand All @@ -682,6 +680,7 @@ def denoise_single_item(self,
(img.shape[0],), guidance, device=self.device_flux, dtype=self.dtype
)
t_vec = None

for t_curr, t_prev in tqdm(
zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1
):
Expand All @@ -695,11 +694,11 @@ def denoise_single_item(self,
else:
t_vec = t_vec.reshape((img.shape[0],)).fill_(t_curr)
if compiling:
torch._dynamo.mark_dynamic(img, 1, min=256, max=8100)
torch._dynamo.mark_dynamic(img, 1, min=256, max=8100)
torch._dynamo.mark_dynamic(img_ids, 1, min=256, max=8100)
self.model = torch.compile(self.model)
compiling = False

pred = self.model(
img=img,
img_ids=img_ids,
Expand Down
17 changes: 15 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def fp8_predict(
) -> List[Image]:
"""Run a single prediction on the model"""
print("running quantized prediction")

return self.fp8_pipe.generate(
prompt=prompt,
width=width,
Expand Down Expand Up @@ -464,6 +465,12 @@ def predict(
prompt: str = SHARED_INPUTS.prompt,
aspect_ratio: str = SHARED_INPUTS.aspect_ratio,
num_outputs: int = SHARED_INPUTS.num_outputs,
num_inference_steps: int = Input(
description="Number of denoising steps. Recommended range is 1-4",
ge=1,
le=4,
default=4,
),
seed: int = SHARED_INPUTS.seed,
output_format: str = SHARED_INPUTS.output_format,
output_quality: int = SHARED_INPUTS.output_quality,
Expand All @@ -475,11 +482,17 @@ def predict(

if go_fast:
imgs, np_imgs = self.fp8_predict(
prompt, num_outputs, num_inference_steps=self.num_steps, **hws_kwargs
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
**hws_kwargs,
)
else:
imgs, np_imgs = self.base_predict(
prompt, num_outputs, num_inference_steps=self.num_steps, **hws_kwargs
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
**hws_kwargs,
)

return self.postprocess(
Expand Down

0 comments on commit 6a063a2

Please sign in to comment.