diff --git a/scripts/controlnet.py b/scripts/controlnet.py index 6d50fa605..6b5d58096 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -1200,8 +1200,6 @@ def process(self, p, *args): if unit.module == 'clip_vision': detected_maps.append((processor.clip_vision_visualization(detected_map), unit.module)) - is_vanilla_samplers = p.sampler_name in ["DDIM", "PLMS", "UniPC"] - control_model_type = ControlModelType.ControlNet if isinstance(model_net, PlugableAdapter): @@ -1238,10 +1236,6 @@ def process(self, p, *args): control_model_type=control_model_type, global_average_pooling=global_average_pooling, hr_hint_cond=hr_control, - batch_size=p.batch_size, - instance_counter=0, - is_vanilla_samplers=is_vanilla_samplers, - cfg_scale=p.cfg_scale, soft_injection=control_mode != external_code.ControlMode.BALANCED, cfg_injection=control_mode == external_code.ControlMode.CONTROL, ) @@ -1250,7 +1244,7 @@ def process(self, p, *args): del model_net self.latest_network = UnetHook(lowvram=hook_lowvram) - self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params) + self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p) self.detected_map = detected_maps def postprocess(self, p, processed, *args): diff --git a/scripts/controlnet_version.py b/scripts/controlnet_version.py index 6a0b5b43b..fb5657c47 100644 --- a/scripts/controlnet_version.py +++ b/scripts/controlnet_version.py @@ -1,4 +1,4 @@ -version_flag = 'v1.1.180' +version_flag = 'v1.1.181' print(f'ControlNet {version_flag}') # A smart trick to know if user has updated as well as if user has restarted terminal. # Note that in "controlnet.py" we do NOT use "importlib.reload" to reload this "controlnet_version.py" diff --git a/scripts/hook.py b/scripts/hook.py index cea2003e2..92384b143 100644 --- a/scripts/hook.py +++ b/scripts/hook.py @@ -11,6 +11,51 @@ from ldm.modules.diffusionmodules.openaimodel import UNetModel from ldm.modules.attention import BasicTransformerBlock +from modules.prompt_parser import MulticondLearnedConditioning, ComposableScheduledPromptConditioning, ScheduledPromptConditioning + + +POSITIVE_MARK_TOKEN = 1024 +NEGATIVE_MARK_TOKEN = - POSITIVE_MARK_TOKEN +MARK_EPS = 1e-3 + + +def prompt_context_is_marked(x): + m = torch.abs(x[0]) - POSITIVE_MARK_TOKEN + m = torch.mean(torch.abs(m)).detach().cpu().float().numpy() + return float(m) < MARK_EPS + + +def mark_prompt_context(x, mark): + if isinstance(x, list): + for i in range(len(x)): + x[i] = mark_prompt_context(x[i], mark) + return x + if isinstance(x, MulticondLearnedConditioning): + for i in range(len(x.batch)): + x.batch[i] = mark_prompt_context(x.batch[i], mark) + return x + if isinstance(x, ComposableScheduledPromptConditioning): + for i in range(len(x.schedules)): + x.schedules[i] = mark_prompt_context(x.schedules[i], mark) + return x + if isinstance(x, ScheduledPromptConditioning): + if prompt_context_is_marked(x.cond): + return x + m = torch.zeros_like(x.cond)[:1] + mark + cond = torch.cat([m, x.cond], dim=0) + return ScheduledPromptConditioning(end_at_step=x.end_at_step, cond=cond) + + +def unmark_prompt_context(x): + mark = x[:, 0, :] + context = x[:, 1:, :] + mark = torch.mean(torch.abs(mark - NEGATIVE_MARK_TOKEN), dim=1) + mark = (mark > MARK_EPS).float() + mark_batch = mark[:, None, None, None] + uc_indices = mark.detach().cpu().numpy().tolist() + uc_indices = [i for i, item in enumerate(uc_indices) if item < 0.5] + return mark_batch, uc_indices, context + class ControlModelType(Enum): """ @@ -81,12 +126,9 @@ def __init__( control_model_type, hr_hint_cond, global_average_pooling, - batch_size, - instance_counter, - is_vanilla_samplers, - cfg_scale, soft_injection, - cfg_injection + cfg_injection, + **kwargs # To avoid errors ): self.control_model = control_model self._hint_cond = hint_cond @@ -100,80 +142,8 @@ def __init__( self.hr_hint_cond = hr_hint_cond self.used_hint_cond = None self.used_hint_cond_latent = None - self.batch_size = batch_size - self.instance_counter = instance_counter - self.is_vanilla_samplers = is_vanilla_samplers - self.cfg_scale = cfg_scale self.soft_injection = soft_injection self.cfg_injection = cfg_injection - self.override_uc_mask = None - - def override_controlnet_cond_uncond_counter(self, counter: list): - ''' - Override the cond-uncond counter of a controlnet unit. - - If any other extension wants to call the forward function of controlnet, you are - highly recommended to manage the cond/uncond sequence to avoid unexpected behaviors - of CN. - - For example, if you call - - forward(x, timesteps, context) - - and the shape of x is [B, C, H, W]. If you hacked the ``unetHook'' of ControlNet, - and for example, let us say your B = 3 (a batch with 3 instances) where your first and - second instances are cond (instances using positive prompt), while the third instance is uncond - (instance using negative prompt), you can write the sequence as [True, True, False], - where True is cond while False is uncond. - - Then in this case, you can call - - my_counter = [True, True, False] - - for control_param in unetHook.control_params: - control_param.override_controlnet_cond_uncond_counter(my_counter) - - .... = forward(x, timesteps, context) - - And you will override all future cond-uncond behaviors of this control_param - (this ControlNet unit) until this control_param is disposed by python garbage collection. - - If you do not call this ``override_controlnet_cond_uncond_counter'', ControlNet will - count the cond-uncond using A1111's Gradio UI's batchsize, which can be WRONG if you - hacked ControlNet and the actual count is not equivalent to the default A1111 behaviors. - - Args: - counter: A list of bool values, the length of list must be same with the B. - Returns: - None - ''' - self.override_uc_mask = counter - - def generate_uc_mask(self, length, dtype=None, device=None, python_list=False): - if isinstance(self.override_uc_mask, list): - if python_list: - return [i for i, v in enumerate(self.override_uc_mask) if not v] - return torch.tensor(self.override_uc_mask, dtype=dtype, device=device) - - if self.is_vanilla_samplers and self.cfg_scale == 1: - if python_list: - return [] - return torch.tensor([1 for _ in range(length)], dtype=dtype, device=device) - - y = [] - - for i in range(length): - p = (self.instance_counter + i) % (self.batch_size * 2) - if self.is_vanilla_samplers: - y += [0] if p < self.batch_size else [1] - else: - y += [1] if p < self.batch_size else [0] - - self.instance_counter += length - - if python_list: - return [i for i in range(length) if y[i] < 0.5] - return torch.tensor(y, dtype=dtype, device=device) @property def hint_cond(self): @@ -237,19 +207,30 @@ def guidance_schedule_handler(self, x): current_sampling_percent = (x.sampling_step / x.total_sampling_steps) param.guidance_stopped = current_sampling_percent < param.start_guidance_percent or current_sampling_percent > param.stop_guidance_percent - def hook(self, model, sd_ldm, control_params): + def hook(self, model, sd_ldm, control_params, process): self.model = model self.sd_ldm = sd_ldm self.control_params = control_params outer = self + def process_sample(*args, **kwargs): + conditioning = kwargs['conditioning'] + unconditional_conditioning = kwargs['unconditional_conditioning'] + mark_prompt_context(conditioning, POSITIVE_MARK_TOKEN) + mark_prompt_context(unconditional_conditioning, NEGATIVE_MARK_TOKEN) + return process.sample_before_CN_hack(*args, **kwargs) + def forward(self, x, timesteps=None, context=None, **kwargs): total_controlnet_embedding = [0.0] * 13 total_t2i_adapter_embedding = [0.0] * 4 require_inpaint_hijack = False is_in_high_res_fix = False + # Handle cond-uncond marker + cond_mark, outer.current_uc_indices, context = unmark_prompt_context(context) + # print(str(cond_mark[:, 0, 0, 0].detach().cpu().numpy().tolist()) + ' - ' + str(outer.current_uc_indices)) + # High-res fix for param in outer.control_params: # select which hint_cond to use @@ -307,10 +288,9 @@ def forward(self, x, timesteps=None, context=None, **kwargs): param.control_model.to(devices.get_device_for("controlnet")) query_size = int(x.shape[0]) control = param.control_model(x=x, hint=param.used_hint_cond, timesteps=timesteps, context=context) - uc_mask = param.generate_uc_mask(query_size, dtype=x.dtype, device=x.device)[:, None, None] control = torch.cat([control.clone() for _ in range(query_size)], dim=0) control *= param.weight - control *= uc_mask + control *= cond_mark[:, :, :, 0] context = torch.cat([context, control.clone()], dim=1) # handle ControlNet / T2I_Adapter @@ -343,8 +323,7 @@ def forward(self, x, timesteps=None, context=None, **kwargs): query_size = int(x.shape[0]) if param.control_model_type == ControlModelType.T2I_Adapter: control = [torch.cat([c.clone() for _ in range(query_size)], dim=0) for c in control] - uc_mask = param.generate_uc_mask(query_size, dtype=x.dtype, device=x.device)[:, None, None, None] - control = [c * uc_mask for c in control] + control = [c * cond_mark for c in control] if param.soft_injection or is_in_high_res_fix: # important! use the soft weights with high-res fix can significantly reduce artifacts. @@ -397,8 +376,6 @@ def forward(self, x, timesteps=None, context=None, **kwargs): if param.control_model_type not in [ControlModelType.AttentionInjection]: continue - query_size = int(x.shape[0]) - outer.current_uc_indices = param.generate_uc_mask(query_size, dtype=x.dtype, device=x.device, python_list=True) ref_xt = outer.sd_ldm.q_sample(param.used_hint_cond_latent, torch.round(timesteps.float()).long()) # Inpaint Hijack @@ -541,6 +518,10 @@ def hacked_group_norm_forward(self, *args, **kwargs): y = x return y.to(x.dtype) + if getattr(process, 'sample_before_CN_hack', None) is None: + process.sample_before_CN_hack = process.sample + process.sample = process_sample + model._original_forward = model.forward outer.original_forward = model.forward model.forward = forward_webui.__get__(model, UNetModel)