Skip to content

Commit

Permalink
Improved output types for the Resolutions node (#2937)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment authored Jun 4, 2024
1 parent cf65086 commit 9a50349
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 25 deletions.
26 changes: 20 additions & 6 deletions backend/src/nodes/properties/inputs/generic_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,11 @@ def __init__(
for variant in enum:
value = variant.value
assert isinstance(value, (int, str))
assert (
re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", variant.name) is not None
), f"Expected the name of {enum.__name__}.{variant.name} to be snake case."

name = split_snake_case(variant.name)
variant_type = f"{type_name}::{join_pascal_case(name)}"
option_label = option_labels.get(variant, join_space_case(name))
variant_type = EnumInput.get_variant_type(variant, type_name)
option_label = option_labels.get(
variant, join_space_case(split_snake_case(variant.name))
)
condition = conditions.get(variant)
if condition is not None:
condition = condition.to_json()
Expand Down Expand Up @@ -301,6 +299,22 @@ def __init__(

self.associated_type = enum

@staticmethod
def get_variant_type(variant: Enum, type_name: str | None = None) -> str:
"""
Returns the full type name of a variant of an enum.
"""

enum = variant.__class__
if type_name is None:
type_name = enum.__name__

assert (
re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", variant.name) is not None
), f"Expected the name of {enum.__name__}.{variant.name} to be snake case."

return f"{type_name}::{join_pascal_case(split_snake_case(variant.name))}"

def enforce(self, value: object) -> E:
value = super().enforce(value)
return self.enum(value)
Expand Down
51 changes: 32 additions & 19 deletions backend/src/packages/chaiNNer_standard/utility/value/resolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from enum import Enum

import navi
from nodes.groups import if_enum_group
from nodes.properties.inputs import EnumInput, NumberInput
from nodes.properties.outputs import NumberOutput
Expand Down Expand Up @@ -96,29 +97,41 @@ class ResList(Enum):
ResList.SQ8192: "Square 8192x8192",
ResList.CUSTOM: "Custom Resolution",
},
),
).with_id(0),
if_enum_group(0, ResList.CUSTOM)(
NumberInput(
"Width",
min=1,
max=None,
default=1920,
unit="px",
has_handle=False,
),
NumberInput(
"Height",
min=1,
max=None,
default=1080,
unit="px",
has_handle=False,
),
NumberInput("Width", min=1, default=1920, unit="px"),
NumberInput("Height", min=1, default=1080, unit="px"),
),
],
outputs=[
NumberOutput("Width", output_type="int(1..)"),
NumberOutput("Height", output_type="int(1..)"),
NumberOutput(
"Width",
output_type=navi.match(
"Input0",
(EnumInput.get_variant_type(ResList.CUSTOM), None, "Input1"),
default=navi.match(
"Input0",
*(
(EnumInput.get_variant_type(v), None, w)
for v, (w, _) in RESOLUTIONS.items()
),
),
),
),
NumberOutput(
"Height",
output_type=navi.match(
"Input0",
(EnumInput.get_variant_type(ResList.CUSTOM), None, "Input2"),
default=navi.match(
"Input0",
*(
(EnumInput.get_variant_type(v), None, h)
for v, (_, h) in RESOLUTIONS.items()
),
),
),
),
],
)
def resolutions_node(
Expand Down

0 comments on commit 9a50349

Please sign in to comment.