Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Improve test suite for PartSegCore #1077

Merged
merged 16 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions package/PartSegCore/algorithm_describe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import typing
import warnings
from abc import ABC, ABCMeta, abstractmethod
from enum import Enum
from functools import wraps

from local_migrator import REGISTER, class_to_str
Comment on lines 3 to 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 NOTE
This review was outside the diff hunks, and no overlapping diff hunk was found. Original lines [1-1]

The file contains complex classes with multiple methods and properties. Consider adding comprehensive docstrings to public methods and classes to improve maintainability and understandability.

+ Add comprehensive docstrings to public methods and classes.

📝 NOTE
This review was outside the diff hunks, and no overlapping diff hunk was found. Original lines [1-1]

The dynamic creation of models using create_model in _GetDescriptionClass is an advanced feature. Ensure that this dynamic behavior is well-documented and tested to avoid confusion and potential bugs.

+ Improve documentation and testing for dynamic model creation in _GetDescriptionClass.

📝 NOTE
This review was outside the diff hunks, and no overlapping diff hunk was found. Original lines [1-1]

The implementation of custom metaclasses, such as AlgorithmDescribeBaseMeta and AddRegisterMeta, introduces complexity. Verify that their use is justified and consider simplifying the design if possible.

+ Evaluate the necessity of custom metaclasses and consider simplifying the design.

Expand Down Expand Up @@ -211,6 +210,8 @@ def get_fields_from_algorithm(ald_desc: AlgorithmDescribeBase) -> typing.List[ty


def is_static(fun):
if fun is None:
return False
args = inspect.getfullargspec(fun).args
return True if len(args) == 0 else args[0] != "self"

Expand Down Expand Up @@ -256,6 +257,9 @@ def __eq__(self, other):
and self.suggested_base_class == other.suggested_base_class
)

def __ne__(self, other):
return not self.__eq__(other)

def __getitem__(self, item) -> AlgorithmType:
# FIXME add better strategy to get proper class when there is conflict of names
try:
Expand All @@ -278,11 +282,11 @@ def register(
self.check_function(value, "get_name", True)
try:
name = value.get_name()
except NotImplementedError:
raise ValueError(f"Class {value} need to implement get_name class method") from None
except (NotImplementedError, AttributeError):
raise ValueError(f"Class {value} need to implement classmethod 'get_name'") from None
if name in self and not replace:
raise ValueError(
f"Object {self[name]} with this name: {name} already exist and register is not in replace mode"
f"Object {self[name]} with this name: '{name}' already exist and register is not in replace mode"
)
if not isinstance(name, str):
raise ValueError(f"Function get_name of class {value} need return string not {type(name)}")
Expand All @@ -292,8 +296,8 @@ def register(
for old_name in old_names:
if old_name in self._old_mapping and not replace:
raise ValueError(
f"Old value mapping for name {old_name} already registered."
f" Currently pointing to {self._old_mapping[name]}"
f"Old value mapping for name '{old_name}' already registered."
f" Currently pointing to {self._old_mapping[old_name]}"
)
self._old_mapping[old_name] = name
return value
Expand All @@ -304,23 +308,23 @@ def check_function(ob, function_name, is_class):
if not is_class and not inspect.isfunction(fun):
raise ValueError(f"Class {ob} need to define method {function_name}")
if is_class and not inspect.ismethod(fun) and not is_static(fun):
raise ValueError(f"Class {ob} need to define classmethod {function_name}")
raise ValueError(f"Class {ob} need to define classmethod '{function_name}'")

def __setitem__(self, key: str, value: AlgorithmType):
if not issubclass(value, AlgorithmDescribeBase):
raise ValueError(
f"Class {value} need to inherit from {AlgorithmDescribeBase.__module__}.AlgorithmDescribeBase"
f"Class {value} need to be subclass of {AlgorithmDescribeBase.__module__}.AlgorithmDescribeBase"
)
self.check_function(value, "get_name", True)
self.check_function(value, "get_fields", True)
try:
val = value.get_name()
except NotImplementedError:
raise ValueError(f"Method get_name of class {value} need to be implemented") from None
except (NotImplementedError, AttributeError):
raise ValueError(f"Class {value} need to implement classmethod 'get_name'") from None
if not isinstance(val, str):
raise ValueError(f"Function get_name of class {value} need return string not {type(val)}")
if key != val:
raise ValueError("Object need to be registered under name returned by gey_name function")
raise ValueError("Object need to be registered under name returned by get_name function")
if not value.__new_style__:
try:
val = value.get_fields()
Expand Down Expand Up @@ -415,7 +419,7 @@ def register(
:param replace: replace existing algorithm, be patient with
:param old_names: list of old names for registered class
"""
return cls.__register__.register(value, replace, old_names)
return cls.__register__.register(value, replace=replace, old_names=old_names)

@classmethod
def get_default(cls):
Expand Down Expand Up @@ -535,19 +539,6 @@ def _pretty_print(
res += "\n"
return res[:-1]

@classmethod
def print_dict(cls, dkt, indent=0, name: str = "") -> str:
if isinstance(dkt, Enum):
return dkt.name
if not isinstance(dkt, typing.MutableMapping):
# FIXME update in future method of proper printing channel number
if name.startswith("channel") and isinstance(dkt, int):
return str(dkt + 1)
return str(dkt)
return "\n" + "\n".join(
" " * indent + f"{k.replace('_', ' ')}: {cls.print_dict(v, indent + 2, k)}" for k, v in dkt.items()
)

def __eq__(self, other):
return (
isinstance(other, self.__class__)
Expand Down
5 changes: 0 additions & 5 deletions package/PartSegCore/analysis/load_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,6 @@ def get_short_name(cls):
def number_of_files(cls):
return 2

@classmethod
def correct_files_order(cls, paths):
name1, name2 = (os.path.basename(os.path.splitext(x)[0]) for x in paths)
return [name1, name2] if name2.endswith("_mask") else paths

@classmethod
def load(
cls,
Expand Down
6 changes: 4 additions & 2 deletions package/PartSegCore/autofit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from math import acos, pi, sqrt

import numpy as np
Expand Down Expand Up @@ -41,14 +42,15 @@ def find_density_orientation(img, voxel_size, cutoff=1):
return vectors, w_n


def get_rotation_parameters(isometric_matrix):
def get_rotation_parameters(isometric_matrix): # pragma: no cover
"""
If 3x3 isometric matrix is not rotation matrix
function transform it into rotation matrix
then calculate rotation axis and angel
:param isometric_matrix: 3x3 np.ndarray with determinant equal 1 or -1
:return: rotation_matrix, rotation axis, rotation angel
"""
warnings.warn("This function is deprecated", FutureWarning, stacklevel=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a deprecation warning to get_rotation_parameters is a good practice when phasing out old functionality. Ensure that the deprecation timeline and alternative methods (if any) are communicated to the users.

if np.linalg.det(isometric_matrix) < 0:
isometric_matrix = np.dot(np.diag([-1, 1, 1]), isometric_matrix)
angel = acos((np.sum(np.diag(isometric_matrix)) - 1) / 2) * 180 / pi
Expand Down Expand Up @@ -78,7 +80,7 @@ def density_mass_center(image, voxel_size=(1.0, 1.0, 1.0)):

if len(voxel_size) != image.ndim:
if len(voxel_size) != len(iter_dim):
raise ValueError("Cannot fit voxel size to array")
raise ValueError("Cannot fit voxel size to array") # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excluding the exception handling from test coverage with # pragma: no cover should be justified. If this error condition is unlikely or difficult to simulate in a test environment, consider adding a comment explaining the reasoning.

voxel_size_array = [0] * image.ndim
for i, item in enumerate(iter_dim):
voxel_size_array[item] = voxel_size[i]
Expand Down
79 changes: 32 additions & 47 deletions package/PartSegCore/io_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import json
import os
import re
Expand Down Expand Up @@ -42,7 +43,7 @@ def check_segmentation_type(tar_file: TarFile) -> SegmentationType:
return SegmentationType.analysis
if "metadata.json" in names:
return SegmentationType.mask
raise WrongFileTypeException
raise WrongFileTypeException # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (llm): Consider providing a specific error message to WrongFileTypeException to enhance error traceability and debugging experience.



def get_tarinfo(name, buffer: typing.Union[BytesIO, StringIO]):
Expand All @@ -56,7 +57,23 @@ def get_tarinfo(name, buffer: typing.Union[BytesIO, StringIO]):
return tar_info


class SaveBase(AlgorithmDescribeBase, ABC):
class _IOBase(AlgorithmDescribeBase, ABC):
@classmethod
def get_name_with_suffix(cls):
return cls.get_name()

@classmethod
def get_extensions(cls) -> typing.List[str]:
match = re.match(r".*\((.*)\)", cls.get_name())
if match is None:
raise ValueError(f"No extensions found in {cls.get_name()}")
extensions = match[1].split(" ")
if not all(x.startswith("*.") for x in extensions):
raise ValueError(f"Error with parsing extensions in {cls.get_name()}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (llm): The error message for parsing extensions is clear, but consider including the incorrect extension format in the message to aid in debugging.

return [x[1:] for x in extensions]
Comment on lines +60 to +73
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The methods get_name_with_suffix and get_extensions in _IOBase introduce new functionality for handling file extensions. Ensure that the regular expression used in get_extensions is robust and correctly handles all expected formats.



class SaveBase(_IOBase, ABC):
need_functions: typing.ClassVar[typing.List[str]] = [
"save",
"get_short_name",
Expand Down Expand Up @@ -91,15 +108,6 @@ def save(
"""
raise NotImplementedError

@classmethod
def get_name_with_suffix(cls):
return cls.get_name()

@classmethod
def get_default_extension(cls):
match = re.search(r"\(\*(\.\w+)", cls.get_name_with_suffix())
return match[1] if match else ""

@classmethod
def need_segmentation(cls):
return True
Expand All @@ -109,23 +117,17 @@ def need_mask(cls):
return False

@classmethod
def get_extensions(cls) -> typing.List[str]:
match = re.match(r".*\((.*)\)", cls.get_name())
if match is None:
raise ValueError(f"No extensions found in {cls.get_name()}")
extensions = match[1].split(" ")
if not all(x.startswith("*.") for x in extensions):
raise ValueError(f"Error with parsing extensions in {cls.get_name()}")
return [x[1:] for x in extensions]
def get_default_extension(cls):
match = re.search(r"\(\*(\.\w+)", cls.get_name_with_suffix())
return match[1] if match else ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question (llm): The use of regex for extracting the default extension is efficient. However, ensure that the regex pattern is robust enough to handle all expected formats of get_name_with_suffix.

Comment on lines +120 to +122
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method get_default_extension in SaveBase uses a regular expression to extract the default file extension. Verify that this method correctly handles edge cases and document any assumptions or limitations.



class LoadBase(AlgorithmDescribeBase, ABC):
class LoadBase(_IOBase, ABC):
need_functions: typing.ClassVar[typing.List[str]] = [
"load",
"get_short_name",
"get_name_with_suffix",
"number_of_files",
"correct_files_order",
"get_next_file",
"partial",
]
Expand Down Expand Up @@ -155,20 +157,6 @@ def load(
"""
raise NotImplementedError

@classmethod
def get_name_with_suffix(cls):
return cls.get_name()

@classmethod
def get_extensions(cls) -> typing.List[str]:
match = re.match(r".*\((.*)\)", cls.get_name())
if match is None:
raise ValueError(f"No extensions found in {cls.get_name()}")
extensions = match[1].split(" ")
if not all(x.startswith("*.") for x in extensions):
raise ValueError(f"Error with parsing extensions in {cls.get_name()}")
return [x[1:] for x in extensions]

@classmethod
def get_fields(cls):
return []
Expand All @@ -178,10 +166,6 @@ def number_of_files(cls):
"""Number of files required for load method"""
return 1

@classmethod
def correct_files_order(cls, paths):
return paths

@classmethod
def get_next_file(cls, file_paths: typing.List[str]):
return file_paths[0]
Expand All @@ -192,19 +176,19 @@ def partial(cls):
return False


def load_metadata_base(data: typing.Union[str, Path]):
def load_metadata_base(data: typing.Union[str, Path, typing.TextIO]):
try:
if isinstance(data, typing.TextIO):
if isinstance(data, io.TextIOBase):
decoded_data = json.load(data, object_hook=partseg_object_hook)
elif os.path.exists(data):
with open(data, encoding="utf-8") as ff:
decoded_data = json.load(ff, object_hook=partseg_object_hook)
else:
decoded_data = json.loads(data, object_hook=partseg_object_hook)
except ValueError as e:
except ValueError as e: # pragma: no cover
try:
decoded_data = json.loads(str(data), object_hook=partseg_object_hook)
except Exception: # pragma: no cover
except Exception:
raise e # noqa: B904

return decoded_data
Expand Down Expand Up @@ -299,7 +283,7 @@ def open_tar_file(
tar_file = TarFile.open(fileobj=file_data)
file_path = ""
else:
raise ValueError(f"wrong type of file_ argument: {type(file_data)}")
raise ValueError(f"wrong type of file_data argument: {type(file_data)}")
return tar_file, file_path


Expand All @@ -325,13 +309,14 @@ def save(
cls,
save_location: typing.Union[str, BytesIO, Path],
project_info,
parameters: dict,
parameters: typing.Optional[dict] = None,
range_changed=None,
step_changed=None,
):
if project_info.image.mask is None and project_info.mask is not None:
ImageWriter.save_mask(project_info.image.substitute(mask=project_info.mask), save_location)
ImageWriter.save_mask(project_info.image, save_location)
else:
ImageWriter.save_mask(project_info.image, save_location)


def tar_to_buff(tar_file, member_name) -> BytesIO:
Expand All @@ -352,7 +337,7 @@ def save(
cls,
save_location: typing.Union[str, BytesIO, Path],
project_info,
parameters: dict,
parameters: typing.Optional[dict] = None,
range_changed=None,
step_changed=None,
):
Expand Down
2 changes: 1 addition & 1 deletion package/PartSegCore/mask/io_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_raw_mask_copy(self):
return MaskProjectTuple(file_path=self.file_path, image=self.image.substitute(), mask=self.mask)

@property
def roi(self):
def roi(self): # pragma: no cover
warnings.warn("roi is deprecated", DeprecationWarning, stacklevel=2)
return self.roi_info.roi

Expand Down
4 changes: 2 additions & 2 deletions package/PartSegCore/napari_plugins/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def partseg_loader(loader: typing.Type[LoadBase], path: str):

try:
project_info = loader.load(load_locations)
except WrongFileTypeException:
except WrongFileTypeException: # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of # pragma: no cover to exclude exception handling from test coverage should be justified. If this exception is unlikely or difficult to trigger in tests, explain why in a comment.

return None

if isinstance(project_info, (ProjectTuple, MaskProjectTuple)):
return project_to_layers(project_info)
return None
return None # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return statement at the end of the function is excluded from test coverage with # pragma: no cover. Ensure that this is intentional and justified, as it might indicate untested logic paths in the function.

Loading
Loading