Skip to content

Commit

Permalink
bring back short hashes to sd checkpoint selection
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Jan 19, 2023
1 parent d1ea518 commit c1928cd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
15 changes: 11 additions & 4 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,32 @@ def __init__(self, filename):
if name.startswith("\\") or name.startswith("/"):
name = name[1:]

self.title = name
self.name = name
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)

self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
self.shorthash = self.sha256[0:10] if self.sha256 else None

self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'

self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])

def register(self):
checkpoints_list[self.title] = self
for id in self.ids:
checkpoint_alisases[id] = self

def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
self.shorthash = self.sha256[0:10]

if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256]
self.register()

self.title = f'{self.name} [{self.shorthash}]'

return self.shorthash


Expand Down Expand Up @@ -225,7 +229,10 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None


def load_model_weights(model, checkpoint_info: CheckpointInfo):
title = checkpoint_info.title
sd_model_hash = checkpoint_info.calculate_shorthash()
if checkpoint_info.title != title:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title

cache_enabled = shared.opts.sd_checkpoint_cache > 0

Expand Down
23 changes: 12 additions & 11 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def apply_setting(key, value):
opts.data_labels[key].onchange()

opts.save(shared.config_filename)
return value
return getattr(opts, key)


def update_generation_info(generation_info, html_info, img_index):
Expand Down Expand Up @@ -597,6 +597,16 @@ def ordered_ui_categories():
yield category


def get_value_for_setting(key):
value = getattr(opts, key)

info = opts.data_labels[key]
args = info.component_args() if callable(info.component_args) else info.component_args or {}
args = {k: v for k, v in args.items() if k not in {'precision'}}

return gr.update(value=value, **args)


def create_ui():
import modules.img2img
import modules.txt2img
Expand Down Expand Up @@ -1600,7 +1610,7 @@ def run_settings_single(value, key):

opts.save(shared.config_filename)

return gr.update(value=value), opts.dumpjson()
return get_value_for_setting(key), opts.dumpjson()

with gr.Blocks(analytics_enabled=False) as settings_interface:
with gr.Row():
Expand Down Expand Up @@ -1771,15 +1781,6 @@ def request_restart():

component_keys = [k for k in opts.data_labels.keys() if k in component_dict]

def get_value_for_setting(key):
value = getattr(opts, key)

info = opts.data_labels[key]
args = info.component_args() if callable(info.component_args) else info.component_args or {}
args = {k: v for k, v in args.items() if k not in {'precision'}}

return gr.update(value=value, **args)

def get_settings_values():
return [get_value_for_setting(key) for key in component_keys]

Expand Down

0 comments on commit c1928cd

Please sign in to comment.