Skip to content

Commit

Permalink
remove Train/Preprocessing tab and put all its functionality into ext…
Browse files Browse the repository at this point in the history
…ras batch images mode
  • Loading branch information
AUTOMATIC1111 authored and ruchej committed Sep 30, 2024
1 parent 318c3ab commit 88c73ee
Show file tree
Hide file tree
Showing 19 changed files with 460 additions and 414 deletions.
17 changes: 17 additions & 0 deletions javascript/ui.js
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,23 @@ function submit_img2img() {
return res;
}

function submit_extras() {
showSubmitButtons('extras', false);

var id = randomId();

requestProgress(id, gradioApp().getElementById('extras_gallery_container'), gradioApp().getElementById('extras_gallery'), function() {
showSubmitButtons('extras', true);
});

var res = create_submit_args(arguments);

res[0] = id;

console.log(res);
return res;
}

function restoreProgressTxt2img() {
showRestoreProgressButton("txt2img", false);
var id = localGet("txt2img_task_id");
Expand Down
15 changes: 0 additions & 15 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin, Image
from modules.sd_models_config import find_checkpoint_config_near_filename
Expand Down Expand Up @@ -235,7 +234,6 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
Expand Down Expand Up @@ -675,19 +673,6 @@ def create_hypernetwork(self, args: dict):
finally:
shared.state.end()

def preprocess(self, args: dict):
try:
shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except Exception as e:
return models.PreprocessResponse(info=f"preprocess error: {e}")
finally:
shared.state.end()

def train_embedding(self, args: dict):
try:
shared.state.begin(job="train_embedding")
Expand Down
3 changes: 0 additions & 3 deletions modules/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,6 @@ class TrainResponse(BaseModel):
class CreateResponse(BaseModel):
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")

class PreprocessResponse(BaseModel):
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")

fields = {}
for key, metadata in opts.data_labels.items():
value = opts.data.get(key)
Expand Down
92 changes: 69 additions & 23 deletions modules/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from modules.shared import opts


def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc()

shared.state.begin(job="extras")
Expand All @@ -29,11 +29,7 @@ def get_images(extras_mode, image, image_folder, input_dir):

image_list = shared.listfiles(input_dir)
for filename in image_list:
try:
image = Image.open(filename)
except Exception:
continue
yield image, filename
yield filename, filename
else:
assert image, 'image not selected'
yield image, None
Expand All @@ -45,37 +41,85 @@ def get_images(extras_mode, image, image_folder, input_dir):

infotext = ''

for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
data_to_process = list(get_images(extras_mode, image, image_folder, input_dir))
shared.state.job_count = len(data_to_process)

for image_placeholder, name in data_to_process:
image_data: Image.Image

shared.state.nextjob()
shared.state.textinfo = name
shared.state.skipped = False

if shared.state.interrupted:
break

if isinstance(image_placeholder, str):
try:
image_data = Image.open(image_placeholder)
except Exception:
continue
else:
image_data = image_placeholder

shared.state.assign_current_image(image_data)

parameters, existing_pnginfo = images.read_info_from_image(image_data)
if parameters:
existing_pnginfo["parameters"] = parameters

pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))

scripts.scripts_postproc.run(pp, args)
scripts.scripts_postproc.run(initial_pp, args)

if opts.use_original_name_batch and name is not None:
basename = os.path.splitext(os.path.basename(name))[0]
forced_filename = basename
else:
basename = ''
forced_filename = None
if shared.state.skipped:
continue

used_suffixes = {}
for pp in [initial_pp, *initial_pp.extra_images]:
suffix = pp.get_suffix(used_suffixes)

if opts.use_original_name_batch and name is not None:
basename = os.path.splitext(os.path.basename(name))[0]
forced_filename = basename + suffix
else:
basename = ''
forced_filename = None

infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])

if opts.enable_pnginfo:
pp.image.info = existing_pnginfo
pp.image.info["postprocessing"] = infotext

if save_output:
fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)

infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
if pp.caption:
caption_filename = os.path.splitext(fullfn)[0] + ".txt"
if os.path.isfile(caption_filename):
with open(caption_filename, encoding="utf8") as file:
existing_caption = file.read().strip()
else:
existing_caption = ""

if opts.enable_pnginfo:
pp.image.info = existing_pnginfo
pp.image.info["postprocessing"] = infotext
action = shared.opts.postprocessing_existing_caption_action
if action == 'Prepend' and existing_caption:
caption = f"{existing_caption} {pp.caption}"
elif action == 'Append' and existing_caption:
caption = f"{pp.caption} {existing_caption}"
elif action == 'Keep' and existing_caption:
caption = existing_caption
else:
caption = pp.caption

if save_output:
images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename)
caption = caption.strip()
if caption:
with open(caption_filename, "w", encoding="utf8") as file:
file.write(caption)

if extras_mode != 2 or show_extras_results:
outputs.append(pp.image)
if extras_mode != 2 or show_extras_results:
outputs.append(pp.image)

image_data.close()

Expand All @@ -99,9 +143,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
"upscaler_2_visibility": extras_upscaler_2_visibility,
},
"GFPGAN": {
"enable": True,
"gfpgan_visibility": gfpgan_visibility,
},
"CodeFormer": {
"enable": True,
"codeformer_visibility": codeformer_visibility,
"codeformer_weight": codeformer_weight,
},
Expand Down
86 changes: 81 additions & 5 deletions modules/scripts_postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,56 @@
import dataclasses
import os
import gradio as gr

from modules import errors, shared


@dataclasses.dataclass
class PostprocessedImageSharedInfo:
target_width: int = None
target_height: int = None


class PostprocessedImage:
def __init__(self, image):
self.image = image
self.info = {}
self.shared = PostprocessedImageSharedInfo()
self.extra_images = []
self.nametags = []
self.disable_processing = False
self.caption = None

def get_suffix(self, used_suffixes=None):
used_suffixes = {} if used_suffixes is None else used_suffixes
suffix = "-".join(self.nametags)
if suffix:
suffix = "-" + suffix

if suffix not in used_suffixes:
used_suffixes[suffix] = 1
return suffix

for i in range(1, 100):
proposed_suffix = suffix + "-" + str(i)

if proposed_suffix not in used_suffixes:
used_suffixes[proposed_suffix] = 1
return proposed_suffix

return suffix

def create_copy(self, new_image, *, nametags=None, disable_processing=False):
pp = PostprocessedImage(new_image)
pp.shared = self.shared
pp.nametags = self.nametags.copy()
pp.info = self.info.copy()
pp.disable_processing = disable_processing

if nametags is not None:
pp.nametags += nametags

return pp


class ScriptPostprocessing:
Expand Down Expand Up @@ -42,10 +85,17 @@ def process(self, pp: PostprocessedImage, **args):

pass

def image_changed(self):
pass
def process_firstpass(self, pp: PostprocessedImage, **args):
"""
Called for all scripts before calling process(). Scripts can examine the image here and set fields
of the pp object to communicate things to other scripts.
args contains a dictionary with all values returned by components from ui()
"""

pass

def image_changed(self):
pass


def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
Expand Down Expand Up @@ -118,16 +168,42 @@ def setup_ui(self):
return inputs

def run(self, pp: PostprocessedImage, args):
for script in self.scripts_in_preferred_order():
shared.state.job = script.name
scripts = []

for script in self.scripts_in_preferred_order():
script_args = args[script.args_from:script.args_to]

process_args = {}
for (name, _component), value in zip(script.controls.items(), script_args):
process_args[name] = value

script.process(pp, **process_args)
scripts.append((script, process_args))

for script, process_args in scripts:
script.process_firstpass(pp, **process_args)

all_images = [pp]

for script, process_args in scripts:
if shared.state.skipped:
break

shared.state.job = script.name

for single_image in all_images.copy():

if not single_image.disable_processing:
script.process(single_image, **process_args)

for extra_image in single_image.extra_images:
if not isinstance(extra_image, PostprocessedImage):
extra_image = single_image.create_copy(extra_image)

all_images.append(extra_image)

single_image.extra_images.clear()

pp.extra_images = all_images[1:]

def create_args_for_run(self, scripts_args):
if not self.ui_created:
Expand Down
1 change: 1 addition & 0 deletions modules/shared_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
'postprocessing_existing_caption_action': OptionInfo("Ignore", "Action for existing captions", gr.Radio, {"choices": ["Ignore", "Keep", "Prepend", "Append"]}).info("when generating captions using postprocessing; Ignore = use generated; Keep = use original; Prepend/Append = combine both"),
}))

options_templates.update(options_section((None, "Hidden options"), {
Expand Down
Loading

0 comments on commit 88c73ee

Please sign in to comment.