From d430f9ea12e91e6359944f0b07acea16d6be556e Mon Sep 17 00:00:00 2001 From: "yexuan.li" Date: Wed, 23 Jun 2021 15:52:26 +0800 Subject: [PATCH] feat: add "isclose" method for Vector, Box3D and LabeledBox3D --- tensorbay/geometry/box.py | 21 +++++++ tensorbay/geometry/tests/test_box.py | 10 ++-- tensorbay/geometry/tests/test_vector.py | 5 ++ tensorbay/geometry/transform.py | 26 +++++++- tensorbay/geometry/vector.py | 33 ++++++++++- tensorbay/label/tests/test_label_box.py | 16 ++--- tensorbay/utility/attr.py | 79 +++++++++++++++++++++++++ 7 files changed, 177 insertions(+), 13 deletions(-) diff --git a/tensorbay/geometry/box.py b/tensorbay/geometry/box.py index 849e47e84..1ddd38a8a 100644 --- a/tensorbay/geometry/box.py +++ b/tensorbay/geometry/box.py @@ -564,3 +564,24 @@ def dumps(self) -> Dict[str, Dict[str, float]]: contents = self._transform.dumps() contents["size"] = self.size.dumps() return contents + + def isclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: + """Determine whether this 3D box is close to another in value. + + Arguments: + other: The other object to compare. + rel_tol: Maximum difference for being considered "close", relative to the + magnitude of the input values + abs_tol: Maximum difference for being considered "close", regardless of the + magnitude of the input values + + Returns: + A bool value indicating whether this vector is close to another. + + """ + if not isinstance(other, self.__class__): + return False + + return self._size.isclose( + other.size, rel_tol=rel_tol, abs_tol=abs_tol + ) and self._transform.isclose(other.transform, rel_tol=rel_tol, abs_tol=abs_tol) diff --git a/tensorbay/geometry/tests/test_box.py b/tensorbay/geometry/tests/test_box.py index 947603bc2..2140535db 100644 --- a/tensorbay/geometry/tests/test_box.py +++ b/tensorbay/geometry/tests/test_box.py @@ -116,10 +116,12 @@ def test_rmul(self): assert box3d.__rmul__(transform) == Box3D( size=(1, 1, 1), translation=[2, 0, 0], rotation=quaternion(-1, 0, 0, 0) ) - assert box3d.__rmul__(quaternion_1) == Box3D( - size=(1, 1, 1), - translation=[1.7999999999999996, 2, 2.6], - rotation=quaternion(-2, 1, 4, -3), + assert box3d.__rmul__(quaternion_1).isclose( + Box3D( + size=(1, 1, 1), + translation=[1.7999999999999996, 2, 2.6], + rotation=quaternion(-2, 1, 4, -3), + ) ) assert box3d.__rmul__(1) == NotImplemented diff --git a/tensorbay/geometry/tests/test_vector.py b/tensorbay/geometry/tests/test_vector.py index 5f3480e49..b29f7027d 100644 --- a/tensorbay/geometry/tests/test_vector.py +++ b/tensorbay/geometry/tests/test_vector.py @@ -125,6 +125,11 @@ def test_abs(self): assert abs(Vector(1, 1)) == 1.4142135623730951 assert abs(Vector(1, 1, 1)) == 1.7320508075688772 + def test_isclose(self): + assert Vector(1, 2).isclose(Vector2D(1.000000000001, 2)) + assert Vector(1, 2, 3).isclose(Vector3D(1.000000000001, 2, 2.999999999996)) + assert not Vector(1, 2, 3).isclose(Vector3D(1.100000000001, 2, 2.999999999996)) + def test_repr_head(self): vector = Vector(1, 2) assert vector._repr_head() == "Vector2D(1, 2)" diff --git a/tensorbay/geometry/transform.py b/tensorbay/geometry/transform.py index 1d08b5aa7..2deac6cdb 100644 --- a/tensorbay/geometry/transform.py +++ b/tensorbay/geometry/transform.py @@ -14,6 +14,7 @@ """ import warnings +from math import isclose as math_isclose from typing import Dict, Iterable, Optional, Type, TypeVar, Union, overload import numpy as np @@ -23,7 +24,9 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - from quaternion import as_rotation_matrix, from_rotation_matrix, quaternion, rotate_vectors + from quaternion import as_rotation_matrix, from_rotation_matrix + from quaternion import isclose as quaternion_isclose + from quaternion import quaternion, rotate_vectors _T = TypeVar("_T", bound="Transform3D") @@ -323,3 +326,24 @@ def inverse(self: _T) -> _T: rotation = self._rotation.inverse() translation = Vector3D(*rotate_vectors(rotation, -self._translation)) return self._create(translation, rotation) + + def isclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: + """Determine whether this 3D transform is close to another in value. + + Arguments: + other: The other object to compare. + rel_tol: Maximum difference for being considered "close", relative to the + magnitude of the input values + abs_tol: Maximum difference for being considered "close", regardless of the + magnitude of the input values + + Returns: + A bool value indicating whether this vector is close to another. + + """ + if not isinstance(other, self.__class__): + return False + + return self._translation.isclose( + other.translation, rel_tol=rel_tol, abs_tol=abs_tol + ) and all(quaternion_isclose(self._rotation, other.rotation)) diff --git a/tensorbay/geometry/vector.py b/tensorbay/geometry/vector.py index 756c63633..c016ab8f4 100644 --- a/tensorbay/geometry/vector.py +++ b/tensorbay/geometry/vector.py @@ -15,7 +15,9 @@ """ from itertools import zip_longest -from math import hypot, sqrt +from math import hypot +from math import isclose as math_isclose +from math import sqrt from sys import version_info from typing import Dict, Iterable, Optional, Sequence, Tuple, Type, TypeVar, Union @@ -204,6 +206,35 @@ def loads(contents: Dict[str, float]) -> _T: return Vector3D.loads(contents) return Vector2D.loads(contents) + def isclose( + self, other: Iterable[float], *, rel_tol: float = 1e-09, abs_tol: float = 0.0 + ) -> bool: + """Determine whether this vector is close to another in value. + + Arguments: + other: The other object to compare. + rel_tol: Maximum difference for being considered "close", relative to the + magnitude of the input values + abs_tol: Maximum difference for being considered "close", regardless of the + magnitude of the input values + + Raises: + TypeError: When other have inconsistent dimension. + + Returns: + A bool value indicating whether this vector is close to another. + + """ + try: + return all( + math_isclose(i, j, rel_tol=rel_tol, abs_tol=abs_tol) + for i, j in zip_longest(self._data, other) + ) + except TypeError as error: + raise TypeError( + f"The other object must have the dimension of {self._DIMENSION}" + ) from error + class Vector2D(Vector): """This class defines the concept of Vector2D. diff --git a/tensorbay/label/tests/test_label_box.py b/tensorbay/label/tests/test_label_box.py index d36889a82..1b9dc098f 100644 --- a/tensorbay/label/tests/test_label_box.py +++ b/tensorbay/label/tests/test_label_box.py @@ -156,13 +156,15 @@ def test_rmul(self): instance="12345", ) - assert labeledbox3d.__rmul__(quaternion_1) == LabeledBox3D( - size=size, - translation=[1.7999999999999996, 2, 2.6], - rotation=[-2, 1, 4, -3], - category="cat", - attributes={"gender": "male"}, - instance="12345", + assert labeledbox3d.__rmul__(quaternion_1).isclose( + LabeledBox3D( + size=size, + translation=[1.7999999999999996, 2, 2.6], + rotation=[-2, 1, 4, -3], + category="cat", + attributes={"gender": "male"}, + instance="12345", + ) ) assert labeledbox3d.__rmul__(1) == NotImplemented diff --git a/tensorbay/utility/attr.py b/tensorbay/utility/attr.py index 29b2c52fc..a6c5c8e7b 100644 --- a/tensorbay/utility/attr.py +++ b/tensorbay/utility/attr.py @@ -10,6 +10,7 @@ :class:`Field` is a class describing the attr related fields. """ +from math import isclose as math_isclose from sys import version_info from typing import ( Any, @@ -65,6 +66,7 @@ def __init__( ) -> None: self.loader: _Callable self.dumper: _Callable + self.is_close: Callable[[Any, Any], bool] self.is_dynamic = is_dynamic self.default = default @@ -95,6 +97,7 @@ def __init__(self, key: Optional[str]) -> None: self.loader: _Callable self.dumper: _Callable self.key = key + self.is_close: _Callable class AttrsMixin: @@ -123,6 +126,8 @@ def __init_subclass__(cls) -> None: field = getattr(cls, name, None) if isinstance(field, Field): field.loader, field.dumper = _get_operators(type_) + if cls.__name__ == "_LabelBase": + field.is_close = _get_isclose(type_) if hasattr(field, "key_converter"): field.key = field.key_converter(name) attrs_fields[name] = field @@ -201,6 +206,45 @@ def _dumps(self) -> Dict[str, Any]: _key_dumper(field.key, contents, field.dumper(value)) return contents + def isclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: + """Determine whether this instance is close to another in value. + + Arguments: + other: The other instance to compare. + rel_tol: Maximum difference for being considered "close", relative to the + magnitude of the input values + abs_tol: Maximum difference for being considered "close", regardless of the + magnitude of the input values + + Raises: + NotImplementedError: When the class type is not LabeledBox3D. + + Returns: + A bool value indicating whether this instance is close to another. + + """ + if self.__class__.__name__ != "LabeledBox3D": + raise NotImplementedError + + if not isinstance(other, self.__class__) or self.__dict__.keys() != other.__dict__.keys(): + return False + + base = getattr(self, _ATTRS_BASE, None) + result = True + if base and hasattr(base, "is_close"): + result = result and base.is_close(self, other, rel_tol=rel_tol, abs_tol=abs_tol) + + if not result: + return result + + for name, field in self._attrs_fields.items(): + if not hasattr(self, name): + continue + + result = result and field.is_close(getattr(self, name), getattr(other, name)) + + return result + def attr( *, @@ -291,6 +335,38 @@ def _get_origin_in_3_6(annotation: Any) -> Any: _get_origin = _get_origin_in_3_6 if version_info < (3, 7) else _get_origin_in_3_7 +def _get_isclose(annotation: Any) -> Callable[[Any, Any], Any]: + """Get attr isclose methods by annotations. + + AttrsMixin has three operating types which are classified by attr annotation. + 1. builtin types, like str, int, None + 2. tensorbay custom class, like tensorbay.label.Classification + 3. tensorbay custom class list or NameList, like List[tensorbay.label.LabeledBox2D] + + Arguments: + annotation: Type of the attr. + + Returns: + The ``isclose`` methods of the annotation. + + """ + origin = _get_origin(annotation) + if isinstance(origin, type) and issubclass(origin, Sequence): + type_ = annotation.__args__[0] + return lambda self, other: all(_get_isclose(type_)(i, j) for i, j in zip(self, other)) + + type_ = annotation + + mod = getattr(type_, "__module__", None) + if mod in _BUILTINS: + if type_ in (int, float): + return math_isclose + if mod == "typing": + return origin.__eq__ + return type_.__eq__ # type: ignore[no-any-return] + return type_.isclose # type: ignore[no-any-return] + + def _get_operators(annotation: Any) -> Tuple[_Callable, _Callable]: """Get attr operating methods by annotations. @@ -315,6 +391,9 @@ def _get_operators(annotation: Any) -> Tuple[_Callable, _Callable]: sequence = None type_ = annotation + if annotation == Dict[str, Union[str, int, float, bool, List[Union[str, int, float, bool]]]]: + print("yes", origin, type_, getattr(type_, "__module__", None)) + if {getattr(sequence, "__module__", None), getattr(type_, "__module__", None)} < _BUILTINS: return _builtin_operator, _builtin_operator