diff --git a/src/data_morph/shapes/bases/line_collection.py b/src/data_morph/shapes/bases/line_collection.py index 42476083..2f91fc2f 100644 --- a/src/data_morph/shapes/bases/line_collection.py +++ b/src/data_morph/shapes/bases/line_collection.py @@ -23,7 +23,13 @@ class LineCollection(Shape): """ def __init__(self, *lines: Iterable[Iterable[Number]]) -> None: - self.lines = lines + # check that lines with the same starting and ending points raise an error + for line in lines: + start, end = line + if np.allclose(start, end): + raise ValueError(f'Line {line} has the same start and end point') + + self.lines = np.array(lines) """Iterable[Iterable[numbers.Number]]: An iterable of two (x, y) pairs representing the endpoints of a line.""" @@ -45,66 +51,48 @@ def distance(self, x: Number, y: Number) -> float: float The minimum distance from the lines of this shape to the point (x, y). - """ - return min( - self._distance_point_to_line(point=(x, y), line=line) for line in self.lines - ) - def _distance_point_to_line( - self, - point: Iterable[Number], - line: Iterable[Iterable[Number]], - ) -> float: + Notes + ----- + Implementation based on `this Stack Overflow answer`_. + + .. _this Stack Overflow answer: https://stackoverflow.com/a/58781995 """ - Calculate the minimum distance between a point and a line. + point = np.array([x, y]) - Parameters - ---------- - point : Iterable[numbers.Number] - Coordinates of a point in 2D space. - line : Iterable[Iterable[numbers.Number]] - Coordinates of the endpoints of a line in 2D space. + start_points = self.lines[:, 0, :] + end_points = self.lines[:, 1, :] - Returns - ------- - float - The minimum distance between the point and the line. + tangent_vector = end_points - start_points + normalized_tangent_vectors = np.divide( + tangent_vector, + np.hypot(tangent_vector[:, 0], tangent_vector[:, 1]).reshape(-1, 1), + ) - Notes - ----- - Implementation based on `this VBA code`_. + # row-wise dot products of 2D vectors + signed_parallel_distance_start = np.multiply( + start_points - point, normalized_tangent_vectors + ).sum(axis=1) + signed_parallel_distance_end = np.multiply( + point - end_points, normalized_tangent_vectors + ).sum(axis=1) + + clamped_parallel_distance = np.maximum.reduce( + [ + signed_parallel_distance_start, + signed_parallel_distance_end, + np.zeros(signed_parallel_distance_start.shape[0]), + ] + ) - .. _this VBA code: http://local.wasp.uwa.edu.au/~pbourke/geometry/pointline/source.vba - """ - start, end = np.array(line) - line_mag = self._euclidean_distance(start, end) - point = np.array(point) - - if line_mag < 0.00000001: - # Arbitrarily large value - return 9999 - - px, py = point - x1, y1 = start - x2, y2 = end - - u1 = ((px - x1) * (x2 - x1)) + ((py - y1) * (y2 - y1)) - u = u1 / (line_mag * line_mag) - - if (u < 0.00001) or (u > 1): - # closest point does not fall within the line segment, take the shorter - # distance to an endpoint - distance = min( - self._euclidean_distance(point, start), - self._euclidean_distance(point, end), - ) - else: - # Intersecting point is on the line, use the formula - ix = x1 + u * (x2 - x1) - iy = y1 + u * (y2 - y1) - distance = self._euclidean_distance(point, np.array((ix, iy))) - - return distance + # row-wise cross products of 2D vectors + perpendicular_distance_component = np.cross( + point - start_points, normalized_tangent_vectors + ) + + return np.min( + np.hypot(clamped_parallel_distance, perpendicular_distance_component) + ) @plot_with_custom_style def plot(self, ax: Axes = None) -> Axes: diff --git a/tests/shapes/bases/test_line_collection.py b/tests/shapes/bases/test_line_collection.py index 6b356552..dbda853d 100644 --- a/tests/shapes/bases/test_line_collection.py +++ b/tests/shapes/bases/test_line_collection.py @@ -35,20 +35,16 @@ def test_distance_nonzero(self, line_collection, point, expected_distance): assert pytest.approx(line_collection.distance(*point)) == expected_distance @pytest.mark.parametrize('line', [[(0, 0), (0, 0)], [(-1, -1), (-1, -1)]], ids=str) - def test_distance_to_small_line_magnitude(self, line_collection, line): - """Test _distance_point_to_line() for small line magnitudes.""" - distance = line_collection._distance_point_to_line((30, 50), line) - assert distance == 9999 + def test_line_as_point(self, line): + """Test LineCollection raises a ValueError for small line magnitudes.""" + with pytest.raises(ValueError): + LineCollection(line) def test_repr(self, line_collection): """Test that the __repr__() method is working.""" - lines = r'\n '.join( - [r'\[\[\d+\.*\d*, \d+\.*\d*\], \[\d+\.*\d*, \d+\.*\d*\]\]'] - * len(line_collection.lines) - ) assert ( re.match( - (r'^\n lines=\n ' + lines), + r"""\n lines=\n {8}array\(\[\[\d+""", repr(line_collection), ) is not None