Skip to content

Commit

Permalink
Add reset button for custom workflow parameters #1374
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Nov 12, 2024
1 parent 0741995 commit 611f1e3
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 14 deletions.
82 changes: 82 additions & 0 deletions ai_diffusion/icons/reset-dark.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
82 changes: 82 additions & 0 deletions ai_diffusion/icons/reset-light.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
65 changes: 51 additions & 14 deletions ai_diffusion/ui/custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class LayerSelect(QComboBox):

def __init__(self, filter: str | None = None, parent: QWidget | None = None):
super().__init__(parent)
self.param = None
self.filter = filter

self.setContentsMargins(0, 0, 0, 0)
Expand Down Expand Up @@ -84,6 +85,7 @@ class IntParamWidget(QWidget):

def __init__(self, param: CustomParam, parent: QWidget | None = None):
super().__init__(parent)
self.param = param
self.setContentsMargins(0, 0, 0, 0)

layout = QHBoxLayout(self)
Expand Down Expand Up @@ -131,6 +133,7 @@ class FloatParamWidget(QWidget):

def __init__(self, param: CustomParam, parent: QWidget | None = None):
super().__init__(parent)
self.param = param
self.setContentsMargins(0, 0, 0, 0)

layout = QHBoxLayout(self)
Expand Down Expand Up @@ -185,6 +188,7 @@ class BoolParamWidget(QWidget):

def __init__(self, param: CustomParam, parent: QWidget | None = None):
super().__init__(parent)
self.param = param
self.setContentsMargins(0, 0, 0, 0)

layout = QHBoxLayout(self)
Expand Down Expand Up @@ -221,6 +225,7 @@ class TextParamWidget(QLineEdit):
def __init__(self, param: CustomParam, parent: QWidget | None = None):
super().__init__(parent)
assert isinstance(param.default, str)
self.param = param

self.value = param.default
self.textChanged.connect(self._notify)
Expand All @@ -243,6 +248,7 @@ class PromptParamWidget(TextPromptWidget):
def __init__(self, param: CustomParam, parent: QWidget | None = None):
super().__init__(is_negative=param.kind is ParamKind.prompt_negative, parent=parent)
assert isinstance(param.default, str)
self.param = param

self.setObjectName("PromptParam")
self.setFrameStyle(QFrame.Shape.StyledPanel)
Expand All @@ -266,6 +272,7 @@ class ChoiceParamWidget(QComboBox):

def __init__(self, param: CustomParam, parent: QWidget | None = None):
super().__init__(parent)
self.param = param
self.setMinimumContentsLength(20)
self.setSizeAdjustPolicy(QComboBox.SizeAdjustPolicy.AdjustToMinimumContentsLength)

Expand Down Expand Up @@ -294,6 +301,7 @@ class StyleParamWidget(QWidget):

def __init__(self, parent: QWidget):
super().__init__(parent)
self.param = None
self._style_select = StyleSelectWidget(self)
self._style_select.value_changed.connect(self._notify)
layout = QHBoxLayout()
Expand Down Expand Up @@ -350,6 +358,44 @@ def _create_param_widget(param: CustomParam, parent: QWidget) -> CustomParamWidg
assert False, f"Unknown param kind: {param.kind}"


class GroupHeader(QWidget):
def __init__(self, text: str, parent: QWidget | None = None):
super().__init__(parent)
self._group_widgets: list[CustomParamWidget] = []

self._expander = ExpanderButton(text, self)
self._expander.toggled.connect(self._show_group)

fh = self.fontMetrics().height()
self._reset_button = QToolButton(self)
self._reset_button.setFixedSize(fh + 2, fh + 2)
self._reset_button.setIcon(theme.icon("reset"))
self._reset_button.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly)
self._reset_button.setAutoRaise(True)
self._reset_button.setToolTip(_("Reset all parameters in this group"))
self._reset_button.clicked.connect(self._reset_group)

layout = QHBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
layout.addWidget(self._expander, stretch=1, alignment=Qt.AlignmentFlag.AlignLeft)
layout.addWidget(self._reset_button, alignment=Qt.AlignmentFlag.AlignRight)

def set_group_widgets(self, widgets: list[CustomParamWidget], show_group: bool):
self._group_widgets = widgets
self._expander.setChecked(show_group)
self._show_group(show_group)

def _show_group(self, checked: bool):
for w in self._group_widgets:
w.setVisible(checked)
self._reset_button.setVisible(checked)

def _reset_group(self):
for w in self._group_widgets:
if not isinstance(w, QLabel) and w.param is not None and w.param.default is not None:
w.value = w.param.default


class WorkflowParamsWidget(QWidget):
value_changed = pyqtSignal()

Expand All @@ -358,20 +404,20 @@ def __init__(self, params: list[CustomParam], parent: QWidget | None = None):
self._widgets: dict[str, CustomParamWidget] = {}

layout = QGridLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
layout.setContentsMargins(0, 0, 2, 0)
layout.setColumnMinimumWidth(0, 10)
layout.setColumnMinimumWidth(2, 10)
layout.setColumnStretch(3, 1)
self.setLayout(layout)

params = sorted(params)
current_group: tuple[str, ExpanderButton | None, list[CustomParamWidget]] = ("", None, [])
current_group: tuple[str, GroupHeader | None, list[CustomParamWidget]] = ("", None, [])

for p in params:
group, expander, group_widgets = current_group
if p.group != group:
self._create_group(expander, group_widgets)
expander = ExpanderButton(p.group, self)
expander = GroupHeader(p.group, self)
group_widgets = []
current_group = (p.group, expander, group_widgets)
layout.addWidget(expander, layout.rowCount(), 0, 1, 4)
Expand All @@ -391,18 +437,9 @@ def __init__(self, params: list[CustomParam], parent: QWidget | None = None):
def _notify(self):
self.value_changed.emit()

def _create_group(self, expander: ExpanderButton | None, widgets: list[CustomParamWidget]):
def _create_group(self, expander: GroupHeader | None, widgets: list[CustomParamWidget]):
if expander is not None:
expander.setChecked(len(self._widgets) < 7)
expander.toggled.connect(self._show_group(widgets))
self._show_group(widgets)(expander.isChecked())

def _show_group(self, widgets: list[CustomParamWidget]):
def set_visible(checked: bool):
for w in widgets:
w.setVisible(checked)

return set_visible
expander.set_group_widgets(widgets, len(self._widgets) < 7)

@property
def value(self):
Expand Down

0 comments on commit 611f1e3

Please sign in to comment.