Skip to content

Commit

Permalink
add external api for tiled diffusion (#1336)
Browse files Browse the repository at this point in the history
add external api for tiled diffusion
  • Loading branch information
lllyasviel authored May 20, 2023
1 parent 1c994ff commit cd98a95
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
2 changes: 1 addition & 1 deletion scripts/controlnet_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version_flag = 'v1.1.179'
version_flag = 'v1.1.180'
print(f'ControlNet {version_flag}')
# A smart trick to know if user has updated as well as if user has restarted terminal.
# Note that in "controlnet.py" we do NOT use "importlib.reload" to reload this "controlnet_version.py"
Expand Down
47 changes: 47 additions & 0 deletions scripts/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,55 @@ def __init__(
self.cfg_scale = cfg_scale
self.soft_injection = soft_injection
self.cfg_injection = cfg_injection
self.override_uc_mask = None

def override_controlnet_cond_uncond_counter(self, counter: list):
'''
Override the cond-uncond counter of a controlnet unit.
If any other extension wants to call the forward function of controlnet, you are
highly recommended to manage the cond/uncond sequence to avoid unexpected behaviors
of CN.
For example, if you call
forward(x, timesteps, context)
and the shape of x is [B, C, H, W]. If you hacked the ``unetHook'' of ControlNet,
and for example, let us say your B = 3 (a batch with 3 instances) where your first and
second instances are cond (instances using positive prompt), while the third instance is uncond
(instance using negative prompt), you can write the sequence as [True, True, False],
where True is cond while False is uncond.
Then in this case, you can call
my_counter = [True, True, False]
for control_param in unetHook.control_params:
control_param.override_controlnet_cond_uncond_counter(my_counter)
.... = forward(x, timesteps, context)
And you will override all future cond-uncond behaviors of this control_param
(this ControlNet unit) until this control_param is disposed by python garbage collection.
If you do not call this ``override_controlnet_cond_uncond_counter'', ControlNet will
count the cond-uncond using A1111's Gradio UI's batchsize, which can be WRONG if you
hacked ControlNet and the actual count is not equivalent to the default A1111 behaviors.
Args:
counter: A list of bool values, the length of list must be same with the B.
Returns:
None
'''
self.override_uc_mask = counter

def generate_uc_mask(self, length, dtype=None, device=None, python_list=False):
if isinstance(self.override_uc_mask, list):
if python_list:
return [i for i, v in enumerate(self.override_uc_mask) if not v]
return torch.tensor(self.override_uc_mask, dtype=dtype, device=device)

if self.is_vanilla_samplers and self.cfg_scale == 1:
if python_list:
return []
Expand Down

0 comments on commit cd98a95

Please sign in to comment.