Skip to content

Commit

Permalink
Merge pull request #124 from stefmolin/diamond
Browse files Browse the repository at this point in the history
Add diamond target shape.
  • Loading branch information
stefmolin authored May 14, 2023
2 parents 3ed1b5f + f99740e commit ad24e18
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/data_morph/shapes/bases/point_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def plot(self, ax: Axes = None) -> Axes:
fig, ax = plt.subplots(layout='constrained')
fig.get_layout_engine().set(w_pad=0.2, h_pad=0.2)
_ = ax.axis('equal')
_ = ax.scatter(*self.points.T, s=2, color='k', alpha=self._alpha)
_ = ax.scatter(*self.points.T, s=5, color='k', alpha=self._alpha)
return ax
1 change: 1 addition & 0 deletions src/data_morph/shapes/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ShapeFactory:
'scatter': points.Scatter,
'right_parab': points.RightParabola,
'up_parab': points.UpParabola,
'diamond': polygons.Diamond,
'rectangle': polygons.Rectangle,
'star': polygons.Star,
}
Expand Down
36 changes: 36 additions & 0 deletions src/data_morph/shapes/polygons.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,42 @@
from .bases.line_collection import LineCollection


class Diamond(LineCollection):
"""
Class for the diamond shape.
.. plot::
:scale: 75
:caption:
This shape is generated using the panda dataset.
import matplotlib.pyplot as plt
from data_morph.data.loader import DataLoader
from data_morph.shapes.polygons import Diamond
_ = Diamond(DataLoader.load_dataset('panda')).plot()
Parameters
----------
dataset : Dataset
The starting dataset to morph into other shapes.
"""

def __init__(self, dataset: Dataset) -> None:
xmin, xmax = dataset.df.x.quantile([0.05, 0.95])
ymin, ymax = dataset.df.y.quantile([0.05, 0.95])

xmid = (xmax + xmin) / 2
ymid = (ymax + ymin) / 2

super().__init__(
[[xmin, ymid], [xmid, ymax]],
[[xmid, ymax], [xmax, ymid]],
[[xmax, ymid], [xmid, ymin]],
[[xmid, ymin], [xmin, ymid]],
)


class Rectangle(LineCollection):
"""
Class for the rectangle shape.
Expand Down
12 changes: 12 additions & 0 deletions tests/shapes/test_polygons.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ def test_lines_form_polygon(self, shape):
assert np.unique(endpoints, axis=0).shape[0] == self.expected_line_count


class TestDiamond(PolygonsModuleTestBase):
"""Test the Diamond class."""

shape_name = 'diamond'
distance_test_cases = [[(20, 50), 0.0], [(30, 60), 2.773501]]
expected_line_count = 4

def test_slopes(self, slopes):
"""Test that the slopes are as expected."""
np.testing.assert_array_equal(np.sort(slopes).flatten(), [-1.5, -1.5, 1.5, 1.5])


class TestRectangle(PolygonsModuleTestBase):
"""Test the Rectangle class."""

Expand Down

0 comments on commit ad24e18

Please sign in to comment.