Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Improve test suite for PartSegCore #1077

Merged
merged 16 commits into from
Feb 14, 2024
Merged
14 changes: 0 additions & 14 deletions package/PartSegCore/algorithm_describe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import typing
import warnings
from abc import ABC, ABCMeta, abstractmethod
from enum import Enum
from functools import wraps

from local_migrator import REGISTER, class_to_str
Comment on lines 3 to 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 NOTE
This review was outside the diff hunks, and no overlapping diff hunk was found. Original lines [1-1]

The file contains complex classes with multiple methods and properties. Consider adding comprehensive docstrings to public methods and classes to improve maintainability and understandability.

+ Add comprehensive docstrings to public methods and classes.

📝 NOTE
This review was outside the diff hunks, and no overlapping diff hunk was found. Original lines [1-1]

The dynamic creation of models using create_model in _GetDescriptionClass is an advanced feature. Ensure that this dynamic behavior is well-documented and tested to avoid confusion and potential bugs.

+ Improve documentation and testing for dynamic model creation in _GetDescriptionClass.

📝 NOTE
This review was outside the diff hunks, and no overlapping diff hunk was found. Original lines [1-1]

The implementation of custom metaclasses, such as AlgorithmDescribeBaseMeta and AddRegisterMeta, introduces complexity. Verify that their use is justified and consider simplifying the design if possible.

+ Evaluate the necessity of custom metaclasses and consider simplifying the design.

Expand Down Expand Up @@ -535,19 +534,6 @@ def _pretty_print(
res += "\n"
return res[:-1]

@classmethod
def print_dict(cls, dkt, indent=0, name: str = "") -> str:
if isinstance(dkt, Enum):
return dkt.name
if not isinstance(dkt, typing.MutableMapping):
# FIXME update in future method of proper printing channel number
if name.startswith("channel") and isinstance(dkt, int):
return str(dkt + 1)
return str(dkt)
return "\n" + "\n".join(
" " * indent + f"{k.replace('_', ' ')}: {cls.print_dict(v, indent + 2, k)}" for k, v in dkt.items()
)

def __eq__(self, other):
return (
isinstance(other, self.__class__)
Expand Down
61 changes: 25 additions & 36 deletions package/PartSegCore/io_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import json
import os
import re
Expand Down Expand Up @@ -42,7 +43,7 @@ def check_segmentation_type(tar_file: TarFile) -> SegmentationType:
return SegmentationType.analysis
if "metadata.json" in names:
return SegmentationType.mask
raise WrongFileTypeException
raise WrongFileTypeException # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (llm): Consider providing a specific error message to WrongFileTypeException to enhance error traceability and debugging experience.



def get_tarinfo(name, buffer: typing.Union[BytesIO, StringIO]):
Expand All @@ -56,7 +57,23 @@ def get_tarinfo(name, buffer: typing.Union[BytesIO, StringIO]):
return tar_info


class SaveBase(AlgorithmDescribeBase, ABC):
class _IOBase(AlgorithmDescribeBase, ABC):
@classmethod
def get_name_with_suffix(cls):
return cls.get_name()

@classmethod
def get_extensions(cls) -> typing.List[str]:
match = re.match(r".*\((.*)\)", cls.get_name())
if match is None:
raise ValueError(f"No extensions found in {cls.get_name()}")
extensions = match[1].split(" ")
if not all(x.startswith("*.") for x in extensions):
raise ValueError(f"Error with parsing extensions in {cls.get_name()}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (llm): The error message for parsing extensions is clear, but consider including the incorrect extension format in the message to aid in debugging.

return [x[1:] for x in extensions]
Comment on lines +60 to +73
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The methods get_name_with_suffix and get_extensions in _IOBase introduce new functionality for handling file extensions. Ensure that the regular expression used in get_extensions is robust and correctly handles all expected formats.



class SaveBase(_IOBase, ABC):
need_functions: typing.ClassVar[typing.List[str]] = [
"save",
"get_short_name",
Expand Down Expand Up @@ -91,15 +108,6 @@ def save(
"""
raise NotImplementedError

@classmethod
def get_name_with_suffix(cls):
return cls.get_name()

@classmethod
def get_default_extension(cls):
match = re.search(r"\(\*(\.\w+)", cls.get_name_with_suffix())
return match[1] if match else ""

@classmethod
def need_segmentation(cls):
return True
Expand All @@ -109,17 +117,12 @@ def need_mask(cls):
return False

@classmethod
def get_extensions(cls) -> typing.List[str]:
match = re.match(r".*\((.*)\)", cls.get_name())
if match is None:
raise ValueError(f"No extensions found in {cls.get_name()}")
extensions = match[1].split(" ")
if not all(x.startswith("*.") for x in extensions):
raise ValueError(f"Error with parsing extensions in {cls.get_name()}")
return [x[1:] for x in extensions]
def get_default_extension(cls):
match = re.search(r"\(\*(\.\w+)", cls.get_name_with_suffix())
return match[1] if match else ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question (llm): The use of regex for extracting the default extension is efficient. However, ensure that the regex pattern is robust enough to handle all expected formats of get_name_with_suffix.

Comment on lines +120 to +122
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method get_default_extension in SaveBase uses a regular expression to extract the default file extension. Verify that this method correctly handles edge cases and document any assumptions or limitations.



class LoadBase(AlgorithmDescribeBase, ABC):
class LoadBase(_IOBase, ABC):
need_functions: typing.ClassVar[typing.List[str]] = [
"load",
"get_short_name",
Expand Down Expand Up @@ -155,20 +158,6 @@ def load(
"""
raise NotImplementedError

@classmethod
def get_name_with_suffix(cls):
return cls.get_name()

@classmethod
def get_extensions(cls) -> typing.List[str]:
match = re.match(r".*\((.*)\)", cls.get_name())
if match is None:
raise ValueError(f"No extensions found in {cls.get_name()}")
extensions = match[1].split(" ")
if not all(x.startswith("*.") for x in extensions):
raise ValueError(f"Error with parsing extensions in {cls.get_name()}")
return [x[1:] for x in extensions]

@classmethod
def get_fields(cls):
return []
Expand All @@ -192,9 +181,9 @@ def partial(cls):
return False


def load_metadata_base(data: typing.Union[str, Path]):
def load_metadata_base(data: typing.Union[str, Path, typing.TextIO]):
try:
if isinstance(data, typing.TextIO):
if isinstance(data, io.TextIOBase):
decoded_data = json.load(data, object_hook=partseg_object_hook)
elif os.path.exists(data):
with open(data, encoding="utf-8") as ff:
Expand Down
34 changes: 32 additions & 2 deletions package/tests/test_PartSegCore/test_algorithm_describe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,28 @@
from PartSegImage import Channel


def test_algorithm_property():
ap = AlgorithmProperty("test", "Test", 1)
assert ap.name == "test"
assert "user_name='Test'" in repr(ap)

Comment on lines +24 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure the test test_algorithm_property includes assertions for all relevant properties of AlgorithmProperty to fully validate its behavior.


def test_algorithm_property_warn():
with pytest.warns(DeprecationWarning, match="use value_type instead"):
ap = AlgorithmProperty("test", "Test", 1, property_type=int)
assert ap.value_type == int
Comment on lines +30 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test test_algorithm_property_warn correctly checks for a deprecation warning. Consider adding a comment explaining the context of deprecation for clarity.



def test_algorithm_property_no_kwargs():
with pytest.raises(ValueError, match="are not expected"):
AlgorithmProperty("test", "Test", 1, a=1)


def test_algorithm_property_list_exc():
with pytest.raises(ValueError, match="should be one of possible values"):
AlgorithmProperty("test", "Test", 1, possible_values=[2, 3], value_type=list)
Comment on lines +41 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In test_algorithm_property_list_exc, validate the exception message more precisely to ensure it matches the expected error scenario closely.



def test_get_description_class():
class SampleClass:
__test_class__ = _GetDescriptionClass()
Expand Down Expand Up @@ -77,6 +99,11 @@ def get_fields(cls) -> typing.List[typing.Union[AlgorithmProperty, str]]:

assert TestSelection["test1"] is Class1

assert TestSelection.__register__ != TestSelection2.__register__

ts = TestSelection(name="test1", values={})
assert ts.algorithm() == Class1

Comment on lines +102 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion in lines 102 and 105 seems redundant since the uniqueness of __register__ between TestSelection and TestSelection2 is implicitly tested by the assertion on line 102. Consider removing line 105 to streamline the test.

-    ts = TestSelection(name="test1", values={})
-    assert ts.algorithm() == Class1

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
assert TestSelection.__register__ != TestSelection2.__register__
ts = TestSelection(name="test1", values={})
assert ts.algorithm() == Class1
assert TestSelection.__register__ != TestSelection2.__register__


def test_algorithm_selection_convert_subclass(clean_register):
class TestSelection(AlgorithmSelection):
Expand Down Expand Up @@ -362,8 +389,11 @@ def test_roi_extraction_profile(self):
ROIExtractionProfile("aaa", "aaa", {})

def test_pretty_print(self):
prof1 = ROIExtractionProfile(name="aaa", algorithm="aaa", values={})
assert f"{prof1}\n " == prof1.pretty_print(AnalysisAlgorithmSelection)

prof1 = ROIExtractionProfile(name="aaa", algorithm="Lower threshold", values={})
assert prof1.pretty_print(AnalysisAlgorithmSelection).startswith("ROI extraction profile name:")
prof1 = ROIExtractionProfile(name="", algorithm="Lower threshold", values={})
assert prof1.pretty_print(AnalysisAlgorithmSelection).startswith("ROI extraction profile\n")
Comment on lines +541 to +545
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test test_pretty_print effectively checks the output format of pretty_print but consider adding more detailed assertions to verify the content of the output, especially for prof2 where multiple lines are expected.

+    assert "Lower threshold" in prof2.pretty_print(AnalysisAlgorithmSelection)
+    assert "default values" in prof2.pretty_print(AnalysisAlgorithmSelection)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
prof1 = ROIExtractionProfile(name="aaa", algorithm="Lower threshold", values={})
assert prof1.pretty_print(AnalysisAlgorithmSelection).startswith("ROI extraction profile name:")
prof1 = ROIExtractionProfile(name="", algorithm="Lower threshold", values={})
assert prof1.pretty_print(AnalysisAlgorithmSelection).startswith("ROI extraction profile\n")
prof1 = ROIExtractionProfile(name="aaa", algorithm="Lower threshold", values={})
assert prof1.pretty_print(AnalysisAlgorithmSelection).startswith("ROI extraction profile name:")
prof1 = ROIExtractionProfile(name="", algorithm="Lower threshold", values={})
assert prof1.pretty_print(AnalysisAlgorithmSelection).startswith("ROI extraction profile\n")
assert "Lower threshold" in prof2.pretty_print(AnalysisAlgorithmSelection)
assert "default values" in prof2.pretty_print(AnalysisAlgorithmSelection)

prof2 = ROIExtractionProfile(
name="aaa",
algorithm=LowerThresholdAlgorithm.get_name(),
Expand Down
80 changes: 78 additions & 2 deletions package/tests/test_PartSegCore/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
LoadBase,
LoadPlanExcel,
LoadPlanJson,
LoadPoints,
SaveBase,
SaveROIAsNumpy,
find_problematic_entries,
Expand Down Expand Up @@ -303,7 +304,7 @@ def test_modernize_0_9_2_3(self, bundle_test_dir):
data = load_metadata_base(file_path)

def test_update_name(self):
data = load_metadata_base(update_name_json)
data = load_metadata_base(UPDATE_NAME_JSON)
mp = data["problematic set"]
assert isinstance(mp, MeasurementProfile)
assert isinstance(mp.chosen_fields[0], MeasurementEntry)
Expand All @@ -327,6 +328,13 @@ def test_load_workflow(self, bundle_test_dir):
for entry in measurement_profile.chosen_fields:
assert entry.calculation_tree.name in MEASUREMENT_DICT

def test_load_workflow_from_text(self, bundle_test_dir):
with open(os.path.join(bundle_test_dir, "workflow.json")) as ff:
data_text = ff.read()
assert isinstance(load_metadata_base(data_text)["workflow"], CalculationPlan)
with open(os.path.join(bundle_test_dir, "workflow.json")) as ff:
isinstance(load_metadata_base(ff)["workflow"], CalculationPlan)
Comment on lines +340 to +345
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is loading same file twice, but in a different way. First by passing file content, second by passing TextIOWrapper. It validates if a function handles all input types properly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Czaki: Thank you for the clarification. I understand now that the method is intentionally designed to test the load_metadata_base function with different types of inputs, ensuring it handles both file content and TextIOWrapper correctly. This is a thoughtful approach to testing the function's flexibility.

Comment on lines +340 to +345
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method test_load_workflow_from_text reads a workflow JSON file twice but only asserts the type of loaded data in one instance. Ensure consistency in testing and consider removing redundant file reading or adding relevant assertions.

-        with open(os.path.join(bundle_test_dir, "workflow.json")) as ff:
-            isinstance(load_metadata_base(ff)["workflow"], CalculationPlan)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def test_load_workflow_from_text(self, bundle_test_dir):
with open(os.path.join(bundle_test_dir, "workflow.json")) as ff:
data_text = ff.read()
assert isinstance(load_metadata_base(data_text)["workflow"], CalculationPlan)
with open(os.path.join(bundle_test_dir, "workflow.json")) as ff:
isinstance(load_metadata_base(ff)["workflow"], CalculationPlan)
def test_load_workflow_from_text(self, bundle_test_dir):
with open(os.path.join(bundle_test_dir, "workflow.json")) as ff:
data_text = ff.read()
assert isinstance(load_metadata_base(data_text)["workflow"], CalculationPlan)



class TestSegmentationMask:
def test_load_seg(self, data_test_dir):
Expand Down Expand Up @@ -475,6 +483,32 @@ def test_load_project_with_history(self, tmp_path, stack_segmentation1, mask_pro
cmp_dict = {str(k): v for k, v in stack_segmentation1.roi_extraction_parameters.items()}
assert str(res.history[0].roi_extraction_parameters["parameters"]) == str(cmp_dict)

def test_mask_project_tuple(self):
mask = np.zeros((10, 10), dtype=np.uint8)
mask[1:-1, 1:-1] = 2
mask2 = np.copy(mask)
mask2[2:-2, 2:-2] = 4
roi_info = ROIInfo(mask2)
mask_prop = MaskProperty.simple_mask()
elem = HistoryElement.create(roi_info, mask, {}, mask_prop)
proj = MaskProjectTuple(
file_path="test_data.tiff",
image=Image(np.zeros((10, 10), dtype=np.uint8), (1, 1), "", axes_order="YX"),
mask=mask,
roi_info=roi_info,
history=[elem],
selected_components=[1, 2],
roi_extraction_parameters={},
)
assert not proj.is_raw()
assert proj.is_masked()
raw_proj = proj.get_raw_copy()
assert raw_proj.is_raw()
assert not raw_proj.is_masked()
raw_masked_proj = proj.get_raw_mask_copy()
assert raw_masked_proj.is_masked()
assert raw_masked_proj.is_raw()


class TestSaveFunctions:
@staticmethod
Expand Down Expand Up @@ -643,6 +677,7 @@ def test_json_parameters_mask_2(stack_segmentation1, tmp_path):

@pytest.mark.parametrize("file_path", (Path(__file__).parent.parent / "test_data" / "notebook").glob("*.json"))
def test_load_notebook_json(file_path):
"""Check if all notebook files can be loaded"""
load_metadata_base(file_path)


Expand Down Expand Up @@ -755,7 +790,39 @@ def test_load_image_for_batch(data_test_dir):
assert proj.mask is None


update_name_json = """
def test_save_base_extension_parse_no_ext():
class Save(SaveBase):
@classmethod
def get_name(cls) -> str:
return "Sample save"

with pytest.raises(ValueError, match="No extensions"):
Save.get_extensions()


def test_save_base_extension_parse_nmalformated_ext():
class Save(SaveBase):
@classmethod
def get_name(cls) -> str:
return "Sample save (a.txt)"

with pytest.raises(ValueError, match="Error with parsing"):
Save.get_extensions()
Czaki marked this conversation as resolved.
Show resolved Hide resolved


def test_load_points(tmp_path):
data_path = tmp_path / "sample.csv"
with data_path.open("w") as fp:
fp.write(POINTS_DATA)

res = LoadPoints.load([data_path])
assert res.file_path == data_path
assert res.points.shape == (5, 4)

assert LoadPoints.get_short_name() == "point_csv"


UPDATE_NAME_JSON = """
{"problematic set": {
"__MeasurementProfile__": true,
"name": "problematic set",
Expand Down Expand Up @@ -834,3 +901,12 @@ def test_load_image_for_batch(data_test_dir):
}
}
"""


POINTS_DATA = """index,axis-0,axis-1,axis-2,axis-3
0.0,0.0,22.0,227.90873370570543,65.07832834070409
1.0,0.0,22.0,91.94021739981048,276.7482348060973
2.0,0.0,22.0,194.83531082048773,380.3782931797794
3.0,0.0,22.0,391.07095327277943,268.6636259303818
4.0,0.0,22.0,152.20734354620717,256.1692217292996
"""
Loading