Skip to content

Commit

Permalink
Merge pull request #5453 from oobabooga/dev
Browse files Browse the repository at this point in the history
Merge dev branch
  • Loading branch information
oobabooga authored Feb 6, 2024
2 parents a329db0 + 775902c commit 0f134bf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
7 changes: 5 additions & 2 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,12 @@ def generate_preset_yaml(state):
defaults = default_preset()
data = {k: state[k] for k in presets_params()}

# Remove entries that are identical to the defaults
# Remove entries that are identical to the defaults.
# sampler_priority is always saved because it is experimental
# and the default order may change.

for k in list(data.keys()):
if data[k] == defaults[k]:
if data[k] == defaults[k] and k != 'sampler_priority':
del data[k]

return yaml.dump(data, sort_keys=False)
7 changes: 3 additions & 4 deletions modules/sampler_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,16 +428,15 @@ def custom_sort_key(obj):

# Sort the list using the custom key function
warpers = sorted(warpers, key=custom_sort_key)
if shared.args.verbose:
logger.info("WARPERS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint([x.__class__.__name__ for x in warpers])

if normalize is not None:
warpers.append(normalize)

warpers.append(SpyLogitsWarper())
warpers = LogitsProcessorList(warpers)
if shared.args.verbose:
logger.info("WARPERS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint([x.__class__.__name__ for x in warpers])

return warpers


Expand Down

0 comments on commit 0f134bf

Please sign in to comment.