diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 5c02649a..2bd6ad06 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -437,6 +437,7 @@ def prepare_all( vmfbs: dict = {}, weights: dict = {}, interactive: bool = False, + num_steps: int = 20, ): ready = self.is_prepared(vmfbs, weights) match ready: @@ -463,7 +464,9 @@ def prepare_all( if not self.map[submodel].get("weights") and self.map[submodel][ "export_args" ].get("external_weights"): - self.export_submodel(submodel, weights_only=True) + self.export_submodel( + submodel, weights_only=True, num_steps=num_steps + ) return self.prepare_all(mlirs, vmfbs, weights, interactive) def is_prepared(self, vmfbs, weights): @@ -581,6 +584,7 @@ def export_submodel( submodel: str, input_mlir: str = None, weights_only: bool = False, + num_steps: int = 20, ): if not os.path.exists(self.pipeline_dir): os.makedirs(self.pipeline_dir) @@ -670,7 +674,9 @@ def export_submodel( self.map[submodel]["export_args"]["precision"], self.map[submodel]["export_args"]["batch_size"], self.map[submodel]["export_args"]["max_length"], - "tokens_to_image", + "produce_img_split", + unet_module_name=self.map["unet"]["module_name"], + num_steps=num_steps, ) dims = [ self.map[submodel]["export_args"]["width"], @@ -699,8 +705,8 @@ def export_submodel( return_path=True, mlir_source="str", ) - self.map[submodel]["vmfb"] = vmfb_path - self.map[submodel]["weights"] = None + self.map[submodel]["vmfb"] = [vmfb_path] + self.map[submodel]["weights"] = [] case _: export_args = self.map[submodel].get("export_args", {}) if weights_only: @@ -721,10 +727,24 @@ def export_submodel( # LOAD def load_map(self): - for submodel in self.map.keys(): + # Make sure fullpipeline is imported last + submodels = list(self.map.keys() - {"fullpipeline"}) + submodels += ["fullpipeline"] if "fullpipeline" in self.map.keys() else [] + for submodel in submodels: if not self.map[submodel]["load"]: - self.printer.print("Skipping load for ", submodel) + self.printer.print(f"Skipping load for {submodel}") continue + elif self.map[submodel].get("wraps"): + vmfbs = [] + weights = [] + for wrapped in self.map[submodel]["wraps"]: + vmfbs.append(self.map[wrapped]["vmfb"]) + if "weights" in self.map[wrapped]: + weights.append(self.map[wrapped]["weights"]) + self.map[submodel]["vmfb"] = vmfbs + self.map[submodel]["vmfb"] + self.map[submodel]["weights"] = weights + self.map[submodel]["weights"] + + print(f"Loading {submodel}") self.load_submodel(submodel) def load_submodel(self, submodel): @@ -751,6 +771,10 @@ def load_submodel(self, submodel): def unload_submodel(self, submodel): self.map[submodel]["runner"].unload() + self.map[submodel]["vmfb"] = None + self.map[submodel]["mlir"] = None + self.map[submodel]["weights"] = None + self.map[submodel]["export_args"]["input_mlir"] = None setattr(self, submodel, None) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 277f74cb..aab9de77 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -118,23 +118,11 @@ "decomp_attn": None, }, }, - "unetloop": { - "module_name": "sdxl_compiled_pipeline", - "load": False, - "keywords": ["unetloop"], - "wraps": ["unet", "scheduler"], - "export_args": { - "batch_size": 1, - "height": 1024, - "width": 1024, - "max_length": 64, - }, - }, "fullpipeline": { "module_name": "sdxl_compiled_pipeline", - "load": False, + "load": True, "keywords": ["fullpipeline"], - "wraps": ["text_encoder", "unet", "scheduler", "vae"], + "wraps": ["unet", "scheduler", "vae"], "export_args": { "batch_size": 1, "height": 1024, @@ -190,6 +178,7 @@ def get_sd_model_map(hf_model_name): "stabilityai/sdxl-turbo", "stabilityai/stable-diffusion-xl-base-1.0", "/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16/checkpoint_pipe", + "/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16//checkpoint_pipe", ]: return sdxl_model_map elif "stabilityai/stable-diffusion-3" in name: @@ -233,6 +222,7 @@ def __init__( benchmark: bool | dict[bool] = False, verbose: bool = False, batch_prompts: bool = False, + compiled_pipeline: bool = False, ): common_export_args = { "hf_model_name": None, @@ -243,11 +233,11 @@ def __init__( "exit_on_vmfb": False, "pipeline_dir": pipeline_dir, "input_mlir": None, - "attn_spec": None, + "attn_spec": attn_spec, "external_weights": None, "external_weight_path": None, } - sd_model_map = get_sd_model_map(hf_model_name) + sd_model_map = copy.deepcopy(get_sd_model_map(hf_model_name)) for submodel in sd_model_map: if "load" not in sd_model_map[submodel]: sd_model_map[submodel]["load"] = True @@ -311,6 +301,7 @@ def __init__( self.scheduler = None self.split_scheduler = True + self.compiled_pipeline = compiled_pipeline self.base_model_name = ( hf_model_name @@ -321,11 +312,6 @@ def __init__( self.is_sdxl = "xl" in self.base_model_name.lower() self.is_sd3 = "stable-diffusion-3" in self.base_model_name if self.is_sdxl: - if self.split_scheduler: - if self.map.get("unetloop"): - self.map.pop("unetloop") - if self.map.get("fullpipeline"): - self.map.pop("fullpipeline") self.tokenizers = [ CLIPTokenizer.from_pretrained( self.base_model_name, subfolder="tokenizer" @@ -339,6 +325,20 @@ def __init__( self.scheduler_device = self.map["unet"]["device"] self.scheduler_driver = self.map["unet"]["driver"] self.scheduler_target = self.map["unet"]["target"] + if not self.compiled_pipeline: + if self.map.get("unetloop"): + self.map.pop("unetloop") + if self.map.get("fullpipeline"): + self.map.pop("fullpipeline") + elif self.compiled_pipeline: + self.map["unet"]["load"] = False + self.map["vae"]["load"] = False + self.load_scheduler( + scheduler_id, + num_inference_steps, + ) + self.map["scheduler"]["runner"].unload() + self.map["scheduler"]["load"] = False elif not self.is_sd3: self.tokenizer = CLIPTokenizer.from_pretrained( self.base_model_name, subfolder="tokenizer" @@ -351,13 +351,15 @@ def __init__( self.latents_dtype = torch_dtypes[self.latents_precision] self.use_i8_punet = self.use_punet = use_i8_punet + if self.use_punet: + self.setup_punet() + else: + self.map["unet"]["keywords"].append("!punet") + self.map["unet"]["function_name"] = "run_forward" + + def setup_punet(self): if self.use_i8_punet: self.map["unet"]["export_args"]["precision"] = "i8" - self.map["unet"]["export_args"]["use_punet"] = True - self.map["unet"]["use_weights_for_export"] = True - self.map["unet"]["keywords"].append("punet") - self.map["unet"]["module_name"] = "compiled_punet" - self.map["unet"]["function_name"] = "main" self.map["unet"]["export_args"]["external_weight_path"] = ( utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" ) @@ -365,9 +367,11 @@ def __init__( if word in ["fp32", "fp16"]: self.map["unet"]["keywords"][idx] = "i8" break - else: - self.map["unet"]["keywords"].append("!punet") - self.map["unet"]["function_name"] = "run_forward" + self.map["unet"]["export_args"]["use_punet"] = True + self.map["unet"]["use_weights_for_export"] = True + self.map["unet"]["keywords"].append("punet") + self.map["unet"]["module_name"] = "compiled_punet" + self.map["unet"]["function_name"] = "main" # LOAD @@ -376,10 +380,6 @@ def load_scheduler( scheduler_id: str, steps: int = 30, ): - if self.is_sd3: - scheduler_device = self.mmdit.device - else: - scheduler_device = self.unet.device if not self.cpu_scheduling: self.map["scheduler"] = { "module_name": "compiled_scheduler", @@ -425,7 +425,11 @@ def load_scheduler( except: print("JIT export of scheduler failed. Loading CPU scheduler.") self.cpu_scheduling = True - if self.cpu_scheduling: + elif self.cpu_scheduling: + if self.is_sd3: + scheduler_device = self.mmdit.device + else: + scheduler_device = self.unet.device scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id) self.scheduler = schedulers.SharkSchedulerCPUWrapper( scheduler, @@ -461,13 +465,10 @@ def encode_prompts_sdxl(self, prompt, negative_prompt): text_input_ids_list += text_inputs.input_ids.unsqueeze(0) uncond_input_ids_list += uncond_input.input_ids.unsqueeze(0) - if self.compiled_pipeline: - return text_input_ids_list, uncond_input_ids_list - else: - prompt_embeds, add_text_embeds = self.text_encoder( - "encode_prompts", [*text_input_ids_list, *uncond_input_ids_list] - ) - return prompt_embeds, add_text_embeds + prompt_embeds, add_text_embeds = self.text_encoder( + "encode_prompts", [*text_input_ids_list, *uncond_input_ids_list] + ) + return prompt_embeds, add_text_embeds def prepare_latents( self, @@ -565,9 +566,11 @@ def _produce_latents_sdxl( [guidance_scale], dtype=self.map["unet"]["np_dtype"], ) + # Disable progress bar if we aren't in verbose mode or if we're printing + # benchmark latencies for unet. for i, t in tqdm( enumerate(timesteps), - disable=(self.map["unet"].get("benchmark") and self.verbose), + disable=(self.map["unet"].get("benchmark") or not self.verbose), ): if self.cpu_scheduling: latent_model_input, t = self.scheduler.scale_model_input( @@ -608,6 +611,75 @@ def _produce_latents_sdxl( latents = self.scheduler("run_step", [noise_pred, t, latents]) return latents + def produce_images_compiled( + self, + sample, + prompt_embeds, + text_embeds, + guidance_scale, + ): + pipe_inputs = [ + sample, + prompt_embeds, + text_embeds, + torch.as_tensor([guidance_scale], dtype=sample.dtype), + ] + # image = self.compiled_pipeline("produce_img_latents", pipe_inputs) + image = self.map["fullpipeline"]["runner"]("produce_image_latents", pipe_inputs) + return image + + def prepare_sampling_inputs( + self, + prompt: str, + negative_prompt: str = "", + steps: int = 30, + batch_count: int = 1, + guidance_scale: float = 7.5, + seed: float = -1, + cpu_scheduling: bool = True, + scheduler_id: str = "EulerDiscrete", + return_imgs: bool = False, + ): + needs_new_scheduler = ( + (steps and steps != self.num_inference_steps) + or (cpu_scheduling != self.cpu_scheduling) + and self.split_scheduler + ) + if not self.scheduler and not self.compiled_pipeline: + needs_new_scheduler = True + + if guidance_scale == 0: + negative_prompt = prompt + prompt = "" + + self.cpu_scheduling = cpu_scheduling + if steps and needs_new_scheduler: + self.num_inference_steps = steps + self.load_scheduler(scheduler_id, steps) + + pipe_start = time.time() + numpy_images = [] + + samples = self.get_rand_latents(seed, batch_count) + + # Tokenize prompt and negative prompt. + if self.is_sdxl: + prompt_embeds, negative_embeds = self.encode_prompts_sdxl( + prompt, negative_prompt + ) + else: + prompt_embeds, negative_embeds = encode_prompt( + self, prompt, negative_prompt + ) + produce_latents_input = [ + samples[0], + prompt_embeds, + negative_embeds, + steps, + guidance_scale, + ] + return produce_latents_input + def generate_images( self, prompt: str, @@ -653,18 +725,23 @@ def generate_images( ) for i in range(batch_count): - produce_latents_input = [ - samples[i], - prompt_embeds, - negative_embeds, - steps, - guidance_scale, - ] - if self.is_sdxl: - latents = self._produce_latents_sdxl(*produce_latents_input) + if self.compiled_pipeline: + image = self.produce_images_compiled( + samples[i], prompt_embeds, negative_embeds, guidance_scale + ).to_host() else: - latents = self._produce_latents_sd(*produce_latents_input) - image = self.vae("decode", [latents]) + produce_latents_input = [ + samples[i], + prompt_embeds, + negative_embeds, + steps, + guidance_scale, + ] + if self.is_sdxl: + latents = self._produce_latents_sdxl(*produce_latents_input) + else: + latents = self._produce_latents_sd(*produce_latents_input) + image = self.vae("decode", [latents]) numpy_images.append(image) pipe_end = time.time() @@ -750,8 +827,10 @@ def numpy_to_pil_image(images): args.use_i8_punet, benchmark, args.verbose, + False, + args.compiled_pipeline, ) - sd_pipe.prepare_all() + sd_pipe.prepare_all(num_steps=args.num_inference_steps) sd_pipe.load_map() sd_pipe.generate_images( args.prompt, diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index cc8591b9..78f20a12 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -28,18 +28,21 @@ "pad_attention": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,128,0,32,0}))", ], + "punet": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + ], "preprocess_default": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))", ], "unet": [ "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions=true", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", "--iree-opt-aggressively-propagate-transposes=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", ], "clip": [ "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions=true", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", "--iree-opt-aggressively-propagate-transposes=true", ], "vae": [ @@ -58,7 +61,7 @@ "--iree-opt-const-eval=false", "--iree-opt-aggressively-propagate-transposes=true", "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions=true", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", "--iree-codegen-gpu-native-math-precision=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", @@ -66,6 +69,9 @@ "pad_attention": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", ], + "punet": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + ], "preprocess_default": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))", ], @@ -153,7 +159,7 @@ def compile_to_vmfb( save_mlir=True, attn_spec=None, winograd=False, - masked_attention=False, + flagset_keywords=[], debug=False, ): flags = [] @@ -235,15 +241,19 @@ def compile_to_vmfb( elif "vae" in safe_name: flags.extend(MI_flags["vae"]) flags.extend(MI_flags["all"]) - if masked_attention: - flags.extend(GFX11_flags["pad_attention"]) + if "masked_attention" in flagset_keywords: + flags.extend(MI_flags["pad_attention"]) + elif "punet" in flagset_keywords: + flags.extend(MI_flags["punet"]) else: - flags.extend(GFX11_flags["preprocess_default"]) + flags.extend(MI_flags["preprocess_default"]) if "gfx11" in target_triple: flags.extend(GFX11_flags["all"]) - if masked_attention: + if "masked_attention" in flagset_keywords: flags.extend(GFX11_flags["pad_attention"]) + elif "punet" in flagset_keywords: + flags.extend(GFX11_flags["punet"]) else: flags.extend(GFX11_flags["preprocess_default"]) @@ -257,15 +267,12 @@ def compile_to_vmfb( attn_spec = get_mfma_spec_path( target_triple, os.path.dirname(safe_name), - masked_attention, use_punet=use_punet, ) flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): - attn_spec = get_wmma_spec_path( - target_triple, os.path.dirname(safe_name), masked_attention - ) + attn_spec = get_wmma_spec_path(target_triple, os.path.dirname(safe_name)) if attn_spec: flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) elif attn_spec and attn_spec != "None": diff --git a/models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py b/models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py index cb2b62be..f0cec20b 100644 --- a/models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py +++ b/models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py @@ -1,9 +1,9 @@ tokens_to_image = r""" module @sdxl_compiled_pipeline {{ - func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x6x{precision}>, tensor) attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}]"}} + func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x6x{precision}>, tensor) attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}]"}} func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>, %arg1: tensor<{bd}x{max_length}x2048x{precision}>, %arg2: tensor<{bd}x1280x{precision}>, %arg3: tensor<{bd}x6x{precision}>, %arg4: tensor<1x{precision}>, %arg5: tensor<1xi64>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]"}} func.func private @compiled_clip.encode_prompts(%arg0: tensor<{batch_size}x{max_length}xi64>, %arg1: tensor<{batch_size}x{max_length}xi64>, %arg2: tensor<{batch_size}x{max_length}xi64>, %arg3: tensor<{batch_size}x{max_length}xi64>) -> (tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>) attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}]"}} - func.func private @{vae_fn_name}.main(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> tensor<{batch_size}x3x{width}x{height}x{precision}> attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]"}} + func.func private @{vae_module}.main(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> tensor<{batch_size}x3x{width}x{height}x{precision}> attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]"}} func.func @tokens_to_image(%sample: tensor<{batch_size}x4x{lw}x{lh}x{precision}>, %guidance_scale: tensor<1x{precision}>, %t_ids_1: tensor<{batch_size}x{max_length}xi64>, %t_ids_2: tensor<{batch_size}x{max_length}xi64>, %u_ids_1: tensor<{batch_size}x{max_length}xi64>, %u_ids_2: tensor<{batch_size}x{max_length}xi64>) -> tensor<{batch_size}x3x{width}x{height}x{precision}> {{ %p_embeds, %t_embeds = func.call @compiled_clip.encode_prompts(%t_ids_1, %t_ids_2, %u_ids_1, %u_ids_2) : (tensor<{batch_size}x{max_length}xi64>, tensor<{batch_size}x{max_length}xi64>, tensor<{batch_size}x{max_length}xi64>, tensor<{batch_size}x{max_length}xi64>) -> (tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>) @@ -18,7 +18,7 @@ %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<1xi64>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> scf.yield %inner : tensor<{batch_size}x4x{lw}x{lh}x{precision}> }} - %image = func.call @{vae_fn_name}.main(%res): (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> tensor<{batch_size}x3x{width}x{height}x{precision}> + %image = func.call @{vae_module}.main(%res): (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> tensor<{batch_size}x3x{width}x{height}x{precision}> return %image : tensor<{batch_size}x3x{width}x{height}x{precision}> }} }} @@ -38,7 +38,7 @@ %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) {{ %step_64 = arith.index_cast %arg0 : index to i64 %this_step = tensor.from_elements %step_64 : tensor<1xi64> - %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<1xi64>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> + %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %this_step, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<1x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> scf.yield %inner : tensor<{batch_size}x4x{lw}x{lh}x{precision}> }} return %res : tensor<{batch_size}x4x{lw}x{lh}x{precision}> @@ -46,6 +46,33 @@ }} """ +produce_img_split = r""" +module @sdxl_compiled_pipeline {{ + func.func private @{scheduler_module}.run_initialize(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<{bd}x6x{precision}>, tensor<1xf16>, tensor<{num_steps}xf32>) attributes {{torch.assume_strict_symbolic_shapes}} + func.func private @{scheduler_module}.run_scale(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1xi64>, %arg2: tensor<{num_steps}xf32>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>) attributes {{torch.assume_strict_symbolic_shapes}} + func.func private @{scheduler_module}.run_step(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1x{precision}>, %arg2: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}} + func.func private @{unet_module}.{unet_function}(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1x{precision}>, %arg2: tensor<{bd}x{max_length}x2048x{precision}>, %arg3: tensor<{bd}x1280x{precision}>, %arg4: tensor<{bd}x6x{precision}>, %arg5: tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}} + func.func private @{vae_module}.decode(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}} + + func.func @produce_image_latents(%sample: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %p_embeds: tensor<{bd}x{max_length}x2048x{precision}>, %t_embeds: tensor<{bd}x1280x{precision}>, %guidance_scale: tensor<1x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}> {{ + %noisy_sample, %time_ids, %delete, %timesteps = func.call @{scheduler_module}.run_initialize(%sample) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<{num_steps}xf32>) + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %n_steps = arith.constant {num_steps} : index + %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) {{ + %step_64 = arith.index_cast %arg0 : index to i64 + %this_step = tensor.from_elements %step_64 : tensor<1xi64> + %scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %this_step, %timesteps) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1xi64>, tensor<{num_steps}xf32>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>) + %inner = func.call @{unet_module}.{unet_function}(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> + %pred = func.call @{scheduler_module}.run_step(%inner, %timestep, %arg) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> + scf.yield %pred : tensor<{batch_size}x4x{lh}x{lw}x{precision}> + }} + %image = func.call @{vae_module}.decode(%res): (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}> + return %image : tensor<{batch_size}x3x{height}x{width}x{precision}> + }} +}} +""" + def get_pipeline_ir( width: int, @@ -54,7 +81,9 @@ def get_pipeline_ir( batch_size: int, max_length: int, type: str, - vae_fn_name: str = "compiled_vae", + num_steps: int = 20, + vae_module: str = "compiled_vae", + unet_module_name: str = "compiled_punet", ): precision = "f32" if precision == "fp32" else "f16" if type == "tokens_to_image": @@ -67,7 +96,7 @@ def get_pipeline_ir( precision=precision, batch_size=batch_size, max_length=max_length, - vae_fn_name=vae_fn_name, + vae_module=vae_module, ) elif type == "unet_loop": return unet_loop.format( @@ -80,3 +109,22 @@ def get_pipeline_ir( batch_size=batch_size, max_length=max_length, ) + elif type == "produce_img_split": + unet_fn_name = "run_forward" + scheduler_module_name = "compiled_scheduler" + vae_module_name = "compiled_vae" + return produce_img_split.format( + width=width, + height=height, + lw=int(width / 8), + lh=int(height / 8), + bd=int(batch_size * 2), + precision=precision, + batch_size=batch_size, + max_length=max_length, + unet_module=unet_module_name, + unet_function=unet_fn_name, + scheduler_module=scheduler_module_name, + vae_module=vae_module_name, + num_steps=num_steps, + ) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py index ec88c525..3f0aaf7e 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py @@ -480,6 +480,7 @@ def export_submodel( self.hf_model_name, None, self.max_length, + self.batch_size, self.precision, "vmfb", self.external_weights, @@ -494,7 +495,6 @@ def export_submodel( input_mlir=input_mlir["prompt_encoder"], attn_spec=self.attn_spec, weights_only=weights_only, - batchsize=self.batch_size, batch_input=self.batch_prompt_input, ) return prompt_encoder_vmfb, prompt_encoder_external_weight_path diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index 40ce6c2e..8f7668d5 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -231,7 +231,7 @@ def export_prompt_encoder( ) if weights_only: - return None, external_weight_path + return external_weight_path class CompiledClip(CompiledModule): if external_weights: @@ -277,7 +277,7 @@ def encode_prompts_turbo( module_str = str(module) if compile_to != "vmfb": - return module_str + return module_str, None else: vmfb_path = utils.compile_to_vmfb( module_str,