diff --git a/requirements.txt b/requirements.txt index 0eb2649..389c900 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ lxml # Calculation numpy==1.* +scipy # Unit testing pytest @@ -11,3 +12,6 @@ unittest-xml-reporting # Documentation sphinx + +# Other stuff +matplotlib diff --git a/sasdata/data.py b/sasdata/data.py index 7f0cbfb..544ba27 100644 --- a/sasdata/data.py +++ b/sasdata/data.py @@ -2,6 +2,8 @@ from typing import TypeVar, Any, Self from dataclasses import dataclass +import numpy as np + from quantities.quantity import NamedQuantity from sasdata.metadata import Metadata from sasdata.quantities.accessors import AccessorTarget @@ -9,7 +11,11 @@ class SasData: - def __init__(self, name: str, data_contents: list[NamedQuantity], raw_metadata: Group, verbose: bool=False): + def __init__(self, name: str, + data_contents: list[NamedQuantity], + raw_metadata: Group, + verbose: bool=False): + self.name = name self._data_contents = data_contents self._raw_metadata = raw_metadata @@ -17,14 +23,11 @@ def __init__(self, name: str, data_contents: list[NamedQuantity], raw_metadata: self.metadata = Metadata(AccessorTarget(raw_metadata, verbose=verbose)) - # TO IMPLEMENT - - # abscissae: list[NamedQuantity[np.ndarray]] - # ordinate: NamedQuantity[np.ndarray] - # other: list[NamedQuantity[np.ndarray]] - # - # metadata: Metadata - # model_requirements: ModellingRequirements + # Components that need to be organised after creation + self.ordinate: NamedQuantity[np.ndarray] = None # TODO: fill out + self.abscissae: list[NamedQuantity[np.ndarray]] = None # TODO: fill out + self.mask = None # TODO: fill out + self.model_requirements = None # TODO: fill out def summary(self, indent = " ", include_raw=False): s = f"{self.name}\n" diff --git a/sasdata/manual_tests/__init__.py b/sasdata/manual_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sasdata/manual_tests/interpolation.py b/sasdata/manual_tests/interpolation.py new file mode 100644 index 0000000..c46078b --- /dev/null +++ b/sasdata/manual_tests/interpolation.py @@ -0,0 +1,44 @@ +import numpy as np +import matplotlib.pyplot as plt + +from sasdata.quantities.quantity import NamedQuantity +from sasdata.quantities.plotting import quantity_plot +from sasdata.quantities import units + +from sasdata.transforms.rebinning import calculate_interpolation_matrix_1d +from sasdata.transforms.rebinning import InterpolationOptions + +def linear_interpolation_check(): + + for from_bins in [(-10, 10, 10), + (-10, 10, 1000), + (-15, 5, 10), + (15,5, 10)]: + for to_bins in [ + (-15, 0, 10), + (-15, 15, 10), + (0, 20, 100)]: + + plt.figure() + + x = NamedQuantity("x", np.linspace(*from_bins), units=units.meters) + y = x**2 + + quantity_plot(x, y) + + new_x = NamedQuantity("x_new", np.linspace(*to_bins), units=units.meters) + + rebin_mat = calculate_interpolation_matrix_1d(x, new_x, order=InterpolationOptions.LINEAR) + + new_y = y @ rebin_mat + + quantity_plot(new_x, new_y) + + print(new_y.history.summary()) + + plt.show() + + + + +linear_interpolation_check() \ No newline at end of file diff --git a/sasdata/model_requirements.py b/sasdata/model_requirements.py index 5d68ad1..12ad545 100644 --- a/sasdata/model_requirements.py +++ b/sasdata/model_requirements.py @@ -3,7 +3,7 @@ import numpy as np from sasdata.metadata import Metadata -from transforms.operation import Operation +from sasdata.quantities.quantity import Operation @dataclass diff --git a/sasdata/quantities/math_operations_test.py b/sasdata/quantities/math_operations_test.py new file mode 100644 index 0000000..5bda5a2 --- /dev/null +++ b/sasdata/quantities/math_operations_test.py @@ -0,0 +1,152 @@ +""" Tests for math operations """ + +import pytest + +import numpy as np +from sasdata.quantities.quantity import NamedQuantity, tensordot, transpose +from sasdata.quantities import units + +order_list = [ + [0, 1, 2, 3], + [0, 2, 1], + [1, 0], + [0, 1], + [2, 0, 1], + [3, 1, 2, 0] +] + +@pytest.mark.parametrize("order", order_list) +def test_transpose_raw(order: list[int]): + """ Check that the transpose operation changes the order of indices correctly - uses sizes as way of tracking""" + + input_shape = tuple([i+1 for i in range(len(order))]) + expected_shape = tuple([i+1 for i in order]) + + input_mat = np.zeros(input_shape) + + measured_mat = transpose(input_mat, axes=tuple(order)) + + assert measured_mat.shape == expected_shape + + +@pytest.mark.parametrize("order", order_list) +def test_transpose_raw(order: list[int]): + """ Check that the transpose operation changes the order of indices correctly - uses sizes as way of tracking""" + input_shape = tuple([i + 1 for i in range(len(order))]) + expected_shape = tuple([i + 1 for i in order]) + + input_mat = NamedQuantity("testmat", np.zeros(input_shape), units=units.none) + + measured_mat = transpose(input_mat, axes=tuple(order)) + + assert measured_mat.value.shape == expected_shape + + +rng_seed = 1979 +tensor_product_with_identity_sizes = (4,6,5) + +@pytest.mark.parametrize("index, size", [tup for tup in enumerate(tensor_product_with_identity_sizes)]) +def test_tensor_product_with_identity_quantities(index, size): + """ Check the correctness of the tensor product by multiplying by the identity (quantity, quantity)""" + np.random.seed(rng_seed) + + x = NamedQuantity("x", np.random.rand(*tensor_product_with_identity_sizes), units=units.meters) + y = NamedQuantity("y", np.eye(size), units.seconds) + + z = tensordot(x, y, index, 0) + + # Check units + assert z.units == units.meters * units.seconds + + # Expected sizes - last index gets moved to end + output_order = [i for i in (0, 1, 2) if i != index] + [index] + output_sizes = [tensor_product_with_identity_sizes[i] for i in output_order] + + assert z.value.shape == tuple(output_sizes) + + # Restore original order and check + reverse_order = [-1, -1, -1] + for to_index, from_index in enumerate(output_order): + reverse_order[from_index] = to_index + + z_reordered = transpose(z, axes = tuple(reverse_order)) + + assert z_reordered.value.shape == tensor_product_with_identity_sizes + + # Check values + + mat_in = x.in_si() + mat_out = transpose(z, axes=tuple(reverse_order)).in_si() + + assert np.all(np.abs(mat_in - mat_out) < 1e-10) + + +@pytest.mark.parametrize("index, size", [tup for tup in enumerate(tensor_product_with_identity_sizes)]) +def test_tensor_product_with_identity_quantity_matrix(index, size): + """ Check the correctness of the tensor product by multiplying by the identity (quantity, matrix)""" + np.random.seed(rng_seed) + + x = NamedQuantity("x", np.random.rand(*tensor_product_with_identity_sizes), units.meters) + y = np.eye(size) + + z = tensordot(x, y, index, 0) + + assert z.units == units.meters + + # Expected sizes - last index gets moved to end + output_order = [i for i in (0, 1, 2) if i != index] + [index] + output_sizes = [tensor_product_with_identity_sizes[i] for i in output_order] + + assert z.value.shape == tuple(output_sizes) + + # Restore original order and check + reverse_order = [-1, -1, -1] + for to_index, from_index in enumerate(output_order): + reverse_order[from_index] = to_index + + z_reordered = transpose(z, axes = tuple(reverse_order)) + + assert z_reordered.value.shape == tensor_product_with_identity_sizes + + # Check values + + mat_in = x.in_si() + mat_out = transpose(z, axes=tuple(reverse_order)).in_si() + + assert np.all(np.abs(mat_in - mat_out) < 1e-10) + + +@pytest.mark.parametrize("index, size", [tup for tup in enumerate(tensor_product_with_identity_sizes)]) +def test_tensor_product_with_identity_matrix_quantity(index, size): + """ Check the correctness of the tensor product by multiplying by the identity (matrix, quantity)""" + np.random.seed(rng_seed) + + x = np.random.rand(*tensor_product_with_identity_sizes) + y = NamedQuantity("y", np.eye(size), units.seconds) + + z = tensordot(x, y, index, 0) + + assert z.units == units.seconds + + + # Expected sizes - last index gets moved to end + output_order = [i for i in (0, 1, 2) if i != index] + [index] + output_sizes = [tensor_product_with_identity_sizes[i] for i in output_order] + + assert z.value.shape == tuple(output_sizes) + + # Restore original order and check + reverse_order = [-1, -1, -1] + for to_index, from_index in enumerate(output_order): + reverse_order[from_index] = to_index + + z_reordered = transpose(z, axes = tuple(reverse_order)) + + assert z_reordered.value.shape == tensor_product_with_identity_sizes + + # Check values + + mat_in = x + mat_out = transpose(z, axes=tuple(reverse_order)).in_si() + + assert np.all(np.abs(mat_in - mat_out) < 1e-10) diff --git a/sasdata/quantities/numerical_encoding.py b/sasdata/quantities/numerical_encoding.py new file mode 100644 index 0000000..879880a --- /dev/null +++ b/sasdata/quantities/numerical_encoding.py @@ -0,0 +1,72 @@ +import numpy as np +from scipy.sparse import coo_matrix, csr_matrix, csc_matrix, coo_array, csr_array, csc_array + +import base64 +import struct + + +def numerical_encode(obj: int | float | np.ndarray | coo_matrix | coo_array | csr_matrix | csr_array | csc_matrix | csc_array): + + if isinstance(obj, int): + return {"type": "int", + "value": obj} + + elif isinstance(obj, float): + return {"type": "float", + "value": base64.b64encode(bytearray(struct.pack('d', obj)))} + + elif isinstance(obj, np.ndarray): + return { + "type": "numpy", + "value": base64.b64encode(obj.tobytes()), + "dtype": obj.dtype.str, + "shape": list(obj.shape) + } + + elif isinstance(obj, (coo_matrix, coo_array, csr_matrix, csr_array, csc_matrix, csc_array)): + + output = { + "type": obj.__class__.__name__, # not robust to name changes, but more concise + "dtype": obj.dtype.str, + "shape": list(obj.shape) + } + + if isinstance(obj, (coo_array, coo_matrix)): + + output["data"] = numerical_encode(obj.data) + output["coords"] = [numerical_encode(coord) for coord in obj.coords] + + + elif isinstance(obj, (csr_array, csr_matrix)): + pass + + + elif isinstance(obj, (csc_array, csc_matrix)): + + pass + + + return output + + else: + raise TypeError(f"Cannot serialise object of type: {type(obj)}") + +def numerical_decode(data: dict[str, str | int | list[int]]) -> int | float | np.ndarray | coo_matrix | coo_array | csr_matrix | csr_array | csc_matrix | csc_array: + obj_type = data["type"] + + match obj_type: + case "int": + return int(data["value"]) + + case "float": + return struct.unpack('d', base64.b64decode(data["value"]))[0] + + case "numpy": + value = base64.b64decode(data["value"]) + dtype = np.dtype(data["dtype"]) + shape = tuple(data["shape"]) + return np.frombuffer(value, dtype=dtype).reshape(*shape) + + case _: + raise ValueError(f"Cannot decode objects of type '{obj_type}'") + diff --git a/sasdata/quantities/operations.py b/sasdata/quantities/operations.py deleted file mode 100644 index 8dbd82f..0000000 --- a/sasdata/quantities/operations.py +++ /dev/null @@ -1,710 +0,0 @@ -from typing import Any, TypeVar, Union - -import json - -T = TypeVar("T") - -def hash_and_name(hash_or_name: int | str): - """ Infer the name of a variable from a hash, or the hash from the name - - Note: hash_and_name(hash_and_name(number)[1]) is not the identity - however: hash_and_name(hash_and_name(number)) is - """ - - if isinstance(hash_or_name, str): - hash_value = hash(hash_or_name) - name = hash_or_name - - return hash_value, name - - elif isinstance(hash_or_name, int): - hash_value = hash_or_name - name = f"#{hash_or_name}" - - return hash_value, name - - elif isinstance(hash_or_name, tuple): - return hash_or_name - - else: - raise TypeError("Variable name_or_hash_value must be either str or int") - - -class Operation: - - serialisation_name = "unknown" - def summary(self, indent_amount: int = 0, indent: str=" "): - """ Summary of the operation tree""" - - s = f"{indent_amount*indent}{self._summary_open()}(\n" - - for chunk in self._summary_components(): - s += chunk.summary(indent_amount+1, indent) + "\n" - - s += f"{indent_amount*indent})" - - return s - def _summary_open(self): - """ First line of summary """ - - def _summary_components(self) -> list["Operation"]: - return [] - def evaluate(self, variables: dict[int, T]) -> T: - - """ Evaluate this operation """ - - def _derivative(self, hash_value: int) -> "Operation": - """ Get the derivative of this operation """ - - def _clean(self): - """ Clean up this operation - i.e. remove silly things like 1*x """ - return self - - def derivative(self, variable: Union[str, int, "Variable"], simplify=True): - if isinstance(variable, Variable): - hash_value = variable.hash_value - else: - hash_value, _ = hash_and_name(variable) - - derivative = self._derivative(hash_value) - - if not simplify: - return derivative - - derivative_string = derivative.serialise() - - # print("---------------") - # print("Base") - # print("---------------") - # print(derivative.summary()) - - # Inefficient way of doing repeated simplification, but it will work - for i in range(100): # set max iterations - - derivative = derivative._clean() - # - # print("-------------------") - # print("Iteration", i+1) - # print("-------------------") - # print(derivative.summary()) - # print("-------------------") - - new_derivative_string = derivative.serialise() - - if derivative_string == new_derivative_string: - break - - derivative_string = new_derivative_string - - return derivative - - @staticmethod - def deserialise(data: str) -> "Operation": - json_data = json.loads(data) - return Operation.deserialise_json(json_data) - - @staticmethod - def deserialise_json(json_data: dict) -> "Operation": - - operation = json_data["operation"] - parameters = json_data["parameters"] - cls = _serialisation_lookup[operation] - - try: - return cls._deserialise(parameters) - - except NotImplementedError: - raise NotImplementedError(f"No method to deserialise {operation} with {parameters} (cls={cls})") - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - raise NotImplementedError(f"Deserialise not implemented for this class") - - def serialise(self) -> str: - return json.dumps(self._serialise_json()) - - def _serialise_json(self) -> dict[str, Any]: - return {"operation": self.serialisation_name, - "parameters": self._serialise_parameters()} - - def _serialise_parameters(self) -> dict[str, Any]: - raise NotImplementedError("_serialise_parameters not implemented") - - def __eq__(self, other: "Operation"): - return NotImplemented - -class ConstantBase(Operation): - pass - -class AdditiveIdentity(ConstantBase): - - serialisation_name = "zero" - def evaluate(self, variables: dict[int, T]) -> T: - return 0 - - def _derivative(self, hash_value: int) -> Operation: - return AdditiveIdentity() - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return AdditiveIdentity() - - def _serialise_parameters(self) -> dict[str, Any]: - return {} - - def summary(self, indent_amount: int=0, indent=" "): - return f"{indent_amount*indent}0 [Add.Id.]" - - def __eq__(self, other): - if isinstance(other, AdditiveIdentity): - return True - elif isinstance(other, Constant): - if other.value == 0: - return True - - return False - - - -class MultiplicativeIdentity(ConstantBase): - - serialisation_name = "one" - - def evaluate(self, variables: dict[int, T]) -> T: - return 1 - - def _derivative(self, hash_value: int): - return AdditiveIdentity() - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return MultiplicativeIdentity() - - - def _serialise_parameters(self) -> dict[str, Any]: - return {} - - - def summary(self, indent_amount: int=0, indent=" "): - return f"{indent_amount*indent}1 [Mul.Id.]" - - def __eq__(self, other): - if isinstance(other, MultiplicativeIdentity): - return True - elif isinstance(other, Constant): - if other.value == 1: - return True - - return False - - -class Constant(ConstantBase): - - serialisation_name = "constant" - def __init__(self, value): - self.value = value - - def summary(self, indent_amount: int = 0, indent: str=" "): - pass - - def evaluate(self, variables: dict[int, T]) -> T: - return self.value - - def _derivative(self, hash_value: int): - return AdditiveIdentity() - - def _clean(self): - - if self.value == 0: - return AdditiveIdentity() - - elif self.value == 1: - return MultiplicativeIdentity() - - else: - return self - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - value = parameters["value"] - return Constant(value) - - - def _serialise_parameters(self) -> dict[str, Any]: - return {"value": self.value} - - - def summary(self, indent_amount: int=0, indent=" "): - return f"{indent_amount*indent}{self.value}" - - def __eq__(self, other): - if isinstance(other, AdditiveIdentity): - return self.value == 0 - - elif isinstance(other, MultiplicativeIdentity): - return self.value == 1 - - elif isinstance(other, Constant): - if other.value == self.value: - return True - - return False - - -class Variable(Operation): - - serialisation_name = "variable" - def __init__(self, name_or_hash_value: int | str | tuple[int, str]): - self.hash_value, self.name = hash_and_name(name_or_hash_value) - - def evaluate(self, variables: dict[int, T]) -> T: - try: - return variables[self.hash_value] - except KeyError: - raise ValueError(f"Variable dictionary didn't have an entry for {self.name} (hash={self.hash_value})") - - def _derivative(self, hash_value: int) -> Operation: - if hash_value == self.hash_value: - return MultiplicativeIdentity() - else: - return AdditiveIdentity() - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - hash_value = parameters["hash_value"] - name = parameters["name"] - - return Variable((hash_value, name)) - - def _serialise_parameters(self) -> dict[str, Any]: - return {"hash_value": self.hash_value, - "name": self.name} - - def summary(self, indent_amount: int = 0, indent: str=" "): - return f"{indent_amount*indent}{self.name}" - - def __eq__(self, other): - if isinstance(other, Variable): - return self.hash_value == other.hash_value - - return False - -class UnaryOperation(Operation): - - def __init__(self, a: Operation): - self.a = a - - def _serialise_parameters(self) -> dict[str, Any]: - return {"a": self.a._serialise_json()} - - def _summary_components(self) -> list["Operation"]: - return [self.a] - - - - -class Neg(UnaryOperation): - - serialisation_name = "neg" - def evaluate(self, variables: dict[int, T]) -> T: - return -self.a.evaluate(variables) - - def _derivative(self, hash_value: int): - return Neg(self.a._derivative(hash_value)) - - def _clean(self): - - clean_a = self.a._clean() - - if isinstance(clean_a, Neg): - # Removes double negations - return clean_a.a - - elif isinstance(clean_a, Constant): - return Constant(-clean_a.value)._clean() - - else: - return Neg(clean_a) - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Neg(Operation.deserialise_json(parameters["a"])) - - - def _summary_open(self): - return "Neg" - - def __eq__(self, other): - if isinstance(other, Neg): - return other.a == self.a - - -class Inv(UnaryOperation): - - serialisation_name = "reciprocal" - - def evaluate(self, variables: dict[int, T]) -> T: - return 1/self.a.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Neg(Div(self.a._derivative(hash_value), Mul(self.a, self.a))) - - def _clean(self): - clean_a = self.a._clean() - - if isinstance(clean_a, Inv): - # Removes double negations - return clean_a.a - - elif isinstance(clean_a, Neg): - # cannonicalise 1/-a to -(1/a) - # over multiple iterations this should have the effect of ordering and gathering Neg and Inv - return Neg(Inv(clean_a.a)) - - elif isinstance(clean_a, Constant): - return Constant(1/clean_a.value)._clean() - - else: - return Inv(clean_a) - - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Inv(Operation.deserialise_json(parameters["a"])) - - def _summary_open(self): - return "Inv" - - - def __eq__(self, other): - if isinstance(other, Inv): - return other.a == self.a - -class BinaryOperation(Operation): - def __init__(self, a: Operation, b: Operation): - self.a = a - self.b = b - - def _clean(self): - return self._clean_ab(self.a._clean(), self.b._clean()) - - def _clean_ab(self, a, b): - raise NotImplementedError("_clean_ab not implemented") - - def _serialise_parameters(self) -> dict[str, Any]: - return {"a": self.a._serialise_json(), - "b": self.b._serialise_json()} - - @staticmethod - def _deserialise_ab(parameters) -> tuple[Operation, Operation]: - return (Operation.deserialise_json(parameters["a"]), - Operation.deserialise_json(parameters["b"])) - - - def _summary_components(self) -> list["Operation"]: - return [self.a, self.b] - - def _self_cls(self) -> type: - """ Own class""" - def __eq__(self, other): - if isinstance(other, self._self_cls()): - return other.a == self.a and self.b == other.b - -class Add(BinaryOperation): - - serialisation_name = "add" - - def _self_cls(self) -> type: - return Add - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) + self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Add(self.a._derivative(hash_value), self.b._derivative(hash_value)) - - def _clean_ab(self, a, b): - - if isinstance(a, AdditiveIdentity): - # Convert 0 + b to b - return b - - elif isinstance(b, AdditiveIdentity): - # Convert a + 0 to a - return a - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constant "a"+"b" to "a+b" - return Constant(a.evaluate({}) + b.evaluate({}))._clean() - - elif isinstance(a, Neg): - if isinstance(b, Neg): - # Convert (-a)+(-b) to -(a+b) - return Neg(Add(a.a, b.a)) - else: - # Convert (-a) + b to b-a - return Sub(b, a.a) - - elif isinstance(b, Neg): - # Convert a+(-b) to a-b - return Sub(a, b.a) - - elif a == b: - return Mul(Constant(2), a) - - else: - return Add(a, b) - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Add(*BinaryOperation._deserialise_ab(parameters)) - - def _summary_open(self): - return "Add" - -class Sub(BinaryOperation): - - serialisation_name = "sub" - - - def _self_cls(self) -> type: - return Sub - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) - self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Sub(self.a._derivative(hash_value), self.b._derivative(hash_value)) - - def _clean_ab(self, a, b): - if isinstance(a, AdditiveIdentity): - # Convert 0 - b to -b - return Neg(b) - - elif isinstance(b, AdditiveIdentity): - # Convert a - 0 to a - return a - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constant pair "a" - "b" to "a-b" - return Constant(a.evaluate({}) - b.evaluate({}))._clean() - - elif isinstance(a, Neg): - if isinstance(b, Neg): - # Convert (-a)-(-b) to b-a - return Sub(b.a, a.a) - else: - # Convert (-a)-b to -(a+b) - return Neg(Add(a.a, b)) - - elif isinstance(b, Neg): - # Convert a-(-b) to a+b - return Add(a, b.a) - - elif a == b: - return AdditiveIdentity() - - else: - return Sub(a, b) - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Sub(*BinaryOperation._deserialise_ab(parameters)) - - - def _summary_open(self): - return "Sub" - -class Mul(BinaryOperation): - - serialisation_name = "mul" - - - def _self_cls(self) -> type: - return Mul - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) * self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Add(Mul(self.a, self.b._derivative(hash_value)), Mul(self.a._derivative(hash_value), self.b)) - - def _clean_ab(self, a, b): - - if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): - # Convert 0*b or a*0 to 0 - return AdditiveIdentity() - - elif isinstance(a, MultiplicativeIdentity): - # Convert 1*b to b - return b - - elif isinstance(b, MultiplicativeIdentity): - # Convert a*1 to a - return a - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constant "a"*"b" to "a*b" - return Constant(a.evaluate({}) * b.evaluate({}))._clean() - - elif isinstance(a, Inv) and isinstance(b, Inv): - return Inv(Mul(a.a, b.a)) - - elif isinstance(a, Inv) and not isinstance(b, Inv): - return Div(b, a.a) - - elif not isinstance(a, Inv) and isinstance(b, Inv): - return Div(a, b.a) - - elif isinstance(a, Neg): - return Neg(Mul(a.a, b)) - - elif isinstance(b, Neg): - return Neg(Mul(a, b.a)) - - elif a == b: - return Pow(a, 2) - - elif isinstance(a, Pow) and a.a == b: - return Pow(b, a.power + 1) - - elif isinstance(b, Pow) and b.a == a: - return Pow(a, b.power + 1) - - elif isinstance(a, Pow) and isinstance(b, Pow) and a.a == b.a: - return Pow(a.a, a.power + b.power) - - else: - return Mul(a, b) - - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Mul(*BinaryOperation._deserialise_ab(parameters)) - - - def _summary_open(self): - return "Mul" - -class Div(BinaryOperation): - - serialisation_name = "div" - - - def _self_cls(self) -> type: - return Div - - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) / self.b.evaluate(variables) - - def _derivative(self, hash_value: int) -> Operation: - return Sub(Div(self.a.derivative(hash_value), self.b), - Div(Mul(self.a, self.b.derivative(hash_value)), Mul(self.b, self.b))) - - def _clean_ab(self, a, b): - if isinstance(a, AdditiveIdentity): - # Convert 0/b to 0 - return AdditiveIdentity() - - elif isinstance(a, MultiplicativeIdentity): - # Convert 1/b to inverse of b - return Inv(b) - - elif isinstance(b, MultiplicativeIdentity): - # Convert a/1 to a - return a - - elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): - # Convert constants "a"/"b" to "a/b" - return Constant(self.a.evaluate({}) / self.b.evaluate({}))._clean() - - - elif isinstance(a, Inv) and isinstance(b, Inv): - return Div(b.a, a.a) - - elif isinstance(a, Inv) and not isinstance(b, Inv): - return Inv(Mul(a.a, b)) - - elif not isinstance(a, Inv) and isinstance(b, Inv): - return Mul(a, b.a) - - elif a == b: - return MultiplicativeIdentity() - - elif isinstance(a, Pow) and a.a == b: - return Pow(b, a.power - 1) - - elif isinstance(b, Pow) and b.a == a: - return Pow(a, 1 - b.power) - - elif isinstance(a, Pow) and isinstance(b, Pow) and a.a == b.a: - return Pow(a.a, a.power - b.power) - - else: - return Div(a, b) - - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Div(*BinaryOperation._deserialise_ab(parameters)) - - def _summary_open(self): - return "Div" - -class Pow(Operation): - - serialisation_name = "pow" - - def __init__(self, a: Operation, power: float): - self.a = a - self.power = power - def evaluate(self, variables: dict[int, T]) -> T: - return self.a.evaluate(variables) ** self.power - - def _derivative(self, hash_value: int) -> Operation: - if self.power == 0: - return AdditiveIdentity() - - elif self.power == 1: - return self.a._derivative(hash_value) - - else: - return Mul(Constant(self.power), Mul(Pow(self.a, self.power-1), self.a._derivative(hash_value))) - - def _clean(self) -> Operation: - a = self.a._clean() - - if self.power == 1: - return a - - elif self.power == 0: - return MultiplicativeIdentity() - - elif self.power == -1: - return Inv(a) - - else: - return Pow(a, self.power) - - - def _serialise_parameters(self) -> dict[str, Any]: - return {"a": Operation._serialise_json(self.a), - "power": self.power} - - @staticmethod - def _deserialise(parameters: dict) -> "Operation": - return Pow(Operation.deserialise_json(parameters["a"]), parameters["power"]) - - def summary(self, indent_amount: int=0, indent=" "): - return (f"{indent_amount*indent}Pow\n" + - self.a.summary(indent_amount+1, indent) + "\n" + - f"{(indent_amount+1)*indent}{self.power}\n" + - f"{indent_amount*indent})") - - def __eq__(self, other): - if isinstance(other, Pow): - return self.a == other.a and self.power == other.power - -_serialisable_classes = [AdditiveIdentity, MultiplicativeIdentity, Constant, - Variable, - Neg, Inv, - Add, Sub, Mul, Div, Pow] - -_serialisation_lookup = {cls.serialisation_name: cls for cls in _serialisable_classes} diff --git a/sasdata/quantities/operations_examples.py b/sasdata/quantities/operations_examples.py index 4509a86..e2e2566 100644 --- a/sasdata/quantities/operations_examples.py +++ b/sasdata/quantities/operations_examples.py @@ -1,4 +1,4 @@ -from sasdata.quantities.operations import Variable, Mul +from sasdata.quantities.quantity import Variable, Mul x = Variable("x") y = Variable("y") diff --git a/sasdata/quantities/operations_test.py b/sasdata/quantities/operations_test.py index 0899eee..6767e32 100644 --- a/sasdata/quantities/operations_test.py +++ b/sasdata/quantities/operations_test.py @@ -1,6 +1,6 @@ import pytest -from sasdata.quantities.operations import Operation, \ +from sasdata.quantities.quantity import Operation, \ Neg, Inv, \ Add, Sub, Mul, Div, Pow, \ Variable, Constant, AdditiveIdentity, MultiplicativeIdentity diff --git a/sasdata/quantities/plotting.py b/sasdata/quantities/plotting.py new file mode 100644 index 0000000..854e23f --- /dev/null +++ b/sasdata/quantities/plotting.py @@ -0,0 +1,23 @@ +import matplotlib.pyplot as plt +from numpy.typing import ArrayLike + +from sasdata.quantities.quantity import Quantity, NamedQuantity + + +def quantity_plot(x: Quantity[ArrayLike], y: Quantity[ArrayLike], *args, **kwargs): + plt.plot(x.value, y.value, *args, **kwargs) + + x_name = x.name if isinstance(x, NamedQuantity) else "x" + y_name = y.name if isinstance(y, NamedQuantity) else "y" + + plt.xlabel(f"{x_name} / {x.units}") + plt.ylabel(f"{y_name} / {y.units}") + +def quantity_scatter(x: Quantity[ArrayLike], y: Quantity[ArrayLike], *args, **kwargs): + plt.scatter(x.value, y.value, *args, **kwargs) + + x_name = x.name if isinstance(x, NamedQuantity) else "x" + y_name = y.name if isinstance(y, NamedQuantity) else "y" + + plt.xlabel(f"{x_name} / {x.units}") + plt.ylabel(f"{y_name} / {y.units}") diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index 20317cf..584f3cf 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -1,15 +1,969 @@ -from typing import Collection, Sequence, TypeVar, Generic, Self -from dataclasses import dataclass +from typing import Self import numpy as np from numpy._typing import ArrayLike -from sasdata.quantities.operations import Operation, Variable -from sasdata.quantities import operations, units +from sasdata.quantities import units +from sasdata.quantities.numerical_encoding import numerical_decode, numerical_encode from sasdata.quantities.units import Unit, NamedUnit import hashlib +from typing import Any, TypeVar, Union + +import json + +T = TypeVar("T") + + + + + +################### Quantity based operations, need to be here to avoid cyclic dependencies ##################### + +def transpose(a: Union["Quantity[ArrayLike]", ArrayLike], axes: tuple | None = None): + """ Transpose an array or an array based quantity, can also do reordering of axes""" + if isinstance(a, Quantity): + + if axes is None: + return DerivedQuantity(value=np.transpose(a.value, axes=axes), + units=a.units, + history=QuantityHistory.apply_operation(Transpose, a.history)) + + else: + return DerivedQuantity(value=np.transpose(a.value, axes=axes), + units=a.units, + history=QuantityHistory.apply_operation(Transpose, a.history, axes=axes)) + + else: + return np.transpose(a, axes=axes) + + +def dot(a: Union["Quantity[ArrayLike]", ArrayLike], b: Union["Quantity[ArrayLike]", ArrayLike]): + """ Dot product of two arrays or two array based quantities """ + a_is_quantity = isinstance(a, Quantity) + b_is_quantity = isinstance(b, Quantity) + + if a_is_quantity or b_is_quantity: + + # If its only one of them that is a quantity, convert the other one + + if not a_is_quantity: + a = Quantity(a, units.none) + + if not b_is_quantity: + b = Quantity(b, units.none) + + return DerivedQuantity( + value=np.dot(a.value, b.value), + units=a.units * b.units, + history=QuantityHistory.apply_operation(Dot, a.history, b.history)) + + else: + return np.dot(a, b) + +def tensordot(a: Union["Quantity[ArrayLike]", ArrayLike] | ArrayLike, b: Union["Quantity[ArrayLike]", ArrayLike], a_index: int, b_index: int): + """ Tensor dot product - equivalent to contracting two tensors, such as + + A_{i0, i1, i2, i3...} and B_{j0, j1, j2...} + + e.g. if a_index is 1 and b_index is zero, it will be the sum + + C_{i0, i2, i3 ..., j1, j2 ...} = sum_k A_{i0, k, i2, i3 ...} B_{k, j1, j2 ...} + + (I think, have to check what happens with indices TODO!) + + """ + + a_is_quantity = isinstance(a, Quantity) + b_is_quantity = isinstance(b, Quantity) + + if a_is_quantity or b_is_quantity: + + # If its only one of them that is a quantity, convert the other one + + if not a_is_quantity: + a = Quantity(a, units.none) + + if not b_is_quantity: + b = Quantity(b, units.none) + + return DerivedQuantity( + value=np.tensordot(a.value, b.value, axes=(a_index, b_index)), + units=a.units * b.units, + history=QuantityHistory.apply_operation( + TensorDot, + a.history, + b.history, + a_index=a_index, + b_index=b_index)) + + else: + return np.tensordot(a, b, axes=(a_index, b_index)) + + +################### Operation Definitions ####################################### + +def hash_and_name(hash_or_name: int | str): + """ Infer the name of a variable from a hash, or the hash from the name + + Note: hash_and_name(hash_and_name(number)[1]) is not the identity + however: hash_and_name(hash_and_name(number)) is + """ + + if isinstance(hash_or_name, str): + hash_value = hash(hash_or_name) + name = hash_or_name + + return hash_value, name + + elif isinstance(hash_or_name, int): + hash_value = hash_or_name + name = f"#{hash_or_name}" + + return hash_value, name + + elif isinstance(hash_or_name, tuple): + return hash_or_name + + else: + raise TypeError("Variable name_or_hash_value must be either str or int") + +class Operation: + + serialisation_name = "unknown" + def summary(self, indent_amount: int = 0, indent: str=" "): + """ Summary of the operation tree""" + + s = f"{indent_amount*indent}{self._summary_open()}(\n" + + for chunk in self._summary_components(): + s += chunk.summary(indent_amount+1, indent) + "\n" + + s += f"{indent_amount*indent})" + + return s + def _summary_open(self): + """ First line of summary """ + + def _summary_components(self) -> list["Operation"]: + return [] + def evaluate(self, variables: dict[int, T]) -> T: + + """ Evaluate this operation """ + + def _derivative(self, hash_value: int) -> "Operation": + """ Get the derivative of this operation """ + + def _clean(self): + """ Clean up this operation - i.e. remove silly things like 1*x """ + return self + + def derivative(self, variable: Union[str, int, "Variable"], simplify=True): + if isinstance(variable, Variable): + hash_value = variable.hash_value + else: + hash_value, _ = hash_and_name(variable) + + derivative = self._derivative(hash_value) + + if not simplify: + return derivative + + derivative_string = derivative.serialise() + + # print("---------------") + # print("Base") + # print("---------------") + # print(derivative.summary()) + + # Inefficient way of doing repeated simplification, but it will work + for i in range(100): # set max iterations + + derivative = derivative._clean() + # + # print("-------------------") + # print("Iteration", i+1) + # print("-------------------") + # print(derivative.summary()) + # print("-------------------") + + new_derivative_string = derivative.serialise() + + if derivative_string == new_derivative_string: + break + + derivative_string = new_derivative_string + + return derivative + + @staticmethod + def deserialise(data: str) -> "Operation": + json_data = json.loads(data) + return Operation.deserialise_json(json_data) + + @staticmethod + def deserialise_json(json_data: dict) -> "Operation": + + operation = json_data["operation"] + parameters = json_data["parameters"] + cls = _serialisation_lookup[operation] + + try: + return cls._deserialise(parameters) + + except NotImplementedError: + raise NotImplementedError(f"No method to deserialise {operation} with {parameters} (cls={cls})") + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + raise NotImplementedError(f"Deserialise not implemented for this class") + + def serialise(self) -> str: + return json.dumps(self._serialise_json()) + + def _serialise_json(self) -> dict[str, Any]: + return {"operation": self.serialisation_name, + "parameters": self._serialise_parameters()} + + def _serialise_parameters(self) -> dict[str, Any]: + raise NotImplementedError("_serialise_parameters not implemented") + + def __eq__(self, other: "Operation"): + return NotImplemented + +class ConstantBase(Operation): + pass + +class AdditiveIdentity(ConstantBase): + + serialisation_name = "zero" + def evaluate(self, variables: dict[int, T]) -> T: + return 0 + + def _derivative(self, hash_value: int) -> Operation: + return AdditiveIdentity() + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return AdditiveIdentity() + + def _serialise_parameters(self) -> dict[str, Any]: + return {} + + def summary(self, indent_amount: int=0, indent=" "): + return f"{indent_amount*indent}0 [Add.Id.]" + + def __eq__(self, other): + if isinstance(other, AdditiveIdentity): + return True + elif isinstance(other, Constant): + if other.value == 0: + return True + + return False + + + +class MultiplicativeIdentity(ConstantBase): + + serialisation_name = "one" + + def evaluate(self, variables: dict[int, T]) -> T: + return 1 + + def _derivative(self, hash_value: int): + return AdditiveIdentity() + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return MultiplicativeIdentity() + + + def _serialise_parameters(self) -> dict[str, Any]: + return {} + + + def summary(self, indent_amount: int=0, indent=" "): + return f"{indent_amount*indent}1 [Mul.Id.]" + + def __eq__(self, other): + if isinstance(other, MultiplicativeIdentity): + return True + elif isinstance(other, Constant): + if other.value == 1: + return True + + return False + + +class Constant(ConstantBase): + + serialisation_name = "constant" + def __init__(self, value): + self.value = value + + def summary(self, indent_amount: int = 0, indent: str=" "): + return repr(self.value) + + def evaluate(self, variables: dict[int, T]) -> T: + return self.value + + def _derivative(self, hash_value: int): + return AdditiveIdentity() + + def _clean(self): + + if self.value == 0: + return AdditiveIdentity() + + elif self.value == 1: + return MultiplicativeIdentity() + + else: + return self + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + value = numerical_decode(parameters["value"]) + return Constant(value) + + + def _serialise_parameters(self) -> dict[str, Any]: + return {"value": numerical_encode(self.value)} + + def summary(self, indent_amount: int=0, indent=" "): + return f"{indent_amount*indent}{self.value}" + + def __eq__(self, other): + if isinstance(other, AdditiveIdentity): + return self.value == 0 + + elif isinstance(other, MultiplicativeIdentity): + return self.value == 1 + + elif isinstance(other, Constant): + if other.value == self.value: + return True + + return False + + +class Variable(Operation): + + serialisation_name = "variable" + def __init__(self, name_or_hash_value: int | str | tuple[int, str]): + self.hash_value, self.name = hash_and_name(name_or_hash_value) + + def evaluate(self, variables: dict[int, T]) -> T: + try: + return variables[self.hash_value] + except KeyError: + raise ValueError(f"Variable dictionary didn't have an entry for {self.name} (hash={self.hash_value})") + + def _derivative(self, hash_value: int) -> Operation: + if hash_value == self.hash_value: + return MultiplicativeIdentity() + else: + return AdditiveIdentity() + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + hash_value = parameters["hash_value"] + name = parameters["name"] + + return Variable((hash_value, name)) + + def _serialise_parameters(self) -> dict[str, Any]: + return {"hash_value": self.hash_value, + "name": self.name} + + def summary(self, indent_amount: int = 0, indent: str=" "): + return f"{indent_amount*indent}{self.name}" + + def __eq__(self, other): + if isinstance(other, Variable): + return self.hash_value == other.hash_value + + return False + +class UnaryOperation(Operation): + + def __init__(self, a: Operation): + self.a = a + + def _serialise_parameters(self) -> dict[str, Any]: + return {"a": self.a._serialise_json()} + + def _summary_components(self) -> list["Operation"]: + return [self.a] + + + + +class Neg(UnaryOperation): + + serialisation_name = "neg" + def evaluate(self, variables: dict[int, T]) -> T: + return -self.a.evaluate(variables) + + def _derivative(self, hash_value: int): + return Neg(self.a._derivative(hash_value)) + + def _clean(self): + + clean_a = self.a._clean() + + if isinstance(clean_a, Neg): + # Removes double negations + return clean_a.a + + elif isinstance(clean_a, Constant): + return Constant(-clean_a.value)._clean() + + else: + return Neg(clean_a) + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Neg(Operation.deserialise_json(parameters["a"])) + + + def _summary_open(self): + return "Neg" + + def __eq__(self, other): + if isinstance(other, Neg): + return other.a == self.a + + +class Inv(UnaryOperation): + + serialisation_name = "reciprocal" + + def evaluate(self, variables: dict[int, T]) -> T: + return 1/self.a.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Neg(Div(self.a._derivative(hash_value), Mul(self.a, self.a))) + + def _clean(self): + clean_a = self.a._clean() + + if isinstance(clean_a, Inv): + # Removes double negations + return clean_a.a + + elif isinstance(clean_a, Neg): + # cannonicalise 1/-a to -(1/a) + # over multiple iterations this should have the effect of ordering and gathering Neg and Inv + return Neg(Inv(clean_a.a)) + + elif isinstance(clean_a, Constant): + return Constant(1/clean_a.value)._clean() + + else: + return Inv(clean_a) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Inv(Operation.deserialise_json(parameters["a"])) + + def _summary_open(self): + return "Inv" + + + def __eq__(self, other): + if isinstance(other, Inv): + return other.a == self.a + +class BinaryOperation(Operation): + def __init__(self, a: Operation, b: Operation): + self.a = a + self.b = b + + def _clean(self): + return self._clean_ab(self.a._clean(), self.b._clean()) + + def _clean_ab(self, a, b): + raise NotImplementedError("_clean_ab not implemented") + + def _serialise_parameters(self) -> dict[str, Any]: + return {"a": self.a._serialise_json(), + "b": self.b._serialise_json()} + + @staticmethod + def _deserialise_ab(parameters) -> tuple[Operation, Operation]: + return (Operation.deserialise_json(parameters["a"]), + Operation.deserialise_json(parameters["b"])) + + + def _summary_components(self) -> list["Operation"]: + return [self.a, self.b] + + def _self_cls(self) -> type: + """ Own class""" + def __eq__(self, other): + if isinstance(other, self._self_cls()): + return other.a == self.a and self.b == other.b + +class Add(BinaryOperation): + + serialisation_name = "add" + + def _self_cls(self) -> type: + return Add + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) + self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Add(self.a._derivative(hash_value), self.b._derivative(hash_value)) + + def _clean_ab(self, a, b): + + if isinstance(a, AdditiveIdentity): + # Convert 0 + b to b + return b + + elif isinstance(b, AdditiveIdentity): + # Convert a + 0 to a + return a + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant "a"+"b" to "a+b" + return Constant(a.evaluate({}) + b.evaluate({}))._clean() + + elif isinstance(a, Neg): + if isinstance(b, Neg): + # Convert (-a)+(-b) to -(a+b) + return Neg(Add(a.a, b.a)) + else: + # Convert (-a) + b to b-a + return Sub(b, a.a) + + elif isinstance(b, Neg): + # Convert a+(-b) to a-b + return Sub(a, b.a) + + elif a == b: + return Mul(Constant(2), a) + + else: + return Add(a, b) + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Add(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "Add" + +class Sub(BinaryOperation): + + serialisation_name = "sub" + + + def _self_cls(self) -> type: + return Sub + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) - self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Sub(self.a._derivative(hash_value), self.b._derivative(hash_value)) + + def _clean_ab(self, a, b): + if isinstance(a, AdditiveIdentity): + # Convert 0 - b to -b + return Neg(b) + + elif isinstance(b, AdditiveIdentity): + # Convert a - 0 to a + return a + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant pair "a" - "b" to "a-b" + return Constant(a.evaluate({}) - b.evaluate({}))._clean() + + elif isinstance(a, Neg): + if isinstance(b, Neg): + # Convert (-a)-(-b) to b-a + return Sub(b.a, a.a) + else: + # Convert (-a)-b to -(a+b) + return Neg(Add(a.a, b)) + + elif isinstance(b, Neg): + # Convert a-(-b) to a+b + return Add(a, b.a) + + elif a == b: + return AdditiveIdentity() + + else: + return Sub(a, b) + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Sub(*BinaryOperation._deserialise_ab(parameters)) + + + def _summary_open(self): + return "Sub" + +class Mul(BinaryOperation): + + serialisation_name = "mul" + + + def _self_cls(self) -> type: + return Mul + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) * self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Add(Mul(self.a, self.b._derivative(hash_value)), Mul(self.a._derivative(hash_value), self.b)) + + def _clean_ab(self, a, b): + + if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): + # Convert 0*b or a*0 to 0 + return AdditiveIdentity() + + elif isinstance(a, MultiplicativeIdentity): + # Convert 1*b to b + return b + + elif isinstance(b, MultiplicativeIdentity): + # Convert a*1 to a + return a + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant "a"*"b" to "a*b" + return Constant(a.evaluate({}) * b.evaluate({}))._clean() + + elif isinstance(a, Inv) and isinstance(b, Inv): + return Inv(Mul(a.a, b.a)) + + elif isinstance(a, Inv) and not isinstance(b, Inv): + return Div(b, a.a) + + elif not isinstance(a, Inv) and isinstance(b, Inv): + return Div(a, b.a) + + elif isinstance(a, Neg): + return Neg(Mul(a.a, b)) + + elif isinstance(b, Neg): + return Neg(Mul(a, b.a)) + + elif a == b: + return Pow(a, 2) + + elif isinstance(a, Pow) and a.a == b: + return Pow(b, a.power + 1) + + elif isinstance(b, Pow) and b.a == a: + return Pow(a, b.power + 1) + + elif isinstance(a, Pow) and isinstance(b, Pow) and a.a == b.a: + return Pow(a.a, a.power + b.power) + + else: + return Mul(a, b) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Mul(*BinaryOperation._deserialise_ab(parameters)) + + + def _summary_open(self): + return "Mul" + +class Div(BinaryOperation): + + serialisation_name = "div" + + + def _self_cls(self) -> type: + return Div + + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) / self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Sub(Div(self.a.derivative(hash_value), self.b), + Div(Mul(self.a, self.b.derivative(hash_value)), Mul(self.b, self.b))) + + def _clean_ab(self, a, b): + if isinstance(a, AdditiveIdentity): + # Convert 0/b to 0 + return AdditiveIdentity() + + elif isinstance(a, MultiplicativeIdentity): + # Convert 1/b to inverse of b + return Inv(b) + + elif isinstance(b, MultiplicativeIdentity): + # Convert a/1 to a + return a + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constants "a"/"b" to "a/b" + return Constant(self.a.evaluate({}) / self.b.evaluate({}))._clean() + + + elif isinstance(a, Inv) and isinstance(b, Inv): + return Div(b.a, a.a) + + elif isinstance(a, Inv) and not isinstance(b, Inv): + return Inv(Mul(a.a, b)) + + elif not isinstance(a, Inv) and isinstance(b, Inv): + return Mul(a, b.a) + + elif a == b: + return MultiplicativeIdentity() + + elif isinstance(a, Pow) and a.a == b: + return Pow(b, a.power - 1) + + elif isinstance(b, Pow) and b.a == a: + return Pow(a, 1 - b.power) + + elif isinstance(a, Pow) and isinstance(b, Pow) and a.a == b.a: + return Pow(a.a, a.power - b.power) + + else: + return Div(a, b) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Div(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "Div" + +class Pow(Operation): + + serialisation_name = "pow" + + def __init__(self, a: Operation, power: float): + self.a = a + self.power = power + + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) ** self.power + + def _derivative(self, hash_value: int) -> Operation: + if self.power == 0: + return AdditiveIdentity() + + elif self.power == 1: + return self.a._derivative(hash_value) + + else: + return Mul(Constant(self.power), Mul(Pow(self.a, self.power-1), self.a._derivative(hash_value))) + + def _clean(self) -> Operation: + a = self.a._clean() + + if self.power == 1: + return a + + elif self.power == 0: + return MultiplicativeIdentity() + + elif self.power == -1: + return Inv(a) + + else: + return Pow(a, self.power) + + + def _serialise_parameters(self) -> dict[str, Any]: + return {"a": Operation._serialise_json(self.a), + "power": self.power} + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Pow(Operation.deserialise_json(parameters["a"]), parameters["power"]) + + def summary(self, indent_amount: int=0, indent=" "): + return (f"{indent_amount*indent}Pow\n" + + self.a.summary(indent_amount+1, indent) + "\n" + + f"{(indent_amount+1)*indent}{self.power}\n" + + f"{indent_amount*indent})") + + def __eq__(self, other): + if isinstance(other, Pow): + return self.a == other.a and self.power == other.power + + + +# +# Matrix operations +# + +class Transpose(Operation): + """ Transpose operation - as per numpy""" + + serialisation_name = "transpose" + + def __init__(self, a: Operation, axes: tuple[int] | None = None): + self.a = a + self.axes = axes + + def evaluate(self, variables: dict[int, T]) -> T: + return np.transpose(self.a.evaluate(variables)) + + def _derivative(self, hash_value: int) -> Operation: + return Transpose(self.a.derivative(hash_value)) # TODO: Check! + + def _clean(self): + clean_a = self.a._clean() + return Transpose(clean_a) + + + def _serialise_parameters(self) -> dict[str, Any]: + if self.axes is None: + return { "a": self.a._serialise_json() } + else: + return { + "a": self.a._serialise_json(), + "axes": list(self.axes) + } + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + if "axes" in parameters: + return Transpose( + a=Operation.deserialise_json(parameters["a"]), + axes=tuple(parameters["axes"])) + else: + return Transpose( + a=Operation.deserialise_json(parameters["a"])) + + + def _summary_open(self): + return "Transpose" + + def __eq__(self, other): + if isinstance(other, Transpose): + return other.a == self.a + + +class Dot(BinaryOperation): + """ Dot product - backed by numpy's dot method""" + + serialisation_name = "dot" + + def evaluate(self, variables: dict[int, T]) -> T: + return dot(self.a.evaluate(variables), self.b.evaluate(variables)) + + def _derivative(self, hash_value: int) -> Operation: + return Add( + Dot(self.a, + self.b._derivative(hash_value)), + Dot(self.a._derivative(hash_value), + self.b)) + + def _clean_ab(self, a, b): + return Dot(a, b) # Do nothing for now + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Dot(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "Dot" + + +# TODO: Add to base operation class, and to quantities +class MatMul(BinaryOperation): + """ Matrix multiplication, using __matmul__ dunder""" + + serialisation_name = "matmul" + + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) @ self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Add( + MatMul(self.a, + self.b._derivative(hash_value)), + MatMul(self.a._derivative(hash_value), + self.b)) + + def _clean_ab(self, a, b): + + if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): + # Convert 0*b or a*0 to 0 + return AdditiveIdentity() + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant "a"@"b" to "a@b" + return Constant(a.evaluate({}) @ b.evaluate({}))._clean() + + elif isinstance(a, Neg): + return Neg(Mul(a.a, b)) + + elif isinstance(b, Neg): + return Neg(Mul(a, b.a)) + + return MatMul(a, b) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return MatMul(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "MatMul" + +class TensorDot(Operation): + serialisation_name = "tensor_product" + + def __init__(self, a: Operation, b: Operation, a_index: int, b_index: int): + self.a = a + self.b = b + self.a_index = a_index + self.b_index = b_index + + def evaluate(self, variables: dict[int, T]) -> T: + return tensordot(self.a, self.b, self.a_index, self.b_index) + + + def _serialise_parameters(self) -> dict[str, Any]: + return { + "a": self.a._serialise_json(), + "b": self.b._serialise_json(), + "a_index": self.a_index, + "b_index": self.b_index } + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return TensorDot(a = Operation.deserialise_json(parameters["a"]), + b = Operation.deserialise_json(parameters["b"]), + a_index=int(parameters["a_index"]), + b_index=int(parameters["b_index"])) + + def _summary_open(self): + return "TensorProduct" + + +_serialisable_classes = [AdditiveIdentity, MultiplicativeIdentity, Constant, + Variable, + Neg, Inv, + Add, Sub, Mul, Div, Pow, + Transpose, Dot, MatMul, TensorDot] + +_serialisation_lookup = {cls.serialisation_name: cls for cls in _serialisable_classes} + class UnitError(Exception): """ Errors caused by unit specification not being correct """ @@ -26,10 +980,25 @@ def hash_data_via_numpy(*data: ArrayLike): return int(md5_hash.hexdigest(), 16) + +##################################### +# # +# # +# # +# Quantities begin here # +# # +# # +# # +##################################### + + + QuantityType = TypeVar("QuantityType") class QuantityHistory: + """ Class that holds the information for keeping track of operations done on quantities """ + def __init__(self, operation_tree: Operation, references: dict[int, "Quantity"]): self.operation_tree = operation_tree self.references = references @@ -43,6 +1012,10 @@ def jacobian(self) -> list[Operation]: # Use the hash value to specify the variable of differentiation return [self.operation_tree.derivative(key) for key in self.reference_key_list] + def _recalculate(self): + """ Recalculate the value of this object - primary use case is for testing """ + return self.operation_tree.evaluate(self.references) + def variance_propagate(self, quantity_units: Unit, covariances: dict[tuple[int, int]: "Quantity"] = {}): """ Do standard error propagation to calculate the uncertainties associated with this quantity @@ -54,14 +1027,6 @@ def variance_propagate(self, quantity_units: Unit, covariances: dict[tuple[int, raise NotImplementedError("User specified covariances not currently implemented") jacobian = self.jacobian() - # jacobian_units = [quantity_units / self.references[key].units for key in self.reference_key_list] - # - # # Evaluate the jacobian - # # TODO: should we use quantities here, does that work automatically? - # evaluated_jacobian = [Quantity( - # value=entry.evaluate(self.si_reference_values), - # units=unit.si_equivalent()) - # for entry, unit in zip(jacobian, jacobian_units)] evaluated_jacobian = [entry.evaluate(self.references) for entry in jacobian] @@ -83,7 +1048,7 @@ def variable(quantity: "Quantity"): return QuantityHistory(Variable(quantity.hash_value), {quantity.hash_value: quantity}) @staticmethod - def apply_operation(operation: type[Operation], *histories: "QuantityHistory") -> "QuantityHistory": + def apply_operation(operation: type[Operation], *histories: "QuantityHistory", **extra_parameters) -> "QuantityHistory": """ Apply an operation to the history This is slightly unsafe as it is possible to attempt to apply an n-ary operation to a number of trees other @@ -100,7 +1065,7 @@ def apply_operation(operation: type[Operation], *histories: "QuantityHistory") - references.update(history.references) return QuantityHistory( - operation(*[history.operation_tree for history in histories]), + operation(*[history.operation_tree for history in histories], **extra_parameters), references) def has_variance(self): @@ -110,6 +1075,16 @@ def has_variance(self): return False + def summary(self): + + variable_strings = [self.references[key].string_repr for key in self.references] + + s = "Variables: "+",".join(variable_strings) + s += "\n" + s += self.operation_tree.summary() + + return s + class Quantity[QuantityType]: @@ -132,10 +1107,10 @@ def __init__(self, self.hash_value = -1 """ Hash based on value and uncertainty for data, -1 if it is a derived hash value """ - """ Contains the variance if it is data driven, else it is """ + self._variance = None + """ Contains the variance if it is data driven """ if standard_error is None: - self._variance = None self.hash_value = hash_data_via_numpy(hash_seed, value) else: self._variance = standard_error ** 2 @@ -205,14 +1180,14 @@ def __mul__(self: Self, other: ArrayLike | Self ) -> Self: return DerivedQuantity( self.value * other.value, self.units * other.units, - history=QuantityHistory.apply_operation(operations.Mul, self.history, other.history)) + history=QuantityHistory.apply_operation(Mul, self.history, other.history)) else: return DerivedQuantity(self.value * other, self.units, QuantityHistory( - operations.Mul( + Mul( self.history.operation_tree, - operations.Constant(other)), + Constant(other)), self.history.references)) def __rmul__(self: Self, other: ArrayLike | Self): @@ -221,33 +1196,72 @@ def __rmul__(self: Self, other: ArrayLike | Self): other.value * self.value, other.units * self.units, history=QuantityHistory.apply_operation( - operations.Mul, + Mul, other.history, self.history)) else: return DerivedQuantity(other * self.value, self.units, QuantityHistory( - operations.Mul( - operations.Constant(other), + Mul( + Constant(other), + self.history.operation_tree), + self.history.references)) + + + def __matmul__(self, other: ArrayLike | Self): + if isinstance(other, Quantity): + return DerivedQuantity( + self.value @ other.value, + self.units * other.units, + history=QuantityHistory.apply_operation( + MatMul, + self.history, + other.history)) + else: + return DerivedQuantity( + self.value @ other, + self.units, + QuantityHistory( + MatMul( + self.history.operation_tree, + Constant(other)), + self.history.references)) + + def __rmatmul__(self, other: ArrayLike | Self): + if isinstance(other, Quantity): + return DerivedQuantity( + other.value @ self.value, + other.units * self.units, + history=QuantityHistory.apply_operation( + MatMul, + other.history, + self.history)) + + else: + return DerivedQuantity(other @ self.value, self.units, + QuantityHistory( + MatMul( + Constant(other), self.history.operation_tree), self.history.references)) + def __truediv__(self: Self, other: float | Self) -> Self: if isinstance(other, Quantity): return DerivedQuantity( self.value / other.value, self.units / other.units, history=QuantityHistory.apply_operation( - operations.Div, + Div, self.history, other.history)) else: return DerivedQuantity(self.value / other, self.units, QuantityHistory( - operations.Div( - operations.Constant(other), + Div( + Constant(other), self.history.operation_tree), self.history.references)) @@ -257,7 +1271,7 @@ def __rtruediv__(self: Self, other: float | Self) -> Self: other.value / self.value, other.units / self.units, history=QuantityHistory.apply_operation( - operations.Div, + Div, other.history, self.history )) @@ -267,8 +1281,8 @@ def __rtruediv__(self: Self, other: float | Self) -> Self: other / self.value, self.units ** -1, QuantityHistory( - operations.Div( - operations.Constant(other), + Div( + Constant(other), self.history.operation_tree), self.history.references)) @@ -279,7 +1293,7 @@ def __add__(self: Self, other: Self | ArrayLike) -> Self: self.value + (other.value * other.units.scale) / self.units.scale, self.units, QuantityHistory.apply_operation( - operations.Add, + Add, self.history, other.history)) else: @@ -293,7 +1307,7 @@ def __add__(self: Self, other: Self | ArrayLike) -> Self: def __neg__(self): return DerivedQuantity(-self.value, self.units, QuantityHistory.apply_operation( - operations.Neg, + Neg, self.history )) @@ -307,7 +1321,7 @@ def __pow__(self: Self, other: int | float): return DerivedQuantity(self.value ** other, self.units ** other, QuantityHistory( - operations.Pow( + Pow( self.history.operation_tree, other), self.history.references)) @@ -359,6 +1373,10 @@ def __repr__(self): def parse(number_or_string: str | ArrayLike, unit: str, absolute_temperature: False): pass + @property + def string_repr(self): + return str(self.hash_value) + class NamedQuantity[QuantityType](Quantity[QuantityType]): def __init__(self, @@ -393,6 +1411,10 @@ def with_standard_error(self, standard_error: Quantity): f"are not compatible with value units ({self.units})") + @property + def string_repr(self): + return self.name + class DerivedQuantity[QuantityType](Quantity[QuantityType]): def __init__(self, value: QuantityType, units: Unit, history: QuantityHistory): super().__init__(value, units, standard_error=None) diff --git a/sasdata/quantities/test_numerical_encoding.py b/sasdata/quantities/test_numerical_encoding.py new file mode 100644 index 0000000..80cfbad --- /dev/null +++ b/sasdata/quantities/test_numerical_encoding.py @@ -0,0 +1,68 @@ +""" Tests for the encoding and decoding of numerical data""" + +import numpy as np +import pytest + +from sasdata.quantities.numerical_encoding import numerical_encode, numerical_decode + + +@pytest.mark.parametrize("value", [-100.0, -10.0, -1.0, 0.0, 0.5, 1.0, 10.0, 100.0, 1e100]) +def test_float_encode_decode(value: float): + + assert isinstance(value, float) # Make sure we have the right inputs + + encoded = numerical_encode(value) + decoded = numerical_decode(encoded) + + assert isinstance(decoded, float) + assert value == decoded + +@pytest.mark.parametrize("value", [-100, -10, -1, 0, 1, 10, 100, 1000000000000000000000000000000000]) +def test_int_encode_decode(value: int): + + assert isinstance(value, int) # Make sure we have the right inputs + + encoded = numerical_encode(value) + decoded = numerical_decode(encoded) + + assert isinstance(decoded, int) + assert value == decoded + +@pytest.mark.parametrize("shape", [ + (2,3,4), + (1,2), + (10,5,10), + (1,), + (4,), + (0, ) ]) +def test_numpy_float_encode_decode(shape): + np.random.seed(1776) + test_matrix = np.random.rand(*shape) + + encoded = numerical_encode(test_matrix) + decoded = numerical_decode(encoded) + + assert decoded.dtype == test_matrix.dtype + assert decoded.shape == test_matrix.shape + assert np.all(decoded == test_matrix) + +@pytest.mark.parametrize("dtype", [int, float, complex]) +def test_numpy_dtypes_encode_decode(dtype): + test_matrix = np.zeros((3,3), dtype=dtype) + + encoded = numerical_encode(test_matrix) + decoded = numerical_decode(encoded) + + assert decoded.dtype == test_matrix.dtype + +@pytest.mark.parametrize("dtype", [int, float, complex]) +@pytest.mark.parametrize("shape, n, m", [ + ((8, 8), (1,3,5),(2,5,7)), + ((6, 8), (1,0,5),(0,5,0)), + ((6, 1), (1, 0, 5), (0, 0, 0)), +]) +def test_coo_matrix_encode_decode(shape, n, m, dtype): + + i_indices = + + values = np.arange(10) \ No newline at end of file diff --git a/sasdata/slicing/__init__.py b/sasdata/slicing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sasdata/slicing/geometry.py b/sasdata/slicing/geometry.py new file mode 100644 index 0000000..e69de29 diff --git a/sasdata/slicing/meshes/__init__.py b/sasdata/slicing/meshes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sasdata/slicing/meshes/delaunay_mesh.py b/sasdata/slicing/meshes/delaunay_mesh.py new file mode 100644 index 0000000..a19c2ac --- /dev/null +++ b/sasdata/slicing/meshes/delaunay_mesh.py @@ -0,0 +1,32 @@ +import numpy as np +from scipy.spatial import Delaunay + +from sasdata.slicing.meshes.mesh import Mesh + +def delaunay_mesh(x, y) -> Mesh: + """ Create a triangulated mesh based on input points """ + + input_data = np.array((x, y)).T + delaunay = Delaunay(input_data) + + return Mesh(points=input_data, cells=delaunay.simplices) + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + + points = np.random.random((100, 2)) + mesh = delaunay_mesh(points[:,0], points[:,1]) + mesh.show(actually_show=False) + + print(mesh.cells[50]) + + # pick random cell to show + for cell in mesh.cells_to_edges[10]: + a, b = mesh.edges[cell] + plt.plot( + [mesh.points[a][0], mesh.points[b][0]], + [mesh.points[a][1], mesh.points[b][1]], + color='r') + + plt.show() diff --git a/sasdata/slicing/meshes/mesh.py b/sasdata/slicing/meshes/mesh.py new file mode 100644 index 0000000..8176633 --- /dev/null +++ b/sasdata/slicing/meshes/mesh.py @@ -0,0 +1,242 @@ +from typing import Sequence + +import numpy as np + +import matplotlib.pyplot as plt +from matplotlib import cm +from matplotlib.collections import LineCollection + +from sasdata.slicing.meshes.util import closed_loop_edges + +class Mesh: + def __init__(self, + points: np.ndarray, + cells: Sequence[Sequence[int]]): + + """ + Object representing a mesh. + + Parameters are the values: + mesh points + map from edge to points + map from cells to edges + + it is done this way to ensure a non-redundant representation of cells and edges, + however there are no checks for the topology of the mesh, this is assumed to be done by + whatever creates it. There are also no checks for ordering of cells. + + :param points: points in 2D forming vertices of the mesh + :param cells: ordered lists of indices of points forming each cell (face) + + """ + + self.points = points + self.cells = cells + + # Get edges + + edges = set() + for cell_index, cell in enumerate(cells): + + for a, b in closed_loop_edges(cell): + # make sure the representation is unique + if a > b: + edges.add((a, b)) + else: + edges.add((b, a)) + + self.edges = list(edges) + + # Associate edges with faces + + edge_lookup = {edge: i for i, edge in enumerate(self.edges)} + self.cells_to_edges = [] + self.cells_to_edges_signs = [] + + for cell in cells: + + this_cell_data = [] + this_sign_data = [] + + for a, b in closed_loop_edges(cell): + # make sure the representation is unique + if a > b: + this_cell_data.append(edge_lookup[(a, b)]) + this_sign_data.append(1) + else: + this_cell_data.append(edge_lookup[(b, a)]) + this_sign_data.append(-1) + + self.cells_to_edges.append(this_cell_data) + self.cells_to_edges_signs.append(this_sign_data) + + # Counts for elements + self.n_points = self.points.shape[0] + self.n_edges = len(self.edges) + self.n_cells = len(self.cells) + + # Areas + self._areas = None + + + @property + def areas(self): + """ Areas of cells """ + + if self._areas is None: + # Calculate areas + areas = [] + for cell in self.cells: + # Use triangle shoelace formula, basically calculate the + # determinant based on of triangles with one point at 0,0 + a_times_2 = 0.0 + for i1, i2 in closed_loop_edges(cell): + p1 = self.points[i1, :] + p2 = self.points[i2, :] + a_times_2 += p1[0]*p2[1] - p1[1]*p2[0] + + areas.append(0.5*np.abs(a_times_2)) + + # Save in cache + self._areas = np.array(areas) + + # Return cache + return self._areas + + + def show(self, actually_show=True, show_labels=False, **kwargs): + """ Show on a plot """ + ax = plt.gca() + segments = [[self.points[edge[0]], self.points[edge[1]]] for edge in self.edges] + line_collection = LineCollection(segments=segments, **kwargs) + ax.add_collection(line_collection) + + if show_labels: + text_color = kwargs["color"] if "color" in kwargs else 'k' + for i, cell in enumerate(self.cells): + xy = np.sum(self.points[cell, :], axis=0)/len(cell) + ax.text(xy[0], xy[1], str(i), horizontalalignment="center", verticalalignment="center", color=text_color) + + x_limits = [np.min(self.points[:,0]), np.max(self.points[:,0])] + y_limits = [np.min(self.points[:,1]), np.max(self.points[:,1])] + + plt.xlim(x_limits) + plt.ylim(y_limits) + + if actually_show: + plt.show() + + def locate_points(self, x: np.ndarray, y: np.ndarray): + """ Find the cells that contain the specified points""" + + x = x.reshape(-1) + y = y.reshape(-1) + + xy = np.concatenate(([x], [y]), axis=1) + + # The most simple implementation is not particularly fast, especially in python + # + # Less obvious, but hopefully faster strategy + # + # Ultimately, checking the inclusion of a point within a polygon + # requires checking the crossings of a half line with the polygon's + # edges. + # + # A fairly efficient thing to do is to check every edge for crossing + # the axis parallel lines x=point_x. + # Then these edges that cross can map back to the polygons they're in + # and a final check for inclusion can be done with the edge sign property + # and some explicit checking of the + # + # Basic idea is: + # 1) build a matrix for each point-edge pair + # True if the edge crosses the half-line above a point + # 2) for each cell get the winding number by evaluating the + # sum of the component edges, weighted 1/-1 according to direction + + + edges = np.array(self.edges) + + edge_xy_1 = self.points[edges[:, 0], :] + edge_xy_2 = self.points[edges[:, 1], :] + + edge_x_1 = edge_xy_1[:, 0] + edge_x_2 = edge_xy_2[:, 0] + + + + # Make an n_edges-by-n_inputs boolean matrix that indicates which of the + # edges cross x=points_x line + crossers = np.logical_xor( + edge_x_1.reshape(-1, 1) < x.reshape(1, -1), + edge_x_2.reshape(-1, 1) < x.reshape(1, -1)) + + # Calculate the gradients, some might be infs, but none that matter will be + # TODO: Disable warnings + gradients = (edge_xy_2[:, 1] - edge_xy_1[:, 1]) / (edge_xy_2[:, 0] - edge_xy_1[:, 0]) + + # Distance to crossing points edge 0 + delta_x = x.reshape(1, -1) - edge_x_1.reshape(-1, 1) + + # Signed distance from point to y (doesn't really matter which sign) + delta_y = gradients.reshape(-1, 1) * delta_x + edge_xy_1[:, 1:] - y.reshape(1, -1) + + score_matrix = np.logical_and(delta_y > 0, crossers) + + output = -np.ones(len(x), dtype=int) + for cell_index, (cell_edges, sign) in enumerate(zip(self.cells_to_edges, self.cells_to_edges_signs)): + cell_score = np.sum(score_matrix[cell_edges, :] * np.array(sign).reshape(-1, 1), axis=0) + points_in_cell = np.abs(cell_score) == 1 + output[points_in_cell] = cell_index + + return output + + def show_data(self, + data: np.ndarray, + cmap='winter', + mesh_color='white', + show_mesh=False, + actually_show=True, + density=False): + + """ Show with data """ + + colormap = cm.get_cmap(cmap, 256) + + data = data.reshape(-1) + + if density: + data = data / self.areas + + cmin = np.min(data) + cmax = np.max(data) + + color_index_map = np.array(255 * (data - cmin) / (cmax - cmin), dtype=int) + + for cell, color_index in zip(self.cells, color_index_map): + + color = colormap(color_index) + + plt.fill(self.points[cell, 0], self.points[cell, 1], color=color, edgecolor=None) + + if show_mesh: + self.show(actually_show=False, color=mesh_color) + + if actually_show: + self.show() + + +if __name__ == "__main__": + from test.slicers.meshes_for_testing import location_test_mesh, location_test_points_x, location_test_points_y + + cell_indices = location_test_mesh.locate_points(location_test_points_x, location_test_points_y) + + print(cell_indices) + + for i in range(location_test_mesh.n_cells): + inds = cell_indices == i + plt.scatter( + location_test_points_x.reshape(-1)[inds], + location_test_points_y.reshape(-1)[inds]) + + location_test_mesh.show() \ No newline at end of file diff --git a/sasdata/slicing/meshes/meshmerge.py b/sasdata/slicing/meshes/meshmerge.py new file mode 100644 index 0000000..2060cc7 --- /dev/null +++ b/sasdata/slicing/meshes/meshmerge.py @@ -0,0 +1,154 @@ +import numpy as np + +from sasdata.slicing.meshes.mesh import Mesh +from sasdata.slicing.meshes.delaunay_mesh import delaunay_mesh + +import time + +def meshmerge(mesh_a: Mesh, mesh_b: Mesh) -> tuple[Mesh, np.ndarray, np.ndarray]: + """ Take two lists of polygons and find their intersections + + Polygons in each of the input variables should not overlap i.e. a point in space should be assignable to + at most one polygon in mesh_a and at most one polygon in mesh_b + + Mesh topology should be sensible, otherwise bad things might happen, also, the cells of the input meshes + must be in order (which is assumed by the mesh class constructor anyway). + + :returns: + 1) A triangulated mesh based on both sets of polygons together + 2) The indices of the mesh_a polygon that corresponds to each triangle, -1 for nothing + 3) The indices of the mesh_b polygon that corresponds to each triangle, -1 for nothing + + """ + + t0 = time.time() + + # Find intersections of all edges in mesh one with edges in mesh two + + # Fastest way might just be to calculate the intersections of all lines on edges, + # see whether we need filtering afterwards + + edges_a = np.array(mesh_a.edges, dtype=int) + edges_b = np.array(mesh_b.edges, dtype=int) + + edge_a_1 = mesh_a.points[edges_a[:, 0], :] + edge_a_2 = mesh_a.points[edges_a[:, 1], :] + edge_b_1 = mesh_b.points[edges_b[:, 0], :] + edge_b_2 = mesh_b.points[edges_b[:, 1], :] + + a_grid, b_grid = np.mgrid[0:mesh_a.n_edges, 0:mesh_b.n_edges] + a_grid = a_grid.reshape(-1) + b_grid = b_grid.reshape(-1) + + p1 = edge_a_1[a_grid, :] + p2 = edge_a_2[a_grid, :] + p3 = edge_b_1[b_grid, :] + p4 = edge_b_2[b_grid, :] + + # + # TODO: Investigate whether adding a bounding box check will help with speed, seems likely as most edges wont cross + # + + # + # Solve the equations + # + # z_a1 + s delta_z_a = z_b1 + t delta_z_b + # + # for z = (x, y) + # + + start_point_diff = p1 - p3 + + delta1 = p2 - p1 + delta3 = p4 - p3 + + deltas = np.concatenate(([-delta1], [delta3]), axis=0) + deltas = np.moveaxis(deltas, 0, 2) + + non_singular = np.linalg.det(deltas) != 0 + + st = np.linalg.solve(deltas[non_singular], start_point_diff[non_singular]) + + # Find the points where s and t are in (0, 1) + + intersection_inds = np.logical_and( + np.logical_and(0 < st[:, 0], st[:, 0] < 1), + np.logical_and(0 < st[:, 1], st[:, 1] < 1)) + + start_points_for_intersections = p1[non_singular][intersection_inds, :] + deltas_for_intersections = delta1[non_singular][intersection_inds, :] + + points_to_add = start_points_for_intersections + st[intersection_inds, 0].reshape(-1,1) * deltas_for_intersections + + t1 = time.time() + print("Edge intersections:", t1 - t0) + + # Build list of all input points, in a way that we can check for coincident points + + + points = np.concatenate(( + mesh_a.points, + mesh_b.points, + points_to_add + )) + + + # Remove coincident points + + points = np.unique(points, axis=0) + + # Triangulate based on these intersections + + output_mesh = delaunay_mesh(points[:, 0], points[:, 1]) + + + t2 = time.time() + print("Delaunay:", t2 - t1) + + + # Find centroids of all output triangles, and find which source cells they belong to + + ## step 1) Assign -1 to all cells of original meshes + assignments_a = -np.ones(output_mesh.n_cells, dtype=int) + assignments_b = -np.ones(output_mesh.n_cells, dtype=int) + + ## step 2) Find centroids of triangulated mesh (just needs to be a point inside, but this is a good one) + centroids = [] + for cell in output_mesh.cells: + centroid = np.sum(output_mesh.points[cell, :]/3, axis=0) + centroids.append(centroid) + + centroids = np.array(centroids) + + t3 = time.time() + print("Centroids:", t3 - t2) + + + ## step 3) Find where points belong based on Mesh classes point location algorithm + + assignments_a = mesh_a.locate_points(centroids[:, 0], centroids[:, 1]) + assignments_b = mesh_b.locate_points(centroids[:, 0], centroids[:, 1]) + + t4 = time.time() + print("Assignments:", t4 - t3) + + return output_mesh, assignments_a, assignments_b + + +def main(): + from voronoi_mesh import voronoi_mesh + + n1 = 100 + n2 = 100 + + m1 = voronoi_mesh(np.random.random(n1), np.random.random(n1)) + m2 = voronoi_mesh(np.random.random(n2), np.random.random(n2)) + + + mesh, assignement1, assignement2 = meshmerge(m1, m2) + + mesh.show() + + +if __name__ == "__main__": + main() diff --git a/sasdata/slicing/meshes/util.py b/sasdata/slicing/meshes/util.py new file mode 100644 index 0000000..b78a9e0 --- /dev/null +++ b/sasdata/slicing/meshes/util.py @@ -0,0 +1,10 @@ +from typing import Sequence, TypeVar + +T = TypeVar("T") + +def closed_loop_edges(values: Sequence[T]) -> tuple[T, T]: + """ Generator for a closed loop of edge pairs """ + for pair in zip(values, values[1:]): + yield pair + + yield values[-1], values[0] \ No newline at end of file diff --git a/sasdata/slicing/meshes/voronoi_mesh.py b/sasdata/slicing/meshes/voronoi_mesh.py new file mode 100644 index 0000000..d47dc2c --- /dev/null +++ b/sasdata/slicing/meshes/voronoi_mesh.py @@ -0,0 +1,96 @@ +import numpy as np +from scipy.spatial import Voronoi + + +from sasdata.slicing.meshes.mesh import Mesh + +def voronoi_mesh(x, y, debug_plot=False) -> Mesh: + """ Create a mesh based on a voronoi diagram of points """ + + input_data = np.array((x.reshape(-1), y.reshape(-1))).T + + # Need to make sure mesh covers a finite region, probably not important for + # much data stuff, but is important for plotting + # + # * We want the cells at the edge of the mesh to have a reasonable size, definitely not infinite + # * The exact size doesn't matter that much + # * It should work well with a grid, but also + # * ...it should be robust so that if the data isn't on a grid, it doesn't cause any serious problems + # + # Plan: Create a square border of points that are totally around the points, this is + # at the distance it would be if it was an extra row of grid points + # to do this we'll need + # 1) an estimate of the grid spacing + # 2) the bounding box of the grid + # + + + # Use the median area of finite voronoi cells as an estimate + voronoi = Voronoi(input_data) + finite_cells = [region for region in voronoi.regions if -1 not in region and len(region) > 0] + premesh = Mesh(points=voronoi.vertices, cells=finite_cells) + + area_spacing = np.median(premesh.areas) + gap = np.sqrt(area_spacing) + + # Bounding box is easy + x_min, y_min = np.min(input_data, axis=0) + x_max, y_max = np.max(input_data, axis=0) + + # Create a border + n_x = int(np.round((x_max - x_min)/gap)) + n_y = int(np.round((y_max - y_min)/gap)) + + top_bottom_xs = np.linspace(x_min - gap, x_max + gap, n_x + 3) + left_right_ys = np.linspace(y_min, y_max, n_y + 1) + + top = np.array([top_bottom_xs, (y_max + gap) * np.ones_like(top_bottom_xs)]) + bottom = np.array([top_bottom_xs, (y_min - gap) * np.ones_like(top_bottom_xs)]) + left = np.array([(x_min - gap) * np.ones_like(left_right_ys), left_right_ys]) + right = np.array([(x_max + gap) * np.ones_like(left_right_ys), left_right_ys]) + + added_points = np.concatenate((top, bottom, left, right), axis=1).T + + if debug_plot: + import matplotlib.pyplot as plt + plt.scatter(x, y) + plt.scatter(added_points[:, 0], added_points[:, 1]) + plt.show() + + new_points = np.concatenate((input_data, added_points), axis=0) + voronoi = Voronoi(new_points) + + # Remove the cells that correspond to the added edge points, + # Because the points on the edge of the square are (weakly) convex, these + # regions be infinite + + # finite_cells = [region for region in voronoi.regions if -1 not in region and len(region) > 0] + + # ... however, we can just use .region_points + input_regions = voronoi.point_region[:input_data.shape[0]] + cells = [voronoi.regions[region_index] for region_index in input_regions] + + return Mesh(points=voronoi.vertices, cells=cells) + + +def square_grid_check(): + values = np.linspace(-10, 10, 21) + x, y = np.meshgrid(values, values) + + mesh = voronoi_mesh(x, y) + + mesh.show(show_labels=True) + +def random_grid_check(): + import matplotlib.pyplot as plt + points = np.random.random((100, 2)) + mesh = voronoi_mesh(points[:, 0], points[:, 1], True) + mesh.show(actually_show=False) + plt.scatter(points[:, 0], points[:, 1]) + plt.show() + + +if __name__ == "__main__": + square_grid_check() + # random_grid_check() + diff --git a/sasdata/slicing/rebinning.py b/sasdata/slicing/rebinning.py new file mode 100644 index 0000000..f2c76de --- /dev/null +++ b/sasdata/slicing/rebinning.py @@ -0,0 +1,149 @@ +from abc import ABC, abstractmethod +from typing import Optional +from dataclasses import dataclass + +import numpy as np + +from sasdata.slicing.meshes.mesh import Mesh +from sasdata.slicing.meshes.voronoi_mesh import voronoi_mesh +from sasdata.slicing.meshes.meshmerge import meshmerge + +import time + +@dataclass +class CacheData: + """ Data cached for repeated calculations with the same coordinates """ + input_coordinates: np.ndarray # Input data + input_coordinates_mesh: Mesh # Mesh of the input data + merged_mesh_data: tuple[Mesh, np.ndarray, np.ndarray] # mesh information about the merging + + +class Rebinner(ABC): + + + def __init__(self): + """ Base class for rebinning methods""" + + self._bin_mesh_cache: Optional[Mesh] = None # cached version of the output bin mesh + + # Output dependent caching + self._input_cache: Optional[CacheData] = None + + + @abstractmethod + def _bin_coordinates(self) -> np.ndarray: + """ Coordinates for the output bins """ + + @abstractmethod + def _bin_mesh(self) -> Mesh: + """ Get the meshes used for binning """ + + @property + def allowable_orders(self) -> list[int]: + return [-1, 0, 1] + + @property + def bin_mesh(self) -> Mesh: + + if self._bin_mesh_cache is None: + bin_mesh = self._bin_mesh() + self._bin_mesh_cache = bin_mesh + + return self._bin_mesh_cache + + def _post_processing(self, coordinates, values) -> tuple[np.ndarray, np.ndarray]: + """ Perform post-processing on the mesh binned values """ + # Default is to do nothing, override if needed + return coordinates, values + + def _calculate(self, input_coordinates: np.ndarray, input_data: np.ndarray, order: int) -> np.ndarray: + """ Main calculation """ + + if order == -1: + # Construct the input output mapping just based on input points being the output cells, + # Equivalent to the original binning method + + mesh = self.bin_mesh + bin_identities = mesh.locate_points(input_coordinates[:,0], input_coordinates[:, 1]) + output_data = np.zeros(mesh.n_cells, dtype=float) + + for index, bin in enumerate(bin_identities): + if bin >= 0: + output_data[bin] += input_data[index] + + return output_data + + else: + # Use a mapping based on meshes + + # Either create de-cache the appropriate mesh + # Why not use a hash? Hashing takes time, equality checks are pretty fast, need to check equality + # when there is a hit anyway in case of very rare chance of collision, hits are the most common case, + # we want it to work 100% of the time, not 99.9999% + if self._input_cache is not None and np.all(self._input_cache.input_coordinates == input_coordinates): + + input_coordinate_mesh = self._input_cache.input_coordinates_mesh + merge_data = self._input_cache.merged_mesh_data + + else: + # Calculate mesh data + input_coordinate_mesh = voronoi_mesh(input_coordinates[:,0], input_coordinates[:, 1]) + self._data_mesh_cache = input_coordinate_mesh + + merge_data = meshmerge(self.bin_mesh, input_coordinate_mesh) + + # Cache mesh data + self._input_cache = CacheData( + input_coordinates=input_coordinates, + input_coordinates_mesh=input_coordinate_mesh, + merged_mesh_data=merge_data) + + merged_mesh, merged_to_output, merged_to_input = merge_data + + # Calculate values according to the order parameter + t0 = time.time() + if order == 0: + # Based on the overlap of cells only + + input_areas = input_coordinate_mesh.areas + output = np.zeros(self.bin_mesh.n_cells, dtype=float) + + for input_index, output_index, area in zip(merged_to_input, merged_to_output, merged_mesh.areas): + if input_index == -1 or output_index == -1: + # merged region does not correspond to anything of interest + continue + + output[output_index] += input_data[input_index] * area / input_areas[input_index] + + print("Main calc:", time.time() - t0) + + return output + + elif order == 1: + # Linear interpolation requires the following relationship with the data, + # as the input data is the total over the whole input cell, the linear + # interpolation requires continuity at the vertices, and a constraint on the + # integral. + # + # We can take each of the input points, and the associated values, and solve a system + # of linear equations that gives a total value. + + raise NotImplementedError("1st order (linear) interpolation currently not implemented") + + else: + raise ValueError(f"Expected order to be in {self.allowable_orders}, got {order}") + + def sum(self, x: np.ndarray, y: np.ndarray, data: np.ndarray, order: int = 0) -> np.ndarray: + """ Return the summed data in the output bins """ + return self._calculate(np.array((x.reshape(-1), y.reshape(-1))).T, data.reshape(-1), order) + + def error_propagate(self, input_coordinates: np.ndarray, data: np.ndarray, errors) -> np.ndarray: + raise NotImplementedError("Error propagation not implemented yet") + + def resolution_propagate(self, input_coordinates: np.ndarray, data: np.ndarray, errors) -> np.ndarray: + raise NotImplementedError("Resolution propagation not implemented yet") + + def average(self, x: np.ndarray, y: np.ndarray, data: np.ndarray, order: int = 0) -> np.ndarray: + """ Return the averaged data in the output bins """ + return self._calculate(np.array((x, y)).T, data.reshape(-1), order) / self.bin_mesh.areas + diff --git a/sasdata/slicing/sample_polygons.py b/sasdata/slicing/sample_polygons.py new file mode 100644 index 0000000..e12fb1e --- /dev/null +++ b/sasdata/slicing/sample_polygons.py @@ -0,0 +1,31 @@ +import numpy as np + +def wedge(q0, q1, theta0, theta1, clockwise=False, n_points_per_degree=2): + + # Traverse a rectangle in curvilinear coordinates (q0, theta0), (q0, theta1), (q1, theta1), (q1, theta0) + if clockwise: + if theta1 > theta0: + theta0 += 2*np.pi + + else: + if theta0 > theta1: + theta1 += 2*np.pi + + subtended_angle = np.abs(theta1 - theta0) + n_points = int(subtended_angle*180*n_points_per_degree/np.pi)+1 + + angles = np.linspace(theta0, theta1, n_points) + + xs = np.concatenate((q0*np.cos(angles), q1*np.cos(angles[::-1]))) + ys = np.concatenate((q0*np.sin(angles), q1*np.sin(angles[::-1]))) + + return np.array((xs, ys)).T + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + xy = wedge(0.3, 0.6, 2, 3) + + plt.plot(xy[:,0], xy[:,1]) + plt.show() + diff --git a/sasdata/slicing/slicer_demo.py b/sasdata/slicing/slicer_demo.py new file mode 100644 index 0000000..af3ee98 --- /dev/null +++ b/sasdata/slicing/slicer_demo.py @@ -0,0 +1,120 @@ +""" Dev docs: Demo to show the behaviour of the re-binning methods """ + +import numpy as np + +import matplotlib.pyplot as plt + +from sasdata.slicing.slicers.AnularSector import AnularSector +from sasdata.slicing.meshes.voronoi_mesh import voronoi_mesh + + + +if __name__ == "__main__": + q_range = 1.5 + demo1 = True + demo2 = True + + # Demo of sums, annular sector over some not very circular data + + if demo1: + + x = (2 * q_range) * (np.random.random(400) - 0.5) + y = (2 * q_range) * (np.random.random(400) - 0.5) + + display_mesh = voronoi_mesh(x, y) + + + def lobe_test_function(x, y): + return 1 + np.sin(x*np.pi/q_range)*np.sin(y*np.pi/q_range) + + + random_lobe_data = lobe_test_function(x, y) + + plt.figure("Input Dataset 1") + display_mesh.show_data(random_lobe_data, actually_show=False) + + data_order_0 = [] + data_order_neg1 = [] + + sizes = np.linspace(0.1, 1, 100) + + for index, size in enumerate(sizes): + q0 = 0.75 - 0.6*size + q1 = 0.75 + 0.6*size + phi0 = np.pi/2 - size + phi1 = np.pi/2 + size + + rebinner = AnularSector(q0, q1, phi0, phi1) + + data_order_neg1.append(rebinner.sum(x, y, random_lobe_data, order=-1)) + data_order_0.append(rebinner.sum(x, y, random_lobe_data, order=0)) + + if index % 10 == 0: + plt.figure("Regions 1") + rebinner.bin_mesh.show(actually_show=False) + + plt.title("Regions") + + plt.figure("Sum of region, dataset 1") + + plt.plot(sizes, data_order_neg1) + plt.plot(sizes, data_order_0) + + plt.legend(["Order -1", "Order 0"]) + plt.title("Sum over region") + + + # Demo of averaging, annular sector over ring shaped data + + if demo2: + + x, y = np.meshgrid(np.linspace(-q_range, q_range, 41), np.linspace(-q_range, q_range, 41)) + x = x.reshape(-1) + y = y.reshape(-1) + + display_mesh = voronoi_mesh(x, y) + + + def ring_test_function(x, y): + r = np.sqrt(x**2 + y**2) + return np.log(np.sinc(r*1.5)**2) + + + grid_ring_data = ring_test_function(x, y) + + plt.figure("Input Dataset 2") + display_mesh.show_data(grid_ring_data, actually_show=False) + + data_order_0 = [] + data_order_neg1 = [] + + sizes = np.linspace(0.1, 1, 100) + + for index, size in enumerate(sizes): + q0 = 0.25 + q1 = 1.25 + + phi0 = np.pi/2 - size + phi1 = np.pi/2 + size + + rebinner = AnularSector(q0, q1, phi0, phi1) + + data_order_neg1.append(rebinner.average(x, y, grid_ring_data, order=-1)) + data_order_0.append(rebinner.average(x, y, grid_ring_data, order=0)) + + if index % 10 == 0: + plt.figure("Regions 2") + rebinner.bin_mesh.show(actually_show=False) + + plt.title("Regions") + + plt.figure("Average of region 2") + + plt.plot(sizes, data_order_neg1) + plt.plot(sizes, data_order_0) + + plt.legend(["Order -1", "Order 0"]) + plt.title("Sum over region") + + plt.show() + diff --git a/sasdata/slicing/slicers/AnularSector.py b/sasdata/slicing/slicers/AnularSector.py new file mode 100644 index 0000000..4ace344 --- /dev/null +++ b/sasdata/slicing/slicers/AnularSector.py @@ -0,0 +1,43 @@ +import numpy as np + +from sasdata.slicing.rebinning import Rebinner +from sasdata.slicing.meshes.mesh import Mesh + +class AnularSector(Rebinner): + """ A single annular sector (wedge sum)""" + def __init__(self, q0: float, q1: float, phi0: float, phi1: float, points_per_degree: int=2): + super().__init__() + + self.q0 = q0 + self.q1 = q1 + self.phi0 = phi0 + self.phi1 = phi1 + + self.points_per_degree = points_per_degree + + def _bin_mesh(self) -> Mesh: + + n_points = np.max([int(1 + 180*self.points_per_degree*(self.phi1 - self.phi0) / np.pi), 2]) + + angles = np.linspace(self.phi0, self.phi1, n_points) + + row1 = self.q0 * np.array([np.cos(angles), np.sin(angles)]) + row2 = self.q1 * np.array([np.cos(angles), np.sin(angles)])[:, ::-1] + + points = np.concatenate((row1, row2), axis=1).T + + cells = [[i for i in range(2*n_points)]] + + return Mesh(points=points, cells=cells) + + def _bin_coordinates(self) -> np.ndarray: + return np.array([], dtype=float) + + +def main(): + """ Just show a random example""" + AnularSector(1, 2, 1, 2).bin_mesh.show() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/sasdata/slicing/slicers/__init__.py b/sasdata/slicing/slicers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sasdata/slicing/transforms.py b/sasdata/slicing/transforms.py new file mode 100644 index 0000000..d04742d --- /dev/null +++ b/sasdata/slicing/transforms.py @@ -0,0 +1,58 @@ +import numpy as np +from scipy.spatial import Voronoi, Delaunay +import matplotlib.pyplot as plt +from matplotlib import cm + + +# Some test data + +qx_base_values = np.linspace(-10, 10, 21) +qy_base_values = np.linspace(-10, 10, 21) + +qx, qy = np.meshgrid(qx_base_values, qy_base_values) + +include = np.logical_not((np.abs(qx) < 2) & (np.abs(qy) < 2)) + +qx = qx[include] +qy = qy[include] + +r = np.sqrt(qx**2 + qy**2) + +data = np.log((1+np.cos(3*r))*np.exp(-r*r)) + +colormap = cm.get_cmap('winter', 256) + +def get_data_mesh(x, y, data): + + input_data = np.array((x, y)).T + voronoi = Voronoi(input_data) + + # plt.scatter(voronoi.vertices[:,0], voronoi.vertices[:,1]) + # plt.scatter(voronoi.points[:,0], voronoi.points[:,1]) + + cmin = np.min(data) + cmax = np.max(data) + + color_index_map = np.array(255 * (data - cmin) / (cmax - cmin), dtype=int) + + for point_index, points in enumerate(voronoi.points): + + region_index = voronoi.point_region[point_index] + region = voronoi.regions[region_index] + + if len(region) > 0: + + if -1 in region: + + pass + + else: + + color = colormap(color_index_map[point_index]) + + circly = region + [region[0]] + plt.fill(voronoi.vertices[circly, 0], voronoi.vertices[circly, 1], color=color, edgecolor="white") + + plt.show() + +get_data_mesh(qx.reshape(-1), qy.reshape(-1), data) \ No newline at end of file diff --git a/sasdata/transforms/operation.py b/sasdata/transforms/operation.py deleted file mode 100644 index c06bb37..0000000 --- a/sasdata/transforms/operation.py +++ /dev/null @@ -1,19 +0,0 @@ -import numpy as np -from sasdata.quantities.quantity import Quantity - -class Operation: - """ Sketch of what model post-processing classes might look like """ - - children: list["Operation"] - named_children: dict[str, "Operation"] - - @property - def name(self) -> str: - raise NotImplementedError("No name for transform") - - def evaluate(self) -> Quantity[np.ndarray]: - pass - - def __call__(self, *children, **named_children): - self.children = children - self.named_children = named_children \ No newline at end of file diff --git a/sasdata/transforms/post_process.py b/sasdata/transforms/post_process.py new file mode 100644 index 0000000..e69de29 diff --git a/sasdata/transforms/rebinning.py b/sasdata/transforms/rebinning.py new file mode 100644 index 0000000..7bdc662 --- /dev/null +++ b/sasdata/transforms/rebinning.py @@ -0,0 +1,204 @@ +""" Algorithms for interpolation and rebinning """ +from typing import TypeVar + +import numpy as np +from numpy._typing import ArrayLike +from scipy.interpolate import interp1d + +from sasdata.quantities.quantity import Quantity +from scipy.sparse import coo_matrix + +from enum import Enum + +class InterpolationOptions(Enum): + NEAREST_NEIGHBOUR = 0 + LINEAR = 1 + CUBIC = 3 + +class InterpolationError(Exception): + """ We probably want to raise exceptions because interpolation is not appropriate/well-defined, + not the same as numerical issues that will raise ValueErrors""" + + +def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], + output_axis: Quantity[ArrayLike], + mask: ArrayLike | None = None, + order: InterpolationOptions = InterpolationOptions.LINEAR, + is_density=False): + + """ Calculate the matrix that converts values recorded at points specified by input_axis to + values recorded at points specified by output_axis""" + + # We want the input values in terms of the output units, will implicitly check compatability + # TODO: incorporate mask + + working_units = output_axis.units + + input_x = input_axis.in_units_of(working_units) + output_x = output_axis.in_units_of(working_units) + + # Get the array indices that will map the array to a sorted one + input_sort = np.argsort(input_x) + output_sort = np.argsort(output_x) + + input_unsort = np.arange(len(input_x), dtype=int)[input_sort] + output_unsort = np.arange(len(output_x), dtype=int)[output_sort] + + sorted_in = input_x[input_sort] + sorted_out = output_x[output_sort] + + n_in = len(sorted_in) + n_out = len(sorted_out) + + conversion_matrix = None # output + + match order: + case InterpolationOptions.NEAREST_NEIGHBOUR: + + # COO Sparse matrix definition data + i_entries = [] + j_entries = [] + + crossing_points = 0.5*(sorted_out[1:] + sorted_out[:-1]) + + # Find the output values nearest to each of the input values + i=0 + for k, crossing_point in enumerate(crossing_points): + while i < n_in and sorted_in[i] < crossing_point: + i_entries.append(i) + j_entries.append(k) + i += 1 + + # All the rest in the last bin + while i < n_in: + i_entries.append(i) + j_entries.append(n_out-1) + i += 1 + + i_entries = input_unsort[np.array(i_entries, dtype=int)] + j_entries = output_unsort[np.array(j_entries, dtype=int)] + values = np.ones_like(i_entries, dtype=float) + + conversion_matrix = coo_matrix((values, (i_entries, j_entries)), shape=(n_in, n_out)) + + case InterpolationOptions.LINEAR: + + # Leverage existing linear interpolation methods to get the mapping + # do a linear interpolation on indices + # the floor should give the left bin + # the ceil should give the right bin + # the fractional part should give the relative weightings + + input_indices = np.arange(n_in, dtype=int) + output_indices = np.arange(n_out, dtype=int) + + fractional = np.interp(x=sorted_out, xp=sorted_in, fp=input_indices, left=0, right=n_in-1) + + left_bins = np.floor(fractional).astype(int) + right_bins = np.ceil(fractional).astype(int) + + right_weight = fractional % 1 + left_weight = 1 - right_weight + + # There *should* be no repeated entries for both i and j in the main part, but maybe at the ends + # If left bin is the same as right bin, then we only want one entry, and the weight should be 1 + + same = left_bins == right_bins + not_same = ~same + + same_bins = left_bins[same] # could equally be right bins, they're the same + + same_indices = output_indices[same] + not_same_indices = output_indices[not_same] + + j_entries_sorted = np.concatenate((same_indices, not_same_indices, not_same_indices)) + i_entries_sorted = np.concatenate((same_bins, left_bins[not_same], right_bins[not_same])) + + i_entries = input_unsort[i_entries_sorted] + j_entries = output_unsort[j_entries_sorted] + + # weights don't need to be unsorted # TODO: check this is right, it should become obvious if we use unsorted data + weights = np.concatenate((np.ones_like(same_bins, dtype=float), left_weight[not_same], right_weight[not_same])) + + conversion_matrix = coo_matrix((weights, (i_entries, j_entries)), shape=(n_in, n_out)) + + case InterpolationOptions.CUBIC: + # Cubic interpolation, much harder to implement because we can't just cheat and use numpy + raise NotImplementedError("Cubic interpolation not implemented yet") + + case _: + raise InterpolationError(f"Unsupported interpolation order: {order}") + + + if mask is None: + return conversion_matrix, None + else: + # Create a new mask + + # Convert to numerical values + # Conservative masking: anything touched by the previous mask is now masked + new_mask = (np.array(mask, dtype=float) @ conversion_matrix) != 0.0 + + return conversion_matrix, new_mask + + +def calculate_interpolation_matrix_2d_axis_axis(input_1: Quantity[ArrayLike], + input_2: Quantity[ArrayLike], + output_1: Quantity[ArrayLike], + output_2: Quantity[ArrayLike], + mask, + order: InterpolationOptions = InterpolationOptions.LINEAR, + is_density: bool = False): + + # This is just the same 1D matrices things + + match order: + case InterpolationOptions.NEAREST_NEIGHBOUR: + pass + + case InterpolationOptions.LINEAR: + pass + + case InterpolationOptions.CUBIC: + pass + + case _: + pass + + +def calculate_interpolation_matrix(input_axes: list[Quantity[ArrayLike]], + output_axes: list[Quantity[ArrayLike]], + data: ArrayLike | None = None, + mask: ArrayLike | None = None): + + # TODO: We probably should delete this, but lets keep it for now + + if len(input_axes) not in (1, 2): + raise InterpolationError("Interpolation is only supported for 1D and 2D data") + + if len(input_axes) == 1 and len(output_axes) == 1: + # Check for dimensionality + input_axis = input_axes[0] + output_axis = output_axes[0] + + if len(input_axis.value.shape) == 1: + if len(output_axis.value.shape) == 1: + calculate_interpolation_matrix_1d() + + if len(output_axes) != len(input_axes): + # Input or output axes might be 2D matrices + pass + + + +def rebin(data: Quantity[ArrayLike], + axes: list[Quantity[ArrayLike]], + new_axes: list[Quantity[ArrayLike]], + mask: ArrayLike | None = None, + interpolation_order: int = 1): + + """ This algorithm is only for operations that preserve dimensionality, + i.e. non-projective rebinning. + """ + + pass \ No newline at end of file diff --git a/sasdata/transforms/test_interpolation.py b/sasdata/transforms/test_interpolation.py new file mode 100644 index 0000000..688da65 --- /dev/null +++ b/sasdata/transforms/test_interpolation.py @@ -0,0 +1,91 @@ +import pytest +import numpy as np +from matplotlib import pyplot as plt +from numpy.typing import ArrayLike +from typing import Callable + +from sasdata.quantities.plotting import quantity_plot +from sasdata.quantities.quantity import NamedQuantity, Quantity +from sasdata.quantities import units + +from sasdata.transforms.rebinning import calculate_interpolation_matrix_1d, InterpolationOptions + +test_functions = [ + lambda x: x**2, + lambda x: 2*x, + lambda x: x**3 +] + + +@pytest.mark.parametrize("fun", test_functions) +def test_linear_interpolate_matrix_inside(fun: Callable[[Quantity[ArrayLike]], Quantity[ArrayLike]]): + original_points = NamedQuantity("x_base", np.linspace(-10,10, 31), units.meters) + test_points = NamedQuantity("x_test", np.linspace(-5, 5, 11), units.meters) + + + mapping = calculate_interpolation_matrix_1d(original_points, test_points, order=InterpolationOptions.LINEAR) + + y_original = fun(original_points) + y_test = y_original @ mapping + y_expected = fun(test_points) + + test_units = y_expected.units + + y_values_test = y_test.in_units_of(test_units) + y_values_expected = y_expected.in_units_of(test_units) + + # print(y_values_test) + # print(y_values_expected) + # + # quantity_plot(original_points, y_original) + # quantity_plot(test_points, y_test) + # quantity_plot(test_points, y_expected) + # plt.show() + + assert len(y_values_test) == len(y_values_expected) + + for t, e in zip(y_values_test, y_values_expected): + assert t == pytest.approx(e, abs=2) + + +@pytest.mark.parametrize("fun", test_functions) +def test_linear_interpolate_different_units(fun: Callable[[Quantity[ArrayLike]], Quantity[ArrayLike]]): + original_points = NamedQuantity("x_base", np.linspace(-10,10, 107), units.meters) + test_points = NamedQuantity("x_test", np.linspace(-5000, 5000, 11), units.millimeters) + + mapping = calculate_interpolation_matrix_1d(original_points, test_points, order=InterpolationOptions.LINEAR) + + y_original = fun(original_points) + y_test = y_original @ mapping + y_expected = fun(test_points) + + test_units = y_expected.units + + y_values_test = y_test.in_units_of(test_units) + y_values_expected = y_expected.in_units_of(test_units) + # + # print(y_values_test) + # print(y_test.in_si()) + # print(y_values_expected) + # + # plt.plot(original_points.in_si(), y_original.in_si()) + # plt.plot(test_points.in_si(), y_test.in_si(), "x") + # plt.plot(test_points.in_si(), y_expected.in_si(), "o") + # plt.show() + + assert len(y_values_test) == len(y_values_expected) + + for t, e in zip(y_values_test, y_values_expected): + assert t == pytest.approx(e, rel=5e-2) + +def test_linearity_linear(): + """ Test linear interpolation between two points""" + x_and_y = NamedQuantity("x_base", np.linspace(-10, 10, 2), units.meters) + new_x = NamedQuantity("x_test", np.linspace(-5000, 5000, 101), units.millimeters) + + mapping = calculate_interpolation_matrix_1d(x_and_y, new_x, order=InterpolationOptions.LINEAR) + + linear_points = x_and_y @ mapping + + for t, e in zip(new_x.in_si(), linear_points.in_si()): + assert t == pytest.approx(e, rel=1e-3) \ No newline at end of file diff --git a/test/slicers/__init__.py b/test/slicers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/slicers/meshes_for_testing.py b/test/slicers/meshes_for_testing.py new file mode 100644 index 0000000..fb346e7 --- /dev/null +++ b/test/slicers/meshes_for_testing.py @@ -0,0 +1,115 @@ +""" +Meshes used in testing along with some expected values +""" + +import numpy as np + +from sasdata.slicing.meshes.voronoi_mesh import voronoi_mesh +from sasdata.slicing.meshes.mesh import Mesh +from sasdata.slicing.meshes.meshmerge import meshmerge + +coords = np.arange(-4, 5) +grid_mesh = voronoi_mesh(*np.meshgrid(coords, coords)) + + +item_1 = np.array([ + [-3.5, -0.5], + [-0.5, 3.5], + [ 0.5, 3.5], + [ 3.5, -0.5], + [ 0.0, 1.5]], dtype=float) + +item_2 = np.array([ + [-1.0, -2.0], + [-2.0, -2.0], + [-2.0, -1.0], + [-1.0, -1.0]], dtype=float) + +mesh_points = np.concatenate((item_1, item_2), axis=0) +cells = [[0,1,2,3,4],[5,6,7,8]] + +shape_mesh = Mesh(mesh_points, cells) + +# Subset of the mappings that meshmerge should include +# This can be read off the plots generated below + + +expected_shape_mappings = [ + (100, -1), + (152, -1), + (141, -1), + (172, -1), + (170, -1), + (0, -1), + (1, -1), + (8, 0), + (9, 0), + (37, 0), + (83, 0), + (190, 1), + (186, 1), + (189, 1), + (193, 1) +] + +expected_grid_mappings = [ + (89, 0), + (90, 1), + (148, 16), + (175, 35), + (60, 47), + (44, 47), + (80, 60) +] + +# +# Mesh location tests +# + +location_test_mesh_points = np.array([ + [0, 0], # 0 + [0, 1], # 1 + [0, 2], # 2 + [1, 0], # 3 + [1, 1], # 4 + [1, 2], # 5 + [2, 0], # 6 + [2, 1], # 7 + [2, 2]], dtype=float) + +location_test_mesh_cells = [ + [0, 1, 4, 3], + [1, 2, 5, 4], + [3, 4, 7, 6], + [4, 5, 8, 7]] + +location_test_mesh = Mesh(location_test_mesh_points, location_test_mesh_cells) + +test_coords = 0.25 + 0.5*np.arange(4) +location_test_points_x, location_test_points_y = np.meshgrid(test_coords, test_coords) + +if __name__ == "__main__": + + import matplotlib.pyplot as plt + + combined_mesh, _, _ = meshmerge(grid_mesh, shape_mesh) + + plt.figure() + combined_mesh.show(actually_show=False, show_labels=True, color='k') + grid_mesh.show(actually_show=False, show_labels=True, color='r') + + plt.xlim([-5, 5]) + plt.ylim([-5, 5]) + + plt.figure() + combined_mesh.show(actually_show=False, show_labels=True, color='k') + shape_mesh.show(actually_show=False, show_labels=True, color='r') + + plt.xlim([-5, 5]) + plt.ylim([-5, 5]) + + plt.figure() + location_test_mesh.show(actually_show=False, show_labels=True) + plt.scatter(location_test_points_x, location_test_points_y) + + plt.show() diff --git a/test/slicers/utest_meshmerge.py b/test/slicers/utest_meshmerge.py new file mode 100644 index 0000000..4e4ee83 --- /dev/null +++ b/test/slicers/utest_meshmerge.py @@ -0,0 +1,28 @@ +""" +Tests for mesh merging operations. + +It's pretty hard to test componentwise, but we can do some tests of the general behaviour +""" + +from sasdata.slicing.meshes.meshmerge import meshmerge +from test.slicers.meshes_for_testing import ( + grid_mesh, shape_mesh, expected_grid_mappings, expected_shape_mappings) + + +def test_meshmerge_mappings(): + """ Test the output of meshmerge is correct + + IMPORTANT IF TESTS FAIL!!!... The docs for scipy.spatial.Voronoi and Delaunay + say that the ordering of faces might depend on machine precession. Thus, these + tests might not be reliable... we'll see how they play out + """ + + + combined_mesh, grid_mappings, shape_mappings = meshmerge(grid_mesh, shape_mesh) + + for triangle_cell, grid_cell in expected_grid_mappings: + assert grid_mappings[triangle_cell] == grid_cell + + for triangle_cell, shape_cell in expected_shape_mappings: + assert shape_mappings[triangle_cell] == shape_cell + diff --git a/test/slicers/utest_point_assignment.py b/test/slicers/utest_point_assignment.py new file mode 100644 index 0000000..4ff53e7 --- /dev/null +++ b/test/slicers/utest_point_assignment.py @@ -0,0 +1,5 @@ + +from test.slicers.meshes_for_testing import location_test_mesh, location_test_points_x, location_test_points_y + +def test_location_assignment(): + pass \ No newline at end of file