From 8da51b9b6c747868958c298464de40596e7ee4a6 Mon Sep 17 00:00:00 2001 From: Joris Snellenburg Date: Fri, 23 Aug 2024 22:14:12 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20support=20for=20pfid=20megaco?= =?UTF-8?q?mplex=20(#1510)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `PFIDMegacomplex` is a megacomplex used for fitting perturbed free induction decay as described in Hamm 1995: https://doi.org/10.1016/0301-0104(95)00262-6 *👌 Guard against zero standard errors for fixed parameters *🧪👌 Add irf attribute to PFIDDatasetModel and update pfid unit tests - OneOscillationWithIrf - OneOscillationWithSequentialModel * ♻️Refactor pfid_megacomplex to be always index dependent * 📚 Added change to changelog * 🩹 Modify unit test to be more suspectiable to orderin bug (fixed in PR #1512 ) * 🧪 Added test to catch AttributeError validating bad pfid model definition * 🩹 Fix AttributeError validating bad pfid model definition * 🧰 Remove pre-commit from dev dependencies Co-authored-by: Sebastian Weigand Co-authored-by: Jörn Weißenborn Co-authored-by: Ivo van Stokkum --- changelog.md | 1 + .../builtin/megacomplexes/pfid/__init__.py | 1 + .../megacomplexes/pfid/pfid_megacomplex.py | 273 +++++++++++++ .../pfid/test/test_pfid_model.py | 371 ++++++++++++++++++ glotaran/parameter/parameters.py | 4 + requirements_dev.txt | 1 - setup.cfg | 1 + 7 files changed, 651 insertions(+), 1 deletion(-) create mode 100644 glotaran/builtin/megacomplexes/pfid/__init__.py create mode 100644 glotaran/builtin/megacomplexes/pfid/pfid_megacomplex.py create mode 100644 glotaran/builtin/megacomplexes/pfid/test/test_pfid_model.py diff --git a/changelog.md b/changelog.md index ad8485392..9c5cf57cd 100644 --- a/changelog.md +++ b/changelog.md @@ -7,6 +7,7 @@ ### ✨ Features - ✨ Add official Python 3.12 support (#1437) +- ✨ Add support for pfid megacomplex (#1510) ### 🩹 Bug fixes diff --git a/glotaran/builtin/megacomplexes/pfid/__init__.py b/glotaran/builtin/megacomplexes/pfid/__init__.py new file mode 100644 index 000000000..cb0b0b2ef --- /dev/null +++ b/glotaran/builtin/megacomplexes/pfid/__init__.py @@ -0,0 +1 @@ +from glotaran.builtin.megacomplexes.pfid.pfid_megacomplex import PFIDMegacomplex diff --git a/glotaran/builtin/megacomplexes/pfid/pfid_megacomplex.py b/glotaran/builtin/megacomplexes/pfid/pfid_megacomplex.py new file mode 100644 index 000000000..9aa8ded6e --- /dev/null +++ b/glotaran/builtin/megacomplexes/pfid/pfid_megacomplex.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import xarray as xr +from scipy.special import erf + +from glotaran.builtin.megacomplexes.decay.irf import Irf +from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian +from glotaran.model import DatasetModel +from glotaran.model import ItemIssue +from glotaran.model import Megacomplex +from glotaran.model import Model +from glotaran.model import ModelItemType +from glotaran.model import ParameterType +from glotaran.model import attribute +from glotaran.model import item +from glotaran.model import megacomplex + +if TYPE_CHECKING: + from glotaran.parameter import Parameters + from glotaran.typing.types import ArrayLike + + +class OscillationParameterIssue(ItemIssue): + def __init__(self, label: str, len_labels: int, len_frequencies: int, len_rates: int): + self.label = label + self.len_labels = len_labels + self.len_frequencies = len_frequencies + self.len_rates = len_rates + + def to_string(self) -> str: + return ( + f"The size of labels ({self.len_labels}), frequencies ({self.len_frequencies}), " + f"and rates ({self.len_rates}) does not match for pfid " + f"megacomplex '{self.label}'." + ) + + +def validate_pfid_parameter( + labels: list[str], + pfid: PFIDMegacomplex, + model: Model, + parameters: Parameters | None, +) -> list[ItemIssue]: + issues = [] + + len_labels, len_frequencies, len_rates = ( + len(pfid.labels), + len(pfid.frequencies), + len(pfid.rates), + ) + + if len({len_labels, len_frequencies, len_rates}) > 1: + issues.append( + OscillationParameterIssue(pfid.label, len_labels, len_frequencies, len_rates) + ) + + return issues + + +@item +class PFIDDatasetModel(DatasetModel): + spectral_axis_inverted: bool = False + spectral_axis_scale: float = 1 + irf: ModelItemType[Irf] | None = None + + +@megacomplex(dataset_model_type=PFIDDatasetModel) +class PFIDMegacomplex(Megacomplex): + dimension: str = "time" + type: str = "pfid" + labels: list[str] = attribute(validator=validate_pfid_parameter) + frequencies: list[ParameterType] # omega_a + rates: list[ParameterType] # 1/T2 + + def calculate_matrix( + self, + dataset_model: DatasetModel, + global_axis: ArrayLike, + model_axis: ArrayLike, + **kwargs, + ): + clp_label = [f"{label}_cos" for label in self.labels] + [ + f"{label}_sin" for label in self.labels + ] + + frequencies = np.array(self.frequencies) + rates = np.array(self.rates) + + if dataset_model.spectral_axis_inverted: + frequencies = dataset_model.spectral_axis_scale / frequencies + elif dataset_model.spectral_axis_scale != 1: + frequencies = frequencies * dataset_model.spectral_axis_scale + + irf = dataset_model.irf + matrix_shape = (global_axis.size, model_axis.size, len(clp_label)) + matrix = np.zeros(matrix_shape, dtype=np.float64) + + if irf is None: + msg = "IRF is required for PFID megacomplex" + raise ValueError(msg) + if isinstance(irf, IrfMultiGaussian): + for i in range(global_axis.size): + calculate_pfid_matrix_gaussian_irf_on_index( + matrix[i], + frequencies, + rates, + irf, + i, + global_axis, + model_axis, + ) + else: + msg = "IRF should be instance of IrfMultiGaussian" + raise ValueError(msg) + return clp_label, matrix + + def finalize_data( + self, + dataset_model: DatasetModel, + dataset: xr.Dataset, + is_full_model: bool = False, + as_global: bool = False, + ): + if is_full_model: + return + + megacomplexes = ( + dataset_model.global_megacomplex if is_full_model else dataset_model.megacomplex + ) + unique = len([m for m in megacomplexes if isinstance(m, PFIDMegacomplex)]) < 2 + + prefix = "pfid" if unique else f"{self.label}_pfid" + + dataset.coords[f"{prefix}"] = self.labels + dataset.coords[f"{prefix}_frequency"] = (prefix, self.frequencies) + dataset.coords[f"{prefix}_rate"] = (prefix, self.rates) + + model_dimension = dataset.attrs["model_dimension"] + global_dimension = dataset.attrs["global_dimension"] + dim1 = dataset.coords[global_dimension].size + dim2 = len(self.labels) + pfid = np.zeros((dim1, dim2), dtype=np.float64) + phase = np.zeros((dim1, dim2), dtype=np.float64) + + for i, label in enumerate(self.labels): + sin = dataset.clp.sel(clp_label=f"{label}_sin") + cos = dataset.clp.sel(clp_label=f"{label}_cos") + pfid[:, i] = np.sqrt(sin * sin + cos * cos) + phase[:, i] = np.unwrap(np.arctan2(sin, cos)) + + dataset[f"{prefix}_associated_spectra"] = ( + (global_dimension, prefix), + pfid, + ) + + dataset[f"{prefix}_phase"] = ( + (global_dimension, prefix), + phase, + ) + + dataset[f"{prefix}_sin"] = ( + ( + global_dimension, + model_dimension, + prefix, + ), + dataset.matrix.sel(clp_label=[f"{label}_sin" for label in self.labels]).to_numpy(), + ) + + dataset[f"{prefix}_cos"] = ( + ( + global_dimension, + model_dimension, + prefix, + ), + dataset.matrix.sel(clp_label=[f"{label}_cos" for label in self.labels]).to_numpy(), + ) + + +def calculate_pfid_matrix_gaussian_irf_on_index( + matrix: ArrayLike, + frequencies: ArrayLike, + rates: ArrayLike, + irf: IrfMultiGaussian, + global_index: int | None, + global_axis: ArrayLike, + model_axis: ArrayLike, +): + centers, widths, scales, shift, _, _ = irf.parameter(global_index, global_axis) + for center, width, scale in zip(centers, widths, scales, strict=True): + matrix += calculate_pfid_matrix_gaussian_irf( + frequencies, + rates, + model_axis, + center, + width, + shift, + scale, + global_axis[global_index], + ) + matrix /= np.sum(scales) + + +def calculate_pfid_matrix_gaussian_irf( + frequencies: np.ndarray, + rates: np.ndarray, + model_axis: np.ndarray, + center: float, + width: float, + shift: float, + scale: float, + global_axis_value: float, +): + """Calculate the damped oscillation matrix taking into account a gaussian irf. + + Parameters + ---------- + frequencies : np.ndarray + an array of frequencies in THz, one per oscillation + rates : np.ndarray + an array of dephasing rates (negative), one per oscillation + model_axis : np.ndarray + the model axis (time) + center : float + the center of the gaussian IRF + width : float + the width (σ) parameter of the the IRF + shift : float + a shift parameter per item on the global axis + scale : float + the scale parameter to scale the matrix by + + Returns + ------- + np.ndarray + An array of the real and imaginary part of the oscillation matrix, + the shape being (len(model_axis), len(frequencies)). + """ + shifted_axis = model_axis - center - shift + # For calculations using the negative rates we use the time axis + # from the beginning up to 5 σ from the irf center + # this is to guard again overflows + left_shifted_axis_indices = np.where(shifted_axis < 5 * width)[0] + left_shifted_axis = shifted_axis[left_shifted_axis_indices] + neg_idx = np.where(rates < 0)[0] + + # c multiply by 0.03 to convert wavenumber (cm-1) to frequency (THz) + # where 0.03 is the product of speed of light 3*10**10 cm/s and time-unit ps (10^-12) + # we postpone the conversion because the global axis is + # always expected to be in cm-1 for relevant experiments + frequency_diff = (global_axis_value - frequencies) * 0.03 * 2 * np.pi + d = width**2 + k = rates + 1j * frequency_diff + dk = k * d + sqwidth = np.sqrt(2) * width + + a = np.zeros((len(model_axis), len(rates)), dtype=np.complex128) + a[np.ix_(left_shifted_axis_indices, neg_idx)] = np.exp( + (-1 * left_shifted_axis[:, None] + 0.5 * dk[:]) * k[:] + ) + + b = np.zeros((len(model_axis), len(rates)), dtype=np.complex128) + # For negative rates we flip the sign of the `erf` by using `-sqwidth` in lieu of `sqwidth` + b[np.ix_(left_shifted_axis_indices, neg_idx)] = 1 + erf( + (left_shifted_axis[:, None] - dk[:]) / -sqwidth + ) + + osc = -(a * b) * scale + + return np.concatenate((osc.real, osc.imag), axis=1) diff --git a/glotaran/builtin/megacomplexes/pfid/test/test_pfid_model.py b/glotaran/builtin/megacomplexes/pfid/test/test_pfid_model.py new file mode 100644 index 000000000..447b149af --- /dev/null +++ b/glotaran/builtin/megacomplexes/pfid/test/test_pfid_model.py @@ -0,0 +1,371 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from glotaran.builtin.megacomplexes.decay import DecayMegacomplex +from glotaran.builtin.megacomplexes.pfid import PFIDMegacomplex +from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex +from glotaran.model import Model +from glotaran.optimization.optimize import optimize +from glotaran.parameter import Parameters +from glotaran.project import Scheme +from glotaran.simulation import simulate + + +class OneOscillationWithIrf: + pure_pfid_model = Model.create_class_from_megacomplexes([PFIDMegacomplex, SpectralMegacomplex]) + sim_model = pure_pfid_model( + **{ + "megacomplex": { + "pfid": { + "type": "pfid", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + "spectral": { + "type": "spectral", + "shape": { + "osc1_cos": "sh2", + "osc1_sin": "sh1", + }, + }, + }, + "shape": { + "sh1": { + "type": "gaussian", + "amplitude": "shapes.amps.1", + "location": "shapes.locs.1", + "width": "shapes.width.1", + }, + "sh2": { + "type": "gaussian", + "amplitude": "shapes.amps.2", + "location": "shapes.locs.2", + "width": "shapes.width.2", + }, + }, + "irf": { + "irf1": { + "type": "gaussian", + "center": "irf.center", + "width": "irf.width", + }, + }, + "dataset": { + "dataset1": { + "megacomplex": ["pfid"], + "global_megacomplex": ["spectral"], + "irf": "irf1", + } + }, + } + ) + + model = pure_pfid_model( + **{ + "megacomplex": { + "m1": { + "type": "pfid", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + }, + "irf": { + "irf1": { + "type": "gaussian", + "center": "irf.center", + "width": "irf.width", + }, + }, + "dataset": { + "dataset1": { + "megacomplex": ["m1"], + "irf": "irf1", + } + }, + } + ) + + wanted_parameter = Parameters.from_dict( + { + "osc": [ + ["freq", 1500], + ["rate", -2], + ], + "shapes": {"amps": [2, -2], "locs": [1490, 1510], "width": [4, 4]}, + "irf": [["center", 0.01], ["width", 0.05]], + } + ) + + parameter = Parameters.from_dict( + { + "osc": [ + ["freq", 1501], + ["rate", -2.1], + ], + "irf": [["center", 0.01], ["width", 0.05]], + } + ) + + time = np.arange(-4, 1, 0.01) + spectral = np.arange(1480, 1520, 1) + axis = {"time": time, "spectral": spectral} + + wanted_clp = ["pfid1"] + wanted_shape = (40, 1) + + +class OneOscillationWithSequentialModel: + decay_pfid_model = Model.create_class_from_megacomplexes( + [PFIDMegacomplex, DecayMegacomplex, SpectralMegacomplex] + ) + sim_model = decay_pfid_model( + **{ + "initial_concentration": { + "j1": {"compartments": ["s1", "s2"], "parameters": ["j.1", "j.0"]}, + }, + "k_matrix": { + "k1": { + "matrix": { + ("s2", "s1"): "kinetic.1", + ("s2", "s2"): "kinetic.2", + } + } + }, + "megacomplex": { + "m1": {"type": "decay", "k_matrix": ["k1"]}, + "m2": { + "type": "pfid", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + "m3": { + "type": "spectral", + "shape": { + "s1": "sh3", + "s2": "sh4", + "osc1_cos": "sh2", + "osc1_sin": "sh1", + }, + }, + }, + "shape": { + "sh1": { + "type": "gaussian", + "amplitude": "shapes.amps.1", + "location": "shapes.locs.1", + "width": "shapes.width.1", + }, + "sh2": { + "type": "gaussian", + "amplitude": "shapes.amps.2", + "location": "shapes.locs.2", + "width": "shapes.width.2", + }, + "sh3": { + "type": "gaussian", + "amplitude": "shapes.amps.3", + "location": "shapes.locs.3", + "width": "shapes.width.3", + }, + "sh4": { + "type": "gaussian", + "amplitude": "shapes.amps.4", + "location": "shapes.locs.4", + "width": "shapes.width.4", + }, + }, + "irf": { + "irf1": { + "type": "gaussian", + "center": "irf.center", + "width": "irf.width", + }, + }, + "dataset": { + "dataset1": { + "initial_concentration": "j1", + "irf": "irf1", + "megacomplex": ["m2", "m1"], + "global_megacomplex": ["m3"], + } + }, + } + ) + + model = decay_pfid_model( + **{ + "initial_concentration": { + "j1": {"compartments": ["s1", "s2"], "parameters": ["j.1", "j.0"]}, + }, + "k_matrix": { + "k1": { + "matrix": { + ("s2", "s1"): "kinetic.1", + ("s2", "s2"): "kinetic.2", + } + } + }, + "megacomplex": { + "m1": {"type": "decay", "k_matrix": ["k1"]}, + "m2": { + "type": "pfid", + "labels": ["osc1"], + "frequencies": ["osc.freq"], + "rates": ["osc.rate"], + }, + }, + "irf": { + "irf1": { + "type": "gaussian", + "center": "irf.center", + "width": "irf.width", + }, + }, + "dataset": { + "dataset1": { + "initial_concentration": "j1", + "irf": "irf1", + "megacomplex": ["m1", "m2"], + } + }, + } + ) + + wanted_parameter = Parameters.from_dict( + { + "j": [ + ["1", 1, {"vary": False, "non-negative": False}], + ["0", 0, {"vary": False, "non-negative": False}], + ], + "kinetic": [ + ["1", 0.05], + ["2", 0.001], + ], + "osc": [ + ["freq", 1500], + ["rate", -2], + ], + "shapes": { + "amps": [2, -2, 8, 9], + "locs": [1490, 1510, 1495, 1505], + "width": [4, 4, 3, 5], + }, + "irf": [["center", 0.01], ["width", 0.05]], + } + ) + + parameter = Parameters.from_dict( + { + "j": [ + ["1", 1, {"vary": False, "non-negative": False}], + ["0", 0, {"vary": False, "non-negative": False}], + ], + "kinetic": [ + ["1", 0.055], + ["2", 0.0015], + ], + "osc": [ + ["freq", 1501], + ["rate", -2.1], + ], + "irf": [["center", 0.01], ["width", 0.05]], + } + ) + + time = np.arange(-5, 80, 0.01) + spectral = np.arange(1480, 1520, 1) + axis = {"time": time, "spectral": spectral} + + wanted_clp = ["osc1_cos", "osc1_sin", "s1", "s2"] + wanted_shape = (600, 4) + + +@pytest.mark.parametrize( + "suite", + [ + OneOscillationWithIrf, + OneOscillationWithSequentialModel, + ], +) +def test_pfid_model(suite): + class_name = suite.__name__ + print(suite.sim_model.validate()) + assert suite.sim_model.valid() + + print(suite.model.validate()) + assert suite.model.valid() + + print(suite.sim_model.validate(suite.wanted_parameter)) + assert suite.sim_model.valid(suite.wanted_parameter) + + print(suite.model.validate(suite.parameter)) + assert suite.model.valid(suite.parameter) + + dataset = simulate( + suite.sim_model, + "dataset1", + suite.wanted_parameter, + suite.axis, + noise=True, + noise_std_dev=1e-8, + noise_seed=123, + ) + print(dataset) + + assert dataset.data.shape == (suite.axis["time"].size, suite.axis["spectral"].size) + + print(suite.parameter) + print(suite.wanted_parameter) + + data = {"dataset1": dataset} + scheme = Scheme( + model=suite.model, + parameters=suite.parameter, + data=data, + maximum_number_function_evaluations=5, + ) + result = optimize(scheme, raise_exception=True) + print(result.optimized_parameters) + + for param in result.optimized_parameters.all(): + assert np.allclose(param.value, suite.wanted_parameter.get(param.label).value, rtol=1e-1) + + resultdata = result.data["dataset1"] + assert np.array_equal(dataset["time"], resultdata["time"]) + assert np.array_equal(dataset["spectral"], resultdata["spectral"]) + assert dataset.data.shape == resultdata.fitted_data.shape + assert np.allclose(dataset.data, resultdata.fitted_data, atol=1e-5) + + assert "pfid_associated_spectra" in resultdata + assert "pfid_phase" in resultdata + + # Ensure that s1, s2 are not mixed up with osc1_cos and osc1_sin by checking amplitudes + if "OneOscillationWithSequentialModel" in class_name: + assert resultdata.species_associated_spectra.sel(species="s1").max() > 7 + assert resultdata.species_associated_spectra.sel(species="s2").max() > 8 + + +def test_pfid_model_validate(): + """An ``OscillationParameterIssue`` should be raised if there is a list length mismatch. + + List values are: ``labels``, ``frequencies``, ``rates``. + """ + pure_pfid_model = Model.create_class_from_megacomplexes([PFIDMegacomplex, SpectralMegacomplex]) + model_data = OneOscillationWithIrf.sim_model.as_dict() + model_data["megacomplex"]["pfid"]["labels"].append("extra-label") + model = pure_pfid_model(**model_data) + validation_msg = model.validate() + assert ( + validation_msg == "Your model has 1 problem:\n\n" + " * The size of labels (2), frequencies (1), and rates (1) does not match for pfid " + "megacomplex 'pfid'." + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-s", "-v"]) diff --git a/glotaran/parameter/parameters.py b/glotaran/parameter/parameters.py index 8a4571bd0..b5d6a5fa1 100644 --- a/glotaran/parameter/parameters.py +++ b/glotaran/parameter/parameters.py @@ -517,7 +517,11 @@ def param_dict_to_markdown( ] if label is not None: return_string += f"{node_indentation}* __{label}__:\n" + if isinstance(parameters, list): + for parameter in parameters: + if abs(parameter.standard_error) < 1e-15: + parameter.standard_error = np.nan parameter_rows = [ [ parameter.label_short, diff --git a/requirements_dev.txt b/requirements_dev.txt index 9004450ce..356565739 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -38,5 +38,4 @@ types-dataclasses>=0.1.7 # code quality assurance flake8>=3.8.3 -pre-commit>=2.9.0 setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/setup.cfg b/setup.cfg index 60ab29267..655121dc1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -62,6 +62,7 @@ glotaran.plugins.megacomplexes = clp_guide = glotaran.builtin.megacomplexes.clp_guide coherent_artifact = glotaran.builtin.megacomplexes.coherent_artifact damped_oscillation = glotaran.builtin.megacomplexes.damped_oscillation + pfid = glotaran.builtin.megacomplexes.pfid decay = glotaran.builtin.megacomplexes.decay spectral = glotaran.builtin.megacomplexes.spectral glotaran.plugins.project_io =