From 36131f0a3b70ff0bde5fdc098b31abdc5de6f878 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Wed, 7 Aug 2024 22:22:50 -0700 Subject: [PATCH] Allow kwargs in VALIDATE_INPUTS functions When kwargs are used, validation is skipped for all inputs as if they had been mentioned explicitly. --- execution.py | 11 +++++--- tests/inference/test_execution.py | 16 +++++++++++ .../testing-pack/specific_tests.py | 27 +++++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/execution.py b/execution.py index 30c0b3834b6..3bffae8e94f 100644 --- a/execution.py +++ b/execution.py @@ -530,8 +530,11 @@ def validate_inputs(prompt, item, validated): valid = True validate_function_inputs = [] + validate_has_kwargs = False if hasattr(obj_class, "VALIDATE_INPUTS"): - validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args + argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS) + validate_function_inputs = argspec.args + validate_has_kwargs = argspec.varkw is not None received_types = {} for x in valid_inputs: @@ -641,7 +644,7 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if x not in validate_function_inputs: + if x not in validate_function_inputs and not validate_has_kwargs: if "min" in extra_info and val < extra_info["min"]: error = { "type": "value_smaller_than_min", @@ -695,11 +698,11 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if len(validate_function_inputs) > 0: + if len(validate_function_inputs) > 0 or validate_has_kwargs: input_data_all, _ = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: - if x in validate_function_inputs: + if x in validate_function_inputs or validate_has_kwargs: input_filtered[x] = input_data_all[x] if 'input_types' in validate_function_inputs: input_filtered['input_types'] = [received_types] diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index b9dec659831..8616ca1e8e8 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -312,6 +312,22 @@ def test_validation_error_edge4(self, test_type, test_value, expect_error, clien else: client.run(g) + @pytest.mark.parametrize("test_value1, test_value2, expect_error", [ + (0.0, 0.5, False), + (0.0, 5.0, False), + (0.0, 7.0, True) + ]) + def test_validation_error_kwargs(self, test_value1, test_value2, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + validation5 = g.node("TestCustomValidation5", input1=test_value1, input2=test_value2) + g.node("SaveImage", images=validation5.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index 8e8ce32ced8..5884cae0c5a 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -221,6 +221,31 @@ def VALIDATE_INPUTS(cls, input1, input2): return True +class TestCustomValidation5: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("FLOAT", {"min": 0.0, "max": 1.0}), + "input2": ("FLOAT", {"min": 0.0, "max": 1.0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation5" + + CATEGORY = "Testing/Nodes" + + def custom_validation5(self, input1, input2): + value = input1 * input2 + return (torch.ones([1, 512, 512, 3]) * value,) + + @classmethod + def VALIDATE_INPUTS(cls, **kwargs): + if kwargs['input2'] == 7.0: + return "7s are not allowed. I've never liked 7s." + return True + class TestDynamicDependencyCycle: @classmethod def INPUT_TYPES(cls): @@ -291,6 +316,7 @@ def mixed_expansion_returns(self, input1): "TestCustomValidation2": TestCustomValidation2, "TestCustomValidation3": TestCustomValidation3, "TestCustomValidation4": TestCustomValidation4, + "TestCustomValidation5": TestCustomValidation5, "TestDynamicDependencyCycle": TestDynamicDependencyCycle, "TestMixedExpansionReturns": TestMixedExpansionReturns, } @@ -303,6 +329,7 @@ def mixed_expansion_returns(self, input1): "TestCustomValidation2": "Custom Validation 2", "TestCustomValidation3": "Custom Validation 3", "TestCustomValidation4": "Custom Validation 4", + "TestCustomValidation5": "Custom Validation 5", "TestDynamicDependencyCycle": "Dynamic Dependency Cycle", "TestMixedExpansionReturns": "Mixed Expansion Returns", }