Skip to content

Commit

Permalink
Merge pull request #140 from stefmolin/rings-shape
Browse files Browse the repository at this point in the history
Add Rings shape
  • Loading branch information
stefmolin authored Sep 23, 2023
2 parents 6ecd0bf + 847c818 commit 2f9491a
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 12 deletions.
63 changes: 53 additions & 10 deletions src/data_morph/shapes/circles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from numbers import Number

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes

from ..data.dataset import Dataset
Expand Down Expand Up @@ -85,37 +86,51 @@ def plot(self, ax: Axes = None) -> Axes:
return ax


class Bullseye(Shape):
class Rings(Shape):
"""
Class representing a bullseye shape comprising two concentric circles.
Class representing rings comprising multiple concentric circles.
.. plot::
:scale: 75
:caption:
This shape is generated using the panda dataset.
from data_morph.data.loader import DataLoader
from data_morph.shapes.circles import Bullseye
from data_morph.shapes.circles import Rings
_ = Bullseye(DataLoader.load_dataset('panda')).plot()
_ = Rings(DataLoader.load_dataset('panda')).plot()
Parameters
----------
dataset : Dataset
The starting dataset to morph into other shapes.
num_rings : int, default 4
The number of rings to include. Must be greater than 1.
See Also
--------
Circle : The individual rings are represented as circles.
"""

def __init__(self, dataset: Dataset) -> None:
def __init__(self, dataset: Dataset, num_rings: int = 4) -> None:
if not isinstance(num_rings, int):
raise TypeError('num_rings must be an integer')
if num_rings <= 1:
raise ValueError('num_rings must be greater than 1')

stdev = dataset.df.std().mean()
self.circles: list[Circle] = [Circle(dataset, r) for r in [stdev, stdev * 2]]
"""list[Circle]: The inner and outer :class:`Circle` objects."""
self.circles: list[Circle] = [
Circle(dataset, r)
for r in np.linspace(stdev / num_rings * 2, stdev * 2, num_rings)
]
"""list[Circle]: The individual rings represented by :class:`Circle` objects."""

def __repr__(self) -> str:
return self._recursive_repr('circles')

def distance(self, x: Number, y: Number) -> float:
"""
Calculate the minimum absolute distance between this bullseye's inner and outer
Calculate the minimum absolute distance between any of this shape's
circles' edges and a point (x, y).
Parameters
Expand All @@ -126,13 +141,13 @@ def distance(self, x: Number, y: Number) -> float:
Returns
-------
float
The minimum absolute distance between this bullseye's inner and outer
The minimum absolute distance between any of this shape's
circles' edges and the point (x, y).
See Also
--------
Circle.distance :
A bullseye consists of two circles, so we use the minimum
Rings consists of multiple circles, so we use the minimum
distance to one of the circles.
"""
return min(circle.distance(x, y) for circle in self.circles)
Expand All @@ -155,3 +170,31 @@ def plot(self, ax: Axes = None) -> Axes:
for circle in self.circles:
ax = circle.plot(ax)
return ax


class Bullseye(Rings):
"""
Class representing a bullseye shape comprising two concentric circles.
.. plot::
:scale: 75
:caption:
This shape is generated using the panda dataset.
from data_morph.data.loader import DataLoader
from data_morph.shapes.circles import Bullseye
_ = Bullseye(DataLoader.load_dataset('panda')).plot()
Parameters
----------
dataset : Dataset
The starting dataset to morph into other shapes.
See Also
--------
Rings : The Bullseye is a special case where we only have 2 rings.
"""

def __init__(self, dataset: Dataset) -> None:
super().__init__(dataset=dataset, num_rings=2)
8 changes: 6 additions & 2 deletions src/data_morph/shapes/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class ShapeFactory:
'up_parab': points.UpParabola,
'diamond': polygons.Diamond,
'rectangle': polygons.Rectangle,
'rings': circles.Rings,
'star': polygons.Star,
}

Expand All @@ -61,22 +62,25 @@ class ShapeFactory:
def __init__(self, dataset: Dataset) -> None:
self._dataset: Dataset = dataset

def generate_shape(self, shape: str) -> Shape:
def generate_shape(self, shape: str, **kwargs) -> Shape:
"""
Generate the shape object based on the dataset.
Parameters
----------
shape : str
The desired shape. See :attr:`.AVAILABLE_SHAPES`.
**kwargs
Additional keyword arguments to pass down when creating
the shape.
Returns
-------
Shape
An shape object of the requested type.
"""
try:
return self._SHAPE_MAPPING[shape](self._dataset)
return self._SHAPE_MAPPING[shape](self._dataset, **kwargs)
except KeyError as err:
raise ValueError(f'No such shape as {shape}.') from err

Expand Down
36 changes: 36 additions & 0 deletions tests/shapes/test_circles.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,39 @@ def test_is_circle(self, shape):
shape.cy + shape.r * np.sin(angles),
):
assert pytest.approx(shape.distance(x, y)) == 0


class TestRings(CirclesModuleTestBase):
"""Test the Rings class."""

shape_name = 'rings'
distance_test_cases = [[(20, 50), 3.16987], [(10, 25), 9.08004]]
repr_regex = (
r'^<Rings>\n'
r' circles=\n'
r' <Circle cx=(\d+\.*\d*) cy=(\d+\.*\d*) r=(\d+\.*\d*)>\n'
r' <Circle cx=(\d+\.*\d*) cy=(\d+\.*\d*) r=(\d+\.*\d*)>'
)

@pytest.mark.parametrize('num_rings', [3, 5])
def test_init(self, shape_factory, num_rings):
"""Test that the Rings contains multiple concentric circles."""
shape = shape_factory.generate_shape(self.shape_name, num_rings=num_rings)

assert len(shape.circles) == num_rings
assert all(
getattr(circle, center_coord) == getattr(shape.circles[0], center_coord)
for circle in shape.circles[1:]
for center_coord in ['cx', 'cy']
)
assert len({circle.r for circle in shape.circles}) == num_rings

@pytest.mark.parametrize('num_rings', ['3', -5, 1, True])
def test_num_rings_is_valid(self, shape_factory, num_rings):
"""Test that num_rings input validation is working."""
if isinstance(num_rings, int):
with pytest.raises(ValueError, match='num_rings must be greater than 1'):
_ = shape_factory.generate_shape(self.shape_name, num_rings=num_rings)
else:
with pytest.raises(TypeError, match='num_rings must be an integer'):
_ = shape_factory.generate_shape(self.shape_name, num_rings=num_rings)

0 comments on commit 2f9491a

Please sign in to comment.