Skip to content

Commit

Permalink
add per-user queue size messages
Browse files Browse the repository at this point in the history
  • Loading branch information
KubaBir committed Mar 27, 2024
1 parent 941bc02 commit 23a822c
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 32 deletions.
111 changes: 79 additions & 32 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import comfy.model_management


def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
Expand Down Expand Up @@ -41,6 +42,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = [unique_id]
return input_data_all


def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
# check if node wants the lists
input_is_list = False
Expand All @@ -51,14 +53,14 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
max_len_input = 0
else:
max_len_input = max([len(x) for x in input_data_all.values()])

# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
d_new = dict()
for k,v in d.items():
for k, v in d.items():
d_new[k] = v[i if len(v) > i else -1]
return d_new

results = []
if input_is_list:
if allow_interrupt:
Expand All @@ -75,11 +77,13 @@ def slice_dict(d, i):
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
return results


def get_output_data(obj, input_data_all):

results = []
uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
return_values = map_node_over_list(
obj, input_data_all, obj.FUNCTION, allow_interrupt=True)

for r in return_values:
if isinstance(r, dict):
Expand All @@ -89,7 +93,7 @@ def get_output_data(obj, input_data_all):
results.append(r['result'])
else:
results.append(r)

output = []
if len(results) > 0:
# check which outputs need concatenating
Expand All @@ -104,11 +108,12 @@ def get_output_data(obj, input_data_all):
else:
output.append([o[i] for o in results])

ui = dict()
ui = dict()
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui


def format_value(x):
if x is None:
return None
Expand All @@ -117,6 +122,7 @@ def format_value(x):
else:
return str(x)


def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
Expand All @@ -132,17 +138,20 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
result = recursive_execute(server, prompt, outputs, input_unique_id,
extra_data, executed, prompt_id, outputs_ui, object_storage)
if result[0] is not True:
# Another node failed further upstream
return result

input_data_all = None
try:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
input_data_all = get_input_data(
inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
server.send_sync(
"executing", {"node": unique_id, "prompt_id": prompt_id}, server.client_id)

obj = object_storage.get((unique_id, class_type), None)
if obj is None:
Expand All @@ -154,7 +163,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", {
"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, server.client_id)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")

Expand All @@ -175,7 +185,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute

output_data_formatted = {}
for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
output_data_formatted[node_id] = [
[format_value(x) for x in l] for l in node_outputs]

logging.error("!!! Exception during processing !!!")
logging.error(traceback.format_exc())
Expand All @@ -194,6 +205,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute

return (True, None, None)


def recursive_will_execute(prompt, outputs, current_item, memo={}):
unique_id = current_item

Expand All @@ -211,11 +223,13 @@ def recursive_will_execute(prompt, outputs, current_item, memo={}):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
will_execute += recursive_will_execute(
prompt, outputs, input_unique_id, memo)

memo[unique_id] = will_execute + [unique_id]
return memo[unique_id]


def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
Expand All @@ -229,11 +243,13 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
is_changed_old = old_prompt[unique_id]['is_changed']
if 'is_changed' not in prompt[unique_id]:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
input_data_all = get_input_data(
inputs, class_def, unique_id, outputs)
if input_data_all is not None:
try:
#is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
# is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(
class_def, input_data_all, "IS_CHANGED")
prompt[unique_id]['is_changed'] = is_changed
except:
to_delete = True
Expand All @@ -256,7 +272,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id in outputs:
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
to_delete = recursive_output_delete_if_changed(
prompt, old_prompt, outputs, input_unique_id)
else:
to_delete = True
if to_delete:
Expand All @@ -269,6 +286,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
del d
return to_delete


class PromptExecutor:
def __init__(self, server):
self.server = server
Expand Down Expand Up @@ -315,7 +333,7 @@ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, e
"current_outputs": error["current_outputs"],
}
self.add_message("execution_error", mes, broadcast=False)

# Next, remove the subsequent outputs since they will not be executed
to_delete = []
for o in self.outputs:
Expand All @@ -337,10 +355,11 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
self.server.client_id = None

self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
self.add_message("execution_start", {
"prompt_id": prompt_id}, broadcast=False)

with torch.inference_mode():
#delete cached outputs if nodes don't exist for them
# delete cached outputs if nodes don't exist for them
to_delete = []
for o in self.outputs:
if o not in prompt:
Expand All @@ -361,7 +380,8 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
del d

for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
recursive_output_delete_if_changed(
prompt, self.old_prompt, self.outputs, x)

current_outputs = set(self.outputs.keys())
for x in list(self.outputs_ui.keys()):
Expand All @@ -371,8 +391,9 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):

comfy.model_management.cleanup_models()
self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
broadcast=False)
{"nodes": list(current_outputs),
"prompt_id": prompt_id},
broadcast=False)
executed = set()
output_node_id = None
to_execute = []
Expand All @@ -381,17 +402,20 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
to_execute += [(0, node_id)]

while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
# always execute the output that depends on the least amount of unexecuted nodes first
memo = {}
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(
prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
output_node_id = to_execute.pop(0)[-1]

# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
self.success, error, ex = recursive_execute(
self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if self.success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
self.handle_execution_error(
prompt_id, prompt, current_outputs, executed, error, ex)
break

for x in executed:
Expand All @@ -401,7 +425,6 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
comfy.model_management.unload_all_models()



def validate_inputs(prompt, item, validated):
unique_id = item
if unique_id in validated:
Expand All @@ -419,7 +442,8 @@ def validate_inputs(prompt, item, validated):

validate_function_inputs = []
if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
validate_function_inputs = inspect.getfullargspec(
obj_class.VALIDATE_INPUTS).args

for x in required_inputs:
if x not in inputs:
Expand Down Expand Up @@ -584,7 +608,7 @@ def validate_inputs(prompt, item, validated):
if x in validate_function_inputs:
input_filtered[x] = input_data_all[x]

#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
# ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
for x in input_filtered:
for i, r in enumerate(ret):
Expand Down Expand Up @@ -614,12 +638,14 @@ def validate_inputs(prompt, item, validated):
validated[unique_id] = ret
return ret


def full_type_name(klass):
module = klass.__module__
if module == 'builtins':
return klass.__qualname__
return module + '.' + klass.__qualname__


def validate_prompt(prompt):
outputs = set()
for x in prompt:
Expand Down Expand Up @@ -669,7 +695,8 @@ def validate_prompt(prompt):
if len(reasons) > 0:
logging.error("* (prompt):")
for reason in reasons:
logging.error(f" - {reason['message']}: {reason['details']}")
logging.error(
f" - {reason['message']}: {reason['details']}")
errors += [(o, reasons)]
for node_id, result in validated.items():
valid = result[0]
Expand All @@ -687,7 +714,8 @@ def validate_prompt(prompt):
}
logging.error(f"* {class_type} {node_id}:")
for reason in reasons:
logging.error(f" - {reason['message']}: {reason['details']}")
logging.error(
f" - {reason['message']}: {reason['details']}")
node_errors[node_id]["dependent_outputs"].append(o)
logging.error("Output will be ignored")

Expand All @@ -709,8 +737,10 @@ def validate_prompt(prompt):

return (True, None, list(good_outputs), node_errors)


MAXIMUM_HISTORY_SIZE = 10000


class PromptQueue:
def __init__(self, server):
self.server = server
Expand Down Expand Up @@ -776,6 +806,23 @@ def get_tasks_remaining(self):
with self.mutex:
return len(self.queue) + len(self.currently_running)

def get_queue_position_per_client(self):
pending_clients = []
with self.mutex:
for item in self.queue:
pending_clients.append([item[3]['client_id'], item[0]])

pending_clients.sort(key=lambda i: i[1])

positions = dict()
for index, item in enumerate(pending_clients):
if item[0] not in positions:
positions[item[0]] = [index + 1]
else:
positions[item[0]].append(index + 1)

return positions

def wipe_queue(self):
with self.mutex:
self.queue = []
Expand Down
14 changes: 14 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,8 +688,22 @@ def send_sync(self, event, data, sid=None):
self.messages.put_nowait, (event, data, sid))

def queue_updated(self):
# Send global queue info
self.send_sync("status", {"status": self.get_queue_info()})

# Self client queue info (how long untill clients request is processed)
pending_clients = self.prompt_queue.get_queue_position_per_client()
if not len(pending_clients):
return

for sid, positions in pending_clients.items():
prompt_info = {}
exec_info = {}
exec_info['queue_remaining'] = positions[0]
prompt_info['exec_info'] = exec_info

self.send_sync("status", {"remaining": prompt_info}, sid)

async def publish_loop(self):
while True:
msg = await self.messages.get()
Expand Down

0 comments on commit 23a822c

Please sign in to comment.