diff --git a/src/data_morph/shapes/circles.py b/src/data_morph/shapes/circles.py index fe444327..0896afc0 100644 --- a/src/data_morph/shapes/circles.py +++ b/src/data_morph/shapes/circles.py @@ -3,6 +3,7 @@ from numbers import Number import matplotlib.pyplot as plt +import numpy as np from matplotlib.axes import Axes from ..data.dataset import Dataset @@ -85,9 +86,9 @@ def plot(self, ax: Axes = None) -> Axes: return ax -class Bullseye(Shape): +class Rings(Shape): """ - Class representing a bullseye shape comprising two concentric circles. + Class representing rings comprising multiple concentric circles. .. plot:: :scale: 75 @@ -95,27 +96,41 @@ class Bullseye(Shape): This shape is generated using the panda dataset. from data_morph.data.loader import DataLoader - from data_morph.shapes.circles import Bullseye + from data_morph.shapes.circles import Rings - _ = Bullseye(DataLoader.load_dataset('panda')).plot() + _ = Rings(DataLoader.load_dataset('panda')).plot() Parameters ---------- dataset : Dataset The starting dataset to morph into other shapes. + num_rings : int, default 4 + The number of rings to include. Must be greater than 1. + + See Also + -------- + Circle : The individual rings are represented as circles. """ - def __init__(self, dataset: Dataset) -> None: + def __init__(self, dataset: Dataset, num_rings: int = 4) -> None: + if not isinstance(num_rings, int): + raise TypeError('num_rings must be an integer') + if num_rings <= 1: + raise ValueError('num_rings must be greater than 1') + stdev = dataset.df.std().mean() - self.circles: list[Circle] = [Circle(dataset, r) for r in [stdev, stdev * 2]] - """list[Circle]: The inner and outer :class:`Circle` objects.""" + self.circles: list[Circle] = [ + Circle(dataset, r) + for r in np.linspace(stdev / num_rings * 2, stdev * 2, num_rings) + ] + """list[Circle]: The individual rings represented by :class:`Circle` objects.""" def __repr__(self) -> str: return self._recursive_repr('circles') def distance(self, x: Number, y: Number) -> float: """ - Calculate the minimum absolute distance between this bullseye's inner and outer + Calculate the minimum absolute distance between any of this shape's circles' edges and a point (x, y). Parameters @@ -126,13 +141,13 @@ def distance(self, x: Number, y: Number) -> float: Returns ------- float - The minimum absolute distance between this bullseye's inner and outer + The minimum absolute distance between any of this shape's circles' edges and the point (x, y). See Also -------- Circle.distance : - A bullseye consists of two circles, so we use the minimum + Rings consists of multiple circles, so we use the minimum distance to one of the circles. """ return min(circle.distance(x, y) for circle in self.circles) @@ -155,3 +170,31 @@ def plot(self, ax: Axes = None) -> Axes: for circle in self.circles: ax = circle.plot(ax) return ax + + +class Bullseye(Rings): + """ + Class representing a bullseye shape comprising two concentric circles. + + .. plot:: + :scale: 75 + :caption: + This shape is generated using the panda dataset. + + from data_morph.data.loader import DataLoader + from data_morph.shapes.circles import Bullseye + + _ = Bullseye(DataLoader.load_dataset('panda')).plot() + + Parameters + ---------- + dataset : Dataset + The starting dataset to morph into other shapes. + + See Also + -------- + Rings : The Bullseye is a special case where we only have 2 rings. + """ + + def __init__(self, dataset: Dataset) -> None: + super().__init__(dataset=dataset, num_rings=2) diff --git a/src/data_morph/shapes/factory.py b/src/data_morph/shapes/factory.py index f6fb08af..2b2db2af 100644 --- a/src/data_morph/shapes/factory.py +++ b/src/data_morph/shapes/factory.py @@ -51,6 +51,7 @@ class ShapeFactory: 'up_parab': points.UpParabola, 'diamond': polygons.Diamond, 'rectangle': polygons.Rectangle, + 'rings': circles.Rings, 'star': polygons.Star, } @@ -61,7 +62,7 @@ class ShapeFactory: def __init__(self, dataset: Dataset) -> None: self._dataset: Dataset = dataset - def generate_shape(self, shape: str) -> Shape: + def generate_shape(self, shape: str, **kwargs) -> Shape: """ Generate the shape object based on the dataset. @@ -69,6 +70,9 @@ def generate_shape(self, shape: str) -> Shape: ---------- shape : str The desired shape. See :attr:`.AVAILABLE_SHAPES`. + **kwargs + Additional keyword arguments to pass down when creating + the shape. Returns ------- @@ -76,7 +80,7 @@ def generate_shape(self, shape: str) -> Shape: An shape object of the requested type. """ try: - return self._SHAPE_MAPPING[shape](self._dataset) + return self._SHAPE_MAPPING[shape](self._dataset, **kwargs) except KeyError as err: raise ValueError(f'No such shape as {shape}.') from err diff --git a/tests/shapes/test_circles.py b/tests/shapes/test_circles.py index e05de2df..1b6c906d 100644 --- a/tests/shapes/test_circles.py +++ b/tests/shapes/test_circles.py @@ -71,3 +71,39 @@ def test_is_circle(self, shape): shape.cy + shape.r * np.sin(angles), ): assert pytest.approx(shape.distance(x, y)) == 0 + + +class TestRings(CirclesModuleTestBase): + """Test the Rings class.""" + + shape_name = 'rings' + distance_test_cases = [[(20, 50), 3.16987], [(10, 25), 9.08004]] + repr_regex = ( + r'^\n' + r' circles=\n' + r' \n' + r' ' + ) + + @pytest.mark.parametrize('num_rings', [3, 5]) + def test_init(self, shape_factory, num_rings): + """Test that the Rings contains multiple concentric circles.""" + shape = shape_factory.generate_shape(self.shape_name, num_rings=num_rings) + + assert len(shape.circles) == num_rings + assert all( + getattr(circle, center_coord) == getattr(shape.circles[0], center_coord) + for circle in shape.circles[1:] + for center_coord in ['cx', 'cy'] + ) + assert len({circle.r for circle in shape.circles}) == num_rings + + @pytest.mark.parametrize('num_rings', ['3', -5, 1, True]) + def test_num_rings_is_valid(self, shape_factory, num_rings): + """Test that num_rings input validation is working.""" + if isinstance(num_rings, int): + with pytest.raises(ValueError, match='num_rings must be greater than 1'): + _ = shape_factory.generate_shape(self.shape_name, num_rings=num_rings) + else: + with pytest.raises(TypeError, match='num_rings must be an integer'): + _ = shape_factory.generate_shape(self.shape_name, num_rings=num_rings)