diff --git a/package/PartSegCore/algorithm_describe_base.py b/package/PartSegCore/algorithm_describe_base.py index 5c859f8c0..c47a15cf0 100644 --- a/package/PartSegCore/algorithm_describe_base.py +++ b/package/PartSegCore/algorithm_describe_base.py @@ -313,18 +313,18 @@ def check_function(ob, function_name, is_class): def __setitem__(self, key: str, value: AlgorithmType): if not issubclass(value, AlgorithmDescribeBase): raise ValueError( - f"Class {value} need to inherit from {AlgorithmDescribeBase.__module__}.AlgorithmDescribeBase" + f"Class {value} need to be subclass of {AlgorithmDescribeBase.__module__}.AlgorithmDescribeBase" ) self.check_function(value, "get_name", True) self.check_function(value, "get_fields", True) try: val = value.get_name() - except NotImplementedError: - raise ValueError(f"Method get_name of class {value} need to be implemented") from None + except (NotImplementedError, AttributeError): + raise ValueError(f"Class {value} need to implement classmethod 'get_name'") from None if not isinstance(val, str): raise ValueError(f"Function get_name of class {value} need return string not {type(val)}") if key != val: - raise ValueError("Object need to be registered under name returned by gey_name function") + raise ValueError("Object need to be registered under name returned by get_name function") if not value.__new_style__: try: val = value.get_fields() diff --git a/package/tests/test_PartSegCore/test_algorithm_describe_base.py b/package/tests/test_PartSegCore/test_algorithm_describe_base.py index e7230726b..528690108 100644 --- a/package/tests/test_PartSegCore/test_algorithm_describe_base.py +++ b/package/tests/test_PartSegCore/test_algorithm_describe_base.py @@ -130,11 +130,14 @@ def get_fields(cls): with pytest.raises(ValueError, match="Class .* need to implement classmethod 'get_name'"): TestSelection.register(Alg2) + with pytest.raises(ValueError, match="Class .* need to implement classmethod 'get_name'"): + TestSelection.__register__["test1"] = Alg2 + with pytest.raises(ValueError, match="Function get_name of class .* need return string not .*int"): TestSelection.register(Alg3) -def test_register_name_collsion(): +def test_register_name_collision(): class TestSelection(AlgorithmSelection): pass @@ -154,7 +157,7 @@ def get_name(cls): @classmethod def get_fields(cls): - return [] + return [] # pragma: no cover class Alg3(AlgorithmDescribeBase): @classmethod @@ -177,6 +180,80 @@ def get_fields(cls): TestSelection.register(Alg3, old_names=["0"]) +def test_register_not_subclass(): + class TestSelection(AlgorithmSelection): + pass + + class Alg1: + @classmethod + def get_name(cls): + return "1" + + @classmethod + def get_fields(cls): + return [] + + with pytest.raises(ValueError, match="Class .* need to be subclass of .*AlgorithmDescribeBase"): + TestSelection.register(Alg1) + + +def test_register_validate_name_assignment(): + class TestSelection(AlgorithmSelection): + pass + + class Alg1(AlgorithmDescribeBase): + @classmethod + def get_name(cls): + return "1" + + @classmethod + def get_fields(cls): + return [] + + class Alg2(Alg1): + @classmethod + def get_name(cls): + return 2 + + with pytest.raises(ValueError, match="need return string"): + TestSelection.__register__["1"] = Alg2 + + with pytest.raises(ValueError, match="under name returned by get_name function"): + TestSelection.__register__["2"] = Alg1 + + +def test_register_get_fields_validity(): + class TestSelection(AlgorithmSelection): + pass + + class Alg1(AlgorithmDescribeBase): + @classmethod + def get_name(cls): + return "1" + + @classmethod + def get_fields(cls): + raise NotImplementedError + + class Alg2(Alg1): + @classmethod + def get_fields(cls): + return () + + with pytest.raises(ValueError, match="need to be implemented"): + TestSelection.register(Alg1) + with pytest.raises(ValueError, match="need return list not"): + TestSelection.register(Alg2) + + +def test_register_no_default_value(): + class TestSelection(AlgorithmSelection): + pass + + with pytest.raises(ValueError, match="Register does not contain any algorithm"): + TestSelection.get_default() + + def test_algorithm_selection_convert_subclass(clean_register): class TestSelection(AlgorithmSelection): pass