From cace32f17da1de20a0aaf7b3cad3d6471ae021ae Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Wed, 20 Sep 2023 13:16:45 +0100 Subject: [PATCH 01/35] Work towards fraction binning --- sasdata/data_util/geometry.py | 0 sasdata/data_util/meshmerge.py | 89 ++++++++++++++++++++++++++++ sasdata/data_util/sample_polygons.py | 31 ++++++++++ sasdata/data_util/transforms.py | 58 ++++++++++++++++++ 4 files changed, 178 insertions(+) create mode 100644 sasdata/data_util/geometry.py create mode 100644 sasdata/data_util/meshmerge.py create mode 100644 sasdata/data_util/sample_polygons.py create mode 100644 sasdata/data_util/transforms.py diff --git a/sasdata/data_util/geometry.py b/sasdata/data_util/geometry.py new file mode 100644 index 0000000..e69de29 diff --git a/sasdata/data_util/meshmerge.py b/sasdata/data_util/meshmerge.py new file mode 100644 index 0000000..eef12bf --- /dev/null +++ b/sasdata/data_util/meshmerge.py @@ -0,0 +1,89 @@ +from typing import Sequence +from scipy.spatial import Delaunay + +import numpy as np + +from dataclasses import dataclass + +@dataclass +class Mesh: + points: np.ndarray + edges: Sequence[Sequence[int]] # List of pairs of points forming edges + cells: Sequence[Sequence[int]] # List of edges constituting a cell + + +def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray]: + """ Take two lists of polygons and find their intersections + + Polygons in each of the input variables should not overlap i.e. a point in space should be assignable to + at most one polygon in mesh_a and at most one polygon in mesh_b + + Mesh topology should be sensible, otherwise bad things might happen + + :returns: + 1) A triangulated mesh based on both sets of polygons together + 2) The indices of the mesh_a polygon that corresponds to each triangle, -1 for nothing + 3) The indices of the mesh_b polygon that corresponds to each triangle, -1 for nothing + + """ + + # Find intersections of all edges in mesh one with edges in mesh two + + new_points = [] + for edge_a in mesh_a.edges: + for edge_b in mesh_b.edges: + # + # Parametric description of intersection in terms of position along lines + # + # Simultaneous eqns (to reflect current wiki notation) + # s(x2 - x1) - t(x4 - x3) = x3 - x1 + # s(y2 - y1) - t(y4 - y3) = y3 - y1 + # + # in matrix form: + # m.(s,t) = v + # + + p1 = mesh_a.points[edge_a[0]] + p2 = mesh_a.points[edge_a[1]] + p3 = mesh_b.points[edge_b[0]] + p4 = mesh_b.points[edge_b[1]] + + m = np.array([ + [p2[0] - p1[0], p3[0] - p4[0]], + [p2[1] - p1[1], p3[1] - p4[1]]]) + + v = np.array([p3[0] - p1[0], p3[1] - p1[1]]) + + if np.linalg.det(m) == 0: + # Lines don't intersect + break + + st = np.linalg.solve(m, v) + + # As the purpose of this is finding new points for the merged mesh, we don't + # want new points if they are right at the end of the lines, hence non strict + # inequalities here + if np.any(st <= 0) or np.any(st >= 1): + # Exclude intection points, that are not on the *segments* + break + + x = p1[0] + (p2[0] - p1[1])*st[0] + y = p1[1] + (p2[1] - p1[1])*st[1] + + new_points.append((x, y)) + + # Build list of all input points, in a way that we can check for coincident points + + + + # Remove coincident points + + + # Triangulate based on these intersections + + # Find centroids of all output triangles, and find which source cells they belong to + + ## Assign -1 to all cells + ## Find centroids + ## Check whether within bounding box + ## If in bounding box, check cell properly using winding number, if inside, assign diff --git a/sasdata/data_util/sample_polygons.py b/sasdata/data_util/sample_polygons.py new file mode 100644 index 0000000..e12fb1e --- /dev/null +++ b/sasdata/data_util/sample_polygons.py @@ -0,0 +1,31 @@ +import numpy as np + +def wedge(q0, q1, theta0, theta1, clockwise=False, n_points_per_degree=2): + + # Traverse a rectangle in curvilinear coordinates (q0, theta0), (q0, theta1), (q1, theta1), (q1, theta0) + if clockwise: + if theta1 > theta0: + theta0 += 2*np.pi + + else: + if theta0 > theta1: + theta1 += 2*np.pi + + subtended_angle = np.abs(theta1 - theta0) + n_points = int(subtended_angle*180*n_points_per_degree/np.pi)+1 + + angles = np.linspace(theta0, theta1, n_points) + + xs = np.concatenate((q0*np.cos(angles), q1*np.cos(angles[::-1]))) + ys = np.concatenate((q0*np.sin(angles), q1*np.sin(angles[::-1]))) + + return np.array((xs, ys)).T + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + xy = wedge(0.3, 0.6, 2, 3) + + plt.plot(xy[:,0], xy[:,1]) + plt.show() + diff --git a/sasdata/data_util/transforms.py b/sasdata/data_util/transforms.py new file mode 100644 index 0000000..d04742d --- /dev/null +++ b/sasdata/data_util/transforms.py @@ -0,0 +1,58 @@ +import numpy as np +from scipy.spatial import Voronoi, Delaunay +import matplotlib.pyplot as plt +from matplotlib import cm + + +# Some test data + +qx_base_values = np.linspace(-10, 10, 21) +qy_base_values = np.linspace(-10, 10, 21) + +qx, qy = np.meshgrid(qx_base_values, qy_base_values) + +include = np.logical_not((np.abs(qx) < 2) & (np.abs(qy) < 2)) + +qx = qx[include] +qy = qy[include] + +r = np.sqrt(qx**2 + qy**2) + +data = np.log((1+np.cos(3*r))*np.exp(-r*r)) + +colormap = cm.get_cmap('winter', 256) + +def get_data_mesh(x, y, data): + + input_data = np.array((x, y)).T + voronoi = Voronoi(input_data) + + # plt.scatter(voronoi.vertices[:,0], voronoi.vertices[:,1]) + # plt.scatter(voronoi.points[:,0], voronoi.points[:,1]) + + cmin = np.min(data) + cmax = np.max(data) + + color_index_map = np.array(255 * (data - cmin) / (cmax - cmin), dtype=int) + + for point_index, points in enumerate(voronoi.points): + + region_index = voronoi.point_region[point_index] + region = voronoi.regions[region_index] + + if len(region) > 0: + + if -1 in region: + + pass + + else: + + color = colormap(color_index_map[point_index]) + + circly = region + [region[0]] + plt.fill(voronoi.vertices[circly, 0], voronoi.vertices[circly, 1], color=color, edgecolor="white") + + plt.show() + +get_data_mesh(qx.reshape(-1), qy.reshape(-1), data) \ No newline at end of file From 459341d3c16837a0eee56326b919c94d886c3a23 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Fri, 22 Sep 2023 13:50:25 +0100 Subject: [PATCH 02/35] Mesh merging and some refactoring --- sasdata/data_util/meshmerge.py | 89 --------- .../{geometry.py => slicing/__init__.py} | 0 sasdata/data_util/slicing/geometry.py | 0 sasdata/data_util/slicing/mesh.py | 28 +++ sasdata/data_util/slicing/meshmerge.py | 170 ++++++++++++++++++ .../{ => slicing}/sample_polygons.py | 0 sasdata/data_util/{ => slicing}/transforms.py | 0 sasdata/data_util/slicing/voronoi_mesh.py | 37 ++++ 8 files changed, 235 insertions(+), 89 deletions(-) delete mode 100644 sasdata/data_util/meshmerge.py rename sasdata/data_util/{geometry.py => slicing/__init__.py} (100%) create mode 100644 sasdata/data_util/slicing/geometry.py create mode 100644 sasdata/data_util/slicing/mesh.py create mode 100644 sasdata/data_util/slicing/meshmerge.py rename sasdata/data_util/{ => slicing}/sample_polygons.py (100%) rename sasdata/data_util/{ => slicing}/transforms.py (100%) create mode 100644 sasdata/data_util/slicing/voronoi_mesh.py diff --git a/sasdata/data_util/meshmerge.py b/sasdata/data_util/meshmerge.py deleted file mode 100644 index eef12bf..0000000 --- a/sasdata/data_util/meshmerge.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import Sequence -from scipy.spatial import Delaunay - -import numpy as np - -from dataclasses import dataclass - -@dataclass -class Mesh: - points: np.ndarray - edges: Sequence[Sequence[int]] # List of pairs of points forming edges - cells: Sequence[Sequence[int]] # List of edges constituting a cell - - -def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray]: - """ Take two lists of polygons and find their intersections - - Polygons in each of the input variables should not overlap i.e. a point in space should be assignable to - at most one polygon in mesh_a and at most one polygon in mesh_b - - Mesh topology should be sensible, otherwise bad things might happen - - :returns: - 1) A triangulated mesh based on both sets of polygons together - 2) The indices of the mesh_a polygon that corresponds to each triangle, -1 for nothing - 3) The indices of the mesh_b polygon that corresponds to each triangle, -1 for nothing - - """ - - # Find intersections of all edges in mesh one with edges in mesh two - - new_points = [] - for edge_a in mesh_a.edges: - for edge_b in mesh_b.edges: - # - # Parametric description of intersection in terms of position along lines - # - # Simultaneous eqns (to reflect current wiki notation) - # s(x2 - x1) - t(x4 - x3) = x3 - x1 - # s(y2 - y1) - t(y4 - y3) = y3 - y1 - # - # in matrix form: - # m.(s,t) = v - # - - p1 = mesh_a.points[edge_a[0]] - p2 = mesh_a.points[edge_a[1]] - p3 = mesh_b.points[edge_b[0]] - p4 = mesh_b.points[edge_b[1]] - - m = np.array([ - [p2[0] - p1[0], p3[0] - p4[0]], - [p2[1] - p1[1], p3[1] - p4[1]]]) - - v = np.array([p3[0] - p1[0], p3[1] - p1[1]]) - - if np.linalg.det(m) == 0: - # Lines don't intersect - break - - st = np.linalg.solve(m, v) - - # As the purpose of this is finding new points for the merged mesh, we don't - # want new points if they are right at the end of the lines, hence non strict - # inequalities here - if np.any(st <= 0) or np.any(st >= 1): - # Exclude intection points, that are not on the *segments* - break - - x = p1[0] + (p2[0] - p1[1])*st[0] - y = p1[1] + (p2[1] - p1[1])*st[1] - - new_points.append((x, y)) - - # Build list of all input points, in a way that we can check for coincident points - - - - # Remove coincident points - - - # Triangulate based on these intersections - - # Find centroids of all output triangles, and find which source cells they belong to - - ## Assign -1 to all cells - ## Find centroids - ## Check whether within bounding box - ## If in bounding box, check cell properly using winding number, if inside, assign diff --git a/sasdata/data_util/geometry.py b/sasdata/data_util/slicing/__init__.py similarity index 100% rename from sasdata/data_util/geometry.py rename to sasdata/data_util/slicing/__init__.py diff --git a/sasdata/data_util/slicing/geometry.py b/sasdata/data_util/slicing/geometry.py new file mode 100644 index 0000000..e69de29 diff --git a/sasdata/data_util/slicing/mesh.py b/sasdata/data_util/slicing/mesh.py new file mode 100644 index 0000000..c27be60 --- /dev/null +++ b/sasdata/data_util/slicing/mesh.py @@ -0,0 +1,28 @@ +from typing import Sequence + +import numpy as np + +import matplotlib.pyplot as plt +from matplotlib.collections import LineCollection + +class Mesh: + def __init__(self, points: np.ndarray, edges: Sequence[Sequence[int]], cells: Sequence[Sequence[int]]): + self.points = points + self.edges = edges + self.cells = cells + + self._cells_to_points = None + + + def show(self, actually_show=True, **kwargs): + + ax = plt.gca() + segments = [[self.points[edge[0]], self.points[edge[1]]] for edge in self.edges] + line_collection = LineCollection(segments=segments, **kwargs) + ax.add_collection(line_collection) + + if actually_show: + plt.show() + + def show_data(self, data: np.ndarray): + raise NotImplementedError("Show data not implemented") \ No newline at end of file diff --git a/sasdata/data_util/slicing/meshmerge.py b/sasdata/data_util/slicing/meshmerge.py new file mode 100644 index 0000000..32cd8e1 --- /dev/null +++ b/sasdata/data_util/slicing/meshmerge.py @@ -0,0 +1,170 @@ +from typing import Sequence +from scipy.spatial import Delaunay + +import numpy as np + +from dataclasses import dataclass + +from sasdata.data_util.slicing.mesh import Mesh + +import matplotlib.pyplot as plt + +def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray]: + """ Take two lists of polygons and find their intersections + + Polygons in each of the input variables should not overlap i.e. a point in space should be assignable to + at most one polygon in mesh_a and at most one polygon in mesh_b + + Mesh topology should be sensible, otherwise bad things might happen + + :returns: + 1) A triangulated mesh based on both sets of polygons together + 2) The indices of the mesh_a polygon that corresponds to each triangle, -1 for nothing + 3) The indices of the mesh_b polygon that corresponds to each triangle, -1 for nothing + + """ + + # Find intersections of all edges in mesh one with edges in mesh two + + new_x = [] + new_y = [] + for edge_a in mesh_a.edges: + for edge_b in mesh_b.edges: + + p1 = mesh_a.points[edge_a[0]] + p2 = mesh_a.points[edge_a[1]] + p3 = mesh_b.points[edge_b[0]] + p4 = mesh_b.points[edge_b[1]] + + # Bounding box check + + # First edge entirely to left of other + if max((p1[0], p2[0])) < min((p3[0], p4[0])): + continue + + # First edge entirely below other + if max((p1[1], p2[1])) < min((p3[1], p4[1])): + continue + + # First edge entirely to right of other + if min((p1[0], p2[0])) > max((p3[0], p4[0])): + continue + + # First edge entirely above other + if min((p1[1], p2[1])) > max((p3[1], p4[1])): + continue + + # + # Parametric description of intersection in terms of position along lines + # + # Simultaneous eqns (to reflect current wiki notation) + # s(x2 - x1) - t(x4 - x3) = x3 - x1 + # s(y2 - y1) - t(y4 - y3) = y3 - y1 + # + # in matrix form: + # m.(s,t) = v + # + + + m = np.array([ + [p2[0] - p1[0], p3[0] - p4[0]], + [p2[1] - p1[1], p3[1] - p4[1]]]) + + v = np.array([p3[0] - p1[0], p3[1] - p1[1]]) + + if np.linalg.det(m) == 0: + # Lines don't intersect, or are colinear in a way that doesn't matter + continue + + st = np.linalg.solve(m, v) + + # As the purpose of this is finding new points for the merged mesh, we don't + # want new points if they are right at the end of the lines, hence non-strict + # inequalities here + if np.any(st <= 0) or np.any(st >= 1): + # Exclude intection points, that are not on the *segments* + continue + + x = p1[0] + (p2[0] - p1[0])*st[0] + y = p1[1] + (p2[1] - p1[1])*st[0] + + new_x.append(x) + new_y.append(y) + + + + # Build list of all input points, in a way that we can check for coincident points + + # plt.scatter(mesh_a.points[:,0], mesh_a.points[:,1]) + # plt.scatter(mesh_b.points[:,0], mesh_b.points[:,1]) + # plt.scatter(new_x, new_y) + # + # mesh_a.show(False) + # mesh_b.show(False, color=(.8, .5, 0)) + # + # plt.xlim([0,1]) + # plt.ylim([0,1]) + # + # plt.show() + + points = np.concatenate(( + mesh_a.points, + mesh_b.points, + np.array((new_x, new_y)).T + )) + + # plt.scatter(points[:,0], points[:,1]) + # plt.show() + + # Remove coincident points + + points = np.unique(points, axis=0) + + # Triangulate based on these intersections + + # Find centroids of all output triangles, and find which source cells they belong to + + ## Assign -1 to all cells + ## Find centroids - they're just the closed voronoi cells? + ## Check whether within bounding box + ## If in bounding box, check cell properly using winding number, if inside, assign + + +def simple_intersection(): + mesh_a = Mesh( + np.array([[0, 0.5],[1,0.5]], dtype=float), + [[0, 1]], []) + + mesh_b = Mesh( + np.array([[0.5, 0], [0.5, 1]], dtype=float), + [[0, 1]], []) + + meshmerge(mesh_a, mesh_b) + + + +def simple_intersection_2(): + mesh_a = Mesh( + np.array([[4,3],[1,3]], dtype=float), + [[0, 1]], []) + + mesh_b = Mesh( + np.array([[3, 4], [3, 1]], dtype=float), + [[0, 1]], []) + + meshmerge(mesh_a, mesh_b) +def main(): + from voronoi_mesh import voronoi_mesh + + n1 = 100 + n2 = 100 + + m1 = voronoi_mesh(np.random.random(n1), np.random.random(n1)) + m2 = voronoi_mesh(np.random.random(n2), np.random.random(n2)) + + + meshmerge(m1, m2) + +if __name__ == "__main__": + main() + # simple_intersection() \ No newline at end of file diff --git a/sasdata/data_util/sample_polygons.py b/sasdata/data_util/slicing/sample_polygons.py similarity index 100% rename from sasdata/data_util/sample_polygons.py rename to sasdata/data_util/slicing/sample_polygons.py diff --git a/sasdata/data_util/transforms.py b/sasdata/data_util/slicing/transforms.py similarity index 100% rename from sasdata/data_util/transforms.py rename to sasdata/data_util/slicing/transforms.py diff --git a/sasdata/data_util/slicing/voronoi_mesh.py b/sasdata/data_util/slicing/voronoi_mesh.py new file mode 100644 index 0000000..34a8fd7 --- /dev/null +++ b/sasdata/data_util/slicing/voronoi_mesh.py @@ -0,0 +1,37 @@ +import numpy as np +from scipy.spatial import Voronoi + + +from sasdata.data_util.slicing.mesh import Mesh + +def voronoi_mesh(x, y) -> Mesh: + + input_data = np.array((x, y)).T + voronoi = Voronoi(input_data) + + edges = set() + + for point_index, points in enumerate(voronoi.points): + + region_index = voronoi.point_region[point_index] + region = voronoi.regions[region_index] + + wrapped = region + [region[0]] + for a, b in zip(wrapped[:-1], wrapped[1:]): + if not a == -1 and not b == -1: + + # make sure the representation is unique + if a > b: + edges.add((a, b)) + else: + edges.add((b, a)) + + edges = list(edges) + + return Mesh(points=voronoi.vertices, edges=edges, cells=[]) + + +if __name__ == "__main__": + points = np.random.random((100, 2)) + mesh = voronoi_mesh(points[:,0], points[:,1]) + mesh.show() \ No newline at end of file From 1290f31b176e0f6b99e1b0bf3bd3ba3a33ce6555 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Sun, 24 Sep 2023 12:54:29 +0100 Subject: [PATCH 03/35] Triangulated mesh --- sasdata/data_util/slicing/delaunay_mesh.py | 34 ++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 sasdata/data_util/slicing/delaunay_mesh.py diff --git a/sasdata/data_util/slicing/delaunay_mesh.py b/sasdata/data_util/slicing/delaunay_mesh.py new file mode 100644 index 0000000..ef90e44 --- /dev/null +++ b/sasdata/data_util/slicing/delaunay_mesh.py @@ -0,0 +1,34 @@ +import numpy as np + +from scipy.spatial import Delaunay + +from sasdata.data_util.slicing.mesh import Mesh + + +def delaunay_mesh(x, y) -> Mesh: + + input_data = np.array((x, y)).T + delaunay = Delaunay(input_data) + + edges = set() + + for simplex_index, simplex in enumerate(delaunay.simplices): + + wrapped = list(simplex) + [simplex[0]] + + for a, b in zip(wrapped[:-1], wrapped[1:]): + # make sure the representation is unique + if a > b: + edges.add((a, b)) + else: + edges.add((b, a)) + + edges = list(edges) + + return Mesh(points=input_data, edges=edges, cells=[]) + + +if __name__ == "__main__": + points = np.random.random((100, 2)) + mesh = delaunay_mesh(points[:,0], points[:,1]) + mesh.show() \ No newline at end of file From 5e809eac30598f2a105e05762194e8c8d116fd49 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Mon, 25 Sep 2023 01:28:00 +0100 Subject: [PATCH 04/35] Mesh merging works --- sasdata/data_util/slicing/delaunay_mesh.py | 34 ------ sasdata/data_util/slicing/mesh.py | 28 ----- sasdata/data_util/slicing/meshes/__init__.py | 0 .../data_util/slicing/meshes/delaunay_mesh.py | 32 +++++ sasdata/data_util/slicing/meshes/mesh.py | 96 +++++++++++++++ .../slicing/{ => meshes}/meshmerge.py | 110 +++++++++++------- sasdata/data_util/slicing/meshes/util.py | 10 ++ .../data_util/slicing/meshes/voronoi_mesh.py | 20 ++++ sasdata/data_util/slicing/voronoi_mesh.py | 37 ------ test/slicers/__init__.py | 0 test/slicers/meshes_for_testing.py | 75 ++++++++++++ test/slicers/utest_meshmerge.py | 21 ++++ 12 files changed, 321 insertions(+), 142 deletions(-) delete mode 100644 sasdata/data_util/slicing/delaunay_mesh.py delete mode 100644 sasdata/data_util/slicing/mesh.py create mode 100644 sasdata/data_util/slicing/meshes/__init__.py create mode 100644 sasdata/data_util/slicing/meshes/delaunay_mesh.py create mode 100644 sasdata/data_util/slicing/meshes/mesh.py rename sasdata/data_util/slicing/{ => meshes}/meshmerge.py (51%) create mode 100644 sasdata/data_util/slicing/meshes/util.py create mode 100644 sasdata/data_util/slicing/meshes/voronoi_mesh.py delete mode 100644 sasdata/data_util/slicing/voronoi_mesh.py create mode 100644 test/slicers/__init__.py create mode 100644 test/slicers/meshes_for_testing.py create mode 100644 test/slicers/utest_meshmerge.py diff --git a/sasdata/data_util/slicing/delaunay_mesh.py b/sasdata/data_util/slicing/delaunay_mesh.py deleted file mode 100644 index ef90e44..0000000 --- a/sasdata/data_util/slicing/delaunay_mesh.py +++ /dev/null @@ -1,34 +0,0 @@ -import numpy as np - -from scipy.spatial import Delaunay - -from sasdata.data_util.slicing.mesh import Mesh - - -def delaunay_mesh(x, y) -> Mesh: - - input_data = np.array((x, y)).T - delaunay = Delaunay(input_data) - - edges = set() - - for simplex_index, simplex in enumerate(delaunay.simplices): - - wrapped = list(simplex) + [simplex[0]] - - for a, b in zip(wrapped[:-1], wrapped[1:]): - # make sure the representation is unique - if a > b: - edges.add((a, b)) - else: - edges.add((b, a)) - - edges = list(edges) - - return Mesh(points=input_data, edges=edges, cells=[]) - - -if __name__ == "__main__": - points = np.random.random((100, 2)) - mesh = delaunay_mesh(points[:,0], points[:,1]) - mesh.show() \ No newline at end of file diff --git a/sasdata/data_util/slicing/mesh.py b/sasdata/data_util/slicing/mesh.py deleted file mode 100644 index c27be60..0000000 --- a/sasdata/data_util/slicing/mesh.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Sequence - -import numpy as np - -import matplotlib.pyplot as plt -from matplotlib.collections import LineCollection - -class Mesh: - def __init__(self, points: np.ndarray, edges: Sequence[Sequence[int]], cells: Sequence[Sequence[int]]): - self.points = points - self.edges = edges - self.cells = cells - - self._cells_to_points = None - - - def show(self, actually_show=True, **kwargs): - - ax = plt.gca() - segments = [[self.points[edge[0]], self.points[edge[1]]] for edge in self.edges] - line_collection = LineCollection(segments=segments, **kwargs) - ax.add_collection(line_collection) - - if actually_show: - plt.show() - - def show_data(self, data: np.ndarray): - raise NotImplementedError("Show data not implemented") \ No newline at end of file diff --git a/sasdata/data_util/slicing/meshes/__init__.py b/sasdata/data_util/slicing/meshes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sasdata/data_util/slicing/meshes/delaunay_mesh.py b/sasdata/data_util/slicing/meshes/delaunay_mesh.py new file mode 100644 index 0000000..45e2087 --- /dev/null +++ b/sasdata/data_util/slicing/meshes/delaunay_mesh.py @@ -0,0 +1,32 @@ +import numpy as np +from scipy.spatial import Delaunay + +from sasdata.data_util.slicing.meshes.mesh import Mesh + +def delaunay_mesh(x, y) -> Mesh: + """ Create a triangulated mesh based on input points """ + + input_data = np.array((x, y)).T + delaunay = Delaunay(input_data) + + return Mesh(points=input_data, cells=delaunay.simplices) + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + + points = np.random.random((100, 2)) + mesh = delaunay_mesh(points[:,0], points[:,1]) + mesh.show(actually_show=False) + + print(mesh.cells[50]) + + # pick random cell to show + for cell in mesh.cells_to_edges[10]: + a, b = mesh.edges[cell] + plt.plot( + [mesh.points[a][0], mesh.points[b][0]], + [mesh.points[a][1], mesh.points[b][1]], + color='r') + + plt.show() diff --git a/sasdata/data_util/slicing/meshes/mesh.py b/sasdata/data_util/slicing/meshes/mesh.py new file mode 100644 index 0000000..b52b3e8 --- /dev/null +++ b/sasdata/data_util/slicing/meshes/mesh.py @@ -0,0 +1,96 @@ +from typing import Sequence + +import numpy as np + +import matplotlib.pyplot as plt +from matplotlib.collections import LineCollection + +from sasdata.data_util.slicing.meshes.util import closed_loop_edges + +class Mesh: + def __init__(self, + points: np.ndarray, + cells: Sequence[Sequence[int]]): + + """ + Object representing a mesh. + + Parameters are the values: + mesh points + map from edge to points + map from cells to edges + + it is done this way to ensure a non-redundant representation of cells and edges, + however there are no checks for the topology of the mesh, this is assumed to be done by + whatever creates it. There are also no checks for ordering of cells. + + :param points: points in 2D forming vertices of the mesh + :param cells: ordered lists of indices of points forming each cell (face) + + """ + + self.points = points + self.cells = cells + + # Get edges + + edges = set() + for cell_index, cell in enumerate(cells): + + for a, b in closed_loop_edges(cell): + # make sure the representation is unique + if a > b: + edges.add((a, b)) + else: + edges.add((b, a)) + + self.edges = list(edges) + + # Associate edges with faces + + edge_lookup = {edge: i for i, edge in enumerate(self.edges)} + self.cells_to_edges = [] + + for cell in cells: + + this_cell_data = [] + + for a, b in closed_loop_edges(cell): + # make sure the representation is unique + if a > b: + this_cell_data.append(edge_lookup[(a, b)]) + else: + this_cell_data.append(edge_lookup[(b, a)]) + + self.cells_to_edges.append(this_cell_data) + + # Counts for elements + self.n_points = self.points.shape[0] + self.n_edges = len(self.edges) + self.n_cells = len(self.cells) + + def show(self, actually_show=True, show_labels=False, **kwargs): + """ Show on a plot """ + ax = plt.gca() + segments = [[self.points[edge[0]], self.points[edge[1]]] for edge in self.edges] + line_collection = LineCollection(segments=segments, **kwargs) + ax.add_collection(line_collection) + + if show_labels: + text_color = kwargs["color"] if "color" in kwargs else 'k' + for i, cell in enumerate(self.cells): + xy = np.sum(self.points[cell, :], axis=0)/len(cell) + ax.text(xy[0], xy[1], str(i), horizontalalignment="center", verticalalignment="center", color=text_color) + + x_limits = [np.min(self.points[:,0]), np.max(self.points[:,0])] + y_limits = [np.min(self.points[:,1]), np.max(self.points[:,1])] + + plt.xlim(x_limits) + plt.ylim(y_limits) + + if actually_show: + plt.show() + + def show_data(self, data: np.ndarray, show_mesh=True): + """ Show with data """ + raise NotImplementedError("Show data not implemented") \ No newline at end of file diff --git a/sasdata/data_util/slicing/meshmerge.py b/sasdata/data_util/slicing/meshes/meshmerge.py similarity index 51% rename from sasdata/data_util/slicing/meshmerge.py rename to sasdata/data_util/slicing/meshes/meshmerge.py index 32cd8e1..3ce52ba 100644 --- a/sasdata/data_util/slicing/meshmerge.py +++ b/sasdata/data_util/slicing/meshes/meshmerge.py @@ -1,13 +1,9 @@ -from typing import Sequence -from scipy.spatial import Delaunay - import numpy as np -from dataclasses import dataclass - -from sasdata.data_util.slicing.mesh import Mesh +from sasdata.data_util.slicing.meshes.mesh import Mesh +from sasdata.data_util.slicing.meshes.delaunay_mesh import delaunay_mesh +from sasdata.data_util.slicing.meshes.util import closed_loop_edges -import matplotlib.pyplot as plt def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray]: """ Take two lists of polygons and find their intersections @@ -15,7 +11,8 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] Polygons in each of the input variables should not overlap i.e. a point in space should be assignable to at most one polygon in mesh_a and at most one polygon in mesh_b - Mesh topology should be sensible, otherwise bad things might happen + Mesh topology should be sensible, otherwise bad things might happen, also, the cells of the input meshes + must be in order (which is assumed by the mesh class constructor anyway). :returns: 1) A triangulated mesh based on both sets of polygons together @@ -95,17 +92,6 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] # Build list of all input points, in a way that we can check for coincident points - # plt.scatter(mesh_a.points[:,0], mesh_a.points[:,1]) - # plt.scatter(mesh_b.points[:,0], mesh_b.points[:,1]) - # plt.scatter(new_x, new_y) - # - # mesh_a.show(False) - # mesh_b.show(False, color=(.8, .5, 0)) - # - # plt.xlim([0,1]) - # plt.ylim([0,1]) - # - # plt.show() points = np.concatenate(( mesh_a.points, @@ -113,8 +99,6 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] np.array((new_x, new_y)).T )) - # plt.scatter(points[:,0], points[:,1]) - # plt.show() # Remove coincident points @@ -122,37 +106,75 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] # Triangulate based on these intersections + output_mesh = delaunay_mesh(points[:, 0], points[:, 1]) + # Find centroids of all output triangles, and find which source cells they belong to - ## Assign -1 to all cells - ## Find centroids - they're just the closed voronoi cells? - ## Check whether within bounding box - ## If in bounding box, check cell properly using winding number, if inside, assign + ## step 1) Assign -1 to all cells of original meshes + assignments_a = -np.ones(output_mesh.n_cells, dtype=int) + assignments_b = -np.ones(output_mesh.n_cells, dtype=int) + + ## step 2) Find centroids of triangulated mesh (just needs to be a point inside, but this is a good one) + centroids = [] + for cell in output_mesh.cells: + centroid = np.sum(output_mesh.points[cell, :]/3, axis=0) + centroids.append(centroid) + + ## step 3) Perform checks based on winding number method (see wikipedia Point in Polygon). + for mesh, assignments in [ + (mesh_a, assignments_a), + (mesh_b, assignments_b)]: + + for centroid_index, centroid in enumerate(centroids): + for cell_index, cell in enumerate(mesh.cells): + # Bounding box check + points = mesh.points[cell, :] + if np.any(centroid < np.min(points, axis=0)): # x or y less than any in polygon + continue -def simple_intersection(): - mesh_a = Mesh( - np.array([[0, 0.5],[1,0.5]], dtype=float), - [[0, 1]], []) + if np.any(centroid > np.max(points, axis=0)): # x or y greater than any in polygon + continue - mesh_b = Mesh( - np.array([[0.5, 0], [0.5, 1]], dtype=float), - [[0, 1]], []) + # Winding number check - count directional crossings of vertical half line from centroid + winding_number = 0 + for i1, i2 in closed_loop_edges(cell): + p1 = mesh.points[i1, :] + p2 = mesh.points[i2, :] - meshmerge(mesh_a, mesh_b) + # if the section xs do not straddle the x=centroid_x coordinate, then the + # edge cannot cross the half line. + # If it does, then remember which way it was + # * Careful about ends + # * Also, note that the p1[0] == p2[0] -> (no contribution) case is covered by the strict inequality + if p1[0] > centroid[0] >= p2[0]: + left_right = -1 + elif p2[0] > centroid[0] >= p1[0]: + left_right = 1 + else: + continue + # Find the y point that it crosses x=centroid at + # note: denominator cannot be zero because of strict inequality above + gradient = (p2[1] - p1[1]) / (p2[0] - p1[0]) + x_delta = centroid[0] - p1[0] + y = p1[1] + x_delta * gradient + if y > centroid[1]: + winding_number += left_right -def simple_intersection_2(): - mesh_a = Mesh( - np.array([[4,3],[1,3]], dtype=float), - [[0, 1]], []) - mesh_b = Mesh( - np.array([[3, 4], [3, 1]], dtype=float), - [[0, 1]], []) + if abs(winding_number) > 0: + # Do assignment of input cell to output triangle index + assignments[centroid_index] = cell_index + + # end cell loop + + # end centroid loop + + return output_mesh, assignments_a, assignments_b + - meshmerge(mesh_a, mesh_b) def main(): from voronoi_mesh import voronoi_mesh @@ -163,8 +185,10 @@ def main(): m2 = voronoi_mesh(np.random.random(n2), np.random.random(n2)) - meshmerge(m1, m2) + mesh, _, _ = meshmerge(m1, m2) + + mesh.show() + if __name__ == "__main__": main() - # simple_intersection() \ No newline at end of file diff --git a/sasdata/data_util/slicing/meshes/util.py b/sasdata/data_util/slicing/meshes/util.py new file mode 100644 index 0000000..b78a9e0 --- /dev/null +++ b/sasdata/data_util/slicing/meshes/util.py @@ -0,0 +1,10 @@ +from typing import Sequence, TypeVar + +T = TypeVar("T") + +def closed_loop_edges(values: Sequence[T]) -> tuple[T, T]: + """ Generator for a closed loop of edge pairs """ + for pair in zip(values, values[1:]): + yield pair + + yield values[-1], values[0] \ No newline at end of file diff --git a/sasdata/data_util/slicing/meshes/voronoi_mesh.py b/sasdata/data_util/slicing/meshes/voronoi_mesh.py new file mode 100644 index 0000000..77db2a6 --- /dev/null +++ b/sasdata/data_util/slicing/meshes/voronoi_mesh.py @@ -0,0 +1,20 @@ +import numpy as np +from scipy.spatial import Voronoi + + +from sasdata.data_util.slicing.meshes.mesh import Mesh + +def voronoi_mesh(x, y) -> Mesh: + + input_data = np.array((x.reshape(-1), y.reshape(-1))).T + voronoi = Voronoi(input_data) + + finite_cells = [region for region in voronoi.regions if -1 not in region and len(region) > 0] + + return Mesh(points=voronoi.vertices, cells=finite_cells) + + +if __name__ == "__main__": + points = np.random.random((100, 2)) + mesh = voronoi_mesh(points[:,0], points[:,1]) + mesh.show() \ No newline at end of file diff --git a/sasdata/data_util/slicing/voronoi_mesh.py b/sasdata/data_util/slicing/voronoi_mesh.py deleted file mode 100644 index 34a8fd7..0000000 --- a/sasdata/data_util/slicing/voronoi_mesh.py +++ /dev/null @@ -1,37 +0,0 @@ -import numpy as np -from scipy.spatial import Voronoi - - -from sasdata.data_util.slicing.mesh import Mesh - -def voronoi_mesh(x, y) -> Mesh: - - input_data = np.array((x, y)).T - voronoi = Voronoi(input_data) - - edges = set() - - for point_index, points in enumerate(voronoi.points): - - region_index = voronoi.point_region[point_index] - region = voronoi.regions[region_index] - - wrapped = region + [region[0]] - for a, b in zip(wrapped[:-1], wrapped[1:]): - if not a == -1 and not b == -1: - - # make sure the representation is unique - if a > b: - edges.add((a, b)) - else: - edges.add((b, a)) - - edges = list(edges) - - return Mesh(points=voronoi.vertices, edges=edges, cells=[]) - - -if __name__ == "__main__": - points = np.random.random((100, 2)) - mesh = voronoi_mesh(points[:,0], points[:,1]) - mesh.show() \ No newline at end of file diff --git a/test/slicers/__init__.py b/test/slicers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/slicers/meshes_for_testing.py b/test/slicers/meshes_for_testing.py new file mode 100644 index 0000000..ff87dc8 --- /dev/null +++ b/test/slicers/meshes_for_testing.py @@ -0,0 +1,75 @@ +""" +Meshes used in testing along with some expected values +""" + +import numpy as np + +from sasdata.data_util.slicing.meshes.voronoi_mesh import voronoi_mesh +from sasdata.data_util.slicing.meshes.mesh import Mesh +from sasdata.data_util.slicing.meshes.meshmerge import meshmerge + +coords = np.arange(-4, 5) +grid_mesh = voronoi_mesh(*np.meshgrid(coords, coords)) + + +item_1 = np.array([ + [-3.5, -0.5], + [-0.5, 3.5], + [ 0.5, 3.5], + [ 3.5, -0.5], + [ 0.0, 1.5]], dtype=float) + +item_2 = np.array([ + [-1.0, -2.0], + [-2.0, -2.0], + [-2.0, -1.0], + [-1.0, -1.0]], dtype=float) + +mesh_points = np.concatenate((item_1, item_2), axis=0) +cells = [[0,1,2,3,4],[5,6,7,8]] + +shape_mesh = Mesh(mesh_points, cells) + +# Subset of the mappings that meshmerge should include +# This can be read off the plots generated below +expected_shape_mappings = [ + (98, -1), + (99, -1), + (12, 0), + (1, -1), + (148, 1), + (149, 1), + (110, 1), + (144, -1), + (123, -1)] + + +expected_grid_mappings = [ + (89, 1), + (146, 29), + (66, 34), + (112, 45) +] + + +if __name__ == "__main__": + + import matplotlib.pyplot as plt + + combined_mesh, _, _ = meshmerge(grid_mesh, shape_mesh) + + plt.figure() + combined_mesh.show(actually_show=False, show_labels=True, color='k') + grid_mesh.show(actually_show=False, show_labels=True, color='r') + + plt.xlim([-4, 4]) + plt.ylim([-4, 4]) + + plt.figure() + combined_mesh.show(actually_show=False, show_labels=True, color='k') + shape_mesh.show(actually_show=False, show_labels=True, color='r') + + plt.xlim([-4, 4]) + plt.ylim([-4, 4]) + + plt.show() diff --git a/test/slicers/utest_meshmerge.py b/test/slicers/utest_meshmerge.py new file mode 100644 index 0000000..d1e16f2 --- /dev/null +++ b/test/slicers/utest_meshmerge.py @@ -0,0 +1,21 @@ +""" +Tests for mesh merging operations. + +It's pretty hard to test componentwise, but we can do some tests of the general behaviour +""" + +from sasdata.data_util.slicing.meshes.meshmerge import meshmerge +from test.slicers.meshes_for_testing import ( + grid_mesh, shape_mesh, expected_grid_mappings, expected_shape_mappings) + + +def test_meshmerge_mappings(): + + combined_mesh, grid_mappings, shape_mappings = meshmerge(grid_mesh, shape_mesh) + + for triangle_cell, grid_cell in expected_grid_mappings: + assert grid_mappings[triangle_cell] == grid_cell + + for triangle_cell, shape_cell in expected_shape_mappings: + assert shape_mappings[triangle_cell] == shape_cell + From f7fc0a5272c1ab8f5ccda48a4e657359ad17c97e Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 28 Sep 2023 02:56:00 +0100 Subject: [PATCH 05/35] Implementation of Rebinner base class --- sasdata/data_util/slicing/meshes/mesh.py | 28 +++++ sasdata/data_util/slicing/rebinning.py | 128 +++++++++++++++++++++++ test/slicers/utest_meshmerge.py | 7 ++ 3 files changed, 163 insertions(+) create mode 100644 sasdata/data_util/slicing/rebinning.py diff --git a/sasdata/data_util/slicing/meshes/mesh.py b/sasdata/data_util/slicing/meshes/mesh.py index b52b3e8..0f12102 100644 --- a/sasdata/data_util/slicing/meshes/mesh.py +++ b/sasdata/data_util/slicing/meshes/mesh.py @@ -69,6 +69,34 @@ def __init__(self, self.n_edges = len(self.edges) self.n_cells = len(self.cells) + # Areas + self._areas = None + + @property + def areas(self): + """ Areas of cells """ + + if self._areas is None: + # Calculate areas + areas = [] + for cell in self.cells: + # Use triangle shoelace formula, basically calculate the + # determinant based on of triangles with one point at 0,0 + a_times_2 = 0.0 + for i1, i2 in closed_loop_edges(cell): + p1 = self.points[i1, :] + p2 = self.points[i2, :] + a_times_2 += p1[0]*p2[1] - p1[1]*p2[0] + + areas.append(0.5*np.abs(a_times_2)) + + # Save in cache + self._areas = np.ndarray(areas) + + # Return cache + return self._areas + + def show(self, actually_show=True, show_labels=False, **kwargs): """ Show on a plot """ ax = plt.gca() diff --git a/sasdata/data_util/slicing/rebinning.py b/sasdata/data_util/slicing/rebinning.py new file mode 100644 index 0000000..c6ba607 --- /dev/null +++ b/sasdata/data_util/slicing/rebinning.py @@ -0,0 +1,128 @@ +from abc import ABC, abstractmethod +from typing import Optional +from dataclasses import dataclass + +import numpy as np + +from sasdata.data_util.slicing.meshes.mesh import Mesh +from sasdata.data_util.slicing.meshes.voronoi_mesh import voronoi_mesh +from sasdata.data_util.slicing.meshes.meshmerge import meshmerge + + +@dataclass +class CacheData: + """ Data cached for repeated calculations with the same coordinates """ + input_coordinates: np.ndarray # Input data + input_coordinates_mesh: Mesh # Mesh of the input data + merged_mesh_data: tuple[Mesh, np.ndarray, np.ndarray] # mesh information about the merging + + +class Rebinner(): + + allowable_orders = [-1,0,1] + + def __init__(self, order): + """ Base class for rebinning methods""" + + self._order = order + self._bin_mesh_cache: Optional[Mesh] = None # cached version of the output bin mesh + + # Output dependent caching + self._input_cache: Optional[CacheData] = None + + if order not in Rebinner.allowable_orders: + raise ValueError(f"Expected order to be in {Rebinner.allowable_orders}, got {order}") + + @abstractmethod + def _bin_coordinates(self) -> np.ndarray: + """ Coordinates for the output bins """ + + @abstractmethod + def _bin_mesh(self) -> Mesh: + """ Get the meshes used for binning """ + + @property + def bin_mesh(self): + if self._bin_mesh_cache is None: + bin_mesh = self._bin_mesh() + self._data_mesh_cache = bin_mesh + else: + return self._bin_mesh_cache + + def _post_processing(self, coordinates, values) -> tuple[np.ndarray, np.ndarray]: + """ Perform post-processing on the mesh binned values """ + # Default is to do nothing, override if needed + return coordinates, values + + def _do_binning(self, data): + """ Main binning algorithm """ + + def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray) -> np.ndarray: + """ Main calculation """ + + if self._order == -1: + # Construct the input output mapping just based on input points being the output cells, + # Equivalent to the original binning method + + pass + + else: + # Use a mapping based on meshes + + # Either create de-cache the appropriate mesh + # Why not use a hash? Hashing takes time, equality checks are pretty fast, need to check equality + # when there is a hit anyway in case of very rare chance of collision, hits are the most common case, + # we want it to work 100% of the time, not 99.9999% + if self._input_cache is not None and np.all(self._input_cache.input_coordinates == input_coordinates): + + input_coordinate_mesh = self._input_cache.input_coordinates_mesh + merge_data = self._input_cache.merged_mesh_data + + else: + # Calculate mesh data + input_coordinate_mesh = voronoi_mesh(input_coordinates[:,0], input_coordinates[:, 1]) + self._data_mesh_cahce = input_coordinate_mesh + + merge_data = meshmerge(self.bin_mesh, input_coordinate_mesh) + + # Cache mesh data + self._input_cache = CacheData( + input_coordinates=input_coordinates, + input_coordinates_mesh=input_coordinate_mesh, + merged_mesh_data=merge_data) + + merged_mesh, merged_to_input, merged_to_output = merge_data + + # Calculate values according to the order parameter + + if self._order == 0: + # Based on the overlap of cells only + + input_areas = input_coordinate_mesh.areas + output = np.zeros(self.bin_mesh.n_cells, dtype=float) + + for input_index, output_index, area in zip(merged_to_input, merged_to_output, merged_mesh.areas): + output[output_index] += input_data[input_index] * area / input_areas[input_data] + + return output + + elif self._order == 1: + raise NotImplementedError("1st order (linear) interpolation currently not implemented") + + else: + raise ValueError(f"Expected order to be in {Rebinner.allowable_orders}, got {self._order}") + + def sum(self, input_coordinates: np.ndarray, data: np.ndarray) -> np.ndarray: + """ Return the summed data in the output bins """ + return self._calculate(input_coordinates, data) + + def error_propagate(self, input_coordinates: np.ndarray, data: np.ndarray, errors) -> np.ndarray: + raise NotImplementedError("Error propagation not implemented yet") + + def resolution_propagate(self, input_coordinates: np.ndarray, data: np.ndarray, errors) -> np.ndarray: + raise NotImplementedError("Resolution propagation not implemented yet") + + def average(self, input_coordinates: np.ndarray, data: np.ndarray) -> np.ndarray: + """ Return the averaged data in the output bins """ + return self._calculate(input_coordinates, data) / self.bin_mesh.areas + diff --git a/test/slicers/utest_meshmerge.py b/test/slicers/utest_meshmerge.py index d1e16f2..f745d02 100644 --- a/test/slicers/utest_meshmerge.py +++ b/test/slicers/utest_meshmerge.py @@ -10,6 +10,13 @@ def test_meshmerge_mappings(): + """ Test the output of meshmerge is correct + + IMPORTANT IF TESTS FAIL!!!... The docs for scipy.spatial.Voronoi and Delaunay + say that the ordering of faces might depend on machine precession. Thus, these + tests might not be reliable... we'll see how they play out + """ + combined_mesh, grid_mappings, shape_mappings = meshmerge(grid_mesh, shape_mesh) From b8f06946564f9c4cc56da53397455fd5965f7de8 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 28 Sep 2023 10:06:44 +0100 Subject: [PATCH 06/35] Work towards demo --- sasdata/data_util/slicing/meshes/mesh.py | 29 +++++++++++++-- .../data_util/slicing/meshes/voronoi_mesh.py | 11 ++++++ sasdata/data_util/slicing/rebinning.py | 21 ++++++++--- sasdata/data_util/slicing/slicer_demo.py | 19 ++++++++++ .../data_util/slicing/slicers/AnularSector.py | 35 +++++++++++++++++++ 5 files changed, 109 insertions(+), 6 deletions(-) create mode 100644 sasdata/data_util/slicing/slicer_demo.py create mode 100644 sasdata/data_util/slicing/slicers/AnularSector.py diff --git a/sasdata/data_util/slicing/meshes/mesh.py b/sasdata/data_util/slicing/meshes/mesh.py index 0f12102..cad7b5f 100644 --- a/sasdata/data_util/slicing/meshes/mesh.py +++ b/sasdata/data_util/slicing/meshes/mesh.py @@ -3,6 +3,7 @@ import numpy as np import matplotlib.pyplot as plt +from matplotlib import cm from matplotlib.collections import LineCollection from sasdata.data_util.slicing.meshes.util import closed_loop_edges @@ -72,6 +73,12 @@ def __init__(self, # Areas self._areas = None + def find_locations(self, points): + """ Find indices of cells containing the input points """ + + + + @property def areas(self): """ Areas of cells """ @@ -119,6 +126,24 @@ def show(self, actually_show=True, show_labels=False, **kwargs): if actually_show: plt.show() - def show_data(self, data: np.ndarray, show_mesh=True): + def show_data(self, data: np.ndarray, cmap='winter', mesh_color='white', show_mesh=True, actually_show=True): """ Show with data """ - raise NotImplementedError("Show data not implemented") \ No newline at end of file + + colormap = cm.get_cmap(cmap, 256) + + cmin = np.min(data) + cmax = np.max(data) + + color_index_map = np.array(255 * (data - cmin) / (cmax - cmin), dtype=int) + + for cell, color_index in zip(self.cells, color_index_map): + + color = colormap(color_index) + + plt.fill(self.points[cell, 0], self.points[cell, 1], color=color) + + if show_mesh: + self.show(actually_show=False, color=mesh_color) + + if actually_show: + self.show() \ No newline at end of file diff --git a/sasdata/data_util/slicing/meshes/voronoi_mesh.py b/sasdata/data_util/slicing/meshes/voronoi_mesh.py index 77db2a6..9754880 100644 --- a/sasdata/data_util/slicing/meshes/voronoi_mesh.py +++ b/sasdata/data_util/slicing/meshes/voronoi_mesh.py @@ -7,10 +7,21 @@ def voronoi_mesh(x, y) -> Mesh: input_data = np.array((x.reshape(-1), y.reshape(-1))).T + + # Need to make sure mesh covers a finite region, probably not important for + # much data stuff, but is important for plotting + # To do this first need to find an appropriate region + # Then we need to adjust the mesh to deal with these points + voronoi = Voronoi(input_data) + + + finite_cells = [region for region in voronoi.regions if -1 not in region and len(region) > 0] + + return Mesh(points=voronoi.vertices, cells=finite_cells) diff --git a/sasdata/data_util/slicing/rebinning.py b/sasdata/data_util/slicing/rebinning.py index c6ba607..06ace87 100644 --- a/sasdata/data_util/slicing/rebinning.py +++ b/sasdata/data_util/slicing/rebinning.py @@ -19,7 +19,6 @@ class CacheData: class Rebinner(): - allowable_orders = [-1,0,1] def __init__(self, order): """ Base class for rebinning methods""" @@ -30,8 +29,9 @@ def __init__(self, order): # Output dependent caching self._input_cache: Optional[CacheData] = None - if order not in Rebinner.allowable_orders: - raise ValueError(f"Expected order to be in {Rebinner.allowable_orders}, got {order}") + if order not in self.allowable_orders: + raise ValueError(f"Expected order to be in {self.allowable_orders}, got {order}") + @abstractmethod def _bin_coordinates(self) -> np.ndarray: @@ -41,6 +41,10 @@ def _bin_coordinates(self) -> np.ndarray: def _bin_mesh(self) -> Mesh: """ Get the meshes used for binning """ + @property + def allowable_orders(self) -> list[int]: + return [-1, 0, 1] + @property def bin_mesh(self): if self._bin_mesh_cache is None: @@ -104,13 +108,22 @@ def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray) -> n for input_index, output_index, area in zip(merged_to_input, merged_to_output, merged_mesh.areas): output[output_index] += input_data[input_index] * area / input_areas[input_data] + return output elif self._order == 1: + # Linear interpolation requires the following relationship with the data, + # as the input data is the total over the whole input cell, the linear + # interpolation requires continuity at the vertices, and a constraint on the + # integral. + # + # We can take each of the input points, and the associated values, and solve a system + # of linear equations that gives a total value. + raise NotImplementedError("1st order (linear) interpolation currently not implemented") else: - raise ValueError(f"Expected order to be in {Rebinner.allowable_orders}, got {self._order}") + raise ValueError(f"Expected order to be in {self.allowable_orders}, got {self._order}") def sum(self, input_coordinates: np.ndarray, data: np.ndarray) -> np.ndarray: """ Return the summed data in the output bins """ diff --git a/sasdata/data_util/slicing/slicer_demo.py b/sasdata/data_util/slicing/slicer_demo.py new file mode 100644 index 0000000..775c1d9 --- /dev/null +++ b/sasdata/data_util/slicing/slicer_demo.py @@ -0,0 +1,19 @@ +""" Dev docs: """ + +import numpy as np + +from sasdata.data_util.slicing.slicers import AnularSector +from sasdata.data_util.slicing.meshes.mesh import Mesh +from sasdata.data_util.slicing.meshes.voronoi_mesh import voronoi_mesh + + + +if __name__ == "__main__": + + # Demo of sums, annular sector over some not very circular data + + q_range = 1.5 + + test_coordinates = (2*q_range)*(np.random.random((100, 2))-0.5) + + # Demo of averaging, annular sector over ring shaped data \ No newline at end of file diff --git a/sasdata/data_util/slicing/slicers/AnularSector.py b/sasdata/data_util/slicing/slicers/AnularSector.py new file mode 100644 index 0000000..bf3021d --- /dev/null +++ b/sasdata/data_util/slicing/slicers/AnularSector.py @@ -0,0 +1,35 @@ +import numpy as np + +from sasdata.data_util.slicing.rebinning import Rebinner +from sasdata.data_util.slicing.meshes.mesh import Mesh + +class AnularSector(Rebinner): + """ A single annular sector (wedge sum)""" + def __init__(self, q0: float, q1: float, phi0: float, phi1: float, order: int=1, points_per_degree: int=2): + super().__init__(order) + + self.q0 = q0 + self.q1 = q1 + self.phi0 = phi0 + self.phi1 = phi1 + + self.points_per_degree = points_per_degree + + def _bin_mesh(self) -> Mesh: + + n_points = 1 + 180*self.points_per_degree*(self.phi1 - self.phi0) / np.pi + + angles = np.linspace(self.phi0, self.phi1, n_points) + + row1 = self.q0 * np.array([np.cos(angles), np.sin(angles)]) + row2 = self.q1 * np.array([np.cos(angles), np.sin(angles)])[:, ::-1] + + points = np.concatenate((row1, row2), axis=1) + + cells = [i for i in range(2*n_points)] + + return Mesh(points=points, cells=cells) + + def _bin_coordinates(self) -> np.ndarray: + return np.array([], dtype=float) + From 583c8b48e2c4f280ccd71813a3f77d15ac7d6857 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 28 Sep 2023 12:59:02 +0100 Subject: [PATCH 07/35] Voronoi mesh edges and ordering --- sasdata/data_util/slicing/meshes/mesh.py | 16 +++- .../data_util/slicing/meshes/voronoi_mesh.py | 80 +++++++++++++++++-- test/slicers/meshes_for_testing.py | 46 +++++++---- 3 files changed, 114 insertions(+), 28 deletions(-) diff --git a/sasdata/data_util/slicing/meshes/mesh.py b/sasdata/data_util/slicing/meshes/mesh.py index cad7b5f..05f4d33 100644 --- a/sasdata/data_util/slicing/meshes/mesh.py +++ b/sasdata/data_util/slicing/meshes/mesh.py @@ -76,7 +76,7 @@ def __init__(self, def find_locations(self, points): """ Find indices of cells containing the input points """ - + @property @@ -98,7 +98,7 @@ def areas(self): areas.append(0.5*np.abs(a_times_2)) # Save in cache - self._areas = np.ndarray(areas) + self._areas = np.array(areas) # Return cache return self._areas @@ -126,11 +126,21 @@ def show(self, actually_show=True, show_labels=False, **kwargs): if actually_show: plt.show() - def show_data(self, data: np.ndarray, cmap='winter', mesh_color='white', show_mesh=True, actually_show=True): + def show_data(self, + data: np.ndarray, + cmap='winter', + mesh_color='white', + show_mesh=True, + actually_show=True, + density=False): + """ Show with data """ colormap = cm.get_cmap(cmap, 256) + if density: + data = data / self.areas + cmin = np.min(data) cmax = np.max(data) diff --git a/sasdata/data_util/slicing/meshes/voronoi_mesh.py b/sasdata/data_util/slicing/meshes/voronoi_mesh.py index 9754880..3497fbb 100644 --- a/sasdata/data_util/slicing/meshes/voronoi_mesh.py +++ b/sasdata/data_util/slicing/meshes/voronoi_mesh.py @@ -4,28 +4,92 @@ from sasdata.data_util.slicing.meshes.mesh import Mesh -def voronoi_mesh(x, y) -> Mesh: +def voronoi_mesh(x, y, debug_plot=False) -> Mesh: + """ Create a mesh based on a voronoi diagram of points """ input_data = np.array((x.reshape(-1), y.reshape(-1))).T # Need to make sure mesh covers a finite region, probably not important for # much data stuff, but is important for plotting - # To do this first need to find an appropriate region - # Then we need to adjust the mesh to deal with these points + # + # * We want the cells at the edge of the mesh to have a reasonable size, definitely not infinite + # * The exact size doesn't matter that much + # * It should work well with a grid, but also + # * ...it should be robust so that if the data isn't on a grid, it doesn't cause any serious problems + # + # Plan: Create a square border of points that are totally around the points, this is + # at the distance it would be if it was an extra row of grid points + # to do this we'll need + # 1) an estimate of the grid spacing + # 2) the bounding box of the grid + # + # Use the median area of finite voronoi cells as an estimate voronoi = Voronoi(input_data) + finite_cells = [region for region in voronoi.regions if -1 not in region and len(region) > 0] + premesh = Mesh(points=voronoi.vertices, cells=finite_cells) + area_spacing = np.median(premesh.areas) + gap = np.sqrt(area_spacing) + # Bounding box is easy + x_min, y_min = np.min(input_data, axis=0) + x_max, y_max = np.max(input_data, axis=0) + # Create a border + n_x = np.round((x_max - x_min)/gap).astype(int) + n_y = np.round((y_max - y_min)/gap).astype(int) - finite_cells = [region for region in voronoi.regions if -1 not in region and len(region) > 0] + top_bottom_xs = np.linspace(x_min - gap, x_max + gap, n_x + 3) + left_right_ys = np.linspace(y_min, y_max, n_y + 1) + top = np.array([top_bottom_xs, (y_max + gap) * np.ones_like(top_bottom_xs)]) + bottom = np.array([top_bottom_xs, (y_min - gap) * np.ones_like(top_bottom_xs)]) + left = np.array([(x_min - gap) * np.ones_like(left_right_ys), left_right_ys]) + right = np.array([(x_max + gap) * np.ones_like(left_right_ys), left_right_ys]) + added_points = np.concatenate((top, bottom, left, right), axis=1).T - return Mesh(points=voronoi.vertices, cells=finite_cells) + if debug_plot: + import matplotlib.pyplot as plt + plt.scatter(x, y) + plt.scatter(added_points[:, 0], added_points[:, 1]) + plt.show() + new_points = np.concatenate((input_data, added_points), axis=0) + voronoi = Voronoi(new_points) -if __name__ == "__main__": + # Remove the cells that correspond to the added edge points, + # Because the points on the edge of the square are (weakly) convex, these + # regions be infinite + + # finite_cells = [region for region in voronoi.regions if -1 not in region and len(region) > 0] + + # ... however, we can just use .region_points + input_regions = voronoi.point_region[:input_data.shape[0]] + cells = [voronoi.regions[region_index] for region_index in input_regions] + + return Mesh(points=voronoi.vertices, cells=cells) + + +def square_grid_check(): + values = np.linspace(-10, 10, 21) + x, y = np.meshgrid(values, values) + + mesh = voronoi_mesh(x, y) + + mesh.show(show_labels=True) + +def random_grid_check(): + import matplotlib.pyplot as plt points = np.random.random((100, 2)) - mesh = voronoi_mesh(points[:,0], points[:,1]) - mesh.show() \ No newline at end of file + mesh = voronoi_mesh(points[:, 0], points[:, 1], True) + mesh.show(actually_show=False) + plt.scatter(points[:, 0], points[:, 1]) + plt.show() + + +if __name__ == "__main__": + square_grid_check() + # random_grid_check() + diff --git a/test/slicers/meshes_for_testing.py b/test/slicers/meshes_for_testing.py index ff87dc8..c742624 100644 --- a/test/slicers/meshes_for_testing.py +++ b/test/slicers/meshes_for_testing.py @@ -32,26 +32,38 @@ # Subset of the mappings that meshmerge should include # This can be read off the plots generated below + + expected_shape_mappings = [ - (98, -1), - (99, -1), - (12, 0), + (100, -1), + (152, -1), + (141, -1), + (172, -1), + (170, -1), + (0, -1), (1, -1), - (148, 1), - (149, 1), - (110, 1), - (144, -1), - (123, -1)] - + (8, 0), + (9, 0), + (37, 0), + (83, 0), + (190, 1), + (186, 1), + (189, 1), + (193, 1) +] expected_grid_mappings = [ - (89, 1), - (146, 29), - (66, 34), - (112, 45) + (89, 0), + (90, 1), + (148, 16), + (175, 35), + (60, 47), + (44, 47), + (80, 60) ] + if __name__ == "__main__": import matplotlib.pyplot as plt @@ -62,14 +74,14 @@ combined_mesh.show(actually_show=False, show_labels=True, color='k') grid_mesh.show(actually_show=False, show_labels=True, color='r') - plt.xlim([-4, 4]) - plt.ylim([-4, 4]) + plt.xlim([-5, 5]) + plt.ylim([-5, 5]) plt.figure() combined_mesh.show(actually_show=False, show_labels=True, color='k') shape_mesh.show(actually_show=False, show_labels=True, color='r') - plt.xlim([-4, 4]) - plt.ylim([-4, 4]) + plt.xlim([-5, 5]) + plt.ylim([-5, 5]) plt.show() From bf653adfae021f2f4924b76dc7663bed17de6218 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 28 Sep 2023 13:54:43 +0100 Subject: [PATCH 08/35] It works, needs benchmarking --- sasdata/data_util/slicing/meshes/mesh.py | 4 +- sasdata/data_util/slicing/rebinning.py | 28 ++++++++----- sasdata/data_util/slicing/slicer_demo.py | 42 +++++++++++++++++-- .../data_util/slicing/slicers/AnularSector.py | 14 +++++-- 4 files changed, 69 insertions(+), 19 deletions(-) diff --git a/sasdata/data_util/slicing/meshes/mesh.py b/sasdata/data_util/slicing/meshes/mesh.py index 05f4d33..ba31c51 100644 --- a/sasdata/data_util/slicing/meshes/mesh.py +++ b/sasdata/data_util/slicing/meshes/mesh.py @@ -130,7 +130,7 @@ def show_data(self, data: np.ndarray, cmap='winter', mesh_color='white', - show_mesh=True, + show_mesh=False, actually_show=True, density=False): @@ -150,7 +150,7 @@ def show_data(self, color = colormap(color_index) - plt.fill(self.points[cell, 0], self.points[cell, 1], color=color) + plt.fill(self.points[cell, 0], self.points[cell, 1], color=color, edgecolor=None) if show_mesh: self.show(actually_show=False, color=mesh_color) diff --git a/sasdata/data_util/slicing/rebinning.py b/sasdata/data_util/slicing/rebinning.py index 06ace87..86818f7 100644 --- a/sasdata/data_util/slicing/rebinning.py +++ b/sasdata/data_util/slicing/rebinning.py @@ -46,12 +46,13 @@ def allowable_orders(self) -> list[int]: return [-1, 0, 1] @property - def bin_mesh(self): + def bin_mesh(self) -> Mesh: + if self._bin_mesh_cache is None: bin_mesh = self._bin_mesh() - self._data_mesh_cache = bin_mesh - else: - return self._bin_mesh_cache + self._bin_mesh_cache = bin_mesh + + return self._bin_mesh_cache def _post_processing(self, coordinates, values) -> tuple[np.ndarray, np.ndarray]: """ Perform post-processing on the mesh binned values """ @@ -95,7 +96,7 @@ def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray) -> n input_coordinates_mesh=input_coordinate_mesh, merged_mesh_data=merge_data) - merged_mesh, merged_to_input, merged_to_output = merge_data + merged_mesh, merged_to_output, merged_to_input = merge_data # Calculate values according to the order parameter @@ -105,8 +106,15 @@ def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray) -> n input_areas = input_coordinate_mesh.areas output = np.zeros(self.bin_mesh.n_cells, dtype=float) + print(np.max(merged_to_input)) + print(np.max(merged_to_output)) + for input_index, output_index, area in zip(merged_to_input, merged_to_output, merged_mesh.areas): - output[output_index] += input_data[input_index] * area / input_areas[input_data] + if input_index == -1 or output_index == -1: + # merged region does not correspond to anything of interest + continue + + output[output_index] += input_data[input_index] * area / input_areas[input_index] return output @@ -125,9 +133,9 @@ def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray) -> n else: raise ValueError(f"Expected order to be in {self.allowable_orders}, got {self._order}") - def sum(self, input_coordinates: np.ndarray, data: np.ndarray) -> np.ndarray: + def sum(self, x: np.ndarray, y: np.ndarray, data: np.ndarray) -> np.ndarray: """ Return the summed data in the output bins """ - return self._calculate(input_coordinates, data) + return self._calculate(np.array((x.reshape(-1), y.reshape(-1))).T, data) def error_propagate(self, input_coordinates: np.ndarray, data: np.ndarray, errors) -> np.ndarray: raise NotImplementedError("Error propagation not implemented yet") @@ -135,7 +143,7 @@ def error_propagate(self, input_coordinates: np.ndarray, data: np.ndarray, error def resolution_propagate(self, input_coordinates: np.ndarray, data: np.ndarray, errors) -> np.ndarray: raise NotImplementedError("Resolution propagation not implemented yet") - def average(self, input_coordinates: np.ndarray, data: np.ndarray) -> np.ndarray: + def average(self, x: np.ndarray, y: np.ndarray, data: np.ndarray) -> np.ndarray: """ Return the averaged data in the output bins """ - return self._calculate(input_coordinates, data) / self.bin_mesh.areas + return self._calculate(np.array((x, y)).T, data) / self.bin_mesh.areas diff --git a/sasdata/data_util/slicing/slicer_demo.py b/sasdata/data_util/slicing/slicer_demo.py index 775c1d9..e76e1c4 100644 --- a/sasdata/data_util/slicing/slicer_demo.py +++ b/sasdata/data_util/slicing/slicer_demo.py @@ -1,19 +1,53 @@ -""" Dev docs: """ +""" Dev docs: Demo to show the behaviour of the re-binning methods """ import numpy as np -from sasdata.data_util.slicing.slicers import AnularSector +import matplotlib.pyplot as plt + +from sasdata.data_util.slicing.slicers.AnularSector import AnularSector from sasdata.data_util.slicing.meshes.mesh import Mesh from sasdata.data_util.slicing.meshes.voronoi_mesh import voronoi_mesh if __name__ == "__main__": + q_range = 1.5 + + + x = (2*q_range)*(np.random.random(400)-0.5) + y = (2*q_range)*(np.random.random(400)-0.5) + + display_mesh = voronoi_mesh(x, y) # Demo of sums, annular sector over some not very circular data - q_range = 1.5 - test_coordinates = (2*q_range)*(np.random.random((100, 2))-0.5) + def lobe_test_function(x, y): + return 1 + np.sin(x*np.pi/q_range)*np.sin(y*np.pi/q_range) + + + random_lobe_data = lobe_test_function(x, y) + + plt.figure("Input Dataset 1") + display_mesh.show_data(random_lobe_data, actually_show=False) + + data_order_0 = [] + + for index, size in enumerate(np.linspace(0.1, 1, 100)): + q0 = 0.75 - 0.6*size + q1 = 0.75 + 0.6*size + phi0 = np.pi/2 - size + phi1 = np.pi/2 + size + + rebinner = AnularSector(q0, q1, phi0, phi1, order=0) + + data_order_0.append(rebinner.sum(x, y, random_lobe_data)) + + if index % 10 == 0: + plt.figure("Regions") + rebinner.bin_mesh.show(actually_show=False) + + plt.show() + # Demo of averaging, annular sector over ring shaped data \ No newline at end of file diff --git a/sasdata/data_util/slicing/slicers/AnularSector.py b/sasdata/data_util/slicing/slicers/AnularSector.py index bf3021d..e9f1377 100644 --- a/sasdata/data_util/slicing/slicers/AnularSector.py +++ b/sasdata/data_util/slicing/slicers/AnularSector.py @@ -17,19 +17,27 @@ def __init__(self, q0: float, q1: float, phi0: float, phi1: float, order: int=1, def _bin_mesh(self) -> Mesh: - n_points = 1 + 180*self.points_per_degree*(self.phi1 - self.phi0) / np.pi + n_points = int(1 + 180*self.points_per_degree*(self.phi1 - self.phi0) / np.pi) angles = np.linspace(self.phi0, self.phi1, n_points) row1 = self.q0 * np.array([np.cos(angles), np.sin(angles)]) row2 = self.q1 * np.array([np.cos(angles), np.sin(angles)])[:, ::-1] - points = np.concatenate((row1, row2), axis=1) + points = np.concatenate((row1, row2), axis=1).T - cells = [i for i in range(2*n_points)] + cells = [[i for i in range(2*n_points)]] return Mesh(points=points, cells=cells) def _bin_coordinates(self) -> np.ndarray: return np.array([], dtype=float) + +def main(): + """ Just show a random example""" + AnularSector(1, 2, 1, 2).bin_mesh.show() + + +if __name__ == "__main__": + main() \ No newline at end of file From 8cc300aad29c27e7079a0fd884ce19f161a4092d Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Fri, 29 Sep 2023 14:47:01 +0100 Subject: [PATCH 09/35] Much faster assignment/merge method --- sasdata/data_util/slicing/meshes/mesh.py | 93 ++++++++++++- sasdata/data_util/slicing/meshes/meshmerge.py | 124 +++++++++++------- sasdata/data_util/slicing/rebinning.py | 7 +- sasdata/data_util/slicing/slicer_demo.py | 8 +- test/slicers/meshes_for_testing.py | 30 ++++- test/slicers/utest_point_assignment.py | 5 + 6 files changed, 206 insertions(+), 61 deletions(-) create mode 100644 test/slicers/utest_point_assignment.py diff --git a/sasdata/data_util/slicing/meshes/mesh.py b/sasdata/data_util/slicing/meshes/mesh.py index ba31c51..6b4df93 100644 --- a/sasdata/data_util/slicing/meshes/mesh.py +++ b/sasdata/data_util/slicing/meshes/mesh.py @@ -51,19 +51,24 @@ def __init__(self, edge_lookup = {edge: i for i, edge in enumerate(self.edges)} self.cells_to_edges = [] + self.cells_to_edges_signs = [] for cell in cells: this_cell_data = [] + this_sign_data = [] for a, b in closed_loop_edges(cell): # make sure the representation is unique if a > b: this_cell_data.append(edge_lookup[(a, b)]) + this_sign_data.append(1) else: this_cell_data.append(edge_lookup[(b, a)]) + this_sign_data.append(-1) self.cells_to_edges.append(this_cell_data) + self.cells_to_edges_signs.append(this_sign_data) # Counts for elements self.n_points = self.points.shape[0] @@ -73,11 +78,6 @@ def __init__(self, # Areas self._areas = None - def find_locations(self, points): - """ Find indices of cells containing the input points """ - - - @property def areas(self): @@ -126,6 +126,71 @@ def show(self, actually_show=True, show_labels=False, **kwargs): if actually_show: plt.show() + def locate_points(self, x: np.ndarray, y: np.ndarray): + """ Find the cells that contain the specified points""" + + x = x.reshape(-1) + y = y.reshape(-1) + + xy = np.concatenate(([x], [y]), axis=1) + + # The most simple implementation is not particularly fast, especially in python + # + # Less obvious, but hopefully faster strategy + # + # Ultimately, checking the inclusion of a point within a polygon + # requires checking the crossings of a half line with the polygon's + # edges. + # + # A fairly efficient thing to do is to check every edge for crossing + # the axis parallel lines x=point_x. + # Then these edges that cross can map back to the polygons they're in + # and a final check for inclusion can be done with the edge sign property + # and some explicit checking of the + # + # Basic idea is: + # 1) build a matrix for each point-edge pair + # True if the edge crosses the half-line above a point + # 2) for each cell get the winding number by evaluating the + # sum of the component edges, weighted 1/-1 according to direction + + + edges = np.array(self.edges) + + edge_xy_1 = self.points[edges[:, 0], :] + edge_xy_2 = self.points[edges[:, 1], :] + + edge_x_1 = edge_xy_1[:, 0] + edge_x_2 = edge_xy_2[:, 0] + + + + # Make an n_edges-by-n_inputs boolean matrix that indicates which of the + # edges cross x=points_x line + crossers = np.logical_xor( + edge_x_1.reshape(-1, 1) < x.reshape(1, -1), + edge_x_2.reshape(-1, 1) < x.reshape(1, -1)) + + # Calculate the gradients, some might be infs, but none that matter will be + # TODO: Disable warnings + gradients = (edge_xy_2[:, 1] - edge_xy_1[:, 1]) / (edge_xy_2[:, 0] - edge_xy_1[:, 0]) + + # Distance to crossing points edge 0 + delta_x = x.reshape(1, -1) - edge_x_1.reshape(-1, 1) + + # Signed distance from point to y (doesn't really matter which sign) + delta_y = gradients.reshape(-1, 1) * delta_x + edge_xy_1[:, 1:] - y.reshape(1, -1) + + score_matrix = np.logical_and(delta_y > 0, crossers) + + output = -np.ones(len(x), dtype=int) + for cell_index, (cell_edges, sign) in enumerate(zip(self.cells_to_edges, self.cells_to_edges_signs)): + cell_score = np.sum(score_matrix[cell_edges, :] * np.array(sign).reshape(-1, 1), axis=0) + points_in_cell = np.abs(cell_score) == 1 + output[points_in_cell] = cell_index + + return output + def show_data(self, data: np.ndarray, cmap='winter', @@ -156,4 +221,20 @@ def show_data(self, self.show(actually_show=False, color=mesh_color) if actually_show: - self.show() \ No newline at end of file + self.show() + + +if __name__ == "__main__": + from test.slicers.meshes_for_testing import location_test_mesh, location_test_points_x, location_test_points_y + + cell_indices = location_test_mesh.locate_points(location_test_points_x, location_test_points_y) + + print(cell_indices) + + for i in range(location_test_mesh.n_cells): + inds = cell_indices == i + plt.scatter( + location_test_points_x.reshape(-1)[inds], + location_test_points_y.reshape(-1)[inds]) + + location_test_mesh.show() \ No newline at end of file diff --git a/sasdata/data_util/slicing/meshes/meshmerge.py b/sasdata/data_util/slicing/meshes/meshmerge.py index 3ce52ba..2524c51 100644 --- a/sasdata/data_util/slicing/meshes/meshmerge.py +++ b/sasdata/data_util/slicing/meshes/meshmerge.py @@ -5,6 +5,8 @@ from sasdata.data_util.slicing.meshes.util import closed_loop_edges +import time + def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray]: """ Take two lists of polygons and find their intersections @@ -21,6 +23,8 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] """ + t0 = time.time() + # Find intersections of all edges in mesh one with edges in mesh two new_x = [] @@ -89,6 +93,8 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] new_y.append(y) + t1 = time.time() + print("Edge intersections:", t1 - t0) # Build list of all input points, in a way that we can check for coincident points @@ -108,6 +114,11 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] output_mesh = delaunay_mesh(points[:, 0], points[:, 1]) + + t2 = time.time() + print("Delaunay:", t2 - t1) + + # Find centroids of all output triangles, and find which source cells they belong to ## step 1) Assign -1 to all cells of original meshes @@ -120,57 +131,72 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] centroid = np.sum(output_mesh.points[cell, :]/3, axis=0) centroids.append(centroid) - ## step 3) Perform checks based on winding number method (see wikipedia Point in Polygon). - for mesh, assignments in [ - (mesh_a, assignments_a), - (mesh_b, assignments_b)]: - - for centroid_index, centroid in enumerate(centroids): - for cell_index, cell in enumerate(mesh.cells): - - # Bounding box check - points = mesh.points[cell, :] - if np.any(centroid < np.min(points, axis=0)): # x or y less than any in polygon - continue + centroids = np.array(centroids) - if np.any(centroid > np.max(points, axis=0)): # x or y greater than any in polygon - continue + t3 = time.time() + print("Centroids:", t3 - t2) - # Winding number check - count directional crossings of vertical half line from centroid - winding_number = 0 - for i1, i2 in closed_loop_edges(cell): - p1 = mesh.points[i1, :] - p2 = mesh.points[i2, :] - # if the section xs do not straddle the x=centroid_x coordinate, then the - # edge cannot cross the half line. - # If it does, then remember which way it was - # * Careful about ends - # * Also, note that the p1[0] == p2[0] -> (no contribution) case is covered by the strict inequality - if p1[0] > centroid[0] >= p2[0]: - left_right = -1 - elif p2[0] > centroid[0] >= p1[0]: - left_right = 1 - else: - continue - - # Find the y point that it crosses x=centroid at - # note: denominator cannot be zero because of strict inequality above - gradient = (p2[1] - p1[1]) / (p2[0] - p1[0]) - x_delta = centroid[0] - p1[0] - y = p1[1] + x_delta * gradient - - if y > centroid[1]: - winding_number += left_right - - - if abs(winding_number) > 0: - # Do assignment of input cell to output triangle index - assignments[centroid_index] = cell_index - - # end cell loop - - # end centroid loop + ## step 3) Perform checks based on winding number method (see wikipedia Point in Polygon). + # + # # TODO: Brute force search is sllllloooooooowwwwww - keeping track of which points are where would be better + # for mesh, assignments in [ + # (mesh_a, assignments_a), + # (mesh_b, assignments_b)]: + # + # for centroid_index, centroid in enumerate(centroids): + # for cell_index, cell in enumerate(mesh.cells): + # + # # Bounding box check + # points = mesh.points[cell, :] + # if np.any(centroid < np.min(points, axis=0)): # x or y less than any in polygon + # continue + # + # if np.any(centroid > np.max(points, axis=0)): # x or y greater than any in polygon + # continue + # + # # Winding number check - count directional crossings of vertical half line from centroid + # winding_number = 0 + # for i1, i2 in closed_loop_edges(cell): + # p1 = mesh.points[i1, :] + # p2 = mesh.points[i2, :] + # + # # if the section xs do not straddle the x=centroid_x coordinate, then the + # # edge cannot cross the half line. + # # If it does, then remember which way it was + # # * Careful about ends + # # * Also, note that the p1[0] == p2[0] -> (no contribution) case is covered by the strict inequality + # if p1[0] > centroid[0] >= p2[0]: + # left_right = -1 + # elif p2[0] > centroid[0] >= p1[0]: + # left_right = 1 + # else: + # continue + # + # # Find the y point that it crosses x=centroid at + # # note: denominator cannot be zero because of strict inequality above + # gradient = (p2[1] - p1[1]) / (p2[0] - p1[0]) + # x_delta = centroid[0] - p1[0] + # y = p1[1] + x_delta * gradient + # + # if y > centroid[1]: + # winding_number += left_right + # + # + # if abs(winding_number) > 0: + # # Do assignment of input cell to output triangle index + # assignments[centroid_index] = cell_index + # break # point is assigned + # + # # end cell loop + # + # # end centroid loop + + assignments_a = mesh_a.locate_points(centroids[:, 0], centroids[:, 1]) + assignments_b = mesh_b.locate_points(centroids[:, 0], centroids[:, 1]) + + t4 = time.time() + print("Assignments:", t4 - t3) return output_mesh, assignments_a, assignments_b @@ -185,7 +211,7 @@ def main(): m2 = voronoi_mesh(np.random.random(n2), np.random.random(n2)) - mesh, _, _ = meshmerge(m1, m2) + mesh, assignement1, assignement2 = meshmerge(m1, m2) mesh.show() diff --git a/sasdata/data_util/slicing/rebinning.py b/sasdata/data_util/slicing/rebinning.py index 86818f7..7b6eea9 100644 --- a/sasdata/data_util/slicing/rebinning.py +++ b/sasdata/data_util/slicing/rebinning.py @@ -8,6 +8,7 @@ from sasdata.data_util.slicing.meshes.voronoi_mesh import voronoi_mesh from sasdata.data_util.slicing.meshes.meshmerge import meshmerge +import time @dataclass class CacheData: @@ -99,16 +100,13 @@ def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray) -> n merged_mesh, merged_to_output, merged_to_input = merge_data # Calculate values according to the order parameter - + t0 = time.time() if self._order == 0: # Based on the overlap of cells only input_areas = input_coordinate_mesh.areas output = np.zeros(self.bin_mesh.n_cells, dtype=float) - print(np.max(merged_to_input)) - print(np.max(merged_to_output)) - for input_index, output_index, area in zip(merged_to_input, merged_to_output, merged_mesh.areas): if input_index == -1 or output_index == -1: # merged region does not correspond to anything of interest @@ -116,6 +114,7 @@ def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray) -> n output[output_index] += input_data[input_index] * area / input_areas[input_index] + print("Main calc:", time.time() - t0) return output diff --git a/sasdata/data_util/slicing/slicer_demo.py b/sasdata/data_util/slicing/slicer_demo.py index e76e1c4..d60c0ac 100644 --- a/sasdata/data_util/slicing/slicer_demo.py +++ b/sasdata/data_util/slicing/slicer_demo.py @@ -33,7 +33,9 @@ def lobe_test_function(x, y): data_order_0 = [] - for index, size in enumerate(np.linspace(0.1, 1, 100)): + sizes = np.linspace(0.1, 1, 100) + + for index, size in enumerate(sizes): q0 = 0.75 - 0.6*size q1 = 0.75 + 0.6*size phi0 = np.pi/2 - size @@ -47,6 +49,10 @@ def lobe_test_function(x, y): plt.figure("Regions") rebinner.bin_mesh.show(actually_show=False) + plt.figure("Data") + + plt.plot(sizes, data_order_0) + plt.show() diff --git a/test/slicers/meshes_for_testing.py b/test/slicers/meshes_for_testing.py index c742624..7cb17b4 100644 --- a/test/slicers/meshes_for_testing.py +++ b/test/slicers/meshes_for_testing.py @@ -62,7 +62,31 @@ (80, 60) ] - +# +# Mesh location tests +# + +location_test_mesh_points = np.array([ + [0, 0], # 0 + [0, 1], # 1 + [0, 2], # 2 + [1, 0], # 3 + [1, 1], # 4 + [1, 2], # 5 + [2, 0], # 6 + [2, 1], # 7 + [2, 2]], dtype=float) + +location_test_mesh_cells = [ + [0, 1, 4, 3], + [1, 2, 5, 4], + [3, 4, 7, 6], + [4, 5, 8, 7]] + +location_test_mesh = Mesh(location_test_mesh_points, location_test_mesh_cells) + +test_coords = 0.25 + 0.5*np.arange(4) +location_test_points_x, location_test_points_y = np.meshgrid(test_coords, test_coords) if __name__ == "__main__": @@ -84,4 +108,8 @@ plt.xlim([-5, 5]) plt.ylim([-5, 5]) + plt.figure() + location_test_mesh.show(actually_show=False, show_labels=True) + plt.scatter(location_test_points_x, location_test_points_y) + plt.show() diff --git a/test/slicers/utest_point_assignment.py b/test/slicers/utest_point_assignment.py new file mode 100644 index 0000000..4ff53e7 --- /dev/null +++ b/test/slicers/utest_point_assignment.py @@ -0,0 +1,5 @@ + +from test.slicers.meshes_for_testing import location_test_mesh, location_test_points_x, location_test_points_y + +def test_location_assignment(): + pass \ No newline at end of file From 1f83877b37827174069fe774bd066c779f93eaa0 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 5 Oct 2023 12:49:13 +0100 Subject: [PATCH 10/35] Significantly faster edge crossing algorithm --- sasdata/data_util/slicing/meshes/meshmerge.py | 90 ++++++++----------- 1 file changed, 37 insertions(+), 53 deletions(-) diff --git a/sasdata/data_util/slicing/meshes/meshmerge.py b/sasdata/data_util/slicing/meshes/meshmerge.py index 2524c51..c0235dc 100644 --- a/sasdata/data_util/slicing/meshes/meshmerge.py +++ b/sasdata/data_util/slicing/meshes/meshmerge.py @@ -26,72 +26,56 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] t0 = time.time() # Find intersections of all edges in mesh one with edges in mesh two + # TODO: Speed this up - new_x = [] - new_y = [] - for edge_a in mesh_a.edges: - for edge_b in mesh_b.edges: + # Fastest way might just be to calculate the intersections of all lines on edges, + # see whether we need filtering afterwards - p1 = mesh_a.points[edge_a[0]] - p2 = mesh_a.points[edge_a[1]] - p3 = mesh_b.points[edge_b[0]] - p4 = mesh_b.points[edge_b[1]] + edges_a = np.array(mesh_a.edges, dtype=int) + edges_b = np.array(mesh_b.edges, dtype=int) - # Bounding box check + edge_a_1 = mesh_a.points[edges_a[:, 0], :] + edge_a_2 = mesh_a.points[edges_a[:, 1], :] + edge_b_1 = mesh_b.points[edges_b[:, 0], :] + edge_b_2 = mesh_b.points[edges_b[:, 1], :] - # First edge entirely to left of other - if max((p1[0], p2[0])) < min((p3[0], p4[0])): - continue + a_grid, b_grid = np.mgrid[0:mesh_a.n_edges, 0:mesh_b.n_edges] + a_grid = a_grid.reshape(-1) + b_grid = b_grid.reshape(-1) - # First edge entirely below other - if max((p1[1], p2[1])) < min((p3[1], p4[1])): - continue - - # First edge entirely to right of other - if min((p1[0], p2[0])) > max((p3[0], p4[0])): - continue - - # First edge entirely above other - if min((p1[1], p2[1])) > max((p3[1], p4[1])): - continue - - # - # Parametric description of intersection in terms of position along lines - # - # Simultaneous eqns (to reflect current wiki notation) - # s(x2 - x1) - t(x4 - x3) = x3 - x1 - # s(y2 - y1) - t(y4 - y3) = y3 - y1 - # - # in matrix form: - # m.(s,t) = v - # + p1 = edge_a_1[a_grid, :] + p2 = edge_a_2[a_grid, :] + p3 = edge_b_1[b_grid, :] + p4 = edge_b_2[b_grid, :] + # + # Solve the equations + # + # z_a1 + s delta_z_a = z_b1 + t delta_z_b + # + # for z = (x, y) + # - m = np.array([ - [p2[0] - p1[0], p3[0] - p4[0]], - [p2[1] - p1[1], p3[1] - p4[1]]]) + start_point_diff = p1 - p3 - v = np.array([p3[0] - p1[0], p3[1] - p1[1]]) + delta1 = p2 - p1 + delta3 = p4 - p3 - if np.linalg.det(m) == 0: - # Lines don't intersect, or are colinear in a way that doesn't matter - continue + deltas = np.concatenate(([-delta1], [delta3]), axis=0) + deltas = np.moveaxis(deltas, 0, 2) - st = np.linalg.solve(m, v) + st = np.linalg.solve(deltas, start_point_diff) - # As the purpose of this is finding new points for the merged mesh, we don't - # want new points if they are right at the end of the lines, hence non-strict - # inequalities here - if np.any(st <= 0) or np.any(st >= 1): - # Exclude intection points, that are not on the *segments* - continue + # Find the points where s and t are in (0, 1) - x = p1[0] + (p2[0] - p1[0])*st[0] - y = p1[1] + (p2[1] - p1[1])*st[0] + intersection_inds = np.logical_and( + np.logical_and(0 < st[:, 0], st[:, 0] < 1), + np.logical_and(0 < st[:, 1], st[:, 1] < 1)) - new_x.append(x) - new_y.append(y) + start_points_for_intersections = p1[intersection_inds, :] + deltas_for_intersections = delta1[intersection_inds, :] + points_to_add = start_points_for_intersections + st[intersection_inds, 0].reshape(-1,1) * deltas_for_intersections t1 = time.time() print("Edge intersections:", t1 - t0) @@ -102,7 +86,7 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] points = np.concatenate(( mesh_a.points, mesh_b.points, - np.array((new_x, new_y)).T + points_to_add )) From 555f76cf63d69a95fdc6ebb0205589f0eb5c3240 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 5 Oct 2023 13:57:08 +0100 Subject: [PATCH 11/35] Demo --- sasdata/data_util/slicing/meshes/mesh.py | 2 + sasdata/data_util/slicing/meshes/meshmerge.py | 68 ++--------- .../data_util/slicing/meshes/voronoi_mesh.py | 5 +- sasdata/data_util/slicing/rebinning.py | 41 +++---- sasdata/data_util/slicing/slicer_demo.py | 112 ++++++++++++++---- .../data_util/slicing/slicers/AnularSector.py | 6 +- 6 files changed, 126 insertions(+), 108 deletions(-) diff --git a/sasdata/data_util/slicing/meshes/mesh.py b/sasdata/data_util/slicing/meshes/mesh.py index 6b4df93..3ac23da 100644 --- a/sasdata/data_util/slicing/meshes/mesh.py +++ b/sasdata/data_util/slicing/meshes/mesh.py @@ -203,6 +203,8 @@ def show_data(self, colormap = cm.get_cmap(cmap, 256) + data = data.reshape(-1) + if density: data = data / self.areas diff --git a/sasdata/data_util/slicing/meshes/meshmerge.py b/sasdata/data_util/slicing/meshes/meshmerge.py index c0235dc..161c1e5 100644 --- a/sasdata/data_util/slicing/meshes/meshmerge.py +++ b/sasdata/data_util/slicing/meshes/meshmerge.py @@ -26,7 +26,6 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] t0 = time.time() # Find intersections of all edges in mesh one with edges in mesh two - # TODO: Speed this up # Fastest way might just be to calculate the intersections of all lines on edges, # see whether we need filtering afterwards @@ -48,6 +47,10 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] p3 = edge_b_1[b_grid, :] p4 = edge_b_2[b_grid, :] + # + # TODO: Investigate whether adding a bounding box check will help with speed, seems likely as most edges wont cross + # + # # Solve the equations # @@ -64,7 +67,9 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] deltas = np.concatenate(([-delta1], [delta3]), axis=0) deltas = np.moveaxis(deltas, 0, 2) - st = np.linalg.solve(deltas, start_point_diff) + non_singular = np.linalg.det(deltas) != 0 + + st = np.linalg.solve(deltas[non_singular], start_point_diff[non_singular]) # Find the points where s and t are in (0, 1) @@ -72,8 +77,8 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] np.logical_and(0 < st[:, 0], st[:, 0] < 1), np.logical_and(0 < st[:, 1], st[:, 1] < 1)) - start_points_for_intersections = p1[intersection_inds, :] - deltas_for_intersections = delta1[intersection_inds, :] + start_points_for_intersections = p1[non_singular][intersection_inds, :] + deltas_for_intersections = delta1[non_singular][intersection_inds, :] points_to_add = start_points_for_intersections + st[intersection_inds, 0].reshape(-1,1) * deltas_for_intersections @@ -121,60 +126,7 @@ def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray] print("Centroids:", t3 - t2) - ## step 3) Perform checks based on winding number method (see wikipedia Point in Polygon). - # - # # TODO: Brute force search is sllllloooooooowwwwww - keeping track of which points are where would be better - # for mesh, assignments in [ - # (mesh_a, assignments_a), - # (mesh_b, assignments_b)]: - # - # for centroid_index, centroid in enumerate(centroids): - # for cell_index, cell in enumerate(mesh.cells): - # - # # Bounding box check - # points = mesh.points[cell, :] - # if np.any(centroid < np.min(points, axis=0)): # x or y less than any in polygon - # continue - # - # if np.any(centroid > np.max(points, axis=0)): # x or y greater than any in polygon - # continue - # - # # Winding number check - count directional crossings of vertical half line from centroid - # winding_number = 0 - # for i1, i2 in closed_loop_edges(cell): - # p1 = mesh.points[i1, :] - # p2 = mesh.points[i2, :] - # - # # if the section xs do not straddle the x=centroid_x coordinate, then the - # # edge cannot cross the half line. - # # If it does, then remember which way it was - # # * Careful about ends - # # * Also, note that the p1[0] == p2[0] -> (no contribution) case is covered by the strict inequality - # if p1[0] > centroid[0] >= p2[0]: - # left_right = -1 - # elif p2[0] > centroid[0] >= p1[0]: - # left_right = 1 - # else: - # continue - # - # # Find the y point that it crosses x=centroid at - # # note: denominator cannot be zero because of strict inequality above - # gradient = (p2[1] - p1[1]) / (p2[0] - p1[0]) - # x_delta = centroid[0] - p1[0] - # y = p1[1] + x_delta * gradient - # - # if y > centroid[1]: - # winding_number += left_right - # - # - # if abs(winding_number) > 0: - # # Do assignment of input cell to output triangle index - # assignments[centroid_index] = cell_index - # break # point is assigned - # - # # end cell loop - # - # # end centroid loop + ## step 3) Find where points belong based on Mesh classes point location algorithm assignments_a = mesh_a.locate_points(centroids[:, 0], centroids[:, 1]) assignments_b = mesh_b.locate_points(centroids[:, 0], centroids[:, 1]) diff --git a/sasdata/data_util/slicing/meshes/voronoi_mesh.py b/sasdata/data_util/slicing/meshes/voronoi_mesh.py index 3497fbb..d3eb81d 100644 --- a/sasdata/data_util/slicing/meshes/voronoi_mesh.py +++ b/sasdata/data_util/slicing/meshes/voronoi_mesh.py @@ -24,6 +24,7 @@ def voronoi_mesh(x, y, debug_plot=False) -> Mesh: # 2) the bounding box of the grid # + # Use the median area of finite voronoi cells as an estimate voronoi = Voronoi(input_data) finite_cells = [region for region in voronoi.regions if -1 not in region and len(region) > 0] @@ -37,8 +38,8 @@ def voronoi_mesh(x, y, debug_plot=False) -> Mesh: x_max, y_max = np.max(input_data, axis=0) # Create a border - n_x = np.round((x_max - x_min)/gap).astype(int) - n_y = np.round((y_max - y_min)/gap).astype(int) + n_x = int(np.round((x_max - x_min)/gap)) + n_y = int(np.round((y_max - y_min)/gap)) top_bottom_xs = np.linspace(x_min - gap, x_max + gap, n_x + 3) left_right_ys = np.linspace(y_min, y_max, n_y + 1) diff --git a/sasdata/data_util/slicing/rebinning.py b/sasdata/data_util/slicing/rebinning.py index 7b6eea9..510535a 100644 --- a/sasdata/data_util/slicing/rebinning.py +++ b/sasdata/data_util/slicing/rebinning.py @@ -18,21 +18,17 @@ class CacheData: merged_mesh_data: tuple[Mesh, np.ndarray, np.ndarray] # mesh information about the merging -class Rebinner(): +class Rebinner(ABC): - def __init__(self, order): + def __init__(self): """ Base class for rebinning methods""" - self._order = order self._bin_mesh_cache: Optional[Mesh] = None # cached version of the output bin mesh # Output dependent caching self._input_cache: Optional[CacheData] = None - if order not in self.allowable_orders: - raise ValueError(f"Expected order to be in {self.allowable_orders}, got {order}") - @abstractmethod def _bin_coordinates(self) -> np.ndarray: @@ -60,17 +56,22 @@ def _post_processing(self, coordinates, values) -> tuple[np.ndarray, np.ndarray] # Default is to do nothing, override if needed return coordinates, values - def _do_binning(self, data): - """ Main binning algorithm """ - - def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray) -> np.ndarray: + def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray, order: int) -> np.ndarray: """ Main calculation """ - if self._order == -1: + if order == -1: # Construct the input output mapping just based on input points being the output cells, # Equivalent to the original binning method - pass + mesh = self.bin_mesh + bin_identities = mesh.locate_points(input_coordinates[:,0], input_coordinates[:, 1]) + output_data = np.zeros(mesh.n_cells, dtype=float) + + for index, bin in enumerate(bin_identities): + if bin >= 0: + output_data[bin] += input_data[index] + + return output_data else: # Use a mapping based on meshes @@ -87,7 +88,7 @@ def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray) -> n else: # Calculate mesh data input_coordinate_mesh = voronoi_mesh(input_coordinates[:,0], input_coordinates[:, 1]) - self._data_mesh_cahce = input_coordinate_mesh + self._data_mesh_cache = input_coordinate_mesh merge_data = meshmerge(self.bin_mesh, input_coordinate_mesh) @@ -101,7 +102,7 @@ def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray) -> n # Calculate values according to the order parameter t0 = time.time() - if self._order == 0: + if order == 0: # Based on the overlap of cells only input_areas = input_coordinate_mesh.areas @@ -118,7 +119,7 @@ def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray) -> n return output - elif self._order == 1: + elif order == 1: # Linear interpolation requires the following relationship with the data, # as the input data is the total over the whole input cell, the linear # interpolation requires continuity at the vertices, and a constraint on the @@ -130,11 +131,11 @@ def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray) -> n raise NotImplementedError("1st order (linear) interpolation currently not implemented") else: - raise ValueError(f"Expected order to be in {self.allowable_orders}, got {self._order}") + raise ValueError(f"Expected order to be in {self.allowable_orders}, got {order}") - def sum(self, x: np.ndarray, y: np.ndarray, data: np.ndarray) -> np.ndarray: + def sum(self, x: np.ndarray, y: np.ndarray, data: np.ndarray, order: int = 0) -> np.ndarray: """ Return the summed data in the output bins """ - return self._calculate(np.array((x.reshape(-1), y.reshape(-1))).T, data) + return self._calculate(np.array((x.reshape(-1), y.reshape(-1))).T, data.reshape(-1), order) def error_propagate(self, input_coordinates: np.ndarray, data: np.ndarray, errors) -> np.ndarray: raise NotImplementedError("Error propagation not implemented yet") @@ -142,7 +143,7 @@ def error_propagate(self, input_coordinates: np.ndarray, data: np.ndarray, error def resolution_propagate(self, input_coordinates: np.ndarray, data: np.ndarray, errors) -> np.ndarray: raise NotImplementedError("Resolution propagation not implemented yet") - def average(self, x: np.ndarray, y: np.ndarray, data: np.ndarray) -> np.ndarray: + def average(self, x: np.ndarray, y: np.ndarray, data: np.ndarray, order: int = 0) -> np.ndarray: """ Return the averaged data in the output bins """ - return self._calculate(np.array((x, y)).T, data) / self.bin_mesh.areas + return self._calculate(np.array((x, y)).T, data.reshape(-1), order) / self.bin_mesh.areas diff --git a/sasdata/data_util/slicing/slicer_demo.py b/sasdata/data_util/slicing/slicer_demo.py index d60c0ac..6096ca9 100644 --- a/sasdata/data_util/slicing/slicer_demo.py +++ b/sasdata/data_util/slicing/slicer_demo.py @@ -12,48 +12,110 @@ if __name__ == "__main__": q_range = 1.5 + demo1 = True + demo2 = True + # Demo of sums, annular sector over some not very circular data - x = (2*q_range)*(np.random.random(400)-0.5) - y = (2*q_range)*(np.random.random(400)-0.5) + if demo1: - display_mesh = voronoi_mesh(x, y) + x = (2 * q_range) * (np.random.random(400) - 0.5) + y = (2 * q_range) * (np.random.random(400) - 0.5) - # Demo of sums, annular sector over some not very circular data + display_mesh = voronoi_mesh(x, y) - def lobe_test_function(x, y): - return 1 + np.sin(x*np.pi/q_range)*np.sin(y*np.pi/q_range) + def lobe_test_function(x, y): + return 1 + np.sin(x*np.pi/q_range)*np.sin(y*np.pi/q_range) - random_lobe_data = lobe_test_function(x, y) + random_lobe_data = lobe_test_function(x, y) - plt.figure("Input Dataset 1") - display_mesh.show_data(random_lobe_data, actually_show=False) + plt.figure("Input Dataset 1") + display_mesh.show_data(random_lobe_data, actually_show=False) - data_order_0 = [] + data_order_0 = [] + data_order_neg1 = [] - sizes = np.linspace(0.1, 1, 100) + sizes = np.linspace(0.1, 1, 100) - for index, size in enumerate(sizes): - q0 = 0.75 - 0.6*size - q1 = 0.75 + 0.6*size - phi0 = np.pi/2 - size - phi1 = np.pi/2 + size + for index, size in enumerate(sizes): + q0 = 0.75 - 0.6*size + q1 = 0.75 + 0.6*size + phi0 = np.pi/2 - size + phi1 = np.pi/2 + size - rebinner = AnularSector(q0, q1, phi0, phi1, order=0) + rebinner = AnularSector(q0, q1, phi0, phi1) - data_order_0.append(rebinner.sum(x, y, random_lobe_data)) + data_order_neg1.append(rebinner.sum(x, y, random_lobe_data, order=-1)) + data_order_0.append(rebinner.sum(x, y, random_lobe_data, order=0)) - if index % 10 == 0: - plt.figure("Regions") - rebinner.bin_mesh.show(actually_show=False) + if index % 10 == 0: + plt.figure("Regions 1") + rebinner.bin_mesh.show(actually_show=False) - plt.figure("Data") + plt.title("Regions") - plt.plot(sizes, data_order_0) + plt.figure("Sum of region, dataset 1") - plt.show() + plt.plot(sizes, data_order_neg1) + plt.plot(sizes, data_order_0) + + plt.legend(["Order -1", "Order 0"]) + plt.title("Sum over region") + + + # Demo of averaging, annular sector over ring shaped data + + if demo2: + + x, y = np.meshgrid(np.linspace(-q_range, q_range, 41), np.linspace(-q_range, q_range, 41)) + x = x.reshape(-1) + y = y.reshape(-1) + + display_mesh = voronoi_mesh(x, y) + + + def ring_test_function(x, y): + r = np.sqrt(x**2 + y**2) + return np.log(np.sinc(r*1.5)**2) + + + grid_ring_data = ring_test_function(x, y) + plt.figure("Input Dataset 2") + display_mesh.show_data(grid_ring_data, actually_show=False) + + data_order_0 = [] + data_order_neg1 = [] + + sizes = np.linspace(0.1, 1, 100) + + for index, size in enumerate(sizes): + q0 = 0.25 + q1 = 1.25 + + phi0 = np.pi/2 - size + phi1 = np.pi/2 + size + + rebinner = AnularSector(q0, q1, phi0, phi1) + + data_order_neg1.append(rebinner.average(x, y, grid_ring_data, order=-1)) + data_order_0.append(rebinner.average(x, y, grid_ring_data, order=0)) + + if index % 10 == 0: + plt.figure("Regions 2") + rebinner.bin_mesh.show(actually_show=False) + + plt.title("Regions") + + plt.figure("Average of region 2") + + plt.plot(sizes, data_order_neg1) + plt.plot(sizes, data_order_0) + + plt.legend(["Order -1", "Order 0"]) + plt.title("Sum over region") + + plt.show() - # Demo of averaging, annular sector over ring shaped data \ No newline at end of file diff --git a/sasdata/data_util/slicing/slicers/AnularSector.py b/sasdata/data_util/slicing/slicers/AnularSector.py index e9f1377..6d034da 100644 --- a/sasdata/data_util/slicing/slicers/AnularSector.py +++ b/sasdata/data_util/slicing/slicers/AnularSector.py @@ -5,8 +5,8 @@ class AnularSector(Rebinner): """ A single annular sector (wedge sum)""" - def __init__(self, q0: float, q1: float, phi0: float, phi1: float, order: int=1, points_per_degree: int=2): - super().__init__(order) + def __init__(self, q0: float, q1: float, phi0: float, phi1: float, points_per_degree: int=2): + super().__init__() self.q0 = q0 self.q1 = q1 @@ -17,7 +17,7 @@ def __init__(self, q0: float, q1: float, phi0: float, phi1: float, order: int=1, def _bin_mesh(self) -> Mesh: - n_points = int(1 + 180*self.points_per_degree*(self.phi1 - self.phi0) / np.pi) + n_points = np.max([int(1 + 180*self.points_per_degree*(self.phi1 - self.phi0) / np.pi), 2]) angles = np.linspace(self.phi0, self.phi1, n_points) From 55a1138a38bc43ece0dc1e9a20b4f269534d3d74 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 5 Oct 2023 14:38:34 +0100 Subject: [PATCH 12/35] Requirements --- requirements.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/requirements.txt b/requirements.txt index 184860a..6d5ff01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ lxml # Calculation numpy +scipy # Unit testing pytest @@ -12,3 +13,6 @@ unittest-xml-reporting # Documentation (future) sphinx html5lib + +# Other stuff +matplotlib \ No newline at end of file From 6735f87c9935b3f832d7f13527a71832de3f04b4 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 10 Oct 2024 13:04:08 +0100 Subject: [PATCH 13/35] Added matrix operations, needs tests --- sasdata/quantities/operations.py | 112 ++++++++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 1 deletion(-) diff --git a/sasdata/quantities/operations.py b/sasdata/quantities/operations.py index 724e55d..67cb636 100644 --- a/sasdata/quantities/operations.py +++ b/sasdata/quantities/operations.py @@ -702,9 +702,119 @@ def __eq__(self, other): if isinstance(other, Pow): return self.a == other.a and self.power == other.power + + +# +# Matrix operations +# + +class Transpose(UnaryOperation): + """ Transpose operation - as per numpy""" + + serialisation_name = "transpose" + + def evaluate(self, variables: dict[int, T]) -> T: + return np.transpose(self.a.evaluate(variables)) + + def _derivative(self, hash_value: int) -> Operation: + return Transpose(self.a.derivative(hash_value)) # TODO: Check! + + def _clean(self): + clean_a = self.a._clean() + return Transpose(clean_a) + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Transpose(Operation.deserialise_json(parameters["a"])) + + def _summary_open(self): + return "Transpose" + + def __eq__(self, other): + if isinstance(other, Transpose): + return other.a == self.a + + +class Dot(BinaryOperation): + """ Dot product - backed by numpy's dot method""" + + serialisation_name = "dot" + + def _self_cls(self) -> type: + return Dot + + def evaluate(self, variables: dict[int, T]) -> T: + return np.dot(self.a.evaluate(variables) + self.b.evaluate(variables)) + + def _derivative(self, hash_value: int) -> Operation: + return Add( + Dot(self.a, + self.b._derivative(hash_value)), + Dot(self.a._derivative(hash_value), + self.b)) + + def _clean_ab(self, a, b): + return Dot(a, b) # Do nothing for now + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Dot(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "Dot" + + +# TODO: Add to base operation class, and to quantities +class MatMul(BinaryOperation): + """ Matrix multiplication, using __matmul__ dunder""" + + serialisation_name = "matmul" + + def _self_cls(self) -> type: + return MatMul + + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) @ self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Add( + MatMul(self.a, + self.b._derivative(hash_value)), + MatMul(self.a._derivative(hash_value), + self.b)) + + def _clean_ab(self, a, b): + + if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): + # Convert 0*b or a*0 to 0 + return AdditiveIdentity() + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant "a"@"b" to "a@b" + return Constant(a.evaluate({}) @ b.evaluate({}))._clean() + + elif isinstance(a, Neg): + return Neg(Mul(a.a, b)) + + elif isinstance(b, Neg): + return Neg(Mul(a, b.a)) + + return MatMul(a, b) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return MatMul(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "MatMul" + + + _serialisable_classes = [AdditiveIdentity, MultiplicativeIdentity, Constant, Variable, Neg, Inv, - Add, Sub, Mul, Div, Pow] + Add, Sub, Mul, Div, Pow, + Transpose, Dot, MatMul] _serialisation_lookup = {cls.serialisation_name: cls for cls in _serialisable_classes} From 96f9b86fd7e8b0b570689abebea71de35d93bd26 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 10 Oct 2024 13:05:22 +0100 Subject: [PATCH 14/35] Numpy import --- sasdata/quantities/operations.py | 1 + sasdata/quantities/quantity.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/sasdata/quantities/operations.py b/sasdata/quantities/operations.py index 67cb636..35f6fc7 100644 --- a/sasdata/quantities/operations.py +++ b/sasdata/quantities/operations.py @@ -1,4 +1,5 @@ from typing import Any, TypeVar, Union +import numpy as np import json diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index df70830..149beb6 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -233,6 +233,9 @@ def __rmul__(self: Self, other: ArrayLike | Self): self.history.operation_tree), self.history.references)) + + + def __truediv__(self: Self, other: float | Self) -> Self: if isinstance(other, Quantity): return DerivedQuantity( From 9fca154afaee7c7ae0a702e963d5ed96663d356a Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 10 Oct 2024 13:12:27 +0100 Subject: [PATCH 15/35] Added __matmul__ and __rmatmul__ to quantities --- sasdata/quantities/quantity.py | 36 ++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index 149beb6..cb1194d 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -234,6 +234,42 @@ def __rmul__(self: Self, other: ArrayLike | Self): self.history.references)) + def __matmul__(self, other: ArrayLike | Self): + if isinstance(other, Quantity): + return DerivedQuantity( + self.value @ other.value, + self.units * other.units, + history=QuantityHistory.apply_operation( + operations.MatMul, + self.history, + other.history)) + else: + return DerivedQuantity( + self.value @ other, + self.units, + QuantityHistory( + operations.MatMul( + self.history.operation_tree, + operations.Constant(other)), + self.history.references)) + + def __rmatmul__(self, other: ArrayLike | Self): + if isinstance(other, Quantity): + return DerivedQuantity( + other.value @ self.value, + other.units * self.units, + history=QuantityHistory.apply_operation( + operations.MatMul, + other.history, + self.history)) + + else: + return DerivedQuantity(other @ self.value, self.units, + QuantityHistory( + operations.MatMul( + operations.Constant(other), + self.history.operation_tree), + self.history.references)) def __truediv__(self: Self, other: float | Self) -> Self: From 9061715e600d51f8192cb47e8a43b77433d40816 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 10 Oct 2024 14:36:27 +0100 Subject: [PATCH 16/35] Entrypoint for rebinning --- sasdata/model_requirements.py | 2 +- sasdata/transforms/operation.py | 19 ------------------- sasdata/transforms/post_process.py | 0 sasdata/transforms/rebinning.py | 9 +++++++++ 4 files changed, 10 insertions(+), 20 deletions(-) delete mode 100644 sasdata/transforms/operation.py create mode 100644 sasdata/transforms/post_process.py create mode 100644 sasdata/transforms/rebinning.py diff --git a/sasdata/model_requirements.py b/sasdata/model_requirements.py index f186d2d..d043b2c 100644 --- a/sasdata/model_requirements.py +++ b/sasdata/model_requirements.py @@ -3,7 +3,7 @@ import numpy as np from sasdata.metadata import Metadata -from transforms.operation import Operation +from sasdata.quantities.operations import Operation @dataclass diff --git a/sasdata/transforms/operation.py b/sasdata/transforms/operation.py deleted file mode 100644 index 5912188..0000000 --- a/sasdata/transforms/operation.py +++ /dev/null @@ -1,19 +0,0 @@ -import numpy as np -from sasdata.quantities.quantity import Quantity - -class Operation: - """ Sketch of what model post-processing classes might look like """ - - children: list["Operation"] - named_children: dict[str, "Operation"] - - @property - def name(self) -> str: - raise NotImplementedError("No name for transform") - - def evaluate(self) -> Quantity[np.ndarray]: - pass - - def __call__(self, *children, **named_children): - self.children = children - self.named_children = named_children \ No newline at end of file diff --git a/sasdata/transforms/post_process.py b/sasdata/transforms/post_process.py new file mode 100644 index 0000000..e69de29 diff --git a/sasdata/transforms/rebinning.py b/sasdata/transforms/rebinning.py new file mode 100644 index 0000000..58d1912 --- /dev/null +++ b/sasdata/transforms/rebinning.py @@ -0,0 +1,9 @@ +""" Algorithms for interpolation and rebinning """ +from typing import TypeVar + +from numpy._typing import ArrayLike + +from sasdata.quantities.quantity import Quantity + +def rebin(data: Quantity[ArrayLike], axes: list[Quantity[ArrayLike]], new_axes: list[Quantity[ArrayLike]], interpolation_order=1): + pass \ No newline at end of file From 5b3561451538592b235e4335a81b48c952aad861 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 10 Oct 2024 14:42:42 +0100 Subject: [PATCH 17/35] Some better commenting on Quantity --- sasdata/quantities/quantity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index cb1194d..6ee80d8 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -132,10 +132,10 @@ def __init__(self, self.hash_value = -1 """ Hash based on value and uncertainty for data, -1 if it is a derived hash value """ - """ Contains the variance if it is data driven, else it is """ + self._variance = None + """ Contains the variance if it is data driven """ if standard_error is None: - self._variance = None self.hash_value = hash_data_via_numpy(hash_seed, value) else: self._variance = standard_error ** 2 From 090808e8e89e1bfba53ee6aeb3d77101674bfe47 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 10 Oct 2024 18:10:36 +0100 Subject: [PATCH 18/35] Work towards rebinning methods --- sasdata/data.py | 21 ++++++----- sasdata/transforms/rebinning.py | 66 ++++++++++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 10 deletions(-) diff --git a/sasdata/data.py b/sasdata/data.py index 7f0cbfb..544ba27 100644 --- a/sasdata/data.py +++ b/sasdata/data.py @@ -2,6 +2,8 @@ from typing import TypeVar, Any, Self from dataclasses import dataclass +import numpy as np + from quantities.quantity import NamedQuantity from sasdata.metadata import Metadata from sasdata.quantities.accessors import AccessorTarget @@ -9,7 +11,11 @@ class SasData: - def __init__(self, name: str, data_contents: list[NamedQuantity], raw_metadata: Group, verbose: bool=False): + def __init__(self, name: str, + data_contents: list[NamedQuantity], + raw_metadata: Group, + verbose: bool=False): + self.name = name self._data_contents = data_contents self._raw_metadata = raw_metadata @@ -17,14 +23,11 @@ def __init__(self, name: str, data_contents: list[NamedQuantity], raw_metadata: self.metadata = Metadata(AccessorTarget(raw_metadata, verbose=verbose)) - # TO IMPLEMENT - - # abscissae: list[NamedQuantity[np.ndarray]] - # ordinate: NamedQuantity[np.ndarray] - # other: list[NamedQuantity[np.ndarray]] - # - # metadata: Metadata - # model_requirements: ModellingRequirements + # Components that need to be organised after creation + self.ordinate: NamedQuantity[np.ndarray] = None # TODO: fill out + self.abscissae: list[NamedQuantity[np.ndarray]] = None # TODO: fill out + self.mask = None # TODO: fill out + self.model_requirements = None # TODO: fill out def summary(self, indent = " ", include_raw=False): s = f"{self.name}\n" diff --git a/sasdata/transforms/rebinning.py b/sasdata/transforms/rebinning.py index 58d1912..75f5376 100644 --- a/sasdata/transforms/rebinning.py +++ b/sasdata/transforms/rebinning.py @@ -1,9 +1,73 @@ """ Algorithms for interpolation and rebinning """ from typing import TypeVar +import numpy as np from numpy._typing import ArrayLike from sasdata.quantities.quantity import Quantity +from scipy.sparse import coo_matrix + +from enum import Enum + +class InterpolationOptions(Enum): + NEAREST_NEIGHBOUR = 0 + LINEAR = 1 + + + +def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], + output_axis: Quantity[ArrayLike], + mask: ArrayLike | None = None, + order: InterpolationOptions = InterpolationOptions.NEAREST_NEIGHBOUR, + is_density=False): + + # We want the input values in terms of the output units, will implicitly check compatability + + working_units = output_axis.units + + input_x = input_axis.in_units_of(working_units) + output_x = output_axis.in_units_of(working_units) + + # Get the array indices that will map the array to a sorted one + input_sort = np.argsort(input_x) + output_sort = np.argsort(output_x) + + output_unsort = np.arange(len(input_x), dtype=int)[output_sort] + sorted_in = input_x[input_sort] + sorted_out = output_x[output_sort] + + match order: + case InterpolationOptions.NEAREST_NEIGHBOUR: + + # COO Sparse matrix definition data + values = [] + j_entries = [] + i_entries = [] + + # Find the output values nearest to each of the input values + for x_in in sorted_in: + + + case _: + raise ValueError(f"Unsupported interpolation order: {order}") + +def calculate_interpolation_matrix(input_axes: list[Quantity[ArrayLike]], + output_axes: list[Quantity[ArrayLike]], + data: ArrayLike | None = None, + mask: ArrayLike | None = None): + + pass + + + +def rebin(data: Quantity[ArrayLike], + axes: list[Quantity[ArrayLike]], + new_axes: list[Quantity[ArrayLike]], + mask: ArrayLike | None = None, + interpolation_order: int = 1): + + """ This algorithm is only for operations that preserve dimensionality, + i.e. non-projective rebinning. + """ -def rebin(data: Quantity[ArrayLike], axes: list[Quantity[ArrayLike]], new_axes: list[Quantity[ArrayLike]], interpolation_order=1): pass \ No newline at end of file From 8b95885f08e3e1117a0c0f7c9cef82fbce542632 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Fri, 11 Oct 2024 11:27:53 +0100 Subject: [PATCH 19/35] Zeroth order rebinning sketch --- sasdata/transforms/rebinning.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/sasdata/transforms/rebinning.py b/sasdata/transforms/rebinning.py index 75f5376..c490f82 100644 --- a/sasdata/transforms/rebinning.py +++ b/sasdata/transforms/rebinning.py @@ -32,7 +32,9 @@ def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], input_sort = np.argsort(input_x) output_sort = np.argsort(output_x) + input_unsort = np.arange(len(output_x), dtype=int)[input_sort] output_unsort = np.arange(len(input_x), dtype=int)[output_sort] + sorted_in = input_x[input_sort] sorted_out = output_x[output_sort] @@ -40,13 +42,33 @@ def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], case InterpolationOptions.NEAREST_NEIGHBOUR: # COO Sparse matrix definition data - values = [] - j_entries = [] i_entries = [] + j_entries = [] + + crossing_points = 0.5*(sorted_out[1:] + sorted_out[:-1]) # Find the output values nearest to each of the input values - for x_in in sorted_in: - + n_i = len(sorted_in) + n_j = len(sorted_out) + i=0 + for k, crossing_point in enumerate(crossing_points): + while i < n_i and sorted_in[i] < crossing_point: + i_entries.append(i) + j_entries.append(k) + i += 1 + + # All the rest in the last bin + while i < n_i: + i_entries.append(i) + j_entries.append(n_j-1) + i += 1 + + i_entries = input_unsort[np.array(i_entries, dtype=int)] + j_entries = output_unsort[np.array(j_entries, dtype=int)] + values = np.ones_like(i_entries, dtype=float) + + return coo_matrix((values, (i_entries, j_entries)), shape=(n_i, n_j)) + case _: raise ValueError(f"Unsupported interpolation order: {order}") From a1db35f71d18563cf668c6881bfdae1ff8dc04c3 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Fri, 11 Oct 2024 12:08:04 +0100 Subject: [PATCH 20/35] First order rebinning --- sasdata/transforms/rebinning.py | 59 +++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 6 deletions(-) diff --git a/sasdata/transforms/rebinning.py b/sasdata/transforms/rebinning.py index c490f82..cd05cfc 100644 --- a/sasdata/transforms/rebinning.py +++ b/sasdata/transforms/rebinning.py @@ -3,6 +3,7 @@ import numpy as np from numpy._typing import ArrayLike +from scipy.interpolate import interp1d from sasdata.quantities.quantity import Quantity from scipy.sparse import coo_matrix @@ -38,6 +39,11 @@ def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], sorted_in = input_x[input_sort] sorted_out = output_x[output_sort] + n_in = len(sorted_in) + n_out = len(sorted_out) + + conversion_matrix = None # output + match order: case InterpolationOptions.NEAREST_NEIGHBOUR: @@ -48,31 +54,72 @@ def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], crossing_points = 0.5*(sorted_out[1:] + sorted_out[:-1]) # Find the output values nearest to each of the input values - n_i = len(sorted_in) - n_j = len(sorted_out) i=0 for k, crossing_point in enumerate(crossing_points): - while i < n_i and sorted_in[i] < crossing_point: + while i < n_in and sorted_in[i] < crossing_point: i_entries.append(i) j_entries.append(k) i += 1 # All the rest in the last bin - while i < n_i: + while i < n_in: i_entries.append(i) - j_entries.append(n_j-1) + j_entries.append(n_out-1) i += 1 i_entries = input_unsort[np.array(i_entries, dtype=int)] j_entries = output_unsort[np.array(j_entries, dtype=int)] values = np.ones_like(i_entries, dtype=float) - return coo_matrix((values, (i_entries, j_entries)), shape=(n_i, n_j)) + conversion_matrix = coo_matrix((values, (i_entries, j_entries)), shape=(n_in, n_out)) + + case InterpolationOptions.LINEAR: + + # Leverage existing linear interpolation methods to get the mapping + # do a linear interpolation on indices + # the floor should give the left bin + # the ceil should give the right bin + # the fractional part should give the relative weightings + + input_indices = np.arange(n_in, dtype=int) + output_indices = np.arange(n_out, dtype=int) + + fractional = np.interp(x=sorted_out, xp=sorted_in, fp=input_indices, left=0, right=n_in-1) + + left_bins = np.floor(fractional, dtype=int) + right_bins = np.ceil(fractional, dtype=int) + + right_weight = fractional % 1 + left_weight = 1 - right_weight + + # There *should* be no repeated entries for both i and j in the main part, but maybe at the ends + # If left bin is the same as right bin, then we only want one entry, and the weight should be 1 + same = left_bins == right_bins + not_same = ~same + + same_bins = left_bins[same] # could equally be right bins, they're the same + + same_indices = output_indices[same] + not_same_indices = output_indices[not_same] + + j_entries_sorted = np.concatenate((same_indices, not_same_indices, not_same_indices)) + i_entries_sorted = np.concatenate((same_bins, left_bins[not_same], right_bins[not_same])) + + i_entries = input_unsort[i_entries_sorted] + j_entries = output_unsort[j_entries_sorted] + + # weights don't need to be unsorted # TODO: check this is right, it should become obvious if we use unsorted data + weights = np.concatenate((np.ones_like(same_bins, dtype=float), left_weight[not_same], right_weight[not_same])) + + conversion_matrix = coo_matrix((weights, (i_entries, j_entries)), shape=(n_in, n_out)) case _: raise ValueError(f"Unsupported interpolation order: {order}") + + return conversion_matrix + def calculate_interpolation_matrix(input_axes: list[Quantity[ArrayLike]], output_axes: list[Quantity[ArrayLike]], data: ArrayLike | None = None, From ed02586ba6ba86cb6cd05521f0363b5cfd8f1883 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Tue, 15 Oct 2024 16:26:44 +0100 Subject: [PATCH 21/35] Rebinning tests and extensions --- sasdata/manual_tests/__init__.py | 0 sasdata/manual_tests/interpolation.py | 44 ++++++++++++ sasdata/quantities/math.py | 5 ++ sasdata/quantities/plotting.py | 23 ++++++ sasdata/quantities/quantity.py | 18 +++++ sasdata/transforms/rebinning.py | 60 ++++++++++++++-- sasdata/transforms/test_interpolation.py | 91 ++++++++++++++++++++++++ 7 files changed, 234 insertions(+), 7 deletions(-) create mode 100644 sasdata/manual_tests/__init__.py create mode 100644 sasdata/manual_tests/interpolation.py create mode 100644 sasdata/quantities/math.py create mode 100644 sasdata/quantities/plotting.py create mode 100644 sasdata/transforms/test_interpolation.py diff --git a/sasdata/manual_tests/__init__.py b/sasdata/manual_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sasdata/manual_tests/interpolation.py b/sasdata/manual_tests/interpolation.py new file mode 100644 index 0000000..c6338a4 --- /dev/null +++ b/sasdata/manual_tests/interpolation.py @@ -0,0 +1,44 @@ +import numpy as np +import matplotlib.pyplot as plt + +from sasdata.quantities.quantity import NamedQuantity +from sasdata.quantities.plotting import quantity_plot +from sasdata.quantities import units + +from sasdata.transforms.rebinning import calculate_interpolation_matrix_1d +from sasdata.transforms.rebinning import InterpolationOptions + +def linear_interpolation_check(): + + for from_bins in [(-10, 10, 10), + (-10, 10, 1000), + (-15, 5, 10), + (15,5, 10)]: + for to_bins in [ + (-15, 0, 10), + (-15, 15, 10), + (0, 20, 100)]: + + plt.figure() + + x = NamedQuantity("x", np.linspace(*from_bins), units=units.meters) + y = x**2 + + quantity_plot(x, y) + + new_x = NamedQuantity("x_new", np.linspace(*to_bins), units=units.meters) + + rebin_mat = calculate_interpolation_matrix_1d(x, new_x, order=InterpolationOptions.LINEAR) + + new_y = y @ rebin_mat + + quantity_plot(new_x, new_y) + + # print(new_y.history.summary()) + + plt.show() + + + + +linear_interpolation_check() \ No newline at end of file diff --git a/sasdata/quantities/math.py b/sasdata/quantities/math.py new file mode 100644 index 0000000..d252ccc --- /dev/null +++ b/sasdata/quantities/math.py @@ -0,0 +1,5 @@ +""" Math module extended to allow operations on quantities """ + +# TODO Implementations for trig and exp +# TODO Implementations for linear algebra stuff + diff --git a/sasdata/quantities/plotting.py b/sasdata/quantities/plotting.py new file mode 100644 index 0000000..854e23f --- /dev/null +++ b/sasdata/quantities/plotting.py @@ -0,0 +1,23 @@ +import matplotlib.pyplot as plt +from numpy.typing import ArrayLike + +from sasdata.quantities.quantity import Quantity, NamedQuantity + + +def quantity_plot(x: Quantity[ArrayLike], y: Quantity[ArrayLike], *args, **kwargs): + plt.plot(x.value, y.value, *args, **kwargs) + + x_name = x.name if isinstance(x, NamedQuantity) else "x" + y_name = y.name if isinstance(y, NamedQuantity) else "y" + + plt.xlabel(f"{x_name} / {x.units}") + plt.ylabel(f"{y_name} / {y.units}") + +def quantity_scatter(x: Quantity[ArrayLike], y: Quantity[ArrayLike], *args, **kwargs): + plt.scatter(x.value, y.value, *args, **kwargs) + + x_name = x.name if isinstance(x, NamedQuantity) else "x" + y_name = y.name if isinstance(y, NamedQuantity) else "y" + + plt.xlabel(f"{x_name} / {x.units}") + plt.ylabel(f"{y_name} / {y.units}") diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index 6ee80d8..3def7e1 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -110,6 +110,16 @@ def has_variance(self): return False + def summary(self): + + variable_strings = [self.references[key].string_repr for key in self.references] + + s = "Variables: "+",".join(variable_strings) + s += "\n" + s += self.operation_tree.summary() + + return s + class Quantity[QuantityType]: @@ -398,6 +408,10 @@ def __repr__(self): def parse(number_or_string: str | ArrayLike, unit: str, absolute_temperature: False): pass + @property + def string_repr(self): + return str(self.hash_value) + class NamedQuantity[QuantityType](Quantity[QuantityType]): def __init__(self, @@ -432,6 +446,10 @@ def with_standard_error(self, standard_error: Quantity): f"are not compatible with value units ({self.units})") + @property + def string_repr(self): + return self.name + class DerivedQuantity[QuantityType](Quantity[QuantityType]): def __init__(self, value: QuantityType, units: Unit, history: QuantityHistory): super().__init__(value, units, standard_error=None) diff --git a/sasdata/transforms/rebinning.py b/sasdata/transforms/rebinning.py index cd05cfc..3335216 100644 --- a/sasdata/transforms/rebinning.py +++ b/sasdata/transforms/rebinning.py @@ -13,13 +13,17 @@ class InterpolationOptions(Enum): NEAREST_NEIGHBOUR = 0 LINEAR = 1 + CUBIC = 3 +class InterpolationError(Exception): + """ We probably want to raise exceptions because interpolation is not appropriate/well-defined, + not the same as numerical issues that will raise ValueErrors""" def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], output_axis: Quantity[ArrayLike], mask: ArrayLike | None = None, - order: InterpolationOptions = InterpolationOptions.NEAREST_NEIGHBOUR, + order: InterpolationOptions = InterpolationOptions.LINEAR, is_density=False): # We want the input values in terms of the output units, will implicitly check compatability @@ -33,8 +37,8 @@ def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], input_sort = np.argsort(input_x) output_sort = np.argsort(output_x) - input_unsort = np.arange(len(output_x), dtype=int)[input_sort] - output_unsort = np.arange(len(input_x), dtype=int)[output_sort] + input_unsort = np.arange(len(input_x), dtype=int)[input_sort] + output_unsort = np.arange(len(output_x), dtype=int)[output_sort] sorted_in = input_x[input_sort] sorted_out = output_x[output_sort] @@ -86,8 +90,8 @@ def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], fractional = np.interp(x=sorted_out, xp=sorted_in, fp=input_indices, left=0, right=n_in-1) - left_bins = np.floor(fractional, dtype=int) - right_bins = np.ceil(fractional, dtype=int) + left_bins = np.floor(fractional).astype(int) + right_bins = np.ceil(fractional).astype(int) right_weight = fractional % 1 left_weight = 1 - right_weight @@ -114,18 +118,60 @@ def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], conversion_matrix = coo_matrix((weights, (i_entries, j_entries)), shape=(n_in, n_out)) + case InterpolationOptions.CUBIC: + # Cubic interpolation, much harder to implement because we can't just cheat and use numpy + raise NotImplementedError("Cubic interpolation not implemented yet") + case _: - raise ValueError(f"Unsupported interpolation order: {order}") + raise InterpolationError(f"Unsupported interpolation order: {order}") return conversion_matrix +def calculate_interpolation_matrix_2d_axis_axis(input_1: Quantity[ArrayLike], + input_2: Quantity[ArrayLike], + output_1: Quantity[ArrayLike], + output_2: Quantity[ArrayLike], + mask, + order: InterpolationOptions = InterpolationOptions.LINEAR, + is_density: bool = False): + + match order: + case InterpolationOptions.NEAREST_NEIGHBOUR: + pass + + case InterpolationOptions.LINEAR: + pass + + case InterpolationOptions.CUBIC: + pass + + case _: + pass + + def calculate_interpolation_matrix(input_axes: list[Quantity[ArrayLike]], output_axes: list[Quantity[ArrayLike]], data: ArrayLike | None = None, mask: ArrayLike | None = None): - pass + # TODO: We probably should delete this, but lets keep it for now + + if len(input_axes) not in (1, 2): + raise InterpolationError("Interpolation is only supported for 1D and 2D data") + + if len(input_axes) == 1 and len(output_axes) == 1: + # Check for dimensionality + input_axis = input_axes[0] + output_axis = output_axes[0] + + if len(input_axis.value.shape) == 1: + if len(output_axis.value.shape) == 1: + calculate_interpolation_matrix_1d() + + if len(output_axes) != len(input_axes): + # Input or output axes might be 2D matrices + diff --git a/sasdata/transforms/test_interpolation.py b/sasdata/transforms/test_interpolation.py new file mode 100644 index 0000000..688da65 --- /dev/null +++ b/sasdata/transforms/test_interpolation.py @@ -0,0 +1,91 @@ +import pytest +import numpy as np +from matplotlib import pyplot as plt +from numpy.typing import ArrayLike +from typing import Callable + +from sasdata.quantities.plotting import quantity_plot +from sasdata.quantities.quantity import NamedQuantity, Quantity +from sasdata.quantities import units + +from sasdata.transforms.rebinning import calculate_interpolation_matrix_1d, InterpolationOptions + +test_functions = [ + lambda x: x**2, + lambda x: 2*x, + lambda x: x**3 +] + + +@pytest.mark.parametrize("fun", test_functions) +def test_linear_interpolate_matrix_inside(fun: Callable[[Quantity[ArrayLike]], Quantity[ArrayLike]]): + original_points = NamedQuantity("x_base", np.linspace(-10,10, 31), units.meters) + test_points = NamedQuantity("x_test", np.linspace(-5, 5, 11), units.meters) + + + mapping = calculate_interpolation_matrix_1d(original_points, test_points, order=InterpolationOptions.LINEAR) + + y_original = fun(original_points) + y_test = y_original @ mapping + y_expected = fun(test_points) + + test_units = y_expected.units + + y_values_test = y_test.in_units_of(test_units) + y_values_expected = y_expected.in_units_of(test_units) + + # print(y_values_test) + # print(y_values_expected) + # + # quantity_plot(original_points, y_original) + # quantity_plot(test_points, y_test) + # quantity_plot(test_points, y_expected) + # plt.show() + + assert len(y_values_test) == len(y_values_expected) + + for t, e in zip(y_values_test, y_values_expected): + assert t == pytest.approx(e, abs=2) + + +@pytest.mark.parametrize("fun", test_functions) +def test_linear_interpolate_different_units(fun: Callable[[Quantity[ArrayLike]], Quantity[ArrayLike]]): + original_points = NamedQuantity("x_base", np.linspace(-10,10, 107), units.meters) + test_points = NamedQuantity("x_test", np.linspace(-5000, 5000, 11), units.millimeters) + + mapping = calculate_interpolation_matrix_1d(original_points, test_points, order=InterpolationOptions.LINEAR) + + y_original = fun(original_points) + y_test = y_original @ mapping + y_expected = fun(test_points) + + test_units = y_expected.units + + y_values_test = y_test.in_units_of(test_units) + y_values_expected = y_expected.in_units_of(test_units) + # + # print(y_values_test) + # print(y_test.in_si()) + # print(y_values_expected) + # + # plt.plot(original_points.in_si(), y_original.in_si()) + # plt.plot(test_points.in_si(), y_test.in_si(), "x") + # plt.plot(test_points.in_si(), y_expected.in_si(), "o") + # plt.show() + + assert len(y_values_test) == len(y_values_expected) + + for t, e in zip(y_values_test, y_values_expected): + assert t == pytest.approx(e, rel=5e-2) + +def test_linearity_linear(): + """ Test linear interpolation between two points""" + x_and_y = NamedQuantity("x_base", np.linspace(-10, 10, 2), units.meters) + new_x = NamedQuantity("x_test", np.linspace(-5000, 5000, 101), units.millimeters) + + mapping = calculate_interpolation_matrix_1d(x_and_y, new_x, order=InterpolationOptions.LINEAR) + + linear_points = x_and_y @ mapping + + for t, e in zip(new_x.in_si(), linear_points.in_si()): + assert t == pytest.approx(e, rel=1e-3) \ No newline at end of file From 2327c04cf1cfe5b31392fa54341d660f44a35a75 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Tue, 15 Oct 2024 17:45:32 +0100 Subject: [PATCH 22/35] Notes --- sasdata/transforms/rebinning.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sasdata/transforms/rebinning.py b/sasdata/transforms/rebinning.py index 3335216..d5120b7 100644 --- a/sasdata/transforms/rebinning.py +++ b/sasdata/transforms/rebinning.py @@ -26,7 +26,11 @@ def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], order: InterpolationOptions = InterpolationOptions.LINEAR, is_density=False): + """ Calculate the matrix that converts values recorded at points specified by input_axis to + values recorded at points specified by output_axis""" + # We want the input values in terms of the output units, will implicitly check compatability + # TODO: incorporate mask working_units = output_axis.units @@ -136,6 +140,8 @@ def calculate_interpolation_matrix_2d_axis_axis(input_1: Quantity[ArrayLike], order: InterpolationOptions = InterpolationOptions.LINEAR, is_density: bool = False): + # If it wasn't for the mask, this would be the same as just two sets of 1D interpolation + match order: case InterpolationOptions.NEAREST_NEIGHBOUR: pass From c1e4db25d42508babd0fc45f16287951cfe290bc Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Tue, 15 Oct 2024 17:46:18 +0100 Subject: [PATCH 23/35] No error --- sasdata/transforms/rebinning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sasdata/transforms/rebinning.py b/sasdata/transforms/rebinning.py index d5120b7..662a0b1 100644 --- a/sasdata/transforms/rebinning.py +++ b/sasdata/transforms/rebinning.py @@ -177,7 +177,7 @@ def calculate_interpolation_matrix(input_axes: list[Quantity[ArrayLike]], if len(output_axes) != len(input_axes): # Input or output axes might be 2D matrices - + pass From 30e3c302186fadfd05c5f1400119ec58a53eb8d9 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Thu, 17 Oct 2024 14:21:59 +0100 Subject: [PATCH 24/35] Moving some things around --- sasdata/{data_util => }/slicing/__init__.py | 0 sasdata/{data_util => }/slicing/geometry.py | 0 sasdata/{data_util => }/slicing/meshes/__init__.py | 0 .../slicing/meshes/delaunay_mesh.py | 2 +- sasdata/{data_util => }/slicing/meshes/mesh.py | 2 +- .../{data_util => }/slicing/meshes/meshmerge.py | 6 ++---- sasdata/{data_util => }/slicing/meshes/util.py | 0 .../{data_util => }/slicing/meshes/voronoi_mesh.py | 2 +- sasdata/{data_util => }/slicing/rebinning.py | 6 +++--- sasdata/{data_util => }/slicing/sample_polygons.py | 0 sasdata/{data_util => }/slicing/slicer_demo.py | 5 ++--- .../slicing/slicers/AnularSector.py | 4 ++-- sasdata/slicing/slicers/__init__.py | 0 sasdata/{data_util => }/slicing/transforms.py | 0 sasdata/transforms/rebinning.py | 14 ++++++++++++-- test/slicers/meshes_for_testing.py | 6 +++--- test/slicers/utest_meshmerge.py | 2 +- 17 files changed, 28 insertions(+), 21 deletions(-) rename sasdata/{data_util => }/slicing/__init__.py (100%) rename sasdata/{data_util => }/slicing/geometry.py (100%) rename sasdata/{data_util => }/slicing/meshes/__init__.py (100%) rename sasdata/{data_util => }/slicing/meshes/delaunay_mesh.py (89%) rename sasdata/{data_util => }/slicing/meshes/mesh.py (96%) rename sasdata/{data_util => }/slicing/meshes/meshmerge.py (92%) rename sasdata/{data_util => }/slicing/meshes/util.py (100%) rename sasdata/{data_util => }/slicing/meshes/voronoi_mesh.py (95%) rename sasdata/{data_util => }/slicing/rebinning.py (94%) rename sasdata/{data_util => }/slicing/sample_polygons.py (100%) rename sasdata/{data_util => }/slicing/slicer_demo.py (90%) rename sasdata/{data_util => }/slicing/slicers/AnularSector.py (87%) create mode 100644 sasdata/slicing/slicers/__init__.py rename sasdata/{data_util => }/slicing/transforms.py (100%) diff --git a/sasdata/data_util/slicing/__init__.py b/sasdata/slicing/__init__.py similarity index 100% rename from sasdata/data_util/slicing/__init__.py rename to sasdata/slicing/__init__.py diff --git a/sasdata/data_util/slicing/geometry.py b/sasdata/slicing/geometry.py similarity index 100% rename from sasdata/data_util/slicing/geometry.py rename to sasdata/slicing/geometry.py diff --git a/sasdata/data_util/slicing/meshes/__init__.py b/sasdata/slicing/meshes/__init__.py similarity index 100% rename from sasdata/data_util/slicing/meshes/__init__.py rename to sasdata/slicing/meshes/__init__.py diff --git a/sasdata/data_util/slicing/meshes/delaunay_mesh.py b/sasdata/slicing/meshes/delaunay_mesh.py similarity index 89% rename from sasdata/data_util/slicing/meshes/delaunay_mesh.py rename to sasdata/slicing/meshes/delaunay_mesh.py index 45e2087..a19c2ac 100644 --- a/sasdata/data_util/slicing/meshes/delaunay_mesh.py +++ b/sasdata/slicing/meshes/delaunay_mesh.py @@ -1,7 +1,7 @@ import numpy as np from scipy.spatial import Delaunay -from sasdata.data_util.slicing.meshes.mesh import Mesh +from sasdata.slicing.meshes.mesh import Mesh def delaunay_mesh(x, y) -> Mesh: """ Create a triangulated mesh based on input points """ diff --git a/sasdata/data_util/slicing/meshes/mesh.py b/sasdata/slicing/meshes/mesh.py similarity index 96% rename from sasdata/data_util/slicing/meshes/mesh.py rename to sasdata/slicing/meshes/mesh.py index 3ac23da..8176633 100644 --- a/sasdata/data_util/slicing/meshes/mesh.py +++ b/sasdata/slicing/meshes/mesh.py @@ -6,7 +6,7 @@ from matplotlib import cm from matplotlib.collections import LineCollection -from sasdata.data_util.slicing.meshes.util import closed_loop_edges +from sasdata.slicing.meshes.util import closed_loop_edges class Mesh: def __init__(self, diff --git a/sasdata/data_util/slicing/meshes/meshmerge.py b/sasdata/slicing/meshes/meshmerge.py similarity index 92% rename from sasdata/data_util/slicing/meshes/meshmerge.py rename to sasdata/slicing/meshes/meshmerge.py index 161c1e5..2060cc7 100644 --- a/sasdata/data_util/slicing/meshes/meshmerge.py +++ b/sasdata/slicing/meshes/meshmerge.py @@ -1,9 +1,7 @@ import numpy as np -from sasdata.data_util.slicing.meshes.mesh import Mesh -from sasdata.data_util.slicing.meshes.delaunay_mesh import delaunay_mesh -from sasdata.data_util.slicing.meshes.util import closed_loop_edges - +from sasdata.slicing.meshes.mesh import Mesh +from sasdata.slicing.meshes.delaunay_mesh import delaunay_mesh import time diff --git a/sasdata/data_util/slicing/meshes/util.py b/sasdata/slicing/meshes/util.py similarity index 100% rename from sasdata/data_util/slicing/meshes/util.py rename to sasdata/slicing/meshes/util.py diff --git a/sasdata/data_util/slicing/meshes/voronoi_mesh.py b/sasdata/slicing/meshes/voronoi_mesh.py similarity index 95% rename from sasdata/data_util/slicing/meshes/voronoi_mesh.py rename to sasdata/slicing/meshes/voronoi_mesh.py index d3eb81d..d47dc2c 100644 --- a/sasdata/data_util/slicing/meshes/voronoi_mesh.py +++ b/sasdata/slicing/meshes/voronoi_mesh.py @@ -2,7 +2,7 @@ from scipy.spatial import Voronoi -from sasdata.data_util.slicing.meshes.mesh import Mesh +from sasdata.slicing.meshes.mesh import Mesh def voronoi_mesh(x, y, debug_plot=False) -> Mesh: """ Create a mesh based on a voronoi diagram of points """ diff --git a/sasdata/data_util/slicing/rebinning.py b/sasdata/slicing/rebinning.py similarity index 94% rename from sasdata/data_util/slicing/rebinning.py rename to sasdata/slicing/rebinning.py index 510535a..f2c76de 100644 --- a/sasdata/data_util/slicing/rebinning.py +++ b/sasdata/slicing/rebinning.py @@ -4,9 +4,9 @@ import numpy as np -from sasdata.data_util.slicing.meshes.mesh import Mesh -from sasdata.data_util.slicing.meshes.voronoi_mesh import voronoi_mesh -from sasdata.data_util.slicing.meshes.meshmerge import meshmerge +from sasdata.slicing.meshes.mesh import Mesh +from sasdata.slicing.meshes.voronoi_mesh import voronoi_mesh +from sasdata.slicing.meshes.meshmerge import meshmerge import time diff --git a/sasdata/data_util/slicing/sample_polygons.py b/sasdata/slicing/sample_polygons.py similarity index 100% rename from sasdata/data_util/slicing/sample_polygons.py rename to sasdata/slicing/sample_polygons.py diff --git a/sasdata/data_util/slicing/slicer_demo.py b/sasdata/slicing/slicer_demo.py similarity index 90% rename from sasdata/data_util/slicing/slicer_demo.py rename to sasdata/slicing/slicer_demo.py index 6096ca9..af3ee98 100644 --- a/sasdata/data_util/slicing/slicer_demo.py +++ b/sasdata/slicing/slicer_demo.py @@ -4,9 +4,8 @@ import matplotlib.pyplot as plt -from sasdata.data_util.slicing.slicers.AnularSector import AnularSector -from sasdata.data_util.slicing.meshes.mesh import Mesh -from sasdata.data_util.slicing.meshes.voronoi_mesh import voronoi_mesh +from sasdata.slicing.slicers.AnularSector import AnularSector +from sasdata.slicing.meshes.voronoi_mesh import voronoi_mesh diff --git a/sasdata/data_util/slicing/slicers/AnularSector.py b/sasdata/slicing/slicers/AnularSector.py similarity index 87% rename from sasdata/data_util/slicing/slicers/AnularSector.py rename to sasdata/slicing/slicers/AnularSector.py index 6d034da..4ace344 100644 --- a/sasdata/data_util/slicing/slicers/AnularSector.py +++ b/sasdata/slicing/slicers/AnularSector.py @@ -1,7 +1,7 @@ import numpy as np -from sasdata.data_util.slicing.rebinning import Rebinner -from sasdata.data_util.slicing.meshes.mesh import Mesh +from sasdata.slicing.rebinning import Rebinner +from sasdata.slicing.meshes.mesh import Mesh class AnularSector(Rebinner): """ A single annular sector (wedge sum)""" diff --git a/sasdata/slicing/slicers/__init__.py b/sasdata/slicing/slicers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sasdata/data_util/slicing/transforms.py b/sasdata/slicing/transforms.py similarity index 100% rename from sasdata/data_util/slicing/transforms.py rename to sasdata/slicing/transforms.py diff --git a/sasdata/transforms/rebinning.py b/sasdata/transforms/rebinning.py index 662a0b1..7bdc662 100644 --- a/sasdata/transforms/rebinning.py +++ b/sasdata/transforms/rebinning.py @@ -130,7 +130,17 @@ def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], raise InterpolationError(f"Unsupported interpolation order: {order}") - return conversion_matrix + if mask is None: + return conversion_matrix, None + else: + # Create a new mask + + # Convert to numerical values + # Conservative masking: anything touched by the previous mask is now masked + new_mask = (np.array(mask, dtype=float) @ conversion_matrix) != 0.0 + + return conversion_matrix, new_mask + def calculate_interpolation_matrix_2d_axis_axis(input_1: Quantity[ArrayLike], input_2: Quantity[ArrayLike], @@ -140,7 +150,7 @@ def calculate_interpolation_matrix_2d_axis_axis(input_1: Quantity[ArrayLike], order: InterpolationOptions = InterpolationOptions.LINEAR, is_density: bool = False): - # If it wasn't for the mask, this would be the same as just two sets of 1D interpolation + # This is just the same 1D matrices things match order: case InterpolationOptions.NEAREST_NEIGHBOUR: diff --git a/test/slicers/meshes_for_testing.py b/test/slicers/meshes_for_testing.py index 7cb17b4..fb346e7 100644 --- a/test/slicers/meshes_for_testing.py +++ b/test/slicers/meshes_for_testing.py @@ -4,9 +4,9 @@ import numpy as np -from sasdata.data_util.slicing.meshes.voronoi_mesh import voronoi_mesh -from sasdata.data_util.slicing.meshes.mesh import Mesh -from sasdata.data_util.slicing.meshes.meshmerge import meshmerge +from sasdata.slicing.meshes.voronoi_mesh import voronoi_mesh +from sasdata.slicing.meshes.mesh import Mesh +from sasdata.slicing.meshes.meshmerge import meshmerge coords = np.arange(-4, 5) grid_mesh = voronoi_mesh(*np.meshgrid(coords, coords)) diff --git a/test/slicers/utest_meshmerge.py b/test/slicers/utest_meshmerge.py index f745d02..21071c0 100644 --- a/test/slicers/utest_meshmerge.py +++ b/test/slicers/utest_meshmerge.py @@ -4,7 +4,7 @@ It's pretty hard to test componentwise, but we can do some tests of the general behaviour """ -from sasdata.data_util.slicing.meshes.meshmerge import meshmerge +from sasdata.slicing.meshes import meshmerge from test.slicers.meshes_for_testing import ( grid_mesh, shape_mesh, expected_grid_mappings, expected_shape_mappings) From 4955f38eb64ba64a60cfb64e60cd843a930cbd47 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Mon, 21 Oct 2024 14:26:04 +0100 Subject: [PATCH 25/35] Move math and operations into quantity --- sasdata/model_requirements.py | 2 +- sasdata/quantities/math.py | 5 - sasdata/quantities/operations.py | 821 --------------------- sasdata/quantities/operations_examples.py | 2 +- sasdata/quantities/operations_test.py | 2 +- sasdata/quantities/quantity.py | 857 +++++++++++++++++++++- 6 files changed, 858 insertions(+), 831 deletions(-) delete mode 100644 sasdata/quantities/math.py delete mode 100644 sasdata/quantities/operations.py diff --git a/sasdata/model_requirements.py b/sasdata/model_requirements.py index d043b2c..3fc19c4 100644 --- a/sasdata/model_requirements.py +++ b/sasdata/model_requirements.py @@ -3,7 +3,7 @@ import numpy as np from sasdata.metadata import Metadata -from sasdata.quantities.operations import Operation +from sasdata.quantities.quantity import Operation @dataclass diff --git a/sasdata/quantities/math.py b/sasdata/quantities/math.py deleted file mode 100644 index d252ccc..0000000 --- a/sasdata/quantities/math.py +++ /dev/null @@ -1,5 +0,0 @@ -""" Math module extended to allow operations on quantities """ - -# TODO Implementations for trig and exp -# TODO Implementations for linear algebra stuff - diff --git a/sasdata/quantities/operations.py b/sasdata/quantities/operations.py deleted file mode 100644 index 35f6fc7..0000000 --- a/sasdata/quantities/operations.py +++ /dev/null @@ -1,821 +0,0 @@ -from typing import Any, TypeVar, Union -import numpy as np - -import json - -T = TypeVar("T") - -def hash_and_name(hash_or_name: int | str): - """ Infer the name of a variable from a hash, or the hash from the name - - Note: hash_and_name(hash_and_name(number)[1]) is not the identity - however: hash_and_name(hash_and_name(number)) is - """ - - if isinstance(hash_or_name, str): - hash_value = hash(hash_or_name) - name = hash_or_name - - return hash_value, name - - elif isinstance(hash_or_name, int): - hash_value = hash_or_name - name = f"#{hash_or_name}" - - return hash_value, name - - elif isinstance(hash_or_name, tuple): - return hash_or_name - - else: - raise TypeError("Variable name_or_hash_value must be either str or int") - - -class Operation: - - serialisation_name = "unknown" - def summary(self, indent_amount: int = 0, indent: str=" "): - """ Summary of the operation tree""" - - s = f"{indent_amount*indent}{self._summary_open()}(\n" - - for chunk in self._summary_components(): - s += chunk.summary(indent_amount+1, indent) + "\n" - - s += f"{indent_amount*indent})" - - return s - def _summary_open(self): - """ First line of summary """ - - def _summary_components(self) -> list["Operation"]: - return [] - def evaluate(self, variables: dict[int, T]) -> T: - - """ Evaluate this operation """ - - def _derivative(self, hash_value: int) -> "Operation": - """ Get the derivative of this operation """ - - def _clean(self): - """ Clean up this operation - i.e. remove silly things like 1*x """ - return self - - def derivative(self, variable: Union[str, int, "Variable"], simplify=True): - if isinstance(variable, Variable): - hash_value = variable.hash_value - else: - hash_value, _ = hash_and_name(variable) - - derivative = self._derivative(hash_value) - - if not simplify: - return derivative - - derivative_string = derivative.serialise() - - # print("---------------") - # print("Base") - # print("---------------") - # print(derivative.summary()) - - # Inefficient way of doing repeated simplification, but it will work - for i in range(100): # set max iterations - - derivative = derivative._clean() - # - # print("-------------------") - # print("Iteration", i+1) - # print("-------------------") - # print(derivative.summary()) - # print("-------------------") - - new_derivative_string = derivative.serialise() - - if derivative_string == new_derivative_string: - break - - derivative_string = new_derivative_string - - return derivative - - @staticmethod - def deserialise(data: str) -> "Operation": - json_data = json.loads(data) - return Operation.deserialise_json(json_data) - - @staticmethod - def deserialise_json(json_data: dict) -> "Operation": - - operation = json_data["operation"] - parameters = json_data["parameters"] - cls = _serialisation_lookup[operation] - - try: - return cls._deserialise(parameters) - - except NotImplementedError: - raise NotImplementedError(f"No method to deserialise {operation} with {parameters} (cls={cls})") - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - raise NotImplementedError(f"Deserialise not implemented for this class") - - def serialise(self) -> str: - return json.dumps(self._serialise_json()) - - def _serialise_json(self) -> dict[str, Any]: - return {"operation": self.serialisation_name, - "parameters": self._serialise_parameters()} - - def _serialise_parameters(self) -> dict[str, Any]: - raise NotImplementedError("_serialise_parameters not implemented") - - def __eq__(self, other: "Operation"): - return NotImplemented - -class ConstantBase(Operation): - pass - -class AdditiveIdentity(ConstantBase): - - serialisation_name = "zero" - def evaluate(self, variables: dict[int, T]) -> T: - return 0 - - def _derivative(self, hash_value: int) -> Operation: - return AdditiveIdentity() - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return AdditiveIdentity() - - def _serialise_parameters(self) -> dict[str, Any]: - return {} - - def summary(self, indent_amount: int=0, indent=" "): - return f"{indent_amount*indent}0 [Add.Id.]" - - def __eq__(self, other): - if isinstance(other, AdditiveIdentity): - return True - elif isinstance(other, Constant): - if other.value == 0: - return True - - return False - - - -class MultiplicativeIdentity(ConstantBase): - - serialisation_name = "one" - - def evaluate(self, variables: dict[int, T]) -> T: - return 1 - - def _derivative(self, hash_value: int): - return AdditiveIdentity() - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return MultiplicativeIdentity() - - - def _serialise_parameters(self) -> dict[str, Any]: - return {} - - - def summary(self, indent_amount: int=0, indent=" "): - return f"{indent_amount*indent}1 [Mul.Id.]" - - def __eq__(self, other): - if isinstance(other, MultiplicativeIdentity): - return True - elif isinstance(other, Constant): - if other.value == 1: - return True - - return False - - -class Constant(ConstantBase): - - serialisation_name = "constant" - def __init__(self, value): - self.value = value - - def summary(self, indent_amount: int = 0, indent: str=" "): - pass - - def evaluate(self, variables: dict[int, T]) -> T: - return self.value - - def _derivative(self, hash_value: int): - return AdditiveIdentity() - - def _clean(self): - - if self.value == 0: - return AdditiveIdentity() - - elif self.value == 1: - return MultiplicativeIdentity() - - else: - return self - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - value = parameters["value"] - return Constant(value) - - - def _serialise_parameters(self) -> dict[str, Any]: - return {"value": self.value} - - - def summary(self, indent_amount: int=0, indent=" "): - return f"{indent_amount*indent}{self.value}" - - def __eq__(self, other): - if isinstance(other, AdditiveIdentity): - return self.value == 0 - - elif isinstance(other, MultiplicativeIdentity): - return self.value == 1 - - elif isinstance(other, Constant): - if other.value == self.value: - return True - - return False - - -class Variable(Operation): - - serialisation_name = "variable" - def __init__(self, name_or_hash_value: int | str | tuple[int, str]): - self.hash_value, self.name = hash_and_name(name_or_hash_value) - - def evaluate(self, variables: dict[int, T]) -> T: - try: - return variables[self.hash_value] - except KeyError: - raise ValueError(f"Variable dictionary didn't have an entry for {self.name} (hash={self.hash_value})") - - def _derivative(self, hash_value: int) -> Operation: - if hash_value == self.hash_value: - return MultiplicativeIdentity() - else: - return AdditiveIdentity() - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - hash_value = parameters["hash_value"] - name = parameters["name"] - - return Variable((hash_value, name)) - - def _serialise_parameters(self) -> dict[str, Any]: - return {"hash_value": self.hash_value, - "name": self.name} - - def summary(self, indent_amount: int = 0, indent: str=" "): - return f"{indent_amount*indent}{self.name}" - - def __eq__(self, other): - if isinstance(other, Variable): - return self.hash_value == other.hash_value - - return False - -class UnaryOperation(Operation): - - def __init__(self, a: Operation): - self.a = a - - def _serialise_parameters(self) -> dict[str, Any]: - return {"a": self.a._serialise_json()} - - def _summary_components(self) -> list["Operation"]: - return [self.a] - - - - -class Neg(UnaryOperation): - - serialisation_name = "neg" - def evaluate(self, variables: dict[int, T]) -> T: - return -self.a.evaluate(variables) - - def _derivative(self, hash_value: int): - return Neg(self.a._derivative(hash_value)) - - def _clean(self): - - clean_a = self.a._clean() - - if isinstance(clean_a, Neg): - # Removes double negations - return clean_a.a - - elif isinstance(clean_a, Constant): - return Constant(-clean_a.value)._clean() - - else: - return Neg(clean_a) - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Neg(Operation.deserialise_json(parameters["a"])) - - - def _summary_open(self): - return "Neg" - - def __eq__(self, other): - if isinstance(other, Neg): - return other.a == self.a - - -class Inv(UnaryOperation): - - serialisation_name = "reciprocal" - - def evaluate(self, variables: dict[int, T]) -> T: - return 1/self.a.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Neg(Div(self.a._derivative(hash_value), Mul(self.a, self.a))) - - def _clean(self): - clean_a = self.a._clean() - - if isinstance(clean_a, Inv): - # Removes double negations - return clean_a.a - - elif isinstance(clean_a, Neg): - # cannonicalise 1/-a to -(1/a) - # over multiple iterations this should have the effect of ordering and gathering Neg and Inv - return Neg(Inv(clean_a.a)) - - elif isinstance(clean_a, Constant): - return Constant(1/clean_a.value)._clean() - - else: - return Inv(clean_a) - - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Inv(Operation.deserialise_json(parameters["a"])) - - def _summary_open(self): - return "Inv" - - - def __eq__(self, other): - if isinstance(other, Inv): - return other.a == self.a - -class BinaryOperation(Operation): - def __init__(self, a: Operation, b: Operation): - self.a = a - self.b = b - - def _clean(self): - return self._clean_ab(self.a._clean(), self.b._clean()) - - def _clean_ab(self, a, b): - raise NotImplementedError("_clean_ab not implemented") - - def _serialise_parameters(self) -> dict[str, Any]: - return {"a": self.a._serialise_json(), - "b": self.b._serialise_json()} - - @staticmethod - def _deserialise_ab(parameters) -> tuple[Operation, Operation]: - return (Operation.deserialise_json(parameters["a"]), - Operation.deserialise_json(parameters["b"])) - - - def _summary_components(self) -> list["Operation"]: - return [self.a, self.b] - - def _self_cls(self) -> type: - """ Own class""" - def __eq__(self, other): - if isinstance(other, self._self_cls()): - return other.a == self.a and self.b == other.b - -class Add(BinaryOperation): - - serialisation_name = "add" - - def _self_cls(self) -> type: - return Add - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) + self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Add(self.a._derivative(hash_value), self.b._derivative(hash_value)) - - def _clean_ab(self, a, b): - - if isinstance(a, AdditiveIdentity): - # Convert 0 + b to b - return b - - elif isinstance(b, AdditiveIdentity): - # Convert a + 0 to a - return a - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constant "a"+"b" to "a+b" - return Constant(a.evaluate({}) + b.evaluate({}))._clean() - - elif isinstance(a, Neg): - if isinstance(b, Neg): - # Convert (-a)+(-b) to -(a+b) - return Neg(Add(a.a, b.a)) - else: - # Convert (-a) + b to b-a - return Sub(b, a.a) - - elif isinstance(b, Neg): - # Convert a+(-b) to a-b - return Sub(a, b.a) - - elif a == b: - return Mul(Constant(2), a) - - else: - return Add(a, b) - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Add(*BinaryOperation._deserialise_ab(parameters)) - - def _summary_open(self): - return "Add" - -class Sub(BinaryOperation): - - serialisation_name = "sub" - - - def _self_cls(self) -> type: - return Sub - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) - self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Sub(self.a._derivative(hash_value), self.b._derivative(hash_value)) - - def _clean_ab(self, a, b): - if isinstance(a, AdditiveIdentity): - # Convert 0 - b to -b - return Neg(b) - - elif isinstance(b, AdditiveIdentity): - # Convert a - 0 to a - return a - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constant pair "a" - "b" to "a-b" - return Constant(a.evaluate({}) - b.evaluate({}))._clean() - - elif isinstance(a, Neg): - if isinstance(b, Neg): - # Convert (-a)-(-b) to b-a - return Sub(b.a, a.a) - else: - # Convert (-a)-b to -(a+b) - return Neg(Add(a.a, b)) - - elif isinstance(b, Neg): - # Convert a-(-b) to a+b - return Add(a, b.a) - - elif a == b: - return AdditiveIdentity() - - else: - return Sub(a, b) - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Sub(*BinaryOperation._deserialise_ab(parameters)) - - - def _summary_open(self): - return "Sub" - -class Mul(BinaryOperation): - - serialisation_name = "mul" - - - def _self_cls(self) -> type: - return Mul - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) * self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Add(Mul(self.a, self.b._derivative(hash_value)), Mul(self.a._derivative(hash_value), self.b)) - - def _clean_ab(self, a, b): - - if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): - # Convert 0*b or a*0 to 0 - return AdditiveIdentity() - - elif isinstance(a, MultiplicativeIdentity): - # Convert 1*b to b - return b - - elif isinstance(b, MultiplicativeIdentity): - # Convert a*1 to a - return a - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constant "a"*"b" to "a*b" - return Constant(a.evaluate({}) * b.evaluate({}))._clean() - - elif isinstance(a, Inv) and isinstance(b, Inv): - return Inv(Mul(a.a, b.a)) - - elif isinstance(a, Inv) and not isinstance(b, Inv): - return Div(b, a.a) - - elif not isinstance(a, Inv) and isinstance(b, Inv): - return Div(a, b.a) - - elif isinstance(a, Neg): - return Neg(Mul(a.a, b)) - - elif isinstance(b, Neg): - return Neg(Mul(a, b.a)) - - elif a == b: - return Pow(a, 2) - - elif isinstance(a, Pow) and a.a == b: - return Pow(b, a.power + 1) - - elif isinstance(b, Pow) and b.a == a: - return Pow(a, b.power + 1) - - elif isinstance(a, Pow) and isinstance(b, Pow) and a.a == b.a: - return Pow(a.a, a.power + b.power) - - else: - return Mul(a, b) - - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Mul(*BinaryOperation._deserialise_ab(parameters)) - - - def _summary_open(self): - return "Mul" - -class Div(BinaryOperation): - - serialisation_name = "div" - - - def _self_cls(self) -> type: - return Div - - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) / self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Sub(Div(self.a.derivative(hash_value), self.b), - Div(Mul(self.a, self.b.derivative(hash_value)), Mul(self.b, self.b))) - - def _clean_ab(self, a, b): - if isinstance(a, AdditiveIdentity): - # Convert 0/b to 0 - return AdditiveIdentity() - - elif isinstance(a, MultiplicativeIdentity): - # Convert 1/b to inverse of b - return Inv(b) - - elif isinstance(b, MultiplicativeIdentity): - # Convert a/1 to a - return a - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constants "a"/"b" to "a/b" - return Constant(self.a.evaluate({}) / self.b.evaluate({}))._clean() - - - elif isinstance(a, Inv) and isinstance(b, Inv): - return Div(b.a, a.a) - - elif isinstance(a, Inv) and not isinstance(b, Inv): - return Inv(Mul(a.a, b)) - - elif not isinstance(a, Inv) and isinstance(b, Inv): - return Mul(a, b.a) - - elif a == b: - return MultiplicativeIdentity() - - elif isinstance(a, Pow) and a.a == b: - return Pow(b, a.power - 1) - - elif isinstance(b, Pow) and b.a == a: - return Pow(a, 1 - b.power) - - elif isinstance(a, Pow) and isinstance(b, Pow) and a.a == b.a: - return Pow(a.a, a.power - b.power) - - else: - return Div(a, b) - - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Div(*BinaryOperation._deserialise_ab(parameters)) - - def _summary_open(self): - return "Div" - -class Pow(Operation): - - serialisation_name = "pow" - - def __init__(self, a: Operation, power: float): - self.a = a - self.power = power - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) ** self.power - - def _derivative(self, hash_value: int) -> Operation: - if self.power == 0: - return AdditiveIdentity() - - elif self.power == 1: - return self.a._derivative(hash_value) - - else: - return Mul(Constant(self.power), Mul(Pow(self.a, self.power-1), self.a._derivative(hash_value))) - - def _clean(self) -> Operation: - a = self.a._clean() - - if self.power == 1: - return a - - elif self.power == 0: - return MultiplicativeIdentity() - - elif self.power == -1: - return Inv(a) - - else: - return Pow(a, self.power) - - - def _serialise_parameters(self) -> dict[str, Any]: - return {"a": Operation._serialise_json(self.a), - "power": self.power} - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Pow(Operation.deserialise_json(parameters["a"]), parameters["power"]) - - def summary(self, indent_amount: int=0, indent=" "): - return (f"{indent_amount*indent}Pow\n" + - self.a.summary(indent_amount+1, indent) + "\n" + - f"{(indent_amount+1)*indent}{self.power}\n" + - f"{indent_amount*indent})") - - def __eq__(self, other): - if isinstance(other, Pow): - return self.a == other.a and self.power == other.power - - - -# -# Matrix operations -# - -class Transpose(UnaryOperation): - """ Transpose operation - as per numpy""" - - serialisation_name = "transpose" - - def evaluate(self, variables: dict[int, T]) -> T: - return np.transpose(self.a.evaluate(variables)) - - def _derivative(self, hash_value: int) -> Operation: - return Transpose(self.a.derivative(hash_value)) # TODO: Check! - - def _clean(self): - clean_a = self.a._clean() - return Transpose(clean_a) - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Transpose(Operation.deserialise_json(parameters["a"])) - - def _summary_open(self): - return "Transpose" - - def __eq__(self, other): - if isinstance(other, Transpose): - return other.a == self.a - - -class Dot(BinaryOperation): - """ Dot product - backed by numpy's dot method""" - - serialisation_name = "dot" - - def _self_cls(self) -> type: - return Dot - - def evaluate(self, variables: dict[int, T]) -> T: - return np.dot(self.a.evaluate(variables) + self.b.evaluate(variables)) - - def _derivative(self, hash_value: int) -> Operation: - return Add( - Dot(self.a, - self.b._derivative(hash_value)), - Dot(self.a._derivative(hash_value), - self.b)) - - def _clean_ab(self, a, b): - return Dot(a, b) # Do nothing for now - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Dot(*BinaryOperation._deserialise_ab(parameters)) - - def _summary_open(self): - return "Dot" - - -# TODO: Add to base operation class, and to quantities -class MatMul(BinaryOperation): - """ Matrix multiplication, using __matmul__ dunder""" - - serialisation_name = "matmul" - - def _self_cls(self) -> type: - return MatMul - - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) @ self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Add( - MatMul(self.a, - self.b._derivative(hash_value)), - MatMul(self.a._derivative(hash_value), - self.b)) - - def _clean_ab(self, a, b): - - if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): - # Convert 0*b or a*0 to 0 - return AdditiveIdentity() - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constant "a"@"b" to "a@b" - return Constant(a.evaluate({}) @ b.evaluate({}))._clean() - - elif isinstance(a, Neg): - return Neg(Mul(a.a, b)) - - elif isinstance(b, Neg): - return Neg(Mul(a, b.a)) - - return MatMul(a, b) - - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return MatMul(*BinaryOperation._deserialise_ab(parameters)) - - def _summary_open(self): - return "MatMul" - - - -_serialisable_classes = [AdditiveIdentity, MultiplicativeIdentity, Constant, - Variable, - Neg, Inv, - Add, Sub, Mul, Div, Pow, - Transpose, Dot, MatMul] - -_serialisation_lookup = {cls.serialisation_name: cls for cls in _serialisable_classes} diff --git a/sasdata/quantities/operations_examples.py b/sasdata/quantities/operations_examples.py index e20eef9..29ddd00 100644 --- a/sasdata/quantities/operations_examples.py +++ b/sasdata/quantities/operations_examples.py @@ -1,4 +1,4 @@ -from sasdata.quantities.operations import Variable, Mul +from sasdata.quantities.quantity import Variable, Mul x = Variable("x") y = Variable("y") diff --git a/sasdata/quantities/operations_test.py b/sasdata/quantities/operations_test.py index 6fffb36..854e865 100644 --- a/sasdata/quantities/operations_test.py +++ b/sasdata/quantities/operations_test.py @@ -1,6 +1,6 @@ import pytest -from sasdata.quantities.operations import Operation, \ +from sasdata.quantities.quantity import Operation, \ Neg, Inv, \ Add, Sub, Mul, Div, Pow, \ Variable, Constant, AdditiveIdentity, MultiplicativeIdentity diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index 3def7e1..279c6ca 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -4,13 +4,866 @@ import numpy as np from numpy._typing import ArrayLike -from sasdata.quantities.operations import Operation, Variable -from sasdata.quantities import operations, units +from sasdata.quantities import units from sasdata.quantities.units import Unit, NamedUnit import hashlib +from typing import Any, TypeVar, Union +import numpy as np + +import json + +T = TypeVar("T") + + + + + +################### Quantity based operations, need to be here to avoid cyclic dependencies ##################### + + +def transpose(a: Union["Quantity[ArrayLike]", ArrayLike]): + if isinstance(a, Quantity): + return + + +def tensordot(a: Union["Quantity[ArrayLike]", ArrayLike] | ArrayLike, b: Union["Quantity[ArrayLike]", ArrayLike], a_index: int, b_index: int): + pass + + +################### Operation Definitions ####################################### + +def hash_and_name(hash_or_name: int | str): + """ Infer the name of a variable from a hash, or the hash from the name + + Note: hash_and_name(hash_and_name(number)[1]) is not the identity + however: hash_and_name(hash_and_name(number)) is + """ + + if isinstance(hash_or_name, str): + hash_value = hash(hash_or_name) + name = hash_or_name + + return hash_value, name + + elif isinstance(hash_or_name, int): + hash_value = hash_or_name + name = f"#{hash_or_name}" + + return hash_value, name + + elif isinstance(hash_or_name, tuple): + return hash_or_name + + else: + raise TypeError("Variable name_or_hash_value must be either str or int") + + +class Operation: + + serialisation_name = "unknown" + def summary(self, indent_amount: int = 0, indent: str=" "): + """ Summary of the operation tree""" + + s = f"{indent_amount*indent}{self._summary_open()}(\n" + + for chunk in self._summary_components(): + s += chunk.summary(indent_amount+1, indent) + "\n" + + s += f"{indent_amount*indent})" + + return s + def _summary_open(self): + """ First line of summary """ + + def _summary_components(self) -> list["Operation"]: + return [] + def evaluate(self, variables: dict[int, T]) -> T: + + """ Evaluate this operation """ + + def _derivative(self, hash_value: int) -> "Operation": + """ Get the derivative of this operation """ + + def _clean(self): + """ Clean up this operation - i.e. remove silly things like 1*x """ + return self + + def derivative(self, variable: Union[str, int, "Variable"], simplify=True): + if isinstance(variable, Variable): + hash_value = variable.hash_value + else: + hash_value, _ = hash_and_name(variable) + + derivative = self._derivative(hash_value) + + if not simplify: + return derivative + + derivative_string = derivative.serialise() + + # print("---------------") + # print("Base") + # print("---------------") + # print(derivative.summary()) + + # Inefficient way of doing repeated simplification, but it will work + for i in range(100): # set max iterations + + derivative = derivative._clean() + # + # print("-------------------") + # print("Iteration", i+1) + # print("-------------------") + # print(derivative.summary()) + # print("-------------------") + + new_derivative_string = derivative.serialise() + + if derivative_string == new_derivative_string: + break + + derivative_string = new_derivative_string + + return derivative + + @staticmethod + def deserialise(data: str) -> "Operation": + json_data = json.loads(data) + return Operation.deserialise_json(json_data) + + @staticmethod + def deserialise_json(json_data: dict) -> "Operation": + + operation = json_data["operation"] + parameters = json_data["parameters"] + cls = _serialisation_lookup[operation] + + try: + return cls._deserialise(parameters) + + except NotImplementedError: + raise NotImplementedError(f"No method to deserialise {operation} with {parameters} (cls={cls})") + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + raise NotImplementedError(f"Deserialise not implemented for this class") + + def serialise(self) -> str: + return json.dumps(self._serialise_json()) + + def _serialise_json(self) -> dict[str, Any]: + return {"operation": self.serialisation_name, + "parameters": self._serialise_parameters()} + + def _serialise_parameters(self) -> dict[str, Any]: + raise NotImplementedError("_serialise_parameters not implemented") + + def __eq__(self, other: "Operation"): + return NotImplemented + +class ConstantBase(Operation): + pass + +class AdditiveIdentity(ConstantBase): + + serialisation_name = "zero" + def evaluate(self, variables: dict[int, T]) -> T: + return 0 + + def _derivative(self, hash_value: int) -> Operation: + return AdditiveIdentity() + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return AdditiveIdentity() + + def _serialise_parameters(self) -> dict[str, Any]: + return {} + + def summary(self, indent_amount: int=0, indent=" "): + return f"{indent_amount*indent}0 [Add.Id.]" + + def __eq__(self, other): + if isinstance(other, AdditiveIdentity): + return True + elif isinstance(other, Constant): + if other.value == 0: + return True + + return False + + + +class MultiplicativeIdentity(ConstantBase): + + serialisation_name = "one" + + def evaluate(self, variables: dict[int, T]) -> T: + return 1 + + def _derivative(self, hash_value: int): + return AdditiveIdentity() + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return MultiplicativeIdentity() + + + def _serialise_parameters(self) -> dict[str, Any]: + return {} + + + def summary(self, indent_amount: int=0, indent=" "): + return f"{indent_amount*indent}1 [Mul.Id.]" + + def __eq__(self, other): + if isinstance(other, MultiplicativeIdentity): + return True + elif isinstance(other, Constant): + if other.value == 1: + return True + + return False + + +class Constant(ConstantBase): + + serialisation_name = "constant" + def __init__(self, value): + self.value = value + + def summary(self, indent_amount: int = 0, indent: str=" "): + pass + + def evaluate(self, variables: dict[int, T]) -> T: + return self.value + + def _derivative(self, hash_value: int): + return AdditiveIdentity() + + def _clean(self): + + if self.value == 0: + return AdditiveIdentity() + + elif self.value == 1: + return MultiplicativeIdentity() + + else: + return self + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + value = parameters["value"] + return Constant(value) + + + def _serialise_parameters(self) -> dict[str, Any]: + return {"value": self.value} + + + def summary(self, indent_amount: int=0, indent=" "): + return f"{indent_amount*indent}{self.value}" + + def __eq__(self, other): + if isinstance(other, AdditiveIdentity): + return self.value == 0 + + elif isinstance(other, MultiplicativeIdentity): + return self.value == 1 + + elif isinstance(other, Constant): + if other.value == self.value: + return True + + return False + + +class Variable(Operation): + + serialisation_name = "variable" + def __init__(self, name_or_hash_value: int | str | tuple[int, str]): + self.hash_value, self.name = hash_and_name(name_or_hash_value) + + def evaluate(self, variables: dict[int, T]) -> T: + try: + return variables[self.hash_value] + except KeyError: + raise ValueError(f"Variable dictionary didn't have an entry for {self.name} (hash={self.hash_value})") + + def _derivative(self, hash_value: int) -> Operation: + if hash_value == self.hash_value: + return MultiplicativeIdentity() + else: + return AdditiveIdentity() + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + hash_value = parameters["hash_value"] + name = parameters["name"] + + return Variable((hash_value, name)) + + def _serialise_parameters(self) -> dict[str, Any]: + return {"hash_value": self.hash_value, + "name": self.name} + + def summary(self, indent_amount: int = 0, indent: str=" "): + return f"{indent_amount*indent}{self.name}" + + def __eq__(self, other): + if isinstance(other, Variable): + return self.hash_value == other.hash_value + + return False + +class UnaryOperation(Operation): + + def __init__(self, a: Operation): + self.a = a + + def _serialise_parameters(self) -> dict[str, Any]: + return {"a": self.a._serialise_json()} + + def _summary_components(self) -> list["Operation"]: + return [self.a] + + + + +class Neg(UnaryOperation): + + serialisation_name = "neg" + def evaluate(self, variables: dict[int, T]) -> T: + return -self.a.evaluate(variables) + + def _derivative(self, hash_value: int): + return Neg(self.a._derivative(hash_value)) + + def _clean(self): + + clean_a = self.a._clean() + + if isinstance(clean_a, Neg): + # Removes double negations + return clean_a.a + + elif isinstance(clean_a, Constant): + return Constant(-clean_a.value)._clean() + + else: + return Neg(clean_a) + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Neg(Operation.deserialise_json(parameters["a"])) + + + def _summary_open(self): + return "Neg" + + def __eq__(self, other): + if isinstance(other, Neg): + return other.a == self.a + + +class Inv(UnaryOperation): + + serialisation_name = "reciprocal" + + def evaluate(self, variables: dict[int, T]) -> T: + return 1/self.a.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Neg(Div(self.a._derivative(hash_value), Mul(self.a, self.a))) + + def _clean(self): + clean_a = self.a._clean() + + if isinstance(clean_a, Inv): + # Removes double negations + return clean_a.a + + elif isinstance(clean_a, Neg): + # cannonicalise 1/-a to -(1/a) + # over multiple iterations this should have the effect of ordering and gathering Neg and Inv + return Neg(Inv(clean_a.a)) + + elif isinstance(clean_a, Constant): + return Constant(1/clean_a.value)._clean() + + else: + return Inv(clean_a) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Inv(Operation.deserialise_json(parameters["a"])) + + def _summary_open(self): + return "Inv" + + + def __eq__(self, other): + if isinstance(other, Inv): + return other.a == self.a + +class BinaryOperation(Operation): + def __init__(self, a: Operation, b: Operation): + self.a = a + self.b = b + + def _clean(self): + return self._clean_ab(self.a._clean(), self.b._clean()) + + def _clean_ab(self, a, b): + raise NotImplementedError("_clean_ab not implemented") + + def _serialise_parameters(self) -> dict[str, Any]: + return {"a": self.a._serialise_json(), + "b": self.b._serialise_json()} + + @staticmethod + def _deserialise_ab(parameters) -> tuple[Operation, Operation]: + return (Operation.deserialise_json(parameters["a"]), + Operation.deserialise_json(parameters["b"])) + + + def _summary_components(self) -> list["Operation"]: + return [self.a, self.b] + + def _self_cls(self) -> type: + """ Own class""" + def __eq__(self, other): + if isinstance(other, self._self_cls()): + return other.a == self.a and self.b == other.b + +class Add(BinaryOperation): + + serialisation_name = "add" + + def _self_cls(self) -> type: + return Add + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) + self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Add(self.a._derivative(hash_value), self.b._derivative(hash_value)) + + def _clean_ab(self, a, b): + + if isinstance(a, AdditiveIdentity): + # Convert 0 + b to b + return b + + elif isinstance(b, AdditiveIdentity): + # Convert a + 0 to a + return a + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant "a"+"b" to "a+b" + return Constant(a.evaluate({}) + b.evaluate({}))._clean() + + elif isinstance(a, Neg): + if isinstance(b, Neg): + # Convert (-a)+(-b) to -(a+b) + return Neg(Add(a.a, b.a)) + else: + # Convert (-a) + b to b-a + return Sub(b, a.a) + + elif isinstance(b, Neg): + # Convert a+(-b) to a-b + return Sub(a, b.a) + + elif a == b: + return Mul(Constant(2), a) + + else: + return Add(a, b) + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Add(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "Add" + +class Sub(BinaryOperation): + + serialisation_name = "sub" + + + def _self_cls(self) -> type: + return Sub + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) - self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Sub(self.a._derivative(hash_value), self.b._derivative(hash_value)) + + def _clean_ab(self, a, b): + if isinstance(a, AdditiveIdentity): + # Convert 0 - b to -b + return Neg(b) + + elif isinstance(b, AdditiveIdentity): + # Convert a - 0 to a + return a + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant pair "a" - "b" to "a-b" + return Constant(a.evaluate({}) - b.evaluate({}))._clean() + + elif isinstance(a, Neg): + if isinstance(b, Neg): + # Convert (-a)-(-b) to b-a + return Sub(b.a, a.a) + else: + # Convert (-a)-b to -(a+b) + return Neg(Add(a.a, b)) + + elif isinstance(b, Neg): + # Convert a-(-b) to a+b + return Add(a, b.a) + + elif a == b: + return AdditiveIdentity() + + else: + return Sub(a, b) + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Sub(*BinaryOperation._deserialise_ab(parameters)) + + + def _summary_open(self): + return "Sub" + +class Mul(BinaryOperation): + + serialisation_name = "mul" + + + def _self_cls(self) -> type: + return Mul + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) * self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Add(Mul(self.a, self.b._derivative(hash_value)), Mul(self.a._derivative(hash_value), self.b)) + + def _clean_ab(self, a, b): + + if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): + # Convert 0*b or a*0 to 0 + return AdditiveIdentity() + + elif isinstance(a, MultiplicativeIdentity): + # Convert 1*b to b + return b + + elif isinstance(b, MultiplicativeIdentity): + # Convert a*1 to a + return a + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant "a"*"b" to "a*b" + return Constant(a.evaluate({}) * b.evaluate({}))._clean() + + elif isinstance(a, Inv) and isinstance(b, Inv): + return Inv(Mul(a.a, b.a)) + + elif isinstance(a, Inv) and not isinstance(b, Inv): + return Div(b, a.a) + + elif not isinstance(a, Inv) and isinstance(b, Inv): + return Div(a, b.a) + + elif isinstance(a, Neg): + return Neg(Mul(a.a, b)) + + elif isinstance(b, Neg): + return Neg(Mul(a, b.a)) + + elif a == b: + return Pow(a, 2) + + elif isinstance(a, Pow) and a.a == b: + return Pow(b, a.power + 1) + + elif isinstance(b, Pow) and b.a == a: + return Pow(a, b.power + 1) + + elif isinstance(a, Pow) and isinstance(b, Pow) and a.a == b.a: + return Pow(a.a, a.power + b.power) + + else: + return Mul(a, b) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Mul(*BinaryOperation._deserialise_ab(parameters)) + + + def _summary_open(self): + return "Mul" + +class Div(BinaryOperation): + + serialisation_name = "div" + + + def _self_cls(self) -> type: + return Div + + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) / self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Sub(Div(self.a.derivative(hash_value), self.b), + Div(Mul(self.a, self.b.derivative(hash_value)), Mul(self.b, self.b))) + + def _clean_ab(self, a, b): + if isinstance(a, AdditiveIdentity): + # Convert 0/b to 0 + return AdditiveIdentity() + + elif isinstance(a, MultiplicativeIdentity): + # Convert 1/b to inverse of b + return Inv(b) + + elif isinstance(b, MultiplicativeIdentity): + # Convert a/1 to a + return a + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constants "a"/"b" to "a/b" + return Constant(self.a.evaluate({}) / self.b.evaluate({}))._clean() + + + elif isinstance(a, Inv) and isinstance(b, Inv): + return Div(b.a, a.a) + + elif isinstance(a, Inv) and not isinstance(b, Inv): + return Inv(Mul(a.a, b)) + + elif not isinstance(a, Inv) and isinstance(b, Inv): + return Mul(a, b.a) + + elif a == b: + return MultiplicativeIdentity() + + elif isinstance(a, Pow) and a.a == b: + return Pow(b, a.power - 1) + + elif isinstance(b, Pow) and b.a == a: + return Pow(a, 1 - b.power) + + elif isinstance(a, Pow) and isinstance(b, Pow) and a.a == b.a: + return Pow(a.a, a.power - b.power) + + else: + return Div(a, b) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Div(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "Div" + +class Pow(Operation): + + serialisation_name = "pow" + + def __init__(self, a: Operation, power: float): + self.a = a + self.power = power + + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) ** self.power + + def _derivative(self, hash_value: int) -> Operation: + if self.power == 0: + return AdditiveIdentity() + + elif self.power == 1: + return self.a._derivative(hash_value) + + else: + return Mul(Constant(self.power), Mul(Pow(self.a, self.power-1), self.a._derivative(hash_value))) + + def _clean(self) -> Operation: + a = self.a._clean() + + if self.power == 1: + return a + + elif self.power == 0: + return MultiplicativeIdentity() + + elif self.power == -1: + return Inv(a) + + else: + return Pow(a, self.power) + + + def _serialise_parameters(self) -> dict[str, Any]: + return {"a": Operation._serialise_json(self.a), + "power": self.power} + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Pow(Operation.deserialise_json(parameters["a"]), parameters["power"]) + + def summary(self, indent_amount: int=0, indent=" "): + return (f"{indent_amount*indent}Pow\n" + + self.a.summary(indent_amount+1, indent) + "\n" + + f"{(indent_amount+1)*indent}{self.power}\n" + + f"{indent_amount*indent})") + + def __eq__(self, other): + if isinstance(other, Pow): + return self.a == other.a and self.power == other.power + + + +# +# Matrix operations +# + +class Transpose(UnaryOperation): + """ Transpose operation - as per numpy""" + + serialisation_name = "transpose" + + def evaluate(self, variables: dict[int, T]) -> T: + return np.transpose(self.a.evaluate(variables)) + + def _derivative(self, hash_value: int) -> Operation: + return Transpose(self.a.derivative(hash_value)) # TODO: Check! + + def _clean(self): + clean_a = self.a._clean() + return Transpose(clean_a) + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Transpose(Operation.deserialise_json(parameters["a"])) + + def _summary_open(self): + return "Transpose" + + def __eq__(self, other): + if isinstance(other, Transpose): + return other.a == self.a + + +class Dot(BinaryOperation): + """ Dot product - backed by numpy's dot method""" + + serialisation_name = "dot" + + def evaluate(self, variables: dict[int, T]) -> T: + return np.dot(self.a.evaluate(variables) + self.b.evaluate(variables)) + + def _derivative(self, hash_value: int) -> Operation: + return Add( + Dot(self.a, + self.b._derivative(hash_value)), + Dot(self.a._derivative(hash_value), + self.b)) + + def _clean_ab(self, a, b): + return Dot(a, b) # Do nothing for now + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Dot(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "Dot" + + +# TODO: Add to base operation class, and to quantities +class MatMul(BinaryOperation): + """ Matrix multiplication, using __matmul__ dunder""" + + serialisation_name = "matmul" + + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) @ self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Add( + MatMul(self.a, + self.b._derivative(hash_value)), + MatMul(self.a._derivative(hash_value), + self.b)) + + def _clean_ab(self, a, b): + + if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): + # Convert 0*b or a*0 to 0 + return AdditiveIdentity() + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant "a"@"b" to "a@b" + return Constant(a.evaluate({}) @ b.evaluate({}))._clean() + + elif isinstance(a, Neg): + return Neg(Mul(a.a, b)) + + elif isinstance(b, Neg): + return Neg(Mul(a, b.a)) + + return MatMul(a, b) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return MatMul(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "MatMul" + +class TensorProduct(Operation): + serialisation_name = "tensor_product" + + def __init__(self, a: Operation, b: Operation, a_index: int, b_index: int): + self.a = a + self.b = b + self.a_index = a_index + self.b_index = b_index + + def evaluate(self, variables: dict[int, T]) -> T: + return np.tensordot(self.a, self.b, axes=(self.a_index, self.b_index)) + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + pass + + def _summary_open(self): + return "TensorProduct" + + +_serialisable_classes = [AdditiveIdentity, MultiplicativeIdentity, Constant, + Variable, + Neg, Inv, + Add, Sub, Mul, Div, Pow, + Transpose, Dot, MatMul] + +_serialisation_lookup = {cls.serialisation_name: cls for cls in _serialisable_classes} + + class UnitError(Exception): """ Errors caused by unit specification not being correct """ From eb4b0114a43c0a22aa34d2871393211be44409e5 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Mon, 21 Oct 2024 14:27:53 +0100 Subject: [PATCH 26/35] Fixes from move --- sasdata/quantities/quantity.py | 42 +++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index 279c6ca..09e2880 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -1068,14 +1068,14 @@ def __mul__(self: Self, other: ArrayLike | Self ) -> Self: return DerivedQuantity( self.value * other.value, self.units * other.units, - history=QuantityHistory.apply_operation(operations.Mul, self.history, other.history)) + history=QuantityHistory.apply_operation(Mul, self.history, other.history)) else: return DerivedQuantity(self.value * other, self.units, QuantityHistory( - operations.Mul( + Mul( self.history.operation_tree, - operations.Constant(other)), + Constant(other)), self.history.references)) def __rmul__(self: Self, other: ArrayLike | Self): @@ -1084,15 +1084,15 @@ def __rmul__(self: Self, other: ArrayLike | Self): other.value * self.value, other.units * self.units, history=QuantityHistory.apply_operation( - operations.Mul, + Mul, other.history, self.history)) else: return DerivedQuantity(other * self.value, self.units, QuantityHistory( - operations.Mul( - operations.Constant(other), + Mul( + Constant(other), self.history.operation_tree), self.history.references)) @@ -1103,7 +1103,7 @@ def __matmul__(self, other: ArrayLike | Self): self.value @ other.value, self.units * other.units, history=QuantityHistory.apply_operation( - operations.MatMul, + MatMul, self.history, other.history)) else: @@ -1111,9 +1111,9 @@ def __matmul__(self, other: ArrayLike | Self): self.value @ other, self.units, QuantityHistory( - operations.MatMul( + MatMul( self.history.operation_tree, - operations.Constant(other)), + Constant(other)), self.history.references)) def __rmatmul__(self, other: ArrayLike | Self): @@ -1122,15 +1122,15 @@ def __rmatmul__(self, other: ArrayLike | Self): other.value @ self.value, other.units * self.units, history=QuantityHistory.apply_operation( - operations.MatMul, + MatMul, other.history, self.history)) else: return DerivedQuantity(other @ self.value, self.units, QuantityHistory( - operations.MatMul( - operations.Constant(other), + MatMul( + Constant(other), self.history.operation_tree), self.history.references)) @@ -1141,15 +1141,15 @@ def __truediv__(self: Self, other: float | Self) -> Self: self.value / other.value, self.units / other.units, history=QuantityHistory.apply_operation( - operations.Div, + Div, self.history, other.history)) else: return DerivedQuantity(self.value / other, self.units, QuantityHistory( - operations.Div( - operations.Constant(other), + Div( + Constant(other), self.history.operation_tree), self.history.references)) @@ -1159,7 +1159,7 @@ def __rtruediv__(self: Self, other: float | Self) -> Self: other.value / self.value, other.units / self.units, history=QuantityHistory.apply_operation( - operations.Div, + Div, other.history, self.history )) @@ -1169,8 +1169,8 @@ def __rtruediv__(self: Self, other: float | Self) -> Self: other / self.value, self.units ** -1, QuantityHistory( - operations.Div( - operations.Constant(other), + Div( + Constant(other), self.history.operation_tree), self.history.references)) @@ -1181,7 +1181,7 @@ def __add__(self: Self, other: Self | ArrayLike) -> Self: self.value + (other.value * other.units.scale) / self.units.scale, self.units, QuantityHistory.apply_operation( - operations.Add, + Add, self.history, other.history)) else: @@ -1195,7 +1195,7 @@ def __add__(self: Self, other: Self | ArrayLike) -> Self: def __neg__(self): return DerivedQuantity(-self.value, self.units, QuantityHistory.apply_operation( - operations.Neg, + Neg, self.history )) @@ -1209,7 +1209,7 @@ def __pow__(self: Self, other: int | float): return DerivedQuantity(self.value ** other, self.units ** other, QuantityHistory( - operations.Pow( + Pow( self.history.operation_tree, other), self.history.references)) From b202e2114c9b4312c123f28d3387a09ff4b5aca4 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Mon, 21 Oct 2024 17:07:42 +0100 Subject: [PATCH 27/35] Tensor product implementation --- sasdata/quantities/quantity.py | 102 +++++++++++++++++++++++++++++---- 1 file changed, 90 insertions(+), 12 deletions(-) diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index 09e2880..852469d 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -25,12 +25,63 @@ def transpose(a: Union["Quantity[ArrayLike]", ArrayLike]): + """ Transpose an array or an array based quantity """ if isinstance(a, Quantity): - return + return DerivedQuantity(value=np.transpose(a.value), + units=a.units, + history=QuantityHistory.apply_operation(Transpose, a.history)) + else: + return np.transpose(a) + +def dot(a: Union["Quantity[ArrayLike]", ArrayLike], b: Union["Quantity[ArrayLike]", ArrayLike]): + """ Dot product of two arrays or two array based quantities """ + a_is_quantity = isinstance(a, Quantity) + b_is_quantity = isinstance(b, Quantity) + + if a_is_quantity or b_is_quantity: + + # If its only one of them that is a quantity, convert the other one + + if not a_is_quantity: + a = Quantity(a, units.dimensionless) + + if not b_is_quantity: + b = Quantity(b, units.dimensionless) + + return DerivedQuantity( + value=np.dot(a.value, b.value), + units=a.units * b.units, + history=QuantityHistory.apply_operation(Dot, a.history, b.history)) + else: + return np.dot(a, b) def tensordot(a: Union["Quantity[ArrayLike]", ArrayLike] | ArrayLike, b: Union["Quantity[ArrayLike]", ArrayLike], a_index: int, b_index: int): - pass + a_is_quantity = isinstance(a, Quantity) + b_is_quantity = isinstance(b, Quantity) + + if a_is_quantity or b_is_quantity: + + # If its only one of them that is a quantity, convert the other one + + if not a_is_quantity: + a = Quantity(a, units.dimensionless) + + if not b_is_quantity: + b = Quantity(b, units.dimensionless) + + return DerivedQuantity( + value=np.tensordot(a.value, b.value, axes=(a_index, b_index)), + units=a.units * b.units, + history=QuantityHistory.apply_operation( + TensorDot, + a.history, + b.history, + a_index=a_index, + b_index=b_index)) + + else: + return np.tensordot(a, b, axes=(a_index, b_index)) ################### Operation Definitions ####################################### @@ -773,7 +824,7 @@ class Dot(BinaryOperation): serialisation_name = "dot" def evaluate(self, variables: dict[int, T]) -> T: - return np.dot(self.a.evaluate(variables) + self.b.evaluate(variables)) + return dot(self.a.evaluate(variables), self.b.evaluate(variables)) def _derivative(self, hash_value: int) -> Operation: return Add( @@ -785,6 +836,7 @@ def _derivative(self, hash_value: int) -> Operation: def _clean_ab(self, a, b): return Dot(a, b) # Do nothing for now + @staticmethod def _deserialise(parameters: dict) -> "Operation": return Dot(*BinaryOperation._deserialise_ab(parameters)) @@ -835,7 +887,7 @@ def _deserialise(parameters: dict) -> "Operation": def _summary_open(self): return "MatMul" -class TensorProduct(Operation): +class TensorDot(Operation): serialisation_name = "tensor_product" def __init__(self, a: Operation, b: Operation, a_index: int, b_index: int): @@ -845,21 +897,32 @@ def __init__(self, a: Operation, b: Operation, a_index: int, b_index: int): self.b_index = b_index def evaluate(self, variables: dict[int, T]) -> T: - return np.tensordot(self.a, self.b, axes=(self.a_index, self.b_index)) + return tensordot(self.a, self.b, self.a_index, self.b_index) + + + def _serialise_parameters(self) -> dict[str, Any]: + return { + "a": self.a._serialise_json(), + "b": self.b._serialise_json(), + "a_index": self.a_index, + "b_index": self.b_index } @staticmethod def _deserialise(parameters: dict) -> "Operation": - pass + return TensorDot(a = Operation.deserialise_json(parameters["a"]), + b = Operation.deserialise_json(parameters["b"]), + a_index=int(parameters["a_index"]), + b_index=int(parameters["b_index"])) def _summary_open(self): return "TensorProduct" _serialisable_classes = [AdditiveIdentity, MultiplicativeIdentity, Constant, - Variable, - Neg, Inv, - Add, Sub, Mul, Div, Pow, - Transpose, Dot, MatMul] + Variable, + Neg, Inv, + Add, Sub, Mul, Div, Pow, + Transpose, Dot, MatMul, TensorDot] _serialisation_lookup = {cls.serialisation_name: cls for cls in _serialisable_classes} @@ -879,10 +942,25 @@ def hash_data_via_numpy(*data: ArrayLike): return int(md5_hash.hexdigest(), 16) + +##################################### +# # +# # +# # +# Quantities begin here # +# # +# # +# # +##################################### + + + QuantityType = TypeVar("QuantityType") class QuantityHistory: + """ Class that holds the information for keeping track of operations done on quantities """ + def __init__(self, operation_tree: Operation, references: dict[int, "Quantity"]): self.operation_tree = operation_tree self.references = references @@ -936,7 +1014,7 @@ def variable(quantity: "Quantity"): return QuantityHistory(Variable(quantity.hash_value), {quantity.hash_value: quantity}) @staticmethod - def apply_operation(operation: type[Operation], *histories: "QuantityHistory") -> "QuantityHistory": + def apply_operation(operation: type[Operation], *histories: "QuantityHistory", **extra_parameters) -> "QuantityHistory": """ Apply an operation to the history This is slightly unsafe as it is possible to attempt to apply an n-ary operation to a number of trees other @@ -953,7 +1031,7 @@ def apply_operation(operation: type[Operation], *histories: "QuantityHistory") - references.update(history.references) return QuantityHistory( - operation(*[history.operation_tree for history in histories]), + operation(*[history.operation_tree for history in histories], **extra_parameters), references) def has_variance(self): From 76a562650271bf6ee1a216c896d3a9a1c8ee18da Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Mon, 21 Oct 2024 19:39:47 +0100 Subject: [PATCH 28/35] Extended transpose, and tensor tests --- sasdata/quantities/math_operations_test.py | 152 +++++++++++++++++++++ sasdata/quantities/quantity.py | 80 ++++++++--- 2 files changed, 211 insertions(+), 21 deletions(-) create mode 100644 sasdata/quantities/math_operations_test.py diff --git a/sasdata/quantities/math_operations_test.py b/sasdata/quantities/math_operations_test.py new file mode 100644 index 0000000..5bda5a2 --- /dev/null +++ b/sasdata/quantities/math_operations_test.py @@ -0,0 +1,152 @@ +""" Tests for math operations """ + +import pytest + +import numpy as np +from sasdata.quantities.quantity import NamedQuantity, tensordot, transpose +from sasdata.quantities import units + +order_list = [ + [0, 1, 2, 3], + [0, 2, 1], + [1, 0], + [0, 1], + [2, 0, 1], + [3, 1, 2, 0] +] + +@pytest.mark.parametrize("order", order_list) +def test_transpose_raw(order: list[int]): + """ Check that the transpose operation changes the order of indices correctly - uses sizes as way of tracking""" + + input_shape = tuple([i+1 for i in range(len(order))]) + expected_shape = tuple([i+1 for i in order]) + + input_mat = np.zeros(input_shape) + + measured_mat = transpose(input_mat, axes=tuple(order)) + + assert measured_mat.shape == expected_shape + + +@pytest.mark.parametrize("order", order_list) +def test_transpose_raw(order: list[int]): + """ Check that the transpose operation changes the order of indices correctly - uses sizes as way of tracking""" + input_shape = tuple([i + 1 for i in range(len(order))]) + expected_shape = tuple([i + 1 for i in order]) + + input_mat = NamedQuantity("testmat", np.zeros(input_shape), units=units.none) + + measured_mat = transpose(input_mat, axes=tuple(order)) + + assert measured_mat.value.shape == expected_shape + + +rng_seed = 1979 +tensor_product_with_identity_sizes = (4,6,5) + +@pytest.mark.parametrize("index, size", [tup for tup in enumerate(tensor_product_with_identity_sizes)]) +def test_tensor_product_with_identity_quantities(index, size): + """ Check the correctness of the tensor product by multiplying by the identity (quantity, quantity)""" + np.random.seed(rng_seed) + + x = NamedQuantity("x", np.random.rand(*tensor_product_with_identity_sizes), units=units.meters) + y = NamedQuantity("y", np.eye(size), units.seconds) + + z = tensordot(x, y, index, 0) + + # Check units + assert z.units == units.meters * units.seconds + + # Expected sizes - last index gets moved to end + output_order = [i for i in (0, 1, 2) if i != index] + [index] + output_sizes = [tensor_product_with_identity_sizes[i] for i in output_order] + + assert z.value.shape == tuple(output_sizes) + + # Restore original order and check + reverse_order = [-1, -1, -1] + for to_index, from_index in enumerate(output_order): + reverse_order[from_index] = to_index + + z_reordered = transpose(z, axes = tuple(reverse_order)) + + assert z_reordered.value.shape == tensor_product_with_identity_sizes + + # Check values + + mat_in = x.in_si() + mat_out = transpose(z, axes=tuple(reverse_order)).in_si() + + assert np.all(np.abs(mat_in - mat_out) < 1e-10) + + +@pytest.mark.parametrize("index, size", [tup for tup in enumerate(tensor_product_with_identity_sizes)]) +def test_tensor_product_with_identity_quantity_matrix(index, size): + """ Check the correctness of the tensor product by multiplying by the identity (quantity, matrix)""" + np.random.seed(rng_seed) + + x = NamedQuantity("x", np.random.rand(*tensor_product_with_identity_sizes), units.meters) + y = np.eye(size) + + z = tensordot(x, y, index, 0) + + assert z.units == units.meters + + # Expected sizes - last index gets moved to end + output_order = [i for i in (0, 1, 2) if i != index] + [index] + output_sizes = [tensor_product_with_identity_sizes[i] for i in output_order] + + assert z.value.shape == tuple(output_sizes) + + # Restore original order and check + reverse_order = [-1, -1, -1] + for to_index, from_index in enumerate(output_order): + reverse_order[from_index] = to_index + + z_reordered = transpose(z, axes = tuple(reverse_order)) + + assert z_reordered.value.shape == tensor_product_with_identity_sizes + + # Check values + + mat_in = x.in_si() + mat_out = transpose(z, axes=tuple(reverse_order)).in_si() + + assert np.all(np.abs(mat_in - mat_out) < 1e-10) + + +@pytest.mark.parametrize("index, size", [tup for tup in enumerate(tensor_product_with_identity_sizes)]) +def test_tensor_product_with_identity_matrix_quantity(index, size): + """ Check the correctness of the tensor product by multiplying by the identity (matrix, quantity)""" + np.random.seed(rng_seed) + + x = np.random.rand(*tensor_product_with_identity_sizes) + y = NamedQuantity("y", np.eye(size), units.seconds) + + z = tensordot(x, y, index, 0) + + assert z.units == units.seconds + + + # Expected sizes - last index gets moved to end + output_order = [i for i in (0, 1, 2) if i != index] + [index] + output_sizes = [tensor_product_with_identity_sizes[i] for i in output_order] + + assert z.value.shape == tuple(output_sizes) + + # Restore original order and check + reverse_order = [-1, -1, -1] + for to_index, from_index in enumerate(output_order): + reverse_order[from_index] = to_index + + z_reordered = transpose(z, axes = tuple(reverse_order)) + + assert z_reordered.value.shape == tensor_product_with_identity_sizes + + # Check values + + mat_in = x + mat_out = transpose(z, axes=tuple(reverse_order)).in_si() + + assert np.all(np.abs(mat_in - mat_out) < 1e-10) diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index 852469d..c4ba7da 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -23,15 +23,23 @@ ################### Quantity based operations, need to be here to avoid cyclic dependencies ##################### - -def transpose(a: Union["Quantity[ArrayLike]", ArrayLike]): - """ Transpose an array or an array based quantity """ +def transpose(a: Union["Quantity[ArrayLike]", ArrayLike], axes: tuple | None = None): + """ Transpose an array or an array based quantity, can also do reordering of axes""" if isinstance(a, Quantity): - return DerivedQuantity(value=np.transpose(a.value), - units=a.units, - history=QuantityHistory.apply_operation(Transpose, a.history)) + + if axes is None: + return DerivedQuantity(value=np.transpose(a.value, axes=axes), + units=a.units, + history=QuantityHistory.apply_operation(Transpose, a.history)) + + else: + return DerivedQuantity(value=np.transpose(a.value, axes=axes), + units=a.units, + history=QuantityHistory.apply_operation(Transpose, a.history, axes=axes)) + else: - return np.transpose(a) + return np.transpose(a, axes=axes) + def dot(a: Union["Quantity[ArrayLike]", ArrayLike], b: Union["Quantity[ArrayLike]", ArrayLike]): """ Dot product of two arrays or two array based quantities """ @@ -43,10 +51,10 @@ def dot(a: Union["Quantity[ArrayLike]", ArrayLike], b: Union["Quantity[ArrayLike # If its only one of them that is a quantity, convert the other one if not a_is_quantity: - a = Quantity(a, units.dimensionless) + a = Quantity(a, units.none) if not b_is_quantity: - b = Quantity(b, units.dimensionless) + b = Quantity(b, units.none) return DerivedQuantity( value=np.dot(a.value, b.value), @@ -57,6 +65,18 @@ def dot(a: Union["Quantity[ArrayLike]", ArrayLike], b: Union["Quantity[ArrayLike return np.dot(a, b) def tensordot(a: Union["Quantity[ArrayLike]", ArrayLike] | ArrayLike, b: Union["Quantity[ArrayLike]", ArrayLike], a_index: int, b_index: int): + """ Tensor dot product - equivalent to contracting two tensors, such as + + A_{i0, i1, i2, i3...} and B_{j0, j1, j2...} + + e.g. if a_index is 1 and b_index is zero, it will be the sum + + C_{i0, i2, i3 ..., j1, j2 ...} = sum_k A_{i0, k, i2, i3 ...} B_{k, j1, j2 ...} + + (I think, have to check what happens with indices TODO!) + + """ + a_is_quantity = isinstance(a, Quantity) b_is_quantity = isinstance(b, Quantity) @@ -65,10 +85,10 @@ def tensordot(a: Union["Quantity[ArrayLike]", ArrayLike] | ArrayLike, b: Union[" # If its only one of them that is a quantity, convert the other one if not a_is_quantity: - a = Quantity(a, units.dimensionless) + a = Quantity(a, units.none) if not b_is_quantity: - b = Quantity(b, units.dimensionless) + b = Quantity(b, units.none) return DerivedQuantity( value=np.tensordot(a.value, b.value, axes=(a_index, b_index)), @@ -791,11 +811,15 @@ def __eq__(self, other): # Matrix operations # -class Transpose(UnaryOperation): +class Transpose(Operation): """ Transpose operation - as per numpy""" serialisation_name = "transpose" + def __init__(self, a: Operation, axes: tuple[int] | None = None): + self.a = a + self.axes = axes + def evaluate(self, variables: dict[int, T]) -> T: return np.transpose(self.a.evaluate(variables)) @@ -806,9 +830,27 @@ def _clean(self): clean_a = self.a._clean() return Transpose(clean_a) + + def _serialise_parameters(self) -> dict[str, Any]: + if self.axes is None: + return { "a": self.a._serialise_json() } + else: + return { + "a": self.a._serialise_json(), + "axes": list(self.axes) + } + + @staticmethod def _deserialise(parameters: dict) -> "Operation": - return Transpose(Operation.deserialise_json(parameters["a"])) + if "axes" in parameters: + return Transpose( + a=Operation.deserialise_json(parameters["a"]), + axes=tuple(parameters["axes"])) + else: + return Transpose( + a=Operation.deserialise_json(parameters["a"])) + def _summary_open(self): return "Transpose" @@ -974,6 +1016,10 @@ def jacobian(self) -> list[Operation]: # Use the hash value to specify the variable of differentiation return [self.operation_tree.derivative(key) for key in self.reference_key_list] + def _recalculate(self): + """ Recalculate the value of this object - primary use case is for testing """ + return self.operation_tree.evaluate(self.references) + def variance_propagate(self, quantity_units: Unit, covariances: dict[tuple[int, int]: "Quantity"] = {}): """ Do standard error propagation to calculate the uncertainties associated with this quantity @@ -985,14 +1031,6 @@ def variance_propagate(self, quantity_units: Unit, covariances: dict[tuple[int, raise NotImplementedError("User specified covariances not currently implemented") jacobian = self.jacobian() - # jacobian_units = [quantity_units / self.references[key].units for key in self.reference_key_list] - # - # # Evaluate the jacobian - # # TODO: should we use quantities here, does that work automatically? - # evaluated_jacobian = [Quantity( - # value=entry.evaluate(self.si_reference_values), - # units=unit.si_equivalent()) - # for entry, unit in zip(jacobian, jacobian_units)] evaluated_jacobian = [entry.evaluate(self.references) for entry in jacobian] From 56a476d1e1eae653e405e56a542c1a22196ea316 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Tue, 22 Oct 2024 15:36:48 +0100 Subject: [PATCH 29/35] Interpolation stuff --- sasdata/manual_tests/interpolation.py | 2 +- sasdata/quantities/math.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sasdata/manual_tests/interpolation.py b/sasdata/manual_tests/interpolation.py index c6338a4..c46078b 100644 --- a/sasdata/manual_tests/interpolation.py +++ b/sasdata/manual_tests/interpolation.py @@ -34,7 +34,7 @@ def linear_interpolation_check(): quantity_plot(new_x, new_y) - # print(new_y.history.summary()) + print(new_y.history.summary()) plt.show() diff --git a/sasdata/quantities/math.py b/sasdata/quantities/math.py index d252ccc..6ef5b29 100644 --- a/sasdata/quantities/math.py +++ b/sasdata/quantities/math.py @@ -2,4 +2,3 @@ # TODO Implementations for trig and exp # TODO Implementations for linear algebra stuff - From 4a90ef950cc8180877f5c4ed0c935e3adf58314c Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Tue, 22 Oct 2024 19:03:30 +0100 Subject: [PATCH 30/35] Encodings for numerical values --- sasdata/quantities/numerical_encoding.py | 40 ++++++++++++++ sasdata/quantities/quantity.py | 6 ++- sasdata/quantities/test_numerical_encoding.py | 54 +++++++++++++++++++ 3 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 sasdata/quantities/numerical_encoding.py create mode 100644 sasdata/quantities/test_numerical_encoding.py diff --git a/sasdata/quantities/numerical_encoding.py b/sasdata/quantities/numerical_encoding.py new file mode 100644 index 0000000..e583d21 --- /dev/null +++ b/sasdata/quantities/numerical_encoding.py @@ -0,0 +1,40 @@ +import numpy as np + +import base64 +import struct + + +def numerical_encode(obj: int | float | np.ndarray): + + if isinstance(obj, int): + return {"type": "int", + "value": obj} + + elif isinstance(obj, float): + return {"type": "float", + "value": base64.b64encode(bytearray(struct.pack('d', obj)))} + + elif isinstance(obj, np.ndarray): + return { + "type": "numpy", + "value": base64.b64encode(obj.tobytes()), + "dtype": obj.dtype.str, + "shape": list(obj.shape) + } + + else: + raise TypeError(f"Cannot serialise object of type: {type(obj)}") + +def numerical_decode(data: dict[str, str | int | list[int]]) -> int | float | np.ndarray: + match data["type"]: + case "int": + return int(data["value"]) + + case "float": + return struct.unpack('d', base64.b64decode(data["value"]))[0] + + case "numpy": + value = base64.b64decode(data["value"]) + dtype = np.dtype(data["dtype"]) + shape = tuple(data["shape"]) + return np.frombuffer(value, dtype=dtype).reshape(*shape) diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index c4ba7da..7a0acbb 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -1,14 +1,17 @@ +from encodings.base64_codec import base64_decode from typing import Collection, Sequence, TypeVar, Generic, Self from dataclasses import dataclass import numpy as np +from lxml.etree import SerialisationError from numpy._typing import ArrayLike from sasdata.quantities import units from sasdata.quantities.units import Unit, NamedUnit import hashlib - +import base64 +import struct from typing import Any, TypeVar, Union import numpy as np @@ -131,7 +134,6 @@ def hash_and_name(hash_or_name: int | str): else: raise TypeError("Variable name_or_hash_value must be either str or int") - class Operation: serialisation_name = "unknown" diff --git a/sasdata/quantities/test_numerical_encoding.py b/sasdata/quantities/test_numerical_encoding.py new file mode 100644 index 0000000..4b17058 --- /dev/null +++ b/sasdata/quantities/test_numerical_encoding.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest + +from sasdata.quantities.numerical_encoding import numerical_encode, numerical_decode + + +@pytest.mark.parametrize("value", [-100.0, -10.0, -1.0, 0.0, 0.5, 1.0, 10.0, 100.0, 1e100]) +def test_float_encode_decode(value: float): + + assert isinstance(value, float) # Make sure we have the right inputs + + encoded = numerical_encode(value) + decoded = numerical_decode(encoded) + + assert isinstance(decoded, float) + assert value == decoded + +@pytest.mark.parametrize("value", [-100, -10, -1, 0, 1, 10, 100, 1000000000000000000000000000000000]) +def test_int_encode_decode(value: int): + + assert isinstance(value, int) # Make sure we have the right inputs + + encoded = numerical_encode(value) + decoded = numerical_decode(encoded) + + assert isinstance(decoded, int) + assert value == decoded + +@pytest.mark.parametrize("shape", [ + (2,3,4), + (1,2), + (10,5,10), + (1,), + (4,), + (0, ) ]) +def test_numpy_float_encode_decode(shape): + np.random.seed(1776) + test_matrix = np.random.rand(*shape) + + encoded = numerical_encode(test_matrix) + decoded = numerical_decode(encoded) + + assert decoded.dtype == test_matrix.dtype + assert decoded.shape == test_matrix.shape + assert np.all(decoded == test_matrix) + +@pytest.mark.parametrize("dtype", [int, float, complex]) +def test_numpy_dtypes_encode_decode(dtype): + test_matrix = np.zeros((3,3), dtype=dtype) + + encoded = numerical_encode(test_matrix) + decoded = numerical_decode(encoded) + + assert decoded.dtype == test_matrix.dtype From d20e9f322652fffa707df8673ea8e6f82c734f35 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Tue, 22 Oct 2024 19:08:50 +0100 Subject: [PATCH 31/35] Tidying up --- sasdata/quantities/quantity.py | 16 +++++----------- sasdata/quantities/test_numerical_encoding.py | 2 ++ 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index 7a0acbb..551ced5 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -1,20 +1,15 @@ -from encodings.base64_codec import base64_decode -from typing import Collection, Sequence, TypeVar, Generic, Self -from dataclasses import dataclass +from typing import Self import numpy as np -from lxml.etree import SerialisationError from numpy._typing import ArrayLike from sasdata.quantities import units +from sasdata.quantities.numerical_encoding import numerical_decode, numerical_encode from sasdata.quantities.units import Unit, NamedUnit import hashlib -import base64 -import struct from typing import Any, TypeVar, Union -import numpy as np import json @@ -309,7 +304,7 @@ def __init__(self, value): self.value = value def summary(self, indent_amount: int = 0, indent: str=" "): - pass + return repr(self.value) def evaluate(self, variables: dict[int, T]) -> T: return self.value @@ -330,13 +325,12 @@ def _clean(self): @staticmethod def _deserialise(parameters: dict) -> "Operation": - value = parameters["value"] + value = numerical_decode(parameters["value"]) return Constant(value) def _serialise_parameters(self) -> dict[str, Any]: - return {"value": self.value} - + return {"value": numerical_encode(self.value)} def summary(self, indent_amount: int=0, indent=" "): return f"{indent_amount*indent}{self.value}" diff --git a/sasdata/quantities/test_numerical_encoding.py b/sasdata/quantities/test_numerical_encoding.py index 4b17058..e1166ee 100644 --- a/sasdata/quantities/test_numerical_encoding.py +++ b/sasdata/quantities/test_numerical_encoding.py @@ -1,3 +1,5 @@ +""" Tests for the encoding and decoding of numerical data""" + import numpy as np import pytest From ebac04c751fb9b6397cb49a6e9e1744f77ef74e0 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Wed, 23 Oct 2024 08:05:42 +0100 Subject: [PATCH 32/35] Work on sparse matrix serialisation --- sasdata/quantities/numerical_encoding.py | 38 +++++++++++++++++-- sasdata/quantities/test_numerical_encoding.py | 9 +++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/sasdata/quantities/numerical_encoding.py b/sasdata/quantities/numerical_encoding.py index e583d21..879880a 100644 --- a/sasdata/quantities/numerical_encoding.py +++ b/sasdata/quantities/numerical_encoding.py @@ -1,10 +1,11 @@ import numpy as np +from scipy.sparse import coo_matrix, csr_matrix, csc_matrix, coo_array, csr_array, csc_array import base64 import struct -def numerical_encode(obj: int | float | np.ndarray): +def numerical_encode(obj: int | float | np.ndarray | coo_matrix | coo_array | csr_matrix | csr_array | csc_matrix | csc_array): if isinstance(obj, int): return {"type": "int", @@ -22,11 +23,38 @@ def numerical_encode(obj: int | float | np.ndarray): "shape": list(obj.shape) } + elif isinstance(obj, (coo_matrix, coo_array, csr_matrix, csr_array, csc_matrix, csc_array)): + + output = { + "type": obj.__class__.__name__, # not robust to name changes, but more concise + "dtype": obj.dtype.str, + "shape": list(obj.shape) + } + + if isinstance(obj, (coo_array, coo_matrix)): + + output["data"] = numerical_encode(obj.data) + output["coords"] = [numerical_encode(coord) for coord in obj.coords] + + + elif isinstance(obj, (csr_array, csr_matrix)): + pass + + + elif isinstance(obj, (csc_array, csc_matrix)): + + pass + + + return output + else: raise TypeError(f"Cannot serialise object of type: {type(obj)}") -def numerical_decode(data: dict[str, str | int | list[int]]) -> int | float | np.ndarray: - match data["type"]: +def numerical_decode(data: dict[str, str | int | list[int]]) -> int | float | np.ndarray | coo_matrix | coo_array | csr_matrix | csr_array | csc_matrix | csc_array: + obj_type = data["type"] + + match obj_type: case "int": return int(data["value"]) @@ -38,3 +66,7 @@ def numerical_decode(data: dict[str, str | int | list[int]]) -> int | float | np dtype = np.dtype(data["dtype"]) shape = tuple(data["shape"]) return np.frombuffer(value, dtype=dtype).reshape(*shape) + + case _: + raise ValueError(f"Cannot decode objects of type '{obj_type}'") + diff --git a/sasdata/quantities/test_numerical_encoding.py b/sasdata/quantities/test_numerical_encoding.py index e1166ee..83fa5fe 100644 --- a/sasdata/quantities/test_numerical_encoding.py +++ b/sasdata/quantities/test_numerical_encoding.py @@ -54,3 +54,12 @@ def test_numpy_dtypes_encode_decode(dtype): decoded = numerical_decode(encoded) assert decoded.dtype == test_matrix.dtype + +@pytest.mark.parametrize("dtype", [int, float, complex]) +@pytest.mark.parametrize("shape, n, m", [ + ((8, 8), (1,3,5),(2,5,7)), + ((6, 8), (1,0,5),(0,5,0)), + ((6, 1), (1, 0, 5), (0, 0, 0)), +]) +def test_coo_matrix_encode_decode(shape, n, m, dtype): + test_matrix = np.arange() From fe9c67eb7e2d75a53d38e3116cb6dd8b6797bd59 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Wed, 23 Oct 2024 10:02:47 +0100 Subject: [PATCH 33/35] Updated final line endings --- sasdata/quantities/quantity.py | 2886 ++++++++++++++++---------------- 1 file changed, 1443 insertions(+), 1443 deletions(-) diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index 551ced5..584f3cf 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -1,1443 +1,1443 @@ -from typing import Self - -import numpy as np -from numpy._typing import ArrayLike - -from sasdata.quantities import units -from sasdata.quantities.numerical_encoding import numerical_decode, numerical_encode -from sasdata.quantities.units import Unit, NamedUnit - -import hashlib - -from typing import Any, TypeVar, Union - -import json - -T = TypeVar("T") - - - - - -################### Quantity based operations, need to be here to avoid cyclic dependencies ##################### - -def transpose(a: Union["Quantity[ArrayLike]", ArrayLike], axes: tuple | None = None): - """ Transpose an array or an array based quantity, can also do reordering of axes""" - if isinstance(a, Quantity): - - if axes is None: - return DerivedQuantity(value=np.transpose(a.value, axes=axes), - units=a.units, - history=QuantityHistory.apply_operation(Transpose, a.history)) - - else: - return DerivedQuantity(value=np.transpose(a.value, axes=axes), - units=a.units, - history=QuantityHistory.apply_operation(Transpose, a.history, axes=axes)) - - else: - return np.transpose(a, axes=axes) - - -def dot(a: Union["Quantity[ArrayLike]", ArrayLike], b: Union["Quantity[ArrayLike]", ArrayLike]): - """ Dot product of two arrays or two array based quantities """ - a_is_quantity = isinstance(a, Quantity) - b_is_quantity = isinstance(b, Quantity) - - if a_is_quantity or b_is_quantity: - - # If its only one of them that is a quantity, convert the other one - - if not a_is_quantity: - a = Quantity(a, units.none) - - if not b_is_quantity: - b = Quantity(b, units.none) - - return DerivedQuantity( - value=np.dot(a.value, b.value), - units=a.units * b.units, - history=QuantityHistory.apply_operation(Dot, a.history, b.history)) - - else: - return np.dot(a, b) - -def tensordot(a: Union["Quantity[ArrayLike]", ArrayLike] | ArrayLike, b: Union["Quantity[ArrayLike]", ArrayLike], a_index: int, b_index: int): - """ Tensor dot product - equivalent to contracting two tensors, such as - - A_{i0, i1, i2, i3...} and B_{j0, j1, j2...} - - e.g. if a_index is 1 and b_index is zero, it will be the sum - - C_{i0, i2, i3 ..., j1, j2 ...} = sum_k A_{i0, k, i2, i3 ...} B_{k, j1, j2 ...} - - (I think, have to check what happens with indices TODO!) - - """ - - a_is_quantity = isinstance(a, Quantity) - b_is_quantity = isinstance(b, Quantity) - - if a_is_quantity or b_is_quantity: - - # If its only one of them that is a quantity, convert the other one - - if not a_is_quantity: - a = Quantity(a, units.none) - - if not b_is_quantity: - b = Quantity(b, units.none) - - return DerivedQuantity( - value=np.tensordot(a.value, b.value, axes=(a_index, b_index)), - units=a.units * b.units, - history=QuantityHistory.apply_operation( - TensorDot, - a.history, - b.history, - a_index=a_index, - b_index=b_index)) - - else: - return np.tensordot(a, b, axes=(a_index, b_index)) - - -################### Operation Definitions ####################################### - -def hash_and_name(hash_or_name: int | str): - """ Infer the name of a variable from a hash, or the hash from the name - - Note: hash_and_name(hash_and_name(number)[1]) is not the identity - however: hash_and_name(hash_and_name(number)) is - """ - - if isinstance(hash_or_name, str): - hash_value = hash(hash_or_name) - name = hash_or_name - - return hash_value, name - - elif isinstance(hash_or_name, int): - hash_value = hash_or_name - name = f"#{hash_or_name}" - - return hash_value, name - - elif isinstance(hash_or_name, tuple): - return hash_or_name - - else: - raise TypeError("Variable name_or_hash_value must be either str or int") - -class Operation: - - serialisation_name = "unknown" - def summary(self, indent_amount: int = 0, indent: str=" "): - """ Summary of the operation tree""" - - s = f"{indent_amount*indent}{self._summary_open()}(\n" - - for chunk in self._summary_components(): - s += chunk.summary(indent_amount+1, indent) + "\n" - - s += f"{indent_amount*indent})" - - return s - def _summary_open(self): - """ First line of summary """ - - def _summary_components(self) -> list["Operation"]: - return [] - def evaluate(self, variables: dict[int, T]) -> T: - - """ Evaluate this operation """ - - def _derivative(self, hash_value: int) -> "Operation": - """ Get the derivative of this operation """ - - def _clean(self): - """ Clean up this operation - i.e. remove silly things like 1*x """ - return self - - def derivative(self, variable: Union[str, int, "Variable"], simplify=True): - if isinstance(variable, Variable): - hash_value = variable.hash_value - else: - hash_value, _ = hash_and_name(variable) - - derivative = self._derivative(hash_value) - - if not simplify: - return derivative - - derivative_string = derivative.serialise() - - # print("---------------") - # print("Base") - # print("---------------") - # print(derivative.summary()) - - # Inefficient way of doing repeated simplification, but it will work - for i in range(100): # set max iterations - - derivative = derivative._clean() - # - # print("-------------------") - # print("Iteration", i+1) - # print("-------------------") - # print(derivative.summary()) - # print("-------------------") - - new_derivative_string = derivative.serialise() - - if derivative_string == new_derivative_string: - break - - derivative_string = new_derivative_string - - return derivative - - @staticmethod - def deserialise(data: str) -> "Operation": - json_data = json.loads(data) - return Operation.deserialise_json(json_data) - - @staticmethod - def deserialise_json(json_data: dict) -> "Operation": - - operation = json_data["operation"] - parameters = json_data["parameters"] - cls = _serialisation_lookup[operation] - - try: - return cls._deserialise(parameters) - - except NotImplementedError: - raise NotImplementedError(f"No method to deserialise {operation} with {parameters} (cls={cls})") - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - raise NotImplementedError(f"Deserialise not implemented for this class") - - def serialise(self) -> str: - return json.dumps(self._serialise_json()) - - def _serialise_json(self) -> dict[str, Any]: - return {"operation": self.serialisation_name, - "parameters": self._serialise_parameters()} - - def _serialise_parameters(self) -> dict[str, Any]: - raise NotImplementedError("_serialise_parameters not implemented") - - def __eq__(self, other: "Operation"): - return NotImplemented - -class ConstantBase(Operation): - pass - -class AdditiveIdentity(ConstantBase): - - serialisation_name = "zero" - def evaluate(self, variables: dict[int, T]) -> T: - return 0 - - def _derivative(self, hash_value: int) -> Operation: - return AdditiveIdentity() - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return AdditiveIdentity() - - def _serialise_parameters(self) -> dict[str, Any]: - return {} - - def summary(self, indent_amount: int=0, indent=" "): - return f"{indent_amount*indent}0 [Add.Id.]" - - def __eq__(self, other): - if isinstance(other, AdditiveIdentity): - return True - elif isinstance(other, Constant): - if other.value == 0: - return True - - return False - - - -class MultiplicativeIdentity(ConstantBase): - - serialisation_name = "one" - - def evaluate(self, variables: dict[int, T]) -> T: - return 1 - - def _derivative(self, hash_value: int): - return AdditiveIdentity() - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return MultiplicativeIdentity() - - - def _serialise_parameters(self) -> dict[str, Any]: - return {} - - - def summary(self, indent_amount: int=0, indent=" "): - return f"{indent_amount*indent}1 [Mul.Id.]" - - def __eq__(self, other): - if isinstance(other, MultiplicativeIdentity): - return True - elif isinstance(other, Constant): - if other.value == 1: - return True - - return False - - -class Constant(ConstantBase): - - serialisation_name = "constant" - def __init__(self, value): - self.value = value - - def summary(self, indent_amount: int = 0, indent: str=" "): - return repr(self.value) - - def evaluate(self, variables: dict[int, T]) -> T: - return self.value - - def _derivative(self, hash_value: int): - return AdditiveIdentity() - - def _clean(self): - - if self.value == 0: - return AdditiveIdentity() - - elif self.value == 1: - return MultiplicativeIdentity() - - else: - return self - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - value = numerical_decode(parameters["value"]) - return Constant(value) - - - def _serialise_parameters(self) -> dict[str, Any]: - return {"value": numerical_encode(self.value)} - - def summary(self, indent_amount: int=0, indent=" "): - return f"{indent_amount*indent}{self.value}" - - def __eq__(self, other): - if isinstance(other, AdditiveIdentity): - return self.value == 0 - - elif isinstance(other, MultiplicativeIdentity): - return self.value == 1 - - elif isinstance(other, Constant): - if other.value == self.value: - return True - - return False - - -class Variable(Operation): - - serialisation_name = "variable" - def __init__(self, name_or_hash_value: int | str | tuple[int, str]): - self.hash_value, self.name = hash_and_name(name_or_hash_value) - - def evaluate(self, variables: dict[int, T]) -> T: - try: - return variables[self.hash_value] - except KeyError: - raise ValueError(f"Variable dictionary didn't have an entry for {self.name} (hash={self.hash_value})") - - def _derivative(self, hash_value: int) -> Operation: - if hash_value == self.hash_value: - return MultiplicativeIdentity() - else: - return AdditiveIdentity() - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - hash_value = parameters["hash_value"] - name = parameters["name"] - - return Variable((hash_value, name)) - - def _serialise_parameters(self) -> dict[str, Any]: - return {"hash_value": self.hash_value, - "name": self.name} - - def summary(self, indent_amount: int = 0, indent: str=" "): - return f"{indent_amount*indent}{self.name}" - - def __eq__(self, other): - if isinstance(other, Variable): - return self.hash_value == other.hash_value - - return False - -class UnaryOperation(Operation): - - def __init__(self, a: Operation): - self.a = a - - def _serialise_parameters(self) -> dict[str, Any]: - return {"a": self.a._serialise_json()} - - def _summary_components(self) -> list["Operation"]: - return [self.a] - - - - -class Neg(UnaryOperation): - - serialisation_name = "neg" - def evaluate(self, variables: dict[int, T]) -> T: - return -self.a.evaluate(variables) - - def _derivative(self, hash_value: int): - return Neg(self.a._derivative(hash_value)) - - def _clean(self): - - clean_a = self.a._clean() - - if isinstance(clean_a, Neg): - # Removes double negations - return clean_a.a - - elif isinstance(clean_a, Constant): - return Constant(-clean_a.value)._clean() - - else: - return Neg(clean_a) - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Neg(Operation.deserialise_json(parameters["a"])) - - - def _summary_open(self): - return "Neg" - - def __eq__(self, other): - if isinstance(other, Neg): - return other.a == self.a - - -class Inv(UnaryOperation): - - serialisation_name = "reciprocal" - - def evaluate(self, variables: dict[int, T]) -> T: - return 1/self.a.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Neg(Div(self.a._derivative(hash_value), Mul(self.a, self.a))) - - def _clean(self): - clean_a = self.a._clean() - - if isinstance(clean_a, Inv): - # Removes double negations - return clean_a.a - - elif isinstance(clean_a, Neg): - # cannonicalise 1/-a to -(1/a) - # over multiple iterations this should have the effect of ordering and gathering Neg and Inv - return Neg(Inv(clean_a.a)) - - elif isinstance(clean_a, Constant): - return Constant(1/clean_a.value)._clean() - - else: - return Inv(clean_a) - - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Inv(Operation.deserialise_json(parameters["a"])) - - def _summary_open(self): - return "Inv" - - - def __eq__(self, other): - if isinstance(other, Inv): - return other.a == self.a - -class BinaryOperation(Operation): - def __init__(self, a: Operation, b: Operation): - self.a = a - self.b = b - - def _clean(self): - return self._clean_ab(self.a._clean(), self.b._clean()) - - def _clean_ab(self, a, b): - raise NotImplementedError("_clean_ab not implemented") - - def _serialise_parameters(self) -> dict[str, Any]: - return {"a": self.a._serialise_json(), - "b": self.b._serialise_json()} - - @staticmethod - def _deserialise_ab(parameters) -> tuple[Operation, Operation]: - return (Operation.deserialise_json(parameters["a"]), - Operation.deserialise_json(parameters["b"])) - - - def _summary_components(self) -> list["Operation"]: - return [self.a, self.b] - - def _self_cls(self) -> type: - """ Own class""" - def __eq__(self, other): - if isinstance(other, self._self_cls()): - return other.a == self.a and self.b == other.b - -class Add(BinaryOperation): - - serialisation_name = "add" - - def _self_cls(self) -> type: - return Add - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) + self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Add(self.a._derivative(hash_value), self.b._derivative(hash_value)) - - def _clean_ab(self, a, b): - - if isinstance(a, AdditiveIdentity): - # Convert 0 + b to b - return b - - elif isinstance(b, AdditiveIdentity): - # Convert a + 0 to a - return a - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constant "a"+"b" to "a+b" - return Constant(a.evaluate({}) + b.evaluate({}))._clean() - - elif isinstance(a, Neg): - if isinstance(b, Neg): - # Convert (-a)+(-b) to -(a+b) - return Neg(Add(a.a, b.a)) - else: - # Convert (-a) + b to b-a - return Sub(b, a.a) - - elif isinstance(b, Neg): - # Convert a+(-b) to a-b - return Sub(a, b.a) - - elif a == b: - return Mul(Constant(2), a) - - else: - return Add(a, b) - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Add(*BinaryOperation._deserialise_ab(parameters)) - - def _summary_open(self): - return "Add" - -class Sub(BinaryOperation): - - serialisation_name = "sub" - - - def _self_cls(self) -> type: - return Sub - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) - self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Sub(self.a._derivative(hash_value), self.b._derivative(hash_value)) - - def _clean_ab(self, a, b): - if isinstance(a, AdditiveIdentity): - # Convert 0 - b to -b - return Neg(b) - - elif isinstance(b, AdditiveIdentity): - # Convert a - 0 to a - return a - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constant pair "a" - "b" to "a-b" - return Constant(a.evaluate({}) - b.evaluate({}))._clean() - - elif isinstance(a, Neg): - if isinstance(b, Neg): - # Convert (-a)-(-b) to b-a - return Sub(b.a, a.a) - else: - # Convert (-a)-b to -(a+b) - return Neg(Add(a.a, b)) - - elif isinstance(b, Neg): - # Convert a-(-b) to a+b - return Add(a, b.a) - - elif a == b: - return AdditiveIdentity() - - else: - return Sub(a, b) - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Sub(*BinaryOperation._deserialise_ab(parameters)) - - - def _summary_open(self): - return "Sub" - -class Mul(BinaryOperation): - - serialisation_name = "mul" - - - def _self_cls(self) -> type: - return Mul - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) * self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Add(Mul(self.a, self.b._derivative(hash_value)), Mul(self.a._derivative(hash_value), self.b)) - - def _clean_ab(self, a, b): - - if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): - # Convert 0*b or a*0 to 0 - return AdditiveIdentity() - - elif isinstance(a, MultiplicativeIdentity): - # Convert 1*b to b - return b - - elif isinstance(b, MultiplicativeIdentity): - # Convert a*1 to a - return a - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constant "a"*"b" to "a*b" - return Constant(a.evaluate({}) * b.evaluate({}))._clean() - - elif isinstance(a, Inv) and isinstance(b, Inv): - return Inv(Mul(a.a, b.a)) - - elif isinstance(a, Inv) and not isinstance(b, Inv): - return Div(b, a.a) - - elif not isinstance(a, Inv) and isinstance(b, Inv): - return Div(a, b.a) - - elif isinstance(a, Neg): - return Neg(Mul(a.a, b)) - - elif isinstance(b, Neg): - return Neg(Mul(a, b.a)) - - elif a == b: - return Pow(a, 2) - - elif isinstance(a, Pow) and a.a == b: - return Pow(b, a.power + 1) - - elif isinstance(b, Pow) and b.a == a: - return Pow(a, b.power + 1) - - elif isinstance(a, Pow) and isinstance(b, Pow) and a.a == b.a: - return Pow(a.a, a.power + b.power) - - else: - return Mul(a, b) - - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Mul(*BinaryOperation._deserialise_ab(parameters)) - - - def _summary_open(self): - return "Mul" - -class Div(BinaryOperation): - - serialisation_name = "div" - - - def _self_cls(self) -> type: - return Div - - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) / self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Sub(Div(self.a.derivative(hash_value), self.b), - Div(Mul(self.a, self.b.derivative(hash_value)), Mul(self.b, self.b))) - - def _clean_ab(self, a, b): - if isinstance(a, AdditiveIdentity): - # Convert 0/b to 0 - return AdditiveIdentity() - - elif isinstance(a, MultiplicativeIdentity): - # Convert 1/b to inverse of b - return Inv(b) - - elif isinstance(b, MultiplicativeIdentity): - # Convert a/1 to a - return a - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constants "a"/"b" to "a/b" - return Constant(self.a.evaluate({}) / self.b.evaluate({}))._clean() - - - elif isinstance(a, Inv) and isinstance(b, Inv): - return Div(b.a, a.a) - - elif isinstance(a, Inv) and not isinstance(b, Inv): - return Inv(Mul(a.a, b)) - - elif not isinstance(a, Inv) and isinstance(b, Inv): - return Mul(a, b.a) - - elif a == b: - return MultiplicativeIdentity() - - elif isinstance(a, Pow) and a.a == b: - return Pow(b, a.power - 1) - - elif isinstance(b, Pow) and b.a == a: - return Pow(a, 1 - b.power) - - elif isinstance(a, Pow) and isinstance(b, Pow) and a.a == b.a: - return Pow(a.a, a.power - b.power) - - else: - return Div(a, b) - - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Div(*BinaryOperation._deserialise_ab(parameters)) - - def _summary_open(self): - return "Div" - -class Pow(Operation): - - serialisation_name = "pow" - - def __init__(self, a: Operation, power: float): - self.a = a - self.power = power - - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) ** self.power - - def _derivative(self, hash_value: int) -> Operation: - if self.power == 0: - return AdditiveIdentity() - - elif self.power == 1: - return self.a._derivative(hash_value) - - else: - return Mul(Constant(self.power), Mul(Pow(self.a, self.power-1), self.a._derivative(hash_value))) - - def _clean(self) -> Operation: - a = self.a._clean() - - if self.power == 1: - return a - - elif self.power == 0: - return MultiplicativeIdentity() - - elif self.power == -1: - return Inv(a) - - else: - return Pow(a, self.power) - - - def _serialise_parameters(self) -> dict[str, Any]: - return {"a": Operation._serialise_json(self.a), - "power": self.power} - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Pow(Operation.deserialise_json(parameters["a"]), parameters["power"]) - - def summary(self, indent_amount: int=0, indent=" "): - return (f"{indent_amount*indent}Pow\n" + - self.a.summary(indent_amount+1, indent) + "\n" + - f"{(indent_amount+1)*indent}{self.power}\n" + - f"{indent_amount*indent})") - - def __eq__(self, other): - if isinstance(other, Pow): - return self.a == other.a and self.power == other.power - - - -# -# Matrix operations -# - -class Transpose(Operation): - """ Transpose operation - as per numpy""" - - serialisation_name = "transpose" - - def __init__(self, a: Operation, axes: tuple[int] | None = None): - self.a = a - self.axes = axes - - def evaluate(self, variables: dict[int, T]) -> T: - return np.transpose(self.a.evaluate(variables)) - - def _derivative(self, hash_value: int) -> Operation: - return Transpose(self.a.derivative(hash_value)) # TODO: Check! - - def _clean(self): - clean_a = self.a._clean() - return Transpose(clean_a) - - - def _serialise_parameters(self) -> dict[str, Any]: - if self.axes is None: - return { "a": self.a._serialise_json() } - else: - return { - "a": self.a._serialise_json(), - "axes": list(self.axes) - } - - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - if "axes" in parameters: - return Transpose( - a=Operation.deserialise_json(parameters["a"]), - axes=tuple(parameters["axes"])) - else: - return Transpose( - a=Operation.deserialise_json(parameters["a"])) - - - def _summary_open(self): - return "Transpose" - - def __eq__(self, other): - if isinstance(other, Transpose): - return other.a == self.a - - -class Dot(BinaryOperation): - """ Dot product - backed by numpy's dot method""" - - serialisation_name = "dot" - - def evaluate(self, variables: dict[int, T]) -> T: - return dot(self.a.evaluate(variables), self.b.evaluate(variables)) - - def _derivative(self, hash_value: int) -> Operation: - return Add( - Dot(self.a, - self.b._derivative(hash_value)), - Dot(self.a._derivative(hash_value), - self.b)) - - def _clean_ab(self, a, b): - return Dot(a, b) # Do nothing for now - - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Dot(*BinaryOperation._deserialise_ab(parameters)) - - def _summary_open(self): - return "Dot" - - -# TODO: Add to base operation class, and to quantities -class MatMul(BinaryOperation): - """ Matrix multiplication, using __matmul__ dunder""" - - serialisation_name = "matmul" - - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) @ self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Add( - MatMul(self.a, - self.b._derivative(hash_value)), - MatMul(self.a._derivative(hash_value), - self.b)) - - def _clean_ab(self, a, b): - - if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): - # Convert 0*b or a*0 to 0 - return AdditiveIdentity() - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constant "a"@"b" to "a@b" - return Constant(a.evaluate({}) @ b.evaluate({}))._clean() - - elif isinstance(a, Neg): - return Neg(Mul(a.a, b)) - - elif isinstance(b, Neg): - return Neg(Mul(a, b.a)) - - return MatMul(a, b) - - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return MatMul(*BinaryOperation._deserialise_ab(parameters)) - - def _summary_open(self): - return "MatMul" - -class TensorDot(Operation): - serialisation_name = "tensor_product" - - def __init__(self, a: Operation, b: Operation, a_index: int, b_index: int): - self.a = a - self.b = b - self.a_index = a_index - self.b_index = b_index - - def evaluate(self, variables: dict[int, T]) -> T: - return tensordot(self.a, self.b, self.a_index, self.b_index) - - - def _serialise_parameters(self) -> dict[str, Any]: - return { - "a": self.a._serialise_json(), - "b": self.b._serialise_json(), - "a_index": self.a_index, - "b_index": self.b_index } - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return TensorDot(a = Operation.deserialise_json(parameters["a"]), - b = Operation.deserialise_json(parameters["b"]), - a_index=int(parameters["a_index"]), - b_index=int(parameters["b_index"])) - - def _summary_open(self): - return "TensorProduct" - - -_serialisable_classes = [AdditiveIdentity, MultiplicativeIdentity, Constant, - Variable, - Neg, Inv, - Add, Sub, Mul, Div, Pow, - Transpose, Dot, MatMul, TensorDot] - -_serialisation_lookup = {cls.serialisation_name: cls for cls in _serialisable_classes} - - -class UnitError(Exception): - """ Errors caused by unit specification not being correct """ - -def hash_data_via_numpy(*data: ArrayLike): - - md5_hash = hashlib.md5() - - for datum in data: - data_bytes = np.array(datum).tobytes() - md5_hash.update(data_bytes) - - # Hash function returns a hex string, we want an int - return int(md5_hash.hexdigest(), 16) - - - -##################################### -# # -# # -# # -# Quantities begin here # -# # -# # -# # -##################################### - - - -QuantityType = TypeVar("QuantityType") - - -class QuantityHistory: - """ Class that holds the information for keeping track of operations done on quantities """ - - def __init__(self, operation_tree: Operation, references: dict[int, "Quantity"]): - self.operation_tree = operation_tree - self.references = references - - self.reference_key_list = [key for key in self.references] - self.si_reference_values = {key: self.references[key].in_si() for key in self.references} - - def jacobian(self) -> list[Operation]: - """ Derivative of this quantity's operation history with respect to each of the references """ - - # Use the hash value to specify the variable of differentiation - return [self.operation_tree.derivative(key) for key in self.reference_key_list] - - def _recalculate(self): - """ Recalculate the value of this object - primary use case is for testing """ - return self.operation_tree.evaluate(self.references) - - def variance_propagate(self, quantity_units: Unit, covariances: dict[tuple[int, int]: "Quantity"] = {}): - """ Do standard error propagation to calculate the uncertainties associated with this quantity - - :param quantity_units: units in which the output should be calculated - :param covariances: off diagonal entries for the covariance matrix - """ - - if covariances: - raise NotImplementedError("User specified covariances not currently implemented") - - jacobian = self.jacobian() - - evaluated_jacobian = [entry.evaluate(self.references) for entry in jacobian] - - hash_values = [key for key in self.references] - output = None - - for hash_value, jac_component in zip(hash_values, evaluated_jacobian): - if output is None: - output = jac_component * (self.references[hash_value].variance * jac_component) - else: - output += jac_component * (self.references[hash_value].variance * jac_component) - - return output - - - @staticmethod - def variable(quantity: "Quantity"): - """ Create a history that starts with the provided data """ - return QuantityHistory(Variable(quantity.hash_value), {quantity.hash_value: quantity}) - - @staticmethod - def apply_operation(operation: type[Operation], *histories: "QuantityHistory", **extra_parameters) -> "QuantityHistory": - """ Apply an operation to the history - - This is slightly unsafe as it is possible to attempt to apply an n-ary operation to a number of trees other - than n, but it is relatively concise. Because it is concise we'll go with this for now and see if it causes - any problems down the line. It is a private static method to discourage misuse. - - """ - - # Copy references over, even though it overrides on collision, - # this should behave because only data based variables should be represented. - # Should not be a problem any more than losing histories - references = {} - for history in histories: - references.update(history.references) - - return QuantityHistory( - operation(*[history.operation_tree for history in histories], **extra_parameters), - references) - - def has_variance(self): - for key in self.references: - if self.references[key].has_variance: - return True - - return False - - def summary(self): - - variable_strings = [self.references[key].string_repr for key in self.references] - - s = "Variables: "+",".join(variable_strings) - s += "\n" - s += self.operation_tree.summary() - - return s - - -class Quantity[QuantityType]: - - - def __init__(self, - value: QuantityType, - units: Unit, - standard_error: QuantityType | None = None, - hash_seed = ""): - - self.value = value - """ Numerical value of this data, in the specified units""" - - self.units = units - """ Units of this data """ - - self._hash_seed = hash_seed - """ Retain this for copying operations""" - - self.hash_value = -1 - """ Hash based on value and uncertainty for data, -1 if it is a derived hash value """ - - self._variance = None - """ Contains the variance if it is data driven """ - - if standard_error is None: - self.hash_value = hash_data_via_numpy(hash_seed, value) - else: - self._variance = standard_error ** 2 - self.hash_value = hash_data_via_numpy(hash_seed, value, standard_error) - - self.history = QuantityHistory.variable(self) - - @property - def has_variance(self): - return self._variance is not None - - @property - def variance(self) -> "Quantity": - """ Get the variance of this object""" - if self._variance is None: - return Quantity(np.zeros_like(self.value), self.units**2) - else: - return Quantity(self._variance, self.units**2) - - def standard_deviation(self) -> "Quantity": - return self.variance ** 0.5 - - def in_units_of(self, units: Unit) -> QuantityType: - """ Get this quantity in other units """ - if self.units.equivalent(units): - return (self.units.scale / units.scale) * self.value - else: - raise UnitError(f"Target units ({units}) not compatible with existing units ({self.units}).") - - def to_units_of(self, new_units: Unit) -> "Quantity[QuantityType]": - new_value, new_error = self.in_units_of_with_standard_error(new_units) - return Quantity(value=new_value, - units=new_units, - standard_error=new_error, - hash_seed=self._hash_seed) - - def variance_in_units_of(self, units: Unit) -> QuantityType: - """ Get the variance of quantity in other units """ - variance = self.variance - if variance.units.equivalent(units): - return (variance.units.scale / units.scale) * variance - else: - raise UnitError(f"Target units ({units}) not compatible with existing units ({variance.units}).") - - def in_si(self): - si_units = self.units.si_equivalent() - return self.in_units_of(si_units) - - def in_units_of_with_standard_error(self, units): - variance = self.variance - units_squared = units**2 - - if variance.units.equivalent(units_squared): - - return self.in_units_of(units), np.sqrt(self.variance.in_units_of(units_squared)) - else: - raise UnitError(f"Target units ({units}) not compatible with existing units ({variance.units}).") - - def in_si_with_standard_error(self): - if self.has_variance: - return self.in_units_of_with_standard_error(self.units.si_equivalent()) - else: - return self.in_si(), None - - def __mul__(self: Self, other: ArrayLike | Self ) -> Self: - if isinstance(other, Quantity): - return DerivedQuantity( - self.value * other.value, - self.units * other.units, - history=QuantityHistory.apply_operation(Mul, self.history, other.history)) - - else: - return DerivedQuantity(self.value * other, self.units, - QuantityHistory( - Mul( - self.history.operation_tree, - Constant(other)), - self.history.references)) - - def __rmul__(self: Self, other: ArrayLike | Self): - if isinstance(other, Quantity): - return DerivedQuantity( - other.value * self.value, - other.units * self.units, - history=QuantityHistory.apply_operation( - Mul, - other.history, - self.history)) - - else: - return DerivedQuantity(other * self.value, self.units, - QuantityHistory( - Mul( - Constant(other), - self.history.operation_tree), - self.history.references)) - - - def __matmul__(self, other: ArrayLike | Self): - if isinstance(other, Quantity): - return DerivedQuantity( - self.value @ other.value, - self.units * other.units, - history=QuantityHistory.apply_operation( - MatMul, - self.history, - other.history)) - else: - return DerivedQuantity( - self.value @ other, - self.units, - QuantityHistory( - MatMul( - self.history.operation_tree, - Constant(other)), - self.history.references)) - - def __rmatmul__(self, other: ArrayLike | Self): - if isinstance(other, Quantity): - return DerivedQuantity( - other.value @ self.value, - other.units * self.units, - history=QuantityHistory.apply_operation( - MatMul, - other.history, - self.history)) - - else: - return DerivedQuantity(other @ self.value, self.units, - QuantityHistory( - MatMul( - Constant(other), - self.history.operation_tree), - self.history.references)) - - - def __truediv__(self: Self, other: float | Self) -> Self: - if isinstance(other, Quantity): - return DerivedQuantity( - self.value / other.value, - self.units / other.units, - history=QuantityHistory.apply_operation( - Div, - self.history, - other.history)) - - else: - return DerivedQuantity(self.value / other, self.units, - QuantityHistory( - Div( - Constant(other), - self.history.operation_tree), - self.history.references)) - - def __rtruediv__(self: Self, other: float | Self) -> Self: - if isinstance(other, Quantity): - return DerivedQuantity( - other.value / self.value, - other.units / self.units, - history=QuantityHistory.apply_operation( - Div, - other.history, - self.history - )) - - else: - return DerivedQuantity( - other / self.value, - self.units ** -1, - QuantityHistory( - Div( - Constant(other), - self.history.operation_tree), - self.history.references)) - - def __add__(self: Self, other: Self | ArrayLike) -> Self: - if isinstance(other, Quantity): - if self.units.equivalent(other.units): - return DerivedQuantity( - self.value + (other.value * other.units.scale) / self.units.scale, - self.units, - QuantityHistory.apply_operation( - Add, - self.history, - other.history)) - else: - raise UnitError(f"Units do not have the same dimensionality: {self.units} vs {other.units}") - - else: - raise UnitError(f"Cannot perform addition/subtraction non-quantity {type(other)} with quantity") - - # Don't need __radd__ because only quantity/quantity operations should be allowed - - def __neg__(self): - return DerivedQuantity(-self.value, self.units, - QuantityHistory.apply_operation( - Neg, - self.history - )) - - def __sub__(self: Self, other: Self | ArrayLike) -> Self: - return self + (-other) - - def __rsub__(self: Self, other: Self | ArrayLike) -> Self: - return (-self) + other - - def __pow__(self: Self, other: int | float): - return DerivedQuantity(self.value ** other, - self.units ** other, - QuantityHistory( - Pow( - self.history.operation_tree, - other), - self.history.references)) - - @staticmethod - def _array_repr_format(arr: np.ndarray): - """ Format the array """ - order = len(arr.shape) - reshaped = arr.reshape(-1) - if len(reshaped) <= 2: - numbers = ",".join([f"{n}" for n in reshaped]) - else: - numbers = f"{reshaped[0]} ... {reshaped[-1]}" - - # if len(reshaped) <= 4: - # numbers = ",".join([f"{n}" for n in reshaped]) - # else: - # numbers = f"{reshaped[0]}, {reshaped[1]} ... {reshaped[-2]}, {reshaped[-1]}" - - return "["*order + numbers + "]"*order - - def __repr__(self): - - if isinstance(self.units, NamedUnit): - - value = self.value - error = self.standard_deviation().in_units_of(self.units) - unit_string = self.units.symbol - - else: - value, error = self.in_si_with_standard_error() - unit_string = self.units.dimensions.si_repr() - - if isinstance(self.value, np.ndarray): - # Get the array in short form - numeric_string = self._array_repr_format(value) - - if self.has_variance: - numeric_string += " ± " + self._array_repr_format(error) - - else: - numeric_string = f"{value}" - if self.has_variance: - numeric_string += f" ± {error}" - - return numeric_string + " " + unit_string - - @staticmethod - def parse(number_or_string: str | ArrayLike, unit: str, absolute_temperature: False): - pass - - @property - def string_repr(self): - return str(self.hash_value) - - -class NamedQuantity[QuantityType](Quantity[QuantityType]): - def __init__(self, - name: str, - value: QuantityType, - units: Unit, - standard_error: QuantityType | None = None): - - super().__init__(value, units, standard_error=standard_error, hash_seed=name) - self.name = name - - def __repr__(self): - return f"[{self.name}] " + super().__repr__() - - def to_units_of(self, new_units: Unit) -> "NamedQuantity[QuantityType]": - new_value, new_error = self.in_units_of_with_standard_error(new_units) - return NamedQuantity(value=new_value, - units=new_units, - standard_error=new_error, - name=self.name) - - def with_standard_error(self, standard_error: Quantity): - if standard_error.units.equivalent(self.units): - return NamedQuantity( - value=self.value, - units=self.units, - standard_error=standard_error.in_units_of(self.units), - name=self.name) - - else: - raise UnitError(f"Standard error units ({standard_error.units}) " - f"are not compatible with value units ({self.units})") - - - @property - def string_repr(self): - return self.name - -class DerivedQuantity[QuantityType](Quantity[QuantityType]): - def __init__(self, value: QuantityType, units: Unit, history: QuantityHistory): - super().__init__(value, units, standard_error=None) - - self.history = history - self._variance_cache = None - self._has_variance = history.has_variance() - - - def to_units_of(self, new_units: Unit) -> "Quantity[QuantityType]": - # TODO: Lots of tests needed for this - return DerivedQuantity( - value=self.in_units_of(new_units), - units=new_units, - history=self.history) - - @property - def has_variance(self): - return self._has_variance - - @property - def variance(self) -> Quantity: - if self._variance_cache is None: - self._variance_cache = self.history.variance_propagate(self.units) - - return self._variance_cache +from typing import Self + +import numpy as np +from numpy._typing import ArrayLike + +from sasdata.quantities import units +from sasdata.quantities.numerical_encoding import numerical_decode, numerical_encode +from sasdata.quantities.units import Unit, NamedUnit + +import hashlib + +from typing import Any, TypeVar, Union + +import json + +T = TypeVar("T") + + + + + +################### Quantity based operations, need to be here to avoid cyclic dependencies ##################### + +def transpose(a: Union["Quantity[ArrayLike]", ArrayLike], axes: tuple | None = None): + """ Transpose an array or an array based quantity, can also do reordering of axes""" + if isinstance(a, Quantity): + + if axes is None: + return DerivedQuantity(value=np.transpose(a.value, axes=axes), + units=a.units, + history=QuantityHistory.apply_operation(Transpose, a.history)) + + else: + return DerivedQuantity(value=np.transpose(a.value, axes=axes), + units=a.units, + history=QuantityHistory.apply_operation(Transpose, a.history, axes=axes)) + + else: + return np.transpose(a, axes=axes) + + +def dot(a: Union["Quantity[ArrayLike]", ArrayLike], b: Union["Quantity[ArrayLike]", ArrayLike]): + """ Dot product of two arrays or two array based quantities """ + a_is_quantity = isinstance(a, Quantity) + b_is_quantity = isinstance(b, Quantity) + + if a_is_quantity or b_is_quantity: + + # If its only one of them that is a quantity, convert the other one + + if not a_is_quantity: + a = Quantity(a, units.none) + + if not b_is_quantity: + b = Quantity(b, units.none) + + return DerivedQuantity( + value=np.dot(a.value, b.value), + units=a.units * b.units, + history=QuantityHistory.apply_operation(Dot, a.history, b.history)) + + else: + return np.dot(a, b) + +def tensordot(a: Union["Quantity[ArrayLike]", ArrayLike] | ArrayLike, b: Union["Quantity[ArrayLike]", ArrayLike], a_index: int, b_index: int): + """ Tensor dot product - equivalent to contracting two tensors, such as + + A_{i0, i1, i2, i3...} and B_{j0, j1, j2...} + + e.g. if a_index is 1 and b_index is zero, it will be the sum + + C_{i0, i2, i3 ..., j1, j2 ...} = sum_k A_{i0, k, i2, i3 ...} B_{k, j1, j2 ...} + + (I think, have to check what happens with indices TODO!) + + """ + + a_is_quantity = isinstance(a, Quantity) + b_is_quantity = isinstance(b, Quantity) + + if a_is_quantity or b_is_quantity: + + # If its only one of them that is a quantity, convert the other one + + if not a_is_quantity: + a = Quantity(a, units.none) + + if not b_is_quantity: + b = Quantity(b, units.none) + + return DerivedQuantity( + value=np.tensordot(a.value, b.value, axes=(a_index, b_index)), + units=a.units * b.units, + history=QuantityHistory.apply_operation( + TensorDot, + a.history, + b.history, + a_index=a_index, + b_index=b_index)) + + else: + return np.tensordot(a, b, axes=(a_index, b_index)) + + +################### Operation Definitions ####################################### + +def hash_and_name(hash_or_name: int | str): + """ Infer the name of a variable from a hash, or the hash from the name + + Note: hash_and_name(hash_and_name(number)[1]) is not the identity + however: hash_and_name(hash_and_name(number)) is + """ + + if isinstance(hash_or_name, str): + hash_value = hash(hash_or_name) + name = hash_or_name + + return hash_value, name + + elif isinstance(hash_or_name, int): + hash_value = hash_or_name + name = f"#{hash_or_name}" + + return hash_value, name + + elif isinstance(hash_or_name, tuple): + return hash_or_name + + else: + raise TypeError("Variable name_or_hash_value must be either str or int") + +class Operation: + + serialisation_name = "unknown" + def summary(self, indent_amount: int = 0, indent: str=" "): + """ Summary of the operation tree""" + + s = f"{indent_amount*indent}{self._summary_open()}(\n" + + for chunk in self._summary_components(): + s += chunk.summary(indent_amount+1, indent) + "\n" + + s += f"{indent_amount*indent})" + + return s + def _summary_open(self): + """ First line of summary """ + + def _summary_components(self) -> list["Operation"]: + return [] + def evaluate(self, variables: dict[int, T]) -> T: + + """ Evaluate this operation """ + + def _derivative(self, hash_value: int) -> "Operation": + """ Get the derivative of this operation """ + + def _clean(self): + """ Clean up this operation - i.e. remove silly things like 1*x """ + return self + + def derivative(self, variable: Union[str, int, "Variable"], simplify=True): + if isinstance(variable, Variable): + hash_value = variable.hash_value + else: + hash_value, _ = hash_and_name(variable) + + derivative = self._derivative(hash_value) + + if not simplify: + return derivative + + derivative_string = derivative.serialise() + + # print("---------------") + # print("Base") + # print("---------------") + # print(derivative.summary()) + + # Inefficient way of doing repeated simplification, but it will work + for i in range(100): # set max iterations + + derivative = derivative._clean() + # + # print("-------------------") + # print("Iteration", i+1) + # print("-------------------") + # print(derivative.summary()) + # print("-------------------") + + new_derivative_string = derivative.serialise() + + if derivative_string == new_derivative_string: + break + + derivative_string = new_derivative_string + + return derivative + + @staticmethod + def deserialise(data: str) -> "Operation": + json_data = json.loads(data) + return Operation.deserialise_json(json_data) + + @staticmethod + def deserialise_json(json_data: dict) -> "Operation": + + operation = json_data["operation"] + parameters = json_data["parameters"] + cls = _serialisation_lookup[operation] + + try: + return cls._deserialise(parameters) + + except NotImplementedError: + raise NotImplementedError(f"No method to deserialise {operation} with {parameters} (cls={cls})") + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + raise NotImplementedError(f"Deserialise not implemented for this class") + + def serialise(self) -> str: + return json.dumps(self._serialise_json()) + + def _serialise_json(self) -> dict[str, Any]: + return {"operation": self.serialisation_name, + "parameters": self._serialise_parameters()} + + def _serialise_parameters(self) -> dict[str, Any]: + raise NotImplementedError("_serialise_parameters not implemented") + + def __eq__(self, other: "Operation"): + return NotImplemented + +class ConstantBase(Operation): + pass + +class AdditiveIdentity(ConstantBase): + + serialisation_name = "zero" + def evaluate(self, variables: dict[int, T]) -> T: + return 0 + + def _derivative(self, hash_value: int) -> Operation: + return AdditiveIdentity() + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return AdditiveIdentity() + + def _serialise_parameters(self) -> dict[str, Any]: + return {} + + def summary(self, indent_amount: int=0, indent=" "): + return f"{indent_amount*indent}0 [Add.Id.]" + + def __eq__(self, other): + if isinstance(other, AdditiveIdentity): + return True + elif isinstance(other, Constant): + if other.value == 0: + return True + + return False + + + +class MultiplicativeIdentity(ConstantBase): + + serialisation_name = "one" + + def evaluate(self, variables: dict[int, T]) -> T: + return 1 + + def _derivative(self, hash_value: int): + return AdditiveIdentity() + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return MultiplicativeIdentity() + + + def _serialise_parameters(self) -> dict[str, Any]: + return {} + + + def summary(self, indent_amount: int=0, indent=" "): + return f"{indent_amount*indent}1 [Mul.Id.]" + + def __eq__(self, other): + if isinstance(other, MultiplicativeIdentity): + return True + elif isinstance(other, Constant): + if other.value == 1: + return True + + return False + + +class Constant(ConstantBase): + + serialisation_name = "constant" + def __init__(self, value): + self.value = value + + def summary(self, indent_amount: int = 0, indent: str=" "): + return repr(self.value) + + def evaluate(self, variables: dict[int, T]) -> T: + return self.value + + def _derivative(self, hash_value: int): + return AdditiveIdentity() + + def _clean(self): + + if self.value == 0: + return AdditiveIdentity() + + elif self.value == 1: + return MultiplicativeIdentity() + + else: + return self + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + value = numerical_decode(parameters["value"]) + return Constant(value) + + + def _serialise_parameters(self) -> dict[str, Any]: + return {"value": numerical_encode(self.value)} + + def summary(self, indent_amount: int=0, indent=" "): + return f"{indent_amount*indent}{self.value}" + + def __eq__(self, other): + if isinstance(other, AdditiveIdentity): + return self.value == 0 + + elif isinstance(other, MultiplicativeIdentity): + return self.value == 1 + + elif isinstance(other, Constant): + if other.value == self.value: + return True + + return False + + +class Variable(Operation): + + serialisation_name = "variable" + def __init__(self, name_or_hash_value: int | str | tuple[int, str]): + self.hash_value, self.name = hash_and_name(name_or_hash_value) + + def evaluate(self, variables: dict[int, T]) -> T: + try: + return variables[self.hash_value] + except KeyError: + raise ValueError(f"Variable dictionary didn't have an entry for {self.name} (hash={self.hash_value})") + + def _derivative(self, hash_value: int) -> Operation: + if hash_value == self.hash_value: + return MultiplicativeIdentity() + else: + return AdditiveIdentity() + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + hash_value = parameters["hash_value"] + name = parameters["name"] + + return Variable((hash_value, name)) + + def _serialise_parameters(self) -> dict[str, Any]: + return {"hash_value": self.hash_value, + "name": self.name} + + def summary(self, indent_amount: int = 0, indent: str=" "): + return f"{indent_amount*indent}{self.name}" + + def __eq__(self, other): + if isinstance(other, Variable): + return self.hash_value == other.hash_value + + return False + +class UnaryOperation(Operation): + + def __init__(self, a: Operation): + self.a = a + + def _serialise_parameters(self) -> dict[str, Any]: + return {"a": self.a._serialise_json()} + + def _summary_components(self) -> list["Operation"]: + return [self.a] + + + + +class Neg(UnaryOperation): + + serialisation_name = "neg" + def evaluate(self, variables: dict[int, T]) -> T: + return -self.a.evaluate(variables) + + def _derivative(self, hash_value: int): + return Neg(self.a._derivative(hash_value)) + + def _clean(self): + + clean_a = self.a._clean() + + if isinstance(clean_a, Neg): + # Removes double negations + return clean_a.a + + elif isinstance(clean_a, Constant): + return Constant(-clean_a.value)._clean() + + else: + return Neg(clean_a) + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Neg(Operation.deserialise_json(parameters["a"])) + + + def _summary_open(self): + return "Neg" + + def __eq__(self, other): + if isinstance(other, Neg): + return other.a == self.a + + +class Inv(UnaryOperation): + + serialisation_name = "reciprocal" + + def evaluate(self, variables: dict[int, T]) -> T: + return 1/self.a.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Neg(Div(self.a._derivative(hash_value), Mul(self.a, self.a))) + + def _clean(self): + clean_a = self.a._clean() + + if isinstance(clean_a, Inv): + # Removes double negations + return clean_a.a + + elif isinstance(clean_a, Neg): + # cannonicalise 1/-a to -(1/a) + # over multiple iterations this should have the effect of ordering and gathering Neg and Inv + return Neg(Inv(clean_a.a)) + + elif isinstance(clean_a, Constant): + return Constant(1/clean_a.value)._clean() + + else: + return Inv(clean_a) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Inv(Operation.deserialise_json(parameters["a"])) + + def _summary_open(self): + return "Inv" + + + def __eq__(self, other): + if isinstance(other, Inv): + return other.a == self.a + +class BinaryOperation(Operation): + def __init__(self, a: Operation, b: Operation): + self.a = a + self.b = b + + def _clean(self): + return self._clean_ab(self.a._clean(), self.b._clean()) + + def _clean_ab(self, a, b): + raise NotImplementedError("_clean_ab not implemented") + + def _serialise_parameters(self) -> dict[str, Any]: + return {"a": self.a._serialise_json(), + "b": self.b._serialise_json()} + + @staticmethod + def _deserialise_ab(parameters) -> tuple[Operation, Operation]: + return (Operation.deserialise_json(parameters["a"]), + Operation.deserialise_json(parameters["b"])) + + + def _summary_components(self) -> list["Operation"]: + return [self.a, self.b] + + def _self_cls(self) -> type: + """ Own class""" + def __eq__(self, other): + if isinstance(other, self._self_cls()): + return other.a == self.a and self.b == other.b + +class Add(BinaryOperation): + + serialisation_name = "add" + + def _self_cls(self) -> type: + return Add + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) + self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Add(self.a._derivative(hash_value), self.b._derivative(hash_value)) + + def _clean_ab(self, a, b): + + if isinstance(a, AdditiveIdentity): + # Convert 0 + b to b + return b + + elif isinstance(b, AdditiveIdentity): + # Convert a + 0 to a + return a + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant "a"+"b" to "a+b" + return Constant(a.evaluate({}) + b.evaluate({}))._clean() + + elif isinstance(a, Neg): + if isinstance(b, Neg): + # Convert (-a)+(-b) to -(a+b) + return Neg(Add(a.a, b.a)) + else: + # Convert (-a) + b to b-a + return Sub(b, a.a) + + elif isinstance(b, Neg): + # Convert a+(-b) to a-b + return Sub(a, b.a) + + elif a == b: + return Mul(Constant(2), a) + + else: + return Add(a, b) + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Add(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "Add" + +class Sub(BinaryOperation): + + serialisation_name = "sub" + + + def _self_cls(self) -> type: + return Sub + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) - self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Sub(self.a._derivative(hash_value), self.b._derivative(hash_value)) + + def _clean_ab(self, a, b): + if isinstance(a, AdditiveIdentity): + # Convert 0 - b to -b + return Neg(b) + + elif isinstance(b, AdditiveIdentity): + # Convert a - 0 to a + return a + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant pair "a" - "b" to "a-b" + return Constant(a.evaluate({}) - b.evaluate({}))._clean() + + elif isinstance(a, Neg): + if isinstance(b, Neg): + # Convert (-a)-(-b) to b-a + return Sub(b.a, a.a) + else: + # Convert (-a)-b to -(a+b) + return Neg(Add(a.a, b)) + + elif isinstance(b, Neg): + # Convert a-(-b) to a+b + return Add(a, b.a) + + elif a == b: + return AdditiveIdentity() + + else: + return Sub(a, b) + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Sub(*BinaryOperation._deserialise_ab(parameters)) + + + def _summary_open(self): + return "Sub" + +class Mul(BinaryOperation): + + serialisation_name = "mul" + + + def _self_cls(self) -> type: + return Mul + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) * self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Add(Mul(self.a, self.b._derivative(hash_value)), Mul(self.a._derivative(hash_value), self.b)) + + def _clean_ab(self, a, b): + + if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): + # Convert 0*b or a*0 to 0 + return AdditiveIdentity() + + elif isinstance(a, MultiplicativeIdentity): + # Convert 1*b to b + return b + + elif isinstance(b, MultiplicativeIdentity): + # Convert a*1 to a + return a + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant "a"*"b" to "a*b" + return Constant(a.evaluate({}) * b.evaluate({}))._clean() + + elif isinstance(a, Inv) and isinstance(b, Inv): + return Inv(Mul(a.a, b.a)) + + elif isinstance(a, Inv) and not isinstance(b, Inv): + return Div(b, a.a) + + elif not isinstance(a, Inv) and isinstance(b, Inv): + return Div(a, b.a) + + elif isinstance(a, Neg): + return Neg(Mul(a.a, b)) + + elif isinstance(b, Neg): + return Neg(Mul(a, b.a)) + + elif a == b: + return Pow(a, 2) + + elif isinstance(a, Pow) and a.a == b: + return Pow(b, a.power + 1) + + elif isinstance(b, Pow) and b.a == a: + return Pow(a, b.power + 1) + + elif isinstance(a, Pow) and isinstance(b, Pow) and a.a == b.a: + return Pow(a.a, a.power + b.power) + + else: + return Mul(a, b) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Mul(*BinaryOperation._deserialise_ab(parameters)) + + + def _summary_open(self): + return "Mul" + +class Div(BinaryOperation): + + serialisation_name = "div" + + + def _self_cls(self) -> type: + return Div + + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) / self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Sub(Div(self.a.derivative(hash_value), self.b), + Div(Mul(self.a, self.b.derivative(hash_value)), Mul(self.b, self.b))) + + def _clean_ab(self, a, b): + if isinstance(a, AdditiveIdentity): + # Convert 0/b to 0 + return AdditiveIdentity() + + elif isinstance(a, MultiplicativeIdentity): + # Convert 1/b to inverse of b + return Inv(b) + + elif isinstance(b, MultiplicativeIdentity): + # Convert a/1 to a + return a + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constants "a"/"b" to "a/b" + return Constant(self.a.evaluate({}) / self.b.evaluate({}))._clean() + + + elif isinstance(a, Inv) and isinstance(b, Inv): + return Div(b.a, a.a) + + elif isinstance(a, Inv) and not isinstance(b, Inv): + return Inv(Mul(a.a, b)) + + elif not isinstance(a, Inv) and isinstance(b, Inv): + return Mul(a, b.a) + + elif a == b: + return MultiplicativeIdentity() + + elif isinstance(a, Pow) and a.a == b: + return Pow(b, a.power - 1) + + elif isinstance(b, Pow) and b.a == a: + return Pow(a, 1 - b.power) + + elif isinstance(a, Pow) and isinstance(b, Pow) and a.a == b.a: + return Pow(a.a, a.power - b.power) + + else: + return Div(a, b) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Div(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "Div" + +class Pow(Operation): + + serialisation_name = "pow" + + def __init__(self, a: Operation, power: float): + self.a = a + self.power = power + + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) ** self.power + + def _derivative(self, hash_value: int) -> Operation: + if self.power == 0: + return AdditiveIdentity() + + elif self.power == 1: + return self.a._derivative(hash_value) + + else: + return Mul(Constant(self.power), Mul(Pow(self.a, self.power-1), self.a._derivative(hash_value))) + + def _clean(self) -> Operation: + a = self.a._clean() + + if self.power == 1: + return a + + elif self.power == 0: + return MultiplicativeIdentity() + + elif self.power == -1: + return Inv(a) + + else: + return Pow(a, self.power) + + + def _serialise_parameters(self) -> dict[str, Any]: + return {"a": Operation._serialise_json(self.a), + "power": self.power} + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Pow(Operation.deserialise_json(parameters["a"]), parameters["power"]) + + def summary(self, indent_amount: int=0, indent=" "): + return (f"{indent_amount*indent}Pow\n" + + self.a.summary(indent_amount+1, indent) + "\n" + + f"{(indent_amount+1)*indent}{self.power}\n" + + f"{indent_amount*indent})") + + def __eq__(self, other): + if isinstance(other, Pow): + return self.a == other.a and self.power == other.power + + + +# +# Matrix operations +# + +class Transpose(Operation): + """ Transpose operation - as per numpy""" + + serialisation_name = "transpose" + + def __init__(self, a: Operation, axes: tuple[int] | None = None): + self.a = a + self.axes = axes + + def evaluate(self, variables: dict[int, T]) -> T: + return np.transpose(self.a.evaluate(variables)) + + def _derivative(self, hash_value: int) -> Operation: + return Transpose(self.a.derivative(hash_value)) # TODO: Check! + + def _clean(self): + clean_a = self.a._clean() + return Transpose(clean_a) + + + def _serialise_parameters(self) -> dict[str, Any]: + if self.axes is None: + return { "a": self.a._serialise_json() } + else: + return { + "a": self.a._serialise_json(), + "axes": list(self.axes) + } + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + if "axes" in parameters: + return Transpose( + a=Operation.deserialise_json(parameters["a"]), + axes=tuple(parameters["axes"])) + else: + return Transpose( + a=Operation.deserialise_json(parameters["a"])) + + + def _summary_open(self): + return "Transpose" + + def __eq__(self, other): + if isinstance(other, Transpose): + return other.a == self.a + + +class Dot(BinaryOperation): + """ Dot product - backed by numpy's dot method""" + + serialisation_name = "dot" + + def evaluate(self, variables: dict[int, T]) -> T: + return dot(self.a.evaluate(variables), self.b.evaluate(variables)) + + def _derivative(self, hash_value: int) -> Operation: + return Add( + Dot(self.a, + self.b._derivative(hash_value)), + Dot(self.a._derivative(hash_value), + self.b)) + + def _clean_ab(self, a, b): + return Dot(a, b) # Do nothing for now + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Dot(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "Dot" + + +# TODO: Add to base operation class, and to quantities +class MatMul(BinaryOperation): + """ Matrix multiplication, using __matmul__ dunder""" + + serialisation_name = "matmul" + + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) @ self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Add( + MatMul(self.a, + self.b._derivative(hash_value)), + MatMul(self.a._derivative(hash_value), + self.b)) + + def _clean_ab(self, a, b): + + if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): + # Convert 0*b or a*0 to 0 + return AdditiveIdentity() + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant "a"@"b" to "a@b" + return Constant(a.evaluate({}) @ b.evaluate({}))._clean() + + elif isinstance(a, Neg): + return Neg(Mul(a.a, b)) + + elif isinstance(b, Neg): + return Neg(Mul(a, b.a)) + + return MatMul(a, b) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return MatMul(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "MatMul" + +class TensorDot(Operation): + serialisation_name = "tensor_product" + + def __init__(self, a: Operation, b: Operation, a_index: int, b_index: int): + self.a = a + self.b = b + self.a_index = a_index + self.b_index = b_index + + def evaluate(self, variables: dict[int, T]) -> T: + return tensordot(self.a, self.b, self.a_index, self.b_index) + + + def _serialise_parameters(self) -> dict[str, Any]: + return { + "a": self.a._serialise_json(), + "b": self.b._serialise_json(), + "a_index": self.a_index, + "b_index": self.b_index } + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return TensorDot(a = Operation.deserialise_json(parameters["a"]), + b = Operation.deserialise_json(parameters["b"]), + a_index=int(parameters["a_index"]), + b_index=int(parameters["b_index"])) + + def _summary_open(self): + return "TensorProduct" + + +_serialisable_classes = [AdditiveIdentity, MultiplicativeIdentity, Constant, + Variable, + Neg, Inv, + Add, Sub, Mul, Div, Pow, + Transpose, Dot, MatMul, TensorDot] + +_serialisation_lookup = {cls.serialisation_name: cls for cls in _serialisable_classes} + + +class UnitError(Exception): + """ Errors caused by unit specification not being correct """ + +def hash_data_via_numpy(*data: ArrayLike): + + md5_hash = hashlib.md5() + + for datum in data: + data_bytes = np.array(datum).tobytes() + md5_hash.update(data_bytes) + + # Hash function returns a hex string, we want an int + return int(md5_hash.hexdigest(), 16) + + + +##################################### +# # +# # +# # +# Quantities begin here # +# # +# # +# # +##################################### + + + +QuantityType = TypeVar("QuantityType") + + +class QuantityHistory: + """ Class that holds the information for keeping track of operations done on quantities """ + + def __init__(self, operation_tree: Operation, references: dict[int, "Quantity"]): + self.operation_tree = operation_tree + self.references = references + + self.reference_key_list = [key for key in self.references] + self.si_reference_values = {key: self.references[key].in_si() for key in self.references} + + def jacobian(self) -> list[Operation]: + """ Derivative of this quantity's operation history with respect to each of the references """ + + # Use the hash value to specify the variable of differentiation + return [self.operation_tree.derivative(key) for key in self.reference_key_list] + + def _recalculate(self): + """ Recalculate the value of this object - primary use case is for testing """ + return self.operation_tree.evaluate(self.references) + + def variance_propagate(self, quantity_units: Unit, covariances: dict[tuple[int, int]: "Quantity"] = {}): + """ Do standard error propagation to calculate the uncertainties associated with this quantity + + :param quantity_units: units in which the output should be calculated + :param covariances: off diagonal entries for the covariance matrix + """ + + if covariances: + raise NotImplementedError("User specified covariances not currently implemented") + + jacobian = self.jacobian() + + evaluated_jacobian = [entry.evaluate(self.references) for entry in jacobian] + + hash_values = [key for key in self.references] + output = None + + for hash_value, jac_component in zip(hash_values, evaluated_jacobian): + if output is None: + output = jac_component * (self.references[hash_value].variance * jac_component) + else: + output += jac_component * (self.references[hash_value].variance * jac_component) + + return output + + + @staticmethod + def variable(quantity: "Quantity"): + """ Create a history that starts with the provided data """ + return QuantityHistory(Variable(quantity.hash_value), {quantity.hash_value: quantity}) + + @staticmethod + def apply_operation(operation: type[Operation], *histories: "QuantityHistory", **extra_parameters) -> "QuantityHistory": + """ Apply an operation to the history + + This is slightly unsafe as it is possible to attempt to apply an n-ary operation to a number of trees other + than n, but it is relatively concise. Because it is concise we'll go with this for now and see if it causes + any problems down the line. It is a private static method to discourage misuse. + + """ + + # Copy references over, even though it overrides on collision, + # this should behave because only data based variables should be represented. + # Should not be a problem any more than losing histories + references = {} + for history in histories: + references.update(history.references) + + return QuantityHistory( + operation(*[history.operation_tree for history in histories], **extra_parameters), + references) + + def has_variance(self): + for key in self.references: + if self.references[key].has_variance: + return True + + return False + + def summary(self): + + variable_strings = [self.references[key].string_repr for key in self.references] + + s = "Variables: "+",".join(variable_strings) + s += "\n" + s += self.operation_tree.summary() + + return s + + +class Quantity[QuantityType]: + + + def __init__(self, + value: QuantityType, + units: Unit, + standard_error: QuantityType | None = None, + hash_seed = ""): + + self.value = value + """ Numerical value of this data, in the specified units""" + + self.units = units + """ Units of this data """ + + self._hash_seed = hash_seed + """ Retain this for copying operations""" + + self.hash_value = -1 + """ Hash based on value and uncertainty for data, -1 if it is a derived hash value """ + + self._variance = None + """ Contains the variance if it is data driven """ + + if standard_error is None: + self.hash_value = hash_data_via_numpy(hash_seed, value) + else: + self._variance = standard_error ** 2 + self.hash_value = hash_data_via_numpy(hash_seed, value, standard_error) + + self.history = QuantityHistory.variable(self) + + @property + def has_variance(self): + return self._variance is not None + + @property + def variance(self) -> "Quantity": + """ Get the variance of this object""" + if self._variance is None: + return Quantity(np.zeros_like(self.value), self.units**2) + else: + return Quantity(self._variance, self.units**2) + + def standard_deviation(self) -> "Quantity": + return self.variance ** 0.5 + + def in_units_of(self, units: Unit) -> QuantityType: + """ Get this quantity in other units """ + if self.units.equivalent(units): + return (self.units.scale / units.scale) * self.value + else: + raise UnitError(f"Target units ({units}) not compatible with existing units ({self.units}).") + + def to_units_of(self, new_units: Unit) -> "Quantity[QuantityType]": + new_value, new_error = self.in_units_of_with_standard_error(new_units) + return Quantity(value=new_value, + units=new_units, + standard_error=new_error, + hash_seed=self._hash_seed) + + def variance_in_units_of(self, units: Unit) -> QuantityType: + """ Get the variance of quantity in other units """ + variance = self.variance + if variance.units.equivalent(units): + return (variance.units.scale / units.scale) * variance + else: + raise UnitError(f"Target units ({units}) not compatible with existing units ({variance.units}).") + + def in_si(self): + si_units = self.units.si_equivalent() + return self.in_units_of(si_units) + + def in_units_of_with_standard_error(self, units): + variance = self.variance + units_squared = units**2 + + if variance.units.equivalent(units_squared): + + return self.in_units_of(units), np.sqrt(self.variance.in_units_of(units_squared)) + else: + raise UnitError(f"Target units ({units}) not compatible with existing units ({variance.units}).") + + def in_si_with_standard_error(self): + if self.has_variance: + return self.in_units_of_with_standard_error(self.units.si_equivalent()) + else: + return self.in_si(), None + + def __mul__(self: Self, other: ArrayLike | Self ) -> Self: + if isinstance(other, Quantity): + return DerivedQuantity( + self.value * other.value, + self.units * other.units, + history=QuantityHistory.apply_operation(Mul, self.history, other.history)) + + else: + return DerivedQuantity(self.value * other, self.units, + QuantityHistory( + Mul( + self.history.operation_tree, + Constant(other)), + self.history.references)) + + def __rmul__(self: Self, other: ArrayLike | Self): + if isinstance(other, Quantity): + return DerivedQuantity( + other.value * self.value, + other.units * self.units, + history=QuantityHistory.apply_operation( + Mul, + other.history, + self.history)) + + else: + return DerivedQuantity(other * self.value, self.units, + QuantityHistory( + Mul( + Constant(other), + self.history.operation_tree), + self.history.references)) + + + def __matmul__(self, other: ArrayLike | Self): + if isinstance(other, Quantity): + return DerivedQuantity( + self.value @ other.value, + self.units * other.units, + history=QuantityHistory.apply_operation( + MatMul, + self.history, + other.history)) + else: + return DerivedQuantity( + self.value @ other, + self.units, + QuantityHistory( + MatMul( + self.history.operation_tree, + Constant(other)), + self.history.references)) + + def __rmatmul__(self, other: ArrayLike | Self): + if isinstance(other, Quantity): + return DerivedQuantity( + other.value @ self.value, + other.units * self.units, + history=QuantityHistory.apply_operation( + MatMul, + other.history, + self.history)) + + else: + return DerivedQuantity(other @ self.value, self.units, + QuantityHistory( + MatMul( + Constant(other), + self.history.operation_tree), + self.history.references)) + + + def __truediv__(self: Self, other: float | Self) -> Self: + if isinstance(other, Quantity): + return DerivedQuantity( + self.value / other.value, + self.units / other.units, + history=QuantityHistory.apply_operation( + Div, + self.history, + other.history)) + + else: + return DerivedQuantity(self.value / other, self.units, + QuantityHistory( + Div( + Constant(other), + self.history.operation_tree), + self.history.references)) + + def __rtruediv__(self: Self, other: float | Self) -> Self: + if isinstance(other, Quantity): + return DerivedQuantity( + other.value / self.value, + other.units / self.units, + history=QuantityHistory.apply_operation( + Div, + other.history, + self.history + )) + + else: + return DerivedQuantity( + other / self.value, + self.units ** -1, + QuantityHistory( + Div( + Constant(other), + self.history.operation_tree), + self.history.references)) + + def __add__(self: Self, other: Self | ArrayLike) -> Self: + if isinstance(other, Quantity): + if self.units.equivalent(other.units): + return DerivedQuantity( + self.value + (other.value * other.units.scale) / self.units.scale, + self.units, + QuantityHistory.apply_operation( + Add, + self.history, + other.history)) + else: + raise UnitError(f"Units do not have the same dimensionality: {self.units} vs {other.units}") + + else: + raise UnitError(f"Cannot perform addition/subtraction non-quantity {type(other)} with quantity") + + # Don't need __radd__ because only quantity/quantity operations should be allowed + + def __neg__(self): + return DerivedQuantity(-self.value, self.units, + QuantityHistory.apply_operation( + Neg, + self.history + )) + + def __sub__(self: Self, other: Self | ArrayLike) -> Self: + return self + (-other) + + def __rsub__(self: Self, other: Self | ArrayLike) -> Self: + return (-self) + other + + def __pow__(self: Self, other: int | float): + return DerivedQuantity(self.value ** other, + self.units ** other, + QuantityHistory( + Pow( + self.history.operation_tree, + other), + self.history.references)) + + @staticmethod + def _array_repr_format(arr: np.ndarray): + """ Format the array """ + order = len(arr.shape) + reshaped = arr.reshape(-1) + if len(reshaped) <= 2: + numbers = ",".join([f"{n}" for n in reshaped]) + else: + numbers = f"{reshaped[0]} ... {reshaped[-1]}" + + # if len(reshaped) <= 4: + # numbers = ",".join([f"{n}" for n in reshaped]) + # else: + # numbers = f"{reshaped[0]}, {reshaped[1]} ... {reshaped[-2]}, {reshaped[-1]}" + + return "["*order + numbers + "]"*order + + def __repr__(self): + + if isinstance(self.units, NamedUnit): + + value = self.value + error = self.standard_deviation().in_units_of(self.units) + unit_string = self.units.symbol + + else: + value, error = self.in_si_with_standard_error() + unit_string = self.units.dimensions.si_repr() + + if isinstance(self.value, np.ndarray): + # Get the array in short form + numeric_string = self._array_repr_format(value) + + if self.has_variance: + numeric_string += " ± " + self._array_repr_format(error) + + else: + numeric_string = f"{value}" + if self.has_variance: + numeric_string += f" ± {error}" + + return numeric_string + " " + unit_string + + @staticmethod + def parse(number_or_string: str | ArrayLike, unit: str, absolute_temperature: False): + pass + + @property + def string_repr(self): + return str(self.hash_value) + + +class NamedQuantity[QuantityType](Quantity[QuantityType]): + def __init__(self, + name: str, + value: QuantityType, + units: Unit, + standard_error: QuantityType | None = None): + + super().__init__(value, units, standard_error=standard_error, hash_seed=name) + self.name = name + + def __repr__(self): + return f"[{self.name}] " + super().__repr__() + + def to_units_of(self, new_units: Unit) -> "NamedQuantity[QuantityType]": + new_value, new_error = self.in_units_of_with_standard_error(new_units) + return NamedQuantity(value=new_value, + units=new_units, + standard_error=new_error, + name=self.name) + + def with_standard_error(self, standard_error: Quantity): + if standard_error.units.equivalent(self.units): + return NamedQuantity( + value=self.value, + units=self.units, + standard_error=standard_error.in_units_of(self.units), + name=self.name) + + else: + raise UnitError(f"Standard error units ({standard_error.units}) " + f"are not compatible with value units ({self.units})") + + + @property + def string_repr(self): + return self.name + +class DerivedQuantity[QuantityType](Quantity[QuantityType]): + def __init__(self, value: QuantityType, units: Unit, history: QuantityHistory): + super().__init__(value, units, standard_error=None) + + self.history = history + self._variance_cache = None + self._has_variance = history.has_variance() + + + def to_units_of(self, new_units: Unit) -> "Quantity[QuantityType]": + # TODO: Lots of tests needed for this + return DerivedQuantity( + value=self.in_units_of(new_units), + units=new_units, + history=self.history) + + @property + def has_variance(self): + return self._has_variance + + @property + def variance(self) -> Quantity: + if self._variance_cache is None: + self._variance_cache = self.history.variance_propagate(self.units) + + return self._variance_cache From fb6c63501dcd6b456364cf2ebc96776ba209f486 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Wed, 23 Oct 2024 10:10:03 +0100 Subject: [PATCH 34/35] Fix test import --- test/slicers/utest_meshmerge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/slicers/utest_meshmerge.py b/test/slicers/utest_meshmerge.py index 21071c0..4e4ee83 100644 --- a/test/slicers/utest_meshmerge.py +++ b/test/slicers/utest_meshmerge.py @@ -4,7 +4,7 @@ It's pretty hard to test componentwise, but we can do some tests of the general behaviour """ -from sasdata.slicing.meshes import meshmerge +from sasdata.slicing.meshes.meshmerge import meshmerge from test.slicers.meshes_for_testing import ( grid_mesh, shape_mesh, expected_grid_mappings, expected_shape_mappings) From b134fcdea8509cd71ff922aaa291d58b9063b233 Mon Sep 17 00:00:00 2001 From: lucas-wilkins Date: Fri, 25 Oct 2024 10:46:07 +0100 Subject: [PATCH 35/35] Work on tests for sparse matrix encoding --- sasdata/quantities/test_numerical_encoding.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sasdata/quantities/test_numerical_encoding.py b/sasdata/quantities/test_numerical_encoding.py index 83fa5fe..80cfbad 100644 --- a/sasdata/quantities/test_numerical_encoding.py +++ b/sasdata/quantities/test_numerical_encoding.py @@ -62,4 +62,7 @@ def test_numpy_dtypes_encode_decode(dtype): ((6, 1), (1, 0, 5), (0, 0, 0)), ]) def test_coo_matrix_encode_decode(shape, n, m, dtype): - test_matrix = np.arange() + + i_indices = + + values = np.arange(10) \ No newline at end of file