Skip to content

Commit

Permalink
Add heart target shape.
Browse files Browse the repository at this point in the history
  • Loading branch information
stefmolin committed May 14, 2023
1 parent b620d03 commit 9d421c4
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/data_morph/shapes/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ShapeFactory:
'x': lines.XLines,
'dots': points.DotsGrid,
'down_parab': points.DownParabola,
'heart': points.Heart,
'left_parab': points.LeftParabola,
'scatter': points.Scatter,
'right_parab': points.RightParabola,
Expand Down
48 changes: 48 additions & 0 deletions src/data_morph/shapes/points.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,54 @@ def __str__(self) -> str:
return 'down_parab'


class Heart(PointCollection):
"""
Class for the heart shape.
.. plot::
:scale: 75
:caption:
This shape is generated using the panda dataset.
from data_morph.data.loader import DataLoader
from data_morph.shapes.points import Heart
_ = Heart(DataLoader.load_dataset('panda')).plot()
Parameters
----------
dataset : Dataset
The starting dataset to morph into other shapes.
Notes
-----
The formula for the heart shape is inspired by
`Heart Curve <https://mathworld.wolfram.com/HeartCurve.html>`_:
Weisstein, Eric W. "Heart Curve." From `MathWorld <https://mathworld.wolfram.com/>`_
--A Wolfram Web Resource. https://mathworld.wolfram.com/HeartCurve.html
"""

def __init__(self, dataset: Dataset) -> None:
x_bounds = dataset.data_bounds.x_bounds
y_bounds = dataset.data_bounds.y_bounds

x_shift = sum(x_bounds) / 2
y_shift = sum(y_bounds) / 2

t = np.linspace(-3, 3, num=80)

x = 16 * np.sin(t) ** 3
y = 13 * np.cos(t) - 5 * np.cos(2 * t) - 2 * np.cos(3 * t) - np.cos(4 * t)

# scale by the half the widest width of the heart
scale_factor = (x_bounds[1] - x_shift) / 16

super().__init__(
*np.stack([x * scale_factor + x_shift, y * scale_factor + y_shift], axis=1)
)


class LeftParabola(PointCollection):
"""
Class for the left parabola shape.
Expand Down
16 changes: 15 additions & 1 deletion tests/shapes/test_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_distance(self, shape, test_point, expected_distance):
Test the distance() method parametrized by distance_test_cases
(see conftest.py).
"""
assert pytest.approx(shape.distance(*test_point)) == expected_distance
assert pytest.approx(shape.distance(*test_point), abs=1e-5) == expected_distance


class TestDotsGrid(PointsModuleTestBase):
Expand Down Expand Up @@ -67,6 +67,20 @@ def test_points_form_symmetric_grid(self, shape):
assert row_midpoint == middle_row[point][1]


class TestHeart(PointsModuleTestBase):
"""Test the Heart class."""

shape_name = 'heart'
distance_test_cases = [
[(19.89946048, 54.82281916), 0.0],
[(10.84680454, 70.18556376), 0.0],
[(29.9971295, 67.66402445), 0.0],
[(27.38657942, 62.417184), 0.0],
[(20, 50), 4.567369],
[(10, 80), 8.564365],
]


class TestScatter(PointsModuleTestBase):
"""Test the Scatter class."""

Expand Down

0 comments on commit 9d421c4

Please sign in to comment.