Skip to content

Commit

Permalink
fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
andompesta committed Nov 26, 2024
1 parent 079778f commit c7fdb64
Showing 1 changed file with 6 additions and 18 deletions.
24 changes: 6 additions & 18 deletions src/flux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
from glob import iglob

import torch
from cuda import cudart
from fire import Fire
from transformers import pipeline

from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image

from flux.trt.trt_manager import TRTManager
from cuda import cudart

from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image

NSFW_THRESHOLD = 0.85

Expand All @@ -29,9 +27,7 @@ class SamplingOptions:


def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = (
"Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
)
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the prompt or write a command starting with a slash:\n"
Expand Down Expand Up @@ -137,9 +133,7 @@ def main(
trt: use TensorRT backend for optimized inference
kwargs: additional arguments for TensorRT support
"""
nsfw_classifier = pipeline(
"image-classification", model="Falconsai/nsfw_image_detection", device=device
)
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

if name not in configs:
available = ", ".join(configs.keys())
Expand All @@ -158,11 +152,7 @@ def main(
os.makedirs(output_dir)
idx = 0
else:
fns = [
fn
for fn in iglob(output_name.format(idx="*"))
if re.search(r"img_[0-9]+\.jpg$", fn)
]
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
Expand Down Expand Up @@ -260,9 +250,7 @@ def main(
torch.cuda.empty_cache()
t5, clip = t5.to(torch_device), clip.to(torch_device)
inp = prepare(t5, clip, x, prompt=opts.prompt)
timesteps = get_schedule(
opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")
)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))

# offload TEs to CPU, load model to gpu
if offload:
Expand Down

0 comments on commit c7fdb64

Please sign in to comment.