Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiling StableDiffusionXL unet(torch.float16) failed. #1035

Open
newgrit1004 opened this issue Nov 18, 2024 · 1 comment
Open

Compiling StableDiffusionXL unet(torch.float16) failed. #1035

newgrit1004 opened this issue Nov 18, 2024 · 1 comment
Labels
bug Something isn't working Inf2

Comments

@newgrit1004
Copy link

Hi, I tried a test about compiling unet(torch.float16), which is the part of StableDiffusionXLPipeline in Inferentia2.8xlarge and it failed.

When the latent size of unet is (64, 64), it did not failed.
However, when the latent size of unet is (128, 128), it failed.

[Error message]

(aws_neuron_venv_pytorch) [ec2-user@ip-172-31-32-56 ~]$ python compile_neuron_sdxl_base_fp16.py 
/home/ec2-user/aws_neuron_venv_pytorch/lib64/python3.9/site-packages/huggingface_hub/file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:05<00:00,  1.33it/s]
/home/ec2-user/compile_neuron_sdxl_base_fp16.py:116: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'.
  self.in_channels = unetwrap.unet.in_channels
.....................................
Compiler status PASS
783.619141103
.....................root = neuronxcc/starfish/penguin/targets/codegen/BirCodeGenLoop.py
root = neuronxcc/starfish/penguin/targets/codegen
root = neuronxcc/starfish/penguin/targets
root = neuronxcc/starfish/penguin
root = neuronxcc/starfish

[TEN404] (_add.23504) Internal tensorizer error: BirCodeGenLoop:too many partition dims! {{0,+,6}[4],+,48}[16] - Please open a support ticket at https://github.com/aws-neuron/aws-neuron-sdk/issues/new. You may also be able to obtain more information using the 'XLA_IR_DEBUG' and 'XLA_HLO_DEBUG' environment variables.
Traceback (most recent call last):
  File "/home/ec2-user/compile_neuron_sdxl_base_fp16.py", line 183, in <module>
    unet_neuron = torch_neuronx.trace(
  File "/home/ec2-user/aws_neuron_venv_pytorch/lib64/python3.9/site-packages/torch_neuronx/xla_impl/trace.py", line 574, in trace
    neff_filename, metaneff, flattener, packer, weights = _trace(
  File "/home/ec2-user/aws_neuron_venv_pytorch/lib64/python3.9/site-packages/torch_neuronx/xla_impl/trace.py", line 639, in _trace
    neff_artifacts = generate_neff(
  File "/home/ec2-user/aws_neuron_venv_pytorch/lib64/python3.9/site-packages/torch_neuronx/xla_impl/trace.py", line 492, in generate_neff
    neff_filename = hlo_compile(
  File "/home/ec2-user/aws_neuron_venv_pytorch/lib64/python3.9/site-packages/torch_neuronx/xla_impl/trace.py", line 394, in hlo_compile
    raise RuntimeError(f"neuronx-cc failed with {status}")
RuntimeError: neuronx-cc failed with 70

Environment

AMI : Deep Learning AMI Neuron (Amazon Linux 2023)

# Configure Linux for Neuron repository updates
sudo tee /etc/yum.repos.d/neuron.repo > /dev/null <<EOF
[neuron]
name=Neuron YUM Repository
baseurl=https://yum.repos.neuron.amazonaws.com
enabled=1
metadata_expire=0
EOF
sudo rpm --import https://yum.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB

# Update OS packages 
sudo yum update -y

# Install OS headers 
sudo yum install kernel-devel-$(uname -r) kernel-headers-$(uname -r) -y

# Install git 
sudo yum install git -y

# install Neuron Driver
sudo yum install aws-neuronx-dkms-2.* -y

# Install Neuron Runtime 
sudo yum install aws-neuronx-collectives-2.* -y
sudo yum install aws-neuronx-runtime-lib-2.* -y

# Install Neuron Tools 
sudo yum install aws-neuronx-tools-2.* -y

# Add PATH
export PATH=/opt/aws/neuron/bin:$PATH

#Install PyTorch Neuron (링크에서 파이토치 버전따라 선택하고 실행)
#여기선 PyTorch2.1.2 버전 선택
# Install External Dependency
sudo yum install -y libxcrypt-compat

# Install Python venv 
sudo yum install -y gcc-c++ 

# Create Python venv
python3.9 -m venv aws_neuron_venv_pytorch 

# Activate Python venv 
source aws_neuron_venv_pytorch/bin/activate 
python -m pip install -U pip 

# Install Jupyter notebook kernel
pip install ipykernel 
python3.9 -m ipykernel install --user --name aws_neuron_venv_pytorch --display-name "Python (torch-neuronx)"
pip install jupyter notebook
pip install environment_kernels

# Set pip repository pointing to the Neuron repository 
python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com

# Install wget, awscli 
python -m pip install wget 
python -m pip install awscli 

# Install Neuron Compiler and Framework
python -m pip install neuronx-cc==2.* torch-neuronx torchvision

pip install diffusers==0.20.0 transformers==4.26.1 accelerate==0.16.0 matplotlib

Python Script

import os
import diffusers
import math
import time
import torch_neuronx
import copy
import torch
import torch.nn as nn
from diffusers.models.attention_processor import Attention
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
import torch.nn.functional as F
from diffusers import DiffusionPipeline


def apply_neuron_attn_override(
    diffusers_pkg, get_attn_scores_func, neuron_scaled_dot_product_attention
):
    # diffusers_version = version.parse(diffusers_pkg.__version__)
    # use_new_diffusers = diffusers_version >= version.parse("0.18.0")
    use_new_diffusers = True
    if use_new_diffusers:
        diffusers_pkg.models.attention_processor.Attention.get_attention_scores = (
            get_attn_scores_func
        )
    else:
        diffusers_pkg.models.cross_attention.CrossAttention.get_attention_scores = (
            get_attn_scores_func
        )

    # If Pytorch 2 is available, a F.scaled_dot_product_attention will be used, so we need to
    # monkey patch that too to be Neuron optimized attention
    if hasattr(F, "scaled_dot_product_attention"):
        F.scaled_dot_product_attention = neuron_scaled_dot_product_attention

def get_attention_scores_neuron(self, query, key, attn_mask):
    if query.size() == key.size():
        attention_scores = custom_badbmm(
            key,
            query.transpose(-1, -2),
            self.scale
        )
        attention_probs = attention_scores.softmax(dim=1).permute(0, 2, 1)
    else:
        attention_scores = custom_badbmm(
            query,
            key.transpose(-1, -2),
            self.scale
        )
        attention_probs = attention_scores.softmax(dim=-1)
    return attention_probs

def custom_badbmm(a, b, scale):
    bmm = torch.bmm(a, b)
    scaled = bmm * scale
    return scaled

def neuron_scaled_dot_product_attention(
    query, key, value, attn_mask=None, dropout_p=None, is_causal=None
):
    orig_shape = None
    if len(query.shape) == 4:
        orig_shape = query.shape

        def to3d(x):
            return x.reshape(-1, x.shape[2], x.shape[3])

        query, key, value = map(to3d, [query, key, value])

    if query.size() == key.size():
        attention_scores = torch.bmm(key, query.transpose(-1, -2)) * (
            1 / math.sqrt(query.size(-1))
        )
        attention_probs = attention_scores.softmax(dim=1).permute(0, 2, 1)

    else:
        attention_scores = torch.bmm(query, key.transpose(-1, -2)) * (
            1 / math.sqrt(query.size(-1))
        )
        attention_probs = attention_scores.softmax(dim=-1)

    attn_out = torch.bmm(attention_probs, value)

    if orig_shape:
        attn_out = attn_out.reshape(
            orig_shape[0], orig_shape[1], attn_out.shape[1], attn_out.shape[2]
        )

    return attn_out

# Replace original cross-attention module with custom cross-attention module for better performance
apply_neuron_attn_override(
    diffusers, get_attention_scores_neuron, neuron_scaled_dot_product_attention
)

class UNetWrap(nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.unet = unet

    def forward(self, sample, timestep, encoder_hidden_states, text_embeds=None, time_ids=None):
        out_tuple = self.unet(
            sample,
            timestep,
            encoder_hidden_states,
            added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids},
            return_dict=False
        )
        return out_tuple

class NeuronUNet(nn.Module):
    def __init__(self, unetwrap):
        super().__init__()
        self.unetwrap = unetwrap
        self.config = unetwrap.unet.config
        self.in_channels = unetwrap.unet.in_channels
        self.add_embedding = unetwrap.unet.add_embedding
        self.device = unetwrap.unet.device

    def forward(self, sample, timestep, encoder_hidden_states, added_cond_kwargs=None, return_dict=False, cross_attention_kwargs=None):
        sample = self.unetwrap(
            sample,
            timestep.float().expand((sample.shape[0],)),
            encoder_hidden_states,
            added_cond_kwargs["text_embeds"],
            added_cond_kwargs["time_ids"]
        )[0]
        return UNet2DConditionOutput(sample=sample)


class TextEncoderOutputWrapper(nn.Module):
    def __init__(self, traceable_text_encoder, original_text_encoder):
        super().__init__()
        self.traceable_text_encoder = traceable_text_encoder
        self.config = original_text_encoder.config
        self.dtype = original_text_encoder.dtype
        self.device = original_text_encoder.device

    def forward(self, text_input_ids, output_hidden_states=True):
        out_tuple = self.traceable_text_encoder(text_input_ids)
        return CLIPTextModelOutput(text_embeds=out_tuple[0], last_hidden_state=out_tuple[1], hidden_states=out_tuple[2])
class TraceableTextEncoder(nn.Module):
    def __init__(self, text_encoder):
        super().__init__()
        self.text_encoder = text_encoder

    def forward(self, text_input_ids):
        out_tuple = self.text_encoder(text_input_ids, output_hidden_states=True, return_dict=False)
        return out_tuple


# For saving compiler artifacts
COMPILER_WORKDIR_ROOT = 'sdxl_base_compile_dir_1024_bf16'
DTYPE = torch.bfloat16
log_file = "compile_times_base.txt"
with open(log_file, "w") as f:
    f.write("model compile_time\n")

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=DTYPE, low_cpu_mem_usage=True)

# Replace original cross-attention module with custom cross-attention module for better performance
Attention.get_attention_scores = get_attention_scores_neuron
pipe.unet = NeuronUNet(UNetWrap(pipe.unet))
unet = copy.deepcopy(pipe.unet.unetwrap)
del pipe


bucket_sizes = [(64, 64), (128, 128)] #512/8=64, 1024/8=128
for bucket_size in bucket_sizes:
    start_time = time.perf_counter()
    h, w = bucket_size
    sample_1b = torch.randn([1, 4, h, w], dtype=DTYPE)
    timestep_1b = torch.tensor(999).float().expand((1,))
    encoder_hidden_states_1b = torch.randn([1, 77, 2048], dtype=DTYPE)
    added_cond_kwargs_1b = {"text_embeds": torch.randn([1, 1280], dtype=DTYPE),
                            "time_ids": torch.randn([1, 6], dtype=DTYPE)}
    example_inputs = (sample_1b, timestep_1b, encoder_hidden_states_1b, added_cond_kwargs_1b["text_embeds"], added_cond_kwargs_1b["time_ids"],)

    unet_neuron = torch_neuronx.trace(
        unet,
        example_inputs,
        compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, str(h)+'_unet'),
        compiler_args=["--model-type=unet-inference"]
    )
    # Enable asynchronous and lazy loading to speed up model load
    torch_neuronx.async_load(unet_neuron)
    torch_neuronx.lazy_load(unet_neuron)

    unet_filename = os.path.join(COMPILER_WORKDIR_ROOT, str(h)+'_unet/model.pt')
    torch.jit.save(unet_neuron, unet_filename)
    end_time = time.perf_counter() - start_time
    print(end_time)
    with open(log_file, "a") as f:
        f.write(f"UNet compile - bucket size: {bucket_size}, end time: {end_time:.2f}sec\n")

Most of code is same from this code

I was able to compile the unet for torch.float32, but not torch.float16 and when the latent size is (128, 128)

@newgrit1004
Copy link
Author

Additionally, I can compile the unet when the dtype is torch.bfloat and latent size is (128, 128).

@devesr-amzn devesr-amzn added bug Something isn't working Inf2 labels Nov 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Inf2
Projects
None yet
Development

No branches or pull requests

2 participants