From 9a503495ad56fc2f9a95ea9384cba893920e0ba3 Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Tue, 4 Jun 2024 15:58:30 +0200 Subject: [PATCH] Improved output types for the Resolutions node (#2937) --- .../nodes/properties/inputs/generic_inputs.py | 26 +++++++--- .../utility/value/resolutions.py | 51 ++++++++++++------- 2 files changed, 52 insertions(+), 25 deletions(-) diff --git a/backend/src/nodes/properties/inputs/generic_inputs.py b/backend/src/nodes/properties/inputs/generic_inputs.py index d172fbff5..a42d0263e 100644 --- a/backend/src/nodes/properties/inputs/generic_inputs.py +++ b/backend/src/nodes/properties/inputs/generic_inputs.py @@ -258,13 +258,11 @@ def __init__( for variant in enum: value = variant.value assert isinstance(value, (int, str)) - assert ( - re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", variant.name) is not None - ), f"Expected the name of {enum.__name__}.{variant.name} to be snake case." - name = split_snake_case(variant.name) - variant_type = f"{type_name}::{join_pascal_case(name)}" - option_label = option_labels.get(variant, join_space_case(name)) + variant_type = EnumInput.get_variant_type(variant, type_name) + option_label = option_labels.get( + variant, join_space_case(split_snake_case(variant.name)) + ) condition = conditions.get(variant) if condition is not None: condition = condition.to_json() @@ -301,6 +299,22 @@ def __init__( self.associated_type = enum + @staticmethod + def get_variant_type(variant: Enum, type_name: str | None = None) -> str: + """ + Returns the full type name of a variant of an enum. + """ + + enum = variant.__class__ + if type_name is None: + type_name = enum.__name__ + + assert ( + re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", variant.name) is not None + ), f"Expected the name of {enum.__name__}.{variant.name} to be snake case." + + return f"{type_name}::{join_pascal_case(split_snake_case(variant.name))}" + def enforce(self, value: object) -> E: value = super().enforce(value) return self.enum(value) diff --git a/backend/src/packages/chaiNNer_standard/utility/value/resolutions.py b/backend/src/packages/chaiNNer_standard/utility/value/resolutions.py index cd708de17..460678f00 100644 --- a/backend/src/packages/chaiNNer_standard/utility/value/resolutions.py +++ b/backend/src/packages/chaiNNer_standard/utility/value/resolutions.py @@ -2,6 +2,7 @@ from enum import Enum +import navi from nodes.groups import if_enum_group from nodes.properties.inputs import EnumInput, NumberInput from nodes.properties.outputs import NumberOutput @@ -96,29 +97,41 @@ class ResList(Enum): ResList.SQ8192: "Square 8192x8192", ResList.CUSTOM: "Custom Resolution", }, - ), + ).with_id(0), if_enum_group(0, ResList.CUSTOM)( - NumberInput( - "Width", - min=1, - max=None, - default=1920, - unit="px", - has_handle=False, - ), - NumberInput( - "Height", - min=1, - max=None, - default=1080, - unit="px", - has_handle=False, - ), + NumberInput("Width", min=1, default=1920, unit="px"), + NumberInput("Height", min=1, default=1080, unit="px"), ), ], outputs=[ - NumberOutput("Width", output_type="int(1..)"), - NumberOutput("Height", output_type="int(1..)"), + NumberOutput( + "Width", + output_type=navi.match( + "Input0", + (EnumInput.get_variant_type(ResList.CUSTOM), None, "Input1"), + default=navi.match( + "Input0", + *( + (EnumInput.get_variant_type(v), None, w) + for v, (w, _) in RESOLUTIONS.items() + ), + ), + ), + ), + NumberOutput( + "Height", + output_type=navi.match( + "Input0", + (EnumInput.get_variant_type(ResList.CUSTOM), None, "Input2"), + default=navi.match( + "Input0", + *( + (EnumInput.get_variant_type(v), None, h) + for v, (_, h) in RESOLUTIONS.items() + ), + ), + ), + ), ], ) def resolutions_node(