From 3d3741edf2a2ee71b27267cf67bada9c5a228982 Mon Sep 17 00:00:00 2001 From: Sandro Cavallari Date: Mon, 18 Nov 2024 11:50:40 +0100 Subject: [PATCH 1/4] simplify dependencies --- trt_requirements.txt | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/trt_requirements.txt b/trt_requirements.txt index 7295b317..b068a057 100644 --- a/trt_requirements.txt +++ b/trt_requirements.txt @@ -1,19 +1,12 @@ tensorrt-cu12 >= 10.5.0 -accelerate colored cuda-python -diffusers==0.31.0 -ftfy -matplotlib -nvtx +transformers==4.42.2 +opencv-python==4.8.0.74 onnx==1.17.0 onnxruntime==1.19.2 -opencv-python==4.8.0.74 -scipy -transformers==4.42.2 --extra-index-url https://pypi.nvidia.com nvidia-modelopt[torch,onnx]==0.19.0 onnx-graphsurgeon polygraphy==0.49.9 -peft==0.13.0 -sentencepiece +sentencepiece \ No newline at end of file From f31ffd40a559313d81729740462facc0a2e392e0 Mon Sep 17 00:00:00 2001 From: Sandro Cavallari Date: Tue, 26 Nov 2024 13:45:58 +0100 Subject: [PATCH 2/4] solve vae quality issue --- src/flux/trt/engine/vae_engine.py | 4 ++-- src/flux/trt/exporter/vae_exporter.py | 6 +----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/flux/trt/engine/vae_engine.py b/src/flux/trt/engine/vae_engine.py index 874596f3..a5d6cd68 100644 --- a/src/flux/trt/engine/vae_engine.py +++ b/src/flux/trt/engine/vae_engine.py @@ -39,9 +39,9 @@ def __init__( def __call__( self, - x: torch.Tensor, + latent: torch.Tensor, ) -> torch.Tensor: - return self.decode(x) + return self.decode(latent) def decode(self, z: torch.Tensor) -> torch.Tensor: z = z.to(dtype=self.tensors["latent"].dtype) diff --git a/src/flux/trt/exporter/vae_exporter.py b/src/flux/trt/exporter/vae_exporter.py index 4cf24082..2599bbfe 100644 --- a/src/flux/trt/exporter/vae_exporter.py +++ b/src/flux/trt/exporter/vae_exporter.py @@ -39,7 +39,7 @@ def __init__( compression_factor=compression_factor, scale_factor=model.params.scale_factor, shift_factor=model.params.shift_factor, - model=model, + model=model.decoder, # we need to trace only the decoder fp16=fp16, tf32=tf32, bf16=bf16, @@ -55,10 +55,6 @@ def __init__( # set proper dtype self.prepare_model() - def get_model(self) -> torch.nn.Module: - self.model.forward = self.model.decode - return self.model - def get_input_names(self): return ["latent"] From f80058f3a1d3d119ba5ef37a39b26c4310feb446 Mon Sep 17 00:00:00 2001 From: Sandro Cavallari Date: Tue, 26 Nov 2024 14:04:30 +0100 Subject: [PATCH 3/4] fix ruff format --- demo_gr.py | 2 ++ src/flux/api.py | 12 +++--------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/demo_gr.py b/demo_gr.py index 83bbe8a4..626fdf12 100644 --- a/demo_gr.py +++ b/demo_gr.py @@ -15,6 +15,7 @@ NSFW_THRESHOLD = 0.85 + def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool): t5 = load_t5(device, max_length=256 if is_schnell else 512) clip = load_clip(device) @@ -23,6 +24,7 @@ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool) nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) return model, ae, t5, clip, nsfw_classifier + class FluxGenerator: def __init__(self, model_name: str, device: str, offload: bool): self.device = torch.device(device) diff --git a/src/flux/api.py b/src/flux/api.py index 1c0b207e..6a608840 100644 --- a/src/flux/api.py +++ b/src/flux/api.py @@ -146,9 +146,7 @@ def request(self): ) result = response.json() if response.status_code != 200: - raise ApiException( - status_code=response.status_code, detail=result.get("detail") - ) + raise ApiException(status_code=response.status_code, detail=result.get("detail")) self.request_id = response.json()["id"] def retrieve(self) -> dict: @@ -170,17 +168,13 @@ def retrieve(self) -> dict: ) result = response.json() if "status" not in result: - raise ApiException( - status_code=response.status_code, detail=result.get("detail") - ) + raise ApiException(status_code=response.status_code, detail=result.get("detail")) elif result["status"] == "Ready": self.result = result["result"] elif result["status"] == "Pending": time.sleep(0.5) else: - raise ApiException( - status_code=200, detail=f"API returned status '{result['status']}'" - ) + raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") return self.result @property From a5986b52099a1a17291239bc82ec5087035218dc Mon Sep 17 00:00:00 2001 From: Sandro Cavallari Date: Tue, 26 Nov 2024 14:39:07 +0100 Subject: [PATCH 4/4] format and sort src/flux/cli --- src/flux/cli.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/src/flux/cli.py b/src/flux/cli.py index 454a335e..7ee2372a 100644 --- a/src/flux/cli.py +++ b/src/flux/cli.py @@ -5,15 +5,13 @@ from glob import iglob import torch +from cuda import cudart from fire import Fire from transformers import pipeline from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack -from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image - from flux.trt.trt_manager import TRTManager -from cuda import cudart - +from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image NSFW_THRESHOLD = 0.85 @@ -29,9 +27,7 @@ class SamplingOptions: def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: - user_question = ( - "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" - ) + user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" usage = ( "Usage: Either write your prompt directly, leave this field empty " "to repeat the prompt or write a command starting with a slash:\n" @@ -137,9 +133,7 @@ def main( trt: use TensorRT backend for optimized inference kwargs: additional arguments for TensorRT support """ - nsfw_classifier = pipeline( - "image-classification", model="Falconsai/nsfw_image_detection", device=device - ) + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) if name not in configs: available = ", ".join(configs.keys()) @@ -158,11 +152,7 @@ def main( os.makedirs(output_dir) idx = 0 else: - fns = [ - fn - for fn in iglob(output_name.format(idx="*")) - if re.search(r"img_[0-9]+\.jpg$", fn) - ] + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] if len(fns) > 0: idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 else: @@ -260,9 +250,7 @@ def main( torch.cuda.empty_cache() t5, clip = t5.to(torch_device), clip.to(torch_device) inp = prepare(t5, clip, x, prompt=opts.prompt) - timesteps = get_schedule( - opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell") - ) + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) # offload TEs to CPU, load model to gpu if offload: