Skip to content

Commit

Permalink
Merge branch 'main' into jelic/feature/vectorize
Browse files Browse the repository at this point in the history
  • Loading branch information
stefmolin authored Nov 25, 2024
2 parents 34be08a + e440ee7 commit 0a25272
Show file tree
Hide file tree
Showing 39 changed files with 1,228 additions and 952 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ repos:
exclude: (\.(svg|png|pdf)$)|(CODE_OF_CONDUCT.md)

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.1
rev: v0.8.0
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --show-fixes]
Expand All @@ -44,7 +44,7 @@ repos:
files: tests/.*

- repo: https://github.com/tox-dev/pyproject-fmt
rev: v2.4.3
rev: v2.5.0
hooks:
- id: pyproject-fmt
args: [--keep-full-version, --no-print-diff]
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import data_morph

sys.path.insert(0, str(Path().absolute()))
from post_build import determine_versions # noqa: E402
from post_build import determine_versions

project = 'Data Morph'
current_year = dt.date.today().year
Expand Down
20 changes: 12 additions & 8 deletions docs/tutorials/shape-creation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,18 @@ to calculate its position and scale:
class XLines(LineCollection):
name = 'x'
def __init__(self, dataset: Dataset) -> None:
xmin, xmax = dataset.morph_bounds.x_bounds
ymin, ymax = dataset.morph_bounds.y_bounds
super().__init__([[xmin, ymin], [xmax, ymax]], [[xmin, ymax], [xmax, ymin]])
def __str__(self) -> str:
return 'x'
Since we inherit from :class:`.LineCollection` here, we don't need to define
the ``distance()`` and ``plot()`` methods (unless we want to override them).
We do override the ``__str__()`` method here since the default will result in
We do set the ``name`` attribute here since the default will result in
a value of ``xlines`` and ``x`` makes more sense for use in the documentation
(see :class:`.ShapeFactory`).

Expand All @@ -82,11 +82,15 @@ Register the shape
For the ``data-morph`` CLI to find your shape, you need to register it with the
:class:`.ShapeFactory`:

1. Add your shape class to the appropriate file inside the ``src/data_morph/shapes/``
directory. Note that the filenames correspond to the type of shape (*e.g.*, use
``src/data_morph/shapes/points.py`` for a new shape inheriting from :class:`.PointCollection`).
2. Add an entry to the ``ShapeFactory._SHAPE_MAPPING`` dictionary in
``src/data_morph/shapes/factory.py``.
1. Add your shape class to the appropriate module inside the ``src/data_morph/shapes/``
directory. Note that these correspond to the type of shape (*e.g.*, use
``src/data_morph/shapes/points/<your_shape>.py`` for a new shape inheriting from
:class:`.PointCollection`).
2. Add your shape to ``__all__`` in that module's ``__init__.py`` (*e.g.*, use
``src/data_morph/shapes/points/__init__.py`` for a new shape inheriting from
:class:`.PointCollection`).
3. Add an entry to the ``ShapeFactory._SHAPE_CLASSES`` tuple in
``src/data_morph/shapes/factory.py``, preserving alphabetical order.

Test out the shape
------------------
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,13 @@ lint.select = [
"I", # isort
"N", # pep8-naming
"PTH", # flake8-use-pathlib
"RUF", # ruff-specific rules
"SIM", # flake8-simplify
"TRY", # tryceratops
"UP", # pyupgrade
"W", # pycodestyle warning
]
lint.ignore = [
"ANN101", # missing type annotation for self (will be removed in future ruff version)
"ANN102", # missing type annotation for cls in classmethod (will be removed in future ruff version)
"E501", # line-too-long
"TRY003", # avoid specifying long messages outside the exception class (revisit later and consider making custom exceptions)
]
Expand Down
2 changes: 1 addition & 1 deletion src/data_morph/bounds/bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __eq__(self, other: BoundingBox) -> bool:
def __repr__(self) -> str:
return '<BoundingBox>\n' f' x={self.x_bounds}' '\n' f' y={self.y_bounds}'

def adjust_bounds(self, x: Number = None, y: Number = None) -> None:
def adjust_bounds(self, x: Number | None = None, y: Number | None = None) -> None:
"""
Adjust bounding box range.
Expand Down
8 changes: 5 additions & 3 deletions src/data_morph/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Class representing a dataset for morphing."""

from __future__ import annotations

from numbers import Number

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -39,13 +41,13 @@ class Dataset:
Utility for creating :class:`Dataset` objects from CSV files.
"""

_REQUIRED_COLUMNS = ['x', 'y']
_REQUIRED_COLUMNS = ('x', 'y')

def __init__(
self,
name: str,
df: pd.DataFrame,
scale: Number = None,
scale: Number | None = None,
) -> None:
self.df: pd.DataFrame = self._validate_data(df).pipe(self._scale_data, scale)
"""pandas.DataFrame: DataFrame containing columns x and y."""
Expand Down Expand Up @@ -184,7 +186,7 @@ def _validate_data(self, data: pd.DataFrame) -> pd.DataFrame:

@plot_with_custom_style
def plot(
self, ax: Axes = None, show_bounds: bool = True, title: str = 'default'
self, ax: Axes | None = None, show_bounds: bool = True, title: str = 'default'
) -> Axes:
"""
Plot the dataset and its bounds.
Expand Down
7 changes: 5 additions & 2 deletions src/data_morph/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Load data for morphing."""

from __future__ import annotations

from importlib.resources import files
from itertools import zip_longest
from numbers import Number
from pathlib import Path
from typing import ClassVar

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -41,7 +44,7 @@ class DataLoader:
"""

_DATA_PATH: str = 'data/starter_shapes/'
_DATASETS: dict = {
_DATASETS: ClassVar[dict[str, str]] = {
'bunny': 'bunny.csv',
'cat': 'cat.csv',
'dino': 'dino.csv',
Expand All @@ -66,7 +69,7 @@ def __init__(self) -> None:
def load_dataset(
cls,
dataset: str,
scale: Number = None,
scale: Number | None = None,
) -> Dataset:
"""
Load dataset.
Expand Down
4 changes: 3 additions & 1 deletion src/data_morph/shapes/bases/line_collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Base class for shapes that are composed of lines."""

from __future__ import annotations

from collections.abc import Iterable
from numbers import Number

Expand Down Expand Up @@ -97,7 +99,7 @@ def distance(self, x: Number, y: Number) -> float:
)

@plot_with_custom_style
def plot(self, ax: Axes = None) -> Axes:
def plot(self, ax: Axes | None = None) -> Axes:
"""
Plot the shape.
Expand Down
4 changes: 3 additions & 1 deletion src/data_morph/shapes/bases/point_collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Base class for shapes that are composed of points."""

from __future__ import annotations

from collections.abc import Iterable
from numbers import Number

Expand Down Expand Up @@ -52,7 +54,7 @@ def distance(self, x: Number, y: Number) -> float:
)

@plot_with_custom_style
def plot(self, ax: Axes = None) -> Axes:
def plot(self, ax: Axes | None = None) -> Axes:
"""
Plot the shape.
Expand Down
23 changes: 21 additions & 2 deletions src/data_morph/shapes/bases/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@
class Shape(ABC):
"""Abstract base class for a shape."""

name: str | None = None
"""The display name for the shape, if the lowercased class name is not desired."""

@classmethod
def get_name(cls) -> str:
"""
Get the name of the shape.
Returns
-------
str
The name of the shape.
"""
return cls.name or cls.__name__.lower()

def __repr__(self) -> str:
"""
Return string representation of the shape.
Expand All @@ -32,8 +47,12 @@ def __str__(self) -> str:
-------
str
The human-readable string representation of the shape.
See Also
--------
get_name : This calls the :meth:`.get_name` class method.
"""
return self.__class__.__name__.lower()
return self.get_name()

@abstractmethod
def distance(self, x: Number, y: Number) -> float:
Expand Down Expand Up @@ -103,7 +122,7 @@ def _recursive_repr(self, attr: str | None = None) -> str:
)

@abstractmethod
def plot(self, ax: Axes = None) -> Axes:
def plot(self, ax: Axes | None = None) -> Axes:
"""
Plot the shape.
Expand Down
8 changes: 5 additions & 3 deletions src/data_morph/shapes/circles.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Shapes that are circular in nature."""

from __future__ import annotations

from numbers import Number

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -33,7 +35,7 @@ class Circle(Shape):
The radius of the circle.
"""

def __init__(self, dataset: Dataset, radius: Number = None) -> None:
def __init__(self, dataset: Dataset, radius: Number | None = None) -> None:
self.center: np.ndarray = dataset.df[['x', 'y']].mean().to_numpy()
"""numpy.ndarray: The (x, y) coordinates of the circle's center."""

Expand Down Expand Up @@ -63,7 +65,7 @@ def distance(self, x: Number, y: Number) -> float:
)

@plot_with_custom_style
def plot(self, ax: Axes = None) -> Axes:
def plot(self, ax: Axes | None = None) -> Axes:
"""
Plot the shape.
Expand Down Expand Up @@ -159,7 +161,7 @@ def distance(self, x: Number, y: Number) -> float:
)

@plot_with_custom_style
def plot(self, ax: Axes = None) -> Axes:
def plot(self, ax: Axes | None = None) -> Axes:
"""
Plot the shape.
Expand Down
80 changes: 55 additions & 25 deletions src/data_morph/shapes/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,39 @@

from itertools import zip_longest
from numbers import Number
from typing import ClassVar

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes

from ..data.dataset import Dataset
from ..plotting.style import plot_with_custom_style
from . import circles, lines, points, polygons
from .bases.shape import Shape
from .circles import Bullseye, Circle, Rings
from .lines import (
Diamond,
HighLines,
HorizontalLines,
Rectangle,
SlantDownLines,
SlantUpLines,
Star,
VerticalLines,
WideLines,
XLines,
)
from .points import (
Club,
DotsGrid,
DownParabola,
Heart,
LeftParabola,
RightParabola,
Scatter,
Spade,
UpParabola,
)


class ShapeFactory:
Expand All @@ -33,33 +57,39 @@ class ShapeFactory:
The starting dataset to morph into other shapes.
"""

_SHAPE_MAPPING: dict = {
'bullseye': circles.Bullseye,
'circle': circles.Circle,
'high_lines': lines.HighLines,
'h_lines': lines.HorizontalLines,
'slant_down': lines.SlantDownLines,
'slant_up': lines.SlantUpLines,
'v_lines': lines.VerticalLines,
'wide_lines': lines.WideLines,
'x': lines.XLines,
'dots': points.DotsGrid,
'down_parab': points.DownParabola,
'heart': points.Heart,
'left_parab': points.LeftParabola,
'scatter': points.Scatter,
'right_parab': points.RightParabola,
'up_parab': points.UpParabola,
'diamond': polygons.Diamond,
'rectangle': polygons.Rectangle,
'rings': circles.Rings,
'star': polygons.Star,
'club': points.Club,
'spade': points.Spade,
_SHAPE_CLASSES: tuple[type[Shape]] = (
Bullseye,
Circle,
Club,
Diamond,
DotsGrid,
DownParabola,
Heart,
HighLines,
HorizontalLines,
LeftParabola,
Rectangle,
RightParabola,
Rings,
Scatter,
SlantDownLines,
SlantUpLines,
Spade,
Star,
UpParabola,
VerticalLines,
WideLines,
XLines,
)
"""New shape classes must be registered here."""

_SHAPE_MAPPING: ClassVar[dict[str, type[Shape]]] = {
shape_cls.get_name(): shape_cls for shape_cls in _SHAPE_CLASSES
}
"""Mapping of shape display names to classes."""

AVAILABLE_SHAPES: list[str] = sorted(_SHAPE_MAPPING.keys())
"""list[str]: The list of available shapes, which can be visualized with
"""The list of available shapes, which can be visualized with
:meth:`.plot_available_shapes`."""

def __init__(self, dataset: Dataset) -> None:
Expand Down
Loading

0 comments on commit 0a25272

Please sign in to comment.