Skip to content

Commit

Permalink
fix non-square issue
Browse files Browse the repository at this point in the history
  • Loading branch information
ShenZhang-Shin authored Jun 16, 2024
1 parent e6dba21 commit defa26b
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions hidiffusion/hidiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.models import ControlNetModel

import warnings
diffusers_version = diffusers.__version__
if diffusers_version < "0.27.0":
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
Expand Down Expand Up @@ -1284,15 +1284,19 @@ def window_partition(x, window_size, shift_size, H, W):
windows: (num_windows*B, window_size, window_size, C)
"""
B, N, C = x.shape
# H, W = int(N**0.5), int(N**0.5)
x = x.view(B,H,W,C)
if H % 2 != 0 or W % 2 != 0:
warnings.warn(
f"HiDiffusion Warning: The feature size is {(H,W)} and cannot be directly partitioned into windows. We interpolate the size to {(window_size[0]*2, window_size[1]*2)} to enable the window partition. Even though the generation is OK, the image quality would be largely decreased. We sugget removing window attention by setting apply_hidiffusion(pipe, apply_window_attn=False) for better image quality."
)
x = F.interpolate(x.permute(0,3,1,2).contiguous(), size=(window_size[0]*2, window_size[1]*2), mode='bicubic').permute(0,2,3,1).contiguous()
if type(shift_size) == list or type(shift_size) == tuple:
if shift_size[0] > 0:
x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
else:
if shift_size > 0:
x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
x = x.view(B, 2, window_size[0], 2, window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
windows = windows.view(-1, window_size[0] * window_size[1], C)
return windows
Expand All @@ -1311,15 +1315,17 @@ def window_reverse(windows, window_size, H, W, shift_size):
"""
B, N, C = windows.shape
windows = windows.view(-1, window_size[0], window_size[1], C)
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
B = int(windows.shape[0] / 4) # 2x2
x = windows.view(B, 2, 2, window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, window_size[0]*2, window_size[1]*2, -1)
if type(shift_size) == list or type(shift_size) == tuple:
if shift_size[0] > 0:
x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
else:
if shift_size > 0:
x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2))
if H % 2 != 0 or W % 2 != 0:
x = F.interpolate(x.permute(0,3,1,2).contiguous(), size=(H, W), mode='bicubic').permute(0,2,3,1).contiguous()
x = x.view(B, H*W, C)
return x

Expand Down Expand Up @@ -1353,9 +1359,9 @@ def window_reverse(windows, window_size, H, W, shift_size):
rand_num = torch.rand(1)
B, N, C = hidden_states.shape
ori_H, ori_W = self.info['size']
downsample_ratio = int(((ori_H*ori_W) // N)**0.5)
H, W = (ori_H//downsample_ratio, ori_W//downsample_ratio)
widow_size = (H//2, W//2)
downsample_ratio = round(((ori_H*ori_W) / N)**0.5)
H, W = (math.ceil(ori_H/downsample_ratio), math.ceil(ori_W/downsample_ratio))
widow_size = (math.ceil(H/2), math.ceil(W/2))
if rand_num <= 0.25:
shift_size = (0,0)
if rand_num > 0.25 and rand_num <= 0.5:
Expand All @@ -1365,7 +1371,6 @@ def window_reverse(windows, window_size, H, W, shift_size):
if rand_num > 0.75 and rand_num <= 1:
shift_size = (widow_size[0]//4*3, widow_size[1]//4*3)
norm_hidden_states = window_partition(norm_hidden_states, widow_size, shift_size, H, W)

# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0

Expand Down Expand Up @@ -1570,9 +1575,11 @@ def custom_forward(*inputs):

if i == 0:
if self.aggressive_raunet and self.timestep >= self.T1_start and self.timestep < self.T1_end:
hidden_states = F.avg_pool2d(hidden_states, kernel_size=(2,2))
self.info["upsample_size"] = (hidden_states.shape[2], hidden_states.shape[3])
hidden_states = F.avg_pool2d(hidden_states, kernel_size=(2,2),ceil_mode=True)
elif self.timestep < self.T1:
hidden_states = F.avg_pool2d(hidden_states, kernel_size=(2,2))
self.info["upsample_size"] = (hidden_states.shape[2], hidden_states.shape[3])
hidden_states = F.avg_pool2d(hidden_states, kernel_size=(2,2),ceil_mode=True)
output_states = output_states + (hidden_states,)

if self.downsamplers is not None:
Expand Down Expand Up @@ -1719,12 +1726,9 @@ def custom_forward(*inputs):

if i == 1:
if self.aggressive_raunet and self.timestep >= self.T1_start and self.timestep < self.T1_end:
re_size = (int(hidden_states.shape[-2] * 2), int(hidden_states.shape[-1] * 2))
hidden_states = F.interpolate(hidden_states, size=re_size, mode='bicubic')
hidden_states = F.interpolate(hidden_states, size=self.info["upsample_size"], mode='bicubic')
elif self.timestep < self.T1:
re_size = (int(hidden_states.shape[-2] * 2), int(hidden_states.shape[-1] * 2))
hidden_states = F.interpolate(hidden_states, size=re_size, mode='bicubic')

hidden_states = F.interpolate(hidden_states, size=self.info["upsample_size"], mode='bicubic')
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
Expand Down Expand Up @@ -1874,9 +1878,6 @@ def forward(self, hidden_states: torch.Tensor, scale = 1.0) -> torch.Tensor:
self.T1 = int(aggressive_step/50 * self.max_timestep)
else:
self.T1 = int(self.max_timestep * self.T1_ratio)
if self.timestep < self.T1:
if ori_H != hidden_states.shape[2] and ori_W != hidden_states.shape[3]:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode='bicubic')
self.timestep += 1
if self.timestep == self.max_timestep:
self.timestep = 0
Expand Down Expand Up @@ -1939,8 +1940,8 @@ def apply_hidiffusion(
model.unet.__class__ = make_block_fn(model.unet.__class__)
diffusion_model = model.unet if hasattr(model, "unet") else model

# force forward_upsample_size=True, see unet_2d_condition.py in diffusers
diffusion_model.num_upsamplers += 2
# Hack, avoid non-square problem. See unet_2d_condition.py in diffusers
diffusion_model.num_upsamplers += 12

name_or_path = model.name_or_path
diffusion_model_module_key = []
Expand All @@ -1954,6 +1955,7 @@ def apply_hidiffusion(

diffusion_model.info = {
'size': None,
'upsample_size': None,
'hooks': [],
'text_to_img_controlnet': hasattr(model, 'controlnet'),
'is_inpainting_task': 'inpainting' in model.name_or_path,
Expand Down

0 comments on commit defa26b

Please sign in to comment.