Skip to content

Commit

Permalink
shift to context marker for cond-ucond (#1337)
Browse files Browse the repository at this point in the history
shift to context marker for cond-ucond #1337
  • Loading branch information
lllyasviel authored May 20, 2023
1 parent cd98a95 commit 539d2fc
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 92 deletions.
8 changes: 1 addition & 7 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,8 +1200,6 @@ def process(self, p, *args):
if unit.module == 'clip_vision':
detected_maps.append((processor.clip_vision_visualization(detected_map), unit.module))

is_vanilla_samplers = p.sampler_name in ["DDIM", "PLMS", "UniPC"]

control_model_type = ControlModelType.ControlNet

if isinstance(model_net, PlugableAdapter):
Expand Down Expand Up @@ -1238,10 +1236,6 @@ def process(self, p, *args):
control_model_type=control_model_type,
global_average_pooling=global_average_pooling,
hr_hint_cond=hr_control,
batch_size=p.batch_size,
instance_counter=0,
is_vanilla_samplers=is_vanilla_samplers,
cfg_scale=p.cfg_scale,
soft_injection=control_mode != external_code.ControlMode.BALANCED,
cfg_injection=control_mode == external_code.ControlMode.CONTROL,
)
Expand All @@ -1250,7 +1244,7 @@ def process(self, p, *args):
del model_net

self.latest_network = UnetHook(lowvram=hook_lowvram)
self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params)
self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p)
self.detected_map = detected_maps

def postprocess(self, p, processed, *args):
Expand Down
2 changes: 1 addition & 1 deletion scripts/controlnet_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version_flag = 'v1.1.180'
version_flag = 'v1.1.181'
print(f'ControlNet {version_flag}')
# A smart trick to know if user has updated as well as if user has restarted terminal.
# Note that in "controlnet.py" we do NOT use "importlib.reload" to reload this "controlnet_version.py"
Expand Down
149 changes: 65 additions & 84 deletions scripts/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,51 @@
from ldm.modules.diffusionmodules.openaimodel import UNetModel
from ldm.modules.attention import BasicTransformerBlock

from modules.prompt_parser import MulticondLearnedConditioning, ComposableScheduledPromptConditioning, ScheduledPromptConditioning


POSITIVE_MARK_TOKEN = 1024
NEGATIVE_MARK_TOKEN = - POSITIVE_MARK_TOKEN
MARK_EPS = 1e-3


def prompt_context_is_marked(x):
m = torch.abs(x[0]) - POSITIVE_MARK_TOKEN
m = torch.mean(torch.abs(m)).detach().cpu().float().numpy()
return float(m) < MARK_EPS


def mark_prompt_context(x, mark):
if isinstance(x, list):
for i in range(len(x)):
x[i] = mark_prompt_context(x[i], mark)
return x
if isinstance(x, MulticondLearnedConditioning):
for i in range(len(x.batch)):
x.batch[i] = mark_prompt_context(x.batch[i], mark)
return x
if isinstance(x, ComposableScheduledPromptConditioning):
for i in range(len(x.schedules)):
x.schedules[i] = mark_prompt_context(x.schedules[i], mark)
return x
if isinstance(x, ScheduledPromptConditioning):
if prompt_context_is_marked(x.cond):
return x
m = torch.zeros_like(x.cond)[:1] + mark
cond = torch.cat([m, x.cond], dim=0)
return ScheduledPromptConditioning(end_at_step=x.end_at_step, cond=cond)


def unmark_prompt_context(x):
mark = x[:, 0, :]
context = x[:, 1:, :]
mark = torch.mean(torch.abs(mark - NEGATIVE_MARK_TOKEN), dim=1)
mark = (mark > MARK_EPS).float()
mark_batch = mark[:, None, None, None]
uc_indices = mark.detach().cpu().numpy().tolist()
uc_indices = [i for i, item in enumerate(uc_indices) if item < 0.5]
return mark_batch, uc_indices, context


class ControlModelType(Enum):
"""
Expand Down Expand Up @@ -81,12 +126,9 @@ def __init__(
control_model_type,
hr_hint_cond,
global_average_pooling,
batch_size,
instance_counter,
is_vanilla_samplers,
cfg_scale,
soft_injection,
cfg_injection
cfg_injection,
**kwargs # To avoid errors
):
self.control_model = control_model
self._hint_cond = hint_cond
Expand All @@ -100,80 +142,8 @@ def __init__(
self.hr_hint_cond = hr_hint_cond
self.used_hint_cond = None
self.used_hint_cond_latent = None
self.batch_size = batch_size
self.instance_counter = instance_counter
self.is_vanilla_samplers = is_vanilla_samplers
self.cfg_scale = cfg_scale
self.soft_injection = soft_injection
self.cfg_injection = cfg_injection
self.override_uc_mask = None

def override_controlnet_cond_uncond_counter(self, counter: list):
'''
Override the cond-uncond counter of a controlnet unit.
If any other extension wants to call the forward function of controlnet, you are
highly recommended to manage the cond/uncond sequence to avoid unexpected behaviors
of CN.
For example, if you call
forward(x, timesteps, context)
and the shape of x is [B, C, H, W]. If you hacked the ``unetHook'' of ControlNet,
and for example, let us say your B = 3 (a batch with 3 instances) where your first and
second instances are cond (instances using positive prompt), while the third instance is uncond
(instance using negative prompt), you can write the sequence as [True, True, False],
where True is cond while False is uncond.
Then in this case, you can call
my_counter = [True, True, False]
for control_param in unetHook.control_params:
control_param.override_controlnet_cond_uncond_counter(my_counter)
.... = forward(x, timesteps, context)
And you will override all future cond-uncond behaviors of this control_param
(this ControlNet unit) until this control_param is disposed by python garbage collection.
If you do not call this ``override_controlnet_cond_uncond_counter'', ControlNet will
count the cond-uncond using A1111's Gradio UI's batchsize, which can be WRONG if you
hacked ControlNet and the actual count is not equivalent to the default A1111 behaviors.
Args:
counter: A list of bool values, the length of list must be same with the B.
Returns:
None
'''
self.override_uc_mask = counter

def generate_uc_mask(self, length, dtype=None, device=None, python_list=False):
if isinstance(self.override_uc_mask, list):
if python_list:
return [i for i, v in enumerate(self.override_uc_mask) if not v]
return torch.tensor(self.override_uc_mask, dtype=dtype, device=device)

if self.is_vanilla_samplers and self.cfg_scale == 1:
if python_list:
return []
return torch.tensor([1 for _ in range(length)], dtype=dtype, device=device)

y = []

for i in range(length):
p = (self.instance_counter + i) % (self.batch_size * 2)
if self.is_vanilla_samplers:
y += [0] if p < self.batch_size else [1]
else:
y += [1] if p < self.batch_size else [0]

self.instance_counter += length

if python_list:
return [i for i in range(length) if y[i] < 0.5]
return torch.tensor(y, dtype=dtype, device=device)

@property
def hint_cond(self):
Expand Down Expand Up @@ -237,19 +207,30 @@ def guidance_schedule_handler(self, x):
current_sampling_percent = (x.sampling_step / x.total_sampling_steps)
param.guidance_stopped = current_sampling_percent < param.start_guidance_percent or current_sampling_percent > param.stop_guidance_percent

def hook(self, model, sd_ldm, control_params):
def hook(self, model, sd_ldm, control_params, process):
self.model = model
self.sd_ldm = sd_ldm
self.control_params = control_params

outer = self

def process_sample(*args, **kwargs):
conditioning = kwargs['conditioning']
unconditional_conditioning = kwargs['unconditional_conditioning']
mark_prompt_context(conditioning, POSITIVE_MARK_TOKEN)
mark_prompt_context(unconditional_conditioning, NEGATIVE_MARK_TOKEN)
return process.sample_before_CN_hack(*args, **kwargs)

def forward(self, x, timesteps=None, context=None, **kwargs):
total_controlnet_embedding = [0.0] * 13
total_t2i_adapter_embedding = [0.0] * 4
require_inpaint_hijack = False
is_in_high_res_fix = False

# Handle cond-uncond marker
cond_mark, outer.current_uc_indices, context = unmark_prompt_context(context)
# print(str(cond_mark[:, 0, 0, 0].detach().cpu().numpy().tolist()) + ' - ' + str(outer.current_uc_indices))

# High-res fix
for param in outer.control_params:
# select which hint_cond to use
Expand Down Expand Up @@ -307,10 +288,9 @@ def forward(self, x, timesteps=None, context=None, **kwargs):
param.control_model.to(devices.get_device_for("controlnet"))
query_size = int(x.shape[0])
control = param.control_model(x=x, hint=param.used_hint_cond, timesteps=timesteps, context=context)
uc_mask = param.generate_uc_mask(query_size, dtype=x.dtype, device=x.device)[:, None, None]
control = torch.cat([control.clone() for _ in range(query_size)], dim=0)
control *= param.weight
control *= uc_mask
control *= cond_mark[:, :, :, 0]
context = torch.cat([context, control.clone()], dim=1)

# handle ControlNet / T2I_Adapter
Expand Down Expand Up @@ -343,8 +323,7 @@ def forward(self, x, timesteps=None, context=None, **kwargs):
query_size = int(x.shape[0])
if param.control_model_type == ControlModelType.T2I_Adapter:
control = [torch.cat([c.clone() for _ in range(query_size)], dim=0) for c in control]
uc_mask = param.generate_uc_mask(query_size, dtype=x.dtype, device=x.device)[:, None, None, None]
control = [c * uc_mask for c in control]
control = [c * cond_mark for c in control]

if param.soft_injection or is_in_high_res_fix:
# important! use the soft weights with high-res fix can significantly reduce artifacts.
Expand Down Expand Up @@ -397,8 +376,6 @@ def forward(self, x, timesteps=None, context=None, **kwargs):
if param.control_model_type not in [ControlModelType.AttentionInjection]:
continue

query_size = int(x.shape[0])
outer.current_uc_indices = param.generate_uc_mask(query_size, dtype=x.dtype, device=x.device, python_list=True)
ref_xt = outer.sd_ldm.q_sample(param.used_hint_cond_latent, torch.round(timesteps.float()).long())

# Inpaint Hijack
Expand Down Expand Up @@ -541,6 +518,10 @@ def hacked_group_norm_forward(self, *args, **kwargs):
y = x
return y.to(x.dtype)

if getattr(process, 'sample_before_CN_hack', None) is None:
process.sample_before_CN_hack = process.sample
process.sample = process_sample

model._original_forward = model.forward
outer.original_forward = model.forward
model.forward = forward_webui.__get__(model, UNetModel)
Expand Down

0 comments on commit 539d2fc

Please sign in to comment.