Skip to content

Commit

Permalink
feat: add "isclose" method for Vector, Box3D and LabeledBox3D
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-000 committed Jul 8, 2021
1 parent a4f2b66 commit d430f9e
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 13 deletions.
21 changes: 21 additions & 0 deletions tensorbay/geometry/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,3 +564,24 @@ def dumps(self) -> Dict[str, Dict[str, float]]:
contents = self._transform.dumps()
contents["size"] = self.size.dumps()
return contents

def isclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
"""Determine whether this 3D box is close to another in value.
Arguments:
other: The other object to compare.
rel_tol: Maximum difference for being considered "close", relative to the
magnitude of the input values
abs_tol: Maximum difference for being considered "close", regardless of the
magnitude of the input values
Returns:
A bool value indicating whether this vector is close to another.
"""
if not isinstance(other, self.__class__):
return False

return self._size.isclose(
other.size, rel_tol=rel_tol, abs_tol=abs_tol
) and self._transform.isclose(other.transform, rel_tol=rel_tol, abs_tol=abs_tol)
10 changes: 6 additions & 4 deletions tensorbay/geometry/tests/test_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ def test_rmul(self):
assert box3d.__rmul__(transform) == Box3D(
size=(1, 1, 1), translation=[2, 0, 0], rotation=quaternion(-1, 0, 0, 0)
)
assert box3d.__rmul__(quaternion_1) == Box3D(
size=(1, 1, 1),
translation=[1.7999999999999996, 2, 2.6],
rotation=quaternion(-2, 1, 4, -3),
assert box3d.__rmul__(quaternion_1).isclose(
Box3D(
size=(1, 1, 1),
translation=[1.7999999999999996, 2, 2.6],
rotation=quaternion(-2, 1, 4, -3),
)
)
assert box3d.__rmul__(1) == NotImplemented

Expand Down
5 changes: 5 additions & 0 deletions tensorbay/geometry/tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ def test_abs(self):
assert abs(Vector(1, 1)) == 1.4142135623730951
assert abs(Vector(1, 1, 1)) == 1.7320508075688772

def test_isclose(self):
assert Vector(1, 2).isclose(Vector2D(1.000000000001, 2))
assert Vector(1, 2, 3).isclose(Vector3D(1.000000000001, 2, 2.999999999996))
assert not Vector(1, 2, 3).isclose(Vector3D(1.100000000001, 2, 2.999999999996))

def test_repr_head(self):
vector = Vector(1, 2)
assert vector._repr_head() == "Vector2D(1, 2)"
Expand Down
26 changes: 25 additions & 1 deletion tensorbay/geometry/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""

import warnings
from math import isclose as math_isclose
from typing import Dict, Iterable, Optional, Type, TypeVar, Union, overload

import numpy as np
Expand All @@ -23,7 +24,9 @@

with warnings.catch_warnings():
warnings.simplefilter("ignore")
from quaternion import as_rotation_matrix, from_rotation_matrix, quaternion, rotate_vectors
from quaternion import as_rotation_matrix, from_rotation_matrix
from quaternion import isclose as quaternion_isclose
from quaternion import quaternion, rotate_vectors

_T = TypeVar("_T", bound="Transform3D")

Expand Down Expand Up @@ -323,3 +326,24 @@ def inverse(self: _T) -> _T:
rotation = self._rotation.inverse()
translation = Vector3D(*rotate_vectors(rotation, -self._translation))
return self._create(translation, rotation)

def isclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
"""Determine whether this 3D transform is close to another in value.
Arguments:
other: The other object to compare.
rel_tol: Maximum difference for being considered "close", relative to the
magnitude of the input values
abs_tol: Maximum difference for being considered "close", regardless of the
magnitude of the input values
Returns:
A bool value indicating whether this vector is close to another.
"""
if not isinstance(other, self.__class__):
return False

return self._translation.isclose(
other.translation, rel_tol=rel_tol, abs_tol=abs_tol
) and all(quaternion_isclose(self._rotation, other.rotation))
33 changes: 32 additions & 1 deletion tensorbay/geometry/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""

from itertools import zip_longest
from math import hypot, sqrt
from math import hypot
from math import isclose as math_isclose
from math import sqrt
from sys import version_info
from typing import Dict, Iterable, Optional, Sequence, Tuple, Type, TypeVar, Union

Expand Down Expand Up @@ -204,6 +206,35 @@ def loads(contents: Dict[str, float]) -> _T:
return Vector3D.loads(contents)
return Vector2D.loads(contents)

def isclose(
self, other: Iterable[float], *, rel_tol: float = 1e-09, abs_tol: float = 0.0
) -> bool:
"""Determine whether this vector is close to another in value.
Arguments:
other: The other object to compare.
rel_tol: Maximum difference for being considered "close", relative to the
magnitude of the input values
abs_tol: Maximum difference for being considered "close", regardless of the
magnitude of the input values
Raises:
TypeError: When other have inconsistent dimension.
Returns:
A bool value indicating whether this vector is close to another.
"""
try:
return all(
math_isclose(i, j, rel_tol=rel_tol, abs_tol=abs_tol)
for i, j in zip_longest(self._data, other)
)
except TypeError as error:
raise TypeError(
f"The other object must have the dimension of {self._DIMENSION}"
) from error


class Vector2D(Vector):
"""This class defines the concept of Vector2D.
Expand Down
16 changes: 9 additions & 7 deletions tensorbay/label/tests/test_label_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,15 @@ def test_rmul(self):
instance="12345",
)

assert labeledbox3d.__rmul__(quaternion_1) == LabeledBox3D(
size=size,
translation=[1.7999999999999996, 2, 2.6],
rotation=[-2, 1, 4, -3],
category="cat",
attributes={"gender": "male"},
instance="12345",
assert labeledbox3d.__rmul__(quaternion_1).isclose(
LabeledBox3D(
size=size,
translation=[1.7999999999999996, 2, 2.6],
rotation=[-2, 1, 4, -3],
category="cat",
attributes={"gender": "male"},
instance="12345",
)
)

assert labeledbox3d.__rmul__(1) == NotImplemented
Expand Down
79 changes: 79 additions & 0 deletions tensorbay/utility/attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
:class:`Field` is a class describing the attr related fields.
"""
from math import isclose as math_isclose
from sys import version_info
from typing import (
Any,
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
) -> None:
self.loader: _Callable
self.dumper: _Callable
self.is_close: Callable[[Any, Any], bool]

self.is_dynamic = is_dynamic
self.default = default
Expand Down Expand Up @@ -95,6 +97,7 @@ def __init__(self, key: Optional[str]) -> None:
self.loader: _Callable
self.dumper: _Callable
self.key = key
self.is_close: _Callable


class AttrsMixin:
Expand Down Expand Up @@ -123,6 +126,8 @@ def __init_subclass__(cls) -> None:
field = getattr(cls, name, None)
if isinstance(field, Field):
field.loader, field.dumper = _get_operators(type_)
if cls.__name__ == "_LabelBase":
field.is_close = _get_isclose(type_)
if hasattr(field, "key_converter"):
field.key = field.key_converter(name)
attrs_fields[name] = field
Expand Down Expand Up @@ -201,6 +206,45 @@ def _dumps(self) -> Dict[str, Any]:
_key_dumper(field.key, contents, field.dumper(value))
return contents

def isclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
"""Determine whether this instance is close to another in value.
Arguments:
other: The other instance to compare.
rel_tol: Maximum difference for being considered "close", relative to the
magnitude of the input values
abs_tol: Maximum difference for being considered "close", regardless of the
magnitude of the input values
Raises:
NotImplementedError: When the class type is not LabeledBox3D.
Returns:
A bool value indicating whether this instance is close to another.
"""
if self.__class__.__name__ != "LabeledBox3D":
raise NotImplementedError

if not isinstance(other, self.__class__) or self.__dict__.keys() != other.__dict__.keys():
return False

base = getattr(self, _ATTRS_BASE, None)
result = True
if base and hasattr(base, "is_close"):
result = result and base.is_close(self, other, rel_tol=rel_tol, abs_tol=abs_tol)

if not result:
return result

for name, field in self._attrs_fields.items():
if not hasattr(self, name):
continue

result = result and field.is_close(getattr(self, name), getattr(other, name))

return result


def attr(
*,
Expand Down Expand Up @@ -291,6 +335,38 @@ def _get_origin_in_3_6(annotation: Any) -> Any:
_get_origin = _get_origin_in_3_6 if version_info < (3, 7) else _get_origin_in_3_7


def _get_isclose(annotation: Any) -> Callable[[Any, Any], Any]:
"""Get attr isclose methods by annotations.
AttrsMixin has three operating types which are classified by attr annotation.
1. builtin types, like str, int, None
2. tensorbay custom class, like tensorbay.label.Classification
3. tensorbay custom class list or NameList, like List[tensorbay.label.LabeledBox2D]
Arguments:
annotation: Type of the attr.
Returns:
The ``isclose`` methods of the annotation.
"""
origin = _get_origin(annotation)
if isinstance(origin, type) and issubclass(origin, Sequence):
type_ = annotation.__args__[0]
return lambda self, other: all(_get_isclose(type_)(i, j) for i, j in zip(self, other))

type_ = annotation

mod = getattr(type_, "__module__", None)
if mod in _BUILTINS:
if type_ in (int, float):
return math_isclose
if mod == "typing":
return origin.__eq__
return type_.__eq__ # type: ignore[no-any-return]
return type_.isclose # type: ignore[no-any-return]


def _get_operators(annotation: Any) -> Tuple[_Callable, _Callable]:
"""Get attr operating methods by annotations.
Expand All @@ -315,6 +391,9 @@ def _get_operators(annotation: Any) -> Tuple[_Callable, _Callable]:
sequence = None
type_ = annotation

if annotation == Dict[str, Union[str, int, float, bool, List[Union[str, int, float, bool]]]]:
print("yes", origin, type_, getattr(type_, "__module__", None))

if {getattr(sequence, "__module__", None), getattr(type_, "__module__", None)} < _BUILTINS:
return _builtin_operator, _builtin_operator

Expand Down

0 comments on commit d430f9e

Please sign in to comment.