Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the compiled pipeline compilation issue #798

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def prepare_all(
vmfbs: dict = {},
weights: dict = {},
interactive: bool = False,
num_steps: int = 20,
):
ready = self.is_prepared(vmfbs, weights)
match ready:
Expand All @@ -463,7 +464,9 @@ def prepare_all(
if not self.map[submodel].get("weights") and self.map[submodel][
"export_args"
].get("external_weights"):
self.export_submodel(submodel, weights_only=True)
self.export_submodel(
submodel, weights_only=True, num_steps=num_steps
)
return self.prepare_all(mlirs, vmfbs, weights, interactive)

def is_prepared(self, vmfbs, weights):
Expand Down Expand Up @@ -581,6 +584,7 @@ def export_submodel(
submodel: str,
input_mlir: str = None,
weights_only: bool = False,
num_steps: int = 20,
):
if not os.path.exists(self.pipeline_dir):
os.makedirs(self.pipeline_dir)
Expand Down Expand Up @@ -670,7 +674,9 @@ def export_submodel(
self.map[submodel]["export_args"]["precision"],
self.map[submodel]["export_args"]["batch_size"],
self.map[submodel]["export_args"]["max_length"],
"tokens_to_image",
"produce_img_split",
unet_module_name=self.map["unet"]["module_name"],
num_steps=num_steps,
)
dims = [
self.map[submodel]["export_args"]["width"],
Expand Down Expand Up @@ -699,8 +705,8 @@ def export_submodel(
return_path=True,
mlir_source="str",
)
self.map[submodel]["vmfb"] = vmfb_path
self.map[submodel]["weights"] = None
self.map[submodel]["vmfb"] = [vmfb_path]
self.map[submodel]["weights"] = []
case _:
export_args = self.map[submodel].get("export_args", {})
if weights_only:
Expand All @@ -721,10 +727,24 @@ def export_submodel(

# LOAD
def load_map(self):
for submodel in self.map.keys():
# Make sure fullpipeline is imported last
submodels = list(self.map.keys() - {"fullpipeline"})
submodels += ["fullpipeline"] if "fullpipeline" in self.map.keys() else []
for submodel in submodels:
if not self.map[submodel]["load"]:
self.printer.print("Skipping load for ", submodel)
self.printer.print(f"Skipping load for {submodel}")
continue
elif self.map[submodel].get("wraps"):
vmfbs = []
weights = []
for wrapped in self.map[submodel]["wraps"]:
vmfbs.append(self.map[wrapped]["vmfb"])
if "weights" in self.map[wrapped]:
weights.append(self.map[wrapped]["weights"])
self.map[submodel]["vmfb"] = vmfbs + self.map[submodel]["vmfb"]
self.map[submodel]["weights"] = weights + self.map[submodel]["weights"]

print(f"Loading {submodel}")
self.load_submodel(submodel)

def load_submodel(self, submodel):
Expand All @@ -751,6 +771,10 @@ def load_submodel(self, submodel):

def unload_submodel(self, submodel):
self.map[submodel]["runner"].unload()
self.map[submodel]["vmfb"] = None
self.map[submodel]["mlir"] = None
self.map[submodel]["weights"] = None
self.map[submodel]["export_args"]["input_mlir"] = None
setattr(self, submodel, None)


Expand Down
187 changes: 133 additions & 54 deletions models/turbine_models/custom_models/sd_inference/sd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,11 @@
"decomp_attn": None,
},
},
"unetloop": {
"module_name": "sdxl_compiled_pipeline",
"load": False,
"keywords": ["unetloop"],
"wraps": ["unet", "scheduler"],
"export_args": {
"batch_size": 1,
"height": 1024,
"width": 1024,
"max_length": 64,
},
},
"fullpipeline": {
"module_name": "sdxl_compiled_pipeline",
"load": False,
"load": True,
"keywords": ["fullpipeline"],
"wraps": ["text_encoder", "unet", "scheduler", "vae"],
"wraps": ["unet", "scheduler", "vae"],
"export_args": {
"batch_size": 1,
"height": 1024,
Expand Down Expand Up @@ -190,6 +178,7 @@ def get_sd_model_map(hf_model_name):
"stabilityai/sdxl-turbo",
"stabilityai/stable-diffusion-xl-base-1.0",
"/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16/checkpoint_pipe",
"/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16//checkpoint_pipe",
]:
return sdxl_model_map
elif "stabilityai/stable-diffusion-3" in name:
Expand Down Expand Up @@ -233,6 +222,7 @@ def __init__(
benchmark: bool | dict[bool] = False,
verbose: bool = False,
batch_prompts: bool = False,
compiled_pipeline: bool = False,
):
common_export_args = {
"hf_model_name": None,
Expand All @@ -243,11 +233,11 @@ def __init__(
"exit_on_vmfb": False,
"pipeline_dir": pipeline_dir,
"input_mlir": None,
"attn_spec": None,
"attn_spec": attn_spec,
"external_weights": None,
"external_weight_path": None,
}
sd_model_map = get_sd_model_map(hf_model_name)
sd_model_map = copy.deepcopy(get_sd_model_map(hf_model_name))
for submodel in sd_model_map:
if "load" not in sd_model_map[submodel]:
sd_model_map[submodel]["load"] = True
Expand Down Expand Up @@ -311,6 +301,7 @@ def __init__(
self.scheduler = None

self.split_scheduler = True
self.compiled_pipeline = compiled_pipeline

self.base_model_name = (
hf_model_name
Expand All @@ -321,11 +312,6 @@ def __init__(
self.is_sdxl = "xl" in self.base_model_name.lower()
self.is_sd3 = "stable-diffusion-3" in self.base_model_name
if self.is_sdxl:
if self.split_scheduler:
if self.map.get("unetloop"):
self.map.pop("unetloop")
if self.map.get("fullpipeline"):
self.map.pop("fullpipeline")
self.tokenizers = [
CLIPTokenizer.from_pretrained(
self.base_model_name, subfolder="tokenizer"
Expand All @@ -339,6 +325,20 @@ def __init__(
self.scheduler_device = self.map["unet"]["device"]
self.scheduler_driver = self.map["unet"]["driver"]
self.scheduler_target = self.map["unet"]["target"]
if not self.compiled_pipeline:
if self.map.get("unetloop"):
self.map.pop("unetloop")
if self.map.get("fullpipeline"):
self.map.pop("fullpipeline")
elif self.compiled_pipeline:
self.map["unet"]["load"] = False
self.map["vae"]["load"] = False
self.load_scheduler(
scheduler_id,
num_inference_steps,
)
self.map["scheduler"]["runner"].unload()
self.map["scheduler"]["load"] = False
elif not self.is_sd3:
self.tokenizer = CLIPTokenizer.from_pretrained(
self.base_model_name, subfolder="tokenizer"
Expand All @@ -351,23 +351,27 @@ def __init__(

self.latents_dtype = torch_dtypes[self.latents_precision]
self.use_i8_punet = self.use_punet = use_i8_punet
if self.use_punet:
self.setup_punet()
else:
self.map["unet"]["keywords"].append("!punet")
self.map["unet"]["function_name"] = "run_forward"

def setup_punet(self):
if self.use_i8_punet:
self.map["unet"]["export_args"]["precision"] = "i8"
self.map["unet"]["export_args"]["use_punet"] = True
self.map["unet"]["use_weights_for_export"] = True
self.map["unet"]["keywords"].append("punet")
self.map["unet"]["module_name"] = "compiled_punet"
self.map["unet"]["function_name"] = "main"
self.map["unet"]["export_args"]["external_weight_path"] = (
utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa"
)
for idx, word in enumerate(self.map["unet"]["keywords"]):
if word in ["fp32", "fp16"]:
self.map["unet"]["keywords"][idx] = "i8"
break
else:
self.map["unet"]["keywords"].append("!punet")
self.map["unet"]["function_name"] = "run_forward"
self.map["unet"]["export_args"]["use_punet"] = True
self.map["unet"]["use_weights_for_export"] = True
self.map["unet"]["keywords"].append("punet")
self.map["unet"]["module_name"] = "compiled_punet"
self.map["unet"]["function_name"] = "main"

# LOAD

Expand All @@ -376,10 +380,6 @@ def load_scheduler(
scheduler_id: str,
steps: int = 30,
):
if self.is_sd3:
scheduler_device = self.mmdit.device
else:
scheduler_device = self.unet.device
if not self.cpu_scheduling:
self.map["scheduler"] = {
"module_name": "compiled_scheduler",
Expand Down Expand Up @@ -425,7 +425,11 @@ def load_scheduler(
except:
print("JIT export of scheduler failed. Loading CPU scheduler.")
self.cpu_scheduling = True
if self.cpu_scheduling:
elif self.cpu_scheduling:
if self.is_sd3:
scheduler_device = self.mmdit.device
else:
scheduler_device = self.unet.device
scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id)
self.scheduler = schedulers.SharkSchedulerCPUWrapper(
scheduler,
Expand Down Expand Up @@ -461,13 +465,10 @@ def encode_prompts_sdxl(self, prompt, negative_prompt):
text_input_ids_list += text_inputs.input_ids.unsqueeze(0)
uncond_input_ids_list += uncond_input.input_ids.unsqueeze(0)

if self.compiled_pipeline:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any particular reason you are removing the non compiled_pipeline option here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's removing the compiled_pipeline version - I think this is a relic of previous iterations when the text encoder was going to be part of the pipeline, but since it isn't we can run the same code for compiled and python pipelines.

return text_input_ids_list, uncond_input_ids_list
else:
prompt_embeds, add_text_embeds = self.text_encoder(
"encode_prompts", [*text_input_ids_list, *uncond_input_ids_list]
)
return prompt_embeds, add_text_embeds
prompt_embeds, add_text_embeds = self.text_encoder(
"encode_prompts", [*text_input_ids_list, *uncond_input_ids_list]
)
return prompt_embeds, add_text_embeds

def prepare_latents(
self,
Expand Down Expand Up @@ -565,9 +566,11 @@ def _produce_latents_sdxl(
[guidance_scale],
dtype=self.map["unet"]["np_dtype"],
)
# Disable progress bar if we aren't in verbose mode or if we're printing
# benchmark latencies for unet.
for i, t in tqdm(
enumerate(timesteps),
disable=(self.map["unet"].get("benchmark") and self.verbose),
disable=(self.map["unet"].get("benchmark") or not self.verbose),
):
if self.cpu_scheduling:
latent_model_input, t = self.scheduler.scale_model_input(
Expand Down Expand Up @@ -608,6 +611,75 @@ def _produce_latents_sdxl(
latents = self.scheduler("run_step", [noise_pred, t, latents])
return latents

def produce_images_compiled(
self,
sample,
prompt_embeds,
text_embeds,
guidance_scale,
):
pipe_inputs = [
sample,
prompt_embeds,
text_embeds,
torch.as_tensor([guidance_scale], dtype=sample.dtype),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary?

Copy link
Contributor Author

@gpetters-amd gpetters-amd Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, the compiled pipeline expects an fp16 tensor while passing the raw value produces an fp64 one.

]
# image = self.compiled_pipeline("produce_img_latents", pipe_inputs)
gpetters-amd marked this conversation as resolved.
Show resolved Hide resolved
image = self.map["fullpipeline"]["runner"]("produce_image_latents", pipe_inputs)
return image

def prepare_sampling_inputs(
self,
prompt: str,
negative_prompt: str = "",
steps: int = 30,
batch_count: int = 1,
guidance_scale: float = 7.5,
seed: float = -1,
cpu_scheduling: bool = True,
scheduler_id: str = "EulerDiscrete",
return_imgs: bool = False,
):
needs_new_scheduler = (
(steps and steps != self.num_inference_steps)
or (cpu_scheduling != self.cpu_scheduling)
and self.split_scheduler
)
if not self.scheduler and not self.compiled_pipeline:
needs_new_scheduler = True

if guidance_scale == 0:
negative_prompt = prompt
prompt = ""

self.cpu_scheduling = cpu_scheduling
if steps and needs_new_scheduler:
self.num_inference_steps = steps
self.load_scheduler(scheduler_id, steps)

pipe_start = time.time()
numpy_images = []

samples = self.get_rand_latents(seed, batch_count)

# Tokenize prompt and negative prompt.
if self.is_sdxl:
prompt_embeds, negative_embeds = self.encode_prompts_sdxl(
prompt, negative_prompt
)
else:
prompt_embeds, negative_embeds = encode_prompt(
self, prompt, negative_prompt
)
produce_latents_input = [
samples[0],
prompt_embeds,
negative_embeds,
steps,
guidance_scale,
]
return produce_latents_input

def generate_images(
self,
prompt: str,
Expand Down Expand Up @@ -653,18 +725,23 @@ def generate_images(
)

for i in range(batch_count):
produce_latents_input = [
samples[i],
prompt_embeds,
negative_embeds,
steps,
guidance_scale,
]
if self.is_sdxl:
latents = self._produce_latents_sdxl(*produce_latents_input)
if self.compiled_pipeline:
image = self.produce_images_compiled(
samples[i], prompt_embeds, negative_embeds, guidance_scale
).to_host()
else:
latents = self._produce_latents_sd(*produce_latents_input)
image = self.vae("decode", [latents])
produce_latents_input = [
samples[i],
prompt_embeds,
negative_embeds,
steps,
guidance_scale,
]
if self.is_sdxl:
latents = self._produce_latents_sdxl(*produce_latents_input)
else:
latents = self._produce_latents_sd(*produce_latents_input)
image = self.vae("decode", [latents])
numpy_images.append(image)
pipe_end = time.time()

Expand Down Expand Up @@ -750,8 +827,10 @@ def numpy_to_pil_image(images):
args.use_i8_punet,
benchmark,
args.verbose,
False,
args.compiled_pipeline,
)
sd_pipe.prepare_all()
sd_pipe.prepare_all(num_steps=args.num_inference_steps)
sd_pipe.load_map()
sd_pipe.generate_images(
args.prompt,
Expand Down
Loading
Loading